From 038957a5b6cde41376a96fcc42eaab617881a4d3 Mon Sep 17 00:00:00 2001 From: Matthias Kramm Date: Tue, 25 Jun 2024 14:34:16 -0700 Subject: [PATCH 1/9] Remove stay type annotation from context(). PiperOrigin-RevId: 646609853 --- tensorflow/python/eager/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index fee09b7dc4a599..f163eca309db3e 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -2465,7 +2465,7 @@ def _reset_jit_compiler_flags(): pywrap_tfe.TF_ResetJitCompilerFlags() -def context() -> Context: +def context(): """Returns a singleton context object.""" if _context is None: _create_context() From f47f49c708571bb9eae5f2464f2fd0f2ef2ee9f8 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Tue, 25 Jun 2024 15:12:37 -0700 Subject: [PATCH 2/9] Properly override repository for JAX builds in `build.py` Also change an `if_static` to an `if_google` to fix JAX builds PiperOrigin-RevId: 646622609 --- third_party/xla/build_tools/build.py | 8 ++++++-- third_party/xla/xla/service/gpu/BUILD | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/third_party/xla/build_tools/build.py b/third_party/xla/build_tools/build.py index 6629ea43d9e88a..8b353f3cdbb9e7 100755 --- a/third_party/xla/build_tools/build.py +++ b/third_party/xla/build_tools/build.py @@ -274,7 +274,9 @@ def nvidia_gpu_build_with_compute_capability( JAX_NUM_GENERATED_CASES=25, JAX_SKIP_SLOW_TESTS=1, ), - options=_DEFAULT_BAZEL_OPTIONS, + options=dict( + **_DEFAULT_BAZEL_OPTIONS, override_repository="xla=/github/xla" + ), ) _JAX_GPU_BUILD = Build( @@ -294,7 +296,9 @@ def nvidia_gpu_build_with_compute_capability( TF_CPP_MIN_LOG_LEVEL=0, JAX_EXCLUDE_TEST_TARGETS="PmapTest.testSizeOverflow", ), - options=_DEFAULT_BAZEL_OPTIONS, + options=dict( + **_DEFAULT_BAZEL_OPTIONS, override_repository="xla=/github/xla" + ), ) _KOKORO_JOB_NAME_TO_BUILD_MAP = { diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 2a6e8de134c2c0..0726033d31e418 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -4796,7 +4796,7 @@ cc_library( ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cudnn_header", - ]) + if_static([ + ]) + if_google([ "@com_google_protobuf//:wrappers_cc_proto", ]), ) From 2d72742d40f1d3121c895f8584ec8882d1e97fc8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jun 2024 15:33:22 -0700 Subject: [PATCH 3/9] Add tensorflow support for 16k page sizes on arm64 Tested both libtensorflowlite.so and libtensorflowlite_jni.so to ensure both libraries are 16k ELF aligned with this change: $ objdump -p bazel-bin/tensorflow/lite/libtensorflowlite.so | grep LOAD | awk '{ print $1 " " $NF }' LOAD 2**14 LOAD 2**14 $ objdump -p bazel-bin/tensorflow/lite/java/libtensorflowlite_jni.so | grep LOAD | awk '{ print $1 " " $NF }' LOAD 2**14 LOAD 2**14 PiperOrigin-RevId: 646629366 --- tensorflow/lite/build_def.bzl | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index da05c312d3c0d3..8ebe89096a8128 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -181,13 +181,22 @@ def tflite_linkopts_no_undefined(): }), ) +def tflite_pagesize_linkopts(): + """Defines linker flags for setting the page size.""" + return select({ + clean_dep("//tensorflow:android_arm64"): [ + "-Wl,-z,max-page-size=16384", + ], + "//conditions:default": [], + }) + def tflite_linkopts(): """Defines linker flags for linking TFLite binary.""" - return tflite_linkopts_unstripped() + tflite_symbol_opts() + return tflite_linkopts_unstripped() + tflite_symbol_opts() + tflite_pagesize_linkopts() def tflite_jni_linkopts(): """Defines linker flags for linking TFLite binary with JNI.""" - return tflite_jni_linkopts_unstripped() + tflite_symbol_opts() + return tflite_jni_linkopts_unstripped() + tflite_symbol_opts() + tflite_pagesize_linkopts() def tflite_jni_binary( name, From dea0e2a6f8beca190ca138093e669d42ee244056 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 25 Jun 2024 15:38:57 -0700 Subject: [PATCH 4/9] [xla:cpu] Optimize KernelThunk host kernel loading PiperOrigin-RevId: 646631043 --- .../select_and_scatter_benchmark_test.cc | 3 +- .../xla/service/cpu/runtime/kernel_thunk.cc | 29 ++++++++++--------- .../xla/service/cpu/runtime/kernel_thunk.h | 14 ++++----- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc index fdfdc7b00b1882..7521dcda5b0f86 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc @@ -75,6 +75,7 @@ BENCHMARK(BM_SelectAndScatterF32) ->MeasureProcessCPUTime() ->Arg(128) ->Arg(256) - ->Arg(512); + ->Arg(512) + ->Arg(1024); } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc index 85d65c0ec323fd..247c995649a525 100644 --- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/numeric/bits.h" #include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "xla/runtime/buffer_use.h" @@ -124,29 +125,31 @@ tsl::AsyncValueRef KernelThunk::Execute( // TODO(ezhulenev): Kernel ptr should be loaded as a part of Thunk // initialization stage. - SE_HOST_Kernel* kernel_ptr = kernel_ptr_.load(); + se::host::HostKernel* kernel = kernel_ptr_.load(); // Because thunks are owned by a parent CpuExecutable, we can safely assume // that kernel pointer will not change after we find it the first time. - if (kernel_ptr == nullptr) { - TF_ASSIGN_OR_RETURN(kernel_ptr, params.host_kernels->Find(kernel_name_)); - kernel_ptr_.store(kernel_ptr); - } + if (ABSL_PREDICT_FALSE(kernel == nullptr)) { + TF_ASSIGN_OR_RETURN(SE_HOST_Kernel * kernel_fn, + params.host_kernels->Find(kernel_name_)); - se::host::HostKernel kernel(buffers_data.size(), kernel_ptr, nullptr); + absl::MutexLock lock(&mutex_); + kernel_.emplace(buffers_data.size(), kernel_fn, nullptr); + kernel_ptr_.store(kernel = &kernel_.value()); + } // If intra-op thread pool is not nullptr, we launch HostKernel in async mode // by scheduling tasks into it. HostKernel launch completion will // automatically signal KernelThunk execute completion. - if (params.intra_op_threadpool && use_task_runner_) { - return kernel.Launch(thread_dim_, buffers_data, - [¶ms](se::host::HostKernel::Task task) { - params.intra_op_threadpool->getPool()->Schedule( - ToCopyableTask(std::move(task))); - }); + if (ABSL_PREDICT_FALSE(params.intra_op_threadpool && use_task_runner_)) { + return kernel->Launch(thread_dim_, buffers_data, + [¶ms](se::host::HostKernel::Task task) { + params.intra_op_threadpool->getPool()->Schedule( + ToCopyableTask(std::move(task))); + }); } - TF_RETURN_IF_ERROR(kernel.Launch(thread_dim_, buffers_data)); + TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, buffers_data)); return OkExecuteEvent(); } diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.h b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.h index 72cd1be097ac25..708f918d342c96 100644 --- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.h +++ b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.h @@ -23,11 +23,13 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/runtime/thunk.h" -#include "xla/stream_executor/host/host_kernel_c_api.h" +#include "xla/stream_executor/host/host_kernel.h" #include "xla/stream_executor/launch_dim.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -64,12 +66,10 @@ class KernelThunk final : public Thunk { // launch the kernel directly in the caller thread. bool use_task_runner_; - // Pointer to the host kernel corresponding to `kernel_name_`. Initialized - // lazily at run time by looking it up in the HostKernels passed via params. - // - // TODO(ezhulenev): This should be moved to initialization stage when we'll - // have it for CPU thunks. - std::atomic kernel_ptr_; + // Lazily loaded host kernel corresponding to `kernel_name_`. + absl::Mutex mutex_; + std::optional kernel_ ABSL_GUARDED_BY(mutex_); + std::atomic kernel_ptr_; // pointer to `kernel_` }; } // namespace xla::cpu From 01aeb511e1c2e357c24d6f8f57bdcc46638549fd Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Tue, 25 Jun 2024 15:46:23 -0700 Subject: [PATCH 5/9] Remove unused VlogOccupancyInfo calls. PiperOrigin-RevId: 646633442 --- .../xla/stream_executor/cuda/cuda_executor.cc | 90 ------------------- .../xla/stream_executor/gpu/gpu_executor.h | 25 ------ .../xla/stream_executor/rocm/rocm_executor.cc | 86 ------------------ 3 files changed, 201 deletions(-) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 7d978f043a42fd..901f051e290606 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -465,20 +465,6 @@ absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, const GpuKernel* cuda_kernel = AsGpuKernel(&kernel); CUfunction cufunc = cuda_kernel->AsGpuFunctionHandle(); - // Only perform/print the occupancy check once. Even just checking to see - // whether we've done an occupancy check on this kernel before isn't free - // (because we have to synchronize), so we only do this at -v 2+. - if (VLOG_IS_ON(2)) { - absl::MutexLock lock(&launched_kernels_mu_); - if (!launched_kernels_.count(cufunc)) { - VlogOccupancyInfo(stream->parent()->GetDeviceDescription(), kernel, - thread_dims, block_dims); - // TODO(rspringer): Remove elements from launched_kernels_...if we ever - // expose a kernel/module deallocation method. - launched_kernels_.insert(cufunc); - } - } - if (cuda_kernel->cache_config() != KernelCacheConfig::kNoPreference) { TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig( cufunc, cuda_kernel->GetGpuCacheConfig())); @@ -547,82 +533,6 @@ absl::Status GpuExecutor::Submit(Stream* stream, return GpuDriver::GraphLaunch(exec, AsGpuStreamValue(stream)); } -// This is a non-essential operation; if there's a failure, proceed without -// logging an error. It's nearly certain that in case of failures, we'd never -// get here in the first place; these are very low-impact routines. -void GpuExecutor::VlogOccupancyInfo(const DeviceDescription& device_description, - const Kernel& kernel, - const ThreadDim& thread_dims, - const BlockDim& block_dims) { - VLOG(2) << "Computing kernel occupancy for kernel " - << kernel.demangled_name(); - VLOG(2) << "Thread dimensions (" << thread_dims.x << ", " << thread_dims.y - << ", " << thread_dims.z << ")"; - - auto regs_per_thread = kernel.metadata().registers_per_thread(); - auto smem_per_block = kernel.metadata().shared_memory_bytes(); - - if (!regs_per_thread && !smem_per_block) { - return; - } - - const GpuKernel* cuda_kernel = AsGpuKernel(&kernel); - CUfunction cufunc = cuda_kernel->AsGpuFunctionHandle(); - - int blocks_per_sm = CalculateOccupancy(device_description, *regs_per_thread, - *smem_per_block, thread_dims, cufunc); - VLOG(2) << "Resident blocks per SM is " << blocks_per_sm; - - int suggested_threads = - CompareOccupancy(&blocks_per_sm, device_description, *regs_per_thread, - *smem_per_block, thread_dims, cufunc); - if (suggested_threads != 0) { - VLOG(2) << "The cuda occupancy calculator recommends using " - << suggested_threads - << " threads per block to achieve an occupancy of " << blocks_per_sm - << " blocks per SM."; - } -} - -// Compute and return maximum blocks per core (occupancy) based on the -// device description, some kernel characteristics and the number of threads per -// block. If unable to compute occupancy, zero is returned. -int GpuExecutor::CalculateOccupancy(const DeviceDescription& device_description, - uint64_t registers_per_thread, - uint64_t shared_memory_per_block, - const ThreadDim& thread_dims, - CUfunction func) { - int suggested_blocks = 0; - int suggested_threads = 0; - CUresult err = cuOccupancyMaxPotentialBlockSize( - &suggested_blocks, &suggested_threads, func, nullptr, - shared_memory_per_block, 0); - CHECK_EQ(err, CUDA_SUCCESS); - return suggested_blocks; -} - -// Compute and return the suggested thread count to achieve ideal occupancy. -// If the provided thread dimensions match this number, zero is returned. -int GpuExecutor::CompareOccupancy(int* initial_blocks, - const DeviceDescription& device_description, - uint64_t registers_per_thread, - uint64_t shared_memory_per_block, - const ThreadDim& thread_dims, - CUfunction func) { - int suggested_blocks = 0; - int suggested_threads = 0; - CUresult err = cuOccupancyMaxPotentialBlockSize( - &suggested_blocks, &suggested_threads, func, nullptr, - shared_memory_per_block, 0); - CHECK_EQ(err, CUDA_SUCCESS); - if (suggested_blocks > *initial_blocks) { - *initial_blocks = suggested_blocks; - return suggested_threads; - } else { - return 0; - } -} - DeviceMemoryBase GpuExecutor::Allocate(uint64_t size, int64_t memory_space) { if (memory_space == 1) { auto result = GpuCollectives::CollectiveMemoryAllocate(context_, size); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h index c0a21007e00b5d..d17959e8fd4bf0 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h @@ -147,17 +147,6 @@ class GpuExecutor : public StreamExecutorCommon { absl::Status Submit(Stream* stream, const CommandBuffer& command_buffer) override; - int CalculateOccupancy(const DeviceDescription& device_description, - uint64_t registers_per_thread, - uint64_t shared_memory_per_block, - const ThreadDim& thread_dims, GpuFunctionHandle func); - - int CompareOccupancy(int* initial_blocks, - const DeviceDescription& device_description, - uint64_t registers_per_thread, - uint64_t shared_memory_per_block, - const ThreadDim& thread_dims, GpuFunctionHandle func); - DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; void Deallocate(DeviceMemoryBase* mem) override; @@ -320,12 +309,6 @@ class GpuExecutor : public StreamExecutorCommon { absl::Status GetKernelMetadata(GpuKernel* cuda_kernel, KernelMetadata* kernel_metadata); - // Prints to VLOG(2) information about the kernel's occupancy and how it might - // be improved. - void VlogOccupancyInfo(const DeviceDescription& device_description, - const Kernel& kernel, const ThreadDim& thread_dims, - const BlockDim& block_dims); - // (supported on CUDA only) absl::Status LoadModuleFromCuBin(const char* cubin, GpuModuleHandle* module) TF_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); @@ -377,14 +360,6 @@ class GpuExecutor : public StreamExecutorCommon { std::unordered_map> gpu_binary_to_module_ ABSL_GUARDED_BY(in_memory_modules_mu_); - // Guards the launched kernel set. - absl::Mutex launched_kernels_mu_; - - // Keeps track of the set of launched kernels. Currently used to suppress the - // occupancy check on subsequent launches. - std::set launched_kernels_ - ABSL_GUARDED_BY(launched_kernels_mu_); - // Handle for the CUDA device being operated on. Immutable // post-initialization. GpuDeviceHandle device_; diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc index f405de0c781537..86b0b3574f922f 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc @@ -337,20 +337,6 @@ absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, const GpuKernel* rocm_kernel = AsGpuKernel(&kernel); hipFunction_t hipfunc = rocm_kernel->AsGpuFunctionHandle(); - // Only perform/print the occupancy check once. Even just checking to see - // whether we've done an occupancy check on this kernel before isn't free - // (because we have to synchronize), so we only do this at -v 2+. - if (VLOG_IS_ON(2)) { - absl::MutexLock lock(&launched_kernels_mu_); - if (!launched_kernels_.count(hipfunc)) { - VlogOccupancyInfo(stream->parent()->GetDeviceDescription(), kernel, - thread_dims, block_dims); - // TODO(rspringer): Remove elements from launched_kernels_...if we ever - // expose a kernel/module deallocation method. - launched_kernels_.insert(hipfunc); - } - } - if (rocm_kernel->cache_config() != KernelCacheConfig::kNoPreference) { TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig( hipfunc, rocm_kernel->GetGpuCacheConfig())); @@ -458,78 +444,6 @@ absl::Status GpuExecutor::LoadModuleFromHsaco(const char* hsaco, return absl::OkStatus(); } -// This is a non-essential operation; if there's a failure, proceed without -// logging an error. It's nearly certain that in case of failures, we'd never -// get here in the first place; these are very low-impact routines. -void GpuExecutor::VlogOccupancyInfo(const DeviceDescription& device_description, - const Kernel& kernel, - const ThreadDim& thread_dims, - const BlockDim& block_dims) { - VLOG(2) << "Computing kernel occupancy for kernel " - << kernel.demangled_name(); - VLOG(2) << "Thread dimensions (" << thread_dims.x << ", " << thread_dims.y - << ", " << thread_dims.z << ")"; - - auto regs_per_thread = kernel.metadata().registers_per_thread(); - auto smem_per_block = kernel.metadata().shared_memory_bytes(); - - if (!regs_per_thread && !smem_per_block) { - return; - } - - const GpuKernel* rocm_kernel = AsGpuKernel(&kernel); - auto hipfunc = rocm_kernel->AsGpuFunctionHandle(); - - int blocks_per_sm = CalculateOccupancy(device_description, *regs_per_thread, - *smem_per_block, thread_dims, hipfunc); - VLOG(2) << "Resident blocks per SM is " << blocks_per_sm; - - int suggested_threads = - CompareOccupancy(&blocks_per_sm, device_description, *regs_per_thread, - *smem_per_block, thread_dims, hipfunc); - if (suggested_threads != 0) { - VLOG(2) << "The rocm occupancy calculator recommends using " - << suggested_threads - << " threads per block to achieve an occupancy of " << blocks_per_sm - << " blocks per SM."; - } -} - -// Compute and return maximum blocks per core (occupancy) based on the -// device description, some kernel characteristics and the number of threads per -// block. If unable to compute occupancy, zero is returned. -int GpuExecutor::CalculateOccupancy(const DeviceDescription& device_description, - uint64_t registers_per_thread, - uint64_t shared_memory_per_block, - const ThreadDim& thread_dims, - GpuFunctionHandle func) { - int suggested_blocks = 0; - int suggested_threads = 0; - (void)rocm::OccupancyGetMaxPotentialBlockSize( - &suggested_blocks, &suggested_threads, func, shared_memory_per_block, 0); - return suggested_blocks; -} - -// Compute and return the suggested thread count to achieve ideal occupancy. -// If the provided thread dimensions match this number, zero is returned. -int GpuExecutor::CompareOccupancy(int* initial_blocks, - const DeviceDescription& device_description, - uint64_t registers_per_thread, - uint64_t shared_memory_per_block, - const ThreadDim& thread_dims, - GpuFunctionHandle func) { - int suggested_blocks = 0; - int suggested_threads = 0; - (void)rocm::OccupancyGetMaxPotentialBlockSize( - &suggested_blocks, &suggested_threads, func, shared_memory_per_block, 0); - if (suggested_blocks > *initial_blocks) { - *initial_blocks = suggested_blocks; - return suggested_threads; - } else { - return 0; - } -} - DeviceMemoryBase GpuExecutor::Allocate(uint64_t size, int64_t memory_space) { if (memory_space == static_cast(stream_executor::MemoryType::kHost)) { From 89489e76b02a78963d5739393bdb3b93ae46c082 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Tue, 25 Jun 2024 15:47:13 -0700 Subject: [PATCH 6/9] Add `xla/package_groups.bzl` and `xla/tsl/package_groups.bzl` to hold `package_groups` and replace Copybara rules PiperOrigin-RevId: 646633709 --- third_party/xla/opensource_only.files | 2 ++ third_party/xla/xla/BUILD | 39 ++-------------------- third_party/xla/xla/package_groups.bzl | 23 +++++++++++++ third_party/xla/xla/tests/BUILD | 11 ++---- third_party/xla/xla/tsl/BUILD | 33 ++---------------- third_party/xla/xla/tsl/package_groups.bzl | 7 ++++ 6 files changed, 41 insertions(+), 74 deletions(-) create mode 100644 third_party/xla/xla/package_groups.bzl create mode 100644 third_party/xla/xla/tsl/package_groups.bzl diff --git a/third_party/xla/opensource_only.files b/third_party/xla/opensource_only.files index 666a858f3e4bd9..baafd35265caaf 100644 --- a/third_party/xla/opensource_only.files +++ b/third_party/xla/opensource_only.files @@ -1,10 +1,12 @@ compiler/xla/mlir_hlo/WORKSPACE: +compiler/xla/package_groups.bzl: compiler/xla/stream_executor/build_defs.bzl: compiler/xla/tsl/cuda/stub.bzl: compiler/xla/tsl/mkl/BUILD: compiler/xla/tsl/mkl/LICENSE: compiler/xla/tsl/mkl/MKL_LICENSE: compiler/xla/tsl/mkl/build_defs.bzl: +compiler/xla/tsl/package_groups.bzl: third_party/BUILD: third_party/__init__:.py third_party/compute_library/BUILD: diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 7b8bacdead5b0f..0b58e14e6b0a58 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -7,6 +7,7 @@ load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//third_party/compute_library:build_defs.bzl", "if_enable_acl") # Placeholder: load py_proto_library +load("//xla:package_groups.bzl", "xla_package_groups") load("//xla:xla.bzl", "xla_cc_test", "xla_py_proto_library") load("//xla/tsl:tsl.bzl", "if_google", "internal_visibility") load("//xla/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") @@ -17,46 +18,12 @@ package( licenses = ["notice"], ) -package_group( - name = "friends", - includes = ["//xla:internal"], - packages = [ - # copybara:uncomment "//learning/...", - "//third_party/australis/...", - "//third_party/iree/...", - "//third_party/libxc/...", - "//third_party/mira/...", - "//third_party/mlcompass/...", - "//third_party/mlir_edge/model_curriculum/...", - "//third_party/openxla/shardonnay/...", - "//third_party/py/enzyme_ad/...", - "//third_party/py/jax/...", - "//third_party/py/t5x/...", - "//third_party/py/tpu_graphs/...", - "//tensorflow/compiler/...", - "//tensorflow/python/tpu/...", - ], -) - -package_group( - name = "internal", - packages = [ - "//xla/...", - ], -) - -package_group( - name = "runtime", - packages = [ - "//xla/runtime/...", - "//xla/service/gpu/runtime/...", - ], -) - exports_files([ "lit.cfg.py", ]) +xla_package_groups() + # Filegroup used to collect source files for dependency checking. filegroup( name = "c_srcs", diff --git a/third_party/xla/xla/package_groups.bzl b/third_party/xla/xla/package_groups.bzl new file mode 100644 index 00000000000000..ea554b3a8aa09d --- /dev/null +++ b/third_party/xla/xla/package_groups.bzl @@ -0,0 +1,23 @@ +"""XLA package_group definitions""" + +def xla_package_groups(name = "xla_package_groups"): + native.package_group( + name = "friends", + packages = ["//..."], + ) + + native.package_group( + name = "internal", + packages = ["//..."], + ) + + native.package_group( + name = "runtime", + packages = ["//..."], + ) + +def xla_tests_package_groups(name = "xla_tests_package_groups"): + native.package_group( + name = "friends", + packages = ["//..."], + ) diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 7426ca4a1f5b41..c6b1a763480b26 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -11,6 +11,7 @@ load( "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) +load("//xla:package_groups.bzl", "xla_tests_package_groups") load("//xla:xla.bzl", "tests_build_defs_bzl_deps", "xla_cc_binary", "xla_cc_test") load("//xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend_test_macros", "xla_test", "xla_test_library") load("//xla/tsl:tsl.bzl", "internal_visibility", "tsl_copts") @@ -22,14 +23,6 @@ package( licenses = ["notice"], ) -package_group( - name = "friends", - includes = [ - "//xla:friends", - ], - packages = ["//platforms/testing/tests/..."], -) - # Filegroup used to collect source files for dependency checking. filegroup( name = "c_srcs", @@ -39,6 +32,8 @@ filegroup( ]), ) +xla_tests_package_groups() + # Generate test_suites for all backends, named "${backend}_tests". generate_backend_suites() diff --git a/third_party/xla/xla/tsl/BUILD b/third_party/xla/xla/tsl/BUILD index 4d15019d7586f4..27a3e2ae8ab9bf 100644 --- a/third_party/xla/xla/tsl/BUILD +++ b/third_party/xla/xla/tsl/BUILD @@ -1,11 +1,14 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("@bazel_skylib//lib:selects.bzl", "selects") load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "bool_setting") +load("package_groups.bzl", "tsl_package_groups") load("tsl.bzl", "if_google", "if_oss") load("tsl.default.bzl", "tsl_google_bzl_deps") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) +tsl_package_groups() + # Config setting to use in select()s to distinguish open source build from # google internal build on configurable attributes. # @@ -497,36 +500,6 @@ config_setting( visibility = ["//visibility:public"], ) -# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST! -# Instead, please use public APIs or public build rules TF provides. -# If you need functionality that is not exposed, we will work with you to expand our public APIs. -# TODO(b/173549186): Move Google-internal TF code out of learning/brain -# TODO(jakeharmon): Prune this for use in TSL -package_group( - name = "internal", - packages = [ - "//devtools/python/indexer/...", - "//learning/brain/keras/...", - "//learning/brain/mlir/...", - "//learning/brain/tfrt/...", - "//learning/lib/ami/simple_ml/...", - "//learning/pathways/...", - "//smartass/brain/configure/...", - "//tensorflow/...", - "//tensorflow_decision_forests/...", - "//tensorflow_federated/...", - "//third_party/cloud_tpu/convergence_tools/sdc_monitoring/...", - "//third_party/cloud_tpu/inference_converter/...", - "//third_party/py/cloud_ml_autoflow/...", - "//third_party/py/envlogger/...", - "//third_party/yggdrasil_decision_forests/...", - ] + if_google([ - # Needed in OSS, where bazel won't allow a package group to refer to an - # external repo. - "@local_tsl//tsl/...", - ]), -) - bzl_library( name = "tsl_bzl", srcs = ["tsl.bzl"], diff --git a/third_party/xla/xla/tsl/package_groups.bzl b/third_party/xla/xla/tsl/package_groups.bzl new file mode 100644 index 00000000000000..e4e44a7f4020f5 --- /dev/null +++ b/third_party/xla/xla/tsl/package_groups.bzl @@ -0,0 +1,7 @@ +"""TSL package_group definitions""" + +def tsl_package_groups(name = "tsl_package_groups"): + native.package_group( + name = "internal", + packages = ["//..."], + ) From 18a0eb07ef315a3a59a339b0b182fc05a81c28a3 Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Tue, 25 Jun 2024 17:28:44 -0700 Subject: [PATCH 7/9] [Triton] Refactoring condition in autotuner to be more robust. Added test to make sure crashing Triton configurations are actually skipped and to guard against breaking it. PiperOrigin-RevId: 646663340 --- .../xla/service/gpu/gemm_fusion_autotuner.cc | 45 ++++++++++--------- .../service/gpu/gemm_fusion_autotuner_test.cc | 28 ++++++++++++ 2 files changed, 51 insertions(+), 22 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc index 25c1b5b3894c8f..6359e5f28dbcb5 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc @@ -312,11 +312,6 @@ absl::StatusOr GetLimits(const HloDotInstruction& dot) { const int max_k = tsl::NextPowerOfTwoS64( dot.operand(1)->shape().dimensions(contracting_index)); - // TODO(b/337839570): block_k = 16 is bugged in Triton for dots with 8-bit - // input. Setting minimum to 32 instead of 16 for these cases. - // TODO(b/337838200): Write the restriction on the minimum tile size to be - // generic. Currently we only handle the 8-bit case as this was the bug we - // ran into. return TileSizeLimit{ /*block_m=*/std::max(max_m, kMinTileSize), /*block_n=*/std::max(max_n, kMinTileSize), @@ -634,13 +629,20 @@ absl::StatusOr> GemmFusionAutotunerImpl::GenerateConfigs( absl::StatusOr> GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) { - bool has_8_bit_operand = HloAnyOf({&dot}, [&](const HloInstruction* node) { - if (node->opcode() != HloOpcode::kConvert) { - return false; - } - auto in_type = node->operand(0)->shape().element_type(); - return primitive_util::BitWidth(in_type) == 8; - }); + // Retrieve the minimum bit-width participating in the dot. This is needed + // to avoid autotuning configurations that are not supported by Triton. This + // is used to restrict the values for tile_k. + std::vector converts = + HloFindAll({&dot}, [&](const HloInstruction* node) { + return node->opcode() == HloOpcode::kConvert; + }); + int minBitWidth = primitive_util::BitWidth(dot.shape().element_type()); + for (auto convert : converts) { + auto in_type = convert->operand(0)->shape().element_type(); + auto out_type = convert->shape().element_type(); + minBitWidth = std::min({minBitWidth, primitive_util::BitWidth(in_type), + primitive_util::BitWidth(out_type)}); + } std::vector result_configs; TF_ASSIGN_OR_RETURN(TileSizeLimit limits, GetLimits(dot)); @@ -690,14 +692,12 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) { } config.split_k = std::min(config.split_k, max_split_k); - // TODO(b/337839570): block_k = 16 is bugged in Triton for dots with 8-bit - // input. Setting minimum to 32 instead of 16 for these cases. - // TODO(b/337838200): Write the restriction on the minimum tile size to be - // generic. Currently we only handle the 8-bit case as this was the bug we - // ran into. - if (has_8_bit_operand && config.block_k == kMinTileSize) { - config.block_k *= 2; - } + // TODO(b/337839570): Triton currently has a limitation where it crashes + // on small block_k values depending on the bit-width of the inputs to the + // dot. The logic below accounts for this limitation. + constexpr int kLdmatrixGranularity = 256; + config.block_k = + std::max(config.block_k, kLdmatrixGranularity / minBitWidth); // Sparse meta should have at least one element per thread. // Note: only 2:4 structured sparsity is currently supported. @@ -706,8 +706,9 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) { config.block_m = std::max(config.block_m, 64); config.num_warps = std::max(config.num_warps, 4); } - config.block_k = - std::max(config.block_k, kMinTileSize * (has_8_bit_operand ? 4 : 2)); + config.block_k = std::max( + config.block_k, + 2 * std::max(kMinTileSize, kLdmatrixGranularity / minBitWidth)); int meta_elements = config.block_m * config.block_k / 16; config.num_warps = std::min(config.num_warps, meta_elements / WarpSize()); diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc index fa75448abf42db..16410403233987 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc @@ -845,6 +845,34 @@ ENTRY e { )"); } +// TODO(b/337839570): Triton currently has a limitation where it crashes +// on small block_k values depending on the bit-width of the inputs to the +// dot. For this test case, it should skip any block_k values that are <= 16 +// since the smallest type has a bit-width of 8. +TEST_F(GemmFusionAutotunerExhaustiveTest, SkipsCrashingTileKConfig) { + std::unique_ptr module = ParseAndReturnVerifiedModule(R"( +HloModule module +ENTRY e { + x = s8[33,33]{1,0} parameter(0) + c = f16[33,33]{1,0} convert(x) + y = f16[33,33]{1,0} parameter(1) + ROOT out = f16[33,33]{1,0} dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)") + .value(); + const se::CudaComputeCapability compute_capability{ + se::CudaComputeCapability::AMPERE, /*minor=*/0}; + TF_ASSERT_OK_AND_ASSIGN( + const std::vector configs, + GetPossibleMatmulAutotuneConfigs( + *Cast( + module->entry_computation()->root_instruction()), + compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); + EXPECT_TRUE(std::all_of( + configs.begin(), configs.end(), + [](const TritonGemmConfig& config) { return config.block_k > 16; })); +} + class GemmFusionAutotunerDisableSplitK : public GemmFusionAutotunerTest { public: DebugOptions GetDebugOptionsForTest() override { From 633e9ccc0e93efeadffd572f8b969a54540df00b Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 25 Jun 2024 17:34:58 -0700 Subject: [PATCH 8/9] [xla:cpu] Optimize KernelThunk by passing SE_HOST_KernelArg directly to the kernel PiperOrigin-RevId: 646664927 --- .../select_and_scatter_benchmark_test.cc | 1 + .../xla/service/cpu/runtime/kernel_thunk.cc | 30 +++++++++-------- .../service/cpu/runtime/kernel_thunk_test.cc | 4 +-- .../xla/stream_executor/host/host_kernel.cc | 33 ++++++++++++++----- .../xla/stream_executor/host/host_kernel.h | 5 +++ .../stream_executor/host/host_kernel_c_api.h | 2 +- .../stream_executor/host/host_kernel_test.cc | 25 +++++++++----- 7 files changed, 65 insertions(+), 35 deletions(-) diff --git a/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc index 7521dcda5b0f86..bbc32250444b0f 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc @@ -73,6 +73,7 @@ static void BM_SelectAndScatterF32(benchmark::State& state) { BENCHMARK(BM_SelectAndScatterF32) ->MeasureProcessCPUTime() + ->Arg(64) ->Arg(128) ->Arg(256) ->Arg(512) diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc index 247c995649a525..a8d793d1076071 100644 --- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc @@ -87,38 +87,40 @@ tsl::AsyncValueRef KernelThunk::Execute( kernel_name_, arguments_buffers_.size(), results_buffers_.size(), thread_dim_.ToString()); - absl::InlinedVector buffers_data; - buffers_data.reserve(arguments_buffers_.size() + results_buffers_.size()); + absl::InlinedVector kernel_args; + kernel_args.reserve(arguments_buffers_.size() + results_buffers_.size()); int64_t arg_num = 0; for (BufferAllocation::Slice& buffer : arguments_buffers_) { - TF_ASSIGN_OR_RETURN(buffers_data.emplace_back(), + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase arg_data, params.buffer_allocations->GetDeviceAddress(buffer)); + kernel_args.push_back( + SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()}); VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", arg_num++, - buffer.ToString(), - buffers_data.back().opaque()); + buffer.ToString(), kernel_args.back().data); } int64_t res_num = 0; for (BufferAllocation::Slice& buffer : results_buffers_) { - TF_ASSIGN_OR_RETURN(buffers_data.emplace_back(), + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase result_data, params.buffer_allocations->GetDeviceAddress(buffer)); + kernel_args.push_back( + SE_HOST_KernelArg{result_data.opaque(), result_data.size()}); VLOG(3) << absl::StreamFormat(" res #%d: %s (%p)", res_num++, - buffer.ToString(), - buffers_data.back().opaque()); + buffer.ToString(), kernel_args.back().data); } // Check that all buffers are aligned to the minimum alignment. We codegen // with the assumption that all buffers are aligned, and if they are not, we // will crash with a segmentation fault, or worse, produce incorrect results. if (min_alignment_.has_value()) { - for (int64_t i = 0; i < buffers_data.size(); ++i) { - auto ptr = reinterpret_cast(buffers_data[i].opaque()); + for (int64_t i = 0; i < kernel_args.size(); ++i) { + auto ptr = reinterpret_cast(kernel_args[i].data); if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) { return Internal( "Host kernel %s buffer argument #%d (%p) is not aligned to a " "required minimum alignment of %d bytes", - info().op_name, i, buffers_data[i].opaque(), *min_alignment_); + info().op_name, i, kernel_args[i].data, *min_alignment_); } } } @@ -134,7 +136,7 @@ tsl::AsyncValueRef KernelThunk::Execute( params.host_kernels->Find(kernel_name_)); absl::MutexLock lock(&mutex_); - kernel_.emplace(buffers_data.size(), kernel_fn, nullptr); + kernel_.emplace(kernel_args.size(), kernel_fn, nullptr); kernel_ptr_.store(kernel = &kernel_.value()); } @@ -142,14 +144,14 @@ tsl::AsyncValueRef KernelThunk::Execute( // by scheduling tasks into it. HostKernel launch completion will // automatically signal KernelThunk execute completion. if (ABSL_PREDICT_FALSE(params.intra_op_threadpool && use_task_runner_)) { - return kernel->Launch(thread_dim_, buffers_data, + return kernel->Launch(thread_dim_, kernel_args, [¶ms](se::host::HostKernel::Task task) { params.intra_op_threadpool->getPool()->Schedule( ToCopyableTask(std::move(task))); }); } - TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, buffers_data)); + TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, kernel_args)); return OkExecuteEvent(); } diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk_test.cc b/third_party/xla/xla/service/cpu/runtime/kernel_thunk_test.cc index a80db35857e86f..9f6993dc9a73a7 100644 --- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk_test.cc +++ b/third_party/xla/xla/service/cpu/runtime/kernel_thunk_test.cc @@ -40,8 +40,8 @@ class AddF32HostKernels : public Thunk::HostKernels { public: absl::StatusOr Find(std::string_view name) override { return +[](const SE_HOST_KernelCallFrame* call_frame) { - SE_HOST_KernelArg& in = call_frame->args[0]; - SE_HOST_KernelArg& out = call_frame->args[1]; + const SE_HOST_KernelArg& in = call_frame->args[0]; + const SE_HOST_KernelArg& out = call_frame->args[1]; float* in_ptr = reinterpret_cast(in.data); float* out_ptr = reinterpret_cast(out.data); diff --git a/third_party/xla/xla/stream_executor/host/host_kernel.cc b/third_party/xla/xla/stream_executor/host/host_kernel.cc index ceb148cdaaf918..04586b5272432b 100644 --- a/third_party/xla/xla/stream_executor/host/host_kernel.cc +++ b/third_party/xla/xla/stream_executor/host/host_kernel.cc @@ -69,7 +69,7 @@ class HostKernelExecuteState HostKernelExecuteState(HostKernel::TaskRunner task_runner, HostKernel::KernelFunction* function, ThreadDim thread_dims, - absl::Span buffers); + absl::Span args); // Notify of a completion of a host kernel task. void Notify(absl::Status status); @@ -118,11 +118,19 @@ HostKernel::HostKernel(unsigned arity, SE_HOST_Kernel* kernel, absl::Status HostKernel::Launch( const ThreadDim& thread_dims, absl::Span buffers) const { - SE_HOST_KernelThreadDim kernel_thread_dims = {thread_dims.x, thread_dims.y, - thread_dims.z}; + return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers)); +} + +absl::Status HostKernel::Launch( + const ThreadDim& thread_dims, + absl::Span args) const { + SE_HOST_KernelThreadDim kernel_thread_dims = { + thread_dims.x, + thread_dims.y, + thread_dims.z, + }; SE_HOST_Kernel* kernel = function_->kernel(); - auto args = ConvertBuffersToKernelArgs(buffers); for (uint64_t z = 0; z < thread_dims.z; ++z) { for (uint64_t y = 0; y < thread_dims.y; ++y) { @@ -134,7 +142,7 @@ absl::Status HostKernel::Launch( SE_HOST_KernelError* error = (*kernel)(&call_frame); - if (error != nullptr) { + if (ABSL_PREDICT_FALSE(error != nullptr)) { return absl::InternalError("Failed to call host kernel"); } } @@ -147,12 +155,19 @@ absl::Status HostKernel::Launch( tsl::AsyncValueRef HostKernel::Launch( const ThreadDim& thread_dims, absl::Span buffers, TaskRunner task_runner) const { + return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers), + std::move(task_runner)); +} + +tsl::AsyncValueRef HostKernel::Launch( + const ThreadDim& thread_dims, absl::Span args, + TaskRunner task_runner) const { size_t num_tasks = thread_dims.x * thread_dims.y * thread_dims.z; CHECK_GT(num_tasks, 0) << "Number of tasks must be positive"; // Crash Ok // Short-circuit launch with a single task and run it in the caller thread. if (ABSL_PREDICT_TRUE(num_tasks == 1)) { - absl::Status launched = Launch(thread_dims, buffers); + absl::Status launched = Launch(thread_dims, args); return ABSL_PREDICT_TRUE(launched.ok()) ? OkLaunchEvent() : tsl::MakeErrorAsyncValueRef(std::move(launched)); @@ -160,7 +175,7 @@ tsl::AsyncValueRef HostKernel::Launch( // Allocate a control structure that will orchestrate kernel execution. auto state = tsl::MakeRef( - std::move(task_runner), function_.get(), thread_dims, buffers); + std::move(task_runner), function_.get(), thread_dims, args); state->CallAsync(/*start_index=*/0, /*end_index=*/num_tasks); @@ -169,12 +184,12 @@ tsl::AsyncValueRef HostKernel::Launch( HostKernelExecuteState::HostKernelExecuteState( HostKernel::TaskRunner task_runner, HostKernel::KernelFunction* function, - ThreadDim thread_dims, absl::Span buffers) + ThreadDim thread_dims, absl::Span args) : task_runner_(std::move(task_runner)), num_tasks_(thread_dims.x * thread_dims.y * thread_dims.z), kernel_(function->kernel()), thread_dims_({thread_dims.x, thread_dims.y, thread_dims.z}), - args_(ConvertBuffersToKernelArgs(buffers)), + args_(args.begin(), args.end()), abort_(false), counter_(num_tasks_), event_(tsl::MakeConstructedAsyncValueRef()) {} diff --git a/third_party/xla/xla/stream_executor/host/host_kernel.h b/third_party/xla/xla/stream_executor/host/host_kernel.h index e8e040a2d86173..9d278b2b79c357 100644 --- a/third_party/xla/xla/stream_executor/host/host_kernel.h +++ b/third_party/xla/xla/stream_executor/host/host_kernel.h @@ -80,6 +80,8 @@ class HostKernel : public Kernel { // `thread_dims` and calling the kernel function. absl::Status Launch(const ThreadDim& thread_dims, absl::Span buffers) const; + absl::Status Launch(const ThreadDim& thread_dims, + absl::Span args) const; // Launches the kernel by iterating over all threads in `thread_dims` and // calling `task_runner` to run individual task (implementation might decide @@ -93,6 +95,9 @@ class HostKernel : public Kernel { tsl::AsyncValueRef Launch( const ThreadDim& thread_dims, absl::Span buffers, TaskRunner task_runner) const; + tsl::AsyncValueRef Launch( + const ThreadDim& thread_dims, absl::Span args, + TaskRunner task_runner) const; // For host platform, we assume that a core is a thread, and we can run at // most one instance of a kernel on a given thread. diff --git a/third_party/xla/xla/stream_executor/host/host_kernel_c_api.h b/third_party/xla/xla/stream_executor/host/host_kernel_c_api.h index 6768706abc2800..30f710cb44b264 100644 --- a/third_party/xla/xla/stream_executor/host/host_kernel_c_api.h +++ b/third_party/xla/xla/stream_executor/host/host_kernel_c_api.h @@ -71,7 +71,7 @@ typedef struct SE_HOST_KernelCallFrame { SE_HOST_KernelThread* thread; size_t num_args; - SE_HOST_KernelArg* args; + const SE_HOST_KernelArg* args; } SE_HOST_KernelCallFrame; // Error reporting for host kernels. NULL means success. diff --git a/third_party/xla/xla/stream_executor/host/host_kernel_test.cc b/third_party/xla/xla/stream_executor/host/host_kernel_test.cc index 5a121bf17cb5b7..aff9e1ed19ce7b 100644 --- a/third_party/xla/xla/stream_executor/host/host_kernel_test.cc +++ b/third_party/xla/xla/stream_executor/host/host_kernel_test.cc @@ -53,9 +53,9 @@ static auto ToCopyableTask(HostKernel::Task task) { } static SE_HOST_KernelError* AddI32(const SE_HOST_KernelCallFrame* call_frame) { - SE_HOST_KernelArg& lhs = call_frame->args[0]; - SE_HOST_KernelArg& rhs = call_frame->args[1]; - SE_HOST_KernelArg& out = call_frame->args[2]; + const SE_HOST_KernelArg& lhs = call_frame->args[0]; + const SE_HOST_KernelArg& rhs = call_frame->args[1]; + const SE_HOST_KernelArg& out = call_frame->args[2]; int32_t* lhs_ptr = reinterpret_cast(lhs.data); int32_t* rhs_ptr = reinterpret_cast(rhs.data); @@ -217,7 +217,9 @@ TEST(HostKernelTest, LaunchAsync) { }; HostKernel host_kernel(/*arity=*/0, no_op); - auto event = host_kernel.Launch(ThreadDim(4, 4, 4), {}, std::move(runner)); + auto event = host_kernel.Launch(ThreadDim(4, 4, 4), + absl::Span(), + std::move(runner)); tsl::BlockUntilReady(event); EXPECT_TRUE(event.IsConcrete()); @@ -245,7 +247,9 @@ TEST(HostKernelTest, LaunchAsyncError) { }; HostKernel host_kernel(/*arity=*/0, maybe_error); - auto event = host_kernel.Launch(ThreadDim(4, 4, 4), {}, std::move(runner)); + auto event = host_kernel.Launch(ThreadDim(4, 4, 4), + absl::Span(), + std::move(runner)); tsl::BlockUntilReady(event); ASSERT_TRUE(event.IsError()); @@ -269,7 +273,8 @@ static void BM_HostKernelSyncLaunch(benchmark::State& state) { HostKernel kernel(/*arity=*/0, NoOp); for (auto _ : state) { - benchmark::DoNotOptimize(kernel.Launch(ThreadDim(tdim_x), /*buffers=*/{})); + benchmark::DoNotOptimize(kernel.Launch( + ThreadDim(tdim_x), absl::Span())); } } @@ -281,9 +286,11 @@ static void BM_HostKernelAsyncLaunch(benchmark::State& state) { HostKernel kernel(/*arity=*/0, NoOp); for (auto _ : state) { - auto event = kernel.Launch(ThreadDim(tdim_x), {}, [&](auto task) { - thread_pool->Schedule(ToCopyableTask(std::move(task))); - }); + auto event = + kernel.Launch(ThreadDim(tdim_x), absl::Span(), + [&](auto task) { + thread_pool->Schedule(ToCopyableTask(std::move(task))); + }); tsl::BlockUntilReady(event); } } From cde7a2e0906c0606cc60befdfd67cffa77481389 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 17 Jun 2024 23:54:47 -0700 Subject: [PATCH 9/9] Change AllReduceSimplifier to handle trivial cross-partition all-reduces. This CL ensures that AllReduceSimplifier can simplify trivial all-reduces (an all-reduce where each subgroup is formed of a single participant) that are not necessarily cross replica (for example a cross partition all-reduce). We only simplify non cross replica all-reduce when the module is SPMD. FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/14073 from apivovarov:select_compare_algsimp 6fe68d7319b272ff041b67e038359540cddda489 PiperOrigin-RevId: 644267887 --- .../mlir/lite/quantization/lite/BUILD | 93 ++++++----- .../quantization/lite/quantize_model_test.cc | 65 ++++---- .../lite/quantize_weights_test.cc | 17 +- .../mlir/lite/quantization/lite}/test_util.cc | 10 +- .../mlir/lite/quantization/lite}/test_util.h | 16 +- .../quantization/lite}/testdata/README.md | 0 .../lite}/testdata/add_with_const_input.bin | Bin .../quantization/lite}/testdata/argmax.bin | Bin .../lite}/testdata/broadcast_to.bin | Bin .../quantization/lite}/testdata/concat.bin | Bin .../quantization/lite}/testdata/custom_op.bin | Bin .../lite/quantization/lite}/testdata/fc.bin | Bin .../quantization/lite}/testdata/fc_qat.bin | Bin .../quantization/lite}/testdata/gather_nd.bin | Bin .../lite}/testdata/lstm_calibrated.bin | Bin .../lite}/testdata/lstm_calibrated2.bin | Bin .../lite}/testdata/lstm_quantized.bin | Bin .../lite}/testdata/lstm_quantized2.bin | Bin .../quantization/lite}/testdata/maximum.bin | Bin .../quantization/lite}/testdata/minimum.bin | Bin .../quantization/lite}/testdata/mixed.bin | Bin .../quantization/lite}/testdata/mixed16x8.bin | Bin .../testdata/multi_input_add_reshape.bin | Bin .../lite/quantization/lite}/testdata/pack.bin | Bin .../lite}/testdata/quantized_with_gather.bin | Bin .../testdata/resource_vars_calibrated.bin | Bin ...single_avg_pool_min_minus_5_max_plus_5.bin | Bin .../lite}/testdata/single_conv_no_bias.bin | Bin .../single_conv_weights_min_0_max_plus_10.bin | Bin ...onv_weights_min_minus_127_max_plus_127.bin | Bin .../single_softmax_min_minus_5_max_plus_5.bin | Bin .../quantization/lite}/testdata/split.bin | Bin .../lite}/testdata/svdf_calibrated.bin | Bin .../lite}/testdata/svdf_quantized.bin | Bin .../quantization/lite}/testdata/transpose.bin | Bin ...nidirectional_sequence_lstm_calibrated.bin | Bin ...unidirectional_sequence_lstm_quantized.bin | Bin .../quantization/lite}/testdata/unpack.bin | Bin .../testdata/weight_shared_between_convs.bin | Bin .../quantization/lite}/testdata/where.bin | Bin tensorflow/lite/tools/optimize/BUILD | 103 +++++------- .../lite/tools/optimize/model_utils_test.cc | 1 - .../tools/optimize/quantization_utils_test.cc | 4 +- .../tools/optimize/quantize_model_test.cc | 106 ++++++++----- .../tools/optimize/quantize_weights_test.cc | 16 +- .../reduced_precision_support_test.cc | 1 - third_party/xla/xla/service/BUILD | 5 + .../xla/xla/service/algebraic_simplifier.cc | 27 ++++ .../xla/service/algebraic_simplifier_test.cc | 89 +++++++++++ .../xla/xla/service/all_reduce_simplifier.cc | 62 ++++++-- .../xla/service/all_reduce_simplifier_test.cc | 91 +++++++++++ .../xla/xla/service/collective_ops_utils.cc | 8 +- .../xla/service/collective_ops_utils_test.cc | 103 ++++++++++-- .../gpu/gpu_windowed_einsum_handler.cc | 149 +++++++++++++++--- .../gpu/gpu_windowed_einsum_handler_test.cc | 86 +++++++++- 55 files changed, 790 insertions(+), 262 deletions(-) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/test_util.cc (95%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/test_util.h (93%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/README.md (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/add_with_const_input.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/argmax.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/broadcast_to.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/concat.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/custom_op.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/fc.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/fc_qat.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/gather_nd.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/lstm_calibrated.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/lstm_calibrated2.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/lstm_quantized.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/lstm_quantized2.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/maximum.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/minimum.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/mixed.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/mixed16x8.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/multi_input_add_reshape.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/pack.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/quantized_with_gather.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/resource_vars_calibrated.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/single_avg_pool_min_minus_5_max_plus_5.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/single_conv_no_bias.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/single_conv_weights_min_0_max_plus_10.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/single_conv_weights_min_minus_127_max_plus_127.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/single_softmax_min_minus_5_max_plus_5.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/split.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/svdf_calibrated.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/svdf_quantized.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/transpose.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/unidirectional_sequence_lstm_calibrated.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/unidirectional_sequence_lstm_quantized.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/unpack.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/weight_shared_between_convs.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/where.bin (100%) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index 78ae512eb54202..e57c2b30808d82 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -10,6 +10,10 @@ package( licenses = ["notice"], ) +exports_files(glob([ + "testdata/*.bin", +])) + package_group( name = "friends", packages = [ @@ -123,39 +127,39 @@ tf_cc_test( name = "quantize_model_test", srcs = ["quantize_model_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - "//tensorflow/lite/tools/optimize:testdata/add_with_const_input.bin", - "//tensorflow/lite/tools/optimize:testdata/argmax.bin", - "//tensorflow/lite/tools/optimize:testdata/broadcast_to.bin", - "//tensorflow/lite/tools/optimize:testdata/concat.bin", - "//tensorflow/lite/tools/optimize:testdata/fc.bin", - "//tensorflow/lite/tools/optimize:testdata/fc_qat.bin", - "//tensorflow/lite/tools/optimize:testdata/gather_nd.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated2.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_quantized2.bin", - "//tensorflow/lite/tools/optimize:testdata/maximum.bin", - "//tensorflow/lite/tools/optimize:testdata/minimum.bin", - "//tensorflow/lite/tools/optimize:testdata/mixed.bin", - "//tensorflow/lite/tools/optimize:testdata/mixed16x8.bin", - "//tensorflow/lite/tools/optimize:testdata/multi_input_add_reshape.bin", - "//tensorflow/lite/tools/optimize:testdata/pack.bin", - "//tensorflow/lite/tools/optimize:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_no_bias.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", - "//tensorflow/lite/tools/optimize:testdata/single_softmax_min_minus_5_max_plus_5.bin", - "//tensorflow/lite/tools/optimize:testdata/split.bin", - "//tensorflow/lite/tools/optimize:testdata/svdf_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/svdf_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/transpose.bin", - "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/unpack.bin", - "//tensorflow/lite/tools/optimize:testdata/where.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/add_with_const_input.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/argmax.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/broadcast_to.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/concat.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc_qat.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/gather_nd.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated2.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized2.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/maximum.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/minimum.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed16x8.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/multi_input_add_reshape.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/pack.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_no_bias.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_softmax_min_minus_5_max_plus_5.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/split.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/transpose.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unpack.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/where.bin", ], tags = [ "tflite_not_portable_android", @@ -163,12 +167,12 @@ tf_cc_test( ], deps = [ ":quantize_model", + ":test_util", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/schema:schema_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", - "//tensorflow/lite/tools/optimize:test_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", @@ -181,13 +185,13 @@ tf_cc_test( name = "quantize_weights_test", srcs = ["quantize_weights_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - "//tensorflow/lite/tools/optimize:testdata/custom_op.bin", - "//tensorflow/lite/tools/optimize:testdata/quantized_with_gather.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", - "//tensorflow/lite/tools/optimize:testdata/weight_shared_between_convs.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/custom_op.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/quantized_with_gather.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/weight_shared_between_convs.bin", ], tags = [ # TODO(b/327796566): re-enable after the bug is fixed @@ -200,15 +204,28 @@ tf_cc_test( ], deps = [ ":quantize_weights", + ":test_util", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/schema:schema_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", - "//tensorflow/lite/tools/optimize:test_util", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@flatbuffers", "@local_tsl//tsl/platform:logging", ], ) + +cc_library( + name = "test_util", + testonly = 1, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + "//tensorflow/lite:framework", + "//tensorflow/lite/core/api", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc index e7d5e00b703392..1e7cdcdea07d33 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/status/status.h" #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" #include "tensorflow/core/platform/init_main.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/tools/optimize/test_util.h" #include "tsl/lib/core/status_test_util.h" // Note: branched from tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -192,7 +192,8 @@ void VerifyQuantizationScale( class QuantizeModelTest : public testing::Test { protected: QuantizeModelTest() { - input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); + input_model_ = + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -277,7 +278,8 @@ class QuantizeConvModelTest : public QuantizeModelTest, protected: QuantizeConvModelTest() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); + input_model_ = + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); // Flatbuffer is missing calibration data -- add dummy params. @@ -347,7 +349,7 @@ TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) { class QuantizeConvNoBiasModelTest : public QuantizeModelTest { protected: QuantizeConvNoBiasModelTest() { - input_model_ = ReadModel(internal::kConvModelWithNoBias); + input_model_ = ReadModel(::mlir::lite::internal::kConvModelWithNoBias); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -367,7 +369,7 @@ class QuantizeConvNoBiasModelTest : public QuantizeModelTest { class QuantizeSplitModelTest : public QuantizeModelTest { protected: QuantizeSplitModelTest() { - input_model_ = ReadModel(internal::kModelSplit); + input_model_ = ReadModel(::mlir::lite::internal::kModelSplit); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -452,7 +454,8 @@ class QuantizeConvModel2Test : public QuantizeModelTest, protected: QuantizeConvModel2Test() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); + input_model_ = + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); auto& subgraph = model_.subgraphs[0]; @@ -690,7 +693,8 @@ TEST_P(QuantizeConvModel2Test, VerifyConvDisablePerChannelQuantization) { class QuantizeSoftmaxTest : public QuantizeModelTest { protected: QuantizeSoftmaxTest() { - input_model_ = ReadModel(internal::kSingleSoftmaxModelMinMinus5MaxPlus5); + input_model_ = + ReadModel(::mlir::lite::internal::kSingleSoftmaxModelMinMinus5MaxPlus5); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -753,7 +757,8 @@ TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { class QuantizeAvgPoolTest : public QuantizeModelTest { protected: QuantizeAvgPoolTest() { - input_model_ = ReadModel(internal::kSingleAvgPoolModelMinMinus5MaxPlus5); + input_model_ = + ReadModel(::mlir::lite::internal::kSingleAvgPoolModelMinMinus5MaxPlus5); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -813,7 +818,7 @@ TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest { protected: QuantizeMultiInputAddWithReshapeTest() { - input_model_ = ReadModel(internal::kMultiInputAddWithReshape); + input_model_ = ReadModel(::mlir::lite::internal::kMultiInputAddWithReshape); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -933,7 +938,7 @@ class QuantizeConstInputTest : public QuantizeModelTest, protected: QuantizeConstInputTest() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kConstInputAddModel); + input_model_ = ReadModel(::mlir::lite::internal::kConstInputAddModel); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -980,7 +985,7 @@ TEST_P(QuantizeConstInputTest, VerifyConstOpInput) { class QuantizeArgMaxTest : public QuantizeModelTest { protected: QuantizeArgMaxTest() { - input_model_ = ReadModel(internal::kModelWithArgMaxOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithArgMaxOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1025,7 +1030,7 @@ TEST_F(QuantizeArgMaxTest, VerifyArgMax) { class QuantizeLSTMTest : public QuantizeModelTest { protected: QuantizeLSTMTest() { - input_model_ = ReadModel(internal::kLstmCalibrated); + input_model_ = ReadModel(::mlir::lite::internal::kLstmCalibrated); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1037,7 +1042,7 @@ TEST_F(QuantizeLSTMTest, VerifyLSTM) { /*allow_float=*/true, TensorType_INT8, output_buffer_)); // Read expected model. - auto expected_fb_model = ReadModel(internal::kLstmQuantized); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kLstmQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1048,7 +1053,7 @@ TEST_F(QuantizeLSTMTest, VerifyLSTM) { class QuantizeLSTM2Test : public QuantizeModelTest { protected: QuantizeLSTM2Test() { - input_model_ = ReadModel(internal::kLstmCalibrated2); + input_model_ = ReadModel(::mlir::lite::internal::kLstmCalibrated2); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1061,7 +1066,7 @@ TEST_F(QuantizeLSTM2Test, VerifyLSTM) { /*allow_float=*/false, TensorType_INT8, output_buffer_)); // Read expected model. - auto expected_fb_model = ReadModel(internal::kLstmQuantized2); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kLstmQuantized2); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1072,7 +1077,8 @@ TEST_F(QuantizeLSTM2Test, VerifyLSTM) { class QuantizeUnidirectionalSequenceLSTMTest : public QuantizeModelTest { protected: QuantizeUnidirectionalSequenceLSTMTest() { - input_model_ = ReadModel(internal::kUnidirectionalSequenceLstmCalibrated); + input_model_ = ReadModel( + ::mlir::lite::internal::kUnidirectionalSequenceLstmCalibrated); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1086,7 +1092,7 @@ TEST_F(QuantizeUnidirectionalSequenceLSTMTest, // Read expected model. auto expected_fb_model = - ReadModel(internal::kUnidirectionalSequenceLstmQuantized); + ReadModel(::mlir::lite::internal::kUnidirectionalSequenceLstmQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1097,7 +1103,7 @@ TEST_F(QuantizeUnidirectionalSequenceLSTMTest, class QuantizeSVDFTest : public QuantizeModelTest { protected: QuantizeSVDFTest() { - input_model_ = ReadModel(internal::kSvdfCalibrated); + input_model_ = ReadModel(::mlir::lite::internal::kSvdfCalibrated); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1110,7 +1116,7 @@ TEST_F(QuantizeSVDFTest, VerifySVDF) { /*allow_float=*/false, TensorType_INT8, output_buffer_)); // Read expected model. - auto expected_fb_model = ReadModel(internal::kSvdfQuantized); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kSvdfQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1123,7 +1129,7 @@ class QuantizeFCTest : public QuantizeModelTest, protected: QuantizeFCTest() { disable_per_channel_quantization_for_dense_ = GetParam(); - input_model_ = ReadModel(internal::kModelWithFCOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithFCOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1371,7 +1377,7 @@ class QuantizeCustomOpTest public ::testing::WithParamInterface { protected: QuantizeCustomOpTest() { - input_model_ = ReadModel(internal::kModelMixed); + input_model_ = ReadModel(::mlir::lite::internal::kModelMixed); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1409,7 +1415,7 @@ INSTANTIATE_TEST_SUITE_P(QuantizeCustomOpTest, QuantizeCustomOpTest, class QuantizePackTest : public QuantizeModelTest { protected: QuantizePackTest() { - input_model_ = ReadModel(internal::kModelPack); + input_model_ = ReadModel(::mlir::lite::internal::kModelPack); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1526,14 +1532,15 @@ TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) { Eq(input2->quantization->zero_point)); } -INSTANTIATE_TEST_SUITE_P(MinimumMaximumTestInst, QuantizeMinimumMaximumTest, - testing::ValuesIn({internal::kModelWithMinimumOp, - internal::kModelWithMaximumOp})); +INSTANTIATE_TEST_SUITE_P( + MinimumMaximumTestInst, QuantizeMinimumMaximumTest, + testing::ValuesIn({::mlir::lite::internal::kModelWithMinimumOp, + ::mlir::lite::internal::kModelWithMaximumOp})); class QuantizeUnpackTest : public QuantizeModelTest { protected: QuantizeUnpackTest() { - input_model_ = ReadModel(internal::kModelWithUnpack); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithUnpack); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1583,7 +1590,7 @@ class QuantizeBroadcastToModelTest protected: QuantizeBroadcastToModelTest() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kModelWithBroadcastToOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithBroadcastToOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1646,7 +1653,7 @@ class QuantizeGatherNDModelTest protected: QuantizeGatherNDModelTest() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kModelWithGatherNDOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithGatherNDOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1706,7 +1713,7 @@ TEST_P(QuantizeGatherNDModelTest, QuantizeGatherND) { class QuantizeWhereModelTest : public QuantizeModelTest { protected: QuantizeWhereModelTest() { - input_model_ = ReadModel(internal::kModelWithWhereOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithWhereOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc index 2e80bcae7486b4..7a42e74c2619af 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h" #include +#include #include #include #include @@ -26,6 +27,7 @@ limitations under the License. #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" #include "tensorflow/core/platform/init_main.h" @@ -33,7 +35,6 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/tools/optimize/test_util.h" #include "tsl/platform/logging.h" // Note: branched from tensorflow/lite/tools/optimize/quantize_weights_test.cc @@ -59,25 +60,25 @@ std::unique_ptr CreateMutableModelFromFile(const Model* input_model) { std::unique_ptr ReadTestModel() { auto model_path = tensorflow::io::JoinPath( - *g_test_model_dir, internal::kConvModelWith0Plus10Weights); + *g_test_model_dir, ::mlir::lite::internal::kConvModelWith0Plus10Weights); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadSharedWeightsTestModel() { - auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, - internal::kModelWithSharedWeights); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithSharedWeights); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadGatherTestModel() { - auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, - internal::kQuantizedWithGather); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kQuantizedWithGather); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadCustomOpTestModel() { - auto model_path = - tensorflow::io::JoinPath(*g_test_model_dir, internal::kModelWithCustomOp); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithCustomOp); return FlatBufferModel::BuildFromFile(model_path.c_str()); } diff --git a/tensorflow/lite/tools/optimize/test_util.cc b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc similarity index 95% rename from tensorflow/lite/tools/optimize/test_util.cc rename to tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc index 5ca45326d1dcad..e096868eec8807 100644 --- a/tensorflow/lite/tools/optimize/test_util.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/tools/optimize/test_util.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include -namespace tflite { -namespace optimize { +namespace mlir { +namespace lite { namespace internal { const char* kConvModelWithMinus128Plus127Weights = "single_conv_weights_min_minus_127_max_plus_127.bin"; @@ -89,5 +89,5 @@ int FailOnErrorReporter::Report(const char* format, va_list args) { return 0; } } // namespace internal -} // namespace optimize -} // namespace tflite +} // namespace lite +} // namespace mlir diff --git a/tensorflow/lite/tools/optimize/test_util.h b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h similarity index 93% rename from tensorflow/lite/tools/optimize/test_util.h rename to tensorflow/compiler/mlir/lite/quantization/lite/test_util.h index 11e7ef230910f2..b4e317c131888e 100644 --- a/tensorflow/lite/tools/optimize/test_util.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_TEST_UTIL_H_ -#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_TEST_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ #include "tensorflow/lite/core/api/error_reporter.h" -namespace tflite { -namespace optimize { +namespace mlir { +namespace lite { namespace internal { // Test model with a single convolution. // Floating point weights of the model are all integers and lie in @@ -132,12 +132,12 @@ extern const char* kQatModelWithFc; extern const char* kModelWithResourceVarsCalibrated; // An error reporter that fails on testing. -class FailOnErrorReporter : public ErrorReporter { +class FailOnErrorReporter : public tflite::ErrorReporter { public: int Report(const char* format, va_list args) override; }; } // namespace internal -} // namespace optimize -} // namespace tflite +} // namespace lite +} // namespace mlir -#endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_TEST_UTIL_H_ +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ diff --git a/tensorflow/lite/tools/optimize/testdata/README.md b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/README.md similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/README.md rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/README.md diff --git a/tensorflow/lite/tools/optimize/testdata/add_with_const_input.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/add_with_const_input.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/add_with_const_input.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/add_with_const_input.bin diff --git a/tensorflow/lite/tools/optimize/testdata/argmax.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/argmax.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/argmax.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/argmax.bin diff --git a/tensorflow/lite/tools/optimize/testdata/broadcast_to.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/broadcast_to.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/broadcast_to.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/broadcast_to.bin diff --git a/tensorflow/lite/tools/optimize/testdata/concat.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/concat.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/concat.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/concat.bin diff --git a/tensorflow/lite/tools/optimize/testdata/custom_op.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/custom_op.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/custom_op.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/custom_op.bin diff --git a/tensorflow/lite/tools/optimize/testdata/fc.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/fc.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/fc.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/fc.bin diff --git a/tensorflow/lite/tools/optimize/testdata/fc_qat.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/fc_qat.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/fc_qat.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/fc_qat.bin diff --git a/tensorflow/lite/tools/optimize/testdata/gather_nd.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/gather_nd.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/gather_nd.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/gather_nd.bin diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_calibrated.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_calibrated.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/lstm_calibrated.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_calibrated.bin diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_calibrated2.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_calibrated2.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/lstm_calibrated2.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_calibrated2.bin diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_quantized.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_quantized.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/lstm_quantized.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_quantized.bin diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_quantized2.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_quantized2.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/lstm_quantized2.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_quantized2.bin diff --git a/tensorflow/lite/tools/optimize/testdata/maximum.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/maximum.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/maximum.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/maximum.bin diff --git a/tensorflow/lite/tools/optimize/testdata/minimum.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/minimum.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/minimum.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/minimum.bin diff --git a/tensorflow/lite/tools/optimize/testdata/mixed.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/mixed.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/mixed.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/mixed.bin diff --git a/tensorflow/lite/tools/optimize/testdata/mixed16x8.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/mixed16x8.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/mixed16x8.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/mixed16x8.bin diff --git a/tensorflow/lite/tools/optimize/testdata/multi_input_add_reshape.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/multi_input_add_reshape.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/multi_input_add_reshape.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/multi_input_add_reshape.bin diff --git a/tensorflow/lite/tools/optimize/testdata/pack.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/pack.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/pack.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/pack.bin diff --git a/tensorflow/lite/tools/optimize/testdata/quantized_with_gather.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/quantized_with_gather.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/quantized_with_gather.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/quantized_with_gather.bin diff --git a/tensorflow/lite/tools/optimize/testdata/resource_vars_calibrated.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/resource_vars_calibrated.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/resource_vars_calibrated.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/resource_vars_calibrated.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_avg_pool_min_minus_5_max_plus_5.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_avg_pool_min_minus_5_max_plus_5.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_avg_pool_min_minus_5_max_plus_5.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_avg_pool_min_minus_5_max_plus_5.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_conv_no_bias.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_no_bias.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_conv_no_bias.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_no_bias.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_conv_weights_min_0_max_plus_10.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_weights_min_0_max_plus_10.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_conv_weights_min_0_max_plus_10.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_weights_min_0_max_plus_10.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_conv_weights_min_minus_127_max_plus_127.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_weights_min_minus_127_max_plus_127.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_conv_weights_min_minus_127_max_plus_127.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_weights_min_minus_127_max_plus_127.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_softmax_min_minus_5_max_plus_5.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_softmax_min_minus_5_max_plus_5.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_softmax_min_minus_5_max_plus_5.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_softmax_min_minus_5_max_plus_5.bin diff --git a/tensorflow/lite/tools/optimize/testdata/split.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/split.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/split.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/split.bin diff --git a/tensorflow/lite/tools/optimize/testdata/svdf_calibrated.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/svdf_calibrated.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/svdf_calibrated.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/svdf_calibrated.bin diff --git a/tensorflow/lite/tools/optimize/testdata/svdf_quantized.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/svdf_quantized.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/svdf_quantized.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/svdf_quantized.bin diff --git a/tensorflow/lite/tools/optimize/testdata/transpose.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/transpose.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/transpose.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/transpose.bin diff --git a/tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_calibrated.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/unidirectional_sequence_lstm_calibrated.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_calibrated.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/unidirectional_sequence_lstm_calibrated.bin diff --git a/tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_quantized.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/unidirectional_sequence_lstm_quantized.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_quantized.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/unidirectional_sequence_lstm_quantized.bin diff --git a/tensorflow/lite/tools/optimize/testdata/unpack.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/unpack.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/unpack.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/unpack.bin diff --git a/tensorflow/lite/tools/optimize/testdata/weight_shared_between_convs.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/weight_shared_between_convs.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/weight_shared_between_convs.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/weight_shared_between_convs.bin diff --git a/tensorflow/lite/tools/optimize/testdata/where.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/where.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/where.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/where.bin diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD index 711f97bdfddd16..a05a5cbdb10710 100644 --- a/tensorflow/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -14,10 +14,6 @@ package( licenses = ["notice"], ) -exports_files(glob([ - "testdata/*.bin", -])) - cc_library( name = "reduced_precision_support", srcs = [], @@ -39,7 +35,6 @@ tf_cc_test( ], deps = [ ":reduced_precision_support", - ":test_util", "//tensorflow/core/platform:platform_port", "//tensorflow/lite:framework", "//tensorflow/lite/core:framework", @@ -223,7 +218,6 @@ tf_cc_test( ], deps = [ ":model_utils", - ":test_util", "//tensorflow/lite:framework", "//tensorflow/lite/core:framework", "//tensorflow/lite/schema:schema_fbs", @@ -250,10 +244,10 @@ tf_cc_test( name = "quantization_utils_test", srcs = ["quantization_utils_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - ":testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", ], tags = [ "tflite_not_portable_android", @@ -261,7 +255,7 @@ tf_cc_test( ], deps = [ ":quantization_utils", - ":test_util", + "//tensorflow/compiler/mlir/lite/quantization/lite:test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", @@ -316,13 +310,13 @@ tf_cc_test( name = "quantize_weights_test", srcs = ["quantize_weights_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - "//tensorflow/lite/tools/optimize:testdata/custom_op.bin", - "//tensorflow/lite/tools/optimize:testdata/quantized_with_gather.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", - "//tensorflow/lite/tools/optimize:testdata/weight_shared_between_convs.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/custom_op.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/quantized_with_gather.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/weight_shared_between_convs.bin", ], tags = [ "tflite_not_portable_android", @@ -330,7 +324,7 @@ tf_cc_test( ], deps = [ ":quantize_weights", - ":test_util", + "//tensorflow/compiler/mlir/lite/quantization/lite:test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", @@ -342,19 +336,6 @@ tf_cc_test( ], ) -cc_library( - name = "test_util", - testonly = 1, - srcs = ["test_util.cc"], - hdrs = ["test_util.h"], - deps = [ - "//tensorflow/lite:framework", - "//tensorflow/lite/core/api", - "@com_google_googletest//:gtest", - "@flatbuffers", - ], -) - cc_library( name = "quantize_model", srcs = ["quantize_model.cc"], @@ -379,40 +360,40 @@ tf_cc_test( name = "quantize_model_test", srcs = ["quantize_model_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - "//tensorflow/lite/tools/optimize:testdata/add_with_const_input.bin", - "//tensorflow/lite/tools/optimize:testdata/argmax.bin", - "//tensorflow/lite/tools/optimize:testdata/broadcast_to.bin", - "//tensorflow/lite/tools/optimize:testdata/concat.bin", - "//tensorflow/lite/tools/optimize:testdata/fc.bin", - "//tensorflow/lite/tools/optimize:testdata/fc_qat.bin", - "//tensorflow/lite/tools/optimize:testdata/gather_nd.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated2.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_quantized2.bin", - "//tensorflow/lite/tools/optimize:testdata/maximum.bin", - "//tensorflow/lite/tools/optimize:testdata/minimum.bin", - "//tensorflow/lite/tools/optimize:testdata/mixed.bin", - "//tensorflow/lite/tools/optimize:testdata/mixed16x8.bin", - "//tensorflow/lite/tools/optimize:testdata/multi_input_add_reshape.bin", - "//tensorflow/lite/tools/optimize:testdata/pack.bin", - "//tensorflow/lite/tools/optimize:testdata/resource_vars_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_no_bias.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", - "//tensorflow/lite/tools/optimize:testdata/single_softmax_min_minus_5_max_plus_5.bin", - "//tensorflow/lite/tools/optimize:testdata/split.bin", - "//tensorflow/lite/tools/optimize:testdata/svdf_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/svdf_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/transpose.bin", - "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/unpack.bin", - "//tensorflow/lite/tools/optimize:testdata/where.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/add_with_const_input.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/argmax.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/broadcast_to.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/concat.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc_qat.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/gather_nd.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated2.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized2.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/maximum.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/minimum.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed16x8.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/multi_input_add_reshape.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/pack.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/resource_vars_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_no_bias.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_softmax_min_minus_5_max_plus_5.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/split.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/transpose.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unpack.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/where.bin", ], tags = [ "tflite_not_portable_android", @@ -420,7 +401,7 @@ tf_cc_test( ], deps = [ ":quantize_model", - ":test_util", + "//tensorflow/compiler/mlir/lite/quantization/lite:test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", diff --git a/tensorflow/lite/tools/optimize/model_utils_test.cc b/tensorflow/lite/tools/optimize/model_utils_test.cc index 65e3afe35e2da2..f702e1fa0a0ddd 100644 --- a/tensorflow/lite/tools/optimize/model_utils_test.cc +++ b/tensorflow/lite/tools/optimize/model_utils_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include "tensorflow/lite/core/model.h" #include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/tools/optimize/test_util.h" namespace tflite { namespace optimize { diff --git a/tensorflow/lite/tools/optimize/quantization_utils_test.cc b/tensorflow/lite/tools/optimize/quantization_utils_test.cc index a09acef6f4aa3c..a0ab9c43eacb75 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils_test.cc +++ b/tensorflow/lite/tools/optimize/quantization_utils_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include "absl/memory/memory.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/testing/util.h" -#include "tensorflow/lite/tools/optimize/test_util.h" namespace { tensorflow::string* g_test_model_dir = nullptr; @@ -46,7 +46,7 @@ std::unique_ptr ReadModel(const char* model) { } std::unique_ptr ReadConvModel() { - return ReadModel(internal::kConvModelWith0Plus10Weights); + return ReadModel(mlir::lite::internal::kConvModelWith0Plus10Weights); } using ::testing::ElementsAreArray; diff --git a/tensorflow/lite/tools/optimize/quantize_model_test.cc b/tensorflow/lite/tools/optimize/quantize_model_test.cc index 681507c8e0d31d..a7e9115f8bdaaa 100644 --- a/tensorflow/lite/tools/optimize/quantize_model_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/testing/util.h" -#include "tensorflow/lite/tools/optimize/test_util.h" // Note: More rigorous model tests can be found in subgraph_quantizer_test.cc @@ -78,7 +78,8 @@ TensorType GetBiasTensorType(TensorType& activation_type) { class QuantizeModelTest : public testing::Test { protected: QuantizeModelTest() - : QuantizeModelTest(ReadModel(internal::kConvModelWith0Plus10Weights)) {} + : QuantizeModelTest( + ReadModel(mlir::lite::internal::kConvModelWith0Plus10Weights)) {} explicit QuantizeModelTest(std::unique_ptr input_model) { input_model_ = std::move(input_model); @@ -91,7 +92,7 @@ class QuantizeModelTest : public testing::Test { const Model* readonly_model_; tflite::ModelT model_; flatbuffers::FlatBufferBuilder builder_; - internal::FailOnErrorReporter error_reporter_; + ::mlir::lite::internal::FailOnErrorReporter error_reporter_; }; void ExpectSameModels(const ModelT& model, const ModelT& expected_model) { @@ -136,7 +137,8 @@ class QuantizeConvModelTest : public QuantizeModelTest, public testing::WithParamInterface { protected: QuantizeConvModelTest() - : QuantizeModelTest(ReadModel(internal::kConvModelWith0Plus10Weights)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} TensorType tensor_type_; @@ -405,7 +407,8 @@ TEST_P(QuantizeConvModelTest, Uint8InputAndOutput) { class QuantizeConvNoBiasModelTest : public QuantizeModelTest { protected: QuantizeConvNoBiasModelTest() - : QuantizeModelTest(ReadModel(internal::kConvModelWithNoBias)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConvModelWithNoBias)) {} }; TEST_F(QuantizeConvNoBiasModelTest, QuantizationSucceeds) { @@ -422,7 +425,8 @@ class QuantizeConcatModelTest : public QuantizeModelTest, public testing::WithParamInterface { protected: QuantizeConcatModelTest() - : QuantizeModelTest(ReadModel(internal::kFloatConcatMax5Max10Max10)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kFloatConcatMax5Max10Max10)) {} void SetUp() override { tensor_type_ = GetParam(); @@ -536,7 +540,7 @@ INSTANTIATE_TEST_SUITE_P(QuantizeConcatModelInst, QuantizeConcatModelTest, class QuantizeSplitModelTest : public QuantizeModelTest { protected: QuantizeSplitModelTest() - : QuantizeModelTest(ReadModel(internal::kModelSplit)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelSplit)) {} }; // There are two outputs for split with different scales, the resulting model @@ -601,8 +605,8 @@ TEST_F(QuantizeSplitModelTest, QuantizeSplit) { class QuantizeConvModel1Test : public QuantizeModelTest { protected: QuantizeConvModel1Test() - : QuantizeModelTest( - ReadModel(internal::kConvModelWithMinus128Plus127Weights)) {} + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kConvModelWithMinus128Plus127Weights)) {} }; TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) { @@ -703,7 +707,8 @@ class QuantizeConvModel2Test : public QuantizeModelTest, public testing::WithParamInterface { protected: QuantizeConvModel2Test() - : QuantizeModelTest(ReadModel(internal::kConvModelWith0Plus10Weights)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -925,8 +930,8 @@ TEST_P(QuantizeConvModel2Test, VerifyConvDisablePerChannelQuantization) { class QuantizeSoftmaxTest : public QuantizeModelTest { protected: QuantizeSoftmaxTest() - : QuantizeModelTest( - ReadModel(internal::kSingleSoftmaxModelMinMinus5MaxPlus5)) {} + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kSingleSoftmaxModelMinMinus5MaxPlus5)) {} }; TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { @@ -985,8 +990,8 @@ TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { class QuantizeAvgPoolTest : public QuantizeModelTest { protected: QuantizeAvgPoolTest() - : QuantizeModelTest( - ReadModel(internal::kSingleAvgPoolModelMinMinus5MaxPlus5)) {} + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kSingleAvgPoolModelMinMinus5MaxPlus5)) {} }; TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { @@ -1045,7 +1050,8 @@ TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest { protected: QuantizeMultiInputAddWithReshapeTest() - : QuantizeModelTest(ReadModel(internal::kMultiInputAddWithReshape)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kMultiInputAddWithReshape)) {} }; TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) { @@ -1155,7 +1161,8 @@ class QuantizeConstInputTest : public QuantizeModelTest, public testing::WithParamInterface { protected: QuantizeConstInputTest() - : QuantizeModelTest(ReadModel(internal::kConstInputAddModel)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConstInputAddModel)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -1213,7 +1220,8 @@ TEST_P(QuantizeConstInputTest, VerifyConstOpInput) { class QuantizeArgMaxTest : public QuantizeModelTest { protected: QuantizeArgMaxTest() - : QuantizeModelTest(ReadModel(internal::kModelWithArgMaxOp)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithArgMaxOp)) {} }; TEST_F(QuantizeArgMaxTest, VerifyArgMax) { @@ -1254,7 +1262,7 @@ TEST_F(QuantizeArgMaxTest, VerifyArgMax) { class QuantizeLSTMTest : public QuantizeModelTest { protected: QuantizeLSTMTest() - : QuantizeModelTest(ReadModel(internal::kLstmCalibrated)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kLstmCalibrated)) {} }; TEST_F(QuantizeLSTMTest, VerifyLSTM) { @@ -1265,7 +1273,7 @@ TEST_F(QuantizeLSTMTest, VerifyLSTM) { ASSERT_EQ(kTfLiteOk, status); // Read expected model. - auto expected_fb_model = ReadModel(internal::kLstmQuantized); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kLstmQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1276,7 +1284,8 @@ TEST_F(QuantizeLSTMTest, VerifyLSTM) { class QuantizeLSTM2Test : public QuantizeModelTest { protected: QuantizeLSTM2Test() - : QuantizeModelTest(ReadModel(internal::kLstmCalibrated2)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kLstmCalibrated2)) { + } }; TEST_F(QuantizeLSTM2Test, VerifyLSTM) { @@ -1287,7 +1296,7 @@ TEST_F(QuantizeLSTM2Test, VerifyLSTM) { ASSERT_EQ(kTfLiteOk, status); // Read expected model. - auto expected_fb_model = ReadModel(internal::kLstmQuantized2); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kLstmQuantized2); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1298,8 +1307,8 @@ TEST_F(QuantizeLSTM2Test, VerifyLSTM) { class QuantizeUnidirectionalSequenceLSTMTest : public QuantizeModelTest { protected: QuantizeUnidirectionalSequenceLSTMTest() - : QuantizeModelTest( - ReadModel(internal::kUnidirectionalSequenceLstmCalibrated)) {} + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kUnidirectionalSequenceLstmCalibrated)) {} }; TEST_F(QuantizeUnidirectionalSequenceLSTMTest, @@ -1312,7 +1321,7 @@ TEST_F(QuantizeUnidirectionalSequenceLSTMTest, // Read expected model. auto expected_fb_model = - ReadModel(internal::kUnidirectionalSequenceLstmQuantized); + ReadModel(::mlir::lite::internal::kUnidirectionalSequenceLstmQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1323,7 +1332,7 @@ TEST_F(QuantizeUnidirectionalSequenceLSTMTest, class QuantizeSVDFTest : public QuantizeModelTest { protected: QuantizeSVDFTest() - : QuantizeModelTest(ReadModel(internal::kSvdfCalibrated)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kSvdfCalibrated)) {} }; TEST_F(QuantizeSVDFTest, VerifySVDF) { @@ -1334,7 +1343,7 @@ TEST_F(QuantizeSVDFTest, VerifySVDF) { ASSERT_EQ(kTfLiteOk, status); // Read expected model. - auto expected_fb_model = ReadModel(internal::kSvdfQuantized); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kSvdfQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1379,7 +1388,8 @@ TEST_F(QuantizeSVDFTest, VerifySVDF) { class QuantizeFCTest : public QuantizeModelTest { protected: - QuantizeFCTest() : QuantizeModelTest(ReadModel(internal::kModelWithFCOp)) {} + QuantizeFCTest() + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelWithFCOp)) {} }; TEST_F(QuantizeFCTest, VerifyFC) { @@ -1430,7 +1440,7 @@ class QuantizeCustomOpTest public ::testing::WithParamInterface { protected: QuantizeCustomOpTest() - : QuantizeModelTest(ReadModel(internal::kModelMixed)), + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelMixed)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -1471,7 +1481,7 @@ INSTANTIATE_TEST_SUITE_P(QuantizeCustomOpTest, QuantizeCustomOpTest, class QuantizeOp16x8Test : public QuantizeModelTest { protected: QuantizeOp16x8Test() - : QuantizeModelTest(ReadModel(internal::kModelMixed16x8)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelMixed16x8)) {} }; TEST_F(QuantizeOp16x8Test, VerifyMixedQuantization16x8) { @@ -1502,7 +1512,8 @@ TEST_F(QuantizeOp16x8Test, VerifyMixedQuantization16x8) { class QuantizePackTest : public QuantizeModelTest { protected: - QuantizePackTest() : QuantizeModelTest(ReadModel(internal::kModelPack)) {} + QuantizePackTest() + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelPack)) {} }; TEST_F(QuantizePackTest, VerifyPack) { @@ -1628,14 +1639,16 @@ TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) { EXPECT_EQ(subgraph->tensors[5]->name, "output"); } -INSTANTIATE_TEST_SUITE_P(MinimumMaximumTestInst, QuantizeMinimumMaximumTest, - testing::ValuesIn({internal::kModelWithMinimumOp, - internal::kModelWithMaximumOp})); +INSTANTIATE_TEST_SUITE_P( + MinimumMaximumTestInst, QuantizeMinimumMaximumTest, + testing::ValuesIn({::mlir::lite::internal::kModelWithMinimumOp, + ::mlir::lite::internal::kModelWithMaximumOp})); class QuantizeUnpackTest : public QuantizeModelTest { protected: QuantizeUnpackTest() - : QuantizeModelTest(ReadModel(internal::kModelWithUnpack)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelWithUnpack)) { + } }; TEST_F(QuantizeUnpackTest, VerifyUnpack) { auto status = QuantizeModel(&builder_, &model_, &error_reporter_); @@ -1680,7 +1693,8 @@ TEST_F(QuantizeUnpackTest, VerifyUnpack) { class QuantizeTransposeTest : public QuantizeModelTest { protected: QuantizeTransposeTest() - : QuantizeModelTest(ReadModel(internal::kModelWithTranspose)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithTranspose)) {} }; TEST_F(QuantizeTransposeTest, VerifyTranspose) { @@ -1720,7 +1734,8 @@ TEST_F(QuantizeTransposeTest, VerifyTranspose) { class QuantizeQatTest : public QuantizeModelTest { protected: - QuantizeQatTest() : QuantizeModelTest(ReadModel(internal::kQatModelWithFc)) {} + QuantizeQatTest() + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kQatModelWithFc)) {} }; TEST_F(QuantizeQatTest, VerifySingleQuantize) { @@ -1777,7 +1792,8 @@ class QuantizeBroadcastToModelTest public testing::WithParamInterface { protected: QuantizeBroadcastToModelTest() - : QuantizeModelTest(ReadModel(internal::kModelWithBroadcastToOp)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithBroadcastToOp)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -1844,7 +1860,8 @@ class QuantizeGatherNDModelTest public testing::WithParamInterface { protected: QuantizeGatherNDModelTest() - : QuantizeModelTest(ReadModel(internal::kModelWithGatherNDOp)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithGatherNDOp)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -1906,7 +1923,8 @@ TEST_P(QuantizeGatherNDModelTest, QuantizeGatherND) { class QuantizeWhereModelTest : public QuantizeModelTest { protected: QuantizeWhereModelTest() - : QuantizeModelTest(ReadModel(internal::kModelWithWhereOp)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithWhereOp)) {} }; TEST_F(QuantizeWhereModelTest, QuantizeWhere) { @@ -1976,8 +1994,8 @@ class QuantizeResourcesModelTest public testing::WithParamInterface { protected: QuantizeResourcesModelTest() - : QuantizeModelTest( - ReadModel(internal::kModelWithResourceVarsCalibrated)) { + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kModelWithResourceVarsCalibrated)) { TestType obj = GetParam(); tensor_type_ = obj.tensor_type; modify_range_ = obj.modify_range; @@ -2119,7 +2137,8 @@ class QuantizeConcatConstModelTest public testing::WithParamInterface { protected: QuantizeConcatConstModelTest() - : QuantizeModelTest(ReadModel(internal::kFloatConcatMax5Max10Max10)) { + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kFloatConcatMax5Max10Max10)) { // Make one of the values constant. MakeInputConstant(&model_); } @@ -2224,7 +2243,8 @@ class BiasInputTest : public QuantizeModelTest, public testing::WithParamInterface { protected: BiasInputTest() - : QuantizeModelTest(ReadModel(internal::kConvModelWith0Plus10Weights)) { + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights)) { BiasTestType obj = GetParam(); tensor_type_ = obj.tensor_type; bias_type_ = obj.bias_type; diff --git a/tensorflow/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/lite/tools/optimize/quantize_weights_test.cc index 0e9c3efc17acd9..b2279ed34908f6 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights_test.cc @@ -22,13 +22,13 @@ limitations under the License. #include #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/lite/core/model.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" -#include "tensorflow/lite/tools/optimize/test_util.h" namespace { tensorflow::string* g_test_model_dir = nullptr; @@ -40,25 +40,25 @@ namespace { std::unique_ptr ReadTestModel() { auto model_path = tensorflow::io::JoinPath( - *g_test_model_dir, internal::kConvModelWith0Plus10Weights); + *g_test_model_dir, ::mlir::lite::internal::kConvModelWith0Plus10Weights); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadSharedWeightsTestModel() { - auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, - internal::kModelWithSharedWeights); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithSharedWeights); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadGatherTestModel() { - auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, - internal::kQuantizedWithGather); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kQuantizedWithGather); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadCustomOpTestModel() { - auto model_path = - tensorflow::io::JoinPath(*g_test_model_dir, internal::kModelWithCustomOp); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithCustomOp); return FlatBufferModel::BuildFromFile(model_path.c_str()); } diff --git a/tensorflow/lite/tools/optimize/reduced_precision_support_test.cc b/tensorflow/lite/tools/optimize/reduced_precision_support_test.cc index 6b5cf538b50c43..19400079b17e96 100644 --- a/tensorflow/lite/tools/optimize/reduced_precision_support_test.cc +++ b/tensorflow/lite/tools/optimize/reduced_precision_support_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/testing/util.h" -#include "tensorflow/lite/tools/optimize/test_util.h" namespace tflite { namespace optimize { diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 553d29e960458e..a4fc2041629600 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -3057,11 +3057,14 @@ cc_library( srcs = ["all_reduce_simplifier.cc"], hdrs = ["all_reduce_simplifier.h"], deps = [ + ":collective_ops_utils", + ":hlo_module_config", ":hlo_pass", ":hlo_replication_analysis", "//xla:literal_util", "//xla:shape_util", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", @@ -3076,6 +3079,7 @@ xla_cc_test( srcs = ["all_reduce_simplifier_test.cc"], deps = [ ":all_reduce_simplifier", + ":hlo_module_config", ":hlo_parser", ":pattern_matcher", ":pattern_matcher_gmock", @@ -7254,6 +7258,7 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/service/algebraic_simplifier.cc b/third_party/xla/xla/service/algebraic_simplifier.cc index dd13ffebd9e852..641fedf0c72405 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.cc +++ b/third_party/xla/xla/service/algebraic_simplifier.cc @@ -8188,6 +8188,33 @@ absl::Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) { select->mutable_operand(0)->shape(), HloOpcode::kNot, select->mutable_operand(0))); } + // select(compare(a, b, GT/GE), a, b) => or(a, b) + // select(compare(a, b, LT/LE), a, b) => and(a, b) + // select(compare(a, b, EQ), a, b) => b + // select(compare(a, b, NE), a, b) => a + HloInstruction *compare, *lhs, *rhs; + if (Match(select, m::Select(m::Op(&compare), m::Op(&lhs), m::Op(&rhs))) && + Match(compare, m::Compare(m::Op().Is(lhs), m::Op().Is(rhs)))) { + auto cmp_dir = compare->comparison_direction(); + if (cmp_dir == ComparisonDirection::kGt || + cmp_dir == ComparisonDirection::kGe) { + return ReplaceWithNewInstruction( + select, HloInstruction::CreateBinary(select->shape(), + HloOpcode::kOr, lhs, rhs)); + } + if (cmp_dir == ComparisonDirection::kLt || + cmp_dir == ComparisonDirection::kLe) { + return ReplaceWithNewInstruction( + select, HloInstruction::CreateBinary(select->shape(), + HloOpcode::kAnd, lhs, rhs)); + } + if (cmp_dir == ComparisonDirection::kEq) { + return ReplaceInstruction(select, rhs); + } + if (cmp_dir == ComparisonDirection::kNe) { + return ReplaceInstruction(select, lhs); + } + } } // select(pred, xs, dynamic_update_slice(xs, x, i)) diff --git a/third_party/xla/xla/service/algebraic_simplifier_test.cc b/third_party/xla/xla/service/algebraic_simplifier_test.cc index 00970a51546b1a..921098aa7565e8 100644 --- a/third_party/xla/xla/service/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/service/algebraic_simplifier_test.cc @@ -736,6 +736,95 @@ TEST_F(AlgebraicSimplifierTest, SelectPredPred2) { GmockMatch(m::Not(m::Parameter(0)))); } +// select(compare(a, b, GT/GE), a, b) => or(a, b), a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectGtCompare) { + for (const auto cmp_dir : {"GT", "GE"}) { + const auto kModuleStr = absl::StrFormat(R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=%s + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )", + cmp_dir); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Or(m::Parameter(0), m::Parameter(1)))); + } +} + +// select(compare(a, b, LT/LE), a, b) => and(a, b), a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectLtCompare) { + for (const auto cmp_dir : {"LT", "LE"}) { + const auto kModuleStr = absl::StrFormat(R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=%s + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )", + cmp_dir); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::And(m::Parameter(0), m::Parameter(1)))); + } +} + +// select(compare(a, b, EQ), a, b) => b, a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectEqCompare) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=EQ + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(1))); +} + +// select(compare(a, b, NE), a, b) => a, a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectNeCompare) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=NE + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(0))); +} + +// select(compare(a, b, NE), b, a) ≠> a - wrong operands order +TEST_F(AlgebraicSimplifierTest, SelectNeCompare_NegativeTestCase) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=NE + ROOT select = pred[8]{0} select(compare, p1, p0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); +} + // Test that select(pred, xs, dynamic_update_slice(xs, x, i)) is simplified // to dynamic_update_slice(xs, select(pred, dynamic_slice(xs, i), x), i) TEST_F(AlgebraicSimplifierTest, SelectDUSWithShapedPred) { diff --git a/third_party/xla/xla/service/all_reduce_simplifier.cc b/third_party/xla/xla/service/all_reduce_simplifier.cc index 5837ea49da0aae..0760433bda4489 100644 --- a/third_party/xla/xla/service/all_reduce_simplifier.cc +++ b/third_party/xla/xla/service/all_reduce_simplifier.cc @@ -19,14 +19,18 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal_util.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_replication_analysis.h" #include "xla/shape_util.h" #include "tsl/platform/errors.h" @@ -42,22 +46,33 @@ absl::StatusOr AllReduceSimplifier::Run( HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/false)); std::vector> all_reduces_to_replace; - // Returns the size of a replica group if all groups have the same size, or -1 - // if they have different sizes. - auto get_replica_group_size = - [this](const HloInstruction* all_reduce) -> int64_t { - if (all_reduce->replica_groups().empty()) { - return replica_count_; + // Returns the number of participants in a replica group if all groups have + // the same size, or -1 if they have different sizes. + // Number of participants depends on the mode of the collective operation. + auto get_participant_counts_for_replica_group = + [](const HloInstruction* all_reduce) -> absl::StatusOr { + const HloModuleConfig& config = all_reduce->GetModule()->config(); + TF_ASSIGN_OR_RETURN( + CollectiveOpGroupMode group_mode, + GetCollectiveOpGroupMode(all_reduce->channel_id().has_value(), + Cast(all_reduce) + ->use_global_device_ids())); + + int64_t num_devices = config.num_partitions(); + int64_t num_replicas = config.replica_count(); + TF_ASSIGN_OR_RETURN(std::vector participant_counts, + GetPariticipantCountsForReplicaGroups( + num_replicas, num_devices, + all_reduce->replica_groups(), group_mode)); + if (participant_counts.empty()) { + return -1; } - int64_t replica_group_size = -1; - for (const auto& group : all_reduce->replica_groups()) { - if (replica_group_size == -1) { - replica_group_size = group.replica_ids_size(); - } else if (replica_group_size != group.replica_ids_size()) { - return -1; - } + if (!absl::c_all_of(participant_counts, [&](int64_t participant_count) { + return participant_count == participant_counts[0]; + })) { + return -1; } - return replica_group_size; + return participant_counts[0]; }; bool changed = false; @@ -83,11 +98,24 @@ absl::StatusOr AllReduceSimplifier::Run( // optimize out (being fed within a tuple input). continue; } - if (!inst->IsCrossReplicaAllReduce()) { + if (!inst->IsCrossReplicaAllReduce() && !inst->IsCrossModuleAllReduce()) { continue; } - int64_t group_size = get_replica_group_size(inst); - if (group_size == -1) { + TF_ASSIGN_OR_RETURN(int64_t group_size, + get_participant_counts_for_replica_group(inst)); + + // We will not simplify this all reduce if any of the following is true: + // 1. All group do not have the same size. + // + // 2. The AllReduce is not cross replica and the group size is not 1. + // Since the replication analysis performed earlier is only for cross + // replica spmd. + // + // 3. The AllReduce is not cross replica and the module is not using spmd. + if (group_size == -1 || + (!inst->IsCrossReplicaAllReduce() && group_size != 1) || + (!inst->IsCrossReplicaAllReduce() && + !module->config().use_spmd_partitioning())) { continue; } if (replication->HloInstructionIsReplicatedAt(inst->operand(0), {}) || diff --git a/third_party/xla/xla/service/all_reduce_simplifier_test.cc b/third_party/xla/xla/service/all_reduce_simplifier_test.cc index 0843fc6df1a87b..e78881a0c19292 100644 --- a/third_party/xla/xla/service/all_reduce_simplifier_test.cc +++ b/third_party/xla/xla/service/all_reduce_simplifier_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" @@ -191,5 +192,95 @@ test { EXPECT_THAT(module->entry_computation()->root_instruction(), GmockMatch(m::Parameter(0))); } + +TEST_F(AllReduceSimplifierTest, TrivialSubgroupNonCrossReplicaAllReduce) { + const char* kModuleStr = R"( +HloModule m + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) +} + + +test { + p0 = f32[8,16] parameter(0), parameter_replication={false} + ROOT all-reduce = f32[8,16] all-reduce(p0), + channel_id=1, + use_global_device_ids=true, + replica_groups={{0},{1},{2},{3},{4},{5},{6},{7}}, + to_apply=sum +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(kModuleStr, /*replica_count=*/1, + /*num_partitions=*/8)); + module->mutable_config().set_use_spmd_partitioning(true); + AllReduceSimplifier simplifier(/*replica_count=*/1); + EXPECT_TRUE(simplifier.Run(module.get()).value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(0))); +} + +TEST_F(AllReduceSimplifierTest, NonCrossReplicaAllReduceAfterAllReduce) { + const char* kModuleStr = R"( +HloModule m + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) +} + + +test { + p0 = f32[8,16] parameter(0), parameter_replication={false} + all-reduce = f32[8,16] all-reduce(p0), + channel_id=1, + use_global_device_ids=true, + replica_groups={{0,2},{1,3},{4,6},{5,7}}, + to_apply=sum + ROOT all-reduce.1 = f32[8,16] all-reduce(all-reduce), + channel_id=2, + use_global_device_ids=true, + replica_groups={{0,4},{1,5},{2,6},{3,7}}, + to_apply=sum +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(kModuleStr, /*replica_count=*/1, + /*num_partitions=*/8)); + module->mutable_config().set_use_spmd_partitioning(true); + AllReduceSimplifier simplifier(/*replica_count=*/1); + EXPECT_FALSE(simplifier.Run(module.get()).value()); +} + +TEST_F(AllReduceSimplifierTest, MPMDNonCrossReplicaAllReduce) { + const char* kModuleStr = R"( +HloModule m + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) +} + +test { + p0 = f32[8,16] parameter(0), parameter_replication={false} + ROOT all-reduce = f32[8,16] all-reduce(p0), + channel_id=1, + replica_groups={{0},{1}}, + to_apply=sum +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(kModuleStr, /*replica_count=*/2, + /*num_partitions=*/1)); + // Mark as MPMD. + module->mutable_config().set_use_spmd_partitioning(false); + AllReduceSimplifier simplifier(/*replica_count=*/2); + EXPECT_FALSE(simplifier.Run(module.get()).value()); +} } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/collective_ops_utils.cc b/third_party/xla/xla/service/collective_ops_utils.cc index 01c4bba5abfefb..5bd802c343f523 100644 --- a/third_party/xla/xla/service/collective_ops_utils.cc +++ b/third_party/xla/xla/service/collective_ops_utils.cc @@ -515,8 +515,12 @@ absl::StatusOr> GetPariticipantCountsForReplicaGroups( switch (group_mode) { case CollectiveOpGroupMode::kCrossReplica: { - participant_counts.resize(participating_replica_groups.size(), - num_partitions); + for (const auto& replica_group : participating_replica_groups) { + for (int partition_id = 0; partition_id < num_partitions; + ++partition_id) { + participant_counts.push_back(replica_group.replica_ids().size()); + } + } return participant_counts; } case CollectiveOpGroupMode::kCrossPartition: { diff --git a/third_party/xla/xla/service/collective_ops_utils_test.cc b/third_party/xla/xla/service/collective_ops_utils_test.cc index 9fbcaf5f1f2ba2..c71776323f869f 100644 --- a/third_party/xla/xla/service/collective_ops_utils_test.cc +++ b/third_party/xla/xla/service/collective_ops_utils_test.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include +#include +#include #include "absl/algorithm/container.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" @@ -129,6 +131,21 @@ TEST(CollectiveOpsUtilsTest, CollectiveWithChannelId2) { EXPECT_EQ(IsOrHasCollectiveWithChannelId(fusion2.get()), nullptr); } +// Creates a container of ReplicaGroups. +std::vector CreateReplicaGroups( + const std::vector> &replica_groups) { + std::vector result; + result.reserve(replica_groups.size()); + for (const auto &replica_group : replica_groups) { + ReplicaGroup group; + for (auto id : replica_group) { + group.add_replica_ids(id); + } + result.push_back(group); + } + return result; +} + } // namespace // Tests for GetCollectOpGroupMode @@ -190,7 +207,7 @@ namespace GetParticipatingDevicesTest { // expected output corresponding to those values. struct TestCase { xla::Array2D device_assignment; - std::vector> replica_groups; + std::vector> replica_groups; bool has_channel_id; std::optional use_global_device_ids; @@ -455,15 +472,8 @@ TEST_P(GetParticipatingDevicesTest, Test) { } } - std::vector replica_groups; - absl::c_transform(tc.replica_groups, std::back_inserter(replica_groups), - [](const std::vector &ids) { - ReplicaGroup group; - for (int id : ids) { - group.add_replica_ids(id); - } - return group; - }); + std::vector replica_groups = + CreateReplicaGroups(tc.replica_groups); absl::StatusOr group_mode = GetCollectiveOpGroupMode(tc.has_channel_id, tc.use_global_device_ids); @@ -518,4 +528,77 @@ INSTANTIATE_TEST_SUITE_P(GetParticipatingDevices, GetParticipatingDevicesTest, testing::ValuesIn(GetTestCases())); } // namespace GetParticipatingDevicesTest + +namespace GetPariticipantCountsForReplicaGroupsTest { + +struct TestCase { + std::string test_name; + std::vector> replica_groups; + CollectiveOpGroupMode group_mode; + int64_t num_replicas; + int64_t num_partitions; + std::vector expected; +}; + +class GetPariticipantCountsForReplicaGroupsTest + : public testing::TestWithParam {}; + +TEST_P(GetPariticipantCountsForReplicaGroupsTest, Test) { + const TestCase &tc = GetParam(); + + std::vector replica_groups = + CreateReplicaGroups(tc.replica_groups); + TF_ASSERT_OK_AND_ASSIGN( + std::vector actual, + GetPariticipantCountsForReplicaGroups(tc.num_replicas, tc.num_partitions, + replica_groups, tc.group_mode)); + EXPECT_THAT(actual, testing::ElementsAreArray(tc.expected)); +} + +std::vector GetTestCases() { + return { + { + "CrossReplicaEmptyGroup", + {}, + CollectiveOpGroupMode::kCrossReplica, + 8, + 1, + {8}, + }, + { + "CrossReplicaWithPartitions", + {{0, 1}, {2, 3}}, + CollectiveOpGroupMode::kCrossReplica, + 4, + 2, + {2, 2, 2, 2}, + }, + { + "CrossReplicaAndPartition", + {{0, 1}, {2, 3}}, + CollectiveOpGroupMode::kCrossReplicaAndPartition, + 4, + 2, + {4, 4}, + }, + { + "FlattenedID", + {{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}}, + CollectiveOpGroupMode::kFlattenedID, + 4, + 2, + {1, 1, 1, 1, 1, 1, 1, 1}, + }, + }; +} +INSTANTIATE_TEST_SUITE_P( + GetPariticipantCountsForReplicaGroups, + GetPariticipantCountsForReplicaGroupsTest, + testing::ValuesIn(GetTestCases()), + [](const testing::TestParamInfo< + GetPariticipantCountsForReplicaGroupsTest::ParamType> &info) { + return info.param.test_name; + }); + +} // namespace GetPariticipantCountsForReplicaGroupsTest } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc index ce1dfa4f1c7863..8f5e26124f24a4 100644 --- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc +++ b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc @@ -378,8 +378,16 @@ absl::StatusOr HandleAgWindowedEinsumLoop(HloComputation* comp, return changed; } +static int64_t GetAgActivationCacheIndex(const HloInstruction* while_loop) { + const HloInstruction* loop_tuple = while_loop->operand(0); + const Shape& tuple_shape = loop_tuple->shape(); + CHECK(tuple_shape.IsTuple()); + return tuple_shape.tuple_shapes_size(); +} + absl::Status ProcessWindowedEinsumLoopForActivationCaching( - GpuWindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop) { + GpuWindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop, + HloInstruction* ag_with_shared_operand) { HloInstruction* loop = ag_loop.loop; // Transform the while body to cache the allgathered result in the // output buffer to be consumed by the dot @@ -392,15 +400,61 @@ absl::Status ProcessWindowedEinsumLoopForActivationCaching( } // Get the output operand of the full buffer. HloInstruction* root = while_body->root_instruction(); + // Change loop body to include the new input and output element. + HloInstruction* input_tuple = while_body->parameter_instruction(0); + const Shape& input_shape = input_tuple->shape(); // The full buffer that we will use to cache the accumulated activation - // is the 4th operand in the output tuple. - int64_t full_cache_buffer_index = 3; + // is the last operand in the output tuple. + int64_t full_cache_buffer_index = GetAgActivationCacheIndex(loop); + std::vector new_input_shapes(input_shape.tuple_shapes().begin(), + input_shape.tuple_shapes().end()); + new_input_shapes.push_back(ag_with_shared_operand->shape()); + // Update body input shape + Shape new_input_shape = ShapeUtil::MakeTupleShape(new_input_shapes); + *input_tuple->mutable_shape() = new_input_shape; HloInstruction* full_buffer_output_gte = - root->mutable_operand(full_cache_buffer_index); - HloInstruction* new_full_buffer_output; + while_body->AddInstruction(HloInstruction::CreateGetTupleElement( + ag_with_shared_operand->shape(), input_tuple, + full_cache_buffer_index)); + + // Update condition input shape + HloComputation* cond_comp = loop->while_condition(); + HloInstruction* cond_input_tuple = cond_comp->parameter_instruction(0); + *cond_input_tuple->mutable_shape() = new_input_shape; + + // Update input to the while instruction in parent computation + HloInstruction* original_while_input = loop->mutable_operand(0); + HloComputation* parent_comp = loop->parent(); + std::vector new_operands( + original_while_input->operands().begin(), + original_while_input->operands().end()); + new_operands.push_back( + parent_comp->AddInstruction(HloInstruction::CreateBroadcast( + ag_with_shared_operand->shape(), + parent_comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(new_input_shapes[0].element_type()))), + {}))); + HloInstruction* new_while_input = + parent_comp->AddInstruction(HloInstruction::CreateTuple(new_operands)); + TF_RETURN_IF_ERROR( + loop->ReplaceOperandWithDifferentShape(0, new_while_input)); + TF_RETURN_IF_ERROR(parent_comp->ReplaceInstructionWithDifferentShape( + original_while_input, new_while_input)); + *loop->mutable_shape() = new_input_shape; + + HloInstruction* new_full_buffer_output = nullptr; // Find the DUS in the loop body and re-use the slice indices // This should just be a constant(0) HloInstruction* dus_boundary_constant; + // The slice we need this time is the output of the first + // collective-permute + HloInstruction* first_cp_output; + for (HloInstruction* gte_user : input_gte->users()) { + if (gte_user->opcode() == HloOpcode::kCollectivePermute) { + first_cp_output = gte_user; + break; + } + } for (HloInstruction* inst : while_body->MakeInstructionPostOrder()) { HloInstruction* slice_indices; // If we have a DUS(PARAM,DS) pattern, we need to update the output @@ -434,24 +488,68 @@ absl::Status ProcessWindowedEinsumLoopForActivationCaching( dus_boundary_constant->shape(), slice_indices)); VLOG(5) << "Created slice op for second slice: " << slice_indices->ToString(); - // The slice we need this time is the output of the first - // collective-permute - HloInstruction* cp_output; - for (HloInstruction* gte_user : input_gte->users()) { - if (gte_user->opcode() == HloOpcode::kCollectivePermute) { - cp_output = gte_user; - break; - } - } new_full_buffer_output = while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( full_buffer_output_gte->shape(), full_buffer_output_gte, - cp_output, + first_cp_output, {dus_boundary_constant, slice_indices, dus_boundary_constant})); } + + // If we have a Dot(DS(parameter_index1)), then operands are sharded along + // the contracting dim. Slice indices will be the contracting dim's slices. + HloInstruction* slice_index; + HloInstruction* ds_index_constant; + HloInstruction* remainder; + HloInstruction* ds_param; + // There will be 2 dynamic-slices for unrolled loops, match for each one to + // get the slice index which will be used to write the corresponding + // received shard into cached activation buffer. For unrolled loops, we need + // to write to the final buffer twice per iteration, so we need to match for + // the correct slice index based on each DS. + if (Match(inst, m::Dot(m::Op(), m::DynamicSlice(&ds_param))) && + Match(ds_param->operand(0), m::GetTupleElement(m::Parameter(), 1))) { + for (int64_t ds_op_i = 1; ds_op_i < ds_param->operands().size(); + ds_op_i++) { + if (!Match( + ds_param->mutable_operand(ds_op_i), + m::Reshape(&slice_index, m::DynamicSlice(m::Constant(), + m::Op(&remainder)))) && + !Match(ds_param->mutable_operand(ds_op_i), + m::Constant(&ds_index_constant))) { + return absl::OkStatus(); + } + } + // First DS has slice index calculated based on loop iterator + // Remainder(add(gte, partition_id)) + if (Match(remainder, + m::Remainder(m::Add(m::GetTupleElement(), m::Op()), m::Op()))) { + full_buffer_output_gte = + while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + full_buffer_output_gte->shape(), full_buffer_output_gte, + input_gte, + {ds_index_constant, ds_index_constant, slice_index})); + } + // Second DS has slice index calculated based on loop iterator+1 hence + // Remainder(add(add(gte, 1), partition_id)) + if (Match(remainder, + m::Remainder( + m::Add(m::Add(m::GetTupleElement(), m::Op()), m::Op()), + m::Op()))) { + new_full_buffer_output = + while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + full_buffer_output_gte->shape(), full_buffer_output_gte, + first_cp_output, + {ds_index_constant, ds_index_constant, slice_index})); + } + } } - TF_RETURN_IF_ERROR(root->ReplaceOperandWith(full_cache_buffer_index, - new_full_buffer_output)); + std::vector original_operands(root->operands().begin(), + root->operands().end()); + original_operands.push_back(new_full_buffer_output); + HloInstruction* new_output_tuple = while_body->AddInstruction( + HloInstruction::CreateTuple(original_operands)); + TF_RETURN_IF_ERROR( + while_body->ReplaceInstructionWithDifferentShape(root, new_output_tuple)); return absl::OkStatus(); } @@ -620,17 +718,20 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { VLOG(5) << "Found all-gather that shares the same operand with a " "windowed einsum loop : " << loop->ToString(); + + if (!ag_loop.consumed) { + TF_RETURN_IF_ERROR(ProcessWindowedEinsumLoopForActivationCaching( + ag_loop, ag_with_shared_operand)); + ag_loop.consumed = true; + } int64_t cache_output_index = dot->operand_index(ag_with_shared_operand); - HloInstruction* new_gte = comp->AddInstruction( - HloInstruction::CreateGetTupleElement(loop, 3)); + HloComputation* comp = dot->parent(); + HloInstruction* new_gte = + comp->AddInstruction(HloInstruction::CreateGetTupleElement( + loop, GetAgActivationCacheIndex(loop) - 1)); TF_RETURN_IF_ERROR( dot->ReplaceOperandWith(cache_output_index, new_gte)); TF_RETURN_IF_ERROR(comp->RemoveInstruction(ag_with_shared_operand)); - if (!ag_loop.consumed) { - TF_RETURN_IF_ERROR( - ProcessWindowedEinsumLoopForActivationCaching(ag_loop)); - ag_loop.consumed = true; - } } } // Rewrites an all-to-all+gemm into multiple independent partial a2a+gemms diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc index 23257e1c71a34b..6f23319980e90c 100644 --- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc @@ -269,23 +269,22 @@ ENTRY main.12_spmd { FindInstructionByName(module->entry_computation(), "dot.7"); // dot.7 should now consume output of the windowed einsum while loop. EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement); - EXPECT_EQ(inst->operand(0)->tuple_index(), 3); + EXPECT_EQ(inst->operand(0)->tuple_index(), 5); EXPECT_EQ(inst->operand(0)->operand(0), ag_loop); // while loop's root should now have a chain of DUS. HloInstruction* ag_while_root = ag_loop->while_body()->root_instruction(); EXPECT_THAT(ag_while_root, GmockMatch(m::Tuple( - m::Op(), m::Op(), m::Op(), + m::Op(), m::Op(), m::Op(), m::Op(), m::Op(), m::DynamicUpdateSlice( m::DynamicUpdateSlice( m::GetTupleElement(m::Parameter()) .WithPredicate([](const HloInstruction* instr) { - return instr->tuple_index() == 3; + return instr->tuple_index() == 5; }), m::Op(), m::Op(), m::Op(), m::Op()), - m::Op(), m::Op(), m::Op(), m::Op()), - m::Op()))); + m::Op(), m::Op(), m::Op(), m::Op())))); } TEST_F(GpuWindowedEinsumHanlderTest, A2aGemmHaveStreamIds) { constexpr absl::string_view kHloString = R"( @@ -838,5 +837,82 @@ ENTRY main.9_spmd { )"); } +TEST_F(GpuWindowedEinsumHanlderTest, + AgLoopsMultipleConsumersAreChainedWithShardedContratingDim) { + constexpr absl::string_view kHloString = R"( +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0})->bf16[4096,6288]{1,0}}, num_partitions=8 + +windowed_dot_general_body_ag { + param.195 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) parameter(0) + get-tuple-element.588 = bf16[16,2048,512]{2,1,0} get-tuple-element(param.195), index=0 + collective-permute.194 = bf16[16,2048,512]{2,1,0} collective-permute(get-tuple-element.588), channel_id=446, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}} + collective-permute.195 = bf16[16,2048,512]{2,1,0} collective-permute(collective-permute.194), channel_id=447, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}} + get-tuple-element.589 = bf16[4096,6288]{1,0} get-tuple-element(param.195), index=1 + get-tuple-element.590 = bf16[16,2048,6288]{2,1,0} get-tuple-element(param.195), index=2 + constant.11432 = s32[8]{0} constant({0, 512, 1024, 1536, 2048, 2560, 3072, 3584}) + get-tuple-element.592 = u32[] get-tuple-element(param.195), index=4 + partition-id.194 = u32[] partition-id() + add.4309 = u32[] add(get-tuple-element.592, partition-id.194) + constant.11431 = u32[] constant(8) + remainder.194 = u32[] remainder(add.4309, constant.11431) + dynamic-slice.388 = s32[1]{0} dynamic-slice(constant.11432, remainder.194), dynamic_slice_sizes={1} + reshape.12959 = s32[] reshape(dynamic-slice.388) + constant.11433 = s32[] constant(0) + dynamic-slice.389 = bf16[512,6288]{1,0} dynamic-slice(get-tuple-element.589, reshape.12959, constant.11433), dynamic_slice_sizes={512,6288} + dot.244 = bf16[16,2048,6288]{2,1,0} dot(get-tuple-element.588, dynamic-slice.389), lhs_contracting_dims={2}, rhs_contracting_dims={0} + add.4310 = bf16[16,2048,6288]{2,1,0} add(get-tuple-element.590, dot.244) + constant.11434 = u32[] constant(1) + add.4312 = u32[] add(get-tuple-element.592, constant.11434) + add.4313 = u32[] add(add.4312, partition-id.194) + remainder.195 = u32[] remainder(add.4313, constant.11431) + dynamic-slice.390 = s32[1]{0} dynamic-slice(constant.11432, remainder.195), dynamic_slice_sizes={1} + reshape.12960 = s32[] reshape(dynamic-slice.390) + dynamic-slice.391 = bf16[512,6288]{1,0} dynamic-slice(get-tuple-element.589, reshape.12960, constant.11433), dynamic_slice_sizes={512,6288} + dot.245 = bf16[16,2048,6288]{2,1,0} dot(collective-permute.194, dynamic-slice.391), lhs_contracting_dims={2}, rhs_contracting_dims={0} + add.4314 = bf16[16,2048,6288]{2,1,0} add(add.4310, dot.245) + get-tuple-element.591 = bf16[16,2048,6288]{2,1,0} get-tuple-element(param.195), index=3 + add.4315 = u32[] add(add.4312, constant.11434) + ROOT tuple.98 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) tuple(collective-permute.195, get-tuple-element.589, add.4314, get-tuple-element.591, add.4315) +} // windowed_dot_general_body_ag + +windowed_dot_general_cond_ag { + param = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) parameter(0) + get-tuple-element = u32[] get-tuple-element(param), index=4 + constant = u32[] constant(4) + ROOT compare = pred[] compare(get-tuple-element, constant), direction=LT +} + +ENTRY main.12_spmd { + param.4 = bf16[16,2048,512]{2,1,0} parameter(0) + param.5 = bf16[4096,6288]{1,0} parameter(1) + constant.22 = bf16[] constant(0) + broadcast = bf16[16,2048,6288]{2,1,0} broadcast(constant.22), dimensions={} + constant.24 = u32[] constant(0) + tuple.2 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) tuple(param.4, param.5, broadcast, broadcast, constant.24) + while = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag + get-tuple-element.13 = bf16[16,2048,6288]{2,1,0} get-tuple-element(while), index=2 + all-gather = bf16[16,2048,4096]{2,1,0} all-gather(param.4), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={2}, use_global_device_ids=true + param.6 = bf16[16,2048,6288]{2,1,0} parameter(2) + ROOT dot.7 = bf16[4096,6288]{1,0} dot(all-gather, param.6), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + GpuWindowedEinsumHandler gpu_handler; + bool changed; + TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); + EXPECT_TRUE(changed); + + HloInstruction* ag_loop = + FindInstructionByName(module->entry_computation(), "while"); + HloInstruction* inst = + FindInstructionByName(module->entry_computation(), "dot.7"); + // dot.7 should now consume output of the windowed einsum while loop. + EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement); + EXPECT_EQ(inst->operand(0)->tuple_index(), 5); + EXPECT_EQ(inst->operand(0)->operand(0), ag_loop); +} } // namespace } // namespace xla::gpu