From 654d96257b68f592bcdaf7ed775c1cbba964abe0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 19 Jun 2024 02:03:43 -0700 Subject: [PATCH 001/256] Update GraphDef version to 1898. PiperOrigin-RevId: 644672001 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 5f8268ae46be11..535b7ffad284fc 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1897 // Updated: 2024/6/18 +#define TF_GRAPH_DEF_VERSION 1898 // Updated: 2024/6/19 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 51f415c7e3f2a52229828414e22574650d333c6a Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Wed, 19 Jun 2024 02:30:32 -0700 Subject: [PATCH 002/256] We have received reports of compiler hangs which we are investigating. Reverts 38dd164bc36e7a2506235ea2561d71da4185bdca PiperOrigin-RevId: 644678893 --- third_party/xla/xla/BUILD | 5 -- third_party/xla/xla/debug_options_flags.cc | 4 -- .../tests/element_wise_row_vectorization.hlo | 25 +++++++++ .../service/gpu/tests/gpu_unrolling_test.cc | 52 +++++++++++++++++++ 4 files changed, 77 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index c8e53111c0fc86..5bb6f608bf0599 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -4,10 +4,6 @@ load( "tf_proto_library", ) load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") -load( - "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", - "if_cuda_is_configured", -) load("//third_party/compute_library:build_defs.bzl", "if_enable_acl") # Placeholder: load py_proto_library @@ -1103,7 +1099,6 @@ cc_library( ], hdrs = ["debug_options_flags.h"], copts = if_enable_acl(["-DXLA_CPU_USE_ACL=1"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), visibility = internal_visibility([":friends"]), deps = [ diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index fa4cb08a82a3f7..979e90d3933d60 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -238,11 +238,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_nccl_collective_max_nchannels(0); opts.set_xla_gpu_nccl_p2p_max_nchannels(0); -#if GOOGLE_CUDA - opts.set_xla_gpu_mlir_emitter_level(1); -#else opts.set_xla_gpu_mlir_emitter_level(0); -#endif opts.set_xla_gpu_max_mlir_kernels(0); opts.set_xla_gpu_skip_mlir_kernels(0); diff --git a/third_party/xla/xla/service/gpu/tests/element_wise_row_vectorization.hlo b/third_party/xla/xla/service/gpu/tests/element_wise_row_vectorization.hlo index 3e75fceb48f530..b49e155da0a685 100644 --- a/third_party/xla/xla/service/gpu/tests/element_wise_row_vectorization.hlo +++ b/third_party/xla/xla/service/gpu/tests/element_wise_row_vectorization.hlo @@ -180,6 +180,31 @@ ENTRY main { // ----- +HloModule MOF, is_scheduled=true + +%fused_computation.4 (param_0: f32[672], param_1: f32[512,14,14,672]) -> (f32[512,14,14,672], f32[512,14,14,672]) { + %param_0 = f32[672]{0} parameter(0) + %broadcast = f32[512,14,14,672]{3,2,1,0} broadcast(%param_0), dimensions={3} + %param_1 = f32[512,14,14,672]{3,2,1,0} parameter(1) + %add = f32[512,14,14,672]{3,2,1,0} add(%broadcast, %param_1) + %neg = f32[512,14,14,672]{3,2,1,0} negate(%add) + ROOT tuple = (f32[512,14,14,672]{3,2,1,0}, f32[512,14,14,672]{3,2,1,0}) tuple(%add, %neg) +} + +ENTRY main { + %param_0 = f32[672]{0} parameter(0) + %param_1 = f32[512,14,14,672]{3,2,1,0} parameter(1) + + ROOT %fusion.4 = (f32[512,14,14,672]{3,2,1,0}, f32[512,14,14,672]) fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation.4 +} + +// Check that we didn't do anything. The block size didn't change. +// CHECK-LABEL: fusion_4 +// CHECK: .reqntid 128, 1, 1 +// CHECK: ld.global.nc.f + +// ----- + HloModule ScalarBroadcasting, is_scheduled=true %fused_computation.5 (param_0: f32[], param_1: f32[512,14,14,672]) -> f32[512,14,14,672] { diff --git a/third_party/xla/xla/service/gpu/tests/gpu_unrolling_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_unrolling_test.cc index 9a140417cb0efe..8d0363478a75b9 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_unrolling_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_unrolling_test.cc @@ -136,6 +136,58 @@ TEST_F(GpuUnrollingTest, DisabledUnrollUnfusedCosine) { /*match_optimized_ir=*/true); } +TEST_F(GpuUnrollingTest, DisabledUnrollUnfusedPower) { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + config.set_debug_options(debug_options); + + const char *const kUnfusedAddModule = R"( + HloModule test_module + + ENTRY SineFunc { + p0 = f32[1600000]{0} parameter(0) + ROOT s = f32[1600000]{0} power(p0, p0) + })"; + auto hlo_module = + ParseAndReturnVerifiedModule(kUnfusedAddModule, config).value(); + + // There is only 1 load, because we pass the `p0` parameter to the kernel only + // once. + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK: load float +; CHECK-NOT: load float +; CHECK: } + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuUnrollingTest, DisabledUnrollUnfusedAtan2) { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + config.set_debug_options(debug_options); + + const char *const kUnfusedAddModule = R"( + HloModule test_module + + ENTRY SineFunc { + p0 = f32[16000000]{0} parameter(0) + ROOT s = f32[16000000]{0} atan2(p0, p0) + })"; + auto hlo_module = + ParseAndReturnVerifiedModule(kUnfusedAddModule, config).value(); + + // There is only 1 load, because we pass the `p0` parameter to the kernel only + // once. + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK: load float +; CHECK-NOT: load float +; CHECK: } + )", + /*match_optimized_ir=*/true); +} + TEST_F(GpuUnrollingTest, UnrollMultiOutputFusion) { HloModuleConfig config; auto debug_options = HloTestBase::GetDebugOptionsForTest(); From ca60e393c97b6c4af847dd32534beb1ac146d258 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 19 Jun 2024 02:53:50 -0700 Subject: [PATCH 003/256] Disable `unary_ops_test` on Arm64 machines This is intended to be temporary to unblock the Linux Arm64 GitHub presubmit. PiperOrigin-RevId: 644684336 --- tensorflow/compiler/tests/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index dea6810b86986e..a0906bfd7b0a59 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1720,6 +1720,7 @@ tf_xla_py_strict_test( python_version = "PY3", shard_count = 20, tags = [ + "no_aarch64", # TODO(b/348125886) "no_cuda_asan", # times out "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "noasan", #times out From a2db2afa0413afbcfdf0b9abfe135eb39852ceeb Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Wed, 19 Jun 2024 03:01:09 -0700 Subject: [PATCH 004/256] Make :redzone_allocator_kernel_cuda a cc_library It used to be a `gpu_kernel_library` which means it gets compiled for device using NVCC in OSS. But the sources files don't contain any CUDA code, so there is no need for that. https://github.com/openxla/xla/issues/13460 reports that this behaviour even fails with some versions of NVCC because NVCC can't deal with C++ template magic being pulled in by Abseil. PiperOrigin-RevId: 644685844 --- third_party/xla/xla/stream_executor/gpu/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 97ec85bcc4fead..9b9237f12189d2 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -491,7 +491,7 @@ gpu_only_cc_library( ]), ) -gpu_kernel_library( +cc_library( name = "redzone_allocator_kernel_cuda", srcs = [ "redzone_allocator_kernel.h", From 80059e48b0c965081ae32053b9fd8d5f2cba5d2c Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Wed, 19 Jun 2024 03:02:29 -0700 Subject: [PATCH 005/256] [XLA:GPU] Two minor cleanups. PiperOrigin-RevId: 644686184 --- third_party/xla/xla/service/gpu/BUILD | 1 - third_party/xla/xla/service/gpu/backend_configs.proto | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index afb2944a8e0b59..bf3e191aec5e35 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1304,7 +1304,6 @@ xla_test( "//xla/stream_executor:device_description", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/lib/core:status_test_util", diff --git a/third_party/xla/xla/service/gpu/backend_configs.proto b/third_party/xla/xla/service/gpu/backend_configs.proto index 93454a67aa0190..5bf6fbc19e61ec 100644 --- a/third_party/xla/xla/service/gpu/backend_configs.proto +++ b/third_party/xla/xla/service/gpu/backend_configs.proto @@ -177,7 +177,7 @@ message FusionBackendConfig { // present, we use the default Triton config. AutotuneResult.TritonGemmKey triton_gemm_config = 2; - // Only valid when kind == "__triton_fusion" for now. Code generation of such + // Only valid when kind == "__triton" for now. Code generation of such // fusions will fail if this field is not set. BlockLevelFusionConfig block_level_fusion_config = 6; From eaef53bb90c18fd6698dc207f47d6ef9bf3310e8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 19 Jun 2024 03:04:00 -0700 Subject: [PATCH 006/256] Automated Code Change PiperOrigin-RevId: 644686640 --- third_party/xla/third_party/tsl/tsl/platform/env.cc | 3 ++- third_party/xla/third_party/tsl/tsl/platform/env.h | 4 ++-- third_party/xla/third_party/tsl/tsl/platform/file_system.cc | 6 +++--- third_party/xla/third_party/tsl/tsl/platform/file_system.h | 4 ++-- .../xla/third_party/tsl/tsl/platform/retrying_file_system.h | 4 ++-- 5 files changed, 11 insertions(+), 10 deletions(-) diff --git a/third_party/xla/third_party/tsl/tsl/platform/env.cc b/third_party/xla/third_party/tsl/tsl/platform/env.cc index 77f48b3372e1eb..789725e8856b94 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/env.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/env.cc @@ -328,7 +328,8 @@ absl::Status Env::HasAtomicMove(const string& path, bool* has_atomic_move) { return fs->HasAtomicMove(path, has_atomic_move); } -Status Env::CanCreateTempFile(const string& fname, bool* can_create_temp_file) { +absl::Status Env::CanCreateTempFile(const string& fname, + bool* can_create_temp_file) { FileSystem* fs; TF_RETURN_IF_ERROR(GetFileSystemForFile(fname, &fs)); return fs->CanCreateTempFile(fname, can_create_temp_file); diff --git a/third_party/xla/third_party/tsl/tsl/platform/env.h b/third_party/xla/third_party/tsl/tsl/platform/env.h index 37abc1dee97d54..0952517f9b7f8c 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/env.h +++ b/third_party/xla/third_party/tsl/tsl/platform/env.h @@ -352,8 +352,8 @@ class Env { /// If this returns false, TensorFlow will write directly to output files /// instead of creating a temporary file and swapping it in. This may mean /// that incomplete writes are visible to consumers. - Status CanCreateTempFile(const std::string& fname, - bool* can_create_temp_file); + absl::Status CanCreateTempFile(const std::string& fname, + bool* can_create_temp_file); /// Stores the size of `fname` in `*file_size`. absl::Status GetFileSize(const std::string& fname, uint64* file_size); diff --git a/third_party/xla/third_party/tsl/tsl/platform/file_system.cc b/third_party/xla/third_party/tsl/tsl/platform/file_system.cc index 68d0fcf0499ca5..ee385af7354074 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/file_system.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/file_system.cc @@ -95,10 +95,10 @@ absl::Status FileSystem::HasAtomicMove(const string& path, return absl::OkStatus(); } -Status FileSystem::CanCreateTempFile(const std::string& fname, - bool* can_create_temp_file) { +absl::Status FileSystem::CanCreateTempFile(const std::string& fname, + bool* can_create_temp_file) { *can_create_temp_file = true; - return OkStatus(); + return absl::OkStatus(); } void FileSystem::FlushCaches(TransactionToken* token) {} diff --git a/third_party/xla/third_party/tsl/tsl/platform/file_system.h b/third_party/xla/third_party/tsl/tsl/platform/file_system.h index 4b728a42c4d507..67209ed491055f 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/file_system.h +++ b/third_party/xla/third_party/tsl/tsl/platform/file_system.h @@ -392,8 +392,8 @@ class FileSystem { /// to determine if there needs to be a temp location to safely write objects. /// If the file system cannot create a temp file, it's possibile that /// uncomplete result may appear in the given file. - virtual Status CanCreateTempFile(const std::string& fname, - bool* can_create_temp_file); + virtual absl::Status CanCreateTempFile(const std::string& fname, + bool* can_create_temp_file); /// \brief Flushes any cached filesystem objects from memory. virtual void FlushCaches() { FlushCaches(nullptr); } diff --git a/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system.h b/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system.h index 88da8787c3d618..a64ecc20e960ff 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system.h +++ b/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system.h @@ -151,8 +151,8 @@ class RetryingFileSystem : public FileSystem { return base_file_system_->HasAtomicMove(path, has_atomic_move); } - Status CanCreateTempFile(const std::string& fname, - bool* can_create_temp_file) override { + absl::Status CanCreateTempFile(const std::string& fname, + bool* can_create_temp_file) override { // this method does not need to be retried return base_file_system_->CanCreateTempFile(fname, can_create_temp_file); } From cdafce893a83095e93018257db4ec0326f92bec7 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Wed, 19 Jun 2024 03:18:55 -0700 Subject: [PATCH 007/256] [XLA:GPU] [NFC] Remove argument which is never passed PiperOrigin-RevId: 644690105 --- .../xla/xla/service/gpu/runtime/sequential_thunk.cc | 2 +- third_party/xla/xla/service/gpu/runtime/thunk.cc | 7 +------ third_party/xla/xla/service/gpu/runtime/thunk.h | 4 +--- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc index c58f31b5207a7a..a0dc62fbc7155d 100644 --- a/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc @@ -33,7 +33,7 @@ SequentialThunk::SequentialThunk(ThunkInfo thunk_info, ThunkSequence thunks) std::string SequentialThunk::ToStringExtra(int indent) const { std::string result = "\n"; - absl::StrAppend(&result, thunks().ToString(indent + 1, nullptr)); + absl::StrAppend(&result, thunks().ToString(indent + 1)); return result; } diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.cc b/third_party/xla/xla/service/gpu/runtime/thunk.cc index 0c95623db733be..877db05fe04e89 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/thunk.cc @@ -303,9 +303,7 @@ std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) { return os << Thunk::KindToString(kind); } -std::string ThunkSequence::ToString( - int indent, - std::function get_thunk_annotation) const { +std::string ThunkSequence::ToString(int indent) const { const std::string indent_str(indent * 2, ' '); if (empty()) return indent_str + "No thunks."; @@ -324,9 +322,6 @@ std::string ThunkSequence::ToString( absl::StrAppend(&result, indent_str, kind_str, std::string(max_thunk_kind_len - kind_str.length(), ' '), "\t"); - if (get_thunk_annotation) { - absl::StrAppend(&result, get_thunk_annotation(thunk.get())); - } absl::StrAppend(&result, thunk->ToStringExtra(indent)); absl::StrAppend(&result, "\n"); } diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.h b/third_party/xla/xla/service/gpu/runtime/thunk.h index 6a271ce8f623ba..78698859bf8553 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/thunk.h @@ -460,9 +460,7 @@ class Thunk { // A sequence of thunks. class ThunkSequence : public std::vector> { public: - std::string ToString(int indent = 0, - std::function - get_thunk_annotation = nullptr) const; + std::string ToString(int indent = 0) const; }; std::ostream& operator<<(std::ostream& os, Thunk::Kind kind); From 160e76085434b413699e6979c800edaef838b302 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Wed, 19 Jun 2024 03:39:07 -0700 Subject: [PATCH 008/256] [XLA:GPU] Simplify TritonSupport tests by providing a standard ENTRY computation. PiperOrigin-RevId: 644694885 --- third_party/xla/xla/service/gpu/BUILD | 3 + .../xla/service/gpu/triton_support_test.cc | 120 ++---------------- .../xla/xla/service/gpu/triton_test_utils.cc | 50 ++++++++ .../xla/xla/service/gpu/triton_test_utils.h | 7 + 4 files changed, 72 insertions(+), 108 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index bf3e191aec5e35..dba824f313e4d6 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -671,6 +671,7 @@ cc_library( deps = [ ":gpu_device_info_for_tests", ":gpu_float_support", + ":ir_emission_utils", ":ir_emitter_triton", ":matmul_utils", "//xla:shape_util", @@ -684,6 +685,7 @@ cc_library( "//xla/stream_executor:device_description", "//xla/tests:filecheck", "//xla/tests:verified_hlo_module", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -693,6 +695,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/gpu/triton_support_test.cc b/third_party/xla/xla/service/gpu/triton_support_test.cc index 6c690bad4c47a9..3cf75a4cbad7e9 100644 --- a/third_party/xla/xla/service/gpu/triton_support_test.cc +++ b/third_party/xla/xla/service/gpu/triton_support_test.cc @@ -72,16 +72,9 @@ using BitcastOrReshapeTest = TritonSupportTestWithParam; TEST_P(BitcastOrReshapeTest, IsTritonSupportedBitcastOrReshape) { auto [data_type, opcode] = GetParam(); const std::string kHloTestTemplate = R"( -triton_computation { +ENTRY triton_computation { parameter_0 = $0[1,16,4]{2,1,0} parameter(0) ROOT bitcast_or_reshape = $0[64]{0} $1(parameter_0) -} - -ENTRY e { - parameter_0 = $0[1,16,4]{2,1,0} parameter(0) - ROOT root_op = $0[64]{0} fusion(parameter_0), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, @@ -120,17 +113,10 @@ TEST_P(UnaryElementwiseTest, IsTritonSupportedUnaryElementwise) { } const std::string kHloTestTemplate = R"( -triton_computation { +ENTRY triton_computation { parameter_0 = $0[33,68]{1,0} parameter(0) unary = $0[33,68]{1,0} $1(parameter_0) ROOT convert = f32[33,68]{1,0} convert(unary) -} - -ENTRY e { - parameter_0 = $0[33,68]{1,0} parameter(0) - ROOT root_op = f32[33,68]{1,0} fusion(parameter_0), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, @@ -184,18 +170,10 @@ TEST_P(BinaryElementwiseTest, IsTritonSupportedBinaryElementwise) { } const std::string kHloTestTemplate = R"( -triton_computation { +ENTRY triton_computation { parameter_0 = $0[11,63]{1,0} parameter(0) parameter_1 = $0[11,63]{1,0} parameter(1) ROOT binary = $0[11,63]{1,0} $1(parameter_0, parameter_1) -} - -ENTRY e { - parameter_0 = $0[11,63]{1,0} parameter(0) - parameter_1 = $0[11,63]{1,0} parameter(1) - ROOT triton_op = $0[11,63]{1,0} fusion(parameter_0, parameter_1), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, @@ -251,19 +229,11 @@ TEST_P(CompareTest, IsTritonSupportedCompare) { } const std::string kHloTestTemplate = R"( -triton_computation { +ENTRY triton_computation { parameter_0 = $0[11,63]{1,0} parameter(0) parameter_1 = $0[11,63]{1,0} parameter(1) compare = pred[11,63]{1,0} $1(parameter_0, parameter_1), direction=GE ROOT convert = f32[11,63]{1,0} convert(compare) -} - -ENTRY e { - parameter_0 = $0[11,63]{1,0} parameter(0) - parameter_1 = $0[11,63]{1,0} parameter(1) - ROOT triton_op = f32[11,63]{1,0} fusion(parameter_0, parameter_1), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, @@ -298,21 +268,12 @@ TEST_P(TernaryElementwiseTest, IsTritonSupportedTernaryElementwise) { } const std::string kHloTestTemplate = R"( -triton_computation { +ENTRY triton_computation { parameter_0 = $0[13,63]{1,0} parameter(0) parameter_1 = $0[13,63]{1,0} parameter(1) parameter_2 = pred[13,63]{1,0} parameter(2) ternary = $0[13,63]{1,0} $1(parameter_2, parameter_0, parameter_1) ROOT convert = f32[13,63]{1,0} convert(ternary) -} - -ENTRY e { - parameter_0 = $0[13,63]{1,0} parameter(0) - parameter_1 = $0[13,63]{1,0} parameter(1) - parameter_2 = pred[13,63]{1,0} parameter(2) - ROOT triton_op = f32[13,63]{1,0} fusion(parameter_0, parameter_1, parameter_2), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, @@ -353,18 +314,10 @@ add { ROOT add = $0[] add(Arg_0, Arg_1) } -triton_computation { +ENTRY triton_computation { parameter_0 = $0[125,127]{1,0} parameter(0) constant_0 = $0[] constant(0) ROOT reduce = $0[125]{0} $1(parameter_0, constant_0), dimensions={1}, to_apply=add -} - -ENTRY main { - parameter_0 = $0[125,127]{1,0} parameter(0) - ROOT triton_op = $0[125]{0} fusion(parameter_0), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, @@ -406,19 +359,11 @@ add { ROOT add = f32[] add(Arg_0, Arg_1) } -triton_computation { +ENTRY triton_computation { parameter_0 = f32[125,127]{1,0} parameter(0) constant_0 = bf16[] constant(0) convert_0 = f32[] convert(constant_0) ROOT reduce = f32[125]{0} reduce(parameter_0, convert_0), dimensions={1}, to_apply=add -} - -ENTRY main { - parameter_0 = f32[125,127]{1,0} parameter(0) - ROOT triton_op = f32[125]{0} fusion(parameter_0), kind=kCustom, - calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( @@ -442,18 +387,10 @@ add { ROOT add = f32[] add(Arg_0, Arg_1) } -triton_computation { +ENTRY triton_computation { parameter_0 = f32[2,125,127]{2,1,0} parameter(0) constant_0 = f32[] constant(0) ROOT reduce = f32[2]{0} reduce(parameter_0, constant_0), dimensions={1,2}, to_apply=add -} - -ENTRY main { - parameter_0 = f32[2,125,127]{2,1,0} parameter(0) - ROOT triton_op = f32[2]{0} fusion(parameter_0), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( @@ -479,18 +416,10 @@ add { ROOT add = f32[] add(Arg_0, Arg_1) } -triton_computation { +ENTRY triton_computation { parameter_0 = f32[125,127]{1,0} parameter(0) constant_0 = f32[] constant(0) ROOT reduce = f32[127]{0} reduce(parameter_0, constant_0), dimensions={0}, to_apply=add -} - -ENTRY main { - parameter_0 = f32[125,127]{1,0} parameter(0) - ROOT triton_op = f32[127]{0} fusion(parameter_0), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( @@ -520,19 +449,11 @@ add { ROOT pair = (f32[], f32[]) tuple(add_0, add_1) } -triton_computation { +ENTRY triton_computation { parameter_0 = f32[125,127] parameter(0) constant_0 = f32[] constant(0) tuple_0 = (f32[125]{0}, f32[125]{0}) reduce(parameter_0, parameter_0, constant_0, constant_0), dimensions={1}, to_apply=add ROOT reduce = f32[125]{0} get-tuple-element(tuple_0), index=0 -} - -ENTRY main { - parameter_0 = f32[125,127]{1,0} parameter(0) - ROOT triton_op = f32[125]{0} fusion(parameter_0), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( @@ -557,19 +478,10 @@ add { ROOT add = f32[] add(Arg_0, Arg_1) } -triton_computation { +ENTRY triton_computation { parameter_0 = f32[125,127]{1,0} parameter(0) init = f32[] parameter(1) ROOT reduce = f32[125]{0} reduce(parameter_0, init), dimensions={1}, to_apply=add -} - -ENTRY main { - parameter_0 = f32[125,127]{1,0} parameter(0) - parameter_1 = f32[] parameter(1) - ROOT triton_op = f32[125]{0} fusion(parameter_0, parameter_1), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( @@ -599,18 +511,10 @@ custom_call { ROOT custom_call = f32[] custom-call(Arg_0, Arg_1), custom_call_target="foo" } -triton_computation { +ENTRY triton_computation { parameter_0 = f32[125,127]{1,0} parameter(0) constant_0 = f32[] constant(0) ROOT reduce = f32[125]{0} reduce(parameter_0, constant_0), dimensions={1}, to_apply=custom_call -} - -ENTRY main { - parameter_0 = f32[125,127]{1,0} parameter(0) - ROOT triton_op = f32[125]{0} fusion(parameter_0), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( diff --git a/third_party/xla/xla/service/gpu/triton_test_utils.cc b/third_party/xla/xla/service/gpu/triton_test_utils.cc index bb5bbe765f858a..97e5925d8a42f2 100644 --- a/third_party/xla/xla/service/gpu/triton_test_utils.cc +++ b/third_party/xla/xla/service/gpu/triton_test_utils.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include #include +#include #include +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -41,6 +43,7 @@ limitations under the License. #include "xla/service/float_normalization.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/gpu_float_support.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/ir_emitter_triton.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/hlo_pass_pipeline.h" @@ -48,6 +51,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/tests/filecheck.h" #include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace xla::gpu { @@ -140,6 +144,46 @@ std::string TritonSupportTestParamsToString( absl::StrReplaceAll(HloOpcodeString(opcode), {{"-", "_"}})); } +namespace { + +// This function does nothing if the input module already has an entry +// computation whose root is a fusion. Otherwise, creates a new entry +// computation whose root is a fusion instruction that calls the original entry +// computation. The new fusion instruction uses the generic Triton backend kind. +absl::Status ConvertEntryToTritonFusion(HloModule* module) { + if (module->entry_computation()->root_instruction()->opcode() == + HloOpcode::kFusion) { + return absl::OkStatus(); + } + auto builder = HloComputation::Builder("entry"); + std::vector params; + for (auto& param : module->entry_computation()->parameter_instructions()) { + TF_ASSIGN_OR_RETURN( + auto param_clone, + builder.AddParameter(HloInstruction::CreateParameter( + param->parameter_number(), param->shape(), + absl::StrCat("param_", param->parameter_number())))); + params.push_back(param_clone); + } + + auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( + module->entry_computation()->root_instruction()->shape(), + HloInstruction::FusionKind::kCustom, params, + module->entry_computation())); + + gpu::GpuBackendConfig gpu_config; + gpu_config.mutable_fusion_backend_config()->set_kind(kTritonFusionKind); + TF_RETURN_IF_ERROR(fusion->set_backend_config(gpu_config)); + + auto new_entry = + module->AddComputationAndUnifyNamesAndIds(builder.Build(), + /*is_entry=*/false); + module->ReplaceEntryComputation(new_entry); + return absl::OkStatus(); +} + +} // namespace + absl::StatusOr TritonSupportTest::ParseTemplateAndGetInstruction( absl::string_view hlo_template, xla::PrimitiveType data_type, @@ -149,8 +193,14 @@ TritonSupportTest::ParseTemplateAndGetInstruction( HloOpcodeString(opcode)); TF_ASSIGN_OR_RETURN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); + TF_RETURN_IF_ERROR(ConvertEntryToTritonFusion(module.get())); const HloComputation* computation = module->GetComputationWithName("triton_computation"); + if (computation == module->entry_computation()) { + return absl::InvalidArgumentError( + "The `triton_computation` and the module's entry computation cannot be " + "the same."); + } const HloFusionInstruction* fusion = DynCast( module->entry_computation()->root_instruction()); if (fusion == nullptr) { diff --git a/third_party/xla/xla/service/gpu/triton_test_utils.h b/third_party/xla/xla/service/gpu/triton_test_utils.h index de503bf8bd36f8..c7cb16abdf36ec 100644 --- a/third_party/xla/xla/service/gpu/triton_test_utils.h +++ b/third_party/xla/xla/service/gpu/triton_test_utils.h @@ -129,6 +129,13 @@ class TritonSupportTest : public TritonFilecheckTest { // The provided template must contain a computation called // `triton_computation`. If the template contains parameters $0 and $1, they // will be replaced with the data type and opcode respectively. + // If the template's entry computation does not have a root fusion + // instruction, a new entry computation will be created. The new computation + // will have a root fusion instruction that has the same parameters as the + // `triton_computation` and contains a fusion instruction that calls the + // `triton_computation` with the generic Triton emitter. Tests that need + // the `__triton_gemm` backend kind should provide their own ENTRY + // computation. absl::StatusOr ParseTemplateAndGetInstruction( absl::string_view hlo_template, xla::PrimitiveType data_type, xla::HloOpcode opcode); From c72dbfc17b8aba90da3d26036387016cee23bc2f Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Wed, 19 Jun 2024 03:40:56 -0700 Subject: [PATCH 009/256] [XLA] Remove proto-based communication for service/client Originally, service/client interface in XLA was envisioned as a compilation service interface, with proto-based communication for RPCs. That compilation service never quite worked in that shape, and in the meantime, main compilation API moved to PjRT. Having lots of protos around complicates the XLA compilation API, and serialization/deserialization to protos also hurts performance. Since all Service usages are realistically local, let's just not use protos at the boundary and pass original datastructures. PiperOrigin-RevId: 644695252 --- third_party/xla/xla/client/BUILD | 1 - third_party/xla/xla/client/client.cc | 302 +------------- third_party/xla/xla/client/client.h | 27 +- third_party/xla/xla/client/global_data.cc | 73 ---- third_party/xla/xla/client/global_data.h | 33 +- third_party/xla/xla/client/local_client.cc | 7 +- third_party/xla/xla/client/local_client.h | 2 +- third_party/xla/xla/service/BUILD | 3 + .../xla/xla/service/compile_only_service.h | 25 +- third_party/xla/xla/service/service.cc | 369 +++++++++--------- third_party/xla/xla/service/service.h | 120 ++++-- third_party/xla/xla/tests/BUILD | 1 + third_party/xla/xla/tests/client_test.cc | 4 +- .../xla/xla/tests/gather_operation_test.cc | 3 +- third_party/xla/xla/xla.proto | 183 --------- 15 files changed, 325 insertions(+), 828 deletions(-) delete mode 100644 third_party/xla/xla/client/global_data.cc diff --git a/third_party/xla/xla/client/BUILD b/third_party/xla/xla/client/BUILD index e0c53df856f3b1..400eb877cbbce1 100644 --- a/third_party/xla/xla/client/BUILD +++ b/third_party/xla/xla/client/BUILD @@ -29,7 +29,6 @@ filegroup( cc_library( name = "global_data", - srcs = ["global_data.cc"], hdrs = ["global_data.h"], deps = [ "//xla:types", diff --git a/third_party/xla/xla/client/client.cc b/third_party/xla/xla/client/client.cc index a58737a181fd84..6e89947f237ff7 100644 --- a/third_party/xla/xla/client/client.cc +++ b/third_party/xla/xla/client/client.cc @@ -41,129 +41,28 @@ Client::~Client() = default; absl::StatusOr Client::Transfer(const GlobalData& data, const Shape* shape_with_layout) { - TransferToClientRequest request; - *request.mutable_data() = data.handle(); - if (shape_with_layout != nullptr) { - *request.mutable_shape_with_layout() = shape_with_layout->ToProto(); - } - TransferToClientResponse response; - - VLOG(1) << "making transfer request"; - VLOG(3) << "TransferToClientRequest: {" << request.DebugString() << "}"; - absl::Status s = stub_->TransferToClient(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - VLOG(3) << "TransferToClientResponse: {" << response.DebugString() << "}"; - - if (!response.has_literal()) { - return FailedPrecondition( - "server provided response without a literal in " - "TransferToClient request"); - } - return Literal::CreateFromProto(*response.mutable_literal()); + return stub_->TransferToClient(data, shape_with_layout); } absl::StatusOr> Client::TransferToServer( const LiteralSlice& literal, const DeviceHandle* device_handle) { - TransferToServerRequest request; - *request.mutable_literal() = literal.ToProto(); - if (device_handle) { - *request.mutable_device_handle() = *device_handle; - } - TransferToServerResponse response; - - VLOG(1) << "making transfer to server request"; - VLOG(3) << "TransferToServerRequest: {" << request.DebugString() << "}"; - absl::Status s = stub_->TransferToServer(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - VLOG(3) << "TransferToServerResponse: {" << response.DebugString() << "}"; - - if (!response.has_data()) { - return FailedPrecondition( - "server provided response without a data handle in " - "TransferToServer request"); - } - - return std::make_unique(stub_, response.data()); + return stub_->TransferToServer(literal, device_handle); } absl::Status Client::TransferToInfeed(const LiteralSlice& literal, int64_t replica_id, const DeviceHandle* device_handle) { - TransferToInfeedRequest request; - *request.mutable_literal() = literal.ToProto(); - if (device_handle) { - *request.mutable_device_handle() = *device_handle; - } - request.set_replica_id(replica_id); - TransferToInfeedResponse response; - - VLOG(1) << "making transfer to infeed request"; - VLOG(3) << "TransferToInfeedRequest: {" << request.DebugString() << "}"; - absl::Status s = stub_->TransferToInfeed(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - VLOG(3) << "TransferToInfeedResponse: {" << response.DebugString() << "}"; - return absl::OkStatus(); + return stub_->TransferToInfeed(literal, replica_id, device_handle); } absl::StatusOr Client::TransferFromOutfeed( const Shape* shape_with_layout, int64_t replica_id, const DeviceHandle* device_handle) { - TransferFromOutfeedRequest request; - if (device_handle) { - *request.mutable_device_handle() = *device_handle; - } - request.set_replica_id(replica_id); - if (shape_with_layout != nullptr) { - *request.mutable_shape_with_layout() = shape_with_layout->ToProto(); - } - TransferFromOutfeedResponse response; - - VLOG(1) << "making transfer from outfeed request"; - VLOG(3) << "TransferFromOutfeedRequest: {" << request.DebugString() << "}"; - absl::Status s = stub_->TransferFromOutfeed(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - VLOG(3) << "TransferFromOutfeedResponse: {" << response.DebugString() << "}"; - - if (!response.has_literal()) { - return FailedPrecondition( - "server provided response without a literal in " - "TransferToClient request"); - } - - return Literal::CreateFromProto(response.literal()); + return stub_->TransferFromOutfeed(shape_with_layout, replica_id, + device_handle); } -absl::Status Client::ResetDevice() { - ResetDeviceRequest request; - ResetDeviceResponse response; - - VLOG(1) << "making reset device request"; - VLOG(3) << "ResetDeviceRequest: {" << request.DebugString() << "}"; - absl::Status s = stub_->ResetDevice(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - VLOG(3) << "ResetDeviceResponse: {" << response.DebugString() << "}"; - return absl::OkStatus(); -} +absl::Status Client::ResetDevice() { return stub_->ResetDevice(); } absl::StatusOr Client::ExecuteAndTransfer( const XlaComputation& computation, absl::Span arguments, @@ -185,30 +84,7 @@ absl::StatusOr Client::ExecuteAndTransfer( absl::StatusOr Client::ComputeConstant( const XlaComputation& computation, const Layout* output_layout) const { - ComputeConstantGraphRequest request; - *request.mutable_computation() = computation.proto(); - if (output_layout != nullptr) { - *request.mutable_output_layout() = output_layout->ToProto(); - } - - ComputeConstantResponse response; - - VLOG(2) << "making compute-constant-graph request"; - absl::Status s = stub_->ComputeConstantGraph(&request, &response); - VLOG(2) << "done with request"; - - if (!s.ok()) { - return s; - } - - VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}"; - - if (!response.has_literal()) { - return Internal( - "no computed literal in the provided response in ComputeConstantGraph " - "request"); - } - return Literal::CreateFromProto(response.literal()); + return stub_->ComputeConstantGraph(computation, output_layout); } absl::StatusOr Client::LoadSnapshot(const HloSnapshot& module) { @@ -219,61 +95,19 @@ absl::StatusOr Client::LoadSnapshot(const HloSnapshot& module) { absl::StatusOr Client::Compile( const XlaComputation& computation, absl::Span argument_shapes, const ExecutionOptions* execution_options) { - CompileRequest request; - *request.mutable_computation() = computation.proto(); - - if (execution_options == nullptr) { - *request.mutable_execution_options() = CreateDefaultExecutionOptions(); - } else { - *request.mutable_execution_options() = *execution_options; - } - if (request.execution_options().device_handles_size() > 1) { - return InvalidArgument( - "Compiling with multiple device handles is not supported. Use " - "'Execute' instead."); - } - - // The argument shapes affect how the computation is compiled. - for (const auto& arg_shape : argument_shapes) { - *request.add_input_shape_with_layout() = arg_shape.ToProto(); + std::optional opts; + if (!execution_options) { + opts = CreateDefaultExecutionOptions(); } - CompileResponse response; - VLOG(1) << "making compile request: " << request.ShortDebugString(); - absl::Status s = stub_->Compile(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - TF_RET_CHECK(response.has_handle()); - return response.handle(); + return stub_->Compile(computation, argument_shapes, + execution_options ? *execution_options : *opts); } absl::StatusOr> Client::Execute( const ExecutionHandle& handle, absl::Span arguments, ExecutionProfile* execution_profile) { - ExecuteRequest request; - *request.mutable_handle() = handle; - for (GlobalData* argument : arguments) { - CHECK(argument != nullptr) << "Argument pointers must not be null."; - *request.add_arguments() = argument->handle(); - } - - ExecuteResponse response; - VLOG(1) << "making execute request: " << request.ShortDebugString(); - absl::Status s = stub_->Execute(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - if (execution_profile != nullptr) { - *execution_profile = response.profile(); - } - - return std::make_unique(stub_, response.output()); + return stub_->Execute(handle, arguments, execution_profile); } absl::StatusOr> Client::Execute( @@ -329,39 +163,7 @@ absl::StatusOr> Client::Execute( absl::StatusOr>> Client::ExecuteParallel(absl::Span computations) { - ExecuteGraphParallelRequest request; - - for (const XlaComputationInstance& computation : computations) { - ExecuteGraphRequest single_request; - *single_request.mutable_computation() = computation.computation.proto(); - for (GlobalData* argument : computation.arguments) { - *single_request.add_arguments() = argument->handle(); - } - *single_request.mutable_execution_options() = computation.execution_options; - *request.add_requests() = single_request; - } - - ExecuteParallelResponse response; - VLOG(1) << "making execute-graph-parallel request: " - << request.ShortDebugString(); - absl::Status s = stub_->ExecuteGraphParallel(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - std::vector> outputs; - for (size_t i = 0, end = response.responses_size(); i < end; ++i) { - outputs.push_back( - std::make_unique(stub_, response.responses(i).output())); - if (i < computations.size() && - computations[i].execution_profile != nullptr) { - *computations[i].execution_profile = response.responses(i).profile(); - } - } - - return std::move(outputs); + return stub_->ExecuteGraphParallel(computations); } absl::StatusOr> Client::GetDeviceHandles( @@ -369,59 +171,17 @@ absl::StatusOr> Client::GetDeviceHandles( if (device_count < 1) { return InvalidArgument("device_count must be greater than 0"); } - GetDeviceHandlesRequest request; - request.set_device_count(device_count); - GetDeviceHandlesResponse response; - VLOG(1) << "making get device request: " << request.ShortDebugString(); - absl::Status s = stub_->GetDeviceHandles(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - std::vector device_handles; - const auto& response_device_handles = response.device_handles(); - device_handles.reserve(response_device_handles.size()); - for (const DeviceHandle& device_handle : response_device_handles) { - device_handles.push_back(device_handle); - } - - return device_handles; + return stub_->GetDeviceHandles(device_count); } absl::Status Client::Unregister(const GlobalData& data) { - UnregisterRequest request; - *request.add_data() = data.handle(); - UnregisterResponse response; - - VLOG(1) << "making unregister request"; - absl::Status s = stub_->Unregister(&request, &response); - VLOG(1) << "done with request"; - - return s; + return stub_->Unregister(data.handle()); } absl::StatusOr>> Client::DeconstructTuple(const GlobalData& data) { - DeconstructTupleRequest request; - *request.mutable_tuple_handle() = data.handle(); - DeconstructTupleResponse response; - - VLOG(1) << "making DestructTuple request"; - absl::Status s = stub_->DeconstructTuple(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - std::vector> handles; - for (auto& handle : response.element_handles()) { - handles.push_back(std::make_unique(stub_, handle)); - } - return std::move(handles); + return stub_->DeconstructTuple(data); } absl::StatusOr> Client::GetComputationShape( @@ -431,36 +191,12 @@ absl::StatusOr> Client::GetComputationShape( } absl::StatusOr Client::GetShape(const GlobalData& data) { - GetShapeRequest request; - *request.mutable_data() = data.handle(); - GetShapeResponse response; - - VLOG(1) << "making get shape request"; - absl::Status s = stub_->GetShape(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - return Shape(response.shape()); + return stub_->GetShape(data); } absl::StatusOr Client::CreateChannelHandleByType( ChannelHandle::ChannelType type) { - CreateChannelHandleRequest request; - request.set_channel_type(type); - CreateChannelHandleResponse response; - - VLOG(1) << "making create channel handle request"; - absl::Status s = stub_->CreateChannelHandle(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - return response.channel(); + return stub_->CreateChannelHandle(type); } absl::StatusOr Client::CreateChannelHandle() { diff --git a/third_party/xla/xla/client/client.h b/third_party/xla/xla/client/client.h index 66d82ec79f744e..1ecfcfe6f358eb 100644 --- a/third_party/xla/xla/client/client.h +++ b/third_party/xla/xla/client/client.h @@ -22,7 +22,6 @@ limitations under the License. #include #include "absl/types/span.h" -#include "xla/client/global_data.h" #include "xla/client/xla_computation.h" #include "xla/literal.h" #include "xla/service/hlo.pb.h" @@ -41,6 +40,8 @@ class Client { explicit Client(Service* stub); virtual ~Client(); + using XlaComputationInstance = XlaComputationInstance; + // Compile the computation with the given argument shapes and returns the // handle to the compiled executable. The compiled executable is cached on the // service, and the returned handle can be used for execution without @@ -70,7 +71,9 @@ class Client { // will be filled with profile data from the execution. absl::StatusOr> Execute( const ExecutionHandle& handle, absl::Span arguments, - ExecutionProfile* execution_profile = nullptr); + ExecutionProfile* execution_profile = nullptr + + ); // Executes the computation with the given arguments and returns the global // data that was produced from the execution. @@ -93,26 +96,6 @@ class Client { const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); - // A struct to represent a computation instance to be executed. - // * If execution_options.device_handles is not empty, the computation is - // executed on the devices associated with the handles by partitioning the - // computation based on the attached sharding attributes. Otherwise, a - // device is chosen by the service. - struct XlaComputationInstance { - const XlaComputation& computation; - std::vector arguments; - ExecutionOptions execution_options; - ExecutionProfile* execution_profile; - - XlaComputationInstance(const XlaComputation& computation, - std::vector arguments, - ExecutionOptions execution_options, - ExecutionProfile* execution_profile) - : computation(computation), - arguments(std::move(arguments)), - execution_options(execution_options), - execution_profile(execution_profile) {} - }; // Executes a list XlaComputationInstances and returns global data produced // from each computation. diff --git a/third_party/xla/xla/client/global_data.cc b/third_party/xla/xla/client/global_data.cc deleted file mode 100644 index 66a5c5fee61673..00000000000000 --- a/third_party/xla/xla/client/global_data.cc +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/client/global_data.h" - -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "xla/types.h" -#include "tsl/platform/logging.h" - -namespace xla { -namespace { - -// Releases a set of global data handles owned by the parent service -// interface. -void ReleaseHandles(Service* parent, - const absl::Span handles) { - UnregisterRequest request; - for (auto& handle : handles) { - VLOG(1) << "Requesting to unregister " << handle.ShortDebugString(); - *request.add_data() = handle; - } - UnregisterResponse response; - absl::Status status = parent->Unregister(&request, &response); - VLOG(1) << "Done with request"; - if (!status.ok()) { - LOG(WARNING) << "Failed to unregister handles: " << status - << "; continuing anyway..."; - } -} - -} // namespace - -GlobalData::GlobalData(Service* parent, GlobalDataHandle handle) - : handle_(std::move(handle)), parent_(parent) {} - -GlobalData::~GlobalData() { - if (parent_ != nullptr) { - ReleaseHandles(parent_, {handle_}); - } -} - -/* static */ void GlobalData::Release( - std::vector> instances) { - absl::flat_hash_map> - parent_handles_map; - for (auto& instance : instances) { - if (instance->parent_ != nullptr) { - parent_handles_map[instance->parent_].push_back(instance->Release()); - } - } - for (auto& parent_handles : parent_handles_map) { - ReleaseHandles(parent_handles.first, parent_handles.second); - } -} - -} // namespace xla diff --git a/third_party/xla/xla/client/global_data.h b/third_party/xla/xla/client/global_data.h index 790a92a97c26bc..d47209edee377b 100644 --- a/third_party/xla/xla/client/global_data.h +++ b/third_party/xla/xla/client/global_data.h @@ -26,37 +26,8 @@ limitations under the License. namespace xla { -// A GlobalData object represents a globally-accessible allocation of -// data in the associated XLA service. -class GlobalData { - public: - // Gives ownership of the global data handle to this object. - GlobalData(Service* parent, GlobalDataHandle handle); - - // Unregisters the wrapped handle, which causes the service to - // deallocate the associated data. - ~GlobalData(); - - const GlobalDataHandle& handle() const { return handle_; } - - // Releases a set of GlobalData handles. A single RPC will be issued - // per unique Service of the given GlobalData objects. - static void Release(std::vector> instances); - - private: - // Detaches the global data handle from the object, such that the destructor - // will not try to release it. - GlobalDataHandle Release() { - parent_ = nullptr; - return handle_; - } - - GlobalDataHandle handle_; // Handle being wrapped. - Service* parent_; // Service used to unregister handle_. - - GlobalData(const GlobalData&) = delete; - GlobalData& operator=(const GlobalData&) = delete; -}; +// TODO(cheshire): Remove. +// Deprecated target for backwards compatibility. } // namespace xla diff --git a/third_party/xla/xla/client/local_client.cc b/third_party/xla/xla/client/local_client.cc index b45cb62eaad3cb..e00f39143bb6ee 100644 --- a/third_party/xla/xla/client/local_client.cc +++ b/third_party/xla/xla/client/local_client.cc @@ -493,7 +493,7 @@ absl::StatusOr LocalClient::ReplicaNumberToDeviceOrdinal( return local_service_->ReplicaNumberToDeviceOrdinal(replica_number); } -absl::StatusOr LocalClient::TransferToLocalServer( +absl::StatusOr LocalClient::TransferToLocalServer( const ::xla::BorrowingLiteral& literal, int device_ordinal) { const ::xla::Shape& shape = literal.shape(); @@ -506,14 +506,13 @@ absl::StatusOr LocalClient::TransferToLocalServer( stream.get(), literal, shaped_buffer)); std::vector<::xla::ScopedShapedBuffer> replicated_buffer; replicated_buffer.emplace_back(std::move(shaped_buffer)); - ::xla::TransferToServerResponse result; - TF_ASSIGN_OR_RETURN(*result.mutable_data(), + TF_ASSIGN_OR_RETURN(GlobalDataHandle data, local_service_->RegisterReplicatedBuffers( std::move(replicated_buffer), absl::StrCat("TransferToServer literal of shape ", ::xla::ShapeUtil::HumanString(shape)))); - return result; + return data; } } // namespace xla diff --git a/third_party/xla/xla/client/local_client.h b/third_party/xla/xla/client/local_client.h index 0d7ee75b07beab..f26c67ced132c8 100644 --- a/third_party/xla/xla/client/local_client.h +++ b/third_party/xla/xla/client/local_client.h @@ -171,7 +171,7 @@ class LocalClient : public Client { se::DeviceMemoryAllocator* allocator = nullptr); // Transfer the BorrowingLiteral to the device with the given ordinal. - absl::StatusOr TransferToLocalServer( + absl::StatusOr TransferToLocalServer( const ::xla::BorrowingLiteral& literal, int device_ordinal); // Copy the data from the device contained in the given ShapedBuffer and diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 99bb153d1d32c1..3d1f6e42012fa7 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -1238,6 +1238,7 @@ cc_library( "//xla:debug_options_flags", "//xla:executable_run_options", "//xla:execution_options_util", + "//xla:literal", "//xla:shape_layout", "//xla:shape_util", "//xla:status_macros", @@ -1245,11 +1246,13 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", + "//xla/client:xla_computation", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", diff --git a/third_party/xla/xla/service/compile_only_service.h b/third_party/xla/xla/service/compile_only_service.h index 27f4b4c97626f1..0238a16f282946 100644 --- a/third_party/xla/xla/service/compile_only_service.h +++ b/third_party/xla/xla/service/compile_only_service.h @@ -53,30 +53,23 @@ class CompileOnlyService : public Service { const AotCompilationOptions& options, std::unique_ptr* metadata); - absl::Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) override { + absl::StatusOr> GetDeviceHandles( + int64_t device_count) override { return Unimplemented("CompileOnlyService does not support devices."); } - absl::Status TransferToServer(const TransferToServerRequest* arg, - TransferToServerResponse* result) override { - return Unimplemented( - "CompileOnlyService does not support device data transfers."); - } - absl::Status TransferToInfeed(const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) override { + + absl::StatusOr> TransferToServer( + const LiteralSlice& literal_slice, + const DeviceHandle* device_handle) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - absl::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) override { + + absl::Status TransferToInfeed(const LiteralSlice& literal, int64_t replica_id, + const DeviceHandle* device_handle) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - absl::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) override { - return Unimplemented("CompileOnlyService does not support devices."); - } private: explicit CompileOnlyService(const ServiceOptions& options, diff --git a/third_party/xla/xla/service/service.cc b/third_party/xla/xla/service/service.cc index 207eae0ab440cc..4a1f323850d9e7 100644 --- a/third_party/xla/xla/service/service.cc +++ b/third_party/xla/xla/service/service.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "xla/debug_options_flags.h" @@ -33,6 +34,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" #include "xla/layout_util.h" +#include "xla/literal.h" #include "xla/service/backend.h" #include "xla/service/compiler.h" #include "xla/service/computation_layout.h" @@ -160,36 +162,26 @@ Service::Service(const ServiceOptions& options, } } -absl::Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) { - TF_ASSIGN_OR_RETURN(*result->mutable_channel(), - channel_tracker_.NewChannel(arg->channel_type())); - return absl::OkStatus(); +absl::StatusOr Service::CreateChannelHandle( + ChannelHandle::ChannelType type) { + return channel_tracker_.NewChannel(type); } -absl::Status Service::Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) { - absl::Status status; - for (auto& data : arg->data()) { - absl::Status unregister_status = allocation_tracker_.Unregister(data); - if (!unregister_status.ok() && status.ok()) { - status = unregister_status; - } - } - return status; +absl::Status Service::Unregister(const GlobalDataHandle& data) { + return allocation_tracker_.Unregister(data); } // Deconstructs a previously-allocated global handle. -absl::Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result) { - TF_ASSIGN_OR_RETURN( - std::vector elements, - allocation_tracker_.DeconstructTuple(arg->tuple_handle())); - - for (auto& element : elements) { - *result->add_element_handles() = element; +absl::StatusOr>> +Service::DeconstructTuple(const GlobalData& data) { + TF_ASSIGN_OR_RETURN(std::vector elements, + allocation_tracker_.DeconstructTuple(data.handle())); + std::vector> out; + out.reserve(elements.size()); + for (GlobalDataHandle& element : elements) { + out.push_back(std::make_unique(this, element)); } - return absl::OkStatus(); + return out; } absl::Status Service::ValidateResultShape(const Shape& client_shape, @@ -207,20 +199,14 @@ absl::Status Service::ValidateResultShape(const Shape& client_shape, absl::StatusOr>> Service::ResolveAndValidateArguments( - absl::Span arguments, + absl::Span arguments, absl::Span stream_executors) const { CHECK_EQ(options_.number_of_replicas(), stream_executors.size()); std::vector> replicated_arguments; replicated_arguments.resize(options_.number_of_replicas()); for (size_t i = 0; i < arguments.size(); ++i) { - auto buffer_status = allocation_tracker_.Resolve(*arguments[i]); - if (!buffer_status.ok()) { - return tsl::errors::CreateWithUpdatedMessage( - buffer_status.status(), - StrCat(buffer_status.status().message(), ", ", - "failed to resolve allocation for parameter ", i)); - } - auto replicated_buffers = buffer_status.value(); + TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, + allocation_tracker_.Resolve(arguments[i]->handle())); CHECK_EQ(options_.number_of_replicas(), replicated_buffers.size()); for (int replica = 0; replica < options_.number_of_replicas(); ++replica) { const ShapedBuffer* shaped_buffer = replicated_buffers[replica]; @@ -515,9 +501,8 @@ absl::StatusOr> Service::GetExecutors( } absl::StatusOr>> -Service::GetArguments( - const ExecutionOptions& execution_options, - absl::Span arguments) const { +Service::GetArguments(const ExecutionOptions& execution_options, + absl::Span arguments) const { // Resolve the allocations for the arguments of the computation, and create // a vector of device memory offsets for the arguments from the allocations. // In the case of partitioned computations, assume all arguments go on the @@ -531,8 +516,9 @@ Service::GetArguments( return replicated_arguments; } -absl::Status Service::ExecuteGraphParallel( - const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) { +absl::StatusOr>> +Service::ExecuteGraphParallel( + absl::Span computations) { VLOG(1) << "running execute-graph-parallel request"; std::vector>> all_arguments; @@ -543,10 +529,11 @@ absl::Status Service::ExecuteGraphParallel( std::vector device_handles; int num_requested_devices = - std::accumulate(arg->requests().begin(), arg->requests().end(), 0, - [](int a, const ExecuteGraphRequest& r) -> int { - return a + r.execution_options().device_handles_size(); + std::accumulate(computations.begin(), computations.end(), 0, + [](int a, const XlaComputationInstance& r) -> int { + return a + r.execution_options.device_handles_size(); }); + if (num_requested_devices * options_.number_of_replicas() > execute_backend_->device_count()) { return FailedPrecondition( @@ -554,23 +541,24 @@ absl::Status Service::ExecuteGraphParallel( num_requested_devices); } - for (int64_t i = 0; i < arg->requests_size(); ++i) { + for (int64_t i = 0; i < computations.size(); ++i) { + const XlaComputationInstance& computation = computations[i]; + // Get the stream executor for the i'th computation. This stream executor // is one of the executors to run the replicated computation. - const ExecutionOptions& execution_options = - arg->requests(i).execution_options(); - const ExecuteGraphRequest& request = arg->requests(i); - TF_RET_CHECK(request.has_computation()) << "computations may not be empty"; - TF_RET_CHECK(request.computation().has_host_program_shape()) + const ExecutionOptions& execution_options = computation.execution_options; + TF_RET_CHECK(computation.computation.proto().has_host_program_shape()) << "program shape may not be empty"; // Get the executors. - TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options, - arg->requests_size(), i)); + TF_ASSIGN_OR_RETURN( + std::vector executors, + GetExecutors(execution_options, computations.size(), i)); // Get the replicated arguments. - TF_ASSIGN_OR_RETURN(auto replicated_arguments, - GetArguments(execution_options, request.arguments())); + TF_ASSIGN_OR_RETURN( + std::vector> replicated_arguments, + GetArguments(execution_options, computation.arguments)); for (auto& args : replicated_arguments) { for (auto& arg : args) { @@ -596,8 +584,8 @@ absl::Status Service::ExecuteGraphParallel( TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig( - ProgramShape{request.computation().host_program_shape()}, - replicated_arguments.front(), request.execution_options())); + ProgramShape{computation.computation.proto().host_program_shape()}, + replicated_arguments.front(), computation.execution_options)); VLOG(3) << "ExecuteGraphParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -605,10 +593,10 @@ absl::Status Service::ExecuteGraphParallel( // Adds to the vectors to build and execute the computations after the loop. all_arguments.push_back(replicated_arguments); all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}}); - module_protos.push_back(&request.computation()); + module_protos.push_back(&computation.computation.proto()); module_configs.push_back(std::move(module_config)); computation_names.insert(computation_names.end(), executors.size(), - request.computation().name()); + computation.computation.name()); all_executors.push_back(executors); device_handles.insert(device_handles.end(), execution_options.device_handles().begin(), @@ -680,6 +668,12 @@ absl::Status Service::ExecuteGraphParallel( } } + for (int64_t i = 0; i < computations.size(); ++i) { + if (computations[i].execution_profile != nullptr) { + *computations[i].execution_profile = profile; + } + } + if (!execution_status.ok()) { // Execution failed so we don't have the results. Dump the HLO snapshot // with just the program arguments. @@ -690,11 +684,11 @@ absl::Status Service::ExecuteGraphParallel( TF_RETURN_IF_ERROR(execution_status); - for (const GlobalDataHandle& output : outputs) { - ExecuteResponse response; - *response.mutable_output() = output; - *response.mutable_profile() = profile; - *result->add_responses() = response; + std::vector> out; + + out.reserve(out.size()); + for (GlobalDataHandle& output : outputs) { + out.push_back(std::make_unique(this, output)); } for (int i = 0, end = executable_ptrs.size(); i < end; i++) { @@ -712,31 +706,32 @@ absl::Status Service::ExecuteGraphParallel( } VLOG(1) << "successfully completed 'execute-graph-parallel' request"; - return absl::OkStatus(); + return out; } -absl::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) { +absl::StatusOr> Service::GetDeviceHandles( + int64_t device_count) { const int64_t available_device_count = execute_backend_->device_count(); const int64_t replica_count = options_.number_of_replicas(); if (replica_count <= 0) { return FailedPrecondition("Replica count must be a positive integer"); } - if (available_device_count < arg->device_count() * replica_count) { + if (available_device_count < device_count * replica_count) { return ResourceExhausted( "Requested logical device count (%d) with replica count (%d) exceeds " "the number of available physical devices on the target (%d)", - arg->device_count(), replica_count, available_device_count); + device_count, replica_count, available_device_count); } - for (int64_t i = 0; i < arg->device_count(); ++i) { + std::vector out; + + for (int64_t i = 0; i < device_count; ++i) { DeviceHandle device_handle; device_handle.set_handle(i); - device_handle.set_device_count(arg->device_count()); - *result->add_device_handles() = device_handle; + device_handle.set_device_count(device_count); + out.push_back(device_handle); } - - return absl::OkStatus(); + return out; } absl::StatusOr> Service::BuildExecutable( @@ -795,71 +790,66 @@ absl::StatusOr> Service::BuildExecutable( return executable; } -absl::Status Service::Compile(const CompileRequest* arg, - CompileResponse* result) { +absl::StatusOr Service::Compile( + const XlaComputation& computation, absl::Span argument_shapes, + const ExecutionOptions& execution_options) { VLOG(1) << "running compile request"; - if (!arg->has_computation()) { - return InvalidArgument("computations may not be empty"); - } - if (!arg->computation().has_host_program_shape()) { + + if (!computation.proto().has_host_program_shape()) { return InvalidArgument("program shape may not be empty"); } - if (arg->execution_options().device_handles_size() > 1) { + if (execution_options.device_handles_size() > 1) { return InvalidArgument( "The compile request does not support multiple device handles."); } - std::vector argument_shapes; - argument_shapes.reserve(arg->input_shape_with_layout_size()); std::vector argument_shape_ptrs; - for (const ShapeProto& shape_proto : arg->input_shape_with_layout()) { - argument_shapes.push_back(Shape(shape_proto)); - argument_shape_ptrs.push_back(&argument_shapes.back()); + for (const Shape& shape : argument_shapes) { + argument_shape_ptrs.push_back(&shape); } + TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(ProgramShape{arg->computation().host_program_shape()}, - argument_shape_ptrs, &arg->execution_options())); + CreateModuleConfig(ProgramShape{computation.proto().host_program_shape()}, + argument_shape_ptrs, &execution_options)); VLOG(3) << "Compile created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); TF_ASSIGN_OR_RETURN( std::unique_ptr executable, - BuildExecutable(arg->computation(), std::move(module_config), + BuildExecutable(computation.proto(), std::move(module_config), execute_backend_.get(), execute_backend_->default_stream_executor(), {/*device_allocator=*/nullptr})); - *result->mutable_handle() = compilation_cache_.Insert(std::move(executable)); - VLOG(1) << "successfully completed 'compile' request"; - return absl::OkStatus(); + return compilation_cache_.Insert(std::move(executable)); } -absl::Status Service::Execute(const ExecuteRequest* arg, - ExecuteResponse* result) { +absl::StatusOr> Service::Execute( + const ExecutionHandle& handle, absl::Span arguments, + ExecutionProfile* execution_profile) { VLOG(1) << "running execute request"; - if (!arg->has_handle()) { - return InvalidArgument("execution handle should not be empty"); - } - TF_ASSIGN_OR_RETURN(auto executable, - compilation_cache_.LookUp(arg->handle())); - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, - SingleComputationDeviceHandle())); + TF_ASSIGN_OR_RETURN(std::shared_ptr executable, + compilation_cache_.LookUp(handle)); + + TF_ASSIGN_OR_RETURN( + std::vector replicas, + Replicas(*execute_backend_, SingleComputationDeviceHandle())); TF_ASSIGN_OR_RETURN( std::vector> replicated_arguments, - ResolveAndValidateArguments(arg->arguments(), replicas)); + ResolveAndValidateArguments(arguments, replicas)); // Check that the replicated_arguments has the same shape and layout as the // module config used when creating the executable. const int64_t num_module_args = executable->module_config().entry_computation_layout().parameter_count(); - if (num_module_args != arg->arguments_size()) { + if (num_module_args != arguments.size()) { return InvalidArgument( "The executable expects %lld arguments, but sees %lld.", - num_module_args, arg->arguments_size()); + num_module_args, arguments.size()); } for (int64_t i = 0; i < num_module_args; i++) { const Shape& shape_module = @@ -887,35 +877,32 @@ absl::Status Service::Execute(const ExecuteRequest* arg, } TF_ASSIGN_OR_RETURN( - *result->mutable_output(), - ExecuteAndRegisterResult(executable.get(), replicated_arguments, - execute_backend_.get(), - SingleComputationDeviceHandle(), - "result of " + executable->module().name(), - result->mutable_profile())); + GlobalDataHandle output, + ExecuteAndRegisterResult( + executable.get(), replicated_arguments, execute_backend_.get(), + SingleComputationDeviceHandle(), + "result of " + executable->module().name(), execution_profile)); if (executable->dumping_snapshot()) { - TF_ASSIGN_OR_RETURN( - const ShapedBuffer* result_buffer, - allocation_tracker_.ResolveForReplica(result->output(), 0)); + TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, + allocation_tracker_.ResolveForReplica(output, 0)); TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(), execute_backend_->transfer_manager(), &snapshot)); DumpHloSnapshotIfEnabled(executable->module(), snapshot); } - VLOG(1) << "successfully completed 'execute' request"; - return absl::OkStatus(); + return std::make_unique(this, output); } -absl::Status Service::TransferToClient(const TransferToClientRequest* arg, - TransferToClientResponse* result) { +absl::StatusOr Service::TransferToClient( + const GlobalData& data, const Shape* shape_with_layout) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer, - allocation_tracker_.ResolveForReplica(arg->data(), 0)); + allocation_tracker_.ResolveForReplica(data.handle(), 0)); Shape return_shape; - if (arg->has_shape_with_layout()) { - return_shape = Shape(arg->shape_with_layout()); + if (shape_with_layout) { + return_shape = Shape(*shape_with_layout); if (!LayoutUtil::HasLayout(return_shape)) { return InvalidArgument("shape_with_layout must have layout if present."); } @@ -943,24 +930,17 @@ absl::Status Service::TransferToClient(const TransferToClientRequest* arg, stream.get(), *shaped_buffer)); if (LayoutUtil::LayoutsInShapesEqual(return_shape, result_literal.shape())) { - *result->mutable_literal() = result_literal.ToProto(); - } else { - *result->mutable_literal() = - result_literal.Relayout(return_shape).ToProto(); + return result_literal; } - return absl::OkStatus(); + return result_literal.Relayout(return_shape); } -absl::Status Service::TransferToServer(const TransferToServerRequest* arg, - TransferToServerResponse* result) { - TF_ASSIGN_OR_RETURN(Literal literal, - Literal::CreateFromProto(arg->literal())); - const Shape& shape = literal.shape(); - +absl::StatusOr> Service::TransferToServer( + const LiteralSlice& literal_slice, const DeviceHandle* device_handle) { + const Shape& shape = literal_slice.shape(); std::vector replicas; - if (arg->has_device_handle()) { - TF_ASSIGN_OR_RETURN(replicas, - Replicas(*execute_backend_, arg->device_handle())); + if (device_handle) { + TF_ASSIGN_OR_RETURN(replicas, Replicas(*execute_backend_, *device_handle)); } else { TF_ASSIGN_OR_RETURN( replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); @@ -982,100 +962,93 @@ absl::Status Service::TransferToServer(const TransferToServerRequest* arg, TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor)); TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralToDevice( - stream.get(), literal, shaped_buffer)); + stream.get(), literal_slice, shaped_buffer)); replicated_buffers.emplace_back(std::move(shaped_buffer)); } - TF_ASSIGN_OR_RETURN(*result->mutable_data(), + + TF_ASSIGN_OR_RETURN(GlobalDataHandle out, allocation_tracker_.RegisterReplicatedBuffers( std::move(replicated_buffers), StrCat("TransferToServer literal of shape ", ShapeUtil::HumanString(shape)))); - return absl::OkStatus(); + return std::make_unique(this, out); } -absl::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) { +absl::Status Service::TransferToInfeed(const LiteralSlice& literal, + int64_t replica_id, + const DeviceHandle* device_handle) { const int64_t replica_count = options_.number_of_replicas(); - if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { + if (replica_id < 0 || replica_id >= replica_count) { return FailedPrecondition( "%s", - StrCat("The replica_id=", arg->replica_id(), + StrCat("The replica_id=", replica_id, " on TransferToInfeedRequest not in range [0, replica_count=", replica_count, ").")); } se::StreamExecutor* executor; - if (arg->has_device_handle()) { + if (device_handle) { TF_ASSIGN_OR_RETURN(auto replicas, - Replicas(*execute_backend_, arg->device_handle())); - executor = replicas[arg->replica_id()]; + Replicas(*execute_backend_, *device_handle)); + executor = replicas[replica_id]; } else { TF_ASSIGN_OR_RETURN( auto replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); - executor = replicas[arg->replica_id()]; + executor = replicas[replica_id]; } - TF_ASSIGN_OR_RETURN(Literal literal, - Literal::CreateFromProto(arg->literal())); return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor, literal); } -absl::Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) { +absl::StatusOr Service::TransferFromOutfeed( + const Shape* shape_with_layout, int64_t replica_id, + const DeviceHandle* device_handle) { const int64_t replica_count = options_.number_of_replicas(); - if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { + if (replica_id < 0 || replica_id >= replica_count) { return FailedPrecondition( "The replica_id=%d on TransferFromOutfeedRequest not in range [0, %d)", - arg->replica_id(), replica_count); + replica_id, replica_count); } se::StreamExecutor* executor; - if (arg->has_device_handle()) { + if (device_handle) { TF_ASSIGN_OR_RETURN(auto replicas, - Replicas(*execute_backend_, arg->device_handle())); - executor = replicas[arg->replica_id()]; + Replicas(*execute_backend_, *device_handle)); + executor = replicas[replica_id]; } else { TF_ASSIGN_OR_RETURN( auto replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); - executor = replicas[arg->replica_id()]; + executor = replicas[replica_id]; } - auto literal = Literal::CreateFromShape(Shape(arg->shape_with_layout())); + auto literal = Literal::CreateFromShape(*shape_with_layout); TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( executor, &literal)); - *result->mutable_literal() = literal.ToProto(); - return absl::OkStatus(); + return literal; } -absl::Status Service::ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) { - return execute_backend_->ResetDevices(); -} +absl::Status Service::ResetDevice() { return execute_backend_->ResetDevices(); } -absl::Status Service::ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) { - if (!arg->has_computation()) { - return InvalidArgument("computations may not be empty"); - } - if (!arg->computation().has_host_program_shape()) { +absl::StatusOr Service::ComputeConstantGraph( + const XlaComputation& computation, const Layout* output_layout) { + if (!computation.proto().has_host_program_shape()) { return InvalidArgument("program shape may not be empty"); } - if (arg->computation().host_program_shape().parameters_size() != 0) { + if (computation.proto().host_program_shape().parameters_size() != 0) { return InvalidArgument( "constant computation may not depend on any parameters."); } - ProgramShape program_shape(arg->computation().host_program_shape()); + ProgramShape program_shape(computation.proto().host_program_shape()); TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); - std::optional output_layout; - if (arg->has_output_layout()) { - output_layout = Layout::CreateFromProto(arg->output_layout()); + + if (output_layout) { TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( *output_layout, program_shape.result())); } @@ -1083,7 +1056,7 @@ absl::Status Service::ComputeConstantGraph( HloModuleConfig config(program_shape); TF_ASSIGN_OR_RETURN(std::unique_ptr module, - CreateModuleFromProto(arg->computation(), config)); + CreateModuleFromProto(computation.proto(), config)); DynamicPadder dynamic_padder; TF_RETURN_IF_ERROR(dynamic_padder.Run(module.get()).status()); @@ -1112,20 +1085,16 @@ absl::Status Service::ComputeConstantGraph( // relayout here. // // TODO(b/77824332): Make HloEvaluator take care of the re-layout. - if (output_layout.has_value()) { + if (output_layout) { result_literal = result_literal.Relayout(*output_layout); } - *result->mutable_literal() = result_literal.ToProto(); - - return absl::OkStatus(); + return result_literal; } -absl::Status Service::GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) { +absl::StatusOr Service::GetShape(const GlobalData& data) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, - allocation_tracker_.ResolveForReplica(arg->data(), 0)); - *result->mutable_shape() = buffer->on_device_shape().ToProto(); - return absl::OkStatus(); + allocation_tracker_.ResolveForReplica(data.handle(), 0)); + return buffer->on_device_shape(); } DeviceHandle Service::SingleComputationDeviceHandle() const { @@ -1152,4 +1121,46 @@ absl::StatusOr> Service::Replicas( return replicas; } +namespace { + +// Releases a set of global data handles owned by the parent service +// interface. +void ReleaseHandles(Service* parent, + const absl::Span handles) { + for (const GlobalDataHandle& handle : handles) { + VLOG(1) << "Requesting to unregister " << handle.ShortDebugString(); + absl::Status status = parent->Unregister(handle); + if (!status.ok()) { + LOG(WARNING) << "Failed to unregister handles: " << status + << "; continuing anyway..."; + } + } + VLOG(1) << "Done with request"; +} + +} // namespace + +GlobalData::GlobalData(Service* parent, GlobalDataHandle handle) + : handle_(std::move(handle)), parent_(parent) {} + +GlobalData::~GlobalData() { + if (parent_ != nullptr) { + ReleaseHandles(parent_, {handle_}); + } +} + +/* static */ void GlobalData::Release( + std::vector> instances) { + absl::flat_hash_map> + parent_handles_map; + for (auto& instance : instances) { + if (instance->parent_ != nullptr) { + parent_handles_map[instance->parent_].push_back(instance->Release()); + } + } + for (auto& parent_handles : parent_handles_map) { + ReleaseHandles(parent_handles.first, parent_handles.second); + } +} + } // namespace xla diff --git a/third_party/xla/xla/service/service.h b/third_party/xla/xla/service/service.h index ff54b36ae04435..3fd7f227c362e8 100644 --- a/third_party/xla/xla/service/service.h +++ b/third_party/xla/xla/service/service.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/client/xla_computation.h" #include "xla/debug_options_flags.h" #include "xla/executable_run_options.h" #include "xla/hlo/ir/hlo_module.h" @@ -44,6 +45,8 @@ limitations under the License. namespace xla { +class Service; + // Options to configure the service when it is created. class ServiceOptions { public: @@ -73,6 +76,59 @@ class ServiceOptions { std::optional> allowed_devices_; }; +// A GlobalData object represents a globally-accessible allocation of +// data in the associated XLA service. +class GlobalData { + public: + // Gives ownership of the global data handle to this object. + GlobalData(Service* parent, GlobalDataHandle handle); + + // Unregisters the wrapped handle, which causes the service to + // deallocate the associated data. + ~GlobalData(); + + const GlobalDataHandle& handle() const { return handle_; } + + // Releases a set of GlobalData handles. A single RPC will be issued + // per unique Service of the given GlobalData objects. + static void Release(std::vector> instances); + + private: + // Detaches the global data handle from the object, such that the destructor + // will not try to release it. + GlobalDataHandle Release() { + parent_ = nullptr; + return handle_; + } + + GlobalDataHandle handle_; // Handle being wrapped. + Service* parent_; // Service used to unregister handle_. + + GlobalData(const GlobalData&) = delete; + GlobalData& operator=(const GlobalData&) = delete; +}; + +// A struct to represent a computation instance to be executed. +// * If execution_options.device_handles is not empty, the computation is +// executed on the devices associated with the handles by partitioning the +// computation based on the attached sharding attributes. Otherwise, a +// device is chosen by the service. +struct XlaComputationInstance { + const XlaComputation& computation; + std::vector arguments; + ExecutionOptions execution_options; + ExecutionProfile* execution_profile; + + XlaComputationInstance(const XlaComputation& computation, + std::vector arguments, + ExecutionOptions execution_options, + ExecutionProfile* execution_profile) + : computation(computation), + arguments(std::move(arguments)), + execution_options(execution_options), + execution_profile(execution_profile) {} +}; + // The XLA service object, which is the same across all platforms. It maintains // the service state of computations and allocations, and delegates // target-specific requests to the target-specific infrastructure @@ -83,30 +139,32 @@ class Service { // // If the handle given is not currently allocated, a NOT_FOUND status is // returned. - virtual absl::Status Unregister(const UnregisterRequest* arg, - UnregisterResponse* result); + virtual absl::Status Unregister(const GlobalDataHandle& data); // Deconstructs a tuple. Returns a newly created GlobalDataHandle for each // element in the tuple. - virtual absl::Status DeconstructTuple(const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result); + virtual absl::StatusOr>> + DeconstructTuple(const GlobalData& data); // Compiles a computation into an executable. The request contains the whole // computation graph. Returns the handle to the executable. - virtual absl::Status Compile(const CompileRequest* arg, - CompileResponse* result); + virtual absl::StatusOr Compile( + const XlaComputation& computation, + absl::Span argument_shapes, + const ExecutionOptions& execution_options); // Executes an executable with the provided global data passes as immutable // arguments. The request contains the handle to the executable. Returns // global data output and execution timing. - virtual absl::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result); + virtual absl::StatusOr> Execute( + const ExecutionHandle& handle, absl::Span arguments, + ExecutionProfile* execution_profile); // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each // computation. - virtual absl::Status ExecuteGraphParallel( - const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result); + absl::StatusOr>> ExecuteGraphParallel( + absl::Span computations); // Requests one or more device handles from the target. // @@ -116,27 +174,28 @@ class Service { // the first set of replicas, and the next R devices to the second set of // replicas, etc. Each returned device handle represents the device with the // replica id 0. - virtual absl::Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result); + virtual absl::StatusOr> GetDeviceHandles( + int64_t device_count); // Requests that global data be transferred to the client in literal form. - virtual absl::Status TransferToClient(const TransferToClientRequest* arg, - TransferToClientResponse* result); + virtual absl::StatusOr TransferToClient( + const GlobalData& data, const Shape* shape_with_layout); // Transfers data from a literal provided by the client, into device memory. - virtual absl::Status TransferToServer(const TransferToServerRequest* arg, - TransferToServerResponse* result); + virtual absl::StatusOr> TransferToServer( + const LiteralSlice& literal_slice, const DeviceHandle* device_handle); // Transfers data from a literal provided by the client, into the Infeed // buffer of the device. - virtual absl::Status TransferToInfeed(const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result); + virtual absl::Status TransferToInfeed(const LiteralSlice& literal, + int64_t replica_id, + const DeviceHandle* device_handle); // Transfers data from the Outfeed othe device to the literal provided by the // client. - virtual absl::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result); + virtual absl::StatusOr TransferFromOutfeed( + const Shape* shape_with_layout, int64_t replica_id, + const DeviceHandle* device_handle); // Resets devices, clearing all existing state on all the devices associated // with this service (including memory allocated on the devices). @@ -147,22 +206,19 @@ class Service { // ResetDevice should be called before an Execution that expect the device to // be in the reset state. For example, if the prior Execution modifies device // state (e.g., architectural state) that the next Execution depends on. - virtual absl::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result); + virtual absl::Status ResetDevice(); - virtual absl::Status ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result); + virtual absl::StatusOr ComputeConstantGraph( + const XlaComputation& computation, const Layout* output_layout); // Returns the shape (with layout) of an array associated with a given data // handle. - virtual absl::Status GetShape(const GetShapeRequest* arg, - GetShapeResponse* result); + virtual absl::StatusOr GetShape(const GlobalData& data); // Creates a unique channel handle that can be used for Send/Recv // instructions. - virtual absl::Status CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result); + virtual absl::StatusOr CreateChannelHandle( + ChannelHandle::ChannelType type); // Returns the backend used to execute computations. const Backend& backend() const { return *execute_backend_; } @@ -201,7 +257,7 @@ class Service { // Prepare the arguments for executing parallel. absl::StatusOr>> GetArguments( const ExecutionOptions& execution_options, - absl::Span arguments) const; + absl::Span arguments) const; protected: friend class LocalExecutable; @@ -217,7 +273,7 @@ class Service { // the corresponding replica. absl::StatusOr>> ResolveAndValidateArguments( - absl::Span arguments, + absl::Span arguments, absl::Span stream_executors) const; public: diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 918992c6aeff99..69c591200c9025 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -1045,6 +1045,7 @@ xla_test( "//xla:status_macros", "//xla:test", "//xla/client:xla_builder", + "//xla/service", ], ) diff --git a/third_party/xla/xla/tests/client_test.cc b/third_party/xla/xla/tests/client_test.cc index 6d67d07eb969da..1adb92207748aa 100644 --- a/third_party/xla/xla/tests/client_test.cc +++ b/third_party/xla/xla/tests/client_test.cc @@ -128,14 +128,14 @@ XLA_TEST_F(ClientTest, // We can't really test parallel execution on CPU since all of the cores in a // CPU are presented as a single device. So for now we test "parallel" // execution on a single device. - std::vector computation_instances; + std::vector computation_instances; TF_ASSERT_OK_AND_ASSIGN(std::vector devices, client_->GetDeviceHandles(1)); ASSERT_EQ(devices.size(), 1); ExecutionOptions options = execution_options_; *options.add_device_handles() = devices[0]; - computation_instances.push_back(Client::XlaComputationInstance( + computation_instances.push_back(XlaComputationInstance( add_with_one_arg, {const_arg.get()}, options, nullptr)); TF_ASSERT_OK_AND_ASSIGN(auto results, diff --git a/third_party/xla/xla/tests/gather_operation_test.cc b/third_party/xla/xla/tests/gather_operation_test.cc index b80d1570a9b3b0..94d02abbea3b04 100644 --- a/third_party/xla/xla/tests/gather_operation_test.cc +++ b/third_party/xla/xla/tests/gather_operation_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/execution_options_util.h" #include "xla/literal_util.h" +#include "xla/service/service.h" #include "xla/status_macros.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" @@ -778,7 +779,7 @@ XLA_TEST_F(GatherClientLibraryTest, xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions(); *execution_options.add_device_handles() = devices[0]; TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, builder.Build()); - std::vector computation_instances = { + std::vector computation_instances = { {computation, {operand_arg.get(), indices_arg.get()}, execution_options, diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index d9e4594d2b4708..16b1da1999d7c3 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -1037,189 +1037,6 @@ message HloModuleProtoWithConfig { HloModuleConfigProto config = 2; } -message GetDeviceHandlesRequest { - int64 device_count = 1; -} - -message GetDeviceHandlesResponse { - repeated DeviceHandle device_handles = 1; -} - -message TransferToClientRequest { - GlobalDataHandle data = 1; - - // This optional field directs the service to return the literal in this - // layout. A shape is used to hold the layout to accommodate tuples. - ShapeProto shape_with_layout = 2; -} - -message TransferToClientResponse { - LiteralProto literal = 1; -} - -message TransferToServerRequest { - LiteralProto literal = 1; - DeviceHandle device_handle = 2; -} - -message TransferToServerResponse { - GlobalDataHandle data = 1; -} - -message TransferToInfeedRequest { - LiteralProto literal = 1; - int64 replica_id = 2; - DeviceHandle device_handle = 3; -} - -message TransferToInfeedResponse {} - -message TransferFromOutfeedRequest { - // This optional field directs the service to return the literal in this - // layout. A shape is used to hold the layout to accommodate tuples. - ShapeProto shape_with_layout = 1; - - int64 replica_id = 2; - DeviceHandle device_handle = 3; -} - -message TransferFromOutfeedResponse { - LiteralProto literal = 1; -} - -message ResetDeviceRequest { - DeviceHandle device_handle = 1; -} - -message ResetDeviceResponse {} - -message CreateChannelHandleRequest { - ChannelHandle.ChannelType channel_type = 1; -} - -message CreateChannelHandleResponse { - ChannelHandle channel = 1; -} - -message UnregisterRequest { - repeated GlobalDataHandle data = 1; -} - -message UnregisterResponse {} - -message CompileRequest { - // The graph to be compiled. - HloModuleProto computation = 1; - - // Options that affect how XLA compiles code to service this request. - ExecutionOptions execution_options = 2; - - // The layouts of the input arguments. If not set, the default layout will be - // used. Although the real arguments are not needed in compilation, the - // layouts of the arguments can affect the compilation. - repeated ShapeProto input_shape_with_layout = 3; -} - -message CompileResponse { - // The handle to the executable. - ExecutionHandle handle = 1; -} - -message ExecuteRequest { - ExecutionHandle handle = 1; - - // The shape and layout of the arguments must be the same as the those of the - // executable's parameters. - repeated GlobalDataHandle arguments = 2; -} - -// TODO(b/118493728): Remove this and ExecuteGraphParallelRequest and replace -// the uses with calls to Compile and Execute. -message ExecuteGraphRequest { - HloModuleProto computation = 1; - repeated GlobalDataHandle arguments = 2; - - // Options that affect how XLA compiles and runs code to service this request. - ExecutionOptions execution_options = 3; -} - -message ExecuteGraphParallelRequest { - repeated ExecuteGraphRequest requests = 1; -} - -message ExecuteResponse { - GlobalDataHandle output = 1; - ExecutionProfile profile = 2; -} - -message ExecuteParallelResponse { - repeated ExecuteResponse responses = 1; -} - -message ComputeConstantGraphRequest { - HloModuleProto computation = 1; - LayoutProto output_layout = 2; -} - -message ComputeConstantResponse { - // A LiteralProto is returned directly for this request. - LiteralProto literal = 1; -} - -message DeconstructTupleRequest { - GlobalDataHandle tuple_handle = 2; -} - -message DeconstructTupleResponse { - repeated GlobalDataHandle element_handles = 1; -} - -message LoadDataRequest { - // Describes the path of the ColumnIO tablet to load. - string columnio_tablet_path = 1; - - // Describes the field to load within the ColumnIO tablet. - string columnio_field = 2; - - // Individual element shape, excluding rows. - ShapeProto element_shape = 3; - - // Warning: ColumnIO does not support random-access, so use offset with - // caution in performance-critical scenarios. - int64 offset = 4; - - // Maximum number of elements (with shape element_shape) to load. - int64 limit = 5; - - // If more than one item is requested (via limit > 1), then this request - // attribute zips together the produced vectors. - bool zip = 6; -} - -message LoadDataResponse { - GlobalDataHandle data = 1; - ShapeProto data_shape = 2; - int64 available_rows = 3; - int64 rows_loaded = 4; - int64 nanoseconds = 5; -} - -message GetShapeRequest { - GlobalDataHandle data = 1; -} - -message GetShapeResponse { - ShapeProto shape = 1; -} - -message UnpackRequest { - GlobalDataHandle data = 1; -} - -message UnpackResponse { - repeated GlobalDataHandle tied_data = 1; -} - // A trace estimated by the Latency Hiding Scheduler. message ScheduleProto { message Instruction { From b1f94db136e67c5d31887f1f00ff4c903202d91c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 19 Jun 2024 04:04:32 -0700 Subject: [PATCH 010/256] Automated Code Change PiperOrigin-RevId: 644700338 --- .../xla/third_party/tsl/tsl/lib/io/block.cc | 4 +- .../tsl/tsl/lib/io/buffered_file.h | 20 +++---- .../tsl/tsl/lib/io/buffered_inputstream.cc | 55 +++++++++--------- .../tsl/tsl/lib/io/buffered_inputstream.h | 22 +++---- .../tsl/lib/io/buffered_inputstream_test.cc | 8 +-- .../xla/third_party/tsl/tsl/lib/io/format.cc | 17 +++--- .../xla/third_party/tsl/tsl/lib/io/format.h | 8 +-- .../third_party/tsl/tsl/lib/io/inputbuffer.cc | 57 ++++++++++--------- .../third_party/tsl/tsl/lib/io/inputbuffer.h | 33 +++++------ .../tsl/tsl/lib/io/inputstream_interface.cc | 4 +- .../tsl/tsl/lib/io/inputstream_interface.h | 8 +-- .../tsl/lib/io/inputstream_interface_test.cc | 8 +-- .../third_party/tsl/tsl/lib/io/iterator.cc | 10 ++-- .../xla/third_party/tsl/tsl/lib/io/iterator.h | 4 +- .../tsl/tsl/lib/io/random_inputstream.cc | 23 ++++---- .../tsl/tsl/lib/io/random_inputstream.h | 12 ++-- .../tsl/tsl/lib/io/record_reader.cc | 29 +++++----- .../tsl/tsl/lib/io/record_reader.h | 18 +++--- .../tsl/lib/io/record_reader_writer_test.cc | 4 +- .../tsl/tsl/lib/io/record_writer.cc | 30 +++++----- .../tsl/tsl/lib/io/record_writer.h | 8 +-- .../tsl/tsl/lib/io/recordio_test.cc | 30 +++++----- 22 files changed, 209 insertions(+), 203 deletions(-) diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/block.cc b/third_party/xla/third_party/tsl/tsl/lib/io/block.cc index 0bc9fa3664c97b..8eefa4b5a3609f 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/block.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/block.cc @@ -98,7 +98,7 @@ class Block::Iter : public Iterator { uint32 restart_index_; // Index of restart block in which current_ falls string key_; StringPiece value_; - Status status_; + absl::Status status_; inline int Compare(const StringPiece& a, const StringPiece& b) const { return a.compare(b); @@ -135,7 +135,7 @@ class Block::Iter : public Iterator { } bool Valid() const override { return current_ < restarts_; } - Status status() const override { return status_; } + absl::Status status() const override { return status_; } StringPiece key() const override { assert(Valid()); return key_; diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file.h b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file.h index 5627a7228fb782..69300956d9fe20 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file.h @@ -36,7 +36,7 @@ class BufferedWritableFile : public WritableFile { } ~BufferedWritableFile() override { Close().IgnoreError(); } - Status Append(StringPiece str_data) override { + absl::Status Append(StringPiece str_data) override { int64_t bytes_left = str_data.size(); const char* data = str_data.data(); @@ -58,22 +58,22 @@ class BufferedWritableFile : public WritableFile { bytes_left -= append_bytes; } - return OkStatus(); + return absl::OkStatus(); } - Status Append(const absl::Cord& data) override { + absl::Status Append(const absl::Cord& data) override { for (absl::string_view fragment : data.Chunks()) { TF_RETURN_IF_ERROR(Append(fragment)); } - return OkStatus(); + return absl::OkStatus(); } - Status Close() override { + absl::Status Close() override { TF_RETURN_IF_ERROR(Flush()); return file_->Close(); } - Status Flush() override { + absl::Status Flush() override { if (buffer_pos_ > 0) { TF_RETURN_IF_ERROR(file_->Append(StringPiece(&buffer_[0], buffer_pos_))); buffer_pos_ = 0; @@ -81,18 +81,18 @@ class BufferedWritableFile : public WritableFile { return file_->Flush(); } - tsl::Status Tell(int64_t* position) override { + absl::Status Tell(int64_t* position) override { int64_t bytes_written; - tsl::Status status = file_->Tell(&bytes_written); + absl::Status status = file_->Tell(&bytes_written); if (status.ok()) { *position = bytes_written + buffer_pos_; - return OkStatus(); + return absl::OkStatus(); } else { return status; } } - Status Sync() override { return file_->Sync(); } + absl::Status Sync() override { return file_->Sync(); } // For compatibilty with the TensorBundle writer, we expose CRC32 checksums. uint32_t crc32() const { return crc32_; } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.cc b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.cc index b3cfdbb20818ec..89ed20757cf093 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.cc @@ -41,13 +41,13 @@ BufferedInputStream::~BufferedInputStream() { } } -Status BufferedInputStream::FillBuffer() { +absl::Status BufferedInputStream::FillBuffer() { if (!file_status_.ok()) { pos_ = 0; limit_ = 0; return file_status_; } - Status s = input_stream_->ReadNBytes(size_, &buf_); + absl::Status s = input_stream_->ReadNBytes(size_, &buf_); pos_ = 0; limit_ = buf_.size(); if (!s.ok()) { @@ -57,10 +57,10 @@ Status BufferedInputStream::FillBuffer() { } template -Status BufferedInputStream::ReadLineHelper(StringType* result, - bool include_eol) { +absl::Status BufferedInputStream::ReadLineHelper(StringType* result, + bool include_eol) { result->clear(); - Status s; + absl::Status s; size_t start_pos = pos_; while (true) { if (pos_ == limit_) { @@ -79,7 +79,7 @@ Status BufferedInputStream::ReadLineHelper(StringType* result, result->append(1, c); } pos_++; - return OkStatus(); + return absl::OkStatus(); } // We don't append '\r' to *result if (c == '\r') { @@ -89,12 +89,13 @@ Status BufferedInputStream::ReadLineHelper(StringType* result, pos_++; } if (absl::IsOutOfRange(s) && !result->empty()) { - return OkStatus(); + return absl::OkStatus(); } return s; } -Status BufferedInputStream::ReadNBytes(int64_t bytes_to_read, tstring* result) { +absl::Status BufferedInputStream::ReadNBytes(int64_t bytes_to_read, + tstring* result) { if (bytes_to_read < 0) { return errors::InvalidArgument("Can't read a negative number of bytes: ", bytes_to_read); @@ -105,7 +106,7 @@ Status BufferedInputStream::ReadNBytes(int64_t bytes_to_read, tstring* result) { } result->reserve(bytes_to_read); - Status s; + absl::Status s; while (result->size() < static_cast(bytes_to_read)) { // Check whether the buffer is fully read or not. if (pos_ == limit_) { @@ -127,12 +128,12 @@ Status BufferedInputStream::ReadNBytes(int64_t bytes_to_read, tstring* result) { // obtained enough data to satisfy the function call. Returning OK then. if (absl::IsOutOfRange(s) && (result->size() == static_cast(bytes_to_read))) { - return OkStatus(); + return absl::OkStatus(); } return s; } -Status BufferedInputStream::SkipNBytes(int64_t bytes_to_skip) { +absl::Status BufferedInputStream::SkipNBytes(int64_t bytes_to_skip) { if (bytes_to_skip < 0) { return errors::InvalidArgument("Can only skip forward, not ", bytes_to_skip); @@ -144,7 +145,7 @@ Status BufferedInputStream::SkipNBytes(int64_t bytes_to_skip) { // Otherwise, we already have read limit_ - pos_, so skip the rest. At this // point we need to get fresh data into the buffer, so reset pos_ and // limit_. - Status s = input_stream_->SkipNBytes(bytes_to_skip - (limit_ - pos_)); + absl::Status s = input_stream_->SkipNBytes(bytes_to_skip - (limit_ - pos_)); pos_ = 0; limit_ = 0; if (absl::IsOutOfRange(s)) { @@ -152,14 +153,14 @@ Status BufferedInputStream::SkipNBytes(int64_t bytes_to_skip) { } return s; } - return OkStatus(); + return absl::OkStatus(); } int64_t BufferedInputStream::Tell() const { return input_stream_->Tell() - (limit_ - pos_); } -Status BufferedInputStream::Seek(int64_t position) { +absl::Status BufferedInputStream::Seek(int64_t position) { if (position < 0) { return errors::InvalidArgument("Seeking to a negative position: ", position); @@ -176,7 +177,7 @@ Status BufferedInputStream::Seek(int64_t position) { if (position < Tell()) { // Seek within buffer before 'pos_' pos_ -= Tell() - position; - return OkStatus(); + return absl::OkStatus(); } // Seek after 'pos_' @@ -184,9 +185,9 @@ Status BufferedInputStream::Seek(int64_t position) { } template -Status BufferedInputStream::ReadAll(T* result) { +absl::Status BufferedInputStream::ReadAll(T* result) { result->clear(); - Status status; + absl::Status status; while (status.ok()) { status = FillBuffer(); if (limit_ == 0) { @@ -198,7 +199,7 @@ Status BufferedInputStream::ReadAll(T* result) { if (absl::IsOutOfRange(status)) { file_status_ = status; - return OkStatus(); + return absl::OkStatus(); } return status; } @@ -206,19 +207,19 @@ Status BufferedInputStream::ReadAll(T* result) { template Status BufferedInputStream::ReadAll(std::string* result); template Status BufferedInputStream::ReadAll(tstring* result); -Status BufferedInputStream::Reset() { +absl::Status BufferedInputStream::Reset() { TF_RETURN_IF_ERROR(input_stream_->Reset()); pos_ = 0; limit_ = 0; - file_status_ = OkStatus(); - return OkStatus(); + file_status_ = absl::OkStatus(); + return absl::OkStatus(); } -Status BufferedInputStream::ReadLine(std::string* result) { +absl::Status BufferedInputStream::ReadLine(std::string* result) { return ReadLineHelper(result, false); } -Status BufferedInputStream::ReadLine(tstring* result) { +absl::Status BufferedInputStream::ReadLine(tstring* result) { return ReadLineHelper(result, false); } @@ -228,8 +229,8 @@ std::string BufferedInputStream::ReadLineAsString() { return result; } -Status BufferedInputStream::SkipLine() { - Status s; +absl::Status BufferedInputStream::SkipLine() { + absl::Status s; bool skipped = false; while (true) { if (pos_ == limit_) { @@ -242,11 +243,11 @@ Status BufferedInputStream::SkipLine() { char c = buf_[pos_++]; skipped = true; if (c == '\n') { - return OkStatus(); + return absl::OkStatus(); } } if (absl::IsOutOfRange(s) && skipped) { - return OkStatus(); + return absl::OkStatus(); } return s; } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.h b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.h index 5318434c63919c..6681f1bbfbed32 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.h @@ -43,9 +43,9 @@ class BufferedInputStream : public InputStreamInterface { ~BufferedInputStream() override; - Status ReadNBytes(int64_t bytes_to_read, tstring* result) override; + absl::Status ReadNBytes(int64_t bytes_to_read, tstring* result) override; - Status SkipNBytes(int64_t bytes_to_skip) override; + absl::Status SkipNBytes(int64_t bytes_to_skip) override; int64_t Tell() const override; @@ -58,7 +58,7 @@ class BufferedInputStream : public InputStreamInterface { // Note: When seeking backwards in a stream, this implementation uses // Reset() + SkipNBytes(), so its performance will be dependent // largely on the performance of SkipNBytes(). - Status Seek(int64_t position); + absl::Status Seek(int64_t position); // Read one text line of data into "*result" until end-of-file or a // \n is read. (The \n is not included in the result.) Overwrites @@ -67,8 +67,8 @@ class BufferedInputStream : public InputStreamInterface { // If successful, returns OK. If we are already at the end of the // file, we return an OUT_OF_RANGE error. Otherwise, we return // some other non-OK status. - Status ReadLine(std::string* result); - Status ReadLine(tstring* result); + absl::Status ReadLine(std::string* result); + absl::Status ReadLine(tstring* result); // Returns one text line of data until end-of-file or a '\n' is read. The '\n' // is included in the result. @@ -83,21 +83,21 @@ class BufferedInputStream : public InputStreamInterface { // If successful, returns OK. If we are already at the end of the // file, we return an OUT_OF_RANGE error. Otherwise, we return // some other non-OK status. - Status SkipLine(); + absl::Status SkipLine(); // Reads the entire contents of the file into *result. // // Note: the amount of memory used by this function call is unbounded, so only // use in ops that expect that behavior. template - Status ReadAll(T* result); + absl::Status ReadAll(T* result); - Status Reset() override; + absl::Status Reset() override; private: - Status FillBuffer(); + absl::Status FillBuffer(); template - Status ReadLineHelper(StringType* result, bool include_eol); + absl::Status ReadLineHelper(StringType* result, bool include_eol); InputStreamInterface* input_stream_; // not owned. size_t size_; // buffer size. @@ -108,7 +108,7 @@ class BufferedInputStream : public InputStreamInterface { bool owns_input_stream_ = false; // When EoF is reached, file_status_ contains the status to skip unnecessary // buffer allocations. - Status file_status_ = OkStatus(); + absl::Status file_status_ = absl::OkStatus(); BufferedInputStream(const BufferedInputStream&) = delete; void operator=(const BufferedInputStream&) = delete; diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc index 56dd88510377bd..ab1f58e0b14a83 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc @@ -36,7 +36,7 @@ class ReadOnceInputStream : public InputStreamInterface { public: ReadOnceInputStream() : start_(true) {} - virtual Status ReadNBytes(int64_t bytes_to_read, tstring* result) { + virtual absl::Status ReadNBytes(int64_t bytes_to_read, tstring* result) { if (bytes_to_read < 11) { return errors::InvalidArgument("Not reading all bytes: ", bytes_to_read); } @@ -52,9 +52,9 @@ class ReadOnceInputStream : public InputStreamInterface { int64_t Tell() const override { return start_ ? 0 : 10; } // Resets the stream to the beginning. - Status Reset() override { + absl::Status Reset() override { start_ = true; - return OkStatus(); + return absl::OkStatus(); } private: @@ -311,7 +311,7 @@ TEST(BufferedInputStream, OutOfRangeCache) { TF_ASSERT_OK((in.ReadNBytes(7, &read))); EXPECT_EQ(read, "3456789"); EXPECT_EQ(10, in.Tell()); - Status s = in.ReadNBytes(5, &read); + absl::Status s = in.ReadNBytes(5, &read); // Make sure the read is failing with OUT_OF_RANGE error. If it is failing // with other errors, it is not caching the OUT_OF_RANGE properly. EXPECT_EQ(error::OUT_OF_RANGE, s.code()) << s; diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/format.cc b/third_party/xla/third_party/tsl/tsl/lib/io/format.cc index bc12656f7fbec7..d0b20da64a385e 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/format.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/format.cc @@ -36,9 +36,9 @@ void BlockHandle::EncodeTo(string* dst) const { core::PutVarint64(dst, size_); } -Status BlockHandle::DecodeFrom(StringPiece* input) { +absl::Status BlockHandle::DecodeFrom(StringPiece* input) { if (core::GetVarint64(input, &offset_) && core::GetVarint64(input, &size_)) { - return OkStatus(); + return absl::OkStatus(); } else { return errors::DataLoss("bad block handle"); } @@ -56,7 +56,7 @@ void Footer::EncodeTo(string* dst) const { assert(dst->size() == original_size + kEncodedLength); } -Status Footer::DecodeFrom(StringPiece* input) { +absl::Status Footer::DecodeFrom(StringPiece* input) { const char* magic_ptr = input->data() + kEncodedLength - 8; const uint32 magic_lo = core::DecodeFixed32(magic_ptr); const uint32 magic_hi = core::DecodeFixed32(magic_ptr + 4); @@ -66,7 +66,7 @@ Status Footer::DecodeFrom(StringPiece* input) { return errors::DataLoss("not an sstable (bad magic number)"); } - Status result = metaindex_handle_.DecodeFrom(input); + absl::Status result = metaindex_handle_.DecodeFrom(input); if (result.ok()) { result = index_handle_.DecodeFrom(input); } @@ -78,8 +78,8 @@ Status Footer::DecodeFrom(StringPiece* input) { return result; } -Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, - BlockContents* result) { +absl::Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, + BlockContents* result) { result->data = StringPiece(); result->cacheable = false; result->heap_allocated = false; @@ -94,7 +94,8 @@ Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, char* buf = new char[n + kBlockTrailerSize]; StringPiece contents; - Status s = file->Read(handle.offset(), n + kBlockTrailerSize, &contents, buf); + absl::Status s = + file->Read(handle.offset(), n + kBlockTrailerSize, &contents, buf); if (!s.ok()) { delete[] buf; return s; @@ -159,7 +160,7 @@ Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, return errors::DataLoss("bad block type"); } - return OkStatus(); + return absl::OkStatus(); } } // namespace table diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/format.h b/third_party/xla/third_party/tsl/tsl/lib/io/format.h index cd8e863435f440..ae5bb26b8b8c86 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/format.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/format.h @@ -46,7 +46,7 @@ class BlockHandle { void set_size(uint64 size) { size_ = size; } void EncodeTo(string* dst) const; - Status DecodeFrom(StringPiece* input); + absl::Status DecodeFrom(StringPiece* input); // Maximum encoding length of a BlockHandle enum { kMaxEncodedLength = 10 + 10 }; @@ -71,7 +71,7 @@ class Footer { void set_index_handle(const BlockHandle& h) { index_handle_ = h; } void EncodeTo(string* dst) const; - Status DecodeFrom(StringPiece* input); + absl::Status DecodeFrom(StringPiece* input); // Encoded length of a Footer. Note that the serialization of a // Footer will always occupy exactly this many bytes. It consists @@ -99,8 +99,8 @@ struct BlockContents { // Read the block identified by "handle" from "file". On failure // return non-OK. On success fill *result and return OK. -extern Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, - BlockContents* result); +extern absl::Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, + BlockContents* result); // Implementation details follow. Clients should ignore, diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.cc b/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.cc index 1b9e2dc6fe2b21..3c183ee1ae1b3c 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.cc @@ -33,9 +33,9 @@ InputBuffer::InputBuffer(RandomAccessFile* file, size_t buffer_bytes) InputBuffer::~InputBuffer() { delete[] buf_; } -Status InputBuffer::FillBuffer() { +absl::Status InputBuffer::FillBuffer() { StringPiece data; - Status s = file_->Read(file_pos_, size_, &data, buf_); + absl::Status s = file_->Read(file_pos_, size_, &data, buf_); if (data.data() != buf_) { memmove(buf_, data.data(), data.size()); } @@ -46,9 +46,9 @@ Status InputBuffer::FillBuffer() { } template -Status InputBuffer::ReadLine(T* result) { +absl::Status InputBuffer::ReadLine(T* result) { result->clear(); - Status s; + absl::Status s; do { size_t buf_remain = limit_ - pos_; char* newline = static_cast(memchr(pos_, '\n', buf_remain)); @@ -59,7 +59,7 @@ Status InputBuffer::ReadLine(T* result) { if (!result->empty() && result->back() == '\r') { result->resize(result->size() - 1); } - return OkStatus(); + return absl::OkStatus(); } if (buf_remain > 0) result->append(pos_, buf_remain); // Get more data into buffer @@ -70,7 +70,7 @@ Status InputBuffer::ReadLine(T* result) { result->resize(result->size() - 1); } if (errors::IsOutOfRange(s) && !result->empty()) { - return OkStatus(); + return absl::OkStatus(); } return s; } @@ -78,7 +78,8 @@ Status InputBuffer::ReadLine(T* result) { template Status InputBuffer::ReadLine(std::string* result); template Status InputBuffer::ReadLine(tstring* result); -Status InputBuffer::ReadNBytes(int64_t bytes_to_read, std::string* result) { +absl::Status InputBuffer::ReadNBytes(int64_t bytes_to_read, + std::string* result) { result->clear(); if (bytes_to_read < 0) { return errors::InvalidArgument("Can't read a negative number of bytes: ", @@ -86,18 +87,18 @@ Status InputBuffer::ReadNBytes(int64_t bytes_to_read, std::string* result) { } result->resize(bytes_to_read); size_t bytes_read = 0; - Status status = ReadNBytes(bytes_to_read, &(*result)[0], &bytes_read); + absl::Status status = ReadNBytes(bytes_to_read, &(*result)[0], &bytes_read); if (bytes_read < bytes_to_read) result->resize(bytes_read); return status; } -Status InputBuffer::ReadNBytes(int64_t bytes_to_read, char* result, - size_t* bytes_read) { +absl::Status InputBuffer::ReadNBytes(int64_t bytes_to_read, char* result, + size_t* bytes_read) { if (bytes_to_read < 0) { return errors::InvalidArgument("Can't read a negative number of bytes: ", bytes_to_read); } - Status status; + absl::Status status; *bytes_read = 0; while (*bytes_read < static_cast(bytes_to_read)) { if (pos_ == limit_) { @@ -117,21 +118,21 @@ Status InputBuffer::ReadNBytes(int64_t bytes_to_read, char* result, } if (errors::IsOutOfRange(status) && (*bytes_read == static_cast(bytes_to_read))) { - return OkStatus(); + return absl::OkStatus(); } return status; } -Status InputBuffer::ReadVarint32Fallback(uint32* result) { - Status s = ReadVarintFallback(result, core::kMaxVarint32Bytes); +absl::Status InputBuffer::ReadVarint32Fallback(uint32* result) { + absl::Status s = ReadVarintFallback(result, core::kMaxVarint32Bytes); if (errors::IsDataLoss(s)) { return errors::DataLoss("Stored data is too large to be a varint32."); } return s; } -Status InputBuffer::ReadVarint64Fallback(uint64* result) { - Status s = ReadVarintFallback(result, core::kMaxVarint64Bytes); +absl::Status InputBuffer::ReadVarint64Fallback(uint64* result) { + absl::Status s = ReadVarintFallback(result, core::kMaxVarint64Bytes); if (errors::IsDataLoss(s)) { return errors::DataLoss("Stored data is too large to be a varint64."); } @@ -139,7 +140,7 @@ Status InputBuffer::ReadVarint64Fallback(uint64* result) { } template -Status InputBuffer::ReadVarintFallback(T* result, int max_bytes) { +absl::Status InputBuffer::ReadVarintFallback(T* result, int max_bytes) { uint8 scratch = 0; auto* p = reinterpret_cast(&scratch); size_t unused_bytes_read = 0; @@ -149,18 +150,18 @@ Status InputBuffer::ReadVarintFallback(T* result, int max_bytes) { int shift = 7 * index; TF_RETURN_IF_ERROR(ReadNBytes(1, p, &unused_bytes_read)); *result |= (static_cast(scratch) & 127) << shift; - if (!(scratch & 128)) return OkStatus(); + if (!(scratch & 128)) return absl::OkStatus(); } return errors::DataLoss("Stored data longer than ", max_bytes, " bytes."); } -Status InputBuffer::SkipNBytes(int64_t bytes_to_skip) { +absl::Status InputBuffer::SkipNBytes(int64_t bytes_to_skip) { if (bytes_to_skip < 0) { return errors::InvalidArgument("Can only skip forward, not ", bytes_to_skip); } int64_t bytes_skipped = 0; - Status s; + absl::Status s; while (bytes_skipped < bytes_to_skip) { if (pos_ == limit_) { // Get more data into buffer @@ -175,12 +176,12 @@ Status InputBuffer::SkipNBytes(int64_t bytes_to_skip) { pos_ += bytes_to_advance; } if (errors::IsOutOfRange(s) && bytes_skipped == bytes_to_skip) { - return OkStatus(); + return absl::OkStatus(); } return s; } -Status InputBuffer::Seek(int64_t position) { +absl::Status InputBuffer::Seek(int64_t position) { if (position < 0) { return errors::InvalidArgument("Seeking to a negative position: ", position); @@ -196,10 +197,10 @@ Status InputBuffer::Seek(int64_t position) { pos_ = limit_ = buf_; file_pos_ = position; } - return OkStatus(); + return absl::OkStatus(); } -Status InputBuffer::Hint(int64_t bytes_to_read) { +absl::Status InputBuffer::Hint(int64_t bytes_to_read) { if (bytes_to_read < 0) { return errors::InvalidArgument("Can't read a negative number of bytes: ", bytes_to_read); @@ -207,14 +208,14 @@ Status InputBuffer::Hint(int64_t bytes_to_read) { // The internal buffer is too small. Do nothing. if (bytes_to_read > size_) { - return OkStatus(); + return absl::OkStatus(); } const int64_t bytes_remain_in_buf = static_cast(limit_ - pos_); // There are enough data in the buffer. Do nothing. if (bytes_to_read <= bytes_remain_in_buf) { - return OkStatus(); + return absl::OkStatus(); } // Additional read from file is necessary. Make some room. @@ -225,7 +226,7 @@ Status InputBuffer::Hint(int64_t bytes_to_read) { // Read the remaining bytes from file. StringPiece data; - Status s = file_->Read(file_pos_, bytes_to_read, &data, limit_); + absl::Status s = file_->Read(file_pos_, bytes_to_read, &data, limit_); if (data.data() != limit_) { memmove(limit_, data.data(), data.size()); } @@ -233,7 +234,7 @@ Status InputBuffer::Hint(int64_t bytes_to_read) { file_pos_ += data.size(); if (errors::IsOutOfRange(s) && data.size() == bytes_to_read) { - return OkStatus(); + return absl::OkStatus(); } else { return s; } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.h b/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.h index e357efb5f75b53..57a4a983c11e75 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.h @@ -45,38 +45,39 @@ class InputBuffer { // file, we return an OUT_OF_RANGE error. Otherwise, we return // some other non-OK status. template - Status ReadLine(T* result); + absl::Status ReadLine(T* result); // Reads bytes_to_read bytes into *result, overwriting *result. // // If successful, returns OK. If we there are not enough bytes to // read before the end of the file, we return an OUT_OF_RANGE error. // Otherwise, we return some other non-OK status. - Status ReadNBytes(int64_t bytes_to_read, std::string* result); + absl::Status ReadNBytes(int64_t bytes_to_read, std::string* result); // An overload that writes to char*. Caller must ensure result[0, // bytes_to_read) is valid to be overwritten. Returns OK iff "*bytes_read == // bytes_to_read". - Status ReadNBytes(int64_t bytes_to_read, char* result, size_t* bytes_read); + absl::Status ReadNBytes(int64_t bytes_to_read, char* result, + size_t* bytes_read); // Reads a single varint32. - Status ReadVarint32(uint32* result); + absl::Status ReadVarint32(uint32* result); // Reads a single varint64. - Status ReadVarint64(uint64* result); + absl::Status ReadVarint64(uint64* result); // Like ReadNBytes() without returning the bytes read. - Status SkipNBytes(int64_t bytes_to_skip); + absl::Status SkipNBytes(int64_t bytes_to_skip); // Seek to this offset within the file. // // If we seek to somewhere within our pre-buffered data, we will re-use what // data we can. Otherwise, Seek() throws out the current buffer and the next // read will trigger a File::Read(). - Status Seek(int64_t position); + absl::Status Seek(int64_t position); // Provides a hint about future reads, which may improve their performance. - Status Hint(int64_t bytes_to_read); + absl::Status Hint(int64_t bytes_to_read); // Returns the position in the file. int64_t Tell() const { return file_pos_ - (limit_ - pos_); } @@ -85,19 +86,19 @@ class InputBuffer { RandomAccessFile* file() const { return file_; } private: - Status FillBuffer(); + absl::Status FillBuffer(); // Internal slow-path routine used by ReadVarint32(). - Status ReadVarint32Fallback(uint32* result); + absl::Status ReadVarint32Fallback(uint32* result); // Internal slow-path routine used by ReadVarint64(). - Status ReadVarint64Fallback(uint64* result); + absl::Status ReadVarint64Fallback(uint64* result); // Helper method for reading a varint which can span at max `max_bytes`. // If the varint is longer, a DataLoss error status is returned. // If end of file is reached while reading, OutOfRange error is returned. template - Status ReadVarintFallback(T* result, int max_bytes); + absl::Status ReadVarintFallback(T* result, int max_bytes); RandomAccessFile* file_; // Not owned int64_t file_pos_; // Next position to read from in "file_" @@ -118,28 +119,28 @@ extern template Status InputBuffer::ReadLine(std::string* result); extern template Status InputBuffer::ReadLine(tstring* result); // Inlined for performance. -inline Status InputBuffer::ReadVarint32(uint32* result) { +inline absl::Status InputBuffer::ReadVarint32(uint32* result) { if (pos_ + core::kMaxVarint32Bytes <= limit_) { // Fast path: directly parse from buffered data. // Reads strictly from the range [pos_, limit_). const char* offset = core::GetVarint32Ptr(pos_, limit_, result); if (offset == nullptr) return errors::OutOfRange("Parsed past limit."); pos_ = const_cast(offset); - return OkStatus(); + return absl::OkStatus(); } else { return ReadVarint32Fallback(result); } } // Inlined for performance. -inline Status InputBuffer::ReadVarint64(uint64* result) { +inline absl::Status InputBuffer::ReadVarint64(uint64* result) { if (pos_ + core::kMaxVarint64Bytes <= limit_) { // Fast path: directly parse from buffered data. // Reads strictly from the range [pos_, limit_). const char* offset = core::GetVarint64Ptr(pos_, limit_, result); if (offset == nullptr) return errors::OutOfRange("Parsed past limit."); pos_ = const_cast(offset); - return OkStatus(); + return absl::OkStatus(); } else { return ReadVarint64Fallback(result); } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.cc b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.cc index 1a2f11d4d2b2b2..6425ff0656b658 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.cc @@ -24,7 +24,7 @@ namespace io { // 8MB at a time. static constexpr int64_t kMaxSkipSize = 8 * 1024 * 1024; -Status InputStreamInterface::SkipNBytes(int64_t bytes_to_skip) { +absl::Status InputStreamInterface::SkipNBytes(int64_t bytes_to_skip) { if (bytes_to_skip < 0) { return errors::InvalidArgument("Can't skip a negative number of bytes"); } @@ -35,7 +35,7 @@ Status InputStreamInterface::SkipNBytes(int64_t bytes_to_skip) { TF_RETURN_IF_ERROR(ReadNBytes(bytes_to_read, &unused)); bytes_to_skip -= bytes_to_read; } - return OkStatus(); + return absl::OkStatus(); } } // namespace io diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.h b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.h index afe87a4b9cc37e..8eb7f2ad868965 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.h @@ -35,13 +35,13 @@ class InputStreamInterface { // Reads the next bytes_to_read from the file. Typical return codes: // * OK - in case of success. // * OUT_OF_RANGE - not enough bytes remaining before end of file. - virtual Status ReadNBytes(int64_t bytes_to_read, tstring* result) = 0; + virtual absl::Status ReadNBytes(int64_t bytes_to_read, tstring* result) = 0; #if defined(TF_CORD_SUPPORT) // Reads the next bytes_to_read from the file. Typical return codes: // * OK - in case of success. // * OUT_OF_RANGE - not enough bytes remaining before end of file. - virtual Status ReadNBytes(int64_t bytes_to_read, absl::Cord* cord) { + virtual absl::Status ReadNBytes(int64_t bytes_to_read, absl::Cord* cord) { return errors::Unimplemented( "ReadNBytes(int64, absl::Cord*) is not implemented."); } @@ -51,7 +51,7 @@ class InputStreamInterface { // Typical return codes: // * OK - in case of success. // * OUT_OF_RANGE - not enough bytes remaining before end of file. - virtual Status SkipNBytes(int64_t bytes_to_skip); + virtual absl::Status SkipNBytes(int64_t bytes_to_skip); // Return the offset of the current byte relative to the beginning of the // file. @@ -61,7 +61,7 @@ class InputStreamInterface { virtual int64_t Tell() const = 0; // Resets the stream to the beginning. - virtual Status Reset() = 0; + virtual absl::Status Reset() = 0; }; } // namespace io diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc index 2f7cda954fd13d..23d4fb0ddf50bc 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc @@ -27,21 +27,21 @@ class TestStringStream : public InputStreamInterface { public: explicit TestStringStream(const string& content) : content_(content) {} - Status ReadNBytes(int64_t bytes_to_read, tstring* result) override { + absl::Status ReadNBytes(int64_t bytes_to_read, tstring* result) override { result->clear(); if (pos_ + bytes_to_read > content_.size()) { return errors::OutOfRange("limit reached"); } *result = content_.substr(pos_, bytes_to_read); pos_ += bytes_to_read; - return OkStatus(); + return absl::OkStatus(); } int64_t Tell() const override { return pos_; } - Status Reset() override { + absl::Status Reset() override { pos_ = 0; - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/iterator.cc b/third_party/xla/third_party/tsl/tsl/lib/io/iterator.cc index 4dff9eb4f61761..a02a4254985087 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/iterator.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/iterator.cc @@ -53,7 +53,7 @@ void Iterator::RegisterCleanup(CleanupFunction func, void* arg1, void* arg2) { namespace { class EmptyIterator : public Iterator { public: - explicit EmptyIterator(const Status& s) : status_(s) {} + explicit EmptyIterator(const absl::Status& s) : status_(s) {} bool Valid() const override { return false; } void Seek(const StringPiece& target) override {} void SeekToFirst() override {} @@ -66,16 +66,16 @@ class EmptyIterator : public Iterator { assert(false); return StringPiece(); } - Status status() const override { return status_; } + absl::Status status() const override { return status_; } private: - Status status_; + absl::Status status_; }; } // namespace -Iterator* NewEmptyIterator() { return new EmptyIterator(OkStatus()); } +Iterator* NewEmptyIterator() { return new EmptyIterator(absl::OkStatus()); } -Iterator* NewErrorIterator(const Status& status) { +Iterator* NewErrorIterator(const absl::Status& status) { return new EmptyIterator(status); } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/iterator.h b/third_party/xla/third_party/tsl/tsl/lib/io/iterator.h index bb83f41ea47dd9..f0b16943c44b9c 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/iterator.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/iterator.h @@ -68,7 +68,7 @@ class Iterator { virtual StringPiece value() const = 0; // If an error has occurred, return it. Else return an ok status. - virtual Status status() const = 0; + virtual absl::Status status() const = 0; // Clients are allowed to register function/arg1/arg2 triples that // will be invoked when this iterator is destroyed. @@ -96,7 +96,7 @@ class Iterator { extern Iterator* NewEmptyIterator(); // Return an empty iterator with the specified status. -extern Iterator* NewErrorIterator(const Status& status); +extern Iterator* NewErrorIterator(const absl::Status& status); } // namespace table } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.cc b/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.cc index 1b5262057771b7..841e3d1bf26f6c 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.cc @@ -30,8 +30,8 @@ RandomAccessInputStream::~RandomAccessInputStream() { } } -Status RandomAccessInputStream::ReadNBytes(int64_t bytes_to_read, - tstring* result) { +absl::Status RandomAccessInputStream::ReadNBytes(int64_t bytes_to_read, + tstring* result) { if (bytes_to_read < 0) { return errors::InvalidArgument("Cannot read negative number of bytes"); } @@ -39,7 +39,7 @@ Status RandomAccessInputStream::ReadNBytes(int64_t bytes_to_read, result->resize_uninitialized(bytes_to_read); char* result_buffer = &(*result)[0]; StringPiece data; - Status s = file_->Read(pos_, bytes_to_read, &data, result_buffer); + absl::Status s = file_->Read(pos_, bytes_to_read, &data, result_buffer); if (data.data() != result_buffer) { memmove(result_buffer, data.data(), data.size()); } @@ -51,13 +51,13 @@ Status RandomAccessInputStream::ReadNBytes(int64_t bytes_to_read, } #if defined(TF_CORD_SUPPORT) -Status RandomAccessInputStream::ReadNBytes(int64_t bytes_to_read, - absl::Cord* result) { +absl::Status RandomAccessInputStream::ReadNBytes(int64_t bytes_to_read, + absl::Cord* result) { if (bytes_to_read < 0) { return errors::InvalidArgument("Cannot read negative number of bytes"); } int64_t current_size = result->size(); - Status s = file_->Read(pos_, bytes_to_read, result); + absl::Status s = file_->Read(pos_, bytes_to_read, result); if (s.ok() || errors::IsOutOfRange(s)) { pos_ += result->size() - current_size; } @@ -69,7 +69,7 @@ Status RandomAccessInputStream::ReadNBytes(int64_t bytes_to_read, // 8MB at a time. static constexpr int64_t kMaxSkipSize = 8 * 1024 * 1024; -Status RandomAccessInputStream::SkipNBytes(int64_t bytes_to_skip) { +absl::Status RandomAccessInputStream::SkipNBytes(int64_t bytes_to_skip) { if (bytes_to_skip < 0) { return errors::InvalidArgument("Can't skip a negative number of bytes"); } @@ -78,17 +78,18 @@ Status RandomAccessInputStream::SkipNBytes(int64_t bytes_to_skip) { // not reached yet and we could return. if (bytes_to_skip > 0) { StringPiece data; - Status s = file_->Read(pos_ + bytes_to_skip - 1, 1, &data, scratch.get()); + absl::Status s = + file_->Read(pos_ + bytes_to_skip - 1, 1, &data, scratch.get()); if ((s.ok() || errors::IsOutOfRange(s)) && data.size() == 1) { pos_ += bytes_to_skip; - return OkStatus(); + return absl::OkStatus(); } } // Read kDefaultSkipSize at a time till bytes_to_skip. while (bytes_to_skip > 0) { int64_t bytes_to_read = std::min(kMaxSkipSize, bytes_to_skip); StringPiece data; - Status s = file_->Read(pos_, bytes_to_read, &data, scratch.get()); + absl::Status s = file_->Read(pos_, bytes_to_read, &data, scratch.get()); if (s.ok() || errors::IsOutOfRange(s)) { pos_ += data.size(); } else { @@ -99,7 +100,7 @@ Status RandomAccessInputStream::SkipNBytes(int64_t bytes_to_skip) { } bytes_to_skip -= bytes_to_read; } - return OkStatus(); + return absl::OkStatus(); } int64_t RandomAccessInputStream::Tell() const { return pos_; } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.h b/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.h index e1608ce3ec2b9b..4d48db62c2b03f 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.h @@ -33,22 +33,22 @@ class RandomAccessInputStream : public InputStreamInterface { ~RandomAccessInputStream() override; - Status ReadNBytes(int64_t bytes_to_read, tstring* result) override; + absl::Status ReadNBytes(int64_t bytes_to_read, tstring* result) override; #if defined(TF_CORD_SUPPORT) - Status ReadNBytes(int64_t bytes_to_read, absl::Cord* result) override; + absl::Status ReadNBytes(int64_t bytes_to_read, absl::Cord* result) override; #endif - Status SkipNBytes(int64_t bytes_to_skip) override; + absl::Status SkipNBytes(int64_t bytes_to_skip) override; int64_t Tell() const override; - Status Seek(int64_t position) { + absl::Status Seek(int64_t position) { pos_ = position; - return OkStatus(); + return absl::OkStatus(); } - Status Reset() override { return Seek(0); } + absl::Status Reset() override { return Seek(0); } private: RandomAccessFile* file_; // Not owned. diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.cc b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.cc index e267b5cee84dab..8d17c610b09f71 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.cc @@ -101,7 +101,8 @@ inline const char* GetChecksumErrorSuffix(uint64 offset) { // and is used only in error messages. For failures at offset 0, // a reminder about the file format is added, because TFRecord files // contain no explicit format marker. -Status RecordReader::ReadChecksummed(uint64 offset, size_t n, tstring* result) { +absl::Status RecordReader::ReadChecksummed(uint64 offset, size_t n, + tstring* result) { if (n >= SIZE_MAX - sizeof(uint32)) { return errors::DataLoss("record size too large", GetChecksumErrorSuffix(offset)); @@ -125,10 +126,10 @@ Status RecordReader::ReadChecksummed(uint64 offset, size_t n, tstring* result) { GetChecksumErrorSuffix(offset)); } result->resize(n); - return OkStatus(); + return absl::OkStatus(); } -Status RecordReader::GetMetadata(Metadata* md) { +absl::Status RecordReader::GetMetadata(Metadata* md) { if (!md) { return errors::InvalidArgument( "Metadata object call to GetMetadata() was null"); @@ -148,7 +149,7 @@ Status RecordReader::GetMetadata(Metadata* md) { tstring record; while (true) { // Read header, containing size of data. - Status s = ReadChecksummed(offset, sizeof(uint64), &record); + absl::Status s = ReadChecksummed(offset, sizeof(uint64), &record); if (!s.ok()) { if (errors::IsOutOfRange(s)) { // We should reach out of range when the record file is complete. @@ -178,10 +179,10 @@ Status RecordReader::GetMetadata(Metadata* md) { } md->stats = cached_metadata_->stats; - return OkStatus(); + return absl::OkStatus(); } -Status RecordReader::PositionInputStream(uint64 offset) { +absl::Status RecordReader::PositionInputStream(uint64 offset) { int64_t curr_pos = input_stream_->Tell(); int64_t desired_pos = static_cast(offset); if (curr_pos > desired_pos || curr_pos < 0 /* EOF */ || @@ -193,14 +194,14 @@ Status RecordReader::PositionInputStream(uint64 offset) { TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(desired_pos - curr_pos)); } DCHECK_EQ(desired_pos, input_stream_->Tell()); - return OkStatus(); + return absl::OkStatus(); } -Status RecordReader::ReadRecord(uint64* offset, tstring* record) { +absl::Status RecordReader::ReadRecord(uint64* offset, tstring* record) { TF_RETURN_IF_ERROR(PositionInputStream(*offset)); // Read header data. - Status s = ReadChecksummed(*offset, sizeof(uint64), record); + absl::Status s = ReadChecksummed(*offset, sizeof(uint64), record); if (!s.ok()) { last_read_failed_ = true; return s; @@ -220,14 +221,14 @@ Status RecordReader::ReadRecord(uint64* offset, tstring* record) { *offset += kHeaderSize + length + kFooterSize; DCHECK_EQ(*offset, input_stream_->Tell()); - return OkStatus(); + return absl::OkStatus(); } -Status RecordReader::SkipRecords(uint64* offset, int num_to_skip, - int* num_skipped) { +absl::Status RecordReader::SkipRecords(uint64* offset, int num_to_skip, + int* num_skipped) { TF_RETURN_IF_ERROR(PositionInputStream(*offset)); - Status s; + absl::Status s; tstring record; *num_skipped = 0; for (int i = 0; i < num_to_skip; ++i) { @@ -252,7 +253,7 @@ Status RecordReader::SkipRecords(uint64* offset, int num_to_skip, DCHECK_EQ(*offset, input_stream_->Tell()); (*num_skipped)++; } - return OkStatus(); + return absl::OkStatus(); } SequentialRecordReader::SequentialRecordReader( diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.h b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.h index 282c0daff2a5a8..61540a657324c8 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.h @@ -94,14 +94,14 @@ class RecordReader { // Read the record at "*offset" into *record and update *offset to // point to the offset of the next record. Returns OK on success, // OUT_OF_RANGE for end of file, or something else for an error. - Status ReadRecord(uint64* offset, tstring* record); + absl::Status ReadRecord(uint64* offset, tstring* record); // Skip num_to_skip record starting at "*offset" and update *offset // to point to the offset of the next num_to_skip + 1 record. // Return OK on success, OUT_OF_RANGE for end of file, or something // else for an error. "*num_skipped" records the number of records that // are actually skipped. It should be equal to num_to_skip on success. - Status SkipRecords(uint64* offset, int num_to_skip, int* num_skipped); + absl::Status SkipRecords(uint64* offset, int num_to_skip, int* num_skipped); // Return the metadata of the Record file. // @@ -112,11 +112,11 @@ class RecordReader { // so that GetMetadata() could be a const method. // // 'metadata' must not be nullptr. - Status GetMetadata(Metadata* md); + absl::Status GetMetadata(Metadata* md); private: - Status ReadChecksummed(uint64 offset, size_t n, tstring* result); - Status PositionInputStream(uint64 offset); + absl::Status ReadChecksummed(uint64 offset, size_t n, tstring* result); + absl::Status PositionInputStream(uint64 offset); RecordReaderOptions options_; std::unique_ptr input_stream_; @@ -143,7 +143,7 @@ class SequentialRecordReader { // Read the next record in the file into *record. Returns OK on success, // OUT_OF_RANGE for end of file, or something else for an error. - Status ReadRecord(tstring* record) { + absl::Status ReadRecord(tstring* record) { return underlying_.ReadRecord(&offset_, record); } @@ -151,7 +151,7 @@ class SequentialRecordReader { // OUT_OF_RANGE for end of file, or something else for an error. // "*num_skipped" records the number of records that are actually skipped. // It should be equal to num_to_skip on success. - Status SkipRecords(int num_to_skip, int* num_skipped) { + absl::Status SkipRecords(int num_to_skip, int* num_skipped) { return underlying_.SkipRecords(&offset_, num_to_skip, num_skipped); } @@ -160,13 +160,13 @@ class SequentialRecordReader { // Seek to this offset within the file and set this offset as the current // offset. Trying to seek backward will throw error. - Status SeekOffset(uint64 offset) { + absl::Status SeekOffset(uint64 offset) { if (offset < offset_) return errors::InvalidArgument( "Trying to seek offset: ", offset, " which is less than the current offset: ", offset_); offset_ = offset; - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc index 2497db348a5729..67df783112f9ee 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc @@ -226,7 +226,7 @@ TEST(RecordReaderWriterTest, TestSkipOutOfRange) { uint64 offset = 0; int num_skipped; tstring record; - Status s = reader.SkipRecords(&offset, 3, &num_skipped); + absl::Status s = reader.SkipRecords(&offset, 3, &num_skipped); EXPECT_EQ(2, num_skipped); EXPECT_EQ(error::OUT_OF_RANGE, s.code()); } @@ -254,7 +254,7 @@ TEST(RecordReaderWriterTest, TestMalformedInput) { tstring record; // At offset 0, the error message reminds of the file type. uint64 offset = 0; - Status s = reader.ReadRecord(&offset, &record); + absl::Status s = reader.ReadRecord(&offset, &record); EXPECT_EQ(error::DATA_LOSS, s.code()); EXPECT_EQ("corrupted record at 0 (Is this even a TFRecord file?)", s.message()); diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.cc b/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.cc index aace9f10e14c6d..9a6a932dd77a26 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.cc @@ -69,7 +69,7 @@ RecordWriter::RecordWriter(WritableFile* dest, ZlibOutputBuffer* zlib_output_buffer = new ZlibOutputBuffer( dest, options.zlib_options.input_buffer_size, options.zlib_options.output_buffer_size, options.zlib_options); - Status s = zlib_output_buffer->Init(); + absl::Status s = zlib_output_buffer->Init(); if (!s.ok()) { LOG(FATAL) << "Failed to initialize Zlib inputbuffer. Error: " << s.ToString(); @@ -89,17 +89,17 @@ RecordWriter::RecordWriter(WritableFile* dest, RecordWriter::~RecordWriter() { if (dest_ != nullptr) { - Status s = Close(); + absl::Status s = Close(); if (!s.ok()) { LOG(ERROR) << "Could not finish writing file: " << s; } } } -Status RecordWriter::WriteRecord(StringPiece data) { +absl::Status RecordWriter::WriteRecord(StringPiece data) { if (dest_ == nullptr) { - return Status(absl::StatusCode::kFailedPrecondition, - "Writer not initialized or previously closed"); + return absl::Status(absl::StatusCode::kFailedPrecondition, + "Writer not initialized or previously closed"); } // Format of a single record: // uint64 length @@ -116,10 +116,10 @@ Status RecordWriter::WriteRecord(StringPiece data) { } #if defined(TF_CORD_SUPPORT) -Status RecordWriter::WriteRecord(const absl::Cord& data) { +absl::Status RecordWriter::WriteRecord(const absl::Cord& data) { if (dest_ == nullptr) { - return Status(absl::StatusCode::kFailedPrecondition, - "Writer not initialized or previously closed"); + return absl::Status(absl::StatusCode::kFailedPrecondition, + "Writer not initialized or previously closed"); } // Format of a single record: // uint64 length @@ -136,21 +136,21 @@ Status RecordWriter::WriteRecord(const absl::Cord& data) { } #endif -Status RecordWriter::Close() { - if (dest_ == nullptr) return OkStatus(); +absl::Status RecordWriter::Close() { + if (dest_ == nullptr) return absl::OkStatus(); if (IsZlibCompressed(options_) || IsSnappyCompressed(options_)) { - Status s = dest_->Close(); + absl::Status s = dest_->Close(); delete dest_; dest_ = nullptr; return s; } - return OkStatus(); + return absl::OkStatus(); } -Status RecordWriter::Flush() { +absl::Status RecordWriter::Flush() { if (dest_ == nullptr) { - return Status(absl::StatusCode::kFailedPrecondition, - "Writer not initialized or previously closed"); + return absl::Status(absl::StatusCode::kFailedPrecondition, + "Writer not initialized or previously closed"); } return dest_->Flush(); } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.h b/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.h index b585cb9b52f70c..06e9a5c847910c 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.h @@ -77,22 +77,22 @@ class RecordWriter { // implicit Close() call in the destructor. ~RecordWriter(); - Status WriteRecord(StringPiece data); + absl::Status WriteRecord(StringPiece data); #if defined(TF_CORD_SUPPORT) - Status WriteRecord(const absl::Cord& data); + absl::Status WriteRecord(const absl::Cord& data); #endif // Flushes any buffered data held by underlying containers of the // RecordWriter to the WritableFile. Does *not* flush the // WritableFile. - Status Flush(); + absl::Status Flush(); // Writes all output to the file. Does *not* close the WritableFile. // // After calling Close(), any further calls to `WriteRecord()` or `Flush()` // are invalid. - Status Close(); + absl::Status Close(); // Utility method to populate TFRecord headers. Populates record-header in // "header[0,kHeaderSize-1]". The record-header is based on data[0, n-1]. diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc index f9702f2ed13997..42adf76f7ef0d3 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc @@ -55,22 +55,22 @@ class StringDest : public WritableFile { public: explicit StringDest(string* contents) : contents_(contents) {} - Status Close() override { return OkStatus(); } - Status Flush() override { return OkStatus(); } - Status Sync() override { return OkStatus(); } - Status Append(StringPiece slice) override { + absl::Status Close() override { return absl::OkStatus(); } + absl::Status Flush() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } + absl::Status Append(StringPiece slice) override { contents_->append(slice.data(), slice.size()); - return OkStatus(); + return absl::OkStatus(); } #if defined(TF_CORD_SUPPORT) - Status Append(const absl::Cord& data) override { + absl::Status Append(const absl::Cord& data) override { contents_->append(std::string(data)); - return OkStatus(); + return absl::OkStatus(); } #endif - Status Tell(int64_t* pos) override { + absl::Status Tell(int64_t* pos) override { *pos = contents_->size(); - return OkStatus(); + return absl::OkStatus(); } private: @@ -82,8 +82,8 @@ class StringSource : public RandomAccessFile { explicit StringSource(string* contents) : contents_(contents), force_error_(false) {} - Status Read(uint64 offset, size_t n, StringPiece* result, - char* scratch) const override { + absl::Status Read(uint64 offset, size_t n, StringPiece* result, + char* scratch) const override { if (force_error_) { force_error_ = false; return errors::DataLoss("read error"); @@ -97,7 +97,7 @@ class StringSource : public RandomAccessFile { n = contents_->size() - offset; } *result = StringPiece(contents_->data() + offset, n); - return OkStatus(); + return absl::OkStatus(); } void force_error() { force_error_ = true; } @@ -150,7 +150,7 @@ class RecordioTest : public ::testing::Test { reading_ = true; } tstring record; - Status s = reader_->ReadRecord(&readpos_, &record); + absl::Status s = reader_->ReadRecord(&readpos_, &record); if (s.ok()) { return record; } else if (errors::IsOutOfRange(s)) { @@ -184,7 +184,7 @@ class RecordioTest : public ::testing::Test { reading_ = true; uint64 offset = WrittenBytes() + offset_past_end; tstring record; - Status s = reader_->ReadRecord(&offset, &record); + absl::Status s = reader_->ReadRecord(&offset, &record); ASSERT_TRUE(errors::IsOutOfRange(s)) << s; } }; @@ -317,7 +317,7 @@ void TestReadError(const RecordWriterOptions& writer_options, uint64 offset = 0; tstring read; file.force_error(); - Status status = reader.ReadRecord(&offset, &read); + absl::Status status = reader.ReadRecord(&offset, &read); ASSERT_TRUE(errors::IsDataLoss(status)); ASSERT_EQ(0, offset); From 89a47214bc7cd1582547ad81c8ed1af7e6e2cee1 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 19 Jun 2024 04:20:29 -0700 Subject: [PATCH 011/256] Refactor llvm_compiler_test. We can run the CpuCompiler and GPUCompiler related tests in separate test targets. PiperOrigin-RevId: 644703636 --- third_party/xla/xla/service/BUILD | 1 + third_party/xla/xla/service/llvm_compiler.cc | 7 +- third_party/xla/xla/tests/BUILD | 38 ++- .../xla/xla/tests/llvm_compiler_test.cc | 232 +++++------------- 4 files changed, 80 insertions(+), 198 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 3d1f6e42012fa7..e1a420ff607158 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -1654,6 +1654,7 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":compiler", + "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Core", "@local_tsl//tsl/platform:denormal", "@local_tsl//tsl/profiler/lib:scoped_annotation", diff --git a/third_party/xla/xla/service/llvm_compiler.cc b/third_party/xla/xla/service/llvm_compiler.cc index 4bbcca0b484cd2..02afa91c5d9ea5 100644 --- a/third_party/xla/xla/service/llvm_compiler.cc +++ b/third_party/xla/xla/service/llvm_compiler.cc @@ -15,6 +15,11 @@ limitations under the License. #include "xla/service/llvm_compiler.h" +#include +#include +#include + +#include "absl/status/statusor.h" #include "tsl/platform/denormal.h" #include "tsl/profiler/lib/scoped_annotation.h" @@ -55,6 +60,6 @@ absl::StatusOr>> LLVMCompiler::Compile( result.push_back(std::move(executable)); } - return {std::move(result)}; + return std::move(result); } } // namespace xla diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 69c591200c9025..b841b47a70624f 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -12,10 +12,6 @@ load( "if_cuda_is_configured", ) load("//xla:xla.bzl", "tests_build_defs_bzl_deps", "xla_cc_binary", "xla_cc_test") -load( - "//xla/stream_executor:build_defs.bzl", - "if_gpu_is_configured", -) load("//xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend_test_macros", "xla_test", "xla_test_library") load("//xla/tsl:tsl.bzl", "internal_visibility", "tsl_copts") load("//xla/tsl:tsl.default.bzl", "filegroup") @@ -2462,32 +2458,30 @@ xla_test( xla_test( name = "llvm_compiler_test", srcs = ["llvm_compiler_test.cc"], - backends = ["gpu"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM", - ]), - # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly - tags = ["gpu"], - deps = if_gpu_is_configured([ - ":verified_hlo_module", + backend_tags = { + # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly. + "gpu": ["gpu"], + }, + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":hlo_test_base", "//xla:literal_util", "//xla:test_helpers", "//xla/hlo/ir:hlo", - "@com_google_absl//absl/status", + "//xla/hlo/ir:hlo_module_group", "//xla/service:backend", - "//xla/service:cpu_plugin", "//xla/service:llvm_compiler", - "//xla/service:platform_util", - "//xla/service/cpu:cpu_compiler", - "//xla/service/gpu:gpu_compiler", "//xla/stream_executor", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Core", + "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:test_main", - ]) + if_cuda_is_configured([ - "//xla/stream_executor/cuda:cuda_platform_id", - ]) + if_rocm_is_configured([ - "//xla/stream_executor/rocm:rocm_platform_id", - ]), + ], ) xla_test( diff --git a/third_party/xla/xla/tests/llvm_compiler_test.cc b/third_party/xla/xla/tests/llvm_compiler_test.cc index dd15cbbeba4cc2..238482b650c3d6 100644 --- a/third_party/xla/xla/tests/llvm_compiler_test.cc +++ b/third_party/xla/xla/tests/llvm_compiler_test.cc @@ -21,199 +21,81 @@ limitations under the License. #include #include +#include #include "absl/status/status.h" +#include "llvm/IR/Module.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module_group.h" #include "xla/literal_util.h" #include "xla/service/backend.h" -#include "xla/service/cpu/cpu_compiler.h" -#include "xla/service/gpu/gpu_compiler.h" -#include "xla/service/platform_util.h" -#if GOOGLE_CUDA -#include "xla/stream_executor/cuda/cuda_platform_id.h" -#elif TENSORFLOW_USE_ROCM -#include "xla/stream_executor/rocm/rocm_platform_id.h" -#endif #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/stream_executor.h" #include "xla/test_helpers.h" -#include "xla/tests/verified_hlo_module.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/casts.h" #include "tsl/platform/threadpool.h" namespace xla { -namespace gpu { - -// Creating dummy data structure needed to initialize a GpuDummyCompiler -constexpr char kDummyTriple[] = "dummy-triple"; -constexpr char kDummyLayout[] = "e"; -const se::Platform::Id kGpuPlatformId = -#if GOOGLE_CUDA - se::cuda::kCudaPlatformId; -#elif TENSORFLOW_USE_ROCM - se::rocm::kROCmPlatformId; -#endif -// This class is a dummy implementation of GpuCompiler and is targeted for unit -// test only -class GpuDummyCompiler : public GpuCompiler { - public: - GpuDummyCompiler() - : GpuCompiler(kGpuPlatformId, kDummyTriple, kDummyLayout) {} - - int32_t GetToolkitVersion() const override { return 0; } - - absl::Status OptimizeHloConvolutionCanonicalization( - HloModule* hlo_module, se::GpuComputeCapability gpu_version, - se::dnn::VersionInfo dnn_version, - se::DeviceMemoryAllocator* device_allocator) { - return absl::OkStatus(); - } - - absl::Status OptimizeHloPostLayoutAssignment( - HloModule* hlo_module, se::StreamExecutor* stream_executor, - const CompileOptions& options, const TargetConfig& gpu_target_config, - tsl::thread::ThreadPool* thread_pool) override { - return absl::OkStatus(); - } +namespace { - absl::StatusOr CompileTargetBinary( - const HloModuleConfig& module_config, llvm::Module* llvm_module, - se::GpuComputeCapability gpu_version, bool relocatable, - const HloModule* debug_module, const CompileOptions& options) override { - return BackendCompileResult{}; - } -}; -} // namespace gpu +using LLVMCompilerTest = HloTestBase; -namespace { +const char* const kHloText = R"( +HloModule Constant -class LLVMCompilerTest : public ::testing::Test { - public: - void SetUp() override { - Platform* platform = FindPlatform(); - ASSERT_NE(platform, nullptr); - - BackendOptions backend_options; - backend_options.set_platform(platform); - absl::StatusOr> backend_or_status = - Backend::CreateBackend(backend_options); - ASSERT_IS_OK(backend_or_status.status()); - backend_ = std::move(backend_or_status).value(); - } - - ~LLVMCompilerTest() override {} - - protected: - using Platform = se::Platform; - - explicit LLVMCompilerTest(std::string platform_name) - : platform_name_(std::move(platform_name)) {} - - void TestCompilerHooks(LLVMCompiler* compiler) { - int pre_opt_hook_call_count = 0; - int post_opt_hook_call_count = 0; - - auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module&) { - ++pre_opt_hook_call_count; - return absl::OkStatus(); - }; - auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module&) { - ++post_opt_hook_call_count; - return absl::OkStatus(); - }; - - // Create HLO module, and run the compiler. - auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - - auto hlo_module = CreateNewVerifiedModule(); - hlo_module->AddEntryComputation(builder.Build()); - - compiler->SetPreOptimizationHook(pre_opt_hook); - compiler->SetPostOptimizationHook(post_opt_hook); - - ASSERT_TRUE(compiler - ->RunBackend(std::move(hlo_module), - backend_->default_stream_executor(), - /*device_allocator=*/nullptr) - .ok()); - - // Test that hooks were called. - EXPECT_EQ(1, pre_opt_hook_call_count); - EXPECT_EQ(1, post_opt_hook_call_count); - } - - void TestMultiModuleCompilation(LLVMCompiler* compiler) { - HloComputation::Builder builder(TestName()); - builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - - std::unique_ptr hlo_module = CreateNewVerifiedModule(); - hlo_module->AddEntryComputation(builder.Build()); - - auto module_group = std::make_unique("test_module_group"); - module_group->push_back(hlo_module->Clone()); - module_group->push_back(std::move(hlo_module)); - - std::vector> executors; - executors.push_back({backend_->default_stream_executor()}); - executors.push_back({backend_->default_stream_executor()}); - - EXPECT_IS_OK(compiler->Compile(std::move(module_group), - std::move(executors), - /*device_allocator=*/nullptr)); - } - - private: - Platform* FindPlatform() { - auto status_or_platform = PlatformUtil::GetPlatform(platform_name_); - return status_or_platform.ok() ? status_or_platform.value() : nullptr; - } - - std::string platform_name_; - std::unique_ptr backend_; - - static std::string TestName() { - return ::testing::UnitTest::GetInstance()->current_test_info()->name(); - } - - std::unique_ptr CreateNewVerifiedModule() { - HloModuleConfig config; - config.set_debug_options(GetDebugOptionsFromFlags()); - return std::make_unique( - TestName(), config, /*verifier_layout_sensitive=*/false, - /*allow_mixed_precision_in_hlo_verifier=*/true, - backend_->compiler()->ShapeSizeBytesFunction()); - } -}; - -class CpuCompilerTest : public LLVMCompilerTest { - public: - CpuCompilerTest() : LLVMCompilerTest("Host") {} -}; - -class GpuCompilerTest : public LLVMCompilerTest { - public: - GpuCompilerTest() : LLVMCompilerTest("GPU") {} -}; - -TEST_F(CpuCompilerTest, HooksTest) { - cpu::CpuCompiler compiler; - TestCompilerHooks(&compiler); +ENTRY main { + ROOT constant = f32[] constant(42.0) } +)"; -TEST_F(GpuCompilerTest, HooksTest) { - gpu::GpuDummyCompiler compiler; - TestCompilerHooks(&compiler); -} +TEST_F(LLVMCompilerTest, HooksTest) { + int pre_opt_hook_call_count = 0; + int post_opt_hook_call_count = 0; -TEST_F(CpuCompilerTest, CpuMultiModuleCompilation) { - cpu::CpuCompiler compiler; - TestMultiModuleCompilation(&compiler); + auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module&) { + ++pre_opt_hook_call_count; + return absl::OkStatus(); + }; + auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module&) { + ++post_opt_hook_call_count; + return absl::OkStatus(); + }; + + // Create HLO module, and run the compiler. + auto hlo_module = ParseAndReturnVerifiedModule(kHloText).value(); + LLVMCompiler* compiler = + tensorflow::down_cast(backend().compiler()); + compiler->SetPreOptimizationHook(pre_opt_hook); + compiler->SetPostOptimizationHook(post_opt_hook); + + ASSERT_TRUE(compiler + ->RunBackend(std::move(hlo_module), + backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ok()); + + // Test that hooks were called. + EXPECT_EQ(1, pre_opt_hook_call_count); + EXPECT_EQ(1, post_opt_hook_call_count); } -TEST_F(GpuCompilerTest, GpuMultModuleCompilation) { - gpu::GpuDummyCompiler compiler; - TestMultiModuleCompilation(&compiler); +TEST_F(LLVMCompilerTest, DISABLED_MultiModuleCompilation) { + auto hlo_module = ParseAndReturnVerifiedModule(kHloText).value(); + auto hlo_module2 = ParseAndReturnVerifiedModule(kHloText).value(); + std::vector> modules; + modules.push_back(std::move(hlo_module)); + modules.push_back(std::move(hlo_module2)); + auto module_group = + std::make_unique("test_module_group", std::move(modules)); + + std::vector> executors; + executors.push_back({backend().default_stream_executor()}); + executors.push_back({backend().default_stream_executor()}); + + EXPECT_IS_OK(backend().compiler()->Compile(std::move(module_group), + std::move(executors), + backend().memory_allocator())); } + } // namespace } // namespace xla From ab3720c821db954be3a8489b6e56014ae1b15f3f Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Wed, 19 Jun 2024 04:30:59 -0700 Subject: [PATCH 012/256] [XLA:GPU] [NFC] Remove redundant argument to GetKernelAnnotation It's always the current module. PiperOrigin-RevId: 644705666 --- third_party/xla/xla/service/gpu/gpu_executable.cc | 8 +++----- .../xla/xla/service/gpu/runtime/annotation.cc | 14 +++++--------- .../xla/xla/service/gpu/runtime/annotation.h | 2 +- .../xla/service/gpu/runtime/command_buffer_cmd.cc | 4 +--- .../service/gpu/runtime/command_buffer_thunk.cc | 4 +--- .../xla/service/gpu/runtime/sequential_thunk.cc | 4 +--- 6 files changed, 12 insertions(+), 24 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index 7a9927e0763660..a0e29b02bb54cb 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -359,8 +359,7 @@ absl::Status ExecuteThunks( Thunk::ExecutableSource executable_source, const ServiceExecutableRunOptions* run_options, const BufferAllocations& buffer_allocations, bool block_host_until_done, - const absl::flat_hash_set& execution_stream_ids, - const ModuleAnnotations& module_annotations) { + const absl::flat_hash_set& execution_stream_ids) { int64_t collective_max_nchannels = debug_options ? debug_options->xla_gpu_nccl_collective_max_nchannels() : 0; @@ -487,8 +486,7 @@ absl::Status ExecuteThunks( // Annotate execution of this op if tracing was enabled when we started // running this module. If tracing is enabled *while* we're running the // module, we won't get any data, but that's probably an OK trade-off. - auto scoped_annotation = - GetKernelAnnotation(&module_annotations, thunk->profile_annotation()); + auto scoped_annotation = GetKernelAnnotation(thunk->profile_annotation()); VLOG(3) << "Executing the thunk for " << thunk->profile_annotation(); if (NeedsAsyncCommsStream(*thunk)) { for (se::Stream* async_stream : async_comms_streams) { @@ -1022,7 +1020,7 @@ absl::StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( TF_RETURN_IF_ERROR(ExecuteThunks( has_module() ? &module_config().debug_options() : nullptr, module_name_, unique_id, *thunks_, executable_source, run_options, buffer_allocations, - block_host_until_done, execution_stream_ids_, module_annotations_)); + block_host_until_done, execution_stream_ids_)); } TF_RETURN_IF_ERROR( diff --git a/third_party/xla/xla/service/gpu/runtime/annotation.cc b/third_party/xla/xla/service/gpu/runtime/annotation.cc index 99b05d3a6c634a..f1473d476cf982 100644 --- a/third_party/xla/xla/service/gpu/runtime/annotation.cc +++ b/third_party/xla/xla/service/gpu/runtime/annotation.cc @@ -527,7 +527,7 @@ ModuleAnnotations::ModuleAnnotations(const HloModule& mod) : top_level{mod} { // range based on the content of `inst`, including `called` etc. // FIXME: using try_emplace here was sensitive to // https://github.com/abseil/abseil-cpp/issues/388. - kernels.insert({inst->name(), {top_level, *inst}}); + kernels.insert({inst->name(), KernelAnnotation{top_level, *inst}}); } } } @@ -548,19 +548,15 @@ ScopedModuleAnnotations::~ScopedModuleAnnotations() { std::exchange(current_annotations, restore_); } -const ModuleAnnotations* GetCurrentModuleAnnotations() { - return current_annotations; -} - std::optional GetKernelAnnotation( - const ModuleAnnotations* annotations, std::string_view profile_annotation) { + std::string_view profile_annotation) { if (profile_annotation.empty()) { return {}; } - if (annotations) { + if (current_annotations) { // Have a set of pre-prepared thunk/kernel annotations to use - const auto iter = annotations->kernels.find(profile_annotation); - if (iter != annotations->kernels.end()) { + const auto iter = current_annotations->kernels.find(profile_annotation); + if (iter != current_annotations->kernels.end()) { // Have a pre-prepared annotation, use it return std::optional{[&] { return iter->second; }}; } diff --git a/third_party/xla/xla/service/gpu/runtime/annotation.h b/third_party/xla/xla/service/gpu/runtime/annotation.h index 70a4c8df1ddeb1..e5e170891a31c9 100644 --- a/third_party/xla/xla/service/gpu/runtime/annotation.h +++ b/third_party/xla/xla/service/gpu/runtime/annotation.h @@ -104,7 +104,7 @@ class ScopedModuleAnnotations { const ModuleAnnotations* GetCurrentModuleAnnotations(); std::optional GetKernelAnnotation( - const ModuleAnnotations* annotations, std::string_view profile_annotation); + std::string_view profile_annotation); } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc index f2b40e949a20b3..381ceea2017dc3 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc @@ -294,8 +294,6 @@ absl::Status CommandBufferCmdSequence::Record( } } - const ModuleAnnotations* annotations = GetCurrentModuleAnnotations(); - // Track the number of commands recorded between barriers. absl::flat_hash_map num_recorded_commands; @@ -303,7 +301,7 @@ absl::Status CommandBufferCmdSequence::Record( ExecutionScopeId execution_scope_id = command.cmd->GetExecutionScope(record_params); std::optional annotation = - GetKernelAnnotation(annotations, command.cmd->profile_annotation()); + GetKernelAnnotation(command.cmd->profile_annotation()); if (command.requires_barrier) { VLOG(3) << "Add command buffer barrier after " diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.cc index 8236ad2a65cb08..b551766f43c1a9 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.cc @@ -206,10 +206,8 @@ absl::Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) { if (tsl::profiler::ProfilerLock::HasActiveSession() && thunks_.has_value()) { VLOG(1) << "Execute command buffer thunk as a regular thunk sequence " "because we detected active profiling session"; - const ModuleAnnotations* annotations = GetCurrentModuleAnnotations(); for (auto& thunk : *thunks_) { - auto scoped_annotation = - GetKernelAnnotation(annotations, thunk->profile_annotation()); + auto scoped_annotation = GetKernelAnnotation(thunk->profile_annotation()); TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(params)); } return absl::OkStatus(); diff --git a/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc index a0dc62fbc7155d..d874c25ea560f8 100644 --- a/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc @@ -53,10 +53,8 @@ absl::Status SequentialThunk::Initialize(const InitializeParams& params) { } absl::Status SequentialThunk::ExecuteOnStream(const ExecuteParams& params) { - const ModuleAnnotations* annotations = GetCurrentModuleAnnotations(); for (const auto& thunk : thunks_) { - auto annotation = - GetKernelAnnotation(annotations, thunk->profile_annotation()); + auto annotation = GetKernelAnnotation(thunk->profile_annotation()); TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(params)); } return absl::OkStatus(); From 7c487a387b2120d5ddcca6040210ccb4243b7dbd Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Wed, 19 Jun 2024 05:23:32 -0700 Subject: [PATCH 013/256] Fix ASAN error with double -> X float conversions. 1 << 63 doesn't fit in an int64, so we have to use a uint64. The generated code should be identical, so this should be unobservable outside of ASAN. PiperOrigin-RevId: 644716249 --- .../service/gpu/fusions/mlir/expand_float_ops.cc | 2 +- .../gpu/fusions/mlir/tests/expand_float_ops.mlir | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/expand_float_ops.cc b/third_party/xla/xla/service/gpu/fusions/mlir/expand_float_ops.cc index 69994e78bca418..cfbf9a2b8c22d0 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/expand_float_ops.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/expand_float_ops.cc @@ -321,7 +321,7 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, from_bits.shrui(value.getType().getIntOrFloatBitWidth() - 1) != 0; from_bits = - from_bits & ((1LL << (value.getType().getIntOrFloatBitWidth() - 1)) - 1); + from_bits & ((1ULL << (value.getType().getIntOrFloatBitWidth() - 1)) - 1); Value result_is_inf = IsInf(value, b); Value input_is_nan = IsNaN(value, b); diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/expand_float_ops.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/expand_float_ops.mlir index 7639ae1c091a96..01fc07daf8c815 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/expand_float_ops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/mlir/tests/expand_float_ops.mlir @@ -92,3 +92,17 @@ module { // correctness. // CHECK-LABEL: @fptoi8 // CHECK-NOT: arith.fptosi {{.*}}f8E5M2 + +// ----- + +module { + func.func @double_to_f8(%arg0: f64) -> f8E5M2 { + %ret = arith.truncf %arg0 : f64 to f8E5M2 + return %ret : f8E5M2 + } +} + +// Just check that this lowers successfully. We have integration tests to verify +// correctness. +// CHECK-LABEL: @double_to_f8 +// CHECK-NOT: arith.truncf From 1571f0a6906bebd8dcfcf1a96e664d08fe102cdc Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Wed, 19 Jun 2024 05:28:54 -0700 Subject: [PATCH 014/256] [XLA:GPU] More consistent error handling for borrowed streams Below we already error out if we can't get a stream, we should just error out if the stream borrower is not available instead of deferring the error. PiperOrigin-RevId: 644717664 --- .../xla/xla/service/gpu/gpu_executable.cc | 30 +++++++------------ .../service/service_executable_run_options.h | 2 ++ 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index a0e29b02bb54cb..f184672150853f 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -381,21 +381,20 @@ absl::Status ExecuteThunks( // Borrow streams required for NcclCollectiveThunk. absl::InlinedVector async_comms_streams( kAsyncStreamTotal, nullptr); - absl::StatusOr> streams = - run_options->BorrowStreams(executor->device_ordinal(), kAsyncStreamTotal, - stream_priority); - if (streams.ok()) { + se::Stream* command_buffer_trace_stream = nullptr; + if (run_options->HasStreamBorrower()) { + TF_ASSIGN_OR_RETURN( + std::vector async_comms_streams_ownr, + run_options->BorrowStreams(executor->device_ordinal(), + kAsyncStreamTotal, stream_priority)); for (int64_t i = 0; i < kAsyncStreamTotal; ++i) { - async_comms_streams[i] = streams->at(i).get(); + async_comms_streams[i] = async_comms_streams_ownr[i].get(); } - } - // Borrow stream for tracing command buffers. - se::Stream* command_buffer_trace_stream = nullptr; - absl::StatusOr borrowed_command_buffer_trace_stream = - run_options->BorrowStream(executor->device_ordinal()); - if (borrowed_command_buffer_trace_stream.ok()) { - command_buffer_trace_stream = borrowed_command_buffer_trace_stream->get(); + // Borrow stream for tracing command buffers. + TF_ASSIGN_OR_RETURN(StreamPool::Ptr borrowed_command_buffer_trace_stream, + run_options->BorrowStream(executor->device_ordinal())); + command_buffer_trace_stream = borrowed_command_buffer_trace_stream.get(); } // Borrow stream for additional compute streams @@ -488,13 +487,6 @@ absl::Status ExecuteThunks( // module, we won't get any data, but that's probably an OK trade-off. auto scoped_annotation = GetKernelAnnotation(thunk->profile_annotation()); VLOG(3) << "Executing the thunk for " << thunk->profile_annotation(); - if (NeedsAsyncCommsStream(*thunk)) { - for (se::Stream* async_stream : async_comms_streams) { - TF_RET_CHECK(async_stream != nullptr) - << "`run_options` must have a stream borrower for async thunks."; - } - } - TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(execute_params)); } return MaybeSyncAndProfile(run_options, std::move(execution_timer), diff --git a/third_party/xla/xla/service/service_executable_run_options.h b/third_party/xla/xla/service/service_executable_run_options.h index 8987b7519ece5b..6a91b5e8be3b45 100644 --- a/third_party/xla/xla/service/service_executable_run_options.h +++ b/third_party/xla/xla/service/service_executable_run_options.h @@ -85,6 +85,8 @@ class ServiceExecutableRunOptions { "No stream borrower"); } + bool HasStreamBorrower() const { return stream_borrower_ != nullptr; } + private: ExecutableRunOptions run_options_; StreamBorrower stream_borrower_; From df2dae48118d675fb92cf51f42aa6abdb391cedc Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Wed, 19 Jun 2024 06:05:25 -0700 Subject: [PATCH 015/256] [JAX/XLA] Correct the logic for showing stack traces on JAX_TRACEBACK_FILTERING being set to off PiperOrigin-RevId: 644725802 --- third_party/xla/xla/pjrt/exceptions.h | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/third_party/xla/xla/pjrt/exceptions.h b/third_party/xla/xla/pjrt/exceptions.h index 6a5865f3cce0ce..cf9696de956549 100644 --- a/third_party/xla/xla/pjrt/exceptions.h +++ b/third_party/xla/xla/pjrt/exceptions.h @@ -53,10 +53,7 @@ class XlaRuntimeError : public std::runtime_error { } static bool ShowStackTraces() { - if (char* value = getenv("JAX_TRACEBACK_FILTERING")) { - return strcmp(value, "off"); - } - return false; + return absl::string_view(getenv("JAX_TRACEBACK_FILTERING")) == "off"; } std::optional status_; From 32aa8bb09bd1eeb73dcdf4f6bc58dc57c79d9de1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 19 Jun 2024 07:13:08 -0700 Subject: [PATCH 016/256] Integrate LLVM at llvm/llvm-project@99c43e3ce314 Updates LLVM usage to match [99c43e3ce314](https://github.com/llvm/llvm-project/commit/99c43e3ce314) PiperOrigin-RevId: 644739615 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index be194471a81f13..002c9c6f710444 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "b99d0b34400176cb9183113b96b245400caaf8d8" - LLVM_SHA256 = "9ca100ba202a8048ad478149a31cb3f06e164ec4dc49c17fe043807bea608c69" + LLVM_COMMIT = "99c43e3ce3142a93bbad4f9efeace254d9a8442c" + LLVM_SHA256 = "40440e956c5a1b73373311a746a55bed9e719ceaa01f2cdc63ed116d5c01c438" tf_http_archive( name = name, From 542172b0394127cc67fc8130e1b3658cdf0174a5 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 19 Jun 2024 07:13:28 -0700 Subject: [PATCH 017/256] [XLA:GPU] Remove dependency from `triton_support_test.cc` on `TritonFusionAnalysis`. This requires making all tests in the file follow the "implication test" pattern, which checks that 1. `IsSupportedInstruction(instr)` implies "Triton generates code for `instr` successfully"; 2. `!IsSupportedInstruction(instr)` implies "Triton fails gracefully to generate code for `instr`". We add a `RunSupportTest` util to `TritonSupportTest` to instantiate this test pattern. Unfortunately, there are still issues with checking `2.` in the specific case of `f16` division, which triggers a crash in LLVM. We allow to bypass this issue in the relevant test by skipping the failure implication test when the `skip_failure_branch_to_avoid_crash` of `RunSupportTest` is explicitly set. The change requires some additional changes to make the tests pass: * `SymbolicTileAnalysis` needs to discard operations that can not be indexed using `IndexingMap`s; * we need to fork the `IsTritonSupportedElementwise` utils defined in `legacy_triton::` to get proper coverage. The reason for this is that these utils were previously deeply intertwined with the compilation pipeline, and would claim to support codegen'ing certain elementwise operations because `FloatNormalization` would later allow them to be codegen'd correctly by changing their type before codegen. Fixing this in the legacy infrastructure is non-trivial (doing it improperly would mean slowdowns if we fail to make fusions that we previously managed to do), and non-desirable since this is code we aim to delete in not-too-long(tm). PiperOrigin-RevId: 644739699 --- third_party/xla/xla/service/gpu/BUILD | 3 +- .../ir_emitter_triton_parametrized_test.cc | 50 +++-- .../gpu/model/symbolic_tile_analysis.cc | 6 + .../xla/xla/service/gpu/triton_support.cc | 113 +++++++++-- .../xla/xla/service/gpu/triton_support.h | 44 +++- .../service/gpu/triton_support_legacy_test.cc | 11 +- .../xla/service/gpu/triton_support_test.cc | 191 ++++++++---------- .../xla/xla/service/gpu/triton_test_utils.cc | 6 +- .../xla/xla/service/gpu/triton_test_utils.h | 11 +- .../service/gpu/triton_tiling_propagation.cc | 2 +- 10 files changed, 276 insertions(+), 161 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index dba824f313e4d6..68ec9f6c18a874 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1248,6 +1248,7 @@ cc_library( "//xla/service:instruction_fusion", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@local_tsl//tsl/platform:tensor_float_32_utils", @@ -1265,12 +1266,12 @@ xla_test( deps = [ ":gpu_device_info_for_tests", ":ir_emitter_triton", - ":triton_fusion_analysis", ":triton_support", ":triton_test_utils", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service/gpu/model:tiled_hlo_computation", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc index 146b53547a1852..f5eff6ed86f731 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc @@ -247,7 +247,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine( ::testing::Values(PRED), ::testing::ValuesIn( - legacy_triton::TritonSupportedUnaryElementwise(PRED)), + legacy_triton:: + TritonSupportedUnaryElementwiseUpToFloatNormalization(PRED)), ::testing::Values(3e-2)), ElementwiseTestParamsToString); @@ -255,40 +256,50 @@ INSTANTIATE_TEST_SUITE_P( ElementwiseTestSuiteS8, UnaryElementwiseTest, ::testing::Combine( ::testing::Values(S8), - ::testing::ValuesIn(legacy_triton::TritonSupportedUnaryElementwise(S8)), + ::testing::ValuesIn( + legacy_triton:: + TritonSupportedUnaryElementwiseUpToFloatNormalization(S8)), ::testing::Values(3e-2)), ElementwiseTestParamsToString); INSTANTIATE_TEST_SUITE_P( ElementwiseTestSuiteS16, UnaryElementwiseTest, - ::testing::Combine(::testing::Values(S16), - ::testing::ValuesIn( - legacy_triton::TritonSupportedUnaryElementwise(S16)), - ::testing::Values(1e-3)), + ::testing::Combine( + ::testing::Values(S16), + ::testing::ValuesIn( + legacy_triton:: + TritonSupportedUnaryElementwiseUpToFloatNormalization(S16)), + ::testing::Values(1e-3)), ElementwiseTestParamsToString); INSTANTIATE_TEST_SUITE_P( ElementwiseTestSuiteS32, UnaryElementwiseTest, - ::testing::Combine(::testing::Values(S32), - ::testing::ValuesIn( - legacy_triton::TritonSupportedUnaryElementwise(S32)), - ::testing::Values(1e-3)), + ::testing::Combine( + ::testing::Values(S32), + ::testing::ValuesIn( + legacy_triton:: + TritonSupportedUnaryElementwiseUpToFloatNormalization(S32)), + ::testing::Values(1e-3)), ElementwiseTestParamsToString); INSTANTIATE_TEST_SUITE_P( ElementwiseTestSuiteF16, UnaryElementwiseTest, - ::testing::Combine(::testing::Values(F16), - ::testing::ValuesIn( - legacy_triton::TritonSupportedUnaryElementwise(F16)), - ::testing::Values(2e-4)), + ::testing::Combine( + ::testing::Values(F16), + ::testing::ValuesIn( + legacy_triton:: + TritonSupportedUnaryElementwiseUpToFloatNormalization(F16)), + ::testing::Values(2e-4)), ElementwiseTestParamsToString); INSTANTIATE_TEST_SUITE_P( ElementwiseTestSuiteF32, UnaryElementwiseTest, - ::testing::Combine(::testing::Values(F32), - ::testing::ValuesIn( - legacy_triton::TritonSupportedUnaryElementwise(F32)), - ::testing::Values(1e-6)), + ::testing::Combine( + ::testing::Values(F32), + ::testing::ValuesIn( + legacy_triton:: + TritonSupportedUnaryElementwiseUpToFloatNormalization(F32)), + ::testing::Values(1e-6)), ElementwiseTestParamsToString); using BinaryElementwiseTest = ElementwiseTest; @@ -364,7 +375,8 @@ ENTRY e { std::vector TestedBinaryElementwise(PrimitiveType element_type) { std::vector ret = - legacy_triton::TritonSupportedBinaryElementwise(element_type); + legacy_triton::TritonSupportedBinaryElementwiseUpToFloatNormalization( + element_type); // Comparison requires an additional property. ret.erase(std::remove_if(ret.begin(), ret.end(), HloOpcodeIsComparison), ret.end()); diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc index b2d0847850114b..0f22883914268b 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -247,6 +247,12 @@ absl::StatusOr ComputeBlockIdToTileOffsetIndexing( IndexingMap operand_indexing_map = ComposeIndexingMaps(tiled_hlo_instruction->indexing_map(), *operand_indexing_map_set.begin()); + if (operand_indexing_map.IsUndefined()) { + return FusionDecision{} + << "Couldn't derive indexing map for instruction " + << tiled_hlo_instruction->hlo()->ToString() << " and operand " + << operand.instruction().ToString(); + } operand_indexing_map.Simplify(); operand_indexing_map.RescaleSymbols(); operand_indexing_map.RemoveUnusedSymbols(); diff --git a/third_party/xla/xla/service/gpu/triton_support.cc b/third_party/xla/xla/service/gpu/triton_support.cc index eb51cf4dae8039..7d605a7aa62879 100644 --- a/third_party/xla/xla/service/gpu/triton_support.cc +++ b/third_party/xla/xla/service/gpu/triton_support.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -112,7 +113,7 @@ bool IsTritonSupportedDataType(PrimitiveType type, } } -std::vector TritonSupportedUnaryElementwise( +std::vector TritonSupportedUnaryElementwiseUpToFloatNormalization( PrimitiveType element_type) { std::vector ret = {HloOpcode::kConvert}; if (element_type == PrimitiveType::PRED) { @@ -136,7 +137,7 @@ std::vector TritonSupportedUnaryElementwise( return ret; } -std::vector TritonSupportedBinaryElementwise( +std::vector TritonSupportedBinaryElementwiseUpToFloatNormalization( PrimitiveType element_type) { if (element_type == PrimitiveType::PRED) { return {HloOpcode::kAnd, HloOpcode::kOr, HloOpcode::kXor, @@ -155,19 +156,25 @@ std::vector TritonSupportedBinaryElementwise( return ret; } -std::vector TritonSupportedTernaryElementwise( +std::vector TritonSupportedTernaryElementwiseUpToFloatNormalization( PrimitiveType element_type) { return {HloOpcode::kSelect, HloOpcode::kClamp}; } -bool IsTritonSupportedElementwise(HloOpcode opcode, - PrimitiveType element_type) { - return absl::c_linear_search(TritonSupportedUnaryElementwise(element_type), - opcode) || - absl::c_linear_search(TritonSupportedBinaryElementwise(element_type), - opcode) || - absl::c_linear_search(TritonSupportedTernaryElementwise(element_type), - opcode); +bool IsTritonSupportedElementwiseUpToFloatNormalization( + HloOpcode opcode, PrimitiveType element_type) { + return absl::c_linear_search( + TritonSupportedUnaryElementwiseUpToFloatNormalization( + element_type), + opcode) || + absl::c_linear_search( + TritonSupportedBinaryElementwiseUpToFloatNormalization( + element_type), + opcode) || + absl::c_linear_search( + TritonSupportedTernaryElementwiseUpToFloatNormalization( + element_type), + opcode); } CodegenDecision CanTritonHandleElementwise( @@ -185,7 +192,7 @@ CodegenDecision CanTritonHandleElementwise( if (instr.opcode() == HloOpcode::kConstant) { return CodegenDecision{}; - } else if (!IsTritonSupportedElementwise( + } else if (!IsTritonSupportedElementwiseUpToFloatNormalization( instr.opcode(), instr.operand(0)->shape().element_type())) { return "Unsupported elementwise operation."; } @@ -417,6 +424,82 @@ CodegenDecision IsTritonSupportedInstruction( } // namespace legacy_triton +namespace { + +// Set of unary elementwise ops that are genuinely supported by Triton. +// TODO(b/345763510): make sure that this is accurate. At the moment, this is +// mostly a fork of the same code in legacy_triton::. +absl::flat_hash_set TritonSupportedUnaryElementwiseOps( + PrimitiveType element_type) { + if (element_type == PrimitiveType::PRED) { + return {HloOpcode::kConvert, HloOpcode::kNot}; + } + absl::flat_hash_set ret = {HloOpcode::kConvert, HloOpcode::kAbs, + HloOpcode::kNegate}; + if (element_type == PrimitiveType::F32 || + element_type == PrimitiveType::F64) { + absl::flat_hash_set additional_opcodes{ + HloOpcode::kCos, HloOpcode::kExp, HloOpcode::kExpm1, + HloOpcode::kFloor, HloOpcode::kCeil, HloOpcode::kLog, + HloOpcode::kLog1p, HloOpcode::kRsqrt, HloOpcode::kSin, + HloOpcode::kSqrt, HloOpcode::kCbrt, HloOpcode::kTan, + HloOpcode::kTanh, HloOpcode::kErf}; + ret.insert(additional_opcodes.begin(), additional_opcodes.end()); + } + + if (element_type == PrimitiveType::BF16 || + element_type == PrimitiveType::F16) { + absl::flat_hash_set additional_opcodes{HloOpcode::kFloor, + HloOpcode::kCeil}; + ret.insert(additional_opcodes.begin(), additional_opcodes.end()); + } + return ret; +} + +// Set of binary elementwise ops that are genuinely supported by Triton. +// TODO(b/345763510): make sure that this is accurate. At the moment, this is +// mostly a fork of the same code in legacy_triton::. +absl::flat_hash_set TritonSupportedBinaryElementwiseOps( + PrimitiveType element_type) { + if (element_type == PrimitiveType::PRED) { + return {HloOpcode::kAnd, HloOpcode::kOr, HloOpcode::kXor, + HloOpcode::kCompare}; + } + absl::flat_hash_set ret = { + HloOpcode::kAdd, HloOpcode::kCompare, HloOpcode::kMaximum, + HloOpcode::kMinimum, HloOpcode::kMultiply, HloOpcode::kSubtract}; + if (element_type == PrimitiveType::F32 || + element_type == PrimitiveType::F64) { + absl::flat_hash_set additional_opcodes{ + HloOpcode::kAtan2, HloOpcode::kDivide, HloOpcode::kPower}; + ret.insert(additional_opcodes.begin(), additional_opcodes.end()); + } else if (element_type == PrimitiveType::BF16) { + ret.insert(HloOpcode::kDivide); + } + return ret; +} + +// Set of ternary elementwise ops that are genuinely supported by Triton. +// TODO(b/345763510): make sure that this is accurate. At the moment, this is +// mostly a fork of the same code in legacy_triton::. +absl::flat_hash_set TritonSupportedTernaryElementwiseOps( + PrimitiveType element_type) { + return {HloOpcode::kSelect, HloOpcode::kClamp}; +} + +// Returns `true` if the given opcode and element type correspond to a n-ary +// elementwise op that is genuinely supported by Triton. The caller is +// responsible for ensuring that the relevant data type is supported on the +// device of interest. +bool IsTritonSupportedElementwise(HloOpcode opcode, + PrimitiveType element_type) { + return TritonSupportedUnaryElementwiseOps(element_type).contains(opcode) || + TritonSupportedBinaryElementwiseOps(element_type).contains(opcode) || + TritonSupportedTernaryElementwiseOps(element_type).contains(opcode); +} + +} // namespace + CodegenDecision IsTritonSupportedInstruction( const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) { bool output_type_is_supported = legacy_triton::IsTritonSupportedDataType( @@ -437,7 +520,11 @@ CodegenDecision IsTritonSupportedInstruction( } if (instr.IsElementwise()) { - return legacy_triton::CanTritonHandleElementwise(instr, gpu_version); + if (!IsTritonSupportedElementwise(instr.opcode(), + instr.shape().element_type())) { + return "Unsupported elementwise operation."; + } + return CodegenDecision{}; } // TODO(bchetioui): support kDot, kPad, and kDynamicSlice. diff --git a/third_party/xla/xla/service/gpu/triton_support.h b/third_party/xla/xla/service/gpu/triton_support.h index 137671d066b68b..504c7c1893cf09 100644 --- a/third_party/xla/xla/service/gpu/triton_support.h +++ b/third_party/xla/xla/service/gpu/triton_support.h @@ -38,22 +38,54 @@ bool IsDistributiveOverAddition(const HloInstruction& hlo); // Allowlist of unary elementwise operations supported by the legacy Triton // emitters. -std::vector TritonSupportedUnaryElementwise(PrimitiveType); +// +// Note: this is not an accurate representation of what is actually supported by +// the Triton emitters, because operations affected by FloatNormalization may +// be tagged as "supported" here, even though FloatNormalization is required to +// make them work. We could fix this, but this is code we aim to delete soon, so +// it doesn't seem worth it. We'll revisit this decision if the code doesn't go +// away soon. +std::vector TritonSupportedUnaryElementwiseUpToFloatNormalization( + PrimitiveType); // Allowlist of binary elementwise operations supported by the legacy Triton // emitters. -std::vector TritonSupportedBinaryElementwise(PrimitiveType); +// +// Note: this is not an accurate representation of what is actually supported by +// the Triton emitters, because operations affected by FloatNormalization may +// be tagged as "supported" here, even though FloatNormalization is required to +// make them work. We could fix this, but this is code we aim to delete soon, so +// it doesn't seem worth it. We'll revisit this decision if the code doesn't go +// away soon. +std::vector TritonSupportedBinaryElementwiseUpToFloatNormalization( + PrimitiveType); // Allowlist of ternary elementwise operations supported by the legacy Triton // emitters. -std::vector TritonSupportedTernaryElementwise(PrimitiveType); +// +// Note: this is not an accurate representation of what is actually supported by +// the Triton emitters, because operations affected by FloatNormalization may +// be tagged as "supported" here, even though FloatNormalization is required to +// make them work. We could fix this, but this is code we aim to delete soon, so +// it doesn't seem worth it. We'll revisit this decision if the code doesn't go +// away soon. +std::vector TritonSupportedTernaryElementwiseUpToFloatNormalization( + PrimitiveType); // Data types that are supported by the legacy Triton emitters. bool IsTritonSupportedDataType(PrimitiveType, const se::GpuComputeCapability&); // Checks elementwise operation against unary, binary, and ternary elementwise // operations supported by the legacy Triton emitters. -bool IsTritonSupportedElementwise(HloOpcode, PrimitiveType); +// +// Note: this is not an accurate representation of what is actually supported by +// the Triton emitters, because operations affected by FloatNormalization may +// be tagged as "supported" here, even though FloatNormalization is required to +// make them work. We could fix this, but this is code we aim to delete soon, so +// it doesn't seem worth it. We'll revisit this decision if the code doesn't go +// away soon. +bool IsTritonSupportedElementwiseUpToFloatNormalization(HloOpcode, + PrimitiveType); CodegenDecision CanTritonHandleGEMM( const HloDotInstruction& dot, const se::GpuComputeCapability& gpu_version); @@ -72,7 +104,9 @@ CodegenDecision IsTritonSupportedDynamicSlice( } // namespace legacy_triton // Return `CodegenDecision`'s equivalent of `true` if the parameter instruction -// is supported by the Triton emitters for the given compute capability. +// is supported by the Triton emitters for the given compute capability. Note +// that this function makes no assumption about what happens if +// `FloatNormalization` is run, unlike the legacy Triton utils. // // Note: this function is entirely dissociated from the legacy Triton emitters. // If you intend to add a feature to the legacy Triton emitters (which you diff --git a/third_party/xla/xla/service/gpu/triton_support_legacy_test.cc b/third_party/xla/xla/service/gpu/triton_support_legacy_test.cc index f7ec46ed7a2dc1..d3452ea02b5a32 100644 --- a/third_party/xla/xla/service/gpu/triton_support_legacy_test.cc +++ b/third_party/xla/xla/service/gpu/triton_support_legacy_test.cc @@ -60,7 +60,7 @@ bool CombinationCrashesTriton( return false; } -class DotTest : public TritonSupportTestWithParam { +class DotTest : public TritonSupportTestBaseWithParam { protected: void TestDotWithTypes(PrimitiveType lhs_type, PrimitiveType rhs_type, PrimitiveType output_type) { @@ -185,7 +185,7 @@ std::string DynamicSliceTestParamToString( // ::testing::ConvertGenerator, which broke the build in some OSS // configurations. class DynamicSliceTest - : public TritonSupportTest, + : public TritonSupportTestBase, public ::testing::WithParamInterface {}; TEST_P(DynamicSliceTest, IsTritonSupportedDynamicSlice) { @@ -262,7 +262,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Bool()), DynamicSliceTestParamToString); -TEST_F(TritonSupportTest, UnsupportedDotOutputTypeFailsGracefullyWithTriton) { +TEST_F(TritonSupportTestBase, + UnsupportedDotOutputTypeFailsGracefullyWithTriton) { const std::string kHloTest = R"( triton_computation { parameter_0 = f32[92,11]{1,0} parameter(0) @@ -304,7 +305,7 @@ ENTRY e { ::testing::HasSubstr("pm.run(triton_module.get()).succeeded()"))); } -TEST_F(TritonSupportTest, +TEST_F(TritonSupportTestBase, UnsupportedDotWithMultipleBatchDimensionsFailsGracefullyWithTriton) { const std::string kHloTest = R"( triton_computation { @@ -347,7 +348,7 @@ ENTRY e { ::testing::HasSubstr("num_batch_dims <= 1"))); } -TEST_F(TritonSupportTest, +TEST_F(TritonSupportTestBase, UnsupportedDotWithNoNonContractingDimensionsFailsGracefullyWithTriton) { const std::string kHloTest = R"( triton_computation { diff --git a/third_party/xla/xla/service/gpu/triton_support_test.cc b/third_party/xla/xla/service/gpu/triton_support_test.cc index 3cf75a4cbad7e9..c3cf0a2e191102 100644 --- a/third_party/xla/xla/service/gpu/triton_support_test.cc +++ b/third_party/xla/xla/service/gpu/triton_support_test.cc @@ -17,9 +17,11 @@ limitations under the License. // has landed. #include "xla/service/gpu/triton_support.h" +#include #include #include #include +#include #include #include @@ -30,7 +32,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/ir_emitter_triton.h" -#include "xla/service/gpu/triton_fusion_analysis.h" +#include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/gpu/triton_test_utils.h" #include "xla/stream_executor/device_description.h" #include "xla/xla.pb.h" @@ -64,9 +66,50 @@ auto AllXlaDataTypes() { return ::testing::ValuesIn(xla_data_types); } -// TODO(b/343158720): remove references to TritonFusionAnalysis in this file. -// TODO(b/343158720): factor out implication tests into a util in order to -// simplify the test structure. +class TritonSupportTest : public TritonSupportTestBase { + public: + // Runs a support test for the given `TestedInstruction`. The support test + // verifies that `IsTritonSupportedInstruction` is in sync with the + // implemented Triton emitter, i.e., given an instruction `instr`, either + // - `IsTritonSupportedInstruction(instr)` => Triton lowering is OK + // - `!IsTritonSupportedInstruction(instr)` => Triton lowering is not OK. + // + // In order to make sure that the call succeeds in both cases, the user must + // pass valid output tile sizes for the tested instruction/computation + // as an additional parameter. + // + // In some cases, the Triton lowering is not handled gracefully by the + // lowering code, and the lowering fails with a crash. In such cases, the + // user can set `skip_failure_branch_to_avoid_crash` to `true` to skip the + // lowering test when `IsTritonSupportedInstruction` returns `false`. + void RunSupportTest(TestedInstruction ti, + std::vector output_tile_sizes, + bool skip_failure_branch_to_avoid_crash = false) { + BlockLevelParameters block_level_parameters = + FromOutputTileSizes(std::move(output_tile_sizes)); + if (IsTritonSupportedInstruction(ti.Instruction(), + GetCudaComputeCapability())) { + TF_EXPECT_OK(CreateTritonIrAndFileCheck(ti.TritonComputation(), + block_level_parameters, + "CHECK: tt.func @triton_fn")); + } else { + if (!skip_failure_branch_to_avoid_crash) { + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + EXPECT_THAT( + TritonWrapper("test_fn", &ti.TritonFusion(), + GetCudaComputeCapability(), dev_info, + block_level_parameters, &llvm_module_, mlir_context_), + Not(IsOk())); + } + } + } +}; + +class TritonSupportTestWithParam : public TritonSupportTest, + public ::testing::WithParamInterface< + std::tuple> {}; + using BitcastOrReshapeTest = TritonSupportTestWithParam; TEST_P(BitcastOrReshapeTest, IsTritonSupportedBitcastOrReshape) { @@ -79,20 +122,7 @@ ENTRY triton_computation { TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); - if (IsTritonSupportedInstruction(ti.Instruction(), - GetCudaComputeCapability())) { - TF_EXPECT_OK(CreateTritonIrAndFileCheck(ti.TritonComputation(), - FromOutputTileSizes({16}), - "CHECK: tt.func @triton_fn")); - } else { - const se::DeviceDescription dev_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); - EXPECT_THAT( - TritonWrapper("test_fn", &ti.TritonFusion(), GetCudaComputeCapability(), - dev_info, FromOutputTileSizes({1}), &llvm_module_, - mlir_context_), - Not(IsOk())); - } + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16}); } INSTANTIATE_TEST_SUITE_P( @@ -121,21 +151,7 @@ ENTRY triton_computation { TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); - if (IsTritonSupportedInstruction(ti.Instruction(), - GetCudaComputeCapability())) { - TF_EXPECT_OK(ApplyFloatNormalization(ti.Module().get())); - TF_EXPECT_OK(CreateTritonIrAndFileCheck(ti.TritonComputation(), - FromOutputTileSizes({1, 32}), - "CHECK: tt.func @triton_fn")); - } else { - // TODO(b/331632717): update the check to use SymbolicTileAnalysis to avoid - // tiling failures and check triton emitter fails gracefully. - EXPECT_THAT(TritonFusionAnalysis::Execute(ti.TritonComputation()), - tsl::testing::StatusIs( - absl::StatusCode::kFailedPrecondition, - ::testing::HasSubstr( - "Can not propagate dim orders and requirements"))); - } + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}); } INSTANTIATE_TEST_SUITE_P( @@ -178,24 +194,14 @@ ENTRY triton_computation { TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); - if (IsTritonSupportedInstruction(ti.Instruction(), - GetCudaComputeCapability())) { - TF_EXPECT_OK(ApplyFloatNormalization(ti.Module().get())); - TF_EXPECT_OK(CreateTritonIrAndFileCheck(ti.TritonComputation(), - FromOutputTileSizes({1, 32}), - "CHECK: tt.func @triton_fn")); - } else { - EXPECT_THAT(TritonFusionAnalysis::Execute(ti.TritonComputation()), - ::testing::AnyOf( - tsl::testing::StatusIs( - absl::StatusCode::kInternal, - ::testing::HasSubstr( - "std::holds_alternative")), - tsl::testing::StatusIs( - absl::StatusCode::kFailedPrecondition, - ::testing::HasSubstr( - "Can not propagate dim orders and requirements")))); + + bool skip_failure_branch_to_avoid_crash = false; + if (data_type == F16 && opcode == HloOpcode::kDivide) { + skip_failure_branch_to_avoid_crash = true; } + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}, + /*skip_failure_branch_to_avoid_crash=*/ + skip_failure_branch_to_avoid_crash); } INSTANTIATE_TEST_SUITE_P( @@ -238,19 +244,7 @@ ENTRY triton_computation { TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); - if (IsTritonSupportedInstruction(ti.Instruction(), - GetCudaComputeCapability())) { - TF_EXPECT_OK(ApplyFloatNormalization(ti.Module().get())); - TF_EXPECT_OK(CreateTritonIrAndFileCheck(ti.TritonComputation(), - FromOutputTileSizes({1, 32}), - "CHECK: tt.func @triton_fn")); - } else { - EXPECT_THAT( - TritonFusionAnalysis::Execute(ti.TritonComputation()), - tsl::testing::StatusIs( - absl::StatusCode::kInternal, - ::testing::HasSubstr("std::holds_alternative"))); - } + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}); } INSTANTIATE_TEST_SUITE_P( @@ -278,19 +272,7 @@ ENTRY triton_computation { TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); - if (IsTritonSupportedInstruction(ti.Instruction(), - GetCudaComputeCapability())) { - TF_EXPECT_OK(ApplyFloatNormalization(ti.Module().get())); - TF_EXPECT_OK(CreateTritonIrAndFileCheck(ti.TritonComputation(), - FromOutputTileSizes({1, 32}), - "CHECK: tt.func @triton_fn")); - } else { - EXPECT_THAT( - TritonFusionAnalysis::Execute(ti.TritonComputation()), - tsl::testing::StatusIs( - absl::StatusCode::kInternal, - ::testing::HasSubstr("std::holds_alternative"))); - } + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}); } INSTANTIATE_TEST_SUITE_P( @@ -300,6 +282,7 @@ INSTANTIATE_TEST_SUITE_P( TritonSupportTestParamsToString); using ReduceConstTest = TritonSupportTestWithParam; + TEST_P(ReduceConstTest, IsTritonSupportedReduceWithConstInit) { auto [data_type, opcode] = GetParam(); if (data_type == BF16 && SkipBF16Tests()) { @@ -322,23 +305,7 @@ ENTRY triton_computation { TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); - if (IsTritonSupportedInstruction(ti.Instruction(), - GetCudaComputeCapability())) { - TF_EXPECT_OK(ApplyFloatNormalization(ti.Module().get())); - TF_EXPECT_OK(CreateTritonIrAndFileCheck(ti.TritonComputation(), - FromOutputTileSizes({1}), - "CHECK: tt.func @triton_fn")); - } else { - const se::DeviceDescription dev_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); - EXPECT_THAT( - TritonWrapper("test_fn", &ti.TritonFusion(), GetCudaComputeCapability(), - dev_info, FromOutputTileSizes({1}), &llvm_module_, - mlir_context_), - tsl::testing::StatusIs( - absl::StatusCode::kInternal, - ::testing::HasSubstr("Failed to compile Triton kernel"))); - } + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}); } INSTANTIATE_TEST_SUITE_P( @@ -378,7 +345,7 @@ ENTRY triton_computation { } TEST_F( - TritonSupportTest, + TritonSupportTestBase, UnsupportedReduceWithMoreThanOneReduceDimensionsFailsGracefullyWithTriton) { const std::string kHloTest = R"( add { @@ -400,11 +367,13 @@ ENTRY triton_computation { .Explain(), ::testing::HasSubstr( "Reduction is not a row-reduction of a single operand.")); - EXPECT_THAT(TritonFusionAnalysis::Execute(ti.TritonComputation()), - tsl::testing::StatusIs( - absl::StatusCode::kFailedPrecondition, - ::testing::HasSubstr( - "Can not propagate dim orders and requirements"))); + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + EXPECT_THAT( + TritonWrapper("test_fn", &ti.TritonFusion(), GetCudaComputeCapability(), + dev_info, FromOutputTileSizes({1}), &llvm_module_, + mlir_context_), + Not(IsOk())); } TEST_F(TritonSupportTest, @@ -429,11 +398,13 @@ ENTRY triton_computation { .Explain(), ::testing::HasSubstr( "Reduction is not a row-reduction of a single operand.")); - EXPECT_THAT(TritonFusionAnalysis::Execute(ti.TritonComputation()), - tsl::testing::StatusIs( - absl::StatusCode::kFailedPrecondition, - ::testing::HasSubstr( - "Can not propagate dim orders and requirements"))); + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + EXPECT_THAT( + TritonWrapper("test_fn", &ti.TritonFusion(), GetCudaComputeCapability(), + dev_info, FromOutputTileSizes({1}), &llvm_module_, + mlir_context_), + Not(IsOk())); } TEST_F(TritonSupportTest, @@ -462,11 +433,13 @@ ENTRY triton_computation { IsTritonSupportedInstruction(ti.Instruction(), GetCudaComputeCapability()) .Explain(), ::testing::HasSubstr("Unsupported output data type")); - EXPECT_THAT(TritonFusionAnalysis::Execute(ti.TritonComputation()), - tsl::testing::StatusIs( - absl::StatusCode::kFailedPrecondition, - ::testing::HasSubstr( - "Can not propagate dim orders and requirements"))); + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + EXPECT_THAT( + TritonWrapper("test_fn", &ti.TritonFusion(), GetCudaComputeCapability(), + dev_info, FromOutputTileSizes({1}), &llvm_module_, + mlir_context_), + Not(IsOk())); } TEST_F(TritonSupportTest, diff --git a/third_party/xla/xla/service/gpu/triton_test_utils.cc b/third_party/xla/xla/service/gpu/triton_test_utils.cc index 97e5925d8a42f2..e9fbd2547412d2 100644 --- a/third_party/xla/xla/service/gpu/triton_test_utils.cc +++ b/third_party/xla/xla/service/gpu/triton_test_utils.cc @@ -125,7 +125,7 @@ absl::Status TritonFilecheckTest::CreateTritonIrAndFileCheckForDot( filecheck_pattern); } -absl::StatusOr TritonSupportTest::ApplyFloatNormalization( +absl::StatusOr TritonSupportTestBase::ApplyFloatNormalization( HloModule* module) { const GpuFloatSupport bf16_support(GetCudaComputeCapability(), BF16); HloPassPipeline pipeline("hlo float normalization"); @@ -184,8 +184,8 @@ absl::Status ConvertEntryToTritonFusion(HloModule* module) { } // namespace -absl::StatusOr -TritonSupportTest::ParseTemplateAndGetInstruction( +absl::StatusOr +TritonSupportTestBase::ParseTemplateAndGetInstruction( absl::string_view hlo_template, xla::PrimitiveType data_type, xla::HloOpcode opcode) { const std::string hlo_text = absl::Substitute( diff --git a/third_party/xla/xla/service/gpu/triton_test_utils.h b/third_party/xla/xla/service/gpu/triton_test_utils.h index c7cb16abdf36ec..09184145c32f0f 100644 --- a/third_party/xla/xla/service/gpu/triton_test_utils.h +++ b/third_party/xla/xla/service/gpu/triton_test_utils.h @@ -91,7 +91,7 @@ class TritonFilecheckTest : public TritonTest { } }; -class TritonSupportTest : public TritonFilecheckTest { +class TritonSupportTestBase : public TritonFilecheckTest { protected: // An HLO module together with a reference to the instruction of interest // that's being tested. See ParseTemplateAndGetInstruction for more details. @@ -114,7 +114,7 @@ class TritonSupportTest : public TritonFilecheckTest { const HloInstruction& Instruction() { return instruction_; } private: - friend TritonSupportTest; + friend TritonSupportTestBase; TestedInstruction(std::unique_ptr module, const HloInstruction& instruction) @@ -148,9 +148,10 @@ class TritonSupportTest : public TritonFilecheckTest { TritonGemmConfig config_{16, 32, 512, 1, 4, 8}; }; -class TritonSupportTestWithParam : public TritonSupportTest, - public ::testing::WithParamInterface< - std::tuple> {}; +class TritonSupportTestBaseWithParam + : public TritonSupportTestBase, + public ::testing::WithParamInterface< + std::tuple> {}; std::string TritonSupportTestParamsToString( const ::testing::TestParamInfo>& data); diff --git a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc index c9c15ba95b3d48..55de5b9958ed59 100644 --- a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc +++ b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc @@ -952,7 +952,7 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, properties); } else if (hlo.operand_count() > 0 && - legacy_triton::IsTritonSupportedElementwise( + legacy_triton::IsTritonSupportedElementwiseUpToFloatNormalization( hlo.opcode(), hlo.operand(0)->shape().element_type())) { return GetPropagatedDimOrdersForElementwise(hlo, direction, src_dim_order); } else if (hlo.opcode() == HloOpcode::kBitcast) { From 353e39e09dfb5aa58562ad9bb4cb40627a474d33 Mon Sep 17 00:00:00 2001 From: Alexander Lyashuk Date: Wed, 19 Jun 2024 07:34:35 -0700 Subject: [PATCH 018/256] [XLA:GPU] A test used a literal out of bounds. Vector of int8_t was initialized with value of 130 in test. PiperOrigin-RevId: 644744094 --- third_party/xla/xla/service/gpu/buffer_comparator_test.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/buffer_comparator_test.cc b/third_party/xla/xla/service/gpu/buffer_comparator_test.cc index 56101fc0f61257..5f8a01cb4798e3 100644 --- a/third_party/xla/xla/service/gpu/buffer_comparator_test.cc +++ b/third_party/xla/xla/service/gpu/buffer_comparator_test.cc @@ -234,9 +234,9 @@ TEST_F(BufferComparatorTest, TestNumbers) { EXPECT_TRUE(CompareEqualFloatBuffers({10}, {9})); EXPECT_TRUE(CompareEqualFloatBuffers({100}, {101})); - EXPECT_FALSE(CompareEqualFloatBuffers({100}, {130})); - EXPECT_TRUE(CompareEqualFloatBuffers({100}, {115}, 0.2)); - EXPECT_FALSE(CompareEqualFloatBuffers({100}, {130}, 0.2)); + EXPECT_FALSE(CompareEqualFloatBuffers({100}, {120})); + EXPECT_TRUE(CompareEqualFloatBuffers({100}, {120}, 0.2)); + EXPECT_FALSE(CompareEqualFloatBuffers({90}, {120}, 0.2)); EXPECT_FALSE(CompareEqualFloatBuffers({0}, {10})); EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10})); EXPECT_TRUE(CompareEqualFloatBuffers({90}, {100})); From 3685c0a6f6659e2ba9158e5c8c91f6082d833c22 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 19 Jun 2024 07:52:00 -0700 Subject: [PATCH 019/256] [XLA:GPU] Remove all SoftMax-related support from legacy Triton infrastructure. Now that non-dot use cases have been redirected to `SymbolicTileAnalysis` and to the new generic Triton emitter, this infrastructure is no longer necessary. PiperOrigin-RevId: 644747550 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/xla/service/gpu/gemm_fusion.cc | 59 ++-- .../xla/service/gpu/triton_fusion_analysis.cc | 39 +-- .../xla/service/gpu/triton_fusion_analysis.h | 15 +- .../gpu/triton_fusion_analysis_test.cc | 309 ------------------ .../service/gpu/triton_tiling_propagation.cc | 186 ++--------- .../service/gpu/triton_tiling_propagation.h | 33 +- 7 files changed, 82 insertions(+), 560 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 68ec9f6c18a874..7beb1794d8ba19 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1370,6 +1370,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", diff --git a/third_party/xla/xla/service/gpu/gemm_fusion.cc b/third_party/xla/xla/service/gpu/gemm_fusion.cc index af5da900c99bf6..6fd59c31b2e54c 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion.cc @@ -61,17 +61,16 @@ namespace gpu { namespace { -using triton_fusion::CombineRequirements; +using triton_fusion::CombineDotRequirements; using triton_fusion::DimensionOrder; using triton_fusion::DimOrderMap; using triton_fusion::DimOrdersAndReqs; using triton_fusion::DimOrdersAndReqsOrError; +using triton_fusion::DotProperties; using triton_fusion::DotRequirements; +using triton_fusion::DotRequirementsOrError; using triton_fusion::FusionContext; using triton_fusion::GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible; -using triton_fusion::HeroProperties; -using triton_fusion::Requirements; -using triton_fusion::RequirementsOrError; using triton_fusion::TransformDirection; // This represents a directed graph. @@ -139,7 +138,7 @@ struct FusionPlan { struct FusionPlanAndRequirements { FusionPlan fusion_plan; - Requirements requirements; + DotRequirements requirements; }; struct HlosAndRequirements { @@ -154,7 +153,7 @@ struct HlosAndRequirements { // // If we fuse further operations they may have to conform to these // requirements. - Requirements requirements; + DotRequirements requirements; }; // Clones the hero kDot operation into the fusion. @@ -192,32 +191,32 @@ int64_t NumAddedParameters(const HloInstruction& hlo) { // Just a helper to reduce "unwrapping" code where we use this. std::optional GetOperandDimOrdersAndCombinedReqs( const HloInstruction& hlo, const DimensionOrder& dim_order, - const HeroProperties& properties, + const DotProperties& properties, const se::GpuComputeCapability& gpu_version, - const Requirements& requirements) { + const DotRequirements& requirements) { DimOrdersAndReqsOrError dim_orders_and_new_reqs = GetPropagatedDimOrdersAndRequirements( hlo, dim_order, TransformDirection::kOutputToInput, properties); if (!std::holds_alternative(dim_orders_and_new_reqs)) { return std::nullopt; } - RequirementsOrError combined_reqs = CombineRequirements( + DotRequirementsOrError combined_reqs = CombineDotRequirements( requirements, std::get(dim_orders_and_new_reqs).requirements); - if (!std::holds_alternative(combined_reqs)) { + if (!std::holds_alternative(combined_reqs)) { return std::nullopt; } return DimOrdersAndReqs{ std::get(dim_orders_and_new_reqs).dim_orders, - std::get(combined_reqs)}; + std::get(combined_reqs)}; } // Just a helper to reduce "unwrapping" code where we use this. std::optional GetOperandDimOrdersAndCombinedReqsIfProfitable( const HloInstruction& hlo, const DimensionOrder& dim_order, - const HeroProperties& properties, + const DotProperties& properties, const se::GpuComputeCapability& gpu_version, - const Requirements& requirements) { + const DotRequirements& requirements) { DimOrdersAndReqsOrError dim_orders_and_new_reqs = GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( hlo, TransformDirection::kOutputToInput, @@ -226,23 +225,23 @@ std::optional GetOperandDimOrdersAndCombinedReqsIfProfitable( if (!std::holds_alternative(dim_orders_and_new_reqs)) { return std::nullopt; } - RequirementsOrError combined_reqs = CombineRequirements( + DotRequirementsOrError combined_reqs = CombineDotRequirements( requirements, std::get(dim_orders_and_new_reqs).requirements); - if (!std::holds_alternative(combined_reqs)) { + if (!std::holds_alternative(combined_reqs)) { return std::nullopt; } return DimOrdersAndReqs{ std::get(dim_orders_and_new_reqs).dim_orders, - std::get(combined_reqs)}; + std::get(combined_reqs)}; } // Just a helper to reduce "unwrapping" code where we use this. std::optional GetUserDimOrdersAndCombinedReqsIfProfitable( const HloInstruction& hlo, const DimensionOrder& hlo_dim_order, - const HloInstruction& user, const HeroProperties& properties, + const HloInstruction& user, const DotProperties& properties, const se::GpuComputeCapability& gpu_version, - const Requirements& requirements) { + const DotRequirements& requirements) { DimOrdersAndReqsOrError dim_orders_and_new_reqs = GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( user, TransformDirection::kInputToOutput, user.operand_index(&hlo), @@ -250,15 +249,15 @@ std::optional GetUserDimOrdersAndCombinedReqsIfProfitable( if (!std::holds_alternative(dim_orders_and_new_reqs)) { return std::nullopt; } - RequirementsOrError combined_reqs = CombineRequirements( + DotRequirementsOrError combined_reqs = CombineDotRequirements( requirements, std::get(dim_orders_and_new_reqs).requirements); - if (!std::holds_alternative(combined_reqs)) { + if (!std::holds_alternative(combined_reqs)) { return std::nullopt; } return DimOrdersAndReqs{ std::get(dim_orders_and_new_reqs).dim_orders, - std::get(combined_reqs)}; + std::get(combined_reqs)}; } // Builds the fusion map and the requirements which can later be used to @@ -267,7 +266,8 @@ FusionPlanAndRequirements BuildFusionPlanTowardOperands( const HloInstruction& root_hlo, const DimensionOrder& root_dim_order, const std::optional& max_params, const se::GpuComputeCapability& gpu_version, - const HeroProperties& properties, const Requirements& requirements_so_far) { + const DotProperties& properties, + const DotRequirements& requirements_so_far) { CHECK(!max_params.has_value() || max_params.value() >= 1); // The graph describing the structure of the fusion that we build - nodes @@ -290,7 +290,7 @@ FusionPlanAndRequirements BuildFusionPlanTowardOperands( // The requirements imposed by the fusion choices made in this function, // combined with the existing requirements. This is one of the outputs of this // function. - Requirements combined_reqs = requirements_so_far; + DotRequirements combined_reqs = requirements_so_far; auto get_or_create_fusion_node = [&](const HloInstruction& hlo, const DimensionOrder& dim_order, @@ -453,7 +453,7 @@ HlosAndRequirements FuseTowardOperands( const HloInstruction& root_hlo, const DimensionOrder& root_dim_order, const std::optional& max_params, const se::GpuComputeCapability& gpu_version, - const HeroProperties& properties, const Requirements& requirements_so_far, + const DotProperties& properties, const DotRequirements& requirements_so_far, HloComputation::Builder& builder, // append std::vector& fusion_params // append ) { @@ -491,7 +491,7 @@ absl::StatusOr FuseDotOperand( const HloInstruction& operand = *dot.operand(operand_index); return FuseTowardOperands(operand, context.dim_orders().at(&operand), TritonFusionAnalysis::kMaxParameterPerDotOperand, - gpu_version, context.hero_properties(), + gpu_version, context.dot_properties(), context.requirements(), builder, fusion_params); } @@ -512,7 +512,7 @@ HlosAndRequirements FuseTowardUsers( const HloInstruction& hlo, const HloInstruction& fused_hlo, const DimensionOrder& hlo_dim_order, const se::GpuComputeCapability& gpu_version, - const HeroProperties& properties, const Requirements& requirements, + const DotProperties& properties, const DotRequirements& requirements, HloComputation::Builder& builder, // append std::vector& fusion_params // append ) { @@ -533,7 +533,7 @@ HlosAndRequirements FuseTowardUsers( return existing_hlos_and_requirements; } DimensionOrder user_dim_order = opt_user_result->dim_orders.at(&user); - Requirements combined_requirements = opt_user_result->requirements; + DotRequirements combined_requirements = opt_user_result->requirements; HloInstruction::InstructionVector new_operands; if (user.operand_count() == 1) { @@ -601,7 +601,7 @@ HlosAndRequirements FuseDotOutput( const auto context = FusionContext::FromDotOutput(dot, /*split_k=*/1, requirements); return FuseTowardUsers(dot, fused_dot, context.dim_orders().at(&dot), - gpu_version, context.hero_properties(), + gpu_version, context.dot_properties(), context.requirements(), builder, fusion_params); } @@ -656,8 +656,7 @@ absl::StatusOr CreateDotFusion( // For now the RHS doesn't support splits, so it also doesn't impose any // requirements. HlosAndRequirements fused_output_and_reqs = - FuseDotOutput(dot, fused_dot, gpu_version, - std::get(lhs_hlos_and_reqs.requirements), + FuseDotOutput(dot, fused_dot, gpu_version, lhs_hlos_and_reqs.requirements, builder, fusion_inputs); if (fusion_output_ptr != nullptr) { diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc b/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc index 10eb3893e96457..b048c7d4b87d74 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc @@ -110,16 +110,6 @@ namespace triton_fusion { return context; } -/*static*/ FusionContext FusionContext::FromSoftmaxRoot( - const HloInstruction& root) { - FusionContext context( - SoftmaxProperties{DimensionOrder::kSoftmaxReductionDimension, - DimensionOrder::kSoftmaxBatchDimension}, - SoftmaxRequirements{}); - context.dim_orders_[&root] = DimensionOrder::FromSoftmaxRoot(root); - return context; -} - namespace { // Tells how many new parameters does a fusion gain by fusing the operation as @@ -146,13 +136,13 @@ bool FusionContext::CombineDimOrdersAndReqs(const DimOrdersAndReqs& update) { } } - RequirementsOrError requirements_or_error = - CombineRequirements(requirements_, update.requirements); + DotRequirementsOrError requirements_or_error = + CombineDotRequirements(requirements_, update.requirements); if (std::holds_alternative(requirements_or_error)) { return false; } - requirements_ = std::move(std::get(requirements_or_error)); + requirements_ = std::move(std::get(requirements_or_error)); dim_orders_.insert(update.dim_orders.begin(), update.dim_orders.end()); return true; } @@ -218,26 +208,11 @@ absl::StatusOr TritonFusionAnalysis::Execute( TritonFusionAnalysis analysis; const HloInstruction* dot = hlo_query::GetFirstInstructionWithOpcode(computation, HloOpcode::kDot); - if (dot != nullptr) { - TF_RETURN_IF_ERROR(analysis.ExecuteForDotFusion(*dot, split_k)); - } else { - TF_RETURN_IF_ERROR( - analysis.ExecuteForSoftmaxFusion(*computation.root_instruction())); - } + TF_RET_CHECK(dot != nullptr); + TF_RETURN_IF_ERROR(analysis.ExecuteForDotFusion(*dot, split_k)); return analysis; } -absl::Status TritonFusionAnalysis::ExecuteForSoftmaxFusion( - const HloInstruction& root) { - auto context = FusionContext::FromSoftmaxRoot(root); - // Softmax fusion uses one tiled scope. - TF_RETURN_IF_ERROR(context.PropagateDimensionOrdersToParameters( - root, parameters_[Scope::OUTPUT], iter_specs_[Scope::OUTPUT])); - iter_specs_[Scope::LHS] = {}; - iter_specs_[Scope::RHS] = {}; - return absl::OkStatus(); -} - absl::Status TritonFusionAnalysis::ExecuteForProducerConsumer( const HloInstruction& producer, const HloInstruction& consumer, int split_k) { @@ -286,7 +261,7 @@ absl::Status TritonFusionAnalysis::ExecuteForDotFusion( TF_RETURN_IF_ERROR(context.PropagateDimensionOrdersToParameters( *dot.operand(operand_number), parameters_[scope], iter_specs_[scope])); if (scope == Scope::LHS) { - lhs_requirements = std::get(context.requirements()); + lhs_requirements = context.requirements(); } } @@ -307,7 +282,7 @@ absl::Status TritonFusionAnalysis::ExecuteForDotFusion( output = output->users()[0]; DimOrdersAndReqsOrError result = GetPropagatedDimOrdersAndRequirements( *output, context.dim_orders().at(input), - TransformDirection::kInputToOutput, context.hero_properties()); + TransformDirection::kInputToOutput, context.dot_properties()); TF_RET_CHECK(std::holds_alternative(result)); TF_RET_CHECK( context.CombineDimOrdersAndReqs(std::get(result))); diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis.h b/third_party/xla/xla/service/gpu/triton_fusion_analysis.h index 69a419d702ccb2..4abad469904a03 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_analysis.h +++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis.h @@ -18,9 +18,11 @@ limitations under the License. // This file contains TritonFusionAnalysis and FusionContext. #include +#include #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -33,7 +35,6 @@ namespace gpu { // Analysis of tensor iteration orders within tiled fusions. class TritonFusionAnalysis { absl::Status ExecuteForDotFusion(const HloInstruction& dot, int split_k); - absl::Status ExecuteForSoftmaxFusion(const HloInstruction& root); public: // Execute the analysis of a fusion computation. @@ -95,7 +96,7 @@ class TritonFusionAnalysis { // namespace to avoid littering the xla::gpu namespace. namespace triton_fusion { class FusionContext { - FusionContext(HeroProperties properties, Requirements requirements) + FusionContext(DotProperties properties, DotRequirements requirements) : properties_(properties), requirements_(requirements) {} public: @@ -109,8 +110,6 @@ class FusionContext { static FusionContext FromDotOutput(const HloInstruction& dot, int split_k, DotRequirements requirements); - static FusionContext FromSoftmaxRoot(const HloInstruction&); - // Add dimension orders from `update` to `dim_orders_` and update // `requirements_` if all of them are compatible. bool CombineDimOrdersAndReqs(const DimOrdersAndReqs& update); @@ -122,13 +121,13 @@ class FusionContext { const HloInstruction& origin, ConstHloInstructionSet& parameters, ConstHloInstructionMap& iter_specs); - const HeroProperties& hero_properties() const { return properties_; } + const DotProperties& dot_properties() const { return properties_; } const DimOrderMap& dim_orders() const { return dim_orders_; } - const Requirements& requirements() const { return requirements_; } + const DotRequirements& requirements() const { return requirements_; } private: - const HeroProperties properties_; - Requirements requirements_; + const DotProperties properties_; + DotRequirements requirements_; DimOrderMap dim_orders_; }; diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc index 051ae78fa9644b..c911bf1462e353 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc @@ -793,315 +793,6 @@ ENTRY e { } } -using TritonSoftmaxAnalysisTest = HloTestBase; - -TEST_F(TritonSoftmaxAnalysisTest, DegenerateBatchDimensionIsSupported) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -max { - p1 = f32[] parameter(1) - p0 = f32[] parameter(0) - ROOT m = f32[] maximum(p0, p1) -} - -triton_softmax_computation { - p0 = f32[1,97]{1,0} parameter(0) - bitcast = f32[97]{0} bitcast(p0) - constant = f32[] constant(-inf) - reduce = f32[] reduce(bitcast, constant), dimensions={0}, to_apply=max - broadcast = f32[1,97]{1,0} broadcast(reduce), dimensions={} - ROOT subtract = f32[1,97]{1,0} subtract(p0, broadcast) -} - -ENTRY e { - p0 = f32[1,97]{1,0} parameter(0) - ROOT r = f32[1,97]{1,0} fusion(p0), kind=kCustom, - calls=triton_softmax_computation, - backend_config={"kind":"__triton"} -})")); - const HloComputation* computation = - module->entry_computation()->root_instruction()->called_computations()[0]; - TF_ASSERT_OK_AND_ASSIGN(const auto analysis, - TritonFusionAnalysis::Execute(*computation)); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, - computation->root_instruction(), 0), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/97, - /*slice_start=*/0, /*slice_limit=*/97, - /*subfragments=*/ElementsAre(97)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, - computation->root_instruction(), 1), - ElementsAre(FieldsAre(/*stride=*/97, /*count=*/1, - /*slice_start=*/0, /*slice_limit=*/1, - /*subfragments=*/ElementsAre(1)))); -} - -TEST_F(TritonSoftmaxAnalysisTest, BroadcastIntoBatchDimensionIsSupported) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -c { - p1 = f32[127]{0} parameter(0) - ROOT b = f32[125,127]{1,0} broadcast(p1), dimensions={1} -} - -ENTRY e { - p0 = f32[127]{0} parameter(0) - ROOT t = f32[125,127]{1,0} fusion(p0), kind=kCustom, calls=c -})")); - const HloComputation* computation = - module->entry_computation()->root_instruction()->called_computations()[0]; - TF_ASSERT_OK_AND_ASSIGN(const auto analysis, - TritonFusionAnalysis::Execute(*computation)); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, - computation->root_instruction(), 0), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/127, - /*slice_start=*/0, /*slice_limit=*/127, - /*subfragments=*/ElementsAre(127)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, - computation->root_instruction(), 1), - ElementsAre(FieldsAre(/*stride=*/127, /*count=*/125, - /*slice_start=*/0, /*slice_limit=*/125, - /*subfragments=*/ElementsAre(125)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, - computation->parameter_instruction(0), 0), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/127, - /*slice_start=*/0, /*slice_limit=*/127, - /*subfragments=*/ElementsAre(127)))); - EXPECT_EQ(analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, - computation->parameter_instruction(0), 1), - nullptr); -} - -TEST_F(TritonSoftmaxAnalysisTest, ReduceOfNonRowDimensionIsNotSupported) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -HloModule t -add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) -} - -triton_softmax_computation { - param_0 = f32[8,4,127]{2,1,0} parameter(0) - constant = f32[] constant(0) - ROOT reduce = f32[4,127]{1,0} reduce(param_0, constant), dimensions={0}, to_apply=add -} - -ENTRY main { - param_0 = f32[8,4,127]{2,1,0} parameter(0) - ROOT fusion = f32[4,127]{1,0} fusion(param_0), kind=kCustom, - calls=triton_softmax_computation, - backend_config={"kind":"__triton"} -})")); - - const HloComputation* computation = - module->entry_computation()->root_instruction()->called_computations()[0]; - const auto analysis = TritonFusionAnalysis::Execute(*computation); - EXPECT_FALSE(analysis.ok()); -} - -TEST_F(TritonSoftmaxAnalysisTest, PadWithinTritonSoftmaxIsNotSupported) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -HloModule t - -add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) -} - -triton_softmax_computation { - param_1 = f32[4,127]{1,0} parameter(0) - constant_0 = f32[] constant(0) - reduce = f32[4]{0} reduce(param_1, constant_0), dimensions={1}, to_apply=add - broadcast = f32[4,127]{1,0} broadcast(reduce), dimensions={0} - ROOT pad = f32[8,127]{1,0} pad(broadcast, constant_0), padding=0_4x0_0 -} - -ENTRY main { - param_0 = f32[4,127]{1,0} parameter(0) - ROOT fusion = f32[8,127]{1,0} fusion(param_0), kind=kCustom, - calls=triton_softmax_computation, - backend_config={"kind":"__triton"} -})")); - - const HloComputation* computation = - module->entry_computation()->root_instruction()->called_computations()[0]; - const auto analysis = TritonFusionAnalysis::Execute(*computation); - EXPECT_FALSE(analysis.ok()); -} - -TEST_F(TritonSoftmaxAnalysisTest, - BitcastWhichSplitsBatchAndReduceDimensionsIsNotSupported) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) -} - -triton_softmax_computation { - param_0 = f32[8,16129]{1,0} parameter(0) - bitcast = f32[8,127,127]{2,1,0} bitcast(param_0) - constant = f32[] constant(0) - reduce = f32[8,127]{1,0} reduce(bitcast, f32[] constant), dimensions={2}, to_apply=add - ROOT broadcast = f32[8,127,127]{2,1,0} broadcast(reduce), dimensions={0,1} -} - -ENTRY main { - param_1 = f32[8,16129]{1,0} parameter(0) - ROOT fusion = f32[8,127,127]{2,1,0} fusion(param_1), kind=kCustom, - calls=triton_softmax_computation, - backend_config={"kind":"__triton"} -})")); - - const HloComputation* computation = - module->entry_computation()->root_instruction()->called_computations()[0]; - const auto analysis = TritonFusionAnalysis::Execute(*computation); - EXPECT_FALSE(analysis.ok()); -} - -TEST_F(TritonSoftmaxAnalysisTest, - BitcastWhichSplitsReduceDimensionIsSupported) { - // Clone of BitcastWhichSplitsBatchAndReduceDimensionsIsNotSupported, - // But in this case the split dimension can be fully tiled as a reduce dim. - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) -} - -triton_softmax_computation { - param_0 = f32[1,8,127,128]{3,2,1,0} parameter(0) - intermediate_bitcast = f32[8,127,2,64]{3,2,1,0} bitcast(param_0) - bitcast = f32[8,127,128]{2,1,0} bitcast(intermediate_bitcast) - constant = f32[] constant(0) - reduce = f32[8,127]{1,0} reduce(bitcast, constant), dimensions={2}, to_apply=add - ROOT broadcast = f32[8,127,128]{2,1 ,0} broadcast(reduce), dimensions={0,1} -} - -ENTRY main { - param_1 = f32[1,8,127,128]{3,2,1,0} parameter(0) - ROOT fusion = f32[8,127,128]{2,1,0} fusion(param_1), kind=kCustom, - calls=triton_softmax_computation, - backend_config={"kind":"__triton"} -})")); - - const HloComputation* computation = - module->entry_computation()->root_instruction()->called_computations()[0]; - TF_ASSERT_OK_AND_ASSIGN(const auto analysis, - TritonFusionAnalysis::Execute(*computation)); -} - -TEST_F(TritonSoftmaxAnalysisTest, - BitcastWhichDoesNotAffectReduceDimIsSupported) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) -} - -triton_softmax_computation { - param_0 = f32[1,2,4,127,128]{4,3,2,1,0} parameter(0) - bitcast = f32[8,127,128]{2,1,0} bitcast(param_0) - constant = f32[] constant(0) - reduce = f32[8,127]{1,0} reduce(bitcast, constant), dimensions={2}, to_apply=add - ROOT broadcast = f32[8,127,128]{2,1,0} broadcast(reduce), dimensions={0,1} -} - -ENTRY main { - param_1 = f32[1,2,4,127,128]{4,3,2,1,0} parameter(0) - ROOT fusion = f32[8,127,128]{2,1,0} fusion(param_1), kind=kCustom, - calls=triton_softmax_computation, - backend_config={"kind":"__triton"} -})")); - - const HloComputation* computation = - module->entry_computation()->root_instruction()->called_computations()[0]; - TF_ASSERT_OK_AND_ASSIGN(const auto analysis, - TritonFusionAnalysis::Execute(*computation)); -} - -TEST_F(TritonSoftmaxAnalysisTest, SliceWithinTritonSoftmaxIsNotSupported) { - // Slices cannot yet be tiled into triton softmax (b/316637896) because they - // cannot be emitted. - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -HloModule t - -add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) -} - -triton_softmax_computation { - param_0 = f32[27,260]{1,0} parameter(0) - slice = f32[4,127]{1,0} slice(param_0), slice={[7:27:5], [6:260:2]} - constant_0 = f32[] constant(0) - reduce = f32[4]{0} reduce(slice, constant_0), dimensions={1}, to_apply=add - ROOT broadcast = f32[4,127]{1,0} broadcast(reduce), dimensions={0} -} - -ENTRY main { - param_0 = f32[27,260]{1,0} parameter(0) - ROOT fusion = f32[4,127]{1,0} fusion(param_0), kind=kCustom, - calls=triton_softmax_computation, - backend_config={"kind":"__triton"} -})")); - - const HloComputation* computation = - module->entry_computation()->root_instruction()->called_computations()[0]; - const auto analysis = TritonFusionAnalysis::Execute(*computation); - EXPECT_FALSE(analysis.ok()); -} - -TEST_F(TritonSoftmaxAnalysisTest, ProducerConsumerFusion) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( -HloModule t -add { - Arg_0 = f32[] parameter(0) - Arg_1 = f32[] parameter(1) - ROOT add = f32[] add(Arg_0, Arg_1) -} - -producer_computation { - parameter_0 = f32[125] parameter(0) - ROOT broadcast = f32[125,127] broadcast(parameter_0), dimensions={0} -} - -triton_softmax_computation { - parameter_0 = f32[125,127] parameter(0) - multiply_0 = f32[125,127] multiply(parameter_0, parameter_0) - constant_0 = f32[] constant(0) - reduce_0 = f32[125] reduce(multiply_0, constant_0), dimensions={1}, to_apply=add - broadcast_4 = f32[125,127] broadcast(reduce_0), dimensions={0} - ROOT multiply = f32[125,127] multiply(multiply_0, broadcast_4) -} - -ENTRY main { - param_0 = f32[125] parameter(0) - param_1 = f32[125,127] parameter(1) - producer_fusion = f32[125,127] fusion(param_0), kind=kLoop, calls=producer_computation - ROOT triton_softmax = f32[125,127] fusion(producer_fusion), kind=kCustom, - calls=triton_softmax_computation, - backend_config={"fusion_backend_config": {"kind":"__triton"}} -})")); - - auto consumer = module->entry_computation()->root_instruction(); - auto producer = consumer->operand(0); - - EXPECT_TRUE( - TritonFusionAnalysis::ExecuteForProducerConsumer(*producer, *consumer) - .ok()); -} - TEST_F(TritonDotAnalysisTest, PadWithTrivialDimension) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( diff --git a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc index 55de5b9958ed59..905adf3c7c57e1 100644 --- a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc +++ b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc @@ -173,23 +173,6 @@ using FragmentOrders = DimensionOrder::FragmentOrders; return dim_order; } -/*static*/ DimensionOrder DimensionOrder::FromSoftmaxRoot( - const HloInstruction& hlo) { - DimensionOrder dim_order; - dim_order.tensor_fragments_order_.reserve(hlo.shape().rank()); - dim_order.dim_fragments_orders_[kSoftmaxReductionDimension].push_back( - dim_order.tensor_fragments_order_.size()); - dim_order.tensor_fragments_order_.push_back( - Fragment{kSoftmaxReductionDimension, hlo.shape().dimensions_minor(0)}); - for (int i = 1; i < hlo.shape().rank(); ++i) { - dim_order.dim_fragments_orders_[kSoftmaxBatchDimension].push_back( - dim_order.tensor_fragments_order_.size()); - dim_order.tensor_fragments_order_.push_back( - Fragment{kSoftmaxBatchDimension, hlo.shape().dimensions_minor(i)}); - } - return dim_order; -} - std::string DimensionOrder::Fragment::ToString() const { return absl::StrCat(dst_dim_number_, ":", count_, ":", slice_start_, "-", sliced_count_); @@ -303,8 +286,14 @@ Int64OrError CombineSplitDimMajorPartSizeReqs(int64_t a, int64_t b) { return FusionDecision("Conflicting splits of splittable dimension"); } -RequirementsOrError CombineDotRequirements(DotRequirements a, - DotRequirements b) { +} // namespace + +DotRequirementsOrError CombineDotRequirements( + DotRequirements a, DotRequirementsOrError b_or_error) { + if (std::holds_alternative(b_or_error)) { + return b_or_error; + } + const DotRequirements& b = std::get(b_or_error); Int64OrError combined_size_req = CombineSplitDimMajorPartSizeReqs(a.splittable_dimension_major_part_size, b.splittable_dimension_major_part_size); @@ -313,37 +302,14 @@ RequirementsOrError CombineDotRequirements(DotRequirements a, } return DotRequirements(std::get(combined_size_req)); } - -RequirementsOrError CombineSoftmaxRequirements(SoftmaxRequirements a, - SoftmaxRequirements b) { - // SoftmaxRequirements is an empty class for now. - return a; -} - -} // namespace - -RequirementsOrError CombineRequirements(Requirements a, - RequirementsOrError b_or_error) { - if (std::holds_alternative(b_or_error)) { - return b_or_error; - } - const Requirements& b = std::get(b_or_error); - if (std::holds_alternative(b)) { - return CombineDotRequirements(std::get(a), - std::get(b)); - } - return CombineSoftmaxRequirements(std::get(a), - std::get(b)); -} - namespace { // If the dimension order is supported by the triton emitters, this returns // which requirements does this order impose on the fusion. // // All subdimensions within a dimension have to be ordered. -RequirementsOrError GetRequirementsIfSupportedOrder( - const DimensionOrder& order, const HeroProperties& properties) { +DotRequirementsOrError GetRequirementsIfSupportedOrder( + const DimensionOrder& order, const DotProperties& properties) { VLOG(8) << order.ToString(); int64_t split_dim_major_part = kNoSplitRequirement; const Fragments& tensor_dim_fragments = order.TensorFragmentsOrder(); @@ -380,14 +346,11 @@ RequirementsOrError GetRequirementsIfSupportedOrder( ++group_counter; if (group_counter > 1) { - if (!std::holds_alternative(properties)) { - return "Splitting a dimension is not supported for Softmax."; - } // Only the dimension indicated by `splittable_dimension_index` (if any) // can be split physically once by other dimensions. Other ones can be // only split logically. const int splittable_dimension_index = - std::get(properties).splittable_dimension_index; + properties.splittable_dimension_index; if (dim_index == splittable_dimension_index) { if (group_counter == 2) { if (split_dim_major_part != kNoSplitRequirement && @@ -408,40 +371,34 @@ RequirementsOrError GetRequirementsIfSupportedOrder( } } - if (std::holds_alternative(properties)) { - return DotRequirements(split_dim_major_part); - } - return SoftmaxRequirements{}; + return DotRequirements(split_dim_major_part); } // Apply GetRequirementsIfSupportedOrder() to all known // dimension orders around `hlo` and combine the result. -RequirementsOrError GetRequirementsIfSupportedOrders( +DotRequirementsOrError GetRequirementsIfSupportedOrders( const HloInstruction& hlo, const DimOrderMap& dim_orders, - const HeroProperties& properties) { - const Requirements empty_requirements = - std::holds_alternative(properties) - ? Requirements(DotRequirements(kNoSplitRequirement)) - : Requirements(SoftmaxRequirements{}); + const DotProperties& properties) { + const DotRequirements empty_requirements(kNoSplitRequirement); auto get_requirements = - [&](const HloInstruction& instr) -> RequirementsOrError { + [&](const HloInstruction& instr) -> DotRequirementsOrError { if (auto it = dim_orders.find(&instr); it != dim_orders.end()) { return GetRequirementsIfSupportedOrder(it->second, properties); } return empty_requirements; }; - Requirements requirements = empty_requirements; + DotRequirements requirements = empty_requirements; for (const HloInstruction* operand : hlo.operands()) { - RequirementsOrError requirements_or_error = - CombineRequirements(requirements, get_requirements(*operand)); + DotRequirementsOrError requirements_or_error = + CombineDotRequirements(requirements, get_requirements(*operand)); if (std::holds_alternative(requirements_or_error)) { return requirements_or_error; } - requirements = std::get(requirements_or_error); + requirements = std::get(requirements_or_error); } - return CombineRequirements(requirements, get_requirements(hlo)); + return CombineDotRequirements(requirements, get_requirements(hlo)); } DimOrderMap GetPropagatedDimOrdersForElementwise( @@ -496,7 +453,7 @@ const HloInstruction& GetDestHlo(const HloInstruction& hlo, DimOrderMapOrError GetPropagatedDimOrdersForBitcast( const HloInstruction& hlo, const TransformDirection direction, - const DimensionOrder& src_dim_order, const HeroProperties& properties) { + const DimensionOrder& src_dim_order, const DotProperties& properties) { const HloInstruction& dst = GetDestHlo(hlo, direction); const Shape& dst_shape = dst.shape(); const Fragments& src_fragments_order = src_dim_order.TensorFragmentsOrder(); @@ -504,7 +461,6 @@ DimOrderMapOrError GetPropagatedDimOrdersForBitcast( DimensionOrder& dst_dim_order = dst_dim_orders.insert({&dst, DimensionOrder()}).first->second; Fragments& dst_fragments_order = dst_dim_order.TensorFragmentsOrder(); - bool dst_remainder_comes_from_reduce_dim = false; // Size of not yet assigned part of current target dimension. int64_t dst_remaining_size = 1; // Track destination fragments created from a source one. @@ -520,38 +476,6 @@ DimOrderMapOrError GetPropagatedDimOrdersForBitcast( dst_fragments_order.push_back(fragment); src_to_dst[&*src_dim].push_back(dst_fragments_order.size() - 1); }; - if (std::holds_alternative(properties) && - src_dim->dst_dim_number() == - std::get(properties).softmax_batch_dimension) { - // Special handling for softmax batch dimension: allow arbitrary reshapes - // on it because it's guaranteed by the construction of the fusion to have - // no physical alterations like transposes. - // Find a continuous group of fragments corresponding to this dimension in - // the source and assign the corresponding size in fragments of the - // destination ignoring the source ones. - - // If there is dst_remaining_size leftover from our previous src_dim, - // and it came from a reduce dim, we cannot tile it in a batch dim. - if (dst_remainder_comes_from_reduce_dim) { - return R"(Unsupported bitcast splits dimension between batch and - reduction dimensions in softmax)"; - } - - dst_remaining_size = src_dim->full_count(); - while (src_dim + 1 != src_fragments_order.cend() && - (src_dim + 1)->dst_dim_number() == src_dim->dst_dim_number()) { - ++src_dim; - dst_remaining_size *= src_dim->full_count(); - } - while (dst_remaining_size > 1) { - CHECK(dst_dim_it != dst_dim_end); - add_new_fragment(Fragment{src_dim->dst_dim_number(), - dst_shape.dimensions(*dst_dim_it)}); - dst_remaining_size /= dst_shape.dimensions(*dst_dim_it); - ++dst_dim_it; - } - continue; - } if (dst_remaining_size >= src_dim->full_count()) { if (dst_remaining_size % src_dim->full_count()) { return "Unsupported bitcast"; @@ -604,16 +528,6 @@ DimOrderMapOrError GetPropagatedDimOrdersForBitcast( ++dst_dim_it; } } - - // We cannot tile a single dim with fragments across both reduce and batch - // dimensions. As such, if we have a dst remainder leftover from tiling a - // src fragment on the reduce dimension in softmax, we must only tile it - // with other src_dim fragments on the reduce dimension. - dst_remainder_comes_from_reduce_dim = - (dst_remaining_size > 1 && - std::holds_alternative(properties) && - src_dim->dst_dim_number() == std::get(properties) - .softmax_reduction_dimension); } CHECK_EQ(dst_remaining_size, 1); @@ -653,7 +567,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForBitcast( // and the way to handle layouts. DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( const HloInstruction& hlo, const TransformDirection direction, - const DimensionOrder& src_dim_order, const HeroProperties& properties) { + const DimensionOrder& src_dim_order, const DotProperties& properties) { // Temporary storage for new fragments local to this function. // Please keep this as the first local variable of this function, with type // std::list to make sure that all pointers to elements of this remain valid @@ -742,19 +656,6 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( reduce->operand(0)->shape().rank() - 1) { return FusionDecision("Only row reductions are supported."); } - for (int i = 0; i < dst_logical.size(); ++i) { - if (i == reduce->dimensions().front()) { - // This way to assign the reduction dimension will only work for - // softmax fusions with known patterns for now. Generally a reduction - // should create a new tiled dimension. - dst_logical[i] = {&new_fragments.emplace_back( - std::get(properties) - .softmax_reduction_dimension, - reduce->operand(0)->shape().dimensions(i))}; - } else { - dst_logical[i] = src_logical[i]; - } - } } else if (hlo.opcode() == HloOpcode::kConcatenate) { dst_logical.resize(src_logical.size()); for (int i = 0; i < src_logical.size(); ++i) { @@ -908,7 +809,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, const TransformDirection direction, const DimensionOrder& src_dim_order, - const HeroProperties& properties) { + const DotProperties& properties) { VLOG(7) << "Analyzing " << hlo.ToString(); if (hlo.opcode() != HloOpcode::kParameter && direction == TransformDirection::kOutputToInput && @@ -932,20 +833,9 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, } return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, properties); - } else if (hlo.opcode() == HloOpcode::kReduce) { - if (!std::holds_alternative(properties)) { - return "Reductions are not supported in GEMM fusions yet."; - } - if (direction != TransformDirection::kOutputToInput) { - return "Unsupported direction of reduction."; - } - return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, - properties); } else if (hlo.opcode() == HloOpcode::kPad) { - if (std::holds_alternative(properties)) { - return "Pad ops are only supported when they are generated as part of " - "the split-k transform of dot fusions."; - } + // Pad ops are only supported when they are generated as part of the split-k + // transform of dot fusions. if (direction != TransformDirection::kOutputToInput) { return "Unsupported pad direction."; } @@ -960,9 +850,6 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, properties); } else if (hlo.opcode() == HloOpcode::kSlice) { // TODO(b/316637896) Add support for slices in softmax. - if (std::holds_alternative(properties)) { - return "Slices are not supported in Softmax fusions yet."; - } if (direction != TransformDirection::kOutputToInput) { return "Unsupported slice direction."; } @@ -971,12 +858,6 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, properties); } else if (hlo.opcode() == HloOpcode::kDynamicSlice && direction == TransformDirection::kOutputToInput) { - // We handle the dynamic slice within EmitTensorPointer, which is only - // used for GEMM fusions. - if (!std::holds_alternative(properties)) { - return "Dynamic slices for now are only supported in GEMM fusions."; - } - if (CodegenDecision decision = legacy_triton::IsTritonSupportedDynamicSlice( *Cast(&hlo)); !decision.CanFuse()) { @@ -994,12 +875,7 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, properties); } else if (hlo.opcode() == HloOpcode::kConcatenate && direction == TransformDirection::kOutputToInput) { - if (!std::holds_alternative(properties)) { - return "Concatenations for now are only supported in GEMM fusions."; - } - - int64_t noncontracting_dim_label = - std::get(properties).noncontracting_dimension; + int64_t noncontracting_dim_label = properties.noncontracting_dimension; const FragmentOrders& src_dim_fragments_orders = src_dim_order.DimFragmentsOrders(); @@ -1098,7 +974,7 @@ FusionDecision IsConversionWorthFusing(const HloInstruction& input, DimOrdersAndReqsOrError GetPropagatedDimOrdersAndRequirements( const HloInstruction& hlo, const DimensionOrder& src_dim_order, - TransformDirection direction, const HeroProperties& properties) { + TransformDirection direction, const DotProperties& properties) { DimOrderMapOrError propagated_dim_orders_or_error = GetPropagatedDimOrders(hlo, direction, src_dim_order, properties); if (std::holds_alternative(propagated_dim_orders_or_error)) { @@ -1106,13 +982,13 @@ DimOrdersAndReqsOrError GetPropagatedDimOrdersAndRequirements( } DimOrderMap propagated_dim_orders = std::move(std::get(propagated_dim_orders_or_error)); - RequirementsOrError requirements_or_error = + DotRequirementsOrError requirements_or_error = GetRequirementsIfSupportedOrders(hlo, propagated_dim_orders, properties); if (std::holds_alternative(requirements_or_error)) { return std::get(requirements_or_error); } return DimOrdersAndReqs{propagated_dim_orders, - std::get(requirements_or_error)}; + std::get(requirements_or_error)}; } DimOrdersAndReqsOrError @@ -1121,7 +997,7 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( const std::optional& src_operand_index, const DimensionOrder& src_dim_order, const se::GpuComputeCapability& gpu_version, - const HeroProperties& properties) { + const DotProperties& properties) { CHECK_EQ(transform_direction == TransformDirection::kInputToOutput, src_operand_index.has_value()); diff --git a/third_party/xla/xla/service/gpu/triton_tiling_propagation.h b/third_party/xla/xla/service/gpu/triton_tiling_propagation.h index 87ff11ae7c7415..ac16cdf7a85190 100644 --- a/third_party/xla/xla/service/gpu/triton_tiling_propagation.h +++ b/third_party/xla/xla/service/gpu/triton_tiling_propagation.h @@ -152,19 +152,9 @@ namespace triton_fusion { // instructions between source and target. class DimensionOrder { public: - // Softmax fusions have a fixed tiling scheme. These numbers are chosen to - // reflect that reductions in softmax fusions currently happen on the minor- - // most dimension (dimensions_minor(0)) and the rest (1+) is treated as a - // single non-tiled batch dimension. The numbers have to match those the - // emitter uses in the queries to the analysis. - static constexpr int kSoftmaxReductionDimension = 0; - static constexpr int kSoftmaxBatchDimension = 1; - static DimensionOrder FromDotOperandOrOutput( const HloInstruction& hlo, int split_k_dimension_index = -1); - static DimensionOrder FromSoftmaxRoot(const HloInstruction& hlo); - // Description of a continuous fragment of one dimension of a tensor. class Fragment { public: @@ -238,13 +228,6 @@ struct DotProperties { // Currently typically LHS non-contracting one. const int splittable_dimension_index; }; -struct SoftmaxProperties { - const int softmax_reduction_dimension; - const int softmax_batch_dimension; -}; -// HeroProperties depend only on the hero op and they don't change as we -// change the fusion. -using HeroProperties = std::variant; // A special value for splittable_dimension_major_part_size. inline constexpr int kNoSplitRequirement = 1; @@ -258,13 +241,11 @@ struct DotRequirements { // dimension must be the given value. int64_t splittable_dimension_major_part_size; }; -struct SoftmaxRequirements {}; -// Requirements can change depending on what we fuse. -using Requirements = std::variant; -using RequirementsOrError = std::variant; -RequirementsOrError CombineRequirements(Requirements a, - RequirementsOrError b_or_error); +using DotRequirementsOrError = std::variant; + +DotRequirementsOrError CombineDotRequirements( + DotRequirements a, DotRequirementsOrError b_or_error); enum class TransformDirection { kInputToOutput, kOutputToInput }; using DimOrderMap = absl::flat_hash_map; @@ -274,7 +255,7 @@ using DimOrderMapOrError = std::variant; // dimension orders through an HLO. struct DimOrdersAndReqs { DimOrderMap dim_orders; - Requirements requirements; + DotRequirements requirements; }; using DimOrdersAndReqsOrError = std::variant; @@ -284,7 +265,7 @@ using DimOrdersAndReqsOrError = std::variant; // fusion. DimOrdersAndReqsOrError GetPropagatedDimOrdersAndRequirements( const HloInstruction& hlo, const DimensionOrder& src_dim_order, - TransformDirection direction, const HeroProperties& properties); + TransformDirection direction, const DotProperties& properties); // If fusing the instruction is possible *and profitable* then it propagates // the `src_dim_order` (describing one side of `hlo`) to the other side and // returns those dim orders and the requirements that they impose on the @@ -298,7 +279,7 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( const std::optional& src_operand_index, const DimensionOrder& src_dim_order, const se::GpuComputeCapability& gpu_version, - const HeroProperties& properties); + const DotProperties& properties); } // namespace triton_fusion } // namespace gpu From 60e8d960d3bb8165e24393db3a9d66f4bb5dcf51 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Wed, 19 Jun 2024 07:58:58 -0700 Subject: [PATCH 020/256] Stop using xla/statusor.h now that it just contains an alias for absl::Status. In some situations, this meant also changing unrelated files to directly include tsl/platform/statusor.h to get the definitions for TF_ASSIGN_OR_RETURN, etc., where they were getting transitively included for free. PiperOrigin-RevId: 644748920 --- third_party/xla/xla/pjrt/BUILD | 28 +++++++++---------- third_party/xla/xla/pjrt/event_pool.cc | 1 + third_party/xla/xla/pjrt/event_pool.h | 2 +- third_party/xla/xla/pjrt/interpreter_device.h | 2 +- third_party/xla/xla/pjrt/layout_mode.cc | 2 +- third_party/xla/xla/pjrt/layout_mode.h | 2 +- third_party/xla/xla/pjrt/pjrt_api.cc | 2 +- third_party/xla/xla/pjrt/pjrt_api.h | 3 +- third_party/xla/xla/pjrt/pjrt_c_api_client.cc | 1 - third_party/xla/xla/pjrt/pjrt_c_api_client.h | 2 +- third_party/xla/xla/pjrt/pjrt_client.h | 1 - third_party/xla/xla/pjrt/pjrt_client_test.cc | 2 +- .../xla/xla/pjrt/pjrt_compiler_test.cc | 1 - third_party/xla/xla/pjrt/pjrt_executable.cc | 2 +- third_party/xla/xla/pjrt/pjrt_layout.h | 2 +- .../xla/pjrt/pjrt_stream_executor_client.cc | 1 - .../xla/pjrt/pjrt_stream_executor_client.h | 2 +- third_party/xla/xla/pjrt/status_casters.h | 2 +- .../xla/pjrt/stream_executor_executable.cc | 2 +- third_party/xla/xla/pjrt/transpose.cc | 3 +- third_party/xla/xla/pjrt/transpose.h | 2 +- third_party/xla/xla/pjrt/transpose_test.cc | 1 + third_party/xla/xla/pjrt/utils.cc | 2 +- third_party/xla/xla/pjrt/utils.h | 2 +- 24 files changed, 34 insertions(+), 36 deletions(-) diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 3ec761e5502451..aa7dfdbdc39a11 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -44,10 +44,11 @@ cc_library( hdrs = ["event_pool.h"], deps = [ "//xla:status_macros", - "//xla:statusor", "//xla:types", "//xla/stream_executor", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:statusor", ], ) @@ -147,12 +148,13 @@ cc_library( srcs = ["pjrt_api.cc"], hdrs = ["pjrt_api.h"], deps = [ - "//xla:statusor", "//xla/pjrt/c:pjrt_c_api_hdrs", "//xla/pjrt/c:pjrt_c_api_helpers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", ], @@ -187,7 +189,6 @@ cc_library( ":utils", "//xla:literal", "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/client:xla_computation", @@ -218,13 +219,13 @@ cc_library( deps = [ ":pjrt_client", "//xla:shape_util", - "//xla:statusor", "//xla:test", "//xla:xla_data_proto_cc", "//xla/client:xla_builder", "//xla/client:xla_computation", "//xla/service:hlo_parser", "//xla/tests:literal_test_util", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:statusor", @@ -245,7 +246,6 @@ cc_library( ":pjrt_layout", "//xla:shape_layout", "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", @@ -322,7 +322,6 @@ xla_cc_test( ":pjrt_client", ":pjrt_compiler", ":pjrt_device_description", - "//xla:statusor", "//xla/client:xla_computation", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", @@ -352,7 +351,6 @@ cc_library( ":layout_mode", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/client:executable_build_options", @@ -363,6 +361,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//mlir:FuncDialect", @@ -381,10 +380,10 @@ cc_library( visibility = ["//xla:friends"], deps = [ "//xla:shape_util", - "//xla:statusor", "//xla/service:hlo_parser", "@com_google_absl//absl/hash", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:statusor", @@ -398,9 +397,9 @@ cc_library( visibility = ["//xla:friends"], deps = [ "//xla:shape_util", - "//xla:statusor", "//xla/service:hlo_parser", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ], ) @@ -436,7 +435,6 @@ cc_library( ":pjrt_common", ":pjrt_executable", ":stream_executor_executable_proto_cc", - "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/service:compiler", "@com_google_absl//absl/container:flat_hash_map", @@ -473,7 +471,6 @@ cc_library( "//xla:shape_tree", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/client:executable_build_options", @@ -555,10 +552,10 @@ cc_library( hdrs = ["interpreter_device.h"], deps = [ ":pjrt_stream_executor_client", - "//xla:statusor", "//xla/client:client_library", "//xla/service:interpreter_plugin", "//xla/service:platform_util", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ], ) @@ -702,18 +699,19 @@ cc_library( "//xla:compiler_macros", "//xla:ef57", "//xla:permutation_util", - "//xla:statusor", "//xla:util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -732,6 +730,7 @@ xla_cc_test( "@com_google_absl//absl/numeric:int128", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", @@ -755,7 +754,6 @@ cc_library( ":pjrt_layout", "//xla:literal", "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", @@ -947,8 +945,8 @@ cc_library( visibility = [":friends"], deps = [ ":exceptions", - "//xla:statusor", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:macros", ], ) diff --git a/third_party/xla/xla/pjrt/event_pool.cc b/third_party/xla/xla/pjrt/event_pool.cc index fadfe8bd56ee79..80e343ba9335cf 100644 --- a/third_party/xla/xla/pjrt/event_pool.cc +++ b/third_party/xla/xla/pjrt/event_pool.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "xla/status_macros.h" +#include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/event_pool.h b/third_party/xla/xla/pjrt/event_pool.h index 2ce9bc7cbcc74e..85594f781a9d26 100644 --- a/third_party/xla/xla/pjrt/event_pool.h +++ b/third_party/xla/xla/pjrt/event_pool.h @@ -19,8 +19,8 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" -#include "xla/statusor.h" #include "xla/stream_executor/stream_executor.h" #include "xla/types.h" diff --git a/third_party/xla/xla/pjrt/interpreter_device.h b/third_party/xla/xla/pjrt/interpreter_device.h index 7952f362cf9a6b..6b987a4e39c2e2 100644 --- a/third_party/xla/xla/pjrt/interpreter_device.h +++ b/third_party/xla/xla/pjrt/interpreter_device.h @@ -18,8 +18,8 @@ limitations under the License. #include +#include "absl/status/statusor.h" #include "xla/pjrt/pjrt_stream_executor_client.h" -#include "xla/statusor.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/layout_mode.cc b/third_party/xla/xla/pjrt/layout_mode.cc index 1983d639155c03..84758c6ef7e293 100644 --- a/third_party/xla/xla/pjrt/layout_mode.cc +++ b/third_party/xla/xla/pjrt/layout_mode.cc @@ -20,10 +20,10 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "xla/layout.h" #include "xla/service/hlo_parser.h" -#include "xla/statusor.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/layout_mode.h b/third_party/xla/xla/pjrt/layout_mode.h index 3e8d208fa89859..6ee3f3dc2ca321 100644 --- a/third_party/xla/xla/pjrt/layout_mode.h +++ b/third_party/xla/xla/pjrt/layout_mode.h @@ -18,9 +18,9 @@ limitations under the License. #include +#include "absl/status/statusor.h" #include "xla/layout.h" #include "xla/shape.h" -#include "xla/statusor.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/pjrt_api.cc b/third_party/xla/xla/pjrt/pjrt_api.cc index 3e8e78cbd0c973..73df2b3b6b7f17 100644 --- a/third_party/xla/xla/pjrt/pjrt_api.cc +++ b/third_party/xla/xla/pjrt/pjrt_api.cc @@ -31,9 +31,9 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/pjrt/c/pjrt_c_api_helpers.h" -#include "xla/statusor.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/pjrt/pjrt_api.h b/third_party/xla/xla/pjrt/pjrt_api.h index 8c8d807ee4a7d0..d2220627ba0072 100644 --- a/third_party/xla/xla/pjrt/pjrt_api.h +++ b/third_party/xla/xla/pjrt/pjrt_api.h @@ -17,9 +17,10 @@ limitations under the License. #define XLA_PJRT_PJRT_API_H_ #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/pjrt/c/pjrt_c_api.h" -#include "xla/statusor.h" +#include "tsl/platform/platform.h" namespace pjrt { diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc index 3ff6a75b5955f1..465eb8682f62d8 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc @@ -67,7 +67,6 @@ limitations under the License. #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/tsl/framework/allocator.h" #include "xla/util.h" diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.h b/third_party/xla/xla/pjrt/pjrt_c_api_client.h index a8d2236930843a..73943657ab4d2e 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.h +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.h @@ -32,6 +32,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" @@ -53,7 +54,6 @@ limitations under the License. #include "xla/service/computation_placer.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/shape.h" -#include "xla/statusor.h" #include "xla/tsl/framework/allocator.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/pjrt/pjrt_client.h b/third_party/xla/xla/pjrt/pjrt_client.h index d9eea97b71cbdb..7e2281d2b08c2a 100644 --- a/third_party/xla/xla/pjrt/pjrt_client.h +++ b/third_party/xla/xla/pjrt/pjrt_client.h @@ -51,7 +51,6 @@ limitations under the License. #include "xla/service/hlo_cost_analysis.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/tsl/framework/allocator.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/pjrt/pjrt_client_test.cc b/third_party/xla/xla/pjrt/pjrt_client_test.cc index e480c9cbdd98b3..cdaadf57295ca5 100644 --- a/third_party/xla/xla/pjrt/pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/pjrt_client_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/synchronization/blocking_counter.h" #include "absl/types/span.h" #include "xla/client/xla_builder.h" @@ -30,7 +31,6 @@ limitations under the License. #include "xla/service/hlo_parser.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/test.h" #include "xla/tests/literal_test_util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/pjrt/pjrt_compiler_test.cc b/third_party/xla/xla/pjrt/pjrt_compiler_test.cc index 14de671708a528..d73bba22f8277a 100644 --- a/third_party/xla/xla/pjrt/pjrt_compiler_test.cc +++ b/third_party/xla/xla/pjrt/pjrt_compiler_test.cc @@ -30,7 +30,6 @@ limitations under the License. #include "xla/pjrt/metrics.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_device_description.h" -#include "xla/statusor.h" #include "tsl/lib/monitoring/cell_reader.h" #include "tsl/platform/status_matchers.h" diff --git a/third_party/xla/xla/pjrt/pjrt_executable.cc b/third_party/xla/xla/pjrt/pjrt_executable.cc index 43638270db32d3..71b391dd39d21c 100644 --- a/third_party/xla/xla/pjrt/pjrt_executable.cc +++ b/third_party/xla/xla/pjrt/pjrt_executable.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -43,7 +44,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_layout.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/pjrt/pjrt_layout.h b/third_party/xla/xla/pjrt/pjrt_layout.h index 60d25a3430b42b..e9c01b36c1fed6 100644 --- a/third_party/xla/xla/pjrt/pjrt_layout.h +++ b/third_party/xla/xla/pjrt/pjrt_layout.h @@ -22,10 +22,10 @@ limitations under the License. #include "absl/hash/hash.h" #include "absl/log/check.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/layout.h" #include "xla/service/hlo_parser.h" -#include "xla/statusor.h" #include "tsl/platform/casts.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index 4c036f2ec74013..99f60350679752 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -129,7 +129,6 @@ limitations under the License. #include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream.h" diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index 7aac763b2d8c1a..64b7e60a618f4a 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -34,6 +34,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" @@ -63,7 +64,6 @@ limitations under the License. #include "xla/service/shaped_buffer.h" #include "xla/shape.h" #include "xla/shape_tree.h" -#include "xla/statusor.h" #include "xla/stream_executor/stream.h" #include "xla/tsl/framework/allocator.h" #include "xla/util.h" diff --git a/third_party/xla/xla/pjrt/status_casters.h b/third_party/xla/xla/pjrt/status_casters.h index 3e997fa98e14e8..5dd48e2f227fd8 100644 --- a/third_party/xla/xla/pjrt/status_casters.h +++ b/third_party/xla/xla/pjrt/status_casters.h @@ -17,8 +17,8 @@ limitations under the License. #define XLA_PJRT_STATUS_CASTERS_H_ #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "xla/pjrt/exceptions.h" -#include "xla/statusor.h" #include "tsl/platform/macros.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/stream_executor_executable.cc b/third_party/xla/xla/pjrt/stream_executor_executable.cc index 766cd0d2147a07..ab82fdaf0c2ec1 100644 --- a/third_party/xla/xla/pjrt/stream_executor_executable.cc +++ b/third_party/xla/xla/pjrt/stream_executor_executable.cc @@ -20,9 +20,9 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "xla/pjrt/stream_executor_executable.pb.h" #include "xla/service/compiler.h" -#include "xla/statusor.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/transpose.cc b/third_party/xla/xla/pjrt/transpose.cc index a04e81e0e362ec..0921357d869e22 100644 --- a/third_party/xla/xla/pjrt/transpose.cc +++ b/third_party/xla/xla/pjrt/transpose.cc @@ -88,6 +88,7 @@ limitations under the License. #include "absl/base/optimization.h" #include "absl/container/inlined_vector.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/synchronization/blocking_counter.h" @@ -96,9 +97,9 @@ limitations under the License. #include "xla/ef57.h" #include "xla/permutation_util.h" #include "xla/pjrt/transpose_kernels.h" -#include "xla/statusor.h" #include "xla/util.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/transpose.h b/third_party/xla/xla/pjrt/transpose.h index 0df1f8a7402289..469e4419c53431 100644 --- a/third_party/xla/xla/pjrt/transpose.h +++ b/third_party/xla/xla/pjrt/transpose.h @@ -35,10 +35,10 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "xla/pjrt/lru_cache.h" -#include "xla/statusor.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/transpose_test.cc b/third_party/xla/xla/pjrt/transpose_test.cc index 011a1ca5eb4780..be702d206131c9 100644 --- a/third_party/xla/xla/pjrt/transpose_test.cc +++ b/third_party/xla/xla/pjrt/transpose_test.cc @@ -34,6 +34,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/util.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test_benchmark.h" #include "tsl/platform/threadpool.h" #include "tsl/protobuf/error_codes.pb.h" diff --git a/third_party/xla/xla/pjrt/utils.cc b/third_party/xla/xla/pjrt/utils.cc index 4a115e402483f9..0117e2e1c70d73 100644 --- a/third_party/xla/xla/pjrt/utils.cc +++ b/third_party/xla/xla/pjrt/utils.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -52,7 +53,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/cpu_info.h" diff --git a/third_party/xla/xla/pjrt/utils.h b/third_party/xla/xla/pjrt/utils.h index face81846e02ef..34c43b70a78c73 100644 --- a/third_party/xla/xla/pjrt/utils.h +++ b/third_party/xla/xla/pjrt/utils.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "xla/client/executable_build_options.h" @@ -32,7 +33,6 @@ limitations under the License. #include "xla/pjrt/layout_mode.h" #include "xla/service/computation_placer.h" #include "xla/shape.h" -#include "xla/statusor.h" #include "xla/xla_data.pb.h" namespace xla { From d87bc7ee2412c085e6183e5a7453f0a2dbb70268 Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Wed, 19 Jun 2024 08:00:43 -0700 Subject: [PATCH 021/256] Avoid underflow in f2reduce PiperOrigin-RevId: 644749268 --- .../temporary/linear_layout_rank_fix.patch | 19 +++++++++++++++++++ third_party/triton/temporary/series.bzl | 4 +++- .../temporary/linear_layout_rank_fix.patch | 19 +++++++++++++++++++ .../third_party/triton/temporary/series.bzl | 4 +++- 4 files changed, 44 insertions(+), 2 deletions(-) create mode 100644 third_party/triton/temporary/linear_layout_rank_fix.patch create mode 100644 third_party/xla/third_party/triton/temporary/linear_layout_rank_fix.patch diff --git a/third_party/triton/temporary/linear_layout_rank_fix.patch b/third_party/triton/temporary/linear_layout_rank_fix.patch new file mode 100644 index 00000000000000..5c0e821bfa2da5 --- /dev/null +++ b/third_party/triton/temporary/linear_layout_rank_fix.patch @@ -0,0 +1,19 @@ +Fixes for b/348174903 + +I'll upstream this. + +diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp +--- a/lib/Tools/LinearLayout.cpp ++++ b/lib/Tools/LinearLayout.cpp +@@ -119,6 +119,11 @@ getInjectiveMat(const LinearLayout &layo + // outDim as columns. In other words, finds the number of linearly-independent + // bases for this output dimension. + int getMatrixRank(std::unique_ptr m, int numRows, int numCols) { ++ // f2reduce underflows if the number of cols is 0, return the rank early in ++ // this case. ++ if (numCols == 0) { ++ return 0; ++ } + // stride is specified in number of 64-bit words per row, and we pack our + // matrix so that there's only one uint64_t per row. + assert(numCols <= 64); diff --git a/third_party/triton/temporary/series.bzl b/third_party/triton/temporary/series.bzl index 53a5059fe432d7..2e2e98d90bb87b 100644 --- a/third_party/triton/temporary/series.bzl +++ b/third_party/triton/temporary/series.bzl @@ -5,4 +5,6 @@ These are created temporarily and should be moved to the first copybara workflow internal patch during the next triton integration process. """ -temporary_patch_list = [] +temporary_patch_list = [ + "//third_party/triton/temporary:linear_layout_rank_fix.patch", +] diff --git a/third_party/xla/third_party/triton/temporary/linear_layout_rank_fix.patch b/third_party/xla/third_party/triton/temporary/linear_layout_rank_fix.patch new file mode 100644 index 00000000000000..5c0e821bfa2da5 --- /dev/null +++ b/third_party/xla/third_party/triton/temporary/linear_layout_rank_fix.patch @@ -0,0 +1,19 @@ +Fixes for b/348174903 + +I'll upstream this. + +diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp +--- a/lib/Tools/LinearLayout.cpp ++++ b/lib/Tools/LinearLayout.cpp +@@ -119,6 +119,11 @@ getInjectiveMat(const LinearLayout &layo + // outDim as columns. In other words, finds the number of linearly-independent + // bases for this output dimension. + int getMatrixRank(std::unique_ptr m, int numRows, int numCols) { ++ // f2reduce underflows if the number of cols is 0, return the rank early in ++ // this case. ++ if (numCols == 0) { ++ return 0; ++ } + // stride is specified in number of 64-bit words per row, and we pack our + // matrix so that there's only one uint64_t per row. + assert(numCols <= 64); diff --git a/third_party/xla/third_party/triton/temporary/series.bzl b/third_party/xla/third_party/triton/temporary/series.bzl index 53a5059fe432d7..2e2e98d90bb87b 100644 --- a/third_party/xla/third_party/triton/temporary/series.bzl +++ b/third_party/xla/third_party/triton/temporary/series.bzl @@ -5,4 +5,6 @@ These are created temporarily and should be moved to the first copybara workflow internal patch during the next triton integration process. """ -temporary_patch_list = [] +temporary_patch_list = [ + "//third_party/triton/temporary:linear_layout_rank_fix.patch", +] From 8e5c47c6994a9c1c3c5a35d1b76f6f0c0b9a6694 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 19 Jun 2024 09:01:35 -0700 Subject: [PATCH 022/256] [XLA:GPU][NFC] Add missing `TODO` in `SymbolicTileTest.CanPropagateTileThroughNonTrivialSplitReshapeFromOutputToInput`. The test is missing constraints due to (for now) missing support for disjunctions. PiperOrigin-RevId: 644762145 --- third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc index 34cb38c265303d..92d92013955275 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc @@ -148,6 +148,8 @@ TEST_F(SymbolicTileTest, TEST_F(SymbolicTileTest, CanPropagateTileThroughNonTrivialSplitReshapeFromOutputToInput) { + // TODO(b/334043867): we need disjunctions here to derive the proper + // constraints for the tile sizes. auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( HloModule m ENTRY e { From d1766d9284fa34e52f6b5b63c9af17ccde74552b Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Wed, 19 Jun 2024 09:08:58 -0700 Subject: [PATCH 023/256] [XLA:GPU] Add a method to Cost Model estimate the best tiling for a fusion. The idea is to iterate over all the tiling options from `SymbolicTileAnalysis::GetGoodTilings()`. Run Cost Model on each tiling and choose the one with the best execution time. This method will be used in TritonSoftMaxRewriter and PriorityFusion to assign block level parameters config to fusions. PiperOrigin-RevId: 644763688 --- third_party/xla/xla/service/gpu/model/BUILD | 2 + .../model/gpu_indexing_performance_model.cc | 121 +++++++++++++++--- .../model/gpu_indexing_performance_model.h | 26 ++++ .../gpu_indexing_performance_model_test.cc | 65 ++++++++++ 4 files changed, 196 insertions(+), 18 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index b27fbcc232ba82..0ce770b63c197a 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -357,6 +357,7 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu/fusions:triton", "//xla/stream_executor:device_description", @@ -386,6 +387,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_traversal", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc index 3e583c84f2d461..049cd3aaae9f21 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -34,6 +35,7 @@ limitations under the License. #include "xla/service/gpu/fusions/triton.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/coalescing_analysis.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" @@ -250,25 +252,11 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimes( return {time_unfused, time_fused}; } -absl::StatusOr -GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledFusion( +EstimateRunTimeData +GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledHloComputation( const HloFusionAdaptor& fusion_adaptor, - const LaunchDimensions& launch_dimensions, - absl::Span tile_sizes) { - // TODO(b/332714755): Add caching for SymbolicTileAnalysis. - SymbolicTileAnalysisOrError analysis_or_error = - SymbolicTileAnalysis::AnalyzeFusion(fusion_adaptor, mlir_context_); - if (!std::holds_alternative(analysis_or_error)) { - return absl::FailedPreconditionError( - absl::StrCat("SymbolicTileAnalysis failed. ", - std::get(analysis_or_error).Explain())); - } - SymbolicTileAnalysis analysis = - std::get(std::move(analysis_or_error)); - - TF_ASSIGN_OR_RETURN(TiledHloComputation tiled_hlo_computation, - analysis.ComputeTiledHloInstructions(tile_sizes)); - + const TiledHloComputation& tiled_hlo_computation, + const LaunchDimensions& launch_dimensions) { absl::flat_hash_map n_bytes_total_map; int64_t flops = 0; @@ -334,6 +322,29 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledFusion( /*exec_time=*/exec_time}; } +absl::StatusOr +GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledFusion( + const HloFusionAdaptor& fusion_adaptor, + const LaunchDimensions& launch_dimensions, + absl::Span tile_sizes) { + // TODO(b/332714755): Add caching for SymbolicTileAnalysis. + SymbolicTileAnalysisOrError analysis_or_error = + SymbolicTileAnalysis::AnalyzeFusion(fusion_adaptor, mlir_context_); + if (const auto* fusion_decision = + std::get_if(&analysis_or_error)) { + return absl::FailedPreconditionError(absl::StrCat( + "SymbolicTileAnalysis failed. ", fusion_decision->Explain())); + } + SymbolicTileAnalysis analysis = + std::get(std::move(analysis_or_error)); + + TF_ASSIGN_OR_RETURN(TiledHloComputation tiled_hlo_computation, + analysis.ComputeTiledHloInstructions(tile_sizes)); + + return EstimateRunTimeForTiledHloComputation( + fusion_adaptor, tiled_hlo_computation, launch_dimensions); +} + absl::StatusOr GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTriton( const HloInstruction* producer, const HloInstruction* consumer) { @@ -352,5 +363,79 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTriton( launch_config->output_tile_sizes); } +// Returns the number of warps to use based on the tile size. The numbers were +// originally selected from Triton SoftMax reduction row length. +// TODO(b/332714755): Make it smarter. +int64_t GetNumWarps(int64_t tile_size) { + if (tile_size <= 512) return 1; + if (tile_size <= 1024) return 2; + if (tile_size <= 16384) return 4; + if (tile_size <= 32768) return 8; + if (tile_size <= 65536) return 16; + return 32; +} + +LaunchDimensions GetLaunchDimensionsForTiledFusion( + const TiledHloComputation& tiled_hlo_computation) { + const auto* tiled_root = tiled_hlo_computation.GetRoot(); + int64_t num_blocks = tiled_root->block_id_to_tile_offsets_indexing() + .GetDimensionBound(0) + .GetLoopTripCount(); + + int64_t num_warps = GetNumWarps(Product(tiled_root->tile_sizes())); + + return {static_cast(num_blocks), + static_cast(num_warps * WarpSize())}; +} + +absl::StatusOr> +GpuPerformanceModelWithIndexingAnalysis::TryFindBestTilingForFusion( + const HloFusionAdaptor& fusion_adaptor) { + SymbolicTileAnalysisOrError analysis_or_error = + SymbolicTileAnalysis::AnalyzeFusion(fusion_adaptor, mlir_context_); + + if (const auto* fusion_decision = + std::get_if(&analysis_or_error)) { + return *fusion_decision; + } + + SymbolicTileAnalysis analysis = + std::get(std::move(analysis_or_error)); + + TF_ASSIGN_OR_RETURN(auto tilings, analysis.GetGoodTilings()); + + std::optional best_tiled_run_time_data; + + for (const auto& tiling : tilings) { + TF_ASSIGN_OR_RETURN(TiledHloComputation tiled_hlo_computation, + analysis.ComputeTiledHloInstructions(tiling)); + + LaunchDimensions launch_dimensions = + GetLaunchDimensionsForTiledFusion(tiled_hlo_computation); + + EstimateRunTimeData estimate_run_time_data = + EstimateRunTimeForTiledHloComputation( + fusion_adaptor, tiled_hlo_computation, launch_dimensions); + + if (!best_tiled_run_time_data.has_value() || + estimate_run_time_data.exec_time < + best_tiled_run_time_data->runtime_data.exec_time) { + BlockLevelParameters block_level_parameters; + block_level_parameters.output_tile_sizes = + std::vector(tiling.begin(), tiling.end()); + block_level_parameters.num_warps = + launch_dimensions.num_threads_per_block() / WarpSize(); + + best_tiled_run_time_data = + TiledRunTimeData{estimate_run_time_data, block_level_parameters}; + } + } + + if (!best_tiled_run_time_data.has_value()) { + return FusionDecision("No valid tilings found."); + } + return *best_tiled_run_time_data; +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h index b70c29e6b722c3..8d8b2ac109afd9 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include "absl/status/statusor.h" @@ -30,13 +31,22 @@ limitations under the License. #include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/service/gpu/model/hlo_op_profiles.h" +#include "xla/service/gpu/model/symbolic_tile_analysis.h" +#include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/hlo_cost_analysis.h" +#include "xla/service/instruction_fusion.h" #include "xla/shape.h" #include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { +// Contains informations about block level parameters and run time of a fusion. +struct TiledRunTimeData { + EstimateRunTimeData runtime_data; + BlockLevelParameters block_level_parameters; +}; + // Implementation of Cost Model that uses indexing analysis to estimate amount // of compute and memory access time. class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { @@ -65,6 +75,11 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { const HloInstruction* producer, absl::Span fused_consumers = {}); + EstimateRunTimeData EstimateRunTimeForTiledHloComputation( + const HloFusionAdaptor& fusion_adaptor, + const TiledHloComputation& tiled_hlo_computation, + const LaunchDimensions& launch_dimensions); + // Estimate the run time of the fusion with the given launch dimensions and // output tile sizes. // @@ -82,6 +97,17 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { absl::StatusOr EstimateRunTimeForTriton( const HloInstruction* producer, const HloInstruction* consumer = nullptr); + // Estimates the best tile sizes for the given fusion. Iterates over all the + // good tile sizes provided by SymbolicTileAnalysis, estimates the run time + // for each of them. + // + // Returns status if there is an error that we can't recover from. + // Returns FusionDecision if the fusion can't be tiled or there are no valid + // block level parameters. + // Otherwise returns block level parameters that give the best execution time. + absl::StatusOr> + TryFindBestTilingForFusion(const HloFusionAdaptor& fusion_adaptor); + private: // Returns an estimate how many FLOPs will be used to produce one element of // the output. diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc index 57468cf258d521..a58e23d2f57a8d 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc @@ -17,7 +17,9 @@ limitations under the License. #include #include +#include +#include #include #include "absl/strings/string_view.h" #include "absl/time/time.h" @@ -26,6 +28,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_performance_model_base.h" @@ -39,6 +42,8 @@ namespace xla { namespace gpu { namespace { +using ::testing::ElementsAre; + class GpuIndexingPerformanceModelTest : public HloTestBase { GpuHloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { return [&](const Shape& shape) { @@ -270,6 +275,66 @@ ENTRY main { EXPECT_NEAR(absl::ToDoubleMicroseconds(runtime_data.exec_time), 5, 1); } +TEST_F(GpuIndexingPerformanceModelTest, + EstimateBestTiling_TritonSoftmax_IsSupported) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +triton_softmax_computation { + param_0 = f32[512,911]{1,0} parameter(0) + param_1 = f32[911]{0} parameter(1) + broadcast_0 = f32[512,911]{1,0} broadcast(param_1), dimensions={1} + multiply_0 = f32[512,911]{1,0} multiply(param_0, broadcast_0) + constant_0 = f32[] constant(0) + reduce_0 = f32[512]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[512,911]{1,0} broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[512,911]{1,0} multiply(multiply_0, broadcast_4) +} + +ENTRY main { + param_0 = f32[512,911]{1,0} parameter(0) + param_1 = f32[911]{0} parameter(1) + ROOT triton_softmax = f32[512,911]{1,0} fusion(param_0, param_1), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} +} +)")); + auto fusion_adaptor = HloFusionAdaptor::ForInstruction( + module->entry_computation()->root_instruction()); + + TF_ASSERT_OK_AND_ASSIGN( + auto tiling_result, + indexing_cost_model_.TryFindBestTilingForFusion(*fusion_adaptor)); + + ASSERT_TRUE(std::holds_alternative(tiling_result)); + + auto tiled_runtime_data = std::get(tiling_result); + + constexpr int64_t kParam0SizeBytes = 512 * 911 * 4; + constexpr int64_t kParam1SizeBytes = 911 * 4; + constexpr int64_t kOutputSizeBytes = 512 * 911 * 4; + + // Launch grid consists of 128 blocks. Each block reads 1 tile of shape [4, + // 911] from param_0 and full param_1. In total param_0 is read once and + // param_1 is read 128 times. + constexpr int64_t kExpectedBytesRead = + kParam0SizeBytes + 128 * kParam1SizeBytes; + + EXPECT_THAT(tiled_runtime_data.block_level_parameters.output_tile_sizes, + ElementsAre(4, 911)); + EXPECT_EQ(tiled_runtime_data.block_level_parameters.num_warps, 4); + + EXPECT_EQ(tiled_runtime_data.runtime_data.bytes_read, kExpectedBytesRead); + EXPECT_EQ(tiled_runtime_data.runtime_data.bytes_written, kOutputSizeBytes); + EXPECT_NEAR( + absl::ToDoubleMicroseconds(tiled_runtime_data.runtime_data.exec_time), 5, + 1); +} + } // namespace } // namespace gpu } // namespace xla From 62f235dc90cc202ff9c09eaa59459dd52f071c14 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Wed, 19 Jun 2024 09:26:39 -0700 Subject: [PATCH 024/256] [XLA] [NFC] Unify multi-host handling for hlo_runner_main PiperOrigin-RevId: 644767090 --- .../functional_hlo_runner.cc | 30 ++++++++++++++--- .../functional_hlo_runner.h | 17 +++++++--- .../functional_hlo_runner_test.cc | 32 +++++++++---------- .../multihost_hlo_runner/hlo_runner_main.cc | 21 ++++++------ 4 files changed, 65 insertions(+), 35 deletions(-) diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc index 38b67bf32c8064..8f2d256128d130 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc @@ -76,11 +76,12 @@ limitations under the License. namespace xla { -absl::StatusOr> GetPjRtClient( +static absl::StatusOr> GetPjRtClient( absl::string_view device_type, absl::string_view address, int node_id, - int num_nodes, bool enable_mock_nccl, + int num_nodes, bool enable_mock_nccl, absl::Duration init_timeout, std::unique_ptr& service, - std::shared_ptr& kv_store) { + std::shared_ptr& kv_store, + std::shared_ptr& distributed_client) { if (device_type == "host") { CHECK_EQ(num_nodes, 1); return xla::FunctionalHloRunner::CreateHostClient(); @@ -97,6 +98,11 @@ absl::StatusOr> GetPjRtClient( if (num_nodes == 1) { return xla::FunctionalHloRunner::CreateGpuClient({}); } else { + TF_RET_CHECK(num_nodes == 1 || !address.empty()); + TF_RET_CHECK(node_id >= 0) + << "Node id is expected to be in range [0, num_nodes)"; + TF_RET_CHECK(node_id < num_nodes) + << "Node id is expected to be in range [0, num_nodes)"; CHECK_GT(address.length(), 0); // Multinode. Start service on task 0. if (node_id == 0) { @@ -111,8 +117,8 @@ absl::StatusOr> GetPjRtClient( } xla::DistributedRuntimeClient::Options options; options.node_id = node_id; - options.init_timeout = absl::Seconds(300); - auto distributed_client = + options.init_timeout = init_timeout; + distributed_client = GetDistributedRuntimeClient(std::string(address), options); TF_QCHECK_OK(distributed_client->Connect()); kv_store = GetDistributedKeyValueStore(distributed_client, @@ -127,6 +133,20 @@ absl::StatusOr> GetPjRtClient( } } +absl::StatusOr GetPjRtClient(absl::string_view device_type, + absl::string_view address, + int node_id, int num_nodes, + bool enable_mock_nccl, + + absl::Duration init_timeout) { + PjRtEnvironment env; + TF_ASSIGN_OR_RETURN(env.client, + GetPjRtClient(device_type, address, node_id, num_nodes, + enable_mock_nccl, init_timeout, env.service, + env.kv_store, env.distributed_client)); + return env; +} + namespace { // Creates an HloModule from the given proto. absl::StatusOr> HloTextToModule( diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h index b714e582bdcefe..2329bce1a2fd8e 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h @@ -40,11 +40,18 @@ limitations under the License. namespace xla { -absl::StatusOr> GetPjRtClient( - absl::string_view device_type, absl::string_view address, int node_id, - int num_nodes, bool enable_mock_nccl, - std::unique_ptr& service, - std::shared_ptr& kv_store); +struct PjRtEnvironment { + std::unique_ptr client; + std::unique_ptr service; + std::shared_ptr kv_store; + std::shared_ptr distributed_client; +}; + +absl::StatusOr GetPjRtClient(absl::string_view device_type, + absl::string_view address, + int node_id, int num_nodes, + bool enable_mock_nccl, + absl::Duration init_timeout); // Supported input formats for the input HLO module. enum class InputFormat { diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc index 8486ba21d4ec29..c1e7e6afdaaceb 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc @@ -288,31 +288,31 @@ TEST_F(FunctionalHloRunnerTest, ShardedAutotuningWorks) { absl::Status ShardedAutotuningWorksTestBody(const int node_id) { tsl::setenv("CUDA_VISIBLE_DEVICES", std::to_string(node_id).data(), /*overwrite=*/true); - std::unique_ptr service = nullptr; - std::shared_ptr kv_store = nullptr; - TF_ASSIGN_OR_RETURN(std::unique_ptr client, - GetPjRtClient("gpu", "127.0.0.1:12345", node_id, - kNumNodes, false, service, kv_store)); - CHECK(kv_store != nullptr); + TF_ASSIGN_OR_RETURN( + PjRtEnvironment env, + xla::GetPjRtClient("gpu", "127.0.0.1:12345", node_id, kNumNodes, + /*enable_mock_nccl=*/false, + /*init_timeout=*/absl::Seconds(120))); + CHECK(env.kv_store != nullptr); TF_RETURN_IF_ERROR(FunctionalHloRunner::LoadAndCompile( - *client, GetDebugOptionsFromFlags(), + *env.client, GetDebugOptionsFromFlags(), FunctionalHloRunner::PreprocessingOptions{}, FunctionalHloRunner::RawCompileOptions{}, GetHloPath("multiple_gemm_fusions.hlo"), InputFormat::kText)); if (node_id == 0) { - TF_ASSIGN_OR_RETURN( - std::string results0, - kv_store->Get("gemm_fusion_autotuning_results_1_0", absl::Seconds(1))); + TF_ASSIGN_OR_RETURN(std::string results0, + env.kv_store->Get("gemm_fusion_autotuning_results_1_0", + absl::Seconds(1))); CHECK(absl::StrContains(results0, "run_time")); - TF_ASSIGN_OR_RETURN( - std::string results1, - kv_store->Get("gemm_fusion_autotuning_results_1_1", absl::Seconds(1))); + TF_ASSIGN_OR_RETURN(std::string results1, + env.kv_store->Get("gemm_fusion_autotuning_results_1_1", + absl::Seconds(1))); CHECK(absl::StrContains(results1, "run_time")); // First two nodes autotune two different fusions. CHECK_NE(results0, results1); - TF_ASSIGN_OR_RETURN( - std::string results2, - kv_store->Get("gemm_fusion_autotuning_results_1_2", absl::Seconds(1))); + TF_ASSIGN_OR_RETURN(std::string results2, + env.kv_store->Get("gemm_fusion_autotuning_results_1_2", + absl::Seconds(1))); // Third node has nothing to autotune. CHECK(!absl::StrContains(results2, "run_time")); } diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc b/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc index a84ee581ab228b..733d42d3e07112 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc @@ -66,7 +66,7 @@ If multiple HLOs are launched, we assume that they are encoded in the same format (HLO text by default). Running multiple HLOs is convenient when replaying all HLOs from an execution dump, with e.g.: - bazel run hlo_runner_main -- /dump/*before_optimizations.txt + bazel run hlo_runner_main -- /dump/hlo_*before_optimizations.txt Mock GPU usage: @@ -101,6 +101,7 @@ struct HloRunnerConfig { bool remove_infeed_outfeed = true; int32_t num_repeats = 1; std::string execution_options_path = ""; + int64_t gpu_client_initialization_timeout_sec = 300; }; } // namespace @@ -226,25 +227,23 @@ static absl::Status RunMultihostHloRunner(int argc, char** argv, QCHECK(opts.dump_output_literal_to.empty() || argc == 2) << "Can only dump output literal when single input file is specified"; - std::unique_ptr service; - std::shared_ptr kv_store; - TF_ASSIGN_OR_RETURN( - std::unique_ptr client, + PjRtEnvironment env, GetPjRtClient(opts.device_type_str, opts.address_str, opts.task_id, - opts.num_nodes, opts.enable_mock_nccl, service, kv_store)); + opts.num_nodes, opts.enable_mock_nccl, + absl::Seconds(opts.gpu_client_initialization_timeout_sec))); for (int c = 1; c < argc; c++) { const char* filename = argv[c]; std::cout << "\n** Running " << filename << " **\n"; if (opts.should_run) { TF_RETURN_IF_ERROR(xla::FunctionalHloRunner::LoadAndRunAndDump( - *client, GetDebugOptionsFromFlags(), preproc_options, + *env.client, GetDebugOptionsFromFlags(), preproc_options, raw_compile_options, running_options, filename, opts.input_format, opts.dump_output_literal_to, opts.task_id)); } else { TF_RETURN_IF_ERROR(FunctionalHloRunner::LoadAndCompile( - *client, GetDebugOptionsFromFlags(), preproc_options, + *env.client, GetDebugOptionsFromFlags(), preproc_options, raw_compile_options, argv[c], opts.input_format, opts.task_id)); } } @@ -314,7 +313,11 @@ int main(int argc, char** argv) { "Repeatedly execute the HLO for this many times."), tsl::Flag("execution_options_path", &opts.execution_options_path, "A path to a protobuf text file which stores the " - "ExecutionOptions message for this HLO module.")}; + "ExecutionOptions message for this HLO module."), + tsl::Flag("gpu_client_initialization_timeout_sec", + &opts.gpu_client_initialization_timeout_sec, + "A timeout, in seconds, for the GPU client initialization. " + "Only used for multi-node GPU runs")}; xla::AppendDebugOptionsFlags(&flag_list); From 21cb1e0b33803c7c24569fa4b014eb85a30fcd90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= Date: Wed, 19 Jun 2024 09:40:19 -0700 Subject: [PATCH 025/256] [XLA:GPU] Support tiling "softmax and reduce" example To allow this, I've added support for deriving symbolic tiles, when the indexing map has symbols in a sum. PiperOrigin-RevId: 644769721 --- .../xla/service/gpu/model/symbolic_tile.cc | 81 +++++++++++++------ .../gpu/model/symbolic_tile_analysis_test.cc | 72 ++++++++++++++++- .../service/gpu/model/symbolic_tile_test.cc | 19 ++++- 3 files changed, 142 insertions(+), 30 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc index f4c4bff3c7fca6..6b6c69caa2b732 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc @@ -321,11 +321,37 @@ void SortByStride(std::vector& sizes_and_strides) { }); } +// Returns the range size of the given size expression. +// +// `size` must be a constant or dimension expression. +std::optional TryGetSizeExpressionRangeSize( + AffineExpr size, absl::Span dimension_intervals) { + CHECK(size.getKind() == AffineExprKind::Constant || + size.getKind() == AffineExprKind::DimId); + if (auto dimension = llvm::dyn_cast(size)) { + const Interval& interval = dimension_intervals.at(dimension.getPosition()); + if (interval.lower != 0) { + // TODO(bchetioui): I think we may need to handle this to have reshapes + // working well with concatenations. Nevertheless, we can take a look + // later. + VLOG(1) << "Attempted to combine strides but got dimension " + << AffineMapPrinter().ToString(dimension) << " with lower bound " + << interval.lower << " != 0"; + return std::nullopt; + } + // We need to add 1 to the upper bound of the interval to describe the + // number of elements being captured, since the interval bounds are + // inclusive. + return interval.upper + 1; + } + return llvm::cast(size).getValue(); +}; + // Given a list of sizes and strides, combines the strides into a single // expression if it is possible. // // The current implementation expects that each size captures a single dimension -// parameter. +// parameter or a constant (coming from a RangeVar). // // Let s be an n-dimensional shape that we want to fully collapse. In order to // be propagated successfully through the collapse, the pattern of the tiling of @@ -363,10 +389,13 @@ std::optional CombineStrides( // to parameters of the initial indexing map. It follows that if a size // expression is exactly a dimension parameter, we know its exact bounds. // - // If a size is not exactly a dimension parameter, then it is dubious - // whether we know the bounds---and may thus calculate wrong strides. - if (size_and_stride.size.getKind() != AffineExprKind::DimId) { - VLOG(1) << "Attempted to combine strides but got non-dimension size " + // If a size is not a constant and not exactly a dimension parameter, then + // it is dubious whether we know the bounds---and may thus calculate wrong + // strides. + if (size_and_stride.size.getKind() != AffineExprKind::Constant && + size_and_stride.size.getKind() != AffineExprKind::DimId) { + VLOG(1) << "Attempted to combine strides but got non-constant, " + "non-dimension size " << AffineMapPrinter().ToString(size_and_stride.size); return std::nullopt; } @@ -388,31 +417,21 @@ std::optional CombineStrides( if (dim_id > 0) { const SizeAndStrideExpression& previous_size_and_stride = sizes_and_strides[dim_id - 1]; - const auto previous_dimension = - llvm::cast(previous_size_and_stride.size); - const Interval& previous_size_interval = - dimension_intervals[previous_dimension.getPosition()]; - if (previous_size_interval.lower != 0) { - // TODO(bchetioui): I think we may need to handle this to have reshapes - // working well with concatenations. Nevertheless, we can take a look - // later. - VLOG(1) << "Attempted to combine strides but got dimension " - << AffineMapPrinter().ToString(previous_dimension) - << " with lower bound " << previous_size_interval.lower - << " != 0"; + std::optional previous_size_expression_range_size = + TryGetSizeExpressionRangeSize(previous_size_and_stride.size, + dimension_intervals); + if (!previous_size_expression_range_size.has_value()) { return std::nullopt; } int64_t previous_stride = llvm::cast(previous_size_and_stride.stride) .getValue(); - // We need to add 1 to the upper bound of the interval to describe the - // number of elements being captured, since the interval bounds are - // inclusive. - if ((previous_size_interval.upper + 1) * previous_stride != stride) { + + if (*previous_size_expression_range_size * previous_stride != stride) { VLOG(1) << "Attempted to combine strides but stride did not grow " << "exactly as expected: got " - << (previous_size_interval.upper + 1) << " * " + << *previous_size_expression_range_size << " * " << previous_stride << " != " << stride; return std::nullopt; } @@ -426,9 +445,12 @@ std::optional CombineStrides( size_and_stride_it != sizes_and_strides.rend(); ++size_and_stride_it) { AffineExpr size = size_and_stride_it->size; AffineExpr stride = size_and_stride_it->stride; - const Interval& size_interval = - dimension_intervals[llvm::cast(size).getPosition()]; - nested_if = IfNeqOne(size, stride, nested_if, size_interval.upper + 1); + std::optional size_expression_range_size = + TryGetSizeExpressionRangeSize(size, dimension_intervals); + if (!size_expression_range_size.has_value()) { + return std::nullopt; + } + nested_if = IfNeqOne(size, stride, nested_if, *size_expression_range_size); } return nested_if; @@ -441,6 +463,15 @@ std::optional CombineSizesAndStrides( absl::Span dimension_intervals) { CHECK(!sizes_and_strides.empty()); + if (VLOG_IS_ON(1)) { + for (const SizeAndStrideExpression& size_and_stride : sizes_and_strides) { + LOG(INFO) << "CombineSizesAndStrides:"; + LOG(INFO) << "size: " << AffineMapPrinter().ToString(size_and_stride.size) + << " stride: " + << AffineMapPrinter().ToString(size_and_stride.stride); + } + } + std::optional maybe_constraints = ConstraintMap(); for (const SizeAndStrideExpression& size_and_stride : sizes_and_strides) { diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc index 660c686d5d4fec..4ecc81135f2e07 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -640,13 +640,13 @@ TEST_F(SymbolicTileAnalysisTest, GetGoodTilingsWorksForSoftmaxExample) { ParseAndReturnVerifiedModule(R"( HloModule m -region { +max_computation { param_0 = f32[] parameter(0) param_1 = f32[] parameter(1) ROOT maximum = f32[] maximum(param_0, param_1) } -region.1 { +add_computation { param_0 = f32[] parameter(0) param_1 = f32[] parameter(1) ROOT add = f32[] add(param_0, param_1) @@ -656,13 +656,13 @@ fused_computation { param_0 = f32[8192,50304] parameter(0) bitcast = f32[4,2048,50304] bitcast(param_0) constant = f32[] constant(-inf) - reduce = f32[8192] reduce(param_0, constant), dimensions={1}, to_apply=region + reduce = f32[8192] reduce(param_0, constant), dimensions={1}, to_apply=max_computation bitcast.1 = f32[4,2048] bitcast(reduce) broadcast = f32[4,2048,50304] broadcast(bitcast.1), dimensions={0,1} subtract = f32[4,2048,50304] subtract(bitcast, broadcast) exponential = f32[4,2048,50304] exponential(subtract) constant.1 = f32[] constant(0) - reduce.1 = f32[4,2048] reduce(exponential, constant.1), dimensions={2}, to_apply=region.1 + reduce.1 = f32[4,2048] reduce(exponential, constant.1), dimensions={2}, to_apply=add_computation log = f32[4,2048] log(reduce.1) broadcast.1 = f32[4,2048,50304] broadcast(log), dimensions={0,1} ROOT subtract.1 = f32[4,2048,50304] subtract(subtract, broadcast.1) @@ -686,6 +686,70 @@ ENTRY entry_computation { LogTilingsIfVlog1(good_tilings); } +TEST_F(SymbolicTileAnalysisTest, + GetGoodTilingsWorksForSoftmaxAndReduceExample) { + // The example is from + // https://github.com/google/paxml/blob/91893818862645f5e9f23b84f530e611551745f6/paxml/contrib/gpu/scripts_gpu/configs.py#L107-L120. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule m + +max_computation { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(param_0, param_1) +} + +add_computation { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT add = f32[] add(param_0, param_1) +} + +fused_computation { + param_0 = f32[8192,50304] parameter(0) + param_1 = s32[4,2048] parameter(1) + broadcast = s32[4,2048,50304] broadcast(param_1), dimensions={0,1} + iota = s32[4,2048,50304] iota(), iota_dimension=2 + compare = pred[4,2048,50304] compare(broadcast, iota), direction=EQ + bitcast = f32[4,2048,50304] bitcast(param_0) + constant = f32[] constant(-inf) + reduce = f32[8192] reduce(param_0, constant), dimensions={1}, to_apply=max_computation + bitcast.1 = f32[4,2048] bitcast(reduce) + broadcast.1 = f32[4,2048,50304] broadcast(bitcast.1), dimensions={0,1} + subtract = f32[4,2048,50304] subtract(bitcast, broadcast.1) + exponential = f32[4,2048,50304] exponential(subtract) + constant.1 = f32[] constant(0) + reduce.1 = f32[4,2048] reduce(exponential, constant.1), dimensions={2}, to_apply=add_computation + log = f32[4,2048] log(reduce.1) + broadcast.2 = f32[4,2048,50304] broadcast(log), dimensions={0,1} + subtract.1 = f32[4,2048,50304] subtract(subtract, broadcast.2) + constant.2 = f32[] constant(0) + broadcast.3 = f32[4,2048,50304] broadcast(constant.2), dimensions={} + select = f32[4,2048,50304] select(compare, subtract.1, broadcast.3) + bitcast.2 = f32[4,2048,393,128] bitcast(select) + ROOT reduce.2 = f32[4,2048,393] reduce(bitcast.2, constant.2), dimensions={3}, to_apply=add_computation +} + +ENTRY entry_computation { + param_0 = f32[8192,50304] parameter(0) + param_1 = s32[4,2048] parameter(1) + ROOT fusion = f32[4,2048,393] fusion(param_0, param_1), kind=kCustom, calls=fused_computation, backend_config={"fusion_backend_config":{"kind":"__triton_softmax"}} +} +)")); + + std::optional opt_analysis = + TryAnalyzeModule(module.get()); + ASSERT_TRUE(opt_analysis.has_value()); + const SymbolicTileAnalysis& analysis = opt_analysis.value(); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector good_tilings, + analysis.GetGoodTilings()); + EXPECT_THAT(good_tilings, Not(IsEmpty())); + LogTilingsIfVlog1(good_tilings); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc index 92d92013955275..b825efc4609812 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc @@ -744,7 +744,7 @@ TEST_F(SymbolicTileTest, } TEST_F(SymbolicTileTest, - CanPropagateTileWhenPreexistingConstraintsCanBeSimplifiedAway) { + CanDeriveTileWhenPreexistingConstraintsCanBeSimplifiedAway) { // The example is from // https://github.com/google/paxml/blob/91893818862645f5e9f23b84f530e611551745f6/paxml/contrib/gpu/scripts_gpu/configs.py#L107-L120. IndexingMap indexing_map = IndexingMap::FromTensorSizes( @@ -765,6 +765,23 @@ TEST_F(SymbolicTileTest, )"))); } +TEST_F(SymbolicTileTest, CanDeriveTileWhenTheIndexingMapHasSymbolsInASum) { + // The example is from + // https://github.com/google/paxml/blob/91893818862645f5e9f23b84f530e611551745f6/paxml/contrib/gpu/scripts_gpu/configs.py#L107-L120. + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0, d1, d2)[s0] -> (d0, d1, d2 * 128 + s0)", + &mlir_context_), + {4, 2048, 393}, {128}); + + EXPECT_THAT(SymbolicTile::FromIndexingMap(indexing_map), + Optional(MatchSymbolicTileString(R"( + Symbolic tile with + offset_map: ()[s0, s1, s2] -> (0, 0, 0) + size_map: ()[s0, s1, s2] -> (s0, s1, s2 * 128) + stride_map: ()[s0, s1, s2] -> (1, 1, 1) + )"))); +} + } // namespace } // namespace gpu } // namespace xla From 4347a69f8985f9777fc9b92a02c86d6a5e23f737 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 19 Jun 2024 10:07:02 -0700 Subject: [PATCH 026/256] Reverts df2dae48118d675fb92cf51f42aa6abdb391cedc PiperOrigin-RevId: 644774653 --- third_party/xla/xla/pjrt/exceptions.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/pjrt/exceptions.h b/third_party/xla/xla/pjrt/exceptions.h index cf9696de956549..6a5865f3cce0ce 100644 --- a/third_party/xla/xla/pjrt/exceptions.h +++ b/third_party/xla/xla/pjrt/exceptions.h @@ -53,7 +53,10 @@ class XlaRuntimeError : public std::runtime_error { } static bool ShowStackTraces() { - return absl::string_view(getenv("JAX_TRACEBACK_FILTERING")) == "off"; + if (char* value = getenv("JAX_TRACEBACK_FILTERING")) { + return strcmp(value, "off"); + } + return false; } std::optional status_; From 1d4b49fb424b21b4274efb0f180f58765fb67518 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Wed, 19 Jun 2024 10:47:17 -0700 Subject: [PATCH 027/256] [XLA:GPU] Add num_warps to BlockLevelFusionConfig and a method to convert the struct to proto. PiperOrigin-RevId: 644782060 --- .../xla/xla/service/gpu/backend_configs.proto | 3 +++ third_party/xla/xla/service/gpu/model/BUILD | 2 +- .../service/gpu/model/tiled_hlo_computation.h | 14 +++++++++-- .../gpu/model/tiled_hlo_computation_test.cc | 23 ++++++++++++++++--- 4 files changed, 36 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/service/gpu/backend_configs.proto b/third_party/xla/xla/service/gpu/backend_configs.proto index 5bf6fbc19e61ec..be10b81cbd1fac 100644 --- a/third_party/xla/xla/service/gpu/backend_configs.proto +++ b/third_party/xla/xla/service/gpu/backend_configs.proto @@ -159,6 +159,9 @@ message BlockLevelFusionConfig { // The output tile sizes of the associated instruction. The length of this // field is expected to be the rank of the output shape. repeated int64 output_tile_sizes = 1; + + // The number of warps to use for the kernel. + int64 num_warps = 2; } message FusionBackendConfig { diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 0ce770b63c197a..cf14c86064b5b2 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -636,7 +636,6 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", "@local_tsl//tsl/lib/gtl:iterator_range", ], ) @@ -646,6 +645,7 @@ xla_cc_test( srcs = ["tiled_hlo_computation_test.cc"], deps = [ ":tiled_hlo_computation", + "//xla/service/gpu:backend_configs_cc", "@com_google_googletest//:gtest_main", ], ) diff --git a/third_party/xla/xla/service/gpu/model/tiled_hlo_computation.h b/third_party/xla/xla/service/gpu/model/tiled_hlo_computation.h index f937c33cce93c8..ca9ea4cda5ef92 100644 --- a/third_party/xla/xla/service/gpu/model/tiled_hlo_computation.h +++ b/third_party/xla/xla/service/gpu/model/tiled_hlo_computation.h @@ -36,7 +36,7 @@ struct BlockLevelParameters { std::vector output_tile_sizes; // Triton-specific parameters. - int num_warps = 1; + int64_t num_warps = 1; int num_ctas = 1; int num_stages = 1; @@ -46,7 +46,17 @@ struct BlockLevelParameters { return BlockLevelParameters{ /*output_tile_sizes=*/ std::vector(config.output_tile_sizes().begin(), - config.output_tile_sizes().end())}; + config.output_tile_sizes().end()), + /*num_warps=*/config.num_warps()}; + } + + // Returns a BlockLevelFusionConfig proto from a BlockLevelParameters struct. + BlockLevelFusionConfig ToBlockLevelFusionConfig() const { + BlockLevelFusionConfig config; + config.mutable_output_tile_sizes()->Add(output_tile_sizes.begin(), + output_tile_sizes.end()); + config.set_num_warps(num_warps); + return config; } }; diff --git a/third_party/xla/xla/service/gpu/model/tiled_hlo_computation_test.cc b/third_party/xla/xla/service/gpu/model/tiled_hlo_computation_test.cc index f09c7aa35f29d0..7a9616478b2ecd 100644 --- a/third_party/xla/xla/service/gpu/model/tiled_hlo_computation_test.cc +++ b/third_party/xla/xla/service/gpu/model/tiled_hlo_computation_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "xla/service/gpu/backend_configs.pb.h" namespace xla { namespace gpu { @@ -29,11 +30,27 @@ TEST(BlockLevelParametersTest, BlockLevelFusionConfig block_level_fusion_config; block_level_fusion_config.mutable_output_tile_sizes()->Add(18); block_level_fusion_config.mutable_output_tile_sizes()->Add(19); + block_level_fusion_config.set_num_warps(12); - EXPECT_THAT(BlockLevelParameters::FromBlockLevelFusionConfig( - block_level_fusion_config) - .output_tile_sizes, + BlockLevelParameters block_level_parameters = + BlockLevelParameters::FromBlockLevelFusionConfig( + block_level_fusion_config); + EXPECT_THAT(block_level_parameters.output_tile_sizes, ElementsAre(18, 19)); + EXPECT_THAT(block_level_parameters.num_warps, 12); +} + +TEST(BlockLevelParametersTest, + BlockLevelParametersCanBeConvertedToBlockLevelFusionConfig) { + BlockLevelParameters block_level_parameters; + block_level_parameters.output_tile_sizes = {18, 19}; + block_level_parameters.num_warps = 12; + + BlockLevelFusionConfig block_level_fusion_config = + block_level_parameters.ToBlockLevelFusionConfig(); + + EXPECT_THAT(block_level_fusion_config.output_tile_sizes(), ElementsAre(18, 19)); + EXPECT_THAT(block_level_fusion_config.num_warps(), 12); } } // namespace From c2caec228a89410f18b2f063cad9c8d148b434bc Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 19 Jun 2024 11:15:59 -0700 Subject: [PATCH 028/256] [xla:cpu] Add support for multi-process all-reduce collective PiperOrigin-RevId: 644786801 --- third_party/xla/xla/pjrt/cpu/cpu_client.cc | 48 ++++++++++++------- third_party/xla/xla/service/cpu/runtime/BUILD | 7 +++ .../service/cpu/runtime/all_reduce_thunk.cc | 41 ++++++++++++---- .../service/cpu/runtime/collective_thunk.cc | 5 ++ .../service/cpu/runtime/collective_thunk.h | 3 ++ .../xla/xla/service/cpu/runtime/thunk.cc | 20 ++++++-- .../xla/xla/service/cpu/runtime/thunk.h | 5 +- .../xla/xla/service/cpu/thunk_emitter.cc | 1 + third_party/xla/xla/tests/all_reduce_test.cc | 1 + 9 files changed, 102 insertions(+), 29 deletions(-) diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.cc b/third_party/xla/xla/pjrt/cpu/cpu_client.cc index 3dc7f4208859b4..28f3a92284f4d0 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.cc @@ -1582,10 +1582,15 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( } cpu::BufferAllocations allocations(buffer_device_mem); + + TF_ASSIGN_OR_RETURN( + cpu::Thunk::CollectiveExecuteParams collective_params, + cpu::Thunk::CollectiveExecuteParams::Create(&run_options)); + cpu::Thunk::ExecuteParams execute_params = { &cpu_executable->host_kernels(), &allocations, cpu::runtime::GetXfeedManager(run_options.device_ordinal()), - run_options.intra_op_thread_pool()}; + run_options.intra_op_thread_pool(), &collective_params}; auto execute_event = cpu_executable->thunks().Execute( execute_params, [&](cpu::ThunkExecutor::Task task) { @@ -1703,22 +1708,31 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( } cpu::BufferAllocations allocations(buffer_device_mem); - cpu::Thunk::ExecuteParams execute_params = { - &cpu_executable->host_kernels(), &allocations, - cpu::runtime::GetXfeedManager(run_options.device_ordinal()), - run_options.intra_op_thread_pool()}; - - auto execute_event = cpu_executable->thunks().Execute( - execute_params, [&](cpu::ThunkExecutor::Task task) { - eigen_device->getPool()->Schedule( - cpu::ToCopyableTask(std::move(task))); - }); - - tsl::profiler::TraceMe trace( - "ThunkExecutor::Execute (wait for completion)"); - tsl::BlockUntilReady(execute_event); - status = execute_event.IsError() ? execute_event.GetError() - : absl::OkStatus(); + + absl::StatusOr + collective_params = + cpu::Thunk::CollectiveExecuteParams::Create(&run_options); + + if (collective_params.ok()) { + cpu::Thunk::ExecuteParams execute_params = { + &cpu_executable->host_kernels(), &allocations, + cpu::runtime::GetXfeedManager(run_options.device_ordinal()), + run_options.intra_op_thread_pool(), &*collective_params}; + + auto execute_event = cpu_executable->thunks().Execute( + execute_params, [&](cpu::ThunkExecutor::Task task) { + eigen_device->getPool()->Schedule( + cpu::ToCopyableTask(std::move(task))); + }); + + tsl::profiler::TraceMe trace( + "ThunkExecutor::Execute (wait for completion)"); + tsl::BlockUntilReady(execute_event); + status = execute_event.IsError() ? execute_event.GetError() + : absl::OkStatus(); + } else { + status = collective_params.status(); + } } else { status = diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD index 1e635a1fb75f4c..ddc8670ef7964a 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/service/cpu/runtime/BUILD @@ -47,7 +47,10 @@ cc_library( "//xla:util", "//xla/runtime:buffer_use", "//xla/service:global_device_id", + "//xla/service/cpu:collectives_interface", + "//xla/service/cpu:cpu_executable_run_options", "//xla/service/cpu:cpu_runtime", + "//xla/service/cpu:in_process_collectives", "//xla/stream_executor", "//xla/stream_executor/host:host_kernel_c_api", "//xla/tsl/concurrency:async_value", @@ -175,11 +178,13 @@ cc_library( ":collective_thunk", ":thunk", "//xla:shape_util", + "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", + "//xla/service/cpu:collectives_interface", "//xla/stream_executor", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/algorithm:container", @@ -190,6 +195,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:numbers", "@local_tsl//tsl/platform:statusor", @@ -213,6 +219,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:statusor", ], diff --git a/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.cc b/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.cc index 0f59a38fd6e645..aa502f45b24c16 100644 --- a/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.cc @@ -23,7 +23,6 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" @@ -31,12 +30,16 @@ limitations under the License. #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" +#include "xla/service/cpu/collectives_interface.h" #include "xla/service/cpu/runtime/collective_thunk.h" #include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/util.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" @@ -146,14 +149,17 @@ tsl::AsyncValueRef AllReduceThunk::Execute( return OkExecuteEvent(); } - // For multi-replica case, we need collective parameters to be able to - // perform the all-reduce operation collectively with other replicas. + // For multi-replica case, we need access to collectives implementation to be + // able to perform the all-reduce operation collectively with other replicas. CollectiveExecuteParams* collective_params = params.collective_params; - if (collective_params == nullptr) { - return Internal( - "Collective parameters are not set for all-reduce operation"); - } + TF_RET_CHECK(collective_params) + << "Collective parameters are not set for all-reduce operation"; + + CollectivesInterface* collectives = collective_params->collectives; + TF_RET_CHECK(collectives) + << "Collectives interface is not set for all-reduce operation"; + // Find out rendezvous key and rank in global devices for the current device. TF_ASSIGN_OR_RETURN(RendezvousKey key, GetRendezvousKey(*collective_params)); TF_ASSIGN_OR_RETURN( int32_t rank, @@ -161,7 +167,18 @@ tsl::AsyncValueRef AllReduceThunk::Execute( VLOG(3) << absl::StreamFormat(" rank=%d, key=%s", rank, key.ToString()); - return absl::UnimplementedError("AllReduceThunk::Execute not implemented"); + TF_ASSIGN_OR_RETURN(std::shared_ptr communicator, + collectives->GetCommunicator(key.global_devices, rank)); + + for (int32_t i = 0; i < num_srcs; ++i) { + const Shape& shape = source_shapes_[i]; + TF_RETURN_IF_ERROR(communicator->AllReduce( + key, reduction_kind_, shape.element_type(), + ShapeUtil::ElementsIn(shape), source_data[i].opaque(), + destination_data[i].opaque(), DefaultCollectiveTimeout())); + } + + return OkExecuteEvent(); } Thunk::BufferUses AllReduceThunk::buffer_uses() const { @@ -173,6 +190,14 @@ Thunk::BufferUses AllReduceThunk::buffer_uses() const { for (auto& destination_buffer : destination_buffers_) { uses.push_back(BufferUse::Write(destination_buffer)); } + + // TODO(ezhulenev): It is a hack to make sure that we execute all collective + // operations in the same order as in HLO schedule, because otherwise racing + // collectives lead to undefined behavior. Instead we should correctly model + // side effects of Thunks. + static auto* fake_alloc = new BufferAllocation(0, 1, 0); + uses.push_back(BufferUse::Write(BufferAllocation::Slice(fake_alloc, 0, 1))); + return uses; } diff --git a/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc b/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc index 21676087e21f38..3f018567942f5d 100644 --- a/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/time/time.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/computation_placer.h" #include "xla/service/cpu/runtime/thunk.h" @@ -41,6 +42,10 @@ CollectiveThunk::CollectiveThunk(Kind kind, Thunk::Info info, OpParams op_params) : Thunk(kind, info), op_params_(std::move(op_params)) {} +absl::Duration CollectiveThunk::DefaultCollectiveTimeout() { + return absl::Minutes(30); +} + absl::StatusOr CollectiveThunk::GetRendezvousKey( const Thunk::CollectiveExecuteParams& params) { TF_RET_CHECK(params.device_assignment) << "Device assignment is null"; diff --git a/third_party/xla/xla/service/cpu/runtime/collective_thunk.h b/third_party/xla/xla/service/cpu/runtime/collective_thunk.h index 022b65346a26f6..588b3b43f7d3a3 100644 --- a/third_party/xla/xla/service/cpu/runtime/collective_thunk.h +++ b/third_party/xla/xla/service/cpu/runtime/collective_thunk.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "absl/time/time.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/runtime/thunk.h" #include "xla/service/global_device_id.h" @@ -46,6 +47,8 @@ class CollectiveThunk : public Thunk { const OpParams& op_params() const { return op_params_; } protected: + absl::Duration DefaultCollectiveTimeout(); + absl::StatusOr GetRendezvousKey( const Thunk::CollectiveExecuteParams& params); diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.cc b/third_party/xla/xla/service/cpu/runtime/thunk.cc index 6da98ddb4f5b8a..c60495b3779631 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/thunk.cc @@ -23,6 +23,9 @@ limitations under the License. #include #include "xla/executable_run_options.h" +#include "xla/service/cpu/collectives_interface.h" +#include "xla/service/cpu/cpu_executable_run_options.h" +#include "xla/service/cpu/in_process_collectives.h" #include "xla/service/global_device_id.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" @@ -70,18 +73,29 @@ Thunk::CollectiveExecuteParams::Create( ? run_options->device_ordinal() : run_options->stream()->parent()->device_ordinal(); + // If CPU executable run options are set, use the collectives interface + // provided by the executable run options. Otherwise, use the in-process + // collectives interface. + static auto* in_process_collectives = new runtime::InProcessCollectives(); + CollectivesInterface* collectives = + run_options->cpu_executable_run_options() + ? run_options->cpu_executable_run_options()->collectives() + : in_process_collectives; + return CollectiveExecuteParams{run_options->run_id(), device_ordinal, GlobalDeviceId(run_options->device_ordinal()), - run_options->device_assignment()}; + run_options->device_assignment(), collectives}; } Thunk::CollectiveExecuteParams::CollectiveExecuteParams( RunId run_id, int64_t local_device_ordinal, GlobalDeviceId global_device_id, - const DeviceAssignment* device_assignment) + const DeviceAssignment* device_assignment, + CollectivesInterface* collectives) : run_id(run_id), local_device_ordinal(local_device_ordinal), global_device_id(global_device_id), - device_assignment(device_assignment) {} + device_assignment(device_assignment), + collectives(collectives) {} tsl::AsyncValueRef Thunk::OkExecuteEvent() { static tsl::AsyncValueOwningRef* event = [] { diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.h b/third_party/xla/xla/service/cpu/runtime/thunk.h index 8074412330f873..3d1a7582ba5578 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.h +++ b/third_party/xla/xla/service/cpu/runtime/thunk.h @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/executable_run_options.h" #include "xla/runtime/buffer_use.h" +#include "xla/service/cpu/collectives_interface.h" #include "xla/service/cpu/runtime/buffer_allocations.h" #include "xla/service/cpu/xfeed_manager.h" #include "xla/service/global_device_id.h" @@ -126,11 +127,13 @@ class Thunk { GlobalDeviceId global_device_id; const DeviceAssignment* device_assignment = nullptr; + CollectivesInterface* collectives = nullptr; private: CollectiveExecuteParams(RunId run_id, int64_t local_device_ordinal, GlobalDeviceId global_device_id, - const DeviceAssignment* device_assignment); + const DeviceAssignment* device_assignment, + CollectivesInterface* collectives); }; //===--------------------------------------------------------------------===// diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.cc b/third_party/xla/xla/service/cpu/thunk_emitter.cc index fe01f464d28b89..cddcd320aee032 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.cc +++ b/third_party/xla/xla/service/cpu/thunk_emitter.cc @@ -165,6 +165,7 @@ absl::StatusOr ThunkEmitter::EmitHloInstruction( case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: + case HloOpcode::kGather: case HloOpcode::kImag: case HloOpcode::kIota: case HloOpcode::kIsFinite: diff --git a/third_party/xla/xla/tests/all_reduce_test.cc b/third_party/xla/xla/tests/all_reduce_test.cc index 2d449f31c61ac2..714ee5fc0c3a94 100644 --- a/third_party/xla/xla/tests/all_reduce_test.cc +++ b/third_party/xla/xla/tests/all_reduce_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "xla/literal.h" From 2ab9b0ab560e2b7591ee666a9a87b250c371a065 Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Wed, 19 Jun 2024 11:27:43 -0700 Subject: [PATCH 029/256] Fix the aggregation of power metrics. PiperOrigin-RevId: 644788813 --- tensorflow/core/profiler/protobuf/power_metrics.proto | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/core/profiler/protobuf/power_metrics.proto b/tensorflow/core/profiler/protobuf/power_metrics.proto index e3360f44298beb..c5066e34485cea 100644 --- a/tensorflow/core/profiler/protobuf/power_metrics.proto +++ b/tensorflow/core/profiler/protobuf/power_metrics.proto @@ -20,6 +20,8 @@ message PowerComponentMetrics { double max_moving_avg_power_10ms = 6; // (FW only) The timescale in us to compute moving averages. uint32 timescale_us = 7; + // The number of samples. + uint64 sample_count = 8; } message PowerMetrics { From b92ced390c24ef0017532a13cc557220ccee45a2 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 19 Jun 2024 12:14:38 -0700 Subject: [PATCH 030/256] [xla:cpu] Add ReplicaId thunk PiperOrigin-RevId: 644797361 --- third_party/xla/xla/service/cpu/BUILD | 1 + third_party/xla/xla/service/cpu/runtime/BUILD | 24 ++++++ .../service/cpu/runtime/replica_id_thunk.cc | 80 +++++++++++++++++++ .../service/cpu/runtime/replica_id_thunk.h | 45 +++++++++++ .../xla/xla/service/cpu/runtime/thunk.cc | 12 ++- .../xla/xla/service/cpu/runtime/thunk.h | 2 + .../xla/xla/service/cpu/thunk_emitter.cc | 14 ++++ .../xla/xla/service/cpu/thunk_emitter.h | 3 + third_party/xla/xla/tests/BUILD | 2 + .../xla/xla/tests/collective_ops_test.cc | 1 + 10 files changed, 180 insertions(+), 4 deletions(-) create mode 100644 third_party/xla/xla/service/cpu/runtime/replica_id_thunk.cc create mode 100644 third_party/xla/xla/service/cpu/runtime/replica_id_thunk.h diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 6753c4595dba50..d42ad330d1b74e 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -821,6 +821,7 @@ cc_library( "//xla/service/cpu/runtime:infeed_thunk", "//xla/service/cpu/runtime:kernel_thunk", "//xla/service/cpu/runtime:outfeed_thunk", + "//xla/service/cpu/runtime:replica_id_thunk", "//xla/service/cpu/runtime:rng_state_thunk", "//xla/service/cpu/runtime:thunk", "//xla/service/cpu/runtime:while_thunk", diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD index ddc8670ef7964a..5d2938ece9f9a5 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/service/cpu/runtime/BUILD @@ -338,6 +338,30 @@ cc_library( ], ) +cc_library( + name = "replica_id_thunk", + srcs = ["replica_id_thunk.cc"], + hdrs = ["replica_id_thunk.h"], + deps = [ + ":thunk", + "//xla:status_macros", + "//xla:util", + "//xla/runtime:buffer_use", + "//xla/service:buffer_assignment", + "//xla/service/cpu:cpu_runtime", + "//xla/stream_executor", + "//xla/tsl/concurrency:async_value", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:traceme", + ], +) + cc_library( name = "infeed_thunk", srcs = ["infeed_thunk.cc"], diff --git a/third_party/xla/xla/service/cpu/runtime/replica_id_thunk.cc b/third_party/xla/xla/service/cpu/runtime/replica_id_thunk.cc new file mode 100644 index 00000000000000..974e39418f148d --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/replica_id_thunk.cc @@ -0,0 +1,80 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/runtime/replica_id_thunk.h" + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "xla/runtime/buffer_use.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/cpu/runtime/thunk.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/traceme.h" + +namespace xla::cpu { + +absl::StatusOr> ReplicaIdThunk::Create( + Info info, BufferAllocation::Slice replica_id_buffer) { + return absl::WrapUnique( + new ReplicaIdThunk(std::move(info), replica_id_buffer)); +} + +ReplicaIdThunk::ReplicaIdThunk(Info info, + BufferAllocation::Slice replica_id_buffer) + : Thunk(Kind::kReplicaId, info), replica_id_buffer_(replica_id_buffer) {} + +tsl::AsyncValueRef ReplicaIdThunk::Execute( + const ExecuteParams& params) { + tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); + + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase replica_id_data, + params.buffer_allocations->GetDeviceAddress(replica_id_buffer_)); + + TF_RET_CHECK(replica_id_data.size() == sizeof(int32_t)) + << "Replica id buffer must be able to fit replica id value"; + + TF_RET_CHECK(params.collective_params) + << "Replica id requires collective params"; + + TF_ASSIGN_OR_RETURN( + int32_t replica_id, + params.collective_params->device_assignment->ReplicaIdForDevice( + params.collective_params->global_device_id)); + + VLOG(3) << absl::StreamFormat("Replica id: %d", replica_id); + VLOG(3) << absl::StreamFormat(" replica_id: slice %s (%p)", + replica_id_buffer_.ToString(), + replica_id_data.opaque()); + + std::memcpy(replica_id_data.opaque(), &replica_id, sizeof(int32_t)); + return OkExecuteEvent(); +} + +ReplicaIdThunk::BufferUses ReplicaIdThunk::buffer_uses() const { + return {BufferUse::Write(replica_id_buffer_)}; +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/replica_id_thunk.h b/third_party/xla/xla/service/cpu/runtime/replica_id_thunk.h new file mode 100644 index 00000000000000..22f61d372186c5 --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/replica_id_thunk.h @@ -0,0 +1,45 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_RUNTIME_REPLICA_ID_THUNK_H_ +#define XLA_SERVICE_CPU_RUNTIME_REPLICA_ID_THUNK_H_ + +#include + +#include "absl/status/statusor.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/cpu/runtime/thunk.h" +#include "xla/tsl/concurrency/async_value_ref.h" + +namespace xla::cpu { + +class ReplicaIdThunk final : public Thunk { + public: + static absl::StatusOr> Create( + Info info, BufferAllocation::Slice replica_id_buffer); + + tsl::AsyncValueRef Execute(const ExecuteParams& params) final; + + BufferUses buffer_uses() const final; + + private: + ReplicaIdThunk(Info info, BufferAllocation::Slice replica_id_buffer); + + BufferAllocation::Slice replica_id_buffer_; +}; + +} // namespace xla::cpu + +#endif // XLA_SERVICE_CPU_RUNTIME_REPLICA_ID_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.cc b/third_party/xla/xla/service/cpu/runtime/thunk.cc index c60495b3779631..590da950ca9fc7 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/thunk.cc @@ -42,22 +42,26 @@ std::string_view Thunk::KindToString(Kind kind) { return "all-reduce"; case Kind::kCall: return "call"; - case Kind::kCopy: - return "copy"; case Kind::kConditional: return "conditional"; + case Kind::kCopy: + return "copy"; case Kind::kDot: return "dot"; case Kind::kFft: return "fft"; case Kind::kInfeed: return "infeed"; - case Kind::kRngGetAndUpdateState: - return "rng-get-and-update-state"; case Kind::kKernel: return "kernel"; case Kind::kOutfeed: return "outfeed"; + case Kind::kPartitionId: + return "partition-id"; + case Kind::kReplicaId: + return "replica-id"; + case Kind::kRngGetAndUpdateState: + return "rng-get-and-update-state"; case Kind::kWhile: return "while"; } diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.h b/third_party/xla/xla/service/cpu/runtime/thunk.h index 3d1a7582ba5578..8cf5fe5af93918 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.h +++ b/third_party/xla/xla/service/cpu/runtime/thunk.h @@ -70,6 +70,8 @@ class Thunk { kInfeed, kKernel, kOutfeed, + kPartitionId, + kReplicaId, kRngGetAndUpdateState, kWhile, }; diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.cc b/third_party/xla/xla/service/cpu/thunk_emitter.cc index cddcd320aee032..d9ceb43aacace3 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.cc +++ b/third_party/xla/xla/service/cpu/thunk_emitter.cc @@ -45,6 +45,7 @@ limitations under the License. #include "xla/service/cpu/runtime/infeed_thunk.h" #include "xla/service/cpu/runtime/kernel_thunk.h" #include "xla/service/cpu/runtime/outfeed_thunk.h" +#include "xla/service/cpu/runtime/replica_id_thunk.h" #include "xla/service/cpu/runtime/rng_state_thunk.h" #include "xla/service/cpu/runtime/thunk.h" #include "xla/service/cpu/runtime/while_thunk.h" @@ -199,6 +200,11 @@ absl::StatusOr ThunkEmitter::EmitHloInstruction( case HloOpcode::kXor: return EmitElementalKernelThunk(instruction); + // ReplicaId and PartitionId identify the location of the current device in + // a logical grid of communicating devices. + case HloOpcode::kReplicaId: + return EmitReplicaIdThunk(instruction); + case HloOpcode::kAllReduce: return EmitAllReduceThunk(instruction); @@ -489,6 +495,14 @@ absl::StatusOr ThunkEmitter::EmitDotThunk( } } +absl::StatusOr ThunkEmitter::EmitReplicaIdThunk( + const HloInstruction* instruction) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice replica_id_buffer, + GetAllocationSlice(instruction)); + return ThunkSequence::Of(ThunkInfo(instruction), + replica_id_buffer); +} + absl::StatusOr ThunkEmitter::EmitFftThunk( const HloInstruction* instruction) { TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.h b/third_party/xla/xla/service/cpu/thunk_emitter.h index 685dd563b94e77..c1f88e15580676 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.h +++ b/third_party/xla/xla/service/cpu/thunk_emitter.h @@ -107,6 +107,9 @@ class ThunkEmitter { absl::StatusOr EmitDotThunk(const HloInstruction* instruction); + absl::StatusOr EmitReplicaIdThunk( + const HloInstruction* instruction); + // Returns the list of buffer allocation slices assigned to the given // instruction that will be passed to the host kernel as arguments: a // flattened list of all the leaf buffers for all operands and result. We do diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index b841b47a70624f..7250e59f99f2a2 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -2167,6 +2167,7 @@ xla_test( "gpu", "cpu", ], + tags = ["test_xla_cpu_thunks"], deps = [ ":hlo_test_base", ":literal_test_util", @@ -2175,6 +2176,7 @@ xla_test( ":xla_internal_test_main", "//xla:literal", "//xla:shape_util", + "//xla/service:hlo_module_config", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/lib/core:status_test_util", diff --git a/third_party/xla/xla/tests/collective_ops_test.cc b/third_party/xla/xla/tests/collective_ops_test.cc index 6c1aee7c677d76..e8865f43cc429b 100644 --- a/third_party/xla/xla/tests/collective_ops_test.cc +++ b/third_party/xla/xla/tests/collective_ops_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/literal.h" #include "xla/primitive_util.h" +#include "xla/service/hlo_module_config.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" From f84430a835333b07ed0c8121ce05dad3a93bca6b Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 19 Jun 2024 13:18:57 -0700 Subject: [PATCH 031/256] [xla:cpu] Add support for ReduceScatter thunk + refactor collective thunk to extract shared code PiperOrigin-RevId: 644807752 --- third_party/xla/xla/service/cpu/BUILD | 2 + third_party/xla/xla/service/cpu/runtime/BUILD | 44 ++++- .../service/cpu/runtime/all_reduce_thunk.cc | 159 ++++-------------- .../service/cpu/runtime/all_reduce_thunk.h | 24 +-- .../service/cpu/runtime/collective_thunk.cc | 148 +++++++++++++++- .../service/cpu/runtime/collective_thunk.h | 58 ++++++- .../cpu/runtime/reduce_scatter_thunk.cc | 101 +++++++++++ .../cpu/runtime/reduce_scatter_thunk.h | 48 ++++++ .../xla/xla/service/cpu/runtime/thunk.cc | 2 + .../xla/xla/service/cpu/runtime/thunk.h | 1 + .../xla/xla/service/cpu/thunk_emitter.cc | 91 +++++++--- .../xla/xla/service/cpu/thunk_emitter.h | 9 +- .../xla/xla/tests/collective_ops_test.cc | 1 + 13 files changed, 513 insertions(+), 175 deletions(-) create mode 100644 third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.cc create mode 100644 third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.h diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index d42ad330d1b74e..9b30bc9c159cf1 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -814,6 +814,7 @@ cc_library( "//xla/service/cpu:dot_op_emitter", "//xla/service/cpu/runtime:all_reduce_thunk", "//xla/service/cpu/runtime:call_thunk", + "//xla/service/cpu/runtime:collective_thunk", "//xla/service/cpu/runtime:conditional_thunk", "//xla/service/cpu/runtime:copy_thunk", "//xla/service/cpu/runtime:dot_thunk", @@ -821,6 +822,7 @@ cc_library( "//xla/service/cpu/runtime:infeed_thunk", "//xla/service/cpu/runtime:kernel_thunk", "//xla/service/cpu/runtime:outfeed_thunk", + "//xla/service/cpu/runtime:reduce_scatter_thunk", "//xla/service/cpu/runtime:replica_id_thunk", "//xla/service/cpu/runtime:rng_state_thunk", "//xla/service/cpu/runtime:thunk", diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD index 5d2938ece9f9a5..dd55c32e5dda58 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/service/cpu/runtime/BUILD @@ -194,7 +194,38 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:numbers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:traceme", + ], +) + +cc_library( + name = "reduce_scatter_thunk", + srcs = ["reduce_scatter_thunk.cc"], + hdrs = ["reduce_scatter_thunk.h"], + deps = [ + ":collective_thunk", + ":thunk", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/runtime:buffer_use", + "//xla/service:buffer_assignment", + "//xla/service:collective_ops_utils", + "//xla/service/cpu:collectives_interface", + "//xla/stream_executor", + "//xla/tsl/concurrency:async_value", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:numbers", @@ -209,18 +240,29 @@ cc_library( hdrs = ["collective_thunk.h"], deps = [ ":thunk", + "//xla:shape_util", "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/runtime:buffer_use", + "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/service:computation_placer", "//xla/service:global_device_id", + "//xla/service/cpu:collectives_interface", + "//xla/stream_executor:device_memory", + "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.cc b/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.cc index aa502f45b24c16..d996c70d63edad 100644 --- a/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/service/cpu/runtime/all_reduce_thunk.h" -#include #include #include #include @@ -23,11 +22,11 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "xla/primitive_util.h" -#include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/collectives_interface.h" @@ -35,8 +34,6 @@ limitations under the License. #include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/util.h" #include "tsl/platform/errors.h" @@ -45,160 +42,76 @@ limitations under the License. #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { -namespace { - -static bool IsDataTypeSupportedByCollectiveReduce(PrimitiveType datatype) { - switch (datatype) { - case PRED: - case S8: - case U8: - case S16: - case U16: - case S32: - case U32: - case S64: - case U64: - case F16: - case F32: - case F64: - case C64: - case C128: - return true; - default: - return false; - } -} - -} // namespace absl::StatusOr> AllReduceThunk::Create( Info info, ReductionKind reduction_kind, OpParams op_params, - absl::Span source_buffers, - absl::Span source_shapes, - absl::Span destination_buffers, - absl::Span destination_shapes, bool single_replica) { - auto datatype = source_shapes[0].element_type(); - - // Check that the data types are supported. + OpBuffers op_buffers, bool single_replica) { + auto datatype = op_buffers.source_shapes[0].element_type(); if (!IsDataTypeSupportedByCollectiveReduce(datatype)) { return Unimplemented("AllReduce for datatype '%s' is not supported", primitive_util::LowercasePrimitiveTypeName(datatype)); } - return absl::WrapUnique(new AllReduceThunk( - std::move(info), reduction_kind, op_params, source_buffers, source_shapes, - destination_buffers, destination_shapes, single_replica)); + return absl::WrapUnique(new AllReduceThunk(std::move(info), reduction_kind, + op_params, std::move(op_buffers), + single_replica)); } -AllReduceThunk::AllReduceThunk( - Info info, ReductionKind reduction_kind, OpParams op_params, - absl::Span source_buffers, - absl::Span source_shapes, - absl::Span destination_buffers, - absl::Span destination_shapes, bool single_replica) - : CollectiveThunk(Kind::kAllReduce, info, op_params), +AllReduceThunk::AllReduceThunk(Info info, ReductionKind reduction_kind, + OpParams op_params, OpBuffers op_buffers, + bool single_replica) + : CollectiveThunk(Kind::kAllReduce, info, op_params, std::move(op_buffers)), reduction_kind_(reduction_kind), - source_buffers_(source_buffers.begin(), source_buffers.end()), - source_shapes_(source_shapes.begin(), source_shapes.end()), - destination_buffers_(destination_buffers.begin(), - destination_buffers.end()), - destination_shapes_(destination_shapes.begin(), destination_shapes.end()), single_replica_(single_replica) {} tsl::AsyncValueRef AllReduceThunk::Execute( const ExecuteParams& params) { tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); - size_t num_srcs = source_buffers_.size(); - size_t num_dsts = destination_buffers_.size(); - DCHECK_EQ(num_srcs, num_dsts) << "Number of src and dst buffers must match"; + TF_ASSIGN_OR_RETURN(OpDeviceMemory data, GetOpDeviceMemory(params)); VLOG(3) << absl::StreamFormat( "AllReduce: #source_buffers=%d, #destination_buffers=%d, " "reduction_kind=%s, single_replica=%v", - num_srcs, num_dsts, ReductionKindToString(reduction_kind_), - single_replica_); - - absl::InlinedVector source_data(num_srcs); - for (int i = 0; i < num_srcs; ++i) { - TF_ASSIGN_OR_RETURN( - source_data[i], - params.buffer_allocations->GetDeviceAddress(source_buffers_[i])); + data.source.size(), data.destination.size(), + ReductionKindToString(reduction_kind_), single_replica_); + + for (int i = 0; i < data.source.size(); ++i) { VLOG(3) << absl::StreamFormat( - " src: %s in slice %s (%p)", source_shapes_[i].ToString(true), - source_buffers_[i].ToString(), source_data[i].opaque()); + " src: %s in slice %s (%p)", source_shape(i).ToString(true), + source_buffer(i).ToString(), data.source[i].opaque()); } - absl::InlinedVector destination_data(num_dsts); - for (int i = 0; i < num_dsts; ++i) { - TF_ASSIGN_OR_RETURN( - destination_data[i], - params.buffer_allocations->GetDeviceAddress(destination_buffers_[i])); + for (int i = 0; i < data.destination.size(); ++i) { VLOG(3) << absl::StreamFormat( - " dst: %s in slice %s (%p)", destination_shapes_[i].ToString(true), - destination_buffers_[i].ToString(), destination_data[i].opaque()); + " dst: %s in slice %s (%p)", destination_shape(i).ToString(true), + destination_buffer(i).ToString(), data.destination[i].opaque()); } // Handle single-replica case by copying the source to the destination. if (single_replica_) { - DCHECK_EQ(source_data.size(), destination_data.size()); - for (int i = 0; i < num_srcs; ++i) { - std::memcpy(destination_data[i].opaque(), source_data[i].opaque(), - destination_data[i].size()); + DCHECK_EQ(data.source.size(), data.destination.size()); + for (int i = 0; i < data.source.size(); ++i) { + std::memcpy(data.destination[i].opaque(), data.source[i].opaque(), + data.destination[i].size()); } return OkExecuteEvent(); } - // For multi-replica case, we need access to collectives implementation to be - // able to perform the all-reduce operation collectively with other replicas. - CollectiveExecuteParams* collective_params = params.collective_params; - TF_RET_CHECK(collective_params) - << "Collective parameters are not set for all-reduce operation"; - - CollectivesInterface* collectives = collective_params->collectives; - TF_RET_CHECK(collectives) - << "Collectives interface is not set for all-reduce operation"; - - // Find out rendezvous key and rank in global devices for the current device. - TF_ASSIGN_OR_RETURN(RendezvousKey key, GetRendezvousKey(*collective_params)); - TF_ASSIGN_OR_RETURN( - int32_t rank, - RankInGlobalDevices(key, collective_params->global_device_id)); - - VLOG(3) << absl::StreamFormat(" rank=%d, key=%s", rank, key.ToString()); - - TF_ASSIGN_OR_RETURN(std::shared_ptr communicator, - collectives->GetCommunicator(key.global_devices, rank)); - - for (int32_t i = 0; i < num_srcs; ++i) { - const Shape& shape = source_shapes_[i]; - TF_RETURN_IF_ERROR(communicator->AllReduce( - key, reduction_kind_, shape.element_type(), - ShapeUtil::ElementsIn(shape), source_data[i].opaque(), - destination_data[i].opaque(), DefaultCollectiveTimeout())); - } + return ExecuteWithCommunicator( + params.collective_params, + [&](const RendezvousKey& key, CollectivesCommunicator& comm) { + for (int32_t i = 0; i < data.source.size(); ++i) { + const Shape& shape = destination_shape(i); + TF_RETURN_IF_ERROR(comm.AllReduce( + key, reduction_kind_, shape.element_type(), + ShapeUtil::ElementsIn(shape), data.source[i].opaque(), + data.destination[i].opaque(), DefaultCollectiveTimeout())); + } + return absl::OkStatus(); + }); return OkExecuteEvent(); } -Thunk::BufferUses AllReduceThunk::buffer_uses() const { - BufferUses uses; - uses.reserve(source_buffers_.size() + destination_buffers_.size()); - for (auto& source_buffer : source_buffers_) { - uses.push_back(BufferUse::Read(source_buffer)); - } - for (auto& destination_buffer : destination_buffers_) { - uses.push_back(BufferUse::Write(destination_buffer)); - } - - // TODO(ezhulenev): It is a hack to make sure that we execute all collective - // operations in the same order as in HLO schedule, because otherwise racing - // collectives lead to undefined behavior. Instead we should correctly model - // side effects of Thunks. - static auto* fake_alloc = new BufferAllocation(0, 1, 0); - uses.push_back(BufferUse::Write(BufferAllocation::Slice(fake_alloc, 0, 1))); - - return uses; -} - } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.h b/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.h index 4fae44c68bfd12..9af3d72fea4070 100644 --- a/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.h +++ b/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.h @@ -17,14 +17,10 @@ limitations under the License. #define XLA_SERVICE_CPU_RUNTIME_ALL_REDUCE_THUNK_H_ #include -#include #include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/runtime/collective_thunk.h" -#include "xla/shape.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/xla_data.pb.h" @@ -36,31 +32,15 @@ class AllReduceThunk final : public CollectiveThunk { static absl::StatusOr> Create( Info info, ReductionKind reduction_kind, OpParams op_params, - absl::Span source_buffers, - absl::Span source_shapes, - absl::Span destination_buffers, - absl::Span destination_shapes, bool single_replica); + OpBuffers op_buffers, bool single_replica); tsl::AsyncValueRef Execute(const ExecuteParams& params) final; - BufferUses buffer_uses() const final; - private: AllReduceThunk(Info info, ReductionKind reduction_kind, OpParams op_params, - absl::Span source_buffers, - absl::Span source_shapes, - absl::Span destination_buffers, - absl::Span destination_shapes, - bool single_replica); + OpBuffers op_buffers, bool single_replica); ReductionKind reduction_kind_; - - std::vector source_buffers_; - std::vector source_shapes_; - - std::vector destination_buffers_; - std::vector destination_shapes_; - bool single_replica_; }; diff --git a/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc b/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc index 3f018567942f5d..75139bc0fd68e9 100644 --- a/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc @@ -15,32 +15,113 @@ limitations under the License. #include "xla/service/cpu/runtime/collective_thunk.h" +#include #include #include +#include #include #include #include #include #include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/runtime/buffer_use.h" +#include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/computation_placer.h" +#include "xla/service/cpu/collectives_interface.h" #include "xla/service/cpu/runtime/thunk.h" #include "xla/service/global_device_id.h" +#include "xla/shape.h" #include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" namespace xla::cpu { CollectiveThunk::CollectiveThunk(Kind kind, Thunk::Info info, - OpParams op_params) - : Thunk(kind, info), op_params_(std::move(op_params)) {} + OpParams op_params, OpBuffers op_buffers) + : Thunk(kind, info), + op_params_(std::move(op_params)), + op_buffers_(std::move(op_buffers)) {} + +Thunk::BufferUses CollectiveThunk::buffer_uses() const { + BufferUses uses; + uses.reserve(source_buffers().size() + destination_buffers().size()); + for (auto& source_buffer : source_buffers()) { + uses.push_back(BufferUse::Read(source_buffer)); + } + for (auto& destination_buffer : destination_buffers()) { + uses.push_back(BufferUse::Write(destination_buffer)); + } + + // TODO(ezhulenev): It is a hack to make sure that we execute all collective + // operations in the same order as in HLO schedule, because otherwise racing + // collectives lead to undefined behavior. Instead we should correctly model + // side effects of Thunks. + static auto* fake_alloc = new BufferAllocation(0, 1, 0); + uses.push_back(BufferUse::Write(BufferAllocation::Slice(fake_alloc, 0, 1))); + + return uses; +} + +bool CollectiveThunk::IsDataTypeSupportedByCollectiveReduce( + PrimitiveType datatype) { + switch (datatype) { + case PRED: + case S8: + case U8: + case S16: + case U16: + case S32: + case U32: + case S64: + case U64: + case F16: + case F32: + case F64: + case C64: + case C128: + return true; + default: + return false; + } +} + +absl::StatusOr +CollectiveThunk::GetOpDeviceMemory(const ExecuteParams& params) { + size_t num_srcs = source_buffers().size(); + size_t num_dsts = destination_buffers().size(); + DCHECK_EQ(num_srcs, num_dsts) << "Number of src and dst buffers must match"; + + absl::InlinedVector source_data(num_srcs); + for (int i = 0; i < num_srcs; ++i) { + TF_ASSIGN_OR_RETURN( + source_data[i], + params.buffer_allocations->GetDeviceAddress(source_buffer(i))); + } + + absl::InlinedVector destination_data(num_dsts); + for (int i = 0; i < num_dsts; ++i) { + TF_ASSIGN_OR_RETURN( + destination_data[i], + params.buffer_allocations->GetDeviceAddress(destination_buffer(i))); + } + + return OpDeviceMemory{std::move(source_data), std::move(destination_data)}; +} absl::Duration CollectiveThunk::DefaultCollectiveTimeout() { return absl::Minutes(30); @@ -84,4 +165,67 @@ absl::StatusOr CollectiveThunk::RankInGlobalDevices( return std::distance(key.global_devices.begin(), it); } +tsl::AsyncValueRef +CollectiveThunk::ExecuteWithCommunicator( + const Thunk::CollectiveExecuteParams* params, Callback callback) { + // Check that we have access to collectives interface implementation and + // parameters that define our "position" in a collective clique. + TF_RET_CHECK(params) + << "Collective parameters are not set for collective operation"; + + CollectivesInterface* collectives = params->collectives; + TF_RET_CHECK(collectives) + << "Collectives interface is not set for collective operation"; + + // Find out rendezvous key and rank in global devices for the current device. + TF_ASSIGN_OR_RETURN(RendezvousKey key, GetRendezvousKey(*params)); + TF_ASSIGN_OR_RETURN(int32_t rank, + RankInGlobalDevices(key, params->global_device_id)); + + VLOG(3) << absl::StreamFormat(" rank=%d, key=%s", rank, key.ToString()); + + TF_ASSIGN_OR_RETURN(std::shared_ptr communicator, + collectives->GetCommunicator(key.global_devices, rank)); + + TF_RETURN_IF_ERROR(callback(key, *communicator)); + + return OkExecuteEvent(); +} + +const BufferAllocation::Slice& CollectiveThunk::source_buffer( + int64_t index) const { + return op_buffers_.source_buffers[index]; +} + +absl::Span CollectiveThunk::source_buffers() + const { + return op_buffers_.source_buffers; +} + +const Shape& CollectiveThunk::source_shape(int64_t index) const { + return op_buffers_.source_shapes[index]; +} + +absl::Span CollectiveThunk::source_shapes() const { + return op_buffers_.source_shapes; +} + +const BufferAllocation::Slice& CollectiveThunk::destination_buffer( + int64_t index) const { + return op_buffers_.destination_buffers[index]; +} + +absl::Span CollectiveThunk::destination_buffers() + const { + return op_buffers_.destination_buffers; +} + +const Shape& CollectiveThunk::destination_shape(int64_t index) const { + return op_buffers_.destination_shapes[index]; +} + +absl::Span CollectiveThunk::destination_shapes() const { + return op_buffers_.destination_shapes; +} + } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/collective_thunk.h b/third_party/xla/xla/service/cpu/runtime/collective_thunk.h index 588b3b43f7d3a3..f3b27a52b501cc 100644 --- a/third_party/xla/xla/service/cpu/runtime/collective_thunk.h +++ b/third_party/xla/xla/service/cpu/runtime/collective_thunk.h @@ -17,14 +17,24 @@ limitations under the License. #define XLA_SERVICE_CPU_RUNTIME_COLLECTIVE_THUNK_H_ #include +#include #include #include +#include "absl/container/inlined_vector.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" +#include "xla/service/cpu/collectives_interface.h" #include "xla/service/cpu/runtime/thunk.h" #include "xla/service/global_device_id.h" +#include "xla/shape.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/tsl/concurrency/async_value_ref.h" namespace xla::cpu { @@ -42,11 +52,39 @@ class CollectiveThunk : public Thunk { std::vector group; }; - CollectiveThunk(Kind kind, Thunk::Info info, OpParams op_params); + // Source and destination buffers for the collective operation. + struct OpBuffers { + std::vector source_buffers; + std::vector source_shapes; + + std::vector destination_buffers; + std::vector destination_shapes; + }; + + // Device memory resolved for the collective operation buffers. + struct OpDeviceMemory { + absl::InlinedVector source; + absl::InlinedVector destination; + }; + + CollectiveThunk(Kind kind, Thunk::Info info, OpParams op_params, + OpBuffers op_buffers); const OpParams& op_params() const { return op_params_; } + const OpBuffers& op_buffers() const { return op_buffers_; } + + // Resolves operation's device memory from the buffers and buffer allocations. + absl::StatusOr GetOpDeviceMemory(const ExecuteParams& params); + + BufferUses buffer_uses() const final; protected: + // Callback for collective thunk implementations. + using Callback = absl::AnyInvocable; + + static bool IsDataTypeSupportedByCollectiveReduce(PrimitiveType datatype); + absl::Duration DefaultCollectiveTimeout(); absl::StatusOr GetRendezvousKey( @@ -55,8 +93,26 @@ class CollectiveThunk : public Thunk { absl::StatusOr RankInGlobalDevices(const RendezvousKey& key, GlobalDeviceId device); + // Acquires collective communicator for the given parameters and executes the + // user provided callback with acquired rendezvous key, rank and communicator. + tsl::AsyncValueRef ExecuteWithCommunicator( + const Thunk::CollectiveExecuteParams* params, Callback callback); + + const BufferAllocation::Slice& source_buffer(int64_t index) const; + absl::Span source_buffers() const; + + const Shape& source_shape(int64_t index) const; + absl::Span source_shapes() const; + + const BufferAllocation::Slice& destination_buffer(int64_t index) const; + absl::Span destination_buffers() const; + + const Shape& destination_shape(int64_t index) const; + absl::Span destination_shapes() const; + private: OpParams op_params_; + OpBuffers op_buffers_; }; } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.cc b/third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.cc new file mode 100644 index 00000000000000..59e7c560f3416a --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.cc @@ -0,0 +1,101 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/runtime/reduce_scatter_thunk.h" + +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "xla/primitive_util.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/cpu/collectives_interface.h" +#include "xla/service/cpu/runtime/collective_thunk.h" +#include "xla/service/cpu/runtime/thunk.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/traceme.h" + +namespace xla::cpu { + +absl::StatusOr> ReduceScatterThunk::Create( + Info info, ReductionKind reduction_kind, OpParams op_params, + OpBuffers op_buffers) { + auto datatype = op_buffers.source_shapes[0].element_type(); + if (!IsDataTypeSupportedByCollectiveReduce(datatype)) { + return Unimplemented("ReduceScatter for datatype '%s' is not supported", + primitive_util::LowercasePrimitiveTypeName(datatype)); + } + + return absl::WrapUnique(new ReduceScatterThunk( + std::move(info), reduction_kind, op_params, std::move(op_buffers))); +} + +ReduceScatterThunk::ReduceScatterThunk(Info info, ReductionKind reduction_kind, + OpParams op_params, OpBuffers op_buffers) + : CollectiveThunk(Kind::kReduceScatter, info, op_params, + std::move(op_buffers)), + reduction_kind_(reduction_kind) {} + +tsl::AsyncValueRef +ReduceScatterThunk::Execute(const ExecuteParams& params) { + tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); + + TF_ASSIGN_OR_RETURN(OpDeviceMemory data, GetOpDeviceMemory(params)); + + VLOG(3) << absl::StreamFormat( + "ReduceScatter: #source_buffers=%d, #destination_buffers=%d, " + "reduction_kind=%s", + data.source.size(), data.destination.size(), + ReductionKindToString(reduction_kind_)); + + for (int i = 0; i < data.source.size(); ++i) { + VLOG(3) << absl::StreamFormat( + " src: %s in slice %s (%p)", source_shape(i).ToString(true), + source_buffer(i).ToString(), data.source[i].opaque()); + } + + for (int i = 0; i < data.destination.size(); ++i) { + VLOG(3) << absl::StreamFormat( + " dst: %s in slice %s (%p)", destination_shape(i).ToString(true), + destination_buffer(i).ToString(), data.destination[i].opaque()); + } + + return ExecuteWithCommunicator( + params.collective_params, + [&](const RendezvousKey& key, CollectivesCommunicator& comm) { + for (int32_t i = 0; i < data.source.size(); ++i) { + const Shape& shape = destination_shape(i); + TF_RETURN_IF_ERROR(comm.ReduceScatter( + key, reduction_kind_, shape.element_type(), + ShapeUtil::ElementsIn(shape), data.source[i].opaque(), + data.destination[i].opaque(), DefaultCollectiveTimeout())); + } + return absl::OkStatus(); + }); +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.h b/third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.h new file mode 100644 index 00000000000000..0c01d653eb2d4c --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.h @@ -0,0 +1,48 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_RUNTIME_REDUCE_SCATTER_THUNK_H_ +#define XLA_SERVICE_CPU_RUNTIME_REDUCE_SCATTER_THUNK_H_ + +#include + +#include "absl/status/statusor.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/cpu/runtime/collective_thunk.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +class ReduceScatterThunk final : public CollectiveThunk { + public: + using CollectiveThunk::OpParams; + + static absl::StatusOr> Create( + Info info, ReductionKind reduction_kind, OpParams op_params, + OpBuffers op_buffers); + + tsl::AsyncValueRef Execute(const ExecuteParams& params) final; + + private: + ReduceScatterThunk(Info info, ReductionKind reduction_kind, + OpParams op_params, OpBuffers op_buffers); + + ReductionKind reduction_kind_; +}; + +} // namespace xla::cpu + +#endif // XLA_SERVICE_CPU_RUNTIME_REDUCE_SCATTER_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.cc b/third_party/xla/xla/service/cpu/runtime/thunk.cc index 590da950ca9fc7..0dadb5c4fb9a1e 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/thunk.cc @@ -58,6 +58,8 @@ std::string_view Thunk::KindToString(Kind kind) { return "outfeed"; case Kind::kPartitionId: return "partition-id"; + case Kind::kReduceScatter: + return "reduce-scatter"; case Kind::kReplicaId: return "replica-id"; case Kind::kRngGetAndUpdateState: diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.h b/third_party/xla/xla/service/cpu/runtime/thunk.h index 8cf5fe5af93918..5dd9d044c14b78 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.h +++ b/third_party/xla/xla/service/cpu/runtime/thunk.h @@ -71,6 +71,7 @@ class Thunk { kKernel, kOutfeed, kPartitionId, + kReduceScatter, kReplicaId, kRngGetAndUpdateState, kWhile, diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.cc b/third_party/xla/xla/service/cpu/thunk_emitter.cc index d9ceb43aacace3..0cd2277bf95df0 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.cc +++ b/third_party/xla/xla/service/cpu/thunk_emitter.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/service/cpu/ir_emitter2.h" #include "xla/service/cpu/runtime/all_reduce_thunk.h" #include "xla/service/cpu/runtime/call_thunk.h" +#include "xla/service/cpu/runtime/collective_thunk.h" #include "xla/service/cpu/runtime/conditional_thunk.h" #include "xla/service/cpu/runtime/copy_thunk.h" #include "xla/service/cpu/runtime/dot_thunk.h" @@ -45,6 +46,7 @@ limitations under the License. #include "xla/service/cpu/runtime/infeed_thunk.h" #include "xla/service/cpu/runtime/kernel_thunk.h" #include "xla/service/cpu/runtime/outfeed_thunk.h" +#include "xla/service/cpu/runtime/reduce_scatter_thunk.h" #include "xla/service/cpu/runtime/replica_id_thunk.h" #include "xla/service/cpu/runtime/rng_state_thunk.h" #include "xla/service/cpu/runtime/thunk.h" @@ -207,6 +209,8 @@ absl::StatusOr ThunkEmitter::EmitHloInstruction( case HloOpcode::kAllReduce: return EmitAllReduceThunk(instruction); + case HloOpcode::kReduceScatter: + return EmitReduceScatterThunk(instruction); // TODO(ezhulenev): Port pad optimizations from IrEmitter. case HloOpcode::kPad: @@ -255,24 +259,38 @@ absl::StatusOr ThunkEmitter::EmitHloInstruction( } } -absl::StatusOr ThunkEmitter::EmitAllReduceThunk( - const HloInstruction* instruction) { - auto* all_reduce = Cast(instruction); - - // Check that we recognize the reduction computation attached to a collective. - auto reduction_kind = MatchReductionComputation(all_reduce->to_apply()); - if (!reduction_kind.has_value()) { - return Unimplemented("AllReduce for computation '%s' is not supported", - all_reduce->to_apply()->ToString()); +static absl::StatusOr MatchReductionKind( + const HloComputation* computation) { + if (auto reduction_kind = MatchReductionComputation(computation)) { + return reduction_kind.value(); } + return Unimplemented("Unsupported reduction computation: %s", + computation->ToString()); +} +template +static absl::StatusOr GetCollectiveOpParams( + const CollectiveInstruction* instruction) { + return CollectiveThunk::OpParams{ + /*op_id=*/instruction->channel_id().has_value() + ? instruction->channel_id().value() + : instruction->GetModule()->unique_id(), + /*has_channel_id=*/instruction->channel_id().has_value(), + /*use_global_device_ids=*/instruction->use_global_device_ids(), + /*replica_groups=*/instruction->replica_groups(), + }; +} + +static absl::StatusOr GetCollectiveOpBuffers( + const HloInstruction* instruction, + const BufferAssignment& buffer_assignment) { // Collect buffer slices for all operands. std::vector source_buffers; std::vector source_shapes; - for (const HloInstruction* operand : all_reduce->operands()) { + for (const HloInstruction* operand : instruction->operands()) { TF_ASSIGN_OR_RETURN(source_buffers.emplace_back(), - GetAllocationSlice(operand)); + buffer_assignment.GetUniqueSlice(operand, {})); source_shapes.push_back(operand->shape()); } @@ -281,27 +299,54 @@ absl::StatusOr ThunkEmitter::EmitAllReduceThunk( std::vector destination_shapes; for (auto& indexed : ShapeUtil::GetLeafShapes(instruction->shape())) { - TF_ASSIGN_OR_RETURN(destination_buffers.emplace_back(), - GetAllocationSlice(instruction, indexed.index)); + TF_ASSIGN_OR_RETURN( + destination_buffers.emplace_back(), + buffer_assignment.GetUniqueSlice(instruction, indexed.index)); destination_shapes.push_back(indexed.shape); } - AllReduceThunk::OpParams op_params = { - /*op_id=*/all_reduce->channel_id().has_value() - ? all_reduce->channel_id().value() - : all_reduce->GetModule()->unique_id(), - /*has_channel_id=*/all_reduce->channel_id().has_value(), - /*use_global_device_ids=*/all_reduce->use_global_device_ids(), - /*replica_groups=*/all_reduce->replica_groups(), + return CollectiveThunk::OpBuffers{ + /*source_buffers=*/std::move(source_buffers), + /*source_shapes=*/std::move(source_shapes), + /*destination_buffers=*/std::move(destination_buffers), + /*destination_shapes=*/std::move(destination_shapes), }; +} + +absl::StatusOr ThunkEmitter::EmitAllReduceThunk( + const HloInstruction* instruction) { + auto* all_reduce = Cast(instruction); + + TF_ASSIGN_OR_RETURN(ReductionKind reduction_kind, + MatchReductionKind(all_reduce->to_apply())); + TF_ASSIGN_OR_RETURN(AllReduceThunk::OpParams op_params, + GetCollectiveOpParams(all_reduce)); + TF_ASSIGN_OR_RETURN(AllReduceThunk::OpBuffers op_buffers, + GetCollectiveOpBuffers(all_reduce, buffer_assignment_)); bool single_replica = hlo_module_config_.replica_count() == 1 && hlo_module_config_.num_partitions() == 1; return ThunkSequence::Of( - ThunkInfo(all_reduce), *reduction_kind, std::move(op_params), - source_buffers, source_shapes, destination_buffers, destination_shapes, - single_replica); + ThunkInfo(all_reduce), reduction_kind, std::move(op_params), + std::move(op_buffers), single_replica); +} + +absl::StatusOr ThunkEmitter::EmitReduceScatterThunk( + const HloInstruction* instruction) { + auto* reduce_scatter = Cast(instruction); + + TF_ASSIGN_OR_RETURN(ReductionKind reduction_kind, + MatchReductionKind(reduce_scatter->to_apply())); + TF_ASSIGN_OR_RETURN(ReduceScatterThunk::OpParams op_params, + GetCollectiveOpParams(reduce_scatter)); + TF_ASSIGN_OR_RETURN( + ReduceScatterThunk::OpBuffers op_buffers, + GetCollectiveOpBuffers(reduce_scatter, buffer_assignment_)); + + return ThunkSequence::Of( + ThunkInfo(reduce_scatter), reduction_kind, std::move(op_params), + std::move(op_buffers)); } absl::StatusOr ThunkEmitter::EmitCallThunk( diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.h b/third_party/xla/xla/service/cpu/thunk_emitter.h index c1f88e15580676..c676a61d24acb7 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.h +++ b/third_party/xla/xla/service/cpu/thunk_emitter.h @@ -67,9 +67,6 @@ class ThunkEmitter { absl::StatusOr EmitHloInstruction( const HloInstruction* instruction); - absl::StatusOr EmitAllReduceThunk( - const HloInstruction* instruction); - absl::StatusOr EmitCallThunk( const HloInstruction* instruction); @@ -110,6 +107,12 @@ class ThunkEmitter { absl::StatusOr EmitReplicaIdThunk( const HloInstruction* instruction); + absl::StatusOr EmitAllReduceThunk( + const HloInstruction* instruction); + + absl::StatusOr EmitReduceScatterThunk( + const HloInstruction* instruction); + // Returns the list of buffer allocation slices assigned to the given // instruction that will be passed to the host kernel as arguments: a // flattened list of all the leaf buffers for all operands and result. We do diff --git a/third_party/xla/xla/tests/collective_ops_test.cc b/third_party/xla/xla/tests/collective_ops_test.cc index e8865f43cc429b..c2ea50132e8df9 100644 --- a/third_party/xla/xla/tests/collective_ops_test.cc +++ b/third_party/xla/xla/tests/collective_ops_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/strings/str_replace.h" From 185483bbfbcafdd83f3e4229039071e5b796b846 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 19 Jun 2024 14:01:33 -0700 Subject: [PATCH 032/256] [xla:cpu] Add support for AllGather thunk PiperOrigin-RevId: 644814295 --- third_party/xla/xla/service/cpu/BUILD | 1 + third_party/xla/xla/service/cpu/runtime/BUILD | 32 +++++++ .../service/cpu/runtime/all_gather_thunk.cc | 88 +++++++++++++++++++ .../service/cpu/runtime/all_gather_thunk.h | 41 +++++++++ .../service/cpu/runtime/all_reduce_thunk.h | 2 - .../cpu/runtime/reduce_scatter_thunk.h | 2 - .../xla/xla/service/cpu/runtime/thunk.cc | 2 + .../xla/xla/service/cpu/runtime/thunk.h | 1 + .../xla/xla/service/cpu/thunk_emitter.cc | 16 ++++ .../xla/xla/service/cpu/thunk_emitter.h | 3 + .../xla/xla/tests/collective_ops_test.cc | 1 + 11 files changed, 185 insertions(+), 4 deletions(-) create mode 100644 third_party/xla/xla/service/cpu/runtime/all_gather_thunk.cc create mode 100644 third_party/xla/xla/service/cpu/runtime/all_gather_thunk.h diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 9b30bc9c159cf1..59488c45da9fe1 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -812,6 +812,7 @@ cc_library( "//xla/service:collective_ops_utils", "//xla/service:hlo_module_config", "//xla/service/cpu:dot_op_emitter", + "//xla/service/cpu/runtime:all_gather_thunk", "//xla/service/cpu/runtime:all_reduce_thunk", "//xla/service/cpu/runtime:call_thunk", "//xla/service/cpu/runtime:collective_thunk", diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD index dd55c32e5dda58..08de6665a6b0c3 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/service/cpu/runtime/BUILD @@ -170,6 +170,38 @@ cc_library( ], ) +cc_library( + name = "all_gather_thunk", + srcs = ["all_gather_thunk.cc"], + hdrs = ["all_gather_thunk.h"], + deps = [ + ":collective_thunk", + ":thunk", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/runtime:buffer_use", + "//xla/service:buffer_assignment", + "//xla/service:collective_ops_utils", + "//xla/service/cpu:collectives_interface", + "//xla/stream_executor", + "//xla/tsl/concurrency:async_value", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:numbers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:traceme", + ], +) + cc_library( name = "all_reduce_thunk", srcs = ["all_reduce_thunk.cc"], diff --git a/third_party/xla/xla/service/cpu/runtime/all_gather_thunk.cc b/third_party/xla/xla/service/cpu/runtime/all_gather_thunk.cc new file mode 100644 index 00000000000000..fecb85a79bebcd --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/all_gather_thunk.cc @@ -0,0 +1,88 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/runtime/all_gather_thunk.h" + +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/cpu/collectives_interface.h" +#include "xla/service/cpu/runtime/collective_thunk.h" +#include "xla/service/cpu/runtime/thunk.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/traceme.h" + +namespace xla::cpu { + +absl::StatusOr> AllGatherThunk::Create( + Info info, OpParams op_params, OpBuffers op_buffers) { + return absl::WrapUnique( + new AllGatherThunk(std::move(info), op_params, std::move(op_buffers))); +} + +AllGatherThunk::AllGatherThunk(Info info, OpParams op_params, + OpBuffers op_buffers) + : CollectiveThunk(Kind::kAllGather, info, op_params, + std::move(op_buffers)) {} + +tsl::AsyncValueRef AllGatherThunk::Execute( + const ExecuteParams& params) { + tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); + + TF_ASSIGN_OR_RETURN(OpDeviceMemory data, GetOpDeviceMemory(params)); + + VLOG(3) << absl::StreamFormat( + "AllGather: #source_buffers=%d, #destination_buffers=%d", + data.source.size(), data.destination.size()); + + for (int i = 0; i < data.source.size(); ++i) { + VLOG(3) << absl::StreamFormat( + " src: %s in slice %s (%p)", source_shape(i).ToString(true), + source_buffer(i).ToString(), data.source[i].opaque()); + } + + for (int i = 0; i < data.destination.size(); ++i) { + VLOG(3) << absl::StreamFormat( + " dst: %s in slice %s (%p)", destination_shape(i).ToString(true), + destination_buffer(i).ToString(), data.destination[i].opaque()); + } + + return ExecuteWithCommunicator( + params.collective_params, + [&](const RendezvousKey& key, CollectivesCommunicator& comm) { + for (int32_t i = 0; i < data.source.size(); ++i) { + const Shape& shape = source_shape(i); + TF_RETURN_IF_ERROR(comm.AllGather( + key, ShapeUtil::ByteSizeOf(shape), data.source[i].opaque(), + data.destination[i].opaque(), DefaultCollectiveTimeout())); + } + return absl::OkStatus(); + }); +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/all_gather_thunk.h b/third_party/xla/xla/service/cpu/runtime/all_gather_thunk.h new file mode 100644 index 00000000000000..028824559daad2 --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/all_gather_thunk.h @@ -0,0 +1,41 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_RUNTIME_ALL_GATHER_THUNK_H_ +#define XLA_SERVICE_CPU_RUNTIME_ALL_GATHER_THUNK_H_ + +#include + +#include "absl/status/statusor.h" +#include "xla/service/cpu/runtime/collective_thunk.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +class AllGatherThunk final : public CollectiveThunk { + public: + static absl::StatusOr> Create( + Info info, OpParams op_params, OpBuffers op_buffers); + + tsl::AsyncValueRef Execute(const ExecuteParams& params) final; + + private: + AllGatherThunk(Info info, OpParams op_params, OpBuffers op_buffers); +}; + +} // namespace xla::cpu + +#endif // XLA_SERVICE_CPU_RUNTIME_ALL_GATHER_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.h b/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.h index 9af3d72fea4070..e85e58e87f774e 100644 --- a/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.h +++ b/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.h @@ -28,8 +28,6 @@ namespace xla::cpu { class AllReduceThunk final : public CollectiveThunk { public: - using CollectiveThunk::OpParams; - static absl::StatusOr> Create( Info info, ReductionKind reduction_kind, OpParams op_params, OpBuffers op_buffers, bool single_replica); diff --git a/third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.h b/third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.h index 0c01d653eb2d4c..4ca5dbaa51f3be 100644 --- a/third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.h +++ b/third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.h @@ -28,8 +28,6 @@ namespace xla::cpu { class ReduceScatterThunk final : public CollectiveThunk { public: - using CollectiveThunk::OpParams; - static absl::StatusOr> Create( Info info, ReductionKind reduction_kind, OpParams op_params, OpBuffers op_buffers); diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.cc b/third_party/xla/xla/service/cpu/runtime/thunk.cc index 0dadb5c4fb9a1e..97f6c15dadf85b 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/thunk.cc @@ -38,6 +38,8 @@ namespace xla::cpu { std::string_view Thunk::KindToString(Kind kind) { switch (kind) { + case Kind::kAllGather: + return "all-gather"; case Kind::kAllReduce: return "all-reduce"; case Kind::kCall: diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.h b/third_party/xla/xla/service/cpu/runtime/thunk.h index 5dd9d044c14b78..0750267c4f850b 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.h +++ b/third_party/xla/xla/service/cpu/runtime/thunk.h @@ -61,6 +61,7 @@ namespace xla::cpu { class Thunk { public: enum class Kind { + kAllGather, kAllReduce, kCall, kCopy, diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.cc b/third_party/xla/xla/service/cpu/thunk_emitter.cc index 0cd2277bf95df0..5ad742c8b4cafb 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.cc +++ b/third_party/xla/xla/service/cpu/thunk_emitter.cc @@ -36,6 +36,7 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/dot_op_emitter.h" #include "xla/service/cpu/ir_emitter2.h" +#include "xla/service/cpu/runtime/all_gather_thunk.h" #include "xla/service/cpu/runtime/all_reduce_thunk.h" #include "xla/service/cpu/runtime/call_thunk.h" #include "xla/service/cpu/runtime/collective_thunk.h" @@ -207,6 +208,8 @@ absl::StatusOr ThunkEmitter::EmitHloInstruction( case HloOpcode::kReplicaId: return EmitReplicaIdThunk(instruction); + case HloOpcode::kAllGather: + return EmitAllGatherThunk(instruction); case HloOpcode::kAllReduce: return EmitAllReduceThunk(instruction); case HloOpcode::kReduceScatter: @@ -313,6 +316,19 @@ static absl::StatusOr GetCollectiveOpBuffers( }; } +absl::StatusOr ThunkEmitter::EmitAllGatherThunk( + const HloInstruction* instruction) { + auto* all_gather = Cast(instruction); + + TF_ASSIGN_OR_RETURN(AllGatherThunk::OpParams op_params, + GetCollectiveOpParams(all_gather)); + TF_ASSIGN_OR_RETURN(AllGatherThunk::OpBuffers op_buffers, + GetCollectiveOpBuffers(all_gather, buffer_assignment_)); + + return ThunkSequence::Of( + ThunkInfo(all_gather), std::move(op_params), std::move(op_buffers)); +} + absl::StatusOr ThunkEmitter::EmitAllReduceThunk( const HloInstruction* instruction) { auto* all_reduce = Cast(instruction); diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.h b/third_party/xla/xla/service/cpu/thunk_emitter.h index c676a61d24acb7..d4a33366ef931e 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.h +++ b/third_party/xla/xla/service/cpu/thunk_emitter.h @@ -107,6 +107,9 @@ class ThunkEmitter { absl::StatusOr EmitReplicaIdThunk( const HloInstruction* instruction); + absl::StatusOr EmitAllGatherThunk( + const HloInstruction* instruction); + absl::StatusOr EmitAllReduceThunk( const HloInstruction* instruction); diff --git a/third_party/xla/xla/tests/collective_ops_test.cc b/third_party/xla/xla/tests/collective_ops_test.cc index c2ea50132e8df9..023df482809936 100644 --- a/third_party/xla/xla/tests/collective_ops_test.cc +++ b/third_party/xla/xla/tests/collective_ops_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include From 59d26954d3f13799c76455f0946bf5c799bdb908 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 19 Jun 2024 15:11:34 -0700 Subject: [PATCH 033/256] Copy definition of tflite::ControlEdges to flatbuffer_export.cc. PiperOrigin-RevId: 644825135 --- tensorflow/compiler/mlir/lite/BUILD | 1 - tensorflow/compiler/mlir/lite/flatbuffer_export.cc | 7 ++++++- tensorflow/lite/graph_info.h | 2 ++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 7e4afbbe358849..bf23c167467f8c 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1100,7 +1100,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:portable_gif_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/lite:graph_info", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/core:framework", "//tensorflow/lite/core/c:private_common", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index b73a9902377af9..7449a037516a6c 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -110,7 +110,6 @@ limitations under the License. #include "tensorflow/lite/core/macros.h" #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h" #include "tensorflow/lite/experimental/remat/metadata_util.h" -#include "tensorflow/lite/graph_info.h" #include "tensorflow/lite/python/metrics/converter_error_data.pb.h" #include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/toco/toco_flags.pb.h" @@ -155,6 +154,12 @@ using VectorBufferOffset = flatbuffers::Offset>; using CustomOptionsOffset = VectorBufferOffset; +// LINT.IfChange +// Node edge.second depends on node edge.first. +using ControlEdge = std::pair; +using ControlEdges = std::vector; +// LINT.ThenChange(//tensorflow/lite/graph_info.h) + namespace tfl = mlir::TFL; ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex"; diff --git a/tensorflow/lite/graph_info.h b/tensorflow/lite/graph_info.h index 1093aa8bb7203f..0370a12c5c17e2 100644 --- a/tensorflow/lite/graph_info.h +++ b/tensorflow/lite/graph_info.h @@ -92,9 +92,11 @@ struct NodeSubset { std::vector output_tensors; }; +// LINT.IfChange // Node edge.second depends on node edge.first. using ControlEdge = std::pair; using ControlEdges = std::vector; +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/flatbuffer_export.cc) // Partitions a list of node indices `nodes_to_partition` into node subsets. // Each node subset is in dependency order internally (i.e. all members of the From 3ee9d16fb7f5533bb9c49b0b79964792f9792821 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Wed, 19 Jun 2024 21:34:11 -0700 Subject: [PATCH 034/256] Stop using xla/statusor.h now that it just contains an alias for absl::Status. In some situations, this meant also changing unrelated files to directly include tsl/platform/statusor.h to get the definitions for TF_ASSIGN_OR_RETURN, etc., where they were getting transitively included for free. PiperOrigin-RevId: 644887864 --- tensorflow/compiler/mlir/lite/BUILD | 8 ++++---- tensorflow/compiler/mlir/lite/flatbuffer_operator.cc | 2 +- tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc | 2 +- tensorflow/compiler/mlir/lite/utils/convert_type.cc | 2 +- tensorflow/compiler/mlir/lite/utils/convert_type.h | 2 +- tensorflow/core/tpu/kernels/BUILD | 9 ++++----- tensorflow/core/tpu/kernels/tpu_execute_op.cc | 2 +- tensorflow/core/tpu/kernels/tpu_util.h | 2 +- third_party/xla/xla/backends/interpreter/BUILD | 6 +++--- third_party/xla/xla/backends/interpreter/compiler.cc | 2 +- third_party/xla/xla/backends/interpreter/compiler.h | 2 +- third_party/xla/xla/backends/interpreter/executable.h | 2 +- .../xla/xla/backends/interpreter/executable_base.h | 2 +- third_party/xla/xla/hlo/evaluator/BUILD | 4 ++-- third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc | 2 +- third_party/xla/xla/hlo/evaluator/hlo_evaluator.h | 2 +- third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc | 2 +- third_party/xla/xla/service/gpu/fusions/BUILD | 5 ++--- third_party/xla/xla/service/gpu/fusions/custom.h | 1 - third_party/xla/xla/service/gpu/fusions/fusion_emitter.h | 2 +- .../service/gpu/fusions/in_place_dynamic_update_slice.h | 2 +- third_party/xla/xla/stream_executor/tpu/BUILD | 5 ++--- .../xla/xla/stream_executor/tpu/c_api_conversions.h | 1 - third_party/xla/xla/stream_executor/tpu/tpu_executable.h | 2 +- .../xla/stream_executor/tpu/tpu_executable_interface.h | 2 +- 25 files changed, 34 insertions(+), 39 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index bf23c167467f8c..46ca0c761e2804 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -736,6 +736,7 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineAnalysis", "@llvm-project//mlir:Analysis", @@ -750,7 +751,6 @@ cc_library( "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla:status", - "@local_xla//xla:statusor", "@local_xla//xla/mlir_hlo", "@stablehlo//:stablehlo_ops", ], @@ -1038,6 +1038,7 @@ cc_library( "//tensorflow/lite/core/c:private_common", "//tensorflow/lite/kernels/internal:kernel_utils", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@flatbuffers", "@llvm-project//llvm:Analysis", @@ -1046,7 +1047,6 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@local_tsl//tsl/platform:status", - "@local_xla//xla:statusor", "@stablehlo//:stablehlo_ops", "@stablehlo//:vhlo_ops", "@stablehlo//:vhlo_types", @@ -1200,9 +1200,9 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:errors", + "@com_google_absl//absl/status:statusor", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", - "@local_xla//xla:statusor", ], ) @@ -1221,10 +1221,10 @@ cc_library( "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/core:protos_all_cc", "//tensorflow/lite/toco:toco_flags_proto_cc", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", - "@local_xla//xla:statusor", ], ) diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc index 7b93fc33f06ee5..25bf15f6be61e7 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers @@ -51,7 +52,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" -#include "xla/statusor.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/lite/core/c/builtin_op_data.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 2011b6d33ccd45..a2ec3624064dd4 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Hashing.h" @@ -58,7 +59,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" -#include "xla/statusor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.cc b/tensorflow/compiler/mlir/lite/utils/convert_type.cc index c7f922de39ad81..1edb08cff57423 100644 --- a/tensorflow/compiler/mlir/lite/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" +#include "absl/status/statusor.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project @@ -22,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "xla/statusor.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.h b/tensorflow/compiler/mlir/lite/utils/convert_type.h index 85631dbe258f8e..118f9cd4979058 100644 --- a/tensorflow/compiler/mlir/lite/utils/convert_type.h +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONVERT_TYPE_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONVERT_TYPE_H_ +#include "absl/status/statusor.h" #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" -#include "xla/statusor.h" #include "tensorflow/core/framework/types.pb.h" namespace mlir { diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 58b6a891a072b9..482160ad4b0161 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -133,7 +133,6 @@ cc_library( "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:logging", "@local_xla//xla:status_macros", - "@local_xla//xla:statusor", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:client_library", "@local_xla//xla/client:compile_only_client", @@ -619,9 +618,9 @@ cc_library( "//tensorflow/cc:ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@local_xla//xla:statusor", "@local_xla//xla/client:compile_only_client", ] + tf_grpc_cc_dependencies(), ) @@ -653,9 +652,9 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@local_xla//xla:statusor", "@local_xla//xla/client:compile_only_client", "@local_xla//xla/stream_executor/tpu:tpu_api", "@local_xla//xla/stream_executor/tpu:tpu_ops_c_api_hdrs", @@ -878,7 +877,7 @@ cc_library( ":tpu_compile_op_lib", ":tpu_compile_op_options", "//tensorflow/core/protobuf/tpu:compilation_result_proto_cc", - "@local_xla//xla:statusor", + "@com_google_absl//absl/status:statusor", "@local_xla//xla/stream_executor/tpu:tpu_node_context", ], alwayslink = True, @@ -912,6 +911,7 @@ cc_library( "//tensorflow/core/tpu:tpu_execute", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:logging", @@ -921,7 +921,6 @@ cc_library( "@local_xla//xla:shape_tree", "@local_xla//xla:shape_util", "@local_xla//xla:status_macros", - "@local_xla//xla:statusor", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/service:backend", "@local_xla//xla/service:computation_placer_hdr", diff --git a/tensorflow/core/tpu/kernels/tpu_execute_op.cc b/tensorflow/core/tpu/kernels/tpu_execute_op.cc index 7a9812be9ae43b..100daa04adb321 100644 --- a/tensorflow/core/tpu/kernels/tpu_execute_op.cc +++ b/tensorflow/core/tpu/kernels/tpu_execute_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "tensorflow/compiler/jit/variable_info.h" @@ -43,7 +44,6 @@ limitations under the License. #include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/event.h" diff --git a/tensorflow/core/tpu/kernels/tpu_util.h b/tensorflow/core/tpu/kernels/tpu_util.h index a44175fbb2c2cd..eaf822df1ae939 100644 --- a/tensorflow/core/tpu/kernels/tpu_util.h +++ b/tensorflow/core/tpu/kernels/tpu_util.h @@ -20,11 +20,11 @@ limitations under the License. #include #include "grpcpp/server_builder.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "xla/client/compile_only_client.h" -#include "xla/statusor.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" diff --git a/third_party/xla/xla/backends/interpreter/BUILD b/third_party/xla/xla/backends/interpreter/BUILD index a3958434a11d8d..8f5e83d870bd57 100644 --- a/third_party/xla/xla/backends/interpreter/BUILD +++ b/third_party/xla/xla/backends/interpreter/BUILD @@ -31,7 +31,6 @@ cc_library( ":platform_id", "//xla:literal", "//xla:status_macros", - "//xla:statusor", "//xla:util", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", @@ -56,6 +55,7 @@ cc_library( "//xla/stream_executor", "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", @@ -82,7 +82,6 @@ cc_library( "//xla:literal", "//xla:shape_tree", "//xla:shape_util", - "//xla:statusor", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:dynamic_dimension_inference", @@ -92,6 +91,7 @@ cc_library( "//xla/service:shaped_buffer", "//xla/service:transfer_manager", "//xla/stream_executor", + "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:statusor", ], ) @@ -106,7 +106,6 @@ cc_library( "//xla:literal", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:types", "//xla:xla_data_proto_cc", "//xla/hlo/evaluator:hlo_evaluator", @@ -119,6 +118,7 @@ cc_library( "//xla/service:shaped_buffer", "//xla/service:transfer_manager", "//xla/stream_executor", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", diff --git a/third_party/xla/xla/backends/interpreter/compiler.cc b/third_party/xla/xla/backends/interpreter/compiler.cc index ae3841a1d4448c..ec0befd3054f70 100644 --- a/third_party/xla/xla/backends/interpreter/compiler.cc +++ b/third_party/xla/xla/backends/interpreter/compiler.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/backends/interpreter/executable.h" #include "xla/backends/interpreter/platform_id.h" @@ -47,7 +48,6 @@ limitations under the License. #include "xla/service/topk_rewriter.h" #include "xla/service/triangular_solve_expander.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/util.h" diff --git a/third_party/xla/xla/backends/interpreter/compiler.h b/third_party/xla/xla/backends/interpreter/compiler.h index 51ce9921002776..4c8d5f0ec4c78f 100644 --- a/third_party/xla/xla/backends/interpreter/compiler.h +++ b/third_party/xla/xla/backends/interpreter/compiler.h @@ -20,13 +20,13 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "xla/backends/interpreter/platform_id.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/compiler.h" #include "xla/service/executable.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_module_config.h" -#include "xla/statusor.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/status.h" diff --git a/third_party/xla/xla/backends/interpreter/executable.h b/third_party/xla/xla/backends/interpreter/executable.h index 4a66f3bb375e29..72a078b0b0c8e9 100644 --- a/third_party/xla/xla/backends/interpreter/executable.h +++ b/third_party/xla/xla/backends/interpreter/executable.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/backends/interpreter/executable_base.h" #include "xla/hlo/evaluator/hlo_evaluator.h" @@ -28,7 +29,6 @@ limitations under the License. #include "xla/service/hlo_module_config.h" #include "xla/service/service_executable_run_options.h" #include "xla/service/shaped_buffer.h" -#include "xla/statusor.h" #include "xla/stream_executor/stream_executor.h" #include "xla/types.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/backends/interpreter/executable_base.h b/third_party/xla/xla/backends/interpreter/executable_base.h index fa55e567464435..21a675d5f7ac0f 100644 --- a/third_party/xla/xla/backends/interpreter/executable_base.h +++ b/third_party/xla/xla/backends/interpreter/executable_base.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" #include "xla/service/dynamic_dimension_inference.h" @@ -26,7 +27,6 @@ limitations under the License. #include "xla/service/hlo_execution_profile.h" #include "xla/service/service_executable_run_options.h" #include "xla/shape.h" -#include "xla/statusor.h" #include "xla/xla.pb.h" namespace xla { namespace interpreter { diff --git a/third_party/xla/xla/hlo/evaluator/BUILD b/third_party/xla/xla/hlo/evaluator/BUILD index 222257e8f9b9bb..c6ecac78ce8160 100644 --- a/third_party/xla/xla/hlo/evaluator/BUILD +++ b/third_party/xla/xla/hlo/evaluator/BUILD @@ -48,7 +48,6 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", @@ -74,6 +73,7 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:str_format", @@ -128,7 +128,6 @@ xla_cc_test( "//xla:literal_util", "//xla:permutation_util", "//xla:shape_util", - "//xla:statusor", "//xla:test", "//xla:types", "//xla:util", @@ -146,6 +145,7 @@ xla_cc_test( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc index 671a0850e682e9..85fe5b90ab897d 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc @@ -43,6 +43,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/numeric/bits.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -76,7 +77,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h index 3aa783cec9bb19..e0871ff82b5d51 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_map.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/array2d.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" @@ -37,7 +38,6 @@ limitations under the License. #include "xla/service/shape_inference.h" #include "xla/service/tuple_points_to_analysis.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc index f79183fd33fcf9..6be1d4c72f38ef 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -53,7 +54,6 @@ limitations under the License. #include "xla/service/shape_inference.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index dde332d86e9042..8af2b159c2967d 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -14,7 +14,6 @@ cc_library( hdrs = ["in_place_dynamic_update_slice.h"], deps = [ ":fusion_emitter", - "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:ir_emission_utils", @@ -26,6 +25,7 @@ cc_library( "//xla/service/llvm_ir:fused_ir_emitter", "//xla/service/llvm_ir:ir_array", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:IR", @@ -122,7 +122,6 @@ cc_library( deps = [ ":fusion_emitter", "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/ffi:attribute_map", @@ -209,7 +208,6 @@ cc_library( deps = [ "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", "//xla/service/gpu:ir_emitter_context", @@ -226,6 +224,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", diff --git a/third_party/xla/xla/service/gpu/fusions/custom.h b/third_party/xla/xla/service/gpu/fusions/custom.h index 57eb1ab63bd349..c5e758da715a8c 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.h +++ b/third_party/xla/xla/service/gpu/fusions/custom.h @@ -20,7 +20,6 @@ limitations under the License. #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emitter_context.h" -#include "xla/statusor.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.h b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.h index aec8ff8b4f3814..bcb204bbe34002 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.h +++ b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.h @@ -25,6 +25,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -39,7 +40,6 @@ limitations under the License. #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/shape.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.h b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.h index 010ccde6ccc512..fe89629892e2fe 100644 --- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.h +++ b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "llvm/IR/IRBuilder.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" @@ -31,7 +32,6 @@ limitations under the License. #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/llvm_ir/ir_array.h" -#include "xla/statusor.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/stream_executor/tpu/BUILD b/third_party/xla/xla/stream_executor/tpu/BUILD index 7ef28f3cc6466b..0b2e21933b9b10 100644 --- a/third_party/xla/xla/stream_executor/tpu/BUILD +++ b/third_party/xla/xla/stream_executor/tpu/BUILD @@ -63,7 +63,6 @@ cc_library( "//xla:shape_layout", "//xla:shape_tree", "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", @@ -584,7 +583,6 @@ cc_library( deps = [ "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:compiler", @@ -597,6 +595,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", @@ -619,7 +618,6 @@ cc_library( ":tpu_executor", ":tpu_executor_api", ":tpu_executor_c_api_hdrs", - "//xla:statusor", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:executable", @@ -631,6 +629,7 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:logging", diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_conversions.h b/third_party/xla/xla/stream_executor/tpu/c_api_conversions.h index 043d860f665c65..2e473e411f1085 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_conversions.h +++ b/third_party/xla/xla/stream_executor/tpu/c_api_conversions.h @@ -31,7 +31,6 @@ limitations under the License. #include "xla/service/shaped_buffer.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/tpu/c_api_decl.h" diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executable.h b/third_party/xla/xla/stream_executor/tpu/tpu_executable.h index 34ce516b2fc1b9..c5b639e9bf7c27 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executable.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executable.h @@ -24,13 +24,13 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/executable.h" #include "xla/service/hlo_execution_profile.h" #include "xla/service/service_executable_run_options.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/tpu/c_api_decl.h" #include "xla/stream_executor/tpu/tpu_executable_interface.h" diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.h b/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.h index 06082387f4258a..ce9555a1fedf3d 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" @@ -30,7 +31,6 @@ limitations under the License. #include "xla/service/hlo_execution_profile.h" #include "xla/service/service_executable_run_options.h" #include "xla/shape.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor.h" From aaebeb863e75ea007093aaf904ed5b307f3096c8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 19 Jun 2024 22:23:16 -0700 Subject: [PATCH 035/256] Automated Code Change PiperOrigin-RevId: 644895910 --- tensorflow/core/kernels/gather_nd_op_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/gather_nd_op_test.cc b/tensorflow/core/kernels/gather_nd_op_test.cc index cfbc46a8ad59cb..3212068f389a95 100644 --- a/tensorflow/core/kernels/gather_nd_op_test.cc +++ b/tensorflow/core/kernels/gather_nd_op_test.cc @@ -157,7 +157,7 @@ TEST_F(GatherNdOpConstructionTest, Error_BadIndicesPolicyInvalid) { .Input(FakeInput(DT_INT32)) .Attr("bad_indices_policy", "AN_UNRECOGNIZED_POLICY") .Finalize(node_def())); - EXPECT_NE(InitOp(), OkStatus()); + EXPECT_NE(InitOp(), absl::OkStatus()); } constexpr int kLookups = 2000; From 9b12cd1d4053f8b93128eec529956ea7521fe63d Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Wed, 19 Jun 2024 22:54:37 -0700 Subject: [PATCH 036/256] PR #13603: NVTX: name threads, CUDA devices and CUDA streams Imported from GitHub PR https://github.com/openxla/xla/pull/13603 This aims to improve the profiling experience. These names are shown in the Nsight Systems UI. Device names: ![Screenshot 2024-06-10 at 14 52 37](https://github.com/openxla/xla/assets/6459623/d889d37e-ca2e-4f5e-b5bd-240bbb625b4c) Stream names: ![Screenshot 2024-06-10 at 14 53 25](https://github.com/openxla/xla/assets/6459623/4bfc4ffa-8fdf-4b93-b23e-95bf056799f3) Thread names: ![Screenshot 2024-06-10 at 14 54 04](https://github.com/openxla/xla/assets/6459623/8852ca9e-f2f4-4a45-8334-a18f8ab5ce18) This also provides a missing link between replica IDs in the HLO and the physical devices in the profile. Copybara import of the project: -- 5b3121c58db8aa1b6529f0aeb8573be8bf2cde80 by Olli Lupton : NVTX: name threads, CUDA devices and CUDA streams -- d973674de6218fcee88473d85bb43ba345652fdf by Olli Lupton : Address review comments -- 918cf3e7b87150e9d666b218bbd9aca0cae606a4 by Olli Lupton : Alternative for @jbaiocchi -- 1d1978437e64c0dac97e97ea4320a6dcb3945296 by Olli Lupton : Address more review comments Merging this change closes #13603 PiperOrigin-RevId: 644901234 --- .../tsl/tsl/profiler/lib/nvtx_utils.cc | 35 +++++++++++++++++++ .../tsl/tsl/profiler/lib/nvtx_utils.h | 13 +++++++ .../tsl/tsl/profiler/lib/nvtx_utils_stub.cc | 5 +++ third_party/xla/xla/pjrt/BUILD | 1 + third_party/xla/xla/pjrt/gpu/BUILD | 1 + .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 18 ++++++++++ .../xla/xla/pjrt/local_device_state.cc | 31 +++++++++++----- .../xla/xla/service/gpu/infeed_manager.cc | 4 ++- third_party/xla/xla/service/stream_pool.cc | 2 ++ third_party/xla/xla/stream_executor/BUILD | 1 + third_party/xla/xla/stream_executor/gpu/BUILD | 3 ++ .../xla/xla/stream_executor/gpu/gpu_stream.cc | 7 ++++ .../xla/xla/stream_executor/gpu/gpu_stream.h | 3 ++ third_party/xla/xla/stream_executor/stream.h | 4 +++ .../xla/xla/stream_executor/stream_common.cc | 2 ++ .../xla/xla/stream_executor/stream_common.h | 6 ++++ .../stream_executor_memory_allocator.cc | 1 + .../trace_command_buffer_factory.cc | 1 + 18 files changed, 128 insertions(+), 10 deletions(-) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc index b122c6e12dfc19..703a4b0b3934ca 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc @@ -14,16 +14,37 @@ limitations under the License. ==============================================================================*/ #include "tsl/profiler/lib/nvtx_utils.h" +#include + #include #include +#include #include #include #include #include "nvtx3/nvToolsExt.h" +#include "nvtx3/nvToolsExtCuda.h" +#include "nvtx3/nvToolsExtCudaRt.h" #include "nvtx3/nvToolsExtPayload.h" +#include "third_party/gpus/cuda/include/cuda.h" + +namespace { +// Get the ID of the current thread following the convention for +// nvtxNameOsThreadA: +// https://nvidia.github.io/NVTX/doxygen/group___r_e_s_o_u_r_c_e___n_a_m_i_n_g.html +// This convention may not match the one in tsl::Env::GetCurrentThreadId(). +std::optional GetCurrentThreadId() { +#ifdef __linux__ + return syscall(SYS_gettid); +#else + return std::nullopt; +#endif +} +} // namespace namespace tsl::profiler { +static_assert(std::is_pointer_v); static_assert(std::is_pointer_v); static_assert(std::is_pointer_v); @@ -33,6 +54,20 @@ ProfilerDomainHandle DefaultProfilerDomain() { return domain; } +void NameCurrentThread(const std::string& thread_name) { + if (std::optional tid = GetCurrentThreadId(); tid.has_value()) { + nvtxNameOsThreadA(*tid, thread_name.c_str()); + } +} + +void NameDevice(int device_id, const std::string& device_name) { + nvtxNameCudaDeviceA(device_id, device_name.c_str()); +} + +void NameStream(StreamHandle stream, const std::string& stream_name) { + nvtxNameCuStreamA(reinterpret_cast(stream), stream_name.c_str()); +} + void RangePop(ProfilerDomainHandle domain) { nvtxDomainRangePop(reinterpret_cast(domain)); } diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.h index 0727072d06390d..4d65c39e4bbd0e 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.h @@ -34,6 +34,19 @@ using ProfilerDomainHandle = ProfilerDomain*; // Get the "TSL" domain if NVTX profiling is enabled, otherwise null ProfilerDomainHandle DefaultProfilerDomain(); +// Assign a human-readable name to the current thread +void NameCurrentThread(const std::string&); + +// Assign a human-readable name to the given local device +void NameDevice(int device_id, const std::string& device_name); + +struct Stream; +// Opaque handle to an execution stream +using StreamHandle = Stream*; + +// Assign a human-readable name to the given execution stream +void NameStream(StreamHandle stream, const std::string& stream_name); + // Register a string with the profiler/NVTX implementation for faster use StringHandle RegisterString(ProfilerDomainHandle, const std::string&); diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils_stub.cc b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils_stub.cc index c887af77ec8b11..ad2e7d20301969 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils_stub.cc @@ -12,10 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tsl/profiler/lib/nvtx_utils.h" namespace tsl::profiler { ProfilerDomainHandle DefaultProfilerDomain() { return {}; } +void NameCurrentThread(const std::string&) {} +void NameDevice(int, const std::string&) {} +void NameStream(StreamHandle, const std::string&) {} void RangePop(ProfilerDomainHandle) {} void RangePush(ProfilerDomainHandle, const char*) {} namespace detail { diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index aa7dfdbdc39a11..eebff42d56ba5f 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -135,6 +135,7 @@ cc_library( "//xla/stream_executor", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 2d383b4aec8bbc..8ae1dfc355c621 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -116,6 +116,7 @@ cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:connected_traceme", + "@local_tsl//tsl/profiler/lib:nvtx_utils", "@local_tsl//tsl/profiler/lib:traceme", ] + if_cuda_or_rocm([ ":nccl_id_store", diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 6e40add20420af..1748f89386df28 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -88,6 +88,7 @@ limitations under the License. #include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" #include "tsl/profiler/lib/connected_traceme.h" +#include "tsl/profiler/lib/nvtx_utils.h" #include "tsl/profiler/lib/traceme.h" #if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) @@ -1015,6 +1016,23 @@ absl::Status BuildDistributedDevices( TF_RET_CHECK(it->second != nullptr); local_device = std::move(it->second); gpu_device_ids[device_proto.local_device_ordinal()] = global_device_id; + // Assign some descriptive names for profiling tools. + auto suffix = + absl::StrFormat(":#global=%d,local=%d,process=%d,slice=%d#", + device_proto.global_device_id(), + device_proto.local_device_ordinal(), node.node_id(), + device_proto.slice_index()); + // Name the device. + tsl::profiler::NameDevice(device_proto.local_device_ordinal(), + absl::StrCat("Xla", suffix)); + // Name the thread that launches work on this device. This is deferred + // until after ExchangeTopologies has been called so the global device + // id and slice index are known. These are not available when the thread + // is created. + local_device->execute_thread()->Schedule( + [name = absl::StrCat("XlaLauncher", suffix)] { + tsl::profiler::NameCurrentThread(name); + }); } auto device = std::make_unique( device_proto.global_device_id(), std::move(local_device), diff --git a/third_party/xla/xla/pjrt/local_device_state.cc b/third_party/xla/xla/pjrt/local_device_state.cc index bf52adbe6d1173..62eba2b6238098 100644 --- a/third_party/xla/xla/pjrt/local_device_state.cc +++ b/third_party/xla/xla/pjrt/local_device_state.cc @@ -19,10 +19,12 @@ limitations under the License. #include #include #include +#include #include #include #include "absl/log/check.h" +#include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "xla/stream_executor/stream.h" #include "xla/util.h" @@ -59,34 +61,43 @@ LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor, int num_device_to_device_streams = stream_options.has_value() ? stream_options->num_device_to_device_streams : kNumDeviceToDeviceStreams; - auto create_stream = [executor, &stream_options]() { + auto create_stream = [executor, &stream_options](std::string const& name) { + std::unique_ptr stream; if (stream_options.has_value()) { - return executor->CreateStream(stream_options->priority).value(); + stream = executor->CreateStream(stream_options->priority).value(); } else { - return executor->CreateStream().value(); + stream = executor->CreateStream().value(); } + if (stream) { + stream->set_name(name); + } + return stream; }; - compute_stream_ = create_stream(); - host_to_device_stream_ = create_stream(); + compute_stream_ = create_stream("Compute"); + host_to_device_stream_ = create_stream("Host-to-device"); if (use_callback_stream) { callback_stream_map_ = absl::flat_hash_map>(); } device_to_host_streams_.reserve(num_device_to_host_streams); for (int i = 0; i < num_device_to_host_streams; ++i) { - device_to_host_streams_.emplace_back(create_stream()); + device_to_host_streams_.emplace_back( + create_stream(absl::StrFormat("Device-to-host #%d", i))); } device_to_device_streams_.reserve(num_device_to_device_streams); for (int i = 0; i < num_device_to_device_streams; ++i) { - device_to_device_streams_.emplace_back(create_stream()); + device_to_device_streams_.emplace_back( + create_stream(absl::StrFormat("Device-to-device #%d", i))); } fixed_size_pool_usage_streams_.reserve(kNumFixedSizePoolUsageStreams); for (int i = 0; i < kNumFixedSizePoolUsageStreams; ++i) { - fixed_size_pool_usage_streams_.emplace_back(create_stream()); + fixed_size_pool_usage_streams_.emplace_back( + create_stream(absl::StrFormat("Fixed size pool #%d", i))); } external_ready_event_streams_.reserve(kNumExternalReadyEventStreams); for (int i = 0; i < kNumExternalReadyEventStreams; ++i) { - external_ready_event_streams_.emplace_back(create_stream()); + external_ready_event_streams_.emplace_back( + create_stream(absl::StrFormat("External ready event #%d", i))); } execute_thread_ = std::make_unique(tsl::Env::Default(), "py_xla_execute"); @@ -143,6 +154,7 @@ absl::Status LocalDeviceState::ThenExecuteCallback( auto callback_stream = callback_stream_map_->find(stream); if (callback_stream == callback_stream_map_->end()) { TF_ASSIGN_OR_RETURN(auto new_stream, executor_->CreateStream()); + new_stream->set_name(absl::StrFormat("Callback for %s", stream->name())); callback_stream = callback_stream_map_->insert({stream, std::move(new_stream)}).first; } @@ -231,6 +243,7 @@ std::unique_ptr LocalDeviceState::BorrowStreamFromPool() { // The stream pool is empty, create a new stream. auto stream = compute_stream_->parent()->CreateStream().value(); + stream->set_name("Pool stream"); return stream; } diff --git a/third_party/xla/xla/service/gpu/infeed_manager.cc b/third_party/xla/xla/service/gpu/infeed_manager.cc index cfb49bb8f340c7..2ddc95c3d6fba3 100644 --- a/third_party/xla/xla/service/gpu/infeed_manager.cc +++ b/third_party/xla/xla/service/gpu/infeed_manager.cc @@ -46,7 +46,9 @@ constexpr int kMaxInfeedsInFlight = 8; InfeedManager::InfeedManager(se::StreamExecutor* executor) : BlockingXfeedQueue(/*max_pending_xfeeds=*/kMaxInfeedsInFlight), - stream_(executor->CreateStream().value()) {} + stream_(executor->CreateStream().value()) { + stream_->set_name("Infeed manager"); +} static absl::StatusOr CopyBufferToDevice( se::Stream* stream, int64_t size, const void* source) { diff --git a/third_party/xla/xla/service/stream_pool.cc b/third_party/xla/xla/service/stream_pool.cc index 54f5c773e76138..e94b5867cd8b51 100644 --- a/third_party/xla/xla/service/stream_pool.cc +++ b/third_party/xla/xla/service/stream_pool.cc @@ -51,6 +51,8 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamPriority priority) { if (!stream) { // Create a new stream. stream = executor_->CreateStream(priority).value(); + stream->set_name(absl::StrFormat("%s pool stream", + se::StreamPriorityToString(priority))); VLOG(1) << absl::StrFormat("Created new stream (%p) with priority = %s", stream.get(), se::StreamPriorityToString(priority)); diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index bc5bf073667f94..5e3efefff90f20 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -707,6 +707,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 9b9237f12189d2..18805d5d239d79 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -310,6 +310,7 @@ gpu_only_cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_common", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings:string_view", ], ) @@ -331,7 +332,9 @@ gpu_only_cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/profiler/lib:nvtx_utils", ], ) diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc index 536bb1adefc9c4..aafd3ae241ec21 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "tsl/platform/errors.h" +#include "tsl/profiler/lib/nvtx_utils.h" namespace stream_executor { namespace gpu { @@ -113,6 +114,12 @@ bool GpuStream::IsIdle() const { return GpuDriver::IsStreamIdle(parent_->gpu_context(), gpu_stream_); } +void GpuStream::set_name(absl::string_view name) { + name_ = name; + tsl::profiler::NameStream( + reinterpret_cast(gpu_stream()), name_); +} + GpuStream* AsGpuStream(Stream* stream) { DCHECK(stream != nullptr); return static_cast(stream); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h index d3ac630f214ce1..4a9f2650cde5d9 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/log/check.h" +#include "absl/strings/string_view.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_types.h" @@ -100,6 +101,8 @@ class GpuStream : public StreamCommon { absl::Status RecordEvent(Event* event) override; absl::Status MemZero(DeviceMemoryBase* location, uint64_t size) override; + void set_name(absl::string_view name) override; + private: GpuExecutor* parent_; // Executor that spawned this stream. GpuStreamHandle gpu_stream_; // Wrapped CUDA stream handle. diff --git a/third_party/xla/xla/stream_executor/stream.h b/third_party/xla/xla/stream_executor/stream.h index 4fd904befad159..f8f03f21cf1983 100644 --- a/third_party/xla/xla/stream_executor/stream.h +++ b/third_party/xla/xla/stream_executor/stream.h @@ -258,6 +258,10 @@ class Stream { virtual absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, const Kernel &k, const KernelArgs &args) = 0; + + // Get/set a name for a stream, which can be shown in profiling tools + virtual absl::string_view name() const = 0; + virtual void set_name(absl::string_view name) = 0; }; template diff --git a/third_party/xla/xla/stream_executor/stream_common.cc b/third_party/xla/xla/stream_executor/stream_common.cc index cf4c3ee7efd75e..f9e7b46e9845f4 100644 --- a/third_party/xla/xla/stream_executor/stream_common.cc +++ b/third_party/xla/xla/stream_executor/stream_common.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_memory.h" @@ -100,6 +101,7 @@ absl::StatusOr StreamCommon::GetOrCreateSubStream() { // No streams are reusable; create a new stream. TF_ASSIGN_OR_RETURN(auto stream, parent_->CreateStream()); Stream *sub_stream = stream.get(); + sub_stream->set_name(absl::StrFormat("Sub-stream of %s", name())); sub_streams_.emplace_back(std::move(stream), false); VLOG(1) << "stream=" << this << " created new sub_stream=" << sub_stream; diff --git a/third_party/xla/xla/stream_executor/stream_common.h b/third_party/xla/xla/stream_executor/stream_common.h index 39e3e025fb2245..66f80fbed05b30 100644 --- a/third_party/xla/xla/stream_executor/stream_common.h +++ b/third_party/xla/xla/stream_executor/stream_common.h @@ -101,6 +101,10 @@ class StreamCommon : public Stream { absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, const Kernel &k, const KernelArgs &args) override; + // Doesn't do anything interesting by default; GpuStream connects this to NVTX + absl::string_view name() const override { return name_; } + void set_name(absl::string_view name) override { name_ = name; } + protected: bool InErrorState() const TF_LOCKS_EXCLUDED(mu_) { absl::ReaderMutexLock lock(&mu_); @@ -116,6 +120,8 @@ class StreamCommon : public Stream { void SetError() { CheckError(false /* = operation_retcode */); } + std::string name_; + private: // The StreamExecutor that supports the operation of this stream. StreamExecutor *parent_; diff --git a/third_party/xla/xla/stream_executor/stream_executor_memory_allocator.cc b/third_party/xla/xla/stream_executor/stream_executor_memory_allocator.cc index 858d6a3a5c1ebd..c17656c6d34147 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_memory_allocator.cc +++ b/third_party/xla/xla/stream_executor/stream_executor_memory_allocator.cc @@ -105,6 +105,7 @@ absl::StatusOr StreamExecutorMemoryAllocator::GetStream( if (!streams_.count(device_ordinal)) { TF_ASSIGN_OR_RETURN(auto stream, executor->CreateStream()); auto stream_ptr = stream.get(); + stream_ptr->set_name("StreamExecutorMemoryAllocator"); streams_.emplace(device_ordinal, std::move(stream)); return stream_ptr; } diff --git a/third_party/xla/xla/stream_executor/trace_command_buffer_factory.cc b/third_party/xla/xla/stream_executor/trace_command_buffer_factory.cc index 5efd0cc4b2d337..0f18b68a08e343 100644 --- a/third_party/xla/xla/stream_executor/trace_command_buffer_factory.cc +++ b/third_party/xla/xla/stream_executor/trace_command_buffer_factory.cc @@ -34,6 +34,7 @@ TraceCommandBufferFactory::Create( absl::AnyInvocable function, CommandBuffer::Mode mode) { TF_ASSIGN_OR_RETURN(auto stream, executor->CreateStream()); + stream->set_name("Command buffer tracer"); return TraceCommandBufferFactory::Create(executor, stream.get(), std::move(function), mode); } From 464f895b598c9142de1ea1f5754ec9061b7b67e2 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Thu, 20 Jun 2024 00:49:16 -0700 Subject: [PATCH 037/256] Make GPU PJRT tests xla_tests There are two remaining tests that are still using the `xla_cc_test` macro for tests that require a GPU. I'm migrating those over to the `xla_test` macro so that we can move forward with removing the `xla_cc_test` macro and thus simplify our test infrastructure. PiperOrigin-RevId: 644925610 --- third_party/xla/xla/pjrt/gpu/BUILD | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 8ae1dfc355c621..85956ebdf13857 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -293,13 +293,12 @@ cc_library( ], ) -xla_cc_test( +xla_test( name = "se_gpu_pjrt_compiler_test", srcs = if_gpu_is_configured(["se_gpu_pjrt_compiler_test.cc"]), + backends = ["gpu"], tags = [ - "gpu", "no_oss", - "requires-gpu-nvidia", ] + if_google(["config-cuda-only"]), deps = [ ":se_gpu_pjrt_client", @@ -309,7 +308,6 @@ xla_cc_test( "//xla/mlir_hlo", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_compiler", - "//xla/service:gpu_plugin", "//xla/service:hlo_parser", "//xla/stream_executor/cuda:cublas_plugin", "//xla/tests:literal_test_util", @@ -323,13 +321,12 @@ xla_cc_test( ], ) -xla_cc_test( +xla_test( name = "se_gpu_pjrt_compiler_aot_test", srcs = if_gpu_is_configured(["se_gpu_pjrt_compiler_aot_test.cc"]), + backends = ["gpu"], tags = [ - "gpu", "no_oss", - "requires-gpu-nvidia", ] + if_google(["config-cuda-only"]), deps = [ ":se_gpu_pjrt_client", @@ -342,7 +339,6 @@ xla_cc_test( "//xla/pjrt:pjrt_compiler", "//xla/pjrt:pjrt_executable", "//xla/service:compiler", - "//xla/service:gpu_plugin", "//xla/service:hlo_parser", "//xla/tests:literal_test_util", "@com_google_absl//absl/memory", From 36947e0dc6cb0b70034384a6ac843fe241577d77 Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Thu, 20 Jun 2024 01:27:14 -0700 Subject: [PATCH 038/256] XNNPack MEAN F32 supports all reduction types PiperOrigin-RevId: 644935163 --- .../lite/delegates/xnnpack/mean_test.cc | 36 ++--- .../delegates/xnnpack/xnnpack_delegate.cc | 137 ++++++++++-------- 2 files changed, 98 insertions(+), 75 deletions(-) diff --git a/tensorflow/lite/delegates/xnnpack/mean_test.cc b/tensorflow/lite/delegates/xnnpack/mean_test.cc index e5307e684ecd72..62ffdce26f6461 100644 --- a/tensorflow/lite/delegates/xnnpack/mean_test.cc +++ b/tensorflow/lite/delegates/xnnpack/mean_test.cc @@ -25,7 +25,7 @@ limitations under the License. namespace tflite { namespace xnnpack { -TEST(Mean, DISABLED_4DReduceBatchSqueezeDims) { +TEST(Mean, 4DReduceBatchSqueezeDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -46,7 +46,7 @@ TEST(Mean, DISABLED_4DReduceBatchSqueezeDims) { .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); } -TEST(Mean, DISABLED_4DReduceBatchKeepDims) { +TEST(Mean, 4DReduceBatchKeepDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -67,7 +67,7 @@ TEST(Mean, DISABLED_4DReduceBatchKeepDims) { .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); } -TEST(Mean, DISABLED_4DReduceHeightSqueezeDims) { +TEST(Mean, 4DReduceHeightSqueezeDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -88,7 +88,7 @@ TEST(Mean, DISABLED_4DReduceHeightSqueezeDims) { .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); } -TEST(Mean, DISABLED_4DReduceHeightKeepDims) { +TEST(Mean, 4DReduceHeightKeepDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -205,7 +205,7 @@ TEST(Mean, 4DReduceHeightWidthKeepDims) { .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); } -TEST(Mean, DISABLED_4DReduceChannelsSqueezeDims) { +TEST(Mean, 4DReduceChannelsSqueezeDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -226,7 +226,7 @@ TEST(Mean, DISABLED_4DReduceChannelsSqueezeDims) { .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); } -TEST(Mean, DISABLED_4DReduceChannelsKeepDims) { +TEST(Mean, 4DReduceChannelsKeepDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -247,7 +247,7 @@ TEST(Mean, DISABLED_4DReduceChannelsKeepDims) { .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); } -TEST(Mean, DISABLED_3DReduceBatchSqueezeDims) { +TEST(Mean, 3DReduceBatchSqueezeDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -267,7 +267,7 @@ TEST(Mean, DISABLED_3DReduceBatchSqueezeDims) { .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); } -TEST(Mean, DISABLED_3DReduceBatchKeepDims) { +TEST(Mean, 3DReduceBatchKeepDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -287,7 +287,7 @@ TEST(Mean, DISABLED_3DReduceBatchKeepDims) { .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); } -TEST(Mean, DISABLED_3DReduceWidthSqueezeDims) { +TEST(Mean, 3DReduceWidthSqueezeDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -307,7 +307,7 @@ TEST(Mean, DISABLED_3DReduceWidthSqueezeDims) { .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); } -TEST(Mean, DISABLED_3DReduceWidthKeepDims) { +TEST(Mean, 3DReduceWidthKeepDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -327,7 +327,7 @@ TEST(Mean, DISABLED_3DReduceWidthKeepDims) { .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); } -TEST(Mean, DISABLED_3DReduceChannelsSqueezeDims) { +TEST(Mean, 3DReduceChannelsSqueezeDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -347,7 +347,7 @@ TEST(Mean, DISABLED_3DReduceChannelsSqueezeDims) { .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); } -TEST(Mean, DISABLED_3DReduceChannelsKeepDims) { +TEST(Mean, 3DReduceChannelsKeepDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -367,7 +367,7 @@ TEST(Mean, DISABLED_3DReduceChannelsKeepDims) { .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); } -TEST(Mean, DISABLED_2DReduceBatchSqueezeDims) { +TEST(Mean, 2DReduceBatchSqueezeDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -386,7 +386,7 @@ TEST(Mean, DISABLED_2DReduceBatchSqueezeDims) { .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); } -TEST(Mean, DISABLED_2DReduceBatchKeepDims) { +TEST(Mean, 2DReduceBatchKeepDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -405,7 +405,7 @@ TEST(Mean, DISABLED_2DReduceBatchKeepDims) { .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); } -TEST(Mean, DISABLED_2DReduceChannelsSqueezeDims) { +TEST(Mean, 2DReduceChannelsSqueezeDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -424,7 +424,7 @@ TEST(Mean, DISABLED_2DReduceChannelsSqueezeDims) { .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); } -TEST(Mean, DISABLED_2DReduceChannelsKeepDims) { +TEST(Mean, 2DReduceChannelsKeepDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -443,7 +443,7 @@ TEST(Mean, DISABLED_2DReduceChannelsKeepDims) { .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); } -TEST(Mean, DISABLED_1DSqueezeDims) { +TEST(Mean, 1DSqueezeDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -458,7 +458,7 @@ TEST(Mean, DISABLED_1DSqueezeDims) { BuiltinOperator_MEAN, xnnpack_delegate.get()); } -TEST(Mean, DISABLED_1DKeepDims) { +TEST(Mean, 1DKeepDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index e40294c18d08ff..0935d28388d182 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -4793,48 +4793,54 @@ class Subgraph { const int32_t* axes_data = reinterpret_cast(axes_tensor.data.data); const int num_reduction_axes = NumElements(&axes_tensor); - switch (num_reduction_axes) { - case 1: - if (axes_data[0] != 2) { - TF_LITE_MAYBE_KERNEL_LOG( - logging_context, - "unsupported MEAN reduction along non-spatial " - "axis %d in node %d", - axes_data[0], node_index); - return kTfLiteError; - } - break; - case 2: - if (std::min(axes_data[0], axes_data[1]) != 1 || - std::max(axes_data[0], axes_data[1]) != 2) { - TF_LITE_MAYBE_KERNEL_LOG( - logging_context, - "unsupported MEAN reduction along non-spatial " - "axes %d and %d in node %d", - std::min(axes_data[0], axes_data[1]), - std::max(axes_data[0], axes_data[1]), node_index); - return kTfLiteError; - } - break; - default: - TF_LITE_MAYBE_KERNEL_LOG( - logging_context, - "unsupported MEAN reduction along %d axes in node %d", - SizeOfDimension(&axes_tensor, 0), node_index); - return kTfLiteError; + bool all_reductions_supported = false; + if (input_tensor.type == kTfLiteFloat32) { + all_reductions_supported = true; } - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, node->outputs->data[0], node_index)); - int expected_output_dims = 4; - if (!reducer_params->keep_dims) { - expected_output_dims -= num_reduction_axes; + if (!all_reductions_supported) { + switch (num_reduction_axes) { + case 1: + if (axes_data[0] != 2) { + TF_LITE_MAYBE_KERNEL_LOG( + logging_context, + "unsupported MEAN reduction along non-spatial " + "axis %d in node %d", + axes_data[0], node_index); + return kTfLiteError; + } + break; + case 2: + if (std::min(axes_data[0], axes_data[1]) != 1 || + std::max(axes_data[0], axes_data[1]) != 2) { + TF_LITE_MAYBE_KERNEL_LOG( + logging_context, + "unsupported MEAN reduction along non-spatial " + "axes %d and %d in node %d", + std::min(axes_data[0], axes_data[1]), + std::max(axes_data[0], axes_data[1]), node_index); + return kTfLiteError; + } + break; + default: + TF_LITE_MAYBE_KERNEL_LOG( + logging_context, + "unsupported MEAN reduction along %d axes in node %d", + SizeOfDimension(&axes_tensor, 0), node_index); + return kTfLiteError; + } + int expected_output_dims = 4; + if (!reducer_params->keep_dims) { + expected_output_dims -= num_reduction_axes; + } + TF_LITE_ENSURE_STATUS(CheckTensorShape( + logging_context, output_tensor, expected_output_dims, + node->outputs->data[0], BuiltinOperator_MEAN, node_index)); } - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, output_tensor, expected_output_dims, - node->outputs->data[0], BuiltinOperator_MEAN, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( delegate, logging_context, output_tensor, node->outputs->data[0], node_index)); @@ -4842,27 +4848,44 @@ class Subgraph { if (subgraph != nullptr) { uint32_t flags = reducer_params->keep_dims ? XNN_FLAG_KEEP_DIMS : 0; xnn_status status = xnn_status_success; - switch (num_reduction_axes) { - case 1: - status = xnn_define_global_average_pooling_1d( - subgraph, - /*output_min=*/-std::numeric_limits::infinity(), - /*output_max=*/+std::numeric_limits::infinity(), - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - flags); - break; - case 2: - status = xnn_define_global_average_pooling_2d( - subgraph, - /*output_min=*/-std::numeric_limits::infinity(), - /*output_max=*/+std::numeric_limits::infinity(), - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - flags); - break; - default: - break; + if (all_reductions_supported) { + std::array reduction_axes; + for (int i = 0; i < num_reduction_axes; ++i) { + if (axes_data[i] < 0) { + reduction_axes[i] = axes_data[i] + NumDimensions(&input_tensor); + } else { + reduction_axes[i] = axes_data[i]; + } + } + std::sort(&reduction_axes[0], &reduction_axes[num_reduction_axes]); + status = xnn_define_static_mean( + subgraph, num_reduction_axes, reduction_axes.data(), + /*input_id=*/input_output_tensors.at(node->inputs->data[0]), + /*output_id=*/input_output_tensors.at(node->outputs->data[0]), + flags); + } else { + switch (num_reduction_axes) { + case 1: + status = xnn_define_global_average_pooling_1d( + subgraph, + /*output_min=*/-std::numeric_limits::infinity(), + /*output_max=*/+std::numeric_limits::infinity(), + /*input_id=*/input_output_tensors.at(node->inputs->data[0]), + /*output_id=*/input_output_tensors.at(node->outputs->data[0]), + flags); + break; + case 2: + status = xnn_define_global_average_pooling_2d( + subgraph, + /*output_min=*/-std::numeric_limits::infinity(), + /*output_max=*/+std::numeric_limits::infinity(), + /*input_id=*/input_output_tensors.at(node->inputs->data[0]), + /*output_id=*/input_output_tensors.at(node->outputs->data[0]), + flags); + break; + default: + break; + } } if (status != xnn_status_success) { TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", From 6cf72fef91bdc7ff5113c332a27279ff34eb4af0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 20 Jun 2024 02:02:14 -0700 Subject: [PATCH 039/256] compat: Update forward compatibility horizon to 2024-06-20 PiperOrigin-RevId: 644943664 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 1609786a1c118c..5a4b90b703f387 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 6, 19) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 6, 20) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 8ee8ff8dc2ad61c81cc3da699a9ea22909d586ec Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 20 Jun 2024 02:02:24 -0700 Subject: [PATCH 040/256] Update GraphDef version to 1899. PiperOrigin-RevId: 644943726 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 535b7ffad284fc..be635a1e132f54 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1898 // Updated: 2024/6/19 +#define TF_GRAPH_DEF_VERSION 1899 // Updated: 2024/6/20 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 75adcb5e1760bb56f53abc580a87a8b884d7bfe1 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Thu, 20 Jun 2024 02:04:32 -0700 Subject: [PATCH 041/256] Clean up some sentinel `-1`s. std::optional is better for this. Also add a TODO to see if we can get rid of the rewrite that increases the number of divisions. PiperOrigin-RevId: 644944444 --- .../xla/xla/service/gpu/model/indexing_map.cc | 65 +++++++++++-------- 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index 86fff568b3b7ee..3b072b91b77254 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -193,13 +193,13 @@ AffineExpr AffineExprSimplifier::RewriteMod(AffineBinaryOpExpr mod) { new_lhs = new_lhs + (extracted_constant % m); Interval no_multiplier_range{0, 0}; - int64_t multiplier_gcd = -1; + std::optional multiplier_gcd = std::nullopt; VisitSummands(new_lhs, [&](AffineExpr expr) { if (auto multiplier = GetConstantRhs(expr, AffineExprKind::Mul)) { - if (multiplier_gcd == -1) { - multiplier_gcd = *multiplier; + if (multiplier_gcd.has_value()) { + multiplier_gcd = std::gcd(*multiplier_gcd, *multiplier); } else { - multiplier_gcd = std::gcd(multiplier_gcd, *multiplier); + multiplier_gcd = *multiplier; } } else { auto range = range_evaluator_->ComputeExpressionRange(expr); @@ -209,17 +209,20 @@ AffineExpr AffineExprSimplifier::RewriteMod(AffineBinaryOpExpr mod) { }); mlir::AffineExpr extracted = getAffineConstantExpr(0, mod.getContext()); - if (m % multiplier_gcd == 0 && no_multiplier_range.lower >= 0 && - no_multiplier_range.upper < multiplier_gcd) { - // Remove everything that doesn't have a multiplier. - new_lhs = RemoveSummands(new_lhs, [&](AffineExpr expr) { - if (GetConstantRhs(expr, AffineExprKind::Mul)) { - return false; - } - extracted = extracted + expr; - return true; - }); + if (multiplier_gcd.has_value()) { + if (m % *multiplier_gcd == 0 && no_multiplier_range.lower >= 0 && + no_multiplier_range.upper < *multiplier_gcd) { + // Remove everything that doesn't have a multiplier. + new_lhs = RemoveSummands(new_lhs, [&](AffineExpr expr) { + if (GetConstantRhs(expr, AffineExprKind::Mul)) { + return false; + } + extracted = extracted + expr; + return true; + }); + } } + return new_lhs % mod.getRHS() + extracted; } @@ -293,16 +296,17 @@ AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) { return extracted; } - int64_t multiplier_gcd = -1; - // The maximum GCD of any remaining multiplier inside the div and the divisor. - int64_t max_remaining_multiplier_gcd = -1; + std::optional multiplier_gcd = std::nullopt; + // The maximum GCD of (divisor, any multiplier inside the div). + int64_t max_remaining_multiplier_gcd = 1; VisitSummands(new_dividend, [&](AffineExpr summand) { if (auto multiplier = GetConstantRhs(summand, AffineExprKind::Mul)) { - if (multiplier_gcd == -1) { - multiplier_gcd = *multiplier; + if (multiplier_gcd.has_value()) { + multiplier_gcd = std::gcd(*multiplier_gcd, *multiplier); } else { - multiplier_gcd = std::gcd(multiplier_gcd, *multiplier); + multiplier_gcd = multiplier; } + max_remaining_multiplier_gcd = std::max(max_remaining_multiplier_gcd, std::gcd(*multiplier, d)); } else { @@ -312,19 +316,24 @@ AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) { } }); - if ((d % multiplier_gcd) == 0) { - if (no_multiplier_range.lower >= 0 && - no_multiplier_range.upper < multiplier_gcd) { - // Remove everything that doesn't have a multiplier. - new_dividend = RemoveSummands(new_dividend, [&](AffineExpr expr) { - auto mult = GetConstantRhs(expr, AffineExprKind::Mul); - return !mult.has_value(); - }); + if (multiplier_gcd.has_value()) { + if ((d % *multiplier_gcd) == 0) { + if (no_multiplier_range.lower >= 0 && + no_multiplier_range.upper < *multiplier_gcd) { + // Remove everything that doesn't have a multiplier. + new_dividend = RemoveSummands(new_dividend, [&](AffineExpr expr) { + auto mult = GetConstantRhs(expr, AffineExprKind::Mul); + return !mult.has_value(); + }); + } } } // If we have a gcd > 1, we can split the div into two: // (x * 128 + y) // 192 -> (x * 2 + y // 64) // 3 + // TODO(jreiffers): This is currently required for some simplifications, but + // it increases the number of divisions, which is not really a simplification. + // See if we can avoid this rewrite. if (max_remaining_multiplier_gcd > 1) { AffineExpr partially_extracted = getAffineConstantExpr(0, mlir_context); new_dividend = RemoveSummands(new_dividend, [&](AffineExpr expr) { From acc8b5b9069afcbb468f29c6cae51bb7f4a6a801 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 20 Jun 2024 03:02:06 -0700 Subject: [PATCH 042/256] Reverts 9b12cd1d4053f8b93128eec529956ea7521fe63d PiperOrigin-RevId: 644957493 --- .../tsl/tsl/profiler/lib/nvtx_utils.cc | 35 ------------------- .../tsl/tsl/profiler/lib/nvtx_utils.h | 13 ------- .../tsl/tsl/profiler/lib/nvtx_utils_stub.cc | 5 --- third_party/xla/xla/pjrt/BUILD | 1 - third_party/xla/xla/pjrt/gpu/BUILD | 1 - .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 18 ---------- .../xla/xla/pjrt/local_device_state.cc | 31 +++++----------- .../xla/xla/service/gpu/infeed_manager.cc | 4 +-- third_party/xla/xla/service/stream_pool.cc | 2 -- third_party/xla/xla/stream_executor/BUILD | 1 - third_party/xla/xla/stream_executor/gpu/BUILD | 3 -- .../xla/xla/stream_executor/gpu/gpu_stream.cc | 7 ---- .../xla/xla/stream_executor/gpu/gpu_stream.h | 3 -- third_party/xla/xla/stream_executor/stream.h | 4 --- .../xla/xla/stream_executor/stream_common.cc | 2 -- .../xla/xla/stream_executor/stream_common.h | 6 ---- .../stream_executor_memory_allocator.cc | 1 - .../trace_command_buffer_factory.cc | 1 - 18 files changed, 10 insertions(+), 128 deletions(-) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc index 703a4b0b3934ca..b122c6e12dfc19 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc @@ -14,37 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tsl/profiler/lib/nvtx_utils.h" -#include - #include #include -#include #include #include #include #include "nvtx3/nvToolsExt.h" -#include "nvtx3/nvToolsExtCuda.h" -#include "nvtx3/nvToolsExtCudaRt.h" #include "nvtx3/nvToolsExtPayload.h" -#include "third_party/gpus/cuda/include/cuda.h" - -namespace { -// Get the ID of the current thread following the convention for -// nvtxNameOsThreadA: -// https://nvidia.github.io/NVTX/doxygen/group___r_e_s_o_u_r_c_e___n_a_m_i_n_g.html -// This convention may not match the one in tsl::Env::GetCurrentThreadId(). -std::optional GetCurrentThreadId() { -#ifdef __linux__ - return syscall(SYS_gettid); -#else - return std::nullopt; -#endif -} -} // namespace namespace tsl::profiler { -static_assert(std::is_pointer_v); static_assert(std::is_pointer_v); static_assert(std::is_pointer_v); @@ -54,20 +33,6 @@ ProfilerDomainHandle DefaultProfilerDomain() { return domain; } -void NameCurrentThread(const std::string& thread_name) { - if (std::optional tid = GetCurrentThreadId(); tid.has_value()) { - nvtxNameOsThreadA(*tid, thread_name.c_str()); - } -} - -void NameDevice(int device_id, const std::string& device_name) { - nvtxNameCudaDeviceA(device_id, device_name.c_str()); -} - -void NameStream(StreamHandle stream, const std::string& stream_name) { - nvtxNameCuStreamA(reinterpret_cast(stream), stream_name.c_str()); -} - void RangePop(ProfilerDomainHandle domain) { nvtxDomainRangePop(reinterpret_cast(domain)); } diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.h index 4d65c39e4bbd0e..0727072d06390d 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.h @@ -34,19 +34,6 @@ using ProfilerDomainHandle = ProfilerDomain*; // Get the "TSL" domain if NVTX profiling is enabled, otherwise null ProfilerDomainHandle DefaultProfilerDomain(); -// Assign a human-readable name to the current thread -void NameCurrentThread(const std::string&); - -// Assign a human-readable name to the given local device -void NameDevice(int device_id, const std::string& device_name); - -struct Stream; -// Opaque handle to an execution stream -using StreamHandle = Stream*; - -// Assign a human-readable name to the given execution stream -void NameStream(StreamHandle stream, const std::string& stream_name); - // Register a string with the profiler/NVTX implementation for faster use StringHandle RegisterString(ProfilerDomainHandle, const std::string&); diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils_stub.cc b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils_stub.cc index ad2e7d20301969..c887af77ec8b11 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils_stub.cc @@ -12,15 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "tsl/profiler/lib/nvtx_utils.h" namespace tsl::profiler { ProfilerDomainHandle DefaultProfilerDomain() { return {}; } -void NameCurrentThread(const std::string&) {} -void NameDevice(int, const std::string&) {} -void NameStream(StreamHandle, const std::string&) {} void RangePop(ProfilerDomainHandle) {} void RangePush(ProfilerDomainHandle, const char*) {} namespace detail { diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index eebff42d56ba5f..aa7dfdbdc39a11 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -135,7 +135,6 @@ cc_library( "//xla/stream_executor", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 85956ebdf13857..24d90e89faba2c 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -116,7 +116,6 @@ cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:connected_traceme", - "@local_tsl//tsl/profiler/lib:nvtx_utils", "@local_tsl//tsl/profiler/lib:traceme", ] + if_cuda_or_rocm([ ":nccl_id_store", diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 1748f89386df28..6e40add20420af 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -88,7 +88,6 @@ limitations under the License. #include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" #include "tsl/profiler/lib/connected_traceme.h" -#include "tsl/profiler/lib/nvtx_utils.h" #include "tsl/profiler/lib/traceme.h" #if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) @@ -1016,23 +1015,6 @@ absl::Status BuildDistributedDevices( TF_RET_CHECK(it->second != nullptr); local_device = std::move(it->second); gpu_device_ids[device_proto.local_device_ordinal()] = global_device_id; - // Assign some descriptive names for profiling tools. - auto suffix = - absl::StrFormat(":#global=%d,local=%d,process=%d,slice=%d#", - device_proto.global_device_id(), - device_proto.local_device_ordinal(), node.node_id(), - device_proto.slice_index()); - // Name the device. - tsl::profiler::NameDevice(device_proto.local_device_ordinal(), - absl::StrCat("Xla", suffix)); - // Name the thread that launches work on this device. This is deferred - // until after ExchangeTopologies has been called so the global device - // id and slice index are known. These are not available when the thread - // is created. - local_device->execute_thread()->Schedule( - [name = absl::StrCat("XlaLauncher", suffix)] { - tsl::profiler::NameCurrentThread(name); - }); } auto device = std::make_unique( device_proto.global_device_id(), std::move(local_device), diff --git a/third_party/xla/xla/pjrt/local_device_state.cc b/third_party/xla/xla/pjrt/local_device_state.cc index 62eba2b6238098..bf52adbe6d1173 100644 --- a/third_party/xla/xla/pjrt/local_device_state.cc +++ b/third_party/xla/xla/pjrt/local_device_state.cc @@ -19,12 +19,10 @@ limitations under the License. #include #include #include -#include #include #include #include "absl/log/check.h" -#include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "xla/stream_executor/stream.h" #include "xla/util.h" @@ -61,43 +59,34 @@ LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor, int num_device_to_device_streams = stream_options.has_value() ? stream_options->num_device_to_device_streams : kNumDeviceToDeviceStreams; - auto create_stream = [executor, &stream_options](std::string const& name) { - std::unique_ptr stream; + auto create_stream = [executor, &stream_options]() { if (stream_options.has_value()) { - stream = executor->CreateStream(stream_options->priority).value(); + return executor->CreateStream(stream_options->priority).value(); } else { - stream = executor->CreateStream().value(); + return executor->CreateStream().value(); } - if (stream) { - stream->set_name(name); - } - return stream; }; - compute_stream_ = create_stream("Compute"); - host_to_device_stream_ = create_stream("Host-to-device"); + compute_stream_ = create_stream(); + host_to_device_stream_ = create_stream(); if (use_callback_stream) { callback_stream_map_ = absl::flat_hash_map>(); } device_to_host_streams_.reserve(num_device_to_host_streams); for (int i = 0; i < num_device_to_host_streams; ++i) { - device_to_host_streams_.emplace_back( - create_stream(absl::StrFormat("Device-to-host #%d", i))); + device_to_host_streams_.emplace_back(create_stream()); } device_to_device_streams_.reserve(num_device_to_device_streams); for (int i = 0; i < num_device_to_device_streams; ++i) { - device_to_device_streams_.emplace_back( - create_stream(absl::StrFormat("Device-to-device #%d", i))); + device_to_device_streams_.emplace_back(create_stream()); } fixed_size_pool_usage_streams_.reserve(kNumFixedSizePoolUsageStreams); for (int i = 0; i < kNumFixedSizePoolUsageStreams; ++i) { - fixed_size_pool_usage_streams_.emplace_back( - create_stream(absl::StrFormat("Fixed size pool #%d", i))); + fixed_size_pool_usage_streams_.emplace_back(create_stream()); } external_ready_event_streams_.reserve(kNumExternalReadyEventStreams); for (int i = 0; i < kNumExternalReadyEventStreams; ++i) { - external_ready_event_streams_.emplace_back( - create_stream(absl::StrFormat("External ready event #%d", i))); + external_ready_event_streams_.emplace_back(create_stream()); } execute_thread_ = std::make_unique(tsl::Env::Default(), "py_xla_execute"); @@ -154,7 +143,6 @@ absl::Status LocalDeviceState::ThenExecuteCallback( auto callback_stream = callback_stream_map_->find(stream); if (callback_stream == callback_stream_map_->end()) { TF_ASSIGN_OR_RETURN(auto new_stream, executor_->CreateStream()); - new_stream->set_name(absl::StrFormat("Callback for %s", stream->name())); callback_stream = callback_stream_map_->insert({stream, std::move(new_stream)}).first; } @@ -243,7 +231,6 @@ std::unique_ptr LocalDeviceState::BorrowStreamFromPool() { // The stream pool is empty, create a new stream. auto stream = compute_stream_->parent()->CreateStream().value(); - stream->set_name("Pool stream"); return stream; } diff --git a/third_party/xla/xla/service/gpu/infeed_manager.cc b/third_party/xla/xla/service/gpu/infeed_manager.cc index 2ddc95c3d6fba3..cfb49bb8f340c7 100644 --- a/third_party/xla/xla/service/gpu/infeed_manager.cc +++ b/third_party/xla/xla/service/gpu/infeed_manager.cc @@ -46,9 +46,7 @@ constexpr int kMaxInfeedsInFlight = 8; InfeedManager::InfeedManager(se::StreamExecutor* executor) : BlockingXfeedQueue(/*max_pending_xfeeds=*/kMaxInfeedsInFlight), - stream_(executor->CreateStream().value()) { - stream_->set_name("Infeed manager"); -} + stream_(executor->CreateStream().value()) {} static absl::StatusOr CopyBufferToDevice( se::Stream* stream, int64_t size, const void* source) { diff --git a/third_party/xla/xla/service/stream_pool.cc b/third_party/xla/xla/service/stream_pool.cc index e94b5867cd8b51..54f5c773e76138 100644 --- a/third_party/xla/xla/service/stream_pool.cc +++ b/third_party/xla/xla/service/stream_pool.cc @@ -51,8 +51,6 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamPriority priority) { if (!stream) { // Create a new stream. stream = executor_->CreateStream(priority).value(); - stream->set_name(absl::StrFormat("%s pool stream", - se::StreamPriorityToString(priority))); VLOG(1) << absl::StrFormat("Created new stream (%p) with priority = %s", stream.get(), se::StreamPriorityToString(priority)); diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index 5e3efefff90f20..bc5bf073667f94 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -707,7 +707,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 18805d5d239d79..9b9237f12189d2 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -310,7 +310,6 @@ gpu_only_cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_common", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings:string_view", ], ) @@ -332,9 +331,7 @@ gpu_only_cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/profiler/lib:nvtx_utils", ], ) diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc index aafd3ae241ec21..536bb1adefc9c4 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc @@ -31,7 +31,6 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "tsl/platform/errors.h" -#include "tsl/profiler/lib/nvtx_utils.h" namespace stream_executor { namespace gpu { @@ -114,12 +113,6 @@ bool GpuStream::IsIdle() const { return GpuDriver::IsStreamIdle(parent_->gpu_context(), gpu_stream_); } -void GpuStream::set_name(absl::string_view name) { - name_ = name; - tsl::profiler::NameStream( - reinterpret_cast(gpu_stream()), name_); -} - GpuStream* AsGpuStream(Stream* stream) { DCHECK(stream != nullptr); return static_cast(stream); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h index 4a9f2650cde5d9..d3ac630f214ce1 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h @@ -22,7 +22,6 @@ limitations under the License. #include #include "absl/log/check.h" -#include "absl/strings/string_view.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_types.h" @@ -101,8 +100,6 @@ class GpuStream : public StreamCommon { absl::Status RecordEvent(Event* event) override; absl::Status MemZero(DeviceMemoryBase* location, uint64_t size) override; - void set_name(absl::string_view name) override; - private: GpuExecutor* parent_; // Executor that spawned this stream. GpuStreamHandle gpu_stream_; // Wrapped CUDA stream handle. diff --git a/third_party/xla/xla/stream_executor/stream.h b/third_party/xla/xla/stream_executor/stream.h index f8f03f21cf1983..4fd904befad159 100644 --- a/third_party/xla/xla/stream_executor/stream.h +++ b/third_party/xla/xla/stream_executor/stream.h @@ -258,10 +258,6 @@ class Stream { virtual absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, const Kernel &k, const KernelArgs &args) = 0; - - // Get/set a name for a stream, which can be shown in profiling tools - virtual absl::string_view name() const = 0; - virtual void set_name(absl::string_view name) = 0; }; template diff --git a/third_party/xla/xla/stream_executor/stream_common.cc b/third_party/xla/xla/stream_executor/stream_common.cc index f9e7b46e9845f4..cf4c3ee7efd75e 100644 --- a/third_party/xla/xla/stream_executor/stream_common.cc +++ b/third_party/xla/xla/stream_executor/stream_common.cc @@ -29,7 +29,6 @@ limitations under the License. #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_memory.h" @@ -101,7 +100,6 @@ absl::StatusOr StreamCommon::GetOrCreateSubStream() { // No streams are reusable; create a new stream. TF_ASSIGN_OR_RETURN(auto stream, parent_->CreateStream()); Stream *sub_stream = stream.get(); - sub_stream->set_name(absl::StrFormat("Sub-stream of %s", name())); sub_streams_.emplace_back(std::move(stream), false); VLOG(1) << "stream=" << this << " created new sub_stream=" << sub_stream; diff --git a/third_party/xla/xla/stream_executor/stream_common.h b/third_party/xla/xla/stream_executor/stream_common.h index 66f80fbed05b30..39e3e025fb2245 100644 --- a/third_party/xla/xla/stream_executor/stream_common.h +++ b/third_party/xla/xla/stream_executor/stream_common.h @@ -101,10 +101,6 @@ class StreamCommon : public Stream { absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, const Kernel &k, const KernelArgs &args) override; - // Doesn't do anything interesting by default; GpuStream connects this to NVTX - absl::string_view name() const override { return name_; } - void set_name(absl::string_view name) override { name_ = name; } - protected: bool InErrorState() const TF_LOCKS_EXCLUDED(mu_) { absl::ReaderMutexLock lock(&mu_); @@ -120,8 +116,6 @@ class StreamCommon : public Stream { void SetError() { CheckError(false /* = operation_retcode */); } - std::string name_; - private: // The StreamExecutor that supports the operation of this stream. StreamExecutor *parent_; diff --git a/third_party/xla/xla/stream_executor/stream_executor_memory_allocator.cc b/third_party/xla/xla/stream_executor/stream_executor_memory_allocator.cc index c17656c6d34147..858d6a3a5c1ebd 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_memory_allocator.cc +++ b/third_party/xla/xla/stream_executor/stream_executor_memory_allocator.cc @@ -105,7 +105,6 @@ absl::StatusOr StreamExecutorMemoryAllocator::GetStream( if (!streams_.count(device_ordinal)) { TF_ASSIGN_OR_RETURN(auto stream, executor->CreateStream()); auto stream_ptr = stream.get(); - stream_ptr->set_name("StreamExecutorMemoryAllocator"); streams_.emplace(device_ordinal, std::move(stream)); return stream_ptr; } diff --git a/third_party/xla/xla/stream_executor/trace_command_buffer_factory.cc b/third_party/xla/xla/stream_executor/trace_command_buffer_factory.cc index 0f18b68a08e343..5efd0cc4b2d337 100644 --- a/third_party/xla/xla/stream_executor/trace_command_buffer_factory.cc +++ b/third_party/xla/xla/stream_executor/trace_command_buffer_factory.cc @@ -34,7 +34,6 @@ TraceCommandBufferFactory::Create( absl::AnyInvocable function, CommandBuffer::Mode mode) { TF_ASSIGN_OR_RETURN(auto stream, executor->CreateStream()); - stream->set_name("Command buffer tracer"); return TraceCommandBufferFactory::Create(executor, stream.get(), std::move(function), mode); } From d7609726e6399873ec61324b7681e5ed53b8dd88 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 20 Jun 2024 03:16:26 -0700 Subject: [PATCH 043/256] Delete flags xla_gpu_max_mlir_kernels and xla_gpu_skip_mlir_kernels These were only needed for debugging. PiperOrigin-RevId: 644961292 --- third_party/xla/xla/debug_options_flags.cc | 12 ---------- .../xla/xla/service/gpu/fusions/fusions.cc | 22 ++----------------- .../xla/xla/service/gpu/fusions/fusions.h | 2 +- .../xla/service/gpu/ir_emitter_unnested.cc | 6 ++--- third_party/xla/xla/xla.proto | 4 ++-- 5 files changed, 7 insertions(+), 39 deletions(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 979e90d3933d60..b223d7c0c320e7 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -239,8 +239,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_nccl_p2p_max_nchannels(0); opts.set_xla_gpu_mlir_emitter_level(0); - opts.set_xla_gpu_max_mlir_kernels(0); - opts.set_xla_gpu_skip_mlir_kernels(0); opts.set_xla_gpu_multi_streamed_windowed_einsum(false); @@ -1698,16 +1696,6 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_mlir_emitter_level(), "Enable new MLIR-based emitters. Level 0 means disabled, " "higher levels enable more of the emitters.")); - flag_list->push_back( - tsl::Flag("xla_gpu_max_mlir_kernels", - int64_setter_for(&DebugOptions::set_xla_gpu_max_mlir_kernels), - debug_options->xla_gpu_max_mlir_kernels(), - "Maximum number of kernels to emit with MLIR.")); - flag_list->push_back( - tsl::Flag("xla_gpu_skip_mlir_kernels", - int64_setter_for(&DebugOptions::set_xla_gpu_skip_mlir_kernels), - debug_options->xla_gpu_skip_mlir_kernels(), - "Number of initial kernels to skip MLIR emission for.")); flag_list->push_back(tsl::Flag( "xla_gpu_multi_streamed_windowed_einsum", bool_setter_for( diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.cc b/third_party/xla/xla/service/gpu/fusions/fusions.cc index 700ff2d70f4044..6a241da5af54fe 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusions.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusions.cc @@ -99,8 +99,8 @@ bool HloFusionInfo::CanEmitDynamicUpdateSliceInPlace() const { return ret.ok() && *ret; } -std::unique_ptr GetFusionEmitter(const FusionInfo& fusion_info, - bool is_emission_phase) { +std::unique_ptr GetFusionEmitter( + const FusionInfo& fusion_info) { const auto& analysis = fusion_info.analysis(); const FusionBackendConfig& backend_config = analysis.fusion_backend_config(); @@ -120,24 +120,6 @@ std::unique_ptr GetFusionEmitter(const FusionInfo& fusion_info, << "Unsupported fusion: " << analysis.fusion_root(0).instruction().parent()->ToString(); - static int num_mlir_emitters = 0; - if (is_emission_phase) { - // This kernel can be emitted with MLIR, but we need to check if there are - // limits to how many kernels can be emitted. - ++num_mlir_emitters; - if (num_mlir_emitters <= opts.xla_gpu_skip_mlir_kernels()) { - VLOG(5) - << "Skipping MLIR emission because initial skips were requested."; - return false; - } - - int n_emitted = num_mlir_emitters - opts.xla_gpu_skip_mlir_kernels(); - if (opts.xla_gpu_max_mlir_kernels() > 0 && - n_emitted > opts.xla_gpu_max_mlir_kernels()) { - VLOG(5) << "Skipping MLIR emission because max_mlir_emitters was set."; - return false; - } - } VLOG(5) << "Emitting with MLIR."; return true; }; diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.h b/third_party/xla/xla/service/gpu/fusions/fusions.h index 96d7fa547a2aa0..9011c80d7f9f43 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusions.h +++ b/third_party/xla/xla/service/gpu/fusions/fusions.h @@ -89,7 +89,7 @@ class PreBufferAssignmentFusionInfo : public FusionInfo { // Returns the emitter for the given fusion. std::unique_ptr GetFusionEmitter( - const FusionInfo& fusion_info, bool is_emission_phase = false); + const FusionInfo& fusion_info); } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 77ad59390f5873..40c985752ab2ea 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -1691,10 +1691,8 @@ absl::Status IrEmitterUnnested::EmitFusion(const HloFusionInstruction* instr) { const HloFusionAnalysis fusion_analysis = HloFusionAnalysis::Create(instr, &device_info); - std::unique_ptr emitter = - GetFusionEmitter(HloFusionInfo(fusion_analysis, instr, - &ir_emitter_context_->buffer_assignment()), - /*is_emission_phase=*/true); + std::unique_ptr emitter = GetFusionEmitter(HloFusionInfo( + fusion_analysis, instr, &ir_emitter_context_->buffer_assignment())); TF_ASSIGN_OR_RETURN(auto result, emitter->Emit(*ir_emitter_context_, *instr)); const ExecutionStreamAssignment& stream_assignment = diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 16b1da1999d7c3..8e92692a535323 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -746,10 +746,10 @@ message DebugOptions { // 4: + Reduce int64 xla_gpu_mlir_emitter_level = 303; // The maximum number of kernels to emit with MLIR. Unlimited if 0. - int64 xla_gpu_max_mlir_kernels = 281; + reserved 281; // was xla_gpu_max_mlir_kernels // The number of initial kernels to not emit with MLIR. Only supported kernels // are counted. - int64 xla_gpu_skip_mlir_kernels = 282; + reserved 282; // was xla_gpu_skip_mlir_kernels // Threshold to rewrite matmul to cuBLAS or Triton (minumum combined number of // elements of both matrices in non-batch dimensions to be considered for a From a69eff2ec792e0da2600b6c7f14da4d644f8c3f4 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Thu, 20 Jun 2024 03:56:44 -0700 Subject: [PATCH 044/256] PR #13831: [GPU] Improve dumping of GEMM fusions. Imported from GitHub PR https://github.com/openxla/xla/pull/13831 - also dump non-optimized fusions to make easier running them in isolation - change the dump file name prefix - cleanup the code - test dump file contents Copybara import of the project: -- e52a1d3530ffaa557d0e26456bd2d0921caf0548 by Ilia Sergachev : [GPU] Improve dumping of GEMM fusions. - also dump non-optimized fusions to make easier running them in isolation - change the dump file name prefix - cleanup the code - test dump file contents Merging this change closes #13831 PiperOrigin-RevId: 644970582 --- third_party/xla/xla/service/gpu/BUILD | 2 + .../xla/service/gpu/gemm_fusion_autotuner.cc | 71 +++++++++++------- .../service/gpu/gemm_fusion_autotuner_test.cc | 75 ++++++++++++++----- 3 files changed, 104 insertions(+), 44 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 7beb1794d8ba19..509b11d8e609de 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -873,6 +873,7 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/pjrt/distributed:key_value_store_interface", "//xla/service:call_inliner", + "//xla/service:dump", "//xla/service:executable", "//xla/service:hlo_module_config", "//xla/service:hlo_pass_pipeline", @@ -893,6 +894,7 @@ xla_test( "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc index b11d4e4f2ff5a1..25c1b5b3894c8f 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc @@ -421,12 +421,18 @@ absl::StatusOr> CublasGemmAutotuneExtractor( return new_module; } -absl::StatusOr> CudnnGemmAutotuneExtractor( - const AutotuneConfig& autotune_config, const HloFusionInstruction* fusion, - const DebugOptions& debug_opts, const int plan_id) { - std::unique_ptr new_module = - ExtractInstructionIntoNewModule(*fusion); - new_module->mutable_config().set_debug_options(debug_opts); +absl::StatusOr> FusionExtractor( + const HloFusionInstruction& fusion, const DebugOptions& debug_opts) { + std::unique_ptr module = ExtractInstructionIntoNewModule(fusion); + module->mutable_config().set_debug_options(debug_opts); + return module; +} + +absl::StatusOr> CuDnnFusionExtractor( + const HloFusionInstruction& fusion, const DebugOptions& debug_opts, + const int plan_id) { + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + FusionExtractor(fusion, debug_opts)); GpuBackendConfig gpu_config; FusionBackendConfig& backend_config = @@ -435,10 +441,9 @@ absl::StatusOr> CudnnGemmAutotuneExtractor( // Provided a plan ID the autotuner just compiles one plan. backend_config.mutable_cudnn_fusion_config()->set_plan_id(plan_id); TF_RETURN_IF_ERROR( - new_module->entry_computation()->root_instruction()->set_backend_config( + module->entry_computation()->root_instruction()->set_backend_config( gpu_config)); - - return new_module; + return module; } bool IsFusionKind(const HloInstruction& hlo, absl::string_view kind) { @@ -476,6 +481,26 @@ AutotuneResult FromConfig(const Config& config) { return res; } +absl::Status DumpOriginalFusion(AutotunerCompileUtil& util, + const HloFusionInstruction& fusion, + int fusion_id) { + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + util.ExtractModule([&](const DebugOptions& debug_opts) { + return FusionExtractor(fusion, debug_opts); + })); + module->set_name(std::string(fusion.name())); + // Using the original module for its debug info and name in the first + // parameter. It's better to include the name of both the original module + // and the extracted module, to avoid name clashes. + DumpToFileInDirOrStdout( + /*module=*/*fusion.GetModule(), + /*file_prefix=*/"", + /*file_suffix=*/ + absl::StrCat("gemm_fusion_", fusion_id, ".", module->name(), ".txt"), + /*contents=*/module->ToString()); + return absl::OkStatus(); +} + absl::Status DumpAutotunedFusion(const AutotuneConfig& autotune_config, const int32_t toolkit_version, AutotunerCompileUtil& util, @@ -483,12 +508,7 @@ absl::Status DumpAutotunedFusion(const AutotuneConfig& autotune_config, const HloFusionInstruction* fusion, int fusion_id) { TritonGemmConfig triton_gemm_config; - if (!result.has_triton()) { - LOG(WARNING) << "Using empty triton GEMM config for op " << fusion->name(); - // Empty TritonGemmConfig has all zero values which is good enough to keep - // fused computation in the dump but illustrate that Triton is not used for - // it after autotuning. - } else { + if (result.has_triton()) { TF_ASSIGN_OR_RETURN(triton_gemm_config, TritonGemmConfig::FromProto(result.triton())); } @@ -498,8 +518,8 @@ absl::Status DumpAutotunedFusion(const AutotuneConfig& autotune_config, std::unique_ptr module, util.ExtractModule([&](const DebugOptions& debug_opts) { if (result.has_algorithm()) { - return CudnnGemmAutotuneExtractor(autotune_config, fusion, debug_opts, - result.algorithm().algo_id()); + return CuDnnFusionExtractor(*fusion, debug_opts, + result.algorithm().algo_id()); } else if (result.has_triton()) { return TritonGemmAutotuneExtractor( triton_gemm_config, device_desc, fusion, debug_opts, @@ -519,7 +539,7 @@ absl::Status DumpAutotunedFusion(const AutotuneConfig& autotune_config, /*module=*/*fusion->GetModule(), /*file_prefix=*/"", /*file_suffix=*/ - absl::StrCat("triton_fusion_", fusion_id, ".", module->name(), + absl::StrCat("gemm_fusion_", fusion_id, ".", module->name(), ".optimized.txt"), /*contents=*/module->ToString()); return absl::OkStatus(); @@ -745,13 +765,13 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util, allow_filtering_kernels_spilling_registers); })); } else if (std::holds_alternative(config)) { - executable = compile_util - .Compile([&](const DebugOptions& opts) { - return CudnnGemmAutotuneExtractor( - config_, fusion, opts, - std::get(config).plan_id); - }) - .value_or(nullptr); + executable = + compile_util + .Compile([&](const DebugOptions& opts) { + return CuDnnFusionExtractor( + *fusion, opts, std::get(config).plan_id); + }) + .value_or(nullptr); } else if (std::holds_alternative(config)) { TF_ASSIGN_OR_RETURN(executable, compile_util.Compile([&](const DebugOptions& opts) { @@ -1094,6 +1114,7 @@ absl::Status GemmFusionAutotunerImpl::Autotune( << tsl::proto_utils::FromDurationProto(best.run_time()); if (debug_options_.xla_gpu_dump_autotuned_gemm_fusions()) { + TF_RETURN_IF_ERROR(DumpOriginalFusion(compile_util, *fusion, fusion_id)); TF_RETURN_IF_ERROR(DumpAutotunedFusion( config_, toolkit_version_, compile_util, best, fusion, fusion_id++)); } diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc index c5195a1506abad..19b0d3877f3490 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include -#include #include #include "absl/log/check.h" #include "absl/log/log.h" @@ -38,6 +37,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/service/call_inliner.h" +#include "xla/service/dump.h" #include "xla/service/executable.h" #include "xla/service/gpu/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" @@ -61,6 +61,7 @@ limitations under the License. #include "tsl/platform/cpu_info.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" +#include "tsl/platform/path.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" @@ -558,16 +559,7 @@ ENTRY %e { EXPECT_NE(executable, nullptr); } -class GemmFusionAutotunerDumpTest : public GemmFusionAutotunerTest { - public: - DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = - GemmFusionAutotunerTest::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_cublas_fallback(true); - debug_options.set_xla_gpu_dump_autotuned_gemm_fusions(true); - return debug_options; - } -}; +using GemmFusionAutotunerDumpTest = GemmFusionAutotunerTest; TEST_F(GemmFusionAutotunerDumpTest, Fp8CublasltFallbackSupport) { const std::string kHloText = R"( @@ -639,24 +631,69 @@ ENTRY main { EXPECT_TRUE(filecheck_matches); } -TEST_F(GemmFusionAutotunerDumpTest, DumpingFusionsWorksWithFallback) { +TEST_F(GemmFusionAutotunerDumpTest, DumpingWorks) { + HloModuleConfig config; + DebugOptions options = GetDebugOptionsForTest(); + options.set_xla_gpu_cublas_fallback(true); + options.set_xla_gpu_dump_autotuned_gemm_fusions(true); + std::string output_directory; + if (!tsl::io::GetTestUndeclaredOutputsDir(&output_directory)) { + output_directory = tsl::testing::TmpDir(); + } + options.set_xla_dump_to(output_directory); + config.set_debug_options(options); // Computation is chosen such that relatively heavy math operations before the // GEMM are not worth fusing because they would get duplicated many times and // slow down execution. Therefore autotuning picks cuBLAS here. - const std::string kHloText = R"( -ENTRY e { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +fusion1 { p0 = f32[3333,3333] parameter(0) s = f32[3333,3333] sine(p0) p1 = f32[3333,3333] parameter(1) c = f32[3333,3333] cosine(p1) ROOT dot = f32[3333,3333] dot(s, c), lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; +} - MatchOptimizedHlo(kHloText, R"( -; CHECK: cublas -; CHECK-NOT: triton -)"); +ENTRY e { + p0 = f32[3333,3333] parameter(0) + p1 = f32[3333,3333] parameter(1) + ROOT rr = f32[3333,3333] fusion(p0, p1), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__triton_gemm"}} +})", + config)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, + GetOptimizedModule(std::move(module))); + + std::string dump; + TF_EXPECT_OK(tsl::ReadFileToString( + tsl::Env::Default(), + tsl::io::JoinPath(output_directory, + FilenameFor(*optimized_module, /*prefix=*/"", + /*suffix=*/"gemm_fusion_0.rr.txt")), + &dump)); + EXPECT_TRUE(*RunFileCheck(dump, R"( +CHECK: HloModule rr +CHECK-NOT: cublas +CHECK: __triton_gemm +CHECK-NOT: block_m +)")); + + dump.clear(); + + TF_EXPECT_OK(tsl::ReadFileToString( + tsl::Env::Default(), + tsl::io::JoinPath( + output_directory, + FilenameFor(*optimized_module, /*prefix=*/"", + /*suffix=*/"gemm_fusion_0.rr.optimized.txt")), + &dump)); + EXPECT_TRUE(*RunFileCheck(dump, R"( +CHECK: HloModule rr +CHECK-NOT: triton +CHECK: cublas +)")); } TEST_F(GemmFusionAutotunerTest, AutotuneCuDnnFusion) { From c72085d4733c4305fb9ce0a2fb086f6c80f01700 Mon Sep 17 00:00:00 2001 From: Ruturaj Vaidya Date: Thu, 20 Jun 2024 03:56:45 -0700 Subject: [PATCH 045/256] PR #13340: [ROCm] Add Swizzle instruction support for mi100+ in reduction fusion shuffle-down function Imported from GitHub PR https://github.com/openxla/xla/pull/13340 Copybara import of the project: -- 0ec69b93a456f91c5c00f6b266958732563af9ce by Ruturaj4 : [ROCm] Add Swizzle instruction support for mi100+ -- 5a3a1fa9e0ed1ec7f1ae4d2ab925e53b44b3c2cb by Ruturaj4 : [ROCm] Fix tests incorporating swizzle changes Merging this change closes #13340 PiperOrigin-RevId: 644970595 --- .../xla/xla/service/gpu/fusions/reduction.cc | 5 +- .../xla/xla/service/gpu/ir_emission_utils.cc | 46 ++++++++++- .../xla/xla/service/gpu/ir_emission_utils.h | 5 +- .../service/gpu/tests/reduce_atomic_min.hlo | 20 ++--- .../gpu/tests/reduce_column_layout_change.hlo | 10 +-- .../service/gpu/tests/reduce_f64_column.hlo | 20 ++--- .../gpu/tests/reduce_large_row_to_scalar.hlo | 80 +++++++++---------- .../gpu/tests/reduce_row_vectorized.hlo | 20 ++--- .../gpu/tests/reduce_variadic_column.hlo | 20 ++--- .../gpu/tests/reduction_vectorization_test.cc | 2 +- 10 files changed, 134 insertions(+), 94 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/reduction.cc b/third_party/xla/xla/service/gpu/fusions/reduction.cc index e4186401e9a61a..8c72ff84668396 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction.cc @@ -556,8 +556,9 @@ void ReductionGroupEmitter::EmitFullWarpShuffleDownLoopForReduce( builder->CreateLoad(shuffled_value_type, partial_result_address, "partial_reduction_result"); builder->CreateStore( - EmitFullWarpShuffleDown(partial_result, builder->getInt32(distance), - builder), + EmitFullWarpShuffleDown( + partial_result, builder->getInt32(distance), builder, + reduction_emitter_.ir_emitter_context_.gpu_device_info()), result_from_other_lane); } diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.cc b/third_party/xla/xla/service/gpu/ir_emission_utils.cc index 069ae1cf75e6f6..5b73e28aa8f0c6 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.cc @@ -226,6 +226,35 @@ llvm::Value* EmitAMDGPUShflDown(llvm::Value* value, llvm::Value* offset, return b->CreateBitCast(result, value->getType()); } +llvm::Value* EmitAMDGPUShflDownSwizzle(llvm::Value* value, llvm::Value* offset, + llvm::IRBuilder<>* b) { + llvm::Module* module = b->GetInsertBlock()->getModule(); + CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32); + auto* i32_ty = b->getInt32Ty(); + + llvm::Function* intrinsic = llvm::cast( + module + ->getOrInsertFunction( + "llvm.amdgcn.ds.swizzle", + llvm::FunctionType::get(/*Result=*/i32_ty, {i32_ty, i32_ty}, + /*isVarArg=*/false)) + .getCallee()); + + // Ensure that the first argument to the AMDGPU intrinsic is i32. + llvm::Value* bitcast_value = b->CreateBitCast(value, i32_ty); + + // Calculate the control value for the swizzle operation. + llvm::Value* control_value = + b->CreateAdd(b->CreateMul(offset, b->getInt32(0x20)), b->getInt32(0x1f)); + + // Create the call to the intrinsic function. + llvm::Value* result = + b->CreateCall(intrinsic, {bitcast_value, control_value}); + + // Bitcast the result back to the original type of the input value. + return b->CreateBitCast(result, value->getType()); +} + // Helper function to emit call to NVPTX shfl_down intrinsic. llvm::Value* EmitNVPTXShflDown(llvm::Value* value, llvm::Value* offset, llvm::IRBuilder<>* b) { @@ -266,8 +295,9 @@ llvm::Value* EmitSPIRShflDown(llvm::Value* value, llvm::Value* offset, } } -llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, - llvm::IRBuilder<>* builder) { +llvm::Value* EmitFullWarpShuffleDown( + llvm::Value* value, llvm::Value* offset, llvm::IRBuilder<>* builder, + const se::DeviceDescription& gpu_device_info) { int bit_width = value->getType()->getPrimitiveSizeInBits(); llvm::Module* module = builder->GetInsertBlock()->getModule(); llvm::Triple target_triple = llvm::Triple(module->getTargetTriple()); @@ -277,6 +307,9 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, if (target_triple.isNVPTX()) { return EmitNVPTXShflDown(value, offset, builder); } else if (target_triple.getArch() == llvm::Triple::amdgcn) { + if (gpu_device_info.rocm_compute_capability().gfx9_mi100_or_later()) { + return EmitAMDGPUShflDownSwizzle(value, offset, builder); + } return EmitAMDGPUShflDown(value, offset, builder); } else if (target_triple.isSPIR()) { return EmitSPIRShflDown(value, offset, builder); @@ -299,8 +332,13 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, insert_val = EmitNVPTXShflDown(builder->CreateExtractElement(x, i), offset, builder); } else if (target_triple.getArch() == llvm::Triple::amdgcn) { - insert_val = EmitAMDGPUShflDown(builder->CreateExtractElement(x, i), - offset, builder); + if (gpu_device_info.rocm_compute_capability().gfx9_mi100_or_later()) { + insert_val = EmitAMDGPUShflDownSwizzle( + builder->CreateExtractElement(x, i), offset, builder); + } else { + insert_val = EmitAMDGPUShflDown(builder->CreateExtractElement(x, i), + offset, builder); + } } else if (target_triple.isSPIR()) { insert_val = EmitSPIRShflDown(builder->CreateExtractElement(x, i), offset, builder); diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.h b/third_party/xla/xla/service/gpu/ir_emission_utils.h index 73e640eda511d4..b9ad5e304f7c1f 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.h +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.h @@ -114,8 +114,9 @@ bool IsContiguousSlice(const Shape& orig, const Shape& sliced); // can't correctly do so on both Volta and earlier GPUs. // // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync -llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, - llvm::IRBuilder<>* builder); +llvm::Value* EmitFullWarpShuffleDown( + llvm::Value* value, llvm::Value* offset, llvm::IRBuilder<>* builder, + const se::DeviceDescription& gpu_device_info); // Emits code that determines whether the current thread is thread 0 within // block 0 of the kernel. diff --git a/third_party/xla/xla/service/gpu/tests/reduce_atomic_min.hlo b/third_party/xla/xla/service/gpu/tests/reduce_atomic_min.hlo index 46e1240f947f8f..c30c6591454c81 100644 --- a/third_party/xla/xla/service/gpu/tests/reduce_atomic_min.hlo +++ b/third_party/xla/xla/service/gpu/tests/reduce_atomic_min.hlo @@ -112,7 +112,7 @@ ENTRY reduce.1 { // CHECK: %[[VAL_86:.*]] = load float, ptr{{.*}} %partial_reduction_result, align 4 // CHECK-PTX: %[[VAL_87:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_86]], i32 16, i32 31) // CHECK-GCN: %[[VAL_87_1:.*]] = bitcast float %[[VAL_86]] to i32 -// CHECK-GCN: %[[VAL_87_2:.*]] = call i32 @__ockl_readuplane_i32 +// CHECK-GCN: %[[VAL_87_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_87_1]], i32 543) // CHECK-GCN: %[[VAL_87:.*]] = bitcast i32 %[[VAL_87_2]] to float // CHECK: store float %[[VAL_87]], ptr{{.*}} %[[VAL_37]], align 4 // CHECK-GCN: %[[VAL_88_1:.*]] = addrspacecast ptr{{.*}} %partial_reduction_result to ptr @@ -125,7 +125,7 @@ ENTRY reduce.1 { // CHECK: %[[VAL_89:.*]] = load float, ptr{{.*}} %partial_reduction_result, align 4 // CHECK-PTX: %[[VAL_90:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_89]], i32 8, i32 31) // CHECK-GCN: %[[VAL_90_1:.*]] = bitcast float %[[VAL_89]] to i32 -// CHECK-GCN: %[[VAL_90_2:.*]] = call i32 @__ockl_readuplane_i32 +// CHECK-GCN: %[[VAL_90_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_90_1]], i32 287) // CHECK-GCN: %[[VAL_90:.*]] = bitcast i32 %[[VAL_90_2]] to float // CHECK: store float %[[VAL_90]], ptr{{.*}} %[[VAL_35]], align 4 // CHECK-GCN: %[[VAL_91_1:.*]] = addrspacecast ptr{{.*}} %partial_reduction_result to ptr @@ -138,7 +138,7 @@ ENTRY reduce.1 { // CHECK: %[[VAL_92:.*]] = load float, ptr{{.*}} %partial_reduction_result, align 4 // CHECK-PTX: %[[VAL_93:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_92]], i32 4, i32 31) // CHECK-GCN: %[[VAL_93_1:.*]] = bitcast float %[[VAL_92]] to i32 -// CHECK-GCN: %[[VAL_93_2:.*]] = call i32 @__ockl_readuplane_i32 +// CHECK-GCN: %[[VAL_93_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_93_1]], i32 159) // CHECK-GCN: %[[VAL_93:.*]] = bitcast i32 %[[VAL_93_2]] to float // CHECK: store float %[[VAL_93]], ptr{{.*}} %[[VAL_33]], align 4 // CHECK-GCN: %[[VAL_94_1:.*]] = addrspacecast ptr{{.*}} %partial_reduction_result to ptr @@ -151,7 +151,7 @@ ENTRY reduce.1 { // CHECK: %[[VAL_95:.*]] = load float, ptr{{.*}} %partial_reduction_result, align 4 // CHECK-PTX: %[[VAL_96:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_95]], i32 2, i32 31) // CHECK-GCN: %[[VAL_96_1:.*]] = bitcast float %[[VAL_95]] to i32 -// CHECK-GCN: %[[VAL_96_2:.*]] = call i32 @__ockl_readuplane_i32 +// CHECK-GCN: %[[VAL_96_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_96_1]], i32 95) // CHECK-GCN: %[[VAL_96:.*]] = bitcast i32 %[[VAL_96_2]] to float // CHECK: store float %[[VAL_96]], ptr{{.*}} %[[VAL_31]], align 4 // CHECK-GCN: %[[VAL_97_1:.*]] = addrspacecast ptr{{.*}} %partial_reduction_result to ptr @@ -164,7 +164,7 @@ ENTRY reduce.1 { // CHECK: %[[VAL_98:.*]] = load float, ptr{{.*}} %partial_reduction_result, align 4 // CHECK-PTX: %[[VAL_99:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_98]], i32 1, i32 31) // CHECK-GCN: %[[VAL_99_1:.*]] = bitcast float %[[VAL_98]] to i32 -// CHECK-GCN: %[[VAL_99_2:.*]] = call i32 @__ockl_readuplane_i32 +// CHECK-GCN: %[[VAL_99_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_99_1]], i32 63) // CHECK-GCN: %[[VAL_99:.*]] = bitcast i32 %[[VAL_99_2]] to float // CHECK: store float %[[VAL_99]], ptr{{.*}} %[[VAL_29]], align 4 // CHECK-GCN: %[[VAL_100_1:.*]] = addrspacecast ptr{{.*}} %partial_reduction_result to ptr @@ -336,7 +336,7 @@ ENTRY reduce.1 { // CHECK-PTX: %[[VAL_178:.*]] = select i1 %[[VAL_177]], ptr %[[VAL_176]], ptr %[[VAL_27]] // CHECK: %[[VAL_179:.*]] = load float, ptr %[[VAL_178]], align 4 // CHECK-GCN: %[[VAL_179_1:.*]] = bitcast float %[[VAL_179]] to i32 -// CHECK-GCN: %[[VAL_180:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_179_1]], i32 16) +// CHECK-GCN: %[[VAL_180:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_179_1]], i32 543) // CHECK-GCN: %[[VAL_180_1:.*]] = bitcast i32 %[[VAL_180]] to float // CHECK-GCN: store float %[[VAL_180_1]], ptr{{.*}} %[[VAL_26]], align 4 // CHECK-PTX: %[[VAL_180:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_179]], i32 16, i32 31) @@ -349,7 +349,7 @@ ENTRY reduce.1 { // CHECK: store float %[[VAL_181]], ptr %[[VAL_178]], align 4 // CHECK: %[[VAL_182:.*]] = load float, ptr %[[VAL_178]], align 4 // CHECK-GCN: %[[VAL_182_1:.*]] = bitcast float %[[VAL_182]] to i32 -// CHECK-GCN: %[[VAL_183:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_182_1]], i32 8) +// CHECK-GCN: %[[VAL_183:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_182_1]], i32 287) // CHECK-GCN: %[[VAL_183_1:.*]] = bitcast i32 %[[VAL_183]] to float // CHECK-GCN: store float %[[VAL_183_1]], ptr{{.*}} %[[VAL_24]], align 4 // CHECK-PTX: %[[VAL_183:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_182]], i32 8, i32 31) @@ -362,7 +362,7 @@ ENTRY reduce.1 { // CHECK: store float %[[VAL_184]], ptr %[[VAL_178]], align 4 // CHECK: %[[VAL_185:.*]] = load float, ptr %[[VAL_178]], align 4 // CHECK-GCN: %[[VAL_185_1:.*]] = bitcast float %[[VAL_185]] to i32 -// CHECK-GCN: %[[VAL_186:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_185_1]], i32 4) +// CHECK-GCN: %[[VAL_186:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_185_1]], i32 159) // CHECK-GCN: %[[VAL_186_1:.*]] = bitcast i32 %[[VAL_186]] to float // CHECK-GCN: store float %[[VAL_186_1]], ptr{{.*}} %[[VAL_22]], align 4 // CHECK-PTX: %[[VAL_186:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_185]], i32 4, i32 31) @@ -375,7 +375,7 @@ ENTRY reduce.1 { // CHECK: store float %[[VAL_187]], ptr %[[VAL_178]], align 4 // CHECK: %[[VAL_188:.*]] = load float, ptr %[[VAL_178]], align 4 // CHECK-GCN: %[[VAL_188_1:.*]] = bitcast float %[[VAL_188]] to i32 -// CHECK-GCN: %[[VAL_189:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_188_1]], i32 2) +// CHECK-GCN: %[[VAL_189:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_188_1]], i32 95) // CHECK-GCN: %[[VAL_189_1:.*]] = bitcast i32 %[[VAL_189]] to float // CHECK-GCN: store float %[[VAL_189_1]], ptr{{.*}} %[[VAL_20]], align 4 // CHECK-PTX: %[[VAL_189:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_188]], i32 2, i32 31) @@ -388,7 +388,7 @@ ENTRY reduce.1 { // CHECK: store float %[[VAL_190]], ptr %[[VAL_178]], align 4 // CHECK: %[[VAL_191:.*]] = load float, ptr %[[VAL_178]], align 4 // CHECK-GCN: %[[VAL_191_1:.*]] = bitcast float %[[VAL_191]] to i32 -// CHECK-GCN: %[[VAL_192:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_191_1]], i32 1) +// CHECK-GCN: %[[VAL_192:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_191_1]], i32 63) // CHECK-GCN: %[[VAL_192_1:.*]] = bitcast i32 %[[VAL_192]] to float // CHECK-GCN: store float %[[VAL_192_1]], ptr{{.*}} %[[VAL_18]], align 4 // CHECK-PTX: %[[VAL_192:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_191]], i32 1, i32 31) diff --git a/third_party/xla/xla/service/gpu/tests/reduce_column_layout_change.hlo b/third_party/xla/xla/service/gpu/tests/reduce_column_layout_change.hlo index 122929f3df280a..e738981f1e2bca 100644 --- a/third_party/xla/xla/service/gpu/tests/reduce_column_layout_change.hlo +++ b/third_party/xla/xla/service/gpu/tests/reduce_column_layout_change.hlo @@ -98,7 +98,7 @@ ENTRY kernel_entry { // CHECK: %[[VAL_51:.*]] = addrspacecast ptr addrspace(3) %[[VAL_50]] to ptr // CHECK: %[[VAL_52:.*]] = load float, ptr %[[VAL_51]], align 4 // CHECK-PTX: %[[VAL_53:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_52]], i32 16, i32 31) -// CHECK-GCN: %[[VAL_53_:.*]] = call i32 @__ockl_readuplane_i32 +// CHECK-GCN: %[[VAL_53_:.*]] = call i32 @llvm.amdgcn.ds.swizzle // CHECK-GCN: %[[VAL_53:.*]] = bitcast i32 // CHECK: store float %[[VAL_53]], ptr{{( addrspace\(5\))?}} %[[VAL_9]], align 4 // CHECK-PTX: call void @[[REDUCTION0:reduction0.*]](ptr %[[VAL_51]], ptr %[[VAL_9]], ptr %[[VAL_8]]) @@ -106,7 +106,7 @@ ENTRY kernel_entry { // CHECK: store float %[[VAL_54]], ptr %[[VAL_51]], align 4 // CHECK: %[[VAL_55:.*]] = load float, ptr{{( addrspace\(5\))?}} %[[VAL_51]], align 4 // CHECK-PTX: %[[VAL_56:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_55]], i32 8, i32 31) -// CHECK-GCN: %[[VAL_56_1_:.*]] = call i32 @__ockl_readuplane_i32 +// CHECK-GCN: %[[VAL_56_1_:.*]] = call i32 @llvm.amdgcn.ds.swizzle // CHECK-GCN: %[[VAL_56:.*]] = bitcast i32 // CHECK: store float %[[VAL_56]], ptr{{( addrspace\(5\))?}} %[[VAL_7]], align 4 // CHECK-PTX: call void @[[REDUCTION0]](ptr %[[VAL_51]], ptr %[[VAL_7]], ptr %[[VAL_6]]) @@ -114,7 +114,7 @@ ENTRY kernel_entry { // CHECK: store float %[[VAL_57]], ptr{{( addrspace\(5\))?}} %[[VAL_51]], align 4 // CHECK: %[[VAL_58:.*]] = load float, ptr{{( addrspace\(5\))?}} %[[VAL_51]], align 4 // CHECK-PTX: %[[VAL_59:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_58]], i32 4, i32 31) -// CHECK-GCN: %[[VAL_59_:.*]] = call i32 @__ockl_readuplane_i32 +// CHECK-GCN: %[[VAL_59_:.*]] = call i32 @llvm.amdgcn.ds.swizzle // CHECK-GCN: %[[VAL_59:.*]] = bitcast i32 // CHECK: store float %[[VAL_59]], ptr{{( addrspace\(5\))?}} %[[VAL_5]], align 4 // CHECK-PTX: call void @[[REDUCTION0]](ptr %[[VAL_51]], ptr %[[VAL_5]], ptr %[[VAL_4]]) @@ -122,7 +122,7 @@ ENTRY kernel_entry { // CHECK: store float %[[VAL_60]], ptr{{( addrspace\(5\))?}} %[[VAL_51]], align 4 // CHECK: %[[VAL_61:.*]] = load float, ptr{{( addrspace\(5\))?}} %[[VAL_51]], align 4 // CHECK-PTX: %[[VAL_62:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_61]], i32 2, i32 31) -// CHECK-GCN: %[[VAL_62_:.*]] = call i32 @__ockl_readuplane_i32 +// CHECK-GCN: %[[VAL_62_:.*]] = call i32 @llvm.amdgcn.ds.swizzle // CHECK-GCN: %[[VAL_62:.*]] = bitcast i32 // CHECK: store float %[[VAL_62]], ptr{{( addrspace\(5\))?}} %[[VAL_3]], align 4 // CHECK-PTX: call void @[[REDUCTION0]](ptr %[[VAL_51]], ptr %[[VAL_3]], ptr %[[VAL_2]]) @@ -130,7 +130,7 @@ ENTRY kernel_entry { // CHECK: store float %[[VAL_63]], ptr{{( addrspace\(5\))?}} %[[VAL_51]], align 4 // CHECK: %[[VAL_64:.*]] = load float, ptr{{( addrspace\(5\))?}} %[[VAL_51]], align 4 // CHECK-PTX: %[[VAL_65:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_64]], i32 1, i32 31) -// CHECK-GCN: %[[VAL_65_:.*]] = call i32 @__ockl_readuplane_i32 +// CHECK-GCN: %[[VAL_65_:.*]] = call i32 @llvm.amdgcn.ds.swizzle // CHECK-GCN: %[[VAL_65:.*]] = bitcast i32 // CHECK: store float %[[VAL_65]], ptr{{( addrspace\(5\))?}} %[[VAL_1]], align 4 // CHECK-PTX: call void @[[REDUCTION0]](ptr %[[VAL_51]], ptr %[[VAL_1]], ptr %[[VAL_0]]) diff --git a/third_party/xla/xla/service/gpu/tests/reduce_f64_column.hlo b/third_party/xla/xla/service/gpu/tests/reduce_f64_column.hlo index 982e45863e2547..43c252b0e695b7 100644 --- a/third_party/xla/xla/service/gpu/tests/reduce_f64_column.hlo +++ b/third_party/xla/xla/service/gpu/tests/reduce_f64_column.hlo @@ -108,11 +108,11 @@ ENTRY e { // CHECK: %[[VAL_54:.*]] = bitcast i64 %[[VAL_53]] to <2 x i32> // CHECK: %[[VAL_55:.*]] = extractelement <2 x i32> %[[VAL_54]], i64 0 // CHECK-PTX: %[[VAL_56:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_55]], i32 16, i32 31) -// CHECK-GCN: %[[VAL_56:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_55]], i32 16) +// CHECK-GCN: %[[VAL_56:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_55]], i32 543) // CHECK: %[[VAL_57:.*]] = insertelement <2 x i32> %[[VAL_54]], i32 %[[VAL_56]], i64 0 // CHECK: %[[VAL_58:.*]] = extractelement <2 x i32> %[[VAL_57]], i64 1 // CHECK-PTX: %[[VAL_59:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_58]], i32 16, i32 31) -// CHECK-GCN: %[[VAL_59:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_58]], i32 16) +// CHECK-GCN: %[[VAL_59:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_58]], i32 543) // CHECK: %[[VAL_60:.*]] = insertelement <2 x i32> %[[VAL_57]], i32 %[[VAL_59]], i64 1 // CHECK: %[[VAL_61:.*]] = bitcast <2 x i32> %[[VAL_60]] to i64 // CHECK: %[[VAL_62:.*]] = bitcast i64 %[[VAL_61]] to double @@ -128,11 +128,11 @@ ENTRY e { // CHECK: %[[VAL_66:.*]] = bitcast i64 %[[VAL_65]] to <2 x i32> // CHECK: %[[VAL_67:.*]] = extractelement <2 x i32> %[[VAL_66]], i64 0 // CHECK-PTX: %[[VAL_68:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_67]], i32 8, i32 31) -// CHECK-GCN: %[[VAL_68:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_67]], i32 8) +// CHECK-GCN: %[[VAL_68:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_67]], i32 287) // CHECK: %[[VAL_69:.*]] = insertelement <2 x i32> %[[VAL_66]], i32 %[[VAL_68]], i64 0 // CHECK: %[[VAL_70:.*]] = extractelement <2 x i32> %[[VAL_69]], i64 1 // CHECK-PTX: %[[VAL_71:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_70]], i32 8, i32 31) -// CHECK-GCN: %[[VAL_71:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_70]], i32 8) +// CHECK-GCN: %[[VAL_71:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_70]], i32 287) // CHECK: %[[VAL_72:.*]] = insertelement <2 x i32> %[[VAL_69]], i32 %[[VAL_71]], i64 1 // CHECK: %[[VAL_73:.*]] = bitcast <2 x i32> %[[VAL_72]] to i64 // CHECK: %[[VAL_74:.*]] = bitcast i64 %[[VAL_73]] to double @@ -148,11 +148,11 @@ ENTRY e { // CHECK: %[[VAL_78:.*]] = bitcast i64 %[[VAL_77]] to <2 x i32> // CHECK: %[[VAL_79:.*]] = extractelement <2 x i32> %[[VAL_78]], i64 0 // CHECK-PTX: %[[VAL_80:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_79]], i32 4, i32 31) -// CHECK-GCN: %[[VAL_80:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_79]], i32 4) +// CHECK-GCN: %[[VAL_80:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_79]], i32 159) // CHECK: %[[VAL_81:.*]] = insertelement <2 x i32> %[[VAL_78]], i32 %[[VAL_80]], i64 0 // CHECK: %[[VAL_82:.*]] = extractelement <2 x i32> %[[VAL_81]], i64 1 // CHECK-PTX: %[[VAL_83:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_82]], i32 4, i32 31) -// CHECK-GCN: %[[VAL_83:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_82]], i32 4) +// CHECK-GCN: %[[VAL_83:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_82]], i32 159) // CHECK: %[[VAL_84:.*]] = insertelement <2 x i32> %[[VAL_81]], i32 %[[VAL_83]], i64 1 // CHECK: %[[VAL_85:.*]] = bitcast <2 x i32> %[[VAL_84]] to i64 // CHECK: %[[VAL_86:.*]] = bitcast i64 %[[VAL_85]] to double @@ -168,11 +168,11 @@ ENTRY e { // CHECK: %[[VAL_90:.*]] = bitcast i64 %[[VAL_89]] to <2 x i32> // CHECK: %[[VAL_91:.*]] = extractelement <2 x i32> %[[VAL_90]], i64 0 // CHECK-PTX: %[[VAL_92:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_91]], i32 2, i32 31) -// CHECK-GCN: %[[VAL_92:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_91]], i32 2) +// CHECK-GCN: %[[VAL_92:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_91]], i32 95) // CHECK: %[[VAL_93:.*]] = insertelement <2 x i32> %[[VAL_90]], i32 %[[VAL_92]], i64 0 // CHECK: %[[VAL_94:.*]] = extractelement <2 x i32> %[[VAL_93]], i64 1 // CHECK-PTX: %[[VAL_95:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_94]], i32 2, i32 31) -// CHECK-GCN: %[[VAL_95:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_94]], i32 2) +// CHECK-GCN: %[[VAL_95:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_94]], i32 95) // CHECK: %[[VAL_96:.*]] = insertelement <2 x i32> %[[VAL_93]], i32 %[[VAL_95]], i64 1 // CHECK: %[[VAL_97:.*]] = bitcast <2 x i32> %[[VAL_96]] to i64 // CHECK: %[[VAL_98:.*]] = bitcast i64 %[[VAL_97]] to double @@ -188,11 +188,11 @@ ENTRY e { // CHECK: %[[VAL_102:.*]] = bitcast i64 %[[VAL_101]] to <2 x i32> // CHECK: %[[VAL_103:.*]] = extractelement <2 x i32> %[[VAL_102]], i64 0 // CHECK-PTX: %[[VAL_104:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_103]], i32 1, i32 31) -// CHECK-GCN: %[[VAL_104:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_103]], i32 1) +// CHECK-GCN: %[[VAL_104:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_103]], i32 63) // CHECK: %[[VAL_105:.*]] = insertelement <2 x i32> %[[VAL_102]], i32 %[[VAL_104]], i64 0 // CHECK: %[[VAL_106:.*]] = extractelement <2 x i32> %[[VAL_105]], i64 1 // CHECK-PTX: %[[VAL_107:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_106]], i32 1, i32 31) -// CHECK-GCN: %[[VAL_107:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_106]], i32 1) +// CHECK-GCN: %[[VAL_107:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_106]], i32 63) // CHECK: %[[VAL_108:.*]] = insertelement <2 x i32> %[[VAL_105]], i32 %[[VAL_107]], i64 1 // CHECK: %[[VAL_109:.*]] = bitcast <2 x i32> %[[VAL_108]] to i64 // CHECK: %[[VAL_110:.*]] = bitcast i64 %[[VAL_109]] to double diff --git a/third_party/xla/xla/service/gpu/tests/reduce_large_row_to_scalar.hlo b/third_party/xla/xla/service/gpu/tests/reduce_large_row_to_scalar.hlo index 35ba85befe94a6..e8173211b45b1d 100644 --- a/third_party/xla/xla/service/gpu/tests/reduce_large_row_to_scalar.hlo +++ b/third_party/xla/xla/service/gpu/tests/reduce_large_row_to_scalar.hlo @@ -90,19 +90,19 @@ ENTRY reduce.1 { // CHECK: %[[VAL_51:.*]] = bitcast i128 %[[VAL_50]] to <4 x i32> // CHECK: %[[VAL_52:.*]] = extractelement <4 x i32> %[[VAL_51]], i64 0 // CHECK-PTX: %[[VAL_53:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_52]], i32 16, i32 31) -// CHECK-GCN: %[[VAL_53:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_52]], i32 16) +// CHECK-GCN: %[[VAL_53:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_52]], i32 543) // CHECK: %[[VAL_54:.*]] = insertelement <4 x i32> %[[VAL_51]], i32 %[[VAL_53]], i64 0 // CHECK: %[[VAL_55:.*]] = extractelement <4 x i32> %[[VAL_54]], i64 1 // CHECK-PTX: %[[VAL_56:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_55]], i32 16, i32 31) -// CHECK-GCN: %[[VAL_56:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_55]], i32 16) +// CHECK-GCN: %[[VAL_56:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_55]], i32 543) // CHECK: %[[VAL_57:.*]] = insertelement <4 x i32> %[[VAL_54]], i32 %[[VAL_56]], i64 1 // CHECK: %[[VAL_58:.*]] = extractelement <4 x i32> %[[VAL_57]], i64 2 // CHECK-PTX: %[[VAL_59:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_58]], i32 16, i32 31) -// CHECK-GCN: %[[VAL_59:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_58]], i32 16) +// CHECK-GCN: %[[VAL_59:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_58]], i32 543) // CHECK: %[[VAL_60:.*]] = insertelement <4 x i32> %[[VAL_57]], i32 %[[VAL_59]], i64 2 // CHECK: %[[VAL_61:.*]] = extractelement <4 x i32> %[[VAL_60]], i64 3 // CHECK-PTX: %[[VAL_62:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_61]], i32 16, i32 31) -// CHECK-GCN: %[[VAL_62:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_61]], i32 16) +// CHECK-GCN: %[[VAL_62:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_61]], i32 543) // CHECK: %[[VAL_63:.*]] = insertelement <4 x i32> %[[VAL_60]], i32 %[[VAL_62]], i64 3 // CHECK: %[[VAL_64:.*]] = bitcast <4 x i32> %[[VAL_63]] to i128 // CHECK: store i128 %[[VAL_64]], ptr{{.*}} %[[VAL_21]], align {{(16|8)}} @@ -117,19 +117,19 @@ ENTRY reduce.1 { // CHECK: %[[VAL_67:.*]] = bitcast i128 %[[VAL_66]] to <4 x i32> // CHECK: %[[VAL_68:.*]] = extractelement <4 x i32> %[[VAL_67]], i64 0 // CHECK-PTX: %[[VAL_69:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_68]], i32 8, i32 31) -// CHECK-GCN: %[[VAL_69:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_68]], i32 8) +// CHECK-GCN: %[[VAL_69:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_68]], i32 287) // CHECK: %[[VAL_70:.*]] = insertelement <4 x i32> %[[VAL_67]], i32 %[[VAL_69]], i64 0 // CHECK: %[[VAL_71:.*]] = extractelement <4 x i32> %[[VAL_70]], i64 1 // CHECK-PTX: %[[VAL_72:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_71]], i32 8, i32 31) -// CHECK-GCN: %[[VAL_72:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_71]], i32 8) +// CHECK-GCN: %[[VAL_72:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_71]], i32 287) // CHECK: %[[VAL_73:.*]] = insertelement <4 x i32> %[[VAL_70]], i32 %[[VAL_72]], i64 1 // CHECK: %[[VAL_74:.*]] = extractelement <4 x i32> %[[VAL_73]], i64 2 // CHECK-PTX: %[[VAL_75:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_74]], i32 8, i32 31) -// CHECK-GCN: %[[VAL_75:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_74]], i32 8) +// CHECK-GCN: %[[VAL_75:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_74]], i32 287) // CHECK: %[[VAL_76:.*]] = insertelement <4 x i32> %[[VAL_73]], i32 %[[VAL_75]], i64 2 // CHECK: %[[VAL_77:.*]] = extractelement <4 x i32> %[[VAL_76]], i64 3 // CHECK-PTX: %[[VAL_78:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_77]], i32 8, i32 31) -// CHECK-GCN: %[[VAL_78:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_77]], i32 8) +// CHECK-GCN: %[[VAL_78:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_77]], i32 287) // CHECK: %[[VAL_79:.*]] = insertelement <4 x i32> %[[VAL_76]], i32 %[[VAL_78]], i64 3 // CHECK: %[[VAL_80:.*]] = bitcast <4 x i32> %[[VAL_79]] to i128 // CHECK: store i128 %[[VAL_80]], ptr{{.*}} %[[VAL_19]], align {{(16|8)}} @@ -144,19 +144,19 @@ ENTRY reduce.1 { // CHECK: %[[VAL_83:.*]] = bitcast i128 %[[VAL_82]] to <4 x i32> // CHECK: %[[VAL_84:.*]] = extractelement <4 x i32> %[[VAL_83]], i64 0 // CHECK-PTX: %[[VAL_85:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_84]], i32 4, i32 31) -// CHECK-GCN: %[[VAL_85:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_84]], i32 4) +// CHECK-GCN: %[[VAL_85:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_84]], i32 159) // CHECK: %[[VAL_86:.*]] = insertelement <4 x i32> %[[VAL_83]], i32 %[[VAL_85]], i64 0 // CHECK: %[[VAL_87:.*]] = extractelement <4 x i32> %[[VAL_86]], i64 1 // CHECK-PTX: %[[VAL_88:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_87]], i32 4, i32 31) -// CHECK-GCN: %[[VAL_88:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_87]], i32 4) +// CHECK-GCN: %[[VAL_88:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_87]], i32 159) // CHECK: %[[VAL_89:.*]] = insertelement <4 x i32> %[[VAL_86]], i32 %[[VAL_88]], i64 1 // CHECK: %[[VAL_90:.*]] = extractelement <4 x i32> %[[VAL_89]], i64 2 // CHECK-PTX: %[[VAL_91:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_90]], i32 4, i32 31) -// CHECK-GCN: %[[VAL_91:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_90]], i32 4) +// CHECK-GCN: %[[VAL_91:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_90]], i32 159) // CHECK: %[[VAL_92:.*]] = insertelement <4 x i32> %[[VAL_89]], i32 %[[VAL_91]], i64 2 // CHECK: %[[VAL_93:.*]] = extractelement <4 x i32> %[[VAL_92]], i64 3 // CHECK-PTX: %[[VAL_94:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_93]], i32 4, i32 31) -// CHECK-GCN: %[[VAL_94:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_93]], i32 4) +// CHECK-GCN: %[[VAL_94:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_93]], i32 159) // CHECK: %[[VAL_95:.*]] = insertelement <4 x i32> %[[VAL_92]], i32 %[[VAL_94]], i64 3 // CHECK: %[[VAL_96:.*]] = bitcast <4 x i32> %[[VAL_95]] to i128 // CHECK: store i128 %[[VAL_96]], ptr{{.*}} %[[VAL_17]], align {{(16|8)}} @@ -171,19 +171,19 @@ ENTRY reduce.1 { // CHECK: %[[VAL_99:.*]] = bitcast i128 %[[VAL_98]] to <4 x i32> // CHECK: %[[VAL_100:.*]] = extractelement <4 x i32> %[[VAL_99]], i64 0 // CHECK-PTX: %[[VAL_101:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_100]], i32 2, i32 31) -// CHECK-GCN: %[[VAL_101:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_100]], i32 2) +// CHECK-GCN: %[[VAL_101:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_100]], i32 95) // CHECK: %[[VAL_102:.*]] = insertelement <4 x i32> %[[VAL_99]], i32 %[[VAL_101]], i64 0 // CHECK: %[[VAL_103:.*]] = extractelement <4 x i32> %[[VAL_102]], i64 1 // CHECK-PTX: %[[VAL_104:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_103]], i32 2, i32 31) -// CHECK-GCN: %[[VAL_104:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_103]], i32 2) +// CHECK-GCN: %[[VAL_104:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_103]], i32 95) // CHECK: %[[VAL_105:.*]] = insertelement <4 x i32> %[[VAL_102]], i32 %[[VAL_104]], i64 1 // CHECK: %[[VAL_106:.*]] = extractelement <4 x i32> %[[VAL_105]], i64 2 // CHECK-PTX: %[[VAL_107:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_106]], i32 2, i32 31) -// CHECK-GCN: %[[VAL_107:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_106]], i32 2) +// CHECK-GCN: %[[VAL_107:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_106]], i32 95) // CHECK: %[[VAL_108:.*]] = insertelement <4 x i32> %[[VAL_105]], i32 %[[VAL_107]], i64 2 // CHECK: %[[VAL_109:.*]] = extractelement <4 x i32> %[[VAL_108]], i64 3 // CHECK-PTX: %[[VAL_110:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_109]], i32 2, i32 31) -// CHECK-GCN: %[[VAL_110:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_109]], i32 2) +// CHECK-GCN: %[[VAL_110:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_109]], i32 95) // CHECK: %[[VAL_111:.*]] = insertelement <4 x i32> %[[VAL_108]], i32 %[[VAL_110]], i64 3 // CHECK: %[[VAL_112:.*]] = bitcast <4 x i32> %[[VAL_111]] to i128 // CHECK: store i128 %[[VAL_112]], ptr{{.*}} %[[VAL_15]], align {{(16|8)}} @@ -198,19 +198,19 @@ ENTRY reduce.1 { // CHECK: %[[VAL_115:.*]] = bitcast i128 %[[VAL_114]] to <4 x i32> // CHECK: %[[VAL_116:.*]] = extractelement <4 x i32> %[[VAL_115]], i64 0 // CHECK-PTX: %[[VAL_117:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_116]], i32 1, i32 31) -// CHECK-GCN: %[[VAL_117:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_116]], i32 1) +// CHECK-GCN: %[[VAL_117:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_116]], i32 63) // CHECK: %[[VAL_118:.*]] = insertelement <4 x i32> %[[VAL_115]], i32 %[[VAL_117]], i64 0 // CHECK: %[[VAL_119:.*]] = extractelement <4 x i32> %[[VAL_118]], i64 1 // CHECK-PTX: %[[VAL_120:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_119]], i32 1, i32 31) -// CHECK-GCN: %[[VAL_120:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_119]], i32 1) +// CHECK-GCN: %[[VAL_120:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_119]], i32 63) // CHECK: %[[VAL_121:.*]] = insertelement <4 x i32> %[[VAL_118]], i32 %[[VAL_120]], i64 1 // CHECK: %[[VAL_122:.*]] = extractelement <4 x i32> %[[VAL_121]], i64 2 // CHECK-PTX: %[[VAL_123:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_122]], i32 1, i32 31) -// CHECK-GCN: %[[VAL_123:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_122]], i32 1) +// CHECK-GCN: %[[VAL_123:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_122]], i32 63) // CHECK: %[[VAL_124:.*]] = insertelement <4 x i32> %[[VAL_121]], i32 %[[VAL_123]], i64 2 // CHECK: %[[VAL_125:.*]] = extractelement <4 x i32> %[[VAL_124]], i64 3 // CHECK-PTX: %[[VAL_126:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_125]], i32 1, i32 31) -// CHECK-GCN: %[[VAL_126:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_125]], i32 1) +// CHECK-GCN: %[[VAL_126:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_125]], i32 63) // CHECK: %[[VAL_127:.*]] = insertelement <4 x i32> %[[VAL_124]], i32 %[[VAL_126]], i64 3 // CHECK: %[[VAL_128:.*]] = bitcast <4 x i32> %[[VAL_127]] to i128 // CHECK: store i128 %[[VAL_128]], ptr{{.*}} %[[VAL_13]], align {{(16|8)}} @@ -392,19 +392,19 @@ ENTRY reduce.1 { // CHECK: %[[VAL_210:.*]] = load i128, ptr %[[VAL_209]], align {{(16|8)}} // CHECK: %[[VAL_211:.*]] = bitcast i128 %[[VAL_210]] to <4 x i32> // CHECK: %[[VAL_212:.*]] = extractelement <4 x i32> %[[VAL_211]], i64 0 -// CHECK-GCN: %[[VAL_213:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_212]], i32 16) +// CHECK-GCN: %[[VAL_213:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_212]], i32 543) // CHECK-PTX: %[[VAL_213:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_212]], i32 16, i32 31) // CHECK: %[[VAL_214:.*]] = insertelement <4 x i32> %[[VAL_211]], i32 %[[VAL_213]], i64 0 // CHECK: %[[VAL_215:.*]] = extractelement <4 x i32> %[[VAL_214]], i64 1 -// CHECK-GCN: %[[VAL_216:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_215]], i32 16) +// CHECK-GCN: %[[VAL_216:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_215]], i32 543) // CHECK-PTX: %[[VAL_216:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_215]], i32 16, i32 31) // CHECK: %[[VAL_217:.*]] = insertelement <4 x i32> %[[VAL_214]], i32 %[[VAL_216]], i64 1 // CHECK: %[[VAL_218:.*]] = extractelement <4 x i32> %[[VAL_217]], i64 2 -// CHECK-GCN: %[[VAL_219:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_218]], i32 16) +// CHECK-GCN: %[[VAL_219:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_218]], i32 543) // CHECK-PTX: %[[VAL_219:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_218]], i32 16, i32 31) // CHECK: %[[VAL_220:.*]] = insertelement <4 x i32> %[[VAL_217]], i32 %[[VAL_219]], i64 2 // CHECK: %[[VAL_221:.*]] = extractelement <4 x i32> %[[VAL_220]], i64 3 -// CHECK-GCN: %[[VAL_222:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_221]], i32 16) +// CHECK-GCN: %[[VAL_222:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_221]], i32 543) // CHECK-PTX: %[[VAL_222:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_221]], i32 16, i32 31) // CHECK: %[[VAL_223:.*]] = insertelement <4 x i32> %[[VAL_220]], i32 %[[VAL_222]], i64 3 // CHECK: %[[VAL_224:.*]] = bitcast <4 x i32> %[[VAL_223]] to i128 @@ -418,19 +418,19 @@ ENTRY reduce.1 { // CHECK: %[[VAL_226:.*]] = load i128, ptr %[[VAL_209]], align {{(16|8)}} // CHECK: %[[VAL_227:.*]] = bitcast i128 %[[VAL_226]] to <4 x i32> // CHECK: %[[VAL_228:.*]] = extractelement <4 x i32> %[[VAL_227]], i64 0 -// CHECK-GCN: %[[VAL_229:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_228]], i32 8) +// CHECK-GCN: %[[VAL_229:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_228]], i32 287) // CHECK-PTX: %[[VAL_229:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_228]], i32 8, i32 31) // CHECK: %[[VAL_230:.*]] = insertelement <4 x i32> %[[VAL_227]], i32 %[[VAL_229]], i64 0 // CHECK: %[[VAL_231:.*]] = extractelement <4 x i32> %[[VAL_230]], i64 1 -// CHECK-GCN: %[[VAL_232:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_231]], i32 8) +// CHECK-GCN: %[[VAL_232:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_231]], i32 287) // CHECK-PTX: %[[VAL_232:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_231]], i32 8, i32 31) // CHECK: %[[VAL_233:.*]] = insertelement <4 x i32> %[[VAL_230]], i32 %[[VAL_232]], i64 1 // CHECK: %[[VAL_234:.*]] = extractelement <4 x i32> %[[VAL_233]], i64 2 -// CHECK-GCN: %[[VAL_235:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_234]], i32 8) +// CHECK-GCN: %[[VAL_235:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_234]], i32 287) // CHECK-PTX: %[[VAL_235:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_234]], i32 8, i32 31) // CHECK: %[[VAL_236:.*]] = insertelement <4 x i32> %[[VAL_233]], i32 %[[VAL_235]], i64 2 // CHECK: %[[VAL_237:.*]] = extractelement <4 x i32> %[[VAL_236]], i64 3 -// CHECK-GCN: %[[VAL_238:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_237]], i32 8) +// CHECK-GCN: %[[VAL_238:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_237]], i32 287) // CHECK-PTX: %[[VAL_238:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_237]], i32 8, i32 31) // CHECK: %[[VAL_239:.*]] = insertelement <4 x i32> %[[VAL_236]], i32 %[[VAL_238]], i64 3 // CHECK: %[[VAL_240:.*]] = bitcast <4 x i32> %[[VAL_239]] to i128 @@ -444,19 +444,19 @@ ENTRY reduce.1 { // CHECK: %[[VAL_242:.*]] = load i128, ptr %[[VAL_209]], align {{(16|8)}} // CHECK: %[[VAL_243:.*]] = bitcast i128 %[[VAL_242]] to <4 x i32> // CHECK: %[[VAL_244:.*]] = extractelement <4 x i32> %[[VAL_243]], i64 0 -// CHECK-GCN: %[[VAL_245:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_244]], i32 4) +// CHECK-GCN: %[[VAL_245:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_244]], i32 159) // CHECK-PTX: %[[VAL_245:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_244]], i32 4, i32 31) // CHECK: %[[VAL_246:.*]] = insertelement <4 x i32> %[[VAL_243]], i32 %[[VAL_245]], i64 0 // CHECK: %[[VAL_247:.*]] = extractelement <4 x i32> %[[VAL_246]], i64 1 -// CHECK-GCN: %[[VAL_248:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_247]], i32 4) +// CHECK-GCN: %[[VAL_248:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_247]], i32 159) // CHECK-PTX: %[[VAL_248:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_247]], i32 4, i32 31) // CHECK: %[[VAL_249:.*]] = insertelement <4 x i32> %[[VAL_246]], i32 %[[VAL_248]], i64 1 // CHECK: %[[VAL_250:.*]] = extractelement <4 x i32> %[[VAL_249]], i64 2 -// CHECK-GCN: %[[VAL_251:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_250]], i32 4) +// CHECK-GCN: %[[VAL_251:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_250]], i32 159) // CHECK-PTX: %[[VAL_251:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_250]], i32 4, i32 31) // CHECK: %[[VAL_252:.*]] = insertelement <4 x i32> %[[VAL_249]], i32 %[[VAL_251]], i64 2 // CHECK: %[[VAL_253:.*]] = extractelement <4 x i32> %[[VAL_252]], i64 3 -// CHECK-GCN: %[[VAL_254:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_253]], i32 4) +// CHECK-GCN: %[[VAL_254:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_253]], i32 159) // CHECK-PTX: %[[VAL_254:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_253]], i32 4, i32 31) // CHECK: %[[VAL_255:.*]] = insertelement <4 x i32> %[[VAL_252]], i32 %[[VAL_254]], i64 3 // CHECK: %[[VAL_256:.*]] = bitcast <4 x i32> %[[VAL_255]] to i128 @@ -470,19 +470,19 @@ ENTRY reduce.1 { // CHECK: %[[VAL_258:.*]] = load i128, ptr %[[VAL_209]], align {{(16|8)}} // CHECK: %[[VAL_259:.*]] = bitcast i128 %[[VAL_258]] to <4 x i32> // CHECK: %[[VAL_260:.*]] = extractelement <4 x i32> %[[VAL_259]], i64 0 -// CHECK-GCN: %[[VAL_261:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_260]], i32 2) +// CHECK-GCN: %[[VAL_261:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_260]], i32 95) // CHECK-PTX: %[[VAL_261:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_260]], i32 2, i32 31) // CHECK: %[[VAL_262:.*]] = insertelement <4 x i32> %[[VAL_259]], i32 %[[VAL_261]], i64 0 // CHECK: %[[VAL_263:.*]] = extractelement <4 x i32> %[[VAL_262]], i64 1 -// CHECK-GCN: %[[VAL_264:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_263]], i32 2) +// CHECK-GCN: %[[VAL_264:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_263]], i32 95) // CHECK-PTX: %[[VAL_264:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_263]], i32 2, i32 31) // CHECK: %[[VAL_265:.*]] = insertelement <4 x i32> %[[VAL_262]], i32 %[[VAL_264]], i64 1 // CHECK: %[[VAL_266:.*]] = extractelement <4 x i32> %[[VAL_265]], i64 2 -// CHECK-GCN: %[[VAL_267:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_266]], i32 2) +// CHECK-GCN: %[[VAL_267:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_266]], i32 95) // CHECK-PTX: %[[VAL_267:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_266]], i32 2, i32 31) // CHECK: %[[VAL_268:.*]] = insertelement <4 x i32> %[[VAL_265]], i32 %[[VAL_267]], i64 2 // CHECK: %[[VAL_269:.*]] = extractelement <4 x i32> %[[VAL_268]], i64 3 -// CHECK-GCN: %[[VAL_270:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_269]], i32 2) +// CHECK-GCN: %[[VAL_270:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_269]], i32 95) // CHECK-PTX: %[[VAL_270:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_269]], i32 2, i32 31) // CHECK: %[[VAL_271:.*]] = insertelement <4 x i32> %[[VAL_268]], i32 %[[VAL_270]], i64 3 // CHECK: %[[VAL_272:.*]] = bitcast <4 x i32> %[[VAL_271]] to i128 @@ -496,19 +496,19 @@ ENTRY reduce.1 { // CHECK: %[[VAL_274:.*]] = load i128, ptr %[[VAL_209]], align {{(16|8)}} // CHECK: %[[VAL_275:.*]] = bitcast i128 %[[VAL_274]] to <4 x i32> // CHECK: %[[VAL_276:.*]] = extractelement <4 x i32> %[[VAL_275]], i64 0 -// CHECK-GCN: %[[VAL_277:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_276]], i32 1) +// CHECK-GCN: %[[VAL_277:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_276]], i32 63) // CHECK-PTX: %[[VAL_277:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_276]], i32 1, i32 31) // CHECK: %[[VAL_278:.*]] = insertelement <4 x i32> %[[VAL_275]], i32 %[[VAL_277]], i64 0 // CHECK: %[[VAL_279:.*]] = extractelement <4 x i32> %[[VAL_278]], i64 1 -// CHECK-GCN: %[[VAL_280:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_279]], i32 1) +// CHECK-GCN: %[[VAL_280:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_279]], i32 63) // CHECK-PTX: %[[VAL_280:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_279]], i32 1, i32 31) // CHECK: %[[VAL_281:.*]] = insertelement <4 x i32> %[[VAL_278]], i32 %[[VAL_280]], i64 1 // CHECK: %[[VAL_282:.*]] = extractelement <4 x i32> %[[VAL_281]], i64 2 -// CHECK-GCN: %[[VAL_283:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_282]], i32 1) +// CHECK-GCN: %[[VAL_283:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_282]], i32 63) // CHECK-PTX: %[[VAL_283:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_282]], i32 1, i32 31) // CHECK: %[[VAL_284:.*]] = insertelement <4 x i32> %[[VAL_281]], i32 %[[VAL_283]], i64 2 // CHECK: %[[VAL_285:.*]] = extractelement <4 x i32> %[[VAL_284]], i64 3 -// CHECK-GCN: %[[VAL_286:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_285]], i32 1) +// CHECK-GCN: %[[VAL_286:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_285]], i32 63) // CHECK-PTX: %[[VAL_286:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_285]], i32 1, i32 31) // CHECK: %[[VAL_287:.*]] = insertelement <4 x i32> %[[VAL_284]], i32 %[[VAL_286]], i64 3 // CHECK: %[[VAL_288:.*]] = bitcast <4 x i32> %[[VAL_287]] to i128 diff --git a/third_party/xla/xla/service/gpu/tests/reduce_row_vectorized.hlo b/third_party/xla/xla/service/gpu/tests/reduce_row_vectorized.hlo index c4976c5fc2b3a7..3e71ce408b7fab 100644 --- a/third_party/xla/xla/service/gpu/tests/reduce_row_vectorized.hlo +++ b/third_party/xla/xla/service/gpu/tests/reduce_row_vectorized.hlo @@ -100,7 +100,7 @@ ENTRY reduce.1 { // CHECK: loop1.loop_exit: ; preds = %[[VAL_44]] // CHECK: %[[VAL_56:.*]] = load float, ptr{{.*}} %[[VAL_28]], align 4 // CHECK-GCN: %[[VAL_57_1:.*]] = bitcast float %[[VAL_56]] to i32 -// CHECK-GCN: %[[VAL_57_2:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_57_1]], i32 16) +// CHECK-GCN: %[[VAL_57_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_57_1]], i32 543) // CHECK-GCN: %[[VAL_57:.*]] = bitcast i32 %[[VAL_57_2]] to float // CHECK-PTX: %[[VAL_57:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_56]], i32 16, i32 31) // CHECK: store float %[[VAL_57]], ptr{{.*}} %[[VAL_20]], align 4 @@ -113,7 +113,7 @@ ENTRY reduce.1 { // CHECK: store float %[[VAL_58]], ptr{{.*}} %[[VAL_28]], align 4 // CHECK: %[[VAL_59:.*]] = load float, ptr{{.*}} %[[VAL_28]], align 4 // CHECK-GCN: %[[VAL_60_1:.*]] = bitcast float %[[VAL_59]] to i32 -// CHECK-GCN: %[[VAL_60_2:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_60_1]], i32 8) +// CHECK-GCN: %[[VAL_60_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_60_1]], i32 287) // CHECK-GCN: %[[VAL_60:.*]] = bitcast i32 %[[VAL_60_2]] to float // CHECK-PTX: %[[VAL_60:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_59]], i32 8, i32 31) // CHECK: store float %[[VAL_60]], ptr{{.*}} %[[VAL_18]], align 4 @@ -126,7 +126,7 @@ ENTRY reduce.1 { // CHECK: store float %[[VAL_61]], ptr{{.*}} %[[VAL_28]], align 4 // CHECK: %[[VAL_62:.*]] = load float, ptr{{.*}} %[[VAL_28]], align 4 // CHECK-GCN: %[[VAL_63_1:.*]] = bitcast float %[[VAL_62]] to i32 -// CHECK-GCN: %[[VAL_63_2:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_63_1]], i32 4) +// CHECK-GCN: %[[VAL_63_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_63_1]], i32 159) // CHECK-GCN: %[[VAL_63:.*]] = bitcast i32 %[[VAL_63_2]] to float // CHECK-PTX: %[[VAL_63:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_62]], i32 4, i32 31) // CHECK: store float %[[VAL_63]], ptr{{.*}} %[[VAL_16]], align 4 @@ -139,7 +139,7 @@ ENTRY reduce.1 { // CHECK: store float %[[VAL_64]], ptr{{.*}} %[[VAL_28]], align 4 // CHECK: %[[VAL_65:.*]] = load float, ptr{{.*}} %[[VAL_28]], align 4 // CHECK-GCN: %[[VAL_66_1:.*]] = bitcast float %[[VAL_65]] to i32 -// CHECK-GCN: %[[VAL_66_2:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_66_1]], i32 2) +// CHECK-GCN: %[[VAL_66_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_66_1]], i32 95) // CHECK-GCN: %[[VAL_66:.*]] = bitcast i32 %[[VAL_66_2]] to float // CHECK-PTX: %[[VAL_66:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_65]], i32 2, i32 31) // CHECK: store float %[[VAL_66]], ptr{{.*}} %[[VAL_14]], align 4 @@ -152,7 +152,7 @@ ENTRY reduce.1 { // CHECK: store float %[[VAL_67]], ptr{{.*}} %[[VAL_28]], align 4 // CHECK: %[[VAL_68:.*]] = load float, ptr{{.*}} %[[VAL_28]], align 4 // CHECK-GCN: %[[VAL_69_1:.*]] = bitcast float %[[VAL_68]] to i32 -// CHECK-GCN: %[[VAL_69_2:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_69_1]], i32 1) +// CHECK-GCN: %[[VAL_69_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_69_1]], i32 63) // CHECK-GCN: %[[VAL_69:.*]] = bitcast i32 %[[VAL_69_2]] to float // CHECK-PTX: %[[VAL_69:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_68]], i32 1, i32 31) // CHECK: store float %[[VAL_69]], ptr{{.*}} %[[VAL_12]], align 4 @@ -339,7 +339,7 @@ ENTRY reduce.1 { // CHECK-PTX: %[[VAL_152:.*]] = select i1 %[[VAL_151]], ptr %[[VAL_150]], ptr %[[VAL_10]] // CHECK: %[[VAL_153:.*]] = load float, ptr %[[VAL_152]], align 4 // CHECK-GCN: %[[VAL_154_1:.*]] = bitcast float %[[VAL_153]] to i32 -// CHECK-GCN: %[[VAL_154_2:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_154_1]], i32 16) +// CHECK-GCN: %[[VAL_154_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_154_1]], i32 543) // CHECK-GCN: %[[VAL_154:.*]] = bitcast i32 %[[VAL_154_2]] to float // CHECK-PTX: %[[VAL_154:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_153]], i32 16, i32 31) // CHECK: store float %[[VAL_154]], ptr{{.*}} %[[VAL_9]], align 4 @@ -351,7 +351,7 @@ ENTRY reduce.1 { // CHECK: store float %[[VAL_155]], ptr %[[VAL_152]], align 4 // CHECK: %[[VAL_156:.*]] = load float, ptr %[[VAL_152]], align 4 // CHECK-GCN: %[[VAL_157_1:.*]] = bitcast float %[[VAL_156]] to i32 -// CHECK-GCN: %[[VAL_157_2:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_157_1]], i32 8) +// CHECK-GCN: %[[VAL_157_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_157_1]], i32 287) // CHECK-GCN: %[[VAL_157:.*]] = bitcast i32 %[[VAL_157_2]] to float // CHECK-PTX: %[[VAL_157:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_156]], i32 8, i32 31) // CHECK: store float %[[VAL_157]], ptr{{.*}} %[[VAL_7]], align 4 @@ -363,7 +363,7 @@ ENTRY reduce.1 { // CHECK: store float %[[VAL_158]], ptr %[[VAL_152]], align 4 // CHECK: %[[VAL_159:.*]] = load float, ptr %[[VAL_152]], align 4 // CHECK-GCN: %[[VAL_160_1:.*]] = bitcast float %[[VAL_159]] to i32 -// CHECK-GCN: %[[VAL_160_2:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_160_1]], i32 4) +// CHECK-GCN: %[[VAL_160_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_160_1]], i32 159) // CHECK-GCN: %[[VAL_160:.*]] = bitcast i32 %[[VAL_160_2]] to float // CHECK-PTX: %[[VAL_160:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_159]], i32 4, i32 31) // CHECK: store float %[[VAL_160]], ptr{{.*}} %[[VAL_5]], align 4 @@ -375,7 +375,7 @@ ENTRY reduce.1 { // CHECK: store float %[[VAL_161]], ptr %[[VAL_152]], align 4 // CHECK: %[[VAL_162:.*]] = load float, ptr %[[VAL_152]], align 4 // CHECK-GCN: %[[VAL_163_1:.*]] = bitcast float %[[VAL_162]] to i32 -// CHECK-GCN: %[[VAL_163_2:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_163_1]], i32 2) +// CHECK-GCN: %[[VAL_163_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_163_1]], i32 95) // CHECK-GCN: %[[VAL_163:.*]] = bitcast i32 %[[VAL_163_2]] to float // CHECK-PTX: %[[VAL_163:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_162]], i32 2, i32 31) // CHECK: store float %[[VAL_163]], ptr{{.*}} %[[VAL_3]], align 4 @@ -387,7 +387,7 @@ ENTRY reduce.1 { // CHECK: store float %[[VAL_164]], ptr %[[VAL_152]], align 4 // CHECK: %[[VAL_165:.*]] = load float, ptr %[[VAL_152]], align 4 // CHECK-GCN: %[[VAL_166_1:.*]] = bitcast float %[[VAL_165]] to i32 -// CHECK-GCN: %[[VAL_166_2:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_166_1]], i32 1) +// CHECK-GCN: %[[VAL_166_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_166_1]], i32 63) // CHECK-GCN: %[[VAL_166:.*]] = bitcast i32 %[[VAL_166_2]] to float // CHECK-PTX: %[[VAL_166:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_165]], i32 1, i32 31) // CHECK: store float %[[VAL_166]], ptr{{.*}} %[[VAL_1]], align 4 diff --git a/third_party/xla/xla/service/gpu/tests/reduce_variadic_column.hlo b/third_party/xla/xla/service/gpu/tests/reduce_variadic_column.hlo index 36008daa5ceda8..75dbf459e9b348 100644 --- a/third_party/xla/xla/service/gpu/tests/reduce_variadic_column.hlo +++ b/third_party/xla/xla/service/gpu/tests/reduce_variadic_column.hlo @@ -125,13 +125,13 @@ ENTRY main { // CHECK: %[[VAL_76:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 // CHECK-PTX: %[[VAL_77:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_76]], i32 16, i32 31) // CHECK-GCN: %[[VAL_76_1:.*]] = bitcast float %[[VAL_76]] to i32 -// CHECK-GCN: %[[VAL_77_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_76_1:.*]], i32 16) +// CHECK-GCN: %[[VAL_77_1:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_76_1:.*]], i32 543) // CHECK-GCN: %[[VAL_77:.*]] = bitcast i32 %[[VAL_77_1:.*]] to float // CHECK: store float %[[VAL_77]], ptr{{.*}}%[[VAL_26]], align 4 // CHECK: %[[VAL_78:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 // CHECK-PTX: %[[VAL_79:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_78]], i32 16, i32 31) // CHECK-GCN: %[[VAL_78_1:.*]] = bitcast float %[[VAL_78]] to i32 -// CHECK-GCN: %[[VAL_79_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_78_1:.*]], i32 16) +// CHECK-GCN: %[[VAL_79_1:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_78_1:.*]], i32 543) // CHECK-GCN: %[[VAL_79:.*]] = bitcast i32 %[[VAL_79_1:.*]] to float // CHECK: store float %[[VAL_79]], ptr{{.*}}%[[VAL_25]], align 4 // CHECK-GCN: %[[VAL_22_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_22]] to ptr @@ -156,13 +156,13 @@ ENTRY main { // CHECK: %[[VAL_84:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 // CHECK-PTX: %[[VAL_85:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_84]], i32 8, i32 31) // CHECK-GCN: %[[VAL_84_1:.*]] = bitcast float %[[VAL_84]] to i32 -// CHECK-GCN: %[[VAL_85_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_84_1:.*]], i32 8) +// CHECK-GCN: %[[VAL_85_1:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_84_1:.*]], i32 287) // CHECK-GCN: %[[VAL_85:.*]] = bitcast i32 %[[VAL_85_1:.*]] to float // CHECK: store float %[[VAL_85]], ptr{{.*}}%[[VAL_21]], align 4 // CHECK: %[[VAL_86:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 // CHECK-PTX: %[[VAL_87:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_86]], i32 8, i32 31) // CHECK-GCN: %[[VAL_86_1:.*]] = bitcast float %[[VAL_86]] to i32 -// CHECK-GCN: %[[VAL_87_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_86_1:.*]], i32 8) +// CHECK-GCN: %[[VAL_87_1:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_86_1:.*]], i32 287) // CHECK-GCN: %[[VAL_87:.*]] = bitcast i32 %[[VAL_87_1:.*]] to float // CHECK: store float %[[VAL_87]], ptr{{.*}}%[[VAL_20]], align 4 // CHECK-GCN: %[[VAL_17_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_17]] to ptr @@ -187,13 +187,13 @@ ENTRY main { // CHECK: %[[VAL_92:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 // CHECK-PTX: %[[VAL_93:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_92]], i32 4, i32 31) // CHECK-GCN: %[[VAL_92_1:.*]] = bitcast float %[[VAL_92]] to i32 -// CHECK-GCN: %[[VAL_93_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_92_1:.*]], i32 4) +// CHECK-GCN: %[[VAL_93_1:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_92_1:.*]], i32 159) // CHECK-GCN: %[[VAL_93:.*]] = bitcast i32 %[[VAL_93_1:.*]] to float // CHECK: store float %[[VAL_93]], ptr{{.*}}%[[VAL_16]], align 4 // CHECK: %[[VAL_94:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 // CHECK-PTX: %[[VAL_95:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_94]], i32 4, i32 31) // CHECK-GCN: %[[VAL_94_1:.*]] = bitcast float %[[VAL_94]] to i32 -// CHECK-GCN: %[[VAL_95_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_94_1:.*]], i32 4) +// CHECK-GCN: %[[VAL_95_1:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_94_1:.*]], i32 159) // CHECK-GCN: %[[VAL_95:.*]] = bitcast i32 %[[VAL_95_1:.*]] to float // CHECK: store float %[[VAL_95]], ptr{{.*}}%[[VAL_15]], align 4 // CHECK-GCN: %[[VAL_12_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_12]] to ptr @@ -218,13 +218,13 @@ ENTRY main { // CHECK: %[[VAL_100:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 // CHECK-PTX: %[[VAL_101:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_100]], i32 2, i32 31) // CHECK-GCN: %[[VAL_100_1:.*]] = bitcast float %[[VAL_100]] to i32 -// CHECK-GCN: %[[VAL_101_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_100_1:.*]], i32 2) +// CHECK-GCN: %[[VAL_101_1:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_100_1:.*]], i32 95) // CHECK-GCN: %[[VAL_101:.*]] = bitcast i32 %[[VAL_101_1:.*]] to float // CHECK: store float %[[VAL_101]], ptr{{.*}}%[[VAL_11]], align 4 // CHECK: %[[VAL_102:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 // CHECK-PTX: %[[VAL_103:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_102]], i32 2, i32 31) // CHECK-GCN: %[[VAL_102_1:.*]] = bitcast float %[[VAL_102]] to i32 -// CHECK-GCN: %[[VAL_103_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_102_1:.*]], i32 2) +// CHECK-GCN: %[[VAL_103_1:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_102_1:.*]], i32 95) // CHECK-GCN: %[[VAL_103:.*]] = bitcast i32 %[[VAL_103_1:.*]] to float // CHECK: store float %[[VAL_103]], ptr{{.*}}%[[VAL_10]], align 4 // CHECK-GCN: %[[VAL_7_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_7]] to ptr @@ -249,13 +249,13 @@ ENTRY main { // CHECK: %[[VAL_108:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 // CHECK-PTX: %[[VAL_109:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_108]], i32 1, i32 31) // CHECK-GCN: %[[VAL_108_1:.*]] = bitcast float %[[VAL_108]] to i32 -// CHECK-GCN: %[[VAL_109_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_108_1:.*]], i32 1) +// CHECK-GCN: %[[VAL_109_1:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_108_1:.*]], i32 63) // CHECK-GCN: %[[VAL_109:.*]] = bitcast i32 %[[VAL_109_1:.*]] to float // CHECK: store float %[[VAL_109]], ptr{{.*}}%[[VAL_6]], align 4 // CHECK: %[[VAL_110:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 // CHECK-PTX: %[[VAL_111:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_110]], i32 1, i32 31) // CHECK-GCN: %[[VAL_110_1:.*]] = bitcast float %[[VAL_110]] to i32 -// CHECK-GCN: %[[VAL_111_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_110_1:.*]], i32 1) +// CHECK-GCN: %[[VAL_111_1:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_110_1:.*]], i32 63) // CHECK-GCN: %[[VAL_111:.*]] = bitcast i32 %[[VAL_111_1:.*]] to float // CHECK: store float %[[VAL_111]], ptr{{.*}}%[[VAL_5]], align 4 // CHECK-GCN: %[[VAL_2_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_2]] to ptr diff --git a/third_party/xla/xla/service/gpu/tests/reduction_vectorization_test.cc b/third_party/xla/xla/service/gpu/tests/reduction_vectorization_test.cc index 3a4f912a205d3a..e4e5845e018d6e 100644 --- a/third_party/xla/xla/service/gpu/tests/reduction_vectorization_test.cc +++ b/third_party/xla/xla/service/gpu/tests/reduction_vectorization_test.cc @@ -140,7 +140,7 @@ CHECK-NOT: SHUFFLE expected_optimized_llvm_ir, {{"X_THREAD", is_built_with_rocm_ ? "@llvm.amdgcn.workitem.id.x" : "@llvm.nvvm.read.ptx.sreg.tid.x"}, - {"SHUFFLE", is_built_with_rocm_ ? "llvm.amdgcn.ds.bpermute" + {"SHUFFLE", is_built_with_rocm_ ? "@llvm.amdgcn.ds.swizzle" : "llvm.nvvm.shfl.sync.down.f32"}}); CompileAndVerifyIr(hlo_text, expected_optimized_llvm_ir, true); From 03850ed22ff4ea3f3325b82698dd1e311f7146cb Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Thu, 20 Jun 2024 04:29:13 -0700 Subject: [PATCH 046/256] [JAX] Fix FDO profile deserialization. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Passing .c_str() to the ParseFromString can lead to inconsistent behavior when c string is not properly null terminated. This diff initializes an std::string explicitly by providing a size of a buffer to be parsed. PiperOrigin-RevId: 644979040 --- third_party/xla/xla/python/profiler.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/python/profiler.cc b/third_party/xla/xla/python/profiler.cc index 5c4ce1b50771a4..66c4492f41654d 100644 --- a/third_party/xla/xla/python/profiler.cc +++ b/third_party/xla/xla/python/profiler.cc @@ -288,7 +288,8 @@ void BuildProfilerSubmodule(nb::module_& m) { fdo_profiles; for (const nb::bytes& profile : profiles) { tensorflow::profiler::ProfiledInstructionsProto profile_proto; - profile_proto.ParseFromString(profile.c_str()); + profile_proto.ParseFromString( + std::string(profile.c_str(), profile.size())); fdo_profiles.push_back(std::move(profile_proto)); } From 5350f43913f187ddb76bed69cbd7216435827895 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 20 Jun 2024 04:37:53 -0700 Subject: [PATCH 047/256] [xla:cpu] Don't forget to include buffer branch index buffer in conditional thunk buffer uses PiperOrigin-RevId: 644981099 --- third_party/xla/xla/service/cpu/runtime/BUILD | 22 ++++++ .../service/cpu/runtime/conditional_thunk.cc | 3 +- .../cpu/runtime/conditional_thunk_test.cc | 69 +++++++++++++++++++ 3 files changed, 93 insertions(+), 1 deletion(-) create mode 100644 third_party/xla/xla/service/cpu/runtime/conditional_thunk_test.cc diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD index 08de6665a6b0c3..994278b8271ee9 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/service/cpu/runtime/BUILD @@ -156,6 +156,7 @@ cc_library( ":thunk", ":thunk_executor", "//xla:util", + "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/stream_executor", "//xla/tsl/concurrency:async_value", @@ -170,6 +171,27 @@ cc_library( ], ) +xla_cc_test( + name = "conditional_thunk_test", + srcs = ["conditional_thunk_test.cc"], + deps = [ + ":buffer_allocations", + ":conditional_thunk", + ":thunk", + "//xla:shape_util", + "//xla/runtime:buffer_use", + "//xla/service:buffer_assignment", + "//xla/service:maybe_owning_device_memory", + "//xla/stream_executor", + "//xla/tsl/concurrency:async_value", + "@com_google_absl//absl/status", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + cc_library( name = "all_gather_thunk", srcs = ["all_gather_thunk.cc"], diff --git a/third_party/xla/xla/service/cpu/runtime/conditional_thunk.cc b/third_party/xla/xla/service/cpu/runtime/conditional_thunk.cc index ff71d0f5a9e647..65fb146686e480 100644 --- a/third_party/xla/xla/service/cpu/runtime/conditional_thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/conditional_thunk.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_format.h" +#include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/runtime/thunk.h" #include "xla/service/cpu/runtime/thunk_executor.h" @@ -96,7 +97,7 @@ tsl::AsyncValueRef ConditionalThunk::Execute( } ConditionalThunk::BufferUses ConditionalThunk::buffer_uses() const { - BufferUses buffer_uses; + BufferUses buffer_uses = {BufferUse::Read(branch_index_buffer_)}; for (const auto& branch_executor : branch_executors_) { BufferUses uses = branch_executor.buffer_uses(); buffer_uses.insert(buffer_uses.end(), uses.begin(), uses.end()); diff --git a/third_party/xla/xla/service/cpu/runtime/conditional_thunk_test.cc b/third_party/xla/xla/service/cpu/runtime/conditional_thunk_test.cc new file mode 100644 index 00000000000000..f3eeb7e3ba4739 --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/conditional_thunk_test.cc @@ -0,0 +1,69 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/runtime/conditional_thunk.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "xla/runtime/buffer_use.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/cpu/runtime/thunk.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla::cpu { +namespace { + +// A test-only thunk to create a Thunk with a specific buffer use. +class TestThunk : public Thunk { + public: + explicit TestThunk(BufferUse buffer_use) + : Thunk(Kind::kKernel, {"test"}), buffer_use_(buffer_use) {} + + tsl::AsyncValueRef Execute(const ExecuteParams&) final { + return absl::UnimplementedError("Unimplemented"); + } + + BufferUses buffer_uses() const final { return {buffer_use_}; } + + private: + BufferUse buffer_use_; +}; + +TEST(ConditionalThunkTest, BufferUses) { + BufferAllocation alloc(0, 1024, 0); + BufferAllocation::Slice branch_index_slice(&alloc, 0, sizeof(int32_t)); + BufferAllocation::Slice read_slice(&alloc, 10, 10); + + std::vector branch_sequences(1); + branch_sequences[0].push_back( + std::make_unique(BufferUse::Read(read_slice))); + + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, ConditionalThunk::Create({"conditional"}, branch_index_slice, + std::move(branch_sequences))); + + EXPECT_EQ(thunk->buffer_uses().size(), 2); + EXPECT_EQ(thunk->buffer_uses()[0], BufferUse::Read(branch_index_slice)); + EXPECT_EQ(thunk->buffer_uses()[1], BufferUse::Read(read_slice)); +} + +} // namespace +} // namespace xla::cpu From ed4deb885a3872838b16b5cd6b5f083952e13a6c Mon Sep 17 00:00:00 2001 From: Shraiysh Date: Thu, 20 Jun 2024 04:55:32 -0700 Subject: [PATCH 048/256] PR #13555: Fix _xla_send_recv_validation in collective pipeliner Imported from GitHub PR https://github.com/openxla/xla/pull/13555 This patch fixes the value of `_xla_send_recv_validation` attribute in collective-pipeliner. Collective pipeliner peels one iteration outside loop and updates the initialization variable of the loop. This means that the trip count of the loop reduces. More details about this are added as docstrings. Copybara import of the project: -- e92f6ba150aaffff8777cb6fa2debd4a323e203c by Shraiysh Vaishay : Fix _xla_send_recv_validation attribute when collective pipeliner runs This patch fixes the value of `_xla_send_recv_validation` attribute in collective-pipeliner. Collective pipeliner peels one iteration outside loop and updates the initialization variable of the loop. This means that the trip count of the loop reduces. More details about this are added as docstrings. Merging this change closes #13555 PiperOrigin-RevId: 644984505 --- third_party/xla/xla/service/BUILD | 4 + .../xla/xla/service/collective_pipeliner.cc | 143 ++++++++ .../xla/service/collective_pipeliner_test.cc | 310 ++++++++++++++++++ 3 files changed, 457 insertions(+) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index e1a420ff607158..26fd0b20551252 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -554,6 +554,7 @@ cc_library( srcs = ["collective_pipeliner.cc"], hdrs = ["collective_pipeliner.h"], deps = [ + ":collective_ops_utils", ":constant_value", ":hlo_dce", ":hlo_pass", @@ -566,7 +567,9 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_instruction_utils", "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_parser", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -593,6 +596,7 @@ xla_cc_test( "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", + "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc index 5feef7bfcf5fe1..998cbc634b0c9e 100644 --- a/third_party/xla/xla/service/collective_pipeliner.cc +++ b/third_party/xla/xla/service/collective_pipeliner.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instruction_utils.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" @@ -48,8 +49,10 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/map_util.h" #include "xla/primitive_util.h" +#include "xla/service/collective_ops_utils.h" #include "xla/service/constant_value.h" #include "xla/service/hlo_dce.h" +#include "xla/service/hlo_parser.h" #include "xla/service/value_range.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -1197,6 +1200,130 @@ HloInstruction* CreateZero(HloComputation* comp, const Shape& shape, } // namespace +using Interval = std::pair; +using Intervals = std::vector; +// Parses a string "{{a,b},{c,d},{e,f},...}" to a vector of pairs. +absl::StatusOr> ParseVectorOfPairs( + absl::string_view str) { + TF_ASSIGN_OR_RETURN(std::vector replica_groups, + ParseReplicaGroupsOnly(str)); + std::vector res; + res.reserve(replica_groups.size()); + for (const ReplicaGroup& replica_group : replica_groups) { + TF_RET_CHECK(replica_group.replica_ids_size() == 2); + int64_t a = replica_group.replica_ids(0); + int64_t b = replica_group.replica_ids(1); + res.emplace_back(a, b); + } + return res; +} + +// If there is a collective-permute instruction with _xla_send_recv_validation +// attribute in the computation, then during pipelining the loop trip count +// changes. This function fixes the attribute for the cloned instruction. +// +// For forward pipelining: A peeled collective permute is executed before the +// loop. This peeled collective permute runs for all the devices that were +// supposed to run on iteration 0. The second execution of collective-permute +// occurs in the second iteration of the old loop, but the first iteration of +// the transformed loop (because of the peeled instruction). Hence, the +// collective permutes inside the loop see iteration bounds reducing by 1. For +// example, let the original trip count be 7, and the attribute be +// {{0,4},{0,5},{1,5},{1,6},{2,6}}. For the peeled collective this attribute +// will become {{0,0},{0,0},{1,0},{1,0},{1,0}} and for the internal collective +// this will become {{0,3},{0,4},{0,4},{0,5},{1,5}} +// +// For backward pipelining: A peeled collective permute is executed after the +// loop. This peeled collective permute runs for all devices that were supposed +// to run the last iteration (trip_count-1). All the other executions, except +// the last instance of collective permute, except the last execution run +// as-it-is without any change to iteration bounds. So, only the devices that +// were supposed to run the last iteration instance see a change in their bounds +// inside the while loop. For example, let the original trip count be 7 and the +// attribute be {{0,4},{0,4},{1,5},{1,6},{2,6}}. For the peeled collective, this +// attribute will become {{1,0},{1,0},{1,0},{0,0},{0,0}} and for the collective +// inside while loop, this attribute will become +// {{0,4},{0,4},{1,5},{1,5},{2,5}}. +absl::Status UpdateSendRecvValidation( + HloInstruction* instruction, bool is_peeled, + CollectivePipeliner::PipeliningDirection direction, + const WhileLoopAnalysis& loop_analysis) { + if (instruction->opcode() != HloOpcode::kCollectivePermute) { + return absl::OkStatus(); + } + const auto& frontend_attributes = instruction->frontend_attributes().map(); + if (!frontend_attributes.contains(kSendRecvValidationAttr)) { + return absl::OkStatus(); + } + VLOG(3) << "Trip count = " + << loop_analysis.GetLoopIterationCount()->GetSignedValue(); + VLOG(3) << "Collective permute with _xla_send_recv_validation: " + << instruction->ToString(); + TF_ASSIGN_OR_RETURN( + Intervals old_intervals, + ParseVectorOfPairs(frontend_attributes.at(kSendRecvValidationAttr))); + + Intervals intervals; + + if (direction == CollectivePipeliner::kForward) { + // It is a forward pipelining which means that the peeled collective permute + // is before the loop. It should run once for the devices executing the + // first iteration and the internal collective permute now sees each + // original iteration decreased by one. + // + // peeled collective permute: + // {{0,0} if {a,b} in old and a<=0<=b, {1,0} otherwise} + // internal collective permute: {{max(0, a-1), max(0, b-1)} | {a,b} in old} + for (auto [a, b] : old_intervals) { + if (is_peeled) { + if (a <= 0 && 0 <= b) { + intervals.push_back({0, 0}); + } else { + intervals.push_back({1, 0}); + } + } else { + intervals.push_back({std::max(0l, a - 1), std::max(0l, b - 1)}); + } + } + } else if (direction == CollectivePipeliner::kBackward) { + // It is a backward pipelining which means that the peeled collective is + // after the loop. It should run once for the devices executing the last + // iteration and the internal collective permute doesn't see the last + // iteration. + // + // peeled collective permute: + // {{0,0} if {a,b} in old and a<=n<=b where n=#last_iteration, {1,0} + // otherwise} + // interval collective permute: + // {{a,min(n-1,b)} | {a,b} in old and n=#last_iteration} + auto trip_count_value = loop_analysis.GetLoopIterationCount(); + if (!trip_count_value) { + return absl::InternalError( + "Unable to deduce loop trip count in collective pipeliner. This is " + "required for backward pipelining while fixing the " + "_xla_send_recv_validation attribute"); + } + int64_t trip_count = trip_count_value->GetSignedValue(); + int64_t last_iteration = trip_count - 1; + for (auto [a, b] : old_intervals) { + if (is_peeled) { + if (a <= last_iteration && last_iteration <= b) { + intervals.push_back({0, 0}); + } else { + intervals.push_back({1, 0}); + } + } else { + intervals.push_back({a, std::min(last_iteration - 1, b)}); + } + } + } + hlo_instruction_utils::AddOrUpdateVectorOfPairsAsAttribute( + instruction, kSendRecvValidationAttr, intervals); + VLOG(3) << "Updated collective_permute with _xla_send_recv_validation: " + << instruction->ToString(); + return absl::OkStatus(); +} + // Function that does the work of pushing forward instructions that have been // determined that can be pipelined. Rough transformation: while (i < LAYERS) { // p0 = param(0) @@ -1324,6 +1451,9 @@ absl::Status TransformLoopForward( TF_RETURN_IF_ERROR( UpdateControlDependencies(instr, cloned_instr, while_body_to_peeled)); UpdateInstructionChannelId(cloned_instr, next_channel_id); + TF_RETURN_IF_ERROR(UpdateSendRecvValidation( + cloned_instr, true, CollectivePipeliner::PipeliningDirection::kForward, + loop_analysis)); while_body_to_peeled[instr] = cloned_instr; auto output_it = is_output_instruction.find(instr); if (output_it != is_output_instruction.end()) { @@ -1386,6 +1516,11 @@ absl::Status TransformLoopForward( HloComputation* new_while_body = loop_computation->parent()->AddEmbeddedComputation( while_body->CloneWithReplacements(&replacements)); + for (HloInstruction* instruction : new_while_body->instructions()) { + TF_RETURN_IF_ERROR(UpdateSendRecvValidation( + instruction, false, CollectivePipeliner::PipeliningDirection::kForward, + loop_analysis)); + } HloInstruction* new_init = loop_computation->AddInstruction( HloInstruction::CreateTuple(new_init_operands)); while_body_to_peeled[while_body->root_instruction()] = new_init; @@ -2336,6 +2471,11 @@ static absl::Status TransformLoopBackward( TF_RETURN_IF_ERROR(UpdateControlDependencies(while_body->root_instruction(), new_loop_root, while_body_replacement_map)); + for (HloInstruction* instruction : new_while_body->instructions()) { + TF_RETURN_IF_ERROR(UpdateSendRecvValidation( + instruction, false, CollectivePipeliner::PipeliningDirection::kBackward, + loop_analysis)); + } absl::flat_hash_map loop_cond_replacements; auto cond_builder = @@ -2414,6 +2554,9 @@ static absl::Status TransformLoopBackward( TF_RETURN_IF_ERROR(UpdateControlDependencies(instr, cloned_instr, while_body_replacement_map)); UpdateInstructionChannelId(cloned_instr, next_channel_id); + TF_RETURN_IF_ERROR(UpdateSendRecvValidation( + cloned_instr, true, CollectivePipeliner::PipeliningDirection::kBackward, + loop_analysis)); while_body_replacement_map[instr] = cloned_instr; if (instruction_is_output_it != is_output_instruction.end()) { output_tuple_instructions[instruction_is_output_it->second] = diff --git a/third_party/xla/xla/service/collective_pipeliner_test.cc b/third_party/xla/xla/service/collective_pipeliner_test.cc index 87f25d724c54d0..520206a43c142a 100644 --- a/third_party/xla/xla/service/collective_pipeliner_test.cc +++ b/third_party/xla/xla/service/collective_pipeliner_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/hlo_parser.h" #include "xla/service/hlo_pass_pipeline.h" +#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" @@ -173,6 +174,157 @@ ENTRY entry { EXPECT_EQ(get_tuple_index->tuple_index(), 3); } +TEST_F(CollectivePipelinerTest, TransformIncrementIndexByOneCollectivePermute) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(14) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2 + cp = bf16[3,8,128] collective-permute(get-tuple-element.5), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,0}}, + frontend_attributes={_xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9},{4,10},{5,11},{6,12},{7,13}}"} + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(14) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + constant.2561 = s32[] constant(0) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(13) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(cp, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) + ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=2 + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, ar.1, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0) + while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + config_.set_num_partitions(8); + config_.set_replica_count(1); + auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value(); + EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true).value()); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( + // CHECK: HloModule + // CHECK: %while_body + // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation="{{[{]}}{0,5},{0,6},{1,7},{2,8},{3,9},{4,10},{5,11},{6,12}{{[}]}}" + // CHECK: %[[dus:.+]] = {{.+}} dynamic-slice({{.+}} %[[cp]], {{.+}}) + // CHECK: %[[mul:.+]] = {{.+}} multiply({{.+}} %[[dus]], {{.+}} %[[dus]]) + // CHECK: %[[dus2:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[mul]], {{.+}}) + // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[dus2]], {{.+}}) + // CHECK: } + // CHECK: ENTRY %entry + // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}{_xla_send_recv_validation="{{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}}" + // CHECK: %[[ds:.+]] = {{.+}} dynamic-slice({{.+}} %[[cp]], {{.+}}) + // CHECK: %[[mul:.+]] = {{.+}} multiply({{.+}} %[[ds]], {{.+}} %[[ds]]) + // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[mul]], {{.+}}) + // CHECK: %[[tuple:.+]] = {{.+}} tuple({{.+}} %[[dus]], {{.+}}) + // CHECK: {{.+}} = {{.+}} while({{.+}} %[[tuple]]) + // CHECK: } + )")); +} + +TEST_F(CollectivePipelinerTest, + TransformIncrementIndexByOneCollectivePermuteBackwardCycle) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(14) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2 + cp = bf16[3,8,128] collective-permute(get-tuple-element.5), channel_id=1, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}}, + frontend_attributes={_xla_send_recv_validation="{{7,13},{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6}}"} + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(14) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + constant.2561 = s32[] constant(0) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(13) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(cp, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) + ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=2 + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, ar.1, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0) + while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + config_.set_num_partitions(8); + config_.set_replica_count(1); + auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value(); + EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true).value()); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( + // CHECK: HloModule + // CHECK: %while_body + // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation="{{[{]}}{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6},{0,5}{{[}]}}" + // CHECK: %[[dus:.+]] = {{.+}} dynamic-slice({{.+}} %[[cp]], {{.+}}) + // CHECK: %[[mul:.+]] = {{.+}} multiply({{.+}} %[[dus]], {{.+}} %[[dus]]) + // CHECK: %[[dus2:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[mul]], {{.+}}) + // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[dus2]], {{.+}}) + // CHECK: } + // CHECK: ENTRY %entry + // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}{_xla_send_recv_validation="{{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}}" + // CHECK: %[[ds:.+]] = {{.+}} dynamic-slice({{.+}} %[[cp]], {{.+}}) + // CHECK: %[[mul:.+]] = {{.+}} multiply({{.+}} %[[ds]], {{.+}} %[[ds]]) + // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[mul]], {{.+}}) + // CHECK: %[[tuple:.+]] = {{.+}} tuple({{.+}} %[[dus]], {{.+}}) + // CHECK: {{.+}} = {{.+}} while({{.+}} %[[tuple]]) + // CHECK: } + )")); +} + TEST_F(CollectivePipelinerTest, UpdateSendRecvChannelIdForHostTransfers) { constexpr absl::string_view hlo_string = R"( HloModule module @@ -1186,6 +1338,164 @@ ENTRY entry { EXPECT_EQ(add_instr_loop->opcode(), HloOpcode::kAdd); } +TEST_F(CollectivePipelinerTest, + TransformIncrementIndexByOneBackwardsCollectivePermute) { + constexpr absl::string_view hlo_string = R"( +HloModule module +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,1,2,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(14) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} +while_body { + param = (s32[], bf16[3,8,128], bf16[3,1,2,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.k = bf16[3,1,2,128] get-tuple-element(param), index=2 + cp = bf16[3,8,128] collective-permute(get-tuple-element.395), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,0}}, + frontend_attributes={_xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9},{4,10},{5,11},{6,12},{7,13}}"} + constant.2561 = s32[] constant(0) + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(14) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(13) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.k = bf16[1,1,2,128] dynamic-slice(get-tuple-element.k, select.1348, constant.2561, constant.2561, constant.2561), dynamic_slice_sizes={1,1,2,128} + r = bf16[1,2,128] reshape(dynamic-slice.k) + a = bf16[1,2,128] add(r, r), control-predecessors={constant.2559} + ag = bf16[1,8,128] all-gather(a), dimensions={1}, replica_groups={} + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(cp, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, ag) + ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=2 + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(cp, ar.1, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,1,2,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.k), control-predecessors={a} +} +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + p1 = bf16[3,1,2,128] parameter(1) + tuple = (s32[], bf16[3,8,128], bf16[3,1,2,128]) tuple(c0, p0, p1) + while = (s32[], bf16[3,8,128], bf16[3,1,2,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + config_.set_num_partitions(8); + config_.set_replica_count(4); + auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value(); + EXPECT_TRUE( + RunOptimizer( + module.get(), /*last_run=*/true, /*level_to_operate_on=*/0, + /*pipeline_use_tree=*/false, + /*process_different_sized_ops=*/false, + /*direction=*/CollectivePipeliner::PipeliningDirection::kBackward, + /*should_process=*/IsAllGather) + .value()); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( + // CHECK: %while_body + // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation="{{[{]}}{0,6},{1,7},{2,8},{3,9},{4,10},{5,11},{6,12},{7,12}{{[}]}}"} + // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[cp]], {{.+}}) + // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[dus]], {{.+}}) + // CHECK: ENTRY %entry + // CHECK: %[[while:.+]] = {{.+}} while({{.+}}) + // CHECK: %[[gte:.+]] = {{.+}} get-tuple-element({{.+}} %[[while]]), index=1 + // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[gte]]), {{.+}}_xla_send_recv_validation="{{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}}" + // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[cp2]], {{.+}}) + // CHECK: %[[tuple:.+]] = {{.+}} tuple({{.+}} %[[dus]], {{.+}}) + // CHECK: ROOT {{.+}} = {{.+}} get-tuple-element({{.+}} %[[tuple]]), index=1 + )")); +} + +TEST_F(CollectivePipelinerTest, + TransformIncrementIndexByOneBackwardsCollectivePermuteBackwardCycle) { + constexpr absl::string_view hlo_string = R"( +HloModule module +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,1,2,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(14) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} +while_body { + param = (s32[], bf16[3,8,128], bf16[3,1,2,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.k = bf16[3,1,2,128] get-tuple-element(param), index=2 + cp = bf16[3,8,128] collective-permute(get-tuple-element.395), channel_id=1, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}}, + frontend_attributes={_xla_send_recv_validation="{{7,13},{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6}}"} + constant.2561 = s32[] constant(0) + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(14) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(13) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.k = bf16[1,1,2,128] dynamic-slice(get-tuple-element.k, select.1348, constant.2561, constant.2561, constant.2561), dynamic_slice_sizes={1,1,2,128} + r = bf16[1,2,128] reshape(dynamic-slice.k) + a = bf16[1,2,128] add(r, r), control-predecessors={constant.2559} + ag = bf16[1,8,128] all-gather(a), dimensions={1}, replica_groups={} + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(cp, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, ag) + ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=2 + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(cp, ar.1, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,1,2,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.k), control-predecessors={a} +} +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + p1 = bf16[3,1,2,128] parameter(1) + tuple = (s32[], bf16[3,8,128], bf16[3,1,2,128]) tuple(c0, p0, p1) + while = (s32[], bf16[3,8,128], bf16[3,1,2,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + config_.set_num_partitions(8); + config_.set_replica_count(4); + auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value(); + EXPECT_TRUE( + RunOptimizer( + module.get(), /*last_run=*/true, /*level_to_operate_on=*/0, + /*pipeline_use_tree=*/false, + /*process_different_sized_ops=*/false, + /*direction=*/CollectivePipeliner::PipeliningDirection::kBackward, + /*should_process=*/IsAllGather) + .value()); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( + // CHECK: %while_body + // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation="{{[{]}}{7,12},{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6}{{[}]}}"} + // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[cp]], {{.+}}) + // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[dus]], {{.+}}) + // CHECK: ENTRY %entry + // CHECK: %[[while:.+]] = {{.+}} while({{.+}}) + // CHECK: %[[gte:.+]] = {{.+}} get-tuple-element({{.+}} %[[while]]), index=1 + // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[gte]]), {{.+}}_xla_send_recv_validation="{{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}}" + // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[cp2]], {{.+}}) + // CHECK: %[[tuple:.+]] = {{.+}} tuple({{.+}} %[[dus]], {{.+}}) + // CHECK: ROOT {{.+}} = {{.+}} get-tuple-element({{.+}} %[[tuple]]), index=1 + )")); +} + TEST_F(CollectivePipelinerTest, TransformIncrementIndexByOneBackwardsModifyOut) { constexpr absl::string_view hlo_string = R"( From 2952f336f14e1c1aca84d46cf3098a34d3254bf3 Mon Sep 17 00:00:00 2001 From: Quentin Khan Date: Thu, 20 Jun 2024 05:51:42 -0700 Subject: [PATCH 049/256] Fix SDPA testing on different devices. PiperOrigin-RevId: 644999007 --- tensorflow/lite/delegates/xnnpack/BUILD | 10 +++++----- tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.cc | 10 ++++++---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD index eaf105cb86e4b8..45b4edc791c1d8 100644 --- a/tensorflow/lite/delegates/xnnpack/BUILD +++ b/tensorflow/lite/delegates/xnnpack/BUILD @@ -596,11 +596,6 @@ cc_library( testonly = 1, srcs = ["odml_sdpa_tester.cc"], hdrs = ["odml_sdpa_tester.h"], - data = [ - "odml_sdpa_composite_gqa.tflite", - "odml_sdpa_composite_mha.tflite", - "odml_sdpa_composite_mqa.tflite", - ], deps = [ "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", @@ -1632,6 +1627,11 @@ cc_test( cc_test( name = "odml_sdpa_test", srcs = ["odml_sdpa_test.cc"], + data = [ + ":odml_sdpa_composite_gqa.tflite", + ":odml_sdpa_composite_mha.tflite", + ":odml_sdpa_composite_mqa.tflite", + ], linkopts = select({ "//tensorflow:emscripten": EMSCRIPTEN_LINKOPTS, "//conditions:default": [], diff --git a/tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.cc b/tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.cc index b840c889f27181..c7714a816bf953 100644 --- a/tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.cc @@ -119,11 +119,13 @@ void ODMLSDPATester::Test(TfLiteDelegate* delegate) const { std::vector ODMLSDPATester::CreateTfLiteModel() const { if (!model_name_.empty() && model_name_ != kOdmlSdpaCustom) { const char kTestModelFolder[] = - "/tensorflow/lite/delegates/xnnpack/"; - const std::string test_model = - testing::SrcDir() + kTestModelFolder + model_name_ + ".tflite"; + "third_party/tensorflow/lite/delegates/xnnpack/"; + const std::string test_model = kTestModelFolder + model_name_ + ".tflite"; std::string model_data; - flatbuffers::LoadFile(test_model.c_str(), /*binary=*/true, &model_data); + if (!flatbuffers::LoadFile(test_model.c_str(), /*binary=*/true, + &model_data)) { + ADD_FAILURE() << "file not loaded: " << test_model; + } return std::vector(model_data.begin(), model_data.end()); } else { flatbuffers::FlatBufferBuilder builder; From 350ecac073def5c3f12e8fe95b3957b307f787a2 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 20 Jun 2024 05:54:05 -0700 Subject: [PATCH 050/256] Add unit tests for CanEmitFusedDynamicUpdateSliceInPlaceForGpu(). We used to test this indirectly via checking the emitted code (tests/dynamic_update_slice_inplace.hlo). It is better to test it in a unit test. Also delete dead code. PiperOrigin-RevId: 644999517 --- third_party/xla/xla/service/gpu/BUILD | 7 +- .../xla/xla/service/gpu/fusions/fusions.cc | 6 +- .../xla/xla/service/gpu/ir_emission_utils.cc | 67 +--- .../xla/xla/service/gpu/ir_emission_utils.h | 26 +- .../xla/service/gpu/ir_emission_utils_test.cc | 378 ++++++++++++++++++ .../xla/xla/service/gpu/priority_fusion.h | 1 + .../gpu/runtime/nccl_collective_thunk.cc | 5 - 7 files changed, 397 insertions(+), 93 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 509b11d8e609de..3b39b1f8363f56 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1091,10 +1091,6 @@ cc_library( "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", "@llvm-project//llvm:TargetParser", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:SideEffectInterfaces", - "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], @@ -1110,8 +1106,11 @@ xla_cc_test( "//xla:literal_util", "//xla:types", "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.cc b/third_party/xla/xla/service/gpu/fusions/fusions.cc index 6a241da5af54fe..52ddb7246534f0 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusions.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusions.cc @@ -95,7 +95,11 @@ std::optional> HloFusionInfo::GetCopyFusion() bool HloFusionInfo::CanEmitDynamicUpdateSliceInPlace() const { auto ret = CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - instr_, buffer_assignment_, analysis().fusion_roots()); + instr_, + [this](const HloInstruction* instruction, const ShapeIndex& index) { + return GetAllocationSlice(*buffer_assignment_, instruction, index); + }, + analysis().fusion_roots()); return ret.ok() && *ret; } diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.cc b/third_party/xla/xla/service/gpu/ir_emission_utils.cc index 5b73e28aa8f0c6..d740e6a14a8247 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.cc @@ -43,15 +43,6 @@ limitations under the License. #include "llvm/IR/Verifier.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TargetParser/Triple.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -90,22 +81,6 @@ bool IsRank1(const Shape& shape, int64_t batch_dimensions_size) { return shape.rank() == batch_dimensions_size + 1; } -Shape GetShapeFromTensorType(mlir::Value value) { - constexpr char kDefaultLayoutAttrName[] = "xla_shape"; - - mlir::Operation* op = value.getDefiningOp(); - CHECK(op); - CHECK(mlir::isa(value.getType())); - Shape shape; - if (auto attr = op->getAttrOfType(kDefaultLayoutAttrName)) { - shape = *xla::ParseShape( - absl::string_view(attr.getValue().data(), attr.getValue().size())); - } else { - shape = TypeToShape(value.getType()); - } - return shape; -} - } // namespace bool IsMatrixMultiplication(const HloInstruction& dot) { @@ -365,16 +340,6 @@ llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) { return b->CreateAnd(is_thread0, is_block0); } -bool WritesMlirBuffer(mlir::Operation* op, mlir::Value operand) { - llvm::SmallVector effects; - mlir::cast(op).getEffectsOnValue(operand, - effects); - return absl::c_any_of( - effects, [](const mlir::MemoryEffects::EffectInstance& instance) { - return mlir::isa(instance.getEffect()); - }); -} - absl::StatusOr GetAllocationSlice( const BufferAssignment& buffer_assignment, const HloInstruction* instr, const ShapeIndex& index) { @@ -409,7 +374,9 @@ absl::InlinedVector GetStartIndices(T instr) { absl::StatusOr CanEmitFusedDynamicUpdateSliceInPlaceForGpu( const HloFusionInstruction* fusion, - const BufferAssignment* buffer_assignment, + std::function( + const HloInstruction* instr, const ShapeIndex& index)> + get_allocation_slice, absl::Span roots) { std::vector dus_instrs = GetOutputDefiningDynamicUpdateSlices(roots); @@ -420,7 +387,7 @@ absl::StatusOr CanEmitFusedDynamicUpdateSliceInPlaceForGpu( fusion->shape(), [&](const Shape& shape, const ShapeIndex index) { if (shape.IsArray()) { TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer, - buffer_assignment->GetUniqueSlice(fusion, index)); + get_allocation_slice(fusion, index)); output_buffers.push_back(buffer); } return absl::OkStatus(); @@ -545,7 +512,7 @@ absl::StatusOr CanEmitFusedDynamicUpdateSliceInPlaceForGpu( const HloInstruction* lhs = fusion->operand(parameter->parameter_number()); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_buffer, - buffer_assignment->GetUniqueSlice(lhs, {})); + get_allocation_slice(lhs, {})); BufferAllocation::Slice rhs_buffer = output_buffers[i]; if (lhs_buffer != rhs_buffer) { return false; @@ -555,25 +522,6 @@ absl::StatusOr CanEmitFusedDynamicUpdateSliceInPlaceForGpu( return true; } -Shape GetShape(mlir::Value value) { - Shape shape; - if (mlir::isa(value.getType())) { - shape = TypeToShape(value.getType()); - } else if (mlir::isa(value.getType())) { - shape = GetShapeFromTensorType(value); - } else if (mlir::isa(value.getType())) { - shape = TypeToShape(value.getType()); - } else { - LOG(FATAL) << "Unexpected value type to get shape for"; - } - if (primitive_util::IsSubByteNonPredType(shape.element_type())) { - // 4-bit types are always packed on the GPU - shape.mutable_layout()->set_element_size_in_bits( - primitive_util::BitWidth(shape.element_type())); - } - return shape; -} - static std::optional FindTiledTranspose( const HloInstruction& instr) { if (instr.opcode() != HloOpcode::kCopy) { @@ -833,11 +781,6 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, return b->getInt32Ty(); } -std::string GetIrNameFromLoc(mlir::Location loc) { - return llvm_ir::SanitizeConstantName( - mlir::mhlo::GetDebugNameFromLocation(loc)); -} - bool IsAMDGPU(const llvm::Module* module) { return llvm::Triple(module->getTargetTriple()).isAMDGPU(); } diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.h b/third_party/xla/xla/service/gpu/ir_emission_utils.h index b9ad5e304f7c1f..3fcc31e65df47d 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.h +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.h @@ -29,8 +29,6 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/literal.h" @@ -122,22 +120,17 @@ llvm::Value* EmitFullWarpShuffleDown( // block 0 of the kernel. llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b); -llvm::SmallVector GetHloOperands(mlir::Operation* op); -llvm::SmallVector GetHloOutputs(mlir::Operation* op); - -bool WritesMlirBuffer(mlir::Operation* op, mlir::Value operand); - -absl::StatusOr GetAllocationSlice( - mlir::Value v, absl::Span allocations, - std::string* constant_name = nullptr); - absl::StatusOr GetAllocationSlice( const BufferAssignment& buffer_assignment, const HloInstruction* instr, const ShapeIndex& index); +// Returns whether 'fusion' can be emitted with the dynamic update slice +// in-place emitter. absl::StatusOr CanEmitFusedDynamicUpdateSliceInPlaceForGpu( const HloFusionInstruction* fusion, - const BufferAssignment* buffer_assignment, + std::function( + const HloInstruction* instr, const ShapeIndex& index)> + get_allocation_slice, absl::Span roots); // Returns the dynamic-update-slice instructions defining the results of a @@ -148,8 +141,6 @@ absl::StatusOr CanEmitFusedDynamicUpdateSliceInPlaceForGpu( std::vector GetOutputDefiningDynamicUpdateSlices( absl::Span roots); -Shape GetShape(mlir::Value value); - // Returns the first hero instruction reachable from `instr` as root. Hero // instruction can be in a different computation if the parent HloFusionAdaptor // is a producer-consumer fusion. @@ -209,13 +200,6 @@ void VerifyModule(const llvm::Module& module); llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64_t launch_size, llvm::IRBuilder<>* b); -// The same as GetIndexTypeForKernel, but works with MLIR ops. -llvm::Type* GetIndexTypeForKernel(mlir::Operation* op, int64_t launch_size, - llvm::IRBuilder<>* b); - -// Returns a sanitized (doesn't need quoting) identifier name from a location. -std::string GetIrNameFromLoc(mlir::Location loc); - // Whether the module's target is an AMD GPU. bool IsAMDGPU(const llvm::Module* module); diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc index 37af036f42ece5..4f6abda1058a8a 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc @@ -19,18 +19,23 @@ limitations under the License. #include #include +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/service/buffer_assignment.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/tests/hlo_test_base.h" #include "xla/types.h" #include "xla/util.h" +#include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { namespace gpu { +using ::tsl::testing::IsOkAndHolds; + class IrEmissionUtilsTest : public HloTestBase {}; TEST_F(IrEmissionUtilsTest, FindTiledLogicalTranspose) { @@ -654,5 +659,378 @@ TEST_F(IrEmissionUtilsTest, LiteralToAttrToXlaFormat) { } } +TEST_F(IrEmissionUtilsTest, + CanEmitFusedDynamicUpdateSliceInPlaceForGpu_HandlesBitcasts) { + const char* hlo = R"( +HloModule fusion, is_scheduled=true + +fused_computation { + param_0.1 = s32[6]{0} parameter(0) + bitcast = s32[2,3]{1,0} bitcast(param_0.1) + zero = s32[] constant(0) + param_1.1 = s32[] parameter(1) + dynamic-slice = s32[1,1]{1,0} dynamic-slice(bitcast, param_1.1, zero), dynamic_slice_sizes={1,1} + one = s32[] constant(1) + bitcasted_one = s32[1,1]{1,0} bitcast(one) + add = s32[1,1] add(dynamic-slice, bitcasted_one) + dynamic-update-slice = s32[2,3]{1,0} dynamic-update-slice(bitcast, add, param_1.1, zero) + ROOT bitcast.1 = s32[6]{0} bitcast(dynamic-update-slice) +} + +ENTRY main { + param_0 = s32[6]{0} parameter(0) + param_1 = s32[] parameter(1) + ROOT fusion = s32[6]{0} fusion(param_0, param_1), kind=kInput, calls=fused_computation +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto fusion = module->entry_computation()->root_instruction(); + BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); + BufferAllocation::Slice slice0(&alloc, 0, 10); + EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( + Cast(fusion), + [&slice0](const HloInstruction*, const ShapeIndex&) { + return slice0; + }, + HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + IsOkAndHolds(true)); +} + +TEST_F( + IrEmissionUtilsTest, + CanEmitFusedDynamicUpdateSliceInPlaceForGpu_ElementwiseOnPathToParameter) { + const char* hlo = R"( +HloModule fusion, is_scheduled=true + +fused_computation { + param_0.1 = s32[2,3]{1,0} parameter(0) + bitcast = s32[2,3]{1,0} negate(param_0.1) + zero = s32[] constant(0) + param_1.1 = s32[] parameter(1) + dynamic-slice = s32[1,1]{1,0} dynamic-slice(bitcast, param_1.1, zero), dynamic_slice_sizes={1,1} + one = s32[] constant(1) + bitcasted_one = s32[1,1]{1,0} bitcast(one) + add = s32[1,1] add(dynamic-slice, bitcasted_one) + dynamic-update-slice = s32[2,3]{1,0} dynamic-update-slice(bitcast, add, param_1.1, zero) + ROOT bitcast.1 = s32[6]{0} bitcast(dynamic-update-slice) +} + +ENTRY main { + param_0 = s32[2,3]{1,0} parameter(0) + param_1 = s32[] parameter(1) + ROOT fusion = s32[6]{0} fusion(param_0, param_1), kind=kInput, calls=fused_computation +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto fusion = module->entry_computation()->root_instruction(); + BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); + BufferAllocation::Slice slice0(&alloc, 0, 10); + EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( + Cast(fusion), + [&slice0](const HloInstruction*, const ShapeIndex&) { + return slice0; + }, + HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + IsOkAndHolds(false)); +} + +// Same test as above, but different allocation slices for parameter and output. +TEST_F(IrEmissionUtilsTest, + CanEmitFusedDynamicUpdateSliceInPlaceForGpu_SlicesDifferent) { + const char* hlo = R"( +HloModule fusion, is_scheduled=true + +fused_computation { + param_0.1 = s32[6]{0} parameter(0) + bitcast = s32[2,3]{1,0} bitcast(param_0.1) + zero = s32[] constant(0) + param_1.1 = s32[] parameter(1) + dynamic-slice = s32[1,1]{1,0} dynamic-slice(bitcast, param_1.1, zero), dynamic_slice_sizes={1,1} + one = s32[] constant(1) + bitcasted_one = s32[1,1]{1,0} bitcast(one) + add = s32[1,1] add(dynamic-slice, bitcasted_one) + dynamic-update-slice = s32[2,3]{1,0} dynamic-update-slice(bitcast, add, param_1.1, zero) + ROOT bitcast.1 = s32[6]{0} bitcast(dynamic-update-slice) +} + +ENTRY main { + param_0 = s32[6]{0} parameter(0) + param_1 = s32[] parameter(1) + ROOT fusion = s32[6]{0} fusion(param_0, param_1), kind=kInput, calls=fused_computation +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto fusion = module->entry_computation()->root_instruction(); + BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); + BufferAllocation::Slice slice0(&alloc, 0, 10); + BufferAllocation::Slice slice1(&alloc, 10, 20); + EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( + Cast(fusion), + [fusion, &slice0, &slice1](const HloInstruction* instr, + const ShapeIndex&) { + if (instr == fusion) { + return slice0; + } + return slice1; + }, + HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + IsOkAndHolds(false)); +} + +TEST_F( + IrEmissionUtilsTest, + CanEmitFusedDynamicUpdateSliceInPlaceForGpu_DynamicUpdateSliceWithDifferentDynamicSliceAccess) { // NOLINT + const char* hlo = R"( +HloModule fusion, input_output_alias={ {}: (0, {}) } + +fused_computation { + param_0.1 = s32[6]{0} parameter(0) + bitcast = s32[2,3]{1,0} bitcast(param_0.1) + zero = s32[] constant(0) + one = s32[] constant(1) + param_1.1 = s32[] parameter(1) + dynamic-slice = s32[2,2]{1,0} dynamic-slice(bitcast, param_1.1, one), dynamic_slice_sizes={2,2} + broadcasted_one = s32[2,2]{1,0} broadcast(one), dimensions={} + add = s32[2,2] add(dynamic-slice, broadcasted_one) + dynamic-update-slice = s32[2,3]{1,0} dynamic-update-slice(bitcast, add, param_1.1, zero) + ROOT bitcast.1 = s32[6]{0} bitcast(dynamic-update-slice) +} + +ENTRY main { + param_0 = s32[6]{0} parameter(0) + param_1 = s32[] parameter(1) + ROOT fusion = s32[6]{0} fusion(param_0, param_1), kind=kInput, calls=fused_computation +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto fusion = module->entry_computation()->root_instruction(); + BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); + BufferAllocation::Slice slice0(&alloc, 0, 10); + EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( + Cast(fusion), + [&slice0](const HloInstruction*, const ShapeIndex&) { + return slice0; + }, + HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + IsOkAndHolds(false)); +} + +TEST_F(IrEmissionUtilsTest, + CanEmitFusedDynamicUpdateSliceInPlaceForGpu_HandlesMultiOutputFusion) { + const char* hlo = R"( +HloModule MultipleInplaceDus, is_scheduled=true, input_output_alias={ {0}: (0, {}), {1}: (2, {}) } + +fused_computation { + p0 = bf16[10,11,12] parameter(0) + p1 = bf16[1,11,12] parameter(1) + p2 = bf16[8,11,12] parameter(2) + p3 = bf16[1,11,12] parameter(3) + p4 = s32[] parameter(4) + c0 = s32[] constant(0) + cmp = pred[] compare(p4, c0), direction=EQ + broadcast = pred[1,11,12] broadcast(cmp), dimensions={} + select = bf16[1,11,12] select(broadcast, p1, p3) + dus0 = bf16[10,11,12] dynamic-update-slice(p0, select, c0, c0, c0) + dus1 = bf16[8,11,12] dynamic-update-slice(p2, select, c0, c0, c0) + ROOT tuple = (bf16[10,11,12], bf16[8,11,12]) tuple(dus0, dus1) +} + +ENTRY main { + p0 = bf16[10,11,12] parameter(0) + p1 = bf16[1,11,12] parameter(1) + p2 = bf16[8,11,12] parameter(2) + p3 = bf16[1,11,12] parameter(3) + p4 = s32[] parameter(4) + ROOT fusion_root_multiple = (bf16[10,11,12], bf16[8,11,12]) fusion(p0, p1, p2, p3, p4), kind=kLoop, calls=fused_computation +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto fusion = module->entry_computation()->root_instruction(); + BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); + BufferAllocation::Slice slice0(&alloc, 0, 10); + EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( + Cast(fusion), + [&slice0](const HloInstruction*, const ShapeIndex&) { + return slice0; + }, + HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + IsOkAndHolds(true)); +} + +TEST_F( + IrEmissionUtilsTest, + CanEmitFusedDynamicUpdateSliceInPlaceForGpu_HandlesMultiOutputFusionWithTransposeBitcasts) { // NOLINT + const char* hlo = R"( +HloModule MultipleInplaceDusWithTransposeBitcastToTheRoot, is_scheduled=true, input_output_alias={ {0}: (0, {}), {1}: (2, {}) } + +fused_computation { + p0 = bf16[10,11,12] parameter(0) + p1 = bf16[1,11,12] parameter(1) + p2 = bf16[8,11,12] parameter(2) + p3 = bf16[1,11,12] parameter(3) + p4 = s32[] parameter(4) + c0 = s32[] constant(0) + cmp = pred[] compare(p4, c0), direction=EQ + broadcast = pred[1,11,12] broadcast(cmp), dimensions={} + select = bf16[1,11,12] select(broadcast, p1, p3) + dus0 = bf16[10,11,12] dynamic-update-slice(p0, select, c0, c0, c0) + bitcasted_dus0 = bf16[11,10,12] bitcast(dus0) + dus1 = bf16[8,11,12] dynamic-update-slice(p2, select, c0, c0, c0) + ROOT tuple = (bf16[11,10,12], bf16[8,11,12]) tuple(bitcasted_dus0, dus1) +} + +ENTRY main { + p0 = bf16[10,11,12] parameter(0) + p1 = bf16[1,11,12] parameter(1) + p2 = bf16[8,11,12] parameter(2) + p3 = bf16[1,11,12] parameter(3) + p4 = s32[] parameter(4) + ROOT fusion_root_multiple_transpose_bitcast = (bf16[11,10,12], bf16[8,11,12]) fusion(p0, p1, p2, p3, p4), kind=kLoop, calls=fused_computation +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto fusion = module->entry_computation()->root_instruction(); + BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); + BufferAllocation::Slice slice0(&alloc, 0, 10); + EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( + Cast(fusion), + [&slice0](const HloInstruction*, const ShapeIndex&) { + return slice0; + }, + HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + IsOkAndHolds(true)); +} + +TEST_F( + IrEmissionUtilsTest, + CanEmitFusedDynamicUpdateSliceInPlaceForGpu_HandlesTransposeBitcastToTheRoot) { // NOLINT + const char* hlo = R"( +HloModule SingleInplaceDusWithTransposeBitcastToTheRoot, is_scheduled=true, input_output_alias={ {}: (0, {}) } + +single_inplace_dus_with_transpose_bitcast { + p0 = bf16[10,11,12] parameter(0) + p1 = bf16[1,11,12] parameter(1) + p2 = bf16[1,11,12] parameter(2) + p3 = s32[] parameter(3) + c0 = s32[] constant(0) + cmp = pred[] compare(p3, c0), direction=EQ + broadcast = pred[1,11,12] broadcast(cmp), dimensions={} + select = bf16[1,11,12] select(broadcast, p1, p2) + dus0 = bf16[10,11,12] dynamic-update-slice(p0, select, c0, c0, c0) + ROOT bitcasted_dus0 = bf16[11,10,12] bitcast(dus0) +} + +ENTRY main { + p0 = bf16[10,11,12] parameter(0) + p1 = bf16[1,11,12] parameter(1) + p2 = bf16[1,11,12] parameter(2) + p3 = s32[] parameter(3) + ROOT fusion_root_transpose_bitcast = bf16[11,10,12] fusion(p0, p1, p2, p3), kind=kLoop, calls=single_inplace_dus_with_transpose_bitcast +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto fusion = module->entry_computation()->root_instruction(); + BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); + BufferAllocation::Slice slice0(&alloc, 0, 10); + EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( + Cast(fusion), + [&slice0](const HloInstruction*, const ShapeIndex&) { + return slice0; + }, + HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + IsOkAndHolds(true)); +} + +TEST_F( + IrEmissionUtilsTest, + CanEmitFusedDynamicUpdateSliceInPlaceForGpu_HandlesReshapeBitcastToTheRoot) { // NOLINT + const char* hlo = R"( +HloModule SingleInplaceDusWithReshapeBitcastToTheRoot, is_scheduled=true, input_output_alias={ {}: (0, {}) } + +single_inplace_dus_with_reshape_bitcast { + p0 = bf16[10,11,12] parameter(0) + p1 = bf16[1,11,12] parameter(1) + p2 = bf16[1,11,12] parameter(2) + p3 = s32[] parameter(3) + c0 = s32[] constant(0) + cmp = pred[] compare(p3, c0), direction=EQ + broadcast = pred[1,11,12] broadcast(cmp), dimensions={} + select = bf16[1,11,12] select(broadcast, p1, p2) + dus0 = bf16[10,11,12] dynamic-update-slice(p0, select, c0, c0, c0) + ROOT bitcasted_dus0 = bf16[10,11,6,2] bitcast(dus0) +} + +ENTRY main { + p0 = bf16[10,11,12] parameter(0) + p1 = bf16[1,11,12] parameter(1) + p2 = bf16[1,11,12] parameter(2) + p3 = s32[] parameter(3) + ROOT fusion_root_reshape_bitcast = bf16[10,11,6,2] fusion(p0, p1, p2, p3), kind=kLoop, calls=single_inplace_dus_with_reshape_bitcast +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto fusion = module->entry_computation()->root_instruction(); + BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); + BufferAllocation::Slice slice0(&alloc, 0, 10); + EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( + Cast(fusion), + [&slice0](const HloInstruction*, const ShapeIndex&) { + return slice0; + }, + HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + IsOkAndHolds(true)); +} + +TEST_F( + IrEmissionUtilsTest, + CanEmitFusedDynamicUpdateSliceInPlaceForGpu_HandlesBitcastToTheRootAndFromParameter) { // NOLINT + const char* hlo = R"( +HloModule SingleInplaceDusWithBitcastToTheRootAndFromTheParameter, is_scheduled=true, input_output_alias={ {}: (0, {}) } + +single_inplace_dus_with_bitcast_to_the_root_and_from_the_parameter { + p0 = bf16[10,11,12] parameter(0) + p1 = bf16[1,11,12] parameter(1) + p2 = bf16[1,11,12] parameter(2) + p3 = s32[] parameter(3) + c0 = s32[] constant(0) + cmp = pred[] compare(p3, c0), direction=EQ + broadcast = pred[1,11,12] broadcast(cmp), dimensions={} + select = bf16[1,11,12] select(broadcast, p1, p2) + bitcasted_p0 = bf16[10,6,2,11] bitcast(p0) + bitcasted_select = bf16[1,6,2,11] bitcast(select) + dus0 = bf16[10,6,2,11] dynamic-update-slice(bitcasted_p0, bitcasted_select, c0, c0, c0, c0) + ROOT bitcasted_dus0 = bf16[10,11,6,2] bitcast(dus0) +} + +ENTRY main { + p0 = bf16[10,11,12] parameter(0) + p1 = bf16[1,11,12] parameter(1) + p2 = bf16[1,11,12] parameter(2) + p3 = s32[] parameter(3) + ROOT fusion_root_bitcast_both_ways = bf16[10,11,6,2] fusion(p0, p1, p2, p3), kind=kLoop, calls=single_inplace_dus_with_bitcast_to_the_root_and_from_the_parameter +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto fusion = module->entry_computation()->root_instruction(); + BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); + BufferAllocation::Slice slice0(&alloc, 0, 10); + EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( + Cast(fusion), + [&slice0](const HloInstruction*, const ShapeIndex&) { + return slice0; + }, + HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + IsOkAndHolds(true)); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/priority_fusion.h b/third_party/xla/xla/service/gpu/priority_fusion.h index 79e2eda54d82fd..61e0d8ff208102 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.h +++ b/third_party/xla/xla/service/gpu/priority_fusion.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc index ce46c155e4c35c..0c2ab26a1f23a2 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc @@ -523,11 +523,6 @@ absl::Status NcclCollectiveDoneThunk::ExecuteOnStream( return params.stream->WaitFor(event); } -absl::Status IsValidOperand(mlir::Value operand, Thunk::Kind reduction_op) { - Shape shape = GetShape(operand); - return IsValidOperand(shape, reduction_op); -} - absl::Status IsValidOperand(Shape shape, Thunk::Kind reduction_op) { if (!LayoutUtil::IsDenseArray(shape)) { return absl::AbortedError( From 4ddfa6ab61167e9c995bb2d4b3547006483003f4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 20 Jun 2024 06:06:13 -0700 Subject: [PATCH 051/256] Integrate LLVM at llvm/llvm-project@e5b0c210cc4c Updates LLVM usage to match [e5b0c210cc4c](https://github.com/llvm/llvm-project/commit/e5b0c210cc4c) PiperOrigin-RevId: 645002700 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 002c9c6f710444..18af3c86dfba84 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "99c43e3ce3142a93bbad4f9efeace254d9a8442c" - LLVM_SHA256 = "40440e956c5a1b73373311a746a55bed9e719ceaa01f2cdc63ed116d5c01c438" + LLVM_COMMIT = "e5b0c210cc4cdaae7075ad2d4aa1efe4eb4cb0c5" + LLVM_SHA256 = "40440422a7e5d0fec35d6b542f4aa5e73af304b029e59dc5516c994696086a70" tf_http_archive( name = name, From c80f6733be89d908137709d8acbc798860539e55 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Thu, 20 Jun 2024 06:21:01 -0700 Subject: [PATCH 052/256] Reverts 4347a69f8985f9777fc9b92a02c86d6a5e23f737 PiperOrigin-RevId: 645006266 --- third_party/xla/xla/pjrt/exceptions.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/pjrt/exceptions.h b/third_party/xla/xla/pjrt/exceptions.h index 6a5865f3cce0ce..f2845893c8938d 100644 --- a/third_party/xla/xla/pjrt/exceptions.h +++ b/third_party/xla/xla/pjrt/exceptions.h @@ -26,6 +26,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" namespace xla { @@ -53,8 +54,8 @@ class XlaRuntimeError : public std::runtime_error { } static bool ShowStackTraces() { - if (char* value = getenv("JAX_TRACEBACK_FILTERING")) { - return strcmp(value, "off"); + if (char* env = getenv("JAX_TRACEBACK_FILTERING")) { + return absl::string_view(env) == "off"; } return false; } From 2c0f46d735bb1a031e865b1ce6443a1db4919008 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Thu, 20 Jun 2024 06:27:04 -0700 Subject: [PATCH 053/256] [XLA] Fix up the behavior for grabbing extra streams. Previous lifetimes weren't correct. PiperOrigin-RevId: 645007804 --- third_party/xla/xla/service/gpu/gpu_executable.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index f184672150853f..196bc02964b5bd 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -382,9 +382,11 @@ absl::Status ExecuteThunks( absl::InlinedVector async_comms_streams( kAsyncStreamTotal, nullptr); se::Stream* command_buffer_trace_stream = nullptr; + std::vector async_comms_streams_ownr; + StreamPool::Ptr borrowed_command_buffer_trace_stream; if (run_options->HasStreamBorrower()) { TF_ASSIGN_OR_RETURN( - std::vector async_comms_streams_ownr, + async_comms_streams_ownr, run_options->BorrowStreams(executor->device_ordinal(), kAsyncStreamTotal, stream_priority)); for (int64_t i = 0; i < kAsyncStreamTotal; ++i) { @@ -392,7 +394,7 @@ absl::Status ExecuteThunks( } // Borrow stream for tracing command buffers. - TF_ASSIGN_OR_RETURN(StreamPool::Ptr borrowed_command_buffer_trace_stream, + TF_ASSIGN_OR_RETURN(borrowed_command_buffer_trace_stream, run_options->BorrowStream(executor->device_ordinal())); command_buffer_trace_stream = borrowed_command_buffer_trace_stream.get(); } From 453db2224f8340db89f6e95ac5bad48697207870 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Thu, 20 Jun 2024 07:50:12 -0700 Subject: [PATCH 054/256] Stop using xla/statusor.h now that it just contains an alias for absl::Status. In some situations, this meant also changing unrelated files to directly include tsl/platform/statusor.h to get the definitions for TF_ASSIGN_OR_RETURN, etc., where they were getting transitively included for free. PiperOrigin-RevId: 645029366 --- third_party/xla/xla/python/ifrt/BUILD | 3 +-- third_party/xla/xla/python/ifrt/host_callback.h | 2 +- third_party/xla/xla/python/ifrt/shape.cc | 2 +- third_party/xla/xla/python/ifrt/test_util.cc | 2 +- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/python/ifrt/BUILD b/third_party/xla/xla/python/ifrt/BUILD index 639772dc013eae..86b06fc97fffd6 100644 --- a/third_party/xla/xla/python/ifrt/BUILD +++ b/third_party/xla/xla/python/ifrt/BUILD @@ -90,7 +90,6 @@ cc_library( ":sharding_proto_cc", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -220,8 +219,8 @@ cc_library( hdrs = ["test_util.h"], deps = [ ":ifrt", - "//xla:statusor", "//xla/tsl/concurrency:ref_count", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", diff --git a/third_party/xla/xla/python/ifrt/host_callback.h b/third_party/xla/xla/python/ifrt/host_callback.h index 3c36237a22908d..32e744c03caf6d 100644 --- a/third_party/xla/xla/python/ifrt/host_callback.h +++ b/third_party/xla/xla/python/ifrt/host_callback.h @@ -18,8 +18,8 @@ limitations under the License. #include +#include "absl/status/statusor.h" #include "llvm/Support/ExtensibleRTTI.h" -#include "xla/statusor.h" #include "xla/tsl/concurrency/ref_count.h" namespace xla { diff --git a/third_party/xla/xla/python/ifrt/shape.cc b/third_party/xla/xla/python/ifrt/shape.cc index fb66cbe464cf87..a77ae4dbe0c7c1 100644 --- a/third_party/xla/xla/python/ifrt/shape.cc +++ b/third_party/xla/xla/python/ifrt/shape.cc @@ -23,11 +23,11 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "xla/python/ifrt/shape.pb.h" -#include "xla/statusor.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/python/ifrt/test_util.cc b/third_party/xla/xla/python/ifrt/test_util.cc index 2da9bf8b4d2091..4353d1a15e1108 100644 --- a/third_party/xla/xla/python/ifrt/test_util.cc +++ b/third_party/xla/xla/python/ifrt/test_util.cc @@ -19,12 +19,12 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/device.h" -#include "xla/statusor.h" namespace xla { namespace ifrt { From 93fcc5eeb101bd54a531437f0d0106670b0704bc Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 20 Jun 2024 08:01:29 -0700 Subject: [PATCH 055/256] [XLA] Remove dead unused pass propagate_static_shapes PiperOrigin-RevId: 645032378 --- third_party/xla/xla/mlir_hlo/BUILD | 1 - .../tests/propagate_static_shapes.mlir | 36 --- .../xla/mlir_hlo/transforms/CMakeLists.txt | 1 - .../xla/xla/mlir_hlo/transforms/passes.h | 10 +- .../xla/xla/mlir_hlo/transforms/passes.td | 9 - .../propagate_static_shapes_to_kernel.cc | 240 ------------------ 6 files changed, 3 insertions(+), 294 deletions(-) delete mode 100644 third_party/xla/xla/mlir_hlo/tests/propagate_static_shapes.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/transforms/propagate_static_shapes_to_kernel.cc diff --git a/third_party/xla/xla/mlir_hlo/BUILD b/third_party/xla/xla/mlir_hlo/BUILD index 2eeefa903ba8b4..4cd1751aea2e60 100644 --- a/third_party/xla/xla/mlir_hlo/BUILD +++ b/third_party/xla/xla/mlir_hlo/BUILD @@ -765,7 +765,6 @@ cc_library( "transforms/generic_host_to_llvm.cc", "transforms/lower_index_cast_pass.cc", "transforms/naive_copy_removal.cc", - "transforms/propagate_static_shapes_to_kernel.cc", "transforms/tile_loops_pass.cc", "transforms/unbufferize_pass.cc", "transforms/unroll_loops.cc", diff --git a/third_party/xla/xla/mlir_hlo/tests/propagate_static_shapes.mlir b/third_party/xla/xla/mlir_hlo/tests/propagate_static_shapes.mlir deleted file mode 100644 index 4113bf5f8e47bd..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/propagate_static_shapes.mlir +++ /dev/null @@ -1,36 +0,0 @@ -// RUN: mlir-hlo-opt %s \ -// RUN: -propagate-static-shapes='convert_pointer_args=!llvm.ptr' \ -// RUN: | FileCheck %s - -module attributes {gpu.container_module} { - - gpu.module @gpu_module { - // CHECK: llvm.func @kernel(%arg0: f32, %arg1: !llvm.ptr, %arg2: f32) - llvm.func @kernel( - %arg0: f32, - %base: !llvm.ptr, %align: !llvm.ptr, %offset: i64, - %size.x: i64, %size.y: i64, %stride.x: i64, %stride.y: i64, - %argN: f32 - ) attributes {gpu.kernel} { - // CHECK: %[[ptr:.*]] = llvm.getelementptr %arg1[4] - // CHECK: llvm.call @dummy(%[[ptr]]) : (!llvm.ptr) -> () - %ptr = llvm.getelementptr %align[%stride.x] : (!llvm.ptr, i64) -> !llvm.ptr, f32 - llvm.call @dummy(%ptr) : (!llvm.ptr) -> () - llvm.return - } - // CHECK: llvm.func @dummy(%arg0: !llvm.ptr) - llvm.func @dummy(%arg0: !llvm.ptr) attributes {gpu.kernel} { - llvm.return - } - } - - func.func @func(%arg0: f32, %arg1: memref<2x4xf32>) { - %c1 = arith.constant 1 : index - gpu.launch_func @gpu_module::@kernel - blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : f32, %arg1 : memref<2x4xf32>, %arg0 : f32) - func.return - } - -} diff --git a/third_party/xla/xla/mlir_hlo/transforms/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/transforms/CMakeLists.txt index 18d08f827e60ce..70af79927f7330 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/CMakeLists.txt +++ b/third_party/xla/xla/mlir_hlo/transforms/CMakeLists.txt @@ -31,7 +31,6 @@ add_mlir_library(MLIRBufferTransforms generic_host_to_llvm.cc lower_index_cast_pass.cc naive_copy_removal.cc - propagate_static_shapes_to_kernel.cc tile_loops_pass.cc vectorize_copy.cc unbufferize_pass.cc diff --git a/third_party/xla/xla/mlir_hlo/transforms/passes.h b/third_party/xla/xla/mlir_hlo/transforms/passes.h index b7ac841829a64e..95d0ff41a96ffe 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/passes.h +++ b/third_party/xla/xla/mlir_hlo/transforms/passes.h @@ -45,33 +45,29 @@ using BufferizePatternsCallback = std::function> createLowerIndexCastPass(); // Pass to tranform compute computations (hlo and linalg) on values to their // corresponding counterparts on buffers. Also bufferizes function signatures. +// Note: dependency from kernelgen. std::unique_ptr> createComputeOpAndFuncBufferizePass(); // Pass to tranform computations on values to their corresponding parts on // buffers. +// Note: dependency from kernelgen. std::unique_ptr> createFinalBufferizePass(); std::unique_ptr> createFinalBufferizePass( uint64_t alignment, BufferizeDialectsCallback dc = {}, BufferizePatternsCallback pc = {}); -// Pass to propagate static shapes to kernel, reducing the kernel arguments -// from a flattened memref to a single pointer. The pointer is converted to -// `pointer_type`, if provided. -std::unique_ptr> -createPropagateStaticShapesToKernelPass(Type pointerType = {}); - // Creates a pass for collapsing multidimensional parallel loops into 1D loops. std::unique_ptr> createCollapseParallelLoopsTo1DPass(); diff --git a/third_party/xla/xla/mlir_hlo/transforms/passes.td b/third_party/xla/xla/mlir_hlo/transforms/passes.td index 88b0649dab3056..5b7ba03d71ac24 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/passes.td +++ b/third_party/xla/xla/mlir_hlo/transforms/passes.td @@ -81,15 +81,6 @@ def FinalBufferizePass : Pass<"final-bufferize", "ModuleOp"> { ]; } -def PropagateStaticShapesToKernelPass : Pass<"propagate-static-shapes", "ModuleOp"> { - let summary = "Pass to rewrite statically shaped kernel arguments to a pointer."; - let constructor = "createPropagateStaticShapesToKernelPass()"; - let options = [ - Option<"ptr_type_opt", "convert_pointer_args", "std::string", - /*default=*/"", "Pointer type to convert pointer arguments to">, - ]; -} - def GenericHostToLLVMPass : Pass<"generic-host-to-llvm", "ModuleOp"> { let summary = "Pass to lower common dialects resulting from HLO to LLVM."; let constructor = "hlo::createGenericHostToLLVMPass()"; diff --git a/third_party/xla/xla/mlir_hlo/transforms/propagate_static_shapes_to_kernel.cc b/third_party/xla/xla/mlir_hlo/transforms/propagate_static_shapes_to_kernel.cc deleted file mode 100644 index 26f4412a29ebcd..00000000000000 --- a/third_party/xla/xla/mlir_hlo/transforms/propagate_static_shapes_to_kernel.cc +++ /dev/null @@ -1,240 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/STLFunctionalExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/AsmParser/AsmParser.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/IR/TypeRange.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/Visitors.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Rewrite/FrozenRewritePatternSet.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "transforms/passes.h" - -namespace mlir { - -#define GEN_PASS_DEF_PROPAGATESTATICSHAPESTOKERNELPASS -#include "transforms/passes.h.inc" - -namespace { - -// Replaces flattened memref arguments (base, aligned, offset, sizes, strides) -// with base and constants if the corresponding launch_func ops argument has -// static shape. Removes all arguments but base. -class PropagateStaticShapesPattern : public OpRewritePattern { - public: - explicit PropagateStaticShapesPattern(MLIRContext* ctx, - SymbolTable& symbolTable, - Type pointerType) - : OpRewritePattern(ctx), - symbolTable(symbolTable), - pointerType(pointerType) {} - - private: - LogicalResult matchAndRewrite(LLVM::LLVMFuncOp funcOp, - PatternRewriter& rewriter) const final; - - SymbolTable& symbolTable; - Type pointerType; -}; - -class PropagateStaticShapesToKernelPass - : public impl::PropagateStaticShapesToKernelPassBase< - PropagateStaticShapesToKernelPass> { - public: - explicit PropagateStaticShapesToKernelPass(Type pointerType) - : pointerType(pointerType) {} - - private: - void runOnOperation() override; - - Type pointerType; -}; - -} // namespace - -// Replaces 'arguments' (containing 'base', 'align', 'offset', 'sizes[rank]', -// 'strides[rank]') corresponding to statically shaped 'memref' with the base -// pointer and constants. The base pointer is changed to 'pointer_type' if -// provided. -static void replaceStaticMemRefArguments(ValueRange arguments, - MemRefType memref, Type pointerType, - PatternRewriter& rewriter) { - assert(arguments.size() >= 3 && "expected at least 3 arguments"); - Value base = arguments[0]; - if (pointerType) { - // Change base to given type, replace with bitcast back to original type. - Type type = base.getType(); - base.setType(pointerType); - auto cast = rewriter.create(base.getLoc(), type, base); - base.replaceAllUsesExcept(/*newValue=*/cast, /*exceptedUser=*/cast); - base = cast.getResult(); - } - - // Replace uses of 'aligned' with 'base'. - arguments[1].replaceAllUsesWith(base); - // Replace uses of 'offset' with constant. - arguments[2].replaceAllUsesWith(rewriter.create( - arguments[2].getLoc(), arguments[2].getType(), - rewriter.getIntegerAttr(arguments[2].getType(), 0))); - auto replace = [&](ArrayRef values, ValueRange arguments) { - for (auto valAndArg : llvm::zip_first(values, arguments)) { - auto argument = std::get<1>(valAndArg); - argument.replaceAllUsesWith(rewriter.create( - argument.getLoc(), argument.getType(), - rewriter.getIntegerAttr(argument.getType(), std::get<0>(valAndArg)))); - } - }; - // Replace 'sizes' and 'strides' with constants. - replace(memref.getShape(), arguments.drop_front(3)); - auto strides = llvm::to_vector<4>(memref.getShape()); - std::partial_sum(strides.rbegin(), strides.rend(), strides.rbegin(), - std::multiplies()); - strides.push_back(1); - replace(llvm::ArrayRef(strides).drop_front(), - arguments.drop_front(3 + memref.getRank())); -} - -LogicalResult PropagateStaticShapesPattern::matchAndRewrite( - LLVM::LLVMFuncOp funcOp, PatternRewriter& rewriter) const { - if (funcOp.isExternal()) - return rewriter.notifyMatchFailure(funcOp, "external"); - if (!funcOp->getAttrOfType( - gpu::GPUDialect::getKernelFuncAttrName())) { - return rewriter.notifyMatchFailure(funcOp, "missing gpu.kernel"); - } - - // Collect gpu.launch_func ops which launch the func_op kernel. - std::optional symUses = - symbolTable.getSymbolUses(funcOp, symbolTable.getOp()); - if (!symUses) - return rewriter.notifyMatchFailure(funcOp, "failed to find symbol uses"); - auto mapper = [](SymbolTable::SymbolUse symUse) { - return dyn_cast(symUse.getUser()); - }; - auto filter = [](gpu::LaunchFuncOp op) -> bool { return op; }; - auto launchOps = llvm::to_vector( - llvm::make_filter_range(llvm::map_range(*symUses, mapper), filter)); - if (launchOps.empty()) - return rewriter.notifyMatchFailure(funcOp, "no gpu.launch_func uses"); - OperandRange operands = launchOps.begin()->getKernelOperands(); - if (llvm::any_of(launchOps, [&](gpu::LaunchFuncOp op) { - return op.getKernelOperands().getTypes() != operands.getTypes(); - })) { - return rewriter.notifyMatchFailure(funcOp, "operand types mismatch"); - } - - rewriter.setInsertionPointToStart(&funcOp.front()); - BitVector argsToDrop(funcOp.getNumArguments()); - // Loop over the launch_op's 'operands' containing scalars and memrefs and the - // func_ops's 'arguments' containing scalars and flattened memrefs. When an - // operand is a staticlly shaped memref, replace the range of arguments - // corresponding to the flattened memref with just the 'base' pointer. - for (auto arguments = funcOp.getArguments(); !arguments.empty(); - operands = operands.drop_front()) { - auto memref = mlir::dyn_cast(operands.getTypes().front()); - if (!memref) { - // Scalar argument, advance by one. - arguments = arguments.drop_front(); - continue; - } - if (!memref.hasRank()) break; // Bail out if unranked. - // memref is flattened to base, align, offset, strides and sizes. - int64_t numArgs = 3 + memref.getRank() * 2; - auto isPtr = [](BlockArgument arg) { - return mlir::isa(arg.getType()); - }; - auto isInt = [](BlockArgument arg) { - return mlir::isa(arg.getType()); - }; - // Bail out if the next num_args are not the expected type. - if (static_cast(arguments.size()) < numArgs) break; - ArrayRef memrefArgs = arguments.take_front(numArgs); - if (!llvm::all_of(memrefArgs.take_front(2), isPtr)) break; - if (!llvm::all_of(memrefArgs.drop_front(2), isInt)) break; - // Replace memref_args with just memref_args[0] if memref has static shape. - if (memref.hasStaticShape() && memref.getLayout().isIdentity()) { - replaceStaticMemRefArguments(memrefArgs, memref, pointerType, rewriter); - unsigned argNumber = arguments.front().getArgNumber(); - // Drop all but 'base' from the flattened memref arguments. - argsToDrop.set(argNumber + 1, argNumber + numArgs); - } - arguments = arguments.drop_front(numArgs); - } - if (argsToDrop.none()) { - return rewriter.notifyMatchFailure(funcOp, "no static shapes"); - } - rewriter.modifyOpInPlace(funcOp, [&] { - SmallVector argTypes; - for (unsigned idx = 0; idx < argsToDrop.size(); ++idx) - if (!argsToDrop[idx]) - argTypes.push_back(funcOp.getArgument(idx).getType()); - auto newFuncType = LLVM::LLVMFunctionType::get( - funcOp.getFunctionType().getReturnType(), argTypes); - function_interface_impl::eraseFunctionArguments(funcOp, argsToDrop, - newFuncType); - }); - return success(); -} - -void PropagateStaticShapesToKernelPass::runOnOperation() { - MLIRContext* ctx = getOperation().getContext(); - auto pointerType = [&]() -> FailureOr { - if (ptr_type_opt.empty()) return this->pointerType; - Type type = parseType(ptr_type_opt, ctx); - if (!type) - return emitError(UnknownLoc::get(ctx), "invalid convert_pointer_args"); - return type; - }(); - if (failed(pointerType)) return signalPassFailure(); - SymbolTable symbolTable(getOperation()); - RewritePatternSet patterns(ctx); - patterns.add(ctx, symbolTable, *pointerType); - FrozenRewritePatternSet frozen(std::move(patterns)); - auto callback = [&](gpu::GPUModuleOp gpuModule) -> WalkResult { - return applyPatternsAndFoldGreedily(gpuModule, frozen); - }; - if (getOperation()->walk(callback).wasInterrupted()) - return signalPassFailure(); -} - -std::unique_ptr> -createPropagateStaticShapesToKernelPass(Type pointerType) { - return std::make_unique(pointerType); -} - -} // namespace mlir From f05e4339df2305c987f934d6fdb3b46cf8c72fba Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Thu, 20 Jun 2024 08:04:06 -0700 Subject: [PATCH 056/256] [XLA:GPU] Use Cost Model to choose tile sizes in SoftmaxRewriterTriton. After this change, all SoftMax fusions create in the rewriter should have BlockLevelFusionConfig with tile parameters to be used in the emitter. PiperOrigin-RevId: 645033335 --- third_party/xla/xla/service/gpu/BUILD | 9 +- .../xla/xla/service/gpu/gpu_compiler.cc | 3 +- .../model/gpu_indexing_performance_model.cc | 2 +- .../model/gpu_indexing_performance_model.h | 6 +- .../service/gpu/softmax_rewriter_triton.cc | 65 +++- .../xla/service/gpu/softmax_rewriter_triton.h | 12 +- .../gpu/softmax_rewriter_triton_test.cc | 335 +++++++++++------- 7 files changed, 284 insertions(+), 148 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 3b39b1f8363f56..8899b034110add 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1555,6 +1555,7 @@ cc_library( hdrs = ["softmax_rewriter_triton.h"], deps = [ ":backend_configs_cc", + ":hlo_traversal", ":ir_emission_utils", ":triton_support", "//xla:shape_util", @@ -1563,9 +1564,13 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_cost_analysis", "//xla/service:hlo_pass", "//xla/service:instruction_fusion", + "//xla/service/gpu/model:fusion_analysis_cache", + "//xla/service/gpu/model:gpu_indexing_performance_model", "//xla/service/gpu/model:symbolic_tile_analysis", + "//xla/service/gpu/model:tiled_hlo_computation", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -2585,12 +2590,14 @@ xla_cc_test( name = "softmax_rewriter_triton_test", srcs = ["softmax_rewriter_triton_test.cc"], deps = [ + ":backend_configs_cc", + ":gpu_device_info_for_tests", ":softmax_rewriter_triton", "//xla:shape_util", "//xla/hlo/ir:hlo", - "//xla/service:instruction_fusion", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu/model:gpu_hlo_cost_analysis", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # build_cleaner: keep diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 440e5ad600cd7c..574e6c239b83d0 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1380,7 +1380,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) { pipeline.AddPass>(simplifier_options, gpu_version); - pipeline.AddPass(gpu_version); + pipeline.AddPass( + gpu_target_config.device_description, ShapeSizeBytesFunction()); } pipeline.AddPass(); diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc index 049cd3aaae9f21..84cdf525441867 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -388,7 +388,7 @@ LaunchDimensions GetLaunchDimensionsForTiledFusion( static_cast(num_warps * WarpSize())}; } -absl::StatusOr> +absl::StatusOr GpuPerformanceModelWithIndexingAnalysis::TryFindBestTilingForFusion( const HloFusionAdaptor& fusion_adaptor) { SymbolicTileAnalysisOrError analysis_or_error = diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h index 8d8b2ac109afd9..d2a70e87e28600 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h @@ -47,6 +47,8 @@ struct TiledRunTimeData { BlockLevelParameters block_level_parameters; }; +using TiledRunTimeDataOrError = std::variant; + // Implementation of Cost Model that uses indexing analysis to estimate amount // of compute and memory access time. class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { @@ -105,8 +107,8 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { // Returns FusionDecision if the fusion can't be tiled or there are no valid // block level parameters. // Otherwise returns block level parameters that give the best execution time. - absl::StatusOr> - TryFindBestTilingForFusion(const HloFusionAdaptor& fusion_adaptor); + absl::StatusOr TryFindBestTilingForFusion( + const HloFusionAdaptor& fusion_adaptor); private: // Returns an estimate how many FLOPs will be used to produce one element of diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc index d270ad1fbe1a66..3d7cd6cac84980 100644 --- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc @@ -14,6 +14,7 @@ limitations under the License. #include #include +#include #include #include @@ -35,8 +36,12 @@ limitations under the License. #include "xla/hlo/utils/hlo_query.h" #include "xla/layout_util.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/model/fusion_analysis_cache.h" +#include "xla/service/gpu/model/gpu_indexing_performance_model.h" #include "xla/service/gpu/model/symbolic_tile_analysis.h" +#include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/gpu/triton_support.h" #include "xla/service/instruction_fusion.h" #include "xla/shape.h" @@ -397,11 +402,35 @@ absl::StatusOr MakeFusionForDiamondChain( return xla::Cast(softmax_fusion); } -absl::Status FuseDiamondChainImpl(const DiamondChainDescriptor& diamond_chain) { +absl::Status FuseDiamondChainImpl( + const DiamondChainDescriptor& diamond_chain, + GpuPerformanceModelWithIndexingAnalysis& indexing_performance_model) { TF_ASSIGN_OR_RETURN(HloFusionInstruction * softmax_fusion, MakeFusionForDiamondChain(diamond_chain)); HloInstruction* root = diamond_chain.root; + auto fusion_adaptor = HloFusionAdaptor::ForInstruction(softmax_fusion); + + TF_ASSIGN_OR_RETURN( + TiledRunTimeDataOrError tiled_runtime_data_or, + indexing_performance_model.TryFindBestTilingForFusion(*fusion_adaptor)); + + if (const auto* fusion_decision = + std::get_if(&tiled_runtime_data_or)) { + return absl::FailedPreconditionError(absl::StrCat( + "SymbolicTileAnalysis failed. ", fusion_decision->Explain())); + } + + TiledRunTimeData tiled_runtime_data = + std::get(std::move(tiled_runtime_data_or)); + + TF_ASSIGN_OR_RETURN(auto backend_config, + softmax_fusion->backend_config()); + *backend_config.mutable_fusion_backend_config() + ->mutable_block_level_fusion_config() = + tiled_runtime_data.block_level_parameters.ToBlockLevelFusionConfig(); + TF_RETURN_IF_ERROR(softmax_fusion->set_backend_config(backend_config)); + if (root->IsRoot()) { root->parent()->set_root_instruction(softmax_fusion); TF_RETURN_IF_ERROR( @@ -446,7 +475,8 @@ SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond( return "Root is not elementwise binary."; } - if (!legacy_triton::IsTritonSupportedInstruction(*instr, gpu_version_)) { + if (!legacy_triton::IsTritonSupportedInstruction( + *instr, device_info_.gpu_compute_capability())) { return "Root is not supported for Triton instruction."; } @@ -455,12 +485,12 @@ SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond( HloInstruction* reduce; if (!TrivialEdge(&broadcast, instr->mutable_operand(1), HloOpcode::kBroadcast, - gpu_version_)) { + device_info_.gpu_compute_capability())) { return "Could not find a trivial connection from root to a broadcast."; } if (!TrivialEdge(&reduce, broadcast->mutable_operand(0), HloOpcode::kReduce, - gpu_version_)) { + device_info_.gpu_compute_capability())) { return "Could not find a trivial connection from matched broadcast to a " "reduction."; } @@ -471,7 +501,8 @@ SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond( } if (CodegenDecision is_supported = - legacy_triton::IsTritonSupportedInstruction(*reduce, gpu_version_); + legacy_triton::IsTritonSupportedInstruction( + *reduce, device_info_.gpu_compute_capability()); !is_supported) { VLOG(3) << is_supported.Explain(); return is_supported; @@ -488,7 +519,7 @@ SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond( return "Broadcast is not along the reduction dimension."; } - while (IsTriviallyFusible(producer, gpu_version_)) { + while (IsTriviallyFusible(producer, device_info_.gpu_compute_capability())) { producer = ChooseOperandForFusionProcessing(producer); } @@ -497,7 +528,7 @@ SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond( } if (!IsTriviallyConnectedProducerOf(producer, instr->mutable_operand(0), - gpu_version_)) { + device_info_.gpu_compute_capability())) { return "Producer is not trivially connected."; } @@ -579,7 +610,8 @@ SoftmaxRewriterTriton::FindAllFusibleDiamondChains( auto last_trivially_fusible_user = [&](HloInstruction* instr) { while (HasOneUse(instr) && !instr->IsRoot() && - IsTriviallyFusible(instr->users().front(), gpu_version_)) { + IsTriviallyFusible(instr->users().front(), + device_info_.gpu_compute_capability())) { instr = instr->users().front(); } @@ -588,7 +620,7 @@ SoftmaxRewriterTriton::FindAllFusibleDiamondChains( // restriction. if (HasOneUse(instr) && !instr->IsRoot() && IsTriviallyFusible( - instr->users().front(), gpu_version_, + instr->users().front(), device_info_.gpu_compute_capability(), /*num_allowed_users=*/instr->users().front()->user_count())) { instr = instr->users().front(); } @@ -615,7 +647,7 @@ SoftmaxRewriterTriton::FindAllFusibleDiamondChains( diamond_chains.reserve(matched_diamonds.size()); HloInstruction* current_fusion_producer = FindFirstNonFusibleDiamondProducer( - matched_diamonds.front().producer, gpu_version_); + matched_diamonds.front().producer, device_info_.gpu_compute_capability()); int current_reduce_dimension_size = reduction_dimension_size_from_diamond_root(matched_diamonds.front().root); @@ -626,7 +658,8 @@ SoftmaxRewriterTriton::FindAllFusibleDiamondChains( matched_diamonds[diamond_idx - 1].root; HloInstruction* first_non_fusible_diamond_producer = - FindFirstNonFusibleDiamondProducer(diamond_producer, gpu_version_); + FindFirstNonFusibleDiamondProducer( + diamond_producer, device_info_.gpu_compute_capability()); int diamond_reduce_dimension_size = reduction_dimension_size_from_diamond_root(diamond_root); @@ -678,14 +711,18 @@ SoftmaxRewriterTriton::FindAllFusibleDiamondChains( absl::Status SoftmaxRewriterTriton::FuseDiamondChain( const DiamondChainDescriptor& diamond_chain) { - return FuseDiamondChainImpl(diamond_chain); + HloFusionAnalysisCache fusion_analysis_cache(device_info_); + GpuPerformanceModelWithIndexingAnalysis indexing_performance_model( + &device_info_, &fusion_analysis_cache, shape_size_, &mlir_context_); + + return FuseDiamondChainImpl(diamond_chain, indexing_performance_model); } absl::StatusOr SoftmaxRewriterTriton::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { - auto cuda_compute_capability = - std::get_if(&gpu_version_); + auto cuda_compute_capability = std::get_if( + &device_info_.gpu_compute_capability()); if (!cuda_compute_capability) { return absl::FailedPreconditionError( "Triton support is only enabled for CUDA GPUs."); diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h index b70f9f56a03334..edcb87414d4e6d 100644 --- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h +++ b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h @@ -23,8 +23,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_pass_interface.h" #include "xla/service/instruction_fusion.h" #include "xla/stream_executor/device_description.h" @@ -43,8 +45,10 @@ using DiamondMatchingDecision = std::variant; // with the Triton-based Softmax emitter. class SoftmaxRewriterTriton : public HloModulePass { public: - explicit SoftmaxRewriterTriton(se::GpuComputeCapability gpu_version) - : gpu_version_(gpu_version) {} + explicit SoftmaxRewriterTriton(const se::DeviceDescription& device_info, + HloCostAnalysis::ShapeSizeFunction shape_size) + : device_info_(device_info), shape_size_(shape_size) {} + absl::string_view name() const override { return "triton-softmax-rewriter"; } using HloPassInterface::Run; @@ -86,7 +90,9 @@ class SoftmaxRewriterTriton : public HloModulePass { HloInstruction* instr) const; private: - se::GpuComputeCapability gpu_version_; + const se::DeviceDescription& device_info_; + const HloCostAnalysis::ShapeSizeFunction shape_size_; + mlir::MLIRContext mlir_context_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc index 70f445de4aac42..679e8416eccfd9 100644 --- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/softmax_rewriter_triton.h" +#include #include #include #include @@ -27,10 +28,15 @@ limitations under the License. #include "absl/strings/substitute.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" -#include "xla/service/instruction_fusion.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "tsl/lib/core/status_test_util.h" @@ -46,13 +52,30 @@ namespace m = ::xla::match; using ::testing::HasSubstr; +GpuHloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() { + return [&](const Shape& shape) { + constexpr int64_t kPointerSize = 8; + return ShapeUtil::ByteSizeOf(shape, kPointerSize); + }; +} + +bool HasBlockLevelFusionConfig(const HloInstruction* fusion) { + return fusion->opcode() == HloOpcode::kFusion && + fusion->has_backend_config() && + fusion->backend_config().ok() && + fusion->backend_config() + ->fusion_backend_config() + .has_block_level_fusion_config(); +} + // Wrapper around SoftmaxRewriterTriton(gpu_version).Run(module) that finds // and fuses as many diamond chains as possible without invoking any kind of // cost analysis. absl::StatusOr SoftmaxRewriterTritonMatchAndRewrite( - se::GpuComputeCapability gpu_version, HloModule* module) { + const se::DeviceDescription& device_info, HloModule* module) { CHECK_NE(module, nullptr); - SoftmaxRewriterTriton softmax_rewriter_triton(gpu_version); + SoftmaxRewriterTriton softmax_rewriter_triton(device_info, + ShapeSizeBytesFunction()); TF_ASSIGN_OR_RETURN(std::vector diamond_chains, softmax_rewriter_triton.FindAllFusibleDiamondChains( *module, /*execution_threads=*/{})); @@ -69,14 +92,8 @@ absl::StatusOr SoftmaxRewriterTritonMatchAndRewrite( class SoftmaxRewriterTritonTest : public HloTestBase, public ::testing::WithParamInterface { - public: - void SetUp() override { - gpu_version_ = se::GpuComputeCapability{ - se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0}}; - } - protected: - se::GpuComputeCapability gpu_version_; + se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()}; }; TEST_P(SoftmaxRewriterTritonTest, CanFuseExactSoftmax) { @@ -113,7 +130,7 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); VLOG(2) << module->ToString(); @@ -121,7 +138,8 @@ ENTRY main { case F32: case BF16: EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + GmockMatch(m::Fusion(m::Parameter()) + .WithPredicate(HasBlockLevelFusionConfig))); break; case F16: EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -155,11 +173,13 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } TEST_F(SoftmaxRewriterTritonTest, CanNotFuseExactSoftmaxF64) { @@ -191,7 +211,7 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_F(SoftmaxRewriterTritonTest, CanFuseExactSoftmaxBF16) { @@ -223,10 +243,12 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } TEST_P(SoftmaxRewriterTritonTest, @@ -272,7 +294,7 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_P(SoftmaxRewriterTritonTest, CanNotFuseSoftmaxDiamondWithWrongLayout) { @@ -298,7 +320,7 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_P(SoftmaxRewriterTritonTest, @@ -325,7 +347,7 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_P(SoftmaxRewriterTritonTest, @@ -352,7 +374,7 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } // TODO(bchetioui): expand so this can be supported? @@ -382,7 +404,7 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_P(SoftmaxRewriterTritonTest, @@ -421,14 +443,15 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); switch (data_type) { case F32: case BF16: EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + GmockMatch(m::Fusion(m::Parameter()) + .WithPredicate(HasBlockLevelFusionConfig))); break; case F16: EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -473,14 +496,15 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); switch (data_type) { case F32: case BF16: EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + GmockMatch(m::Fusion(m::Parameter()) + .WithPredicate(HasBlockLevelFusionConfig))); break; case F16: EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -516,10 +540,12 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } TEST_P(SoftmaxRewriterTritonTest, CanFuseDiamondWithUnaryElementwisePrefix) { @@ -546,10 +572,12 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } TEST_P(SoftmaxRewriterTritonTest, @@ -576,10 +604,12 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } TEST_P(SoftmaxRewriterTritonTest, @@ -608,7 +638,7 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_P(SoftmaxRewriterTritonTest, @@ -636,7 +666,7 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_P(SoftmaxRewriterTritonTest, @@ -666,7 +696,7 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_P(SoftmaxRewriterTritonTest, @@ -696,10 +726,12 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } TEST_P(SoftmaxRewriterTritonTest, @@ -729,7 +761,7 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_P(SoftmaxRewriterTritonTest, @@ -768,14 +800,18 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); switch (data_type) { case F32: case BF16: - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Bitcast(m::Fusion(m::Parameter()))))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Bitcast(m::Fusion(m::Parameter()) + .WithPredicate( + HasBlockLevelFusionConfig))) + .WithPredicate(HasBlockLevelFusionConfig))); break; case F16: EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -822,19 +858,25 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); switch (data_type) { case F32: case BF16: - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Tuple(m::Fusion(m::Fusion()), - m::Fusion(m::Parameter())))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::Fusion(m::Fusion()).WithPredicate(HasBlockLevelFusionConfig), + m::Fusion(m::Parameter()) + .WithPredicate(HasBlockLevelFusionConfig)))); break; case F16: - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Tuple(m::Divide(), m::Fusion(m::Parameter())))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Tuple(m::Divide(), + m::Fusion(m::Parameter()) + .WithPredicate(HasBlockLevelFusionConfig)))); break; default: ABSL_UNREACHABLE(); @@ -877,15 +919,18 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); switch (data_type) { case F32: case BF16: - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Tuple(m::Fusion(m::Fusion()), - m::Fusion(m::Parameter())))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::Fusion(m::Fusion()).WithPredicate(HasBlockLevelFusionConfig), + m::Fusion(m::Parameter()) + .WithPredicate(HasBlockLevelFusionConfig)))); break; case F16: EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -924,11 +969,12 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); VLOG(2) << module->ToString(); EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::IsFinite(m::Parameter())))); + GmockMatch(m::Fusion(m::IsFinite(m::Parameter())) + .WithPredicate(HasBlockLevelFusionConfig))); } TEST_P(SoftmaxRewriterTritonTest, @@ -959,7 +1005,7 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_P(SoftmaxRewriterTritonTest, @@ -991,11 +1037,13 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } TEST_P( @@ -1030,7 +1078,7 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_P(SoftmaxRewriterTritonTest, CanFuseSoftmaxDiamondWithSmallRows) { @@ -1055,12 +1103,14 @@ ENTRY main { primitive_util::LowercasePrimitiveTypeName(data_type)); auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - EXPECT_THAT(SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()), + EXPECT_THAT(SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()), tsl::testing::IsOkAndHolds(true)); TF_EXPECT_OK(verifier().Run(module.get()).status()); VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } TEST_P(SoftmaxRewriterTritonTest, @@ -1089,13 +1139,16 @@ ENTRY main { EXPECT_TRUE( SoftmaxRewriterTritonMatchAndRewrite( - se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0}, + TestGpuDeviceInfo::RTXA6000DeviceInfo( + se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0}), module.get()) .value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } TEST_F(SoftmaxRewriterTritonTest, RewriterBailsOutOnPreAmpereGpu) { @@ -1119,7 +1172,9 @@ ENTRY main { EXPECT_THAT( SoftmaxRewriterTriton( - se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, 0}) + TestGpuDeviceInfo::RTXA6000DeviceInfo( + se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, 0}), + ShapeSizeBytesFunction()) .Run(module.get()), tsl::testing::StatusIs( tsl::error::FAILED_PRECONDITION, @@ -1147,7 +1202,9 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_THAT( - SoftmaxRewriterTriton(se::RocmComputeCapability{}).Run(module.get()), + SoftmaxRewriterTriton(TestGpuDeviceInfo::AMDMI210DeviceInfo(), + ShapeSizeBytesFunction()) + .Run(module.get()), tsl::testing::StatusIs( tsl::error::FAILED_PRECONDITION, ::testing::StrEq("Triton support is only enabled for CUDA GPUs."))); @@ -1177,11 +1234,13 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Convert(m::Fusion(m::Parameter())))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Convert( + m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)))); } TEST_P(SoftmaxRewriterTritonTest, DoesNotFuseConvertWithC128DataType) { @@ -1208,11 +1267,13 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Convert(m::Fusion(m::Parameter())))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Convert( + m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)))); } TEST_P(SoftmaxRewriterTritonTest, @@ -1240,11 +1301,13 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } TEST_P( @@ -1273,11 +1336,13 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } TEST_P(SoftmaxRewriterTritonTest, @@ -1314,11 +1379,13 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } TEST_P(SoftmaxRewriterTritonTest, @@ -1351,11 +1418,13 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } TEST_P( @@ -1384,7 +1453,7 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_P(SoftmaxRewriterTritonTest, @@ -1422,11 +1491,13 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } TEST_F( @@ -1456,7 +1527,7 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_F( @@ -1483,9 +1554,9 @@ ENTRY main { )"; auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - SoftmaxRewriterTriton fusion_rewriter(gpu_version_); + SoftmaxRewriterTriton fusion_rewriter(device_info_, ShapeSizeBytesFunction()); EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_P(SoftmaxRewriterTritonTest, CanFuseRMSNormDiamond) { @@ -1523,16 +1594,17 @@ ENTRY main.30 { case F32: case BF16: EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()) + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()) .value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + GmockMatch(m::Fusion(m::Parameter()) + .WithPredicate(HasBlockLevelFusionConfig))); break; case F16: // Triton does not support F16 rsqrt. EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()) + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()) .value()); break; default: @@ -1572,11 +1644,13 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } TEST_P( @@ -1611,11 +1685,13 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } TEST_P( @@ -1646,11 +1722,13 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } TEST_P(SoftmaxRewriterTritonTest, @@ -1680,11 +1758,13 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } TEST_P(SoftmaxRewriterTritonTest, @@ -1716,11 +1796,13 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } TEST_F( @@ -1752,7 +1834,7 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_F(SoftmaxRewriterTritonTest, FusionDecisionIsCapturedExplicitly) { @@ -1773,7 +1855,8 @@ ENTRY main { )"; auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - SoftmaxRewriterTriton softmax_rewriter_triton(gpu_version_); + SoftmaxRewriterTriton softmax_rewriter_triton(device_info_, + ShapeSizeBytesFunction()); int unmatched = 0, matched = 0; for (HloInstruction* instruction : module->entry_computation()->MakeInstructionPostOrder()) { @@ -1822,7 +1905,7 @@ ENTRY main { })"; auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_F( @@ -1850,7 +1933,7 @@ ENTRY main { })"; auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_F( @@ -1878,7 +1961,7 @@ ENTRY main { })"; auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_F( @@ -1906,7 +1989,7 @@ ENTRY main { })"; auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_F( @@ -1934,7 +2017,7 @@ ENTRY main { })"; auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_F( @@ -1962,7 +2045,7 @@ ENTRY main { })"; auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_F( @@ -1990,7 +2073,7 @@ ENTRY main { })"; auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_F( @@ -2019,7 +2102,7 @@ ENTRY main { )"; auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } TEST_F( @@ -2047,7 +2130,7 @@ ENTRY main { })"; auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); + SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); } INSTANTIATE_TEST_SUITE_P(SoftmaxRewriterTritonTestSuite, From 45f67b2c266ffd272e96e1207cd9e28ddfb3f070 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 20 Jun 2024 08:07:08 -0700 Subject: [PATCH 057/256] Only split those constants that are shared between manually and automatically sharded regions of the graph. PiperOrigin-RevId: 645034610 --- .../xla/hlo/experimental/auto_sharding/BUILD | 2 + .../auto_sharding/auto_sharding.cc | 15 +++++++ .../auto_sharding/auto_sharding_option.h | 3 ++ .../auto_sharding/auto_sharding_test.cc | 27 ++++++++++-- .../auto_sharding/auto_sharding_util.cc | 43 ++++++++++++++++++- .../auto_sharding/auto_sharding_util.h | 4 ++ 6 files changed, 90 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD index c9796e4eff1536..a8a4b7e27fb09c 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD @@ -48,6 +48,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:hlo_constant_splitter", "//xla/hlo/utils:hlo_live_range", "//xla/hlo/utils:hlo_sharding_util", "//xla/service:buffer_value", @@ -58,6 +59,7 @@ cc_library( "//xla/service:hlo_alias_analysis", "//xla/service:hlo_buffer", "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_dce", "//xla/service:hlo_memory_scheduler", "//xla/service:hlo_pass", "//xla/service:hlo_value", diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 030ae8c296159b..b2ec822336e1cc 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -67,6 +67,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/transforms/hlo_constant_splitter.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/service/buffer_value.h" @@ -76,6 +77,7 @@ limitations under the License. #include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_cost_analysis.h" +#include "xla/service/hlo_dce.h" #include "xla/service/hlo_memory_scheduler.h" #include "xla/service/hlo_value.h" #include "xla/service/optimize_input_output_buffer_alias.h" @@ -4386,7 +4388,20 @@ absl::StatusOr AutoSharding::Run( } } + // Run HloConstantSplitter for modules with manually partitioned sub-graphs to + // avoid having constant ops that are used as part of such manually + // partitioned sub-graphs, as well as outside those, leading to conflicts + // during sharding. However, constant splitting can cause increased + // auto-sharding times, and hence we enable this only when needed. bool module_is_manually_partitioned = ModuleIsManuallyPartitioned(module); + if (module_is_manually_partitioned) { + HloConstantSplitter constant_splitter( + /*split_expressions=*/option_.enable_expression_constant_splitter, + /*extra_constraints=*/spmd::OpEncountersShardToFull); + CHECK_OK(constant_splitter.Run(module, execution_threads)); + CHECK_OK(HloDCE().Run(module, execution_threads)); + } + std::vector> mesh_shapes; if (option_.try_multiple_mesh_shapes || module_is_manually_partitioned) { bool asymmetrical_mesh_dims = false; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h index 2165b9e78d880b..b2fbb1eabd551e 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h @@ -206,6 +206,9 @@ struct AutoShardingOption { // number of devices larger than the size of the tensor dimension bool allow_shardings_small_dims_across_many_devices = false; + // Split constant expressions as well when invoking HloConstantSplitter. + bool enable_expression_constant_splitter = false; + // Prints a debug string. std::string ToString() const; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index e10b9dbebec4e2..c7fd9bdc42dc09 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -34,6 +34,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/utils/hlo_live_range.h" @@ -53,9 +54,11 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace spmd { namespace { +using ::testing::Contains; using ::testing::Each; using ::testing::ElementsAre; using ::testing::ElementsAreArray; +using ::testing::Eq; using ::testing::FieldsAre; using ::testing::IsEmpty; using ::testing::IsFalse; @@ -432,7 +435,7 @@ ENTRY %RngBitGenerator (p0: u64[2]) -> (u64[2], u32[16,16]) { EXPECT_THAT(instruction, op::Sharding("{replicated}")); } -TEST_F(AutoShardingTest, SPMDShardToFullShapeTest) { +TEST_F(AutoShardingTest, SPMDShardToFullShapeWithConstantTest) { constexpr absl::string_view kHloString = R"( HloModule rng_bit_generator @@ -444,11 +447,16 @@ add.6.clone { ENTRY main { input.1 = bf16[512,512]{1,0} parameter(0) + constant.1 = bf16[] constant(16.7) + broadcast.1 = bf16[128,128]{1,0} broadcast(constant.1), dimensions={} + broadcast.2 = bf16[512,512]{1,0} broadcast(constant.1), dimensions={} custom-call.1 = bf16[512,512]{1,0} custom-call(input.1), custom_call_target="Sharding", sharding={devices=[4,4]<=[16]} custom-call.2 = bf16[128,128]{1,0} custom-call(custom-call.1), custom_call_target="SPMDFullToShardShape", sharding={manual} all-reduce.1 = bf16[128,128]{1,0} all-reduce(custom-call.2), channel_id=621, replica_groups={{0,1,2,3},{4,5,6,7},{8,9,10,11},{12,13,14,15}}, use_global_device_ids=true, to_apply=add.6.clone, frontend_attributes={from-cross-replica-sharding="true"}, backend_config={"flag_configs":[],"barrier_config":{"barrier_type":"CUSTOM","id":"9"},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[]} - custom-call.3 = bf16[512,512]{1,0} custom-call(all-reduce.1), custom_call_target="SPMDShardToFullShape", sharding={devices=[4,1,4]<=[16]last_tile_dim_replicate} - ROOT copy.1 = copy(custom-call.3) + add.1 = bf16[128,128]{1,0} add(bf16[128,128]{1,0} all-reduce.1, bf16[128,128]{1,0} broadcast.1) + custom-call.3 = bf16[512,512]{1,0} custom-call(add.1), custom_call_target="SPMDShardToFullShape", sharding={devices=[4,1,4]<=[16]last_tile_dim_replicate} + add.2 = bf16[512,512]{1,0} add(bf16[512,512]{1,0} custom-call.3, bf16[512,512]{1,0} broadcast.2) + ROOT copy.1 = bf16[512,512]{1,0} copy(add.2) })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -481,6 +489,19 @@ ENTRY main { const HloInstruction* custom_call1 = custom_call2->operand(0); ASSERT_NE(custom_call1, nullptr); EXPECT_THAT(custom_call1, op::Sharding("{devices=[4,4]<=[16]}")); + + // Check that there are two constant instructions as we split one that is + // shared + std::vector instructions( + module->entry_computation()->instructions().begin(), + module->entry_computation()->instructions().end()); + EXPECT_THAT( + module->entry_computation()->instructions(), + Contains(ResultOf( + "opcode", + [](const HloInstruction* ins) { return ins->opcode(); }, + Eq(HloOpcode::kConstant))) + .Times(2)); } TEST_F(AutoShardingTest, SPMDShardToFullShapeMultipleValidMeshShapeTest) { diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index 9ece09787cc8b2..039fc07706c7a0 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" #include -#include #include #include #include @@ -24,6 +23,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -2373,5 +2373,46 @@ absl::StatusOr GetPartialReduceReductionDim( return parsed_json[kReductionDimKey].asInt64(); } +bool OpEncountersShardToFull(const HloInstruction* op) { + std::queue queue; + queue.push(op); + + absl::flat_hash_set visited; + while (!queue.empty()) { + const HloInstruction* instruction = queue.front(); + queue.pop(); + if (visited.contains(instruction)) { + continue; + } + visited.insert(instruction); + + for (const HloComputation* computation : + instruction->called_computations()) { + for (const HloInstruction* parameter : + computation->parameter_instructions()) { + if (spmd::IsSPMDShardToFullShapeCustomCall(parameter)) { + return true; + } else if (spmd::IsSPMDFullToShardShapeCustomCall(parameter) || + parameter == instruction || visited.contains(parameter)) { + continue; + } + queue.push(parameter); + } + } + + for (const HloInstruction* user : instruction->users()) { + if (spmd::IsSPMDShardToFullShapeCustomCall(user)) { + return true; + } else if (spmd::IsSPMDFullToShardShapeCustomCall(user) || + visited.contains(user)) { + continue; + } + queue.push(user); + } + } + + return false; +} + } // namespace spmd } // namespace xla diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index d06b705233fc85..e86dce3966fca4 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -676,6 +676,10 @@ HloSharding ReplaceGivenShardingsWithUnknownForTuple( // Extract the reduction_dim of a PartialReduce custom call absl::StatusOr GetPartialReduceReductionDim(const HloInstruction* ins); +// Returns true if an HLO op flows to a SPMDShardToFullShape custom call without +// encountering a SPMDFullToShardShape custom call on the call. +bool OpEncountersShardToFull(const HloInstruction* op); + } // namespace spmd } // namespace xla From 9e0fce7d28c509c7f4bcd7e1e2a5d07650526e29 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Thu, 20 Jun 2024 08:15:43 -0700 Subject: [PATCH 058/256] Move StreamExecutor::Memset32 into Stream and its derived classes. PiperOrigin-RevId: 645037176 --- tensorflow/c/experimental/stream_executor/BUILD | 2 ++ .../stream_executor/stream_executor.cc | 9 --------- .../stream_executor/stream_executor_internal.h | 17 ++++++++++++++++- .../xla/xla/backends/interpreter/executor.h | 4 ---- .../xla/stream_executor/cuda/cuda_executor.cc | 12 ------------ .../xla/xla/stream_executor/gpu/gpu_executor.h | 2 -- .../xla/xla/stream_executor/gpu/gpu_stream.cc | 10 ++++++++++ .../xla/xla/stream_executor/gpu/gpu_stream.h | 3 +++ .../xla/stream_executor/host/host_executor.cc | 10 ---------- .../xla/stream_executor/host/host_executor.h | 2 -- .../xla/xla/stream_executor/host/host_stream.cc | 13 +++++++++++++ .../xla/xla/stream_executor/host/host_stream.h | 2 ++ .../xla/stream_executor/mock_stream_executor.h | 4 ---- .../xla/stream_executor/rocm/rocm_executor.cc | 12 ------------ third_party/xla/xla/stream_executor/stream.h | 5 ++++- .../xla/xla/stream_executor/stream_common.cc | 5 ----- .../xla/xla/stream_executor/stream_common.h | 2 -- .../xla/xla/stream_executor/stream_executor.h | 7 ------- .../xla/xla/stream_executor/tpu/tpu_executor.h | 4 ---- 19 files changed, 50 insertions(+), 75 deletions(-) diff --git a/tensorflow/c/experimental/stream_executor/BUILD b/tensorflow/c/experimental/stream_executor/BUILD index e14f94ae16f8df..d51568003e16fa 100644 --- a/tensorflow/c/experimental/stream_executor/BUILD +++ b/tensorflow/c/experimental/stream_executor/BUILD @@ -71,8 +71,10 @@ cc_library( "//tensorflow/c:c_api_macros", "//tensorflow/c:tf_status", "//tensorflow/c:tf_status_helper", + "@com_google_absl//absl/status", "@local_tsl//tsl/platform:statusor", "@local_xla//xla/stream_executor", + "@local_xla//xla/stream_executor:event", "@local_xla//xla/stream_executor:executor_cache", "@local_xla//xla/stream_executor:stream_common", "@local_xla//xla/stream_executor:stream_executor_h", diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index ef54b703a68db6..2bca4311fd7cc4 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -312,15 +312,6 @@ class CStreamExecutor : public StreamExecutorCommon { size, c_status.get()); return StatusFromTF_Status(c_status.get()); } - absl::Status Memset32(Stream* stream, DeviceMemoryBase* location, - uint32 pattern, uint64 size) override { - OwnedTFStatus c_status(TF_NewStatus()); - SP_Stream stream_handle = static_cast(stream)->Handle(); - SP_DeviceMemoryBase device_mem = DeviceMemoryBaseToC(location); - stream_executor_->memset32(&device_, stream_handle, &device_mem, pattern, - size, c_status.get()); - return StatusFromTF_Status(c_status.get()); - } absl::Status Memcpy(Stream* stream, void* host_dst, const DeviceMemoryBase& gpu_src, uint64 size) override { OwnedTFStatus c_status(TF_NewStatus()); diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h index 21447ba6aa16f5..d8e0ec75365b95 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h +++ b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h @@ -1,4 +1,3 @@ -#include "xla/stream_executor/stream.h" /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,10 +18,18 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ #define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ +#include +#include + +#include "absl/status/status.h" #include "tensorflow/c/experimental/stream_executor/stream_executor.h" +#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" #include "xla/stream_executor/executor_cache.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_common.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/statusor.h" @@ -220,6 +227,14 @@ class CStream : public StreamCommon { c_status.get()); return tensorflow::StatusFromTF_Status(c_status.get()); } + absl::Status Memset32(DeviceMemoryBase* location, uint32_t pattern, + uint64_t size) override { + tensorflow::TF_StatusPtr c_status(TF_NewStatus()); + SP_DeviceMemoryBase device_mem = DeviceMemoryBaseToC(location); + stream_executor_->memset32(device_, stream_handle_, &device_mem, pattern, + size, c_status.get()); + return tensorflow::StatusFromTF_Status(c_status.get()); + } SP_Stream Handle() { return stream_handle_; } diff --git a/third_party/xla/xla/backends/interpreter/executor.h b/third_party/xla/xla/backends/interpreter/executor.h index d103d94b99ac47..7a9419c73ef085 100644 --- a/third_party/xla/xla/backends/interpreter/executor.h +++ b/third_party/xla/xla/backends/interpreter/executor.h @@ -104,10 +104,6 @@ class XlaInterpreterExecutor : public StreamExecutorCommon { uint8_t pattern, uint64_t size) override { return absl::InternalError("Interpreter can not memset"); } - absl::Status Memset32(Stream *stream, DeviceMemoryBase *location, - uint32_t pattern, uint64_t size) override { - return absl::InternalError("Interpreter can not memset"); - } // No "synchronize all activity" implemented for this platform at the moment. bool SynchronizeAllActivity() override { return true; } diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 68f79ea1f2be9c..1a5f070a3a4d00 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -681,18 +681,6 @@ absl::Status GpuExecutor::Memset(Stream* stream, DeviceMemoryBase* location, AsGpuStreamValue(stream)); } -absl::Status GpuExecutor::Memset32(Stream* stream, DeviceMemoryBase* location, - uint32_t pattern, uint64_t size) { - VLOG(2) << "enqueueing memset32 operation onto stream " << stream - << " at location " << location << " with size " << size - << " and pattern " << std::hex << pattern; - CHECK(reinterpret_cast(location->opaque()) % 4 == 0 && - size % 4 == 0); - return GpuDriver::AsynchronousMemsetUint32( - context_, AsCudaDevicePtr(location), pattern, size / 4, - AsGpuStreamValue(stream)); -} - absl::Status GpuExecutor::Memcpy(Stream* stream, void* host_dst, const DeviceMemoryBase& gpu_src, uint64_t size) { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h index 775ef19365a99f..097c5143e32f26 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h @@ -210,8 +210,6 @@ class GpuExecutor : public StreamExecutorCommon { absl::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8_t pattern, uint64_t size) override; - absl::Status Memset32(Stream* stream, DeviceMemoryBase* location, - uint32_t pattern, uint64_t size) override; absl::Status Memcpy(Stream* stream, void* host_dst, const DeviceMemoryBase& gpu_src, uint64_t size) override; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc index 536bb1adefc9c4..7130e00a878762 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc @@ -58,6 +58,16 @@ Stream::PlatformSpecificHandle GpuStream::platform_specific_handle() const { return handle; } +absl::Status GpuStream::Memset32(DeviceMemoryBase* location, uint32_t pattern, + uint64_t size) { + CHECK(reinterpret_cast(location->opaque()) % 4 == 0 && + size % 4 == 0); + return GpuDriver::AsynchronousMemsetUint32( + parent_->gpu_context(), + reinterpret_cast(location->opaque()), pattern, size / 4, + gpu_stream()); +} + absl::Status GpuStream::MemZero(DeviceMemoryBase* location, uint64_t size) { if (reinterpret_cast(location->opaque()) % 4 == 0 && size % 4 == 0) { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h index d3ac630f214ce1..633a2341dabcd6 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h @@ -19,6 +19,7 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_STREAM_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_STREAM_H_ +#include #include #include "absl/log/check.h" @@ -99,6 +100,8 @@ class GpuStream : public StreamCommon { absl::Status WaitFor(Event* event) override; absl::Status RecordEvent(Event* event) override; absl::Status MemZero(DeviceMemoryBase* location, uint64_t size) override; + absl::Status Memset32(DeviceMemoryBase* location, uint32_t pattern, + uint64_t size) override; private: GpuExecutor* parent_; // Executor that spawned this stream. diff --git a/third_party/xla/xla/stream_executor/host/host_executor.cc b/third_party/xla/xla/stream_executor/host/host_executor.cc index 0d75e7e19962d6..95194dbc397033 100644 --- a/third_party/xla/xla/stream_executor/host/host_executor.cc +++ b/third_party/xla/xla/stream_executor/host/host_executor.cc @@ -188,16 +188,6 @@ absl::Status HostExecutor::Memset(Stream* stream, DeviceMemoryBase* location, return absl::OkStatus(); } -absl::Status HostExecutor::Memset32(Stream* stream, DeviceMemoryBase* location, - uint32_t pattern, uint64_t size) { - void* gpu_mem = location->opaque(); - // Enqueue the [asynchronous] memzero on the stream (HostStream) associated - // with the HostExecutor. - AsHostStream(stream)->EnqueueTask( - [gpu_mem, size, pattern]() { memset(gpu_mem, pattern, size); }); - return absl::OkStatus(); -} - absl::Status HostExecutor::SynchronousMemcpy(DeviceMemoryBase* gpu_dst, const void* host_src, uint64_t size) { diff --git a/third_party/xla/xla/stream_executor/host/host_executor.h b/third_party/xla/xla/stream_executor/host/host_executor.h index 5339315985d1a8..6d9d64a33bfd3e 100644 --- a/third_party/xla/xla/stream_executor/host/host_executor.h +++ b/third_party/xla/xla/stream_executor/host/host_executor.h @@ -100,8 +100,6 @@ class HostExecutor : public StreamExecutorCommon { absl::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8_t pattern, uint64_t size) override; - absl::Status Memset32(Stream* stream, DeviceMemoryBase* location, - uint32_t pattern, uint64_t size) override; // No "synchronize all activity" implemented for this platform at the moment. bool SynchronizeAllActivity() override { return true; } diff --git a/third_party/xla/xla/stream_executor/host/host_stream.cc b/third_party/xla/xla/stream_executor/host/host_stream.cc index 3a08013ee1225b..a45f188ee82fc0 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream.cc +++ b/third_party/xla/xla/stream_executor/host/host_stream.cc @@ -17,7 +17,10 @@ limitations under the License. // the HostExecutor implementation. #include "xla/stream_executor/host/host_stream.h" +#include + #include // NOLINT +#include #include #include #include @@ -27,6 +30,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/host/host_event.h" #include "xla/stream_executor/stream.h" @@ -53,6 +57,15 @@ HostStream::~HostStream() { parent()->DeallocateStream(this); } +absl::Status HostStream::Memset32(DeviceMemoryBase* location, uint32_t pattern, + uint64_t size) { + void* gpu_mem = location->opaque(); + // Enqueue the [asynchronous] memzero on the stream (HostStream) associated + // with the HostExecutor. + EnqueueTask([gpu_mem, size, pattern]() { memset(gpu_mem, pattern, size); }); + return absl::OkStatus(); +} + absl::Status HostStream::MemZero(DeviceMemoryBase* location, uint64_t size) { void* gpu_mem = location->opaque(); // Enqueue the [asynchronous] memzero on the stream (HostStream) associated diff --git a/third_party/xla/xla/stream_executor/host/host_stream.h b/third_party/xla/xla/stream_executor/host/host_stream.h index 105b4704f8609b..83dd18be8a9693 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream.h +++ b/third_party/xla/xla/stream_executor/host/host_stream.h @@ -55,6 +55,8 @@ class HostStream : public StreamCommon { absl::Status WaitFor(Event* event) override; absl::Status RecordEvent(Event* event) override; absl::Status MemZero(DeviceMemoryBase* location, uint64_t size) override; + absl::Status Memset32(DeviceMemoryBase* location, uint32_t pattern, + uint64_t size) override; private: bool WorkAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); diff --git a/third_party/xla/xla/stream_executor/mock_stream_executor.h b/third_party/xla/xla/stream_executor/mock_stream_executor.h index f5b1dc722d98ae..ac19fcf0932476 100644 --- a/third_party/xla/xla/stream_executor/mock_stream_executor.h +++ b/third_party/xla/xla/stream_executor/mock_stream_executor.h @@ -109,10 +109,6 @@ class MockStreamExecutor : public StreamExecutor { (Stream * stream, DeviceMemoryBase* location, uint8_t pattern, uint64_t size), (override)); - MOCK_METHOD(absl::Status, Memset32, - (Stream * stream, DeviceMemoryBase* location, uint32_t pattern, - uint64_t size), - (override)); MOCK_METHOD(absl::Status, Memcpy, (Stream * stream, void* host_dst, const DeviceMemoryBase& device_src, uint64_t size), diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc index 32b9fc4267d694..c8b00b133f898c 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc @@ -582,18 +582,6 @@ absl::Status GpuExecutor::Memset(Stream* stream, DeviceMemoryBase* location, AsGpuStreamValue(stream)); } -absl::Status GpuExecutor::Memset32(Stream* stream, DeviceMemoryBase* location, - uint32 pattern, uint64_t size) { - VLOG(2) << "enqueueing memset32 operation onto stream " << stream - << " at location " << location << " with size " << size - << " and pattern " << std::hex << pattern; - CHECK(reinterpret_cast(location->opaque()) % 4 == 0 && - size % 4 == 0); - return GpuDriver::AsynchronousMemsetUint32( - context_, AsROCmDevicePtr(location), pattern, size / 4, - AsGpuStreamValue(stream)); -} - absl::Status GpuExecutor::Memcpy(Stream* stream, void* host_dst, const DeviceMemoryBase& gpu_src, uint64_t size) { diff --git a/third_party/xla/xla/stream_executor/stream.h b/third_party/xla/xla/stream_executor/stream.h index 4fd904befad159..2cc531e2ef32bb 100644 --- a/third_party/xla/xla/stream_executor/stream.h +++ b/third_party/xla/xla/stream_executor/stream.h @@ -209,7 +209,10 @@ class Stream { // size bytes, where bytes must be evenly 32-bit sized (i.e. evenly divisible // by 4). The location must not be null. virtual absl::Status Memset32(DeviceMemoryBase *location, uint32_t pattern, - uint64_t size) = 0; + uint64_t size) { + return absl::UnimplementedError( + "Memset32 is not supported on this stream."); + } // (Synchronously) block the host code waiting for the operations // entrained on the stream (enqueued to this point in program diff --git a/third_party/xla/xla/stream_executor/stream_common.cc b/third_party/xla/xla/stream_executor/stream_common.cc index cf4c3ee7efd75e..44a09f16f827b2 100644 --- a/third_party/xla/xla/stream_executor/stream_common.cc +++ b/third_party/xla/xla/stream_executor/stream_common.cc @@ -163,11 +163,6 @@ absl::Status StreamCommon::Memcpy(DeviceMemoryBase *gpu_dst, return absl::InternalError("failed to memcpy"); } -absl::Status StreamCommon::Memset32(DeviceMemoryBase *location, - uint32_t pattern, uint64_t size) { - return parent_->Memset32(this, location, pattern, size); -} - absl::Status StreamCommon::DoHostCallback( absl::AnyInvocable callback) { return DoHostCallbackWithStatus([cb = std::move(callback)]() mutable { diff --git a/third_party/xla/xla/stream_executor/stream_common.h b/third_party/xla/xla/stream_executor/stream_common.h index 39e3e025fb2245..c83a0949d476f0 100644 --- a/third_party/xla/xla/stream_executor/stream_common.h +++ b/third_party/xla/xla/stream_executor/stream_common.h @@ -77,8 +77,6 @@ class StreamCommon : public Stream { uint64_t size) override; absl::Status Memcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, uint64_t size) override; - absl::Status Memset32(DeviceMemoryBase *location, uint32_t pattern, - uint64_t size) override; absl::Status BlockHostUntilDone() override TF_LOCKS_EXCLUDED(mu_); absl::Status DoHostCallback(absl::AnyInvocable callback) override; absl::Status DoHostCallbackWithStatus( diff --git a/third_party/xla/xla/stream_executor/stream_executor.h b/third_party/xla/xla/stream_executor/stream_executor.h index 636dda8581203a..f629c7c2781848 100644 --- a/third_party/xla/xla/stream_executor/stream_executor.h +++ b/third_party/xla/xla/stream_executor/stream_executor.h @@ -231,13 +231,6 @@ class StreamExecutor { return absl::InternalError("Not implemented"); } - // Enqueues an operation onto stream to set 32-bit patterns starting at - // location, for byte count given by size. size must be 32-bit quantified - // (i.e. evenly divisible by 4). Returns whether the operation was - // successfully enqueued onto the stream. - virtual absl::Status Memset32(Stream* stream, DeviceMemoryBase* location, - uint32_t pattern, uint64_t size) = 0; - // Enqueues a memcpy operation onto stream, with a host destination location // host_dst and a device memory source, with target size size. virtual absl::Status Memcpy(Stream* stream, void* host_dst, diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h index 73785f36cf1e37..f6d3352c84b9c5 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h @@ -140,10 +140,6 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { // -- Unimplemented (stubbed out) methods. - absl::Status Memset32(Stream* stream, DeviceMemoryBase* location, - uint32_t pattern, uint64_t size) override { - LOG(FATAL) << "not yet implemented"; - } absl::Status EnablePeerAccessTo(StreamExecutor* other) override { LOG(FATAL) << "not yet implemented"; } From 9e43d0798045f1177dc828e2e5a4f5259057361d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 20 Jun 2024 09:16:00 -0700 Subject: [PATCH 059/256] Prevents linspace from generating nans for F8 types. PiperOrigin-RevId: 645056198 --- third_party/xla/xla/BUILD | 1 + third_party/xla/xla/array2d.h | 8 +++--- third_party/xla/xla/array2d_test.cc | 43 +++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 5bb6f608bf0599..5b813d50f79dc2 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -778,6 +778,7 @@ xla_cc_test( deps = [ ":array2d", ":test", + "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/array2d.h b/third_party/xla/xla/array2d.h index 2e8c1547a967a3..9afe514f5458c1 100644 --- a/third_party/xla/xla/array2d.h +++ b/third_party/xla/xla/array2d.h @@ -95,14 +95,14 @@ std::unique_ptr> MakeLinspaceArray2D(double from, double to, int64_t n1, int64_t n2) { auto array = std::make_unique>(n1, n2); int64_t count = n1 * n2; - NativeT step = - static_cast((count > 1) ? (to - from) / (count - 1) : 0); + double step = + static_cast((count > 1) ? (to - from) / (count - 1) : 0); + auto set = [&array, n2](int64_t index, NativeT value) { (*array)(index / n2, index % n2) = value; }; for (int64_t i = 0; i < count - 1; ++i) { - set(i, (static_cast(from) + - static_cast(i) * static_cast(step))); + set(i, (static_cast(from + i * step))); } set(count - 1, static_cast(to)); return array; diff --git a/third_party/xla/xla/array2d_test.cc b/third_party/xla/xla/array2d_test.cc index b7052d1c33f725..f3212f6bc20b72 100644 --- a/third_party/xla/xla/array2d_test.cc +++ b/third_party/xla/xla/array2d_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "xla/test.h" +#include "tsl/platform/ml_dtypes.h" namespace xla { namespace { @@ -146,6 +147,48 @@ TEST(Array2dTest, Linspace) { EXPECT_FLOAT_EQ((*arr)(2, 1), 3.5); } +TEST(Array2dTest, LinspaceF8E5M2) { + auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); + + EXPECT_EQ(arr->n1(), 3); + EXPECT_EQ(arr->n2(), 2); + + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 0)), 1.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 1)), 1.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 0)), 2.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 1)), 2.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 0)), 3.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 3.5); +} + +TEST(Array2dTest, LinspaceF8E4M3Fn) { + auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); + + EXPECT_EQ(arr->n1(), 3); + EXPECT_EQ(arr->n2(), 2); + + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 0)), 1.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 1)), 1.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 0)), 2.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 1)), 2.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 0)), 3.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 3.5); +} + +// We need this test for float8_e4m3fn because it doesn't have a representation +// for infinity. We need to ensure the algorithm used for linspace doesn't +// convert large numbers directly to F8E4M3FN. +TEST(Array2dTest, LinspaceF8E4M3FnNoNan) { + auto arr = MakeLinspaceArray2D(0, 1, 23, 42); + + for (int64_t n1 = 0; n1 < arr->n1(); ++n1) { + for (int64_t n2 = 0; n2 < arr->n2(); ++n2) { + // Check for NaN. + EXPECT_EQ((*arr)(n1, n2), (*arr)(n1, n2)); + } + } +} + TEST(Array2dTest, Stringification) { auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); const std::string expected = R"([[1, 1.5], From a16c577adb042bc6572294f448b70131f2d8129f Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Thu, 20 Jun 2024 09:25:35 -0700 Subject: [PATCH 060/256] Disable wgmma support in XLA, since it is causing huge compile time regressions in certain benchmarks PiperOrigin-RevId: 645058828 --- third_party/xla/xla/service/gpu/ir_emitter_triton.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index 0dcb453d7655b0..371850d884a608 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -95,6 +95,7 @@ limitations under the License. #include "mlir/Transforms/Passes.h" // from @llvm-project #include "xla/autotuning.pb.h" #include "xla/comparison_util.h" +#include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -2723,6 +2724,14 @@ absl::StatusOr TritonWrapper( } } + // TODO(b/344841434): Remove this once we fixed compile-time regressions on + // multiple benchmarks. We need to disable this for now, since it is + // significantly slowing down folks that are trying to run them. + auto debug_options = GetDebugOptionsFromFlags(); + if (!debug_options.xla_gpu_enable_triton_hopper()) { + tsl::setenv("DISABLE_MMA_V3", "true", true /*overwrite*/); + } + TF_ASSIGN_OR_RETURN(auto triton_module, CreateTritonModule(fn_name, fusion, device_info, block_level_parameters, mlir_context)); From 9024f0262255f626ca621f259273cf2c7f15ea80 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Thu, 20 Jun 2024 09:25:42 -0700 Subject: [PATCH 061/256] [XLA] [NFC] Use a single SequentialThunk to communicate a sequence of thunks No need to pass around a vector. PiperOrigin-RevId: 645058862 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../service/gpu/compile_module_to_llvm_ir.h | 3 +- .../xla/xla/service/gpu/gpu_compiler.cc | 14 ++++----- .../xla/xla/service/gpu/gpu_executable.cc | 23 +++++--------- .../xla/xla/service/gpu/gpu_executable.h | 7 ++--- .../xla/service/gpu/ir_emitter_unnested.cc | 22 ++++++-------- .../xla/xla/service/gpu/ir_emitter_unnested.h | 6 ++-- third_party/xla/xla/service/gpu/runtime/BUILD | 1 + .../gpu/runtime/command_buffer_thunk.cc | 21 +++++-------- .../gpu/runtime/command_buffer_thunk.h | 7 +++-- .../xla/service/gpu/runtime/for_all_thunks.cc | 6 ++-- .../gpu/runtime/for_all_thunks_test.cc | 26 ++++++++++------ .../xla/service/gpu/runtime/kernel_thunk.cc | 4 +-- .../xla/service/gpu/runtime/kernel_thunk.h | 4 +-- .../service/gpu/runtime/sequential_thunk.cc | 30 +++++++++++++++---- .../service/gpu/runtime/sequential_thunk.h | 2 +- .../xla/xla/service/gpu/runtime/thunk.cc | 25 ---------------- .../xla/xla/service/gpu/runtime/thunk.h | 7 ++--- .../xla/service/gpu/runtime/while_thunk.cc | 10 +++---- .../xla/xla/service/gpu/runtime/while_thunk.h | 4 +-- 20 files changed, 103 insertions(+), 120 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 8899b034110add..98daa26d27e662 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1003,6 +1003,7 @@ cc_library( "//xla/service/gpu/runtime:for_all_thunks", "//xla/service/gpu/runtime:nccl_clique", "//xla/service/gpu/runtime:nccl_clique_key", + "//xla/service/gpu/runtime:sequential_thunk", "//xla/service/gpu/runtime:thunk", "//xla/stream_executor", "//xla/stream_executor:device_description", diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h index 0fdef39f261df2..d7005f879c3994 100644 --- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h +++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h @@ -31,6 +31,7 @@ limitations under the License. #include "xla/service/gpu/executable.pb.h" #include "xla/service/gpu/execution_stream_assignment.h" #include "xla/service/gpu/gpu_executable.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_dataflow_analysis.h" @@ -49,7 +50,7 @@ struct CompileModuleResults { std::unique_ptr buffer_assignment; std::unique_ptr execution_stream_assignment; std::vector allocations; - GpuExecutable::OwnedThunkSequence executable; + std::unique_ptr executable; std::vector constants; absl::flat_hash_map output_info; Shape output_shape; diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 574e6c239b83d0..f342425606d158 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -408,8 +408,6 @@ GpuThunkAotCompilationResult::LoadExecutable( auto ir_emitter = IrEmitterUnnested::Create(&ir_emitter_context); TF_RETURN_IF_ERROR( ir_emitter->EmitHloComputation(hlo_module->entry_computation())); - std::unique_ptr thunk_sequence = - ir_emitter->ConsumeThunkSequence(); // Get all other fields required by GpuExecutable. std::vector constants = @@ -431,7 +429,7 @@ GpuThunkAotCompilationResult::LoadExecutable( Thunk::BinaryMap(proto_.dnn_compiled_graphs().cbegin(), proto_.dnn_compiled_graphs().cend()), /*gpu_version=*/gpu_device_info.gpu_compute_capability(), - /*executable=*/std::move(thunk_sequence), + /*executable=*/ir_emitter->ConsumeThunkSequence(), /*constants=*/std::move(constants), /*output_info=*/std::move(output_info), /*module_name=*/std::move(hlo_module->name()), @@ -2072,8 +2070,9 @@ GpuCompiler::CompileToBackendResult( } RecordXlaDeviceBinarySize(backend_result.binary.size()); if (DumpingEnabledForHloModule(*module)) { - DumpToFileInDirOrStdout(*module, "", "thunk_sequence.txt", - compile_module_results.executable->ToString()); + DumpToFileInDirOrStdout( + *module, "", "thunk_sequence.txt", + compile_module_results.executable->ToString(/*indent=*/0)); } return CompileResultWithMetadata{std::move(backend_result), @@ -2148,8 +2147,9 @@ absl::StatusOr> GpuCompiler::RunBackend( gpu_device_info)); if (DumpingEnabledForHloModule(*module)) { - DumpToFileInDirOrStdout(*module, "", "thunk_sequence.txt", - res.compile_module_results.executable->ToString()); + DumpToFileInDirOrStdout( + *module, "", "thunk_sequence.txt", + res.compile_module_results.executable->ToString(/*indent=*/0)); } // The module is being moved into the GpuExecutable below and we need to diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index 196bc02964b5bd..42f7e6f72c1701 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -123,7 +123,7 @@ static bool NeedsAsyncCommsStream(Thunk& thunk) { // `GpuExecutable`. At run time `Thunks` may use additional streams to launch // compute operations in parallel. static absl::flat_hash_set GetExecutionStreamIds( - const ThunkSequence& thunks) { + const SequentialThunk& thunks) { absl::flat_hash_set stream_ids; ForAllThunks( [&](const Thunk* thunk) { @@ -355,7 +355,7 @@ absl::Status RendezvousAfterInitialization( absl::Status ExecuteThunks( const DebugOptions* debug_options, const std::string& module_name, - ModuleIdentifier module_id, const ThunkSequence& thunk_sequence, + ModuleIdentifier module_id, SequentialThunk& thunk_sequence, Thunk::ExecutableSource executable_source, const ServiceExecutableRunOptions* run_options, const BufferAllocations& buffer_allocations, bool block_host_until_done, @@ -443,9 +443,8 @@ absl::Status ExecuteThunks( Thunk::PrepareParams prepare_params{&collective_params}; tsl::profiler::TraceMe trace([&] { return "Thunks::Prepare"; }); - for (const std::unique_ptr& thunk : thunk_sequence) { - TF_RETURN_IF_ERROR(thunk->Prepare(prepare_params, resource_requests)); - } + TF_RETURN_IF_ERROR( + thunk_sequence.Prepare(prepare_params, resource_requests)); } // Acquire collective cliques requested by thunks. @@ -465,9 +464,7 @@ absl::Status ExecuteThunks( run_options->run_options().ffi_execution_context()}; tsl::profiler::TraceMe trace([&] { return "Thunks::Initialize"; }); - for (const std::unique_ptr& thunk : thunk_sequence) { - TF_RETURN_IF_ERROR(thunk->Initialize(initialize_params)); - } + TF_RETURN_IF_ERROR(thunk_sequence.Initialize(initialize_params)); } // Maybe join a round of rendezvous after thunk initialization. We do this @@ -483,14 +480,8 @@ absl::Status ExecuteThunks( command_buffer_trace_stream, &collective_params, &collective_cliques, std::move(additional_execution_streams)); - for (const std::unique_ptr& thunk : thunk_sequence) { - // Annotate execution of this op if tracing was enabled when we started - // running this module. If tracing is enabled *while* we're running the - // module, we won't get any data, but that's probably an OK trade-off. - auto scoped_annotation = GetKernelAnnotation(thunk->profile_annotation()); - VLOG(3) << "Executing the thunk for " << thunk->profile_annotation(); - TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(execute_params)); - } + TF_RETURN_IF_ERROR(thunk_sequence.ExecuteOnStream(execute_params)); + return MaybeSyncAndProfile(run_options, std::move(execution_timer), block_host_until_done ? main_stream : nullptr); } diff --git a/third_party/xla/xla/service/gpu/gpu_executable.h b/third_party/xla/xla/service/gpu/gpu_executable.h index 3a0a87be9009fe..6acdbd59ca662b 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.h +++ b/third_party/xla/xla/service/gpu/gpu_executable.h @@ -40,6 +40,7 @@ limitations under the License. #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/runtime/annotation.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/hlo_execution_profile.h" #include "xla/service/hlo_module_config.h" @@ -63,8 +64,6 @@ namespace gpu { // This is an immutable data type after initialization, and thus thread safe. class GpuExecutable : public Executable { public: - using OwnedThunkSequence = std::unique_ptr; - struct ConstantInfo { std::string symbol_name; DenseDataIntermediate content; @@ -88,7 +87,7 @@ class GpuExecutable : public Executable { std::vector binary; Thunk::BinaryMap dnn_compiled_graphs; se::GpuComputeCapability gpu_version; - OwnedThunkSequence executable; + std::unique_ptr executable; std::vector constants; absl::flat_hash_map output_info; std::string module_name; @@ -234,7 +233,7 @@ class GpuExecutable : public Executable { // The thunks to be invoked by this GpuExecutable. They are generated by the // IrEmitter (null if XLA:GPU runtime is enabled). - OwnedThunkSequence thunks_; + std::unique_ptr thunks_; // Additional execution streams requested by `thunks_`. absl::flat_hash_set execution_stream_ids_; diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 40c985752ab2ea..118948792b215a 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -225,28 +225,23 @@ absl::Status IrEmitterUnnested::EmitConstant( static ConditionalThunkConfig GetConditionalThunkConfig( const HloInstruction* instr, - std::vector branch_thunk_sequences) { + std::vector> branch_thunk_sequences) { ConditionalThunkConfig config; config.branch_index_is_bool = instr->operand(0)->shape().element_type() == PRED; config.branch_count = instr->branch_count(); - config.branch_thunks.reserve(config.branch_count); - for (auto& branch_thunk_sequence : branch_thunk_sequences) { - config.branch_thunks.emplace_back( - new SequentialThunk(Thunk::ThunkInfo::WithProfileAnnotation(instr), - std::move(branch_thunk_sequence))); - } + config.branch_thunks = std::move(branch_thunk_sequences); return config; } absl::Status IrEmitterUnnested::EmitConditional(const HloInstruction* instr) { - std::vector branch_thunks; + std::vector> branch_thunks; branch_thunks.reserve(instr->branch_count()); for (auto comp : instr->branch_computations()) { auto ir_emitter = IrEmitterUnnested::Create(ir_emitter_context_); TF_RETURN_IF_ERROR(ir_emitter->EmitHloComputation(comp)); - branch_thunks.push_back(std::move(*ir_emitter->ConsumeThunkSequence())); + branch_thunks.push_back(ir_emitter->ConsumeThunkSequence()); } ConditionalThunkConfig config = @@ -569,7 +564,7 @@ absl::Status IrEmitterUnnested::EmitCommandBufferThunk( const HloComputation* command_buffer = instr->called_computations().front(); auto ir_emitter = IrEmitterUnnested::Create(ir_emitter_context_); TF_RETURN_IF_ERROR(ir_emitter->EmitHloComputation(command_buffer)); - std::unique_ptr thunk_sequence = + std::unique_ptr thunk_sequence = ir_emitter->ConsumeThunkSequence(); // Maybe serialize all commands in a sequence by forcing barriers between all @@ -581,12 +576,13 @@ absl::Status IrEmitterUnnested::EmitCommandBufferThunk( ? CommandBufferCmdSequence::SynchronizationMode::kAutomatic : CommandBufferCmdSequence::SynchronizationMode::kSerialize; - TF_ASSIGN_OR_RETURN(CommandBufferCmdSequence cmd_sequence, - ConvertToCommands(*thunk_sequence, synchronization_mode)); + TF_ASSIGN_OR_RETURN( + CommandBufferCmdSequence cmd_sequence, + ConvertToCommands(thunk_sequence->thunks(), synchronization_mode)); AddThunkToThunkSequence(std::make_unique( std::move(cmd_sequence), Thunk::ThunkInfo::WithProfileAnnotation(instr), - std::move(*thunk_sequence))); + std::move(thunk_sequence))); return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h index 5774dbe96ec7a4..ac92bf265ff4b2 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h @@ -44,6 +44,7 @@ limitations under the License. #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/runtime/copy_thunk.h" #include "xla/service/gpu/runtime/send_recv_thunk.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_util.h" @@ -111,8 +112,9 @@ class IrEmitterUnnested : public IrEmitter { IrEmitterContext* ir_emitter_context); // Transfers the ownship of thunk_sequence_ out. - std::unique_ptr ConsumeThunkSequence() { - return std::make_unique(std::move(thunk_sequence_)); + std::unique_ptr ConsumeThunkSequence() { + return std::make_unique(Thunk::ThunkInfo{}, + std::move(thunk_sequence_)); } // Emits code for the given HLO computation. diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index 7ca1b282dc0cbb..5af4d5703fee7f 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -422,6 +422,7 @@ cc_library( ":thunk", "//xla/service:buffer_assignment", # build_cleaner: keep "//xla/service/gpu:buffer_allocations", # build_cleaner: keep + "//xla/service/gpu/runtime:sequential_thunk", # build_cleaner: keep "//xla/stream_executor", "//xla/stream_executor:command_buffer", "@com_google_absl//absl/base:core_headers", diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.cc index b551766f43c1a9..a37913a57e352a 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.cc @@ -56,7 +56,7 @@ CommandBufferThunk::ExecutorCommandBuffer::ExecutorCommandBuffer( CommandBufferThunk::CommandBufferThunk(CommandBufferCmdSequence commands, ThunkInfo thunk_info, - std::optional thunks) + std::unique_ptr thunks) : Thunk(Thunk::kCommandBuffer, std::move(thunk_info)), commands_(std::move(commands)), thunks_(std::move(thunks)), @@ -116,10 +116,8 @@ absl::Status CommandBufferThunk::Prepare(const PrepareParams& params, // Always prepare thunks if they are present so we are ready to fall back // on them if we detect profiling activity. - if (thunks_.has_value()) { - for (auto& thunk : *thunks_) { - TF_RETURN_IF_ERROR(thunk->Prepare(params, resource_requests)); - } + if (thunks_) { + TF_RETURN_IF_ERROR(thunks_->Prepare(params, resource_requests)); } return absl::OkStatus(); @@ -139,10 +137,8 @@ absl::Status CommandBufferThunk::Initialize(const InitializeParams& params) { // Always initialize thunks if they are present so we are ready to fall back // on them if we detect profiling activity. - if (thunks_.has_value()) { - for (auto& thunk : *thunks_) { - TF_RETURN_IF_ERROR(thunk->Initialize(params)); - } + if (thunks_) { + TF_RETURN_IF_ERROR(thunks_->Initialize(params)); } // Construct ExecuteParams with empty fields for everything that is not needed @@ -203,13 +199,10 @@ absl::Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) { // TODO(b/290773547): Profiler (CUPTI) + CUDA graphs lead to memory // corruption. As a work around disable command buffers (CUDA graphs) and run // everything in op-by-op mode. - if (tsl::profiler::ProfilerLock::HasActiveSession() && thunks_.has_value()) { + if (tsl::profiler::ProfilerLock::HasActiveSession() && thunks_) { VLOG(1) << "Execute command buffer thunk as a regular thunk sequence " "because we detected active profiling session"; - for (auto& thunk : *thunks_) { - auto scoped_annotation = GetKernelAnnotation(thunk->profile_annotation()); - TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(params)); - } + TF_RETURN_IF_ERROR(thunks_->ExecuteOnStream(params)); return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.h b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.h index 3c05aada9607e7..a3cb4672c951b3 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "xla/service/gpu/runtime/command_buffer_cmd.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory.h" @@ -37,9 +38,9 @@ namespace xla::gpu { class CommandBufferThunk : public Thunk { public: CommandBufferThunk(CommandBufferCmdSequence commands, ThunkInfo thunk_info, - std::optional thunks = std::nullopt); + std::unique_ptr thunks = nullptr); - const std::optional& thunks() const { return thunks_; } + const std::unique_ptr& thunks() const { return thunks_; } absl::Status Prepare(const PrepareParams& params, ResourceRequests& resource_requests) override; @@ -125,7 +126,7 @@ class CommandBufferThunk : public Thunk { // Thunk sequence that executes the same commands as in `commands_` but using // thunk mechanism. We use it as a fallback mechanism to work around CUPTI // bugs that lead to memory corruption when CUPTI traces CUDA graph execution. - std::optional thunks_; + std::unique_ptr thunks_; // Command buffer thunk state allocated in heap to allow global (per-process) // management of instantiated command buffers. diff --git a/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc b/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc index d10d82d31581ef..0ccf5d09bd6f3b 100644 --- a/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc +++ b/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc @@ -40,10 +40,10 @@ void ForAllThunks(absl::FunctionRef fn, ->embedded_thunk()); break; case Thunk::kCommandBuffer: - if (const std::optional& sequence = + if (const std::unique_ptr& sequence = tensorflow::down_cast(thunk)->thunks(); - sequence.has_value()) { - ForAllThunks(fn, &sequence.value()); + sequence != nullptr) { + ForAllThunks(fn, sequence.get()); } break; case Thunk::kConditional: diff --git a/third_party/xla/xla/service/gpu/runtime/for_all_thunks_test.cc b/third_party/xla/xla/service/gpu/runtime/for_all_thunks_test.cc index fceb4168396a14..6220e55fbcc4a9 100644 --- a/third_party/xla/xla/service/gpu/runtime/for_all_thunks_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/for_all_thunks_test.cc @@ -81,11 +81,16 @@ TEST(ForAllThunksTest, CommandBufferThunk) { ThunkSequence thunk_sequence; thunk_sequence.push_back(std::move(thunk)); + auto sequential_thunk = std::make_unique( + Thunk::ThunkInfo(), std::move(thunk_sequence)); + Thunk* sequential_thunk_ptr = sequential_thunk.get(); + CommandBufferThunk command_buffer_thunk(CommandBufferCmdSequence(), Thunk::ThunkInfo(), - std::move(thunk_sequence)); + std::move(sequential_thunk)); EXPECT_THAT(GetAllThunks(&command_buffer_thunk), - UnorderedElementsAre(thunk_ptr, &command_buffer_thunk)); + UnorderedElementsAre(thunk_ptr, &command_buffer_thunk, + sequential_thunk_ptr)); } TEST(ForAllThunksTest, ConditionalThunk) { @@ -113,18 +118,21 @@ TEST(ForAllThunksTest, WhileThunk) { auto condition_thunk = std::make_unique(); Thunk* condition_thunk_ptr = condition_thunk.get(); - auto condition_thunk_sequence = std::make_unique(); - condition_thunk_sequence->push_back(std::move(condition_thunk)); + ThunkSequence condition_thunk_sequence; + condition_thunk_sequence.push_back(std::move(condition_thunk)); auto body_thunk = std::make_unique(); Thunk* body_thunk_ptr = body_thunk.get(); - auto body_thunk_sequence = std::make_unique(); - body_thunk_sequence->push_back(std::move(body_thunk)); + ThunkSequence body_thunk_sequence; + body_thunk_sequence.push_back(std::move(body_thunk)); - WhileThunk while_thunk(Thunk::ThunkInfo(), BufferAllocation::Slice(), - std::move(condition_thunk_sequence), - std::move(body_thunk_sequence)); + WhileThunk while_thunk( + Thunk::ThunkInfo(), BufferAllocation::Slice(), + std::make_unique(Thunk::ThunkInfo(), + std::move(condition_thunk_sequence)), + std::make_unique(Thunk::ThunkInfo(), + std::move(body_thunk_sequence))); EXPECT_THAT(GetAllThunks(&while_thunk), // `WhileThunk` wraps the `condition_thunk_sequence` and diff --git a/third_party/xla/xla/service/gpu/runtime/kernel_thunk.cc b/third_party/xla/xla/service/gpu/runtime/kernel_thunk.cc index a0343001ac35b7..ad15ee2eb5a6f7 100644 --- a/third_party/xla/xla/service/gpu/runtime/kernel_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/kernel_thunk.cc @@ -70,7 +70,7 @@ KernelThunk::KernelThunk(const HloInstruction* instr, std::string kernel_name, } } -std::string KernelThunk::ToStringExtra(int indent) const { +std::string KernelThunk::ToString(int indent) const { return absl::StrFormat( ", kernel = %s, launch dimensions = %s, cluster_dim = %s", kernel_name_, launch_dimensions_.ToString(), @@ -178,7 +178,7 @@ CustomKernelThunk::CustomKernelThunk( } } -std::string CustomKernelThunk::ToStringExtra(int indent) const { +std::string CustomKernelThunk::ToString(int indent) const { return custom_kernel_.ToString(); } diff --git a/third_party/xla/xla/service/gpu/runtime/kernel_thunk.h b/third_party/xla/xla/service/gpu/runtime/kernel_thunk.h index ae5e0da2ea393e..d26e5cab3a182f 100644 --- a/third_party/xla/xla/service/gpu/runtime/kernel_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/kernel_thunk.h @@ -77,7 +77,7 @@ class KernelThunk : public Thunk { KernelThunk& operator=(const KernelThunk&) = delete; ~KernelThunk() override = default; - std::string ToStringExtra(int indent) const override; + std::string ToString(int indent) const override; absl::Status Initialize(const InitializeParams& params) override; absl::Status ExecuteOnStream(const ExecuteParams& params) override; @@ -130,7 +130,7 @@ class CustomKernelThunk : public Thunk { CustomKernelThunk(const HloInstruction* inst, CustomKernel custom_kernel, absl::Span kernel_arguments); - std::string ToStringExtra(int indent) const override; + std::string ToString(int indent) const override; absl::Status Initialize(const InitializeParams& params) override; absl::Status ExecuteOnStream(const ExecuteParams& params) override; diff --git a/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc index d874c25ea560f8..6033ddc052748c 100644 --- a/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc @@ -31,9 +31,28 @@ namespace gpu { SequentialThunk::SequentialThunk(ThunkInfo thunk_info, ThunkSequence thunks) : Thunk(Kind::kSequential, thunk_info), thunks_(std::move(thunks)) {} -std::string SequentialThunk::ToStringExtra(int indent) const { - std::string result = "\n"; - absl::StrAppend(&result, thunks().ToString(indent + 1)); +std::string SequentialThunk::ToString(int indent) const { + const std::string indent_str(indent * 2, ' '); + if (thunks_.empty()) return indent_str + "No thunks."; + + auto thunk_with_longest_kind = absl::c_max_element( + thunks_, + [](const std::unique_ptr& a, const std::unique_ptr& b) { + return Thunk::KindToString(a->kind()).length() < + Thunk::KindToString(b->kind()).length(); + }); + int64_t max_thunk_kind_len = + Thunk::KindToString(thunk_with_longest_kind->get()->kind()).length(); + std::string result; + for (const std::unique_ptr& thunk : thunks_) { + // Write out the thunk kind, padded out to max_thunk_kind_len. + absl::string_view kind_str = Thunk::KindToString(thunk->kind()); + absl::StrAppend(&result, indent_str, kind_str, + std::string(max_thunk_kind_len - kind_str.length(), ' '), + "\t"); + absl::StrAppend(&result, thunk->ToString(indent + 1)); + absl::StrAppend(&result, "\n"); + } return result; } @@ -53,8 +72,9 @@ absl::Status SequentialThunk::Initialize(const InitializeParams& params) { } absl::Status SequentialThunk::ExecuteOnStream(const ExecuteParams& params) { - for (const auto& thunk : thunks_) { - auto annotation = GetKernelAnnotation(thunk->profile_annotation()); + for (const std::unique_ptr& thunk : thunks_) { + std::optional annotation = + GetKernelAnnotation(thunk->profile_annotation()); TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(params)); } return absl::OkStatus(); diff --git a/third_party/xla/xla/service/gpu/runtime/sequential_thunk.h b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.h index 4642f08becadb5..d754d42f394865 100644 --- a/third_party/xla/xla/service/gpu/runtime/sequential_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.h @@ -35,7 +35,7 @@ class SequentialThunk : public Thunk { ThunkSequence& thunks() { return thunks_; } const ThunkSequence& thunks() const { return thunks_; } - std::string ToStringExtra(int indent) const override; + std::string ToString(int indent) const override; absl::Status Prepare(const PrepareParams& params, ResourceRequests& resource_requests) override; diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.cc b/third_party/xla/xla/service/gpu/runtime/thunk.cc index 877db05fe04e89..759df65bc28aa6 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/thunk.cc @@ -303,31 +303,6 @@ std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) { return os << Thunk::KindToString(kind); } -std::string ThunkSequence::ToString(int indent) const { - const std::string indent_str(indent * 2, ' '); - if (empty()) return indent_str + "No thunks."; - - auto thunk_with_longest_kind = absl::c_max_element( - *this, - [](const std::unique_ptr& a, const std::unique_ptr& b) { - return Thunk::KindToString(a->kind()).length() < - Thunk::KindToString(b->kind()).length(); - }); - int64_t max_thunk_kind_len = - Thunk::KindToString(thunk_with_longest_kind->get()->kind()).length(); - std::string result; - for (const std::unique_ptr& thunk : *this) { - // Write out the thunk kind, padded out to max_thunk_kind_len. - absl::string_view kind_str = Thunk::KindToString(thunk->kind()); - absl::StrAppend(&result, indent_str, kind_str, - std::string(max_thunk_kind_len - kind_str.length(), ' '), - "\t"); - absl::StrAppend(&result, thunk->ToStringExtra(indent)); - absl::StrAppend(&result, "\n"); - } - return result; -} - bool IsReductionCollective(Thunk::Kind kind) { return kind == Thunk::kNcclAllReduce || kind == Thunk::kNcclAllReduceStart || kind == Thunk::kNcclReduceScatter || diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.h b/third_party/xla/xla/service/gpu/runtime/thunk.h index 78698859bf8553..483c817b7add93 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/thunk.h @@ -410,7 +410,7 @@ class Thunk { Thunk(const Thunk&) = delete; Thunk& operator=(const Thunk&) = delete; - virtual std::string ToStringExtra(int indent) const { return ""; } + virtual std::string ToString(int indent) const { return ""; } Kind kind() const { return kind_; } std::string_view profile_annotation() const { return profile_annotation_; } @@ -458,10 +458,7 @@ class Thunk { }; // A sequence of thunks. -class ThunkSequence : public std::vector> { - public: - std::string ToString(int indent = 0) const; -}; +using ThunkSequence = std::vector>; std::ostream& operator<<(std::ostream& os, Thunk::Kind kind); diff --git a/third_party/xla/xla/service/gpu/runtime/while_thunk.cc b/third_party/xla/xla/service/gpu/runtime/while_thunk.cc index b5473b230bfb7e..d3ad896b10793b 100644 --- a/third_party/xla/xla/service/gpu/runtime/while_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/while_thunk.cc @@ -59,15 +59,13 @@ absl::StatusOr WhileThunk::CurrentLoopIteration(int64_t depth) { WhileThunk::WhileThunk( ThunkInfo thunk_info, const BufferAllocation::Slice& condition_result_buffer_index, - std::unique_ptr condition_thunk_sequence, - std::unique_ptr body_thunk_sequence, + std::unique_ptr condition_thunk_sequence, + std::unique_ptr body_thunk_sequence, std::optional trip_count) : Thunk(Kind::kWhile, thunk_info), condition_result_buffer_index_(condition_result_buffer_index), - condition_thunk_sequence_(std::make_unique( - ThunkInfo(), std::move(*condition_thunk_sequence))), - body_thunk_sequence_(std::make_unique( - ThunkInfo(), std::move(*body_thunk_sequence))), + condition_thunk_sequence_(std::move(condition_thunk_sequence)), + body_thunk_sequence_(std::move(body_thunk_sequence)), trip_count_(trip_count) {} absl::Status WhileThunk::Prepare(const PrepareParams& params, diff --git a/third_party/xla/xla/service/gpu/runtime/while_thunk.h b/third_party/xla/xla/service/gpu/runtime/while_thunk.h index 5cf46876dc9e2b..3ab0069c1a897f 100644 --- a/third_party/xla/xla/service/gpu/runtime/while_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/while_thunk.h @@ -53,8 +53,8 @@ class WhileThunk : public Thunk { // Constructs a WhileThunk to compute while instruction 'hlo'. WhileThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& condition_result_buffer_index, - std::unique_ptr condition_thunk_sequence, - std::unique_ptr body_thunk_sequence, + std::unique_ptr condition_thunk_sequence, + std::unique_ptr body_thunk_sequence, std::optional trip_count = std::nullopt); WhileThunk(const WhileThunk&) = delete; WhileThunk& operator=(const WhileThunk&) = delete; From 546829c2657218df365cfe244c049ed24a879b3f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 20 Jun 2024 09:30:17 -0700 Subject: [PATCH 062/256] Call ShapeUtil::ByteSizeOfElements instead of a copy of the function modified for auto-sharding. Reusing the original function does not seem to trigger crashes as comments in the code suggest. PiperOrigin-RevId: 645060360 --- .../auto_sharding/auto_sharding_util.h | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index e86dce3966fca4..c8fd040c7a2107 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -155,27 +155,10 @@ std::string ToString(absl::Span span) { return absl::StrCat("[", absl::StrJoin(span, ", "), "]"); } -// Shape Utility - -// Get the bytes of an array shape without checking its layout. -// This is modified from ShapeUtil::ByteSizeOfElements (shape_util.cc). -inline int64_t ByteSizeOfElementsNoCheck(const Shape& shape) { - TF_DCHECK_OK(ShapeUtil::ValidateShape(shape)); - CHECK(shape.IsArray()); - int64_t allocated_element_count; - - // Disable this check. Otherwise, it raises a fatal error on HloOpcode::kIota - // generated by jax dropout. - // CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString(); - allocated_element_count = ShapeUtil::ElementsIn(shape); - return allocated_element_count * - ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()); -} - // Get the number of bytes of a shape. inline double GetBytes(const Shape& shape) { if (shape.IsArray()) { - return ByteSizeOfElementsNoCheck(shape); + return ShapeUtil::ByteSizeOfElements(shape); } return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/8); } From af52806c7efb78054cd7d465c22ba53ce6fe3be8 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 20 Jun 2024 09:52:04 -0700 Subject: [PATCH 063/256] [xla:cpu] Don't run collective test with thunks, not all thunks are ready PiperOrigin-RevId: 645066425 --- third_party/xla/xla/tests/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 7250e59f99f2a2..f66a6c6099d7f9 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -2167,7 +2167,6 @@ xla_test( "gpu", "cpu", ], - tags = ["test_xla_cpu_thunks"], deps = [ ":hlo_test_base", ":literal_test_util", From a0d4b376ceb97e8c904854eac33eb87a5b7ad0aa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 20 Jun 2024 10:02:00 -0700 Subject: [PATCH 064/256] Migrate usage of schema_conversion_utils. PiperOrigin-RevId: 645069638 --- tensorflow/compiler/mlir/lite/BUILD | 2 +- tensorflow/compiler/mlir/lite/flatbuffer_export.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 46ca0c761e2804..32a5bc738faa1a 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1091,6 +1091,7 @@ cc_library( "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite/metrics:error_collector_inst", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_tensor", @@ -1106,7 +1107,6 @@ cc_library( "//tensorflow/lite/delegates/flex:allowlisted_flex_ops_lib", "//tensorflow/lite/experimental/remat:metadata_util", "//tensorflow/lite/python/metrics:converter_error_data_proto_cc", - "//tensorflow/lite/schema:schema_conversion_utils", "//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/tools/versioning", "//tensorflow/lite/tools/versioning:gpu_compatibility", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 7449a037516a6c..0c6c56cf001a0c 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -86,6 +86,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/lite/utils/low_bit_utils.h" #include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h" @@ -111,7 +112,6 @@ limitations under the License. #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h" #include "tensorflow/lite/experimental/remat/metadata_util.h" #include "tensorflow/lite/python/metrics/converter_error_data.pb.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/tools/versioning/gpu_compatibility.h" #include "tensorflow/lite/tools/versioning/op_version.h" From 4fafa4eb9c4a58dbca0dd78e09f54af0f2fed4ed Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Thu, 20 Jun 2024 10:14:40 -0700 Subject: [PATCH 065/256] Stop using xla/statusor.h now that it just contains an alias for absl::Status. In some situations, this meant also changing unrelated files to directly include tsl/platform/statusor.h to get the definitions for TF_ASSIGN_OR_RETURN, etc., where they were getting transitively included for free. PiperOrigin-RevId: 645074582 --- third_party/xla/xla/hlo/experimental/auto_sharding/BUILD | 7 +++---- .../xla/hlo/experimental/auto_sharding/auto_sharding.cc | 2 +- .../experimental/auto_sharding/auto_sharding_strategy.cc | 2 +- .../hlo/experimental/auto_sharding/auto_sharding_test.cc | 2 +- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD index a8a4b7e27fb09c..44ebc599c7bf19 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD @@ -44,7 +44,6 @@ cc_library( "//xla:array", "//xla:shape_tree", "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -73,6 +72,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", @@ -110,7 +110,6 @@ cc_library( ":auto_sharding_proto_cc", ":auto_sharding_strategy", "//xla:status_macros", - "//xla:statusor", "//xla:util", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", @@ -118,6 +117,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_ortools//ortools/linear_solver", @@ -264,7 +264,6 @@ cc_library( "//xla:array", "//xla:shape_tree", "//xla:shape_util", - "//xla:statusor", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:ptrvec", @@ -336,7 +335,6 @@ xla_cc_test( ":auto_sharding_option", ":auto_sharding_strategy", ":auto_sharding_util", - "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", "//xla/hlo/utils:hlo_matchers", @@ -351,6 +349,7 @@ xla_cc_test( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", "@local_tsl//tsl/lib/core:status_test_util", diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index b2ec822336e1cc..3b81bc79578a2e 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -39,6 +39,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -85,7 +86,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 2fc47e3c206dec..1cc88b3da7fd00 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -52,7 +53,6 @@ limitations under the License. #include "xla/service/hlo_cost_analysis.h" #include "xla/service/sharding_propagation.h" #include "xla/shape.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index c7fd9bdc42dc09..4991a67d064ab7 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" @@ -44,7 +45,6 @@ limitations under the License. #include "xla/service/hlo_memory_scheduler.h" #include "xla/service/hlo_parser.h" #include "xla/service/hlo_value.h" -#include "xla/statusor.h" #include "xla/tests/hlo_test_base.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" From 70d2f87bc0a7908543e81e91981a160ab1b2be72 Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Thu, 20 Jun 2024 10:29:29 -0700 Subject: [PATCH 066/256] Support quantized per-tensor type for MHLO Ceil/Floor Ops. PiperOrigin-RevId: 645079909 --- third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_base.td | 3 +++ third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td | 4 ++-- .../Dialect/mhlo/hlo-legalize-to-stablehlo.mlir | 14 ++++++++++++++ .../tests/Dialect/mhlo/mhlo_quantized.mlir | 16 +++++++++++++++- .../xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir | 2 +- .../Dialect/mhlo/stablehlo-legalize-to-hlo.mlir | 14 ++++++++++++++ 6 files changed, 49 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_base.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_base.td index 992aca563edf97..cfab8f3857ebc8 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_base.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_base.td @@ -105,6 +105,9 @@ defvar MHLO_PredOrIntTensor = HLO_PredOrIntTensor; // Any floating-point or complex tensor types defvar MHLO_FpOrComplexTensor = HLO_FpOrComplexTensor; +// Any floating-point or quantized tensor types +defvar MHLO_FpOrQuantizedIntTensor = HLO_FpOrQuantizedIntTensor; + // Any floating-point, complex or quantized tensor types defvar MHLO_FpComplexOrQuantizedIntTensor = HLO_FpComplexOrQuantizedIntTensor; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td index 2dd1b7a33d3129..35eff2edbb202f 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td @@ -230,7 +230,7 @@ def MHLO_CbrtOp: MHLO_UnaryElementwiseOp<"cbrt", }]; } def MHLO_CeilOp: MHLO_UnaryElementwiseOp<"ceil", - [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpTensor> { + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpOrQuantizedIntTensor> { let summary = "Ceil operation"; let description = [{ Performs element-wise ceil of `operand` tensor and produces a `result` tensor. @@ -357,7 +357,7 @@ def MHLO_Expm1Op: MHLO_UnaryElementwiseOp<"exponential_minus_one", }]; } def MHLO_FloorOp: MHLO_UnaryElementwiseOp<"floor", - [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpTensor> { + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpOrQuantizedIntTensor> { let summary = "Floor operation"; let description = [{ Performs element-wise floor of `operand` tensor and produces a `result` diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index cd82c6a66fcbd0..5e7b46629b3491 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -533,6 +533,13 @@ func.func @op_ceil(%arg0: tensor) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "quantized_op_ceil" +func.func @quantized_op_ceil(%arg0: tensor>) -> tensor> { + // CHECK: "stablehlo.ceil"([[ARG0:%arg[0-9]+]]) : (tensor>) -> tensor> + %0 = "mhlo.ceil"(%arg0) : (tensor>) -> tensor> + func.return %0 : tensor> +} + // CHECK-LABEL: "op_cholesky" func.func @op_cholesky(%arg0: tensor<1x16x16xf32>) -> tensor<1x16x16xf32> { // CHECK: "stablehlo.cholesky"([[ARG0:%arg[0-9]+]]) <{ @@ -956,6 +963,13 @@ func.func @op_floor(%arg0: tensor) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "quantized_op_floor" +func.func @quantized_op_floor(%arg0: tensor>) -> tensor> { + // CHECK: "stablehlo.floor"([[ARG0:%arg[0-9]+]]) : (tensor>) -> tensor> + %0 = "mhlo.floor"(%arg0) : (tensor>) -> tensor> + func.return %0 : tensor> +} + // FusionOp aka mhlo.fusion is unsupported at the moment (see negative test below). // CHECK-LABEL: "op_gather" diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_quantized.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_quantized.mlir index 6e4171d44d6cbf..64025b79d74363 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_quantized.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_quantized.mlir @@ -22,4 +22,18 @@ func.func @uniform_quantized_c1(%arg0: tensor<2x!quant.uniform>) { // expected-error@+1 {{Expressed type of result expected to be 'f32', but got 'f64'}} %0 = "mhlo.uniform_quantize"(%arg0) : (tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> func.return -} \ No newline at end of file +} + +// ----- + +func.func @quantized_ceil_valid(%arg0: tensor<2x!quant.uniform>) { + %0 = mhlo.ceil %arg0 : tensor<2x!quant.uniform> + func.return +} + +// ----- + +func.func @quantized_floor_valid(%arg0: tensor<2x!quant.uniform>) { + %0 = mhlo.floor %arg0 : tensor<2x!quant.uniform> + func.return +} diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index 09b25b5a6539f4..839d66f5ea42f6 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -2716,7 +2716,7 @@ func.func @or_invalid_f32_type(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> te // ----- func.func @floor_invalid_i32_type(%arg0: tensor<4xi32>) -> tensor<4xi32> { - // expected-error@+1 {{op operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<4xi32>'}} + // expected-error@+1 {{op operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<4xi32>'}} %0 = "mhlo.floor"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> func.return %0 : tensor<4xi32> } diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index 9a55189a5b79c8..073b8b8bf77fce 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -520,6 +520,13 @@ func.func @op_ceil(%arg0: tensor) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "quantized_op_ceil" +func.func @quantized_op_ceil(%arg0: tensor>) -> tensor> { + // CHECK: "mhlo.ceil"([[ARG0:%arg[0-9]+]]) : (tensor>) -> tensor> + %0 = "stablehlo.ceil"(%arg0) : (tensor>) -> tensor> + func.return %0 : tensor> +} + // CHECK-LABEL: "op_cholesky" func.func @op_cholesky(%arg0: tensor<1x16x16xf32>) -> tensor<1x16x16xf32> { // CHECK: "mhlo.cholesky"([[ARG0:%arg[0-9]+]]) <{ @@ -941,6 +948,13 @@ func.func @op_floor(%arg0: tensor) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "quantized_op_floor" +func.func @quantized_op_floor(%arg0: tensor>) -> tensor> { + // CHECK: "mhlo.floor"([[ARG0:%arg[0-9]+]]) : (tensor>) -> tensor> + %0 = "stablehlo.floor"(%arg0) : (tensor>) -> tensor> + func.return %0 : tensor> +} + // CHECK-LABEL: "op_gather" func.func @op_gather(%arg0: tensor<2x3x4x2xi32>, %arg1: tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32> { // CHECK: "mhlo.gather"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) <{ From 2e76b08127ab52d95b18ae3ac2396b9bccae9f61 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Thu, 20 Jun 2024 10:31:02 -0700 Subject: [PATCH 067/256] [IFRT] Add AttributeMap `xla::ifrt::AttributeMap` is a reusable class that wraps around `absl::flat_hash_map`. It provides protobuf serialization and a human-readable debug string method. It will be used for various "attributes" methods (`Client`, `Device`, `Topology`). PiperOrigin-RevId: 645080443 --- third_party/xla/xla/python/ifrt/BUILD | 30 ++++ .../xla/xla/python/ifrt/attribute_map.cc | 137 ++++++++++++++++++ .../xla/xla/python/ifrt/attribute_map.h | 103 +++++++++++++ .../xla/xla/python/ifrt/attribute_map.proto | 35 +++++ .../xla/xla/python/ifrt/attribute_map_test.cc | 65 +++++++++ 5 files changed, 370 insertions(+) create mode 100644 third_party/xla/xla/python/ifrt/attribute_map.cc create mode 100644 third_party/xla/xla/python/ifrt/attribute_map.h create mode 100644 third_party/xla/xla/python/ifrt/attribute_map.proto create mode 100644 third_party/xla/xla/python/ifrt/attribute_map_test.cc diff --git a/third_party/xla/xla/python/ifrt/BUILD b/third_party/xla/xla/python/ifrt/BUILD index 86b06fc97fffd6..a2373d224ba2d7 100644 --- a/third_party/xla/xla/python/ifrt/BUILD +++ b/third_party/xla/xla/python/ifrt/BUILD @@ -137,6 +137,31 @@ xla_cc_test( ], ) +cc_library( + name = "attribute_map", + srcs = ["attribute_map.cc"], + hdrs = ["attribute_map.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":attribute_map_proto_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +xla_cc_test( + name = "attribute_map_test", + size = "small", + srcs = ["attribute_map_test.cc"], + deps = [ + ":attribute_map", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + xla_cc_test( name = "future_test", size = "small", @@ -285,6 +310,11 @@ xla_cc_test( ], ) +tf_proto_library( + name = "attribute_map_proto", + srcs = ["attribute_map.proto"], +) + cc_library( name = "client_impl_test_lib", testonly = True, diff --git a/third_party/xla/xla/python/ifrt/attribute_map.cc b/third_party/xla/xla/python/ifrt/attribute_map.cc new file mode 100644 index 00000000000000..d2b28c9368792a --- /dev/null +++ b/third_party/xla/xla/python/ifrt/attribute_map.cc @@ -0,0 +1,137 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/ifrt/attribute_map.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "xla/python/ifrt/attribute_map.pb.h" + +namespace xla { +namespace ifrt { + +absl::StatusOr AttributeMap::FromProto( + const AttributeMapProto& proto) { + AttributeMap::Map map; + map.reserve(proto.attributes_size()); + for (const auto& [key, value] : proto.attributes()) { + switch (value.value_case()) { + case AttributeMapProto::Value::kStringValue: + map.insert({key, StringValue(value.string_value())}); + break; + case AttributeMapProto::Value::kBoolValue: + map.insert({key, BoolValue(value.bool_value())}); + break; + case AttributeMapProto::Value::kInt64Value: + map.insert({key, Int64Value(value.int64_value())}); + break; + case AttributeMapProto::Value::kInt64ListValue: + map.insert({key, Int64ListValue(std::vector( + value.int64_list_value().elements().begin(), + value.int64_list_value().elements().end()))}); + break; + case AttributeMapProto::Value::kFloatValue: + map.insert({key, FloatValue(value.float_value())}); + break; + default: + return absl::InvalidArgumentError( + absl::StrCat("Unsupported value type: ", value.value_case())); + } + } + return AttributeMap(std::move(map)); +} + +AttributeMapProto AttributeMap::ToProto() const { + AttributeMapProto proto; + for (const auto& [key, value] : map_) { + AttributeMapProto::Value value_proto; + std::visit( + [&](const auto& value) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + value_proto.set_string_value(value.value); + } else if constexpr (std::is_same_v) { + value_proto.set_bool_value(value.value); + } else if constexpr (std::is_same_v) { + value_proto.set_int64_value(value.value); + } else if constexpr (std::is_same_v) { + auto* int64_list = value_proto.mutable_int64_list_value(); + int64_list->mutable_elements()->Reserve(value.value.size()); + for (const auto& element : value.value) { + int64_list->add_elements(element); + } + } else if constexpr (std::is_same_v) { + value_proto.set_float_value(value.value); + } + }, + value); + proto.mutable_attributes()->insert({key, std::move(value_proto)}); + } + return proto; +} + +std::string AttributeMap::DebugString(size_t max_string_length, + size_t max_int64_list_size) const { + auto formatter = [=](std::string* out, + const AttributeMap::Map::value_type& key_value) { + absl::StrAppend(out, key_value.first, "="); + std::visit( + [&](const auto& value) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + if (value.value.size() > max_string_length) { + absl::StrAppend( + out, "\"", value.value.substr(0, max_string_length), "...\""); + } else { + absl::StrAppend(out, "\"", value.value, "\""); + } + } else if constexpr (std::is_same_v) { + absl::StrAppend(out, value.value ? "true" : "false"); + } else if constexpr (std::is_same_v) { + absl::StrAppend(out, value.value); + } else if constexpr (std::is_same_v) { + if (value.value.size() > max_int64_list_size) { + absl::StrAppend( + out, "[", + absl::StrJoin(value.value.begin(), + value.value.begin() + max_int64_list_size, + ", "), + "...]"); + } else { + absl::StrAppend(out, "[", absl::StrJoin(value.value, ", "), "]"); + } + } else if constexpr (std::is_same_v) { + absl::StrAppend(out, value.value); + } + }, + key_value.second); + }; + + return absl::StrCat("AttributeMap([", absl::StrJoin(map_, ", ", formatter), + "])"); +} + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/attribute_map.h b/third_party/xla/xla/python/ifrt/attribute_map.h new file mode 100644 index 00000000000000..630933d9fce6bb --- /dev/null +++ b/third_party/xla/xla/python/ifrt/attribute_map.h @@ -0,0 +1,103 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_IFRT_ATTRIBUTE_MAP_H_ +#define XLA_PYTHON_IFRT_ATTRIBUTE_MAP_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "xla/python/ifrt/attribute_map.pb.h" + +namespace xla { +namespace ifrt { + +// Attribute map that contains UTF-8 keys and variant values. +class AttributeMap { + public: + // Supported value types for `AttributeMap`. Modeled after + // `xla::PjRtValueType`, but they add a layer of structs that prevent implicit + // conversion. This ensures that `Value` to be constructed with a correct + // type. See + // https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2018/p0608r3.html + // construction of `Value` with a wrong type. + struct StringValue { + explicit StringValue(std::string value) : value(std::move(value)) {} + std::string value; + bool operator==(const StringValue& other) const { + return value == other.value; + } + }; + struct BoolValue { + explicit BoolValue(bool value) : value(value) {} + bool operator==(const BoolValue& other) const { + return value == other.value; + } + bool value; + }; + struct Int64Value { + explicit Int64Value(int64_t value) : value(value) {} + bool operator==(const Int64Value& other) const { + return value == other.value; + } + int64_t value; + }; + struct Int64ListValue { + explicit Int64ListValue(std::vector value) + : value(std::move(value)) {} + bool operator==(const Int64ListValue& other) const { + return value == other.value; + } + std::vector value; + }; + struct FloatValue { + explicit FloatValue(float value) : value(value) {} + bool operator==(const FloatValue& other) const { + return value == other.value; + } + float value; + }; + using Value = std::variant; + + using Map = absl::flat_hash_map; + + explicit AttributeMap(Map map) : map_(std::move(map)) {} + + const Map& map() const { return map_; } + + // Deserializes `AttributeMapProto` into `AttributeMap`. + static absl::StatusOr FromProto(const AttributeMapProto& proto); + + // Serializes `AttributeMap` into `AttributeMapProto`. + AttributeMapProto ToProto() const; + + std::string DebugString(size_t max_string_length = 64, + size_t max_int64_list_size = 16) const; + + private: + Map map_; +}; + +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_ATTRIBUTE_MAP_H_ diff --git a/third_party/xla/xla/python/ifrt/attribute_map.proto b/third_party/xla/xla/python/ifrt/attribute_map.proto new file mode 100644 index 00000000000000..bd20a0a54e519b --- /dev/null +++ b/third_party/xla/xla/python/ifrt/attribute_map.proto @@ -0,0 +1,35 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +// Proto equivalent of C++ `AttributeMap`. +message AttributeMapProto { + message Value { + message Int64List { + repeated sfixed64 elements = 1; + } + oneof value { + bytes string_value = 1; + bool bool_value = 2; + sfixed64 int64_value = 3; + Int64List int64_list_value = 4; + float float_value = 5; + } + } + map attributes = 1; +} diff --git a/third_party/xla/xla/python/ifrt/attribute_map_test.cc b/third_party/xla/xla/python/ifrt/attribute_map_test.cc new file mode 100644 index 00000000000000..7c6ed104a6d298 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/attribute_map_test.cc @@ -0,0 +1,65 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/ifrt/attribute_map.h" + +#include +#include +#include + +#include +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace { + +TEST(AttributeMapTest, MapElements) { + AttributeMap map({ + {"string", AttributeMap::StringValue("value")}, + {"bool", AttributeMap::BoolValue(true)}, + {"int64", AttributeMap::Int64Value(123)}, + {"int64_list", AttributeMap::Int64ListValue({int64_t{1}, int64_t{2}})}, + {"float", AttributeMap::FloatValue(1.23f)}, + }); + + EXPECT_EQ(map.map(), AttributeMap::Map({ + {"string", AttributeMap::StringValue("value")}, + {"bool", AttributeMap::BoolValue(true)}, + {"int64", AttributeMap::Int64Value(123)}, + {"int64_list", AttributeMap::Int64ListValue( + {int64_t{1}, int64_t{2}})}, + {"float", AttributeMap::FloatValue(1.23f)}, + })) + << map.DebugString(); +} + +TEST(AttributeMapTest, ToFromProto) { + AttributeMap map({ + {"string", AttributeMap::StringValue("value")}, + {"bool", AttributeMap::BoolValue(true)}, + {"int64", AttributeMap::Int64Value(123)}, + {"int64_list", AttributeMap::Int64ListValue({int64_t{1}, int64_t{2}})}, + {"float", AttributeMap::FloatValue(1.23f)}, + }); + + TF_ASSERT_OK_AND_ASSIGN(auto map_copy, + AttributeMap::FromProto(map.ToProto())); + EXPECT_EQ(map_copy.map(), map.map()) << map_copy.DebugString(); +} + +} // namespace +} // namespace ifrt +} // namespace xla From 61845939869b71e5c4c639aa57dfb03fe7dfa7d6 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Thu, 20 Jun 2024 10:50:02 -0700 Subject: [PATCH 068/256] Move StreamExecutor::Memcpy processing for to-device copies completely into Stream and its derived classes. PiperOrigin-RevId: 645086988 --- .../stream_executor/stream_executor.cc | 12 ----------- .../stream_executor_internal.h | 20 ++++++++++++++++++- .../xla/xla/backends/interpreter/executor.cc | 11 ---------- .../xla/xla/backends/interpreter/executor.h | 2 -- .../xla/stream_executor/cuda/cuda_executor.cc | 13 ------------ .../xla/stream_executor/gpu/gpu_executor.h | 3 --- .../xla/xla/stream_executor/gpu/gpu_stream.cc | 11 ++++++++++ .../xla/xla/stream_executor/gpu/gpu_stream.h | 10 ++++++++++ .../xla/stream_executor/host/host_executor.cc | 10 ---------- .../xla/stream_executor/host/host_executor.h | 2 -- .../xla/stream_executor/host/host_stream.cc | 9 +++++++++ .../xla/stream_executor/host/host_stream.h | 10 ++++++++++ .../stream_executor/mock_stream_executor.h | 4 ---- .../xla/stream_executor/rocm/rocm_executor.cc | 13 ------------ .../xla/xla/stream_executor/stream_common.cc | 5 ----- .../xla/xla/stream_executor/stream_common.h | 3 +-- .../xla/xla/stream_executor/stream_executor.h | 5 ----- .../xla/stream_executor/tpu/tpu_executor.cc | 10 ---------- .../xla/stream_executor/tpu/tpu_executor.h | 4 ---- .../xla/xla/stream_executor/tpu/tpu_stream.h | 19 ++++++++++++++++++ 20 files changed, 79 insertions(+), 97 deletions(-) diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index 2bca4311fd7cc4..ccf808c7e11504 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -324,18 +324,6 @@ class CStreamExecutor : public StreamExecutorCommon { } return StatusFromTF_Status(c_status.get()); } - absl::Status Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, - const void* host_src, uint64 size) override { - OwnedTFStatus c_status(TF_NewStatus()); - SP_Stream stream_handle = static_cast(stream)->Handle(); - SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst); - stream_executor_->memcpy_htod(&device_, stream_handle, &device_mem_dst, - host_src, size, c_status.get()); - if (TF_GetCode(c_status.get()) != TF_OK) { - LOG(ERROR) << TF_Message(c_status.get()); - } - return StatusFromTF_Status(c_status.get()); - } bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, uint64 size) override { diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h index d8e0ec75365b95..02bb318ea33b09 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h +++ b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h @@ -235,7 +235,25 @@ class CStream : public StreamCommon { size, c_status.get()); return tensorflow::StatusFromTF_Status(c_status.get()); } - + absl::Status Memcpy(DeviceMemoryBase* gpu_dst, const void* host_src, + uint64_t size) override { + tensorflow::TF_StatusPtr c_status(TF_NewStatus()); + SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst); + stream_executor_->memcpy_htod(device_, stream_handle_, &device_mem_dst, + host_src, size, c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + } + return tensorflow::StatusFromTF_Status(c_status.get()); + } + absl::Status Memcpy(DeviceMemoryBase* gpu_dst, + const DeviceMemoryBase& gpu_src, uint64_t size) override { + return StreamCommon::Memcpy(gpu_dst, gpu_src, size); + } + absl::Status Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, + uint64_t size) override { + return StreamCommon::Memcpy(host_dst, gpu_src, size); + } SP_Stream Handle() { return stream_handle_; } private: diff --git a/third_party/xla/xla/backends/interpreter/executor.cc b/third_party/xla/xla/backends/interpreter/executor.cc index a73358683cfc4b..bbfe3964b5df8a 100644 --- a/third_party/xla/xla/backends/interpreter/executor.cc +++ b/third_party/xla/xla/backends/interpreter/executor.cc @@ -49,17 +49,6 @@ absl::Status XlaInterpreterExecutor::Memcpy(Stream *stream, void *host_dst, return AsExecutorStream(stream)->BlockUntilDone(); } -absl::Status XlaInterpreterExecutor::Memcpy(Stream *stream, - DeviceMemoryBase *dev_dst, - const void *host_src, - uint64_t size) { - AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() { - // Ignore errors. - absl::Status ok = SynchronousMemcpy(dev_dst, host_src, size); - }); - return AsExecutorStream(stream)->BlockUntilDone(); -} - absl::Status XlaInterpreterExecutor::SynchronousMemcpy( DeviceMemoryBase *dev_dst, const void *host_src, uint64_t size) { memcpy(dev_dst->opaque(), host_src, size); diff --git a/third_party/xla/xla/backends/interpreter/executor.h b/third_party/xla/xla/backends/interpreter/executor.h index 7a9419c73ef085..5a4950552dcd4c 100644 --- a/third_party/xla/xla/backends/interpreter/executor.h +++ b/third_party/xla/xla/backends/interpreter/executor.h @@ -92,8 +92,6 @@ class XlaInterpreterExecutor : public StreamExecutorCommon { absl::Status Memcpy(Stream *stream, void *host_dst, const DeviceMemoryBase &dev_src, uint64_t size) override; - absl::Status Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, - const void *host_src, uint64_t size) override; bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *pop_dst, const DeviceMemoryBase &host_src, uint64_t size) override { diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 1a5f070a3a4d00..d73c2ed935f3ac 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -695,19 +695,6 @@ absl::Status GpuExecutor::Memcpy(Stream* stream, void* host_dst, return absl::OkStatus(); } -absl::Status GpuExecutor::Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, - const void* host_src, uint64_t size) { - bool ok = GpuDriver::AsynchronousMemcpyH2D(context_, AsCudaDevicePtr(gpu_dst), - host_src, size, - AsGpuStreamValue(stream)); - // TODO(b/326130105): Change AsynchronousMemcpyD2H calls to return - // absl::Status. - if (!ok) { - return absl::InternalError("Failed to memcpy from device to host."); - } - return absl::OkStatus(); -} - bool GpuExecutor::MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h index 097c5143e32f26..6c1f89ac00b931 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h @@ -214,9 +214,6 @@ class GpuExecutor : public StreamExecutorCommon { absl::Status Memcpy(Stream* stream, void* host_dst, const DeviceMemoryBase& gpu_src, uint64_t size) override; - absl::Status Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, - const void* host_src, uint64_t size) override; - bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, uint64_t size) override; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc index 7130e00a878762..573bb4d6fcb632 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc @@ -77,6 +77,17 @@ absl::Status GpuStream::MemZero(DeviceMemoryBase* location, uint64_t size) { } } +absl::Status GpuStream::Memcpy(DeviceMemoryBase* gpu_dst, const void* host_src, + uint64_t size) { + bool ok = GpuDriver::AsynchronousMemcpyH2D( + parent_->gpu_context(), reinterpret_cast(gpu_dst->opaque()), + host_src, size, gpu_stream()); + if (!ok) { + return absl::InternalError("Failed to memcpy from device to host."); + } + return absl::OkStatus(); +} + absl::Status GpuStream::WaitFor(Stream* other) { GpuStream* other_gpu = AsGpuStream(other); GpuEventHandle other_completed_event = *(other_gpu->completed_event()); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h index 633a2341dabcd6..84440824e71f5e 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h @@ -102,6 +102,16 @@ class GpuStream : public StreamCommon { absl::Status MemZero(DeviceMemoryBase* location, uint64_t size) override; absl::Status Memset32(DeviceMemoryBase* location, uint32_t pattern, uint64_t size) override; + absl::Status Memcpy(DeviceMemoryBase* gpu_dst, const void* host_src, + uint64_t size) override; + absl::Status Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, + uint64_t size) override { + return StreamCommon::Memcpy(host_dst, gpu_src, size); + } + absl::Status Memcpy(DeviceMemoryBase* gpu_dst, + const DeviceMemoryBase& gpu_src, uint64_t size) override { + return StreamCommon::Memcpy(gpu_dst, gpu_src, size); + } private: GpuExecutor* parent_; // Executor that spawned this stream. diff --git a/third_party/xla/xla/stream_executor/host/host_executor.cc b/third_party/xla/xla/stream_executor/host/host_executor.cc index 95194dbc397033..cf125ac7dbc858 100644 --- a/third_party/xla/xla/stream_executor/host/host_executor.cc +++ b/third_party/xla/xla/stream_executor/host/host_executor.cc @@ -154,16 +154,6 @@ absl::Status HostExecutor::Memcpy(Stream* stream, void* host_dst, return absl::OkStatus(); } -absl::Status HostExecutor::Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, - const void* host_src, uint64_t size) { - void* dst_mem = gpu_dst->opaque(); - // Enqueue the [asynchronous] memcpy on the stream (HostStream) associated - // with the HostExecutor. - AsHostStream(stream)->EnqueueTask( - [dst_mem, host_src, size]() { memcpy(dst_mem, host_src, size); }); - return absl::OkStatus(); -} - bool HostExecutor::MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, diff --git a/third_party/xla/xla/stream_executor/host/host_executor.h b/third_party/xla/xla/stream_executor/host/host_executor.h index 6d9d64a33bfd3e..e9102ca83c546a 100644 --- a/third_party/xla/xla/stream_executor/host/host_executor.h +++ b/third_party/xla/xla/stream_executor/host/host_executor.h @@ -92,8 +92,6 @@ class HostExecutor : public StreamExecutorCommon { absl::Status Memcpy(Stream* stream, void* host_dst, const DeviceMemoryBase& gpu_src, uint64_t size) override; - absl::Status Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, - const void* host_src, uint64_t size) override; bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, uint64_t size) override; diff --git a/third_party/xla/xla/stream_executor/host/host_stream.cc b/third_party/xla/xla/stream_executor/host/host_stream.cc index a45f188ee82fc0..57041acaaa13fb 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream.cc +++ b/third_party/xla/xla/stream_executor/host/host_stream.cc @@ -57,6 +57,15 @@ HostStream::~HostStream() { parent()->DeallocateStream(this); } +absl::Status HostStream::Memcpy(DeviceMemoryBase* gpu_dst, const void* host_src, + uint64_t size) { + void* dst_mem = gpu_dst->opaque(); + // Enqueue the [asynchronous] memcpy on the stream (HostStream) associated + // with the HostExecutor. + EnqueueTask([dst_mem, host_src, size]() { memcpy(dst_mem, host_src, size); }); + return absl::OkStatus(); +} + absl::Status HostStream::Memset32(DeviceMemoryBase* location, uint32_t pattern, uint64_t size) { void* gpu_mem = location->opaque(); diff --git a/third_party/xla/xla/stream_executor/host/host_stream.h b/third_party/xla/xla/stream_executor/host/host_stream.h index 83dd18be8a9693..08428c720c7f31 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream.h +++ b/third_party/xla/xla/stream_executor/host/host_stream.h @@ -57,6 +57,16 @@ class HostStream : public StreamCommon { absl::Status MemZero(DeviceMemoryBase* location, uint64_t size) override; absl::Status Memset32(DeviceMemoryBase* location, uint32_t pattern, uint64_t size) override; + absl::Status Memcpy(DeviceMemoryBase* gpu_dst, const void* host_src, + uint64_t size) override; + absl::Status Memcpy(DeviceMemoryBase* gpu_dst, + const DeviceMemoryBase& gpu_src, uint64_t size) override { + return StreamCommon::Memcpy(gpu_dst, gpu_src, size); + } + absl::Status Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, + uint64_t size) override { + return StreamCommon::Memcpy(host_dst, gpu_src, size); + } private: bool WorkAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); diff --git a/third_party/xla/xla/stream_executor/mock_stream_executor.h b/third_party/xla/xla/stream_executor/mock_stream_executor.h index ac19fcf0932476..a207ee9ac1b8f0 100644 --- a/third_party/xla/xla/stream_executor/mock_stream_executor.h +++ b/third_party/xla/xla/stream_executor/mock_stream_executor.h @@ -113,10 +113,6 @@ class MockStreamExecutor : public StreamExecutor { (Stream * stream, void* host_dst, const DeviceMemoryBase& device_src, uint64_t size), (override)); - MOCK_METHOD(absl::Status, Memcpy, - (Stream * stream, DeviceMemoryBase* device_dst, - const void* host_src, uint64_t size), - (override)); MOCK_METHOD(bool, MemcpyDeviceToDevice, (Stream * stream, DeviceMemoryBase* device_dst, const DeviceMemoryBase& device_src, uint64_t size), diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc index c8b00b133f898c..5c8e30f8830da6 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc @@ -597,19 +597,6 @@ absl::Status GpuExecutor::Memcpy(Stream* stream, void* host_dst, return absl::OkStatus(); } -absl::Status GpuExecutor::Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, - const void* host_src, uint64_t size) { - bool ok = GpuDriver::AsynchronousMemcpyH2D(context_, AsROCmDevicePtr(gpu_dst), - host_src, size, - AsGpuStreamValue(stream)); - // TODO(b/326130105): Change AsynchronousMemcpyD2H calls to return - // absl::Status. - if (!ok) { - return absl::InternalError("Failed to memcpy from device to host."); - } - return absl::OkStatus(); -} - bool GpuExecutor::MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, diff --git a/third_party/xla/xla/stream_executor/stream_common.cc b/third_party/xla/xla/stream_executor/stream_common.cc index 44a09f16f827b2..f171491347f996 100644 --- a/third_party/xla/xla/stream_executor/stream_common.cc +++ b/third_party/xla/xla/stream_executor/stream_common.cc @@ -149,11 +149,6 @@ absl::Status StreamCommon::Memcpy(void *host_dst, return parent_->Memcpy(this, host_dst, gpu_src, size); } -absl::Status StreamCommon::Memcpy(DeviceMemoryBase *gpu_dst, - const void *host_src, uint64_t size) { - return parent_->Memcpy(this, gpu_dst, host_src, size); -} - absl::Status StreamCommon::Memcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, uint64_t size) { diff --git a/third_party/xla/xla/stream_executor/stream_common.h b/third_party/xla/xla/stream_executor/stream_common.h index c83a0949d476f0..63b5b6d3a48414 100644 --- a/third_party/xla/xla/stream_executor/stream_common.h +++ b/third_party/xla/xla/stream_executor/stream_common.h @@ -73,8 +73,7 @@ class StreamCommon : public Stream { void ReturnSubStream(Stream *sub_stream) override TF_LOCKS_EXCLUDED(mu_); absl::Status Memcpy(void *host_dst, const DeviceMemoryBase &gpu_src, uint64_t size) override; - absl::Status Memcpy(DeviceMemoryBase *gpu_dst, const void *host_src, - uint64_t size) override; + absl::Status Memcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, uint64_t size) override; absl::Status BlockHostUntilDone() override TF_LOCKS_EXCLUDED(mu_); diff --git a/third_party/xla/xla/stream_executor/stream_executor.h b/third_party/xla/xla/stream_executor/stream_executor.h index f629c7c2781848..a314759fe1079a 100644 --- a/third_party/xla/xla/stream_executor/stream_executor.h +++ b/third_party/xla/xla/stream_executor/stream_executor.h @@ -237,11 +237,6 @@ class StreamExecutor { const DeviceMemoryBase& device_src, uint64_t size) = 0; - // Enqueues a memcpy operation onto stream, with a device destination location - // and a host memory source, with target size size. - virtual absl::Status Memcpy(Stream* stream, DeviceMemoryBase* device_dst, - const void* host_src, uint64_t size) = 0; - // Enqueues a memcpy operation onto stream, with a device destination location // and a device source location, with target size size. Peer access should // have been enabled between the StreamExecutors owning the device memory diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc b/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc index 8d0485097440c4..8b82cd019f928b 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc @@ -180,16 +180,6 @@ absl::Status TpuExecutor::Memcpy( return status.status(); } -absl::Status TpuExecutor::Memcpy( - Stream* stream, ::stream_executor::DeviceMemoryBase* device_dst, - const void* host_src, uint64_t size) { - StatusHelper status; - SE_DeviceMemoryBase se_base = ApiConverter::ToC(*device_dst); - ExecutorApiFn()->TpuExecutor_MemcpyFromHostFn( - executor_, get_stream(stream), &se_base, host_src, size, status.c_status); - return status.status(); -} - absl::Status TpuExecutor::SynchronousMemcpy( ::stream_executor::DeviceMemoryBase* device_dst, const void* host_src, uint64_t size) { diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h index f6d3352c84b9c5..aa64ded8e1b048 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h @@ -101,9 +101,6 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { const DeviceMemoryBase& device_src, uint64_t size) override; - absl::Status Memcpy(Stream* stream, DeviceMemoryBase* device_dst, - const void* host_src, uint64_t size) override; - bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& host_src, uint64_t size) override; @@ -115,7 +112,6 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { absl::Status SynchronousMemcpy(void* host_dst, const DeviceMemoryBase& device_src, uint64_t size) override; - absl::Status UnloadAllPrograms() override; absl::Status EnqueueCompactionOnStreamForHbm( diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_stream.h b/third_party/xla/xla/stream_executor/tpu/tpu_stream.h index ee0d9687c1ab66..b8e04925ac4f86 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_stream.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_stream.h @@ -123,6 +123,25 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface { return status.status(); } + absl::Status Memcpy(stream_executor::DeviceMemoryBase* device_dst, + const void* host_src, uint64_t size) override { + StatusHelper status; + SE_DeviceMemoryBase se_base = ApiConverter::ToC(*device_dst); + stream_executor::tpu::ExecutorApiFn()->TpuExecutor_MemcpyFromHostFn( + se_executor_, stream_, &se_base, host_src, size, status.c_status); + return status.status(); + } + absl::Status Memcpy(stream_executor::DeviceMemoryBase* gpu_dst, + const stream_executor::DeviceMemoryBase& gpu_src, + uint64_t size) override { + return StreamCommon::Memcpy(gpu_dst, gpu_src, size); + } + absl::Status Memcpy(void* host_dst, + const stream_executor::DeviceMemoryBase& gpu_src, + uint64_t size) override { + return StreamCommon::Memcpy(host_dst, gpu_src, size); + } + SE_Stream* se_stream() const { return stream_; } private: From 0939cebb19832b3f430a32ed1c512c3023ad9a0c Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Thu, 20 Jun 2024 11:00:26 -0700 Subject: [PATCH 069/256] [XLA:GPU] Support mocking away all collectives PiperOrigin-RevId: 645090736 --- .../xla/xla/service/gpu/gpu_executable.cc | 16 +++++-- .../service/gpu/runtime/command_buffer_cmd.cc | 7 +++- .../service/gpu/runtime/sequential_thunk.cc | 3 ++ .../xla/xla/service/gpu/runtime/thunk.cc | 42 +++++++++++++++++-- .../xla/xla/service/gpu/runtime/thunk.h | 8 +++- .../functional_hlo_runner.cc | 3 +- .../functional_hlo_runner_test.cc | 27 ++++++++++++ .../multihost_hlo_runner/hlo_runner_main.cc | 2 +- 8 files changed, 98 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index 42f7e6f72c1701..bb943bef4f902e 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -360,6 +360,13 @@ absl::Status ExecuteThunks( const ServiceExecutableRunOptions* run_options, const BufferAllocations& buffer_allocations, bool block_host_until_done, const absl::flat_hash_set& execution_stream_ids) { + bool mock_collectives = + run_options->run_options().gpu_executable_run_options() + ? run_options->run_options() + .gpu_executable_run_options() + ->enable_mock_nccl_collectives() + : false; + int64_t collective_max_nchannels = debug_options ? debug_options->xla_gpu_nccl_collective_max_nchannels() : 0; @@ -448,9 +455,12 @@ absl::Status ExecuteThunks( } // Acquire collective cliques requested by thunks. - TF_ASSIGN_OR_RETURN( - Thunk::CollectiveCliques collective_cliques, - resource_requests.AcquireCollectiveCliques(collective_params)); + Thunk::CollectiveCliques collective_cliques; + if (!mock_collectives) { + TF_ASSIGN_OR_RETURN( + collective_cliques, + resource_requests.AcquireCollectiveCliques(collective_params)); + } { // Initialize thunks using prepared resources before execution. Thunk::InitializeParams initialize_params{ diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc index 381ceea2017dc3..78c5afc1e8ab61 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc @@ -297,7 +297,12 @@ absl::Status CommandBufferCmdSequence::Record( // Track the number of commands recorded between barriers. absl::flat_hash_map num_recorded_commands; - for (auto& command : commands_) { + for (CommandInfo& command : commands_) { + if (execute_params.mock_collectives && + dynamic_cast(command.cmd.get())) { + continue; + } + ExecutionScopeId execution_scope_id = command.cmd->GetExecutionScope(record_params); std::optional annotation = diff --git a/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc index 6033ddc052748c..ac2e652dcb1537 100644 --- a/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc @@ -75,6 +75,9 @@ absl::Status SequentialThunk::ExecuteOnStream(const ExecuteParams& params) { for (const std::unique_ptr& thunk : thunks_) { std::optional annotation = GetKernelAnnotation(thunk->profile_annotation()); + if (params.mock_collectives && thunk->IsCollective()) { + continue; + } TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(params)); } return absl::OkStatus(); diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.cc b/third_party/xla/xla/service/gpu/runtime/thunk.cc index 759df65bc28aa6..972f05a9437601 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/thunk.cc @@ -186,7 +186,12 @@ Thunk::ExecuteParams Thunk::ExecuteParams::Create( run_options.run_options().send_device_memory_function(), run_options.run_options().recv_device_memory_function(), run_options.run_options().ffi_execution_context(), - additional_compute_streams); + additional_compute_streams, + run_options.run_options().gpu_executable_run_options() + ? run_options.run_options() + .gpu_executable_run_options() + ->enable_mock_nccl_collectives() + : false); } Thunk::ExecuteParams Thunk::ExecuteParams::CloneWithNewAllocations( @@ -209,7 +214,7 @@ Thunk::ExecuteParams::ExecuteParams( SendDeviceMemoryFunction* send_device_memory_function, RecvDeviceMemoryFunction* recv_device_memory_function, const ffi::ExecutionContext* ffi_execution_context, - ExecutionStreamIdMap additional_compute_streams) + ExecutionStreamIdMap additional_compute_streams, bool mock_collectives) : buffer_allocations(buffer_allocations), stream(stream), command_buffer_trace_stream(command_buffer_trace_stream), @@ -220,7 +225,8 @@ Thunk::ExecuteParams::ExecuteParams( send_device_memory_function(send_device_memory_function), recv_device_memory_function(recv_device_memory_function), ffi_execution_context(ffi_execution_context), - additional_compute_streams(additional_compute_streams) {} + additional_compute_streams(additional_compute_streams), + mock_collectives(mock_collectives) {} //===----------------------------------------------------------------------===// @@ -322,5 +328,35 @@ Thunk::ThunkInfo Thunk::ThunkInfo::WithProfileAnnotation( return thunk_info; } +bool Thunk::IsCollective() const { + switch (kind()) { + case kNcclAllGather: + case kNcclAllGatherStart: + case kNcclAllGatherDone: + case kNcclAllReduce: + case kNcclAllReduceStart: + case kNcclAllReduceDone: + case kNcclCollectiveBroadcast: + case kNcclCollectiveBroadcastStart: + case kNcclCollectiveBroadcastDone: + case kNcclCollectivePermute: + case kNcclCollectivePermuteStart: + case kNcclCollectivePermuteDone: + case kNcclReduceScatter: + case kNcclReduceScatterStart: + case kNcclReduceScatterDone: + case kNcclAllToAll: + case kNcclAllToAllStart: + case kNcclAllToAllDone: + case kNcclSend: + case kNcclSendDone: + case kNcclRecv: + case kNcclRecvDone: + return true; + default: + return false; + } +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.h b/third_party/xla/xla/service/gpu/runtime/thunk.h index 483c817b7add93..0bd548d1218086 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/thunk.h @@ -382,6 +382,8 @@ class Thunk { // Additional compute streams on which thunks launch operations. ExecutionStreamIdMap additional_compute_streams; + bool mock_collectives = false; + private: friend class CommandBufferThunk; @@ -394,7 +396,8 @@ class Thunk { SendDeviceMemoryFunction* send_device_memory_function, RecvDeviceMemoryFunction* recv_device_memory_function, const ffi::ExecutionContext* ffi_execution_context, - ExecutionStreamIdMap additional_compute_streams = {}); + ExecutionStreamIdMap additional_compute_streams = {}, + bool mock_collectives = false); }; //===--------------------------------------------------------------------===// @@ -451,6 +454,9 @@ class Thunk { static absl::StatusOr GetStreamForExecution( ExecutionStreamId stream_id, const ExecuteParams& params); + // Returns `true` if this thunk requires inter-GPU communication. + bool IsCollective() const; + private: Kind kind_; std::string profile_annotation_; diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc index 8f2d256128d130..ab77b7ecd0bd46 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc @@ -98,11 +98,12 @@ static absl::StatusOr> GetPjRtClient( if (num_nodes == 1) { return xla::FunctionalHloRunner::CreateGpuClient({}); } else { - TF_RET_CHECK(num_nodes == 1 || !address.empty()); + TF_RET_CHECK(!address.empty()); TF_RET_CHECK(node_id >= 0) << "Node id is expected to be in range [0, num_nodes)"; TF_RET_CHECK(node_id < num_nodes) << "Node id is expected to be in range [0, num_nodes)"; + CHECK_GT(address.length(), 0); // Multinode. Start service on task 0. if (node_id == 0) { diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc index c1e7e6afdaaceb..be970c1716c4b8 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc @@ -319,6 +319,33 @@ absl::Status ShardedAutotuningWorksTestBody(const int node_id) { return absl::OkStatus(); } +TEST_F(FunctionalHloRunnerTest, CanRunWithMockCollectives) { + // This test corresponds to: + // --use_spmd_partitioning=true --num_replicas=1 --num_partitions=16 + if (IsTestingCpu()) { + GTEST_SKIP() << "GPU-only test"; + } + xla::DebugOptions debug_options; + FunctionalHloRunner::PreprocessingOptions preproc_options; + FunctionalHloRunner::RawCompileOptions raw_compile_options; + raw_compile_options.spmd_mode = + FunctionalHloRunner::SpmdMode::kUseSpmdPartitioning; + raw_compile_options.num_replicas = 1; + raw_compile_options.num_partitions = 16; + + FunctionalHloRunner::RunningOptions running_options; + running_options.module_argument_mode = + FunctionalHloRunner::ModuleArgumentMode::kUseZerosAsInput; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + FunctionalHloRunner::CreateMockGpuClient(16)); + + TF_EXPECT_OK(FunctionalHloRunner::LoadAndRunAndDump( + *client, debug_options, preproc_options, raw_compile_options, + running_options, {GetHloPath("sharded_16_devices.hlo")}, + InputFormat::kText)); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc b/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc index 733d42d3e07112..c79cf8493109a8 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc @@ -70,7 +70,7 @@ all HLOs from an execution dump, with e.g.: Mock GPU usage: - bazel run hlo_runner_main -- --enable_mock_gpu=true /path/to/hlo_module.hlo + bazel run hlo_runner_main -- --enable_mock_nccl=true /path/to/hlo_module.hlo Tip: If the input generation takes too long or uses too much host memory, consider using --hlo_argument_mode=uninitialized. From 1fa8849677a3519a16c672da539254ff17839c5e Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Thu, 20 Jun 2024 11:17:11 -0700 Subject: [PATCH 070/256] Don't reject F32 non 4D input tensors for float. XNNPack can handle them all PiperOrigin-RevId: 645097491 --- tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index 0935d28388d182..547b18320d28cc 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -4773,9 +4773,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4, - node->inputs->data[0], - BuiltinOperator_MEAN, node_index)); TF_LITE_ENSURE_STATUS( CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, node->inputs->data[0], node_index)); @@ -4796,6 +4793,10 @@ class Subgraph { bool all_reductions_supported = false; if (input_tensor.type == kTfLiteFloat32) { all_reductions_supported = true; + } else { + TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4, + node->inputs->data[0], + BuiltinOperator_MEAN, node_index)); } const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS( From d792589825a381483b5e116f22c66ba1eed050eb Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Thu, 20 Jun 2024 11:34:47 -0700 Subject: [PATCH 071/256] Stop using xla/statusor.h now that it just contains an alias for absl::Status. In some situations, this meant also changing unrelated files to directly include tsl/platform/statusor.h to get the definitions for TF_ASSIGN_OR_RETURN, etc., where they were getting transitively included for free. PiperOrigin-RevId: 645103999 --- third_party/xla/xla/client/lib/BUILD | 43 ++++++++++--------- third_party/xla/xla/client/lib/approx_topk.cc | 2 +- .../xla/xla/client/lib/approx_topk_shape.h | 2 +- third_party/xla/xla/client/lib/arithmetic.cc | 2 +- .../xla/xla/client/lib/conv_grad_size_util.cc | 1 + third_party/xla/xla/client/lib/logdet.cc | 2 +- third_party/xla/xla/client/lib/logdet_test.cc | 2 +- third_party/xla/xla/client/lib/loops.h | 2 +- .../xla/xla/client/lib/lu_decomposition.cc | 2 +- third_party/xla/xla/client/lib/matrix.cc | 2 +- third_party/xla/xla/client/lib/matrix.h | 2 +- third_party/xla/xla/client/lib/matrix_test.cc | 2 +- third_party/xla/xla/client/lib/prng.cc | 2 +- third_party/xla/xla/client/lib/prng_test.cc | 2 +- third_party/xla/xla/client/lib/qr.cc | 2 +- third_party/xla/xla/client/lib/qr_test.cc | 2 +- .../xla/xla/client/lib/self_adjoint_eig.cc | 2 +- .../xla/client/lib/self_adjoint_eig_test.cc | 2 +- third_party/xla/xla/client/lib/slicing.cc | 2 +- third_party/xla/xla/client/lib/svd.cc | 2 +- third_party/xla/xla/client/lib/svd_test.cc | 2 +- third_party/xla/xla/client/lib/testing.cc | 2 +- third_party/xla/xla/client/lib/tridiagonal.cc | 2 +- third_party/xla/xla/client/lib/tuple.cc | 2 +- third_party/xla/xla/client/lib/tuple.h | 2 +- 25 files changed, 46 insertions(+), 44 deletions(-) diff --git a/third_party/xla/xla/client/lib/BUILD b/third_party/xla/xla/client/lib/BUILD index 2907f8c21d3b0c..8936e0e850f369 100644 --- a/third_party/xla/xla/client/lib/BUILD +++ b/third_party/xla/xla/client/lib/BUILD @@ -30,10 +30,10 @@ cc_library( deps = [ ":constants", "//xla:shape_util", - "//xla:statusor", "//xla:xla_data_proto_cc", "//xla/client:xla_builder", "//xla/client:xla_computation", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ], ) @@ -151,6 +151,7 @@ cc_library( "//xla:status_macros", "//xla/client:padding", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) @@ -179,9 +180,9 @@ cc_library( ":constants", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla/client:xla_builder", "//xla/client:xla_computation", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], @@ -236,7 +237,6 @@ cc_library( "//xla:literal", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", @@ -245,6 +245,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", @@ -258,7 +259,6 @@ xla_test( ":constants", ":matrix", ":slicing", - "//xla:statusor", "//xla:test", "//xla:types", "//xla/client:xla_builder", @@ -266,6 +266,7 @@ xla_test( "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ], ) @@ -303,10 +304,10 @@ cc_library( deps = [ ":constants", "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/client:xla_builder", + "@com_google_absl//absl/status:statusor", ], ) @@ -317,7 +318,6 @@ xla_test( ":constants", ":prng", "//xla:shape_util", - "//xla:statusor", "//xla:test", "//xla:xla_data_proto_cc", "//xla/client:xla_builder", @@ -325,6 +325,7 @@ xla_test( "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", ], ) @@ -342,9 +343,9 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:xla_data_proto_cc", "//xla/client:xla_builder", + "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:errors", ], ) @@ -360,7 +361,6 @@ xla_test( "//xla:array2d", "//xla:array3d", "//xla:literal", - "//xla:statusor", "//xla:test", "//xla:types", "//xla:xla_data_proto_cc", @@ -369,6 +369,7 @@ xla_test( "//xla/tests:literal_test_util", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/lib/core:status_test_util", ], ) @@ -379,10 +380,10 @@ cc_library( hdrs = ["lu_decomposition.h"], deps = [ "//xla:shape_util", - "//xla:statusor", "//xla:types", "//xla:xla_data_proto_cc", "//xla/client:xla_builder", + "@com_google_absl//absl/status:statusor", ], ) @@ -393,11 +394,11 @@ cc_library( deps = [ ":approx_topk_shape", "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/client:xla_builder", "//xla/client:xla_computation", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", ], ) @@ -407,8 +408,8 @@ cc_library( srcs = ["approx_topk_shape.cc"], hdrs = ["approx_topk_shape.h"], deps = [ - "//xla:statusor", "//xla:util", + "@com_google_absl//absl/status:statusor", ], ) @@ -421,10 +422,10 @@ cc_library( ":constants", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:types", "//xla:util", "//xla/client:xla_builder", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", ], ) @@ -518,7 +519,6 @@ cc_library( "//xla:execution_options_util", "//xla:literal", "//xla:shape_util", - "//xla:statusor", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", @@ -527,6 +527,7 @@ cc_library( "//xla/client:xla_builder", "//xla/client:xla_computation", "//xla/tests:test_utils", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:protobuf", ], @@ -547,10 +548,10 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/client:xla_builder", + "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:errors", ], ) @@ -571,7 +572,6 @@ xla_test( "//xla:array2d", "//xla:array3d", "//xla:literal", - "//xla:statusor", "//xla:test", "//xla:xla_data_proto_cc", "//xla/client:xla_builder", @@ -579,6 +579,7 @@ xla_test( "//xla/tests:literal_test_util", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/lib/core:status_test_util", ], ) @@ -598,9 +599,9 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:xla_data_proto_cc", "//xla/client:xla_builder", + "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:errors", ], ) @@ -621,7 +622,6 @@ xla_test( "//xla:array3d", "//xla:literal", "//xla:shape_util", - "//xla:statusor", "//xla:test", "//xla:xla_data_proto_cc", "//xla/client:xla_builder", @@ -629,6 +629,7 @@ xla_test( "//xla/tests:literal_test_util", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/lib/core:status_test_util", ], ) @@ -643,10 +644,10 @@ cc_library( ":slicing", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:xla_data_proto_cc", "//xla/client:xla_builder", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", ], ) @@ -686,8 +687,8 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla/client:xla_builder", + "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:errors", ], ) @@ -704,13 +705,13 @@ xla_test( "//xla:array2d", "//xla:array3d", "//xla:literal", - "//xla:statusor", "//xla:test", "//xla/client:xla_builder", "//xla/tests:client_library_test_base", "//xla/tests:literal_test_util", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/lib/core:status_test_util", ], ) @@ -722,9 +723,9 @@ cc_library( deps = [ "//xla:shape_tree", "//xla:shape_util", - "//xla:statusor", "//xla/client:xla_builder", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/client/lib/approx_topk.cc b/third_party/xla/xla/client/lib/approx_topk.cc index 40b9360976215f..13837dc3047b69 100644 --- a/third_party/xla/xla/client/lib/approx_topk.cc +++ b/third_party/xla/xla/client/lib/approx_topk.cc @@ -20,13 +20,13 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "xla/client/lib/approx_topk_shape.h" #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/client/lib/approx_topk_shape.h b/third_party/xla/xla/client/lib/approx_topk_shape.h index de12c7fc0a3e1c..ef59a604adb7f2 100644 --- a/third_party/xla/xla/client/lib/approx_topk_shape.h +++ b/third_party/xla/xla/client/lib/approx_topk_shape.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "xla/statusor.h" +#include "absl/status/statusor.h" namespace xla { diff --git a/third_party/xla/xla/client/lib/arithmetic.cc b/third_party/xla/xla/client/lib/arithmetic.cc index e4f365f599f6d4..804face7f38517 100644 --- a/third_party/xla/xla/client/lib/arithmetic.cc +++ b/third_party/xla/xla/client/lib/arithmetic.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "xla/client/lib/constants.h" #include "xla/client/xla_builder.h" @@ -28,7 +29,6 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/client/lib/conv_grad_size_util.cc b/third_party/xla/xla/client/lib/conv_grad_size_util.cc index 6d23c71a15f39f..9bfad0a446566c 100644 --- a/third_party/xla/xla/client/lib/conv_grad_size_util.cc +++ b/third_party/xla/xla/client/lib/conv_grad_size_util.cc @@ -19,6 +19,7 @@ limitations under the License. #include "xla/status_macros.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/client/lib/logdet.cc b/third_party/xla/xla/client/lib/logdet.cc index 7eaee01274e218..8c41f5d6cce0a7 100644 --- a/third_party/xla/xla/client/lib/logdet.cc +++ b/third_party/xla/xla/client/lib/logdet.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "xla/client/lib/arithmetic.h" #include "xla/client/lib/constants.h" #include "xla/client/lib/loops.h" @@ -30,7 +31,6 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "tsl/platform/errors.h" namespace xla { diff --git a/third_party/xla/xla/client/lib/logdet_test.cc b/third_party/xla/xla/client/lib/logdet_test.cc index c9be45c5ed13f1..8c0a8b91d97b18 100644 --- a/third_party/xla/xla/client/lib/logdet_test.cc +++ b/third_party/xla/xla/client/lib/logdet_test.cc @@ -17,12 +17,12 @@ limitations under the License. #include +#include "absl/status/statusor.h" #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/client/lib/matrix.h" #include "xla/client/xla_builder.h" #include "xla/literal.h" -#include "xla/statusor.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/client/lib/loops.h b/third_party/xla/xla/client/lib/loops.h index 5cd7426992ceaf..18efcc15d270e4 100644 --- a/third_party/xla/xla/client/lib/loops.h +++ b/third_party/xla/xla/client/lib/loops.h @@ -19,11 +19,11 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" -#include "xla/statusor.h" namespace xla { diff --git a/third_party/xla/xla/client/lib/lu_decomposition.cc b/third_party/xla/xla/client/lib/lu_decomposition.cc index cbce5be6b06c88..e7de77d9c5c527 100644 --- a/third_party/xla/xla/client/lib/lu_decomposition.cc +++ b/third_party/xla/xla/client/lib/lu_decomposition.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/types.h" namespace xla { diff --git a/third_party/xla/xla/client/lib/matrix.cc b/third_party/xla/xla/client/lib/matrix.cc index 5a27e2ccf922b1..e43f4df437e603 100644 --- a/third_party/xla/xla/client/lib/matrix.cc +++ b/third_party/xla/xla/client/lib/matrix.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" @@ -44,7 +45,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/client/lib/matrix.h b/third_party/xla/xla/client/lib/matrix.h index b1b18b1ae9fd82..df3a2e878d88a7 100644 --- a/third_party/xla/xla/client/lib/matrix.h +++ b/third_party/xla/xla/client/lib/matrix.h @@ -21,10 +21,10 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/client/xla_builder.h" -#include "xla/statusor.h" #include "xla/types.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/client/lib/matrix_test.cc b/third_party/xla/xla/client/lib/matrix_test.cc index 73a8582292194f..fc6770a6eaf9ba 100644 --- a/third_party/xla/xla/client/lib/matrix_test.cc +++ b/third_party/xla/xla/client/lib/matrix_test.cc @@ -21,11 +21,11 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/client/lib/constants.h" #include "xla/client/lib/slicing.h" #include "xla/client/xla_builder.h" -#include "xla/statusor.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/client/lib/prng.cc b/third_party/xla/xla/client/lib/prng.cc index def89a0d40b95c..92cc1ee35ff1f0 100644 --- a/third_party/xla/xla/client/lib/prng.cc +++ b/third_party/xla/xla/client/lib/prng.cc @@ -23,10 +23,10 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "xla/client/lib/constants.h" #include "xla/client/xla_builder.h" #include "xla/primitive_util.h" -#include "xla/statusor.h" #include "xla/util.h" namespace xla { diff --git a/third_party/xla/xla/client/lib/prng_test.cc b/third_party/xla/xla/client/lib/prng_test.cc index 424846cc6fc2fd..22241e9fab1da9 100644 --- a/third_party/xla/xla/client/lib/prng_test.cc +++ b/third_party/xla/xla/client/lib/prng_test.cc @@ -18,11 +18,11 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "xla/client/lib/constants.h" #include "xla/client/xla_builder.h" #include "xla/primitive_util.h" #include "xla/shape.h" -#include "xla/statusor.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/client/lib/qr.cc b/third_party/xla/xla/client/lib/qr.cc index 39631584fea623..dc0792cb86372b 100644 --- a/third_party/xla/xla/client/lib/qr.cc +++ b/third_party/xla/xla/client/lib/qr.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "xla/client/lib/arithmetic.h" #include "xla/client/lib/constants.h" #include "xla/client/lib/loops.h" @@ -29,7 +30,6 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "tsl/platform/errors.h" namespace xla { diff --git a/third_party/xla/xla/client/lib/qr_test.cc b/third_party/xla/xla/client/lib/qr_test.cc index a21932b3e797e3..2bd613e587b475 100644 --- a/third_party/xla/xla/client/lib/qr_test.cc +++ b/third_party/xla/xla/client/lib/qr_test.cc @@ -15,13 +15,13 @@ limitations under the License. #include "xla/client/lib/qr.h" +#include "absl/status/statusor.h" #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/client/lib/constants.h" #include "xla/client/lib/matrix.h" #include "xla/client/xla_builder.h" #include "xla/literal.h" -#include "xla/statusor.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/client/lib/self_adjoint_eig.cc b/third_party/xla/xla/client/lib/self_adjoint_eig.cc index 606f968c6648cc..241ca6528fc614 100644 --- a/third_party/xla/xla/client/lib/self_adjoint_eig.cc +++ b/third_party/xla/xla/client/lib/self_adjoint_eig.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "xla/client/lib/arithmetic.h" #include "xla/client/lib/comparators.h" #include "xla/client/lib/constants.h" @@ -31,7 +32,6 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/util.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/client/lib/self_adjoint_eig_test.cc b/third_party/xla/xla/client/lib/self_adjoint_eig_test.cc index 7be635f2a9a796..9937eee4d40285 100644 --- a/third_party/xla/xla/client/lib/self_adjoint_eig_test.cc +++ b/third_party/xla/xla/client/lib/self_adjoint_eig_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "xla/array.h" #include "xla/array2d.h" #include "xla/array3d.h" @@ -28,7 +29,6 @@ limitations under the License. #include "xla/client/lib/matrix.h" #include "xla/client/xla_builder.h" #include "xla/literal.h" -#include "xla/statusor.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/client/lib/slicing.cc b/third_party/xla/xla/client/lib/slicing.cc index 881181b86c317a..0a8e74b60e18ae 100644 --- a/third_party/xla/xla/client/lib/slicing.cc +++ b/third_party/xla/xla/client/lib/slicing.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/client/lib/arithmetic.h" #include "xla/client/lib/constants.h" @@ -28,7 +29,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/util.h" namespace xla { diff --git a/third_party/xla/xla/client/lib/svd.cc b/third_party/xla/xla/client/lib/svd.cc index 2304ec3b14c253..23ab459b6c3a52 100644 --- a/third_party/xla/xla/client/lib/svd.cc +++ b/third_party/xla/xla/client/lib/svd.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "xla/client/lib/arithmetic.h" #include "xla/client/lib/comparators.h" #include "xla/client/lib/constants.h" @@ -30,7 +31,6 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/client/lib/svd_test.cc b/third_party/xla/xla/client/lib/svd_test.cc index f27d78974e1dd2..4a2437b6fbeddb 100644 --- a/third_party/xla/xla/client/lib/svd_test.cc +++ b/third_party/xla/xla/client/lib/svd_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/client/lib/arithmetic.h" @@ -28,7 +29,6 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/literal.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/client/lib/testing.cc b/third_party/xla/xla/client/lib/testing.cc index 0461108bb24960..8aab8351f41115 100644 --- a/third_party/xla/xla/client/lib/testing.cc +++ b/third_party/xla/xla/client/lib/testing.cc @@ -18,12 +18,12 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "xla/client/xla_builder.h" #include "xla/execution_options_util.h" #include "xla/literal.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/tests/test_utils.h" #include "xla/types.h" #include "xla/util.h" diff --git a/third_party/xla/xla/client/lib/tridiagonal.cc b/third_party/xla/xla/client/lib/tridiagonal.cc index 4d6be1d176bc99..128a9810e7c437 100644 --- a/third_party/xla/xla/client/lib/tridiagonal.cc +++ b/third_party/xla/xla/client/lib/tridiagonal.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/client/lib/constants.h" #include "xla/client/lib/loops.h" @@ -29,7 +30,6 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" namespace xla { namespace tridiagonal { diff --git a/third_party/xla/xla/client/lib/tuple.cc b/third_party/xla/xla/client/lib/tuple.cc index ae116fba3ae723..4cefa748bc8d04 100644 --- a/third_party/xla/xla/client/lib/tuple.cc +++ b/third_party/xla/xla/client/lib/tuple.cc @@ -18,11 +18,11 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" #include "xla/client/xla_builder.h" #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/client/lib/tuple.h b/third_party/xla/xla/client/lib/tuple.h index 56c8a3f6b99e08..dd8fb3c6ec82bf 100644 --- a/third_party/xla/xla/client/lib/tuple.h +++ b/third_party/xla/xla/client/lib/tuple.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_TUPLE_H_ #define XLA_CLIENT_LIB_TUPLE_H_ +#include "absl/status/statusor.h" #include "xla/client/xla_builder.h" #include "xla/shape_tree.h" -#include "xla/statusor.h" namespace xla { From 1dfff9d609ca185733b70a5b26da95f4ab01071a Mon Sep 17 00:00:00 2001 From: Jan Date: Thu, 20 Jun 2024 11:49:08 -0700 Subject: [PATCH 072/256] PR #13985: Removing spurious `option go_package` from autotuning.proto Imported from GitHub PR https://github.com/openxla/xla/pull/13985 Generally, **bazel** uses the `M` option for the Go proto generator -- see https://protobuf.dev/reference/go/go-generated/#invocation Unfortunately, hard-coding `option go_package` here prevents other libraries from using the `.proto` file by compiling the Go bindings in their own package structure. Removing it makes it friendlier for downstream Go users of the XLA library. Plus I don't see it used anywhere in the library, so better remove things not used -- and it's not defined in any of the other `.proto` files. Copybara import of the project: -- b6e20a245771f6acd7e966c0318f8321f69f1546 by Jan : Removing spurious go_package autotuning.proto Generally, **bazel** uses the `M` option for the Go proto generator -- see https://protobuf.dev/reference/go/go-generated/#invocation Unfortunately, hard-coding `option go_package` here prevents other libraries from using the `.proto` file by compiling the Go bindings in their own package structure. Removing it makes it friendlier for downstream Go users of the XLA library. Plus I don't see it used anywhere in the library, so better remove things not used -- and it's not defined in any of the other `.proto` files. Merging this change closes #13985 PiperOrigin-RevId: 645108212 --- third_party/xla/xla/autotuning.proto | 2 -- 1 file changed, 2 deletions(-) diff --git a/third_party/xla/xla/autotuning.proto b/third_party/xla/xla/autotuning.proto index 9d6de133f6a93f..a7ffcbb57ae6ef 100644 --- a/third_party/xla/xla/autotuning.proto +++ b/third_party/xla/xla/autotuning.proto @@ -11,8 +11,6 @@ import "google/protobuf/any.proto"; import "google/protobuf/duration.proto"; import "tsl/protobuf/dnn.proto"; -option go_package = "github.com/google/tsl/tsl/go/protobuf/for_core_protos_go_proto"; - message CudnnVersion { int32 major = 1; int32 minor = 2; From 444d8bdfaeee240843181cb33127f02204c9f15f Mon Sep 17 00:00:00 2001 From: Laura Pak Date: Thu, 20 Jun 2024 12:10:02 -0700 Subject: [PATCH 073/256] Return absl::Status not TFLiteStatus from ::tflite::optimize::QuantizeWeights. PiperOrigin-RevId: 645115004 --- tensorflow/compiler/mlir/lite/BUILD | 1 - .../mlir/lite/tf_to_tfl_flatbuffer.cc | 11 +-- tensorflow/lite/toco/tflite/export.cc | 9 +- tensorflow/lite/tools/optimize/BUILD | 2 + .../lite/tools/optimize/quantize_weights.cc | 87 ++++++++++++------- .../lite/tools/optimize/quantize_weights.h | 11 +-- .../optimize/quantize_weights_portable.cc | 69 +++++++++------ .../tools/optimize/quantize_weights_test.cc | 80 ++++++++--------- 8 files changed, 154 insertions(+), 116 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 32a5bc738faa1a..940cb0c47ce319 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1442,7 +1442,6 @@ cc_library( "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", - "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/experimental/remat:metadata_util", "//tensorflow/lite/python/metrics:converter_error_data_proto_cc", "//tensorflow/lite/toco:toco_flags_proto_cc", diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 227de88bb8510b..19f4de9a0395b2 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -78,7 +78,6 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/public/session.h" -#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/experimental/remat/metadata_util.h" #include "tensorflow/lite/python/metrics/converter_error_data.pb.h" #include "tensorflow/lite/tools/optimize/quantize_weights.h" @@ -287,12 +286,10 @@ absl::Status ApplyDynamicRangeQuantizationFromOldQuantizer( } bool use_updated_hybrid_scheme = !quant_specs.disable_per_channel; - if (::tflite::optimize::QuantizeWeights( - &q_builder, input_model, quantized_type, use_updated_hybrid_scheme, - ::tflite::optimize::QuantizerType::OLD_QUANTIZER) != kTfLiteOk) { - return absl::InvalidArgumentError( - "Quantize weights transformation failed."); - } + absl::Status quantize_weights_status = ::tflite::optimize::QuantizeWeights( + &q_builder, input_model, quantized_type, use_updated_hybrid_scheme, + ::tflite::optimize::QuantizerType::OLD_QUANTIZER); + if (!quantize_weights_status.ok()) return quantize_weights_status; const uint8_t* q_buffer = q_builder.GetBufferPointer(); *result = std::string(reinterpret_cast(q_buffer), q_builder.GetSize()); diff --git a/tensorflow/lite/toco/tflite/export.cc b/tensorflow/lite/toco/tflite/export.cc index ec4fcee3c7e23a..74cdeb8c2a0220 100644 --- a/tensorflow/lite/toco/tflite/export.cc +++ b/tensorflow/lite/toco/tflite/export.cc @@ -680,10 +680,11 @@ tensorflow::Status Export( return tensorflow::errors::InvalidArgument( "Quantized type not recognized"); } - if (::tflite::optimize::QuantizeWeights( - &q_builder, input_model, quantized_type, - !params.disable_per_channel, - ::tflite::optimize::QuantizerType::OLD_QUANTIZER) != kTfLiteOk) { + if (!::tflite::optimize::QuantizeWeights( + &q_builder, input_model, quantized_type, + !params.disable_per_channel, + ::tflite::optimize::QuantizerType::OLD_QUANTIZER) + .ok()) { return tensorflow::errors::InvalidArgument( "Quantize weights transformation failed."); } diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD index dbdddcedecf06f..711f97bdfddd16 100644 --- a/tensorflow/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -290,11 +290,13 @@ cc_library( ":quantization_utils", ":model_utils", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@flatbuffers", "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core:framework", # TODO(suharshs): Move the relevant quantization utils to a non-internal location. "//tensorflow/lite/kernels/internal:tensor_utils", diff --git a/tensorflow/lite/tools/optimize/quantize_weights.cc b/tensorflow/lite/tools/optimize/quantize_weights.cc index a42662284c8932..93a65ecc6bfa11 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights.cc @@ -23,8 +23,10 @@ limitations under the License. #include "flatbuffers/flexbuffers.h" #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/context.h" #include "tensorflow/lite/core/model.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" @@ -207,7 +209,7 @@ bool CheckAllOpInputsQuantized(const SubGraphT* subgraph, const OperatorT* op, // Inserts Tensors for each input tensor of op that should be // quantized into tensor_map. -TfLiteStatus InsertQuantizableInputTensorsFromOperator( +absl::Status InsertQuantizableInputTensorsFromOperator( const ModelT* model, OperatorT* op, uint64_t weights_min_num_elements, const CustomOpMap& custom_op_map, absl::flat_hash_map* tensor_map, @@ -234,7 +236,9 @@ TfLiteStatus InsertQuantizableInputTensorsFromOperator( } uint64_t num_elements; - TF_LITE_ENSURE_STATUS(utils::NumElements(*tensor, &num_elements)); + if (utils::NumElements(*tensor, &num_elements) != kTfLiteOk) { + return absl::InternalError("Error in quantization_utils NumElements"); + } if (num_elements < weights_min_num_elements) { LOG(INFO) << "Skipping quantization of tensor " << tensor->name << " because it has fewer than " << weights_min_num_elements @@ -303,7 +307,7 @@ TfLiteStatus InsertQuantizableInputTensorsFromOperator( } } - return kTfLiteOk; + return absl::OkStatus(); } // Updates operator code versions for the operators with INT8 inputs. @@ -392,7 +396,7 @@ inline bool IsOpDenylisted(const flat_hash_set& op_denylist, return op_denylist.find(op_code) != op_denylist.end(); } -TfLiteStatus QuantizeWeightsInt8( +absl::Status QuantizeWeightsInt8( flatbuffers::FlatBufferBuilder* builder, const Model* input_model, bool use_hybrid_evaluation, uint64_t weights_min_num_elements, const CustomOpMap& custom_op_map, bool use_updated_hybrid_scheme, @@ -407,20 +411,26 @@ TfLiteStatus QuantizeWeightsInt8( absl::flat_hash_map tensor_map; for (int i = 0; i < subgraph->operators.size(); ++i) { OperatorT* op = subgraph->operators[i].get(); - TF_LITE_ENSURE_STATUS(InsertQuantizableInputTensorsFromOperator( + absl::Status status = InsertQuantizableInputTensorsFromOperator( model.get(), op, weights_min_num_elements, custom_op_map, &tensor_map, - subgraph_index, use_updated_hybrid_scheme)); + subgraph_index, use_updated_hybrid_scheme); + if (!status.ok()) return status; } for (std::pair tensor_pair : tensor_map) { // Quantize the tensor. if (tensor_pair.second.is_per_channel) { - TF_LITE_ENSURE_STATUS(utils::SymmetricQuantizeTensorPerChannel( - model.get(), tensor_pair.second.t, tensor_pair.second.channel_dim, - nullptr)); + if (utils::SymmetricQuantizeTensorPerChannel( + model.get(), tensor_pair.second.t, + tensor_pair.second.channel_dim, nullptr) != kTfLiteOk) { + return absl::InternalError( + "SymmetricQuantizeTensorPerChannel failed"); + } } else { - TF_LITE_ENSURE_STATUS( - utils::SymmetricQuantizeTensor(model.get(), tensor_pair.second.t)); + if (utils::SymmetricQuantizeTensor(model.get(), tensor_pair.second.t) != + kTfLiteOk) { + return absl::InternalError("SymmetricQuantizeTensor failed"); + } } } @@ -436,7 +446,7 @@ TfLiteStatus QuantizeWeightsInt8( consumer_op_infos, custom_op_map); if (tensor_idx < 0) { // Error message is already logged by PassQuantizationAndGetConsumers. - return kTfLiteError; + return absl::InternalError("PassQuantizationAndGetConsumers failed"); } } @@ -517,10 +527,10 @@ TfLiteStatus QuantizeWeightsInt8( Model::Pack(*builder, model.get()); FinishModelBuffer(*builder, output_model_location); - return kTfLiteOk; + return absl::OkStatus(); } -TfLiteStatus QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder* builder, +absl::Status QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder* builder, const Model* input_model) { std::unique_ptr model; model.reset(input_model->UnPack()); @@ -540,7 +550,7 @@ TfLiteStatus QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder* builder, TensorT* tensor = subgraph->tensors[tensor_idx].get(); BufferT* buffer = model->buffers[tensor->buffer].get(); if (buffer == nullptr) { - return kTfLiteError; + return absl::InternalError("Buffer is null"); } // Quantize tensors that have data to quantize. bool is_constant = !model->buffers[tensor->buffer].get()->data.empty(); @@ -553,8 +563,10 @@ TfLiteStatus QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder* builder, // The hash map ensures that we quantize each tensor exactly once. for (std::pair tensor_pair : tensor_map) { // Quantize the tensor. - TF_LITE_ENSURE_STATUS( - utils::QuantizeTensorFloat16(model.get(), tensor_pair.second)); + if (utils::QuantizeTensorFloat16(model.get(), tensor_pair.second) != + kTfLiteOk) { + return absl::InternalError("QuantizeTensorFloat16 failed"); + } int32_t tensor_idx = tensor_pair.first; TensorT* tensor = tensor_pair.second; @@ -593,12 +605,12 @@ TfLiteStatus QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder* builder, flatbuffers::Offset output_model_location = Model::Pack(*builder, model.get()); FinishModelBuffer(*builder, output_model_location); - return kTfLiteOk; + return absl::OkStatus(); } } // namespace namespace internal { -TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const Model* input_model, uint64_t weights_min_num_elements, bool use_hybrid_evaluation, @@ -606,8 +618,11 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, // By default we require that only weights with more than // kWeightsMinSizeDefault elements are quantized. if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { - return mlir::lite::QuantizeWeights( - builder, input_model, weights_min_num_elements, use_hybrid_evaluation); + return mlir::lite::QuantizeWeights(builder, input_model, + weights_min_num_elements, + use_hybrid_evaluation) == kTfLiteOk + ? absl::OkStatus() + : absl::InternalError("QuantizeWeights failed"); } CustomOpMap custom_op_map; return QuantizeWeightsInt8(builder, input_model, use_hybrid_evaluation, @@ -616,13 +631,15 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, } } // namespace internal -TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const Model* input_model, uint64_t weights_min_num_elements, QuantizerType quantizer_type) { if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { return mlir::lite::QuantizeWeights(builder, input_model, - weights_min_num_elements); + weights_min_num_elements) == kTfLiteOk + ? absl::OkStatus() + : absl::InternalError("QuantizeWeights failed"); } CustomOpMap custom_op_map; return QuantizeWeightsInt8(builder, input_model, true, @@ -630,7 +647,7 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, kUseUpdatedHybridSchemeDefault); } -TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const Model* input_model, BufferType quant_type, bool use_updated_hybrid_scheme, QuantizerType quantizer_type) { @@ -639,7 +656,9 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { return mlir::lite::QuantizeWeights(builder, input_model, (mlir::lite::BufferType)quant_type, - use_updated_hybrid_scheme); + use_updated_hybrid_scheme) == kTfLiteOk + ? absl::OkStatus() + : absl::InternalError("QuantizeWeights failed"); } switch (quant_type) { case BufferType::QUANTIZED_INT8: { @@ -653,7 +672,7 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, } } -TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const Model* input_model, uint64_t weights_min_num_elements, const CustomOpMap& custom_op_map, @@ -661,15 +680,18 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { mlir::lite::CustomOpMap mlir_custom_op_map; ConstructMLIRCustomOpMap(mlir_custom_op_map, custom_op_map); - return mlir::lite::QuantizeWeights( - builder, input_model, weights_min_num_elements, mlir_custom_op_map); + return mlir::lite::QuantizeWeights(builder, input_model, + weights_min_num_elements, + mlir_custom_op_map) == kTfLiteOk + ? absl::OkStatus() + : absl::InternalError("QuantizeWeights failed"); } return QuantizeWeightsInt8(builder, input_model, true, weights_min_num_elements, custom_op_map, kUseUpdatedHybridSchemeDefault); } -TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const Model* input_model, uint64_t weights_min_num_elements, const CustomOpMap& custom_op_map, @@ -680,8 +702,11 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, mlir::lite::CustomOpMap mlir_custom_op_map; ConstructMLIRCustomOpMap(mlir_custom_op_map, custom_op_map); return mlir::lite::QuantizeWeights( - builder, input_model, weights_min_num_elements, mlir_custom_op_map, - use_updated_hybrid_scheme, op_denylist); + builder, input_model, weights_min_num_elements, + mlir_custom_op_map, use_updated_hybrid_scheme, + op_denylist) == kTfLiteOk + ? absl::OkStatus() + : absl::InternalError("QuantizeWeights failed"); } return QuantizeWeightsInt8(builder, input_model, /*use_hybrid_evaluation=*/true, diff --git a/tensorflow/lite/tools/optimize/quantize_weights.h b/tensorflow/lite/tools/optimize/quantize_weights.h index 2b589584acdbec..d2f22e2b4f0ff0 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights.h +++ b/tensorflow/lite/tools/optimize/quantize_weights.h @@ -22,6 +22,7 @@ limitations under the License. #include "flatbuffers/flexbuffers.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "tensorflow/lite/context.h" #include "tensorflow/lite/core/model.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -59,7 +60,7 @@ constexpr bool kUseUpdatedHybridSchemeDefault = true; // A tflite::Model can be obtained from the builder with: // const uint8_t* buffer = builder->GetBufferPointer(); // tflite::Model* model = GetModel(buffer); -TfLiteStatus QuantizeWeights( +absl::Status QuantizeWeights( flatbuffers::FlatBufferBuilder* builder, const Model* input_model, BufferType quant_type = BufferType::QUANTIZED_INT8, bool use_updated_hybrid_scheme = kUseUpdatedHybridSchemeDefault, @@ -67,13 +68,13 @@ TfLiteStatus QuantizeWeights( // Same as above, but only weights with greater than or equal // weights_min_num_elements elements will be quantized. -TfLiteStatus QuantizeWeights( +absl::Status QuantizeWeights( flatbuffers::FlatBufferBuilder* builder, const Model* input_model, uint64_t weights_min_num_elements, QuantizerType quantizer_type = QuantizerType::OLD_QUANTIZER); // Same as above, but with entry point of quantizing custom ops. -TfLiteStatus QuantizeWeights( +absl::Status QuantizeWeights( flatbuffers::FlatBufferBuilder* builder, const Model* input_model, uint64_t weights_min_num_elements, const CustomOpMap& custom_op_map, QuantizerType quantizer_type = QuantizerType::OLD_QUANTIZER); @@ -81,7 +82,7 @@ TfLiteStatus QuantizeWeights( // Same as above, but if use updated_hybrid_scheme is false, // use previous quantization scheme. Optional op_denylist argument // disables hybrid evaluation for provided BuiltinOperators. -TfLiteStatus QuantizeWeights( +absl::Status QuantizeWeights( flatbuffers::FlatBufferBuilder* builder, const Model* input_model, uint64_t weights_min_num_elements, const CustomOpMap& custom_op_map, bool use_updated_hybrid_scheme, @@ -94,7 +95,7 @@ namespace internal { // // We use this internal QuantizeWeights call to test models with hybrid // evaluation disabled. -TfLiteStatus QuantizeWeights( +absl::Status QuantizeWeights( flatbuffers::FlatBufferBuilder* builder, const Model* input_model, uint64_t weights_min_num_elements, bool use_hybrid_evaluation, QuantizerType quantizer_type = QuantizerType::OLD_QUANTIZER); diff --git a/tensorflow/lite/tools/optimize/quantize_weights_portable.cc b/tensorflow/lite/tools/optimize/quantize_weights_portable.cc index 5a77092e1c2e4b..d85d2292217cc0 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights_portable.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights_portable.cc @@ -24,6 +24,7 @@ limitations under the License. #include "flatbuffers/flexbuffers.h" #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/context.h" #include "tensorflow/lite/core/model.h" @@ -380,7 +381,7 @@ inline bool IsOpDenylisted(const flat_hash_set& op_denylist, return op_denylist.find(op_code) != op_denylist.end(); } -TfLiteStatus QuantizeWeightsInt8( +absl::Status QuantizeWeightsInt8( flatbuffers::FlatBufferBuilder* builder, const Model* input_model, bool use_hybrid_evaluation, uint64_t weights_min_num_elements, const CustomOpMap& custom_op_map, bool use_updated_hybrid_scheme, @@ -395,20 +396,28 @@ TfLiteStatus QuantizeWeightsInt8( absl::flat_hash_map tensor_map; for (int i = 0; i < subgraph->operators.size(); ++i) { OperatorT* op = subgraph->operators[i].get(); - TF_LITE_ENSURE_STATUS(InsertQuantizableInputTensorsFromOperator( - model.get(), op, weights_min_num_elements, custom_op_map, &tensor_map, - subgraph_index, use_updated_hybrid_scheme)); + if (InsertQuantizableInputTensorsFromOperator( + model.get(), op, weights_min_num_elements, custom_op_map, + &tensor_map, subgraph_index, + use_updated_hybrid_scheme) != kTfLiteOk) { + return absl::InternalError( + "Failed to insert quantizable input tensors from operator"); + } } for (std::pair tensor_pair : tensor_map) { // Quantize the tensor. if (tensor_pair.second.is_per_channel) { - TF_LITE_ENSURE_STATUS(utils::SymmetricQuantizeTensorPerChannel( - model.get(), tensor_pair.second.t, tensor_pair.second.channel_dim, - nullptr)); + if (utils::SymmetricQuantizeTensorPerChannel( + model.get(), tensor_pair.second.t, + tensor_pair.second.channel_dim, nullptr) != kTfLiteOk) { + return absl::InternalError("Failed to quantize tensor per channel"); + } } else { - TF_LITE_ENSURE_STATUS( - utils::SymmetricQuantizeTensor(model.get(), tensor_pair.second.t)); + if (utils::SymmetricQuantizeTensor(model.get(), tensor_pair.second.t) != + kTfLiteOk) { + return absl::InternalError("Failed to quantize tensor"); + } } } @@ -424,7 +433,8 @@ TfLiteStatus QuantizeWeightsInt8( consumer_op_infos, custom_op_map); if (tensor_idx < 0) { // Error message is already logged by PassQuantizationAndGetConsumers. - return kTfLiteError; + return absl::InternalError( + "Failed to pass quantization and get consumers"); } } @@ -505,10 +515,10 @@ TfLiteStatus QuantizeWeightsInt8( Model::Pack(*builder, model.get()); FinishModelBuffer(*builder, output_model_location); - return kTfLiteOk; + return absl::OkStatus(); } -TfLiteStatus QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder* builder, +absl::Status QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder* builder, const Model* input_model) { std::unique_ptr model; model.reset(input_model->UnPack()); @@ -528,7 +538,7 @@ TfLiteStatus QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder* builder, TensorT* tensor = subgraph->tensors[tensor_idx].get(); BufferT* buffer = model->buffers[tensor->buffer].get(); if (buffer == nullptr) { - return kTfLiteError; + return absl::InternalError("Buffer is null"); } // Quantize tensors that have data to quantize. bool is_constant = !model->buffers[tensor->buffer].get()->data.empty(); @@ -541,8 +551,10 @@ TfLiteStatus QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder* builder, // The hash map ensures that we quantize each tensor exactly once. for (std::pair tensor_pair : tensor_map) { // Quantize the tensor. - TF_LITE_ENSURE_STATUS( - utils::QuantizeTensorFloat16(model.get(), tensor_pair.second)); + if (utils::QuantizeTensorFloat16(model.get(), tensor_pair.second) != + kTfLiteOk) { + return absl::InternalError("QuantizeTensorFloat16 failed"); + } int32_t tensor_idx = tensor_pair.first; TensorT* tensor = tensor_pair.second; @@ -581,19 +593,20 @@ TfLiteStatus QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder* builder, flatbuffers::Offset output_model_location = Model::Pack(*builder, model.get()); FinishModelBuffer(*builder, output_model_location); - return kTfLiteOk; + return absl::OkStatus(); } } // namespace namespace internal { -TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const Model* input_model, uint64_t weights_min_num_elements, bool use_hybrid_evaluation, QuantizerType quantizer_type) { if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { LOG(ERROR) << "Portable targets cannot use the MLIR quantizer."; - return kTfLiteError; + return absl::InternalError( + "Portable targets cannot use the MLIR quantizer."); } // By default we require that only weights with more than // kWeightsMinSizeDefault elements are quantized. @@ -604,13 +617,14 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, } } // namespace internal -TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const Model* input_model, uint64_t weights_min_num_elements, QuantizerType quantizer_type) { if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { LOG(ERROR) << "Portable targets cannot use the MLIR quantizer."; - return kTfLiteError; + return absl::InternalError( + "Portable targets cannot use the MLIR quantizer."); } CustomOpMap custom_op_map; return QuantizeWeightsInt8(builder, input_model, true, @@ -618,13 +632,14 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, kUseUpdatedHybridSchemeDefault); } -TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const Model* input_model, BufferType quant_type, bool use_updated_hybrid_scheme, QuantizerType quantizer_type) { if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { LOG(ERROR) << "Portable targets cannot use the MLIR quantizer."; - return kTfLiteError; + return absl::InternalError( + "Portable targets cannot use the MLIR quantizer."); } switch (quant_type) { case BufferType::QUANTIZED_INT8: { @@ -640,21 +655,22 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, } } -TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const Model* input_model, uint64_t weights_min_num_elements, const CustomOpMap& custom_op_map, QuantizerType quantizer_type) { if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { LOG(ERROR) << "Portable targets cannot use the MLIR quantizer."; - return kTfLiteError; + return absl::InternalError( + "Portable targets cannot use the MLIR quantizer."); } return QuantizeWeightsInt8(builder, input_model, true, weights_min_num_elements, custom_op_map, kUseUpdatedHybridSchemeDefault); } -TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const Model* input_model, uint64_t weights_min_num_elements, const CustomOpMap& custom_op_map, @@ -663,7 +679,8 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, QuantizerType quantizer_type) { if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { LOG(ERROR) << "Portable targets cannot use the MLIR quantizer."; - return kTfLiteError; + return absl::InternalError( + "Portable targets cannot use the MLIR quantizer."); } return QuantizeWeightsInt8(builder, input_model, /*use_hybrid_evaluation=*/true, diff --git a/tensorflow/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/lite/tools/optimize/quantize_weights_test.cc index 8f7a1150bef722..0e9c3efc17acd9 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights_test.cc @@ -134,9 +134,8 @@ class QuantizeWeightsTest : public testing::Test { TEST_F(QuantizeWeightsTest, QuantizationSucceeds) { LoadBasicModel(); flatbuffers::FlatBufferBuilder builder; - auto status = - QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER); - EXPECT_EQ(status, kTfLiteOk); + ASSERT_TRUE( + QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER).ok()); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -149,9 +148,9 @@ TEST_F(QuantizeWeightsTest, WeightsMinNumElements) { // happen, i.e. the original model is the same size as the old one. flatbuffers::FlatBufferBuilder builder; const uint64_t kWeightsMinNumElements = 1000000; - EXPECT_EQ(QuantizeWeights(&builder, model_, kWeightsMinNumElements, - QuantizerType::OLD_QUANTIZER), - kTfLiteOk); + ASSERT_TRUE(QuantizeWeights(&builder, model_, kWeightsMinNumElements, + QuantizerType::OLD_QUANTIZER) + .ok()); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -180,9 +179,8 @@ TEST_F(QuantizeWeightsTest, WeightsMinNumElements) { TEST_F(QuantizeWeightsTest, HybridConv) { LoadBasicModel(); flatbuffers::FlatBufferBuilder builder; - auto status = - QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER); - EXPECT_EQ(status, kTfLiteOk); + ASSERT_TRUE( + QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER).ok()); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -236,10 +234,10 @@ TEST_F(QuantizeWeightsTest, HybridConv) { TEST_F(QuantizeWeightsTest, DequantizeConv) { LoadBasicModel(); flatbuffers::FlatBufferBuilder builder; - auto status = internal::QuantizeWeights(&builder, model_, 0, - /*use_hybrid_evaluation=*/false, - QuantizerType::OLD_QUANTIZER); - EXPECT_EQ(status, kTfLiteOk); + ASSERT_TRUE(internal::QuantizeWeights(&builder, model_, 0, + /*use_hybrid_evaluation=*/false, + QuantizerType::OLD_QUANTIZER) + .ok()); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -296,10 +294,10 @@ TEST_F(QuantizeWeightsTest, DequantizeConv) { TEST_F(QuantizeWeightsTest, DequantizeConvFloat16) { LoadBasicModel(); flatbuffers::FlatBufferBuilder builder; - auto status = tflite::optimize::QuantizeWeights( - &builder, model_, BufferType::QUANTIZED_FLOAT16, - kUseUpdatedHybridSchemeDefault, QuantizerType::OLD_QUANTIZER); - EXPECT_EQ(status, kTfLiteOk); + ASSERT_TRUE(tflite::optimize::QuantizeWeights( + &builder, model_, BufferType::QUANTIZED_FLOAT16, + kUseUpdatedHybridSchemeDefault, QuantizerType::OLD_QUANTIZER) + .ok()); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -356,9 +354,8 @@ TEST_F(QuantizeWeightsTest, DequantizeConvFloat16) { TEST_F(QuantizeWeightsTest, SharedWeights_Hybrid) { LoadSharedWeightsModel(); flatbuffers::FlatBufferBuilder builder; - auto status = - QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER); - EXPECT_EQ(status, kTfLiteOk); + ASSERT_TRUE( + QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER).ok()); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -390,10 +387,10 @@ TEST_F(QuantizeWeightsTest, SharedWeights_Hybrid) { TEST_F(QuantizeWeightsTest, SharedWeights_Dequantize) { LoadSharedWeightsModel(); flatbuffers::FlatBufferBuilder builder; - auto status = internal::QuantizeWeights(&builder, model_, 0, - /*use_hybrid_evaluation*/ false, - QuantizerType::OLD_QUANTIZER); - EXPECT_EQ(status, kTfLiteOk); + ASSERT_TRUE(internal::QuantizeWeights(&builder, model_, 0, + /*use_hybrid_evaluation*/ false, + QuantizerType::OLD_QUANTIZER) + .ok()); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -433,9 +430,8 @@ TEST_F(QuantizeWeightsTest, SharedWeights_Dequantize) { TEST_F(QuantizeWeightsTest, VerifyGatherQuantization) { LoadGatherTestModel(); flatbuffers::FlatBufferBuilder builder; - auto status = - QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER); - EXPECT_EQ(status, kTfLiteOk); + ASSERT_TRUE( + QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER).ok()); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -472,9 +468,9 @@ TEST_F(QuantizeWeightsTest, VerifyCustomOpQuantizationDequantize) { }; flatbuffers::FlatBufferBuilder builder; - auto status = QuantizeWeights(&builder, model_, 0, custom_op_map, - QuantizerType::OLD_QUANTIZER); - ASSERT_EQ(status, kTfLiteOk); + ASSERT_TRUE(QuantizeWeights(&builder, model_, 0, custom_op_map, + QuantizerType::OLD_QUANTIZER) + .ok()); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -520,9 +516,9 @@ TEST_F(QuantizeWeightsTest, VerifyCustomOpQuantizationHybrid) { }; flatbuffers::FlatBufferBuilder builder; - auto status = QuantizeWeights(&builder, model_, 0, custom_op_map, - QuantizerType::OLD_QUANTIZER); - ASSERT_EQ(status, kTfLiteOk); + ASSERT_TRUE(QuantizeWeights(&builder, model_, 0, custom_op_map, + QuantizerType::OLD_QUANTIZER) + .ok()); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -553,10 +549,10 @@ TEST_F(QuantizeWeightsTest, VerifyUpdatedHybridSchemeFalseQuantizationHybrid) { LoadBasicModel(); flatbuffers::FlatBufferBuilder builder; const CustomOpMap custom_op_map; - auto status = QuantizeWeights( - &builder, model_, 0, custom_op_map, /*use_updated_hybrid_scheme=*/false, - /*op_denylist=*/{}, QuantizerType::OLD_QUANTIZER); - EXPECT_EQ(status, kTfLiteOk); + ASSERT_TRUE(QuantizeWeights(&builder, model_, 0, custom_op_map, + /*use_updated_hybrid_scheme=*/false, + /*op_denylist=*/{}, QuantizerType::OLD_QUANTIZER) + .ok()); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -607,11 +603,11 @@ TEST_F(QuantizeWeightsTest, DequantizeConvBlocklisted) { LoadBasicModel(); flatbuffers::FlatBufferBuilder builder; const CustomOpMap custom_op_map; - auto status = QuantizeWeights(&builder, model_, 0, custom_op_map, - /*use_updated_hybrid_scheme=*/true, - /*op_denylist*/ {BuiltinOperator_CONV_2D}, - QuantizerType::OLD_QUANTIZER); - EXPECT_EQ(status, kTfLiteOk); + ASSERT_TRUE(QuantizeWeights(&builder, model_, 0, custom_op_map, + /*use_updated_hybrid_scheme=*/true, + /*op_denylist*/ {BuiltinOperator_CONV_2D}, + QuantizerType::OLD_QUANTIZER) + .ok()); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); From b1587f62c2691e1c1e15ce2cdeed10bede4d7bdf Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Thu, 20 Jun 2024 12:10:46 -0700 Subject: [PATCH 074/256] Move StreamExecutor::Memcpy to-host processing completely to Stream and its derived classes. PiperOrigin-RevId: 645115168 --- .../stream_executor/stream_executor.cc | 12 ------------ .../stream_executor/stream_executor_internal.h | 9 ++++++++- .../xla/xla/backends/interpreter/executor.cc | 10 ---------- .../xla/xla/backends/interpreter/executor.h | 2 -- .../xla/xla/stream_executor/cuda/cuda_executor.cc | 14 -------------- .../xla/xla/stream_executor/gpu/gpu_executor.h | 3 --- .../xla/xla/stream_executor/gpu/gpu_stream.cc | 13 +++++++++++++ .../xla/xla/stream_executor/gpu/gpu_stream.h | 4 +--- .../xla/xla/stream_executor/host/host_executor.cc | 11 ----------- .../xla/xla/stream_executor/host/host_executor.h | 2 -- .../xla/xla/stream_executor/host/host_stream.cc | 9 +++++++++ .../xla/xla/stream_executor/host/host_stream.h | 4 +--- .../xla/xla/stream_executor/mock_stream_executor.h | 4 ---- .../xla/xla/stream_executor/stream_common.cc | 6 ------ .../xla/xla/stream_executor/stream_common.h | 3 --- .../xla/xla/stream_executor/stream_executor.h | 6 ------ .../xla/xla/stream_executor/tpu/tpu_executor.cc | 10 ---------- .../xla/xla/stream_executor/tpu/tpu_executor.h | 4 ---- .../xla/xla/stream_executor/tpu/tpu_stream.h | 14 +++++++++----- 19 files changed, 41 insertions(+), 99 deletions(-) diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index ccf808c7e11504..4ad6de2fce2993 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -312,18 +312,6 @@ class CStreamExecutor : public StreamExecutorCommon { size, c_status.get()); return StatusFromTF_Status(c_status.get()); } - absl::Status Memcpy(Stream* stream, void* host_dst, - const DeviceMemoryBase& gpu_src, uint64 size) override { - OwnedTFStatus c_status(TF_NewStatus()); - SP_Stream stream_handle = static_cast(stream)->Handle(); - SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src); - stream_executor_->memcpy_dtoh(&device_, stream_handle, host_dst, - &device_mem_src, size, c_status.get()); - if (TF_GetCode(c_status.get()) != TF_OK) { - LOG(ERROR) << TF_Message(c_status.get()); - } - return StatusFromTF_Status(c_status.get()); - } bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, uint64 size) override { diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h index 02bb318ea33b09..ffdcfe2a64bade 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h +++ b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h @@ -252,7 +252,14 @@ class CStream : public StreamCommon { } absl::Status Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, uint64_t size) override { - return StreamCommon::Memcpy(host_dst, gpu_src, size); + tensorflow::TF_StatusPtr c_status(TF_NewStatus()); + SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src); + stream_executor_->memcpy_dtoh(device_, stream_handle_, host_dst, + &device_mem_src, size, c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + } + return tensorflow::StatusFromTF_Status(c_status.get()); } SP_Stream Handle() { return stream_handle_; } diff --git a/third_party/xla/xla/backends/interpreter/executor.cc b/third_party/xla/xla/backends/interpreter/executor.cc index bbfe3964b5df8a..d7118f92c6ea84 100644 --- a/third_party/xla/xla/backends/interpreter/executor.cc +++ b/third_party/xla/xla/backends/interpreter/executor.cc @@ -39,16 +39,6 @@ void XlaInterpreterExecutor::Deallocate(DeviceMemoryBase *mem) { delete[] static_cast(mem->opaque()); } -absl::Status XlaInterpreterExecutor::Memcpy(Stream *stream, void *host_dst, - const DeviceMemoryBase &dev_src, - uint64_t size) { - AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() { - // Ignore errors. - absl::Status ok = SynchronousMemcpy(host_dst, dev_src, size); - }); - return AsExecutorStream(stream)->BlockUntilDone(); -} - absl::Status XlaInterpreterExecutor::SynchronousMemcpy( DeviceMemoryBase *dev_dst, const void *host_src, uint64_t size) { memcpy(dev_dst->opaque(), host_src, size); diff --git a/third_party/xla/xla/backends/interpreter/executor.h b/third_party/xla/xla/backends/interpreter/executor.h index 5a4950552dcd4c..d23fa614048688 100644 --- a/third_party/xla/xla/backends/interpreter/executor.h +++ b/third_party/xla/xla/backends/interpreter/executor.h @@ -90,8 +90,6 @@ class XlaInterpreterExecutor : public StreamExecutorCommon { delete[] static_cast(mem); } - absl::Status Memcpy(Stream *stream, void *host_dst, - const DeviceMemoryBase &dev_src, uint64_t size) override; bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *pop_dst, const DeviceMemoryBase &host_src, uint64_t size) override { diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index d73c2ed935f3ac..129d5de237c795 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -681,20 +681,6 @@ absl::Status GpuExecutor::Memset(Stream* stream, DeviceMemoryBase* location, AsGpuStreamValue(stream)); } -absl::Status GpuExecutor::Memcpy(Stream* stream, void* host_dst, - const DeviceMemoryBase& gpu_src, - uint64_t size) { - bool ok = GpuDriver::AsynchronousMemcpyD2H(context_, host_dst, - AsCudaDevicePtr(gpu_src), size, - AsGpuStreamValue(stream)); - // TODO(b/326130105): Change AsynchronousMemcpyD2H calls to return - // absl::Status. - if (!ok) { - return absl::InternalError("Failed to memcpy from device to host."); - } - return absl::OkStatus(); -} - bool GpuExecutor::MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h index 6c1f89ac00b931..3d3d4a3eafd436 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h @@ -211,9 +211,6 @@ class GpuExecutor : public StreamExecutorCommon { absl::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8_t pattern, uint64_t size) override; - absl::Status Memcpy(Stream* stream, void* host_dst, - const DeviceMemoryBase& gpu_src, uint64_t size) override; - bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, uint64_t size) override; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc index 573bb4d6fcb632..fd9415662d6c4a 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc @@ -88,6 +88,19 @@ absl::Status GpuStream::Memcpy(DeviceMemoryBase* gpu_dst, const void* host_src, return absl::OkStatus(); } +absl::Status GpuStream::Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, + uint64_t size) { + bool ok = GpuDriver::AsynchronousMemcpyD2H( + parent_->gpu_context(), host_dst, + reinterpret_cast(gpu_src.opaque()), size, gpu_stream()); + // TODO(b/326130105): Change AsynchronousMemcpyD2H calls to return + // absl::Status. + if (!ok) { + return absl::InternalError("Failed to memcpy from device to host."); + } + return absl::OkStatus(); +} + absl::Status GpuStream::WaitFor(Stream* other) { GpuStream* other_gpu = AsGpuStream(other); GpuEventHandle other_completed_event = *(other_gpu->completed_event()); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h index 84440824e71f5e..d56321c74436a7 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h @@ -105,9 +105,7 @@ class GpuStream : public StreamCommon { absl::Status Memcpy(DeviceMemoryBase* gpu_dst, const void* host_src, uint64_t size) override; absl::Status Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, - uint64_t size) override { - return StreamCommon::Memcpy(host_dst, gpu_src, size); - } + uint64_t size) override; absl::Status Memcpy(DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, uint64_t size) override { return StreamCommon::Memcpy(gpu_dst, gpu_src, size); diff --git a/third_party/xla/xla/stream_executor/host/host_executor.cc b/third_party/xla/xla/stream_executor/host/host_executor.cc index cf125ac7dbc858..24c6ee65f934f0 100644 --- a/third_party/xla/xla/stream_executor/host/host_executor.cc +++ b/third_party/xla/xla/stream_executor/host/host_executor.cc @@ -143,17 +143,6 @@ absl::Status HostExecutor::SynchronousMemZero(DeviceMemoryBase* location, return absl::OkStatus(); } -absl::Status HostExecutor::Memcpy(Stream* stream, void* host_dst, - const DeviceMemoryBase& gpu_src, - uint64_t size) { - // Enqueue the [asynchronous] memcpy on the stream (HostStream) associated - // with the HostExecutor. - void* src_mem = const_cast(gpu_src.opaque()); - AsHostStream(stream)->EnqueueTask( - [host_dst, src_mem, size]() { memcpy(host_dst, src_mem, size); }); - return absl::OkStatus(); -} - bool HostExecutor::MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, diff --git a/third_party/xla/xla/stream_executor/host/host_executor.h b/third_party/xla/xla/stream_executor/host/host_executor.h index e9102ca83c546a..3ebe6c9a26ad03 100644 --- a/third_party/xla/xla/stream_executor/host/host_executor.h +++ b/third_party/xla/xla/stream_executor/host/host_executor.h @@ -90,8 +90,6 @@ class HostExecutor : public StreamExecutorCommon { delete[] static_cast(mem); } - absl::Status Memcpy(Stream* stream, void* host_dst, - const DeviceMemoryBase& gpu_src, uint64_t size) override; bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, uint64_t size) override; diff --git a/third_party/xla/xla/stream_executor/host/host_stream.cc b/third_party/xla/xla/stream_executor/host/host_stream.cc index 57041acaaa13fb..b10699924f7f95 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream.cc +++ b/third_party/xla/xla/stream_executor/host/host_stream.cc @@ -57,6 +57,15 @@ HostStream::~HostStream() { parent()->DeallocateStream(this); } +absl::Status HostStream::Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, + uint64_t size) { + // Enqueue the [asynchronous] memcpy on the stream (HostStream) associated + // with the HostExecutor. + void* src_mem = const_cast(gpu_src.opaque()); + EnqueueTask([host_dst, src_mem, size]() { memcpy(host_dst, src_mem, size); }); + return absl::OkStatus(); +} + absl::Status HostStream::Memcpy(DeviceMemoryBase* gpu_dst, const void* host_src, uint64_t size) { void* dst_mem = gpu_dst->opaque(); diff --git a/third_party/xla/xla/stream_executor/host/host_stream.h b/third_party/xla/xla/stream_executor/host/host_stream.h index 08428c720c7f31..b97a8d51bce3e6 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream.h +++ b/third_party/xla/xla/stream_executor/host/host_stream.h @@ -64,9 +64,7 @@ class HostStream : public StreamCommon { return StreamCommon::Memcpy(gpu_dst, gpu_src, size); } absl::Status Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, - uint64_t size) override { - return StreamCommon::Memcpy(host_dst, gpu_src, size); - } + uint64_t size) override; private: bool WorkAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); diff --git a/third_party/xla/xla/stream_executor/mock_stream_executor.h b/third_party/xla/xla/stream_executor/mock_stream_executor.h index a207ee9ac1b8f0..33eccd6adfb1ae 100644 --- a/third_party/xla/xla/stream_executor/mock_stream_executor.h +++ b/third_party/xla/xla/stream_executor/mock_stream_executor.h @@ -109,10 +109,6 @@ class MockStreamExecutor : public StreamExecutor { (Stream * stream, DeviceMemoryBase* location, uint8_t pattern, uint64_t size), (override)); - MOCK_METHOD(absl::Status, Memcpy, - (Stream * stream, void* host_dst, - const DeviceMemoryBase& device_src, uint64_t size), - (override)); MOCK_METHOD(bool, MemcpyDeviceToDevice, (Stream * stream, DeviceMemoryBase* device_dst, const DeviceMemoryBase& device_src, uint64_t size), diff --git a/third_party/xla/xla/stream_executor/stream_common.cc b/third_party/xla/xla/stream_executor/stream_common.cc index f171491347f996..43db75984efdba 100644 --- a/third_party/xla/xla/stream_executor/stream_common.cc +++ b/third_party/xla/xla/stream_executor/stream_common.cc @@ -143,12 +143,6 @@ void StreamCommon::ReturnSubStream(Stream *sub_stream) { << sub_stream; } -absl::Status StreamCommon::Memcpy(void *host_dst, - const DeviceMemoryBase &gpu_src, - uint64_t size) { - return parent_->Memcpy(this, host_dst, gpu_src, size); -} - absl::Status StreamCommon::Memcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, uint64_t size) { diff --git a/third_party/xla/xla/stream_executor/stream_common.h b/third_party/xla/xla/stream_executor/stream_common.h index 63b5b6d3a48414..4278fe7063d550 100644 --- a/third_party/xla/xla/stream_executor/stream_common.h +++ b/third_party/xla/xla/stream_executor/stream_common.h @@ -71,9 +71,6 @@ class StreamCommon : public Stream { absl::StatusOr GetOrCreateSubStream() override TF_LOCKS_EXCLUDED(mu_); void ReturnSubStream(Stream *sub_stream) override TF_LOCKS_EXCLUDED(mu_); - absl::Status Memcpy(void *host_dst, const DeviceMemoryBase &gpu_src, - uint64_t size) override; - absl::Status Memcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, uint64_t size) override; absl::Status BlockHostUntilDone() override TF_LOCKS_EXCLUDED(mu_); diff --git a/third_party/xla/xla/stream_executor/stream_executor.h b/third_party/xla/xla/stream_executor/stream_executor.h index a314759fe1079a..593c5c33f3833c 100644 --- a/third_party/xla/xla/stream_executor/stream_executor.h +++ b/third_party/xla/xla/stream_executor/stream_executor.h @@ -231,12 +231,6 @@ class StreamExecutor { return absl::InternalError("Not implemented"); } - // Enqueues a memcpy operation onto stream, with a host destination location - // host_dst and a device memory source, with target size size. - virtual absl::Status Memcpy(Stream* stream, void* host_dst, - const DeviceMemoryBase& device_src, - uint64_t size) = 0; - // Enqueues a memcpy operation onto stream, with a device destination location // and a device source location, with target size size. Peer access should // have been enabled between the StreamExecutors owning the device memory diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc b/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc index 8b82cd019f928b..8bbf2ca1945122 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc @@ -170,16 +170,6 @@ absl::Status TpuExecutor::EnqueueInfeed(int32_t infeed_queue_index, return status.status(); } -absl::Status TpuExecutor::Memcpy( - Stream* stream, void* host_dst, - const ::stream_executor::DeviceMemoryBase& device_src, uint64_t size) { - StatusHelper status; - SE_DeviceMemoryBase se_base = ApiConverter::ToC(device_src); - ExecutorApiFn()->TpuExecutor_MemcpyToHostFn( - executor_, get_stream(stream), host_dst, &se_base, size, status.c_status); - return status.status(); -} - absl::Status TpuExecutor::SynchronousMemcpy( ::stream_executor::DeviceMemoryBase* device_dst, const void* host_src, uint64_t size) { diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h index aa64ded8e1b048..a10efb34f39b91 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h @@ -97,10 +97,6 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { bool HostCallback(Stream* stream, absl::AnyInvocable callback) override; - absl::Status Memcpy(Stream* stream, void* host_dst, - const DeviceMemoryBase& device_src, - uint64_t size) override; - bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& host_src, uint64_t size) override; diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_stream.h b/third_party/xla/xla/stream_executor/tpu/tpu_stream.h index b8e04925ac4f86..c056a4716e3332 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_stream.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_stream.h @@ -131,15 +131,19 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface { se_executor_, stream_, &se_base, host_src, size, status.c_status); return status.status(); } - absl::Status Memcpy(stream_executor::DeviceMemoryBase* gpu_dst, - const stream_executor::DeviceMemoryBase& gpu_src, + absl::Status Memcpy(stream_executor::DeviceMemoryBase* device_dst, + const stream_executor::DeviceMemoryBase& device_src, uint64_t size) override { - return StreamCommon::Memcpy(gpu_dst, gpu_src, size); + return StreamCommon::Memcpy(device_dst, device_src, size); } absl::Status Memcpy(void* host_dst, - const stream_executor::DeviceMemoryBase& gpu_src, + const stream_executor::DeviceMemoryBase& device_src, uint64_t size) override { - return StreamCommon::Memcpy(host_dst, gpu_src, size); + StatusHelper status; + SE_DeviceMemoryBase se_base = ApiConverter::ToC(device_src); + stream_executor::tpu::ExecutorApiFn()->TpuExecutor_MemcpyToHostFn( + se_executor_, stream_, host_dst, &se_base, size, status.c_status); + return status.status(); } SE_Stream* se_stream() const { return stream_; } From f4d5a7abc1338040129e3ce60ee046b9be950ab1 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 20 Jun 2024 12:14:47 -0700 Subject: [PATCH 075/256] [xla:cpu] Add support for AllToAll thunk PiperOrigin-RevId: 645116217 --- third_party/xla/xla/service/cpu/BUILD | 1 + third_party/xla/xla/service/cpu/runtime/BUILD | 32 ++++++ .../service/cpu/runtime/all_to_all_thunk.cc | 99 +++++++++++++++++++ .../service/cpu/runtime/all_to_all_thunk.h | 41 ++++++++ .../xla/xla/service/cpu/runtime/thunk.cc | 2 + .../xla/xla/service/cpu/runtime/thunk.h | 1 + .../xla/xla/service/cpu/thunk_emitter.cc | 31 ++++++ .../xla/xla/service/cpu/thunk_emitter.h | 3 + third_party/xla/xla/tests/BUILD | 1 + .../xla/xla/tests/collective_ops_test.cc | 1 + 10 files changed, 212 insertions(+) create mode 100644 third_party/xla/xla/service/cpu/runtime/all_to_all_thunk.cc create mode 100644 third_party/xla/xla/service/cpu/runtime/all_to_all_thunk.h diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 59488c45da9fe1..7d3457638ad18e 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -814,6 +814,7 @@ cc_library( "//xla/service/cpu:dot_op_emitter", "//xla/service/cpu/runtime:all_gather_thunk", "//xla/service/cpu/runtime:all_reduce_thunk", + "//xla/service/cpu/runtime:all_to_all_thunk", "//xla/service/cpu/runtime:call_thunk", "//xla/service/cpu/runtime:collective_thunk", "//xla/service/cpu/runtime:conditional_thunk", diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD index 994278b8271ee9..c4ec5905e90983 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/service/cpu/runtime/BUILD @@ -256,6 +256,38 @@ cc_library( ], ) +cc_library( + name = "all_to_all_thunk", + srcs = ["all_to_all_thunk.cc"], + hdrs = ["all_to_all_thunk.h"], + deps = [ + ":collective_thunk", + ":thunk", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/runtime:buffer_use", + "//xla/service:buffer_assignment", + "//xla/service:collective_ops_utils", + "//xla/service/cpu:collectives_interface", + "//xla/stream_executor", + "//xla/tsl/concurrency:async_value", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:numbers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:traceme", + ], +) + cc_library( name = "reduce_scatter_thunk", srcs = ["reduce_scatter_thunk.cc"], diff --git a/third_party/xla/xla/service/cpu/runtime/all_to_all_thunk.cc b/third_party/xla/xla/service/cpu/runtime/all_to_all_thunk.cc new file mode 100644 index 00000000000000..f6c066707bb6f4 --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/all_to_all_thunk.cc @@ -0,0 +1,99 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/runtime/all_to_all_thunk.h" + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/cpu/collectives_interface.h" +#include "xla/service/cpu/runtime/collective_thunk.h" +#include "xla/service/cpu/runtime/thunk.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/traceme.h" + +namespace xla::cpu { + +absl::StatusOr> AllToAllThunk::Create( + Info info, OpParams op_params, OpBuffers op_buffers) { + return absl::WrapUnique( + new AllToAllThunk(std::move(info), op_params, std::move(op_buffers))); +} + +AllToAllThunk::AllToAllThunk(Info info, OpParams op_params, + OpBuffers op_buffers) + : CollectiveThunk(Kind::kAllToAll, info, op_params, std::move(op_buffers)) { +} + +tsl::AsyncValueRef AllToAllThunk::Execute( + const ExecuteParams& params) { + tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); + + TF_ASSIGN_OR_RETURN(OpDeviceMemory data, GetOpDeviceMemory(params)); + + VLOG(3) << absl::StreamFormat( + "AllToAll: #source_buffers=%d, #destination_buffers=%d", + data.source.size(), data.destination.size()); + + for (int i = 0; i < data.source.size(); ++i) { + VLOG(3) << absl::StreamFormat( + " src: %s in slice %s (%p)", source_shape(i).ToString(true), + source_buffer(i).ToString(), data.source[i].opaque()); + } + + for (int i = 0; i < data.destination.size(); ++i) { + VLOG(3) << absl::StreamFormat( + " dst: %s in slice %s (%p)", destination_shape(i).ToString(true), + destination_buffer(i).ToString(), data.destination[i].opaque()); + } + + return ExecuteWithCommunicator( + params.collective_params, + [&](const RendezvousKey& key, CollectivesCommunicator& comm) { + const Shape& shape = destination_shape(0); + + absl::InlinedVector input_buffers; + input_buffers.reserve(data.source.size()); + for (int i = 0; i < data.source.size(); ++i) { + input_buffers.push_back(data.source[i].opaque()); + } + + absl::InlinedVector output_buffers; + output_buffers.reserve(data.destination.size()); + for (int i = 0; i < data.destination.size(); ++i) { + output_buffers.push_back(data.destination[i].opaque()); + } + + TF_RETURN_IF_ERROR(comm.AllToAll(key, ShapeUtil::ByteSizeOf(shape), + input_buffers, output_buffers, + DefaultCollectiveTimeout())); + + return absl::OkStatus(); + }); +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/all_to_all_thunk.h b/third_party/xla/xla/service/cpu/runtime/all_to_all_thunk.h new file mode 100644 index 00000000000000..94b0eec845bd7c --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/all_to_all_thunk.h @@ -0,0 +1,41 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_RUNTIME_ALL_TO_ALL_THUNK_H_ +#define XLA_SERVICE_CPU_RUNTIME_ALL_TO_ALL_THUNK_H_ + +#include + +#include "absl/status/statusor.h" +#include "xla/service/cpu/runtime/collective_thunk.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +class AllToAllThunk final : public CollectiveThunk { + public: + static absl::StatusOr> Create( + Info info, OpParams op_params, OpBuffers op_buffers); + + tsl::AsyncValueRef Execute(const ExecuteParams& params) final; + + private: + AllToAllThunk(Info info, OpParams op_params, OpBuffers op_buffers); +}; + +} // namespace xla::cpu + +#endif // XLA_SERVICE_CPU_RUNTIME_ALL_TO_ALL_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.cc b/third_party/xla/xla/service/cpu/runtime/thunk.cc index 97f6c15dadf85b..88484ce5416044 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/thunk.cc @@ -42,6 +42,8 @@ std::string_view Thunk::KindToString(Kind kind) { return "all-gather"; case Kind::kAllReduce: return "all-reduce"; + case Kind::kAllToAll: + return "all-to-all"; case Kind::kCall: return "call"; case Kind::kConditional: diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.h b/third_party/xla/xla/service/cpu/runtime/thunk.h index 0750267c4f850b..7a21dfa858357e 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.h +++ b/third_party/xla/xla/service/cpu/runtime/thunk.h @@ -63,6 +63,7 @@ class Thunk { enum class Kind { kAllGather, kAllReduce, + kAllToAll, kCall, kCopy, kConditional, diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.cc b/third_party/xla/xla/service/cpu/thunk_emitter.cc index 5ad742c8b4cafb..f37d7360326630 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.cc +++ b/third_party/xla/xla/service/cpu/thunk_emitter.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/service/cpu/ir_emitter2.h" #include "xla/service/cpu/runtime/all_gather_thunk.h" #include "xla/service/cpu/runtime/all_reduce_thunk.h" +#include "xla/service/cpu/runtime/all_to_all_thunk.h" #include "xla/service/cpu/runtime/call_thunk.h" #include "xla/service/cpu/runtime/collective_thunk.h" #include "xla/service/cpu/runtime/conditional_thunk.h" @@ -214,6 +215,8 @@ absl::StatusOr ThunkEmitter::EmitHloInstruction( return EmitAllReduceThunk(instruction); case HloOpcode::kReduceScatter: return EmitReduceScatterThunk(instruction); + case HloOpcode::kAllToAll: + return EmitAllToAllThunk(instruction); // TODO(ezhulenev): Port pad optimizations from IrEmitter. case HloOpcode::kPad: @@ -284,6 +287,21 @@ static absl::StatusOr GetCollectiveOpParams( }; } +// TODO(ezhulenev): Figure out why AllToAll instruction does not have +// `use_global_device_ids` field and how to unify it with every other collective +// operation. +static absl::StatusOr GetCollectiveOpParams( + const HloAllToAllInstruction* instruction) { + return CollectiveThunk::OpParams{ + /*op_id=*/instruction->channel_id().has_value() + ? instruction->channel_id().value() + : instruction->GetModule()->unique_id(), + /*has_channel_id=*/instruction->channel_id().has_value(), + /*use_global_device_ids=*/std::nullopt, + /*replica_groups=*/instruction->replica_groups(), + }; +} + static absl::StatusOr GetCollectiveOpBuffers( const HloInstruction* instruction, const BufferAssignment& buffer_assignment) { @@ -348,6 +366,19 @@ absl::StatusOr ThunkEmitter::EmitAllReduceThunk( std::move(op_buffers), single_replica); } +absl::StatusOr ThunkEmitter::EmitAllToAllThunk( + const HloInstruction* instruction) { + auto* all_to_all = Cast(instruction); + + TF_ASSIGN_OR_RETURN(AllToAllThunk::OpParams op_params, + GetCollectiveOpParams(all_to_all)); + TF_ASSIGN_OR_RETURN(AllToAllThunk::OpBuffers op_buffers, + GetCollectiveOpBuffers(all_to_all, buffer_assignment_)); + + return ThunkSequence::Of( + ThunkInfo(all_to_all), std::move(op_params), std::move(op_buffers)); +} + absl::StatusOr ThunkEmitter::EmitReduceScatterThunk( const HloInstruction* instruction) { auto* reduce_scatter = Cast(instruction); diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.h b/third_party/xla/xla/service/cpu/thunk_emitter.h index d4a33366ef931e..cc5a5175d6d92f 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.h +++ b/third_party/xla/xla/service/cpu/thunk_emitter.h @@ -113,6 +113,9 @@ class ThunkEmitter { absl::StatusOr EmitAllReduceThunk( const HloInstruction* instruction); + absl::StatusOr EmitAllToAllThunk( + const HloInstruction* instruction); + absl::StatusOr EmitReduceScatterThunk( const HloInstruction* instruction); diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index f66a6c6099d7f9..bf95343ebd314d 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -2174,6 +2174,7 @@ xla_test( ":test_utils", ":xla_internal_test_main", "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla/service:hlo_module_config", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/tests/collective_ops_test.cc b/third_party/xla/xla/tests/collective_ops_test.cc index 023df482809936..d9509a1eec2fa7 100644 --- a/third_party/xla/xla/tests/collective_ops_test.cc +++ b/third_party/xla/xla/tests/collective_ops_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/str_replace.h" #include "absl/types/span.h" #include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/primitive_util.h" #include "xla/service/hlo_module_config.h" #include "xla/tests/hlo_test_base.h" From 77fbee0f68cea7da77caf6d21b72da1565981a5a Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Thu, 20 Jun 2024 13:07:08 -0700 Subject: [PATCH 076/256] [IFRT] Add PjRt<->IFRT attribute map conversion utility functions This change adds utility functions for converting between `xla::ifrt::AttributeMap` and `absl::flat_hash_map` (`xla::PjRtValueType` = `xla::PjRtDeviceAttribute`). This conversion will be used in the code that exports attributes from IFRT while the source of the information comes from a PjRt attribute map. PiperOrigin-RevId: 645131451 --- third_party/xla/xla/python/pjrt_ifrt/BUILD | 24 +++++ .../pjrt_ifrt/pjrt_attribute_map_util.cc | 87 +++++++++++++++++++ .../pjrt_ifrt/pjrt_attribute_map_util.h | 39 +++++++++ .../pjrt_ifrt/pjrt_attribute_map_util_test.cc | 75 ++++++++++++++++ 4 files changed, 225 insertions(+) create mode 100644 third_party/xla/xla/python/pjrt_ifrt/pjrt_attribute_map_util.cc create mode 100644 third_party/xla/xla/python/pjrt_ifrt/pjrt_attribute_map_util.h create mode 100644 third_party/xla/xla/python/pjrt_ifrt/pjrt_attribute_map_util_test.cc diff --git a/third_party/xla/xla/python/pjrt_ifrt/BUILD b/third_party/xla/xla/python/pjrt_ifrt/BUILD index 86427ec27a8214..8d5f6adff7c6b1 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/BUILD +++ b/third_party/xla/xla/python/pjrt_ifrt/BUILD @@ -271,6 +271,30 @@ cc_library( alwayslink = True, ) +cc_library( + name = "pjrt_attribute_map_util", + srcs = ["pjrt_attribute_map_util.cc"], + hdrs = ["pjrt_attribute_map_util.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//xla/pjrt:pjrt_common", + "//xla/python/ifrt:attribute_map", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + +xla_cc_test( + name = "pjrt_attribute_map_util_test", + srcs = ["pjrt_attribute_map_util_test.cc"], + deps = [ + ":pjrt_attribute_map_util", + "//xla/pjrt:pjrt_common", + "//xla/python/ifrt:attribute_map", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "basic_string_array", srcs = ["basic_string_array.cc"], diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_attribute_map_util.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_attribute_map_util.cc new file mode 100644 index 00000000000000..27b7ab4ce3d19f --- /dev/null +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_attribute_map_util.cc @@ -0,0 +1,87 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/python/ifrt/attribute_map.h" + +namespace xla { +namespace ifrt { + +AttributeMap FromPjRtDeviceAttributeMap( + absl::flat_hash_map attributes) { + AttributeMap::Map result; + result.reserve(attributes.size()); + for (auto& item : attributes) { + std::visit( + [&](auto& value) { + using T = std::decay_t; + const auto& key = item.first; + if constexpr (std::is_same_v) { + result.insert({key, AttributeMap::StringValue(std::move(value))}); + } else if constexpr (std::is_same_v) { + result.insert({key, AttributeMap::BoolValue(value)}); + } else if constexpr (std::is_same_v) { + result.insert({key, AttributeMap::Int64Value(value)}); + } else if constexpr (std::is_same_v>) { + result.insert( + {key, AttributeMap::Int64ListValue(std::move(value))}); + } else if constexpr (std::is_same_v) { + result.insert({key, AttributeMap::FloatValue(value)}); + } + }, + item.second); + } + return AttributeMap(std::move(result)); +} + +absl::flat_hash_map ToPjRtDeviceAttributeMap( + AttributeMap attributes) { + absl::flat_hash_map result; + result.reserve(attributes.map().size()); + for (auto& item : attributes.map()) { + std::visit( + [&](auto& value) { + using T = std::decay_t; + const auto& key = item.first; + if constexpr (std::is_same_v) { + result.insert({key, std::move(value.value)}); + } else if constexpr (std::is_same_v) { + result.insert({key, value.value}); + } else if constexpr (std::is_same_v) { + result.insert({key, value.value}); + } else if constexpr (std::is_same_v) { + result.insert({key, std::move(value.value)}); + } else if constexpr (std::is_same_v) { + result.insert({key, value.value}); + } + }, + item.second); + } + return result; +} + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_attribute_map_util.h b/third_party/xla/xla/python/pjrt_ifrt/pjrt_attribute_map_util.h new file mode 100644 index 00000000000000..935c244266d052 --- /dev/null +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_attribute_map_util.h @@ -0,0 +1,39 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_PJRT_IFRT_PJRT_ATTRIBUTE_MAP_UTIL_H_ +#define XLA_PYTHON_PJRT_IFRT_PJRT_ATTRIBUTE_MAP_UTIL_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/python/ifrt/attribute_map.h" + +namespace xla { +namespace ifrt { + +// Converts a PjRt device attribute map into an IFRT attribute map. +AttributeMap FromPjRtDeviceAttributeMap( + absl::flat_hash_map attributes); + +// Converts an IFRT attribute map into a PjRt device attribute map. +absl::flat_hash_map ToPjRtDeviceAttributeMap( + AttributeMap attributes); + +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_PJRT_IFRT_PJRT_ATTRIBUTE_MAP_UTIL_H_ diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_attribute_map_util_test.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_attribute_map_util_test.cc new file mode 100644 index 00000000000000..f90c5b49f02937 --- /dev/null +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_attribute_map_util_test.cc @@ -0,0 +1,75 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" + +#include +#include +#include + +#include +#include "absl/container/flat_hash_map.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/python/ifrt/attribute_map.h" + +namespace xla { +namespace ifrt { +namespace { + +TEST(PjRtAttributeMapUtilTest, FromPjRtDeviceAttributeMap) { + absl::flat_hash_map pjrt_map({ + {"string", xla::PjRtValueType(std::string("value"))}, + {"bool", xla::PjRtValueType(true)}, + {"int64", xla::PjRtValueType(int64_t{123})}, + {"int64_list", + xla::PjRtValueType(std::vector({int64_t{1}, int64_t{2}}))}, + {"float", xla::PjRtValueType(1.23f)}, + }); + + EXPECT_EQ(FromPjRtDeviceAttributeMap(pjrt_map).map(), + AttributeMap::Map({ + {"string", AttributeMap::StringValue("value")}, + {"bool", AttributeMap::BoolValue(true)}, + {"int64", AttributeMap::Int64Value(123)}, + {"int64_list", + AttributeMap::Int64ListValue({int64_t{1}, int64_t{2}})}, + {"float", AttributeMap::FloatValue(1.23f)}, + })); +} + +TEST(PjRtAttributeMapUtilTest, ToPjRtDeviceAttributeMap) { + AttributeMap map({ + {"string", AttributeMap::StringValue("value")}, + {"bool", AttributeMap::BoolValue(true)}, + {"int64", AttributeMap::Int64Value(123)}, + {"int64_list", AttributeMap::Int64ListValue({int64_t{1}, int64_t{2}})}, + {"float", AttributeMap::FloatValue(1.23f)}, + }); + + EXPECT_EQ( + ToPjRtDeviceAttributeMap(map), + (absl::flat_hash_map({ + {"string", xla::PjRtValueType(std::string("value"))}, + {"bool", xla::PjRtValueType(true)}, + {"int64", xla::PjRtValueType(int64_t{123})}, + {"int64_list", + xla::PjRtValueType(std::vector({int64_t{1}, int64_t{2}}))}, + {"float", xla::PjRtValueType(1.23f)}, + }))); +} + +} // namespace +} // namespace ifrt +} // namespace xla From c02a0c5ce19dd441a85417c840fd8a689a17fd27 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Thu, 20 Jun 2024 13:13:08 -0700 Subject: [PATCH 077/256] Stop using xla/statusor.h now that it just contains an alias for absl::Status. In some situations, this meant also changing unrelated files to directly include tsl/platform/statusor.h to get the definitions for TF_ASSIGN_OR_RETURN, etc., where they were getting transitively included for free. PiperOrigin-RevId: 645133070 --- third_party/xla/xla/client/BUILD | 16 ++++++++-------- third_party/xla/xla/client/client.h | 2 +- third_party/xla/xla/client/client_library.h | 2 +- third_party/xla/xla/client/compile_only_client.h | 2 +- .../xla/xla/client/executable_build_options.cc | 2 +- third_party/xla/xla/client/local_client.h | 2 +- third_party/xla/xla/client/padding.h | 2 +- third_party/xla/xla/client/value_inference.cc | 2 +- third_party/xla/xla/client/xla_builder.h | 2 +- third_party/xla/xla/pjrt/gpu/BUILD | 6 ++---- third_party/xla/xla/pjrt/gpu/gpu_helpers.cc | 2 +- third_party/xla/xla/pjrt/gpu/gpu_helpers.h | 2 +- third_party/xla/xla/pjrt/gpu/nccl_id_store.cc | 2 +- third_party/xla/xla/pjrt/gpu/nccl_id_store.h | 2 +- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 1 - .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.h | 1 - .../xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 1 - third_party/xla/xla/service/llvm_ir/BUILD | 9 +++++---- .../xla/xla/service/llvm_ir/fused_ir_emitter.cc | 2 +- .../xla/xla/service/llvm_ir/fused_ir_emitter.h | 2 +- third_party/xla/xla/service/llvm_ir/ir_array.cc | 2 +- third_party/xla/xla/service/llvm_ir/llvm_util.cc | 2 +- third_party/xla/xla/service/llvm_ir/llvm_util.h | 2 +- .../xla/xla/service/llvm_ir/loop_emitter.cc | 1 + .../xla/xla/service/llvm_ir/loop_emitter.h | 2 +- 25 files changed, 34 insertions(+), 37 deletions(-) diff --git a/third_party/xla/xla/client/BUILD b/third_party/xla/xla/client/BUILD index 400eb877cbbce1..1e7f1bf6615628 100644 --- a/third_party/xla/xla/client/BUILD +++ b/third_party/xla/xla/client/BUILD @@ -46,10 +46,10 @@ cc_library( srcs = ["padding.cc"], hdrs = ["padding.h"], deps = [ - "//xla:statusor", "//xla:types", "//xla:util", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@local_tsl//tsl/lib/math:math_util", "@local_tsl//tsl/platform:logging", @@ -77,13 +77,13 @@ cc_library( "//xla:execution_options_util", "//xla:literal", "//xla:status_macros", - "//xla:statusor", "//xla:types", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/service", "//xla/service:hlo_proto_cc", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", @@ -100,7 +100,6 @@ cc_library( "//xla:debug_options_flags", "//xla:execution_options_util", "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla:xla_proto_cc", "//xla/pjrt:compile_options_proto_cc", @@ -110,6 +109,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", @@ -128,7 +128,6 @@ cc_library( ":xla_computation", "//xla:executable_run_options", "//xla:shape_tree", - "//xla:statusor", "//xla:xla_data_proto_cc", "//xla/service:backend", "//xla/service:compiler", @@ -143,6 +142,7 @@ cc_library( "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", ], ) @@ -155,11 +155,11 @@ cc_library( ":client", ":xla_computation", "//xla:status_macros", - "//xla:statusor", "//xla:xla_data_proto_cc", "//xla/service:compile_only_service", "//xla/service:compiler", "//xla/stream_executor", + "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", "@llvm-project//llvm:TargetParser", ], @@ -175,7 +175,6 @@ cc_library( ":compile_only_client", ":local_client", "//xla:status_macros", - "//xla:statusor", "//xla:types", "//xla:util", "//xla/service:backend", @@ -185,6 +184,7 @@ cc_library( "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:logging", ], ) @@ -229,13 +229,13 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/service:hlo_proto_cc", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", @@ -263,7 +263,6 @@ cc_library( "//xla:shape_util", "//xla:sharding_op_util", "//xla:status_macros", - "//xla:statusor", "//xla:util", "//xla:window_util", "//xla:xla_data_proto_cc", @@ -278,6 +277,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/lib/core:bitmap", diff --git a/third_party/xla/xla/client/client.h b/third_party/xla/xla/client/client.h index 1ecfcfe6f358eb..49c9911f5cdf1b 100644 --- a/third_party/xla/xla/client/client.h +++ b/third_party/xla/xla/client/client.h @@ -21,12 +21,12 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/client/xla_computation.h" #include "xla/literal.h" #include "xla/service/hlo.pb.h" #include "xla/service/service.h" -#include "xla/statusor.h" #include "xla/types.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/client/client_library.h b/third_party/xla/xla/client/client_library.h index bb84e521b3b5ce..db867329be71ad 100644 --- a/third_party/xla/xla/client/client_library.h +++ b/third_party/xla/xla/client/client_library.h @@ -29,11 +29,11 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "xla/client/compile_only_client.h" #include "xla/client/local_client.h" #include "xla/service/compile_only_service.h" #include "xla/service/local_service.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor.h" #include "xla/types.h" diff --git a/third_party/xla/xla/client/compile_only_client.h b/third_party/xla/xla/client/compile_only_client.h index fac2fbd98932a1..8dde8c884dd78c 100644 --- a/third_party/xla/xla/client/compile_only_client.h +++ b/third_party/xla/xla/client/compile_only_client.h @@ -19,11 +19,11 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "xla/client/client.h" #include "xla/client/xla_computation.h" #include "xla/service/compile_only_service.h" #include "xla/service/compiler.h" -#include "xla/statusor.h" #include "xla/stream_executor/stream_executor.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/client/executable_build_options.cc b/third_party/xla/xla/client/executable_build_options.cc index b746e965eca9e0..77c7791d151bbe 100644 --- a/third_party/xla/xla/client/executable_build_options.cc +++ b/third_party/xla/xla/client/executable_build_options.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/log/check.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "xla/debug_options_flags.h" #include "xla/execution_options_util.h" @@ -30,7 +31,6 @@ limitations under the License. #include "xla/service/computation_placer.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/client/local_client.h b/third_party/xla/xla/client/local_client.h index f26c67ced132c8..236ebe0bfb2f3c 100644 --- a/third_party/xla/xla/client/local_client.h +++ b/third_party/xla/xla/client/local_client.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/client/client.h" #include "xla/client/executable_build_options.h" @@ -32,7 +33,6 @@ limitations under the License. #include "xla/service/maybe_owning_device_memory.h" #include "xla/service/shaped_buffer.h" #include "xla/shape_tree.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/client/padding.h b/third_party/xla/xla/client/padding.h index 50ed6e58057ad8..e71522616bf1ab 100644 --- a/third_party/xla/xla/client/padding.h +++ b/third_party/xla/xla/client/padding.h @@ -19,8 +19,8 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/statusor.h" #include "xla/types.h" namespace xla { diff --git a/third_party/xla/xla/client/value_inference.cc b/third_party/xla/xla/client/value_inference.cc index 66cdf55be98a7f..1ba694ad6154c9 100644 --- a/third_party/xla/xla/client/value_inference.cc +++ b/third_party/xla/xla/client/value_inference.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "xla/comparison_util.h" @@ -34,7 +35,6 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/client/xla_builder.h b/third_party/xla/xla/client/xla_builder.h index 0e7ba2d23f49ec..c1192bf716a0d3 100644 --- a/third_party/xla/xla/client/xla_builder.h +++ b/third_party/xla/xla/client/xla_builder.h @@ -33,6 +33,7 @@ limitations under the License. #include "absl/functional/function_ref.h" #include "absl/log/check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/array.h" @@ -51,7 +52,6 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/bitmap.h" diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 24d90e89faba2c..217833f0fdcdde 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -19,7 +19,6 @@ cc_library( hdrs = ["gpu_helpers.h"], visibility = internal_visibility(["//xla/pjrt:friends"]), deps = [ - "//xla:statusor", "//xla:types", "//xla:util", "//xla/client:client_library", @@ -30,6 +29,7 @@ cc_library( "//xla/tsl/framework:bfc_allocator", "//xla/tsl/framework:device_id_impl", "//xla/tsl/util:env_var", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", ], ) @@ -47,7 +47,6 @@ cc_library( "//xla:literal", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:util", "//xla:xla_proto_cc", "//xla/client:client_library", @@ -145,7 +144,6 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:test", "//xla:xla_data_proto_cc", "//xla/client:xla_computation", @@ -184,13 +182,13 @@ cc_library( hdrs = ["nccl_id_store.h"], deps = [ "//xla:status_macros", - "//xla:statusor", "//xla/pjrt/distributed:key_value_store_interface", "//xla/service:global_device_id", "//xla/service/gpu/runtime:nccl_api", "//xla/service/gpu/runtime:nccl_clique_key", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@local_tsl//tsl/platform:errors", diff --git a/third_party/xla/xla/pjrt/gpu/gpu_helpers.cc b/third_party/xla/xla/pjrt/gpu/gpu_helpers.cc index e33521addb4765..d0fcfe70c1e29d 100644 --- a/third_party/xla/xla/pjrt/gpu/gpu_helpers.cc +++ b/third_party/xla/xla/pjrt/gpu/gpu_helpers.cc @@ -23,9 +23,9 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "xla/client/client_library.h" #include "xla/service/platform_util.h" -#include "xla/statusor.h" #include "xla/stream_executor/integrations/device_host_allocator.h" #include "xla/stream_executor/integrations/device_mem_allocator.h" #include "xla/tsl/framework/device_id.h" diff --git a/third_party/xla/xla/pjrt/gpu/gpu_helpers.h b/third_party/xla/xla/pjrt/gpu/gpu_helpers.h index 450dca53402881..7657483d197341 100644 --- a/third_party/xla/xla/pjrt/gpu/gpu_helpers.h +++ b/third_party/xla/xla/pjrt/gpu/gpu_helpers.h @@ -22,9 +22,9 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/client/local_client.h" -#include "xla/statusor.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/framework/bfc_allocator.h" #include "xla/types.h" diff --git a/third_party/xla/xla/pjrt/gpu/nccl_id_store.cc b/third_party/xla/xla/pjrt/gpu/nccl_id_store.cc index c61dbb748f40b8..2825f87d2f14c3 100644 --- a/third_party/xla/xla/pjrt/gpu/nccl_id_store.cc +++ b/third_party/xla/xla/pjrt/gpu/nccl_id_store.cc @@ -18,12 +18,12 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "xla/service/gpu/runtime/nccl_api.h" #include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/pjrt/gpu/nccl_id_store.h b/third_party/xla/xla/pjrt/gpu/nccl_id_store.h index 389775debd4caf..db02988fc7d88b 100644 --- a/third_party/xla/xla/pjrt/gpu/nccl_id_store.h +++ b/third_party/xla/xla/pjrt/gpu/nccl_id_store.h @@ -21,11 +21,11 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/service/global_device_id.h" #include "xla/service/gpu/runtime/nccl_clique_key.h" -#include "xla/statusor.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 6e40add20420af..4be7dc6deb7f66 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -108,7 +108,6 @@ limitations under the License. #endif #include "xla/service/gpu/gpu_executable_run_options.h" -#include "xla/statusor.h" #include "xla/stream_executor/integrations/device_mem_allocator.h" #include "xla/stream_executor/integrations/tf_allocator_adapter.h" #include "xla/util.h" diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h index 2afadec5fdfc14..db351f143ff44e 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -48,7 +48,6 @@ limitations under the License. #include "xla/service/computation_placer.h" #include "xla/service/gpu/gpu_executable_run_options.h" #include "xla/shape.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/tsl/framework/allocator.h" #include "tsl/platform/casts.h" diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index c95372e1e60468..40a0c1115074f9 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -51,7 +51,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/stream_executor/stream.h" #include "xla/test.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/service/llvm_ir/BUILD b/third_party/xla/xla/service/llvm_ir/BUILD index 88c927a856e477..d5645efd460ce4 100644 --- a/third_party/xla/xla/service/llvm_ir/BUILD +++ b/third_party/xla/xla/service/llvm_ir/BUILD @@ -71,7 +71,6 @@ cc_library( ":llvm_type_conversion_util", "//xla:literal", "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -80,6 +79,7 @@ cc_library( "//xla/service/cpu:cpu_options", "@com_google_absl//absl/base", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", @@ -124,11 +124,11 @@ cc_library( ":llvm_util", "//xla:permutation_util", "//xla:shape_util", - "//xla:statusor", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", @@ -163,14 +163,15 @@ cc_library( ":llvm_loop", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:types", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:Core", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:statusor", ], ) @@ -184,12 +185,12 @@ cc_library( ":tuple_ops", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:elemental_ir_emitter", "//xla/service:fusion_node_indexing_evaluation", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Core", "@llvm-project//llvm:TargetParser", "@local_tsl//tsl/platform:errors", diff --git a/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.cc b/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.cc index f69cc88a14f711..16b10a0eefc977 100644 --- a/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" @@ -35,7 +36,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.h b/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.h index 098666d890eb96..e3e5c8204f820f 100644 --- a/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.h +++ b/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.h @@ -19,10 +19,10 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "llvm/IR/Value.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/elemental_ir_emitter.h" -#include "xla/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.cc b/third_party/xla/xla/service/llvm_ir/ir_array.cc index 62ec82249b572e..38833e64526b57 100644 --- a/third_party/xla/xla/service/llvm_ir/ir_array.cc +++ b/third_party/xla/xla/service/llvm_ir/ir_array.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" @@ -33,7 +34,6 @@ limitations under the License. #include "xla/service/llvm_ir/llvm_type_conversion_util.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.cc b/third_party/xla/xla/service/llvm_ir/llvm_util.cc index 1992f0dea01dae..a691aee395d46a 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_util.cc +++ b/third_party/xla/xla/service/llvm_ir/llvm_util.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -68,7 +69,6 @@ limitations under the License. #include "xla/service/llvm_ir/llvm_type_conversion_util.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/byte_order.h" diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.h b/third_party/xla/xla/service/llvm_ir/llvm_util.h index fab4e6d9e90b4c..b22f171f1fa4d5 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_util.h +++ b/third_party/xla/xla/service/llvm_ir/llvm_util.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/IR/BasicBlock.h" @@ -42,7 +43,6 @@ limitations under the License. #include "xla/literal.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" -#include "xla/statusor.h" #include "xla/xla_data.pb.h" namespace llvm { diff --git a/third_party/xla/xla/service/llvm_ir/loop_emitter.cc b/third_party/xla/xla/service/llvm_ir/loop_emitter.cc index 260fbc228250cb..17ae97c578b0be 100644 --- a/third_party/xla/xla/service/llvm_ir/loop_emitter.cc +++ b/third_party/xla/xla/service/llvm_ir/loop_emitter.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/protobuf.h" +#include "tsl/platform/statusor.h" namespace xla { namespace llvm_ir { diff --git a/third_party/xla/xla/service/llvm_ir/loop_emitter.h b/third_party/xla/xla/service/llvm_ir/loop_emitter.h index 3182de779f9987..40c6ee6e8c36f0 100644 --- a/third_party/xla/xla/service/llvm_ir/loop_emitter.h +++ b/third_party/xla/xla/service/llvm_ir/loop_emitter.h @@ -18,12 +18,12 @@ limitations under the License. #include +#include "absl/status/statusor.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_loop.h" -#include "xla/statusor.h" namespace xla { namespace llvm_ir { From 1789531f168db0ba3d342b67104bb89b7e05394f Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Thu, 20 Jun 2024 13:28:59 -0700 Subject: [PATCH 078/256] [xla:hlo][NFC] Fix dims in a comment to match the size of reshape_dims. PiperOrigin-RevId: 645137856 --- third_party/xla/xla/hlo/ir/tile_assignment.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/tile_assignment.h b/third_party/xla/xla/hlo/ir/tile_assignment.h index ca7f23bda2dc14..ac277e9f4dddb6 100644 --- a/third_party/xla/xla/hlo/ir/tile_assignment.h +++ b/third_party/xla/xla/hlo/ir/tile_assignment.h @@ -53,11 +53,11 @@ class IotaTileAssignment { // `reshape_dims`: is the dimensions the 1D iota array is reshaped to. // `transpose_perm`: is the dimension permutation to transpose `reshape_dims`. // - // e.g. dims=[8,8,8] reshape_dims=[4,2,2], transpose_perm=[0,1,2] (no - // transpose) corresponds to [8,8,8]<=[16] which in full array V1 format is + // e.g. dims=[4,4,1] reshape_dims=[4,2,2], transpose_perm=[0,1,2] (no + // transpose) corresponds to [4,4,1]<=[16] which in full array V1 format is // [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]. - // e.g. dims=[8,8,8] reshape_dims=[4,2,2], transpose_perm=[1,0,2] (swap dim 0 - // and dim 1) corresponds to [8,8,8]<=[4,2,2]T(1,0,2) which in full array V1 + // e.g. dims=[4,4,1] reshape_dims=[4,2,2], transpose_perm=[1,0,2] (swap dim 0 + // and dim 1) corresponds to [4,4,1]<=[4,2,2]T(1,0,2) which in full array V1 // format is [0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15]. static IotaTileAssignment Create(absl::Span dims, absl::Span reshape_dims, From 79ead0e9141aa888fcc2a847b0b26bad8474f715 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Thu, 20 Jun 2024 13:37:02 -0700 Subject: [PATCH 079/256] [mhlo] Remove UnaryEinsumOp from MHLO PiperOrigin-RevId: 645140478 --- .../xla/mlir_hlo/mhlo/IR/chlo_canonicalize.td | 30 ------------------ .../xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc | 9 ------ .../xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td | 31 ------------------- .../xla/mlir_hlo/mhlo/IR/mhlo_canonicalize.td | 7 ----- .../hlo_legalize_to_stablehlo.cc | 26 +++++++++++++--- .../mhlo/transforms/map_stablehlo_to_hlo_op.h | 2 +- .../stablehlo_legalize_to_hlo.cc | 14 +++++++++ .../mhlo/canonicalize/canonicalize.mlir | 8 ----- .../mhlo/hlo-legalize-to-stablehlo.mlir | 11 ------- .../mhlo/stablehlo-legalize-to-hlo.mlir | 22 ++++++------- .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 6 ---- 11 files changed, 47 insertions(+), 119 deletions(-) delete mode 100644 third_party/xla/xla/mlir_hlo/mhlo/IR/chlo_canonicalize.td diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/chlo_canonicalize.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/chlo_canonicalize.td deleted file mode 100644 index 2db08da7e5f2ee..00000000000000 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/chlo_canonicalize.td +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This is the canonicalize pattern definition file. - -include "mlir/IR/OpBase.td" -include "mhlo/IR/hlo_ops.td" -include "mhlo/IR/hlo_utils.td" - -def UnaryToBinaryEinsumEq : NativeCodeCall< - "$_builder.getStringAttr(\",\" + $0.getValue().str())">; - -// Convert UnaryEinsumOp to EinsumOp with two operands with redundant first -// operand. -def UnaryEinsumToEinsum : Pat< - (MHLO_UnaryEinsumOp $operand, $equation), - (MHLO_EinsumOp (MHLO_ConstantOp (GetScalarOfType<1> $operand)), - $operand, (UnaryToBinaryEinsumEq $equation))>; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index a5bd1fb60cc525..740c32e4f4e858 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -5559,15 +5559,6 @@ LogicalResult TupleOp::inferReturnTypes( inferredReturnTypes); } -//===----------------------------------------------------------------------===// -// UnaryEinsumOp -//===----------------------------------------------------------------------===// - -void UnaryEinsumOp::getCanonicalizationPatterns(RewritePatternSet& results, - MLIRContext* context) { - results.add(context); -} - //===----------------------------------------------------------------------===// // CompareOp //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td index 35eff2edbb202f..e945cc95d4f149 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td @@ -2687,37 +2687,6 @@ def MHLO_EinsumOp: MHLO_Op<"einsum", [Pure]> { // side HLO ops. } -def MHLO_UnaryEinsumOp: MHLO_Op<"unary_einsum", [Pure]> { - let summary = "UnaryEinsum operation"; - let description = [{ - This operation is on its way out of StableHLO, so it is not included in - the specification: https://github.com/openxla/stablehlo/issues/3. - - Informally, this operation does the same thing as TF's einsum: - https://www.tensorflow.org/api_docs/python/tf/einsum - - Example: - ```mlir - %result = "mhlo.unary_einsum"(%operand) { - einsum_config = "ab->a" - } : (tensor<4x16xf32>) -> tensor<4xf32> - ``` - }]; - - let arguments = (ins - MHLO_Tensor:$operand, - StrAttr:$einsum_config - ); - - let results = (outs MHLO_Tensor); - - let hasCanonicalizer = 1; - - // UnaryEinsumOp is unconditionally canonicalized to the binary EinsumOp so - // the HLO converter shouldn't be invoked. - let hasCustomHLOConverter = 1; -} - def MHLO_FftOp: MHLO_Op<"fft", [InferTensorType, Pure]> { let summary = "Fft operation"; let description = [{ diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/mhlo_canonicalize.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/mhlo_canonicalize.td index e737d8cbf169f0..0caa5c0a0e23e3 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/mhlo_canonicalize.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/mhlo_canonicalize.td @@ -25,13 +25,6 @@ def UnaryToBinaryEinsumEq : NativeCodeCall< def GetI64DenseElementsAttr : NativeCodeCall< "$0.mapValues($_builder.getI64Type(), [](llvm::APInt x) { return x.sext(64); })">; -// Convert UnaryEinsumOp to EinsumOp with two operands with redundant first -// operand. -def UnaryEinsumToEinsum : Pat< - (MHLO_UnaryEinsumOp $operand, $equation), - (MHLO_EinsumOp (MHLO_ConstantOp (GetScalarOfType<1> $operand)), - $operand, (UnaryToBinaryEinsumEq $equation))>; - // A dynamic reshape of a dynamic reshape is a dynamic reshape. def RemoveRedundantDynamicReshape : Pat< (MHLO_DynamicReshapeOp (MHLO_DynamicReshapeOp $operand, $shape1), $shape2), diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc index da0897354f131a..96c599a8600adf 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc @@ -578,8 +578,11 @@ class HloToStablehloCustomCallOpConverter bool allowExperimentalFeatures; }; -template -class HloToStablehloOpConverter : public OpConversionPattern { +template +class HloToStablehloOpConverter + : public OpConversionPattern> { + using HloOpTy = StablehloToHloOp; + public: HloToStablehloOpConverter(TypeConverter& converter, MLIRContext* context, bool allowExperimentalFeatures) @@ -678,14 +681,27 @@ class HloToStablehloOpConverter : public OpConversionPattern { bool allowExperimentalFeatures; }; +// Deprecated ops. +template <> +class HloToStablehloOpConverter + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(stablehlo::UnaryEinsumOp stablehloOp, + typename stablehlo::UnaryEinsumOp::Adaptor, + ConversionPatternRewriter&) const final { + return stablehloOp.emitError( + "UnaryEinsumOp is deprecated and not supported in MHLO"); + } +}; + template void populateHloToStablehloPatterns(RewritePatternSet* patterns, TypeConverter* converter, MLIRContext* context, bool allowExperimentalFeatures) { - patterns - ->add>...>( - *converter, context, allowExperimentalFeatures); + patterns->add...>( + *converter, context, allowExperimentalFeatures); } template diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/map_stablehlo_to_hlo_op.h b/third_party/xla/xla/mlir_hlo/mhlo/transforms/map_stablehlo_to_hlo_op.h index 80c6c0b0bd1ea7..39d6de77380ae8 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/map_stablehlo_to_hlo_op.h +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/map_stablehlo_to_hlo_op.h @@ -153,7 +153,7 @@ MAP_STABLEHLO_TO_HLO(TorchIndexSelectOp) MAP_STABLEHLO_TO_HLO(TransposeOp) MAP_STABLEHLO_TO_HLO(TriangularSolveOp) MAP_STABLEHLO_TO_HLO(TupleOp) -MAP_STABLEHLO_TO_HLO(UnaryEinsumOp) +// (deprecated) MAP_STABLEHLO_TO_HLO(UnaryEinsumOp) MAP_STABLEHLO_TO_HLO(UniformDequantizeOp) MAP_STABLEHLO_TO_HLO(UniformQuantizeOp) MAP_STABLEHLO_TO_HLO(WhileOp) diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc index 4d329a95c12c69..cc3cb691dbd364 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc @@ -390,6 +390,20 @@ class StablehloToHloOpConverter : public OpConversionPattern { } }; +// Deprecated ops. +template <> +class StablehloToHloOpConverter + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(stablehlo::UnaryEinsumOp stablehloOp, + typename stablehlo::UnaryEinsumOp::Adaptor, + ConversionPatternRewriter&) const final { + return stablehloOp.emitError( + "UnaryEinsumOp is deprecated and not supported in MHLO"); + } +}; + template void populateStablehloToHloPatterns(RewritePatternSet* patterns, TypeConverter* converter, diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir index 371ff439b06341..b34126213193df 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir @@ -973,14 +973,6 @@ func.func @iota_broadcast_second() -> tensor<5x4xi32> { func.return %0 : tensor<5x4xi32> } -// CHECK-LABEL: @unary_einsum -func.func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { - // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK: "mhlo.einsum"(%[[ONE]], %arg0) <{einsum_config = ",ab->aa"}> - %0 = "mhlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - // CHECK-LABEL: func @fold_copy // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func.func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> { diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 5e7b46629b3491..12facaef8bcc72 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -1655,17 +1655,6 @@ func.func @op_tuple(%arg0: tensor) -> tuple> { func.return %0 : tuple> } -// CHECK-LABEL: "op_unary_einsum" -func.func @op_unary_einsum(%arg0: tensor<8x16xf32>) -> tensor<8xf32> { - // CHECK: "stablehlo.unary_einsum"([[ARG0:%arg[0-9]+]]) <{ - // CHECK-SAME: einsum_config = "ab->a" - // CHECK-SAME: }> : (tensor<8x16xf32>) -> tensor<8xf32> - %0 = "mhlo.unary_einsum"(%arg0) { - einsum_config = "ab->a" - } : (tensor<8x16xf32>) -> tensor<8xf32> - func.return %0 : tensor<8xf32> -} - // CHECK-LABEL: "op_uniform_dequantize" func.func @op_uniform_dequantize(%arg0: tensor>) -> tensor { // CHECK: "stablehlo.uniform_dequantize"([[ARG0:%arg[0-9]+]]) : (tensor>) -> tensor diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index 073b8b8bf77fce..3e65235b9b5bee 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -1630,17 +1630,6 @@ func.func @op_tuple(%arg0: tensor) -> tuple> { func.return %0 : tuple> } -// CHECK-LABEL: "op_unary_einsum" -func.func @op_unary_einsum(%arg0: tensor<8x16xf32>) -> tensor<8xf32> { - // CHECK: "mhlo.unary_einsum"([[ARG0:%arg[0-9]+]]) <{ - // CHECK-SAME: einsum_config = "ab->a" - // CHECK-SAME: }> : (tensor<8x16xf32>) -> tensor<8xf32> - %0 = "stablehlo.unary_einsum"(%arg0) { - einsum_config = "ab->a" - } : (tensor<8x16xf32>) -> tensor<8xf32> - func.return %0 : tensor<8xf32> -} - // CHECK-LABEL: "op_uniform_dequantize" func.func @op_uniform_dequantize(%arg0: tensor>) -> tensor { // CHECK: "mhlo.uniform_dequantize"([[ARG0:%arg[0-9]+]]) : (tensor>) -> tensor @@ -1969,3 +1958,14 @@ func.func @op_topk_mhlo_v1(%arg0: tensor<5x10xf32>) -> (tensor<5x8xf32>, tensor< } : (tensor<5x10xf32>) -> (tensor<5x8xf32>, tensor<5x8xi32>) func.return %0#0, %0#1 : tensor<5x8xf32>, tensor<5x8xi32> } + +// ----- + +func.func @op_unary_einsum_deprecated(%arg0: tensor<8x16xf32>) -> tensor<8xf32> { + // expected-error@+2 {{failed to legalize operation 'stablehlo.unary_einsum' that was explicitly marked illegal}} + // expected-error@+1 {{UnaryEinsumOp is deprecated and not supported in MHLO}} + %0 = "stablehlo.unary_einsum"(%arg0) { + einsum_config = "ab->a" + } : (tensor<8x16xf32>) -> tensor<8xf32> + func.return %0 : tensor<8xf32> +} diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 4a6ad69d70428e..0178e589ca5733 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -2707,12 +2707,6 @@ LogicalResult ExportXlaOp(TraceOp op, OpLoweringContext ctx) { return success(); } -LogicalResult ExportXlaOp(UnaryEinsumOp op, OpLoweringContext ctx) { - // Intentional as UnaryEinsumOp is always lowered to the EinsumOp with two - // operands. - return failure(); -} - LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) { xla::XlaComputation condition; xla::XlaComputation body; From 5f2a16c90fcbd0f44602ca9d38fa9b55ca049655 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 20 Jun 2024 13:43:52 -0700 Subject: [PATCH 080/256] Update the `curl` dependency: 8.4.0 -> 8.6.0. Due to security vulnerabilities CVE-2023-46219 and CVE-2023-46218. Fixes https://github.com/tensorflow/tensorflow/issues/69799 PiperOrigin-RevId: 645142519 --- tensorflow/workspace2.bzl | 6 +++--- third_party/curl.BUILD | 1 + third_party/xla/third_party/tsl/third_party/curl.BUILD | 1 + third_party/xla/third_party/tsl/workspace2.bzl | 6 +++--- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 249ab7b451957b..cd04eaba8ed78c 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -419,10 +419,10 @@ def _tf_repositories(): tf_http_archive( name = "curl", build_file = "//third_party:curl.BUILD", - sha256 = "816e41809c043ff285e8c0f06a75a1fa250211bbfb2dc0a037eeef39f1a9e427", - strip_prefix = "curl-8.4.0", + sha256 = "9c6db808160015f30f3c656c0dec125feb9dc00753596bf858a272b5dd8dc398", + strip_prefix = "curl-8.6.0", system_build_file = "//third_party/systemlibs:curl.BUILD", - urls = tf_mirror_urls("https://curl.se/download/curl-8.4.0.tar.gz"), + urls = tf_mirror_urls("https://curl.se/download/curl-8.6.0.tar.gz"), ) # WARNING: make sure ncteisen@ and vpai@ are cc-ed on any CL to change the below rule diff --git a/third_party/curl.BUILD b/third_party/curl.BUILD index a58b18c73bd8a5..742dea2407fe2a 100644 --- a/third_party/curl.BUILD +++ b/third_party/curl.BUILD @@ -311,6 +311,7 @@ cc_library( "lib/vquic/curl_msh3.h", "lib/vquic/curl_ngtcp2.c", "lib/vquic/curl_ngtcp2.h", + "lib/vquic/curl_osslq.h", "lib/vquic/curl_quiche.c", "lib/vquic/curl_quiche.h", "lib/vquic/vquic.c", diff --git a/third_party/xla/third_party/tsl/third_party/curl.BUILD b/third_party/xla/third_party/tsl/third_party/curl.BUILD index a58b18c73bd8a5..742dea2407fe2a 100644 --- a/third_party/xla/third_party/tsl/third_party/curl.BUILD +++ b/third_party/xla/third_party/tsl/third_party/curl.BUILD @@ -311,6 +311,7 @@ cc_library( "lib/vquic/curl_msh3.h", "lib/vquic/curl_ngtcp2.c", "lib/vquic/curl_ngtcp2.h", + "lib/vquic/curl_osslq.h", "lib/vquic/curl_quiche.c", "lib/vquic/curl_quiche.h", "lib/vquic/vquic.c", diff --git a/third_party/xla/third_party/tsl/workspace2.bzl b/third_party/xla/third_party/tsl/workspace2.bzl index 4bd2fd2e1c158c..b583f3c0eec7be 100644 --- a/third_party/xla/third_party/tsl/workspace2.bzl +++ b/third_party/xla/third_party/tsl/workspace2.bzl @@ -327,10 +327,10 @@ def _tf_repositories(): tf_http_archive( name = "curl", build_file = "//third_party:curl.BUILD", - sha256 = "816e41809c043ff285e8c0f06a75a1fa250211bbfb2dc0a037eeef39f1a9e427", - strip_prefix = "curl-8.4.0", + sha256 = "9c6db808160015f30f3c656c0dec125feb9dc00753596bf858a272b5dd8dc398", + strip_prefix = "curl-8.6.0", system_build_file = "//third_party/systemlibs:curl.BUILD", - urls = tf_mirror_urls("https://curl.se/download/curl-8.4.0.tar.gz"), + urls = tf_mirror_urls("https://curl.se/download/curl-8.6.0.tar.gz"), ) # WARNING: make sure ncteisen@ and vpai@ are cc-ed on any CL to change the below rule From f34a29f60a7013aad4e366fddf24894acf8a3037 Mon Sep 17 00:00:00 2001 From: Penporn Koanantakool Date: Thu, 20 Jun 2024 13:49:30 -0700 Subject: [PATCH 081/256] [xla:gpu] Rename collective_ops_test_e2e to conform with Google's test naming style. PiperOrigin-RevId: 645144203 --- third_party/xla/xla/tests/BUILD | 4 ++-- ...{collective_ops_test_e2e.cc => collective_ops_e2e_test.cc} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename third_party/xla/xla/tests/{collective_ops_test_e2e.cc => collective_ops_e2e_test.cc} (100%) diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index bf95343ebd314d..d31713c3242e15 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -2186,8 +2186,8 @@ xla_test( ) xla_test( - name = "collective_ops_test_e2e", - srcs = ["collective_ops_test_e2e.cc"], + name = "collective_ops_e2e_test", + srcs = ["collective_ops_e2e_test.cc"], backend_tags = { # This test is tagged "manual" because it requires multiple GPUs, and # Forge only supports single-GPU tests. Guitar skips "manual" tests diff --git a/third_party/xla/xla/tests/collective_ops_test_e2e.cc b/third_party/xla/xla/tests/collective_ops_e2e_test.cc similarity index 100% rename from third_party/xla/xla/tests/collective_ops_test_e2e.cc rename to third_party/xla/xla/tests/collective_ops_e2e_test.cc From 9b2e9eebfc01187fb20c89951ca2d386fad143ab Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Thu, 20 Jun 2024 13:55:25 -0700 Subject: [PATCH 082/256] Move StreamExecutor::MemcpyDeviceToDevice processing into Stream and its derived classes. PiperOrigin-RevId: 645146097 --- .../stream_executor/stream_executor.cc | 15 --------------- .../stream_executor/stream_executor_internal.h | 10 +++++++++- .../xla/xla/backends/interpreter/executor.h | 6 ------ .../xla/xla/stream_executor/cuda/cuda_executor.cc | 9 --------- .../xla/xla/stream_executor/gpu/gpu_executor.h | 4 ---- .../xla/xla/stream_executor/gpu/gpu_stream.cc | 13 +++++++++++++ .../xla/xla/stream_executor/gpu/gpu_stream.h | 4 +--- .../xla/xla/stream_executor/host/host_executor.cc | 14 -------------- .../xla/xla/stream_executor/host/host_executor.h | 4 ---- .../xla/xla/stream_executor/host/host_stream.cc | 12 ++++++++++++ .../xla/xla/stream_executor/host/host_stream.h | 4 +--- .../xla/stream_executor/mock_stream_executor.h | 4 ---- .../xla/xla/stream_executor/rocm/rocm_executor.cc | 9 --------- third_party/xla/xla/stream_executor/stream.h | 8 ++++++-- .../xla/xla/stream_executor/stream_common.cc | 9 --------- .../xla/xla/stream_executor/stream_common.h | 2 -- .../xla/xla/stream_executor/stream_executor.h | 9 --------- .../xla/xla/stream_executor/tpu/tpu_executor.cc | 6 ------ .../xla/xla/stream_executor/tpu/tpu_executor.h | 4 ---- .../xla/xla/stream_executor/tpu/tpu_stream.h | 3 ++- 20 files changed, 44 insertions(+), 105 deletions(-) diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index 4ad6de2fce2993..ff7777e1005601 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -312,21 +312,6 @@ class CStreamExecutor : public StreamExecutorCommon { size, c_status.get()); return StatusFromTF_Status(c_status.get()); } - bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, - const DeviceMemoryBase& gpu_src, - uint64 size) override { - OwnedTFStatus c_status(TF_NewStatus()); - SP_Stream stream_handle = static_cast(stream)->Handle(); - SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst); - SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src); - stream_executor_->memcpy_dtod(&device_, stream_handle, &device_mem_dst, - &device_mem_src, size, c_status.get()); - if (TF_GetCode(c_status.get()) != TF_OK) { - LOG(ERROR) << TF_Message(c_status.get()); - return false; - } - return true; - } bool HostCallback(Stream* stream, absl::AnyInvocable callback) override { SP_Stream stream_handle = static_cast(stream)->Handle(); diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h index ffdcfe2a64bade..f61a084f96a63b 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h +++ b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h @@ -248,7 +248,15 @@ class CStream : public StreamCommon { } absl::Status Memcpy(DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, uint64_t size) override { - return StreamCommon::Memcpy(gpu_dst, gpu_src, size); + tensorflow::TF_StatusPtr c_status(TF_NewStatus()); + SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst); + SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src); + stream_executor_->memcpy_dtod(device_, stream_handle_, &device_mem_dst, + &device_mem_src, size, c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + } + return tensorflow::StatusFromTF_Status(c_status.get()); } absl::Status Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, uint64_t size) override { diff --git a/third_party/xla/xla/backends/interpreter/executor.h b/third_party/xla/xla/backends/interpreter/executor.h index d23fa614048688..41c49f7bf1071b 100644 --- a/third_party/xla/xla/backends/interpreter/executor.h +++ b/third_party/xla/xla/backends/interpreter/executor.h @@ -90,12 +90,6 @@ class XlaInterpreterExecutor : public StreamExecutorCommon { delete[] static_cast(mem); } - bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *pop_dst, - const DeviceMemoryBase &host_src, - uint64_t size) override { - return false; - } - absl::Status Memset(Stream *stream, DeviceMemoryBase *location, uint8_t pattern, uint64_t size) override { return absl::InternalError("Interpreter can not memset"); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 129d5de237c795..7d978f043a42fd 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -681,15 +681,6 @@ absl::Status GpuExecutor::Memset(Stream* stream, DeviceMemoryBase* location, AsGpuStreamValue(stream)); } -bool GpuExecutor::MemcpyDeviceToDevice(Stream* stream, - DeviceMemoryBase* gpu_dst, - const DeviceMemoryBase& gpu_src, - uint64_t size) { - return GpuDriver::AsynchronousMemcpyD2D(context_, AsCudaDevicePtr(gpu_dst), - AsCudaDevicePtr(gpu_src), size, - AsGpuStreamValue(stream)); -} - bool GpuExecutor::HostCallback(Stream* stream, absl::AnyInvocable callback) { auto callback_ptr = diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h index 3d3d4a3eafd436..c0a21007e00b5d 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h @@ -211,10 +211,6 @@ class GpuExecutor : public StreamExecutorCommon { absl::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8_t pattern, uint64_t size) override; - bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, - const DeviceMemoryBase& gpu_src, - uint64_t size) override; - bool HostCallback(Stream* stream, absl::AnyInvocable callback) override; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc index fd9415662d6c4a..450e6a51acb5eb 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc @@ -77,6 +77,19 @@ absl::Status GpuStream::MemZero(DeviceMemoryBase* location, uint64_t size) { } } +absl::Status GpuStream::Memcpy(DeviceMemoryBase* gpu_dst, + const DeviceMemoryBase& gpu_src, uint64_t size) { + if (GpuDriver::AsynchronousMemcpyD2D( + parent_->gpu_context(), + reinterpret_cast(gpu_dst->opaque()), + reinterpret_cast(gpu_src.opaque()), size, + gpu_stream())) { + return absl::OkStatus(); + } + + return absl::InternalError("Failed to memcpy from device to device."); +} + absl::Status GpuStream::Memcpy(DeviceMemoryBase* gpu_dst, const void* host_src, uint64_t size) { bool ok = GpuDriver::AsynchronousMemcpyH2D( diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h index d56321c74436a7..c226913b85d601 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h @@ -107,9 +107,7 @@ class GpuStream : public StreamCommon { absl::Status Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, uint64_t size) override; absl::Status Memcpy(DeviceMemoryBase* gpu_dst, - const DeviceMemoryBase& gpu_src, uint64_t size) override { - return StreamCommon::Memcpy(gpu_dst, gpu_src, size); - } + const DeviceMemoryBase& gpu_src, uint64_t size) override; private: GpuExecutor* parent_; // Executor that spawned this stream. diff --git a/third_party/xla/xla/stream_executor/host/host_executor.cc b/third_party/xla/xla/stream_executor/host/host_executor.cc index 24c6ee65f934f0..0748cc33d8012f 100644 --- a/third_party/xla/xla/stream_executor/host/host_executor.cc +++ b/third_party/xla/xla/stream_executor/host/host_executor.cc @@ -143,20 +143,6 @@ absl::Status HostExecutor::SynchronousMemZero(DeviceMemoryBase* location, return absl::OkStatus(); } -bool HostExecutor::MemcpyDeviceToDevice(Stream* stream, - DeviceMemoryBase* gpu_dst, - const DeviceMemoryBase& gpu_src, - uint64_t size) { - void* dst_mem = gpu_dst->opaque(); - void* src_mem = const_cast(gpu_src.opaque()); - // Enqueue this [asynchronous] "device-to-device" (i.e., host-to-host, given - // the nature of the HostExecutor) memcpy on the stream (HostStream) - // associated with the HostExecutor. - AsHostStream(stream)->EnqueueTask( - [src_mem, dst_mem, size]() { memcpy(dst_mem, src_mem, size); }); - return true; -} - absl::Status HostExecutor::Memset(Stream* stream, DeviceMemoryBase* location, uint8 pattern, uint64_t size) { void* gpu_mem = location->opaque(); diff --git a/third_party/xla/xla/stream_executor/host/host_executor.h b/third_party/xla/xla/stream_executor/host/host_executor.h index 3ebe6c9a26ad03..182bbf22f9ad76 100644 --- a/third_party/xla/xla/stream_executor/host/host_executor.h +++ b/third_party/xla/xla/stream_executor/host/host_executor.h @@ -90,10 +90,6 @@ class HostExecutor : public StreamExecutorCommon { delete[] static_cast(mem); } - bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, - const DeviceMemoryBase& gpu_src, - uint64_t size) override; - absl::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8_t pattern, uint64_t size) override; diff --git a/third_party/xla/xla/stream_executor/host/host_stream.cc b/third_party/xla/xla/stream_executor/host/host_stream.cc index b10699924f7f95..61ea26d3dbcc33 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream.cc +++ b/third_party/xla/xla/stream_executor/host/host_stream.cc @@ -57,6 +57,18 @@ HostStream::~HostStream() { parent()->DeallocateStream(this); } +absl::Status HostStream::Memcpy(DeviceMemoryBase* gpu_dst, + const DeviceMemoryBase& gpu_src, + uint64_t size) { + void* dst_mem = gpu_dst->opaque(); + void* src_mem = const_cast(gpu_src.opaque()); + // Enqueue this [asynchronous] "device-to-device" (i.e., host-to-host, given + // the nature of the HostExecutor) memcpy on the stream (HostStream) + // associated with the HostExecutor. + EnqueueTask([src_mem, dst_mem, size]() { memcpy(dst_mem, src_mem, size); }); + return absl::OkStatus(); +} + absl::Status HostStream::Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, uint64_t size) { // Enqueue the [asynchronous] memcpy on the stream (HostStream) associated diff --git a/third_party/xla/xla/stream_executor/host/host_stream.h b/third_party/xla/xla/stream_executor/host/host_stream.h index b97a8d51bce3e6..15fdbcafaa253f 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream.h +++ b/third_party/xla/xla/stream_executor/host/host_stream.h @@ -60,9 +60,7 @@ class HostStream : public StreamCommon { absl::Status Memcpy(DeviceMemoryBase* gpu_dst, const void* host_src, uint64_t size) override; absl::Status Memcpy(DeviceMemoryBase* gpu_dst, - const DeviceMemoryBase& gpu_src, uint64_t size) override { - return StreamCommon::Memcpy(gpu_dst, gpu_src, size); - } + const DeviceMemoryBase& gpu_src, uint64_t size) override; absl::Status Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, uint64_t size) override; diff --git a/third_party/xla/xla/stream_executor/mock_stream_executor.h b/third_party/xla/xla/stream_executor/mock_stream_executor.h index 33eccd6adfb1ae..2a51fb0e8ece76 100644 --- a/third_party/xla/xla/stream_executor/mock_stream_executor.h +++ b/third_party/xla/xla/stream_executor/mock_stream_executor.h @@ -109,10 +109,6 @@ class MockStreamExecutor : public StreamExecutor { (Stream * stream, DeviceMemoryBase* location, uint8_t pattern, uint64_t size), (override)); - MOCK_METHOD(bool, MemcpyDeviceToDevice, - (Stream * stream, DeviceMemoryBase* device_dst, - const DeviceMemoryBase& device_src, uint64_t size), - (override)); MOCK_METHOD(bool, HostCallback, (Stream * stream, absl::AnyInvocable callback), (override)); diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc index 5c8e30f8830da6..12b05f307c659e 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc @@ -597,15 +597,6 @@ absl::Status GpuExecutor::Memcpy(Stream* stream, void* host_dst, return absl::OkStatus(); } -bool GpuExecutor::MemcpyDeviceToDevice(Stream* stream, - DeviceMemoryBase* gpu_dst, - const DeviceMemoryBase& gpu_src, - uint64_t size) { - return GpuDriver::AsynchronousMemcpyD2D(context_, AsROCmDevicePtr(gpu_dst), - AsROCmDevicePtr(gpu_src), size, - AsGpuStreamValue(stream)); -} - bool GpuExecutor::HostCallback(Stream* stream, absl::AnyInvocable callback) { auto callback_ptr = diff --git a/third_party/xla/xla/stream_executor/stream.h b/third_party/xla/xla/stream_executor/stream.h index 2cc531e2ef32bb..9b74136f8e06c1 100644 --- a/third_party/xla/xla/stream_executor/stream.h +++ b/third_party/xla/xla/stream_executor/stream.h @@ -192,8 +192,12 @@ class Stream { // of the given target size. gpu_src/dst must be pointers to GPU memory and // peer access must be enabled between their owning StreamExecutors. virtual absl::Status Memcpy(DeviceMemoryBase *gpu_dst, - const DeviceMemoryBase &gpu_src, - uint64_t size) = 0; + const DeviceMemoryBase &gpu_src, uint64_t size) { + return absl::UnimplementedError( + "Memcpy from device to device is not implemented for this " + "stream."); + } + absl::Status MemcpyD2D(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, uint64_t size) { return Memcpy(gpu_dst, gpu_src, size); diff --git a/third_party/xla/xla/stream_executor/stream_common.cc b/third_party/xla/xla/stream_executor/stream_common.cc index 43db75984efdba..e9d46afc5174f2 100644 --- a/third_party/xla/xla/stream_executor/stream_common.cc +++ b/third_party/xla/xla/stream_executor/stream_common.cc @@ -143,15 +143,6 @@ void StreamCommon::ReturnSubStream(Stream *sub_stream) { << sub_stream; } -absl::Status StreamCommon::Memcpy(DeviceMemoryBase *gpu_dst, - const DeviceMemoryBase &gpu_src, - uint64_t size) { - if (parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size)) { - return absl::OkStatus(); - } - return absl::InternalError("failed to memcpy"); -} - absl::Status StreamCommon::DoHostCallback( absl::AnyInvocable callback) { return DoHostCallbackWithStatus([cb = std::move(callback)]() mutable { diff --git a/third_party/xla/xla/stream_executor/stream_common.h b/third_party/xla/xla/stream_executor/stream_common.h index 4278fe7063d550..e5d501ca3c1280 100644 --- a/third_party/xla/xla/stream_executor/stream_common.h +++ b/third_party/xla/xla/stream_executor/stream_common.h @@ -71,8 +71,6 @@ class StreamCommon : public Stream { absl::StatusOr GetOrCreateSubStream() override TF_LOCKS_EXCLUDED(mu_); void ReturnSubStream(Stream *sub_stream) override TF_LOCKS_EXCLUDED(mu_); - absl::Status Memcpy(DeviceMemoryBase *gpu_dst, - const DeviceMemoryBase &gpu_src, uint64_t size) override; absl::Status BlockHostUntilDone() override TF_LOCKS_EXCLUDED(mu_); absl::Status DoHostCallback(absl::AnyInvocable callback) override; absl::Status DoHostCallbackWithStatus( diff --git a/third_party/xla/xla/stream_executor/stream_executor.h b/third_party/xla/xla/stream_executor/stream_executor.h index 593c5c33f3833c..0a2e9747cc7be1 100644 --- a/third_party/xla/xla/stream_executor/stream_executor.h +++ b/third_party/xla/xla/stream_executor/stream_executor.h @@ -231,15 +231,6 @@ class StreamExecutor { return absl::InternalError("Not implemented"); } - // Enqueues a memcpy operation onto stream, with a device destination location - // and a device source location, with target size size. Peer access should - // have been enabled between the StreamExecutors owning the device memory - // regions. - virtual bool MemcpyDeviceToDevice(Stream* stream, - DeviceMemoryBase* device_dst, - const DeviceMemoryBase& device_src, - uint64_t size) = 0; - // Enqueues on a stream a user-specified function to be run on the host. virtual bool HostCallback(Stream* stream, absl::AnyInvocable callback) = 0; diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc b/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc index 8bbf2ca1945122..4d54c2948ea679 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc @@ -190,12 +190,6 @@ absl::Status TpuExecutor::SynchronousMemcpy( return status.status(); } -bool TpuExecutor::MemcpyDeviceToDevice( - Stream* stream, ::stream_executor::DeviceMemoryBase* gpu_dst, - const ::stream_executor::DeviceMemoryBase& host_src, uint64_t size) { - LOG(FATAL) << __func__ << " not supported on TpuExecutor"; -} - absl::Status TpuExecutor::UnloadAllPrograms() { StatusHelper status; ExecutorApiFn()->TpuExecutor_UnloadAllProgramsFn(executor_, status.c_status); diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h index a10efb34f39b91..6e9251060c3fd1 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h @@ -97,10 +97,6 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { bool HostCallback(Stream* stream, absl::AnyInvocable callback) override; - bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, - const DeviceMemoryBase& host_src, - uint64_t size) override; - bool SynchronizeAllActivity() override; absl::Status SynchronousMemcpy(DeviceMemoryBase* device_dst, diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_stream.h b/third_party/xla/xla/stream_executor/tpu/tpu_stream.h index c056a4716e3332..79bc6eef8cfbe4 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_stream.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_stream.h @@ -134,7 +134,8 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface { absl::Status Memcpy(stream_executor::DeviceMemoryBase* device_dst, const stream_executor::DeviceMemoryBase& device_src, uint64_t size) override { - return StreamCommon::Memcpy(device_dst, device_src, size); + return absl::UnimplementedError( + "Memcpy from device to deviceis not implemented for TPU"); } absl::Status Memcpy(void* host_dst, const stream_executor::DeviceMemoryBase& device_src, From 61dc6fb8cf109ef9a4caf875ebd75acfda58cd1c Mon Sep 17 00:00:00 2001 From: Arturo Schmidt Date: Thu, 20 Jun 2024 14:13:28 -0700 Subject: [PATCH 083/256] Delete translate directory ConvertMlirToGraphDef. PiperOrigin-RevId: 645152198 --- .../tensorflow/translate/export_graphdef.cc | 20 ------------------- .../tensorflow/translate/export_graphdef.h | 5 ----- 2 files changed, 25 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index a4d50422625de9..9ccfa83d9c29eb 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -809,24 +809,4 @@ Status ConvertMlirToGraph(mlir::ModuleOp module, &control_ret_nodes); } -absl::StatusOr> ConvertMlirToGraphdef( - mlir::ModuleOp module, const GraphExportConfig& configs) { - FunctionLibraryDefinition flib_def(OpRegistry::Global(), - FunctionDefLibrary()); - std::unique_ptr graph; - TF_RETURN_IF_ERROR(ConvertMlirToGraph(module, configs, &graph, &flib_def)); - - // If the entry function is exported to flib, then no graph is constructed. - // Construct one in that case. - if (configs.export_entry_func_to_flib) { - graph = std::make_unique(OpRegistry::Global()); - TF_RETURN_IF_ERROR( - graph->mutable_flib_def()->AddLibrary(std::move(flib_def))); - } - - auto graphdef = std::make_unique(); - graph->ToGraphDef(graphdef.get()); - return graphdef; -} - } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h index 722baaff57f41b..acf30a97bcbf05 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h @@ -31,11 +31,6 @@ limitations under the License. namespace tensorflow { -ABSL_DEPRECATED("Use tensorflow::tf2xla::api::ConvertMlirToGraphdef instead.") -// Given an MLIR module, returns a GraphDef. -absl::StatusOr> ConvertMlirToGraphdef( - mlir::ModuleOp module, const GraphExportConfig& configs); - // Converts an MLIR module to TensorFlow graph and FunctionLibraryDefinition. // The "main" function of the module is stored in the graph and the rest of // functions are stored in the library. From d7913e7d71360341bd0eb66e71895b0497be4907 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 20 Jun 2024 14:14:47 -0700 Subject: [PATCH 084/256] Replace string with std::string in quantize_model_test.cc PiperOrigin-RevId: 645152559 --- tensorflow/compiler/mlir/lite/quantization/lite/BUILD | 1 - .../mlir/lite/quantization/lite/quantize_model_test.cc | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index e88b876e859c5d..35aa1bb8d9001e 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -168,7 +168,6 @@ tf_cc_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", - "//tensorflow/lite:string", "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/tools/optimize:test_util", "@com_google_absl//absl/container:flat_hash_set", diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc index 3bff10adfd578f..fe9a326dbd0b13 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc @@ -37,7 +37,6 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/string_type.h" #include "tensorflow/lite/tools/optimize/test_util.h" // Note: branched from tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -71,7 +70,8 @@ ModelT UnPackFlatBufferModel(const Model& flatbuffer_model) { TfLiteStatus QuantizeModel( ModelT* model, const TensorType& input_type, const TensorType& output_type, - const bool allow_float, const std::unordered_set& operator_names, + const bool allow_float, + const std::unordered_set& operator_names, const TensorType& activations_type, std::string& output_buffer, const bool disable_per_channel = false, const absl::flat_hash_set& blocked_ops = {}, @@ -156,7 +156,7 @@ TfLiteStatus QuantizeModelAllOperators( disable_per_channel_for_dense_layers); } -std::unique_ptr ReadModel(const string& model_name) { +std::unique_ptr ReadModel(const std::string& model_name) { auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, model_name); return FlatBufferModel::BuildFromFile(model_path.c_str()); } From e7218705a91c404fe9e37c629e8e9bdd34892d1b Mon Sep 17 00:00:00 2001 From: Mehrdad Khani Date: Thu, 20 Jun 2024 14:18:26 -0700 Subject: [PATCH 085/256] Brings back one usage of GetCorrectedUseTime() to improve code reuse PiperOrigin-RevId: 645153704 --- .../memory_space_assignment/algorithm.cc | 42 +------------------ 1 file changed, 1 insertion(+), 41 deletions(-) diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc index baab8be470ae06..f7aab77a084e5c 100644 --- a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc +++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc @@ -1992,47 +1992,7 @@ MsaAlgorithm::AllocationRequest MsaAlgorithm::CreateAllocationRequest( latest_prefetch_time = std::min(computation_span.start - 1, latest_prefetch_time); } - if (hlo_use.instruction->opcode() == HloOpcode::kWhile) { - // Given an example while loop and flattened schedule (logical times - // shown on the left): - // - // 0: a = ... - // 1: ... - // cond { - // 2: p = param(0) - // 3: ... - // } - // body { - // 4: p = param(0) - // 5: ... - // 6: ROOT ... - // } - // 7: w = while(a), body=body, cond=cond - // - // When processing "a" (time 0) and its while use (time 7), we update - // the interval to time 0-4. This is so that the remaining interval - // (5-6) can be allocated separately and this buffer doesn't waste - // alternate memory space within the while loop body. - HloComputation* while_body = hlo_use.instruction->while_body(); - // We require while body ROOTs to be the last in the schedule. - CHECK_EQ(instruction_schedule.at(while_body->root_instruction()) + 1, - instruction_schedule.at(hlo_use.instruction)) - << "While body ROOTs need to be the last in the schedule! " - "Please run RootInstructionSinker."; - // Replace the use time with the parameter time so that we can decide - // on alternate memory allocations within the while loop body when we - // look at uses within the while loop body. - use_time = instruction_schedule.at(while_body->parameter_instruction(0)); - } else if (hlo_use.instruction->opcode() == HloOpcode::kConditional) { - // Replace the use time with the earliest parameter of called - // computations. - for (const HloComputation* called_computation : - hlo_use.instruction->called_computations()) { - use_time = std::min(use_time, - instruction_schedule.at( - called_computation->parameter_instruction(0))); - } - } + use_time = GetCorrectedUseTime(hlo_use); } // Add a required assignment in default memory if the use not allowed in From 5fe8d891e8e697c97941a082a148117a569ccdd7 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 20 Jun 2024 14:37:53 -0700 Subject: [PATCH 086/256] [xla:cpu] Add support for CollectivePermute thunk + enabled collective ops test with thunks as all collective operations are now supported PiperOrigin-RevId: 645159920 --- third_party/xla/xla/service/cpu/BUILD | 1 + third_party/xla/xla/service/cpu/runtime/BUILD | 34 +++++ .../cpu/runtime/collective_permute_thunk.cc | 133 ++++++++++++++++++ .../cpu/runtime/collective_permute_thunk.h | 52 +++++++ .../xla/xla/service/cpu/runtime/thunk.cc | 2 + .../xla/xla/service/cpu/runtime/thunk.h | 1 + .../xla/xla/service/cpu/thunk_emitter.cc | 33 +++++ .../xla/xla/service/cpu/thunk_emitter.h | 3 + third_party/xla/xla/tests/BUILD | 2 + .../xla/xla/tests/collective_ops_test.cc | 1 + 10 files changed, 262 insertions(+) create mode 100644 third_party/xla/xla/service/cpu/runtime/collective_permute_thunk.cc create mode 100644 third_party/xla/xla/service/cpu/runtime/collective_permute_thunk.h diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 7d3457638ad18e..3ab2a2abb0ef56 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -816,6 +816,7 @@ cc_library( "//xla/service/cpu/runtime:all_reduce_thunk", "//xla/service/cpu/runtime:all_to_all_thunk", "//xla/service/cpu/runtime:call_thunk", + "//xla/service/cpu/runtime:collective_permute_thunk", "//xla/service/cpu/runtime:collective_thunk", "//xla/service/cpu/runtime:conditional_thunk", "//xla/service/cpu/runtime:copy_thunk", diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD index c4ec5905e90983..e89d00a776415a 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/service/cpu/runtime/BUILD @@ -320,6 +320,40 @@ cc_library( ], ) +cc_library( + name = "collective_permute_thunk", + srcs = ["collective_permute_thunk.cc"], + hdrs = ["collective_permute_thunk.h"], + deps = [ + ":collective_thunk", + ":thunk", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/runtime:buffer_use", + "//xla/service:buffer_assignment", + "//xla/service:collective_ops_utils", + "//xla/service:computation_placer", + "//xla/service/cpu:collectives_interface", + "//xla/stream_executor", + "//xla/tsl/concurrency:async_value", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:numbers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:traceme", + ], +) + cc_library( name = "collective_thunk", srcs = ["collective_thunk.cc"], diff --git a/third_party/xla/xla/service/cpu/runtime/collective_permute_thunk.cc b/third_party/xla/xla/service/cpu/runtime/collective_permute_thunk.cc new file mode 100644 index 00000000000000..8b4923edd921ff --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/collective_permute_thunk.cc @@ -0,0 +1,133 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/runtime/collective_permute_thunk.h" + +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/computation_placer.h" +#include "xla/service/cpu/collectives_interface.h" +#include "xla/service/cpu/runtime/collective_thunk.h" +#include "xla/service/cpu/runtime/thunk.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/traceme.h" + +namespace xla::cpu { + +absl::StatusOr> +CollectivePermuteThunk::Create( + Info info, OpParams op_params, OpBuffers op_buffers, + absl::Span source_target_pairs) { + return absl::WrapUnique(new CollectivePermuteThunk( + std::move(info), op_params, std::move(op_buffers), source_target_pairs)); +} + +CollectivePermuteThunk::CollectivePermuteThunk( + Info info, OpParams op_params, OpBuffers op_buffers, + absl::Span source_target_pairs) + : CollectiveThunk(Kind::kCollectivePermute, info, op_params, + std::move(op_buffers)), + source_target_pairs_(source_target_pairs.begin(), + source_target_pairs.end()) {} + +tsl::AsyncValueRef +CollectivePermuteThunk::Execute(const ExecuteParams& params) { + tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); + + TF_ASSIGN_OR_RETURN(OpDeviceMemory data, GetOpDeviceMemory(params)); + + Thunk::CollectiveExecuteParams* collective_params = params.collective_params; + TF_RET_CHECK(collective_params) << "Collectives parameters are not set"; + + TF_ASSIGN_OR_RETURN(DeviceAssignment::LogicalID logical_id, + collective_params->device_assignment->LogicalIdForDevice( + collective_params->global_device_id)); + + int32_t logical_device_id = op_params().has_channel_id + ? logical_id.computation_id + : logical_id.replica_id; + + // Find replicas that we will communicate with. + std::optional source_replica_id; + std::vector copy_to; + + for (auto& [from, to] : source_target_pairs_) { + if (from == logical_device_id) { + copy_to.push_back(to); + } + if (to == logical_device_id) { + TF_RET_CHECK(!source_replica_id.has_value()) + << "Duplicate source replica: " << from << ". " + << "Previous source replica: " << *source_replica_id; + source_replica_id = from; + } + } + + VLOG(3) << absl::StreamFormat( + "CollectivePermute: #source_buffers=%d, #destination_buffers=%d, " + "source_target_pairs=[%s], logical_device_id=%d (%s), " + "source_replica_id=%d, copy_to=[%s]", + data.source.size(), data.destination.size(), + absl::StrJoin(source_target_pairs_, ", ", absl::PairFormatter("->")), + logical_device_id, + op_params().has_channel_id ? "computation id" : "replica id", + source_replica_id.value_or(-1), absl::StrJoin(copy_to, ",")); + + for (int i = 0; i < data.source.size(); ++i) { + VLOG(3) << absl::StreamFormat( + " src: %s in slice %s (%p)", source_shape(i).ToString(true), + source_buffer(i).ToString(), data.source[i].opaque()); + } + + for (int i = 0; i < data.destination.size(); ++i) { + VLOG(3) << absl::StreamFormat( + " dst: %s in slice %s (%p)", destination_shape(i).ToString(true), + destination_buffer(i).ToString(), data.destination[i].opaque()); + } + + return ExecuteWithCommunicator( + params.collective_params, + [&](const RendezvousKey& key, CollectivesCommunicator& comm) { + for (int32_t i = 0; i < data.source.size(); ++i) { + const Shape& shape = source_shape(i); + TF_RETURN_IF_ERROR(comm.CollectivePermute( + key, ShapeUtil::ByteSizeOf(shape), source_replica_id, copy_to, + data.source[i].opaque(), data.destination[i].opaque(), + DefaultCollectiveTimeout())); + } + return absl::OkStatus(); + }); +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/collective_permute_thunk.h b/third_party/xla/xla/service/cpu/runtime/collective_permute_thunk.h new file mode 100644 index 00000000000000..1689d23c795c6f --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/collective_permute_thunk.h @@ -0,0 +1,52 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_RUNTIME_COLLECTIVE_PERMUTE_THUNK_H_ +#define XLA_SERVICE_CPU_RUNTIME_COLLECTIVE_PERMUTE_THUNK_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/service/cpu/runtime/collective_thunk.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +class CollectivePermuteThunk final : public CollectiveThunk { + public: + using SourceTargetPair = std::pair; + + static absl::StatusOr> Create( + Info info, OpParams op_params, OpBuffers op_buffers, + absl::Span source_target_pairs); + + tsl::AsyncValueRef Execute(const ExecuteParams& params) final; + + private: + CollectivePermuteThunk( + Info info, OpParams op_params, OpBuffers op_buffers, + absl::Span source_target_pairs); + + std::vector source_target_pairs_; +}; + +} // namespace xla::cpu + +#endif // XLA_SERVICE_CPU_RUNTIME_COLLECTIVE_PERMUTE_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.cc b/third_party/xla/xla/service/cpu/runtime/thunk.cc index 88484ce5416044..b360a492bc4bcc 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/thunk.cc @@ -46,6 +46,8 @@ std::string_view Thunk::KindToString(Kind kind) { return "all-to-all"; case Kind::kCall: return "call"; + case Kind::kCollectivePermute: + return "collective-permute"; case Kind::kConditional: return "conditional"; case Kind::kCopy: diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.h b/third_party/xla/xla/service/cpu/runtime/thunk.h index 7a21dfa858357e..c1a88c72da9342 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.h +++ b/third_party/xla/xla/service/cpu/runtime/thunk.h @@ -65,6 +65,7 @@ class Thunk { kAllReduce, kAllToAll, kCall, + kCollectivePermute, kCopy, kConditional, kDot, diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.cc b/third_party/xla/xla/service/cpu/thunk_emitter.cc index f37d7360326630..0923d2a541c13f 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.cc +++ b/third_party/xla/xla/service/cpu/thunk_emitter.cc @@ -40,6 +40,7 @@ limitations under the License. #include "xla/service/cpu/runtime/all_reduce_thunk.h" #include "xla/service/cpu/runtime/all_to_all_thunk.h" #include "xla/service/cpu/runtime/call_thunk.h" +#include "xla/service/cpu/runtime/collective_permute_thunk.h" #include "xla/service/cpu/runtime/collective_thunk.h" #include "xla/service/cpu/runtime/conditional_thunk.h" #include "xla/service/cpu/runtime/copy_thunk.h" @@ -217,6 +218,8 @@ absl::StatusOr ThunkEmitter::EmitHloInstruction( return EmitReduceScatterThunk(instruction); case HloOpcode::kAllToAll: return EmitAllToAllThunk(instruction); + case HloOpcode::kCollectivePermute: + return EmitCollectivePermuteThunk(instruction); // TODO(ezhulenev): Port pad optimizations from IrEmitter. case HloOpcode::kPad: @@ -302,6 +305,21 @@ static absl::StatusOr GetCollectiveOpParams( }; } +// TODO(ezhulenev): Figure out why CollectivePermute instruction does not have +// `use_global_device_ids` field and how to unify it with every other collective +// operation. +static absl::StatusOr GetCollectiveOpParams( + const HloCollectivePermuteInstruction* instruction) { + return CollectiveThunk::OpParams{ + /*op_id=*/instruction->channel_id().has_value() + ? instruction->channel_id().value() + : instruction->GetModule()->unique_id(), + /*has_channel_id=*/instruction->channel_id().has_value(), + /*use_global_device_ids=*/std::nullopt, + /*replica_groups=*/{}, // CollectivePermute does not have replica groups + }; +} + static absl::StatusOr GetCollectiveOpBuffers( const HloInstruction* instruction, const BufferAssignment& buffer_assignment) { @@ -379,6 +397,21 @@ absl::StatusOr ThunkEmitter::EmitAllToAllThunk( ThunkInfo(all_to_all), std::move(op_params), std::move(op_buffers)); } +absl::StatusOr ThunkEmitter::EmitCollectivePermuteThunk( + const HloInstruction* instruction) { + auto* collective_permute = Cast(instruction); + + TF_ASSIGN_OR_RETURN(CollectivePermuteThunk::OpParams op_params, + GetCollectiveOpParams(collective_permute)); + TF_ASSIGN_OR_RETURN( + CollectivePermuteThunk::OpBuffers op_buffers, + GetCollectiveOpBuffers(collective_permute, buffer_assignment_)); + + return ThunkSequence::Of( + ThunkInfo(collective_permute), std::move(op_params), + std::move(op_buffers), collective_permute->source_target_pairs()); +} + absl::StatusOr ThunkEmitter::EmitReduceScatterThunk( const HloInstruction* instruction) { auto* reduce_scatter = Cast(instruction); diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.h b/third_party/xla/xla/service/cpu/thunk_emitter.h index cc5a5175d6d92f..78987bb4373f68 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.h +++ b/third_party/xla/xla/service/cpu/thunk_emitter.h @@ -116,6 +116,9 @@ class ThunkEmitter { absl::StatusOr EmitAllToAllThunk( const HloInstruction* instruction); + absl::StatusOr EmitCollectivePermuteThunk( + const HloInstruction* instruction); + absl::StatusOr EmitReduceScatterThunk( const HloInstruction* instruction); diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index d31713c3242e15..831173ad721b08 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -2167,6 +2167,7 @@ xla_test( "gpu", "cpu", ], + tags = ["test_xla_cpu_thunks"], deps = [ ":hlo_test_base", ":literal_test_util", @@ -2176,6 +2177,7 @@ xla_test( "//xla:literal", "//xla:literal_util", "//xla:shape_util", + "//xla/service:computation_placer", "//xla/service:hlo_module_config", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", diff --git a/third_party/xla/xla/tests/collective_ops_test.cc b/third_party/xla/xla/tests/collective_ops_test.cc index d9509a1eec2fa7..a6ba7316185f16 100644 --- a/third_party/xla/xla/tests/collective_ops_test.cc +++ b/third_party/xla/xla/tests/collective_ops_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" +#include "xla/service/computation_placer.h" #include "xla/service/hlo_module_config.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" From d5b1a3db0c652b5f68af75b708c3409b6c70a37b Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Thu, 20 Jun 2024 14:40:07 -0700 Subject: [PATCH 087/256] [IFRT Proxy] Run client and backend tests with all supported protocol versions This change extends IFRT Proxy client and server tests to try all supported protocol versions to verify the compatibillity. Previously, only the maximum supported server version and the minimum supported client version were tested. PiperOrigin-RevId: 645160570 --- .../xla/xla/python/ifrt_proxy/client/BUILD | 1 + .../python/ifrt_proxy/client/client_test.cc | 28 ++++-- .../ifrt_proxy/server/ifrt_backend_test.cc | 98 +++++++++++-------- 3 files changed, 77 insertions(+), 50 deletions(-) diff --git a/third_party/xla/xla/python/ifrt_proxy/client/BUILD b/third_party/xla/xla/python/ifrt_proxy/client/BUILD index 89cbd3d6ce8197..8517472bde7773 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/client/BUILD @@ -156,6 +156,7 @@ ifrt_proxy_cc_test( "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", "//xla/service:computation_placer_hdr", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform", diff --git a/third_party/xla/xla/python/ifrt_proxy/client/client_test.cc b/third_party/xla/xla/python/ifrt_proxy/client/client_test.cc index 0b421f6123c212..f565990d88c174 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/client_test.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/client_test.cc @@ -20,6 +20,7 @@ #include #include #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/pjrt/pjrt_device_description.h" #include "xla/python/ifrt/device.h" @@ -58,14 +59,14 @@ using ::testing::EquivToProto; using ::testing::proto::Partially; #endif -IfrtProxyVersion Version() { - IfrtProxyVersion version; - version.set_protocol_version(kClientMinVersion); - return version; -} - -class ClientTest : public ::testing::Test { +class ClientTest : public ::testing::TestWithParam { protected: + IfrtProxyVersion Version() { + IfrtProxyVersion version; + version.set_protocol_version(GetParam()); + return version; + } + void SetUp() override { session_ = std::make_shared(); rpc_helper_ = std::make_shared(Version(), session_); @@ -127,7 +128,7 @@ class ClientTest : public ::testing::Test { std::unique_ptr client_; }; -TEST_F(ClientTest, Init) { +TEST_P(ClientTest, Init) { EXPECT_EQ(client_->platform_name(), "ifrt-service"); EXPECT_EQ(client_->platform_version(), "n/a"); EXPECT_EQ(client_->platform_id(), 42); @@ -172,7 +173,7 @@ TEST_F(ClientTest, Init) { // TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS #if defined(PLATFORM_GOOGLE) -TEST_F(ClientTest, GetDefaultDeviceAssignmentSuccess) { +TEST_P(ClientTest, GetDefaultDeviceAssignmentSuccess) { IfrtResponse response; xla::DeviceAssignment assignment(1, 3); assignment.Serialize( @@ -197,7 +198,7 @@ TEST_F(ClientTest, GetDefaultDeviceAssignmentSuccess) { // TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS #if defined(PLATFORM_GOOGLE) -TEST_F(ClientTest, GetDefaultDeviceAssignmentFailure) { +TEST_P(ClientTest, GetDefaultDeviceAssignmentFailure) { EXPECT_CALL(*session_, Enqueue(Pointee(Partially(EquivToProto( R"pb( get_default_device_assignment_request { @@ -212,6 +213,13 @@ TEST_F(ClientTest, GetDefaultDeviceAssignmentFailure) { } #endif +INSTANTIATE_TEST_SUITE_P( + ClientTestWithAllVersions, ClientTest, + testing::Range(kClientMinVersion, kClientMaxVersion + 1), + [](const testing::TestParamInfo& info) { + return absl::StrCat(info.param); + }); + } // namespace } // namespace proxy } // namespace ifrt diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc index 1e806f18f7b2ce..9cc1af0444f4e8 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc @@ -111,11 +111,15 @@ using ::testing::proto::Partially; constexpr uint64_t kSessionId = 12345; -IfrtProxyVersion Version() { - IfrtProxyVersion version; - version.set_protocol_version(kServerMaxVersion); - return version; -} +class IfrtBackendTest + : public ::testing::TestWithParam { + protected: + IfrtProxyVersion Version() { + IfrtProxyVersion version; + version.set_protocol_version(GetParam()); + return version; + } +}; // Makes an empty request with the given op_id. Does not fail. std::unique_ptr NewIfrtRequest(uint64_t op_id) { @@ -125,19 +129,19 @@ std::unique_ptr NewIfrtRequest(uint64_t op_id) { return ifrt_request; } -TEST(IfrtBackendTest, CreationFailsWithNullIfrtClient) { +TEST_P(IfrtBackendTest, CreationFailsWithNullIfrtClient) { EXPECT_THAT(IfrtBackend::Create(Version(), kSessionId, nullptr, nullptr), StatusIs(absl::StatusCode::kInvalidArgument)); } -TEST(IfrtBackendTest, SuccessfulCreation) { +TEST_P(IfrtBackendTest, SuccessfulCreation) { auto ifrt_client = std::make_unique(); ASSERT_THAT(IfrtBackend::Create(Version(), kSessionId, std::move(ifrt_client), std::make_shared()), IsOk()); } -TEST(IfrtBackendTest, ShutdownSucceeds) { +TEST_P(IfrtBackendTest, ShutdownSucceeds) { auto ifrt_client = std::make_unique(); TF_ASSERT_OK_AND_ASSIGN( auto ifrt_backend, @@ -145,7 +149,7 @@ TEST(IfrtBackendTest, ShutdownSucceeds) { std::make_shared())); } -TEST(IfrtBackendTest, ProcessFailsWithNoRequestSet) { +TEST_P(IfrtBackendTest, ProcessFailsWithNoRequestSet) { auto ifrt_client = std::make_unique(); TF_ASSERT_OK_AND_ASSIGN( auto ifrt_backend, @@ -159,6 +163,13 @@ TEST(IfrtBackendTest, ProcessFailsWithNoRequestSet) { ASSERT_THAT(process_status, Not(IsOk())); } +INSTANTIATE_TEST_SUITE_P( + IfrtBackendTestWithAllVersions, IfrtBackendTest, + testing::Range(kServerMinVersion, kServerMaxVersion + 1), + [](const testing::TestParamInfo& info) { + return absl::StrCat(info.param); + }); + struct TestProgram : llvm::RTTIExtends { static char ID; // NOLINT }; @@ -217,7 +228,7 @@ class TestCompileOptionsSerDes [[maybe_unused]] char TestCompileOptionsSerDes::ID = 0; // NOLINT -class IfrtBackendHandlerTest : public testing::Test { +class IfrtBackendHandlerTest : public IfrtBackendTest { protected: static void SetUpTestSuite() { RegisterSerDes(std::make_unique()); @@ -354,7 +365,7 @@ class IfrtBackendHandlerTest : public testing::Test { // TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS #if defined(PLATFORM_GOOGLE) -TEST_F(IfrtBackendHandlerTest, Init) { +TEST_P(IfrtBackendHandlerTest, Init) { EXPECT_CALL(*mock_client_, platform_name()) .WillRepeatedly(Return("ifrt_backend")); EXPECT_CALL(*mock_client_, platform_version()).WillRepeatedly(Return("n/a")); @@ -456,7 +467,7 @@ TEST_F(IfrtBackendHandlerTest, Init) { // Consider redoing the happy-path test below with PjRt CPU-only backend for // non-SingleDeviceSharding. -TEST_F(IfrtBackendHandlerTest, DisassembleIntoSingleDeviceArraysSucceeds) { +TEST_P(IfrtBackendHandlerTest, DisassembleIntoSingleDeviceArraysSucceeds) { // Set up a mock source array that returns two single device arrays on // disassembly. std::vector> single_device_arrays; @@ -486,7 +497,7 @@ TEST_F(IfrtBackendHandlerTest, DisassembleIntoSingleDeviceArraysSucceeds) { SizeIs(2)); } -TEST_F(IfrtBackendHandlerTest, MakeArrayFromHostBufferSuccess) { +TEST_P(IfrtBackendHandlerTest, MakeArrayFromHostBufferSuccess) { // Given the below shape, dtype, and compact byte_strides, the size of the // array data needs to be 480 bytes. const uint64_t kHostBufferHandle = 1234; @@ -530,7 +541,7 @@ TEST_F(IfrtBackendHandlerTest, MakeArrayFromHostBufferSuccess) { EXPECT_NE(response->make_array_from_host_buffer_response().array_handle(), 0); } -TEST_F(IfrtBackendHandlerTest, AssembleArrayFromSingleDeviceArrays) { +TEST_P(IfrtBackendHandlerTest, AssembleArrayFromSingleDeviceArrays) { auto ifrt_request = NewIfrtRequest(NewOpId()); { ASSERT_TRUE(TextFormat::ParseFromString( @@ -574,7 +585,7 @@ TEST_F(IfrtBackendHandlerTest, AssembleArrayFromSingleDeviceArrays) { 0); } -TEST_F(IfrtBackendHandlerTest, CopyToHostSuccess) { +TEST_P(IfrtBackendHandlerTest, CopyToHostSuccess) { Shape shape({5, 3, 4}); tsl::RCReference array = tsl::MakeRef(); @@ -607,7 +618,7 @@ TEST_F(IfrtBackendHandlerTest, CopyToHostSuccess) { IsOkAndHolds(Pointee(SizeIs(480)))); } -TEST_F(IfrtBackendHandlerTest, CopyToHostFailsWithNonExistentArrays) { +TEST_P(IfrtBackendHandlerTest, CopyToHostFailsWithNonExistentArrays) { auto ifrt_request = NewIfrtRequest(NewOpId()); ASSERT_TRUE(TextFormat::ParseFromString( R"pb( @@ -620,7 +631,7 @@ TEST_F(IfrtBackendHandlerTest, CopyToHostFailsWithNonExistentArrays) { StatusIs(absl::StatusCode::kNotFound)); } -TEST_F(IfrtBackendHandlerTest, +TEST_P(IfrtBackendHandlerTest, DisassembleIntoSingleArrayFailsWhenBackendRuntimeFails) { // Set up a mock source array that fails the disassembly. constexpr absl::string_view kDisassembleErrorMessage = @@ -645,7 +656,7 @@ TEST_F(IfrtBackendHandlerTest, StatusIs(absl::StatusCode::kUnknown, StrEq(kDisassembleErrorMessage))); } -TEST_F(IfrtBackendHandlerTest, CopyArrays) { +TEST_P(IfrtBackendHandlerTest, CopyArrays) { std::vector> src_arrays; src_arrays.push_back(tsl::MakeRef()); @@ -686,7 +697,7 @@ TEST_F(IfrtBackendHandlerTest, CopyArrays) { SizeIs(copied_arrays.size())); } -TEST_F(IfrtBackendHandlerTest, ReshardSuccess) { +TEST_P(IfrtBackendHandlerTest, ReshardSuccess) { auto src_mock_array = tsl::MakeRef(); auto resharded_mock_array = tsl::MakeRef(); EXPECT_CALL(*src_mock_array, Reshard(_, _)) @@ -711,7 +722,7 @@ TEST_F(IfrtBackendHandlerTest, ReshardSuccess) { EXPECT_NE(response->reshard_response().array_handle(), 0); } -TEST_F(IfrtBackendHandlerTest, FullyReplicatedShardSuccess) { +TEST_P(IfrtBackendHandlerTest, FullyReplicatedShardSuccess) { auto fully_replicated_mock_array = tsl::MakeRef(); auto resultant_array = tsl::MakeRef(); EXPECT_CALL(*fully_replicated_mock_array, FullyReplicatedShard(_)) @@ -732,7 +743,7 @@ TEST_F(IfrtBackendHandlerTest, FullyReplicatedShardSuccess) { EXPECT_NE(response->fully_replicated_shard_response().array_handle(), 0); } -TEST_F(IfrtBackendHandlerTest, FullyReplicatedShardFailure) { +TEST_P(IfrtBackendHandlerTest, FullyReplicatedShardFailure) { auto fully_replicated_mock_array = tsl::MakeRef(); EXPECT_CALL(*fully_replicated_mock_array, FullyReplicatedShard(_)) .WillOnce(Return(absl::UnknownError("injected error"))); @@ -752,7 +763,7 @@ TEST_F(IfrtBackendHandlerTest, FullyReplicatedShardFailure) { StatusIs(absl::StatusCode::kUnknown, StrEq("injected error"))); } -TEST_F(IfrtBackendHandlerTest, +TEST_P(IfrtBackendHandlerTest, FullyReplicatedShardFailsWithNonExistentArrayHandle) { auto ifrt_request = NewIfrtRequest(NewOpId()); auto* fully_replicated_shard_request = @@ -765,7 +776,7 @@ TEST_F(IfrtBackendHandlerTest, StatusIs(absl::StatusCode::kNotFound)); } -TEST_F(IfrtBackendHandlerTest, ReshardFailsWhenTheBackendFails) { +TEST_P(IfrtBackendHandlerTest, ReshardFailsWhenTheBackendFails) { auto mock_array = tsl::MakeRef(); EXPECT_CALL(*mock_array, Reshard(_, _)) .WillOnce(Return(absl::UnknownError("injected error"))); @@ -786,7 +797,7 @@ TEST_F(IfrtBackendHandlerTest, ReshardFailsWhenTheBackendFails) { StatusIs(absl::StatusCode::kUnknown, StrEq("injected error"))); } -TEST_F(IfrtBackendHandlerTest, ReshardFailsWithNonExistentArrayHandle) { +TEST_P(IfrtBackendHandlerTest, ReshardFailsWithNonExistentArrayHandle) { auto ifrt_request = NewIfrtRequest(NewOpId()); auto* reshard_request = ifrt_request->mutable_reshard_request(); reshard_request->set_array_handle(0); @@ -797,7 +808,7 @@ TEST_F(IfrtBackendHandlerTest, ReshardFailsWithNonExistentArrayHandle) { StatusIs(absl::StatusCode::kNotFound)); } -TEST_F(IfrtBackendHandlerTest, +TEST_P(IfrtBackendHandlerTest, CheckArrayReadyRequestRelaysTheResultFromBackend) { auto mock_array = tsl::MakeRef(); TF_ASSERT_OK_AND_ASSIGN(auto array_handle, @@ -827,7 +838,7 @@ TEST_F(IfrtBackendHandlerTest, } } -TEST_F(IfrtBackendHandlerTest, +TEST_P(IfrtBackendHandlerTest, CheckArrayReadyRequestFailsWithNonExistentArrayHandle) { auto ifrt_request = NewIfrtRequest(NewOpId()); ifrt_request->mutable_check_value_ready_request()->add_value_handles(0); @@ -835,7 +846,7 @@ TEST_F(IfrtBackendHandlerTest, StatusIs(absl::StatusCode::kNotFound)); } -TEST_F(IfrtBackendHandlerTest, DeleteArraySuccess) { +TEST_P(IfrtBackendHandlerTest, DeleteArraySuccess) { tsl::RCReference mock_array = tsl::MakeRef(); EXPECT_CALL(*mock_array, Delete()) @@ -851,14 +862,14 @@ TEST_F(IfrtBackendHandlerTest, DeleteArraySuccess) { EXPECT_NE(resp->delete_array_response().deletion_future_handle(), 0); } -TEST_F(IfrtBackendHandlerTest, DeleteArrayFailsWithNonExistentArrayHandle) { +TEST_P(IfrtBackendHandlerTest, DeleteArrayFailsWithNonExistentArrayHandle) { auto ifrt_request = NewIfrtRequest(NewOpId()); ifrt_request->mutable_delete_array_request()->set_array_handle(0); EXPECT_THAT(CallBackend(std::move(ifrt_request)), StatusIs(absl::StatusCode::kNotFound)); } -TEST_F(IfrtBackendHandlerTest, +TEST_P(IfrtBackendHandlerTest, IsDeleteRelaysBackTheReturnValueFromBackendRuntime) { tsl::RCReference mock_array = tsl::MakeRef(); @@ -883,14 +894,14 @@ TEST_F(IfrtBackendHandlerTest, EXPECT_FALSE(resp->is_array_deleted_response().deleted()); } -TEST_F(IfrtBackendHandlerTest, IsDeleteFailsForNonExistentArrays) { +TEST_P(IfrtBackendHandlerTest, IsDeleteFailsForNonExistentArrays) { auto ifrt_request = NewIfrtRequest(NewOpId()); ifrt_request->mutable_is_array_deleted_request()->set_array_handle(0); EXPECT_THAT(CallBackend(std::move(ifrt_request)), StatusIs(absl::StatusCode::kNotFound)); } -TEST_F(IfrtBackendHandlerTest, DestructArrayTest) { +TEST_P(IfrtBackendHandlerTest, DestructArrayTest) { tsl::RCReference mock_array = tsl::MakeRef(); TF_ASSERT_OK_AND_ASSIGN(auto array_handle, @@ -914,7 +925,7 @@ TEST_F(IfrtBackendHandlerTest, DestructArrayTest) { // TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS #if defined(PLATFORM_GOOGLE) -TEST_F(IfrtBackendHandlerTest, CompileSuccess) { +TEST_P(IfrtBackendHandlerTest, CompileSuccess) { std::vector devices(4); for (int i = 0; i < 4; ++i) { EXPECT_CALL(devices[i], Id()).WillOnce(Return(DeviceId(i))); @@ -956,7 +967,7 @@ TEST_F(IfrtBackendHandlerTest, CompileSuccess) { } #endif -TEST_F(IfrtBackendHandlerTest, CompileFailure) { +TEST_P(IfrtBackendHandlerTest, CompileFailure) { ASSERT_THAT( CompileTestLoadedExecutable(absl::InternalError("injected error")), StatusIs(absl::StatusCode::kInternal, StrEq("injected error"))); @@ -964,7 +975,7 @@ TEST_F(IfrtBackendHandlerTest, CompileFailure) { // TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS #if defined(PLATFORM_GOOGLE) -TEST_F(IfrtBackendHandlerTest, LoadedExecutableMetadata) { +TEST_P(IfrtBackendHandlerTest, LoadedExecutableMetadata) { MockLoadedExecutable* executable; uint64_t handle; { @@ -1075,7 +1086,7 @@ TEST_F(IfrtBackendHandlerTest, LoadedExecutableMetadata) { // TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS #if defined(PLATFORM_GOOGLE) -TEST_F(IfrtBackendHandlerTest, LoadedExecutableExecute) { +TEST_P(IfrtBackendHandlerTest, LoadedExecutableExecute) { MockDevice device; ON_CALL(device, Id()).WillByDefault(Return(DeviceId(0))); @@ -1173,7 +1184,7 @@ TEST_F(IfrtBackendHandlerTest, LoadedExecutableExecute) { // TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS #if defined(PLATFORM_GOOGLE) -TEST_F(IfrtBackendHandlerTest, LoadedExecutableDelete) { +TEST_P(IfrtBackendHandlerTest, LoadedExecutableDelete) { MockLoadedExecutable* executable; uint64_t handle; { @@ -1219,7 +1230,7 @@ TEST_F(IfrtBackendHandlerTest, LoadedExecutableDelete) { } #endif -TEST_F(IfrtBackendHandlerTest, LoadedExecutableDestruct) { +TEST_P(IfrtBackendHandlerTest, LoadedExecutableDestruct) { MockLoadedExecutable* executable; uint64_t handle; { @@ -1255,7 +1266,7 @@ TEST_F(IfrtBackendHandlerTest, LoadedExecutableDestruct) { } } -TEST_F(IfrtBackendHandlerTest, LoadedHostCallbackExecute) { +TEST_P(IfrtBackendHandlerTest, LoadedHostCallbackExecute) { // Build a remote host callback with one F32 argument and one F32 result. std::vector hcb_args = {{ .channel_id = 1, @@ -1398,7 +1409,7 @@ TEST_F(IfrtBackendHandlerTest, LoadedHostCallbackExecute) { } } -TEST_F(IfrtBackendHandlerTest, GetDefaultDeviceAssignmentSuccess) { +TEST_P(IfrtBackendHandlerTest, GetDefaultDeviceAssignmentSuccess) { const int kNumReplicas = 1; const int kNumPartitions = 3; @@ -1421,7 +1432,7 @@ TEST_F(IfrtBackendHandlerTest, GetDefaultDeviceAssignmentSuccess) { EXPECT_EQ(assignment_got->computation_count(), kNumPartitions); } -TEST_F(IfrtBackendHandlerTest, +TEST_P(IfrtBackendHandlerTest, GetDefaultDeviceAssignmentFailsIfTheBackendFails) { const int kNumReplicas = 1; const int kNumPartitions = 3; @@ -1440,6 +1451,13 @@ TEST_F(IfrtBackendHandlerTest, StatusIs(absl::StatusCode::kUnknown, StrEq("injected error"))); } +INSTANTIATE_TEST_SUITE_P( + IfrtBackendHandlerTestWithAllVersions, IfrtBackendHandlerTest, + testing::Range(kServerMinVersion, kServerMaxVersion + 1), + [](const testing::TestParamInfo& info) { + return absl::StrCat(info.param); + }); + } // namespace } // namespace proxy } // namespace ifrt From 8c81a27feafc2c80162978413de8eb3f0d0c2e0d Mon Sep 17 00:00:00 2001 From: Fergus Henderson Date: Thu, 20 Jun 2024 15:02:09 -0700 Subject: [PATCH 088/256] Minor cleanups to #includes etc. in sample_stable_delegate. PiperOrigin-RevId: 645167369 --- .../utils/experimental/sample_stable_delegate/BUILD | 1 + .../sample_app_using_stable_delegate.cc | 2 ++ .../sample_stable_delegate_external_test.cc | 10 ++++++---- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/BUILD b/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/BUILD index 29c1af37c1e5b2..ce1761fb2e6240 100644 --- a/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/BUILD +++ b/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/BUILD @@ -53,6 +53,7 @@ cc_library_with_tflite( cc_test( name = "sample_stable_delegate_test", srcs = ["sample_stable_delegate_test.cc"], + copts = tflite_copts(), data = [ "//tensorflow/lite:testdata/add.bin", "//tensorflow/lite:testdata/sub.bin", diff --git a/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_app_using_stable_delegate.cc b/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_app_using_stable_delegate.cc index d09f35561d74ef..de0607afe60c99 100644 --- a/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_app_using_stable_delegate.cc +++ b/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_app_using_stable_delegate.cc @@ -19,8 +19,10 @@ limitations under the License. #include +#include "tensorflow/lite/acceleration/configuration/c/stable_delegate.h" #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/c/c_api.h" // For TfLiteTensorByteSize. +#include "tensorflow/lite/c/c_api_types.h" // For kTfLiteOk #include "tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter_builder.h" diff --git a/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc b/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc index e237eef084c368..f80a52d0896742 100644 --- a/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc +++ b/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc @@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include #include #include +#include "tensorflow/lite/acceleration/configuration/c/stable_delegate.h" #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/c/c_api.h" #include "tensorflow/lite/c/c_api_opaque.h" @@ -33,7 +34,7 @@ using tflite::delegates::utils::LoadDelegateFromSharedLibrary; TEST(SampleStableDelegate, LoadFromSharedLibraryFile) { // Load the example stable opaque_delegate that implements the ADD operation - // from a shared libary file. + // from a shared library file. const TfLiteStableDelegate* stable_delegate_handle = LoadDelegateFromSharedLibrary( "tensorflow/lite/delegates/utils/experimental/" @@ -54,7 +55,8 @@ TEST(SampleStableDelegate, LoadFromSharedLibraryTestFile) { const TfLiteStableDelegate* stable_delegate_handle = LoadDelegateFromSharedLibrary( "tensorflow/lite/delegates/utils/experimental/" - "sample_stable_delegate/libtensorflowlite_sample_stable_delegate_for_test.so"); + "sample_stable_delegate/" + "libtensorflowlite_sample_stable_delegate_for_test.so"); ASSERT_NE(stable_delegate_handle, nullptr); EXPECT_STREQ(stable_delegate_handle->delegate_abi_version, TFL_STABLE_DELEGATE_ABI_VERSION); @@ -95,7 +97,7 @@ TEST(SampleStableDelegate, LoadFromSharedLibraryTestFile) { TfLiteInterpreterGetInputTensor(interpreter, /*input_index=*/0); ASSERT_NE(input_tensor, nullptr); const float kTensorCellValue = 3.f; - int64_t n = tflite::NumElements(input_tensor); + std::int64_t n = tflite::NumElements(input_tensor); std::vector input(n, kTensorCellValue); ASSERT_EQ(TfLiteTensorCopyFromBuffer(input_tensor, input.data(), input.size() * sizeof(float)), From 5483b6047b15a1dec9215a432860ca258b785675 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Thu, 20 Jun 2024 15:13:41 -0700 Subject: [PATCH 089/256] Fix formatting in `tensorflow/python` These files are making it impossible to run `hg fix` for trivial changes PiperOrigin-RevId: 645171043 --- .../collective_all_reduce_strategy.py | 509 +++++++++++------- .../parameter_server_strategy_v2.py | 167 ++++-- tensorflow/python/eager/context.py | 452 ++++++++++------ tensorflow/python/training/server_lib_test.py | 226 ++++---- 4 files changed, 851 insertions(+), 503 deletions(-) diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index c6a48e3bdd66df..c7dac643cefb37 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -71,7 +71,8 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy): is: ``` - TF_CONFIG = '{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }' + TF_CONFIG = '{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, + "task": {"type": "worker", "index": 0} }' ``` Your program runs on each worker as-is. Note that collectives require each @@ -134,7 +135,8 @@ def step_fn(inputs): ``` See - [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) + [Multi-worker training with + Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for a detailed tutorial. __Saving__ @@ -143,7 +145,8 @@ def step_fn(inputs): because variables whose synchronization=ON_READ triggers aggregation during saving. It's recommended to save to a different path on each worker to avoid race conditions. Each worker saves the same thing. See - [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#model_saving_and_loading) + [Multi-worker training with + Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#model_saving_and_loading) tutorial for examples. __Known Issues__ @@ -154,8 +157,8 @@ def step_fn(inputs): or `None`. * In eager mode, the strategy needs to be created before calling any other Tensorflow API. - """ + # pylint: enable=line-too-long # TODO(anjalisridhar): Update our guides with examples showing how we can use @@ -164,9 +167,7 @@ def step_fn(inputs): # The starting number for collective keys. This should only be set in tests. _collective_key_base = 0 - def __init__(self, - cluster_resolver=None, - communication_options=None): + def __init__(self, cluster_resolver=None, communication_options=None): """Creates the strategy. Args: @@ -186,22 +187,28 @@ def __init__(self, CollectiveAllReduceExtended( self, cluster_resolver=cluster_resolver, - communication_options=communication_options)) + communication_options=communication_options, + ) + ) distribute_lib.distribution_strategy_gauge.get_cell("V2").set( - "MultiWorkerMirroredStrategy") + "MultiWorkerMirroredStrategy" + ) # pylint: disable=protected-access distribute_lib.distribution_strategy_replica_gauge.get_cell( - "num_workers").set(self.extended._num_workers) + "num_workers" + ).set(self.extended._num_workers) distribute_lib.distribution_strategy_replica_gauge.get_cell( - "num_replicas_per_worker").set(self.extended._num_devices_per_worker) + "num_replicas_per_worker" + ).set(self.extended._num_devices_per_worker) @classmethod def _from_local_devices(cls, devices, communication_options=None): """A convenience method to create an object with a list of devices.""" obj = cls(communication_options=communication_options) obj.extended._initialize_local( # pylint: disable=protected-access - tfconfig_cluster_resolver.TFConfigClusterResolver(), devices=devices) + tfconfig_cluster_resolver.TFConfigClusterResolver(), devices=devices + ) return obj @property @@ -229,15 +236,19 @@ def __instancecheck__(cls, instance): @tf_export("distribute.experimental.MultiWorkerMirroredStrategy", v1=[]) class _CollectiveAllReduceStrategyExperimental( CollectiveAllReduceStrategy, - metaclass=_CollectiveAllReduceStrategyExperimentalMeta): + metaclass=_CollectiveAllReduceStrategyExperimentalMeta, +): __doc__ = CollectiveAllReduceStrategy.__doc__ @deprecation.deprecated( - None, "use distribute.MultiWorkerMirroredStrategy instead") - def __init__(self, - communication=collective_util.CommunicationImplementation.AUTO, - cluster_resolver=None): + None, "use distribute.MultiWorkerMirroredStrategy instead" + ) + def __init__( + self, + communication=collective_util.CommunicationImplementation.AUTO, + cluster_resolver=None, + ): """Creates the strategy. Args: @@ -250,22 +261,30 @@ def __init__(self, `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used. """ communication_options = collective_util.Options( - implementation=communication) - super(_CollectiveAllReduceStrategyExperimental, - self).__init__(cluster_resolver, communication_options) + implementation=communication + ) + super(_CollectiveAllReduceStrategyExperimental, self).__init__( + cluster_resolver, communication_options + ) @classmethod def _from_local_devices( cls, devices, - communication=collective_util.CommunicationImplementation.AUTO): + communication=collective_util.CommunicationImplementation.AUTO, + ): """A convenience method to create an object with a list of devices.""" obj = cls(communication) - obj.extended._initialize_local(tfconfig_cluster_resolver.TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access + # pylint: disable=protected-access + obj.extended._initialize_local( + tfconfig_cluster_resolver.TFConfigClusterResolver(), devices=devices + ) return obj -_CollectiveAllReduceStrategyExperimental.__name__ = CollectiveAllReduceStrategy.__name__ +_CollectiveAllReduceStrategyExperimental.__name__ = ( + CollectiveAllReduceStrategy.__name__ +) @tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"]) # pylint: disable=missing-docstring @@ -276,27 +295,36 @@ class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1): # The starting number for collective keys. This should only be set in tests. _collective_key_base = 0 - def __init__(self, - communication=collective_util.CommunicationImplementation.AUTO, - cluster_resolver=None): + def __init__( + self, + communication=collective_util.CommunicationImplementation.AUTO, + cluster_resolver=None, + ): """Initializes the object.""" communication_options = collective_util.Options( - implementation=communication) + implementation=communication + ) super(CollectiveAllReduceStrategyV1, self).__init__( CollectiveAllReduceExtended( self, cluster_resolver=cluster_resolver, - communication_options=communication_options)) + communication_options=communication_options, + ) + ) distribute_lib.distribution_strategy_gauge.get_cell("V1").set( - "MultiWorkerMirroredStrategy") + "MultiWorkerMirroredStrategy" + ) # pylint: disable=protected-access distribute_lib.distribution_strategy_replica_gauge.get_cell( - "num_workers").set(self.extended._num_workers) + "num_workers" + ).set(self.extended._num_workers) distribute_lib.distribution_strategy_replica_gauge.get_cell( - "num_gpu_per_worker").set( - self.extended._num_devices_per_worker - if self.extended._local_device_type == "GPU" - else 0) + "num_gpu_per_worker" + ).set( + self.extended._num_devices_per_worker + if self.extended._local_device_type == "GPU" + else 0 + ) def _is_gpu_device(device): @@ -320,34 +348,49 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): # Timeout in seconds the each check health. _check_health_timeout = 10 - def __init__(self, container_strategy, cluster_resolver, - communication_options, devices=None): + def __init__( + self, + container_strategy, + cluster_resolver, + communication_options, + devices=None, + ): if not isinstance(communication_options, collective_util.Options): - raise ValueError("communication_options must be an instance of " - "tf.distribute.experimental.CommunicationOptions") + raise ValueError( + "communication_options must be an instance of " + "tf.distribute.experimental.CommunicationOptions" + ) if cluster_resolver and devices: raise ValueError( - "cluster_resolver and devices cannot be set at the same time") + "cluster_resolver and devices cannot be set at the same time" + ) - self._cluster_resolver = cluster_resolver or tfconfig_cluster_resolver.TFConfigClusterResolver() - if not isinstance(self._cluster_resolver, cluster_resolver_lib.ClusterResolver): - raise ValueError("cluster_resolver must be an instance of " - "tf.distribute.cluster_resolver.ClusterResolver") + self._cluster_resolver = ( + cluster_resolver or tfconfig_cluster_resolver.TFConfigClusterResolver() + ) + if not isinstance( + self._cluster_resolver, cluster_resolver_lib.ClusterResolver + ): + raise ValueError( + "cluster_resolver must be an instance of " + "tf.distribute.cluster_resolver.ClusterResolver" + ) distribute_lib.StrategyExtendedV1.__init__(self, container_strategy) self._communication_options = communication_options self._collective_key_base = container_strategy._collective_key_base # pylint: disable=protected-access self._initialize_strategy(self._cluster_resolver, devices=devices) self._cfer_fn_cache = weakref.WeakKeyDictionary() self.experimental_enable_get_next_as_optional = True - assert isinstance(self._cross_device_ops, - cross_device_ops_lib.CollectiveAllReduce) + assert isinstance( + self._cross_device_ops, cross_device_ops_lib.CollectiveAllReduce + ) def _use_merge_call(self): # We currently only disable merge_call when XLA is used to compile the `fn` # passed to `strategy.run` and all devices are GPU. return not control_flow_util.GraphOrParentsInXlaContext( - ops.get_default_graph()) or not all( - [_is_gpu_device(d) for d in self._devices]) + ops.get_default_graph() + ) or not all([_is_gpu_device(d) for d in self._devices]) def _initialize_strategy(self, cluster_resolver, devices): # If devices are provided or cluster_spec is not specified, initialize @@ -360,7 +403,9 @@ def _initialize_strategy(self, cluster_resolver, devices): def _initialize_local_devices(self, cluster_resolver, worker_device): # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in # some cases. - if isinstance(cluster_resolver, tfconfig_cluster_resolver.TFConfigClusterResolver): + if isinstance( + cluster_resolver, tfconfig_cluster_resolver.TFConfigClusterResolver + ): num_gpus = context.num_gpus() num_tpus = 0 else: @@ -378,7 +423,8 @@ def _initialize_local_devices(self, cluster_resolver, worker_device): num_local_devices = 1 local_devices = tuple( f"{worker_device}/device:{local_device_type}:{i}" - for i in range(num_local_devices)) + for i in range(num_local_devices) + ) return local_devices, local_device_type def _initialize_local(self, cluster_resolver, devices=None): @@ -389,10 +435,13 @@ def _initialize_local(self, cluster_resolver, devices=None): if ops.executing_eagerly_outside_functions(): try: context.context().configure_collective_ops( - scoped_allocator_enabled_ops=("CollectiveReduce",)) + scoped_allocator_enabled_ops=("CollectiveReduce",) + ) except RuntimeError: - logging.warning("Collective ops is not configured at program startup. " - "Some performance features may not be enabled.") + logging.warning( + "Collective ops is not configured at program startup. " + "Some performance features may not be enabled." + ) self._collective_ops_configured = True if devices: @@ -405,26 +454,31 @@ def _initialize_local(self, cluster_resolver, devices=None): local_device_type = "CPU" else: local_devices, local_device_type = self._initialize_local_devices( - cluster_resolver, worker_device="") + cluster_resolver, worker_device="" + ) self._worker_device = device_util.canonicalize("/device:CPU:0") self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) self._collective_keys = cross_device_utils.CollectiveKeys( - group_key_start=1 + self._collective_key_base) + group_key_start=1 + self._collective_key_base + ) self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( devices=local_devices, group_size=len(local_devices), options=self._communication_options, - collective_keys=self._collective_keys) + collective_keys=self._collective_keys, + ) # CrossDeviceOps for per host tensors. self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( devices=[self._worker_device], group_size=self._num_workers, options=self._communication_options, - collective_keys=self._collective_keys) + collective_keys=self._collective_keys, + ) super(CollectiveAllReduceExtended, self)._initialize_single_worker( - local_devices) + local_devices + ) self._cluster_spec = None self._task_type = None @@ -445,42 +499,54 @@ def _initialize_local(self, cluster_resolver, devices=None): logging.info( "Single-worker MultiWorkerMirroredStrategy with local_devices " - "= %r, communication = %s", local_devices, - self._communication_options.implementation) + "= %r, communication = %s", + local_devices, + self._communication_options.implementation, + ) def _initialize_multi_worker(self, cluster_resolver): """Initializes the object for multi-worker training.""" cluster_spec = multi_worker_util.normalize_cluster_spec( - cluster_resolver.cluster_spec()) + cluster_resolver.cluster_spec() + ) task_type = cluster_resolver.task_type task_id = cluster_resolver.task_id if task_type is None or task_id is None: - raise ValueError("When `cluster_spec` is given, you must also specify " - "`task_type` and `task_id`.") + raise ValueError( + "When `cluster_spec` is given, you must also specify " + "`task_type` and `task_id`." + ) self._cluster_spec = cluster_spec self._task_type = task_type self._task_id = task_id self._id_in_cluster = multi_worker_util.id_in_cluster( - self._cluster_spec, self._task_type, self._task_id) + self._cluster_spec, self._task_type, self._task_id + ) self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type) if not self._num_workers: - raise ValueError("No `worker`, `chief` or `evaluator` tasks can be found " - "in `cluster_spec`.") + raise ValueError( + "No `worker`, `chief` or `evaluator` tasks can be found " + "in `cluster_spec`." + ) - self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, - task_id) + self._is_chief = multi_worker_util.is_chief( + cluster_spec, task_type, task_id + ) self._worker_device = "/job:%s/task:%d" % (task_type, task_id) self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) - if (ops.executing_eagerly_outside_functions() and - not getattr(self, "_local_or_standalone_client_mode", False)): + if ops.executing_eagerly_outside_functions() and not getattr( + self, "_local_or_standalone_client_mode", False + ): context.context().configure_collective_ops( collective_leader=multi_worker_util.collective_leader( - cluster_spec, task_type, task_id), + cluster_spec, task_type, task_id + ), scoped_allocator_enabled_ops=("CollectiveReduce",), - device_filters=("/job:%s/task:%d" % (task_type, task_id),)) + device_filters=("/job:%s/task:%d" % (task_type, task_id),), + ) self._collective_ops_configured = True if context.context().coordination_service is None: coordinated_jobs = ["chief", "worker"] @@ -490,18 +556,23 @@ def _initialize_multi_worker(self, cluster_resolver): if job in cluster_spec.jobs: coordinated_job_config.append( coordination_config_pb2.CoordinatedJob( - name=job, - num_tasks=cluster_spec.num_tasks(job))) + name=job, num_tasks=cluster_spec.num_tasks(job) + ) + ) context.context().configure_coordination_service( service_type="standalone", service_leader=multi_worker_util.coordination_leader( - cluster_spec), - coordinated_jobs=coordinated_job_config) + cluster_spec + ), + coordinated_jobs=coordinated_job_config, + ) # Starting a std server in eager mode and in independent worker mode. - if (context.executing_eagerly() and - not getattr(self, "_std_server_started", False) and - not getattr(self, "_local_or_standalone_client_mode", False)): + if ( + context.executing_eagerly() + and not getattr(self, "_std_server_started", False) + and not getattr(self, "_local_or_standalone_client_mode", False) + ): # Checking _local_or_standalone_client_mode as well because we should not # create the std server in standalone client mode. config_proto = copy.deepcopy(context.context().config) @@ -522,7 +593,8 @@ def _initialize_multi_worker(self, cluster_resolver): job_name=task_type, task_index=task_id, protocol=cluster_resolver.rpc_layer or "grpc", - port=port) + port=port, + ) context.context().enable_collective_ops(server_def) self._std_server_started = True # The `ensure_initialized` is needed before calling @@ -530,7 +602,8 @@ def _initialize_multi_worker(self, cluster_resolver): context.context().ensure_initialized() logging.info( "Enabled multi-worker collective ops with available devices: %r", - context.context().devices()) + context.context().devices(), + ) # TODO(yuefengz): The `num_gpus` is only for this particular task. It # assumes all workers have the same number of GPUs. We should remove this @@ -538,25 +611,30 @@ def _initialize_multi_worker(self, cluster_resolver): # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in # some cases. local_devices, local_device_type = self._initialize_local_devices( - cluster_resolver, self._worker_device) + cluster_resolver, self._worker_device + ) if local_device_type == "TPU": tpu_cluster_resolver.initialize_tpu_system() self._collective_keys = cross_device_utils.CollectiveKeys( - group_key_start=1 + self._collective_key_base) + group_key_start=1 + self._collective_key_base + ) self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( devices=local_devices, group_size=len(local_devices) * self._num_workers, options=self._communication_options, - collective_keys=self._collective_keys) + collective_keys=self._collective_keys, + ) # CrossDeviceOps for per host tensors. self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( devices=[self._worker_device], group_size=self._num_workers, options=self._communication_options, - collective_keys=self._collective_keys) + collective_keys=self._collective_keys, + ) super(CollectiveAllReduceExtended, self)._initialize_single_worker( - local_devices) + local_devices + ) # Add a default device so that ops without specified devices will not end up # on other workers. @@ -576,9 +654,14 @@ def _initialize_multi_worker(self, cluster_resolver): logging.info( "MultiWorkerMirroredStrategy with cluster_spec = %r, task_type = %r, " "task_id = %r, num_workers = %r, local_devices = %r, " - "communication = %s", cluster_spec.as_dict(), task_type, task_id, - self._num_workers, local_devices, - self._communication_options.implementation) + "communication = %s", + cluster_spec.as_dict(), + task_type, + task_id, + self._num_workers, + local_devices, + self._communication_options.implementation, + ) def __del__(self): self._stop_check_health_thread() @@ -590,18 +673,19 @@ def _input_workers_with_options(self, options=None): else: return input_lib.InputWorkers([( host_device, - [device_util.get_host_for_device(worker) for worker in - self.worker_devices])]) + [ + device_util.get_host_for_device(worker) + for worker in self.worker_devices + ], + )]) @property def _input_workers(self): return self._input_workers_with_options() - def _get_variable_creator_initial_value(self, - replica_id, - device, - primary_var, - **kwargs): + def _get_variable_creator_initial_value( + self, replica_id, device, primary_var, **kwargs + ): if replica_id == 0: # First replica on each worker. assert device is not None assert primary_var is None @@ -610,8 +694,9 @@ def initial_value_fn(): # pylint: disable=g-missing-docstring # Only the first device participates in the broadcast of initial values. group_key = self._collective_keys.get_group_key([device]) group_size = self._num_workers - collective_instance_key = ( - self._collective_keys.get_instance_key(group_key, device)) + collective_instance_key = self._collective_keys.get_instance_key( + group_key, device + ) with ops.device(device): initial_value = kwargs["initial_value"] @@ -621,41 +706,56 @@ def initial_value_fn(): # pylint: disable=g-missing-docstring initial_value = initial_value.wrapped_value assert not callable(initial_value) initial_value = ops.convert_to_tensor( - initial_value, dtype=kwargs.get("dtype", None)) + initial_value, dtype=kwargs.get("dtype", None) + ) if self._num_workers > 1: if self._is_chief: bcast_send = collective_ops.broadcast_send( - initial_value, initial_value.shape, initial_value.dtype, - group_size, group_key, collective_instance_key) + initial_value, + initial_value.shape, + initial_value.dtype, + group_size, + group_key, + collective_instance_key, + ) with ops.control_dependencies([bcast_send]): return array_ops.identity(initial_value) else: - return collective_ops.broadcast_recv(initial_value.shape, - initial_value.dtype, - group_size, group_key, - collective_instance_key) + return collective_ops.broadcast_recv( + initial_value.shape, + initial_value.dtype, + group_size, + group_key, + collective_instance_key, + ) return initial_value return initial_value_fn else: - return super(CollectiveAllReduceExtended, - self)._get_variable_creator_initial_value( - replica_id=replica_id, - device=device, - primary_var=primary_var, - **kwargs) + return super( + CollectiveAllReduceExtended, self + )._get_variable_creator_initial_value( + replica_id=replica_id, + device=device, + primary_var=primary_var, + **kwargs, + ) def _make_input_context(self): input_context = distribute_lib.InputContext( num_input_pipelines=self._num_workers, input_pipeline_id=self._id_in_cluster, - num_replicas_in_sync=self._num_replicas_in_sync) + num_replicas_in_sync=self._num_replicas_in_sync, + ) return input_context def _experimental_distribute_dataset(self, dataset, options): - if (options and options.experimental_replication_mode == - distribute_lib.InputReplicationMode.PER_REPLICA): + if ( + options + and options.experimental_replication_mode + == distribute_lib.InputReplicationMode.PER_REPLICA + ): raise NotImplementedError( "InputReplicationMode.PER_REPLICA " "is only supported in " @@ -669,32 +769,38 @@ def _experimental_distribute_dataset(self, dataset, options): self._container_strategy(), num_replicas_in_sync=self._num_replicas_in_sync, input_context=input_context, - options=options) + options=options, + ) def _distribute_datasets_from_function(self, dataset_fn, options): - if (options and options.experimental_replication_mode == - distribute_lib.InputReplicationMode.PER_REPLICA): + if ( + options + and options.experimental_replication_mode + == distribute_lib.InputReplicationMode.PER_REPLICA + ): raise NotImplementedError( "InputReplicationMode.PER_REPLICA " "is only supported in " "`distribute_datasets_from_function` " - "of tf.distribute.MirroredStrategy") + "of tf.distribute.MirroredStrategy" + ) input_context = self._make_input_context() return input_util.get_distributed_datasets_from_function( dataset_fn=dataset_fn, input_workers=self._input_workers_with_options(options), input_contexts=[input_context], strategy=self._container_strategy(), - options=options) + options=options, + ) def _experimental_distribute_values_from_function(self, value_fn): per_replica_values = [] num_local_replicas = len(self.worker_devices) for local_replica_id in range(num_local_replicas): - replica_id = (self._id_in_cluster * num_local_replicas + - local_replica_id) + replica_id = self._id_in_cluster * num_local_replicas + local_replica_id value_context = distribute_lib.ValueContext( - replica_id, self._num_replicas_in_sync) + replica_id, self._num_replicas_in_sync + ) per_replica_values.append(value_fn(value_context)) return distribute_utils.regroup(per_replica_values, always_wrap=True) @@ -706,23 +812,26 @@ def _make_dataset_iterator(self, dataset): self._input_workers, self._container_strategy(), num_replicas_in_sync=self._num_replicas_in_sync, - input_context=input_context) + input_context=input_context, + ) def _make_input_fn_iterator( self, input_fn, - replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): + replication_mode=distribute_lib.InputReplicationMode.PER_WORKER, + ): """Distributes the input function to each local GPU.""" input_context = self._make_input_context() - return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers, - [input_context], - self._container_strategy()) - - def _configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): + return input_lib_v1.InputFunctionIterator( + input_fn, + self._input_workers, + [input_context], + self._container_strategy(), + ) + + def _configure( + self, session_config=None, cluster_spec=None, task_type=None, task_id=None + ): """Configures the object. Args: @@ -741,11 +850,14 @@ def _configure(self, task_type=task_type, task_id=task_id, num_accelerators={ - self._local_device_type: self._num_devices_per_worker}, - rpc_layer=self._rpc_layer) + self._local_device_type: self._num_devices_per_worker + }, + rpc_layer=self._rpc_layer, + ) self._initialize_multi_worker(cluster_resolver) - assert isinstance(self._cross_device_ops, - cross_device_ops_lib.CollectiveAllReduce) + assert isinstance( + self._cross_device_ops, cross_device_ops_lib.CollectiveAllReduce + ) if session_config: session_config.CopyFrom(self._update_config_proto(session_config)) @@ -757,16 +869,19 @@ def _update_config_proto(self, config_proto): # all-reduces. rewrite_options = updated_config.graph_options.rewrite_options rewrite_options.scoped_allocator_optimization = ( - rewriter_config_pb2.RewriterConfig.ON) + rewriter_config_pb2.RewriterConfig.ON + ) # We turn on ScopedAllocator only for CollectiveReduce op, i.e. enable_op = # ["CollectiveReduce"]. Since we can't assign to a repeated proto field, we # clear and then append. del rewrite_options.scoped_allocator_opts.enable_op[:] rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce") - if (not ops.executing_eagerly_outside_functions() and - self._communication_options.implementation == - collective_util.CommunicationImplementation.NCCL): + if ( + not ops.executing_eagerly_outside_functions() + and self._communication_options.implementation + == collective_util.CommunicationImplementation.NCCL + ): updated_config.experimental.collective_nccl = True if not self._cluster_spec: @@ -778,13 +893,16 @@ def _update_config_proto(self, config_proto): # Collective group leader is needed for collective ops to coordinate # workers. updated_config.experimental.collective_group_leader = ( - multi_worker_util.collective_leader(self._cluster_spec, self._task_type, - self._task_id)) + multi_worker_util.collective_leader( + self._cluster_spec, self._task_type, self._task_id + ) + ) # The device filters prevent communication between workers. del updated_config.device_filters[:] updated_config.device_filters.append( - "/job:%s/task:%d" % (self._task_type, self._task_id)) + "/job:%s/task:%d" % (self._task_type, self._task_id) + ) return updated_config @@ -806,36 +924,42 @@ def _get_cross_device_ops(self, value): def _gather_to_implementation(self, value, destinations, axis, options): return self._get_cross_device_ops(value)._gather( # pylint: disable=protected-access - value, - destinations=destinations, - axis=axis, - options=options) + value, destinations=destinations, axis=axis, options=options + ) def _reduce_to(self, reduce_op, value, destinations, options): - if (isinstance(value, values.Mirrored) and - reduce_op == reduce_util.ReduceOp.MEAN): + if ( + isinstance(value, values.Mirrored) + and reduce_op == reduce_util.ReduceOp.MEAN + ): return value assert not isinstance(value, values.Mirrored) - if (isinstance(value, values.DistributedValues) and - len(self.worker_devices) == 1): + if ( + isinstance(value, values.DistributedValues) + and len(self.worker_devices) == 1 + ): value = value.values[0] # When there are multiple workers, we need to reduce across workers using # collective ops. - if (not isinstance(value, values.DistributedValues) and - self._num_workers == 1): + if ( + not isinstance(value, values.DistributedValues) + and self._num_workers == 1 + ): # This function handles reducing values that are not PerReplica or # Mirrored values. For example, the same value could be present on all # replicas in which case `value` would be a single value or value could # be 0. return cross_device_ops_lib.reduce_non_distributed_value( - reduce_op, value, destinations, len(self.worker_devices)) + reduce_op, value, destinations, len(self.worker_devices) + ) return self._get_cross_device_ops(value).reduce( reduce_op, value, destinations=destinations, - options=self._communication_options.merge(options)) + options=self._communication_options.merge(options), + ) def _replica_ctx_all_reduce(self, reduce_op, value, options=None): """Implements `StrategyExtendedV2._replica_ctx_all_reduce`.""" @@ -854,12 +978,14 @@ def _replica_ctx_all_reduce(self, reduce_op, value, options=None): replica_context = distribute_lib.get_replica_context() assert replica_context, ( "`StrategyExtended._replica_ctx_all_reduce` must be called in a " - "replica context") + "replica context" + ) return self._cross_device_ops._all_reduce( # pylint: disable=protected-access reduce_op, value, replica_context._replica_id, # pylint: disable=protected-access - options) + options, + ) def _check_health(self): while True: @@ -873,32 +999,44 @@ def _check_health(self): attempts += 1 try: context.context().check_collective_ops_peer_health( - peer, timeout_in_ms=self._check_health_timeout * 1000) + peer, timeout_in_ms=self._check_health_timeout * 1000 + ) # If check_collective_ops_peer_health doesn't raise an Exception, # the peer is healthy. break - except (errors.UnavailableError, errors.FailedPreconditionError, - errors.DeadlineExceededError) as e: + except ( + errors.UnavailableError, + errors.FailedPreconditionError, + errors.DeadlineExceededError, + ) as e: # TODO(b/151232436): Always raise UnavailableError when a peer # fails. Now there could be many kinds of errors: # - Unavailable: when the peer is not reachable, e.g. it's down. # - FailedPrecondition: when the peer has restarted. if attempts < self._check_health_retry_limit: - logging.warning("%s seems down, retrying %d/%d", peer, attempts, - self._check_health_retry_limit) + logging.warning( + "%s seems down, retrying %d/%d", + peer, + attempts, + self._check_health_retry_limit, + ) continue logging.error( "Cluster check alive failed, %s is down, " - "aborting collectives: %s", peer, e) + "aborting collectives: %s", + peer, + e, + ) context.context().abort_collective_ops( errors.UNAVAILABLE, - "cluster check alive failed, {} is down".format(peer)) + "cluster check alive failed, {} is down".format(peer), + ) return except Exception as e: # pylint: disable=broad-except logging.error("Unexpected exception in check alive: %s", e) context.context().abort_collective_ops( - errors.INTERNAL, - "unexecpted exception in check alive: %s" % e) + errors.INTERNAL, "unexecpted exception in check alive: %s" % e + ) return time.sleep(self._check_health_interval) @@ -912,8 +1050,10 @@ def _start_check_health_thread(self): # # TODO(b/151232436): change to an explicit barrier if we have it. dummy_value = array_ops.identity([]) - logging.info("Waiting for the cluster, timeout = %s", - self._check_health_initial_timeout or "inf") + logging.info( + "Waiting for the cluster, timeout = %s", + self._check_health_initial_timeout or "inf", + ) try: self._host_cross_device_ops.reduce( reduce_util.ReduceOp.SUM, @@ -921,21 +1061,24 @@ def _start_check_health_thread(self): dummy_value, options=collective_util.Options( timeout_seconds=self._check_health_initial_timeout, - implementation=collective_util.CommunicationImplementation.RING)) + implementation=collective_util.CommunicationImplementation.RING, + ), + ) if context.is_async(): context.async_wait() except errors.DeadlineExceededError: raise RuntimeError( - "Timeout waiting for the cluster, timeout is %d seconds" % - self._check_health_initial_timeout) + "Timeout waiting for the cluster, timeout is %d seconds" + % self._check_health_initial_timeout + ) logging.info("Cluster is ready.") self._check_health_thread_should_stop = threading.Event() # Start the thread as daemon to avoid it blocking the program from exiting. # We try best to shutdown the thread but __del__ is not guaranteed to be # called when program exists. self._check_health_thread = threading.Thread( - target=self._check_health, - daemon=True) + target=self._check_health, daemon=True + ) self._check_health_thread.start() def _stop_check_health_thread(self): @@ -947,11 +1090,13 @@ def _stop_check_health_thread(self): logging.info("check health thread stopped") def _warn_nccl_no_gpu(self): - if ((self._communication_options.implementation == - collective_util.CommunicationImplementation.NCCL) and - self._local_device_type != "GPU"): - logging.warning("Enabled NCCL communication but no GPUs detected/" - "specified.") + if ( + self._communication_options.implementation + == collective_util.CommunicationImplementation.NCCL + ) and self._local_device_type != "GPU": + logging.warning( + "Enabled NCCL communication but no GPUs detected/specified." + ) def _in_multi_worker_mode(self): """Whether this strategy indicates working in multi-worker settings.""" @@ -993,15 +1138,17 @@ def _get_replica_id_in_sync_group(self, replica_id): return self._id_in_cluster * len(self.worker_devices) + replica_id def _get_local_replica_id(self, replica_id_in_sync_group): - return (replica_id_in_sync_group - - self._id_in_cluster * len(self.worker_devices)) + return replica_id_in_sync_group - self._id_in_cluster * len( + self.worker_devices + ) def __deepcopy__(self, memo): # We check the check health thread instead of whether we are in eager mode # to limit the backward incompatibility. if hasattr(self, "_check_health_thread"): raise ValueError( - "MultiWorkerMirroredStrategy cannot be deep copied in eager mode.") + "MultiWorkerMirroredStrategy cannot be deep copied in eager mode." + ) # Otherwise, do a regular deepcopy. cls = self.__class__ result = cls.__new__(cls) diff --git a/tensorflow/python/distribute/parameter_server_strategy_v2.py b/tensorflow/python/distribute/parameter_server_strategy_v2.py index b69d879aa6b15c..d6bdcd0ac13d6f 100644 --- a/tensorflow/python/distribute/parameter_server_strategy_v2.py +++ b/tensorflow/python/distribute/parameter_server_strategy_v2.py @@ -72,7 +72,8 @@ @tf_export( "distribute.experimental.ParameterServerStrategy", "distribute.ParameterServerStrategy", - v1=[]) + v1=[], +) class ParameterServerStrategyV2(distribute_lib.Strategy): """An multi-worker tf.distribute strategy with parameter servers. @@ -607,14 +608,16 @@ def __init__( self._being_scheduled = False self._set_num_gpus() distribute_lib.distribution_strategy_replica_gauge.get_cell( - "num_gpus_per_worker").set(self._num_gpus_per_worker) + "num_gpus_per_worker" + ).set(self._num_gpus_per_worker) # Don't canonicalize the devices here since this code is executed on Chief, # but we want the reduce evaluation to be done on each worker. Placer will # automatically choose the right device based on current context. # TODO(ishark): Use select_cross_device_ops instead. self._cross_device_ops = cross_device_ops_lib.ReductionToOneDevice( - reduce_to_device="/device:CPU:0") + reduce_to_device="/device:CPU:0" + ) self._cross_device_ops._canonicalize_devices = False # pylint: disable=protected-access self._allow_run_without_coordinator = False self._coordinator_creation_lock = threading.Lock() @@ -650,17 +653,24 @@ def var_creator(**kwargs): # Create and wrap the variable. v = next_creator(**kwargs) wrapped_v = ps_values.CachingVariable(v) - wrapped = ps_values.AggregatingVariable(self._container_strategy(), - wrapped_v, aggregation) + wrapped = ps_values.AggregatingVariable( + self._container_strategy(), wrapped_v, aggregation + ) return wrapped if self._num_replicas_in_sync > 1: - if aggregation not in (vs.VariableAggregation.NONE, - vs.VariableAggregation.SUM, - vs.VariableAggregation.MEAN, - vs.VariableAggregation.ONLY_FIRST_REPLICA): - raise ValueError("Invalid variable aggregation mode: " + aggregation + - " for variable: " + kwargs["name"]) + if aggregation not in ( + vs.VariableAggregation.NONE, + vs.VariableAggregation.SUM, + vs.VariableAggregation.MEAN, + vs.VariableAggregation.ONLY_FIRST_REPLICA, + ): + raise ValueError( + "Invalid variable aggregation mode: " + + aggregation + + " for variable: " + + kwargs["name"] + ) return var_creator else: @@ -673,7 +683,8 @@ def variable_creator_single_replica(**kwargs): def _create_per_worker_variable(self, next_creator, **kwargs): """Create an unsynced, unaggregated variable on each worker.""" return ps_values.PerWorkerVariable( - self._container_strategy(), next_creator, **kwargs) + self._container_strategy(), next_creator, **kwargs + ) def _create_variable(self, next_creator, **kwargs): """Implements StrategyExtendedV2._create_variable. @@ -708,7 +719,10 @@ def _create_variable(self, next_creator, **kwargs): var = var_creator(**kwargs) logging.debug( "Creating variable (name:%s, shape:%r) that colocates with %s", - var.name, var.shape, kwargs["colocate_with"].name) + var.name, + var.shape, + kwargs["colocate_with"].name, + ) return var if self._variable_partitioner is None: @@ -766,11 +780,15 @@ def initializer(shape, dtype, **kwargs): return self._create_variable_round_robin(var_creator, **kwargs) num_partitions = self._variable_partitioner(shape=shape, dtype=dtype) - if not num_partitions or num_partitions[0] == 0 or any( - v != 1 for v in num_partitions[1:]): + if ( + not num_partitions + or num_partitions[0] == 0 + or any(v != 1 for v in num_partitions[1:]) + ): raise ValueError( "variable_partitioner must return a list/tuple whose elements are 1" - " besides the first element (non-zero), got: %r" % num_partitions) + " besides the first element (non-zero), got: %r" % num_partitions + ) if num_partitions[0] == 1: # no partition return self._create_variable_round_robin(var_creator, **kwargs) @@ -793,19 +811,25 @@ def initializer(shape, dtype, **kwargs): def init_shard_fn(shard_index): if not init_from_fn: logging.log_if( - logging.WARN, _INEFFICIENT_INIT_WARNING % name, shard_index == 0 and - shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS) - return initial_value[offsets[shard_index]:offsets[shard_index + 1]] - partition_shape = (offsets[shard_index + 1] - - offsets[shard_index],) + shape[1:] + logging.WARN, + _INEFFICIENT_INIT_WARNING % name, + shard_index == 0 + and shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS, + ) + return initial_value[offsets[shard_index] : offsets[shard_index + 1]] + partition_shape = ( + offsets[shard_index + 1] - offsets[shard_index], + ) + shape[1:] partition_offset = (offsets[shard_index],) + (0,) * len(shape[1:]) arg_spec = tf_inspect.getfullargspec(initial_value) - if ("shard_info" not in arg_spec.args and - "shard_info" not in arg_spec.kwonlyargs): + if ( + "shard_info" not in arg_spec.args + and "shard_info" not in arg_spec.kwonlyargs + ): try: value = initial_value( - partition_shape=partition_shape, - partition_offset=partition_offset) + partition_shape=partition_shape, partition_offset=partition_offset + ) except (TypeError, ValueError): # TypeError: Initializer doesn't accept kwargs # ValueError: Initializer doesn't accept partition kwargs @@ -819,16 +843,20 @@ def init_shard_fn(shard_index): # Initializer doesn't support partition: value is the full value # and needs to be sliced to get the partition value. logging.log_if( - logging.WARN, _INEFFICIENT_INIT_WARNING % name, - shard_index == 0 and - shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS) - return value[offsets[shard_index]:offsets[shard_index + 1]] + logging.WARN, + _INEFFICIENT_INIT_WARNING % name, + shard_index == 0 + and shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS, + ) + return value[offsets[shard_index] : offsets[shard_index + 1]] else: # For compatibility with `CheckpointInitialValueCallable`. return initial_value( shard_info=trackable.ShardInfo( shape=tensor_shape.as_shape(partition_shape), - offset=partition_offset)) + offset=partition_offset, + ) + ) var_list = [] for i in range(num_partitions): @@ -847,17 +875,22 @@ def _create_variable_round_robin(self, next_creator, **kwargs): with ops.colocate_with(None, ignore_existing=True): # Explicitly set CPU:0 device for PS in case create variable is called # inside replica_fn and worker has with GPU:0 scope. - with ops.device("/job:ps/task:%d/device:CPU:0" % - (self._variable_count % self._num_ps)): + with ops.device( + "/job:ps/task:%d/device:CPU:0" % (self._variable_count % self._num_ps) + ): var = next_creator(**kwargs) log_method = ( - logging.info if os.getenv("TF_PSS_VERBOSE_VARIABLE_PLACEMENT") + logging.info + if os.getenv("TF_PSS_VERBOSE_VARIABLE_PLACEMENT") else logging.debug ) log_method( "Creating variable (name:%s, shape:%r) on " - "/job:ps/task:%d/device:CPU:0", var.name, var.shape, - (self._variable_count % self._num_ps)) + "/job:ps/task:%d/device:CPU:0", + var.name, + var.shape, + (self._variable_count % self._num_ps), + ) self._variable_count += 1 return var @@ -866,7 +899,8 @@ def _resource_creator_scope(self): with self._coordinator_creation_lock: if not self._container_strategy()._cluster_coordinator: # pylint: disable=protected-access cluster_coordinator.ClusterCoordinator( - strategy=self._container_strategy()) + strategy=self._container_strategy() + ) # TODO(wxinyi): We should warn the user of the inefficiency of creating # `StaticHashTable` inside a `@tf.function`-wrapped `dataset_fn` to be @@ -880,31 +914,38 @@ def _resource_creator_scope(self): def lookup_creator(next_creator, *args, **kwargs): if load_context.in_load_context(): - return (ps_values.RestoredDistributedTable( - self._container_strategy(), lambda: next_creator(*args, **kwargs))) # pylint: disable=protected-access + return ps_values.RestoredDistributedTable( + self._container_strategy(), lambda: next_creator(*args, **kwargs) + ) # pylint: disable=protected-access else: - return ps_values.DistributedTable(self._container_strategy(), - lambda: next_creator(*args, **kwargs)) # pylint: disable=protected-access + return ps_values.DistributedTable( + self._container_strategy(), lambda: next_creator(*args, **kwargs) + ) # pylint: disable=protected-access def restored_lookup_creator(next_creator, *args, **kwargs): - return (ps_values.RestoredDistributedTable( - self._container_strategy(), lambda: next_creator(*args, **kwargs))) # pylint: disable=protected-access + return ps_values.RestoredDistributedTable( + self._container_strategy(), lambda: next_creator(*args, **kwargs) + ) # pylint: disable=protected-access return [ ops.resource_creator_scope("StaticHashTable", lookup_creator), - ops.resource_creator_scope("RestoredStaticHashTable", - restored_lookup_creator) + ops.resource_creator_scope( + "RestoredStaticHashTable", restored_lookup_creator + ), ] def _assert_used_with_cluster_coordinator(self): - if (not self._used_with_coordinator and - not self._allow_run_without_coordinator): + if ( + not self._used_with_coordinator + and not self._allow_run_without_coordinator + ): raise NotImplementedError( "`tf.distribute.experimental.ParameterServerStrategy` must be used " "with `tf.distribute.experimental.coordinator.ClusterCoordinator` in " "a custom training loop. If you are using `Model.fit`, please supply " "a dataset function directly to a " - "`tf.keras.utils.experimental.DatasetCreator` instead.") + "`tf.keras.utils.experimental.DatasetCreator` instead." + ) def _assert_being_scheduled_by_cluster_coordinator(self): if not self._being_scheduled and not self._allow_run_without_coordinator: @@ -915,14 +956,16 @@ def _assert_being_scheduled_by_cluster_coordinator(self): "coordinator, which can be slow. To properly dispatch functions to " "run on workers, methods like `run` or `reduce` should be used " "within a function passed to `tf.distribute.experimental.coordinator." - "ClusterCoordinator.schedule`.") + "ClusterCoordinator.schedule`." + ) # options is not used right now. But we may want to support options while # creating InputWorkers in future, similar to MirroredStrategy. def _input_workers_with_options(self, options=None): input_workers_devices = (("/device:CPU:0", self.worker_devices),) return input_lib.InputWorkers( - input_workers_devices, canonicalize_devices=False) + input_workers_devices, canonicalize_devices=False + ) def _experimental_distribute_dataset(self, dataset, options): input_workers_devices = self._input_workers_with_options() @@ -936,7 +979,8 @@ def _experimental_distribute_dataset(self, dataset, options): self._container_strategy(), num_replicas_in_sync=self._num_replicas_in_sync, options=options, - build=ops.inside_function()) # will be built by ClusterCoordinator + build=ops.inside_function(), + ) # will be built by ClusterCoordinator def _distribute_datasets_from_function(self, dataset_fn, options): # There is no synchronization beyond a worker and thus, the number of @@ -947,7 +991,8 @@ def _distribute_datasets_from_function(self, dataset_fn, options): input_context = distribute_lib.InputContext( num_input_pipelines=num_input_pipelines_in_sync, input_pipeline_id=input_pipeline_id_in_sync, - num_replicas_in_sync=self._num_replicas_in_sync) + num_replicas_in_sync=self._num_replicas_in_sync, + ) # If this DistributedDatasetFromFunction is created outside # ClusterCoordinator, i,e, outside a tf.function, we don't build its @@ -955,10 +1000,12 @@ def _distribute_datasets_from_function(self, dataset_fn, options): # ClusterCoordinator.create_per_worker_dataset. return input_util.get_distributed_datasets_from_function( dataset_fn, - self._input_workers_with_options(options), [input_context], + self._input_workers_with_options(options), + [input_context], self._container_strategy(), options=options, - build=ops.inside_function()) # will be built by ClusterCoordinator + build=ops.inside_function(), + ) # will be built by ClusterCoordinator @property def worker_devices(self): @@ -972,15 +1019,17 @@ def worker_devices(self): def _call_for_each_replica(self, fn, args, kwargs): self._assert_being_scheduled_by_cluster_coordinator() - return mirrored_run.call_for_each_replica(self._container_strategy(), fn, - args, kwargs) + return mirrored_run.call_for_each_replica( + self._container_strategy(), fn, args, kwargs + ) def _reduce(self, reduce_op, value): self._assert_being_scheduled_by_cluster_coordinator() dst = device_util.current() or self._default_device or "/device:CPU:0" destinations = device_util.canonicalize_without_job_and_task(dst) result = self._local_results( - self.reduce_to(reduce_op, value, destinations))[0] + self.reduce_to(reduce_op, value, destinations) + )[0] return result def _reduce_to(self, reduce_op, value, destinations, options): @@ -989,7 +1038,8 @@ def _reduce_to(self, reduce_op, value, destinations, options): def get_values(x): if isinstance(x, values.DistributedValues): return self._cross_device_ops.reduce( - reduce_op, x, destinations=destinations) # pylint: disable=protected-access + reduce_op, x, destinations=destinations + ) # pylint: disable=protected-access return x return nest.map_structure(get_values, value) @@ -1004,6 +1054,7 @@ def get_values(x): "footprint, explicitly specify `dtype` and `shape` when creating " "variables, and use `tf.initializers` to initialize the variable. " "Note that some initializers (e.g., orthogonal) don't support " - "memory-efficient initialization and there is not much you can do here.") + "memory-efficient initialization and there is not much you can do here." +) _LARGE_VARIABLE_NUM_ELEMENTS = 1e9 diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index e400fbaa117209..fee09b7dc4a599 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -70,7 +70,8 @@ DEVICE_PLACEMENT_WARN = pywrap_tfe.TFE_DEVICE_PLACEMENT_WARN DEVICE_PLACEMENT_SILENT = pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT DEVICE_PLACEMENT_SILENT_FOR_INT32 = ( - pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) + pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 +) SYNC = 0 ASYNC = 1 @@ -79,7 +80,8 @@ _python_eager_context_create_counter = monitoring.Counter( "/tensorflow/api/python/eager_context_create_counter", - "Counter for number of eager contexts created in Python.") + "Counter for number of eager contexts created in Python.", +) # Re-exporting through context. is_tfrt_enabled = tfrt_utils.enabled @@ -239,16 +241,20 @@ def config_proto_serialized(self): def config_proto_serialized(self, config): if isinstance(config, config_pb2.ConfigProto): self._config_proto_serialized = config.SerializeToString( - deterministic=True) + deterministic=True + ) elif isinstance(config, str): self._config_proto_serialized = config elif config is None: self._config_proto_serialized = ( - config_pb2.ConfigProto().SerializeToString()) + config_pb2.ConfigProto().SerializeToString() + ) else: - raise ValueError("the rewriter config must be either a " - "config_pb2.ConfigProto, or a serialized string of that " - "proto or None. got: {}".format(type(config))) + raise ValueError( + "the rewriter config must be either a " + "config_pb2.ConfigProto, or a serialized string of that " + "proto or None. got: {}".format(type(config)) + ) def as_attrs(self): if self.config_proto_serialized is None: @@ -291,7 +297,8 @@ def zeros_cache(self): ContextSwitch = collections.namedtuple( "ContextSwitch", - ["is_building_function", "enter_context_fn", "device_stack"]) + ["is_building_function", "enter_context_fn", "device_stack"], +) # `_ContextSwitchStack` is a `threading.local` to match the semantics of @@ -311,7 +318,8 @@ def __init__(self, eager): self.push( is_building_function=False, enter_context_fn=eager_mode, - device_stack=None) + device_stack=None, + ) def push(self, is_building_function, enter_context_fn, device_stack): """Push metadata about a context switch onto the stack. @@ -331,7 +339,8 @@ def push(self, is_building_function, enter_context_fn, device_stack): """ self.stack.append( - ContextSwitch(is_building_function, enter_context_fn, device_stack)) + ContextSwitch(is_building_function, enter_context_fn, device_stack) + ) def pop(self): """Pop the stack.""" @@ -341,7 +350,8 @@ def pop(self): @tf_export("config.LogicalDevice") class LogicalDevice( - collections.namedtuple("LogicalDevice", ["name", "device_type"])): + collections.namedtuple("LogicalDevice", ["name", "device_type"]) +): """Abstraction for a logical device initialized by the runtime. A `tf.config.LogicalDevice` corresponds to an initialized logical device on a @@ -356,12 +366,20 @@ class LogicalDevice( """ -@tf_export("config.LogicalDeviceConfiguration", - "config.experimental.VirtualDeviceConfiguration") +@tf_export( + "config.LogicalDeviceConfiguration", + "config.experimental.VirtualDeviceConfiguration", +) class LogicalDeviceConfiguration( - collections.namedtuple("LogicalDeviceConfiguration", [ - "memory_limit", "experimental_priority", "experimental_device_ordinal" - ])): + collections.namedtuple( + "LogicalDeviceConfiguration", + [ + "memory_limit", + "experimental_priority", + "experimental_device_ordinal", + ], + ) +): """Configuration class for a logical devices. The class specifies the parameters to configure a `tf.config.PhysicalDevice` @@ -385,17 +403,21 @@ class LogicalDeviceConfiguration( Currently only supported for Nvidia GPUs. """ - def __new__(cls, - memory_limit=None, - experimental_priority=None, - experimental_device_ordinal=None): - return super().__new__(cls, memory_limit, experimental_priority, - experimental_device_ordinal) + def __new__( + cls, + memory_limit=None, + experimental_priority=None, + experimental_device_ordinal=None, + ): + return super().__new__( + cls, memory_limit, experimental_priority, experimental_device_ordinal + ) @tf_export("config.PhysicalDevice") class PhysicalDevice( - collections.namedtuple("PhysicalDevice", ["name", "device_type"])): + collections.namedtuple("PhysicalDevice", ["name", "device_type"]) +): """Abstraction for a locally visible physical device. TensorFlow can utilize various devices such as the CPU or multiple GPUs @@ -415,6 +437,7 @@ class PhysicalDevice( name: Unique identifier for device. device_type: String declaring the type of device such as "CPU" or "GPU". """ + pass @@ -458,11 +481,13 @@ class Context: # TODO(agarwal): create and link in some documentation for `execution_mode`. # pylint: disable=redefined-outer-name - def __init__(self, - config=None, - device_policy=None, - execution_mode=None, - server_def=None): + def __init__( + self, + config=None, + device_policy=None, + execution_mode=None, + server_def=None, + ): """Creates a new Context. Args: @@ -473,24 +498,18 @@ def __init__(self, operation on a device with inputs which are not on that device. When set to None, an appropriate value will be picked automatically. The value picked may change between TensorFlow releases. Defaults to - DEVICE_PLACEMENT_SILENT. - Valid values: - - DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not - correct. - - DEVICE_PLACEMENT_WARN: copies the tensors which are not on the right - device but raises a warning. - - DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might hide - performance problems. - - DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors, - raising errors on the other ones. + DEVICE_PLACEMENT_SILENT. Valid values: DEVICE_PLACEMENT_EXPLICIT - + raises an error if the placement is not correct. DEVICE_PLACEMENT_WARN - + copies the tensors which are not on the right device but raises a + warning. DEVICE_PLACEMENT_SILENT - silently copies the tensors. This + might hide performance problems. DEVICE_PLACEMENT_SILENT_FOR_INT32 - + silently copies int32 tensors, raising errors on the other ones. execution_mode: (Optional.) Policy controlling how operations dispatched are actually executed. When set to None, an appropriate value will be picked automatically. The value picked may change between TensorFlow - releases. - Valid values: - - SYNC: executes each operation synchronously. - - ASYNC: executes each operation asynchronously. These operations may - return "non-ready" handles. + releases. Valid values: - SYNC: executes each operation synchronously. + ASYNC - executes each operation asynchronously. These operations may + return "non-ready" handles. server_def: (Optional.) A tensorflow::ServerDef proto. Enables execution on remote devices. GrpcServers need to be started by creating an identical server_def to this, and setting the appropriate task_indexes, @@ -510,7 +529,8 @@ def __init__(self, self._thread_local_data = pywrap_tfe.EagerContextThreadLocalData( self, is_eager=lambda: default_execution_mode == EAGER_MODE, - device_spec=_starting_device_spec) + device_spec=_starting_device_spec, + ) self._context_switches = _ContextSwitchStack(self.executing_eagerly()) self._context_handle = None self._context_devices = None @@ -522,8 +542,9 @@ def __init__(self, self._device_policy = device_policy self._mirroring_policy = None if execution_mode not in (None, SYNC, ASYNC): - raise ValueError("execution_mode should be None/SYNC/ASYNC. Got %s" % - execution_mode) + raise ValueError( + "execution_mode should be None/SYNC/ASYNC. Got %s" % execution_mode + ) if execution_mode is None: execution_mode = SYNC self._default_is_async = execution_mode == ASYNC @@ -618,10 +639,14 @@ def _initialize_logical_devices(self): if spec.job == "localhost": spec = spec.replace(job=None, replica=None, task=None) logical_devices.append( - LogicalDevice(name=spec.to_string(), device_type=spec.device_type)) + LogicalDevice(name=spec.to_string(), device_type=spec.device_type) + ) dev_type = pywrap_tfe.TF_DeviceListType(device_list, i) - if (dev_type == "GPU" and spec.job == current_job and - spec.task == current_task): + if ( + dev_type == "GPU" + and spec.job == current_job + and spec.task == current_task + ): self._num_gpus += 1 finally: @@ -643,29 +668,37 @@ def ensure_initialized(self): pywrap_tfe.TFE_ContextOptionsSetConfig(opts, config_str) if self._device_policy is not None: pywrap_tfe.TFE_ContextOptionsSetDevicePlacementPolicy( - opts, self._device_policy) + opts, self._device_policy + ) if self._mirroring_policy is not None: pywrap_tfe.TFE_ContextOptionsSetMirroringPolicy( - opts, self._mirroring_policy) + opts, self._mirroring_policy + ) if self._default_is_async == ASYNC: pywrap_tfe.TFE_ContextOptionsSetAsync(opts, True) if self._use_tfrt is not None: pywrap_tfe.TFE_ContextOptionsSetTfrt(opts, self._use_tfrt) pywrap_tfe.TFE_ContextOptionsSetRunEagerOpAsFunction(opts, True) pywrap_tfe.TFE_ContextOptionsSetJitCompileRewrite( - opts, self._jit_compile_rewrite) + opts, self._jit_compile_rewrite + ) context_handle = pywrap_tfe.TFE_NewContext(opts) finally: pywrap_tfe.TFE_DeleteContextOptions(opts) assert not (self._server_def and self._collective_ops_server_def), ( "Cannot enable remote execution as well as collective ops at the " - "moment. If this is important to you, please file an issue.") + "moment. If this is important to you, please file an issue." + ) if self._server_def is not None: server_def_str = self._server_def.SerializeToString() timeout = 0 # Indicates no timeout. pywrap_tfe.TFE_ContextSetServerDefWithTimeoutAndRetries( - context_handle, _KEEP_ALIVE_SECS, server_def_str, timeout, - self._set_server_def_retries) + context_handle, + _KEEP_ALIVE_SECS, + server_def_str, + timeout, + self._set_server_def_retries, + ) elif self._collective_ops_server_def is not None: server_def_str = self._collective_ops_server_def.SerializeToString() pywrap_tfe.TFE_EnableCollectiveOps(context_handle, server_def_str) @@ -733,8 +766,9 @@ def set_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS): if self._context_handle: server_def_str = server_def.SerializeToString() - pywrap_tfe.TFE_ContextSetServerDef(self._context_handle, keep_alive_secs, - server_def_str) + pywrap_tfe.TFE_ContextSetServerDef( + self._context_handle, keep_alive_secs, server_def_str + ) self._initialize_logical_devices() # Clear all the caches in case there are remote tensors in them. @@ -766,8 +800,9 @@ def update_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS): if self._context_handle: server_def_str = server_def.SerializeToString() - pywrap_tfe.TFE_ContextUpdateServerDef(self._context_handle, - keep_alive_secs, server_def_str) + pywrap_tfe.TFE_ContextUpdateServerDef( + self._context_handle, keep_alive_secs, server_def_str + ) self._initialize_logical_devices() self._clear_caches() @@ -777,7 +812,7 @@ def check_alive(self, worker_name): Args: worker_name: a string representing the remote worker. It must be a fully - specified name like "/job:worker/replica:0/task:0". + specified name like "/job:worker/replica:0/task:0". Returns: a boolean indicating whether the remote worker is alive or not. @@ -824,19 +859,23 @@ def clear_executor_errors(self): else: raise ValueError("Context is not initialized.") - def configure_coordination_service(self, - service_type, - service_leader="", - enable_health_check=True, - cluster_register_timeout_in_ms=0, - heartbeat_timeout_in_ms=0, - shutdown_barrier_timeout_in_ms=0, - coordinated_jobs=None, - allow_new_incarnation_to_reconnect=False): + def configure_coordination_service( + self, + service_type, + service_leader="", + enable_health_check=True, + cluster_register_timeout_in_ms=0, + heartbeat_timeout_in_ms=0, + shutdown_barrier_timeout_in_ms=0, + coordinated_jobs=None, + allow_new_incarnation_to_reconnect=False, + ): """Enable distributed coordination service with specified configs.""" if self._context_handle: - logging.warning("Configuring coordination service type may not be " - "effective because the context is already initialized.") + logging.warning( + "Configuring coordination service type may not be " + "effective because the context is already initialized." + ) config = coordination_config_pb2.CoordinationServiceConfig() config.service_type = service_type if service_leader: @@ -846,13 +885,16 @@ def configure_coordination_service(self, config.heartbeat_timeout_in_ms = heartbeat_timeout_in_ms config.shutdown_barrier_timeout_in_ms = shutdown_barrier_timeout_in_ms config.allow_new_incarnation_to_reconnect = ( - allow_new_incarnation_to_reconnect) + allow_new_incarnation_to_reconnect + ) if coordinated_jobs is not None: if isinstance(coordinated_jobs, list): config.coordinated_job_list.extend(coordinated_jobs) else: - raise ValueError("`coordinated_jobs` must be list[CoordinatedJob] or " - "None, but got: %s" % (coordinated_jobs,)) + raise ValueError( + "`coordinated_jobs` must be list[CoordinatedJob] or " + "None, but got: %s" % (coordinated_jobs,) + ) self._coordination_service_config = config @property @@ -868,8 +910,9 @@ def set_config_key_value(self, key, value): def get_config_key_value(self, key, timeout_in_ms=0): ensure_initialized() with c_api_util.tf_buffer() as buffer_: - pywrap_tfe.TFE_GetConfigKeyValue(self._context_handle, key, - timeout_in_ms, buffer_) + pywrap_tfe.TFE_GetConfigKeyValue( + self._context_handle, key, timeout_in_ms, buffer_ + ) value = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8") return value @@ -885,8 +928,9 @@ def report_error_to_cluster(self, error_code, error_message): error_message: a string. The error message. """ if self._context_handle: - pywrap_tfe.TFE_ReportErrorToCluster(self._context_handle, error_code, - error_message) + pywrap_tfe.TFE_ReportErrorToCluster( + self._context_handle, error_code, error_message + ) else: raise ValueError("Context is not initialized.") @@ -901,8 +945,9 @@ def get_task_states(self, job_configs): """ if self._context_handle: job_names, task_nums = zip(*job_configs) - return pywrap_tfe.TFE_GetTaskStates(self._context_handle, job_names, - task_nums) + return pywrap_tfe.TFE_GetTaskStates( + self._context_handle, job_names, task_nums + ) else: raise ValueError("Context is not initialized.") @@ -916,8 +961,9 @@ def wait_at_barrier(self, barrier_id, timeout_in_ms): timeout_in_ms: Duration before the barrier times out and fails. """ ensure_initialized() - pywrap_tfe.TFE_WaitAtBarrier(self._context_handle, barrier_id, - timeout_in_ms) + pywrap_tfe.TFE_WaitAtBarrier( + self._context_handle, barrier_id, timeout_in_ms + ) def clear_kernel_cache(self): """Clear kernel cache and reset all stateful kernels.""" @@ -943,8 +989,10 @@ def enable_collective_ops(self, server_def): # TODO(b/129298253): Allow creating datasets/tensors before enabling # collective ops. if self._context_handle is not None: - logging.warning("Enabling collective ops after program startup may cause " - "error when accessing previously created tensors.") + logging.warning( + "Enabling collective ops after program startup may cause " + "error when accessing previously created tensors." + ) with self._initialize_lock: assert self._initialized server_def_str = self._collective_ops_server_def.SerializeToString() @@ -957,7 +1005,8 @@ def configure_collective_ops( collective_leader="", scoped_allocator_enabled_ops=("CollectiveReduce",), use_nccl_communication=False, - device_filters=None): + device_filters=None, + ): """Configure collective ops. Collective group leader is necessary for collective ops to run, other @@ -966,7 +1015,7 @@ def configure_collective_ops( Args: collective_leader: a device string for collective leader, e.g. "/job:worker/replica:0/task:0"; empty string means local execution of - collective ops. + collective ops. scoped_allocator_enabled_ops: a tuple or a list of op names for scoped allocator to run with. use_nccl_communication: whether to use nccl communication for collective @@ -978,11 +1027,13 @@ def configure_collective_ops( RuntimeError: if this method is not called at program startup. """ if self._collective_leader is not None: - if (self._collective_leader != collective_leader or - self._collective_scoped_allocator_enabled_ops != - scoped_allocator_enabled_ops or - self._collective_use_nccl_communication != use_nccl_communication or - self._collective_device_filters != device_filters): + if ( + self._collective_leader != collective_leader + or self._collective_scoped_allocator_enabled_ops + != scoped_allocator_enabled_ops + or self._collective_use_nccl_communication != use_nccl_communication + or self._collective_device_filters != device_filters + ): raise ValueError("Collective ops are already configured.") else: return @@ -1029,8 +1080,9 @@ def check_collective_ops_peer_health(self, task, timeout_in_ms): tf.errors.InvalidArgumentError: when the task string is invalid. """ self.ensure_initialized() - pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task, - timeout_in_ms) + pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth( + self._handle, task, timeout_in_ms + ) @property def _handle(self): @@ -1150,21 +1202,23 @@ def execution_mode(self): def execution_mode(self, mode): """Sets execution mode for current thread.""" if mode not in (None, SYNC, ASYNC): - raise ValueError("Execution mode should be None/SYNC/ASYNC. Got %s" % - mode) + raise ValueError( + "Execution mode should be None/SYNC/ASYNC. Got %s" % mode + ) if mode is None: mode = SYNC - enable_async = (mode == ASYNC) + enable_async = mode == ASYNC if self.is_async() != enable_async: # Only set the execution mode if the context has already been initialized if self._context_handle is not None: self.executor.wait() executor_new = executor.new_executor(enable_async) self._thread_local_data.executor = executor_new - pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, - executor_new.handle()) + pywrap_tfe.TFE_ContextSetExecutorForThread( + self._context_handle, executor_new.handle() + ) else: self._default_is_async = enable_async @@ -1178,7 +1232,8 @@ def is_async(self): def executor(self): self.ensure_initialized() return executor.Executor( - pywrap_tfe.TFE_ContextGetExecutorForThread(self._context_handle)) + pywrap_tfe.TFE_ContextGetExecutorForThread(self._context_handle) + ) @executor.setter def executor(self, e): @@ -1198,7 +1253,9 @@ def config(self): if self._optimizer_jit is not None: config.graph_options.optimizer_options.global_jit_level = ( config_pb2.OptimizerOptions.ON_1 - if self._optimizer_jit else config_pb2.OptimizerOptions.OFF) + if self._optimizer_jit + else config_pb2.OptimizerOptions.OFF + ) if self._intra_op_parallelism_threads is not None: config.intra_op_parallelism_threads = self._intra_op_parallelism_threads if self._inter_op_parallelism_threads is not None: @@ -1217,22 +1274,31 @@ def config(self): is_mlir_bridge_enabled = pywrap_tfe.TF_IsMlirBridgeEnabled() config.experimental.mlir_bridge_rollout = is_mlir_bridge_enabled - if (is_mlir_bridge_enabled == - config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED): + if ( + is_mlir_bridge_enabled + == config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED + ): config.experimental.enable_mlir_bridge = True if self._enable_mlir_graph_optimization is not None: config.experimental.enable_mlir_graph_optimization = ( - self._enable_mlir_graph_optimization) + self._enable_mlir_graph_optimization + ) def rewriter_toggle(option): toggle = self._optimizer_experimental_options.get(option, None) if toggle is None: return - setattr(config.graph_options.rewrite_options, option, - (rewriter_config_pb2.RewriterConfig.ON - if toggle else rewriter_config_pb2.RewriterConfig.OFF)) + setattr( + config.graph_options.rewrite_options, + option, + ( + rewriter_config_pb2.RewriterConfig.ON + if toggle + else rewriter_config_pb2.RewriterConfig.OFF + ), + ) def rewriter_bool(option): toggle = self._optimizer_experimental_options.get(option, None) @@ -1286,7 +1352,8 @@ def rewriter_bool(option): if self._collective_scoped_allocator_enabled_ops: rewrite_options = config.graph_options.rewrite_options rewrite_options.scoped_allocator_optimization = ( - rewriter_config_pb2.RewriterConfig.ON) + rewriter_config_pb2.RewriterConfig.ON + ) del rewrite_options.scoped_allocator_opts.enable_op[:] for op in self._collective_scoped_allocator_enabled_ops: rewrite_options.scoped_allocator_opts.enable_op.append(op) @@ -1300,7 +1367,8 @@ def rewriter_bool(option): # Configure coordination service if self._coordination_service_config: config.experimental.coordination_config.CopyFrom( - self._coordination_service_config) + self._coordination_service_config + ) return config @@ -1345,13 +1413,16 @@ def _compute_gpu_options(self): # devices. if device_ordinals and len(device_limits) != len(device_ordinals): raise ValueError( - "device_ordinals must be specified for all virtual devices") + "device_ordinals must be specified for all virtual devices" + ) virtual_devices.append( config_pb2.GPUOptions.Experimental.VirtualDevices( memory_limit_mb=device_limits, priority=priority, - device_ordinal=device_ordinals)) + device_ordinal=device_ordinals, + ) + ) # Only compute growth if virtual devices have not been configured and we # have GPUs @@ -1366,7 +1437,9 @@ def _compute_gpu_options(self): allow_growth=allow_growth, visible_device_list=",".join(visible_device_list), experimental=config_pb2.GPUOptions.Experimental( - virtual_devices=virtual_devices)) + virtual_devices=virtual_devices + ), + ) @property def function_call_options(self): @@ -1383,7 +1456,8 @@ def function_call_options(self): if self._soft_device_placement is None: config.allow_soft_placement = True self._thread_local_data.function_call_options = FunctionCallOptions( - config_proto=config) + config_proto=config + ) return self._thread_local_data.function_call_options @@ -1490,12 +1564,14 @@ def is_custom_device(self, device_name): self.ensure_initialized() return pywrap_tfe.TFE_Py_IsCustomDevice(self._handle, device_name) - def register_custom_device(self, device_capsule, device_name, - device_info_capsule): + def register_custom_device( + self, device_capsule, device_name, device_info_capsule + ): """Calls TFE_RegisterCustomDevice. See the non-member function.""" self.ensure_initialized() - pywrap_tfe.TFE_Py_RegisterCustomDevice(self._handle, device_capsule, - device_name, device_info_capsule) + pywrap_tfe.TFE_Py_RegisterCustomDevice( + self._handle, device_capsule, device_name, device_info_capsule + ) def pack_eager_tensors(self, tensors): """Pack multiple `EagerTensor`s of the same dtype and shape. @@ -1542,9 +1618,7 @@ def function_scope_id(self): def call_function(self, name, tensor_inputs, num_outputs): """Calls the function associated with the given name.""" attrs = tuple( - itertools.chain( - *self.function_call_options.as_attrs().items() - ) + itertools.chain(*self.function_call_options.as_attrs().items()) ) cancellation_context = cancellation.context() @@ -1597,10 +1671,13 @@ def remove_op_callback(self, callback): KeyError: If `callback` is not already registered. """ if callback not in self._thread_local_data.op_callbacks: - raise KeyError("The specified op callback has not been registered, " - "and hence cannot be removed.") + raise KeyError( + "The specified op callback has not been registered, " + "and hence cannot be removed." + ) del self._thread_local_data.op_callbacks[ - self._thread_local_data.op_callbacks.index(callback)] + self._thread_local_data.op_callbacks.index(callback) + ] @property def op_callbacks(self): @@ -1690,13 +1767,18 @@ def get_device_details(self, device): # pylint: disable=redefined-outer-name A dict with string keys. """ if not isinstance(device, PhysicalDevice): - raise ValueError("device must be a tf.config.PhysicalDevice, but got: " - "%s" % (device,)) - if (self._physical_device_to_index is None or - device not in self._physical_device_to_index): - raise ValueError("The PhysicalDevice must be one obtained from " - "calling `tf.config.list_physical_devices`, but got: " - "%s" % (device,)) + raise ValueError( + "device must be a tf.config.PhysicalDevice, but got: %s" % (device,) + ) + if ( + self._physical_device_to_index is None + or device not in self._physical_device_to_index + ): + raise ValueError( + "The PhysicalDevice must be one obtained from " + "calling `tf.config.list_physical_devices`, but got: " + "%s" % (device,) + ) index = self._physical_device_to_index[device] details = pywrap_tfe.TF_GetDeviceDetails(index) @@ -1705,9 +1787,11 @@ def get_device_details(self, device): # pylint: disable=redefined-outer-name try: major, minor = details["compute_capability"].split(".") details["compute_capability"] = (int(major), int(minor)) - except ValueError: - raise RuntimeError("Device returned compute capability an in invalid " - "format: %s" % details["compute_capability"]) + except ValueError as exc: + raise RuntimeError( + "Device returned compute capability an in invalid format: %s" + % details["compute_capability"] + ) from exc return details def _import_config(self): @@ -1728,7 +1812,8 @@ class representation. self.set_visible_devices([], "CPU") elif num_cpus > 1: self.set_logical_device_configuration( - cpus[0], [LogicalDeviceConfiguration() for _ in range(num_cpus)]) + cpus[0], [LogicalDeviceConfiguration() for _ in range(num_cpus)] + ) # Parse GPU options gpus = [d for d in self._physical_devices if d.device_type == "GPU"] @@ -1801,7 +1886,8 @@ def set_visible_devices(self, devices, device_type=None): if self._context_handle is not None: raise RuntimeError( - "Visible devices cannot be modified after being initialized") + "Visible devices cannot be modified after being initialized" + ) self._visible_device_list = visible_device_list @@ -1835,18 +1921,21 @@ def set_memory_growth(self, dev, enable): if dev in self._virtual_device_map: raise ValueError( - "Cannot set memory growth on device when virtual devices configured") + "Cannot set memory growth on device when virtual devices configured" + ) if dev.device_type != "GPU" and dev not in self._pluggable_devices: raise ValueError( - "Cannot set memory growth on non-GPU and non-Pluggable devices") + "Cannot set memory growth on non-GPU and non-Pluggable devices" + ) if self._memory_growth_map.get(dev) == enable: return if self._context_handle is not None: raise RuntimeError( - "Physical devices cannot be modified after being initialized") + "Physical devices cannot be modified after being initialized" + ) self._memory_growth_map[dev] = enable @@ -1869,29 +1958,38 @@ def set_logical_device_configuration(self, dev, virtual_devices): if dev.device_type == "CPU": for vdev in virtual_devices: if vdev.memory_limit is not None: - raise ValueError("Setting memory limit on CPU virtual devices is " - "currently not supported") + raise ValueError( + "Setting memory limit on CPU virtual devices is " + "currently not supported" + ) if vdev.experimental_priority is not None: - raise ValueError("Setting experimental_priority on CPU virtual " - " devices is currently not supported") + raise ValueError( + "Setting experimental_priority on CPU virtual " + " devices is currently not supported" + ) if vdev.experimental_device_ordinal is not None: - raise ValueError("Setting experimental_device_ordinal on CPU virtual " - " devices is currently not supported") + raise ValueError( + "Setting experimental_device_ordinal on CPU virtual " + " devices is currently not supported" + ) elif dev.device_type == "GPU": for vdev in virtual_devices: if vdev.memory_limit is None: raise ValueError( - "Setting memory limit is required for GPU virtual devices") + "Setting memory limit is required for GPU virtual devices" + ) else: - raise ValueError("Virtual devices are not supported for %s" % - dev.device_type) + raise ValueError( + "Virtual devices are not supported for %s" % dev.device_type + ) if self._virtual_device_map.get(dev) == virtual_devices: return if self._context_handle is not None: raise RuntimeError( - "Virtual devices cannot be modified after being initialized") + "Virtual devices cannot be modified after being initialized" + ) self._virtual_device_map[dev] = virtual_devices @@ -1913,10 +2011,15 @@ def set_logical_cpu_devices(self, num_cpus, prefix=""): server_def = self._server_def or self._collective_ops_server_def local_prefix = ["/device"] if server_def is not None: - local_prefix.append("/job:%s/replica:0/task:%d" % (server_def.job_name, - server_def.task_index)) - logical_local_devices = [d for d in self.list_logical_devices("CPU") if - d.name.startswith(tuple(local_prefix))] + local_prefix.append( + "/job:%s/replica:0/task:%d" + % (server_def.job_name, server_def.task_index) + ) + logical_local_devices = [ + d + for d in self.list_logical_devices("CPU") + if d.name.startswith(tuple(local_prefix)) + ] self.ensure_initialized() # Error out if there are already multiple logical CPU in the context. if len(logical_local_devices) > 1: @@ -1966,7 +2069,8 @@ def get_compiler_ir( ) @deprecated( - None, "XLA:CPU and XLA:GPU devices are deprecated", warn_once=True) + None, "XLA:CPU and XLA:GPU devices are deprecated", warn_once=True + ) def enable_xla_devices(self): """Enables XLA:CPU and XLA:GPU devices registration.""" pywrap_tfe.TF_EnableXlaDevices() @@ -1992,8 +2096,10 @@ def enable_mlir_graph_optimization(self, enabled): @property def optimizer_jit(self): level = self.config.graph_options.optimizer_options.global_jit_level - return (level == config_pb2.OptimizerOptions.ON_1 or - level == config_pb2.OptimizerOptions.ON_2) + return ( + level == config_pb2.OptimizerOptions.ON_1 + or level == config_pb2.OptimizerOptions.ON_2 + ) @optimizer_jit.setter def optimizer_jit(self, enabled): @@ -2013,7 +2119,7 @@ def get_optimizer_experimental_options(self): def rewriter_toggle(option): attr = getattr(rewrite_options, option) if attr != 0: - options[option] = (attr == rewriter_config_pb2.RewriterConfig.ON) + options[option] = attr == rewriter_config_pb2.RewriterConfig.ON def rewriter_bool(option): options[option] = getattr(rewrite_options, option) @@ -2063,7 +2169,8 @@ def intra_op_parallelism_threads(self, num_threads): if self._context_handle is not None: raise RuntimeError( - "Intra op parallelism cannot be modified after initialization.") + "Intra op parallelism cannot be modified after initialization." + ) self._intra_op_parallelism_threads = num_threads @@ -2078,7 +2185,8 @@ def inter_op_parallelism_threads(self, num_threads): if self._context_handle is not None: raise RuntimeError( - "Inter op parallelism cannot be modified after initialization.") + "Inter op parallelism cannot be modified after initialization." + ) self._inter_op_parallelism_threads = num_threads @@ -2143,7 +2251,8 @@ def device_policy(self, policy): # Only set the policy if the context has already been initialized if self._context_handle is not None: pywrap_tfe.TFE_ContextSetThreadLocalDevicePlacementPolicy( - self._handle, self._device_policy) + self._handle, self._device_policy + ) @property def use_tfrt(self): @@ -2171,7 +2280,8 @@ def operation_timeout_in_ms(self, timeout_in_ms): if self._context_handle is not None: raise RuntimeError( - "Operation timeout cannot be modified after initialization.") + "Operation timeout cannot be modified after initialization." + ) self._operation_timeout_in_ms = timeout_in_ms @@ -2263,23 +2373,28 @@ def __enter__(self): cache_key = (old_device_name, new_device_name) try: new_device_name, new_device_spec = _device_parsing_cache[cache_key] - except TypeError: + except TypeError as exc: # Error while trying to compute the cache key. - raise ValueError("Expecting a string device name. Got %s(%s)" % - (type(new_device_name), new_device_name)) - except KeyError: + raise ValueError( + "Expecting a string device name. Got %s(%s)" + % (type(new_device_name), new_device_name) + ) from exc + except KeyError as exc: # Handle a cache miss. if new_device_name is not None: if not isinstance(new_device_name, str): - raise ValueError("Expecting a string device name. Got %s(%s)" % - (type(new_device_name), new_device_name)) + raise ValueError( + "Expecting a string device name. Got %s(%s)" + % (type(new_device_name), new_device_name) + ) from exc device_spec = pydev.DeviceSpec.from_string(new_device_name) if old_device_name: new_device_spec = copy.copy(old_device_spec) else: ctx.ensure_initialized() new_device_spec = pydev.DeviceSpec.from_string( - ctx._context_devices[0]) # pylint: disable=protected-access + ctx._context_devices[0] + ) # pylint: disable=protected-access new_device_spec = new_device_spec.make_merged_spec(device_spec) else: new_device_spec = pydev.DeviceSpec.from_string("") @@ -3012,8 +3127,9 @@ def register_custom_device(device_capsule, device_name, device_info_capsule): argument to TFE_RegisterCustomDevice). This method takes ownership of the memory and clears the capsule destructor. """ - context().register_custom_device(device_capsule, device_name, - device_info_capsule) + context().register_custom_device( + device_capsule, device_name, device_info_capsule + ) # Not every user creates a Context via context.context() diff --git a/tensorflow/python/training/server_lib_test.py b/tensorflow/python/training/server_lib_test.py index 11175ec078fef7..2478ab64e6be80 100644 --- a/tensorflow/python/training/server_lib_test.py +++ b/tensorflow/python/training/server_lib_test.py @@ -87,8 +87,10 @@ def testResetFails(self): # Verifies that resetting target with no server times out. with self.assertRaises(errors_impl.DeadlineExceededError): session.Session.reset( - "grpc://localhost:0", ["test0"], - config=config_pb2.ConfigProto(operation_timeout_in_ms=5)) + "grpc://localhost:0", + ["test0"], + config=config_pb2.ConfigProto(operation_timeout_in_ms=5), + ) # Verifies no containers are reset with non-existent container. server = self._cached_server @@ -114,8 +116,11 @@ def _useRPCConfig(self): Returns: A `tf.compat.v1.ConfigProto`. """ - return config_pb2.ConfigProto(rpc_options=rpc_options_pb2.RPCOptions( - use_rpc_for_inprocess_master=True)) + return config_pb2.ConfigProto( + rpc_options=rpc_options_pb2.RPCOptions( + use_rpc_for_inprocess_master=True + ) + ) def testLargeConstant(self): server = self._cached_server @@ -159,8 +164,9 @@ def testCloseCancelsBlockingOperation(self): sess.run(dequeue_t) def blocking_dequeue(): - with self.assertRaisesRegex(errors_impl.CancelledError, - "Session::Close"): + with self.assertRaisesRegex( + errors_impl.CancelledError, "Session::Close" + ): sess.run(dequeue_t) blocking_thread = self.checkedThread(blocking_dequeue) @@ -180,26 +186,29 @@ def testInteractiveSession(self): def testSetConfiguration(self): config = config_pb2.ConfigProto( - gpu_options=config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.1)) + gpu_options=config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.1) + ) # Configure a server using the default local server options. server = server_lib.Server.create_local_server(config=config, start=False) - self.assertEqual(0.1, server.server_def.default_session_config.gpu_options. - per_process_gpu_memory_fraction) + self.assertEqual( + 0.1, + server.server_def.default_session_config.gpu_options.per_process_gpu_memory_fraction, + ) # Configure a server using an explicit ServerDefd with an # overridden config. - cluster_def = server_lib.ClusterSpec({ - "localhost": ["localhost:0"] - }).as_cluster_def() + cluster_def = server_lib.ClusterSpec( + {"localhost": ["localhost:0"]} + ).as_cluster_def() server_def = tensorflow_server_pb2.ServerDef( - cluster=cluster_def, - job_name="localhost", - task_index=0, - protocol="grpc") + cluster=cluster_def, job_name="localhost", task_index=0, protocol="grpc" + ) server = server_lib.Server(server_def, config=config, start=False) - self.assertEqual(0.1, server.server_def.default_session_config.gpu_options. - per_process_gpu_memory_fraction) + self.assertEqual( + 0.1, + server.server_def.default_session_config.gpu_options.per_process_gpu_memory_fraction, + ) def testRestartedMaster(self): master_old = server_lib.Server.create_local_server() @@ -210,10 +219,10 @@ def get_cluster_def(master, worker): cluster_def = cluster_pb2.ClusterDef() job = cluster_def.job.add() job.name = "master" - job.tasks[0] = master.target[len("grpc://"):] + job.tasks[0] = master.target[len("grpc://") :] job = cluster_def.job.add() job.name = "worker" - job.tasks[0] = worker.target[len("grpc://"):] + job.tasks[0] = worker.target[len("grpc://") :] return cluster_def def check_session_devices(sess): @@ -230,7 +239,8 @@ def check_session_devices(sess): b = a + a config = config_pb2.ConfigProto( - cluster_def=get_cluster_def(master_old, worker)) + cluster_def=get_cluster_def(master_old, worker) + ) sess_old = session.Session(master_old.target, config=config) check_session_devices(sess_old) @@ -241,7 +251,8 @@ def check_session_devices(sess): # the old master incarnation to be garbage collected. config = config_pb2.ConfigProto( - cluster_def=get_cluster_def(master_new, worker)) + cluster_def=get_cluster_def(master_new, worker) + ) sess_new = session.Session(master_new.target, config=config) check_session_devices(sess_new) @@ -251,8 +262,9 @@ def check_session_devices(sess): # Running on worker with the old session should raise an exception since # the WorkerSession of the old session has been garbage collected - with self.assertRaisesRegex(errors_impl.AbortedError, - "Session handle is not found"): + with self.assertRaisesRegex( + errors_impl.AbortedError, "Session handle is not found" + ): sess_old.run(b) sess_old.close() @@ -261,9 +273,8 @@ def check_session_devices(sess): def testInvalidHostname(self): with self.assertRaisesRegex(errors_impl.InvalidArgumentError, "port"): _ = server_lib.Server( - { - "local": ["localhost"] - }, job_name="local", task_index=0) + {"local": ["localhost"]}, job_name="local", task_index=0 + ) def testTimeoutRaisesException(self): server = self._cached_server @@ -274,29 +285,32 @@ def testTimeoutRaisesException(self): with session.Session(server.target) as sess: with self.assertRaises(errors_impl.DeadlineExceededError): sess.run( - blocking_t, options=config_pb2.RunOptions(timeout_in_ms=1000)) + blocking_t, options=config_pb2.RunOptions(timeout_in_ms=1000) + ) with session.Session(server.target, config=self._useRPCConfig()) as sess: with self.assertRaises(errors_impl.DeadlineExceededError): sess.run( - blocking_t, options=config_pb2.RunOptions(timeout_in_ms=1000)) + blocking_t, options=config_pb2.RunOptions(timeout_in_ms=1000) + ) def testTwoServersSamePort(self): # Starting a server with the same target as the cached server should fail. server = self._cached_server with self.assertRaises(errors_impl.UnknownError): - _ = server_lib.Server( - {"local_2": [server.target[len("grpc://"):]]}) + _ = server_lib.Server({"local_2": [server.target[len("grpc://") :]]}) def testExtendAfterQueueRunners(self): server = self._cached_server with session.Session(server.target) as sess: - input_queue = input_ops.input_producer(constant_op.constant( - [0.], dtype=dtypes.float32)) + input_queue = input_ops.input_producer( + constant_op.constant([0.0], dtype=dtypes.float32) + ) self.assertIsNotNone(input_queue) var = variable_v1.VariableV1( - 1., dtype=dtypes.float32, trainable=False, name="var") + 1.0, dtype=dtypes.float32, trainable=False, name="var" + ) sess.run(variables.global_variables_initializer()) queue_runner_impl.start_queue_runners(sess) @@ -319,7 +333,10 @@ def testIsolateSessionState(self): # Initially all variables are initialized. for sess in [ - sharing_sess_0, sharing_sess_1, isolate_sess_0, isolate_sess_1 + sharing_sess_0, + sharing_sess_1, + isolate_sess_0, + isolate_sess_1, ]: with self.assertRaises(errors_impl.FailedPreconditionError): sess.run(v) @@ -391,37 +408,45 @@ def testShapeChangingIsolateState(self): class ServerDefTest(test.TestCase): def testLocalServer(self): - cluster_def = server_lib.ClusterSpec({ - "local": ["localhost:2222"] - }).as_cluster_def() + cluster_def = server_lib.ClusterSpec( + {"local": ["localhost:2222"]} + ).as_cluster_def() server_def = tensorflow_server_pb2.ServerDef( - cluster=cluster_def, job_name="local", task_index=0, protocol="grpc") + cluster=cluster_def, job_name="local", task_index=0, protocol="grpc" + ) - self.assertProtoEquals(""" + self.assertProtoEquals( + """ cluster { job { name: 'local' tasks { key: 0 value: 'localhost:2222' } } } job_name: 'local' task_index: 0 protocol: 'grpc' - """, server_def) + """, + server_def, + ) # Verifies round trip from Proto->Spec->Proto is correct. cluster_spec = server_lib.ClusterSpec(cluster_def) self.assertProtoEquals(cluster_def, cluster_spec.as_cluster_def()) def testTwoProcesses(self): - cluster_def = server_lib.ClusterSpec({ - "local": ["localhost:2222", "localhost:2223"] - }).as_cluster_def() + cluster_def = server_lib.ClusterSpec( + {"local": ["localhost:2222", "localhost:2223"]} + ).as_cluster_def() server_def = tensorflow_server_pb2.ServerDef( - cluster=cluster_def, job_name="local", task_index=1, protocol="grpc") + cluster=cluster_def, job_name="local", task_index=1, protocol="grpc" + ) - self.assertProtoEquals(""" + self.assertProtoEquals( + """ cluster { job { name: 'local' tasks { key: 0 value: 'localhost:2222' } tasks { key: 1 value: 'localhost:2223' } } } job_name: 'local' task_index: 1 protocol: 'grpc' - """, server_def) + """, + server_def, + ) # Verifies round trip from Proto->Spec->Proto is correct. cluster_spec = server_lib.ClusterSpec(cluster_def) @@ -430,12 +455,14 @@ def testTwoProcesses(self): def testTwoJobs(self): cluster_def = server_lib.ClusterSpec({ "ps": ["ps0:2222", "ps1:2222"], - "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"], }).as_cluster_def() server_def = tensorflow_server_pb2.ServerDef( - cluster=cluster_def, job_name="worker", task_index=2, protocol="grpc") + cluster=cluster_def, job_name="worker", task_index=2, protocol="grpc" + ) - self.assertProtoEquals(""" + self.assertProtoEquals( + """ cluster { job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } tasks { key: 1 value: 'ps1:2222' } } @@ -444,7 +471,9 @@ def testTwoJobs(self): tasks { key: 2 value: 'worker2:2222' } } } job_name: 'worker' task_index: 2 protocol: 'grpc' - """, server_def) + """, + server_def, + ) # Verifies round trip from Proto->Spec->Proto is correct. cluster_spec = server_lib.ClusterSpec(cluster_def) @@ -453,15 +482,14 @@ def testTwoJobs(self): def testDenseAndSparseJobs(self): cluster_def = server_lib.ClusterSpec({ "ps": ["ps0:2222", "ps1:2222"], - "worker": { - 0: "worker0:2222", - 2: "worker2:2222" - } + "worker": {0: "worker0:2222", 2: "worker2:2222"}, }).as_cluster_def() server_def = tensorflow_server_pb2.ServerDef( - cluster=cluster_def, job_name="worker", task_index=2, protocol="grpc") + cluster=cluster_def, job_name="worker", task_index=2, protocol="grpc" + ) - self.assertProtoEquals(""" + self.assertProtoEquals( + """ cluster { job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } tasks { key: 1 value: 'ps1:2222' } } @@ -469,7 +497,9 @@ def testDenseAndSparseJobs(self): tasks { key: 2 value: 'worker2:2222' } } } job_name: 'worker' task_index: 2 protocol: 'grpc' - """, server_def) + """, + server_def, + ) # Verifies round trip from Proto->Spec->Proto is correct. cluster_spec = server_lib.ClusterSpec(cluster_def) @@ -479,20 +509,20 @@ def testDenseAndSparseJobs(self): class ClusterSpecTest(test.TestCase): def testStringConversion(self): - cluster_spec = server_lib.ClusterSpec({ - "ps": ["ps0:1111"], - "worker": ["worker0:3333", "worker1:4444"] - }) + cluster_spec = server_lib.ClusterSpec( + {"ps": ["ps0:1111"], "worker": ["worker0:3333", "worker1:4444"]} + ) expected_str = ( "ClusterSpec({'ps': ['ps0:1111'], 'worker': ['worker0:3333', " - "'worker1:4444']})") + "'worker1:4444']})" + ) self.assertEqual(expected_str, str(cluster_spec)) def testProtoDictDefEquivalences(self): cluster_spec = server_lib.ClusterSpec({ "ps": ["ps0:2222", "ps1:2222"], - "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"], }) expected_proto = """ @@ -505,22 +535,21 @@ def testProtoDictDefEquivalences(self): self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def()) self.assertProtoEquals( - expected_proto, - server_lib.ClusterSpec(cluster_spec).as_cluster_def()) + expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def() + ) self.assertProtoEquals( expected_proto, - server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def()) + server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def(), + ) self.assertProtoEquals( expected_proto, - server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def()) + server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def(), + ) def testProtoDictDefEquivalencesWithStringTaskIndex(self): - cluster_spec = server_lib.ClusterSpec({ - "ps": ["ps0:2222", "ps1:2222"], - "worker": { - "1": "worker1:2222" - } - }) + cluster_spec = server_lib.ClusterSpec( + {"ps": ["ps0:2222", "ps1:2222"], "worker": {"1": "worker1:2222"}} + ) expected_proto = """ job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } @@ -530,20 +559,21 @@ def testProtoDictDefEquivalencesWithStringTaskIndex(self): self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def()) self.assertProtoEquals( - expected_proto, - server_lib.ClusterSpec(cluster_spec).as_cluster_def()) + expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def() + ) self.assertProtoEquals( expected_proto, - server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def()) + server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def(), + ) self.assertProtoEquals( expected_proto, - server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def()) + server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def(), + ) def testProtoDictDefEquivalencesWithZeroWorker(self): - cluster_spec = server_lib.ClusterSpec({ - "ps": ["ps0:2222", "ps1:2222"], - "worker": [] - }) + cluster_spec = server_lib.ClusterSpec( + {"ps": ["ps0:2222", "ps1:2222"], "worker": []} + ) expected_proto = """ job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } @@ -553,22 +583,22 @@ def testProtoDictDefEquivalencesWithZeroWorker(self): self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def()) self.assertProtoEquals( - expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def()) + expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def() + ) self.assertProtoEquals( expected_proto, - server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def()) + server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def(), + ) self.assertProtoEquals( expected_proto, - server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def()) + server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def(), + ) def testClusterSpecAccessors(self): original_dict = { "ps": ["ps0:2222", "ps1:2222"], "worker": ["worker0:2222", "worker1:2222", "worker2:2222"], - "sparse": { - 0: "sparse0:2222", - 3: "sparse3:2222" - } + "sparse": {0: "sparse0:2222", 3: "sparse3:2222"}, } cluster_spec = server_lib.ClusterSpec(original_dict) @@ -596,10 +626,14 @@ def testClusterSpecAccessors(self): # NOTE(mrry): `ClusterSpec.job_tasks()` is not recommended for use # with sparse jobs. self.assertEqual(["ps0:2222", "ps1:2222"], cluster_spec.job_tasks("ps")) - self.assertEqual(["worker0:2222", "worker1:2222", "worker2:2222"], - cluster_spec.job_tasks("worker")) - self.assertEqual(["sparse0:2222", None, None, "sparse3:2222"], - cluster_spec.job_tasks("sparse")) + self.assertEqual( + ["worker0:2222", "worker1:2222", "worker2:2222"], + cluster_spec.job_tasks("worker"), + ) + self.assertEqual( + ["sparse0:2222", None, None, "sparse3:2222"], + cluster_spec.job_tasks("sparse"), + ) with self.assertRaises(ValueError): cluster_spec.job_tasks("unknown") @@ -616,9 +650,9 @@ def testEq(self): server_lib.ClusterSpec({"job": ["host:2222"]}), ) self.assertEqual( - server_lib.ClusterSpec({"job": { - 0: "host:2222" - }}), server_lib.ClusterSpec({"job": ["host:2222"]})) + server_lib.ClusterSpec({"job": {0: "host:2222"}}), + server_lib.ClusterSpec({"job": ["host:2222"]}), + ) def testNe(self): self.assertNotEqual( From 4f2091dadb41a2cbabec425efb3e5b91cd2a01cf Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 20 Jun 2024 15:17:37 -0700 Subject: [PATCH 090/256] [JAX] Teach jit fast path how to handle negative static_argnums correctly. PiperOrigin-RevId: 645172085 --- third_party/xla/xla/python/jax_jit.cc | 7 +++++-- third_party/xla/xla/python/xla_client.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/python/jax_jit.cc b/third_party/xla/xla/python/jax_jit.cc index fa6125ded04f6c..4f10ee7c7e6036 100644 --- a/third_party/xla/xla/python/jax_jit.cc +++ b/third_party/xla/xla/python/jax_jit.cc @@ -283,9 +283,12 @@ absl::Status ParseArguments( signature.dynamic_arg_treedefs.reserve(positional_args.size()); // Positional arguments. + int num_positional_args = positional_args.size(); for (int i = 0; i < positional_args.size(); ++i) { - if (std::find(static_argnums.begin(), static_argnums.end(), i) == - static_argnums.end()) { + if (std::find_if(static_argnums.begin(), static_argnums.end(), + [i, num_positional_args](int t) { + return t >= 0 ? i == t : i == t + num_positional_args; + }) == static_argnums.end()) { signature.dynamic_arg_treedefs.emplace_back(pytree_registry); xla::PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); pytree_def.Flatten(positional_args[i], flat_dynamic_args); diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index ca7f749aade67c..d49fcdafb1ca5e 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 271 +_version = 272 # Version number for MLIR:Python components. mlir_api_version = 57 From 7be686d73b184d859a6cf2cb4ee599fdab24437a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 20 Jun 2024 15:18:18 -0700 Subject: [PATCH 091/256] Implements the `layout` method for the BasicStringArray class. PiperOrigin-RevId: 645172260 --- third_party/xla/xla/python/pjrt_ifrt/BUILD | 3 + .../python/pjrt_ifrt/basic_string_array.cc | 62 ++++++++++--- .../xla/python/pjrt_ifrt/basic_string_array.h | 17 ++++ .../pjrt_ifrt/basic_string_array_test.cc | 88 +++++++++++++++++++ 4 files changed, 159 insertions(+), 11 deletions(-) diff --git a/third_party/xla/xla/python/pjrt_ifrt/BUILD b/third_party/xla/xla/python/pjrt_ifrt/BUILD index 8d5f6adff7c6b1..49b498d893f859 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/BUILD +++ b/third_party/xla/xla/python/pjrt_ifrt/BUILD @@ -307,6 +307,7 @@ cc_library( "//xla/tsl/concurrency:ref_count", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -326,7 +327,9 @@ xla_cc_test( deps = [ ":basic_string_array", ":tfrt_cpu_client_test_lib", + "//xla:shape_util", "//xla/pjrt:pjrt_future", + "//xla/pjrt:pjrt_layout", "//xla/python/ifrt", "//xla/python/ifrt:test_util", "//xla/tsl/concurrency:ref_count", diff --git a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc index deac6780b624a5..b832981c766998 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -50,6 +51,40 @@ limitations under the License. namespace xla { namespace ifrt { +///////////////////////////////////////////////////////////////////////////// +// +// BasicStringArrayLayout +// + +std::string BasicStringArrayLayout::Serialize() const { + // We currently do not have any state that need to be serialized. Return an + // empty string. + return std::string(); +} + +std::string BasicStringArrayLayout::ToString() const { + return "BasicStringArrayLayout: Dense, major-to-minor."; +} + +bool BasicStringArrayLayout::operator==(const PjRtLayout& other) const { + auto* other_basic_string_array_layout = + dynamic_cast(&other); + if (other_basic_string_array_layout == nullptr) { + return false; + } + // All BasicStringArrayLayout objects are the same - they are all dense, + // major-to-minor. So, all of them are equal. + return true; +} + +void BasicStringArrayLayout::Hash(absl::HashState state) const { +} // Nothing to add to the hash state. Just return. + +///////////////////////////////////////////////////////////////////////////// +// +// BasicStringArray +// + char BasicStringArray::ID = 0; absl::StatusOr> BasicStringArray::Create( @@ -167,9 +202,9 @@ BasicStringArray::DisassembleIntoSingleDeviceArrays( // (3) shape and sharding by disassembing the source array's sharding. // // The Futures, the on-done-with-host-buffer callbacks, shapes and shardings - // are used to make the arrays. The promises and the buffer backing stores are - // passed onto the OnReady callback that populates them when the buffers of - // the source array become ready. + // are used to make the arrays. The promises and the buffer backing stores + // are passed onto the OnReady callback that populates them when the buffers + // of the source array become ready. std::vector> buffer_promises; buffer_promises.reserve(num_shards); std::vector> buffer_futures; @@ -226,8 +261,8 @@ BasicStringArray::DisassembleIntoSingleDeviceArrays( }); // Make and return the individual single device arrays. These will become - // ready when the this (source) array becomes ready and the callback we set up - // above runs. + // ready when the this (source) array becomes ready and the callback we set + // up above runs. TF_ASSIGN_OR_RETURN(auto shapes_and_shadings, sharding_->Disassemble(shape_)); std::vector> arrays; @@ -319,9 +354,10 @@ absl::StatusOr> BasicStringArray::FullyReplicatedShard( } // Some user code paths (e.g.: through JAX) may not correctly set the - // `is_fully_replicated` flag when they are using ConcreteEvenSharding. If and - // when that causes a problem, we should investigate a way to actually looking - // into the sharding to determine if it is a fully replicated sharding. + // `is_fully_replicated` flag when they are using ConcreteEvenSharding. If + // and when that causes a problem, we should investigate a way to actually + // looking into the sharding to determine if it is a fully replicated + // sharding. if (!sharding_->IsFullyReplicated()) { return absl::FailedPreconditionError("This array is not fully replicated"); } @@ -352,8 +388,8 @@ absl::StatusOr> BasicStringArray::FullyReplicatedShard( } // No need to check the size of input_buffers. The consistency checks that - // were run when the source array's buffers became ready would have ensured - // that the input_buffers have at least one shard's worth of data. + // were run when the source array's buffers became ready would have + // ensured that the input_buffers have at least one shard's worth of data. auto& input_buffer = (*input_buffers)[0]; backing_store->CopyFrom(input_buffer); @@ -370,7 +406,11 @@ absl::StatusOr> BasicStringArray::FullyReplicatedShard( } absl::StatusOr> BasicStringArray::layout() const { - return absl::UnimplementedError("Not implemented"); + absl::MutexLock lock(&mu_); + if (is_deleted_) { + return absl::FailedPreconditionError("Array has already been deleted"); + } + return std::make_unique(); } std::string BasicStringArray::DebugString() const { diff --git a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.h b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.h index 6344e8d7c0f14a..a02af9b12dc4cc 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.h +++ b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.h @@ -26,6 +26,7 @@ limitations under the License. #include "absl/base/attributes.h" #include "absl/base/thread_annotations.h" #include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" @@ -42,6 +43,22 @@ limitations under the License. namespace xla { namespace ifrt { +// Describes the layout of a `BasicStringArray`. +class BasicStringArrayLayout : public PjRtLayout { + public: + BasicStringArrayLayout() = default; + BasicStringArrayLayout(const BasicStringArrayLayout& other) = delete; + + ~BasicStringArrayLayout() override = default; + + std::string Serialize() const override; + std::string ToString() const override; + bool operator==(const PjRtLayout& other) const override; + + protected: + void Hash(absl::HashState state) const override; +}; + // `BasicStringArray` implements an `ifrt::Array` by wrapping a local (aka host) // string buffer. This object is expected to live exclusively in the IFRT layer, // and thus is not specific to any particular backend. However, it is currently diff --git a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc index 351d5c57481248..922c3875e782c5 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc @@ -32,7 +32,9 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" +#include "xla/layout.h" #include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/dtype.h" @@ -54,6 +56,11 @@ namespace { using ::testing::HasSubstr; using ::tsl::testing::StatusIs; +// //////////////////////////////////////////////////////////////////////////// +// +// Common utility functions. +// + // Makes a simple single device sharded `BasicStringArray` from the // user-supplied buffers and on_done_with_buffer callback by means of the // factory method: `BasicStringArray::Create`. Uses the first device from the @@ -119,6 +126,51 @@ CreateNonReadyTestArray( return std::make_pair(std::move(array), std::move(buffers_promise)); } +///////////////////////////////////////////////////////////////////////////// +// +// Tests related to BasicStringArrayLayout. +// + +TEST(BasicStringArrayLayoutTest, Serialize) { + BasicStringArrayLayout layout; + // Seerialize currently has no state to serialize, and so the returned value + // should be an empty string. + EXPECT_TRUE(layout.Serialize().empty()); +} + +TEST(BasicStringArrayLayoutTest, ToString) { + BasicStringArrayLayout layout; + auto output_str = layout.ToString(); + EXPECT_THAT(output_str, HasSubstr("major-to-minor")); +} + +TEST(BasicStringArrayLayoutTest, Equality) { + BasicStringArrayLayout layout_1; + + // In the equality comparisons below, use the PjRtLayout interface for the + // second object so we can avoid the error: `ambiguity is between a regular + // call to this operator and a call with the argument order reversed`. + + // Any two BasicStringArrayLayouts are equal. + BasicStringArrayLayout layout_2; + const PjRtLayout& layout_3 = layout_2; + EXPECT_EQ(layout_1, layout_3); + + // In the next test, EXPECT_NE is not used because the version of EXCEPT_NE + // available in the open sourced libraries requires the operator `!=` to be + // overloaded. + + // Non-BasicStringArrayLayouts are not equal to BasicStringArrayLayouts. + xla::PjRtXlaLayout layout_6((xla::Layout())); + const PjRtLayout& layout_7 = layout_6; + EXPECT_FALSE(layout_7 == layout_1); +} + +///////////////////////////////////////////////////////////////////////////// +// +// Tests related to BasicStringArray. +// + TEST(BasicStringArrayTest, CreateSuccess) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); BasicStringArray::Buffers buffers; @@ -883,6 +935,42 @@ TEST(FullyReplicatedShardTest, FailsAfterDeletion) { StatusIs(absl::StatusCode::kFailedPrecondition)); } +TEST(LayoutTest, Success) { + TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); + + constexpr char kContents[] = "abc"; + auto [buffers, on_done_with_buffer] = + MakeBuffersAndOnDoneWithBuffer({kContents}); + TF_ASSERT_OK_AND_ASSIGN( + auto array, + CreateTestArray(client.get(), + Future(std::move(buffers)), + std::move(on_done_with_buffer))); + + // The number of dimensions for the testArray should be 1. Typical usage of + // BasicStringArrayLayout does not require an accessor to retrieve the number + // of dimensions. Instead of adding a test only method, we could just check + // the serialized layout. + TF_ASSERT_OK_AND_ASSIGN(auto layout, array->layout()); + EXPECT_TRUE(layout->Serialize().empty()); +} + +TEST(LayoutTest, FailsAfterDeletion) { + TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); + + constexpr char kContents[] = "abc"; + auto [buffers, on_done_with_buffer] = + MakeBuffersAndOnDoneWithBuffer({kContents}); + TF_ASSERT_OK_AND_ASSIGN( + auto array, + CreateTestArray(client.get(), Future(buffers), + std::move(on_done_with_buffer))); + + array->Delete(); + + EXPECT_THAT(array->layout(), StatusIs(absl::StatusCode::kFailedPrecondition)); +} + } // namespace } // namespace ifrt } // namespace xla From cf23a3e8acd31c1dfb708e57eddbd0ef2296a3df Mon Sep 17 00:00:00 2001 From: zoranjovanovic-ns <126815388+zoranjovanovic-ns@users.noreply.github.com> Date: Thu, 20 Jun 2024 15:58:46 -0700 Subject: [PATCH 092/256] =?UTF-8?q?PR=20#13971:=20[ROCm]=20Fixed=20build?= =?UTF-8?q?=20break=20caused=20by=20https://github.com/openxla/xla/com?= =?UTF-8?q?=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/openxla/xla/pull/13971 …mit/c9fba04117bf16c61a61780ab39b380ae99e73ae Copybara import of the project: -- f1aac47b10caaa9c675d7745ff8f111b89ada582 by Zoran Jovanovic : [ROCm] Fixed build break caused by https://github.com/openxla/xla/commit/c9fba04117bf16c61a61780ab39b380ae99e73ae Merging this change closes #13971 PiperOrigin-RevId: 645182930 --- third_party/xla/xla/service/gpu/triton_test_utils.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/triton_test_utils.cc b/third_party/xla/xla/service/gpu/triton_test_utils.cc index e9fbd2547412d2..2117a476a95c18 100644 --- a/third_party/xla/xla/service/gpu/triton_test_utils.cc +++ b/third_party/xla/xla/service/gpu/triton_test_utils.cc @@ -172,7 +172,8 @@ absl::Status ConvertEntryToTritonFusion(HloModule* module) { module->entry_computation())); gpu::GpuBackendConfig gpu_config; - gpu_config.mutable_fusion_backend_config()->set_kind(kTritonFusionKind); + gpu_config.mutable_fusion_backend_config()->set_kind( + std::string(kTritonFusionKind)); TF_RETURN_IF_ERROR(fusion->set_backend_config(gpu_config)); auto new_entry = From 5bcf657f995eeecb416ab496dba9da787467f692 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Thu, 20 Jun 2024 16:10:43 -0700 Subject: [PATCH 093/256] Stop using xla/statusor.h now that it just contains an alias for absl::Status. In some situations, this meant also changing unrelated files to directly include tsl/platform/statusor.h to get the definitions for TF_ASSIGN_OR_RETURN, etc., where they were getting transitively included for free. PiperOrigin-RevId: 645186390 --- tensorflow/compiler/jit/BUILD | 10 +++++----- tensorflow/compiler/jit/compilability_check_util.cc | 2 +- tensorflow/compiler/jit/compilability_check_util.h | 2 +- tensorflow/compiler/jit/device_util.h | 2 +- tensorflow/compiler/jit/mark_for_compilation_pass.cc | 2 +- tensorflow/compiler/jit/shape_inference.h | 2 +- tensorflow/compiler/jit/xla_cluster_util.h | 2 +- third_party/xla/xla/hlo/ir/BUILD | 1 - .../xla/xla/hlo/ir/dfs_hlo_visitor_with_default.h | 2 +- third_party/xla/xla/hlo/ir/hlo_computation.h | 2 +- third_party/xla/xla/hlo/ir/hlo_domain_metadata.h | 2 +- .../xla/xla/hlo/ir/hlo_input_output_alias_config.cc | 2 +- third_party/xla/xla/hlo/ir/hlo_instruction.h | 2 +- third_party/xla/xla/hlo/ir/hlo_module_metadata.h | 1 + third_party/xla/xla/hlo/ir/hlo_opcode.h | 2 +- third_party/xla/xla/pjrt/distributed/BUILD | 6 +++--- third_party/xla/xla/pjrt/distributed/distributed.cc | 2 +- third_party/xla/xla/pjrt/distributed/distributed.h | 2 +- third_party/xla/xla/pjrt/distributed/service.h | 2 +- third_party/xla/xla/pjrt/distributed/topology_util.cc | 2 +- third_party/xla/xla/pjrt/distributed/topology_util.h | 2 +- 21 files changed, 26 insertions(+), 26 deletions(-) diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index a3934ae99155ca..efdf25ae46f532 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -961,8 +961,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_xla//xla:statusor", ], ) @@ -1110,10 +1110,10 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@local_xla//xla:status_macros", - "@local_xla//xla:statusor", "@local_xla//xla:union_find", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", @@ -1139,10 +1139,10 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@local_xla//xla:status_macros", - "@local_xla//xla:statusor", "@local_xla//xla/service/graphcycles", ], ) @@ -1159,10 +1159,10 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_xla//xla:status_macros", - "@local_xla//xla:statusor", ], ) @@ -1376,9 +1376,9 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_xla//xla:statusor", "@local_xla//xla:union_find", "@local_xla//xla:util", "@local_xla//xla/service/graphcycles", diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 89cfc9c933afa4..ef46760f5065b4 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" @@ -43,7 +44,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/service/graphcycles/graphcycles.h" -#include "xla/statusor.h" #include "xla/union_find.h" #include "xla/util.h" #include "tensorflow/core/common_runtime/function.h" diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index 1a05b69d5756b4..7c38cc92c541b7 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/defs.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/service/graphcycles/graphcycles.h" -#include "xla/statusor.h" #include "xla/union_find.h" #include "xla/util.h" #include "tensorflow/core/common_runtime/function.h" diff --git a/tensorflow/compiler/jit/device_util.h b/tensorflow/compiler/jit/device_util.h index ec4d9484ae8854..de06732f39008d 100644 --- a/tensorflow/compiler/jit/device_util.h +++ b/tensorflow/compiler/jit/device_util.h @@ -21,11 +21,11 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/numeric/bits.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "tensorflow/core/framework/types.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 0883cff150b995..aa59b847ac3cb4 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -33,6 +33,7 @@ limitations under the License. #include "absl/base/call_once.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/compilability_check_util.h" #include "tensorflow/compiler/jit/deadness_analysis.h" @@ -45,7 +46,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/service/graphcycles/graphcycles.h" -#include "xla/statusor.h" #include "xla/union_find.h" #include "xla/util.h" #include "tensorflow/core/common_runtime/function.h" diff --git a/tensorflow/compiler/jit/shape_inference.h b/tensorflow/compiler/jit/shape_inference.h index e12452fb316cf6..3bd814823013f0 100644 --- a/tensorflow/compiler/jit/shape_inference.h +++ b/tensorflow/compiler/jit/shape_inference.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "xla/statusor.h" +#include "absl/status/statusor.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index 241786fefbe108..8f5c8580de3cdd 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -22,10 +22,10 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "xla/service/graphcycles/graphcycles.h" -#include "xla/statusor.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/graph/algorithm.h" diff --git a/third_party/xla/xla/hlo/ir/BUILD b/third_party/xla/xla/hlo/ir/BUILD index 1102ef1a67bae0..0494eb8f9276fb 100644 --- a/third_party/xla/xla/hlo/ir/BUILD +++ b/third_party/xla/xla/hlo/ir/BUILD @@ -71,7 +71,6 @@ cc_library( "//xla:shape_tree", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:types", "//xla:util", "//xla:window_util", diff --git a/third_party/xla/xla/hlo/ir/dfs_hlo_visitor_with_default.h b/third_party/xla/xla/hlo/ir/dfs_hlo_visitor_with_default.h index 40317cb77d85ec..37e52646ee83ff 100644 --- a/third_party/xla/xla/hlo/ir/dfs_hlo_visitor_with_default.h +++ b/third_party/xla/xla/hlo/ir/dfs_hlo_visitor_with_default.h @@ -21,12 +21,12 @@ limitations under the License. #include "absl/base/optimization.h" #include "absl/log/log.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/statusor.h" #include "tsl/platform/status.h" namespace xla { diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.h b/third_party/xla/xla/hlo/ir/hlo_computation.h index 3e1161fdeeefdb..89340201fbdaa3 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.h +++ b/third_party/xla/xla/hlo/ir/hlo_computation.h @@ -29,6 +29,7 @@ limitations under the License. #include "absl/functional/function_ref.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -42,7 +43,6 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/service/name_uniquer.h" #include "xla/shape_tree.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/hlo/ir/hlo_domain_metadata.h b/third_party/xla/xla/hlo/ir/hlo_domain_metadata.h index a90f6d8a5ede28..249e406b18338e 100644 --- a/third_party/xla/xla/hlo/ir/hlo_domain_metadata.h +++ b/third_party/xla/xla/hlo/ir/hlo_domain_metadata.h @@ -21,8 +21,8 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "xla/statusor.h" #include "xla/types.h" #include "tsl/platform/status.h" diff --git a/third_party/xla/xla/hlo/ir/hlo_input_output_alias_config.cc b/third_party/xla/xla/hlo/ir/hlo_input_output_alias_config.cc index ada3edf2bbe618..c54af5675f4a7c 100644 --- a/third_party/xla/xla/hlo/ir/hlo_input_output_alias_config.cc +++ b/third_party/xla/xla/hlo/ir/hlo_input_output_alias_config.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "xla/hlo/ir/hlo_computation.h" @@ -32,7 +33,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index 4de06016084736..7a55565cb12502 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -41,6 +41,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -61,7 +62,6 @@ limitations under the License. #include "xla/service/name_uniquer.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/xla_data.pb.h" #include "tsl/lib/gtl/iterator_range.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/hlo/ir/hlo_module_metadata.h b/third_party/xla/xla/hlo/ir/hlo_module_metadata.h index 14231f29897e59..7f115f042d139f 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module_metadata.h +++ b/third_party/xla/xla/hlo/ir/hlo_module_metadata.h @@ -26,6 +26,7 @@ limitations under the License. #include "xla/util.h" #include "tsl/platform/env.h" #include "tsl/platform/protobuf.h" +#include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/hlo/ir/hlo_opcode.h b/third_party/xla/xla/hlo/ir/hlo_opcode.h index 79d90e3120d608..8c1cc7c00e5dcb 100644 --- a/third_party/xla/xla/hlo/ir/hlo_opcode.h +++ b/third_party/xla/xla/hlo/ir/hlo_opcode.h @@ -20,8 +20,8 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "xla/statusor.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/distributed/BUILD b/third_party/xla/xla/pjrt/distributed/BUILD index f1d8e6ea111660..be69c677d7011c 100644 --- a/third_party/xla/xla/pjrt/distributed/BUILD +++ b/third_party/xla/xla/pjrt/distributed/BUILD @@ -23,7 +23,6 @@ cc_library( deps = [ ":topology_util", ":util", - "//xla:statusor", "//xla:types", "//xla:util", "//xla/tsl/distributed_runtime/coordination:coordination_service", @@ -32,6 +31,7 @@ cc_library( "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_service_impl", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", @@ -102,7 +102,7 @@ cc_library( deps = [ ":client", ":service", - "//xla:statusor", + "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:grpc_credentials", ] + tsl_grpc_cc_dependencies(), ) @@ -114,13 +114,13 @@ cc_library( deps = [ ":key_value_store_interface", ":protocol_proto_cc", - "//xla:statusor", "//xla:util", "//xla/pjrt:pjrt_client", "//xla/pjrt:utils", "//xla/pjrt/gpu:gpu_topology_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", diff --git a/third_party/xla/xla/pjrt/distributed/distributed.cc b/third_party/xla/xla/pjrt/distributed/distributed.cc index 4eb5c8ac65ad6e..69f9f2e249b402 100644 --- a/third_party/xla/xla/pjrt/distributed/distributed.cc +++ b/third_party/xla/xla/pjrt/distributed/distributed.cc @@ -18,11 +18,11 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "grpcpp/channel.h" #include "grpcpp/create_channel.h" #include "xla/pjrt/distributed/client.h" #include "xla/pjrt/distributed/service.h" -#include "xla/statusor.h" #include "tsl/platform/grpc_credentials.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/distributed/distributed.h b/third_party/xla/xla/pjrt/distributed/distributed.h index 393bd234530bbe..8145ddaa5c699e 100644 --- a/third_party/xla/xla/pjrt/distributed/distributed.h +++ b/third_party/xla/xla/pjrt/distributed/distributed.h @@ -19,9 +19,9 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "xla/pjrt/distributed/client.h" #include "xla/pjrt/distributed/service.h" -#include "xla/statusor.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/distributed/service.h b/third_party/xla/xla/pjrt/distributed/service.h index 1bc5416fb5db9e..d1e3279a7a4a3d 100644 --- a/third_party/xla/xla/pjrt/distributed/service.h +++ b/third_party/xla/xla/pjrt/distributed/service.h @@ -20,13 +20,13 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" #include "absl/time/time.h" #include "grpcpp/grpcpp.h" #include "grpcpp/security/server_credentials.h" #include "grpcpp/server_builder.h" -#include "xla/statusor.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service.h" #include "xla/tsl/distributed_runtime/rpc/async_service_interface.h" #include "xla/types.h" diff --git a/third_party/xla/xla/pjrt/distributed/topology_util.cc b/third_party/xla/xla/pjrt/distributed/topology_util.cc index b423b2ccae4b9e..e3926dcb39cd5a 100644 --- a/third_party/xla/xla/pjrt/distributed/topology_util.cc +++ b/third_party/xla/xla/pjrt/distributed/topology_util.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -36,7 +37,6 @@ limitations under the License. #include "xla/pjrt/distributed/protocol.pb.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/utils.h" -#include "xla/statusor.h" #include "xla/util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/pjrt/distributed/topology_util.h b/third_party/xla/xla/pjrt/distributed/topology_util.h index 7549a227921a49..ec902d72efd63a 100644 --- a/third_party/xla/xla/pjrt/distributed/topology_util.h +++ b/third_party/xla/xla/pjrt/distributed/topology_util.h @@ -20,12 +20,12 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/distributed/protocol.pb.h" #include "xla/pjrt/gpu/gpu_topology.pb.h" -#include "xla/statusor.h" namespace xla { From bc30cae2b535131f4defaf55c21a9a041f0ddae9 Mon Sep 17 00:00:00 2001 From: Ionel Gog Date: Thu, 20 Jun 2024 18:45:46 -0700 Subject: [PATCH 094/256] [IFRT] Move MemoryKind attribute from IfrtShardingAttrInterface to IfrtArrayType PiperOrigin-RevId: 645221837 --- .../xla/xla/python/ifrt/ir/ifrt_dialect.cc | 20 +++---- .../xla/xla/python/ifrt/ir/ifrt_dialect.h | 1 + .../xla/xla/python/ifrt/ir/ifrt_dialect.td | 52 ++++++++++--------- .../xla/xla/python/ifrt/ir/ifrt_interfaces.h | 1 - .../xla/xla/python/ifrt/ir/ifrt_interfaces.td | 6 --- .../tests/ifrt_verify_sharding_specified.mlir | 17 +++--- 6 files changed, 43 insertions(+), 54 deletions(-) diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.cc b/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.cc index 265c4deea601fc..a3927e4cbf5e04 100644 --- a/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.cc +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.cc @@ -131,7 +131,7 @@ mlir::LogicalResult IfrtDialect::verifyRegionArgAttribute( mlir::LogicalResult IfrtShardingParamAttr::verify( llvm::function_ref emitError, - ShardingParam sharding_param, mlir::StringAttr memory_kind) { + ShardingParam sharding_param) { return sharding_param.verify(emitError); } @@ -158,12 +158,6 @@ int IfrtShardingParamAttr::NumDevices() const { return getSharding().NumDevices(); }; -xla::ifrt::MemoryKind IfrtShardingParamAttr::MemoryKind() const { - return getMemoryKind() == nullptr - ? xla::ifrt::MemoryKind() - : xla::ifrt::MemoryKind(getMemoryKind().str()); -}; - //===----------------------------------------------------------------------===// // IfrtUnspecifiedShardingAttr //===----------------------------------------------------------------------===// @@ -195,10 +189,6 @@ IfrtUnspecifiedShardingAttr::LocalShapeFromGlobalShape( int IfrtUnspecifiedShardingAttr::NumDevices() const { return 0; } -xla::ifrt::MemoryKind IfrtUnspecifiedShardingAttr::MemoryKind() const { - return xla::ifrt::MemoryKind(); -} - //===----------------------------------------------------------------------===// // IfrtArrayType //===----------------------------------------------------------------------===// @@ -211,10 +201,16 @@ llvm::ArrayRef IfrtArrayType::getDevices() const { mlir::LogicalResult IfrtArrayType::verify( llvm::function_ref emitError, mlir::RankedTensorType shape, IfrtShardingAttrInterface sharding_attr, - IfrtDevicesAttr devices) { + IfrtDevicesAttr devices, mlir::StringAttr memory_kind) { return sharding_attr.CanApplyTo(emitError, shape, devices.getIds()); } +xla::ifrt::MemoryKind IfrtArrayType::MemoryKind() const { + return getMemoryKindAttr() == nullptr + ? xla::ifrt::MemoryKind() + : xla::ifrt::MemoryKind(getMemoryKindAttr().str()); +}; + //===----------------------------------------------------------------------===// // IfrtDevicesAttr //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.h b/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.h index a8fcd0bc500b5e..155597bac4d485 100644 --- a/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.h +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.h @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/IR/Dialect.h" // from @llvm-project #include "xla/python/ifrt/ir/ifrt_interfaces.h" #include "xla/python/ifrt/ir/sharding_param.h" +#include "xla/python/ifrt/memory.h" // Generated definitions. #include "xla/python/ifrt/ir/ifrt_dialect.h.inc" // IWYU pragma: export diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.td b/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.td index dbcba1a62aaccd..d4ddbad8d43712 100644 --- a/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.td +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.td @@ -74,26 +74,8 @@ def Ifrt_ShardingParamAttr : AttrDef:$memory_kind - ); - let assemblyFormat = [{ - `<` $sharding (`,` `memory_kind` `=` $memory_kind^)? `>` - }]; - - let builders = [ - AttrBuilder<(ins "::xla::ifrt::ShardingParam":$sharding), [{ - return $_get($_ctxt, sharding, /*memory_kind=*/nullptr); - }]>, - AttrBuilder<(ins - "::xla::ifrt::ShardingParam":$sharding, - "::mlir::StringRef":$memory_kind), [{ - return $_get($_ctxt, - sharding, - ::mlir::StringAttr::get($_ctxt, memory_kind)); - }]> - ]; + let parameters = (ins Ifrt_ShardingParameter:$sharding); + let assemblyFormat = "`<` $sharding `>`"; let genVerifyDecl = 1; } @@ -181,7 +163,8 @@ def Ifrt_ArrayType : TypeDef { let parameters = (ins Builtin_RankedTensor:$shape, "::xla::ifrt::IfrtShardingAttrInterface":$sharding_attr, - Ifrt_DevicesAttr:$devices_attr); + Ifrt_DevicesAttr:$devices_attr, + OptionalParameter<"::mlir::StringAttr">:$memory_kind_attr); let builders = [ TypeBuilder<(ins @@ -190,14 +173,16 @@ def Ifrt_ArrayType : TypeDef { "::llvm::ArrayRef":$devices), [{ return Base::get( $_ctxt, shape, sharding_attr, - ::xla::ifrt::IfrtDevicesAttr::get($_ctxt, devices)); + ::xla::ifrt::IfrtDevicesAttr::get($_ctxt, devices), + /*memory_kind=*/nullptr); }]>, TypeBuilder<(ins "::mlir::RankedTensorType":$shape, "::llvm::ArrayRef":$devices), [{ return Base::get( $_ctxt, shape, ::xla::ifrt::IfrtUnspecifiedShardingAttr::get($_ctxt), - ::xla::ifrt::IfrtDevicesAttr::get($_ctxt, devices)); + ::xla::ifrt::IfrtDevicesAttr::get($_ctxt, devices), + /*memory_kind=*/nullptr); }]>, TypeBuilder<(ins "::mlir::RankedTensorType":$shape, @@ -206,17 +191,34 @@ def Ifrt_ArrayType : TypeDef { return Base::get( $_ctxt, shape, ::xla::ifrt::IfrtShardingParamAttr::get($_ctxt, sharding), - ::xla::ifrt::IfrtDevicesAttr::get($_ctxt, devices)); + ::xla::ifrt::IfrtDevicesAttr::get($_ctxt, devices), + /*memory_kind=*/nullptr); + }]>, + TypeBuilder<(ins + "::mlir::RankedTensorType":$shape, + "::xla::ifrt::ShardingParam":$sharding, + "::llvm::ArrayRef":$devices, + "::std::string":$memory_kind), [{ + return Base::get( + $_ctxt, shape, + ::xla::ifrt::IfrtShardingParamAttr::get($_ctxt, sharding), + ::xla::ifrt::IfrtDevicesAttr::get($_ctxt, devices), + ::mlir::StringAttr::get($_ctxt, memory_kind)); }]> ]; - let assemblyFormat = "`<` $shape`,` $sharding_attr`,` $devices_attr`>`"; + let assemblyFormat = [{ + `<` $shape`,` $sharding_attr `,` $devices_attr + (`,` `memory_kind` `=` $memory_kind_attr^)? `>` + }]; let genVerifyDecl = 1; let extraClassDeclaration = [{ // Get logical device ids from `devices_attr`. ::llvm::ArrayRef getDevices() const; + // Get the memory kind from `memory_kind`. + ::xla::ifrt::MemoryKind MemoryKind() const; }]; } diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_interfaces.h b/third_party/xla/xla/python/ifrt/ir/ifrt_interfaces.h index 77d754bbb4a39b..addbbafd202d08 100644 --- a/third_party/xla/xla/python/ifrt/ir/ifrt_interfaces.h +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_interfaces.h @@ -23,7 +23,6 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "xla/python/ifrt/ir/constants.h" #include "xla/python/ifrt/ir/sharding_param.h" -#include "xla/python/ifrt/memory.h" namespace mlir { namespace OpTrait { diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_interfaces.td b/third_party/xla/xla/python/ifrt/ir/ifrt_interfaces.td index f2dfe981a89957..f1a94a517b91b8 100644 --- a/third_party/xla/xla/python/ifrt/ir/ifrt_interfaces.td +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_interfaces.td @@ -108,12 +108,6 @@ def Ifrt_ShardingAttrInterface : Ifrt_AttrInterface<"IfrtShardingAttrInterface"> /*retTy=*/"int", /*methodName=*/"NumDevices", /*args=*/(ins) - >, - InterfaceMethod< - /*desc=*/"Returns the memory kind.", - /*retTy=*/"::xla::ifrt::MemoryKind", - /*methodName=*/"MemoryKind", - /*args=*/(ins) > ]; } diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_sharding_specified.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_sharding_specified.mlir index e79d8816c7b937..9d4734e6301783 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_sharding_specified.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_sharding_specified.mlir @@ -1,18 +1,15 @@ // RUN: ifrt-opt %s -ifrt-verify-sharding-specified -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: @good_arrays -#sharding = #ifrt.sharding_param<2 to [0] on 2, memory_kind = "device"> +#sharding = #ifrt.sharding_param<2 to [0] on 2> +!array0 = !ifrt.array, #sharding, [0,1], memory_kind = "device"> +!array1 = !ifrt.array, #sharding, [2,3], memory_kind = "device"> module @good_arrays { - func.func @main(%arg0: !ifrt.array, #sharding, [0,1]>) - -> !ifrt.array, #sharding, [2,3]> - attributes {ifrt.function} { + func.func @main(%arg0: !array0) -> !array1 attributes {ifrt.function} { %0, %ctrl_1 = ifrt.Call @identity(%arg0) on devices [0,1] - : (!ifrt.array, #sharding, [0,1]>) - -> !ifrt.array, #sharding, [0,1]> - %1 = "ifrt.Reshard"(%0) - : (!ifrt.array, #sharding, [0,1]>) - -> !ifrt.array, #sharding, [2,3]> - return %1 : !ifrt.array, #sharding, [2,3]> + : (!array0) -> !array0 + %1 = "ifrt.Reshard"(%0) : (!array0) -> !array1 + return %1 : !array1 } func.func private @identity(%arg0: tensor<2xi32>) -> tensor<2xi32> { From 43b64083cb2283bcaa988c854b3ec3fe8ca34ea0 Mon Sep 17 00:00:00 2001 From: Penporn Koanantakool Date: Thu, 20 Jun 2024 19:03:54 -0700 Subject: [PATCH 095/256] [xla:cpu] Add CustomCall thunk. PiperOrigin-RevId: 645224939 --- third_party/xla/xla/service/cpu/BUILD | 5 +- .../xla/xla/service/cpu/cpu_executable.cc | 12 +- third_party/xla/xla/service/cpu/runtime/BUILD | 40 ++- .../service/cpu/runtime/custom_call_thunk.cc | 262 ++++++++++++++++++ .../service/cpu/runtime/custom_call_thunk.h | 76 +++++ .../xla/xla/service/cpu/runtime/thunk.cc | 26 ++ .../xla/xla/service/cpu/runtime/thunk.h | 26 ++ .../xla/xla/service/cpu/thunk_emitter.cc | 85 +++++- .../xla/xla/service/cpu/thunk_emitter.h | 3 + third_party/xla/xla/tests/BUILD | 1 + third_party/xla/xla/tests/custom_call_test.cc | 24 ++ 11 files changed, 549 insertions(+), 11 deletions(-) create mode 100644 third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc create mode 100644 third_party/xla/xla/service/cpu/runtime/custom_call_thunk.h diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 3ab2a2abb0ef56..bb002039a8eb4f 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -801,6 +801,7 @@ cc_library( srcs = ["thunk_emitter.cc"], hdrs = ["thunk_emitter.h"], deps = [ + ":dot_op_emitter", ":ir_emitter2", ":target_machine_features", "//xla:cpu_function_runtime", @@ -811,7 +812,6 @@ cc_library( "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/service:hlo_module_config", - "//xla/service/cpu:dot_op_emitter", "//xla/service/cpu/runtime:all_gather_thunk", "//xla/service/cpu/runtime:all_reduce_thunk", "//xla/service/cpu/runtime:all_to_all_thunk", @@ -820,6 +820,7 @@ cc_library( "//xla/service/cpu/runtime:collective_thunk", "//xla/service/cpu/runtime:conditional_thunk", "//xla/service/cpu/runtime:copy_thunk", + "//xla/service/cpu/runtime:custom_call_thunk", "//xla/service/cpu/runtime:dot_thunk", "//xla/service/cpu/runtime:fft_thunk", "//xla/service/cpu/runtime:infeed_thunk", @@ -830,11 +831,11 @@ cc_library( "//xla/service/cpu/runtime:rng_state_thunk", "//xla/service/cpu/runtime:thunk", "//xla/service/cpu/runtime:while_thunk", - "//xla/stream_executor:launch_dim", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", diff --git a/third_party/xla/xla/service/cpu/cpu_executable.cc b/third_party/xla/xla/service/cpu/cpu_executable.cc index 4c15f9fa17a869..c19fedb753440e 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.cc +++ b/third_party/xla/xla/service/cpu/cpu_executable.cc @@ -358,10 +358,18 @@ absl::Status CpuExecutable::ExecuteThunks( TF_ASSIGN_OR_RETURN(Thunk::CollectiveExecuteParams collective_execute_params, Thunk::CollectiveExecuteParams::Create(run_options)); + // Prepare for executing XLA custom calls. + // TODO(penporn): Consolidate with other thunk parameter set up calls. + TF_ASSIGN_OR_RETURN(Thunk::CustomCallExecuteParams custom_call_execute_params, + Thunk::CustomCallExecuteParams::Create(run_options)); + Thunk::ExecuteParams execute_params = { - &*host_kernels_, &allocations, + &*host_kernels_, + &allocations, runtime::GetXfeedManager(run_options->device_ordinal()), - run_options->intra_op_thread_pool(), &collective_execute_params}; + run_options->intra_op_thread_pool(), + &collective_execute_params, + &custom_call_execute_params}; auto executed_event = thunks_->Execute(execute_params); tsl::BlockUntilReady(executed_event); diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD index e89d00a776415a..0e2a6bd860566e 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/service/cpu/runtime/BUILD @@ -44,7 +44,7 @@ cc_library( deps = [ ":buffer_allocations", "//xla:executable_run_options", - "//xla:util", + "//xla/ffi:execution_context", "//xla/runtime:buffer_use", "//xla/service:global_device_id", "//xla/service/cpu:collectives_interface", @@ -54,11 +54,8 @@ cc_library( "//xla/stream_executor", "//xla/stream_executor/host:host_kernel_c_api", "//xla/tsl/concurrency:async_value", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:traceme", @@ -403,12 +400,10 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:numbers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:traceme", ], @@ -433,6 +428,39 @@ xla_cc_test( ], ) +cc_library( + name = "custom_call_thunk", + srcs = ["custom_call_thunk.cc"], + hdrs = ["custom_call_thunk.h"], + deps = [ + ":thunk", + "//xla:shape_util", + "//xla:util", + "//xla/ffi:attribute_map", + "//xla/ffi:call_frame", + "//xla/ffi:ffi_api", + "//xla/runtime:buffer_use", + "//xla/service:buffer_assignment", + "//xla/service:custom_call_status", + "//xla/service:custom_call_status_internal", + "//xla/service:custom_call_target_registry", + "//xla/stream_executor", + "//xla/tsl/concurrency:async_value", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@eigen_archive//:eigen3", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:traceme", + ], +) + cc_library( name = "dot_thunk", srcs = [ diff --git a/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc b/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc new file mode 100644 index 00000000000000..77e331ae620a14 --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc @@ -0,0 +1,262 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/runtime/custom_call_thunk.h" + +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "mlir/AsmParser/AsmParser.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/ffi/attribute_map.h" +#include "xla/ffi/call_frame.h" +#include "xla/ffi/ffi_api.h" +#include "xla/runtime/buffer_use.h" +#include "xla/service/cpu/runtime/thunk.h" +#include "xla/service/custom_call_status_internal.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/traceme.h" + +namespace xla::cpu { + +absl::StatusOr> CustomCallThunk::Create( + Info info, absl::string_view target_name, OpBuffers op_buffers, + absl::string_view backend_config, CustomCallApiVersion api_version) { + return absl::WrapUnique( + new CustomCallThunk(std::move(info), target_name, std::move(op_buffers), + std::move(backend_config), api_version)); +} + +CustomCallThunk::CustomCallThunk(Info info, absl::string_view target_name, + OpBuffers op_buffers, + absl::string_view backend_config, + CustomCallApiVersion api_version) + : Thunk(Kind::kCustomCall, std::move(info)), + target_name_(target_name), + op_buffers_(std::move(op_buffers)), + backend_config_(std::move(backend_config)), + api_version_(api_version) {} + +tsl::AsyncValueRef CustomCallThunk::Execute( + const ExecuteParams& params) { + VLOG(3) << absl::StreamFormat( + "CustomCall: %s, #arguments=%d, #results=%d", target_name_, + op_buffers_.arguments_buffers.size(), op_buffers_.results_buffers.size()); + if (api_version_ == CustomCallApiVersion::API_VERSION_TYPED_FFI) { + return CallTypedFFI(params); + } + return CallUntypedAPI(params); +} + +tsl::AsyncValueRef CustomCallThunk::CallTypedFFI( + const ExecuteParams& params) { + tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); + + // Find the registered FFI handler for this target. + auto handler = ffi::FindHandler(target_name_, "Host"); + if (!handler.ok()) { + // Overwrite the returned error code (kNotFound) to kInternal to match the + // original CPU implementation. + // TODO(penporn): Change this to kUnimplemented to match the GPU backend + // when thunks is the only runtime for CPU. + return Internal( + "No registered implementation for FFI custom call to %s for Host", + target_name_); + } + + // Build the FFI call frame. + ffi::CallFrameBuilder builder; + + // Add input buffers. + for (int i = 0; i < op_buffers_.arguments_buffers.size(); ++i) { + auto& slice = op_buffers_.arguments_buffers[i]; + auto& shape = op_buffers_.arguments_shapes[i]; + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase arg, + params.buffer_allocations->GetDeviceAddress(slice)); + builder.AddBufferArg(arg, shape.element_type(), shape.dimensions()); + VLOG(3) << absl::StreamFormat(" arg: %s in slice %s (%p)", + shape.ToString(true), slice.ToString(), + arg.opaque()); + } + + // Add output buffers. + for (int i = 0; i < op_buffers_.results_buffers.size(); ++i) { + auto& slice = op_buffers_.results_buffers[i]; + auto& shape = op_buffers_.results_shapes[i]; + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase res, + params.buffer_allocations->GetDeviceAddress(slice)); + builder.AddBufferRet(res, shape.element_type(), shape.dimensions()); + VLOG(3) << absl::StreamFormat(" res: %s in slice %s (%p)", + shape.ToString(true), slice.ToString(), + res.opaque()); + } + + // Add attributes. + if (!backend_config_.empty()) { + // Parse backend config into an MLIR dictionary. + mlir::MLIRContext mlir_context; + ffi::CallFrameBuilder::FlatAttributesMap attributes; + mlir::Attribute attr = mlir::parseAttribute(backend_config_, &mlir_context); + if (auto dict = attr.dyn_cast_or_null()) { + TF_ASSIGN_OR_RETURN(attributes, xla::ffi::BuildAttributesMap(dict)); + } else { + return Internal( + "Unsupported backend config. Expected a string parsable into " + "dictionary attribute"); + } + // Convert the MLIR dictionary to FFI attributes. + ffi::CallFrameBuilder::AttributesBuilder attrs; + attrs.Append(std::move(attributes)); + builder.AddAttributes(attrs.Build()); + VLOG(3) << absl::StreamFormat(" attributes: %s", backend_config_); + } + + // Forward ExecutableRunOptions to the FFI handlers via the call options. + CustomCallExecuteParams* custom_call_params = params.custom_call_params; + ffi::CallOptions call_options = {custom_call_params->device_ordinal, + custom_call_params->stream, + custom_call_params->allocator, + /*called_computation=*/nullptr, + custom_call_params->ffi_execution_context}; + + // Call the function and check execution status. + ffi::CallFrame call_frame = builder.Build(); + auto status = ffi::Call(handler->bundle.execute, call_frame, call_options); + if (!status.ok()) { + // Overwrite the returned error code to kInternal to match the original CPU + // implementation. + // TODO(penporn): Use TF_RETURN_IF_ERROR when thunks is the only runtime. + return Internal("%s", status.message()); + } + return OkExecuteEvent(); +} + +tsl::AsyncValueRef CustomCallThunk::CallUntypedAPI( + const ExecuteParams& params) { + tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); + + // Find the corresponding call target. + void* call_target = + CustomCallTargetRegistry::Global()->Lookup(target_name_, "Host"); + if (!call_target) { + // Use kInternal to match the original CPU implementation. + // TODO(penporn): Change this to kUnimplemented to match the GPU backend + // when thunks is the only runtime for CPU. + return Internal( + "No registered implementation for untyped custom call to %s for Host", + target_name_); + } + + // Collect raw input pointers in an array. + absl::InlinedVector arguments; + arguments.reserve(op_buffers_.arguments_buffers.size()); + for (int i = 0; i < op_buffers_.arguments_buffers.size(); ++i) { + auto& slice = op_buffers_.arguments_buffers[i]; + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase arg, + params.buffer_allocations->GetDeviceAddress(slice)); + arguments.push_back(arg.opaque()); + VLOG(3) << absl::StreamFormat( + " arg: %s in slice %s (%p)", + op_buffers_.arguments_shapes[i].ToString(true), slice.ToString(), + arg.opaque()); + } + const void** in_ptrs = arguments.data(); + + // Collect raw output pointers in another array. + absl::InlinedVector results; + results.reserve(op_buffers_.results_buffers.size()); + for (int i = 0; i < op_buffers_.results_buffers.size(); ++i) { + auto& slice = op_buffers_.results_buffers[i]; + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase res, + params.buffer_allocations->GetDeviceAddress(slice)); + results.push_back(res.opaque()); + VLOG(3) << absl::StreamFormat(" res: %s in slice %s (%p)", + op_buffers_.results_shapes[i].ToString(true), + slice.ToString(), res.opaque()); + } + void* out_ptr = results.size() == 1 ? results[0] : results.data(); + + // Set up the correct function type for each API version. + CustomCallTarget custom_call_target; + switch (api_version_) { + case CustomCallApiVersion::API_VERSION_ORIGINAL: + using v1_signature = void (*)(void* /*out*/, const void** /*in*/); + custom_call_target = [call_target](void* out, const void** in, + const char* opaque, size_t opaque_len, + XlaCustomCallStatus* status) { + auto fn = reinterpret_cast(call_target); + fn(out, in); + }; + break; + case CustomCallApiVersion::API_VERSION_STATUS_RETURNING: + using v2_signature = void (*)(void* /*out*/, const void** /*in*/, + XlaCustomCallStatus* /*status*/); + custom_call_target = [call_target](void* out, const void** in, + const char* opaque, size_t opaque_len, + XlaCustomCallStatus* status) { + auto fn = reinterpret_cast(call_target); + fn(out, in, status); + }; + break; + case CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: + using v3_signature = + void (*)(void* /*out*/, const void** /*in*/, const char* /*opaque*/, + size_t /*opaque_len*/, XlaCustomCallStatus* /*status*/); + custom_call_target = reinterpret_cast(call_target); + break; + default: + return InvalidArgument( + "Unknown custom-call API version enum value: %d (%s)", api_version_, + CustomCallApiVersion_Name(api_version_)); + } + + // Call the function and check execution status. + XlaCustomCallStatus status; + custom_call_target(out_ptr, in_ptrs, backend_config_.c_str(), + backend_config_.size(), &status); + auto status_message = xla::CustomCallStatusGetMessage(&status); + if (status_message.has_value()) { + return Internal("%s", status_message.value()); + } + return OkExecuteEvent(); +} + +CustomCallThunk::BufferUses CustomCallThunk::buffer_uses() const { + BufferUses buffer_uses; + for (const auto& argument : op_buffers_.arguments_buffers) { + buffer_uses.emplace_back(argument, BufferUse::kRead); + } + for (const auto& result : op_buffers_.results_buffers) { + buffer_uses.emplace_back(result, BufferUse::kWrite); + } + return buffer_uses; +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.h b/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.h new file mode 100644 index 00000000000000..c7d83d89b582c8 --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.h @@ -0,0 +1,76 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_RUNTIME_CUSTOM_CALL_THUNK_H_ +#define XLA_SERVICE_CPU_RUNTIME_CUSTOM_CALL_THUNK_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/cpu/runtime/thunk.h" +#include "xla/service/custom_call_status.h" +#include "xla/shape.h" +#include "xla/tsl/concurrency/async_value_ref.h" + +namespace xla::cpu { + +// Handles XLA custom calls. +class CustomCallThunk final : public Thunk { + public: + // Buffer allocation slices and shapes to fill FFI arguments. + struct OpBuffers { + std::vector arguments_buffers; + std::vector arguments_shapes; + + std::vector results_buffers; + std::vector results_shapes; + }; + + static absl::StatusOr> Create( + Info info, absl::string_view target_name, OpBuffers op_buffers, + absl::string_view backend_config, CustomCallApiVersion api_version); + + tsl::AsyncValueRef Execute(const ExecuteParams& params) final; + + BufferUses buffer_uses() const final; + + private: + CustomCallThunk(Info info, absl::string_view target_name, + OpBuffers op_buffers, absl::string_view backend_config, + CustomCallApiVersion api_version); + + // Handles typed-FFI custom calls (API v4). + tsl::AsyncValueRef CallTypedFFI(const ExecuteParams& params); + + // Handles legacy, untyped custom calls (API v1-v3). + tsl::AsyncValueRef CallUntypedAPI(const ExecuteParams& params); + + // Function signature for legacy untyped API. + using CustomCallTarget = std::function; + + std::string target_name_; + OpBuffers op_buffers_; + std::string backend_config_; + CustomCallApiVersion api_version_; +}; + +} // namespace xla::cpu + +#endif // XLA_SERVICE_CPU_RUNTIME_CUSTOM_CALL_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.cc b/third_party/xla/xla/service/cpu/runtime/thunk.cc index b360a492bc4bcc..21b107f1155917 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/thunk.cc @@ -52,6 +52,8 @@ std::string_view Thunk::KindToString(Kind kind) { return "conditional"; case Kind::kCopy: return "copy"; + case Kind::kCustomCall: + return "custom-call"; case Kind::kDot: return "dot"; case Kind::kFft: @@ -109,6 +111,30 @@ Thunk::CollectiveExecuteParams::CollectiveExecuteParams( device_assignment(device_assignment), collectives(collectives) {} +absl::StatusOr +Thunk::CustomCallExecuteParams::Create( + const ExecutableRunOptions* run_options) { + // Device ordinal must be set by caller and passed in run options, if not, + // we use the device ordinal from the parent StreamExecutor. + int32_t device_ordinal = + run_options->device_ordinal() >= 0 + ? run_options->device_ordinal() + : run_options->stream()->parent()->device_ordinal(); + + return CustomCallExecuteParams{device_ordinal, run_options->stream(), + run_options->allocator(), + run_options->ffi_execution_context()}; +} + +Thunk::CustomCallExecuteParams::CustomCallExecuteParams( + int32_t device_ordinal, stream_executor::Stream* stream, + stream_executor::DeviceMemoryAllocator* allocator, + const ffi::ExecutionContext* ffi_execution_context) + : device_ordinal(device_ordinal), + stream(stream), + allocator(allocator), + ffi_execution_context(ffi_execution_context) {} + tsl::AsyncValueRef Thunk::OkExecuteEvent() { static tsl::AsyncValueOwningRef* event = [] { auto* storage = new tsl::internal::AsyncValueStorage(); diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.h b/third_party/xla/xla/service/cpu/runtime/thunk.h index c1a88c72da9342..945b2612bd2be4 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.h +++ b/third_party/xla/xla/service/cpu/runtime/thunk.h @@ -28,12 +28,14 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" #include "xla/executable_run_options.h" +#include "xla/ffi/execution_context.h" #include "xla/runtime/buffer_use.h" #include "xla/service/cpu/collectives_interface.h" #include "xla/service/cpu/runtime/buffer_allocations.h" #include "xla/service/cpu/xfeed_manager.h" #include "xla/service/global_device_id.h" #include "xla/stream_executor/host/host_kernel_c_api.h" +#include "xla/stream_executor/stream.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/chain.h" #include "tsl/platform/statusor.h" @@ -68,6 +70,7 @@ class Thunk { kCollectivePermute, kCopy, kConditional, + kCustomCall, kDot, kFft, kInfeed, @@ -142,6 +145,28 @@ class Thunk { CollectivesInterface* collectives); }; + //===--------------------------------------------------------------------===// + // CustomCallExecuteParams + //===--------------------------------------------------------------------===// + + // Parameters capturing all the details required for custom call execution of + // XLA executables. + struct CustomCallExecuteParams { + static absl::StatusOr Create( + const ExecutableRunOptions* run_options); + + int32_t device_ordinal; + stream_executor::Stream* stream = nullptr; + stream_executor::DeviceMemoryAllocator* allocator = nullptr; + const ffi::ExecutionContext* ffi_execution_context = nullptr; + + private: + CustomCallExecuteParams(int32_t device_ordinal, + stream_executor::Stream* stream, + stream_executor::DeviceMemoryAllocator* allocator, + const ffi::ExecutionContext* ffi_execution_context); + }; + //===--------------------------------------------------------------------===// // ExecuteParams //===--------------------------------------------------------------------===// @@ -154,6 +179,7 @@ class Thunk { runtime::XfeedManager* xfeed = nullptr; const Eigen::ThreadPoolDevice* intra_op_threadpool = nullptr; CollectiveExecuteParams* collective_params = nullptr; + CustomCallExecuteParams* custom_call_params = nullptr; }; // An execute event that becomes ready when all tasks are completed. diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.cc b/third_party/xla/xla/service/cpu/thunk_emitter.cc index 0923d2a541c13f..fed32a0cc7fe96 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.cc +++ b/third_party/xla/xla/service/cpu/thunk_emitter.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/service/cpu/thunk_emitter.h" -#include #include #include #include @@ -23,6 +22,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/types/span.h" #include "xla/cpu_function_runtime.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -44,6 +44,7 @@ limitations under the License. #include "xla/service/cpu/runtime/collective_thunk.h" #include "xla/service/cpu/runtime/conditional_thunk.h" #include "xla/service/cpu/runtime/copy_thunk.h" +#include "xla/service/cpu/runtime/custom_call_thunk.h" #include "xla/service/cpu/runtime/dot_thunk.h" #include "xla/service/cpu/runtime/fft_thunk.h" #include "xla/service/cpu/runtime/infeed_thunk.h" @@ -261,6 +262,9 @@ absl::StatusOr ThunkEmitter::EmitHloInstruction( case HloOpcode::kFft: return EmitFftThunk(instruction); + case HloOpcode::kCustomCall: + return EmitCustomCallThunk(instruction); + default: return absl::UnimplementedError( absl::StrCat("HLO opcode `", HloOpcodeString(instruction->opcode()), @@ -649,6 +653,85 @@ absl::StatusOr ThunkEmitter::EmitFftThunk( /*output_shape=*/instruction->shape()); } +static absl::StatusOr GetCustomCallOpBuffers( + const HloInstruction* instruction, + const BufferAssignment& buffer_assignment) { + // Collect buffer slices for all operands. + std::vector arguments_buffers; + std::vector arguments_shapes; + for (HloInstruction* operand : instruction->operands()) { + for (auto& indexed : ShapeUtil::GetLeafShapes(operand->shape())) { + TF_ASSIGN_OR_RETURN( + arguments_buffers.emplace_back(), + buffer_assignment.GetUniqueSlice(operand, indexed.index)); + arguments_shapes.push_back(indexed.shape); + } + } + + // Collect buffer slices for all results. + std::vector results_buffers; + std::vector results_shapes; + for (auto& indexed : ShapeUtil::GetLeafShapes(instruction->shape())) { + TF_ASSIGN_OR_RETURN( + results_buffers.emplace_back(), + buffer_assignment.GetUniqueSlice(instruction, indexed.index)); + results_shapes.push_back(indexed.shape); + } + + return CustomCallThunk::OpBuffers{ + /*arguments_buffers=*/std::move(arguments_buffers), + /*arguments_shapes=*/std::move(arguments_shapes), + /*results_buffers=*/std::move(results_buffers), + /*results_shapes=*/std::move(results_shapes), + }; +} + +static bool IsValidCustomCallApiVersion(CustomCallApiVersion api_version) { + switch (api_version) { + case CustomCallApiVersion::API_VERSION_ORIGINAL: + case CustomCallApiVersion::API_VERSION_STATUS_RETURNING: + case CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: + case CustomCallApiVersion::API_VERSION_TYPED_FFI: + return true; + default: + return false; + } +} + +absl::StatusOr ThunkEmitter::EmitCustomCallThunk( + const HloInstruction* instruction) { + auto custom_call = Cast(instruction); + + // TODO(penporn): Support these existing targets. + auto custom_call_target = custom_call->custom_call_target(); + if (custom_call_target == "PadToStatic" || + custom_call_target == "SliceToDynamic" || custom_call_target == "TopK" || + custom_call_target == "__onednn$matmul" || + custom_call_target == "__onednn$softmax" || + custom_call_target == "__onednn$layernorm" || + custom_call_target == "__onednn$matmul_reorder") { + return Unimplemented("Custom call target %s is not implemented.", + custom_call_target); + } + + // Check the API version. + auto version = custom_call->api_version(); + if (!IsValidCustomCallApiVersion(version)) { + return InvalidArgument( + "Unknown custom-call API version enum value: %d (%s)", version, + CustomCallApiVersion_Name(version)); + } + + // Get backend config and buffer assignments.ß + auto backend_config = custom_call->opaque(); + TF_ASSIGN_OR_RETURN(auto op_buffers, + GetCustomCallOpBuffers(instruction, buffer_assignment_)); + + return ThunkSequence::Of(ThunkInfo(instruction), + custom_call_target, op_buffers, + backend_config, version); +} + absl::StatusOr ThunkEmitter::GetHostKernelAllocationSlices(const HloInstruction* instruction) { HostKernelAllocationSlices slices; diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.h b/third_party/xla/xla/service/cpu/thunk_emitter.h index 78987bb4373f68..8f97169a932195 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.h +++ b/third_party/xla/xla/service/cpu/thunk_emitter.h @@ -122,6 +122,9 @@ class ThunkEmitter { absl::StatusOr EmitReduceScatterThunk( const HloInstruction* instruction); + absl::StatusOr EmitCustomCallThunk( + const HloInstruction* instruction); + // Returns the list of buffer allocation slices assigned to the given // instruction that will be passed to the host kernel as arguments: a // flattened list of all the leaf buffers for all operands and result. We do diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 831173ad721b08..fd8ae35b30ca5c 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -1828,6 +1828,7 @@ xla_test( name = "custom_call_test", srcs = ["custom_call_test.cc"], backends = ["cpu"], + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":hlo_test_base", diff --git a/third_party/xla/xla/tests/custom_call_test.cc b/third_party/xla/xla/tests/custom_call_test.cc index cdc42d09c3a98a..e52a0d1b85bf4a 100644 --- a/third_party/xla/xla/tests/custom_call_test.cc +++ b/third_party/xla/xla/tests/custom_call_test.cc @@ -1351,5 +1351,29 @@ XLA_TEST_F(FfiCustomCallTest, FfiNestedTupleInput) { EXPECT_EQ(result, expected); } +XLA_TEST_F(FfiCustomCallTest, FfiNestedTupleInputAndOutput) { + GTEST_SKIP() << "Nested tuple inputs/outputs not yet implemented."; + const char* const kModuleStr = R"( + HloModule m + + ENTRY test { + c0 = ((f32[], f32[]), (f32[], f32[])) constant(((7.0, 42.0), (8.0, 43.0))) + ROOT custom-call = (f32[], (f32[], f32[]), f32[]) custom-call(c0), custom_call_target="__xla_test$$FfiTupleRotate", api_version=API_VERSION_TYPED_FFI + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + Literal arg0 = LiteralUtil::CreateR0(7.f); + Literal arg1 = LiteralUtil::CreateR0(42.f); + Literal arg2 = LiteralUtil::CreateR0(8.f); + Literal arg3 = LiteralUtil::CreateR0(43.f); + + Literal inner_tuple = LiteralUtil::MakeTuple({&arg2, &arg3}); + Literal expected = LiteralUtil::MakeTuple({&arg1, &inner_tuple, &arg0}); + TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {})); + EXPECT_EQ(result, expected); +} + } // namespace } // namespace xla From f69fafeb193a10a5f7b79fb59d064a6fb79bfa9b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 20 Jun 2024 19:04:50 -0700 Subject: [PATCH 096/256] [XLA:Python:JAX] Add a method jax_jit.parse_arguments and a class jax_jit.ArgumentsSignature. These expose the C++ jit fast path argument parsing logic to Python, separate from the rest of the jit logic. A subsequent change will use this logic to cache inner jits during JAX tracing. PiperOrigin-RevId: 645225161 --- third_party/xla/xla/python/BUILD | 36 ++++++++++ third_party/xla/xla/python/jax_jit.cc | 65 +++++++++++++++++++ third_party/xla/xla/python/jax_jit.h | 2 +- third_party/xla/xla/python/jax_jit_test.py | 47 ++++++++++++++ .../xla/xla/python/nb_absl_inlined_vector.h | 33 ++++++++++ third_party/xla/xla/python/xla_client.py | 2 +- .../xla/xla/python/xla_extension/jax_jit.pyi | 25 +++++++ 7 files changed, 208 insertions(+), 2 deletions(-) create mode 100644 third_party/xla/xla/python/jax_jit_test.py create mode 100644 third_party/xla/xla/python/nb_absl_inlined_vector.h diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index fd1225158efdd3..c882063e075dc3 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -546,6 +546,8 @@ cc_library( features = ["-use_header_modules"], visibility = [":friends"], # For the functions to access C++ flags/thread-local variables deps = [ + ":nb_absl_inlined_vector", + ":nb_absl_span", ":nb_helpers", ":py_client", ":python_ref_manager", @@ -555,6 +557,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -1493,6 +1496,18 @@ cc_library( ], ) +cc_library( + name = "nb_absl_inlined_vector", + hdrs = ["nb_absl_inlined_vector.h"], + compatible_with = [], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + deps = [ + "//third_party/nanobind", + "@com_google_absl//absl/container:inlined_vector", + ], +) + cc_library( name = "nb_absl_span", hdrs = ["nb_absl_span.h"], @@ -1553,3 +1568,24 @@ xla_cc_test( "@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc_impl", ], ) + +py_strict_test( + name = "jax_jit_test", + srcs = ["jax_jit_test.py"], + main = "jax_jit_test.py", + python_version = "PY3", + srcs_version = "PY3", + tags = [ + "no_oss", + "not_run:arm", + ], # TODO(phawkins): This test passes, but requires --config=monolithic. + deps = [ + ":xla_client", + ":xla_extension", + "//third_party/py/numpy", + "@absl_py//absl/flags", + "@absl_py//absl/logging", + "@absl_py//absl/testing:absltest", + "@absl_py//absl/testing:parameterized", + ] + xla_py_test_deps(), +) diff --git a/third_party/xla/xla/python/jax_jit.cc b/third_party/xla/xla/python/jax_jit.cc index 4f10ee7c7e6036..4665afd98cabc2 100644 --- a/third_party/xla/xla/python/jax_jit.cc +++ b/third_party/xla/xla/python/jax_jit.cc @@ -39,6 +39,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/base/attributes.h" #include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -46,10 +47,14 @@ limitations under the License. #include "absl/types/span.h" #include "third_party/nanobind/include/nanobind/nanobind.h" #include "third_party/nanobind/include/nanobind/stl/optional.h" // IWYU pragma: keep +#include "third_party/nanobind/include/nanobind/stl/pair.h" // IWYU pragma: keep #include "third_party/nanobind/include/nanobind/stl/string.h" // IWYU pragma: keep #include "third_party/nanobind/include/nanobind/stl/string_view.h" // IWYU pragma: keep +#include "third_party/nanobind/include/nanobind/stl/vector.h" // IWYU pragma: keep #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/status_casters.h" +#include "xla/python/nb_absl_inlined_vector.h" // IWYU pragma: keep +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep #include "xla/python/py_values.h" #include "xla/python/pytree.h" #include "xla/python/sharding.h" @@ -402,6 +407,66 @@ void BuildJaxjitSubmodule(nb::module_& m) { xla::ValueOrThrowWrapper(xla::PyArgSignatureOfValue)); jitlib.def("_is_float0", &xla::IsFloat0); + + nb::class_ argument_signature(jitlib, "ArgumentSignature"); + argument_signature.def_ro("static_args", &ArgumentSignature::static_args) + .def_ro("static_arg_names", &ArgumentSignature::static_arg_names) + .def_ro("dynamic_arg_names", &ArgumentSignature::dynamic_arg_names) + .def_ro("dynamic_arg_treedefs", &ArgumentSignature::dynamic_arg_treedefs) + .def("__repr__", &ArgumentSignature::DebugString) + .def("__str__", &ArgumentSignature::DebugString) + .def("__hash__", + [](const ArgumentSignature& s) { return absl::HashOf(s); }) + .def("__eq__", [](const ArgumentSignature& a, + const ArgumentSignature& b) { return a == b; }) + .def("__ne__", [](const ArgumentSignature& a, + const ArgumentSignature& b) { return a != b; }); + + jitlib.def( + "parse_arguments", + [](nb::sequence positional_args, nb::sequence keyword_args, + nb::tuple kwnames, absl::Span static_argnums, + absl::Span static_argnames, + xla::PyTreeRegistry* pytree_registry) { + ArgumentSignature signature; + absl::InlinedVector flat_dynamic_args; + nb::object positional_args_seq = nb::steal(PySequence_Fast( + positional_args.ptr(), "positional_args must be a list or tuple")); + if (!positional_args_seq.ptr()) { + throw nb::python_error(); + } + nb::object keyword_args_seq = nb::steal(PySequence_Fast( + keyword_args.ptr(), "keyword_args must be a list or tuple")); + if (!keyword_args_seq.ptr()) { + throw nb::python_error(); + } + absl::Span positional_args_span = + absl::MakeSpan(PySequence_Fast_ITEMS(positional_args_seq.ptr()), + PySequence_Fast_GET_SIZE(positional_args_seq.ptr())); + absl::Span keyword_args_span = + absl::MakeSpan(PySequence_Fast_ITEMS(keyword_args_seq.ptr()), + PySequence_Fast_GET_SIZE(keyword_args_seq.ptr())); + xla::ThrowIfError(ParseArguments( + positional_args_span, keyword_args_span, kwnames, static_argnums, + static_argnames, pytree_registry, signature, flat_dynamic_args)); + return std::make_pair(std::move(signature), + std::move(flat_dynamic_args)); + }, + nb::arg("positional_args"), nb::arg("keyword_args"), nb::arg("kwnames"), + nb::arg("static_argnums"), nb::arg("static_argnames"), + nb::arg("pytree_registry"), + R"doc(Parses the arguments to a function as jax.jit would. + +Returns a ArgumentSignature and the flattened dynamic arguments. + +Args: + positional_args: The positional arguments. + keyword_args: The keyword arguments. + kwnames: The keyword names. + static_argnums: The static argument numbers. + static_argnames: The static argument names. + pytree_registry: The pytree registry. +)doc"); } } // namespace jax diff --git a/third_party/xla/xla/python/jax_jit.h b/third_party/xla/xla/python/jax_jit.h index c076ef9cbeabe1..ecdd6205273477 100644 --- a/third_party/xla/xla/python/jax_jit.h +++ b/third_party/xla/xla/python/jax_jit.h @@ -162,7 +162,7 @@ H AbslHashValue(H h, const ArgumentSignature& s) { // arguments // static_argnames: the names of the static arguments // pytree_registry: the registry to use to convert the arguments to pytrees -// arguments: output; describes the static arguments and the identities of the +// signature: output; describes the static arguments and the identities of the // dynamic arguments. // flat_dynamic_args: output; the concatenation of the dynamic positional // arguments and sorted keyword arguments. diff --git a/third_party/xla/xla/python/jax_jit_test.py b/third_party/xla/xla/python/jax_jit_test.py new file mode 100644 index 00000000000000..abd15d8fef3cde --- /dev/null +++ b/third_party/xla/xla/python/jax_jit_test.py @@ -0,0 +1,47 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for jax_jit helper functions.""" + +from absl.testing import absltest + +from xla.python import xla_client + +jax_jit = xla_client._xla.jax_jit +pytree = xla_client._xla.pytree + +pytree_registry = pytree.default_registry() + + +class JaxJitTest(absltest.TestCase): + + def testParseArguments(self): + sig, args = jax_jit.parse_arguments( + positional_args=[1, 2, 3], + keyword_args=[4, 5], + kwnames=("a", "b"), + static_argnums=[0, 2], + static_argnames=["a"], + pytree_registry=pytree_registry, + ) + self.assertEqual(args, [2, 5]) + self.assertEqual(sig.static_args, [1, 3, 4]) + self.assertEqual(sig.static_arg_names, ["a"]) + _, leaf = pytree_registry.flatten(0) + self.assertEqual(sig.dynamic_arg_names, ["b"]) + self.assertEqual(sig.dynamic_arg_treedefs, [leaf, leaf]) + + +if __name__ == "__main__": + absltest.main() diff --git a/third_party/xla/xla/python/nb_absl_inlined_vector.h b/third_party/xla/xla/python/nb_absl_inlined_vector.h new file mode 100644 index 00000000000000..a68a43dbe3eb98 --- /dev/null +++ b/third_party/xla/xla/python/nb_absl_inlined_vector.h @@ -0,0 +1,33 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_NB_ABSL_INLINED_VECTOR_H_ +#define XLA_PYTHON_NB_ABSL_INLINED_VECTOR_H_ + +#include "absl/container/inlined_vector.h" +#include "third_party/nanobind/include/nanobind/nanobind.h" +#include "third_party/nanobind/include/nanobind/stl/detail/nb_list.h" + +namespace nanobind { +namespace detail { + +template +struct type_caster> + : list_caster, Type> {}; + +} // namespace detail +} // namespace nanobind + +#endif // XLA_PYTHON_NB_ABSL_INLINED_VECTOR_H_ diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index d49fcdafb1ca5e..6db04632947def 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 272 +_version = 273 # Version number for MLIR:Python components. mlir_api_version = 57 diff --git a/third_party/xla/xla/python/xla_extension/jax_jit.pyi b/third_party/xla/xla/python/xla_extension/jax_jit.pyi index b0428919a20a70..931ee12dfb8779 100644 --- a/third_party/xla/xla/python/xla_extension/jax_jit.pyi +++ b/third_party/xla/xla/python/xla_extension/jax_jit.pyi @@ -18,6 +18,8 @@ from typing import Any, Callable, Optional, Sequence, Tuple import numpy as np from xla.python import xla_extension +from . import pytree + Client = xla_extension.Client Device = xla_extension.Device @@ -50,3 +52,26 @@ def _ArgSignatureOfValue( __jax_enable_x64: bool) -> ArgSignature: ... def _is_float0(__arg: Any) -> bool: ... + + +class ArgumentSignature: + static_args: Sequence[Any] + static_arg_names: Sequence[str] + dynamic_arg_names: Sequence[str] + dynamic_arg_treedefs: Sequence[pytree.PyTreeDef] + + def __eq__(self, value, /): ... + def __ne__(self, value, /): ... + def __hash__(self, /): ... + def __str__(self): ... + def __repr__(self): ... + + +def parse_arguments( + positional_args: Sequence[Any], + keyword_args: Sequence[Any], + kwnames: Tuple[str, ...], + static_argnums: Sequence[int], + static_argnames: Sequence[str], + pytree_registry: pytree.PyTreeRegistry, +) -> tuple[ArgumentSignature, Sequence[Any]]: ... \ No newline at end of file From 0d801cd2025f93812f0011b804172c62863c9179 Mon Sep 17 00:00:00 2001 From: Ionel Gog Date: Thu, 20 Jun 2024 19:20:23 -0700 Subject: [PATCH 097/256] [IFRT] Add ifrt.CopyArrays op to IFRT IR. This op is the IFRT IR equivalent of `CopyArrays` from `xla::ifrt::Client`. PiperOrigin-RevId: 645227813 --- .../xla/xla/python/ifrt/ir/ifrt_ops.cc | 57 +++++++ .../xla/xla/python/ifrt/ir/ifrt_ops.td | 29 ++++ .../xla/xla/python/ifrt/ir/tests/BUILD | 1 + .../ifrt/ir/tests/ifrt_verify_donation.mlir | 16 ++ .../ifrt/ir/tests/verify_copy_arrays.mlir | 159 ++++++++++++++++++ .../transforms/ifrt_verify_donation_pass.cc | 30 ++-- 6 files changed, 278 insertions(+), 14 deletions(-) create mode 100644 third_party/xla/xla/python/ifrt/ir/tests/verify_copy_arrays.mlir diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc index ab6468d518c005..7efb6d650107a2 100644 --- a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -279,6 +280,62 @@ mlir::LogicalResult DisassembleOp::verify() { return mlir::success(); } +mlir::LogicalResult CopyArraysOp::verify() { + int num_in_arrays = getInputs().size(); + int num_out_arrays = getOutputs().size(); + if (num_in_arrays == 0) { + return emitOpError() << "requires at least one input array"; + } + if (num_in_arrays != num_out_arrays) { + return emitOpError() + << "requires the same number of input and output arrays"; + } + IfrtArrayType first_input = + llvm::cast(getInputs().front().getType()); + auto src_devices = first_input.getDevicesAttr(); + auto src_memory_kind = first_input.MemoryKind(); + IfrtArrayType first_output = + llvm::cast(getOutputs().front().getType()); + auto dst_devices = first_output.getDevicesAttr(); + auto dst_memory_kind = first_output.MemoryKind(); + for (const auto [idx, pair] : + llvm::enumerate(llvm::zip(getInputs(), getOutputs()))) { + const auto input_array = + llvm::cast(std::get<0>(pair).getType()); + if (src_devices != input_array.getDevicesAttr()) { + return emitOpError() << "requires all input arrays to have the same " + "devices, but input #" + << idx << " has different devices"; + } + if (src_memory_kind != input_array.MemoryKind()) { + return emitOpError() << "requires all input arrays to have the same " + "memory kind, but input #" + << idx << " has a different memory kind"; + } + const auto output_array = + llvm::cast(std::get<1>(pair).getType()); + if (dst_devices != output_array.getDevicesAttr()) { + return emitOpError() << "requires all output arrays to have the same " + "devices, but output #" + << idx << " has different devices"; + } + if (dst_memory_kind != output_array.MemoryKind()) { + return emitOpError() << "requires all output arrays to have the same " + "memory kind, but output #" + << idx << " has a different memory kind"; + } + if (input_array.getShape() != output_array.getShape()) { + return emitOpError() << "requires input #" << idx << " and output #" + << idx << " to have the same shape and dtype"; + } + if (input_array.getShardingAttr() != output_array.getShardingAttr()) { + return emitOpError() << "requires input #" << idx << " and output #" + << idx << " to have the same sharding"; + } + } + return mlir::success(); +} + mlir::LogicalResult RemapArraysOp::verify() { int num_in_arrays = getInputs().size(); int num_out_arrays = getOutputs().size(); diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.td b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.td index cb46203901da56..4691634746ff83 100644 --- a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.td +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.td @@ -53,6 +53,35 @@ def Ifrt_ReshardOp : Ifrt_Op<"Reshard", [NestedInIfrtFunc]> { let hasVerifier = 1; } +def Ifrt_CopyArraysOp + : Ifrt_Op<"CopyArrays", [AttrSizedOperandSegments, NestedInIfrtFunc]> { + let summary = "Copies arrays to a new set of devices"; + let description = [{ + Copies the input arrays to the output arrays. + + This op requires that all input arrays have use the same devices and memory + king, and all the output arrays use the same devices and memory king. Note + that the devices and memory kind used by the input and output arrays might + not be the same. Moreover, the corresponding output array of an input array + must have the same sharding as the op can not reshard. + }]; + + let arguments = (ins + Variadic:$inputs, + DefaultValuedOptionalAttr:$donated, + Variadic:$control_inputs); + let results = (outs + Variadic:$outputs, + Ifrt_ControlType:$control_output); + + let assemblyFormat = [{ + `(` $inputs `)` oilist(`after` $control_inputs) attr-dict + `:` functional-type($inputs, $outputs) + }]; + + let hasVerifier = 1; +} + def Ifrt_AssembleOp : Ifrt_Op<"Assemble", [AttrSizedOperandSegments, NestedInIfrtFunc]> { let summary = "Assembles single device arrays to a sharded array"; diff --git a/third_party/xla/xla/python/ifrt/ir/tests/BUILD b/third_party/xla/xla/python/ifrt/ir/tests/BUILD index 65fb16d0a7813a..1e588ebae38b45 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/tests/BUILD @@ -20,6 +20,7 @@ lit_test_suite( "verify_attrs.mlir", "verify_call.mlir", "verify_call_loaded_executable.mlir", + "verify_copy_arrays.mlir", "verify_disassemble.mlir", "verify_loaded_executable.mlir", "verify_remap_arrays.mlir", diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir index 59a1fb61b9ac77..a8cd493a16aa60 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir @@ -93,6 +93,22 @@ module @donate_to_reshard_and_call_error { // ----- +!array0 = !ifrt.array, + #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<2 to [0] on 2>, [2, 3]> +module @donate_to_two_copy_arrays_error { + func.func @main(%arg0: !array0 {ifrt.donated}) -> (!array1, !array1) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.CopyArrays(%arg0) {donated=true} : (!array0) -> !array1 + // expected-error @+1 {{'ifrt.CopyArrays' op input #0 already donated.}} + %1, %ctrl_1 = ifrt.CopyArrays(%arg0) {donated=true} : (!array0) -> !array1 + return %0, %1 : !array1, !array1 + } +} + +// ----- + !array = !ifrt.array, #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> module @program_arg_not_donated_to_remap_error { diff --git a/third_party/xla/xla/python/ifrt/ir/tests/verify_copy_arrays.mlir b/third_party/xla/xla/python/ifrt/ir/tests/verify_copy_arrays.mlir new file mode 100644 index 00000000000000..34631c8b4db1c2 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/tests/verify_copy_arrays.mlir @@ -0,0 +1,159 @@ +// RUN: ifrt-opt %s -split-input-file -verify-diagnostics + +!array0 = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> +!array2 = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [2,3]> +!array3 = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [2,3]> +func.func @copy_two_different_arrays(%arg0: !array0, %arg1: !array1) + attributes {ifrt.function} { + %0, %1, %ctrl = ifrt.CopyArrays(%arg0, %arg1) + : (!array0, !array1) -> (!array2, !array3) + return +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [2,3]> +func.func @copy_donated_array(%arg0: !array0) + attributes {ifrt.function} { + %0, %ctrl = ifrt.CopyArrays(%arg0) {donated=true} + : (!array0) -> (!array1) + return +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> +func.func @requires_at_least_one_input() attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.CopyArrays' op requires at least one input array}} + %ctrl = ifrt.CopyArrays() : () -> () + return +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> +func.func @requires_same_num_inputs_and_outputs(%arg0: !array) + attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.CopyArrays' op requires the same number of input and output arrays}} + %0, %1, %ctrl = ifrt.CopyArrays(%arg0) : (!array) -> (!array, !array) + return +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> +func.func @requires_copied_array_to_have_same_dtype(%arg0: !array0) + attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.CopyArrays' op requires input #1 and output #1 to have the same shape and dtype}} + %0, %1, %ctrl = ifrt.CopyArrays(%arg0, %arg0) + : (!array0, !array0) -> (!array0, !array1) + return +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> +func.func @requires_copied_array_to_have_same_shape(%arg0: !array0) + attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.CopyArrays' op requires input #1 and output #1 to have the same shape and dtype}} + %0, %1, %ctrl = ifrt.CopyArrays(%arg0, %arg0) + : (!array0, !array0) -> (!array0, !array1) + return +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> +func.func @requires_copied_array_to_have_same_sharding(%arg0: !array0) + attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.CopyArrays' op requires input #0 and output #0 to have the same sharding}} + %0, %ctrl = ifrt.CopyArrays(%arg0) : (!array0) -> (!array1) + return +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [2,3]> +!array2 = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [4,5]> +func.func @requires_inputs_with_same_devices(%arg0: !array0, %arg1: !array1) + attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.CopyArrays' op requires all input arrays to have the same devices}} + %0, %1, %ctrl = ifrt.CopyArrays(%arg0, %arg1) + : (!array0, !array1) -> (!array2, !array2) + return +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [2,3]> +!array2 = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [4,5]> +func.func @requires_outputs_with_same_devices(%arg0: !array0) + attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.CopyArrays' op requires all output arrays to have the same devices}} + %0, %1, %ctrl = ifrt.CopyArrays(%arg0, %arg0) + : (!array0, !array0) -> (!array1, !array2) + return +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1], + memory_kind = "device"> +!array1 = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1], + memory_kind = "host"> +!array2 = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [2,3]> +func.func @requires_inputs_with_same_memory_kind(%arg0: !array0, %arg1: !array1) + attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.CopyArrays' op requires all input arrays to have the same memory kind}} + %0, %1, %ctrl = ifrt.CopyArrays(%arg0, %arg1) + : (!array0, !array1) -> (!array2, !array2) + return +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [2,3], + memory_kind = "device"> +!array2 = !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [2,3], + memory_kind = "host"> +func.func @requires_outputs_with_same_memory_kind(%arg0: !array0) + attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.CopyArrays' op requires all output arrays to have the same memory kind}} + %0, %1, %ctrl = ifrt.CopyArrays(%arg0, %arg0) + : (!array0, !array0) -> (!array1, !array2) + return +} diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc index 63f8a19d0256a9..9801b842355879 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc @@ -103,21 +103,23 @@ void IfrtVerifyDonationPass::runOnOperation() { } return mlir::success(); }) - .Case([&](auto& op) { - if (op.getDonated()) { - for (const auto [idx, input] : - llvm::enumerate(op.getInputs())) { - if (!donated_values.insert(input).second) { - op.emitOpError() << "input #" << idx << " already donated."; - return mlir::failure(); - } - if (mlir::failed(VerifyIfInputAndDonated(op, input))) { - return mlir::failure(); + .Case( + [&](auto& op) { + if (op.getDonated()) { + for (const auto [idx, input] : + llvm::enumerate(op.getInputs())) { + if (!donated_values.insert(input).second) { + op.emitOpError() + << "input #" << idx << " already donated."; + return mlir::failure(); + } + if (mlir::failed(VerifyIfInputAndDonated(op, input))) { + return mlir::failure(); + } + } } - } - } - return mlir::success(); - }) + return mlir::success(); + }) .Default(mlir::success()); if (mlir::failed(result)) { From 6f72e32cd07ad580f1b546403457dc9bbbc12701 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 20 Jun 2024 23:37:42 -0700 Subject: [PATCH 098/256] Replace EXPECT_OK with TF_EXPECT_OK EXPECT_OK doesn't exist in open source. PiperOrigin-RevId: 645280836 --- .../xla/xla/python/ifrt_proxy/server/BUILD | 1 + .../ifrt_proxy/server/ifrt_backend_test.cc | 3 +- third_party/xla/xla/service/gpu/BUILD | 1 + .../service/gpu/gemm_fusion_autotuner_test.cc | 2 +- .../gpu/ir_emitter_triton_mem_utils_test.cc | 3 +- third_party/xla/xla/service/gpu/model/BUILD | 1 + .../gpu/model/symbolic_tile_analysis_test.cc | 6 +- .../distributed_runtime/coordination/BUILD | 2 +- .../coordination/coordination_service_test.cc | 123 +++++++++--------- 9 files changed, 75 insertions(+), 67 deletions(-) diff --git a/third_party/xla/xla/python/ifrt_proxy/server/BUILD b/third_party/xla/xla/python/ifrt_proxy/server/BUILD index df4af174557b60..29a9f3173e304c 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/server/BUILD @@ -196,6 +196,7 @@ ifrt_proxy_cc_test( "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", + "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:protobuf", diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc index 9cc1af0444f4e8..e03230a08ffa4a 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc @@ -69,6 +69,7 @@ #include "xla/test.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/xla_data.pb.h" +#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/protobuf.h" // IWYU pragma: keep @@ -963,7 +964,7 @@ TEST_P(IfrtBackendHandlerTest, CompileSuccess) { addressable_device_ids: [ 0, 1, 2, 3 ] fingerprint_value: "fingerprint" )pb"))); - EXPECT_OK(CheckFuture(response.ready_future_handle())); + TF_EXPECT_OK(CheckFuture(response.ready_future_handle())); } #endif diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 98daa26d27e662..8bc8fd4315a260 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -718,6 +718,7 @@ cc_test( "@llvm-project//mlir:IR", "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:Support", + "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:logging", "@triton//:TritonDialects", ], diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc index 19b0d3877f3490..2168bc78711acd 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc @@ -616,7 +616,7 @@ ENTRY main { GetToolkitVersion(), fp8_rewrite); } - EXPECT_OK(HloTestBase::RunHloPass(&pipeline, module.get())); + TF_EXPECT_OK(HloTestBase::RunHloPass(&pipeline, module.get())); const bool is_at_least_hopper = std::holds_alternative( autotune_config.GetGpuComputeCapability()) && diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_mem_utils_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_mem_utils_test.cc index 5362609440fc5e..be0e64581c3352 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_mem_utils_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_mem_utils_test.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "third_party/triton/include/triton/Dialect/Triton/IR/Dialect.h" #include "third_party/triton/include/triton/Dialect/Triton/IR/Types.h" @@ -91,7 +92,7 @@ CreateAndTileParameterHloInstruction(std::vector shape_sizes, auto tiled_hlo = TiledHloInstruction::Create( hlo.get(), tile_sizes, tile_strides, CreateAffineMap(tile_sizes, ctx)); - EXPECT_OK(tiled_hlo); + TF_EXPECT_OK(tiled_hlo); return std::make_pair(std::move(hlo), std::move(tiled_hlo.value())); } diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index cf14c86064b5b2..a35308c7f0d81b 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -706,6 +706,7 @@ xla_cc_test( "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", + "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc index 4ecc81135f2e07..279ada6aaa6267 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -40,6 +40,7 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" #include "xla/util.h" +#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { @@ -462,7 +463,8 @@ ENTRY main { // Passing tile parameters that satisfy the constraints should let us compute // a TiledHloComputation. - EXPECT_OK(analysis->ParametersSatisfyConstraints(possible_tile_parameters)); + TF_EXPECT_OK( + analysis->ParametersSatisfyConstraints(possible_tile_parameters)); // Passing tile parameters that do not satisfy the constraints should result // in an error... @@ -470,7 +472,7 @@ ENTRY main { StatusIs(absl::StatusCode::kInvalidArgument)); // ... unless we pinky-promise (lie) that they satisfy the constraints ;) - EXPECT_OK(analysis->ComputeTiledHloInstructions( + TF_EXPECT_OK(analysis->ComputeTiledHloInstructions( impossible_tile_parameters, /*constraints_are_known_satisfied=*/true)); } diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD index 61d6d0377f67bb..194c6891fb9b05 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD @@ -121,8 +121,8 @@ tsl_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", + "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:random", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc index 7d5dd06c5e56d7..3fcc8cb3fd0a36 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc @@ -36,6 +36,7 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h" #include "xla/tsl/distributed_runtime/coordination/test_device.pb.h" +#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/random.h" #include "tsl/platform/status.h" @@ -398,7 +399,7 @@ TEST(CoordinationServiceTest, RegisterTask_AlreadyConnected_Succeeds) { const absl::Status status = coord_service->RegisterTask(task_0, /*incarnation=*/0); - EXPECT_OK(status) << status; + TF_EXPECT_OK(status) << status; } TEST(CoordinationServiceTest, @@ -500,17 +501,17 @@ TEST_F(CoordinateTwoTasksTest, InsertKeyValue_Duplicate_Fail) { EXPECT_TRUE(absl::IsAlreadyExists( coord_service_->InsertKeyValue("key0", "never_added"))); auto result = coord_service_->TryGetKeyValue("key0"); - EXPECT_OK(result.status()); + TF_EXPECT_OK(result.status()); EXPECT_EQ(result.value(), "original_value"); } TEST_F(CoordinateTwoTasksTest, InsertKeyValue_Duplicate_Overwrite) { EnableCoordinationService(); ASSERT_OK(coord_service_->InsertKeyValue("key0", "original_value")); - EXPECT_OK(coord_service_->InsertKeyValue("key0", "overwritten_value", - /*allow_overwrite=*/true)); + TF_EXPECT_OK(coord_service_->InsertKeyValue("key0", "overwritten_value", + /*allow_overwrite=*/true)); auto result = coord_service_->TryGetKeyValue("key0"); - EXPECT_OK(result.status()); + TF_EXPECT_OK(result.status()); EXPECT_EQ(result.value(), "overwritten_value"); } @@ -910,9 +911,9 @@ TEST_F(CoordinationBarrierTest, Barrier) { EXPECT_TRUE(n_0.HasBeenNotified()); EXPECT_TRUE(n_1.HasBeenNotified()); EXPECT_TRUE(n_2.HasBeenNotified()); - EXPECT_OK(barrier_status_0); - EXPECT_OK(barrier_status_1); - EXPECT_OK(barrier_status_2); + TF_EXPECT_OK(barrier_status_0); + TF_EXPECT_OK(barrier_status_1); + TF_EXPECT_OK(barrier_status_2); } TEST_F(CoordinationBarrierTest, BarrierWithSubsetOfTasks) { @@ -941,8 +942,8 @@ TEST_F(CoordinationBarrierTest, BarrierWithSubsetOfTasks) { // All listed tasks passed the barrier. EXPECT_TRUE(n_0.HasBeenNotified()); EXPECT_TRUE(n_1.HasBeenNotified()); - EXPECT_OK(barrier_status_0); - EXPECT_OK(barrier_status_1); + TF_EXPECT_OK(barrier_status_0); + TF_EXPECT_OK(barrier_status_1); } TEST_F(CoordinationBarrierTest, BarrierWithMismatchedTasks) { @@ -1017,8 +1018,8 @@ TEST_F(CoordinationBarrierTest, BarrierByNonParticipatingTaskThreeTasks) { n_1.WaitForNotification(); // Barrier should pass because only participating tasks have called it. - EXPECT_OK(barrier_status_0); - EXPECT_OK(barrier_status_1); + TF_EXPECT_OK(barrier_status_0); + TF_EXPECT_OK(barrier_status_1); // Task 2 unexpectedly calls a barrier that it is not participating in. GetCoordinationService()->BarrierAsync( @@ -1118,7 +1119,7 @@ TEST_F(CoordinationBarrierTest, BarrierCancelled) { GetCoordinationService()->CancelBarrier(barrier_id, GetTask(0)); EXPECT_TRUE(absl::IsCancelled(barrier_status)); - EXPECT_OK(cancelled_status); + TF_EXPECT_OK(cancelled_status); } TEST_F(CoordinationBarrierTest, CancelNonExistentBarrier_FutureBarrierFails) { @@ -1161,9 +1162,9 @@ TEST_F(CoordinationBarrierTest, CancelAfterBarrierHasPassed) { GetCoordinationService()->CancelBarrier(barrier_id, GetTask(0)); EXPECT_TRUE(absl::IsFailedPrecondition(cancelled_status)); - EXPECT_OK(barrier_status_0); - EXPECT_OK(barrier_status_1); - EXPECT_OK(barrier_status_2); + TF_EXPECT_OK(barrier_status_0); + TF_EXPECT_OK(barrier_status_1); + TF_EXPECT_OK(barrier_status_2); } TEST_F(CoordinationBarrierTest, PassedBarrierReturnsImmediately) { @@ -1209,10 +1210,10 @@ TEST_F(CoordinationBarrierTest, PassedBarrierReturnsImmediately) { EXPECT_TRUE(n1.HasBeenNotified()); EXPECT_TRUE(n2.HasBeenNotified()); EXPECT_TRUE(n_repeat.HasBeenNotified()); - EXPECT_OK(barrier_status_0); - EXPECT_OK(barrier_status_1); - EXPECT_OK(barrier_status_2); - EXPECT_OK(barrier_status_repeat); + TF_EXPECT_OK(barrier_status_0); + TF_EXPECT_OK(barrier_status_1); + TF_EXPECT_OK(barrier_status_2); + TF_EXPECT_OK(barrier_status_repeat); } TEST_F(CoordinationBarrierTest, BarrierFailsIfTaskIsAlreadyInError) { @@ -1287,28 +1288,28 @@ TEST_F(CoordinationBarrierTest, barrier_status_2 = s; n_2.Notify(); }); - EXPECT_OK(barrier_status_0); - EXPECT_OK(barrier_status_1); - EXPECT_OK(barrier_status_2); + TF_EXPECT_OK(barrier_status_0); + TF_EXPECT_OK(barrier_status_1); + TF_EXPECT_OK(barrier_status_2); } TEST_F(CoordinateTwoTasksTest, ResetAndRegisterAgain) { EnableCoordinationService(); - EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); - EXPECT_OK(coord_service_->ResetTask(task_0_)); + TF_EXPECT_OK(coord_service_->ResetTask(task_0_)); // Task should be allowed to register again after being reset. - EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); } TEST_F(CoordinateTwoTasksTest, Reset_HeartbeatsAreAcceptedForAGracePeriod) { EnableCoordinationService(); - EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); - EXPECT_OK(coord_service_->ResetTask(task_0_)); + TF_EXPECT_OK(coord_service_->ResetTask(task_0_)); // Heartbeat should be allowed for a short grace period after reset. - EXPECT_OK(coord_service_->RecordHeartbeat(task_0_, incarnation_0_)); + TF_EXPECT_OK(coord_service_->RecordHeartbeat(task_0_, incarnation_0_)); // Heartbeat failure should be triggered for disconnected task after grace // period. @@ -1321,7 +1322,7 @@ TEST_F(CoordinateTwoTasksTest, Reset_HeartbeatsAreAcceptedForAGracePeriod) { TEST_F(CoordinateTwoTasksTest, Reset_FailsOngoingBarrier) { EnableCoordinationService(/*has_service_to_client_connection=*/true, /*enable_shutdown_barrier=*/false); - EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); absl::Status barrier_status; absl::Notification barrier_n; coord_service_->BarrierAsync("ongoing_barrier", absl::InfiniteDuration(), @@ -1332,7 +1333,7 @@ TEST_F(CoordinateTwoTasksTest, Reset_FailsOngoingBarrier) { barrier_n.Notify(); }); - EXPECT_OK(coord_service_->ResetTask(task_0_)); + TF_EXPECT_OK(coord_service_->ResetTask(task_0_)); // Ongoing barrier should fail with error after shutdown. EXPECT_TRUE(barrier_n.HasBeenNotified()); @@ -1342,17 +1343,17 @@ TEST_F(CoordinateTwoTasksTest, Reset_FailsOngoingBarrier) { TEST_F(CoordinateTwoTasksTest, Shutdown_HeartbeatsAreAcceptedForAGracePeriod) { EnableCoordinationService(/*has_service_to_client_connection=*/true, /*enable_shutdown_barrier=*/false); - EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); absl::Notification n; coord_service_->ShutdownTaskAsync(task_0_, [&n](absl::Status s) { - EXPECT_OK(s); + TF_EXPECT_OK(s); n.Notify(); }); n.WaitForNotification(); // Heartbeat should be allowed for a short grace period after shutdown. - EXPECT_OK(coord_service_->RecordHeartbeat(task_0_, incarnation_0_)); + TF_EXPECT_OK(coord_service_->RecordHeartbeat(task_0_, incarnation_0_)); // Heartbeat failure should be triggered for disconnected task after grace // period. @@ -1365,7 +1366,7 @@ TEST_F(CoordinateTwoTasksTest, Shutdown_HeartbeatsAreAcceptedForAGracePeriod) { TEST_F(CoordinateTwoTasksTest, Shutdown_FailsOngoingBarrier) { EnableCoordinationService(/*has_service_to_client_connection=*/true, /*enable_shutdown_barrier=*/false); - EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); absl::Status barrier_status; absl::Notification barrier_n; coord_service_->BarrierAsync("ongoing_barrier", absl::InfiniteDuration(), @@ -1378,7 +1379,7 @@ TEST_F(CoordinateTwoTasksTest, Shutdown_FailsOngoingBarrier) { absl::Notification shutdown_n; coord_service_->ShutdownTaskAsync(task_0_, [&shutdown_n](absl::Status s) { - EXPECT_OK(s); + TF_EXPECT_OK(s); shutdown_n.Notify(); }); shutdown_n.WaitForNotification(); @@ -1391,8 +1392,8 @@ TEST_F(CoordinateTwoTasksTest, Shutdown_FailsOngoingBarrier) { TEST_F(CoordinateTwoTasksTest, ShutdownWithBarrier_BarrierSucceeds) { EnableCoordinationService(/*has_service_to_client_connection=*/true, /*enable_shutdown_barrier=*/true); - EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); - EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); absl::Status barrier_status; absl::Status barrier_status_2; @@ -1401,22 +1402,22 @@ TEST_F(CoordinateTwoTasksTest, ShutdownWithBarrier_BarrierSucceeds) { coord_service_->ShutdownTaskAsync( task_1_, [&barrier_status_2](absl::Status s) { barrier_status_2 = s; }); - EXPECT_OK(barrier_status); - EXPECT_OK(barrier_status_2); + TF_EXPECT_OK(barrier_status); + TF_EXPECT_OK(barrier_status_2); // Confirm that both tasks have disconnected. // Note: this should not happen in prod where RegisterTask() is called after // Shutdown(), which is prevented by agent-side logic. - EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); - EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); } TEST_F(CoordinateTwoTasksTest, ShutdownWithBarrier_BarrierFails_TaskDisconnectsOtherTaskIsAlerted) { EnableCoordinationService(/*has_service_to_client_connection=*/true, /*enable_shutdown_barrier=*/true); - EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); - EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); absl::Status barrier_status; absl::Notification n; @@ -1432,7 +1433,7 @@ TEST_F(CoordinateTwoTasksTest, // Confirm that task_0_ has disconnected. // Note: this should not happen in prod where RegisterTask() is called after // Shutdown(), which is prevented by agent-side logic. - EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); // Other task is alerted that shutdown has been initiated without it. absl::Status other_task_status = client_1_.GetStatus(); @@ -1443,8 +1444,8 @@ TEST_F(CoordinateTwoTasksTest, ShutdownWithBarrier_BarrierFailsWithoutClientConnection_ServiceStops) { EnableCoordinationService(/*has_service_to_client_connection=*/false, /*enable_shutdown_barrier=*/true); - EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); - EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); absl::Status barrier_status; absl::Notification n; @@ -1550,8 +1551,8 @@ TEST_F(CoordinateTwoTasksTest, UnrecoverableTaskPropagatesError) { /*enable_shutdown_barrier=*/false, /*set_worker_job_recoverable=*/false); - EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); - EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); ASSERT_OK(coord_service_->ReportTaskError(task_0_, absl::InternalError("test_error"))); @@ -1567,8 +1568,8 @@ TEST_F(CoordinateTwoTasksTest, RecoverableTaskWillNotPropagateError) { /*enable_shutdown_barrier=*/false, /*set_worker_job_recoverable=*/true); - EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); - EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); ASSERT_OK(coord_service_->ReportTaskError(task_0_, absl::InternalError("test_error"))); @@ -1577,7 +1578,7 @@ TEST_F(CoordinateTwoTasksTest, RecoverableTaskWillNotPropagateError) { coord_service_->RecordHeartbeat(task_0_, incarnation_0_))); // Since no error propagation for recoverable tasks, other tasks should work // as normal. - EXPECT_OK(client_1_.GetStatus()); + TF_EXPECT_OK(client_1_.GetStatus()); } TEST_F(CoordinateTwoTasksTest, @@ -1586,8 +1587,8 @@ TEST_F(CoordinateTwoTasksTest, /*enable_shutdown_barrier=*/false, /*set_worker_job_recoverable=*/true); - EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); - EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); ASSERT_OK(coord_service_->ReportTaskError(task_0_, absl::InternalError("test_error"))); @@ -1596,13 +1597,13 @@ TEST_F(CoordinateTwoTasksTest, coord_service_->RecordHeartbeat(task_0_, incarnation_0_))); // Since no error propagation for recoverable tasks, other tasks should work // as normal. - EXPECT_OK(client_1_.GetStatus()); + TF_EXPECT_OK(client_1_.GetStatus()); // Reset and register the error task again, both tasks should be healthy. - EXPECT_OK(coord_service_->ResetTask(task_0_)); - EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_new_)); - EXPECT_OK(coord_service_->RecordHeartbeat(task_0_, incarnation_0_new_)); - EXPECT_OK(client_1_.GetStatus()); + TF_EXPECT_OK(coord_service_->ResetTask(task_0_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_new_)); + TF_EXPECT_OK(coord_service_->RecordHeartbeat(task_0_, incarnation_0_new_)); + TF_EXPECT_OK(client_1_.GetStatus()); } TEST_F(CoordinateTwoTasksTest, UnavailableTaskCanReconnect) { @@ -1611,11 +1612,11 @@ TEST_F(CoordinateTwoTasksTest, UnavailableTaskCanReconnect) { /*set_worker_job_recoverable=*/false, /*allow_new_incarnation_to_reconnect=*/true); - EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); ASSERT_OK(coord_service_->ReportTaskError( task_0_, MakeCoordinationError(absl::UnavailableError("test_error")))); - EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_new_)); + TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_new_)); } } // namespace tsl From 6fd81ed51e9d781d9bfbba3c40ec345217c60c44 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Fri, 21 Jun 2024 01:00:28 -0700 Subject: [PATCH 099/256] Simplify div simplification. We have a rewrite rule that increases the number of divisions, which was needed to undo some upstream `mod` simplifications. We don't use the upstream simplifier anymore, so it can be removed, if we also make the division simplification a bit smarter. The removed test can be added back with another simplification rule (basically the one I removed here in reverse), but I'll do that separately. PiperOrigin-RevId: 645299494 --- .../service/gpu/fusions/input_slices_test.cc | 2 +- .../xla/service/gpu/fusions/loop_mlir_test.cc | 21 +++--- .../xla/xla/service/gpu/fusions/loop_test.cc | 14 ++-- .../fusions/mlir/tests/simplify_affine.mlir | 4 +- .../service/gpu/fusions/scatter_mlir_test.cc | 11 +-- .../xla/service/gpu/fusions/scatter_test.cc | 11 +-- .../xla/xla/service/gpu/model/indexing_map.cc | 67 ++++++------------- .../service/gpu/model/indexing_map_test.cc | 18 +---- 8 files changed, 53 insertions(+), 95 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc b/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc index 8d5b479f3221e1..9c32e9e035ebb1 100644 --- a/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc @@ -80,7 +80,7 @@ TEST_F(InputSlicesTest, ThreadIndexing) { (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (0, ((bl_x * 128 + th_x) floordiv 3) mod 2, (bl_x * 128 + th_x) mod 3, - ((bl_x * 64 + th_x floordiv 2) floordiv 3) mod 5) + ((bl_x * 128 + th_x) floordiv 6) mod 5) domain: th_x in [0, 127] th_y in [0, 0] diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc index 6af3fb0a9e173d..6c5d4464396800 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc @@ -54,9 +54,9 @@ TEST_F(MlirLoopFusionTest, ThreadId_IndexingUnrolled) { EXPECT_THAT(thread_id_to_output_indexing->ToString(thread_id_printer_), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( - (((bl_x * 16 + th_x floordiv 8) floordiv 3 + chunk_id * 5376) floordiv 625) mod 100, - (((bl_x * 128 + th_x) floordiv 3 + chunk_id * 43008) floordiv 25) mod 200, - (th_x * 4 + bl_x * 512 + chunk_id * 516096) mod 300 + unroll_id + ((bl_x * 128 + th_x + chunk_id * 129024) floordiv 15000) mod 100, + ((bl_x * 128 + th_x + chunk_id * 129024) floordiv 75) mod 200, + (th_x * 4 + bl_x * 512 + chunk_id * 516096) mod 300 + unroll_id ) domain: th_x in [0, 127] @@ -148,9 +148,10 @@ TEST_F(MlirLoopFusionTest, ThreadId_Broadcast) { EXPECT_THAT(thread_id_to_output_indexing->ToString(thread_id_printer_), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( - ((bl_x * 16 + th_x floordiv 8) floordiv 75) mod 10, - ((bl_x * 64 + th_x floordiv 2) floordiv 15) mod 20, - (bl_x * 128 + th_x) mod 30) + ((bl_x * 128 + th_x) floordiv 600) mod 10, + ((bl_x * 128 + th_x) floordiv 30) mod 20, + (bl_x * 128 + th_x) mod 30 + ) domain: th_x in [0, 127] th_y in [0, 0] @@ -166,8 +167,8 @@ TEST_F(MlirLoopFusionTest, ThreadId_Broadcast) { /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); EXPECT_THAT(thread_id_to_input_indexing->ToString(thread_id_printer_), MatchIndexingString(R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( - ((bl_x * 64 + th_x floordiv 2) floordiv 15) mod 20) + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> + (((bl_x * 128 + th_x) floordiv 30) mod 20) domain: th_x in [0, 127] th_y in [0, 0] @@ -196,8 +197,8 @@ TEST_F(MlirLoopFusionTest, Constant_Broadcast) { )"; TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( // CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1 * 1024 + d0)> - // CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (((d1 * 4 + d0 floordiv 256) floordiv 3) mod 2)> - // CHECK: #[[MAP2:.*]] = affine_map<(d0, d1) -> (((d1 * 64 + d0 floordiv 16) floordiv 3) mod 16)> + // CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (((d1 * 1024 + d0) floordiv 768) mod 2)> + // CHECK: #[[MAP2:.*]] = affine_map<(d0, d1) -> (((d1 * 1024 + d0) floordiv 48) mod 16)> // CHECK: #[[MAP3:.*]] = affine_map<(d0, d1) -> ((d1 * 1024 + d0) mod 48)> // CHECK: func.func @fused_computation(%[[ARG0:.*]]: tensor<2x16x48xbf16> // CHECK: %[[UPPER_BOUND:.*]] = arith.constant 1535 : index diff --git a/third_party/xla/xla/service/gpu/fusions/loop_test.cc b/third_party/xla/xla/service/gpu/fusions/loop_test.cc index 6262b130c207b1..4337260ca040e0 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop_test.cc @@ -88,9 +88,9 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) { EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( - (((bl_x * 16 + th_x floordiv 8) floordiv 3 + chunk_id * 5376) floordiv 625) mod 100, - (((bl_x * 128 + th_x) floordiv 3 + chunk_id * 43008) floordiv 25) mod 200, - (th_x * 4 + bl_x * 512 + chunk_id * 516096) mod 300 + unroll_id + ((bl_x * 128 + th_x + chunk_id * 129024) floordiv 15000) mod 100, + ((bl_x * 128 + th_x + chunk_id * 129024) floordiv 75) mod 200, + (th_x * 4 + bl_x * 512 + chunk_id * 516096) mod 300 + unroll_id ) domain: th_x in [0, 127] @@ -183,8 +183,8 @@ TEST_F(LoopTest, Broadcast) { EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( - ((bl_x * 16 + th_x floordiv 8) floordiv 75) mod 10, - ((bl_x * 64 + th_x floordiv 2) floordiv 15) mod 20, + ((bl_x * 128 + th_x) floordiv 600) mod 10, + ((bl_x * 128 + th_x) floordiv 30) mod 20, (bl_x * 128 + th_x) mod 30) domain: th_x in [0, 127] @@ -202,8 +202,8 @@ TEST_F(LoopTest, Broadcast) { /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); EXPECT_THAT(thread_id_to_input_indexing->ToString(printer_), MatchIndexingString(R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( - ((bl_x * 64 + th_x floordiv 2) floordiv 15) mod 20) + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> + (((bl_x * 128 + th_x) floordiv 30) mod 20) domain: th_x in [0, 127] th_y in [0, 0] diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir index 3d96446cbdd6b3..ec1a726da9db13 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir +++ b/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir @@ -7,7 +7,7 @@ func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.pt %0 = gpu.thread_id x {xla.range = [0 : index, 127 : index]} %1 = gpu.block_id x {xla.range = [0 : index, 3071 : index]} scf.for %arg3 = %c0 to %c4 step %c1 { - %2 = affine.apply affine_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 - ((s1 * 4 + s2) floordiv 256) * 256 + (s1 floordiv 64) * 256 - ((s0 * 2 + s1 floordiv 64) floordiv 3) * 768 + ((s0 * 128 + s1) floordiv 192) * 768 - (((s0 * 128 + s1) floordiv 192) floordiv 1024) * 786432 + (s0 floordiv 1536) * 786432)>()[%1, %0, %arg3] + %2 = affine.apply affine_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4))>()[%1, %0, %arg3] %3 = arith.index_castui %2 : index to i64 %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32 %5 = llvm.load %4 invariant : !llvm.ptr -> f32 @@ -62,7 +62,7 @@ func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.pt %0 = gpu.thread_id x %1 = gpu.block_id x scf.for %i = %c0 to %c4 step %c1 { - %2 = xla_gpu.apply_indexing affine_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 - ((s1 * 4 + s2) floordiv 256) * 256 + (s1 floordiv 64) * 256 - ((s0 * 2 + s1 floordiv 64) floordiv 3) * 768 + ((s0 * 128 + s1) floordiv 192) * 768 - (((s0 * 128 + s1) floordiv 192) floordiv 1024) * 786432 + (s0 floordiv 1536) * 786432)> + %2 = xla_gpu.apply_indexing affine_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4))> [%1 in [0, 3071], %0 in [0, 127], %i in [0, 3]] %3 = arith.index_castui %2 : index to i64 %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32 diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc index a40e8e037bb0af..e0cd913f00d341 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc @@ -81,9 +81,10 @@ TEST_F(MlirScatterFusionTest, ThreadIdIndexing) { constexpr auto kUpdatesIndexing = R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( - ((bl_x * 16 + th_x floordiv 8) floordiv 25) mod 42, - ((bl_x * 32 + th_x floordiv 4) floordiv 5) mod 10, - (bl_x * 128 + th_x) mod 20) + ((bl_x * 128 + th_x) floordiv 200) mod 42, + ((bl_x * 128 + th_x) floordiv 20) mod 10, + (bl_x * 128 + th_x) mod 20 + ) domain: th_x in [0, 127] th_y in [0, 0] @@ -121,8 +122,8 @@ TEST_F(MlirScatterFusionTest, ThreadIdIndexing) { MatchIndexingString(kUpdatesIndexing)); constexpr auto kIndicesIndexing = R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] -> ( - ((bl_x * 16 + th_x floordiv 8) floordiv 25) mod 42, 0) + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] -> + (((bl_x * 128 + th_x) floordiv 200) mod 42, 0) domain: th_x in [0, 127] th_y in [0, 0] diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_test.cc b/third_party/xla/xla/service/gpu/fusions/scatter_test.cc index f150e1d412cb69..b765194e3f868b 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/scatter_test.cc @@ -145,9 +145,10 @@ TEST_F(ScatterFusionTest, ThreadIdIndexing) { constexpr auto kUpdatesIndexing = R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( - ((bl_x * 16 + th_x floordiv 8) floordiv 25) mod 42, - ((bl_x * 32 + th_x floordiv 4) floordiv 5) mod 10, - (bl_x * 128 + th_x) mod 20) + ((bl_x * 128 + th_x) floordiv 200) mod 42, + ((bl_x * 128 + th_x) floordiv 20) mod 10, + (bl_x * 128 + th_x) mod 20 + ) domain: th_x in [0, 127] th_y in [0, 0] @@ -185,8 +186,8 @@ TEST_F(ScatterFusionTest, ThreadIdIndexing) { MatchIndexingString(kUpdatesIndexing)); constexpr auto kIndicesIndexing = R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] -> ( - ((bl_x * 16 + th_x floordiv 8) floordiv 25) mod 42, 0) + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] -> + (((bl_x * 128 + th_x) floordiv 200) mod 42, 0) domain: th_x in [0, 127] th_y in [0, 0] diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index 3b072b91b77254..16e10ac756e05b 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -232,6 +232,8 @@ AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) { auto lhs = range_evaluator_->ComputeExpressionRange(lhs_simplified); auto rhs = range_evaluator_->ComputeExpressionRange(div.getRHS()); + // TODO(jreiffers): Split this function into multiple (one for each rewrite + // rule). if (0 <= lhs.lower && lhs.upper < rhs.lower) { return getAffineConstantExpr(0, mlir_context); } @@ -273,7 +275,6 @@ AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) { } } - Interval no_multiplier_range{0, 0}; AffineExpr zero = getAffineConstantExpr(0, mlir_context); AffineExpr extracted = zero; auto new_dividend = RemoveSummands(lhs_simplified, [&](AffineExpr expr) { @@ -296,61 +297,31 @@ AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) { return extracted; } - std::optional multiplier_gcd = std::nullopt; - // The maximum GCD of (divisor, any multiplier inside the div). - int64_t max_remaining_multiplier_gcd = 1; + // The gcd of all multipliers and the dividend. + int64_t multiplier_divisor_gcd = d; + Interval no_multiplier_range{0, 0}; VisitSummands(new_dividend, [&](AffineExpr summand) { if (auto multiplier = GetConstantRhs(summand, AffineExprKind::Mul)) { - if (multiplier_gcd.has_value()) { - multiplier_gcd = std::gcd(*multiplier_gcd, *multiplier); - } else { - multiplier_gcd = multiplier; - } - - max_remaining_multiplier_gcd = - std::max(max_remaining_multiplier_gcd, std::gcd(*multiplier, d)); + multiplier_divisor_gcd = std::gcd(multiplier_divisor_gcd, *multiplier); } else { - auto range = range_evaluator_->ComputeExpressionRange(summand); - no_multiplier_range.lower += range.lower; - no_multiplier_range.upper += range.upper; + no_multiplier_range = no_multiplier_range + + range_evaluator_->ComputeExpressionRange(summand); } }); - if (multiplier_gcd.has_value()) { - if ((d % *multiplier_gcd) == 0) { - if (no_multiplier_range.lower >= 0 && - no_multiplier_range.upper < *multiplier_gcd) { - // Remove everything that doesn't have a multiplier. - new_dividend = RemoveSummands(new_dividend, [&](AffineExpr expr) { - auto mult = GetConstantRhs(expr, AffineExprKind::Mul); - return !mult.has_value(); - }); + // Consider an expression like: `(x * 6 + y) / 9`. if the range of `y` is at + // most `[0; 3)`, we can rewrite it to `(x * 2) / 3`, since `y` can't affect + // the result. + if (no_multiplier_range.lower >= 0 && + no_multiplier_range.upper < multiplier_divisor_gcd) { + auto new_new_dividend = zero; + VisitSummands(new_dividend, [&](AffineExpr summand) { + if (auto mult = GetConstantRhs(summand, AffineExprKind::Mul)) { + new_new_dividend = new_new_dividend + + (GetLhs(summand) * (*mult / multiplier_divisor_gcd)); } - } - } - - // If we have a gcd > 1, we can split the div into two: - // (x * 128 + y) // 192 -> (x * 2 + y // 64) // 3 - // TODO(jreiffers): This is currently required for some simplifications, but - // it increases the number of divisions, which is not really a simplification. - // See if we can avoid this rewrite. - if (max_remaining_multiplier_gcd > 1) { - AffineExpr partially_extracted = getAffineConstantExpr(0, mlir_context); - new_dividend = RemoveSummands(new_dividend, [&](AffineExpr expr) { - if (auto multiplier = GetConstantRhs(expr, AffineExprKind::Mul); - multiplier && (*multiplier > 0) && - ((*multiplier % max_remaining_multiplier_gcd) == 0)) { - partially_extracted = - partially_extracted + - GetLhs(expr) * (*multiplier / max_remaining_multiplier_gcd); - // Remove from dividend. - return true; - } - return false; }); - return extracted + (partially_extracted + - new_dividend.floorDiv(max_remaining_multiplier_gcd)) - .floorDiv(d / max_remaining_multiplier_gcd); + return new_new_dividend.floorDiv(d / multiplier_divisor_gcd) + extracted; } // If we removed nothing, return the original division. diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc index 0cb86a85c702db..882ab81543923d 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc @@ -605,7 +605,7 @@ TEST_F(IndexingMapTest, ParseAffineMap(serialized_map, &mlir_context_), {10, 10, 10}, {}); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( - (d0, d1, d2) -> (d0 * 2 + (d2 floordiv 4 + d1) floordiv 2, + (d0, d1, d2) -> (d0 * 2 + (d1 * 4 + d2) floordiv 8, (d1 * 4 + d2) mod 8) domain: d0 in [0, 9] @@ -715,22 +715,6 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsInSequence) { )")); } -TEST_F(IndexingMapTest, AffineMapSimplification_DivGcdGreater1) { - auto serialized_map = - "()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 - ((s0 * 2 + s1 floordiv 64) " - "floordiv 3) * 768 + ((s0 * 128 + s1) floordiv 192) * 768)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {1234, 128, 4}); - EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( - ()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2) - domain: - s0 in [0, 1233] - s1 in [0, 127] - s2 in [0, 3] - )")); -} - TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromMod) { auto serialized_map = "()[s0, s1, s2, s3] -> ((s0 * 458752 + s1 + s2 * 4 + s3 * 512) mod " From 003e21fa01552e6471940f665d17c997bfbcc92a Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Fri, 21 Jun 2024 01:07:31 -0700 Subject: [PATCH 100/256] Stop using xla/statusor.h now that it just contains an alias for absl::Status. In some situations, this meant also changing unrelated files to directly include tsl/platform/statusor.h to get the definitions for TF_ASSIGN_OR_RETURN, etc., where they were getting transitively included for free. PiperOrigin-RevId: 645301329 --- tensorflow/compiler/jit/tests/BUILD | 2 +- .../compiler/jit/tests/auto_clustering_test_helper.cc | 2 +- .../compiler/jit/tests/auto_clustering_test_helper.h | 2 +- tensorflow/core/tpu/BUILD | 2 +- tensorflow/core/tpu/tpu_execute.cc | 2 +- tensorflow/core/tpu/tpu_execute.h | 2 +- third_party/xla/xla/hlo/utils/BUILD | 2 +- third_party/xla/xla/hlo/utils/hlo_live_range.cc | 2 +- third_party/xla/xla/hlo/utils/hlo_live_range.h | 2 +- third_party/xla/xla/pjrt/cpu/BUILD | 3 +-- third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h | 2 +- third_party/xla/xla/pjrt/cpu/cpu_client.h | 1 - third_party/xla/xla/tools/hlo_bisect/BUILD | 4 ++-- third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state.h | 2 +- .../xla/xla/tools/hlo_bisect/hlo_bisect_state_test.cc | 2 +- third_party/xla/xla/tools/multihost_hlo_runner/BUILD | 8 ++++---- .../tools/multihost_hlo_runner/functional_hlo_runner.h | 2 +- .../multihost_hlo_runner/functional_hlo_runner_test.cc | 2 +- third_party/xla/xla/translate/hlo_to_mhlo/BUILD | 4 ++-- .../xla/xla/translate/hlo_to_mhlo/custom_call_importer.h | 2 +- .../xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h | 2 +- third_party/xla/xla/translate/mhlo_to_hlo/BUILD | 4 ++-- .../xla/xla/translate/mhlo_to_hlo/attribute_exporter.h | 2 +- third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h | 2 +- 24 files changed, 29 insertions(+), 31 deletions(-) diff --git a/tensorflow/compiler/jit/tests/BUILD b/tensorflow/compiler/jit/tests/BUILD index 0c3c4986e44a88..ed7f66ee50c352 100644 --- a/tensorflow/compiler/jit/tests/BUILD +++ b/tensorflow/compiler/jit/tests/BUILD @@ -28,9 +28,9 @@ cc_library( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/tools/optimization:optimization_pass_runner_lib", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_xla//xla:status_macros", - "@local_xla//xla:statusor", ], ) diff --git a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc index 2a761c3e5a57c6..2cb75500dd7916 100644 --- a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc +++ b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc @@ -15,11 +15,11 @@ limitations under the License. #include "tensorflow/compiler/jit/tests/auto_clustering_test_helper.h" +#include "absl/status/statusor.h" #include "absl/strings/numbers.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/random_inputstream.h" diff --git a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.h b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.h index 9266ae449b6078..7f97ee0fe8136e 100644 --- a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.h +++ b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_TESTS_AUTO_CLUSTERING_TEST_HELPER_H_ #define TENSORFLOW_COMPILER_JIT_TESTS_AUTO_CLUSTERING_TEST_HELPER_H_ -#include "xla/statusor.h" +#include "absl/status/statusor.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index 7aacb65ac33672..a27cf6393af06f 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -268,6 +268,7 @@ cc_library( "@com_google_absl//absl/base", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -278,7 +279,6 @@ cc_library( "@local_xla//xla:shape_util", "@local_xla//xla:status", "@local_xla//xla:status_macros", - "@local_xla//xla:statusor", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/ir:hlo", diff --git a/tensorflow/core/tpu/tpu_execute.cc b/tensorflow/core/tpu/tpu_execute.cc index 8daf680e09314f..75d6586c7a84a7 100644 --- a/tensorflow/core/tpu/tpu_execute.cc +++ b/tensorflow/core/tpu/tpu_execute.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/cleanup/cleanup.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" @@ -47,7 +48,6 @@ limitations under the License. #include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/tpu/c_api_conversions.h" #include "xla/stream_executor/tpu/c_api_decl.h" diff --git a/tensorflow/core/tpu/tpu_execute.h b/tensorflow/core/tpu/tpu_execute.h index 053f48807ab77e..bfa177d66a7e4c 100644 --- a/tensorflow/core/tpu/tpu_execute.h +++ b/tensorflow/core/tpu/tpu_execute.h @@ -19,10 +19,10 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "xla/service/computation_placer.h" #include "xla/service/executable.h" #include "xla/service/hlo.pb.h" -#include "xla/statusor.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/tpu/tpu_node_context.h" #include "xla/stream_executor/tpu/tpu_ops_c_api.h" diff --git a/third_party/xla/xla/hlo/utils/BUILD b/third_party/xla/xla/hlo/utils/BUILD index 7c6c4719a5bd7e..94353b525fc9de 100644 --- a/third_party/xla/xla/hlo/utils/BUILD +++ b/third_party/xla/xla/hlo/utils/BUILD @@ -26,7 +26,6 @@ cc_library( hdrs = ["hlo_live_range.h"], deps = [ "//xla:shape_util", - "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/service:hlo_alias_analysis", "//xla/service:hlo_buffer", @@ -35,6 +34,7 @@ cc_library( "//xla/service:hlo_value", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:logging", diff --git a/third_party/xla/xla/hlo/utils/hlo_live_range.cc b/third_party/xla/xla/hlo/utils/hlo_live_range.cc index 670bf8f570889a..093e3d8cfdb4a8 100644 --- a/third_party/xla/xla/hlo/utils/hlo_live_range.cc +++ b/third_party/xla/xla/hlo/utils/hlo_live_range.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" @@ -36,7 +37,6 @@ limitations under the License. #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_value.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "tsl/platform/logging.h" namespace xla { diff --git a/third_party/xla/xla/hlo/utils/hlo_live_range.h b/third_party/xla/xla/hlo/utils/hlo_live_range.h index eb1530503ab121..87c6ec1ed5f0e2 100644 --- a/third_party/xla/xla/hlo/utils/hlo_live_range.h +++ b/third_party/xla/xla/hlo/utils/hlo_live_range.h @@ -20,12 +20,12 @@ the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_value.h" -#include "xla/statusor.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/cpu/BUILD b/third_party/xla/xla/pjrt/cpu/BUILD index 79c69a0f5fb961..ba298b9f0d0044 100644 --- a/third_party/xla/xla/pjrt/cpu/BUILD +++ b/third_party/xla/xla/pjrt/cpu/BUILD @@ -68,7 +68,6 @@ cc_library( "//xla:literal", "//xla:shape_tree", "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/pjrt:pjrt_client", @@ -90,6 +89,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", @@ -145,7 +145,6 @@ cc_library( "//xla:literal", "//xla:literal_util", "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", diff --git a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h index bd8004fafde01b..293bc4016ed803 100644 --- a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h +++ b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h @@ -31,6 +31,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" @@ -41,7 +42,6 @@ limitations under the License. #include "xla/pjrt/transpose.h" #include "xla/service/cpu/cpu_event.h" #include "xla/shape.h" -#include "xla/statusor.h" #include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/ref_count.h" diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.h b/third_party/xla/xla/pjrt/cpu/cpu_client.h index 41bda465a394cf..1de15a0c4f625c 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.h +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.h @@ -60,7 +60,6 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/shape.h" -#include "xla/statusor.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/tools/hlo_bisect/BUILD b/third_party/xla/xla/tools/hlo_bisect/BUILD index 3ee57085839e29..3df7fbbccdbad9 100644 --- a/third_party/xla/xla/tools/hlo_bisect/BUILD +++ b/third_party/xla/xla/tools/hlo_bisect/BUILD @@ -43,7 +43,6 @@ cc_library( hdrs = ["hlo_bisect_state.h"], deps = [ "//xla:literal", - "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:hlo_dce", @@ -52,6 +51,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", ], ) @@ -62,13 +62,13 @@ xla_cc_test( deps = [ ":hlo_bisect_state", "//xla:literal", - "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", ], ) diff --git a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state.h b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state.h index 22cd9d79378f85..2db497f127a04f 100644 --- a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state.h +++ b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state.h @@ -23,9 +23,9 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" -#include "xla/statusor.h" namespace xla { namespace bisect { diff --git a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state_test.cc b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state_test.cc index 3e608c82ddb16b..b4bfc641a84a8c 100644 --- a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state_test.cc +++ b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state_test.cc @@ -21,12 +21,12 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/statusor.h" #include "xla/tests/hlo_test_base.h" namespace xla { diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD index 74d0a19bf3957f..ae4e0e9a0e7278 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD +++ b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD @@ -36,7 +36,6 @@ xla_cc_binary( ":functional_hlo_runner", "//xla:debug_options_flags", "//xla:status_macros", - "//xla:statusor", "//xla/pjrt:pjrt_client", "//xla/pjrt/distributed", "//xla/pjrt/distributed:client", @@ -46,6 +45,7 @@ xla_cc_binary( "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@local_tsl//tsl/platform:errors", @@ -74,7 +74,6 @@ xla_cc_binary( ":functional_hlo_runner", "//xla:debug_options_flags", "//xla:status_macros", - "//xla:statusor", "//xla/pjrt:pjrt_client", "//xla/pjrt/distributed", "//xla/pjrt/distributed:client", @@ -85,6 +84,7 @@ xla_cc_binary( "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@local_tsl//tsl/platform:errors", @@ -110,7 +110,6 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:util", "//xla:xla_proto_cc", "//xla/client:executable_build_options", @@ -139,6 +138,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", @@ -178,7 +178,6 @@ xla_test( deps = [ ":functional_hlo_runner", "//xla:debug_options_flags", - "//xla:statusor", "//xla/pjrt:pjrt_client", "//xla/pjrt/distributed:key_value_store_interface", "//xla/pjrt/distributed:service", @@ -186,6 +185,7 @@ xla_test( "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h index 2329bce1a2fd8e..06e00ba63b9ded 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h @@ -25,6 +25,7 @@ limitations under the License. #include #include "absl/container/btree_map.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_module.h" @@ -33,7 +34,6 @@ limitations under the License. #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" -#include "xla/statusor.h" #include "xla/xla.pb.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc index be970c1716c4b8..a42c28fb76b35b 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/log/check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_format.h" #include "absl/time/time.h" @@ -29,7 +30,6 @@ limitations under the License. #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/distributed/service.h" #include "xla/pjrt/pjrt_client.h" -#include "xla/statusor.h" #include "xla/tests/filecheck.h" #include "xla/tsl/util/command_line_flags.h" #include "tsl/lib/core/status_test_util.h" diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/BUILD b/third_party/xla/xla/translate/hlo_to_mhlo/BUILD index 272fdf22826fb2..eb12ccecb76e68 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/BUILD +++ b/third_party/xla/xla/translate/hlo_to_mhlo/BUILD @@ -31,11 +31,11 @@ cc_library( srcs = ["custom_call_importer.cc"], hdrs = ["custom_call_importer.h"], deps = [ - "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", "//xla/mlir_hlo", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:AsmParser", @@ -77,7 +77,6 @@ cc_library( "//xla:protobuf_util", "//xla:shape_layout", "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", @@ -88,6 +87,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.h index 4fbe4516b93201..e3ecb358d2cf7e 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_HLO_TO_MHLO_CUSTOM_CALL_IMPORTER_H_ #define XLA_TRANSLATE_HLO_TO_MHLO_CUSTOM_CALL_IMPORTER_H_ +#include "absl/status/statusor.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -23,7 +24,6 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/ValueRange.h" // from @llvm-project #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/statusor.h" namespace xla { diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h index c92bca01f078ac..93d60e5a7255d9 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/optional.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -35,7 +36,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_sharding.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" -#include "xla/statusor.h" #include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/BUILD b/third_party/xla/xla/translate/mhlo_to_hlo/BUILD index dff79e3cf4606a..bbc756917625eb 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/BUILD +++ b/third_party/xla/xla/translate/mhlo_to_hlo/BUILD @@ -20,7 +20,6 @@ cc_library( hdrs = ["attribute_exporter.h"], deps = [ "//xla:shape_util", - "//xla:statusor", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", @@ -28,6 +27,7 @@ cc_library( "//xla/service:hlo_parser", "//xla/service:hlo_proto_cc", "//xla/stream_executor:dnn", + "@com_google_absl//absl/status:statusor", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", ], @@ -39,11 +39,11 @@ cc_library( hdrs = ["layout_util.h"], deps = [ "//xla:shape_util", - "//xla:statusor", "//xla:xla_data_proto_cc", "//xla/client:xla_builder", "//xla/hlo/ir:hlo", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", ], ) diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.h b/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.h index 96aa716b588628..d965d5ec3ed2ac 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.h +++ b/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.h @@ -18,11 +18,11 @@ limitations under the License. #include +#include "absl/status/statusor.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/stream_executor/dnn.h" #include "xla/types.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h b/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h index f68e3ed94fe851..2c85a82680345a 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h +++ b/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h @@ -21,10 +21,10 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "xla/client/xla_builder.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/shape.h" -#include "xla/statusor.h" #include "xla/xla_data.pb.h" namespace mlir { From 7fbc7481f82d2e68f0a6947fe97d8eeb1fc8c3a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Bana=C5=9B?= Date: Fri, 21 Jun 2024 02:00:20 -0700 Subject: [PATCH 101/256] [xla:cpu] Add support for convolution thunks This implementation has similar performance to the current runtime for 3x3 convolution shapes in `xla/service/cpu/benchmarks/convolution_benchmark_test.cc` benchmark, for both multi-threaded and single-threaded versions. Change doesn't affect 1x1 convolutions in this benchmark, because `convolution` is rewritten as `dot` op by HLO passes in this case. All features have been ported from the current runtime, only ACL convolutions were not tested. Added convolution thunk unit tests, they cover some extra cases like F16 convolution and incorrect input data. Note F16 is not actually tested in convolution E2E tests (xla/tests/convolution_test.cc), because F16 is rewritten as F32 convolution there. PiperOrigin-RevId: 645314154 --- third_party/xla/xla/service/cpu/BUILD | 37 +- third_party/xla/xla/service/cpu/runtime/BUILD | 51 +++ .../service/cpu/runtime/convolution_thunk.cc | 341 ++++++++++++++++++ .../service/cpu/runtime/convolution_thunk.h | 127 +++++++ .../cpu/runtime/convolution_thunk_test.cc | 303 ++++++++++++++++ .../xla/xla/service/cpu/runtime/thunk.cc | 2 + .../xla/xla/service/cpu/runtime/thunk.h | 1 + .../xla/xla/service/cpu/thunk_emitter.cc | 58 +++ .../xla/xla/service/cpu/thunk_emitter.h | 3 + third_party/xla/xla/tests/BUILD | 1 + third_party/xla/xla/tests/convolution_test.cc | 72 ++++ 11 files changed, 980 insertions(+), 16 deletions(-) create mode 100644 third_party/xla/xla/service/cpu/runtime/convolution_thunk.cc create mode 100644 third_party/xla/xla/service/cpu/runtime/convolution_thunk.h create mode 100644 third_party/xla/xla/service/cpu/runtime/convolution_thunk_test.cc diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index bb002039a8eb4f..8fbe954e40d016 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -802,6 +802,7 @@ cc_library( hdrs = ["thunk_emitter.h"], deps = [ ":dot_op_emitter", + ":ir_emission_utils", ":ir_emitter2", ":target_machine_features", "//xla:cpu_function_runtime", @@ -819,6 +820,7 @@ cc_library( "//xla/service/cpu/runtime:collective_permute_thunk", "//xla/service/cpu/runtime:collective_thunk", "//xla/service/cpu/runtime:conditional_thunk", + "//xla/service/cpu/runtime:convolution_thunk", "//xla/service/cpu/runtime:copy_thunk", "//xla/service/cpu/runtime:custom_call_thunk", "//xla/service/cpu/runtime:dot_thunk", @@ -1011,15 +1013,24 @@ cc_library( ) cc_library( - name = "runtime_conv2d", - srcs = [ - "runtime_conv2d.cc", - "runtime_conv_impl.h", + name = "runtime_conv_impl", + hdrs = ["runtime_conv_impl.h"], + visibility = internal_visibility([":friends"]), + deps = [ + "//xla/tsl/framework/contraction:eigen_contraction_kernel", + "//xla/tsl/framework/convolution:eigen_helpers", + "@eigen_archive//:eigen3", ], +) + +cc_library( + name = "runtime_conv2d", + srcs = ["runtime_conv2d.cc"], hdrs = ["runtime_conv2d.h"], copts = runtime_copts(), visibility = ["//visibility:public"], deps = [ + ":runtime_conv_impl", ":runtime_lightweight_check", "//xla:executable_run_options", "//xla/tsl/framework/contraction:eigen_contraction_kernel", @@ -1032,14 +1043,12 @@ cc_library( cc_library( name = "runtime_conv3d", - srcs = [ - "runtime_conv3d.cc", - "runtime_conv_impl.h", - ], + srcs = ["runtime_conv3d.cc"], hdrs = ["runtime_conv3d.h"], copts = runtime_copts(), visibility = ["//visibility:public"], deps = [ + ":runtime_conv_impl", ":runtime_lightweight_check", "//xla:executable_run_options", "//xla/tsl/framework/contraction:eigen_contraction_kernel", @@ -1164,14 +1173,12 @@ cc_library( cc_library( name = "runtime_single_threaded_conv2d", - srcs = [ - "runtime_conv_impl.h", - "runtime_single_threaded_conv2d.cc", - ], + srcs = ["runtime_single_threaded_conv2d.cc"], hdrs = ["runtime_single_threaded_conv2d.h"], copts = runtime_copts(), visibility = ["//visibility:public"], deps = [ + ":runtime_conv_impl", "//xla/tsl/framework/contraction:eigen_contraction_kernel", "//xla/tsl/framework/convolution:eigen_helpers", "@com_google_absl//absl/base:dynamic_annotations", @@ -1182,14 +1189,12 @@ cc_library( cc_library( name = "runtime_single_threaded_conv3d", - srcs = [ - "runtime_conv_impl.h", - "runtime_single_threaded_conv3d.cc", - ], + srcs = ["runtime_single_threaded_conv3d.cc"], hdrs = ["runtime_single_threaded_conv3d.h"], copts = runtime_copts(), visibility = ["//visibility:public"], deps = [ + ":runtime_conv_impl", "//xla/tsl/framework/contraction:eigen_contraction_kernel", "//xla/tsl/framework/convolution:eigen_helpers", "@com_google_absl//absl/base:dynamic_annotations", diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD index 0e2a6bd860566e..8cd0e661a11d42 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/service/cpu/runtime/BUILD @@ -1,5 +1,6 @@ load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//xla:xla.bzl", "xla_cc_test") +load("//xla/service/cpu:build_defs.bzl", "runtime_copts") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -221,6 +222,56 @@ cc_library( ], ) +cc_library( + name = "convolution_thunk", + srcs = ["convolution_thunk.cc"], + hdrs = ["convolution_thunk.h"], + copts = runtime_copts(), + deps = [ + ":thunk", + "//xla:executable_run_options", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla/runtime:buffer_use", + "//xla/service:buffer_assignment", + "//xla/service/cpu:runtime_conv2d_acl", + "//xla/service/cpu:runtime_conv_impl", + "//xla/stream_executor:device_memory", + "//xla/tsl/concurrency:async_value", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:traceme", + ], +) + +xla_cc_test( + name = "convolution_thunk_test", + srcs = ["convolution_thunk_test.cc"], + deps = [ + ":buffer_allocations", + ":convolution_thunk", + ":thunk", + "//xla:shape_util", + "//xla/service:buffer_assignment", + "//xla/service:maybe_owning_device_memory", + "//xla/stream_executor:device_memory", + "//xla/tsl/concurrency:async_value", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + cc_library( name = "all_reduce_thunk", srcs = ["all_reduce_thunk.cc"], diff --git a/third_party/xla/xla/service/cpu/runtime/convolution_thunk.cc b/third_party/xla/xla/service/cpu/runtime/convolution_thunk.cc new file mode 100644 index 00000000000000..d2aab7a165a2d4 --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/convolution_thunk.cc @@ -0,0 +1,341 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/cpu/runtime/convolution_thunk.h" + +#define EIGEN_USE_THREADS + +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "Eigen/Core" // from @eigen_archive +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "xla/executable_run_options.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/cpu/runtime/thunk.h" +#include "xla/service/cpu/runtime_conv2d_acl.h" +#include "xla/service/cpu/runtime_conv_impl.h" +#include "xla/shape.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/traceme.h" + +namespace xla::cpu { +namespace { + +bool IsSupportedType(PrimitiveType primitive_type) { + return primitive_type == PrimitiveType::F16 || + primitive_type == PrimitiveType::F32; +} + +auto GetConvolutionRank(const Shape& input_shape) { + // Convolution rank is the number of spatial dimensions. Besides spatial + // dimensions, input shape contains two other dimensions (batch size and the + // number of channels). + return input_shape.dimensions_size() - 2; +} + +bool CanUseACL(const ConvolutionThunk::Options& options, + PrimitiveType primitive_type, int64_t convolution_rank) { + return options.use_acl && primitive_type == PrimitiveType::F32 && + convolution_rank == 2; +} + +auto MakeRunOptions(const Eigen::ThreadPoolDevice* threadpool) { + ExecutableRunOptions run_options; + run_options.set_intra_op_thread_pool(threadpool); + return run_options; +} + +} // namespace + +absl::StatusOr> ConvolutionThunk::Create( + Info info, Options options, BufferAllocation::Slice input_buffer, + const Shape& input_shape, BufferAllocation::Slice kernel_buffer, + const Shape& kernel_shape, BufferAllocation::Slice output_buffer, + const Shape& output_shape, const ConvolutionDimensionNumbers& dnums, + const Window& window, int64_t feature_group_count) { + // TODO(abanas): Add shape verification (batch size, feature count, etc.) + auto primitive_type = input_shape.element_type(); + if (!IsSupportedType(primitive_type)) { + return InvalidArgument("ConvolutionThunk: Unsupported element type (%s)", + PrimitiveType_Name(primitive_type)); + } + + int64_t convolution_rank = GetConvolutionRank(input_shape); + if (convolution_rank > 3) { + return InvalidArgument("ConvolutionThunk: Incorrect convolution rank (%d)", + convolution_rank); + } + + absl::InlinedVector input_dims; + absl::InlinedVector kernel_dims; + absl::InlinedVector output_dims; + + // We lower 1D convolutions into calls to the same Eigen function as 2D + // convolutions, except that we pretend that the 1D convolution is really + // a 2D convolution with the missing dimension set to 1. We also adjust + // the padding, dilation parameters as needed. + if (convolution_rank == 1) { + input_dims.push_back(1); + kernel_dims.push_back(1); + output_dims.push_back(1); + } + + // Turn off ACL if not supported for given primitive type and convolution + // rank. + options.use_acl = CanUseACL(options, primitive_type, convolution_rank); + + // Input tensor. + int64_t input_batch = input_shape.dimensions(dnums.input_batch_dimension()); + for (int d : dnums.input_spatial_dimensions()) { + input_dims.push_back(input_shape.dimensions(d)); + } + int64_t input_channels = + input_shape.dimensions(dnums.input_feature_dimension()); + + // Kernel tensor. + for (int d : dnums.kernel_spatial_dimensions()) { + kernel_dims.push_back(kernel_shape.dimensions(d)); + } + int64_t kernel_channels = + kernel_shape.dimensions(dnums.kernel_input_feature_dimension()); + int64_t kernel_filters = + kernel_shape.dimensions(dnums.kernel_output_feature_dimension()); + + // Output tensor. + for (int d : dnums.output_spatial_dimensions()) { + output_dims.push_back(output_shape.dimensions(d)); + } + + // Extract the window stride for the convolution. + absl::InlinedVector strides; + absl::InlinedVector padding_before; + absl::InlinedVector padding_after; + absl::InlinedVector base_dilation; + absl::InlinedVector window_dilation; + if (convolution_rank == 1) { + strides.push_back(1); + padding_before.push_back(0); + padding_after.push_back(0); + base_dilation.push_back(1); + window_dilation.push_back(1); + } + for (const auto& d : window.dimensions()) { + strides.push_back(d.stride()); + padding_before.push_back(d.padding_low()); + padding_after.push_back(d.padding_high()); + base_dilation.push_back(d.base_dilation()); + window_dilation.push_back(d.window_dilation()); + } + + auto valid_num_dims = [](absl::Span xs) { + return xs.size() >= 2 && xs.size() <= 3; + }; + TF_RET_CHECK(valid_num_dims(input_dims)) << input_dims.size(); + TF_RET_CHECK(valid_num_dims(kernel_dims)); + TF_RET_CHECK(valid_num_dims(output_dims)); + TF_RET_CHECK(valid_num_dims(strides)); + TF_RET_CHECK(valid_num_dims(padding_before)); + TF_RET_CHECK(valid_num_dims(padding_after)); + TF_RET_CHECK(valid_num_dims(base_dilation)); + TF_RET_CHECK(valid_num_dims(window_dilation)); + + return absl::WrapUnique(new ConvolutionThunk( + std::move(info), std::move(input_buffer), input_shape, + std::move(kernel_buffer), kernel_shape, std::move(output_buffer), + output_shape, input_batch, input_dims, input_channels, kernel_dims, + kernel_channels, kernel_filters, output_dims, strides, padding_before, + padding_after, base_dilation, window_dilation, feature_group_count, + options)); +} + +ConvolutionThunk::ConvolutionThunk( + Info info, BufferAllocation::Slice input_buffer, const Shape& input_shape, + BufferAllocation::Slice kernel_buffer, const Shape& kernel_shape, + BufferAllocation::Slice output_buffer, const Shape& output_shape, + int64_t input_batch, const absl::InlinedVector& input_dims, + int64_t input_channels, const absl::InlinedVector& kernel_dims, + int64_t kernel_channels, int64_t kernel_filters, + const absl::InlinedVector& output_dims, + const absl::InlinedVector& strides, + const absl::InlinedVector& padding_before, + const absl::InlinedVector& padding_after, + const absl::InlinedVector& base_dilation, + const absl::InlinedVector& window_dilation, + int64_t feature_group_count, Options options) + : Thunk(Kind::kConvolution, std::move(info)), + input_buffer_(input_buffer), + input_shape_(input_shape), + kernel_buffer_(kernel_buffer), + kernel_shape_(kernel_shape), + output_buffer_(output_buffer), + output_shape_(output_shape), + input_batch_(input_batch), + input_dims_(input_dims), + input_channels_(input_channels), + kernel_dims_(kernel_dims), + kernel_channels_(kernel_channels), + kernel_filters_(kernel_filters), + output_dims_(output_dims), + strides_(strides), + padding_before_(padding_before), + padding_after_(padding_after), + base_dilation_(base_dilation), + window_dilation_(window_dilation), + feature_group_count_(feature_group_count), + convolution_rank_(input_dims.size()), + options_(options) {} + +tsl::AsyncValueRef ConvolutionThunk::Execute( + const ExecuteParams& params) { + tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); + + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase input_data, + params.buffer_allocations->GetDeviceAddress(input_buffer_)); + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase kernel_data, + params.buffer_allocations->GetDeviceAddress(kernel_buffer_)); + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase output_data, + params.buffer_allocations->GetDeviceAddress(output_buffer_)); + + VLOG(3) << absl::StreamFormat("ConvolutionThunk::Execute"); + VLOG(3) << absl::StreamFormat(" input: %s in slice %s (%p)", + input_shape_.ToString(true), + input_buffer_.ToString(), input_data.opaque()); + VLOG(3) << absl::StreamFormat( + " kernel: %s in slice %s (%p)", kernel_shape_.ToString(true), + kernel_buffer_.ToString(), kernel_data.opaque()); + VLOG(3) << absl::StreamFormat( + " output: %s in slice %s (%p)", output_shape_.ToString(true), + output_buffer_.ToString(), output_data.opaque()); + + if (options_.use_acl) { + HandleACLConvolution(params, input_data, kernel_data, output_data); + return OkExecuteEvent(); + } + + // Eigen convolution + if (convolution_rank_ == 2) { + HandleEigen2DConvolution(params, input_data, kernel_data, output_data); + } else { + HandleEigen3DConvolution(params, input_data, kernel_data, output_data); + } + + // TODO(abanas): Execute asynchronously in multi-thread mode using + // Eigen::ThreadPoolDevice. + return OkExecuteEvent(); +} + +void ConvolutionThunk::HandleACLConvolution(const ExecuteParams& params, + se::DeviceMemoryBase input, + se::DeviceMemoryBase kernel, + se::DeviceMemoryBase output) { + // NOTE: This is the basic support for ACL. Performance was not + // benchmarked and is likely not good, the design could be improved + // (e.g. creating run_options is a hack). + auto run_options = MakeRunOptions(params.intra_op_threadpool); + __xla_cpu_runtime_ACLConv2DF32( + &run_options, static_cast(output.opaque()), + static_cast(input.opaque()), static_cast(kernel.opaque()), + input_batch_, input_dims_.x, input_dims_.y, input_channels_, + kernel_dims_.x, kernel_dims_.y, kernel_channels_, kernel_filters_, + output_dims_.x, output_dims_.y, strides_.x, strides_.y, padding_before_.x, + padding_after_.x, padding_before_.y, padding_after_.y, base_dilation_.x, + base_dilation_.y, window_dilation_.x, window_dilation_.y, + feature_group_count_); +} + +void ConvolutionThunk::HandleEigen2DConvolution(const ExecuteParams& params, + se::DeviceMemoryBase input, + se::DeviceMemoryBase kernel, + se::DeviceMemoryBase output) { + auto dispatch = [&](auto type_tag, const auto& eigen_device) { + using scalar_type = decltype(type_tag); + tensorflow::xla::EigenConv2DImpl( + eigen_device, static_cast(output.opaque()), + static_cast(input.opaque()), + static_cast(kernel.opaque()), input_batch_, input_dims_.x, + input_dims_.y, input_channels_, kernel_dims_.x, kernel_dims_.y, + kernel_channels_, kernel_filters_, output_dims_.x, output_dims_.y, + strides_.x, strides_.y, padding_before_.x, padding_after_.x, + padding_before_.y, padding_after_.y, base_dilation_.x, base_dilation_.y, + window_dilation_.x, window_dilation_.y, feature_group_count_); + }; + + if (input_shape_.element_type() == PrimitiveType::F16) { + if (options_.multi_threaded) { + dispatch(Eigen::half(), *params.intra_op_threadpool); + } else { + dispatch(Eigen::half(), Eigen::DefaultDevice()); + } + } else { + if (options_.multi_threaded) { + dispatch(float(), *params.intra_op_threadpool); + } else { + dispatch(float(), Eigen::DefaultDevice()); + } + } +} + +void ConvolutionThunk::HandleEigen3DConvolution(const ExecuteParams& params, + se::DeviceMemoryBase input, + se::DeviceMemoryBase kernel, + se::DeviceMemoryBase output) { + auto dispatch = [&](auto type_tag, const auto& eigen_device) { + using scalar_type = decltype(type_tag); + tensorflow::xla::EigenConv3DImpl( + eigen_device, static_cast(output.opaque()), + static_cast(input.opaque()), + static_cast(kernel.opaque()), input_batch_, input_dims_.x, + input_dims_.y, input_dims_.z, input_channels_, kernel_dims_.x, + kernel_dims_.y, kernel_dims_.z, kernel_channels_, kernel_filters_, + output_dims_.x, output_dims_.y, output_dims_.z, strides_.x, strides_.y, + strides_.z, padding_before_.x, padding_after_.x, padding_before_.y, + padding_after_.y, padding_before_.z, padding_after_.z, base_dilation_.x, + base_dilation_.y, base_dilation_.z, window_dilation_.x, + window_dilation_.y, window_dilation_.z, feature_group_count_); + }; + + if (input_shape_.element_type() == PrimitiveType::F16) { + if (options_.multi_threaded) { + dispatch(Eigen::half(), *params.intra_op_threadpool); + } else { + dispatch(Eigen::half(), Eigen::DefaultDevice()); + } + } else { + if (options_.multi_threaded) { + dispatch(float(), *params.intra_op_threadpool); + } else { + dispatch(float(), Eigen::DefaultDevice()); + } + } +} + +ConvolutionThunk::Dims::Dims(const absl::InlinedVector& dims) + : x(dims[0]), y(dims[1]), z(dims.size() == 3 ? dims[2] : 0) {} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/convolution_thunk.h b/third_party/xla/xla/service/cpu/runtime/convolution_thunk.h new file mode 100644 index 00000000000000..761f1ff63fe873 --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/convolution_thunk.h @@ -0,0 +1,127 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_RUNTIME_CONVOLUTION_THUNK_H_ +#define XLA_SERVICE_CPU_RUNTIME_CONVOLUTION_THUNK_H_ + +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" +#include "xla/runtime/buffer_use.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/cpu/runtime/thunk.h" +#include "xla/shape.h" +#include "xla/tsl/concurrency/async_value_ref.h" + +namespace xla::cpu { + +// Performs 1D, 2D or 3D convolution. +class ConvolutionThunk final : public Thunk { + public: + struct Options { + bool multi_threaded = false; + bool use_acl = false; + }; + static absl::StatusOr> Create( + Info info, Options options, BufferAllocation::Slice input_buffer, + const Shape& input_shape, BufferAllocation::Slice kernel_buffer, + const Shape& kernel_shape, BufferAllocation::Slice output_buffer, + const Shape& output_shape, const ConvolutionDimensionNumbers& dnums, + const Window& window, int64_t feature_group_count); + + tsl::AsyncValueRef Execute(const ExecuteParams& params) final; + + Thunk::BufferUses buffer_uses() const final { + return {{input_buffer_, BufferUse::kRead}, + {kernel_buffer_, BufferUse::kRead}, + {output_buffer_, BufferUse::kWrite}}; + } + + private: + ConvolutionThunk(Info info, BufferAllocation::Slice input_buffer, + const Shape& input_shape, + BufferAllocation::Slice kernel_buffer, + const Shape& kernel_shape, + BufferAllocation::Slice output_buffer, + const Shape& output_shape, int64_t input_batch, + const absl::InlinedVector& input_dims, + int64_t input_channels, + const absl::InlinedVector& kernel_dims, + int64_t kernel_channels, int64_t kernel_filters, + const absl::InlinedVector& output_dims, + const absl::InlinedVector& strides, + const absl::InlinedVector& padding_before, + const absl::InlinedVector& padding_after, + const absl::InlinedVector& base_dilation, + const absl::InlinedVector& window_dilation, + int64_t feature_group_count, Options options); + + void HandleACLConvolution(const ExecuteParams& params, + se::DeviceMemoryBase input, + se::DeviceMemoryBase kernel, + se::DeviceMemoryBase output); + void HandleEigen2DConvolution(const ExecuteParams& params, + se::DeviceMemoryBase input, + se::DeviceMemoryBase kernel, + se::DeviceMemoryBase output); + void HandleEigen3DConvolution(const ExecuteParams& params, + se::DeviceMemoryBase input, + se::DeviceMemoryBase kernel, + se::DeviceMemoryBase output); + + // A helper struct to store the x, y and z dimensions of a tensor, introduced + // for readability. + // In case of 2D convolution, only the x and y dimensions are used and z is + // set to 0. + struct Dims { + explicit Dims(const absl::InlinedVector& dims); + + int64_t x; + int64_t y; + int64_t z; + }; + + BufferAllocation::Slice input_buffer_; + Shape input_shape_; + + BufferAllocation::Slice kernel_buffer_; + Shape kernel_shape_; + + BufferAllocation::Slice output_buffer_; + Shape output_shape_; + + int64_t input_batch_; + Dims input_dims_; + int64_t input_channels_; + Dims kernel_dims_; + int64_t kernel_channels_; + int64_t kernel_filters_; + Dims output_dims_; + Dims strides_; + Dims padding_before_; + Dims padding_after_; + Dims base_dilation_; + Dims window_dilation_; + int64_t feature_group_count_; + int convolution_rank_; + Options options_; +}; + +} // namespace xla::cpu + +#endif // XLA_SERVICE_CPU_RUNTIME_CONVOLUTION_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/convolution_thunk_test.cc b/third_party/xla/xla/service/cpu/runtime/convolution_thunk_test.cc new file mode 100644 index 00000000000000..973771cb9aa771 --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/convolution_thunk_test.cc @@ -0,0 +1,303 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/runtime/convolution_thunk.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "Eigen/Core" // from @eigen_archive +#include "xla/primitive_util.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/cpu/runtime/buffer_allocations.h" +#include "xla/service/cpu/runtime/thunk.h" +#include "xla/service/maybe_owning_device_memory.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla::cpu { +namespace { + +// NOTE: This file serves to verify the basic functionality of the convolution +// thunk. Comprehensive tests cases are common for all backends and are covered +// in xla/tests/convolution_test.cc file. + +// Convolution dimensions to be used in the tests. +struct ConvolutionDimensions { + int batch_size = 1; + int input_size = 3; + int input_channels = 5; + int kernel_size = 3; + int output_channels = 3; + // Correct for 0 padding, default stride, default dilation. + int output_size = input_size - kernel_size + 1; +}; + +template +class ConvolutionThunkTypedTest : public ::testing::Test {}; + +using CorrectTypes = ::testing::Types; +TYPED_TEST_SUITE(ConvolutionThunkTypedTest, CorrectTypes); + +std::vector MakeInputDims( + int convolution_rank, + ConvolutionDimensions dims = ConvolutionDimensions()) { + std::vector input_dims = {dims.batch_size}; + for (int i = 0; i < convolution_rank; ++i) { + input_dims.push_back(dims.input_size); + } + input_dims.push_back(dims.input_channels); + return input_dims; +} + +std::vector MakeKernelDims( + int convolution_rank, + ConvolutionDimensions dims = ConvolutionDimensions()) { + std::vector kernel_dims = {}; + for (int i = 0; i < convolution_rank; ++i) { + kernel_dims.push_back(dims.kernel_size); + } + kernel_dims.push_back(dims.input_channels); + kernel_dims.push_back(dims.output_channels); + return kernel_dims; +} + +std::vector MakeOutputDims( + int convolution_rank, + ConvolutionDimensions dims = ConvolutionDimensions()) { + std::vector output_dims = {dims.batch_size}; + for (int i = 0; i < convolution_rank; ++i) { + output_dims.push_back(dims.output_size); + } + output_dims.push_back(dims.output_channels); + return output_dims; +} + +template +std::vector MakeDataVector(const std::vector& dims) { + auto size = absl::c_accumulate(dims, 1, std::multiplies()); + return std::vector(size, ElementType(0.0)); +} + +template +std::vector MakeBuffers( + const std::vector& input, + const std::vector& kernel, + const std::vector& output) { + std::vector buffers; + size_t input_size_in_bytes = input.size() * sizeof(ElementType); + buffers.emplace_back(se::DeviceMemoryBase(input.data(), input_size_in_bytes)); + size_t kernel_size_in_bytes = kernel.size() * sizeof(ElementType); + buffers.emplace_back( + se::DeviceMemoryBase(kernel.data(), kernel_size_in_bytes)); + size_t output_size_in_bytes = output.size() * sizeof(ElementType); + buffers.emplace_back( + se::DeviceMemoryBase(output.data(), output_size_in_bytes)); + return buffers; +} + +ConvolutionThunk::Options MakeConvolutionOptions() { + ConvolutionThunk::Options options; + options.multi_threaded = false; + options.use_acl = false; + return options; +} + +ConvolutionDimensionNumbers MakeConvolutionDimensionNumbers( + int convolution_rank) { + ConvolutionDimensionNumbers dnums; + // Input dimensions. + int dim = 0; + dnums.set_input_batch_dimension(dim++); + for (int i = 0; i < convolution_rank; ++i) { + dnums.add_input_spatial_dimensions(dim++); + } + dnums.set_input_feature_dimension(dim++); + + // Kernel dimensions. + dim = 0; + for (int i = 0; i < convolution_rank; ++i) { + dnums.add_kernel_spatial_dimensions(dim++); + } + dnums.set_kernel_input_feature_dimension(dim++); + dnums.set_kernel_output_feature_dimension(dim++); + + // Output dimensions. + dim = 0; + dnums.set_output_batch_dimension(dim++); + for (int i = 0; i < convolution_rank; ++i) { + dnums.add_output_spatial_dimensions(dim++); + } + dnums.set_output_feature_dimension(dim++); + + return dnums; +} + +Window MakeWindow(int convolution_rank) { + Window window; + for (int i = 0; i < convolution_rank; ++i) { + WindowDimension* window_dim = window.add_dimensions(); + window_dim->set_stride(1); + window_dim->set_padding_low(0); + window_dim->set_padding_high(0); + window_dim->set_window_dilation(1); + window_dim->set_base_dilation(1); + } + return window; +} + +// This class is used to build ConvolutionThunk and execute it. It stores all +// the variables that are needed to create and execute the thunk, so it must be +// kept alive until the end of the execution. +template +class ConvolutionThunkBuilder { + public: + auto Build(int convolution_rank, + ConvolutionDimensions dims = ConvolutionDimensions()) { + // Data dimensions. + auto input_dims = MakeInputDims(convolution_rank, dims); + auto kernel_dims = MakeKernelDims(convolution_rank, dims); + auto output_dims = MakeOutputDims(convolution_rank, dims); + + // Actual data. + input_ = MakeDataVector(input_dims); + kernel_ = MakeDataVector(kernel_dims); + output_ = MakeDataVector(output_dims); + + // Buffers. + size_t input_size_in_bytes = input_.size() * sizeof(ElementType); + buffers_.emplace_back( + se::DeviceMemoryBase(input_.data(), input_size_in_bytes)); + size_t kernel_size_in_bytes = kernel_.size() * sizeof(ElementType); + buffers_.emplace_back( + se::DeviceMemoryBase(kernel_.data(), kernel_size_in_bytes)); + size_t output_size_in_bytes = output_.size() * sizeof(ElementType); + buffers_.emplace_back( + se::DeviceMemoryBase(output_.data(), output_size_in_bytes)); + + // Buffer allocations. + allocations_ = std::make_unique(buffers_); + + input_alloc_ = + std::make_unique(0, input_size_in_bytes, 0); + kernel_alloc_ = + std::make_unique(1, kernel_size_in_bytes, 0); + output_alloc_ = + std::make_unique(2, output_size_in_bytes, 0); + + BufferAllocation::Slice input_slice(input_alloc_.get(), 0, + input_size_in_bytes); + BufferAllocation::Slice kernel_slice(kernel_alloc_.get(), 0, + kernel_size_in_bytes); + BufferAllocation::Slice output_slice(output_alloc_.get(), 0, + output_size_in_bytes); + + // Shapes. + auto primitive_type = primitive_util::NativeToPrimitiveType(); + Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_dims); + Shape kernel_shape = ShapeUtil::MakeShape(primitive_type, kernel_dims); + Shape output_shape = ShapeUtil::MakeShape(primitive_type, output_dims); + + // Convolution parameters. + auto options = MakeConvolutionOptions(); + auto dnums = MakeConvolutionDimensionNumbers(convolution_rank); + auto window = MakeWindow(convolution_rank); + + // Create thunk. + return ConvolutionThunk::Create( + {"convolution"}, options, std::move(input_slice), input_shape, + std::move(kernel_slice), kernel_shape, std::move(output_slice), + output_shape, dnums, window, + /*feature_group_count=*/1); + } + + // Get execution parameters for the last created thunk. + auto GetExecutionParams() { + return Thunk::ExecuteParams{nullptr, allocations_.get()}; + } + + private: + std::vector input_; + std::vector kernel_; + std::vector output_; + std::vector buffers_; + + // Unique pointers, because they are created only when needed. + std::unique_ptr allocations_; + std::unique_ptr input_alloc_; + std::unique_ptr kernel_alloc_; + std::unique_ptr output_alloc_; +}; + +template +void SuccessfulConvolution(int convolution_rank) { + ConvolutionThunkBuilder builder; + TF_ASSERT_OK_AND_ASSIGN(auto thunk, builder.Build(convolution_rank)) + + // Execute thunk and wait for completion. + Thunk::ExecuteParams params = builder.GetExecutionParams(); + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + + // Verify that the execution was successful. + // NOTE: We don't check the correctness of the output here, just that it + // executes without errors. Numerics is verified in generic convolution tests. + ASSERT_FALSE(execute_event.IsError()) << execute_event.GetError(); +} + +TYPED_TEST(ConvolutionThunkTypedTest, SuccessfulConvolution1D) { + SuccessfulConvolution(/*convolution_rank=*/1); +} + +TYPED_TEST(ConvolutionThunkTypedTest, SuccessfulConvolution2D) { + SuccessfulConvolution(/*convolution_rank=*/2); +} + +TYPED_TEST(ConvolutionThunkTypedTest, SuccessfulConvolution3D) { + SuccessfulConvolution(/*convolution_rank=*/3); +} + +TEST(ConvolutionThunkTest, CreationErrorOnUnsupportedType) { + ConvolutionThunkBuilder builder; + + auto status_or_thunk = builder.Build(/*convolution_rank=*/2); + EXPECT_EQ(status_or_thunk.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status_or_thunk.status().message(), + ::testing::HasSubstr("Unsupported element type (S32)")); +} + +TEST(ConvolutionThunkTest, CreationErrorOnIncorrectConvolutionRank) { + ConvolutionThunkBuilder builder; + + auto status_or_thunk = builder.Build(/*convolution_rank=*/4); + EXPECT_EQ(status_or_thunk.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status_or_thunk.status().message(), + ::testing::HasSubstr("Incorrect convolution rank (4)")); +} + +} // namespace +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.cc b/third_party/xla/xla/service/cpu/runtime/thunk.cc index 21b107f1155917..5d3924b283a59f 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/thunk.cc @@ -50,6 +50,8 @@ std::string_view Thunk::KindToString(Kind kind) { return "collective-permute"; case Kind::kConditional: return "conditional"; + case Kind::kConvolution: + return "convolution"; case Kind::kCopy: return "copy"; case Kind::kCustomCall: diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.h b/third_party/xla/xla/service/cpu/runtime/thunk.h index 945b2612bd2be4..9a29f13aa79c47 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.h +++ b/third_party/xla/xla/service/cpu/runtime/thunk.h @@ -70,6 +70,7 @@ class Thunk { kCollectivePermute, kCopy, kConditional, + kConvolution, kCustomCall, kDot, kFft, diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.cc b/third_party/xla/xla/service/cpu/thunk_emitter.cc index fed32a0cc7fe96..6d20fb2a1c1184 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.cc +++ b/third_party/xla/xla/service/cpu/thunk_emitter.cc @@ -32,9 +32,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/layout_util.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/dot_op_emitter.h" +#include "xla/service/cpu/ir_emission_utils.h" #include "xla/service/cpu/ir_emitter2.h" #include "xla/service/cpu/runtime/all_gather_thunk.h" #include "xla/service/cpu/runtime/all_reduce_thunk.h" @@ -43,6 +45,7 @@ limitations under the License. #include "xla/service/cpu/runtime/collective_permute_thunk.h" #include "xla/service/cpu/runtime/collective_thunk.h" #include "xla/service/cpu/runtime/conditional_thunk.h" +#include "xla/service/cpu/runtime/convolution_thunk.h" #include "xla/service/cpu/runtime/copy_thunk.h" #include "xla/service/cpu/runtime/custom_call_thunk.h" #include "xla/service/cpu/runtime/dot_thunk.h" @@ -253,6 +256,9 @@ absl::StatusOr ThunkEmitter::EmitHloInstruction( case HloOpcode::kOutfeed: return EmitOutfeedThunk(instruction); + case HloOpcode::kConvolution: + return EmitConvolutionThunk(instruction); + case HloOpcode::kCopy: return EmitCopyThunk(instruction); @@ -448,6 +454,58 @@ absl::StatusOr ThunkEmitter::EmitConcatenateThunk( return EmitElementalKernelThunk(instruction); } +absl::StatusOr ThunkEmitter::EmitConvolutionThunk( + const HloInstruction* instruction) { + // NOTE: The following code (along with TODOs and comments) partially + // duplicates IrEmitter::HandleConvolution. This duplication is temporary, + // as IrEmitter will be removed when we switch to thunks runtime. + const HloInstruction* input = instruction->operand(0); + const HloInstruction* kernel = instruction->operand(1); + TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( + /*instruction=*/*instruction, /*operands=*/{input, kernel}, + /*supported_types=*/ + {PRED, S8, U8, S16, U16, S32, U32, S64, U64, F16, F32, F64, C64, C128})); + + // TODO(tonywy): Add PotentiallyImplementedAsMKLConvolution to support + // different data layouts. + if (PotentiallyImplementedAsEigenConvolution(*instruction, + target_machine_features_)) { + const Shape& input_shape = input->shape(); + const Shape& kernel_shape = kernel->shape(); + const Shape& output_shape = instruction->shape(); + + // The input, kernel and output agree with respect to layout. + if (LayoutUtil::IsMonotonicWithDim0Major(input_shape.layout()) && + LayoutUtil::IsMonotonicWithDim0Major(kernel_shape.layout()) && + LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout())) { + TF_ASSIGN_OR_RETURN(auto input_buffer, GetAllocationSlice(input)); + + TF_ASSIGN_OR_RETURN(auto kernel_buffer, GetAllocationSlice(kernel)); + + TF_ASSIGN_OR_RETURN(auto output_buffer, GetAllocationSlice(instruction)); + + ConvolutionThunk::Options options; + options.multi_threaded = + hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + options.use_acl = hlo_module_config_.debug_options().xla_cpu_use_acl(); + return ThunkSequence::Of( + ThunkInfo(instruction), options, input_buffer, input_shape, + kernel_buffer, kernel_shape, output_buffer, output_shape, + instruction->convolution_dimension_numbers(), instruction->window(), + instruction->feature_group_count()); + } + } + + // This is a completely un-optimized version of convolution just to + // have an early version that works. E.g. the input index and + // padding calculation is not hoisted out of the inner loop. + // + // See the description of convolution in the XLA documentation for the pseudo + // code for convolution. + VLOG(2) << "Falling back to unoptimized convolution: " << instruction->name(); + return EmitElementalKernelThunk(instruction); +} + absl::StatusOr ThunkEmitter::EmitCopyThunk( const HloInstruction* instruction) { const HloInstruction* source = instruction->operand(0); diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.h b/third_party/xla/xla/service/cpu/thunk_emitter.h index 8f97169a932195..adaab639b3d397 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.h +++ b/third_party/xla/xla/service/cpu/thunk_emitter.h @@ -73,6 +73,9 @@ class ThunkEmitter { absl::StatusOr EmitConcatenateThunk( const HloInstruction* instruction); + absl::StatusOr EmitConvolutionThunk( + const HloInstruction* instruction); + absl::StatusOr EmitCopyThunk( const HloInstruction* instruction); diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index fd8ae35b30ca5c..2bdfe02297d678 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -1177,6 +1177,7 @@ xla_test( "optonly", # Timed out on 2020-07-18 "nozapfhahn", + "test_xla_cpu_thunks", ], deps = CONVOLUTION_TEST_DEPS + [ "@com_google_absl//absl/memory", diff --git a/third_party/xla/xla/tests/convolution_test.cc b/third_party/xla/xla/tests/convolution_test.cc index 6b0e6df54f841a..56703df13727da 100644 --- a/third_party/xla/xla/tests/convolution_test.cc +++ b/third_party/xla/xla/tests/convolution_test.cc @@ -375,6 +375,78 @@ class Convolve2D_1x3x3x5_3x3x5x3_Valid : public ConvolutionTest { TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x5x3_Valid, TestTypes); TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x3_Valid, Types) { this->RunTest(); } +// Test same padding for 2D convolution with kernel of such size, that every +// single pad value is different (low and high, in x and y dimension). +// Intention of this test is to verify that padding is implemented correctly. +template +class Convolve2D_1x6x6x1_6x2x1x1_Same : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 6, 6, 1}; + std::vector filter_dims = {6, 2, 1, 1}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + dnums.set_output_batch_dimension(0); + dnums.add_output_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(2); + dnums.set_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kSame, dnums); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).value(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).value(); + + auto expected_r1 = LiteralUtil::CreateR1( + {static_cast(836), static_cast(904), static_cast(972), + static_cast(1040), static_cast(1108), static_cast(540), + static_cast(1255), static_cast(1330), static_cast(1405), + static_cast(1480), static_cast(1555), static_cast(750), + static_cast(1710), static_cast(1788), static_cast(1866), + static_cast(1944), static_cast(2022), static_cast(966), + static_cast(1315), static_cast(1370), static_cast(1425), + static_cast(1480), static_cast(1535), static_cast(720), + static_cast(932), static_cast(968), static_cast(1004), + static_cast(1040), static_cast(1076), static_cast(492), + static_cast(585), static_cast(606), static_cast(627), + static_cast(648), static_cast(669), static_cast(294)}); + + auto expected_r4 = expected_r1.Reshape({1, 6, 6, 1}).value(); + + auto input_literal = client_->TransferToServer(input_r4).value(); + auto filter_literal = client_->TransferToServer(filter_r4).value(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x6x6x1_6x2x1x1_Same, TestTypes); +TYPED_TEST(Convolve2D_1x6x6x1_6x2x1x1_Same, Types) { this->RunTest(); } + template class Convolve1D_1x3x5_3x5x3_Valid : public ConvolutionTest { public: From efbe0865742d9e854157fdf840ac39e9db4b3c04 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 21 Jun 2024 02:02:16 -0700 Subject: [PATCH 102/256] compat: Update forward compatibility horizon to 2024-06-21 PiperOrigin-RevId: 645314841 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 5a4b90b703f387..d11ceda88f94e8 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 6, 20) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 6, 21) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 9d03768692f59f7548edf152c8ae17392c0c1907 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 21 Jun 2024 02:02:38 -0700 Subject: [PATCH 103/256] Update GraphDef version to 1900. PiperOrigin-RevId: 645314973 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index be635a1e132f54..eaab911a44b57f 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1899 // Updated: 2024/6/20 +#define TF_GRAPH_DEF_VERSION 1900 // Updated: 2024/6/21 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 87c79a19c9f92d4b8bef414d0cd10661c4c30fb6 Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Fri, 21 Jun 2024 02:37:56 -0700 Subject: [PATCH 104/256] [XLA:GPU] Add HloFindAll to hlo_traversal to find all nodes matching a particular condition. PiperOrigin-RevId: 645323692 --- .../xla/xla/service/gpu/hlo_traversal.cc | 32 +++++++++++-- .../xla/xla/service/gpu/hlo_traversal.h | 9 ++++ .../xla/xla/service/gpu/hlo_traversal_test.cc | 46 +++++++++++++++++++ 3 files changed, 83 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/gpu/hlo_traversal.cc b/third_party/xla/xla/service/gpu/hlo_traversal.cc index bcc70eb2729878..7bfdbdbeabf03d 100644 --- a/third_party/xla/xla/service/gpu/hlo_traversal.cc +++ b/third_party/xla/xla/service/gpu/hlo_traversal.cc @@ -568,10 +568,11 @@ std::optional HloFindIf( return result; } -std::optional HloFindIf( +std::vector HloFindAllImpl( absl::Span roots, const std::function& visit, - bool visit_operands) { + bool visit_operands, bool find_first_only = false) { + std::vector result; absl::flat_hash_set visited; std::queue q; auto enqueue = [&](const HloInstruction* node) { @@ -598,11 +599,34 @@ std::optional HloFindIf( const HloInstruction* node = q.front(); q.pop(); if (visit(node)) { - return node; + result.push_back(node); + if (find_first_only) { + return result; + } } enqueue(node); } - return std::nullopt; + return result; +} + +std::optional HloFindIf( + absl::Span roots, + const std::function& visit, + bool visit_operands) { + auto result = HloFindAllImpl(roots, visit, visit_operands, + /*find_first_only=*/true); + if (result.empty()) { + return std::nullopt; + } + return result[0]; +} + +std::vector HloFindAll( + absl::Span roots, + const std::function& visit, + bool visit_operands) { + std::vector result; + return HloFindAllImpl(roots, visit, visit_operands); } std::vector HloFindUseChain(HloInstructionAdaptor parent, diff --git a/third_party/xla/xla/service/gpu/hlo_traversal.h b/third_party/xla/xla/service/gpu/hlo_traversal.h index bcaff4f4f13ae9..bee9d06af8cd42 100644 --- a/third_party/xla/xla/service/gpu/hlo_traversal.h +++ b/third_party/xla/xla/service/gpu/hlo_traversal.h @@ -189,6 +189,15 @@ std::optional HloFindIf( const std::function& visit, bool visit_operands = true); +// Visit the HLO nodes starting from `roots`. If `visit_operands` is true, the +// search is going towards the operands, otherwise towards the users. Returns +// all nodes for which `visit` returns true. If no node matches, returns an +// empty vector. +std::vector HloFindAll( + absl::Span roots, + const std::function& visit, + bool visit_operands = true); + // Find a use chain from `parent` to `root`. Empty if no chain exists. // `[parent]` if `parent` is `root`. std::vector HloFindUseChain(HloInstructionAdaptor parent, diff --git a/third_party/xla/xla/service/gpu/hlo_traversal_test.cc b/third_party/xla/xla/service/gpu/hlo_traversal_test.cc index 4d03a38f5da20f..d8168d2687b3d7 100644 --- a/third_party/xla/xla/service/gpu/hlo_traversal_test.cc +++ b/third_party/xla/xla/service/gpu/hlo_traversal_test.cc @@ -269,6 +269,52 @@ TEST_F(HloTraversalTest, NotFound) { ASSERT_EQ(result, std::nullopt); } +TEST_F(HloTraversalTest, FindAllMultiple) { + const char kConverts[] = R"( + HloModule test + + ENTRY entry { + p0 = s8[128] parameter(0) + p1 = pred[128] parameter(1) + p1c = s8[128] convert(p1) + p1c1 = f16[128] convert(p1c) + p0c = f16[128] convert(p0) + ROOT diff = f16[128] subtract(p0c, p1c1) + })"; + + auto module = ParseAndReturnVerifiedModule(kConverts).value(); + auto root = module->entry_computation()->GetInstructionWithName("diff"); + std::vector converts = + HloFindAll({root}, [&](const HloInstruction* node) { + return node->opcode() == HloOpcode::kConvert; + }); + + auto get = [&](absl::string_view name) { + return module->entry_computation()->GetInstructionWithName(name); + }; + + EXPECT_THAT(converts, ElementsAre(get("p0c"), get("p1c1"), get("p1c"))); +} + +TEST_F(HloTraversalTest, FindAllNotFound) { + const char kConverts[] = R"( + HloModule test + + ENTRY entry { + p0 = s8[128] parameter(0) + p1 = f16[128] parameter(1) + p0c = f16[128] convert(p0) + ROOT diff = f16[128] subtract(p0c, p1) + })"; + auto module = ParseAndReturnVerifiedModule(kConverts).value(); + auto root = module->entry_computation()->GetInstructionWithName("diff"); + std::vector converts = + HloFindAll({root}, [&](const HloInstruction* node) { + return node->opcode() == HloOpcode::kAdd; + }); + EXPECT_THAT(converts, IsEmpty()); +} + const char kTwoFusions[] = R"( HloModule test From c38900629399a45394d34c75f9989098e8e8694f Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Fri, 21 Jun 2024 02:38:11 -0700 Subject: [PATCH 105/256] Unify constant folding for affine expressions. Currently, we have separate constant folders for division (two, actually!) and symbols/dimensions for some reason. We can just handle everything in one place. PiperOrigin-RevId: 645323739 --- .../xla/xla/service/gpu/model/indexing_map.cc | 36 ++++++------------- 1 file changed, 11 insertions(+), 25 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index 16e10ac756e05b..e3ec255cc4d978 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -228,21 +228,18 @@ AffineExpr AffineExprSimplifier::RewriteMod(AffineBinaryOpExpr mod) { AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) { auto mlir_context = range_evaluator_->GetMLIRContext(); - auto lhs_simplified = SimplifyOnce(div.getLHS()); - auto lhs = range_evaluator_->ComputeExpressionRange(lhs_simplified); - auto rhs = range_evaluator_->ComputeExpressionRange(div.getRHS()); + auto rhs_range = range_evaluator_->ComputeExpressionRange(div.getRHS()); // TODO(jreiffers): Split this function into multiple (one for each rewrite // rule). - if (0 <= lhs.lower && lhs.upper < rhs.lower) { - return getAffineConstantExpr(0, mlir_context); - } // The logic below assumes we have a constant RHS. - if (!rhs.IsPoint()) { + if (!rhs_range.IsPoint()) { return div; } - int64_t d = rhs.lower; + int64_t d = rhs_range.lower; + + auto lhs_simplified = SimplifyOnce(div.getLHS()); // Rewrite `(c % ab) // a` to `(c // a) % b`. // (c % ab) // a @@ -255,13 +252,6 @@ AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) { return GetLhs(lhs_simplified).floorDiv(d) % (*mod / d); } - // If the dividend's range has a single element, return its value. - int64_t a = FloorDiv(lhs.lower, d); - int64_t b = FloorDiv(lhs.upper, d); - if (a == b) { - return getAffineConstantExpr(a, mlir_context); - } - // Rewrite `(a / b) / c` to `a / (b * c)` if `a >= 0` and `b` and `c` are // constants. if (lhs_simplified.getKind() == AffineExprKind::FloorDiv) { @@ -432,6 +422,12 @@ AffineExpr CanonicalizeOrder(AffineExpr in) { } AffineExpr AffineExprSimplifier::SimplifyOnce(AffineExpr expr) { + auto bounds = range_evaluator_->ComputeExpressionRange(expr); + if (bounds.IsPoint()) { + return getAffineConstantExpr(bounds.lower, + range_evaluator_->GetMLIRContext()); + } + switch (expr.getKind()) { case AffineExprKind::Mul: { auto lhs = SimplifyOnce(GetLhs(expr)); @@ -525,16 +521,6 @@ AffineExpr AffineExprSimplifier::SimplifyOnce(AffineExpr expr) { return RewriteMod(mlir::cast(expr)); case AffineExprKind::FloorDiv: return RewriteFloorDiv(mlir::cast(expr)); - case AffineExprKind::DimId: - case AffineExprKind::SymbolId: { - auto bounds = range_evaluator_->ComputeExpressionRange(expr); - if (bounds.IsPoint()) { - return getAffineConstantExpr(bounds.lower, - range_evaluator_->GetMLIRContext()); - } - return expr; - } - default: return expr; } From b45bc2e7a1970b02dce76e4f603699e5720ca145 Mon Sep 17 00:00:00 2001 From: Alexander Lyashuk Date: Fri, 21 Jun 2024 02:53:19 -0700 Subject: [PATCH 106/256] [XLA:GPU] Fix broken build In the open source version of protobuf, std::string_view is not accepted as a parameter of a string field setter. PiperOrigin-RevId: 645326774 --- .../xla/xla/service/gpu/gemm_fusion_autotuner_test.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc index 2168bc78711acd..fa75448abf42db 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc @@ -600,8 +600,9 @@ ENTRY main { } })pb")); autotune_results_override.mutable_results(0)->set_device( - cache_key.GetModelStr()); - autotune_results_override.mutable_results(0)->set_hlo(cache_key.GetHlo()); + std::string(cache_key.GetModelStr())); + autotune_results_override.mutable_results(0)->set_hlo( + std::string(cache_key.GetHlo())); CHECK_OK(AutotunerUtil::LoadAutotuneResults(autotune_results_override)); HloPassPipeline pipeline("gemm_autotune"); From 8d71548cf5698329c04e11010319e1303b1cab13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Fri, 21 Jun 2024 02:59:56 -0700 Subject: [PATCH 107/256] Introduce nested tuple support in FFI This change allows passing (nested) tuples to FFI. Tuples are flattened acting as a logical buffer separator which enables them to be treated as flat buffer arguments. All tuple arrangements are supported as long as the underlying bare buffers are valid for input and output. PiperOrigin-RevId: 645328206 --- third_party/xla/xla/service/cpu/BUILD | 2 + third_party/xla/xla/service/cpu/ir_emitter.cc | 245 ++++++++++++----- third_party/xla/xla/service/cpu/ir_emitter.h | 7 +- .../service/cpu/runtime_handle_ffi_call.cc | 80 ++---- .../xla/service/cpu/runtime_handle_ffi_call.h | 2 +- third_party/xla/xla/shape_util.h | 48 +++- third_party/xla/xla/tests/custom_call_test.cc | 253 +++++++++++++++++- 7 files changed, 496 insertions(+), 141 deletions(-) diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 8fbe954e40d016..436535e1c405ac 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -733,6 +733,7 @@ cc_library( "@local_tsl//tsl/lib/math:math_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", ], ) @@ -1328,6 +1329,7 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", ], diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index 40498df73a6b39..802cd9e14a4d6d 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -87,6 +87,7 @@ limitations under the License. #include "tsl/lib/math/math_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) #include "xla/service/cpu/onednn_memory_util.h" @@ -2836,48 +2837,71 @@ absl::Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { } #endif // INTEL_MKL && ENABLE_ONEDNN_V3 absl::Span operands(custom_call->operands()); - llvm::AllocaInst* operands_alloca = - llvm_ir::EmitAllocaAtFunctionEntryWithCount(b_.getPtrTy(), - b_.getInt32(operands.size()), - "cc_operands_alloca", &b_); - for (size_t i = 0; i < operands.size(); ++i) { - llvm::Value* slot_in_operands_alloca = InBoundsGEP( - operands_alloca->getAllocatedType(), operands_alloca, {b_.getInt64(i)}); - Store(GetEmittedValueFor(operands[i]), slot_in_operands_alloca); + auto typed_custom_call = Cast(custom_call); + auto is_typed_ffi = typed_custom_call->api_version() == + CustomCallApiVersion::API_VERSION_TYPED_FFI; + std::vector operand_values; + operand_values.reserve(operands.size()); + + for (int64_t i = 0; i < operands.size(); ++i) { + HloInstruction* operand = operands[i]; + if (is_typed_ffi) { + // Emit nested tuples as flat buffer pointers + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + operand->shape(), [&](const Shape& shape, const ShapeIndex& index) { + if (!shape.IsArray()) { + return absl::OkStatus(); + } + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, + assignment_.GetUniqueSlice(operand, index)); + operand_values.push_back(EmitBufferPointer(slice, shape)); + return absl::OkStatus(); + })); + } else { + operand_values.push_back(GetEmittedValueFor(operand)); + } } + llvm::AllocaInst* operands_alloca = + llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b_.getPtrTy(), b_.getInt32(operand_values.size()), + "cc_operands_alloca", &b_); if (emit_code_for_msan_) { // Mark the alloca as initialized for msan. The buffer gets read by the // custom callee, which might be msan-instrumented. // TODO(b/66051036): Run the msan instrumentation pass instead. const llvm::DataLayout& dl = module_->getDataLayout(); llvm::Type* intptr_type = b_.getIntPtrTy(dl); - EmitCallToFunc( - "__msan_unpoison", - {operands_alloca, - llvm::ConstantInt::get( - intptr_type, *operands_alloca->getAllocationSizeInBits(dl) / 8)}, - b_.getVoidTy()); + EmitCallToFunc("__msan_unpoison", + {operands_alloca, + llvm::ConstantInt::get( + intptr_type, *operands_alloca->getAllocationSize(dl))}, + b_.getVoidTy()); + } + for (int64_t i = 0; i < operand_values.size(); ++i) { + llvm::Value* slot_in_operands_alloca = InBoundsGEP( + operands_alloca->getAllocatedType(), operands_alloca, {b_.getInt64(i)}); + Store(operand_values[i], slot_in_operands_alloca); } TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); // Write the tuple table if the output is a tuple. + std::vector tuple_ptrs; if (custom_call->shape().IsTuple()) { - std::vector base_ptrs; for (int i = 0; i < ShapeUtil::TupleElementCount(custom_call->shape()); ++i) { const Shape& elem_shape = ShapeUtil::GetTupleElementShape(custom_call->shape(), i); - TF_RET_CHECK(!elem_shape.IsTuple()) << "Nested tuples not implemented"; + if (!is_typed_ffi) { + TF_RET_CHECK(!elem_shape.IsTuple()) << "Nested tuples not implemented"; + } TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, assignment_.GetUniqueSlice(custom_call, {i})); - llvm::Value* addr = EmitBufferPointer(slice, elem_shape); - base_ptrs.push_back(addr); + tuple_ptrs.push_back(EmitBufferPointer(slice, elem_shape)); } - llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_); + llvm_ir::EmitTuple(GetIrArrayFor(custom_call), tuple_ptrs, &b_); } auto* output_address = GetEmittedValueFor(custom_call); - auto typed_custom_call = Cast(custom_call); switch (typed_custom_call->api_version()) { case CustomCallApiVersion::API_VERSION_ORIGINAL: EmitCallToFunc(custom_call->custom_call_target(), @@ -2900,7 +2924,45 @@ absl::Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { break; } case CustomCallApiVersion::API_VERSION_TYPED_FFI: { - EmitCallToFfi(typed_custom_call, output_address, operands_alloca); + // Flatten into raw buffers to avoid (nested) tuples + std::vector buffer_ptrs; + if (custom_call->shape().IsTuple()) { + buffer_ptrs.reserve(ShapeUtil::TupleElementCount(custom_call->shape())); + } + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + custom_call->shape(), + [&](const Shape& shape, const ShapeIndex& index) { + if (!shape.IsArray()) { + return absl::OkStatus(); + } + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, + assignment_.GetUniqueSlice(custom_call, index)); + buffer_ptrs.push_back(EmitBufferPointer(slice, shape)); + return absl::OkStatus(); + })); + llvm::AllocaInst* results_alloca = + llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b_.getPtrTy(), b_.getInt32(buffer_ptrs.size()), + "ffi_results_alloca", &b_); + if (emit_code_for_msan_) { + // Mark the alloca as initialized for msan + // TODO(b/66051036): Run the msan instrumentation pass instead. + const llvm::DataLayout& dl = module_->getDataLayout(); + llvm::Type* intptr_type = b_.getIntPtrTy(dl); + EmitCallToFunc( + "__msan_unpoison", + {results_alloca, + llvm::ConstantInt::get(intptr_type, + *results_alloca->getAllocationSize(dl))}, + b_.getVoidTy()); + } + for (int i = 0; i < buffer_ptrs.size(); ++i) { + llvm::Value* tuple_slot_in_results_alloca = + InBoundsGEP(results_alloca->getAllocatedType(), results_alloca, + {b_.getInt64(i)}); + Store(buffer_ptrs[i], tuple_slot_in_results_alloca); + } + EmitCallToFfi(typed_custom_call, results_alloca, operands_alloca); EmitEarlyReturnIfErrorStatus(); break; } @@ -3159,24 +3221,45 @@ static const Shape& GetShape(T&& arg) { } }; -template -llvm::AllocaInst* IrEmitter::StoreTypes(std::string_view alloca_name, - T&& args) { - auto* types_alloca = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - b_.getInt32Ty(), b_.getInt64(args.size()), alloca_name, &b_); +struct EncodedInfo { + llvm::AllocaInst* alloca; + int64_t size; +}; + +template +static EncodedInfo StoreEncodedTypes(std::string_view alloca_name, + const Args& args, llvm::IRBuilder<>& ir) { + // Store the types of `args` into the allocated memory. These types are stored + // as int32_t values contiguously. All tuples are flattened to bare elements. + int64_t total_elements = 0; + for (int64_t i = 0; i < args.size(); ++i) { + total_elements += ShapeUtil::GetLeafCount(GetShape(args[i])); + } + llvm::AllocaInst* types_alloca = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + ir.getInt32Ty(), ir.getInt64(total_elements), alloca_name, &ir); + int64_t element_id = 0; + auto store_type = [&](const Shape& shape, const ShapeIndex& index) { + if (shape.IsTuple()) { + return; + } + llvm::Value* slot_in_types_alloca = ir.CreateConstInBoundsGEP1_32( + ir.getInt32Ty(), types_alloca, element_id++); + ir.CreateStore(ir.getInt32(shape.element_type()), slot_in_types_alloca); + }; for (int64_t i = 0; i < args.size(); ++i) { - llvm::Value* slot_in_types_alloca = - ConstInBoundsGEP1_32(b_.getInt32Ty(), types_alloca, i); - Store(b_.getInt32(GetShape(args[i]).element_type()), slot_in_types_alloca); + ShapeUtil::ForEachSubshape(GetShape(args[i]), store_type); } - return types_alloca; + CHECK_EQ(element_id, total_elements); + return {types_alloca, total_elements}; }; -template -llvm::Value* IrEmitter::StoreShapes(std::string_view alloca_name, T&& args) { - // Prepare metadata for all buffers - // Shapes metadata is encoded using contiguous flattened dimension values: +template +static EncodedInfo StoreEncodedShapes(std::string_view alloca_name, + const Args& args, llvm::IRBuilder<>& ir) { + // Prepare metadata for all buffers. A tuple shape is flattened to only encode + // information about its elements (buffers). Shapes metadata is encoded using + // contiguous flattened dimension values: // { // 1: DIMCOUNT_1, DIM_1[1], DIM_1[2], ..., DIM_1[DIMCOUNT_1], // \______________DIMCOUNT_1 _______________/ @@ -3187,47 +3270,65 @@ llvm::Value* IrEmitter::StoreShapes(std::string_view alloca_name, T&& args) { // \______________DIMCOUNT_N _______________/ // } // where N is `operand_count`, and `DIMCOUNT_i` is the # of dimensions - std::size_t total_dims = - absl::c_accumulate(args, int64_t{0}, [](int64_t acc, auto&& arg) { - return acc + GetShape(arg).dimensions().size(); - }); - int64_t encoded_shapes_size = args.size() // the dimension count identifiers - + total_dims; // the # of dimension values + int64_t total_dims = 0; + int64_t total_dim_counts = 0; + for (int64_t i = 0; i < args.size(); ++i) { + ShapeUtil::ForEachSubshape( + GetShape(args[i]), [&](const Shape& shape, const ShapeIndex& index) { + if (!shape.IsArray()) { + return; + } + total_dims += shape.dimensions().size(); + ++total_dim_counts; + }); + } + int64_t shapes_encoding_size = total_dim_counts // the # of dimension counts + + total_dims; // the # of dimension values - llvm::Value* shapes_alloca = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - b_.getInt64Ty(), b_.getInt64(encoded_shapes_size), alloca_name, &b_); + llvm::AllocaInst* shapes_alloca = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + ir.getInt64Ty(), ir.getInt64(shapes_encoding_size), alloca_name, &ir); int64_t slot_id = 0; - for (int64_t i = 0; i < args.size(); ++i) { - auto dims = GetShape(args[i]).dimensions(); - llvm::Value* alloca_slot = - ConstInBoundsGEP1_64(b_.getInt64Ty(), shapes_alloca, slot_id++); + auto store_shape = [&](const Shape& shape, const ShapeIndex& index) { + if (!shape.IsArray()) { + return; + } + llvm::Value* alloca_slot = ir.CreateConstInBoundsGEP1_64( + ir.getInt64Ty(), shapes_alloca, slot_id++); // Store the operand count - Store(b_.getInt64(dims.size()), alloca_slot); + ir.CreateStore(ir.getInt64(shape.dimensions().size()), alloca_slot); // Store the operand dimensions - for (int64_t dim : dims) { - alloca_slot = - ConstInBoundsGEP1_64(b_.getInt64Ty(), shapes_alloca, slot_id++); - Store(b_.getInt64(dim), alloca_slot); + for (int64_t dim : shape.dimensions()) { + alloca_slot = ir.CreateConstInBoundsGEP1_64(ir.getInt64Ty(), + shapes_alloca, slot_id++); + ir.CreateStore(ir.getInt64(dim), alloca_slot); } + }; + + for (int64_t i = 0; i < args.size(); ++i) { + ShapeUtil::ForEachSubshape(GetShape(args[i]), store_shape); } - CHECK_EQ(slot_id, encoded_shapes_size); // All slots are filled - return shapes_alloca; + CHECK_EQ(slot_id, shapes_encoding_size); // All slots are filled + return {shapes_alloca, shapes_encoding_size}; }; llvm::Value* IrEmitter::EmitCallToFfi(HloCustomCallInstruction* custom_call, - llvm::Value* output_address, + llvm::AllocaInst* results_alloca, llvm::AllocaInst* operands_alloca) { const auto& operands = absl::MakeSpan(custom_call->operands()); const auto& shape = custom_call->shape(); const auto& result_shapes = shape.IsTuple() ? shape.tuple_shapes() : std::vector({shape}); - auto operand_types_alloca = StoreTypes("meta_types_operands", operands); - auto operand_shapes_alloca = StoreShapes("meta_shapes_operands", operands); + EncodedInfo operand_types_encoded = + StoreEncodedTypes("operands_types", operands, b_); + EncodedInfo operand_shapes_encoded = + StoreEncodedShapes("operands_shapes", operands, b_); - auto result_types_alloca = StoreTypes("meta_types_results", result_shapes); - auto result_shapes_alloca = StoreShapes("meta_shapes_results", result_shapes); + EncodedInfo result_types_encoded = + StoreEncodedTypes("results_types", result_shapes, b_); + EncodedInfo result_shapes_encoded = + StoreEncodedShapes("results_shapes", result_shapes, b_); const absl::string_view target = custom_call->custom_call_target(); // name const absl::string_view opaque = custom_call->opaque(); @@ -3236,20 +3337,20 @@ llvm::Value* IrEmitter::EmitCallToFfi(HloCustomCallInstruction* custom_call, const auto opaque_ref = llvm_ir::AsStringRef(opaque); std::vector arguments = { - GetExecutableRunOptionsArgument(), // run_options_ptr - b_.CreateGlobalStringPtr(target_ref), // target_name_ptr - b_.getInt64(target.size()), // target_name_len - output_address, // output - operands_alloca, // inputs - b_.CreateGlobalStringPtr(opaque_ref), // opaque_str_ptr - b_.getInt64(opaque.size()), // opaque_str_len - GetStatusArgument(), // status_opaque - operand_types_alloca, // operand_types - b_.getInt64(operands.size()), // operand_count - operand_shapes_alloca, // operand_dims - result_types_alloca, // result_types - b_.getInt64(result_shapes.size()), // result_count - result_shapes_alloca, // result_dims + /*run_options_ptr=*/GetExecutableRunOptionsArgument(), + /*target_name_ptr=*/b_.CreateGlobalStringPtr(target_ref), + /*target_name_len=*/b_.getInt64(target.size()), + /*outputs=*/results_alloca, + /*inputs=*/operands_alloca, + /*opaque_str_ptr=*/b_.CreateGlobalStringPtr(opaque_ref), + /*opaque_str_len=*/b_.getInt64(opaque.size()), + /*status_opaque=*/GetStatusArgument(), + /*operand_types=*/operand_types_encoded.alloca, + /*operand_count=*/b_.getInt64(operand_types_encoded.size), + /*operand_dims=*/operand_shapes_encoded.alloca, + /*result_types=*/result_types_encoded.alloca, + /*result_count=*/b_.getInt64(result_types_encoded.size), + /*result_dims=*/result_shapes_encoded.alloca, }; return EmitCallToFunc(runtime::kHandleFfiCallSymbolName, arguments, diff --git a/third_party/xla/xla/service/cpu/ir_emitter.h b/third_party/xla/xla/service/cpu/ir_emitter.h index 2973f0755627c6..9d7b4a65dfd81d 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.h +++ b/third_party/xla/xla/service/cpu/ir_emitter.h @@ -474,14 +474,9 @@ class IrEmitter : public DfsHloVisitorWithDefault, bool only_accesses_arg_memory = false, bool only_accesses_inaccessible_mem_or_arg_mem = false); - template - llvm::AllocaInst* StoreTypes(std::string_view alloca_name, T&& args); - template - llvm::Value* StoreShapes(std::string_view alloca_name, T&& args); - // Emits a call to a proxy that builds an FFI call frame for `custom_call` llvm::Value* EmitCallToFfi(HloCustomCallInstruction* custom_call, - llvm::Value* output_address, + llvm::AllocaInst* results_alloca, llvm::AllocaInst* operands_alloca); // Assignment of the buffers needed by the computation and their shape diff --git a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc index eb6c7d3f146a6c..d9ccd7f5c6a451 100644 --- a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc +++ b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/executable_run_options.h" #include "xla/ffi/attribute_map.h" #include "xla/ffi/call_frame.h" @@ -55,41 +56,28 @@ absl::Span DecodeDims(int64_t* encoded_dims_data) { return absl::MakeSpan(dims_begin, dims_begin + dims_count); } -// TODO(heinsaar): Once on C++20, this can (and should) be a local lambda with -// an explicit template parameter list. -class ArgInserter { - public: - template - explicit ArgInserter(B&& b) : b_(std::forward(b)) {} - - template - void operator()(Args&&... args) const { - b_.AddBufferArg(std::forward(args)...); - } - - private: - ffi::CallFrameBuilder& b_; -}; - -// TODO(heinsaar): Once on C++20, this can (and should) be a local lambda with -// an explicit template parameter list. -class RetInserter { - public: - template - explicit RetInserter(B&& b) : b_(std::forward(b)) {} +void BuildArgBuffers(absl::Span types, int64_t* encoded_dims, + absl::Span address_space, + ffi::CallFrameBuilder& builder) { + int64_t dim_pos = 0; + for (int64_t i = 0; i < types.size(); ++i) { + auto dtype = static_cast(types[i]); + auto dims = DecodeDims(encoded_dims + dim_pos); + auto elem_count = absl::c_accumulate(dims, 1, std::multiplies()); + auto data_width = xla::primitive_util::ByteWidth(dtype) * elem_count; - template - void operator()(Args&&... args) const { - b_.AddBufferRet(std::forward(args)...); + builder.AddBufferArg( + tensorflow::se::DeviceMemoryBase(address_space[i], data_width), + /*type = */ dtype, + /*dims = */ dims); + dim_pos += 1; // Jumps over count value + dim_pos += dims.size(); // Jumps over all dimensions in a shape } +} - private: - ffi::CallFrameBuilder& b_; -}; - -template -void BuildBuffers(absl::Span types, int64_t* encoded_dims, - absl::Span address_space, Builder&& builder) { +void BuildRetBuffers(absl::Span types, int64_t* encoded_dims, + absl::Span address_space, + ffi::CallFrameBuilder& builder) { int64_t dim_pos = 0; for (int64_t i = 0; i < types.size(); ++i) { auto dtype = static_cast(types[i]); @@ -97,9 +85,10 @@ void BuildBuffers(absl::Span types, int64_t* encoded_dims, auto elem_count = absl::c_accumulate(dims, 1, std::multiplies()); auto data_width = xla::primitive_util::ByteWidth(dtype) * elem_count; - builder(tensorflow::se::DeviceMemoryBase(address_space[i], data_width), - /*type = */ dtype, - /*dims = */ dims); + builder.AddBufferRet( + tensorflow::se::DeviceMemoryBase(address_space[i], data_width), + /*type = */ dtype, + /*dims = */ dims); dim_pos += 1; // Jumps over count value dim_pos += dims.size(); // Jumps over all dimensions in a shape } @@ -114,14 +103,6 @@ static absl::Status BuildAndCallFfi( CHECK_EQ(outputs.size(), result_types.size()); CHECK_EQ(inputs.size(), operand_types.size()); - if (absl::c_any_of(operand_types, [](int32_t type) { - return static_cast(type) == - xla::PrimitiveType::TUPLE; - })) { - return absl::InternalError( - "Tuple operands are not supported yet in typed FFI custom calls."); - } - // Find the registered FFI handler for this custom call target. absl::StatusOr registration = ffi::FindHandler(target_name, "Host"); @@ -139,7 +120,7 @@ static absl::Status BuildAndCallFfi( // Backend config not empty, so proceed to parse it into an MLIR attribute // and build an MLIR compatible map of attributes out of it. mlir::Attribute attr = mlir::parseAttribute(backend_config, &mlir_context); - if (auto dict = attr.dyn_cast_or_null()) { + if (auto dict = mlir::dyn_cast_or_null(attr)) { TF_ASSIGN_OR_RETURN(attributes, xla::ffi::BuildAttributesMap(dict)); } else { return absl::InternalError( @@ -156,8 +137,8 @@ static absl::Status BuildAndCallFfi( builder.AddAttributes(attrs.Build()); // Decode dimensions metadata into shapes and build operand & result buffers - BuildBuffers(operand_types, operand_dims, inputs, ArgInserter(builder)); - BuildBuffers(result_types, result_dims, outputs, RetInserter(builder)); + BuildArgBuffers(operand_types, operand_dims, inputs, builder); + BuildRetBuffers(result_types, result_dims, outputs, builder); // Forward executable run options to the FFI handlers via the call options. ffi::CallOptions call_options = { @@ -173,7 +154,7 @@ static absl::Status BuildAndCallFfi( ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_HandleFfiCall( const void* run_options_ptr, const char* target_name_ptr, - int64_t target_name_len, void* output, void** inputs, + int64_t target_name_len, void** outputs, void** inputs, const char* opaque_str_ptr, int64_t opaque_str_len, void* status_opaque, int32_t* operand_types, int64_t operand_count, int64_t* operand_dims, int32_t* result_types, int64_t result_count, int64_t* result_dims) { @@ -181,11 +162,6 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_HandleFfiCall( auto backend_config = absl::string_view(opaque_str_ptr, opaque_str_len); auto xla_status = reinterpret_cast(status_opaque); - void** outputs = &output; - if (result_count > 1) { // output is a tuple - outputs = reinterpret_cast(output); - } - // Annotate memory coming from jit compiled function as initialized to // suppress false positives from msan sanitizer. ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(result_types, diff --git a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.h b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.h index 22511ebfa5c8f6..e6517c5a18833c 100644 --- a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.h +++ b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.h @@ -22,7 +22,7 @@ extern "C" { extern void __xla_cpu_runtime_HandleFfiCall( const void* run_options_ptr, const char* target_name_ptr, - int64_t target_name_len, void* output, void** inputs, + int64_t target_name_len, void** outputs, void** inputs, const char* opaque_str_ptr, int64_t opaque_str_len, void* status_opaque, int32_t* operand_types, int64_t operand_count, int64_t* operand_dims, int32_t* result_types, int64_t result_count, int64_t* result_dims); diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h index 35690357dcfb8d..c87931674ee1ac 100644 --- a/third_party/xla/xla/shape_util.h +++ b/third_party/xla/xla/shape_util.h @@ -601,15 +601,51 @@ class ShapeUtil { // // The visitor function must have the signature // + // absl::Status fn(const Shape& subshape, const ShapeIndex& index) + // void fn(Shape* subshape, const ShapeIndex& index) (mutable version) + template + static absl::Status ForEachLeafShapeWithStatus(const Shape& shape, Fn&& fn) { + return ForEachSubshapeWithStatus( + shape, [&](const Shape& subshape, const ShapeIndex& index) { + if (IsLeafIndex(shape, index)) { + TF_RETURN_IF_ERROR(fn(subshape, index)); + } + return absl::OkStatus(); + }); + } + template + static absl::Status ForEachMutableLeafShapeWithStatus(Shape* shape, Fn&& fn) { + return ForEachMutableSubshapeWithStatus( + shape, [&](Shape* subshape, const ShapeIndex& index) { + if (IsLeafIndex(*shape, index)) { + TF_RETURN_IF_ERROR(fn(subshape, index)); + } + return absl::OkStatus(); + }); + } + + // Calls the given visitor function for each leaf subshape of the given shape. + // Subshapes are visited in DFS pre-order starting with the entire shape + // (index {}). + // + // The visitor function must have the signature // void fn(const Shape& subshape, const ShapeIndex& index) + // void fn(Shape* subshape, const ShapeIndex& index) (mutable version) template static void ForEachLeafShape(const Shape& shape, Fn&& fn) { - ForEachSubshape(shape, - [&](const Shape& sub_shape, const ShapeIndex& index) { - if (IsLeafIndex(shape, index)) { - fn(sub_shape, index); - } - }); + ForEachLeafShapeWithStatus(shape, [&](const Shape& subshape, + const ShapeIndex& index) { + fn(subshape, index); + return absl::OkStatus(); + }).IgnoreError(); + } + template + static void ForEachMutableLeafShape(const Shape& shape, Fn&& fn) { + ForEachMutableLeafShapeWithStatus(shape, [&](Shape* subshape, + const ShapeIndex& index) { + fn(subshape, index); + return absl::OkStatus(); + }).IgnoreError(); } // Variants of ForEach(Mutable)Subshape which propagate absl::Status from the diff --git a/third_party/xla/xla/tests/custom_call_test.cc b/third_party/xla/xla/tests/custom_call_test.cc index e52a0d1b85bf4a..1ffc83d80e0da1 100644 --- a/third_party/xla/xla/tests/custom_call_test.cc +++ b/third_party/xla/xla/tests/custom_call_test.cc @@ -787,8 +787,128 @@ XLA_FFI_DEFINE_HANDLER(kVerifyR2Dimensions, VerifyR2Dimensions, XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$VerifyR2Dimensions", "Host", kVerifyR2Dimensions); +static absl::Status SwapTupleAnyBuffersToS16U32(ffi::AnyBuffer in_1, + ffi::AnyBuffer in_2, + ResultBufferR0 out_1, + ResultBufferR0 out_2) { + auto tuple_elem_1 = reinterpret_cast(in_1.data.opaque()); + auto tuple_elem_2 = reinterpret_cast(in_2.data.opaque()); + out_1->data.base()[0] = tuple_elem_2[0]; + out_2->data.base()[0] = tuple_elem_1[0]; + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kSwapTupleAnyBuffersToS16U32, + SwapTupleAnyBuffersToS16U32, + ffi::Ffi::Bind() + .Arg() + .Arg() + .Ret>() + .Ret>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), + "__xla_test$$SwapTupleAnyBuffersToS16U32", "Host", + kSwapTupleAnyBuffersToS16U32); + +static absl::Status SwapTupleU32S16ToS16U32(ffi::BufferR0 in_1, + ffi::BufferR0 in_2, + ResultBufferR0 out_1, + ResultBufferR0 out_2) { + auto tuple_elem_1 = in_1.data.base(); + auto tuple_elem_2 = in_2.data.base(); + out_1->data.base()[0] = tuple_elem_2[0]; + out_2->data.base()[0] = tuple_elem_1[0]; + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kSwapTupleU32S16ToS16U32, SwapTupleU32S16ToS16U32, + (ffi::Ffi::Bind() + .Arg>() + .Arg>() + .Ret>() + .Ret>())); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), + "__xla_test$$SwapTupleU32S16ToS16U32", "Host", + kSwapTupleU32S16ToS16U32); + +static absl::Status HandleTupleDifferentRanks(ffi::BufferR0 x_1, + ffi::BufferR1 x_2, + ffi::BufferR2 y_1, + ffi::BufferR3 y_2, + ResultBuffer x_out, + ResultBuffer y_out) { + if (x_2.data.ElementCount() != x_out->data.ElementCount()) { + return absl::FailedPreconditionError( + "`x_2` parameter should have the same number of elements as `x_out`"); + } + if (y_1.dimensions != y_out->dimensions.subspan(1) || + y_2.dimensions.front() + 1 != y_out->dimensions.front()) { + return absl::FailedPreconditionError( + "Cannot concatenate `y_1` and `y_2` due to dimensions mismatch. " + "`y_2` dimensions should represent a batched `y_1`"); + } + // Multiply R1 vector by R0 scalar + const auto factor = x_1.data.base()[0]; + for (int i = 0; i < x_2.data.ElementCount(); ++i) { + x_out->data.base()[i] = factor * x_2.data.base()[i]; + } + // Append R2 buffer to R3 buffer + auto last_pos = + std::copy_n(y_2.data.base(), y_2.data.ElementCount(), y_out->data.base()); + std::copy_n(y_1.data.base(), y_1.data.ElementCount(), last_pos); + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kHandleTupleDifferentRanks, HandleTupleDifferentRanks, + ffi::Ffi::Bind() + .Arg>() + .Arg>() + .Arg>() + .Arg>() + .Ret>() + .Ret>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), + "__xla_test$$HandleTupleDifferentRanks", "Host", + kHandleTupleDifferentRanks); + } // namespace +// __xla_test$$ConcatVectors + +static absl::Status Concat3Vectors(ffi::BufferR2 vec_1, + ffi::BufferR2 vec_2, + ffi::BufferR2 vec_3, + ResultBuffer out) { + if (out->dimensions.back() != 3) { + return absl::FailedPreconditionError("output dimension 0 expected to be 3"); + } + float* out_data = out->data.base(); + + ffi::BufferR2* vecs[3] = {&vec_1, &vec_2, &vec_3}; + for (int elem_idx = 0; elem_idx < out->dimensions.front(); ++elem_idx) { + for (int vec_idx = 0; vec_idx < 3; ++vec_idx) { + // {{vec_0[0], vec_1[0], vec_2[0]}, + // {vec_0[1], vec_1[1], vec_2[1]}, + // ...} + const auto out_idx = elem_idx * out->dimensions.back() + vec_idx; + out_data[out_idx] = vecs[vec_idx]->data.base()[elem_idx]; + } + } + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kConcat3Vectors, Concat3Vectors, + ffi::Ffi::Bind() + .Arg>() + .Arg>() + .Arg>() + .Ret>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$Concat3Vectors", + "Host", kConcat3Vectors); + using FfiCustomCallTest = CustomCallTest; XLA_TEST_F(FfiCustomCallTest, FfiReportsSuccess) { @@ -1279,7 +1399,6 @@ XLA_TEST_F(FfiCustomCallTest, FfiTupleOutput) { } XLA_TEST_F(FfiCustomCallTest, FfiNestedTupleOutput) { - GTEST_SKIP() << "Nested tuple outputs not yet implemented."; const char* const kModuleStr = R"( HloModule m @@ -1308,7 +1427,6 @@ XLA_TEST_F(FfiCustomCallTest, FfiNestedTupleOutput) { } XLA_TEST_F(FfiCustomCallTest, FfiTupleInput) { - GTEST_SKIP() << "Tuple inputs not yet implemented."; const char* const kModuleStr = R"( HloModule m @@ -1329,7 +1447,6 @@ XLA_TEST_F(FfiCustomCallTest, FfiTupleInput) { } XLA_TEST_F(FfiCustomCallTest, FfiNestedTupleInput) { - GTEST_SKIP() << "Nested tuple inputs not yet implemented."; const char* const kModuleStr = R"( HloModule m @@ -1351,8 +1468,136 @@ XLA_TEST_F(FfiCustomCallTest, FfiNestedTupleInput) { EXPECT_EQ(result, expected); } +XLA_TEST_F(FfiCustomCallTest, SwapTupleAnyBuffersToS16U32) { + const char* const kModuleStr = R"( + HloModule m + + ENTRY test { + p0 = (u32[], s16[]) parameter(0) + ROOT custom-call = (s16[], u32[]) custom-call(p0), custom_call_target="__xla_test$$SwapTupleAnyBuffersToS16U32", api_version=API_VERSION_TYPED_FFI + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + Literal arg0 = LiteralUtil::CreateR0(0xDEADC0DE); + Literal arg1 = LiteralUtil::CreateR0(29); + Literal argument = LiteralUtil::MakeTuple({&arg0, &arg1}); + Literal expected = LiteralUtil::MakeTuple({&arg1, &arg0}); + + TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {&argument})); + EXPECT_EQ(result, expected); +} + +XLA_TEST_F(FfiCustomCallTest, IgnoresEmptyTupleParameter) { + const char* const kModuleStr = R"( + HloModule m + + ENTRY test { + p0 = (u32[], s16[], ((), ())) parameter(0) + ROOT custom-call = (s16[], u32[]) custom-call(p0), custom_call_target="__xla_test$$SwapTupleAnyBuffersToS16U32", api_version=API_VERSION_TYPED_FFI + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + Literal arg0 = LiteralUtil::CreateR0(0xDEADC0DE); + Literal arg1 = LiteralUtil::CreateR0(29); + Literal empty_tuple = LiteralUtil::MakeTuple({}); + Literal nested_tuple = LiteralUtil::MakeTuple({&empty_tuple, &empty_tuple}); + Literal argument = LiteralUtil::MakeTuple({&arg0, &arg1, &nested_tuple}); + Literal expected = LiteralUtil::MakeTuple({&arg1, &arg0}); + + TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {&argument})); + EXPECT_EQ(result, expected); +} + +XLA_TEST_F(FfiCustomCallTest, SwapTupleU32S16ToS16U32) { + const char* const kModuleStr = R"( + HloModule m + + ENTRY test { + p0 = (u32[], s16[]) parameter(0) + ROOT custom-call = (s16[], u32[]) custom-call(p0), custom_call_target="__xla_test$$SwapTupleU32S16ToS16U32", api_version=API_VERSION_TYPED_FFI + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + Literal arg0 = LiteralUtil::CreateR0(0xDEADC0DE); + Literal arg1 = LiteralUtil::CreateR0(29); + Literal argument = LiteralUtil::MakeTuple({&arg0, &arg1}); + Literal expected = LiteralUtil::MakeTuple({&arg1, &arg0}); + + TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {&argument})); + EXPECT_EQ(result, expected); +} + +XLA_TEST_F(FfiCustomCallTest, HandleR2Tuple) { + const char* const kModuleStr = R"( + HloModule m + + ENTRY test { + p0 = (f32[2, 1], f32[2, 1], f32[2, 1]) parameter(0) + ROOT custom-call = f32[2, 3] custom-call(p0), custom_call_target="__xla_test$$Concat3Vectors", api_version=API_VERSION_TYPED_FFI + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + Literal arg_0 = LiteralUtil::CreateR2({{1.f}, {2.f}}); + Literal arg_1 = LiteralUtil::CreateR2({{3.f}, {4.f}}); + Literal arg_2 = LiteralUtil::CreateR2({{5.f}, {6.f}}); + Literal tuple_arg = LiteralUtil::MakeTuple({&arg_0, &arg_1, &arg_2}); + + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {&tuple_arg})); + + LiteralTestUtil::ExpectR2Equal({{1.f, 3.f, 5.f}, // + {2.f, 4.f, 6.f}}, // + result); +} + +XLA_TEST_F(FfiCustomCallTest, HandleTupleDifferentRanks) { + const char* const kModuleStr = R"( + HloModule m + + ENTRY test { + p0 = ((u32[], s16[5]), (f32[2, 2], f32[4, 2, 2])) parameter(0) + ROOT custom-call = (s32[5], f32[5, 2, 2]) custom-call(p0), custom_call_target="__xla_test$$HandleTupleDifferentRanks", api_version=API_VERSION_TYPED_FFI + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + Literal arg_0 = LiteralUtil::CreateR0(100); + Literal arg_1 = LiteralUtil::CreateR1({29, 30, 31, 32, 33}); + Literal arg_2 = LiteralUtil::CreateR2({{17.f, 18.f}, {19.f, 20.f}}); + Literal arg_3 = LiteralUtil::CreateR3({{{1.f, 2.f}, {3.f, 4.f}}, + {{5.f, 6.f}, {7.f, 8.f}}, + {{9.f, 10.f}, {11.f, 12.f}}, + {{13.f, 14.f}, {15.f, 16.f}}}); + Literal tuple_arg_0 = LiteralUtil::MakeTuple({&arg_0, &arg_1}); + Literal tuple_arg_1 = LiteralUtil::MakeTuple({&arg_2, &arg_3}); + Literal tuple_arg = LiteralUtil::MakeTuple({&tuple_arg_0, &tuple_arg_1}); + + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {&tuple_arg})); + + Literal expected_0 = + LiteralUtil::CreateR1({2900, 3000, 3100, 3200, 3300}); + Literal expected_1 = + LiteralUtil::CreateR3({{{1.f, 2.f}, {3.f, 4.f}}, + {{5.f, 6.f}, {7.f, 8.f}}, + {{9.f, 10.f}, {11.f, 12.f}}, + {{13.f, 14.f}, {15.f, 16.f}}, + {{17.f, 18.f}, {19.f, 20.f}}}); + + Literal expected_tuple = LiteralUtil::MakeTuple({&expected_0, &expected_1}); + EXPECT_EQ(result, expected_tuple); +} + XLA_TEST_F(FfiCustomCallTest, FfiNestedTupleInputAndOutput) { - GTEST_SKIP() << "Nested tuple inputs/outputs not yet implemented."; const char* const kModuleStr = R"( HloModule m From e15904ccb6a9d6ad118c8842f7bc56c26c3bca0f Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Fri, 21 Jun 2024 03:21:21 -0700 Subject: [PATCH 108/256] [XLA:GPU] Remove the requirement to run on a machine with a GPU when using `triton_test_utils`. The logic that requires a GPU is moved from `triton_test_utils` to `ir_emitter_triton_test`. To continue to allow for reusing the utils, I converted a few methods to free functions that can be used by both Triton support tests and `ir_emitter_triton_test`. This change temporarily disables coverage for H100, but we will reintroduce it in a follow up. PiperOrigin-RevId: 645332365 --- third_party/xla/xla/service/gpu/BUILD | 14 +- .../xla/service/gpu/ir_emitter_triton_test.cc | 265 +++++++++++------- .../service/gpu/triton_support_legacy_test.cc | 65 +++-- .../xla/service/gpu/triton_support_test.cc | 62 ++-- .../xla/xla/service/gpu/triton_test_utils.cc | 55 ++-- .../xla/xla/service/gpu/triton_test_utils.h | 76 ++--- 6 files changed, 291 insertions(+), 246 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 8bc8fd4315a260..ef8e4f2fc60d87 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -627,8 +627,6 @@ xla_test( ":backend_configs_cc", ":gpu_device_info_for_tests", ":ir_emitter_triton", - ":matmul_utils", - ":triton_fusion_analysis", ":triton_test_utils", "//xla:autotuning_proto_cc", "//xla:error_spec", @@ -639,6 +637,7 @@ xla_test( "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/service/gpu/model:tiled_hlo_computation", + "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cublas_plugin", "//xla/tests:filecheck", @@ -681,9 +680,9 @@ cc_library( "//xla/service:float_normalization", "//xla/service:hlo_pass_pipeline", "//xla/service/gpu/model:tiled_hlo_computation", - "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -1258,14 +1257,9 @@ cc_library( ], ) -xla_test( +xla_cc_test( name = "triton_support_test", - srcs = if_gpu_is_configured(["triton_support_test.cc"]), - backends = [ - "gpu_a100", - "gpu_h100", - "gpu_amd_any", - ], + srcs = ["triton_support_test.cc"], deps = [ ":gpu_device_info_for_tests", ":ir_emitter_triton", diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc index 2c22daeee72e36..af403cb6875b2e 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc @@ -46,6 +46,7 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/service/gpu/triton_test_utils.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" @@ -71,6 +72,36 @@ namespace { namespace m = ::xla::match; +class TritonTest : public GpuCodegenTest { + public: + stream_executor::CudaComputeCapability GetCudaComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + } + + const stream_executor::GpuComputeCapability& GpuComputeComp() { + return device_desc().gpu_compute_capability(); + } + stream_executor::GpuComputeCapability CudaAmpereOrRocm() { + if (std::holds_alternative( + GpuComputeComp())) { + return stream_executor::GpuComputeCapability{ + device_desc().rocm_compute_capability()}; + } else { + return stream_executor::GpuComputeCapability{ + stream_executor::CudaComputeCapability{ + stream_executor::CudaComputeCapability::AMPERE, 0}}; + } + } + + protected: + const stream_executor::DeviceDescription& device_desc() { + return backend().default_stream_executor()->GetDeviceDescription(); + } +}; + class TritonGemmTest : public TritonTest { public: DebugOptions GetDebugOptionsForTest() override { @@ -110,7 +141,7 @@ class TritonGemmTestWithoutTritonGemmAny : public TritonGemmTest { } }; -TEST_F(TritonFilecheckTest, TestGemm) { +TEST_F(TritonTest, TestGemm) { const std::string kHloText = R"( HloModule t, is_scheduled=true @@ -132,7 +163,8 @@ ENTRY e { "split_k":1,"num_stages":1,"num_warps":2, "num_ctas":1}}} })"; - TF_EXPECT_OK(CreateTritonIrAndFileCheckForDot(kHloText, "triton_gemm_r", R"( + TF_EXPECT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_gemm_r", R"( CHECK: tt.func @triton_fn(%[[LHS:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[RHS:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[OUT:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[ZERO_KN:.*]] = arith.constant dense<0.000000e+00> : tensor<32x64xf32> CHECK-DAG: %[[ZERO_MK:.*]] = arith.constant dense<0.000000e+00> : tensor<16x32xf32> @@ -194,7 +226,7 @@ CHECK: } )")); } -TEST_F(TritonFilecheckTest, TestGemmWithTrivialNonContractingDimension) { +TEST_F(TritonTest, TestGemmWithTrivialNonContractingDimension) { const std::string kHloText = R"( HloModule t, is_scheduled=true @@ -216,7 +248,8 @@ ENTRY e { "num_ctas":1}}} })"; - TF_EXPECT_OK(CreateTritonIrAndFileCheckForDot(kHloText, "triton_dot", R"( + TF_EXPECT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( CHECK: tt.func @triton_fn(%[[LHS:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[RHS:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[OUT:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[ZERO_KN:.*]] = arith.constant dense<0.000000e+00> : tensor<32x16xf32> CHECK-DAG: %[[ZERO_MK:.*]] = arith.constant dense<0.000000e+00> : tensor<16x32xf32> @@ -276,7 +309,7 @@ CHECK: } )")); } -TEST_F(TritonFilecheckTest, TestSoftmaxEmitterWithSingleParameter) { +TEST_F(TritonTest, TestSoftmaxEmitterWithSingleParameter) { const std::string kHloText = R"( HloModule t add { @@ -298,8 +331,9 @@ ENTRY main { param_0 = f32[125,127]{1,0} parameter(0) ROOT triton_softmax = f32[125,127]{1,0} fusion(param_0), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} })"; - TF_EXPECT_OK(CreateTritonIrAndFileCheck( - kHloText, FromOutputTileSizes({1, 127}), "triton_softmax_computation", R"( + TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, + FromOutputTileSizes({1, 127}), + "triton_softmax_computation", R"( CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK: %[[PID:.*]] = tt.get_program_id x : i32 @@ -327,7 +361,7 @@ CHECK: } )")); } -TEST_F(TritonFilecheckTest, TestSoftmaxEmitterWithSingleScalarParameter) { +TEST_F(TritonTest, TestSoftmaxEmitterWithSingleScalarParameter) { const std::string kHloText = R"( HloModule t add { @@ -350,8 +384,9 @@ ENTRY main { param_0 = f32[] constant(42) ROOT triton_softmax = f32[125,127]{1,0} fusion(param_0), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} })"; - TF_EXPECT_OK(CreateTritonIrAndFileCheck( - kHloText, FromOutputTileSizes({1, 127}), "triton_softmax_computation", R"( + TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, + FromOutputTileSizes({1, 127}), + "triton_softmax_computation", R"( CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 @@ -378,7 +413,7 @@ CHECK: } )")); } -TEST_F(TritonFilecheckTest, TestSoftmaxEmitterWithMultipleParameters) { +TEST_F(TritonTest, TestSoftmaxEmitterWithMultipleParameters) { const std::string kHloText = R"( HloModule t @@ -405,8 +440,9 @@ ENTRY main { ROOT triton_softmax = f32[125,127]{1,0} fusion(param_0, param_1), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} } )"; - TF_EXPECT_OK(CreateTritonIrAndFileCheck( - kHloText, FromOutputTileSizes({1, 127}), "triton_softmax_computation", R"( + TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, + FromOutputTileSizes({1, 127}), + "triton_softmax_computation", R"( CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 @@ -440,8 +476,7 @@ CHECK: } )")); } -TEST_F(TritonFilecheckTest, - TestSoftmaxEmitterWithMultipleParametersOrderSwapped) { +TEST_F(TritonTest, TestSoftmaxEmitterWithMultipleParametersOrderSwapped) { // This mirrors the multiple parameter test above, but with the parameter to // be batch-broadcasted in the parameter_0 place instead of parameter_1. const std::string kHloText = R"( @@ -470,8 +505,9 @@ ENTRY main { ROOT triton_softmax = f32[125,127]{1,0} fusion(param_1, param_0), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} } )"; - TF_EXPECT_OK(CreateTritonIrAndFileCheck( - kHloText, FromOutputTileSizes({1, 127}), "triton_softmax_computation", R"( + TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, + FromOutputTileSizes({1, 127}), + "triton_softmax_computation", R"( CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 @@ -506,7 +542,7 @@ CHECK: } )")); } -TEST_F(TritonFilecheckTest, +TEST_F(TritonTest, TestSoftmaxEmitterWithAdditionalParameterEnteringAfterDiamond) { const std::string kHloText = R"( HloModule t @@ -533,8 +569,9 @@ ENTRY main { ROOT triton_softmax = f32[125,127]{1,0} fusion(param_0, param_1), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} } )"; - TF_EXPECT_OK(CreateTritonIrAndFileCheck( - kHloText, FromOutputTileSizes({1, 127}), "triton_softmax_computation", R"( + TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, + FromOutputTileSizes({1, 127}), + "triton_softmax_computation", R"( CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 @@ -568,7 +605,7 @@ CHECK: } )")); } -TEST_F(TritonFilecheckTest, +TEST_F(TritonTest, TestSoftmaxEmitterWithMultipleParametersAlongTiledDimension) { const std::string kHloText = R"( HloModule t @@ -600,8 +637,9 @@ ENTRY main { ROOT triton_softmax = f32[125,127]{1,0} fusion(param_0, param_1, param_2), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} } )"; - TF_EXPECT_OK(CreateTritonIrAndFileCheck( - kHloText, FromOutputTileSizes({1, 127}), "triton_softmax_computation", R"( + TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, + FromOutputTileSizes({1, 127}), + "triton_softmax_computation", R"( CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P3:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[C127_i64:.*]] = arith.constant 127 : i64 @@ -641,7 +679,7 @@ CHECK: } )")); } -TEST_F(TritonFilecheckTest, TestSoftmaxEmitterWithMultipleTiledDimensions) { +TEST_F(TritonTest, TestSoftmaxEmitterWithMultipleTiledDimensions) { const std::string kHloText = R"( HloModule t @@ -672,7 +710,7 @@ ENTRY main { ROOT triton_softmax = f32[10,125,127]{2,1,0} fusion(param_0, param_1, param_2), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} } )"; - TF_EXPECT_OK(CreateTritonIrAndFileCheck(kHloText, + TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, FromOutputTileSizes({1, 1, 127}), "triton_softmax_computation", R"( CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)> @@ -715,7 +753,7 @@ CHECK: } } TEST_F( - TritonFilecheckTest, + TritonTest, DiamondWithAdditionalDiamondParameterBroadcastedAlongReductionDimProducesAccurateResults) { // NOLINT(whitespace/line_length) const std::string kHloText = R"( HloModule h1 @@ -746,8 +784,9 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloText)); - TF_ASSERT_OK(CreateTritonIrAndFileCheck( - kHloText, FromOutputTileSizes({1, 16}), "triton_softmax_computation", R"( + TF_ASSERT_OK(CreateTritonIrAndFileCheck(this, kHloText, + FromOutputTileSizes({1, 16}), + "triton_softmax_computation", R"( CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 16)> CHECK-LABEL: tt.func @triton_fn( CHECK-SAME: %[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, @@ -783,13 +822,13 @@ CHECK-SAME: !tt.ptr> /*arel=*/0})); } -TEST_F(TritonFilecheckTest, NestedReducerFusionGetsCodegenedCorrectly) { +TEST_F(TritonTest, NestedReducerFusionGetsCodegenedCorrectly) { // TODO(b/327336797): remove filter once V100 codegen in Triton is removed. if (!GetCudaComputeCapability().IsAtLeast( se::CudaComputeCapability::AMPERE)) { GTEST_SKIP() << "Doesn't pass on pre-Ampere GPUs."; } - if (SkipBF16Tests()) { + if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } @@ -830,7 +869,7 @@ ENTRY main { } TEST_F( - TritonFilecheckTest, + TritonTest, DiamondWithAdditionalDiamondParameterBroadcastedAlongBatchDimProducesAccurateResults) { // NOLINT(whitespace/line_length) const std::string kHloText = R"( HloModule h1 @@ -861,8 +900,9 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloText)); - TF_ASSERT_OK(CreateTritonIrAndFileCheck( - kHloText, FromOutputTileSizes({1, 32}), "triton_softmax_computation", R"( + TF_ASSERT_OK(CreateTritonIrAndFileCheck(this, kHloText, + FromOutputTileSizes({1, 32}), + "triton_softmax_computation", R"( CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 32)> CHECK-LABEL: tt.func @triton_fn( CHECK-SAME: %[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, @@ -901,7 +941,7 @@ CHECK-SAME: !tt.ptr> } TEST_F( - TritonFilecheckTest, + TritonTest, DiamondWithAdditionalSplatDiamondScalarParameterProducesAccurateResults) { // NOLINT(whitespace/line_length) const std::string kHloText = R"( HloModule h1 @@ -933,7 +973,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloText)); - TF_ASSERT_OK(CreateTritonIrAndFileCheck(kHloText, + TF_ASSERT_OK(CreateTritonIrAndFileCheck(this, kHloText, FromOutputTileSizes({1, 1, 16}), "triton_softmax_computation", R"( // CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 16)> @@ -973,7 +1013,7 @@ ENTRY main { } TEST_F( - TritonFilecheckTest, + TritonTest, DiamondWithAdditionalBroadcastOf1DParameterAlongNonReductionDimensionsProducesAccurateResults) { // NOLINT(whitespace/line_length) const std::string kHloText = R"( HloModule h1 @@ -1004,7 +1044,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloText)); - TF_ASSERT_OK(CreateTritonIrAndFileCheck(kHloText, + TF_ASSERT_OK(CreateTritonIrAndFileCheck(this, kHloText, FromOutputTileSizes({1, 1, 16}), "triton_softmax_computation", R"( // CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 16)> @@ -1045,7 +1085,7 @@ ENTRY main { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); } -TEST_F(TritonFilecheckTest, PredParametersAreTruncatedToI1) { +TEST_F(TritonTest, PredParametersAreTruncatedToI1) { const std::string kHloText = R"( HloModule m @@ -1078,16 +1118,15 @@ ENTRY e { } } )"; - TF_EXPECT_OK( - CreateTritonIrAndFileCheckForDot(kHloText, "triton_gemm_computation", R"( + TF_EXPECT_OK(CreateTritonIrAndFileCheckForDot(this, kHloText, + "triton_gemm_computation", R"( CHECK: %[[LOAD:.*]] = tt.load %{{.*}} {{.*}} : !tt.ptr> CHECK: %[[TRUNCI:.*]] = arith.trunci %[[LOAD]] : tensor<16x16xi8> to tensor<16x16xi1> CHECK: %{{.*}} = arith.andi %[[TRUNCI]], %{{.*}} : tensor<16x16xi1> )")); } -TEST_F(TritonFilecheckTest, - CodegenBatchedDotWithConcatenationWithCorrectBatchStride) { +TEST_F(TritonTest, CodegenBatchedDotWithConcatenationWithCorrectBatchStride) { constexpr absl::string_view kHloText = R"( HloModule t, is_scheduled=true @@ -1113,7 +1152,8 @@ ENTRY e { "num_ctas":1}}} })"; - TF_EXPECT_OK(CreateTritonIrAndFileCheckForDot(kHloText, "triton_gemm", R"( + TF_EXPECT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_gemm", R"( CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr CHECK-SAME: %[[P1:[^:]*]]: !tt.ptr CHECK-SAME: %[[P2:[^:]*]]: !tt.ptr @@ -1127,7 +1167,7 @@ CHECK: %[[BLOCK_BASE_PTR:.*]] = tt.addptr %[[ARG_PTR]], %[[OFFSET]] )")); } -TEST_F(TritonFilecheckTest, CodegenDynamicSliceWithCorrectOffsets) { +TEST_F(TritonTest, CodegenDynamicSliceWithCorrectOffsets) { // The start index(es) for the non-majormost dimension(s) are constant zero(s) // because we don't support dynamic slice on those dimensions. constexpr absl::string_view kHloText = R"( @@ -1159,7 +1199,8 @@ ENTRY e { "num_stages":"1","num_warps":"4","num_ctas":"1"}}} })"; - ASSERT_THAT(CreateTritonIrAndFileCheckForDot(kHloText, "triton_gemm", R"( + ASSERT_THAT( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_gemm", R"( CHECK: tt.func @triton_fn({{[^,]*}}, %[[DYNAMIC_SLICE_INPUT:[^:]*]]: !tt.ptr {{[^,]*}}, %[[START_INDEX0_PTR:[^:]*]]: !tt.ptr CHECK-DAG: %[[C0_i32:.*]] = arith.constant 0 : i32 CHECK-DAG: %[[C1_i64:.*]] = arith.constant 1 : i64 @@ -1175,10 +1216,10 @@ CHECK-DAG: %[[ROW_OFFSET_i64:.*]] = arith.extsi %[[ROW_OFFSET]] : i32 to i64 CHECK-DAG: %[[ROW_LIMIT:.*]] = arith.addi %[[ROW_OFFSET_i64]], %[[C5_i64]] : i64 CHECK-DAG: tt.make_tensor_ptr %[[DYNAMIC_SLICE_INPUT]], [%[[C2_i64]], %[[ROW_LIMIT]]], [%[[C1_i64]], %[[C2_i64]]], [%[[C0_i32]], %[[ROW_OFFSET]]] )"), - tsl::testing::IsOk()); + tsl::testing::IsOk()); } -TEST_F(TritonFilecheckTest, SparseDot) { +TEST_F(TritonTest, SparseDot) { const char* kHloText = R"( HloModule t @@ -1200,7 +1241,8 @@ ENTRY e { {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} } )"; - TF_ASSERT_OK(CreateTritonIrAndFileCheckForDot(kHloText, "triton_dot", R"( + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( CHECK: %[[LHS:[0-9]+]] = tt.load CHECK: %[[RHS:[0-9]+]] = tt.load CHECK: %[[META:[0-9]+]] = tt.load @@ -1208,7 +1250,7 @@ CHECK: triton_gpu.sparse_dot %[[LHS]], %[[RHS]], %{{[^:]+}}, %[[META]] : )")); } -TEST_F(TritonFilecheckTest, SparseDotWithMasking) { +TEST_F(TritonTest, SparseDotWithMasking) { const char* kHloText = R"( HloModule t @@ -1230,7 +1272,8 @@ ENTRY e { {"block_m":32,"block_n":32,"block_k":64,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} } )"; - TF_ASSERT_OK(CreateTritonIrAndFileCheckForDot(kHloText, "triton_dot", R"( + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( CHECK-DAG: %[[C24:.+]] = arith.constant dense<24> CHECK-DAG: %[[C48:.+]] = arith.constant dense<48> CHECK: %[[LHS:[0-9]+]] = tt.load %{{.+}} {boundaryCheck = array @@ -1244,7 +1287,7 @@ CHECK: triton_gpu.sparse_dot %[[LHS_MASKED]], %[[RHS_MASKED]], %{{[^:]+}}, %[[ME )")); } -TEST_F(TritonFilecheckTest, SparseDotBroadcastMetadata) { +TEST_F(TritonTest, SparseDotBroadcastMetadata) { const char* kHloText = R"( HloModule t @@ -1268,7 +1311,8 @@ ENTRY e { {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} } )"; - TF_ASSERT_OK(CreateTritonIrAndFileCheckForDot(kHloText, "triton_dot", R"( + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( CHECK: %[[TWO:.+]] = arith.constant 2 : i32 CHECK: %[[LHS:[0-9]+]] = tt.load CHECK: %[[RHS:[0-9]+]] = tt.load @@ -2414,7 +2458,7 @@ ENTRY e { TEST_F(TritonGemmTestAny, DoNotFuseConcatenationOfSplitNonContractingDimension) { - if (SkipBF16Tests()) { + if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } const std::string hlo_text = R"( @@ -2639,7 +2683,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, DoubleBroadcastOfScalarConstantIsHandled) { - if (SkipBF16Tests()) { + if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } const std::string kHloText = R"( @@ -2687,7 +2731,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, AlwaysFuseScalarConstantAtBroadcastInput) { - if (SkipBF16Tests()) { + if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } const std::string kHloText = R"( @@ -2743,7 +2787,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, FuseConcatenation) { - if (SkipBF16Tests()) { + if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } const std::string kHloText = R"( @@ -3137,7 +3181,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, ParameterAfterDotIsFused) { - if (SkipBF16Tests()) { + if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } const std::string kHloText = R"( @@ -3169,7 +3213,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, OutputFusionExecutesCorrectly) { - if (SkipBF16Tests()) { + if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } const std::string kHloText = R"( @@ -3205,7 +3249,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, SplitLHSOutputTransposeAloneIsNotFused) { - if (SkipBF16Tests()) { + if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } const std::string kHloText = R"( @@ -3232,7 +3276,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, SplitLHSInputOutputIsFused) { - if (SkipBF16Tests()) { + if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } const std::string kHloText = R"( @@ -3522,7 +3566,7 @@ ENTRY e { } TEST_F(CompareTest, BF16TransposedLHS) { - if (SkipBF16Tests()) { + if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } const char* hlo_text_ref = R"( @@ -3726,7 +3770,7 @@ ENTRY e { } TEST_F(CompareTest, S8BF16) { - if (SkipBF16Tests()) { + if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } const char* hlo_text_ref = R"( @@ -3776,7 +3820,7 @@ ENTRY e { } TEST_F(CompareTest, SplitK) { - if (SkipBF16Tests()) { + if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } const std::string hlo_text_ref = R"( @@ -3852,7 +3896,7 @@ ENTRY e { } TEST_F(CompareTest, SplitKBatch) { - if (SkipBF16Tests()) { + if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } const std::string kHloTextRef = R"( @@ -3917,7 +3961,7 @@ ENTRY e { } TEST_F(CompareTest, SplitKNontrivialBitcast) { - if (SkipBF16Tests()) { + if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } const std::string kHloTextRef = R"( @@ -4492,7 +4536,7 @@ ENTRY e { } TEST_F(CompareTest, PredToBF16ConversionWorks) { - if (SkipBF16Tests()) { + if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } const std::string kHloTextTest = R"( @@ -4615,7 +4659,7 @@ class TritonGemmContractionDims : public TritonGemmTest { }; TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_0) { - if (SkipBF16Tests()) { + if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } const std::string kHloText = R"( @@ -4640,7 +4684,7 @@ ENTRY e { } TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_1_2) { - if (SkipBF16Tests()) { + if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } const std::string kHloText = R"( @@ -4665,7 +4709,7 @@ ENTRY e { } TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_0_1) { - if (SkipBF16Tests()) { + if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } const std::string kHloText = R"( @@ -4691,7 +4735,7 @@ ENTRY e { } TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_1) { - if (SkipBF16Tests()) { + if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } const std::string kHloText = R"( @@ -4716,10 +4760,10 @@ ENTRY e { // In these tests, we depend on "algorithm" annotations for selecting the 6XBF16 // algorithm. -class Triton6xBF16GemmTest : public TritonFilecheckTest { +class Triton6xBF16GemmTest : public TritonTest { public: DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = TritonFilecheckTest::GetDebugOptionsForTest(); + DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); // These 2 flags are not strictly necessary now, but we're adding them to be // on the safe side against future flakiness. // @@ -4736,7 +4780,7 @@ class Triton6xBF16GemmTest : public TritonFilecheckTest { protected: void SetUp() override { - if (SkipBF16Tests()) { + if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } } @@ -4746,10 +4790,10 @@ class Triton6xBF16GemmTest : public TritonFilecheckTest { // algorithm. // TODO(b/316147294): Remove this class and the --xla_gpu_enable_bf16_6way_gemm // flag after we will support the algorithm values through the entire stack. -class Triton6xBF16GemmTestWithFlag : public TritonFilecheckTest { +class Triton6xBF16GemmTestWithFlag : public TritonTest { public: DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = TritonFilecheckTest::GetDebugOptionsForTest(); + DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); // Enable triton fusion for all supported GEMMs. debug_options.set_xla_gpu_triton_gemm_any(true); // Do not fall back to cuBLAS, we are testing Triton. @@ -4784,7 +4828,8 @@ ENTRY e { {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} } )"; - TF_ASSERT_OK(CreateTritonIrAndFileCheckForDot(kHloText, "triton_dot", R"( + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32> CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32> CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32> @@ -4824,7 +4869,8 @@ ENTRY e { {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} } )"; - TF_ASSERT_OK(CreateTritonIrAndFileCheckForDot(kHloText, "triton_dot", R"( + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32> CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32> CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32> @@ -4865,7 +4911,8 @@ ENTRY e { {"block_m":64,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":4, "num_ctas":1}}} } )"; - TF_ASSERT_OK(CreateTritonIrAndFileCheckForDot(kHloText, "triton_dot", R"( + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<64x32xbf16> * tensor<32x32xbf16> -> tensor<64x32xf32> )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-5, @@ -4893,7 +4940,8 @@ ENTRY e { {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} } )"; - TF_ASSERT_OK(CreateTritonIrAndFileCheckForDot(kHloText, "triton_dot", R"( + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> )")); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -4935,7 +4983,8 @@ ENTRY e { {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} } )"; - TF_ASSERT_OK(CreateTritonIrAndFileCheckForDot(kHloText, "triton_dot", R"( + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> )")); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -4989,7 +5038,8 @@ ENTRY e { {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} } )"; - TF_ASSERT_OK(CreateTritonIrAndFileCheckForDot(kHloText, "triton_dot", R"( + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> )")); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -5035,10 +5085,10 @@ CHECK-NOT: mma.sync.aligned.{{.*}}.row.col.f32.tf32.tf32.f32 // In these tests, we depend on "algorithm" annotations for selecting the 3XBF16 // algorithm. -class Triton3xBF16GemmTest : public TritonFilecheckTest { +class Triton3xBF16GemmTest : public TritonTest { public: DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = TritonFilecheckTest::GetDebugOptionsForTest(); + DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); // These 2 flags are not strictly necessary now, but we're adding them the // to be on the safe side against future flakiness. // @@ -5058,10 +5108,10 @@ class Triton3xBF16GemmTest : public TritonFilecheckTest { // algorithm. // TODO(b/316147294): Remove this class and the --xla_gpu_enable_bf16_3way_gemm // flag after we will support the algorithm values through the entire stack. -class Triton3xBF16GemmTestWithFlag : public TritonFilecheckTest { +class Triton3xBF16GemmTestWithFlag : public TritonTest { public: DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = TritonFilecheckTest::GetDebugOptionsForTest(); + DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); // Enable triton fusion for all supported GEMMs. debug_options.set_xla_gpu_triton_gemm_any(true); // Do not fall back to cuBLAS, we are testing Triton. @@ -5076,7 +5126,7 @@ class Triton3xBF16GemmTestWithFlag : public TritonFilecheckTest { protected: void SetUp() override { - if (SkipBF16Tests()) { + if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } } @@ -5103,7 +5153,8 @@ ENTRY e { {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} } )"; - TF_ASSERT_OK(CreateTritonIrAndFileCheckForDot(kHloText, "triton_dot", R"( + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32> CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32> CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32> @@ -5143,7 +5194,8 @@ ENTRY e { {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} } )"; - TF_ASSERT_OK(CreateTritonIrAndFileCheckForDot(kHloText, "triton_dot", R"( + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32> CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32> CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32> @@ -5183,7 +5235,8 @@ ENTRY e { {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} } )"; - TF_ASSERT_OK(CreateTritonIrAndFileCheckForDot(kHloText, "triton_dot", R"( + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( CHECK: tt.dot CHECK-SAME: tensor<32x32xf16> * tensor<32x32xf16> -> tensor<32x32xf32> CHECK-NOT: tt.dot @@ -5211,7 +5264,8 @@ ENTRY e { {"block_m":64,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":4, "num_ctas":1}}} } )"; - TF_ASSERT_OK(CreateTritonIrAndFileCheckForDot(kHloText, "triton_dot", R"( + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<64x32xbf16> * tensor<32x32xbf16> -> tensor<64x32xf32> )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-4, @@ -5239,7 +5293,8 @@ ENTRY e { {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} } )"; - TF_ASSERT_OK(CreateTritonIrAndFileCheckForDot(kHloText, "triton_dot", R"( + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> )")); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -5281,7 +5336,8 @@ ENTRY e { {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} } )"; - TF_ASSERT_OK(CreateTritonIrAndFileCheckForDot(kHloText, "triton_dot", R"( + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> )")); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -5325,7 +5381,8 @@ ENTRY e { {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} } )"; - TF_ASSERT_OK(CreateTritonIrAndFileCheckForDot(kHloText, "triton_dot", R"( + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> )")); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -5413,7 +5470,7 @@ ENTRY entry { // This test could be modified to allow TF32 once this bug is fixed. // TODO(b/320659359) Allow TF32 for 8-bit or less types with F32. -TEST_F(TritonFilecheckTest, NoTF32For8BitOrLessWithF32) { +TEST_F(TritonTest, NoTF32For8BitOrLessWithF32) { const std::string hlo_text = R"( HloModule t @@ -5442,7 +5499,8 @@ ENTRY e { "num_ctas":1}}} })"; - TF_ASSERT_OK(CreateTritonIrAndFileCheckForDot(hlo_text, "triton_dot", R"( + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, hlo_text, "triton_dot", R"( CHECK: tt.dot CHECK-NOT: inputPrecision = tf32 )")); @@ -5450,7 +5508,7 @@ CHECK-NOT: inputPrecision = tf32 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonFilecheckTest, Fp8LoweringIsSupportedPostHopper) { +TEST_F(TritonTest, Fp8LoweringIsSupportedPostHopper) { if (!GetCudaComputeCapability().IsAtLeast( se::CudaComputeCapability::HOPPER)) { GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; @@ -5477,14 +5535,15 @@ ENTRY main { "num_stages":"4","num_warps":"4","num_ctas":"1"}}} })"; - TF_ASSERT_OK(CreateTritonIrAndFileCheckForDot(hlo_text, "triton_dot", R"( + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, hlo_text, "triton_dot", R"( CHECK: tt.dot {{.*}}{maxNumImpreciseAcc = 2147483647 : i32} : tensor<128x64xf8E4M3FNUZ> * tensor<64x32xf8E4M3FNUZ> -> tensor<128x32xf32> )")); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3})); } -TEST_F(TritonFilecheckTest, TestGenericEmitterReductionFusion) { +TEST_F(TritonTest, TestGenericEmitterReductionFusion) { const std::string kHloText = R"( HloModule t add { @@ -5507,7 +5566,8 @@ ENTRY main { param_1 = f32[125]{0} parameter(1) ROOT triton_reduction = f32[125]{0} fusion(param_0, param_1), kind=kCustom, calls=triton_reduction_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} })"; - TF_EXPECT_OK(CreateTritonIrAndFileCheck(kHloText, FromOutputTileSizes({1}), + TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, + FromOutputTileSizes({1}), "triton_reduction_computation", R"( CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { @@ -5544,7 +5604,7 @@ CHECK: } )")); } -TEST_F(TritonFilecheckTest, TestGenericEmitterWithSoftMaxSingleParameter) { +TEST_F(TritonTest, TestGenericEmitterWithSoftMaxSingleParameter) { const std::string kHloText = R"( HloModule t add { @@ -5566,8 +5626,9 @@ ENTRY main { param_0 = f32[125,127]{1,0} parameter(0) ROOT triton_softmax = f32[125,127]{1,0} fusion(param_0), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} })"; - TF_EXPECT_OK(CreateTritonIrAndFileCheck( - kHloText, FromOutputTileSizes({1, 127}), "triton_softmax_computation", R"( + TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, + FromOutputTileSizes({1, 127}), + "triton_softmax_computation", R"( CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK: %[[PID:.*]] = tt.get_program_id x : i32 diff --git a/third_party/xla/xla/service/gpu/triton_support_legacy_test.cc b/third_party/xla/xla/service/gpu/triton_support_legacy_test.cc index d3452ea02b5a32..33063b7f81fce2 100644 --- a/third_party/xla/xla/service/gpu/triton_support_legacy_test.cc +++ b/third_party/xla/xla/service/gpu/triton_support_legacy_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -49,13 +50,23 @@ namespace xla { namespace gpu { namespace { -bool CombinationCrashesTriton( - PrimitiveType lhs_type, PrimitiveType rhs_type, PrimitiveType output_type, - se::CudaComputeCapability cuda_compute_capability) { - if (!cuda_compute_capability.IsAtLeastHopper() && - (lhs_type == F8E4M3FN || rhs_type == F8E4M3FN || - output_type == F8E4M3FN)) { - return true; +se::GpuComputeCapability GetComputeCapability() { + // TODO(b/348572380) Make this more general and test additional platforms. + return se::CudaComputeCapability::Ampere(); +} + +bool CombinationCrashesTriton(PrimitiveType lhs_type, PrimitiveType rhs_type, + PrimitiveType output_type, + se::GpuComputeCapability gpu_compute_capability) { + if (std::holds_alternative( + gpu_compute_capability)) { + auto cuda_compute_capability = + std::get(gpu_compute_capability); + if (!cuda_compute_capability.IsAtLeastHopper() && + (lhs_type == F8E4M3FN || rhs_type == F8E4M3FN || + output_type == F8E4M3FN)) { + return true; + } } return false; } @@ -64,7 +75,7 @@ class DotTest : public TritonSupportTestBaseWithParam { protected: void TestDotWithTypes(PrimitiveType lhs_type, PrimitiveType rhs_type, PrimitiveType output_type) { - if (lhs_type == BF16 && SkipBF16Tests()) { + if (lhs_type == BF16 && !SupportsBF16(GetComputeCapability())) { GTEST_SKIP(); } const HloOpcode opcode = HloOpcode::kDot; @@ -101,28 +112,29 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(hlo_test, /*data_type=*/{}, opcode)); - if (legacy_triton::IsTritonSupportedInstruction( - ti.Instruction(), GetCudaComputeCapability())) { - TF_EXPECT_OK(ApplyFloatNormalization(ti.Module().get())); + if (legacy_triton::IsTritonSupportedInstruction(ti.Instruction(), + GetComputeCapability())) { + TF_EXPECT_OK( + ApplyFloatNormalization(ti.Module().get(), GetComputeCapability())); EXPECT_TRUE(RunAndCompareNoHloPasses( std::move(ti.Module()), ErrorSpec{/*aabs=*/primitive_util::IsF8Type(lhs_type) ? 1.0 : 2e-4, /*arel=*/2e-4})); } else { if (CombinationCrashesTriton(lhs_type, rhs_type, output_type, - GetCudaComputeCapability())) { + GetComputeCapability())) { return; } const se::DeviceDescription dev_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetComputeCapability()); BlockLevelParameters block_level_parameters; block_level_parameters.num_ctas = 1; block_level_parameters.num_stages = 4; block_level_parameters.num_warps = 8; EXPECT_THAT( - TritonWrapper("test_fn", &ti.TritonFusion(), - GetCudaComputeCapability(), dev_info, - block_level_parameters, &llvm_module_, mlir_context_), + TritonWrapper("test_fn", &ti.TritonFusion(), GetComputeCapability(), + dev_info, block_level_parameters, &llvm_module_, + mlir_context_), tsl::testing::StatusIs( absl::StatusCode::kInternal, ::testing::HasSubstr("Failed to compile Triton kernel"))); @@ -190,7 +202,7 @@ class DynamicSliceTest TEST_P(DynamicSliceTest, IsTritonSupportedDynamicSlice) { const DynamicSliceTestParam param(GetParam()); - if (param.data_type == BF16 && SkipBF16Tests()) { + if (param.data_type == BF16 && !SupportsBF16(GetComputeCapability())) { GTEST_SKIP(); } @@ -237,7 +249,7 @@ ENTRY e { const bool is_supported_instruction = legacy_triton::IsTritonSupportedInstruction(ti.Instruction(), - GetCudaComputeCapability()) + GetComputeCapability()) .CanFuse(); const bool is_supported_dynamic_slice = legacy_triton::IsTritonSupportedDynamicSlice( @@ -246,7 +258,8 @@ ENTRY e { EXPECT_EQ(is_supported_instruction, is_supported_dynamic_slice); if (is_supported_instruction) { - TF_EXPECT_OK(ApplyFloatNormalization(ti.Module().get())); + TF_EXPECT_OK( + ApplyFloatNormalization(ti.Module().get(), GetComputeCapability())); EXPECT_TRUE(RunAndCompareNoHloPasses( std::move(ti.Module()), ErrorSpec{/*aabs=*/2e-4, /*arel=*/2e-4})); } else { @@ -287,9 +300,9 @@ ENTRY e { ParseTemplateAndGetInstruction( kHloTest, /*data_type=*/{}, HloOpcode::kDot)); const se::DeviceDescription dev_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetComputeCapability()); EXPECT_THAT(legacy_triton::IsTritonSupportedInstruction( - ti.Instruction(), GetCudaComputeCapability()) + ti.Instruction(), GetComputeCapability()) .Explain(), ::testing::HasSubstr("Unsupported output data type for Dot op.")); BlockLevelParameters block_level_parameters; @@ -297,7 +310,7 @@ ENTRY e { block_level_parameters.num_stages = 4; block_level_parameters.num_warps = 8; EXPECT_THAT( - TritonWrapper("test_fn", &ti.TritonFusion(), GetCudaComputeCapability(), + TritonWrapper("test_fn", &ti.TritonFusion(), GetComputeCapability(), dev_info, block_level_parameters, &llvm_module_, mlir_context_), tsl::testing::StatusIs( @@ -331,9 +344,9 @@ ENTRY e { ParseTemplateAndGetInstruction( kHloTest, /*data_type=*/{}, HloOpcode::kDot)); const se::DeviceDescription dev_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetComputeCapability()); EXPECT_THAT(legacy_triton::IsTritonSupportedInstruction( - ti.Instruction(), GetCudaComputeCapability()) + ti.Instruction(), GetComputeCapability()) .Explain(), ::testing::HasSubstr("Multiple batch dimensions")); BlockLevelParameters block_level_parameters; @@ -341,7 +354,7 @@ ENTRY e { block_level_parameters.num_stages = 4; block_level_parameters.num_warps = 8; EXPECT_THAT( - TritonWrapper("test_fn", &ti.TritonFusion(), GetCudaComputeCapability(), + TritonWrapper("test_fn", &ti.TritonFusion(), GetComputeCapability(), dev_info, block_level_parameters, &llvm_module_, mlir_context_), tsl::testing::StatusIs(absl::StatusCode::kInternal, @@ -369,7 +382,7 @@ ENTRY e { ParseTemplateAndGetInstruction( kHloTest, /*data_type=*/{}, HloOpcode::kDot)); EXPECT_THAT(legacy_triton::IsTritonSupportedInstruction( - ti.Instruction(), GetCudaComputeCapability()) + ti.Instruction(), GetComputeCapability()) .Explain(), ::testing::HasSubstr("No non-contracting dimensions.")); EXPECT_THAT(TritonFusionAnalysis::Execute(ti.TritonComputation()), diff --git a/third_party/xla/xla/service/gpu/triton_support_test.cc b/third_party/xla/xla/service/gpu/triton_support_test.cc index c3cf0a2e191102..eeb9c617d34b6b 100644 --- a/third_party/xla/xla/service/gpu/triton_support_test.cc +++ b/third_party/xla/xla/service/gpu/triton_support_test.cc @@ -49,6 +49,11 @@ namespace { using ::testing::Not; using ::testing::status::IsOk; +se::GpuComputeCapability GetComputeCapability() { + // TODO(b/348572380) Make this more general and test additional platforms. + return se::CudaComputeCapability::Ampere(); +} + auto AllXlaDataTypes() { std::vector xla_data_types; std::vector to_filter_out = {PRIMITIVE_TYPE_INVALID, @@ -88,18 +93,18 @@ class TritonSupportTest : public TritonSupportTestBase { BlockLevelParameters block_level_parameters = FromOutputTileSizes(std::move(output_tile_sizes)); if (IsTritonSupportedInstruction(ti.Instruction(), - GetCudaComputeCapability())) { + GetComputeCapability())) { TF_EXPECT_OK(CreateTritonIrAndFileCheck(ti.TritonComputation(), block_level_parameters, "CHECK: tt.func @triton_fn")); } else { if (!skip_failure_branch_to_avoid_crash) { const se::DeviceDescription dev_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetComputeCapability()); EXPECT_THAT( - TritonWrapper("test_fn", &ti.TritonFusion(), - GetCudaComputeCapability(), dev_info, - block_level_parameters, &llvm_module_, mlir_context_), + TritonWrapper("test_fn", &ti.TritonFusion(), GetComputeCapability(), + dev_info, block_level_parameters, &llvm_module_, + mlir_context_), Not(IsOk())); } } @@ -138,7 +143,7 @@ using UnaryElementwiseTest = TritonSupportTestWithParam; // instead of relying on triton gemm kernel. TEST_P(UnaryElementwiseTest, IsTritonSupportedUnaryElementwise) { auto [data_type, opcode] = GetParam(); - if (data_type == BF16 && SkipBF16Tests()) { + if (data_type == BF16 && !SupportsBF16(GetComputeCapability())) { GTEST_SKIP(); } @@ -181,7 +186,7 @@ using BinaryElementwiseTest = TritonSupportTestWithParam; TEST_P(BinaryElementwiseTest, IsTritonSupportedBinaryElementwise) { auto [data_type, opcode] = GetParam(); - if (data_type == BF16 && SkipBF16Tests()) { + if (data_type == BF16 && !SupportsBF16(GetComputeCapability())) { GTEST_SKIP(); } @@ -230,7 +235,7 @@ using CompareTest = TritonSupportTestWithParam; TEST_P(CompareTest, IsTritonSupportedCompare) { auto [data_type, opcode] = GetParam(); - if (data_type == BF16 && SkipBF16Tests()) { + if (data_type == BF16 && !SupportsBF16(GetComputeCapability())) { GTEST_SKIP(); } @@ -257,7 +262,7 @@ using TernaryElementwiseTest = TritonSupportTestWithParam; TEST_P(TernaryElementwiseTest, IsTritonSupportedTernaryElementwise) { auto [data_type, opcode] = GetParam(); - if (data_type == BF16 && SkipBF16Tests()) { + if (data_type == BF16 && !SupportsBF16(GetComputeCapability())) { GTEST_SKIP(); } @@ -285,7 +290,7 @@ using ReduceConstTest = TritonSupportTestWithParam; TEST_P(ReduceConstTest, IsTritonSupportedReduceWithConstInit) { auto [data_type, opcode] = GetParam(); - if (data_type == BF16 && SkipBF16Tests()) { + if (data_type == BF16 && !SupportsBF16(GetComputeCapability())) { GTEST_SKIP(); } @@ -316,7 +321,7 @@ INSTANTIATE_TEST_SUITE_P( TEST_F(TritonSupportTest, SupportedReduceWithConvertConstantIsCodegenedSuccessfullyWithTriton) { - if (SkipBF16Tests()) { + if (!SupportsBF16(GetComputeCapability())) { GTEST_SKIP(); } const std::string kHloTest = R"( @@ -336,9 +341,10 @@ ENTRY triton_computation { ParseTemplateAndGetInstruction( kHloTest, /*data_type=*/{}, HloOpcode::kReduce)); EXPECT_TRUE( - IsTritonSupportedInstruction(ti.Instruction(), GetCudaComputeCapability()) + IsTritonSupportedInstruction(ti.Instruction(), GetComputeCapability()) .CanFuse()); - TF_EXPECT_OK(ApplyFloatNormalization(ti.Module().get())); + TF_EXPECT_OK( + ApplyFloatNormalization(ti.Module().get(), GetComputeCapability())); TF_EXPECT_OK(CreateTritonIrAndFileCheck(ti.TritonComputation(), FromOutputTileSizes({1}), "CHECK: tt.func @triton_fn")); @@ -363,14 +369,14 @@ ENTRY triton_computation { ParseTemplateAndGetInstruction( kHloTest, /*data_type=*/{}, HloOpcode::kReduce)); EXPECT_THAT( - IsTritonSupportedInstruction(ti.Instruction(), GetCudaComputeCapability()) + IsTritonSupportedInstruction(ti.Instruction(), GetComputeCapability()) .Explain(), ::testing::HasSubstr( "Reduction is not a row-reduction of a single operand.")); const se::DeviceDescription dev_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetComputeCapability()); EXPECT_THAT( - TritonWrapper("test_fn", &ti.TritonFusion(), GetCudaComputeCapability(), + TritonWrapper("test_fn", &ti.TritonFusion(), GetComputeCapability(), dev_info, FromOutputTileSizes({1}), &llvm_module_, mlir_context_), Not(IsOk())); @@ -394,14 +400,14 @@ ENTRY triton_computation { ParseTemplateAndGetInstruction( kHloTest, /*data_type=*/{}, HloOpcode::kReduce)); EXPECT_THAT( - IsTritonSupportedInstruction(ti.Instruction(), GetCudaComputeCapability()) + IsTritonSupportedInstruction(ti.Instruction(), GetComputeCapability()) .Explain(), ::testing::HasSubstr( "Reduction is not a row-reduction of a single operand.")); const se::DeviceDescription dev_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetComputeCapability()); EXPECT_THAT( - TritonWrapper("test_fn", &ti.TritonFusion(), GetCudaComputeCapability(), + TritonWrapper("test_fn", &ti.TritonFusion(), GetComputeCapability(), dev_info, FromOutputTileSizes({1}), &llvm_module_, mlir_context_), Not(IsOk())); @@ -430,13 +436,13 @@ ENTRY triton_computation { ParseTemplateAndGetInstruction( kHloTest, /*data_type=*/{}, HloOpcode::kReduce)); EXPECT_THAT( - IsTritonSupportedInstruction(ti.Instruction(), GetCudaComputeCapability()) + IsTritonSupportedInstruction(ti.Instruction(), GetComputeCapability()) .Explain(), ::testing::HasSubstr("Unsupported output data type")); const se::DeviceDescription dev_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetComputeCapability()); EXPECT_THAT( - TritonWrapper("test_fn", &ti.TritonFusion(), GetCudaComputeCapability(), + TritonWrapper("test_fn", &ti.TritonFusion(), GetComputeCapability(), dev_info, FromOutputTileSizes({1}), &llvm_module_, mlir_context_), Not(IsOk())); @@ -460,14 +466,14 @@ ENTRY triton_computation { ParseTemplateAndGetInstruction( kHloTest, /*data_type=*/{}, HloOpcode::kReduce)); const se::DeviceDescription dev_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetComputeCapability()); EXPECT_THAT( - IsTritonSupportedInstruction(ti.Instruction(), GetCudaComputeCapability()) + IsTritonSupportedInstruction(ti.Instruction(), GetComputeCapability()) .Explain(), ::testing::HasSubstr("Reduction init value should be a constant " "or a convert of a constant.")); EXPECT_THAT( - TritonWrapper("test_fn", &ti.TritonFusion(), GetCudaComputeCapability(), + TritonWrapper("test_fn", &ti.TritonFusion(), GetComputeCapability(), dev_info, FromOutputTileSizes({1}), &llvm_module_, mlir_context_), tsl::testing::StatusIs( @@ -493,13 +499,13 @@ ENTRY triton_computation { ParseTemplateAndGetInstruction( kHloTest, /*data_type=*/{}, HloOpcode::kReduce)); const se::DeviceDescription dev_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetComputeCapability()); EXPECT_THAT( - IsTritonSupportedInstruction(ti.Instruction(), GetCudaComputeCapability()) + IsTritonSupportedInstruction(ti.Instruction(), GetComputeCapability()) .Explain(), ::testing::HasSubstr("Unsupported reduction computation by Triton.")); EXPECT_THAT( - TritonWrapper("test_fn", &ti.TritonFusion(), GetCudaComputeCapability(), + TritonWrapper("test_fn", &ti.TritonFusion(), GetComputeCapability(), dev_info, FromOutputTileSizes({1}), &llvm_module_, mlir_context_), tsl::testing::StatusIs(absl::StatusCode::kInvalidArgument, diff --git a/third_party/xla/xla/service/gpu/triton_test_utils.cc b/third_party/xla/xla/service/gpu/triton_test_utils.cc index 2117a476a95c18..9f4cd4cc8717bc 100644 --- a/third_party/xla/xla/service/gpu/triton_test_utils.cc +++ b/third_party/xla/xla/service/gpu/triton_test_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/triton_test_utils.h" +#include #include #include #include @@ -50,47 +51,38 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/filecheck.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace xla::gpu { -bool TritonTest::SkipBF16Tests() { - if (std::holds_alternative( - GpuComputeComp())) { - auto rcc = device_desc().rocm_compute_capability(); - return !rcc.has_bf16_dtype_support(); +bool SupportsBF16(const stream_executor::GpuComputeCapability& cc) { + if (std::holds_alternative(cc)) { + return std::get(cc).IsAtLeast( + se::CudaComputeCapability::AMPERE); + } else if (std::holds_alternative( + cc)) { + return std::get(cc) + .has_bf16_dtype_support(); } - return !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); + CHECK(false); } -stream_executor::GpuComputeCapability TritonTest::CudaAmpereOrRocm() { - if (std::holds_alternative( - GpuComputeComp())) { - return stream_executor::GpuComputeCapability{ - device_desc().rocm_compute_capability()}; - } else { - return stream_executor::GpuComputeCapability{ - stream_executor::CudaComputeCapability{ - stream_executor::CudaComputeCapability::AMPERE, 0}}; - } -} - -absl::Status TritonFilecheckTest::CreateTritonIrAndFileCheck( - absl::string_view hlo_text, +absl::Status CreateTritonIrAndFileCheck( + HloTestBase* test, absl::string_view hlo_text, const BlockLevelParameters& block_level_parameters, absl::string_view triton_fusion_name, absl::string_view filecheck_pattern) { TF_ASSIGN_OR_RETURN(std::unique_ptr verified_module, - ParseAndReturnVerifiedModule(hlo_text)); + test->ParseAndReturnVerifiedModule(hlo_text)); auto* comp = verified_module->GetComputationWithName(triton_fusion_name); TF_RET_CHECK(comp != nullptr); return CreateTritonIrAndFileCheck(*comp, block_level_parameters, filecheck_pattern); } -absl::Status TritonFilecheckTest::CreateTritonIrAndFileCheck( +absl::Status CreateTritonIrAndFileCheck( const HloComputation& computation, const BlockLevelParameters& block_level_parameters, absl::string_view filecheck_pattern) { @@ -112,22 +104,23 @@ absl::Status TritonFilecheckTest::CreateTritonIrAndFileCheck( return absl::OkStatus(); } -absl::Status TritonFilecheckTest::CreateTritonIrAndFileCheckForDot( - absl::string_view hlo_text, absl::string_view triton_fusion_name, - absl::string_view filecheck_pattern) { - return CreateTritonIrAndFileCheck(hlo_text, /*block_level_parameters=*/{}, +absl::Status CreateTritonIrAndFileCheckForDot( + HloTestBase* test, absl::string_view hlo_text, + absl::string_view triton_fusion_name, absl::string_view filecheck_pattern) { + return CreateTritonIrAndFileCheck(test, hlo_text, + /*block_level_parameters=*/{}, triton_fusion_name, filecheck_pattern); } -absl::Status TritonFilecheckTest::CreateTritonIrAndFileCheckForDot( +absl::Status CreateTritonIrAndFileCheckForDot( const HloComputation& computation, absl::string_view filecheck_pattern) { return CreateTritonIrAndFileCheck(computation, /*block_level_parameters=*/{}, filecheck_pattern); } -absl::StatusOr TritonSupportTestBase::ApplyFloatNormalization( - HloModule* module) { - const GpuFloatSupport bf16_support(GetCudaComputeCapability(), BF16); +absl::StatusOr ApplyFloatNormalization( + HloModule* module, const stream_executor::GpuComputeCapability& cc) { + const GpuFloatSupport bf16_support(cc, BF16); HloPassPipeline pipeline("hlo float normalization"); pipeline.AddPass(&bf16_support); return pipeline.Run(module); diff --git a/third_party/xla/xla/service/gpu/triton_test_utils.h b/third_party/xla/xla/service/gpu/triton_test_utils.h index 09184145c32f0f..25a7b4b1811d27 100644 --- a/third_party/xla/xla/service/gpu/triton_test_utils.h +++ b/third_party/xla/xla/service/gpu/triton_test_utils.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -36,62 +37,41 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" -#include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" namespace xla::gpu { -class TritonTest : public GpuCodegenTest { - public: - stream_executor::CudaComputeCapability GetCudaComputeCapability() { - return backend() - .default_stream_executor() - ->GetDeviceDescription() - .cuda_compute_capability(); - } +bool SupportsBF16(const stream_executor::GpuComputeCapability& cc); - const stream_executor::GpuComputeCapability& GpuComputeComp() { - return device_desc().gpu_compute_capability(); - } +absl::Status CreateTritonIrAndFileCheck( + HloTestBase* test, absl::string_view hlo_text, + const BlockLevelParameters& block_level_parameters, + absl::string_view triton_fusion_name, absl::string_view filecheck_pattern); - bool SkipBF16Tests(); - stream_executor::GpuComputeCapability CudaAmpereOrRocm(); +absl::Status CreateTritonIrAndFileCheck( + const HloComputation& computation, + const BlockLevelParameters& block_level_parameters, + absl::string_view filecheck_pattern); - protected: - const stream_executor::DeviceDescription& device_desc() { - return backend().default_stream_executor()->GetDeviceDescription(); - } -}; +absl::Status CreateTritonIrAndFileCheckForDot( + HloTestBase* test, absl::string_view hlo_text, + absl::string_view triton_fusion_name, absl::string_view filecheck_pattern); -class TritonFilecheckTest : public TritonTest { - public: - absl::Status CreateTritonIrAndFileCheck( - absl::string_view hlo_text, - const BlockLevelParameters& block_level_parameters, - absl::string_view triton_fusion_name, - absl::string_view filecheck_pattern); - - absl::Status CreateTritonIrAndFileCheck( - const HloComputation& computation, - const BlockLevelParameters& block_level_parameters, - absl::string_view filecheck_pattern); - - absl::Status CreateTritonIrAndFileCheckForDot( - absl::string_view hlo_text, absl::string_view triton_fusion_name, - absl::string_view filecheck_pattern); - - absl::Status CreateTritonIrAndFileCheckForDot( - const HloComputation& computation, absl::string_view filecheck_pattern); - - BlockLevelParameters FromOutputTileSizes( - std::vector output_tile_sizes) { - BlockLevelParameters block_level_parameters; - block_level_parameters.output_tile_sizes = std::move(output_tile_sizes); - return block_level_parameters; - } -}; +absl::Status CreateTritonIrAndFileCheckForDot( + const HloComputation& computation, absl::string_view filecheck_pattern); -class TritonSupportTestBase : public TritonFilecheckTest { +inline BlockLevelParameters FromOutputTileSizes( + std::vector output_tile_sizes) { + BlockLevelParameters block_level_parameters; + block_level_parameters.output_tile_sizes = std::move(output_tile_sizes); + return block_level_parameters; +} + +absl::StatusOr ApplyFloatNormalization( + HloModule* module, const stream_executor::GpuComputeCapability& cc); + +class TritonSupportTestBase : public HloTestBase { protected: // An HLO module together with a reference to the instruction of interest // that's being tested. See ParseTemplateAndGetInstruction for more details. @@ -140,8 +120,6 @@ class TritonSupportTestBase : public TritonFilecheckTest { absl::string_view hlo_template, xla::PrimitiveType data_type, xla::HloOpcode opcode); - absl::StatusOr ApplyFloatNormalization(HloModule* module); - llvm::LLVMContext llvm_ctx_; llvm::Module llvm_module_{"module", llvm_ctx_}; mlir::MLIRContext mlir_context_; From 57cfb2047301cdb59374d580c5aee61983c216f6 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Fri, 21 Jun 2024 04:15:16 -0700 Subject: [PATCH 109/256] Fix simplification of a//b//c. We previously checked if a is positive. This is not needed: c needs to be positive. PiperOrigin-RevId: 645343695 --- .../xla/xla/service/gpu/model/indexing_map.cc | 62 ++++++++++++------- .../service/gpu/model/indexing_map_test.cc | 10 +++ 2 files changed, 48 insertions(+), 24 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index e3ec255cc4d978..4fcd4c4c865cbf 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -111,11 +111,16 @@ class AffineExprSimplifier { // - Rewrites a % b to a if a is known to be less than b. mlir::AffineExpr RewriteMod(mlir::AffineBinaryOpExpr mod); - // Simplifier for floordiv. - // - Rewrites (a * 100 + ...) / 100 to a + (...) / 100 - // - Rewrites a / 100 to 0 when a is known to be less than 100. + // Simplifier for floordiv. Uses all the rules defined below. mlir::AffineExpr RewriteFloorDiv(mlir::AffineBinaryOpExpr div); + // Rewrites `(c % ab) // a` to `(c // a) % b`. Returns nullptr on mismatch. + AffineExpr SimplifyModDiv(AffineExpr divisor, int64_t dividend); + + // Rewrites `a // b // c` to `a // (b * c)` if `c` is positive. Returns + // nullptr on mismatch. + AffineExpr SimplifyDivDiv(AffineExpr divisor, int64_t dividend); + // Removes summands from arbitrarily nested sums (e.g, ((a+b)+c)) if `pred` // returns true. In this example, `pred` is evaluated on `a`, `b` and `c`, not // on `a+b`. @@ -226,6 +231,29 @@ AffineExpr AffineExprSimplifier::RewriteMod(AffineBinaryOpExpr mod) { return new_lhs % mod.getRHS() + extracted; } +AffineExpr AffineExprSimplifier::SimplifyModDiv(AffineExpr divisor, + int64_t dividend) { + if (auto mod = GetConstantRhs(divisor, AffineExprKind::Mod); + mod && (*mod % dividend == 0)) { + return GetLhs(divisor).floorDiv(dividend) % (*mod / dividend); + } + return nullptr; +} + +AffineExpr AffineExprSimplifier::SimplifyDivDiv(AffineExpr divisor, + int64_t dividend) { + // The outer dividend must be positive, since: + // (8 // -9) // -1 = -1 // -1 = 1 + // Whereas 8 // 9 = 0. The inner dividend can be negative. + if (dividend <= 0) { + return nullptr; + } + if (auto inner_dividend = GetConstantRhs(divisor, AffineExprKind::FloorDiv)) { + return GetLhs(divisor).floorDiv(dividend * *inner_dividend); + } + return nullptr; +} + AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) { auto mlir_context = range_evaluator_->GetMLIRContext(); auto rhs_range = range_evaluator_->ComputeExpressionRange(div.getRHS()); @@ -242,27 +270,13 @@ AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) { auto lhs_simplified = SimplifyOnce(div.getLHS()); // Rewrite `(c % ab) // a` to `(c // a) % b`. - // (c % ab) // a - // = (c - c // ab * ab) // a expand mod - // = c // a - (c // ab * b) rhs of - divides a - // = c // a - (c // a) // b * b) split ab - // = (c // a) % b contract mod - if (auto mod = GetConstantRhs(lhs_simplified, AffineExprKind::Mod); - mod && (*mod % d == 0)) { - return GetLhs(lhs_simplified).floorDiv(d) % (*mod / d); - } - - // Rewrite `(a / b) / c` to `a / (b * c)` if `a >= 0` and `b` and `c` are - // constants. - if (lhs_simplified.getKind() == AffineExprKind::FloorDiv) { - auto lhs_div = mlir::cast(lhs_simplified); - auto lhs_lhs = range_evaluator_->ComputeExpressionRange(lhs_div.getLHS()); - if (lhs_lhs.lower >= 0) { - auto lhs_rhs = range_evaluator_->ComputeExpressionRange(lhs_div.getRHS()); - if (lhs_rhs.IsPoint()) { - return lhs_div.getLHS().floorDiv(lhs_rhs.lower * d); - } - } + if (auto result = SimplifyModDiv(lhs_simplified, d)) { + return result; + } + + // Rewrite `((a // b) // c)` to `a // (b * c)`. + if (auto result = SimplifyDivDiv(lhs_simplified, d)) { + return result; } AffineExpr zero = getAffineConstantExpr(0, mlir_context); diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc index 882ab81543923d..24bbf63fd385d4 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc @@ -715,6 +715,16 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsInSequence) { )")); } +TEST_F(IndexingMapTest, AffineMapSimplification_NegativeDiv) { + // (s0 floordiv 2) floordiv -7 is not s0 floordiv -14: + // 15 // 2 // -7 = -1 + // 15 // -14 = -2 + auto serialized_map = "()[s0] -> ((s0 floordiv 2) floordiv -7)"; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {}, {1234}); + EXPECT_FALSE(indexing_map.Simplify()); +} + TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromMod) { auto serialized_map = "()[s0, s1, s2, s3] -> ((s0 * 458752 + s1 + s2 * 4 + s3 * 512) mod " From 9aea1b91ec78ff1f7dfb91ad44ecc83414760044 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Fri, 21 Jun 2024 04:34:20 -0700 Subject: [PATCH 110/256] Remove unnecessary dependencies. When splitting xla_ops target into several targets, the original dependencies with alwayslink attribute were added to all new targets. This is the right thing to do when having no additional information. But here, we know that these dependencies are just needed from xla_ops target, as the new targets are not meant to be depended on directly. Only in case there is a direct header include, we still need the dependency (one time on light_outside_compilation, one time on while_op). PiperOrigin-RevId: 645347508 --- tensorflow/compiler/tf2xla/kernels/BUILD | 1186 +++------------------- 1 file changed, 132 insertions(+), 1054 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 8a902aefbfc826..9de549cd94782a 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -48,7 +48,6 @@ cc_library( ":clip_by_value_op", ":concat_op", ":const_op", - ":conv_op_helpers", ":conv_ops", ":cross_op", ":cwise_ops", @@ -109,7 +108,6 @@ cc_library( ":retval_op", ":reverse_op", ":reverse_sequence_op", - ":rng_converter_utils", ":roll_op", ":scan_ops", ":scatter_nd_op", @@ -137,7 +135,6 @@ cc_library( ":strided_slice_op", ":tensor_array_ops", ":tensor_list_ops", - ":tensor_list_utils", ":tile_ops", ":to_bool_op", ":topk_op", @@ -525,10 +522,6 @@ tf_kernel_library( name = "xla_dot_op", srcs = ["xla_dot_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", @@ -542,21 +535,14 @@ tf_kernel_library( "//tensorflow/core:protos_all_cc", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "training_ops", srcs = ["training_ops.cc"], deps = [ - ":case_op", ":cwise_ops", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -569,20 +555,13 @@ tf_kernel_library( "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:math", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "batch_matmul_op", srcs = ["batch_matmul_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -596,20 +575,13 @@ tf_kernel_library( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client/lib:math", "@local_xla//xla/client/lib:matrix", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ) + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), + ], ) tf_kernel_library( name = "softmax_op", srcs = ["softmax_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", @@ -629,20 +601,13 @@ tf_kernel_library( "@local_xla//xla/client:xla_builder", "@local_xla//xla/client:xla_computation", "@local_xla//xla/client/lib:constants", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "in_topk_op", srcs = ["in_topk_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -660,23 +625,16 @@ tf_kernel_library( "@local_xla//xla/client/lib:arithmetic", "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:sorting", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "unary_ops_composition", srcs = ["unary_ops_composition.cc"], deps = [ - ":case_op", ":cwise_ops", ":elu_op", - ":if_op", ":relu_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -693,20 +651,13 @@ tf_kernel_library( "@local_xla//xla/client/lib:arithmetic", "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:math", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "topk_op", srcs = ["topk_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -719,21 +670,14 @@ tf_kernel_library( "@local_xla//xla:literal", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:sorting", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "tensor_array_ops", srcs = ["tensor_array_ops.cc"], deps = [ - ":case_op", ":gather_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -746,20 +690,13 @@ tf_kernel_library( "//tensorflow/core:lib", "@local_xla//xla:literal", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "tile_ops", srcs = ["tile_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -775,20 +712,13 @@ tf_kernel_library( "@com_google_absl//absl/types:span", "@local_xla//xla/client:value_inference", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "strided_slice_op", srcs = ["strided_slice_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -808,20 +738,13 @@ tf_kernel_library( "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:dynamic_shaped_ops", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "xla_broadcast_helper_op", srcs = ["xla_broadcast_helper_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -836,20 +759,13 @@ tf_kernel_library( "@com_google_absl//absl/strings", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "xla_svd_op", srcs = ["xla_svd_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -861,10 +777,7 @@ tf_kernel_library( "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:slicing", "@local_xla//xla/client/lib:svd", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( @@ -872,10 +785,6 @@ tf_kernel_library( srcs = ["reduction_ops.cc"], hdrs = ["reduction_ops.h"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -894,20 +803,13 @@ tf_kernel_library( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "batchtospace_op", srcs = ["batchtospace_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -916,20 +818,13 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "spmd_manual_sharding_ops", srcs = ["spmd_manual_sharding_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/mlir/tensorflow:xla_sharding_util", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", @@ -944,21 +839,14 @@ tf_kernel_library( "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "fused_conv_ops", srcs = ["fused_conv_ops.cc"], deps = [ - ":case_op", ":conv_op_helpers", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -970,20 +858,13 @@ tf_kernel_library( "//tensorflow/core:lib", "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:math", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "matrix_band_part_op", srcs = ["matrix_band_part_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -992,20 +873,13 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "clip_by_value_op", srcs = ["clip_by_value_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -1015,20 +889,13 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "sharding_util_ops", srcs = ["sharding_util_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -1045,20 +912,13 @@ tf_kernel_library( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "sort_ops", srcs = ["sort_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", @@ -1070,20 +930,13 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:comparators", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "matrix_inverse_op", srcs = ["matrix_inverse_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -1097,20 +950,13 @@ tf_kernel_library( "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:matrix", "@local_xla//xla/client/lib:qr", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "reduce_window_op", srcs = ["reduce_window_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", @@ -1123,20 +969,13 @@ tf_kernel_library( "//tensorflow/core:framework", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client:xla_computation", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "pad_op", srcs = ["pad_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -1152,20 +991,13 @@ tf_kernel_library( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:value_inference", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "function_ops", srcs = ["function_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -1176,20 +1008,13 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "sparse_to_dense_op", srcs = ["sparse_to_dense_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -1198,20 +1023,13 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/lib:scatter", "//tensorflow/compiler/tf2xla/ops:xla_ops", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "reverse_op", srcs = ["reverse_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -1225,21 +1043,14 @@ tf_kernel_library( "@com_google_absl//absl/container:inlined_vector", "@local_xla//xla:literal", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "random_ops", srcs = ["random_ops.cc"], deps = [ - ":case_op", ":gather_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", @@ -1266,10 +1077,7 @@ tf_kernel_library( "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:dynamic_shaped_ops", "@local_xla//xla/client/lib:loops", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( @@ -1277,10 +1085,6 @@ tf_kernel_library( srcs = ["gather_op.cc"], hdrs = ["gather_op_helpers.h"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", @@ -1297,20 +1101,13 @@ tf_kernel_library( "@local_xla//xla/client:client_library", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:slicing", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "segment_reduction_ops", srcs = ["segment_reduction_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -1323,20 +1120,13 @@ tf_kernel_library( "@local_xla//xla/client:value_inference", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "dynamic_partition_op", srcs = ["dynamic_partition_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -1355,20 +1145,13 @@ tf_kernel_library( "@local_xla//xla/client/lib:arithmetic", "@local_xla//xla/client/lib:comparators", "@local_xla//xla/client/lib:constants", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "transpose_op", srcs = ["transpose_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -1381,21 +1164,14 @@ tf_kernel_library( "//tensorflow/core:framework", "@local_xla//xla:shape_util", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "identity_op", srcs = ["identity_op.cc"], deps = [ - ":case_op", - ":if_op", ":tensor_list_utils", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -1407,22 +1183,15 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/log:check", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "tensor_list_ops", srcs = ["tensor_list_ops.cc"], deps = [ - ":case_op", ":gather_op", - ":if_op", ":tensor_list_utils", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -1437,20 +1206,13 @@ tf_kernel_library( "@local_xla//xla:status_macros", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "cross_op", srcs = ["cross_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -1459,20 +1221,13 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "next_after_op", srcs = ["next_after_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -1486,21 +1241,14 @@ tf_kernel_library( "//tensorflow/core:lib", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:math", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "stochastic_cast_op", srcs = ["stochastic_cast_op.cc"], deps = [ - ":case_op", - ":if_op", ":random_ops_util", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -1514,20 +1262,13 @@ tf_kernel_library( "//tensorflow/core/kernels:stochastic_cast_op_header", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "select_op", srcs = ["select_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", @@ -1540,20 +1281,13 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "l2loss_op", srcs = ["l2loss_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -1564,20 +1298,13 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "arg_op", srcs = ["arg_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -1590,10 +1317,7 @@ tf_kernel_library( "//tensorflow/core:lib", "@local_xla//xla:literal_util", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( @@ -1601,10 +1325,6 @@ tf_kernel_library( srcs = ["shape_util.cc"], hdrs = ["shape_util.h"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -1613,20 +1333,13 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "to_bool_op", srcs = ["to_bool_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -1637,20 +1350,13 @@ tf_kernel_library( "//tensorflow/core:lib", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "mirror_pad_op", srcs = ["mirror_pad_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -1667,20 +1373,13 @@ tf_kernel_library( "@local_xla//xla:util", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "xla_reduce_op", srcs = ["xla_reduce_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", @@ -1695,20 +1394,13 @@ tf_kernel_library( "@com_google_absl//absl/algorithm:container", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client:xla_computation", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "qr_op", srcs = ["qr_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -1719,20 +1411,13 @@ tf_kernel_library( "//tensorflow/core:framework", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:qr", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "slice_op", srcs = ["slice_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -1749,20 +1434,13 @@ tf_kernel_library( "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:dynamic_shaped_ops", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "pack_op", srcs = ["pack_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -1775,20 +1453,13 @@ tf_kernel_library( "//tensorflow/core:lib", "@local_xla//xla:literal_util", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "all_reduce_op", srcs = ["all_reduce_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", @@ -1803,20 +1474,13 @@ tf_kernel_library( "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:math", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "sharding_op", srcs = ["sharding_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -1828,21 +1492,14 @@ tf_kernel_library( "//tensorflow/core:framework", "@local_xla//xla:sharding_op_util", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "concat_op", srcs = ["concat_op.cc"], deps = [ - ":case_op", - ":if_op", ":shape_util", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -1855,20 +1512,13 @@ tf_kernel_library( "//tensorflow/core:lib", "@local_xla//xla:literal_util", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "scan_ops", srcs = ["scan_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -1882,10 +1532,7 @@ tf_kernel_library( "@local_xla//xla:literal", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client:xla_computation", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( @@ -1893,10 +1540,6 @@ tf_kernel_library( srcs = ["image_resize_ops.cc"], hdrs = ["image_resize_ops.h"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/jit:xla_activity_listener", "//tensorflow/compiler/jit:xla_activity_proto_cc", "//tensorflow/compiler/tf2xla:common", @@ -1921,20 +1564,13 @@ tf_kernel_library( "@local_xla//xla/client:xla_builder", "@local_xla//xla/client:xla_computation", "@local_xla//xla/client/lib:constants", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ] + if_cuda_or_rocm([":light_outside_compilation"]), ) tf_kernel_library( name = "spacetobatch_op", srcs = ["spacetobatch_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -1944,20 +1580,13 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core/util:overflow", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "lrn_ops", srcs = ["lrn_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -1969,21 +1598,14 @@ tf_kernel_library( "//tensorflow/core/platform:errors", "@local_xla//xla/client:padding", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "binary_ops", srcs = ["binary_ops.cc"], deps = [ - ":case_op", ":cwise_ops", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2000,20 +1622,13 @@ tf_kernel_library( "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:math", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "roll_op", srcs = ["roll_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -2024,10 +1639,7 @@ tf_kernel_library( "//tensorflow/core:lib", "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:slicing", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( @@ -2035,11 +1647,7 @@ tf_kernel_library( srcs = ["random_ops_util.cc"], hdrs = ["random_ops_util.h"], deps = [ - ":case_op", - ":if_op", ":rng_converter_utils", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2062,21 +1670,14 @@ tf_kernel_library( "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:prng", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "unary_ops", srcs = ["unary_ops.cc"], deps = [ - ":case_op", ":cwise_ops", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", @@ -2094,10 +1695,7 @@ tf_kernel_library( "@local_xla//xla/client/lib:arithmetic", "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:math", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( @@ -2105,10 +1703,6 @@ tf_kernel_library( srcs = ["cwise_ops.cc"], hdrs = ["cwise_ops.h"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2126,20 +1720,13 @@ tf_kernel_library( "@local_xla//xla/client:client_library", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "matrix_triangular_solve_op", srcs = ["matrix_triangular_solve_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -2152,10 +1739,7 @@ tf_kernel_library( "//tensorflow/core/platform:errors", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( @@ -2163,10 +1747,6 @@ tf_kernel_library( srcs = ["relu_op.cc"], hdrs = ["relu_op.h"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2179,21 +1759,14 @@ tf_kernel_library( "@local_xla//xla:literal", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "reduction_ops_common", srcs = ["reduction_ops_common.cc"], deps = [ - ":case_op", - ":if_op", ":reduction_ops", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2215,20 +1788,13 @@ tf_kernel_library( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client:xla_computation", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "bucketize_op", srcs = ["bucketize_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -2239,20 +1805,13 @@ tf_kernel_library( "//tensorflow/core:framework", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:arithmetic", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "depthtospace_op", srcs = ["depthtospace_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -2263,20 +1822,13 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "xla_optimization_barrier_op", srcs = ["xla_optimization_barrier_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2285,20 +1837,13 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "matmul_op", srcs = ["matmul_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -2313,20 +1858,13 @@ tf_kernel_library( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:matrix", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ) + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), + ], ) tf_kernel_library( name = "matrix_solve_op", srcs = ["matrix_solve_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -2340,21 +1878,14 @@ tf_kernel_library( "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:matrix", "@local_xla//xla/client/lib:qr", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "categorical_op", srcs = ["categorical_op.cc"], deps = [ - ":case_op", - ":if_op", ":random_ops_util", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2370,20 +1901,13 @@ tf_kernel_library( "@local_xla//xla/client/lib:arithmetic", "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:prng", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "dynamic_stitch_op", srcs = ["dynamic_stitch_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2395,20 +1919,13 @@ tf_kernel_library( "//tensorflow/core:framework", "@local_xla//xla:literal_util", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "beta_op", srcs = ["beta_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -2423,20 +1940,13 @@ tf_kernel_library( "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:loops", "@local_xla//xla/client/lib:math", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "unique_op", srcs = ["unique_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2459,20 +1969,13 @@ tf_kernel_library( "@local_xla//xla/client/lib:arithmetic", "@local_xla//xla/client/lib:comparators", "@local_xla//xla/client/lib:constants", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "reshape_op", srcs = ["reshape_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2490,20 +1993,13 @@ tf_kernel_library( "@local_xla//xla/client:value_inference", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "pooling_ops", srcs = ["pooling_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", @@ -2530,20 +2026,13 @@ tf_kernel_library( "@local_xla//xla/client/lib:arithmetic", "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:pooling", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "data_format_ops", srcs = ["data_format_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -2555,20 +2044,13 @@ tf_kernel_library( "//tensorflow/core:lib", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:slicing", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "xla_dequantize_op", srcs = ["xla_dequantize_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2581,20 +2063,13 @@ tf_kernel_library( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:quantize", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "const_op", srcs = ["const_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2606,22 +2081,15 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "shape_op", srcs = ["shape_op.cc"], deps = [ - ":case_op", - ":if_op", ":shape_util", ":tensor_list_utils", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2641,21 +2109,14 @@ tf_kernel_library( "@local_xla//xla:shape_util", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "image_ops", srcs = ["image_ops.cc"], deps = [ - ":case_op", ":gather_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2680,20 +2141,13 @@ tf_kernel_library( "@local_xla//xla/client/lib:dynamic_shaped_ops", "@local_xla//xla/client/lib:loops", "@local_xla//xla/client/lib:sorting", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "retval_op", srcs = ["retval_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2707,20 +2161,13 @@ tf_kernel_library( "//tensorflow/core/platform:errors", "@local_xla//xla:status_macros", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "xla_custom_call_v2_op", srcs = ["xla_custom_call_v2_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2735,20 +2182,13 @@ tf_kernel_library( "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "listdiff_op", srcs = ["listdiff_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2763,20 +2203,13 @@ tf_kernel_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "sendrecv_ops", srcs = ["sendrecv_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2788,21 +2221,14 @@ tf_kernel_library( "//tensorflow/core:framework", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "conv_ops", srcs = ["conv_ops.cc"], deps = [ - ":case_op", ":conv_op_helpers", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2817,20 +2243,13 @@ tf_kernel_library( "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:matrix", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "dequantize_op", srcs = ["dequantize_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -2843,20 +2262,13 @@ tf_kernel_library( "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:matrix", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "ensure_shape_op", srcs = ["ensure_shape_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2868,20 +2280,13 @@ tf_kernel_library( "//tensorflow/core:framework", "@local_xla//xla:literal", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "where_op", srcs = ["where_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2903,20 +2308,13 @@ tf_kernel_library( "@local_xla//xla/client/lib:comparators", "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:dynamic_shaped_ops", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "stack_ops", srcs = ["stack_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2928,20 +2326,13 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "@local_xla//xla:literal", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "xla_reduce_precision_op", srcs = ["xla_reduce_precision_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -2950,20 +2341,13 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "diag_op", srcs = ["diag_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", @@ -2981,10 +2365,7 @@ tf_kernel_library( "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:matrix", "@local_xla//xla/client/lib:pooling", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( @@ -2992,10 +2373,6 @@ tf_kernel_library( srcs = ["index_ops.cc"], hdrs = ["index_ops.h"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -3011,20 +2388,13 @@ tf_kernel_library( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:arithmetic", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "lower_upper_bound_ops", srcs = ["lower_upper_bound_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -3038,20 +2408,13 @@ tf_kernel_library( "//tensorflow/core/platform:errors", "@local_xla//xla:comparison_util", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "spacetodepth_op", srcs = ["spacetodepth_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -3062,20 +2425,13 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "empty_op", srcs = ["empty_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -3087,20 +2443,13 @@ tf_kernel_library( "//tensorflow/core:framework", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "bincount_op", srcs = ["bincount_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -3114,20 +2463,13 @@ tf_kernel_library( "@local_xla//xla/client/lib:arithmetic", "@local_xla//xla/client/lib:comparators", "@local_xla//xla/client/lib:constants", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "tridiagonal_ops", srcs = ["tridiagonal_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -3139,20 +2481,13 @@ tf_kernel_library( "//tensorflow/core:lib", "@local_xla//xla/client/lib:slicing", "@local_xla//xla/client/lib:tridiagonal", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "device_index_op", srcs = ["device_index_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -3168,20 +2503,13 @@ tf_kernel_library( "@local_xla//xla/client/lib:arithmetic", "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:math", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "bcast_ops", srcs = ["bcast_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -3194,21 +2522,14 @@ tf_kernel_library( "@com_google_absl//absl/strings", "@local_xla//xla:literal", "@local_xla//xla/client:value_inference", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "aggregate_ops", srcs = ["aggregate_ops.cc"], deps = [ - ":case_op", - ":if_op", ":tensor_list_utils", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -3218,20 +2539,13 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:lib", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "split_op", srcs = ["split_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -3244,20 +2558,13 @@ tf_kernel_library( "//tensorflow/core:lib", "@local_xla//xla:literal", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "replica_id_op", srcs = ["replica_id_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -3267,20 +2574,13 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "bias_ops", srcs = ["bias_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -3291,20 +2591,14 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "xla_select_and_scatter_op", srcs = ["xla_select_and_scatter_op.cc"], deps = [ - ":case_op", - ":if_op", ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", @@ -3317,22 +2611,15 @@ tf_kernel_library( "//tensorflow/core:framework", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client:xla_computation", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "stateless_random_ops_v2", srcs = ["stateless_random_ops_v2.cc"], deps = [ - ":case_op", - ":if_op", ":random_ops_util", ":rng_converter_utils", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", @@ -3350,20 +2637,13 @@ tf_kernel_library( "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:dynamic_shaped_ops", "@local_xla//xla/client/lib:prng", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "approx_topk_op", srcs = ["approx_topk_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -3381,21 +2661,14 @@ tf_kernel_library( "@local_xla//xla/client:xla_builder", "@local_xla//xla/client:xla_computation", "@local_xla//xla/client/lib:approx_topk", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "stateful_random_ops", srcs = ["stateful_random_ops.cc"], deps = [ - ":case_op", - ":if_op", ":random_ops_util", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -3413,20 +2686,13 @@ tf_kernel_library( "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:math", "@local_xla//xla/client/lib:prng", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "no_op", srcs = ["no_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -3435,20 +2701,13 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "xla_conv_op", srcs = ["xla_conv_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", @@ -3461,20 +2720,13 @@ tf_kernel_library( "//tensorflow/core:framework", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "broadcast_to_op", srcs = ["broadcast_to_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -3485,20 +2737,13 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:lib", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "sequence_ops", srcs = ["sequence_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -3512,22 +2757,15 @@ tf_kernel_library( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "variable_ops", srcs = ["variable_ops.cc"], deps = [ - ":case_op", ":gather_op", - ":if_op", ":shape_util", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -3543,20 +2781,13 @@ tf_kernel_library( "@local_xla//xla:literal", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:slicing", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "matrix_diag_ops", srcs = ["matrix_diag_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -3576,20 +2807,13 @@ tf_kernel_library( "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:matrix", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "reverse_sequence_op", srcs = ["reverse_sequence_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -3605,20 +2829,13 @@ tf_kernel_library( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "xla_custom_call_op", srcs = ["xla_custom_call_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -3633,20 +2850,13 @@ tf_kernel_library( "//tensorflow/core/tpu:tpu_defs", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "xla_self_adjoint_eig_op", srcs = ["xla_self_adjoint_eig_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -3656,20 +2866,13 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:lib", "@local_xla//xla/client/lib:self_adjoint_eig", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "cast_op", srcs = ["cast_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -3687,20 +2890,13 @@ tf_kernel_library( "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:arithmetic", "@local_xla//xla/client/lib:constants", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "dynamic_slice_ops", srcs = ["dynamic_slice_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", @@ -3712,20 +2908,13 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "fft_ops", srcs = ["fft_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -3738,20 +2927,13 @@ tf_kernel_library( "@local_xla//xla:literal_util", "@local_xla//xla:util", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "xla_pad_op", srcs = ["xla_pad_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -3764,20 +2946,13 @@ tf_kernel_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "one_hot_op", srcs = ["one_hot_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -3790,20 +2965,13 @@ tf_kernel_library( "//tensorflow/core:portable_gif_internal", "//tensorflow/core/platform:errors", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "unpack_op", srcs = ["unpack_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -3816,10 +2984,7 @@ tf_kernel_library( "//tensorflow/core:lib", "@local_xla//xla:literal", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( @@ -3827,11 +2992,7 @@ tf_kernel_library( srcs = ["elu_op.cc"], hdrs = ["elu_op.h"], deps = [ - ":case_op", ":cwise_ops", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -3843,21 +3004,14 @@ tf_kernel_library( "@local_xla//xla:literal", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "batch_norm_op", srcs = ["batch_norm_op.cc"], deps = [ - ":case_op", - ":if_op", ":relu_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", @@ -3872,21 +3026,14 @@ tf_kernel_library( "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:math", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "extract_image_patches_op", srcs = ["extract_image_patches_op.cc"], deps = [ - ":case_op", ":conv_op_helpers", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -3902,20 +3049,13 @@ tf_kernel_library( "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:matrix", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "scatter_nd_op", srcs = ["scatter_nd_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -3930,20 +3070,13 @@ tf_kernel_library( "@local_xla//xla:status_macros", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "fill_op", srcs = ["fill_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -3955,21 +3088,14 @@ tf_kernel_library( "//tensorflow/core:framework", "@local_xla//xla/client:value_inference", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "stateless_random_ops", srcs = ["stateless_random_ops.cc"], deps = [ - ":case_op", - ":if_op", ":random_ops_util", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -3987,20 +3113,13 @@ tf_kernel_library( "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:math", "@local_xla//xla/client/lib:prng", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "gather_scatter_ops", srcs = ["gather_scatter_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -4012,20 +3131,13 @@ tf_kernel_library( "//tensorflow/core:protos_all_cc", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:xla_builder", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "einsum_op", srcs = ["einsum_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -4037,20 +3149,13 @@ tf_kernel_library( "//tensorflow/core:framework", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:matrix", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "cholesky_op", srcs = ["cholesky_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -4060,20 +3165,13 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:matrix", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "quantize_and_dequantize_op", srcs = ["quantize_and_dequantize_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -4093,20 +3191,13 @@ tf_kernel_library( "@local_xla//xla/client/lib:arithmetic", "@local_xla//xla/client/lib:constants", "@local_xla//xla/client/lib:math", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "fake_quantize_ops", srcs = ["fake_quantize_ops.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -4118,20 +3209,13 @@ tf_kernel_library( "//tensorflow/core:lib", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client/lib:arithmetic", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_kernel_library( name = "fake_param_op", srcs = ["fake_param_op.cc"], deps = [ - ":case_op", - ":if_op", - ":while_op", - ":xla_call_module_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -4142,10 +3226,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "@local_xla//xla/client/lib:constants", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) tf_cc_test( @@ -4166,8 +3247,5 @@ tf_kernel_library( deps = [ "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_op_registry", - ] + if_cuda_or_rocm( - if_false = [], - if_true = [":light_outside_compilation"], - ), + ], ) From e193e941b790215bf444aa69e6cf4fdec4d03e70 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Fri, 21 Jun 2024 05:27:06 -0700 Subject: [PATCH 111/256] NFC: Replace RemoveSummands with MapSummands. The latter is more general and slightly nicer to use. PiperOrigin-RevId: 645357823 --- .../xla/xla/service/gpu/model/indexing_map.cc | 79 ++++++++----------- 1 file changed, 32 insertions(+), 47 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index 4fcd4c4c865cbf..9b7045310ecd5c 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -121,11 +121,10 @@ class AffineExprSimplifier { // nullptr on mismatch. AffineExpr SimplifyDivDiv(AffineExpr divisor, int64_t dividend); - // Removes summands from arbitrarily nested sums (e.g, ((a+b)+c)) if `pred` - // returns true. In this example, `pred` is evaluated on `a`, `b` and `c`, not - // on `a+b`. - mlir::AffineExpr RemoveSummands( - mlir::AffineExpr expr, const std::function& pred); + // Rewrites summands in arbitrarily nested sums (e.g, ((a+b)+c)) by applying + // `fn` to each one. In the example, the result is fn(a)+fn(b)+fn(c). + AffineExpr MapSummands(AffineExpr expr, + const std::function& fn); void VisitSummands(mlir::AffineExpr expr, const std::function& visit); @@ -181,19 +180,20 @@ AffineExpr AffineExprSimplifier::RewriteMod(AffineBinaryOpExpr mod) { return (GetLhs(lhs_simplified) % (m / *mul)) * *mul; } + auto zero = getAffineConstantExpr(0, mod.getContext()); int64_t extracted_constant = 0; - auto new_lhs = RemoveSummands(lhs_simplified, [&](AffineExpr expr) { + auto new_lhs = MapSummands(lhs_simplified, [&](AffineExpr expr) { if (auto cst = mlir::dyn_cast(expr); cst && cst.getValue() >= m) { extracted_constant += cst.getValue(); - return true; + return zero; } if (auto multiplier = GetConstantRhs(expr, AffineExprKind::Mul)) { if (*multiplier % m == 0) { - return true; + return zero; } } - return false; + return expr; }); new_lhs = new_lhs + (extracted_constant % m); @@ -218,12 +218,12 @@ AffineExpr AffineExprSimplifier::RewriteMod(AffineBinaryOpExpr mod) { if (m % *multiplier_gcd == 0 && no_multiplier_range.lower >= 0 && no_multiplier_range.upper < *multiplier_gcd) { // Remove everything that doesn't have a multiplier. - new_lhs = RemoveSummands(new_lhs, [&](AffineExpr expr) { + new_lhs = MapSummands(new_lhs, [&](AffineExpr expr) { if (GetConstantRhs(expr, AffineExprKind::Mul)) { - return false; + return expr; } extracted = extracted + expr; - return true; + return zero; }); } } @@ -242,14 +242,9 @@ AffineExpr AffineExprSimplifier::SimplifyModDiv(AffineExpr divisor, AffineExpr AffineExprSimplifier::SimplifyDivDiv(AffineExpr divisor, int64_t dividend) { - // The outer dividend must be positive, since: - // (8 // -9) // -1 = -1 // -1 = 1 - // Whereas 8 // 9 = 0. The inner dividend can be negative. - if (dividend <= 0) { - return nullptr; - } - if (auto inner_dividend = GetConstantRhs(divisor, AffineExprKind::FloorDiv)) { - return GetLhs(divisor).floorDiv(dividend * *inner_dividend); + // The inner dividend here can be negative. + if (auto dividend_2 = GetConstantRhs(divisor, AffineExprKind::FloorDiv)) { + return GetLhs(divisor).floorDiv(dividend * *dividend_2); } return nullptr; } @@ -261,8 +256,8 @@ AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) { // TODO(jreiffers): Split this function into multiple (one for each rewrite // rule). - // The logic below assumes we have a constant RHS. - if (!rhs_range.IsPoint()) { + // The logic below assumes we have a constant positive RHS. + if (!rhs_range.IsPoint() || rhs_range.lower <= 0) { return div; } int64_t d = rhs_range.lower; @@ -281,7 +276,7 @@ AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) { AffineExpr zero = getAffineConstantExpr(0, mlir_context); AffineExpr extracted = zero; - auto new_dividend = RemoveSummands(lhs_simplified, [&](AffineExpr expr) { + auto new_dividend = MapSummands(lhs_simplified, [&](AffineExpr expr) { if (auto multiplier = GetConstantRhs(expr, AffineExprKind::Mul)) { // (x * 7 + ...) / 3 -> can't extract. We could extract x * 2 and keep // one x, but we currently have no reason to do that. @@ -289,18 +284,13 @@ AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) { int64_t factor = *multiplier / d; extracted = extracted + GetLhs(expr) * factor; // Remove from dividend. - return true; + return zero; } } // Not a constant multiplier, keep in dividend. - return false; + return expr; }); - // If we removed everything, skip the div. - if (new_dividend == zero) { - return extracted; - } - // The gcd of all multipliers and the dividend. int64_t multiplier_divisor_gcd = d; Interval no_multiplier_range{0, 0}; @@ -318,23 +308,18 @@ AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) { // the result. if (no_multiplier_range.lower >= 0 && no_multiplier_range.upper < multiplier_divisor_gcd) { - auto new_new_dividend = zero; - VisitSummands(new_dividend, [&](AffineExpr summand) { + new_dividend = MapSummands(new_dividend, [&](AffineExpr summand) { if (auto mult = GetConstantRhs(summand, AffineExprKind::Mul)) { - new_new_dividend = new_new_dividend + - (GetLhs(summand) * (*mult / multiplier_divisor_gcd)); + return GetLhs(summand) * (*mult / multiplier_divisor_gcd); } + // This has no multiplier and we previously determined it can't affect + // the result of the division. + return zero; }); - return new_new_dividend.floorDiv(d / multiplier_divisor_gcd) + extracted; - } - - // If we removed nothing, return the original division. - if (extracted == getAffineConstantExpr(0, mlir_context) && - new_dividend == div.getLHS()) { - return div; + d /= multiplier_divisor_gcd; } - return extracted + new_dividend.floorDiv(div.getRHS()); + return new_dividend.floorDiv(d) + extracted; } std::optional AffineExprSimplifier::GetConstantRhs( @@ -350,18 +335,18 @@ std::optional AffineExprSimplifier::GetConstantRhs( return bound.lower; } -AffineExpr AffineExprSimplifier::RemoveSummands( - AffineExpr expr, const std::function& pred) { +AffineExpr AffineExprSimplifier::MapSummands( + AffineExpr expr, const std::function& fn) { if (expr.getKind() == AffineExprKind::Add) { auto add = mlir::dyn_cast(expr); - auto lhs = RemoveSummands(add.getLHS(), pred); - auto rhs = RemoveSummands(add.getRHS(), pred); + auto lhs = MapSummands(add.getLHS(), fn); + auto rhs = MapSummands(add.getRHS(), fn); if (lhs == add.getLHS() && rhs == add.getRHS()) { return add; } return lhs + rhs; } - return pred(expr) ? mlir::getAffineConstantExpr(0, expr.getContext()) : expr; + return fn(expr); } void AffineExprSimplifier::VisitSummands( From d566f8dea4978792f38585c8cbabdb1e53b47d6c Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Fri, 21 Jun 2024 05:31:18 -0700 Subject: [PATCH 112/256] Check input channels before delegating PiperOrigin-RevId: 645358861 --- .../delegates/xnnpack/xnnpack_delegate.cc | 39 ++++++++++--------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index 547b18320d28cc..9b849d271b05d0 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -4213,6 +4213,27 @@ class Subgraph { logging_context, node_index, fc_params->activation, &output_min, &output_max)); + uint32_t dq_quantized_id = XNN_INVALID_VALUE_ID; + size_t num_nonbatch_dims = 0; + int ic = 1; + int input_dims_remaining = NumDimensions(&input_tensor) - 1; + // Which input dimensions are part of input_channels. + if (dynamically_quantized) { + while (ic != input_channels && input_dims_remaining >= 0) { + ic *= input_tensor.dims->data[input_dims_remaining]; + --input_dims_remaining; + ++num_nonbatch_dims; + } + if (ic != input_channels) { + TF_LITE_MAYBE_KERNEL_LOG( + logging_context, + "Could not determine how many input dimensions to use for " + "input_channels: %s node #%d", + EnumNameBuiltinOperator(BuiltinOperator_FULLY_CONNECTED), + node_index); + return kTfLiteError; + } + } if (subgraph != nullptr) { if (dynamically_quantized) { TfLiteAffineQuantization* filter_params = @@ -4230,24 +4251,6 @@ class Subgraph { filter_tensor.params.zero_point; } } - uint32_t dq_quantized_id = XNN_INVALID_VALUE_ID; - size_t num_nonbatch_dims = 0; - int ic = 1; - int input_dims_remaining = NumDimensions(&input_tensor) - 1; - // Which input dimensions are part of input_channels. - while (ic != input_channels && input_dims_remaining >= 0) { - ic *= input_tensor.dims->data[input_dims_remaining]; - --input_dims_remaining; - ++num_nonbatch_dims; - } - if (ic != input_channels) { - TF_LITE_KERNEL_LOG( - logging_context, - "Could not determine how many input dimensions to use for " - "input_channels: %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_FULLY_CONNECTED), - node_index); - } std::vector input_dims( &input_tensor.dims->data[0], &input_tensor.dims->data[NumDimensions(&input_tensor)]); From e73001e2fa44e9e1f2578ac4438c9ae55945161e Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 21 Jun 2024 05:45:54 -0700 Subject: [PATCH 113/256] PR #13856: Extend FFI DataType with FP8 Types Imported from GitHub PR https://github.com/openxla/xla/pull/13856 Extend FFI DataType with FP8 Types. Copybara import of the project: -- b2e55a63f42de89778376b5b903fafb30f049212 by Phuong Nguyen : added fp8 types into DataType in FFI Signed-off-by: Phuong Nguyen -- a6ab1f9e8b20f9d6536fe05e247eeb12bb3bff6d by Phuong Nguyen : fixed enum values and formatted changes with clang-f Signed-off-by: Phuong Nguyen Merging this change closes #13856 PiperOrigin-RevId: 645361764 --- third_party/xla/xla/ffi/api/api.h | 10 ++++++++++ third_party/xla/xla/ffi/api/c_api.h | 5 +++++ third_party/xla/xla/ffi/api/ffi.h | 5 +++++ third_party/xla/xla/ffi/api/ffi_test.cc | 7 +++++++ third_party/xla/xla/ffi/call_frame.cc | 5 +++++ 5 files changed, 32 insertions(+) diff --git a/third_party/xla/xla/ffi/api/api.h b/third_party/xla/xla/ffi/api/api.h index a10acc92759d5b..c000e68819bcff 100644 --- a/third_party/xla/xla/ffi/api/api.h +++ b/third_party/xla/xla/ffi/api/api.h @@ -129,6 +129,16 @@ inline std::ostream& operator<<(std::ostream& os, return os << "C128"; case XLA_FFI_DataType_TOKEN: return os << "TOKEN"; + case XLA_FFI_DataType_F8E5M2: + return os << "F8E5M2"; + case XLA_FFI_DataType_F8E4M3FN: + return os << "F8E4M3FN"; + case XLA_FFI_DataType_F8E4M3B11FNUZ: + return os << "F8E4M3B11FNUZ"; + case XLA_FFI_DataType_F8E5M2FNUZ: + return os << "F8E5M2FNUZ"; + case XLA_FFI_DataType_F8E4M3FNUZ: + return os << "F8E4M3FNUZ"; } } diff --git a/third_party/xla/xla/ffi/api/c_api.h b/third_party/xla/xla/ffi/api/c_api.h index 8989a6c12b7ed9..368a88f2bee44a 100644 --- a/third_party/xla/xla/ffi/api/c_api.h +++ b/third_party/xla/xla/ffi/api/c_api.h @@ -178,6 +178,11 @@ typedef enum { XLA_FFI_DataType_C64 = 15, XLA_FFI_DataType_C128 = 18, XLA_FFI_DataType_TOKEN = 17, + XLA_FFI_DataType_F8E5M2 = 19, + XLA_FFI_DataType_F8E4M3FN = 20, + XLA_FFI_DataType_F8E4M3B11FNUZ = 23, + XLA_FFI_DataType_F8E5M2FNUZ = 24, + XLA_FFI_DataType_F8E4M3FNUZ = 25, } XLA_FFI_DataType; // LINT.ThenChange(ffi_test.cc) diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h index 19897667e48489..edd0e2087a3475 100644 --- a/third_party/xla/xla/ffi/api/ffi.h +++ b/third_party/xla/xla/ffi/api/ffi.h @@ -60,6 +60,11 @@ enum class DataType : uint8_t { C64 = XLA_FFI_DataType_C64, C128 = XLA_FFI_DataType_C128, TOKEN = XLA_FFI_DataType_TOKEN, + F8E5M2 = XLA_FFI_DataType_F8E5M2, + F8E4M3FN = XLA_FFI_DataType_F8E4M3FN, + F8E4M3B11FNUZ = XLA_FFI_DataType_F8E4M3B11FNUZ, + F8E5M2FNUZ = XLA_FFI_DataType_F8E5M2FNUZ, + F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ, }; inline std::ostream& operator<<(std::ostream& os, const DataType dtype) { diff --git a/third_party/xla/xla/ffi/api/ffi_test.cc b/third_party/xla/xla/ffi/api/ffi_test.cc index b6e21a4a689c14..df234501637b8a 100644 --- a/third_party/xla/xla/ffi/api/ffi_test.cc +++ b/third_party/xla/xla/ffi/api/ffi_test.cc @@ -91,6 +91,13 @@ TEST(FfiTest, DataTypeEnumValue) { EXPECT_EQ(encoded(PrimitiveType::C128), encoded(DataType::C128)); EXPECT_EQ(encoded(PrimitiveType::TOKEN), encoded(DataType::TOKEN)); + + EXPECT_EQ(encoded(PrimitiveType::F8E5M2), encoded(DataType::F8E5M2)); + EXPECT_EQ(encoded(PrimitiveType::F8E4M3FN), encoded(DataType::F8E4M3FN)); + EXPECT_EQ(encoded(PrimitiveType::F8E4M3B11FNUZ), + encoded(DataType::F8E4M3B11FNUZ)); + EXPECT_EQ(encoded(PrimitiveType::F8E5M2FNUZ), encoded(DataType::F8E5M2FNUZ)); + EXPECT_EQ(encoded(PrimitiveType::F8E4M3FNUZ), encoded(DataType::F8E4M3FNUZ)); } TEST(FfiTest, ErrorEnumValue) { diff --git a/third_party/xla/xla/ffi/call_frame.cc b/third_party/xla/xla/ffi/call_frame.cc index bd991f4e584700..35c945b2bbc583 100644 --- a/third_party/xla/xla/ffi/call_frame.cc +++ b/third_party/xla/xla/ffi/call_frame.cc @@ -256,6 +256,11 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) { case PrimitiveType::C64: case PrimitiveType::C128: case PrimitiveType::TOKEN: + case PrimitiveType::F8E5M2: + case PrimitiveType::F8E4M3FN: + case PrimitiveType::F8E4M3B11FNUZ: + case PrimitiveType::F8E5M2FNUZ: + case PrimitiveType::F8E4M3FNUZ: return static_cast(primitive_type); default: DCHECK(false) << "Unsupported primitive type " From f11f5e370f1022ec4b59deb2ca2c34ed0504ed91 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Fri, 21 Jun 2024 06:02:00 -0700 Subject: [PATCH 114/256] [XLA:GPU] Fix integer overflow issues in Cost Model and Symbolic Tile Analysis. PiperOrigin-RevId: 645364990 --- third_party/xla/xla/service/gpu/model/BUILD | 1 + .../gpu_indexing_performance_model_test.cc | 42 +++++++++++++++++++ .../gpu/model/gpu_performance_model_base.cc | 9 ++-- .../gpu/model/gpu_performance_model_base.h | 2 +- .../gpu/model/symbolic_tile_analysis.cc | 4 +- .../gpu/model/symbolic_tile_analysis_test.cc | 36 ++++++++++++++++ 6 files changed, 87 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index a35308c7f0d81b..8aaa9b9e3c327b 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -388,6 +388,7 @@ xla_cc_test( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:launch_dimensions", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc index a58e23d2f57a8d..3d03f8915892ab 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_performance_model_base.h" @@ -335,6 +336,47 @@ ENTRY main { 1); } +// This test means to catch integer overflow errors when run with ASan build. +// The checks below are just sanity checks for values. +TEST_F( + GpuIndexingPerformanceModelTest, + EstimateRunTimeForTiledFusion_NumberOfTilesLargerThanInt32Max_IsSupported) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule softmax + +max_computation { + arg_0 = f16[] parameter(0) + arg_1 = f16[] parameter(1) + ROOT maximum = f16[] maximum(arg_0, arg_1) +} + +softmax { + param_0 = f16[65538,32768]{1,0} parameter(0) + constant_neg_inf = f16[] constant(-inf) + reduce = f16[65538]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = f16[65538,32768]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = f16[65538,32768]{1,0} subtract(param_0, broadcast) +} + +ENTRY main { + param_0 = f16[65538,32768]{1,0} parameter(0) + ROOT fusion = f16[65538,32768]{1,0} fusion(param_0), kind=kCustom, calls=softmax +} +)")); + auto fusion_adaptor = HloFusionAdaptor::ForInstruction( + module->entry_computation()->root_instruction()); + + LaunchDimensions launch_dimensions{65538LL * 32768LL, 32}; + TF_ASSERT_OK_AND_ASSIGN( + auto runtime_data, + indexing_cost_model_.EstimateRunTimeForTiledFusion( + *fusion_adaptor, launch_dimensions, /*output_tile_sizes=*/{1, 1})); + + EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.read_time), 183, 1); + EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.compute_time), 39, 1); + EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.exec_time), 185, 1); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc index 2d0fc0ab3d737d..56e34c30b963dc 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc @@ -403,12 +403,13 @@ absl::Duration GpuPerformanceModelBase::WriteTime( /*static*/ absl::Duration GpuPerformanceModelBase::ComputeTime( - const se::DeviceDescription& gpu_device_info, int64_t flops, int num_blocks, - int num_threads_per_block) { + const se::DeviceDescription& gpu_device_info, int64_t flops, + int64_t num_blocks, int64_t num_threads_per_block) { int64_t n_active_fpus_per_core = - std::min(num_threads_per_block, gpu_device_info.fpus_per_core()); + std::min(num_threads_per_block, gpu_device_info.fpus_per_core()); - int64_t n_active_core = std::min(num_blocks, gpu_device_info.core_count()); + int64_t n_active_core = + std::min(num_blocks, gpu_device_info.core_count()); int64_t fpu_count = n_active_core * n_active_fpus_per_core; int64_t flop_per_ns_per_fpu = gpu_device_info.clock_rate_ghz() * /*fma:*/ 2; diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h index bd33ce14b4a186..103c6a9f9794dd 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h @@ -215,7 +215,7 @@ class GpuPerformanceModelBase { static absl::Duration ComputeTime( const se::DeviceDescription& gpu_device_info, int64_t flops, - int num_blocks, int num_threads_per_block); + int64_t num_blocks, int64_t num_threads_per_block); static absl::Duration CombineComputeAndMemoryAccessTime( absl::Duration compute_time, absl::Duration memory_access_time, diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc index 0f22883914268b..d22917cba3f2de 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -78,11 +78,11 @@ IndexingMap ComputeBlockIdToOutputTileIndexing( mlir::MLIRContext* mlir_context) { CHECK_EQ(dimensions.size(), tile_sizes.size()); // Crash OK - int num_tiles = 1; + int64_t num_tiles = 1; std::vector outer_loop_bounds; outer_loop_bounds.reserve(dimensions.size()); for (auto [dim_size, tile_size] : llvm::zip(dimensions, tile_sizes)) { - int num_tiles_per_dim = (dim_size + tile_size - 1) / tile_size; + int64_t num_tiles_per_dim = (dim_size + tile_size - 1) / tile_size; num_tiles *= num_tiles_per_dim; outer_loop_bounds.push_back(num_tiles_per_dim); diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc index 279ada6aaa6267..7a420ba9542298 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -752,6 +752,42 @@ ENTRY entry_computation { LogTilingsIfVlog1(good_tilings); } +// This test means to catch integer overflow errors when run with ASan build. +TEST_F(SymbolicTileAnalysisTest, + FusionWithNumberOfTilesLargerThanInt32MaxIsSupported) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule softmax + +fused_computation { + param_0 = f16[65538,32768]{1,0} parameter(0) + ROOT log = f16[65538,32768]{1,0} log(param_0) +} + +ENTRY main { + param_0 = f16[65538,32768]{1,0} parameter(0) + ROOT fusion = f16[65538,32768]{1,0} fusion(param_0), kind=kLoop, calls=fused_computation +} +)")); + + std::optional analysis = TryAnalyzeModule(module.get()); + ASSERT_TRUE(analysis.has_value()); + + TF_ASSERT_OK_AND_ASSIGN( + TiledHloComputation tiled_hlo_computation, + analysis->ComputeTiledHloInstructions(/*tile_parameters=*/{1, 1})); + + EXPECT_THAT(*tiled_hlo_computation.GetRoot(), + MatchTiledHloInstruction( + /*tile_sizes=*/{1, 1}, + /*tile_strides=*/{1, 1}, + /*block_id_to_tile_offsets_indexing=*/R"( + (d0) -> (d0 floordiv 32768, d0 mod 32768) + domain: + d0 in [0, 2147549183] + )")); +} + } // namespace } // namespace gpu } // namespace xla From 6dda17ce00d3042e1e2a6c2be252e06bbc3b3846 Mon Sep 17 00:00:00 2001 From: Shanbin Ke Date: Fri, 21 Jun 2024 07:10:45 -0700 Subject: [PATCH 115/256] PR #13757: [XLA:GPU] Upgrade cuDNN frontend to 1.5 Imported from GitHub PR https://github.com/openxla/xla/pull/13757 release note: https://github.com/NVIDIA/cudnn-frontend/releases/tag/v1.5.0 Copybara import of the project: -- 32dfd425085375e1debed05b8a1b83862140bd2e by Cjkkkk : init -- 9f096b6b8261068d2a41c63209d6fc8882066e10 by Cjkkkk : fix some header include and use 1.5.1 Merging this change closes #13757 PiperOrigin-RevId: 645379774 --- tensorflow/workspace2.bzl | 6 +-- third_party/cudnn_frontend_header_fix.patch | 41 ++++++++++++++++++- .../cudnn_frontend_header_fix.patch | 41 ++++++++++++++++++- third_party/xla/workspace2.bzl | 6 +-- third_party/xla/xla/tsl/cuda/cudnn.symbols | 1 + 5 files changed, 87 insertions(+), 8 deletions(-) diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index cd04eaba8ed78c..08d01a069bb30b 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -181,9 +181,9 @@ def _tf_repositories(): name = "cudnn_frontend_archive", build_file = "//third_party:cudnn_frontend.BUILD", patch_file = ["//third_party:cudnn_frontend_header_fix.patch"], - sha256 = "5727ed189a17fe888f1729ba09b2afd8df3e71192a27e9fa87e14a60f7b9d367", - strip_prefix = "cudnn-frontend-1.3.0", - urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.3.0.zip"), + sha256 = "281789777ac296f5f8215a7c4bd066de8816d240eb44c760788beebf8d25a99f", + strip_prefix = "cudnn-frontend-1.5.1", + urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.5.1.zip"), ) tf_http_archive( diff --git a/third_party/cudnn_frontend_header_fix.patch b/third_party/cudnn_frontend_header_fix.patch index 70476bd3ff5d56..6a57589b53e6a2 100644 --- a/third_party/cudnn_frontend_header_fix.patch +++ b/third_party/cudnn_frontend_header_fix.patch @@ -1,5 +1,5 @@ diff --git a/include/cudnn_frontend.h b/include/cudnn_frontend.h -index 0f0d5a6..802bcbb 100644 +index 7bca425..cd74e3a 100644 --- a/include/cudnn_frontend.h +++ b/include/cudnn_frontend.h @@ -97,7 +97,7 @@ @@ -11,3 +11,42 @@ index 0f0d5a6..802bcbb 100644 #include "cudnn_frontend_ConvDesc.h" #include "cudnn_frontend_Heuristics.h" +diff --git a/include/cudnn_frontend/backend/backend_descriptor.h b/include/cudnn_frontend/backend/backend_descriptor.h +index dc7ad25..ade164a 100644 +--- a/include/cudnn_frontend/backend/backend_descriptor.h ++++ b/include/cudnn_frontend/backend/backend_descriptor.h +@@ -2,7 +2,7 @@ + + #include + +-#include "cudnn.h" ++#include "third_party/gpus/cudnn/cudnn.h" + + namespace cudnn_frontend::detail { + +diff --git a/include/cudnn_frontend/backend/execution_helpers.h b/include/cudnn_frontend/backend/execution_helpers.h +index 334ffde..d2ca694 100644 +--- a/include/cudnn_frontend/backend/execution_helpers.h ++++ b/include/cudnn_frontend/backend/execution_helpers.h +@@ -2,7 +2,7 @@ + + #include + +-#include "cudnn.h" ++#include "third_party/gpus/cudnn/cudnn.h" + + #include "backend_descriptor.h" + +diff --git a/include/cudnn_frontend/backend/plan_helpers.h b/include/cudnn_frontend/backend/plan_helpers.h +index 1fa458d..8c37d10 100644 +--- a/include/cudnn_frontend/backend/plan_helpers.h ++++ b/include/cudnn_frontend/backend/plan_helpers.h +@@ -2,7 +2,7 @@ + + #include + +-#include "cudnn.h" ++#include "third_party/gpus/cudnn/cudnn.h" + + #include "backend_descriptor.h" + diff --git a/third_party/xla/third_party/cudnn_frontend_header_fix.patch b/third_party/xla/third_party/cudnn_frontend_header_fix.patch index 70476bd3ff5d56..6a57589b53e6a2 100644 --- a/third_party/xla/third_party/cudnn_frontend_header_fix.patch +++ b/third_party/xla/third_party/cudnn_frontend_header_fix.patch @@ -1,5 +1,5 @@ diff --git a/include/cudnn_frontend.h b/include/cudnn_frontend.h -index 0f0d5a6..802bcbb 100644 +index 7bca425..cd74e3a 100644 --- a/include/cudnn_frontend.h +++ b/include/cudnn_frontend.h @@ -97,7 +97,7 @@ @@ -11,3 +11,42 @@ index 0f0d5a6..802bcbb 100644 #include "cudnn_frontend_ConvDesc.h" #include "cudnn_frontend_Heuristics.h" +diff --git a/include/cudnn_frontend/backend/backend_descriptor.h b/include/cudnn_frontend/backend/backend_descriptor.h +index dc7ad25..ade164a 100644 +--- a/include/cudnn_frontend/backend/backend_descriptor.h ++++ b/include/cudnn_frontend/backend/backend_descriptor.h +@@ -2,7 +2,7 @@ + + #include + +-#include "cudnn.h" ++#include "third_party/gpus/cudnn/cudnn.h" + + namespace cudnn_frontend::detail { + +diff --git a/include/cudnn_frontend/backend/execution_helpers.h b/include/cudnn_frontend/backend/execution_helpers.h +index 334ffde..d2ca694 100644 +--- a/include/cudnn_frontend/backend/execution_helpers.h ++++ b/include/cudnn_frontend/backend/execution_helpers.h +@@ -2,7 +2,7 @@ + + #include + +-#include "cudnn.h" ++#include "third_party/gpus/cudnn/cudnn.h" + + #include "backend_descriptor.h" + +diff --git a/include/cudnn_frontend/backend/plan_helpers.h b/include/cudnn_frontend/backend/plan_helpers.h +index 1fa458d..8c37d10 100644 +--- a/include/cudnn_frontend/backend/plan_helpers.h ++++ b/include/cudnn_frontend/backend/plan_helpers.h +@@ -2,7 +2,7 @@ + + #include + +-#include "cudnn.h" ++#include "third_party/gpus/cudnn/cudnn.h" + + #include "backend_descriptor.h" + diff --git a/third_party/xla/workspace2.bzl b/third_party/xla/workspace2.bzl index 4044a9f10263bd..3d00bf1c1c4a1b 100644 --- a/third_party/xla/workspace2.bzl +++ b/third_party/xla/workspace2.bzl @@ -48,9 +48,9 @@ def _tf_repositories(): name = "cudnn_frontend_archive", build_file = "//third_party:cudnn_frontend.BUILD", patch_file = ["//third_party:cudnn_frontend_header_fix.patch"], - sha256 = "5727ed189a17fe888f1729ba09b2afd8df3e71192a27e9fa87e14a60f7b9d367", - strip_prefix = "cudnn-frontend-1.3.0", - urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.3.0.zip"), + sha256 = "281789777ac296f5f8215a7c4bd066de8816d240eb44c760788beebf8d25a99f", + strip_prefix = "cudnn-frontend-1.5.1", + urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.5.1.zip"), ) tf_http_archive( diff --git a/third_party/xla/xla/tsl/cuda/cudnn.symbols b/third_party/xla/xla/tsl/cuda/cudnn.symbols index 95c46295e1dcbb..7c97a717a9f614 100644 --- a/third_party/xla/xla/tsl/cuda/cudnn.symbols +++ b/third_party/xla/xla/tsl/cuda/cudnn.symbols @@ -270,3 +270,4 @@ cudnnSpatialTfSamplerForward cudnnTransformFilter cudnnTransformTensor cudnnTransformTensorEx +cudnnGetLastErrorString From a8c6cbd52060a8b21fcaba9ecd224a6d6e5d89da Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Fri, 21 Jun 2024 08:32:37 -0700 Subject: [PATCH 116/256] [XLA] Do not validate the operand layout constraint of LayoutConstraint custom calls since that shape may have been sharded. PiperOrigin-RevId: 645398598 --- third_party/xla/xla/service/hlo_verifier.cc | 3 ++- third_party/xla/xla/service/layout_assignment.cc | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/hlo_verifier.cc b/third_party/xla/xla/service/hlo_verifier.cc index b0023794dae0c5..c29f2a829a258a 100644 --- a/third_party/xla/xla/service/hlo_verifier.cc +++ b/third_party/xla/xla/service/hlo_verifier.cc @@ -1320,7 +1320,8 @@ absl::Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) { const HloCustomCallInstruction* custom_call = DynCast(instruction); TF_RET_CHECK(custom_call != nullptr); - if (custom_call->layout_constrained()) { + if (custom_call->layout_constrained() && + !custom_call->IsCustomCall("LayoutConstraint")) { // If the layout is constrained, verify all the respective shapes have // layouts and that the constrained operand shapes match the shapes of the // operands. diff --git a/third_party/xla/xla/service/layout_assignment.cc b/third_party/xla/xla/service/layout_assignment.cc index ba99b6772b1968..688af31615c710 100644 --- a/third_party/xla/xla/service/layout_assignment.cc +++ b/third_party/xla/xla/service/layout_assignment.cc @@ -969,7 +969,8 @@ bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) { // Operands of layout-constrained custom calls must match the expected // constrained layouts. absl::Status CheckCustomCallLayout(HloInstruction* instruction) { - if (IsLayoutConstrainedCustomCall(instruction)) { + if (IsLayoutConstrainedCustomCall(instruction) && + !instruction->IsCustomCall("LayoutConstraint")) { const HloCustomCallInstruction* custom_call = DynCast(instruction); for (int64_t i = 0; i < custom_call->operand_count(); ++i) { From 32f6b488f0ab9cc5a7ee39d8f339ceda3cce1d12 Mon Sep 17 00:00:00 2001 From: Kuy Mainwaring Date: Fri, 21 Jun 2024 08:43:44 -0700 Subject: [PATCH 117/256] [XLA:GPU] Clang-tidy cleanup for xla/service/ar_crs_combiner_test.cc PiperOrigin-RevId: 645401424 --- third_party/xla/xla/service/BUILD | 5 ++++- third_party/xla/xla/service/ar_crs_combiner_test.cc | 10 +++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 26fd0b20551252..760d50fbba592d 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -6892,11 +6892,14 @@ xla_cc_test( srcs = ["ar_crs_combiner_test.cc"], deps = [ ":ar_crs_combiner", + "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/lib/core:status_test_util", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/ar_crs_combiner_test.cc b/third_party/xla/xla/service/ar_crs_combiner_test.cc index 3e80b3d6de3e6a..e18d81d20fa93e 100644 --- a/third_party/xla/xla/service/ar_crs_combiner_test.cc +++ b/third_party/xla/xla/service/ar_crs_combiner_test.cc @@ -15,10 +15,18 @@ limitations under the License. #include "xla/service/ar_crs_combiner.h" +#include +#include +#include + +#include +#include #include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { From 5ad340221f3f1f4795b64c6468a48bf47260beaa Mon Sep 17 00:00:00 2001 From: Kuy Mainwaring Date: Fri, 21 Jun 2024 08:45:29 -0700 Subject: [PATCH 118/256] [XLA:GPU] Clang-tidy cleanup for xla/service/bfloat16_conversion_folding.cc PiperOrigin-RevId: 645401787 --- third_party/xla/xla/service/BUILD | 10 ++++++++-- .../xla/xla/service/bfloat16_conversion_folding.cc | 13 +++++++++++-- .../xla/service/bfloat16_conversion_folding_test.cc | 6 ++++++ 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 760d50fbba592d..3cc8fe9e299a0d 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -336,11 +336,16 @@ cc_library( ":float_support", ":hlo_dataflow_analysis", ":hlo_pass", - "//xla:status_macros", + "//xla:shape_util", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "@com_google_absl//absl/types:span", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", ], ) @@ -357,6 +362,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status:statusor", ], ) diff --git a/third_party/xla/xla/service/bfloat16_conversion_folding.cc b/third_party/xla/xla/service/bfloat16_conversion_folding.cc index fa7bd08e5338d4..ee6af27bf8dac1 100644 --- a/third_party/xla/xla/service/bfloat16_conversion_folding.cc +++ b/third_party/xla/xla/service/bfloat16_conversion_folding.cc @@ -15,15 +15,24 @@ limitations under the License. #include "xla/service/bfloat16_conversion_folding.h" -#include "absl/types/span.h" +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/float_support.h" #include "xla/service/hlo_dataflow_analysis.h" -#include "xla/status_macros.h" +#include "xla/shape_util.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/status.h" namespace xla { diff --git a/third_party/xla/xla/service/bfloat16_conversion_folding_test.cc b/third_party/xla/xla/service/bfloat16_conversion_folding_test.cc index dad6743c3c40ef..cd02490a843b05 100644 --- a/third_party/xla/xla/service/bfloat16_conversion_folding_test.cc +++ b/third_party/xla/xla/service/bfloat16_conversion_folding_test.cc @@ -15,11 +15,17 @@ limitations under the License. #include "xla/service/bfloat16_conversion_folding.h" +#include +#include + +#include "absl/status/statusor.h" +#include "xla/hlo/ir/collective_device_list.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/float_support.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" From 9f0e241c5ce342fd80ea33a32b4d2cadec7d178f Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Fri, 21 Jun 2024 08:46:20 -0700 Subject: [PATCH 119/256] [XLA:GPU] Parametrize Triton support tests by data type and device type. Now that the tests do not require hardware to execute, it's much easier (and cheaper!) to test different hardware platforms. A crucial change here is the change to calling `TritonWrapper` directly as opposed to `CreateTritonIrAndFileCheck`. Unlike the latter, `TritonWrapper` performs lowering all the way down to LLVM IR, which allows us to notice improperly gated Triton lowering crashes that were previously completely hidden. The new crashes are commented out pending investigation. Also drive-by delete the BF16 filters at the start of the test cases. `IsTritonSupportedInstruction` takes in a compute capability as a parameter and should return `false` itself if a data type is not supported. PiperOrigin-RevId: 645401955 --- third_party/xla/xla/service/gpu/BUILD | 5 +- .../xla/xla/service/gpu/triton_support.cc | 2 - .../xla/service/gpu/triton_support_test.cc | 448 ++++++++---------- .../xla/xla/service/gpu/triton_test_utils.cc | 34 +- .../xla/xla/service/gpu/triton_test_utils.h | 3 + 5 files changed, 232 insertions(+), 260 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index ef8e4f2fc60d87..203468c5175ed8 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1260,23 +1260,22 @@ cc_library( xla_cc_test( name = "triton_support_test", srcs = ["triton_support_test.cc"], + shard_count = 20, deps = [ ":gpu_device_info_for_tests", ":ir_emitter_triton", ":triton_support", ":triton_test_utils", + "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/gpu/model:tiled_hlo_computation", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/gpu/triton_support.cc b/third_party/xla/xla/service/gpu/triton_support.cc index 7d605a7aa62879..6b38d914b9c0e1 100644 --- a/third_party/xla/xla/service/gpu/triton_support.cc +++ b/third_party/xla/xla/service/gpu/triton_support.cc @@ -473,8 +473,6 @@ absl::flat_hash_set TritonSupportedBinaryElementwiseOps( absl::flat_hash_set additional_opcodes{ HloOpcode::kAtan2, HloOpcode::kDivide, HloOpcode::kPower}; ret.insert(additional_opcodes.begin(), additional_opcodes.end()); - } else if (element_type == PrimitiveType::BF16) { - ret.insert(HloOpcode::kDivide); } return ret; } diff --git a/third_party/xla/xla/service/gpu/triton_support_test.cc b/third_party/xla/xla/service/gpu/triton_support_test.cc index eeb9c617d34b6b..6ee50a9ce28c7e 100644 --- a/third_party/xla/xla/service/gpu/triton_support_test.cc +++ b/third_party/xla/xla/service/gpu/triton_support_test.cc @@ -18,18 +18,19 @@ limitations under the License. #include "xla/service/gpu/triton_support.h" #include -#include #include #include #include +#include #include #include #include #include "absl/algorithm/container.h" -#include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/primitive_util.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/ir_emitter_triton.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" @@ -37,9 +38,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/protobuf.h" -#include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" namespace xla { @@ -71,11 +70,31 @@ auto AllXlaDataTypes() { return ::testing::ValuesIn(xla_data_types); } +auto AllDevicesToTest() { + using cc = se::GpuComputeCapability; +#ifdef TENSORFLOW_USE_ROCM + se::RocmComputeCapability example_rocm_compute_capability = + TestGpuDeviceInfo::AMDMI210DeviceInfo().rocm_compute_capability(); + return ::testing::Values(cc(example_rocm_compute_capability)); +#else // GOOGLE_CUDA + return ::testing::Values(cc(se::CudaComputeCapability::Ampere()), + cc(se::CudaComputeCapability::Hopper())); +#endif +} + +// Generates all the possible test combinations for a given opcodes. A test +// combination is a tuple of the form (data_type, opcode, compute_capability). +auto AllTestCombinationsForOpcodes(std::vector&& opcodes) { + return ::testing::Combine(AllXlaDataTypes(), ::testing::ValuesIn(opcodes), + AllDevicesToTest()); +} + class TritonSupportTest : public TritonSupportTestBase { public: - // Runs a support test for the given `TestedInstruction`. The support test - // verifies that `IsTritonSupportedInstruction` is in sync with the - // implemented Triton emitter, i.e., given an instruction `instr`, either + // Runs a support test for the given `TestedInstruction` and the given + // compute capability. The support test verifies that + // `IsTritonSupportedInstruction` is in sync with the implemented Triton + // emitter, i.e., given an instruction `instr`, either // - `IsTritonSupportedInstruction(instr)` => Triton lowering is OK // - `!IsTritonSupportedInstruction(instr)` => Triton lowering is not OK. // @@ -89,36 +108,39 @@ class TritonSupportTest : public TritonSupportTestBase { // lowering test when `IsTritonSupportedInstruction` returns `false`. void RunSupportTest(TestedInstruction ti, std::vector output_tile_sizes, + se::GpuComputeCapability cc, bool skip_failure_branch_to_avoid_crash = false) { BlockLevelParameters block_level_parameters = FromOutputTileSizes(std::move(output_tile_sizes)); - if (IsTritonSupportedInstruction(ti.Instruction(), - GetComputeCapability())) { - TF_EXPECT_OK(CreateTritonIrAndFileCheck(ti.TritonComputation(), - block_level_parameters, - "CHECK: tt.func @triton_fn")); + const se::DeviceDescription dev_info = + std::holds_alternative(cc) + ? TestGpuDeviceInfo::RTXA6000DeviceInfo(cc) + : TestGpuDeviceInfo::AMDMI210DeviceInfo(); + if (IsTritonSupportedInstruction(ti.Instruction(), cc)) { + EXPECT_THAT( + TritonWrapper("test_fn", &ti.TritonFusion(), cc, dev_info, + block_level_parameters, &llvm_module_, mlir_context_), + IsOk()); } else { if (!skip_failure_branch_to_avoid_crash) { - const se::DeviceDescription dev_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(GetComputeCapability()); EXPECT_THAT( - TritonWrapper("test_fn", &ti.TritonFusion(), GetComputeCapability(), - dev_info, block_level_parameters, &llvm_module_, - mlir_context_), + TritonWrapper("test_fn", &ti.TritonFusion(), cc, dev_info, + block_level_parameters, &llvm_module_, mlir_context_), Not(IsOk())); } } } }; -class TritonSupportTestWithParam : public TritonSupportTest, - public ::testing::WithParamInterface< - std::tuple> {}; +class TritonSupportTestWithParam + : public TritonSupportTest, + public ::testing::WithParamInterface< + std::tuple> {}; using BitcastOrReshapeTest = TritonSupportTestWithParam; TEST_P(BitcastOrReshapeTest, IsTritonSupportedBitcastOrReshape) { - auto [data_type, opcode] = GetParam(); + auto [data_type, opcode, cc] = GetParam(); const std::string kHloTestTemplate = R"( ENTRY triton_computation { parameter_0 = $0[1,16,4]{2,1,0} parameter(0) @@ -127,26 +149,20 @@ ENTRY triton_computation { TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); - RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16}); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16}, cc); } -INSTANTIATE_TEST_SUITE_P( - BitcastOrReshapeTestSuite, BitcastOrReshapeTest, - ::testing::Combine(AllXlaDataTypes(), - ::testing::Values(HloOpcode::kBitcast, - HloOpcode::kReshape)), - TritonSupportTestParamsToString); +INSTANTIATE_TEST_SUITE_P(BitcastOrReshapeTestSuite, BitcastOrReshapeTest, + AllTestCombinationsForOpcodes({HloOpcode::kBitcast, + HloOpcode::kReshape}), + TritonSupportTestTypeOpcodeAndDeviceToString); using UnaryElementwiseTest = TritonSupportTestWithParam; // TODO(b/331636835): updates elementwise op tests to directly emit single op // instead of relying on triton gemm kernel. TEST_P(UnaryElementwiseTest, IsTritonSupportedUnaryElementwise) { - auto [data_type, opcode] = GetParam(); - if (data_type == BF16 && !SupportsBF16(GetComputeCapability())) { - GTEST_SKIP(); - } - + auto [data_type, opcode, cc] = GetParam(); const std::string kHloTestTemplate = R"( ENTRY triton_computation { parameter_0 = $0[33,68]{1,0} parameter(0) @@ -156,20 +172,24 @@ ENTRY triton_computation { TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); - RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}, cc); } +// TODO(b/345763510): make sure to test all the data types for the unary, +// binary, and ternary elementwise ops. INSTANTIATE_TEST_SUITE_P( UnaryElementwiseTestSuite, UnaryElementwiseTest, ::testing::Combine(::testing::Values(S8, S16, S32, F16, F32, BF16), ::testing::Values(HloOpcode::kConvert, HloOpcode::kAbs, - HloOpcode::kNegate)), - TritonSupportTestParamsToString); + HloOpcode::kNegate), + AllDevicesToTest()), + TritonSupportTestTypeOpcodeAndDeviceToString); INSTANTIATE_TEST_SUITE_P( UnaryPREDTestSuite, UnaryElementwiseTest, ::testing::Combine(::testing::Values(PRED), - ::testing::Values(HloOpcode::kConvert, HloOpcode::kNot)), - TritonSupportTestParamsToString); + ::testing::Values(HloOpcode::kConvert, HloOpcode::kNot), + AllDevicesToTest()), + TritonSupportTestTypeOpcodeAndDeviceToString); INSTANTIATE_TEST_SUITE_P( UnaryMathTestSuite, UnaryElementwiseTest, ::testing::Combine(::testing::Values(F16, F32, BF16), @@ -179,17 +199,14 @@ INSTANTIATE_TEST_SUITE_P( HloOpcode::kLog1p, HloOpcode::kRsqrt, HloOpcode::kSin, HloOpcode::kSqrt, HloOpcode::kCbrt, HloOpcode::kTan, - HloOpcode::kTanh, HloOpcode::kErf)), - TritonSupportTestParamsToString); + HloOpcode::kTanh, HloOpcode::kErf), + AllDevicesToTest()), + TritonSupportTestTypeOpcodeAndDeviceToString); using BinaryElementwiseTest = TritonSupportTestWithParam; TEST_P(BinaryElementwiseTest, IsTritonSupportedBinaryElementwise) { - auto [data_type, opcode] = GetParam(); - if (data_type == BF16 && !SupportsBF16(GetComputeCapability())) { - GTEST_SKIP(); - } - + auto [data_type, opcode, cc] = GetParam(); const std::string kHloTestTemplate = R"( ENTRY triton_computation { parameter_0 = $0[11,63]{1,0} parameter(0) @@ -201,10 +218,11 @@ ENTRY triton_computation { ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); bool skip_failure_branch_to_avoid_crash = false; - if (data_type == F16 && opcode == HloOpcode::kDivide) { + if (primitive_util::BitWidth(data_type) == 16 && + opcode == HloOpcode::kDivide) { skip_failure_branch_to_avoid_crash = true; } - RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}, + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}, cc, /*skip_failure_branch_to_avoid_crash=*/ skip_failure_branch_to_avoid_crash); } @@ -215,30 +233,29 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(HloOpcode::kAdd, HloOpcode::kMultiply, HloOpcode::kMaximum, HloOpcode::kMinimum, - HloOpcode::kSubtract)), - TritonSupportTestParamsToString); + HloOpcode::kSubtract), + AllDevicesToTest()), + TritonSupportTestTypeOpcodeAndDeviceToString); INSTANTIATE_TEST_SUITE_P(BinaryPREDTestSuite, BinaryElementwiseTest, ::testing::Combine(::testing::Values(PRED), ::testing::Values(HloOpcode::kAnd, HloOpcode::kOr, - HloOpcode::kXor)), - TritonSupportTestParamsToString); + HloOpcode::kXor), + AllDevicesToTest()), + TritonSupportTestTypeOpcodeAndDeviceToString); INSTANTIATE_TEST_SUITE_P( BinaryMathTestSuite, BinaryElementwiseTest, ::testing::Combine(::testing::Values(F16, F32, BF16), ::testing::Values(HloOpcode::kAtan2, HloOpcode::kDivide, - HloOpcode::kPower)), - TritonSupportTestParamsToString); + HloOpcode::kPower), + AllDevicesToTest()), + TritonSupportTestTypeOpcodeAndDeviceToString); using CompareTest = TritonSupportTestWithParam; TEST_P(CompareTest, IsTritonSupportedCompare) { - auto [data_type, opcode] = GetParam(); - if (data_type == BF16 && !SupportsBF16(GetComputeCapability())) { - GTEST_SKIP(); - } - + auto [data_type, opcode, cc] = GetParam(); const std::string kHloTestTemplate = R"( ENTRY triton_computation { parameter_0 = $0[11,63]{1,0} parameter(0) @@ -249,23 +266,20 @@ ENTRY triton_computation { TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); - RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}, cc); } INSTANTIATE_TEST_SUITE_P( CompareTestSuite, CompareTest, ::testing::Combine(::testing::Values(PRED, S8, S16, S32, F16, F32, BF16), - ::testing::Values(HloOpcode::kCompare)), - TritonSupportTestParamsToString); + ::testing::Values(HloOpcode::kCompare), + AllDevicesToTest()), + TritonSupportTestTypeOpcodeAndDeviceToString); using TernaryElementwiseTest = TritonSupportTestWithParam; TEST_P(TernaryElementwiseTest, IsTritonSupportedTernaryElementwise) { - auto [data_type, opcode] = GetParam(); - if (data_type == BF16 && !SupportsBF16(GetComputeCapability())) { - GTEST_SKIP(); - } - + auto [data_type, opcode, cc] = GetParam(); const std::string kHloTestTemplate = R"( ENTRY triton_computation { parameter_0 = $0[13,63]{1,0} parameter(0) @@ -277,25 +291,24 @@ ENTRY triton_computation { TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); - RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}, cc); } INSTANTIATE_TEST_SUITE_P( TernaryElementwiseTestSuite, TernaryElementwiseTest, ::testing::Combine(::testing::Values(PRED, S8, S16, S32, F16, F32, BF16), - ::testing::Values(HloOpcode::kSelect)), - TritonSupportTestParamsToString); - -using ReduceConstTest = TritonSupportTestWithParam; - -TEST_P(ReduceConstTest, IsTritonSupportedReduceWithConstInit) { - auto [data_type, opcode] = GetParam(); - if (data_type == BF16 && !SupportsBF16(GetComputeCapability())) { - GTEST_SKIP(); - } - - const std::string kHloTestTemplate = R"( -HloModule t + ::testing::Values(HloOpcode::kSelect), + AllDevicesToTest()), + TritonSupportTestTypeOpcodeAndDeviceToString); + +using ReduceTest = TritonSupportTestWithParam; + +TEST_P(ReduceTest, IsTritonSupportedReduction) { + GTEST_SKIP() << "TODO(b/348565795): this test is currently broken."; + auto [data_type, opcode, cc] = GetParam(); + bool dtype_is_complex = data_type == C64 || data_type == C128; + const std::string kHloTestTemplate = + absl::Substitute(R"( add { Arg_0 = $0[] parameter(0) Arg_1 = $0[] parameter(1) @@ -304,213 +317,152 @@ add { ENTRY triton_computation { parameter_0 = $0[125,127]{1,0} parameter(0) - constant_0 = $0[] constant(0) - ROOT reduce = $0[125]{0} $1(parameter_0, constant_0), dimensions={1}, to_apply=add -})"; + constant_0 = $0[] constant($1) + ROOT reduce = $0[125]{0} reduce(parameter_0, constant_0), + dimensions={1}, to_apply=add +})", + "$0", dtype_is_complex ? "(0, 0)" : "0"); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); - RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}); -} - -INSTANTIATE_TEST_SUITE_P( - ReduceConstTestSuite, ReduceConstTest, - ::testing::Combine(::testing::Values(F16, F32, BF16), - ::testing::Values(HloOpcode::kReduce)), - TritonSupportTestParamsToString); - -TEST_F(TritonSupportTest, - SupportedReduceWithConvertConstantIsCodegenedSuccessfullyWithTriton) { - if (!SupportsBF16(GetComputeCapability())) { - GTEST_SKIP(); - } - const std::string kHloTest = R"( -add { - Arg_0 = f32[] parameter(0) - Arg_1 = f32[] parameter(1) - ROOT add = f32[] add(Arg_0, Arg_1) + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); } -ENTRY triton_computation { - parameter_0 = f32[125,127]{1,0} parameter(0) - constant_0 = bf16[] constant(0) - convert_0 = f32[] convert(constant_0) - ROOT reduce = f32[125]{0} reduce(parameter_0, convert_0), dimensions={1}, to_apply=add -})"; - TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, - ParseTemplateAndGetInstruction( - kHloTest, /*data_type=*/{}, HloOpcode::kReduce)); - EXPECT_TRUE( - IsTritonSupportedInstruction(ti.Instruction(), GetComputeCapability()) - .CanFuse()); - TF_EXPECT_OK( - ApplyFloatNormalization(ti.Module().get(), GetComputeCapability())); - TF_EXPECT_OK(CreateTritonIrAndFileCheck(ti.TritonComputation(), - FromOutputTileSizes({1}), - "CHECK: tt.func @triton_fn")); -} - -TEST_F( - TritonSupportTestBase, +TEST_P( + ReduceTest, UnsupportedReduceWithMoreThanOneReduceDimensionsFailsGracefullyWithTriton) { - const std::string kHloTest = R"( + auto [data_type, opcode, cc] = GetParam(); + bool dtype_is_complex = data_type == C64 || data_type == C128; + const std::string kHloTestTemplate = + absl::Substitute(R"( add { - Arg_0 = f32[] parameter(0) - Arg_1 = f32[] parameter(1) - ROOT add = f32[] add(Arg_0, Arg_1) + Arg_0 = $0[] parameter(0) + Arg_1 = $0[] parameter(1) + ROOT add = $0[] add(Arg_0, Arg_1) } ENTRY triton_computation { - parameter_0 = f32[2,125,127]{2,1,0} parameter(0) - constant_0 = f32[] constant(0) - ROOT reduce = f32[2]{0} reduce(parameter_0, constant_0), dimensions={1,2}, to_apply=add -})"; - TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, - ParseTemplateAndGetInstruction( - kHloTest, /*data_type=*/{}, HloOpcode::kReduce)); - EXPECT_THAT( - IsTritonSupportedInstruction(ti.Instruction(), GetComputeCapability()) - .Explain(), - ::testing::HasSubstr( - "Reduction is not a row-reduction of a single operand.")); - const se::DeviceDescription dev_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(GetComputeCapability()); - EXPECT_THAT( - TritonWrapper("test_fn", &ti.TritonFusion(), GetComputeCapability(), - dev_info, FromOutputTileSizes({1}), &llvm_module_, - mlir_context_), - Not(IsOk())); + parameter_0 = $0[2,125,127]{2,1,0} parameter(0) + constant_0 = $0[] constant($1) + ROOT reduce = $0[2]{0} reduce(parameter_0, constant_0), + dimensions={1,2}, to_apply=add +})", + "$0", dtype_is_complex ? "(0, 0)" : "0"); + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); + EXPECT_FALSE(IsTritonSupportedInstruction(ti.Instruction(), cc)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); } -TEST_F(TritonSupportTest, +TEST_P(ReduceTest, UnsupportedReduceWithNonLastReduceDimensionFailsGracefullyWithTriton) { - const std::string kHloTest = R"( + auto [data_type, opcode, cc] = GetParam(); + bool dtype_is_complex = data_type == C64 || data_type == C128; + const std::string kHloTestTemplate = + absl::Substitute(R"( add { - Arg_0 = f32[] parameter(0) - Arg_1 = f32[] parameter(1) - ROOT add = f32[] add(Arg_0, Arg_1) + Arg_0 = $0[] parameter(0) + Arg_1 = $0[] parameter(1) + ROOT add = $0[] add(Arg_0, Arg_1) } ENTRY triton_computation { - parameter_0 = f32[125,127]{1,0} parameter(0) - constant_0 = f32[] constant(0) - ROOT reduce = f32[127]{0} reduce(parameter_0, constant_0), dimensions={0}, to_apply=add -})"; - TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, - ParseTemplateAndGetInstruction( - kHloTest, /*data_type=*/{}, HloOpcode::kReduce)); - EXPECT_THAT( - IsTritonSupportedInstruction(ti.Instruction(), GetComputeCapability()) - .Explain(), - ::testing::HasSubstr( - "Reduction is not a row-reduction of a single operand.")); - const se::DeviceDescription dev_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(GetComputeCapability()); - EXPECT_THAT( - TritonWrapper("test_fn", &ti.TritonFusion(), GetComputeCapability(), - dev_info, FromOutputTileSizes({1}), &llvm_module_, - mlir_context_), - Not(IsOk())); + parameter_0 = $0[125,127]{1,0} parameter(0) + constant_0 = $0[] constant($1) + ROOT reduce = $0[127]{0} reduce(parameter_0, constant_0), dimensions={0}, to_apply=add +})", + "$0", dtype_is_complex ? "(0, 0)" : "0"); + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); + EXPECT_FALSE(IsTritonSupportedInstruction(ti.Instruction(), cc)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); } -TEST_F(TritonSupportTest, +TEST_P(ReduceTest, UnsupportedReduceWithMoreThanOneOperandsFailsGracefullyWithTriton) { - const std::string kHloTest = R"( + auto [data_type, opcode, cc] = GetParam(); + bool dtype_is_complex = data_type == C64 || data_type == C128; + const std::string kHloTestTemplate = + absl::Substitute(R"( add { - Arg_0 = f32[] parameter(0) - Arg_2 = f32[] parameter(1) - Arg_1 = f32[] parameter(2) - Arg_3 = f32[] parameter(3) - add_0 = f32[] add(Arg_0, Arg_2) - add_1 = f32[] add(Arg_1, Arg_3) - ROOT pair = (f32[], f32[]) tuple(add_0, add_1) + Arg_0 = $0[] parameter(0) + Arg_1 = $0[] parameter(1) + Arg_2 = $0[] parameter(2) + Arg_3 = $0[] parameter(3) + add_0 = $0[] add(Arg_0, Arg_2) + add_1 = $0[] add(Arg_1, Arg_3) + ROOT pair = ($0[], $0[]) tuple(add_0, add_1) } ENTRY triton_computation { - parameter_0 = f32[125,127] parameter(0) - constant_0 = f32[] constant(0) - tuple_0 = (f32[125]{0}, f32[125]{0}) reduce(parameter_0, parameter_0, constant_0, constant_0), dimensions={1}, to_apply=add - ROOT reduce = f32[125]{0} get-tuple-element(tuple_0), index=0 -})"; - TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, - ParseTemplateAndGetInstruction( - kHloTest, /*data_type=*/{}, HloOpcode::kReduce)); - EXPECT_THAT( - IsTritonSupportedInstruction(ti.Instruction(), GetComputeCapability()) - .Explain(), - ::testing::HasSubstr("Unsupported output data type")); - const se::DeviceDescription dev_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(GetComputeCapability()); - EXPECT_THAT( - TritonWrapper("test_fn", &ti.TritonFusion(), GetComputeCapability(), - dev_info, FromOutputTileSizes({1}), &llvm_module_, - mlir_context_), - Not(IsOk())); + parameter_0 = $0[125,127] parameter(0) + constant_0 = $0[] constant($1) + tuple = ($0[125]{0}, $0[125]{0}) reduce( + parameter_0, parameter_0, constant_0, constant_0), + dimensions={1}, to_apply=add + ROOT reduce = $0[125]{0} get-tuple-element(tuple), index=0 +})", + "$0", dtype_is_complex ? "(0, 0)" : "0"); + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); + EXPECT_FALSE(IsTritonSupportedInstruction(ti.Instruction(), cc)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); } -TEST_F(TritonSupportTest, +TEST_P(ReduceTest, UnsupportedReduceWithNonConstReduceValueFailsGracefullyWithTriton) { - const std::string kHloTest = R"( + auto [data_type, opcode, cc] = GetParam(); + const std::string kHloTestTemplate = R"( add { - Arg_0 = f32[] parameter(0) - Arg_1 = f32[] parameter(1) - ROOT add = f32[] add(Arg_0, Arg_1) + Arg_0 = $0[] parameter(0) + Arg_1 = $0[] parameter(1) + ROOT add = $0[] add(Arg_0, Arg_1) } ENTRY triton_computation { - parameter_0 = f32[125,127]{1,0} parameter(0) - init = f32[] parameter(1) - ROOT reduce = f32[125]{0} reduce(parameter_0, init), dimensions={1}, to_apply=add + parameter_0 = $0[125,127]{1,0} parameter(0) + init = $0[] parameter(1) + ROOT reduce = $0[125]{0} reduce(parameter_0, init), dimensions={1}, to_apply=add })"; - TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, - ParseTemplateAndGetInstruction( - kHloTest, /*data_type=*/{}, HloOpcode::kReduce)); - const se::DeviceDescription dev_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(GetComputeCapability()); - EXPECT_THAT( - IsTritonSupportedInstruction(ti.Instruction(), GetComputeCapability()) - .Explain(), - ::testing::HasSubstr("Reduction init value should be a constant " - "or a convert of a constant.")); - EXPECT_THAT( - TritonWrapper("test_fn", &ti.TritonFusion(), GetComputeCapability(), - dev_info, FromOutputTileSizes({1}), &llvm_module_, - mlir_context_), - tsl::testing::StatusIs( - absl::StatusCode::kInternal, - ::testing::HasSubstr("operand->opcode() == HloOpcode::kConstant"))); + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); + EXPECT_FALSE(IsTritonSupportedInstruction(ti.Instruction(), cc)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); } -TEST_F(TritonSupportTest, - UnsupportedReductionComputationFailsGracefullyWithTriton) { - const std::string kHloTest = R"( +TEST_P(ReduceTest, UnsupportedReductionComputationFailsGracefullyWithTriton) { + auto [data_type, opcode, cc] = GetParam(); + bool dtype_is_complex = data_type == C64 || data_type == C128; + const std::string kHloTestTemplate = + absl::Substitute(R"( custom_call { - Arg_0 = f32[] parameter(0) - Arg_1 = f32[] parameter(1) - ROOT custom_call = f32[] custom-call(Arg_0, Arg_1), custom_call_target="foo" + Arg_0 = $0[] parameter(0) + Arg_1 = $0[] parameter(1) + ROOT custom_call = $0[] custom-call(Arg_0, Arg_1), custom_call_target="foo" } ENTRY triton_computation { - parameter_0 = f32[125,127]{1,0} parameter(0) - constant_0 = f32[] constant(0) - ROOT reduce = f32[125]{0} reduce(parameter_0, constant_0), dimensions={1}, to_apply=custom_call -})"; - TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, - ParseTemplateAndGetInstruction( - kHloTest, /*data_type=*/{}, HloOpcode::kReduce)); - const se::DeviceDescription dev_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(GetComputeCapability()); - EXPECT_THAT( - IsTritonSupportedInstruction(ti.Instruction(), GetComputeCapability()) - .Explain(), - ::testing::HasSubstr("Unsupported reduction computation by Triton.")); - EXPECT_THAT( - TritonWrapper("test_fn", &ti.TritonFusion(), GetComputeCapability(), - dev_info, FromOutputTileSizes({1}), &llvm_module_, - mlir_context_), - tsl::testing::StatusIs(absl::StatusCode::kInvalidArgument, - ::testing::HasSubstr("Unsupported operation"))); + parameter_0 = $0[125,127]{1,0} parameter(0) + constant_0 = $0[] constant($1) + ROOT reduce = $0[125]{0} reduce(parameter_0, constant_0), + dimensions={1}, to_apply=custom_call +})", + "$0", dtype_is_complex ? "(0, 0)" : "0"); + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); + EXPECT_FALSE(IsTritonSupportedInstruction(ti.Instruction(), cc)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); } + +INSTANTIATE_TEST_SUITE_P(ReduceTestSuite, ReduceTest, + AllTestCombinationsForOpcodes({HloOpcode::kReduce}), + TritonSupportTestTypeOpcodeAndDeviceToString); + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/triton_test_utils.cc b/third_party/xla/xla/service/gpu/triton_test_utils.cc index 9f4cd4cc8717bc..c5712d567d044c 100644 --- a/third_party/xla/xla/service/gpu/triton_test_utils.cc +++ b/third_party/xla/xla/service/gpu/triton_test_utils.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/service/gpu/triton_test_utils.h" -#include #include #include #include @@ -126,17 +125,38 @@ absl::StatusOr ApplyFloatNormalization( return pipeline.Run(module); } -std::string TritonSupportTestParamsToString( - const ::testing::TestParamInfo>& - data) { - PrimitiveType data_type; - HloOpcode opcode; - std::tie(data_type, opcode) = data.param; +namespace { + +std::string PrimitiveTypeAndHloOpcodeToString(PrimitiveType data_type, + HloOpcode opcode) { return absl::StrCat( primitive_util::LowercasePrimitiveTypeName(data_type), "_", absl::StrReplaceAll(HloOpcodeString(opcode), {{"-", "_"}})); } +} // namespace + +std::string TritonSupportTestParamsToString( + const ::testing::TestParamInfo>& + data) { + auto [data_type, opcode] = data.param; + return PrimitiveTypeAndHloOpcodeToString(data_type, opcode); +} + +std::string TritonSupportTestTypeOpcodeAndDeviceToString( + const ::testing::TestParamInfo< + std::tuple>& data) { + auto [data_type, opcode, cc] = data.param; + std::string cc_str; + if (std::holds_alternative(cc)) { + cc_str = std::get(cc).ToString(); + } else { + cc_str = "rocm"; + } + return absl::StrCat(PrimitiveTypeAndHloOpcodeToString(data_type, opcode), "_", + absl::StrReplaceAll(cc_str, {{".", ""}})); +} + namespace { // This function does nothing if the input module already has an entry diff --git a/third_party/xla/xla/service/gpu/triton_test_utils.h b/third_party/xla/xla/service/gpu/triton_test_utils.h index 25a7b4b1811d27..e25341b1529a82 100644 --- a/third_party/xla/xla/service/gpu/triton_test_utils.h +++ b/third_party/xla/xla/service/gpu/triton_test_utils.h @@ -134,6 +134,9 @@ class TritonSupportTestBaseWithParam std::string TritonSupportTestParamsToString( const ::testing::TestParamInfo>& data); +std::string TritonSupportTestTypeOpcodeAndDeviceToString( + const ::testing::TestParamInfo< + std::tuple>& data); } // namespace xla::gpu #endif // XLA_SERVICE_GPU_TRITON_TEST_UTILS_H_ From a9ab5056e20aefc114761ba9468b0af084e28c32 Mon Sep 17 00:00:00 2001 From: Marissa Ikonomidis Date: Fri, 21 Jun 2024 08:48:55 -0700 Subject: [PATCH 120/256] Introduce utility function GetPrevAllowedBatchSize. Original author is piatov@. PiperOrigin-RevId: 645402594 --- tensorflow/core/kernels/batching_util/BUILD | 1 + .../batching_util/batch_scheduler_utils.cc | 27 +++++++++++++++++++ .../batching_util/batch_scheduler_utils.h | 6 +++++ .../batch_scheduler_utils_test.cc | 24 +++++++++++++++++ 4 files changed, 58 insertions(+) diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD index 74834c9e2a9707..a136eeb1fab768 100644 --- a/tensorflow/core/kernels/batching_util/BUILD +++ b/tensorflow/core/kernels/batching_util/BUILD @@ -149,6 +149,7 @@ cc_library( hdrs = ["batch_scheduler_utils.h"], deps = [ "//tensorflow/core:portable_gif_internal", + "@com_google_absl//absl/algorithm:container", ], ) diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler_utils.cc b/tensorflow/core/kernels/batching_util/batch_scheduler_utils.cc index 148bb7c039f2dc..21622a5f0ba2f0 100644 --- a/tensorflow/core/kernels/batching_util/batch_scheduler_utils.cc +++ b/tensorflow/core/kernels/batching_util/batch_scheduler_utils.cc @@ -15,8 +15,10 @@ limitations under the License. #include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h" +#include #include +#include "absl/algorithm/container.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -29,6 +31,8 @@ int GetNextAllowedBatchSize(int batch_size, if (disable_padding || allowed_batch_sizes.empty()) { return batch_size; } + DCHECK(absl::c_is_sorted(allowed_batch_sizes)); + DCHECK_GT(batch_size, 0); for (int allowed_size : allowed_batch_sizes) { if (allowed_size >= batch_size) { return allowed_size; @@ -40,5 +44,28 @@ int GetNextAllowedBatchSize(int batch_size, return batch_size; } +int32 GetPrevAllowedBatchSize(int batch_size, + const std::vector& allowed_batch_sizes, + bool disable_padding) { + if (disable_padding || allowed_batch_sizes.empty()) { + return batch_size; + } + + DCHECK(absl::c_is_sorted(allowed_batch_sizes)); + DCHECK_GT(batch_size, 0); + + // First from the end allowed_batch_size not larger than batch_size. + auto result = std::find_if( + allowed_batch_sizes.rbegin(), allowed_batch_sizes.rend(), + [&](int allowed_size) { return allowed_size <= batch_size; }); + + if (result == allowed_batch_sizes.rend()) { + // No such element exists. + return batch_size; + } + + return *result; +} + } // namespace serving } // namespace tensorflow diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h b/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h index 38831531abd6e7..7e4382a9d862db 100644 --- a/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h +++ b/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h @@ -30,6 +30,12 @@ int GetNextAllowedBatchSize(int batch_size, const std::vector& allowed_batch_sizes, bool disable_padding); +// Returns the largest allowed batch size that is smaller than or equal to +// batch_size. Returns batch_size if no such size exists. +int GetPrevAllowedBatchSize(int batch_size, + const std::vector& allowed_batch_sizes, + bool disable_padding); + } // namespace serving } // namespace tensorflow diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler_utils_test.cc b/tensorflow/core/kernels/batching_util/batch_scheduler_utils_test.cc index 9cd6ce1ddcb210..2bff515a57aeb9 100644 --- a/tensorflow/core/kernels/batching_util/batch_scheduler_utils_test.cc +++ b/tensorflow/core/kernels/batching_util/batch_scheduler_utils_test.cc @@ -42,6 +42,30 @@ TEST(GetNextAllowedBatchSizeTest, GreaterThanAllowedBatchSize) { EXPECT_EQ(GetNextAllowedBatchSize(10, {2, 4, 8}, false), 10); } +TEST(GetPrevAllowedBatchSizeTest, PaddingDisallowed) { + EXPECT_EQ(GetPrevAllowedBatchSize(3, {2, 4, 8}, true), 3); +} + +TEST(GetPrevAllowedBatchSizeTest, EmptyAllowedBatchSizes) { + EXPECT_EQ(GetPrevAllowedBatchSize(3, {}, false), 3); +} + +TEST(GetPrevAllowedBatchSizeTest, PrevAllowedBatchSizeFound) { + EXPECT_EQ(GetPrevAllowedBatchSize(3, {1, 2, 4, 8}, false), 2); +} + +TEST(GetPrevAllowedBatchSizeTest, NoSmallerAllowedBatchSizeFound) { + EXPECT_EQ(GetPrevAllowedBatchSize(3, {4, 8}, false), 3); +} + +TEST(GetPrevAllowedBatchSizeTest, AlreadyAllowedBatchSize) { + EXPECT_EQ(GetPrevAllowedBatchSize(2, {1, 2, 4, 8}, false), 2); +} + +TEST(GetPrevAllowedBatchSizeTest, GreaterThanMaxAllowedBatchSize) { + EXPECT_EQ(GetPrevAllowedBatchSize(10, {2, 4, 8}, false), 8); +} + } // namespace } // namespace serving From a892e21183c030a5b571e402af5e93127ca0e16a Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Fri, 21 Jun 2024 09:03:41 -0700 Subject: [PATCH 121/256] [XLA:GPU] Introduce `ConstraintExpression` to hold generalized constraint expressions for `SymbolicTile{Analysis}`. `ConstraintExpression` represents a "flat" constraint expression of the form   ((expr0 in interval0) && (expr1 in interval1)...) || ((expr{n} in interval{n}) &&...)... The underlying constraints are stored in a vector of maps, such that each map represents the conjunction of some constraints, and the vector represents the disjunction of all its contained maps (conjunctions). This representation is effective because `&&` (`And`) is distributive over `||` (`Or`), ensuring that we can always flatten any given `ConstraintExpression` in this way, and that we have reasonable combinators for `&&` and `||`. We store a boolean `is_satisfiable_` to indicate whether we expect that the constraints can be satisfied. When set to `false`, we expect the `ConstraintExpression` to be empty (bottom). `ConstraintExpression` is plumbed through `SymbolicTile` and `SymbolicTileAnalysis` but disjunctions are not yet exploited. Future changes to the derivation logic for `SymbolicTile`s will address that. PiperOrigin-RevId: 645406452 --- third_party/xla/xla/service/gpu/model/BUILD | 3 +- .../xla/service/gpu/model/symbolic_tile.cc | 232 ++++++++++++----- .../xla/xla/service/gpu/model/symbolic_tile.h | 107 +++++--- .../gpu/model/symbolic_tile_analysis.cc | 66 +++-- .../gpu/model/symbolic_tile_analysis.h | 12 +- .../gpu/model/symbolic_tile_analysis_test.cc | 20 +- .../service/gpu/model/symbolic_tile_test.cc | 238 +++++++++++++++++- 7 files changed, 548 insertions(+), 130 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 8aaa9b9e3c327b..81a08a64f9cfeb 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -533,7 +533,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -691,6 +691,7 @@ xla_cc_test( srcs = ["symbolic_tile_analysis_test.cc"], deps = [ ":indexing_test_utils", + ":symbolic_tile", ":symbolic_tile_analysis", ":tiled_hlo_computation", ":tiled_hlo_instruction", diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc index 6b6c69caa2b732..0b68868bf5c214 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc @@ -28,7 +28,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/log/check.h" #include "absl/log/log.h" -#include "absl/strings/string_view.h" +#include "absl/strings/str_join.h" #include "absl/types/span.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -109,16 +109,14 @@ AffineMap SubstituteAllIndicesAndRangeVarSymbolsWithSameValue( struct SizeAndStrideExpression { AffineExpr size; AffineExpr stride; - ConstraintMap constraints; - bool is_satisfiable; + ConstraintExpression constraints; - SizeAndStrideExpression(AffineExpr size, AffineExpr stride, - ConstraintMap constraints = ConstraintMap(), - bool is_satisfiable = true) + SizeAndStrideExpression( + AffineExpr size, AffineExpr stride, + ConstraintExpression constraints = ConstraintExpression()) : size(std::move(size)), stride(std::move(stride)), - constraints(std::move(constraints)), - is_satisfiable(is_satisfiable) {} + constraints(std::move(constraints)) {} }; // Extracts size and stride expressions from the operands to a modulo @@ -154,10 +152,12 @@ std::optional ExtractSizeAndStrideFromMod( AffineExpr constrained_expr = getAffineSymbolExpr(dim_expr.getPosition(), lhs.getContext()) % modulus; - ConstraintMap constraints; + ConstraintExpression constraints; // TODO(b/334043867): we only add a constraint for n being a multiple of c // while we do not support disjunctions. - constraints.insert({constrained_expr, Interval{/*lower=*/0, /*upper=*/0}}); + ConstraintExpression::ConjointConstraints conjunction; + conjunction.insert({constrained_expr, Interval{/*lower=*/0, /*upper=*/0}}); + constraints.And(std::move(conjunction)); // In this case, stride is effectively 1 mod modulus = 1. return SizeAndStrideExpression( @@ -472,25 +472,11 @@ std::optional CombineSizesAndStrides( } } - std::optional maybe_constraints = ConstraintMap(); + ConstraintExpression constraints; - for (const SizeAndStrideExpression& size_and_stride : sizes_and_strides) { - maybe_constraints = MergeConstraintMapIfPresentAndCompatible( - std::move(maybe_constraints), size_and_stride.constraints); - if (!maybe_constraints.has_value()) { - break; - } - } - - ConstraintMap constraints; - bool is_satisfiable = true; - - // Handle cases that we don't know how to process by constructing a - // ConstraintMap with an unsatisfiable constraint. - if (maybe_constraints.has_value()) { - constraints = std::move(*maybe_constraints); - } else { - is_satisfiable = false; + for (SizeAndStrideExpression& size_and_stride : sizes_and_strides) { + constraints = ConstraintExpression::And( + std::move(constraints), std::move(size_and_stride.constraints)); } AffineExpr size = CombineSizes(sizes_and_strides); @@ -501,8 +487,7 @@ std::optional CombineSizesAndStrides( } // TODO(b/326998704): handle reshape constraints here. - return SizeAndStrideExpression(size, *stride, std::move(constraints), - is_satisfiable); + return SizeAndStrideExpression(size, *stride, std::move(constraints)); } std::optional ExtractSizeAndStride( @@ -612,8 +597,21 @@ AffineExpr SimplifyAffineExpr(const AffineExpr& expr, return tmp_indexing_map.GetAffineMap().getResults().back(); } -} // anonymous namespace - +// Merges `maybe_first_map` and `second_map` if +// (1) `maybe_first_map` is present, and +// (2) `second_map` and `*maybe_first_map` have distinct sets of keys. +// Otherwise, returns `std::nullopt`. +// +// +// The behaviour of this function is in spirit equivalent to using C++23's +// `std::optional::and_then` to merge a collection of `ConstraintMap`s. +// +// We pass `maybe_first_map` by value here in order to exploit move semantics +// to avoid copies when possible. +// +// TODO(bchetioui): allow merging constraints in more edge cases, e.g. if one +// of the intervals is contained within the other. +// TODO(bchetioui): clean up this util. std::optional MergeConstraintMapIfPresentAndCompatible( std::optional maybe_first_map, const ConstraintMap& second_map) { @@ -637,6 +635,122 @@ std::optional MergeConstraintMapIfPresentAndCompatible( return first_map; } +} // anonymous namespace + +/*static*/ ConstraintExpression ConstraintExpression::And( + ConstraintExpression first, ConstraintExpression second) { + // When either one of the expressions is unsatisfiable, their conjunction is + // necessarily unsatisfiable. + if (!first.is_satisfiable_ || !second.is_satisfiable_) { + return ConstraintExpression::GetUnsatisfiableConstraintExpression(); + } + + // Both first and second are satisfiable. Handle here explicitly the case + // where one (or both) of the maps are trivially satisfied. + if (first.IsAlwaysSatisfied()) { + return second; + } + + if (second.IsAlwaysSatisfied()) { + return first; + } + + // `IsAlwaysSatisfied()` is true if and only if the map holds literally no + // useful information and is equivalent to a default-constructed + // `ConstraintExpression`---one that is neither unsatisfiable, nor contains + // any constraints. Therefore, we can assume below that both of the provided + // `ConstraintExpression`s are satisfiable and each contain at least one + // constraint. + // + // By distributivity, we have that: + // (conj0 || conj1 || ...) && (conj2 || conj3 || ...) + // = (conj0 && conj2 || conj0 && conj3 || ... || + // conj1 && conj2 || conj1 && conj3 ...) + // which allows us to construct the result by essentially taking the cartesian + // product of the disjoint conjunctions of `first` with those of `second`. + ConstraintExpression result; + for (ConjointConstraints& conjunction_1 : + first.disjoint_conjoint_constraints_) { + for (ConjointConstraints& conjunction_2 : + second.disjoint_conjoint_constraints_) { + std::optional maybe_conjunction = + MergeConstraintMapIfPresentAndCompatible(conjunction_1, + conjunction_2); + // We only add the resulting conjunction to the result + // `ConstraintExpression` if it is satisfiable, since it is otherwise + // redundant: + // (conj || false = conj). + if (maybe_conjunction.has_value()) { + result.disjoint_conjoint_constraints_.push_back( + std::move(*maybe_conjunction)); + } + } + } + + // If all the resulting conjunctions are unsatisfiable, the result itself is + // unsatisfiable: + // (false || false = false). + // In our case, this manifests as an empty list of constraints in the result. + result.is_satisfiable_ = !result.disjoint_conjoint_constraints_.empty(); + + return result; +} + +/*static*/ ConstraintExpression ConstraintExpression::Or( + ConstraintExpression first, ConstraintExpression second) { + // When either one of the expressions is unsatisfiable, we can simply return + // the other one. + if (!first.is_satisfiable_) { + return second; + } + + if (!second.is_satisfiable_) { + return first; + } + + absl::c_copy(second.disjoint_conjoint_constraints_, + std::back_inserter(first.disjoint_conjoint_constraints_)); + return first; +} + +void ConstraintExpression::Or( + ConstraintExpression::ConjointConstraints conjunction) { + if (conjunction.empty()) { + return; + } + + disjoint_conjoint_constraints_.push_back(std::move(conjunction)); + is_satisfiable_ = true; +} + +void ConstraintExpression::And( + ConstraintExpression::ConjointConstraints conjunction) { + if (!is_satisfiable_ || conjunction.empty()) { + return; + } + + if (disjoint_conjoint_constraints_.empty()) { + disjoint_conjoint_constraints_.push_back(std::move(conjunction)); + return; + } + + std::vector new_constraints; + new_constraints.reserve(disjoint_conjoint_constraints_.size()); + + for (ConjointConstraints& conjunction_2 : disjoint_conjoint_constraints_) { + std::optional maybe_result = + MergeConstraintMapIfPresentAndCompatible(std::move(conjunction_2), + conjunction); + // TODO(bchetioui): rework `MergeConstraintMapIfPresentAndCompatible`. + if (maybe_result.has_value()) { + new_constraints.push_back(std::move(*maybe_result)); + } + } + + is_satisfiable_ = !new_constraints.empty(); + disjoint_conjoint_constraints_ = std::move(new_constraints); +} + /*static*/ std::optional SymbolicTile::FromIndexingMap( IndexingMap indexing_map) { VLOG(1) << "SymbolicTile::FromIndexingMap: " << indexing_map.ToString(); @@ -689,7 +803,7 @@ std::optional MergeConstraintMapIfPresentAndCompatible( expr = SimplifyAffineExpr(expr, indexing_map); } - std::optional maybe_constraints = ConstraintMap(); + ConstraintExpression constraints; std::vector size_expressions; std::vector stride_expressions; size_expressions.reserve(offset_expressions.size()); @@ -711,19 +825,8 @@ std::optional MergeConstraintMapIfPresentAndCompatible( size_expressions.push_back(maybe_size_and_stride->size); stride_expressions.push_back(maybe_size_and_stride->stride); - maybe_constraints = MergeConstraintMapIfPresentAndCompatible( - std::move(maybe_constraints), maybe_size_and_stride->constraints); - } - - ConstraintMap constraints; - bool is_satisfiable = true; - - // Handle cases that we don't know how to process by constructing a - // ConstraintMap with an unsatisfiable constraint. - if (maybe_constraints.has_value()) { - constraints = std::move(*maybe_constraints); - } else { - is_satisfiable = false; + constraints = ConstraintExpression::And( + std::move(constraints), std::move(maybe_size_and_stride->constraints)); } // Eliminate negative strides and recalculate offsets. @@ -773,7 +876,7 @@ std::optional MergeConstraintMapIfPresentAndCompatible( CHECK_EQ(tile_map.GetRangeVarsCount(), 0); VLOG(1) << "tile_map: " << tile_map.ToString(); - return SymbolicTile(std::move(tile_map), constraints, is_satisfiable); + return SymbolicTile(std::move(tile_map), std::move(constraints)); } std::string SymbolicTile::RtVarsToString( @@ -809,24 +912,31 @@ void SymbolicTile::Print(std::ostream& out, /*first_rt_var_symbol_index=*/tile_map_.GetDimensionCount(), out, printer); } - if (!constraints_.empty() && is_satisfiable_) { + if (is_satisfiable() && !constraints_.IsAlwaysSatisfied()) { out << "\n\tconstraints: "; + absl::Span conjunctions = + constraints_.DisjointConjointConstraints(); // Accumulate constraints in a vector in order to put them in lexicographic // order and to get deterministic output. - std::vector constraint_strings; - constraint_strings.reserve(constraints_.size()); - for (const auto& [expr, interval] : constraints_) { - std::stringstream ss; - printer.Print(ss, expr); - ss << " in "; - interval.Print(ss); - constraint_strings.push_back(ss.str()); - } - std::sort(constraint_strings.begin(), constraint_strings.end()); - for (absl::string_view constraint_string : constraint_strings) { - out << "\n\t" << constraint_string; + std::vector conjunction_strings; + conjunction_strings.reserve(conjunctions.size()); + for (const auto& disjunction : conjunctions) { + std::vector constraint_strings; + constraint_strings.reserve(disjunction.size()); + for (const auto& [expr, interval] : disjunction) { + std::stringstream ss; + printer.Print(ss, expr); + ss << " in "; + interval.Print(ss); + constraint_strings.push_back(ss.str()); + } + std::sort(constraint_strings.begin(), constraint_strings.end()); + conjunction_strings.push_back( + absl::StrJoin(constraint_strings, " &&\n\t")); } - } else if (!is_satisfiable_) { + std::sort(conjunction_strings.begin(), conjunction_strings.end()); + out << "\n\t" << absl::StrJoin(conjunction_strings, "\n||\n\t"); + } else if (!is_satisfiable()) { out << "\n\tconstraints: "; out << "\n\tunsatisfiable"; } diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.h b/third_party/xla/xla/service/gpu/model/symbolic_tile.h index f5c3f680eae5c3..d348f95f67ecdc 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile.h +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.h @@ -20,8 +20,10 @@ limitations under the License. #include #include #include +#include #include "absl/log/check.h" +#include "absl/types/span.h" #include "llvm/ADT/DenseMap.h" #include "mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/IR/AffineMap.h" // from @llvm-project @@ -31,6 +33,70 @@ limitations under the License. namespace xla { namespace gpu { +// `ConstraintExpression` represents a "flat" constraint expression of the form +// ((expr0 in interval0) && (expr1 in interval1)...) || +// ((expr{n} in interval{n}) &&...)... +// +// The underlying constraints are stored in a vector of maps, such that each +// map represents the conjunction of some constraints, and the vector represents +// the disjunction of all its contained maps (conjunctions). This representation +// is effective because `&&` (`And`) is distributive over `||` (`Or`), ensuring +// that we can always flatten any given `ConstraintExpression` in this way, and +// that we have reasonable combinators for `&&` and `||`. +// +// We store a boolean `is_satisfiable_` to indicate whether we expect that the +// constraints can be satisfied. When set to `false`, we expect the +// `ConstraintExpression` to be empty (bottom). +class ConstraintExpression { + public: + using ConjointConstraints = llvm::DenseMap; + // Takes the conjunction of the constraints of `first` and `second`. + static ConstraintExpression And(ConstraintExpression first, + ConstraintExpression second); + + // Takes the disjunction of the constraints of `first` and `second`. + static ConstraintExpression Or(ConstraintExpression first, + ConstraintExpression second); + + // Produces the unsatisfiable constraint expression. + static ConstraintExpression GetUnsatisfiableConstraintExpression() { + ConstraintExpression unsatisfiable; + unsatisfiable.is_satisfiable_ = false; + return unsatisfiable; + } + + // Convenience util to take the disjunction of `this` and unwrapped + // `ConjointConstraints`. + void Or(ConjointConstraints conjunction); + + // Convenience util to take the conjunction of `this` and unwrapped + // `ConjointConstraints`. + void And(ConjointConstraints conjunction); + + // Whether the constraints can be satisfied. When this is set to `false`, + // the domain of the `TileConstraints` must be considered empty. + bool is_satisfiable() const { return is_satisfiable_; } + + // Returns `true` if the constraint expression is marked satisfiable and does + // not contain any constraint. We expect this to be the case for a default + // constructed `ConstraintExpression`. + bool IsAlwaysSatisfied() const { + return is_satisfiable_ && disjoint_conjoint_constraints_.empty(); + } + + // Accessor for the underlying disjoint conjunctions of constraints. This is + // expected to be empty if `is_satisfiable()` is `false`. + absl::Span DisjointConjointConstraints() const { + return disjoint_conjoint_constraints_; + } + + // TODO(bchetioui): add a util to verify constraints here later. + // TODO(bchetioui): is canonicalization of disjunctions necessary? + private: + bool is_satisfiable_ = true; + std::vector disjoint_conjoint_constraints_; +}; + // Tiling in the simpler case, when we don't have dynamic offsets (see the // general case later): // @@ -175,15 +241,15 @@ class SymbolicTile { // Constraints on the `sizes` of the input tile. The variable names in this // map correspond to the parameter names of `offset_map()`, `size_map()`, and - // `stride_map()`. Contents are irrelevant when `is_satisfiable()` is false. - const ConstraintMap& constraints() const { - CHECK(is_satisfiable_); + // `stride_map()`. Content is irrelevant when `is_satisfiable()` is false. + const ConstraintExpression& constraints() const { + CHECK(constraints_.is_satisfiable()); return constraints_; } // Whether the `SymbolicTile` constraints can be satisfied. When this is set - // to true, the domain of the `SymbolicTile` must be considered empty. - bool is_satisfiable() const { return is_satisfiable_; } + // to `false`, the domain of the `SymbolicTile` must be considered empty. + bool is_satisfiable() const { return constraints_.is_satisfiable(); } // A map from one tile's sizes and RTVars to another tile's offsets, sizes, // and strides. @@ -212,37 +278,12 @@ class SymbolicTile { IndexingMap tile_map_; // See the comment of constraints(). - ConstraintMap constraints_; + ConstraintExpression constraints_; - // See the comment of is_satisfiable(). - bool is_satisfiable_ = true; - - explicit SymbolicTile(IndexingMap tile_map, ConstraintMap constraints, - bool is_satisfiable = true) - : tile_map_(std::move(tile_map)), - constraints_(std::move(constraints)), - is_satisfiable_(is_satisfiable) {} + explicit SymbolicTile(IndexingMap tile_map, ConstraintExpression constraints) + : tile_map_(std::move(tile_map)), constraints_(std::move(constraints)) {} }; -// Merges `maybe_first_map` and `second_map` if -// (1) `maybe_first_map` is present, and -// (2) `second_map` and `*maybe_first_map` have distinct sets of keys. -// Otherwise, returns `std::nullopt`. -// -// -// The behaviour of this function is in spirit equivalent to using C++23's -// `std::optional::and_then` to merge a collection of `ConstraintMap`s. -// -// We pass `maybe_first_map` by value here in order to exploit move semantics -// to avoid copies when possible. -// -// TODO(bchetioui): allow merging constraints in more edge cases, e.g. if one -// of the intervals is contained within the other. -std::optional -MergeConstraintMapIfPresentAndCompatible( - std::optional maybe_first_map, - const SymbolicTile::ConstraintMap& second_map); - } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc index d22917cba3f2de..4cea5d1fee6830 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -171,7 +171,7 @@ absl::StatusOr ComputeBlockIdToTileOffsetIndexing( const HloInstructionAdaptor&, IndexingMap)> get_tiled_hlo_instruction; - std::optional constraints = ConstraintMap(); + ConstraintExpression constraints; // Create a new tiled hlo instruction or return existing instruction from // cache for the given hlo and indexing map. @@ -215,13 +215,11 @@ absl::StatusOr ComputeBlockIdToTileOffsetIndexing( << hlo->ToString(); } - constraints = MergeConstraintMapIfPresentAndCompatible( - std::move(constraints), symbolic_tile->constraints()); + constraints = ConstraintExpression::And(std::move(constraints), + symbolic_tile->constraints()); - if (!constraints.has_value()) { - return FusionDecision{} << "Failed to merge constraints of " - << hlo->ToString() << " in pre-existing " - << "constraint map"; + if (!constraints.is_satisfiable()) { + return FusionDecision{} << "Fusion has unsatisfiable constraints"; } tiled_hlo_instructions.push_back( @@ -293,36 +291,58 @@ absl::StatusOr ComputeBlockIdToTileOffsetIndexing( }); return SymbolicTileAnalysis(std::move(tiled_hlo_instructions), - std::move(*constraints), ctx); + std::move(constraints), ctx); } absl::StatusOr SymbolicTileAnalysis::ParametersSatisfyConstraints( absl::Span tile_parameters) const { + if (!constraints_.is_satisfiable()) { + return absl::FailedPreconditionError( + "SymbolicTileAnalysis's constraints are not satisfiable. " + "This should never happen."); + } + + // Handle the unconstrained case. + if (constraints_.IsAlwaysSatisfied()) { + return true; + } + // Populate parameter map. llvm::SmallVector parameters = llvm::to_vector( llvm::map_range(tile_parameters, [this](const int64_t v) -> AffineExpr { return mlir::getAffineConstantExpr(v, context_); })); - for (auto [constrained_expr, interval] : constraints_) { - AffineExpr constrained_expr_value = - constrained_expr.replaceSymbols(parameters); - if (constrained_expr_value.getKind() != mlir::AffineExprKind::Constant) { - return absl::InvalidArgumentError(absl::StrCat( - "Failed to reduce ", AffineMapPrinter().ToString(constrained_expr), - " to a constant with tile parameters ", - absl::StrJoin(tile_parameters, ", "))); - } + // TODO(bchetioui): replace with convenience methods in + // `ConstraintExpression`. + bool constraints_are_satisfied = false; + for (const ConstraintExpression::ConjointConstraints& conjunction : + constraints_.DisjointConjointConstraints()) { + bool conjunction_is_satisfied = true; + for (const auto& [constrained_expr, interval] : conjunction) { + AffineExpr constrained_expr_value = + constrained_expr.replaceSymbols(parameters); + if (constrained_expr_value.getKind() != mlir::AffineExprKind::Constant) { + return absl::InvalidArgumentError(absl::StrCat( + "Failed to reduce ", AffineMapPrinter().ToString(constrained_expr), + " to a constant with tile parameters ", + absl::StrJoin(tile_parameters, ", "))); + } - int64_t constrained_value = - llvm::cast(constrained_expr_value).getValue(); + int64_t constrained_value = + llvm::cast(constrained_expr_value) + .getValue(); - if (constrained_value < interval.lower || - constrained_value > interval.upper) { - return false; + if (constrained_value < interval.lower || + constrained_value > interval.upper) { + conjunction_is_satisfied = false; + break; + } } + constraints_are_satisfied |= conjunction_is_satisfied; } - return true; + + return constraints_are_satisfied; } absl::StatusOr diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h index ac938f5c17589e..2a601c6c1ab8d8 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h @@ -85,11 +85,9 @@ class SymbolicTileAnalysis { } // Returns the constraints for the parameters of the symbolic tiled HLO - // computation. This is the union of the constraints of all the symbolic tiles - // encountered throughout the computation. - const SymbolicTile::ConstraintMap& GetConstraints() const { - return constraints_; - } + // computation. This is the intersection of the constraints of all the + // symbolic tiles encountered throughout the computation. + const ConstraintExpression& GetConstraints() const { return constraints_; } // Returns true if a list of tile parameters satisfies the symbolic tile // analysis's constraints. @@ -120,7 +118,7 @@ class SymbolicTileAnalysis { private: SymbolicTileAnalysis(std::vector> symbolic_tiled_hlo_instructions, - SymbolicTile::ConstraintMap constraints, + ConstraintExpression constraints, mlir::MLIRContext* context) : symbolic_tiled_hlo_instructions_( std::move(symbolic_tiled_hlo_instructions)), @@ -132,7 +130,7 @@ class SymbolicTileAnalysis { symbolic_tiled_hlo_instructions_; // See the documentation of GetConstraints(). - SymbolicTile::ConstraintMap constraints_; + ConstraintExpression constraints_; mlir::MLIRContext* context_; }; diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc index 7a420ba9542298..74942b19eaae39 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -34,6 +34,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/model/indexing_test_utils.h" +#include "xla/service/gpu/model/symbolic_tile.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/gpu/model/tiled_hlo_instruction.h" #include "xla/service/instruction_fusion.h" @@ -370,9 +371,14 @@ ENTRY main { })")); std::optional analysis = TryAnalyzeModule(module.get()); ASSERT_TRUE(analysis.has_value()); - EXPECT_THAT(analysis->GetConstraints(), SizeIs(1)); + const ConstraintExpression& constraints = analysis->GetConstraints(); + EXPECT_THAT(constraints.DisjointConjointConstraints(), SizeIs(1)); + EXPECT_THAT(constraints.DisjointConjointConstraints().front(), SizeIs(1)); } +// TODO(b/334043867): add disjunction tests here once disjunctions are actually +// used in `SymbolicTile`s. + TEST_F(SymbolicTileAnalysisTest, DoesNotBailOutOnConstrainedBitcast) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( @@ -387,7 +393,9 @@ ENTRY main { })")); std::optional analysis = TryAnalyzeModule(module.get()); ASSERT_TRUE(analysis.has_value()); - EXPECT_THAT(analysis->GetConstraints(), SizeIs(1)); + const ConstraintExpression& constraints = analysis->GetConstraints(); + EXPECT_THAT(constraints.DisjointConjointConstraints(), SizeIs(1)); + EXPECT_THAT(constraints.DisjointConjointConstraints().front(), SizeIs(1)); } TEST_F(SymbolicTileAnalysisTest, BailOutOnUnsupportedConcatenate) { @@ -440,7 +448,9 @@ ENTRY main { })")); std::optional analysis = TryAnalyzeModule(module.get()); ASSERT_TRUE(analysis.has_value()); - EXPECT_THAT(analysis->GetConstraints(), SizeIs(2)); + const ConstraintExpression& constraints = analysis->GetConstraints(); + EXPECT_THAT(constraints.DisjointConjointConstraints(), SizeIs(1)); + EXPECT_THAT(constraints.DisjointConjointConstraints().front(), SizeIs(2)); // We expect the constraints here to be // s0 mod 6 in [0, 0] @@ -496,7 +506,9 @@ ENTRY main { ASSERT_TRUE(analysis.has_value()); // Each bitcast in the above module introduces one constraint. Once they are // aggregated, we have two! - EXPECT_THAT(analysis->GetConstraints(), SizeIs(2)); + const ConstraintExpression& constraints = analysis->GetConstraints(); + EXPECT_THAT(constraints.DisjointConjointConstraints(), SizeIs(1)); + EXPECT_THAT(constraints.DisjointConjointConstraints().front(), SizeIs(2)); } TEST_F(SymbolicTileAnalysisTest, BailsOutWhenConstraintsCanNotBeMerged) { diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc index b825efc4609812..e60a8fd7157bd8 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include @@ -44,7 +45,11 @@ using ::mlir::AffineExpr; using ::mlir::AffineMap; using ::testing::ElementsAre; using ::testing::ExplainMatchResult; +using ::testing::IsEmpty; using ::testing::Optional; +using ::testing::SizeIs; + +using ConjointConstraints = ConstraintExpression::ConjointConstraints; MATCHER_P(MatchSymbolicTileString, symbolic_tile_string, "") { return ExplainMatchResult( @@ -716,7 +721,7 @@ TEST_F(SymbolicTileTest, CanCombineCompatibleConstraints) { size_map: ()[s0, s1] -> (1, (s0 + 5) floordiv 6, s0 - ((s0 - 1) floordiv 6) * 6, (s1 + 7) floordiv 8, s1 - ((s1 - 1) floordiv 8) * 8) stride_map: ()[s0, s1] -> (0, 1, 1, 1, 1) constraints: - s0 mod 6 in [0, 0] + s0 mod 6 in [0, 0] && s1 mod 8 in [0, 0] )"))); } @@ -782,6 +787,237 @@ TEST_F(SymbolicTileTest, CanDeriveTileWhenTheIndexingMapHasSymbolsInASum) { )"))); } +class ConstraintExpressionTest : public IndexingTestBase { + public: + using ConstraintVector = std::vector>; + + // Constructs a conjoint constraint from a vector of pairs containing a string + // representation of an affine expression and an interval. + ConjointConstraints GetConjointConstraints( + ConstraintVector&& expr_and_interval_pairs) { + ConjointConstraints conjunction; + for (auto& [string_expr, interval] : expr_and_interval_pairs) { + conjunction.insert( + {ParseAffineExpr(string_expr, &mlir_context_), interval}); + } + return conjunction; + } +}; + +TEST_F(ConstraintExpressionTest, + DefaultConstructedConstraintExpressionIsAlwaysSatisfied) { + EXPECT_TRUE(ConstraintExpression().IsAlwaysSatisfied()); +} + +TEST_F(ConstraintExpressionTest, + UnsatisfiableConstraintExpressionHoldsNoConstraint) { + ConstraintExpression unsatisfiable_constraint = + ConstraintExpression::GetUnsatisfiableConstraintExpression(); + EXPECT_FALSE(unsatisfiable_constraint.is_satisfiable()); + EXPECT_THAT(unsatisfiable_constraint.DisjointConjointConstraints(), + IsEmpty()); +} + +TEST_F( + ConstraintExpressionTest, + CanSuccessfullyPerformConjunctionOfConstraintExpressionWithConjointConstraints) { // NOLINT(whitespace/line_length) + ConjointConstraints conjunction_1 = + GetConjointConstraints({{"d0", Interval{0, 5}}, {"d1", Interval{0, 5}}}); + ConjointConstraints conjunction_2 = + GetConjointConstraints({{"d2", Interval{0, 5}}}); + + ConstraintExpression constraints; + constraints.And(std::move(conjunction_1)); + constraints.And(std::move(conjunction_2)); + // Constraints can be merged without trouble, and hence the constraint + // expression is satisfiable. + EXPECT_TRUE(constraints.is_satisfiable()); + const auto& conjunctions = constraints.DisjointConjointConstraints(); + // There is a single conjunction in the disjoint expression. + EXPECT_THAT(conjunctions, SizeIs(1)); + // There are three constraints in the single conjunction. + EXPECT_THAT(conjunctions.front(), SizeIs(3)); + + // TODO(bchetioui): add test for the case where a conjunction becomes + // unsatisfiable and thus gets eliminated from the disjoint expression. +} + +TEST_F( + ConstraintExpressionTest, + CanSuccessfullyPerformDisjunctionOfConstraintExpressionWithConjointConstraints) { // NOLINT(whitespace/line_length) + ConjointConstraints conjunction_1 = + GetConjointConstraints({{"d0", Interval{0, 5}}, {"d1", Interval{0, 5}}}); + ConjointConstraints conjunction_2 = + GetConjointConstraints({{"d2", Interval{0, 5}}}); + + ConstraintExpression constraints; + constraints.Or(std::move(conjunction_1)); + constraints.Or(std::move(conjunction_2)); + EXPECT_TRUE(constraints.is_satisfiable()); + const auto& conjunctions = constraints.DisjointConjointConstraints(); + // There are now two conjunctions in the disjoint expression. + EXPECT_THAT(conjunctions, SizeIs(2)); + // There are two constraints in the first conjunction. + EXPECT_THAT(conjunctions.front(), SizeIs(2)); + // And one constraint in the second conjunction. + EXPECT_THAT(conjunctions.back(), SizeIs(1)); +} + +TEST_F( + ConstraintExpressionTest, + CanSuccessfullyPerformConjunctionOfConstraintExpressionWithConstraintExpression) { // NOLINT(whitespace/line_length) + // Construct the first `ConstraintExpression` to be of the form + // a || b. + ConjointConstraints conjunction_1 = + GetConjointConstraints({{"d0", Interval{0, 5}}}); + ConjointConstraints conjunction_2 = + GetConjointConstraints({{"d1", Interval{0, 5}}}); + ConstraintExpression constraints_1; + constraints_1.Or(std::move(conjunction_1)); + constraints_1.Or(std::move(conjunction_2)); + + // Construct the second `ConstraintExpression` to be of the form + // c || d || e. + ConjointConstraints conjunction_3 = + GetConjointConstraints({{"d2", Interval{0, 5}}}); + ConjointConstraints conjunction_4 = + GetConjointConstraints({{"d3", Interval{0, 5}}}); + ConjointConstraints conjunction_5 = + GetConjointConstraints({{"d4", Interval{0, 5}}}); + ConstraintExpression constraints_2; + constraints_2.Or(std::move(conjunction_3)); + constraints_2.Or(std::move(conjunction_4)); + constraints_2.Or(std::move(conjunction_5)); + + // Taking the conjunction of the two `ConstraintExpression`s should result in + // a `ConstraintExpression` of the form + // a && c || a && d || a && e || b && c || b && d || b && e. + ConstraintExpression result_constraint_expression = + ConstraintExpression::And(std::move(constraints_1), constraints_2); + + EXPECT_TRUE(result_constraint_expression.is_satisfiable()); + // There are now six conjunctions in the disjoint expression, as described + // above. + EXPECT_THAT(result_constraint_expression.DisjointConjointConstraints(), + SizeIs(6)); + // And each of the conjunction consists only of two elements. + for (const ConjointConstraints& conjunction : + result_constraint_expression.DisjointConjointConstraints()) { + EXPECT_THAT(conjunction, SizeIs(2)); + } + + // Lastly, make sure that the conjunction of an empty `ConstraintExpression` + // with a non-empty one results in passing the non-empty one through, on both + // sides. + ConstraintExpression empty_constraints; + EXPECT_THAT(ConstraintExpression::And(empty_constraints, constraints_2) + .DisjointConjointConstraints(), + SizeIs(3)); + EXPECT_THAT(ConstraintExpression::And(std::move(constraints_2), + std::move(empty_constraints)) + .DisjointConjointConstraints(), + SizeIs(3)); +} + +TEST_F( + ConstraintExpressionTest, + CanSuccessfullyPerformDisjunctionOfConstraintExpressionWithConstraintExpression) { // NOLINT(whitespace/line_length) + // Construct the first `ConstraintExpression` to be of the form + // a || b. + ConjointConstraints conjunction_1 = + GetConjointConstraints({{"d0", Interval{0, 5}}}); + ConjointConstraints conjunction_2 = + GetConjointConstraints({{"d1", Interval{0, 5}}}); + ConstraintExpression constraints_1; + constraints_1.Or(std::move(conjunction_1)); + constraints_1.Or(std::move(conjunction_2)); + + // Construct the second `ConstraintExpression` to be of the form + // c || d || e. + ConjointConstraints conjunction_3 = + GetConjointConstraints({{"d2", Interval{0, 5}}}); + ConjointConstraints conjunction_4 = + GetConjointConstraints({{"d3", Interval{0, 5}}}); + ConjointConstraints conjunction_5 = + GetConjointConstraints({{"d4", Interval{0, 5}}}); + ConstraintExpression constraints_2; + constraints_2.Or(std::move(conjunction_3)); + constraints_2.Or(std::move(conjunction_4)); + constraints_2.Or(std::move(conjunction_5)); + + // Taking the disjunction of the two `ConstraintExpression`s should result in + // a `ConstraintExpression` of the form + // a || b || c || d || e. + ConstraintExpression result_constraint_expression = ConstraintExpression::Or( + std::move(constraints_1), std::move(constraints_2)); + + EXPECT_TRUE(result_constraint_expression.is_satisfiable()); + // There are now five conjunctions in the disjoint expression, as described + // above. + EXPECT_THAT(result_constraint_expression.DisjointConjointConstraints(), + SizeIs(5)); + // And each of the conjunctions consists only of a single constraint. + for (const ConjointConstraints& conjunction : + result_constraint_expression.DisjointConjointConstraints()) { + EXPECT_THAT(conjunction, SizeIs(1)); + } +} + +TEST_F( + ConstraintExpressionTest, + ConjunctionInvolvingUnsatisfiableConstraintExpressionIsUnsatisfiable) { // NOLINT(whitespace/line_length) + ConstraintExpression constraints = + ConstraintExpression::GetUnsatisfiableConstraintExpression(); + ConjointConstraints conjunction_1 = + GetConjointConstraints({{"d0", Interval{0, 5}}}); + + constraints.And(std::move(conjunction_1)); + EXPECT_FALSE(constraints.is_satisfiable()); + EXPECT_THAT(constraints.DisjointConjointConstraints(), IsEmpty()); +} + +TEST_F( + ConstraintExpressionTest, + DisjunctionInvolvingUnsatisfiableConstraintExpressionIsSatisfiable) { // NOLINT(whitespace/line_length) + ConstraintExpression constraints = + ConstraintExpression::GetUnsatisfiableConstraintExpression(); + ConjointConstraints conjunction_1 = + GetConjointConstraints({{"d0", Interval{0, 5}}}); + + // Try first with a single group of `ConjointConstraints`. + constraints.Or(conjunction_1); + EXPECT_TRUE(constraints.is_satisfiable()); + EXPECT_THAT(constraints.DisjointConjointConstraints(), SizeIs(1)); + + // Make sure this also works when constructing the conjunction from two + // `ConstraintExpression`s. + ConstraintExpression constraints_1 = + ConstraintExpression::GetUnsatisfiableConstraintExpression(); + ConstraintExpression constraints_2; + constraints_2.Or(std::move(conjunction_1)); + + ConstraintExpression result_constraint_expression = ConstraintExpression::Or( + std::move(constraints_1), std::move(constraints_2)); + EXPECT_TRUE(result_constraint_expression.is_satisfiable()); + EXPECT_THAT(result_constraint_expression.DisjointConjointConstraints(), + SizeIs(1)); +} + +TEST_F( + ConstraintExpressionTest, + DisjunctionInvolvingTwoUnsatisfiableConstraintExpressionsIsUnsatisfiable) { // NOLINT(whitespace/line_length) + ConstraintExpression constraints_1 = + ConstraintExpression::GetUnsatisfiableConstraintExpression(); + ConstraintExpression constraints_2 = + ConstraintExpression::GetUnsatisfiableConstraintExpression(); + + EXPECT_FALSE( + ConstraintExpression::And(constraints_1, constraints_2).is_satisfiable()); +} + +// TODO(b/334043867): add support for intersecting constraints within a single +// conjunction. + } // namespace } // namespace gpu } // namespace xla From c26e492fce7226f0d5b98d5265b26bbd1ac35eca Mon Sep 17 00:00:00 2001 From: Kuy Mainwaring Date: Fri, 21 Jun 2024 10:20:31 -0700 Subject: [PATCH 122/256] [XLA:GPU] Clang-tidy cleanup for xla/service/bfloat16_propagation_test.cc PiperOrigin-RevId: 645428274 --- third_party/xla/xla/service/BUILD | 6 ++++++ .../xla/xla/service/bfloat16_propagation_test.cc | 12 ++++++++++++ 2 files changed, 18 insertions(+) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 3cc8fe9e299a0d..093d373ee11212 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -436,6 +436,9 @@ xla_cc_test( deps = [ ":bfloat16_propagation", ":float_support", + ":hlo_verifier", + "//xla:comparison_util", + "//xla:literal_util", "//xla:shape_util", "//xla:test", "//xla:test_helpers", @@ -444,6 +447,9 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/bfloat16_propagation_test.cc b/third_party/xla/xla/service/bfloat16_propagation_test.cc index 72f7be63a73426..53e76916ba98c9 100644 --- a/third_party/xla/xla/service/bfloat16_propagation_test.cc +++ b/third_party/xla/xla/service/bfloat16_propagation_test.cc @@ -15,17 +15,29 @@ limitations under the License. #include "xla/service/bfloat16_propagation.h" +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "xla/comparison_util.h" +#include "xla/hlo/ir/collective_device_list.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal_util.h" #include "xla/service/float_support.h" +#include "xla/service/hlo_verifier.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { From abcaba2e5fce335212d1f7617bdb4e3def525c23 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 21 Jun 2024 10:26:32 -0700 Subject: [PATCH 123/256] Move metadata_util to tensorflow/compiler/mlir/lite/experimental/remat PiperOrigin-RevId: 645430414 --- tensorflow/compiler/mlir/lite/BUILD | 18 ++++++++-- .../mlir/lite/experimental/remat/BUILD | 26 ++++++++++++++ .../lite/experimental/remat/metadata_util.cc | 2 +- .../lite/experimental/remat/metadata_util.h | 10 +++--- .../experimental/remat/metadata_util_test.cc | 2 +- .../compiler/mlir/lite/flatbuffer_export.cc | 9 ++--- .../compiler/mlir/lite/flatbuffer_import.cc | 2 +- .../mlir/lite/tf_to_tfl_flatbuffer.cc | 2 +- .../compiler/mlir/lite/utils/control_edges.h | 34 ++++++++++++++++++ tensorflow/lite/CMakeLists.txt | 2 ++ tensorflow/lite/core/BUILD | 10 +++--- tensorflow/lite/core/interpreter.cc | 2 +- tensorflow/lite/core/interpreter.h | 2 +- tensorflow/lite/delegates/BUILD | 2 +- tensorflow/lite/delegates/delegate_test.cc | 2 +- tensorflow/lite/experimental/remat/BUILD | 35 ------------------- tensorflow/lite/graph_info.h | 2 +- 17 files changed, 99 insertions(+), 63 deletions(-) rename tensorflow/{ => compiler/mlir}/lite/experimental/remat/metadata_util.cc (97%) rename tensorflow/{ => compiler/mlir}/lite/experimental/remat/metadata_util.h (86%) rename tensorflow/{ => compiler/mlir}/lite/experimental/remat/metadata_util_test.cc (95%) create mode 100644 tensorflow/compiler/mlir/lite/utils/control_edges.h delete mode 100644 tensorflow/lite/experimental/remat/BUILD diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 940cb0c47ce319..a6187e92719888 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1089,6 +1089,8 @@ cc_library( ":string_utils", ":tensorflow_lite", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", + "//tensorflow/compiler/mlir/lite:control_edges", + "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", "//tensorflow/compiler/mlir/lite/metrics:error_collector_inst", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", @@ -1105,7 +1107,6 @@ cc_library( "//tensorflow/lite/core:framework", "//tensorflow/lite/core/c:private_common", "//tensorflow/lite/delegates/flex:allowlisted_flex_ops_lib", - "//tensorflow/lite/experimental/remat:metadata_util", "//tensorflow/lite/python/metrics:converter_error_data_proto_cc", "//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/tools/versioning", @@ -1148,6 +1149,7 @@ cc_library( ":offset_buffer", ":size_utils", ":tensorflow_lite", + "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", "//tensorflow/compiler/mlir/lite/schema:schema_utils", @@ -1164,7 +1166,6 @@ cc_library( "//tensorflow/core/platform:errors", "//tensorflow/core/platform:status", "//tensorflow/lite:framework", - "//tensorflow/lite/experimental/remat:metadata_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", @@ -1415,6 +1416,7 @@ cc_library( "//tensorflow/cc/saved_model:loader", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite/debug", + "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", "//tensorflow/compiler/mlir/lite/metrics:error_collector", "//tensorflow/compiler/mlir/lite/metrics:error_collector_inst", "//tensorflow/compiler/mlir/lite/quantization/stablehlo:quantization", @@ -1442,7 +1444,6 @@ cc_library( "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", - "//tensorflow/lite/experimental/remat:metadata_util", "//tensorflow/lite/python/metrics:converter_error_data_proto_cc", "//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/tools/optimize:quantize_weights", @@ -1517,6 +1518,17 @@ cc_library( visibility = ["//tensorflow/lite:__pkg__"], ) +exports_files(srcs = ["utils/control_edges.h"]) + +cc_library( + name = "control_edges", + hdrs = ["utils/control_edges.h"], + visibility = [ + "//tensorflow/compiler/mlir/lite/experimental/remat:__pkg__", + "//tensorflow/lite:__pkg__", + ], +) + tf_cc_test( name = "offset_buffer_test", srcs = ["offset_buffer_test.cc"], diff --git a/tensorflow/compiler/mlir/lite/experimental/remat/BUILD b/tensorflow/compiler/mlir/lite/experimental/remat/BUILD index f0d059ca919196..089d5e695ea20a 100644 --- a/tensorflow/compiler/mlir/lite/experimental/remat/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/remat/BUILD @@ -1,4 +1,5 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) @@ -19,3 +20,28 @@ tf_cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "metadata_util", + srcs = ["metadata_util.cc"], + hdrs = ["metadata_util.h"], + compatible_with = get_compatible_with_portable(), + visibility = [ + "//tensorflow/compiler/mlir/lite:__pkg__", + "//tensorflow/lite/core:__pkg__", + "//tensorflow/lite/delegates:__pkg__", + ], + deps = [ + "//tensorflow/compiler/mlir/lite:control_edges", + ], +) + +tf_cc_test( + name = "metadata_util_test", + size = "small", + srcs = ["metadata_util_test.cc"], + deps = [ + ":metadata_util", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/lite/experimental/remat/metadata_util.cc b/tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.cc similarity index 97% rename from tensorflow/lite/experimental/remat/metadata_util.cc rename to tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.cc index 71224865098c59..a580a64edf24fd 100644 --- a/tensorflow/lite/experimental/remat/metadata_util.cc +++ b/tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/remat/metadata_util.h" +#include "tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.h" #include #include diff --git a/tensorflow/lite/experimental/remat/metadata_util.h b/tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.h similarity index 86% rename from tensorflow/lite/experimental/remat/metadata_util.h rename to tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.h index cf9edc751e5548..6036c468469a27 100644 --- a/tensorflow/lite/experimental/remat/metadata_util.h +++ b/tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.h @@ -18,13 +18,15 @@ limitations under the License. /// information to/from model metadata. /// -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_REMAT_METADATA_UTIL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_REMAT_METADATA_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_REMAT_METADATA_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_REMAT_METADATA_UTIL_H_ +#include +#include #include #include -#include "tensorflow/lite/graph_info.h" +#include "tensorflow/compiler/mlir/lite/utils/control_edges.h" namespace tflite { @@ -57,4 +59,4 @@ inline constexpr char kModelUseStablehloTensorKey[] = "keep_stablehlo_constant"; } // namespace tflite -#endif // TENSORFLOW_LITE_EXPERIMENTAL_REMAT_METADATA_UTIL_H_ +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_REMAT_METADATA_UTIL_H_ diff --git a/tensorflow/lite/experimental/remat/metadata_util_test.cc b/tensorflow/compiler/mlir/lite/experimental/remat/metadata_util_test.cc similarity index 95% rename from tensorflow/lite/experimental/remat/metadata_util_test.cc rename to tensorflow/compiler/mlir/lite/experimental/remat/metadata_util_test.cc index 1773e2029d5a64..abb6a674b5dbd8 100644 --- a/tensorflow/lite/experimental/remat/metadata_util_test.cc +++ b/tensorflow/compiler/mlir/lite/experimental/remat/metadata_util_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/remat/metadata_util.h" +#include "tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.h" #include #include diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 0c6c56cf001a0c..d1868519766686 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -81,12 +81,14 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "stablehlo/dialect/VhloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/compiler/mlir/lite/utils/control_edges.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/lite/utils/low_bit_utils.h" #include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h" @@ -110,7 +112,6 @@ limitations under the License. #include "tensorflow/lite/core/interpreter.h" #include "tensorflow/lite/core/macros.h" #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h" -#include "tensorflow/lite/experimental/remat/metadata_util.h" #include "tensorflow/lite/python/metrics/converter_error_data.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/tools/versioning/gpu_compatibility.h" @@ -154,12 +155,6 @@ using VectorBufferOffset = flatbuffers::Offset>; using CustomOptionsOffset = VectorBufferOffset; -// LINT.IfChange -// Node edge.second depends on node edge.first. -using ControlEdge = std::pair; -using ControlEdges = std::vector; -// LINT.ThenChange(//tensorflow/lite/graph_info.h) - namespace tfl = mlir::TFL; ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex"; diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index b51a5feb156964..2adeb5727cd15f 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -73,6 +73,7 @@ limitations under the License. #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "stablehlo/dialect/VhloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/offset_buffer.h" @@ -95,7 +96,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/lite/experimental/remat/metadata_util.h" #include "tensorflow/lite/graph_info.h" #include "tensorflow/lite/model_builder.h" #include "tsl/platform/status.h" diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 19f4de9a0395b2..0468d93b0c2438 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -46,6 +46,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/debug/debug.h" +#include "tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/metrics/error_collector.h" #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h" @@ -78,7 +79,6 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/public/session.h" -#include "tensorflow/lite/experimental/remat/metadata_util.h" #include "tensorflow/lite/python/metrics/converter_error_data.pb.h" #include "tensorflow/lite/tools/optimize/quantize_weights.h" #include "tensorflow/lite/tools/optimize/reduced_precision_support.h" diff --git a/tensorflow/compiler/mlir/lite/utils/control_edges.h b/tensorflow/compiler/mlir/lite/utils/control_edges.h new file mode 100644 index 00000000000000..e5a16ba7e6f7fd --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/control_edges.h @@ -0,0 +1,34 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONTROL_EDGES_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONTROL_EDGES_H_ + +#include +#include +#include + +namespace tflite { + +// LINT.IfChange + +using ControlEdge = std::pair; +using ControlEdges = std::vector; + +// LINT.ThenChange(//tensorflow/lite/graph_info.h) + +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONTROL_EDGES_H_ diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt index ab74a09c1acf3e..7d28e94619ae17 100644 --- a/tensorflow/lite/CMakeLists.txt +++ b/tensorflow/lite/CMakeLists.txt @@ -657,6 +657,8 @@ set(_ALL_TFLITE_SRCS ${TF_SOURCE_DIR}/compiler/mlir/lite/schema/schema_generated.h ${TF_SOURCE_DIR}/compiler/mlir/lite/utils/string_utils.cc ${TF_SOURCE_DIR}/compiler/mlir/lite/utils/string_utils.h + ${TF_SOURCE_DIR}/compiler/mlir/lite/experimental/remat/metadata_util.h + ${TF_SOURCE_DIR}/compiler/mlir/lite/experimental/remat/metadata_util.cc ${TFLITE_SOURCE_DIR}/schema/schema_generated.h ) add_library(tensorflow-lite diff --git a/tensorflow/lite/core/BUILD b/tensorflow/lite/core/BUILD index b1b1db307fa665..375e63b84d7565 100644 --- a/tensorflow/lite/core/BUILD +++ b/tensorflow/lite/core/BUILD @@ -49,6 +49,7 @@ cc_library( deps = [ ":cc_api_stable", ":signature_runner", + "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", "//tensorflow/lite:allocation", "//tensorflow/lite:array", "//tensorflow/lite:external_cpu_backend_context", @@ -66,7 +67,6 @@ cc_library( "//tensorflow/lite/core/api:verifier", "//tensorflow/lite/core/async:async_signature_runner", "//tensorflow/lite/core/c:common", - "//tensorflow/lite/experimental/remat:metadata_util", "//tensorflow/lite/experimental/resource", "//tensorflow/lite/internal:signature_def", "//tensorflow/lite/profiling:root_profiler", @@ -120,6 +120,7 @@ cc_library( ":cc_api_stable", ":model_builder", ":signature_runner", + "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", "//tensorflow/lite:allocation", "//tensorflow/lite:array", "//tensorflow/lite:external_cpu_backend_context", @@ -137,7 +138,6 @@ cc_library( "//tensorflow/lite/core/api:verifier", "//tensorflow/lite/core/async:async_signature_runner", "//tensorflow/lite/core/c:common", - "//tensorflow/lite/experimental/remat:metadata_util", "//tensorflow/lite/experimental/resource", "//tensorflow/lite/internal:signature_def", "//tensorflow/lite/profiling:root_profiler", @@ -174,6 +174,7 @@ cc_library( ":model_builder", ":signature_runner", ":subgraph", + "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", "//tensorflow/lite:allocation", "//tensorflow/lite:array", "//tensorflow/lite:external_cpu_backend_context", @@ -200,7 +201,6 @@ cc_library( "//tensorflow/lite/delegates:telemetry", "//tensorflow/lite/delegates/xnnpack:tflite_with_xnnpack_qs8", "//tensorflow/lite/delegates/xnnpack:tflite_with_xnnpack_qu8", - "//tensorflow/lite/experimental/remat:metadata_util", "//tensorflow/lite/experimental/resource", "//tensorflow/lite/internal:signature_def", "//tensorflow/lite/kernels/internal:compatibility", @@ -253,6 +253,7 @@ cc_library( deps = [ ":cc_api_stable", ":signature_runner", + "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", "//tensorflow/lite:allocation", "//tensorflow/lite:array", "//tensorflow/lite:external_cpu_backend_context", @@ -271,7 +272,6 @@ cc_library( "//tensorflow/lite/core/async:async_signature_runner", "//tensorflow/lite/core/c:c_api_types", "//tensorflow/lite/core/c:common", - "//tensorflow/lite/experimental/remat:metadata_util", "//tensorflow/lite/experimental/resource", "//tensorflow/lite/internal:signature_def", "//tensorflow/lite/profiling:root_profiler", @@ -467,6 +467,7 @@ cc_library( "//tensorflow/lite/kernels:__subpackages__", ], deps = [ + "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", "//tensorflow/lite:allocation", "//tensorflow/lite:array", "//tensorflow/lite:graph_info", @@ -480,7 +481,6 @@ cc_library( "//tensorflow/lite/core/api", "//tensorflow/lite/core/c:c_api_types", "//tensorflow/lite/core/c:common", - "//tensorflow/lite/experimental/remat:metadata_util", "//tensorflow/lite/experimental/resource", "//tensorflow/lite/profiling:root_profiler", "//tensorflow/lite/profiling/telemetry", diff --git a/tensorflow/lite/core/interpreter.cc b/tensorflow/lite/core/interpreter.cc index 15b5d18d272b5d..9d1623c5b4821f 100644 --- a/tensorflow/lite/core/interpreter.cc +++ b/tensorflow/lite/core/interpreter.cc @@ -27,13 +27,13 @@ limitations under the License. #include #include "ruy/denormal.h" // from @ruy +#include "tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.h" #include "tensorflow/lite/allocation.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/profiler.h" #include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/core/signature_runner.h" #include "tensorflow/lite/core/subgraph.h" -#include "tensorflow/lite/experimental/remat/metadata_util.h" #include "tensorflow/lite/external_cpu_backend_context.h" #include "tensorflow/lite/interpreter_options.h" #include "tensorflow/lite/logger.h" diff --git a/tensorflow/lite/core/interpreter.h b/tensorflow/lite/core/interpreter.h index ed9d798f34753b..c83e68f2cd9ca2 100644 --- a/tensorflow/lite/core/interpreter.h +++ b/tensorflow/lite/core/interpreter.h @@ -46,7 +46,7 @@ limitations under the License. #include "tensorflow/lite/core/c/common.h" // IWYU pragma: export #include "tensorflow/lite/core/signature_runner.h" #include "tensorflow/lite/core/subgraph.h" -#include "tensorflow/lite/experimental/remat/metadata_util.h" +#include "tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.h" #include "tensorflow/lite/experimental/resource/initialization_status.h" #include "tensorflow/lite/experimental/resource/resource_base.h" #include "tensorflow/lite/external_cpu_backend_context.h" diff --git a/tensorflow/lite/delegates/BUILD b/tensorflow/lite/delegates/BUILD index 6e171eacff486c..cf619553fedc08 100644 --- a/tensorflow/lite/delegates/BUILD +++ b/tensorflow/lite/delegates/BUILD @@ -125,6 +125,7 @@ cc_test( ":delegate_test_util", ":interpreter_utils", ":utils", + "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", "//tensorflow/lite:builtin_ops", "//tensorflow/lite:framework", "//tensorflow/lite:kernel_api", @@ -136,7 +137,6 @@ cc_test( "//tensorflow/lite/core/c:c_api_types", "//tensorflow/lite/core/c:common", "//tensorflow/lite/core/kernels:builtin_ops", - "//tensorflow/lite/experimental/remat:metadata_util", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels/internal:compatibility", diff --git a/tensorflow/lite/delegates/delegate_test.cc b/tensorflow/lite/delegates/delegate_test.cc index 8d37bfa38c8af7..f4bc7fd9f654da 100644 --- a/tensorflow/lite/delegates/delegate_test.cc +++ b/tensorflow/lite/delegates/delegate_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.h" #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/core/c/c_api_opaque.h" #include "tensorflow/lite/core/c/c_api_types.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" #include "tensorflow/lite/delegates/delegate_test_util.h" -#include "tensorflow/lite/experimental/remat/metadata_util.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/kernel_util.h" diff --git a/tensorflow/lite/experimental/remat/BUILD b/tensorflow/lite/experimental/remat/BUILD deleted file mode 100644 index 2860b4c75f32b8..00000000000000 --- a/tensorflow/lite/experimental/remat/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [":friends"], - licenses = ["notice"], -) - -package_group( - name = "friends", - packages = [ - "//tensorflow/compiler/mlir/lite/...", - "//tensorflow/lite/...", - ], -) - -cc_library( - name = "metadata_util", - srcs = ["metadata_util.cc"], - hdrs = ["metadata_util.h"], - compatible_with = get_compatible_with_portable(), - deps = [ - "//tensorflow/lite:graph_info", - ], -) - -cc_test( - name = "metadata_util_test", - size = "small", - srcs = ["metadata_util_test.cc"], - deps = [ - ":metadata_util", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/graph_info.h b/tensorflow/lite/graph_info.h index 0370a12c5c17e2..532ea8040c2938 100644 --- a/tensorflow/lite/graph_info.h +++ b/tensorflow/lite/graph_info.h @@ -96,7 +96,7 @@ struct NodeSubset { // Node edge.second depends on node edge.first. using ControlEdge = std::pair; using ControlEdges = std::vector; -// LINT.ThenChange(//tensorflow/compiler/mlir/lite/flatbuffer_export.cc) +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/utils/control_edges.h) // Partitions a list of node indices `nodes_to_partition` into node subsets. // Each node subset is in dependency order internally (i.e. all members of the From ff06f4d71b5cd5a8b05d2544ce2627a889b0a90c Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Fri, 21 Jun 2024 11:33:27 -0700 Subject: [PATCH 124/256] Stop using xla/statusor.h now that it just contains an alias for absl::Status. In some situations, this meant also changing unrelated files to directly include tsl/platform/statusor.h to get the definitions for TF_ASSIGN_OR_RETURN, etc., where they were getting transitively included for free. PiperOrigin-RevId: 645452258 --- third_party/xla/xla/service/cpu/BUILD | 21 ++++++++----------- .../xla/xla/service/cpu/compiler_functor.cc | 2 +- .../xla/xla/service/cpu/cpu_compiler.h | 2 +- .../xla/xla/service/cpu/cpu_runtime.cc | 1 - .../xla/service/cpu/cpu_transfer_manager.cc | 2 +- .../xla/service/cpu/cpu_transfer_manager.h | 2 +- third_party/xla/xla/service/cpu/cpu_xfeed.cc | 2 +- third_party/xla/xla/service/cpu/ir_emitter.h | 2 +- third_party/xla/xla/service/cpu/ir_function.h | 2 +- .../service/cpu/parallel_task_assignment.cc | 2 +- .../service/cpu/parallel_task_assignment.h | 2 +- .../xla/xla/service/cpu/sample_harness.cc | 2 +- .../xla/xla/service/cpu/xfeed_manager.h | 2 +- 13 files changed, 20 insertions(+), 24 deletions(-) diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 436535e1c405ac..10a56c4e2c4c2f 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -146,7 +146,6 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:types", "//xla:util", "//xla/service:hlo_cost_analysis", @@ -154,6 +153,7 @@ cc_library( "@com_google_absl//absl/base", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:notification", @@ -171,7 +171,6 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", @@ -182,6 +181,7 @@ cc_library( "//xla/stream_executor:platform_manager", "//xla/stream_executor/host:host_platform_id", "@com_google_absl//absl/base", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -233,7 +233,6 @@ cc_library( "//xla:protobuf_util", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", @@ -406,7 +405,6 @@ cc_library( ":target_machine_features", ":xla_framework", "//xla:cpu_function_runtime", - "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", @@ -420,6 +418,7 @@ cc_library( "//xla/stream_executor", "//xla/stream_executor/host:host_platform_id", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Target", ], alwayslink = True, # Contains compiler registration @@ -550,7 +549,6 @@ cc_library( "//xla:shape_tree", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", @@ -696,7 +694,6 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:types", "//xla:util", "//xla:window_util", @@ -724,6 +721,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -772,10 +770,10 @@ cc_library( ":shape_partition", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:types", "//xla/service:hlo_module_config", "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", @@ -911,7 +909,6 @@ xla_cc_binary( deps = [ "//xla:array4d", "//xla:literal", - "//xla:statusor", "//xla:types", "//xla:xla_data_proto_cc", "//xla/client", @@ -920,6 +917,7 @@ xla_cc_binary( "//xla/client:local_client", "//xla/client:xla_builder", "//xla/client:xla_computation", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", @@ -933,12 +931,12 @@ cc_library( deps = [ ":cpu_runtime", ":llvm_ir_runtime", - "//xla:statusor", "//xla:types", "//xla:util", "//xla/service:llvm_compiler", "//xla/service/llvm_ir:llvm_util", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Analysis", "@llvm-project//llvm:Core", "@llvm-project//llvm:Instrumentation", @@ -969,7 +967,6 @@ cc_library( ":in_process_collectives", "//xla:executable_run_options", "//xla:shape_util", - "//xla:statusor", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", @@ -1539,7 +1536,6 @@ cc_library( ":ir_emission_utils", ":shape_partition", ":target_machine_features", - "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:hlo_cost_analysis", @@ -1549,6 +1545,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", @@ -1676,7 +1673,6 @@ cc_library( "//xla:literal", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", @@ -1686,6 +1682,7 @@ cc_library( "//xla/service/llvm_ir:llvm_util", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", "@llvm-project//llvm:TargetParser", diff --git a/third_party/xla/xla/service/cpu/compiler_functor.cc b/third_party/xla/xla/service/cpu/compiler_functor.cc index b481ba24bd9418..756f43e371100a 100644 --- a/third_party/xla/xla/service/cpu/compiler_functor.cc +++ b/third_party/xla/xla/service/cpu/compiler_functor.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -38,7 +39,6 @@ limitations under the License. #include "xla/service/cpu/cpu_runtime.h" #include "xla/service/cpu/llvm_ir_runtime.h" #include "xla/service/llvm_ir/llvm_util.h" -#include "xla/statusor.h" #include "xla/types.h" #include "xla/util.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.h b/third_party/xla/xla/service/cpu/cpu_compiler.h index 87fc7eb82eae15..a2984088f395d6 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.h +++ b/third_party/xla/xla/service/cpu/cpu_compiler.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "llvm/Target/TargetMachine.h" #include "xla/cpu_function_runtime.h" #include "xla/hlo/ir/hlo_module.h" @@ -37,7 +38,6 @@ limitations under the License. #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_profile_printer_data.pb.h" #include "xla/service/llvm_compiler.h" -#include "xla/statusor.h" #include "xla/stream_executor/stream_executor.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.cc b/third_party/xla/xla/service/cpu/cpu_runtime.cc index 90597f7ef3ef38..05125edd3d5f62 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.cc +++ b/third_party/xla/xla/service/cpu/cpu_runtime.cc @@ -49,7 +49,6 @@ limitations under the License. #include "xla/service/global_device_id.h" #include "xla/service/hlo_parser.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream_executor.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/cpu/cpu_transfer_manager.cc b/third_party/xla/xla/service/cpu/cpu_transfer_manager.cc index fd68f2996357f6..98eb8851f658eb 100644 --- a/third_party/xla/xla/service/cpu/cpu_transfer_manager.cc +++ b/third_party/xla/xla/service/cpu/cpu_transfer_manager.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/base/casts.h" +#include "absl/status/statusor.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/compiler.h" @@ -28,7 +29,6 @@ limitations under the License. #include "xla/service/cpu/cpu_xfeed.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/stream_executor/host/host_platform_id.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" diff --git a/third_party/xla/xla/service/cpu/cpu_transfer_manager.h b/third_party/xla/xla/service/cpu/cpu_transfer_manager.h index ed7f0fe3f7b383..d131ef57f7c485 100644 --- a/third_party/xla/xla/service/cpu/cpu_transfer_manager.h +++ b/third_party/xla/xla/service/cpu/cpu_transfer_manager.h @@ -18,12 +18,12 @@ limitations under the License. #include +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/literal.h" #include "xla/service/cpu/xfeed_manager.h" #include "xla/service/generic_transfer_manager.h" #include "xla/service/transfer_manager.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream_executor.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/cpu/cpu_xfeed.cc b/third_party/xla/xla/service/cpu/cpu_xfeed.cc index 332ce3b1c126fb..8d95ae7e1b3a80 100644 --- a/third_party/xla/xla/service/cpu/cpu_xfeed.cc +++ b/third_party/xla/xla/service/cpu/cpu_xfeed.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/cleanup/cleanup.h" +#include "absl/status/statusor.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/cpu/cpu_runtime.h" @@ -32,7 +33,6 @@ limitations under the License. #include "xla/service/shaped_buffer.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/types.h" #include "xla/util.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/service/cpu/ir_emitter.h b/third_party/xla/xla/service/cpu/ir_emitter.h index 9d7b4a65dfd81d..7e34bbf9e7b88c 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.h +++ b/third_party/xla/xla/service/cpu/ir_emitter.h @@ -27,6 +27,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/IR/Attributes.h" @@ -48,7 +49,6 @@ limitations under the License. #include "xla/service/llvm_ir/ir_builder_mixin.h" #include "xla/service/llvm_ir/loop_emitter.h" #include "xla/service/name_uniquer.h" -#include "xla/statusor.h" #include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/service/cpu/ir_function.h b/third_party/xla/xla/service/cpu/ir_function.h index 73a85867438a9d..a9ae4ce1a817a2 100644 --- a/third_party/xla/xla/service/cpu/ir_function.h +++ b/third_party/xla/xla/service/cpu/ir_function.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_CPU_IR_FUNCTION_H_ #define XLA_SERVICE_CPU_IR_FUNCTION_H_ +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -24,7 +25,6 @@ limitations under the License. #include "xla/service/cpu/ir_emission_utils.h" #include "xla/service/hlo_module_config.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/types.h" namespace xla { diff --git a/third_party/xla/xla/service/cpu/parallel_task_assignment.cc b/third_party/xla/xla/service/cpu/parallel_task_assignment.cc index 60d6fd7c1c966a..ab55629e1fd578 100644 --- a/third_party/xla/xla/service/cpu/parallel_task_assignment.cc +++ b/third_party/xla/xla/service/cpu/parallel_task_assignment.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" @@ -36,7 +37,6 @@ limitations under the License. #include "xla/service/cpu/target_machine_features.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/llvm_ir/dynamic_update_slice_util.h" -#include "xla/statusor.h" #include "xla/util.h" #include "tsl/platform/cpu_info.h" #include "tsl/platform/logging.h" // IWYU pragma: keep diff --git a/third_party/xla/xla/service/cpu/parallel_task_assignment.h b/third_party/xla/xla/service/cpu/parallel_task_assignment.h index 2fad81a5d9b870..e523323262b793 100644 --- a/third_party/xla/xla/service/cpu/parallel_task_assignment.h +++ b/third_party/xla/xla/service/cpu/parallel_task_assignment.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -28,7 +29,6 @@ limitations under the License. #include "xla/service/cpu/target_machine_features.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" #include "xla/util.h" namespace xla { diff --git a/third_party/xla/xla/service/cpu/sample_harness.cc b/third_party/xla/xla/service/cpu/sample_harness.cc index 1f21f22d2846a6..f43ff05bcb237a 100644 --- a/third_party/xla/xla/service/cpu/sample_harness.cc +++ b/third_party/xla/xla/service/cpu/sample_harness.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "xla/array4d.h" #include "xla/client/client.h" @@ -26,7 +27,6 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" #include "xla/literal.h" -#include "xla/statusor.h" #include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/init_main.h" diff --git a/third_party/xla/xla/service/cpu/xfeed_manager.h b/third_party/xla/xla/service/cpu/xfeed_manager.h index e29d3b7d945eb9..054e04743c4c85 100644 --- a/third_party/xla/xla/service/cpu/xfeed_manager.h +++ b/third_party/xla/xla/service/cpu/xfeed_manager.h @@ -22,9 +22,9 @@ limitations under the License. #include +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/shape.h" -#include "xla/statusor.h" #include "xla/types.h" #include "xla/xla_data.pb.h" From ba946a9ddc7b2b0377f391675c59c3685b536007 Mon Sep 17 00:00:00 2001 From: Seher Ellis Date: Fri, 21 Jun 2024 11:45:09 -0700 Subject: [PATCH 125/256] [XLA] Support broadcast as a formatting op in collective pipeliner. PiperOrigin-RevId: 645455530 --- third_party/xla/xla/service/BUILD | 2 + .../xla/xla/service/collective_pipeliner.cc | 7 +- .../xla/service/collective_pipeliner_test.cc | 79 +++++++++++++++++++ 3 files changed, 86 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 093d373ee11212..54657169ed0d78 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -610,9 +610,11 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", ], ) diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc index 998cbc634b0c9e..7faef4e5330872 100644 --- a/third_party/xla/xla/service/collective_pipeliner.cc +++ b/third_party/xla/xla/service/collective_pipeliner.cc @@ -321,7 +321,8 @@ CheckStoreIntoSliceIsCompatible(HloInstruction* instr, return HloPredicateIsOp(i) || + HloOpcode::kAllReduce, HloOpcode::kTranspose, + HloOpcode::kBroadcast>(i) || (multi_uses_pipelining && i->IsElementwise()) || i->IsCustomCall(CollectivePipeliner::kInsertedByPreviousStep); }; @@ -2121,9 +2122,11 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, continue; } if (formatting_op->opcode() == HloOpcode::kBroadcast) { - CHECK(formatting_op->dimensions().empty()); auto operands = collect_operands(formatting_op); std::vector dimensions(1, 0); + for (const int64_t dim : formatting_op->dimensions()) { + dimensions.push_back(dim + 1); + } // Constant scalars don't get expanded ahead of time and are kept // scalar. if (operands[0]->shape().dimensions_size() == 0) { diff --git a/third_party/xla/xla/service/collective_pipeliner_test.cc b/third_party/xla/xla/service/collective_pipeliner_test.cc index 520206a43c142a..56f0aaf829b841 100644 --- a/third_party/xla/xla/service/collective_pipeliner_test.cc +++ b/third_party/xla/xla/service/collective_pipeliner_test.cc @@ -21,9 +21,11 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -2630,5 +2632,82 @@ ENTRY entry { .value()); } +TEST_F(CollectivePipelinerTest, BroadcastAsFormattingOp) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +add.1 { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=2 + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + constant.2561 = s32[] constant(0) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) + ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1 + b.1 = bf16[1,8,128,32] broadcast(ar.1), dimensions={0,1,2} + constant = bf16[] constant(0) + reduce = bf16[1,8,128] reduce(b.1, constant), dimensions={3}, to_apply=add.1 + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, reduce, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.35) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0) + while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value(); + EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true, + /*level_to_operate_on=*/0, + /*pipeline_use_tree=*/true, + /*process_different_sized_ops=*/true, + CollectivePipeliner::kForwardSink) + .value()); + XLA_VLOG_LINES(1, module->ToString()); + const HloInstruction* while_instr = + FindInstruction(module.get(), HloOpcode::kWhile); + EXPECT_GE(while_instr->users().size(), 2); + EXPECT_TRUE( + absl::c_any_of(while_instr->users(), [](const HloInstruction* user) { + return absl::c_any_of( + user->users(), [](const HloInstruction* user_user) { + return user_user->opcode() == HloOpcode::kAllReduce; + }); + })); +} + } // namespace } // namespace xla From 701f78b09d0959229e6f1483f86ee6e4a97cd646 Mon Sep 17 00:00:00 2001 From: Kuy Mainwaring Date: Fri, 21 Jun 2024 12:06:05 -0700 Subject: [PATCH 126/256] [XLA:GPU] Clang-tidy cleanup for xla/service/broadcast_canonicalizer.cc PiperOrigin-RevId: 645461819 --- third_party/xla/xla/service/BUILD | 7 +++++++ .../xla/xla/service/broadcast_canonicalizer.cc | 13 +++++++++++++ 2 files changed, 20 insertions(+) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 54657169ed0d78..0a8de50552f6c1 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -312,6 +312,13 @@ cc_library( deps = [ ":hlo_creation_utils", ":hlo_pass", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/broadcast_canonicalizer.cc b/third_party/xla/xla/service/broadcast_canonicalizer.cc index e763b4d60d0e23..e1bfbc217aecbe 100644 --- a/third_party/xla/xla/service/broadcast_canonicalizer.cc +++ b/third_party/xla/xla/service/broadcast_canonicalizer.cc @@ -15,7 +15,20 @@ limitations under the License. #include "xla/service/broadcast_canonicalizer.h" +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_creation_utils.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { From 4203409fda00fbfb2a78bfe5788b996d9e7f0523 Mon Sep 17 00:00:00 2001 From: Kuy Mainwaring Date: Fri, 21 Jun 2024 12:06:29 -0700 Subject: [PATCH 127/256] [XLA:GPU] Clang-tidy cleanup for xla/service/buffer_assignment.h PiperOrigin-RevId: 645461929 --- third_party/xla/xla/service/BUILD | 4 +++- .../xla/xla/service/buffer_assignment.h | 21 +++++++++++++------ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 0a8de50552f6c1..b8829394f072de 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -1801,14 +1801,16 @@ cc_library( ], deps = [ ":buffer_assignment_proto_cc", + ":buffer_value", ":buffer_value_containers", + ":call_graph", ":hlo_alias_analysis", ":hlo_buffer", ":hlo_dataflow_analysis", + ":hlo_ordering", ":hlo_proto_cc", ":hlo_value", ":logical_buffer", - ":tuple_points_to_analysis", "//xla:shape_util", "//xla:status_macros", "//xla:types", diff --git a/third_party/xla/xla/service/buffer_assignment.h b/third_party/xla/xla/service/buffer_assignment.h index b1bed335ef8254..b6887beb8546c3 100644 --- a/third_party/xla/xla/service/buffer_assignment.h +++ b/third_party/xla/xla/service/buffer_assignment.h @@ -16,14 +16,21 @@ limitations under the License. #ifndef XLA_SERVICE_BUFFER_ASSIGNMENT_H_ #define XLA_SERVICE_BUFFER_ASSIGNMENT_H_ +#include +#include +#include #include #include #include +#include +#include #include +#include #include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" @@ -31,14 +38,18 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/buffer_assignment.pb.h" +#include "xla/service/buffer_value.h" +#include "xla/service/call_graph.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_buffer.h" #include "xla/service/hlo_dataflow_analysis.h" +#include "xla/service/hlo_ordering.h" +#include "xla/service/hlo_value.h" #include "xla/service/logical_buffer.h" #include "xla/service/memory_space_assignment/memory_space_assignment.h" -#include "xla/service/tuple_points_to_analysis.h" -#include "xla/types.h" +#include "xla/shape_util.h" #include "tsl/platform/logging.h" namespace xla { @@ -94,7 +105,7 @@ class BufferAllocation { return !is_thread_local() && !is_tuple(); } - // Whether this allocation is readonly i.e. backed by memory we cannot write + // Whether this allocation is read-only i.e. backed by memory we cannot write // to. bool is_readonly() const { // Entry parameters are generally readonly, except when they are aliased @@ -249,9 +260,7 @@ class BufferAllocation { // Return the set of heap traces used to assign slices to logical buffers in // this allocation. - const std::vector HeapTraces() const { - return heap_traces_; - } + std::vector HeapTraces() const { return heap_traces_; } // Returns the LogicalBuffers which are live at the point of peak memory usage // for this allocation. The point of peak memory usage is the point at which From dbc7a4acaf436bef6e8ca670e6077993a3a3f647 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 21 Jun 2024 12:07:04 -0700 Subject: [PATCH 128/256] Add wrapper for building reduce-window HLO using a binary operation's opcode rather than a computation; small refactoring to use a helper function for building computations. PiperOrigin-RevId: 645462122 --- .../xla/xla/service/hlo_creation_utils.cc | 68 +++++++++++-------- .../xla/xla/service/hlo_creation_utils.h | 4 ++ .../xla/service/hlo_creation_utils_test.cc | 48 +++++++++++++ 3 files changed, 93 insertions(+), 27 deletions(-) diff --git a/third_party/xla/xla/service/hlo_creation_utils.cc b/third_party/xla/xla/service/hlo_creation_utils.cc index e2c1f360252177..987a4f5e34e5cb 100644 --- a/third_party/xla/xla/service/hlo_creation_utils.cc +++ b/third_party/xla/xla/service/hlo_creation_utils.cc @@ -422,6 +422,26 @@ HloInstruction* MakeReducePrecisionHlo(HloInstruction* operand, metadata); } +namespace { +static HloComputation* MakeBinaryScalarComputation(HloOpcode binary_opcode, + PrimitiveType dtype, + HloInstruction* ctx, + HloModule* module) { + CHECK_NE(ctx, nullptr); + HloComputation::Builder b( + absl::StrCat(ctx->name(), ".reduce_sub_computation")); + const Shape scalar_shape = ShapeUtil::MakeShape(dtype, {}); + HloInstruction* lhs = + b.AddInstruction(HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + HloInstruction* rhs = + b.AddInstruction(HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + b.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs)); + CHECK_NE(module, nullptr); + return module->AddEmbeddedComputation(b.Build()); +} +} // namespace + absl::StatusOr MakeReduceHlo( HloInstruction* operand, HloInstruction* init_value, absl::Span dimensions, HloComputation* reduce_computation, @@ -448,24 +468,29 @@ absl::StatusOr MakeReduceWindowHlo( metadata); } +absl::StatusOr MakeReduceWindowHlo( + HloInstruction* operand, HloInstruction* init_value, const Window& window, + HloOpcode binary_opcode, const OpMetadata* metadata) { + HloComputation* reduce_computation = MakeBinaryScalarComputation( + binary_opcode, operand->shape().element_type(), operand, + operand->GetModule()); + TF_ASSIGN_OR_RETURN(Shape inferred_shape, + ShapeInference::InferReduceWindowShape( + operand->shape(), init_value->shape(), window, + reduce_computation->ComputeProgramShape())); + return operand->parent()->AddInstruction( + HloInstruction::CreateReduceWindow(inferred_shape, operand, init_value, + window, reduce_computation), + metadata); +} + absl::StatusOr MakeReduceHlo( HloInstruction* operand, HloInstruction* init_value, absl::Span dimensions, HloOpcode binary_opcode, const OpMetadata* metadata, const FrontendAttributes* frontend_attributes) { - auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {}); - HloComputation* reduce_computation; - { - HloComputation::Builder b( - absl::StrCat(operand->name(), ".reduce_sub_computation")); - auto lhs = b.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "lhs")); - auto rhs = b.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "rhs")); - b.AddInstruction( - HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs)); - reduce_computation = - operand->GetModule()->AddEmbeddedComputation(b.Build()); - } + HloComputation* reduce_computation = MakeBinaryScalarComputation( + binary_opcode, operand->shape().element_type(), operand, + operand->GetModule()); return MakeReduceHlo(operand, init_value, dimensions, reduce_computation, metadata, frontend_attributes); } @@ -478,19 +503,8 @@ absl::StatusOr MakeReduceHlo( std::vector all_dims(operand->shape().rank()); std::iota(all_dims.begin(), all_dims.end(), 0); - auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {}); - HloComputation* reduce_computation; - { - HloComputation::Builder b( - absl::StrCat(operand->name(), ".reduce_sub_computation")); - auto lhs = b.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "lhs")); - auto rhs = b.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "rhs")); - b.AddInstruction( - HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs)); - reduce_computation = module->AddEmbeddedComputation(b.Build()); - } + HloComputation* reduce_computation = MakeBinaryScalarComputation( + binary_opcode, operand->shape().element_type(), operand, module); return MakeReduceHlo(operand, init_value, all_dims, reduce_computation, metadata, frontend_attributes); } diff --git a/third_party/xla/xla/service/hlo_creation_utils.h b/third_party/xla/xla/service/hlo_creation_utils.h index 8ca3ac5ed716e0..26af13d70a4d39 100644 --- a/third_party/xla/xla/service/hlo_creation_utils.h +++ b/third_party/xla/xla/service/hlo_creation_utils.h @@ -198,6 +198,10 @@ absl::StatusOr MakeReduceWindowHlo( HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation, const OpMetadata* metadata = nullptr); +absl::StatusOr MakeReduceWindowHlo( + HloInstruction* operand, HloInstruction* init_value, const Window& window, + HloOpcode binary_opcode, const OpMetadata* metadata = nullptr); + // Creates a Reduce HLO instruction and adds it to the computation containing // the operand. This will create the sub-computation needed for the reduction in // the given module. binary_opcode should represent a binary operation. diff --git a/third_party/xla/xla/service/hlo_creation_utils_test.cc b/third_party/xla/xla/service/hlo_creation_utils_test.cc index 4df62a6463e484..252345fbbbc5ff 100644 --- a/third_party/xla/xla/service/hlo_creation_utils_test.cc +++ b/third_party/xla/xla/service/hlo_creation_utils_test.cc @@ -487,6 +487,54 @@ TEST_F(HloCreationUtilsTest, ReduceWindow) { expected_output_shape); } +TEST_F(HloCreationUtilsTest, ReduceWindowBinaryOpcode) { + const Shape scalar_shape = ShapeUtil::MakeShape(S32, {}); + std::unique_ptr module = CreateNewVerifiedModule(); + + auto builder = HloComputation::Builder(TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4}); + Shape expected_output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + + Window window; + // First dimension is unchanged. + WindowDimension* batch_dim = window.add_dimensions(); + batch_dim->set_size(1); + batch_dim->set_stride(1); + batch_dim->set_padding_low(0); + batch_dim->set_padding_high(0); + batch_dim->set_window_dilation(1); + batch_dim->set_base_dilation(1); + + // Second and third dimension are reduced. + for (int64_t i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(2); + dim->set_stride(2); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + module->AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN( + HloInstruction * reduce_window, + MakeReduceWindowHlo(a_param, init, window, HloOpcode::kAdd)); + module->entry_computation()->set_root_instruction( + reduce_window, + /*accept_different_shape=*/true); + + *module->mutable_entry_computation_layout() = + module->compute_computation_layout(); + EXPECT_EQ(module->entry_computation()->root_instruction()->shape(), + expected_output_shape); +} + TEST_F(HloCreationUtilsTest, DynamicBroadcastShape) { HloInstruction* param; HloComputation* entry_computation; From 5e39d5ce77d8f7f69992d9d2780cbe4af789a5cc Mon Sep 17 00:00:00 2001 From: wenchenvincent <32376000+wenchenvincent@users.noreply.github.com> Date: Fri, 21 Jun 2024 12:22:08 -0700 Subject: [PATCH 129/256] PR #12928: [ROCM] Updated fp8 matmul with adjustments for updated hipBlasLt support in ROCm 6.2 Imported from GitHub PR https://github.com/openxla/xla/pull/12928 Main changes include: * Added support for fp8 matmul with output data type to be fp8 and bf16. * Added buffer comparators for fp8e4m3fnuz and fp8e5m2fnuz Copybara import of the project: -- 230f6e405b65d44e030078835699ab6fad9c7886 by Wen Chen : [ROCM] Updated fp8 matmul with adjustments for updated hipBlasLt support Main changes include: * Added support for fp8 matmul with output data type to be fp8 and bf16. * Added buffer comparators for fp8e4m3fnuz and fp8e5m2fnuz -- 0bf7e04d60adb528fb8eea949285770fcc91d748 by Wen Chen : [ROCM] Addressed reviewer comment. Merging this change closes #12928 PiperOrigin-RevId: 645466349 --- .../xla/xla/service/gpu/buffer_comparator.cc | 10 ++ .../xla/service/gpu/buffer_comparator.cu.cc | 61 ++++++++++ .../xla/xla/service/gpu/buffer_comparator.h | 8 ++ .../xla/xla/service/gpu/gemm_rewriter.cc | 38 +++--- .../xla/xla/service/gpu/matmul_utils.cc | 4 + .../service/gpu/tests/gemm_rewrite_test.cc | 115 +++++++++++------- .../xla/stream_executor/rocm/hip_blas_lt.cc | 17 +++ 7 files changed, 192 insertions(+), 61 deletions(-) diff --git a/third_party/xla/xla/service/gpu/buffer_comparator.cc b/third_party/xla/xla/service/gpu/buffer_comparator.cc index 6834ed52f65325..e00889b23b27dd 100644 --- a/third_party/xla/xla/service/gpu/buffer_comparator.cc +++ b/third_party/xla/xla/service/gpu/buffer_comparator.cc @@ -187,6 +187,16 @@ absl::StatusOr BufferComparator::CompareEqual( stream, current, expected, "fp8_e5m2_comparison", buffer_comparator::fp8_e5m2_comparison()); #endif // GOOGLE_CUDA +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200 + case xla::F8E4M3FNUZ: + return CompareEqualParameterized( + stream, current, expected, "fp8_e4m3fnuz_comparison", + buffer_comparator::fp8_e4m3fnuz_comparison()); + case xla::F8E5M2FNUZ: + return CompareEqualParameterized( + stream, current, expected, "fp8_e5m2fnuz_comparison", + buffer_comparator::fp8_e5m2fnuz_comparison()); +#endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200 case xla::F16: return CompareEqualParameterized( stream, current, expected, "fp16_comparison", diff --git a/third_party/xla/xla/service/gpu/buffer_comparator.cu.cc b/third_party/xla/xla/service/gpu/buffer_comparator.cu.cc index bbe3395345a054..b8e5a8e8d1e662 100644 --- a/third_party/xla/xla/service/gpu/buffer_comparator.cu.cc +++ b/third_party/xla/xla/service/gpu/buffer_comparator.cu.cc @@ -25,6 +25,11 @@ using bfloat16 = __nv_bfloat16; #include #include +#include "rocm/rocm_config.h" +#if TF_ROCM_VERSION >= 60200 +#include +#endif // TF_ROCM_VERSION >= 60200 + using bfloat16 = hip_bfloat16; #define BF16_TO_F32 float @@ -97,6 +102,52 @@ __global__ void xla_fp8_e5m2_comparison(__nv_fp8_storage_t* buffer_a, } #endif // GOOGLE_CUDA +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200 +__global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a, + __hip_fp8_storage_t* buffer_b, + float rel_error_threshold, + uint64_t buffer_length, + int* mismatch_count) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= buffer_length) return; + __hip_fp8_e4m3_fnuz elem_a_fp8, elem_b_fp8; + elem_a_fp8.__x = buffer_a[idx]; + elem_b_fp8.__x = buffer_b[idx]; + float elem_a = static_cast(elem_a_fp8); + float elem_b = static_cast(elem_b_fp8); + elem_a = Canonicalize(elem_a); + elem_b = Canonicalize(elem_b); + if (isnan(elem_a) && isnan(elem_b)) return; + + float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1); + + if (rel_error > rel_error_threshold || isnan(rel_error)) + atomicAdd(mismatch_count, 1); +} + +__global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a, + __hip_fp8_storage_t* buffer_b, + float rel_error_threshold, + uint64_t buffer_length, + int* mismatch_count) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= buffer_length) return; + __hip_fp8_e5m2_fnuz elem_a_fp8, elem_b_fp8; + elem_a_fp8.__x = buffer_a[idx]; + elem_b_fp8.__x = buffer_b[idx]; + float elem_a = static_cast(elem_a_fp8); + float elem_b = static_cast(elem_b_fp8); + elem_a = Canonicalize(elem_a); + elem_b = Canonicalize(elem_b); + if (isnan(elem_a) && isnan(elem_b)) return; + + float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1); + + if (rel_error > rel_error_threshold || isnan(rel_error)) + atomicAdd(mismatch_count, 1); +} +#endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200 + __global__ void xla_fp16_comparison(__half* buffer_a, __half* buffer_b, float rel_error_threshold, uint64_t buffer_length, @@ -206,6 +257,16 @@ void* fp8_e5m2_comparison() { } #endif +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200 +void* fp8_e4m3fnuz_comparison() { + return reinterpret_cast(&xla_fp8_e4m3fnuz_comparison); +} + +void* fp8_e5m2fnuz_comparison() { + return reinterpret_cast(&xla_fp8_e5m2fnuz_comparison); +} +#endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200 + void* fp16_comparison() { return reinterpret_cast(&xla_fp16_comparison); } diff --git a/third_party/xla/xla/service/gpu/buffer_comparator.h b/third_party/xla/xla/service/gpu/buffer_comparator.h index 28c0ec2b03d737..b8c00fa5d9bb8f 100644 --- a/third_party/xla/xla/service/gpu/buffer_comparator.h +++ b/third_party/xla/xla/service/gpu/buffer_comparator.h @@ -22,6 +22,10 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream_executor.h" +#if TENSORFLOW_USE_ROCM +#include "rocm/rocm_config.h" +#endif + namespace xla::gpu { // A device-side comparator that compares buffers. @@ -76,6 +80,10 @@ namespace buffer_comparator { // Returns a pointer to CUDA C++ device function implementing comparison. void* fp8_e4m3fn_comparison(); void* fp8_e5m2_comparison(); +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200 +void* fp8_e4m3fnuz_comparison(); +void* fp8_e5m2fnuz_comparison(); +#endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200 void* fp16_comparison(); void* bf16_comparison(); void* fp32_comparison(); diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/gemm_rewriter.cc index 660ae0412c8e2d..1559f64f94450b 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter.cc @@ -630,7 +630,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { const_cast(instr->operand(0)))) && (b = MatchFp8Param( const_cast(instr->operand(1))))) { - if (IsRocm(gpu_version_) && instr->shape().element_type() != F16 && + if (IsRocm(gpu_version_) && toolkit_version_ < 60200 && + instr->shape().element_type() != F16 && instr->shape().element_type() != F32) { TF_ASSIGN_OR_RETURN(instr, TurnF8DotWithUnsupportedOutputTypeIntoF32(instr)); @@ -1095,20 +1096,24 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } } - switch (instr->shape().element_type()) { - case F8E4M3FN: - case F8E5M2: - case BF16: - case F16: - case F32: - break; - default: - - VLOG(1) << "Failed to rewrite " << instr->ToShortString() - << " into FP8 Custom Call. Output element type must be " - "F8E4M3FN, F8E5M2, BF16, F16 or F32. Actual element type is " - << PrimitiveType_Name(instr->shape().element_type()); - return false; + PrimitiveType d_type = instr->shape().element_type(); + bool supported_d_type = (d_type == BF16 || d_type == F16 || d_type == F32); + if (IsCuda(gpu_version_) && (d_type == F8E4M3FN || d_type == F8E5M2)) { + supported_d_type = true; + } + if (IsRocm(gpu_version_) && toolkit_version_ >= 60200 && + (d_type == F8E4M3FNUZ || d_type == F8E5M2FNUZ)) { + supported_d_type = true; + } + if (!supported_d_type) { + VLOG(1) << "Failed to rewrite " << instr->ToShortString() + << " into FP8 Custom Call. Output element type must be " + << (IsCuda(gpu_version_) ? "F8E4M3FN, F8E5M2, BF16, F16 or F32. " + : toolkit_version_ >= 60200 + ? "F8E4M3FNUZ, F8E5M2FNUZ, BF16, F16 or F32. " + : "BF16, F16 or F32. ") + << "Actual element type is " << PrimitiveType_Name(d_type); + return false; } // Each operand must have exactly one contracting and one non-contracting @@ -1768,7 +1773,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // CUBLAS_STATUS_NOT_SUPPORTED in some cases when fusing gelu into an FP8 // matmul. We cannot check the patch version, so disable this fusion with // CUDA versions less than 12.4. - if (toolkit_version_ < 12040 && IsCublasLtMatmulF8(*gemm)) { + if (IsCuda(gpu_version_) && toolkit_version_ < 12040 && + IsCublasLtMatmulF8(*gemm)) { return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/matmul_utils.cc b/third_party/xla/xla/service/gpu/matmul_utils.cc index 0e056848f62101..fe4982e9a223b9 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils.cc +++ b/third_party/xla/xla/service/gpu/matmul_utils.cc @@ -366,8 +366,12 @@ absl::StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, // for matmuls with FP8 inputs and outputs, C must instead have the same // dtype as the vector bias if present, and either BF16 or F16 otherwise. So // we set the dtype of C here. +#if GOOGLE_CUDA + // hipBlasLt does not yet support the C matrix to be BF16 for fp8 matmul + // with fp8 output. Thus only do this for CUDA side. c_matrix_shape.set_element_type( bias_shape_ptr != nullptr ? bias_shape_ptr->element_type() : BF16); +#endif } TF_ASSIGN_OR_RETURN(MatrixLayout c_layout, diff --git a/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc b/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc index 3426a278b75505..2c851e2c98b32d 100644 --- a/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -115,9 +115,7 @@ class GemmRewriteTest : public GpuCodegenTest { if (IsCuda()) { return std::get(Capability()).IsAtLeast(8, 9); } - return std::get(Capability()) - .has_fp8_support() && - GetDebugOptionsForTest().xla_gpu_enable_cublaslt(); + return std::get(Capability()).has_fp8_support(); } bool HasCudaComputeCapability(const se::CudaComputeCapability& cc) { @@ -4923,9 +4921,15 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDF8) { ; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[C1:[^ ]+]] = f32[] constant(1) -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), -; CHECK: custom_call_target="__cublas$lt$matmul$f8", +)" +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200 + R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), +)" +#else + R"(; CHECK-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), +)" +#endif + R"(; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 ; CHECK-DAG: "alpha_imag":0 @@ -4950,9 +4954,10 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) { GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 -#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 - GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; -#endif // TF_ROCM_VERSION < 60000 +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200 + GTEST_SKIP() << "F8 gemm rewrite for D to be fp8 with Matrix Bias is only " + "supported in ROCm 6.2 and above."; +#endif // TF_ROCM_VERSION < 60200 const char* hlo_text = R"( HloModule test @@ -4978,8 +4983,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) { ; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[C1:[^ ]+]] = f32[] constant(1) -; CHECK-GCN-NEXT: [[DOT_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), -; CHECK-PTX-NEXT: [[DOT_TUPLE:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5809,10 +5813,16 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P3:%[^ ]+]] = bf16[] parameter(3) ; CHECK-NEXT: [[XS1:%[^ ]+]] = f32[] convert([[P3]]) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-PTX-NEXT: [[B:%[^ ]+]] = bf16[16]{0} parameter(4) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]], [[B]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), -; CHECK: custom_call_target="__cublas$lt$matmul$f8", +)" +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200 + R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), +)" +#else + R"(; CHECK-NEXT: [[B:%[^ ]+]] = bf16[16]{0} parameter(4) +; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]], [[B]]), +)" +#endif + R"(; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 ; CHECK-DAG: "alpha_imag":0 @@ -5826,9 +5836,15 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-DAG: "precision_config":{ ; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] ; CHECK-DAG: } -; CHECK-PTX-DAG: "epilogue":"BIAS_GELU" -; CHECK-GCN-DAG: "epilogue":"DEFAULT" -; CHECK: } +)" +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200 + R"(; CHECK-GCN-DAG: "epilogue":"DEFAULT" +)" +#else + R"(; CHECK-DAG: "epilogue":"BIAS_GELU" +)" +#endif + R"(; CHECK: } )"); #endif // (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM } @@ -5898,9 +5914,15 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P3:%[^ ]+]] = bf16[] parameter(3) ; CHECK-NEXT: [[XS1:%[^ ]+]] = f32[] convert([[P3]]) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), -; CHECK: custom_call_target="__cublas$lt$matmul$f8", +)" +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200 + R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), +)" +#else + R"(; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), +)" +#endif + R"(; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 ; CHECK-DAG: "alpha_imag":0 @@ -5914,9 +5936,15 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-DAG: "precision_config":{ ; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] ; CHECK-DAG: } -; CHECK-PTX-DAG: "epilogue":"GELU" -; CHECK-GCN-DAG: "epilogue":"DEFAULT" -; CHECK: } +)" +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200 + R"(; CHECK-GCN-DAG: "epilogue":"DEFAULT" +)" +#else + R"(; CHECK-DAG: "epilogue":"GELU" +)" +#endif + R"(; CHECK: } )"); #endif // (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM } @@ -6143,11 +6171,11 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledDF8) { ; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(1) -; CHECK-PTX-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) -; CHECK-PTX-NEXT: [[P2_INV:%[^ ]+]] = f32[] divide([[C0]], [[P2]]) -; CHECK-PTX-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) +; CHECK-NEXT: [[P2_INV:%[^ ]+]] = f32[] divide([[C0]], [[P2]]) +; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6201,11 +6229,10 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DF8) { ; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(1) -; CHECK-PTX-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) -; CHECK-PTX-NEXT: [[P2_INV:%[^ ]+]] = f32[] divide([[C0]], [[P2]]) -; CHECK-PTX-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) +; CHECK-NEXT: [[P2_INV:%[^ ]+]] = f32[] divide([[C0]], [[P2]]) +; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6258,10 +6285,9 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABInvScaledF32DF8) { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} -; CHECK-PTX-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) -; CHECK-PTX-NEXT: [[C0:%[^ ]+]] = f32[] constant(1) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]], [[C0]], /*index=5*/[[C0]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[C0]], [[C0]], /*index=5*/[[C0]]), +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) +; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(1) +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]], [[C0]], /*index=5*/[[C0]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6318,10 +6344,9 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DMatrixBiasF8) { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} -; CHECK-PTX-NEXT: [[P2:%[^ ]+]] = f32[16,16]{1,0} parameter(2) -; CHECK-PTX-NEXT: [[C0:%[^ ]+]] = f32[] constant(1) -; CHECK-PTX-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]], [[C0]], /*index=5*/[[C0]], [[C0]]), -; CHECK-GCN-NEXT: [[GEMM:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[C0]], [[C0]], /*index=5*/[[C0]]), +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[16,16]{1,0} parameter(2) +; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(1) +; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]], [[C0]], /*index=5*/[[C0]], [[C0]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6399,7 +6424,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) { ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) ; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6525,7 +6550,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) { ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) ; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6614,7 +6639,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasWithDAmaxF8) { ; CHECK-PTX: [[P4:%[^ ]+]] = f16[] parameter(5) ; CHECK-PTX: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]), ; CHECK-NOT: output_to_operand_aliasing -; CHECK-GCN: [[OUT:%[^ ]+]] = f16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]), +; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6696,7 +6721,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) { ; CHECK-PTX-NEXT: [[CV2:%[^ ]+]] = f32[] convert([[DV]]) ; CHECK-NEXT: [[VB:%[^ ]+]] = f16[16]{0} parameter(2) ; CHECK-PTX: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[CV2]], [[VB]]), -; CHECK-GCN: [[OUT:%[^ ]+]] = f16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]), +; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc index 426b461bcdd312..a3f5b8c180c278 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc @@ -558,6 +558,8 @@ absl::Status BlasLt::MatmulPlan::ExecuteOnStream( profile_result); \ } +// FP8 compatible types combinations (Full table in +// https://github.com/ROCm/hipBLASLt/blob/develop/docs/api-reference.rst?plain=1) #if TF_ROCM_VERSION >= 60000 TYPED_MATMUL(float, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_16F, HIP_R_16F) @@ -575,6 +577,21 @@ absl::Status BlasLt::MatmulPlan::ExecuteOnStream( HIP_R_32F) #endif +#if TF_ROCM_VERSION >= 60200 + TYPED_MATMUL(float, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_16BF, + HIP_R_16BF) + TYPED_MATMUL(float, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E5M2_FNUZ, HIP_R_16BF, + HIP_R_16BF) + TYPED_MATMUL(float, HIP_R_8F_E5M2_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_16BF, + HIP_R_16BF) + TYPED_MATMUL(float, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, + HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ) + TYPED_MATMUL(float, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E5M2_FNUZ, + HIP_R_8F_E5M2_FNUZ, HIP_R_8F_E5M2_FNUZ) + TYPED_MATMUL(float, HIP_R_8F_E5M2_FNUZ, HIP_R_8F_E4M3_FNUZ, + HIP_R_8F_E5M2_FNUZ, HIP_R_8F_E5M2_FNUZ) +#endif + // Other data types: TYPED_MATMUL(float, HIP_R_16BF, HIP_R_16BF, HIP_R_16BF, HIP_R_16BF) TYPED_MATMUL(float, HIP_R_16F, HIP_R_16F, HIP_R_16F, HIP_R_16F) From be6bfff2bd2bc5c93d79935f090ee09e6419bd7e Mon Sep 17 00:00:00 2001 From: Chi Zeng Date: Fri, 21 Jun 2024 12:47:00 -0700 Subject: [PATCH 130/256] [XLA:GPU] Use radix sort in place of classic sort for TopK if input size >= 33K. PiperOrigin-RevId: 645473037 --- third_party/xla/xla/service/gpu/gpu_sort_rewriter.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.h b/third_party/xla/xla/service/gpu/gpu_sort_rewriter.h index 2f60816b1a05cb..4e71c2d7c58a48 100644 --- a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.h +++ b/third_party/xla/xla/service/gpu/gpu_sort_rewriter.h @@ -54,7 +54,7 @@ class GpuSortRewriter : public HloModulePass { absl::StatusOr RunOnInstruction(HloSortInstruction* sort_op); absl::StatusOr RunOnComputation(HloComputation* computation); - static inline int sort_size_threshold_ = 100000; + static inline int sort_size_threshold_ = 33000; }; } // namespace gpu From 34a8b4c35eac9021b2aef993a1af78a1d1f5d105 Mon Sep 17 00:00:00 2001 From: Kuy Mainwaring Date: Fri, 21 Jun 2024 12:51:26 -0700 Subject: [PATCH 131/256] [XLA:GPU] Clang-tidy cleanup for xla/service/call_graph.h PiperOrigin-RevId: 645474439 --- third_party/xla/xla/service/BUILD | 10 ++++++++++ third_party/xla/xla/service/call_graph.h | 10 ++++++++++ third_party/xla/xla/service/call_graph_test.cc | 14 ++++++++++++++ 3 files changed, 34 insertions(+) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index b8829394f072de..614864cfaef0b0 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -1045,11 +1045,15 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", ], ) @@ -1058,6 +1062,8 @@ xla_cc_test( srcs = ["call_graph_test.cc"], deps = [ ":call_graph", + "//xla:comparison_util", + "//xla:literal_util", "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", @@ -1065,9 +1071,13 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/call_graph.h b/third_party/xla/xla/service/call_graph.h index a64346e5d0dc3d..c6f933ef1a250d 100644 --- a/third_party/xla/xla/service/call_graph.h +++ b/third_party/xla/xla/service/call_graph.h @@ -18,15 +18,25 @@ limitations under the License. #ifndef XLA_SERVICE_CALL_GRAPH_H_ #define XLA_SERVICE_CALL_GRAPH_H_ +#include #include #include +#include +#include +#include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "tsl/platform/logging.h" namespace xla { diff --git a/third_party/xla/xla/service/call_graph_test.cc b/third_party/xla/xla/service/call_graph_test.cc index e4127567010db2..dfa7d28f06ab1d 100644 --- a/third_party/xla/xla/service/call_graph_test.cc +++ b/third_party/xla/xla/service/call_graph_test.cc @@ -15,18 +15,32 @@ limitations under the License. #include "xla/service/call_graph.h" +#include +#include +#include +#include +#include #include +#include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" +#include "xla/literal_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { From 0ef3cc94c2d0e0ebf685c5d3f8374de95fa7950b Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Fri, 21 Jun 2024 13:17:30 -0700 Subject: [PATCH 132/256] Fix TSAN issue in interpreter Stream. PiperOrigin-RevId: 645481601 --- .../xla/xla/backends/interpreter/executor.h | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/third_party/xla/xla/backends/interpreter/executor.h b/third_party/xla/xla/backends/interpreter/executor.h index 41c49f7bf1071b..d551ed0b158d93 100644 --- a/third_party/xla/xla/backends/interpreter/executor.h +++ b/third_party/xla/xla/backends/interpreter/executor.h @@ -59,6 +59,22 @@ class InterpreterStream : public host::HostStream { absl::Status RecordEvent(Event *event) override { return absl::UnimplementedError("Not implemented."); } + + absl::Status Memcpy(void *host_dst, const DeviceMemoryBase &gpu_src, + uint64_t size) override { + void *src_mem = const_cast(gpu_src.opaque()); + EnqueueTask( + [this, host_dst, src_mem, size]() { memcpy(host_dst, src_mem, size); }); + return BlockUntilDone(); + } + + absl::Status Memcpy(DeviceMemoryBase *gpu_dst, const void *host_src, + uint64_t size) override { + void *dst_mem = gpu_dst->opaque(); + EnqueueTask( + [this, dst_mem, host_src, size]() { memcpy(dst_mem, host_src, size); }); + return BlockUntilDone(); + } }; class XlaInterpreterExecutor : public StreamExecutorCommon { From 0d541c304b89b5e595c47d1cb1dda89c0ba9cb99 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Fri, 21 Jun 2024 13:30:19 -0700 Subject: [PATCH 133/256] Stop using xla/statusor.h now that it just contains an alias for absl::Status. In some situations, this meant also changing unrelated files to directly include tsl/platform/statusor.h to get the definitions for TF_ASSIGN_OR_RETURN, etc., where they were getting transitively included for free. PiperOrigin-RevId: 645485250 --- .../compiler/aot/embedded_protocol_buffers.cc | 1 + third_party/xla/xla/BUILD | 18 +++++++++--------- third_party/xla/xla/comparison_util.cc | 2 +- third_party/xla/xla/comparison_util.h | 2 +- third_party/xla/xla/literal_util.cc | 2 +- third_party/xla/xla/literal_util.h | 2 +- third_party/xla/xla/map_util.h | 2 +- third_party/xla/xla/packed_literal_reader.h | 2 +- third_party/xla/xla/primitive_util.h | 2 +- third_party/xla/xla/refcounting_hash_map.h | 2 +- third_party/xla/xla/shape_tree.h | 2 +- third_party/xla/xla/shape_util.h | 2 +- third_party/xla/xla/test_helpers.h | 2 +- third_party/xla/xla/text_literal_reader.h | 2 +- 14 files changed, 22 insertions(+), 21 deletions(-) diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc index d14de99e55e6f9..5c664da9d64215 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -32,6 +32,7 @@ limitations under the License. #include "llvm/TargetParser/Triple.h" #include "xla/service/llvm_ir/llvm_type_conversion_util.h" #include "xla/util.h" +#include "tsl/platform/statusor.h" namespace tensorflow { namespace tfcompile { diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 5b813d50f79dc2..80bf0ff68b7d94 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -143,12 +143,12 @@ cc_library( visibility = internal_visibility([":friends"]), deps = [ ":shape_util", - ":statusor", ":types", ":util", ":xla_data_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:logging", @@ -339,7 +339,6 @@ cc_library( deps = [ ":status", ":status_macros", - ":statusor", ":types", ":xla_data_proto_cc", "@com_google_absl//absl/algorithm:container", @@ -351,6 +350,7 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", @@ -455,7 +455,6 @@ cc_library( ":permutation_util", ":printer", ":status_macros", - ":statusor", ":types", ":util", ":xla_data_proto_cc", @@ -467,6 +466,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", @@ -664,11 +664,11 @@ cc_library( ":literal", ":shape_util", ":status_macros", - ":statusor", ":types", ":util", ":xla_data_proto_cc", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/lib/core:bitmap", @@ -851,11 +851,11 @@ cc_library( ":literal", ":shape_util", ":status_macros", - ":statusor", ":types", ":util", ":xla_data_proto_cc", "@com_google_absl//absl/base", + "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:protobuf", @@ -868,8 +868,8 @@ cc_library( hdrs = ["test_helpers.h"], visibility = internal_visibility([":friends"]), deps = [ - ":statusor", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:test", ], ) @@ -883,11 +883,11 @@ cc_library( ":literal", ":shape_util", ":status_macros", - ":statusor", ":types", ":util", ":xla_data_proto_cc", "//xla/service:hlo_parser", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_tsl//tsl/lib/io:buffered_inputstream", "@local_tsl//tsl/lib/io:random_inputstream", @@ -952,10 +952,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":shape_util", - ":statusor", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@local_tsl//tsl/lib/gtl:iterator_range", "@local_tsl//tsl/platform:errors", @@ -1149,10 +1149,10 @@ cc_library( name = "refcounting_hash_map", hdrs = ["refcounting_hash_map.h"], deps = [ - ":statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", ], ) diff --git a/third_party/xla/xla/comparison_util.cc b/third_party/xla/xla/comparison_util.cc index a0df27e07928d8..25f17e2c080684 100644 --- a/third_party/xla/xla/comparison_util.cc +++ b/third_party/xla/xla/comparison_util.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/primitive_util.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" // IWYU pragma: keep diff --git a/third_party/xla/xla/comparison_util.h b/third_party/xla/xla/comparison_util.h index 3c72cac3655b3b..2fd4defb11b401 100644 --- a/third_party/xla/xla/comparison_util.h +++ b/third_party/xla/xla/comparison_util.h @@ -22,9 +22,9 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/primitive_util.h" -#include "xla/statusor.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/literal_util.cc b/third_party/xla/xla/literal_util.cc index 27d801a8038201..d878e98dbc0676 100644 --- a/third_party/xla/xla/literal_util.cc +++ b/third_party/xla/xla/literal_util.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" @@ -35,7 +36,6 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/literal_util.h b/third_party/xla/xla/literal_util.h index fad4b5ed9f2826..24cccb58438c17 100644 --- a/third_party/xla/xla/literal_util.h +++ b/third_party/xla/xla/literal_util.h @@ -30,6 +30,7 @@ limitations under the License. #include #include "absl/functional/function_ref.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/array.h" @@ -43,7 +44,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/bitmap.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/map_util.h b/third_party/xla/xla/map_util.h index f25caf0874d408..80f34992f506d9 100644 --- a/third_party/xla/xla/map_util.h +++ b/third_party/xla/xla/map_util.h @@ -21,7 +21,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "xla/statusor.h" +#include "absl/status/statusor.h" #include "xla/util.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/packed_literal_reader.h b/third_party/xla/xla/packed_literal_reader.h index 9103e7544ca57d..1b9b14a0c93c8d 100644 --- a/third_party/xla/xla/packed_literal_reader.h +++ b/third_party/xla/xla/packed_literal_reader.h @@ -18,8 +18,8 @@ limitations under the License. #include +#include "absl/status/statusor.h" #include "xla/literal.h" -#include "xla/statusor.h" #include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/env.h" diff --git a/third_party/xla/xla/primitive_util.h b/third_party/xla/xla/primitive_util.h index 1f400fc7a89d60..8fbeedbff94dad 100644 --- a/third_party/xla/xla/primitive_util.h +++ b/third_party/xla/xla/primitive_util.h @@ -29,9 +29,9 @@ limitations under the License. #include "absl/base/attributes.h" #include "absl/base/optimization.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/statusor.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/refcounting_hash_map.h b/third_party/xla/xla/refcounting_hash_map.h index 3f7e3280245312..68520a636cfcbb 100644 --- a/third_party/xla/xla/refcounting_hash_map.h +++ b/third_party/xla/xla/refcounting_hash_map.h @@ -22,8 +22,8 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/container/node_hash_map.h" #include "absl/functional/function_ref.h" +#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" -#include "xla/statusor.h" namespace xla { diff --git a/third_party/xla/xla/shape_tree.h b/third_party/xla/xla/shape_tree.h index 8fe712cfe2a1e8..ba4e13560fd2c3 100644 --- a/third_party/xla/xla/shape_tree.h +++ b/third_party/xla/xla/shape_tree.h @@ -27,10 +27,10 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "tsl/lib/gtl/iterator_range.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h index c87931674ee1ac..72abbea5325f70 100644 --- a/third_party/xla/xla/shape_util.h +++ b/third_party/xla/xla/shape_util.h @@ -36,6 +36,7 @@ limitations under the License. #include "absl/functional/function_ref.h" #include "absl/log/check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/layout.h" #include "xla/layout_util.h" @@ -43,7 +44,6 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/printer.h" #include "xla/shape.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/test_helpers.h b/third_party/xla/xla/test_helpers.h index abcf4e438d1208..bc0a054626b497 100644 --- a/third_party/xla/xla/test_helpers.h +++ b/third_party/xla/xla/test_helpers.h @@ -17,7 +17,7 @@ limitations under the License. #define XLA_TEST_HELPERS_H_ #include "absl/status/status.h" -#include "xla/statusor.h" +#include "absl/status/statusor.h" #include "tsl/platform/test.h" // This module contains a minimal subset of gmock functionality just diff --git a/third_party/xla/xla/text_literal_reader.h b/third_party/xla/xla/text_literal_reader.h index a0d56611bac030..397229e74d81cf 100644 --- a/third_party/xla/xla/text_literal_reader.h +++ b/third_party/xla/xla/text_literal_reader.h @@ -18,9 +18,9 @@ limitations under the License. #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/literal.h" -#include "xla/statusor.h" #include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/env.h" From 6688a28e46c9ea31d6a44671276121a1773a620d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 21 Jun 2024 13:45:26 -0700 Subject: [PATCH 134/256] Fix a typo and add corresponding tests. The typo essentially reversed the semantics of the boolean parameter allow_shardings_small_dims_across_many_devices added in cl/644531507. PiperOrigin-RevId: 645489632 --- .../auto_sharding/auto_sharding_strategy.cc | 2 +- .../auto_sharding/auto_sharding_test.cc | 49 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 1cc88b3da7fd00..c1af8a87e5e376 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -921,7 +921,7 @@ BuildStrategyAndCost( } } } - if (option.allow_shardings_small_dims_across_many_devices) { + if (!option.allow_shardings_small_dims_across_many_devices) { RemoveShardingsWhereSmallDimsShardedAcrossManyDevices( ins->shape(), strategy_group.get(), /* instruction_has_user_sharding */ ins->has_sharding()); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index 4991a67d064ab7..3a231c51a31ca1 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -410,6 +410,55 @@ ENTRY %elementwise { EXPECT_TRUE(changed); } +TEST_F(AutoShardingTest, AllowShardingsSmallDimsAcrossManyDevicesTest) { + constexpr absl::string_view kHloString = R"( +HloModule module + +ENTRY %elementwise { + parameter.1 = bf16[8,1024]{1,0} parameter(0), sharding={devices=[16,16]<=[256]} + add.1 = bf16[8,1024]{1,0} add(parameter.1, parameter.1) + ROOT copy.45 = bf16[8,1024]{1,0} copy(add.1) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + AutoSharding( + /* option */ AutoShardingOption{ + .enable = true, + .preserve_shardings = + AutoShardingOption::PreserveShardingsType::kKeepAllShardings, + .device_mesh_shape = {128, 1}, + .device_mesh_alpha = {1.0, 1.0}, + .device_mesh_beta = {0.01, 1.0}, + .allow_shardings_small_dims_across_many_devices = true}) + .Run(module.get())); + VLOG(10) << module->ToString(); + EXPECT_TRUE(changed); + const HloInstruction* add1 = FindInstruction(module.get(), "add.1"); + EXPECT_THAT(add1, op::Sharding("{devices=[16,16]<=[256]}")); + + // Test with allow_shardings_small_dims_across_many_devices = False + TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnVerifiedModule(kHloString)); + TF_ASSERT_OK_AND_ASSIGN( + changed, + AutoSharding( + /* option */ AutoShardingOption{ + .enable = true, + .preserve_shardings = + AutoShardingOption::PreserveShardingsType::kKeepAllShardings, + .device_mesh_shape = {128, 1}, + .device_mesh_alpha = {1.0, 1.0}, + .device_mesh_beta = {0.01, 1.0}, + .allow_shardings_small_dims_across_many_devices = false}) + .Run(module.get())); + VLOG(10) << module->ToString(); + EXPECT_TRUE(changed); + add1 = FindInstruction(module.get(), "add.1"); + EXPECT_THAT(add1, Not(op::Sharding("{devices=[16,16]<=[256]}"))); +} + TEST_F(AutoShardingTest, RngBitGeneratorArrayInput) { constexpr absl::string_view kHloString = R"( HloModule rng_bit_generator From 525d163fb6df00e5d946076bf6c7ffbbe6d977d7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 21 Jun 2024 13:47:48 -0700 Subject: [PATCH 135/256] Integrate LLVM at llvm/llvm-project@c07be08df573 Updates LLVM usage to match [c07be08df573](https://github.com/llvm/llvm-project/commit/c07be08df573) PiperOrigin-RevId: 645490268 --- tensorflow/compiler/mlir/lite/BUILD | 2 + .../compiler/mlir/lite/flatbuffer_import.cc | 2 + .../tests/legalize-stablehlo-vhlo.mlir | 4 +- .../transforms/legalize_stablehlo_to_vhlo.cc | 18 +-- .../mlir/lite/tf_to_tfl_flatbuffer.cc | 3 + .../compiler/mlir/tensorflow/ir/tf_ops_n_z.cc | 36 ++--- .../utils/side_effect_analysis_util.cc | 22 ++- .../utils/side_effect_analysis_util.h | 4 +- third_party/llvm/workspace.bzl | 4 +- .../triton/llvm_integration/cl644752548.patch | 144 ++++++++++++++++++ .../triton/llvm_integration/series.bzl | 1 + .../triton/llvm_integration/cl644752548.patch | 144 ++++++++++++++++++ .../triton/llvm_integration/series.bzl | 1 + 13 files changed, 339 insertions(+), 46 deletions(-) create mode 100644 third_party/triton/llvm_integration/cl644752548.patch create mode 100644 third_party/xla/third_party/triton/llvm_integration/cl644752548.patch diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index a6187e92719888..e1e4712276fe28 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1178,6 +1178,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:Support", "@llvm-project//mlir:TranslateLib", "@local_tsl//tsl/platform:status", @@ -1459,6 +1460,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 2adeb5727cd15f..58f0deee7c95e9 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -44,6 +44,7 @@ limitations under the License. #include "llvm/Analysis/AssumeBundleQueries.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -1630,6 +1631,7 @@ OwningOpRef tflite::FlatBufferToMlir( if (!disable_vhlo_to_stablehlo) { mlir::PassManager pass_manager(module.getContext()); pass_manager.addPass(mlir::odml::createLegalizeVhloToStablehloPass()); + pass_manager.addPass(mlir::createReconcileUnrealizedCastsPass()); auto result = pass_manager.run(module); if (failed(result)) { return nullptr; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-vhlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-vhlo.mlir index 9b811bc823da43..2391c2e94a2b8c 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-vhlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-vhlo.mlir @@ -1,5 +1,5 @@ -// RUN: odml-to-stablehlo-opt %s --stablehlo-legalize-vhlo -split-input-file | FileCheck %s -// RUN: odml-to-stablehlo-opt --stablehlo-legalize-vhlo %s | odml-to-stablehlo-opt --vhlo-legalize-stablehlo > %t.0 +// RUN: odml-to-stablehlo-opt %s --stablehlo-legalize-vhlo -reconcile-unrealized-casts -split-input-file | FileCheck %s +// RUN: odml-to-stablehlo-opt --stablehlo-legalize-vhlo -reconcile-unrealized-casts %s | odml-to-stablehlo-opt --vhlo-legalize-stablehlo -reconcile-unrealized-casts > %t.0 // RUN: odml-to-stablehlo-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc index b3a85259cc482a..be2411bffb167b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc @@ -257,18 +257,6 @@ LogicalResult ApplyVhloToStablehloPatterns(ModuleOp module) { return success(); } -LogicalResult ApplyUnrealizedCastCanonicalization(ModuleOp module) { - MLIRContext *context = module->getContext(); - RewritePatternSet patterns(context); - ConversionTarget target(*context); - target.addIllegalOp(); - populateReconcileUnrealizedCastsPatterns(patterns); - if (failed(applyPartialConversion(module, target, std::move(patterns)))) { - return module->emitError("Failed to fold unrealized cast"); - } - return success(); -} - } // namespace struct LegalizeStablehloToVhloPass @@ -286,8 +274,7 @@ struct LegalizeStablehloToVhloPass if (failed(ApplyStablehloToVhloPatterns(module, /*is_func_legal=*/true)) || failed(ApplyVhloToVersionPatterns(module, target_version)) || - failed(ApplyTypeConverter(module, to_builtin_converter)) || - failed(ApplyUnrealizedCastCanonicalization(module))) + failed(ApplyTypeConverter(module, to_builtin_converter))) return signalPassFailure(); } }; @@ -308,8 +295,7 @@ struct LegalizeVhloToStablehloPass if (failed(ApplyTypeConverter(module, to_vhlo_converter)) || failed(ApplyVhloToVersionPatterns(module, stablehlo::getCurrentVersion())) || - failed(ApplyVhloToStablehloPatterns(module)) || - failed(ApplyUnrealizedCastCanonicalization(module))) + failed(ApplyVhloToStablehloPatterns(module))) return signalPassFailure(); } }; diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 0468d93b0c2438..2edeb71430e4a0 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -33,6 +33,7 @@ limitations under the License. #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project @@ -358,6 +359,7 @@ absl::Status ConvertTFExecutorToStablehloFlatbuffer( } pass_manager.clear(); pass_manager.addPass(mlir::odml::createLegalizeStablehloToVhloPass()); + pass_manager.addPass(mlir::createReconcileUnrealizedCastsPass()); if (failed(pass_manager.run(module))) { return status_handler.Combine( absl::InvalidArgumentError("VHLO lowering failed")); @@ -505,6 +507,7 @@ absl::Status ConvertTFExecutorToTFLOrFlatbuffer( } pass_manager.clear(); pass_manager.addPass(mlir::odml::createLegalizeStablehloToVhloPass()); + pass_manager.addPass(mlir::createReconcileUnrealizedCastsPass()); if (failed(pass_manager.run(module))) { return status_handler.Combine( absl::InvalidArgumentError("VHLO lowering failed")); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index c0528f35bd11fe..18ca601fbb6f39 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -2374,8 +2374,8 @@ void TPUExecuteOp::getEffects( // effects on resources. For the MLIR bridge, this op will never be // populated with resource handles and tf.TPUExecuteAndUpdateVariables is // used instead. - for (Value value : getArgs()) { - MarkResourceAsReadAndWrite(value, effects); + for (OpOperand &op_operand : getArgsMutable()) { + MarkResourceAsReadAndWrite(op_operand, effects); } } @@ -2393,8 +2393,8 @@ void _XlaRunOp::getEffects( // Conservatively mark resource handles as read and write, as without // analyzing _XlaCompile, there is not sufficient information to determine // effects on resources. - for (Value value : getArgs()) { - MarkResourceAsReadAndWrite(value, effects); + for (OpOperand &op_operand : getArgsMutable()) { + MarkResourceAsReadAndWrite(op_operand, effects); } } @@ -2455,22 +2455,24 @@ void TPUExecuteAndUpdateVariablesOp::getEffects( effects.reserve(getDeviceVarReadsIndices().size() + 1); effects.emplace_back(MemoryEffects::Write::get(), ResourceEffects::TPUExecute::get()); - auto resource_handles = llvm::make_filter_range(getArgs(), [](Value value) { - return value.getType() - .cast() - .getElementType() - .isa(); - }); - - for (const auto &entry : llvm::enumerate(resource_handles)) { - Value value = entry.value(); - effects.emplace_back(MemoryEffects::Read::get(), value, + auto resource_handles = + llvm::make_filter_range(getArgsMutable(), [](OpOperand &op_operand) { + return op_operand.get() + .getType() + .cast() + .getElementType() + .isa(); + }); + + for (const auto& entry : llvm::enumerate(resource_handles)) { + OpOperand &op_operand = entry.value(); + effects.emplace_back(MemoryEffects::Read::get(), &op_operand, ResourceEffects::Variable::get()); if (getDeviceVarUpdatesIndices() .getValue()[entry.index()] .cast() .getInt() >= 0) - effects.emplace_back(MemoryEffects::Write::get(), value, + effects.emplace_back(MemoryEffects::Write::get(), &op_operand, ResourceEffects::Variable::get()); } } @@ -3070,8 +3072,8 @@ void XlaLaunchOp::getEffects( // Conservatively mark resource handles as read and write, as without // analyzing XlaLaunch, there is not sufficient information to determine // effects on resources. - for (Value value : getArgs()) { - MarkResourceAsReadAndWrite(value, effects); + for (OpOperand &op_operand : getArgsMutable()) { + MarkResourceAsReadAndWrite(op_operand, effects); } } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.cc index 7a6da9fcbd04d2..9d4305b8e033f4 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.cc @@ -38,23 +38,31 @@ std::string GetDeviceAttrAsResourceInstanceStr(mlir::Operation* op) { } void MarkResourceAsReadAndWrite( - Value value, + OpOperand& op_operand, SmallVectorImpl>& effects) { - if (value.getType().cast().getElementType().isa()) { - effects.emplace_back(MemoryEffects::Read::get(), value, + if (op_operand.get() + .getType() + .cast() + .getElementType() + .isa()) { + effects.emplace_back(MemoryEffects::Read::get(), &op_operand, ResourceEffects::Variable::get()); - effects.emplace_back(MemoryEffects::Write::get(), value, + effects.emplace_back(MemoryEffects::Write::get(), &op_operand, ResourceEffects::Variable::get()); } } void MarkResourceAsReadOnly( - Value value, + OpOperand& op_operand, SmallVectorImpl>& effects) { - if (value.getType().cast().getElementType().isa()) { - effects.emplace_back(MemoryEffects::Read::get(), value, + if (op_operand.get() + .getType() + .cast() + .getElementType() + .isa()) { + effects.emplace_back(MemoryEffects::Read::get(), &op_operand, ResourceEffects::Variable::get()); } } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h b/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h index c55ad530f15962..102dd7008f00f9 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h @@ -29,12 +29,12 @@ namespace TF { std::string GetDeviceAttrAsResourceInstanceStr(Operation* op); void MarkResourceAsReadAndWrite( - Value value, + OpOperand& op_operand, SmallVectorImpl>& effect); void MarkResourceAsReadOnly( - Value value, + OpOperand& op_operand, SmallVectorImpl>& effect); diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 18af3c86dfba84..7d5e50e6a89f12 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "e5b0c210cc4cdaae7075ad2d4aa1efe4eb4cb0c5" - LLVM_SHA256 = "40440422a7e5d0fec35d6b542f4aa5e73af304b029e59dc5516c994696086a70" + LLVM_COMMIT = "c07be08df5731dac0b36e029a0dd03ccb099deea" + LLVM_SHA256 = "9f18d8c90a81c966f819bbfd911baed7fc67e019f6b5de1af4bcdd6bd1fb87bf" tf_http_archive( name = name, diff --git a/third_party/triton/llvm_integration/cl644752548.patch b/third_party/triton/llvm_integration/cl644752548.patch new file mode 100644 index 00000000000000..f8d745f44a7056 --- /dev/null +++ b/third_party/triton/llvm_integration/cl644752548.patch @@ -0,0 +1,144 @@ +==== triton/lib/Dialect/Triton/IR/Ops.cpp#34 - /google/src/cloud/joelwee/mlir_2c1ae801e1b66a09a15028ae4ba614e0911eec00_1718810061/triton/lib/Dialect/Triton/IR/Ops.cpp ==== +# action=edit type=text +--- triton/lib/Dialect/Triton/IR/Ops.cpp 2024-06-18 01:09:34.000000000 -0700 ++++ triton/lib/Dialect/Triton/IR/Ops.cpp 2024-06-19 09:46:57.000000000 -0700 +@@ -15,7 +15,7 @@ + void LoadOp::getEffects( + SmallVectorImpl> + &effects) { +- effects.emplace_back(MemoryEffects::Read::get(), getPtr(), ++ effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable(), + triton::GlobalMemory::get()); + if (getIsVolatile()) + effects.emplace_back(MemoryEffects::Write::get(), +==== triton/lib/Dialect/TritonGPU/IR/Dialect.cpp#51 - /google/src/cloud/joelwee/mlir_2c1ae801e1b66a09a15028ae4ba614e0911eec00_1718810061/triton/lib/Dialect/TritonGPU/IR/Dialect.cpp ==== +# action=edit type=text +--- triton/lib/Dialect/TritonGPU/IR/Dialect.cpp 2024-06-19 07:20:43.000000000 -0700 ++++ triton/lib/Dialect/TritonGPU/IR/Dialect.cpp 2024-06-19 09:48:52.000000000 -0700 +@@ -2993,7 +2993,8 @@ + effects.emplace_back(MemoryEffects::Allocate::get(), + mlir::triton::gpu::SharedMemory::get()); + if (getSrc()) +- effects.emplace_back(MemoryEffects::Write::get(), getResult(), ++ effects.emplace_back(MemoryEffects::Write::get(), ++ getOperation()->getOpResult(0), + mlir::triton::gpu::SharedMemory::get()); + } + +@@ -3001,7 +3002,7 @@ + void LocalLoadOp::getEffects( + SmallVectorImpl> + &effects) { +- effects.emplace_back(MemoryEffects::Read::get(), getSrc(), ++ effects.emplace_back(MemoryEffects::Read::get(), &getSrcMutable(), + mlir::triton::gpu::SharedMemory::get()); + } + +@@ -3009,7 +3010,7 @@ + void LocalStoreOp::getEffects( + SmallVectorImpl> + &effects) { +- effects.emplace_back(MemoryEffects::Write::get(), getDst(), ++ effects.emplace_back(MemoryEffects::Write::get(), &getDstMutable(), + mlir::triton::gpu::SharedMemory::get()); + } + +@@ -3017,9 +3018,9 @@ + void AsyncCopyGlobalToLocalOp::getEffects( + SmallVectorImpl> + &effects) { +- effects.emplace_back(MemoryEffects::Read::get(), getSrc(), ++ effects.emplace_back(MemoryEffects::Read::get(), &getSrcMutable(), + mlir::triton::GlobalMemory::get()); +- effects.emplace_back(MemoryEffects::Write::get(), getResult(), ++ effects.emplace_back(MemoryEffects::Write::get(), &getResultMutable(), + mlir::triton::gpu::SharedMemory::get()); + } + +==== triton/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp#8 - /google/src/cloud/joelwee/mlir_2c1ae801e1b66a09a15028ae4ba614e0911eec00_1718810061/triton/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp ==== +# action=edit type=text +--- triton/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp 2024-06-07 10:43:52.000000000 -0700 ++++ triton/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp 2024-06-19 09:50:49.000000000 -0700 +@@ -60,13 +60,13 @@ + void WarpGroupDotOp::getEffects( + SmallVectorImpl> + &effects) { +- auto a = getA(); +- auto b = getB(); +- if (isa(a.getType())) +- effects.emplace_back(MemoryEffects::Read::get(), a, ++ auto& a = getAMutable(); ++ auto& b = getBMutable(); ++ if (isa(a.get().getType())) ++ effects.emplace_back(MemoryEffects::Read::get(), &a, + mlir::triton::gpu::SharedMemory::get()); +- if (isa(b.getType())) +- effects.emplace_back(MemoryEffects::Read::get(), b, ++ if (isa(b.get().getType())) ++ effects.emplace_back(MemoryEffects::Read::get(), &b, + mlir::triton::gpu::SharedMemory::get()); + } + +@@ -99,7 +99,7 @@ + void InitBarrierOp::getEffects( + SmallVectorImpl> + &effects) { +- effects.emplace_back(MemoryEffects::Write::get(), getAlloc(), ++ effects.emplace_back(MemoryEffects::Write::get(), &getAllocMutable(), + mlir::triton::gpu::SharedMemory::get()); + } + +@@ -113,7 +113,7 @@ + void InvalBarrierOp::getEffects( + SmallVectorImpl> + &effects) { +- effects.emplace_back(MemoryEffects::Write::get(), getAlloc(), ++ effects.emplace_back(MemoryEffects::Write::get(), &getAllocMutable(), + mlir::triton::gpu::SharedMemory::get()); + } + +@@ -127,7 +127,7 @@ + void BarrierExpectOp::getEffects( + SmallVectorImpl> + &effects) { +- effects.emplace_back(MemoryEffects::Write::get(), getAlloc(), ++ effects.emplace_back(MemoryEffects::Write::get(), &getAllocMutable(), + mlir::triton::gpu::SharedMemory::get()); + } + +@@ -141,7 +141,7 @@ + void WaitBarrierOp::getEffects( + SmallVectorImpl> + &effects) { +- effects.emplace_back(MemoryEffects::Read::get(), getAlloc(), ++ effects.emplace_back(MemoryEffects::Read::get(), &getAllocMutable(), + mlir::triton::gpu::SharedMemory::get()); + // Need a side effect to prevent compiler from reordering and removing + // the wait operation. +@@ -161,11 +161,11 @@ + void AsyncTMACopyGlobalToLocalOp::getEffects( + SmallVectorImpl> + &effects) { +- effects.emplace_back(MemoryEffects::Read::get(), getDescPtr(), ++ effects.emplace_back(MemoryEffects::Read::get(), &getDescPtrMutable(), + mlir::triton::GlobalMemory::get()); +- effects.emplace_back(MemoryEffects::Write::get(), getBarrier(), ++ effects.emplace_back(MemoryEffects::Write::get(), &getBarrierMutable(), + mlir::triton::gpu::SharedMemory::get()); +- effects.emplace_back(MemoryEffects::Write::get(), getResult(), ++ effects.emplace_back(MemoryEffects::Write::get(), &getResultMutable(), + mlir::triton::gpu::SharedMemory::get()); + } + +@@ -173,9 +173,9 @@ + void AsyncTMACopyLocalToGlobalOp::getEffects( + SmallVectorImpl> + &effects) { +- effects.emplace_back(MemoryEffects::Write::get(), getDescPtr(), ++ effects.emplace_back(MemoryEffects::Write::get(), &getDescPtrMutable(), + mlir::triton::GlobalMemory::get()); +- effects.emplace_back(MemoryEffects::Read::get(), getSrc(), ++ effects.emplace_back(MemoryEffects::Read::get(), &getSrcMutable(), + mlir::triton::gpu::SharedMemory::get()); + } + diff --git a/third_party/triton/llvm_integration/series.bzl b/third_party/triton/llvm_integration/series.bzl index 9d97061753ae0a..09ff0934c79a69 100644 --- a/third_party/triton/llvm_integration/series.bzl +++ b/third_party/triton/llvm_integration/series.bzl @@ -5,4 +5,5 @@ These should be upstreamed to openai/triton as part of the next triton integrati """ llvm_patch_list = [ + "//third_party/triton/llvm_integration:cl644752548.patch", ] diff --git a/third_party/xla/third_party/triton/llvm_integration/cl644752548.patch b/third_party/xla/third_party/triton/llvm_integration/cl644752548.patch new file mode 100644 index 00000000000000..f8d745f44a7056 --- /dev/null +++ b/third_party/xla/third_party/triton/llvm_integration/cl644752548.patch @@ -0,0 +1,144 @@ +==== triton/lib/Dialect/Triton/IR/Ops.cpp#34 - /google/src/cloud/joelwee/mlir_2c1ae801e1b66a09a15028ae4ba614e0911eec00_1718810061/triton/lib/Dialect/Triton/IR/Ops.cpp ==== +# action=edit type=text +--- triton/lib/Dialect/Triton/IR/Ops.cpp 2024-06-18 01:09:34.000000000 -0700 ++++ triton/lib/Dialect/Triton/IR/Ops.cpp 2024-06-19 09:46:57.000000000 -0700 +@@ -15,7 +15,7 @@ + void LoadOp::getEffects( + SmallVectorImpl> + &effects) { +- effects.emplace_back(MemoryEffects::Read::get(), getPtr(), ++ effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable(), + triton::GlobalMemory::get()); + if (getIsVolatile()) + effects.emplace_back(MemoryEffects::Write::get(), +==== triton/lib/Dialect/TritonGPU/IR/Dialect.cpp#51 - /google/src/cloud/joelwee/mlir_2c1ae801e1b66a09a15028ae4ba614e0911eec00_1718810061/triton/lib/Dialect/TritonGPU/IR/Dialect.cpp ==== +# action=edit type=text +--- triton/lib/Dialect/TritonGPU/IR/Dialect.cpp 2024-06-19 07:20:43.000000000 -0700 ++++ triton/lib/Dialect/TritonGPU/IR/Dialect.cpp 2024-06-19 09:48:52.000000000 -0700 +@@ -2993,7 +2993,8 @@ + effects.emplace_back(MemoryEffects::Allocate::get(), + mlir::triton::gpu::SharedMemory::get()); + if (getSrc()) +- effects.emplace_back(MemoryEffects::Write::get(), getResult(), ++ effects.emplace_back(MemoryEffects::Write::get(), ++ getOperation()->getOpResult(0), + mlir::triton::gpu::SharedMemory::get()); + } + +@@ -3001,7 +3002,7 @@ + void LocalLoadOp::getEffects( + SmallVectorImpl> + &effects) { +- effects.emplace_back(MemoryEffects::Read::get(), getSrc(), ++ effects.emplace_back(MemoryEffects::Read::get(), &getSrcMutable(), + mlir::triton::gpu::SharedMemory::get()); + } + +@@ -3009,7 +3010,7 @@ + void LocalStoreOp::getEffects( + SmallVectorImpl> + &effects) { +- effects.emplace_back(MemoryEffects::Write::get(), getDst(), ++ effects.emplace_back(MemoryEffects::Write::get(), &getDstMutable(), + mlir::triton::gpu::SharedMemory::get()); + } + +@@ -3017,9 +3018,9 @@ + void AsyncCopyGlobalToLocalOp::getEffects( + SmallVectorImpl> + &effects) { +- effects.emplace_back(MemoryEffects::Read::get(), getSrc(), ++ effects.emplace_back(MemoryEffects::Read::get(), &getSrcMutable(), + mlir::triton::GlobalMemory::get()); +- effects.emplace_back(MemoryEffects::Write::get(), getResult(), ++ effects.emplace_back(MemoryEffects::Write::get(), &getResultMutable(), + mlir::triton::gpu::SharedMemory::get()); + } + +==== triton/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp#8 - /google/src/cloud/joelwee/mlir_2c1ae801e1b66a09a15028ae4ba614e0911eec00_1718810061/triton/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp ==== +# action=edit type=text +--- triton/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp 2024-06-07 10:43:52.000000000 -0700 ++++ triton/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp 2024-06-19 09:50:49.000000000 -0700 +@@ -60,13 +60,13 @@ + void WarpGroupDotOp::getEffects( + SmallVectorImpl> + &effects) { +- auto a = getA(); +- auto b = getB(); +- if (isa(a.getType())) +- effects.emplace_back(MemoryEffects::Read::get(), a, ++ auto& a = getAMutable(); ++ auto& b = getBMutable(); ++ if (isa(a.get().getType())) ++ effects.emplace_back(MemoryEffects::Read::get(), &a, + mlir::triton::gpu::SharedMemory::get()); +- if (isa(b.getType())) +- effects.emplace_back(MemoryEffects::Read::get(), b, ++ if (isa(b.get().getType())) ++ effects.emplace_back(MemoryEffects::Read::get(), &b, + mlir::triton::gpu::SharedMemory::get()); + } + +@@ -99,7 +99,7 @@ + void InitBarrierOp::getEffects( + SmallVectorImpl> + &effects) { +- effects.emplace_back(MemoryEffects::Write::get(), getAlloc(), ++ effects.emplace_back(MemoryEffects::Write::get(), &getAllocMutable(), + mlir::triton::gpu::SharedMemory::get()); + } + +@@ -113,7 +113,7 @@ + void InvalBarrierOp::getEffects( + SmallVectorImpl> + &effects) { +- effects.emplace_back(MemoryEffects::Write::get(), getAlloc(), ++ effects.emplace_back(MemoryEffects::Write::get(), &getAllocMutable(), + mlir::triton::gpu::SharedMemory::get()); + } + +@@ -127,7 +127,7 @@ + void BarrierExpectOp::getEffects( + SmallVectorImpl> + &effects) { +- effects.emplace_back(MemoryEffects::Write::get(), getAlloc(), ++ effects.emplace_back(MemoryEffects::Write::get(), &getAllocMutable(), + mlir::triton::gpu::SharedMemory::get()); + } + +@@ -141,7 +141,7 @@ + void WaitBarrierOp::getEffects( + SmallVectorImpl> + &effects) { +- effects.emplace_back(MemoryEffects::Read::get(), getAlloc(), ++ effects.emplace_back(MemoryEffects::Read::get(), &getAllocMutable(), + mlir::triton::gpu::SharedMemory::get()); + // Need a side effect to prevent compiler from reordering and removing + // the wait operation. +@@ -161,11 +161,11 @@ + void AsyncTMACopyGlobalToLocalOp::getEffects( + SmallVectorImpl> + &effects) { +- effects.emplace_back(MemoryEffects::Read::get(), getDescPtr(), ++ effects.emplace_back(MemoryEffects::Read::get(), &getDescPtrMutable(), + mlir::triton::GlobalMemory::get()); +- effects.emplace_back(MemoryEffects::Write::get(), getBarrier(), ++ effects.emplace_back(MemoryEffects::Write::get(), &getBarrierMutable(), + mlir::triton::gpu::SharedMemory::get()); +- effects.emplace_back(MemoryEffects::Write::get(), getResult(), ++ effects.emplace_back(MemoryEffects::Write::get(), &getResultMutable(), + mlir::triton::gpu::SharedMemory::get()); + } + +@@ -173,9 +173,9 @@ + void AsyncTMACopyLocalToGlobalOp::getEffects( + SmallVectorImpl> + &effects) { +- effects.emplace_back(MemoryEffects::Write::get(), getDescPtr(), ++ effects.emplace_back(MemoryEffects::Write::get(), &getDescPtrMutable(), + mlir::triton::GlobalMemory::get()); +- effects.emplace_back(MemoryEffects::Read::get(), getSrc(), ++ effects.emplace_back(MemoryEffects::Read::get(), &getSrcMutable(), + mlir::triton::gpu::SharedMemory::get()); + } + diff --git a/third_party/xla/third_party/triton/llvm_integration/series.bzl b/third_party/xla/third_party/triton/llvm_integration/series.bzl index 9d97061753ae0a..09ff0934c79a69 100644 --- a/third_party/xla/third_party/triton/llvm_integration/series.bzl +++ b/third_party/xla/third_party/triton/llvm_integration/series.bzl @@ -5,4 +5,5 @@ These should be upstreamed to openai/triton as part of the next triton integrati """ llvm_patch_list = [ + "//third_party/triton/llvm_integration:cl644752548.patch", ] From a76a162af4c24f11f035e7cf9ed39d45d51ed5bd Mon Sep 17 00:00:00 2001 From: Laura Pak Date: Fri, 21 Jun 2024 13:48:01 -0700 Subject: [PATCH 136/256] Move `::mlir::lite::QuantizeWeights` from `TfLiteStatus` to `absl::Status`. PiperOrigin-RevId: 645490322 --- .../mlir/lite/quantization/lite/BUILD | 4 +- .../quantization/lite/quantize_weights.cc | 21 ++++---- .../lite/quantization/lite/quantize_weights.h | 10 ++-- .../lite/quantize_weights_test.cc | 52 +++++++------------ .../lite/tools/optimize/quantize_weights.cc | 29 +++-------- 5 files changed, 47 insertions(+), 69 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index 35aa1bb8d9001e..8370dd98d79e3b 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -65,9 +65,9 @@ cc_library( "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/core:protos_all_cc", - "//tensorflow/lite/c:c_api_types", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@flatbuffers//:runtime_cc", "@llvm-project//llvm:Support", @@ -204,8 +204,8 @@ tf_cc_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/tools/optimize:test_util", + "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@flatbuffers", "@local_tsl//tsl/platform:logging", diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.cc index 7d37ded0a21937..33f7df2f4ad472 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers @@ -44,7 +45,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/lite/c/c_api_types.h" namespace mlir { namespace lite { @@ -79,7 +79,7 @@ std::unique_ptr CreateMutableModelFromFile( // TODO(b/214314076): Support MLIR model as an input for the C++ dynamic range // quantization API -TfLiteStatus QuantizeWeights( +absl::Status QuantizeWeights( flatbuffers::FlatBufferBuilder* builder, const tflite::Model* input_model, const tflite::TensorType& inference_type, const absl::flat_hash_set& denylisted_ops, @@ -136,7 +136,8 @@ TfLiteStatus QuantizeWeights( quant_specs.inference_type == tensorflow::DT_QINT8)) { LOG(ERROR) << "Couldn't apply dynamic range quantization since unsupported " "inference_type is passed."; - return kTfLiteError; + return absl::InvalidArgumentError( + "Quantize weights transformation failed."); } llvm::dbgs() << "weight_quantization: " << true @@ -152,7 +153,8 @@ TfLiteStatus QuantizeWeights( if (failed(pm.run(module.get()))) { absl::string_view err = statusHandler.ConsumeStatus().message(); LOG(ERROR) << "Failed to quantize: " << err; - return kTfLiteError; + return absl::InvalidArgumentError( + "Quantize weights transformation failed."); } // Export the results to the builder @@ -164,15 +166,16 @@ TfLiteStatus QuantizeWeights( if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options, &result)) { LOG(ERROR) << "Failed to export MLIR to flatbuffer."; - return kTfLiteError; + return absl::InvalidArgumentError( + "Quantize weights transformation failed."); } builder->PushFlatBuffer(reinterpret_cast(result.data()), result.size()); - return kTfLiteOk; + return absl::OkStatus(); } -TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const tflite::Model* input_model, int64_t weights_min_num_elements, bool use_hybrid_evaluation) { @@ -188,7 +191,7 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, } // In MLIR use_updated_hybrid_scheme = true means per-channel operation. -TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const tflite::Model* input_model, BufferType quant_type, bool use_updated_hybrid_scheme) { @@ -209,7 +212,7 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, /*legacy_float_scale=*/true); } -TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const tflite::Model* input_model, int64_t weights_min_num_elements, const CustomOpMap& custom_op_map, diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h index 1707a9d6517d8b..77c0813eda432c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h @@ -22,9 +22,9 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" -#include "tensorflow/lite/c/c_api_types.h" namespace mlir { namespace lite { @@ -57,7 +57,7 @@ using CustomOpMap = std::unordered_map; // third_party/tensorflow/lite/tools/optimize/quantize_weights.h. // TODO(b/202468183): Selective quantization + quant debugger support for // dynamic range quantization for verify_numeric and whole_model_verify flags. -TfLiteStatus QuantizeWeights( +absl::Status QuantizeWeights( flatbuffers::FlatBufferBuilder* builder, const tflite::Model* input_model, const tflite::TensorType& inference_type, const absl::flat_hash_set& denylisted_ops, @@ -67,17 +67,17 @@ TfLiteStatus QuantizeWeights( bool legacy_float_scale = false); // Overloading methods to support old quantizer versions API -TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const tflite::Model* input_model, int64_t weights_min_num_elements, bool use_hybrid_evaluation = true); -TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const tflite::Model* input_model, BufferType quant_type = BufferType::QUANTIZED_INT8, bool use_updated_hybrid_scheme = true); -TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const tflite::Model* input_model, int64_t weights_min_num_elements, const CustomOpMap& custom_op_map, diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc index 479de3e728743b..2e80bcae7486b4 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc @@ -20,7 +20,9 @@ limitations under the License. #include #include +#include #include +#include "absl/status/status.h" #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers @@ -30,7 +32,6 @@ limitations under the License. #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" -#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/model_builder.h" #include "tensorflow/lite/tools/optimize/test_util.h" #include "tsl/platform/logging.h" @@ -209,8 +210,7 @@ const Tensor* FindMatchingExpectedTensor( TEST_F(QuantizeWeightsTest, QuantizationSucceeds) { LoadBasicModel(); flatbuffers::FlatBufferBuilder builder; - auto status = QuantizeWeights(&builder, model_, 0); - EXPECT_EQ(status, kTfLiteOk); + EXPECT_OK(QuantizeWeights(&builder, model_, 0)); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -220,9 +220,8 @@ TEST_F(QuantizeWeightsTest, QuantizationSucceeds) { TEST_F(QuantizeWeightsTest, QuantizationFails) { LoadBasicModel(); flatbuffers::FlatBufferBuilder builder; - auto status = - QuantizeWeights(&builder, model_, TensorType_UINT8, {}, {}, 1024); - EXPECT_EQ(status, kTfLiteError); + EXPECT_EQ(QuantizeWeights(&builder, model_, TensorType_UINT8, {}, {}, 1024), + absl::InternalError("Quantize weights transformation failed.")); } TEST_F(QuantizeWeightsTest, WeightsMinNumElements) { @@ -232,7 +231,7 @@ TEST_F(QuantizeWeightsTest, WeightsMinNumElements) { flatbuffers::FlatBufferBuilder builder; const uint64_t kWeightsMinNumElements = 1000000; EXPECT_EQ(QuantizeWeights(&builder, model_, kWeightsMinNumElements), - kTfLiteOk); + absl::InternalError("Quantize weights transformation failed.")); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -261,8 +260,7 @@ TEST_F(QuantizeWeightsTest, WeightsMinNumElements) { TEST_F(QuantizeWeightsTest, HybridConv) { LoadBasicModel(); flatbuffers::FlatBufferBuilder builder; - auto status = QuantizeWeights(&builder, model_, 0); - EXPECT_EQ(status, kTfLiteOk); + EXPECT_OK(QuantizeWeights(&builder, model_, 0)); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -321,9 +319,8 @@ TEST_F(QuantizeWeightsTest, HybridConv) { TEST_F(QuantizeWeightsTest, DequantizeConv) { LoadBasicModel(); flatbuffers::FlatBufferBuilder builder; - auto status = QuantizeWeights(&builder, model_, 0, - /*use_hybrid_evaluation=*/false); - EXPECT_EQ(status, kTfLiteOk); + EXPECT_OK(QuantizeWeights(&builder, model_, 0, + /*use_hybrid_evaluation=*/false)); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -387,9 +384,7 @@ TEST_F(QuantizeWeightsTest, DequantizeConv) { TEST_F(QuantizeWeightsTest, DequantizeConvFloat16) { LoadBasicModel(); flatbuffers::FlatBufferBuilder builder; - auto status = - QuantizeWeights(&builder, model_, BufferType::QUANTIZED_FLOAT16); - EXPECT_EQ(status, kTfLiteOk); + EXPECT_OK(QuantizeWeights(&builder, model_, BufferType::QUANTIZED_FLOAT16)); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -449,8 +444,7 @@ TEST_F(QuantizeWeightsTest, DequantizeConvFloat16) { TEST_F(QuantizeWeightsTest, SharedWeights_Hybrid) { LoadSharedWeightsModel(); flatbuffers::FlatBufferBuilder builder; - auto status = QuantizeWeights(&builder, model_, 0); - EXPECT_EQ(status, kTfLiteOk); + EXPECT_OK(QuantizeWeights(&builder, model_, 0)); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -482,9 +476,8 @@ TEST_F(QuantizeWeightsTest, SharedWeights_Hybrid) { TEST_F(QuantizeWeightsTest, SharedWeights_Dequantize) { LoadSharedWeightsModel(); flatbuffers::FlatBufferBuilder builder; - auto status = QuantizeWeights(&builder, model_, 0, - /*use_hybrid_evaluation=*/false); - EXPECT_EQ(status, kTfLiteOk); + EXPECT_OK( + QuantizeWeights(&builder, model_, 0, /*use_hybrid_evaluation=*/false)); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -524,8 +517,7 @@ TEST_F(QuantizeWeightsTest, SharedWeights_Dequantize) { TEST_F(QuantizeWeightsTest, VerifyGatherQuantization) { LoadGatherTestModel(); flatbuffers::FlatBufferBuilder builder; - auto status = QuantizeWeights(&builder, model_, 0); - EXPECT_EQ(status, kTfLiteOk); + EXPECT_OK(QuantizeWeights(&builder, model_, 0)); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -562,8 +554,7 @@ TEST_F(QuantizeWeightsTest, VerifyCustomOpQuantizationDequantize) { }; flatbuffers::FlatBufferBuilder builder; - auto status = QuantizeWeights(&builder, model_, 0, custom_op_map); - ASSERT_EQ(status, kTfLiteOk); + EXPECT_OK(QuantizeWeights(&builder, model_, 0, custom_op_map)); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -609,8 +600,7 @@ TEST_F(QuantizeWeightsTest, VerifyCustomOpQuantizationHybrid) { }; flatbuffers::FlatBufferBuilder builder; - auto status = QuantizeWeights(&builder, model_, 0, custom_op_map); - ASSERT_EQ(status, kTfLiteOk); + ASSERT_OK(QuantizeWeights(&builder, model_, 0, custom_op_map)); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -641,8 +631,7 @@ TEST_F(QuantizeWeightsTest, VerifyUpdatedHybridSchemeFalseQuantizationHybrid) { LoadBasicModel(); flatbuffers::FlatBufferBuilder builder; const CustomOpMap custom_op_map; - auto status = QuantizeWeights(&builder, model_, 0, custom_op_map, false); - EXPECT_EQ(status, kTfLiteOk); + EXPECT_OK(QuantizeWeights(&builder, model_, 0, custom_op_map, false)); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -698,10 +687,9 @@ TEST_F(QuantizeWeightsTest, DequantizeConvBlocklisted) { LoadBasicModel(); flatbuffers::FlatBufferBuilder builder; const CustomOpMap custom_op_map; - auto status = QuantizeWeights(&builder, model_, 0, custom_op_map, - /*use_updated_hybrid_scheme=*/true, - {BuiltinOperator_CONV_2D}); - EXPECT_EQ(status, kTfLiteOk); + EXPECT_OK(QuantizeWeights(&builder, model_, 0, custom_op_map, + /*use_updated_hybrid_scheme=*/true, + {BuiltinOperator_CONV_2D})); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); diff --git a/tensorflow/lite/tools/optimize/quantize_weights.cc b/tensorflow/lite/tools/optimize/quantize_weights.cc index 93a65ecc6bfa11..837f8a180f0a9a 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights.cc @@ -618,11 +618,8 @@ absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, // By default we require that only weights with more than // kWeightsMinSizeDefault elements are quantized. if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { - return mlir::lite::QuantizeWeights(builder, input_model, - weights_min_num_elements, - use_hybrid_evaluation) == kTfLiteOk - ? absl::OkStatus() - : absl::InternalError("QuantizeWeights failed"); + return mlir::lite::QuantizeWeights( + builder, input_model, weights_min_num_elements, use_hybrid_evaluation); } CustomOpMap custom_op_map; return QuantizeWeightsInt8(builder, input_model, use_hybrid_evaluation, @@ -637,9 +634,7 @@ absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, QuantizerType quantizer_type) { if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { return mlir::lite::QuantizeWeights(builder, input_model, - weights_min_num_elements) == kTfLiteOk - ? absl::OkStatus() - : absl::InternalError("QuantizeWeights failed"); + weights_min_num_elements); } CustomOpMap custom_op_map; return QuantizeWeightsInt8(builder, input_model, true, @@ -656,9 +651,7 @@ absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { return mlir::lite::QuantizeWeights(builder, input_model, (mlir::lite::BufferType)quant_type, - use_updated_hybrid_scheme) == kTfLiteOk - ? absl::OkStatus() - : absl::InternalError("QuantizeWeights failed"); + use_updated_hybrid_scheme); } switch (quant_type) { case BufferType::QUANTIZED_INT8: { @@ -680,11 +673,8 @@ absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { mlir::lite::CustomOpMap mlir_custom_op_map; ConstructMLIRCustomOpMap(mlir_custom_op_map, custom_op_map); - return mlir::lite::QuantizeWeights(builder, input_model, - weights_min_num_elements, - mlir_custom_op_map) == kTfLiteOk - ? absl::OkStatus() - : absl::InternalError("QuantizeWeights failed"); + return mlir::lite::QuantizeWeights( + builder, input_model, weights_min_num_elements, mlir_custom_op_map); } return QuantizeWeightsInt8(builder, input_model, true, weights_min_num_elements, custom_op_map, @@ -702,11 +692,8 @@ absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, mlir::lite::CustomOpMap mlir_custom_op_map; ConstructMLIRCustomOpMap(mlir_custom_op_map, custom_op_map); return mlir::lite::QuantizeWeights( - builder, input_model, weights_min_num_elements, - mlir_custom_op_map, use_updated_hybrid_scheme, - op_denylist) == kTfLiteOk - ? absl::OkStatus() - : absl::InternalError("QuantizeWeights failed"); + builder, input_model, weights_min_num_elements, mlir_custom_op_map, + use_updated_hybrid_scheme, op_denylist); } return QuantizeWeightsInt8(builder, input_model, /*use_hybrid_evaluation=*/true, From 5bfe6d087f2c2e04b24ebd1d294eeef0aa925efb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 21 Jun 2024 14:55:04 -0700 Subject: [PATCH 137/256] Allow strategies for slice ops where the sliced dimensions could be sharded. PiperOrigin-RevId: 645509500 --- .../auto_sharding/auto_sharding_strategy.cc | 80 ++++++++++++++++--- .../auto_sharding/auto_sharding_test.cc | 41 ++++++++++ .../auto_sharding/auto_sharding_util.cc | 26 ------ .../auto_sharding/auto_sharding_util.h | 5 -- 4 files changed, 111 insertions(+), 41 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index c1af8a87e5e376..f84034b0838acf 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -74,6 +74,54 @@ std::optional ConstructImprovedSharding( allow_aggressive_resharding); } +std::pair +ComputeSliceShardingAndCommunicationCostFromOperand( + const HloSharding& input_spec, const Shape& old_shape, + const Shape& new_shape, const Array& device_mesh, + const ClusterEnvironment& cluster_env) { + if (input_spec.IsReplicated()) { + return std::make_pair(input_spec, 0); + } + + CHECK(old_shape.IsArray()); + + std::vector tensor_to_mesh_dim = + GetTensorDimToMeshDim(new_shape.rank(), input_spec, device_mesh, + /* consider_reverse_device_meshes */ true); + + std::vector mesh_dims_for_communication; + std::vector tensor_dims; + std::vector mesh_dims; + for (size_t i = 0; i < new_shape.rank(); ++i) { + if (tensor_to_mesh_dim[i] == -1) { + continue; + } + tensor_dims.push_back(i); + mesh_dims.push_back(tensor_to_mesh_dim[i]); + if (new_shape.dimensions(i) != old_shape.dimensions(i)) { + mesh_dims_for_communication.push_back(tensor_to_mesh_dim[i]); + } + } + + // When input_spec shards one or more or the sliced tensor dimensions, we + // might be required to perform some collective communication. In the worst + // case, the sliced output would be available on one machine, which we would + // need to then re-shard across the devices per result. We approximate the + // cost for this operation by adding up the ReduceScatter cost across the mesh + // dimensions that shard sliced tensor dimensions. + const HloSharding& result = + Tile(new_shape, tensor_dims, mesh_dims, device_mesh); + double num_bytes_to_transfer = GetBytes(new_shape); + double communication_cost = 0; + for (size_t i = 0; i < mesh_dims_for_communication.size(); ++i) { + int64_t mesh_dim = mesh_dims_for_communication[i]; + num_bytes_to_transfer /= device_mesh.dim(mesh_dim); + communication_cost += + cluster_env.ReduceScatterCost(num_bytes_to_transfer, mesh_dim); + } + return std::make_pair(result, communication_cost); +} + // NOLINTBEGIN(readability/fn_size) // TODO(zhuohan): Decompose this function into smaller pieces absl::StatusOr> @@ -462,6 +510,7 @@ BuildStrategyAndCost( HloSharding input_spec = src_strategy_group->strategies[sid].output_sharding; + double compute_cost = 0, communication_cost = 0; // Find output shardings. switch (opcode) { case HloOpcode::kSlice: { @@ -477,19 +526,31 @@ BuildStrategyAndCost( if (is_1d_sharding && input_spec.TotalNumTiles() == cluster_env.device_mesh_1d_.num_elements()) { - output_spec = PropagateDimwiseShardingSlice( - input_spec, operand->shape(), ins->shape(), - cluster_env.device_mesh_1d_); + std::pair + output_spec_and_communication_cost = + ComputeSliceShardingAndCommunicationCostFromOperand( + input_spec, operand->shape(), ins->shape(), + cluster_env.device_mesh_1d_, cluster_env); + output_spec = output_spec_and_communication_cost.first; + communication_cost = output_spec_and_communication_cost.second; } else if (is_1d_sharding) { CHECK_EQ(input_spec.TotalNumTiles(), cluster_env.original_device_mesh_1d_.num_elements()); - output_spec = PropagateDimwiseShardingSlice( - input_spec, operand->shape(), ins->shape(), - cluster_env.original_device_mesh_1d_); + std::pair + output_spec_and_communication_cost = + ComputeSliceShardingAndCommunicationCostFromOperand( + input_spec, operand->shape(), ins->shape(), + cluster_env.original_device_mesh_1d_, cluster_env); + output_spec = output_spec_and_communication_cost.first; + communication_cost = output_spec_and_communication_cost.second; } else { - output_spec = PropagateDimwiseShardingSlice( - input_spec, operand->shape(), ins->shape(), - cluster_env.device_mesh_); + std::pair + output_spec_and_communication_cost = + ComputeSliceShardingAndCommunicationCostFromOperand( + input_spec, operand->shape(), ins->shape(), + cluster_env.device_mesh_, cluster_env); + output_spec = output_spec_and_communication_cost.first; + communication_cost = output_spec_and_communication_cost.second; } break; } @@ -525,7 +586,6 @@ BuildStrategyAndCost( } std::string name = ToStringSimple(*output_spec); - double compute_cost = 0, communication_cost = 0; double memory_cost = GetBytes(ins->shape()) / output_spec->NumTiles(); std::pair resharding_costs = GenerateReshardingCostsAndMissingShardingsForAllOperands( diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index 3a231c51a31ca1..686a5842dc6945 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -63,6 +63,7 @@ using ::testing::FieldsAre; using ::testing::IsEmpty; using ::testing::IsFalse; using ::testing::IsTrue; +using ::testing::Not; using ::testing::Pair; using ::testing::ResultOf; using ::testing::UnorderedElementsAre; @@ -382,6 +383,46 @@ ENTRY %elementwise { EXPECT_THAT(instructions, Each(op::Sharding("{devices=[4,1]0,2,1,3}"))); } +TEST_F(AutoShardingTest, SlicedTensorDimensionShardedTest) { + // Below we check that sharding candidates that shard sliced tensor dimensions + // are generated by ensuring that no reshapes are added when pre-existing + // sharding annotations shard sliced tensor dimensions. + constexpr absl::string_view kHloString = R"( +HloModule module + +ENTRY %slicemodule { + param = s32[512,3084]{1,0} parameter(0), sharding={devices=[1,4]0,2,1,3} + slice = s32[512,2048]{1,0} slice(param), slice={[0:512], [0:2048]}, sharding={devices=[1,4]0,2,1,3} + ROOT copy = s32[512,2048]{1,0} copy(slice) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + AutoSharding( + /* option */ { + .enable = true, + .preserve_shardings = + AutoShardingOption::PreserveShardingsType::kKeepAllShardings, + .solve_nd_sharding_iteratively = true, + .device_mesh_shape = {2, 2}, + .device_mesh_ids = {0, 2, 1, 3}, + .device_mesh_alpha = {1.0, 1.0}, + .device_mesh_beta = {0.01, 1.0}}) + .Run(module.get())); + + VLOG(10) << module->ToString(); + EXPECT_TRUE(changed); + + std::vector instructions = + module->entry_computation()->MakeInstructionPostOrder(); + EXPECT_THAT(instructions, + Not(Contains(ResultOf( + [](const HloInstruction* ins) { return ins->opcode(); }, + Eq(HloOpcode::kReshape))))); +} + TEST_F(AutoShardingTest, UserShardingTest) { constexpr absl::string_view kHloString = R"( HloModule module diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index 039fc07706c7a0..49ef8d22830957 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -154,32 +154,6 @@ std::optional PropagateDimwiseSharding( return input_spec; } -HloSharding PropagateDimwiseShardingSlice(const HloSharding& input_spec, - const Shape& old_shape, - const Shape& new_shape, - const Array& device_mesh) { - if (input_spec.IsReplicated()) { - return input_spec; - } - - CHECK(old_shape.IsArray()); - - std::vector tensor_to_mesh_dim = - GetTensorDimToMeshDim(new_shape.rank(), input_spec, device_mesh, - /* consider_reverse_device_meshes */ true); - - std::vector tensor_dims; - std::vector mesh_dims; - for (size_t i = 0; i < new_shape.rank(); ++i) { - if (new_shape.dimensions(i) == old_shape.dimensions(i) && - tensor_to_mesh_dim[i] > -1) { - tensor_dims.push_back(i); - mesh_dims.push_back(tensor_to_mesh_dim[i]); - } - } - return Tile(new_shape, tensor_dims, mesh_dims, device_mesh); -} - // Propagate sharding for ReduceWindow-like operations. // The sharding can successfully propagate if the window operation only happens // on tensor dimensions that are not tiled. diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index c8fd040c7a2107..dd78eda6ebf304 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -419,11 +419,6 @@ std::optional PropagateDimwiseSharding( const HloSharding& input_spec, const Shape& old_shape, const Shape& new_shape); -HloSharding PropagateDimwiseShardingSlice(const HloSharding& input_spec, - const Shape& old_shape, - const Shape& new_shape, - const Array& device_mesh); - // Propagate sharding for ReduceWindow-like operations. // The sharding can successfully propagate if the window operation only happens // on tensor dimensions that are not tiled. From eea84120ca6cc5da887db720e1dc7cd5ab5655d6 Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Fri, 21 Jun 2024 16:19:50 -0700 Subject: [PATCH 138/256] Integrate StableHLO at openxla/stablehlo@61826746 PiperOrigin-RevId: 645531287 --- third_party/stablehlo/temporary.patch | 904 ------------------ third_party/stablehlo/workspace.bzl | 4 +- .../xla/third_party/stablehlo/temporary.patch | 904 ------------------ .../xla/third_party/stablehlo/workspace.bzl | 4 +- 4 files changed, 4 insertions(+), 1812 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 0eb67a7bda2ee2..cc9c7aee78573f 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -164,893 +164,6 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt #------------------------------------------------------------------------------- # Directory setup -diff --ruN a/stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py b/stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py ---- stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py -+++ stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py -@@ -22,54 +22,64 @@ - - - def main(): -- try: -- import functional_algorithms as fa -- except ImportError as msg: -- print(f"Skipping: {msg}") -- return -+ try: -+ import functional_algorithms as fa -+ except ImportError as msg: -+ print(f"Skipping: {msg}") -+ return - -- fa_version = tuple(map(int, fa.__version__.split(".", 4)[:3])) -- if fa_version < (0, 4, 0): -- warnings.warn("functional_algorithm version 0.4.0 or newer is required," -- f" got {fa.__version__}") -- return -+ fa_version = tuple(map(int, fa.__version__.split(".", 4)[:3])) -+ if fa_version < (0, 4, 0): -+ warnings.warn( -+ "functional_algorithm version 0.4.0 or newer is required," -+ f" got {fa.__version__}" -+ ) -+ return - -- output_file = os.path.relpath( -- os.path.normpath( -- os.path.join(os.path.dirname(__file__), "..", "..", "stablehlo", -- "transforms", "ChloDecompositionPatternsMath.td")), -- os.getcwd()) -+ output_file = os.path.relpath( -+ os.path.normpath( -+ os.path.join( -+ os.path.dirname(__file__), -+ "..", -+ "..", -+ "stablehlo", -+ "transforms", -+ "ChloDecompositionPatternsMath.td", -+ ) -+ ), -+ os.getcwd(), -+ ) - -- sources = [] -- target = fa.targets.stablehlo -- for chloname, fname, args in [ -- ("CHLO_AsinOp", "complex_asin", ("z:complex",)), -- ("CHLO_AsinOp", "real_asin", ("x:float",)), -- ]: -- func = getattr(fa.algorithms, fname, None) -- if func is None: -- warnings.warn( -- "{fa.algorithms.__name__} does not define {fname}. Skipping.") -- continue -- ctx = fa.Context(paths=[fa.algorithms]) -- graph = ctx.trace(func, *args).implement_missing(target).simplify() -- graph.props.update(name=chloname) -- src = graph.tostring(target) -- sources.append(target.make_comment( -- func.__doc__)) if func.__doc__ else None -- sources[-1] += src -- source = "\n\n".join(sources) + "\n" -+ sources = [] -+ target = fa.targets.stablehlo -+ for chloname, fname, args in [ -+ ("CHLO_AsinOp", "complex_asin", ("z:complex",)), -+ ("CHLO_AsinOp", "real_asin", ("x:float",)), -+ ]: -+ func = getattr(fa.algorithms, fname, None) -+ if func is None: -+ warnings.warn( -+ "{fa.algorithms.__name__} does not define {fname}. Skipping." -+ ) -+ continue -+ ctx = fa.Context(paths=[fa.algorithms]) -+ graph = ctx.trace(func, *args).implement_missing(target).simplify() -+ graph.props.update(name=chloname) -+ src = graph.tostring(target) -+ sources.append(target.make_comment(func.__doc__)) if func.__doc__ else None -+ sources[-1] += src -+ source = "\n\n".join(sources) + "\n" - -- if os.path.isfile(output_file): -- f = open(output_file, "r") -- content = f.read() -- f.close() -- if content.endswith(source): -- print(f"{output_file} is up-to-date.") -- return -+ if os.path.isfile(output_file): -+ f = open(output_file, "r") -+ content = f.read() -+ f.close() -+ if content.endswith(source): -+ print(f"{output_file} is up-to-date.") -+ return - -- f = open(output_file, "w") -- f.write("""\ -+ f = open(output_file, "w") -+ f.write("""\ - /* Copyright 2024 The StableHLO Authors. - - Licensed under the Apache License, Version 2.0 (the "License"); -@@ -86,15 +96,14 @@ - ==============================================================================*/ - - """) -- f.write( -- target.make_comment(f"""\ -+ f.write(target.make_comment(f"""\ - - This file is generated using functional_algorithms tool ({fa.__version__}). - See build_tools/math/README.md for more information.""") + "\n") -- f.write(source) -- f.close() -- print(f"Created {output_file}") -+ f.write(source) -+ f.close() -+ print(f"Created {output_file}") - - - if __name__ == "__main__": -- main() -+ main() -diff --ruN a/stablehlo/build_tools/math/generate_tests.py b/stablehlo/build_tools/math/generate_tests.py ---- stablehlo/build_tools/math/generate_tests.py -+++ stablehlo/build_tools/math/generate_tests.py -@@ -55,374 +55,394 @@ - - - def main(): -- try: -- import functional_algorithms as fa -- except ImportError as msg: -- print(f"Skipping: {msg}") -- return -- -- fa_version = tuple(map(int, fa.__version__.split(".", 4)[:3])) -- if fa_version < (0, 4, 0): -- warnings.warn("functional_algorithm version 0.4.0 or newer is required," -- f" got {fa.__version__}") -- return -- -- target_dir = os.path.relpath( -- os.path.normpath( -- os.path.join(os.path.dirname(__file__), "..", "..", "stablehlo", -- "tests", "math")), os.getcwd()) -- -- flush_subnormals = False -- for op in operations: -- opname = op["name"] -- mpmath_opname = op.get("mpmath_name", opname) -- size_re = size_im = op.get("size", default_size) -- extra_prec_multiplier = op.get("extra_prec_multiplier", -- default_extra_prec_multiplier) -- max_ulp_difference = op.get("max_ulp_difference", -- default_max_ulp_difference) -- -- nmp = fa.utils.numpy_with_mpmath( -- extra_prec_multiplier=extra_prec_multiplier, -- flush_subnormals=flush_subnormals) -- for dtype in [np.complex64, np.complex128, np.float32, np.float64]: -- fi = np.finfo(dtype) -- -- float_dtype = to_float_dtype[dtype] -- finfo = np.finfo(float_dtype) -- -- if dtype in [np.complex64, np.complex128]: -- samples = fa.utils.complex_samples( -- size=(size_re, size_im), -- dtype=dtype, -- include_subnormal=not flush_subnormals).flatten() -- else: -- samples = fa.utils.real_samples( -- size=size_re * size_im, -- dtype=dtype, -- include_subnormal=not flush_subnormals).flatten() -- -- expected = getattr(nmp, mpmath_opname)(samples) -- -- module_name = f"{opname}_{dtype.__name__}" -- m = SSA.make_module(module_name) -- -- samples_func = m.make_function("samples", "", mlir_type(samples)) -- samples_func.assign(samples) -- samples_func.return_last() -- -- expected_func = m.make_function("expected", "", mlir_type(expected)) -- expected_func.assign(expected) -- expected_func.return_last() -- -- main_func = m.make_function("main", "", "", "public") -- -- ref_samples = main_func.call("samples") -- actual = main_func.composite(f"chlo.{opname}", ref_samples) -- expected = main_func.call("expected") -- -- main_func.void_call( -- "check.expect_close", -- actual, -- expected, -- f"max_ulp_difference = {max_ulp_difference}", -- atypes=", ".join(map(main_func.get_ref_type, -- [actual, expected])), -- ) -- main_func.void_call("func.return") -- source = str(m).rstrip() + "\n" -- fname = os.path.join(target_dir, f"{module_name}.mlir") -- if os.path.isfile(fname): -- f = open(fname, "r") -- content = f.read() -- f.close() -- if content.endswith(source): -- print(f"{fname} is up-to-date.") -- continue -- -- f = open(fname, "w") -- f.write( -- "// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s | stablehlo-translate --interpret\n" -- ) -- f.write( -- "// This file is generated, see build_tools/math/README.md for more information.\n" -- ) -- f.write(source) -- f.close() -- print(f"Created {fname}") -- -- # Testing ULP difference -- for dtype in [np.float32, np.float64]: -- fi = np.finfo(dtype) -- -- max_ulp_difference = 0 -- min_ulp_difference = 0 -- -- finfo = np.finfo(dtype) -- module_name = f"ulp_difference_{dtype.__name__}" -- m = SSA.make_module(module_name) -- -- main_func = m.make_function("main", "", "", "public") -- -- def samples_generator(): -- data = [ -- -finfo.max, -1e9 - 1.2, -finfo.smallest_normal, -- -finfo.smallest_subnormal, 0, finfo.smallest_subnormal, -- finfo.smallest_normal, 1.2, 1e9 -- ] -- for expected_ulp_difference in [0, 1, 5, 50]: -- if expected_ulp_difference == 0: -- actual = np.array(data + [np.inf, -np.inf, np.nan], -- dtype=dtype) -- else: -- actual = np.array(data, dtype=dtype) -- shifted = actual -- for i in range(expected_ulp_difference): -- shifted = np.nextafter(shifted, np.inf, dtype=dtype) -- label = str(expected_ulp_difference) -- yield actual, shifted, expected_ulp_difference, label -- -- actual = np.array([np.inf] * 5, dtype=dtype) -- shifted = np.array([-np.inf, np.nan, 0, 1.2, finfo.max], -- dtype=dtype) -- yield actual, shifted, 2**64 - 1, "nonfinite" -- -- for actual, shifted, expected_ulp_difference, label in samples_generator( -- ): -- -- actual_func = m.make_function(f"actual_{label}", "", -- mlir_type(actual)) -- actual_func.comment(f'{list(actual)}') -- actual_func.assign(actual) -- actual_func.return_last() -- -- shifted_func = m.make_function(f"shifted_{label}", "", -- mlir_type(shifted)) -- shifted_func.comment(f'{list(shifted)}') -- shifted_func.assign(shifted) -- shifted_func.return_last() -- -- actual_values = main_func.call(f"actual_{label}") -- shifted_values = main_func.call(f"shifted_{label}") -- -- main_func.void_call( -- "check.expect_close", -- actual_values, -- shifted_values, -- f"max_ulp_difference = {expected_ulp_difference}", -- f"min_ulp_difference = {expected_ulp_difference}", -- atypes=", ".join( -- map(main_func.get_ref_type, -- [actual_values, shifted_values])), -- ) -- -- main_func.void_call("func.return") -- source = str(m).rstrip() + "\n" -- fname = os.path.join(target_dir, f"{module_name}.mlir") -- if os.path.isfile(fname): -- f = open(fname, "r") -- content = f.read() -- f.close() -- if content.endswith(source): -- print(f"{fname} is up-to-date.") -- continue -- -- f = open(fname, "w") -- f.write( -- "// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s | stablehlo-translate --interpret\n" -- ) -- f.write( -- "// This file is generated, see build_tools/math/README.md for more information.\n" -- ) -- f.write(source) -+ try: -+ import functional_algorithms as fa -+ except ImportError as msg: -+ print(f"Skipping: {msg}") -+ return -+ -+ fa_version = tuple(map(int, fa.__version__.split(".", 4)[:3])) -+ if fa_version < (0, 4, 0): -+ warnings.warn( -+ "functional_algorithm version 0.4.0 or newer is required," -+ f" got {fa.__version__}" -+ ) -+ return -+ -+ target_dir = os.path.relpath( -+ os.path.normpath( -+ os.path.join( -+ os.path.dirname(__file__), -+ "..", -+ "..", -+ "stablehlo", -+ "tests", -+ "math", -+ ) -+ ), -+ os.getcwd(), -+ ) -+ -+ flush_subnormals = False -+ for op in operations: -+ opname = op["name"] -+ mpmath_opname = op.get("mpmath_name", opname) -+ size_re = size_im = op.get("size", default_size) -+ extra_prec_multiplier = op.get( -+ "extra_prec_multiplier", default_extra_prec_multiplier -+ ) -+ max_ulp_difference = op.get( -+ "max_ulp_difference", default_max_ulp_difference -+ ) -+ -+ nmp = fa.utils.numpy_with_mpmath( -+ extra_prec_multiplier=extra_prec_multiplier, -+ flush_subnormals=flush_subnormals, -+ ) -+ for dtype in [np.complex64, np.complex128, np.float32, np.float64]: -+ fi = np.finfo(dtype) -+ -+ float_dtype = to_float_dtype[dtype] -+ finfo = np.finfo(float_dtype) -+ -+ if dtype in [np.complex64, np.complex128]: -+ samples = fa.utils.complex_samples( -+ size=(size_re, size_im), -+ dtype=dtype, -+ include_subnormal=not flush_subnormals, -+ ).flatten() -+ else: -+ samples = fa.utils.real_samples( -+ size=size_re * size_im, -+ dtype=dtype, -+ include_subnormal=not flush_subnormals, -+ ).flatten() -+ -+ expected = getattr(nmp, mpmath_opname)(samples) -+ -+ module_name = f"{opname}_{dtype.__name__}" -+ m = SSA.make_module(module_name) -+ -+ samples_func = m.make_function("samples", "", mlir_type(samples)) -+ samples_func.assign(samples) -+ samples_func.return_last() -+ -+ expected_func = m.make_function("expected", "", mlir_type(expected)) -+ expected_func.assign(expected) -+ expected_func.return_last() -+ -+ main_func = m.make_function("main", "", "", "public") -+ -+ ref_samples = main_func.call("samples") -+ actual = main_func.composite(f"chlo.{opname}", ref_samples) -+ expected = main_func.call("expected") -+ -+ main_func.void_call( -+ "check.expect_close", -+ actual, -+ expected, -+ f"max_ulp_difference = {max_ulp_difference}", -+ atypes=", ".join(map(main_func.get_ref_type, [actual, expected])), -+ ) -+ main_func.void_call("func.return") -+ source = str(m).rstrip() + "\n" -+ fname = os.path.join(target_dir, f"{module_name}.mlir") -+ if os.path.isfile(fname): -+ f = open(fname, "r") -+ content = f.read() - f.close() -- print(f"Created {fname}") -+ if content.endswith(source): -+ print(f"{fname} is up-to-date.") -+ continue -+ -+ f = open(fname, "w") -+ f.write( -+ "// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s |" -+ " stablehlo-translate --interpret\n" -+ ) -+ f.write( -+ "// This file is generated, see build_tools/math/README.md for more" -+ " information.\n" -+ ) -+ f.write(source) -+ f.close() -+ print(f"Created {fname}") -+ -+ # Testing ULP difference -+ for dtype in [np.float32, np.float64]: -+ fi = np.finfo(dtype) -+ -+ max_ulp_difference = 0 -+ min_ulp_difference = 0 -+ -+ finfo = np.finfo(dtype) -+ module_name = f"ulp_difference_{dtype.__name__}" -+ m = SSA.make_module(module_name) -+ -+ main_func = m.make_function("main", "", "", "public") -+ -+ def samples_generator(): -+ data = [ -+ -finfo.max, -+ -1e9 - 1.2, -+ -finfo.smallest_normal, -+ -finfo.smallest_subnormal, -+ 0, -+ finfo.smallest_subnormal, -+ finfo.smallest_normal, -+ 1.2, -+ 1e9, -+ ] -+ for expected_ulp_difference in [0, 1, 5, 50]: -+ if expected_ulp_difference == 0: -+ actual = np.array(data + [np.inf, -np.inf, np.nan], dtype=dtype) -+ else: -+ actual = np.array(data, dtype=dtype) -+ shifted = actual -+ for i in range(expected_ulp_difference): -+ shifted = np.nextafter(shifted, np.inf, dtype=dtype) -+ label = str(expected_ulp_difference) -+ yield actual, shifted, expected_ulp_difference, label -+ -+ actual = np.array([np.inf] * 5, dtype=dtype) -+ shifted = np.array([-np.inf, np.nan, 0, 1.2, finfo.max], dtype=dtype) -+ yield actual, shifted, 2**64 - 1, "nonfinite" -+ -+ for actual, shifted, expected_ulp_difference, label in samples_generator(): -+ -+ actual_func = m.make_function(f"actual_{label}", "", mlir_type(actual)) -+ actual_func.comment(f"{list(actual)}") -+ actual_func.assign(actual) -+ actual_func.return_last() -+ -+ shifted_func = m.make_function(f"shifted_{label}", "", mlir_type(shifted)) -+ shifted_func.comment(f"{list(shifted)}") -+ shifted_func.assign(shifted) -+ shifted_func.return_last() -+ -+ actual_values = main_func.call(f"actual_{label}") -+ shifted_values = main_func.call(f"shifted_{label}") -+ -+ main_func.void_call( -+ "check.expect_close", -+ actual_values, -+ shifted_values, -+ f"max_ulp_difference = {expected_ulp_difference}", -+ f"min_ulp_difference = {expected_ulp_difference}", -+ atypes=", ".join( -+ map(main_func.get_ref_type, [actual_values, shifted_values]) -+ ), -+ ) -+ -+ main_func.void_call("func.return") -+ source = str(m).rstrip() + "\n" -+ fname = os.path.join(target_dir, f"{module_name}.mlir") -+ if os.path.isfile(fname): -+ f = open(fname, "r") -+ content = f.read() -+ f.close() -+ if content.endswith(source): -+ print(f"{fname} is up-to-date.") -+ continue -+ -+ f = open(fname, "w") -+ f.write( -+ "// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s |" -+ " stablehlo-translate --interpret\n" -+ ) -+ f.write( -+ "// This file is generated, see build_tools/math/README.md for more" -+ " information.\n" -+ ) -+ f.write(source) -+ f.close() -+ print(f"Created {fname}") - - - class Block: -- """A data structure used in SSA""" -- -- def __init__(self, parent, prefix, suffix, start_counter=0): -- self.parent = parent -- self.prefix = prefix -- self.suffix = suffix -- self.counter = start_counter -- self.statements = {} -- -- def tostr(self, tab=""): -- lines = [] -- lines.append(tab + self.prefix) -- for i in sorted(self.statements): -- op, expr, typ = self.statements[i] -- if op == "//": -- lines.append(f"{tab} {op} {expr}") -- elif typ: -- lines.append(f"{tab} {op} {expr} : {typ}") -- else: -- assert not expr, (op, expr, typ) -- lines.append(f"{tab} {op}") -- lines.append(tab + self.suffix) -- return "\n".join(lines) -- -- def comment(self, message): -- # add comment to code -- self.statements[self.counter] = ("//", message, None) -- self.counter += 1 -- -- def assign(self, expr, typ=None): -- if isinstance(expr, np.ndarray): -- assert typ is None, typ -- typ = mlir_type(expr) -- expr = shlo_constant(expr) -- elif isinstance(expr, str) and typ is not None: -- pass -- elif isinstance(expr, bool) and typ is not None: -- expr = shlo_constant(expr) -- else: -- raise NotImplementedError((expr, typ)) -- target = f"%{self.counter}" -- self.statements[self.counter] = (f"{target} =", expr, typ) -- self.counter += 1 -- return target -- -- def call(self, name, *args): -- # call function created with make_function -- sargs = ", ".join(args) -- return self.assign(f"call @{name}({sargs})", -- typ=self.get_function_type(name)) -- -- def composite(self, name, *args, **options): -- sargs = ", ".join(args) -- atypes = tuple(map(self.get_ref_type, args)) -- rtype = options.get("rtype") -- if rtype is None: -- # assuming the first op argument defines the op type -- rtype = atypes[0] -- sargs = ", ".join(args) -- typ = f'({", ".join(atypes)}) -> {rtype}' -- return self.assign(f'"{name}"({sargs})', typ=typ) -- -- def void_call(self, name, *args, **options): -- # call function that has void return -- if args: -- sargs = ", ".join(args) -- atypes = options.get("atypes") -- if atypes is None: -- atypes = ", ".join(map(self.get_ref_type, args)) -- self.statements[self.counter] = (name, f"{sargs}", f"{atypes}") -- else: -- self.statements[self.counter] = (name, "", "") -- self.counter += 1 -- -- def apply(self, op, *args, **options): -- sargs = ", ".join(args) -- atypes = tuple(map(self.get_ref_type, args)) -- rtype = options.get("rtype") -- if rtype is None: -- # assuming the first op argument defines the op type -- rtype = atypes[0] -- typ = f'({", ".join(atypes)}) -> {rtype}' -- return self.assign(f"{op} {sargs}", typ=typ) -- -- def return_last(self): -- ref = f"%{self.counter - 1}" -- self.statements[self.counter] = ("return", ref, self.get_ref_type(ref)) -- self.counter += 1 -- -- @property -- def is_function(self): -- return self.prefix.startwith("func.func") -- -- @property -- def function_name(self): -- if self.prefix.startswith("func.func"): -- i = self.prefix.find("@") -- j = self.prefix.find("(", i) -- assert -1 not in {i, j}, self.prefix -- return self.prefix[i + 1:j] -- -- @property -- def function_type(self): -- if self.prefix.startswith("func.func"): -- i = self.prefix.find("(", self.prefix.find("@")) -- j = self.prefix.find("{", i) -- assert -1 not in {i, j}, self.prefix -- return self.prefix[i:j].strip() -- -- def get_function_type(self, name): -- for block in self.parent.blocks: -- if block.function_name == name: -- return block.function_type -- -- def get_ref_type(self, ref): -- assert ref.startswith("%"), ref -- counter = int(ref[1:]) -- typ = self.statements[counter][-1] -- return typ.rsplit("->", 1)[-1].strip() -+ """A data structure used in SSA""" -+ -+ def __init__(self, parent, prefix, suffix, start_counter=0): -+ self.parent = parent -+ self.prefix = prefix -+ self.suffix = suffix -+ self.counter = start_counter -+ self.statements = {} -+ -+ def tostr(self, tab=""): -+ lines = [] -+ lines.append(tab + self.prefix) -+ for i in sorted(self.statements): -+ op, expr, typ = self.statements[i] -+ if op == "//": -+ lines.append(f"{tab} {op} {expr}") -+ elif typ: -+ lines.append(f"{tab} {op} {expr} : {typ}") -+ else: -+ assert not expr, (op, expr, typ) -+ lines.append(f"{tab} {op}") -+ lines.append(tab + self.suffix) -+ return "\n".join(lines) -+ -+ def comment(self, message): -+ # add comment to code -+ self.statements[self.counter] = ("//", message, None) -+ self.counter += 1 -+ -+ def assign(self, expr, typ=None): -+ if isinstance(expr, np.ndarray): -+ assert typ is None, typ -+ typ = mlir_type(expr) -+ expr = shlo_constant(expr) -+ elif isinstance(expr, str) and typ is not None: -+ pass -+ elif isinstance(expr, bool) and typ is not None: -+ expr = shlo_constant(expr) -+ else: -+ raise NotImplementedError((expr, typ)) -+ target = f"%{self.counter}" -+ self.statements[self.counter] = (f"{target} =", expr, typ) -+ self.counter += 1 -+ return target -+ -+ def call(self, name, *args): -+ # call function created with make_function -+ sargs = ", ".join(args) -+ return self.assign( -+ f"call @{name}({sargs})", typ=self.get_function_type(name) -+ ) -+ -+ def composite(self, name, *args, **options): -+ sargs = ", ".join(args) -+ atypes = tuple(map(self.get_ref_type, args)) -+ rtype = options.get("rtype") -+ if rtype is None: -+ # assuming the first op argument defines the op type -+ rtype = atypes[0] -+ sargs = ", ".join(args) -+ typ = f'({", ".join(atypes)}) -> {rtype}' -+ return self.assign(f'"{name}"({sargs})', typ=typ) -+ -+ def void_call(self, name, *args, **options): -+ # call function that has void return -+ if args: -+ sargs = ", ".join(args) -+ atypes = options.get("atypes") -+ if atypes is None: -+ atypes = ", ".join(map(self.get_ref_type, args)) -+ self.statements[self.counter] = (name, f"{sargs}", f"{atypes}") -+ else: -+ self.statements[self.counter] = (name, "", "") -+ self.counter += 1 -+ -+ def apply(self, op, *args, **options): -+ sargs = ", ".join(args) -+ atypes = tuple(map(self.get_ref_type, args)) -+ rtype = options.get("rtype") -+ if rtype is None: -+ # assuming the first op argument defines the op type -+ rtype = atypes[0] -+ typ = f'({", ".join(atypes)}) -> {rtype}' -+ return self.assign(f"{op} {sargs}", typ=typ) -+ -+ def return_last(self): -+ ref = f"%{self.counter - 1}" -+ self.statements[self.counter] = ("return", ref, self.get_ref_type(ref)) -+ self.counter += 1 -+ -+ @property -+ def is_function(self): -+ return self.prefix.startwith("func.func") -+ -+ @property -+ def function_name(self): -+ if self.prefix.startswith("func.func"): -+ i = self.prefix.find("@") -+ j = self.prefix.find("(", i) -+ assert -1 not in {i, j}, self.prefix -+ return self.prefix[i + 1 : j] -+ -+ @property -+ def function_type(self): -+ if self.prefix.startswith("func.func"): -+ i = self.prefix.find("(", self.prefix.find("@")) -+ j = self.prefix.find("{", i) -+ assert -1 not in {i, j}, self.prefix -+ return self.prefix[i:j].strip() -+ -+ def get_function_type(self, name): -+ for block in self.parent.blocks: -+ if block.function_name == name: -+ return block.function_type -+ -+ def get_ref_type(self, ref): -+ assert ref.startswith("%"), ref -+ counter = int(ref[1:]) -+ typ = self.statements[counter][-1] -+ return typ.rsplit("->", 1)[-1].strip() - - - class SSA: -- """A light-weight SSA form factory.""" -- -- def __init__(self, prefix, suffix): -- self.prefix = prefix -- self.suffix = suffix -- self.blocks = [] -- -- @classmethod -- def make_module(cls, name): -- return SSA(f"module @{name} {{", "}") -- -- def make_function(self, name, args, rtype, attrs="private"): -- if rtype: -- b = Block(self, f"func.func {attrs} @{name}({args}) -> {rtype} {{", -- "}") -- else: -- b = Block(self, f"func.func {attrs} @{name}({args}) {{", "}") -- self.blocks.append(b) -- return b -- -- def tostr(self, tab=""): -- lines = [] -- lines.append(tab + self.prefix) -- for b in self.blocks: -- lines.extend(b.tostr(tab=tab + " ").split("\n")) -- lines.append(tab + self.suffix) -- return "\n".join(lines) -- -- def __str__(self): -- return self.tostr() -+ """A light-weight SSA form factory.""" -+ -+ def __init__(self, prefix, suffix): -+ self.prefix = prefix -+ self.suffix = suffix -+ self.blocks = [] -+ -+ @classmethod -+ def make_module(cls, name): -+ return SSA(f"module @{name} {{", "}") -+ -+ def make_function(self, name, args, rtype, attrs="private"): -+ if rtype: -+ b = Block(self, f"func.func {attrs} @{name}({args}) -> {rtype} {{", "}") -+ else: -+ b = Block(self, f"func.func {attrs} @{name}({args}) {{", "}") -+ self.blocks.append(b) -+ return b -+ -+ def tostr(self, tab=""): -+ lines = [] -+ lines.append(tab + self.prefix) -+ for b in self.blocks: -+ lines.extend(b.tostr(tab=tab + " ").split("\n")) -+ lines.append(tab + self.suffix) -+ return "\n".join(lines) -+ -+ def __str__(self): -+ return self.tostr() - - - def mlir_type(obj): -- if isinstance(obj, np.ndarray): -- s = "x".join(map(str, obj.shape)) -- t = { -- np.bool_: "i1", -- np.float16: "f16", -- np.float32: "f32", -- np.float64: "f64", -- np.complex64: "complex", -- np.complex128: "complex", -- }[obj.dtype.type] -- return f"tensor<{s}x{t}>" -+ if isinstance(obj, np.ndarray): -+ s = "x".join(map(str, obj.shape)) -+ t = { -+ np.bool_: "i1", -+ np.float16: "f16", -+ np.float32: "f32", -+ np.float64: "f64", -+ np.complex64: "complex", -+ np.complex128: "complex", -+ }[obj.dtype.type] -+ return f"tensor<{s}x{t}>" -+ else: -+ raise NotImplementedError(type(obj)) -+ -+ -+def shlo_constant(obj): -+ if isinstance(obj, bool): -+ v = str(obj).lower() -+ return f"stablehlo.constant dense<{v}>" -+ if isinstance(obj, np.ndarray): -+ if obj.dtype == np.bool_: -+ h = "".join(map(lambda n: "%01x" % n, obj.view(np.uint8))).upper() - else: -- raise NotImplementedError(type(obj)) -- -- --def shlo_constant(obj): -- if isinstance(obj, bool): -- v = str(obj).lower() -- return f"stablehlo.constant dense<{v}>" -- if isinstance(obj, np.ndarray): -- if obj.dtype == np.bool_: -- h = "".join(map(lambda n: "%01x" % n, obj.view(np.uint8))).upper() -- else: -- h = "".join(map(lambda n: "%02x" % n, obj.view(np.uint8))).upper() -- return f'stablehlo.constant dense<"0x{h}">' -- else: -- raise NotImplementedError(type(obj)) -+ h = "".join(map(lambda n: "%02x" % n, obj.view(np.uint8))).upper() -+ return f'stablehlo.constant dense<"0x{h}">' -+ else: -+ raise NotImplementedError(type(obj)) - - - if __name__ == "__main__": -- main() -+ main() diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists.txt --- stablehlo/stablehlo/CMakeLists.txt +++ stablehlo/stablehlo/CMakeLists.txt @@ -3830,21 +2943,4 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c +} // namespace experimental +} // namespace stablehlo +} // namespace mlir -diff --ruN a/stablehlo/stablehlo/integrations/python/tests/stablehlo.py b/stablehlo/stablehlo/integrations/python/tests/stablehlo.py ---- stablehlo/stablehlo/integrations/python/tests/stablehlo.py -+++ stablehlo/stablehlo/integrations/python/tests/stablehlo.py -@@ -283,11 +283,13 @@ - expected = arg + arg - assert (actual == expected).all() - -+ - @run - def test_get_smaller_version(): - curr_version = stablehlo.get_current_version() - min_version = stablehlo.get_minimum_version() - assert stablehlo.get_smaller_version(curr_version, min_version) == min_version -+ - - @run - def test_serialization_apis(): diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 15648501cf7f1c..4a70795435d0f9 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "f1f49945f3862a46ecd6c7fc111e9d7843c2b7da" - STABLEHLO_SHA256 = "5e79eaf23075e627c8cbcb4e8572f91fd2db70a4a96722f02490482f26177b4b" + STABLEHLO_COMMIT = "61826746d640f6342856af5ac71dabe6fbf37ff5" + STABLEHLO_SHA256 = "fece99979b885686438068a786895dcd9a559fdc86072178e42181404022d483" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index 0eb67a7bda2ee2..cc9c7aee78573f 100755 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -164,893 +164,6 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt #------------------------------------------------------------------------------- # Directory setup -diff --ruN a/stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py b/stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py ---- stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py -+++ stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py -@@ -22,54 +22,64 @@ - - - def main(): -- try: -- import functional_algorithms as fa -- except ImportError as msg: -- print(f"Skipping: {msg}") -- return -+ try: -+ import functional_algorithms as fa -+ except ImportError as msg: -+ print(f"Skipping: {msg}") -+ return - -- fa_version = tuple(map(int, fa.__version__.split(".", 4)[:3])) -- if fa_version < (0, 4, 0): -- warnings.warn("functional_algorithm version 0.4.0 or newer is required," -- f" got {fa.__version__}") -- return -+ fa_version = tuple(map(int, fa.__version__.split(".", 4)[:3])) -+ if fa_version < (0, 4, 0): -+ warnings.warn( -+ "functional_algorithm version 0.4.0 or newer is required," -+ f" got {fa.__version__}" -+ ) -+ return - -- output_file = os.path.relpath( -- os.path.normpath( -- os.path.join(os.path.dirname(__file__), "..", "..", "stablehlo", -- "transforms", "ChloDecompositionPatternsMath.td")), -- os.getcwd()) -+ output_file = os.path.relpath( -+ os.path.normpath( -+ os.path.join( -+ os.path.dirname(__file__), -+ "..", -+ "..", -+ "stablehlo", -+ "transforms", -+ "ChloDecompositionPatternsMath.td", -+ ) -+ ), -+ os.getcwd(), -+ ) - -- sources = [] -- target = fa.targets.stablehlo -- for chloname, fname, args in [ -- ("CHLO_AsinOp", "complex_asin", ("z:complex",)), -- ("CHLO_AsinOp", "real_asin", ("x:float",)), -- ]: -- func = getattr(fa.algorithms, fname, None) -- if func is None: -- warnings.warn( -- "{fa.algorithms.__name__} does not define {fname}. Skipping.") -- continue -- ctx = fa.Context(paths=[fa.algorithms]) -- graph = ctx.trace(func, *args).implement_missing(target).simplify() -- graph.props.update(name=chloname) -- src = graph.tostring(target) -- sources.append(target.make_comment( -- func.__doc__)) if func.__doc__ else None -- sources[-1] += src -- source = "\n\n".join(sources) + "\n" -+ sources = [] -+ target = fa.targets.stablehlo -+ for chloname, fname, args in [ -+ ("CHLO_AsinOp", "complex_asin", ("z:complex",)), -+ ("CHLO_AsinOp", "real_asin", ("x:float",)), -+ ]: -+ func = getattr(fa.algorithms, fname, None) -+ if func is None: -+ warnings.warn( -+ "{fa.algorithms.__name__} does not define {fname}. Skipping." -+ ) -+ continue -+ ctx = fa.Context(paths=[fa.algorithms]) -+ graph = ctx.trace(func, *args).implement_missing(target).simplify() -+ graph.props.update(name=chloname) -+ src = graph.tostring(target) -+ sources.append(target.make_comment(func.__doc__)) if func.__doc__ else None -+ sources[-1] += src -+ source = "\n\n".join(sources) + "\n" - -- if os.path.isfile(output_file): -- f = open(output_file, "r") -- content = f.read() -- f.close() -- if content.endswith(source): -- print(f"{output_file} is up-to-date.") -- return -+ if os.path.isfile(output_file): -+ f = open(output_file, "r") -+ content = f.read() -+ f.close() -+ if content.endswith(source): -+ print(f"{output_file} is up-to-date.") -+ return - -- f = open(output_file, "w") -- f.write("""\ -+ f = open(output_file, "w") -+ f.write("""\ - /* Copyright 2024 The StableHLO Authors. - - Licensed under the Apache License, Version 2.0 (the "License"); -@@ -86,15 +96,14 @@ - ==============================================================================*/ - - """) -- f.write( -- target.make_comment(f"""\ -+ f.write(target.make_comment(f"""\ - - This file is generated using functional_algorithms tool ({fa.__version__}). - See build_tools/math/README.md for more information.""") + "\n") -- f.write(source) -- f.close() -- print(f"Created {output_file}") -+ f.write(source) -+ f.close() -+ print(f"Created {output_file}") - - - if __name__ == "__main__": -- main() -+ main() -diff --ruN a/stablehlo/build_tools/math/generate_tests.py b/stablehlo/build_tools/math/generate_tests.py ---- stablehlo/build_tools/math/generate_tests.py -+++ stablehlo/build_tools/math/generate_tests.py -@@ -55,374 +55,394 @@ - - - def main(): -- try: -- import functional_algorithms as fa -- except ImportError as msg: -- print(f"Skipping: {msg}") -- return -- -- fa_version = tuple(map(int, fa.__version__.split(".", 4)[:3])) -- if fa_version < (0, 4, 0): -- warnings.warn("functional_algorithm version 0.4.0 or newer is required," -- f" got {fa.__version__}") -- return -- -- target_dir = os.path.relpath( -- os.path.normpath( -- os.path.join(os.path.dirname(__file__), "..", "..", "stablehlo", -- "tests", "math")), os.getcwd()) -- -- flush_subnormals = False -- for op in operations: -- opname = op["name"] -- mpmath_opname = op.get("mpmath_name", opname) -- size_re = size_im = op.get("size", default_size) -- extra_prec_multiplier = op.get("extra_prec_multiplier", -- default_extra_prec_multiplier) -- max_ulp_difference = op.get("max_ulp_difference", -- default_max_ulp_difference) -- -- nmp = fa.utils.numpy_with_mpmath( -- extra_prec_multiplier=extra_prec_multiplier, -- flush_subnormals=flush_subnormals) -- for dtype in [np.complex64, np.complex128, np.float32, np.float64]: -- fi = np.finfo(dtype) -- -- float_dtype = to_float_dtype[dtype] -- finfo = np.finfo(float_dtype) -- -- if dtype in [np.complex64, np.complex128]: -- samples = fa.utils.complex_samples( -- size=(size_re, size_im), -- dtype=dtype, -- include_subnormal=not flush_subnormals).flatten() -- else: -- samples = fa.utils.real_samples( -- size=size_re * size_im, -- dtype=dtype, -- include_subnormal=not flush_subnormals).flatten() -- -- expected = getattr(nmp, mpmath_opname)(samples) -- -- module_name = f"{opname}_{dtype.__name__}" -- m = SSA.make_module(module_name) -- -- samples_func = m.make_function("samples", "", mlir_type(samples)) -- samples_func.assign(samples) -- samples_func.return_last() -- -- expected_func = m.make_function("expected", "", mlir_type(expected)) -- expected_func.assign(expected) -- expected_func.return_last() -- -- main_func = m.make_function("main", "", "", "public") -- -- ref_samples = main_func.call("samples") -- actual = main_func.composite(f"chlo.{opname}", ref_samples) -- expected = main_func.call("expected") -- -- main_func.void_call( -- "check.expect_close", -- actual, -- expected, -- f"max_ulp_difference = {max_ulp_difference}", -- atypes=", ".join(map(main_func.get_ref_type, -- [actual, expected])), -- ) -- main_func.void_call("func.return") -- source = str(m).rstrip() + "\n" -- fname = os.path.join(target_dir, f"{module_name}.mlir") -- if os.path.isfile(fname): -- f = open(fname, "r") -- content = f.read() -- f.close() -- if content.endswith(source): -- print(f"{fname} is up-to-date.") -- continue -- -- f = open(fname, "w") -- f.write( -- "// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s | stablehlo-translate --interpret\n" -- ) -- f.write( -- "// This file is generated, see build_tools/math/README.md for more information.\n" -- ) -- f.write(source) -- f.close() -- print(f"Created {fname}") -- -- # Testing ULP difference -- for dtype in [np.float32, np.float64]: -- fi = np.finfo(dtype) -- -- max_ulp_difference = 0 -- min_ulp_difference = 0 -- -- finfo = np.finfo(dtype) -- module_name = f"ulp_difference_{dtype.__name__}" -- m = SSA.make_module(module_name) -- -- main_func = m.make_function("main", "", "", "public") -- -- def samples_generator(): -- data = [ -- -finfo.max, -1e9 - 1.2, -finfo.smallest_normal, -- -finfo.smallest_subnormal, 0, finfo.smallest_subnormal, -- finfo.smallest_normal, 1.2, 1e9 -- ] -- for expected_ulp_difference in [0, 1, 5, 50]: -- if expected_ulp_difference == 0: -- actual = np.array(data + [np.inf, -np.inf, np.nan], -- dtype=dtype) -- else: -- actual = np.array(data, dtype=dtype) -- shifted = actual -- for i in range(expected_ulp_difference): -- shifted = np.nextafter(shifted, np.inf, dtype=dtype) -- label = str(expected_ulp_difference) -- yield actual, shifted, expected_ulp_difference, label -- -- actual = np.array([np.inf] * 5, dtype=dtype) -- shifted = np.array([-np.inf, np.nan, 0, 1.2, finfo.max], -- dtype=dtype) -- yield actual, shifted, 2**64 - 1, "nonfinite" -- -- for actual, shifted, expected_ulp_difference, label in samples_generator( -- ): -- -- actual_func = m.make_function(f"actual_{label}", "", -- mlir_type(actual)) -- actual_func.comment(f'{list(actual)}') -- actual_func.assign(actual) -- actual_func.return_last() -- -- shifted_func = m.make_function(f"shifted_{label}", "", -- mlir_type(shifted)) -- shifted_func.comment(f'{list(shifted)}') -- shifted_func.assign(shifted) -- shifted_func.return_last() -- -- actual_values = main_func.call(f"actual_{label}") -- shifted_values = main_func.call(f"shifted_{label}") -- -- main_func.void_call( -- "check.expect_close", -- actual_values, -- shifted_values, -- f"max_ulp_difference = {expected_ulp_difference}", -- f"min_ulp_difference = {expected_ulp_difference}", -- atypes=", ".join( -- map(main_func.get_ref_type, -- [actual_values, shifted_values])), -- ) -- -- main_func.void_call("func.return") -- source = str(m).rstrip() + "\n" -- fname = os.path.join(target_dir, f"{module_name}.mlir") -- if os.path.isfile(fname): -- f = open(fname, "r") -- content = f.read() -- f.close() -- if content.endswith(source): -- print(f"{fname} is up-to-date.") -- continue -- -- f = open(fname, "w") -- f.write( -- "// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s | stablehlo-translate --interpret\n" -- ) -- f.write( -- "// This file is generated, see build_tools/math/README.md for more information.\n" -- ) -- f.write(source) -+ try: -+ import functional_algorithms as fa -+ except ImportError as msg: -+ print(f"Skipping: {msg}") -+ return -+ -+ fa_version = tuple(map(int, fa.__version__.split(".", 4)[:3])) -+ if fa_version < (0, 4, 0): -+ warnings.warn( -+ "functional_algorithm version 0.4.0 or newer is required," -+ f" got {fa.__version__}" -+ ) -+ return -+ -+ target_dir = os.path.relpath( -+ os.path.normpath( -+ os.path.join( -+ os.path.dirname(__file__), -+ "..", -+ "..", -+ "stablehlo", -+ "tests", -+ "math", -+ ) -+ ), -+ os.getcwd(), -+ ) -+ -+ flush_subnormals = False -+ for op in operations: -+ opname = op["name"] -+ mpmath_opname = op.get("mpmath_name", opname) -+ size_re = size_im = op.get("size", default_size) -+ extra_prec_multiplier = op.get( -+ "extra_prec_multiplier", default_extra_prec_multiplier -+ ) -+ max_ulp_difference = op.get( -+ "max_ulp_difference", default_max_ulp_difference -+ ) -+ -+ nmp = fa.utils.numpy_with_mpmath( -+ extra_prec_multiplier=extra_prec_multiplier, -+ flush_subnormals=flush_subnormals, -+ ) -+ for dtype in [np.complex64, np.complex128, np.float32, np.float64]: -+ fi = np.finfo(dtype) -+ -+ float_dtype = to_float_dtype[dtype] -+ finfo = np.finfo(float_dtype) -+ -+ if dtype in [np.complex64, np.complex128]: -+ samples = fa.utils.complex_samples( -+ size=(size_re, size_im), -+ dtype=dtype, -+ include_subnormal=not flush_subnormals, -+ ).flatten() -+ else: -+ samples = fa.utils.real_samples( -+ size=size_re * size_im, -+ dtype=dtype, -+ include_subnormal=not flush_subnormals, -+ ).flatten() -+ -+ expected = getattr(nmp, mpmath_opname)(samples) -+ -+ module_name = f"{opname}_{dtype.__name__}" -+ m = SSA.make_module(module_name) -+ -+ samples_func = m.make_function("samples", "", mlir_type(samples)) -+ samples_func.assign(samples) -+ samples_func.return_last() -+ -+ expected_func = m.make_function("expected", "", mlir_type(expected)) -+ expected_func.assign(expected) -+ expected_func.return_last() -+ -+ main_func = m.make_function("main", "", "", "public") -+ -+ ref_samples = main_func.call("samples") -+ actual = main_func.composite(f"chlo.{opname}", ref_samples) -+ expected = main_func.call("expected") -+ -+ main_func.void_call( -+ "check.expect_close", -+ actual, -+ expected, -+ f"max_ulp_difference = {max_ulp_difference}", -+ atypes=", ".join(map(main_func.get_ref_type, [actual, expected])), -+ ) -+ main_func.void_call("func.return") -+ source = str(m).rstrip() + "\n" -+ fname = os.path.join(target_dir, f"{module_name}.mlir") -+ if os.path.isfile(fname): -+ f = open(fname, "r") -+ content = f.read() - f.close() -- print(f"Created {fname}") -+ if content.endswith(source): -+ print(f"{fname} is up-to-date.") -+ continue -+ -+ f = open(fname, "w") -+ f.write( -+ "// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s |" -+ " stablehlo-translate --interpret\n" -+ ) -+ f.write( -+ "// This file is generated, see build_tools/math/README.md for more" -+ " information.\n" -+ ) -+ f.write(source) -+ f.close() -+ print(f"Created {fname}") -+ -+ # Testing ULP difference -+ for dtype in [np.float32, np.float64]: -+ fi = np.finfo(dtype) -+ -+ max_ulp_difference = 0 -+ min_ulp_difference = 0 -+ -+ finfo = np.finfo(dtype) -+ module_name = f"ulp_difference_{dtype.__name__}" -+ m = SSA.make_module(module_name) -+ -+ main_func = m.make_function("main", "", "", "public") -+ -+ def samples_generator(): -+ data = [ -+ -finfo.max, -+ -1e9 - 1.2, -+ -finfo.smallest_normal, -+ -finfo.smallest_subnormal, -+ 0, -+ finfo.smallest_subnormal, -+ finfo.smallest_normal, -+ 1.2, -+ 1e9, -+ ] -+ for expected_ulp_difference in [0, 1, 5, 50]: -+ if expected_ulp_difference == 0: -+ actual = np.array(data + [np.inf, -np.inf, np.nan], dtype=dtype) -+ else: -+ actual = np.array(data, dtype=dtype) -+ shifted = actual -+ for i in range(expected_ulp_difference): -+ shifted = np.nextafter(shifted, np.inf, dtype=dtype) -+ label = str(expected_ulp_difference) -+ yield actual, shifted, expected_ulp_difference, label -+ -+ actual = np.array([np.inf] * 5, dtype=dtype) -+ shifted = np.array([-np.inf, np.nan, 0, 1.2, finfo.max], dtype=dtype) -+ yield actual, shifted, 2**64 - 1, "nonfinite" -+ -+ for actual, shifted, expected_ulp_difference, label in samples_generator(): -+ -+ actual_func = m.make_function(f"actual_{label}", "", mlir_type(actual)) -+ actual_func.comment(f"{list(actual)}") -+ actual_func.assign(actual) -+ actual_func.return_last() -+ -+ shifted_func = m.make_function(f"shifted_{label}", "", mlir_type(shifted)) -+ shifted_func.comment(f"{list(shifted)}") -+ shifted_func.assign(shifted) -+ shifted_func.return_last() -+ -+ actual_values = main_func.call(f"actual_{label}") -+ shifted_values = main_func.call(f"shifted_{label}") -+ -+ main_func.void_call( -+ "check.expect_close", -+ actual_values, -+ shifted_values, -+ f"max_ulp_difference = {expected_ulp_difference}", -+ f"min_ulp_difference = {expected_ulp_difference}", -+ atypes=", ".join( -+ map(main_func.get_ref_type, [actual_values, shifted_values]) -+ ), -+ ) -+ -+ main_func.void_call("func.return") -+ source = str(m).rstrip() + "\n" -+ fname = os.path.join(target_dir, f"{module_name}.mlir") -+ if os.path.isfile(fname): -+ f = open(fname, "r") -+ content = f.read() -+ f.close() -+ if content.endswith(source): -+ print(f"{fname} is up-to-date.") -+ continue -+ -+ f = open(fname, "w") -+ f.write( -+ "// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s |" -+ " stablehlo-translate --interpret\n" -+ ) -+ f.write( -+ "// This file is generated, see build_tools/math/README.md for more" -+ " information.\n" -+ ) -+ f.write(source) -+ f.close() -+ print(f"Created {fname}") - - - class Block: -- """A data structure used in SSA""" -- -- def __init__(self, parent, prefix, suffix, start_counter=0): -- self.parent = parent -- self.prefix = prefix -- self.suffix = suffix -- self.counter = start_counter -- self.statements = {} -- -- def tostr(self, tab=""): -- lines = [] -- lines.append(tab + self.prefix) -- for i in sorted(self.statements): -- op, expr, typ = self.statements[i] -- if op == "//": -- lines.append(f"{tab} {op} {expr}") -- elif typ: -- lines.append(f"{tab} {op} {expr} : {typ}") -- else: -- assert not expr, (op, expr, typ) -- lines.append(f"{tab} {op}") -- lines.append(tab + self.suffix) -- return "\n".join(lines) -- -- def comment(self, message): -- # add comment to code -- self.statements[self.counter] = ("//", message, None) -- self.counter += 1 -- -- def assign(self, expr, typ=None): -- if isinstance(expr, np.ndarray): -- assert typ is None, typ -- typ = mlir_type(expr) -- expr = shlo_constant(expr) -- elif isinstance(expr, str) and typ is not None: -- pass -- elif isinstance(expr, bool) and typ is not None: -- expr = shlo_constant(expr) -- else: -- raise NotImplementedError((expr, typ)) -- target = f"%{self.counter}" -- self.statements[self.counter] = (f"{target} =", expr, typ) -- self.counter += 1 -- return target -- -- def call(self, name, *args): -- # call function created with make_function -- sargs = ", ".join(args) -- return self.assign(f"call @{name}({sargs})", -- typ=self.get_function_type(name)) -- -- def composite(self, name, *args, **options): -- sargs = ", ".join(args) -- atypes = tuple(map(self.get_ref_type, args)) -- rtype = options.get("rtype") -- if rtype is None: -- # assuming the first op argument defines the op type -- rtype = atypes[0] -- sargs = ", ".join(args) -- typ = f'({", ".join(atypes)}) -> {rtype}' -- return self.assign(f'"{name}"({sargs})', typ=typ) -- -- def void_call(self, name, *args, **options): -- # call function that has void return -- if args: -- sargs = ", ".join(args) -- atypes = options.get("atypes") -- if atypes is None: -- atypes = ", ".join(map(self.get_ref_type, args)) -- self.statements[self.counter] = (name, f"{sargs}", f"{atypes}") -- else: -- self.statements[self.counter] = (name, "", "") -- self.counter += 1 -- -- def apply(self, op, *args, **options): -- sargs = ", ".join(args) -- atypes = tuple(map(self.get_ref_type, args)) -- rtype = options.get("rtype") -- if rtype is None: -- # assuming the first op argument defines the op type -- rtype = atypes[0] -- typ = f'({", ".join(atypes)}) -> {rtype}' -- return self.assign(f"{op} {sargs}", typ=typ) -- -- def return_last(self): -- ref = f"%{self.counter - 1}" -- self.statements[self.counter] = ("return", ref, self.get_ref_type(ref)) -- self.counter += 1 -- -- @property -- def is_function(self): -- return self.prefix.startwith("func.func") -- -- @property -- def function_name(self): -- if self.prefix.startswith("func.func"): -- i = self.prefix.find("@") -- j = self.prefix.find("(", i) -- assert -1 not in {i, j}, self.prefix -- return self.prefix[i + 1:j] -- -- @property -- def function_type(self): -- if self.prefix.startswith("func.func"): -- i = self.prefix.find("(", self.prefix.find("@")) -- j = self.prefix.find("{", i) -- assert -1 not in {i, j}, self.prefix -- return self.prefix[i:j].strip() -- -- def get_function_type(self, name): -- for block in self.parent.blocks: -- if block.function_name == name: -- return block.function_type -- -- def get_ref_type(self, ref): -- assert ref.startswith("%"), ref -- counter = int(ref[1:]) -- typ = self.statements[counter][-1] -- return typ.rsplit("->", 1)[-1].strip() -+ """A data structure used in SSA""" -+ -+ def __init__(self, parent, prefix, suffix, start_counter=0): -+ self.parent = parent -+ self.prefix = prefix -+ self.suffix = suffix -+ self.counter = start_counter -+ self.statements = {} -+ -+ def tostr(self, tab=""): -+ lines = [] -+ lines.append(tab + self.prefix) -+ for i in sorted(self.statements): -+ op, expr, typ = self.statements[i] -+ if op == "//": -+ lines.append(f"{tab} {op} {expr}") -+ elif typ: -+ lines.append(f"{tab} {op} {expr} : {typ}") -+ else: -+ assert not expr, (op, expr, typ) -+ lines.append(f"{tab} {op}") -+ lines.append(tab + self.suffix) -+ return "\n".join(lines) -+ -+ def comment(self, message): -+ # add comment to code -+ self.statements[self.counter] = ("//", message, None) -+ self.counter += 1 -+ -+ def assign(self, expr, typ=None): -+ if isinstance(expr, np.ndarray): -+ assert typ is None, typ -+ typ = mlir_type(expr) -+ expr = shlo_constant(expr) -+ elif isinstance(expr, str) and typ is not None: -+ pass -+ elif isinstance(expr, bool) and typ is not None: -+ expr = shlo_constant(expr) -+ else: -+ raise NotImplementedError((expr, typ)) -+ target = f"%{self.counter}" -+ self.statements[self.counter] = (f"{target} =", expr, typ) -+ self.counter += 1 -+ return target -+ -+ def call(self, name, *args): -+ # call function created with make_function -+ sargs = ", ".join(args) -+ return self.assign( -+ f"call @{name}({sargs})", typ=self.get_function_type(name) -+ ) -+ -+ def composite(self, name, *args, **options): -+ sargs = ", ".join(args) -+ atypes = tuple(map(self.get_ref_type, args)) -+ rtype = options.get("rtype") -+ if rtype is None: -+ # assuming the first op argument defines the op type -+ rtype = atypes[0] -+ sargs = ", ".join(args) -+ typ = f'({", ".join(atypes)}) -> {rtype}' -+ return self.assign(f'"{name}"({sargs})', typ=typ) -+ -+ def void_call(self, name, *args, **options): -+ # call function that has void return -+ if args: -+ sargs = ", ".join(args) -+ atypes = options.get("atypes") -+ if atypes is None: -+ atypes = ", ".join(map(self.get_ref_type, args)) -+ self.statements[self.counter] = (name, f"{sargs}", f"{atypes}") -+ else: -+ self.statements[self.counter] = (name, "", "") -+ self.counter += 1 -+ -+ def apply(self, op, *args, **options): -+ sargs = ", ".join(args) -+ atypes = tuple(map(self.get_ref_type, args)) -+ rtype = options.get("rtype") -+ if rtype is None: -+ # assuming the first op argument defines the op type -+ rtype = atypes[0] -+ typ = f'({", ".join(atypes)}) -> {rtype}' -+ return self.assign(f"{op} {sargs}", typ=typ) -+ -+ def return_last(self): -+ ref = f"%{self.counter - 1}" -+ self.statements[self.counter] = ("return", ref, self.get_ref_type(ref)) -+ self.counter += 1 -+ -+ @property -+ def is_function(self): -+ return self.prefix.startwith("func.func") -+ -+ @property -+ def function_name(self): -+ if self.prefix.startswith("func.func"): -+ i = self.prefix.find("@") -+ j = self.prefix.find("(", i) -+ assert -1 not in {i, j}, self.prefix -+ return self.prefix[i + 1 : j] -+ -+ @property -+ def function_type(self): -+ if self.prefix.startswith("func.func"): -+ i = self.prefix.find("(", self.prefix.find("@")) -+ j = self.prefix.find("{", i) -+ assert -1 not in {i, j}, self.prefix -+ return self.prefix[i:j].strip() -+ -+ def get_function_type(self, name): -+ for block in self.parent.blocks: -+ if block.function_name == name: -+ return block.function_type -+ -+ def get_ref_type(self, ref): -+ assert ref.startswith("%"), ref -+ counter = int(ref[1:]) -+ typ = self.statements[counter][-1] -+ return typ.rsplit("->", 1)[-1].strip() - - - class SSA: -- """A light-weight SSA form factory.""" -- -- def __init__(self, prefix, suffix): -- self.prefix = prefix -- self.suffix = suffix -- self.blocks = [] -- -- @classmethod -- def make_module(cls, name): -- return SSA(f"module @{name} {{", "}") -- -- def make_function(self, name, args, rtype, attrs="private"): -- if rtype: -- b = Block(self, f"func.func {attrs} @{name}({args}) -> {rtype} {{", -- "}") -- else: -- b = Block(self, f"func.func {attrs} @{name}({args}) {{", "}") -- self.blocks.append(b) -- return b -- -- def tostr(self, tab=""): -- lines = [] -- lines.append(tab + self.prefix) -- for b in self.blocks: -- lines.extend(b.tostr(tab=tab + " ").split("\n")) -- lines.append(tab + self.suffix) -- return "\n".join(lines) -- -- def __str__(self): -- return self.tostr() -+ """A light-weight SSA form factory.""" -+ -+ def __init__(self, prefix, suffix): -+ self.prefix = prefix -+ self.suffix = suffix -+ self.blocks = [] -+ -+ @classmethod -+ def make_module(cls, name): -+ return SSA(f"module @{name} {{", "}") -+ -+ def make_function(self, name, args, rtype, attrs="private"): -+ if rtype: -+ b = Block(self, f"func.func {attrs} @{name}({args}) -> {rtype} {{", "}") -+ else: -+ b = Block(self, f"func.func {attrs} @{name}({args}) {{", "}") -+ self.blocks.append(b) -+ return b -+ -+ def tostr(self, tab=""): -+ lines = [] -+ lines.append(tab + self.prefix) -+ for b in self.blocks: -+ lines.extend(b.tostr(tab=tab + " ").split("\n")) -+ lines.append(tab + self.suffix) -+ return "\n".join(lines) -+ -+ def __str__(self): -+ return self.tostr() - - - def mlir_type(obj): -- if isinstance(obj, np.ndarray): -- s = "x".join(map(str, obj.shape)) -- t = { -- np.bool_: "i1", -- np.float16: "f16", -- np.float32: "f32", -- np.float64: "f64", -- np.complex64: "complex", -- np.complex128: "complex", -- }[obj.dtype.type] -- return f"tensor<{s}x{t}>" -+ if isinstance(obj, np.ndarray): -+ s = "x".join(map(str, obj.shape)) -+ t = { -+ np.bool_: "i1", -+ np.float16: "f16", -+ np.float32: "f32", -+ np.float64: "f64", -+ np.complex64: "complex", -+ np.complex128: "complex", -+ }[obj.dtype.type] -+ return f"tensor<{s}x{t}>" -+ else: -+ raise NotImplementedError(type(obj)) -+ -+ -+def shlo_constant(obj): -+ if isinstance(obj, bool): -+ v = str(obj).lower() -+ return f"stablehlo.constant dense<{v}>" -+ if isinstance(obj, np.ndarray): -+ if obj.dtype == np.bool_: -+ h = "".join(map(lambda n: "%01x" % n, obj.view(np.uint8))).upper() - else: -- raise NotImplementedError(type(obj)) -- -- --def shlo_constant(obj): -- if isinstance(obj, bool): -- v = str(obj).lower() -- return f"stablehlo.constant dense<{v}>" -- if isinstance(obj, np.ndarray): -- if obj.dtype == np.bool_: -- h = "".join(map(lambda n: "%01x" % n, obj.view(np.uint8))).upper() -- else: -- h = "".join(map(lambda n: "%02x" % n, obj.view(np.uint8))).upper() -- return f'stablehlo.constant dense<"0x{h}">' -- else: -- raise NotImplementedError(type(obj)) -+ h = "".join(map(lambda n: "%02x" % n, obj.view(np.uint8))).upper() -+ return f'stablehlo.constant dense<"0x{h}">' -+ else: -+ raise NotImplementedError(type(obj)) - - - if __name__ == "__main__": -- main() -+ main() diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists.txt --- stablehlo/stablehlo/CMakeLists.txt +++ stablehlo/stablehlo/CMakeLists.txt @@ -3830,21 +2943,4 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c +} // namespace experimental +} // namespace stablehlo +} // namespace mlir -diff --ruN a/stablehlo/stablehlo/integrations/python/tests/stablehlo.py b/stablehlo/stablehlo/integrations/python/tests/stablehlo.py ---- stablehlo/stablehlo/integrations/python/tests/stablehlo.py -+++ stablehlo/stablehlo/integrations/python/tests/stablehlo.py -@@ -283,11 +283,13 @@ - expected = arg + arg - assert (actual == expected).all() - -+ - @run - def test_get_smaller_version(): - curr_version = stablehlo.get_current_version() - min_version = stablehlo.get_minimum_version() - assert stablehlo.get_smaller_version(curr_version, min_version) == min_version -+ - - @run - def test_serialization_apis(): diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl index 15648501cf7f1c..4a70795435d0f9 100644 --- a/third_party/xla/third_party/stablehlo/workspace.bzl +++ b/third_party/xla/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "f1f49945f3862a46ecd6c7fc111e9d7843c2b7da" - STABLEHLO_SHA256 = "5e79eaf23075e627c8cbcb4e8572f91fd2db70a4a96722f02490482f26177b4b" + STABLEHLO_COMMIT = "61826746d640f6342856af5ac71dabe6fbf37ff5" + STABLEHLO_SHA256 = "fece99979b885686438068a786895dcd9a559fdc86072178e42181404022d483" # LINT.ThenChange(Google-internal path) tf_http_archive( From 04938cc96427b81171d87a2abedc4e2df3c778ff Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 21 Jun 2024 23:11:03 -0700 Subject: [PATCH 139/256] Automated Code Change PiperOrigin-RevId: 645604862 --- tensorflow/core/profiler/lib/BUILD | 3 ++- tensorflow/core/profiler/lib/profiler_disabled_test.cc | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/profiler/lib/BUILD b/tensorflow/core/profiler/lib/BUILD index 27155ff3c5e393..ae0cd00e1a95e3 100644 --- a/tensorflow/core/profiler/lib/BUILD +++ b/tensorflow/core/profiler/lib/BUILD @@ -249,10 +249,11 @@ tf_cc_test( name = "profiler_disabled_test", srcs = ["profiler_disabled_test.cc"], deps = [ - ":profiler_lock", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/profiler/lib:profiler_lock", ], ) diff --git a/tensorflow/core/profiler/lib/profiler_disabled_test.cc b/tensorflow/core/profiler/lib/profiler_disabled_test.cc index 7bf0b9090a31f1..f55b50ad0375f8 100644 --- a/tensorflow/core/profiler/lib/profiler_disabled_test.cc +++ b/tensorflow/core/profiler/lib/profiler_disabled_test.cc @@ -14,9 +14,10 @@ limitations under the License. ==============================================================================*/ #include +#include "absl/status/statusor.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/lib/profiler_lock.h" +#include "tsl/profiler/lib/profiler_lock.h" namespace tensorflow { namespace profiler { From a025656d3f0afad668ae947187bf48e24d18ae03 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 22 Jun 2024 02:00:05 -0700 Subject: [PATCH 140/256] Automated Code Change PiperOrigin-RevId: 645631649 --- tensorflow/compiler/mlir/tfrt/BUILD | 2 ++ tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index 3e991f23808a2f..d3a0925d1020b1 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -328,6 +328,8 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:status", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", diff --git a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc index 93d50a012a6fed..407871ecbdcd94 100644 --- a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc +++ b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" From 342af9b54852d96a3ed119b8c96195e447f57d2e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 22 Jun 2024 02:02:09 -0700 Subject: [PATCH 141/256] compat: Update forward compatibility horizon to 2024-06-22 PiperOrigin-RevId: 645631949 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index d11ceda88f94e8..49606d977e5051 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 6, 21) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 6, 22) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From a2c59354200eac6e6689644d9c05589c1320feec Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 22 Jun 2024 02:03:42 -0700 Subject: [PATCH 142/256] Update GraphDef version to 1901. PiperOrigin-RevId: 645632322 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index eaab911a44b57f..e462e1429448b5 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1900 // Updated: 2024/6/21 +#define TF_GRAPH_DEF_VERSION 1901 // Updated: 2024/6/22 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 09a5049be3fce9f9d050a7266a31cf24141d0e88 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 22 Jun 2024 02:12:19 -0700 Subject: [PATCH 143/256] Automated Code Change PiperOrigin-RevId: 645633722 --- tensorflow/core/framework/BUILD | 1 + .../simple_hash_table_kernel.cc | 17 ++++++++++++++++- .../simple_hash_table/simple_hash_table_op.cc | 6 ++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 257ce3f9dd2bb3..303edd624df834 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -719,6 +719,7 @@ cc_library( "//learning/deepmind/tensorflow/queues:__pkg__", "//learning/deepmind/tensorflow/sstable:__pkg__", "//tensorflow/compiler/mlir/tools/kernel_gen:__pkg__", + "//tensorflow/examples/custom_ops_doc/simple_hash_table:__pkg__", "//third_party/py/grain/_src/tensorflow/ops:__pkg__", "//waymo/ml/compiler/frontend/kernels:__pkg__", ], diff --git a/tensorflow/examples/custom_ops_doc/simple_hash_table/simple_hash_table_kernel.cc b/tensorflow/examples/custom_ops_doc/simple_hash_table/simple_hash_table_kernel.cc index 67183be4ae0187..577c5d75ca3f28 100644 --- a/tensorflow/examples/custom_ops_doc/simple_hash_table/simple_hash_table_kernel.cc +++ b/tensorflow/examples/custom_ops_doc/simple_hash_table/simple_hash_table_kernel.cc @@ -16,10 +16,25 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" -#include "absl/strings/str_cat.h" +#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" +#include "tensorflow/core/framework/resource_base.h" +#include "tensorflow/core/framework/resource_handle.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/thread_annotations.h" // Please use the appropriate namespace for your project namespace tensorflow { diff --git a/tensorflow/examples/custom_ops_doc/simple_hash_table/simple_hash_table_op.cc b/tensorflow/examples/custom_ops_doc/simple_hash_table/simple_hash_table_op.cc index bc7461fbd77b45..76a343830b49ad 100644 --- a/tensorflow/examples/custom_ops_doc/simple_hash_table/simple_hash_table_op.cc +++ b/tensorflow/examples/custom_ops_doc/simple_hash_table/simple_hash_table_op.cc @@ -17,6 +17,12 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/errors.h" // Please use the appropriate namespace for your project namespace tensorflow { From 0e07c958546762913beae6f6c4d969799a56845d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 22 Jun 2024 10:41:06 -0700 Subject: [PATCH 144/256] Automated Code Change PiperOrigin-RevId: 645692440 --- tensorflow/core/tfrt/mlrt/bytecode/BUILD | 6 ++++++ tensorflow/core/tfrt/mlrt/bytecode/bytecode_test.cc | 1 + tensorflow/core/tfrt/mlrt/bytecode/executable_test.cc | 1 + tensorflow/core/tfrt/mlrt/bytecode/function_test.cc | 2 ++ tensorflow/core/tfrt/mlrt/bytecode/kernel_test.cc | 1 + tensorflow/core/tfrt/mlrt/bytecode/span_test.cc | 1 + 6 files changed, 12 insertions(+) diff --git a/tensorflow/core/tfrt/mlrt/bytecode/BUILD b/tensorflow/core/tfrt/mlrt/bytecode/BUILD index 9e1f0a979c846a..30f3d7b1b6508c 100644 --- a/tensorflow/core/tfrt/mlrt/bytecode/BUILD +++ b/tensorflow/core/tfrt/mlrt/bytecode/BUILD @@ -56,6 +56,7 @@ tf_cc_test( srcs = ["bytecode_test.cc"], deps = [ ":bytecode", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", ], ) @@ -64,6 +65,7 @@ tf_cc_test( name = "kernel_test", srcs = ["kernel_test.cc"], deps = [ + ":bytecode", ":kernel", "@com_google_googletest//:gtest_main", ], @@ -73,7 +75,9 @@ tf_cc_test( name = "function_test", srcs = ["function_test.cc"], deps = [ + ":bytecode", ":function", + ":kernel", "@com_google_googletest//:gtest_main", ], ) @@ -82,6 +86,7 @@ tf_cc_test( name = "executable_test", srcs = ["executable_test.cc"], deps = [ + ":bytecode", ":executable", "@com_google_googletest//:gtest_main", ], @@ -91,6 +96,7 @@ tf_cc_test( name = "span_test", srcs = ["span_test.cc"], deps = [ + ":bytecode", ":span", "@com_google_googletest//:gtest_main", ], diff --git a/tensorflow/core/tfrt/mlrt/bytecode/bytecode_test.cc b/tensorflow/core/tfrt/mlrt/bytecode/bytecode_test.cc index 124c850d5275f2..dd0e8a1c18d4e1 100644 --- a/tensorflow/core/tfrt/mlrt/bytecode/bytecode_test.cc +++ b/tensorflow/core/tfrt/mlrt/bytecode/bytecode_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" namespace mlrt { namespace bc { diff --git a/tensorflow/core/tfrt/mlrt/bytecode/executable_test.cc b/tensorflow/core/tfrt/mlrt/bytecode/executable_test.cc index e10800482c4297..73d673427dc54b 100644 --- a/tensorflow/core/tfrt/mlrt/bytecode/executable_test.cc +++ b/tensorflow/core/tfrt/mlrt/bytecode/executable_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" namespace mlrt { namespace bc { diff --git a/tensorflow/core/tfrt/mlrt/bytecode/function_test.cc b/tensorflow/core/tfrt/mlrt/bytecode/function_test.cc index 2a6a4e75eb9349..ab354fb8ad79d0 100644 --- a/tensorflow/core/tfrt/mlrt/bytecode/function_test.cc +++ b/tensorflow/core/tfrt/mlrt/bytecode/function_test.cc @@ -16,6 +16,8 @@ limitations under the License. #include #include +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/kernel.h" namespace mlrt { namespace bc { diff --git a/tensorflow/core/tfrt/mlrt/bytecode/kernel_test.cc b/tensorflow/core/tfrt/mlrt/bytecode/kernel_test.cc index b8d506aceb05d9..e6af0978aa1b3e 100644 --- a/tensorflow/core/tfrt/mlrt/bytecode/kernel_test.cc +++ b/tensorflow/core/tfrt/mlrt/bytecode/kernel_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" namespace mlrt { namespace bc { diff --git a/tensorflow/core/tfrt/mlrt/bytecode/span_test.cc b/tensorflow/core/tfrt/mlrt/bytecode/span_test.cc index da7fa2ea5d4c53..70c70d8a3d6121 100644 --- a/tensorflow/core/tfrt/mlrt/bytecode/span_test.cc +++ b/tensorflow/core/tfrt/mlrt/bytecode/span_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" namespace mlrt { namespace bc { From 0f0409a71fb03123988860c1a6df593cfc7b04b8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 22 Jun 2024 11:16:29 -0700 Subject: [PATCH 145/256] Adds an option to enable / disable post-processing. PiperOrigin-RevId: 645697145 --- .../xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc | 2 +- .../xla/hlo/experimental/auto_sharding/auto_sharding_option.h | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 3b81bc79578a2e..6c23923952f623 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -4208,7 +4208,7 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( // ----- Set Sharding ----- SetHloSharding(sequence, instructions_to_shard, strategy_map, cost_graph, output.s_val, (mesh_idx == partial_mesh_shapes.size() - 1)); - if (mesh_idx == partial_mesh_shapes.size() - 1) { + if (option_.post_process && mesh_idx == partial_mesh_shapes.size() - 1) { if (!SetHloShardingPostProcessing( sequence, instructions_to_shard, strategy_map, cost_graph, output.s_val, cluster_env, diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h index b2fbb1eabd551e..17e8db8354fde7 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h @@ -209,6 +209,9 @@ struct AutoShardingOption { // Split constant expressions as well when invoking HloConstantSplitter. bool enable_expression_constant_splitter = false; + // Whether to post-process the solution by reshaping / resharding tensors. + bool post_process = true; + // Prints a debug string. std::string ToString() const; From 069e05167c3a317add1e8e480b4a078276f67c0c Mon Sep 17 00:00:00 2001 From: Farzin Houshmand Date: Sat, 22 Jun 2024 11:51:34 -0700 Subject: [PATCH 146/256] Add GetUniqueGteInstruction to hlo_query utility file. PiperOrigin-RevId: 645701191 --- third_party/xla/xla/hlo/utils/BUILD | 1 + third_party/xla/xla/hlo/utils/hlo_query.cc | 21 +++++++++++++++++ third_party/xla/xla/hlo/utils/hlo_query.h | 7 ++++++ .../xla/xla/hlo/utils/hlo_query_test.cc | 23 +++++++++++++++++++ 4 files changed, 52 insertions(+) diff --git a/third_party/xla/xla/hlo/utils/BUILD b/third_party/xla/xla/hlo/utils/BUILD index 94353b525fc9de..63db0ae7d5ef6e 100644 --- a/third_party/xla/xla/hlo/utils/BUILD +++ b/third_party/xla/xla/hlo/utils/BUILD @@ -155,6 +155,7 @@ cc_library( "//xla:literal", "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", ], diff --git a/third_party/xla/xla/hlo/utils/hlo_query.cc b/third_party/xla/xla/hlo/utils/hlo_query.cc index 487fc72e99b5d9..85e41fff68a149 100644 --- a/third_party/xla/xla/hlo/utils/hlo_query.cc +++ b/third_party/xla/xla/hlo/utils/hlo_query.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_query.h" #include +#include #include "absl/algorithm/container.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -23,6 +24,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" +#include "xla/service/pattern_matcher.h" #include "xla/shape_util.h" namespace xla { @@ -247,5 +249,24 @@ bool HasX64TransformedHostTransfer(const HloModule& module) { return false; } +HloInstruction* GetUniqueGteInstruction(const HloInstruction* operand, + int64_t index) { + HloInstruction* gte = nullptr; + for (HloInstruction* instr : operand->parent()->MakeInstructionPostOrder()) { + if (!Match(instr, match::GetTupleElement().WithTupleIndex(index))) { + continue; + } + if (instr->operand(0) != operand) { + continue; + } + // If gte is not unique, return nullptr. + if (gte != nullptr) { + return nullptr; + } + gte = instr; + } + return gte; +} + } // namespace hlo_query } // namespace xla diff --git a/third_party/xla/xla/hlo/utils/hlo_query.h b/third_party/xla/xla/hlo/utils/hlo_query.h index 8343df4dc24472..cda265362d452b 100644 --- a/third_party/xla/xla/hlo/utils/hlo_query.h +++ b/third_party/xla/xla/hlo/utils/hlo_query.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_HLO_UTILS_HLO_QUERY_H_ #define XLA_HLO_UTILS_HLO_QUERY_H_ +#include + #include "absl/container/flat_hash_set.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -145,6 +147,11 @@ int64_t NextChannelId(const HloModule& module); // rewritten into tuple shaped transfers. bool HasX64TransformedHostTransfer(const HloModule& module); +// Returns the unique GTE instruction with the given operand and index. Returns +// nullptr if no such instruction exists or is not unique. +HloInstruction* GetUniqueGteInstruction(const HloInstruction* operand, + int64_t index); + } // namespace hlo_query } // namespace xla diff --git a/third_party/xla/xla/hlo/utils/hlo_query_test.cc b/third_party/xla/xla/hlo/utils/hlo_query_test.cc index 7697bffc855806..acefa21aa9e2f4 100644 --- a/third_party/xla/xla/hlo/utils/hlo_query_test.cc +++ b/third_party/xla/xla/hlo/utils/hlo_query_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_parser.h" @@ -109,5 +110,27 @@ ENTRY main { EXPECT_EQ(CountInstructions(*computation, HloOpcode::kMultiply), 3); } +TEST_F(HloQueryTest, GetUniqueGteTest) { + constexpr absl::string_view kHloString = R"( + HloModule m + + ENTRY main { + param.0 = (f32[32]{0}, f32[32]{0}, f32[32]{0}, f32[32]{0}) parameter(0) + gte1 = f32[32]{0} get-tuple-element(param.0), index=0 + gte2 = f32[32]{0} get-tuple-element(param.0), index=1 + dup_gte2 = f32[32]{0} get-tuple-element(param.0), index=1 + gte3 = f32[32]{0} get-tuple-element(param.0), index=2 + ROOT gte4 = f32[32]{0} get-tuple-element(param.0), index=3 + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule(kHloString)); + HloInstruction* param = module->entry_computation()->parameter_instruction(0); + HloInstruction* gte1 = hlo_query::GetUniqueGteInstruction(param, /*index=*/0); + EXPECT_NE(gte1, nullptr); + HloInstruction* gte2 = hlo_query::GetUniqueGteInstruction(param, /*index=*/1); + EXPECT_EQ(gte2, nullptr); +} + } // namespace } // namespace xla From 0ea7f4ae6816df76bf580e7b0440134ff12f44de Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 22 Jun 2024 11:52:57 -0700 Subject: [PATCH 147/256] Automated Code Change PiperOrigin-RevId: 645701336 --- third_party/xla/xla/hlo/evaluator/BUILD | 2 +- third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc | 2 -- third_party/xla/xla/hlo/evaluator/hlo_evaluator.h | 10 ++++++++++ .../evaluator/hlo_evaluator_typed_visitor_bfloat16.cc | 1 + .../hlo_evaluator_typed_visitor_complex128.cc | 1 + .../evaluator/hlo_evaluator_typed_visitor_complex64.cc | 1 + .../evaluator/hlo_evaluator_typed_visitor_float8.cc | 1 + .../hlo/evaluator/hlo_evaluator_typed_visitor_half.cc | 1 + 8 files changed, 16 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/hlo/evaluator/BUILD b/third_party/xla/xla/hlo/evaluator/BUILD index c6ecac78ce8160..90674968c68567 100644 --- a/third_party/xla/xla/hlo/evaluator/BUILD +++ b/third_party/xla/xla/hlo/evaluator/BUILD @@ -70,6 +70,7 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", @@ -87,7 +88,6 @@ cc_library( "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:types", ], ) diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc index 85fe5b90ab897d..a26b9f6e97005a 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc @@ -84,10 +84,8 @@ limitations under the License. #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" -#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -#include "tsl/platform/types.h" namespace xla { diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h index e0871ff82b5d51..59ed4416f3a2fd 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h @@ -16,6 +16,16 @@ limitations under the License. #ifndef XLA_HLO_EVALUATOR_HLO_EVALUATOR_H_ #define XLA_HLO_EVALUATOR_HLO_EVALUATOR_H_ +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "Eigen/Core" // from @eigen_archive +#include "xla/comparison_util.h" +#include "xla/hlo/ir/dfs_hlo_visitor.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape.h" +#include "xla/status_macros.h" +#include "tsl/platform/errors.h" #define _USE_MATH_DEFINES #include diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_bfloat16.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_bfloat16.cc index 09859c69bb874d..8074e4d0dce3aa 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_bfloat16.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_bfloat16.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/evaluator/hlo_evaluator_typed_visitor.h" +#include "xla/types.h" namespace xla { template class HloEvaluatorTypedVisitor; diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_complex128.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_complex128.cc index d113a6e30c6265..67007bfbf56c99 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_complex128.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_complex128.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/evaluator/hlo_evaluator_typed_visitor.h" +#include "xla/types.h" namespace xla { template class HloEvaluatorTypedVisitor; diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_complex64.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_complex64.cc index 9a55017daff5c3..80b390480270e8 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_complex64.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_complex64.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/evaluator/hlo_evaluator_typed_visitor.h" +#include "xla/types.h" namespace xla { template class HloEvaluatorTypedVisitor; diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc index 9df467e7fd5f67..7c97c210aa36a5 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/evaluator/hlo_evaluator_typed_visitor.h" +#include "tsl/platform/ml_dtypes.h" namespace xla { template class HloEvaluatorTypedVisitor; diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_half.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_half.cc index bb34e07c4d5e61..1420816cd79315 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_half.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_half.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "Eigen/Core" // from @eigen_archive #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/evaluator/hlo_evaluator_typed_visitor.h" From d608037a986403bc6726cacd5522cf51a220be59 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Sat, 22 Jun 2024 13:24:10 -0700 Subject: [PATCH 148/256] [xla:cpu] Add thunk_testlib for writing tests for thunks PiperOrigin-RevId: 645712264 --- third_party/xla/xla/service/cpu/runtime/BUILD | 13 ++++++ .../cpu/runtime/conditional_thunk_test.cc | 21 +-------- .../xla/service/cpu/runtime/thunk_testlib.h | 44 +++++++++++++++++++ 3 files changed, 59 insertions(+), 19 deletions(-) create mode 100644 third_party/xla/xla/service/cpu/runtime/thunk_testlib.h diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD index 8cd0e661a11d42..7a1c4071a0e619 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/service/cpu/runtime/BUILD @@ -64,6 +64,18 @@ cc_library( ], ) +cc_library( + name = "thunk_testlib", + testonly = 1, + hdrs = ["thunk_testlib.h"], + deps = [ + ":thunk", + "//xla/runtime:buffer_use", + "//xla/tsl/concurrency:async_value", + "@com_google_absl//absl/status", + ], +) + xla_cc_test( name = "thunk_test", srcs = ["thunk_test.cc"], @@ -176,6 +188,7 @@ xla_cc_test( ":buffer_allocations", ":conditional_thunk", ":thunk", + ":thunk_testlib", "//xla:shape_util", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", diff --git a/third_party/xla/xla/service/cpu/runtime/conditional_thunk_test.cc b/third_party/xla/xla/service/cpu/runtime/conditional_thunk_test.cc index f3eeb7e3ba4739..cd142aad6f3ba2 100644 --- a/third_party/xla/xla/service/cpu/runtime/conditional_thunk_test.cc +++ b/third_party/xla/xla/service/cpu/runtime/conditional_thunk_test.cc @@ -20,33 +20,16 @@ limitations under the License. #include #include -#include "absl/status/status.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/runtime/thunk.h" -#include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/service/cpu/runtime/thunk_testlib.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla::cpu { namespace { -// A test-only thunk to create a Thunk with a specific buffer use. -class TestThunk : public Thunk { - public: - explicit TestThunk(BufferUse buffer_use) - : Thunk(Kind::kKernel, {"test"}), buffer_use_(buffer_use) {} - - tsl::AsyncValueRef Execute(const ExecuteParams&) final { - return absl::UnimplementedError("Unimplemented"); - } - - BufferUses buffer_uses() const final { return {buffer_use_}; } - - private: - BufferUse buffer_use_; -}; - TEST(ConditionalThunkTest, BufferUses) { BufferAllocation alloc(0, 1024, 0); BufferAllocation::Slice branch_index_slice(&alloc, 0, sizeof(int32_t)); @@ -54,7 +37,7 @@ TEST(ConditionalThunkTest, BufferUses) { std::vector branch_sequences(1); branch_sequences[0].push_back( - std::make_unique(BufferUse::Read(read_slice))); + std::make_unique(BufferUse::Read(read_slice))); TF_ASSERT_OK_AND_ASSIGN( auto thunk, ConditionalThunk::Create({"conditional"}, branch_index_slice, diff --git a/third_party/xla/xla/service/cpu/runtime/thunk_testlib.h b/third_party/xla/xla/service/cpu/runtime/thunk_testlib.h new file mode 100644 index 00000000000000..2c6bdc7360a67b --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/thunk_testlib.h @@ -0,0 +1,44 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_RUNTIME_THUNK_TESTLIB_H_ +#define XLA_SERVICE_CPU_RUNTIME_THUNK_TESTLIB_H_ + +#include "absl/status/status.h" +#include "xla/runtime/buffer_use.h" +#include "xla/service/cpu/runtime/thunk.h" +#include "xla/tsl/concurrency/async_value_ref.h" + +namespace xla::cpu { + +// A test-only thunk to create a Thunk with a specific buffer use. +class BufferUseThunk : public Thunk { + public: + explicit BufferUseThunk(BufferUse buffer_use) + : Thunk(Kind::kKernel, {"buffer-use"}), buffer_use_(buffer_use) {} + + tsl::AsyncValueRef Execute(const ExecuteParams&) final { + return absl::UnimplementedError("Unimplemented"); + } + + BufferUses buffer_uses() const final { return {buffer_use_}; } + + private: + BufferUse buffer_use_; +}; + +} // namespace xla::cpu + +#endif // XLA_SERVICE_CPU_RUNTIME_THUNK_TESTLIB_H_ From bd2b781b0ab164a42e45582a0d25c29b86b281b8 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Sat, 22 Jun 2024 13:44:20 -0700 Subject: [PATCH 149/256] [xla:cpu] Add WhileThunk test PiperOrigin-RevId: 645715011 --- third_party/xla/xla/service/cpu/runtime/BUILD | 22 +++++++ .../service/cpu/runtime/while_thunk_test.cc | 59 +++++++++++++++++++ 2 files changed, 81 insertions(+) create mode 100644 third_party/xla/xla/service/cpu/runtime/while_thunk_test.cc diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD index 7a1c4071a0e619..8e4977c01e0a04 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/service/cpu/runtime/BUILD @@ -737,6 +737,28 @@ cc_library( ], ) +xla_cc_test( + name = "while_thunk_test", + srcs = ["while_thunk_test.cc"], + deps = [ + ":buffer_allocations", + ":thunk", + ":thunk_testlib", + ":while_thunk", + "//xla:shape_util", + "//xla/runtime:buffer_use", + "//xla/service:buffer_assignment", + "//xla/service:maybe_owning_device_memory", + "//xla/stream_executor", + "//xla/tsl/concurrency:async_value", + "@com_google_absl//absl/status", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + cc_library( name = "fft_thunk", srcs = ["fft_thunk.cc"], diff --git a/third_party/xla/xla/service/cpu/runtime/while_thunk_test.cc b/third_party/xla/xla/service/cpu/runtime/while_thunk_test.cc new file mode 100644 index 00000000000000..99ee184861b158 --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/while_thunk_test.cc @@ -0,0 +1,59 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/runtime/while_thunk.h" + +#include +#include +#include +#include + +#include "xla/runtime/buffer_use.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/cpu/runtime/thunk.h" +#include "xla/service/cpu/runtime/thunk_testlib.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla::cpu { +namespace { + +TEST(WhileThunkTest, BufferUses) { + BufferAllocation alloc(0, 1024, 0); + BufferAllocation::Slice predicate_slice(&alloc, 0, sizeof(int32_t)); + BufferAllocation::Slice cond_read_slice(&alloc, 10, 10); + BufferAllocation::Slice body_read_slice(&alloc, 20, 10); + + ThunkSequence cond_sequence; + cond_sequence.push_back( + std::make_unique(BufferUse::Read(cond_read_slice))); + + ThunkSequence body_sequence; + body_sequence.push_back( + std::make_unique(BufferUse::Read(body_read_slice))); + + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, + WhileThunk::Create({"while"}, predicate_slice, std::move(cond_sequence), + std::move(body_sequence))); + + EXPECT_EQ(thunk->buffer_uses().size(), 3); + EXPECT_EQ(thunk->buffer_uses()[0], BufferUse::Write(predicate_slice)); + EXPECT_EQ(thunk->buffer_uses()[1], BufferUse::Read(cond_read_slice)); + EXPECT_EQ(thunk->buffer_uses()[2], BufferUse::Read(body_read_slice)); +} + +} // namespace +} // namespace xla::cpu From e371313cf0e073a47dd6f67fe795643a5e27286c Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Sat, 22 Jun 2024 14:29:21 -0700 Subject: [PATCH 150/256] [xla:cpu] Add ReplicaId thunk test PiperOrigin-RevId: 645719700 --- third_party/xla/xla/service/cpu/runtime/BUILD | 21 +++++ .../cpu/runtime/replica_id_thunk_test.cc | 79 +++++++++++++++++++ 2 files changed, 100 insertions(+) create mode 100644 third_party/xla/xla/service/cpu/runtime/replica_id_thunk_test.cc diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD index 8e4977c01e0a04..10561be09ea72a 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/service/cpu/runtime/BUILD @@ -616,6 +616,27 @@ cc_library( ], ) +xla_cc_test( + name = "replica_id_thunk_test", + srcs = ["replica_id_thunk_test.cc"], + deps = [ + ":buffer_allocations", + ":replica_id_thunk", + ":thunk", + "//xla:executable_run_options", + "//xla:shape_util", + "//xla/service:buffer_assignment", + "//xla/service:executable", + "//xla/service:maybe_owning_device_memory", + "//xla/stream_executor", + "//xla/tsl/concurrency:async_value", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + cc_library( name = "infeed_thunk", srcs = ["infeed_thunk.cc"], diff --git a/third_party/xla/xla/service/cpu/runtime/replica_id_thunk_test.cc b/third_party/xla/xla/service/cpu/runtime/replica_id_thunk_test.cc new file mode 100644 index 00000000000000..c1345cef5e00a2 --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/replica_id_thunk_test.cc @@ -0,0 +1,79 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/runtime/replica_id_thunk.h" + +#include +#include +#include + +#include "xla/executable_run_options.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/cpu/runtime/buffer_allocations.h" +#include "xla/service/cpu/runtime/thunk.h" +#include "xla/service/maybe_owning_device_memory.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla::cpu { +namespace { + +DeviceAssignment CreateDeviceAssignment(std::vector devices) { + DeviceAssignment device_assignment(/*replica_count=*/devices.size(), + /*computation_count=*/1); + for (int64_t i = 0; i < devices.size(); ++i) { + device_assignment(i, 0) = devices[i]; + } + return device_assignment; +} + +TEST(ReplicaIdThunkTest, GetReplicaId) { + std::vector dst(1, -1); + + std::vector buffers; + buffers.emplace_back(se::DeviceMemoryBase(dst.data(), sizeof(int32_t))); + + BufferAllocation alloc(/*index=*/0, /*size=*/sizeof(int32_t), /*color=*/0); + BufferAllocation::Slice id_slice(&alloc, /*offset=*/0, + /*size=*/sizeof(int32_t)); + + std::string name(Thunk::KindToString(Thunk::Kind::kReplicaId)); + TF_ASSERT_OK_AND_ASSIGN(auto thunk, ReplicaIdThunk::Create({name}, id_slice)); + + BufferAllocations allocations(buffers); + DeviceAssignment device_assn = CreateDeviceAssignment({0, 1}); + + ExecutableRunOptions run_options; + run_options.set_device_ordinal(0); + run_options.set_device_assignment(&device_assn); + + TF_ASSERT_OK_AND_ASSIGN(Thunk::CollectiveExecuteParams collective_params, + Thunk::CollectiveExecuteParams::Create(&run_options)); + + Thunk::ExecuteParams params; + params.buffer_allocations = &allocations; + params.collective_params = &collective_params; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_FALSE(execute_event.IsError()); + + EXPECT_EQ(dst[0], 0); +} + +} // namespace +} // namespace xla::cpu From d7602ced905769892be97eb8dc1e472e705c1eab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Bana=C5=9B?= Date: Sat, 22 Jun 2024 14:33:37 -0700 Subject: [PATCH 151/256] [xla:cpu] Add extern templates for Conv2D and Conv3D. These templates were instantiated twice (once for current runtime, once for thunks runtime). Now they are instantiated once. It reduces binary size and compilation time. PiperOrigin-RevId: 645720154 --- third_party/xla/xla/service/cpu/BUILD | 4 ++ .../xla/xla/service/cpu/runtime_conv_impl.cc | 66 +++++++++++++++++++ .../xla/xla/service/cpu/runtime_conv_impl.h | 50 ++++++++++++++ 3 files changed, 120 insertions(+) create mode 100644 third_party/xla/xla/service/cpu/runtime_conv_impl.cc diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 10a56c4e2c4c2f..6bad3c5c9b9c8f 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -72,6 +72,7 @@ filegroup( srcs = [ # Single-threaded support. "runtime_custom_call_status.cc", + "runtime_conv_impl.cc", "runtime_fp16.cc", "runtime_key_value_sort.cc", "runtime_pow.cc", @@ -1012,12 +1013,15 @@ cc_library( cc_library( name = "runtime_conv_impl", + srcs = ["runtime_conv_impl.cc"], hdrs = ["runtime_conv_impl.h"], + copts = runtime_copts(), visibility = internal_visibility([":friends"]), deps = [ "//xla/tsl/framework/contraction:eigen_contraction_kernel", "//xla/tsl/framework/convolution:eigen_helpers", "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:mutex", # build_cleaner: keep ], ) diff --git a/third_party/xla/xla/service/cpu/runtime_conv_impl.cc b/third_party/xla/xla/service/cpu/runtime_conv_impl.cc new file mode 100644 index 00000000000000..b3990cbed1ea65 --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime_conv_impl.cc @@ -0,0 +1,66 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#define EIGEN_USE_THREADS + +#include "xla/service/cpu/runtime_conv_impl.h" + +namespace tensorflow::xla { + +// Instantiate Conv2D template for all supported devices and data types. +#define CONV2D_INSTANTIATE_TEMPLATE(EigenDevice, ScalarType) \ + template void EigenConv2DImpl( \ + const EigenDevice& device, ScalarType* out, ScalarType* lhs, \ + ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x, \ + Eigen::Index input_y, Eigen::Index input_channels, \ + Eigen::Index kernel_x, Eigen::Index kernel_y, \ + Eigen::Index kernel_channels, Eigen::Index kernel_filters, \ + Eigen::Index output_x, Eigen::Index output_y, Eigen::Index x_stride, \ + Eigen::Index y_stride, Eigen::Index padding_x_before, \ + Eigen::Index padding_x_after, Eigen::Index padding_y_before, \ + Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation, \ + Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation, \ + Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count) + +CONV2D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, Eigen::half); +CONV2D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, float); +CONV2D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, Eigen::half); +CONV2D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, float); + +#undef CONV2D_INSTANTIATE_TEMPLATE + +// Instantiate Conv3D template for all supported devices and data types. +#define CONV3D_INSTANTIATE_TEMPLATE(EigenDevice, ScalarType) \ + template void EigenConv3DImpl( \ + const EigenDevice& device, ScalarType* out, ScalarType* lhs, \ + ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x, \ + Eigen::Index input_y, Eigen::Index input_z, Eigen::Index input_channels, \ + Eigen::Index kernel_x, Eigen::Index kernel_y, Eigen::Index kernel_z, \ + Eigen::Index kernel_channels, Eigen::Index kernel_filters, \ + Eigen::Index output_x, Eigen::Index output_y, Eigen::Index output_z, \ + Eigen::Index x_stride, Eigen::Index y_stride, Eigen::Index z_stride, \ + Eigen::Index padding_x_before, Eigen::Index padding_x_after, \ + Eigen::Index padding_y_before, Eigen::Index padding_y_after, \ + Eigen::Index padding_z_before, Eigen::Index padding_z_after, \ + Eigen::Index lhs_x_dilation, Eigen::Index lhs_y_dilation, \ + Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation, \ + Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation, \ + Eigen::Index feature_group_count) + +CONV3D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, Eigen::half); +CONV3D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, float); +CONV3D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, Eigen::half); +CONV3D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, float); + +} // namespace tensorflow::xla diff --git a/third_party/xla/xla/service/cpu/runtime_conv_impl.h b/third_party/xla/xla/service/cpu/runtime_conv_impl.h index 957d0b2c82f5bd..5545aa3cd51087 100644 --- a/third_party/xla/xla/service/cpu/runtime_conv_impl.h +++ b/third_party/xla/xla/service/cpu/runtime_conv_impl.h @@ -191,6 +191,56 @@ void EigenConv3DImpl( } } +// Extern Conv2D template for all supported devices and data types. +// TODO(abanas): These templates are instantiated in convolution_thunk.cc. Move +// the definitions from this file to convolution thunk, and make all runtime +// conv targets depend on it. +#define CONV2D_EXTERN_TEMPLATE(EigenDevice, ScalarType) \ + extern template void EigenConv2DImpl( \ + const EigenDevice& device, ScalarType* out, ScalarType* lhs, \ + ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x, \ + Eigen::Index input_y, Eigen::Index input_channels, \ + Eigen::Index kernel_x, Eigen::Index kernel_y, \ + Eigen::Index kernel_channels, Eigen::Index kernel_filters, \ + Eigen::Index output_x, Eigen::Index output_y, Eigen::Index x_stride, \ + Eigen::Index y_stride, Eigen::Index padding_x_before, \ + Eigen::Index padding_x_after, Eigen::Index padding_y_before, \ + Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation, \ + Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation, \ + Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count) + +CONV2D_EXTERN_TEMPLATE(Eigen::DefaultDevice, Eigen::half); +CONV2D_EXTERN_TEMPLATE(Eigen::DefaultDevice, float); +CONV2D_EXTERN_TEMPLATE(Eigen::ThreadPoolDevice, Eigen::half); +CONV2D_EXTERN_TEMPLATE(Eigen::ThreadPoolDevice, float); + +#undef CONV2D_EXTERN_TEMPLATE + +// Extern Conv3D template for all supported devices and data types. +#define CONV3D_EXTERN_TEMPLATE(EigenDevice, ScalarType) \ + extern template void EigenConv3DImpl( \ + const EigenDevice& device, ScalarType* out, ScalarType* lhs, \ + ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x, \ + Eigen::Index input_y, Eigen::Index input_z, Eigen::Index input_channels, \ + Eigen::Index kernel_x, Eigen::Index kernel_y, Eigen::Index kernel_z, \ + Eigen::Index kernel_channels, Eigen::Index kernel_filters, \ + Eigen::Index output_x, Eigen::Index output_y, Eigen::Index output_z, \ + Eigen::Index x_stride, Eigen::Index y_stride, Eigen::Index z_stride, \ + Eigen::Index padding_x_before, Eigen::Index padding_x_after, \ + Eigen::Index padding_y_before, Eigen::Index padding_y_after, \ + Eigen::Index padding_z_before, Eigen::Index padding_z_after, \ + Eigen::Index lhs_x_dilation, Eigen::Index lhs_y_dilation, \ + Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation, \ + Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation, \ + Eigen::Index feature_group_count) + +CONV3D_EXTERN_TEMPLATE(Eigen::DefaultDevice, Eigen::half); +CONV3D_EXTERN_TEMPLATE(Eigen::DefaultDevice, float); +CONV3D_EXTERN_TEMPLATE(Eigen::ThreadPoolDevice, Eigen::half); +CONV3D_EXTERN_TEMPLATE(Eigen::ThreadPoolDevice, float); + +#undef CONV3D_EXTERN_TEMPLATE + } // namespace xla } // namespace tensorflow From eb1f2b41bfcc5813ad063da963738894bc7e9b3d Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Sat, 22 Jun 2024 20:28:32 -0700 Subject: [PATCH 152/256] [Gradients] Tag constant zero tensors for outputs with no gradient with `_is_zeros_tensor`. Previously, when computing gradients in a V2 while loop, we would not tag numeric zero tensors that have a fully-defined shape with the `_is_zeros_tensor` attribute. This attribute is used to optimize the gradient graph for `SoftmaxCrossEntropyWithLogits` and `SparseSoftmaxCrossEntropyWithLogits`, and omitting it leads to an unnecessary softmax and matrix multiplication in the eventual program. PiperOrigin-RevId: 645761816 --- tensorflow/python/ops/control_flow_state.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/ops/control_flow_state.py b/tensorflow/python/ops/control_flow_state.py index b604189e91aa2c..ab687285bc41b2 100644 --- a/tensorflow/python/ops/control_flow_state.py +++ b/tensorflow/python/ops/control_flow_state.py @@ -803,6 +803,12 @@ def _ZerosLikeV1(op, index): return array_ops.zeros_like(val, optimize=False) +@array_ops._tag_zeros_tensor # pylint: disable=protected-access +def _ConstantZeros(shape, dtype): + """Create a constant zero tensor.""" + return constant_op.constant(0, shape=shape, dtype=dtype) + + def _ZerosLikeV2(op, index): """Branch of ZerosLike for TF2.""" val = op.outputs[index] @@ -817,7 +823,7 @@ def _ZerosLikeV2(op, index): # it helps avoid creating extra nodes(possibly Consts) for the shape. # For variants, we must use ZerosLike. if val.shape.is_fully_defined(): - return constant_op.constant(0, shape=val.shape.dims, dtype=val.dtype) + return _ConstantZeros(val.shape.dims, val.dtype) else: # Note: Even though we add `Shape` in the default graph, while_v2 is smart # enough to place it in the forward graph i.e. `val.graph`. From a4fbe1b37fa21939ac457669998ed1521003fdcd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 23 Jun 2024 02:02:13 -0700 Subject: [PATCH 153/256] Update GraphDef version to 1902. PiperOrigin-RevId: 645809917 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index e462e1429448b5..755f92897906c3 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1901 // Updated: 2024/6/22 +#define TF_GRAPH_DEF_VERSION 1902 // Updated: 2024/6/23 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From efd51a9eb7909edc391285e4bc12789514de191c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 23 Jun 2024 02:02:15 -0700 Subject: [PATCH 154/256] compat: Update forward compatibility horizon to 2024-06-23 PiperOrigin-RevId: 645809921 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 49606d977e5051..82244cce6f8bf7 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 6, 22) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 6, 23) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 841fecad98ff3d8f4c1ff2b68bfe2a9f6e76056f Mon Sep 17 00:00:00 2001 From: Chi Zeng Date: Sun, 23 Jun 2024 02:12:21 -0700 Subject: [PATCH 155/256] [XLA:GPU] Set reduce_window_rewrite_base_length to 16 by default PiperOrigin-RevId: 645811580 --- third_party/xla/xla/debug_options_flags.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index b223d7c0c320e7..df05be304cf1a1 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -252,7 +252,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_use_memcpy_local_p2p(false); - opts.set_xla_reduce_window_rewrite_base_length(32); + opts.set_xla_reduce_window_rewrite_base_length(16); opts.set_xla_gpu_require_complete_aot_autotune_results(false); From bea53c9ad42dd5150ef9954c4aefc4bc2fd90de4 Mon Sep 17 00:00:00 2001 From: akhilgoe <114951738+akhilgoe@users.noreply.github.com> Date: Sun, 23 Jun 2024 16:48:04 -0700 Subject: [PATCH 156/256] PR #10301: [XLA:CPU][oneDNN] Convolution XLA HLO Pattern Matcher with oneDNN custom call rewrite Imported from GitHub PR https://github.com/openxla/xla/pull/10301 This PR enables oneDNN library call for the matched XLA HLO Convolution pattern through custom_call instruction. In particular, this PR: 1. Adds oneDNN convolution rewriter pass that will rewrite HLO Convolution to oneDNN Convolution. 2. Refactors backend config to enhance code reusability. 3. Adds a convolution test file to verify rewrite and execution result Copybara import of the project: -- b7d0abb4f683595c91bf144bcf1b208254ca5c74 by Akhil Goel : Add onednn convolution support -- fbf544ea346250129c00023da7e193aecad5375b by Akhil Goel : Remove unused symbol from BUILD file -- 76c079109361c75a6c512d32801039ed7a3f30e1 by Akhil Goel : Fix buildifier error -- 01f59d2d11481a1ac0b3e7a491261e1ee541aac0 by Akhil Goel : Address Review Comments -- 667502387a855ed0974822334283b890f84d1c34 by Akhil Goel : Refactor oneDNN rewritability check to a separate function -- 6670413b5150518044a371e99e60a9ea36984660 by Akhil Goel : Add cpu package to onednn_config proto file -- c1e8fc78e2a6a039ef615b17130be8b0fbd9c901 by Akhil Goel : Push missing change in merge -- 9f446d82fef58f0b0b946f6c5e4baba8cf5e4a50 by Akhil Goel : Mark outdated ids as reserved Merging this change closes #10301 PiperOrigin-RevId: 645921244 --- third_party/xla/xla/service/BUILD | 1 + .../xla/xla/service/algebraic_simplifier.cc | 85 +++++++ .../xla/xla/service/algebraic_simplifier.h | 11 + third_party/xla/xla/service/cpu/BUILD | 63 +++++ .../xla/xla/service/cpu/backend_config.proto | 58 +---- .../xla/xla/service/cpu/cpu_compiler.cc | 4 + .../xla/xla/service/cpu/cpu_float_support.cc | 4 + .../xla/xla/service/cpu/cpu_runtime.cc | 2 + third_party/xla/xla/service/cpu/cpu_runtime.h | 1 + third_party/xla/xla/service/cpu/ir_emitter.cc | 86 ++++++- third_party/xla/xla/service/cpu/ir_emitter.h | 1 + .../xla/xla/service/cpu/onednn_config.proto | 116 +++++++++ .../xla/xla/service/cpu/onednn_convolution.cc | 230 ++++++++++++++++++ .../xla/xla/service/cpu/onednn_convolution.h | 31 +++ .../cpu/onednn_convolution_rewriter.cc | 140 +++++++++++ .../service/cpu/onednn_convolution_rewriter.h | 49 ++++ .../xla/xla/service/cpu/onednn_layer_norm.cc | 5 +- .../xla/xla/service/cpu/onednn_matmul.cc | 43 ++-- .../xla/service/cpu/onednn_matmul_rewriter.cc | 88 ++++--- .../xla/service/cpu/onednn_ops_rewriter.cc | 4 +- .../xla/xla/service/cpu/simple_orc_jit.cc | 2 + third_party/xla/xla/tests/BUILD | 20 ++ .../xla/xla/tests/onednn_convolution_test.cc | 87 +++++++ .../xla/xla/tests/onednn_layer_norm_test.cc | 2 +- .../xla/xla/tests/onednn_matmul_test.cc | 58 +++-- 25 files changed, 1060 insertions(+), 131 deletions(-) create mode 100644 third_party/xla/xla/service/cpu/onednn_config.proto create mode 100644 third_party/xla/xla/service/cpu/onednn_convolution.cc create mode 100644 third_party/xla/xla/service/cpu/onednn_convolution.h create mode 100644 third_party/xla/xla/service/cpu/onednn_convolution_rewriter.cc create mode 100644 third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h create mode 100644 third_party/xla/xla/tests/onednn_convolution_test.cc diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 614864cfaef0b0..5f490b4d833ff4 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -2598,6 +2598,7 @@ cc_library( name = "algebraic_simplifier", srcs = ["algebraic_simplifier.cc"], hdrs = ["algebraic_simplifier.h"], + copts = tsl_copts(), deps = [ ":hlo_cost_analysis", ":hlo_creation_utils", diff --git a/third_party/xla/xla/service/algebraic_simplifier.cc b/third_party/xla/xla/service/algebraic_simplifier.cc index 9cb49ab70bedcf..c9b9850a94d2a8 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.cc +++ b/third_party/xla/xla/service/algebraic_simplifier.cc @@ -8805,6 +8805,82 @@ absl::StatusOr AlgebraicSimplifierVisitor::SwapConvOperands( return true; } +absl::StatusOr AlgebraicSimplifierVisitor::IsOneDnnRewritableBF16Conv( + HloInstruction** convolution) { + bool can_rewrite = true; + auto from_dtype = (*convolution)->shape().element_type(); + if (!options_.executing_on_cpu() || from_dtype != PrimitiveType::BF16) { + return false; + } + if ((*convolution)->batch_group_count() != 1 || + (*convolution)->operand(1)->opcode() == HloOpcode::kReverse) { + can_rewrite = false; + } + const Shape& inp_shape = (*convolution)->operand(0)->shape(); + const Shape& ker_shape = (*convolution)->operand(1)->shape(); + const Shape& out_shape = (*convolution)->shape(); + if (ShapeUtil::IsZeroElementArray(inp_shape) || + ShapeUtil::IsZeroElementArray(ker_shape) || + ShapeUtil::IsZeroElementArray(out_shape)) { + can_rewrite = false; + } + + auto dims = (*convolution)->window().dimensions().size(); + if (dims >= 4 || dims <= 0) can_rewrite = false; + + if (inp_shape.rank() != ker_shape.rank() || + inp_shape.rank() != out_shape.rank()) { + can_rewrite = false; + } + + for (auto it = (*convolution)->window().dimensions().begin(); + it != (*convolution)->window().dimensions().end(); it++) { + if ((*it).padding_low() < 0 || (*it).padding_high() < 0 || + (*it).stride() < 0 || (*it).base_dilation() != 1 || + (*it).window_reversal()) { + can_rewrite = false; + } + } + + if (can_rewrite) { + return true; + } + + // To ensure the correctness of the generated LLVM IR, we cast + // the convolutions that are not rewritable to onednn custom calls to higher + // precision. This does not compromise performance as lower floating point + // precision convolutions are converted to higher precision in the regular + // optimization pipeline. + auto to_dtype = PrimitiveType::F32; + std::vector new_operands; + auto from_dtype_operands = (*convolution)->operands(); + + std::for_each( + from_dtype_operands.begin(), from_dtype_operands.end(), + [&new_operands, &to_dtype](HloInstruction* instr) { + new_operands.push_back( + instr->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(instr->shape(), to_dtype), + instr))); + }); + + HloInstruction* to_conv = + (*convolution) + ->AddInstruction( + (*convolution) + ->CloneWithNewOperands(ShapeUtil::ChangeElementType( + (*convolution)->shape(), to_dtype), + new_operands)); + + HloInstruction* from_conv = + to_conv->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(to_conv->shape(), from_dtype), to_conv)); + + TF_RETURN_IF_ERROR(ReplaceInstruction(*convolution, from_conv)); + *convolution = to_conv; + return false; +} + absl::StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( HloInstruction* convolution) { auto* lhs = convolution->mutable_operand(0); @@ -9064,6 +9140,15 @@ absl::Status AlgebraicSimplifierVisitor::HandleConvolution( if (swapped) { return absl::OkStatus(); } +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + // Convert the data type back to F32 if we can't rewrite BF16 convolution to + // oneDNN custom call. + TF_ASSIGN_OR_RETURN(bool can_rewrite_bf16_conv_to_onednn, + IsOneDnnRewritableBF16Conv(&convolution)); + if (can_rewrite_bf16_conv_to_onednn) { + return OkStatus(); + } +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 // Try to replace the convolution with a kDot or a kMultiply instruction. TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution)); if (replaced_with_dot) { diff --git a/third_party/xla/xla/service/algebraic_simplifier.h b/third_party/xla/xla/service/algebraic_simplifier.h index 18e8b56f6d4e0a..7bece1ccd322b2 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.h +++ b/third_party/xla/xla/service/algebraic_simplifier.h @@ -247,6 +247,12 @@ class AlgebraicSimplifierOptions { enable_unconditional_reduce_of_concat_replacement; } + // Indicates whether running on CPU + bool executing_on_cpu() const { return executing_on_cpu_; } + void set_executing_on_cpu(bool executing_on_cpu) { + executing_on_cpu_ = executing_on_cpu; + } + private: // Metadata struct can be used to store any metadata information encapsulated // with the AlgebraicSimplierOptions that can be later used in an @@ -280,6 +286,7 @@ class AlgebraicSimplifierOptions { bool minmax_propagate_nan_{true}; bool enable_unconditional_reduce_of_concat_replacement_{true}; bool use_associative_reordering_{false}; + bool executing_on_cpu_{false}; double associative_reordering_threshold_{2.0}; Metadata metadata_; }; @@ -590,6 +597,10 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { // convolution. absl::StatusOr SwapConvOperands(HloInstruction* convolution); + // Checks if the given convolution is in BF16 and is oneDNN rewritable, if not + // then it promotes the data type of the convolution to F32 + absl::StatusOr IsOneDnnRewritableBF16Conv(HloInstruction** convolution); + // Tries to use a kDot in place of the given convolution. absl::StatusOr SimplifyConvToDot(HloInstruction* convolution); // Tries to use a multiplication in place of the given convolution. diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 6bad3c5c9b9c8f..e7350a66d091e0 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -221,6 +221,7 @@ cc_library( ":ir_emission_utils", ":ir_emitter", ":ir_emitter2", + ":onednn_convolution_rewriter", ":onednn_matmul_rewriter", ":onednn_ops_rewriter", ":parallel_task_assignment", @@ -460,6 +461,7 @@ cc_library( deps = [ ":compiler_functor", ":cpu_runtime", + ":onednn_convolution", ":onednn_layer_norm", ":onednn_matmul", ":onednn_softmax", @@ -1641,10 +1643,19 @@ xla_cc_test( ], ) +tf_proto_library( + name = "onednn_config_proto", + srcs = ["onednn_config.proto"], + cc_api_version = 2, +) + tf_proto_library( name = "backend_config_proto", srcs = ["backend_config.proto"], cc_api_version = 2, + protodeps = [ + ":onednn_config_proto", + ], ) cc_library( @@ -1720,6 +1731,29 @@ cc_library( ] + mkl_deps(), ) +cc_library( + name = "onednn_convolution", + srcs = ["onednn_convolution.cc"], + hdrs = ["onednn_convolution.h"], + copts = runtime_copts() + tsl_copts(), + visibility = ["//visibility:public"], + deps = [ + ":backend_config_proto_cc", + ":onednn_memory_util", + ":onednn_util", + ":runtime_lightweight_check", + "//xla:executable_run_options", + "//xla:shape_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:dynamic_annotations", + "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:blocking_counter", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:platform_port", + ] + mkl_deps(), +) + cc_library( name = "onednn_layer_norm", srcs = ["onednn_layer_norm.cc"], @@ -1832,12 +1866,41 @@ cc_library( ] + mkl_deps(), ) +cc_library( + name = "onednn_convolution_rewriter", + srcs = ["onednn_convolution_rewriter.cc"], + hdrs = ["onednn_convolution_rewriter.h"], + copts = tsl_copts(), + deps = [ + ":backend_config_proto_cc", + ":onednn_convolution", + ":onednn_memory_util", + ":onednn_util", + "//xla:executable_run_options", + "//xla:shape_util", + "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/hlo/evaluator:hlo_evaluator", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "//xla/service:pattern_matcher", + "@com_google_absl//absl/algorithm:container", + "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:blocking_counter", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:platform_port", + ] + mkl_deps(), +) + cc_library( name = "cpu_float_support", srcs = ["cpu_float_support.cc"], hdrs = ["cpu_float_support.h"], copts = tsl_copts(), deps = [ + ":onednn_convolution_rewriter", ":onednn_matmul_rewriter", "//xla/service:float_support", ], diff --git a/third_party/xla/xla/service/cpu/backend_config.proto b/third_party/xla/xla/service/cpu/backend_config.proto index 9bbab4d1c5daf3..c75a3e6df9e5a1 100644 --- a/third_party/xla/xla/service/cpu/backend_config.proto +++ b/third_party/xla/xla/service/cpu/backend_config.proto @@ -2,58 +2,22 @@ syntax = "proto3"; package xla.cpu; +import "xla/service/cpu/onednn_config.proto"; + // Backend config for XLA:CPU. message BackendConfig { // Number of partitions per outer dimension (in order, starting with // outer-most dimension first). Used by the parallel cpu backend to partition // HLOs into parallel tasks. repeated int64 outer_dimension_partitions = 1; - // Configuration to be used by oneDNN matmul - OneDnnMatMulConfig onednn_matmul_config = 2; - OneDnnLayerNormConfig onednn_layer_norm_config = 3; - OneDnnSoftmaxConfig onednn_softmax_config = 4; -} - -message OneDnnMatMulConfig { - bool transpose_a = 1; - bool transpose_b = 2; - // These enum needs to be mapped to oneDNN enum for post_op algorithm. - // TODO(intel-tf): Add kinds supported by oneDNN. - enum FusionKind { - UNDEFINED = 0; - BIAS = 1; - RELU = 2; - TANH = 3; - GELU_ERF = 4; - GELU_TANH = 5; - BINARY_ADD = 6; - LINEAR = 7; - ELU = 8; - RELU6 = 9; - SIGMOID = 10; - } - repeated FusionKind fused_ops = 3; - bool bias_broadcast = 4; - // To avoid protobuf failures for specific decimal values, - // the original float value alpha is type-casted to int32. - int32 alpha_typecast = 5; - bool weights_prepacked = 6; - bool user_scratchpad = 7; -} - -message OneDnnLayerNormConfig { - // These enum needs to be mapped to oneDNN enum for post_op algorithm. - // TODO(intel-tf): Add kinds supported by oneDNN. - enum FusionKind { - UNDEFINED = 0; - SCALE = 1; - SHIFT = 2; - SCALE_AND_SHIFT = 3; + oneof backend_config_oneof { + // Configuration to be used by oneDNN matmul + OneDnnMatMulConfig onednn_matmul_config = 2; + // Configuration to be used by oneDNN layer norm + OneDnnNormConfig onednn_layer_norm_config = 3; + // Configuration to be used by oneDNN softmax + OneDnnSoftmaxConfig onednn_softmax_config = 4; + // Configuration to be used by oneDNN convolution + OneDnnConvolutionConfig onednn_conv_config = 5; } - FusionKind fused_ops = 1; - int32 epsilon_typecast = 2; -} - -message OneDnnSoftmaxConfig { - int32 softmax_axis = 1; } diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index 8cd2d488e14802..8b4b11b43a4773 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -202,6 +202,7 @@ limitations under the License. #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) #include "xla/service/cpu/cpu_float_support.h" +#include "xla/service/cpu/onednn_convolution_rewriter.h" #include "xla/service/cpu/onednn_matmul_rewriter.h" #include "xla/service/cpu/onednn_ops_rewriter.h" #include "xla/service/simplify_fp_conversions.h" @@ -629,6 +630,7 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( // other platforms do, so it should be changed. options.set_minmax_propagate_nan(false); options.set_supports_non_canonical_dots(false); + options.set_executing_on_cpu(true); pipeline.AddPass(options); pipeline.AddPass(); pipeline.AddPass(); @@ -743,6 +745,7 @@ absl::Status CpuCompiler::RunHloPassesAfterLayoutAssn( if (debug_options.xla_allow_excess_precision()) { pipeline.AddPass(); } + pipeline.AddPass(); pipeline.AddPass(max_parallelism, compile_options.thread_pool); // Run SimplifyFPConversions pass again to remove redundant Convert ops @@ -774,6 +777,7 @@ absl::Status CpuCompiler::RunHloPassesAfterLayoutAssn( // TODO(b/209827141): XLA:CPU doesn't propagate NaN through min/max, but // other platforms do, so it should be changed. options.set_minmax_propagate_nan(false); + options.set_executing_on_cpu(true); pipeline.AddPass(options); pipeline.AddPass(); pipeline.AddPass(/*is_layout_sensitive=*/true); diff --git a/third_party/xla/xla/service/cpu/cpu_float_support.cc b/third_party/xla/xla/service/cpu/cpu_float_support.cc index 0bb4dd8e875a75..6914e656b900a6 100644 --- a/third_party/xla/xla/service/cpu/cpu_float_support.cc +++ b/third_party/xla/xla/service/cpu/cpu_float_support.cc @@ -17,6 +17,7 @@ limitations under the License. #include "xla/service/cpu/cpu_float_support.h" +#include "xla/service/cpu/onednn_convolution_rewriter.h" #include "xla/service/cpu/onednn_matmul_rewriter.h" namespace xla { @@ -28,6 +29,9 @@ bool CpuFloatSupport::IsSupported(const HloInstruction& hlo) const { case HloOpcode::kDot: return LowPrecisionType() == BF16 && OneDnnMatMulRewriter::ShouldRewrite(&hlo); + case HloOpcode::kConvolution: + return LowPrecisionType() == BF16 && + OneDnnConvolutionRewriter::ShouldRewrite(&hlo); // Collective ops. case HloOpcode::kAllGather: case HloOpcode::kAllReduce: diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.cc b/third_party/xla/xla/service/cpu/cpu_runtime.cc index 05125edd3d5f62..95916458f0232d 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.cc +++ b/third_party/xla/xla/service/cpu/cpu_runtime.cc @@ -163,6 +163,8 @@ extern const char* const kOneDnnSoftmaxSymbolName = "__xla_cpu_runtime_OneDnnSoftmax"; extern const char* const kOneDnnLayerNormSymbolName = "__xla_cpu_runtime_OneDnnLayerNorm"; +extern const char* const kOneDnnConvolutionSymbolName = + "__xla_cpu_runtime_OneDnnConvolution"; extern const char* const kOneDnnMatMulReorderSymbolName = "__xla_cpu_runtime_OneDnnMatMulReorder"; extern const char* const kHandleFfiCallSymbolName = diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.h b/third_party/xla/xla/service/cpu/cpu_runtime.h index 976017ae83ca6b..54385855f0bb36 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.h +++ b/third_party/xla/xla/service/cpu/cpu_runtime.h @@ -89,6 +89,7 @@ extern const char* const kReduceScatterSymbolName; extern const char* const kOneDnnMatMulSymbolName; extern const char* const kOneDnnSoftmaxSymbolName; extern const char* const kOneDnnLayerNormSymbolName; +extern const char* const kOneDnnConvolutionSymbolName; extern const char* const kOneDnnMatMulReorderSymbolName; extern const char* const kHandleFfiCallSymbolName; diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index 802cd9e14a4d6d..a5a87ef0ba2b8d 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -2702,14 +2702,89 @@ absl::Status IrEmitter::HandleOneDnnMatMulCalls( return absl::OkStatus(); } +absl::Status IrEmitter::HandleOneDnnConvolution(HloInstruction* custom_call) { + // args[0]: ptr to nargs + // args[1]: ptr to ExecutableRunOptions + // args[2]: ptr to OneDnnConvolutionConfig + // args[3...]: ptrs to operands + + // First three arguments: nargs, ExecutableRunOptions, and + // OneDnnConvolutionConfig. + const int nargs_offset = 3; + const int num_operands = custom_call->operand_count(); + const int nargs = nargs_offset + num_operands; + int arg_indx = 0; + + llvm::Type* i64_type = b_.getInt64Ty(); + llvm::Type* ptr_type = b_.getPtrTy(); + llvm::ArrayType* ptr_array_type = llvm::ArrayType::get(ptr_type, nargs); + llvm::Value* args_val = llvm::UndefValue::get(ptr_array_type); + + llvm::Value* nargs_val = b_.getInt64(nargs); + llvm::Value* nargs_ptr = + llvm_ir::EmitAllocaAtFunctionEntry(i64_type, "nargs", &b_); + b_.CreateLifetimeStart(nargs_ptr, b_.getInt64(-1)); + b_.CreateStore(nargs_val, nargs_ptr); + args_val = b_.CreateInsertValue(args_val, nargs_ptr, arg_indx++); + + llvm::Value* run_opts_val = GetExecutableRunOptionsArgument(); + args_val = b_.CreateInsertValue(args_val, run_opts_val, arg_indx++); + + auto typed_custom_call = Cast(custom_call); + auto backend_config = typed_custom_call->backend_config(); + OneDnnConvolutionConfig conv_config; + conv_config.CopyFrom(backend_config->onednn_conv_config()); + std::string str_config; + conv_config.SerializeToString(&str_config); + llvm::Value* conv_config_val = + b_.CreateGlobalStringPtr(llvm_ir::AsStringRef(str_config)); + args_val = b_.CreateInsertValue(args_val, conv_config_val, arg_indx++); + + std::vector operands_stack_alloca; + operands_stack_alloca.reserve(num_operands); + absl::c_transform(custom_call->operands(), operands_stack_alloca.begin(), + [this](HloInstruction* instr) { + llvm_ir::IrArray ir_array(GetIrArrayFor(instr)); + return GetAllocaAndEmitMemrefInfo(b_, ir_array); + }); + for (int i = 0; i < num_operands; ++i) { + args_val = b_.CreateInsertValue(args_val, operands_stack_alloca[i].value, + arg_indx++); + } + TF_RET_CHECK(nargs == arg_indx) + << "Number of arguments don't equal the last argument index."; + + llvm::Value* args_ptr = llvm_ir::EmitAllocaAtFunctionEntry( + ptr_array_type, "convolution.args", &b_); + b_.CreateLifetimeStart(args_ptr, b_.getInt64(-1)); + b_.CreateStore(args_val, args_ptr); + + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); + llvm_ir::IrArray result_array = GetIrArrayFor(custom_call); + auto result_stack_alloca = GetAllocaAndEmitMemrefInfo(b_, result_array); + + EmitCallToFunc(runtime::kOneDnnConvolutionSymbolName, + {result_stack_alloca.value, args_ptr}, b_.getVoidTy()); + + // Lifetime ends for all stack allocations. + b_.CreateLifetimeEnd(nargs_ptr, b_.getInt64(-1)); + for (int i = 0; i < num_operands; ++i) { + operands_stack_alloca[i].EmitLifetimeEnd(); + } + b_.CreateLifetimeEnd(args_ptr, b_.getInt64(-1)); + result_stack_alloca.EmitLifetimeEnd(); + + return OkStatus(); +} + absl::Status IrEmitter::HandleOneDnnLayerNorm(HloInstruction* custom_call) { // args[0]: ptr to nargs // args[1]: ptr to ExecutableRunOptions - // args[2]: ptr to OneDnnLayerNormConfig + // args[2]: ptr to OneDnnNormConfig // args[3...]: ptrs to operands // First three arguments: nargs, ExecutableRunOptions, and - // OneDnnLayerNormConfig. + // OneDnnNormConfig. const int nargs_offset = 3; const int num_operands = custom_call->operand_count(); const int nargs = nargs_offset + num_operands; @@ -2732,10 +2807,10 @@ absl::Status IrEmitter::HandleOneDnnLayerNorm(HloInstruction* custom_call) { llvm::Value* run_opts_val = GetExecutableRunOptionsArgument(); args_val = b_.CreateInsertValue(args_val, run_opts_val, arg_indx++); - // Insert OneDnnLayerNormConfig. + // Insert OneDnnNormConfig. auto typed_custom_call = Cast(custom_call); auto backend_config = typed_custom_call->backend_config(); - OneDnnLayerNormConfig ln_config; + OneDnnNormConfig ln_config; ln_config.CopyFrom(backend_config->onednn_layer_norm_config()); std::string str_config; ln_config.SerializeToString(&str_config); @@ -2831,6 +2906,9 @@ absl::Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { if (custom_call->custom_call_target() == "__onednn$layernorm") { return HandleOneDnnLayerNorm(custom_call); } + if (custom_call->custom_call_target() == "__onednn$convolution") { + return HandleOneDnnConvolution(custom_call); + } if (custom_call->custom_call_target() == "__onednn$matmul_reorder") { return HandleOneDnnMatMulCalls(custom_call, runtime::kOneDnnMatMulReorderSymbolName); diff --git a/third_party/xla/xla/service/cpu/ir_emitter.h b/third_party/xla/xla/service/cpu/ir_emitter.h index 7e34bbf9e7b88c..1e05edc0bccf9f 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.h +++ b/third_party/xla/xla/service/cpu/ir_emitter.h @@ -234,6 +234,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, std::string runtime_symbol_name); absl::Status HandleOneDnnSoftmax(HloInstruction* hlo); absl::Status HandleOneDnnLayerNorm(HloInstruction* hlo); + absl::Status HandleOneDnnConvolution(HloInstruction* hlo); #endif // INTEL_MKL && ENABLE_ONEDNN_V3 // Private helper to initialize an IR function for the computation. void InitializeIrFunction(const std::string& function_name); diff --git a/third_party/xla/xla/service/cpu/onednn_config.proto b/third_party/xla/xla/service/cpu/onednn_config.proto new file mode 100644 index 00000000000000..9f38673eaacebd --- /dev/null +++ b/third_party/xla/xla/service/cpu/onednn_config.proto @@ -0,0 +1,116 @@ +syntax = "proto3"; + +package xla.cpu; + +message OneDnnDataLayoutProto { + // The batch dimension of the tensor + uint64 batch_dim = 1; + // The feature dimension of the tensor + uint64 feature_dim = 2; + // The spatial dimensions of the tensor + repeated uint64 spatial_dims = 3; +} + +message OneDnnFilterLayoutProto { + // The input feature dimension of the tensor + uint64 input_feature_dim = 1; + // The output feature dimension of the tensor + uint64 output_feature_dim = 2; + // The spatial dimensions of the tensor + repeated uint64 spatial_dims = 3; + // Shape of the tensor + repeated uint64 shape = 4; +} + +message OneDnnFactorLayoutProto { + // The dimensions of the tensor + repeated uint64 dimensions = 1; + // Shape of the tensor + repeated uint64 shape = 2; +} + +message OneDnnOptimizationConfig { + bool weights_prepacked = 1; + bool user_scratchpad = 2; + bool bias_broadcast = 3; +} + +message OneDnnFusionConfig { + // These enum needs to be mapped to oneDNN enum for post_op algorithm. + // TODO(intel-tf): Add kinds supported by oneDNN. + enum FusionKind { + UNDEFINED = 0; + BIAS = 1; + RELU = 2; + TANH = 3; + GELU_ERF = 4; + GELU_TANH = 5; + BINARY_ADD = 6; + LINEAR = 7; + ELU = 8; + RELU6 = 9; + SIGMOID = 10; + } + repeated FusionKind ops = 1; + // To avoid protobuf failures for specific decimal values, + // the original float value alpha is type-casted to int32. + int32 alpha_typecast = 2; +} + +message OneDnnTensorLayoutProto { + uint64 dims = 1; + oneof layout { + OneDnnDataLayoutProto data = 2; + OneDnnFilterLayoutProto filter = 3; + OneDnnFactorLayoutProto tensor = 4; + } +} + +message OneDnnSoftmaxConfig { + int32 softmax_axis = 1; +} + +message OneDnnMatMulConfig { + bool transpose_a = 1; + bool transpose_b = 2; + + OneDnnFusionConfig fusions = 3; + + reserved 4; // was bias_broadcast + reserved 5; // was alpha_typecast + reserved 6; // was weights_prepacked + reserved 7; // was user_scratchpad + + OneDnnOptimizationConfig optimization_config = 8; +} + +message OneDnnWindowProto { + repeated uint64 size = 1; + repeated uint64 pad_left = 2; + repeated uint64 pad_right = 3; + repeated uint64 strides = 4; + repeated uint64 window_dilations = 5; +} + +message OneDnnNormConfig { + enum ScaleAndShift { + UNDEFINED = 0; + SCALE = 1; + SHIFT = 2; + SCALE_AND_SHIFT = 3; + } + ScaleAndShift rescale = 1; + int32 epsilon_typecast = 2; + OneDnnFusionConfig fusions = 3; +} + +message OneDnnConvolutionConfig { + uint64 dims = 1; + OneDnnTensorLayoutProto input = 2; + OneDnnTensorLayoutProto kernel = 3; + OneDnnTensorLayoutProto output = 4; + OneDnnWindowProto window = 5; + + OneDnnFusionConfig fusions = 6; + uint64 feature_groups = 7; +} diff --git a/third_party/xla/xla/service/cpu/onednn_convolution.cc b/third_party/xla/xla/service/cpu/onednn_convolution.cc new file mode 100644 index 00000000000000..f7a4e17d339fb7 --- /dev/null +++ b/third_party/xla/xla/service/cpu/onednn_convolution.cc @@ -0,0 +1,230 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + +#include "xla/service/cpu/onednn_convolution.h" + +#include +#include +#include +#include +#include +#include + +#define EIGEN_USE_THREADS + +#include "dnnl.hpp" +#include "absl/base/dynamic_annotations.h" +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "xla/executable_run_options.h" +#include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/cpu/onednn_memory_util.h" +#include "xla/service/cpu/runtime_lightweight_check.h" +#include "xla/tsl/util/onednn_threadpool.h" +#include "tsl/platform/logging.h" + +namespace xla { +namespace cpu { +namespace { +using dnnl::algorithm; +using dnnl::convolution_forward; +using dnnl::memory; +using dnnl::prop_kind; +using dnnl::stream; +} // namespace + +dnnl::memory ReorderMemory(const dnnl::engine& engine, + const dnnl::memory::desc& dest_md, + dnnl::memory& src_mem, + const dnnl::stream& onednn_stream) { + auto dest_mem = memory(dest_md, engine); + dnnl::reorder(src_mem, dest_mem).execute(onednn_stream, src_mem, dest_mem); + return dest_mem; +} + +dnnl::memory::format_tag GetFormatTag(const int dims) { + return (dims == 3) ? dnnl::memory::format_tag::nwc + : (dims == 4) ? dnnl::memory::format_tag::nhwc + : (dims == 5) ? dnnl::memory::format_tag::ndhwc + : dnnl::memory::format_tag::any; +} + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnConvolution( + void* result, void** args) { + // args[0]: ptr to nargs + // args[1]: ptr to ExecutableRunOptions + // args[2]: ptr to OneDnnConvolutionConfig + // args[3...]: ptrs to operands + int arg_indx = 0; + const int64_t num_args = *(static_cast(args[arg_indx++])); + + const xla::ExecutableRunOptions* run_options = + static_cast(args[arg_indx++]); + XLA_LIGHTWEIGHT_CHECK(run_options != nullptr); + XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr); + tsl::OneDnnThreadPool thread_pool( + run_options->intra_op_thread_pool()->getPool(), false); + dnnl::engine cpu_engine(dnnl::engine::kind::cpu, 0); +#ifndef ENABLE_ONEDNN_OPENMP + auto onednn_stream = + stream(dnnl::threadpool_interop::make_stream(cpu_engine, &thread_pool)); +#else + auto onednn_stream = stream(cpu_engine); +#endif // ENABLE_ONEDNN_OPENMP + + std::string config_str(static_cast(args[arg_indx++])); + OneDnnConvolutionConfig conv_config; + conv_config.ParseFromString(config_str); + + // Generate permutations to create memory descriptors + std::vector inp_perm_axes(conv_config.dims()); + std::vector ker_perm_axes(conv_config.dims()); + std::vector out_perm_axes(conv_config.dims()); + + int index_i = 0; + int index_o = 0; + int index_k = 0; + + inp_perm_axes[conv_config.input().data().batch_dim()] = index_i++; + out_perm_axes[conv_config.output().data().batch_dim()] = index_o++; + ker_perm_axes[conv_config.kernel().filter().output_feature_dim()] = index_k++; + + inp_perm_axes[conv_config.input().data().feature_dim()] = index_i++; + out_perm_axes[conv_config.output().data().feature_dim()] = index_o++; + ker_perm_axes[conv_config.kernel().filter().input_feature_dim()] = index_k++; + + std::vector inp_dim_axes( + conv_config.input().data().spatial_dims().begin(), + conv_config.input().data().spatial_dims().end()); + std::vector ker_dim_axes( + conv_config.kernel().filter().spatial_dims().begin(), + conv_config.kernel().filter().spatial_dims().end()); + std::vector out_dim_axes( + conv_config.output().data().spatial_dims().begin(), + conv_config.output().data().spatial_dims().end()); + + std::for_each(inp_dim_axes.begin(), inp_dim_axes.end(), + [&inp_perm_axes, &index_i](int64_t& n) { + n -= 1; + inp_perm_axes[n] = index_i++; + }); + std::for_each(ker_dim_axes.begin(), ker_dim_axes.end(), + [&ker_perm_axes, &index_k](int64_t& n) { + n -= 1; + ker_perm_axes[n] = index_k++; + }); + std::for_each(out_dim_axes.begin(), out_dim_axes.end(), + [&out_perm_axes, &index_o](int64_t& n) { + n -= 1; + out_perm_axes[n] = index_o++; + }); + + memory::dims strides(conv_config.window().strides().begin(), + conv_config.window().strides().end()); + memory::dims pad_left(conv_config.window().pad_left().begin(), + conv_config.window().pad_left().end()); + memory::dims pad_right(conv_config.window().pad_right().begin(), + conv_config.window().pad_right().end()); + memory::dims rhs_dilations(conv_config.window().window_dilations().begin(), + conv_config.window().window_dilations().end()); + + std::for_each(strides.begin(), strides.end(), [](int64_t& n) { n -= 1; }); + std::for_each(pad_left.begin(), pad_left.end(), [](int64_t& n) { n -= 1; }); + std::for_each(pad_right.begin(), pad_right.end(), [](int64_t& n) { n -= 1; }); + std::for_each(rhs_dilations.begin(), rhs_dilations.end(), + [](int64_t& n) { n -= 2; }); + + auto groups = conv_config.feature_groups(); + + MemrefInfo inp_minfo(args[arg_indx++]); + MemrefInfo ker_minfo(args[arg_indx++]); + MemrefInfo res_minfo(result); + + // Permute memory descriptors + auto inp_md = inp_minfo.GetOneDnnMemDesc(); + auto ker_md = ker_minfo.GetOneDnnMemDesc(); + auto res_md = res_minfo.GetOneDnnMemDesc(); + + std::vector inp_axes(inp_perm_axes.begin(), inp_perm_axes.end()); + std::vector ker_axes(ker_perm_axes.begin(), ker_perm_axes.end()); + std::vector out_axes(out_perm_axes.begin(), out_perm_axes.end()); + + auto new_inp_md = inp_md.permute_axes(inp_axes); + auto new_ker_md = ker_md.permute_axes(ker_axes); + auto new_res_md = res_md.permute_axes(out_axes); + + if (groups > 1) { + auto corr_dims = new_ker_md.get_dims(); + corr_dims.insert(corr_dims.begin(), 1, groups); + corr_dims[1] = corr_dims[1] / groups; + new_ker_md = new_ker_md.reshape(corr_dims); + } + + auto any_ker_md = + memory::desc(new_ker_md.get_dims(), new_ker_md.get_data_type(), + dnnl::memory::format_tag::any); + auto any_inp_md = + memory::desc(new_inp_md.get_dims(), new_inp_md.get_data_type(), + GetFormatTag(new_inp_md.get_ndims())); + auto any_res_md = + memory::desc(new_res_md.get_dims(), new_res_md.get_data_type(), + GetFormatTag(new_res_md.get_ndims())); + + XLA_LIGHTWEIGHT_CHECK(num_args == arg_indx); + + dnnl::primitive_attr attrs; + + auto inp_mem = memory(new_inp_md, cpu_engine, inp_minfo.Data()); + auto ker_mem = memory(new_ker_md, cpu_engine, ker_minfo.Data()); + auto res_mem = memory(new_res_md, cpu_engine, res_minfo.Data()); + + auto conv_pd = convolution_forward::primitive_desc( + cpu_engine, prop_kind::forward_inference, algorithm::convolution_direct, + any_inp_md, any_ker_md, any_res_md, strides, rhs_dilations, pad_left, + pad_right, attrs); + + auto new_inp_mem = (conv_pd.src_desc() == inp_mem.get_desc()) + ? inp_mem + : ReorderMemory(cpu_engine, conv_pd.src_desc(), + inp_mem, onednn_stream); + auto new_ker_mem = (conv_pd.weights_desc() == ker_mem.get_desc()) + ? ker_mem + : ReorderMemory(cpu_engine, conv_pd.weights_desc(), + ker_mem, onednn_stream); + auto new_res_mem = (conv_pd.dst_desc() == res_mem.get_desc()) + ? res_mem + : memory(conv_pd.dst_desc(), cpu_engine); + + auto conv_prim = convolution_forward(conv_pd); + + std::unordered_map conv_args{{DNNL_ARG_SRC, new_inp_mem}, + {DNNL_ARG_WEIGHTS, new_ker_mem}, + {DNNL_ARG_DST, new_res_mem}}; + + conv_prim.execute(onednn_stream, conv_args); + + if (conv_pd.dst_desc() == res_mem.get_desc()) { + res_mem = new_res_mem; + } else { + dnnl::reorder(new_res_mem, res_mem) + .execute(onednn_stream, new_res_mem, res_mem); + } +} + +} // namespace cpu +} // namespace xla + +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 diff --git a/third_party/xla/xla/service/cpu/onednn_convolution.h b/third_party/xla/xla/service/cpu/onednn_convolution.h new file mode 100644 index 00000000000000..19cbbe2e2a371a --- /dev/null +++ b/third_party/xla/xla/service/cpu/onednn_convolution.h @@ -0,0 +1,31 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_ONEDNN_CONVOLUTION_H_ +#define XLA_SERVICE_CPU_ONEDNN_CONVOLUTION_H_ +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + +namespace xla { +namespace cpu { + +extern "C" { +extern void __xla_cpu_runtime_OneDnnConvolution(void* result, void** args); +} // extern "C" + +} // namespace cpu +} // namespace xla + +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 +#endif // XLA_SERVICE_CPU_ONEDNN_CONVOLUTION_H_ diff --git a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.cc new file mode 100644 index 00000000000000..c148f116988ea4 --- /dev/null +++ b/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.cc @@ -0,0 +1,140 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + +#include "xla/service/cpu/onednn_convolution_rewriter.h" + +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/cpu/onednn_memory_util.h" +#include "xla/service/cpu/onednn_util.h" +#include "xla/service/pattern_matcher.h" +#include "xla/status_macros.h" + +namespace xla { +namespace cpu { + +namespace { +namespace m = match; +} // namespace + +bool OneDnnConvolutionRewriter::ShouldRewrite(const HloInstruction* conv) { + if (conv->HasControlDependencies()) return false; + if (!IsSupportedType(conv->shape().element_type())) return false; + if (conv->batch_group_count() != 1) return false; + + if (conv->operand(1)->opcode() == HloOpcode::kReverse) return false; + + const Shape& inp_shape = conv->operand(0)->shape(); + const Shape& ker_shape = conv->operand(1)->shape(); + const Shape& out_shape = conv->shape(); + if (ShapeUtil::IsZeroElementArray(inp_shape) || + ShapeUtil::IsZeroElementArray(ker_shape) || + ShapeUtil::IsZeroElementArray(out_shape)) { + return false; + } + + auto dims = conv->window().dimensions().size(); + if (dims >= 4 || dims <= 0) return false; + + if (inp_shape.rank() != ker_shape.rank() || + inp_shape.rank() != out_shape.rank()) { + return false; + } + + return true; +} + +class OneDnnConvolutionRewriterVisitor : public DfsHloRewriteVisitor { + public: + Status HandleConvolution(HloInstruction* conv) override { + auto pattern = match::Op(&conv).WithOpcode(HloOpcode::kConvolution); + if (!Match(conv, pattern)) return OkStatus(); + if (!OneDnnConvolutionRewriter::ShouldRewrite(conv)) return OkStatus(); + + const Shape& conv_shape = conv->shape(); + auto dims = conv->window().dimensions().size(); + const ConvolutionDimensionNumbers& conv_ddata = + conv->convolution_dimension_numbers(); + + BackendConfig backend_config; + OneDnnConvolutionConfig* conv_config = + backend_config.mutable_onednn_conv_config(); + + conv_config->set_dims(conv_shape.rank()); + conv_config->set_feature_groups(conv->feature_group_count()); + conv_config->mutable_input()->mutable_data()->set_batch_dim( + conv_ddata.input_batch_dimension()); + conv_config->mutable_kernel()->mutable_filter()->set_input_feature_dim( + conv_ddata.kernel_input_feature_dimension()); + conv_config->mutable_output()->mutable_data()->set_batch_dim( + conv_ddata.output_batch_dimension()); + conv_config->mutable_input()->mutable_data()->set_feature_dim( + conv_ddata.input_feature_dimension()); + conv_config->mutable_kernel()->mutable_filter()->set_output_feature_dim( + conv_ddata.kernel_output_feature_dimension()); + conv_config->mutable_output()->mutable_data()->set_feature_dim( + conv_ddata.output_feature_dimension()); + + const Shape& output_shape = conv->shape(); + + for (auto it = conv->window().dimensions().begin(); + it != conv->window().dimensions().end(); it++) { + if ((*it).padding_low() < 0 || (*it).padding_high() < 0 || + (*it).stride() < 0) { + return OkStatus(); + } + conv_config->mutable_window()->add_pad_left((*it).padding_low() + 1); + conv_config->mutable_window()->add_pad_right((*it).padding_high() + 1); + conv_config->mutable_window()->add_strides((*it).stride() + 1); + conv_config->mutable_window()->add_window_dilations( + (*it).window_dilation() + 1); + if ((*it).base_dilation() != 1 || (*it).window_reversal()) { + return OkStatus(); + } + } + + for (int i = 0; i < dims; i++) { + conv_config->mutable_input()->mutable_data()->add_spatial_dims( + conv_ddata.input_spatial_dimensions()[i] + 1); + conv_config->mutable_kernel()->mutable_filter()->add_spatial_dims( + conv_ddata.kernel_spatial_dimensions()[i] + 1); + conv_config->mutable_output()->mutable_data()->add_spatial_dims( + conv_ddata.output_spatial_dimensions()[i] + 1); + } + + HloInstruction* custom_call = + conv->AddInstruction(HloInstruction::CreateCustomCall( + output_shape, {conv->mutable_operand(0), conv->mutable_operand(1)}, + "__onednn$convolution")); + + TF_RETURN_IF_ERROR(custom_call->set_backend_config(backend_config)); + TF_RETURN_IF_ERROR(ReplaceInstruction(conv, custom_call)); + return OkStatus(); + } +}; + +StatusOr OneDnnConvolutionRewriter::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + OneDnnConvolutionRewriterVisitor visitor; + return visitor.RunOnModule(module, execution_threads); +} + +} // namespace cpu +} // namespace xla + +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 diff --git a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h b/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h new file mode 100644 index 00000000000000..334db7f60b8356 --- /dev/null +++ b/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h @@ -0,0 +1,49 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_CPU_ONEDNN_CONVOLUTION_REWRITER_H_ +#define XLA_SERVICE_CPU_ONEDNN_CONVOLUTION_REWRITER_H_ +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + +#include + +#include "absl/algorithm/container.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { +namespace cpu { + +// This pass converts hlo convolution instructions into a single oneDNN +// operation and rewrites into custom calls. +class OneDnnConvolutionRewriter : public HloModulePass { + public: + absl::string_view name() const override { + return "onednn-convolution-rewriter"; + } + + using HloPassInterface::Run; + StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + static bool ShouldRewrite(const HloInstruction* instr); +}; + +} // namespace cpu +} // namespace xla + +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 +#endif // XLA_SERVICE_CPU_ONEDNN_CONVOLUTION_REWRITER_H_ diff --git a/third_party/xla/xla/service/cpu/onednn_layer_norm.cc b/third_party/xla/xla/service/cpu/onednn_layer_norm.cc index d2109a1bc2f956..6abd698f5898a6 100644 --- a/third_party/xla/xla/service/cpu/onednn_layer_norm.cc +++ b/third_party/xla/xla/service/cpu/onednn_layer_norm.cc @@ -48,7 +48,7 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnLayerNorm( void* result, void** args) { // args[0]: ptr to nargs. We don't use nargs here. // args[1]: ptr to ExecutableRunOptions - // args[2]: ptr to OneDnnLayerNormConfig + // args[2]: ptr to OneDnnNormConfig // args[3...]: ptrs to operands int arg_indx = 1; const xla::ExecutableRunOptions* run_options = @@ -65,7 +65,7 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnLayerNorm( auto onednn_stream = stream(cpu_engine); #endif // ENABLE_ONEDNN_OPENMP std::string config_str(static_cast(args[arg_indx++])); - OneDnnLayerNormConfig ln_config; + OneDnnNormConfig ln_config; ln_config.ParseFromString(config_str); MemrefInfo layer_minfo(args[arg_indx++]); @@ -82,7 +82,6 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnLayerNorm( auto scale_mem = memory(scaleshift_md, cpu_engine, gamma_minfo.Data()); auto shift_mem = memory(scaleshift_md, cpu_engine, beta_minfo.Data()); - // TODO(intel-tf): Move epsilon to OneDnnLayerNormConfig. float epsilon; *(reinterpret_cast(&epsilon)) = ln_config.epsilon_typecast(); diff --git a/third_party/xla/xla/service/cpu/onednn_matmul.cc b/third_party/xla/xla/service/cpu/onednn_matmul.cc index 414e4d6da9a577..5e066795344766 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul.cc +++ b/third_party/xla/xla/service/cpu/onednn_matmul.cc @@ -70,10 +70,10 @@ dnnl::memory::desc OneDnnMatMulOptWeightsDesc( auto weights_md = ShapeToMemDesc(weights_shape); TRANSPOSE_LAST_TWO_DIMS_IF(matmul_config->transpose_a(), input_md); TRANSPOSE_LAST_TWO_DIMS_IF(matmul_config->transpose_b(), weights_md); - auto bias_md = - absl::c_count(matmul_config->fused_ops(), OneDnnMatMulConfig::BIAS) > 0 - ? ShapeToMemDesc(bias_shape) - : dnnl::memory::desc{}; + auto bias_md = absl::c_count(matmul_config->fusions().ops(), + OneDnnFusionConfig::BIAS) > 0 + ? ShapeToMemDesc(bias_shape) + : dnnl::memory::desc{}; auto output_md = ShapeToMemDesc(output_shape); // extend bias rank to match result rank @@ -115,7 +115,7 @@ std::unique_ptr CreateMatMulPrimDesc( const OneDnnMatMulConfig& matmul_config, FusedOperandsRef* fused_operands_ref = nullptr) { auto bias_md = memory::desc(); - bool weights_packed = matmul_config.weights_prepacked(); + bool weights_packed = matmul_config.optimization_config().weights_prepacked(); auto weights_md = plain_weights_md; if (weights_packed) { weights_md = memory::desc(weights_md.get_dims(), weights_md.get_data_type(), @@ -124,27 +124,27 @@ std::unique_ptr CreateMatMulPrimDesc( dnnl::post_ops post_ops; int fused_operand_idx = 0; - for (auto& fused_op : matmul_config.fused_ops()) { + for (auto& fused_op : matmul_config.fusions().ops()) { switch (fused_op) { - case OneDnnMatMulConfig::RELU: + case OneDnnFusionConfig::RELU: post_ops.append_eltwise(dnnl::algorithm::eltwise_relu, 0.f, 0.f); break; - case OneDnnMatMulConfig::TANH: + case OneDnnFusionConfig::TANH: post_ops.append_eltwise(dnnl::algorithm::eltwise_tanh, 0.f, 0.f); break; - case OneDnnMatMulConfig::GELU_TANH: + case OneDnnFusionConfig::GELU_TANH: post_ops.append_eltwise(dnnl::algorithm::eltwise_gelu_tanh, 0.f, 0.f); break; - case OneDnnMatMulConfig::GELU_ERF: + case OneDnnFusionConfig::GELU_ERF: post_ops.append_eltwise(dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f); break; - case OneDnnMatMulConfig::RELU6: + case OneDnnFusionConfig::RELU6: post_ops.append_eltwise(dnnl::algorithm::eltwise_clip_v2, 0.f, 6.0f); break; - case OneDnnMatMulConfig::SIGMOID: + case OneDnnFusionConfig::SIGMOID: post_ops.append_eltwise(dnnl::algorithm::eltwise_logistic, 0.f, 0.f); break; - case OneDnnMatMulConfig::BIAS: { + case OneDnnFusionConfig::BIAS: { bias_md = fused_mds.at(fused_operand_idx); // Extend bias rank to match result rank. auto missed_rank = output_md.get_ndims() - bias_md.get_ndims(); @@ -162,10 +162,10 @@ std::unique_ptr CreateMatMulPrimDesc( } fused_operand_idx++; } break; - case OneDnnMatMulConfig::ELU: + case OneDnnFusionConfig::ELU: post_ops.append_eltwise(dnnl::algorithm::eltwise_elu, 1.0f, 0.0f); break; - case OneDnnMatMulConfig::BINARY_ADD: { + case OneDnnFusionConfig::BINARY_ADD: { auto binary_md = fused_mds.at(fused_operand_idx); // Extend addend rank to match result rank. auto missed_rank = output_md.get_ndims() - binary_md.get_ndims(); @@ -186,10 +186,10 @@ std::unique_ptr CreateMatMulPrimDesc( post_ops.append_binary(dnnl::algorithm::binary_add, binary_md); fused_operand_idx++; } break; - case OneDnnMatMulConfig::LINEAR: { + case OneDnnFusionConfig::LINEAR: { float const_float; *(reinterpret_cast(&const_float)) = - matmul_config.alpha_typecast(); + matmul_config.fusions().alpha_typecast(); post_ops.append_eltwise(dnnl::algorithm::eltwise_linear, const_float, 0.f); } break; @@ -202,7 +202,7 @@ std::unique_ptr CreateMatMulPrimDesc( } dnnl::primitive_attr attrs; - if (matmul_config.user_scratchpad()) { + if (matmul_config.optimization_config().user_scratchpad()) { attrs.set_scratchpad_mode(dnnl::scratchpad_mode::user); } if (post_ops.len() > 0) { @@ -294,7 +294,7 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMul( TRANSPOSE_LAST_TWO_DIMS_IF( matmul_config.transpose_b() && weights_md.get_ndims() > 1, weights_md); auto output_md = output_minfo.GetOneDnnMemDesc(); - if (matmul_config.weights_prepacked()) { + if (matmul_config.optimization_config().weights_prepacked()) { // Weight pre-packing is supported for 2D weights only. // Since prepacked weights array is flattened, try to infer the dims from // input and output. @@ -336,7 +336,7 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMul( {DNNL_ARG_WEIGHTS, rhs_mem}, {DNNL_ARG_DST, result_mem}}; - if (matmul_config.user_scratchpad()) { + if (matmul_config.optimization_config().user_scratchpad()) { XLA_LIGHTWEIGHT_CHECK(scratch != nullptr); MemrefInfo scratch_minfo(scratch); auto scratchpad_md = matmul_pd->scratchpad_desc(); @@ -380,7 +380,8 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMulReorder( auto output_md = output_minfo.GetOneDnnMemDesc(); auto bias_md = dnnl::memory::desc{}; - if (absl::c_count(matmul_config.fused_ops(), OneDnnMatMulConfig::BIAS) > 0) { + if (absl::c_count(matmul_config.fusions().ops(), OneDnnFusionConfig::BIAS) > + 0) { MemrefInfo bias_minfo(args[arg_indx++]); bias_md = bias_minfo.GetOneDnnMemDesc(); } diff --git a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc index 08ae6767998d51..95098dd31cfc6b 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc @@ -264,13 +264,13 @@ auto GELUActivation(HloInstruction* instr, HloInstruction** src) { if (Match(errf, errf_apprx_pattern)) { // Matched Gelu-approximate pattern - return OneDnnMatMulConfig::GELU_TANH; + return OneDnnFusionConfig::GELU_TANH; } else if (Match(errf, errf_exact_pattern)) { // Matched Gelu-exact pattern - return OneDnnMatMulConfig::GELU_ERF; + return OneDnnFusionConfig::GELU_ERF; } } - return OneDnnMatMulConfig::UNDEFINED; + return OneDnnFusionConfig::UNDEFINED; } // OneDNN matmul can fuse add operation with automatic broadcasting along the @@ -502,11 +502,13 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { // Add(bias) + Add(e.g., residual) is enabled. if (!dot->backend_config() ->mutable_onednn_matmul_config() - ->fused_ops() + ->mutable_fusions() + ->ops() .empty() && dot->backend_config() ->mutable_onednn_matmul_config() - ->fused_ops(0) == OneDnnMatMulConfig::BIAS) { + ->mutable_fusions() + ->ops(0) == OneDnnFusionConfig::BIAS) { return absl::OkStatus(); } std::vector new_operands; @@ -559,19 +561,25 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { auto matmul_call = Cast(instr->AddInstruction( dot->CloneWithNewOperands(dot->shape(), new_operands))); - OneDnnMatMulConfig_FusionKind kind; + OneDnnFusionConfig_FusionKind kind; auto backend_config = matmul_call->backend_config(); - if (backend_config->mutable_onednn_matmul_config()->fused_ops().empty() && + if (backend_config->mutable_onednn_matmul_config() + ->fusions() + .ops() + .empty() && addend->shape().rank() == 1) { - kind = OneDnnMatMulConfig::BIAS; + kind = OneDnnFusionConfig::BIAS; } else { - kind = OneDnnMatMulConfig::BINARY_ADD; + kind = OneDnnFusionConfig::BINARY_ADD; } - backend_config->mutable_onednn_matmul_config()->add_fused_ops(kind); + backend_config->mutable_onednn_matmul_config() + ->mutable_fusions() + ->add_ops(kind); if (optional_addend_broadcast) { - backend_config->mutable_onednn_matmul_config()->set_bias_broadcast( - true); + backend_config->mutable_onednn_matmul_config() + ->mutable_optimization_config() + ->set_bias_broadcast(true); } TF_RETURN_IF_ERROR(matmul_call->set_backend_config(*backend_config)); @@ -629,7 +637,7 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { OneDnnMatmulInstr(&matmul_call)) .WithOneUser(), BcastConstScalar(0)))) { - return FuseActivation(OneDnnMatMulConfig::RELU, instr, matmul_call, + return FuseActivation(OneDnnFusionConfig::RELU, instr, matmul_call, intermediate_instr, optional_bitcast); } return absl::OkStatus(); @@ -659,7 +667,7 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { if (Match(src, ElementwiseSafeIntermediates( &intermediate_instr, &optional_bitcast, OneDnnMatmulInstr(&matmul_call)))) { - return FuseActivation(OneDnnMatMulConfig::ELU, instr, matmul_call, + return FuseActivation(OneDnnFusionConfig::ELU, instr, matmul_call, intermediate_instr); } } @@ -676,7 +684,7 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { &intermediate_instr, &optional_bitcast, OneDnnMatmulInstr(&matmul_call)) .WithOneUser()))) { - return FuseActivation(OneDnnMatMulConfig::TANH, instr, matmul_call, + return FuseActivation(OneDnnFusionConfig::TANH, instr, matmul_call, intermediate_instr); } return absl::OkStatus(); @@ -694,7 +702,7 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { OneDnnMatmulInstr(&matmul_call)) .WithOneUser(), BcastConstScalar(6)))) { - return FuseActivation(OneDnnMatMulConfig::RELU6, instr, matmul_call, + return FuseActivation(OneDnnFusionConfig::RELU6, instr, matmul_call, intermediate_instr); } return absl::OkStatus(); @@ -705,7 +713,7 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { HloInstruction* intermediate_instr = nullptr; HloInstruction* src; auto activation = GELUActivation(instr, &src); - if (activation != OneDnnMatMulConfig::UNDEFINED) { + if (activation != OneDnnFusionConfig::UNDEFINED) { HloInstruction* optional_bitcast = nullptr; if (Match(src, ElementwiseSafeIntermediates( &intermediate_instr, &optional_bitcast, @@ -741,12 +749,15 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { auto matmul_call = Cast(instr->AddInstruction( dot->CloneWithNewOperands(instr->shape(), new_operands))); auto backend_config = matmul_call->backend_config(); - backend_config->mutable_onednn_matmul_config()->add_fused_ops( - OneDnnMatMulConfig::LINEAR); + backend_config->mutable_onednn_matmul_config() + ->mutable_fusions() + ->add_ops(OneDnnFusionConfig::LINEAR); // Casting to int32 because of issues in proto config for decimal types // handling. - backend_config->mutable_onednn_matmul_config()->set_alpha_typecast( - *(reinterpret_cast(&constant_value.value()))); + backend_config->mutable_onednn_matmul_config() + ->mutable_fusions() + ->set_alpha_typecast( + *(reinterpret_cast(&constant_value.value()))); TF_RETURN_IF_ERROR(matmul_call->set_backend_config(*backend_config)); HloInstruction* new_instr; if (optional_convert != nullptr && @@ -781,14 +792,14 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { &intermediate_instr, &optional_bitcast, OneDnnMatmulInstr(&matmul_call)) .WithOneUser())) { - return FuseActivation(OneDnnMatMulConfig::SIGMOID, instr, matmul_call, + return FuseActivation(OneDnnFusionConfig::SIGMOID, instr, matmul_call, intermediate_instr, optional_bitcast); } } return absl::OkStatus(); } - absl::Status FuseActivation(OneDnnMatMulConfig_FusionKind kind, + absl::Status FuseActivation(OneDnnFusionConfig_FusionKind kind, HloInstruction* activation, HloInstruction* matmul, HloInstruction* intermediate_instr = nullptr, @@ -796,7 +807,7 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { TF_ASSIGN_OR_RETURN(auto backend_config, matmul->backend_config()); auto* matmul_config = backend_config.mutable_onednn_matmul_config(); - matmul_config->add_fused_ops(kind); + matmul_config->mutable_fusions()->add_ops(kind); TF_RETURN_IF_ERROR(matmul->set_backend_config(backend_config)); std::unique_ptr output = matmul->Clone(); if (optional_bitcast != nullptr && @@ -1052,42 +1063,45 @@ class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor { std::unique_ptr threadpool_device_; }; -#define EMIT_GET_BACKEND_CONFIG_SPECIALIZATION(GETTER, PRIM_DESC, CONFIG, \ - FIELD) \ - template <> \ - inline bool OneDnnPostRewriteVisitor::GETTER(HloInstruction * \ - custom_call) { \ - auto backend_config = custom_call->backend_config(); \ - return backend_config.ok() ? backend_config->CONFIG().FIELD() : false; \ +#define EMIT_GET_BACKEND_CONFIG_SPECIALIZATION(GETTER, PRIM_DESC, CONFIG, \ + SUB_CONFIG, FIELD) \ + template <> \ + inline bool OneDnnPostRewriteVisitor::GETTER(HloInstruction * \ + custom_call) { \ + auto backend_config = custom_call->backend_config(); \ + return backend_config.ok() ? backend_config->CONFIG().SUB_CONFIG().FIELD() \ + : false; \ } EMIT_GET_BACKEND_CONFIG_SPECIALIZATION(GetUserScratch, dnnl::matmul::primitive_desc, - onednn_matmul_config, user_scratchpad); + onednn_matmul_config, + optimization_config, user_scratchpad); EMIT_GET_BACKEND_CONFIG_SPECIALIZATION(GetWeightsPrepack, dnnl::matmul::primitive_desc, - onednn_matmul_config, weights_prepacked); + onednn_matmul_config, + optimization_config, weights_prepacked); #define EMIT_SET_BACKEND_CONFIG_SPECIALIZATION(SETTER, PRIM_DESC, CONFIG_TYPE, \ - CONFIG, FIELD) \ + CONFIG, SUB_CONFIG, FIELD) \ template <> \ inline absl::Status OneDnnPostRewriteVisitor::SETTER( \ HloInstruction * custom_call, bool value) { \ TF_ASSIGN_OR_RETURN(auto backend_config, \ custom_call->backend_config()); \ CONFIG_TYPE* config = backend_config.mutable_##CONFIG(); \ - config->set_##FIELD(value); \ + config->mutable_##SUB_CONFIG()->set_##FIELD(value); \ return custom_call->set_backend_config(backend_config); \ } EMIT_SET_BACKEND_CONFIG_SPECIALIZATION(SetWeightsPrepack, dnnl::matmul::primitive_desc, OneDnnMatMulConfig, onednn_matmul_config, - weights_prepacked); + optimization_config, weights_prepacked); EMIT_SET_BACKEND_CONFIG_SPECIALIZATION(SetUserScratch, dnnl::matmul::primitive_desc, OneDnnMatMulConfig, onednn_matmul_config, - user_scratchpad); + optimization_config, user_scratchpad); absl::StatusOr OneDnnMatMulRewriter::Run( HloModule* module, diff --git a/third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc index 0a54d70c69b18b..f95bf13880000a 100644 --- a/third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc @@ -499,9 +499,9 @@ class OneDnnOpsRewriterVisitor : public DfsHloRewriteVisitor { src_shape, {src, scale_operand, bias_operand}, "__onednn$layernorm")); BackendConfig backend_config; - OneDnnLayerNormConfig* ln_config = + OneDnnNormConfig* ln_config = backend_config.mutable_onednn_layer_norm_config(); - ln_config->set_fused_ops(OneDnnLayerNormConfig::SCALE_AND_SHIFT); + ln_config->set_rescale(OneDnnNormConfig::SCALE_AND_SHIFT); ln_config->set_epsilon_typecast(*(reinterpret_cast(&eps))); TF_RETURN_IF_ERROR(ln_call->set_backend_config(backend_config)); diff --git a/third_party/xla/xla/service/cpu/simple_orc_jit.cc b/third_party/xla/xla/service/cpu/simple_orc_jit.cc index 21ce0280906c1e..715d5ef1ad48cf 100644 --- a/third_party/xla/xla/service/cpu/simple_orc_jit.cc +++ b/third_party/xla/xla/service/cpu/simple_orc_jit.cc @@ -70,6 +70,7 @@ limitations under the License. #include "tsl/platform/logging.h" #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) +#include "xla/service/cpu/onednn_convolution.h" #include "xla/service/cpu/onednn_layer_norm.h" #include "xla/service/cpu/onednn_matmul.h" #include "xla/service/cpu/onednn_softmax.h" @@ -547,6 +548,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(OneDnnMatMul); REGISTER_CPU_RUNTIME_SYMBOL(OneDnnSoftmax); REGISTER_CPU_RUNTIME_SYMBOL(OneDnnLayerNorm); + REGISTER_CPU_RUNTIME_SYMBOL(OneDnnConvolution); REGISTER_CPU_RUNTIME_SYMBOL(OneDnnMatMulReorder); #endif // INTEL_MKL && ENABLE_ONEDNN_V3 diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 2bdfe02297d678..c10e7c04725339 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -2985,6 +2985,26 @@ xla_test( ], ) +xla_test( + name = "onednn_convolution_test", + srcs = ["onednn_convolution_test.cc"], + backends = [ + "cpu", + ], + copts = tsl_copts(), + deps = [ + ":hlo_test_base", + ":test_macros_header", + ":xla_internal_test_main", + "//xla:literal", + "//xla:shape_util", + "//xla:test", + "//xla:test_helpers", + "//xla/hlo/utils:hlo_matchers", + "@local_tsl//tsl/platform:platform_port", + ], +) + xla_test( name = "onednn_layer_norm_test", srcs = ["onednn_layer_norm_test.cc"], diff --git a/third_party/xla/xla/tests/onednn_convolution_test.cc b/third_party/xla/xla/tests/onednn_convolution_test.cc new file mode 100644 index 00000000000000..9d606ae44fa93b --- /dev/null +++ b/third_party/xla/xla/tests/onednn_convolution_test.cc @@ -0,0 +1,87 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + +#include + +#include "xla/hlo/utils/hlo_matchers.h" +#include "xla/literal.h" +#include "xla/service/cpu/onednn_matmul_rewriter.h" +#include "xla/service/cpu/onednn_util.h" +#include "xla/shape_util.h" +#include "xla/test.h" +#include "xla/test_helpers.h" +#include "xla/tests/filecheck.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/test_macros.h" +#include "tsl/platform/cpu_info.h" + +namespace xla { +namespace cpu { + +class ConvolutionTest : public HloTestBase { + protected: + const char* conv_rewrite_str_ = R"( + ; CHECK: custom_call_target="__onednn$convolution", + ; CHECK: backend_config={ + ; CHECK-DAG: "outer_dimension_partitions":[], + ; CHECK-DAG: "onednn_conv_config":{ + ; CHECK-DAG: } + ; CHECK: } + )"; +}; + +TEST_F(ConvolutionTest, Simple2DTestF32) { + const char* convolution_module_str = R"( + HloModule convolution.test.f32 + + ENTRY convolution.test.f32 { + arg.0 = f32[1,22,22,1] parameter(0), parameter_replication={false} + reshape.0 = f32[1,22,22,1] reshape(arg.0) + arg.1 = f32[8,8,1,1] parameter(1), parameter_replication={false} + reshape.1 = f32[8,8,1,1] reshape(arg.1) + convolution.0 = f32[1,11,11,1] convolution(reshape.0, reshape.1), window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f + reshape.2 = f32[1,11,11,1] reshape(convolution.0) + tuple.0 = (f32[1,11,11,1]) tuple(reshape.2) + ROOT get-tuple-element.0 = f32[1,11,11,1] get-tuple-element(tuple.0), index=0 + })"; + + EXPECT_TRUE(RunAndCompare(convolution_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(convolution_module_str, conv_rewrite_str_); +} + +TEST_F(ConvolutionTest, Simple3DTestBF16) { + if (!IsSupportedType(PrimitiveType::BF16)) { + GTEST_SKIP() << "CPU does not support BF16."; + } + + const char* convolution_module_str = R"( + HloModule convolution.test.bf16 + + ENTRY convolution.test.bf16 { + p0 = bf16[8,4,5,5,1] parameter(0) + p1 = bf16[3,3,3,1,32] parameter(1) + ROOT conv = bf16[8,4,5,5,32] convolution(p0, p1), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f +})"; + + EXPECT_TRUE(RunAndCompare(convolution_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(convolution_module_str, conv_rewrite_str_); +} + +} // namespace cpu +} // namespace xla + +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 diff --git a/third_party/xla/xla/tests/onednn_layer_norm_test.cc b/third_party/xla/xla/tests/onednn_layer_norm_test.cc index 5645822f9e1368..9751e207b5e5da 100644 --- a/third_party/xla/xla/tests/onednn_layer_norm_test.cc +++ b/third_party/xla/xla/tests/onednn_layer_norm_test.cc @@ -29,7 +29,7 @@ class LayerNormTest : public HloTestBase { ; CHECK: custom_call_target="__onednn$layernorm", ; CHECK: backend_config={ ; CHECK-DAG: "onednn_layer_norm_config":{ - ; CHECK-DAG: "fused_ops":"SCALE_AND_SHIFT" + ; CHECK-DAG: "rescale":"SCALE_AND_SHIFT" ; CHECK-DAG: } ; CHECK: } )"; diff --git a/third_party/xla/xla/tests/onednn_matmul_test.cc b/third_party/xla/xla/tests/onednn_matmul_test.cc index 463bac0313ba94..0c250c14276213 100644 --- a/third_party/xla/xla/tests/onednn_matmul_test.cc +++ b/third_party/xla/xla/tests/onednn_matmul_test.cc @@ -41,7 +41,9 @@ class MatmulTest : public HloTestBase { ; CHECK: backend_config={ ; CHECK-DAG: "outer_dimension_partitions":[], ; CHECK-DAG: "onednn_matmul_config":{ - ; CHECK-DAG: "fused_ops":["BIAS"] + ; CHECK-DAG: "fusions":{ + ; CHECK-DAG: "ops":["BIAS"] + ; CHECK-DAG: } ; CHECK-DAG: } ; CHECK: } )"; @@ -50,7 +52,9 @@ class MatmulTest : public HloTestBase { ; CHECK: backend_config={ ; CHECK-DAG: "outer_dimension_partitions":[], ; CHECK-DAG: "onednn_matmul_config":{ - ; CHECK-DAG: "fused_ops":["BINARY_ADD"] + ; CHECK-DAG: "fusions":{ + ; CHECK-DAG: "ops":["BINARY_ADD"] + ; CHECK-DAG: } ; CHECK-DAG: } ; CHECK: } )"; @@ -59,7 +63,6 @@ class MatmulTest : public HloTestBase { ; CHECK: backend_config={ ; CHECK-DAG: "outer_dimension_partitions":[], ; CHECK-DAG: "onednn_matmul_config":{ - ; CHECK-DAG: "fused_ops":[] ; CHECK-DAG: } ; CHECK: } )"; @@ -68,7 +71,9 @@ class MatmulTest : public HloTestBase { ; CHECK: backend_config={ ; CHECK-DAG: "outer_dimension_partitions":[], ; CHECK-DAG: "onednn_matmul_config":{ - ; CHECK-DAG: "fused_ops":["BIAS","GELU_TANH"] + ; CHECK-DAG: "fusions":{ + ; CHECK-DAG: "ops":["BIAS","GELU_TANH"] + ; CHECK-DAG: } ; CHECK-DAG: } ; CHECK: } )"; @@ -77,7 +82,9 @@ class MatmulTest : public HloTestBase { ; CHECK: backend_config={ ; CHECK-DAG: "outer_dimension_partitions":[], ; CHECK-DAG: "onednn_matmul_config":{ - ; CHECK-DAG: "fused_ops":["BIAS","GELU_ERF"] + ; CHECK-DAG: "fusions":{ + ; CHECK-DAG: "ops":["BIAS","GELU_ERF"] + ; CHECK-DAG: } ; CHECK-DAG: } ; CHECK: } )"; @@ -86,7 +93,9 @@ class MatmulTest : public HloTestBase { ; CHECK: backend_config={ ; CHECK-DAG: "outer_dimension_partitions":[], ; CHECK-DAG: "onednn_matmul_config":{ - ; CHECK-DAG: "fused_ops":["BIAS","ELU"] + ; CHECK-DAG: "fusions":{ + ; CHECK-DAG: "ops":["BIAS","ELU"] + ; CHECK-DAG: } ; CHECK-DAG: } ; CHECK: } )"; @@ -95,7 +104,9 @@ class MatmulTest : public HloTestBase { ; CHECK: backend_config={ ; CHECK-DAG: "outer_dimension_partitions":[], ; CHECK-DAG: "onednn_matmul_config":{ - ; CHECK-DAG: "fused_ops":["BIAS","TANH"] + ; CHECK-DAG: "fusions":{ + ; CHECK-DAG: "ops":["BIAS","TANH"] + ; CHECK-DAG: } ; CHECK-DAG: } ; CHECK: } )"; @@ -104,7 +115,9 @@ class MatmulTest : public HloTestBase { ; CHECK: backend_config={ ; CHECK-DAG: "outer_dimension_partitions":[], ; CHECK-DAG: "onednn_matmul_config":{ - ; CHECK-DAG: "fused_ops":["BIAS","RELU6"] + ; CHECK-DAG: "fusions":{ + ; CHECK-DAG: "ops":["BIAS","RELU6"] + ; CHECK-DAG: } ; CHECK-DAG: } ; CHECK: } )"; @@ -113,8 +126,9 @@ class MatmulTest : public HloTestBase { ; CHECK: backend_config={ ; CHECK-DAG: "outer_dimension_partitions":[], ; CHECK-DAG: "onednn_matmul_config":{ - ; CHECK-DAG: "fused_ops":["BIAS","SIGMOID"] - ; CHECK-DAG: } + ; CHECK-DAG: "fusions":{ + ; CHECK-DAG: "ops":["BIAS","SIGMOID"] + ; CHECK-DAG: } ; CHECK: } )"; }; @@ -406,7 +420,9 @@ TEST_F(MatmulTest, ApproxGELUTestF32) { ; CHECK: backend_config={ ; CHECK-DAG: "outer_dimension_partitions":[], ; CHECK-DAG: "onednn_matmul_config":{ - ; CHECK-DAG: "fused_ops":["GELU_TANH"] + ; CHECK-DAG: "fusions":{ + ; CHECK-DAG: "ops":["GELU_TANH"] + ; CHECK-DAG: } ; CHECK-DAG: } ; CHECK: } )"); @@ -597,7 +613,9 @@ TEST_F(MatmulTest, ExactGELUTestF32) { ; CHECK: backend_config={ ; CHECK-DAG: "outer_dimension_partitions":[], ; CHECK-DAG: "onednn_matmul_config":{ - ; CHECK-DAG: "fused_ops":["GELU_ERF"] + ; CHECK-DAG: "fusions":{ + ; CHECK-DAG: "ops":["GELU_ERF"] + ; CHECK-DAG: } ; CHECK-DAG: } ; CHECK: } )"); @@ -810,7 +828,9 @@ TEST_F(MatmulTest, ReLUTestF32) { ; CHECK: backend_config={ ; CHECK-DAG: "outer_dimension_partitions":[], ; CHECK-DAG: "onednn_matmul_config":{ - ; CHECK-DAG: "fused_ops":["RELU"] + ; CHECK-DAG: "fusions":{ + ; CHECK-DAG: "ops":["RELU"] + ; CHECK-DAG: } ; CHECK-DAG: } ; CHECK: } )"); @@ -888,7 +908,9 @@ TEST_F(MatmulTest, DivisionByConstantWithEltwiseLinearF32) { ; CHECK: backend_config={ ; CHECK-DAG: "outer_dimension_partitions":[], ; CHECK-DAG: "onednn_matmul_config":{ - ; CHECK-DAG: "fused_ops":["LINEAR"] + ; CHECK-DAG: "fusions":{ + ; CHECK-DAG: "ops":["LINEAR"] + ; CHECK-DAG: } ; CHECK-DAG: } ; CHECK: } )"); @@ -1264,7 +1286,9 @@ TEST_F(MatmulTest, SimpleTestF32WithMulAndAddFusion) { ; CHECK: backend_config={ ; CHECK-DAG: "outer_dimension_partitions":[], ; CHECK-DAG: "onednn_matmul_config":{ - ; CHECK-DAG: "fused_ops":["LINEAR","BINARY_ADD"] + ; CHECK-DAG: "fusions":{ + ; CHECK-DAG: "ops":["LINEAR","BINARY_ADD"] + ; CHECK-DAG: } ; CHECK-DAG: } ; CHECK: } )"); @@ -1437,7 +1461,9 @@ TEST_F(MatmulTest, SimpleTestBF16WithMulAndAddFusion) { ; CHECK: backend_config={ ; CHECK-DAG: "outer_dimension_partitions":[], ; CHECK-DAG: "onednn_matmul_config":{ - ; CHECK-DAG: "fused_ops":["LINEAR","BINARY_ADD"] + ; CHECK-DAG: "fusions":{ + ; CHECK-DAG: "ops":["LINEAR","BINARY_ADD"] + ; CHECK-DAG: } ; CHECK-DAG: } ; CHECK: } )"); From 88980c0dd2e1cc6c7c84f86e131720e7d11ff40e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 23 Jun 2024 17:52:39 -0700 Subject: [PATCH 157/256] Automated Code Change PiperOrigin-RevId: 645929520 --- tensorflow/compiler/mlir/tfr/BUILD | 20 +++++++++++++--- .../tfr/integration/graph_decompose_pass.cc | 11 +++++++-- .../tfr/integration/graph_decompose_pass.h | 7 ++++++ .../tfr/integration/node_expansion_pass.cc | 8 +++++++ .../tfr/integration/node_expansion_pass.h | 3 +++ .../mlir/tfr/integration/tfr_decompose_ctx.cc | 24 +++++++++++++++---- .../mlir/tfr/integration/tfr_decompose_ctx.h | 3 +++ .../tfr/integration/tfr_decompose_ctx_test.cc | 7 +----- 8 files changed, 68 insertions(+), 15 deletions(-) diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD index aa4a2160295682..0cffc86dbc8b18 100644 --- a/tensorflow/compiler/mlir/tfr/BUILD +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -226,23 +226,27 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_attr", "//tensorflow/compiler/mlir/tensorflow:convert_type", - "//tensorflow/compiler/mlir/tensorflow:export_graphdef", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tf2xla/api/v2:tf_executor_to_graph", + "//tensorflow/core:framework", + "//tensorflow/core:framework_types_hdr", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:Support", ], ) @@ -258,7 +262,6 @@ tf_cc_test( "//tensorflow/core:test_main", "@com_google_absl//absl/types:span", "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:statusor", "@local_xla//xla:test", ], ) @@ -270,8 +273,14 @@ cc_library( deps = [ ":tfr_decompose_ctx", "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/common_runtime:device_set", + "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", ], alwayslink = 1, @@ -306,8 +315,13 @@ cc_library( deps = [ ":tfr_decompose_ctx", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime/eager:core_no_xla", + "//tensorflow/core/common_runtime/eager:custom_device", "//tensorflow/core/common_runtime/eager:eager_op_rewrite_registry", + "//tensorflow/core/common_runtime/eager:eager_operation", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc index 5324c425fb3a0a..46a6cc4e80c52f 100644 --- a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc +++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc @@ -14,11 +14,18 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h" -#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" #include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/monitoring/counter.h" -#include "tsl/platform/statusor.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tsl/platform/errors.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h index 65f9f28bf1a33f..efde075691ba87 100644 --- a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h +++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h @@ -17,9 +17,16 @@ limitations under the License. #include +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" #include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tsl/platform/statusor.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc index 939ecc682e73a5..fec69d490f4282 100644 --- a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc +++ b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc @@ -16,11 +16,19 @@ limitations under the License. #include +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" +#include "tensorflow/core/common_runtime/eager/custom_device.h" +#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h" +#include "tensorflow/core/common_runtime/eager/eager_operation.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/status.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.h b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.h index 58ea34ee1e1bc2..477e5d1d3fa85b 100644 --- a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.h +++ b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.h @@ -17,6 +17,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" #include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h" +#include "tensorflow/core/common_runtime/eager/eager_operation.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" #include "tsl/platform/statusor.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc index a53aa0542c368a..a297112c84ded8 100644 --- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -24,32 +26,46 @@ limitations under the License. #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Verifier.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h" #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" #include "tensorflow/compiler/mlir/tfr/passes/passes.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/env_var.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h index f32e8edd02c138..5e13c49d8bf840 100644 --- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h @@ -15,12 +15,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_TFR_DECOMPOSE_CTX_H_ #define TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_TFR_DECOMPOSE_CTX_H_ +#include "absl/status/statusor.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" #include "tsl/platform/statusor.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc index 430328eec1f9fe..f13a4755a700ca 100644 --- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc @@ -18,9 +18,6 @@ limitations under the License. #include #include "absl/types/span.h" -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/test.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -28,11 +25,9 @@ limitations under the License. #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" -#include "tsl/platform/statusor.h" using testing::ElementsAreArray; using testing::Test; From 24d85f377ee31a7e6b79a79d4ff0fb1f0cfa6371 Mon Sep 17 00:00:00 2001 From: Quentin Khan Date: Sun, 23 Jun 2024 19:46:55 -0700 Subject: [PATCH 158/256] Add missing `const` qualifier in `tflite::Subgraph`. PiperOrigin-RevId: 645944358 --- tensorflow/lite/core/subgraph.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h index 281ac04adc2096..160f9c644e387a 100644 --- a/tensorflow/lite/core/subgraph.h +++ b/tensorflow/lite/core/subgraph.h @@ -592,7 +592,7 @@ class Subgraph { // Returns true if the subgraph has been fully delegated. bool IsFullyDelegated() const; - const std::unordered_map& GetTensorBufferIdentifiers() { + const std::unordered_map& GetTensorBufferIdentifiers() const { return tensor_buffer_identifiers_; } From 4e518efe4f8212692d4814b52c7271422822314b Mon Sep 17 00:00:00 2001 From: Leo Heinsaar Date: Mon, 24 Jun 2024 00:08:26 -0700 Subject: [PATCH 159/256] [xla:cpu] Add benchmark for op `gather` PiperOrigin-RevId: 645988622 --- .../xla/xla/service/cpu/benchmarks/BUILD | 17 +++ .../cpu/benchmarks/gather_benchmark_test.cc | 136 ++++++++++++++++++ 2 files changed, 153 insertions(+) create mode 100644 third_party/xla/xla/service/cpu/benchmarks/gather_benchmark_test.cc diff --git a/third_party/xla/xla/service/cpu/benchmarks/BUILD b/third_party/xla/xla/service/cpu/benchmarks/BUILD index 52a4ec2fea2f0d..0a6e049b3a8572 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/BUILD +++ b/third_party/xla/xla/service/cpu/benchmarks/BUILD @@ -147,3 +147,20 @@ xla_cc_test( "@local_tsl//tsl/platform:test_main", ], ) + +xla_cc_test( + name = "gather_benchmark_test", + srcs = ["gather_benchmark_test.cc"], + deps = [ + ":hlo_benchmark_runner", + "//xla:array2d", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", + ], +) diff --git a/third_party/xla/xla/service/cpu/benchmarks/gather_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/gather_benchmark_test.cc new file mode 100644 index 00000000000000..ca6ec89fa35fc1 --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/gather_benchmark_test.cc @@ -0,0 +1,136 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "xla/array2d.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/shape_util.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/test_benchmark.h" + +namespace xla::cpu { + +static void BM_GatherS32(benchmark::State& state) { + int64_t d0 = state.range(0); + int64_t d1 = state.range(1); + int64_t slice_size = state.range(2); + + std::string_view hlo = R"( + HloModule gather_s32_d$d0_d$d1_s$slice_size + + ENTRY e { + operand = s32[$d0,$d1] parameter(0) + indices = s32[$slice_size, 1] parameter(1) + ROOT gather = s32[$slice_size, $d1] gather(operand, indices), + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, + index_vector_dim=1, + slice_sizes={1, $d1} + } + )"; + + std::minstd_rand0 engine; + + auto operand_shape = ShapeUtil::MakeShape(S32, {d0, d1}); + auto indices_shape = ShapeUtil::MakeShape(S32, {slice_size, 1}); + auto operand = *LiteralUtil::CreateRandomLiteral( + operand_shape, &engine, /*mean=*/50, /*stddev=*/10); + + // Generate random indices to be used in the gather + std::vector random_indices(slice_size); + std::uniform_int_distribution dist(0, d0 - 1); + absl::c_generate(random_indices, [&]() { return dist(engine); }); + + // Transform the indices into a 2D array - as expected by the gather op + Array2D indices_2d(slice_size, 1); + for (int i = 0; i < slice_size; ++i) { + indices_2d(i, 0) = random_indices[i]; + } + auto indices = LiteralUtil::CreateR2FromArray2D(indices_2d); + + std::vector args = {&operand, &indices}; + CHECK_OK(RunHloBenchmark(state, hlo, args, + {{"$d0", absl::StrCat(d0)}, + {"$d1", absl::StrCat(d1)}, + {"$slice_size", absl::StrCat(slice_size)}})); +} + +BENCHMARK(BM_GatherS32) + ->MeasureProcessCPUTime() + ->Args({3, 3, 1}) + ->Args({3, 3, 2}) + ->Args({3, 3, 4}) + ->Args({3, 32, 1}) + ->Args({3, 32, 2}) + ->Args({3, 32, 8}) + ->Args({3, 64, 1}) + ->Args({3, 64, 2}) + ->Args({3, 64, 16}) + ->Args({3, 128, 1}) + ->Args({3, 128, 2}) + ->Args({3, 128, 32}) + ->Args({3, 256, 1}) + ->Args({3, 256, 2}) + ->Args({3, 256, 64}) + ->Args({3, 512, 1}) + ->Args({3, 512, 2}) + ->Args({3, 512, 128}) + ->Args({10, 3, 1}) + ->Args({10, 3, 2}) + ->Args({10, 3, 4}) + ->Args({10, 32, 1}) + ->Args({10, 32, 2}) + ->Args({10, 32, 8}) + ->Args({10, 64, 1}) + ->Args({10, 64, 2}) + ->Args({10, 64, 16}) + ->Args({10, 128, 1}) + ->Args({10, 128, 2}) + ->Args({10, 128, 32}) + ->Args({10, 256, 1}) + ->Args({10, 256, 2}) + ->Args({10, 256, 64}) + ->Args({10, 512, 1}) + ->Args({10, 512, 2}) + ->Args({10, 512, 128}) + ->Args({100, 3, 1}) + ->Args({100, 3, 2}) + ->Args({100, 3, 4}) + ->Args({100, 32, 1}) + ->Args({100, 32, 2}) + ->Args({100, 32, 8}) + ->Args({100, 64, 1}) + ->Args({100, 64, 2}) + ->Args({100, 64, 16}) + ->Args({100, 128, 1}) + ->Args({100, 128, 2}) + ->Args({100, 128, 32}) + ->Args({100, 256, 1}) + ->Args({100, 256, 2}) + ->Args({100, 256, 64}) + ->Args({100, 512, 1}) + ->Args({100, 512, 2}) + ->Args({100, 512, 128}); + +} // namespace xla::cpu From e5fc01001e110f6d67695e039278854349f5d853 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jun 2024 02:03:48 -0700 Subject: [PATCH 160/256] Update GraphDef version to 1903. PiperOrigin-RevId: 646015314 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 755f92897906c3..56c889c5865f6e 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1902 // Updated: 2024/6/23 +#define TF_GRAPH_DEF_VERSION 1903 // Updated: 2024/6/24 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From ab8ac1fd535917738d3cbd83e142892d30fbbbbc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jun 2024 02:03:48 -0700 Subject: [PATCH 161/256] compat: Update forward compatibility horizon to 2024-06-24 PiperOrigin-RevId: 646015316 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 82244cce6f8bf7..beb6522ab43b7e 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 6, 23) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 6, 24) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 84587690fb77206957eb3584d0b3870347d665b8 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Mon, 24 Jun 2024 03:18:53 -0700 Subject: [PATCH 162/256] Disable Zapfhahn for tests that time out. PiperOrigin-RevId: 646031586 --- tensorflow/compiler/tests/BUILD | 1 + third_party/xla/xla/service/gpu/BUILD | 1 + 2 files changed, 2 insertions(+) diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index a0906bfd7b0a59..44061a2bfc6977 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1565,6 +1565,7 @@ tf_xla_py_strict_test( shard_count = 10, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "nozapfhahn", # Times out under coverage "optonly", ], deps = [ diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 203468c5175ed8..9e873c5182b2cf 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -735,6 +735,7 @@ xla_test( "large", "no_oss", # requires-mem:16g tag doesn't work in open source "nomac", + "nozapfhahn", # Times out under coverage "requires-mem:16g", ], deps = [ From 94728ca510e84d028598ce900d4f80f189f2f3d9 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Mon, 24 Jun 2024 03:51:09 -0700 Subject: [PATCH 163/256] [XLA:GPU] Remove unused function in `triton_support_test` PiperOrigin-RevId: 646038400 --- third_party/xla/xla/service/gpu/triton_support_test.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/third_party/xla/xla/service/gpu/triton_support_test.cc b/third_party/xla/xla/service/gpu/triton_support_test.cc index 6ee50a9ce28c7e..72478795aa2bb8 100644 --- a/third_party/xla/xla/service/gpu/triton_support_test.cc +++ b/third_party/xla/xla/service/gpu/triton_support_test.cc @@ -48,11 +48,6 @@ namespace { using ::testing::Not; using ::testing::status::IsOk; -se::GpuComputeCapability GetComputeCapability() { - // TODO(b/348572380) Make this more general and test additional platforms. - return se::CudaComputeCapability::Ampere(); -} - auto AllXlaDataTypes() { std::vector xla_data_types; std::vector to_filter_out = {PRIMITIVE_TYPE_INVALID, From b6744034b0da32b9df3ab4e320b9c4f1dd0d49fd Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Mon, 24 Jun 2024 06:05:38 -0700 Subject: [PATCH 164/256] Integrate LLVM at llvm/llvm-project@e5a41f0afc15 Updates LLVM usage to match [e5a41f0afc15](https://github.com/llvm/llvm-project/commit/e5a41f0afc15) PiperOrigin-RevId: 646067665 --- third_party/llvm/workspace.bzl | 4 ++-- third_party/xla/xla/python/ifrt/ir/ifrt_dialect.td | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 7d5e50e6a89f12..9f6d4c6f1a30a5 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "c07be08df5731dac0b36e029a0dd03ccb099deea" - LLVM_SHA256 = "9f18d8c90a81c966f819bbfd911baed7fc67e019f6b5de1af4bcdd6bd1fb87bf" + LLVM_COMMIT = "e5a41f0afc152cc24b8fef3aa177ef53b2e77c43" + LLVM_SHA256 = "088d6f89839fcdba456f9ab887a43c05cd1224823535ecadf801865677e1d146" tf_http_archive( name = name, diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.td b/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.td index d4ddbad8d43712..e46ff35490c2a1 100644 --- a/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.td +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.td @@ -161,7 +161,7 @@ def Ifrt_ArrayType : TypeDef { let summary = "An Ifrt array sharded on a set of devices."; let parameters = (ins - Builtin_RankedTensor:$shape, + "::mlir::RankedTensorType":$shape, "::xla::ifrt::IfrtShardingAttrInterface":$sharding_attr, Ifrt_DevicesAttr:$devices_attr, OptionalParameter<"::mlir::StringAttr">:$memory_kind_attr); From 347de0e06593efa11bb0b884d92e33bd7d7d25ce Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Mon, 24 Jun 2024 08:29:06 -0700 Subject: [PATCH 165/256] Integrate LLVM at llvm/llvm-project@5cd0ba30f53d Updates LLVM usage to match [5cd0ba30f53d](https://github.com/llvm/llvm-project/commit/5cd0ba30f53d) PiperOrigin-RevId: 646103303 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 9f6d4c6f1a30a5..72267d070e6b9a 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "e5a41f0afc152cc24b8fef3aa177ef53b2e77c43" - LLVM_SHA256 = "088d6f89839fcdba456f9ab887a43c05cd1224823535ecadf801865677e1d146" + LLVM_COMMIT = "5cd0ba30f53d11835dbfd05ad4071d397387fb04" + LLVM_SHA256 = "cdf76f8704646b105ca671ab9fd4fbeb856d1af964f177e3740b8ce362631af4" tf_http_archive( name = name, From a4065439f0f713a82b29f7bef6c7e90f3fd5ae1c Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Mon, 24 Jun 2024 08:48:01 -0700 Subject: [PATCH 166/256] [XLA:GPU][NFC] Updating xla_gpu_enable_triton_hopper description to reflect the current behavior. PiperOrigin-RevId: 646108882 --- third_party/xla/xla/debug_options_flags.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index df05be304cf1a1..9cfce6dd46b0c8 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -1645,7 +1645,7 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "xla_gpu_enable_triton_hopper", bool_setter_for(&DebugOptions::set_xla_gpu_enable_triton_hopper), debug_options->xla_gpu_enable_triton_hopper(), - "Enable Hopper-specific optimizations such as MMA_V3 and pipelining.")); + "Currently used to enable MMA_V3 for Hopper in Triton")); flag_list->push_back(tsl::Flag( "xla_gpu_enable_libnvptxcompiler", [debug_options](bool enabled) { From a2291251e3b9b7cae26666aedf0557d059241afc Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Mon, 24 Jun 2024 09:02:28 -0700 Subject: [PATCH 167/256] PR #14075: Add pow(A, 0.5) => sqrt(A), A >= 0 to algsimp Imported from GitHub PR https://github.com/openxla/xla/pull/14075 This PR adds pattern `pow(A, 0.5) => sqrt(A), A >= 0` Validation - the following checks are valid both before and after simplification.: ```c If x == 0 the result is 0 If x > 0 the result > 0 If x is inf the result is inf If x is nan the result is nan ``` Copybara import of the project: -- 214b16a491f8f257f2f15576e5e20d512696592d by Alexander Pivovarov : Add pow(A, 0.5) => sqrt(A), A >= 0 to algsimp Merging this change closes #14075 PiperOrigin-RevId: 646113258 --- .../xla/xla/service/algebraic_simplifier.cc | 21 ++++++++++++ .../xla/service/algebraic_simplifier_test.cc | 34 +++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/third_party/xla/xla/service/algebraic_simplifier.cc b/third_party/xla/xla/service/algebraic_simplifier.cc index c9b9850a94d2a8..9df5c33d095261 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.cc +++ b/third_party/xla/xla/service/algebraic_simplifier.cc @@ -94,6 +94,19 @@ bool IsAll(const HloInstruction* op, int8_t value) { } } +// Unwraps broadcasts hunting for a constant. If we find one, checks if the +// constant contains only the given value. +bool IsAllFloat(const HloInstruction* op, float value) { + switch (op->opcode()) { + case HloOpcode::kBroadcast: + return IsAllFloat(op->operand(0), value); + case HloOpcode::kConstant: + return op->literal().IsAllFloat(value); + default: + return false; + } +} + bool IsAll(const HloInstruction* op, const Literal& scalar) { CHECK(ShapeUtil::IsScalar(scalar.shape())); switch (op->opcode()) { @@ -5449,6 +5462,14 @@ absl::Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { MakeScalarLike(lhs, 1), lhs)); } + VLOG(10) << "trying transform [pow(A, 0.5) => sqrt(A)], for A >= 0: " + << power->ToString(); + if (IsAllFloat(rhs, 0.5) && IsNonNegative(lhs, options_)) { + return ReplaceWithNewInstruction( + power, + HloInstruction::CreateUnary(power->shape(), HloOpcode::kSqrt, lhs)); + } + return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/algebraic_simplifier_test.cc b/third_party/xla/xla/service/algebraic_simplifier_test.cc index c52400b831b6c5..9b2fd29d36e723 100644 --- a/third_party/xla/xla/service/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/service/algebraic_simplifier_test.cc @@ -2166,6 +2166,40 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { EXPECT_EQ(root->operand(0)->literal().GetFirstElement(), 1); } +// pow(A, 0.5) => sqrt(A), for A >= 0 +TEST_F(AlgebraicSimplifierTest, PowHalf) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[1,32] parameter(0) + c0 = f32[] constant(0.5) + br0 = f32[1,32] broadcast(f32[] c0), dimensions={} + abs0 = f32[1,32] abs(p0) + ROOT pow = f32[1,32] power(abs0, br0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + ASSERT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Sqrt(m::Abs(m::Parameter(0))))); +} + +// pow(A, 0.5) ≠> sqrt(A) +// if A is arbitrary number - no simplification +TEST_F(AlgebraicSimplifierTest, PowHalf_NegativeTestCase) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[1,32] parameter(0) + c0 = f32[] constant(0.5) + br0 = f32[1,32] broadcast(f32[] c0), dimensions={} + ROOT pow = f32[1,32] power(p0, br0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); +} + TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); From ba299c3345899a42805922482944c1db79ce02c9 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 24 Jun 2024 09:20:43 -0700 Subject: [PATCH 168/256] [XLA:GPU] Add pretty printing util for `ConstraintExpression`. PiperOrigin-RevId: 646119174 --- .../xla/service/gpu/model/symbolic_tile.cc | 69 +++++++++++-------- .../xla/xla/service/gpu/model/symbolic_tile.h | 5 ++ .../service/gpu/model/symbolic_tile_test.cc | 19 +++++ 3 files changed, 66 insertions(+), 27 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc index 0b68868bf5c214..038e83f08cd55c 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc @@ -751,6 +751,46 @@ void ConstraintExpression::And( disjoint_conjoint_constraints_ = std::move(new_constraints); } +std::string ConstraintExpression::ToString( + const AffineMapPrinter& printer) const { + std::stringstream ss; + Print(ss, printer); + return ss.str(); +} + +void ConstraintExpression::Print(std::ostream& out, + const AffineMapPrinter& printer) const { + if (IsAlwaysSatisfied()) { + out << "always satisfied"; + return; + } + + if (is_satisfiable()) { + // Accumulate constraints in a vector in order to put them in lexicographic + // order and to get deterministic output. + std::vector conjunction_strings; + conjunction_strings.reserve(disjoint_conjoint_constraints_.size()); + for (const auto& disjunction : disjoint_conjoint_constraints_) { + std::vector constraint_strings; + constraint_strings.reserve(disjunction.size()); + for (const auto& [expr, interval] : disjunction) { + std::stringstream ss; + printer.Print(ss, expr); + ss << " in "; + interval.Print(ss); + constraint_strings.push_back(ss.str()); + } + std::sort(constraint_strings.begin(), constraint_strings.end()); + conjunction_strings.push_back(absl::StrJoin(constraint_strings, " && ")); + } + std::sort(conjunction_strings.begin(), conjunction_strings.end()); + out << absl::StrJoin(conjunction_strings, " || "); + } else if (!is_satisfiable()) { + out << "unsatisfiable"; + } + out << "\n"; +} + /*static*/ std::optional SymbolicTile::FromIndexingMap( IndexingMap indexing_map) { VLOG(1) << "SymbolicTile::FromIndexingMap: " << indexing_map.ToString(); @@ -912,35 +952,10 @@ void SymbolicTile::Print(std::ostream& out, /*first_rt_var_symbol_index=*/tile_map_.GetDimensionCount(), out, printer); } - if (is_satisfiable() && !constraints_.IsAlwaysSatisfied()) { + if (!constraints_.IsAlwaysSatisfied()) { out << "\n\tconstraints: "; - absl::Span conjunctions = - constraints_.DisjointConjointConstraints(); - // Accumulate constraints in a vector in order to put them in lexicographic - // order and to get deterministic output. - std::vector conjunction_strings; - conjunction_strings.reserve(conjunctions.size()); - for (const auto& disjunction : conjunctions) { - std::vector constraint_strings; - constraint_strings.reserve(disjunction.size()); - for (const auto& [expr, interval] : disjunction) { - std::stringstream ss; - printer.Print(ss, expr); - ss << " in "; - interval.Print(ss); - constraint_strings.push_back(ss.str()); - } - std::sort(constraint_strings.begin(), constraint_strings.end()); - conjunction_strings.push_back( - absl::StrJoin(constraint_strings, " &&\n\t")); - } - std::sort(conjunction_strings.begin(), conjunction_strings.end()); - out << "\n\t" << absl::StrJoin(conjunction_strings, "\n||\n\t"); - } else if (!is_satisfiable()) { - out << "\n\tconstraints: "; - out << "\n\tunsatisfiable"; + constraints_.Print(out, printer); } - out << "\n"; } namespace { diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.h b/third_party/xla/xla/service/gpu/model/symbolic_tile.h index d348f95f67ecdc..47b4084ae84ae8 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile.h +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.h @@ -90,6 +90,11 @@ class ConstraintExpression { return disjoint_conjoint_constraints_; } + std::string ToString( + const AffineMapPrinter& printer = AffineMapPrinter()) const; + + void Print(std::ostream& out, const AffineMapPrinter& printer) const; + // TODO(bchetioui): add a util to verify constraints here later. // TODO(bchetioui): is canonicalization of disjunctions necessary? private: diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc index e60a8fd7157bd8..393b3305786ffb 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc @@ -809,6 +809,25 @@ TEST_F(ConstraintExpressionTest, EXPECT_TRUE(ConstraintExpression().IsAlwaysSatisfied()); } +TEST_F(ConstraintExpressionTest, PrettyPrintingTest) { + EXPECT_TRUE( + ApproximateMatch(ConstraintExpression().ToString(), "always satisfied")); + EXPECT_TRUE(ApproximateMatch( + ConstraintExpression::GetUnsatisfiableConstraintExpression().ToString(), + "unsatisfiable")); + + ConjointConstraints conjunction_1 = + GetConjointConstraints({{"d0", Interval{0, 5}}, {"d1", Interval{0, 5}}}); + ConjointConstraints conjunction_2 = + GetConjointConstraints({{"d2", Interval{0, 5}}}); + + ConstraintExpression constraints; + constraints.Or(std::move(conjunction_1)); + constraints.Or(std::move(conjunction_2)); + EXPECT_TRUE(ApproximateMatch(constraints.ToString(), + "d0 in [0, 5] && d1 in [0, 5] || d2 in [0, 5]")); +} + TEST_F(ConstraintExpressionTest, UnsatisfiableConstraintExpressionHoldsNoConstraint) { ConstraintExpression unsatisfiable_constraint = From 1e0d90e566051bbcd2f1a350ce91f2832643a6fa Mon Sep 17 00:00:00 2001 From: Kuy Mainwaring Date: Mon, 24 Jun 2024 10:09:03 -0700 Subject: [PATCH 169/256] [XLA:GPU] Clang-tidy cleanup for xla/service/bitcast_dtypes_expander.h PiperOrigin-RevId: 646135550 --- third_party/xla/xla/service/BUILD | 2 +- third_party/xla/xla/service/bitcast_dtypes_expander.h | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 5f490b4d833ff4..f78cd4e18038db 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -2855,7 +2855,6 @@ cc_library( srcs = ["bitcast_dtypes_expander.cc"], hdrs = ["bitcast_dtypes_expander.h"], deps = [ - ":hlo_pass", ":op_expander_pass", "//xla:literal_util", "//xla:shape_util", @@ -2867,6 +2866,7 @@ cc_library( "//xla/client/lib:constants", "//xla/hlo/ir:hlo", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:logging", diff --git a/third_party/xla/xla/service/bitcast_dtypes_expander.h b/third_party/xla/xla/service/bitcast_dtypes_expander.h index 6aff287d411fdc..f103c37878a603 100644 --- a/third_party/xla/xla/service/bitcast_dtypes_expander.h +++ b/third_party/xla/xla/service/bitcast_dtypes_expander.h @@ -13,10 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + +#include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_pass_interface.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/op_expander_pass.h" #ifndef XLA_SERVICE_BITCAST_DTYPES_EXPANDER_H_ From e8345e2753bf4b2907a81bd29f32e5b819b99de8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jun 2024 10:15:15 -0700 Subject: [PATCH 170/256] Add a new target for versioning that includes warnings. This removes flatbuffer_exports.cc dependency on tflite schema_fbs_version PiperOrigin-RevId: 646137639 --- tensorflow/compiler/mlir/lite/BUILD | 12 ++++++++- tensorflow/compiler/mlir/lite/build_def.bzl | 21 ++++++++++++++++ .../compiler/mlir/lite/flatbuffer_export.cc | 2 +- tensorflow/compiler/mlir/lite/version.h | 25 +++++++++++++++++++ 4 files changed, 58 insertions(+), 2 deletions(-) create mode 100644 tensorflow/compiler/mlir/lite/version.h diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index e1e4712276fe28..974f767940924f 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -4,6 +4,7 @@ load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") load("//tensorflow:tensorflow.bzl", "if_google", "if_oss", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary") load("//tensorflow:tensorflow.default.bzl", "filegroup", "get_compatible_with_portable") +load("//tensorflow/compiler/mlir/lite:build_def.bzl", "tflite_copts_warnings") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( @@ -1084,6 +1085,7 @@ cc_library( deps = [ ":convert_type", ":flatbuffer_tflite_operator_lib", + ":lite_version", ":low_bit_utils", ":stateful_ops_utils", ":string_utils", @@ -1103,7 +1105,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:portable_gif_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/core:framework", "//tensorflow/lite/core/c:private_common", "//tensorflow/lite/delegates/flex:allowlisted_flex_ops_lib", @@ -1599,3 +1600,12 @@ selects.config_setting_group( ], ) # LINT.ThenChange(//tensorflow/lite/BUILD) + +# LINT.IfChange(version) +cc_library( + name = "lite_version", + hdrs = ["version.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts_warnings(), +) +# LINT.ThenChange(//tensorflow/lite:version) diff --git a/tensorflow/compiler/mlir/lite/build_def.bzl b/tensorflow/compiler/mlir/lite/build_def.bzl index 15a7df2c68852d..19a2dd6d557114 100644 --- a/tensorflow/compiler/mlir/lite/build_def.bzl +++ b/tensorflow/compiler/mlir/lite/build_def.bzl @@ -66,3 +66,24 @@ def tflite_copts(): return copts + tflite_copts_extra() # LINT.ThenChange(//tensorflow/lite/build_def.bzl:tflite_copts) + +# LINT.IfChange(tflite_copts_warnings) +def tflite_copts_warnings(): + """Defines common warning flags used primarily by internal TFLite libraries.""" + + # TODO(b/155906820): Include with `tflite_copts()` after validating clients. + + return select({ + clean_dep("//tensorflow:windows"): [ + # We run into trouble on Windows toolchains with warning flags, + # as mentioned in the comments below on each flag. + # We could be more aggressive in enabling supported warnings on each + # Windows toolchain, but we compromise with keeping BUILD files simple + # by limiting the number of config_setting's. + ], + "//conditions:default": [ + "-Wall", + ], + }) + +# LINT.ThenChange(//tensorflow/lite/build_def.bzl:tflite_copts_warnings) diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index d1868519766686..6250c60680e867 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -93,6 +93,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/low_bit_utils.h" #include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h" #include "tensorflow/compiler/mlir/lite/utils/string_utils.h" +#include "tensorflow/compiler/mlir/lite/version.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -117,7 +118,6 @@ limitations under the License. #include "tensorflow/lite/tools/versioning/gpu_compatibility.h" #include "tensorflow/lite/tools/versioning/op_version.h" #include "tensorflow/lite/tools/versioning/runtime_version.h" -#include "tensorflow/lite/version.h" #include "tsl/platform/fingerprint.h" #include "tsl/platform/status.h" #include "tsl/platform/tstring.h" diff --git a/tensorflow/compiler/mlir/lite/version.h b/tensorflow/compiler/mlir/lite/version.h new file mode 100644 index 00000000000000..321bd39568721e --- /dev/null +++ b/tensorflow/compiler/mlir/lite/version.h @@ -0,0 +1,25 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_VERSION_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_VERSION_H_ + +// LINT.IfChange(tflite_schema_version) +// The version number of the Schema. Ideally all changes will be backward +// compatible. If that ever changes, we must ensure that version is the first +// entry in the new tflite root so that we can see that version is not 1. +#define TFLITE_SCHEMA_VERSION (3) +// LINT.ThenChange(//tensorflow/lite/version.h) + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_VERSION_H_ From e1df4c39060913cd49f2927cb302b79bfbf93ef5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jun 2024 10:21:02 -0700 Subject: [PATCH 171/256] Fix bytes accessed for paramter feeding both nested fusion and other trivial user. The assumption here is that the nested fusion and other trivial user are unable to share reads of the parameter. PiperOrigin-RevId: 646139659 --- .../xla/xla/service/hlo_cost_analysis.cc | 29 ++++++++------- .../xla/xla/service/hlo_cost_analysis_test.cc | 36 ++++++++++++++++++- 2 files changed, 51 insertions(+), 14 deletions(-) diff --git a/third_party/xla/xla/service/hlo_cost_analysis.cc b/third_party/xla/xla/service/hlo_cost_analysis.cc index 5382929fe23c55..16597b4769d04f 100644 --- a/third_party/xla/xla/service/hlo_cost_analysis.cc +++ b/third_party/xla/xla/service/hlo_cost_analysis.cc @@ -174,17 +174,15 @@ int64_t HloCostAnalysis::GetShapeSize(const Shape& shape) const { int64_t HloCostAnalysis::FusionParameterReadBytes( const HloInstruction* hlo) const { - int64_t size = 0; - bool seen_trivial_user = false; CHECK(hlo->IsFused() && (hlo->opcode() == HloOpcode::kParameter || hlo->opcode() == HloOpcode::kGetTupleElement)); auto handle_slice = [this](const HloInstruction* hlo, const HloInstruction* user) -> int64_t { return GetShapeSize(user->shape()); }; - auto handle_dynamic_slice = [&seen_trivial_user, this]( - const HloInstruction* hlo, - const HloInstruction* user) -> int64_t { + auto handle_dynamic_slice = [this](const HloInstruction* hlo, + const HloInstruction* user, + bool& seen_trivial_user) -> int64_t { if (hlo == user->operand(0)) { return GetShapeSize(user->shape()); } @@ -195,8 +193,8 @@ int64_t HloCostAnalysis::FusionParameterReadBytes( return 0; }; auto handle_dynamic_update_slice = - [&seen_trivial_user, this](const HloInstruction* hlo, - const HloInstruction* user) -> int64_t { + [this](const HloInstruction* hlo, const HloInstruction* user, + bool& seen_trivial_user) -> int64_t { // Operand 0 is aliased to the output. if (hlo != user->operand(0) && !seen_trivial_user) { seen_trivial_user = true; @@ -204,10 +202,13 @@ int64_t HloCostAnalysis::FusionParameterReadBytes( } return 0; }; + int64_t size = 0; + bool seen_trivial_user = false; for (const HloInstruction* user : hlo->users()) { switch (user->opcode()) { case HloOpcode::kFusion: { for (int64_t idx : user->OperandIndices(hlo)) { + bool nested_seen_trivial_user = false; const auto& fusion_users = user->users(); const HloInstruction* root_instruction = user->fused_instructions_computation()->root_instruction(); @@ -222,12 +223,14 @@ int64_t HloCostAnalysis::FusionParameterReadBytes( size += handle_slice(user, fusion_user); } else if (fusion_is_simple && fusion_user->opcode() == HloOpcode::kDynamicSlice) { - size += handle_dynamic_slice(user, fusion_user); + size += handle_dynamic_slice(user, fusion_user, + nested_seen_trivial_user); } else if (fusion_is_simple && fusion_user->opcode() == HloOpcode::kDynamicUpdateSlice) { - size += handle_dynamic_update_slice(user, fusion_user); - } else if (!seen_trivial_user) { - seen_trivial_user = true; + size += handle_dynamic_update_slice(user, fusion_user, + nested_seen_trivial_user); + } else if (!nested_seen_trivial_user) { + nested_seen_trivial_user = true; size += FusionParameterReadBytes(user->fused_parameter(idx)); } } @@ -238,10 +241,10 @@ int64_t HloCostAnalysis::FusionParameterReadBytes( size += handle_slice(hlo, user); break; case HloOpcode::kDynamicSlice: - size += handle_dynamic_slice(hlo, user); + size += handle_dynamic_slice(hlo, user, seen_trivial_user); break; case HloOpcode::kDynamicUpdateSlice: - size += handle_dynamic_update_slice(hlo, user); + size += handle_dynamic_update_slice(hlo, user, seen_trivial_user); break; case HloOpcode::kBroadcast: case HloOpcode::kReshape: diff --git a/third_party/xla/xla/service/hlo_cost_analysis_test.cc b/third_party/xla/xla/service/hlo_cost_analysis_test.cc index ac73f9b43cc5be..74c6e158f834de 100644 --- a/third_party/xla/xla/service/hlo_cost_analysis_test.cc +++ b/third_party/xla/xla/service/hlo_cost_analysis_test.cc @@ -962,7 +962,7 @@ HloModule temp, is_scheduled=true fused_computation.1 { tmp_0 = bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} parameter(0) - tmp_1 = bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} fusion(bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} tmp_0), kind=kOutput, calls= + tmp_1 = bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} fusion(bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} tmp_0), kind=kLoop, calls= { tmp_0 = bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} parameter(0) ROOT tmp_4 = bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} add(bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} tmp_0, bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} tmp_0) @@ -990,6 +990,40 @@ ENTRY temp { EXPECT_EQ(1073741824, fusion_analysis.bytes_accessed(*fusion_root)); } +TEST_F(FusionCostAnalysis, ParamFeedsNestedFusionAndTrivialUser) { + absl::string_view hlo_text = R"( +HloModule temp, is_scheduled=true + +fused_computation.1 { + tmp_0 = bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} parameter(0) + tmp_1 = bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} fusion(bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} tmp_0), kind=kLoop, calls= + { + tmp_0 = bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} parameter(0) + ROOT tmp_4 = bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} add(bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} tmp_0, bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} tmp_0) + } + tmp_2 = bf16[]{:T(256)} constant(0) + tmp_3 = bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} reduce-window(bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} tmp_1, bf16[]{:T(256)} tmp_2), window={size=1x1x1x1023 pad=0_0x0_0x0_0x511_511}, to_apply= + { + tmp_0 = bf16[]{:T(256)} parameter(0) + tmp_1 = bf16[]{:T(256)} parameter(1) + ROOT tmp_2 = bf16[]{:T(256)} add(bf16[]{:T(256)} tmp_0, bf16[]{:T(256)} tmp_1) + } + ROOT tmp_4 = bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} divide(bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} tmp_0, bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} tmp_3) +} + +ENTRY temp { + tmp_0 = bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} parameter(0) + ROOT result = bf16[64,16,512,512]{2,3,1,0:T(8,128)(2,1)} fusion(tmp_0), kind=kLoop, calls=fused_computation.1 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto fusion_module, + ParseAndReturnVerifiedModule(hlo_text)); + HloCostAnalysis fusion_analysis(ShapeSize); + auto* fusion_root = fusion_module->entry_computation()->root_instruction(); + ASSERT_IS_OK(fusion_root->Accept(&fusion_analysis)); + EXPECT_EQ(1610612736, fusion_analysis.bytes_accessed(*fusion_root)); +} + TEST_F(FusionCostAnalysis, LoopFusionTupleOutput) { Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); From 5cc913d1a8dfc99e6703bf34b14f96057ca7c606 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 24 Jun 2024 10:45:08 -0700 Subject: [PATCH 172/256] [XLA:GPU] Add test for rendering a `ConstraintExpression` unsatisfiable by `And`ing it with incompatible constraints. Also introduce the `MatchConstraintExpressionString` matcher for convenience. PiperOrigin-RevId: 646148530 --- .../service/gpu/model/symbolic_tile_test.cc | 51 ++++++++++++++++--- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc index 393b3305786ffb..2d4785aa243d67 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc @@ -57,6 +57,12 @@ MATCHER_P(MatchSymbolicTileString, symbolic_tile_string, "") { result_listener); } +MATCHER_P(MatchConstraintExpressionString, constraint_expression_string, "") { + return ExplainMatchResult( + true, ApproximateMatch(constraint_expression_string, arg.ToString()), + result_listener); +} + std::vector EvaluateMapAt(AffineMap affine_map, absl::Span parameters) { CHECK_EQ(affine_map.getNumSymbols(), parameters.size()); @@ -810,11 +816,10 @@ TEST_F(ConstraintExpressionTest, } TEST_F(ConstraintExpressionTest, PrettyPrintingTest) { - EXPECT_TRUE( - ApproximateMatch(ConstraintExpression().ToString(), "always satisfied")); - EXPECT_TRUE(ApproximateMatch( - ConstraintExpression::GetUnsatisfiableConstraintExpression().ToString(), - "unsatisfiable")); + EXPECT_THAT(ConstraintExpression(), + MatchConstraintExpressionString("always satisfied")); + EXPECT_THAT(ConstraintExpression::GetUnsatisfiableConstraintExpression(), + MatchConstraintExpressionString("unsatisfiable")); ConjointConstraints conjunction_1 = GetConjointConstraints({{"d0", Interval{0, 5}}, {"d1", Interval{0, 5}}}); @@ -824,7 +829,7 @@ TEST_F(ConstraintExpressionTest, PrettyPrintingTest) { ConstraintExpression constraints; constraints.Or(std::move(conjunction_1)); constraints.Or(std::move(conjunction_2)); - EXPECT_TRUE(ApproximateMatch(constraints.ToString(), + EXPECT_THAT(constraints, MatchConstraintExpressionString( "d0 in [0, 5] && d1 in [0, 5] || d2 in [0, 5]")); } @@ -856,9 +861,39 @@ TEST_F( EXPECT_THAT(conjunctions, SizeIs(1)); // There are three constraints in the single conjunction. EXPECT_THAT(conjunctions.front(), SizeIs(3)); +} + +TEST_F( + ConstraintExpressionTest, + CorrectlyEliminatesConjunctionFromDisjunctionWhenItBecomesUnsatisfiable) { + ConjointConstraints conjunction_1 = + GetConjointConstraints({{"d0", Interval{0, 5}}}); + ConjointConstraints conjunction_2 = + GetConjointConstraints({{"d1", Interval{0, 5}}}); + + ConstraintExpression constraints; + constraints.Or(std::move(conjunction_1)); + constraints.Or(std::move(conjunction_2)); + EXPECT_THAT(constraints, + MatchConstraintExpressionString("d0 in [0, 5] || d1 in [0, 5]")); + + // `conjunction_1` && `conjunction_3` is an unsatisfiable constraint. Taking + // the conjunction of the existing constraint expression with `conjunction_3` + // should therefore evict the unsatisfiable intersection of `conjunction_1` + // and `conjunction_3` from the disjoint expression. + ConjointConstraints conjunction_3 = + GetConjointConstraints({{"d0", Interval{6, 6}}}); + constraints.And(std::move(conjunction_3)); - // TODO(bchetioui): add test for the case where a conjunction becomes - // unsatisfiable and thus gets eliminated from the disjoint expression. + EXPECT_THAT(constraints, + MatchConstraintExpressionString("d0 in [6, 6] && d1 in [0, 5]")); + + // But becomes unsatisfiable if we eliminate the last remaining constraint by + // constructing another unsatisfiable conjunction. + ConjointConstraints conjunction_4 = + GetConjointConstraints({{"d0", Interval{7, 7}}}); + constraints.And(std::move(conjunction_4)); + EXPECT_THAT(constraints, MatchConstraintExpressionString("unsatisfiable")); } TEST_F( From 63acb867e40000495af020ccd713ca5b8da878af Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jun 2024 10:58:20 -0700 Subject: [PATCH 173/256] Reverts 27125ab80d84e1a9f0e0d93aa5416e316c73e91d PiperOrigin-RevId: 646153363 --- tensorflow/core/common_runtime/eager/BUILD | 1 + .../eager/context_distributed_manager.cc | 18 +++--- .../core/common_runtime/gpu/gpu_device.cc | 2 +- .../saved_model/saved_model_aot_compile.cc | 4 +- third_party/xla/xla/pjrt/c/BUILD | 1 + .../xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc | 10 +++- third_party/xla/xla/pjrt/gpu/BUILD | 2 + third_party/xla/xla/pjrt/gpu/gpu_topology.cc | 3 +- third_party/xla/xla/pjrt/gpu/gpu_topology.h | 2 +- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 38 +++++++----- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.h | 58 ++++++++----------- .../xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc | 25 +++++--- 12 files changed, 95 insertions(+), 69 deletions(-) diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index 0cd5ea00b11499..c7da536c946e30 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -274,6 +274,7 @@ tf_cuda_library( "//tensorflow/core/framework:resource_base", "@local_xla//xla/pjrt/distributed:key_value_store_interface", "@local_xla//xla/pjrt:local_device_state", + "@local_xla//xla/pjrt/gpu:gpu_topology", "@local_xla//xla/pjrt:pjrt_client", "@local_xla//xla/pjrt:pjrt_compiler", "@local_xla//xla/service/gpu:gpu_executable_run_options", diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc index 036f1da0f31f82..d2b51ee6652580 100644 --- a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc +++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc @@ -79,6 +79,7 @@ limitations under the License. #if (defined(PLATFORM_GOOGLE) && defined(TF_PLATFORM_LINUX_X86_64)) #define TF_GPU_USE_PJRT #include "xla/pjrt/distributed/key_value_store_interface.h" +#include "xla/pjrt/gpu/gpu_topology.h" #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/local_device_state.h" #include "xla/pjrt/pjrt_compiler.h" @@ -328,17 +329,17 @@ absl::Status CreateClientOnce( // proceed. creation_state->SetReady(); } - auto status = BuildDistributedDevices( + auto device_topology_pair = BuildDistributedDevices( platform_name, std::move(unique_local_device_states), node_id, num_nodes, - &pjrt_devices, gpu_run_options.get(), kv_store, - /*enable_mock_nccl=*/false); - if (!status.ok()) { + gpu_run_options.get(), kv_store, /*enable_mock_nccl=*/false); + if (!device_topology_pair.ok()) { if (use_creation_info) { creation_state->SetDone(); } - return status; + return device_topology_pair.status(); } + pjrt_devices = std::move(device_topology_pair->first); VLOG(2) << "Distributed devices built with size=" << pjrt_devices.size(); int i = 0; for (const auto& pjrt_device : pjrt_devices) { @@ -358,10 +359,11 @@ absl::Status CreateClientOnce( /*allocator=*/std::move(info->allocator), /*host_memory_allocator=*/std::move(info->host_memory_allocator), /*should_stage_host_to_device_transfers=*/true, - /*gpu_run_options=*/std::move(gpu_run_options), kv_store); + /*gpu_run_options=*/std::move(gpu_run_options), kv_store, + xla::GpuTopology::FromProto(device_topology_pair->second)); VLOG(2) << "PJRT GPU client with remote devices created."; - status = SetPjRtClientInTFGlobalResourceManager(DeviceType(DEVICE_GPU), - std::move(pjrt_client)); + auto status = SetPjRtClientInTFGlobalResourceManager( + DeviceType(DEVICE_GPU), std::move(pjrt_client)); creation_state->SetDone(); return status; } else { diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index b5913bf06705f1..49d0c4dc1e6282 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -1931,7 +1931,7 @@ Status BaseGPUDeviceFactory::CreateDevices( /*host_memory_allocator=*/std::move(pjrt_gpu_host_allocator), /*should_stage_host_to_device_transfers=*/true, /*gpu_run_options=*/std::move(gpu_run_options), - /*kv_store=*/nullptr); + /*kv_store=*/nullptr, /*gpu_topology=*/nullptr); return SetPjRtClientInTFGlobalResourceManager(DeviceType(DEVICE_GPU), std::move(pjrt_client)); diff --git a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc index 0ed2e8618c4d71..dbaa76925a372c 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc +++ b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc @@ -310,8 +310,8 @@ AotCompileToGpuPjRtExecutable( xla::Compiler::TargetConfig gpu_config(gpu_target_config); xla::StreamExecutorGpuCompiler pjrt_gpu_compiler; // Create a trivial topology, which won't be used. - xla::StreamExecutorGpuTopologyDescription topology( - xla::CudaId(), xla::CudaName(), "fake_device", {0}); + xla::StreamExecutorGpuTopologyDescription topology(xla::CudaId(), + xla::CudaName(), nullptr); xla::CompileOptions pjrt_options = GetPjRtCompileOptions(options, **compilation_result); pjrt_options.target_config = gpu_config; diff --git a/third_party/xla/xla/pjrt/c/BUILD b/third_party/xla/xla/pjrt/c/BUILD index 37e1003d5d0ce5..a677da30a3361c 100644 --- a/third_party/xla/xla/pjrt/c/BUILD +++ b/third_party/xla/xla/pjrt/c/BUILD @@ -269,6 +269,7 @@ cc_library( "//xla/pjrt:pjrt_device_description", "//xla/pjrt:pjrt_executable", "//xla/pjrt/gpu:gpu_helpers", + "//xla/pjrt/gpu:gpu_topology", "//xla/pjrt/gpu:se_gpu_pjrt_client", "//xla/pjrt/gpu:se_gpu_pjrt_compiler", # To register GPU AOT compiler "//xla/python:custom_partition_callback", diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc index 93e48d9e807703..d16483166b5725 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc @@ -42,6 +42,7 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api_stream_extension.h" #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" #include "xla/pjrt/gpu/gpu_helpers.h" +#include "xla/pjrt/gpu/gpu_topology.h" #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" @@ -186,9 +187,16 @@ PJRT_Error* PJRT_GpuDeviceTopology_Create( device_ids.push_back(executor->device_ordinal()); } auto gpu_target_config = xla::Compiler::TargetConfig(executor); + // TODO(b/341334898): Create a single-host GPU topology. Will be updated for + // multi-host support in the future. + auto gpu_topology = std::make_shared( + device_ids, description.name(), + /*num_slices=*/1, + /*num_hosts_per_slice=*/1, + /*num_devices_per_host=*/device_ids.size()); auto pjrt_topology = std::make_unique( - xla::CudaId(), xla::CudaName(), description.name(), device_ids, + xla::CudaId(), xla::CudaName(), std::move(gpu_topology), absl::flat_hash_map{ {"target_config", gpu_target_config.ToProto().SerializeAsString()}}); diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 217833f0fdcdde..42e5f6f3c5c8c8 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -44,6 +44,7 @@ cc_library( ":gpu_helpers", ":gpu_metrics", ":gpu_topology", + ":gpu_topology_proto_cc", "//xla:literal", "//xla:shape_util", "//xla:status_macros", @@ -298,6 +299,7 @@ xla_test( "no_oss", ] + if_google(["config-cuda-only"]), deps = [ + ":gpu_topology", ":se_gpu_pjrt_client", ":se_gpu_pjrt_compiler", "//xla:test", diff --git a/third_party/xla/xla/pjrt/gpu/gpu_topology.cc b/third_party/xla/xla/pjrt/gpu/gpu_topology.cc index e9baf5f359ba66..600adf98231fcb 100644 --- a/third_party/xla/xla/pjrt/gpu/gpu_topology.cc +++ b/third_party/xla/xla/pjrt/gpu/gpu_topology.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/pjrt/gpu/gpu_topology.h" #include +#include #include namespace xla { @@ -33,7 +34,7 @@ std::unique_ptr GpuTopology::FromProto( GpuTopologyProto GpuTopology::ToProto() const { GpuTopologyProto proto; proto.mutable_device_ids()->Add(device_ids().begin(), device_ids().end()); - proto.set_platform_version(platform_version()); + proto.set_platform_version(std::string(platform_version())); proto.set_num_slices(num_slices()); proto.set_num_hosts_per_slice(num_hosts_per_slice()); proto.set_num_devices_per_host(num_devices_per_host()); diff --git a/third_party/xla/xla/pjrt/gpu/gpu_topology.h b/third_party/xla/xla/pjrt/gpu/gpu_topology.h index 5c87b4223b92e8..957a36001b0968 100644 --- a/third_party/xla/xla/pjrt/gpu/gpu_topology.h +++ b/third_party/xla/xla/pjrt/gpu/gpu_topology.h @@ -56,7 +56,7 @@ class GpuTopology { const GpuTopologyProto& proto); GpuTopologyProto ToProto() const; - std::string platform_version() const { return platform_version_; } + std::string_view platform_version() const { return platform_version_; } int32_t num_slices() const { return num_slices_; } int32_t num_hosts_per_slice() const { return num_hosts_per_slice_; } int32_t num_devices_per_host() const { return num_devices_per_host_; } diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 4be7dc6deb7f66..5bf6bac568779c 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -55,6 +55,8 @@ limitations under the License. #include "xla/pjrt/distributed/topology_util.h" #include "xla/pjrt/event_pool.h" #include "xla/pjrt/gpu/gpu_helpers.h" +#include "xla/pjrt/gpu/gpu_topology.h" +#include "xla/pjrt/gpu/gpu_topology.pb.h" #include "xla/pjrt/host_memory_spaces.h" #include "xla/pjrt/local_device_state.h" #include "xla/pjrt/pjrt_client.h" @@ -485,14 +487,15 @@ StreamExecutorGpuClient::StreamExecutorGpuClient( std::unique_ptr host_memory_allocator, bool should_stage_host_to_device_transfers, std::unique_ptr gpu_run_options, - std::shared_ptr kv_store) + std::shared_ptr kv_store, + std::shared_ptr gpu_topology) : xla::PjRtStreamExecutorClient( platform_name, client, std::move(devices), process_index, std::move(allocator), std::move(host_memory_allocator), should_stage_host_to_device_transfers, std::move(gpu_run_options)), topology_(xla::StreamExecutorGpuTopologyDescription::Create( tsl::Fingerprint64(platform_name), platform_name, - devices_.back()->device_kind(), devices_)), + std::move(gpu_topology))), kv_store_(std::move(kv_store)) { for (auto* device : addressable_devices()) { // Use the device id to construct a globally unique memory space id. We do @@ -948,15 +951,15 @@ GetStreamExecutorGpuDeviceAllocator( } // namespace -absl::Status BuildDistributedDevices( +absl::StatusOr BuildDistributedDevices( std::string_view platform_name, std::map> local_device_states, int node_id, int num_nodes, - std::vector>* devices, gpu::GpuExecutableRunOptions* gpu_executable_run_options, std::shared_ptr kv_store, bool enable_mock_nccl, absl::Duration get_local_topology_timeout, absl::Duration get_global_topology_timeout) { + std::vector> devices; LocalTopologyProto local_topology; local_topology.set_node_id(node_id); std::string boot_id_str; @@ -1020,7 +1023,7 @@ absl::Status BuildDistributedDevices( device_proto.name(), device_proto.vendor(), device_proto.compute_capability(), device_proto.core_count(), node.node_id(), device_proto.slice_index()); - devices->push_back(std::move(device)); + devices.push_back(std::move(device)); } } for (const auto& device : local_device_states) { @@ -1038,7 +1041,9 @@ absl::Status BuildDistributedDevices( }); } #endif // GOOGLE_CUDA - return absl::OkStatus(); + TF_ASSIGN_OR_RETURN(GpuTopologyProto gpu_topology, + BuildGpuTopology(global_topology)); + return std::make_pair(std::move(devices), gpu_topology); } std::string MakeComputeCapabilityString(const se::DeviceDescription* desc) { @@ -1170,22 +1175,27 @@ absl::StatusOr> GetStreamExecutorGpuClient( kv_store = std::make_shared(); } TF_RET_CHECK(options.num_nodes == 1 || kv_store != nullptr); - TF_RETURN_IF_ERROR(BuildDistributedDevices( - pjrt_platform_name, std::move(local_device_states), options.node_id, - options.num_nodes, &devices, gpu_run_options.get(), kv_store, - options.enable_mock_nccl)); + TF_ASSIGN_OR_RETURN( + DeviceTopologyPair device_topology_pair, + BuildDistributedDevices(pjrt_platform_name, + std::move(local_device_states), options.node_id, + options.num_nodes, gpu_run_options.get(), + kv_store, options.enable_mock_nccl)); + + auto gpu_topology = std::shared_ptr( + GpuTopology::FromProto(device_topology_pair.second)); return std::unique_ptr(std::make_unique( - pjrt_platform_name, xla_client, std::move(devices), options.node_id, - std::move(allocator), std::move(host_memory_allocator), + pjrt_platform_name, xla_client, std::move(device_topology_pair.first), + options.node_id, std::move(allocator), std::move(host_memory_allocator), options.should_stage_host_to_device_transfers, std::move(gpu_run_options), - std::move(kv_store))); + std::move(kv_store), std::move(gpu_topology))); } absl::StatusOr StreamExecutorGpuTopologyDescription::Serialize() const { std::string result; - if (!tsl::SerializeToStringDeterministic(gpu_topology_.ToProto(), &result)) { + if (!tsl::SerializeToStringDeterministic(gpu_topology_->ToProto(), &result)) { return absl::InternalError("Failed to serialize gpu_topology"); } return result; diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h index db351f143ff44e..afb624b248f863 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -38,6 +38,7 @@ limitations under the License. #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/gpu/gpu_helpers.h" #include "xla/pjrt/gpu/gpu_topology.h" +#include "xla/pjrt/gpu/gpu_topology.pb.h" #include "xla/pjrt/local_device_state.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" @@ -60,36 +61,27 @@ class MultiDeviceAdapter; } namespace xla { +using DeviceTopologyPair = + std::pair>, + GpuTopologyProto>; class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription { public: static StreamExecutorGpuTopologyDescription Create( const PjRtPlatformId platform_id, const absl::string_view platform_name, - const absl::string_view platform_version, - const std::vector& devices) { - std::vector device_ids; - device_ids.reserve(devices.size()); - for (PjRtDevice* device : devices) { - device_ids.push_back(device->id()); - } + std::shared_ptr gpu_topology) { return StreamExecutorGpuTopologyDescription(platform_id, platform_name, - platform_version, device_ids); + gpu_topology); } - // `gpu_device_ids` is the list of logical device ids for the GPU devices and - // will be used to initialize the GPU topology. + StreamExecutorGpuTopologyDescription( const PjRtPlatformId platform_id, const absl::string_view platform_name, - const absl::string_view platform_version, - const std::vector& gpu_device_ids, + std::shared_ptr gpu_topology, const absl::flat_hash_map& attributes = {}) : platform_id_(platform_id), platform_name_(platform_name), - platform_version_(platform_version), - // TODO(b/331224674): Add support for multi-host. - gpu_topology_(gpu_device_ids, platform_version, /*num_slices=*/1, - /*num_hosts_per_slice=*/1, - /*num_devices_per_host=*/gpu_device_ids.size()), + gpu_topology_(std::move(gpu_topology)), attributes_(attributes) {} bool operator==(const StreamExecutorGpuTopologyDescription& other) const { @@ -104,39 +96,40 @@ class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription { absl::string_view platform_name() const override { return platform_name_; } absl::string_view platform_version() const override { - return platform_version_; + return gpu_topology_->platform_version(); } std::vector> DeviceDescriptions() const override { std::vector> devices; - devices.reserve(gpu_topology_.number_of_devices()); - for (const int device_id : gpu_topology_.device_ids()) { + devices.reserve(gpu_topology_->number_of_devices()); + for (const int device_id : gpu_topology_->device_ids()) { devices.push_back(std::make_unique( - device_id, platform_version_)); + device_id, std::string(platform_version()))); } return devices; } - const GpuTopology& gpu_topology() const { return gpu_topology_; } - const GpuTopology* gpu_topology_ptr() const { return &gpu_topology_; } + const GpuTopology& gpu_topology() const { return *gpu_topology_; } + const GpuTopology* gpu_topology_ptr() const { return gpu_topology_.get(); } // No subslice is supported. bool is_subslice_topology() const override { return false; } - // The topology support only single host now. - absl::StatusOr ProcessCount() const override { return 1; } + absl::StatusOr ProcessCount() const override { + return gpu_topology_->number_of_hosts(); + } absl::StatusOr CoreCountOfDefaultType() const override { - return gpu_topology_.number_of_devices(); + return gpu_topology_->number_of_devices(); } absl::StatusOr LogicalDeviceCountOfDefaultType() const override { - return gpu_topology_.number_of_devices(); + return gpu_topology_->number_of_devices(); } absl::StatusOr CoreCountOfDefaultTypePerProcess() const override { - return gpu_topology_.number_of_devices(); + return gpu_topology_->number_of_devices(); } absl::StatusOr CoreCountOfDefaultTypePerChip() const override { @@ -158,8 +151,7 @@ class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription { private: const PjRtPlatformId platform_id_; const std::string platform_name_; - const std::string platform_version_; - const GpuTopology gpu_topology_; + std::shared_ptr gpu_topology_; absl::flat_hash_map attributes_; }; @@ -208,7 +200,8 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { std::unique_ptr host_memory_allocator, bool should_stage_host_to_device_transfers, std::unique_ptr gpu_run_options, - std::shared_ptr kv_store); + std::shared_ptr kv_store, + std::shared_ptr gpu_topology); absl::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override; @@ -265,11 +258,10 @@ std::vector> BuildLocalDevices( std::string MakeComputeCapabilityString(const se::DeviceDescription* desc); -absl::Status BuildDistributedDevices( +absl::StatusOr BuildDistributedDevices( std::string_view platform_name, std::map> local_device_states, int node_id, int num_nodes, - std::vector>* devices, gpu::GpuExecutableRunOptions* gpu_executable_run_options, std::shared_ptr kv_store, bool enable_mock_nccl, absl::Duration get_local_topology_timeout = absl::Minutes(2), diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc index 70a47c2400c433..4a94f14ab162a0 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/Parser/Parser.h" // from @llvm-project #include "xla/client/xla_computation.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/pjrt/gpu/gpu_topology.h" #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" @@ -60,10 +61,18 @@ absl::StatusOr GetXlaComputation( return XlaComputation(hlo_module->ToProto()); } +std::shared_ptr GetGpuTopology( + std::vector device_ids, absl::string_view platform_version, + int num_slices, int num_hosts_per_slice, int num_devices_per_host) { + return std::make_shared(device_ids, platform_version, + num_slices, num_hosts_per_slice, + num_devices_per_host); +} + TEST(StreamExecutorGpuCompilerTest, NoClientXla) { StreamExecutorGpuCompiler compiler; - StreamExecutorGpuTopologyDescription topology(CudaId(), CudaName(), - "Fake_device", {0, 1}); + StreamExecutorGpuTopologyDescription topology( + CudaId(), CudaName(), GetGpuTopology({0, 1}, "Fake_device", 1, 1, 2)); TF_ASSERT_OK_AND_ASSIGN(auto computation, GetXlaComputation(kProgram)); EXPECT_THAT(compiler.Compile(xla::CompileOptions(), computation, topology, @@ -73,8 +82,8 @@ TEST(StreamExecutorGpuCompilerTest, NoClientXla) { TEST(StreamExecutorGpuCompilerTest, TopologyNotSameXla) { StreamExecutorGpuCompiler compiler; - StreamExecutorGpuTopologyDescription topology(CudaId(), CudaName(), - "Fake_device", {0, 1}); + StreamExecutorGpuTopologyDescription topology( + CudaId(), CudaName(), GetGpuTopology({0, 1}, "Fake_device", 1, 1, 2)); TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(GpuClientOptions())); @@ -119,8 +128,8 @@ TEST(StreamExecutorGpuCompilerTest, NoClientMlir) { auto mlir_module = mlir::parseSourceString(mlir_str, &context); - StreamExecutorGpuTopologyDescription topology(CudaId(), CudaName(), - "Fake_device", {0, 1}); + StreamExecutorGpuTopologyDescription topology( + CudaId(), CudaName(), GetGpuTopology({0, 1}, "Fake_device", 1, 1, 2)); EXPECT_THAT( compiler.Compile(xla::CompileOptions(), mlir_module.get(), topology, @@ -137,8 +146,8 @@ TEST(StreamExecutorGpuCompilerTest, TopologyNotSameMlir) { auto mlir_module = mlir::parseSourceString(mlir_str, &context); - StreamExecutorGpuTopologyDescription topology(CudaId(), CudaName(), - "Fake_device", {0, 1}); + StreamExecutorGpuTopologyDescription topology( + CudaId(), CudaName(), GetGpuTopology({0, 1}, "Fake_device", 1, 1, 2)); TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(GpuClientOptions())); From 20fa4d95b5b8339a7d1a7e817b636332060dcc4a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jun 2024 11:34:48 -0700 Subject: [PATCH 174/256] Remove unused enable_propagation_on_dots option in SpaceToBatchConverter. PiperOrigin-RevId: 646166829 --- third_party/xla/xla/service/space_to_batch_converter.h | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/service/space_to_batch_converter.h b/third_party/xla/xla/service/space_to_batch_converter.h index 0686c12f91f333..97d19918320ee9 100644 --- a/third_party/xla/xla/service/space_to_batch_converter.h +++ b/third_party/xla/xla/service/space_to_batch_converter.h @@ -31,7 +31,6 @@ struct SpaceToBatchController { bool enable_propagations_on_trivial_window_dilations; bool disable_starting_on_small_chains; int64_t limit_on_batch_size; - bool enable_propagations_on_dots = false; int64_t dimension_from_end_to_convert = 1; // We choose the new batch size to be number_of_splits times that of the old // batch so that space-to-batch propagation through several convolutional From 2ff5bd5663d6852fd995fccb34bbfc80a0010cd7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jun 2024 11:46:46 -0700 Subject: [PATCH 175/256] Use the internal version of builtin_op_data.h in mlir/lite. PiperOrigin-RevId: 646170820 --- tensorflow/compiler/mlir/lite/BUILD | 4 +-- tensorflow/compiler/mlir/lite/core/c/BUILD | 20 +++++++++-- .../mlir/lite/core/c/builtin_op_data.h | 34 ++++++++++++++++++- .../compiler/mlir/lite/flatbuffer_export.cc | 2 +- .../compiler/mlir/lite/flatbuffer_operator.cc | 3 +- 5 files changed, 54 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 974f767940924f..db052746dcda3f 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1028,6 +1028,7 @@ cc_library( ":convert_type", ":converter_inc", ":tensorflow_lite", + "//tensorflow/compiler/mlir/lite/core/c:private_common", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", "//tensorflow/compiler/mlir/lite/schema:schema_utils", @@ -1036,7 +1037,6 @@ cc_library( "//tensorflow/core/platform:errors", "//tensorflow/core/platform:status", "//tensorflow/core/platform:statusor", - "//tensorflow/lite/core/c:private_common", "//tensorflow/lite/kernels/internal:kernel_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", @@ -1092,6 +1092,7 @@ cc_library( ":tensorflow_lite", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite:control_edges", + "//tensorflow/compiler/mlir/lite/core/c:private_common", "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", "//tensorflow/compiler/mlir/lite/metrics:error_collector_inst", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", @@ -1106,7 +1107,6 @@ cc_library( "//tensorflow/core:portable_gif_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/lite/core:framework", - "//tensorflow/lite/core/c:private_common", "//tensorflow/lite/delegates/flex:allowlisted_flex_ops_lib", "//tensorflow/lite/python/metrics:converter_error_data_proto_cc", "//tensorflow/lite/toco:toco_flags_proto_cc", diff --git a/tensorflow/compiler/mlir/lite/core/c/BUILD b/tensorflow/compiler/mlir/lite/core/c/BUILD index 02e10ba9de270b..8368a273de4141 100644 --- a/tensorflow/compiler/mlir/lite/core/c/BUILD +++ b/tensorflow/compiler/mlir/lite/core/c/BUILD @@ -9,7 +9,7 @@ package( licenses = ["notice"], ) -# LINT.IfChange(cc_library_common) +# LINT.IfChange(common) cc_library( name = "common", srcs = [], @@ -17,9 +17,23 @@ cc_library( compatible_with = get_compatible_with_portable(), copts = tflite_copts(), visibility = [ - "//tensorflow/compiler/mlir/lite/kernels:__pkg__", + "//tensorflow/compiler/mlir/lite:__subpackages__", "//tensorflow/compiler/mlir/quantization/tensorflow/utils:__pkg__", ], alwayslink = 1, # Why?? TODO(b/161243354): eliminate this. ) -# LINT.ThenChange(//tensorflow/lite/core/c/BUILD) +# LINT.ThenChange(//tensorflow/lite/core/c:common) + +# LINT.IfChange(private_common) +# This is a private target, its visibility is set to public only to be +# used by "tflite_custom_c_library" and "tflite_flex_cc_library". +# Do not use this target directly and don't consider it as a part of the public API. +alias( + name = "private_common", + actual = ":common", + tags = ["avoid_dep"], + visibility = [ + "//visibility:public", + ], +) +# LINT.ThenChange(//tensorflow/lite/core/c:private_common) diff --git a/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h b/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h index 8185215ba0cf8c..7a67c630fe1ebd 100644 --- a/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h +++ b/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h @@ -15,13 +15,45 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_BUILTIN_OP_DATA_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_BUILTIN_OP_DATA_H_ -// LINT.IfChange +// LINT.IfChange(enum) typedef enum { kTfLitePaddingUnknown = 0, kTfLitePaddingSame, kTfLitePaddingValid, } TfLitePadding; +// Possible fused activation functions. +typedef enum { + kTfLiteActNone = 0, + kTfLiteActRelu, + kTfLiteActReluN1To1, // min(max(-1, x), 1) + kTfLiteActRelu6, // min(max(0, x), 6) + kTfLiteActTanh, + kTfLiteActSignBit, + kTfLiteActSigmoid, +} TfLiteFusedActivation; +// LINT.ThenChange(//tensorflow/lite/core/c/builtin_op_data.h) + +// LINT.IfChange(struct) +// TODO(b/130259536): We should move this out of builtin_op_data. +typedef struct { + int width; + int height; + int width_offset; + int height_offset; +} TfLitePaddingValues; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int filter_width; + int filter_height; + TfLiteFusedActivation activation; + struct { + TfLitePaddingValues padding; + } computed; +} TfLitePoolParams; // LINT.ThenChange(//tensorflow/lite/core/c/builtin_op_data.h) #endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 6250c60680e867..7b8059c0024d12 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -81,6 +81,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "stablehlo/dialect/VhloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h" #include "tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" @@ -109,7 +110,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/tstring.h" -#include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/interpreter.h" #include "tensorflow/lite/core/macros.h" #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h" diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc index 25bf15f6be61e7..affb3936967eb3 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc @@ -46,6 +46,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "stablehlo/dialect/VhloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" @@ -54,8 +55,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/lite/core/c/builtin_op_data.h" -#include "tensorflow/lite/kernels/internal/kernel_utils.h" #include "tsl/platform/status.h" namespace { From b24db0b2a858d001445850c39158500b633d5eb8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jun 2024 12:25:01 -0700 Subject: [PATCH 176/256] Add back `xla/stream_executor:cuda_platform` to `tf_additional_binary_deps`. Should fix https://github.com/tensorflow/tensorflow/issues/63362 Reverts changelist 582804278 PiperOrigin-RevId: 646182849 --- tensorflow/core/platform/build_config.default.bzl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/platform/build_config.default.bzl b/tensorflow/core/platform/build_config.default.bzl index 650e776c0423fc..b39d6b913740e5 100644 --- a/tensorflow/core/platform/build_config.default.bzl +++ b/tensorflow/core/platform/build_config.default.bzl @@ -1,5 +1,6 @@ """OSS versions of Bazel macros that can't be migrated to TSL.""" +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load( "@local_xla//xla/tsl:tsl.bzl", @@ -28,7 +29,9 @@ def tf_additional_binary_deps(): # core. Label("//tensorflow/core/kernels:lookup_util"), Label("//tensorflow/core/util/tensor_bundle"), - ] + if_rocm([ + ] + if_cuda([ + Label("@local_xla//xla/stream_executor:cuda_platform"), + ]) + if_rocm([ "@local_xla//xla/stream_executor:rocm_platform", "@local_xla//xla/stream_executor/rocm:rocm_rpath", ]) + if_mkl_ml([ From 3207538ce66d64ddcfe51cd31cda744f10fff695 Mon Sep 17 00:00:00 2001 From: Laura Pak Date: Mon, 24 Jun 2024 12:30:24 -0700 Subject: [PATCH 177/256] Change the return type of ::mlir::lite::QuantizeModel from TfLiteStatus to absl::Status. PiperOrigin-RevId: 646184579 --- tensorflow/compiler/mlir/lite/python/BUILD | 1 + .../mlir/lite/python/converter_python_api.cc | 4 +- .../mlir/lite/quantization/lite/BUILD | 7 +- .../lite/quantization/lite/quantize_model.cc | 15 +- .../lite/quantization/lite/quantize_model.h | 4 +- .../quantization/lite/quantize_model_test.cc | 206 ++++++++---------- .../lite/quantization/lite/tfl_quantizer.cc | 7 +- tensorflow/lite/toco/python/BUILD | 1 + .../lite/toco/python/toco_python_api.cc | 4 +- 9 files changed, 113 insertions(+), 136 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/python/BUILD b/tensorflow/compiler/mlir/lite/python/BUILD index 8d50caa5f1321b..3624da1347036e 100644 --- a/tensorflow/compiler/mlir/lite/python/BUILD +++ b/tensorflow/compiler/mlir/lite/python/BUILD @@ -229,6 +229,7 @@ cc_library( "//tensorflow/lite/toco/logging:toco_conversion_log_proto_cc", "//third_party/python_runtime:headers", # build_cleaner: keep; DNR: b/35864863 "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf_headers", diff --git a/tensorflow/compiler/mlir/lite/python/converter_python_api.cc b/tensorflow/compiler/mlir/lite/python/converter_python_api.cc index 4509ed004e4927..704c378609440d 100644 --- a/tensorflow/compiler/mlir/lite/python/converter_python_api.cc +++ b/tensorflow/compiler/mlir/lite/python/converter_python_api.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers @@ -385,7 +386,8 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, /*legacy_float_scale=*/true, denylisted_ops, denylisted_nodes, enable_variable_quantization, disable_per_channel_for_dense_layers, debug_options); - if (status != kTfLiteOk) { + if (!status.ok()) { + LOG(ERROR) << "Failed to quantize model: " << status; error_reporter->exception(); return nullptr; } diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index 8370dd98d79e3b..78ae512eb54202 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -37,9 +37,9 @@ cc_library( "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/core:protos_all_cc", - "//tensorflow/lite/c:c_api_types", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -114,7 +114,7 @@ tf_cc_binary( deps = [ ":quantize_model", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", - "//tensorflow/lite/c:c_api_types", + "@com_google_absl//absl/status", "@llvm-project//llvm:Support", ], ) @@ -168,11 +168,12 @@ tf_cc_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/tools/optimize:test_util", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@flatbuffers", + "@local_tsl//tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 72076858c6b742..cf4e7273799362 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" @@ -43,7 +44,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/lite/c/c_api_types.h" namespace mlir { namespace lite { @@ -54,7 +54,7 @@ std::string TfLiteToMlir(const absl::string_view tflite_op_name) { } // TODO(fengliuai): check the result for `fully_quantize` flag. -TfLiteStatus QuantizeModel( +absl::Status QuantizeModel( const absl::string_view model_buffer, const tflite::TensorType &input_type, const tflite::TensorType &output_type, const tflite::TensorType &inference_type, @@ -82,8 +82,7 @@ TfLiteStatus QuantizeModel( OwningOpRef module = tflite::FlatBufferToMlir( model_buffer, &context, UnknownLoc::get(&context)); if (!module) { - LOG(ERROR) << "Couldn't import flatbuffer to MLIR."; - return kTfLiteError; + return absl::InternalError("Couldn't import flatbuffer to MLIR."); } // Apply quantization passes. @@ -128,8 +127,7 @@ TfLiteStatus QuantizeModel( pm.addPass(TFL::CreatePostQuantizeRemoveQDQPass()); if (failed(pm.run(module.get()))) { const std::string err(statusHandler.ConsumeStatus().message()); - LOG(ERROR) << "Failed to quantize: " << err; - return kTfLiteError; + return absl::InternalError(err); } // Export the results. @@ -139,10 +137,9 @@ TfLiteStatus QuantizeModel( options.toco_flags.set_allow_custom_ops(true); if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options, &output_buffer)) { - LOG(ERROR) << "Failed to export MLIR to flatbuffer."; - return kTfLiteError; + return absl::InternalError("Failed to export MLIR to flatbuffer."); } - return kTfLiteOk; + return absl::OkStatus(); } } // namespace lite diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h index afd075c1614c88..895c7344966fe9 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h @@ -19,10 +19,10 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/mlir/lite/debug/debug_options.pb.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" -#include "tensorflow/lite/c/c_api_types.h" namespace mlir { namespace lite { @@ -44,7 +44,7 @@ namespace lite { // When `legacy_float_scale` is true, the quantizer will use float scale instead // of double, and call TOCO's quantization routines to maintain bit-exactness of // the values with the TOCO quantizer. -TfLiteStatus QuantizeModel( +absl::Status QuantizeModel( absl::string_view model_buffer, const tflite::TensorType &input_type, const tflite::TensorType &output_type, const tflite::TensorType &inference_type, diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc index fe9a326dbd0b13..e7d5e00b703392 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" @@ -35,9 +36,9 @@ limitations under the License. #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" -#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/model_builder.h" #include "tensorflow/lite/tools/optimize/test_util.h" +#include "tsl/lib/core/status_test_util.h" // Note: branched from tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -68,7 +69,7 @@ ModelT UnPackFlatBufferModel(const Model& flatbuffer_model) { return model; } -TfLiteStatus QuantizeModel( +absl::Status QuantizeModel( ModelT* model, const TensorType& input_type, const TensorType& output_type, const bool allow_float, const std::unordered_set& operator_names, @@ -95,36 +96,36 @@ TfLiteStatus QuantizeModel( /*enable_variable_quantization=*/false, /*disable_per_channel_for_dense_layers=*/ disable_per_channel_for_dense_layers); - if (status != kTfLiteOk) { + if (!status.ok()) { return status; } auto flatbuffer_model = FlatBufferModel::BuildFromBuffer( output_buffer.data(), output_buffer.size()); *model = UnPackFlatBufferModel(*flatbuffer_model->GetModel()); - return kTfLiteOk; + return absl::OkStatus(); } -TfLiteStatus QuantizeModel(ModelT* model, const TensorType& input_type, +absl::Status QuantizeModel(ModelT* model, const TensorType& input_type, const TensorType& output_type, bool allow_float, std::string& output_buffer) { return QuantizeModel(model, input_type, output_type, allow_float, /*operator_names=*/{}, TensorType_INT8, output_buffer); } -TfLiteStatus QuantizeModel(ModelT* model, const TensorType& input_type, +absl::Status QuantizeModel(ModelT* model, const TensorType& input_type, const TensorType& output_type, std::string& output_buffer) { return QuantizeModel(model, input_type, output_type, /*allow_float=*/false, output_buffer); } -TfLiteStatus QuantizeModel(ModelT* model, std::string& output_buffer) { +absl::Status QuantizeModel(ModelT* model, std::string& output_buffer) { return QuantizeModel(model, TensorType_FLOAT32, TensorType_FLOAT32, /*allow_float=*/true, output_buffer); } -TfLiteStatus QuantizeModelAllOperators( +absl::Status QuantizeModelAllOperators( ModelT* model, const TensorType& input_type, const TensorType& output_type, bool allow_float, const TensorType& activations_type, bool disable_per_channel, std::string& output_buffer) { @@ -133,7 +134,7 @@ TfLiteStatus QuantizeModelAllOperators( disable_per_channel); } -TfLiteStatus QuantizeModelAllOperators(ModelT* model, +absl::Status QuantizeModelAllOperators(ModelT* model, const TensorType& input_type, const TensorType& output_type, bool allow_float, @@ -143,7 +144,7 @@ TfLiteStatus QuantizeModelAllOperators(ModelT* model, /*operator_names=*/{}, activations_type, output_buffer); } -TfLiteStatus QuantizeModelAllOperators( +absl::Status QuantizeModelAllOperators( ModelT* model, const TensorType& input_type, const TensorType& output_type, bool allow_float, const TensorType& activations_type, std::string& output_buffer, bool disable_per_channel_for_dense_layers) { @@ -297,21 +298,19 @@ INSTANTIATE_TEST_SUITE_P(QuantizeConvModelTestInst, QuantizeConvModelTest, testing::ValuesIn({TensorType_INT8})); TEST_P(QuantizeConvModelTest, QuantizationSucceeds) { - auto status = QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, - /*allow_float=*/false, tensor_type_, - output_buffer_); - EXPECT_THAT(status, Eq(kTfLiteOk)); + TF_EXPECT_OK(QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, + /*allow_float=*/false, tensor_type_, + output_buffer_)); const Model* output_model = GetModel(output_buffer_.data()); ASSERT_TRUE(output_model); } TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayer) { - auto status = QuantizeModel(&model_, TensorType_FLOAT32, TensorType_FLOAT32, - /*allow_float=*/true, /*operator_names=*/{}, - TensorType_FLOAT32, output_buffer_, - /*disable_per_channel=*/false, {"CONV_2D"}); - EXPECT_THAT(status, Eq(kTfLiteOk)); + TF_EXPECT_OK(QuantizeModel(&model_, TensorType_FLOAT32, TensorType_FLOAT32, + /*allow_float=*/true, /*operator_names=*/{}, + TensorType_FLOAT32, output_buffer_, + /*disable_per_channel=*/false, {"CONV_2D"})); ModelT expected_model; readonly_model_->UnPackTo(&expected_model); @@ -320,12 +319,11 @@ TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayer) { } TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayerByName) { - auto status = QuantizeModel(&model_, TensorType_FLOAT32, TensorType_FLOAT32, - /*allow_float=*/true, /*operator_names=*/{}, - TensorType_FLOAT32, output_buffer_, - /*disable_per_channel=*/false, - /*blocked_ops=*/{}, {"output"}); - EXPECT_THAT(status, Eq(kTfLiteOk)); + TF_EXPECT_OK(QuantizeModel(&model_, TensorType_FLOAT32, TensorType_FLOAT32, + /*allow_float=*/true, /*operator_names=*/{}, + TensorType_FLOAT32, output_buffer_, + /*disable_per_channel=*/false, + /*blocked_ops=*/{}, {"output"})); ModelT expected_model; readonly_model_->UnPackTo(&expected_model); @@ -334,10 +332,9 @@ TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayerByName) { } TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) { - auto status = QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, - /*allow_float=*/false, tensor_type_, - output_buffer_); - EXPECT_THAT(status, Eq(kTfLiteOk)); + TF_EXPECT_OK(QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, + /*allow_float=*/false, tensor_type_, + output_buffer_)); for (const auto& subgraph : model_.subgraphs) { for (const auto& tensor : subgraph->tensors) { @@ -379,10 +376,9 @@ class QuantizeSplitModelTest : public QuantizeModelTest { // There are two outputs for split with different scales, the resulting model // should have the scales be hardcodes to the input scale value. TEST_F(QuantizeSplitModelTest, QuantizeSplit) { - auto status = QuantizeModelAllOperators( - &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, - TensorType_INT8, output_buffer_); - EXPECT_THAT(status, Eq(kTfLiteOk)); + TF_EXPECT_OK(QuantizeModelAllOperators( + &model_, TensorType_INT8, TensorType_INT8, + /*allow_float=*/false, TensorType_INT8, output_buffer_)); // There is only one subgraph. const int32_t subgraph_idx = 0; @@ -477,10 +473,9 @@ INSTANTIATE_TEST_SUITE_P(QuantizeConvModel2TestInst, QuantizeConvModel2Test, testing::ValuesIn({TensorType_INT8})); TEST_P(QuantizeConvModel2Test, VerifyConvQuantization) { - auto status = QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, - /*allow_float=*/false, tensor_type_, - output_buffer_); - ASSERT_THAT(status, Eq(kTfLiteOk)); + TF_ASSERT_OK(QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, + /*allow_float=*/false, tensor_type_, + output_buffer_)); const auto& subgraph = model_.subgraphs[0]; auto conv_op = subgraph->operators[0].get(); const int input_tensor_idx = 0; @@ -584,10 +579,10 @@ TEST_P(QuantizeConvModel2Test, VerifyConvQuantization) { } TEST_P(QuantizeConvModel2Test, VerifyConvDisablePerChannelQuantization) { - auto status = QuantizeModelAllOperators( - &model_, tensor_type_, tensor_type_, /*allow_float=*/false, tensor_type_, - /*disable_per_channel=*/true, output_buffer_); - ASSERT_THAT(status, Eq(kTfLiteOk)); + TF_ASSERT_OK(QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, + /*allow_float=*/false, tensor_type_, + /*disable_per_channel=*/true, + output_buffer_)); const auto& subgraph = model_.subgraphs[0]; auto conv_op = subgraph->operators[0].get(); const int input_tensor_idx = 0; @@ -702,10 +697,9 @@ class QuantizeSoftmaxTest : public QuantizeModelTest { }; TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { - auto status = QuantizeModelAllOperators( - &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, - TensorType_INT8, output_buffer_); - ASSERT_THAT(status, Eq(kTfLiteOk)); + TF_ASSERT_OK(QuantizeModelAllOperators( + &model_, TensorType_INT8, TensorType_INT8, + /*allow_float=*/false, TensorType_INT8, output_buffer_)); const auto& subgraph = model_.subgraphs[0]; auto op = subgraph->operators[0].get(); @@ -766,10 +760,9 @@ class QuantizeAvgPoolTest : public QuantizeModelTest { }; TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { - auto status = QuantizeModelAllOperators( - &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, - TensorType_INT8, output_buffer_); - ASSERT_THAT(status, Eq(kTfLiteOk)); + TF_ASSERT_OK(QuantizeModelAllOperators( + &model_, TensorType_INT8, TensorType_INT8, + /*allow_float=*/false, TensorType_INT8, output_buffer_)); const auto& subgraph = model_.subgraphs[0]; auto op = subgraph->operators[0].get(); @@ -827,11 +820,9 @@ class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest { }; TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) { - auto status = QuantizeModelAllOperators( - &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, - TensorType_INT8, output_buffer_); - - ASSERT_THAT(status, Eq(kTfLiteOk)); + TF_ASSERT_OK(QuantizeModelAllOperators( + &model_, TensorType_INT8, TensorType_INT8, + /*allow_float=*/false, TensorType_INT8, output_buffer_)); // Verify Reshape is quantized. const auto& subgraph = model_.subgraphs[0]; @@ -879,10 +870,9 @@ TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) { } TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyAddQuantization) { - auto status = QuantizeModelAllOperators( - &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, - TensorType_INT8, output_buffer_); - ASSERT_THAT(status, Eq(kTfLiteOk)); + TF_ASSERT_OK(QuantizeModelAllOperators( + &model_, TensorType_INT8, TensorType_INT8, + /*allow_float=*/false, TensorType_INT8, output_buffer_)); // Verify ADD is quantized. const auto& subgraph = model_.subgraphs[0]; @@ -954,10 +944,9 @@ INSTANTIATE_TEST_SUITE_P(QuantizeConstInputTestInst, QuantizeConstInputTest, testing::ValuesIn({TensorType_INT8})); TEST_P(QuantizeConstInputTest, VerifyConstOpInput) { - auto status = QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, - /*allow_float=*/false, tensor_type_, - output_buffer_); - ASSERT_THAT(status, Eq(kTfLiteOk)); + TF_ASSERT_OK(QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, + /*allow_float=*/false, tensor_type_, + output_buffer_)); // Verify ConstOp is quantized. const auto& subgraph = model_.subgraphs[0]; @@ -998,10 +987,9 @@ class QuantizeArgMaxTest : public QuantizeModelTest { }; TEST_F(QuantizeArgMaxTest, VerifyArgMax) { - auto status = QuantizeModelAllOperators( - &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, - TensorType_INT8, output_buffer_); - ASSERT_THAT(status, Eq(kTfLiteOk)); + TF_ASSERT_OK(QuantizeModelAllOperators( + &model_, TensorType_INT8, TensorType_INT8, + /*allow_float=*/false, TensorType_INT8, output_buffer_)); const auto& subgraph = model_.subgraphs[0]; auto op = subgraph->operators[0].get(); @@ -1044,10 +1032,9 @@ class QuantizeLSTMTest : public QuantizeModelTest { }; TEST_F(QuantizeLSTMTest, VerifyLSTM) { - auto status = QuantizeModelAllOperators( - &model_, TensorType_FLOAT32, TensorType_FLOAT32, /*allow_float=*/true, - TensorType_INT8, output_buffer_); - ASSERT_THAT(status, Eq(kTfLiteOk)); + TF_ASSERT_OK(QuantizeModelAllOperators( + &model_, TensorType_FLOAT32, TensorType_FLOAT32, + /*allow_float=*/true, TensorType_INT8, output_buffer_)); // Read expected model. auto expected_fb_model = ReadModel(internal::kLstmQuantized); @@ -1069,10 +1056,9 @@ class QuantizeLSTM2Test : public QuantizeModelTest { TEST_F(QuantizeLSTM2Test, VerifyLSTM) { // Quantize model. - auto status = QuantizeModelAllOperators( + TF_ASSERT_OK(QuantizeModelAllOperators( &model_, TensorType_FLOAT32, TensorType_FLOAT32, - /*allow_float=*/false, TensorType_INT8, output_buffer_); - ASSERT_THAT(status, Eq(kTfLiteOk)); + /*allow_float=*/false, TensorType_INT8, output_buffer_)); // Read expected model. auto expected_fb_model = ReadModel(internal::kLstmQuantized2); @@ -1094,10 +1080,9 @@ class QuantizeUnidirectionalSequenceLSTMTest : public QuantizeModelTest { TEST_F(QuantizeUnidirectionalSequenceLSTMTest, VerifyUnidirectionalSequenceLSTM) { - auto status = QuantizeModelAllOperators( - &model_, TensorType_FLOAT32, TensorType_FLOAT32, /*allow_float=*/false, - TensorType_INT8, output_buffer_); - ASSERT_THAT(status, Eq(kTfLiteOk)); + TF_ASSERT_OK(QuantizeModelAllOperators( + &model_, TensorType_FLOAT32, TensorType_FLOAT32, + /*allow_float=*/false, TensorType_INT8, output_buffer_)); // Read expected model. auto expected_fb_model = @@ -1120,10 +1105,9 @@ class QuantizeSVDFTest : public QuantizeModelTest { TEST_F(QuantizeSVDFTest, VerifySVDF) { // Quantize model. - auto status = QuantizeModelAllOperators( - &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, - TensorType_INT8, output_buffer_); - ASSERT_THAT(status, Eq(kTfLiteOk)); + TF_ASSERT_OK(QuantizeModelAllOperators( + &model_, TensorType_INT8, TensorType_INT8, + /*allow_float=*/false, TensorType_INT8, output_buffer_)); // Read expected model. auto expected_fb_model = ReadModel(internal::kSvdfQuantized); @@ -1148,10 +1132,9 @@ class QuantizeFCTest : public QuantizeModelTest, }; TEST_P(QuantizeFCTest, VerifyFC8x8) { - auto status = QuantizeModelAllOperators( - &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, - TensorType_INT8, output_buffer_); - ASSERT_THAT(status, Eq(kTfLiteOk)); + TF_ASSERT_OK(QuantizeModelAllOperators( + &model_, TensorType_INT8, TensorType_INT8, + /*allow_float=*/false, TensorType_INT8, output_buffer_)); const auto& subgraph = model_.subgraphs[0]; auto op = subgraph->operators[0].get(); @@ -1201,10 +1184,9 @@ TEST_P(QuantizeFCTest, VerifyFC8x8) { } TEST_P(QuantizeFCTest, VerifyFCFor16x8) { - auto status = QuantizeModelAllOperators( - &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, - TensorType_INT16, output_buffer_); - ASSERT_THAT(status, Eq(kTfLiteOk)); + TF_ASSERT_OK(QuantizeModelAllOperators( + &model_, TensorType_INT8, TensorType_INT8, + /*allow_float=*/false, TensorType_INT16, output_buffer_)); const std::unique_ptr& subgraph = model_.subgraphs[0]; const tflite::OperatorT* op = subgraph->operators[0].get(); @@ -1256,12 +1238,11 @@ TEST_P(QuantizeFCTest, VerifyFCFor16x8) { } TEST_P(QuantizeFCTest, VerifyDisablePerChannelQuantization) { - auto status = QuantizeModelAllOperators( - &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, - TensorType_INT8, output_buffer_, + TF_ASSERT_OK(QuantizeModelAllOperators( + &model_, TensorType_INT8, TensorType_INT8, + /*allow_float=*/false, TensorType_INT8, output_buffer_, /*disable_per_channel_for_dense_layers=*/ - disable_per_channel_quantization_for_dense_); - ASSERT_THAT(status, Eq(kTfLiteOk)); + disable_per_channel_quantization_for_dense_)); const auto& subgraph = model_.subgraphs[0]; auto fc_op = subgraph->operators[0].get(); @@ -1397,10 +1378,9 @@ class QuantizeCustomOpTest }; TEST_P(QuantizeCustomOpTest, VerifyMixedQuantization) { - auto status = QuantizeModelAllOperators(&model_, GetParam(), GetParam(), - /*allow_float=*/true, GetParam(), - output_buffer_); - ASSERT_THAT(status, Eq(kTfLiteOk)); + TF_ASSERT_OK(QuantizeModelAllOperators(&model_, GetParam(), GetParam(), + /*allow_float=*/true, GetParam(), + output_buffer_)); const auto& subgraph = model_.subgraphs[0]; auto float_graph = readonly_model_->subgraphs()->Get(0); // The original model reshape->custom->custom->squeeze. @@ -1436,9 +1416,7 @@ class QuantizePackTest : public QuantizeModelTest { }; TEST_F(QuantizePackTest, VerifyPack) { - auto status = QuantizeModel(&model_, output_buffer_); - - ASSERT_THAT(status, Eq(kTfLiteOk)); + TF_ASSERT_OK(QuantizeModel(&model_, output_buffer_)); const auto subgraph = model_.subgraphs[0].get(); @@ -1500,8 +1478,7 @@ class QuantizeMinimumMaximumTest }; TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) { - auto status = QuantizeModel(&model_, output_buffer_); - ASSERT_THAT(status, Eq(kTfLiteOk)); + TF_ASSERT_OK(QuantizeModel(&model_, output_buffer_)); const auto& subgraph = model_.subgraphs[0]; // Check that the first op is Quantize and the last is Dequant. const auto& quant_op = subgraph->operators[0]; @@ -1563,9 +1540,7 @@ class QuantizeUnpackTest : public QuantizeModelTest { }; TEST_F(QuantizeUnpackTest, VerifyUnpack) { - auto status = QuantizeModel(&model_, output_buffer_); - - ASSERT_THAT(status, Eq(kTfLiteOk)); + TF_ASSERT_OK(QuantizeModel(&model_, output_buffer_)); const auto subgraph = model_.subgraphs[0].get(); auto op = subgraph->operators[1].get(); @@ -1620,10 +1595,9 @@ INSTANTIATE_TEST_SUITE_P(QuantizeBroadcastToModelTestInst, testing::ValuesIn({TensorType_INT8})); TEST_P(QuantizeBroadcastToModelTest, VerifyBroadcastToQuantization) { - auto status = QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, - /*allow_float=*/false, tensor_type_, - output_buffer_); - EXPECT_THAT(status, Eq(kTfLiteOk)); + TF_EXPECT_OK(QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, + /*allow_float=*/false, tensor_type_, + output_buffer_)); // There is only one subgraph. const int32_t subgraph_idx = 0; @@ -1685,10 +1659,9 @@ INSTANTIATE_TEST_SUITE_P(QuantizeGatherNDModelTestInst, testing::ValuesIn({TensorType_INT8})); TEST_P(QuantizeGatherNDModelTest, QuantizeGatherND) { - auto status = QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, - /*allow_float=*/false, tensor_type_, - output_buffer_); - EXPECT_THAT(status, Eq(kTfLiteOk)); + TF_EXPECT_OK(QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, + /*allow_float=*/false, tensor_type_, + output_buffer_)); // There is only one subgraph. const int32_t subgraph_idx = 0; @@ -1743,9 +1716,8 @@ TEST_F(QuantizeWhereModelTest, QuantizeWhere) { // Where operator takes a BOOL tensor as input // and outputs INT64 indices, both of which // should not be quantized - auto status = - QuantizeModel(&model_, TensorType_BOOL, TensorType_INT64, output_buffer_); - EXPECT_THAT(status, Eq(kTfLiteOk)); + TF_EXPECT_OK(QuantizeModel(&model_, TensorType_BOOL, TensorType_INT64, + output_buffer_)); // There is only one subgraph. const int32_t subgraph_idx = 0; diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc index 8eb225689142a2..c070b8ac282a6f 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "absl/status/status.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" @@ -23,7 +24,6 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" -#include "tensorflow/lite/c/c_api_types.h" using llvm::cl::opt; @@ -35,7 +35,7 @@ static opt inputFileName(llvm::cl::Positional, namespace mlir { namespace { -TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer, +absl::Status QuantizeAnnotatedModel(llvm::StringRef buffer, std::string& output_buffer) { return mlir::lite::QuantizeModel( buffer, tflite::TensorType_INT8, tflite::TensorType_INT8, @@ -59,7 +59,8 @@ int main(int argc, char** argv) { std::string output_buffer; if (auto status = mlir::QuantizeAnnotatedModel(buffer->getBuffer().str(), output_buffer); - status != kTfLiteOk) { + !status.ok()) { + llvm::errs() << status.message() << "\n"; return 1; } diff --git a/tensorflow/lite/toco/python/BUILD b/tensorflow/lite/toco/python/BUILD index 991eef9dde4e49..71ba1009e7503c 100644 --- a/tensorflow/lite/toco/python/BUILD +++ b/tensorflow/lite/toco/python/BUILD @@ -70,6 +70,7 @@ cc_library( "//tensorflow/lite/toco/logging:toco_conversion_log_proto_cc", "//third_party/python_runtime:headers", # build_cleaner: keep; DNR: b/35864863 "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf_headers", diff --git a/tensorflow/lite/toco/python/toco_python_api.cc b/tensorflow/lite/toco/python/toco_python_api.cc index 3393bc7737f7d4..0dfbb9630f8cd2 100644 --- a/tensorflow/lite/toco/python/toco_python_api.cc +++ b/tensorflow/lite/toco/python/toco_python_api.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers @@ -386,7 +387,8 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, /*legacy_float_scale=*/true, denylisted_ops, denylisted_nodes, enable_variable_quantization, disable_per_channel_for_dense_layers, debug_options); - if (status != kTfLiteOk) { + if (!status.ok()) { + ABSL_LOG(ERROR) << "QuantizeModel failed: " << status.message(); error_reporter->exception(); return nullptr; } From cd298863621ce928849146803654d7be86110dd0 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 24 Jun 2024 12:44:05 -0700 Subject: [PATCH 178/256] [xla:cpu] Work around missing token dependency support in infeed/outfeed thunks PiperOrigin-RevId: 646188716 --- third_party/xla/xla/service/cpu/runtime/BUILD | 44 ++++++++++++++++ .../xla/service/cpu/runtime/infeed_thunk.cc | 9 ++++ .../service/cpu/runtime/infeed_thunk_test.cc | 50 +++++++++++++++++++ .../xla/service/cpu/runtime/outfeed_thunk.cc | 9 ++++ .../service/cpu/runtime/outfeed_thunk_test.cc | 50 +++++++++++++++++++ 5 files changed, 162 insertions(+) create mode 100644 third_party/xla/xla/service/cpu/runtime/infeed_thunk_test.cc create mode 100644 third_party/xla/xla/service/cpu/runtime/outfeed_thunk_test.cc diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD index 10561be09ea72a..360be19dfe940b 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/service/cpu/runtime/BUILD @@ -592,6 +592,28 @@ cc_library( ], ) +xla_cc_test( + name = "outfeed_thunk_test", + srcs = ["outfeed_thunk_test.cc"], + deps = [ + ":buffer_allocations", + ":outfeed_thunk", + ":thunk", + ":thunk_testlib", + "//xla:shape_util", + "//xla/runtime:buffer_use", + "//xla/service:buffer_assignment", + "//xla/service:maybe_owning_device_memory", + "//xla/stream_executor", + "//xla/tsl/concurrency:async_value", + "@com_google_absl//absl/status", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + cc_library( name = "replica_id_thunk", srcs = ["replica_id_thunk.cc"], @@ -662,6 +684,28 @@ cc_library( ], ) +xla_cc_test( + name = "infeed_thunk_test", + srcs = ["infeed_thunk_test.cc"], + deps = [ + ":buffer_allocations", + ":infeed_thunk", + ":thunk", + ":thunk_testlib", + "//xla:shape_util", + "//xla/runtime:buffer_use", + "//xla/service:buffer_assignment", + "//xla/service:maybe_owning_device_memory", + "//xla/stream_executor", + "//xla/tsl/concurrency:async_value", + "@com_google_absl//absl/status", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + cc_library( name = "kernel_thunk", srcs = ["kernel_thunk.cc"], diff --git a/third_party/xla/xla/service/cpu/runtime/infeed_thunk.cc b/third_party/xla/xla/service/cpu/runtime/infeed_thunk.cc index 2a16510e499d68..b0f6b5e6a34c35 100644 --- a/third_party/xla/xla/service/cpu/runtime/infeed_thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/infeed_thunk.cc @@ -96,6 +96,15 @@ InfeedThunk::BufferUses InfeedThunk::buffer_uses() const { for (const InfeedBuffer& infeed_buffer : infeed_buffers_) { buffer_uses.emplace_back(infeed_buffer.slice, BufferUse::kWrite); } + + // TODO(ezhulenev): It is a hack to make sure that we execute all xfeed + // operations in the same order as in HLO schedule, because otherwise racing + // xfeeds lead to undefined behavior. Instead we should correctly model + // side effects of Thunks. + static auto* fake_alloc = new BufferAllocation(0, 1, 0); + buffer_uses.push_back( + BufferUse::Write(BufferAllocation::Slice(fake_alloc, 0, 1))); + return buffer_uses; } diff --git a/third_party/xla/xla/service/cpu/runtime/infeed_thunk_test.cc b/third_party/xla/xla/service/cpu/runtime/infeed_thunk_test.cc new file mode 100644 index 00000000000000..510903279555a2 --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/infeed_thunk_test.cc @@ -0,0 +1,50 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/runtime/infeed_thunk.h" + +#include + +#include "xla/runtime/buffer_use.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/cpu/runtime/thunk.h" +#include "xla/shape_util.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla::cpu { +namespace { + +TEST(InfeedThunkTest, BufferUses) { + BufferAllocation alloc(0, 1024, 0); + BufferAllocation::Slice infeed_slice(&alloc, 10, 40); + + InfeedThunk::InfeedBuffer infeed_buffer = { + infeed_slice, + ShapeUtil::MakeShape(F32, {10}), + }; + + TF_ASSERT_OK_AND_ASSIGN(auto thunk, + InfeedThunk::Create({"infeed"}, {infeed_buffer})); + + EXPECT_EQ(thunk->buffer_uses().size(), 2); + EXPECT_EQ(thunk->buffer_uses()[0], BufferUse::Write(infeed_slice)); + + BufferAllocation::Slice side_effect_slice(&alloc, 0, 1); + EXPECT_EQ(thunk->buffer_uses()[1], BufferUse::Write(side_effect_slice)); +} + +} // namespace +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/outfeed_thunk.cc b/third_party/xla/xla/service/cpu/runtime/outfeed_thunk.cc index 87c70f4a78ae8b..3edcc23f03ca59 100644 --- a/third_party/xla/xla/service/cpu/runtime/outfeed_thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/outfeed_thunk.cc @@ -97,6 +97,15 @@ OutfeedThunk::BufferUses OutfeedThunk::buffer_uses() const { for (const OutfeedBuffer& outfeed_buffer : outfeed_buffers_) { buffer_uses.emplace_back(outfeed_buffer.slice, BufferUse::kRead); } + + // TODO(ezhulenev): It is a hack to make sure that we execute all xfeed + // operations in the same order as in HLO schedule, because otherwise racing + // xfeeds lead to undefined behavior. Instead we should correctly model + // side effects of Thunks. + static auto* fake_alloc = new BufferAllocation(0, 1, 0); + buffer_uses.push_back( + BufferUse::Write(BufferAllocation::Slice(fake_alloc, 0, 1))); + return buffer_uses; } diff --git a/third_party/xla/xla/service/cpu/runtime/outfeed_thunk_test.cc b/third_party/xla/xla/service/cpu/runtime/outfeed_thunk_test.cc new file mode 100644 index 00000000000000..0f745140424ada --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/outfeed_thunk_test.cc @@ -0,0 +1,50 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/runtime/outfeed_thunk.h" + +#include + +#include "xla/runtime/buffer_use.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/cpu/runtime/thunk.h" +#include "xla/shape_util.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla::cpu { +namespace { + +TEST(OutfeedThunkTest, BufferUses) { + BufferAllocation alloc(0, 1024, 0); + BufferAllocation::Slice outfeed_slice(&alloc, 10, 40); + + OutfeedThunk::OutfeedBuffer outfeed_buffer = { + outfeed_slice, + ShapeUtil::MakeShape(F32, {10}), + }; + + TF_ASSERT_OK_AND_ASSIGN(auto thunk, + OutfeedThunk::Create({"outfeed"}, {outfeed_buffer})); + + EXPECT_EQ(thunk->buffer_uses().size(), 2); + EXPECT_EQ(thunk->buffer_uses()[0], BufferUse::Read(outfeed_slice)); + + BufferAllocation::Slice side_effect_slice(&alloc, 0, 1); + EXPECT_EQ(thunk->buffer_uses()[1], BufferUse::Write(side_effect_slice)); +} + +} // namespace +} // namespace xla::cpu From 4295e4f665444a08aaaaca58d874aa89094a9819 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Mon, 24 Jun 2024 13:03:51 -0700 Subject: [PATCH 179/256] Stop using xla/status.h, xla:status, and xla::Status now that xla::Status is just an alias for an absl::Status PiperOrigin-RevId: 646194531 --- tensorflow/compiler/mlir/lite/BUILD | 2 ++ tensorflow/core/tpu/BUILD | 1 - third_party/xla/xla/BUILD | 3 ++- .../xla/xla/service/algebraic_simplifier.cc | 2 +- third_party/xla/xla/service/cpu/ir_emitter.cc | 2 +- .../service/cpu/onednn_convolution_rewriter.cc | 16 +++++++++------- third_party/xla/xla/service/gpu/BUILD | 3 +-- third_party/xla/xla/service/gpu/gemm_rewriter.cc | 1 - .../xla/xla/service/gpu/gpu_conv_rewriter.cc | 9 +++++---- .../xla/service/memory_space_assignment/BUILD | 1 - .../memory_space_assignment/simulator_test.cc | 1 - third_party/xla/xla/status.h | 11 +---------- third_party/xla/xla/stream_executor/rocm/BUILD | 2 -- third_party/xla/xla/util.h | 2 +- 14 files changed, 23 insertions(+), 33 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index db052746dcda3f..6ab833fcac1394 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -737,6 +737,7 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineAnalysis", @@ -752,6 +753,7 @@ cc_library( "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla:status", + "@local_xla//xla:statusor", "@local_xla//xla/mlir_hlo", "@stablehlo//:stablehlo_ops", ], diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index a27cf6393af06f..4ddc0386b3283d 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -277,7 +277,6 @@ cc_library( "@local_xla//xla:shape_layout", "@local_xla//xla:shape_tree", "@local_xla//xla:shape_util", - "@local_xla//xla:status", "@local_xla//xla:status_macros", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 80bf0ff68b7d94..f81390565195b0 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -301,6 +301,7 @@ xla_cc_test( cc_library( name = "status", hdrs = ["status.h"], + deprecation = "Use @com_google_absl//absl/status instead.", visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/log:check", @@ -337,7 +338,6 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - ":status", ":status_macros", ":types", ":xla_data_proto_cc", @@ -350,6 +350,7 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", diff --git a/third_party/xla/xla/service/algebraic_simplifier.cc b/third_party/xla/xla/service/algebraic_simplifier.cc index 9df5c33d095261..9a8b616b20549e 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.cc +++ b/third_party/xla/xla/service/algebraic_simplifier.cc @@ -9167,7 +9167,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleConvolution( TF_ASSIGN_OR_RETURN(bool can_rewrite_bf16_conv_to_onednn, IsOneDnnRewritableBF16Conv(&convolution)); if (can_rewrite_bf16_conv_to_onednn) { - return OkStatus(); + return absl::OkStatus(); } #endif // INTEL_MKL && ENABLE_ONEDNN_V3 // Try to replace the convolution with a kDot or a kMultiply instruction. diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index a5a87ef0ba2b8d..c33219914d3ff5 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -2774,7 +2774,7 @@ absl::Status IrEmitter::HandleOneDnnConvolution(HloInstruction* custom_call) { b_.CreateLifetimeEnd(args_ptr, b_.getInt64(-1)); result_stack_alloca.EmitLifetimeEnd(); - return OkStatus(); + return absl::OkStatus(); } absl::Status IrEmitter::HandleOneDnnLayerNorm(HloInstruction* custom_call) { diff --git a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.cc index c148f116988ea4..b8533cb2eb7481 100644 --- a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.cc @@ -60,10 +60,12 @@ bool OneDnnConvolutionRewriter::ShouldRewrite(const HloInstruction* conv) { class OneDnnConvolutionRewriterVisitor : public DfsHloRewriteVisitor { public: - Status HandleConvolution(HloInstruction* conv) override { + absl::Status HandleConvolution(HloInstruction* conv) override { auto pattern = match::Op(&conv).WithOpcode(HloOpcode::kConvolution); - if (!Match(conv, pattern)) return OkStatus(); - if (!OneDnnConvolutionRewriter::ShouldRewrite(conv)) return OkStatus(); + if (!Match(conv, pattern)) return absl::OkStatus(); + if (!OneDnnConvolutionRewriter::ShouldRewrite(conv)) { + return absl::OkStatus(); + } const Shape& conv_shape = conv->shape(); auto dims = conv->window().dimensions().size(); @@ -95,7 +97,7 @@ class OneDnnConvolutionRewriterVisitor : public DfsHloRewriteVisitor { it != conv->window().dimensions().end(); it++) { if ((*it).padding_low() < 0 || (*it).padding_high() < 0 || (*it).stride() < 0) { - return OkStatus(); + return absl::OkStatus(); } conv_config->mutable_window()->add_pad_left((*it).padding_low() + 1); conv_config->mutable_window()->add_pad_right((*it).padding_high() + 1); @@ -103,7 +105,7 @@ class OneDnnConvolutionRewriterVisitor : public DfsHloRewriteVisitor { conv_config->mutable_window()->add_window_dilations( (*it).window_dilation() + 1); if ((*it).base_dilation() != 1 || (*it).window_reversal()) { - return OkStatus(); + return absl::OkStatus(); } } @@ -123,11 +125,11 @@ class OneDnnConvolutionRewriterVisitor : public DfsHloRewriteVisitor { TF_RETURN_IF_ERROR(custom_call->set_backend_config(backend_config)); TF_RETURN_IF_ERROR(ReplaceInstruction(conv, custom_call)); - return OkStatus(); + return absl::OkStatus(); } }; -StatusOr OneDnnConvolutionRewriter::Run( +absl::StatusOr OneDnnConvolutionRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { OneDnnConvolutionRewriterVisitor visitor; diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 9e873c5182b2cf..d054ac443cc08c 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1211,7 +1211,6 @@ cc_library( "//xla:literal", "//xla:literal_util", "//xla:shape_util", - "//xla:status", "//xla:status_macros", "//xla:types", "//xla:util", @@ -2143,7 +2142,6 @@ cc_library( ":cublas_cudnn", "//xla:permutation_util", "//xla:shape_util", - "//xla:status", "//xla:util", "//xla:window_util", "//xla:xla_data_proto_cc", @@ -2152,6 +2150,7 @@ cc_library( "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/gemm_rewriter.cc index 1559f64f94450b..e3dd0cfa5fc75f 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter.cc @@ -56,7 +56,6 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status.h" #include "xla/status_macros.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_description.h" diff --git a/third_party/xla/xla/service/gpu/gpu_conv_rewriter.cc b/third_party/xla/xla/service/gpu/gpu_conv_rewriter.cc index 4016623959e4a6..cb5b1867241e58 100644 --- a/third_party/xla/xla/service/gpu/gpu_conv_rewriter.cc +++ b/third_party/xla/xla/service/gpu/gpu_conv_rewriter.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" @@ -40,7 +41,6 @@ limitations under the License. #include "xla/service/gpu/cublas_cudnn.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status.h" #include "xla/stream_executor/device_description.h" #include "xla/util.h" #include "xla/window_util.h" @@ -55,7 +55,8 @@ namespace gpu { namespace { -Status CheckTypes(HloInstruction* conv, const se::GpuComputeCapability cc) { +absl::Status CheckTypes(HloInstruction* conv, + const se::GpuComputeCapability cc) { auto valid_shape = [conv, &cc](const Shape& shape) -> absl::Status { PrimitiveType type = shape.element_type(); if (!primitive_util::IsFloatingPointType(type) && @@ -90,13 +91,13 @@ Status CheckTypes(HloInstruction* conv, const se::GpuComputeCapability cc) { conv->ToString()); } } - return OkStatus(); + return absl::OkStatus(); }; TF_RETURN_IF_ERROR(valid_shape(conv->shape())); TF_RETURN_IF_ERROR(valid_shape(conv->operand(0)->shape())); TF_RETURN_IF_ERROR(valid_shape(conv->operand(1)->shape())); - return OkStatus(); + return absl::OkStatus(); } using ConvolutionMatch = std::optional< diff --git a/third_party/xla/xla/service/memory_space_assignment/BUILD b/third_party/xla/xla/service/memory_space_assignment/BUILD index bade8dc2134936..d6ab413044d66d 100644 --- a/third_party/xla/xla/service/memory_space_assignment/BUILD +++ b/third_party/xla/xla/service/memory_space_assignment/BUILD @@ -327,7 +327,6 @@ xla_cc_test( ":cost_analysis", ":simulator", "//xla:shape_util", - "//xla:status", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", "//xla/service:hlo_alias_analysis", diff --git a/third_party/xla/xla/service/memory_space_assignment/simulator_test.cc b/third_party/xla/xla/service/memory_space_assignment/simulator_test.cc index e379ccecb0f83f..7207ebbbdf670a 100644 --- a/third_party/xla/xla/service/memory_space_assignment/simulator_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/simulator_test.cc @@ -30,7 +30,6 @@ limitations under the License. #include "xla/service/memory_space_assignment/cost_analysis.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status.h" #include "xla/tests/hlo_test_base.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/status.h b/third_party/xla/xla/status.h index 818bfdf4b1ba2d..9ed24549214eed 100644 --- a/third_party/xla/xla/status.h +++ b/third_party/xla/xla/status.h @@ -16,15 +16,6 @@ limitations under the License. #ifndef XLA_STATUS_H_ #define XLA_STATUS_H_ -#include "absl/log/check.h" // IWYU pragma: export -#include "absl/status/status.h" -#include "absl/status/statusor.h" - -namespace xla { -// NOLINTBEGIN(misc-unused-using-decls) -using absl::OkStatus; -using absl::Status; -// NOLINTEND(misc-unused-using-decls) -} // namespace xla +// This is an obsolete header. Please use absl/status/status.h instead. #endif // XLA_STATUS_H_ diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index d206e398d6dfdf..98d472e6f6f9e8 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -544,7 +544,6 @@ cc_library( ":rocm_executor", ":rocm_platform_id", "//xla:shape_util", - "//xla:status", "//xla:status_macros", "//xla:types", "//xla:util", @@ -578,7 +577,6 @@ cc_library( visibility = ["//visibility:public"], deps = if_rocm_is_configured([ # keep sorted - "//xla:status", "//xla:types", "//xla/stream_executor", "//xla/stream_executor:host_or_device_scalar", diff --git a/third_party/xla/xla/util.h b/third_party/xla/xla/util.h index e21fb6fcf00340..0df05a1f0eddd3 100644 --- a/third_party/xla/xla/util.h +++ b/third_party/xla/xla/util.h @@ -38,13 +38,13 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/numeric/bits.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "Eigen/Core" // from @eigen_archive -#include "xla/status.h" #include "xla/status_macros.h" #include "xla/types.h" #include "xla/xla_data.pb.h" From 138b03517af2b784a297ecc1ac216b2fd474e96f Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Mon, 24 Jun 2024 13:24:22 -0700 Subject: [PATCH 180/256] Instead of copybara rules, use `if_google` to remove extra proto deps PiperOrigin-RevId: 646200636 --- third_party/xla/third_party/tsl/tsl/protobuf/BUILD | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/BUILD b/third_party/xla/third_party/tsl/tsl/protobuf/BUILD index 395527e4e22e6e..65000ff408801c 100644 --- a/third_party/xla/third_party/tsl/tsl/protobuf/BUILD +++ b/third_party/xla/third_party/tsl/tsl/protobuf/BUILD @@ -31,6 +31,7 @@ tf_proto_library( name = "dnn_proto", srcs = ["dnn.proto"], make_default_target_header_only = True, + protodeps = if_google(["//google/protobuf:wrappers"]), visibility = ["//visibility:public"], ) @@ -38,6 +39,7 @@ tf_proto_library( name = "error_codes_proto_impl", srcs = ["error_codes.proto"], make_default_target_header_only = True, + protodeps = if_google(["//google/protobuf:any"]), visibility = ["//visibility:public"], ) @@ -73,6 +75,7 @@ tf_proto_library( create_grpc_library = True, create_java_proto = False, create_service = True, + protodeps = if_google(["//google/protobuf:any"]), visibility = ["//visibility:public"], ) @@ -103,6 +106,10 @@ tf_proto_library( name = "test_log_proto", srcs = ["test_log.proto"], make_default_target_header_only = True, + protodeps = if_google([ + "//google/protobuf:any", + "//google/protobuf:wrappers", + ]), visibility = internal_visibility([ "//tensorflow/core:__subpackages__", "@local_xla//xla/tsl/util:__pkg__", @@ -124,6 +131,6 @@ tf_proto_library( ":rpc_options_proto", ":status_proto", ":test_log_proto", - ], + ] + if_google(["//google/protobuf:any"]), visibility = ["//visibility:public"], ) From ea6473f33bae9bcf27400a3c338fd07b9b7e6334 Mon Sep 17 00:00:00 2001 From: Kris Tonthat Date: Mon, 24 Jun 2024 13:43:56 -0700 Subject: [PATCH 181/256] Mark Task Library as deprecated in TFLite documentation. PiperOrigin-RevId: 646207497 --- tensorflow/lite/g3doc/_book.yaml | 1 + tensorflow/lite/g3doc/android/index.md | 10 +++++++++- .../task_library/audio_classifier.md | 8 ++++++++ .../task_library/bert_nl_classifier.md | 8 ++++++++ .../task_library/bert_question_answerer.md | 8 ++++++++ .../task_library/customized_task_api.md | 8 ++++++++ .../task_library/image_classifier.md | 8 ++++++++ .../task_library/image_embedder.md | 8 ++++++++ .../task_library/image_searcher.md | 8 ++++++++ .../task_library/image_segmenter.md | 8 ++++++++ .../task_library/nl_classifier.md | 8 ++++++++ .../task_library/object_detector.md | 8 ++++++++ .../inference_with_metadata/task_library/overview.md | 8 ++++++++ .../task_library/text_embedder.md | 10 +++++++++- .../task_library/text_searcher.md | 8 ++++++++ 15 files changed, 115 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/g3doc/_book.yaml b/tensorflow/lite/g3doc/_book.yaml index a229d8baf5e0c3..dbd6f81e7652e6 100644 --- a/tensorflow/lite/g3doc/_book.yaml +++ b/tensorflow/lite/g3doc/_book.yaml @@ -23,6 +23,7 @@ upper_tabs: - heading: "Libraries and tools" - title: "Task Library" + status: deprecated section: - title: "Overview" path: /lite/inference_with_metadata/task_library/overview diff --git a/tensorflow/lite/g3doc/android/index.md b/tensorflow/lite/g3doc/android/index.md index 97603760c2d983..3f456f074ac085 100644 --- a/tensorflow/lite/g3doc/android/index.md +++ b/tensorflow/lite/g3doc/android/index.md @@ -134,10 +134,18 @@ environment, which are described in the ### Development APIs and libraries {:#apis} + + There are two main APIs you can use to integrate TensorFlow Lite machine learning models into your Android app: -* **[TensorFlow Lite Task API](../api_docs/java/org/tensorflow/lite/task/core/package-summary) (recommended)** +* [TensorFlow Lite Task API](../api_docs/java/org/tensorflow/lite/task/core/package-summary) * [TensorFlow Lite Interpreter API](../api_docs/java/org/tensorflow/lite/InterpreterApi) The diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/audio_classifier.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/audio_classifier.md index 5f62b56c0fde2c..5667938f9e1283 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/audio_classifier.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/audio_classifier.md @@ -1,5 +1,13 @@ # Integrate audio classifiers + + Audio classification is a common use case of Machine Learning to classify the sound types. For example, it can identify the bird species by their songs. diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_nl_classifier.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_nl_classifier.md index f156880316ebb5..3a8c18a565113f 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_nl_classifier.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_nl_classifier.md @@ -1,5 +1,13 @@ # Integrate BERT natural language classifier + + The Task Library `BertNLClassifier` API is very similar to the `NLClassifier` that classifies input text into different categories, except that this API is specially tailored for Bert related models that require Wordpiece and diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md index 2754eb1bd0ba9a..675ef8de8349b2 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md @@ -1,5 +1,13 @@ # Integrate BERT question answerer + + The Task Library `BertQuestionAnswerer` API loads a Bert model and answers questions based on the content of a given passage. For more information, see the documentation for the Question-Answer model diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/customized_task_api.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/customized_task_api.md index c9d2d7dc06abcd..07fd64a7d2118a 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/customized_task_api.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/customized_task_api.md @@ -1,5 +1,13 @@ # Build you own Task API + + TensorFlow Lite Task Library provides prebuilt native/Android/iOS APIs on top of the same infrastructure that abstracts TensorFlow. You can extend the Task API infrastructure to build customized APIs diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_classifier.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_classifier.md index 15725b50aff688..ea1d846f5f8e3c 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_classifier.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_classifier.md @@ -1,5 +1,13 @@ # Integrate image classifiers + + Image classification is a common use of machine learning to identify what an image represents. For example, we might want to know what type of animal appears in a given picture. The task of predicting what an image represents is called diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_embedder.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_embedder.md index 1c134ca7f543fc..66d5c290df8ac5 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_embedder.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_embedder.md @@ -1,5 +1,13 @@ # Integrate image embedders + + Image embedders allow embedding images into a high-dimensional feature vector representing the semantic meaning of an image, which can then be compared with the feature vector of other images to evaluate their semantic similarity. diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_searcher.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_searcher.md index fa1a901f17fcc8..e6f354f1ffd21e 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_searcher.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_searcher.md @@ -1,5 +1,13 @@ # Integrate image searchers + + Image search allows searching for similar images in a database of images. It works by embedding the search query into a high-dimensional vector representing the semantic meaning of the query, followed by similarity search in a diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_segmenter.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_segmenter.md index cf92a139b6c548..ab727756bdecaf 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_segmenter.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_segmenter.md @@ -1,5 +1,13 @@ # Integrate image segmenters + + Image segmenters predict whether each pixel of an image is associated with a certain class. This is in contrast to object detection, diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/nl_classifier.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/nl_classifier.md index 2c4b2459aca5d1..f3f75ca5aad97c 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/nl_classifier.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/nl_classifier.md @@ -1,5 +1,13 @@ # Integrate Natural language classifier + + The Task Library's `NLClassifier` API classifies input text into different categories, and is a versatile and configurable API that can handle most text classification models. diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/object_detector.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/object_detector.md index 092e6be1395d82..46644d5ac56b87 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/object_detector.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/object_detector.md @@ -1,5 +1,13 @@ # Integrate object detectors + + Object detectors can identify which of a known set of objects might be present and provide information about their positions within the given image or a video stream. An object detector is trained to detect the presence and location of diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/overview.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/overview.md index 3ade16ce142a8e..139b95a8395d5b 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/overview.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/overview.md @@ -1,5 +1,13 @@ # TensorFlow Lite Task Library + + TensorFlow Lite Task Library contains a set of powerful and easy-to-use task-specific libraries for app developers to create ML experiences with TFLite. It provides optimized out-of-box model interfaces for popular machine learning diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/text_embedder.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/text_embedder.md index 91a0db27fd8ff5..71bd3c3cd4b3a7 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/text_embedder.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/text_embedder.md @@ -1,4 +1,12 @@ -# Integrate text embedders. +# Integrate text embedders + + Text embedders allow embedding text into a high-dimensional feature vector representing its semantic meaning, which can then be compared with the feature diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/text_searcher.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/text_searcher.md index 000d2b9ed28ddf..657a6418fd5910 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/text_searcher.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/text_searcher.md @@ -1,5 +1,13 @@ # Integrate text searchers + + Text search allows searching for semantically similar text in a corpus. It works by embedding the search query into a high-dimensional vector representing the semantic meaning of the query, followed by similarity search in a predefined, From 8583ef6dd875447a008ec6110a1baba6dc6fe616 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Mon, 24 Jun 2024 13:46:29 -0700 Subject: [PATCH 182/256] PR #14082: [MLIR][MHLO] Repack EliminateBroadcastInDimTranspose as OpRewritePattern class Imported from GitHub PR https://github.com/openxla/xla/pull/14082 `mlir::Canonicalizer` pass constructor allows to specify `disabledPatterns`. `disabledPatterns` - is a set of names used to filter out input patterns with **debug name** in this set. The `EliminateBroadcastInDimTranspose` pattern was previously implemented as a function, resulting in an empty debug name. Consequently, this pattern could not be disabled when necessary. ## Solution To address this issue, the `EliminateBroadcastInDimTranspose` pattern has been refactored into a class that derives from `OpRewritePattern`. This change ensures that the pattern now has an automatically generated debug name: `mlir::mhlo::EliminateBroadcastInDimTranspose`. Additionally, similar changes have been applied to other patterns listed in `TransposeOp::getCanonicalizationPatterns()`. The debug names of the refactored patterns are as follows: ```c mlir::mhlo::EliminateRedundantTranspose mlir::mhlo::EliminateBroadcastInDimTranspose mlir::mhlo::SimplifyTranspose ``` These changes improve the flexibility and maintainability of the pattern disabling mechanism within the `mlir::Canonicalizer` pass. Copybara import of the project: -- 355a79955a3d6ef5b19f3f359be8323082e25fb4 by Alexander Pivovarov : Refactor EliminateBroadcastInDimTranspose as OpRewritePattern class Merging this change closes #14082 PiperOrigin-RevId: 646208210 --- .../xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc | 147 ++++++++++-------- 1 file changed, 80 insertions(+), 67 deletions(-) diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index 740c32e4f4e858..5ec514d5b2f982 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -5378,84 +5378,97 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { } // transpose(transpose(X)) => transpose(X) -static LogicalResult eliminateRedundantTranspose(TransposeOp op, - PatternRewriter& rewriter) { - auto tranposeOperand = op.getOperand().getDefiningOp(); - if (!tranposeOperand) { - return failure(); +class EliminateRedundantTranspose : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TransposeOp op, + PatternRewriter& rewriter) const override { + auto tranposeOperand = op.getOperand().getDefiningOp(); + if (!tranposeOperand) { + return failure(); + } + auto operandPermutation = + tranposeOperand.getPermutation().getValues(); + auto newPermutation = + cast(op.getPermutation().mapValues( + op.getPermutation().getElementType(), + [&operandPermutation](const APInt& index) -> APInt { + return operandPermutation[index.getSExtValue()]; + })); + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + tranposeOperand.getOperand(), + newPermutation); + return success(); } - auto operandPermutation = tranposeOperand.getPermutation().getValues(); - auto newPermutation = - cast(op.getPermutation().mapValues( - op.getPermutation().getElementType(), - [&operandPermutation](const APInt& index) -> APInt { - return operandPermutation[index.getSExtValue()]; - })); - rewriter.replaceOpWithNewOp(op, op.getResult().getType(), - tranposeOperand.getOperand(), - newPermutation); - return success(); -} +}; -// transpose(broadcast_in_dim(X)) => broadcast_in_dim(X) -static LogicalResult eliminateBroadcastInDimTranspose( - TransposeOp op, PatternRewriter& rewriter) { - auto broadcastInDimOp = op.getOperand().getDefiningOp(); - if (!broadcastInDimOp) { - return failure(); - } - DenseIntElementsAttr broadcastDimensions = - broadcastInDimOp.getBroadcastDimensions(); - DenseIntElementsAttr permutation = op.getPermutation(); - SmallVector newBroadcastDimensions; - for (auto dimension : broadcastDimensions.getValues()) { - int64_t index = 0; - for (auto p : permutation.getValues()) { - if (p == dimension) { - newBroadcastDimensions.push_back(index); - break; +// BroadcastInDim(BroadcastInDim(X)) => BroadcastInDim(X) +class EliminateBroadcastInDimTranspose : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TransposeOp op, + PatternRewriter& rewriter) const override { + auto broadcastInDimOp = op.getOperand().getDefiningOp(); + if (!broadcastInDimOp) { + return failure(); + } + DenseIntElementsAttr broadcastDimensions = + broadcastInDimOp.getBroadcastDimensions(); + DenseIntElementsAttr permutation = op.getPermutation(); + SmallVector newBroadcastDimensions; + for (auto dimension : broadcastDimensions.getValues()) { + int64_t index = 0; + for (auto p : permutation.getValues()) { + if (p == dimension) { + newBroadcastDimensions.push_back(index); + break; + } + index++; } - index++; } + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), broadcastInDimOp.getOperand(), + rewriter.getI64TensorAttr(newBroadcastDimensions)); + return success(); } - rewriter.replaceOpWithNewOp( - op, op->getResultTypes(), broadcastInDimOp.getOperand(), - rewriter.getI64TensorAttr(newBroadcastDimensions)); - return success(); -} +}; // simplify Transpose: replace Transpose with Reshape if they are equivalent -static LogicalResult simplifyTranspose(TransposeOp op, - PatternRewriter& rewriter) { - auto operandType = dyn_cast(op.getOperand().getType()); - auto resultType = dyn_cast(op.getResult().getType()); - if (!operandType || !resultType) { - return failure(); - } - // Not support dynamic shape a.t.m. BTW, when it's dynamic shape, - // maybe Transpose should be replaced by DynamicReshape. - if (!operandType.hasStaticShape() || !resultType.hasStaticShape()) { - return failure(); - } - auto permutation = op.getPermutation().getValues(); - llvm::SmallVector sortedPermutation; - for (int64_t i = 0, e = resultType.getRank(); i < e; i++) { - if (resultType.getDimSize(i) != 1) { - sortedPermutation.push_back(permutation[i]); +class SimplifyTranspose : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TransposeOp op, + PatternRewriter& rewriter) const override { + auto operandType = dyn_cast(op.getOperand().getType()); + auto resultType = dyn_cast(op.getResult().getType()); + if (!operandType || !resultType) { + return failure(); } + // Not support dynamic shape a.t.m. BTW, when it's dynamic shape, + // maybe Transpose should be replaced by DynamicReshape. + if (!operandType.hasStaticShape() || !resultType.hasStaticShape()) { + return failure(); + } + auto permutation = op.getPermutation().getValues(); + llvm::SmallVector sortedPermutation; + for (int64_t i = 0, e = resultType.getRank(); i < e; i++) { + if (resultType.getDimSize(i) != 1) { + sortedPermutation.push_back(permutation[i]); + } + } + if (llvm::is_sorted(sortedPermutation)) { + rewriter.replaceOpWithNewOp(op, op.getType(), op.getOperand()); + return success(); + } + return failure(); } - if (llvm::is_sorted(sortedPermutation)) { - rewriter.replaceOpWithNewOp(op, op.getType(), op.getOperand()); - return success(); - } - return failure(); -} +}; void TransposeOp::getCanonicalizationPatterns(RewritePatternSet& results, - MLIRContext* /*context*/) { - results.add(eliminateRedundantTranspose); - results.add(eliminateBroadcastInDimTranspose); - results.add(simplifyTranspose); + MLIRContext* context) { + results.add(context); + results.add(context); + results.add(context); } LogicalResult TransposeOp::reifyReturnTypeShapes( From 7f4cd34e384b2094fd4fa98e32754c38136b56ce Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Mon, 24 Jun 2024 13:47:35 -0700 Subject: [PATCH 183/256] Fix ROCm xla build. PiperOrigin-RevId: 646208546 --- third_party/xla/xla/client/client.h | 2 +- .../xla/xla/stream_executor/gpu/gpu_stream.cc | 9 +++++---- .../xla/xla/stream_executor/rocm/rocm_executor.cc | 15 --------------- 3 files changed, 6 insertions(+), 20 deletions(-) diff --git a/third_party/xla/xla/client/client.h b/third_party/xla/xla/client/client.h index 49c9911f5cdf1b..f3eacbcad3ec06 100644 --- a/third_party/xla/xla/client/client.h +++ b/third_party/xla/xla/client/client.h @@ -40,7 +40,7 @@ class Client { explicit Client(Service* stub); virtual ~Client(); - using XlaComputationInstance = XlaComputationInstance; + using XlaComputationInstance = xla::XlaComputationInstance; // Compile the computation with the given argument shapes and returns the // handle to the compiled executable. The compiled executable is cached on the diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc index 450e6a51acb5eb..98797cf5dc443d 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc @@ -81,9 +81,9 @@ absl::Status GpuStream::Memcpy(DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, uint64_t size) { if (GpuDriver::AsynchronousMemcpyD2D( parent_->gpu_context(), - reinterpret_cast(gpu_dst->opaque()), - reinterpret_cast(gpu_src.opaque()), size, - gpu_stream())) { + reinterpret_cast(const_cast(gpu_dst->opaque())), + reinterpret_cast(const_cast(gpu_src.opaque())), + size, gpu_stream())) { return absl::OkStatus(); } @@ -105,7 +105,8 @@ absl::Status GpuStream::Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, uint64_t size) { bool ok = GpuDriver::AsynchronousMemcpyD2H( parent_->gpu_context(), host_dst, - reinterpret_cast(gpu_src.opaque()), size, gpu_stream()); + reinterpret_cast(const_cast(gpu_src.opaque())), size, + gpu_stream()); // TODO(b/326130105): Change AsynchronousMemcpyD2H calls to return // absl::Status. if (!ok) { diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc index 12b05f307c659e..f405de0c781537 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc @@ -582,21 +582,6 @@ absl::Status GpuExecutor::Memset(Stream* stream, DeviceMemoryBase* location, AsGpuStreamValue(stream)); } -absl::Status GpuExecutor::Memcpy(Stream* stream, void* host_dst, - const DeviceMemoryBase& gpu_src, - uint64_t size) { - bool ok = GpuDriver::AsynchronousMemcpyD2H(context_, host_dst, - AsROCmDevicePtr(gpu_src), size, - AsGpuStreamValue(stream)); - - // TODO(b/326130105): Change AsynchronousMemcpyD2H calls to return - // absl::Status. - if (!ok) { - return absl::InternalError("Failed to memcpy from device to host."); - } - return absl::OkStatus(); -} - bool GpuExecutor::HostCallback(Stream* stream, absl::AnyInvocable callback) { auto callback_ptr = From efad0872a0fd3f35f9a17c2635155d428f3fcc9b Mon Sep 17 00:00:00 2001 From: Kris Tonthat Date: Mon, 24 Jun 2024 13:59:03 -0700 Subject: [PATCH 184/256] Deprecate TensorFlow Lite Model Maker Library PiperOrigin-RevId: 646212055 --- tensorflow/lite/g3doc/_book.yaml | 3 +-- .../modify/model_maker/audio_classification.ipynb | 11 ++++++++--- .../modify/model_maker/image_classification.ipynb | 8 ++++++++ .../lite/g3doc/models/modify/model_maker/index.md | 8 ++++++++ .../models/modify/model_maker/object_detection.ipynb | 10 +++++++++- .../models/modify/model_maker/question_answer.ipynb | 8 ++++++++ .../modify/model_maker/speech_recognition.ipynb | 7 +++++++ .../modify/model_maker/text_classification.ipynb | 8 ++++++++ .../models/modify/model_maker/text_searcher.ipynb | 8 ++++++++ 9 files changed, 65 insertions(+), 6 deletions(-) diff --git a/tensorflow/lite/g3doc/_book.yaml b/tensorflow/lite/g3doc/_book.yaml index dbd6f81e7652e6..8b04e13063d0b9 100644 --- a/tensorflow/lite/g3doc/_book.yaml +++ b/tensorflow/lite/g3doc/_book.yaml @@ -236,7 +236,7 @@ upper_tabs: - heading: "Modify models" - title: "Model Maker" - status: experimental + status: deprecated section: - title: "Overview" path: /lite/models/modify/model_maker @@ -253,7 +253,6 @@ upper_tabs: path: /lite/models/modify/model_maker/question_answer - title: "Text search" path: /lite/models/modify/model_maker/text_searcher - status: New - heading: "Audio" - title: "Audio classification" path: /lite/models/modify/model_maker/audio_classification diff --git a/tensorflow/lite/g3doc/models/modify/model_maker/audio_classification.ipynb b/tensorflow/lite/g3doc/models/modify/model_maker/audio_classification.ipynb index e5ea411a6ec0e7..4eadaf1373bf66 100644 --- a/tensorflow/lite/g3doc/models/modify/model_maker/audio_classification.ipynb +++ b/tensorflow/lite/g3doc/models/modify/model_maker/audio_classification.ipynb @@ -64,6 +64,13 @@ "id": "BB5k6xNKJ5Xe" }, "source": [ + "\u003caside class=\"warning\"\u003e\n", + " \u003cp\u003e\u003cb\u003eWarning:\u003c/b\u003e The\n", + " \u003ca href=\"https://www.tensorflow.org/lite/models/modify/model_maker\"\u003e\n", + " TensorFlow Lite Model Maker Library\u003c/a\u003e is deprecated and replaced by\n", + " \u003ca href=\"https://ai.google.dev/edge/mediapipe/solutions/model_maker\"\u003e\n", + " MediaPipe Model Maker\u003c/a\u003e.\u003c/p\u003e\n", + "\u003c/aside\u003e\n", "\n", "In this colab notebook, you'll learn how to use the [TensorFlow Lite Model Maker](https://www.tensorflow.org/lite/models/modify/model_maker) to train a custom audio classification model.\n", "\n", @@ -383,9 +390,7 @@ "metadata": { "id": "xEM1aRNKtvHS" }, - "source": [ - "" - ] + "source": [] }, { "cell_type": "code", diff --git a/tensorflow/lite/g3doc/models/modify/model_maker/image_classification.ipynb b/tensorflow/lite/g3doc/models/modify/model_maker/image_classification.ipynb index 4eb4c3d4a733f9..795b9fac43f162 100644 --- a/tensorflow/lite/g3doc/models/modify/model_maker/image_classification.ipynb +++ b/tensorflow/lite/g3doc/models/modify/model_maker/image_classification.ipynb @@ -71,6 +71,14 @@ "id": "m86-Nh4pMHqY" }, "source": [ + "\u003caside class=\"warning\"\u003e\n", + " \u003cp\u003e\u003cb\u003eWarning:\u003c/b\u003e The\n", + " \u003ca href=\"https://www.tensorflow.org/lite/models/modify/model_maker\"\u003e\n", + " TensorFlow Lite Model Maker Library\u003c/a\u003e is deprecated and replaced by\n", + " \u003ca href=\"https://ai.google.dev/edge/mediapipe/solutions/model_maker\"\u003e\n", + " MediaPipe Model Maker\u003c/a\u003e.\u003c/p\u003e\n", + "\u003c/aside\u003e\n", + "\n", "The [TensorFlow Lite Model Maker library](https://www.tensorflow.org/lite/models/modify/model_maker) simplifies the process of adapting and converting a TensorFlow neural-network model to particular input data when deploying this model for on-device ML applications.\n", "\n", "This notebook shows an end-to-end example that utilizes this Model Maker library to illustrate the adaption and conversion of a commonly-used image classification model to classify flowers on a mobile device." diff --git a/tensorflow/lite/g3doc/models/modify/model_maker/index.md b/tensorflow/lite/g3doc/models/modify/model_maker/index.md index e1447e4e71bd1f..b09ec6abcdc296 100644 --- a/tensorflow/lite/g3doc/models/modify/model_maker/index.md +++ b/tensorflow/lite/g3doc/models/modify/model_maker/index.md @@ -1,5 +1,13 @@ # TensorFlow Lite Model Maker + + ## Overview The TensorFlow Lite Model Maker library simplifies the process of training a diff --git a/tensorflow/lite/g3doc/models/modify/model_maker/object_detection.ipynb b/tensorflow/lite/g3doc/models/modify/model_maker/object_detection.ipynb index 0946c2957d579d..037f9d3a3bfce4 100644 --- a/tensorflow/lite/g3doc/models/modify/model_maker/object_detection.ipynb +++ b/tensorflow/lite/g3doc/models/modify/model_maker/object_detection.ipynb @@ -68,13 +68,21 @@ "id": "sr3q-gvm3cI8" }, "source": [ + "\u003caside class=\"warning\"\u003e\n", + " \u003cp\u003e\u003cb\u003eWarning:\u003c/b\u003e The\n", + " \u003ca href=\"https://www.tensorflow.org/lite/models/modify/model_maker\"\u003e\n", + " TensorFlow Lite Model Maker Library\u003c/a\u003e is deprecated and replaced by\n", + " \u003ca href=\"https://ai.google.dev/edge/mediapipe/solutions/model_maker\"\u003e\n", + " MediaPipe Model Maker\u003c/a\u003e.\u003c/p\u003e\n", + "\u003c/aside\u003e\n", + "\n", "In this colab notebook, you'll learn how to use the [TensorFlow Lite Model Maker](https://www.tensorflow.org/lite/models/modify/model_maker) library to train a custom object detection model capable of detecting salads within images on a mobile device.\n", "\n", "The Model Maker library uses *transfer learning* to simplify the process of training a TensorFlow Lite model using a custom dataset. Retraining a TensorFlow Lite model with your own custom dataset reduces the amount of training data required and will shorten the training time.\n", "\n", "You'll use the publicly available *Salads* dataset, which was created from the [Open Images Dataset V4](https://storage.googleapis.com/openimages/web/index.html).\n", "\n", - "Each image in the dataset \u0008contains objects labeled as one of the following classes:\n", + "Each image in the dataset \bcontains objects labeled as one of the following classes:\n", "* Baked Good\n", "* Cheese\n", "* Salad\n", diff --git a/tensorflow/lite/g3doc/models/modify/model_maker/question_answer.ipynb b/tensorflow/lite/g3doc/models/modify/model_maker/question_answer.ipynb index 706afac6c24ba7..fc6aae59cbfc6b 100644 --- a/tensorflow/lite/g3doc/models/modify/model_maker/question_answer.ipynb +++ b/tensorflow/lite/g3doc/models/modify/model_maker/question_answer.ipynb @@ -68,6 +68,14 @@ "id": "sr3q-gvm3cI8" }, "source": [ + "\u003caside class=\"warning\"\u003e\n", + " \u003cp\u003e\u003cb\u003eWarning:\u003c/b\u003e The\n", + " \u003ca href=\"https://www.tensorflow.org/lite/models/modify/model_maker\"\u003e\n", + " TensorFlow Lite Model Maker Library\u003c/a\u003e is deprecated and replaced by\n", + " \u003ca href=\"https://ai.google.dev/edge/mediapipe/solutions/model_maker\"\u003e\n", + " MediaPipe Model Maker\u003c/a\u003e.\u003c/p\u003e\n", + "\u003c/aside\u003e\n", + "\n", "The [TensorFlow Lite Model Maker library](https://www.tensorflow.org/lite/models/modify/model_maker) simplifies the process of adapting and converting a TensorFlow model to particular input data when deploying this model for on-device ML applications.\n", "\n", "This notebook shows an end-to-end example that utilizes the Model Maker library to illustrate the adaptation and conversion of a commonly-used question answer model for question answer task." diff --git a/tensorflow/lite/g3doc/models/modify/model_maker/speech_recognition.ipynb b/tensorflow/lite/g3doc/models/modify/model_maker/speech_recognition.ipynb index 765b957e75c75b..096bedc5118b55 100644 --- a/tensorflow/lite/g3doc/models/modify/model_maker/speech_recognition.ipynb +++ b/tensorflow/lite/g3doc/models/modify/model_maker/speech_recognition.ipynb @@ -69,6 +69,13 @@ "id": "BB5k6xNKJ5Xe" }, "source": [ + "\u003caside class=\"warning\"\u003e\n", + " \u003cp\u003e\u003cb\u003eWarning:\u003c/b\u003e The\n", + " \u003ca href=\"https://www.tensorflow.org/lite/models/modify/model_maker\"\u003e\n", + " TensorFlow Lite Model Maker Library\u003c/a\u003e is deprecated and replaced by\n", + " \u003ca href=\"https://ai.google.dev/edge/mediapipe/solutions/model_maker\"\u003e\n", + " MediaPipe Model Maker\u003c/a\u003e.\u003c/p\u003e\n", + "\u003c/aside\u003e\n", "\n", "In this colab notebook, you'll learn how to use the [TensorFlow Lite Model Maker](https://www.tensorflow.org/lite/models/modify/model_maker) to train a speech recognition model that can classify spoken words or short phrases using one-second sound samples. The Model Maker library uses transfer learning to retrain an existing TensorFlow model with a new dataset, which reduces the amount of sample data and time required for training. \n", "\n", diff --git a/tensorflow/lite/g3doc/models/modify/model_maker/text_classification.ipynb b/tensorflow/lite/g3doc/models/modify/model_maker/text_classification.ipynb index 881945469b0049..83947fd12b3460 100644 --- a/tensorflow/lite/g3doc/models/modify/model_maker/text_classification.ipynb +++ b/tensorflow/lite/g3doc/models/modify/model_maker/text_classification.ipynb @@ -68,6 +68,14 @@ "id": "sr3q-gvm3cI8" }, "source": [ + "\u003caside class=\"warning\"\u003e\n", + " \u003cp\u003e\u003cb\u003eWarning:\u003c/b\u003e The\n", + " \u003ca href=\"https://www.tensorflow.org/lite/models/modify/model_maker\"\u003e\n", + " TensorFlow Lite Model Maker Library\u003c/a\u003e is deprecated and replaced by\n", + " \u003ca href=\"https://ai.google.dev/edge/mediapipe/solutions/model_maker\"\u003e\n", + " MediaPipe Model Maker\u003c/a\u003e.\u003c/p\u003e\n", + "\u003c/aside\u003e\n", + "\n", "The [TensorFlow Lite Model Maker library](https://www.tensorflow.org/lite/models/modify/model_maker) simplifies the process of adapting and converting a TensorFlow model to particular input data when deploying this model for on-device ML applications.\n", "\n", "This notebook shows an end-to-end example that utilizes the Model Maker library to illustrate the adaptation and conversion of a commonly-used text classification model to classify movie reviews on a mobile device. The text classification model classifies text into predefined categories. The inputs should be preprocessed text and the outputs are the probabilities of the categories. The dataset used in this tutorial are positive and negative movie reviews." diff --git a/tensorflow/lite/g3doc/models/modify/model_maker/text_searcher.ipynb b/tensorflow/lite/g3doc/models/modify/model_maker/text_searcher.ipynb index 6f595a830426a9..9c1b9e518864c5 100644 --- a/tensorflow/lite/g3doc/models/modify/model_maker/text_searcher.ipynb +++ b/tensorflow/lite/g3doc/models/modify/model_maker/text_searcher.ipynb @@ -70,6 +70,14 @@ "id": "c2sdIlXEVPZR" }, "source": [ + "\u003caside class=\"warning\"\u003e\n", + " \u003cp\u003e\u003cb\u003eWarning:\u003c/b\u003e The\n", + " \u003ca href=\"https://www.tensorflow.org/lite/models/modify/model_maker\"\u003e\n", + " TensorFlow Lite Model Maker Library\u003c/a\u003e is deprecated and replaced by\n", + " \u003ca href=\"https://ai.google.dev/edge/mediapipe/solutions/model_maker\"\u003e\n", + " MediaPipe Model Maker\u003c/a\u003e.\u003c/p\u003e\n", + "\u003c/aside\u003e\n", + "\n", "In this colab notebook, you can learn how to use the [TensorFlow Lite Model Maker](https://www.tensorflow.org/lite/models/modify/model_maker) library to create a TFLite Searcher model. You can use a text Searcher model to build Semantic Search or Smart Reply for your app. This type of model lets you take a text query and search for the most related entries in a text dataset, such as a database of web pages. The model returns a list of the smallest distance scoring entries in the dataset, including metadata you specify, such as URL, page title, or other text entry identifiers. After building this, you can deploy it onto devices (e.g. Android) using [Task Library Searcher API](https://www.tensorflow.org/lite/inference_with_metadata/task_library/text_searcher) to run inference with just a few lines of code.\n", "\n", "This tutorial leverages CNN/DailyMail dataset as an instance to create the TFLite Searcher model. You can try with your own dataset with the compatible input comma separated value (CSV) format." From 3efa90487df90fef4050188a7a8c9bf6130c8343 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jun 2024 15:06:44 -0700 Subject: [PATCH 185/256] Adding some VLOGs to async collective creator and merger for debugging. PiperOrigin-RevId: 646233999 --- third_party/xla/xla/service/BUILD | 1 + third_party/xla/xla/service/async_collective_creator.cc | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index f78cd4e18038db..17c8f39a482604 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -119,6 +119,7 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@local_tsl//tsl/platform:errors", ], ) diff --git a/third_party/xla/xla/service/async_collective_creator.cc b/third_party/xla/xla/service/async_collective_creator.cc index 58ab8a56100903..d46f17fd37eb6b 100644 --- a/third_party/xla/xla/service/async_collective_creator.cc +++ b/third_party/xla/xla/service/async_collective_creator.cc @@ -15,10 +15,12 @@ limitations under the License. #include "xla/service/async_collective_creator.h" +#include #include #include #include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" #include "xla/frontend_attributes.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -226,6 +228,7 @@ absl::StatusOr AsyncCollectiveCreator::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; + int64_t collectives_replaced = 0; for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { std::vector supported_collectives = @@ -235,8 +238,11 @@ absl::StatusOr AsyncCollectiveCreator::Run( } TF_ASSIGN_OR_RETURN(bool comp_changed, ReplaceCollectives(computation, supported_collectives)); + collectives_replaced += supported_collectives.size(); changed |= comp_changed; } + VLOG(1) << "Replaced " << collectives_replaced + << " sync collectives with async versions."; return changed; } From 5d94a25b7e789705db87f67b880f019af71f7ace Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Mon, 24 Jun 2024 15:25:17 -0700 Subject: [PATCH 186/256] Change visibilities to public in `tsl/lib` PiperOrigin-RevId: 646239154 --- .../xla/third_party/tsl/tsl/lib/histogram/BUILD | 6 +----- .../xla/third_party/tsl/tsl/lib/monitoring/BUILD | 9 ++------- third_party/xla/third_party/tsl/tsl/lib/random/BUILD | 7 ++----- third_party/xla/third_party/tsl/tsl/lib/strings/BUILD | 10 +--------- 4 files changed, 6 insertions(+), 26 deletions(-) diff --git a/third_party/xla/third_party/tsl/tsl/lib/histogram/BUILD b/third_party/xla/third_party/tsl/tsl/lib/histogram/BUILD index 7d6f66ad8497ed..4de34f8e390755 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/histogram/BUILD +++ b/third_party/xla/third_party/tsl/tsl/lib/histogram/BUILD @@ -18,11 +18,7 @@ cc_library( name = "histogram", srcs = ["histogram.cc"], hdrs = ["histogram.h"], - visibility = internal_visibility([ - "//learning/brain/google/monitoring:__pkg__", - "//tensorflow/core/lib/histogram:__pkg__", - "//tsl/lib/monitoring:__pkg__", - ]), + visibility = ["//visibility:public"], deps = [ "//tsl/platform:logging", "//tsl/platform:macros", diff --git a/third_party/xla/third_party/tsl/tsl/lib/monitoring/BUILD b/third_party/xla/third_party/tsl/tsl/lib/monitoring/BUILD index ab00117320cc5b..0d8129444005fa 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/monitoring/BUILD +++ b/third_party/xla/third_party/tsl/tsl/lib/monitoring/BUILD @@ -95,9 +95,7 @@ cc_library( cc_library( name = "metric_def", hdrs = ["metric_def.h"], - visibility = internal_visibility([ - "//tensorflow/core:__subpackages__", - ]), + visibility = ["//visibility:public"], deps = [ ":types", "//tsl/platform:stringpiece", @@ -110,10 +108,7 @@ cc_library( name = "collection_registry", srcs = ["collection_registry.cc"], hdrs = ["collection_registry.h"], - visibility = internal_visibility([ - "//tensorflow/core:__subpackages__", - "//tensorflow_serving/model_servers:__pkg__", - ]), + visibility = ["//visibility:public"], deps = [ ":collected_metrics", ":metric_def", diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/BUILD b/third_party/xla/third_party/tsl/tsl/lib/random/BUILD index d9da60f5887004..c64a1332e76ff8 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/BUILD +++ b/third_party/xla/third_party/tsl/tsl/lib/random/BUILD @@ -35,7 +35,7 @@ cc_library( "random_distributions.h", "simple_philox.h", ], - visibility = internal_visibility(default_visibility), + visibility = ["//visibility:public"], deps = [ ":exact_uniform_int", ":philox_random", @@ -64,10 +64,7 @@ cc_library( name = "philox_random", hdrs = ["philox_random.h"], compatible_with = get_compatible_with_portable(), - visibility = internal_visibility([ - "//tensorflow/core/lib/random:__pkg__", - "//tensorflow/lite:__subpackages__", - ]), + visibility = ["//visibility:public"], ) cc_library( diff --git a/third_party/xla/third_party/tsl/tsl/lib/strings/BUILD b/third_party/xla/third_party/tsl/tsl/lib/strings/BUILD index f529fb395957a7..699965e401c526 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/strings/BUILD +++ b/third_party/xla/third_party/tsl/tsl/lib/strings/BUILD @@ -11,15 +11,7 @@ cc_library( name = "proto_serialization", srcs = ["proto_serialization.cc"], hdrs = ["proto_serialization.h"], - visibility = internal_visibility([ - "@local_xla//xla/pjrt:__subpackages__", - "@local_xla//xla/python:__pkg__", - "@local_xla//xla/service:__pkg__", - "@local_xla//xla/stream_executor:__pkg__", - "//tensorflow/core/lib/strings:__pkg__", - "//tensorflow/compiler/tf2xla/kernels:__pkg__", - "//tensorflow/core/util/autotune_maps:__pkg__", - ]), + visibility = ["//visibility:public"], deps = [ "//tsl/lib/gtl:inlined_vector", "//tsl/platform:hash", From 88a5c0363705a67fea4ac4e1a3eb32e394d675c9 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 24 Jun 2024 15:54:21 -0700 Subject: [PATCH 187/256] [xla:cpu] Enable select-and-scatter emitter PiperOrigin-RevId: 646247814 --- third_party/xla/xla/service/cpu/BUILD | 1 + .../xla/xla/service/cpu/benchmarks/BUILD | 16 ++++ .../select_and_scatter_benchmark_test.cc | 80 +++++++++++++++++++ third_party/xla/xla/service/cpu/ir_emitter.cc | 16 +++- third_party/xla/xla/service/cpu/ir_emitter.h | 13 +++ .../xla/xla/service/cpu/ir_emitter2.cc | 18 +++++ third_party/xla/xla/service/cpu/ir_emitter2.h | 4 + .../xla/xla/service/cpu/thunk_emitter.cc | 14 ++++ .../xla/xla/service/cpu/thunk_emitter.h | 3 + 9 files changed, 162 insertions(+), 3 deletions(-) create mode 100644 third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index e7350a66d091e0..7ff98b5e32a330 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -724,6 +724,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", diff --git a/third_party/xla/xla/service/cpu/benchmarks/BUILD b/third_party/xla/xla/service/cpu/benchmarks/BUILD index 0a6e049b3a8572..5b2594f5cc32d0 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/BUILD +++ b/third_party/xla/xla/service/cpu/benchmarks/BUILD @@ -164,3 +164,19 @@ xla_cc_test( "@local_tsl//tsl/platform:test_main", ], ) + +xla_cc_test( + name = "select_and_scatter_benchmark_test", + srcs = ["select_and_scatter_benchmark_test.cc"], + deps = [ + ":hlo_benchmark_runner", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", + ], +) diff --git a/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc new file mode 100644 index 00000000000000..fdfdc7b00b1882 --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc @@ -0,0 +1,80 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/shape_util.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/test_benchmark.h" + +namespace xla::cpu { + +static void BM_SelectAndScatterF32(benchmark::State& state) { + int64_t d0 = state.range(0); + int64_t d1 = (d0 - 1) / 2; + + std::string_view hlo = R"( + HloModule select_and_scatter_f32_$d0 + + ge { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT compare = pred[] compare(a, b), direction=GE + } + + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + + ENTRY e { + p0 = f32[$d0,$d0] parameter(0) + p1 = f32[$d1,$d1] parameter(1) + p2 = f32[] parameter(2) + ROOT sas = f32[$d0,$d0] select-and-scatter(p0, p1, p2), + window={size=3x3 stride=2x2 pad=0_0x0_0}, select=ge, scatter=add + } + )"; + + std::minstd_rand0 engine; + + auto p0 = *LiteralUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(F32, {d0, d0}), &engine, 1.0f, 0.1f); + auto p1 = *LiteralUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(F32, {d1, d1}), &engine, 1.0f, 0.1f); + auto p2 = LiteralUtil::CreateR0(1.0f); + + std::vector args = {&p0, &p1, &p2}; + CHECK_OK( + RunHloBenchmark(state, hlo, args, + {{"$d0", absl::StrCat(d0)}, {"$d1", absl::StrCat(d1)}})); +} + +BENCHMARK(BM_SelectAndScatterF32) + ->MeasureProcessCPUTime() + ->Arg(128) + ->Arg(256) + ->Arg(512); + +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index c33219914d3ff5..0af8d8f5d23cd7 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -738,6 +738,19 @@ absl::Status IrEmitter::HandleSelectAndScatter( CHECK_EQ(select_and_scatter->operand_count(), 3); const auto operand = select_and_scatter->operand(0); const auto source = select_and_scatter->operand(1); + + return HandleSelectAndScatter(select_and_scatter, GetIrArrayFor(operand), + GetIrArrayFor(source), + GetIrArrayFor(select_and_scatter)); +} + +absl::Status IrEmitter::HandleSelectAndScatter( + HloInstruction* select_and_scatter, const llvm_ir::IrArray& operand_array, + const llvm_ir::IrArray& source_array, + const llvm_ir::IrArray& output_array) { + CHECK_EQ(select_and_scatter->operand_count(), 3); + const auto operand = select_and_scatter->operand(0); + const auto source = select_and_scatter->operand(1); const auto init_value = select_and_scatter->operand(2); const Window& window = select_and_scatter->window(); PrimitiveType operand_element_type = operand->shape().element_type(); @@ -849,7 +862,6 @@ absl::Status IrEmitter::HandleSelectAndScatter( Store(operand_index[i], selected_index_address_slot); } }; - llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm_ir::IrArray::Index operand_index( operand_multi_index, operand_array.GetShape(), b_.getInt64Ty()); llvm::Value* operand_data = @@ -899,10 +911,8 @@ absl::Status IrEmitter::HandleSelectAndScatter( selected_index_address->getAllocatedType(), gep_index); selected_multi_index.push_back(Load(type, selected_index_address_slot)); } - llvm_ir::IrArray source_array(GetIrArrayFor(source)); llvm::Value* source_value = source_array.EmitReadArrayElement(source_index, &b_); - llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter)); llvm_ir::IrArray::Index selected_index( selected_multi_index, output_array.GetShape(), source_index.GetType()); llvm::Value* output_value = diff --git a/third_party/xla/xla/service/cpu/ir_emitter.h b/third_party/xla/xla/service/cpu/ir_emitter.h index 1e05edc0bccf9f..9988a260657e1f 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.h +++ b/third_party/xla/xla/service/cpu/ir_emitter.h @@ -27,6 +27,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -39,6 +40,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/ir_function.h" #include "xla/service/cpu/target_machine_features.h" @@ -53,6 +55,10 @@ limitations under the License. namespace xla { namespace cpu { + +// Forward declare emitter for XLA:CPU thunks. +class IrEmitter2; + // This class is the top-level API for the XLA HLO --> LLVM IR compiler. It // implements the DfsHloVisitor interface and emits HLO computations as LLVM IR // functions. @@ -160,6 +166,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, } protected: + friend class IrEmitter2; + // // The following methods implement the DfsHloVisitor interface. // @@ -217,6 +225,11 @@ class IrEmitter : public DfsHloVisitorWithDefault, absl::Status Preprocess(HloInstruction* hlo) override; absl::Status Postprocess(HloInstruction* hlo) override; + absl::Status HandleSelectAndScatter(HloInstruction* select_and_scatter, + const llvm_ir::IrArray& operand_array, + const llvm_ir::IrArray& source_array, + const llvm_ir::IrArray& output_array); + // A convenient helper for calling BufferAssignment::GetUniqueSlice. BufferAllocation::Slice GetAllocationSlice( const HloInstruction& hlo, const ShapeIndex& index = {}) const { diff --git a/third_party/xla/xla/service/cpu/ir_emitter2.cc b/third_party/xla/xla/service/cpu/ir_emitter2.cc index 80c6db985d623b..ba5797c628b2a2 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter2.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter2.cc @@ -415,6 +415,24 @@ absl::StatusOr IrEmitter2::EmitDotFusionHostKernel( se::ThreadDim()}); } +// Emits a host kernel for the given select-and-scatter instruction. +absl::StatusOr +IrEmitter2::EmitSelectAndScatterHostKernel(const HloInstruction* instr) { + KernelPrototype kernel_prototype = EmitKernelPrototype(instr); + + llvm_ir::IrArray operand_array = kernel_prototype.arguments[0]; + llvm_ir::IrArray source_array = kernel_prototype.arguments[1]; + llvm_ir::IrArray output_array = kernel_prototype.results[0]; + + TF_RETURN_IF_ERROR(nested_ir_emitter_->HandleSelectAndScatter( + const_cast(instr), operand_array, source_array, + output_array)); + + return kernels_.emplace_back( + KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(), + se::ThreadDim()}); +} + //===----------------------------------------------------------------------===// // Building HostKernel prototypes. //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/service/cpu/ir_emitter2.h b/third_party/xla/xla/service/cpu/ir_emitter2.h index 8ff46cb1c31d60..503409a07d357d 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter2.h +++ b/third_party/xla/xla/service/cpu/ir_emitter2.h @@ -123,6 +123,10 @@ class IrEmitter2 { absl::StatusOr EmitDotFusionHostKernel( const HloFusionInstruction* fusion); + // Emits a host kernel for the given select-and-scatter instruction. + absl::StatusOr EmitSelectAndScatterHostKernel( + const HloInstruction* instr); + // Emits a host kernel prototype and prepares function for emitting kernel // body into it. KernelPrototype EmitKernelPrototype(std::string_view name, diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.cc b/third_party/xla/xla/service/cpu/thunk_emitter.cc index 6d20fb2a1c1184..e9c8807fc3f887 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.cc +++ b/third_party/xla/xla/service/cpu/thunk_emitter.cc @@ -209,6 +209,9 @@ absl::StatusOr ThunkEmitter::EmitHloInstruction( case HloOpcode::kXor: return EmitElementalKernelThunk(instruction); + case HloOpcode::kSelectAndScatter: + return EmitSelectAndScatterThunk(instruction); + // ReplicaId and PartitionId identify the location of the current device in // a logical grid of communicating devices. case HloOpcode::kReplicaId: @@ -790,6 +793,17 @@ absl::StatusOr ThunkEmitter::EmitCustomCallThunk( backend_config, version); } +absl::StatusOr ThunkEmitter::EmitSelectAndScatterThunk( + const HloInstruction* instruction) { + TF_ASSIGN_OR_RETURN(auto kernel, + ir_emitter_.EmitSelectAndScatterHostKernel(instruction)); + TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); + + return ThunkSequence::Of(ThunkInfo(instruction), + buffers.arguments, buffers.results, + kernel.name, kernel.thread_dims); +} + absl::StatusOr ThunkEmitter::GetHostKernelAllocationSlices(const HloInstruction* instruction) { HostKernelAllocationSlices slices; diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.h b/third_party/xla/xla/service/cpu/thunk_emitter.h index adaab639b3d397..f536beea0b6302 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.h +++ b/third_party/xla/xla/service/cpu/thunk_emitter.h @@ -128,6 +128,9 @@ class ThunkEmitter { absl::StatusOr EmitCustomCallThunk( const HloInstruction* instruction); + absl::StatusOr EmitSelectAndScatterThunk( + const HloInstruction* instruction); + // Returns the list of buffer allocation slices assigned to the given // instruction that will be passed to the host kernel as arguments: a // flattened list of all the leaf buffers for all operands and result. We do From 3f7f0be0eded089f3615a9d14fc5b0a2c5a7ad36 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Mon, 24 Jun 2024 15:59:06 -0700 Subject: [PATCH 188/256] PR #14074: Add sqrt(x) * sqrt(x) => x to algsimp Imported from GitHub PR https://github.com/openxla/xla/pull/14074 This PR adds pattern `sqrt(x) * sqrt(x) => x , for x >= 0` Validation - the following checks are valid both before and after simplification.: ```c If x == 0 the result is 0 If x > 0 the result > 0 If x is inf the result is inf If x is nan the result is nan ``` Related PR - https://github.com/openxla/xla/pull/13771 Copybara import of the project: -- b47b88ad5d726fb5227ce52e8a04572967efaef2 by Alexander Pivovarov : Add sqrt(x) * sqrt(x) => x to algsimp Merging this change closes #14074 PiperOrigin-RevId: 646249151 --- .../xla/xla/service/algebraic_simplifier.cc | 7 ++++ .../xla/service/algebraic_simplifier_test.cc | 32 +++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/third_party/xla/xla/service/algebraic_simplifier.cc b/third_party/xla/xla/service/algebraic_simplifier.cc index 9a8b616b20549e..6be5e65adf7fc3 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.cc +++ b/third_party/xla/xla/service/algebraic_simplifier.cc @@ -4582,6 +4582,13 @@ absl::Status AlgebraicSimplifierVisitor::HandleMultiply( HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add)); } + VLOG(10) << "trying transform [sqrt(x) * sqrt(x) => x], for x >= 0 " + << multiply->ToString(); + if (Match(multiply, m::Multiply(m::Sqrt(m::Op(&a)), m::Sqrt(m::Op(&a)))) && + IsNonNegative(a, options_)) { + return ReplaceInstruction(multiply, a); + } + VLOG(10) << "trying transform [rsqrt(B) * rsqrt(B) => 1/B], for B >= 0 " << multiply->ToString(); HloInstruction* b; diff --git a/third_party/xla/xla/service/algebraic_simplifier_test.cc b/third_party/xla/xla/service/algebraic_simplifier_test.cc index 9b2fd29d36e723..d37bb196383b98 100644 --- a/third_party/xla/xla/service/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/service/algebraic_simplifier_test.cc @@ -9203,6 +9203,38 @@ TEST_F(AlgebraicSimplifierTest, RsqrtDivide_NegativeTestCase) { ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); } +// sqrt(x) * sqrt(x) => x, for x >= 0 +TEST_F(AlgebraicSimplifierTest, MultiplySelfSqrt) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[1,32] parameter(0) + abs0 = f32[1,32] abs(p0) + sqrt = f32[1,32] sqrt(abs0) + ROOT mul = f32[1,32] multiply(sqrt, sqrt) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + ASSERT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Abs(m::Parameter(0)))); +} + +// sqrt(x) * sqrt(x) ≠> x +// if x is arbitrary number - no simplification +TEST_F(AlgebraicSimplifierTest, MultiplySelfSqrt_NegativeTestCase) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[1,32] parameter(0) + sqrt = f32[1,32] sqrt(p0) + ROOT mul = f32[1,32] multiply(sqrt, sqrt) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); +} + // rsqrt(x) * rsqrt(x) -> 1/x, for x >= 0 TEST_F(AlgebraicSimplifierTest, MultiplySelfRsqrt) { const char* kModuleStr = R"( From bd594187900c0d898c54ad56e96b486a729946d7 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Mon, 24 Jun 2024 16:23:20 -0700 Subject: [PATCH 189/256] Use tsl/platform/statusor.h in status_macros.h instead of xla/statusor.h PiperOrigin-RevId: 646256201 --- third_party/xla/xla/BUILD | 3 ++- .../xla/xla/service/cpu/onednn_convolution_rewriter.h | 2 +- third_party/xla/xla/status_macros.h | 2 +- third_party/xla/xla/statusor.h | 10 ++-------- 4 files changed, 6 insertions(+), 11 deletions(-) diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index f81390565195b0..496c3c1c3ffdf6 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -271,7 +271,6 @@ cc_library( hdrs = ["status_macros.h"], visibility = ["//visibility:public"], deps = [ - ":statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -279,6 +278,7 @@ cc_library( "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:stacktrace", "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", ], ) @@ -315,6 +315,7 @@ cc_library( hdrs = [ "statusor.h", ], + deprecation = "Use @com_google_absl//absl/status:statusor and/or //third_party/tensorflow/tsl/platform:statusor instead.", linkopts = select({ "//xla/tsl:freebsd": ["-lexecinfo"], "//conditions:default": [], diff --git a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h b/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h index 334db7f60b8356..312221f212d7d9 100644 --- a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h +++ b/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h @@ -35,7 +35,7 @@ class OneDnnConvolutionRewriter : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/status_macros.h b/third_party/xla/xla/status_macros.h index d62bfa276a4c28..3f42d962e4b307 100644 --- a/third_party/xla/xla/status_macros.h +++ b/third_party/xla/xla/status_macros.h @@ -25,9 +25,9 @@ limitations under the License. #include "absl/base/optimization.h" #include "absl/status/status.h" -#include "xla/statusor.h" #include "tsl/platform/macros.h" #include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { namespace status_macros { diff --git a/third_party/xla/xla/statusor.h b/third_party/xla/xla/statusor.h index 3a7d6da42897f7..67dce52566d5b5 100644 --- a/third_party/xla/xla/statusor.h +++ b/third_party/xla/xla/statusor.h @@ -15,13 +15,7 @@ limitations under the License. #ifndef XLA_STATUSOR_H_ #define XLA_STATUSOR_H_ -#include "tsl/platform/statusor.h" - -namespace xla { - -// Use steam_executor's absl::StatusOr so we don't duplicate code. -using tsl::StatusOr; // TENSORFLOW_STATUS_OK - -} // namespace xla +// This file is deprecated. Use absl/status/statusor.h and/or +// tsl/platform/statusor.h instead. #endif // XLA_STATUSOR_H_ From 54312e394681006d4c2b2c1d9a3b5eaa5ef0ff76 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Mon, 24 Jun 2024 16:40:51 -0700 Subject: [PATCH 190/256] Remove the last few uses of xla/status.h and xla/statusor.h from XLA. PiperOrigin-RevId: 646260993 --- third_party/xla/xla/service/gpu/BUILD | 2 -- 1 file changed, 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index d054ac443cc08c..4a4efce3557eed 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -6101,8 +6101,6 @@ cc_library( deps = [ ":backend_configs_cc", "//xla:literal_util", - "//xla:status", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", From 875cdd3a278440b260fed5ee24b73c87eefea1de Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 24 Jun 2024 17:09:43 -0700 Subject: [PATCH 191/256] Integrate StableHLO at openxla/stablehlo@b6129ded PiperOrigin-RevId: 646269058 --- third_party/stablehlo/workspace.bzl | 4 ++-- third_party/xla/third_party/stablehlo/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 4a70795435d0f9..2202e6f9d36bb0 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "61826746d640f6342856af5ac71dabe6fbf37ff5" - STABLEHLO_SHA256 = "fece99979b885686438068a786895dcd9a559fdc86072178e42181404022d483" + STABLEHLO_COMMIT = "b6129dedc3799fa7714a22dc03b645db7b46486b" + STABLEHLO_SHA256 = "92498ea51363d79c89c377ef1723078e258b09b6af006103f327443e3e6ff2f8" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl index 4a70795435d0f9..2202e6f9d36bb0 100644 --- a/third_party/xla/third_party/stablehlo/workspace.bzl +++ b/third_party/xla/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "61826746d640f6342856af5ac71dabe6fbf37ff5" - STABLEHLO_SHA256 = "fece99979b885686438068a786895dcd9a559fdc86072178e42181404022d483" + STABLEHLO_COMMIT = "b6129dedc3799fa7714a22dc03b645db7b46486b" + STABLEHLO_SHA256 = "92498ea51363d79c89c377ef1723078e258b09b6af006103f327443e3e6ff2f8" # LINT.ThenChange(Google-internal path) tf_http_archive( From 9f2a24fe01d07f2539fc54c9f2098475333859dc Mon Sep 17 00:00:00 2001 From: Victor Stone Date: Mon, 24 Jun 2024 17:12:07 -0700 Subject: [PATCH 192/256] Override name() of HloPass for ShardonnayCallInliner PiperOrigin-RevId: 646269601 --- third_party/xla/xla/service/call_inliner.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/call_inliner.h b/third_party/xla/xla/service/call_inliner.h index ebc32684c93d67..5698d0bc17d3d0 100644 --- a/third_party/xla/xla/service/call_inliner.h +++ b/third_party/xla/xla/service/call_inliner.h @@ -45,7 +45,7 @@ class CallInliner : public HloModulePass { bool update_domain = false) : single_call_site_(single_call_site), update_domain_(update_domain) {} ~CallInliner() override = default; - absl::string_view name() const override { return "CallInliner"; } + absl::string_view name() const override { return "call-inliner"; } using HloPassInterface::Run; absl::StatusOr Run( From 6f34f2e4633d95ca5dfabddac337cd1a17989d56 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Mon, 24 Jun 2024 17:45:09 -0700 Subject: [PATCH 193/256] Rewrite `basic_string_array_test.cc` to use `Client::CopyArrays` instead of `Array::Reshard` This is in preparation for removing the now-deprecated `Array::Reshard`. PiperOrigin-RevId: 646277841 --- .../pjrt_ifrt/basic_string_array_test.cc | 87 ++++++++++--------- 1 file changed, 45 insertions(+), 42 deletions(-) diff --git a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc index 922c3875e782c5..a0d21a4cf11307 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc @@ -693,53 +693,53 @@ TEST(DisassembleArrayIntoSingleDeviceArrays, FailsIfTheArrayHasBeenDeleted) { StatusIs(absl::StatusCode::kFailedPrecondition)); } -TEST(ReshardTest, SuccessSingleDeviceShardedArray) { +TEST(CopyTest, SuccessSingleDeviceShardedArray) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); auto devices = client->addressable_devices(); ASSERT_GE(devices.size(), 2); auto [buffers, on_done_with_buffer] = MakeBuffersAndOnDoneWithBuffer({"abc"}); + std::vector> arrays; TF_ASSERT_OK_AND_ASSIGN( - auto array, + arrays.emplace_back(), CreateTestArray(client.get(), Future(buffers), std::move(on_done_with_buffer))); // CreateTestArray above would place the array on the first device. Use the - // second one for the new sharding. - std::shared_ptr new_sharding = - SingleDeviceSharding::Create(devices[1], MemoryKind()); - + // second one for the new array. TF_ASSERT_OK_AND_ASSIGN( - auto new_array, - array->Reshard(new_sharding, ArrayCopySemantics::kAlwaysCopy)); + auto new_arrays, + client->CopyArrays(absl::MakeSpan(arrays), DeviceList({devices[1]}), + MemoryKind(), ArrayCopySemantics::kAlwaysCopy)); auto new_basic_string_array = - llvm::dyn_cast(new_array.get()); + llvm::dyn_cast(new_arrays[0].get()); TF_ASSERT_OK_AND_ASSIGN(auto new_buffers, new_basic_string_array->buffers().Await()); ASSERT_EQ(new_buffers.size(), 1); EXPECT_THAT(new_buffers[0], testing::ElementsAre("abc")); } -TEST(ReshardTest, SuccessMultiDeviceShardedArray) { +TEST(CopyTest, SuccessMultiDeviceShardedArray) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); auto devices = client->addressable_devices(); ASSERT_GE(devices.size(), 4); const std::vector per_shard_contents({"shard 0", "shard 1"}); + std::vector> arrays; TF_ASSERT_OK_AND_ASSIGN( - auto array, MakeShardedStringTestArray(client.get(), per_shard_contents, - /*is_fully_replicated=*/false)); - - std::shared_ptr new_sharding = OpaqueSharding::Create( - DeviceList({devices[2], devices[3]}), MemoryKind()); + arrays.emplace_back(), + MakeShardedStringTestArray(client.get(), per_shard_contents, + /*is_fully_replicated=*/false)); TF_ASSERT_OK_AND_ASSIGN( - auto new_array, - array->Reshard(new_sharding, ArrayCopySemantics::kAlwaysCopy)); + auto new_arrays, + client->CopyArrays(absl::MakeSpan(arrays), + DeviceList({devices[2], devices[3]}), MemoryKind(), + ArrayCopySemantics::kAlwaysCopy)); auto new_basic_string_array = - llvm::dyn_cast(new_array.get()); + llvm::dyn_cast(new_arrays[0].get()); TF_ASSERT_OK_AND_ASSIGN(auto new_buffers, new_basic_string_array->buffers().Await()); ASSERT_EQ(new_buffers.size(), 2); @@ -747,44 +747,45 @@ TEST(ReshardTest, SuccessMultiDeviceShardedArray) { EXPECT_THAT(new_buffers[1], testing::ElementsAre("shard 1")); } -TEST(ReshardTest, FailsAfterDeletion) { +TEST(CopyTest, FailsAfterDeletion) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); auto devices = client->addressable_devices(); ASSERT_GE(devices.size(), 2); auto [buffers, on_done_with_buffer] = MakeBuffersAndOnDoneWithBuffer({"abc"}); + std::vector> arrays; TF_ASSERT_OK_AND_ASSIGN( - auto array, + arrays.emplace_back(), CreateTestArray(client.get(), Future(buffers), std::move(on_done_with_buffer))); - array->Delete(); + arrays[0]->Delete(); EXPECT_THAT( - array->Reshard(SingleDeviceSharding::Create(devices[1], MemoryKind()), - ArrayCopySemantics::kAlwaysCopy), + client->CopyArrays(absl::MakeSpan(arrays), DeviceList({devices[1]}), + MemoryKind(), ArrayCopySemantics::kAlwaysCopy), StatusIs(absl::StatusCode::kFailedPrecondition)); } -TEST(ReshardTest, FailsWithDifferentNumbersDevicesInNewSharding) { +TEST(CopyTest, FailsWithDifferentNumbersDevices) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); auto devices = client->addressable_devices(); ASSERT_GE(devices.size(), 2); auto [buffers, on_done_with_buffer] = MakeBuffersAndOnDoneWithBuffer({"abc"}); + std::vector> arrays; TF_ASSERT_OK_AND_ASSIGN( - auto array, + arrays.emplace_back(), CreateTestArray(client.get(), Future(buffers), std::move(on_done_with_buffer))); - EXPECT_THAT( - array->Reshard(OpaqueSharding::Create( - DeviceList({devices[0], devices[1]}), MemoryKind()), - ArrayCopySemantics::kAlwaysCopy), - StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(client->CopyArrays(absl::MakeSpan(arrays), + DeviceList({devices[0], devices[1]}), + MemoryKind(), ArrayCopySemantics::kAlwaysCopy), + StatusIs(absl::StatusCode::kInvalidArgument)); } -TEST(ReshardTest, NonReadySourceArraySuccessfullyBecomesReadyAfterReshard) { +TEST(CopyTest, NonReadySourceArraySuccessfullyBecomesReadyAfterCopy) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); auto devices = client->addressable_devices(); ASSERT_GE(devices.size(), 2); @@ -795,12 +796,13 @@ TEST(ReshardTest, NonReadySourceArraySuccessfullyBecomesReadyAfterReshard) { TF_ASSERT_OK_AND_ASSIGN( auto ret, CreateNonReadyTestArray(client.get(), devices[0], std::move(on_done_with_buffer))); - auto array = std::move(ret.first); + std::vector> arrays; + arrays.push_back(std::move(ret.first)); auto promise = std::move(ret.second); - TF_ASSERT_OK( - array->Reshard(SingleDeviceSharding::Create(devices[1], MemoryKind()), - ArrayCopySemantics::kAlwaysCopy)); + TF_ASSERT_OK(client->CopyArrays(absl::MakeSpan(arrays), + DeviceList({devices[1]}), MemoryKind(), + ArrayCopySemantics::kAlwaysCopy)); absl::Notification done_readying_single_device_arrays; tsl::Env::Default()->SchedClosure(([&]() mutable { @@ -808,7 +810,7 @@ TEST(ReshardTest, NonReadySourceArraySuccessfullyBecomesReadyAfterReshard) { done_readying_single_device_arrays.Notify(); })); - auto basic_string_array = llvm::dyn_cast(array.get()); + auto basic_string_array = llvm::dyn_cast(arrays[0].get()); ASSERT_NE(basic_string_array, nullptr); TF_ASSERT_OK_AND_ASSIGN(auto new_buffers, @@ -822,7 +824,7 @@ TEST(ReshardTest, NonReadySourceArraySuccessfullyBecomesReadyAfterReshard) { done_readying_single_device_arrays.WaitForNotification(); } -TEST(ReshardTest, NonReadySourceArrayFailsToBecomeReadyAfterReshard) { +TEST(CopyTest, NonReadySourceArrayFailsToBecomeReadyAfterCopy) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); auto devices = client->addressable_devices(); ASSERT_GE(devices.size(), 2); @@ -833,12 +835,13 @@ TEST(ReshardTest, NonReadySourceArrayFailsToBecomeReadyAfterReshard) { TF_ASSERT_OK_AND_ASSIGN( auto ret, CreateNonReadyTestArray(client.get(), devices[0], std::move(on_done_with_buffer))); - auto array = std::move(ret.first); + std::vector> arrays; + arrays.push_back(std::move(ret.first)); auto promise = std::move(ret.second); - TF_ASSERT_OK( - array->Reshard(SingleDeviceSharding::Create(devices[1], MemoryKind()), - ArrayCopySemantics::kAlwaysCopy)); + TF_ASSERT_OK(client->CopyArrays(absl::MakeSpan(arrays), + DeviceList({devices[1]}), MemoryKind(), + ArrayCopySemantics::kAlwaysCopy)); absl::Notification done_readying_single_device_arrays; tsl::Env::Default()->SchedClosure(([&]() mutable { @@ -846,7 +849,7 @@ TEST(ReshardTest, NonReadySourceArrayFailsToBecomeReadyAfterReshard) { done_readying_single_device_arrays.Notify(); })); - auto basic_string_array = llvm::dyn_cast(array.get()); + auto basic_string_array = llvm::dyn_cast(arrays[0].get()); ASSERT_NE(basic_string_array, nullptr); auto buffers_future = basic_string_array->buffers(); From f436745a12214221e0fecec5c0fa9a8517dc9eec Mon Sep 17 00:00:00 2001 From: Shawn Wang Date: Mon, 24 Jun 2024 18:33:37 -0700 Subject: [PATCH 194/256] PR #13885: [XLA:GPU] Allow user to change command buffer trace cache capacity Imported from GitHub PR https://github.com/openxla/xla/pull/13885 Allow user to set the command buffer trace cache size, increasing the cache size may sometimes reduces the chances of doing command buffer tracing for updating command buffer instance. Copybara import of the project: -- 824b1ebf7bcb50e892e931f192b9ff883e21f0ca by Shawn Wang : Allow user to change command buffer trace cache capacity Merging this change closes #13885 PiperOrigin-RevId: 646288429 --- third_party/xla/xla/debug_options_flags.cc | 9 +- third_party/xla/xla/service/gpu/runtime/BUILD | 1 + .../service/gpu/runtime/command_buffer_cmd.cc | 9 +- .../gpu/runtime/command_buffer_cmd_test.cc | 156 +++++++++--------- third_party/xla/xla/xla.proto | 7 +- 5 files changed, 104 insertions(+), 78 deletions(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 9cfce6dd46b0c8..6e4d82c877f992 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -115,6 +115,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN); opts.set_xla_gpu_graph_min_graph_size(5); opts.set_xla_gpu_graph_enable_concurrent_region(false); + opts.set_xla_cmd_buffer_trace_cache_size(16); // Despite the name, fast min/max on GPUs does not seem to be any faster, and // adds very counter-intuitive "NaN-swallowing" behavior. @@ -1217,7 +1218,13 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_graph_enable_concurrent_region(), "Identify concurrent regions in gpu graphs and execute them " "concurrently.")); - + flag_list->push_back(tsl::Flag( + "xla_cmd_buffer_trace_cache_size", + int64_setter_for(&DebugOptions::set_xla_cmd_buffer_trace_cache_size), + debug_options->xla_cmd_buffer_trace_cache_size(), + "Set the command buffer trace cache size, increasing the cache size may " + "sometimes reduces the chances of doing command buffer tracing for " + "updating command buffer instance.")); flag_list->push_back( tsl::Flag("xla_dump_disable_metadata", bool_setter_for(&DebugOptions::set_xla_dump_disable_metadata), diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index 5af4d5703fee7f..9db883559efcea 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -63,6 +63,7 @@ cc_library( ":nccl_collective_broadcast_thunk", ":nccl_collective_thunk", ":thunk", + "//xla:debug_options_flags", "//xla:executable_run_options", "//xla:types", "//xla:util", diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc index 78c5afc1e8ab61..77787478bc7839 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc @@ -38,6 +38,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/debug_options_flags.h" #include "xla/executable_run_options.h" #include "xla/ffi/call_frame.h" #include "xla/ffi/ffi_api.h" @@ -432,8 +433,12 @@ absl::Status TracedCommandBufferCmd::AddTracedCommandBuffer( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, se::CommandBuffer* command_buffer, absl::FunctionRef trace) { - auto traced_cmd = record_params.state.GetOrCreate( - this, [&] { return std::make_unique(buffers()); }); + auto traced_cmd = + record_params.state.GetOrCreate(this, [&] { + const auto& debug_options = xla::GetDebugOptionsFromFlags(); + return std::make_unique( + buffers(), debug_options.xla_cmd_buffer_trace_cache_size()); + }); TF_ASSIGN_OR_RETURN( auto nested_cmd, diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc index f76fec88ed8e31..ceb0d58821f247 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc @@ -398,81 +398,89 @@ TEST(CommandBufferCmdStateManageTest, GetOrCreateState) { } TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) { - se::StreamExecutor* executor = GpuExecutor(); - - auto stream = executor->CreateStream().value(); - BufferAllocation alloc0(/*index=*/0, /*size=*/1024, /*color=*/0); - BufferAllocation alloc1(/*index=*/1, /*size=*/1024, /*color=*/0); - - CommandBufferCmd::BufferUsageVector buffers = { - {BufferAllocation::Slice(&alloc0, 0, 1024), MemoryAccess::kRead}, - {BufferAllocation::Slice(&alloc1, 0, 1024), MemoryAccess::kWrite}}; - - TracedCommandBuffer traced_cmd_buffer(buffers, /*capacity=*/2); - - se::DeviceMemoryBase mem0(reinterpret_cast(0x01234567)); - se::DeviceMemoryBase mem1(reinterpret_cast(0x12345670)); - - se::StreamExecutorMemoryAllocator allocator(executor); - BufferAllocations allocations({mem0, mem1}, 0, &allocator); - - // No-op trace callback to count how many times it was called. - int64_t num_calls = 0; - auto trace = [&](se::Stream*) { - num_calls++; - return absl::OkStatus(); + auto run_traced_test = [](int trace_cache_size) { + se::StreamExecutor* executor = GpuExecutor(); + + auto stream = executor->CreateStream().value(); + BufferAllocation alloc0(/*index=*/0, /*size=*/1024, /*color=*/0); + BufferAllocation alloc1(/*index=*/1, /*size=*/1024, /*color=*/0); + + CommandBufferCmd::BufferUsageVector buffers = { + {BufferAllocation::Slice(&alloc0, 0, 1024), MemoryAccess::kRead}, + {BufferAllocation::Slice(&alloc1, 0, 1024), MemoryAccess::kWrite}}; + + TracedCommandBuffer traced_cmd_buffer(buffers, + /*capacity=*/trace_cache_size); + + se::DeviceMemoryBase mem0(reinterpret_cast(0x01234567)); + se::DeviceMemoryBase mem1(reinterpret_cast(0x12345670)); + + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({mem0, mem1}, 0, &allocator); + + // No-op trace callback to count how many times it was called. + int64_t num_calls = 0; + auto trace = [&](se::Stream*) { + num_calls++; + return absl::OkStatus(); + }; + + TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer0, + traced_cmd_buffer.GetOrTraceCommandBuffer( + &allocations, executor, stream.get(), trace)); + + TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer1, + traced_cmd_buffer.GetOrTraceCommandBuffer( + &allocations, executor, stream.get(), trace)); + + // Check that command buffer was reused as buffer allocations didn't + // change. + ASSERT_EQ(command_buffer0, command_buffer1); + EXPECT_EQ(num_calls, 1); + + // Check that when memory address changes we re-trace the command + // buffer. + se::DeviceMemoryBase mem2(reinterpret_cast(0x23456701)); + allocations = BufferAllocations({mem0, mem2}, 0, &allocator); + + TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer2, + traced_cmd_buffer.GetOrTraceCommandBuffer( + &allocations, executor, stream.get(), trace)); + + ASSERT_NE(command_buffer0, command_buffer2); + EXPECT_EQ(num_calls, 2); + + // Check that we keep first command buffer in cache. + allocations = BufferAllocations({mem0, mem1}, 0, &allocator); + + TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer3, + traced_cmd_buffer.GetOrTraceCommandBuffer( + &allocations, executor, stream.get(), trace)); + ASSERT_EQ(command_buffer0, command_buffer3); + EXPECT_EQ(num_calls, 2); + + // Check that we trace a new graph when buffer allocation pattern is + // new. + allocations = BufferAllocations({mem0, mem0}, 0, &allocator); + + TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer4, + traced_cmd_buffer.GetOrTraceCommandBuffer( + &allocations, executor, stream.get(), trace)); + ASSERT_NE(command_buffer4, command_buffer3); + ASSERT_NE(command_buffer4, command_buffer2); + EXPECT_EQ(num_calls, 3); + + // Check that we still keep the previous graph in cache. + allocations = BufferAllocations({mem0, mem1}, 0, &allocator); + + TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer5, + traced_cmd_buffer.GetOrTraceCommandBuffer( + &allocations, executor, stream.get(), trace)); + ASSERT_EQ(command_buffer0, command_buffer5); + EXPECT_EQ(num_calls, 3); }; - - TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer0, - traced_cmd_buffer.GetOrTraceCommandBuffer( - &allocations, executor, stream.get(), trace)); - - TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer1, - traced_cmd_buffer.GetOrTraceCommandBuffer( - &allocations, executor, stream.get(), trace)); - - // Check that command buffer was reused as buffer allocations didn't change. - ASSERT_EQ(command_buffer0, command_buffer1); - EXPECT_EQ(num_calls, 1); - - // Check that when memory address changes we re-trace the command buffer. - se::DeviceMemoryBase mem2(reinterpret_cast(0x23456701)); - allocations = BufferAllocations({mem0, mem2}, 0, &allocator); - - TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer2, - traced_cmd_buffer.GetOrTraceCommandBuffer( - &allocations, executor, stream.get(), trace)); - - ASSERT_NE(command_buffer0, command_buffer2); - EXPECT_EQ(num_calls, 2); - - // Check that we keep first command buffer in cache. - allocations = BufferAllocations({mem0, mem1}, 0, &allocator); - - TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer3, - traced_cmd_buffer.GetOrTraceCommandBuffer( - &allocations, executor, stream.get(), trace)); - ASSERT_EQ(command_buffer0, command_buffer3); - EXPECT_EQ(num_calls, 2); - - // Check that we trace a new graph when buffer allocation pattern is new. - allocations = BufferAllocations({mem0, mem0}, 0, &allocator); - - TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer4, - traced_cmd_buffer.GetOrTraceCommandBuffer( - &allocations, executor, stream.get(), trace)); - ASSERT_NE(command_buffer4, command_buffer3); - ASSERT_NE(command_buffer4, command_buffer2); - EXPECT_EQ(num_calls, 3); - - // Check that we still keep the previous graph in cache. - allocations = BufferAllocations({mem0, mem1}, 0, &allocator); - - TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer5, - traced_cmd_buffer.GetOrTraceCommandBuffer( - &allocations, executor, stream.get(), trace)); - ASSERT_EQ(command_buffer0, command_buffer5); - EXPECT_EQ(num_calls, 3); + run_traced_test(2); + run_traced_test(3); } //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 8e92692a535323..b8b7a4a3f43def 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -826,7 +826,12 @@ message DebugOptions { string xla_gpu_per_fusion_autotune_cache_dir = 310; - // Next id: 311 + // The command buffer trace cache size, increasing the cache size may + // sometimes reduces the chances of doing command buffer tracing for + // updating command buffer instance. + int64 xla_cmd_buffer_trace_cache_size = 311; + + // Next id: 312 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From 87d592a0ea9ad2b7eb38591f791f752f03ed0101 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jun 2024 18:45:42 -0700 Subject: [PATCH 195/256] Reverts bd594187900c0d898c54ad56e96b486a729946d7 PiperOrigin-RevId: 646290687 --- third_party/xla/xla/BUILD | 3 +-- .../xla/xla/service/cpu/onednn_convolution_rewriter.h | 2 +- third_party/xla/xla/status_macros.h | 2 +- third_party/xla/xla/statusor.h | 10 ++++++++-- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 496c3c1c3ffdf6..f81390565195b0 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -271,6 +271,7 @@ cc_library( hdrs = ["status_macros.h"], visibility = ["//visibility:public"], deps = [ + ":statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -278,7 +279,6 @@ cc_library( "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:stacktrace", "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", ], ) @@ -315,7 +315,6 @@ cc_library( hdrs = [ "statusor.h", ], - deprecation = "Use @com_google_absl//absl/status:statusor and/or //third_party/tensorflow/tsl/platform:statusor instead.", linkopts = select({ "//xla/tsl:freebsd": ["-lexecinfo"], "//conditions:default": [], diff --git a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h b/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h index 312221f212d7d9..334db7f60b8356 100644 --- a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h +++ b/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h @@ -35,7 +35,7 @@ class OneDnnConvolutionRewriter : public HloModulePass { } using HloPassInterface::Run; - absl::StatusOr Run( + StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/status_macros.h b/third_party/xla/xla/status_macros.h index 3f42d962e4b307..d62bfa276a4c28 100644 --- a/third_party/xla/xla/status_macros.h +++ b/third_party/xla/xla/status_macros.h @@ -25,9 +25,9 @@ limitations under the License. #include "absl/base/optimization.h" #include "absl/status/status.h" +#include "xla/statusor.h" #include "tsl/platform/macros.h" #include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace xla { namespace status_macros { diff --git a/third_party/xla/xla/statusor.h b/third_party/xla/xla/statusor.h index 67dce52566d5b5..3a7d6da42897f7 100644 --- a/third_party/xla/xla/statusor.h +++ b/third_party/xla/xla/statusor.h @@ -15,7 +15,13 @@ limitations under the License. #ifndef XLA_STATUSOR_H_ #define XLA_STATUSOR_H_ -// This file is deprecated. Use absl/status/statusor.h and/or -// tsl/platform/statusor.h instead. +#include "tsl/platform/statusor.h" + +namespace xla { + +// Use steam_executor's absl::StatusOr so we don't duplicate code. +using tsl::StatusOr; // TENSORFLOW_STATUS_OK + +} // namespace xla #endif // XLA_STATUSOR_H_ From 6a1e73dfa46f281cfba753ea883eceee4d04b7b7 Mon Sep 17 00:00:00 2001 From: Seher Ellis Date: Mon, 24 Jun 2024 19:13:03 -0700 Subject: [PATCH 196/256] [XLA] Allow CollectivePipeliner to iteratively forward-sink the collectives that have at least one other forward-sinkable collective in their user tree. PiperOrigin-RevId: 646296580 --- third_party/xla/xla/hlo/ir/hlo_instruction.h | 6 + third_party/xla/xla/service/BUILD | 1 + .../xla/xla/service/collective_pipeliner.cc | 30 ++++- .../xla/xla/service/collective_pipeliner.h | 7 ++ .../xla/service/collective_pipeliner_test.cc | 108 ++++++++++++++++++ 5 files changed, 149 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index 7a55565cb12502..9109e45da6d3f3 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -2861,6 +2861,12 @@ bool HloPredicateIsOp(const HloInstruction* instruction) { ((instruction->opcode() == rest) || ...); } +template +bool HloPredicateIsNotOp(const HloInstruction* instruction) { + return (instruction->opcode() != op) && + ((instruction->opcode() != rest) && ...); +} + /* static */ inline bool HloInstruction::MightHaveCalledComputations( HloOpcode opcode) { switch (opcode) { diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 17c8f39a482604..6a70a0aa96c65d 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -619,6 +619,7 @@ xla_cc_test( "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc index 7faef4e5330872..6e08ae23708dad 100644 --- a/third_party/xla/xla/service/collective_pipeliner.cc +++ b/third_party/xla/xla/service/collective_pipeliner.cc @@ -2578,11 +2578,9 @@ static absl::Status TransformLoopBackward( return absl::OkStatus(); } -absl::StatusOr CollectivePipeliner::Run( +absl::StatusOr CollectivePipeliner::RunPipeliner( HloModule* module, const absl::flat_hash_set& execution_threads) { - CHECK(config_.acceptable_formatting); - CHECK(config_.should_process); bool changed = false; std::vector while_loop_instructions; for (HloComputation* computation : module->MakeComputationPostOrder()) { @@ -2687,4 +2685,30 @@ absl::StatusOr CollectivePipeliner::Run( return changed; } +absl::StatusOr CollectivePipeliner::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + CHECK(config_.acceptable_formatting); + CHECK(config_.should_process); + + if (config_.pipelining_direction != PipeliningDirection::kForwardSink) { + return RunPipeliner(module, execution_threads); + } + + // If the pipelining direction is kForwardSink, run the pipeliner until it + // does not change the module anymore. The maximum number of iterations should + // be equal to the maximum number of pipelineable collectives in a chain of + // users plus one. In each iteration, we pipeline the last pipelineable + // collectives, which do not have any other pipelineable collectives in their + // user subtree. + bool changed = true; + int64_t iter = 0; + while (changed) { + TF_ASSIGN_OR_RETURN(changed, RunPipeliner(module, execution_threads)); + VLOG(1) << "Finished running pipeliner's iteration: " << iter; + iter++; + } + return iter > 1; +} + } // namespace xla diff --git a/third_party/xla/xla/service/collective_pipeliner.h b/third_party/xla/xla/service/collective_pipeliner.h index 80103f02a32145..8df3de487027d7 100644 --- a/third_party/xla/xla/service/collective_pipeliner.h +++ b/third_party/xla/xla/service/collective_pipeliner.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_COLLECTIVE_PIPELINER_H_ #define XLA_SERVICE_COLLECTIVE_PIPELINER_H_ +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" @@ -132,6 +133,12 @@ class CollectivePipeliner : public HloModulePass { } } + // Pipelines the collectives that do not have any other pipelineable + // collectives in their user subtree. + absl::StatusOr RunPipeliner( + HloModule* module, + const absl::flat_hash_set& execution_threads); + using HloPassInterface::Run; absl::StatusOr Run( HloModule* module, diff --git a/third_party/xla/xla/service/collective_pipeliner_test.cc b/third_party/xla/xla/service/collective_pipeliner_test.cc index 56f0aaf829b841..4fc7e8de731dd8 100644 --- a/third_party/xla/xla/service/collective_pipeliner_test.cc +++ b/third_party/xla/xla/service/collective_pipeliner_test.cc @@ -15,13 +15,16 @@ limitations under the License. #include "xla/service/collective_pipeliner.h" +#include #include #include +#include #include #include #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -2709,5 +2712,110 @@ ENTRY entry { })); } +TEST_F(CollectivePipelinerTest, ForwardSinkDependentPipelineableCollectives) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +add.1 { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=2 + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + constant.2561 = s32[] constant(0) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) + ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1 + b.1 = bf16[1,8,128,32] broadcast(ar.1), dimensions={0,1,2} + constant = bf16[] constant(0) + reduce = bf16[1,8,128] reduce(b.1, constant), dimensions={3}, to_apply=add.1 + ar.2 = bf16[1,8,128] all-reduce(reduce), replica_groups={}, to_apply=add, channel_id=2 + c1 = bf16[] constant(2.0) + bc = bf16[1,8,128] broadcast(c1) + mul1 = bf16[1,8,128] multiply(ar.2, bc) + mul3 = bf16[1,8,128] multiply(mul1, ar.2) + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, mul3, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.35) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0) + while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + config_.set_use_spmd_partitioning(true); + auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value(); + EXPECT_TRUE( + RunOptimizer( + module.get(), /*last_run=*/true, + /*level_to_operate_on=*/0, + /*pipeline_use_tree=*/true, + /*process_different_sized_ops=*/true, + CollectivePipeliner::kForwardSink, + /*should_process=*/HloPredicateIsOp, + /*acceptable_formatting=*/HloPredicateIsNotOp) + .value()); + XLA_VLOG_LINES(1, module->ToString()); + // Return the closest all-reduce in the user subtree rooted at instruction i. + std::function find_all_reduce = + [&](const HloInstruction* i) -> const HloInstruction* { + std::queue queue; + queue.push(i); + absl::flat_hash_set visited; + while (!queue.empty()) { + const HloInstruction* curr_inst = queue.front(); + queue.pop(); + for (HloInstruction* operand : curr_inst->operands()) { + if (operand->opcode() == HloOpcode::kAllReduce) { + return operand; + } + if (visited.insert(operand).second) { + queue.push(operand); + } + } + } + return nullptr; + }; + // Check if root has the two all-reduces in the operand subtree where one is + // an ancestor of the other. + const HloInstruction* all_reduce1 = + find_all_reduce(module->entry_computation()->root_instruction()); + EXPECT_NE(all_reduce1, nullptr); + const HloInstruction* all_reduce2 = find_all_reduce(all_reduce1); + EXPECT_NE(all_reduce2, nullptr); + EXPECT_THAT(all_reduce2, op::AllReduce(op::GetTupleElement(op::While()))); +} + } // namespace } // namespace xla From f9b7c9948ee998f3548eb4be49aafdee162f76a1 Mon Sep 17 00:00:00 2001 From: Philipp Hack Date: Mon, 24 Jun 2024 20:14:31 -0700 Subject: [PATCH 197/256] PR #13938: Set Asynchronous Custom Call Start Thunks Imported from GitHub PR https://github.com/openxla/xla/pull/13938 Corrects the types of the thunks emitted for asynchronous GEMM Custom Call starts. Copybara import of the project: -- b52bf7af4c505a832d2eb67c2d8357fc1ac3b15c by Philipp Hack : Corrects the thunks of asynchronous custom call starts. -- 76eaa3a985d9c21b76170c95456041a1d6169a1d by Philipp Hack : Corrects the thunks of asynchronous custom call starts. Merging this change closes #13938 PiperOrigin-RevId: 646309731 --- .../xla/xla/service/gpu/ir_emitter_unnested.cc | 4 ++-- .../xla/xla/tests/collective_ops_e2e_test.cc | 14 ++++++++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 118948792b215a..0580a2d611c021 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -1727,14 +1727,14 @@ absl::Status IrEmitterUnnested::EmitAsyncCustomCallStart( } #if GOOGLE_CUDA || TF_HIPBLASLT if (IsCublasLtMatmul(*wrapped)) { - auto status = EmitGemmThunk(custom_call); + auto status = EmitCublasLtMatmulThunk(custom_call); if (status.ok()) { thunk_sequence_.back()->set_execution_stream_id(execution_stream_id); } return status; } if (IsCublasLtMatmulF8(*wrapped)) { - auto status = EmitGemmThunk(custom_call); + auto status = EmitCublasLtMatmulThunkF8(custom_call); if (status.ok()) { thunk_sequence_.back()->set_execution_stream_id(execution_stream_id); } diff --git a/third_party/xla/xla/tests/collective_ops_e2e_test.cc b/third_party/xla/xla/tests/collective_ops_e2e_test.cc index 17ebdc2c01805a..fc14363173ad1b 100644 --- a/third_party/xla/xla/tests/collective_ops_e2e_test.cc +++ b/third_party/xla/xla/tests/collective_ops_e2e_test.cc @@ -650,7 +650,8 @@ TEST_F(CollectiveOpsTestE2E, NoAllToAllDecomposition) { // E2E tests comparing the results of windowed einsum and non-windowed cases. class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E { public: - void CollectiveOpsCompareWindowedNonWindowed(absl::string_view hlo_text) { + void CollectiveOpsCompareWindowedNonWindowed( + absl::string_view hlo_text, bool disable_dot_merger = false) { const int64_t kNumReplicas = 1; const int64_t kNumPartitions = 4; @@ -661,6 +662,9 @@ class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E { opts.set_xla_gpu_multi_streamed_windowed_einsum(true); opts.set_xla_gpu_graph_min_graph_size(200); opts.set_xla_gpu_enable_triton_gemm(false); + if (disable_dot_merger) { + opts.add_xla_disable_hlo_passes("dot-merger"); + } config.set_debug_options(opts); config.set_num_partitions(kNumPartitions); TF_ASSERT_OK_AND_ASSIGN(auto module, @@ -689,6 +693,9 @@ class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E { auto ref_opts = GetDebugOptionsForTest(); ref_opts.set_xla_gpu_graph_min_graph_size(200); ref_opts.set_xla_gpu_enable_triton_gemm(false); + if (disable_dot_merger) { + ref_opts.add_xla_disable_hlo_passes("dot-merger"); + } ref_config.set_debug_options(ref_opts); ref_config.set_num_partitions(kNumPartitions); TF_ASSERT_OK_AND_ASSIGN(auto ref_module, @@ -779,7 +786,10 @@ ENTRY main.12 { } // main.12 )"; - CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr); + // Disable the dot merger pass which can prevent the creation of FP8 GEMM + // Custom Calls. + CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr, + /*disable_dot_merger=*/true); } TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, From 6b5f0aeec40998ebe289937578c69a7cabd52927 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jun 2024 21:30:05 -0700 Subject: [PATCH 198/256] Automated Code Change PiperOrigin-RevId: 646324593 --- tensorflow/cc/framework/fuzzing/BUILD | 2 ++ tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/cc/framework/fuzzing/BUILD b/tensorflow/cc/framework/fuzzing/BUILD index 74a946c283777d..b31ccdb2ece913 100644 --- a/tensorflow/cc/framework/fuzzing/BUILD +++ b/tensorflow/cc/framework/fuzzing/BUILD @@ -30,6 +30,8 @@ cc_library( "//tensorflow/core/platform:hash", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:status", ], diff --git a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc index 585baa08df864a..667d566a4fa5c1 100644 --- a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc +++ b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "tensorflow/cc/framework/cc_op_gen_util.h" #include "tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h" #include "tensorflow/core/framework/api_def.pb.h" @@ -28,7 +30,6 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/types.h" #include "tsl/platform/status.h" From e350c969926c188c1cbc2bd04fb0ba33e62d14a7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jun 2024 22:14:42 -0700 Subject: [PATCH 199/256] Automated Code Change PiperOrigin-RevId: 646333281 --- .../lite/tools/evaluation/tasks/coco_object_detection/BUILD | 1 + .../tools/evaluation/tasks/coco_object_detection/run_eval.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD index c75c6fc24f0e15..09520ff67ff966 100644 --- a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD +++ b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD @@ -31,6 +31,7 @@ cc_library( "//tensorflow/lite/core/c:common", "//tensorflow/lite/tools:command_line_flags", "//tensorflow/lite/tools:logging", + "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", "//tensorflow/lite/tools/evaluation:evaluation_stage", "//tensorflow/lite/tools/evaluation:utils", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", diff --git a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc index 453d4ce74d9296..a35accb00aec15 100644 --- a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc +++ b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/tools/command_line_flags.h" +#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" #include "tensorflow/lite/tools/evaluation/stages/object_detection_stage.h" From 34e68084b44c782dcde7f6a9666b9efc0e965da0 Mon Sep 17 00:00:00 2001 From: zoranjovanovic-ns <126815388+zoranjovanovic-ns@users.noreply.github.com> Date: Mon, 24 Jun 2024 23:13:17 -0700 Subject: [PATCH 200/256] PR #13003: [ROCm] Switch on Triton feature for ROCm. Imported from GitHub PR https://github.com/openxla/xla/pull/13003 Copybara import of the project: -- f0e5fa9facd95011ac427d3e5f5ab91e5e89c39d by Zoran Jovanovic : [ROCm] Switch on Triton feature for ROCm. -- 9ebba63afde51d7ca1f9333678e15653a7a02d5d by Zoran Jovanovic : [ROCm] Fixed an issue with test cases from ir_emitter_triton_test.cc -- 4124c03cf3e32f0c01d41b41a8fbd59bf3da14ae by Zoran Jovanovic : [ROCm] Fixed an issue with gpu_compiler_test.cc -- 0098edc9c05103368985db1c8744018880764a40 by Zoran Jovanovic : [ROCm] Applied comments from code review. -- 3f38ab121b3a68bc83704814a9fece7cdf67ded1 by Zoran Jovanovic : [ROCm] Fixed failed tests because of https://github.com/openxla/xla/commit/19c11baa83f31e25a3f841cf41fa47a53e8ca161 Merging this change closes #13003 PiperOrigin-RevId: 646345795 --- third_party/triton/temporary/amd_pr7.patch | 46 ++++++++++++++++++ third_party/triton/temporary/series.bzl | 1 + .../triton/temporary/amd_pr7.patch | 46 ++++++++++++++++++ .../third_party/triton/temporary/series.bzl | 1 + third_party/xla/xla/service/gpu/BUILD | 4 +- .../xla/xla/service/gpu/amdgpu_compiler.cc | 7 +-- .../xla/xla/service/gpu/fusions/triton.cc | 3 ++ .../xla/xla/service/gpu/gemm_fusion.cc | 10 ++-- .../xla/xla/service/gpu/gemm_fusion_test.cc | 18 ------- .../xla/xla/service/gpu/gpu_compiler.cc | 8 +++- .../xla/service/gpu/ir_emitter_triton_rocm.cc | 47 ++++++++++--------- .../xla/service/gpu/ir_emitter_triton_test.cc | 12 +++++ .../xla/service/gpu/ir_emitter_unnested.cc | 6 +-- .../xla/stream_executor/device_description.h | 5 ++ 14 files changed, 163 insertions(+), 51 deletions(-) create mode 100644 third_party/triton/temporary/amd_pr7.patch create mode 100644 third_party/xla/third_party/triton/temporary/amd_pr7.patch diff --git a/third_party/triton/temporary/amd_pr7.patch b/third_party/triton/temporary/amd_pr7.patch new file mode 100644 index 00000000000000..4dbf9c37ae0dc4 --- /dev/null +++ b/third_party/triton/temporary/amd_pr7.patch @@ -0,0 +1,46 @@ +==== triton/BUILD#46 - /google/src/cloud/csigg/triton_amd/triton/BUILD ==== +# action=edit type=text +--- triton/BUILD 2024-04-11 02:00:21.000000000 -0700 ++++ triton/BUILD 2024-04-21 23:52:01.000000000 -0700 +@@ -725,12 +725,12 @@ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:GPUDialect", ++ "@llvm-project//mlir:GPUToROCDLTransforms", + "@llvm-project//mlir:IR", +- "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", +- "@llvm-project//mlir:TransformUtils", ++ "@llvm-project//mlir:ROCDLDialect", + "@llvm-project//mlir:Transforms", + ], + ) +diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +index f59efd6..cf601f0 100644 +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +@@ -1132,6 +1132,21 @@ struct FpToFpOpConversion + for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) { + inVals.push_back(operands[i][0]); + } ++ ++ bool isSrcFP16 = srcElementType.isF16(); ++ bool isSrcBF16 = srcElementType.isBF16(); ++ ++ if ((isSrcFP16 || isSrcBF16) ++ && isDstFP32) { ++ SmallVector outVals; ++ for (Value &v : inVals) { ++ if(isSrcFP16) ++ outVals.push_back(convertFp16ToFp32(loc, rewriter, v)); ++ else ++ outVals.push_back(convertBf16ToFp32(loc, rewriter, v)); ++ } ++ return outVals; ++ } + if (useFP16IntermediateSrc) + for (Value &v : inVals) + v = cvtFp32ToFp16(loc, rewriter, v, diff --git a/third_party/triton/temporary/series.bzl b/third_party/triton/temporary/series.bzl index 2e2e98d90bb87b..76d9121963651a 100644 --- a/third_party/triton/temporary/series.bzl +++ b/third_party/triton/temporary/series.bzl @@ -7,4 +7,5 @@ internal patch during the next triton integration process. temporary_patch_list = [ "//third_party/triton/temporary:linear_layout_rank_fix.patch", + "//third_party/triton/temporary:amd_pr7.patch", ] diff --git a/third_party/xla/third_party/triton/temporary/amd_pr7.patch b/third_party/xla/third_party/triton/temporary/amd_pr7.patch new file mode 100644 index 00000000000000..4dbf9c37ae0dc4 --- /dev/null +++ b/third_party/xla/third_party/triton/temporary/amd_pr7.patch @@ -0,0 +1,46 @@ +==== triton/BUILD#46 - /google/src/cloud/csigg/triton_amd/triton/BUILD ==== +# action=edit type=text +--- triton/BUILD 2024-04-11 02:00:21.000000000 -0700 ++++ triton/BUILD 2024-04-21 23:52:01.000000000 -0700 +@@ -725,12 +725,12 @@ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:GPUDialect", ++ "@llvm-project//mlir:GPUToROCDLTransforms", + "@llvm-project//mlir:IR", +- "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", +- "@llvm-project//mlir:TransformUtils", ++ "@llvm-project//mlir:ROCDLDialect", + "@llvm-project//mlir:Transforms", + ], + ) +diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +index f59efd6..cf601f0 100644 +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +@@ -1132,6 +1132,21 @@ struct FpToFpOpConversion + for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) { + inVals.push_back(operands[i][0]); + } ++ ++ bool isSrcFP16 = srcElementType.isF16(); ++ bool isSrcBF16 = srcElementType.isBF16(); ++ ++ if ((isSrcFP16 || isSrcBF16) ++ && isDstFP32) { ++ SmallVector outVals; ++ for (Value &v : inVals) { ++ if(isSrcFP16) ++ outVals.push_back(convertFp16ToFp32(loc, rewriter, v)); ++ else ++ outVals.push_back(convertBf16ToFp32(loc, rewriter, v)); ++ } ++ return outVals; ++ } + if (useFP16IntermediateSrc) + for (Value &v : inVals) + v = cvtFp32ToFp16(loc, rewriter, v, diff --git a/third_party/xla/third_party/triton/temporary/series.bzl b/third_party/xla/third_party/triton/temporary/series.bzl index 2e2e98d90bb87b..76d9121963651a 100644 --- a/third_party/xla/third_party/triton/temporary/series.bzl +++ b/third_party/xla/third_party/triton/temporary/series.bzl @@ -7,4 +7,5 @@ internal patch during the next triton integration process. temporary_patch_list = [ "//third_party/triton/temporary:linear_layout_rank_fix.patch", + "//third_party/triton/temporary:amd_pr7.patch", ] diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 4a4efce3557eed..4b808a5b0a5002 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -583,12 +583,14 @@ cc_library( "@triton//:TritonGPUToLLVM", "@triton//:TritonToTritonGPU", "@triton//:TritonGPUTransforms", + "@triton//:TritonLLVMIR", ]) + if_cuda_is_configured([ "@triton//third_party/nvidia:NVGPUToLLVM", "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", - "@triton//:TritonLLVMIR", ]) + if_rocm_is_configured([ "@local_tsl//tsl/platform:rocm_rocdl_path", + "@triton//third_party/amd:TritonAMDGPUToLLVM", + "@triton//third_party/amd:TritonAMDGPUTransforms", ]), ) diff --git a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc index fc56b9bc4f0360..04483ce86c78c9 100644 --- a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc @@ -39,6 +39,7 @@ limitations under the License. #include "xla/service/gpu/cudnn_fused_conv_rewriter.h" #include "xla/service/gpu/cusolver_rewriter.h" #include "xla/service/gpu/gemm_algorithm_picker.h" +#include "xla/service/gpu/gpu_algebraic_simplifier.h" #include "xla/service/gpu/gpu_compiler.h" #include "xla/service/gpu/gpu_conv_padding_legalization.h" #include "xla/service/gpu/gpu_conv_rewriter.h" @@ -141,7 +142,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( GetAlgebraicSimplifierOptions(hlo_module->config()); options.set_enable_conv_operand_swap(false); options.set_enable_unconditional_reduce_of_concat_replacement(false); - pipeline.AddPass>(options); + pipeline.AddPass>(options, gpu_version); // tf2xla bridge, DepthwiseConvolutionConverter, GpuConvRewriter, and // CudnnSimplifyPadding introduce reshapes and transposes. Run ReshapeMover @@ -151,7 +152,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( ReshapeMoverOptions reshape_mover_options; reshape_mover_options.reshape_of_1d_broadcast_is_cheap = true; pipeline.AddPass(reshape_mover_options); - pipeline.AddPass(options); + pipeline.AddPass(options, gpu_version); }(); // The reshapes and transposes can possibly be eliminated using @@ -162,7 +163,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( [&, &pipeline = pipeline.AddPass>( "simplify_after_conv_canonicalization")] { pipeline.AddPass(); - pipeline.AddPass(options); + pipeline.AddPass(options, gpu_version); }(); // GpuConvRewriter, GpuConvPaddingLegalization and diff --git a/third_party/xla/xla/service/gpu/fusions/triton.cc b/third_party/xla/xla/service/gpu/fusions/triton.cc index d983ab056ba213..f31724b71fa728 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton.cc @@ -205,6 +205,9 @@ absl::StatusOr TritonFusion::Emit( triton_config.set_block_k(64); triton_config.set_block_n(64); triton_config.set_split_k(1); + triton_config.set_num_stages(1); + triton_config.set_num_warps(2); + triton_config.set_num_ctas(1); block_level_parameters.num_ctas = 1; block_level_parameters.num_stages = 1; diff --git a/third_party/xla/xla/service/gpu/gemm_fusion.cc b/third_party/xla/xla/service/gpu/gemm_fusion.cc index 6fd59c31b2e54c..e80100c806beb0 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion.cc @@ -800,10 +800,14 @@ absl::StatusOr GemmFusion::Run( const absl::flat_hash_set& execution_threads) { auto cuda_compute_capability = std::get_if(&gpu_version_); - if (!cuda_compute_capability) { + auto rocm_compute_capability = + std::get_if(&gpu_version_); + if (!cuda_compute_capability && !rocm_compute_capability) { return absl::FailedPreconditionError( - "Triton support is only enabled for CUDA GPUs."); - } else if (!cuda_compute_capability->IsAtLeastAmpere()) { + "Triton support is only enabled for CUDA and ROCm GPUs."); + } + + if (cuda_compute_capability && !cuda_compute_capability->IsAtLeastAmpere()) { return absl::FailedPreconditionError( absl::StrCat("Triton support is only enabled for Ampere GPUs (compute ", "capability 8.0) and up, but got compute capability ", diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/gemm_fusion_test.cc index 1eb6066af9be46..a07ed2e5604704 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion_test.cc @@ -876,24 +876,6 @@ ENTRY e { "(compute capability 8.0) and up, but got"))); } -TEST_F(GemmFusionLevel2Test, GemmFusionBailsOutOnNonCudaGpu) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - p0 = f32[2,53] parameter(0) - p0e = f32[2,53] exponential(p0) - p1 = s16[53,2] parameter(1) - p1c = f32[53,2] convert(p1) - ROOT dot = f32[2,2] dot(p0e, p1c), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})")); - EXPECT_THAT( - GemmFusion(se::RocmComputeCapability{}).Run(module.get()), - tsl::testing::StatusIs( - absl::StatusCode::kFailedPrecondition, - ::testing::StrEq("Triton support is only enabled for CUDA GPUs."))); -} - TEST_F(GemmFusionLevel2Test, ParameterUsedElementwiseTwiceIsFused) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index f342425606d158..06874df49a2714 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1340,16 +1340,20 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( gpu_target_config.device_description.gpu_compute_capability(); pipeline.AddPass(gpu_version); const auto* cuda_cc = std::get_if(&gpu_version); + const auto* rocm_cc = std::get_if(&gpu_version); // Rewrite FP8 GEMMs ahead of Triton which currently lacks support for FP8 // and may rewrite quantized FP8 GEMMs as higher-precision GEMMs. pipeline.AddPass(gpu_version, GetToolkitVersion(), /*f8_rewrite=*/true); - if (debug_options.xla_gpu_enable_triton_gemm() && cuda_cc != nullptr && - cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) { + if (debug_options.xla_gpu_enable_triton_gemm() && + ((cuda_cc != nullptr && + cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) || + rocm_cc)) { pipeline.AddPass(); pipeline.AddPass(gpu_version); } + // Rewrite non-FP8 GEMMs. pipeline.AddPass(gpu_version, GetToolkitVersion(), /*f8_rewrite=*/false); diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_rocm.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_rocm.cc index b0d5dc5187c33a..9a39fb3f886279 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_rocm.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_rocm.cc @@ -14,7 +14,8 @@ limitations under the License. ==============================================================================*/ // TODO(ROCm): Enable and include ROCm Triton passes when ROCm Triton is // included in build. -// #include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h" +#include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h" +#include "third_party/amd/include/TritonAMDGPUTransforms/Passes.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" // from @llvm-project @@ -56,9 +57,10 @@ absl::Status CreateTritonPipeline( const int ccAsInt = 0; // TODO(ROCm): Check why some test fail when threadsPerWarp is set to 64. const int threadsPerWarp = 32; + auto ccRocm = std::get(cc); // Based on make_ttir() in - // @triton//:third_party/nvidia/backend/compiler.py + // @triton//:third_party/amd/backend/compiler.py pm.addPass(mlir::createInlinerPass()); pm.addPass(mt::createRewriteTensorPointerPass()); pm.addPass(mt::createCombineOpsPass()); @@ -69,47 +71,50 @@ absl::Status CreateTritonPipeline( pm.addPass(mlir::createSymbolDCEPass()); // Based on make_ttgir() in - // @triton//:third_party/nvidia/backend/compiler.py + // @triton//:third_party/amd/backend/compiler.py pm.addPass(mt::createConvertTritonToTritonGPUPass( - absl::StrFormat("cuda:%u", ccAsInt), block_level_parameters.num_warps, - threadsPerWarp, block_level_parameters.num_ctas)); + absl::StrCat("hip:", ccRocm.gfx_version()), + block_level_parameters.num_warps, threadsPerWarp, + block_level_parameters.num_ctas)); pm.addPass(mt::gpu::createTritonGPUCoalesce()); pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); pm.addPass(mt::gpu::createTritonGPUOptimizeThreadLocality()); - pm.addPass(createSparseBlockedToMMAPass()); pm.addPass(mt::gpu::createTritonGPUAccelerateMatmul()); pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); // TODO ROCm Check if we want to compare MI100 and greater + pm.addPass(mlir::createTritonAMDGPUOptimizeEpiloguePass()); pm.addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true})); - pm.addPass(mlir::createCSEPass()); - pm.addPass( - mt::gpu::createTritonGPUPipeline({block_level_parameters.num_stages})); - pm.addPass(mt::gpu::createTritonGPUPrefetch()); - - // TODO ROCm Check if we want to compare MI100 and greater + if (block_level_parameters.num_stages == 0 && ccRocm.has_amd_matrix_core()) { + pm.addPass(mlir::createTritonAMDGPUStreamPipelinePass()); + pm.addPass(mlir::createCanonicalizerPass()); + } pm.addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true})); pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); pm.addPass(mt::gpu::createTritonGPUReduceDataDuplication()); - pm.addPass(mt::gpu::createTritonGPUReorderInstructions()); + if (block_level_parameters.num_stages == 0) { + pm.addPass(mt::gpu::createTritonGPUReorderInstructions()); + } pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createSymbolDCEPass()); - pm.addPass(mlir::createCanonicalizerPass()); // Based on make_llir() in - // @triton//:third_party/nvidia/backend/compiler.py - // pm.addPass(mt::gpu::createDecomposeUnsupportedConversionsPass()); + // @triton//:third_party/amd/backend/compiler.py + pm.addPass(mlir::triton::AMD::createDecomposeUnsupportedConversionsPass( + ccRocm.gfx_version())); pm.addPass(mlir::createConvertSCFToCFPass()); pm.addPass(mlir::createConvertIndexToLLVMPass()); pm.addPass(mt::gpu::createAllocateSharedMemoryPass()); - // pm.addPass(mt::createConvertTritonAMDGPUToLLVMPass()); - pm.addPass(mlir::createArithToLLVMConversionPass()); + pm.addPass( + mt::createConvertTritonAMDGPUToLLVMPass(ccRocm.gfx_version(), true)); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createSymbolDCEPass()); // Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass. - pm.addPass(mlir::createConvertSCFToCFPass()); pm.addPass(mlir::createConvertControlFlowToLLVMPass()); - + pm.addPass(mlir::createArithToLLVMConversionPass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::createSymbolDCEPass()); + pm.addPass(mt::createConvertBuiltinFuncToLLVMPass()); // There is no clusters in ROCm for now. out_cluster_info.clusterDimX = 1; out_cluster_info.clusterDimY = 1; diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc index af403cb6875b2e..224dc94a1b3a2b 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc @@ -2458,6 +2458,9 @@ ENTRY e { TEST_F(TritonGemmTestAny, DoNotFuseConcatenationOfSplitNonContractingDimension) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "Not using autotuner on ROCM yet."; + } if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } @@ -3609,6 +3612,9 @@ ENTRY e { } TEST_F(CompareTest, UsingOptinSharedMemoryOnAmpereProducesSameResult) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "No Optin Shared Memory on AMD."; + } const se::DeviceDescription dev_info = backend().default_stream_executor()->GetDeviceDescription(); constexpr int kBytesOfSharedMemoryTested = 64 * 1024; @@ -5061,6 +5067,9 @@ CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> } TEST_F(Triton6xBF16GemmTest, Emit6xBF16GemmEndToEnd) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "ALG_DOT_BF16_BF16_F32_X6 not supported on ROCM."; + } const char* kHloText = R"( HloModule t @@ -5404,6 +5413,9 @@ CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> } TEST_F(Triton3xBF16GemmTest, Emit3xBF16GemmEndToEnd) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "ALG_DOT_BF16_BF16_F32_X3 not supported on ROCM."; + } const char* kHloText = R"( HloModule t diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 0580a2d611c021..33b37150c07c52 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -1584,8 +1584,8 @@ absl::Status IrEmitterUnnested::EmitTopKCustomCall( absl::Status IrEmitterUnnested::EmitTritonCustomCall( const HloCustomCallInstruction* instr) { -#if !GOOGLE_CUDA - return absl::UnimplementedError("Triton support requires CUDA"); +#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM + return absl::UnimplementedError("Triton support requires CUDA or ROCm"); #else auto generate = [this, &instr]() -> absl::StatusOr { mlir::MLIRContext& mlir_context = *ir_emitter_context_->mlir_context(); @@ -1613,7 +1613,7 @@ absl::Status IrEmitterUnnested::EmitTritonCustomCall( TF_ASSIGN_OR_RETURN( auto result, CompileTritonToLLVM(hlo_module->config(), hlo_module->name(), - ir_emitter_context_->cuda_compute_capability(), + ir_emitter_context_->gpu_compute_capability(), ir_emitter_context_->gpu_device_info(), block_level_parameters, triton_module.get(), ir_emitter_context_->llvm_module(), mlir_context)); diff --git a/third_party/xla/xla/stream_executor/device_description.h b/third_party/xla/xla/stream_executor/device_description.h index 3b5267bf305d8d..94f90b2d8422f7 100644 --- a/third_party/xla/xla/stream_executor/device_description.h +++ b/third_party/xla/xla/stream_executor/device_description.h @@ -218,6 +218,11 @@ class RocmComputeCapability { bool has_mfma_instr_support() const { return gfx9_mi100_or_later(); } + bool has_amd_matrix_core() const { + return (gfx9_mi100_or_later() || gfx_version().find("gfx11") || + gfx_version().find("gfx12")); + } + bool has_fp16_atomics_support() const { // TODO(rocm): Check. This should be the same as has_fast_fp16_support(). return gfx9_mi200_or_later(); From f5cd3ce94a296af4202113be110da5d5844e0a3a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jun 2024 23:36:25 -0700 Subject: [PATCH 201/256] HLO Buffer Assignment: include HloAliasAnalysis in BufferAssigner::MustNotLiveOut function parameters. PiperOrigin-RevId: 646350803 --- third_party/xla/xla/service/BUILD | 1 + third_party/xla/xla/service/buffer_assignment.cc | 6 ++++-- third_party/xla/xla/service/buffer_assignment.h | 4 ++-- third_party/xla/xla/service/buffer_assignment_test.cc | 4 +++- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 6a70a0aa96c65d..59a5dc94278e45 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -1858,6 +1858,7 @@ xla_cc_test( ":copy_insertion", ":cpu_plugin", ":flatten_call_graph", + ":hlo_alias_analysis", ":hlo_dce", ":hlo_memory_scheduler", ":hlo_ordering", diff --git a/third_party/xla/xla/service/buffer_assignment.cc b/third_party/xla/xla/service/buffer_assignment.cc index 0ac5dd39e9692d..0f366f7d514272 100644 --- a/third_party/xla/xla/service/buffer_assignment.cc +++ b/third_party/xla/xla/service/buffer_assignment.cc @@ -1201,7 +1201,8 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, // If a buffer maybe live out, the allocation cannot contain any node // where must_not_live_out_ returns true. for (const HloValue* value : hlo_buffer.values()) { - if ((*must_not_live_out_)(value->instruction(), value->index())) { + if ((*must_not_live_out_)(assignment->alias_analysis(), + value->instruction(), value->index())) { VLOG(4) << "Can't assign: " << value->instruction()->ToString() << " cannot live out of the module"; return false; @@ -1216,7 +1217,8 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, if (assignment->alias_analysis().BufferLivesOut(hlo_buffer)) { for (const auto& buffer_offset_size : allocation->assigned_buffers()) { const HloValue* value = buffer_offset_size.first; - if ((*must_not_live_out_)(value->instruction(), value->index())) { + if ((*must_not_live_out_)(assignment->alias_analysis(), + value->instruction(), value->index())) { VLOG(4) << "Can't assign: " << value->instruction() << " cannot live out of the module"; return false; diff --git a/third_party/xla/xla/service/buffer_assignment.h b/third_party/xla/xla/service/buffer_assignment.h index b6887beb8546c3..b3f05e345bd767 100644 --- a/third_party/xla/xla/service/buffer_assignment.h +++ b/third_party/xla/xla/service/buffer_assignment.h @@ -623,8 +623,8 @@ class BufferAssigner { public: using Colorer = std::function; - using MustNotLiveOut = - std::function; + using MustNotLiveOut = std::function; using PrivateStacks = absl::flat_hash_map>; diff --git a/third_party/xla/xla/service/buffer_assignment_test.cc b/third_party/xla/xla/service/buffer_assignment_test.cc index 40061a1c7d4d41..a11b86ca357043 100644 --- a/third_party/xla/xla/service/buffer_assignment_test.cc +++ b/third_party/xla/xla/service/buffer_assignment_test.cc @@ -36,6 +36,7 @@ limitations under the License. #include "xla/service/copy_insertion.h" #include "xla/service/flatten_call_graph.h" #include "xla/service/hlo.pb.h" +#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_dce.h" #include "xla/service/hlo_memory_scheduler.h" #include "xla/service/hlo_ordering.h" @@ -143,7 +144,8 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignmentNoBuffersReuseForAdd( HloModule* module, int64_t alignment = 1) { - auto must_not_live_out = [](const HloInstruction* instruction, + auto must_not_live_out = [](const HloAliasAnalysis& alias_analysis, + const HloInstruction* instruction, const ShapeIndex&) { return instruction->opcode() == HloOpcode::kAdd; }; From e75634a7c117e2e64869d9909918fb6dd75a3535 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 24 Jun 2024 23:39:32 -0700 Subject: [PATCH 202/256] NFC: Move floordiv rewrite for LHS sums to a separate function. PiperOrigin-RevId: 646351431 --- .../xla/xla/service/gpu/model/indexing_map.cc | 118 +++++++++--------- 1 file changed, 59 insertions(+), 59 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index 9b7045310ecd5c..bcfcf494a5a5b9 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -115,11 +115,14 @@ class AffineExprSimplifier { mlir::AffineExpr RewriteFloorDiv(mlir::AffineBinaryOpExpr div); // Rewrites `(c % ab) // a` to `(c // a) % b`. Returns nullptr on mismatch. - AffineExpr SimplifyModDiv(AffineExpr divisor, int64_t dividend); + AffineExpr SimplifyModDiv(AffineExpr dividend, int64_t divisor); // Rewrites `a // b // c` to `a // (b * c)` if `c` is positive. Returns // nullptr on mismatch. - AffineExpr SimplifyDivDiv(AffineExpr divisor, int64_t dividend); + AffineExpr SimplifyDivDiv(AffineExpr dividend, int64_t divisor); + + // Rewrites `a // b` where a may be a sum. + AffineExpr SimplifySumDiv(AffineExpr dividend, int64_t divisor); // Rewrites summands in arbitrarily nested sums (e.g, ((a+b)+c)) by applying // `fn` to each one. In the example, the result is fn(a)+fn(b)+fn(c). @@ -183,15 +186,13 @@ AffineExpr AffineExprSimplifier::RewriteMod(AffineBinaryOpExpr mod) { auto zero = getAffineConstantExpr(0, mod.getContext()); int64_t extracted_constant = 0; auto new_lhs = MapSummands(lhs_simplified, [&](AffineExpr expr) { - if (auto cst = mlir::dyn_cast(expr); - cst && cst.getValue() >= m) { + if (auto cst = mlir::dyn_cast(expr)) { extracted_constant += cst.getValue(); return zero; } - if (auto multiplier = GetConstantRhs(expr, AffineExprKind::Mul)) { - if (*multiplier % m == 0) { - return zero; - } + if (auto multiplier = GetConstantRhs(expr, AffineExprKind::Mul); + multiplier && (*multiplier % m == 0)) { + return zero; } return expr; }); @@ -231,57 +232,33 @@ AffineExpr AffineExprSimplifier::RewriteMod(AffineBinaryOpExpr mod) { return new_lhs % mod.getRHS() + extracted; } -AffineExpr AffineExprSimplifier::SimplifyModDiv(AffineExpr divisor, - int64_t dividend) { - if (auto mod = GetConstantRhs(divisor, AffineExprKind::Mod); - mod && (*mod % dividend == 0)) { - return GetLhs(divisor).floorDiv(dividend) % (*mod / dividend); +AffineExpr AffineExprSimplifier::SimplifyModDiv(AffineExpr dividend, + int64_t divisor) { + if (auto mod = GetConstantRhs(dividend, AffineExprKind::Mod); + mod && (*mod % divisor == 0)) { + return GetLhs(dividend).floorDiv(divisor) % (*mod / divisor); } return nullptr; } -AffineExpr AffineExprSimplifier::SimplifyDivDiv(AffineExpr divisor, - int64_t dividend) { - // The inner dividend here can be negative. - if (auto dividend_2 = GetConstantRhs(divisor, AffineExprKind::FloorDiv)) { - return GetLhs(divisor).floorDiv(dividend * *dividend_2); +AffineExpr AffineExprSimplifier::SimplifyDivDiv(AffineExpr dividend, + int64_t divisor) { + // The inner divisor here can be negative. + if (auto inner_divisor = GetConstantRhs(dividend, AffineExprKind::FloorDiv)) { + return GetLhs(dividend).floorDiv(divisor * *inner_divisor); } return nullptr; } -AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) { - auto mlir_context = range_evaluator_->GetMLIRContext(); - auto rhs_range = range_evaluator_->ComputeExpressionRange(div.getRHS()); - - // TODO(jreiffers): Split this function into multiple (one for each rewrite - // rule). - - // The logic below assumes we have a constant positive RHS. - if (!rhs_range.IsPoint() || rhs_range.lower <= 0) { - return div; - } - int64_t d = rhs_range.lower; - - auto lhs_simplified = SimplifyOnce(div.getLHS()); - - // Rewrite `(c % ab) // a` to `(c // a) % b`. - if (auto result = SimplifyModDiv(lhs_simplified, d)) { - return result; - } - - // Rewrite `((a // b) // c)` to `a // (b * c)`. - if (auto result = SimplifyDivDiv(lhs_simplified, d)) { - return result; - } - - AffineExpr zero = getAffineConstantExpr(0, mlir_context); +AffineExpr AffineExprSimplifier::SimplifySumDiv(AffineExpr dividend, + int64_t divisor) { + AffineExpr zero = getAffineConstantExpr(0, dividend.getContext()); AffineExpr extracted = zero; - auto new_dividend = MapSummands(lhs_simplified, [&](AffineExpr expr) { + auto new_dividend = MapSummands(dividend, [&](AffineExpr expr) { if (auto multiplier = GetConstantRhs(expr, AffineExprKind::Mul)) { - // (x * 7 + ...) / 3 -> can't extract. We could extract x * 2 and keep - // one x, but we currently have no reason to do that. - if (*multiplier % d == 0) { - int64_t factor = *multiplier / d; + // We can extract summands whose factor is a multiple of the divisor. + if (*multiplier % divisor == 0) { + int64_t factor = *multiplier / divisor; extracted = extracted + GetLhs(expr) * factor; // Remove from dividend. return zero; @@ -291,8 +268,8 @@ AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) { return expr; }); - // The gcd of all multipliers and the dividend. - int64_t multiplier_divisor_gcd = d; + // The gcd of all multipliers and the divisor. + int64_t multiplier_divisor_gcd = divisor; Interval no_multiplier_range{0, 0}; VisitSummands(new_dividend, [&](AffineExpr summand) { if (auto multiplier = GetConstantRhs(summand, AffineExprKind::Mul)) { @@ -316,10 +293,39 @@ AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) { // the result of the division. return zero; }); - d /= multiplier_divisor_gcd; + divisor /= multiplier_divisor_gcd; } - return new_dividend.floorDiv(d) + extracted; + return new_dividend.floorDiv(divisor) + extracted; +} + +AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) { + auto rhs_range = range_evaluator_->ComputeExpressionRange(div.getRHS()); + + // The logic below assumes we have a constant positive RHS. + if (!rhs_range.IsPoint() || rhs_range.lower <= 0) { + return div; + } + int64_t d = rhs_range.lower; + + auto lhs_simplified = SimplifyOnce(div.getLHS()); + + // Rewrite `(c % ab) // a` to `(c // a) % b`. + if (auto result = SimplifyModDiv(lhs_simplified, d)) { + return result; + } + + // Rewrite `((a // b) // c)` to `a // (b * c)`. + if (auto result = SimplifyDivDiv(lhs_simplified, d)) { + return result; + } + + // Rewrite sums on the LHS. + if (auto result = SimplifySumDiv(lhs_simplified, d)) { + return result; + } + + return div; } std::optional AffineExprSimplifier::GetConstantRhs( @@ -596,12 +602,6 @@ AffineMap AffineExprSimplifier::Simplify(AffineMap affine_map) { affine_map.getContext())); } -// Computes intersection of two ranges. -Interval Intersect(const Interval& lhs, const Interval& rhs) { - return Interval{std::max(lhs.lower, rhs.lower), - std::min(lhs.upper, rhs.upper)}; -} - // Simplifies a constraint range, i.e. a constraint d0 + x in [lb, ub] will // become d0 in [lb - x, ub - x]. Also supports *, floorDiv. bool SimplifyConstraintRangeOnce(AffineExpr* expr, Interval* range) { From f7d4c3cb4b473ef360821bb3c835cc8ccad6b35f Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 24 Jun 2024 23:59:46 -0700 Subject: [PATCH 203/256] [XLA:GPU] Add proper constraint intersection logic for `ConstraintExpression`s. Given two constraints on the same `AffineExpr`, taking their intersection now constrains the `AffineExpr` to the intersection of their intervals. PiperOrigin-RevId: 646355471 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/model/symbolic_tile.cc | 75 ++++++++----------- .../xla/xla/service/gpu/model/symbolic_tile.h | 2 - .../gpu/model/symbolic_tile_analysis.cc | 1 - .../gpu/model/symbolic_tile_analysis_test.cc | 20 ----- .../service/gpu/model/symbolic_tile_test.cc | 41 ++++------ .../gpu/softmax_rewriter_triton_test.cc | 48 +----------- 7 files changed, 53 insertions(+), 135 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 4b808a5b0a5002..9755cf25cb81db 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2592,6 +2592,7 @@ xla_cc_test( ":softmax_rewriter_triton", "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/service:instruction_fusion", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/service/gpu/model:gpu_hlo_cost_analysis", diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc index 038e83f08cd55c..91fffa1f39cd8a 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc @@ -54,7 +54,7 @@ using ::mlir::AffineSymbolExpr; using ::mlir::getAffineConstantExpr; using ::mlir::getAffineDimExpr; using ::mlir::MLIRContext; -using ConstraintMap = SymbolicTile::ConstraintMap; +using ConjointConstraints = ConstraintExpression::ConjointConstraints; // Gets a modified version of `expressions` where both the original dimensions // and symbols are replaced with symbols. @@ -597,42 +597,37 @@ AffineExpr SimplifyAffineExpr(const AffineExpr& expr, return tmp_indexing_map.GetAffineMap().getResults().back(); } -// Merges `maybe_first_map` and `second_map` if -// (1) `maybe_first_map` is present, and -// (2) `second_map` and `*maybe_first_map` have distinct sets of keys. -// Otherwise, returns `std::nullopt`. -// -// -// The behaviour of this function is in spirit equivalent to using C++23's -// `std::optional::and_then` to merge a collection of `ConstraintMap`s. -// -// We pass `maybe_first_map` by value here in order to exploit move semantics -// to avoid copies when possible. -// -// TODO(bchetioui): allow merging constraints in more edge cases, e.g. if one -// of the intervals is contained within the other. -// TODO(bchetioui): clean up this util. -std::optional MergeConstraintMapIfPresentAndCompatible( - std::optional maybe_first_map, - const ConstraintMap& second_map) { - if (!maybe_first_map.has_value()) { - return std::nullopt; - } - - ConstraintMap& first_map = *maybe_first_map; - - for (const auto& [expr, interval] : second_map) { - if (first_map.contains(expr)) { - AffineMapPrinter printer; - VLOG(1) << "Got two different constraints for expression " - << printer.ToString(expr); - return std::nullopt; +// Tries to take the conjunction of `conjunction_1` and `conjunction_2`. +// Fails and returns `std::nullopt` if and only if the conjunction attempt +// results in an unsatisfiable constraint. +std::optional TryIntersectConjointConstraints( + ConjointConstraints conjunction_1, + const ConjointConstraints& conjunction_2) { + if (conjunction_1.empty()) { + return conjunction_2; + } + + if (conjunction_2.empty()) { + return std::move(conjunction_1); + } + + ConjointConstraints result = std::move(conjunction_1); + for (const auto& [expr, interval] : conjunction_2) { + if (auto result_it = result.find(expr); result_it != result.end()) { + auto& [result_expr, result_interval] = *result_it; + result_interval = result_interval.Intersect(interval); + if (!result_interval.IsFeasible()) { + AffineMapPrinter printer; + VLOG(1) << "Got two incompatible intervals for expression " + << printer.ToString(expr); + return std::nullopt; + } + } else { + result.insert({expr, interval}); } - - first_map.insert({expr, interval}); } - return first_map; + return result; } } // anonymous namespace @@ -674,8 +669,7 @@ std::optional MergeConstraintMapIfPresentAndCompatible( for (ConjointConstraints& conjunction_2 : second.disjoint_conjoint_constraints_) { std::optional maybe_conjunction = - MergeConstraintMapIfPresentAndCompatible(conjunction_1, - conjunction_2); + TryIntersectConjointConstraints(conjunction_1, conjunction_2); // We only add the resulting conjunction to the result // `ConstraintExpression` if it is satisfiable, since it is otherwise // redundant: @@ -713,8 +707,7 @@ std::optional MergeConstraintMapIfPresentAndCompatible( return first; } -void ConstraintExpression::Or( - ConstraintExpression::ConjointConstraints conjunction) { +void ConstraintExpression::Or(ConjointConstraints conjunction) { if (conjunction.empty()) { return; } @@ -723,8 +716,7 @@ void ConstraintExpression::Or( is_satisfiable_ = true; } -void ConstraintExpression::And( - ConstraintExpression::ConjointConstraints conjunction) { +void ConstraintExpression::And(ConjointConstraints conjunction) { if (!is_satisfiable_ || conjunction.empty()) { return; } @@ -739,8 +731,7 @@ void ConstraintExpression::And( for (ConjointConstraints& conjunction_2 : disjoint_conjoint_constraints_) { std::optional maybe_result = - MergeConstraintMapIfPresentAndCompatible(std::move(conjunction_2), - conjunction); + TryIntersectConjointConstraints(std::move(conjunction_2), conjunction); // TODO(bchetioui): rework `MergeConstraintMapIfPresentAndCompatible`. if (maybe_result.has_value()) { new_constraints.push_back(std::move(*maybe_result)); diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.h b/third_party/xla/xla/service/gpu/model/symbolic_tile.h index 47b4084ae84ae8..75830944b01308 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile.h +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.h @@ -230,8 +230,6 @@ class SymbolicTile { public: static std::optional FromIndexingMap(IndexingMap indexing_map); - using ConstraintMap = llvm::DenseMap; - // For printing in tests. std::string RtVarsToString( const AffineMapPrinter& printer = AffineMapPrinter()) const; diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc index 4cea5d1fee6830..415925580cb390 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -69,7 +69,6 @@ namespace { using ::mlir::AffineExpr; using ::mlir::MLIRContext; -using ConstraintMap = SymbolicTile::ConstraintMap; // Computes indexing map from program id into the tile offset for the given // shape and tile sizes. diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc index 74942b19eaae39..97ddfb0036a8ac 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -511,26 +511,6 @@ ENTRY main { EXPECT_THAT(constraints.DisjointConjointConstraints().front(), SizeIs(2)); } -TEST_F(SymbolicTileAnalysisTest, BailsOutWhenConstraintsCanNotBeMerged) { - // TODO(bchetioui): allow merging a constraint with itself. - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -fusion { - p0 = f32[1,48,4,8]{3,2,1,0} parameter(0) - p1 = f32[1,48,4,8]{3,2,1,0} parameter(1) - bitcast_p0 = f32[48,32]{1,0} bitcast(p0) - bitcast_p1 = f32[48,32]{1,0} bitcast(p1) - ROOT add = f32[48,32]{1,0} add(bitcast_p0, bitcast_p1) -} - -ENTRY main { - p0 = f32[1,48,4,8]{3,2,1,0} parameter(0) - p1 = f32[1,48,4,8]{3,2,1,0} parameter(1) - ROOT fusion = f32[48,32]{1,0} fusion(p0, p1), kind=kLoop, calls=fusion -})")); - EXPECT_FALSE(TryAnalyzeModule(module.get()).has_value()); -} - bool AlwaysValid(absl::Span) { return true; } TEST(GetGoodTilingsTest, ReturnsOneTilingWhenRankIsZero) { diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc index 2d4785aa243d67..da64bf2d16c2c8 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc @@ -732,28 +732,6 @@ TEST_F(SymbolicTileTest, CanCombineCompatibleConstraints) { )"))); } -TEST_F(SymbolicTileTest, - DerivesUnsatisfiableConstraintWhenMergingOfConstraintsIsUnsupported) { - // This is kind of an artificial test case that we could easily support---we - // assume here that we can't merge two constraints that are the same. - // Nevertheless, there doesn't seem to be an obvious way to produce other - // constraints that would trigger this particular failure at the moment. This - // will change as we support more constraints, disjunctions, etc... - IndexingMap indexing_map( - ParseAffineMap("(d0) -> (d0 mod 6, d0 mod 6)", &mlir_context_), - /*dimensions=*/{DimVar{0, 10}}, /*range_vars=*/{}, /*rt_vars=*/{}); - - EXPECT_THAT(SymbolicTile::FromIndexingMap(std::move(indexing_map)), - Optional(MatchSymbolicTileString(R"( - Symbolic tile with - offset_map: ()[s0] -> (0, 0) - size_map: ()[s0] -> (s0 - ((s0 - 1) floordiv 6) * 6, s0 - ((s0 - 1) floordiv 6) * 6) - stride_map: ()[s0] -> (1, 1) - constraints: - unsatisfiable - )"))); -} - TEST_F(SymbolicTileTest, CanDeriveTileWhenPreexistingConstraintsCanBeSimplifiedAway) { // The example is from @@ -833,6 +811,22 @@ TEST_F(ConstraintExpressionTest, PrettyPrintingTest) { "d0 in [0, 5] && d1 in [0, 5] || d2 in [0, 5]")); } +TEST_F(ConstraintExpressionTest, + ConjunctionOfConstraintsOnTheSameExpressionAreIntersected) { + ConstraintExpression constraints; + + constraints.And(GetConjointConstraints({{"d0", Interval{0, 5}}})); + EXPECT_THAT(constraints, MatchConstraintExpressionString("d0 in [0, 5]")); + + // Constraints are intersected. + constraints.And(GetConjointConstraints({{"d0", Interval{3, 6}}})); + EXPECT_THAT(constraints, MatchConstraintExpressionString("d0 in [3, 5]")); + + // Empty intersection results in unsatisfiability. + constraints.And(GetConjointConstraints({{"d0", Interval{7, 8}}})); + EXPECT_THAT(constraints, MatchConstraintExpressionString("unsatisfiable")); +} + TEST_F(ConstraintExpressionTest, UnsatisfiableConstraintExpressionHoldsNoConstraint) { ConstraintExpression unsatisfiable_constraint = @@ -1069,9 +1063,6 @@ TEST_F( ConstraintExpression::And(constraints_1, constraints_2).is_satisfiable()); } -// TODO(b/334043867): add support for intersecting constraints within a single -// conjunction. - } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc index 679e8416eccfd9..7f8f848f557a23 100644 --- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/instruction_fusion.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" @@ -251,51 +252,8 @@ ENTRY main { m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } -TEST_P(SoftmaxRewriterTritonTest, - CanNotFuseSoftmaxWhenResultingComputationCanNotBeTiledCorrectly) { - PrimitiveType data_type = GetParam(); - const std::string hlo_string_template = R"( -HloModule softmax -max_computation { - arg_0 = $0[] parameter(0) - arg_1 = $0[] parameter(1) - ROOT maximum = $0[] maximum(arg_0, arg_1) -} -add_computation { - arg_0.1 = $0[] parameter(0) - arg_1.1 = $0[] parameter(1) - ROOT add = $0[] add(arg_0.1, arg_1.1) -} -ENTRY main { - param_0 = $0[130,125]{1,0} parameter(0) - constant_neg_inf = $0[] constant(-inf) - bitcasted_param_0 = $0[65,2,125] bitcast(param_0) - reduce = $0[65,2]{1,0} reduce(bitcasted_param_0, constant_neg_inf), dimensions={2}, to_apply=max_computation - bitcasted_reduce = $0[130] bitcast(reduce) - broadcast = $0[130,125]{1,0} broadcast(bitcasted_reduce), dimensions={0} - bitcasted_broadcast = $0[65,2,125] bitcast(broadcast) - subtract = $0[65,2,125]{2,1,0} subtract(bitcasted_param_0, bitcasted_broadcast) - bitcasted_subtract = $0[130,125] bitcast(subtract) - exponential = $0[130,125]{1,0} exponential(bitcasted_subtract) - constant_zero = $0[] constant(0) - bitcasted_exponential = $0[2,65,125] bitcast(exponential) - second_reduce = $0[2,65]{1,0} reduce(bitcasted_exponential, constant_zero), dimensions={2}, to_apply=add_computation - second_bitcasted_reduce = $0[130] bitcast(second_reduce) - second_broadcast = $0[130,125]{1,0} broadcast(second_bitcasted_reduce), dimensions={0} - second_bitcasted_broadcast = $0[2,65,125] bitcast(second_broadcast) - divide = $0[2,65,125]{2,1,0} divide(bitcasted_exponential, second_bitcasted_broadcast) - ROOT bitcasted_divide = $0[130,125] bitcast(divide) -} -)"; - const std::string hlo_string = - absl::Substitute(hlo_string_template, - primitive_util::LowercasePrimitiveTypeName(data_type)); - - auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - - EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value()); -} +// TODO(b/334043867): is there still a meaningful written that can be written +// here for 'CanNotFuseSoftmaxWhenResultingComputationCanNotBeTiledCorrectly'? TEST_P(SoftmaxRewriterTritonTest, CanNotFuseSoftmaxDiamondWithWrongLayout) { PrimitiveType data_type = GetParam(); From 0e7c7e784d0f5a1611457afc96dde71a38c80ede Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Tue, 25 Jun 2024 00:22:08 -0700 Subject: [PATCH 204/256] Simplify more divisions. This reintroduces the simplification I recently removed, but does it properly this time. Previously, we rewrote: (d0 * 128 + d1) floordiv 192 to (d0 * 2 + d1 floordiv 64) floordiv 3. Now, we do it the other way around, which is the correct way, since it reduces the number of divisons. PiperOrigin-RevId: 646361465 --- .../xla/xla/service/gpu/model/indexing_map.cc | 47 ++++++++++++++----- .../service/gpu/model/indexing_map_test.cc | 16 +++++++ 2 files changed, 51 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index bcfcf494a5a5b9..5e74a5da529bba 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -271,6 +271,8 @@ AffineExpr AffineExprSimplifier::SimplifySumDiv(AffineExpr dividend, // The gcd of all multipliers and the divisor. int64_t multiplier_divisor_gcd = divisor; Interval no_multiplier_range{0, 0}; + std::optional min_inner_divisor = std::nullopt; + std::optional inner_divisor_gcd = std::nullopt; VisitSummands(new_dividend, [&](AffineExpr summand) { if (auto multiplier = GetConstantRhs(summand, AffineExprKind::Mul)) { multiplier_divisor_gcd = std::gcd(multiplier_divisor_gcd, *multiplier); @@ -278,6 +280,14 @@ AffineExpr AffineExprSimplifier::SimplifySumDiv(AffineExpr dividend, no_multiplier_range = no_multiplier_range + range_evaluator_->ComputeExpressionRange(summand); } + + if (auto inner_divisor = + GetConstantRhs(summand, AffineExprKind::FloorDiv)) { + min_inner_divisor = + std::min(min_inner_divisor.value_or(*inner_divisor), *inner_divisor); + inner_divisor_gcd = + std::gcd(inner_divisor_gcd.value_or(*inner_divisor), *inner_divisor); + } }); // Consider an expression like: `(x * 6 + y) / 9`. if the range of `y` is at @@ -296,6 +306,24 @@ AffineExpr AffineExprSimplifier::SimplifySumDiv(AffineExpr dividend, divisor /= multiplier_divisor_gcd; } + // If we have an inner divisor whose value is equal to the GCD of all the + // divisors, we can remove a division: + // `(a0 / c + a1 / cd + ...) / e` -> `(a0 + a1 / d + (...) * c) / ce` + // This potentially increases the number of multiplications, but it's + // generally a win. It also matches what the MLIR simplifier does better, so + // we can get more simplifications. + if (min_inner_divisor && *min_inner_divisor > 0 && + min_inner_divisor == inner_divisor_gcd) { + new_dividend = MapSummands(new_dividend, [&](AffineExpr summand) { + if (auto inner_divisor = + GetConstantRhs(summand, AffineExprKind::FloorDiv)) { + return GetLhs(summand).floorDiv(*inner_divisor / *inner_divisor_gcd); + } + return summand * *inner_divisor_gcd; + }); + divisor *= *inner_divisor_gcd; + } + return new_dividend.floorDiv(divisor) + extracted; } @@ -481,18 +509,13 @@ AffineExpr AffineExprSimplifier::SimplifyOnce(AffineExpr expr) { if (!div) continue; // Already erased. if ((div_mul % mod_mul) || (div_mul / mod_mul) != mod_c) continue; - auto mod_lhs = GetLhs(mod); - if (GetConstantRhs(mod_lhs, AffineExprKind::FloorDiv)) { - // If x is a floorDiv itself, we need to check a bit more carefully: - // ((x // c0) % c1) * d + (x // (c0 * c1)) * (c1 * d)` - // `x // (c0 * c1)` will be simplified, so we we may not even have - // `c0 * c1` in the expression, if `x` contains a multiplier. - if (Simplify(mod_lhs.floorDiv(*mod_c)) != Simplify(div)) continue; - } else { - if (mod_lhs != GetLhs(div)) continue; - auto div_c = GetConstantRhs(div, AffineExprKind::FloorDiv); - if (mod_c != div_c) continue; - } + // In many cases, we could just compare the LHSes of the mod and the + // div, but if x is a floorDiv itself, we need to check a bit more + // carefully: + // ((x // c0) % c1) * d + (x // (c0 * c1)) * (c1 * d)` + // `x // (c0 * c1)` will be simplified, so we we may not even have + // `c0 * c1` in the expression, if `x` contains a multiplier. + if (Simplify(GetLhs(mod).floorDiv(*mod_c)) != Simplify(div)) continue; others.push_back(GetLhs(mod) * mod_mul); divs[div_i].first = nullptr; diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc index 24bbf63fd385d4..1e326a532de107 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc @@ -715,6 +715,22 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsInSequence) { )")); } +TEST_F(IndexingMapTest, AffineMapSimplification_DivGcdGreater1) { + auto serialized_map = + "()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 - ((s0 * 2 + s1 floordiv 64) " + "floordiv 3) * 768 + ((s0 * 128 + s1) floordiv 192) * 768)"; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {}, {1234, 128, 4}); + EXPECT_TRUE(indexing_map.Simplify()); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + ()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2) + domain: + s0 in [0, 1233] + s1 in [0, 127] + s2 in [0, 3] + )")); +} + TEST_F(IndexingMapTest, AffineMapSimplification_NegativeDiv) { // (s0 floordiv 2) floordiv -7 is not s0 floordiv -14: // 15 // 2 // -7 = -1 From b2a67e829fc56df9bb9ff0e3aa48fc769e24e762 Mon Sep 17 00:00:00 2001 From: lingzhi98 <103185827+lingzhi98@users.noreply.github.com> Date: Tue, 25 Jun 2024 00:42:29 -0700 Subject: [PATCH 205/256] PR #13958: Prefer vector size as 4 for u8/s8 transpose Imported from GitHub PR https://github.com/openxla/xla/pull/13958 We observe 4 bytes as per vectorized load can get better performance for transpose, which means 4xi8 is the better choice compared with previous 2xi8. ``` %fused_computation { %p0 = u8[8,4096,4096] parameter(0) %transpose = u8[8,4096,4096] transpose(%p0), dimensions={0,2,1} } ENTRY main { %param = u8[8,4096,4096] parameter(0) ROOT %fusion = u8[8,4096,4096] fusion(%param), kind=kInput, calls=%fused_computation } improve from 270us to 214us, collected on A100 40GB machine. ``` Copybara import of the project: -- 96d19dc0b9e15de7c77bd223fa4c3c83e4e35285 by Zhou, Lingzhi : prefer vector size as 4 for u8/s8 transpose -- 425f0e9239d236956e62c78234142796c6142bbe by Zhou, Lingzhi : remove result comparation for vectorized transpose -- 8283d3c69dc9517164f6b4cc6a0787d80ed006e2 by Zhou, Lingzhi : refine -- 13dc604351bff933fd44e1c195a930c7818886e2 by Zhou, Lingzhi : change variable name Merging this change closes #13958 PiperOrigin-RevId: 646367040 --- .../xla/service/gpu/fusions/transpose_mlir.cc | 22 ++++++++---- .../gpu/fusions/transpose_mlir_test.cc | 36 +++++++++++++++++-- 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc index 0f5cee3ca64a1a..8459134499f7c4 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc @@ -79,6 +79,7 @@ using mlir_converter::ApplyIndexing; constexpr int kNumRows = 4; constexpr int kBaseBlockSize = WarpSize(); constexpr int kNumThreadsPerBlock = 128; +constexpr int kMaxVectorizedBytes = 4; } // namespace @@ -126,13 +127,20 @@ MlirTransposeFusion::MlirTransposeFusion(const HloFusionAnalysis& analysis) // the input dimensions are divisible by the vector size. Vectorizing loads // for large data types does not help (there's already enough parallelism). const auto& device = analysis_.device_info(); - bool enough_work = Product(block_counts_) * kNumThreadsPerBlock >= - 4 * device.core_count() * device.threads_per_core_limit(); - bool enough_shmem = shmem_usage * 4 <= device.shared_memory_per_block(); - bool aligned_dims = - (input_shape_[2] % 2 == 0) && (input_shape_[permutation_[2]] % 2 == 0); - if (max_element_bytes < 4 && enough_work && enough_shmem && aligned_dims) { - compute_block_sizes(2); + for (int vec_size = kMaxVectorizedBytes / max_element_bytes; vec_size > 1; + vec_size /= 2) { + int elems_per_thread = vec_size * vec_size; + bool enough_work = Product(block_counts_) * kNumThreadsPerBlock >= + elems_per_thread * device.core_count() * + device.threads_per_core_limit(); + bool enough_shmem = + shmem_usage * elems_per_thread <= device.shared_memory_per_block(); + bool aligned_dims = (input_shape_[2] % vec_size == 0) && + (input_shape_[permutation_[2]] % vec_size == 0); + if (enough_work && enough_shmem && aligned_dims) { + compute_block_sizes(vec_size); + break; + } } } diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc index fef26c162b7282..a492cb0af146c5 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc @@ -666,7 +666,6 @@ TEST_F(MlirTransposeFusionTest, VectorizedTranspose021) { )"; TF_EXPECT_OK(EmitAndCheckIR( kHloString, "// CHECK: xla_gpu.allocate_shared : tensor<1x64x65xbf16>")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } TEST_F(MlirTransposeFusionTest, VectorizedTranspose210) { @@ -684,7 +683,40 @@ TEST_F(MlirTransposeFusionTest, VectorizedTranspose210) { )"; TF_EXPECT_OK(EmitAndCheckIR( kHloString, "// CHECK: xla_gpu.allocate_shared : tensor<64x1x65xbf16>")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirTransposeFusionTest, PreferLargeVectorSize021) { + auto kHloString = R"( + HloModule Transpose + %fused_computation { + %p0 = u8[256,256,256] parameter(0) + %transpose = u8[256,256,256] transpose(%p0), dimensions={0,2,1} + } + ENTRY main { + %param = u8[256,256,256] parameter(0) + ROOT %fusion = u8[256,256,256] fusion(%param), kind=kInput, + calls=%fused_computation + } + )"; + TF_EXPECT_OK(EmitAndCheckIR( + kHloString, "// CHECK: xla_gpu.allocate_shared : tensor<1x128x129xi8>")); +} + +TEST_F(MlirTransposeFusionTest, PreferLargeVectorSize210) { + auto kHloString = R"( + HloModule Transpose + %fused_computation { + %p0 = u8[256,256,256] parameter(0) + %transpose = u8[256,256,256] transpose(%p0), dimensions={2,1,0} + } + ENTRY main { + %param = u8[256,256,256] parameter(0) + ROOT %fusion = u8[256,256,256] fusion(%param), kind=kInput, + calls=%fused_computation + } + )"; + TF_EXPECT_OK(EmitAndCheckIR( + kHloString, "// CHECK: xla_gpu.allocate_shared : tensor<128x1x129xi8>")); } } // namespace From fd7f0851a08ce3b9d8ea50ad1ad47676af631863 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Tue, 25 Jun 2024 00:55:24 -0700 Subject: [PATCH 206/256] Add config-cuda-only tag to pjrt_c_api_gpu_test Internally the PJRT client needs to be explicitly compiled with CUDA/GPU support and all its tests do as well, otherwise they would be testing a PJRT client without GPU support Adding the tag `config-cuda-only` ensures that the test is being built with CUDA support enabled in presubmits. PiperOrigin-RevId: 646370443 --- third_party/xla/xla/pjrt/c/BUILD | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/third_party/xla/xla/pjrt/c/BUILD b/third_party/xla/xla/pjrt/c/BUILD index a677da30a3361c..c254e9e8205cb1 100644 --- a/third_party/xla/xla/pjrt/c/BUILD +++ b/third_party/xla/xla/pjrt/c/BUILD @@ -13,6 +13,10 @@ load( "xla_cc_test", ) load("//xla/tests:build_defs.bzl", "xla_test") +load( + "//xla/tsl:tsl.bzl", + "if_google", +) # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) @@ -350,6 +354,9 @@ xla_test( name = "pjrt_c_api_gpu_test", srcs = ["pjrt_c_api_gpu_test.cc"], backends = ["gpu"], + tags = if_google([ + "config-cuda-only", + ]), deps = [ ":pjrt_c_api_ffi_extension_hdrs", ":pjrt_c_api_gpu", From 00951cf4df42c3567c451a1d3956de27142648e5 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 25 Jun 2024 01:33:27 -0700 Subject: [PATCH 207/256] [XLA:GPU] Add missing disjoint constraints in `ExtractSizeAndStrideFromMod`. This leverages disjunctions to relax constraints on `bitcast`s and `reshape`s. PiperOrigin-RevId: 646381273 --- .../xla/service/gpu/model/symbolic_tile.cc | 18 +++++--- .../gpu/model/symbolic_tile_analysis_test.cc | 45 ++++++++++--------- .../service/gpu/model/symbolic_tile_test.cc | 15 +++---- 3 files changed, 40 insertions(+), 38 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc index 91fffa1f39cd8a..b2fbaf02b92a89 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc @@ -150,14 +150,18 @@ std::optional ExtractSizeAndStrideFromMod( dim_expr - 1, modulus) * modulus; - AffineExpr constrained_expr = - getAffineSymbolExpr(dim_expr.getPosition(), lhs.getContext()) % modulus; + AffineExpr tile_size_expr = + getAffineSymbolExpr(dim_expr.getPosition(), lhs.getContext()); + Interval zero_interval{/*lower=*/0, /*upper=*/0}; + // TODO(b/326998704): the below also becomes more complicated if stride is + // not unit. + // + // tile_size % modulus == 0 || modulus % tile_size == 0 ConstraintExpression constraints; - // TODO(b/334043867): we only add a constraint for n being a multiple of c - // while we do not support disjunctions. - ConstraintExpression::ConjointConstraints conjunction; - conjunction.insert({constrained_expr, Interval{/*lower=*/0, /*upper=*/0}}); - constraints.And(std::move(conjunction)); + constraints.And( + /*conjunction=*/{{tile_size_expr % modulus, zero_interval}}); + constraints.Or( + /*conjunction=*/{{modulus % tile_size_expr, zero_interval}}); // In this case, stride is effectively 1 mod modulus = 1. return SizeAndStrideExpression( diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc index 97ddfb0036a8ac..31adf8b1780982 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -372,13 +372,10 @@ ENTRY main { std::optional analysis = TryAnalyzeModule(module.get()); ASSERT_TRUE(analysis.has_value()); const ConstraintExpression& constraints = analysis->GetConstraints(); - EXPECT_THAT(constraints.DisjointConjointConstraints(), SizeIs(1)); + EXPECT_THAT(constraints.DisjointConjointConstraints(), SizeIs(2)); EXPECT_THAT(constraints.DisjointConjointConstraints().front(), SizeIs(1)); } -// TODO(b/334043867): add disjunction tests here once disjunctions are actually -// used in `SymbolicTile`s. - TEST_F(SymbolicTileAnalysisTest, DoesNotBailOutOnConstrainedBitcast) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( @@ -394,7 +391,7 @@ ENTRY main { std::optional analysis = TryAnalyzeModule(module.get()); ASSERT_TRUE(analysis.has_value()); const ConstraintExpression& constraints = analysis->GetConstraints(); - EXPECT_THAT(constraints.DisjointConjointConstraints(), SizeIs(1)); + EXPECT_THAT(constraints.DisjointConjointConstraints(), SizeIs(2)); EXPECT_THAT(constraints.DisjointConjointConstraints().front(), SizeIs(1)); } @@ -449,18 +446,22 @@ ENTRY main { std::optional analysis = TryAnalyzeModule(module.get()); ASSERT_TRUE(analysis.has_value()); const ConstraintExpression& constraints = analysis->GetConstraints(); - EXPECT_THAT(constraints.DisjointConjointConstraints(), SizeIs(1)); - EXPECT_THAT(constraints.DisjointConjointConstraints().front(), SizeIs(2)); + EXPECT_THAT(constraints.DisjointConjointConstraints(), SizeIs(4)); + for (const ConstraintExpression::ConjointConstraints& conjunction : + constraints.DisjointConjointConstraints()) + EXPECT_THAT(conjunction, SizeIs(2)); // We expect the constraints here to be - // s0 mod 6 in [0, 0] - // s1 mod 8 in [0, 0] - // We expect tile sizes {6, 8} to satisfy these constraints. + // 6 mod s0 in [0, 0] && 8 mod s1 in [0, 0] || + // 6 mod s0 in [0, 0] && s1 mod 8 in [0, 0] || + // 8 mod s1 in [0, 0] && s0 mod 6 in [0, 0] || + // s0 mod 6 in [0, 0] && s1 mod 8 in [0, 0] + // Tile sizes {6, 8} satisfy these constraints. std::vector possible_tile_parameters({6, 8}); EXPECT_THAT(analysis->ParametersSatisfyConstraints(possible_tile_parameters), IsOkAndHolds(true)); - // However, we do not expect tile sizes {6, 7} to satisfy these constraints. + // However, tile sizes {6, 7} do not satisfy these constraints. std::vector impossible_tile_parameters({6, 7}); EXPECT_THAT( analysis->ParametersSatisfyConstraints(impossible_tile_parameters), @@ -504,10 +505,10 @@ ENTRY main { })")); std::optional analysis = TryAnalyzeModule(module.get()); ASSERT_TRUE(analysis.has_value()); - // Each bitcast in the above module introduces one constraint. Once they are - // aggregated, we have two! + // Each bitcast in the above module introduces one disjoint constraint. Once + // they are aggregated, we have four disjoint constraints! const ConstraintExpression& constraints = analysis->GetConstraints(); - EXPECT_THAT(constraints.DisjointConjointConstraints(), SizeIs(1)); + EXPECT_THAT(constraints.DisjointConjointConstraints(), SizeIs(4)); EXPECT_THAT(constraints.DisjointConjointConstraints().front(), SizeIs(2)); } @@ -587,13 +588,13 @@ TEST_F(SymbolicTileAnalysisTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( fusion { - p0 = f32[1,8,6,4]{3,2,1,0} parameter(0) - ROOT bitcast = f32[48,4]{1,0} bitcast(p0) + p0 = f32[1,8,6,1]{3,2,1,0} parameter(0) + ROOT bitcast = f32[48,1]{1,0} bitcast(p0) } ENTRY main { - p0 = f32[1,8,6,4]{3,2,1,0} parameter(0) - ROOT fusion = f32[48,4]{1,0} fusion(p0), kind=kLoop, calls=fusion + p0 = f32[1,8,6,1]{3,2,1,0} parameter(0) + ROOT fusion = f32[48,1]{1,0} fusion(p0), kind=kLoop, calls=fusion })")); std::optional opt_analysis = @@ -604,11 +605,13 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN( std::vector good_tilings, analysis.GetGoodTilings()); - // The constraint on the 1st dimension is "s0 mod 6 in [0, 0]", and only 48 - // fulfills that from the set of possible tile sizes (1, 2, 4, 8, 16, 32, 48). + // The constraint on the 1st dimension is + // 6 mod s0 in [0, 0] || s0 mod 6 in [0, 0], + // and only 48, 1, and 2 fulfill it from the set of possible tile sizes + // (1, 2, 4, 8, 16, 32, 48). // There is no constraint on the 2nd dimension. EXPECT_EQ(good_tilings, std::vector( - {{48, 1}, {48, 2}, {48, 4}})); + {{1, 1}, {2, 1}, {48, 1}})); } // Logs the tilings if VLOG level 1 is enabled. diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc index da64bf2d16c2c8..9f776f4ccca12d 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc @@ -141,10 +141,6 @@ TEST_F(SymbolicTileTest, // TODO(bchetioui): support expanding one dimension to more than two // dimensions and constrain accordingly. - // TODO(b/334043867): add disjunctions in order to relax some of these - // constraints. Currently we only support the reshaped tile size to be a - // multiple of the smaller collapsed axes---we also need to support the case - // where the tile size is a divisor of the collapsed axis. EXPECT_THAT( SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), Optional(MatchSymbolicTileString(R"( @@ -153,7 +149,7 @@ TEST_F(SymbolicTileTest, size_map: ()[s0, s1] -> (1, (s0 + 5) floordiv 6, s0 - ((s0 - 1) floordiv 6) * 6, s1) stride_map: ()[s0, s1] -> (0, 1, 1, 1) constraints: - s0 mod 6 in [0, 0] + 6 mod s0 in [0, 0] || s0 mod 6 in [0, 0] )"))); } @@ -716,9 +712,6 @@ TEST_F(SymbolicTileTest, CanCombineCompatibleConstraints) { } )")); - // TODO(b/334043867): add disjunctions in order to relax some of these - // constraints. Currently we only support the reshaped axis to be a multiple - // of the smaller collapsed axes. EXPECT_THAT( SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), Optional(MatchSymbolicTileString(R"( @@ -727,8 +720,10 @@ TEST_F(SymbolicTileTest, CanCombineCompatibleConstraints) { size_map: ()[s0, s1] -> (1, (s0 + 5) floordiv 6, s0 - ((s0 - 1) floordiv 6) * 6, (s1 + 7) floordiv 8, s1 - ((s1 - 1) floordiv 8) * 8) stride_map: ()[s0, s1] -> (0, 1, 1, 1, 1) constraints: - s0 mod 6 in [0, 0] && - s1 mod 8 in [0, 0] + 6 mod s0 in [0, 0] && 8 mod s1 in [0, 0] || + 6 mod s0 in [0, 0] && s1 mod 8 in [0, 0] || + 8 mod s1 in [0, 0] && s0 mod 6 in [0, 0] || + s0 mod 6 in [0, 0] && s1 mod 8 in [0, 0] )"))); } From a0da03263492463221066ee5b3376787e2188465 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Tue, 25 Jun 2024 01:34:05 -0700 Subject: [PATCH 208/256] This is not correct in general. (((s1 + s1) floordiv 3) + (s0 floordiv 3)) floordiv 6) is simplified to (s1 * 2 + s0) floordiv 18 But this is not equivalent: the results for [2, 8] are 1 and 0, respectively. Reverts 0e7c7e784d0f5a1611457afc96dde71a38c80ede PiperOrigin-RevId: 646381435 --- .../xla/xla/service/gpu/model/indexing_map.cc | 47 +++++-------------- .../service/gpu/model/indexing_map_test.cc | 16 ------- 2 files changed, 12 insertions(+), 51 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index 5e74a5da529bba..bcfcf494a5a5b9 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -271,8 +271,6 @@ AffineExpr AffineExprSimplifier::SimplifySumDiv(AffineExpr dividend, // The gcd of all multipliers and the divisor. int64_t multiplier_divisor_gcd = divisor; Interval no_multiplier_range{0, 0}; - std::optional min_inner_divisor = std::nullopt; - std::optional inner_divisor_gcd = std::nullopt; VisitSummands(new_dividend, [&](AffineExpr summand) { if (auto multiplier = GetConstantRhs(summand, AffineExprKind::Mul)) { multiplier_divisor_gcd = std::gcd(multiplier_divisor_gcd, *multiplier); @@ -280,14 +278,6 @@ AffineExpr AffineExprSimplifier::SimplifySumDiv(AffineExpr dividend, no_multiplier_range = no_multiplier_range + range_evaluator_->ComputeExpressionRange(summand); } - - if (auto inner_divisor = - GetConstantRhs(summand, AffineExprKind::FloorDiv)) { - min_inner_divisor = - std::min(min_inner_divisor.value_or(*inner_divisor), *inner_divisor); - inner_divisor_gcd = - std::gcd(inner_divisor_gcd.value_or(*inner_divisor), *inner_divisor); - } }); // Consider an expression like: `(x * 6 + y) / 9`. if the range of `y` is at @@ -306,24 +296,6 @@ AffineExpr AffineExprSimplifier::SimplifySumDiv(AffineExpr dividend, divisor /= multiplier_divisor_gcd; } - // If we have an inner divisor whose value is equal to the GCD of all the - // divisors, we can remove a division: - // `(a0 / c + a1 / cd + ...) / e` -> `(a0 + a1 / d + (...) * c) / ce` - // This potentially increases the number of multiplications, but it's - // generally a win. It also matches what the MLIR simplifier does better, so - // we can get more simplifications. - if (min_inner_divisor && *min_inner_divisor > 0 && - min_inner_divisor == inner_divisor_gcd) { - new_dividend = MapSummands(new_dividend, [&](AffineExpr summand) { - if (auto inner_divisor = - GetConstantRhs(summand, AffineExprKind::FloorDiv)) { - return GetLhs(summand).floorDiv(*inner_divisor / *inner_divisor_gcd); - } - return summand * *inner_divisor_gcd; - }); - divisor *= *inner_divisor_gcd; - } - return new_dividend.floorDiv(divisor) + extracted; } @@ -509,13 +481,18 @@ AffineExpr AffineExprSimplifier::SimplifyOnce(AffineExpr expr) { if (!div) continue; // Already erased. if ((div_mul % mod_mul) || (div_mul / mod_mul) != mod_c) continue; - // In many cases, we could just compare the LHSes of the mod and the - // div, but if x is a floorDiv itself, we need to check a bit more - // carefully: - // ((x // c0) % c1) * d + (x // (c0 * c1)) * (c1 * d)` - // `x // (c0 * c1)` will be simplified, so we we may not even have - // `c0 * c1` in the expression, if `x` contains a multiplier. - if (Simplify(GetLhs(mod).floorDiv(*mod_c)) != Simplify(div)) continue; + auto mod_lhs = GetLhs(mod); + if (GetConstantRhs(mod_lhs, AffineExprKind::FloorDiv)) { + // If x is a floorDiv itself, we need to check a bit more carefully: + // ((x // c0) % c1) * d + (x // (c0 * c1)) * (c1 * d)` + // `x // (c0 * c1)` will be simplified, so we we may not even have + // `c0 * c1` in the expression, if `x` contains a multiplier. + if (Simplify(mod_lhs.floorDiv(*mod_c)) != Simplify(div)) continue; + } else { + if (mod_lhs != GetLhs(div)) continue; + auto div_c = GetConstantRhs(div, AffineExprKind::FloorDiv); + if (mod_c != div_c) continue; + } others.push_back(GetLhs(mod) * mod_mul); divs[div_i].first = nullptr; diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc index 1e326a532de107..24bbf63fd385d4 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc @@ -715,22 +715,6 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsInSequence) { )")); } -TEST_F(IndexingMapTest, AffineMapSimplification_DivGcdGreater1) { - auto serialized_map = - "()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 - ((s0 * 2 + s1 floordiv 64) " - "floordiv 3) * 768 + ((s0 * 128 + s1) floordiv 192) * 768)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {1234, 128, 4}); - EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( - ()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2) - domain: - s0 in [0, 1233] - s1 in [0, 127] - s2 in [0, 3] - )")); -} - TEST_F(IndexingMapTest, AffineMapSimplification_NegativeDiv) { // (s0 floordiv 2) floordiv -7 is not s0 floordiv -14: // 15 // 2 // -7 = -1 From bf438c36804cc75f65848b93ed9fd9f84f37838e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jun 2024 02:03:29 -0700 Subject: [PATCH 209/256] compat: Update forward compatibility horizon to 2024-06-25 PiperOrigin-RevId: 646389779 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index beb6522ab43b7e..7f327f40856558 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 6, 24) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 6, 25) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From e9002503620af1fa174ebbced0fe78d88d141afa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jun 2024 02:03:30 -0700 Subject: [PATCH 210/256] Update GraphDef version to 1904. PiperOrigin-RevId: 646389781 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 56c889c5865f6e..0979c04fe2b229 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1903 // Updated: 2024/6/24 +#define TF_GRAPH_DEF_VERSION 1904 // Updated: 2024/6/25 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 3175327dd14654f6a17ffc9c2f4181306e7e0316 Mon Sep 17 00:00:00 2001 From: Jaroslav Sevcik Date: Tue, 25 Jun 2024 02:38:46 -0700 Subject: [PATCH 211/256] PR #14088: Create executable buffer at the appropriate memory space on GPUs Imported from GitHub PR https://github.com/openxla/xla/pull/14088 Copybara import of the project: -- fa7d6071f64c3c614ccf78e49c359078a0c49604 by Jaroslav Sevcik : Create executable buffer at the appropriate memory space Merging this change closes #14088 PiperOrigin-RevId: 646398973 --- third_party/xla/xla/pjrt/BUILD | 1 + .../xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 96 +++++++++++++++++++ .../xla/pjrt/pjrt_stream_executor_client.cc | 14 ++- 3 files changed, 110 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index aa7dfdbdc39a11..bb576fa04bc541 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -477,6 +477,7 @@ cc_library( "//xla/client:local_client", "//xla/client:xla_computation", "//xla/hlo/ir:hlo", + "//xla/pjrt:host_memory_spaces", "//xla/pjrt/distributed:protocol_proto_cc", "//xla/service:compiler", "//xla/service:computation_layout", diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index 40a0c1115074f9..a8476013edfe20 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -780,5 +781,100 @@ TEST(StreamExecutorGpuClientTest, MockNcclClientTest) { } } +namespace { + +absl::StatusOr> CreateDeviceBufferForTest( + xla::PjRtClient* client) { + auto device = client->addressable_devices()[0]; + TF_EXPECT_OK(device->default_memory_space()); + + std::vector data{1, 2, 3, 4}; + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(S32, {4}, {0}); + TF_ASSIGN_OR_RETURN( + auto input, client->BufferFromHostBuffer( + data.data(), shape.element_type(), shape.dimensions(), + /*byte_strides=*/std::nullopt, + PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, + /*on_done_with_host_buffer=*/nullptr, device)); + EXPECT_EQ(input->memory_space()->kind(), "device"); + return input; +} + +} // namespace + +TEST(StreamExecutorGpuClientTest, ExecutePinnedHostOutputTest) { + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); + TF_ASSERT_OK_AND_ASSIGN(auto input, CreateDeviceBufferForTest(client.get())); + + static constexpr char const* kD2HProgram = R"( + HloModule f + + ENTRY main.5 { + p = s32[4]{0} parameter(0) + ROOT cc = s32[4] custom-call(p), + custom_call_target="annotate_device_placement", + frontend_attributes={_xla_buffer_placement="pinned_host"} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto executable, + CompileExecutable(kD2HProgram, *client)); + TF_ASSERT_OK_AND_ASSIGN( + auto result, executable->Execute({{input.get()}}, ExecuteOptions())); + + std::vector>& result_buffers = result[0]; + EXPECT_EQ(result_buffers[0]->memory_space()->kind(), "pinned_host"); + + TF_ASSERT_OK_AND_ASSIGN(auto memory_stats, + executable->GetCompiledMemoryStats()); + EXPECT_EQ(memory_stats.output_size_in_bytes, 0); + EXPECT_EQ(memory_stats.host_output_size_in_bytes, 16); +} + +TEST(StreamExecutorGpuClientTest, ExecutePinnedHostOutputTupleTest) { + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); + TF_ASSERT_OK_AND_ASSIGN(auto input, CreateDeviceBufferForTest(client.get())); + + static constexpr char const* kD2HProgram = R"( + HloModule f + + ENTRY main.5 { + p = s32[4]{0} parameter(0) + cc = s32[4] custom-call(p), + custom_call_target="annotate_device_placement", + frontend_attributes={_xla_buffer_placement="pinned_host"} + ROOT tuple = (s32[4]{0}, s32[4]{0}) tuple(s32[4]{0} p, s32[4]{0} cc) + } + )"; + + // Build the output shape with the correct memory space set. + Shape host_shape = input->on_device_shape(); + host_shape.mutable_layout()->set_memory_space(Layout::kHostMemorySpace); + Shape out_shape = + ShapeUtil::MakeTupleShape({input->on_device_shape(), host_shape}); + + // Set the result layout so that the compiler assertions on memory + // spaces pass. + xla::CompileOptions options; + options.executable_build_options.set_result_layout(out_shape); + + TF_ASSERT_OK_AND_ASSIGN(auto executable, + CompileExecutable(kD2HProgram, *client, options)); + + // Untuple the result so that we get separate buffers. + // This is how JAX invokes XLA. + ExecuteOptions execute_options; + execute_options.untuple_result = true; + TF_ASSERT_OK_AND_ASSIGN( + auto result, executable->Execute({{input.get()}}, execute_options)); + + std::vector>& result_buffers = result[0]; + EXPECT_EQ(result_buffers.size(), 2); + EXPECT_EQ(result_buffers[0]->memory_space()->kind(), "device"); + EXPECT_EQ(result_buffers[1]->memory_space()->kind(), "pinned_host"); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index 99f60350679752..de9dd466f88fce 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -104,6 +104,7 @@ limitations under the License. #include "xla/pjrt/distributed/protocol.pb.h" #include "xla/pjrt/event_pool.h" #include "xla/pjrt/host_callback.h" +#include "xla/pjrt/host_memory_spaces.h" #include "xla/pjrt/local_device_state.h" #include "xla/pjrt/metrics.h" #include "xla/pjrt/mlir_to_hlo.h" @@ -2240,9 +2241,20 @@ std::unique_ptr OutputBufferHelper( std::shared_ptr out_buffer = TrackedDeviceBuffer::FromScopedShapedBuffer(result_buffer, {definition_event}); + Shape shape = result_buffer->on_device_shape(); + PjRtMemorySpace* memory_space = + device->default_memory_space().value_or(nullptr); + if (shape.has_layout() && + shape.layout().memory_space() == Layout::kHostMemorySpace) { + absl::StatusOr memory_space_or = + device->memory_space_by_kind(PinnedHostMemorySpace::kKind); + if (memory_space_or.ok()) { + memory_space = memory_space_or.value(); + } + } auto pjrt_buffer = std::make_unique( result_buffer->on_device_shape(), std::move(out_buffer), client, device, - device->default_memory_space().value_or(nullptr)); + memory_space); RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device, definition_event, local_device->compute_stream(), /*prefer_to_retain_reference=*/false, &buffers_to_release); From 9d5165ddba7bbb7da88b2d07e6e9c2e7f7801610 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Tue, 25 Jun 2024 02:41:32 -0700 Subject: [PATCH 212/256] [XLA:GPU] Print Interval as a semi-open interval. Replaced by %s/in \[\(\d\+\), \(\d\+\)\]/\='in ['.(submatch(1)).', '.(submatch(2)+1).')'/g PiperOrigin-RevId: 646399750 --- third_party/xla/docs/indexing.md | 288 ++--- .../gpu/fusions/concatenate_mlir_test.cc | 18 +- .../service/gpu/fusions/concatenate_test.cc | 18 +- ...in_place_dynamic_update_slice_mlir_test.cc | 24 +- .../in_place_dynamic_update_slice_test.cc | 16 +- .../gpu/fusions/input_slices_mlir_test.cc | 36 +- .../service/gpu/fusions/input_slices_test.cc | 18 +- .../xla/service/gpu/fusions/loop_mlir_test.cc | 86 +- .../xla/xla/service/gpu/fusions/loop_test.cc | 86 +- .../mlir/elemental_hlo_to_mlir_test.cc | 70 +- .../gpu/fusions/mlir/ir/xla_gpu_ops.cc | 8 +- .../gpu/fusions/mlir/tests/canonicalize.mlir | 44 +- .../gpu/fusions/mlir/tests/invalid.mlir | 4 +- .../gpu/fusions/mlir/tests/lower_tensors.mlir | 2 +- .../service/gpu/fusions/mlir/tests/ops.mlir | 12 +- .../fusions/mlir/tests/optimize_loops.mlir | 16 +- .../fusions/mlir/tests/simplify_affine.mlir | 8 +- .../mlir/tests/vectorize_loads_stores.mlir | 18 +- .../gpu/fusions/reduction_mlir_test.cc | 560 ++++----- .../xla/service/gpu/fusions/reduction_test.cc | 194 ++-- .../service/gpu/fusions/scatter_mlir_test.cc | 38 +- .../xla/service/gpu/fusions/scatter_test.cc | 38 +- .../gpu/fusions/transpose_mlir_test.cc | 172 +-- .../xla/service/gpu/fusions/transpose_test.cc | 164 +-- .../service/gpu/model/coalescing_analysis.cc | 8 +- .../gpu/model/indexing_analysis_test.cc | 1014 ++++++++--------- .../xla/xla/service/gpu/model/indexing_map.cc | 3 +- .../service/gpu/model/indexing_map_test.cc | 276 ++--- .../gpu/model/symbolic_tile_analysis_test.cc | 30 +- .../service/gpu/model/symbolic_tile_test.cc | 32 +- 30 files changed, 1651 insertions(+), 1650 deletions(-) diff --git a/third_party/xla/docs/indexing.md b/third_party/xla/docs/indexing.md index 5f0b70a8daa6c1..7d1722de431beb 100644 --- a/third_party/xla/docs/indexing.md +++ b/third_party/xla/docs/indexing.md @@ -15,7 +15,7 @@ bc0 = f32[10, 20, 30] broadcast(p0), dimensions={1} ``` the indexing map from the output to input is `(i, j, k) -> (j)` for `i in -[0, 10]`, `j in [0, 20]` and `k in [0, 30]`. +[0, 10]`, `j in [0, 21)` and `k in [0, 31)`. ## Motivation @@ -166,8 +166,8 @@ The output to input maps: ``` (d0, d1) -> (d0, d1) domain: -d0 in [0, 9] -d1 in [0, 19] +d0 in [0, 10) +d1 in [0, 20) ``` The input to output maps @@ -177,8 +177,8 @@ The input to output maps ``` (d0, d1) -> (d0, d1) domain: -d0 in [0, 9] -d1 in [0, 19] +d0 in [0, 10) +d1 in [0, 20) ``` ### [Broadcast](https://openxla.org/xla/operation_semantics#broadcastindim) @@ -196,9 +196,9 @@ The output to input map: ``` (d0, d1, d2) -> (d1) domain: -d0 in [0, 9] -d1 in [0, 19] -d2 in [0, 29] +d0 in [0, 10) +d1 in [0, 20) +d2 in [0, 30) ``` The input to output map @@ -206,9 +206,9 @@ The input to output map ``` (d0)[s0, s1] -> (s0, d0, s1) domain: -d0 in [0, 19] -s0 in [0, 9] -s1 in [0, 29] +d0 in [0, 20) +s0 in [0, 10) +s1 in [0, 30) ``` Note that now we have **s** on the right side for the input-to-output @@ -235,16 +235,16 @@ The output to input map for `src`: ``` (d0, d1, d2)[s0, s1, s2] -> (d0 + s0, d1 + s1, d2 + s2) domain: -d0 in [0, 0] -d1 in [0, 1] -d2 in [0, 31] -s0 in [0, 1] +d0 in [0, 1) +d1 in [0, 2) +d2 in [0, 32) +s0 in [0, 2) hlo: of1 = s32[] parameter(1) (d0, d1, d2) -> () -s1 in [0, 0] +s1 in [0, 1) hlo: of2 = s32[] parameter(2) (d0, d1, d2) -> () -s2 in [0, 226] +s2 in [0, 227) hlo: of3 = s32[] parameter(3) (d0, d1, d2) -> () ``` @@ -260,9 +260,9 @@ The output to input map for `of1`, `of2` and `of3`: ``` (d0, d1, d2) -> () domain: -d0 in [0, 0] -d1 in [0, 1] -d2 in [0, 31] +d0 in [0, 1) +d1 in [0, 2) +d2 in [0, 32) ``` ### [DynamicUpdateSlice](https://openxla.org/xla/operation_semantics#dynamicupdateslice) @@ -281,20 +281,20 @@ do not support inqequality constraints. ``` (d0, d1) -> (d0, d1) domain: -d0 in [0, 19] -d1 in [0, 29] +d0 in [0, 20) +d1 in [0, 30) ``` The output to input map for `upd`: ``` (d0, d1)[s0, s1] -> (d0 - s0, d1 - s1) domain: -d0 in [0, 19] -d1 in [0, 29] -s0 in [0, 15] +d0 in [0, 20) +d1 in [0, 30) +s0 in [0, 16) hlo: of1 = s32[] parameter(2) (d0, d1) -> () -s1 in [0, 20] +s1 in [0, 21) hlo: of2 = s32[] parameter(3) (d0, d1) -> () ``` @@ -311,8 +311,8 @@ The output to input map for `of1` and `of2`: ``` (d0, d1) -> () domain: -d0 in [0, 19] -d1 in [0, 29] +d0 in [0, 20) +d1 in [0, 30) ``` ### [Gather](https://openxla.org/xla/operation_semantics#gather) @@ -334,14 +334,14 @@ The output to input map for `operand`: (d0, d1, d2, d3)[s0, s1] -> (d1 + s0, d2 + s1, d3) domain: -d0 in [0, 1805] -d1 in [0, 6] -d2 in [0, 7] -d3 in [0, 3] -s0 in [0, 26] +d0 in [0, 1806) +d1 in [0, 7) +d2 in [0, 8) +d3 in [0, 4) +s0 in [0, 27) hlo: indices = s32[1806,2]{1,0} parameter(1) (d0, d1, d2, d3) -> (d0, 0) -s1 in [0, 68] +s1 in [0, 69) hlo: indices = s32[1806,2]{1,0} parameter(1) (d0, d1, d2, d3) -> (d0, 1) ``` @@ -356,11 +356,11 @@ The output to input map for `indices`: ``` (d0, d1, d2, d3)[s0] -> (d0, s0) domain: - d0 in [0, 1805] - d1 in [0, 6] - d2 in [0, 7] - d3 in [0, 3] - s0 in [0, 1] + d0 in [0, 1806) + d1 in [0, 7) + d2 in [0, 8) + d3 in [0, 4) + s0 in [0, 2) ``` The range variable `s0` shows that we need the entire row (d0, *) of the `indices` tensor to compute an element of the output. @@ -380,10 +380,10 @@ The output to input map: ``` (d0, d1, d2, d3) -> (d0, d3, d1, d2) domain: -d0 in [0, 2] -d1 in [0, 5] -d2 in [0, 127] -d3 in [0, 12287] +d0 in [0, 3) +d1 in [0, 6) +d2 in [0, 128) +d3 in [0, 12288) ``` The input to output map: @@ -391,10 +391,10 @@ The input to output map: ``` (d0, d1, d2, d3) -> (d0, d2, d3, d1) domain: -d0 in [0, 2] -d1 in [0, 12287] -d2 in [0, 5] -d3 in [0, 127] +d0 in [0, 3) +d1 in [0, 12288) +d2 in [0, 6) +d3 in [0, 128) ``` ### [Reverse](https://openxla.org/xla/operation_semantics#rev_reverse) @@ -412,10 +412,10 @@ The output to input map: ``` (d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3) domain: -d0 in [0, 0] -d1 in [0, 16] -d2 in [0, 8] -d3 in [0, 8] +d0 in [0, 1) +d1 in [0, 17) +d2 in [0, 9) +d3 in [0, 9) ``` The input to output map: @@ -423,10 +423,10 @@ The input to output map: ``` (d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3) domain: -d0 in [0, 0] -d1 in [0, 16] -d2 in [0, 8] -d3 in [0, 8] +d0 in [0, 1) +d1 in [0, 17) +d2 in [0, 9) +d3 in [0, 9) ``` ### **[(Variadic)Reduce](https://openxla.org/xla/operation_semantics#reduce)** @@ -451,8 +451,8 @@ The output to input maps: ``` (d0)[s0] -> (s0, d0) domain: -d0 in [0, 9] -s0 in [0, 255] +d0 in [0, 10) +s0 in [0, 256) ``` - output -> init_j: @@ -460,7 +460,7 @@ s0 in [0, 255] ``` (d0) -> () domain: -d0 in [0, 9] +d0 in [0, 10) ``` The input to output maps: @@ -470,8 +470,8 @@ The input to output maps: ``` (d0, d1) -> (d1) domain: -d0 in [0, 255] -d1 in [0, 9] +d0 in [0, 256) +d1 in [0, 10) ``` - init_i -> output_j: @@ -479,7 +479,7 @@ d1 in [0, 9] ``` ()[s0] -> (s0) domain: -s0 in [0, 9] +s0 in [0, 10) ``` for i, j = 0, ... INPUT_COUNT. @@ -501,9 +501,9 @@ The output to input map: ``` (d0, d1, d2) -> (d0 + 5, d1 * 7 + 3, d2 * 2) domain: -d0 in [0, 4] -d1 in [0, 2] -d2 in [0, 24] +d0 in [0, 5) +d1 in [0, 3) +d2 in [0, 25) ``` The input to output map: @@ -511,11 +511,11 @@ The input to output map: ``` (d0, d1, d2) -> (d0 - 5, (d1 - 3) floordiv 7, d2 floordiv 2) domain: -d0 in [5, 9] -d1 in [3, 17] -d2 in [0, 48] -(d1 - 3) mod 7 in [0, 0] -d2 mod 2 in [0, 0] +d0 in [5, 10) +d1 in [3, 18) +d2 in [0, 49) +(d1 - 3) mod 7 in [0, 1) +d2 mod 2 in [0, 1) ``` ### [Reshape](https://openxla.org/xla/operation_semantics#reshape) @@ -536,7 +536,7 @@ The output to input map: ``` (d0) -> (d0 floordiv 8, d0 mod 8) domain: -d0 in [0, 31] +d0 in [0, 32) ``` The input to output map: @@ -544,8 +544,8 @@ The input to output map: ``` (d0, d1) -> (d0 * 8 + d1) domain: -d0 in [0, 3] -d1 in [0, 7] +d0 in [0, 4) +d1 in [0, 8) ``` #### Expand shape @@ -562,8 +562,8 @@ The output to input map: ``` (d0, d1) -> (d0 * 8 + d1) domain: -d0 in [0, 3] -d1 in [0, 7] +d0 in [0, 4) +d1 in [0, 8) ``` The input to output map: @@ -571,7 +571,7 @@ The input to output map: ``` (d0) -> (d0 floordiv 8, d0 mod 8) domain: -d0 in [0, 31] +d0 in [0, 32) ``` #### Generic reshape @@ -594,11 +594,11 @@ This reshape can be represented as a composition of collapse shape of The output to input map: ``` -(d0, d1, d2) -> (d0 * 2 + d1 floordiv 2, d2 + (d1 mod 2) * 4) +(d0, d1, d2) -> (d0 * 2 + d1 floordiv 2, d2 + (d1 mod 2) * 4) domain: -d0 in [0, 1] -d1 in [0, 3] -d2 in [0, 3] +d0 in [0, 2) +d1 in [0, 4) +d2 in [0, 4) ``` The input to output map: @@ -606,8 +606,8 @@ The input to output map: ``` (d0, d1) -> (d0 floordiv 2, d1 floordiv 4 + (d0 mod 2) * 2, d1 mod 4) domain: -d0 in [0, 3] -d1 in [0, 7] +d0 in [0, 4) +d1 in [0, 8) ``` ##### Example 2: Expanded and collapsed subshapes @@ -627,9 +627,9 @@ The output to input map: ``` (d0, d1, d2) -> (d0 floordiv 8, d0 mod 8, d1 * 4 + d2) domain: -d0 in [0, 31] -d1 in [0, 2] -d2 in [0, 3] +d0 in [0, 32) +d1 in [0, 3) +d2 in [0, 4) ``` The input to output map: @@ -637,9 +637,9 @@ The input to output map: ``` (d0, d1, d2) -> (d0 * 8 + d1, d2 floordiv 4, d2 mod 4) domain: -d0 in [0, 3] -d1 in [0, 7] -d2 in [0, 11] +d0 in [0, 4) +d1 in [0, 8) +d2 in [0, 12) ``` ### Bitcast @@ -668,9 +668,9 @@ The output to inputs maps: ``` (d0, d1, d2) -> (d0, d1, d2) domain: -d0 in [0, 1] -d1 in [0, 4] -d2 in [0, 6] +d0 in [0, 2) +d1 in [0, 5) +d2 in [0, 7) ``` - output -> input 2: @@ -678,9 +678,9 @@ d2 in [0, 6] ``` (d0, d1, d2) -> (d0, d1 - 5, d2) domain: -d0 in [0, 1] -d1 in [5, 15] -d2 in [0, 6] +d0 in [0, 2) +d1 in [5, 16) +d2 in [0, 7) ``` - output -> input 3: @@ -688,9 +688,9 @@ d2 in [0, 6] ``` (d0, d1, d2) -> (d0, d1 - 16, d2) domain: -d0 in [0, 1] -d1 in [16, 32] -d2 in [0, 6] +d0 in [0, 2) +d1 in [16, 33) +d2 in [0, 7) ``` @@ -701,9 +701,9 @@ The inputs to output maps: ``` (d0, d1, d2) -> (d0, d1, d2) domain: -d0 in [0, 1] -d1 in [0, 4] -d2 in [0, 6] +d0 in [0, 2) +d1 in [0, 5) +d2 in [0, 7) ``` - input 2 -> output: @@ -711,9 +711,9 @@ d2 in [0, 6] ``` (d0, d1, d2) -> (d0, d1 + 5, d2) domain: -d0 in [0, 1] -d1 in [0, 10] -d2 in [0, 6] +d0 in [0, 2) +d1 in [0, 11) +d2 in [0, 7) ``` - input 3 -> output: @@ -721,9 +721,9 @@ d2 in [0, 6] ``` (d0, d1, d2) -> (d0, d1 + 16, d2) domain: -d0 in [0, 1] -d1 in [0, 16] -d2 in [0, 6] +d0 in [0, 2) +d1 in [0, 17) +d2 in [0, 7) ``` ### [Dot](https://openxla.org/xla/operation_semantics#dot) @@ -745,10 +745,10 @@ The output to inputs maps: ``` (d0, d1, d2)[s0] -> (d0, d1, s0) domain: -d0 in [0, 3] -d1 in [0, 127] -d2 in [0, 63] -s0 in [0, 255] +d0 in [0, 4) +d1 in [0, 128) +d2 in [0, 64) +s0 in [0, 256) ``` - output -> input_2: @@ -756,10 +756,10 @@ s0 in [0, 255] ``` (d0, d1, d2)[s0] -> (d0, s0, d2) domain: -d0 in [0, 3] -d1 in [0, 127] -d2 in [0, 63] -s0 in [0, 255] +d0 in [0, 4) +d1 in [0, 128) +d2 in [0, 64) +s0 in [0, 256) ``` The inputs to output maps: @@ -769,10 +769,10 @@ The inputs to output maps: ``` (d0, d1, d2)[s0] -> (d0, d1, s0) domain: -d0 in [0, 3] -d1 in [0, 127] -d2 in [0, 255] -s0 in [0, 63] +d0 in [0, 4) +d1 in [0, 128) +d2 in [0, 256) +s0 in [0, 64) ``` - input_2 -> output: @@ -780,10 +780,10 @@ s0 in [0, 63] ``` (d0, d1, d2)[s0] -> (d0, s0, d1) domain: -d0 in [0, 3] -d1 in [0, 255] -d2 in [0, 63] -s0 in [0, 127] +d0 in [0, 4) +d1 in [0, 256) +d2 in [0, 64) +s0 in [0, 128) ``` ### [Pad](https://openxla.org/xla/operation_semantics#pad) @@ -805,9 +805,9 @@ The output to input maps: ``` (d0, d1) -> ((d0 - 1) floordiv 2, d1 - 4) domain: -d0 in [1, 7] -d1 in [4, 7] -(d0 - 1) mod 2 in [0, 0] +d0 in [1, 8) +d1 in [4, 8) +(d0 - 1) mod 2 in [0, 1) ``` - output -> init: @@ -815,8 +815,8 @@ d1 in [4, 7] ``` (d0, d1) -> () domain: -d0 in [0, 11] -d1 in [0, 15] +d0 in [0, 12) +d1 in [0, 16) ``` @@ -841,9 +841,9 @@ The output to input maps: ``` (d0, d1)[s0] -> (d0, d1 + s0) domain: -d0 in [0, 1023] -d1 in [0, 2] -s0 in [0, 511] +d0 in [0, 1024) +d1 in [0, 3) +s0 in [0, 512) ``` - output -> init: @@ -851,8 +851,8 @@ s0 in [0, 511] ``` (d0, d1) -> () domain: -d0 in [0, 1023] -d1 in [0, 2] +d0 in [0, 1024) +d1 in [0, 3) ``` ## Indexing Maps for Fusion @@ -873,7 +873,7 @@ f { } ``` -The output-to-input indexing maps for `p0` will be `(d0, d1) -> (d0, d1)` and +The output-to-input indexing maps for `p0` will be `(d0, d1) -> (d0, d1)` and `(d0, d1) -> (d1, d0)`. It means that to compute one element of the output we might need to read the input parameter twice. @@ -909,10 +909,10 @@ The output-to-input indexing maps for `parameter 0` for softmax: ``` (d0, d1, d2)[s0] -> (d0, d1, s0) domain: -d0 in [0, 1] -d1 in [0, 64] -d2 in [0, 124] -s0 in [0, 124] +d0 in [0, 2) +d1 in [0, 65) +d2 in [0, 125) +s0 in [0, 125) ``` and @@ -920,9 +920,9 @@ and ``` (d0, d1, d2) -> (d0, d1, d2) domain: -d0 in [0, 1] -d1 in [0, 64] -d2 in [0, 124] +d0 in [0, 2) +d1 in [0, 65) +d2 in [0, 125) ``` where `s0` refers to the inner-most dimension of the input. @@ -941,10 +941,10 @@ The simplifier can rewrite the following expressions. 1. `(d0, d1) -> (d0 + d1 floordiv 16, d1 mod 16)` for **d** in `[0, 6] x [0, 14]` becomes `(d0, d1) -> (d0, d1)` 2. `(d0, d1, d2) -> ((100d0 + 10d1 + d2) floorDiv 100, ((100d0 + 10d1 + - d2) mod 100) floordiv 10, d2 mod 10)` for `di in [0, 9]` becomes `(d0, d1, + d2) mod 100) floordiv 10, d2 mod 10)` for `di in [0, 10)` becomes `(d0, d1, d2) -> (d0, d1, d2)`. 3. `(d0, d1, d2) -> ((16d0 + 4d1 + d2) floordiv 8, (16d0 + 4d1 + d2) mod - 8)` for `d_i in [0, 9]` becomes `(d0, d1, d2) -> (2d0 + (4d1 + + 8)` for `d_i in [0, 10)` becomes `(d0, d1, d2) -> (2d0 + (4d1 + d2) floordiv 8,(4d1 + d2) mod 8)`. 4. `(d0, d1) -> (-(-11d0 - d1 + 109) floordiv 11 + 9)` for **d** in `[0, 9] x [0, 10]` becomes `(d0, d1) -> (d0)`. @@ -967,10 +967,10 @@ Indexing map simplification also simplifies the constraints. 1. Constraints of type `lower_bound <= affine_expr (floordiv, +, -, *) constant <= upper_bound` are rewritten as `updated_lower_bound <= affine_expr <= updated_upped_bound`. -2. Constraints that are always satisfied, e.g. `d0 + s0 in [0, 20]` -for `d0 in [0, 5]` and `s0 in [1, 3]` are eliminated. +2. Constraints that are always satisfied, e.g. `d0 + s0 in [0, 21)` +for `d0 in [0, 6)` and `s0 in [1, 4)` are eliminated. 3. Affine expressions in the constraints are optimized as the indexing affine map above. -For more examples see [indexing_map_test.cc](https://github.com/openxla/xla/blob/main/xla/service/gpu/model/indexing_map_test.cc). \ No newline at end of file +For more examples see [indexing_map_test.cc](https://github.com/openxla/xla/blob/main/xla/service/gpu/model/indexing_map_test.cc). diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir_test.cc index 4a104da784d7ea..4ca3becb759aa8 100644 --- a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir_test.cc @@ -59,15 +59,15 @@ TEST_F(MlirConcatenateFusionTest, ThreadIdIndexing) { (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( (bl_x * 128 + th_x) mod 400) domain: - th_x in [0, 127] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 3] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] - th_x + bl_x * 128 in [0, 399] + th_x in [0, 128) + th_y in [0, 1) + th_z in [0, 1) + bl_x in [0, 4) + bl_y in [0, 1) + bl_z in [0, 1) + chunk_id in [0, 1) + unroll_id in [0, 1) + th_x + bl_x * 128 in [0, 400) )"; auto thread_id_to_output_indexing_0 = fusion.ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate_test.cc b/third_party/xla/xla/service/gpu/fusions/concatenate_test.cc index c1bc6603ca65ec..6ceb4cee88a0e7 100644 --- a/third_party/xla/xla/service/gpu/fusions/concatenate_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/concatenate_test.cc @@ -81,15 +81,15 @@ TEST_F(ConcatenateTest, ThreadIndexing) { (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( (bl_x * 128 + th_x) mod 400) domain: - th_x in [0, 127] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 3] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] - th_x + bl_x * 128 in [0, 399] + th_x in [0, 128) + th_y in [0, 1) + th_z in [0, 1) + bl_x in [0, 4) + bl_y in [0, 1) + bl_z in [0, 1) + chunk_id in [0, 1) + unroll_id in [0, 1) + th_x + bl_x * 128 in [0, 400) )"; EXPECT_THAT( fusion diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc index b68a95e9516bfd..c3ef11a04f1de8 100644 --- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc @@ -66,14 +66,14 @@ TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) { (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( th_x floordiv 6, th_x mod 6) domain: - th_x in [0, 29] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 0] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] + th_x in [0, 30) + th_y in [0, 1) + th_z in [0, 1) + bl_x in [0, 1) + bl_y in [0, 1) + bl_z in [0, 1) + chunk_id in [0, 1) + unroll_id in [0, 1) )")); auto thread_id_dst_indexing = fusion.ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); @@ -112,8 +112,8 @@ TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, SimpleDUS) { // CHECK-DAG: %[[C_15:.*]] = arith.constant 15 // CHECK-DAG: %[[C_0:.*]] = arith.constant 0 // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x - // CHECK: %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]] in [0, 29]) - // CHECK: %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]] in [0, 29]) + // CHECK: %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]] in [0, 30)) + // CHECK: %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]] in [0, 30)) // CHECK: %[[I0:.*]] = xla_gpu.pure_call @fused_computation_i0 // CHECK: %[[I1:.*]] = xla_gpu.pure_call @fused_computation_i1 // CHECK: %[[IDX0:.*]] = arith.index_cast %[[I0]] @@ -162,8 +162,8 @@ TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, OutOfBoundDUS) { // CHECK-DAG: %[[C_5:.*]] = arith.constant 5 // CHECK-DAG: %[[C_0:.*]] = arith.constant 0 // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x - // CHECK: %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]] in [0, 5]) - // CHECK: %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]] in [0, 5]) + // CHECK: %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]] in [0, 6)) + // CHECK: %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]] in [0, 6)) // CHECK: %[[I0:.*]] = xla_gpu.pure_call @fused_computation_i0 // CHECK: %[[I1:.*]] = xla_gpu.pure_call @fused_computation_i1 // CHECK: %[[IDX0:.*]] = arith.index_cast %[[I0]] diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_test.cc b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_test.cc index 98e533a3c5c13c..e4c67af1013708 100644 --- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_test.cc @@ -83,14 +83,14 @@ TEST_F(InPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) { (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( th_x floordiv 6, th_x mod 6) domain: - th_x in [0, 29] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 0] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] + th_x in [0, 30) + th_y in [0, 1) + th_z in [0, 1) + bl_x in [0, 1) + bl_y in [0, 1) + bl_z in [0, 1) + chunk_id in [0, 1) + unroll_id in [0, 1) )")); auto thread_id_dst_indexing = fusion->ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir_test.cc index 27e5fb35a0e4fb..a145f616efeda3 100644 --- a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir_test.cc @@ -57,15 +57,15 @@ TEST_F(MlirInputSlicesFusionTest, ThreadIndexing) { th_x mod 5 ) domain: - th_x in [5, 19] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 0] - bl_y in [0, 0] - bl_z in [0, 0] - s0 in [0, 0] - s1 in [0, 0] - th_x mod 5 in [0, 2] + th_x in [5, 20) + th_y in [0, 1) + th_z in [0, 1) + bl_x in [0, 1) + bl_y in [0, 1) + bl_z in [0, 1) + s0 in [0, 1) + s1 in [0, 1) + th_x mod 5 in [0, 3) )")); auto thread_id_to_output_indexing_1 = emitter->ComputeThreadIdToOutputIndexing(1, &mlir_context_); @@ -77,15 +77,15 @@ TEST_F(MlirInputSlicesFusionTest, ThreadIndexing) { th_x mod 5 ) domain: - th_x in [0, 9] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 0] - bl_y in [0, 0] - bl_z in [0, 0] - s0 in [0, 0] - s1 in [0, 0] - th_x mod 5 in [0, 2] + th_x in [0, 10) + th_y in [0, 1) + th_z in [0, 1) + bl_x in [0, 1) + bl_y in [0, 1) + bl_z in [0, 1) + s0 in [0, 1) + s1 in [0, 1) + th_x mod 5 in [0, 3) )")); } diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc b/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc index 9c32e9e035ebb1..c36b67e85e2df6 100644 --- a/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc @@ -82,15 +82,15 @@ TEST_F(InputSlicesTest, ThreadIndexing) { (bl_x * 128 + th_x) mod 3, ((bl_x * 128 + th_x) floordiv 6) mod 5) domain: - th_x in [0, 127] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 1] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] - th_x + bl_x * 128 in [0, 29] + th_x in [0, 128) + th_y in [0, 1) + th_z in [0, 1) + bl_x in [0, 2) + bl_y in [0, 1) + bl_z in [0, 1) + chunk_id in [0, 1) + unroll_id in [0, 1) + th_x + bl_x * 128 in [0, 30) )")); } diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc index 6c5d4464396800..52feb6c50fea11 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc @@ -59,15 +59,15 @@ TEST_F(MlirLoopFusionTest, ThreadId_IndexingUnrolled) { (th_x * 4 + bl_x * 512 + chunk_id * 516096) mod 300 + unroll_id ) domain: - th_x in [0, 127] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 1007] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 11] - unroll_id in [0, 3] - (th_x + bl_x * 128) * 4 + chunk_id * 516096 in [0, 5999996] + th_x in [0, 128) + th_y in [0, 1) + th_z in [0, 1) + bl_x in [0, 1008) + bl_y in [0, 1) + bl_z in [0, 1) + chunk_id in [0, 12) + unroll_id in [0, 4) + (th_x + bl_x * 128) * 4 + chunk_id * 516096 in [0, 5999997) )")); } @@ -97,14 +97,14 @@ TEST_F(MlirLoopFusionTest, ThreadId_IndexingNotUnrolled) { MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) domain: - th_x in [0, 19] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 0] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] + th_x in [0, 20) + th_y in [0, 1) + th_z in [0, 1) + bl_x in [0, 1) + bl_y in [0, 1) + bl_z in [0, 1) + chunk_id in [0, 1) + unroll_id in [0, 1) )")); auto thread_id_to_input_indexing = fusion.ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); @@ -112,14 +112,14 @@ TEST_F(MlirLoopFusionTest, ThreadId_IndexingNotUnrolled) { MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) domain: - th_x in [0, 19] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 0] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] + th_x in [0, 20) + th_y in [0, 1) + th_z in [0, 1) + bl_x in [0, 1) + bl_y in [0, 1) + bl_z in [0, 1) + chunk_id in [0, 1) + unroll_id in [0, 1) )")); } @@ -153,15 +153,15 @@ TEST_F(MlirLoopFusionTest, ThreadId_Broadcast) { (bl_x * 128 + th_x) mod 30 ) domain: - th_x in [0, 127] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 46] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] - th_x + bl_x * 128 in [0, 5999] + th_x in [0, 128) + th_y in [0, 1) + th_z in [0, 1) + bl_x in [0, 47) + bl_y in [0, 1) + bl_z in [0, 1) + chunk_id in [0, 1) + unroll_id in [0, 1) + th_x + bl_x * 128 in [0, 6000) )")); auto thread_id_to_input_indexing = fusion.ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); @@ -170,15 +170,15 @@ TEST_F(MlirLoopFusionTest, ThreadId_Broadcast) { (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (((bl_x * 128 + th_x) floordiv 30) mod 20) domain: - th_x in [0, 127] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 46] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] - th_x + bl_x * 128 in [0, 5999] + th_x in [0, 128) + th_y in [0, 1) + th_z in [0, 1) + bl_x in [0, 47) + bl_y in [0, 1) + bl_z in [0, 1) + chunk_id in [0, 1) + unroll_id in [0, 1) + th_x + bl_x * 128 in [0, 6000) )")); } diff --git a/third_party/xla/xla/service/gpu/fusions/loop_test.cc b/third_party/xla/xla/service/gpu/fusions/loop_test.cc index 4337260ca040e0..eb583f6def0da8 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop_test.cc @@ -93,15 +93,15 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) { (th_x * 4 + bl_x * 512 + chunk_id * 516096) mod 300 + unroll_id ) domain: - th_x in [0, 127] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 1007] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 11] - unroll_id in [0, 3] - (th_x + bl_x * 128) * 4 + chunk_id * 516096 in [0, 5999996] + th_x in [0, 128) + th_y in [0, 1) + th_z in [0, 1) + bl_x in [0, 1008) + bl_y in [0, 1) + bl_z in [0, 1) + chunk_id in [0, 12) + unroll_id in [0, 4) + (th_x + bl_x * 128) * 4 + chunk_id * 516096 in [0, 5999997) )")); } @@ -131,14 +131,14 @@ TEST_F(LoopTest, ThreadIndexingNotUnrolled) { MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) domain: - th_x in [0, 19] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 0] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] + th_x in [0, 20) + th_y in [0, 1) + th_z in [0, 1) + bl_x in [0, 1) + bl_y in [0, 1) + bl_z in [0, 1) + chunk_id in [0, 1) + unroll_id in [0, 1) )")); auto thread_id_to_input_indexing = loop_fusion->ComputeThreadIdToInputIndexing( @@ -147,14 +147,14 @@ TEST_F(LoopTest, ThreadIndexingNotUnrolled) { MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) domain: - th_x in [0, 19] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 0] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] + th_x in [0, 20) + th_y in [0, 1) + th_z in [0, 1) + bl_x in [0, 1) + bl_y in [0, 1) + bl_z in [0, 1) + chunk_id in [0, 1) + unroll_id in [0, 1) )")); } @@ -187,15 +187,15 @@ TEST_F(LoopTest, Broadcast) { ((bl_x * 128 + th_x) floordiv 30) mod 20, (bl_x * 128 + th_x) mod 30) domain: - th_x in [0, 127] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 46] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] - th_x + bl_x * 128 in [0, 5999] + th_x in [0, 128) + th_y in [0, 1) + th_z in [0, 1) + bl_x in [0, 47) + bl_y in [0, 1) + bl_z in [0, 1) + chunk_id in [0, 1) + unroll_id in [0, 1) + th_x + bl_x * 128 in [0, 6000) )")); auto thread_id_to_input_indexing = loop_fusion->ComputeThreadIdToInputIndexing( @@ -205,15 +205,15 @@ TEST_F(LoopTest, Broadcast) { (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (((bl_x * 128 + th_x) floordiv 30) mod 20) domain: - th_x in [0, 127] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 46] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] - th_x + bl_x * 128 in [0, 5999] + th_x in [0, 128) + th_y in [0, 1) + th_z in [0, 1) + bl_x in [0, 47) + bl_y in [0, 1) + bl_z in [0, 1) + chunk_id in [0, 1) + unroll_id in [0, 1) + th_x + bl_x * 128 in [0, 6000) )")); } diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc index 7bc60617f43b9c..4647c3f6478b41 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc @@ -236,9 +236,9 @@ TEST_F(ElementalHloToMlirTest, ReduceWindow) { // CHECK: %[[RET:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C7]] // CHECK-SAME: step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) // CHECK: %[[J0:.*]] = xla_gpu.apply_indexing affine_map<(d0) -> (d0 * 4)> - // CHECK-SAME: (%[[Y]] in [0, 2]) + // CHECK-SAME: (%[[Y]] in [0, 3)) // CHECK: %[[J1:.*]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0 - 3)> - // CHECK-SAME: (%[[Z]] in [0, 7])[%[[I]] in [0, 6]] + // CHECK-SAME: (%[[Z]] in [0, 8))[%[[I]] in [0, 7)] // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]] // CHECK-SAME: [%[[X]], %[[J0]], %[[J1]]] // CHECK: %[[UPD:.*]] = func.call @add_sum(%[[ACC]], @@ -286,7 +286,7 @@ TEST_F(ElementalHloToMlirTest, ReduceWindowWithRescaling) { // `s0 floordiv ` in the map: // CHECK: %[[K:.*]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 * 2 + s0)> - // CHECK-SAME: (%[[X]] in [0, 18])[%[[I]] in [0, 3]] + // CHECK-SAME: (%[[X]] in [0, 19))[%[[I]] in [0, 4)] // CHECK: tensor.extract %[[ARG0]][%[[K]], %[[Y]], %[[Z]]] )")); @@ -433,7 +433,7 @@ TEST_F(ElementalHloToMlirTest, Pad) { // CHECK-DAG: %[[C4:.*]] = arith.constant 4 // CHECK-DAG: %[[C7:.*]] = arith.constant 7 // CHECK: %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2)>(%[[X]] in [1, 7]) + // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2)>(%[[X]] in [1, 8)) // CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]] // CHECK-DAG: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]] // CHECK-DAG: %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]] @@ -446,10 +446,10 @@ TEST_F(ElementalHloToMlirTest, Pad) { // CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]] // CHECK: %[[IN0:.*]] = xla_gpu.apply_indexing // CHECK-SAME: <(d0) -> ((d0 - 1) floordiv 2)> - // CHECK-SAME: (%[[X]] in [1, 7]) + // CHECK-SAME: (%[[X]] in [1, 8)) // CHECK: %[[IN1:.*]] = xla_gpu.apply_indexing // CHECK-SAME: <(d0) -> (d0 - 4)> - // CHECK-SAME: (%[[Y]] in [4, 7]) + // CHECK-SAME: (%[[Y]] in [4, 8)) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]] // CHECK: scf.yield %[[VAL]] // CHECK: } else { @@ -477,7 +477,7 @@ TEST_F(ElementalHloToMlirTest, PadUnsigned) { // CHECK-DAG: %[[C4:.*]] = arith.constant 4 // CHECK-DAG: %[[C7:.*]] = arith.constant 7 // CHECK: %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2)>(%[[X]] in [1, 7]) + // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2)>(%[[X]] in [1, 8)) // CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]] // CHECK-DAG: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]] // CHECK-DAG: %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]] @@ -490,10 +490,10 @@ TEST_F(ElementalHloToMlirTest, PadUnsigned) { // CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]] // CHECK: %[[IN0:.*]] = xla_gpu.apply_indexing // CHECK-SAME: <(d0) -> ((d0 - 1) floordiv 2)> - // CHECK-SAME: (%[[X]] in [1, 7]) + // CHECK-SAME: (%[[X]] in [1, 8)) // CHECK: %[[IN1:.*]] = xla_gpu.apply_indexing // CHECK-SAME: <(d0) -> (d0 - 4)> - // CHECK-SAME: (%[[Y]] in [4, 7]) + // CHECK-SAME: (%[[Y]] in [4, 8)) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]] // CHECK: scf.yield %[[VAL]] // CHECK: } else { @@ -811,10 +811,10 @@ TEST_F(ElementalHloToMlirTest, ConvolutionSimple) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[W]] in [0, 5])[%[[X]] in [0, 2]] + // CHECK-SAME: (%[[W]] in [0, 6))[%[[X]] in [0, 3)] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[H]] in [0, 7])[%[[Y]] in [0, 4]] + // CHECK-SAME: (%[[H]] in [0, 8))[%[[Y]] in [0, 5)] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -857,10 +857,10 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithWindowStrides) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 * 2 + s0)> - // CHECK-SAME: (%[[W]] in [0, 2])[%[[X]] in [0, 2]] + // CHECK-SAME: (%[[W]] in [0, 3))[%[[X]] in [0, 3)] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 * 2 + s0)> - // CHECK-SAME: (%[[H]] in [0, 3])[%[[Y]] in [0, 4]] + // CHECK-SAME: (%[[H]] in [0, 4))[%[[Y]] in [0, 5)] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -903,21 +903,21 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithPadding) { // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { - // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0)>(%[[W]] in [0, 7])[%[[X]] in [0, 2]] + // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0)>(%[[W]] in [0, 8))[%[[X]] in [0, 3)] // CHECK-DAG: %[[TXGE:.+]] = arith.cmpi sge, %[[TESTX]], %[[C1]] : index // CHECK-DAG: %[[TXLE:.+]] = arith.cmpi sle, %[[TESTX]], %[[C8]] : index // CHECK-DAG: %[[TX:.+]] = arith.andi %[[TXGE]], %[[TXLE]] : i1 - // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0)>(%[[H]] in [0, 11])[%[[Y]] in [0, 4]] + // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0)>(%[[H]] in [0, 12))[%[[Y]] in [0, 5)] // CHECK-DAG: %[[TYGE:.+]] = arith.cmpi sge, %[[TESTY]], %[[C2]] : index // CHECK-DAG: %[[TYLE:.+]] = arith.cmpi sle, %[[TESTY]], %[[C13]] : index // CHECK-DAG: %[[TY:.+]] = arith.andi %[[TYGE]], %[[TYLE]] : i1 // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0 - 1)> - // CHECK-SAME: (%[[W]] in [0, 7])[%[[X]] in [0, 2]] + // CHECK-SAME: (%[[W]] in [0, 8))[%[[X]] in [0, 3)] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0 - 2)> - // CHECK-SAME: (%[[H]] in [0, 11])[%[[Y]] in [0, 4]] + // CHECK-SAME: (%[[H]] in [0, 12))[%[[Y]] in [0, 5)] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -957,17 +957,17 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithLhsDilation) { // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { - // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> ((d0 + s0) mod 2)>(%[[W]] in [0, 12])[%[[X]] in [0, 2]] + // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> ((d0 + s0) mod 2)>(%[[W]] in [0, 13))[%[[X]] in [0, 3)] // CHECK-DAG: %[[TX:.+]] = arith.cmpi eq, %[[TESTX]], %[[C0]] : index - // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> ((d0 + s0) mod 2)>(%[[H]] in [0, 18])[%[[Y]] in [0, 4]] + // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> ((d0 + s0) mod 2)>(%[[H]] in [0, 19))[%[[Y]] in [0, 5)] // CHECK-DAG: %[[TY:.+]] = arith.cmpi eq, %[[TESTY]], %[[C0]] : index // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> ((d0 + s0) floordiv 2)> - // CHECK-SAME: (%[[W]] in [0, 12])[%[[X]] in [0, 2]] + // CHECK-SAME: (%[[W]] in [0, 13))[%[[X]] in [0, 3)] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> ((d0 + s0) floordiv 2)> - // CHECK-SAME: (%[[H]] in [0, 18])[%[[Y]] in [0, 4]] + // CHECK-SAME: (%[[H]] in [0, 19))[%[[Y]] in [0, 5)] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1010,10 +1010,10 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithRhsDilation) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0 * 2)> - // CHECK-SAME: (%[[W]] in [0, 3])[%[[X]] in [0, 2]] + // CHECK-SAME: (%[[W]] in [0, 4))[%[[X]] in [0, 3)] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0 * 2)> - // CHECK-SAME: (%[[H]] in [0, 3])[%[[Y]] in [0, 4]] + // CHECK-SAME: (%[[H]] in [0, 4))[%[[Y]] in [0, 5)] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1056,16 +1056,16 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithFeatureGroupCount) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[W]] in [0, 5]) - // CHECK-SAME: [%[[X]] in [0, 2]] + // CHECK-SAME: (%[[W]] in [0, 6)) + // CHECK-SAME: [%[[X]] in [0, 3)] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[H]] in [0, 7]) - // CHECK-SAME: [%[[Y]] in [0, 4]] + // CHECK-SAME: (%[[H]] in [0, 8)) + // CHECK-SAME: [%[[Y]] in [0, 5)] // CHECK: %[[XX2:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> ((d0 floordiv 8) * 2 + s0)> - // CHECK-SAME: (%[[O]] in [0, 15]) - // CHECK-SAME: [%[[I]] in [0, 1]] + // CHECK-SAME: (%[[O]] in [0, 16)) + // CHECK-SAME: [%[[I]] in [0, 2)] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[XX2]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<2x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1110,12 +1110,12 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithBatchGroupCount) { // CHECK: %[[R4:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[W]] in [0, 5]) - // CHECK-SAME: [%[[X]] in [0, 2]] + // CHECK-SAME: (%[[W]] in [0, 6)) + // CHECK-SAME: [%[[X]] in [0, 3)] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[H]] in [0, 7]) - // CHECK-SAME: [%[[Y]] in [0, 4]] + // CHECK-SAME: (%[[H]] in [0, 8)) + // CHECK-SAME: [%[[Y]] in [0, 5)] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[G]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1557,7 +1557,7 @@ TEST_F(ElementalHloToMlirTest, MixedIndexingTuple) { // CHECK: %[[A:.*]] = tensor.extract %[[P0]][%[[X]], %[[Y]]] // CHECK: %[[IDX:.*]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0, d1) -> (d0 * 10 + d1)> - // CHECK-SAME: (%[[X]] in [0, 9], %[[Y]] in [0, 9]) + // CHECK-SAME: (%[[X]] in [0, 10), %[[Y]] in [0, 10)) // CHECK: %[[B:.*]] = tensor.extract %[[P1]][%[[IDX]]] // CHECK: return %[[A]], %[[B]] )")); @@ -1581,7 +1581,7 @@ TEST_F(ElementalHloToMlirTest, NestedTuple) { // CHECK: %[[P0_V:.*]] = xla_gpu.pure_call @main_p0 // CHECK: %[[IDX:.*]] = // CHECK-SAME: affine_map<(d0, d1) -> (d0 * 10 + d1)> - // CHECK-SAME: (%[[X]] in [0, 9], %[[Y]] in [0, 9]) + // CHECK-SAME: (%[[X]] in [0, 10), %[[Y]] in [0, 10)) // CHECK: %[[P1_V:.*]] = xla_gpu.pure_call @main_p1 // CHECK-SAME: (%[[P0]], %[[P1]], %[[IDX]]) // CHECK: return %[[P0_V]], %[[P1_V]], %[[P1_V]], %[[P1_V]], %[[P0_V]] diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc b/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc index 88af1769ac4e02..efa683a4c044a4 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc @@ -243,12 +243,12 @@ mlir::ParseResult parseOperandsWithBoundsList( if (parser.parseOperand(operand) || parser.parseKeyword("in") || parser.parseLSquare() || parser.parseInteger(lower_bound) || parser.parseComma() || parser.parseInteger(upper_bound) || - parser.parseRSquare()) { + parser.parseRParen()) { return failure(); } operands->push_back(operand); lower_bounds->push_back(lower_bound); - upper_bounds->push_back(upper_bound); + upper_bounds->push_back(upper_bound - 1); return success(); })) { return failure(); @@ -309,7 +309,7 @@ void ApplyIndexingOp::print(mlir::OpAsmPrinter& p) { p << '('; for (int dim_id = 0; dim_id < num_dimensions; ++dim_id) { p << operands[dim_id] << " in " << '[' << lower_bounds[dim_id] << ", " - << upper_bounds[dim_id] << ']'; + << upper_bounds[dim_id] + 1 << ')'; if (dim_id != num_dimensions - 1) { p << ", "; } @@ -322,7 +322,7 @@ void ApplyIndexingOp::print(mlir::OpAsmPrinter& p) { for (int symbol_id = 0; symbol_id < num_symbols; ++symbol_id) { unsigned operand_id = num_dimensions + symbol_id; p << operands[operand_id] << " in " << '[' << lower_bounds[operand_id] - << ", " << upper_bounds[operand_id] << ']'; + << ", " << upper_bounds[operand_id] + 1 << ')'; if (symbol_id != num_symbols - 1) { p << ", "; } diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir index ded1cf683f3ffd..337ce1563e8e8f 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir +++ b/third_party/xla/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir @@ -2,14 +2,14 @@ #map0 = affine_map<()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2)> func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #map0 [%s0 in [-10, 10], %s1 in [0, 2]] + %0:2 = xla_gpu.apply_indexing #map0 [%s0 in [-10, 10), %s1 in [0, 3)] func.return %0#0, %0#1 : index, index } // CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 + 1, s0 mod 2)> // CHECK-LABEL: func.func @simplify_apply_indexing // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -// CHECK: xla_gpu.apply_indexing #[[$MAP]][%[[ARG_0]] in [-10, 10]] +// CHECK: xla_gpu.apply_indexing #[[$MAP]][%[[ARG_0]] in [-10, 10)] // ----- @@ -17,8 +17,8 @@ func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) { func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index, %d2: index, %s0: index, %s1: index) -> (index, index, index) { %0:3 = xla_gpu.apply_indexing #map0 - (%d0 in [0, 1], %d1 in [0, 2], %d2 in [0, 3]) - [%s0 in [-11, 11], %s1 in [0, 3]] + (%d0 in [0, 2), %d1 in [0, 3), %d2 in [0, 4)) + [%s0 in [-11, 11), %s1 in [0, 4)] func.return %0#0, %0#1, %0#2 : index, index, index } // CHECK: #[[$MAP:.*]] = affine_map<(d0, d1)[s0] -> (s0 + 1, s0 mod 2, d0 + d1)> @@ -30,15 +30,15 @@ func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index, // CHECK-SAME: %[[ARG_3:[a-zA-Z0-9_]+]]: index, // CHECK-SAME: %[[ARG_4:[a-zA-Z0-9_]+]]: index) // CHECK: xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[ARG_0]] in [0, 1], %[[ARG_2]] in [0, 3]) -// CHECK-SAME: [%[[ARG_3]] in [-11, 11]] +// CHECK-SAME: (%[[ARG_0]] in [0, 2), %[[ARG_2]] in [0, 4)) +// CHECK-SAME: [%[[ARG_3]] in [-11, 11)] // ----- #map0 = affine_map<(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0)> func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index) -> (index, index, index, index, index) { - %0:5 = xla_gpu.apply_indexing #map0 (%d0 in [-10, 10], %d1 in [0, 2])[%s0 in [-1, 1]] + %0:5 = xla_gpu.apply_indexing #map0 (%d0 in [-10, 10), %d1 in [0, 3))[%s0 in [-1, 2)] func.return %0#0, %0#1, %0#2, %0#3, %0#4 : index, index, index, index, index } // CHECK: #[[$MAP:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> @@ -56,7 +56,7 @@ func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index) #map0 = affine_map<(d0, d1)[s0] -> (d0 + s0, s0 + 4, d1 mod 2, 1 + d1, s0)> func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index) { - %0:5 = xla_gpu.apply_indexing #map0 (%d0 in [-10, 10], %d1 in [0, 2])[%s0 in [-1, 1]] + %0:5 = xla_gpu.apply_indexing #map0 (%d0 in [-10, 10), %d1 in [0, 3))[%s0 in [-1, 2)] func.return %0#2 : index } // CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)> @@ -64,7 +64,7 @@ func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index) // CHECK-LABEL: func.func @remove_unused_results // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -// CHECK: %[[NEW_RESULT:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG_1]] in [0, 2]) +// CHECK: %[[NEW_RESULT:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG_1]] in [0, 3)) // CHECK: return %[[NEW_RESULT]] // ----- @@ -74,22 +74,22 @@ func.func @fold_operands(%d0: index) -> index { %d1 = arith.constant 1 : index %s0 = arith.constant 2 : index %s1 = arith.constant 3 : index - %0 = xla_gpu.apply_indexing #map0 (%d0 in [0, 10], %d1 in [0, 5]) - [%s0 in [-10, 10], %s1 in [0, 4]] + %0 = xla_gpu.apply_indexing #map0 (%d0 in [0, 11), %d1 in [0, 6)) + [%s0 in [-10, 10), %s1 in [0, 5)] func.return %0 : index } // CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 + 3)> // CHECK-LABEL: func.func @fold_operands // CHECK-SAME: %[[ARG_0:.*]]: index) -// CHECK: xla_gpu.apply_indexing #[[$MAP]](%[[ARG_0]] in [0, 10]) +// CHECK: xla_gpu.apply_indexing #[[$MAP]](%[[ARG_0]] in [0, 11)) // ----- func.func @fold_operands_and_results(%arg0: index, %arg1: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (0, d1)> - (%arg0 in [0, 4], %arg1 in [0, 5]) + (%arg0 in [0, 5), %arg1 in [0, 6)) return %0#0, %0#1 : index, index } @@ -102,9 +102,9 @@ func.func @fold_operands_and_results(%arg0: index, %arg1: index) func.func @fold_sequence(%arg0: index, %arg1: index) -> index { %0 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)> - (%arg0 in [0, 5], %arg1 in [0, 4]) + (%arg0 in [0, 6), %arg1 in [0, 5)) %1 = xla_gpu.apply_indexing affine_map<(d0) -> (d0 mod 100 + 42)> - (%0 in [0, 10000]) + (%0 in [0, 10001)) func.return %1 : index } @@ -112,15 +112,15 @@ func.func @fold_sequence(%arg0: index, %arg1: index) -> index { // CHECK-LABEL: func.func @fold_sequence // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) // CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[ARG0]] in [0, 5], %[[ARG1]] in [0, 4]) +// CHECK-SAME: (%[[ARG0]] in [0, 6), %[[ARG1]] in [0, 5)) // ----- func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index { %0 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)> - (%arg0 in [0, 5], %arg1 in [0, 4]) + (%arg0 in [0, 6), %arg1 in [0, 5)) %1 = xla_gpu.apply_indexing affine_map<()[s0] -> (s0 mod 100 + 42)> - [%0 in [0, 10000]] + [%0 in [0, 10001)] func.return %1 : index } @@ -128,15 +128,15 @@ func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index { // CHECK-LABEL: func.func @fold_sequence_sym // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) // CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[ARG0]] in [0, 5], %[[ARG1]] in [0, 4]) +// CHECK-SAME: (%[[ARG0]] in [0, 6), %[[ARG1]] in [0, 5)) // ----- func.func @fold_sequence_shared_operands(%arg0: index, %arg1: index) -> index { %0 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)> - (%arg0 in [0, 5], %arg1 in [0, 4]) + (%arg0 in [0, 6), %arg1 in [0, 5)) %1 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)> - (%arg1 in [0, 4], %0 in [0, 10000]) + (%arg1 in [0, 5), %0 in [0, 10001)) func.return %1 : index } @@ -144,4 +144,4 @@ func.func @fold_sequence_shared_operands(%arg0: index, %arg1: index) -> index { // CHECK-LABEL: func.func @fold_sequence_shared_operands // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) // CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[ARG1]] in [0, 4], %[[ARG0]] in [0, 5]) +// CHECK-SAME: (%[[ARG1]] in [0, 5), %[[ARG0]] in [0, 6)) diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/invalid.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/invalid.mlir index fcc6c6d6d7b129..55aa9c78512034 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/invalid.mlir +++ b/third_party/xla/xla/service/gpu/fusions/mlir/tests/invalid.mlir @@ -3,6 +3,6 @@ #map0 = affine_map<(d0, d1)[s0] -> (d0, d1 + s0)> func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { // expected-error @+1 {{operand, lower_bounds, upper_bounds count and affine map dimension and symbol count must match}} - %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 2]) + %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 3)) func.return %0#0, %0#1 : index, index -} \ No newline at end of file +} diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir index ba03eb8e047c0f..22f579e8f44424 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir +++ b/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir @@ -91,7 +91,7 @@ module { // CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, // CHECK-SAME: %[[X:.*]]: index, %[[Y:.*]]: index // CHECK: %[[IDX:.*]] = xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[X]] in [0, 1], %[[Y]] in [0, 2]) +// CHECK-SAME: (%[[X]] in [0, 2), %[[Y]] in [0, 3)) // CHECK: %[[IDX_CAST:.*]] = arith.index_castui %[[IDX]] : index to i64 // CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[IDX_CAST]]] // CHECK: llvm.load %[[PTR]] diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/ops.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/ops.mlir index ce571070c0bb6e..da3ef936395b72 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/ops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/mlir/tests/ops.mlir @@ -58,7 +58,7 @@ func.func @caller(%a: f32, %b: f32) -> f32 { #map0 = affine_map<(d0, d1)[s0] -> (d0, d1 + s0)> func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 2], %d1 in [1, 3])[%s0 in [2, 4]] + %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 3), %d1 in [1, 4))[%s0 in [2, 5)] func.return %0#0, %0#1 : index, index } // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0, d1 + s0)> @@ -66,13 +66,13 @@ func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) // CHECK-LABEL: @apply_indexing // CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index, %[[s0:.*]]: index) // CHECK: xla_gpu.apply_indexing #[[$MAP0]] -// CHECK-SAME: (%[[d0]] in [0, 2], %[[d1]] in [1, 3])[%[[s0]] in [2, 4]] +// CHECK-SAME: (%[[d0]] in [0, 3), %[[d1]] in [1, 4))[%[[s0]] in [2, 5)] // ----- #map0 = affine_map<(d0, d1) -> (d0, d1)> func.func @apply_indexing_no_symbols(%d0: index, %d1: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 2], %d1 in [1, 3]) + %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 3), %d1 in [1, 4)) func.return %0#0, %0#1 : index, index } // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> @@ -80,17 +80,17 @@ func.func @apply_indexing_no_symbols(%d0: index, %d1: index) -> (index, index) { // CHECK-LABEL: @apply_indexing_no_symbols // CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index) // CHECK: xla_gpu.apply_indexing #[[$MAP0]] -// CHECK-SAME: (%[[d0]] in [0, 2], %[[d1]] in [1, 3]) +// CHECK-SAME: (%[[d0]] in [0, 3), %[[d1]] in [1, 4)) // ----- #map0 = affine_map<()[s0] -> (s0, s0)> func.func @apply_indexing_no_dims(%s0: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #map0 [%s0 in [2, 4]] + %0:2 = xla_gpu.apply_indexing #map0 [%s0 in [2, 5)] func.return %0#0, %0#1 : index, index } // CHECK: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0, s0)> // CHECK-LABEL: @apply_indexing_no_dims // CHECK: (%[[s0:.*]]: index) -// CHECK: xla_gpu.apply_indexing #[[$MAP0]][%[[s0]] in [2, 4]] \ No newline at end of file +// CHECK: xla_gpu.apply_indexing #[[$MAP0]][%[[s0]] in [2, 5)] diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/optimize_loops.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/optimize_loops.mlir index 6f903f3ace4748..1f173fa26b1d47 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/optimize_loops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/mlir/tests/optimize_loops.mlir @@ -21,23 +21,23 @@ module { %1 = arith.cmpi eq, %0, %c0 : index %2 = arith.divui %thread_id_x, %c32 : index %3 = arith.cmpi ult, %thread_id_x, %c8 : index - %4 = xla_gpu.apply_indexing #map(%block_id_x in [0, 31]) - %5 = xla_gpu.apply_indexing #map1(%block_id_x in [0, 31]) + %4 = xla_gpu.apply_indexing #map(%block_id_x in [0, 32)) + %5 = xla_gpu.apply_indexing #map1(%block_id_x in [0, 32)) %extracted = tensor.extract %arg2[%4, %5] : tensor<4x8xf32> %6 = arith.mulf %extracted, %cst : f32 %7 = arith.addf %6, %cst : f32 %8 = math.rsqrt %7 : f32 %9:2 = scf.for %arg7 = %c0 to %c8 step %c1 iter_args(%arg8 = %arg6, %arg9 = %cst) -> (tensor<4x8x4096xf32>, f32) { - %18 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]] + %18 = xla_gpu.apply_indexing #map2(%c0 in [0, 2), %thread_id_x in [0, 256))[%arg7 in [0, 8)] %19 = vector.transfer_read %arg1[%18], %cst_1 {in_bounds = [true]} : tensor<4096xbf16>, vector<2xbf16> - %20 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]] + %20 = xla_gpu.apply_indexing #map2(%c0 in [0, 2), %thread_id_x in [0, 256))[%arg7 in [0, 8)] %21 = vector.transfer_read %arg3[%20], %cst_1 {in_bounds = [true]} : tensor<4096xbf16>, vector<2xbf16> - %22 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]] + %22 = xla_gpu.apply_indexing #map2(%c0 in [0, 2), %thread_id_x in [0, 256))[%arg7 in [0, 8)] %23 = vector.transfer_read %arg4[%4, %5, %22], %cst_1 {in_bounds = [true]} : tensor<4x8x4096xbf16>, vector<2xbf16> - %24 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]] + %24 = xla_gpu.apply_indexing #map2(%c0 in [0, 2), %thread_id_x in [0, 256))[%arg7 in [0, 8)] %25 = vector.transfer_read %arg0[%4, %5, %24], %cst {in_bounds = [true]} : tensor<4x8x4096xf32>, vector<2xf32> %26:2 = scf.for %arg10 = %c0 to %c2 step %c1 iter_args(%arg11 = %arg8, %arg12 = %arg9) -> (tensor<4x8x4096xf32>, f32) { - %27 = xla_gpu.apply_indexing #map2(%arg10 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]] + %27 = xla_gpu.apply_indexing #map2(%arg10 in [0, 2), %thread_id_x in [0, 256))[%arg7 in [0, 8)] %28 = vector.extract %25[%arg10] : f32 from vector<2xf32> %29 = vector.extract %23[%arg10] : bf16 from vector<2xbf16> %30 = arith.extf %29 : bf16 to f32 @@ -151,7 +151,7 @@ module { %cst = arith.constant dense<[0.0, 0.0]> : vector<2xf32> %cst0 = arith.constant 0.0 : f32 %ret = scf.for %i = %c0 to %c17 step %c1 iter_args (%iter = %cst) -> (vector<2xf32>) { - %base = xla_gpu.apply_indexing affine_map<(d0) -> (d0 * 2)>(%i in [0, 15]) + %base = xla_gpu.apply_indexing affine_map<(d0) -> (d0 * 2)>(%i in [0, 16)) %val = vector.transfer_read %arg[%base], %cst0 : tensor<34xf32>, vector<2xf32> %log = math.log %val : vector<2xf32> %add = arith.addf %log, %iter : vector<2xf32> diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir index ec1a726da9db13..acd6cc097f9e63 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir +++ b/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir @@ -63,7 +63,7 @@ func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.pt %1 = gpu.block_id x scf.for %i = %c0 to %c4 step %c1 { %2 = xla_gpu.apply_indexing affine_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4))> - [%1 in [0, 3071], %0 in [0, 127], %i in [0, 3]] + [%1 in [0, 3072), %0 in [0, 128), %i in [0, 4)] %3 = arith.index_castui %2 : index to i64 %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32 %5 = llvm.load %4 invariant : !llvm.ptr -> f32 @@ -92,7 +92,7 @@ func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.pt func.func @arg_ranges(%arg0: index, %arg1: index) -> index { %0 = xla_gpu.apply_indexing affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)> - [%arg0 in [0, 42], %arg1 in [0, 1000]] + [%arg0 in [0, 43), %arg1 in [0, 1001)] return %0 : index } @@ -106,7 +106,7 @@ func.func @arg_ranges(%arg0: index, %arg1: index) -> index { func.func @cant_lower(%arg0: index, %arg1: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100, s0 + s1)> - [%arg0 in [-10, 42], %arg1 in [0, 1000]] + [%arg0 in [-10, 43), %arg1 in [0, 1001)] return %0#0, %0#1 : index, index } @@ -124,7 +124,7 @@ func.func @order_summands(%arg1: index) { scf.for %arg3 = %c0 to %c4 step %c1 { %0 = xla_gpu.apply_indexing affine_map<()[s0, s1, s2] -> ((s0 + s1) floordiv 3 + s0 * 512 + s1 * 4 + s2 * 10)> - [%arg2 in [0, 3], %arg1 in [0, 3], %arg3 in [0, 3]] + [%arg2 in [0, 4), %arg1 in [0, 4), %arg3 in [0, 4)] "dummy.op"(%0) : (index) -> () } } diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/vectorize_loads_stores.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/vectorize_loads_stores.mlir index 1141d1581505ea..fb944a5dcb923d 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/vectorize_loads_stores.mlir +++ b/third_party/xla/xla/service/gpu/fusions/mlir/tests/vectorize_loads_stores.mlir @@ -10,7 +10,7 @@ module { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] + %idx = xla_gpu.apply_indexing #map(%i in [0, 64))[%j in [0, 2)] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -29,7 +29,7 @@ module { // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C64]] step %[[C1]] iter_args(%[[ITER:.*]] = -// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #map(%[[I]] in [0, 63]) +// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #map(%[[I]] in [0, 64)) // CHECK-NEXT: %[[V:.*]] = vector.transfer_read %[[ARG0]][%[[BASE]]] // CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] // CHECK-NEXT: vector.extract %[[V]][%[[J]]] @@ -76,7 +76,7 @@ module { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] + %idx = xla_gpu.apply_indexing #map(%i in [0, 64))[%j in [0, 2)] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -102,7 +102,7 @@ module { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] + %idx = xla_gpu.apply_indexing #map(%i in [0, 64))[%j in [0, 2)] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -152,7 +152,7 @@ module { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] + %idx = xla_gpu.apply_indexing #map(%i in [0, 64))[%j in [0, 2)] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -312,7 +312,7 @@ module { %extracted1 = tensor.extract %arg2[%arg4] : tensor<32xf32> %0:2 = scf.for %i = %c0 to %c8 step %c1 iter_args(%iter0 = %arg3, %iter1 = %cst) -> (tensor<32x4096xf32>, f32) { %1:2 = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter2 = %iter0, %iter3 = %iter1) -> (tensor<32x4096xf32>, f32) { - %2 = xla_gpu.apply_indexing #map(%j in [0, 1], %arg4 in [0, 255])[%i in [0, 7]] + %2 = xla_gpu.apply_indexing #map(%j in [0, 2), %arg4 in [0, 256))[%i in [0, 8)] %extracted2 = tensor.extract %arg0[%i, %2] : tensor<32x4096xf32> %extracted3 = tensor.extract %arg1[%2] : tensor<4096xbf16> %3 = arith.extf %extracted3 : bf16 to f32 @@ -333,7 +333,7 @@ module { // CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}, %[[ARG1:.*]]: tensor{{.*}}, %[[ARG2:.*]]: tensor{{.*}}, %[[ARG3:.*]]: tensor{{.*}}, %[[ARG4:.*]]: index) // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: scf.for %[[I:.*]] = %[[C0]] -// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG4]] in [0, 255])[%[[I]] in [0, 7]] +// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG4]] in [0, 256))[%[[I]] in [0, 8)] // CHECK: %[[READ1:.*]] = vector.transfer_read %[[ARG1]][%[[BASE]]] // CHECK: %[[READ2:.*]] = vector.transfer_read %[[ARG0]][%[[I]], %[[BASE]]] // CHECK: %[[INNER:.*]]:2 = scf.for %[[J:.*]] = %[[C0]] {{.*}} iter_args(%[[F:.*]] = {{.*}}, %[[V:.*]] = {{.*}}) -> (f32, vector<2xf32>) @@ -360,7 +360,7 @@ module { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] + %idx = xla_gpu.apply_indexing #map(%i in [0, 64))[%j in [0, 2)] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -390,7 +390,7 @@ module { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] + %idx = xla_gpu.apply_indexing #map(%i in [0, 64))[%j in [0, 2)] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc index 1bc6d48da5363b..35e8656ae0a2e9 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc @@ -74,19 +74,19 @@ TEST_F(MlirRowReductionTest, VariadicRowReduce) { (d3 * 2 + d0 floordiv 128) mod 3, (d0 mod 128 + s2 * 128) * 2 + s3) domain: - d0 in [0, 255] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 2] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 0] - s1 in [0, 0] - s2 in [0, 7] - s3 in [0, 1] - s4 in [0, 0] - d0 mod 128 + s2 * 128 in [0, 1023] - d3 * 2 + d0 floordiv 128 in [0, 5] + d0 in [0, 256) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 3) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 1) + s1 in [0, 1) + s2 in [0, 8) + s3 in [0, 2) + s4 in [0, 1) + d0 mod 128 + s2 * 128 in [0, 1024) + d3 * 2 + d0 floordiv 128 in [0, 6) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -94,14 +94,14 @@ TEST_F(MlirRowReductionTest, VariadicRowReduce) { (d0, d1, d2, d3, d4, d5) -> ((d3 * 2 + d0 floordiv 128) floordiv 3, (d3 * 2 + d0 floordiv 128) mod 3) domain: - d0 in [0, 255] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 2] - d4 in [0, 0] - d5 in [0, 0] - d0 mod 128 in [0, 0] - d3 * 2 + d0 floordiv 128 in [0, 5] + d0 in [0, 256) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 3) + d4 in [0, 1) + d5 in [0, 1) + d0 mod 128 in [0, 1) + d3 * 2 + d0 floordiv 128 in [0, 6) )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } @@ -242,33 +242,33 @@ TEST_F(MlirRowReductionTest, F64RowReduction) { d3 * 8 + d0 floordiv 32, (d0 mod 32 + s2 * 32) * 2 + s3) domain: - d0 in [0, 255] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 12] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 0] - s1 in [0, 0] - s2 in [0, 1] - s3 in [0, 1] - s4 in [0, 0] - d0 mod 32 + s2 * 32 in [0, 63] - d3 * 8 + d0 floordiv 32 in [0, 99] + d0 in [0, 256) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 13) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 1) + s1 in [0, 1) + s2 in [0, 2) + s3 in [0, 2) + s4 in [0, 1) + d0 mod 32 + s2 * 32 in [0, 64) + d3 * 8 + d0 floordiv 32 in [0, 100) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5) -> (d3 * 8 + d0 floordiv 32) domain: - d0 in [0, 255] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 12] - d4 in [0, 0] - d5 in [0, 0] - d0 mod 32 in [0, 0] - d3 * 8 + d0 floordiv 32 in [0, 99] + d0 in [0, 256) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 13) + d4 in [0, 1) + d5 in [0, 1) + d0 mod 32 in [0, 1) + d3 * 8 + d0 floordiv 32 in [0, 100) )")); // This reduction is small enough not to require shared memory. TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( @@ -307,32 +307,32 @@ TEST_F(MlirRowReductionTest, MultiRowReduction) { (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( d3 * 64 + d0 floordiv 4, d0 mod 4) domain: - d0 in [0, 255] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 15] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 0] - s1 in [0, 0] - s2 in [0, 0] - s3 in [0, 0] - d0 mod 4 in [0, 3] - d3 * 64 + d0 floordiv 4 in [0, 1023] + d0 in [0, 256) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 16) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 1) + s1 in [0, 1) + s2 in [0, 1) + s3 in [0, 1) + d0 mod 4 in [0, 4) + d3 * 64 + d0 floordiv 4 in [0, 1024) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5) -> (d3 * 64 + d0 floordiv 4) domain: - d0 in [0, 255] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 15] - d4 in [0, 0] - d5 in [0, 0] - d0 mod 4 in [0, 0] - d3 * 64 + d0 floordiv 4 in [0, 1023] + d0 in [0, 256) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 16) + d4 in [0, 1) + d5 in [0, 1) + d0 mod 4 in [0, 1) + d3 * 64 + d0 floordiv 4 in [0, 1024) )")); // Multi-row reductions don't use shared memory. TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( @@ -371,10 +371,10 @@ TEST_F(MlirRowReductionTest, NonPowerOfTwoRowReduction) { // CHECK: %[[FULL_TILES:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] // CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] // CHECK-NOT: scf.if - // CHECK: xla_gpu.apply_indexing #[[MAP1]](%[[J]] in [0, 1], %thread_id_x in [0, 255])[%[[I]] in [0, 4]] + // CHECK: xla_gpu.apply_indexing #[[MAP1]](%[[J]] in [0, 2), %thread_id_x in [0, 256))[%[[I]] in [0, 5)] // CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%{{.*}} = %[[FULL_TILES]]) // CHECK: scf.if - // CHECK: xla_gpu.apply_indexing #[[MAP2]](%[[J]] in [0, 1], %thread_id_x in [0, 255]) + // CHECK: xla_gpu.apply_indexing #[[MAP2]](%[[J]] in [0, 2), %thread_id_x in [0, 256)) )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } @@ -419,30 +419,30 @@ TEST_F(MlirRowReductionTest, NonTrivialEpilogue) { (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( (d0 floordiv 4) * 4 + d0 mod 4) domain: - d0 in [0, 3] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 0] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 0] - s1 in [0, 0] - s2 in [0, 0] - s3 in [0, 0] - d0 mod 4 in [0, 3] + d0 in [0, 4) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 1) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 1) + s1 in [0, 1) + s2 in [0, 1) + s3 in [0, 1) + d0 mod 4 in [0, 4) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5) -> () domain: - d0 in [0, 3] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 0] - d4 in [0, 0] - d5 in [0, 0] - d0 mod 4 in [0, 0] + d0 in [0, 4) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 1) + d4 in [0, 1) + d5 in [0, 1) + d0 mod 4 in [0, 1) )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } @@ -480,33 +480,33 @@ TEST_F(MlirRowReductionTest, SideOutput) { (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4] -> ( d3 * 2 + d0 floordiv 128, (d0 mod 128 + s2 * 128) * 2 + s3) domain: - d0 in [0, 255] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 3] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 0] - s1 in [0, 0] - s2 in [0, 7] - s3 in [0, 1] - s4 in [0, 0] - d0 mod 128 + s2 * 128 in [0, 1023] - d3 * 2 + d0 floordiv 128 in [0, 7] + d0 in [0, 256) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 4) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 1) + s1 in [0, 1) + s2 in [0, 8) + s3 in [0, 2) + s4 in [0, 1) + d0 mod 128 + s2 * 128 in [0, 1024) + d3 * 2 + d0 floordiv 128 in [0, 8) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5) -> (d3 * 2 + d0 floordiv 128) domain: - d0 in [0, 255] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 3] - d4 in [0, 0] - d5 in [0, 0] - d0 mod 128 in [0, 0] - d3 * 2 + d0 floordiv 128 in [0, 7] + d0 in [0, 256) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 4) + d4 in [0, 1) + d5 in [0, 1) + d0 mod 128 in [0, 1) + d3 * 2 + d0 floordiv 128 in [0, 8) )")); TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( // CHECK: @fused_computation @@ -551,33 +551,33 @@ TEST_F(MlirRowReductionTest, UnsignedSideOutput) { (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4] -> ( d3 * 2 + d0 floordiv 128, (d0 mod 128 + s2 * 128) * 2 + s3) domain: - d0 in [0, 255] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 3] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 0] - s1 in [0, 0] - s2 in [0, 7] - s3 in [0, 1] - s4 in [0, 0] - d0 mod 128 + s2 * 128 in [0, 1023] - d3 * 2 + d0 floordiv 128 in [0, 7] + d0 in [0, 256) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 4) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 1) + s1 in [0, 1) + s2 in [0, 8) + s3 in [0, 2) + s4 in [0, 1) + d0 mod 128 + s2 * 128 in [0, 1024) + d3 * 2 + d0 floordiv 128 in [0, 8) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5) -> (d3 * 2 + d0 floordiv 128) domain: - d0 in [0, 255] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 3] - d4 in [0, 0] - d5 in [0, 0] - d0 mod 128 in [0, 0] - d3 * 2 + d0 floordiv 128 in [0, 7] + d0 in [0, 256) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 4) + d4 in [0, 1) + d5 in [0, 1) + d0 mod 128 in [0, 1) + d3 * 2 + d0 floordiv 128 in [0, 8) )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } @@ -610,18 +610,18 @@ TEST_F(MlirRowReductionTest, BroadcastSideOutput) { MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> () domain: - d0 in [0, 31] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 0] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 0] - s1 in [0, 0] - s2 in [0, 1] - s3 in [0, 0] - (d0 + s2 * 32) mod 6 in [0, 5] - d0 + s2 * 32 in [0, 35] + d0 in [0, 32) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 1) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 1) + s1 in [0, 1) + s2 in [0, 2) + s3 in [0, 1) + (d0 + s2 * 32) mod 6 in [0, 6) + d0 + s2 * 32 in [0, 36) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -629,17 +629,17 @@ TEST_F(MlirRowReductionTest, BroadcastSideOutput) { (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( (d0 + s2 * 32) floordiv 6, (d0 + s2 * 32) mod 6) domain: - d0 in [0, 31] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 0] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 0] - s1 in [0, 0] - s2 in [0, 1] - s3 in [0, 0] - d0 + s2 * 32 in [0, 35] + d0 in [0, 32) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 1) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 1) + s1 in [0, 1) + s2 in [0, 2) + s3 in [0, 1) + d0 + s2 * 32 in [0, 36) )")); TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( // CHECK: @fused_computation @@ -686,29 +686,29 @@ TEST_F(MlirRowReductionTest, VariadicMOF) { (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( (d0 + s2 * 32) floordiv 6, (d0 + s2 * 32) mod 6) domain: - d0 in [0, 31] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 0] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 0] - s1 in [0, 0] - s2 in [0, 1] - s3 in [0, 0] - d0 + s2 * 32 in [0, 35] + d0 in [0, 32) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 1) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 1) + s1 in [0, 1) + s2 in [0, 2) + s3 in [0, 1) + d0 + s2 * 32 in [0, 36) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5) -> () domain: - d0 in [0, 0] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 0] - d4 in [0, 0] - d5 in [0, 0] + d0 in [0, 1) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 1) + d4 in [0, 1) + d5 in [0, 1) )")); TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( // CHECK: @fused_computation @@ -750,19 +750,19 @@ TEST_F(MlirRowReductionTest, ThreadIndexingOutputLayout) { (d3 * 8 + d0 floordiv 32) mod 64, (d0 mod 32 + s2 * 32) * 2 + s3) domain: - d0 in [0, 255] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 799] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 0] - s1 in [0, 0] - s2 in [0, 7] - s3 in [0, 1] - s4 in [0, 0] - d0 mod 32 + s2 * 32 in [0, 255] - d3 * 8 + d0 floordiv 32 in [0, 6399] + d0 in [0, 256) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 800) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 1) + s1 in [0, 1) + s2 in [0, 8) + s3 in [0, 2) + s4 in [0, 1) + d0 mod 32 + s2 * 32 in [0, 256) + d3 * 8 + d0 floordiv 32 in [0, 6400) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -772,14 +772,14 @@ TEST_F(MlirRowReductionTest, ThreadIndexingOutputLayout) { (d3 * 8 + d0 floordiv 32) mod 64 ) domain: - d0 in [0, 255] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 799] - d4 in [0, 0] - d5 in [0, 0] - d0 mod 32 in [0, 0] - d3 * 8 + d0 floordiv 32 in [0, 6399] + d0 in [0, 256) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 800) + d4 in [0, 1) + d5 in [0, 1) + d0 mod 32 in [0, 1) + d3 * 8 + d0 floordiv 32 in [0, 6400) )")); } @@ -894,16 +894,16 @@ TEST_F(MlirColumnReductionTest, ColumnReduction) { (d3 mod 11) * 32 + d0 mod 32 + s1 ) domain: - d0 in [0, 1023] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 142] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 32] - s1 in [0, 0] - (d3 mod 11) * 32 + d0 mod 32 + s1 in [0, 320] - d0 floordiv 32 + s0 * 32 in [0, 1050] + d0 in [0, 1024) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 143) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 33) + s1 in [0, 1) + (d3 mod 11) * 32 + d0 mod 32 + s1 in [0, 321) + d0 floordiv 32 + s0 * 32 in [0, 1051) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -912,15 +912,15 @@ TEST_F(MlirColumnReductionTest, ColumnReduction) { d3 floordiv 11, (d3 mod 11) * 32 + d0 floordiv 32 + s0 ) domain: - d0 in [0, 1023] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 142] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 0] - (d3 mod 11) * 32 + d0 floordiv 32 + s0 in [0, 320] - d0 mod 32 in [0, 0] + d0 in [0, 1024) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 143) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 1) + (d3 mod 11) * 32 + d0 floordiv 32 + s0 in [0, 321) + d0 mod 32 in [0, 1) )")); TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( // CHECK: xla_gpu.pure_call @Add_add @@ -1009,16 +1009,16 @@ TEST_F(MlirColumnReductionTest, ColumnReductionVectorization) { (d3 floordiv 256) * 2048 + d0 floordiv 32 + s0 * 32, ((d3 mod 256) * 32 + d0 mod 32) * 2 + s1) domain: - d0 in [0, 1023] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 255] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 63] - s1 in [0, 1] - ((d3 mod 256) * 32 + d0 mod 32) * 2 + s1 in [0, 16383] - d0 floordiv 32 + s0 * 32 in [0, 2047] + d0 in [0, 1024) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 256) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 64) + s1 in [0, 2) + ((d3 mod 256) * 32 + d0 mod 32) * 2 + s1 in [0, 16384) + d0 floordiv 32 + s0 * 32 in [0, 2048) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -1026,15 +1026,15 @@ TEST_F(MlirColumnReductionTest, ColumnReductionVectorization) { (d0, d1, d2, d3, d4, d5)[s0] -> ((d3 floordiv 256) * 16384 + ((d3 mod 256) * 32 + d0 floordiv 32) * 2 + s0) domain: - d0 in [0, 1023] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 255] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 1] - ((d3 mod 256) * 32 + d0 floordiv 32) * 2 + s0 in [0, 16383] - d0 mod 32 in [0, 0] + d0 in [0, 1024) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 256) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 2) + ((d3 mod 256) * 32 + d0 floordiv 32) * 2 + s0 in [0, 16384) + d0 mod 32 in [0, 1) )")); TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( // CHECK: vector<2xf32> @@ -1084,16 +1084,16 @@ TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_v2) { ((d3 mod 24) * 32 + d0 mod 32) * 2 + s1 ) domain: - d0 in [0, 1023] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 4607] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 1] - s1 in [0, 1] - ((d3 mod 24) * 32 + d0 mod 32) * 2 + s1 in [0, 1535] - d0 floordiv 32 + s0 * 32 in [0, 63] + d0 in [0, 1024) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 4608) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 2) + s1 in [0, 2) + ((d3 mod 24) * 32 + d0 mod 32) * 2 + s1 in [0, 1536) + d0 floordiv 32 + s0 * 32 in [0, 64) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -1102,15 +1102,15 @@ TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_v2) { d3 floordiv 24, ((d3 mod 24) * 32 + d0 floordiv 32) * 2 + s0) domain: - d0 in [0, 1023] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 4607] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 1] - ((d3 mod 24) * 32 + d0 floordiv 32) * 2 + s0 in [0, 1535] - d0 mod 32 in [0, 0] + d0 in [0, 1024) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 4608) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 2) + ((d3 mod 24) * 32 + d0 floordiv 32) * 2 + s0 in [0, 1536) + d0 mod 32 in [0, 1) )")); } @@ -1132,16 +1132,16 @@ TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_v4) { ((d3 mod 12) * 32 + d0 mod 32) * 4 + s1 ) domain: - d0 in [0, 1023] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 2303] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 1] - s1 in [0, 3] - ((d3 mod 12) * 32 + d0 mod 32) * 4 + s1 in [0, 1535] - d0 floordiv 32 + s0 * 32 in [0, 63] + d0 in [0, 1024) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 2304) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 2) + s1 in [0, 4) + ((d3 mod 12) * 32 + d0 mod 32) * 4 + s1 in [0, 1536) + d0 floordiv 32 + s0 * 32 in [0, 64) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -1150,15 +1150,15 @@ TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_v4) { d3 floordiv 12, ((d3 mod 12) * 32 + d0 floordiv 32) * 4 + s0) domain: - d0 in [0, 1023] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 2303] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 3] - ((d3 mod 12) * 32 + d0 floordiv 32) * 4 + s0 in [0, 1535] - d0 mod 32 in [0, 0] + d0 in [0, 1024) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 2304) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 4) + ((d3 mod 12) * 32 + d0 floordiv 32) * 4 + s0 in [0, 1536) + d0 mod 32 in [0, 1) )")); } @@ -1182,16 +1182,16 @@ TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_Complex) { (d3 mod 48) * 32 + d0 mod 32 + s1 ) domain: - d0 in [0, 1023] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 9215] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 1] - s1 in [0, 0] - (d3 mod 48) * 32 + d0 mod 32 + s1 in [0, 1535] - d0 floordiv 32 + s0 * 32 in [0, 63] + d0 in [0, 1024) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 9216) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 2) + s1 in [0, 1) + (d3 mod 48) * 32 + d0 mod 32 + s1 in [0, 1536) + d0 floordiv 32 + s0 * 32 in [0, 64) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -1200,15 +1200,15 @@ TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_Complex) { d3 floordiv 48, (d3 mod 48) * 32 + d0 floordiv 32 + s0) domain: - d0 in [0, 1023] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 9215] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 0] - (d3 mod 48) * 32 + d0 floordiv 32 + s0 in [0, 1535] - d0 mod 32 in [0, 0] + d0 in [0, 1024) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 9216) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 1) + (d3 mod 48) * 32 + d0 floordiv 32 + s0 in [0, 1536) + d0 mod 32 in [0, 1) )")); } diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_test.cc b/third_party/xla/xla/service/gpu/fusions/reduction_test.cc index 636ffccadc4866..491a80c84ad215 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_test.cc @@ -81,18 +81,18 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) { (d0 mod 32 + s2 * 32) * 2 + s3 ) domain: - d0 in [0, 255] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 799] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 0] - s1 in [0, 0] - s2 in [0, 7] - s3 in [0, 1] - d0 mod 32 + s2 * 32 in [0, 255] - d3 * 8 + d0 floordiv 32 in [0, 6399] + d0 in [0, 256) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 800) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 1) + s1 in [0, 1) + s2 in [0, 8) + s3 in [0, 2) + d0 mod 32 + s2 * 32 in [0, 256) + d3 * 8 + d0 floordiv 32 in [0, 6400) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -102,14 +102,14 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) { (d3 * 8 + d0 floordiv 32) mod 64 ) domain: - d0 in [0, 255] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 799] - d4 in [0, 0] - d5 in [0, 0] - d0 mod 32 in [0, 0] - d3 * 8 + d0 floordiv 32 in [0, 6399] + d0 in [0, 256) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 800) + d4 in [0, 1) + d5 in [0, 1) + d0 mod 32 in [0, 1) + d3 * 8 + d0 floordiv 32 in [0, 6400) )")); } @@ -148,17 +148,17 @@ TEST_F(ReductionTest, ThreadIndexingMultiRowReduction) { d0 mod 4 ) domain: - d0 in [0, 255] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 99] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 0] - s1 in [0, 0] - s2 in [0, 0] - d0 mod 4 in [0, 3] - d3 * 64 + d0 floordiv 4 in [0, 6399] + d0 in [0, 256) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 100) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 1) + s1 in [0, 1) + s2 in [0, 1) + d0 mod 4 in [0, 4) + d3 * 64 + d0 floordiv 4 in [0, 6400) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -168,14 +168,14 @@ TEST_F(ReductionTest, ThreadIndexingMultiRowReduction) { (d0 floordiv 4) mod 64 ) domain: - d0 in [0, 255] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 99] - d4 in [0, 0] - d5 in [0, 0] - d0 mod 4 in [0, 0] - d3 * 64 + d0 floordiv 4 in [0, 6399] + d0 in [0, 256) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 100) + d4 in [0, 1) + d5 in [0, 1) + d0 mod 4 in [0, 1) + d3 * 64 + d0 floordiv 4 in [0, 6400) )")); } @@ -214,11 +214,11 @@ TEST_F(ReductionTest, ThreadIndexingColumnReduction) { d0 mod 32 ) domain: - d0 in [0, 1023] d1 in [0, 0] d2 in [0, 0] - d3 in [0, 99] d4 in [0, 0] d5 in [0, 0] - s0 in [0, 0] s1 in [0, 127] s2 in [0, 0] - d0 floordiv 32 + s1 * 32 in [0, 63] - d0 mod 32 in [0, 31] + d0 in [0, 1024) d1 in [0, 1) d2 in [0, 1) + d3 in [0, 100) d4 in [0, 1) d5 in [0, 1) + s0 in [0, 1) s1 in [0, 128) s2 in [0, 1) + d0 floordiv 32 + s1 * 32 in [0, 64) + d0 mod 32 in [0, 32) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -228,9 +228,9 @@ TEST_F(ReductionTest, ThreadIndexingColumnReduction) { d0 floordiv 32 ) domain: - d0 in [0, 1023] d1 in [0, 0] d2 in [0, 0] - d3 in [0, 99] d4 in [0, 0] d5 in [0, 0] - d0 mod 32 in [0, 0] + d0 in [0, 1024) d1 in [0, 1) d2 in [0, 1) + d3 in [0, 100) d4 in [0, 1) d5 in [0, 1) + d0 mod 32 in [0, 1) )")); } @@ -268,14 +268,14 @@ TEST_F(ReductionTest, ThreadIndexingOutputLayout) { (d3 * 8 + d0 floordiv 32) mod 64 ) domain: - d0 in [0, 255] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 799] - d4 in [0, 0] - d5 in [0, 0] - d0 mod 32 in [0, 0] - d3 * 8 + d0 floordiv 32 in [0, 6399] + d0 in [0, 256) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 800) + d4 in [0, 1) + d5 in [0, 1) + d0 mod 32 in [0, 1) + d3 * 8 + d0 floordiv 32 in [0, 6400) )")); } @@ -314,16 +314,16 @@ TEST_F(ReductionTest, ThreadIndexingSideOutput) { (d0 mod 32) * 2 + s2 * 64 + s3 ) domain: - d0 in [0, 255] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 799] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 0] - s1 in [0, 0] - s2 in [0, 7] - s3 in [0, 1] + d0 in [0, 256) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 800) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 1) + s1 in [0, 1) + s2 in [0, 8) + s3 in [0, 2) )"; auto input_indexing = fusion.ComputeThreadIdToInputIndexing(1, 0, &mlir_context_); @@ -369,17 +369,17 @@ TEST_F(ReductionTest, ThreadIndexingVectorized) { (d0 + s2 * 512) * 2 + s3 ) domain: - d0 in [0, 511] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 1023] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 0] - s1 in [0, 0] - s2 in [0, 7] - s3 in [0, 1] - d0 + s2 * 512 in [0, 4095] + d0 in [0, 512) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 1024) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 1) + s1 in [0, 1) + s2 in [0, 8) + s3 in [0, 2) + d0 + s2 * 512 in [0, 4096) )")); } @@ -414,33 +414,33 @@ TEST_F(ReductionTest, ThreadIndexingBroadcastSideOutput) { (d0 + s2 * 32) mod 6 ) domain: - d0 in [0, 31] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 0] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 0] - s1 in [0, 0] - s2 in [0, 15] - d0 + s2 * 32 in [0, 35] + d0 in [0, 32) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 1) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 1) + s1 in [0, 1) + s2 in [0, 16) + d0 + s2 * 32 in [0, 36) )")); EXPECT_THAT( fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> () domain: - d0 in [0, 31] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 0] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 0] - s1 in [0, 0] - s2 in [0, 15] - (d0 + s2 * 32) mod 6 in [0, 5] - d0 + s2 * 32 in [0, 35] + d0 in [0, 32) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 1) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 1) + s1 in [0, 1) + s2 in [0, 16) + (d0 + s2 * 32) mod 6 in [0, 6) + d0 + s2 * 32 in [0, 36) )")); } diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc index e0cd913f00d341..266e2325e4ef46 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc @@ -86,15 +86,15 @@ TEST_F(MlirScatterFusionTest, ThreadIdIndexing) { (bl_x * 128 + th_x) mod 20 ) domain: - th_x in [0, 127] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 65] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] - th_x + bl_x * 128 in [0, 8399] + th_x in [0, 128) + th_y in [0, 1) + th_z in [0, 1) + bl_x in [0, 66) + bl_y in [0, 1) + bl_z in [0, 1) + chunk_id in [0, 1) + unroll_id in [0, 1) + th_x + bl_x * 128 in [0, 8400) )"; EXPECT_THAT( fusion @@ -125,16 +125,16 @@ TEST_F(MlirScatterFusionTest, ThreadIdIndexing) { (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] -> (((bl_x * 128 + th_x) floordiv 200) mod 42, 0) domain: - th_x in [0, 127] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 65] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] - index_id in [0, 0] - th_x + bl_x * 128 in [0, 8399] + th_x in [0, 128) + th_y in [0, 1) + th_z in [0, 1) + bl_x in [0, 66) + bl_y in [0, 1) + bl_z in [0, 1) + chunk_id in [0, 1) + unroll_id in [0, 1) + index_id in [0, 1) + th_x + bl_x * 128 in [0, 8400) )"; EXPECT_THAT( fusion diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_test.cc b/third_party/xla/xla/service/gpu/fusions/scatter_test.cc index b765194e3f868b..4ef767c4d0986a 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/scatter_test.cc @@ -150,15 +150,15 @@ TEST_F(ScatterFusionTest, ThreadIdIndexing) { (bl_x * 128 + th_x) mod 20 ) domain: - th_x in [0, 127] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 65] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] - th_x + bl_x * 128 in [0, 8399] + th_x in [0, 128) + th_y in [0, 1) + th_z in [0, 1) + bl_x in [0, 66) + bl_y in [0, 1) + bl_z in [0, 1) + chunk_id in [0, 1) + unroll_id in [0, 1) + th_x + bl_x * 128 in [0, 8400) )"; EXPECT_THAT( fusion @@ -189,16 +189,16 @@ TEST_F(ScatterFusionTest, ThreadIdIndexing) { (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] -> (((bl_x * 128 + th_x) floordiv 200) mod 42, 0) domain: - th_x in [0, 127] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 65] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] - index_id in [0, 0] - th_x + bl_x * 128 in [0, 8399] + th_x in [0, 128) + th_y in [0, 1) + th_z in [0, 1) + bl_x in [0, 66) + bl_y in [0, 1) + bl_z in [0, 1) + chunk_id in [0, 1) + unroll_id in [0, 1) + index_id in [0, 1) + th_x + bl_x * 128 in [0, 8400) )"; EXPECT_THAT( fusion diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc index a492cb0af146c5..a54885894f96b4 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc @@ -56,15 +56,15 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing021) { (d3 mod 2) * 32 + d0 mod 32 ) domain: - d0 in [0, 127] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 199] - d4 in [0, 0] - d5 in [0, 0] - - s0 in [0, 7] - s1 in [0, 0] + d0 in [0, 128) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 200) + d4 in [0, 1) + d5 in [0, 1) + + s0 in [0, 8) + s1 in [0, 1) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -75,15 +75,15 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing021) { d0 mod 32 ) domain: - d0 in [0, 127] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 199] - d4 in [0, 0] - d5 in [0, 0] - - s0 in [0, 7] - s1 in [0, 0] + d0 in [0, 128) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 200) + d4 in [0, 1) + d5 in [0, 1) + + s0 in [0, 8) + s1 in [0, 1) )")); } @@ -113,15 +113,15 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing201) { d0 mod 32 ) domain: - d0 in [0, 127] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 199] - d4 in [0, 0] - d5 in [0, 0] - - s0 in [0, 7] - s1 in [0, 0] + d0 in [0, 128) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 200) + d4 in [0, 1) + d5 in [0, 1) + + s0 in [0, 8) + s1 in [0, 1) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -132,15 +132,15 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing201) { (d3 mod 2) * 32 + d0 mod 32 ) domain: - d0 in [0, 127] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 199] - d4 in [0, 0] - d5 in [0, 0] - - s0 in [0, 7] - s1 in [0, 0] + d0 in [0, 128) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 200) + d4 in [0, 1) + d5 in [0, 1) + + s0 in [0, 8) + s1 in [0, 1) )")); } @@ -170,14 +170,14 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexingVectorized021) { (d0 mod 32) * 2 + s1 ) domain: - d0 in [0, 127] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 8191] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 15] - s1 in [0, 1] + d0 in [0, 128) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 8192) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 16) + s1 in [0, 2) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -188,14 +188,14 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexingVectorized021) { (d0 mod 32) * 2 + s1 ) domain: - d0 in [0, 127] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 8191] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 15] - s1 in [0, 1] + d0 in [0, 128) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 8192) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 16) + s1 in [0, 2) )")); } @@ -224,14 +224,14 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexingVectorized210) { (d0 mod 32) * 2 + s1 + (d3 mod 128) * 64 ) domain: - d0 in [0, 127] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 8191] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 15] - s1 in [0, 1] + d0 in [0, 128) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 8192) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 16) + s1 in [0, 2) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -242,14 +242,14 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexingVectorized210) { (d0 mod 32) * 2 + s1 ) domain: - d0 in [0, 127] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 8191] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 15] - s1 in [0, 1] + d0 in [0, 128) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 8192) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 16) + s1 in [0, 2) )")); } @@ -620,15 +620,15 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexingSideOutput) { d0 floordiv 32 + s0 * 4 ) domain: - d0 in [0, 127] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 199] - d4 in [0, 0] - d5 in [0, 0] - - s0 in [0, 7] - s1 in [0, 0] + d0 in [0, 128) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 200) + d4 in [0, 1) + d5 in [0, 1) + + s0 in [0, 8) + s1 in [0, 1) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(1, &mlir_context)->ToString(), @@ -639,15 +639,15 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexingSideOutput) { (d3 mod 2) * 32 + d0 mod 32 ) domain: - d0 in [0, 127] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 199] - d4 in [0, 0] - d5 in [0, 0] - - s0 in [0, 7] - s1 in [0, 0] + d0 in [0, 128) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 200) + d4 in [0, 1) + d5 in [0, 1) + + s0 in [0, 8) + s1 in [0, 1) )")); } diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_test.cc b/third_party/xla/xla/service/gpu/fusions/transpose_test.cc index b290e073cec333..6255b09cca63d7 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_test.cc @@ -80,16 +80,16 @@ TEST_F(TransposeTest, ThreadIndexing021) { (d3 mod 2) * 32 + d0 mod 32 ) domain: - d0 in [0, 127] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 199] - d4 in [0, 0] - d5 in [0, 0] - - s0 in [0, 0] - s1 in [0, 7] - s2 in [0, 0] + d0 in [0, 128) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 200) + d4 in [0, 1) + d5 in [0, 1) + + s0 in [0, 1) + s1 in [0, 8) + s2 in [0, 1) )")); EXPECT_THAT( fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), @@ -100,16 +100,16 @@ TEST_F(TransposeTest, ThreadIndexing021) { d0 mod 32 ) domain: - d0 in [0, 127] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 199] - d4 in [0, 0] - d5 in [0, 0] - - s0 in [0, 0] - s1 in [0, 7] - s2 in [0, 0] + d0 in [0, 128) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 200) + d4 in [0, 1) + d5 in [0, 1) + + s0 in [0, 1) + s1 in [0, 8) + s2 in [0, 1) )")); } @@ -142,16 +142,16 @@ TEST_F(TransposeTest, ThreadIndexing201) { d0 mod 32 ) domain: - d0 in [0, 127] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 199] - d4 in [0, 0] - d5 in [0, 0] - - s0 in [0, 0] - s1 in [0, 7] - s2 in [0, 0] + d0 in [0, 128) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 200) + d4 in [0, 1) + d5 in [0, 1) + + s0 in [0, 1) + s1 in [0, 8) + s2 in [0, 1) )")); EXPECT_THAT( fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), @@ -162,16 +162,16 @@ TEST_F(TransposeTest, ThreadIndexing201) { (d3 mod 2) * 32 + d0 mod 32 ) domain: - d0 in [0, 127] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 199] - d4 in [0, 0] - d5 in [0, 0] - - s0 in [0, 0] - s1 in [0, 7] - s2 in [0, 0] + d0 in [0, 128) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 200) + d4 in [0, 1) + d5 in [0, 1) + + s0 in [0, 1) + s1 in [0, 8) + s2 in [0, 1) )")); } @@ -207,17 +207,17 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { d0 mod 4 ) domain: - d0 in [0, 127] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 1] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 7] - s1 in [0, 0] - s2 in [0, 0] - d0 floordiv 32 + s0 * 4 in [0, 23] - d0 mod 32 in [0, 23] + d0 in [0, 128) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 2) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 8) + s1 in [0, 1) + s2 in [0, 1) + d0 floordiv 32 + s0 * 4 in [0, 24) + d0 mod 32 in [0, 24) )")); EXPECT_THAT( fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), @@ -229,17 +229,17 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { d0 mod 32 ) domain: - d0 in [0, 127] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 1] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 7] - s1 in [0, 0] - s2 in [0, 0] - d0 floordiv 32 + s0 * 4 in [0, 23] - d0 mod 32 in [0, 23] + d0 in [0, 128) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 2) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 8) + s1 in [0, 1) + s2 in [0, 1) + d0 floordiv 32 + s0 * 4 in [0, 24) + d0 mod 32 in [0, 24) )")); } @@ -305,16 +305,16 @@ TEST_F(TransposeTest, ThreadIndexingSideOutput) { d0 floordiv 32 + s1 * 4 ) domain: - d0 in [0, 127] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 199] - d4 in [0, 0] - d5 in [0, 0] - - s0 in [0, 0] - s1 in [0, 7] - s2 in [0, 0] + d0 in [0, 128) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 200) + d4 in [0, 1) + d5 in [0, 1) + + s0 in [0, 1) + s1 in [0, 8) + s2 in [0, 1) )")); EXPECT_THAT( fusion->ComputeThreadIdToOutputIndexing(1, &mlir_context)->ToString(), @@ -325,16 +325,16 @@ TEST_F(TransposeTest, ThreadIndexingSideOutput) { (d3 mod 2) * 32 + d0 mod 32 ) domain: - d0 in [0, 127] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 199] - d4 in [0, 0] - d5 in [0, 0] - - s0 in [0, 0] - s1 in [0, 7] - s2 in [0, 0] + d0 in [0, 128) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 200) + d4 in [0, 1) + d5 in [0, 1) + + s0 in [0, 1) + s1 in [0, 8) + s2 in [0, 1) )")); } diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc index a4dde9f2ae2a5a..ca69e81d2b93a4 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc @@ -359,9 +359,9 @@ int64_t EvaluateAffineExpr(AffineExpr expr, // For example, for the following indexing map: // (d0)[s0] -> (d0 + s0) // domain: -// d0 in [0, 3] +// d0 in [0, 4) // s0 in [0, 1, 2] -// s0 mod 2 in [0, 0] +// s0 mod 2 in [0, 1) // The function will compute the following indices [0, 2, 1, 3, 2, 4, 3, 5]. void FindAllIndices(AffineExpr expr, int dim_id, int symbol_id, const std::vector& dimension_ranges, @@ -397,8 +397,8 @@ void FindAllIndices(AffineExpr expr, int dim_id, int symbol_id, // Computes contiguous intervals of accessed elements. // For example, for an indexing map // (thread_x) -> (thread_x * 4 + s0 + (thread_x floordiv 16) * 1984) -// d0 in [0, 31] -// s0 in [0, 3] +// d0 in [0, 32) +// s0 in [0, 4) // The intervals are [0, 63] and [2047, 2111]. std::vector FindIntervals( AffineExpr expr, const std::vector& dimension_ranges, diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc index 301f6331fb39a0..e5c82bcf0ef833 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc @@ -65,14 +65,14 @@ TEST_F(IndexingAnalysisTest, FuseProducerConsumerOutputToInputIndexing) { UnorderedElementsAre(Pair(parameter, ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 999] - d1 in [0, 999] + d0 in [0, 1000) + d1 in [0, 1000) )"))), Pair(transpose, ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 999] - d1 in [0, 999] + d0 in [0, 1000) + d1 in [0, 1000) )"))))); } @@ -97,26 +97,26 @@ TEST_F(IndexingAnalysisTest, ComputeGroupedOutputToInputIndexing) { Pair(root, ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 999] - d1 in [0, 999] + d0 in [0, 1000) + d1 in [0, 1000) )"))), Pair(transpose, ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 999] - d1 in [0, 999] + d0 in [0, 1000) + d1 in [0, 1000) )"))), Pair(parameter, UnorderedElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 999] - d1 in [0, 999] + d0 in [0, 1000) + d1 in [0, 1000) )"), MatchIndexingMap(R"( (d0, d1) -> (d1, d0) domain: - d0 in [0, 999] - d1 in [0, 999] + d0 in [0, 1000) + d1 in [0, 1000) )"))))); } @@ -155,29 +155,29 @@ TEST_F(IndexingAnalysisTest, Pair(root, ElementsAre(MatchIndexingMap(R"( (d0) -> (d0) domain: - d0 in [0, 31] + d0 in [0, 32) )"))), Pair(root->operand(0), ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (d0, s0) domain: - d0 in [0, 31] - s0 in [0, 39] + d0 in [0, 32) + s0 in [0, 40) )"))), Pair(root->operand(1), ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (d0, s0) domain: - d0 in [0, 31] - s0 in [0, 39] + d0 in [0, 32) + s0 in [0, 40) )"))), Pair(root->operand(2), ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: - d0 in [0, 31] + d0 in [0, 32) )"))), Pair(root->operand(3), ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: - d0 in [0, 31] + d0 in [0, 32) )"))))); } @@ -206,8 +206,8 @@ TEST_F(IndexingAnalysisTest, ComputeGroupedOutputToInputIndexing_SingleOp) { parameter, ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 999] - d1 in [0, 999] + d0 in [0, 1000) + d1 in [0, 1000) )"))))); } @@ -248,18 +248,18 @@ TEST_F(IndexingAnalysisTest, Pair(&bcast.instruction(), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d1, d2, d3) domain: - d0 in [0, 14] - d1 in [0, 31] - d2 in [0, 19] - d3 in [0, 63] + d0 in [0, 15) + d1 in [0, 32) + d2 in [0, 20) + d3 in [0, 64) )"))), Pair(¶meter_0.instruction(), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d2) domain: - d0 in [0, 14] - d1 in [0, 31] - d2 in [0, 19] - d3 in [0, 63] + d0 in [0, 15) + d1 in [0, 32) + d2 in [0, 20) + d3 in [0, 64) )"))))); } @@ -277,9 +277,9 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestOutputPermutation) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d1, d2, d0) domain: - d0 in [0, 29] - d1 in [0, 9] - d2 in [0, 19] + d0 in [0, 30) + d1 in [0, 10) + d2 in [0, 20) )")))); auto output_indexing = GetInputToOutputIndexing(root, /*input_id=*/0, @@ -288,9 +288,9 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestOutputPermutation) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d2, d0, d1) domain: - d0 in [0, 9] - d1 in [0, 19] - d2 in [0, 29] + d0 in [0, 10) + d1 in [0, 20) + d2 in [0, 30) )")))); } @@ -351,9 +351,9 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputPermutation) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d2, d0, d1) domain: - d0 in [0, 9] - d1 in [0, 19] - d2 in [0, 29] + d0 in [0, 10) + d1 in [0, 20) + d2 in [0, 30) )")))); auto output_indexing = GetInputToOutputIndexing(root, /*input_id=*/0, @@ -362,9 +362,9 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputPermutation) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d1, d2, d0) domain: - d0 in [0, 29] - d1 in [0, 9] - d2 in [0, 19] + d0 in [0, 30) + d1 in [0, 10) + d2 in [0, 20) )")))); } @@ -382,9 +382,9 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputAndOutputPermutation) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1, d2) domain: - d0 in [0, 29] - d1 in [0, 9] - d2 in [0, 19] + d0 in [0, 30) + d1 in [0, 10) + d2 in [0, 20) )")))); auto output_indexing = GetInputToOutputIndexing(root, /*input_id=*/0, @@ -393,9 +393,9 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputAndOutputPermutation) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1, d2) domain: - d0 in [0, 29] - d1 in [0, 9] - d2 in [0, 19] + d0 in [0, 30) + d1 in [0, 10) + d2 in [0, 20) )")))); } @@ -413,14 +413,14 @@ TEST_F(IndexingAnalysisTest, ElementwiseOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 9] - d1 in [0, 19] + d0 in [0, 10) + d1 in [0, 20) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 9] - d1 in [0, 19] + d0 in [0, 10) + d1 in [0, 20) )")))); auto output_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); @@ -428,8 +428,8 @@ TEST_F(IndexingAnalysisTest, ElementwiseOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 9] - d1 in [0, 19] + d0 in [0, 10) + d1 in [0, 20) )")))); auto output_indexing_1 = GetInputToOutputIndexing(root, /*input_id=*/1); @@ -437,8 +437,8 @@ TEST_F(IndexingAnalysisTest, ElementwiseOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 9] - d1 in [0, 19] + d0 in [0, 10) + d1 in [0, 20) )")))); } @@ -461,14 +461,14 @@ TEST_F(IndexingAnalysisTest, Map) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 9] - d1 in [0, 19] + d0 in [0, 10) + d1 in [0, 20) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 9] - d1 in [0, 19] + d0 in [0, 10) + d1 in [0, 20) )")))); auto output_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); @@ -476,8 +476,8 @@ TEST_F(IndexingAnalysisTest, Map) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 9] - d1 in [0, 19] + d0 in [0, 10) + d1 in [0, 20) )")))); auto output_indexing_1 = GetInputToOutputIndexing(root, /*input_id=*/1); @@ -485,8 +485,8 @@ TEST_F(IndexingAnalysisTest, Map) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 9] - d1 in [0, 19] + d0 in [0, 10) + d1 in [0, 20) )")))); } @@ -502,9 +502,9 @@ TEST_F(IndexingAnalysisTest, BitcastIsReshape) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 * 4 + d2) domain: - d0 in [0, 3] - d1 in [0, 7] - d2 in [0, 3] + d0 in [0, 4) + d1 in [0, 8) + d2 in [0, 4) )")))); } @@ -520,10 +520,10 @@ TEST_F(IndexingAnalysisTest, BitcastIsTranspose) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d3, d1, d2) domain: - d0 in [0, 2] - d1 in [0, 5] - d2 in [0, 127] - d3 in [0, 12287] + d0 in [0, 3) + d1 in [0, 6) + d2 in [0, 128) + d3 in [0, 12288) )")))); } @@ -540,17 +540,17 @@ TEST_F(IndexingAnalysisTest, BitcastIsTransposeReshapeTranspose) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d1, d0 floordiv 3, d0 mod 3) domain: - d0 in [0, 50] - d1 in [0, 15] + d0 in [0, 51) + d1 in [0, 16) )")))); auto output_indexing = GetInputToOutputIndexing(root); EXPECT_THAT(output_indexing.indexing_maps, ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d1 * 3 + d2, d0) domain: - d0 in [0, 15] - d1 in [0, 16] - d2 in [0, 2] + d0 in [0, 16) + d1 in [0, 17) + d2 in [0, 3) )")))); } @@ -567,9 +567,9 @@ TEST_F(IndexingAnalysisTest, BroadcastOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d1) domain: - d0 in [0, 9] - d1 in [0, 19] - d2 in [0, 29] + d0 in [0, 10) + d1 in [0, 20) + d2 in [0, 30) )")))); auto output_indexing = GetInputToOutputIndexing(root); @@ -577,9 +577,9 @@ TEST_F(IndexingAnalysisTest, BroadcastOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0)[s0, s1] -> (s0, d0, s1) domain: - d0 in [0, 19] - s0 in [0, 9] - s1 in [0, 29] + d0 in [0, 20) + s0 in [0, 10) + s1 in [0, 30) )")))); } @@ -610,23 +610,23 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1, d2) domain: - d0 in [0, 1] - d1 in [0, 4] - d2 in [0, 6] + d0 in [0, 2) + d1 in [0, 5) + d2 in [0, 7) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 - 5, d2) domain: - d0 in [0, 1] - d1 in [5, 15] - d2 in [0, 6] + d0 in [0, 2) + d1 in [5, 16) + d2 in [0, 7) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 - 16, d2) domain: - d0 in [0, 1] - d1 in [16, 32] - d2 in [0, 6] + d0 in [0, 2) + d1 in [16, 33) + d2 in [0, 7) )")))); auto output_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); @@ -634,9 +634,9 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1, d2) domain: - d0 in [0, 1] - d1 in [0, 4] - d2 in [0, 6] + d0 in [0, 2) + d1 in [0, 5) + d2 in [0, 7) )")))); auto output_indexing_1 = GetInputToOutputIndexing(root, /*input_id=*/1); @@ -644,9 +644,9 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 + 5, d2) domain: - d0 in [0, 1] - d1 in [0, 10] - d2 in [0, 6] + d0 in [0, 2) + d1 in [0, 11) + d2 in [0, 7) )")))); auto output_indexing_2 = GetInputToOutputIndexing(root, /*input_id=*/2); @@ -654,9 +654,9 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 + 16, d2) domain: - d0 in [0, 1] - d1 in [0, 16] - d2 in [0, 6] + d0 in [0, 2) + d1 in [0, 17) + d2 in [0, 7) )")))); } @@ -677,39 +677,39 @@ TEST_F(IndexingAnalysisTest, DynamicSliceOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2)[s0, s1, s2] -> (d0 + s0, d1 + s1, d2 + s2) domain: - d0 in [0, 0] - d1 in [0, 1] - d2 in [0, 31] - s0 in [0, 1] + d0 in [0, 1) + d1 in [0, 2) + d2 in [0, 32) + s0 in [0, 2) hlo: %of1 = s32[] parameter(1) (d0, d1, d2) -> () - s1 in [0, 0] + s1 in [0, 1) hlo: %of2 = s32[] parameter(2) (d0, d1, d2) -> () - s2 in [0, 226] + s2 in [0, 227) hlo: %of3 = s32[] parameter(3) (d0, d1, d2) -> () )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> () domain: - d0 in [0, 0] - d1 in [0, 1] - d2 in [0, 31] + d0 in [0, 1) + d1 in [0, 2) + d2 in [0, 32) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> () domain: - d0 in [0, 0] - d1 in [0, 1] - d2 in [0, 31] + d0 in [0, 1) + d1 in [0, 2) + d2 in [0, 32) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> () domain: - d0 in [0, 0] - d1 in [0, 1] - d2 in [0, 31] + d0 in [0, 1) + d1 in [0, 2) + d2 in [0, 32) )")))); } @@ -729,32 +729,32 @@ TEST_F(IndexingAnalysisTest, DynamicUpdateSliceOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 19] - d1 in [0, 29] + d0 in [0, 20) + d1 in [0, 30) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d0 - s0, d1 - s1) domain: - d0 in [0, 19] - d1 in [0, 29] - s0 in [0, 15] + d0 in [0, 20) + d1 in [0, 30) + s0 in [0, 16) hlo: %of1 = s32[] parameter(2) (d0, d1) -> () - s1 in [0, 20] + s1 in [0, 21) hlo: %of2 = s32[] parameter(3) (d0, d1) -> () )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 19] - d1 in [0, 29] + d0 in [0, 20) + d1 in [0, 30) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 19] - d1 in [0, 29] + d0 in [0, 20) + d1 in [0, 30) )")))); } @@ -776,12 +776,12 @@ TEST_F(IndexingAnalysisTest, FusionOpWithSingleBinaryOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0) -> (d0) domain: - d0 in [0, 99] + d0 in [0, 100) )")), ElementsAre(MatchIndexingMap(R"( (d0) -> (d0) domain: - d0 in [0, 99] + d0 in [0, 100) )")))); } @@ -849,66 +849,66 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5)[s0] -> (d2, d0 * 768 + s0, d4, d5) domain: - d0 in [0, 15] - d1 in [0, 15] - d2 in [0, 2] - d3 in [0, 0] - d4 in [0, 5] - d5 in [0, 127] - s0 in [0, 767] + d0 in [0, 16) + d1 in [0, 16) + d2 in [0, 3) + d3 in [0, 1) + d4 in [0, 6) + d5 in [0, 128) + s0 in [0, 768) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5)[s0] -> (d0 * 768 + s0) domain: - d0 in [0, 15] - d1 in [0, 15] - d2 in [0, 2] - d3 in [0, 0] - d4 in [0, 5] - d5 in [0, 127] - s0 in [0, 767] + d0 in [0, 16) + d1 in [0, 16) + d2 in [0, 3) + d3 in [0, 1) + d4 in [0, 6) + d5 in [0, 128) + s0 in [0, 768) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5) -> (d1) domain: - d0 in [0, 15] - d1 in [0, 15] - d2 in [0, 2] - d3 in [0, 0] - d4 in [0, 5] - d5 in [0, 127] + d0 in [0, 16) + d1 in [0, 16) + d2 in [0, 3) + d3 in [0, 1) + d4 in [0, 6) + d5 in [0, 128) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5)[s0] -> (d1, d0 * 768 + s0) domain: - d0 in [0, 15] - d1 in [0, 15] - d2 in [0, 2] - d3 in [0, 0] - d4 in [0, 5] - d5 in [0, 127] - s0 in [0, 767] + d0 in [0, 16) + d1 in [0, 16) + d2 in [0, 3) + d3 in [0, 1) + d4 in [0, 6) + d5 in [0, 128) + s0 in [0, 768) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5)[s0] -> (d1, d0 * 768 + s0) domain: - d0 in [0, 15] - d1 in [0, 15] - d2 in [0, 2] - d3 in [0, 0] - d4 in [0, 5] - d5 in [0, 127] - s0 in [0, 767] + d0 in [0, 16) + d1 in [0, 16) + d2 in [0, 3) + d3 in [0, 1) + d4 in [0, 6) + d5 in [0, 128) + s0 in [0, 768) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5) -> (d2, d4, d5) domain: - d0 in [0, 15] - d1 in [0, 15] - d2 in [0, 2] - d3 in [0, 0] - d4 in [0, 5] - d5 in [0, 127] + d0 in [0, 16) + d1 in [0, 16) + d2 in [0, 3) + d3 in [0, 1) + d4 in [0, 6) + d5 in [0, 128) )")))); } @@ -962,17 +962,17 @@ TEST_F(IndexingAnalysisTest, FusionOpWithSoftmax) { ElementsAre(UnorderedElementsAre(MatchIndexingMap(R"( (d0, d1, d2)[s0] -> (d0, d1, s0) domain: - d0 in [0, 1] - d1 in [0, 64] - d2 in [0, 124] - s0 in [0, 124] + d0 in [0, 2) + d1 in [0, 65) + d2 in [0, 125) + s0 in [0, 125) )"), MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1, d2) domain: - d0 in [0, 1] - d1 in [0, 64] - d2 in [0, 124] + d0 in [0, 2) + d1 in [0, 65) + d2 in [0, 125) )")))); } @@ -993,14 +993,14 @@ TEST_F(IndexingAnalysisTest, FusionOpTensorPlusTransposedTensor) { ElementsAre(UnorderedElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 999] - d1 in [0, 999] + d0 in [0, 1000) + d1 in [0, 1000) )"), MatchIndexingMap(R"( (d0, d1) -> (d1, d0) domain: - d0 in [0, 999] - d1 in [0, 999] + d0 in [0, 1000) + d1 in [0, 1000) )")))); } @@ -1030,32 +1030,32 @@ TEST_F(IndexingAnalysisTest, FusionExponentialDuplication) { ElementsAre(UnorderedElementsAre(MatchIndexingMap(R"( (d0) -> (d0 + 1) domain: - d0 in [0, 1] + d0 in [0, 2) )"), MatchIndexingMap(R"( (d0) -> (d0) domain: - d0 in [0, 1] + d0 in [0, 2) )"), MatchIndexingMap(R"( (d0) -> (d0 + 2) domain: - d0 in [0, 1] + d0 in [0, 2) )")), UnorderedElementsAre(MatchIndexingMap(R"( (d0) -> (d0 + 2) domain: - d0 in [0, 1] + d0 in [0, 2) )"), MatchIndexingMap(R"( (d0) -> (d0 + 1) domain: - d0 in [0, 1] + d0 in [0, 2) )"), MatchIndexingMap(R"( (d0) -> (d0) domain: - d0 in [0, 1] + d0 in [0, 2) )")))); } @@ -1074,25 +1074,25 @@ TEST_F(IndexingAnalysisTest, GatherOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1] -> (d1 + s0, d2 + s1, d3) domain: - d0 in [0, 1805] - d1 in [0, 6] - d2 in [0, 7] - d3 in [0, 3] - s0 in [0, 26] + d0 in [0, 1806) + d1 in [0, 7) + d2 in [0, 8) + d3 in [0, 4) + s0 in [0, 27) hlo: %indices = s32[1806,2]{1,0} parameter(1) (d0, d1, d2, d3) -> (d0, 0) - s1 in [0, 68] + s1 in [0, 69) hlo: %indices = s32[1806,2]{1,0} parameter(1) (d0, d1, d2, d3) -> (d0, 1) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0] -> (d0, s0) domain: - d0 in [0, 1805] - d1 in [0, 6] - d2 in [0, 7] - d3 in [0, 3] - s0 in [0, 1] + d0 in [0, 1806) + d1 in [0, 7) + d2 in [0, 8) + d3 in [0, 4) + s0 in [0, 2) )")))); } @@ -1122,15 +1122,15 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReduceOfReduce) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0)[s0, s1, s2] -> (s0, s2, d0, s1) domain: - d0 in [0, 9] - s0 in [0, 149] - s1 in [0, 49] - s2 in [0, 19] + d0 in [0, 10) + s0 in [0, 150) + s1 in [0, 50) + s2 in [0, 20) )")), ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: - d0 in [0, 9] + d0 in [0, 10) )")))); } @@ -1160,15 +1160,15 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReduceOfBroadcast) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0] -> (d0, s0) domain: - d0 in [0, 14] - d1 in [0, 63] - s0 in [0, 19] + d0 in [0, 15) + d1 in [0, 64) + s0 in [0, 20) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 14] - d1 in [0, 63] + d0 in [0, 15) + d1 in [0, 64) )")))); } @@ -1201,9 +1201,9 @@ TEST_F(IndexingAnalysisTest, FusionOpWithTransposeOfTranspose) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d2, d0, d1) domain: - d0 in [0, 9] - d1 in [0, 49] - d2 in [0, 19] + d0 in [0, 10) + d1 in [0, 50) + d2 in [0, 20) )")))); } @@ -1233,14 +1233,14 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReducedSlice) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0)[s0, s1] -> (s0 + 5, d0 * 2, s1 * 3 + 50) domain: - d0 in [0, 31] - s0 in [0, 15] - s1 in [0, 127] + d0 in [0, 32) + s0 in [0, 16) + s1 in [0, 128) )")), ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: - d0 in [0, 31] + d0 in [0, 32) )")))); } @@ -1261,7 +1261,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReshape_CollapseOfExpand) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0) -> (d0) domain: - d0 in [0, 127] + d0 in [0, 128) )")))); } @@ -1282,8 +1282,8 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReshape_ExpandOfCollapse) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 7] - d1 in [0, 15] + d0 in [0, 8) + d1 in [0, 16) )")))); } @@ -1304,9 +1304,9 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReshape_ChainedGenericReshapes) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1, d2) domain: - d0 in [0, 9] - d1 in [0, 9] - d2 in [0, 9] + d0 in [0, 10) + d1 in [0, 10) + d2 in [0, 10) )")))); } @@ -1331,9 +1331,9 @@ TEST_F(IndexingAnalysisTest, FusionOpWithSliceOfSlice) { d1 * 6 + 8, d2 * 12 + 65) domain: - d0 in [0, 6] - d1 in [0, 8] - d2 in [0, 23] + d0 in [0, 7) + d1 in [0, 9) + d2 in [0, 24) )")))); } @@ -1367,44 +1367,44 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDynSliceOfDynSlice) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0, s1, s2, s3] -> (d0 + s0 + s2, d1 + s1 + s3) domain: - d0 in [0, 24] - d1 in [0, 15] - s0 in [0, 100] + d0 in [0, 25) + d1 in [0, 16) + s0 in [0, 101) hlo: %of11 = s32[] parameter(1) (d0, d1) -> () - s1 in [0, 32] + s1 in [0, 33) hlo: %of12 = s32[] parameter(2) (d0, d1) -> () - s2 in [0, 25] + s2 in [0, 26) hlo: %of21 = s32[] parameter(3) (d0, d1) -> () - s3 in [0, 16] + s3 in [0, 17) hlo: %of22 = s32[] parameter(4) (d0, d1) -> () )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 24] - d1 in [0, 15] + d0 in [0, 25) + d1 in [0, 16) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 24] - d1 in [0, 15] + d0 in [0, 25) + d1 in [0, 16) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 24] - d1 in [0, 15] + d0 in [0, 25) + d1 in [0, 16) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 24] - d1 in [0, 15] + d0 in [0, 25) + d1 in [0, 16) )")))); } @@ -1431,23 +1431,23 @@ TEST_F(IndexingAnalysisTest, FusionOpSliceOfAllConcatenateOpInputs) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 * 3, d2) domain: - d0 in [0, 1] - d1 in [0, 1] - d2 in [0, 6] + d0 in [0, 2) + d1 in [0, 2) + d2 in [0, 7) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 * 3 - 5, d2) domain: - d0 in [0, 1] - d1 in [2, 5] - d2 in [0, 6] + d0 in [0, 2) + d1 in [2, 6) + d2 in [0, 7) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 * 3 - 16, d2) domain: - d0 in [0, 1] - d1 in [6, 10] - d2 in [0, 6] + d0 in [0, 2) + d1 in [6, 11) + d2 in [0, 7) )")))); } @@ -1474,9 +1474,9 @@ TEST_F(IndexingAnalysisTest, FusionOpSliceOfOneOfConcatenateOpInputs) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 * 2, d2) domain: - d0 in [0, 1] - d1 in [0, 2] - d2 in [0, 6] + d0 in [0, 2) + d1 in [0, 3) + d2 in [0, 7) )")), ElementsAre(MatchIndexingMap("KNOWN EMPTY")), ElementsAre(MatchIndexingMap("KNOWN EMPTY")))); @@ -1501,16 +1501,16 @@ TEST_F(IndexingAnalysisTest, FusionOpReshapeOfConcat) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0 * 8 + d1) domain: - d0 in [0, 3] - d1 in [0, 7] - d0 * 8 + d1 in [0, 1] + d0 in [0, 4) + d1 in [0, 8) + d0 * 8 + d1 in [0, 2) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0 * 8 + d1 - 2) domain: - d0 in [0, 3] - d1 in [0, 7] - d0 * 8 + d1 in [2, 31] + d0 in [0, 4) + d1 in [0, 8) + d0 * 8 + d1 in [2, 32) )")))); } @@ -1537,7 +1537,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpCollapseShape) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0) -> (d0 floordiv 8, d0 mod 8) domain: - d0 in [0, 31] + d0 in [0, 32) )")))); } @@ -1553,8 +1553,8 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandShape) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0 * 8 + d1) domain: - d0 in [0, 3] - d1 in [0, 7] + d0 in [0, 4) + d1 in [0, 8) )")))); } @@ -1571,9 +1571,9 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandAndCollapseShape) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0 floordiv 8, d0 mod 8, d1 * 4 + d2) domain: - d0 in [0, 31] - d1 in [0, 2] - d2 in [0, 3] + d0 in [0, 32) + d1 in [0, 3) + d2 in [0, 4) )")))); auto output_indexing = GetInputToOutputIndexing(root); @@ -1581,9 +1581,9 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandAndCollapseShape) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0 * 8 + d1, d2 floordiv 4, d2 mod 4) domain: - d0 in [0, 3] - d1 in [0, 7] - d2 in [0, 11] + d0 in [0, 4) + d1 in [0, 8) + d2 in [0, 12) )")))); } @@ -1599,9 +1599,9 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandSubshapeOnly) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0 * 4 + d1, d2) domain: - d0 in [0, 3] - d1 in [0, 3] - d2 in [0, 7] + d0 in [0, 4) + d1 in [0, 4) + d2 in [0, 8) )")))); } @@ -1618,9 +1618,9 @@ TEST_F(IndexingAnalysisTest, ReshapeOpGenericReshape2DTo3D) { (d0, d1, d2) -> (d0 * 2 + d1 floordiv 2, (d1 mod 2) * 4 + d2) domain: - d0 in [0, 1] - d1 in [0, 3] - d2 in [0, 3] + d0 in [0, 2) + d1 in [0, 4) + d2 in [0, 4) )")))); } @@ -1638,8 +1638,8 @@ TEST_F(IndexingAnalysisTest, ReshapeOpGenericReshape3DTo2D) { (d0 mod 2) * 2 + d1 floordiv 4, d1 mod 4) domain: - d0 in [0, 3] - d1 in [0, 7] + d0 in [0, 4) + d1 in [0, 8) )")))); } @@ -1659,15 +1659,15 @@ TEST_F(IndexingAnalysisTest, PadOp) { d1 - 4 ) domain: - d0 in [1, 7] - d1 in [4, 7] - (d0 - 1) mod 2 in [0, 0] + d0 in [1, 8) + d1 in [4, 8) + (d0 - 1) mod 2 in [0, 1) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 11] - d1 in [0, 15] + d0 in [0, 12) + d1 in [0, 16) )")))); } @@ -1684,14 +1684,14 @@ TEST_F(IndexingAnalysisTest, PadOpNoInterior) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0 - 1, d1) domain: - d0 in [1, 2] - d1 in [0, 7] + d0 in [1, 3) + d1 in [0, 8) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 9] - d1 in [0, 7] + d0 in [0, 10) + d1 in [0, 8) )")))); } @@ -1713,13 +1713,13 @@ TEST_F(IndexingAnalysisTest, PadOpNegativePadding) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0) -> ((d0 + 3) floordiv 2) domain: - d0 in [0, 4] - (d0 + 3) mod 2 in [0, 0] + d0 in [0, 5) + (d0 + 3) mod 2 in [0, 1) )")), ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: - d0 in [0, 4] + d0 in [0, 5) )")))); } @@ -1743,16 +1743,16 @@ TEST_F(IndexingAnalysisTest, ReduceOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d0, s0, d1, s1) domain: - d0 in [0, 149] - d1 in [0, 9] - s0 in [0, 19] - s1 in [0, 49] + d0 in [0, 150) + d1 in [0, 10) + s0 in [0, 20) + s1 in [0, 50) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 149] - d1 in [0, 9] + d0 in [0, 150) + d1 in [0, 10) )")))); auto output_indexing_0 = GetInputToOutputIndexing(root, 0); @@ -1760,18 +1760,18 @@ TEST_F(IndexingAnalysisTest, ReduceOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d2) domain: - d0 in [0, 149] - d1 in [0, 19] - d2 in [0, 9] - d3 in [0, 49] + d0 in [0, 150) + d1 in [0, 20) + d2 in [0, 10) + d3 in [0, 50) )")))); auto output_indexing_1 = GetInputToOutputIndexing(root, 1); EXPECT_THAT(output_indexing_1.indexing_maps, ElementsAre(ElementsAre(MatchIndexingMap(R"( ()[s0, s1] -> (s0, s1) domain: - s0 in [0, 149] - s1 in [0, 9] + s0 in [0, 150) + s1 in [0, 10) )")))); } @@ -1803,26 +1803,26 @@ TEST_F(IndexingAnalysisTest, VariadicReduceOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (s0, d0) domain: - d0 in [0, 9] - s0 in [0, 255] + d0 in [0, 10) + s0 in [0, 256) )")), ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (s0, d0) domain: - d0 in [0, 9] - s0 in [0, 255] + d0 in [0, 10) + s0 in [0, 256) )")), ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: - d0 in [0, 9] + d0 in [0, 10) )")), ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: - d0 in [0, 9] + d0 in [0, 10) )")))); auto output_indexing_1 = GetOutputToInputIndexing(root, /*output_id=*/1); @@ -1830,31 +1830,31 @@ TEST_F(IndexingAnalysisTest, VariadicReduceOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (s0, d0) domain: - d0 in [0, 9] - s0 in [0, 255] + d0 in [0, 10) + s0 in [0, 256) )")), ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (s0, d0) domain: - d0 in [0, 9] - s0 in [0, 255] + d0 in [0, 10) + s0 in [0, 256) )")), ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: - d0 in [0, 9] + d0 in [0, 10) )")), ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: - d0 in [0, 9] + d0 in [0, 10) )")))); constexpr std::string_view kInputToOutputIndexing = R"( (d0, d1) -> (d1) domain: - d0 in [0, 255] - d1 in [0, 9] + d0 in [0, 256) + d1 in [0, 10) )"; auto input_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); EXPECT_THAT( @@ -1871,7 +1871,7 @@ TEST_F(IndexingAnalysisTest, VariadicReduceOp) { constexpr std::string_view kInitToOutputIndexing = R"( ()[s0] -> (s0) domain: - s0 in [0, 9] + s0 in [0, 10) )"; auto input_indexing_2 = GetInputToOutputIndexing(root, /*input_id=*/2); EXPECT_THAT( @@ -1905,15 +1905,15 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_NoPadding) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0] -> (d0, d1 + s0) domain: - d0 in [0, 1023] - d1 in [0, 2] - s0 in [0, 511] + d0 in [0, 1024) + d1 in [0, 3) + s0 in [0, 512) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 1023] - d1 in [0, 2] + d0 in [0, 1024) + d1 in [0, 3) )")))); } @@ -1937,18 +1937,18 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_PaddingAndWindowStride) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d0 * 2 + s0 - 1, d1 + s1) domain: - d0 in [0, 6] - d1 in [0, 16] - s0 in [0, 2] - s1 in [0, 1] - d0 * 2 + s0 in [1, 13] - d1 + s1 in [0, 16] + d0 in [0, 7) + d1 in [0, 17) + s0 in [0, 3) + s1 in [0, 2) + d0 * 2 + s0 in [1, 14) + d1 + s1 in [0, 17) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 6] - d1 in [0, 16] + d0 in [0, 7) + d1 in [0, 17) )")))); } @@ -1972,16 +1972,16 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_BaseDilation) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0 floordiv 2, d1 floordiv 2) domain: - d0 in [0, 2] - d1 in [0, 4] - d0 mod 2 in [0, 0] - d1 mod 2 in [0, 0] + d0 in [0, 3) + d1 in [0, 5) + d0 mod 2 in [0, 1) + d1 mod 2 in [0, 1) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 2] - d1 in [0, 4] + d0 in [0, 3) + d1 in [0, 5) )")))); } @@ -2005,15 +2005,15 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_WindowDilation) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0] -> (d0 + s0 * 3, d1) domain: - d0 in [0, 3] - d1 in [0, 2] - s0 in [0, 1] + d0 in [0, 4) + d1 in [0, 3) + s0 in [0, 2) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 3] - d1 in [0, 2] + d0 in [0, 4) + d1 in [0, 3) )")))); } @@ -2044,60 +2044,60 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_Variadic) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (s0, d1 + s1) domain: - d0 in [0, 0] - d1 in [0, 1] - s0 in [0, 1] - s1 in [0, 1] + d0 in [0, 1) + d1 in [0, 2) + s0 in [0, 2) + s1 in [0, 2) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (s0, d1 + s1) domain: - d0 in [0, 0] - d1 in [0, 1] - s0 in [0, 1] - s1 in [0, 1] + d0 in [0, 1) + d1 in [0, 2) + s0 in [0, 2) + s1 in [0, 2) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 0] - d1 in [0, 1] + d0 in [0, 1) + d1 in [0, 2) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 0] - d1 in [0, 1] + d0 in [0, 1) + d1 in [0, 2) )")))); auto input_indexing_1 = GetOutputToInputIndexing(root, /*output_id=*/1); EXPECT_THAT(input_indexing_1.indexing_maps, ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (s0, d1 + s1) domain: - d0 in [0, 0] - d1 in [0, 1] - s0 in [0, 1] - s1 in [0, 1] + d0 in [0, 1) + d1 in [0, 2) + s0 in [0, 2) + s1 in [0, 2) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (s0, d1 + s1) domain: - d0 in [0, 0] - d1 in [0, 1] - s0 in [0, 1] - s1 in [0, 1] + d0 in [0, 1) + d1 in [0, 2) + s0 in [0, 2) + s1 in [0, 2) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 0] - d1 in [0, 1] + d0 in [0, 1) + d1 in [0, 2) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 0] - d1 in [0, 1] + d0 in [0, 1) + d1 in [0, 2) )")))); } @@ -2116,24 +2116,24 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_NoPadding) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (d0, d1 + s0, d2 + s1, s2) domain: - d0 in [0, 0] - d1 in [0, 9] - d2 in [0, 5] - d3 in [0, 7] - s0 in [0, 2] - s1 in [0, 4] - s2 in [0, 3] + d0 in [0, 1) + d1 in [0, 10) + d2 in [0, 6) + d3 in [0, 8) + s0 in [0, 3) + s1 in [0, 5) + s2 in [0, 4) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3) domain: - d0 in [0, 0] - d1 in [0, 9] - d2 in [0, 5] - d3 in [0, 7] - s0 in [0, 2] - s1 in [0, 4] - s2 in [0, 3] + d0 in [0, 1) + d1 in [0, 10) + d2 in [0, 6) + d3 in [0, 8) + s0 in [0, 3) + s1 in [0, 5) + s2 in [0, 4) )")))); } @@ -2152,26 +2152,26 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_PaddingAndWindowStride) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (d0, d1 * 2 + s0 - 1, d2 * 2 + s1 - 2, s2) domain: - d0 in [0, 0] - d1 in [0, 5] - d2 in [0, 4] - d3 in [0, 7] - s0 in [0, 2] - s1 in [0, 4] - s2 in [0, 3] - d1 * 2 + s0 in [1, 12] - d2 * 2 + s1 in [2, 11] + d0 in [0, 1) + d1 in [0, 6) + d2 in [0, 5) + d3 in [0, 8) + s0 in [0, 3) + s1 in [0, 5) + s2 in [0, 4) + d1 * 2 + s0 in [1, 13) + d2 * 2 + s1 in [2, 12) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3) domain: - d0 in [0, 0] - d1 in [0, 5] - d2 in [0, 4] - d3 in [0, 7] - s0 in [0, 2] - s1 in [0, 4] - s2 in [0, 3] + d0 in [0, 1) + d1 in [0, 6) + d2 in [0, 5) + d3 in [0, 8) + s0 in [0, 3) + s1 in [0, 5) + s2 in [0, 4) )")))); } @@ -2190,26 +2190,26 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_LhsDilation) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (d0, (d1 + s0) floordiv 2, (d2 + s1) floordiv 2, s2) domain: - d0 in [0, 0] - d1 in [0, 20] - d2 in [0, 14] - d3 in [0, 7] - s0 in [0, 2] - s1 in [0, 4] - s2 in [0, 3] - (d1 + s0) mod 2 in [0, 0] - (d2 + s1) mod 2 in [0, 0] + d0 in [0, 1) + d1 in [0, 21) + d2 in [0, 15) + d3 in [0, 8) + s0 in [0, 3) + s1 in [0, 5) + s2 in [0, 4) + (d1 + s0) mod 2 in [0, 1) + (d2 + s1) mod 2 in [0, 1) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3) domain: - d0 in [0, 0] - d1 in [0, 20] - d2 in [0, 14] - d3 in [0, 7] - s0 in [0, 2] - s1 in [0, 4] - s2 in [0, 3] + d0 in [0, 1) + d1 in [0, 21) + d2 in [0, 15) + d3 in [0, 8) + s0 in [0, 3) + s1 in [0, 5) + s2 in [0, 4) )")))); } @@ -2228,24 +2228,24 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_RhsDilation) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (d0, d1 + s0 * 2, d2 + s1 * 2, s2) domain: - d0 in [0, 0] - d1 in [0, 7] - d2 in [0, 1] - d3 in [0, 7] - s0 in [0, 2] - s1 in [0, 4] - s2 in [0, 3] + d0 in [0, 1) + d1 in [0, 8) + d2 in [0, 2) + d3 in [0, 8) + s0 in [0, 3) + s1 in [0, 5) + s2 in [0, 4) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3) domain: - d0 in [0, 0] - d1 in [0, 7] - d2 in [0, 1] - d3 in [0, 7] - s0 in [0, 2] - s1 in [0, 4] - s2 in [0, 3] + d0 in [0, 1) + d1 in [0, 8) + d2 in [0, 2) + d3 in [0, 8) + s0 in [0, 3) + s1 in [0, 5) + s2 in [0, 4) )")))); } @@ -2264,24 +2264,24 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_FeatureGroups) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (d0, d1 + s0, d2 + s1, (d3 floordiv 8) * 4 + s2) domain: - d0 in [0, 0] - d1 in [0, 9] - d2 in [0, 5] - d3 in [0, 47] - s0 in [0, 2] - s1 in [0, 4] - s2 in [0, 3] + d0 in [0, 1) + d1 in [0, 10) + d2 in [0, 6) + d3 in [0, 48) + s0 in [0, 3) + s1 in [0, 5) + s2 in [0, 4) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3) domain: - d0 in [0, 0] - d1 in [0, 9] - d2 in [0, 5] - d3 in [0, 47] - s0 in [0, 2] - s1 in [0, 4] - s2 in [0, 3] + d0 in [0, 1) + d1 in [0, 10) + d2 in [0, 6) + d3 in [0, 48) + s0 in [0, 3) + s1 in [0, 5) + s2 in [0, 4) )")))); } @@ -2300,25 +2300,25 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_BatchGroups) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 + s3 * 2, d1 + s0, d2 + s1, s2) domain: - d0 in [0, 1] - d1 in [0, 9] - d2 in [0, 5] - d3 in [0, 20] - s0 in [0, 2] - s1 in [0, 4] - s2 in [0, 3] - s3 in [0, 6] + d0 in [0, 2) + d1 in [0, 10) + d2 in [0, 6) + d3 in [0, 21) + s0 in [0, 3) + s1 in [0, 5) + s2 in [0, 4) + s3 in [0, 7) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3) domain: - d0 in [0, 1] - d1 in [0, 9] - d2 in [0, 5] - d3 in [0, 20] - s0 in [0, 2] - s1 in [0, 4] - s2 in [0, 3] + d0 in [0, 2) + d1 in [0, 10) + d2 in [0, 6) + d3 in [0, 21) + s0 in [0, 3) + s1 in [0, 5) + s2 in [0, 4) )")))); } @@ -2335,10 +2335,10 @@ TEST_F(IndexingAnalysisTest, ReverseOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3) domain: - d0 in [0, 0] - d1 in [0, 16] - d2 in [0, 8] - d3 in [0, 8] + d0 in [0, 1) + d1 in [0, 17) + d2 in [0, 9) + d3 in [0, 9) )")))); auto output_indexing = GetInputToOutputIndexing(root); @@ -2346,10 +2346,10 @@ TEST_F(IndexingAnalysisTest, ReverseOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3) domain: - d0 in [0, 0] - d1 in [0, 16] - d2 in [0, 8] - d3 in [0, 8] + d0 in [0, 1) + d1 in [0, 17) + d2 in [0, 9) + d3 in [0, 9) )")))); } @@ -2373,8 +2373,8 @@ TEST_F(IndexingAnalysisTest, ReverseReshape) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 9] - d1 in [0, 10] + d0 in [0, 10) + d1 in [0, 11) )")))); } @@ -2392,9 +2392,9 @@ TEST_F(IndexingAnalysisTest, SliceOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0 + 5, d1 * 7 + 3, d2 * 2) domain: - d0 in [0, 4] - d1 in [0, 2] - d2 in [0, 24] + d0 in [0, 5) + d1 in [0, 3) + d2 in [0, 25) )")))); auto output_indexing = GetInputToOutputIndexing(root); EXPECT_THAT(output_indexing.indexing_maps, @@ -2405,11 +2405,11 @@ TEST_F(IndexingAnalysisTest, SliceOp) { d2 floordiv 2 ) domain: - d0 in [5, 9] - d1 in [3, 17] - d2 in [0, 48] - (d1 - 3) mod 7 in [0, 0] - d2 mod 2 in [0, 0] + d0 in [5, 10) + d1 in [3, 18) + d2 in [0, 49) + (d1 - 3) mod 7 in [0, 1) + d2 mod 2 in [0, 1) )")))); } @@ -2427,20 +2427,20 @@ TEST_F(IndexingAnalysisTest, TransposeOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d3, d1, d2) domain: - d0 in [0, 2] - d1 in [0, 5] - d2 in [0, 127] - d3 in [0, 12287] + d0 in [0, 3) + d1 in [0, 6) + d2 in [0, 128) + d3 in [0, 12288) )")))); auto output_indexing = GetInputToOutputIndexing(root); EXPECT_THAT(output_indexing.indexing_maps, ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d2, d3, d1) domain: - d0 in [0, 2] - d1 in [0, 12287] - d2 in [0, 5] - d3 in [0, 127] + d0 in [0, 3) + d1 in [0, 12288) + d2 in [0, 6) + d3 in [0, 128) )")))); } @@ -2456,10 +2456,10 @@ TEST_F(IndexingAnalysisTest, TransposeOp4D) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d3, d1, d2) domain: - d0 in [0, 2] - d1 in [0, 5] - d2 in [0, 127] - d3 in [0, 12287] + d0 in [0, 3) + d1 in [0, 6) + d2 in [0, 128) + d3 in [0, 12288) )")))); } @@ -2478,26 +2478,26 @@ TEST_F(IndexingAnalysisTest, DotOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5)[s0, s1] -> (d2, d1, s1, d3, s0, d0) domain: - d0 in [0, 9] - d1 in [0, 37] - d2 in [0, 3] - d3 in [0, 10] - d4 in [0, 15] - d5 in [0, 21] - s0 in [0, 17] - s1 in [0, 16] + d0 in [0, 10) + d1 in [0, 38) + d2 in [0, 4) + d3 in [0, 11) + d4 in [0, 16) + d5 in [0, 22) + s0 in [0, 18) + s1 in [0, 17) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5)[s0, s1] -> (s1, d0, d4, s0, d5, d1) domain: - d0 in [0, 9] - d1 in [0, 37] - d2 in [0, 3] - d3 in [0, 10] - d4 in [0, 15] - d5 in [0, 21] - s0 in [0, 17] - s1 in [0, 16] + d0 in [0, 10) + d1 in [0, 38) + d2 in [0, 4) + d3 in [0, 11) + d4 in [0, 16) + d5 in [0, 22) + s0 in [0, 18) + s1 in [0, 17) )")))); } @@ -2558,8 +2558,8 @@ TEST_F(IndexingAnalysisTest, FusionWithUnsupportedOp) { ElementsAre(UnorderedElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0 * 6, d1 * 2) domain: - d0 in [0, 3] - d1 in [0, 2] + d0 in [0, 4) + d1 in [0, 3) )")), ElementsAre(UndefinedMap()), ElementsAre(UndefinedMap()))); } @@ -2577,16 +2577,16 @@ TEST_F(IndexingAnalysisTest, TilingIndexing) { d0 mod 4 + s2 * 4 ) domain: - d0 in [0, 15] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 8191] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 7] - s1 in [0, 0] - s2 in [0, 3] - (d3 floordiv 64) * 8 + s0 in [0, 1021] + d0 in [0, 16) + d1 in [0, 1) + d2 in [0, 1) + d3 in [0, 8192) + d4 in [0, 1) + d5 in [0, 1) + s0 in [0, 8) + s1 in [0, 1) + s2 in [0, 4) + (d3 floordiv 64) * 8 + s0 in [0, 1022) )")); } @@ -2619,8 +2619,8 @@ TEST_F(IndexingAnalysisTest, EpilogueIndexing) { MatchIndexingString(R"( (d0, d1) -> (d1 * 1000 + d0) domain: - d0 in [0, 999] - d1 in [0, 999] + d0 in [0, 1000) + d1 in [0, 1000) )")); } @@ -2649,8 +2649,8 @@ TEST_F(IndexingAnalysisTest, EpilogueIndexing_NoEpilogue) { MatchIndexingString(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 999] - d1 in [0, 999] + d0 in [0, 1000) + d1 in [0, 1000) )")); } diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index bcfcf494a5a5b9..9eff63baa40256 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -749,7 +749,8 @@ std::string Interval::ToString() const { } void Interval::Print(std::ostream& out) const { - out << '[' << lower << ", " << upper << "]"; + // The interval is printed as a semi-open one because it is easier to read. + out << '[' << lower << ", " << upper + 1 << ")"; } int64_t Interval::GetLoopTripCount() const { diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc index 24bbf63fd385d4..6eeb6f9e0c3de9 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc @@ -80,13 +80,13 @@ TEST_F(IndexingMapTest, RTVar) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1)[range, rt_0, rt_1] -> (d1, d0, range + rt_0, rt_0) domain: - d0 in [0, 99] - d1 in [0, 43] - range in [-99, 99] - rt_0 in [0, 2] + d0 in [0, 100) + d1 in [0, 44) + range in [-99, 100) + rt_0 in [0, 3) hlo: NULL () -> () - rt_1 in [0, 7] + rt_1 in [0, 8) hlo: NULL () -> () )")); @@ -128,10 +128,10 @@ TEST_F(IndexingMapTest, Composition_Permutation) { EXPECT_THAT(composed, MatchIndexingMap(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0) domain: - d0 in [0, 3] - s0 in [0, 1] - s1 in [0, 1] - s2 in [0, 3] + d0 in [0, 4) + s0 in [0, 2) + s1 in [0, 2) + s2 in [0, 4) )")); } @@ -147,10 +147,10 @@ TEST_F(IndexingMapTest, Composition_RestrictedInterval) { EXPECT_THAT(composed, MatchIndexingMap(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0) domain: - d0 in [0, 4] - s0 in [0, 6] - s1 in [0, 1] - s2 in [0, 5] + d0 in [0, 5) + s0 in [0, 7) + s1 in [0, 2) + s2 in [0, 6) )")); } @@ -174,26 +174,26 @@ TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { EXPECT_THAT(composed, MatchIndexingMap(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0) domain: - d0 in [0, 9] - s0 in [0, 69] - s1 in [0, 19] - s2 in [0, 7] - d0 + s2 in [0, 20] - d0 mod 8 in [0, 0] - s0 mod 3 in [1, 1] - s2 mod 4 in [0, 0] + d0 in [0, 10) + s0 in [0, 70) + s1 in [0, 20) + s2 in [0, 8) + d0 + s2 in [0, 21) + d0 mod 8 in [0, 1) + s0 mod 3 in [1, 2) + s2 mod 4 in [0, 1) )")); EXPECT_TRUE(composed.Simplify()); EXPECT_THAT(composed, MatchIndexingMap(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0) domain: - d0 in [0, 8] - s0 in [1, 67] - s1 in [0, 19] - s2 in [0, 4] - d0 mod 8 in [0, 0] - s0 mod 3 in [1, 1] - s2 mod 4 in [0, 0] + d0 in [0, 9) + s0 in [1, 68) + s1 in [0, 20) + s2 in [0, 5) + d0 mod 8 in [0, 1) + s0 mod 3 in [1, 2) + s2 mod 4 in [0, 1) )")); } @@ -210,12 +210,12 @@ TEST_F(IndexingMapTest, RemoveUnusedDimensions_ConstraintUsesDim) { EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d1, s0, s1) domain: - d0 in [0, 49] - d1 in [0, 59] - s0 in [0, 69] - s1 in [0, 19] - d0 + s0 in [1, 100] - s0 mod 3 in [0, 0] + d0 in [0, 50) + d1 in [0, 60) + s0 in [0, 70) + s1 in [0, 20) + d0 + s0 in [1, 101) + s0 mod 3 in [0, 1) )")); } @@ -230,9 +230,9 @@ TEST_F(IndexingMapTest, RemoveUnusedDimensions_ConstraintUsesOnlyUnusedDim) { EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0)[s0, s1] -> (s0, d0, s1) domain: - d0 in [0, 59] - s0 in [0, 69] - s1 in [0, 19] + d0 in [0, 60) + s0 in [0, 70) + s1 in [0, 20) )")); } @@ -248,11 +248,11 @@ TEST_F(IndexingMapTest, RemoveUnusedDimensions_ConstraintsWithManyDims) { EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d0 + s0 * 4 + d1 - 42) domain: - d0 in [0, 1] - d1 in [0, 3] - s0 in [0, 31] - s1 in [0, 63] - d0 + s0 * 4 + d1 in [24, 459] + d0 in [0, 2) + d1 in [0, 4) + s0 in [0, 32) + s1 in [0, 64) + d0 + s0 * 4 + d1 in [24, 460) )")); } @@ -271,12 +271,12 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintsWithManyDims) { EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d0 + s0 * 4 + d1 - 42) domain: - d0 in [0, 1] - d1 in [0, 3] - s0 in [0, 31] - s1 in [0, 95] - d0 + s0 * 4 + d1 in [24, 459] - s0 + s1 in [0, 512] + d0 in [0, 2) + d1 in [0, 4) + s0 in [0, 32) + s1 in [0, 96) + d0 + s0 * 4 + d1 in [24, 460) + s0 + s1 in [0, 513) )")); EXPECT_THAT(ConvertToSTL(unused_vars), ::testing::ElementsAreArray( @@ -296,12 +296,12 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesSymbol) { EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d1, d0, s1) domain: - d0 in [0, 49] - d1 in [0, 59] - s0 in [0, 69] - s1 in [0, 19] - s0 + s1 in [1, 100] - s0 mod 3 in [0, 0] + d0 in [0, 50) + d1 in [0, 60) + s0 in [0, 70) + s1 in [0, 20) + s0 + s1 in [1, 101) + s0 mod 3 in [0, 1) )")); } @@ -316,9 +316,9 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSymbols) { EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0] -> (d1, d0, s0) domain: - d0 in [0, 49] - d1 in [0, 59] - s0 in [0, 19] + d0 in [0, 50) + d1 in [0, 60) + s0 in [0, 20) )")); } @@ -330,7 +330,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintIsAConstantWithinRange) { EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0) -> (d0) domain: - d0 in [0, 49] + d0 in [0, 50) )")); } @@ -383,10 +383,10 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithManySymbols) { EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0)[s0, s1] -> (d0 * 4 + s0 + s1 - 42) domain: - d0 in [0, 31] - s0 in [0, 1] - s1 in [0, 3] - d0 * 4 + s0 + s1 in [24, 459] + d0 in [0, 32) + s0 in [0, 2) + s1 in [0, 4) + d0 * 4 + s0 + s1 in [24, 460) )")); } @@ -407,12 +407,12 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithRTVars) { EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0)[s0, s1] -> (d0 * 4 + s0 + s1 - 42) domain: - d0 in [0, 31] - s0 in [0, 1] - s1 in [0, 3] + d0 in [0, 32) + s0 in [0, 2) + s1 in [0, 4) hlo: NULL () -> () - d0 * 4 + s0 + s1 in [24, 459] + d0 * 4 + s0 + s1 in [24, 460) )")); } @@ -426,8 +426,8 @@ TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum) { EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0) -> (d0) domain: - d0 in [0, 99] - d0 mod 8 in [45, 49] + d0 in [0, 100) + d0 mod 8 in [45, 50) )")); } @@ -441,7 +441,7 @@ TEST_F(IndexingMapTest, EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0) -> (d0) domain: - d0 in [40, 95] + d0 in [40, 96) )")); } @@ -456,8 +456,8 @@ TEST_F(IndexingMapTest, EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0) domain: - d0 in [0, 99] - s0 in [-33, -13] + d0 in [0, 100) + s0 in [-33, -12) )")); } @@ -472,8 +472,8 @@ TEST_F(IndexingMapTest, EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0) domain: - d0 in [0, 99] - s0 in [15, 35] + d0 in [0, 100) + s0 in [15, 36) )")); } @@ -487,7 +487,7 @@ TEST_F(IndexingMapTest, EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0) -> (d0) domain: - d0 in [2, 4] + d0 in [2, 5) )")); } @@ -502,8 +502,8 @@ TEST_F(IndexingMapTest, EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0) domain: - d0 in [0, 99] - s0 in [-3, -2] + d0 in [0, 100) + s0 in [-3, -1) )")); } @@ -518,8 +518,8 @@ TEST_F(IndexingMapTest, EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0) domain: - d0 in [0, 99] - s0 in [2, 3] + d0 in [0, 100) + s0 in [2, 4) )")); } @@ -541,12 +541,12 @@ TEST_F(IndexingMapTest, ConstraintMerge_Mod) { EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0, s1] -> (d0, s1, s0) domain: - d0 in [0, 3] - s0 in [-18, -6] - s1 in [1, 6] - d0 mod 3 in [0, 0] - s0 mod 6 in [0, 0] - s1 mod 5 in [1, 1] + d0 in [0, 4) + s0 in [-18, -5) + s1 in [1, 7) + d0 mod 3 in [0, 1) + s0 mod 6 in [0, 1) + s1 mod 5 in [1, 2) )")); } @@ -558,7 +558,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ConstantDims) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (5) domain: - d0 in [5, 5] + d0 in [5, 6) )")); } @@ -571,8 +571,8 @@ TEST_F(IndexingMapTest, EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 7] - d1 in [0, 15] + d0 in [0, 8) + d1 in [0, 16) )")); } @@ -589,9 +589,9 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithMultipliers) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1, d2) -> (d0, d1, d2) domain: - d0 in [0, 8] - d1 in [0, 8] - d2 in [0, 8] + d0 in [0, 9) + d1 in [0, 9) + d2 in [0, 9) )")); } @@ -608,9 +608,9 @@ TEST_F(IndexingMapTest, (d0, d1, d2) -> (d0 * 2 + (d1 * 4 + d2) floordiv 8, (d1 * 4 + d2) mod 8) domain: - d0 in [0, 9] - d1 in [0, 9] - d2 in [0, 9] + d0 in [0, 10) + d1 in [0, 10) + d2 in [0, 10) )")); } @@ -624,8 +624,8 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithReverse) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 7] - d1 in [0, 8] + d0 in [0, 8) + d1 in [0, 9) )")); } @@ -637,7 +637,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape) { EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0] -> (s0 * 128) - domain: s0 in [0, 127] + domain: s0 in [0, 128) )")); } @@ -650,8 +650,8 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape2) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0 * 128 + d1) domain: - d0 in [0, 1023] - d1 in [0, 127] + d0 in [0, 1024) + d1 in [0, 128) )")); } @@ -664,7 +664,7 @@ TEST_F(IndexingMapTest, EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> ((-d0) mod 2) domain: - d0 in [0, 127] + d0 in [0, 128) )")); } @@ -683,8 +683,8 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyBitcastAndBack) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0 * 512 + d1 * 4) domain: - d0 in [0, 3071] - d1 in [0, 127] + d0 in [0, 3072) + d1 in [0, 128) )")); } @@ -697,7 +697,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape_Regression) { EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0] -> (((s0 * 64) floordiv 715) * 715 + (s0 * 128) mod 715) - domain: s0 in [0, 127] + domain: s0 in [0, 128) )")); } @@ -711,7 +711,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsInSequence) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0] -> (s0) domain: - s0 in [0, 1233] + s0 in [0, 1234) )")); } @@ -737,10 +737,10 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromMod) { (s0 * 458752 + s2 * 4 + s3 * 512) mod 20000 + s1 ) domain: - s0 in [0, 871] - s1 in [0, 3] - s2 in [0, 127] - s3 in [0, 895] + s0 in [0, 872) + s1 in [0, 4) + s2 in [0, 128) + s3 in [0, 896) )")); } @@ -757,8 +757,8 @@ TEST_F(IndexingMapTest, s0 * 4 + s1 floordiv 32 ) domain: - s0 in [0, 1] - s1 in [0, 127] + s0 in [0, 2) + s1 in [0, 128) )")); } @@ -773,10 +773,10 @@ TEST_F(IndexingMapTest, RescaleSymbols_Simple) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0) domain: - d0 in [0, 3] - s0 in [0, 1] - s1 in [0, 1] - s2 in [0, 5] + d0 in [0, 4) + s0 in [0, 2) + s1 in [0, 2) + s2 in [0, 6) )")); } @@ -793,10 +793,10 @@ TEST_F(IndexingMapTest, RescaleSymbols_WithShift) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0 * 6 + 3) domain: - d0 in [0, 3] - s0 in [0, 6] - s1 in [0, 1] - s2 in [0, 5] + d0 in [0, 4) + s0 in [0, 7) + s1 in [0, 2) + s2 in [0, 6) )")); } @@ -813,10 +813,10 @@ TEST_F(IndexingMapTest, RescaleSymbols_TwoModConstraints) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0) domain: - d0 in [0, 3] - s0 in [0, 1] - s1 in [0, 1] - s2 in [0, 5] + d0 in [0, 4) + s0 in [0, 2) + s1 in [0, 2) + s2 in [0, 6) )")); } @@ -833,11 +833,11 @@ TEST_F(IndexingMapTest, RescaleSymbols_RescaledSymbolInOtherNonModConstraint) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0 * 6 + 3) domain: - d0 in [0, 3] - s0 in [0, 1] - s1 in [0, 1] - s2 in [0, 5] - (s0 * 6 + 3) * s2 in [0, 28] + d0 in [0, 4) + s0 in [0, 2) + s1 in [0, 2) + s2 in [0, 6) + (s0 * 6 + 3) * s2 in [0, 29) )")); } @@ -1152,7 +1152,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_Iota) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (d0, d0) domain: - d0 in [0, 255] + d0 in [0, 256) )")); } @@ -1181,7 +1181,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_IotaAsConstant) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (d0, 7) domain: - d0 in [0, 255] + d0 in [0, 256) )")); } @@ -1212,8 +1212,8 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_ConstraintsGetUpdated) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (d0, d0) domain: - d0 in [0, 254] - d0 mod 2 in [0, 0] + d0 in [0, 255) + d0 mod 2 in [0, 1) )")); } @@ -1245,7 +1245,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_Broadcast) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (d0, 11) domain: - d0 in [0, 31] + d0 in [0, 32) )")); } @@ -1286,7 +1286,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_ChainedNoncomputeOps) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (d0, (d0 floordiv 12) * -4 + 8) domain: - d0 in [0, 35] + d0 in [0, 36) )")); } @@ -1319,8 +1319,8 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_PartialRTVarRemoval) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0] -> (d0, s0) domain: - d0 in [0, 23] - s0 in [0, 512] + d0 in [0, 24) + s0 in [0, 513) hlo: %constant = s64[12]{0} constant({...}) (d0) -> (d0 floordiv 2) )")); @@ -1356,7 +1356,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_Add) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (d0, d0 * 2 + 42) domain: - d0 in [0, 11] + d0 in [0, 12) )")); } @@ -1395,7 +1395,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_Multiply) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (d0, (-d0 + 11) * d0) domain: - d0 in [0, 11] + d0 in [0, 12) )")); } @@ -1431,8 +1431,8 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_PartiallyOptimizableAdd) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0] -> (d0, d0 * 2 + s0) domain: - d0 in [0, 11] - s0 in [0, 11] + d0 in [0, 12) + s0 in [0, 12) hlo: %constant = s64[12]{0} constant({...}) (d0) -> (d0) )")); diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc index 31adf8b1780982..d5cbfe9da5c8f7 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -132,7 +132,7 @@ ENTRY main { EXPECT_THAT(root->block_id_to_tile_offsets_indexing(), MatchIndexingMap(R"( (d0) -> (d0 floordiv 10, (d0 mod 10) * 10) domain: - d0 in [0, 19] + d0 in [0, 20) )")); auto p0_from_subtract0 = root->operand(0); @@ -144,7 +144,7 @@ ENTRY main { /*block_id_to_tile_offsets_indexing=*/R"( (d0) -> (d0 floordiv 10, (d0 mod 10) * 10) domain: - d0 in [0, 19] + d0 in [0, 20) )")); EXPECT_THAT(*p0_from_subtract1, MatchTiledHloInstruction( @@ -153,7 +153,7 @@ ENTRY main { /*block_id_to_tile_offsets_indexing=*/R"( (d0) -> (d0 floordiv 10, 0) domain: - d0 in [0, 19] + d0 in [0, 20) )")); } @@ -243,7 +243,7 @@ ENTRY main { /*tile_sizes=*/{1, 97}, /*tile_strides=*/{1, 1}, /*block_id_to_tile_offsets_indexing=*/R"( (d0) -> (d0, 0) - domain: d0 in [0, 1] + domain: d0 in [0, 2) )")); } @@ -273,7 +273,7 @@ ENTRY main { /*block_id_to_tile_offsets_indexing=*/R"( (d0) -> ((d0 floordiv 16) * 2, ((d0 floordiv 8) mod 2) * 4, (d0 mod 8) * 2) domain: - d0 in [0, 31] + d0 in [0, 32) )")); EXPECT_THAT(*root->operand(0), @@ -282,7 +282,7 @@ ENTRY main { /*block_id_to_tile_offsets_indexing=*/R"( (d0) -> (((d0 floordiv 8) mod 2) * 4, (d0 mod 8) * 2, (d0 floordiv 16) * 2) domain: - d0 in [0, 31] + d0 in [0, 32) )")); } @@ -316,7 +316,7 @@ ENTRY main { /*block_id_to_tile_offsets_indexing=*/R"( (d0) -> ((d0 floordiv 4) * 2, (d0 mod 4) * 2) domain: - d0 in [0, 7] + d0 in [0, 8) )")); EXPECT_THAT(*p0_from_slice0, @@ -325,7 +325,7 @@ ENTRY main { /*block_id_to_tile_offsets_indexing=*/R"( (d0) -> ((d0 floordiv 4) * 2, (d0 mod 4) * 2 + 2) domain: - d0 in [0, 7] + d0 in [0, 8) )")); EXPECT_THAT(*p0_from_slice1, @@ -334,7 +334,7 @@ ENTRY main { /*block_id_to_tile_offsets_indexing=*/R"( (d0) -> ((d0 floordiv 4) * 2 + 3, (d0 mod 4) * 2 + 4) domain: - d0 in [0, 7] + d0 in [0, 8) )")); } @@ -452,10 +452,10 @@ ENTRY main { EXPECT_THAT(conjunction, SizeIs(2)); // We expect the constraints here to be - // 6 mod s0 in [0, 0] && 8 mod s1 in [0, 0] || - // 6 mod s0 in [0, 0] && s1 mod 8 in [0, 0] || - // 8 mod s1 in [0, 0] && s0 mod 6 in [0, 0] || - // s0 mod 6 in [0, 0] && s1 mod 8 in [0, 0] + // 6 mod s0 in [0, 1) && 8 mod s1 in [0, 1) || + // 6 mod s0 in [0, 1) && s1 mod 8 in [0, 1) || + // 8 mod s1 in [0, 1) && s0 mod 6 in [0, 1) || + // s0 mod 6 in [0, 1) && s1 mod 8 in [0, 1) // Tile sizes {6, 8} satisfy these constraints. std::vector possible_tile_parameters({6, 8}); EXPECT_THAT(analysis->ParametersSatisfyConstraints(possible_tile_parameters), @@ -606,7 +606,7 @@ ENTRY main { std::vector good_tilings, analysis.GetGoodTilings()); // The constraint on the 1st dimension is - // 6 mod s0 in [0, 0] || s0 mod 6 in [0, 0], + // 6 mod s0 in [0, 1) || s0 mod 6 in [0, 1), // and only 48, 1, and 2 fulfill it from the set of possible tile sizes // (1, 2, 4, 8, 16, 32, 48). // There is no constraint on the 2nd dimension. @@ -779,7 +779,7 @@ ENTRY main { /*block_id_to_tile_offsets_indexing=*/R"( (d0) -> (d0 floordiv 32768, d0 mod 32768) domain: - d0 in [0, 2147549183] + d0 in [0, 2147549184) )")); } diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc index 9f776f4ccca12d..1f59a4156de433 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc @@ -149,7 +149,7 @@ TEST_F(SymbolicTileTest, size_map: ()[s0, s1] -> (1, (s0 + 5) floordiv 6, s0 - ((s0 - 1) floordiv 6) * 6, s1) stride_map: ()[s0, s1] -> (0, 1, 1, 1) constraints: - 6 mod s0 in [0, 0] || s0 mod 6 in [0, 0] + 6 mod s0 in [0, 1) || s0 mod 6 in [0, 1) )"))); } @@ -430,10 +430,10 @@ TEST_F(SymbolicTileTest, CanPropagateTileThroughDynamicSlice) { size_map: ()[s0, s1, s2] -> (1, s1, s2) stride_map: ()[s0, s1, s2] -> (0, 1, 1) rt_vars: - s3 in [0, 1] + s3 in [0, 2) hlo: %of1 = s32[] parameter(1) (d0, d1, d2) -> () - s4 in [0, 226] + s4 in [0, 227) hlo: %of3 = s32[] parameter(3) (d0, d1, d2) -> () )"))); @@ -482,10 +482,10 @@ TEST_F(SymbolicTileTest, CanPropagateTileThroughDynamicUpdateSlice) { size_map: ()[s0, s1] -> (s0, s1) stride_map: ()[s0, s1] -> (1, 1) rt_vars: - s2 in [0, 15] + s2 in [0, 16) hlo: %of1 = s32[] parameter(2) (d0, d1) -> () - s3 in [0, 20] + s3 in [0, 21) hlo: %of2 = s32[] parameter(3) (d0, d1) -> () )"))); @@ -525,10 +525,10 @@ TEST_F(SymbolicTileTest, CanPropagateTileThroughGather) { size_map: ()[s0, s1, s2, s3] -> (s1, s2, s3) stride_map: ()[s0, s1, s2, s3] -> (1, 1, 1) rt_vars: - s4 in [0, 26] + s4 in [0, 27) hlo: %indices = s32[1806,2]{1,0} parameter(1) (d0, d1, d2, d3) -> (d0, 0) - s5 in [0, 68] + s5 in [0, 69) hlo: %indices = s32[1806,2]{1,0} parameter(1) (d0, d1, d2, d3) -> (d0, 1) )"))); @@ -720,10 +720,10 @@ TEST_F(SymbolicTileTest, CanCombineCompatibleConstraints) { size_map: ()[s0, s1] -> (1, (s0 + 5) floordiv 6, s0 - ((s0 - 1) floordiv 6) * 6, (s1 + 7) floordiv 8, s1 - ((s1 - 1) floordiv 8) * 8) stride_map: ()[s0, s1] -> (0, 1, 1, 1, 1) constraints: - 6 mod s0 in [0, 0] && 8 mod s1 in [0, 0] || - 6 mod s0 in [0, 0] && s1 mod 8 in [0, 0] || - 8 mod s1 in [0, 0] && s0 mod 6 in [0, 0] || - s0 mod 6 in [0, 0] && s1 mod 8 in [0, 0] + 6 mod s0 in [0, 1) && 8 mod s1 in [0, 1) || + 6 mod s0 in [0, 1) && s1 mod 8 in [0, 1) || + 8 mod s1 in [0, 1) && s0 mod 6 in [0, 1) || + s0 mod 6 in [0, 1) && s1 mod 8 in [0, 1) )"))); } @@ -803,7 +803,7 @@ TEST_F(ConstraintExpressionTest, PrettyPrintingTest) { constraints.Or(std::move(conjunction_1)); constraints.Or(std::move(conjunction_2)); EXPECT_THAT(constraints, MatchConstraintExpressionString( - "d0 in [0, 5] && d1 in [0, 5] || d2 in [0, 5]")); + "d0 in [0, 6) && d1 in [0, 6) || d2 in [0, 6)")); } TEST_F(ConstraintExpressionTest, @@ -811,11 +811,11 @@ TEST_F(ConstraintExpressionTest, ConstraintExpression constraints; constraints.And(GetConjointConstraints({{"d0", Interval{0, 5}}})); - EXPECT_THAT(constraints, MatchConstraintExpressionString("d0 in [0, 5]")); + EXPECT_THAT(constraints, MatchConstraintExpressionString("d0 in [0, 6)")); // Constraints are intersected. constraints.And(GetConjointConstraints({{"d0", Interval{3, 6}}})); - EXPECT_THAT(constraints, MatchConstraintExpressionString("d0 in [3, 5]")); + EXPECT_THAT(constraints, MatchConstraintExpressionString("d0 in [3, 6)")); // Empty intersection results in unsatisfiability. constraints.And(GetConjointConstraints({{"d0", Interval{7, 8}}})); @@ -864,7 +864,7 @@ TEST_F( constraints.Or(std::move(conjunction_1)); constraints.Or(std::move(conjunction_2)); EXPECT_THAT(constraints, - MatchConstraintExpressionString("d0 in [0, 5] || d1 in [0, 5]")); + MatchConstraintExpressionString("d0 in [0, 6) || d1 in [0, 6)")); // `conjunction_1` && `conjunction_3` is an unsatisfiable constraint. Taking // the conjunction of the existing constraint expression with `conjunction_3` @@ -875,7 +875,7 @@ TEST_F( constraints.And(std::move(conjunction_3)); EXPECT_THAT(constraints, - MatchConstraintExpressionString("d0 in [6, 6] && d1 in [0, 5]")); + MatchConstraintExpressionString("d0 in [6, 7) && d1 in [0, 6)")); // But becomes unsatisfiable if we eliminate the last remaining constraint by // constructing another unsatisfiable conjunction. From c573d79097ba98f928c1c835430937c7dc15a0d0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jun 2024 02:59:46 -0700 Subject: [PATCH 213/256] Automated Code Change PiperOrigin-RevId: 646403895 --- tensorflow/lite/BUILD | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index b0264333080f73..4a1d5bda4f394d 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -413,11 +413,12 @@ cc_library( # "//base:addressmap", # "//base:low_level_alloc", # "//base:malloc_hook", -# "//tensorflow/core/profiler/lib:traceme", # "//tensorflow/lite/core/c:common", # "//tensorflow/lite/kernels:kernel_util", # "@local_tsl//tsl/profiler/backends/cpu:traceme_recorder", # "@local_tsl//tsl/profiler/lib:scoped_memory_debug_annotation", +# "@local_tsl//tsl/profiler/lib:traceme", +# "@local_tsl//tsl/profiler/lib:traceme_encode", # "@com_google_absl//absl/base", # "@com_google_absl//absl/base:core_headers", # "@com_google_absl//absl/debugging:stacktrace", From 960949aceeeceb4b1d075aa70f221a7496c3903d Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Tue, 25 Jun 2024 03:17:47 -0700 Subject: [PATCH 214/256] Fix simplification of nested divisons. The rewrite is only correct if there's a single division in the sum. Reverts a0da03263492463221066ee5b3376787e2188465 PiperOrigin-RevId: 646408142 --- .../xla/xla/service/gpu/model/indexing_map.cc | 49 ++++++++++++++----- .../service/gpu/model/indexing_map_test.cc | 22 +++++++++ 2 files changed, 59 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index 9eff63baa40256..f732af92bd5d45 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -271,6 +271,8 @@ AffineExpr AffineExprSimplifier::SimplifySumDiv(AffineExpr dividend, // The gcd of all multipliers and the divisor. int64_t multiplier_divisor_gcd = divisor; Interval no_multiplier_range{0, 0}; + std::optional inner_divisor = std::nullopt; + int num_inner_divisors = 0; VisitSummands(new_dividend, [&](AffineExpr summand) { if (auto multiplier = GetConstantRhs(summand, AffineExprKind::Mul)) { multiplier_divisor_gcd = std::gcd(multiplier_divisor_gcd, *multiplier); @@ -278,6 +280,11 @@ AffineExpr AffineExprSimplifier::SimplifySumDiv(AffineExpr dividend, no_multiplier_range = no_multiplier_range + range_evaluator_->ComputeExpressionRange(summand); } + + if (auto divisor = GetConstantRhs(summand, AffineExprKind::FloorDiv)) { + inner_divisor = divisor; + ++num_inner_divisors; + } }); // Consider an expression like: `(x * 6 + y) / 9`. if the range of `y` is at @@ -296,6 +303,29 @@ AffineExpr AffineExprSimplifier::SimplifySumDiv(AffineExpr dividend, divisor /= multiplier_divisor_gcd; } + // If we have an inner divisor whose value is equal to the GCD of all the + // divisors, we can remove a division: + // `(a0 / c0 + ...) / c1` -> `(a0 + (...) * c0) / c0c1` + // This potentially increases the number of multiplications, but it's + // generally a win. It also matches what the MLIR simplifier does better, so + // we can get more simplifications. Note that this rewrite is not correct if + // there's more than one inner division, since each inner dividend may be + // rounded down, whereas the sum might not be. For example, in + // `(a0 / 3 + a1 / 3) / 6)` + // If a0 is 16 and a1 is 2, the result is `(5 + 0) / 6 = 0`, whereas the + // rewritten form `(a0 + a1) / 18` evaluates to 1. This can only happen when + // there is more than one division. + if (num_inner_divisors == 1) { + new_dividend = MapSummands(new_dividend, [&](AffineExpr summand) { + if (auto inner_divisor = + GetConstantRhs(summand, AffineExprKind::FloorDiv)) { + return GetLhs(summand).floorDiv(*inner_divisor / *inner_divisor); + } + return summand * *inner_divisor; + }); + divisor *= *inner_divisor; + } + return new_dividend.floorDiv(divisor) + extracted; } @@ -481,18 +511,13 @@ AffineExpr AffineExprSimplifier::SimplifyOnce(AffineExpr expr) { if (!div) continue; // Already erased. if ((div_mul % mod_mul) || (div_mul / mod_mul) != mod_c) continue; - auto mod_lhs = GetLhs(mod); - if (GetConstantRhs(mod_lhs, AffineExprKind::FloorDiv)) { - // If x is a floorDiv itself, we need to check a bit more carefully: - // ((x // c0) % c1) * d + (x // (c0 * c1)) * (c1 * d)` - // `x // (c0 * c1)` will be simplified, so we we may not even have - // `c0 * c1` in the expression, if `x` contains a multiplier. - if (Simplify(mod_lhs.floorDiv(*mod_c)) != Simplify(div)) continue; - } else { - if (mod_lhs != GetLhs(div)) continue; - auto div_c = GetConstantRhs(div, AffineExprKind::FloorDiv); - if (mod_c != div_c) continue; - } + // In many cases, we could just compare the LHSes of the mod and the + // div, but if x is a floorDiv itself, we need to check a bit more + // carefully: + // ((x // c0) % c1) * d + (x // (c0 * c1)) * (c1 * d)` + // `x // (c0 * c1)` will be simplified, so we we may not even have + // `c0 * c1` in the expression, if `x` contains a multiplier. + if (Simplify(GetLhs(mod).floorDiv(*mod_c)) != Simplify(div)) continue; others.push_back(GetLhs(mod) * mod_mul); divs[div_i].first = nullptr; diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc index 6eeb6f9e0c3de9..8cf7cb7201ecf1 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc @@ -715,6 +715,28 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsInSequence) { )")); } +TEST_F(IndexingMapTest, AffineMapSimplification_DivDiv) { + auto serialized_map = "()[s0, s1] -> ((s0 * 2 + s1 floordiv 64) floordiv 3)"; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {}, {1234, 128}); + EXPECT_TRUE(indexing_map.Simplify()); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + ()[s0, s1] -> ((s0 * 128 + s1) floordiv 192) + domain: + s0 in [0, 1234) + s1 in [0, 128) + )")); +} + +TEST_F(IndexingMapTest, AffineMapSimplification_DivSumDiv) { + auto serialized_map = + "()[s0, s1] -> ((s0 floordiv 3 + s1 floordiv 3) floordiv 6)"; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {}, {1234, 128}); + // The rewrite tested in AffineMapSimplification_DivDiv must not trigger here. + EXPECT_FALSE(indexing_map.Simplify()); +} + TEST_F(IndexingMapTest, AffineMapSimplification_NegativeDiv) { // (s0 floordiv 2) floordiv -7 is not s0 floordiv -14: // 15 // 2 // -7 = -1 From f004a90f916cb9a5c669ea7626e667a39d788681 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 25 Jun 2024 03:23:56 -0700 Subject: [PATCH 215/256] Fix bugs in matching code. The logic was missing checks whether the operands that are supposed to be the same are actually the same. PiperOrigin-RevId: 646409358 --- .../xla/xla/service/algebraic_simplifier.cc | 17 +++++----- .../xla/service/algebraic_simplifier_test.cc | 34 +++++++++++++++++++ 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/service/algebraic_simplifier.cc b/third_party/xla/xla/service/algebraic_simplifier.cc index 6be5e65adf7fc3..8ed51b652109b5 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.cc +++ b/third_party/xla/xla/service/algebraic_simplifier.cc @@ -4584,20 +4584,21 @@ absl::Status AlgebraicSimplifierVisitor::HandleMultiply( VLOG(10) << "trying transform [sqrt(x) * sqrt(x) => x], for x >= 0 " << multiply->ToString(); - if (Match(multiply, m::Multiply(m::Sqrt(m::Op(&a)), m::Sqrt(m::Op(&a)))) && - IsNonNegative(a, options_)) { - return ReplaceInstruction(multiply, a); + if (Match(multiply, + m::Multiply(m::Sqrt(m::Op(&lhs)), m::Sqrt(m::Op(&rhs)))) && + lhs == rhs && IsNonNegative(lhs, options_)) { + return ReplaceInstruction(multiply, lhs); } - VLOG(10) << "trying transform [rsqrt(B) * rsqrt(B) => 1/B], for B >= 0 " + VLOG(10) << "trying transform [rsqrt(x) * rsqrt(x) => 1/x], for x >= 0 " << multiply->ToString(); - HloInstruction* b; - if (Match(multiply, m::Multiply(m::Rsqrt(m::Op(&b)), m::Rsqrt(m::Op(&b)))) && - IsNonNegative(b, options_)) { + if (Match(multiply, + m::Multiply(m::Rsqrt(m::Op(&lhs)), m::Rsqrt(m::Op(&rhs)))) && + lhs == rhs && IsNonNegative(lhs, options_)) { return ReplaceWithNewInstruction( multiply, HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kDivide, - MakeScalarLike(b, 1), b)); + MakeScalarLike(lhs, 1), lhs)); } return absl::OkStatus(); diff --git a/third_party/xla/xla/service/algebraic_simplifier_test.cc b/third_party/xla/xla/service/algebraic_simplifier_test.cc index d37bb196383b98..2e2b81eb5ab3fa 100644 --- a/third_party/xla/xla/service/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/service/algebraic_simplifier_test.cc @@ -9220,6 +9220,23 @@ TEST_F(AlgebraicSimplifierTest, MultiplySelfSqrt) { GmockMatch(m::Abs(m::Parameter(0)))); } +// sqrt(x) * sqrt(y) is not simplified. +TEST_F(AlgebraicSimplifierTest, MultiplySqrtDifferentOperands) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[1,32] parameter(0) + abs = f32[1,32] abs(p0) + exp = f32[1,32] exponential(p0) + sqrt = f32[1,32] sqrt(abs) + sqrt2 = f32[1,32] sqrt(exp) + ROOT mul = f32[1,32] multiply(sqrt, sqrt2) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); +} + // sqrt(x) * sqrt(x) ≠> x // if x is arbitrary number - no simplification TEST_F(AlgebraicSimplifierTest, MultiplySelfSqrt_NegativeTestCase) { @@ -9253,6 +9270,23 @@ TEST_F(AlgebraicSimplifierTest, MultiplySelfRsqrt) { m::Abs(m::Parameter(0))))); } +// rsqrt(x) * rsqrt(y) is not simplified. +TEST_F(AlgebraicSimplifierTest, MultiplyRsqrtDifferentOperands) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[1,32] parameter(0) + abs = f32[1,32] abs(p0) + exp = f32[1,32] exponential(p0) + rsqrt = f32[1,32] rsqrt(abs) + rsqrt2 = f32[1,32] rsqrt(exp) + ROOT mul = f32[1,32] multiply(rsqrt, rsqrt2) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); +} + // rsqrt(x) * rsqrt(x) -> 1/x // if x is arbitrary number - no simplification TEST_F(AlgebraicSimplifierTest, MultiplySelfRsqrt_NegativeTestCase) { From 9e0957281ee29229a2f907cad196a041c1b126d8 Mon Sep 17 00:00:00 2001 From: Shawn Wang Date: Tue, 25 Jun 2024 03:34:28 -0700 Subject: [PATCH 216/256] PR #13224: [XLA:GPU] Adding an option that allocates temp buffer through separate memory allocator. Imported from GitHub PR https://github.com/openxla/xla/pull/13224 This PR will assign XLA temp buffer to a dedicated memory space, and uses a standalone cuda async allocator to allocate temp buffer, as temp buffer is always allocated & deallocated for each step, this will ensure that temp buffer can be allocated to a stable address, which is good for cuda-graph's perf. Copybara import of the project: -- 2b27e0d0c83ce5450fc40364992f2feba0156d30 by Shawn Wang : use separate memory allocator for XLA module temp buffer Merging this change closes #13224 PiperOrigin-RevId: 646411996 --- third_party/xla/xla/debug_options_flags.cc | 10 ++ third_party/xla/xla/pjrt/gpu/BUILD | 2 + .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 127 ++++++++++-------- .../xla/xla/service/buffer_assignment.cc | 30 +++-- .../xla/xla/service/buffer_assignment.h | 11 +- .../service/gpu/compile_module_to_llvm_ir.cc | 11 +- .../service/gpu/gpu_memory_space_assignment.h | 1 + .../gpu/gpu_cudamallocasync_allocator_test.cc | 2 - third_party/xla/xla/xla.proto | 9 +- 9 files changed, 130 insertions(+), 73 deletions(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 6e4d82c877f992..dc3cdf54d4ce60 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -149,6 +149,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_nccl_comm_splitting(false); opts.set_xla_gpu_enable_nccl_per_stream_comms(false); + opts.set_xla_gpu_temp_buffer_use_separate_color(false); + // Set 4GB space limit for redzone scratch allocator. opts.set_xla_gpu_redzone_scratch_max_megabytes(1LL << 12); opts.set_xla_gpu_redzone_padding_bytes(8 * 1024 * 1024); @@ -1281,6 +1283,14 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "Enables NCCL User Buffer Registration. collective_memory_size in the " "allocator config must also be set to a non-zero value that is large " "enough to meet peak collective memory usage.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_temp_buffer_use_separate_color", + bool_setter_for( + &DebugOptions::set_xla_gpu_temp_buffer_use_separate_color), + debug_options->xla_gpu_temp_buffer_use_separate_color(), + "Enables temp User Buffer Registration. Enable this flag will use a " + "separate cuda async memory allocator to allocate temp buffer, this will " + "allocate temp buffer to the fixed address on every iteration")); flag_list->push_back(tsl::Flag( "xla_gpu_enable_nccl_comm_splitting", bool_setter_for(&DebugOptions::set_xla_gpu_enable_nccl_comm_splitting), diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 42e5f6f3c5c8c8..fb8184b8a87b24 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -123,6 +123,8 @@ cc_library( ]) + if_cuda([ "@local_config_cuda//cuda:cuda_headers", "//xla/stream_executor/gpu:gpu_cudamallocasync_allocator", + "//xla/service/gpu:gpu_memory_space_assignment", + "//xla:debug_options_flags", ]) + if_rocm([ "@local_config_rocm//rocm:rocm_headers", ]), diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 5bf6bac568779c..7d2bc51a04b199 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -93,11 +93,13 @@ limitations under the License. #include "tsl/profiler/lib/traceme.h" #if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) +#include "xla/debug_options_flags.h" #include "xla/pjrt/compile_options.pb.h" #include "xla/pjrt/gpu/gpu_metrics.h" #include "xla/pjrt/gpu/nccl_id_store.h" #include "xla/pjrt/stream_executor_executable.pb.h" #include "xla/service/gpu/gpu_compiler.h" +#include "xla/service/gpu/gpu_memory_space_assignment.h" #include "xla/xla.pb.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -797,59 +799,53 @@ namespace { #if defined(GOOGLE_CUDA) && CUDA_VERSION >= 11020 -absl::StatusOr> -CreateCudaAsyncAllocator( - se::Platform* platform, - const std::map>& addressable_devices, - double memory_fraction, bool preallocate) { - CHECK_GT(addressable_devices.size(), 0); - std::vector allocators; +absl::StatusOr> +CreateCudaAsyncAllocator(const LocalDeviceState& device, double memory_fraction, + bool reserve_memory, bool create_new_pool, + bool sync_mode, bool compute_stats = true) { + se::StreamExecutor* executor = device.executor(); + int device_ordinal = executor->device_ordinal(); + + int64_t free_memory; + int64_t total_memory; + if (!executor->DeviceMemoryUsage(&free_memory, &total_memory)) { + return Unavailable("Failed to query available memory from device %i", + device_ordinal); + } + // To allow full GPU memory to be visible to the Cuda Async allocator + // if using unified memory. + // When unified memory is enabled, allow GPU memory oversubscription by + // setting memory_fraction > 1. + size_t allocator_memory = total_memory * memory_fraction; + if (reserve_memory) { + LOG(INFO) << "XLA backend allocating " << allocator_memory + << " bytes on device " << device_ordinal + << " for CudaAsyncAllocator."; + } else { + LOG(INFO) << "XLA backend will use up to " << allocator_memory + << " bytes on device " << device_ordinal + << " for CudaAsyncAllocator."; + } - for (auto& ordinal_and_device : addressable_devices) { - se::StreamExecutor* executor = ordinal_and_device.second->executor(); - int device_ordinal = executor->device_ordinal(); + auto allocator = std::make_unique( + /*platform_device_id*/ tsl::PlatformDeviceId(device_ordinal), + /*create_new_pool*/ create_new_pool, + /*new_pool_size*/ allocator_memory, + /*reserve_memory*/ reserve_memory, + /*reserve_memory_size*/ reserve_memory ? allocator_memory : 0, + /*sync_mode*/ sync_mode, + /*compute_stats*/ compute_stats); - int64_t free_memory; - int64_t total_memory; - if (!executor->DeviceMemoryUsage(&free_memory, &total_memory)) { - return Unavailable("Failed to query available memory from device %i", - device_ordinal); - } - // To allow full GPU memory to be visible to the Cuda Async allocator - // if using unified memory. - // When unified memory is enabled, allow GPU memory oversubscription by - // setting memory_fraction > 1. - size_t allocator_memory = total_memory * memory_fraction; - if (preallocate) { - LOG(INFO) << "XLA backend allocating " << allocator_memory - << " bytes on device " << device_ordinal - << " for CudaAsyncAllocator."; - } else { - LOG(INFO) << "XLA backend will use up to " << allocator_memory - << " bytes on device " << device_ordinal - << " for CudaAsyncAllocator."; - } + allocator->SetStreamAndPreallocateMemory( + device.compute_stream()->platform_specific_handle().stream); - auto allocator = std::make_unique( - tsl::PlatformDeviceId(device_ordinal), allocator_memory, preallocate); - allocator->SetStreamAndPreallocateMemory( - ordinal_and_device.second->compute_stream() - ->platform_specific_handle() - .stream); - allocators.emplace_back(std::move(allocator), - ordinal_and_device.second->compute_stream(), - /*memory_space=*/0); - } - return allocators; + return allocator; } #else // defined(GOOGLE_CUDA) && CUDA_VERSION >= 11020 - -absl::StatusOr> -CreateCudaAsyncAllocator( - se::Platform* platform, - const std::map>& addressable_devices, - double memory_fraction, bool preallocate) { +absl::StatusOr> CreateCudaAsyncAllocator( + const LocalDeviceState& device, double memory_fraction, bool reserve_memory, + bool create_new_pool, bool sync_mode, bool compute_stats = true) { return FailedPrecondition("CUDA async allocator requires CUDA >= 11.2"); } @@ -881,17 +877,17 @@ GetStreamExecutorGpuDeviceAllocator( std::vector allocators; switch (allocator_config.kind) { case GpuAllocatorConfig::Kind::kCudaAsync: { - auto allocators_or = CreateCudaAsyncAllocator( - platform, addressable_devices, allocator_config.memory_fraction, - allocator_config.preallocate); - if (allocators_or.ok()) { - LOG(INFO) << "Using CUDA async allocator."; - allocators = std::move(allocators_or.value()); - break; + for (const auto& ordinal_and_device : addressable_devices) { + TF_ASSIGN_OR_RETURN( + auto async_allocator, + CreateCudaAsyncAllocator( + *(ordinal_and_device.second), allocator_config.memory_fraction, + allocator_config.preallocate, false, false, true)); + allocators.emplace_back(std::move(async_allocator), + ordinal_and_device.second->compute_stream(), + /*memory_space=*/0); } - LOG(ERROR) << "Failed to initialize CUDA async allocator: " - << allocators_or.status() << "; falling back to BFC."; - [[fallthrough]]; + break; } case GpuAllocatorConfig::Kind::kDefault: @@ -945,6 +941,23 @@ GetStreamExecutorGpuDeviceAllocator( static_cast(se::MemoryType::kHost)); } +#if defined(GOOGLE_CUDA) && CUDA_VERSION >= 11020 + const auto& debug_options = xla::GetDebugOptionsFromFlags(); + if (debug_options.xla_gpu_temp_buffer_use_separate_color()) { + // Add memory allocator to allocate memory buffers with persistent temp + // memory space color. + for (const auto& ordinal_and_device : addressable_devices) { + TF_ASSIGN_OR_RETURN( + auto async_allocator, + CreateCudaAsyncAllocator(*(ordinal_and_device.second), 1.0, false, + true, true, true)); + allocators.emplace_back( + std::move(async_allocator), + ordinal_and_device.second->compute_stream(), + /*memory_space=*/gpu::kTempBufferMemorySpaceColor); + } + } +#endif return std::make_unique(platform, std::move(allocators)); } diff --git a/third_party/xla/xla/service/buffer_assignment.cc b/third_party/xla/xla/service/buffer_assignment.cc index 0f366f7d514272..77aa9283aaade0 100644 --- a/third_party/xla/xla/service/buffer_assignment.cc +++ b/third_party/xla/xla/service/buffer_assignment.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -40,6 +41,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/map_util.h" +#include "xla/service/buffer_value.h" #include "xla/service/buffer_value_containers.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo.pb.h" @@ -251,9 +253,11 @@ void BufferAllocation::AddAssignment(const HloValue& buffer, int64_t offset, CHECK_LE(offset + size, size_) << "LogicalBuffer " << buffer << " size out of range at offset: " << offset << " with size: " << size; - CHECK_EQ(buffer.color(), color()) - << "Buffer color " << buffer.color() << " for buffer " << buffer - << " does not match allocation color " << color() << "."; + if (!(IsPreallocatedTempBuffer() && color() != 0)) { + CHECK_EQ(buffer.color(), color()) + << "Buffer color " << buffer.color() << " for buffer " << buffer + << " does not match allocation color " << color() << "."; + } OffsetSize offset_size; offset_size.offset = offset; offset_size.size = size; @@ -616,7 +620,8 @@ void BufferAssignment::AddAssignment(BufferAllocation* allocation, // Combines allocations of temporary buffers of the same color into one big // BufferAllocation. void BufferAssignment::CombineTempAllocations( - const absl::flat_hash_set& private_stack_colors) { + const absl::flat_hash_set& private_stack_colors, + std::optional temp_buffer_color) { VLOG(1) << "CombineTempAllocations()"; // Stores the combined allocations. std::deque combined_allocations; @@ -700,6 +705,12 @@ void BufferAssignment::CombineTempAllocations( temp_allocation.peak_buffers_.begin(), temp_allocation.peak_buffers_.end()); } + + if (temp_buffer_color.has_value()) { + if (combined_allocation->color() == 0) { + combined_allocation->set_color(temp_buffer_color.value()); + } + } } // Replace all existing temporary allocations with the new combined // allocations. @@ -1091,13 +1102,14 @@ absl::StatusOr> BufferAssigner::Run( const PrivateStacks& private_stacks, GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare heap_buffer_interval_compare, - std::optional isolation_options) { + std::optional isolation_options, + std::optional temp_buffer_color) { BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer), must_not_live_out, std::move(preset_assignments)); return assigner.CreateAssignment( module, std::move(hlo_ordering), std::move(buffer_size), std::move(color_alignment), std::move(can_share_buffer), private_stacks, - heap_buffer_interval_compare, isolation_options); + heap_buffer_interval_compare, isolation_options, temp_buffer_color); } bool BufferAssigner::LiveRangeInterferes(const HloValue* buffer1, @@ -2002,7 +2014,8 @@ BufferAssigner::CreateAssignment( const PrivateStacks& private_stacks, GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare heap_buffer_interval_compare, - std::optional isolation_options) { + std::optional isolation_options, + std::optional temp_buffer_color) { TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, HloAliasAnalysis::Run(module, can_share_buffer)); @@ -2110,7 +2123,8 @@ BufferAssigner::CreateAssignment( for (const auto& [color, computations] : private_stacks) { private_stack_colors.insert(color); } - assignment->CombineTempAllocations(private_stack_colors); + + assignment->CombineTempAllocations(private_stack_colors, temp_buffer_color); XLA_VLOG_LINES(2, assignment->ToString()); TF_RETURN_IF_ERROR(assignment->ComputeSummaryStats()); diff --git a/third_party/xla/xla/service/buffer_assignment.h b/third_party/xla/xla/service/buffer_assignment.h index b3f05e345bd767..337d9faa9f64ac 100644 --- a/third_party/xla/xla/service/buffer_assignment.h +++ b/third_party/xla/xla/service/buffer_assignment.h @@ -163,6 +163,7 @@ class BufferAllocation { // color can reside in this allocation. LogicalBuffer::Color color() const { return color_; } + void set_color(LogicalBuffer::Color color) { color_ = color; } struct OffsetSize { int64_t offset = 0; int64_t size = 0; @@ -579,7 +580,8 @@ class BufferAssignment { // Combines allocations of temporary buffers into one big BufferAllocation. void CombineTempAllocations( - const absl::flat_hash_set& private_stack_colors); + const absl::flat_hash_set& private_stack_colors, + std::optional temp_buffer_color); // Computes stats for the assignment, to be retrieved by GetStats. absl::Status ComputeSummaryStats(); @@ -665,7 +667,8 @@ class BufferAssigner { GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare heap_buffer_interval_compare = nullptr, std::optional - isolation_options = std::nullopt); + isolation_options = std::nullopt, + std::optional temp_buffer_color = std::nullopt); private: BufferAssigner(bool allocate_buffers_for_constants, Colorer colorer, @@ -687,8 +690,8 @@ class BufferAssigner { const PrivateStacks& private_stacks, GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare heap_buffer_interval_compare, - std::optional - isolation_options); + std::optional isolation_options, + std::optional temp_buffer_color); // Assigns buffers to the instructions in the given computations. "assignment" // is modified to reflect the new buffer assignments. If is_thread_local is diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc index ea0a19c4568358..1f09d44d974f7a 100644 --- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -183,7 +183,16 @@ absl::StatusOr CompileModuleToLlvmIr( .xla_gpu_enable_nccl_user_buffers() ? CollectiveColorer() : BufferAssigner::DefaultColorer(), - /*must_not_live_out=*/{}, can_share_buffer_function)); + /*must_not_live_out=*/{}, + /*can_share_buffer*/ can_share_buffer_function, + /*preset_assignments*/ {}, + /*private_stack*/ {}, /*heap_buffer_interval_compare*/ nullptr, + /*isolation_options*/ std::nullopt, + hlo_module->config() + .debug_options() + .xla_gpu_temp_buffer_use_separate_color() + ? std::optional(kTempBufferMemorySpaceColor) + : std::nullopt)); } VLOG(1) << "Buffer Assignment Stats for " << hlo_module->name() << "\n" << results.buffer_assignment->GetStats().ToString(); diff --git a/third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h b/third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h index ebd1af9dfb2b25..fff43614afd98e 100644 --- a/third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h +++ b/third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h @@ -29,6 +29,7 @@ namespace xla { namespace gpu { inline constexpr int64_t kCollectiveMemorySpaceColor = 1; +inline constexpr int64_t kTempBufferMemorySpaceColor = 2; // Set memory space to kCollectiveMemorySpaceColor for all allocations used by // all-reduce, all-gather, and reduce-scatter. This memory space maps to diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator_test.cc index 5fb64838554809..0b155be7f91ced 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator_test.cc @@ -88,7 +88,6 @@ TEST(GpuCudaMallocAsyncAllocator, AddressAlignedNewPool) { /*compute_stats*/ true); allocator.SetStreamAndPreallocateMemory( se::gpu::AsGpuStreamValue(stream.get())); - void* addr1 = allocator.AllocateRaw(128, 127); void* addr2 = allocator.AllocateRaw(128, 129); CHECK_EQ((reinterpret_cast(addr1) & 127), 0); @@ -115,7 +114,6 @@ TEST(GpuCudaMallocAsyncAllocator, SyncAddressAlignedNewPool) { /*compute_stats*/ true); allocator.SetStreamAndPreallocateMemory( se::gpu::AsGpuStreamValue(stream.get())); - void* addr1 = allocator.AllocateRaw(128, 127); void* addr2 = allocator.AllocateRaw(128, 129); CHECK_EQ((reinterpret_cast(addr1) & 127), 0); diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index b8b7a4a3f43def..19fc5ef5b839b7 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -831,7 +831,14 @@ message DebugOptions { // updating command buffer instance. int64 xla_cmd_buffer_trace_cache_size = 311; - // Next id: 312 + // Enable this flag will use a separate memory space color for + // temp buffer, and then will use separate memory allocator to allocate it, + // as there is no other memory allocation interference, + // it will allocate temp buffer to some fix address on every iteration, + // which is good for cuda-graph perf. + bool xla_gpu_temp_buffer_use_separate_color = 312; + + // Next id: 313 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From 821993cbaa8fa4cd8686ecfc81aab465b110529f Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Tue, 25 Jun 2024 04:38:15 -0700 Subject: [PATCH 217/256] Canonicalize nested sums correctly. Currently, we only canonicalize the order of LHS and RHS. In some very rare cases, the structure of an expression is such that `canonicalize(simplify-mlir(expr))` is not idempotent, leading to infinite loops in apply_indexing canonicalization. PiperOrigin-RevId: 646427921 --- .../xla/service/gpu/fusions/loop_mlir_test.cc | 4 +- .../xla/xla/service/gpu/fusions/loop_test.cc | 4 +- .../gpu/fusions/reduction_mlir_test.cc | 2 +- .../gpu/fusions/transpose_mlir_test.cc | 2 +- .../xla/xla/service/gpu/model/indexing_map.cc | 84 ++++++++++--------- .../service/gpu/model/indexing_map_test.cc | 12 +++ 6 files changed, 64 insertions(+), 44 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc index 52feb6c50fea11..ccc86600eb10a8 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc @@ -54,8 +54,8 @@ TEST_F(MlirLoopFusionTest, ThreadId_IndexingUnrolled) { EXPECT_THAT(thread_id_to_output_indexing->ToString(thread_id_printer_), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( - ((bl_x * 128 + th_x + chunk_id * 129024) floordiv 15000) mod 100, - ((bl_x * 128 + th_x + chunk_id * 129024) floordiv 75) mod 200, + ((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 15000) mod 100, + ((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 75) mod 200, (th_x * 4 + bl_x * 512 + chunk_id * 516096) mod 300 + unroll_id ) domain: diff --git a/third_party/xla/xla/service/gpu/fusions/loop_test.cc b/third_party/xla/xla/service/gpu/fusions/loop_test.cc index eb583f6def0da8..c28491b18448a7 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop_test.cc @@ -88,8 +88,8 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) { EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( - ((bl_x * 128 + th_x + chunk_id * 129024) floordiv 15000) mod 100, - ((bl_x * 128 + th_x + chunk_id * 129024) floordiv 75) mod 200, + ((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 15000) mod 100, + ((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 75) mod 200, (th_x * 4 + bl_x * 512 + chunk_id * 516096) mod 300 + unroll_id ) domain: diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc index 35e8656ae0a2e9..a9dfd784b924df 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc @@ -362,7 +362,7 @@ TEST_F(MlirRowReductionTest, NonPowerOfTwoRowReduction) { ROOT fusion = f64[100] fusion(a, c), kind=kInput, calls=fused_computation })"; TF_EXPECT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1)[s0] -> (d0 + s0 * 128 + (d1 mod 64) * 2)> + // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1)[s0] -> ((d1 mod 64) * 2 + s0 * 128 + d0)> // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> ((d1 mod 64) * 2 + d0 + 512)> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc index a54885894f96b4..33d18cd2461ce6 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc @@ -221,7 +221,7 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexingVectorized210) { (d0, d1, d2, d3, d4, d5)[s0, s1] -> ( d0 floordiv 32 + s0 * 4, d3 floordiv 128, - (d0 mod 32) * 2 + s1 + (d3 mod 128) * 64 + (d0 mod 32) * 2 + (d3 mod 128) * 64 + s1 ) domain: d0 in [0, 128) diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index f732af92bd5d45..b455336daec615 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -80,10 +80,38 @@ using mlir::MLIRContext; AffineExpr GetLhs(AffineExpr e) { return mlir::cast(e).getLHS(); -}; +} + AffineExpr GetRhs(AffineExpr e) { return mlir::cast(e).getRHS(); -}; +} + +// Rewrites summands in arbitrarily nested sums (e.g, ((a+b)+c)) by applying +// `fn` to each one. In the example, the result is fn(a)+fn(b)+fn(c). +AffineExpr MapSummands(AffineExpr expr, + const std::function& fn) { + if (expr.getKind() == AffineExprKind::Add) { + auto add = mlir::dyn_cast(expr); + auto lhs = MapSummands(add.getLHS(), fn); + auto rhs = MapSummands(add.getRHS(), fn); + if (lhs == add.getLHS() && rhs == add.getRHS()) { + return add; + } + return lhs + rhs; + } + return fn(expr); +} + +// Calls `visit` for each summand in an arbitrarily nested sum. +void VisitSummands(mlir::AffineExpr expr, + const std::function& visit) { + if (expr.getKind() == AffineExprKind::Add) { + VisitSummands(GetLhs(expr), visit); + VisitSummands(GetRhs(expr), visit); + } else { + visit(expr); + } +} class AffineExprSimplifier { public: @@ -124,13 +152,6 @@ class AffineExprSimplifier { // Rewrites `a // b` where a may be a sum. AffineExpr SimplifySumDiv(AffineExpr dividend, int64_t divisor); - // Rewrites summands in arbitrarily nested sums (e.g, ((a+b)+c)) by applying - // `fn` to each one. In the example, the result is fn(a)+fn(b)+fn(c). - AffineExpr MapSummands(AffineExpr expr, - const std::function& fn); - void VisitSummands(mlir::AffineExpr expr, - const std::function& visit); - // Attempts to simplify the expression, but doesn't attempt to simplify the // result further. mlir::AffineExpr SimplifyOnce(mlir::AffineExpr expr); @@ -371,30 +392,6 @@ std::optional AffineExprSimplifier::GetConstantRhs( return bound.lower; } -AffineExpr AffineExprSimplifier::MapSummands( - AffineExpr expr, const std::function& fn) { - if (expr.getKind() == AffineExprKind::Add) { - auto add = mlir::dyn_cast(expr); - auto lhs = MapSummands(add.getLHS(), fn); - auto rhs = MapSummands(add.getRHS(), fn); - if (lhs == add.getLHS() && rhs == add.getRHS()) { - return add; - } - return lhs + rhs; - } - return fn(expr); -} - -void AffineExprSimplifier::VisitSummands( - mlir::AffineExpr expr, const std::function& visit) { - if (expr.getKind() == AffineExprKind::Add) { - VisitSummands(GetLhs(expr), visit); - VisitSummands(GetRhs(expr), visit); - } else { - visit(expr); - } -} - // Compares the two expression by their AST. The ordering is arbitrary but // similar to what MLIR's simplifier does. int CompareExprs(AffineExpr a, AffineExpr b) { @@ -443,14 +440,25 @@ int CompareExprs(AffineExpr a, AffineExpr b) { } AffineExpr CanonicalizeOrder(AffineExpr in) { + if (in.getKind() == AffineExprKind::Add) { + // If we have nested adds, canonicalize all of them together. + llvm::SmallVector summands; + VisitSummands(in, [&](AffineExpr summand) { + summands.push_back(CanonicalizeOrder(summand)); + }); + llvm::sort(summands, [](AffineExpr a, AffineExpr b) { + return CompareExprs(a, b) < 0; + }); + auto result = mlir::getAffineConstantExpr(0, in.getContext()); + for (auto summand : summands) { + result = result + summand; + } + return result; + } + if (auto binop = mlir::dyn_cast(in)) { auto lhs = CanonicalizeOrder(binop.getLHS()); auto rhs = CanonicalizeOrder(binop.getRHS()); - if ((binop.getKind() == AffineExprKind::Add || - binop.getKind() == AffineExprKind::Mul) && - CompareExprs(lhs, rhs) > 0) { - std::swap(lhs, rhs); - } return getAffineBinaryOpExpr(binop.getKind(), lhs, rhs); } return in; diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc index 8cf7cb7201ecf1..b497ad39b3575d 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc @@ -562,6 +562,18 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ConstantDims) { )")); } +TEST_F(IndexingMapTest, AffineMapSimplification_SumOrderRegression) { + // This is a regression test for a bug where we didn't canonicalize the order + // of summands correctly, leading to `Simplify` not being idempotent. + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0, d1)[s0, s1] -> (((((d0 + (d0 mod 3)) floordiv 3) + " + "(s0 + ((s0 + s0) mod 3))) + (((d0 + s0) mod 3) + 0)))", + &mlir_context_), + {10, 20}, {30, 40}); + EXPECT_TRUE(indexing_map.Simplify()); + EXPECT_FALSE(indexing_map.Simplify()); +} + TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsIfSmallerThanDivisor) { auto serialized_map = "(d0, d1) -> (d0 + d1 floordiv 16, d1 mod 16)"; From 9ebb13917879a7ec500e0c7b895bc9b9eddd0f07 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Tue, 25 Jun 2024 04:41:10 -0700 Subject: [PATCH 218/256] Don't generate vector<1x*> in shuffle_reduce lowering. Previously, this was the case for types < 32 bits in width. The generated ptx is the same, but the intermediate IR is unnecessarily complex. PiperOrigin-RevId: 646428725 --- .../gpu/fusions/mlir/lower_xla_gpu_to_scf.cc | 23 +++++++++++-------- .../mlir/tests/lower_xla_gpu_to_scf.mlir | 17 ++++++++++++++ 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/lower_xla_gpu_to_scf.cc b/third_party/xla/xla/service/gpu/fusions/mlir/lower_xla_gpu_to_scf.cc index 4222fd9db15767..c2cc64536cf80b 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/lower_xla_gpu_to_scf.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/lower_xla_gpu_to_scf.cc @@ -125,16 +125,21 @@ struct RewriteShuffleReduce : mlir::OpRewritePattern { auto padded_int_ty = b.getIntegerType(n_shuffles * 32); value = b.create(int_ty, value); value = b.create(padded_int_ty, value); - auto vector_type = ml::getVectorType(b.getI32Type(), n_shuffles); - value = b.create(vector_type, value); - mlir::Value result_vec = b.create(vector_type); - for (int i = 0; i < n_shuffles; ++i) { - auto idx = b.create(i, 32); - result_vec = b.create( - result_vec, - shuffle_32(b.create(value, idx)), idx); + if (n_shuffles > 1) { + // Don't generate vectors if the size is 1. + auto vector_type = ml::getVectorType(b.getI32Type(), n_shuffles); + value = b.create(vector_type, value); + mlir::Value result_vec = b.create(vector_type); + for (int i = 0; i < n_shuffles; ++i) { + auto idx = b.create(i, 32); + result_vec = b.create( + result_vec, + shuffle_32(b.create(value, idx)), idx); + } + value = b.create(padded_int_ty, result_vec); + } else { + value = shuffle_32(value); } - value = b.create(padded_int_ty, result_vec); value = b.create(int_ty, value); value = b.create(ty, value); return value; diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_xla_gpu_to_scf.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_xla_gpu_to_scf.mlir index efc8bdb2a28953..645430ae0d1bcc 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_xla_gpu_to_scf.mlir +++ b/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_xla_gpu_to_scf.mlir @@ -79,6 +79,23 @@ module { // ----- +module { + func.func @reducer(%a: i8, %b: i8) -> i8 { + return %a : i8 + } + + func.func @shuffler_i8(%a: i8) -> i8 { + %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : i8 + return %ret : i8 + } +} + +// CHECK: @shuffler_i8( +// CHECK-NOT: vector +// CHECK-COUNT-1: gpu.shuffle down {{.*}}, %[[C1]] + +// ----- + module { func.func @predicated_insert( %v: i32, %tensor: tensor<2xi32>, %index: index, From 3183231b3beb37dbbd1581fc664e291e8d3e8113 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Tue, 25 Jun 2024 04:41:19 -0700 Subject: [PATCH 219/256] PR #13920: Move collective pipeliner after post-layout opts and add a unit test Imported from GitHub PR https://github.com/openxla/xla/pull/13920 @kaixih @frgossen This is just a rebased version of https://github.com/openxla/xla/pull/12866 Copybara import of the project: -- dcf940b6f41e3e14b8e999e36055c07bf468f9a1 by shuw : Move collective pipeliner after post-layout opts and add a unit test -- 9a731e9ec6c84f809804776a3fba2d0b11a60294 by shuw : Change to EXPECT_EQ Merging this change closes #13920 PiperOrigin-RevId: 646428752 --- third_party/xla/xla/BUILD | 2 + third_party/xla/xla/debug_options_flags.cc | 11 ++ .../xla/xla/service/gpu/gpu_compiler.cc | 70 +++++++---- third_party/xla/xla/tests/BUILD | 1 + .../xla/xla/tests/collective_ops_e2e_test.cc | 116 ++++++++++++++++++ third_party/xla/xla/xla.proto | 3 +- 6 files changed, 179 insertions(+), 24 deletions(-) diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index f81390565195b0..27f86f61f9c431 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -1112,6 +1112,8 @@ cc_library( "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:logging", diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index dc3cdf54d4ce60..8e07a7428eb168 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/debug_options_flags.h" +#include #include #include #include @@ -26,6 +27,8 @@ limitations under the License. #include "absl/base/call_once.h" #include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/ascii.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -169,6 +172,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_pipelined_reduce_scatter(false); opts.set_xla_gpu_enable_pipelined_p2p(false); + opts.set_xla_gpu_run_post_layout_collective_pipeliner(true); + opts.set_xla_gpu_collective_permute_decomposer_threshold( std::numeric_limits::max()); @@ -1432,6 +1437,12 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_enable_pipelined_p2p), debug_options->xla_gpu_enable_pipelined_p2p(), "Enable pipelinling of P2P instructions.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_run_post_layout_collective_pipeliner", + bool_setter_for( + &DebugOptions::set_xla_gpu_run_post_layout_collective_pipeliner), + debug_options->xla_gpu_run_post_layout_collective_pipeliner(), + "Move collective pipeliner after the post-layout optimization.")); flag_list->push_back(tsl::Flag( "xla_gpu_collective_permute_decomposer_threshold", int64_setter_for( diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 06874df49a2714..37ca5ee31e5e8a 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -792,29 +792,8 @@ absl::Status RunOptimizationPasses( return pipeline.Run(hlo_module).status(); } -absl::Status RunCollectiveOptimizationPasses( - HloModule* hlo_module, - const AlgebraicSimplifierOptions& layout_insensitive_algsimp_opts, - se::GpuComputeCapability gpu_version) { - // Optimize collectives generated by SPMD partitioning. Enable these passes - // otherwise as well so that all collectives can get these optimizations. - const DebugOptions& debug_options = hlo_module->config().debug_options(); - - HloPassPipeline collectives_pipeline("collective-optimizations"); - collectives_pipeline.AddPass(); - if (debug_options.xla_gpu_enable_all_reduce_splitter()) { - collectives_pipeline.AddPass(); - } - collectives_pipeline.AddPass(); - collectives_pipeline.AddPass(); - collectives_pipeline.AddPass( - debug_options.xla_gpu_enable_reassociation_for_converted_ar()); - collectives_pipeline.AddPass(); - - collectives_pipeline.AddPass( - /*enable_reduce_scatter=*/debug_options - .xla_gpu_enable_while_loop_reduce_scatter_code_motion()); - +absl::Status AddCollectivePipelinerPasses( + const DebugOptions& debug_options, HloPassPipeline& collectives_pipeline) { if (debug_options.xla_gpu_enable_pipelined_collectives() || debug_options.xla_gpu_enable_pipelined_all_reduce()) { CollectivePipeliner::Config config{ @@ -872,6 +851,49 @@ absl::Status RunCollectiveOptimizationPasses( /*reuse_pipelined_op_buffer=*/HloPredicateFalse}; collectives_pipeline.AddPass(config); } + return absl::OkStatus(); +} + +absl::Status RunPostLayoutCollectivePipelinerPasses(HloModule* hlo_module) { + const DebugOptions& debug_options = hlo_module->config().debug_options(); + HloPassPipeline collectives_pipeline("collective-pipeliner-optimizations"); + if (debug_options.xla_gpu_run_post_layout_collective_pipeliner()) { + TF_RETURN_IF_ERROR( + AddCollectivePipelinerPasses(debug_options, collectives_pipeline)); + // We call WhileLoopTripCountAnnotator at the end of the collective + // pipeline, which might have changed the loop trip count. + collectives_pipeline.AddPass(); + } + return collectives_pipeline.Run(hlo_module).status(); +} + +absl::Status RunCollectiveOptimizationPasses( + HloModule* hlo_module, + const AlgebraicSimplifierOptions& layout_insensitive_algsimp_opts, + se::GpuComputeCapability gpu_version) { + // Optimize collectives generated by SPMD partitioning. Enable these passes + // otherwise as well so that all collectives can get these optimizations. + const DebugOptions& debug_options = hlo_module->config().debug_options(); + + HloPassPipeline collectives_pipeline("collective-optimizations"); + collectives_pipeline.AddPass(); + if (debug_options.xla_gpu_enable_all_reduce_splitter()) { + collectives_pipeline.AddPass(); + } + collectives_pipeline.AddPass(); + collectives_pipeline.AddPass(); + collectives_pipeline.AddPass( + debug_options.xla_gpu_enable_reassociation_for_converted_ar()); + collectives_pipeline.AddPass(); + + collectives_pipeline.AddPass( + /*enable_reduce_scatter=*/debug_options + .xla_gpu_enable_while_loop_reduce_scatter_code_motion()); + + if (!debug_options.xla_gpu_run_post_layout_collective_pipeliner()) { + TF_RETURN_IF_ERROR( + AddCollectivePipelinerPasses(debug_options, collectives_pipeline)); + } collectives_pipeline.AddPass( hlo_module->config() @@ -1215,6 +1237,8 @@ absl::Status GpuCompiler::OptimizeHloModule( hlo_module, stream_exec, options, gpu_target_config, thread_pool.get_mutable())); + TF_RETURN_IF_ERROR(RunPostLayoutCollectivePipelinerPasses(hlo_module)); + // This is a "low effort, high impact" fusion that should be run first. if (hlo_module->config() .debug_options() diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index c10e7c04725339..333c103e222810 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -2217,6 +2217,7 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/service/gpu:backend_configs_cc", + "//xla/stream_executor", ], ) diff --git a/third_party/xla/xla/tests/collective_ops_e2e_test.cc b/third_party/xla/xla/tests/collective_ops_e2e_test.cc index fc14363173ad1b..a539f6a161cd2d 100644 --- a/third_party/xla/xla/tests/collective_ops_e2e_test.cc +++ b/third_party/xla/xla/tests/collective_ops_e2e_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include "xla/hlo/ir/hlo_casting_utils.h" @@ -24,6 +25,8 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" @@ -47,6 +50,26 @@ DeviceAssignment MakeDeviceAssn(int64_t num_replicas) { class CollectiveOpsTestE2E : public HloTestBase { public: + bool IsCuda() { + return std::holds_alternative(Capability()); + } + + const se::GpuComputeCapability& Capability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + } + + bool HasFp8Support() { + if (IsCuda()) { + return std::get(Capability()).IsAtLeast(8, 9); + } + return std::get(Capability()) + .has_fp8_support() && + GetDebugOptionsForTest().xla_gpu_enable_cublaslt(); + } + absl::StatusOr> ExecuteReplicated(Executable* executable, int64_t num_replicas) { DeviceAssignment device_assignment = MakeDeviceAssn(num_replicas); @@ -865,5 +888,98 @@ ENTRY main.9_spmd { CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr); } +TEST_F(CollectiveOpsTestE2E, PostLayoutCollectivePipeliner) { + // We need fp8 support to test the post-layout collective pipeliner. This will + // preserve the desired fp8 patterns and so the gemm rewriter can correctly + // recognize them and rewrite to custom fp8 gemm calls. + if (!HasFp8Support()) { + GTEST_SKIP() << "Test requires a post-Ada GPU."; + } + + absl::string_view kModuleReplicatedStr = R"( +HloModule module, entry_computation_layout={(bf16[384,128], bf16[96,128], bf16[], bf16[])->bf16[384,128]}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} +while_cond { + param = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} +while_body { + param = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[384,128] get-tuple-element(param), index=1 + get-tuple-element.k = bf16[96,128] get-tuple-element(param), index=2 + constant.2561 = s32[] constant(0) + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.k = bf16[32,128] dynamic-slice(get-tuple-element.k, select.1348, constant.2561), dynamic_slice_sizes={32,128} + r = bf16[32,128] bitcast(dynamic-slice.k) + a = bf16[32,128] add(r, r), control-predecessors={constant.2559} + // A fp8 pattern of quant-dequant before the collective AG. + qa = f8e4m3fn[32,128] convert(a) + dqa = bf16[32,128] convert(qa) + a_scale = bf16[] get-tuple-element(param), index=3 + a_scales = bf16[32,128] broadcast(a_scale), dimensions={} + dqa_unscaled = bf16[32,128] multiply(dqa, a_scales) + mb = bf16[128,128] all-gather(dqa_unscaled), channel_id=1, use_global_device_ids=true, dimensions={0}, replica_groups={{0,1,2,3}} + ma = bf16[128,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561), dynamic_slice_sizes={128,128} + + qma = f8e4m3fn[128,128] convert(ma) + dqma = bf16[128,128] convert(qma) + ma_scale = bf16[] get-tuple-element(param), index=4 + ma_scales = bf16[128,128] broadcast(ma_scale), dimensions={} + dqma_unscaled = bf16[128,128] multiply(dqma, ma_scales) + mc = bf16[128,128] dot(dqma_unscaled, mb), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dynamic-update-slice.35 = bf16[384,128] dynamic-update-slice(get-tuple-element.395, mc, select.1348, constant.2561) + ROOT tuple = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.k, a_scale, ma_scale), control-predecessors={a} +} +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[384,128] parameter(0) + p1 = bf16[96,128] parameter(1) + s0 = bf16[] parameter(2) + s1 = bf16[] parameter(3) + tuple = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) tuple(c0, p0, p1, s0, s1) + while = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[384,128] get-tuple-element(while), index=1 +} +)"; + + const int64_t kNumReplicas = 1; + const int64_t kNumPartitions = 4; + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + auto opts = GetDebugOptionsForTest(); + opts.set_xla_gpu_run_post_layout_collective_pipeliner(true); + opts.set_xla_gpu_enable_pipelined_collectives(true); + opts.set_xla_gpu_enable_triton_gemm(false); + config.set_debug_options(opts); + config.set_num_partitions(kNumPartitions); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config)); + + TF_ASSERT_OK_AND_ASSIGN(auto executable, + CreateExecutable(std::move(module), + /*run_hlo_passes=*/true)); + EXPECT_TRUE(executable->has_module()); + HloInstruction* gemm_op = + FindInstruction(&executable->module(), HloOpcode::kCustomCall); + EXPECT_THAT(gemm_op, NotNull()); + EXPECT_EQ(gemm_op->custom_call_target(), "__cublas$lt$matmul$f8"); +} } // namespace } // namespace xla diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 19fc5ef5b839b7..b8efd2fba7b1bf 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -584,6 +584,7 @@ message DebugOptions { bool xla_gpu_enable_pipelined_all_gather = 227; bool xla_gpu_enable_pipelined_reduce_scatter = 231; bool xla_gpu_enable_pipelined_p2p = 246; + bool xla_gpu_run_post_layout_collective_pipeliner = 313; // The minimum data size in bytes to trigger collective-permute-decomposer // transformation. @@ -838,7 +839,7 @@ message DebugOptions { // which is good for cuda-graph perf. bool xla_gpu_temp_buffer_use_separate_color = 312; - // Next id: 313 + // Next id: 314 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From 743c1c8e34e11bafd069fe8c84bb343b1a0e8ca7 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Tue, 25 Jun 2024 04:56:15 -0700 Subject: [PATCH 220/256] [XLA:GPU] Remove obsolete comment. PiperOrigin-RevId: 646432168 --- third_party/xla/xla/service/gpu/model/indexing_test_utils.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/indexing_test_utils.h b/third_party/xla/xla/service/gpu/model/indexing_test_utils.h index 1c784191ebf913..62abd0e5e7fdb4 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_test_utils.h +++ b/third_party/xla/xla/service/gpu/model/indexing_test_utils.h @@ -33,9 +33,6 @@ namespace xla { namespace gpu { // Matches two strings ignoring whitespaces. -// 'lhs' may contain regions bounded by the special pattern '###', -// in which case, each region is parsed as a sequence of terms separated by -// '+ signs. The function will try to match all permutations of terms. bool ApproximateMatch(std::string_view lhs, std::string_view rhs); MATCHER(UndefinedMap, "") { return arg.IsUndefined(); } From 812f36f04e02b58a541f3287da54f49b37f10757 Mon Sep 17 00:00:00 2001 From: Shawn Wang Date: Tue, 25 Jun 2024 04:58:38 -0700 Subject: [PATCH 221/256] PR #13946: [XLA:GPU] Fix the bug of thread conflicts for running command buffer cublasLtCmd Imported from GitHub PR https://github.com/openxla/xla/pull/13946 It is a bug that CublasLtCmd's save the gemm plan and algorithm across Initialize and Execute. Copybara import of the project: -- 236e513efa1a93f5aaf014ba31b876f4f6c47c04 by Shawn Wang : add mutex to resolve multi thread conflicting in cublasLt command -- 39a7dd78218dadd69ae360d4024b1bbc9de3c2e5 by Shawn Wang : update the cublasLt test with thread conflicting case. Merging this change closes #13946 PiperOrigin-RevId: 646432697 --- .../service/gpu/runtime/command_buffer_cmd.cc | 15 +- .../service/gpu/runtime/command_buffer_cmd.h | 3 - .../gpu/runtime/command_buffer_thunk_test.cc | 140 +++++++++--------- 3 files changed, 83 insertions(+), 75 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc index 77787478bc7839..4cf4c3a9b9274f 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc @@ -1194,15 +1194,20 @@ absl::Status CublasLtCmd::Initialize(const Thunk::InitializeParams& params, if (!params.stream->parent()->AsBlas()) { return absl::InternalError("Failed to initialize BLAS support for GemmCmd"); } - TF_ASSIGN_OR_RETURN(plan_, GetMatmulPlan(params.stream)); - TF_ASSIGN_OR_RETURN(algorithm_, - GetMatmulAlgorithm(plan_, workspace_buffer_.size())); + // Populate plan and algorithm cache; + TF_ASSIGN_OR_RETURN(auto plan, GetMatmulPlan(params.stream)); + TF_RETURN_IF_ERROR( + GetMatmulAlgorithm(plan, workspace_buffer_.size()).status()); return absl::OkStatus(); } absl::Status CublasLtCmd::Record(const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, se::CommandBuffer* command_buffer) { + TF_ASSIGN_OR_RETURN(auto plan, GetMatmulPlan(execute_params.stream)); + TF_ASSIGN_OR_RETURN(auto algorithm, + GetMatmulAlgorithm(plan, workspace_buffer_.size())); + const BufferAllocations& allocs = *execute_params.buffer_allocations; se::DeviceMemoryBase bias, a_scale, b_scale, c_scale, d_scale, aux, d_amax; @@ -1247,12 +1252,12 @@ absl::Status CublasLtCmd::Record(const Thunk::ExecuteParams& execute_params, return AddTracedCommandBuffer( execute_params, record_params, command_buffer, [&](se::Stream* stream) { - return plan_->ExecuteOnStream( + return plan->ExecuteOnStream( stream, allocs.GetDeviceAddress(a_buffer_), allocs.GetDeviceAddress(b_buffer_), allocs.GetDeviceAddress(c_buffer_), allocs.GetDeviceAddress(d_buffer_), bias, aux, a_scale, b_scale, - c_scale, d_scale, d_amax, algorithm_, + c_scale, d_scale, d_amax, algorithm, allocs.GetDeviceAddress(workspace_buffer_)); }); } diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h index 507b216e026de4..c9c26f1826b7f0 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h @@ -781,9 +781,6 @@ class CublasLtCmd : public TracedCommandBufferCmd { se::gpu::BlasLt::MatmulAlgorithm> matmul_algorithm_cache_; - se::gpu::BlasLt::MatmulPlan* plan_; - se::gpu::BlasLt::MatmulAlgorithm algorithm_; - const GemmConfig gemm_config_; const se::gpu::BlasLt::Epilogue epilogue_; const int64_t algorithm_idx_; diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc index 4355cee257dd06..60c7a3e3ddfd85 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include // NOLINT #include #include @@ -594,7 +595,8 @@ TEST(CommandBufferThunkTest, CublasLtCmd) { se::StreamExecutor* executor = GpuExecutor(); - TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(auto stream1, executor->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(auto stream2, executor->CreateStream()); // CublasLt formula: D = alpha*(A*B) + beta*(C), @@ -603,35 +605,6 @@ TEST(CommandBufferThunkTest, CublasLtCmd) { int64_t c_length = sizeof(float) * 2 * 3; int64_t d_length = sizeof(float) * 2 * 3; - // Prepare arguments: - // a = [1.0, 2.0, 3.0, 4.0 - // 5.0, 6.0, 7.0, 8.0] - // b = [1.0, 1.0, 1.0 - // 1.0, 1.0, 1.0 - // 1.0, 1.0, 1.0 - // 1.0, 1.0, 1.0] - // c = [1.0, 1.0, 1.0 - // 1.0, 1.0, 1.0] - - se::DeviceMemory a = executor->AllocateArray(2 * 4); - std::vector a_arr{1, 2, 3, 4, 5, 6, 7, 8}; - TF_ASSERT_OK(stream->Memcpy(&a, a_arr.data(), a_length)); - - se::DeviceMemory b = executor->AllocateArray(4 * 3); - std::vector b_arr(12, 1); - TF_ASSERT_OK(stream->Memcpy(&b, b_arr.data(), b_length)); - - se::DeviceMemory c = executor->AllocateArray(2 * 3); - std::vector c_arr(6, 1); - TF_ASSERT_OK(stream->Memcpy(&c, c_arr.data(), c_length)); - - se::DeviceMemory d = executor->AllocateArray(2 * 3); - TF_ASSERT_OK(stream->MemZero(&d, d_length)); - - se::DeviceMemory workspace = - executor->AllocateArray(1024 * 1024); - TF_ASSERT_OK(stream->MemZero(&workspace, 1024 * 1024)); - // Prepare buffer allocations for recording command buffer. BufferAllocation alloc_a(/*index=*/0, a_length, /*color=*/0); BufferAllocation alloc_b(/*index=*/1, b_length, /*color=*/0); @@ -673,57 +646,90 @@ TEST(CommandBufferThunkTest, CublasLtCmd) { // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); - ServiceExecutableRunOptions run_options; - se::StreamExecutorMemoryAllocator allocator(executor); - BufferAllocations allocations({a, b, c, d, workspace}, 0, &allocator); + std::vector a_arr_1{1, 2, 3, 4, 5, 6, 7, 8}; + std::vector a_arr_2{2, 3, 4, 5, 6, 7, 8, 9}; + std::vector result_1{11, 11, 11, 27, 27, 27}; + std::vector result_2{15, 15, 15, 31, 31, 31}; - Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( - run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); + auto run_cublaslt_test = [&](std::unique_ptr& stream, + std::vector a_arr, + std::vector result) { + se::DeviceMemory a = executor->AllocateArray(2 * 4); + TF_ASSERT_OK(stream->Memcpy(&a, a_arr.data(), a_length)); - Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; - TF_ASSERT_OK(thunk.Initialize( - {executor, source, &allocations, stream.get(), stream.get()})); + se::DeviceMemory b = executor->AllocateArray(4 * 3); + std::vector b_arr(12, 1); + TF_ASSERT_OK(stream->Memcpy(&b, b_arr.data(), b_length)); - // Execute command buffer thunk and verify that it executed a GEMM. - TF_ASSERT_OK(thunk.ExecuteOnStream(params)); - TF_ASSERT_OK(stream->BlockHostUntilDone()); + se::DeviceMemory c = executor->AllocateArray(2 * 3); + std::vector c_arr(6, 1); + TF_ASSERT_OK(stream->Memcpy(&c, c_arr.data(), c_length)); - // Copy `out` data back to host. - std::vector dst(6, 0); - TF_ASSERT_OK(stream->Memcpy(dst.data(), d, d_length)); + se::DeviceMemory d = executor->AllocateArray(2 * 3); + TF_ASSERT_OK(stream->MemZero(&d, d_length)); - ASSERT_EQ(dst, std::vector({11, 11, 11, 27, 27, 27})); + se::DeviceMemory workspace = + executor->AllocateArray(1024 * 1024); + TF_ASSERT_OK(stream->MemZero(&workspace, 1024 * 1024)); - // Prepare buffer allocation for updating command buffer. - se::DeviceMemory updated_d = executor->AllocateArray(2 * 3); - TF_ASSERT_OK(stream->MemZero(&updated_d, d_length)); + ServiceExecutableRunOptions run_options; + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({a, b, c, d, workspace}, 0, &allocator); - // Update buffer allocation to updated `d` buffer. - allocations = - BufferAllocations({a, b, c, updated_d, workspace}, 0, &allocator); + Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( + run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); - // Thunk execution should automatically update underlying command buffer. - TF_ASSERT_OK(thunk.ExecuteOnStream(params)); - TF_ASSERT_OK(stream->BlockHostUntilDone()); + Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; + TF_ASSERT_OK(thunk.Initialize( + {executor, source, &allocations, stream.get(), stream.get()})); - // Copy `updated_out` data back to host. - std::fill(dst.begin(), dst.end(), 0); - TF_ASSERT_OK(stream->Memcpy(dst.data(), updated_d, d_length)); + // Execute command buffer thunk and verify that it executed a GEMM. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); - ASSERT_EQ(dst, std::vector({11, 11, 11, 27, 27, 27})); + // Copy `out` data back to host. + std::vector dst(6, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), d, d_length)); - // Try to update the command buffer with the same buffers. - TF_ASSERT_OK(stream->MemZero(&updated_d, d_length)); + ASSERT_EQ(dst, result); - // Thunk execution should automatically update underlying command buffer. - TF_ASSERT_OK(thunk.ExecuteOnStream(params)); - TF_ASSERT_OK(stream->BlockHostUntilDone()); + // Prepare buffer allocation for updating command buffer. + se::DeviceMemory updated_d = executor->AllocateArray(2 * 3); + TF_ASSERT_OK(stream->MemZero(&updated_d, d_length)); - // Copy `updated_out` data back to host. - std::fill(dst.begin(), dst.end(), 0); - TF_ASSERT_OK(stream->Memcpy(dst.data(), updated_d, d_length)); + // Update buffer allocation to updated `d` buffer. + allocations = + BufferAllocations({a, b, c, updated_d, workspace}, 0, &allocator); + + // Thunk execution should automatically update underlying command + // buffer. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `updated_out` data back to host. + std::fill(dst.begin(), dst.end(), 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), updated_d, d_length)); + + ASSERT_EQ(dst, result); - ASSERT_EQ(dst, std::vector({11, 11, 11, 27, 27, 27})); + // Try to update the command buffer with the same buffers. + TF_ASSERT_OK(stream->MemZero(&updated_d, d_length)); + + // Thunk execution should automatically update underlying command + // buffer. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `updated_out` data back to host. + std::fill(dst.begin(), dst.end(), 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), updated_d, d_length)); + + ASSERT_EQ(dst, result); + }; + std::thread t1(run_cublaslt_test, std::ref(stream1), a_arr_1, result_1); + std::thread t2(run_cublaslt_test, std::ref(stream2), a_arr_2, result_2); + t1.join(); + t2.join(); } TEST(CommandBufferThunkTest, MultipleLaunchCmd) { From fb136ad345ddbedeec2ea2382dc44aa6ff0f5b2f Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Tue, 25 Jun 2024 05:32:56 -0700 Subject: [PATCH 222/256] Remove a dead function declaration. PiperOrigin-RevId: 646441411 --- third_party/xla/xla/service/gpu/model/indexing_analysis.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.h b/third_party/xla/xla/service/gpu/model/indexing_analysis.h index 9bc879de491ce2..f41e204a728bac 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis.h +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.h @@ -64,8 +64,6 @@ struct HloInstructionIndexing { std::ostream& operator<<(std::ostream& out, const HloInstructionIndexing& instr_indexing); -std::string ToString(const mlir::AffineMap& affine_map); - // Computes indexing maps for all input operands necessary to compute an element // of the `output_id` instruction output. HloInstructionIndexing ComputeOutputToInputIndexing(const HloInstruction* instr, From b2ce016554616a1da7b70e4a427ea27bf755b2a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Tue, 25 Jun 2024 05:59:55 -0700 Subject: [PATCH 223/256] [xla:cpu] Add LogicalId thunk encompassing PartitionId and ReplicaId PiperOrigin-RevId: 646447540 --- .../xla/xla/service/computation_placer.cc | 7 + .../xla/xla/service/computation_placer.h | 2 + third_party/xla/xla/service/cpu/BUILD | 2 +- third_party/xla/xla/service/cpu/runtime/BUILD | 23 ++-- .../service/cpu/runtime/logical_id_thunk.cc | 120 ++++++++++++++++++ ...{replica_id_thunk.h => logical_id_thunk.h} | 35 +++-- ...thunk_test.cc => logical_id_thunk_test.cc} | 64 ++++++++-- .../service/cpu/runtime/replica_id_thunk.cc | 80 ------------ .../xla/xla/service/cpu/thunk_emitter.cc | 12 +- .../xla/xla/service/cpu/thunk_emitter.h | 3 + 10 files changed, 237 insertions(+), 111 deletions(-) create mode 100644 third_party/xla/xla/service/cpu/runtime/logical_id_thunk.cc rename third_party/xla/xla/service/cpu/runtime/{replica_id_thunk.h => logical_id_thunk.h} (52%) rename third_party/xla/xla/service/cpu/runtime/{replica_id_thunk_test.cc => logical_id_thunk_test.cc} (50%) delete mode 100644 third_party/xla/xla/service/cpu/runtime/replica_id_thunk.cc diff --git a/third_party/xla/xla/service/computation_placer.cc b/third_party/xla/xla/service/computation_placer.cc index 58aedeb6319237..ee0cf2932a1e86 100644 --- a/third_party/xla/xla/service/computation_placer.cc +++ b/third_party/xla/xla/service/computation_placer.cc @@ -71,6 +71,13 @@ absl::StatusOr DeviceAssignment::ReplicaIdForDevice( return logical_id.replica_id; } +absl::StatusOr DeviceAssignment::PartitionIdForDevice( + GlobalDeviceId device_id) const { + TF_ASSIGN_OR_RETURN(const LogicalID logical_id, + LogicalIdForDevice(device_id)); + return logical_id.computation_id; +} + absl::flat_hash_map DeviceAssignment::GetDeviceToLogicalIdMap() const { absl::flat_hash_map diff --git a/third_party/xla/xla/service/computation_placer.h b/third_party/xla/xla/service/computation_placer.h index f85aa4153ac0d6..552fce2d84ba91 100644 --- a/third_party/xla/xla/service/computation_placer.h +++ b/third_party/xla/xla/service/computation_placer.h @@ -58,6 +58,8 @@ class DeviceAssignment : public Array2D { absl::StatusOr LogicalIdForDevice(GlobalDeviceId device_id) const; // Finds the replica ID for the given device. absl::StatusOr ReplicaIdForDevice(GlobalDeviceId device_id) const; + // Finds the partition ID for the given device. + absl::StatusOr PartitionIdForDevice(GlobalDeviceId device_id) const; // Returns a map from device ID to logical ID. Querying this map is much more // efficient than `LogicalIdForDevice` if queried repeatedly. absl::flat_hash_map GetDeviceToLogicalIdMap() diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 7ff98b5e32a330..11e1bc0d7bd59e 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -830,9 +830,9 @@ cc_library( "//xla/service/cpu/runtime:fft_thunk", "//xla/service/cpu/runtime:infeed_thunk", "//xla/service/cpu/runtime:kernel_thunk", + "//xla/service/cpu/runtime:logical_id_thunk", "//xla/service/cpu/runtime:outfeed_thunk", "//xla/service/cpu/runtime:reduce_scatter_thunk", - "//xla/service/cpu/runtime:replica_id_thunk", "//xla/service/cpu/runtime:rng_state_thunk", "//xla/service/cpu/runtime:thunk", "//xla/service/cpu/runtime:while_thunk", diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD index 360be19dfe940b..018063221843eb 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/service/cpu/runtime/BUILD @@ -615,23 +615,21 @@ xla_cc_test( ) cc_library( - name = "replica_id_thunk", - srcs = ["replica_id_thunk.cc"], - hdrs = ["replica_id_thunk.h"], + name = "logical_id_thunk", + srcs = ["logical_id_thunk.cc"], + hdrs = ["logical_id_thunk.h"], deps = [ ":thunk", "//xla:status_macros", - "//xla:util", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", - "//xla/service/cpu:cpu_runtime", + "//xla/service:computation_placer_hdr", + "//xla/service:global_device_id", "//xla/stream_executor", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:traceme", @@ -639,20 +637,19 @@ cc_library( ) xla_cc_test( - name = "replica_id_thunk_test", - srcs = ["replica_id_thunk_test.cc"], + name = "logical_id_thunk_test", + srcs = ["logical_id_thunk_test.cc"], deps = [ ":buffer_allocations", - ":replica_id_thunk", + ":logical_id_thunk", ":thunk", "//xla:executable_run_options", - "//xla:shape_util", "//xla/service:buffer_assignment", - "//xla/service:executable", "//xla/service:maybe_owning_device_memory", "//xla/stream_executor", "//xla/tsl/concurrency:async_value", - "@local_tsl//tsl/lib/core:status_test_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", diff --git a/third_party/xla/xla/service/cpu/runtime/logical_id_thunk.cc b/third_party/xla/xla/service/cpu/runtime/logical_id_thunk.cc new file mode 100644 index 00000000000000..9a047b54e72a5a --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/logical_id_thunk.cc @@ -0,0 +1,120 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/cpu/runtime/logical_id_thunk.h" + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "xla/runtime/buffer_use.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/computation_placer.h" +#include "xla/service/cpu/runtime/thunk.h" +#include "xla/service/global_device_id.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/traceme.h" + +namespace xla::cpu { + +static Thunk::Kind ToThunkKind(LogicalIdKind logical_id_kind) { + switch (logical_id_kind) { + case LogicalIdKind::kPartitionId: + return Thunk::Kind::kPartitionId; + case LogicalIdKind::kReplicaId: + return Thunk::Kind::kReplicaId; + } +} + +template +absl::StatusOr>> +LogicalIdThunk::Create(Info info, + BufferAllocation::Slice logical_id_buffer) { + return absl::WrapUnique( + new LogicalIdThunk(std::move(info), logical_id_buffer)); +} + +template +LogicalIdThunk::LogicalIdThunk(Info info, + BufferAllocation::Slice logical_id_buffer) + : Thunk(ToThunkKind(type), info), logical_id_buffer_(logical_id_buffer) {} + +template +static constexpr auto ToString() { + if constexpr (type == LogicalIdKind::kPartitionId) { + return "Partition"; + } else if constexpr (type == LogicalIdKind::kReplicaId) { + return "Replica"; + } +} + +template +absl::StatusOr LogicalIdThunk::GetIdForDevice( + const DeviceAssignment* device_assignment, GlobalDeviceId device_id) const { + if constexpr (type == LogicalIdKind::kPartitionId) { + return device_assignment->PartitionIdForDevice(device_id); + } else if constexpr (type == LogicalIdKind::kReplicaId) { + return device_assignment->ReplicaIdForDevice(device_id); + } +} + +template +tsl::AsyncValueRef::ExecuteEvent> +LogicalIdThunk::Execute(const ExecuteParams& params) { + tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); + + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase logical_id_data, + params.buffer_allocations->GetDeviceAddress(logical_id_buffer_)); + + TF_RET_CHECK(logical_id_data.size() == sizeof(int32_t)) + << "Logical id buffer must be able to fit logical id value"; + + TF_RET_CHECK(params.collective_params) + << ToString() << " id requires collective params"; + + TF_ASSIGN_OR_RETURN( + int32_t logical_id, + GetIdForDevice(params.collective_params->device_assignment, + params.collective_params->global_device_id)); + + VLOG(3) << absl::StreamFormat("%s id: %d", ToString(), logical_id); + VLOG(3) << absl::StreamFormat(" logical_id: slice %s (%p)", + logical_id_buffer_.ToString(), + logical_id_data.opaque()); + + std::memcpy(logical_id_data.opaque(), &logical_id, sizeof(int32_t)); + return OkExecuteEvent(); +} + +template +using BufferUses = typename LogicalIdThunk::BufferUses; + +template +BufferUses LogicalIdThunk::buffer_uses() const { + return {BufferUse::Write(logical_id_buffer_)}; +} + +template class LogicalIdThunk; +template class LogicalIdThunk; + +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/replica_id_thunk.h b/third_party/xla/xla/service/cpu/runtime/logical_id_thunk.h similarity index 52% rename from third_party/xla/xla/service/cpu/runtime/replica_id_thunk.h rename to third_party/xla/xla/service/cpu/runtime/logical_id_thunk.h index 22f61d372186c5..bd5bd84621e2a7 100644 --- a/third_party/xla/xla/service/cpu/runtime/replica_id_thunk.h +++ b/third_party/xla/xla/service/cpu/runtime/logical_id_thunk.h @@ -13,33 +13,52 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_REPLICA_ID_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_REPLICA_ID_THUNK_H_ +#ifndef XLA_SERVICE_CPU_RUNTIME_LOGICAL_ID_THUNK_H_ +#define XLA_SERVICE_CPU_RUNTIME_LOGICAL_ID_THUNK_H_ +#include #include #include "absl/status/statusor.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/computation_placer.h" #include "xla/service/cpu/runtime/thunk.h" +#include "xla/service/global_device_id.h" #include "xla/tsl/concurrency/async_value_ref.h" namespace xla::cpu { -class ReplicaIdThunk final : public Thunk { +enum class LogicalIdKind { + kPartitionId, + kReplicaId, +}; + +template +class LogicalIdThunk : public Thunk { public: - static absl::StatusOr> Create( - Info info, BufferAllocation::Slice replica_id_buffer); + static absl::StatusOr> Create( + Info info, BufferAllocation::Slice logical_id_buffer); tsl::AsyncValueRef Execute(const ExecuteParams& params) final; BufferUses buffer_uses() const final; private: - ReplicaIdThunk(Info info, BufferAllocation::Slice replica_id_buffer); + LogicalIdThunk(Info info, BufferAllocation::Slice logical_id_buffer); + + absl::StatusOr GetIdForDevice( + const DeviceAssignment* device_assignment, + GlobalDeviceId device_id) const; - BufferAllocation::Slice replica_id_buffer_; + BufferAllocation::Slice logical_id_buffer_; }; +class ReplicaIdThunk final : public LogicalIdThunk { +}; + +class PartitionIdThunk final + : public LogicalIdThunk {}; + } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_REPLICA_ID_THUNK_H_ +#endif // XLA_SERVICE_CPU_RUNTIME_LOGICAL_ID_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/replica_id_thunk_test.cc b/third_party/xla/xla/service/cpu/runtime/logical_id_thunk_test.cc similarity index 50% rename from third_party/xla/xla/service/cpu/runtime/replica_id_thunk_test.cc rename to third_party/xla/xla/service/cpu/runtime/logical_id_thunk_test.cc index c1345cef5e00a2..8aff05ca3d39e8 100644 --- a/third_party/xla/xla/service/cpu/runtime/replica_id_thunk_test.cc +++ b/third_party/xla/xla/service/cpu/runtime/logical_id_thunk_test.cc @@ -13,12 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/replica_id_thunk.h" +#include "xla/service/cpu/runtime/logical_id_thunk.h" #include #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "xla/executable_run_options.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/runtime/buffer_allocations.h" @@ -32,16 +34,23 @@ limitations under the License. namespace xla::cpu { namespace { -DeviceAssignment CreateDeviceAssignment(std::vector devices) { - DeviceAssignment device_assignment(/*replica_count=*/devices.size(), - /*computation_count=*/1); - for (int64_t i = 0; i < devices.size(); ++i) { - device_assignment(i, 0) = devices[i]; +absl::StatusOr CreateDeviceAssignment( + std::vector> devices) { + const auto computation_count = devices.size(); + if (devices.empty()) { + return absl::InternalError("Devices must not be empty."); + } + const auto replica_count = devices[0].size(); + DeviceAssignment device_assignment(replica_count, computation_count); + for (int64_t partition = 0; partition < computation_count; ++partition) { + for (int64_t replica = 0; replica < replica_count; ++replica) { + device_assignment(replica, partition) = devices[partition][replica]; + } } return device_assignment; } -TEST(ReplicaIdThunkTest, GetReplicaId) { +TEST(LogicalIdThunkTest, GetReplicaId) { std::vector dst(1, -1); std::vector buffers; @@ -55,7 +64,8 @@ TEST(ReplicaIdThunkTest, GetReplicaId) { TF_ASSERT_OK_AND_ASSIGN(auto thunk, ReplicaIdThunk::Create({name}, id_slice)); BufferAllocations allocations(buffers); - DeviceAssignment device_assn = CreateDeviceAssignment({0, 1}); + TF_ASSERT_OK_AND_ASSIGN(DeviceAssignment device_assn, + CreateDeviceAssignment({{0, 1}})); ExecutableRunOptions run_options; run_options.set_device_ordinal(0); @@ -75,5 +85,43 @@ TEST(ReplicaIdThunkTest, GetReplicaId) { EXPECT_EQ(dst[0], 0); } +TEST(LogicalIdThunkTest, GetPartitionId) { + std::vector dst(2, -1); + + std::vector buffers; + static constexpr auto kDataSize = 2 * sizeof(int32_t); + buffers.emplace_back(se::DeviceMemoryBase(dst.data(), kDataSize)); + + BufferAllocation alloc(/*index=*/0, /*size=*/kDataSize, /*color=*/0); + BufferAllocation::Slice id_slice(&alloc, /*offset=*/sizeof(int32_t), + /*size=*/sizeof(int32_t)); + + std::string name(Thunk::KindToString(Thunk::Kind::kPartitionId)); + TF_ASSERT_OK_AND_ASSIGN(auto thunk, + PartitionIdThunk::Create({name}, id_slice)); + + BufferAllocations allocations(buffers); + TF_ASSERT_OK_AND_ASSIGN(DeviceAssignment device_assn, + CreateDeviceAssignment({{0}, {1}})); + + ExecutableRunOptions run_options; + run_options.set_device_ordinal(0); + run_options.set_device_assignment(&device_assn); + + TF_ASSERT_OK_AND_ASSIGN(Thunk::CollectiveExecuteParams collective_params, + Thunk::CollectiveExecuteParams::Create(&run_options)); + + Thunk::ExecuteParams params; + params.buffer_allocations = &allocations; + params.collective_params = &collective_params; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_FALSE(execute_event.IsError()); + + EXPECT_EQ(dst[0], -1); + EXPECT_EQ(dst[1], 0); +} + } // namespace } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/replica_id_thunk.cc b/third_party/xla/xla/service/cpu/runtime/replica_id_thunk.cc deleted file mode 100644 index 974e39418f148d..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/replica_id_thunk.cc +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/cpu/runtime/replica_id_thunk.h" - -#include -#include -#include -#include - -#include "absl/memory/memory.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_format.h" -#include "xla/runtime/buffer_use.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/status_macros.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" -#include "tsl/profiler/lib/traceme.h" - -namespace xla::cpu { - -absl::StatusOr> ReplicaIdThunk::Create( - Info info, BufferAllocation::Slice replica_id_buffer) { - return absl::WrapUnique( - new ReplicaIdThunk(std::move(info), replica_id_buffer)); -} - -ReplicaIdThunk::ReplicaIdThunk(Info info, - BufferAllocation::Slice replica_id_buffer) - : Thunk(Kind::kReplicaId, info), replica_id_buffer_(replica_id_buffer) {} - -tsl::AsyncValueRef ReplicaIdThunk::Execute( - const ExecuteParams& params) { - tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); - - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase replica_id_data, - params.buffer_allocations->GetDeviceAddress(replica_id_buffer_)); - - TF_RET_CHECK(replica_id_data.size() == sizeof(int32_t)) - << "Replica id buffer must be able to fit replica id value"; - - TF_RET_CHECK(params.collective_params) - << "Replica id requires collective params"; - - TF_ASSIGN_OR_RETURN( - int32_t replica_id, - params.collective_params->device_assignment->ReplicaIdForDevice( - params.collective_params->global_device_id)); - - VLOG(3) << absl::StreamFormat("Replica id: %d", replica_id); - VLOG(3) << absl::StreamFormat(" replica_id: slice %s (%p)", - replica_id_buffer_.ToString(), - replica_id_data.opaque()); - - std::memcpy(replica_id_data.opaque(), &replica_id, sizeof(int32_t)); - return OkExecuteEvent(); -} - -ReplicaIdThunk::BufferUses ReplicaIdThunk::buffer_uses() const { - return {BufferUse::Write(replica_id_buffer_)}; -} - -} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.cc b/third_party/xla/xla/service/cpu/thunk_emitter.cc index e9c8807fc3f887..c0fa81beb7c9ca 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.cc +++ b/third_party/xla/xla/service/cpu/thunk_emitter.cc @@ -52,9 +52,9 @@ limitations under the License. #include "xla/service/cpu/runtime/fft_thunk.h" #include "xla/service/cpu/runtime/infeed_thunk.h" #include "xla/service/cpu/runtime/kernel_thunk.h" +#include "xla/service/cpu/runtime/logical_id_thunk.h" #include "xla/service/cpu/runtime/outfeed_thunk.h" #include "xla/service/cpu/runtime/reduce_scatter_thunk.h" -#include "xla/service/cpu/runtime/replica_id_thunk.h" #include "xla/service/cpu/runtime/rng_state_thunk.h" #include "xla/service/cpu/runtime/thunk.h" #include "xla/service/cpu/runtime/while_thunk.h" @@ -216,6 +216,8 @@ absl::StatusOr ThunkEmitter::EmitHloInstruction( // a logical grid of communicating devices. case HloOpcode::kReplicaId: return EmitReplicaIdThunk(instruction); + case HloOpcode::kPartitionId: + return EmitPartitionIdThunk(instruction); case HloOpcode::kAllGather: return EmitAllGatherThunk(instruction); @@ -693,6 +695,14 @@ absl::StatusOr ThunkEmitter::EmitReplicaIdThunk( replica_id_buffer); } +absl::StatusOr ThunkEmitter::EmitPartitionIdThunk( + const HloInstruction* instruction) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice partition_id_buffer, + GetAllocationSlice(instruction)); + return ThunkSequence::Of(ThunkInfo(instruction), + partition_id_buffer); +} + absl::StatusOr ThunkEmitter::EmitFftThunk( const HloInstruction* instruction) { TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.h b/third_party/xla/xla/service/cpu/thunk_emitter.h index f536beea0b6302..66c4bfd71d2b9e 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.h +++ b/third_party/xla/xla/service/cpu/thunk_emitter.h @@ -110,6 +110,9 @@ class ThunkEmitter { absl::StatusOr EmitReplicaIdThunk( const HloInstruction* instruction); + absl::StatusOr EmitPartitionIdThunk( + const HloInstruction* instruction); + absl::StatusOr EmitAllGatherThunk( const HloInstruction* instruction); From 2bc3af3a12a0e79877d090ae0d4613338be5cccc Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Tue, 25 Jun 2024 07:27:19 -0700 Subject: [PATCH 224/256] [XLA:GPU][NFC] Section comments for GPU LHS components. PiperOrigin-RevId: 646470872 --- .../service/gpu/gpu_latency_hiding_scheduler.cc | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc index 1d5a3b99d49e43..4ecb194121022b 100644 --- a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc +++ b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc @@ -116,7 +116,9 @@ CanonicalAsyncOp GpuGetCanonicalAsyncOp(const HloInstruction& hlo) { } } -// GpuAsyncTrackerBase implementations begin +//===--------------------------------------------------------------------===// +// GpuAsyncTrackerBase +//===--------------------------------------------------------------------===// GpuAsyncTrackerBase::GpuAsyncTrackerBase(const SchedulerConfig& config, GetCanonicalAsyncOpFunc func) : AsyncTracker(config, func) {} @@ -160,9 +162,10 @@ void GpuAsyncTrackerBase::PostProcessScheduleGraph( } } } -// GpuAsyncTrackerBase implementations end -// GpuAsyncTracker implementations begin +//===--------------------------------------------------------------------===// +// GpuAsyncTracker +//===--------------------------------------------------------------------===// GpuAsyncTracker::GpuAsyncTracker(const SchedulerConfig& config) : GpuAsyncTrackerBase(config) {} @@ -308,9 +311,9 @@ int64_t GpuAsyncTracker::GetNumResourcesPerInstruction( return num_resources - (found ? 1 : 0); } -// GpuAsyncTracker implementations end - -// GpuLatencyEstimator implementations begin +//===--------------------------------------------------------------------===// +// GpuLatencyEstimator +//===--------------------------------------------------------------------===// GpuLatencyEstimator::GpuLatencyEstimator(int64_t pointer_size, GetCanonicalAsyncOpFunc func) : ApproximateLatencyEstimator(func), pointer_size_(pointer_size) {} From c6716bcd3e68bd7f44587094402c2b9bcc33ba33 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Tue, 25 Jun 2024 07:42:23 -0700 Subject: [PATCH 225/256] PR #13450: Fold reduce_window + slice[-1] to reduce + reshape Imported from GitHub PR https://github.com/openxla/xla/pull/13450 In some cases framework can generate reduce_window + slice instead of simple reduce. This PR extends [existing simplification](https://github.com/openxla/xla/blob/main/xla/service/algebraic_simplifier.cc#L7756) of `reduce_window` to `reduce` for the following patterns: ``` // Example 1 (one reduce dimension): r = s32[2,8] reduce-window(s32[2,8] p, c), window={size=1x8 pad=0_0x7_0} s = s32[2,1] slice(r), slice={[0:2], [7:8]} // Can be folded to: r = s32[2] reduce(s32[2,8] p, c), dimensions={1}, s = s32[2] reshape(r) ``` ``` // Example 2 (two reduce dimensions): p = s32[3,4,2] r = s32[3,4,2] reduce-window(p, c), window={size=1x4x2 pad=0_0x3_0x1_0} s = s32[3,1,1] slice(r), slice={[0:3], [3:4], [1:2]} // Can be folded to: r = s32[3] reduce(p, c), dimensions={1,2}, s = s32[3,1,1] reshape(r) ``` @DavidNorman @blakehechtman Could you review it? Copybara import of the project: -- 131f1dad265d038732dff9218b793ec4b1ad9afa by Alexander Pivovarov : Fold reduce_window + slice[-1] to reduce + reshape Merging this change closes #13450 PiperOrigin-RevId: 646475259 --- .../xla/xla/service/algebraic_simplifier.cc | 82 +++++++++++++++++++ .../xla/service/algebraic_simplifier_test.cc | 60 ++++++++++++++ 2 files changed, 142 insertions(+) diff --git a/third_party/xla/xla/service/algebraic_simplifier.cc b/third_party/xla/xla/service/algebraic_simplifier.cc index 8ed51b652109b5..dd13ffebd9e852 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.cc +++ b/third_party/xla/xla/service/algebraic_simplifier.cc @@ -6597,6 +6597,88 @@ absl::Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { } } + if (HloInstruction * reduce_window; + options_.enable_window_reduce_to_reduce_replacement() && + hlo_instruction_utils::IsUnstridedSlice(slice) && + Match(slice, m::Slice(m::ReduceWindow(&reduce_window).WithOneUse()))) { + // A reduce_window with window pad + slice[:,-1] can be expressed as + // reduce + reshape if all dimensions either have a window size of one or + // the entire dimension. No stride or dilation are expected. reduce_window + // pad should be present to make output shape equal to input shape. + // slice_limit[dim] should be equal to reduce_window shape[dim]. + // slice_limit[dim] - slice_start[dim] should be equal to 1 for reduced dim + // + // The reshape is a bitcast since it adds one-sized dimensions. Often + // these ones are immediately removed as well with another reshape. The + // implementation of reduce tends to be slightly more efficient at + // reducing entire dimensions compared to reduce window. + // + //// Example 1: + // r = s32[2,8] reduce-window(s32[2,8] p, c), window={size=1x8 pad=0_0x7_0} + // s = s32[2,1] slice(r), slice={[0:2], [7:8]} + //// Can be folded to: + // r = s32[2] reduce(s32[2,8] p, c), dimensions={1}, + // s = s32[2] reshape(r) + // + //// Example 2: + // p = s32[3,4,2] + // r = s32[3,4,2] reduce-window(p, c), window={size=1x4x2 pad=0_0x3_0x1_0} + // s = s32[3,1,1] slice(r), slice={[0:3], [3:4], [1:2]} + //// Can be folded to: + // r = s32[3] reduce(p, c), dimensions={1,2}, + // s = s32[3,1,1] reshape(r) + auto effective_reduce_dims = [&] { + auto& window = reduce_window->window(); + // reduce_window should have padding, but no Strides and Dilation + if (window_util::HasStride(window) || window_util::HasDilation(window) || + !window_util::HasPadding(window)) { + return DimensionVector{}; + } + auto rank = reduce_window->shape().dimensions_size(); + auto& slice_starts = slice->slice_starts(); + auto& slice_limits = slice->slice_limits(); + DimensionVector reduce_dims; + for (auto i = 0; i < rank; ++i) { + auto window_dim_size = window.dimensions(i).size(); + auto reduce_window_dim_size = reduce_window->shape().dimensions(i); + auto slice_dim_size = slice->shape().dimensions(i); + if (reduce_window_dim_size != slice_limits[i] || + window.dimensions(i).padding_low() != slice_starts[i] || + window.dimensions(i).padding_high() != 0) { + return DimensionVector{}; + } + if (window_dim_size == 1 && reduce_window_dim_size == slice_dim_size && + slice_starts[i] == 0) { + continue; + } + if (slice_dim_size == 1 && reduce_window_dim_size == window_dim_size && + slice_limits[i] - slice_starts[i] == 1) { + reduce_dims.push_back(i); + } else { + return DimensionVector{}; + } + } + return reduce_dims; + }(); + + // If a reduce window can be expressed as a reduce, do so and reshape the + // output. + if (!effective_reduce_dims.empty()) { + Shape reduce_shape = ShapeUtil::DeleteDimensions(effective_reduce_dims, + reduce_window->shape()); + simplifier_->UpdateLayout(&reduce_shape); + HloInstruction* reduce = + slice->AddInstruction(HloInstruction::CreateReduce( + /*shape=*/reduce_shape, + /*operand=*/reduce_window->mutable_operand(0), + /*init_value=*/reduce_window->mutable_operand(1), + /*dimensions_to_reduce=*/effective_reduce_dims, + /*reduce_computation=*/reduce_window->to_apply())); + return ReplaceWithNewInstruction( + slice, HloInstruction::CreateReshape(slice->shape(), reduce)); + } + } + // Do not try to reorder slices and reshapes after layout assignment as it may // be invalid. if (!options_.is_layout_sensitive()) { diff --git a/third_party/xla/xla/service/algebraic_simplifier_test.cc b/third_party/xla/xla/service/algebraic_simplifier_test.cc index 2e2b81eb5ab3fa..00970a51546b1a 100644 --- a/third_party/xla/xla/service/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/service/algebraic_simplifier_test.cc @@ -6721,6 +6721,66 @@ TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) { EXPECT_EQ(root->slice_limits(0), 2); } +TEST_F(AlgebraicSimplifierTest, SliceOfReduceWindowOneReduceDim) { + const char* hlo = R"( + HloModule m + Add.1 { + p0 = s32[] parameter(0) + p1 = s32[] parameter(1) + ROOT r = s32[] add(p0, p1) + } + ENTRY test { + p0 = s32[2,8] parameter(0) + c0 = s32[] constant(0) + r = s32[2,8] reduce-window(s32[2,8] p0, s32[] c0), window={size=1x8 pad=0_0x7_0}, to_apply=Add.1 + ROOT s = s32[2,1] slice(s32[2,8] r), slice={[0:2], [7:8]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo)); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).value()); + auto root = m->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + GmockMatch(m::Reshape(m::Reduce(m::Parameter(0), m::Constant()) + .WithShape(S32, {2}) + .WithPredicate([](const HloInstruction* instr) { + return instr->dimensions() == + std::vector({1}); + })) + .WithShape(S32, {2, 1}))); +} + +TEST_F(AlgebraicSimplifierTest, SliceOfReduceWindowTwoReduceDims) { + const char* hlo = R"( + HloModule m + Add.1 { + p0 = s32[] parameter(0) + p1 = s32[] parameter(1) + ROOT r = s32[] add(p0, p1) + } + ENTRY test { + p0 = s32[3,4,2] parameter(0) + c0 = s32[] constant(0) + r = s32[3,4,2] reduce-window(s32[3,4,2] p0, s32[] c0), window={size=1x4x2 pad=0_0x3_0x1_0}, to_apply=Add.1 + ROOT s = s32[3,1,1] slice(s32[3,4,2] r), slice={[0:3], [3:4], [1:2]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo)); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).value()); + auto root = m->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + GmockMatch(m::Reshape(m::Reduce(m::Parameter(0), m::Constant()) + .WithShape(S32, {3}) + .WithPredicate([](const HloInstruction* instr) { + return instr->dimensions() == + std::vector({1, 2}); + })) + .WithShape(S32, {3, 1, 1}))); +} + TEST_F(AlgebraicSimplifierTest, ConcatToBroadcast) { const char* hlo_string = R"( HloModule module From f14adf79822cb8bb4f5670e2f94a8b1c0862ad5a Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Tue, 25 Jun 2024 07:58:23 -0700 Subject: [PATCH 226/256] Call Stream::Launch instead of StreamExecutor::Launch everywhere. Future CLs will remove ::Launch from StreamExecutor, and place it on the correct stream. PiperOrigin-RevId: 646479446 --- .../kernels/cutlass_gemm_custom_kernel_benchmarks.cc | 4 ++-- .../gpu/kernels/cutlass_gemm_custom_kernel_test.cc | 8 ++++---- .../xla/service/gpu/kernels/topk_custom_kernel_test.cc | 8 ++++---- .../xla/xla/service/gpu/runtime/kernel_thunk.cc | 10 +++++----- .../xla/xla/service/gpu/stream_executor_util.cc | 9 ++++----- .../xla/stream_executor/gpu/gpu_command_buffer_test.cc | 2 +- .../xla/xla/stream_executor/host/host_kernel_test.cc | 6 ++---- third_party/xla/xla/stream_executor/stream.h | 8 ++++++++ third_party/xla/xla/stream_executor/stream_common.cc | 7 +++++++ third_party/xla/xla/stream_executor/stream_common.h | 3 +++ 10 files changed, 40 insertions(+), 25 deletions(-) diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc index b843fac733154c..4175412335ee84 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc @@ -74,8 +74,8 @@ static void BM_RowMajorGemm(benchmark::State& state) { custom_kernel->shared_memory_bytes()); for (auto s : state) { - TF_CHECK_OK(executor->Launch(stream.get(), custom_kernel->thread_dims(), - custom_kernel->block_dims(), *gemm, args)); + TF_CHECK_OK(stream->Launch(custom_kernel->thread_dims(), + custom_kernel->block_dims(), *gemm, args)); TF_CHECK_OK(stream->BlockHostUntilDone()); } } diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc index b566fceab8511c..9d94ed77dfb275 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc @@ -70,8 +70,8 @@ TEST(CutlassGemmKernelTest, SimpleGemm) { se::KernelArgsDeviceMemoryArray arr( std::vector({a, b, c}), custom_kernel->shared_memory_bytes()); - TF_ASSERT_OK(executor->Launch(stream.get(), custom_kernel->thread_dims(), - custom_kernel->block_dims(), *gemm, arr)); + TF_ASSERT_OK(stream->Launch(custom_kernel->thread_dims(), + custom_kernel->block_dims(), *gemm, arr)); // Copy `c` data back to host. std::vector dst(length, -1.0f); @@ -120,8 +120,8 @@ TEST(CutlassGemmKernelTest, LoadFromSharedLibrary) { se::KernelArgsDeviceMemoryArray arr( std::vector({a, b, c}), custom_kernel->shared_memory_bytes()); - TF_ASSERT_OK(executor->Launch(stream.get(), custom_kernel->thread_dims(), - custom_kernel->block_dims(), *gemm, arr)); + TF_ASSERT_OK(stream->Launch(custom_kernel->thread_dims(), + custom_kernel->block_dims(), *gemm, arr)); // Copy `c` data back to host. std::vector dst(length, -1.0f); diff --git a/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel_test.cc b/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel_test.cc index b083118d0a8d7c..0a8a4d9342b81d 100644 --- a/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel_test.cc @@ -120,8 +120,8 @@ TEST_P(TopKKernelTest, TopKFloat) { std::vector( {input_buffer, output_values, output_indices}), custom_kernel->shared_memory_bytes()); - TF_ASSERT_OK(executor->Launch(stream.get(), custom_kernel->thread_dims(), - custom_kernel->block_dims(), *kernel, arr)); + TF_ASSERT_OK(stream->Launch(custom_kernel->thread_dims(), + custom_kernel->block_dims(), *kernel, arr)); std::vector got(k); ASSERT_TRUE(stream->BlockHostUntilDone().ok()); @@ -175,8 +175,8 @@ TEST_P(TopKKernelTest, TopKPackedNegative) { std::vector( {input_buffer, output_values, output_indices}), custom_kernel->shared_memory_bytes()); - TF_ASSERT_OK(executor->Launch(stream.get(), custom_kernel->thread_dims(), - custom_kernel->block_dims(), *kernel, arr)); + TF_ASSERT_OK(stream->Launch(custom_kernel->thread_dims(), + custom_kernel->block_dims(), *kernel, arr)); std::vector got(k); ASSERT_TRUE(stream->BlockHostUntilDone().ok()); diff --git a/third_party/xla/xla/service/gpu/runtime/kernel_thunk.cc b/third_party/xla/xla/service/gpu/runtime/kernel_thunk.cc index ad15ee2eb5a6f7..3ea5a010658af4 100644 --- a/third_party/xla/xla/service/gpu/runtime/kernel_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/kernel_thunk.cc @@ -223,12 +223,12 @@ absl::Status CustomKernelThunk::ExecuteOnStream(const ExecuteParams& params) { custom_kernel_.shared_memory_bytes()); if (auto cluster = custom_kernel_.cluster_dims(); cluster.has_value()) { - return executor->Launch(params.stream, custom_kernel_.thread_dims(), - custom_kernel_.block_dims(), *cluster, *kernel, - args); + return params.stream->Launch(custom_kernel_.thread_dims(), + custom_kernel_.block_dims(), *cluster, *kernel, + args); } else { - return executor->Launch(params.stream, custom_kernel_.thread_dims(), - custom_kernel_.block_dims(), *kernel, args); + return params.stream->Launch(custom_kernel_.thread_dims(), + custom_kernel_.block_dims(), *kernel, args); } } diff --git a/third_party/xla/xla/service/gpu/stream_executor_util.cc b/third_party/xla/xla/service/gpu/stream_executor_util.cc index c278fbe8b4ff37..141308de52f834 100644 --- a/third_party/xla/xla/service/gpu/stream_executor_util.cc +++ b/third_party/xla/xla/service/gpu/stream_executor_util.cc @@ -393,8 +393,8 @@ absl::Status ExecuteKernelOnStream(const se::Kernel& kernel, std::unique_ptr kernel_args, se::PackKernelArgs(args, kernel.metadata())); - return stream->parent()->Launch(stream, dims.thread_counts_per_block(), - dims.block_counts(), kernel, *kernel_args); + return stream->Launch(dims.thread_counts_per_block(), dims.block_counts(), + kernel, *kernel_args); } absl::Status ExecuteKernelOnStream(const se::Kernel& kernel, @@ -406,9 +406,8 @@ absl::Status ExecuteKernelOnStream(const se::Kernel& kernel, std::unique_ptr kernel_args, se::PackKernelArgs(args, kernel.metadata())); - return stream->parent()->Launch(stream, dims.thread_counts_per_block(), - dims.block_counts(), cluster_dim, kernel, - *kernel_args); + return stream->Launch(dims.thread_counts_per_block(), dims.block_counts(), + cluster_dim, kernel, *kernel_args); } // Unimplemented for integers yet. diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc index 14c5897118ead7..ef31559eefc5bd 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc @@ -206,7 +206,7 @@ TEST(CudaCommandBufferTest, TraceSingleKernel) { auto cmd_buffer = TraceCommandBufferFactory::Create( executor, [&](Stream* stream) { - return executor->Launch(stream, ThreadDim(), BlockDim(4), *add, args); + return stream->Launch(ThreadDim(), BlockDim(4), *add, args); }, primary); diff --git a/third_party/xla/xla/stream_executor/host/host_kernel_test.cc b/third_party/xla/xla/stream_executor/host/host_kernel_test.cc index de2ddabedef777..5a121bf17cb5b7 100644 --- a/third_party/xla/xla/stream_executor/host/host_kernel_test.cc +++ b/third_party/xla/xla/stream_executor/host/host_kernel_test.cc @@ -168,8 +168,7 @@ TEST(HostKernelTest, Addition3D) { KernelFactory::Create(executor.get(), spec)); const KernelArgsDeviceMemoryArray kargs{args, /*shared_memory_bytes=*/0}; - TF_ASSERT_OK(executor->Launch(stream.get(), ThreadDim(2, 2, 3), BlockDim(1), - *add, kargs)); + TF_ASSERT_OK(stream->Launch(ThreadDim(2, 2, 3), BlockDim(1), *add, kargs)); std::vector expected = {11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33}; @@ -196,8 +195,7 @@ TEST(HostKernelTest, JitAddition) { KernelFactory::Create(executor.get(), spec)); const KernelArgsDeviceMemoryArray kargs{args, /*shared_memory_bytes=*/0}; - TF_ASSERT_OK( - executor->Launch(stream.get(), ThreadDim(4), BlockDim(1), *add, kargs)); + TF_ASSERT_OK(stream->Launch(ThreadDim(4), BlockDim(1), *add, kargs)); std::vector expected = {6, 8, 10, 12}; EXPECT_EQ(out, expected); diff --git a/third_party/xla/xla/stream_executor/stream.h b/third_party/xla/xla/stream_executor/stream.h index 9b74136f8e06c1..8ccb438c98cdfc 100644 --- a/third_party/xla/xla/stream_executor/stream.h +++ b/third_party/xla/xla/stream_executor/stream.h @@ -265,6 +265,14 @@ class Stream { virtual absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, const Kernel &k, const KernelArgs &args) = 0; + + // Launches a data parallel kernel with the given thread/block + // dimensionality and already-packed args/sizes to pass to the underlying + // platform driver. + virtual absl::Status Launch(const ThreadDim &thread_dims, + const BlockDim &block_dims, + const ClusterDim &cluster_dims, const Kernel &k, + const KernelArgs &args) = 0; }; template diff --git a/third_party/xla/xla/stream_executor/stream_common.cc b/third_party/xla/xla/stream_executor/stream_common.cc index e9d46afc5174f2..3ce495e35d2740 100644 --- a/third_party/xla/xla/stream_executor/stream_common.cc +++ b/third_party/xla/xla/stream_executor/stream_common.cc @@ -54,6 +54,13 @@ absl::Status StreamCommon::Launch(const ThreadDim &thread_dims, return parent_->Launch(this, thread_dims, block_dims, k, args); } +absl::Status StreamCommon::Launch(const ThreadDim &thread_dims, + const BlockDim &block_dims, + const ClusterDim &cluster_dims, + const Kernel &k, const KernelArgs &args) { + return parent_->Launch(this, thread_dims, block_dims, cluster_dims, k, args); +} + StreamCommon::PlatformSpecificHandle StreamCommon::platform_specific_handle() const { PlatformSpecificHandle handle; diff --git a/third_party/xla/xla/stream_executor/stream_common.h b/third_party/xla/xla/stream_executor/stream_common.h index e5d501ca3c1280..4f00641ae7d947 100644 --- a/third_party/xla/xla/stream_executor/stream_common.h +++ b/third_party/xla/xla/stream_executor/stream_common.h @@ -92,6 +92,9 @@ class StreamCommon : public Stream { } absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, const Kernel &k, const KernelArgs &args) override; + absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, + const ClusterDim &cluster_dims, const Kernel &k, + const KernelArgs &args) override; protected: bool InErrorState() const TF_LOCKS_EXCLUDED(mu_) { From 9873a6ae693fadce57a234700f14d485d93efe4e Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Tue, 25 Jun 2024 08:03:01 -0700 Subject: [PATCH 227/256] Restore non-CUDA dependencies for :autotuner_util_test The dependency `:xla_internal_test_main` was recently accidentally moved into the `if_cuda_is_configured` macro which makes the test not compiled anymore in a non-CUDA build config. This change restores the dependency and also sorts the CUDA specific dependencies. PiperOrigin-RevId: 646480891 --- third_party/xla/xla/service/gpu/BUILD | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 9755cf25cb81db..0383ca23925e73 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -5840,35 +5840,37 @@ xla_cc_test( name = "autotuner_util_test", srcs = if_cuda_is_configured(["autotuner_util_test.cc"]), deps = if_cuda_is_configured([ + # keep sorted ":autotuner_util", - "@com_google_googletest//:gtest", - "@com_google_absl//absl/base:log_severity", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/log:scoped_mock_log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", "//xla:autotune_results_proto_cc", "//xla:autotuning_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", - "//xla/stream_executor:platform_manager", "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", "//xla/stream_executor/host:host_platform", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", - "//xla/tests:xla_internal_test_main", # Keep outside GPU guard + "@com_google_absl//absl/base:log_severity", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:scoped_mock_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", - ]), + ]) + [ + "//xla/tests:xla_internal_test_main", # Keep outside GPU guard + ], ) cc_library( From 5f23cec9d40a8a832a591f871462cb0b551991a0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jun 2024 08:11:36 -0700 Subject: [PATCH 228/256] Use compiler/mlir version of control edges for flatbuffer_import, and remove framework dep. PiperOrigin-RevId: 646483715 --- tensorflow/compiler/mlir/lite/BUILD | 3 ++- tensorflow/compiler/mlir/lite/flatbuffer_import.cc | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 6ab833fcac1394..2a48f298b114b5 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1152,6 +1152,7 @@ cc_library( ":offset_buffer", ":size_utils", ":tensorflow_lite", + "//tensorflow/compiler/mlir/lite:control_edges", "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", @@ -1168,7 +1169,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:errors", "//tensorflow/core/platform:status", - "//tensorflow/lite:framework", + "//tensorflow/lite:model_builder", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 58f0deee7c95e9..a03e988fde32fa 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -83,6 +83,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h" +#include "tensorflow/compiler/mlir/lite/utils/control_edges.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/lite/utils/size_utils.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h" @@ -97,7 +98,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/lite/graph_info.h" #include "tensorflow/lite/model_builder.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" From d657a42f9c9e45e20218f6e458424a62427a451e Mon Sep 17 00:00:00 2001 From: Marissa Ikonomidis Date: Tue, 25 Jun 2024 09:04:38 -0700 Subject: [PATCH 229/256] Add a new batching padding policy to the batching op. Original author is piatov@. PiperOrigin-RevId: 646498929 --- .../tensorflow/passes/convert_tpu_model_to_cpu.td | 2 +- .../compiler/mlir/tensorflow/ir/tf_generated_ops.td | 7 ++++--- tensorflow/core/kernels/batch_kernels.cc | 1 + tensorflow/core/kernels/batch_kernels.h | 1 + tensorflow/core/kernels/batch_kernels_test.cc | 2 ++ tensorflow/core/ops/batch_ops.cc | 8 ++++++++ .../runtime_fallback/runtime/fallback_batch_kernel.cc | 1 + .../core/runtime_fallback/runtime/fallback_batch_kernel.h | 1 + .../runtime/runtime_fallback_batch_tf_opkernels.cc | 5 +++++ tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc | 4 ++++ tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt | 2 +- tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt | 2 +- 12 files changed, 30 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.td index 03e7b18569d7a2..5be01f936250b0 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.td @@ -26,7 +26,7 @@ def GetBatchFunctionOpArgOperands: // because `TF_BatchFunctionOp` doesn't have the `CallOpInterface` trait. def ReplaceBatchFunctionOpToPartitionedCallOp : Pat< (TF_BatchFunctionOp:$src_op_res - $_, $_, $f, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_), + $_, $_, $f, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_), (TF_PartitionedCallOp (GetBatchFunctionOpArgOperands $src_op_res), $f, diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index d5a190db180e0a..3b531c4ed0590b 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -1137,6 +1137,7 @@ to be batched.}]>:$captured_tensors, DefaultValuedOptionalAttr:$low_priority_allowed_batch_sizes, DefaultValuedOptionalAttr:$low_priority_max_enqueued_batches, DefaultValuedOptionalAttr, "\"low_priority_padding_with_max_batch_size\"">:$mixed_priority_policy, + DefaultValuedOptionalAttr, "\"PAD_UP\"">:$batch_padding_policy, DefaultValuedOptionalAttr:$enable_large_batch_splitting ); @@ -15362,7 +15363,7 @@ e.g. Max(segment_ids) should be equal to `num_segments` - 1 for a 1-d segment_id With inconsistent num_segments, the op still runs. only difference is, the output takes the size of num_segments irrespective of size of segment_ids and data. for num_segments less than expected output size, the last elements are ignored -for num_segments more than the expected output size, last elements are assigned +for num_segments more than the expected output size, last elements are assigned smallest possible value for the specific numeric type. For example: @@ -15532,7 +15533,7 @@ e.g. Max(segment_ids) should be equal to `num_segments` - 1 for a 1-d segment_id With inconsistent num_segments, the op still runs. only difference is, the output takes the size of num_segments irrespective of size of segment_ids and data. for num_segments less than expected output size, the last elements are ignored -for num_segments more than the expected output size, last elements are assigned +for num_segments more than the expected output size, last elements are assigned the largest possible value for the specific numeric type. For example: @@ -15638,7 +15639,7 @@ The only difference with SegmentProd is the additional input `num_segments`. This helps in evaluating the output shape in compile time. `num_segments` should be consistent with segment_ids. e.g. Max(segment_ids) - 1 should be equal to `num_segments` for a 1-d segment_ids -With inconsistent num_segments, the op still runs. only difference is, +With inconsistent num_segments, the op still runs. only difference is, the output takes the size of num_segments irrespective of size of segment_ids and data. for num_segments less than expected output size, the last elements are ignored for num_segments more than the expected output size, last elements are assigned 1. diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc index f8634974c0e3e1..8e1e97dc2565f9 100644 --- a/tensorflow/core/kernels/batch_kernels.cc +++ b/tensorflow/core/kernels/batch_kernels.cc @@ -305,6 +305,7 @@ BatchFunctionKernel::BatchFunctionKernel(OpKernelConstruction* c) &low_priority_max_enqueued_batches_)); OP_REQUIRES_OK(c, c->GetAttr("mixed_priority_policy", &mixed_priority_policy_)); + OP_REQUIRES_OK(c, c->GetAttr("batch_padding_policy", &batch_padding_policy_)); OP_REQUIRES_OK(c, c->GetAttr("f", &func_)); diff --git a/tensorflow/core/kernels/batch_kernels.h b/tensorflow/core/kernels/batch_kernels.h index f13bff06d303ca..11373af2048991 100644 --- a/tensorflow/core/kernels/batch_kernels.h +++ b/tensorflow/core/kernels/batch_kernels.h @@ -111,6 +111,7 @@ class BatchFunctionKernel : public AsyncOpKernel { int32 low_priority_max_enqueued_batches_; std::vector low_priority_allowed_batch_sizes_; std::string mixed_priority_policy_; + std::string batch_padding_policy_; NameAttrList func_; absl::optional fhandle_ TF_GUARDED_BY(mu_); bool enable_large_batch_splitting_ = false; diff --git a/tensorflow/core/kernels/batch_kernels_test.cc b/tensorflow/core/kernels/batch_kernels_test.cc index 4bdc1dbb3c8140..62666c099518fd 100644 --- a/tensorflow/core/kernels/batch_kernels_test.cc +++ b/tensorflow/core/kernels/batch_kernels_test.cc @@ -139,6 +139,7 @@ class BatchFunctionTestState : public SharedBatchFunctionTestState { .Attr("low_priority_max_enqueued_batches", enable_low_priority_queue ? 2 : 0) .Attr("mixed_priority_policy", mixed_priority_policy) + .Attr("batch_padding_policy", "PAD_UP") .Attr("Tin", {DataType::DT_INT64}) .Input(inputs) .Attr("Tcaptured", std::vector{}) @@ -609,6 +610,7 @@ class BatchFunctionKernelParallelWarmupTestState .Attr("low_priority_batch_timeout_micros", 8000) .Attr("low_priority_allowed_batch_sizes", {32, 64}) .Attr("low_priority_max_enqueued_batches", 1000) + .Attr("batch_padding_policy", "PAD_UP") .Attr("Tin", {DataType::DT_INT64}) .Input(inputs) .Attr("Tcaptured", std::vector{}) diff --git a/tensorflow/core/ops/batch_ops.cc b/tensorflow/core/ops/batch_ops.cc index 6361fe72d850ac..99d45512374584 100644 --- a/tensorflow/core/ops/batch_ops.cc +++ b/tensorflow/core/ops/batch_ops.cc @@ -71,6 +71,14 @@ REGISTER_OP("BatchFunction") "{'low_priority_padding_with_max_batch_size', " "'low_priority_padding_with_next_allowed_batch_size', " "'priority_isolation'} = 'low_priority_padding_with_max_batch_size'") + // The policy that a batch scheduler is using when deciding what to do when, + // say, 18 requests need to be batched, but only 16 and 32 batch sizes are + // allowed. The following options are available. + // + // - PAD_UP: pad to size 32. + .Attr( + "batch_padding_policy: " + "{'PAD_UP'} = 'PAD_UP'") .Attr("Tin: list(type)") .Attr("Tcaptured: list(type) >= 0") .Attr("Tout: list(type)") diff --git a/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.cc b/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.cc index 54f4fcbc863ce9..227b5b1a65650b 100644 --- a/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.cc +++ b/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.cc @@ -101,6 +101,7 @@ BatchFunctionFallbackKernelBase::BatchFunctionFallbackKernelBase( &low_priority_max_enqueued_batches_)); OP_REQUIRES_OK(c, c->GetAttr("mixed_priority_policy", &mixed_priority_policy_)); + OP_REQUIRES_OK(c, c->GetAttr("batch_padding_policy", &batch_padding_policy_)); if (shared_name_.empty()) { // If shared_name is not supplied, use name instead (prevent collisions by diff --git a/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h b/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h index 0d3f83e9246b9e..dae6eb35a43f59 100644 --- a/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h +++ b/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h @@ -83,6 +83,7 @@ class BatchFunctionFallbackKernelBase : public AsyncOpKernel { bool enable_large_batch_splitting_; bool has_attribute_enable_large_batch_splitting_; bool disable_padding_; + std::string batch_padding_policy_; // Parameters for adaptive batch scheduler only. // Note 'num_batch_threads_' above is shared by two implementations of batch diff --git a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc index d5a27a01a0ddc4..e953c7088d5583 100644 --- a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc +++ b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc @@ -433,6 +433,11 @@ REGISTER_OP("_BatchFunctionFallback") "{'low_priority_padding_with_max_batch_size', " "'low_priority_padding_with_next_allowed_batch_size', " "'priority_isolation'} = 'low_priority_padding_with_max_batch_size'") + // See the description of the batch_padding_policy attribute of + // BatchFunction in core/ops/batch_ops.cc. + .Attr( + "batch_padding_policy: " + "{'PAD_UP'} = 'PAD_UP'") .Attr("Tin: list(type)") .Attr("Tcaptured: list(type) >= 0") .Attr("Tout: list(type)") diff --git a/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc b/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc index d6348880280517..f9966959ea517d 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc @@ -1291,6 +1291,10 @@ mlrt::bc::Buffer CreateExecutableForBatchFunctionOp() { key: "mixed_priority_policy" value { s: "low_priority_padding_with_max_batch_size" } } + attr { + key: "batch_padding_policy" + value { s: "PAD_UP" } + } attr { key: "container" value { s: "container" } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 98c7718e9c479d..3830ed010ec0de 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -382,7 +382,7 @@ tf_module { } member_method { name: "BatchFunction" - argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'low_priority_max_batch_size\', \'low_priority_batch_timeout_micros\', \'low_priority_allowed_batch_sizes\', \'low_priority_max_enqueued_batches\', \'mixed_priority_policy\', \'enable_large_batch_splitting\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'0\', \'0\', \'[]\', \'0\', \'low_priority_padding_with_max_batch_size\', \'False\', \'None\'], " + argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'low_priority_max_batch_size\', \'low_priority_batch_timeout_micros\', \'low_priority_allowed_batch_sizes\', \'low_priority_max_enqueued_batches\', \'mixed_priority_policy\', \'batch_padding_policy\', \'enable_large_batch_splitting\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'0\', \'0\', \'[]\', \'0\', \'low_priority_padding_with_max_batch_size\', \'PAD_UP\', \'False\', \'None\'], " } member_method { name: "BatchIFFT" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 98c7718e9c479d..3830ed010ec0de 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -382,7 +382,7 @@ tf_module { } member_method { name: "BatchFunction" - argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'low_priority_max_batch_size\', \'low_priority_batch_timeout_micros\', \'low_priority_allowed_batch_sizes\', \'low_priority_max_enqueued_batches\', \'mixed_priority_policy\', \'enable_large_batch_splitting\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'0\', \'0\', \'[]\', \'0\', \'low_priority_padding_with_max_batch_size\', \'False\', \'None\'], " + argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'low_priority_max_batch_size\', \'low_priority_batch_timeout_micros\', \'low_priority_allowed_batch_sizes\', \'low_priority_max_enqueued_batches\', \'mixed_priority_policy\', \'batch_padding_policy\', \'enable_large_batch_splitting\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'0\', \'0\', \'[]\', \'0\', \'low_priority_padding_with_max_batch_size\', \'PAD_UP\', \'False\', \'None\'], " } member_method { name: "BatchIFFT" From fd9e0b76c3ad43d7cb44c3f4747a2c2cb2452a7f Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Tue, 25 Jun 2024 09:16:03 -0700 Subject: [PATCH 230/256] Integrate LLVM at llvm/llvm-project@4e0a0eae58f7 Updates LLVM usage to match [4e0a0eae58f7](https://github.com/llvm/llvm-project/commit/4e0a0eae58f7) PiperOrigin-RevId: 646502371 --- third_party/llvm/toolchains.patch | 6 +++--- third_party/llvm/workspace.bzl | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/third_party/llvm/toolchains.patch b/third_party/llvm/toolchains.patch index 1cbcfe31d3d072..2f8721373d0a81 100644 --- a/third_party/llvm/toolchains.patch +++ b/third_party/llvm/toolchains.patch @@ -1,9 +1,9 @@ diff --git a/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel b/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel -index c43ab727e285..7d848d2dffae 100644 +index 38970d9929b9..2690c97aa3e0 100644 --- a/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel -@@ -30,6 +30,36 @@ exports_files([ - "utils/lit/lit.py", +@@ -34,6 +34,36 @@ exports_files([ + "utils/lldbDataFormatters.py", ]) +config_setting( diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 72267d070e6b9a..3354fb5736390d 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "5cd0ba30f53d11835dbfd05ad4071d397387fb04" - LLVM_SHA256 = "cdf76f8704646b105ca671ab9fd4fbeb856d1af964f177e3740b8ce362631af4" + LLVM_COMMIT = "4e0a0eae58f7a6998866719f7eb970096a2a52e9" + LLVM_SHA256 = "ee3ed32065549d13c33ac52fd188c78d45364e740c6b54752feeb951ce15f617" tf_http_archive( name = name, From 1b5fc0570edac4d7b82435b3612e402f9594cf14 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jun 2024 09:18:58 -0700 Subject: [PATCH 231/256] Update ops-related pbtxt files. PiperOrigin-RevId: 646503190 --- .../compat/ops_history_v2/BatchFunction.pbtxt | 147 ++++++++++++++++++ tensorflow/core/ops/ops.pbtxt | 12 ++ 2 files changed, 159 insertions(+) diff --git a/tensorflow/core/ops/compat/ops_history_v2/BatchFunction.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/BatchFunction.pbtxt index cf5e5896d084ba..8fecdf6b1490e7 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/BatchFunction.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/BatchFunction.pbtxt @@ -655,3 +655,150 @@ op { } is_distributed_communication: true } +op { + name: "BatchFunction" + input_arg { + name: "in_tensors" + type_list_attr: "Tin" + } + input_arg { + name: "captured_tensors" + type_list_attr: "Tcaptured" + } + output_arg { + name: "out_tensors" + type_list_attr: "Tout" + } + attr { + name: "f" + type: "func" + } + attr { + name: "num_batch_threads" + type: "int" + } + attr { + name: "max_batch_size" + type: "int" + } + attr { + name: "batch_timeout_micros" + type: "int" + } + attr { + name: "max_enqueued_batches" + type: "int" + default_value { + i: 10 + } + } + attr { + name: "allowed_batch_sizes" + type: "list(int)" + default_value { + list { + } + } + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "batching_queue" + type: "string" + default_value { + s: "" + } + } + attr { + name: "low_priority_max_batch_size" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "low_priority_batch_timeout_micros" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "low_priority_allowed_batch_sizes" + type: "list(int)" + default_value { + list { + } + } + } + attr { + name: "low_priority_max_enqueued_batches" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "mixed_priority_policy" + type: "string" + default_value { + s: "low_priority_padding_with_max_batch_size" + } + allowed_values { + list { + s: "low_priority_padding_with_max_batch_size" + s: "low_priority_padding_with_next_allowed_batch_size" + s: "priority_isolation" + } + } + } + attr { + name: "batch_padding_policy" + type: "string" + default_value { + s: "PAD_UP" + } + allowed_values { + list { + s: "PAD_UP" + } + } + } + attr { + name: "Tin" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "Tcaptured" + type: "list(type)" + has_minimum: true + } + attr { + name: "Tout" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "enable_large_batch_splitting" + type: "bool" + default_value { + b: false + } + } + is_distributed_communication: true +} diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index da60d1a2cffb2c..64378c608c10f1 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -4103,6 +4103,18 @@ op { } } } + attr { + name: "batch_padding_policy" + type: "string" + default_value { + s: "PAD_UP" + } + allowed_values { + list { + s: "PAD_UP" + } + } + } attr { name: "Tin" type: "list(type)" From 1081403feec133be9b9eda1e5c473f31de595c8e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jun 2024 19:35:51 +0000 Subject: [PATCH 232/256] Merged commit includes the following changes: 646570256 by A. Unique TensorFlower: Put back type aliases for some 3p projects until they're migrated off of xla::Status. -- 646567067 by A. Unique TensorFlower: [xla:cpu] Optimize KernelThunk alignment checks -- 646562233 by A. Unique TensorFlower: Automated rollback of changelist 609005660. 646560125 by A. Unique TensorFlower: [XLA:CollectivePipeliner] Add more execution tests (using the HLOs in collective_pipeliner_test.cc). -- 646554714 by A. Unique TensorFlower: Instead of copybara rules, use `if_google` to remove extra proto deps Followup will do the same for TSL -- 646551061 by A. Unique TensorFlower: Remove `Array::Reshard` This CL removes the deprecated `Array::Reshard` API. All existing users have been manually migrated to use `Client::CopyArrays`. IFRT Proxy is updated such that the client no longer issues `Array::Reshard` and the server emulates the reshard behavior by using `Client::CopyArrays`. Since this does not actually change the wire format, we do not need to update the version number. Once the reshard API passes the compatibility window, we can remove its proto message and handler altogether. -- 646545951 by A. Unique TensorFlower: Add license header to `dependabot.yml` -- 646545541 by A. Unique TensorFlower: Remove force_synchronous attribute from ParallelMap op in map_parallelization optimizer. The code reuses the attributes/inputs of the original Map op but just changes it to a ParallelMap op. But the force_synchronous attribute is not supported in ParallelMap and causes log warnings. The issue was introduced in cl/642418430 -- 646534280 by A. Unique TensorFlower: Use absl::StatusOr instead of xla::StatusOr. -- 646517068 by A. Unique TensorFlower: Add more pattern to HloUnstacker pass + some refactoring. Added a support for handling slicing fusion pattern: fusion(stacked_operand, loop_iteration_var), calls=fusion_computation fusion_computation { p0 = parameter(0) p1 = parameter(1) slice = dynamic_slice(p0, p1, zero, ...) ROOT bitcast = bitcast(slice) } Add "xla_enable_hlo_unstacker" flag to the compiler. -- 646513305 by A. Unique TensorFlower: Remove unused deps. -- 646513101 by A. Unique TensorFlower: [xla:cpu] Add a fast path for executing thunks sequentially -- 646507520 by A. Unique TensorFlower: Added a fingerprint field to PjRtStreamExecutorLoadedExecutable to avoid recalculating fingerprints when FingerprintExecutable() is called. This change significantly reduces idle time before execution when the GPU load tracker enqueues an executable. -- 646505763 by A. Unique TensorFlower: Change visibility rules. -- 646505592 by A. Unique TensorFlower: [XLA:GPU] Parse block-level parameters from backend config when available. If block-level parameters are not available, fall back to the SoftMax heuristic. The original plan was to parse block-level parameters from the config and remove the heuristic, but it turned out that we don't support all "valid" tiling. With this change it will be easier to write tests and verify that we don't have problem, before we could remove the heuristic and fully migrate to fusion backend config. Also fix strides in ir_emitter_triton.cc. This was not a problem before, because SoftMax heuristic only produces tiles that are contiguous in memory. -- 646505352 by A. Unique TensorFlower: [xla:cpu] Add dynamic-update-slice fusion optimization to IrEmitter2 + enable select-and-scatter test that used to time out without DUS optimization -- 646504512 by A. Unique TensorFlower: PR #62472: Hash Pin docker images Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/62472 Also related to https://github.com/tensorflow/tensorflow/pull/62471, would you consider hash pin the docker images? The security benefit of doing so is that it mitigates the risk of typosquatting attacks since the images are public. If there is a need for them to be updated regularly, I can also submit a .github/dependabot file to update the docker images regularly (weekly or monthly for example). Besides, AFAIUC, the dockerfiles are used for build and tests, which lead to another benefit of hash pinning: reliability and stability. Let me know your thoughts about i. Thanks! Copybara import of the project: -- 8f4589fe583518d3099c98215e5e6bf3858fa24e by Joyce Brum : feat: create dependabot Signed-off-by: Joyce Brum Merging this change closes #62472 -- PiperOrigin-RevId: 646570256 --- .github/dependabot.yml | 44 ++ ci/devinfra/docker_windows/Dockerfile | 2 +- ci/official/containers/linux_arm64/Dockerfile | 4 +- tensorflow/compiler/jit/xla_launch_util.cc | 2 +- .../gpu/gpu_serving_device_selector.h | 3 - .../optimizers/data/map_parallelization.cc | 1 + tensorflow/lite/core/BUILD | 3 - tensorflow/lite/core/model_builder.h | 4 - tensorflow/tools/gcs_test/Dockerfile | 2 +- .../tools/tf_sig_build_dockerfiles/Dockerfile | 4 +- .../tsl/tsl/platform/cloud/gcs_file_system.cc | 6 +- .../platform/cloud/gcs_file_system_test.cc | 444 +++++++++--------- third_party/xla/xla/BUILD | 13 +- .../xla/xla/backends/interpreter/BUILD | 2 +- .../xla/pjrt/pjrt_stream_executor_client.cc | 21 +- .../xla/pjrt/pjrt_stream_executor_client.h | 5 +- third_party/xla/xla/python/BUILD | 5 +- third_party/xla/xla/python/ifrt/array.h | 23 - .../xla/python/ifrt/array_impl_test_lib.cc | 64 --- third_party/xla/xla/python/ifrt/mock.cc | 5 - third_party/xla/xla/python/ifrt/mock.h | 4 - .../xla/xla/python/ifrt_proxy/client/array.h | 4 +- .../xla/python/ifrt_proxy/client/client.cc | 11 +- .../xla/xla/python/ifrt_proxy/common/BUILD | 2 +- .../ifrt_proxy/common/ifrt_service.proto | 4 +- .../python/ifrt_proxy/server/ifrt_backend.cc | 17 +- .../ifrt_proxy/server/ifrt_backend_test.cc | 85 ++-- third_party/xla/xla/python/pjrt_ifrt/BUILD | 6 +- .../python/pjrt_ifrt/basic_string_array.cc | 8 +- .../xla/python/pjrt_ifrt/basic_string_array.h | 9 +- .../xla/xla/python/pjrt_ifrt/pjrt_array.cc | 7 +- .../xla/xla/python/pjrt_ifrt/pjrt_array.h | 8 +- .../xla/xla/python/pjrt_ifrt/pjrt_client.cc | 16 +- third_party/xla/xla/python/tools/BUILD | 2 +- third_party/xla/xla/service/BUILD | 13 +- third_party/xla/xla/service/cpu/BUILD | 1 + .../cpu/benchmarks/fusion_benchmark_test.cc | 36 ++ third_party/xla/xla/service/cpu/ir_emitter.h | 2 + .../xla/xla/service/cpu/ir_emitter2.cc | 19 +- .../service/cpu/onednn_convolution_rewriter.h | 3 +- third_party/xla/xla/service/cpu/runtime/BUILD | 3 + .../xla/service/cpu/runtime/kernel_thunk.cc | 20 +- .../xla/service/cpu/runtime/kernel_thunk.h | 11 +- .../service/cpu/runtime/kernel_thunk_test.cc | 8 + .../xla/service/cpu/runtime/thunk_executor.cc | 85 +++- .../xla/service/cpu/runtime/thunk_executor.h | 16 + .../cpu/runtime/thunk_executor_test.cc | 96 ++-- third_party/xla/xla/service/gpu/BUILD | 2 +- .../xla/xla/service/gpu/fusions/triton.cc | 31 +- .../xla/service/gpu/fusions/triton_test.cc | 62 ++- .../xla/xla/service/gpu/ir_emitter_triton.cc | 18 +- .../gpu/ir_emitter_triton_mem_utils_test.cc | 9 +- .../xla/service/gpu/ir_emitter_triton_test.cc | 54 +++ third_party/xla/xla/service/hlo_unstacker.cc | 316 +++++++++---- .../xla/xla/service/hlo_unstacker_test.cc | 59 ++- third_party/xla/xla/status.h | 8 + third_party/xla/xla/stream_executor/BUILD | 9 +- third_party/xla/xla/tests/BUILD | 4 + .../collective_pipeliner_execution_test.cc | 345 ++++++++++++++ third_party/xla/xla/tools/BUILD | 4 +- third_party/xla/xla/tsl/util/proto/BUILD | 4 +- 61 files changed, 1487 insertions(+), 591 deletions(-) create mode 100644 .github/dependabot.yml diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000000000..b6d4f3ced20add --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,44 @@ +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +version: 2 +updates: + - package-ecosystem: github-actions + directory: / + schedule: + interval: monthly + groups: + github-actions: + patterns: + - "*" + + - package-ecosystem: docker + directory: /ci/devinfra/docker_windows + schedule: + interval: monthly + + - package-ecosystem: docker + directory: /ci/official/containers/linux_arm64 + schedule: + interval: monthly + + - package-ecosystem: docker + directory: /tensorflow/tools/gcs_test + schedule: + interval: monthly + + - package-ecosystem: docker + directory: /tensorflow/tools/tf_sig_build_dockerfiles + schedule: + interval: monthly diff --git a/ci/devinfra/docker_windows/Dockerfile b/ci/devinfra/docker_windows/Dockerfile index 8d70ccf7611d0a..491887e198aa07 100644 --- a/ci/devinfra/docker_windows/Dockerfile +++ b/ci/devinfra/docker_windows/Dockerfile @@ -1,4 +1,4 @@ -FROM mcr.microsoft.com/dotnet/framework/sdk:4.8-windowsservercore-ltsc2019 +FROM mcr.microsoft.com/dotnet/framework/sdk:4.8-windowsservercore-ltsc2019@sha256:46e393cbb7c915c504a810639e35f40cb516f8e886e4cbcf8a3b49f86705a070 # Set default powershell policy for this script (ProgressPreference='SilentlyContinue' makes # downloads with Invoke-WebRequest not show the progress bar and is MUCH faster). diff --git a/ci/official/containers/linux_arm64/Dockerfile b/ci/official/containers/linux_arm64/Dockerfile index 5ddf6b02f46d60..9ba44d58d7004c 100644 --- a/ci/official/containers/linux_arm64/Dockerfile +++ b/ci/official/containers/linux_arm64/Dockerfile @@ -1,5 +1,5 @@ ################################################################################ -FROM ubuntu:20.04 as builder +FROM ubuntu:20.04@sha256:874aca52f79ae5f8258faff03e10ce99ae836f6e7d2df6ecd3da5c1cad3a912b as builder ################################################################################ # Install devtoolset build dependencies @@ -23,7 +23,7 @@ COPY apt.conf /etc/apt/ RUN /build_patchelf.sh ################################################################################ -FROM nvidia/cuda:12.3.1-devel-ubuntu20.04 as devel +FROM nvidia/cuda:12.3.1-devel-ubuntu20.04@sha256:befbdfddbb52727f9ce8d0c574cac0f631c606b1e6f0e523f3a0777fe2720c99 as devel ################################################################################ COPY --from=builder /dt10 /dt10 COPY --from=builder /patchelf/patchelf_0.14.3-1_arm64.deb /patchelf/patchelf_0.14.3-1_arm64.deb diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index d146c8878450ee..1e8caacac02aa8 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -866,7 +866,7 @@ Status RunPjRtExecutable( xla::PjRtLocalDeviceId(pjrt_device_id))); gpu::GpuServingDeviceSelectorResource* device_selector_resource = nullptr; - if (device_type == DEVICE_GPU && gpu::kUseGpuServingDeviceSelector) { + if (device_type == DEVICE_GPU) { auto rm = ctx->resource_manager(); TF_RETURN_IF_ERROR(rm->LookupOrCreate< gpu::GpuServingDeviceSelectorResource>( diff --git a/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h b/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h index 814a0c9f5d0004..c6f352acb961f6 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h +++ b/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h @@ -32,9 +32,6 @@ namespace gpu { class GpuServingDeviceSelector; const char kGpuServingDeviceSelectorResourceName[] = "gpu_serving_device_selector"; -// TODO(b/335729939): Disable GPU load tracker for performance regression -// investigation. Remove when fixed. -const bool kUseGpuServingDeviceSelector = false; class GpuServingDeviceSelectorResource : public ResourceBase { public: diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc index 21c334201fedca..f1208076f6dc5d 100644 --- a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc @@ -47,6 +47,7 @@ NodeDef MakeParallelMap(const string& name, MutableGraphView* graph) { auto* num_parallel_calls = graph_utils::AddScalarConstNode( static_cast(data::model::kAutotune), graph); parallel_map.add_input(num_parallel_calls->name()); + parallel_map.mutable_attr()->erase("force_synchronous"); AddNodeAttr("deterministic", "true", ¶llel_map); return parallel_map; diff --git a/tensorflow/lite/core/BUILD b/tensorflow/lite/core/BUILD index 375e63b84d7565..0a7c63486bc9b1 100644 --- a/tensorflow/lite/core/BUILD +++ b/tensorflow/lite/core/BUILD @@ -294,13 +294,10 @@ cc_library( deps = [ ":macros", "//tensorflow/lite:allocation", - "//tensorflow/lite:mutable_op_resolver", "//tensorflow/lite:stderr_reporter", "//tensorflow/lite:string", "//tensorflow/lite/core/api:error_reporter", - "//tensorflow/lite/core/api:op_resolver", "//tensorflow/lite/core/api:verifier", - "//tensorflow/lite/core/c:common", "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/strings", "@flatbuffers", diff --git a/tensorflow/lite/core/model_builder.h b/tensorflow/lite/core/model_builder.h index cd94fbd5ad8bc9..3337d9358aac22 100644 --- a/tensorflow/lite/core/model_builder.h +++ b/tensorflow/lite/core/model_builder.h @@ -33,13 +33,9 @@ limitations under the License. #include "tensorflow/lite/allocation.h" #include "tensorflow/lite/core/api/error_reporter.h" -#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/verifier.h" -#include "tensorflow/lite/core/c/common.h" -#include "tensorflow/lite/mutable_op_resolver.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/stderr_reporter.h" -#include "tensorflow/lite/string_type.h" namespace tflite { diff --git a/tensorflow/tools/gcs_test/Dockerfile b/tensorflow/tools/gcs_test/Dockerfile index 69b554047bb8ea..e92a7c7db92a00 100644 --- a/tensorflow/tools/gcs_test/Dockerfile +++ b/tensorflow/tools/gcs_test/Dockerfile @@ -1,4 +1,4 @@ -FROM ubuntu:16.04 +FROM ubuntu:16.04@sha256:1f1a2d56de1d604801a9671f301190704c25d604a416f59e03c04f5c6ffee0d6 LABEL maintainer="Shanqing Cai " diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile b/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile index 5ca83bcd730c18..0433b3a4127ef3 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile +++ b/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile @@ -1,5 +1,5 @@ ################################################################################ -FROM ubuntu:22.04 as builder +FROM ubuntu:22.04@sha256:a6d2b38300ce017add71440577d5b0a90460d0e57fd7aec21dd0d1b0761bbfb2 as builder ################################################################################ # Install devtoolset build dependencies @@ -16,7 +16,7 @@ COPY builder.devtoolset/glibc2.17-inline.patch /glibc2.17-inline.patch RUN /build_devtoolset.sh devtoolset-9 /dt9 ################################################################################ -FROM nvidia/cuda:12.3.1-base-ubuntu22.04 as devel +FROM nvidia/cuda:12.3.1-base-ubuntu22.04@sha256:6a7febf317514458233b87819ce47d5441357dd7763e91800c35f6745f34bbbd as devel ################################################################################ COPY --from=builder /dt9 /dt9 diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc index 869dc993ee0a9d..ea65028a96cd22 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc @@ -66,10 +66,10 @@ limitations under the License. namespace tsl { namespace { -constexpr char kGcsUriBase[] = "https://www.googleapis.com./storage/v1/"; +constexpr char kGcsUriBase[] = "https://www.googleapis.com/storage/v1/"; constexpr char kGcsUploadUriBase[] = - "https://www.googleapis.com./upload/storage/v1/"; -constexpr char kStorageHost[] = "storage.googleapis.com."; + "https://www.googleapis.com/upload/storage/v1/"; +constexpr char kStorageHost[] = "storage.googleapis.com"; constexpr char kBucketMetadataLocationKey[] = "location"; constexpr size_t kReadAppendableFileBufferSize = 1024 * 1024; // In bytes. constexpr int kGetChildrenDefaultPageSize = 1000; diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc index e403599096e5f3..9221128276af9e 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc @@ -62,13 +62,13 @@ class FakeZoneProvider : public ZoneProvider { TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache) { std::vector requests( {new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-5\n" "Timeouts: 5 1 20\n", "012345"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 6-11\n" "Timeouts: 5 1 20\n", @@ -108,13 +108,13 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache) { TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered) { std::vector requests({ new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-9\n" "Timeouts: 5 1 20\n", "0123456789"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 10-19\n" "Timeouts: 5 1 20\n", @@ -155,14 +155,14 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered) { TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_Errors) { std::vector requests({ new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-9\n" "Timeouts: 5 1 20\n", "Server Not", errors::Unavailable("important HTTP error 308"), nullptr, {}, 308), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 6-15\n" "Timeouts: 5 1 20\n", @@ -204,13 +204,13 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_Errors) { TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_ReadAtEOF) { std::vector requests( {new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-9\n" "Timeouts: 5 1 20\n", "0123456789"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 10-19\n" "Timeouts: 5 1 20\n", @@ -251,7 +251,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_CachedOutOfRange) { // In this test, there is only one backend request since we cache the file // size. std::vector requests({new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-9\n" "Timeouts: 5 1 20\n", @@ -297,13 +297,13 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_CachedNotSequential) { // a backend request. std::vector requests( {new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 1-10\n" "Timeouts: 5 1 20\n", "12345678"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-9\n" "Timeouts: 5 1 20\n", @@ -339,13 +339,13 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_CachedNotSequential) { TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_Growing) { std::vector requests( {new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-9\n" "Timeouts: 5 1 20\n", "012345678"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 9-18\n" "Timeouts: 5 1 20\n", @@ -387,13 +387,13 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_ReadBackwards) { // Go backwards in the file. It should trigger a new read. std::vector requests( {new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 5-14\n" "Timeouts: 5 1 20\n", "56789"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-9\n" "Timeouts: 5 1 20\n", @@ -433,7 +433,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_ReadBackwards) { TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintInSameLocation) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", R"( @@ -460,7 +460,7 @@ TEST(GcsFileSystemTest, TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintCaching) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", R"( @@ -468,7 +468,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintCaching) { "location":"US-EAST1" })"), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/anotherbucket\n" + "Uri: https://www.googleapis.com/storage/v1/b/anotherbucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", R"( @@ -476,7 +476,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintCaching) { "location":"US-EAST1" })"), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", R"( @@ -517,7 +517,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintCaching) { TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintInDifferentLocation) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", R"( @@ -547,13 +547,13 @@ TEST(GcsFileSystemTest, TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache_DifferentN) { std::vector requests( {new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-2\n" "Timeouts: 5 1 20\n", "012"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 3-12\n" "Timeouts: 5 1 20\n", @@ -593,26 +593,26 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache) { // "0123456789abcde". std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "random_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"15\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-8\n" "Timeouts: 5 1 20\n", "012345678"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 9-17\n" "Timeouts: 5 1 20\n", "9abcde"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 18-26\n" "Timeouts: 5 1 20\n", @@ -679,27 +679,27 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_Flush) { // "0123456789abcde". std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "random_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"15\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-8\n" "Timeouts: 5 1 20\n", "012345678"), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "random_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"15\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-8\n" "Timeouts: 5 1 20\n", @@ -738,24 +738,22 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_MaxStaleness) { // "0123456789abcdef". std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "object?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"16\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), - new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/object\n" - "Auth Token: fake_token\n" - "Range: 0-7\n" - "Timeouts: 5 1 20\n", - "01234567"), - new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/object\n" - "Auth Token: fake_token\n" - "Range: 8-15\n" - "Timeouts: 5 1 20\n", - "89abcdef")}); + new FakeHttpRequest("Uri: https://storage.googleapis.com/bucket/object\n" + "Auth Token: fake_token\n" + "Range: 0-7\n" + "Timeouts: 5 1 20\n", + "01234567"), + new FakeHttpRequest("Uri: https://storage.googleapis.com/bucket/object\n" + "Auth Token: fake_token\n" + "Range: 8-15\n" + "Timeouts: 5 1 20\n", + "89abcdef")}); GcsFileSystem fs( std::unique_ptr(new FakeAuthProvider), std::unique_ptr( @@ -802,27 +800,27 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_FileSignatureChanges) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "random_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"5\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-8\n" "Timeouts: 5 1 20\n", "01234"), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "random_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"5\",\"generation\": \"2\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-8\n" "Timeouts: 5 1 20\n", @@ -876,14 +874,14 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoObjectName) { TEST(GcsFileSystemTest, NewRandomAccessFile_InconsistentRead) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "random_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"6\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-5\n" "Timeouts: 5 1 20\n", @@ -919,20 +917,20 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_InconsistentRead) { TEST(GcsFileSystemTest, NewWritableFile) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fwriteable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"16\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/path%2Fwriteable\n" + "Uri: https://storage.googleapis.com/bucket/path%2Fwriteable\n" "Auth Token: fake_token\n" "Range: 0-7\n" "Timeouts: 5 1 20\n", "01234567"), new FakeHttpRequest( - "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fwriteable\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -946,14 +944,14 @@ TEST(GcsFileSystemTest, NewWritableFile) { "Put body: content1,content2\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fwriteable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"33\",\"generation\": \"2\"," "\"updated\": \"2016-04-29T23:15:34.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/path%2Fwriteable\n" + "Uri: https://storage.googleapis.com/bucket/path%2Fwriteable\n" "Auth Token: fake_token\n" "Range: 0-7\n" "Timeouts: 5 1 20\n", @@ -1000,7 +998,7 @@ TEST(GcsFileSystemTest, NewWritableFile) { TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceeds) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fwriteable.txt\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -1078,20 +1076,20 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceedsOnGetStatus) { // path. std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fwriteable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"16\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/path%2Fwriteable\n" + "Uri: https://storage.googleapis.com/bucket/path%2Fwriteable\n" "Auth Token: fake_token\n" "Range: 0-7\n" "Timeouts: 5 1 20\n", "01234567"), new FakeHttpRequest( - "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fwriteable\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -1111,14 +1109,14 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceedsOnGetStatus) { "Put: yes\n", "", OkStatus(), nullptr, {}, 201), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fwriteable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"33\",\"generation\": \"2\"," "\"updated\": \"2016-04-29T23:19:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/path%2Fwriteable\n" + "Uri: https://storage.googleapis.com/bucket/path%2Fwriteable\n" "Auth Token: fake_token\n" "Range: 0-7\n" "Timeouts: 5 1 20\n", @@ -1165,7 +1163,7 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceedsOnGetStatus) { TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadAllAttemptsFail) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fwriteable.txt\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -1198,7 +1196,7 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadAllAttemptsFail) { // These calls will be made in the Close() attempt from the destructor. // Letting the destructor succeed. requests.emplace_back(new FakeHttpRequest( - "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fwriteable.txt\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -1247,7 +1245,7 @@ TEST(GcsFileSystemTest, NewWritableFile_UploadReturns410) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fwriteable.txt\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -1264,7 +1262,7 @@ TEST(GcsFileSystemTest, NewWritableFile_UploadReturns410) { // These calls will be made in the Close() attempt from the destructor. // Letting the destructor succeed. new FakeHttpRequest( - "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fwriteable.txt\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -1336,26 +1334,26 @@ TEST(GcsFileSystemTest, NewWritableFile_NoObjectName) { TEST(GcsFileSystemTest, NewAppendableFile) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fappendable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/path%2Fappendable\n" + "Uri: https://storage.googleapis.com/bucket/path%2Fappendable\n" "Auth Token: fake_token\n" "Range: 0-1048575\n" "Timeouts: 5 1 20\n", "content1,"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/path%2Fappendable\n" + "Uri: https://storage.googleapis.com/bucket/path%2Fappendable\n" "Auth Token: fake_token\n" "Range: 0-31\n" "Timeouts: 5 1 20\n", "content1,"), new FakeHttpRequest( - "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fappendable\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -1369,14 +1367,14 @@ TEST(GcsFileSystemTest, NewAppendableFile) { "Put body: content1,content2\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fappendable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"2\"," "\"updated\": \"2016-04-29T23:25:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/path%2Fappendable\n" + "Uri: https://storage.googleapis.com/bucket/path%2Fappendable\n" "Auth Token: fake_token\n" "Range: 0-31\n" "Timeouts: 5 1 20\n", @@ -1437,13 +1435,13 @@ TEST(GcsFileSystemTest, NewAppendableFile_NoObjectName) { TEST(GcsFileSystemTest, NewAppendableFile_ObjectDoesNotExist) { std::vector requests( {new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/filename\n" + "Uri: https://storage.googleapis.com/bucket/filename\n" "Auth Token: fake_token\n" "Range: 0-1048575\n" "Timeouts: 5 1 20\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o" + "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o" "?uploadType=resumable&name=filename\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 0\n" @@ -1469,7 +1467,7 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) { const string content = "file content"; std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Frandom_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -1477,7 +1475,7 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) { ", \"generation\": \"1\"", ", \"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - strings::StrCat("Uri: https://storage.googleapis.com./bucket/" + strings::StrCat("Uri: https://storage.googleapis.com/bucket/" "path%2Frandom_access.txt\n" "Auth Token: fake_token\n" "Range: 0-", @@ -1522,7 +1520,7 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile_NoObjectName) { TEST(GcsFileSystemTest, FileExists_YesAsObject) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Ffile1.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -1545,13 +1543,13 @@ TEST(GcsFileSystemTest, FileExists_YesAsObject) { TEST(GcsFileSystemTest, FileExists_YesAsFolder) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fsubfolder?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubfolder%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -1575,12 +1573,12 @@ TEST(GcsFileSystemTest, FileExists_YesAsFolder) { TEST(GcsFileSystemTest, FileExists_YesAsBucket) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket1\n" + "Uri: https://www.googleapis.com/storage/v1/b/bucket1\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"size\": \"100\"}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket1\n" + "Uri: https://www.googleapis.com/storage/v1/b/bucket1\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"size\": \"100\"}")}); @@ -1602,13 +1600,13 @@ TEST(GcsFileSystemTest, FileExists_YesAsBucket) { TEST(GcsFileSystemTest, FileExists_NotAsObjectOrFolder) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Ffile1.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Ffile1.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -1632,12 +1630,12 @@ TEST(GcsFileSystemTest, FileExists_NotAsObjectOrFolder) { TEST(GcsFileSystemTest, FileExists_NotAsBucket) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket2\n" + "Uri: https://www.googleapis.com/storage/v1/b/bucket2\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket2\n" + "Uri: https://www.googleapis.com/storage/v1/b/bucket2\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404)}); @@ -1658,20 +1656,20 @@ TEST(GcsFileSystemTest, FileExists_NotAsBucket) { TEST(GcsFileSystemTest, FileExists_StatCache) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Ffile1.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fsubfolder%2F?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubfolder%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -1699,7 +1697,7 @@ TEST(GcsFileSystemTest, FileExists_StatCache) { TEST(GcsFileSystemTest, FileExists_DirectoryMark) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "dir%2F?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -1722,7 +1720,7 @@ TEST(GcsFileSystemTest, FileExists_DirectoryMark) { TEST(GcsFileSystemTest, GetChildren_NoItems) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F&prefix=" "path%2F\n" "Auth Token: fake_token\n" @@ -1747,7 +1745,7 @@ TEST(GcsFileSystemTest, GetChildren_NoItems) { TEST(GcsFileSystemTest, GetChildren_ThreeFiles) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F&prefix=" "path%2F\n" "Auth Token: fake_token\n" @@ -1776,7 +1774,7 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles) { TEST(GcsFileSystemTest, GetChildren_SelfDirectoryMarker) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F&prefix=" "path%2F\n" "Auth Token: fake_token\n" @@ -1804,7 +1802,7 @@ TEST(GcsFileSystemTest, GetChildren_SelfDirectoryMarker) { TEST(GcsFileSystemTest, GetChildren_ThreeFiles_NoSlash) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F&prefix=" "path%2F\n" "Auth Token: fake_token\n" @@ -1833,7 +1831,7 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles_NoSlash) { TEST(GcsFileSystemTest, GetChildren_Root) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket-a-b-c/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket-a-b-c/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -1857,7 +1855,7 @@ TEST(GcsFileSystemTest, GetChildren_Root) { TEST(GcsFileSystemTest, GetChildren_Empty) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F&prefix=" "path%2F\n" "Auth Token: fake_token\n" @@ -1883,7 +1881,7 @@ TEST(GcsFileSystemTest, GetChildren_Empty) { TEST(GcsFileSystemTest, GetChildren_Pagination) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F&" "prefix=path%2F\n" "Auth Token: fake_token\n" @@ -1894,7 +1892,7 @@ TEST(GcsFileSystemTest, GetChildren_Pagination) { " { \"name\": \"path/file3.txt\" }]," "\"prefixes\": [\"path/subpath/\"]}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F&" "prefix=path%2F" "&pageToken=ABCD==\n" @@ -1925,7 +1923,7 @@ TEST(GcsFileSystemTest, GetChildren_Pagination) { TEST(GcsFileSystemTest, GetMatchingPaths_NoWildcard) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubpath%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -1951,7 +1949,7 @@ TEST(GcsFileSystemTest, GetMatchingPaths_NoWildcard) { TEST(GcsFileSystemTest, GetMatchingPaths_BucketAndWildcard) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -1980,7 +1978,7 @@ TEST(GcsFileSystemTest, GetMatchingPaths_BucketAndWildcard) { TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_Matches) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2008,7 +2006,7 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_Matches) { TEST(GcsFileSystemTest, GetMatchingPaths_SelfDirectoryMarker) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2033,7 +2031,7 @@ TEST(GcsFileSystemTest, GetMatchingPaths_SelfDirectoryMarker) { TEST(GcsFileSystemTest, GetMatchingPaths_SlashInObjectName) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2058,7 +2056,7 @@ TEST(GcsFileSystemTest, GetMatchingPaths_SlashInObjectName) { TEST(GcsFileSystemTest, GetMatchingPaths_SlashInObjectNameEscaped) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2083,7 +2081,7 @@ TEST(GcsFileSystemTest, GetMatchingPaths_SlashInObjectNameEscaped) { TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_NoMatches) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2129,14 +2127,14 @@ TEST(GcsFileSystemTest, GetMatchingPaths_OnlyWildcard) { TEST(GcsFileSystemTest, GetMatchingPaths_Cache) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubpath%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"items\": [ " " { \"name\": \"path/subpath/file2.txt\" }]}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2174,14 +2172,14 @@ TEST(GcsFileSystemTest, GetMatchingPaths_Cache) { TEST(GcsFileSystemTest, GetMatchingPaths_Cache_Flush) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubpath%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"items\": [ " " { \"name\": \"path/subpath/file2.txt\" }]}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubpath%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2220,33 +2218,33 @@ TEST(GcsFileSystemTest, GetMatchingPaths_Cache_Flush) { TEST(GcsFileSystemTest, DeleteFile) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Ffile1.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/path%2Ffile1.txt\n" + "Uri: https://storage.googleapis.com/bucket/path%2Ffile1.txt\n" "Auth Token: fake_token\n" "Range: 0-15\n" "Timeouts: 5 1 20\n", "01234567"), - new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" "/bucket/o/path%2Ffile1.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Ffile1.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"2\"," "\"updated\": \"2016-04-29T23:19:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/path%2Ffile1.txt\n" + "Uri: https://storage.googleapis.com/bucket/path%2Ffile1.txt\n" "Auth Token: fake_token\n" "Range: 0-15\n" "Timeouts: 5 1 20\n", @@ -2298,26 +2296,26 @@ TEST(GcsFileSystemTest, DeleteFile_NoObjectName) { TEST(GcsFileSystemTest, DeleteFile_StatCacheRemoved) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), - new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" "/bucket/o/file.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=file.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2349,7 +2347,7 @@ TEST(GcsFileSystemTest, DeleteFile_StatCacheRemoved) { TEST(GcsFileSystemTest, DeleteDir_Empty) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F&maxResults=2\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2371,13 +2369,13 @@ TEST(GcsFileSystemTest, DeleteDir_Empty) { TEST(GcsFileSystemTest, DeleteDir_OnlyDirMarkerLeft) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F&maxResults=2\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"items\": [ " " { \"name\": \"path/\" }]}"), - new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" "/bucket/o/path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -2399,7 +2397,7 @@ TEST(GcsFileSystemTest, DeleteDir_OnlyDirMarkerLeft) { TEST(GcsFileSystemTest, DeleteDir_BucketOnly) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?fields=items%2F" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?fields=items%2F" "name%2CnextPageToken&maxResults=2\nAuth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}")}); @@ -2419,7 +2417,7 @@ TEST(GcsFileSystemTest, DeleteDir_BucketOnly) { TEST(GcsFileSystemTest, DeleteDir_NonEmpty) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F&maxResults=2\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2442,7 +2440,7 @@ TEST(GcsFileSystemTest, DeleteDir_NonEmpty) { TEST(GcsFileSystemTest, GetFileSize) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2486,7 +2484,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { std::vector requests( {// Check if this is a folder or an object. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path1%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2495,7 +2493,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { " { \"name\": \"path1/subfolder/file1.txt\" }]}"), // Requesting the full list of files in the folder. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path1%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2505,7 +2503,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { " { \"name\": \"path1/file2.txt\" }]}"), // Copying the directory marker. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path1%2F/rewriteTo/b/bucket/o/path2%2F\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -2513,7 +2511,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { "{\"done\": true}"), // Deleting the original directory marker. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path1%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -2521,7 +2519,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { ""), // Copying the first file. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path1%2Fsubfolder%2Ffile1.txt/rewriteTo/b/bucket/o/" "path2%2Fsubfolder%2Ffile1.txt\n" "Auth Token: fake_token\n" @@ -2530,7 +2528,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { "{\"done\": true}"), // Deleting the first original file. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path1%2Fsubfolder%2Ffile1.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -2538,7 +2536,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { ""), // Copying the second file. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path1%2Ffile2.txt/rewriteTo/b/bucket/o/path2%2Ffile2.txt\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -2546,7 +2544,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { "{\"done\": true}"), // Deleting the second original file. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path1%2Ffile2.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -2570,34 +2568,34 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { TEST(GcsFileSystemTest, RenameFile_Object) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fsrc.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/path%2Fsrc.txt\n" + "Uri: https://storage.googleapis.com/bucket/path%2Fsrc.txt\n" "Auth Token: fake_token\n" "Range: 0-15\n" "Timeouts: 5 1 20\n", "01234567"), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fdst.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/path%2Fdst.txt\n" + "Uri: https://storage.googleapis.com/bucket/path%2Fdst.txt\n" "Auth Token: fake_token\n" "Range: 0-15\n" "Timeouts: 5 1 20\n", "76543210"), // IsDirectory is checking whether there are children objects. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsrc.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2605,7 +2603,7 @@ TEST(GcsFileSystemTest, RenameFile_Object) { "{}"), // Copying to the new location. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fsrc.txt/rewriteTo/b/bucket/o/path%2Fdst.txt\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -2613,34 +2611,34 @@ TEST(GcsFileSystemTest, RenameFile_Object) { "{\"done\": true}"), // Deleting the original file. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fsrc.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fsrc.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"2\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/path%2Fsrc.txt\n" + "Uri: https://storage.googleapis.com/bucket/path%2Fsrc.txt\n" "Auth Token: fake_token\n" "Range: 0-15\n" "Timeouts: 5 1 20\n", "89abcdef"), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fdst.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"2\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/path%2Fdst.txt\n" + "Uri: https://storage.googleapis.com/bucket/path%2Fdst.txt\n" "Auth Token: fake_token\n" "Range: 0-15\n" "Timeouts: 5 1 20\n", @@ -2683,7 +2681,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_FlushTargetStatCache) { std::vector requests( {// Stat the target file. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fdst.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2691,7 +2689,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_FlushTargetStatCache) { "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), // IsDirectory is checking whether there are children objects. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsrc.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2699,7 +2697,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_FlushTargetStatCache) { "{}"), // IsDirectory is checking if the path exists as an object. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fsrc.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2707,7 +2705,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_FlushTargetStatCache) { "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), // Copying to the new location. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fsrc.txt/rewriteTo/b/bucket/o/path%2Fdst.txt\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -2715,14 +2713,14 @@ TEST(GcsFileSystemTest, RenameFile_Object_FlushTargetStatCache) { "{\"done\": true}"), // Deleting the original file. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fsrc.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fdst.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2759,7 +2757,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) { std::vector requests( {// IsDirectory is checking whether there are children objects. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsrc.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2767,7 +2765,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) { "{}"), // IsDirectory is checking if the path exists as an object. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fsrc.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2775,7 +2773,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) { "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), // Copying to the new location. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fsrc.txt/rewriteTo/b/bucket/o/path%2Fdst.txt\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -2783,7 +2781,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) { "{\"done\": true}"), // Deleting the original file - the deletion returns a failure. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fsrc.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -2791,7 +2789,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) { "", errors::Unavailable("503"), 503), // Deleting the original file again - the deletion returns NOT_FOUND. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fsrc.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -2817,7 +2815,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_Incomplete) { std::vector requests( {// IsDirectory is checking whether there are children objects. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsrc.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2825,7 +2823,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_Incomplete) { "{}"), // IsDirectory is checking if the path exists as an object. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fsrc.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2833,7 +2831,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_Incomplete) { "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), // Copying to the new location. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fsrc.txt/rewriteTo/b/bucket/o/path%2Fdst.txt\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -2856,7 +2854,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_Incomplete) { TEST(GcsFileSystemTest, Stat_Object) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2883,13 +2881,13 @@ TEST(GcsFileSystemTest, Stat_Object) { TEST(GcsFileSystemTest, Stat_Folder) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "subfolder?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=subfolder%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2917,13 +2915,13 @@ TEST(GcsFileSystemTest, Stat_Folder) { TEST(GcsFileSystemTest, Stat_ObjectOrFolderNotFound) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2947,7 +2945,7 @@ TEST(GcsFileSystemTest, Stat_ObjectOrFolderNotFound) { TEST(GcsFileSystemTest, Stat_Bucket) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}")}); @@ -2971,7 +2969,7 @@ TEST(GcsFileSystemTest, Stat_Bucket) { TEST(GcsFileSystemTest, Stat_BucketNotFound) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404)}); @@ -2994,20 +2992,20 @@ TEST(GcsFileSystemTest, Stat_BucketNotFound) { TEST(GcsFileSystemTest, Stat_Cache) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "subfolder%2F?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=subfolder%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -3043,14 +3041,14 @@ TEST(GcsFileSystemTest, Stat_Cache) { TEST(GcsFileSystemTest, Stat_Cache_Flush) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3087,7 +3085,7 @@ TEST(GcsFileSystemTest, Stat_Cache_Flush) { TEST(GcsFileSystemTest, Stat_FilenameEndingWithSlash) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "dir%2F?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3113,14 +3111,14 @@ TEST(GcsFileSystemTest, Stat_FilenameEndingWithSlash) { TEST(GcsFileSystemTest, IsDirectory_NotFound) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=file.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3143,14 +3141,14 @@ TEST(GcsFileSystemTest, IsDirectory_NotFound) { TEST(GcsFileSystemTest, IsDirectory_NotDirectoryButObject) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=file.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3174,14 +3172,14 @@ TEST(GcsFileSystemTest, IsDirectory_NotDirectoryButObject) { TEST(GcsFileSystemTest, IsDirectory_Yes) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=subfolder%2F" "&maxResults=1\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"items\": [{\"name\": \"subfolder/\"}]}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=subfolder%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -3205,12 +3203,12 @@ TEST(GcsFileSystemTest, IsDirectory_Yes) { TEST(GcsFileSystemTest, IsDirectory_Bucket) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}")}); @@ -3231,7 +3229,7 @@ TEST(GcsFileSystemTest, IsDirectory_Bucket) { TEST(GcsFileSystemTest, IsDirectory_BucketNotFound) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404)}); @@ -3256,14 +3254,14 @@ TEST(GcsFileSystemTest, CreateDir_Folder) { { // File doesn't exist. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "subpath%2F?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}"), // Simple upload. new FakeHttpRequest( - "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" "uploadType=media&name=subpath%2F&ifGenerationMatch=0\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -3271,7 +3269,7 @@ TEST(GcsFileSystemTest, CreateDir_Folder) { ""), // File exists. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "subpath%2F?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3279,14 +3277,14 @@ TEST(GcsFileSystemTest, CreateDir_Folder) { "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), // File doesn't exist again. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "subpath%2F?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}"), // Simulate object uploaded in between. new FakeHttpRequest( - "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" "uploadType=media&name=subpath%2F&ifGenerationMatch=0\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -3318,12 +3316,12 @@ TEST(GcsFileSystemTest, CreateDir_Folder) { TEST(GcsFileSystemTest, CreateDir_Bucket) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "")}); @@ -3346,7 +3344,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_Ok) { std::vector requests( {// IsDirectory is checking whether there are children objects. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -3355,7 +3353,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_Ok) { " { \"name\": \"path/file1.txt\" }]}"), // GetChildren recursively. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3365,35 +3363,35 @@ TEST(GcsFileSystemTest, DeleteRecursively_Ok) { " { \"name\": \"path/subpath/file2.txt\" }," " { \"name\": \"path/file3.txt\" }]}"), // Delete the current directory's marker. - new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" "/bucket/o/path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), // Delete the object - fails and will be retried. - new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" "/bucket/o/path%2Ffile1.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", "", errors::Unavailable("500"), 500), // Delete the object again. - new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" "/bucket/o/path%2Ffile1.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), // Delete the object. - new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" "/bucket/o/path%2Fsubpath%2Ffile2.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), // Delete the object. - new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" "/bucket/o/path%2Ffile3.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -3421,7 +3419,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { std::vector requests( {// IsDirectory is checking whether there are children objects. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -3430,7 +3428,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { " { \"name\": \"path/file1.txt\" }]}"), // Calling GetChildren recursively. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3440,14 +3438,14 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { " { \"name\": \"path/subpath/file2.txt\" }," " { \"name\": \"path/file3.txt\" }]}"), // Deleting the object. - new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" "/bucket/o/path%2Ffile1.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), // Deleting the directory marker gs://bucket/path/ - fails with 404. - new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" "/bucket/o/path%2Fsubpath%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -3455,7 +3453,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { "", errors::NotFound("404"), 404), // Checking if gs://bucket/path/subpath/ is a folder - it is. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubpath%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -3463,14 +3461,14 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { strings::StrCat("{\"items\": [ " " { \"name\": \"path/subpath/\" }]}")), // Deleting the object gs://bucket/path/subpath/file2.txt - new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" "/bucket/o/path%2Fsubpath%2Ffile2.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), // Deleting the object s://bucket/path/file3.txt - fails with 404. - new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" "/bucket/o/path%2Ffile3.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -3478,7 +3476,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { "", errors::NotFound("404"), 404), // Checking if gs://bucket/path/file3.txt/ is a folder - it's not. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Ffile3.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -3486,7 +3484,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { "{}"), // Checking if gs://bucket/path/file3.txt is an object - fails with 404. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Ffile3.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3514,7 +3512,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_NotAFolder) { std::vector requests( {// IsDirectory is checking whether there are children objects. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -3522,7 +3520,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_NotAFolder) { "{}"), // IsDirectory is checking if the path exists as an object. new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3606,7 +3604,7 @@ TEST(GcsFileSystemTest, AdditionalRequestHeaderTest) { std::vector requests( {// IsDirectory is checking whether there are children objects. - new FakeHttpRequest("Uri: https://www.googleapis.com./fake\n" + new FakeHttpRequest("Uri: https://www.googleapis.com/fake\n" "Auth Token: fake_token\n" "Header mynewheader: newheadercontents\n" "Header Hello: world\n", @@ -3624,7 +3622,7 @@ TEST(GcsFileSystemTest, AdditionalRequestHeaderTest) { std::unique_ptr request; TF_EXPECT_OK(fs7.CreateHttpRequest(&request)); - request->SetUri("https://www.googleapis.com./fake"); + request->SetUri("https://www.googleapis.com/fake"); request->AddHeader("Hello", "world"); TF_EXPECT_OK(request->Send()); } @@ -3686,7 +3684,7 @@ TEST(GcsFileSystemTest, OverrideCacheParameters) { TEST(GcsFileSystemTest, CreateHttpRequest) { std::vector requests( {// IsDirectory is checking whether there are children objects. - new FakeHttpRequest("Uri: https://www.googleapis.com./fake\n" + new FakeHttpRequest("Uri: https://www.googleapis.com/fake\n" "Auth Token: fake_token\n" "Header Hello: world\n", "{}")}); @@ -3703,7 +3701,7 @@ TEST(GcsFileSystemTest, CreateHttpRequest) { std::unique_ptr request; TF_EXPECT_OK(fs.CreateHttpRequest(&request)); - request->SetUri("https://www.googleapis.com./fake"); + request->SetUri("https://www.googleapis.com/fake"); request->AddHeader("Hello", "world"); TF_EXPECT_OK(request->Send()); } @@ -3747,7 +3745,7 @@ class TestGcsStats : public GcsStatsInterface { TEST(GcsFileSystemTest, Stat_StatsRecording) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3775,7 +3773,7 @@ TEST(GcsFileSystemTest, Stat_StatsRecording) { TEST(GcsFileSystemTest, NewRandomAccessFile_StatsRecording) { std::vector requests({new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-5\n" "Timeouts: 5 1 20\n", @@ -3817,7 +3815,7 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { // Fetch the file (stats and then content) new FakeHttpRequest( "Uri: " - "https://www.googleapis.com./storage/v1/b/bucket/o/" + "https://www.googleapis.com/storage/v1/b/bucket/o/" "some%2Fpath%2Fappendable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3825,14 +3823,14 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( "Uri: " - "https://storage.googleapis.com./bucket/some%2Fpath%2Fappendable\n" + "https://storage.googleapis.com/bucket/some%2Fpath%2Fappendable\n" "Auth Token: fake_token\n" "Range: 0-1048575\n" "Timeouts: 5 1 20\n", contents[0]), // Upload entire file new FakeHttpRequest( - "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=some%2Fpath%2Fappendable\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 18\n" @@ -3850,7 +3848,7 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { // Upload new part to a temporary object new FakeHttpRequest( "Uri: " - "https://www.googleapis.com./upload/storage/v1/b/bucket/" + "https://www.googleapis.com/upload/storage/v1/b/bucket/" "o?uploadType=resumable&name=some%2Fpath%2F.tmpcompose%2Fappendable." "18\n" "Auth Token: fake_token\n" @@ -3872,7 +3870,7 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { // Fetch generation new FakeHttpRequest( "Uri: " - "https://www.googleapis.com./storage/v1/b/bucket/o/" + "https://www.googleapis.com/storage/v1/b/bucket/o/" "some%2Fpath%2Fappendable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3880,7 +3878,7 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), // Compose the new part at the end of the original object. new FakeHttpRequest("Uri: " - "https://www.googleapis.com./storage/v1/b/bucket/o/" + "https://www.googleapis.com/storage/v1/b/bucket/o/" "some%2Fpath%2Fappendable/compose\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -3893,14 +3891,14 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { ""), // Delete the temporary object. new FakeHttpRequest("Uri: " - "https://www.googleapis.com./storage/v1/b/bucket/o/" + "https://www.googleapis.com/storage/v1/b/bucket/o/" "some%2Fpath%2F.tmpcompose%2Fappendable.18\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=some%2Fpath%2F.tmpcompose%2Fappendable." "27\n" "Auth Token: fake_token\n" @@ -3919,14 +3917,14 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { // Fetch generation new FakeHttpRequest( "Uri: " - "https://www.googleapis.com./storage/v1/b/bucket/o/" + "https://www.googleapis.com/storage/v1/b/bucket/o/" "some%2Fpath%2Fappendable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"4567\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest("Uri: " - "https://www.googleapis.com./storage/v1/b/bucket/o/" + "https://www.googleapis.com/storage/v1/b/bucket/o/" "some%2Fpath%2Fappendable/compose\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -3938,7 +3936,7 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { "'some/path/.tmpcompose/appendable.27'}]}\n", ""), new FakeHttpRequest("Uri: " - "https://www.googleapis.com./storage/v1/b/bucket/o/" + "https://www.googleapis.com/storage/v1/b/bucket/o/" "some%2Fpath%2F.tmpcompose%2Fappendable." "27\n" "Auth Token: fake_token\n" @@ -3975,20 +3973,20 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithoutCompose) { {"content0,", "content1,", "content2,", "content3,"}); std::vector requests({ new FakeHttpRequest( - "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fappendable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com./bucket/path%2Fappendable\n" + "Uri: https://storage.googleapis.com/bucket/path%2Fappendable\n" "Auth Token: fake_token\n" "Range: 0-1048575\n" "Timeouts: 5 1 20\n", contents[0]), new FakeHttpRequest( - "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fappendable\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 18\n" @@ -4005,7 +4003,7 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithoutCompose) { contents[0], contents[1], "\n"), ""), new FakeHttpRequest("Uri: " - "https://www.googleapis.com./upload/storage/v1/b/" + "https://www.googleapis.com/upload/storage/v1/b/" "bucket/o?" "uploadType=resumable&name=path%2Fappendable\n" "Auth Token: fake_token\n" @@ -4026,7 +4024,7 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithoutCompose) { contents[0], contents[1], contents[2], "\n"), ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fappendable\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 36\n" diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 27f86f61f9c431..7b8bacdead5b0f 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -8,7 +8,7 @@ load("//third_party/compute_library:build_defs.bzl", "if_enable_acl") # Placeholder: load py_proto_library load("//xla:xla.bzl", "xla_cc_test", "xla_py_proto_library") -load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.bzl", "if_google", "internal_visibility") load("//xla/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") package( @@ -103,7 +103,7 @@ tf_proto_library( protodeps = [ ":xla_data_proto", "//xla/service:hlo_proto", - ], + ] + if_google(["@com_google_protobuf//:any"]), visibility = ["//visibility:public"], ) @@ -304,9 +304,7 @@ cc_library( deprecation = "Use @com_google_absl//absl/status instead.", visibility = ["//visibility:public"], deps = [ - "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", ], ) @@ -1209,7 +1207,12 @@ tf_proto_library( name = "autotuning_proto", srcs = ["autotuning.proto"], make_default_target_header_only = True, - protodeps = ["@local_tsl//tsl/protobuf:dnn_proto"], + protodeps = [ + "@local_tsl//tsl/protobuf:dnn_proto", + ] + if_google([ + "@com_google_protobuf//:any", + "@com_google_protobuf//:duration", + ]), ) cc_library( diff --git a/third_party/xla/xla/backends/interpreter/BUILD b/third_party/xla/xla/backends/interpreter/BUILD index 8f5e83d870bd57..b2166ea9a28d42 100644 --- a/third_party/xla/xla/backends/interpreter/BUILD +++ b/third_party/xla/xla/backends/interpreter/BUILD @@ -69,7 +69,7 @@ cc_library( srcs = ["platform_id.cc"], hdrs = ["platform_id.h"], deps = ["//xla/stream_executor"] + if_static( - ["@com_google_protobuf//:protobuf"], + ["@com_google_protobuf//:any_cc_proto"], ["@com_google_protobuf//:protobuf_headers"], ), ) diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index de9dd466f88fce..6d30ef0330f1d3 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -2287,6 +2287,7 @@ PjRtStreamExecutorLoadedExecutable::PjRtStreamExecutorLoadedExecutable( TransferManager* transfer_manager = client_->client()->backend().transfer_manager(); executables_.reserve(executables.size()); + tsl::Fprint128 fingerprint = tsl::Fingerprint128(fingerprint_); for (auto& executable : executables) { const auto& computation_layout = executable->executable()->module().entry_computation_layout(); @@ -2296,10 +2297,14 @@ PjRtStreamExecutorLoadedExecutable::PjRtStreamExecutorLoadedExecutable( parameter_shapes.push_back(transfer_manager->HostShapeToDeviceShape( computation_layout.parameter_shape(i))); } + fingerprint = tsl::FingerprintCat128( + fingerprint, + tsl::Fingerprint128(executable->executable()->module().ToString())); executables_.emplace_back(std::move(executable)); on_device_executable_parameter_shapes_.push_back( std::move(parameter_shapes)); } + fingerprint_ = absl::StrCat(fingerprint.low64, fingerprint.high64); int num_partitions; if (device_assignment_ == nullptr) { @@ -3251,22 +3256,6 @@ PjRtStreamExecutorLoadedExecutable::GetOutputMemoryKinds() const { return Unimplemented("GetOutputMemoryKinds is not supported."); } -absl::StatusOr -PjRtStreamExecutorLoadedExecutable::FingerprintExecutable() const { - if (executables_.size() != 1) { - return absl::InternalError( - "Fingerprinting multiple executables within one " - "PjRtStreamExecutorLoadedExecutable is not supported."); - } - - Executable* executable = executables_[0]->executable(); - if (executable->has_module()) { - return executable->module().GetFingerprint128(); - } else { - return absl::InternalError("Executable does not have HLO modules."); - } -} - absl::StatusOr PjRtStreamExecutorClient::GetExecutableExtras(CompileOptions* options) { ExecutableExtras extras; diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index 64b7e60a618f4a..76d628ec9f7d7b 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -998,7 +998,9 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable { return compile_options_; } - absl::StatusOr FingerprintExecutable() const override; + absl::StatusOr FingerprintExecutable() const override { + return fingerprint_; + }; protected: bool parameter_is_tupled_arguments() const { @@ -1077,6 +1079,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable { // addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of // unique_ptrs to play well with the Python bindings (see xla.cc). std::vector addressable_devices_; + std::string fingerprint_; }; } // namespace xla diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index c882063e075dc3..87df6e556ff23c 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -410,7 +410,6 @@ cc_library( "@local_tsl//tsl/profiler/lib:profiler_session", "@local_tsl//tsl/profiler/lib:traceme", "@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", - "@com_google_protobuf//:protobuf", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", ] + if_cuda([ @@ -419,7 +418,9 @@ cc_library( "//xla/stream_executor/cuda:cuda_driver", ]) + if_rocm([ "@local_config_rocm//rocm:rocm_headers", - ]) + if_cuda_or_rocm([":py_client_gpu"]), # TODO(b/337876408): remove after migration to plugin + ]) + if_cuda_or_rocm([ + ":py_client_gpu", # TODO(b/337876408): remove after migration to plugin + ]) + if_google(["@com_google_protobuf//:any_cc_proto"]), ) cc_library( diff --git a/third_party/xla/xla/python/ifrt/array.h b/third_party/xla/xla/python/ifrt/array.h index ec620fd72f4b16..e080cceeec31c0 100644 --- a/third_party/xla/xla/python/ifrt/array.h +++ b/third_party/xla/xla/python/ifrt/array.h @@ -122,29 +122,6 @@ class Array : public llvm::RTTIExtends { void* data, std::optional> byte_strides, ArrayCopySemantics semantics) = 0; - // Copies the array with a new sharding, creating a new array. - // - // Resharding falls into one of the three cases: - // - // * Metadata-only resharding: Use a new sharding for the array that expects - // the same physical layout of underlying buffers on the same devices. - // * 1-to-1 buffer copy: Copy individual buffers to different devices without - // altering their physical layout. - // * M-to-N buffer resharding: Shuffle the buffer data across the boundary of - // the buffers, changing their physical layout. - // - // Implementations may return `UNIMPLEMENTED` if they do not know how to copy - // or reshuffle the data to match the new sharding. - // - // It may fail if the buffer data would be sent from/to an unaddressable - // device. - // - // TODO(b/343992694): Remove this API in favor of `Client::CopyArrays`. - ABSL_DEPRECATED("Use `Client::CopyArrays` instead") - virtual absl::StatusOr> Reshard( - std::shared_ptr new_sharding, - ArrayCopySemantics semantics) = 0; - static char ID; // NOLINT }; diff --git a/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc b/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc index d1b1051d3c51ee..e93500f2fd5954 100644 --- a/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc +++ b/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc @@ -445,70 +445,6 @@ TEST(ArrayImplTest, AssembleAndDisassembleSingleDeviceArray) { ElementsAreArray(array->sharding().devices().devices())); } -TEST(ArrayImplTest, ReshardToSameSharding) { - TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); - - DType dtype(DType::kF32); - Shape shape({2, 3}); - std::vector data(6); - std::iota(data.begin(), data.end(), 0); - Device* device = client->addressable_devices().at(0); - std::shared_ptr sharding = - SingleDeviceSharding::Create(device, MemoryKind()); - auto semantics = Client::HostBufferSemantics::kImmutableOnlyDuringCall; - - TF_ASSERT_OK_AND_ASSIGN( - auto array, client->MakeArrayFromHostBuffer( - data.data(), dtype, shape, - /*byte_strides=*/std::nullopt, sharding, semantics, - /*on_done_with_host_buffer=*/{})); - - TF_ASSERT_OK_AND_ASSIGN( - auto resharded_array, - array->Reshard(sharding, ArrayCopySemantics::kAlwaysCopy)); - - std::vector out_data(6); - auto future = resharded_array->CopyToHostBuffer( - out_data.data(), /*byte_strides=*/std::nullopt, - ArrayCopySemantics::kAlwaysCopy); - TF_ASSERT_OK(future.Await()); - EXPECT_THAT(out_data, ElementsAreArray(data)); -} - -TEST(ArrayImplTest, ReshardToDifferentDevice) { - TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); - - DType dtype(DType::kF32); - Shape shape({2, 3}); - std::vector data(6); - std::iota(data.begin(), data.end(), 0); - Device* device = client->addressable_devices().at(0); - std::shared_ptr sharding = - SingleDeviceSharding::Create(device, MemoryKind()); - auto semantics = Client::HostBufferSemantics::kImmutableOnlyDuringCall; - - TF_ASSERT_OK_AND_ASSIGN( - auto array, client->MakeArrayFromHostBuffer( - data.data(), dtype, shape, - /*byte_strides=*/std::nullopt, sharding, semantics, - /*on_done_with_host_buffer=*/{})); - - Device* new_device = client->addressable_devices().at(1); - std::shared_ptr new_sharding = - SingleDeviceSharding::Create(new_device, MemoryKind()); - - TF_ASSERT_OK_AND_ASSIGN( - auto resharded_array, - array->Reshard(new_sharding, ArrayCopySemantics::kAlwaysCopy)); - - std::vector out_data(6); - auto future = resharded_array->CopyToHostBuffer( - out_data.data(), /*byte_strides=*/std::nullopt, - ArrayCopySemantics::kAlwaysCopy); - TF_ASSERT_OK(future.Await()); - EXPECT_THAT(out_data, ElementsAreArray(data)); -} - TEST(ArrayImplTest, CopyToSameDevices) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); diff --git a/third_party/xla/xla/python/ifrt/mock.cc b/third_party/xla/xla/python/ifrt/mock.cc index f11e4cd11ffe44..2972adad4d0186 100644 --- a/third_party/xla/xla/python/ifrt/mock.cc +++ b/third_party/xla/xla/python/ifrt/mock.cc @@ -95,11 +95,6 @@ MockArray::MockArray(tsl::RCReference delegated) ArrayCopySemantics semantics) { return delegated_->CopyToHostBuffer(data, byte_strides, semantics); }); - ON_CALL(*this, Reshard) - .WillByDefault([this](std::shared_ptr new_sharding, - ArrayCopySemantics semantics) { - return delegated_->Reshard(std::move(new_sharding), semantics); - }); } // LINT.ThenChange() diff --git a/third_party/xla/xla/python/ifrt/mock.h b/third_party/xla/xla/python/ifrt/mock.h index c3fafdcd2e16b7..ae1810f6fb3a82 100644 --- a/third_party/xla/xla/python/ifrt/mock.h +++ b/third_party/xla/xla/python/ifrt/mock.h @@ -88,10 +88,6 @@ class MockArray : public llvm::RTTIExtends { std::optional> byte_strides, ArrayCopySemantics semantics), (final)); - MOCK_METHOD(absl::StatusOr>, Reshard, - (std::shared_ptr new_sharding, - ArrayCopySemantics semantics), - (final)); // LINT.ThenChange(mock.cc:MockArrayDelegation) tsl::RCReference delegated() const { return delegated_; } diff --git a/third_party/xla/xla/python/ifrt_proxy/client/array.h b/third_party/xla/xla/python/ifrt_proxy/client/array.h index 26d337a21916af..0452e504594a22 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/array.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/array.h @@ -127,9 +127,11 @@ class Array final : public llvm::RTTIExtends { void* data, std::optional> byte_strides, ArrayCopySemantics semantics) override; + // This will be deleted once the client requires the minimum version of 3. + ABSL_DEPRECATED("Use `Client::CopyArrays` instead") absl::StatusOr> Reshard( std::shared_ptr new_sharding, - ArrayCopySemantics semantics) override; + ArrayCopySemantics semantics); static char ID; // NOLINT diff --git a/third_party/xla/xla/python/ifrt_proxy/client/client.cc b/third_party/xla/xla/python/ifrt_proxy/client/client.cc index 83b8bcab682334..42c83556440129 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/client.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/client.cc @@ -227,8 +227,15 @@ Client::CopyArrays(absl::Span> arrays, TF_ASSIGN_OR_RETURN( auto new_sharding, array->sharding().WithDeviceAssignment(devices, memory_kind)); - TF_ASSIGN_OR_RETURN(new_arrays.emplace_back(), - array->Reshard(std::move(new_sharding), semantics)); + if (auto* const proxy_array = + llvm::dyn_cast(array.get())) { + TF_ASSIGN_OR_RETURN( + new_arrays.emplace_back(), + proxy_array->Reshard(std::move(new_sharding), semantics)); + } else { + return absl::InvalidArgumentError( + "Unsupported array type for xla::ifrt::proxy::Client::CopyArrays"); + } } return new_arrays; } diff --git a/third_party/xla/xla/python/ifrt_proxy/common/BUILD b/third_party/xla/xla/python/ifrt_proxy/common/BUILD index 30e93ada9ccfda..1230456a6600a3 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/common/BUILD @@ -67,7 +67,7 @@ tf_proto_library( srcs = ["ifrt_service.proto"], protodeps = [ ":types_proto", - # copybara:uncomment "//google/protobuf:any", + # copybara:uncomment "@com_google_protobuf//:any", "//xla:xla_data_proto", "//xla/pjrt:execute_options_proto", "//xla/python/ifrt:dtype_proto", diff --git a/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto index 64cd920166c1ea..945b66a3538045 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto +++ b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto @@ -55,7 +55,7 @@ message IfrtRequest { disassemble_into_single_device_arrays_request = 7; DeleteArrayRequest delete_array_request = 9; CopyArraysRequest copy_arrays_request = 24; - ReshardRequest reshard_request = 10; + ReshardRequest reshard_request = 10 [deprecated = true]; FullyReplicatedShardRequest fully_replicated_shard_request = 20; IsArrayDeletedRequest is_array_deleted_request = 11; DestructArrayRequest destruct_array_request = 12; @@ -102,7 +102,7 @@ message IfrtResponse { disassemble_into_single_device_arrays_response = 7; DeleteArrayResponse delete_array_response = 9; CopyArraysResponse copy_arrays_response = 24; - ReshardResponse reshard_response = 10; + ReshardResponse reshard_response = 10 [deprecated = true]; FullyReplicatedShardResponse fully_replicated_shard_response = 20; IsArrayDeletedResponse is_array_deleted_response = 11; DestructArrayResponse destruct_array_response = 12; diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc index 40eee8ad0416d5..85521905214040 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc @@ -697,13 +697,24 @@ absl::StatusOr IfrtBackend::HandleReshardRequest( TF_ASSIGN_OR_RETURN(auto semantics, FromArrayCopySemanticsProto( reshard_request.copy_semantics())); - TF_ASSIGN_OR_RETURN(auto resharded_array, - array->Reshard(sharding, semantics)); + // Emulate the old `Array::Reshard` behavior using `Client::CopyArrays`. No + // existing IFRT implementations before `Array::Reshard` was deleted actually + // supported resharding, so this should be safe. + if (!array->sharding().HasSamePartitioning(*sharding)) { + return absl::InvalidArgumentError(absl::StrCat( + "IFRT Proxy does not support resharding, but got ", + array->sharding().DebugString(), " as the original sharding and ", + sharding->DebugString(), " as the target sharding")); + } + TF_ASSIGN_OR_RETURN( + auto copied_arrays, + client_->CopyArrays(absl::MakeSpan(&array, 1), sharding->devices(), + sharding->memory_kind(), semantics)); uint64_t resharded_array_handle = handle_generator_.New(); { absl::MutexLock lock(&arrays_mutex_); - arrays_.insert({resharded_array_handle, std::move(resharded_array)}); + arrays_.insert({resharded_array_handle, std::move(copied_arrays[0])}); } auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc index e03230a08ffa4a..f54ea4effd7371 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc @@ -700,21 +700,27 @@ TEST_P(IfrtBackendHandlerTest, CopyArrays) { TEST_P(IfrtBackendHandlerTest, ReshardSuccess) { auto src_mock_array = tsl::MakeRef(); - auto resharded_mock_array = tsl::MakeRef(); - EXPECT_CALL(*src_mock_array, Reshard(_, _)) - .WillOnce(Return(std::move(resharded_mock_array))); + TF_ASSERT_OK_AND_ASSIGN(auto* device, + mock_client_->LookupDevice(DeviceId(0))); + auto src_sharding = SingleDeviceSharding::Create(device, MemoryKind()); + ON_CALL(*src_mock_array, sharding()).WillByDefault(ReturnRef(*src_sharding)); TF_ASSERT_OK_AND_ASSIGN(auto src_array_handle, MakeTestArray(std::move(src_mock_array))); + auto copied_mock_array = tsl::MakeRef(); + EXPECT_CALL(*mock_client_, CopyArrays(_, _, _, _)) + .WillOnce(Return(std::vector>( + {copied_mock_array}))); + auto ifrt_request = NewIfrtRequest(NewOpId()); auto* reshard_request = ifrt_request->mutable_reshard_request(); reshard_request->set_array_handle(src_array_handle); reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY); - TF_ASSERT_OK_AND_ASSIGN(auto* device, + TF_ASSERT_OK_AND_ASSIGN(auto* new_device, mock_client_->LookupDevice(DeviceId(1))); TF_ASSERT_OK_AND_ASSIGN( *ifrt_request->mutable_reshard_request()->mutable_sharding(), - SingleDeviceSharding::Create(device, MemoryKind())->ToProto()); + SingleDeviceSharding::Create(new_device, MemoryKind())->ToProto()); TF_ASSERT_OK_AND_ASSIGN(auto response, CallBackend(std::move(ifrt_request))); @@ -723,6 +729,43 @@ TEST_P(IfrtBackendHandlerTest, ReshardSuccess) { EXPECT_NE(response->reshard_response().array_handle(), 0); } +TEST_P(IfrtBackendHandlerTest, ReshardFailsWhenTheBackendFails) { + auto mock_array = tsl::MakeRef(); + TF_ASSERT_OK_AND_ASSIGN(auto* device, + mock_client_->LookupDevice(DeviceId(1))); + auto sharding = SingleDeviceSharding::Create(device, MemoryKind()); + ON_CALL(*mock_array, sharding()).WillByDefault(ReturnRef(*sharding)); + TF_ASSERT_OK_AND_ASSIGN(auto array_handle, + MakeTestArray(std::move(mock_array))); + + EXPECT_CALL(*mock_client_, CopyArrays(_, _, _, _)) + .WillOnce(Return(absl::UnknownError("injected error"))); + + auto ifrt_request = NewIfrtRequest(NewOpId()); + auto* reshard_request = ifrt_request->mutable_reshard_request(); + reshard_request->set_array_handle(array_handle); + reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY); + TF_ASSERT_OK_AND_ASSIGN(auto* new_device, + mock_client_->LookupDevice(DeviceId(1))); + TF_ASSERT_OK_AND_ASSIGN( + *ifrt_request->mutable_reshard_request()->mutable_sharding(), + SingleDeviceSharding::Create(new_device, MemoryKind())->ToProto()); + + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kUnknown, StrEq("injected error"))); +} + +TEST_P(IfrtBackendHandlerTest, ReshardFailsWithNonExistentArrayHandle) { + auto ifrt_request = NewIfrtRequest(NewOpId()); + auto* reshard_request = ifrt_request->mutable_reshard_request(); + reshard_request->set_array_handle(0); + reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY); + reshard_request->mutable_sharding(); + + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kNotFound)); +} + TEST_P(IfrtBackendHandlerTest, FullyReplicatedShardSuccess) { auto fully_replicated_mock_array = tsl::MakeRef(); auto resultant_array = tsl::MakeRef(); @@ -777,38 +820,6 @@ TEST_P(IfrtBackendHandlerTest, StatusIs(absl::StatusCode::kNotFound)); } -TEST_P(IfrtBackendHandlerTest, ReshardFailsWhenTheBackendFails) { - auto mock_array = tsl::MakeRef(); - EXPECT_CALL(*mock_array, Reshard(_, _)) - .WillOnce(Return(absl::UnknownError("injected error"))); - TF_ASSERT_OK_AND_ASSIGN(auto array_handle, - MakeTestArray(std::move(mock_array))); - - auto ifrt_request = NewIfrtRequest(NewOpId()); - auto* reshard_request = ifrt_request->mutable_reshard_request(); - reshard_request->set_array_handle(array_handle); - reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY); - TF_ASSERT_OK_AND_ASSIGN(auto* device, - mock_client_->LookupDevice(DeviceId(1))); - TF_ASSERT_OK_AND_ASSIGN( - *ifrt_request->mutable_reshard_request()->mutable_sharding(), - SingleDeviceSharding::Create(device, MemoryKind())->ToProto()); - - EXPECT_THAT(CallBackend(std::move(ifrt_request)), - StatusIs(absl::StatusCode::kUnknown, StrEq("injected error"))); -} - -TEST_P(IfrtBackendHandlerTest, ReshardFailsWithNonExistentArrayHandle) { - auto ifrt_request = NewIfrtRequest(NewOpId()); - auto* reshard_request = ifrt_request->mutable_reshard_request(); - reshard_request->set_array_handle(0); - reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY); - reshard_request->mutable_sharding(); - - EXPECT_THAT(CallBackend(std::move(ifrt_request)), - StatusIs(absl::StatusCode::kNotFound)); -} - TEST_P(IfrtBackendHandlerTest, CheckArrayReadyRequestRelaysTheResultFromBackend) { auto mock_array = tsl::MakeRef(); diff --git a/third_party/xla/xla/python/pjrt_ifrt/BUILD b/third_party/xla/xla/python/pjrt_ifrt/BUILD index 49b498d893f859..9726b446eb21dd 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/BUILD +++ b/third_party/xla/xla/python/pjrt_ifrt/BUILD @@ -1,6 +1,6 @@ load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library") load("//xla:xla.bzl", "xla_cc_test") -load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.bzl", "if_google", "internal_visibility") load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") package_group( @@ -70,7 +70,9 @@ tf_proto_library( name = "xla_host_callback_proto", srcs = ["xla_host_callback.proto"], cc_api_version = 2, - protodeps = ["//xla:xla_data_proto"], + protodeps = [ + "//xla:xla_data_proto", + ] + if_google(["@com_google_protobuf//:any"]), ) tf_proto_library( diff --git a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc index b832981c766998..74c8cd6e097c3f 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc @@ -33,6 +33,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" #include "xla/python/ifrt/future.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" @@ -286,8 +287,9 @@ Future<> BasicStringArray::CopyToHostBuffer( return Future<>(absl::UnimplementedError("Not implemented")); } -absl::StatusOr> BasicStringArray::Reshard( - std::shared_ptr new_sharding, +absl::StatusOr> BasicStringArray::Copy( + std::optional devices, + std::optional memory_kind, ArrayCopySemantics semantics) { DCHECK(this); absl::MutexLock lock(&mu_); @@ -295,6 +297,8 @@ absl::StatusOr> BasicStringArray::Reshard( return absl::FailedPreconditionError("Array has already been deleted"); } + TF_ASSIGN_OR_RETURN(auto new_sharding, + sharding().WithDeviceAssignment(devices, memory_kind)); if (new_sharding->devices().size() != sharding_->devices().size()) { return absl::InvalidArgumentError(absl::StrCat( "Number of devices in new sharding: ", new_sharding->devices().size(), diff --git a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.h b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.h index a02af9b12dc4cc..6513663d231a56 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.h +++ b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.h @@ -34,8 +34,10 @@ limitations under the License. #include "llvm/Support/ExtensibleRTTI.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" #include "xla/tsl/concurrency/ref_count.h" @@ -128,9 +130,10 @@ class BasicStringArray final void* data, std::optional> byte_strides, ArrayCopySemantics semantics) override; - absl::StatusOr> Reshard( - std::shared_ptr new_sharding, - ArrayCopySemantics semantics) override; + absl::StatusOr> Copy( + std::optional devices, + std::optional memory_kind, + ArrayCopySemantics semantics); Future<> GetReadyFuture() const override; diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc index c73212464db7be..73324e6b1c8c91 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc @@ -450,10 +450,13 @@ absl::StatusOr GetMemorySpaceFromMemoryKind( return memory; } -absl::StatusOr> PjRtArray::Reshard( - std::shared_ptr new_sharding, +absl::StatusOr> PjRtArray::Copy( + std::optional devices, + std::optional memory_kind, ArrayCopySemantics semantics) { DCHECK(this); + TF_ASSIGN_OR_RETURN(auto new_sharding, + sharding().WithDeviceAssignment(devices, memory_kind)); if (new_sharding->devices().size() != sharding_->devices().size()) { return InvalidArgument( "Resharding to a different number of devices: %d; expected %d", diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h index 94253b30a3b912..37396a5f708df4 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h @@ -26,6 +26,7 @@ limitations under the License. #include "llvm/Support/ExtensibleRTTI.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" #include "xla/python/ifrt/shape.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/tsl/concurrency/ref_count.h" @@ -156,9 +157,10 @@ class PjRtArray final void* data, std::optional> byte_strides, ArrayCopySemantics semantics) override; - absl::StatusOr> Reshard( - std::shared_ptr new_sharding, - ArrayCopySemantics semantics) override; + absl::StatusOr> Copy( + std::optional devices, + std::optional memory_kind, + ArrayCopySemantics semantics); Future<> GetReadyFuture() const override; diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc index e412d9cb87450b..f4a4e4c79687a4 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc @@ -696,11 +696,17 @@ absl::StatusOr>> PjRtClient::CopyArrays( std::vector> new_arrays; new_arrays.reserve(arrays.size()); for (const auto& array : arrays) { - TF_ASSIGN_OR_RETURN( - auto new_sharding, - array->sharding().WithDeviceAssignment(devices, memory_kind)); - TF_ASSIGN_OR_RETURN(new_arrays.emplace_back(), - array->Reshard(std::move(new_sharding), semantics)); + if (auto* const pjrt_array = llvm::dyn_cast(array.get())) { + TF_ASSIGN_OR_RETURN(new_arrays.emplace_back(), + pjrt_array->Copy(devices, memory_kind, semantics)); + } else if (auto* const string_array = + llvm::dyn_cast(array.get())) { + TF_ASSIGN_OR_RETURN(new_arrays.emplace_back(), + string_array->Copy(devices, memory_kind, semantics)); + } else { + return absl::InvalidArgumentError( + "Unsupported array type for PjRtClient::CopyArrays"); + } } return new_arrays; } diff --git a/third_party/xla/xla/python/tools/BUILD b/third_party/xla/xla/python/tools/BUILD index 480f20bb2fa769..05e1b0ff08ca2d 100644 --- a/third_party/xla/xla/python/tools/BUILD +++ b/third_party/xla/xla/python/tools/BUILD @@ -48,7 +48,7 @@ pytype_strict_library( ) # NOTE: Copybara detects the `tsl_pybind_extension` rule and automatically -# injects the "@com_google_protobuf//:protobuf_python" python dependency +# injects the @com_google_protobuf//:protobuf_python python dependency # required by "@pybind11_protobuf//pybind11_protobuf:native_proto_caster". tsl_pybind_extension( name = "_types", diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 59a5dc94278e45..553d29e960458e 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -51,7 +51,9 @@ tf_proto_library( srcs = ["hlo.proto"], cc_api_version = 2, make_default_target_header_only = True, - protodeps = ["//xla:xla_data_proto"], + protodeps = [ + "//xla:xla_data_proto", + ] + if_google(["@com_google_protobuf//:any"]), visibility = ["//visibility:public"], ) @@ -79,6 +81,11 @@ tf_proto_library( name = "metrics_proto", srcs = ["metrics.proto"], cc_api_version = 2, + protodeps = if_google([ + "@com_google_protobuf//:any", + "@com_google_protobuf//:duration", + "@com_google_protobuf//:timestamp", + ]), visibility = ["//visibility:public"], ) @@ -7566,7 +7573,7 @@ cc_library( "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - ], + ] + if_google(["@com_google_protobuf//:any_cc_proto"]), ) cc_library( @@ -8154,7 +8161,7 @@ tf_proto_library( protodeps = [ ":hlo_proto", "@local_tsl//tsl/protobuf:status_proto", - ], + ] + if_google(["@com_google_protobuf//:duration"]), visibility = ["//visibility:public"], ) diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 11e1bc0d7bd59e..9bcaf4e324fd15 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -633,6 +633,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:elemental_ir_emitter", "//xla/service/cpu:dot_op_emitter", + "//xla/service/llvm_ir:dynamic_update_slice_util", "//xla/service/llvm_ir:fused_ir_emitter", "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:llvm_util", diff --git a/third_party/xla/xla/service/cpu/benchmarks/fusion_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/fusion_benchmark_test.cc index cc4b571179cf7b..60738192e2b54c 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/fusion_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/fusion_benchmark_test.cc @@ -87,6 +87,33 @@ static void BM_BcastFusionF32(benchmark::State& state) { CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); } +static void BM_DynamicUpdateSliceFusionF32(benchmark::State& state) { + int64_t d0 = state.range(0); + + std::string_view hlo = R"( + HloModule dynamic_update_slice_fusion_f32_$d0 + + ENTRY e { + p0 = f32[$d0,256] parameter(0) + p1 = s32[] parameter(1) + p2 = s32[] parameter(2) + slice = f32[1,1] dynamic-slice(p0, p1, p2), dynamic_slice_sizes={1,1} + add = f32[1,1] add(slice, slice) + ROOT update = f32[$d0,256] dynamic-update-slice(p0, add, p1, p2) + } + )"; + + std::minstd_rand0 engine; + + auto shape = ShapeUtil::MakeShape(F32, {d0, 256}); + auto p0 = *LiteralUtil::CreateRandomLiteral(shape, &engine, 1.0f, 0.1f); + auto p1 = LiteralUtil::CreateR0(0); + auto p2 = LiteralUtil::CreateR0(0); + + std::vector args = {&p0, &p1, &p2}; + CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); +} + BENCHMARK(BM_FusionF32) ->MeasureProcessCPUTime() ->Arg(128) @@ -105,4 +132,13 @@ BENCHMARK(BM_BcastFusionF32) ->Arg(8192) ->Arg(16384); +BENCHMARK(BM_DynamicUpdateSliceFusionF32) + ->MeasureProcessCPUTime() + ->Arg(128) + ->Arg(256) + ->Arg(512) + ->Arg(1024) + ->Arg(8192) + ->Arg(16384); + } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/ir_emitter.h b/third_party/xla/xla/service/cpu/ir_emitter.h index 9988a260657e1f..cabcbfd1a09ba5 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.h +++ b/third_party/xla/xla/service/cpu/ir_emitter.h @@ -165,6 +165,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, return target_machine_features_; } + const BufferAssignment& assignment() const { return assignment_; } + protected: friend class IrEmitter2; diff --git a/third_party/xla/xla/service/cpu/ir_emitter2.cc b/third_party/xla/xla/service/cpu/ir_emitter2.cc index ba5797c628b2a2..eb5b8139a5275e 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter2.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter2.cc @@ -55,6 +55,7 @@ limitations under the License. #include "xla/service/cpu/parallel_loop_emitter.h" #include "xla/service/cpu/shape_partition.h" #include "xla/service/elemental_ir_emitter.h" +#include "xla/service/llvm_ir/dynamic_update_slice_util.h" #include "xla/service/llvm_ir/fused_ir_emitter.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_util.h" @@ -285,8 +286,8 @@ absl::StatusOr IrEmitter2::EmitFusionHostKernel( ElementalIrEmitter elemental_emitter(module_, &b, &hlo_module_, nested_ir_emitter_, fast_min_max()); - FusedIrEmitter fused_emitter(elemental_emitter); + FusedIrEmitter fused_emitter(elemental_emitter); for (int i = 0; i < fusion->operand_count(); i++) { fused_emitter.BindGenerator( *fusion->fused_parameter(i), [&, i](llvm_ir::IrArray::Index idx) { @@ -294,6 +295,22 @@ absl::StatusOr IrEmitter2::EmitFusionHostKernel( }); } + // Check if the fusion can be emitted in-place and skip expensive loop for + // all elements in the output array. + if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace( + const_cast(fusion), + nested_ir_emitter_->assignment())) { + // Delegate to common implementation of fused in-place dynamic-update-slice. + TF_RETURN_IF_ERROR(llvm_ir::EmitFusedDynamicUpdateSliceInPlace( + const_cast(fusion), kernel_prototype.results[0], + &fused_emitter, &b)); + + return kernels_.emplace_back( + KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(), + se::ThreadDim()}); + } + + // Emit plain elemental loops for the fusion operation. TF_ASSIGN_OR_RETURN( auto element_generator, fused_emitter.GetGenerator(*fusion->fused_expression_root())); diff --git a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h b/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h index 334db7f60b8356..2dbd3a66eb48c2 100644 --- a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h +++ b/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" @@ -35,7 +36,7 @@ class OneDnnConvolutionRewriter : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD index 018063221843eb..8c37031edf7985 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/service/cpu/runtime/BUILD @@ -717,8 +717,10 @@ cc_library( "//xla/stream_executor/host:host_kernel", "//xla/stream_executor/host:host_kernel_c_api", "//xla/tsl/concurrency:async_value", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", + "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", @@ -746,6 +748,7 @@ xla_cc_test( "//xla/stream_executor/host:host_kernel_c_api", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc index 5633426f192210..85d65c0ec323fd 100644 --- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc @@ -23,8 +23,10 @@ limitations under the License. #include #include +#include "absl/base/optimization.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" +#include "absl/numeric/bits.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive @@ -49,7 +51,12 @@ absl::StatusOr> KernelThunk::Create( Info info, absl::Span arguments_buffers, absl::Span results_buffers, std::string kernel_name, se::ThreadDim thread_dim, - std::optional min_alignment) { + std::optional min_alignment) { + if (min_alignment.has_value() && !absl::has_single_bit(*min_alignment)) { + return Internal("Host kernel %s minimum alignment %d is not a power of 2", + info.op_name, *min_alignment); + } + return absl::WrapUnique( new KernelThunk(std::move(info), arguments_buffers, results_buffers, std::move(kernel_name), thread_dim, min_alignment)); @@ -59,13 +66,14 @@ KernelThunk::KernelThunk( Info info, absl::Span arguments_buffers, absl::Span results_buffers, std::string kernel_name, se::ThreadDim thread_dim, - std::optional min_alignment) + std::optional min_alignment) : Thunk(Kind::kKernel, std::move(info)), arguments_buffers_(arguments_buffers.begin(), arguments_buffers.end()), results_buffers_(results_buffers.begin(), results_buffers.end()), kernel_name_(std::move(kernel_name)), thread_dim_(thread_dim), min_alignment_(min_alignment), + use_task_runner_(thread_dim != se::ThreadDim()), kernel_ptr_(nullptr) {} tsl::AsyncValueRef KernelThunk::Execute( @@ -104,12 +112,12 @@ tsl::AsyncValueRef KernelThunk::Execute( // will crash with a segmentation fault, or worse, produce incorrect results. if (min_alignment_.has_value()) { for (int64_t i = 0; i < buffers_data.size(); ++i) { - se::DeviceMemoryBase& data = buffers_data[i]; - if (reinterpret_cast(data.opaque()) % *min_alignment_ != 0) { + auto ptr = reinterpret_cast(buffers_data[i].opaque()); + if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) { return Internal( "Host kernel %s buffer argument #%d (%p) is not aligned to a " "required minimum alignment of %d bytes", - info().op_name, i, data.opaque(), *min_alignment_); + info().op_name, i, buffers_data[i].opaque(), *min_alignment_); } } } @@ -130,7 +138,7 @@ tsl::AsyncValueRef KernelThunk::Execute( // If intra-op thread pool is not nullptr, we launch HostKernel in async mode // by scheduling tasks into it. HostKernel launch completion will // automatically signal KernelThunk execute completion. - if (params.intra_op_threadpool) { + if (params.intra_op_threadpool && use_task_runner_) { return kernel.Launch(thread_dim_, buffers_data, [¶ms](se::host::HostKernel::Task task) { params.intra_op_threadpool->getPool()->Schedule( diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.h b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.h index b0bce468de5240..72cd1be097ac25 100644 --- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.h +++ b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.h @@ -40,7 +40,7 @@ class KernelThunk final : public Thunk { Info info, absl::Span arguments_buffers, absl::Span results_buffers, std::string kernel_name, se::ThreadDim thread_dim, - std::optional min_alignment = std::nullopt); + std::optional min_alignment = std::nullopt); tsl::AsyncValueRef Execute(const ExecuteParams& params) final; @@ -51,13 +51,18 @@ class KernelThunk final : public Thunk { absl::Span arguments_buffers, absl::Span results_buffers, std::string kernel_name, se::ThreadDim thread_dim, - std::optional min_alignment); + std::optional min_alignment); std::vector arguments_buffers_; std::vector results_buffers_; std::string kernel_name_; se::ThreadDim thread_dim_; - std::optional min_alignment_; + std::optional min_alignment_; + + // If `true`, pass a HostKernel::TaskRunner to the kernel launch. If kernel + // has a single thread, we skip constructing HostKernel::TaskRunner and + // launch the kernel directly in the caller thread. + bool use_task_runner_; // Pointer to the host kernel corresponding to `kernel_name_`. Initialized // lazily at run time by looking it up in the HostKernels passed via params. diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk_test.cc b/third_party/xla/xla/service/cpu/runtime/kernel_thunk_test.cc index 7fa2c3bc357b2e..a80db35857e86f 100644 --- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk_test.cc +++ b/third_party/xla/xla/service/cpu/runtime/kernel_thunk_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/runtime/buffer_allocations.h" #include "xla/service/cpu/runtime/thunk.h" @@ -53,6 +54,13 @@ class AddF32HostKernels : public Thunk::HostKernels { } }; +TEST(KernelThunkTest, CheckAlignment) { + auto thunk = KernelThunk::Create({"test"}, {}, {}, "test", se::ThreadDim(), + /*min_alignment=*/3); + EXPECT_TRUE(absl::StrContains(thunk.status().message(), + "minimum alignment 3 is not a power of 2")); +} + TEST(KernelThunkTest, AddF32) { std::vector buffers; std::vector in = {1.0, 2.0, 3.0, 4.0}; diff --git a/third_party/xla/xla/service/cpu/runtime/thunk_executor.cc b/third_party/xla/xla/service/cpu/runtime/thunk_executor.cc index 37513cc37d89e5..6dbf7d750c5cd8 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk_executor.cc +++ b/third_party/xla/xla/service/cpu/runtime/thunk_executor.cc @@ -41,7 +41,8 @@ namespace xla::cpu { ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, std::vector nodes_defs) : thunk_sequence_(std::move(thunk_sequence)), - nodes_defs_(std::move(nodes_defs)) { + nodes_defs_(std::move(nodes_defs)), + is_sequential_(true) { for (NodeId i = 0; i < nodes_defs_.size(); ++i) { // Mark nodes with empty in-edges as source nodes. if (nodes_defs_[i].in_edges.empty()) { @@ -57,10 +58,17 @@ ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, // Erase redundant edges between nodes. int64_t num_erased_edges = TransitiveReduction(); + // Check if constructed execution DAG is sequential: every node depends on the + // completion of the previous node. + for (NodeId i = 1; i < nodes_defs_.size() && is_sequential_; ++i) { + is_sequential_ &= (absl::c_count(nodes_defs_[i].in_edges, i - 1) != 0); + } + VLOG(2) << absl::StreamFormat( "Constructed ThunkExecutor with %d nodes: #source_nodes=%d " - "#sink_nodes=%d, #erased_edges=%d", - nodes_defs_.size(), source_.size(), sink_.size(), num_erased_edges); + "#sink_nodes=%d, #erased_edges=%d, is_sequential=%v", + nodes_defs_.size(), source_.size(), sink_.size(), num_erased_edges, + is_sequential_); // Sanity check that all vectors are empty or all vectors are non-empty. DCHECK((!source_.empty() && !sink_.empty() && !thunk_sequence_.empty()) || @@ -123,6 +131,13 @@ tsl::AsyncValueRef ThunkExecutor::Execute( return thunk_sequence_[0]->Execute(params); } + // If thunk sequence dependencies form a sequential execution graph, we skip + // expensive async execution and simply run thunks one by one. + if (is_sequential_) { + return ExecuteSequential(params); + } + + // Create async execution state on heap and kick-off execution. auto state = std::make_unique(this, std::move(runner)); Execute(state.get(), params, ReadyQueue(source_.begin(), source_.end())); @@ -138,6 +153,70 @@ tsl::AsyncValueRef ThunkExecutor::Execute( return execute_event; } +tsl::AsyncValueRef +ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) { + for (int64_t i = 0; i < thunk_sequence_.size(); ++i) { + Thunk& thunk = *thunk_sequence_[i]; + auto execute_event = thunk.Execute(params); + + // If thunk execution is not completed yet, attach a continuation to + // resume sequential execution starting from the next thunk. + if (ABSL_PREDICT_FALSE(!execute_event.IsAvailable())) { + auto event = tsl::MakeConstructedAsyncValueRef(); + execute_event.AndThen([this, ¶ms, i, event](absl::Status status) { + if (ABSL_PREDICT_FALSE(!status.ok())) { + event.SetError(std::move(status)); + } else { + ResumeExecuteSequential(i + 1, params, std::move(event)); + } + }); + return event; + } + + // Abort execution if any of the thunks failed. + if (ABSL_PREDICT_FALSE(execute_event.IsError())) { + return execute_event; + } + } + + // If we got to the end of the sequence it means that all thunks have + // succeeded. + return Thunk::OkExecuteEvent(); +} + +void ThunkExecutor::ResumeExecuteSequential( + int64_t index, const Thunk::ExecuteParams& params, + tsl::AsyncValueRef event) { + for (int64_t i = index; i < thunk_sequence_.size(); ++i) { + Thunk& thunk = *thunk_sequence_[i]; + auto execute_event = thunk.Execute(params); + + // If thunk execution is not completed yet, attach a continuation to + // resume sequential execution starting from the next thunk. + if (ABSL_PREDICT_FALSE(!execute_event.IsAvailable())) { + execute_event.AndThen( + [this, ¶ms, i, event = std::move(event)](absl::Status status) { + if (ABSL_PREDICT_FALSE(!status.ok())) { + event.SetError(std::move(status)); + } else { + ResumeExecuteSequential(i + 1, params, std::move(event)); + } + }); + return; + } + + // Abort execution if any of the thunks failed. + if (ABSL_PREDICT_FALSE(execute_event.IsError())) { + event.SetError(execute_event.GetError()); + return; + } + } + + // If we got to the end of the sequence it means that all thunks have + // succeeded. + event.SetStateConcrete(); +} + void ThunkExecutor::Execute(ExecuteState* state, const Thunk::ExecuteParams& params, ReadyQueue ready_queue) { diff --git a/third_party/xla/xla/service/cpu/runtime/thunk_executor.h b/third_party/xla/xla/service/cpu/runtime/thunk_executor.h index c16d140a91af12..e7fe89926dd192 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk_executor.h +++ b/third_party/xla/xla/service/cpu/runtime/thunk_executor.h @@ -85,6 +85,8 @@ class ThunkExecutor { std::string ToString() const; + bool is_sequential() const { return is_sequential_; } + private: using ReadyQueue = absl::InlinedVector; @@ -121,6 +123,15 @@ class ThunkExecutor { tsl::AsyncValueRef execute_event; }; + // Executes thunks sequentially starting from the first thunk in the sequence. + tsl::AsyncValueRef ExecuteSequential( + const Thunk::ExecuteParams& params); + + // Resumes sequential thunk execution starting from the given index. + void ResumeExecuteSequential(int64_t index, + const Thunk::ExecuteParams& params, + tsl::AsyncValueRef event); + // Executes nodes in the ready queue with given thunk parameters. void Execute(ExecuteState* state, const Thunk::ExecuteParams& params, ReadyQueue ready_queue); @@ -143,6 +154,11 @@ class ThunkExecutor { std::vector source_; std::vector sink_; + + // If NodeDef graph dependency structure is sequential and does not have any + // opportunities for executing thunks concurrently, we skip the expensive + // async execution and simply run thunks in the `thunk_sequence_` one by one. + bool is_sequential_; }; } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/thunk_executor_test.cc b/third_party/xla/xla/service/cpu/runtime/thunk_executor_test.cc index 38bdd2ee4c5d8c..7ba4cad6a4dea4 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk_executor_test.cc +++ b/third_party/xla/xla/service/cpu/runtime/thunk_executor_test.cc @@ -63,13 +63,15 @@ using ::testing::ElementsAre; class AddI32Thunk final : public Thunk { public: AddI32Thunk(std::string name, std::vector srcs, - std::vector dsts, bool inject_error, - std::vector* trace); + std::vector dsts, + std::vector* trace, bool inject_error, + bool inject_side_effect); static std::unique_ptr Create( std::string name, std::vector srcs, - std::vector dsts, bool inject_error = false, - std::vector* trace = nullptr); + std::vector dsts, + std::vector* trace = nullptr, bool inject_error = false, + bool inject_side_effect = false); static std::vector AsDeviceMemory( absl::Span* const> data); @@ -86,16 +88,18 @@ class AddI32Thunk final : public Thunk { private: std::vector srcs_; std::vector dsts_; - bool inject_error_; std::vector* trace_; + bool inject_error_; + bool inject_side_effect_; }; std::unique_ptr AddI32Thunk::Create( std::string name, std::vector srcs, - std::vector dsts, bool inject_error, - std::vector* trace) { + std::vector dsts, std::vector* trace, + bool inject_error, bool inject_side_effect) { return std::make_unique(std::move(name), std::move(srcs), - std::move(dsts), inject_error, trace); + std::move(dsts), trace, inject_error, + inject_side_effect); } std::vector AddI32Thunk::AsDeviceMemory( @@ -111,12 +115,14 @@ std::vector AddI32Thunk::AsDeviceMemory( AddI32Thunk::AddI32Thunk(std::string name, std::vector srcs, std::vector dsts, - bool inject_error, std::vector* trace) + std::vector* trace, bool inject_error, + bool inject_side_effect) : Thunk(Kind::kKernel, Info{name}), srcs_(std::move(srcs)), dsts_(std::move(dsts)), + trace_(trace), inject_error_(inject_error), - trace_(trace) {} + inject_side_effect_(inject_side_effect) {} absl::Status AddI32Thunk::Execute(const BufferAllocations* allocations, BufferAllocation::Slice src_slice, @@ -178,10 +184,20 @@ AddI32Thunk::BufferUses AddI32Thunk::buffer_uses() const { BufferUses buffer_uses; for (const auto& src : srcs_) buffer_uses.push_back(BufferUse::Read(src)); for (const auto& dst : dsts_) buffer_uses.push_back(BufferUse::Write(dst)); + + // TODO(ezhulenev): Add proper side-effect support to Thunks. For now we just + // inject a write to a random slice of allocation 0 to emulate a side-effect + // and force all thunks to be executed sequentially. + if (inject_side_effect_) { + static auto* fake_alloc = new BufferAllocation(0, 1, 0); + buffer_uses.push_back( + BufferUse::Write(BufferAllocation::Slice(fake_alloc, 0, 1))); + } + return buffer_uses; } -TEST(ThunkExecutorTest, Ordering) { +TEST(ThunkExecutorTest, DependencyOrdering) { BufferAllocation alloc(/*index=*/0, /*size=*/80, /*color=*/0); BufferAllocation::Slice slice0(&alloc, /*offset=*/0, /*size=*/40); @@ -196,10 +212,28 @@ TEST(ThunkExecutorTest, Ordering) { TF_ASSERT_OK_AND_ASSIGN(ThunkExecutor executor, ThunkExecutor::Create(std::move(sequence))); + EXPECT_FALSE(executor.is_sequential()); EXPECT_THAT(executor.source(), ElementsAre(0, 1)); EXPECT_THAT(executor.sink(), ElementsAre(2)); } +TEST(ThunkExecutorTest, SequentialOrdering) { + BufferAllocation alloc(/*index=*/0, /*size=*/80, /*color=*/0); + BufferAllocation::Slice slice(&alloc, /*offset=*/0, /*size=*/40); + + ThunkSequence sequence; + sequence.push_back(AddI32Thunk::Create("a", {slice}, {slice})); + sequence.push_back(AddI32Thunk::Create("b", {slice}, {slice})); + sequence.push_back(AddI32Thunk::Create("c", {slice}, {slice})); + + TF_ASSERT_OK_AND_ASSIGN(ThunkExecutor executor, + ThunkExecutor::Create(std::move(sequence))); + + EXPECT_TRUE(executor.is_sequential()); + EXPECT_THAT(executor.source(), ElementsAre(0)); + EXPECT_THAT(executor.sink(), ElementsAre(2)); +} + TEST(ThunkExecutorTest, TransitiveReduction) { BufferAllocation alloc(/*index=*/0, /*size=*/80, /*color=*/0); BufferAllocation::Slice slice(&alloc, /*offset=*/0, /*size=*/40); @@ -231,12 +265,9 @@ TEST(ThunkExecutorTest, Execute) { std::vector trace; ThunkSequence sequence; - sequence.push_back(AddI32Thunk::Create("a", {slice0}, {slice0}, - /*inject_error=*/false, &trace)); - sequence.push_back(AddI32Thunk::Create("b", {slice1}, {slice1}, - /*inject_error=*/false, &trace)); - sequence.push_back(AddI32Thunk::Create("c", {slice2}, {slice2}, - /*inject_error=*/false, &trace)); + sequence.push_back(AddI32Thunk::Create("a", {slice0}, {slice0}, &trace)); + sequence.push_back(AddI32Thunk::Create("b", {slice1}, {slice1}, &trace)); + sequence.push_back(AddI32Thunk::Create("c", {slice2}, {slice2}, &trace)); TF_ASSERT_OK_AND_ASSIGN(ThunkExecutor executor, ThunkExecutor::Create(std::move(sequence))); @@ -281,7 +312,7 @@ struct GeneratedThunkSequence { static absl::StatusOr> GenerateThunkSequence(size_t num_elements, size_t num_thunks, - bool inject_errors = false) { + bool inject_errors, bool inject_side_effects) { auto g = std::make_unique(GeneratedThunkSequence{ BufferAllocation(/*index=*/0, num_elements * sizeof(int32_t), 0), BufferAllocation(/*index=*/1, num_elements * sizeof(int32_t), 0), @@ -316,8 +347,9 @@ GenerateThunkSequence(size_t num_elements, size_t num_thunks, TF_RETURN_IF_ERROR(AddI32Thunk::Execute(&allocations, src, dst)); bool inject_error = inject_errors && inject_error_dist(engine) == 0; - g->sequence.push_back( - AddI32Thunk::Create(absl::StrCat(i), {src}, {dst}, inject_error)); + g->sequence.push_back(AddI32Thunk::Create(absl::StrCat(i), {src}, {dst}, + /*trace=*/nullptr, inject_error, + inject_side_effects)); } return g; @@ -326,10 +358,12 @@ GenerateThunkSequence(size_t num_elements, size_t num_thunks, // Parameterized thunk executor stress tests that builds a random thunk sequence // and optionally uses a thread pool to execute thunk executor tasks. class ThunkExecutorStressTest - : public testing::TestWithParam> { + : public testing::TestWithParam< + std::tuple> { public: void SetUp() override { - auto& [_, use_task_runner, use_device, inject_errors] = GetParam(); + auto& [_, use_task_runner, use_device, inject_errors, inject_side_effects] = + GetParam(); use_task_runner_ = use_task_runner; use_device_ = use_device; @@ -366,11 +400,13 @@ class ThunkExecutorStressTest }; TEST_P(ThunkExecutorStressTest, Execute) { - auto [num_thunks, use_task_runner, use_device, inject_errors] = GetParam(); + auto [num_thunks, use_task_runner, use_device, inject_errors, + inject_side_effects] = GetParam(); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr g, - GenerateThunkSequence(/*num_elements=*/1024, num_thunks, inject_errors)); + GenerateThunkSequence(/*num_elements=*/1024, num_thunks, inject_errors, + inject_side_effects)); TF_ASSERT_OK_AND_ASSIGN(ThunkExecutor executor, ThunkExecutor::Create(std::move(g->sequence))); @@ -393,7 +429,7 @@ TEST_P(ThunkExecutorStressTest, Execute) { INSTANTIATE_TEST_SUITE_P(ThunkExecutor, ThunkExecutorStressTest, testing::Combine(testing::ValuesIn({10, 100, 1000}), testing::Bool(), testing::Bool(), - testing::Bool())); + testing::Bool(), testing::Bool())); //===----------------------------------------------------------------------===// // Performance benchmarks below @@ -402,7 +438,10 @@ INSTANTIATE_TEST_SUITE_P(ThunkExecutor, ThunkExecutorStressTest, static void BM_SyncThunkExecutor(benchmark::State& state) { const size_t num_thunks = state.range(0); - auto g = GenerateThunkSequence(/*num_elements=*/1024, num_thunks).value(); + auto g = GenerateThunkSequence(/*num_elements=*/1024, num_thunks, + /*inject_errors=*/false, + /*inject_side_effects=*/false) + .value(); auto e = ThunkExecutor::Create(std::move(g->sequence)).value(); BufferAllocations allocations(g->buffers); @@ -422,7 +461,10 @@ static void BM_AsyncThunkExecutor(benchmark::State& state) { Eigen::ThreadPoolDevice device(thread_pool.AsEigenThreadPool(), thread_pool.NumThreads()); - auto g = GenerateThunkSequence(/*num_elements=*/1024, num_thunks).value(); + auto g = GenerateThunkSequence(/*num_elements=*/1024, num_thunks, + /*inject_errors=*/false, + /*inject_side_effects=*/false) + .value(); auto e = ThunkExecutor::Create(std::move(g->sequence)).value(); BufferAllocations allocations(g->buffers); diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 0383ca23925e73..2a6e8de134c2c0 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -4797,7 +4797,7 @@ cc_library( "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cudnn_header", ]) + if_static([ - "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:wrappers_cc_proto", ]), ) diff --git a/third_party/xla/xla/service/gpu/fusions/triton.cc b/third_party/xla/xla/service/gpu/fusions/triton.cc index f31724b71fa728..49e3646510bfc6 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton.cc @@ -171,13 +171,9 @@ absl::StatusOr TritonFusion::Emit( auto launch_config = *this->launch_config(); launch_dimensions = launch_config.launch_dimensions; - // TODO(bchetioui): parse block-level parameters from backend config - // where available. BlockLevelParameters block_level_parameters; - block_level_parameters.output_tile_sizes = std::vector( - hlo_computation->root_instruction()->shape().rank() - 1, 1); - block_level_parameters.output_tile_sizes.push_back( - hlo_computation->root_instruction()->shape().dimensions().back()); + block_level_parameters.output_tile_sizes = + launch_config.output_tile_sizes; block_level_parameters.num_warps = launch_dimensions.num_threads_per_block() / WarpSize(); block_level_parameters.num_ctas = 1; @@ -283,6 +279,29 @@ absl::StatusOr TritonFusion::Emit( } std::optional TritonFusion::launch_config() const { + if (analysis_.fusion_backend_config().has_block_level_fusion_config()) { + BlockLevelParameters block_level_parameters = + BlockLevelParameters::FromBlockLevelFusionConfig( + analysis_.fusion_backend_config().block_level_fusion_config()); + + int64_t num_blocks = 1; + for (auto [dim_size, dim_tile_size] : + llvm::zip(analysis_.fusion_root(0).shape().dimensions(), + block_level_parameters.output_tile_sizes)) { + num_blocks *= (dim_size + dim_tile_size - 1) / dim_tile_size; + } + + LaunchConfig launch_config; + launch_config.launch_dimensions = LaunchDimensions( + static_cast(num_blocks), + static_cast(block_level_parameters.num_warps * WarpSize())); + launch_config.output_tile_sizes = block_level_parameters.output_tile_sizes; + return launch_config; + } + + // TODO(shyshkov): Remove the SoftMax heuristic once the block-level fusion + // config is fully rolled out. All tiles size should be set before reaching + // this point. if (analysis_.fusion_backend_config().kind() == kTritonFusionKind) { // TODO(b/332649307): Change the line below to something more generic that // can handle different instructions (not just Reduce) and different diff --git a/third_party/xla/xla/service/gpu/fusions/triton_test.cc b/third_party/xla/xla/service/gpu/fusions/triton_test.cc index f714d2df9ebf6e..63d12b816152b5 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -33,9 +32,66 @@ using ::testing::ElementsAre; class TritonFusionTest : public HloTestBase {}; -TEST_F(TritonFusionTest, TritonSoftmaxFusion) { +TEST_F(TritonFusionTest, + TritonFusionWithBlockLevelFusionConfig_LaunchDimensionsAreCorrect) { #ifndef GOOGLE_CUDA - GTEST_SKIP() << "Triton fusion only enable for CUDA devices."; + GTEST_SKIP() << "Triton fusion only enabled for CUDA devices."; +#endif + + auto module = ParseAndReturnVerifiedModule(R"( + HloModule t + + add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) + } + + auxiliary_computation { + parameter_0 = f32[125]{0} parameter(0) + ROOT broadcast = f32[125,127]{1,0} broadcast(parameter_0), dimensions={0} + } + + triton_softmax_computation { + parameter_0 = f32[125,127]{1,0} parameter(0) + multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0) + constant_0 = f32[] constant(0) + reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4) + } + + ENTRY main { + param_0 = f32[125]{0} parameter(0) + auxiliary_fusion = f32[125,127]{1,0} fusion(param_0), kind=kLoop, calls=auxiliary_computation + ROOT triton_softmax = f32[125,127]{1,0} fusion(auxiliary_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config":{"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["3","127"],"num_warps":"4"}}} + })") + .value(); + + stream_executor::GpuDeviceInfoProto device_info_proto; + stream_executor::DeviceDescription device_info(device_info_proto); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis_fused = + AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info); + + auto emitter_fused = + GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused}); + auto triton_fusion = dynamic_cast(emitter_fused.get()); + ASSERT_NE(triton_fusion, nullptr); + auto launch_config = triton_fusion->launch_config(); + ASSERT_NE(launch_config, std::nullopt); + EXPECT_EQ(launch_config->launch_dimensions.num_blocks(), + /*ceil(125 / 3)=*/42); + EXPECT_EQ(launch_config->launch_dimensions.num_threads_per_block(), + /*32 * num_warps=*/128); + EXPECT_THAT(launch_config->output_tile_sizes, ElementsAre(3, 127)); +} + +TEST_F(TritonFusionTest, + TritonFusionWithoutBlockLevelFusionConfig_LaunchFromSoftMaxHeuristic) { +#ifndef GOOGLE_CUDA + GTEST_SKIP() << "Triton fusion only enabled for CUDA devices."; #endif auto module = ParseAndReturnVerifiedModule(R"( diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index 371850d884a608..48a723d605d402 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -102,6 +102,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" @@ -2493,8 +2494,23 @@ MakeTensorPtrOpAndBoundaryChecks CreateMakeTensorPtrOp( llvm::SmallVector order; llvm::SmallVector boundary_checks; + const std::vector& tile_strides = tiled_hlo.tile_strides(); + const Shape& shape = tiled_hlo.hlo()->shape(); + + // Compute physical strides of the tile. `tile_strides` contains strides for + // individual dimensions. We need to convert them to strides in the buffer + // taking into account physical layout. + // TODO(b/331332678): Compute indexing maps to physical layout indexing in + // SymbolicTileAnalysis. + llvm::SmallVector physical_strides(tile_strides.size(), 1); + int64_t current_stride = 1; + for (int64_t cur_dim : LayoutUtil::MinorToMajor(shape)) { + physical_strides[cur_dim] = tile_strides[cur_dim] * current_stride; + current_stride *= shape.dimensions(cur_dim); + } + for (auto [size, stride] : - llvm::zip(tiled_hlo.tile_sizes(), tiled_hlo.tile_strides())) { + llvm::zip(tiled_hlo.tile_sizes(), physical_strides)) { if (size == 1) continue; int dimension_index = sizes.size(); diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_mem_utils_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_mem_utils_test.cc index be0e64581c3352..01703672f66b16 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_mem_utils_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_mem_utils_test.cc @@ -161,7 +161,7 @@ TEST_F(TritonMakeTensorPtrTest, BlockProperties) { EXPECT_THAT(ConstOpValuesToInt(ptr.op.getShape()), ElementsAre(3, 4)); EXPECT_THAT(TensorShape(ptr.op), ElementsAre(4, 4)); EXPECT_THAT(ptr.boundary_checks, ElementsAre(0)); - EXPECT_THAT(ConstOpValuesToInt(ptr.op.getStrides()), ElementsAre(1, 1)); + EXPECT_THAT(ConstOpValuesToInt(ptr.op.getStrides()), ElementsAre(20, 1)); EXPECT_THAT(ConstOpValuesToInt(ptr.op.getOffsets()), ElementsAre(0, 0)); EXPECT_THAT(ptr.op.getOrder(), ElementsAre(1, 0)); } @@ -170,7 +170,7 @@ TEST_F(TritonMakeTensorPtrTest, BlockProperties) { EXPECT_THAT(ConstOpValuesToInt(ptr.op.getShape()), ElementsAre(4, 4)); EXPECT_THAT(TensorShape(ptr.op), ElementsAre(4, 4)); EXPECT_TRUE(ptr.boundary_checks.empty()); - EXPECT_THAT(ConstOpValuesToInt(ptr.op.getStrides()), ElementsAre(1, 1)); + EXPECT_THAT(ConstOpValuesToInt(ptr.op.getStrides()), ElementsAre(20, 1)); EXPECT_THAT(ConstOpValuesToInt(ptr.op.getOffsets()), ElementsAre(0, 0)); EXPECT_THAT(ptr.op.getOrder(), ElementsAre(1, 0)); } @@ -199,7 +199,7 @@ TEST_F(TritonMakeTensorPtrTest, BlockProperties) { EXPECT_THAT(ConstOpValuesToInt(ptr.op.getShape()), ElementsAre(3, 4)); EXPECT_THAT(TensorShape(ptr.op), ElementsAre(4, 4)); EXPECT_THAT(ptr.boundary_checks, ElementsAre(0)); - EXPECT_THAT(ConstOpValuesToInt(ptr.op.getStrides()), ElementsAre(1, 1)); + EXPECT_THAT(ConstOpValuesToInt(ptr.op.getStrides()), ElementsAre(20, 1)); EXPECT_THAT(ConstOpValuesToInt(ptr.op.getOffsets()), ElementsAre(0, 0)); EXPECT_THAT(ptr.op.getOrder(), ElementsAre(1, 0)); } @@ -212,7 +212,8 @@ TEST_F(TritonMakeTensorPtrTest, BlockProperties) { EXPECT_THAT(ConstOpValuesToInt(ptr.op.getShape()), ElementsAre(3, 4, 6)); EXPECT_THAT(TensorShape(ptr.op), ElementsAre(4, 4, 8)); EXPECT_THAT(ptr.boundary_checks, ElementsAre(0, 2)); - EXPECT_THAT(ConstOpValuesToInt(ptr.op.getStrides()), ElementsAre(1, 1, 1)); + EXPECT_THAT(ConstOpValuesToInt(ptr.op.getStrides()), + ElementsAre(3000, 150, 1)); EXPECT_THAT(ConstOpValuesToInt(ptr.op.getOffsets()), ElementsAre(0, 0, 0)); EXPECT_THAT(ptr.op.getOrder(), ElementsAre(2, 1, 0)); } diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc index 224dc94a1b3a2b..45a04dbb82d261 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc @@ -5668,6 +5668,60 @@ CHECK: } )")); } +TEST_F(TritonTest, TestSoftMaxWithTileElementsNotAllContiguous) { + const std::string kHloText = R"( +HloModule m + +region { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT add.1 = f32[] add(param_0, param_1) +} + +triton_softmax_computation { + constant.1 = f32[] constant(0) + broadcast.2 = f32[4,4,8] broadcast(constant.1), dimensions={} + param_0.1 = f32[4,4,8] parameter(0) + constant = f32[] constant(0) + reduce = f32[4,4] reduce(param_0.1, constant), dimensions={2}, to_apply=region + broadcast = f32[4,4,8] broadcast(reduce), dimensions={0,1} + multiply = f32[4,4,8] multiply(broadcast.2, broadcast) + ROOT add.2 = f32[4,4,8] add(multiply, broadcast) +} + +ENTRY entry_computation { + param_0.2 = f32[4,4,8] parameter(0) + ROOT fusion = f32[4,4,8] fusion(param_0.2), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["2","2","8"],"num_warps":"1"}}} +})"; + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-6, + /*arel=*/1e-6})); +} + +TEST_F(TritonTest, TestSliceWithTileElementsNotAllContiguous) { + const std::string kHloText = R"( +HloModule m + +region { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT add.2 = f32[] add(param_0, param_1) +} + +fused_computation { + param_0.1 = f32[16,16,32] parameter(0) + slice = f32[4,4,8] slice(param_0.1), slice={[2:10:2], [2:6], [3:11]} + slice.1 = f32[4,4,8] slice(param_0.1), slice={[4:8], [8:16:2], [13:21]} + ROOT add.3 = f32[4,4,8] add(slice, slice.1) +} + +ENTRY entry_computation { + param_0.2 = f32[16,16,32] parameter(0) + ROOT fusion = f32[4,4,8] fusion(param_0.2), kind=kCustom, calls=fused_computation, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["2","2","8"],"num_warps":"1"}}} +})"; + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-6, + /*arel=*/1e-6})); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/hlo_unstacker.cc b/third_party/xla/xla/service/hlo_unstacker.cc index 2f647b914da8fb..7702f9d578d612 100644 --- a/third_party/xla/xla/service/hlo_unstacker.cc +++ b/third_party/xla/xla/service/hlo_unstacker.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/hlo_unstacker.h" +#include #include #include #include @@ -46,7 +47,7 @@ limitations under the License. namespace xla { namespace { -// TODO(b/342457472): Remove this struct and move its field to the +// TODO: b/342457472 - Remove this struct and move its field to the // UnstackerTransformer as static members. A struct that holds the required // information for unstacking that is fixed across different unstacker // instastances. @@ -63,96 +64,141 @@ struct UnstackerMetadata { WhileLoopUnroller::GetUnrollableLoops(module, {}); for (const auto& [instr, while_loop_config] : loops) { metadata.unrollable_loop_bodies[instr->while_body()] = while_loop_config; + metadata.bodies[instr->while_body()] = instr; } return metadata; } absl::flat_hash_map unrollable_loop_bodies; - // A pair of custom pattern and its handler lambda that describes the - // transformation needed to unstack the hlo graph for the pattern. - std::pair, - std::function> - custom_handler; + absl::flat_hash_map bodies; + // Vector containing pairs of custom patterns and their corresponding handler + // lambdas. The patterns are checked in the order in which they are inserted + // into this vector. + std::vector< + std::pair, + std::function>> + custom_handlers; }; -// A struct that holds the required information for two-step unstacking. The -// content of each instance differs for each operand of a while loop. -struct UnstackerTransformer { - UnstackerMetadata metadata; - static absl::StatusOr Create( - const UnstackerMetadata& c) { - UnstackerTransformer transformer; - transformer.metadata = std::move(c); - return transformer; - } +// Performs the two-step unstacking. Each instance of this class is responsible +// for a single operand of a while loop. +class UnstackerTransformer { + public: + // Default unroll_factor of -1 indicates full unrolling + explicit UnstackerTransformer(const UnstackerMetadata& metadata) + : metadata_(metadata) {} // Given an instruction and the index of the its changed operand, it applies // the custom handler and populates body_changes lambdas that unstacks the hlo // graph accordingly. bool HandleInstruction(const HloInstruction* instr, int64_t changed_idx) { + // Currently, we only unstack operands that are used within fusion + // computations. + if (instr->opcode() != HloOpcode::kFusion) { + return false; + } VLOG(3) << "HandleInstruction(" << instr->shape().ToString() << instr->name() << ", " << changed_idx << ")"; - auto custom_pattern = metadata.custom_handler.first; - auto custom_handler = metadata.custom_handler.second; + for (const auto& [custom_pattern, custom_handler] : + metadata_.custom_handlers) { + const HloInstruction* stacked_user = + custom_pattern(metadata_, instr, changed_idx); + // Try the next pattern if current pattern is not found. + if (stacked_user == nullptr) { + continue; + } + if (unstacking_computation_ != nullptr) { + VLOG(3) << "Seen multiple users, cannot handle. \n instr: " + << instr->ToString() << "\n hoisted_computation: " + << unstacking_computation_->ToString( + HloPrintOptions::Fingerprint()); + return false; + } - const HloInstruction* stacked_user = - custom_pattern(metadata, instr, changed_idx); - if (stacked_user == nullptr) { - return false; - } - if (unstacking_computation != nullptr) { - LOG(ERROR) << "Seen multiple users, cannot handle. \n instr: " - << instr->ToString() << "\n hoisted_computation: " - << unstacking_computation->ToString( - HloPrintOptions::Fingerprint()); - return false; + unstacking_computation_ = + stacked_user->fused_instructions_computation()->Clone( + "hoisted_unstacking"); + VLOG(3) << "Unstacking computation: " + << unstacking_computation_->ToString( + HloPrintOptions::Fingerprint()); + + // TODO: b/342440749 - Currently, we assume the stacked dimension is + // always the most major dimension. This condition can be checked and + // terminate unstacking if not met. + Shape slice_shape = stacked_user->shape(); + int64_t num_layers = stacked_user->operand(0)->shape().dimensions(0); + std::vector shapes; + for (int64_t i = 0; i < num_layers; ++i) { + shapes.push_back(slice_shape); + } + unstacked_shape_ = + std::make_unique(ShapeUtil::MakeTupleShape(shapes)); + + unstacked_instrs_.push_back(instr); + + // Wrapper function around the unstacker lambda which calls the unstacker. + std::function unstack_wrapper = + [&custom_handler = custom_handler, stacked_user, + slice_shape]() mutable -> absl::Status { + HloInstruction* mutable_dynamic_slicing_fusion = + const_cast(stacked_user); + return custom_handler(mutable_dynamic_slicing_fusion, slice_shape); + }; + body_changes_.push_back(unstack_wrapper); + return true; } + return false; + } - unstacking_computation = - stacked_user->fused_instructions_computation()->Clone( - "hoisted_unstacking"); - VLOG(3) << "Unstacking computation: " - << unstacking_computation->ToString(HloPrintOptions::Fingerprint()); - - // TODO(b/342440749): Currently, we assume the stacked dimension is always - // the most major dimension. This condition can be checked and terminate - // unstacking if not met. - Shape slice_shape = stacked_user->shape(); - int64_t num_layers = stacked_user->operand(0)->shape().dimensions(0); - std::vector shapes; - for (int64_t i = 0; i < num_layers; ++i) { - shapes.push_back(slice_shape); - } - unstacked_shape = - std::make_unique(ShapeUtil::MakeTupleShape(shapes)); - - // Wrapper function around the unstacker lambda which calls the unstacker. - std::function unstack_wrapper = - [=]() mutable -> absl::Status { - HloInstruction* mutable_dynamic_slicing_fusion = - const_cast(stacked_user); - return custom_handler(mutable_dynamic_slicing_fusion, slice_shape); - }; - body_changes.push_back(unstack_wrapper); - return true; + std::vector& GetUnstackedInstructions() { + return unstacked_instrs_; } + const Shape* GetUnstackedShape() const { return unstacked_shape_.get(); } + + // The function returns a mutable pointer to the unstacking computation since + // the pointer is later used to clone the computation. + HloComputation* GetUnstackingComputation() const { + return unstacking_computation_.get(); + } + + std::vector>& GetLoopChanges() { + return loop_changes_; + } + + std::vector>& GetBodyChanges() { + return body_changes_; + } + + absl::flat_hash_map& GetOperandChanges() { + return operand_changes_; + } + + void AddLoopChange(std::function loop_change) { + loop_changes_.push_back(loop_change); + } + + private: + const UnstackerMetadata& metadata_; // This pointer is populated if the unstacker finds unstackable loop input. - std::unique_ptr unstacked_shape = nullptr; + std::unique_ptr unstacked_shape_ = nullptr; // This is a pointer to the computation that is responsible for unstacking. It // is used to hoist the unstacking computations outside the loop bodies. - std::unique_ptr unstacking_computation = nullptr; + std::unique_ptr unstacking_computation_ = nullptr; // A vector of lambdas that describe necessary changes to the shape of the // loops to unstack. The lambdas accept the pointer to the new unstacked // shape. - std::vector> loop_changes; + std::vector> loop_changes_; // a list of lambdas that captures all the changes to the hlo graph needed for // unstacking. - std::vector> body_changes; + std::vector> body_changes_; // A map that tracks the index of the changed operand for instructions of type // get-tuple-element, tuple, and while during unstacking. - absl::flat_hash_map operand_changes; + absl::flat_hash_map operand_changes_; + // Holds the list of unstacked instructions that will be used to identify + // loops that need to be unrolled. + std::vector unstacked_instrs_; }; bool CanUnstackWhileOperand(const HloInstruction* while_instr, @@ -169,12 +215,12 @@ bool PropagateGteShapeChange(HloInstruction* gte, UnstackerTransformer& unstacker) { VLOG(5) << "PropagateGteShapeChange(" << gte->ToString() << ")"; - // TODO(b/343457903): Use HloDataflowAnalysis to track the usage of a value + // TODO: b/343457903 - Use HloDataflowAnalysis to track the usage of a value // instead of manually applying bfs // // Apply BFS to propagate the index of the changed operand. absl::flat_hash_map& visited = - unstacker.operand_changes; + unstacker.GetOperandChanges(); std::deque worklist; worklist.push_back(gte); visited.insert({gte, gte->tuple_index()}); @@ -283,11 +329,12 @@ bool CanUnstackWhileOperand(const HloInstruction* while_instr, loop->while_condition()->ReplaceParameter( 0, HloInstruction::CreateParameter(0, old_shape, "unstacked")); }; - auto loop_change_wrapper = [=](const Shape* new_shape) { + auto loop_change_wrapper = [&loop_change, while_instr, + index](const Shape* new_shape) { HloInstruction* mutable_loop = const_cast(while_instr); loop_change(mutable_loop, new_shape, index); }; - unstacker.loop_changes.push_back(loop_change_wrapper); + unstacker.AddLoopChange(loop_change_wrapper); return true; } return false; @@ -303,7 +350,7 @@ void UnstackWhileInput(const UnstackerTransformer& unstacker, HloInstruction* old_while_input = while_instr->while_init()->mutable_operand(index); - // TODO(b/341815540): Instead of creating the unstacked tuple for every input + // TODO: b/341815540 - Instead of creating the unstacked tuple for every input // index, we should reuse if the input and unstacking computations are the // same. // @@ -312,15 +359,16 @@ void UnstackWhileInput(const UnstackerTransformer& unstacker, std::vector slices; for (int64_t i = 0; i < new_shape->tuple_shapes_size(); ++i) { std::vector operands = { - old_while_input, - while_instr->AddInstruction(MakeConstantWithShape( - unstacker.unstacking_computation->parameter_instruction(1)->shape(), - i))}; + old_while_input, while_instr->AddInstruction(MakeConstantWithShape( + unstacker.GetUnstackingComputation() + ->parameter_instruction(1) + ->shape(), + i))}; HloInstruction* slice = while_instr->AddInstruction(HloInstruction::CreateFusion( slice_shape, HloInstruction::FusionKind::kLoop, operands, while_instr->GetModule()->AddEmbeddedComputation( - unstacker.unstacking_computation->Clone()), + unstacker.GetUnstackingComputation()->Clone()), "hoisted")); slices.push_back(slice); } @@ -335,10 +383,12 @@ void UnstackWhileInput(const UnstackerTransformer& unstacker, // Apply the two-step unstacking algorithm to the given while_instr at the given // index. -bool UnstackWhileOperandAtIndex(const UnstackerMetadata& metadata, - HloInstruction* while_instr, int64_t index) { - UnstackerTransformer unstacker = - UnstackerTransformer::Create(metadata).value(); +bool UnstackWhileOperandAtIndex( + const UnstackerMetadata& metadata, HloInstruction* while_instr, + int64_t index, std::vector& unstacked_instructions) { + // UnstackerTransformer unstacker = + // UnstackerTransformer::Create(metadata).value(); + UnstackerTransformer unstacker = UnstackerTransformer(metadata); // First step of unstacking to determine whether while_instr at index is // unstackable. @@ -357,7 +407,7 @@ bool UnstackWhileOperandAtIndex(const UnstackerMetadata& metadata, // If unstacker has not found an unstackable shape, there is no point in // applying the unstacker changes. - if (unstacker.unstacked_shape == nullptr) { + if (unstacker.GetUnstackedShape() == nullptr) { return false; } @@ -366,17 +416,17 @@ bool UnstackWhileOperandAtIndex(const UnstackerMetadata& metadata, // // Update the shape of get-tuple-element, tuple, and, while instructions // based on the unstacked_shape and the index of the changed operand. - for (const auto& [instr, index] : unstacker.operand_changes) { + for (const auto& [instr, index] : unstacker.GetOperandChanges()) { switch (instr->opcode()) { case HloOpcode::kGetTupleElement: - *instr->mutable_shape() = *unstacker.unstacked_shape; + *instr->mutable_shape() = *unstacker.GetUnstackedShape(); break; case HloOpcode::kTuple: *instr->mutable_shape()->mutable_tuple_shapes(index) = - *unstacker.unstacked_shape; + *unstacker.GetUnstackedShape(); break; case HloOpcode::kWhile: - ShapeUtil::UpdateTupleShape(*unstacker.unstacked_shape, index, + ShapeUtil::UpdateTupleShape(*unstacker.GetUnstackedShape(), index, instr->mutable_shape()); break; default: @@ -384,22 +434,87 @@ bool UnstackWhileOperandAtIndex(const UnstackerMetadata& metadata, } } // Apply the changes to the body according to the provided custom handler. - for (const auto& body_change : unstacker.body_changes) { + for (const auto& body_change : unstacker.GetBodyChanges()) { CHECK_OK(body_change()); } // Update the input and output shape of the loop. - UnstackWhileInput(unstacker, while_instr, unstacker.unstacked_shape.get(), + UnstackWhileInput(unstacker, while_instr, unstacker.GetUnstackedShape(), index); const Shape& new_while_shape = while_instr->while_init()->shape(); *while_instr->mutable_shape() = new_while_shape; // Apply the changes to the shape of the loop body and condition // computations. - for (auto& loop_change : unstacker.loop_changes) { - loop_change(unstacker.unstacked_shape.get()); + for (auto& loop_change : unstacker.GetLoopChanges()) { + loop_change(unstacker.GetUnstackedShape()); + } + for (const HloInstruction* instr : unstacker.GetUnstackedInstructions()) { + unstacked_instructions.push_back(instr); } return true; } +// This function recognizes fusions with the following pattern: +// fusion(stacked, loop_iteration_var) +// computation { +// p0 = parameter(0) +// p1 = parameter(1) +// slice = dynamic_slice(p0, p1, zero, ...) +// ROOT bitcast = bitcast(slice) +// } +const HloInstruction* IsDynamicSlicingFusion(const UnstackerMetadata& metadata, + const HloInstruction* instr, + int64_t stacked_operand_idx) { + CHECK_EQ(instr->opcode(), HloOpcode::kFusion); + if (instr->fused_parameters().size() != 2) { + return nullptr; + } + if (!metadata.unrollable_loop_bodies.contains(instr->parent())) { + VLOG(5) << "Instruction not inside unrollable while body, " + << instr->ToString() << instr->parent()->ToString(); + return nullptr; + } + + WhileLoopConfig while_instr_config = + metadata.unrollable_loop_bodies.at(instr->parent()); + + for (HloInstruction* fused_instr : + instr->fused_instructions_computation()->MakeInstructionPostOrder()) { + if (!Match(fused_instr, match::DynamicSlice())) { + continue; + } + std::optional dynamic_index = + MatchShapeCoveringDynamicIndexInstruction( + fused_instr, + instr->fused_instructions_computation()->parameter_instruction( + stacked_operand_idx), + HloOpcode::kDynamicSlice, while_instr_config); + if (dynamic_index.has_value() && dynamic_index.value() == 0) { + HloInstruction* bitcast_operand = nullptr; + if (Match(instr->fused_instructions_computation()->root_instruction(), + match::Bitcast(match::Op(&bitcast_operand)))) { + if (bitcast_operand == fused_instr) { + return instr; + } + } + } + } + return nullptr; +} + +absl::Status UnstackDynamicSlicingFusion( + HloInstruction* mutable_dynamic_slicing_fusion, const Shape& slice_shape) { + HloComputation* parent_loop = mutable_dynamic_slicing_fusion->parent(); + + HloInstruction* stacked = mutable_dynamic_slicing_fusion->mutable_operand(0); + HloInstruction* offset = mutable_dynamic_slicing_fusion->mutable_operand(1); + + HloInstruction* new_operand = + parent_loop->AddInstruction(HloInstruction::CreateCustomCall( + slice_shape, {stacked, offset}, "DynamicGte")); + return mutable_dynamic_slicing_fusion->ReplaceAllUsesWithDifferentShape( + new_operand); +} + // This method checks if the given instruction is a fusion with the following // properties: // 1. It is inside the body of an unrollable loop @@ -413,9 +528,7 @@ bool UnstackWhileOperandAtIndex(const UnstackerMetadata& metadata, const HloInstruction* GetNestedDynamicSlicingFusion( const UnstackerMetadata& metadata, const HloInstruction* instr, int64_t stacked_operand_idx) { - if (!Match(instr, match::Fusion())) { - return nullptr; - } + CHECK_EQ(instr->opcode(), HloOpcode::kFusion); if (!metadata.unrollable_loop_bodies.contains(instr->parent())) { VLOG(5) << "Instruction not inside unrollable while body, " @@ -536,13 +649,13 @@ absl::StatusOr HloUnstacker::Run( const absl::flat_hash_set& execution_threads) { TF_ASSIGN_OR_RETURN(auto metadata, UnstackerMetadata::Create(module)); - // Custom handler is a pair of pattern and transformation function that - // captures different cases of unstacking. It is decoupled from the unstacking - // algorithm for modularity. - metadata.custom_handler = std::make_pair(GetNestedDynamicSlicingFusion, - UnstackNestedDynamicSlicingFusion); + metadata.custom_handlers.push_back( + std::make_pair(IsDynamicSlicingFusion, UnstackDynamicSlicingFusion)); + metadata.custom_handlers.push_back(std::make_pair( + GetNestedDynamicSlicingFusion, UnstackNestedDynamicSlicingFusion)); bool unstacked = false; + std::vector unstacked_instructions; for (HloInstruction* instr : module->entry_computation()->MakeInstructionPostOrder()) { if (instr->opcode() != HloOpcode::kWhile) { @@ -552,7 +665,8 @@ absl::StatusOr HloUnstacker::Run( VLOG(3) << "Attempting to unstack " << instr->name() << " at " << i << " with stacked shape " << instr->shape().tuple_shapes(i).ToString(); - if (UnstackWhileOperandAtIndex(metadata, instr, i)) { + if (UnstackWhileOperandAtIndex(metadata, instr, i, + unstacked_instructions)) { VLOG(3) << "Unstacked " << instr->name() << " at " << i << " with stacked shape " << instr->shape().tuple_shapes(i).ToString(); @@ -566,7 +680,19 @@ absl::StatusOr HloUnstacker::Run( TF_RETURN_IF_ERROR(module->RemoveUnusedComputations()); // We rely on the WhileLoopUnroller pass to unroll loop bodies and rewrite // custom-calls created by unstacker, i.e., DynamicGte and DynamicTuple. - TF_RETURN_IF_ERROR(WhileLoopUnroller(-1, true).Run(module).status()); + std::vector loops_to_unroll; + for (const HloInstruction* instr : unstacked_instructions) { + HloInstruction* loop = metadata.bodies[instr->parent()]; + if (std::find(loops_to_unroll.begin(), loops_to_unroll.end(), loop) == + loops_to_unroll.end()) { + loops_to_unroll.push_back(loop); + } + } + for (HloInstruction* loop : loops_to_unroll) { + TF_ASSIGN_OR_RETURN(bool unrolled, + WhileLoopUnroller::Unroll(loop, -1, true, true)); + CHECK(unrolled); + } } return unstacked; } diff --git a/third_party/xla/xla/service/hlo_unstacker_test.cc b/third_party/xla/xla/service/hlo_unstacker_test.cc index 707f442077592e..3ad1bfea95dda4 100644 --- a/third_party/xla/xla/service/hlo_unstacker_test.cc +++ b/third_party/xla/xla/service/hlo_unstacker_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" @@ -30,6 +31,62 @@ namespace { using UnstackerTest = HloTestBase; +TEST_F(UnstackerTest, UnstackLoopSingleFusionUser) { + std::string hlo_string = R"( + HloModule SimpleLoop + %fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> + s8[128,128] { + %param_0.51117 = s8[3,128,128] parameter(0) + p1 = s32[] parameter(1) + %constant.85694 = s32[] constant(0) + %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] + %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), + dynamic_slice_sizes={1,128,128} ROOT %bitcast.31250 = s8[128,128] + bitcast(s8[1,128,128] %dynamic-slice.22040) + } + + %while.body (wide_param: (s32[], bf16[8,128], s8[3,128,128])) -> (s32[], + bf16[8,128], s8[3,128,128]) { + wide_p = (s32[], bf16[8,128], s8[3,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + p0 = bf16[8,128] get-tuple-element(wide_p), index=1 + p1 = s8[3,128,128] get-tuple-element(wide_p), index=2 + one = s32[] constant(1) + inc = s32[] add(i, one) + %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, + calls=%fused_computation.slice conv = bf16[8,128] convolution(bf16[8,128] + p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf ROOT out = (s32[], + bf16[8,128], s8[3,128,128]) tuple(inc, conv, p1) + } + + %while.cond (wide_param: (s32[], bf16[8,128], s8[3,128,128])) -> pred[] { + wide_p = (s32[], bf16[8,128], s8[3,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + %constant.12857 = s32[] constant(3) + ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, s32[] + %constant.12857), direction=LT + } + + ENTRY main { + p0 = s8[3,128,128] parameter(0) + p1 = bf16[8,128] parameter(1) + init = s32[] constant(0) + while.input = (s32[], bf16[8,128], s8[3,128,128]) tuple(init, p1, p0) + while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), + condition=%while.cond , body=%while.body while_use = s8[3,128,128] + get-tuple-element(while.out), index=2 ROOT out = bf16[8,128] + get-tuple-element(while.out), index=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + auto original = module->Clone(); + TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); + EXPECT_TRUE(unstacked); + EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), + std::nullopt)); +} + TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUser) { std::string hlo_string = R"( HloModule SimpleLoop @@ -412,7 +469,7 @@ TEST_F(UnstackerTest, UnstackMultipleLoops) { while.input = (s32[], bf16[8,128], s8[4,128,128]) tuple(init, p1, weight) while.out = (s32[], bf16[8,128], s8[4,128,128]) while(while.input), condition=%while.cond1 , body=%while.body1 second.while.input = (s32[], bf16[8,128], s8[4,128,128]) tuple(init, p1, weight) - second.while.output = (s32[], bf16[8,128], s8[4,128,128]) while(second.while.input), condition=%while.cond2 , body=%while.body2 + second.while.out = (s32[], bf16[8,128], s8[4,128,128]) while(second.while.input), condition=%while.cond2 , body=%while.body2 ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 } )"; diff --git a/third_party/xla/xla/status.h b/third_party/xla/xla/status.h index 9ed24549214eed..2c91a61b4ed048 100644 --- a/third_party/xla/xla/status.h +++ b/third_party/xla/xla/status.h @@ -16,6 +16,14 @@ limitations under the License. #ifndef XLA_STATUS_H_ #define XLA_STATUS_H_ +#include "absl/status/status.h" + // This is an obsolete header. Please use absl/status/status.h instead. +namespace xla { +// NOLINTBEGIN(misc-unused-using-decls) +using absl::OkStatus; +using absl::Status; +// NOLINTEND(misc-unused-using-decls) +} // namespace xla #endif // XLA_STATUS_H_ diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index bc5bf073667f94..f46d705caa82ea 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -4,7 +4,7 @@ load("@local_tsl//tsl/platform:build_config_root.bzl", "if_static") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//xla:xla.bzl", "xla_cc_test") load("//xla/stream_executor:build_defs.bzl", "stream_executor_build_defs_bzl_deps", "stream_executor_friends", "stream_executor_internal") -load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.bzl", "if_google", "internal_visibility") load("//xla/tsl:tsl.default.bzl", "filegroup") package( @@ -137,7 +137,8 @@ cc_library( "@local_tsl//tsl/protobuf:dnn_proto_cc", ] + if_static([ ":stream_executor_impl", - "@com_google_protobuf//:protobuf", # indirectly-used by dnn.h + ]) + if_google([ + "@com_google_protobuf//:wrappers_cc_proto", # indirectly-used by dnn.h ]), ) @@ -225,7 +226,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", - ] + if_static(["@com_google_protobuf//:protobuf"]), + ] + if_google(["@com_google_protobuf//:wrappers_cc_proto"]), ) cc_library( @@ -398,7 +399,7 @@ cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/protobuf:dnn_proto_cc", - ] + if_static(["@com_google_protobuf//:protobuf"]), + ] + if_google(["@com_google_protobuf//:wrappers_cc_proto"]), ) cc_library( diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 333c103e222810..7426ca4a1f5b41 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -1718,6 +1718,7 @@ xla_test( "nomac", # b/194731834 "nozapfhahn", "optonly", + "test_xla_cpu_thunks", ], deps = [ ":client_library_test_base", @@ -2227,12 +2228,15 @@ xla_test( deps = [ ":hlo_test_base", ":xla_internal_test_main", + "//xla:error_spec", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:collective_pipeliner", "//xla/service:hlo_dce", "//xla/service:hlo_parser", "//xla/service:hlo_pass_pipeline", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", ], ) diff --git a/third_party/xla/xla/tests/collective_pipeliner_execution_test.cc b/third_party/xla/xla/tests/collective_pipeliner_execution_test.cc index bc54b0817bc99f..d06981fafa7267 100644 --- a/third_party/xla/xla/tests/collective_pipeliner_execution_test.cc +++ b/third_party/xla/xla/tests/collective_pipeliner_execution_test.cc @@ -20,14 +20,18 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/error_spec.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/collective_pipeliner.h" #include "xla/service/hlo_dce.h" #include "xla/service/hlo_parser.h" #include "xla/service/hlo_pass_pipeline.h" #include "xla/tests/hlo_test_base.h" +#include "xla/util.h" namespace xla { namespace { @@ -125,6 +129,71 @@ ENTRY entry { ErrorSpec{0.1, 0.1})); } +TEST_F(CollectivePipelinerExecutionTest, TransformIncrementIndexByOneNoReuse) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2 + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + constant.2561 = s32[] constant(0) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) + ar.1 = bf16[1,8,128] negate(mul) + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, ar.1, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0) + while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + auto module = ParseAndReturnUnverifiedModule(hlo_string).value(); + auto module2 = ParseAndReturnUnverifiedModule(hlo_string).value(); + EXPECT_TRUE( + RunOptimizer(module.get(), /*last_run=*/true, /*level_to_operate_on=*/0, + /*should_process=*/HloPredicateIsOp, + /*pipelining_direction=*/ + CollectivePipeliner::PipeliningDirection::kForward, + /*pipeline_use_tree=*/false, + /*acceptable_formatting=*/HloPredicateTrue, + /*reuse_pipelined_op_buffer=*/HloPredicateFalse) + .value()); + XLA_VLOG_LINES(1, module->ToString()); + XLA_VLOG_LINES(1, module2->ToString()); + EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2), + ErrorSpec{0.1, 0.1})); +} + TEST_F(CollectivePipelinerExecutionTest, PushAgOver) { constexpr absl::string_view hlo_string = R"( HloModule module, entry_computation_layout={(bf16[3,8,128]{2,1,0})->bf16[3,8,128]{2,1,0}} @@ -953,5 +1022,281 @@ ENTRY entry { ErrorSpec{0.1, 0.1})); } +TEST_F(CollectivePipelinerExecutionTest, TransformIncrementByTwoFormat) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=2 + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + constant.2561 = s32[] constant(0) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) + ar.1 = bf16[1,8,128] negate(mul) + c = bf16[] constant(5.0) + b = bf16[1,8,128] broadcast(c), dimensions={} + a = bf16[1,8,128] add(ar.1, b) + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, a, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.35) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0) + while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + auto module = ParseAndReturnUnverifiedModule(hlo_string).value(); + auto module2 = ParseAndReturnUnverifiedModule(hlo_string).value(); + + EXPECT_TRUE( + RunOptimizer(module.get(), /*last_run=*/true, 0, + /*should_process=*/HloPredicateIsOp, + CollectivePipeliner::PipeliningDirection::kForwardSink, + /*pipeline_use_tree=*/true) + .value()); + XLA_VLOG_LINES(1, module->ToString()); + XLA_VLOG_LINES(1, module2->ToString()); + EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2), + ErrorSpec{0.1, 0.1})); +} + +TEST_F(CollectivePipelinerExecutionTest, MultiUsesElementwiseMerge) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2 + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + constant.2561 = s32[] constant(0) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) + c2 = bf16[] constant(2.0) + bc = bf16[1,8,128] broadcast(c2) + ar.1 = bf16[1,8,128] sqrt(mul) + ar.2 = bf16[1,8,128] negate(mul) + mul2 = bf16[1,8,128] multiply(ar.1, bc) + mul3 = bf16[1,8,128] multiply(mul2, ar.2) + mul4 = bf16[1,8,128] multiply(mul3, mul) + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, mul4, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0) + while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + auto module = ParseAndReturnUnverifiedModule(hlo_string).value(); + auto module2 = ParseAndReturnUnverifiedModule(hlo_string).value(); + + EXPECT_TRUE( + RunOptimizer(module.get(), /*last_run=*/true, 0, + /*should_process=*/ + HloPredicateIsOp, + CollectivePipeliner::PipeliningDirection::kForward, + /*pipeline_use_tree=*/true) + .value()); + XLA_VLOG_LINES(1, module->ToString()); + XLA_VLOG_LINES(1, module2->ToString()); + EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2), + ErrorSpec{0.1, 0.1})); +} + +TEST_F(CollectivePipelinerExecutionTest, BroadcastAsFormattingOp) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add.1 { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=2 + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + constant.2561 = s32[] constant(0) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) + ar.1 = bf16[1,8,128] negate(mul) + b.1 = bf16[1,8,128,32] broadcast(ar.1), dimensions={0,1,2} + constant = bf16[] constant(0) + reduce = bf16[1,8,128] reduce(b.1, constant), dimensions={3}, to_apply=add.1 + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, reduce, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.35) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0) + while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + auto module = ParseAndReturnUnverifiedModule(hlo_string).value(); + auto module2 = ParseAndReturnUnverifiedModule(hlo_string).value(); + + EXPECT_TRUE( + RunOptimizer(module.get(), /*last_run=*/true, 0, + /*should_process=*/HloPredicateIsOp, + CollectivePipeliner::PipeliningDirection::kForwardSink, + /*pipeline_use_tree=*/true) + .value()); + XLA_VLOG_LINES(1, module->ToString()); + XLA_VLOG_LINES(1, module2->ToString()); + EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2), + ErrorSpec{0.1, 0.1})); +} + +TEST_F(CollectivePipelinerExecutionTest, + ForwardSinkDependentPipelineableCollectives) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add.1 { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=2 + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + constant.2561 = s32[] constant(0) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) + ar.1 = bf16[1,8,128] negate(mul) + b.1 = bf16[1,8,128,32] broadcast(ar.1), dimensions={0,1,2} + constant = bf16[] constant(0) + reduce = bf16[1,8,128] reduce(b.1, constant), dimensions={3}, to_apply=add.1 + ar.2 = bf16[1,8,128] negate(reduce) + c1 = bf16[] constant(2.0) + bc = bf16[1,8,128] broadcast(c1) + mul1 = bf16[1,8,128] multiply(ar.2, bc) + mul3 = bf16[1,8,128] multiply(mul1, ar.2) + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, mul3, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.35) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0) + while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + auto module = ParseAndReturnUnverifiedModule(hlo_string).value(); + auto module2 = ParseAndReturnUnverifiedModule(hlo_string).value(); + + EXPECT_TRUE( + RunOptimizer( + module.get(), /*last_run=*/true, 0, + /*should_process=*/HloPredicateIsOp, + CollectivePipeliner::PipeliningDirection::kForwardSink, + /*pipeline_use_tree=*/true, + /*acceptable_formatting=*/HloPredicateIsNotOp) + .value()); + XLA_VLOG_LINES(1, module->ToString()); + XLA_VLOG_LINES(1, module2->ToString()); + EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2), + ErrorSpec{0.1, 0.1})); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index d7109b0b69f5e2..6b2b3adab2812f 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -797,7 +797,7 @@ tsl_gpu_library( "//xla/service/gpu:gpu_compiler", "//xla/stream_executor/gpu:gpu_init", "//xla/service/gpu:gpu_symbol_repository", - ]), + ]) + if_google(["@com_google_protobuf//:duration_cc_proto"]), ) xla_test( @@ -840,7 +840,7 @@ xla_test( "@local_tsl//tsl/platform:test", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", "@local_tsl//tsl/protobuf:status_proto_cc", - ], + ] + if_google(["@com_google_protobuf//:duration_cc_proto"]), ) xla_test( diff --git a/third_party/xla/xla/tsl/util/proto/BUILD b/third_party/xla/xla/tsl/util/proto/BUILD index 2752d1f13e07d0..eed31a15a8aa1c 100644 --- a/third_party/xla/xla/tsl/util/proto/BUILD +++ b/third_party/xla/xla/tsl/util/proto/BUILD @@ -2,6 +2,7 @@ load( "@local_tsl//tsl/platform:rules_cc.bzl", "cc_library", ) +load("//xla/tsl:tsl.bzl", "if_google") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -16,6 +17,5 @@ cc_library( hdrs = ["proto_utils.h"], deps = [ "@com_google_absl//absl/time", - "@com_google_protobuf//:protobuf_headers", - ], + ] + if_google(["@com_google_protobuf//:duration_cc_proto"]), ) From 1a0285ee0b48b2dc416eb64d4945e5dfb88db01d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jun 2024 13:21:15 -0700 Subject: [PATCH 233/256] Updated API docs for tf.quantization.fake_quant_with_min_max_args_gradient with Example code. PiperOrigin-RevId: 646584697 --- .../mlir/tensorflow/ir/tf_generated_ops.td | 37 ++++++++++++++++++- ..._def_FakeQuantWithMinMaxArgsGradient.pbtxt | 35 ++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 3b531c4ed0590b..97066983e1110f 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -5459,7 +5459,42 @@ def TF_FakeQuantWithMinMaxArgsGradientOp : TF_Op<"FakeQuantWithMinMaxArgsGradien let results = (outs Res= min && inputs <= max)`.}]>:$backprops +`gradients * (inputs >= min && inputs <= max)`. + +``` +import tensorflow as tf + +# Define some sample data +gradients = tf.random.uniform((2, 3), minval=-5.0, maxval=5.0, dtype=tf.float32) +inputs = tf.random.uniform((2, 3), minval=-10.0, maxval=10.0, dtype=tf.float32) + +# Define quantization parameters (adjust as needed) +min_val = -2.0 +max_val = 8.0 +num_bits = 4 # Number of bits for quantization + +# Calculate gradients for fake quantization with specified parameters +output_gradients = tf.quantization.fake_quant_with_min_max_args_gradient( + gradients=gradients, inputs=inputs, min=min_val, max=max_val, num_bits=num_bits, narrow_range = False, name=None +) + +# Print the original gradients and the gradients after the fake-quant operation +print("Original Gradients:") +print(gradients) +print("\nGradients after Fake-Quantization:") +print(output_gradients) + +``` +#Original Gradients: +#tf.Tensor( +#[[ 1.242547 3.217492 3.568469 ] +#[-0.55371046 0.23130894 2.608243 ]], shape=(2, 3), dtype=float32) + +#Gradients after Fake-Quantization: +#tf.Tensor( +#[[ 0. 3.217492 3.568469 ] +# [-0.55371046 0.23130894 2.608243 ]], shape=(2, 3), dtype=float32) +}]>:$backprops ); let extraClassDeclaration = [{ diff --git a/tensorflow/core/api_def/base_api/api_def_FakeQuantWithMinMaxArgsGradient.pbtxt b/tensorflow/core/api_def/base_api/api_def_FakeQuantWithMinMaxArgsGradient.pbtxt index 5241acc559ead8..46d0dff0c4f5d6 100644 --- a/tensorflow/core/api_def/base_api/api_def_FakeQuantWithMinMaxArgsGradient.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_FakeQuantWithMinMaxArgsGradient.pbtxt @@ -17,6 +17,41 @@ END description: <= min && inputs <= max)`. + +``` +import tensorflow as tf + +# Define some sample data +gradients = tf.random.uniform((2, 3), minval=-5.0, maxval=5.0, dtype=tf.float32) +inputs = tf.random.uniform((2, 3), minval=-10.0, maxval=10.0, dtype=tf.float32) + +# Define quantization parameters (adjust as needed) +min_val = -2.0 +max_val = 8.0 +num_bits = 4 # Number of bits for quantization + +# Calculate gradients for fake quantization with specified parameters +output_gradients = tf.quantization.fake_quant_with_min_max_args_gradient( + gradients=gradients, inputs=inputs, min=min_val, max=max_val, num_bits=num_bits, narrow_range = False, name=None +) + +# Print the original gradients and the gradients after the fake-quant operation +print("Original Gradients:") +print(gradients) +print("\nGradients after Fake-Quantization:") +print(output_gradients) + +``` +#Original Gradients: +#tf.Tensor( +#[[ 1.242547 3.217492 3.568469 ] +#[-0.55371046 0.23130894 2.608243 ]], shape=(2, 3), dtype=float32) + +#Gradients after Fake-Quantization: +#tf.Tensor( +#[[ 0. 3.217492 3.568469 ] +# [-0.55371046 0.23130894 2.608243 ]], shape=(2, 3), dtype=float32) + END } summary: "Compute gradients for a FakeQuantWithMinMaxArgs operation." From 38cdbf3dad63c4143b7343d23b0bb819d2609a69 Mon Sep 17 00:00:00 2001 From: Jaroslav Sevcik Date: Tue, 25 Jun 2024 13:21:17 -0700 Subject: [PATCH 234/256] PR #14087: Implement GetOutputMemoryKinds for GPU executables Imported from GitHub PR https://github.com/openxla/xla/pull/14087 Copybara import of the project: -- 8bf16c545c937d7cab93f94e9ee57dc88903c199 by Jaroslav Sevcik : Implement GetOutputMemoryKinds for GPU executables -- 092f0dadbf90a3f5cf53fd87f968ee7f15ff5ed6 by Jaroslav Sevcik : Reserve space in vectors -- 6716ae8cba91b636c0d44c6a9ead8f74c24c0721 by Jaroslav Sevcik : Reuse bits of tests Merging this change closes #14087 PiperOrigin-RevId: 646584714 --- .../xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 93 +++++++++++++------ .../xla/pjrt/pjrt_stream_executor_client.cc | 51 +++++++++- 2 files changed, 117 insertions(+), 27 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index a8476013edfe20..70cb28f66b51b6 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -800,24 +800,35 @@ absl::StatusOr> CreateDeviceBufferForTest( return input; } +constexpr char const* kD2HProgram = R"( + HloModule f + + ENTRY main.5 { + p = s32[4]{0} parameter(0) + ROOT cc = s32[4] custom-call(p), + custom_call_target="annotate_device_placement", + frontend_attributes={_xla_buffer_placement="pinned_host"} + } +)"; + +constexpr char const* kD2HProgramTupleOutput = R"( + HloModule f + + ENTRY main.5 { + p = s32[4]{0} parameter(0) + cc = s32[4] custom-call(p), + custom_call_target="annotate_device_placement", + frontend_attributes={_xla_buffer_placement="pinned_host"} + ROOT tuple = (s32[4]{0}, s32[4]{0}) tuple(s32[4]{0} p, s32[4]{0} cc) + } +)"; + } // namespace TEST(StreamExecutorGpuClientTest, ExecutePinnedHostOutputTest) { TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(GpuClientOptions())); TF_ASSERT_OK_AND_ASSIGN(auto input, CreateDeviceBufferForTest(client.get())); - - static constexpr char const* kD2HProgram = R"( - HloModule f - - ENTRY main.5 { - p = s32[4]{0} parameter(0) - ROOT cc = s32[4] custom-call(p), - custom_call_target="annotate_device_placement", - frontend_attributes={_xla_buffer_placement="pinned_host"} - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto executable, CompileExecutable(kD2HProgram, *client)); TF_ASSERT_OK_AND_ASSIGN( @@ -837,18 +848,6 @@ TEST(StreamExecutorGpuClientTest, ExecutePinnedHostOutputTupleTest) { GetStreamExecutorGpuClient(GpuClientOptions())); TF_ASSERT_OK_AND_ASSIGN(auto input, CreateDeviceBufferForTest(client.get())); - static constexpr char const* kD2HProgram = R"( - HloModule f - - ENTRY main.5 { - p = s32[4]{0} parameter(0) - cc = s32[4] custom-call(p), - custom_call_target="annotate_device_placement", - frontend_attributes={_xla_buffer_placement="pinned_host"} - ROOT tuple = (s32[4]{0}, s32[4]{0}) tuple(s32[4]{0} p, s32[4]{0} cc) - } - )"; - // Build the output shape with the correct memory space set. Shape host_shape = input->on_device_shape(); host_shape.mutable_layout()->set_memory_space(Layout::kHostMemorySpace); @@ -860,8 +859,9 @@ TEST(StreamExecutorGpuClientTest, ExecutePinnedHostOutputTupleTest) { xla::CompileOptions options; options.executable_build_options.set_result_layout(out_shape); - TF_ASSERT_OK_AND_ASSIGN(auto executable, - CompileExecutable(kD2HProgram, *client, options)); + TF_ASSERT_OK_AND_ASSIGN( + auto executable, + CompileExecutable(kD2HProgramTupleOutput, *client, options)); // Untuple the result so that we get separate buffers. // This is how JAX invokes XLA. @@ -876,5 +876,46 @@ TEST(StreamExecutorGpuClientTest, ExecutePinnedHostOutputTupleTest) { EXPECT_EQ(result_buffers[1]->memory_space()->kind(), "pinned_host"); } +TEST(StreamExecutorGpuClientTest, ExecutablePinnedHostOutputMemoryKindTest) { + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); + TF_ASSERT_OK_AND_ASSIGN(auto executable, + CompileExecutable(kD2HProgram, *client)); + + TF_ASSERT_OK_AND_ASSIGN(auto memory_kinds, + executable->GetOutputMemoryKinds()); + EXPECT_EQ(memory_kinds.size(), 1); + EXPECT_EQ(memory_kinds[0].size(), 1); + EXPECT_EQ(memory_kinds[0][0], "pinned_host"); +} + +TEST(StreamExecutorGpuClientTest, + ExecutablePinnedHostTupleOutputMemoryKindTest) { + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); + + // Build the output shape with the correct memory space set. + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(S32, {4}, {0}); + Shape host_shape = shape; + host_shape.mutable_layout()->set_memory_space(Layout::kHostMemorySpace); + Shape out_shape = ShapeUtil::MakeTupleShape({shape, host_shape}); + + // Set the result layout so that the compiler assertions on memory + // spaces pass. + xla::CompileOptions options; + options.executable_build_options.set_result_layout(out_shape); + + TF_ASSERT_OK_AND_ASSIGN( + auto executable, + CompileExecutable(kD2HProgramTupleOutput, *client, options)); + + TF_ASSERT_OK_AND_ASSIGN(auto memory_kinds, + executable->GetOutputMemoryKinds()); + EXPECT_EQ(memory_kinds.size(), 1); + EXPECT_EQ(memory_kinds[0].size(), 2); + EXPECT_EQ(memory_kinds[0][0], "device"); + EXPECT_EQ(memory_kinds[0][1], "pinned_host"); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index 6d30ef0330f1d3..fb143187699dd7 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -3251,9 +3251,58 @@ PjRtStreamExecutorLoadedExecutable::GetHloModules() const { return std::move(modules); } +namespace { + +absl::StatusOr MemoryKindFromSimpleShape( + const Shape& shape, absl::string_view default_memory_kind) { + if (!shape.has_layout()) { + return default_memory_kind; + } + switch (shape.layout().memory_space()) { + case Layout::kHostMemorySpace: + return PinnedHostMemorySpace::kKind; + case Layout::kDefaultMemorySpace: + return default_memory_kind; + default: + return InvalidArgument("Unexpected memory space %d in output layout", + shape.layout().memory_space()); + } +} + +absl::StatusOr> MemoryKindsFromShape( + const Shape& shape, absl::string_view default_memory_kind) { + if (!shape.IsTuple()) { + TF_ASSIGN_OR_RETURN(absl::string_view memory_kind, + MemoryKindFromSimpleShape(shape, default_memory_kind)); + return {{memory_kind}}; + } + std::vector result; + result.reserve(shape.tuple_shapes_size()); + for (auto element_shape : shape.tuple_shapes()) { + TF_ASSIGN_OR_RETURN( + absl::string_view element_memory_kind, + MemoryKindFromSimpleShape(element_shape, default_memory_kind)); + result.push_back(element_memory_kind); + } + return result; +} + +} // namespace + absl::StatusOr>> PjRtStreamExecutorLoadedExecutable::GetOutputMemoryKinds() const { - return Unimplemented("GetOutputMemoryKinds is not supported."); + TF_ASSIGN_OR_RETURN(auto shapes, GetOutputShapes()); + TF_ASSIGN_OR_RETURN(PjRtMemorySpace * default_memory_space, + addressable_devices()[0]->default_memory_space()); + std::vector> out; + out.reserve(shapes.size()); + for (auto shape : shapes) { + TF_ASSIGN_OR_RETURN( + std::vector memory_kind, + MemoryKindsFromShape(shape, default_memory_space->kind())); + out.push_back(memory_kind); + } + return out; } absl::StatusOr From 37510ef390a098c740bfd34678c14738ea761dc9 Mon Sep 17 00:00:00 2001 From: Kris Tonthat Date: Tue, 25 Jun 2024 13:25:11 -0700 Subject: [PATCH 235/256] Add NNAPI and Hexagon deprecation warnings to TF Lite docs PiperOrigin-RevId: 646585930 --- tensorflow/lite/g3doc/_book.yaml | 2 ++ .../lite/g3doc/android/delegates/hexagon.md | 16 ++++++++++++---- tensorflow/lite/g3doc/android/delegates/nnapi.md | 12 ++++++++++++ tensorflow/lite/g3doc/performance/delegates.md | 12 ++++++++++++ .../g3doc/performance/implementing_delegate.md | 12 ++++++++++++ 5 files changed, 50 insertions(+), 4 deletions(-) diff --git a/tensorflow/lite/g3doc/_book.yaml b/tensorflow/lite/g3doc/_book.yaml index 8b04e13063d0b9..0f24420a7b5a30 100644 --- a/tensorflow/lite/g3doc/_book.yaml +++ b/tensorflow/lite/g3doc/_book.yaml @@ -141,8 +141,10 @@ upper_tabs: path: /lite/android/delegates/gpu_native.md - title: "NNAPI delegate" path: /lite/android/delegates/nnapi + status: deprecated - title: "Hexagon delegate" path: /lite/android/delegates/hexagon + status: deprecated - heading: "Models with metadata" - title: "Overview" diff --git a/tensorflow/lite/g3doc/android/delegates/hexagon.md b/tensorflow/lite/g3doc/android/delegates/hexagon.md index 56662db33d3d5f..60fb6604dd8c15 100644 --- a/tensorflow/lite/g3doc/android/delegates/hexagon.md +++ b/tensorflow/lite/g3doc/android/delegates/hexagon.md @@ -1,5 +1,17 @@ # TensorFlow Lite Hexagon delegate + + This document explains how to use the TensorFlow Lite Hexagon Delegate in your application using the Java and/or C API. The delegate leverages the Qualcomm Hexagon library to execute quantized kernels on the DSP. Note that the delegate @@ -7,10 +19,6 @@ is intended to *complement* NNAPI functionality, particularly for devices where NNAPI DSP acceleration is unavailable (e.g., on older devices, or devices that don’t yet have a DSP NNAPI driver). -Caution: The currently released versions of the Hexagon delegate, [up to version -1.20.0.1](#hexagon_versions), are no longer supported. An updated version of this -delegate is expected soon. - **Supported devices:** Currently the following Hexagon architecture are supported, including but not diff --git a/tensorflow/lite/g3doc/android/delegates/nnapi.md b/tensorflow/lite/g3doc/android/delegates/nnapi.md index c8d4ce7ee06ea6..96a17d3a31ecbd 100644 --- a/tensorflow/lite/g3doc/android/delegates/nnapi.md +++ b/tensorflow/lite/g3doc/android/delegates/nnapi.md @@ -1,5 +1,17 @@ # TensorFlow Lite NNAPI delegate + + The [Android Neural Networks API (NNAPI)](https://developer.android.com/ndk/guides/neuralnetworks) is available on all Android devices running Android 8.1 (API level 27) or diff --git a/tensorflow/lite/g3doc/performance/delegates.md b/tensorflow/lite/g3doc/performance/delegates.md index 51a3c305bd3b84..d0866fbcca302b 100644 --- a/tensorflow/lite/g3doc/performance/delegates.md +++ b/tensorflow/lite/g3doc/performance/delegates.md @@ -1,5 +1,17 @@ # TensorFlow Lite Delegates + + ## Introduction **Delegates** enable hardware acceleration of TensorFlow Lite models by diff --git a/tensorflow/lite/g3doc/performance/implementing_delegate.md b/tensorflow/lite/g3doc/performance/implementing_delegate.md index e2908127dd8f4a..fdbc880cf11cd5 100644 --- a/tensorflow/lite/g3doc/performance/implementing_delegate.md +++ b/tensorflow/lite/g3doc/performance/implementing_delegate.md @@ -1,5 +1,17 @@ # Implementing a Custom Delegate + + [TOC] ## What is a TensorFlow Lite Delegate? From 59e731c177ca597e0429d070ac929361de6c1603 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jun 2024 13:45:20 -0700 Subject: [PATCH 236/256] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 646592600 --- tensorflow/go/op/wrappers.go | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 97090ec541caba..beb3950d5417ca 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -17236,6 +17236,42 @@ func FakeQuantWithMinMaxArgsGradientNarrowRange(value bool) FakeQuantWithMinMaxA // // Returns Backpropagated gradients below the FakeQuantWithMinMaxArgs operation: // `gradients * (inputs >= min && inputs <= max)`. +// +// ``` +// import tensorflow as tf +// +// # Define some sample data +// gradients = tf.random.uniform((2, 3), minval=-5.0, maxval=5.0, dtype=tf.float32) +// inputs = tf.random.uniform((2, 3), minval=-10.0, maxval=10.0, dtype=tf.float32) +// +// # Define quantization parameters (adjust as needed) +// min_val = -2.0 +// max_val = 8.0 +// num_bits = 4 # Number of bits for quantization +// +// # Calculate gradients for fake quantization with specified parameters +// output_gradients = tf.quantization.fake_quant_with_min_max_args_gradient( +// +// gradients=gradients, inputs=inputs, min=min_val, max=max_val, num_bits=num_bits, narrow_range = False, name=None +// +// ) +// +// # Print the original gradients and the gradients after the fake-quant operation +// print("Original Gradients:") +// print(gradients) +// print("\nGradients after Fake-Quantization:") +// print(output_gradients) +// +// ``` +// #Original Gradients: +// #tf.Tensor( +// #[[ 1.242547 3.217492 3.568469 ] +// #[-0.55371046 0.23130894 2.608243 ]], shape=(2, 3), dtype=float32) +// +// #Gradients after Fake-Quantization: +// #tf.Tensor( +// #[[ 0. 3.217492 3.568469 ] +// # [-0.55371046 0.23130894 2.608243 ]], shape=(2, 3), dtype=float32) func FakeQuantWithMinMaxArgsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, optional ...FakeQuantWithMinMaxArgsGradientAttr) (backprops tf.Output) { if scope.Err() != nil { return From 038957a5b6cde41376a96fcc42eaab617881a4d3 Mon Sep 17 00:00:00 2001 From: Matthias Kramm Date: Tue, 25 Jun 2024 14:34:16 -0700 Subject: [PATCH 237/256] Remove stay type annotation from context(). PiperOrigin-RevId: 646609853 --- tensorflow/python/eager/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index fee09b7dc4a599..f163eca309db3e 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -2465,7 +2465,7 @@ def _reset_jit_compiler_flags(): pywrap_tfe.TF_ResetJitCompilerFlags() -def context() -> Context: +def context(): """Returns a singleton context object.""" if _context is None: _create_context() From f47f49c708571bb9eae5f2464f2fd0f2ef2ee9f8 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Tue, 25 Jun 2024 15:12:37 -0700 Subject: [PATCH 238/256] Properly override repository for JAX builds in `build.py` Also change an `if_static` to an `if_google` to fix JAX builds PiperOrigin-RevId: 646622609 --- third_party/xla/build_tools/build.py | 8 ++++++-- third_party/xla/xla/service/gpu/BUILD | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/third_party/xla/build_tools/build.py b/third_party/xla/build_tools/build.py index 6629ea43d9e88a..8b353f3cdbb9e7 100755 --- a/third_party/xla/build_tools/build.py +++ b/third_party/xla/build_tools/build.py @@ -274,7 +274,9 @@ def nvidia_gpu_build_with_compute_capability( JAX_NUM_GENERATED_CASES=25, JAX_SKIP_SLOW_TESTS=1, ), - options=_DEFAULT_BAZEL_OPTIONS, + options=dict( + **_DEFAULT_BAZEL_OPTIONS, override_repository="xla=/github/xla" + ), ) _JAX_GPU_BUILD = Build( @@ -294,7 +296,9 @@ def nvidia_gpu_build_with_compute_capability( TF_CPP_MIN_LOG_LEVEL=0, JAX_EXCLUDE_TEST_TARGETS="PmapTest.testSizeOverflow", ), - options=_DEFAULT_BAZEL_OPTIONS, + options=dict( + **_DEFAULT_BAZEL_OPTIONS, override_repository="xla=/github/xla" + ), ) _KOKORO_JOB_NAME_TO_BUILD_MAP = { diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 2a6e8de134c2c0..0726033d31e418 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -4796,7 +4796,7 @@ cc_library( ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cudnn_header", - ]) + if_static([ + ]) + if_google([ "@com_google_protobuf//:wrappers_cc_proto", ]), ) From 2d72742d40f1d3121c895f8584ec8882d1e97fc8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jun 2024 15:33:22 -0700 Subject: [PATCH 239/256] Add tensorflow support for 16k page sizes on arm64 Tested both libtensorflowlite.so and libtensorflowlite_jni.so to ensure both libraries are 16k ELF aligned with this change: $ objdump -p bazel-bin/tensorflow/lite/libtensorflowlite.so | grep LOAD | awk '{ print $1 " " $NF }' LOAD 2**14 LOAD 2**14 $ objdump -p bazel-bin/tensorflow/lite/java/libtensorflowlite_jni.so | grep LOAD | awk '{ print $1 " " $NF }' LOAD 2**14 LOAD 2**14 PiperOrigin-RevId: 646629366 --- tensorflow/lite/build_def.bzl | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index da05c312d3c0d3..8ebe89096a8128 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -181,13 +181,22 @@ def tflite_linkopts_no_undefined(): }), ) +def tflite_pagesize_linkopts(): + """Defines linker flags for setting the page size.""" + return select({ + clean_dep("//tensorflow:android_arm64"): [ + "-Wl,-z,max-page-size=16384", + ], + "//conditions:default": [], + }) + def tflite_linkopts(): """Defines linker flags for linking TFLite binary.""" - return tflite_linkopts_unstripped() + tflite_symbol_opts() + return tflite_linkopts_unstripped() + tflite_symbol_opts() + tflite_pagesize_linkopts() def tflite_jni_linkopts(): """Defines linker flags for linking TFLite binary with JNI.""" - return tflite_jni_linkopts_unstripped() + tflite_symbol_opts() + return tflite_jni_linkopts_unstripped() + tflite_symbol_opts() + tflite_pagesize_linkopts() def tflite_jni_binary( name, From dea0e2a6f8beca190ca138093e669d42ee244056 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 25 Jun 2024 15:38:57 -0700 Subject: [PATCH 240/256] [xla:cpu] Optimize KernelThunk host kernel loading PiperOrigin-RevId: 646631043 --- .../select_and_scatter_benchmark_test.cc | 3 +- .../xla/service/cpu/runtime/kernel_thunk.cc | 29 ++++++++++--------- .../xla/service/cpu/runtime/kernel_thunk.h | 14 ++++----- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc index fdfdc7b00b1882..7521dcda5b0f86 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc @@ -75,6 +75,7 @@ BENCHMARK(BM_SelectAndScatterF32) ->MeasureProcessCPUTime() ->Arg(128) ->Arg(256) - ->Arg(512); + ->Arg(512) + ->Arg(1024); } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc index 85d65c0ec323fd..247c995649a525 100644 --- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/numeric/bits.h" #include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "xla/runtime/buffer_use.h" @@ -124,29 +125,31 @@ tsl::AsyncValueRef KernelThunk::Execute( // TODO(ezhulenev): Kernel ptr should be loaded as a part of Thunk // initialization stage. - SE_HOST_Kernel* kernel_ptr = kernel_ptr_.load(); + se::host::HostKernel* kernel = kernel_ptr_.load(); // Because thunks are owned by a parent CpuExecutable, we can safely assume // that kernel pointer will not change after we find it the first time. - if (kernel_ptr == nullptr) { - TF_ASSIGN_OR_RETURN(kernel_ptr, params.host_kernels->Find(kernel_name_)); - kernel_ptr_.store(kernel_ptr); - } + if (ABSL_PREDICT_FALSE(kernel == nullptr)) { + TF_ASSIGN_OR_RETURN(SE_HOST_Kernel * kernel_fn, + params.host_kernels->Find(kernel_name_)); - se::host::HostKernel kernel(buffers_data.size(), kernel_ptr, nullptr); + absl::MutexLock lock(&mutex_); + kernel_.emplace(buffers_data.size(), kernel_fn, nullptr); + kernel_ptr_.store(kernel = &kernel_.value()); + } // If intra-op thread pool is not nullptr, we launch HostKernel in async mode // by scheduling tasks into it. HostKernel launch completion will // automatically signal KernelThunk execute completion. - if (params.intra_op_threadpool && use_task_runner_) { - return kernel.Launch(thread_dim_, buffers_data, - [¶ms](se::host::HostKernel::Task task) { - params.intra_op_threadpool->getPool()->Schedule( - ToCopyableTask(std::move(task))); - }); + if (ABSL_PREDICT_FALSE(params.intra_op_threadpool && use_task_runner_)) { + return kernel->Launch(thread_dim_, buffers_data, + [¶ms](se::host::HostKernel::Task task) { + params.intra_op_threadpool->getPool()->Schedule( + ToCopyableTask(std::move(task))); + }); } - TF_RETURN_IF_ERROR(kernel.Launch(thread_dim_, buffers_data)); + TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, buffers_data)); return OkExecuteEvent(); } diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.h b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.h index 72cd1be097ac25..708f918d342c96 100644 --- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.h +++ b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.h @@ -23,11 +23,13 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/runtime/thunk.h" -#include "xla/stream_executor/host/host_kernel_c_api.h" +#include "xla/stream_executor/host/host_kernel.h" #include "xla/stream_executor/launch_dim.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -64,12 +66,10 @@ class KernelThunk final : public Thunk { // launch the kernel directly in the caller thread. bool use_task_runner_; - // Pointer to the host kernel corresponding to `kernel_name_`. Initialized - // lazily at run time by looking it up in the HostKernels passed via params. - // - // TODO(ezhulenev): This should be moved to initialization stage when we'll - // have it for CPU thunks. - std::atomic kernel_ptr_; + // Lazily loaded host kernel corresponding to `kernel_name_`. + absl::Mutex mutex_; + std::optional kernel_ ABSL_GUARDED_BY(mutex_); + std::atomic kernel_ptr_; // pointer to `kernel_` }; } // namespace xla::cpu From 01aeb511e1c2e357c24d6f8f57bdcc46638549fd Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Tue, 25 Jun 2024 15:46:23 -0700 Subject: [PATCH 241/256] Remove unused VlogOccupancyInfo calls. PiperOrigin-RevId: 646633442 --- .../xla/stream_executor/cuda/cuda_executor.cc | 90 ------------------- .../xla/stream_executor/gpu/gpu_executor.h | 25 ------ .../xla/stream_executor/rocm/rocm_executor.cc | 86 ------------------ 3 files changed, 201 deletions(-) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 7d978f043a42fd..901f051e290606 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -465,20 +465,6 @@ absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, const GpuKernel* cuda_kernel = AsGpuKernel(&kernel); CUfunction cufunc = cuda_kernel->AsGpuFunctionHandle(); - // Only perform/print the occupancy check once. Even just checking to see - // whether we've done an occupancy check on this kernel before isn't free - // (because we have to synchronize), so we only do this at -v 2+. - if (VLOG_IS_ON(2)) { - absl::MutexLock lock(&launched_kernels_mu_); - if (!launched_kernels_.count(cufunc)) { - VlogOccupancyInfo(stream->parent()->GetDeviceDescription(), kernel, - thread_dims, block_dims); - // TODO(rspringer): Remove elements from launched_kernels_...if we ever - // expose a kernel/module deallocation method. - launched_kernels_.insert(cufunc); - } - } - if (cuda_kernel->cache_config() != KernelCacheConfig::kNoPreference) { TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig( cufunc, cuda_kernel->GetGpuCacheConfig())); @@ -547,82 +533,6 @@ absl::Status GpuExecutor::Submit(Stream* stream, return GpuDriver::GraphLaunch(exec, AsGpuStreamValue(stream)); } -// This is a non-essential operation; if there's a failure, proceed without -// logging an error. It's nearly certain that in case of failures, we'd never -// get here in the first place; these are very low-impact routines. -void GpuExecutor::VlogOccupancyInfo(const DeviceDescription& device_description, - const Kernel& kernel, - const ThreadDim& thread_dims, - const BlockDim& block_dims) { - VLOG(2) << "Computing kernel occupancy for kernel " - << kernel.demangled_name(); - VLOG(2) << "Thread dimensions (" << thread_dims.x << ", " << thread_dims.y - << ", " << thread_dims.z << ")"; - - auto regs_per_thread = kernel.metadata().registers_per_thread(); - auto smem_per_block = kernel.metadata().shared_memory_bytes(); - - if (!regs_per_thread && !smem_per_block) { - return; - } - - const GpuKernel* cuda_kernel = AsGpuKernel(&kernel); - CUfunction cufunc = cuda_kernel->AsGpuFunctionHandle(); - - int blocks_per_sm = CalculateOccupancy(device_description, *regs_per_thread, - *smem_per_block, thread_dims, cufunc); - VLOG(2) << "Resident blocks per SM is " << blocks_per_sm; - - int suggested_threads = - CompareOccupancy(&blocks_per_sm, device_description, *regs_per_thread, - *smem_per_block, thread_dims, cufunc); - if (suggested_threads != 0) { - VLOG(2) << "The cuda occupancy calculator recommends using " - << suggested_threads - << " threads per block to achieve an occupancy of " << blocks_per_sm - << " blocks per SM."; - } -} - -// Compute and return maximum blocks per core (occupancy) based on the -// device description, some kernel characteristics and the number of threads per -// block. If unable to compute occupancy, zero is returned. -int GpuExecutor::CalculateOccupancy(const DeviceDescription& device_description, - uint64_t registers_per_thread, - uint64_t shared_memory_per_block, - const ThreadDim& thread_dims, - CUfunction func) { - int suggested_blocks = 0; - int suggested_threads = 0; - CUresult err = cuOccupancyMaxPotentialBlockSize( - &suggested_blocks, &suggested_threads, func, nullptr, - shared_memory_per_block, 0); - CHECK_EQ(err, CUDA_SUCCESS); - return suggested_blocks; -} - -// Compute and return the suggested thread count to achieve ideal occupancy. -// If the provided thread dimensions match this number, zero is returned. -int GpuExecutor::CompareOccupancy(int* initial_blocks, - const DeviceDescription& device_description, - uint64_t registers_per_thread, - uint64_t shared_memory_per_block, - const ThreadDim& thread_dims, - CUfunction func) { - int suggested_blocks = 0; - int suggested_threads = 0; - CUresult err = cuOccupancyMaxPotentialBlockSize( - &suggested_blocks, &suggested_threads, func, nullptr, - shared_memory_per_block, 0); - CHECK_EQ(err, CUDA_SUCCESS); - if (suggested_blocks > *initial_blocks) { - *initial_blocks = suggested_blocks; - return suggested_threads; - } else { - return 0; - } -} - DeviceMemoryBase GpuExecutor::Allocate(uint64_t size, int64_t memory_space) { if (memory_space == 1) { auto result = GpuCollectives::CollectiveMemoryAllocate(context_, size); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h index c0a21007e00b5d..d17959e8fd4bf0 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h @@ -147,17 +147,6 @@ class GpuExecutor : public StreamExecutorCommon { absl::Status Submit(Stream* stream, const CommandBuffer& command_buffer) override; - int CalculateOccupancy(const DeviceDescription& device_description, - uint64_t registers_per_thread, - uint64_t shared_memory_per_block, - const ThreadDim& thread_dims, GpuFunctionHandle func); - - int CompareOccupancy(int* initial_blocks, - const DeviceDescription& device_description, - uint64_t registers_per_thread, - uint64_t shared_memory_per_block, - const ThreadDim& thread_dims, GpuFunctionHandle func); - DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; void Deallocate(DeviceMemoryBase* mem) override; @@ -320,12 +309,6 @@ class GpuExecutor : public StreamExecutorCommon { absl::Status GetKernelMetadata(GpuKernel* cuda_kernel, KernelMetadata* kernel_metadata); - // Prints to VLOG(2) information about the kernel's occupancy and how it might - // be improved. - void VlogOccupancyInfo(const DeviceDescription& device_description, - const Kernel& kernel, const ThreadDim& thread_dims, - const BlockDim& block_dims); - // (supported on CUDA only) absl::Status LoadModuleFromCuBin(const char* cubin, GpuModuleHandle* module) TF_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); @@ -377,14 +360,6 @@ class GpuExecutor : public StreamExecutorCommon { std::unordered_map> gpu_binary_to_module_ ABSL_GUARDED_BY(in_memory_modules_mu_); - // Guards the launched kernel set. - absl::Mutex launched_kernels_mu_; - - // Keeps track of the set of launched kernels. Currently used to suppress the - // occupancy check on subsequent launches. - std::set launched_kernels_ - ABSL_GUARDED_BY(launched_kernels_mu_); - // Handle for the CUDA device being operated on. Immutable // post-initialization. GpuDeviceHandle device_; diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc index f405de0c781537..86b0b3574f922f 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc @@ -337,20 +337,6 @@ absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, const GpuKernel* rocm_kernel = AsGpuKernel(&kernel); hipFunction_t hipfunc = rocm_kernel->AsGpuFunctionHandle(); - // Only perform/print the occupancy check once. Even just checking to see - // whether we've done an occupancy check on this kernel before isn't free - // (because we have to synchronize), so we only do this at -v 2+. - if (VLOG_IS_ON(2)) { - absl::MutexLock lock(&launched_kernels_mu_); - if (!launched_kernels_.count(hipfunc)) { - VlogOccupancyInfo(stream->parent()->GetDeviceDescription(), kernel, - thread_dims, block_dims); - // TODO(rspringer): Remove elements from launched_kernels_...if we ever - // expose a kernel/module deallocation method. - launched_kernels_.insert(hipfunc); - } - } - if (rocm_kernel->cache_config() != KernelCacheConfig::kNoPreference) { TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig( hipfunc, rocm_kernel->GetGpuCacheConfig())); @@ -458,78 +444,6 @@ absl::Status GpuExecutor::LoadModuleFromHsaco(const char* hsaco, return absl::OkStatus(); } -// This is a non-essential operation; if there's a failure, proceed without -// logging an error. It's nearly certain that in case of failures, we'd never -// get here in the first place; these are very low-impact routines. -void GpuExecutor::VlogOccupancyInfo(const DeviceDescription& device_description, - const Kernel& kernel, - const ThreadDim& thread_dims, - const BlockDim& block_dims) { - VLOG(2) << "Computing kernel occupancy for kernel " - << kernel.demangled_name(); - VLOG(2) << "Thread dimensions (" << thread_dims.x << ", " << thread_dims.y - << ", " << thread_dims.z << ")"; - - auto regs_per_thread = kernel.metadata().registers_per_thread(); - auto smem_per_block = kernel.metadata().shared_memory_bytes(); - - if (!regs_per_thread && !smem_per_block) { - return; - } - - const GpuKernel* rocm_kernel = AsGpuKernel(&kernel); - auto hipfunc = rocm_kernel->AsGpuFunctionHandle(); - - int blocks_per_sm = CalculateOccupancy(device_description, *regs_per_thread, - *smem_per_block, thread_dims, hipfunc); - VLOG(2) << "Resident blocks per SM is " << blocks_per_sm; - - int suggested_threads = - CompareOccupancy(&blocks_per_sm, device_description, *regs_per_thread, - *smem_per_block, thread_dims, hipfunc); - if (suggested_threads != 0) { - VLOG(2) << "The rocm occupancy calculator recommends using " - << suggested_threads - << " threads per block to achieve an occupancy of " << blocks_per_sm - << " blocks per SM."; - } -} - -// Compute and return maximum blocks per core (occupancy) based on the -// device description, some kernel characteristics and the number of threads per -// block. If unable to compute occupancy, zero is returned. -int GpuExecutor::CalculateOccupancy(const DeviceDescription& device_description, - uint64_t registers_per_thread, - uint64_t shared_memory_per_block, - const ThreadDim& thread_dims, - GpuFunctionHandle func) { - int suggested_blocks = 0; - int suggested_threads = 0; - (void)rocm::OccupancyGetMaxPotentialBlockSize( - &suggested_blocks, &suggested_threads, func, shared_memory_per_block, 0); - return suggested_blocks; -} - -// Compute and return the suggested thread count to achieve ideal occupancy. -// If the provided thread dimensions match this number, zero is returned. -int GpuExecutor::CompareOccupancy(int* initial_blocks, - const DeviceDescription& device_description, - uint64_t registers_per_thread, - uint64_t shared_memory_per_block, - const ThreadDim& thread_dims, - GpuFunctionHandle func) { - int suggested_blocks = 0; - int suggested_threads = 0; - (void)rocm::OccupancyGetMaxPotentialBlockSize( - &suggested_blocks, &suggested_threads, func, shared_memory_per_block, 0); - if (suggested_blocks > *initial_blocks) { - *initial_blocks = suggested_blocks; - return suggested_threads; - } else { - return 0; - } -} - DeviceMemoryBase GpuExecutor::Allocate(uint64_t size, int64_t memory_space) { if (memory_space == static_cast(stream_executor::MemoryType::kHost)) { From 89489e76b02a78963d5739393bdb3b93ae46c082 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Tue, 25 Jun 2024 15:47:13 -0700 Subject: [PATCH 242/256] Add `xla/package_groups.bzl` and `xla/tsl/package_groups.bzl` to hold `package_groups` and replace Copybara rules PiperOrigin-RevId: 646633709 --- third_party/xla/opensource_only.files | 2 ++ third_party/xla/xla/BUILD | 39 ++-------------------- third_party/xla/xla/package_groups.bzl | 23 +++++++++++++ third_party/xla/xla/tests/BUILD | 11 ++---- third_party/xla/xla/tsl/BUILD | 33 ++---------------- third_party/xla/xla/tsl/package_groups.bzl | 7 ++++ 6 files changed, 41 insertions(+), 74 deletions(-) create mode 100644 third_party/xla/xla/package_groups.bzl create mode 100644 third_party/xla/xla/tsl/package_groups.bzl diff --git a/third_party/xla/opensource_only.files b/third_party/xla/opensource_only.files index 666a858f3e4bd9..baafd35265caaf 100644 --- a/third_party/xla/opensource_only.files +++ b/third_party/xla/opensource_only.files @@ -1,10 +1,12 @@ compiler/xla/mlir_hlo/WORKSPACE: +compiler/xla/package_groups.bzl: compiler/xla/stream_executor/build_defs.bzl: compiler/xla/tsl/cuda/stub.bzl: compiler/xla/tsl/mkl/BUILD: compiler/xla/tsl/mkl/LICENSE: compiler/xla/tsl/mkl/MKL_LICENSE: compiler/xla/tsl/mkl/build_defs.bzl: +compiler/xla/tsl/package_groups.bzl: third_party/BUILD: third_party/__init__:.py third_party/compute_library/BUILD: diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 7b8bacdead5b0f..0b58e14e6b0a58 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -7,6 +7,7 @@ load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//third_party/compute_library:build_defs.bzl", "if_enable_acl") # Placeholder: load py_proto_library +load("//xla:package_groups.bzl", "xla_package_groups") load("//xla:xla.bzl", "xla_cc_test", "xla_py_proto_library") load("//xla/tsl:tsl.bzl", "if_google", "internal_visibility") load("//xla/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") @@ -17,46 +18,12 @@ package( licenses = ["notice"], ) -package_group( - name = "friends", - includes = ["//xla:internal"], - packages = [ - # copybara:uncomment "//learning/...", - "//third_party/australis/...", - "//third_party/iree/...", - "//third_party/libxc/...", - "//third_party/mira/...", - "//third_party/mlcompass/...", - "//third_party/mlir_edge/model_curriculum/...", - "//third_party/openxla/shardonnay/...", - "//third_party/py/enzyme_ad/...", - "//third_party/py/jax/...", - "//third_party/py/t5x/...", - "//third_party/py/tpu_graphs/...", - "//tensorflow/compiler/...", - "//tensorflow/python/tpu/...", - ], -) - -package_group( - name = "internal", - packages = [ - "//xla/...", - ], -) - -package_group( - name = "runtime", - packages = [ - "//xla/runtime/...", - "//xla/service/gpu/runtime/...", - ], -) - exports_files([ "lit.cfg.py", ]) +xla_package_groups() + # Filegroup used to collect source files for dependency checking. filegroup( name = "c_srcs", diff --git a/third_party/xla/xla/package_groups.bzl b/third_party/xla/xla/package_groups.bzl new file mode 100644 index 00000000000000..ea554b3a8aa09d --- /dev/null +++ b/third_party/xla/xla/package_groups.bzl @@ -0,0 +1,23 @@ +"""XLA package_group definitions""" + +def xla_package_groups(name = "xla_package_groups"): + native.package_group( + name = "friends", + packages = ["//..."], + ) + + native.package_group( + name = "internal", + packages = ["//..."], + ) + + native.package_group( + name = "runtime", + packages = ["//..."], + ) + +def xla_tests_package_groups(name = "xla_tests_package_groups"): + native.package_group( + name = "friends", + packages = ["//..."], + ) diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 7426ca4a1f5b41..c6b1a763480b26 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -11,6 +11,7 @@ load( "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) +load("//xla:package_groups.bzl", "xla_tests_package_groups") load("//xla:xla.bzl", "tests_build_defs_bzl_deps", "xla_cc_binary", "xla_cc_test") load("//xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend_test_macros", "xla_test", "xla_test_library") load("//xla/tsl:tsl.bzl", "internal_visibility", "tsl_copts") @@ -22,14 +23,6 @@ package( licenses = ["notice"], ) -package_group( - name = "friends", - includes = [ - "//xla:friends", - ], - packages = ["//platforms/testing/tests/..."], -) - # Filegroup used to collect source files for dependency checking. filegroup( name = "c_srcs", @@ -39,6 +32,8 @@ filegroup( ]), ) +xla_tests_package_groups() + # Generate test_suites for all backends, named "${backend}_tests". generate_backend_suites() diff --git a/third_party/xla/xla/tsl/BUILD b/third_party/xla/xla/tsl/BUILD index 4d15019d7586f4..27a3e2ae8ab9bf 100644 --- a/third_party/xla/xla/tsl/BUILD +++ b/third_party/xla/xla/tsl/BUILD @@ -1,11 +1,14 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("@bazel_skylib//lib:selects.bzl", "selects") load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "bool_setting") +load("package_groups.bzl", "tsl_package_groups") load("tsl.bzl", "if_google", "if_oss") load("tsl.default.bzl", "tsl_google_bzl_deps") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) +tsl_package_groups() + # Config setting to use in select()s to distinguish open source build from # google internal build on configurable attributes. # @@ -497,36 +500,6 @@ config_setting( visibility = ["//visibility:public"], ) -# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST! -# Instead, please use public APIs or public build rules TF provides. -# If you need functionality that is not exposed, we will work with you to expand our public APIs. -# TODO(b/173549186): Move Google-internal TF code out of learning/brain -# TODO(jakeharmon): Prune this for use in TSL -package_group( - name = "internal", - packages = [ - "//devtools/python/indexer/...", - "//learning/brain/keras/...", - "//learning/brain/mlir/...", - "//learning/brain/tfrt/...", - "//learning/lib/ami/simple_ml/...", - "//learning/pathways/...", - "//smartass/brain/configure/...", - "//tensorflow/...", - "//tensorflow_decision_forests/...", - "//tensorflow_federated/...", - "//third_party/cloud_tpu/convergence_tools/sdc_monitoring/...", - "//third_party/cloud_tpu/inference_converter/...", - "//third_party/py/cloud_ml_autoflow/...", - "//third_party/py/envlogger/...", - "//third_party/yggdrasil_decision_forests/...", - ] + if_google([ - # Needed in OSS, where bazel won't allow a package group to refer to an - # external repo. - "@local_tsl//tsl/...", - ]), -) - bzl_library( name = "tsl_bzl", srcs = ["tsl.bzl"], diff --git a/third_party/xla/xla/tsl/package_groups.bzl b/third_party/xla/xla/tsl/package_groups.bzl new file mode 100644 index 00000000000000..e4e44a7f4020f5 --- /dev/null +++ b/third_party/xla/xla/tsl/package_groups.bzl @@ -0,0 +1,7 @@ +"""TSL package_group definitions""" + +def tsl_package_groups(name = "tsl_package_groups"): + native.package_group( + name = "internal", + packages = ["//..."], + ) From 18a0eb07ef315a3a59a339b0b182fc05a81c28a3 Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Tue, 25 Jun 2024 17:28:44 -0700 Subject: [PATCH 243/256] [Triton] Refactoring condition in autotuner to be more robust. Added test to make sure crashing Triton configurations are actually skipped and to guard against breaking it. PiperOrigin-RevId: 646663340 --- .../xla/service/gpu/gemm_fusion_autotuner.cc | 45 ++++++++++--------- .../service/gpu/gemm_fusion_autotuner_test.cc | 28 ++++++++++++ 2 files changed, 51 insertions(+), 22 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc index 25c1b5b3894c8f..6359e5f28dbcb5 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc @@ -312,11 +312,6 @@ absl::StatusOr GetLimits(const HloDotInstruction& dot) { const int max_k = tsl::NextPowerOfTwoS64( dot.operand(1)->shape().dimensions(contracting_index)); - // TODO(b/337839570): block_k = 16 is bugged in Triton for dots with 8-bit - // input. Setting minimum to 32 instead of 16 for these cases. - // TODO(b/337838200): Write the restriction on the minimum tile size to be - // generic. Currently we only handle the 8-bit case as this was the bug we - // ran into. return TileSizeLimit{ /*block_m=*/std::max(max_m, kMinTileSize), /*block_n=*/std::max(max_n, kMinTileSize), @@ -634,13 +629,20 @@ absl::StatusOr> GemmFusionAutotunerImpl::GenerateConfigs( absl::StatusOr> GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) { - bool has_8_bit_operand = HloAnyOf({&dot}, [&](const HloInstruction* node) { - if (node->opcode() != HloOpcode::kConvert) { - return false; - } - auto in_type = node->operand(0)->shape().element_type(); - return primitive_util::BitWidth(in_type) == 8; - }); + // Retrieve the minimum bit-width participating in the dot. This is needed + // to avoid autotuning configurations that are not supported by Triton. This + // is used to restrict the values for tile_k. + std::vector converts = + HloFindAll({&dot}, [&](const HloInstruction* node) { + return node->opcode() == HloOpcode::kConvert; + }); + int minBitWidth = primitive_util::BitWidth(dot.shape().element_type()); + for (auto convert : converts) { + auto in_type = convert->operand(0)->shape().element_type(); + auto out_type = convert->shape().element_type(); + minBitWidth = std::min({minBitWidth, primitive_util::BitWidth(in_type), + primitive_util::BitWidth(out_type)}); + } std::vector result_configs; TF_ASSIGN_OR_RETURN(TileSizeLimit limits, GetLimits(dot)); @@ -690,14 +692,12 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) { } config.split_k = std::min(config.split_k, max_split_k); - // TODO(b/337839570): block_k = 16 is bugged in Triton for dots with 8-bit - // input. Setting minimum to 32 instead of 16 for these cases. - // TODO(b/337838200): Write the restriction on the minimum tile size to be - // generic. Currently we only handle the 8-bit case as this was the bug we - // ran into. - if (has_8_bit_operand && config.block_k == kMinTileSize) { - config.block_k *= 2; - } + // TODO(b/337839570): Triton currently has a limitation where it crashes + // on small block_k values depending on the bit-width of the inputs to the + // dot. The logic below accounts for this limitation. + constexpr int kLdmatrixGranularity = 256; + config.block_k = + std::max(config.block_k, kLdmatrixGranularity / minBitWidth); // Sparse meta should have at least one element per thread. // Note: only 2:4 structured sparsity is currently supported. @@ -706,8 +706,9 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) { config.block_m = std::max(config.block_m, 64); config.num_warps = std::max(config.num_warps, 4); } - config.block_k = - std::max(config.block_k, kMinTileSize * (has_8_bit_operand ? 4 : 2)); + config.block_k = std::max( + config.block_k, + 2 * std::max(kMinTileSize, kLdmatrixGranularity / minBitWidth)); int meta_elements = config.block_m * config.block_k / 16; config.num_warps = std::min(config.num_warps, meta_elements / WarpSize()); diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc index fa75448abf42db..16410403233987 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc @@ -845,6 +845,34 @@ ENTRY e { )"); } +// TODO(b/337839570): Triton currently has a limitation where it crashes +// on small block_k values depending on the bit-width of the inputs to the +// dot. For this test case, it should skip any block_k values that are <= 16 +// since the smallest type has a bit-width of 8. +TEST_F(GemmFusionAutotunerExhaustiveTest, SkipsCrashingTileKConfig) { + std::unique_ptr module = ParseAndReturnVerifiedModule(R"( +HloModule module +ENTRY e { + x = s8[33,33]{1,0} parameter(0) + c = f16[33,33]{1,0} convert(x) + y = f16[33,33]{1,0} parameter(1) + ROOT out = f16[33,33]{1,0} dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)") + .value(); + const se::CudaComputeCapability compute_capability{ + se::CudaComputeCapability::AMPERE, /*minor=*/0}; + TF_ASSERT_OK_AND_ASSIGN( + const std::vector configs, + GetPossibleMatmulAutotuneConfigs( + *Cast( + module->entry_computation()->root_instruction()), + compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); + EXPECT_TRUE(std::all_of( + configs.begin(), configs.end(), + [](const TritonGemmConfig& config) { return config.block_k > 16; })); +} + class GemmFusionAutotunerDisableSplitK : public GemmFusionAutotunerTest { public: DebugOptions GetDebugOptionsForTest() override { From 633e9ccc0e93efeadffd572f8b969a54540df00b Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 25 Jun 2024 17:34:58 -0700 Subject: [PATCH 244/256] [xla:cpu] Optimize KernelThunk by passing SE_HOST_KernelArg directly to the kernel PiperOrigin-RevId: 646664927 --- .../select_and_scatter_benchmark_test.cc | 1 + .../xla/service/cpu/runtime/kernel_thunk.cc | 30 +++++++++-------- .../service/cpu/runtime/kernel_thunk_test.cc | 4 +-- .../xla/stream_executor/host/host_kernel.cc | 33 ++++++++++++++----- .../xla/stream_executor/host/host_kernel.h | 5 +++ .../stream_executor/host/host_kernel_c_api.h | 2 +- .../stream_executor/host/host_kernel_test.cc | 25 +++++++++----- 7 files changed, 65 insertions(+), 35 deletions(-) diff --git a/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc index 7521dcda5b0f86..bbc32250444b0f 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc @@ -73,6 +73,7 @@ static void BM_SelectAndScatterF32(benchmark::State& state) { BENCHMARK(BM_SelectAndScatterF32) ->MeasureProcessCPUTime() + ->Arg(64) ->Arg(128) ->Arg(256) ->Arg(512) diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc index 247c995649a525..a8d793d1076071 100644 --- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc @@ -87,38 +87,40 @@ tsl::AsyncValueRef KernelThunk::Execute( kernel_name_, arguments_buffers_.size(), results_buffers_.size(), thread_dim_.ToString()); - absl::InlinedVector buffers_data; - buffers_data.reserve(arguments_buffers_.size() + results_buffers_.size()); + absl::InlinedVector kernel_args; + kernel_args.reserve(arguments_buffers_.size() + results_buffers_.size()); int64_t arg_num = 0; for (BufferAllocation::Slice& buffer : arguments_buffers_) { - TF_ASSIGN_OR_RETURN(buffers_data.emplace_back(), + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase arg_data, params.buffer_allocations->GetDeviceAddress(buffer)); + kernel_args.push_back( + SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()}); VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", arg_num++, - buffer.ToString(), - buffers_data.back().opaque()); + buffer.ToString(), kernel_args.back().data); } int64_t res_num = 0; for (BufferAllocation::Slice& buffer : results_buffers_) { - TF_ASSIGN_OR_RETURN(buffers_data.emplace_back(), + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase result_data, params.buffer_allocations->GetDeviceAddress(buffer)); + kernel_args.push_back( + SE_HOST_KernelArg{result_data.opaque(), result_data.size()}); VLOG(3) << absl::StreamFormat(" res #%d: %s (%p)", res_num++, - buffer.ToString(), - buffers_data.back().opaque()); + buffer.ToString(), kernel_args.back().data); } // Check that all buffers are aligned to the minimum alignment. We codegen // with the assumption that all buffers are aligned, and if they are not, we // will crash with a segmentation fault, or worse, produce incorrect results. if (min_alignment_.has_value()) { - for (int64_t i = 0; i < buffers_data.size(); ++i) { - auto ptr = reinterpret_cast(buffers_data[i].opaque()); + for (int64_t i = 0; i < kernel_args.size(); ++i) { + auto ptr = reinterpret_cast(kernel_args[i].data); if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) { return Internal( "Host kernel %s buffer argument #%d (%p) is not aligned to a " "required minimum alignment of %d bytes", - info().op_name, i, buffers_data[i].opaque(), *min_alignment_); + info().op_name, i, kernel_args[i].data, *min_alignment_); } } } @@ -134,7 +136,7 @@ tsl::AsyncValueRef KernelThunk::Execute( params.host_kernels->Find(kernel_name_)); absl::MutexLock lock(&mutex_); - kernel_.emplace(buffers_data.size(), kernel_fn, nullptr); + kernel_.emplace(kernel_args.size(), kernel_fn, nullptr); kernel_ptr_.store(kernel = &kernel_.value()); } @@ -142,14 +144,14 @@ tsl::AsyncValueRef KernelThunk::Execute( // by scheduling tasks into it. HostKernel launch completion will // automatically signal KernelThunk execute completion. if (ABSL_PREDICT_FALSE(params.intra_op_threadpool && use_task_runner_)) { - return kernel->Launch(thread_dim_, buffers_data, + return kernel->Launch(thread_dim_, kernel_args, [¶ms](se::host::HostKernel::Task task) { params.intra_op_threadpool->getPool()->Schedule( ToCopyableTask(std::move(task))); }); } - TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, buffers_data)); + TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, kernel_args)); return OkExecuteEvent(); } diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk_test.cc b/third_party/xla/xla/service/cpu/runtime/kernel_thunk_test.cc index a80db35857e86f..9f6993dc9a73a7 100644 --- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk_test.cc +++ b/third_party/xla/xla/service/cpu/runtime/kernel_thunk_test.cc @@ -40,8 +40,8 @@ class AddF32HostKernels : public Thunk::HostKernels { public: absl::StatusOr Find(std::string_view name) override { return +[](const SE_HOST_KernelCallFrame* call_frame) { - SE_HOST_KernelArg& in = call_frame->args[0]; - SE_HOST_KernelArg& out = call_frame->args[1]; + const SE_HOST_KernelArg& in = call_frame->args[0]; + const SE_HOST_KernelArg& out = call_frame->args[1]; float* in_ptr = reinterpret_cast(in.data); float* out_ptr = reinterpret_cast(out.data); diff --git a/third_party/xla/xla/stream_executor/host/host_kernel.cc b/third_party/xla/xla/stream_executor/host/host_kernel.cc index ceb148cdaaf918..04586b5272432b 100644 --- a/third_party/xla/xla/stream_executor/host/host_kernel.cc +++ b/third_party/xla/xla/stream_executor/host/host_kernel.cc @@ -69,7 +69,7 @@ class HostKernelExecuteState HostKernelExecuteState(HostKernel::TaskRunner task_runner, HostKernel::KernelFunction* function, ThreadDim thread_dims, - absl::Span buffers); + absl::Span args); // Notify of a completion of a host kernel task. void Notify(absl::Status status); @@ -118,11 +118,19 @@ HostKernel::HostKernel(unsigned arity, SE_HOST_Kernel* kernel, absl::Status HostKernel::Launch( const ThreadDim& thread_dims, absl::Span buffers) const { - SE_HOST_KernelThreadDim kernel_thread_dims = {thread_dims.x, thread_dims.y, - thread_dims.z}; + return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers)); +} + +absl::Status HostKernel::Launch( + const ThreadDim& thread_dims, + absl::Span args) const { + SE_HOST_KernelThreadDim kernel_thread_dims = { + thread_dims.x, + thread_dims.y, + thread_dims.z, + }; SE_HOST_Kernel* kernel = function_->kernel(); - auto args = ConvertBuffersToKernelArgs(buffers); for (uint64_t z = 0; z < thread_dims.z; ++z) { for (uint64_t y = 0; y < thread_dims.y; ++y) { @@ -134,7 +142,7 @@ absl::Status HostKernel::Launch( SE_HOST_KernelError* error = (*kernel)(&call_frame); - if (error != nullptr) { + if (ABSL_PREDICT_FALSE(error != nullptr)) { return absl::InternalError("Failed to call host kernel"); } } @@ -147,12 +155,19 @@ absl::Status HostKernel::Launch( tsl::AsyncValueRef HostKernel::Launch( const ThreadDim& thread_dims, absl::Span buffers, TaskRunner task_runner) const { + return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers), + std::move(task_runner)); +} + +tsl::AsyncValueRef HostKernel::Launch( + const ThreadDim& thread_dims, absl::Span args, + TaskRunner task_runner) const { size_t num_tasks = thread_dims.x * thread_dims.y * thread_dims.z; CHECK_GT(num_tasks, 0) << "Number of tasks must be positive"; // Crash Ok // Short-circuit launch with a single task and run it in the caller thread. if (ABSL_PREDICT_TRUE(num_tasks == 1)) { - absl::Status launched = Launch(thread_dims, buffers); + absl::Status launched = Launch(thread_dims, args); return ABSL_PREDICT_TRUE(launched.ok()) ? OkLaunchEvent() : tsl::MakeErrorAsyncValueRef(std::move(launched)); @@ -160,7 +175,7 @@ tsl::AsyncValueRef HostKernel::Launch( // Allocate a control structure that will orchestrate kernel execution. auto state = tsl::MakeRef( - std::move(task_runner), function_.get(), thread_dims, buffers); + std::move(task_runner), function_.get(), thread_dims, args); state->CallAsync(/*start_index=*/0, /*end_index=*/num_tasks); @@ -169,12 +184,12 @@ tsl::AsyncValueRef HostKernel::Launch( HostKernelExecuteState::HostKernelExecuteState( HostKernel::TaskRunner task_runner, HostKernel::KernelFunction* function, - ThreadDim thread_dims, absl::Span buffers) + ThreadDim thread_dims, absl::Span args) : task_runner_(std::move(task_runner)), num_tasks_(thread_dims.x * thread_dims.y * thread_dims.z), kernel_(function->kernel()), thread_dims_({thread_dims.x, thread_dims.y, thread_dims.z}), - args_(ConvertBuffersToKernelArgs(buffers)), + args_(args.begin(), args.end()), abort_(false), counter_(num_tasks_), event_(tsl::MakeConstructedAsyncValueRef()) {} diff --git a/third_party/xla/xla/stream_executor/host/host_kernel.h b/third_party/xla/xla/stream_executor/host/host_kernel.h index e8e040a2d86173..9d278b2b79c357 100644 --- a/third_party/xla/xla/stream_executor/host/host_kernel.h +++ b/third_party/xla/xla/stream_executor/host/host_kernel.h @@ -80,6 +80,8 @@ class HostKernel : public Kernel { // `thread_dims` and calling the kernel function. absl::Status Launch(const ThreadDim& thread_dims, absl::Span buffers) const; + absl::Status Launch(const ThreadDim& thread_dims, + absl::Span args) const; // Launches the kernel by iterating over all threads in `thread_dims` and // calling `task_runner` to run individual task (implementation might decide @@ -93,6 +95,9 @@ class HostKernel : public Kernel { tsl::AsyncValueRef Launch( const ThreadDim& thread_dims, absl::Span buffers, TaskRunner task_runner) const; + tsl::AsyncValueRef Launch( + const ThreadDim& thread_dims, absl::Span args, + TaskRunner task_runner) const; // For host platform, we assume that a core is a thread, and we can run at // most one instance of a kernel on a given thread. diff --git a/third_party/xla/xla/stream_executor/host/host_kernel_c_api.h b/third_party/xla/xla/stream_executor/host/host_kernel_c_api.h index 6768706abc2800..30f710cb44b264 100644 --- a/third_party/xla/xla/stream_executor/host/host_kernel_c_api.h +++ b/third_party/xla/xla/stream_executor/host/host_kernel_c_api.h @@ -71,7 +71,7 @@ typedef struct SE_HOST_KernelCallFrame { SE_HOST_KernelThread* thread; size_t num_args; - SE_HOST_KernelArg* args; + const SE_HOST_KernelArg* args; } SE_HOST_KernelCallFrame; // Error reporting for host kernels. NULL means success. diff --git a/third_party/xla/xla/stream_executor/host/host_kernel_test.cc b/third_party/xla/xla/stream_executor/host/host_kernel_test.cc index 5a121bf17cb5b7..aff9e1ed19ce7b 100644 --- a/third_party/xla/xla/stream_executor/host/host_kernel_test.cc +++ b/third_party/xla/xla/stream_executor/host/host_kernel_test.cc @@ -53,9 +53,9 @@ static auto ToCopyableTask(HostKernel::Task task) { } static SE_HOST_KernelError* AddI32(const SE_HOST_KernelCallFrame* call_frame) { - SE_HOST_KernelArg& lhs = call_frame->args[0]; - SE_HOST_KernelArg& rhs = call_frame->args[1]; - SE_HOST_KernelArg& out = call_frame->args[2]; + const SE_HOST_KernelArg& lhs = call_frame->args[0]; + const SE_HOST_KernelArg& rhs = call_frame->args[1]; + const SE_HOST_KernelArg& out = call_frame->args[2]; int32_t* lhs_ptr = reinterpret_cast(lhs.data); int32_t* rhs_ptr = reinterpret_cast(rhs.data); @@ -217,7 +217,9 @@ TEST(HostKernelTest, LaunchAsync) { }; HostKernel host_kernel(/*arity=*/0, no_op); - auto event = host_kernel.Launch(ThreadDim(4, 4, 4), {}, std::move(runner)); + auto event = host_kernel.Launch(ThreadDim(4, 4, 4), + absl::Span(), + std::move(runner)); tsl::BlockUntilReady(event); EXPECT_TRUE(event.IsConcrete()); @@ -245,7 +247,9 @@ TEST(HostKernelTest, LaunchAsyncError) { }; HostKernel host_kernel(/*arity=*/0, maybe_error); - auto event = host_kernel.Launch(ThreadDim(4, 4, 4), {}, std::move(runner)); + auto event = host_kernel.Launch(ThreadDim(4, 4, 4), + absl::Span(), + std::move(runner)); tsl::BlockUntilReady(event); ASSERT_TRUE(event.IsError()); @@ -269,7 +273,8 @@ static void BM_HostKernelSyncLaunch(benchmark::State& state) { HostKernel kernel(/*arity=*/0, NoOp); for (auto _ : state) { - benchmark::DoNotOptimize(kernel.Launch(ThreadDim(tdim_x), /*buffers=*/{})); + benchmark::DoNotOptimize(kernel.Launch( + ThreadDim(tdim_x), absl::Span())); } } @@ -281,9 +286,11 @@ static void BM_HostKernelAsyncLaunch(benchmark::State& state) { HostKernel kernel(/*arity=*/0, NoOp); for (auto _ : state) { - auto event = kernel.Launch(ThreadDim(tdim_x), {}, [&](auto task) { - thread_pool->Schedule(ToCopyableTask(std::move(task))); - }); + auto event = + kernel.Launch(ThreadDim(tdim_x), absl::Span(), + [&](auto task) { + thread_pool->Schedule(ToCopyableTask(std::move(task))); + }); tsl::BlockUntilReady(event); } } From 43a5d4f5d8a067f3d4e2343d7ab757223fd92ae3 Mon Sep 17 00:00:00 2001 From: TJ Xu Date: Tue, 25 Jun 2024 17:38:55 -0700 Subject: [PATCH 245/256] PR #13813: [NVIDIA GPU] Assign a fixed index for cached activation Imported from GitHub PR https://github.com/openxla/xla/pull/13813 gpu_windowed_einsum_handler pass has been re-using the empty buffer of the transformed while loop. This buffer is given by the spmd dot_handler pass. The shape of the buffer has changed from the allgathered shape of the sharded operand to the output shape of the dot which leads to a shape incompatibility error. To make the gpu handler completely safe, we will make a new element in the tuple to host the cached activation with the desired shape. The slice index of where to write the slice into the full buffer also changes based on whether it's contracting or non-contracting dim is sharded. With the new element, we will need to determine the slice index ourselves in the handler pass. Copybara import of the project: -- ceeff8e5da8ecb3f382bbd8dee83e2f0c909b22d by TJ Xu : Assign a fixed index for cached activation Cache correct activation slice when contracting dim is sharded -- 233763b8efb4ab0045eb998b437c7b28c8f776c8 by TJ Xu : Simplified logic in gpu einsum handler to be more generic -- 2220cd1a022ad519cd23ab36c31c70c9627fc76d by TJ Xu : remove un-used variables Merging this change closes #13813 PiperOrigin-RevId: 646666635 --- .../gpu/gpu_windowed_einsum_handler.cc | 149 +++++++++++++++--- .../gpu/gpu_windowed_einsum_handler_test.cc | 86 +++++++++- 2 files changed, 206 insertions(+), 29 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc index ce1dfa4f1c7863..8f5e26124f24a4 100644 --- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc +++ b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc @@ -378,8 +378,16 @@ absl::StatusOr HandleAgWindowedEinsumLoop(HloComputation* comp, return changed; } +static int64_t GetAgActivationCacheIndex(const HloInstruction* while_loop) { + const HloInstruction* loop_tuple = while_loop->operand(0); + const Shape& tuple_shape = loop_tuple->shape(); + CHECK(tuple_shape.IsTuple()); + return tuple_shape.tuple_shapes_size(); +} + absl::Status ProcessWindowedEinsumLoopForActivationCaching( - GpuWindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop) { + GpuWindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop, + HloInstruction* ag_with_shared_operand) { HloInstruction* loop = ag_loop.loop; // Transform the while body to cache the allgathered result in the // output buffer to be consumed by the dot @@ -392,15 +400,61 @@ absl::Status ProcessWindowedEinsumLoopForActivationCaching( } // Get the output operand of the full buffer. HloInstruction* root = while_body->root_instruction(); + // Change loop body to include the new input and output element. + HloInstruction* input_tuple = while_body->parameter_instruction(0); + const Shape& input_shape = input_tuple->shape(); // The full buffer that we will use to cache the accumulated activation - // is the 4th operand in the output tuple. - int64_t full_cache_buffer_index = 3; + // is the last operand in the output tuple. + int64_t full_cache_buffer_index = GetAgActivationCacheIndex(loop); + std::vector new_input_shapes(input_shape.tuple_shapes().begin(), + input_shape.tuple_shapes().end()); + new_input_shapes.push_back(ag_with_shared_operand->shape()); + // Update body input shape + Shape new_input_shape = ShapeUtil::MakeTupleShape(new_input_shapes); + *input_tuple->mutable_shape() = new_input_shape; HloInstruction* full_buffer_output_gte = - root->mutable_operand(full_cache_buffer_index); - HloInstruction* new_full_buffer_output; + while_body->AddInstruction(HloInstruction::CreateGetTupleElement( + ag_with_shared_operand->shape(), input_tuple, + full_cache_buffer_index)); + + // Update condition input shape + HloComputation* cond_comp = loop->while_condition(); + HloInstruction* cond_input_tuple = cond_comp->parameter_instruction(0); + *cond_input_tuple->mutable_shape() = new_input_shape; + + // Update input to the while instruction in parent computation + HloInstruction* original_while_input = loop->mutable_operand(0); + HloComputation* parent_comp = loop->parent(); + std::vector new_operands( + original_while_input->operands().begin(), + original_while_input->operands().end()); + new_operands.push_back( + parent_comp->AddInstruction(HloInstruction::CreateBroadcast( + ag_with_shared_operand->shape(), + parent_comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(new_input_shapes[0].element_type()))), + {}))); + HloInstruction* new_while_input = + parent_comp->AddInstruction(HloInstruction::CreateTuple(new_operands)); + TF_RETURN_IF_ERROR( + loop->ReplaceOperandWithDifferentShape(0, new_while_input)); + TF_RETURN_IF_ERROR(parent_comp->ReplaceInstructionWithDifferentShape( + original_while_input, new_while_input)); + *loop->mutable_shape() = new_input_shape; + + HloInstruction* new_full_buffer_output = nullptr; // Find the DUS in the loop body and re-use the slice indices // This should just be a constant(0) HloInstruction* dus_boundary_constant; + // The slice we need this time is the output of the first + // collective-permute + HloInstruction* first_cp_output; + for (HloInstruction* gte_user : input_gte->users()) { + if (gte_user->opcode() == HloOpcode::kCollectivePermute) { + first_cp_output = gte_user; + break; + } + } for (HloInstruction* inst : while_body->MakeInstructionPostOrder()) { HloInstruction* slice_indices; // If we have a DUS(PARAM,DS) pattern, we need to update the output @@ -434,24 +488,68 @@ absl::Status ProcessWindowedEinsumLoopForActivationCaching( dus_boundary_constant->shape(), slice_indices)); VLOG(5) << "Created slice op for second slice: " << slice_indices->ToString(); - // The slice we need this time is the output of the first - // collective-permute - HloInstruction* cp_output; - for (HloInstruction* gte_user : input_gte->users()) { - if (gte_user->opcode() == HloOpcode::kCollectivePermute) { - cp_output = gte_user; - break; - } - } new_full_buffer_output = while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( full_buffer_output_gte->shape(), full_buffer_output_gte, - cp_output, + first_cp_output, {dus_boundary_constant, slice_indices, dus_boundary_constant})); } + + // If we have a Dot(DS(parameter_index1)), then operands are sharded along + // the contracting dim. Slice indices will be the contracting dim's slices. + HloInstruction* slice_index; + HloInstruction* ds_index_constant; + HloInstruction* remainder; + HloInstruction* ds_param; + // There will be 2 dynamic-slices for unrolled loops, match for each one to + // get the slice index which will be used to write the corresponding + // received shard into cached activation buffer. For unrolled loops, we need + // to write to the final buffer twice per iteration, so we need to match for + // the correct slice index based on each DS. + if (Match(inst, m::Dot(m::Op(), m::DynamicSlice(&ds_param))) && + Match(ds_param->operand(0), m::GetTupleElement(m::Parameter(), 1))) { + for (int64_t ds_op_i = 1; ds_op_i < ds_param->operands().size(); + ds_op_i++) { + if (!Match( + ds_param->mutable_operand(ds_op_i), + m::Reshape(&slice_index, m::DynamicSlice(m::Constant(), + m::Op(&remainder)))) && + !Match(ds_param->mutable_operand(ds_op_i), + m::Constant(&ds_index_constant))) { + return absl::OkStatus(); + } + } + // First DS has slice index calculated based on loop iterator + // Remainder(add(gte, partition_id)) + if (Match(remainder, + m::Remainder(m::Add(m::GetTupleElement(), m::Op()), m::Op()))) { + full_buffer_output_gte = + while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + full_buffer_output_gte->shape(), full_buffer_output_gte, + input_gte, + {ds_index_constant, ds_index_constant, slice_index})); + } + // Second DS has slice index calculated based on loop iterator+1 hence + // Remainder(add(add(gte, 1), partition_id)) + if (Match(remainder, + m::Remainder( + m::Add(m::Add(m::GetTupleElement(), m::Op()), m::Op()), + m::Op()))) { + new_full_buffer_output = + while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + full_buffer_output_gte->shape(), full_buffer_output_gte, + first_cp_output, + {ds_index_constant, ds_index_constant, slice_index})); + } + } } - TF_RETURN_IF_ERROR(root->ReplaceOperandWith(full_cache_buffer_index, - new_full_buffer_output)); + std::vector original_operands(root->operands().begin(), + root->operands().end()); + original_operands.push_back(new_full_buffer_output); + HloInstruction* new_output_tuple = while_body->AddInstruction( + HloInstruction::CreateTuple(original_operands)); + TF_RETURN_IF_ERROR( + while_body->ReplaceInstructionWithDifferentShape(root, new_output_tuple)); return absl::OkStatus(); } @@ -620,17 +718,20 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { VLOG(5) << "Found all-gather that shares the same operand with a " "windowed einsum loop : " << loop->ToString(); + + if (!ag_loop.consumed) { + TF_RETURN_IF_ERROR(ProcessWindowedEinsumLoopForActivationCaching( + ag_loop, ag_with_shared_operand)); + ag_loop.consumed = true; + } int64_t cache_output_index = dot->operand_index(ag_with_shared_operand); - HloInstruction* new_gte = comp->AddInstruction( - HloInstruction::CreateGetTupleElement(loop, 3)); + HloComputation* comp = dot->parent(); + HloInstruction* new_gte = + comp->AddInstruction(HloInstruction::CreateGetTupleElement( + loop, GetAgActivationCacheIndex(loop) - 1)); TF_RETURN_IF_ERROR( dot->ReplaceOperandWith(cache_output_index, new_gte)); TF_RETURN_IF_ERROR(comp->RemoveInstruction(ag_with_shared_operand)); - if (!ag_loop.consumed) { - TF_RETURN_IF_ERROR( - ProcessWindowedEinsumLoopForActivationCaching(ag_loop)); - ag_loop.consumed = true; - } } } // Rewrites an all-to-all+gemm into multiple independent partial a2a+gemms diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc index 23257e1c71a34b..6f23319980e90c 100644 --- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc @@ -269,23 +269,22 @@ ENTRY main.12_spmd { FindInstructionByName(module->entry_computation(), "dot.7"); // dot.7 should now consume output of the windowed einsum while loop. EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement); - EXPECT_EQ(inst->operand(0)->tuple_index(), 3); + EXPECT_EQ(inst->operand(0)->tuple_index(), 5); EXPECT_EQ(inst->operand(0)->operand(0), ag_loop); // while loop's root should now have a chain of DUS. HloInstruction* ag_while_root = ag_loop->while_body()->root_instruction(); EXPECT_THAT(ag_while_root, GmockMatch(m::Tuple( - m::Op(), m::Op(), m::Op(), + m::Op(), m::Op(), m::Op(), m::Op(), m::Op(), m::DynamicUpdateSlice( m::DynamicUpdateSlice( m::GetTupleElement(m::Parameter()) .WithPredicate([](const HloInstruction* instr) { - return instr->tuple_index() == 3; + return instr->tuple_index() == 5; }), m::Op(), m::Op(), m::Op(), m::Op()), - m::Op(), m::Op(), m::Op(), m::Op()), - m::Op()))); + m::Op(), m::Op(), m::Op(), m::Op())))); } TEST_F(GpuWindowedEinsumHanlderTest, A2aGemmHaveStreamIds) { constexpr absl::string_view kHloString = R"( @@ -838,5 +837,82 @@ ENTRY main.9_spmd { )"); } +TEST_F(GpuWindowedEinsumHanlderTest, + AgLoopsMultipleConsumersAreChainedWithShardedContratingDim) { + constexpr absl::string_view kHloString = R"( +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0})->bf16[4096,6288]{1,0}}, num_partitions=8 + +windowed_dot_general_body_ag { + param.195 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) parameter(0) + get-tuple-element.588 = bf16[16,2048,512]{2,1,0} get-tuple-element(param.195), index=0 + collective-permute.194 = bf16[16,2048,512]{2,1,0} collective-permute(get-tuple-element.588), channel_id=446, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}} + collective-permute.195 = bf16[16,2048,512]{2,1,0} collective-permute(collective-permute.194), channel_id=447, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}} + get-tuple-element.589 = bf16[4096,6288]{1,0} get-tuple-element(param.195), index=1 + get-tuple-element.590 = bf16[16,2048,6288]{2,1,0} get-tuple-element(param.195), index=2 + constant.11432 = s32[8]{0} constant({0, 512, 1024, 1536, 2048, 2560, 3072, 3584}) + get-tuple-element.592 = u32[] get-tuple-element(param.195), index=4 + partition-id.194 = u32[] partition-id() + add.4309 = u32[] add(get-tuple-element.592, partition-id.194) + constant.11431 = u32[] constant(8) + remainder.194 = u32[] remainder(add.4309, constant.11431) + dynamic-slice.388 = s32[1]{0} dynamic-slice(constant.11432, remainder.194), dynamic_slice_sizes={1} + reshape.12959 = s32[] reshape(dynamic-slice.388) + constant.11433 = s32[] constant(0) + dynamic-slice.389 = bf16[512,6288]{1,0} dynamic-slice(get-tuple-element.589, reshape.12959, constant.11433), dynamic_slice_sizes={512,6288} + dot.244 = bf16[16,2048,6288]{2,1,0} dot(get-tuple-element.588, dynamic-slice.389), lhs_contracting_dims={2}, rhs_contracting_dims={0} + add.4310 = bf16[16,2048,6288]{2,1,0} add(get-tuple-element.590, dot.244) + constant.11434 = u32[] constant(1) + add.4312 = u32[] add(get-tuple-element.592, constant.11434) + add.4313 = u32[] add(add.4312, partition-id.194) + remainder.195 = u32[] remainder(add.4313, constant.11431) + dynamic-slice.390 = s32[1]{0} dynamic-slice(constant.11432, remainder.195), dynamic_slice_sizes={1} + reshape.12960 = s32[] reshape(dynamic-slice.390) + dynamic-slice.391 = bf16[512,6288]{1,0} dynamic-slice(get-tuple-element.589, reshape.12960, constant.11433), dynamic_slice_sizes={512,6288} + dot.245 = bf16[16,2048,6288]{2,1,0} dot(collective-permute.194, dynamic-slice.391), lhs_contracting_dims={2}, rhs_contracting_dims={0} + add.4314 = bf16[16,2048,6288]{2,1,0} add(add.4310, dot.245) + get-tuple-element.591 = bf16[16,2048,6288]{2,1,0} get-tuple-element(param.195), index=3 + add.4315 = u32[] add(add.4312, constant.11434) + ROOT tuple.98 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) tuple(collective-permute.195, get-tuple-element.589, add.4314, get-tuple-element.591, add.4315) +} // windowed_dot_general_body_ag + +windowed_dot_general_cond_ag { + param = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) parameter(0) + get-tuple-element = u32[] get-tuple-element(param), index=4 + constant = u32[] constant(4) + ROOT compare = pred[] compare(get-tuple-element, constant), direction=LT +} + +ENTRY main.12_spmd { + param.4 = bf16[16,2048,512]{2,1,0} parameter(0) + param.5 = bf16[4096,6288]{1,0} parameter(1) + constant.22 = bf16[] constant(0) + broadcast = bf16[16,2048,6288]{2,1,0} broadcast(constant.22), dimensions={} + constant.24 = u32[] constant(0) + tuple.2 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) tuple(param.4, param.5, broadcast, broadcast, constant.24) + while = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag + get-tuple-element.13 = bf16[16,2048,6288]{2,1,0} get-tuple-element(while), index=2 + all-gather = bf16[16,2048,4096]{2,1,0} all-gather(param.4), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={2}, use_global_device_ids=true + param.6 = bf16[16,2048,6288]{2,1,0} parameter(2) + ROOT dot.7 = bf16[4096,6288]{1,0} dot(all-gather, param.6), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + GpuWindowedEinsumHandler gpu_handler; + bool changed; + TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); + EXPECT_TRUE(changed); + + HloInstruction* ag_loop = + FindInstructionByName(module->entry_computation(), "while"); + HloInstruction* inst = + FindInstructionByName(module->entry_computation(), "dot.7"); + // dot.7 should now consume output of the windowed einsum while loop. + EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement); + EXPECT_EQ(inst->operand(0)->tuple_index(), 5); + EXPECT_EQ(inst->operand(0)->operand(0), ag_loop); +} } // namespace } // namespace xla::gpu From 7343933df4f96affee731371c674782409677fa3 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Tue, 25 Jun 2024 17:40:21 -0700 Subject: [PATCH 246/256] PR #14073: Add select(compare(a, b, GT/GE), a, b) => or(a, b) to algsimp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/openxla/xla/pull/14073 In one of the customer's HLOs (in the reduceOp computation function), I found the following pattern: ```c p0 = pred[]{0} parameter(0) p1 = pred[]{0} parameter(1) compare = pred[]{0} compare(p0, p1), direction=GT select = pred[]{0} select(compare, p0, p1) ``` It can be simplified to `logical_or`. This PR adds the following patterns to algsimp ```c select(compare(a, b, GT/GE), a, b) => or(a, b) select(compare(a, b, LT/LE), a, b) => and(a, b) select(compare(a, b, EQ), a, b) => b select(compare(a, b, NE), a, b) => a a,b ∈ PRED ``` Copybara import of the project: -- 6fe68d7319b272ff041b67e038359540cddda489 by Alexander Pivovarov : Add select(compare(a, b, GT/GE), a, b) => or(a, b) to algsimp Merging this change closes #14073 PiperOrigin-RevId: 646667024 --- .../xla/xla/service/algebraic_simplifier.cc | 27 ++++++ .../xla/service/algebraic_simplifier_test.cc | 89 +++++++++++++++++++ 2 files changed, 116 insertions(+) diff --git a/third_party/xla/xla/service/algebraic_simplifier.cc b/third_party/xla/xla/service/algebraic_simplifier.cc index dd13ffebd9e852..641fedf0c72405 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.cc +++ b/third_party/xla/xla/service/algebraic_simplifier.cc @@ -8188,6 +8188,33 @@ absl::Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) { select->mutable_operand(0)->shape(), HloOpcode::kNot, select->mutable_operand(0))); } + // select(compare(a, b, GT/GE), a, b) => or(a, b) + // select(compare(a, b, LT/LE), a, b) => and(a, b) + // select(compare(a, b, EQ), a, b) => b + // select(compare(a, b, NE), a, b) => a + HloInstruction *compare, *lhs, *rhs; + if (Match(select, m::Select(m::Op(&compare), m::Op(&lhs), m::Op(&rhs))) && + Match(compare, m::Compare(m::Op().Is(lhs), m::Op().Is(rhs)))) { + auto cmp_dir = compare->comparison_direction(); + if (cmp_dir == ComparisonDirection::kGt || + cmp_dir == ComparisonDirection::kGe) { + return ReplaceWithNewInstruction( + select, HloInstruction::CreateBinary(select->shape(), + HloOpcode::kOr, lhs, rhs)); + } + if (cmp_dir == ComparisonDirection::kLt || + cmp_dir == ComparisonDirection::kLe) { + return ReplaceWithNewInstruction( + select, HloInstruction::CreateBinary(select->shape(), + HloOpcode::kAnd, lhs, rhs)); + } + if (cmp_dir == ComparisonDirection::kEq) { + return ReplaceInstruction(select, rhs); + } + if (cmp_dir == ComparisonDirection::kNe) { + return ReplaceInstruction(select, lhs); + } + } } // select(pred, xs, dynamic_update_slice(xs, x, i)) diff --git a/third_party/xla/xla/service/algebraic_simplifier_test.cc b/third_party/xla/xla/service/algebraic_simplifier_test.cc index 00970a51546b1a..921098aa7565e8 100644 --- a/third_party/xla/xla/service/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/service/algebraic_simplifier_test.cc @@ -736,6 +736,95 @@ TEST_F(AlgebraicSimplifierTest, SelectPredPred2) { GmockMatch(m::Not(m::Parameter(0)))); } +// select(compare(a, b, GT/GE), a, b) => or(a, b), a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectGtCompare) { + for (const auto cmp_dir : {"GT", "GE"}) { + const auto kModuleStr = absl::StrFormat(R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=%s + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )", + cmp_dir); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Or(m::Parameter(0), m::Parameter(1)))); + } +} + +// select(compare(a, b, LT/LE), a, b) => and(a, b), a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectLtCompare) { + for (const auto cmp_dir : {"LT", "LE"}) { + const auto kModuleStr = absl::StrFormat(R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=%s + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )", + cmp_dir); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::And(m::Parameter(0), m::Parameter(1)))); + } +} + +// select(compare(a, b, EQ), a, b) => b, a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectEqCompare) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=EQ + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(1))); +} + +// select(compare(a, b, NE), a, b) => a, a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectNeCompare) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=NE + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(0))); +} + +// select(compare(a, b, NE), b, a) ≠> a - wrong operands order +TEST_F(AlgebraicSimplifierTest, SelectNeCompare_NegativeTestCase) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=NE + ROOT select = pred[8]{0} select(compare, p1, p0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); +} + // Test that select(pred, xs, dynamic_update_slice(xs, x, i)) is simplified // to dynamic_update_slice(xs, select(pred, dynamic_slice(xs, i), x), i) TEST_F(AlgebraicSimplifierTest, SelectDUSWithShapedPred) { From 60a8e1a2d4cebeee0424bf4e38ccff2b61ef9b40 Mon Sep 17 00:00:00 2001 From: Laura Pak Date: Tue, 25 Jun 2024 17:51:51 -0700 Subject: [PATCH 247/256] Move tensorflow/lite/tools/optimize/testdata models to tf/compiler/mlir/lite/quantization/lite/testdata. PiperOrigin-RevId: 646669412 --- .../mlir/lite/quantization/lite/BUILD | 93 ++++++++------- .../quantization/lite/quantize_model_test.cc | 65 ++++++----- .../lite/quantize_weights_test.cc | 17 +-- .../mlir/lite/quantization/lite}/test_util.cc | 10 +- .../mlir/lite/quantization/lite}/test_util.h | 16 +-- .../quantization/lite}/testdata/README.md | 0 .../lite}/testdata/add_with_const_input.bin | Bin .../quantization/lite}/testdata/argmax.bin | Bin .../lite}/testdata/broadcast_to.bin | Bin .../quantization/lite}/testdata/concat.bin | Bin .../quantization/lite}/testdata/custom_op.bin | Bin .../lite/quantization/lite}/testdata/fc.bin | Bin .../quantization/lite}/testdata/fc_qat.bin | Bin .../quantization/lite}/testdata/gather_nd.bin | Bin .../lite}/testdata/lstm_calibrated.bin | Bin .../lite}/testdata/lstm_calibrated2.bin | Bin .../lite}/testdata/lstm_quantized.bin | Bin .../lite}/testdata/lstm_quantized2.bin | Bin .../quantization/lite}/testdata/maximum.bin | Bin .../quantization/lite}/testdata/minimum.bin | Bin .../quantization/lite}/testdata/mixed.bin | Bin .../quantization/lite}/testdata/mixed16x8.bin | Bin .../testdata/multi_input_add_reshape.bin | Bin .../lite/quantization/lite}/testdata/pack.bin | Bin .../lite}/testdata/quantized_with_gather.bin | Bin .../testdata/resource_vars_calibrated.bin | Bin ...single_avg_pool_min_minus_5_max_plus_5.bin | Bin .../lite}/testdata/single_conv_no_bias.bin | Bin .../single_conv_weights_min_0_max_plus_10.bin | Bin ...onv_weights_min_minus_127_max_plus_127.bin | Bin .../single_softmax_min_minus_5_max_plus_5.bin | Bin .../quantization/lite}/testdata/split.bin | Bin .../lite}/testdata/svdf_calibrated.bin | Bin .../lite}/testdata/svdf_quantized.bin | Bin .../quantization/lite}/testdata/transpose.bin | Bin ...nidirectional_sequence_lstm_calibrated.bin | Bin ...unidirectional_sequence_lstm_quantized.bin | Bin .../quantization/lite}/testdata/unpack.bin | Bin .../testdata/weight_shared_between_convs.bin | Bin .../quantization/lite}/testdata/where.bin | Bin tensorflow/lite/tools/optimize/BUILD | 103 +++++++---------- .../lite/tools/optimize/model_utils_test.cc | 1 - .../tools/optimize/quantization_utils_test.cc | 4 +- .../tools/optimize/quantize_model_test.cc | 106 +++++++++++------- .../tools/optimize/quantize_weights_test.cc | 16 +-- .../reduced_precision_support_test.cc | 1 - 46 files changed, 228 insertions(+), 204 deletions(-) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/test_util.cc (95%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/test_util.h (93%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/README.md (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/add_with_const_input.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/argmax.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/broadcast_to.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/concat.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/custom_op.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/fc.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/fc_qat.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/gather_nd.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/lstm_calibrated.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/lstm_calibrated2.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/lstm_quantized.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/lstm_quantized2.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/maximum.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/minimum.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/mixed.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/mixed16x8.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/multi_input_add_reshape.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/pack.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/quantized_with_gather.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/resource_vars_calibrated.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/single_avg_pool_min_minus_5_max_plus_5.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/single_conv_no_bias.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/single_conv_weights_min_0_max_plus_10.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/single_conv_weights_min_minus_127_max_plus_127.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/single_softmax_min_minus_5_max_plus_5.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/split.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/svdf_calibrated.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/svdf_quantized.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/transpose.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/unidirectional_sequence_lstm_calibrated.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/unidirectional_sequence_lstm_quantized.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/unpack.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/weight_shared_between_convs.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/where.bin (100%) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index 78ae512eb54202..e57c2b30808d82 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -10,6 +10,10 @@ package( licenses = ["notice"], ) +exports_files(glob([ + "testdata/*.bin", +])) + package_group( name = "friends", packages = [ @@ -123,39 +127,39 @@ tf_cc_test( name = "quantize_model_test", srcs = ["quantize_model_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - "//tensorflow/lite/tools/optimize:testdata/add_with_const_input.bin", - "//tensorflow/lite/tools/optimize:testdata/argmax.bin", - "//tensorflow/lite/tools/optimize:testdata/broadcast_to.bin", - "//tensorflow/lite/tools/optimize:testdata/concat.bin", - "//tensorflow/lite/tools/optimize:testdata/fc.bin", - "//tensorflow/lite/tools/optimize:testdata/fc_qat.bin", - "//tensorflow/lite/tools/optimize:testdata/gather_nd.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated2.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_quantized2.bin", - "//tensorflow/lite/tools/optimize:testdata/maximum.bin", - "//tensorflow/lite/tools/optimize:testdata/minimum.bin", - "//tensorflow/lite/tools/optimize:testdata/mixed.bin", - "//tensorflow/lite/tools/optimize:testdata/mixed16x8.bin", - "//tensorflow/lite/tools/optimize:testdata/multi_input_add_reshape.bin", - "//tensorflow/lite/tools/optimize:testdata/pack.bin", - "//tensorflow/lite/tools/optimize:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_no_bias.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", - "//tensorflow/lite/tools/optimize:testdata/single_softmax_min_minus_5_max_plus_5.bin", - "//tensorflow/lite/tools/optimize:testdata/split.bin", - "//tensorflow/lite/tools/optimize:testdata/svdf_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/svdf_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/transpose.bin", - "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/unpack.bin", - "//tensorflow/lite/tools/optimize:testdata/where.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/add_with_const_input.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/argmax.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/broadcast_to.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/concat.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc_qat.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/gather_nd.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated2.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized2.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/maximum.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/minimum.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed16x8.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/multi_input_add_reshape.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/pack.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_no_bias.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_softmax_min_minus_5_max_plus_5.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/split.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/transpose.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unpack.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/where.bin", ], tags = [ "tflite_not_portable_android", @@ -163,12 +167,12 @@ tf_cc_test( ], deps = [ ":quantize_model", + ":test_util", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/schema:schema_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", - "//tensorflow/lite/tools/optimize:test_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", @@ -181,13 +185,13 @@ tf_cc_test( name = "quantize_weights_test", srcs = ["quantize_weights_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - "//tensorflow/lite/tools/optimize:testdata/custom_op.bin", - "//tensorflow/lite/tools/optimize:testdata/quantized_with_gather.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", - "//tensorflow/lite/tools/optimize:testdata/weight_shared_between_convs.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/custom_op.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/quantized_with_gather.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/weight_shared_between_convs.bin", ], tags = [ # TODO(b/327796566): re-enable after the bug is fixed @@ -200,15 +204,28 @@ tf_cc_test( ], deps = [ ":quantize_weights", + ":test_util", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/schema:schema_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", - "//tensorflow/lite/tools/optimize:test_util", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@flatbuffers", "@local_tsl//tsl/platform:logging", ], ) + +cc_library( + name = "test_util", + testonly = 1, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + "//tensorflow/lite:framework", + "//tensorflow/lite/core/api", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc index e7d5e00b703392..1e7cdcdea07d33 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/status/status.h" #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" #include "tensorflow/core/platform/init_main.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/tools/optimize/test_util.h" #include "tsl/lib/core/status_test_util.h" // Note: branched from tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -192,7 +192,8 @@ void VerifyQuantizationScale( class QuantizeModelTest : public testing::Test { protected: QuantizeModelTest() { - input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); + input_model_ = + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -277,7 +278,8 @@ class QuantizeConvModelTest : public QuantizeModelTest, protected: QuantizeConvModelTest() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); + input_model_ = + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); // Flatbuffer is missing calibration data -- add dummy params. @@ -347,7 +349,7 @@ TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) { class QuantizeConvNoBiasModelTest : public QuantizeModelTest { protected: QuantizeConvNoBiasModelTest() { - input_model_ = ReadModel(internal::kConvModelWithNoBias); + input_model_ = ReadModel(::mlir::lite::internal::kConvModelWithNoBias); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -367,7 +369,7 @@ class QuantizeConvNoBiasModelTest : public QuantizeModelTest { class QuantizeSplitModelTest : public QuantizeModelTest { protected: QuantizeSplitModelTest() { - input_model_ = ReadModel(internal::kModelSplit); + input_model_ = ReadModel(::mlir::lite::internal::kModelSplit); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -452,7 +454,8 @@ class QuantizeConvModel2Test : public QuantizeModelTest, protected: QuantizeConvModel2Test() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); + input_model_ = + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); auto& subgraph = model_.subgraphs[0]; @@ -690,7 +693,8 @@ TEST_P(QuantizeConvModel2Test, VerifyConvDisablePerChannelQuantization) { class QuantizeSoftmaxTest : public QuantizeModelTest { protected: QuantizeSoftmaxTest() { - input_model_ = ReadModel(internal::kSingleSoftmaxModelMinMinus5MaxPlus5); + input_model_ = + ReadModel(::mlir::lite::internal::kSingleSoftmaxModelMinMinus5MaxPlus5); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -753,7 +757,8 @@ TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { class QuantizeAvgPoolTest : public QuantizeModelTest { protected: QuantizeAvgPoolTest() { - input_model_ = ReadModel(internal::kSingleAvgPoolModelMinMinus5MaxPlus5); + input_model_ = + ReadModel(::mlir::lite::internal::kSingleAvgPoolModelMinMinus5MaxPlus5); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -813,7 +818,7 @@ TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest { protected: QuantizeMultiInputAddWithReshapeTest() { - input_model_ = ReadModel(internal::kMultiInputAddWithReshape); + input_model_ = ReadModel(::mlir::lite::internal::kMultiInputAddWithReshape); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -933,7 +938,7 @@ class QuantizeConstInputTest : public QuantizeModelTest, protected: QuantizeConstInputTest() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kConstInputAddModel); + input_model_ = ReadModel(::mlir::lite::internal::kConstInputAddModel); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -980,7 +985,7 @@ TEST_P(QuantizeConstInputTest, VerifyConstOpInput) { class QuantizeArgMaxTest : public QuantizeModelTest { protected: QuantizeArgMaxTest() { - input_model_ = ReadModel(internal::kModelWithArgMaxOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithArgMaxOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1025,7 +1030,7 @@ TEST_F(QuantizeArgMaxTest, VerifyArgMax) { class QuantizeLSTMTest : public QuantizeModelTest { protected: QuantizeLSTMTest() { - input_model_ = ReadModel(internal::kLstmCalibrated); + input_model_ = ReadModel(::mlir::lite::internal::kLstmCalibrated); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1037,7 +1042,7 @@ TEST_F(QuantizeLSTMTest, VerifyLSTM) { /*allow_float=*/true, TensorType_INT8, output_buffer_)); // Read expected model. - auto expected_fb_model = ReadModel(internal::kLstmQuantized); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kLstmQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1048,7 +1053,7 @@ TEST_F(QuantizeLSTMTest, VerifyLSTM) { class QuantizeLSTM2Test : public QuantizeModelTest { protected: QuantizeLSTM2Test() { - input_model_ = ReadModel(internal::kLstmCalibrated2); + input_model_ = ReadModel(::mlir::lite::internal::kLstmCalibrated2); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1061,7 +1066,7 @@ TEST_F(QuantizeLSTM2Test, VerifyLSTM) { /*allow_float=*/false, TensorType_INT8, output_buffer_)); // Read expected model. - auto expected_fb_model = ReadModel(internal::kLstmQuantized2); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kLstmQuantized2); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1072,7 +1077,8 @@ TEST_F(QuantizeLSTM2Test, VerifyLSTM) { class QuantizeUnidirectionalSequenceLSTMTest : public QuantizeModelTest { protected: QuantizeUnidirectionalSequenceLSTMTest() { - input_model_ = ReadModel(internal::kUnidirectionalSequenceLstmCalibrated); + input_model_ = ReadModel( + ::mlir::lite::internal::kUnidirectionalSequenceLstmCalibrated); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1086,7 +1092,7 @@ TEST_F(QuantizeUnidirectionalSequenceLSTMTest, // Read expected model. auto expected_fb_model = - ReadModel(internal::kUnidirectionalSequenceLstmQuantized); + ReadModel(::mlir::lite::internal::kUnidirectionalSequenceLstmQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1097,7 +1103,7 @@ TEST_F(QuantizeUnidirectionalSequenceLSTMTest, class QuantizeSVDFTest : public QuantizeModelTest { protected: QuantizeSVDFTest() { - input_model_ = ReadModel(internal::kSvdfCalibrated); + input_model_ = ReadModel(::mlir::lite::internal::kSvdfCalibrated); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1110,7 +1116,7 @@ TEST_F(QuantizeSVDFTest, VerifySVDF) { /*allow_float=*/false, TensorType_INT8, output_buffer_)); // Read expected model. - auto expected_fb_model = ReadModel(internal::kSvdfQuantized); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kSvdfQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1123,7 +1129,7 @@ class QuantizeFCTest : public QuantizeModelTest, protected: QuantizeFCTest() { disable_per_channel_quantization_for_dense_ = GetParam(); - input_model_ = ReadModel(internal::kModelWithFCOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithFCOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1371,7 +1377,7 @@ class QuantizeCustomOpTest public ::testing::WithParamInterface { protected: QuantizeCustomOpTest() { - input_model_ = ReadModel(internal::kModelMixed); + input_model_ = ReadModel(::mlir::lite::internal::kModelMixed); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1409,7 +1415,7 @@ INSTANTIATE_TEST_SUITE_P(QuantizeCustomOpTest, QuantizeCustomOpTest, class QuantizePackTest : public QuantizeModelTest { protected: QuantizePackTest() { - input_model_ = ReadModel(internal::kModelPack); + input_model_ = ReadModel(::mlir::lite::internal::kModelPack); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1526,14 +1532,15 @@ TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) { Eq(input2->quantization->zero_point)); } -INSTANTIATE_TEST_SUITE_P(MinimumMaximumTestInst, QuantizeMinimumMaximumTest, - testing::ValuesIn({internal::kModelWithMinimumOp, - internal::kModelWithMaximumOp})); +INSTANTIATE_TEST_SUITE_P( + MinimumMaximumTestInst, QuantizeMinimumMaximumTest, + testing::ValuesIn({::mlir::lite::internal::kModelWithMinimumOp, + ::mlir::lite::internal::kModelWithMaximumOp})); class QuantizeUnpackTest : public QuantizeModelTest { protected: QuantizeUnpackTest() { - input_model_ = ReadModel(internal::kModelWithUnpack); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithUnpack); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1583,7 +1590,7 @@ class QuantizeBroadcastToModelTest protected: QuantizeBroadcastToModelTest() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kModelWithBroadcastToOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithBroadcastToOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1646,7 +1653,7 @@ class QuantizeGatherNDModelTest protected: QuantizeGatherNDModelTest() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kModelWithGatherNDOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithGatherNDOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1706,7 +1713,7 @@ TEST_P(QuantizeGatherNDModelTest, QuantizeGatherND) { class QuantizeWhereModelTest : public QuantizeModelTest { protected: QuantizeWhereModelTest() { - input_model_ = ReadModel(internal::kModelWithWhereOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithWhereOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc index 2e80bcae7486b4..7a42e74c2619af 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h" #include +#include #include #include #include @@ -26,6 +27,7 @@ limitations under the License. #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" #include "tensorflow/core/platform/init_main.h" @@ -33,7 +35,6 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/tools/optimize/test_util.h" #include "tsl/platform/logging.h" // Note: branched from tensorflow/lite/tools/optimize/quantize_weights_test.cc @@ -59,25 +60,25 @@ std::unique_ptr CreateMutableModelFromFile(const Model* input_model) { std::unique_ptr ReadTestModel() { auto model_path = tensorflow::io::JoinPath( - *g_test_model_dir, internal::kConvModelWith0Plus10Weights); + *g_test_model_dir, ::mlir::lite::internal::kConvModelWith0Plus10Weights); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadSharedWeightsTestModel() { - auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, - internal::kModelWithSharedWeights); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithSharedWeights); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadGatherTestModel() { - auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, - internal::kQuantizedWithGather); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kQuantizedWithGather); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadCustomOpTestModel() { - auto model_path = - tensorflow::io::JoinPath(*g_test_model_dir, internal::kModelWithCustomOp); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithCustomOp); return FlatBufferModel::BuildFromFile(model_path.c_str()); } diff --git a/tensorflow/lite/tools/optimize/test_util.cc b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc similarity index 95% rename from tensorflow/lite/tools/optimize/test_util.cc rename to tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc index 5ca45326d1dcad..e096868eec8807 100644 --- a/tensorflow/lite/tools/optimize/test_util.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/tools/optimize/test_util.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include -namespace tflite { -namespace optimize { +namespace mlir { +namespace lite { namespace internal { const char* kConvModelWithMinus128Plus127Weights = "single_conv_weights_min_minus_127_max_plus_127.bin"; @@ -89,5 +89,5 @@ int FailOnErrorReporter::Report(const char* format, va_list args) { return 0; } } // namespace internal -} // namespace optimize -} // namespace tflite +} // namespace lite +} // namespace mlir diff --git a/tensorflow/lite/tools/optimize/test_util.h b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h similarity index 93% rename from tensorflow/lite/tools/optimize/test_util.h rename to tensorflow/compiler/mlir/lite/quantization/lite/test_util.h index 11e7ef230910f2..b4e317c131888e 100644 --- a/tensorflow/lite/tools/optimize/test_util.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_TEST_UTIL_H_ -#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_TEST_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ #include "tensorflow/lite/core/api/error_reporter.h" -namespace tflite { -namespace optimize { +namespace mlir { +namespace lite { namespace internal { // Test model with a single convolution. // Floating point weights of the model are all integers and lie in @@ -132,12 +132,12 @@ extern const char* kQatModelWithFc; extern const char* kModelWithResourceVarsCalibrated; // An error reporter that fails on testing. -class FailOnErrorReporter : public ErrorReporter { +class FailOnErrorReporter : public tflite::ErrorReporter { public: int Report(const char* format, va_list args) override; }; } // namespace internal -} // namespace optimize -} // namespace tflite +} // namespace lite +} // namespace mlir -#endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_TEST_UTIL_H_ +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ diff --git a/tensorflow/lite/tools/optimize/testdata/README.md b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/README.md similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/README.md rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/README.md diff --git a/tensorflow/lite/tools/optimize/testdata/add_with_const_input.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/add_with_const_input.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/add_with_const_input.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/add_with_const_input.bin diff --git a/tensorflow/lite/tools/optimize/testdata/argmax.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/argmax.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/argmax.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/argmax.bin diff --git a/tensorflow/lite/tools/optimize/testdata/broadcast_to.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/broadcast_to.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/broadcast_to.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/broadcast_to.bin diff --git a/tensorflow/lite/tools/optimize/testdata/concat.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/concat.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/concat.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/concat.bin diff --git a/tensorflow/lite/tools/optimize/testdata/custom_op.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/custom_op.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/custom_op.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/custom_op.bin diff --git a/tensorflow/lite/tools/optimize/testdata/fc.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/fc.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/fc.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/fc.bin diff --git a/tensorflow/lite/tools/optimize/testdata/fc_qat.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/fc_qat.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/fc_qat.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/fc_qat.bin diff --git a/tensorflow/lite/tools/optimize/testdata/gather_nd.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/gather_nd.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/gather_nd.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/gather_nd.bin diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_calibrated.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_calibrated.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/lstm_calibrated.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_calibrated.bin diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_calibrated2.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_calibrated2.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/lstm_calibrated2.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_calibrated2.bin diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_quantized.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_quantized.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/lstm_quantized.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_quantized.bin diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_quantized2.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_quantized2.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/lstm_quantized2.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_quantized2.bin diff --git a/tensorflow/lite/tools/optimize/testdata/maximum.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/maximum.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/maximum.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/maximum.bin diff --git a/tensorflow/lite/tools/optimize/testdata/minimum.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/minimum.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/minimum.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/minimum.bin diff --git a/tensorflow/lite/tools/optimize/testdata/mixed.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/mixed.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/mixed.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/mixed.bin diff --git a/tensorflow/lite/tools/optimize/testdata/mixed16x8.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/mixed16x8.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/mixed16x8.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/mixed16x8.bin diff --git a/tensorflow/lite/tools/optimize/testdata/multi_input_add_reshape.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/multi_input_add_reshape.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/multi_input_add_reshape.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/multi_input_add_reshape.bin diff --git a/tensorflow/lite/tools/optimize/testdata/pack.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/pack.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/pack.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/pack.bin diff --git a/tensorflow/lite/tools/optimize/testdata/quantized_with_gather.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/quantized_with_gather.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/quantized_with_gather.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/quantized_with_gather.bin diff --git a/tensorflow/lite/tools/optimize/testdata/resource_vars_calibrated.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/resource_vars_calibrated.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/resource_vars_calibrated.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/resource_vars_calibrated.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_avg_pool_min_minus_5_max_plus_5.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_avg_pool_min_minus_5_max_plus_5.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_avg_pool_min_minus_5_max_plus_5.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_avg_pool_min_minus_5_max_plus_5.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_conv_no_bias.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_no_bias.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_conv_no_bias.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_no_bias.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_conv_weights_min_0_max_plus_10.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_weights_min_0_max_plus_10.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_conv_weights_min_0_max_plus_10.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_weights_min_0_max_plus_10.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_conv_weights_min_minus_127_max_plus_127.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_weights_min_minus_127_max_plus_127.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_conv_weights_min_minus_127_max_plus_127.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_weights_min_minus_127_max_plus_127.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_softmax_min_minus_5_max_plus_5.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_softmax_min_minus_5_max_plus_5.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_softmax_min_minus_5_max_plus_5.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_softmax_min_minus_5_max_plus_5.bin diff --git a/tensorflow/lite/tools/optimize/testdata/split.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/split.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/split.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/split.bin diff --git a/tensorflow/lite/tools/optimize/testdata/svdf_calibrated.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/svdf_calibrated.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/svdf_calibrated.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/svdf_calibrated.bin diff --git a/tensorflow/lite/tools/optimize/testdata/svdf_quantized.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/svdf_quantized.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/svdf_quantized.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/svdf_quantized.bin diff --git a/tensorflow/lite/tools/optimize/testdata/transpose.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/transpose.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/transpose.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/transpose.bin diff --git a/tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_calibrated.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/unidirectional_sequence_lstm_calibrated.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_calibrated.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/unidirectional_sequence_lstm_calibrated.bin diff --git a/tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_quantized.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/unidirectional_sequence_lstm_quantized.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_quantized.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/unidirectional_sequence_lstm_quantized.bin diff --git a/tensorflow/lite/tools/optimize/testdata/unpack.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/unpack.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/unpack.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/unpack.bin diff --git a/tensorflow/lite/tools/optimize/testdata/weight_shared_between_convs.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/weight_shared_between_convs.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/weight_shared_between_convs.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/weight_shared_between_convs.bin diff --git a/tensorflow/lite/tools/optimize/testdata/where.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/where.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/where.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/where.bin diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD index 711f97bdfddd16..a05a5cbdb10710 100644 --- a/tensorflow/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -14,10 +14,6 @@ package( licenses = ["notice"], ) -exports_files(glob([ - "testdata/*.bin", -])) - cc_library( name = "reduced_precision_support", srcs = [], @@ -39,7 +35,6 @@ tf_cc_test( ], deps = [ ":reduced_precision_support", - ":test_util", "//tensorflow/core/platform:platform_port", "//tensorflow/lite:framework", "//tensorflow/lite/core:framework", @@ -223,7 +218,6 @@ tf_cc_test( ], deps = [ ":model_utils", - ":test_util", "//tensorflow/lite:framework", "//tensorflow/lite/core:framework", "//tensorflow/lite/schema:schema_fbs", @@ -250,10 +244,10 @@ tf_cc_test( name = "quantization_utils_test", srcs = ["quantization_utils_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - ":testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", ], tags = [ "tflite_not_portable_android", @@ -261,7 +255,7 @@ tf_cc_test( ], deps = [ ":quantization_utils", - ":test_util", + "//tensorflow/compiler/mlir/lite/quantization/lite:test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", @@ -316,13 +310,13 @@ tf_cc_test( name = "quantize_weights_test", srcs = ["quantize_weights_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - "//tensorflow/lite/tools/optimize:testdata/custom_op.bin", - "//tensorflow/lite/tools/optimize:testdata/quantized_with_gather.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", - "//tensorflow/lite/tools/optimize:testdata/weight_shared_between_convs.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/custom_op.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/quantized_with_gather.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/weight_shared_between_convs.bin", ], tags = [ "tflite_not_portable_android", @@ -330,7 +324,7 @@ tf_cc_test( ], deps = [ ":quantize_weights", - ":test_util", + "//tensorflow/compiler/mlir/lite/quantization/lite:test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", @@ -342,19 +336,6 @@ tf_cc_test( ], ) -cc_library( - name = "test_util", - testonly = 1, - srcs = ["test_util.cc"], - hdrs = ["test_util.h"], - deps = [ - "//tensorflow/lite:framework", - "//tensorflow/lite/core/api", - "@com_google_googletest//:gtest", - "@flatbuffers", - ], -) - cc_library( name = "quantize_model", srcs = ["quantize_model.cc"], @@ -379,40 +360,40 @@ tf_cc_test( name = "quantize_model_test", srcs = ["quantize_model_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - "//tensorflow/lite/tools/optimize:testdata/add_with_const_input.bin", - "//tensorflow/lite/tools/optimize:testdata/argmax.bin", - "//tensorflow/lite/tools/optimize:testdata/broadcast_to.bin", - "//tensorflow/lite/tools/optimize:testdata/concat.bin", - "//tensorflow/lite/tools/optimize:testdata/fc.bin", - "//tensorflow/lite/tools/optimize:testdata/fc_qat.bin", - "//tensorflow/lite/tools/optimize:testdata/gather_nd.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated2.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_quantized2.bin", - "//tensorflow/lite/tools/optimize:testdata/maximum.bin", - "//tensorflow/lite/tools/optimize:testdata/minimum.bin", - "//tensorflow/lite/tools/optimize:testdata/mixed.bin", - "//tensorflow/lite/tools/optimize:testdata/mixed16x8.bin", - "//tensorflow/lite/tools/optimize:testdata/multi_input_add_reshape.bin", - "//tensorflow/lite/tools/optimize:testdata/pack.bin", - "//tensorflow/lite/tools/optimize:testdata/resource_vars_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_no_bias.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", - "//tensorflow/lite/tools/optimize:testdata/single_softmax_min_minus_5_max_plus_5.bin", - "//tensorflow/lite/tools/optimize:testdata/split.bin", - "//tensorflow/lite/tools/optimize:testdata/svdf_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/svdf_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/transpose.bin", - "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/unpack.bin", - "//tensorflow/lite/tools/optimize:testdata/where.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/add_with_const_input.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/argmax.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/broadcast_to.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/concat.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc_qat.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/gather_nd.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated2.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized2.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/maximum.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/minimum.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed16x8.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/multi_input_add_reshape.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/pack.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/resource_vars_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_no_bias.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_softmax_min_minus_5_max_plus_5.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/split.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/transpose.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unpack.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/where.bin", ], tags = [ "tflite_not_portable_android", @@ -420,7 +401,7 @@ tf_cc_test( ], deps = [ ":quantize_model", - ":test_util", + "//tensorflow/compiler/mlir/lite/quantization/lite:test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", diff --git a/tensorflow/lite/tools/optimize/model_utils_test.cc b/tensorflow/lite/tools/optimize/model_utils_test.cc index 65e3afe35e2da2..f702e1fa0a0ddd 100644 --- a/tensorflow/lite/tools/optimize/model_utils_test.cc +++ b/tensorflow/lite/tools/optimize/model_utils_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include "tensorflow/lite/core/model.h" #include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/tools/optimize/test_util.h" namespace tflite { namespace optimize { diff --git a/tensorflow/lite/tools/optimize/quantization_utils_test.cc b/tensorflow/lite/tools/optimize/quantization_utils_test.cc index a09acef6f4aa3c..a0ab9c43eacb75 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils_test.cc +++ b/tensorflow/lite/tools/optimize/quantization_utils_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include "absl/memory/memory.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/testing/util.h" -#include "tensorflow/lite/tools/optimize/test_util.h" namespace { tensorflow::string* g_test_model_dir = nullptr; @@ -46,7 +46,7 @@ std::unique_ptr ReadModel(const char* model) { } std::unique_ptr ReadConvModel() { - return ReadModel(internal::kConvModelWith0Plus10Weights); + return ReadModel(mlir::lite::internal::kConvModelWith0Plus10Weights); } using ::testing::ElementsAreArray; diff --git a/tensorflow/lite/tools/optimize/quantize_model_test.cc b/tensorflow/lite/tools/optimize/quantize_model_test.cc index 681507c8e0d31d..a7e9115f8bdaaa 100644 --- a/tensorflow/lite/tools/optimize/quantize_model_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/testing/util.h" -#include "tensorflow/lite/tools/optimize/test_util.h" // Note: More rigorous model tests can be found in subgraph_quantizer_test.cc @@ -78,7 +78,8 @@ TensorType GetBiasTensorType(TensorType& activation_type) { class QuantizeModelTest : public testing::Test { protected: QuantizeModelTest() - : QuantizeModelTest(ReadModel(internal::kConvModelWith0Plus10Weights)) {} + : QuantizeModelTest( + ReadModel(mlir::lite::internal::kConvModelWith0Plus10Weights)) {} explicit QuantizeModelTest(std::unique_ptr input_model) { input_model_ = std::move(input_model); @@ -91,7 +92,7 @@ class QuantizeModelTest : public testing::Test { const Model* readonly_model_; tflite::ModelT model_; flatbuffers::FlatBufferBuilder builder_; - internal::FailOnErrorReporter error_reporter_; + ::mlir::lite::internal::FailOnErrorReporter error_reporter_; }; void ExpectSameModels(const ModelT& model, const ModelT& expected_model) { @@ -136,7 +137,8 @@ class QuantizeConvModelTest : public QuantizeModelTest, public testing::WithParamInterface { protected: QuantizeConvModelTest() - : QuantizeModelTest(ReadModel(internal::kConvModelWith0Plus10Weights)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} TensorType tensor_type_; @@ -405,7 +407,8 @@ TEST_P(QuantizeConvModelTest, Uint8InputAndOutput) { class QuantizeConvNoBiasModelTest : public QuantizeModelTest { protected: QuantizeConvNoBiasModelTest() - : QuantizeModelTest(ReadModel(internal::kConvModelWithNoBias)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConvModelWithNoBias)) {} }; TEST_F(QuantizeConvNoBiasModelTest, QuantizationSucceeds) { @@ -422,7 +425,8 @@ class QuantizeConcatModelTest : public QuantizeModelTest, public testing::WithParamInterface { protected: QuantizeConcatModelTest() - : QuantizeModelTest(ReadModel(internal::kFloatConcatMax5Max10Max10)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kFloatConcatMax5Max10Max10)) {} void SetUp() override { tensor_type_ = GetParam(); @@ -536,7 +540,7 @@ INSTANTIATE_TEST_SUITE_P(QuantizeConcatModelInst, QuantizeConcatModelTest, class QuantizeSplitModelTest : public QuantizeModelTest { protected: QuantizeSplitModelTest() - : QuantizeModelTest(ReadModel(internal::kModelSplit)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelSplit)) {} }; // There are two outputs for split with different scales, the resulting model @@ -601,8 +605,8 @@ TEST_F(QuantizeSplitModelTest, QuantizeSplit) { class QuantizeConvModel1Test : public QuantizeModelTest { protected: QuantizeConvModel1Test() - : QuantizeModelTest( - ReadModel(internal::kConvModelWithMinus128Plus127Weights)) {} + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kConvModelWithMinus128Plus127Weights)) {} }; TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) { @@ -703,7 +707,8 @@ class QuantizeConvModel2Test : public QuantizeModelTest, public testing::WithParamInterface { protected: QuantizeConvModel2Test() - : QuantizeModelTest(ReadModel(internal::kConvModelWith0Plus10Weights)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -925,8 +930,8 @@ TEST_P(QuantizeConvModel2Test, VerifyConvDisablePerChannelQuantization) { class QuantizeSoftmaxTest : public QuantizeModelTest { protected: QuantizeSoftmaxTest() - : QuantizeModelTest( - ReadModel(internal::kSingleSoftmaxModelMinMinus5MaxPlus5)) {} + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kSingleSoftmaxModelMinMinus5MaxPlus5)) {} }; TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { @@ -985,8 +990,8 @@ TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { class QuantizeAvgPoolTest : public QuantizeModelTest { protected: QuantizeAvgPoolTest() - : QuantizeModelTest( - ReadModel(internal::kSingleAvgPoolModelMinMinus5MaxPlus5)) {} + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kSingleAvgPoolModelMinMinus5MaxPlus5)) {} }; TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { @@ -1045,7 +1050,8 @@ TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest { protected: QuantizeMultiInputAddWithReshapeTest() - : QuantizeModelTest(ReadModel(internal::kMultiInputAddWithReshape)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kMultiInputAddWithReshape)) {} }; TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) { @@ -1155,7 +1161,8 @@ class QuantizeConstInputTest : public QuantizeModelTest, public testing::WithParamInterface { protected: QuantizeConstInputTest() - : QuantizeModelTest(ReadModel(internal::kConstInputAddModel)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConstInputAddModel)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -1213,7 +1220,8 @@ TEST_P(QuantizeConstInputTest, VerifyConstOpInput) { class QuantizeArgMaxTest : public QuantizeModelTest { protected: QuantizeArgMaxTest() - : QuantizeModelTest(ReadModel(internal::kModelWithArgMaxOp)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithArgMaxOp)) {} }; TEST_F(QuantizeArgMaxTest, VerifyArgMax) { @@ -1254,7 +1262,7 @@ TEST_F(QuantizeArgMaxTest, VerifyArgMax) { class QuantizeLSTMTest : public QuantizeModelTest { protected: QuantizeLSTMTest() - : QuantizeModelTest(ReadModel(internal::kLstmCalibrated)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kLstmCalibrated)) {} }; TEST_F(QuantizeLSTMTest, VerifyLSTM) { @@ -1265,7 +1273,7 @@ TEST_F(QuantizeLSTMTest, VerifyLSTM) { ASSERT_EQ(kTfLiteOk, status); // Read expected model. - auto expected_fb_model = ReadModel(internal::kLstmQuantized); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kLstmQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1276,7 +1284,8 @@ TEST_F(QuantizeLSTMTest, VerifyLSTM) { class QuantizeLSTM2Test : public QuantizeModelTest { protected: QuantizeLSTM2Test() - : QuantizeModelTest(ReadModel(internal::kLstmCalibrated2)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kLstmCalibrated2)) { + } }; TEST_F(QuantizeLSTM2Test, VerifyLSTM) { @@ -1287,7 +1296,7 @@ TEST_F(QuantizeLSTM2Test, VerifyLSTM) { ASSERT_EQ(kTfLiteOk, status); // Read expected model. - auto expected_fb_model = ReadModel(internal::kLstmQuantized2); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kLstmQuantized2); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1298,8 +1307,8 @@ TEST_F(QuantizeLSTM2Test, VerifyLSTM) { class QuantizeUnidirectionalSequenceLSTMTest : public QuantizeModelTest { protected: QuantizeUnidirectionalSequenceLSTMTest() - : QuantizeModelTest( - ReadModel(internal::kUnidirectionalSequenceLstmCalibrated)) {} + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kUnidirectionalSequenceLstmCalibrated)) {} }; TEST_F(QuantizeUnidirectionalSequenceLSTMTest, @@ -1312,7 +1321,7 @@ TEST_F(QuantizeUnidirectionalSequenceLSTMTest, // Read expected model. auto expected_fb_model = - ReadModel(internal::kUnidirectionalSequenceLstmQuantized); + ReadModel(::mlir::lite::internal::kUnidirectionalSequenceLstmQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1323,7 +1332,7 @@ TEST_F(QuantizeUnidirectionalSequenceLSTMTest, class QuantizeSVDFTest : public QuantizeModelTest { protected: QuantizeSVDFTest() - : QuantizeModelTest(ReadModel(internal::kSvdfCalibrated)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kSvdfCalibrated)) {} }; TEST_F(QuantizeSVDFTest, VerifySVDF) { @@ -1334,7 +1343,7 @@ TEST_F(QuantizeSVDFTest, VerifySVDF) { ASSERT_EQ(kTfLiteOk, status); // Read expected model. - auto expected_fb_model = ReadModel(internal::kSvdfQuantized); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kSvdfQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1379,7 +1388,8 @@ TEST_F(QuantizeSVDFTest, VerifySVDF) { class QuantizeFCTest : public QuantizeModelTest { protected: - QuantizeFCTest() : QuantizeModelTest(ReadModel(internal::kModelWithFCOp)) {} + QuantizeFCTest() + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelWithFCOp)) {} }; TEST_F(QuantizeFCTest, VerifyFC) { @@ -1430,7 +1440,7 @@ class QuantizeCustomOpTest public ::testing::WithParamInterface { protected: QuantizeCustomOpTest() - : QuantizeModelTest(ReadModel(internal::kModelMixed)), + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelMixed)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -1471,7 +1481,7 @@ INSTANTIATE_TEST_SUITE_P(QuantizeCustomOpTest, QuantizeCustomOpTest, class QuantizeOp16x8Test : public QuantizeModelTest { protected: QuantizeOp16x8Test() - : QuantizeModelTest(ReadModel(internal::kModelMixed16x8)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelMixed16x8)) {} }; TEST_F(QuantizeOp16x8Test, VerifyMixedQuantization16x8) { @@ -1502,7 +1512,8 @@ TEST_F(QuantizeOp16x8Test, VerifyMixedQuantization16x8) { class QuantizePackTest : public QuantizeModelTest { protected: - QuantizePackTest() : QuantizeModelTest(ReadModel(internal::kModelPack)) {} + QuantizePackTest() + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelPack)) {} }; TEST_F(QuantizePackTest, VerifyPack) { @@ -1628,14 +1639,16 @@ TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) { EXPECT_EQ(subgraph->tensors[5]->name, "output"); } -INSTANTIATE_TEST_SUITE_P(MinimumMaximumTestInst, QuantizeMinimumMaximumTest, - testing::ValuesIn({internal::kModelWithMinimumOp, - internal::kModelWithMaximumOp})); +INSTANTIATE_TEST_SUITE_P( + MinimumMaximumTestInst, QuantizeMinimumMaximumTest, + testing::ValuesIn({::mlir::lite::internal::kModelWithMinimumOp, + ::mlir::lite::internal::kModelWithMaximumOp})); class QuantizeUnpackTest : public QuantizeModelTest { protected: QuantizeUnpackTest() - : QuantizeModelTest(ReadModel(internal::kModelWithUnpack)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelWithUnpack)) { + } }; TEST_F(QuantizeUnpackTest, VerifyUnpack) { auto status = QuantizeModel(&builder_, &model_, &error_reporter_); @@ -1680,7 +1693,8 @@ TEST_F(QuantizeUnpackTest, VerifyUnpack) { class QuantizeTransposeTest : public QuantizeModelTest { protected: QuantizeTransposeTest() - : QuantizeModelTest(ReadModel(internal::kModelWithTranspose)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithTranspose)) {} }; TEST_F(QuantizeTransposeTest, VerifyTranspose) { @@ -1720,7 +1734,8 @@ TEST_F(QuantizeTransposeTest, VerifyTranspose) { class QuantizeQatTest : public QuantizeModelTest { protected: - QuantizeQatTest() : QuantizeModelTest(ReadModel(internal::kQatModelWithFc)) {} + QuantizeQatTest() + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kQatModelWithFc)) {} }; TEST_F(QuantizeQatTest, VerifySingleQuantize) { @@ -1777,7 +1792,8 @@ class QuantizeBroadcastToModelTest public testing::WithParamInterface { protected: QuantizeBroadcastToModelTest() - : QuantizeModelTest(ReadModel(internal::kModelWithBroadcastToOp)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithBroadcastToOp)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -1844,7 +1860,8 @@ class QuantizeGatherNDModelTest public testing::WithParamInterface { protected: QuantizeGatherNDModelTest() - : QuantizeModelTest(ReadModel(internal::kModelWithGatherNDOp)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithGatherNDOp)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -1906,7 +1923,8 @@ TEST_P(QuantizeGatherNDModelTest, QuantizeGatherND) { class QuantizeWhereModelTest : public QuantizeModelTest { protected: QuantizeWhereModelTest() - : QuantizeModelTest(ReadModel(internal::kModelWithWhereOp)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithWhereOp)) {} }; TEST_F(QuantizeWhereModelTest, QuantizeWhere) { @@ -1976,8 +1994,8 @@ class QuantizeResourcesModelTest public testing::WithParamInterface { protected: QuantizeResourcesModelTest() - : QuantizeModelTest( - ReadModel(internal::kModelWithResourceVarsCalibrated)) { + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kModelWithResourceVarsCalibrated)) { TestType obj = GetParam(); tensor_type_ = obj.tensor_type; modify_range_ = obj.modify_range; @@ -2119,7 +2137,8 @@ class QuantizeConcatConstModelTest public testing::WithParamInterface { protected: QuantizeConcatConstModelTest() - : QuantizeModelTest(ReadModel(internal::kFloatConcatMax5Max10Max10)) { + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kFloatConcatMax5Max10Max10)) { // Make one of the values constant. MakeInputConstant(&model_); } @@ -2224,7 +2243,8 @@ class BiasInputTest : public QuantizeModelTest, public testing::WithParamInterface { protected: BiasInputTest() - : QuantizeModelTest(ReadModel(internal::kConvModelWith0Plus10Weights)) { + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights)) { BiasTestType obj = GetParam(); tensor_type_ = obj.tensor_type; bias_type_ = obj.bias_type; diff --git a/tensorflow/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/lite/tools/optimize/quantize_weights_test.cc index 0e9c3efc17acd9..b2279ed34908f6 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights_test.cc @@ -22,13 +22,13 @@ limitations under the License. #include #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/lite/core/model.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" -#include "tensorflow/lite/tools/optimize/test_util.h" namespace { tensorflow::string* g_test_model_dir = nullptr; @@ -40,25 +40,25 @@ namespace { std::unique_ptr ReadTestModel() { auto model_path = tensorflow::io::JoinPath( - *g_test_model_dir, internal::kConvModelWith0Plus10Weights); + *g_test_model_dir, ::mlir::lite::internal::kConvModelWith0Plus10Weights); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadSharedWeightsTestModel() { - auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, - internal::kModelWithSharedWeights); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithSharedWeights); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadGatherTestModel() { - auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, - internal::kQuantizedWithGather); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kQuantizedWithGather); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadCustomOpTestModel() { - auto model_path = - tensorflow::io::JoinPath(*g_test_model_dir, internal::kModelWithCustomOp); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithCustomOp); return FlatBufferModel::BuildFromFile(model_path.c_str()); } diff --git a/tensorflow/lite/tools/optimize/reduced_precision_support_test.cc b/tensorflow/lite/tools/optimize/reduced_precision_support_test.cc index 6b5cf538b50c43..19400079b17e96 100644 --- a/tensorflow/lite/tools/optimize/reduced_precision_support_test.cc +++ b/tensorflow/lite/tools/optimize/reduced_precision_support_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/testing/util.h" -#include "tensorflow/lite/tools/optimize/test_util.h" namespace tflite { namespace optimize { From d10098df02f5cf2b26e0f791bf9e1d9a6b1989df Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jun 2024 18:23:58 -0700 Subject: [PATCH 248/256] In this change we 1. rename GetInstructionSize to ByteSizeOfShape. For the case where a sharding is provided, we also compute the bytes of the sharded shape by first computing explicitly the sharded shape, rather than merely dividing the byte size of the unsharded shape by the number of tiles it is sharded into. This takes into account any padding that might be needed when sharding tensors. 2. Simplify GetBytes to remove a redundant if condition 3. Use a pointer size of 8 bytes to compute shape byte sizes. This helps us better account of tuple metadata storage. PiperOrigin-RevId: 646676353 --- .../auto_sharding/auto_sharding.cc | 9 ++--- .../auto_sharding/auto_sharding_util.cc | 38 +++++++++++++------ .../auto_sharding/auto_sharding_util.h | 15 ++++++-- 3 files changed, 42 insertions(+), 20 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 6c23923952f623..0d8ebd650c3422 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -1630,7 +1630,7 @@ void CheckMemoryCosts(StrategyGroup* strategy_group, const Shape& shape) { for (const auto& strategy : strategy_group->strategies) { if (strategy.output_sharding.IsReplicated()) { full_mem = strategy.memory_cost; - size_t size = GetInstructionSize(shape); + size_t size = ByteSizeOfShape(shape); CHECK_EQ(strategy.memory_cost, size); } } @@ -2186,7 +2186,7 @@ void CheckHloSharding( ins->opcode() != HloOpcode::kGetTupleElement) { // TODO(yuemmawang) Check other cases when it's helpful (it's not // needed so far). - double size = GetInstructionSize(ins->shape()) / 1024 / 1024 / 1024; + double size = ByteSizeOfShape(ins->shape()) / 1024 / 1024 / 1024; if ((!ShardingIsComplete(ins->sharding(), total_num_devices) || ins->sharding().IsReplicated()) && size > 1) { @@ -2224,7 +2224,7 @@ void CheckHloSharding( // of resharding overheads, and some inconsistent shardings are // unavoidable. size_t op_size = - GetInstructionSize(op->shape()) / (1024.0 * 1024 * 1024); + ByteSizeOfShape(op->shape()) / (1024.0 * 1024 * 1024); std::string str = absl::StrCat("Shardings not consistent (op size ", op_size, " GB):", ins->ToString(), "\n Operand: ", op->ToString()); @@ -4290,8 +4290,7 @@ bool ModuleIsManuallyPartitioned(const HloModule* module) { bool IsSmallTensor(const HloInstruction* ins, const AutoShardingOption& option) { - return spmd::GetInstructionSize(ins->shape()) <= - option.small_tensor_byte_size; + return spmd::ByteSizeOfShape(ins->shape()) <= option.small_tensor_byte_size; } bool ShardedOnTooManyMeshAxes(const HloModule& module) { diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index 49ef8d22830957..a6399b7965c18b 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -1416,8 +1416,7 @@ absl::Status FixMixedMeshShapeReshardingGetTupleElement( HloInstruction* replace_with = ReshardTensor(inst, src_sharding, dst_sharding, device_mesh); inst->set_sharding(src_sharding); - size_t size = - GetInstructionSize(replace_with->shape()) / (1024 * 1024 * 1024); + size_t size = ByteSizeOfShape(replace_with->shape()) / (1024 * 1024 * 1024); if (size > 1) { LOG(WARNING) << "Large reshape instruction inserted (operand of " << inst->name() << ") with size " << size @@ -1484,8 +1483,7 @@ absl::Status FixMixedMeshShapeResharding(HloInstruction* inst, int operand_num, } } - size_t size = - GetInstructionSize(replace_with->shape()) / (1024 * 1024 * 1024); + size_t size = ByteSizeOfShape(replace_with->shape()) / (1024 * 1024 * 1024); if (size > 1) { LOG(WARNING) << "Large reshape instruction inserted (operand of " << inst->name() << ") with size " << size @@ -1854,15 +1852,32 @@ std::vector VectorGreaterThanOneElementIndices( return result; } -int64_t GetInstructionSize(const Shape& shape) { +int64_t ByteSizeOfShapeWithSharding(const Shape& shape, + std::optional sharding) { if (shape.IsTuple()) { - int64_t size = 0; - for (const Shape& subshape : shape.tuple_shapes()) { - size += GetInstructionSize(subshape); + int64_t size = + ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/kAutoShardingPointerSize); + for (size_t i = 0; i < shape.tuple_shapes_size(); i++) { + const Shape& subshape = shape.tuple_shapes().at(i); + if (sharding) { + const HloSharding& sub_sharding = + sharding->IsTuple() + ? sharding->GetSubSharding(shape, + ShapeIndex{static_cast(i)}) + : *sharding; + size += ByteSizeOfShapeWithSharding(subshape, sub_sharding); + } else { + size += ByteSizeOfShapeWithSharding(subshape, std::nullopt); + } } return size; + } else if (shape.IsArray()) { + return ShapeUtil::ByteSizeOf(sharding ? sharding->TileShape(shape) : shape); + } else if (shape.IsToken()) { + return 0; + } else { + return kAutoShardingPointerSize; } - return GetBytes(shape); } int64_t GetShardedInstructionSize(const Shape& shape, int64_t num_devices, @@ -1871,9 +1886,10 @@ int64_t GetShardedInstructionSize(const Shape& shape, int64_t num_devices, sharding = HloSharding::Replicate(); } if (shape.IsTuple()) { - int64_t size = 0; + int64_t size = + ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/kAutoShardingPointerSize); for (size_t i = 0; i < shape.tuple_shapes_size(); i++) { - Shape subshape = shape.tuple_shapes().at(i); + const Shape& subshape = shape.tuple_shapes().at(i); size += GetShardedInstructionSize( subshape, sharding.has_value() diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index dd78eda6ebf304..5b3e65f164155a 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -54,6 +54,8 @@ inline constexpr absl::string_view kIdentityMarker = "identity"; inline constexpr absl::string_view kPipelineMarkerStartType = "start"; inline constexpr absl::string_view kPipelineMarkerEndType = "end"; +inline constexpr int64_t kAutoShardingPointerSize = 8; + inline bool IsSPMDFullToShardShapeCustomCall(const HloInstruction* ins) { return ins->IsCustomCall("SPMDFullToShardShape"); } @@ -157,9 +159,6 @@ std::string ToString(absl::Span span) { // Get the number of bytes of a shape. inline double GetBytes(const Shape& shape) { - if (shape.IsArray()) { - return ShapeUtil::ByteSizeOfElements(shape); - } return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/8); } @@ -584,7 +583,15 @@ size_t VectorGreaterThanOneElementCount(absl::Span span, std::vector VectorGreaterThanOneElementIndices( absl::Span span, bool omit_last_dim = false); -int64_t GetInstructionSize(const Shape& shape); +// Computes bytes size of a shape recursively if it is sharded according to an +// optionally provided sharding +int64_t ByteSizeOfShapeWithSharding(const Shape& shape, + std::optional sharding); + +// Computes bytes size of a shape recursively +inline int64_t ByteSizeOfShape(const Shape& shape) { + return ByteSizeOfShapeWithSharding(shape, /*sharding=*/std::nullopt); +} int64_t GetShardedInstructionSize( const Shape& shape, int64_t num_devices, From 7817e4afe2be703f2705d03cc915b434421d9d33 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jun 2024 18:38:14 -0700 Subject: [PATCH 249/256] Change AllReduceSimplifier to handle trivial cross-partition all-reduces. This CL ensures that AllReduceSimplifier can simplify trivial all-reduces (an all-reduce where each subgroup is formed of a single participant) that are not necessarily cross replica (for example a cross partition all-reduce). We only simplify non cross replica all-reduce when the module is SPMD. PiperOrigin-RevId: 646679054 --- third_party/xla/xla/service/BUILD | 5 + .../xla/xla/service/all_reduce_simplifier.cc | 62 ++++++++--- .../xla/service/all_reduce_simplifier_test.cc | 91 ++++++++++++++++ .../xla/xla/service/collective_ops_utils.cc | 8 +- .../xla/service/collective_ops_utils_test.cc | 103 ++++++++++++++++-- 5 files changed, 240 insertions(+), 29 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 553d29e960458e..a4fc2041629600 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -3057,11 +3057,14 @@ cc_library( srcs = ["all_reduce_simplifier.cc"], hdrs = ["all_reduce_simplifier.h"], deps = [ + ":collective_ops_utils", + ":hlo_module_config", ":hlo_pass", ":hlo_replication_analysis", "//xla:literal_util", "//xla:shape_util", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", @@ -3076,6 +3079,7 @@ xla_cc_test( srcs = ["all_reduce_simplifier_test.cc"], deps = [ ":all_reduce_simplifier", + ":hlo_module_config", ":hlo_parser", ":pattern_matcher", ":pattern_matcher_gmock", @@ -7254,6 +7258,7 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/service/all_reduce_simplifier.cc b/third_party/xla/xla/service/all_reduce_simplifier.cc index 5837ea49da0aae..0760433bda4489 100644 --- a/third_party/xla/xla/service/all_reduce_simplifier.cc +++ b/third_party/xla/xla/service/all_reduce_simplifier.cc @@ -19,14 +19,18 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal_util.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_replication_analysis.h" #include "xla/shape_util.h" #include "tsl/platform/errors.h" @@ -42,22 +46,33 @@ absl::StatusOr AllReduceSimplifier::Run( HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/false)); std::vector> all_reduces_to_replace; - // Returns the size of a replica group if all groups have the same size, or -1 - // if they have different sizes. - auto get_replica_group_size = - [this](const HloInstruction* all_reduce) -> int64_t { - if (all_reduce->replica_groups().empty()) { - return replica_count_; + // Returns the number of participants in a replica group if all groups have + // the same size, or -1 if they have different sizes. + // Number of participants depends on the mode of the collective operation. + auto get_participant_counts_for_replica_group = + [](const HloInstruction* all_reduce) -> absl::StatusOr { + const HloModuleConfig& config = all_reduce->GetModule()->config(); + TF_ASSIGN_OR_RETURN( + CollectiveOpGroupMode group_mode, + GetCollectiveOpGroupMode(all_reduce->channel_id().has_value(), + Cast(all_reduce) + ->use_global_device_ids())); + + int64_t num_devices = config.num_partitions(); + int64_t num_replicas = config.replica_count(); + TF_ASSIGN_OR_RETURN(std::vector participant_counts, + GetPariticipantCountsForReplicaGroups( + num_replicas, num_devices, + all_reduce->replica_groups(), group_mode)); + if (participant_counts.empty()) { + return -1; } - int64_t replica_group_size = -1; - for (const auto& group : all_reduce->replica_groups()) { - if (replica_group_size == -1) { - replica_group_size = group.replica_ids_size(); - } else if (replica_group_size != group.replica_ids_size()) { - return -1; - } + if (!absl::c_all_of(participant_counts, [&](int64_t participant_count) { + return participant_count == participant_counts[0]; + })) { + return -1; } - return replica_group_size; + return participant_counts[0]; }; bool changed = false; @@ -83,11 +98,24 @@ absl::StatusOr AllReduceSimplifier::Run( // optimize out (being fed within a tuple input). continue; } - if (!inst->IsCrossReplicaAllReduce()) { + if (!inst->IsCrossReplicaAllReduce() && !inst->IsCrossModuleAllReduce()) { continue; } - int64_t group_size = get_replica_group_size(inst); - if (group_size == -1) { + TF_ASSIGN_OR_RETURN(int64_t group_size, + get_participant_counts_for_replica_group(inst)); + + // We will not simplify this all reduce if any of the following is true: + // 1. All group do not have the same size. + // + // 2. The AllReduce is not cross replica and the group size is not 1. + // Since the replication analysis performed earlier is only for cross + // replica spmd. + // + // 3. The AllReduce is not cross replica and the module is not using spmd. + if (group_size == -1 || + (!inst->IsCrossReplicaAllReduce() && group_size != 1) || + (!inst->IsCrossReplicaAllReduce() && + !module->config().use_spmd_partitioning())) { continue; } if (replication->HloInstructionIsReplicatedAt(inst->operand(0), {}) || diff --git a/third_party/xla/xla/service/all_reduce_simplifier_test.cc b/third_party/xla/xla/service/all_reduce_simplifier_test.cc index 0843fc6df1a87b..e78881a0c19292 100644 --- a/third_party/xla/xla/service/all_reduce_simplifier_test.cc +++ b/third_party/xla/xla/service/all_reduce_simplifier_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" @@ -191,5 +192,95 @@ test { EXPECT_THAT(module->entry_computation()->root_instruction(), GmockMatch(m::Parameter(0))); } + +TEST_F(AllReduceSimplifierTest, TrivialSubgroupNonCrossReplicaAllReduce) { + const char* kModuleStr = R"( +HloModule m + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) +} + + +test { + p0 = f32[8,16] parameter(0), parameter_replication={false} + ROOT all-reduce = f32[8,16] all-reduce(p0), + channel_id=1, + use_global_device_ids=true, + replica_groups={{0},{1},{2},{3},{4},{5},{6},{7}}, + to_apply=sum +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(kModuleStr, /*replica_count=*/1, + /*num_partitions=*/8)); + module->mutable_config().set_use_spmd_partitioning(true); + AllReduceSimplifier simplifier(/*replica_count=*/1); + EXPECT_TRUE(simplifier.Run(module.get()).value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(0))); +} + +TEST_F(AllReduceSimplifierTest, NonCrossReplicaAllReduceAfterAllReduce) { + const char* kModuleStr = R"( +HloModule m + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) +} + + +test { + p0 = f32[8,16] parameter(0), parameter_replication={false} + all-reduce = f32[8,16] all-reduce(p0), + channel_id=1, + use_global_device_ids=true, + replica_groups={{0,2},{1,3},{4,6},{5,7}}, + to_apply=sum + ROOT all-reduce.1 = f32[8,16] all-reduce(all-reduce), + channel_id=2, + use_global_device_ids=true, + replica_groups={{0,4},{1,5},{2,6},{3,7}}, + to_apply=sum +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(kModuleStr, /*replica_count=*/1, + /*num_partitions=*/8)); + module->mutable_config().set_use_spmd_partitioning(true); + AllReduceSimplifier simplifier(/*replica_count=*/1); + EXPECT_FALSE(simplifier.Run(module.get()).value()); +} + +TEST_F(AllReduceSimplifierTest, MPMDNonCrossReplicaAllReduce) { + const char* kModuleStr = R"( +HloModule m + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) +} + +test { + p0 = f32[8,16] parameter(0), parameter_replication={false} + ROOT all-reduce = f32[8,16] all-reduce(p0), + channel_id=1, + replica_groups={{0},{1}}, + to_apply=sum +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(kModuleStr, /*replica_count=*/2, + /*num_partitions=*/1)); + // Mark as MPMD. + module->mutable_config().set_use_spmd_partitioning(false); + AllReduceSimplifier simplifier(/*replica_count=*/2); + EXPECT_FALSE(simplifier.Run(module.get()).value()); +} } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/collective_ops_utils.cc b/third_party/xla/xla/service/collective_ops_utils.cc index 01c4bba5abfefb..5bd802c343f523 100644 --- a/third_party/xla/xla/service/collective_ops_utils.cc +++ b/third_party/xla/xla/service/collective_ops_utils.cc @@ -515,8 +515,12 @@ absl::StatusOr> GetPariticipantCountsForReplicaGroups( switch (group_mode) { case CollectiveOpGroupMode::kCrossReplica: { - participant_counts.resize(participating_replica_groups.size(), - num_partitions); + for (const auto& replica_group : participating_replica_groups) { + for (int partition_id = 0; partition_id < num_partitions; + ++partition_id) { + participant_counts.push_back(replica_group.replica_ids().size()); + } + } return participant_counts; } case CollectiveOpGroupMode::kCrossPartition: { diff --git a/third_party/xla/xla/service/collective_ops_utils_test.cc b/third_party/xla/xla/service/collective_ops_utils_test.cc index 9fbcaf5f1f2ba2..c71776323f869f 100644 --- a/third_party/xla/xla/service/collective_ops_utils_test.cc +++ b/third_party/xla/xla/service/collective_ops_utils_test.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include +#include +#include #include "absl/algorithm/container.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" @@ -129,6 +131,21 @@ TEST(CollectiveOpsUtilsTest, CollectiveWithChannelId2) { EXPECT_EQ(IsOrHasCollectiveWithChannelId(fusion2.get()), nullptr); } +// Creates a container of ReplicaGroups. +std::vector CreateReplicaGroups( + const std::vector> &replica_groups) { + std::vector result; + result.reserve(replica_groups.size()); + for (const auto &replica_group : replica_groups) { + ReplicaGroup group; + for (auto id : replica_group) { + group.add_replica_ids(id); + } + result.push_back(group); + } + return result; +} + } // namespace // Tests for GetCollectOpGroupMode @@ -190,7 +207,7 @@ namespace GetParticipatingDevicesTest { // expected output corresponding to those values. struct TestCase { xla::Array2D device_assignment; - std::vector> replica_groups; + std::vector> replica_groups; bool has_channel_id; std::optional use_global_device_ids; @@ -455,15 +472,8 @@ TEST_P(GetParticipatingDevicesTest, Test) { } } - std::vector replica_groups; - absl::c_transform(tc.replica_groups, std::back_inserter(replica_groups), - [](const std::vector &ids) { - ReplicaGroup group; - for (int id : ids) { - group.add_replica_ids(id); - } - return group; - }); + std::vector replica_groups = + CreateReplicaGroups(tc.replica_groups); absl::StatusOr group_mode = GetCollectiveOpGroupMode(tc.has_channel_id, tc.use_global_device_ids); @@ -518,4 +528,77 @@ INSTANTIATE_TEST_SUITE_P(GetParticipatingDevices, GetParticipatingDevicesTest, testing::ValuesIn(GetTestCases())); } // namespace GetParticipatingDevicesTest + +namespace GetPariticipantCountsForReplicaGroupsTest { + +struct TestCase { + std::string test_name; + std::vector> replica_groups; + CollectiveOpGroupMode group_mode; + int64_t num_replicas; + int64_t num_partitions; + std::vector expected; +}; + +class GetPariticipantCountsForReplicaGroupsTest + : public testing::TestWithParam {}; + +TEST_P(GetPariticipantCountsForReplicaGroupsTest, Test) { + const TestCase &tc = GetParam(); + + std::vector replica_groups = + CreateReplicaGroups(tc.replica_groups); + TF_ASSERT_OK_AND_ASSIGN( + std::vector actual, + GetPariticipantCountsForReplicaGroups(tc.num_replicas, tc.num_partitions, + replica_groups, tc.group_mode)); + EXPECT_THAT(actual, testing::ElementsAreArray(tc.expected)); +} + +std::vector GetTestCases() { + return { + { + "CrossReplicaEmptyGroup", + {}, + CollectiveOpGroupMode::kCrossReplica, + 8, + 1, + {8}, + }, + { + "CrossReplicaWithPartitions", + {{0, 1}, {2, 3}}, + CollectiveOpGroupMode::kCrossReplica, + 4, + 2, + {2, 2, 2, 2}, + }, + { + "CrossReplicaAndPartition", + {{0, 1}, {2, 3}}, + CollectiveOpGroupMode::kCrossReplicaAndPartition, + 4, + 2, + {4, 4}, + }, + { + "FlattenedID", + {{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}}, + CollectiveOpGroupMode::kFlattenedID, + 4, + 2, + {1, 1, 1, 1, 1, 1, 1, 1}, + }, + }; +} +INSTANTIATE_TEST_SUITE_P( + GetPariticipantCountsForReplicaGroups, + GetPariticipantCountsForReplicaGroupsTest, + testing::ValuesIn(GetTestCases()), + [](const testing::TestParamInfo< + GetPariticipantCountsForReplicaGroupsTest::ParamType> &info) { + return info.param.test_name; + }); + +} // namespace GetPariticipantCountsForReplicaGroupsTest } // namespace xla From 778ebc83db53e27cb1f6457c7a0d92c5ead0db46 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 25 Jun 2024 18:46:41 -0700 Subject: [PATCH 250/256] [xla:cpu] NFC: Micro-optimizations for KernelThunk PiperOrigin-RevId: 646680450 --- .../select_and_scatter_benchmark_test.cc | 3 +- .../xla/service/cpu/runtime/kernel_thunk.cc | 30 +++++++++++-------- .../xla/stream_executor/host/host_kernel.cc | 16 +++++----- .../xla/stream_executor/host/host_kernel.h | 2 ++ 4 files changed, 28 insertions(+), 23 deletions(-) diff --git a/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc index bbc32250444b0f..b03df23e41e764 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc @@ -76,7 +76,6 @@ BENCHMARK(BM_SelectAndScatterF32) ->Arg(64) ->Arg(128) ->Arg(256) - ->Arg(512) - ->Arg(1024); + ->Arg(512); } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc index a8d793d1076071..21c57fef35f940 100644 --- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc @@ -87,40 +87,46 @@ tsl::AsyncValueRef KernelThunk::Execute( kernel_name_, arguments_buffers_.size(), results_buffers_.size(), thread_dim_.ToString()); - absl::InlinedVector kernel_args; - kernel_args.reserve(arguments_buffers_.size() + results_buffers_.size()); + int64_t num_args = arguments_buffers_.size() + results_buffers_.size(); + absl::InlinedVector kernel_args(num_args); + + // We initialize `kernel_args` array using pointer to the first argument, + // because individual elements access adds up measurable overhead, and this + // code is on the critical path. + SE_HOST_KernelArg* kernel_args_ptr = kernel_args.data(); + int64_t kernel_arg_idx = 0; int64_t arg_num = 0; for (BufferAllocation::Slice& buffer : arguments_buffers_) { TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase arg_data, params.buffer_allocations->GetDeviceAddress(buffer)); - kernel_args.push_back( - SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()}); VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", arg_num++, - buffer.ToString(), kernel_args.back().data); + buffer.ToString(), arg_data.opaque()); + kernel_args_ptr[kernel_arg_idx++] = + SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()}; } int64_t res_num = 0; for (BufferAllocation::Slice& buffer : results_buffers_) { TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase result_data, params.buffer_allocations->GetDeviceAddress(buffer)); - kernel_args.push_back( - SE_HOST_KernelArg{result_data.opaque(), result_data.size()}); VLOG(3) << absl::StreamFormat(" res #%d: %s (%p)", res_num++, - buffer.ToString(), kernel_args.back().data); + buffer.ToString(), result_data.opaque()); + kernel_args_ptr[kernel_arg_idx++] = + SE_HOST_KernelArg{result_data.opaque(), result_data.size()}; } // Check that all buffers are aligned to the minimum alignment. We codegen // with the assumption that all buffers are aligned, and if they are not, we // will crash with a segmentation fault, or worse, produce incorrect results. if (min_alignment_.has_value()) { - for (int64_t i = 0; i < kernel_args.size(); ++i) { - auto ptr = reinterpret_cast(kernel_args[i].data); + for (int64_t i = 0; i < num_args; ++i) { + auto ptr = reinterpret_cast(kernel_args_ptr[i].data); if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) { return Internal( "Host kernel %s buffer argument #%d (%p) is not aligned to a " "required minimum alignment of %d bytes", - info().op_name, i, kernel_args[i].data, *min_alignment_); + info().op_name, i, kernel_args_ptr[i].data, *min_alignment_); } } } @@ -136,7 +142,7 @@ tsl::AsyncValueRef KernelThunk::Execute( params.host_kernels->Find(kernel_name_)); absl::MutexLock lock(&mutex_); - kernel_.emplace(kernel_args.size(), kernel_fn, nullptr); + kernel_.emplace(num_args, kernel_fn, nullptr); kernel_ptr_.store(kernel = &kernel_.value()); } diff --git a/third_party/xla/xla/stream_executor/host/host_kernel.cc b/third_party/xla/xla/stream_executor/host/host_kernel.cc index 04586b5272432b..cad37e1bfa4fb0 100644 --- a/third_party/xla/xla/stream_executor/host/host_kernel.cc +++ b/third_party/xla/xla/stream_executor/host/host_kernel.cc @@ -67,8 +67,7 @@ class HostKernelExecuteState : public tsl::ReferenceCounted { public: HostKernelExecuteState(HostKernel::TaskRunner task_runner, - HostKernel::KernelFunction* function, - ThreadDim thread_dims, + SE_HOST_Kernel* kernel, ThreadDim thread_dims, absl::Span args); // Notify of a completion of a host kernel task. @@ -112,6 +111,7 @@ HostKernel::HostKernel(std::shared_ptr thread_pool) HostKernel::HostKernel(unsigned arity, SE_HOST_Kernel* kernel, std::shared_ptr thread_pool) : function_(std::make_unique(kernel)), + kernel_(function_->kernel()), arity_(arity), thread_pool_(thread_pool) {} @@ -130,8 +130,6 @@ absl::Status HostKernel::Launch( thread_dims.z, }; - SE_HOST_Kernel* kernel = function_->kernel(); - for (uint64_t z = 0; z < thread_dims.z; ++z) { for (uint64_t y = 0; y < thread_dims.y; ++y) { for (uint64_t x = 0; x < thread_dims.x; ++x) { @@ -140,7 +138,7 @@ absl::Status HostKernel::Launch( SE_HOST_KernelCallFrame call_frame = { &kernel_thread_dims, &kernel_thread, args.size(), args.data()}; - SE_HOST_KernelError* error = (*kernel)(&call_frame); + SE_HOST_KernelError* error = (*kernel_)(&call_frame); if (ABSL_PREDICT_FALSE(error != nullptr)) { return absl::InternalError("Failed to call host kernel"); @@ -174,8 +172,8 @@ tsl::AsyncValueRef HostKernel::Launch( } // Allocate a control structure that will orchestrate kernel execution. - auto state = tsl::MakeRef( - std::move(task_runner), function_.get(), thread_dims, args); + auto state = tsl::MakeRef(std::move(task_runner), + kernel_, thread_dims, args); state->CallAsync(/*start_index=*/0, /*end_index=*/num_tasks); @@ -183,11 +181,11 @@ tsl::AsyncValueRef HostKernel::Launch( } HostKernelExecuteState::HostKernelExecuteState( - HostKernel::TaskRunner task_runner, HostKernel::KernelFunction* function, + HostKernel::TaskRunner task_runner, SE_HOST_Kernel kernel, ThreadDim thread_dims, absl::Span args) : task_runner_(std::move(task_runner)), num_tasks_(thread_dims.x * thread_dims.y * thread_dims.z), - kernel_(function->kernel()), + kernel_(kernel), thread_dims_({thread_dims.x, thread_dims.y, thread_dims.z}), args_(args.begin(), args.end()), abort_(false), diff --git a/third_party/xla/xla/stream_executor/host/host_kernel.h b/third_party/xla/xla/stream_executor/host/host_kernel.h index 9d278b2b79c357..9bc96cb9e7ca2a 100644 --- a/third_party/xla/xla/stream_executor/host/host_kernel.h +++ b/third_party/xla/xla/stream_executor/host/host_kernel.h @@ -113,10 +113,12 @@ class HostKernel : public Kernel { std::enable_if_t>* = nullptr> void SetKernelFunction(std::unique_ptr function) { function_ = std::move(function); + kernel_ = function_->kernel(); } private: std::unique_ptr function_; + SE_HOST_Kernel* kernel_; // pointer to the kernel owned by `function_` unsigned arity_; std::shared_ptr thread_pool_; From b9bf806a8b752628d34ecf5aeb344352c851f183 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 25 Jun 2024 19:06:57 -0700 Subject: [PATCH 251/256] [xla:cpu] Move BufferAllocations implementation to header file Resolving buffer slice device memory is on a critical path of every thunk. Move implementation to header and force inlining to improve performance of ultra small kernels. PiperOrigin-RevId: 646684024 --- third_party/xla/xla/service/cpu/runtime/BUILD | 19 ++++- .../service/cpu/runtime/buffer_allocations.cc | 76 ----------------- .../service/cpu/runtime/buffer_allocations.h | 85 ++++++++++++++++--- .../cpu/runtime/buffer_allocations_test.cc | 53 ++++++++++++ 4 files changed, 144 insertions(+), 89 deletions(-) delete mode 100644 third_party/xla/xla/service/cpu/runtime/buffer_allocations.cc create mode 100644 third_party/xla/xla/service/cpu/runtime/buffer_allocations_test.cc diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD index 8c37031edf7985..f9ad7ef5c51300 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/service/cpu/runtime/BUILD @@ -17,21 +17,36 @@ package_group( cc_library( name = "buffer_allocations", - srcs = ["buffer_allocations.cc"], hdrs = ["buffer_allocations.h"], deps = [ + "//xla:util", "//xla/service:buffer_assignment", "//xla/service:maybe_owning_device_memory", "//xla/stream_executor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:statusor", ], ) +xla_cc_test( + name = "buffer_allocations_test", + srcs = ["buffer_allocations_test.cc"], + deps = [ + ":buffer_allocations", + "//xla/service:buffer_assignment", + "//xla/service:maybe_owning_device_memory", + "//xla/stream_executor:device_memory", + "@com_google_absl//absl/status", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + cc_library( name = "task", hdrs = ["task.h"], diff --git a/third_party/xla/xla/service/cpu/runtime/buffer_allocations.cc b/third_party/xla/xla/service/cpu/runtime/buffer_allocations.cc deleted file mode 100644 index e35b931c08e5bc..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/buffer_allocations.cc +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/cpu/runtime/buffer_allocations.h" - -#include - -#include "absl/base/optimization.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "xla/service/buffer_assignment.h" -#include "xla/stream_executor/device_memory.h" -#include "tsl/platform/statusor.h" - -namespace xla::cpu { - -absl::StatusOr BufferAllocations::GetDeviceAddress( - BufferAllocation::Index buffer_index) const { - if (ABSL_PREDICT_FALSE(buffer_index < 0 || buffer_index >= buffers_.size())) { - return absl::InvalidArgumentError(absl::StrCat( - "Invalid buffer_index ", buffer_index, - " value. It must be in the range [0, ", buffers_.size(), ")")); - } - - return buffers_[buffer_index].AsDeviceMemoryBase(); -} - -absl::StatusOr BufferAllocations::GetDeviceAddress( - const BufferAllocation::Slice& buffer_slice) const { - // Handle empty slices explicitly and return a null pointer device memory to - // guarantee that we do not accidentally write through the empty slice which - // would hide a real bug in the code. - if (ABSL_PREDICT_FALSE(buffer_slice.size() == 0)) { - return se::DeviceMemoryBase(nullptr, 0); - } - - int64_t index = buffer_slice.index(); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase base, GetDeviceAddress(index)); - - int64_t offset = buffer_slice.offset(); - int64_t extent = offset + buffer_slice.size(); - - if (ABSL_PREDICT_FALSE(offset < 0)) { - return absl::InvalidArgumentError( - absl::StrCat("Buffer slice offset ", offset, " must be non-negative")); - } - - if (ABSL_PREDICT_FALSE(offset >= base.size())) { - return absl::InvalidArgumentError(absl::StrCat( - "Buffer slice offset ", offset, " is out of range for buffer #", index, - " of size ", base.size())); - } - - if (ABSL_PREDICT_FALSE(extent > base.size())) { - return absl::InvalidArgumentError(absl::StrCat( - "Buffer slice extent ", extent, " is out of range for buffer #", index, - " of size ", base.size())); - } - - return base.GetByteSlice(offset, buffer_slice.size()); -} - -} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/buffer_allocations.h b/third_party/xla/xla/service/cpu/runtime/buffer_allocations.h index 76f05390a01b07..7abcff73fb5b66 100644 --- a/third_party/xla/xla/service/cpu/runtime/buffer_allocations.h +++ b/third_party/xla/xla/service/cpu/runtime/buffer_allocations.h @@ -16,39 +16,102 @@ limitations under the License. #ifndef XLA_SERVICE_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_ #define XLA_SERVICE_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_ +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/optimization.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/service/buffer_assignment.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/stream_executor/device_memory.h" +#include "xla/util.h" namespace xla::cpu { // Buffer allocation is a container for device buffers allocated for a // particular XLA execution. Buffers are indexed by the buffer allocation index. -// -// TODO(b/342513610): BufferAllocations should be unified with a same class in -// the XLA:GPU runtime, probably as a part of `buffer_assignment.h`. class BufferAllocations { public: - explicit BufferAllocations(absl::Span buffers) - : buffers_(buffers) {} + explicit inline BufferAllocations( + absl::Span buffers); // Returns the device address of buffer `buffer_index`. `buffer_index` must be // a valid index, i.e., in [0, buffer_count). - absl::StatusOr GetDeviceAddress( - BufferAllocation::Index buffer_index) const; + inline ABSL_ATTRIBUTE_ALWAYS_INLINE absl::StatusOr + GetDeviceAddress(BufferAllocation::Index buffer_index) const; // Same as above, but also adjusts the returned address for the offset and // size contained in the given slice. - absl::StatusOr GetDeviceAddress( - const BufferAllocation::Slice& buffer_slice) const; + inline ABSL_ATTRIBUTE_ALWAYS_INLINE absl::StatusOr + GetDeviceAddress(const BufferAllocation::Slice& buffer_slice) const; private: - // TODO(ezhulenev): Make BufferAllocations an owner of the buffers. - absl::Span buffers_; // not owned + std::vector buffers_; + size_t num_buffers_; }; +BufferAllocations::BufferAllocations( + absl::Span buffers) + : buffers_(buffers.size()), num_buffers_(buffers_.size()) { + for (size_t i = 0; i < buffers.size(); ++i) { + buffers_[i] = buffers[i].AsDeviceMemoryBase(); + } +} + +absl::StatusOr BufferAllocations::GetDeviceAddress( + BufferAllocation::Index index) const { + if (ABSL_PREDICT_FALSE(index < 0 || index >= num_buffers_)) { + return InvalidArgument( + "Invalid buffer index %d. It must be in the range [0, %d)", index, + num_buffers_); + } + + return buffers_[index]; +} + +absl::StatusOr BufferAllocations::GetDeviceAddress( + const BufferAllocation::Slice& buffer_slice) const { + // Handle empty slices explicitly and return a null pointer device memory to + // guarantee that we do not accidentally write through the empty slice which + // would hide a real bug in the code. + if (ABSL_PREDICT_FALSE(buffer_slice.size() == 0)) { + return se::DeviceMemoryBase(nullptr, 0); + } + + int64_t index = buffer_slice.index(); + if (ABSL_PREDICT_FALSE(index < 0 || index >= num_buffers_)) { + return InvalidArgument( + "Invalid buffer index %d. It must be in the range [0, %d)", index, + num_buffers_); + } + const se::DeviceMemoryBase& base = buffers_[index]; + + int64_t offset = buffer_slice.offset(); + int64_t extent = offset + buffer_slice.size(); + + if (ABSL_PREDICT_FALSE(offset < 0)) { + return InvalidArgument("Buffer slice offset %d must be non-negative", + offset); + } + + if (ABSL_PREDICT_FALSE(offset >= base.size())) { + return InvalidArgument( + "Buffer slice offset %d is out of range for buffer #%d of size %d", + offset, index, base.size()); + } + + if (ABSL_PREDICT_FALSE(extent > base.size())) { + return InvalidArgument( + "Buffer slice extent %d is out of range for buffer #%d of size %d", + extent, index, base.size()); + } + + return base.GetByteSlice(offset, buffer_slice.size()); +} + } // namespace xla::cpu #endif // XLA_SERVICE_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/buffer_allocations_test.cc b/third_party/xla/xla/service/cpu/runtime/buffer_allocations_test.cc new file mode 100644 index 00000000000000..f281924e2542ac --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/buffer_allocations_test.cc @@ -0,0 +1,53 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/runtime/buffer_allocations.h" + +#include +#include + +#include "xla/service/buffer_assignment.h" +#include "xla/service/maybe_owning_device_memory.h" +#include "xla/stream_executor/device_memory.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla::cpu { +namespace { + +TEST(BufferAllocationsTest, GetDeviceAddress) { + std::vector buffers; + std::vector data = {1.0, 2.0, 3.0, 4.0}; + + size_t size_in_bytes = data.size() * sizeof(float); + buffers.emplace_back(se::DeviceMemoryBase(data.data(), size_in_bytes)); + + BufferAllocations allocations(buffers); + + BufferAllocation alloc(0, size_in_bytes, 0); + BufferAllocation::Slice slice(&alloc, /*offset=*/2 * sizeof(float), + /*size=*/sizeof(float)); + + TF_ASSERT_OK_AND_ASSIGN(se::DeviceMemoryBase alloc_mem, + allocations.GetDeviceAddress(0)); + EXPECT_EQ(alloc_mem.opaque(), &data[0]); + + TF_ASSERT_OK_AND_ASSIGN(se::DeviceMemoryBase slice_mem, + allocations.GetDeviceAddress(slice)); + EXPECT_EQ(slice_mem.opaque(), &data[2]); +} + +} // namespace +} // namespace xla::cpu From 03816bce65879c7921578ed6aaf39280897257a4 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Tue, 25 Jun 2024 20:34:33 -0700 Subject: [PATCH 252/256] PR #13869: [GPU] Dump cuDNN GEMM fusion graphs to json files. Imported from GitHub PR https://github.com/openxla/xla/pull/13869 Requires cuDNN frontend v1.5: merge https://github.com/openxla/xla/pull/13757 first. Copybara import of the project: -- afe425460a5135f5e8e1ad72accfa52fb9c9808a by Ilia Sergachev : [GPU] Dump cuDNN GEMM fusion graphs to json files. Merging this change closes #13869 PiperOrigin-RevId: 646703335 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/cudnn_fusion_compiler.cc | 12 ++- third_party/xla/xla/service/gpu/fusions/BUILD | 2 + .../xla/xla/service/gpu/fusions/cudnn_test.cc | 75 +++++++++++++++++++ 4 files changed, 89 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 0726033d31e418..1699d7f5ee5afe 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -3014,6 +3014,7 @@ cc_library( "//xla/service:hlo_pass", "//xla/stream_executor:dnn", "//xla/stream_executor:stream_executor_h", + "//xla/service:dump", "//xla/stream_executor/cuda:cudnn_frontend_helpers", "//xla/stream_executor/cuda:cudnn_plugin", "@local_tsl//tsl/platform:errors", diff --git a/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc b/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc index 1f459446c38425..f1136d116fd313 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc +++ b/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/primitive_util.h" +#include "xla/service/dump.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cudnn_support_utils.h" #include "xla/service/gpu/ir_emission_utils.h" @@ -573,6 +574,16 @@ absl::StatusOr> HloFusionToCuDnnGraph( .set_dim(dimensions) .set_stride(strides) .set_uid(se::gpu::CuDnnTensorUID(fusion.operand_count())); + if (!fusion.GetModule()->config().debug_options().xla_dump_to().empty()) { + json dump; + graph.serialize(dump); + DumpToFileInDirOrStdout( + /*module=*/*fusion.GetModule(), + /*file_prefix=*/"", + /*file_suffix=*/ + absl::StrCat("cudnn_fusion_", fusion.name(), ".json"), + /*contents=*/dump.dump(1)); + } if (cudnn_frontend::error_t result = graph.validate(); result.is_bad()) { VLOG(3) << result.get_message(); return std::nullopt; @@ -589,7 +600,6 @@ absl::StatusOr PrepareGraph( if (!graph.has_value()) { return absl::InternalError("Construction of cuDNN graph failed."); } - VLOG(6) << graph->Graph().print(); TF_ASSIGN_OR_RETURN(bool supported, graph->Prepare(dnn_support)); if (!supported) { return absl::InternalError("cuDNN graph is not supported."); diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index 8af2b159c2967d..aae31efb38fcc8 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -677,6 +677,7 @@ xla_test( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service:dump", "//xla/service:executable", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", @@ -693,6 +694,7 @@ xla_test( "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", ], diff --git a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc index 0c9dec9849da0c..4eeed4ed8c0606 100644 --- a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" +#include "xla/service/dump.h" #include "xla/service/executable.h" #include "xla/service/gpu/cudnn_fusion_compiler.h" #include "xla/service/gpu/runtime/thunk.h" @@ -44,6 +45,7 @@ limitations under the License. #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/path.h" #include "tsl/platform/statusor.h" namespace xla { @@ -83,6 +85,79 @@ class CuDnnFusionTest : public GpuCodegenTest { } }; +TEST_F(CuDnnFusionTest, DumpingWorks) { + HloModuleConfig config; + DebugOptions options = GetDebugOptionsForTest(); + std::string output_directory; + if (!tsl::io::GetTestUndeclaredOutputsDir(&output_directory)) { + output_directory = tsl::testing::TmpDir(); + } + options.set_xla_dump_to(output_directory); + config.set_debug_options(options); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +fd0 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + ROOT d = f32[64,64] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + ROOT d0 = f32[64,64] fusion(p0, p1), kind=kCustom, calls=fd0, + backend_config={"fusion_backend_config":{"kind":"__cudnn$fusion","cudnn_fusion_config":{"plan_id":"0"}}} +})", + config)); + Thunk::BinaryMap dnn_compiled_graphs; + CuDnnFusionCompiler cudnn_compiler(*backend().default_stream_executor(), + dnn_compiled_graphs); + TF_ASSERT_OK_AND_ASSIGN(bool changed, cudnn_compiler.Run(module.get())); + EXPECT_TRUE(changed); + std::string dump; + TF_EXPECT_OK(tsl::ReadFileToString( + tsl::Env::Default(), + tsl::io::JoinPath(output_directory, + FilenameFor(*module, /*prefix=*/"", + /*suffix=*/"cudnn_fusion_d0.json")), + &dump)); + EXPECT_TRUE(*RunFileCheck(dump, R"( +CHECK: "nodes": [ +CHECK: "inputs": { +CHECK: "A": "p0", +CHECK: "B": "p1" +CHECK: }, +CHECK: "outputs": { +CHECK: "C": "d" +CHECK: }, +CHECK: "tag": "MATMUL" +CHECK: } +CHECK: ], +CHECK: "tensors": { +CHECK: "d": { +CHECK: "data_type": "FLOAT", +CHECK: "dim": [1,64,64], +CHECK: "stride": [1,64,1], +CHECK: "uid": 3, +CHECK: "uid_assigned": true +CHECK: }, +CHECK: "p0": { +CHECK: "data_type": "FLOAT", +CHECK: "dim": [1,64,64], +CHECK: "stride": [1,64,1], +CHECK: "uid": 1, +CHECK: "uid_assigned": true +CHECK: }, +CHECK: "p1": { +CHECK: "data_type": "FLOAT", +CHECK: "dim": [1,64,64], +CHECK: "stride": [1,64,1], +CHECK: "uid": 2, +CHECK: "uid_assigned": true +CHECK: } +)")); +} + using CuDnnFusionExecutionTest = CuDnnFusionTest; namespace m = ::xla::match; From d24dadc14e6172b30ed45324c849e2666b505c3f Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Tue, 25 Jun 2024 21:31:17 -0700 Subject: [PATCH 253/256] Fix a bug in PjRtStreamExecutorLoadedExecutable::GetOutputMemoryKinds. If there are no addressable devices, the code will crash when trying to access the default memory space of the first device. Also changed to const & for two loop variables to avoid copies. PiperOrigin-RevId: 646715753 --- third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index fb143187699dd7..be44421dcd84de 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -3278,7 +3278,7 @@ absl::StatusOr> MemoryKindsFromShape( } std::vector result; result.reserve(shape.tuple_shapes_size()); - for (auto element_shape : shape.tuple_shapes()) { + for (const auto& element_shape : shape.tuple_shapes()) { TF_ASSIGN_OR_RETURN( absl::string_view element_memory_kind, MemoryKindFromSimpleShape(element_shape, default_memory_kind)); @@ -3292,11 +3292,16 @@ absl::StatusOr> MemoryKindsFromShape( absl::StatusOr>> PjRtStreamExecutorLoadedExecutable::GetOutputMemoryKinds() const { TF_ASSIGN_OR_RETURN(auto shapes, GetOutputShapes()); + if (addressable_devices().empty()) { + return Unimplemented( + "GetOutputMemoryKinds is not supported when there are no addressable " + "devices in PjRtStreamExecutorLoadedExecutable."); + } TF_ASSIGN_OR_RETURN(PjRtMemorySpace * default_memory_space, addressable_devices()[0]->default_memory_space()); std::vector> out; out.reserve(shapes.size()); - for (auto shape : shapes) { + for (const auto& shape : shapes) { TF_ASSIGN_OR_RETURN( std::vector memory_kind, MemoryKindsFromShape(shape, default_memory_space->kind())); From 80c7862a3eb91ed0d57ce81cdc2ae3821eb8c3b3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jun 2024 22:17:55 -0700 Subject: [PATCH 254/256] Automated Code Change PiperOrigin-RevId: 646725767 --- .../test/src/org/tensorflow/demo/CameraConnectionFragment.java | 2 +- .../org/tensorflow/demo/LegacyCameraConnectionFragment.java | 3 ++- .../src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow/tools/android/test/src/org/tensorflow/demo/CameraConnectionFragment.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/CameraConnectionFragment.java index 361cf0e87c7f80..974ccd90ad9b6b 100644 --- a/tensorflow/tools/android/test/src/org/tensorflow/demo/CameraConnectionFragment.java +++ b/tensorflow/tools/android/test/src/org/tensorflow/demo/CameraConnectionFragment.java @@ -401,7 +401,7 @@ private void setUpCameraOutputs() { // reuse throughout app. ErrorDialog.newInstance(getString(R.string.camera_error)) .show(getChildFragmentManager(), FRAGMENT_DIALOG); - throw new RuntimeException(getString(R.string.camera_error)); + throw new RuntimeException(getString(R.string.camera_error), e); } cameraConnectionCallback.onPreviewSizeChosen(previewSize, sensorOrientation); diff --git a/tensorflow/tools/android/test/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java index 068c7b0d945669..b16c21a89bd8e6 100644 --- a/tensorflow/tools/android/test/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java +++ b/tensorflow/tools/android/test/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java @@ -208,8 +208,9 @@ private int getCameraId() { CameraInfo ci = new CameraInfo(); for (int i = 0; i < Camera.getNumberOfCameras(); i++) { Camera.getCameraInfo(i, ci); - if (ci.facing == CameraInfo.CAMERA_FACING_BACK) + if (ci.facing == CameraInfo.CAMERA_FACING_BACK) { return i; + } } return -1; // No camera found } diff --git a/tensorflow/tools/android/test/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java index ea837e62d5b06a..4ed90925d9badc 100644 --- a/tensorflow/tools/android/test/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java +++ b/tensorflow/tools/android/test/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java @@ -116,7 +116,7 @@ public static Classifier create( try { d.loadCoderOptions(assetManager, locationFilename, d.boxPriors); } catch (final IOException e) { - throw new RuntimeException("Error initializing box priors from " + locationFilename); + throw new RuntimeException("Error initializing box priors from " + locationFilename, e); } // Pre-allocate buffers. From cfc5c2e0b08a97dff1d27e76422d6ded37a90456 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Jun 2024 00:03:43 -0700 Subject: [PATCH 255/256] Automated Code Change PiperOrigin-RevId: 646751150 --- tensorflow/core/common_runtime/executor.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 81bc4df917807e..48ec47636e30df 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -647,7 +647,7 @@ void ExecutorState::ProcessAsync( tsl::profiler::GetTFTraceMeLevel(/*is_expensive=*/false)); // Trace async op start. - profiler::TraceMeProducer producer( + tsl::profiler::TraceMeProducer producer( [&] { return tsl::profiler::TraceMeEncode( "ExecutorState::ProcessAsync::Start", @@ -659,7 +659,7 @@ void ExecutorState::ProcessAsync( auto done = [this, state, activity_id, ctx_id = producer.GetContextId()]() { // Trace async op done. - profiler::TraceMeConsumer consumer( + tsl::profiler::TraceMeConsumer consumer( [&] { return profiler::TraceMeEncode( "ExecutorState::ProcessAsync::Done", @@ -807,13 +807,13 @@ void ExecutorState::ProcessInline( bool completed = false; int64_t last_iter_num = -1; - std::unique_ptr iteration_scope; + std::unique_ptr iteration_scope; while (!inline_ready->empty()) { TaggedNode tagged_node = inline_ready->front(); int64_t current_iter_num = tagged_node.get_iter_num(); if (current_iter_num != last_iter_num) { - iteration_scope = std::make_unique( + iteration_scope = std::make_unique( // From TraceMeProducer in DirectSession::RunInternal, // GraphMgr::ExecuteAsync, or FunctionLibraryRuntime::Run. [&] { @@ -1459,7 +1459,7 @@ void ExecutorState::Finish() { } delete this; runner([step_id, trace_id, status, done_cb = std::move(done_cb)]() { - profiler::TraceMeConsumer activity( + tsl::profiler::TraceMeConsumer activity( // From TraceMeProducer in KernelAndDeviceFunc::RunAsync, // DirectSession::RunInternal or GraphMgr::ExecuteAsync. [&] { @@ -1482,7 +1482,7 @@ void ExecutorState::Finish() { done_cb = std::move(done_cb)](const Status& status) mutable { delete this; runner([step_id, trace_id, status, done_cb = std::move(done_cb)]() { - profiler::TraceMeConsumer activity( + tsl::profiler::TraceMeConsumer activity( // From TraceMeProducer in KernelAndDeviceFunc::RunAsync, // DirectSession::RunInternal or GraphMgr::ExecuteAsync. [&] { @@ -1497,7 +1497,7 @@ void ExecutorState::Finish() { } else { delete this; runner([step_id, trace_id, status, done_cb = std::move(done_cb)]() { - profiler::TraceMeConsumer activity( + tsl::profiler::TraceMeConsumer activity( // From TraceMeProducer in KernelAndDeviceFunc::RunAsync, // DirectSession::RunInternal or GraphMgr::ExecuteAsync. [&] { From dd054d82b22efb51ac99185266920c2a6b9f8aea Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 18 Jun 2024 01:05:15 -0700 Subject: [PATCH 256/256] [NCCL] Upgrade TF NCCL version to 2.21.5 PiperOrigin-RevId: 644285199 --- tensorflow/tools/pip_package/setup.py | 2 +- tensorflow/workspace2.bzl | 6 +-- third_party/nccl/archive.BUILD | 4 +- third_party/nccl/archive.patch | 50 +++++++++++++++++-- .../tsl/third_party/nccl/archive.BUILD | 4 +- .../tsl/third_party/nccl/archive.patch | 50 +++++++++++++++++-- .../xla/third_party/tsl/workspace2.bzl | 6 +-- 7 files changed, 105 insertions(+), 17 deletions(-) diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 8e7c1d2b22836b..44e0112141d770 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -160,7 +160,7 @@ def standard_or_nightly(standard, nightly): 'nvidia-curand-cu12 == 10.3.4.107', 'nvidia-cusolver-cu12 == 11.5.4.101', 'nvidia-cusparse-cu12 == 12.2.0.103', - 'nvidia-nccl-cu12 == 2.19.3', + 'nvidia-nccl-cu12 == 2.21.5', 'nvidia-nvjitlink-cu12 == 12.3.101', ] diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 08d01a069bb30b..380a1a93585bde 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -516,9 +516,9 @@ def _tf_repositories(): name = "nccl_archive", build_file = "//third_party:nccl/archive.BUILD", patch_file = ["//third_party/nccl:archive.patch"], - sha256 = "1c5474553afedb88e878c772f13d6f90b9226b3f2971dfa6f873adb9443100c2", - strip_prefix = "nccl-2.19.3-1", - urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.19.3-1.tar.gz"), + sha256 = "1923596984d85e310b5b6c52b2c72a1b93da57218f2bc5a5c7ac3d59297a3303", + strip_prefix = "nccl-2.21.5-1", + urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.21.5-1.tar.gz"), ) java_import_external( diff --git a/third_party/nccl/archive.BUILD b/third_party/nccl/archive.BUILD index 72f91a68474f97..1f4b58f47e379c 100644 --- a/third_party/nccl/archive.BUILD +++ b/third_party/nccl/archive.BUILD @@ -22,9 +22,9 @@ exports_files(["LICENSE.txt"]) NCCL_MAJOR = 2 -NCCL_MINOR = 19 +NCCL_MINOR = 21 -NCCL_PATCH = 3 +NCCL_PATCH = 5 NCCL_VERSION = NCCL_MAJOR * 10000 + NCCL_MINOR * 100 + NCCL_PATCH # e.g., 21605 diff --git a/third_party/nccl/archive.patch b/third_party/nccl/archive.patch index 372c0f493fc387..2b4fa56a97e759 100644 --- a/third_party/nccl/archive.patch +++ b/third_party/nccl/archive.patch @@ -1,9 +1,31 @@ +diff --git a/src/device/all_gather.h b/src/device/all_gather.h +index 809e8ae..57eab81 100644 +--- a/src/device/all_gather.h ++++ b/src/device/all_gather.h +@@ -296,7 +296,7 @@ struct RunWorkElement(scat); ++ prims.template process(scat); + } + } + return; +@@ -314,7 +314,7 @@ struct RunWorkElement(scat); ++ prims.template process(scat); + } + return; + } diff --git a/src/device/common.cu b/src/device/common.cu.cc similarity index 100% rename from src/device/common.cu rename to src/device/common.cu.cc diff --git a/src/device/common.h b/src/device/common.h -index 97581f7..134fdb8 100644 +index d8581d3..09ac3b6 100644 --- a/src/device/common.h +++ b/src/device/common.h @@ -15,7 +15,7 @@ @@ -14,9 +36,9 @@ index 97581f7..134fdb8 100644 +extern __device__ ncclDevFuncPtr_t ncclDevFuncTable[]; struct ncclShmemGroup { - ncclConnInfo *recvConns[NCCL_MAX_NVLS_ARITY]; + ncclConnInfo *recvConns[NCCL_MAX_ARITY]; diff --git a/src/device/generate.py b/src/device/generate.py -index 0b053de..87bf6cb 100755 +index 43de85d..87cd677 100755 --- a/src/device/generate.py +++ b/src/device/generate.py @@ -195,7 +195,7 @@ kernel_funcs = sorted(set(best_kernel(*fn) for fn in primary_funcs)) @@ -111,3 +133,25 @@ diff --git a/src/device/onerank.cu b/src/device/onerank.cu.cc similarity index 100% rename from src/device/onerank.cu rename to src/device/onerank.cu.cc +diff --git a/src/device/reduce_scatter.h b/src/device/reduce_scatter.h +index d0b5249..2dacd60 100644 +--- a/src/device/reduce_scatter.h ++++ b/src/device/reduce_scatter.h +@@ -254,7 +254,7 @@ struct RunWorkElement(scat); ++ prims.template process(scat); + } + return; + } +@@ -278,7 +278,7 @@ struct RunWorkElement(scat); ++ prims.template process(scat); + } + } + return; diff --git a/third_party/xla/third_party/tsl/third_party/nccl/archive.BUILD b/third_party/xla/third_party/tsl/third_party/nccl/archive.BUILD index 72f91a68474f97..1f4b58f47e379c 100644 --- a/third_party/xla/third_party/tsl/third_party/nccl/archive.BUILD +++ b/third_party/xla/third_party/tsl/third_party/nccl/archive.BUILD @@ -22,9 +22,9 @@ exports_files(["LICENSE.txt"]) NCCL_MAJOR = 2 -NCCL_MINOR = 19 +NCCL_MINOR = 21 -NCCL_PATCH = 3 +NCCL_PATCH = 5 NCCL_VERSION = NCCL_MAJOR * 10000 + NCCL_MINOR * 100 + NCCL_PATCH # e.g., 21605 diff --git a/third_party/xla/third_party/tsl/third_party/nccl/archive.patch b/third_party/xla/third_party/tsl/third_party/nccl/archive.patch index 372c0f493fc387..2b4fa56a97e759 100644 --- a/third_party/xla/third_party/tsl/third_party/nccl/archive.patch +++ b/third_party/xla/third_party/tsl/third_party/nccl/archive.patch @@ -1,9 +1,31 @@ +diff --git a/src/device/all_gather.h b/src/device/all_gather.h +index 809e8ae..57eab81 100644 +--- a/src/device/all_gather.h ++++ b/src/device/all_gather.h +@@ -296,7 +296,7 @@ struct RunWorkElement(scat); ++ prims.template process(scat); + } + } + return; +@@ -314,7 +314,7 @@ struct RunWorkElement(scat); ++ prims.template process(scat); + } + return; + } diff --git a/src/device/common.cu b/src/device/common.cu.cc similarity index 100% rename from src/device/common.cu rename to src/device/common.cu.cc diff --git a/src/device/common.h b/src/device/common.h -index 97581f7..134fdb8 100644 +index d8581d3..09ac3b6 100644 --- a/src/device/common.h +++ b/src/device/common.h @@ -15,7 +15,7 @@ @@ -14,9 +36,9 @@ index 97581f7..134fdb8 100644 +extern __device__ ncclDevFuncPtr_t ncclDevFuncTable[]; struct ncclShmemGroup { - ncclConnInfo *recvConns[NCCL_MAX_NVLS_ARITY]; + ncclConnInfo *recvConns[NCCL_MAX_ARITY]; diff --git a/src/device/generate.py b/src/device/generate.py -index 0b053de..87bf6cb 100755 +index 43de85d..87cd677 100755 --- a/src/device/generate.py +++ b/src/device/generate.py @@ -195,7 +195,7 @@ kernel_funcs = sorted(set(best_kernel(*fn) for fn in primary_funcs)) @@ -111,3 +133,25 @@ diff --git a/src/device/onerank.cu b/src/device/onerank.cu.cc similarity index 100% rename from src/device/onerank.cu rename to src/device/onerank.cu.cc +diff --git a/src/device/reduce_scatter.h b/src/device/reduce_scatter.h +index d0b5249..2dacd60 100644 +--- a/src/device/reduce_scatter.h ++++ b/src/device/reduce_scatter.h +@@ -254,7 +254,7 @@ struct RunWorkElement(scat); ++ prims.template process(scat); + } + return; + } +@@ -278,7 +278,7 @@ struct RunWorkElement(scat); ++ prims.template process(scat); + } + } + return; diff --git a/third_party/xla/third_party/tsl/workspace2.bzl b/third_party/xla/third_party/tsl/workspace2.bzl index b583f3c0eec7be..6738dbec1b827a 100644 --- a/third_party/xla/third_party/tsl/workspace2.bzl +++ b/third_party/xla/third_party/tsl/workspace2.bzl @@ -397,9 +397,9 @@ def _tf_repositories(): name = "nccl_archive", build_file = "//third_party:nccl/archive.BUILD", patch_file = ["//third_party/nccl:archive.patch"], - sha256 = "1c5474553afedb88e878c772f13d6f90b9226b3f2971dfa6f873adb9443100c2", - strip_prefix = "nccl-2.19.3-1", - urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.19.3-1.tar.gz"), + sha256 = "1923596984d85e310b5b6c52b2c72a1b93da57218f2bc5a5c7ac3d59297a3303", + strip_prefix = "nccl-2.21.5-1", + urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.21.5-1.tar.gz"), ) # Note that we are currently taking NVTX headers from a NCCL release to get nvToolsExtPayload.h