From eaef53bb90c18fd6698dc207f47d6ef9bf3310e8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 19 Jun 2024 03:04:00 -0700 Subject: [PATCH 1/6] Automated Code Change PiperOrigin-RevId: 644686640 --- third_party/xla/third_party/tsl/tsl/platform/env.cc | 3 ++- third_party/xla/third_party/tsl/tsl/platform/env.h | 4 ++-- third_party/xla/third_party/tsl/tsl/platform/file_system.cc | 6 +++--- third_party/xla/third_party/tsl/tsl/platform/file_system.h | 4 ++-- .../xla/third_party/tsl/tsl/platform/retrying_file_system.h | 4 ++-- 5 files changed, 11 insertions(+), 10 deletions(-) diff --git a/third_party/xla/third_party/tsl/tsl/platform/env.cc b/third_party/xla/third_party/tsl/tsl/platform/env.cc index 77f48b3372e1eb..789725e8856b94 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/env.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/env.cc @@ -328,7 +328,8 @@ absl::Status Env::HasAtomicMove(const string& path, bool* has_atomic_move) { return fs->HasAtomicMove(path, has_atomic_move); } -Status Env::CanCreateTempFile(const string& fname, bool* can_create_temp_file) { +absl::Status Env::CanCreateTempFile(const string& fname, + bool* can_create_temp_file) { FileSystem* fs; TF_RETURN_IF_ERROR(GetFileSystemForFile(fname, &fs)); return fs->CanCreateTempFile(fname, can_create_temp_file); diff --git a/third_party/xla/third_party/tsl/tsl/platform/env.h b/third_party/xla/third_party/tsl/tsl/platform/env.h index 37abc1dee97d54..0952517f9b7f8c 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/env.h +++ b/third_party/xla/third_party/tsl/tsl/platform/env.h @@ -352,8 +352,8 @@ class Env { /// If this returns false, TensorFlow will write directly to output files /// instead of creating a temporary file and swapping it in. This may mean /// that incomplete writes are visible to consumers. - Status CanCreateTempFile(const std::string& fname, - bool* can_create_temp_file); + absl::Status CanCreateTempFile(const std::string& fname, + bool* can_create_temp_file); /// Stores the size of `fname` in `*file_size`. absl::Status GetFileSize(const std::string& fname, uint64* file_size); diff --git a/third_party/xla/third_party/tsl/tsl/platform/file_system.cc b/third_party/xla/third_party/tsl/tsl/platform/file_system.cc index 68d0fcf0499ca5..ee385af7354074 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/file_system.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/file_system.cc @@ -95,10 +95,10 @@ absl::Status FileSystem::HasAtomicMove(const string& path, return absl::OkStatus(); } -Status FileSystem::CanCreateTempFile(const std::string& fname, - bool* can_create_temp_file) { +absl::Status FileSystem::CanCreateTempFile(const std::string& fname, + bool* can_create_temp_file) { *can_create_temp_file = true; - return OkStatus(); + return absl::OkStatus(); } void FileSystem::FlushCaches(TransactionToken* token) {} diff --git a/third_party/xla/third_party/tsl/tsl/platform/file_system.h b/third_party/xla/third_party/tsl/tsl/platform/file_system.h index 4b728a42c4d507..67209ed491055f 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/file_system.h +++ b/third_party/xla/third_party/tsl/tsl/platform/file_system.h @@ -392,8 +392,8 @@ class FileSystem { /// to determine if there needs to be a temp location to safely write objects. /// If the file system cannot create a temp file, it's possibile that /// uncomplete result may appear in the given file. - virtual Status CanCreateTempFile(const std::string& fname, - bool* can_create_temp_file); + virtual absl::Status CanCreateTempFile(const std::string& fname, + bool* can_create_temp_file); /// \brief Flushes any cached filesystem objects from memory. virtual void FlushCaches() { FlushCaches(nullptr); } diff --git a/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system.h b/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system.h index 88da8787c3d618..a64ecc20e960ff 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system.h +++ b/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system.h @@ -151,8 +151,8 @@ class RetryingFileSystem : public FileSystem { return base_file_system_->HasAtomicMove(path, has_atomic_move); } - Status CanCreateTempFile(const std::string& fname, - bool* can_create_temp_file) override { + absl::Status CanCreateTempFile(const std::string& fname, + bool* can_create_temp_file) override { // this method does not need to be retried return base_file_system_->CanCreateTempFile(fname, can_create_temp_file); } From cdafce893a83095e93018257db4ec0326f92bec7 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Wed, 19 Jun 2024 03:18:55 -0700 Subject: [PATCH 2/6] [XLA:GPU] [NFC] Remove argument which is never passed PiperOrigin-RevId: 644690105 --- .../xla/xla/service/gpu/runtime/sequential_thunk.cc | 2 +- third_party/xla/xla/service/gpu/runtime/thunk.cc | 7 +------ third_party/xla/xla/service/gpu/runtime/thunk.h | 4 +--- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc index c58f31b5207a7a..a0dc62fbc7155d 100644 --- a/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc @@ -33,7 +33,7 @@ SequentialThunk::SequentialThunk(ThunkInfo thunk_info, ThunkSequence thunks) std::string SequentialThunk::ToStringExtra(int indent) const { std::string result = "\n"; - absl::StrAppend(&result, thunks().ToString(indent + 1, nullptr)); + absl::StrAppend(&result, thunks().ToString(indent + 1)); return result; } diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.cc b/third_party/xla/xla/service/gpu/runtime/thunk.cc index 0c95623db733be..877db05fe04e89 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/thunk.cc @@ -303,9 +303,7 @@ std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) { return os << Thunk::KindToString(kind); } -std::string ThunkSequence::ToString( - int indent, - std::function get_thunk_annotation) const { +std::string ThunkSequence::ToString(int indent) const { const std::string indent_str(indent * 2, ' '); if (empty()) return indent_str + "No thunks."; @@ -324,9 +322,6 @@ std::string ThunkSequence::ToString( absl::StrAppend(&result, indent_str, kind_str, std::string(max_thunk_kind_len - kind_str.length(), ' '), "\t"); - if (get_thunk_annotation) { - absl::StrAppend(&result, get_thunk_annotation(thunk.get())); - } absl::StrAppend(&result, thunk->ToStringExtra(indent)); absl::StrAppend(&result, "\n"); } diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.h b/third_party/xla/xla/service/gpu/runtime/thunk.h index 6a271ce8f623ba..78698859bf8553 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/thunk.h @@ -460,9 +460,7 @@ class Thunk { // A sequence of thunks. class ThunkSequence : public std::vector> { public: - std::string ToString(int indent = 0, - std::function - get_thunk_annotation = nullptr) const; + std::string ToString(int indent = 0) const; }; std::ostream& operator<<(std::ostream& os, Thunk::Kind kind); From 160e76085434b413699e6979c800edaef838b302 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Wed, 19 Jun 2024 03:39:07 -0700 Subject: [PATCH 3/6] [XLA:GPU] Simplify TritonSupport tests by providing a standard ENTRY computation. PiperOrigin-RevId: 644694885 --- third_party/xla/xla/service/gpu/BUILD | 3 + .../xla/service/gpu/triton_support_test.cc | 120 ++---------------- .../xla/xla/service/gpu/triton_test_utils.cc | 50 ++++++++ .../xla/xla/service/gpu/triton_test_utils.h | 7 + 4 files changed, 72 insertions(+), 108 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index bf3e191aec5e35..dba824f313e4d6 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -671,6 +671,7 @@ cc_library( deps = [ ":gpu_device_info_for_tests", ":gpu_float_support", + ":ir_emission_utils", ":ir_emitter_triton", ":matmul_utils", "//xla:shape_util", @@ -684,6 +685,7 @@ cc_library( "//xla/stream_executor:device_description", "//xla/tests:filecheck", "//xla/tests:verified_hlo_module", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -693,6 +695,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/gpu/triton_support_test.cc b/third_party/xla/xla/service/gpu/triton_support_test.cc index 6c690bad4c47a9..3cf75a4cbad7e9 100644 --- a/third_party/xla/xla/service/gpu/triton_support_test.cc +++ b/third_party/xla/xla/service/gpu/triton_support_test.cc @@ -72,16 +72,9 @@ using BitcastOrReshapeTest = TritonSupportTestWithParam; TEST_P(BitcastOrReshapeTest, IsTritonSupportedBitcastOrReshape) { auto [data_type, opcode] = GetParam(); const std::string kHloTestTemplate = R"( -triton_computation { +ENTRY triton_computation { parameter_0 = $0[1,16,4]{2,1,0} parameter(0) ROOT bitcast_or_reshape = $0[64]{0} $1(parameter_0) -} - -ENTRY e { - parameter_0 = $0[1,16,4]{2,1,0} parameter(0) - ROOT root_op = $0[64]{0} fusion(parameter_0), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, @@ -120,17 +113,10 @@ TEST_P(UnaryElementwiseTest, IsTritonSupportedUnaryElementwise) { } const std::string kHloTestTemplate = R"( -triton_computation { +ENTRY triton_computation { parameter_0 = $0[33,68]{1,0} parameter(0) unary = $0[33,68]{1,0} $1(parameter_0) ROOT convert = f32[33,68]{1,0} convert(unary) -} - -ENTRY e { - parameter_0 = $0[33,68]{1,0} parameter(0) - ROOT root_op = f32[33,68]{1,0} fusion(parameter_0), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, @@ -184,18 +170,10 @@ TEST_P(BinaryElementwiseTest, IsTritonSupportedBinaryElementwise) { } const std::string kHloTestTemplate = R"( -triton_computation { +ENTRY triton_computation { parameter_0 = $0[11,63]{1,0} parameter(0) parameter_1 = $0[11,63]{1,0} parameter(1) ROOT binary = $0[11,63]{1,0} $1(parameter_0, parameter_1) -} - -ENTRY e { - parameter_0 = $0[11,63]{1,0} parameter(0) - parameter_1 = $0[11,63]{1,0} parameter(1) - ROOT triton_op = $0[11,63]{1,0} fusion(parameter_0, parameter_1), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, @@ -251,19 +229,11 @@ TEST_P(CompareTest, IsTritonSupportedCompare) { } const std::string kHloTestTemplate = R"( -triton_computation { +ENTRY triton_computation { parameter_0 = $0[11,63]{1,0} parameter(0) parameter_1 = $0[11,63]{1,0} parameter(1) compare = pred[11,63]{1,0} $1(parameter_0, parameter_1), direction=GE ROOT convert = f32[11,63]{1,0} convert(compare) -} - -ENTRY e { - parameter_0 = $0[11,63]{1,0} parameter(0) - parameter_1 = $0[11,63]{1,0} parameter(1) - ROOT triton_op = f32[11,63]{1,0} fusion(parameter_0, parameter_1), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, @@ -298,21 +268,12 @@ TEST_P(TernaryElementwiseTest, IsTritonSupportedTernaryElementwise) { } const std::string kHloTestTemplate = R"( -triton_computation { +ENTRY triton_computation { parameter_0 = $0[13,63]{1,0} parameter(0) parameter_1 = $0[13,63]{1,0} parameter(1) parameter_2 = pred[13,63]{1,0} parameter(2) ternary = $0[13,63]{1,0} $1(parameter_2, parameter_0, parameter_1) ROOT convert = f32[13,63]{1,0} convert(ternary) -} - -ENTRY e { - parameter_0 = $0[13,63]{1,0} parameter(0) - parameter_1 = $0[13,63]{1,0} parameter(1) - parameter_2 = pred[13,63]{1,0} parameter(2) - ROOT triton_op = f32[13,63]{1,0} fusion(parameter_0, parameter_1, parameter_2), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, @@ -353,18 +314,10 @@ add { ROOT add = $0[] add(Arg_0, Arg_1) } -triton_computation { +ENTRY triton_computation { parameter_0 = $0[125,127]{1,0} parameter(0) constant_0 = $0[] constant(0) ROOT reduce = $0[125]{0} $1(parameter_0, constant_0), dimensions={1}, to_apply=add -} - -ENTRY main { - parameter_0 = $0[125,127]{1,0} parameter(0) - ROOT triton_op = $0[125]{0} fusion(parameter_0), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, @@ -406,19 +359,11 @@ add { ROOT add = f32[] add(Arg_0, Arg_1) } -triton_computation { +ENTRY triton_computation { parameter_0 = f32[125,127]{1,0} parameter(0) constant_0 = bf16[] constant(0) convert_0 = f32[] convert(constant_0) ROOT reduce = f32[125]{0} reduce(parameter_0, convert_0), dimensions={1}, to_apply=add -} - -ENTRY main { - parameter_0 = f32[125,127]{1,0} parameter(0) - ROOT triton_op = f32[125]{0} fusion(parameter_0), kind=kCustom, - calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( @@ -442,18 +387,10 @@ add { ROOT add = f32[] add(Arg_0, Arg_1) } -triton_computation { +ENTRY triton_computation { parameter_0 = f32[2,125,127]{2,1,0} parameter(0) constant_0 = f32[] constant(0) ROOT reduce = f32[2]{0} reduce(parameter_0, constant_0), dimensions={1,2}, to_apply=add -} - -ENTRY main { - parameter_0 = f32[2,125,127]{2,1,0} parameter(0) - ROOT triton_op = f32[2]{0} fusion(parameter_0), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( @@ -479,18 +416,10 @@ add { ROOT add = f32[] add(Arg_0, Arg_1) } -triton_computation { +ENTRY triton_computation { parameter_0 = f32[125,127]{1,0} parameter(0) constant_0 = f32[] constant(0) ROOT reduce = f32[127]{0} reduce(parameter_0, constant_0), dimensions={0}, to_apply=add -} - -ENTRY main { - parameter_0 = f32[125,127]{1,0} parameter(0) - ROOT triton_op = f32[127]{0} fusion(parameter_0), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( @@ -520,19 +449,11 @@ add { ROOT pair = (f32[], f32[]) tuple(add_0, add_1) } -triton_computation { +ENTRY triton_computation { parameter_0 = f32[125,127] parameter(0) constant_0 = f32[] constant(0) tuple_0 = (f32[125]{0}, f32[125]{0}) reduce(parameter_0, parameter_0, constant_0, constant_0), dimensions={1}, to_apply=add ROOT reduce = f32[125]{0} get-tuple-element(tuple_0), index=0 -} - -ENTRY main { - parameter_0 = f32[125,127]{1,0} parameter(0) - ROOT triton_op = f32[125]{0} fusion(parameter_0), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( @@ -557,19 +478,10 @@ add { ROOT add = f32[] add(Arg_0, Arg_1) } -triton_computation { +ENTRY triton_computation { parameter_0 = f32[125,127]{1,0} parameter(0) init = f32[] parameter(1) ROOT reduce = f32[125]{0} reduce(parameter_0, init), dimensions={1}, to_apply=add -} - -ENTRY main { - parameter_0 = f32[125,127]{1,0} parameter(0) - parameter_1 = f32[] parameter(1) - ROOT triton_op = f32[125]{0} fusion(parameter_0, parameter_1), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( @@ -599,18 +511,10 @@ custom_call { ROOT custom_call = f32[] custom-call(Arg_0, Arg_1), custom_call_target="foo" } -triton_computation { +ENTRY triton_computation { parameter_0 = f32[125,127]{1,0} parameter(0) constant_0 = f32[] constant(0) ROOT reduce = f32[125]{0} reduce(parameter_0, constant_0), dimensions={1}, to_apply=custom_call -} - -ENTRY main { - parameter_0 = f32[125,127]{1,0} parameter(0) - ROOT triton_op = f32[125]{0} fusion(parameter_0), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( diff --git a/third_party/xla/xla/service/gpu/triton_test_utils.cc b/third_party/xla/xla/service/gpu/triton_test_utils.cc index bb5bbe765f858a..97e5925d8a42f2 100644 --- a/third_party/xla/xla/service/gpu/triton_test_utils.cc +++ b/third_party/xla/xla/service/gpu/triton_test_utils.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include #include +#include #include +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -41,6 +43,7 @@ limitations under the License. #include "xla/service/float_normalization.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/gpu_float_support.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/ir_emitter_triton.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/hlo_pass_pipeline.h" @@ -48,6 +51,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/tests/filecheck.h" #include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace xla::gpu { @@ -140,6 +144,46 @@ std::string TritonSupportTestParamsToString( absl::StrReplaceAll(HloOpcodeString(opcode), {{"-", "_"}})); } +namespace { + +// This function does nothing if the input module already has an entry +// computation whose root is a fusion. Otherwise, creates a new entry +// computation whose root is a fusion instruction that calls the original entry +// computation. The new fusion instruction uses the generic Triton backend kind. +absl::Status ConvertEntryToTritonFusion(HloModule* module) { + if (module->entry_computation()->root_instruction()->opcode() == + HloOpcode::kFusion) { + return absl::OkStatus(); + } + auto builder = HloComputation::Builder("entry"); + std::vector params; + for (auto& param : module->entry_computation()->parameter_instructions()) { + TF_ASSIGN_OR_RETURN( + auto param_clone, + builder.AddParameter(HloInstruction::CreateParameter( + param->parameter_number(), param->shape(), + absl::StrCat("param_", param->parameter_number())))); + params.push_back(param_clone); + } + + auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( + module->entry_computation()->root_instruction()->shape(), + HloInstruction::FusionKind::kCustom, params, + module->entry_computation())); + + gpu::GpuBackendConfig gpu_config; + gpu_config.mutable_fusion_backend_config()->set_kind(kTritonFusionKind); + TF_RETURN_IF_ERROR(fusion->set_backend_config(gpu_config)); + + auto new_entry = + module->AddComputationAndUnifyNamesAndIds(builder.Build(), + /*is_entry=*/false); + module->ReplaceEntryComputation(new_entry); + return absl::OkStatus(); +} + +} // namespace + absl::StatusOr TritonSupportTest::ParseTemplateAndGetInstruction( absl::string_view hlo_template, xla::PrimitiveType data_type, @@ -149,8 +193,14 @@ TritonSupportTest::ParseTemplateAndGetInstruction( HloOpcodeString(opcode)); TF_ASSIGN_OR_RETURN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); + TF_RETURN_IF_ERROR(ConvertEntryToTritonFusion(module.get())); const HloComputation* computation = module->GetComputationWithName("triton_computation"); + if (computation == module->entry_computation()) { + return absl::InvalidArgumentError( + "The `triton_computation` and the module's entry computation cannot be " + "the same."); + } const HloFusionInstruction* fusion = DynCast( module->entry_computation()->root_instruction()); if (fusion == nullptr) { diff --git a/third_party/xla/xla/service/gpu/triton_test_utils.h b/third_party/xla/xla/service/gpu/triton_test_utils.h index de503bf8bd36f8..c7cb16abdf36ec 100644 --- a/third_party/xla/xla/service/gpu/triton_test_utils.h +++ b/third_party/xla/xla/service/gpu/triton_test_utils.h @@ -129,6 +129,13 @@ class TritonSupportTest : public TritonFilecheckTest { // The provided template must contain a computation called // `triton_computation`. If the template contains parameters $0 and $1, they // will be replaced with the data type and opcode respectively. + // If the template's entry computation does not have a root fusion + // instruction, a new entry computation will be created. The new computation + // will have a root fusion instruction that has the same parameters as the + // `triton_computation` and contains a fusion instruction that calls the + // `triton_computation` with the generic Triton emitter. Tests that need + // the `__triton_gemm` backend kind should provide their own ENTRY + // computation. absl::StatusOr ParseTemplateAndGetInstruction( absl::string_view hlo_template, xla::PrimitiveType data_type, xla::HloOpcode opcode); From c72dbfc17b8aba90da3d26036387016cee23bc2f Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Wed, 19 Jun 2024 03:40:56 -0700 Subject: [PATCH 4/6] [XLA] Remove proto-based communication for service/client Originally, service/client interface in XLA was envisioned as a compilation service interface, with proto-based communication for RPCs. That compilation service never quite worked in that shape, and in the meantime, main compilation API moved to PjRT. Having lots of protos around complicates the XLA compilation API, and serialization/deserialization to protos also hurts performance. Since all Service usages are realistically local, let's just not use protos at the boundary and pass original datastructures. PiperOrigin-RevId: 644695252 --- third_party/xla/xla/client/BUILD | 1 - third_party/xla/xla/client/client.cc | 302 +------------- third_party/xla/xla/client/client.h | 27 +- third_party/xla/xla/client/global_data.cc | 73 ---- third_party/xla/xla/client/global_data.h | 33 +- third_party/xla/xla/client/local_client.cc | 7 +- third_party/xla/xla/client/local_client.h | 2 +- third_party/xla/xla/service/BUILD | 3 + .../xla/xla/service/compile_only_service.h | 25 +- third_party/xla/xla/service/service.cc | 369 +++++++++--------- third_party/xla/xla/service/service.h | 120 ++++-- third_party/xla/xla/tests/BUILD | 1 + third_party/xla/xla/tests/client_test.cc | 4 +- .../xla/xla/tests/gather_operation_test.cc | 3 +- third_party/xla/xla/xla.proto | 183 --------- 15 files changed, 325 insertions(+), 828 deletions(-) delete mode 100644 third_party/xla/xla/client/global_data.cc diff --git a/third_party/xla/xla/client/BUILD b/third_party/xla/xla/client/BUILD index e0c53df856f3b1..400eb877cbbce1 100644 --- a/third_party/xla/xla/client/BUILD +++ b/third_party/xla/xla/client/BUILD @@ -29,7 +29,6 @@ filegroup( cc_library( name = "global_data", - srcs = ["global_data.cc"], hdrs = ["global_data.h"], deps = [ "//xla:types", diff --git a/third_party/xla/xla/client/client.cc b/third_party/xla/xla/client/client.cc index a58737a181fd84..6e89947f237ff7 100644 --- a/third_party/xla/xla/client/client.cc +++ b/third_party/xla/xla/client/client.cc @@ -41,129 +41,28 @@ Client::~Client() = default; absl::StatusOr Client::Transfer(const GlobalData& data, const Shape* shape_with_layout) { - TransferToClientRequest request; - *request.mutable_data() = data.handle(); - if (shape_with_layout != nullptr) { - *request.mutable_shape_with_layout() = shape_with_layout->ToProto(); - } - TransferToClientResponse response; - - VLOG(1) << "making transfer request"; - VLOG(3) << "TransferToClientRequest: {" << request.DebugString() << "}"; - absl::Status s = stub_->TransferToClient(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - VLOG(3) << "TransferToClientResponse: {" << response.DebugString() << "}"; - - if (!response.has_literal()) { - return FailedPrecondition( - "server provided response without a literal in " - "TransferToClient request"); - } - return Literal::CreateFromProto(*response.mutable_literal()); + return stub_->TransferToClient(data, shape_with_layout); } absl::StatusOr> Client::TransferToServer( const LiteralSlice& literal, const DeviceHandle* device_handle) { - TransferToServerRequest request; - *request.mutable_literal() = literal.ToProto(); - if (device_handle) { - *request.mutable_device_handle() = *device_handle; - } - TransferToServerResponse response; - - VLOG(1) << "making transfer to server request"; - VLOG(3) << "TransferToServerRequest: {" << request.DebugString() << "}"; - absl::Status s = stub_->TransferToServer(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - VLOG(3) << "TransferToServerResponse: {" << response.DebugString() << "}"; - - if (!response.has_data()) { - return FailedPrecondition( - "server provided response without a data handle in " - "TransferToServer request"); - } - - return std::make_unique(stub_, response.data()); + return stub_->TransferToServer(literal, device_handle); } absl::Status Client::TransferToInfeed(const LiteralSlice& literal, int64_t replica_id, const DeviceHandle* device_handle) { - TransferToInfeedRequest request; - *request.mutable_literal() = literal.ToProto(); - if (device_handle) { - *request.mutable_device_handle() = *device_handle; - } - request.set_replica_id(replica_id); - TransferToInfeedResponse response; - - VLOG(1) << "making transfer to infeed request"; - VLOG(3) << "TransferToInfeedRequest: {" << request.DebugString() << "}"; - absl::Status s = stub_->TransferToInfeed(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - VLOG(3) << "TransferToInfeedResponse: {" << response.DebugString() << "}"; - return absl::OkStatus(); + return stub_->TransferToInfeed(literal, replica_id, device_handle); } absl::StatusOr Client::TransferFromOutfeed( const Shape* shape_with_layout, int64_t replica_id, const DeviceHandle* device_handle) { - TransferFromOutfeedRequest request; - if (device_handle) { - *request.mutable_device_handle() = *device_handle; - } - request.set_replica_id(replica_id); - if (shape_with_layout != nullptr) { - *request.mutable_shape_with_layout() = shape_with_layout->ToProto(); - } - TransferFromOutfeedResponse response; - - VLOG(1) << "making transfer from outfeed request"; - VLOG(3) << "TransferFromOutfeedRequest: {" << request.DebugString() << "}"; - absl::Status s = stub_->TransferFromOutfeed(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - VLOG(3) << "TransferFromOutfeedResponse: {" << response.DebugString() << "}"; - - if (!response.has_literal()) { - return FailedPrecondition( - "server provided response without a literal in " - "TransferToClient request"); - } - - return Literal::CreateFromProto(response.literal()); + return stub_->TransferFromOutfeed(shape_with_layout, replica_id, + device_handle); } -absl::Status Client::ResetDevice() { - ResetDeviceRequest request; - ResetDeviceResponse response; - - VLOG(1) << "making reset device request"; - VLOG(3) << "ResetDeviceRequest: {" << request.DebugString() << "}"; - absl::Status s = stub_->ResetDevice(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - VLOG(3) << "ResetDeviceResponse: {" << response.DebugString() << "}"; - return absl::OkStatus(); -} +absl::Status Client::ResetDevice() { return stub_->ResetDevice(); } absl::StatusOr Client::ExecuteAndTransfer( const XlaComputation& computation, absl::Span arguments, @@ -185,30 +84,7 @@ absl::StatusOr Client::ExecuteAndTransfer( absl::StatusOr Client::ComputeConstant( const XlaComputation& computation, const Layout* output_layout) const { - ComputeConstantGraphRequest request; - *request.mutable_computation() = computation.proto(); - if (output_layout != nullptr) { - *request.mutable_output_layout() = output_layout->ToProto(); - } - - ComputeConstantResponse response; - - VLOG(2) << "making compute-constant-graph request"; - absl::Status s = stub_->ComputeConstantGraph(&request, &response); - VLOG(2) << "done with request"; - - if (!s.ok()) { - return s; - } - - VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}"; - - if (!response.has_literal()) { - return Internal( - "no computed literal in the provided response in ComputeConstantGraph " - "request"); - } - return Literal::CreateFromProto(response.literal()); + return stub_->ComputeConstantGraph(computation, output_layout); } absl::StatusOr Client::LoadSnapshot(const HloSnapshot& module) { @@ -219,61 +95,19 @@ absl::StatusOr Client::LoadSnapshot(const HloSnapshot& module) { absl::StatusOr Client::Compile( const XlaComputation& computation, absl::Span argument_shapes, const ExecutionOptions* execution_options) { - CompileRequest request; - *request.mutable_computation() = computation.proto(); - - if (execution_options == nullptr) { - *request.mutable_execution_options() = CreateDefaultExecutionOptions(); - } else { - *request.mutable_execution_options() = *execution_options; - } - if (request.execution_options().device_handles_size() > 1) { - return InvalidArgument( - "Compiling with multiple device handles is not supported. Use " - "'Execute' instead."); - } - - // The argument shapes affect how the computation is compiled. - for (const auto& arg_shape : argument_shapes) { - *request.add_input_shape_with_layout() = arg_shape.ToProto(); + std::optional opts; + if (!execution_options) { + opts = CreateDefaultExecutionOptions(); } - CompileResponse response; - VLOG(1) << "making compile request: " << request.ShortDebugString(); - absl::Status s = stub_->Compile(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - TF_RET_CHECK(response.has_handle()); - return response.handle(); + return stub_->Compile(computation, argument_shapes, + execution_options ? *execution_options : *opts); } absl::StatusOr> Client::Execute( const ExecutionHandle& handle, absl::Span arguments, ExecutionProfile* execution_profile) { - ExecuteRequest request; - *request.mutable_handle() = handle; - for (GlobalData* argument : arguments) { - CHECK(argument != nullptr) << "Argument pointers must not be null."; - *request.add_arguments() = argument->handle(); - } - - ExecuteResponse response; - VLOG(1) << "making execute request: " << request.ShortDebugString(); - absl::Status s = stub_->Execute(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - if (execution_profile != nullptr) { - *execution_profile = response.profile(); - } - - return std::make_unique(stub_, response.output()); + return stub_->Execute(handle, arguments, execution_profile); } absl::StatusOr> Client::Execute( @@ -329,39 +163,7 @@ absl::StatusOr> Client::Execute( absl::StatusOr>> Client::ExecuteParallel(absl::Span computations) { - ExecuteGraphParallelRequest request; - - for (const XlaComputationInstance& computation : computations) { - ExecuteGraphRequest single_request; - *single_request.mutable_computation() = computation.computation.proto(); - for (GlobalData* argument : computation.arguments) { - *single_request.add_arguments() = argument->handle(); - } - *single_request.mutable_execution_options() = computation.execution_options; - *request.add_requests() = single_request; - } - - ExecuteParallelResponse response; - VLOG(1) << "making execute-graph-parallel request: " - << request.ShortDebugString(); - absl::Status s = stub_->ExecuteGraphParallel(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - std::vector> outputs; - for (size_t i = 0, end = response.responses_size(); i < end; ++i) { - outputs.push_back( - std::make_unique(stub_, response.responses(i).output())); - if (i < computations.size() && - computations[i].execution_profile != nullptr) { - *computations[i].execution_profile = response.responses(i).profile(); - } - } - - return std::move(outputs); + return stub_->ExecuteGraphParallel(computations); } absl::StatusOr> Client::GetDeviceHandles( @@ -369,59 +171,17 @@ absl::StatusOr> Client::GetDeviceHandles( if (device_count < 1) { return InvalidArgument("device_count must be greater than 0"); } - GetDeviceHandlesRequest request; - request.set_device_count(device_count); - GetDeviceHandlesResponse response; - VLOG(1) << "making get device request: " << request.ShortDebugString(); - absl::Status s = stub_->GetDeviceHandles(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - std::vector device_handles; - const auto& response_device_handles = response.device_handles(); - device_handles.reserve(response_device_handles.size()); - for (const DeviceHandle& device_handle : response_device_handles) { - device_handles.push_back(device_handle); - } - - return device_handles; + return stub_->GetDeviceHandles(device_count); } absl::Status Client::Unregister(const GlobalData& data) { - UnregisterRequest request; - *request.add_data() = data.handle(); - UnregisterResponse response; - - VLOG(1) << "making unregister request"; - absl::Status s = stub_->Unregister(&request, &response); - VLOG(1) << "done with request"; - - return s; + return stub_->Unregister(data.handle()); } absl::StatusOr>> Client::DeconstructTuple(const GlobalData& data) { - DeconstructTupleRequest request; - *request.mutable_tuple_handle() = data.handle(); - DeconstructTupleResponse response; - - VLOG(1) << "making DestructTuple request"; - absl::Status s = stub_->DeconstructTuple(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - std::vector> handles; - for (auto& handle : response.element_handles()) { - handles.push_back(std::make_unique(stub_, handle)); - } - return std::move(handles); + return stub_->DeconstructTuple(data); } absl::StatusOr> Client::GetComputationShape( @@ -431,36 +191,12 @@ absl::StatusOr> Client::GetComputationShape( } absl::StatusOr Client::GetShape(const GlobalData& data) { - GetShapeRequest request; - *request.mutable_data() = data.handle(); - GetShapeResponse response; - - VLOG(1) << "making get shape request"; - absl::Status s = stub_->GetShape(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - return Shape(response.shape()); + return stub_->GetShape(data); } absl::StatusOr Client::CreateChannelHandleByType( ChannelHandle::ChannelType type) { - CreateChannelHandleRequest request; - request.set_channel_type(type); - CreateChannelHandleResponse response; - - VLOG(1) << "making create channel handle request"; - absl::Status s = stub_->CreateChannelHandle(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - return response.channel(); + return stub_->CreateChannelHandle(type); } absl::StatusOr Client::CreateChannelHandle() { diff --git a/third_party/xla/xla/client/client.h b/third_party/xla/xla/client/client.h index 66d82ec79f744e..1ecfcfe6f358eb 100644 --- a/third_party/xla/xla/client/client.h +++ b/third_party/xla/xla/client/client.h @@ -22,7 +22,6 @@ limitations under the License. #include #include "absl/types/span.h" -#include "xla/client/global_data.h" #include "xla/client/xla_computation.h" #include "xla/literal.h" #include "xla/service/hlo.pb.h" @@ -41,6 +40,8 @@ class Client { explicit Client(Service* stub); virtual ~Client(); + using XlaComputationInstance = XlaComputationInstance; + // Compile the computation with the given argument shapes and returns the // handle to the compiled executable. The compiled executable is cached on the // service, and the returned handle can be used for execution without @@ -70,7 +71,9 @@ class Client { // will be filled with profile data from the execution. absl::StatusOr> Execute( const ExecutionHandle& handle, absl::Span arguments, - ExecutionProfile* execution_profile = nullptr); + ExecutionProfile* execution_profile = nullptr + + ); // Executes the computation with the given arguments and returns the global // data that was produced from the execution. @@ -93,26 +96,6 @@ class Client { const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); - // A struct to represent a computation instance to be executed. - // * If execution_options.device_handles is not empty, the computation is - // executed on the devices associated with the handles by partitioning the - // computation based on the attached sharding attributes. Otherwise, a - // device is chosen by the service. - struct XlaComputationInstance { - const XlaComputation& computation; - std::vector arguments; - ExecutionOptions execution_options; - ExecutionProfile* execution_profile; - - XlaComputationInstance(const XlaComputation& computation, - std::vector arguments, - ExecutionOptions execution_options, - ExecutionProfile* execution_profile) - : computation(computation), - arguments(std::move(arguments)), - execution_options(execution_options), - execution_profile(execution_profile) {} - }; // Executes a list XlaComputationInstances and returns global data produced // from each computation. diff --git a/third_party/xla/xla/client/global_data.cc b/third_party/xla/xla/client/global_data.cc deleted file mode 100644 index 66a5c5fee61673..00000000000000 --- a/third_party/xla/xla/client/global_data.cc +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/client/global_data.h" - -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "xla/types.h" -#include "tsl/platform/logging.h" - -namespace xla { -namespace { - -// Releases a set of global data handles owned by the parent service -// interface. -void ReleaseHandles(Service* parent, - const absl::Span handles) { - UnregisterRequest request; - for (auto& handle : handles) { - VLOG(1) << "Requesting to unregister " << handle.ShortDebugString(); - *request.add_data() = handle; - } - UnregisterResponse response; - absl::Status status = parent->Unregister(&request, &response); - VLOG(1) << "Done with request"; - if (!status.ok()) { - LOG(WARNING) << "Failed to unregister handles: " << status - << "; continuing anyway..."; - } -} - -} // namespace - -GlobalData::GlobalData(Service* parent, GlobalDataHandle handle) - : handle_(std::move(handle)), parent_(parent) {} - -GlobalData::~GlobalData() { - if (parent_ != nullptr) { - ReleaseHandles(parent_, {handle_}); - } -} - -/* static */ void GlobalData::Release( - std::vector> instances) { - absl::flat_hash_map> - parent_handles_map; - for (auto& instance : instances) { - if (instance->parent_ != nullptr) { - parent_handles_map[instance->parent_].push_back(instance->Release()); - } - } - for (auto& parent_handles : parent_handles_map) { - ReleaseHandles(parent_handles.first, parent_handles.second); - } -} - -} // namespace xla diff --git a/third_party/xla/xla/client/global_data.h b/third_party/xla/xla/client/global_data.h index 790a92a97c26bc..d47209edee377b 100644 --- a/third_party/xla/xla/client/global_data.h +++ b/third_party/xla/xla/client/global_data.h @@ -26,37 +26,8 @@ limitations under the License. namespace xla { -// A GlobalData object represents a globally-accessible allocation of -// data in the associated XLA service. -class GlobalData { - public: - // Gives ownership of the global data handle to this object. - GlobalData(Service* parent, GlobalDataHandle handle); - - // Unregisters the wrapped handle, which causes the service to - // deallocate the associated data. - ~GlobalData(); - - const GlobalDataHandle& handle() const { return handle_; } - - // Releases a set of GlobalData handles. A single RPC will be issued - // per unique Service of the given GlobalData objects. - static void Release(std::vector> instances); - - private: - // Detaches the global data handle from the object, such that the destructor - // will not try to release it. - GlobalDataHandle Release() { - parent_ = nullptr; - return handle_; - } - - GlobalDataHandle handle_; // Handle being wrapped. - Service* parent_; // Service used to unregister handle_. - - GlobalData(const GlobalData&) = delete; - GlobalData& operator=(const GlobalData&) = delete; -}; +// TODO(cheshire): Remove. +// Deprecated target for backwards compatibility. } // namespace xla diff --git a/third_party/xla/xla/client/local_client.cc b/third_party/xla/xla/client/local_client.cc index b45cb62eaad3cb..e00f39143bb6ee 100644 --- a/third_party/xla/xla/client/local_client.cc +++ b/third_party/xla/xla/client/local_client.cc @@ -493,7 +493,7 @@ absl::StatusOr LocalClient::ReplicaNumberToDeviceOrdinal( return local_service_->ReplicaNumberToDeviceOrdinal(replica_number); } -absl::StatusOr LocalClient::TransferToLocalServer( +absl::StatusOr LocalClient::TransferToLocalServer( const ::xla::BorrowingLiteral& literal, int device_ordinal) { const ::xla::Shape& shape = literal.shape(); @@ -506,14 +506,13 @@ absl::StatusOr LocalClient::TransferToLocalServer( stream.get(), literal, shaped_buffer)); std::vector<::xla::ScopedShapedBuffer> replicated_buffer; replicated_buffer.emplace_back(std::move(shaped_buffer)); - ::xla::TransferToServerResponse result; - TF_ASSIGN_OR_RETURN(*result.mutable_data(), + TF_ASSIGN_OR_RETURN(GlobalDataHandle data, local_service_->RegisterReplicatedBuffers( std::move(replicated_buffer), absl::StrCat("TransferToServer literal of shape ", ::xla::ShapeUtil::HumanString(shape)))); - return result; + return data; } } // namespace xla diff --git a/third_party/xla/xla/client/local_client.h b/third_party/xla/xla/client/local_client.h index 0d7ee75b07beab..f26c67ced132c8 100644 --- a/third_party/xla/xla/client/local_client.h +++ b/third_party/xla/xla/client/local_client.h @@ -171,7 +171,7 @@ class LocalClient : public Client { se::DeviceMemoryAllocator* allocator = nullptr); // Transfer the BorrowingLiteral to the device with the given ordinal. - absl::StatusOr TransferToLocalServer( + absl::StatusOr TransferToLocalServer( const ::xla::BorrowingLiteral& literal, int device_ordinal); // Copy the data from the device contained in the given ShapedBuffer and diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 99bb153d1d32c1..3d1f6e42012fa7 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -1238,6 +1238,7 @@ cc_library( "//xla:debug_options_flags", "//xla:executable_run_options", "//xla:execution_options_util", + "//xla:literal", "//xla:shape_layout", "//xla:shape_util", "//xla:status_macros", @@ -1245,11 +1246,13 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", + "//xla/client:xla_computation", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", diff --git a/third_party/xla/xla/service/compile_only_service.h b/third_party/xla/xla/service/compile_only_service.h index 27f4b4c97626f1..0238a16f282946 100644 --- a/third_party/xla/xla/service/compile_only_service.h +++ b/third_party/xla/xla/service/compile_only_service.h @@ -53,30 +53,23 @@ class CompileOnlyService : public Service { const AotCompilationOptions& options, std::unique_ptr* metadata); - absl::Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) override { + absl::StatusOr> GetDeviceHandles( + int64_t device_count) override { return Unimplemented("CompileOnlyService does not support devices."); } - absl::Status TransferToServer(const TransferToServerRequest* arg, - TransferToServerResponse* result) override { - return Unimplemented( - "CompileOnlyService does not support device data transfers."); - } - absl::Status TransferToInfeed(const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) override { + + absl::StatusOr> TransferToServer( + const LiteralSlice& literal_slice, + const DeviceHandle* device_handle) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - absl::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) override { + + absl::Status TransferToInfeed(const LiteralSlice& literal, int64_t replica_id, + const DeviceHandle* device_handle) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - absl::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) override { - return Unimplemented("CompileOnlyService does not support devices."); - } private: explicit CompileOnlyService(const ServiceOptions& options, diff --git a/third_party/xla/xla/service/service.cc b/third_party/xla/xla/service/service.cc index 207eae0ab440cc..4a1f323850d9e7 100644 --- a/third_party/xla/xla/service/service.cc +++ b/third_party/xla/xla/service/service.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "xla/debug_options_flags.h" @@ -33,6 +34,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" #include "xla/layout_util.h" +#include "xla/literal.h" #include "xla/service/backend.h" #include "xla/service/compiler.h" #include "xla/service/computation_layout.h" @@ -160,36 +162,26 @@ Service::Service(const ServiceOptions& options, } } -absl::Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) { - TF_ASSIGN_OR_RETURN(*result->mutable_channel(), - channel_tracker_.NewChannel(arg->channel_type())); - return absl::OkStatus(); +absl::StatusOr Service::CreateChannelHandle( + ChannelHandle::ChannelType type) { + return channel_tracker_.NewChannel(type); } -absl::Status Service::Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) { - absl::Status status; - for (auto& data : arg->data()) { - absl::Status unregister_status = allocation_tracker_.Unregister(data); - if (!unregister_status.ok() && status.ok()) { - status = unregister_status; - } - } - return status; +absl::Status Service::Unregister(const GlobalDataHandle& data) { + return allocation_tracker_.Unregister(data); } // Deconstructs a previously-allocated global handle. -absl::Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result) { - TF_ASSIGN_OR_RETURN( - std::vector elements, - allocation_tracker_.DeconstructTuple(arg->tuple_handle())); - - for (auto& element : elements) { - *result->add_element_handles() = element; +absl::StatusOr>> +Service::DeconstructTuple(const GlobalData& data) { + TF_ASSIGN_OR_RETURN(std::vector elements, + allocation_tracker_.DeconstructTuple(data.handle())); + std::vector> out; + out.reserve(elements.size()); + for (GlobalDataHandle& element : elements) { + out.push_back(std::make_unique(this, element)); } - return absl::OkStatus(); + return out; } absl::Status Service::ValidateResultShape(const Shape& client_shape, @@ -207,20 +199,14 @@ absl::Status Service::ValidateResultShape(const Shape& client_shape, absl::StatusOr>> Service::ResolveAndValidateArguments( - absl::Span arguments, + absl::Span arguments, absl::Span stream_executors) const { CHECK_EQ(options_.number_of_replicas(), stream_executors.size()); std::vector> replicated_arguments; replicated_arguments.resize(options_.number_of_replicas()); for (size_t i = 0; i < arguments.size(); ++i) { - auto buffer_status = allocation_tracker_.Resolve(*arguments[i]); - if (!buffer_status.ok()) { - return tsl::errors::CreateWithUpdatedMessage( - buffer_status.status(), - StrCat(buffer_status.status().message(), ", ", - "failed to resolve allocation for parameter ", i)); - } - auto replicated_buffers = buffer_status.value(); + TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, + allocation_tracker_.Resolve(arguments[i]->handle())); CHECK_EQ(options_.number_of_replicas(), replicated_buffers.size()); for (int replica = 0; replica < options_.number_of_replicas(); ++replica) { const ShapedBuffer* shaped_buffer = replicated_buffers[replica]; @@ -515,9 +501,8 @@ absl::StatusOr> Service::GetExecutors( } absl::StatusOr>> -Service::GetArguments( - const ExecutionOptions& execution_options, - absl::Span arguments) const { +Service::GetArguments(const ExecutionOptions& execution_options, + absl::Span arguments) const { // Resolve the allocations for the arguments of the computation, and create // a vector of device memory offsets for the arguments from the allocations. // In the case of partitioned computations, assume all arguments go on the @@ -531,8 +516,9 @@ Service::GetArguments( return replicated_arguments; } -absl::Status Service::ExecuteGraphParallel( - const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) { +absl::StatusOr>> +Service::ExecuteGraphParallel( + absl::Span computations) { VLOG(1) << "running execute-graph-parallel request"; std::vector>> all_arguments; @@ -543,10 +529,11 @@ absl::Status Service::ExecuteGraphParallel( std::vector device_handles; int num_requested_devices = - std::accumulate(arg->requests().begin(), arg->requests().end(), 0, - [](int a, const ExecuteGraphRequest& r) -> int { - return a + r.execution_options().device_handles_size(); + std::accumulate(computations.begin(), computations.end(), 0, + [](int a, const XlaComputationInstance& r) -> int { + return a + r.execution_options.device_handles_size(); }); + if (num_requested_devices * options_.number_of_replicas() > execute_backend_->device_count()) { return FailedPrecondition( @@ -554,23 +541,24 @@ absl::Status Service::ExecuteGraphParallel( num_requested_devices); } - for (int64_t i = 0; i < arg->requests_size(); ++i) { + for (int64_t i = 0; i < computations.size(); ++i) { + const XlaComputationInstance& computation = computations[i]; + // Get the stream executor for the i'th computation. This stream executor // is one of the executors to run the replicated computation. - const ExecutionOptions& execution_options = - arg->requests(i).execution_options(); - const ExecuteGraphRequest& request = arg->requests(i); - TF_RET_CHECK(request.has_computation()) << "computations may not be empty"; - TF_RET_CHECK(request.computation().has_host_program_shape()) + const ExecutionOptions& execution_options = computation.execution_options; + TF_RET_CHECK(computation.computation.proto().has_host_program_shape()) << "program shape may not be empty"; // Get the executors. - TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options, - arg->requests_size(), i)); + TF_ASSIGN_OR_RETURN( + std::vector executors, + GetExecutors(execution_options, computations.size(), i)); // Get the replicated arguments. - TF_ASSIGN_OR_RETURN(auto replicated_arguments, - GetArguments(execution_options, request.arguments())); + TF_ASSIGN_OR_RETURN( + std::vector> replicated_arguments, + GetArguments(execution_options, computation.arguments)); for (auto& args : replicated_arguments) { for (auto& arg : args) { @@ -596,8 +584,8 @@ absl::Status Service::ExecuteGraphParallel( TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig( - ProgramShape{request.computation().host_program_shape()}, - replicated_arguments.front(), request.execution_options())); + ProgramShape{computation.computation.proto().host_program_shape()}, + replicated_arguments.front(), computation.execution_options)); VLOG(3) << "ExecuteGraphParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -605,10 +593,10 @@ absl::Status Service::ExecuteGraphParallel( // Adds to the vectors to build and execute the computations after the loop. all_arguments.push_back(replicated_arguments); all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}}); - module_protos.push_back(&request.computation()); + module_protos.push_back(&computation.computation.proto()); module_configs.push_back(std::move(module_config)); computation_names.insert(computation_names.end(), executors.size(), - request.computation().name()); + computation.computation.name()); all_executors.push_back(executors); device_handles.insert(device_handles.end(), execution_options.device_handles().begin(), @@ -680,6 +668,12 @@ absl::Status Service::ExecuteGraphParallel( } } + for (int64_t i = 0; i < computations.size(); ++i) { + if (computations[i].execution_profile != nullptr) { + *computations[i].execution_profile = profile; + } + } + if (!execution_status.ok()) { // Execution failed so we don't have the results. Dump the HLO snapshot // with just the program arguments. @@ -690,11 +684,11 @@ absl::Status Service::ExecuteGraphParallel( TF_RETURN_IF_ERROR(execution_status); - for (const GlobalDataHandle& output : outputs) { - ExecuteResponse response; - *response.mutable_output() = output; - *response.mutable_profile() = profile; - *result->add_responses() = response; + std::vector> out; + + out.reserve(out.size()); + for (GlobalDataHandle& output : outputs) { + out.push_back(std::make_unique(this, output)); } for (int i = 0, end = executable_ptrs.size(); i < end; i++) { @@ -712,31 +706,32 @@ absl::Status Service::ExecuteGraphParallel( } VLOG(1) << "successfully completed 'execute-graph-parallel' request"; - return absl::OkStatus(); + return out; } -absl::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) { +absl::StatusOr> Service::GetDeviceHandles( + int64_t device_count) { const int64_t available_device_count = execute_backend_->device_count(); const int64_t replica_count = options_.number_of_replicas(); if (replica_count <= 0) { return FailedPrecondition("Replica count must be a positive integer"); } - if (available_device_count < arg->device_count() * replica_count) { + if (available_device_count < device_count * replica_count) { return ResourceExhausted( "Requested logical device count (%d) with replica count (%d) exceeds " "the number of available physical devices on the target (%d)", - arg->device_count(), replica_count, available_device_count); + device_count, replica_count, available_device_count); } - for (int64_t i = 0; i < arg->device_count(); ++i) { + std::vector out; + + for (int64_t i = 0; i < device_count; ++i) { DeviceHandle device_handle; device_handle.set_handle(i); - device_handle.set_device_count(arg->device_count()); - *result->add_device_handles() = device_handle; + device_handle.set_device_count(device_count); + out.push_back(device_handle); } - - return absl::OkStatus(); + return out; } absl::StatusOr> Service::BuildExecutable( @@ -795,71 +790,66 @@ absl::StatusOr> Service::BuildExecutable( return executable; } -absl::Status Service::Compile(const CompileRequest* arg, - CompileResponse* result) { +absl::StatusOr Service::Compile( + const XlaComputation& computation, absl::Span argument_shapes, + const ExecutionOptions& execution_options) { VLOG(1) << "running compile request"; - if (!arg->has_computation()) { - return InvalidArgument("computations may not be empty"); - } - if (!arg->computation().has_host_program_shape()) { + + if (!computation.proto().has_host_program_shape()) { return InvalidArgument("program shape may not be empty"); } - if (arg->execution_options().device_handles_size() > 1) { + if (execution_options.device_handles_size() > 1) { return InvalidArgument( "The compile request does not support multiple device handles."); } - std::vector argument_shapes; - argument_shapes.reserve(arg->input_shape_with_layout_size()); std::vector argument_shape_ptrs; - for (const ShapeProto& shape_proto : arg->input_shape_with_layout()) { - argument_shapes.push_back(Shape(shape_proto)); - argument_shape_ptrs.push_back(&argument_shapes.back()); + for (const Shape& shape : argument_shapes) { + argument_shape_ptrs.push_back(&shape); } + TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(ProgramShape{arg->computation().host_program_shape()}, - argument_shape_ptrs, &arg->execution_options())); + CreateModuleConfig(ProgramShape{computation.proto().host_program_shape()}, + argument_shape_ptrs, &execution_options)); VLOG(3) << "Compile created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); TF_ASSIGN_OR_RETURN( std::unique_ptr executable, - BuildExecutable(arg->computation(), std::move(module_config), + BuildExecutable(computation.proto(), std::move(module_config), execute_backend_.get(), execute_backend_->default_stream_executor(), {/*device_allocator=*/nullptr})); - *result->mutable_handle() = compilation_cache_.Insert(std::move(executable)); - VLOG(1) << "successfully completed 'compile' request"; - return absl::OkStatus(); + return compilation_cache_.Insert(std::move(executable)); } -absl::Status Service::Execute(const ExecuteRequest* arg, - ExecuteResponse* result) { +absl::StatusOr> Service::Execute( + const ExecutionHandle& handle, absl::Span arguments, + ExecutionProfile* execution_profile) { VLOG(1) << "running execute request"; - if (!arg->has_handle()) { - return InvalidArgument("execution handle should not be empty"); - } - TF_ASSIGN_OR_RETURN(auto executable, - compilation_cache_.LookUp(arg->handle())); - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, - SingleComputationDeviceHandle())); + TF_ASSIGN_OR_RETURN(std::shared_ptr executable, + compilation_cache_.LookUp(handle)); + + TF_ASSIGN_OR_RETURN( + std::vector replicas, + Replicas(*execute_backend_, SingleComputationDeviceHandle())); TF_ASSIGN_OR_RETURN( std::vector> replicated_arguments, - ResolveAndValidateArguments(arg->arguments(), replicas)); + ResolveAndValidateArguments(arguments, replicas)); // Check that the replicated_arguments has the same shape and layout as the // module config used when creating the executable. const int64_t num_module_args = executable->module_config().entry_computation_layout().parameter_count(); - if (num_module_args != arg->arguments_size()) { + if (num_module_args != arguments.size()) { return InvalidArgument( "The executable expects %lld arguments, but sees %lld.", - num_module_args, arg->arguments_size()); + num_module_args, arguments.size()); } for (int64_t i = 0; i < num_module_args; i++) { const Shape& shape_module = @@ -887,35 +877,32 @@ absl::Status Service::Execute(const ExecuteRequest* arg, } TF_ASSIGN_OR_RETURN( - *result->mutable_output(), - ExecuteAndRegisterResult(executable.get(), replicated_arguments, - execute_backend_.get(), - SingleComputationDeviceHandle(), - "result of " + executable->module().name(), - result->mutable_profile())); + GlobalDataHandle output, + ExecuteAndRegisterResult( + executable.get(), replicated_arguments, execute_backend_.get(), + SingleComputationDeviceHandle(), + "result of " + executable->module().name(), execution_profile)); if (executable->dumping_snapshot()) { - TF_ASSIGN_OR_RETURN( - const ShapedBuffer* result_buffer, - allocation_tracker_.ResolveForReplica(result->output(), 0)); + TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, + allocation_tracker_.ResolveForReplica(output, 0)); TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(), execute_backend_->transfer_manager(), &snapshot)); DumpHloSnapshotIfEnabled(executable->module(), snapshot); } - VLOG(1) << "successfully completed 'execute' request"; - return absl::OkStatus(); + return std::make_unique(this, output); } -absl::Status Service::TransferToClient(const TransferToClientRequest* arg, - TransferToClientResponse* result) { +absl::StatusOr Service::TransferToClient( + const GlobalData& data, const Shape* shape_with_layout) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer, - allocation_tracker_.ResolveForReplica(arg->data(), 0)); + allocation_tracker_.ResolveForReplica(data.handle(), 0)); Shape return_shape; - if (arg->has_shape_with_layout()) { - return_shape = Shape(arg->shape_with_layout()); + if (shape_with_layout) { + return_shape = Shape(*shape_with_layout); if (!LayoutUtil::HasLayout(return_shape)) { return InvalidArgument("shape_with_layout must have layout if present."); } @@ -943,24 +930,17 @@ absl::Status Service::TransferToClient(const TransferToClientRequest* arg, stream.get(), *shaped_buffer)); if (LayoutUtil::LayoutsInShapesEqual(return_shape, result_literal.shape())) { - *result->mutable_literal() = result_literal.ToProto(); - } else { - *result->mutable_literal() = - result_literal.Relayout(return_shape).ToProto(); + return result_literal; } - return absl::OkStatus(); + return result_literal.Relayout(return_shape); } -absl::Status Service::TransferToServer(const TransferToServerRequest* arg, - TransferToServerResponse* result) { - TF_ASSIGN_OR_RETURN(Literal literal, - Literal::CreateFromProto(arg->literal())); - const Shape& shape = literal.shape(); - +absl::StatusOr> Service::TransferToServer( + const LiteralSlice& literal_slice, const DeviceHandle* device_handle) { + const Shape& shape = literal_slice.shape(); std::vector replicas; - if (arg->has_device_handle()) { - TF_ASSIGN_OR_RETURN(replicas, - Replicas(*execute_backend_, arg->device_handle())); + if (device_handle) { + TF_ASSIGN_OR_RETURN(replicas, Replicas(*execute_backend_, *device_handle)); } else { TF_ASSIGN_OR_RETURN( replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); @@ -982,100 +962,93 @@ absl::Status Service::TransferToServer(const TransferToServerRequest* arg, TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor)); TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralToDevice( - stream.get(), literal, shaped_buffer)); + stream.get(), literal_slice, shaped_buffer)); replicated_buffers.emplace_back(std::move(shaped_buffer)); } - TF_ASSIGN_OR_RETURN(*result->mutable_data(), + + TF_ASSIGN_OR_RETURN(GlobalDataHandle out, allocation_tracker_.RegisterReplicatedBuffers( std::move(replicated_buffers), StrCat("TransferToServer literal of shape ", ShapeUtil::HumanString(shape)))); - return absl::OkStatus(); + return std::make_unique(this, out); } -absl::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) { +absl::Status Service::TransferToInfeed(const LiteralSlice& literal, + int64_t replica_id, + const DeviceHandle* device_handle) { const int64_t replica_count = options_.number_of_replicas(); - if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { + if (replica_id < 0 || replica_id >= replica_count) { return FailedPrecondition( "%s", - StrCat("The replica_id=", arg->replica_id(), + StrCat("The replica_id=", replica_id, " on TransferToInfeedRequest not in range [0, replica_count=", replica_count, ").")); } se::StreamExecutor* executor; - if (arg->has_device_handle()) { + if (device_handle) { TF_ASSIGN_OR_RETURN(auto replicas, - Replicas(*execute_backend_, arg->device_handle())); - executor = replicas[arg->replica_id()]; + Replicas(*execute_backend_, *device_handle)); + executor = replicas[replica_id]; } else { TF_ASSIGN_OR_RETURN( auto replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); - executor = replicas[arg->replica_id()]; + executor = replicas[replica_id]; } - TF_ASSIGN_OR_RETURN(Literal literal, - Literal::CreateFromProto(arg->literal())); return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor, literal); } -absl::Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) { +absl::StatusOr Service::TransferFromOutfeed( + const Shape* shape_with_layout, int64_t replica_id, + const DeviceHandle* device_handle) { const int64_t replica_count = options_.number_of_replicas(); - if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { + if (replica_id < 0 || replica_id >= replica_count) { return FailedPrecondition( "The replica_id=%d on TransferFromOutfeedRequest not in range [0, %d)", - arg->replica_id(), replica_count); + replica_id, replica_count); } se::StreamExecutor* executor; - if (arg->has_device_handle()) { + if (device_handle) { TF_ASSIGN_OR_RETURN(auto replicas, - Replicas(*execute_backend_, arg->device_handle())); - executor = replicas[arg->replica_id()]; + Replicas(*execute_backend_, *device_handle)); + executor = replicas[replica_id]; } else { TF_ASSIGN_OR_RETURN( auto replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); - executor = replicas[arg->replica_id()]; + executor = replicas[replica_id]; } - auto literal = Literal::CreateFromShape(Shape(arg->shape_with_layout())); + auto literal = Literal::CreateFromShape(*shape_with_layout); TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( executor, &literal)); - *result->mutable_literal() = literal.ToProto(); - return absl::OkStatus(); + return literal; } -absl::Status Service::ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) { - return execute_backend_->ResetDevices(); -} +absl::Status Service::ResetDevice() { return execute_backend_->ResetDevices(); } -absl::Status Service::ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) { - if (!arg->has_computation()) { - return InvalidArgument("computations may not be empty"); - } - if (!arg->computation().has_host_program_shape()) { +absl::StatusOr Service::ComputeConstantGraph( + const XlaComputation& computation, const Layout* output_layout) { + if (!computation.proto().has_host_program_shape()) { return InvalidArgument("program shape may not be empty"); } - if (arg->computation().host_program_shape().parameters_size() != 0) { + if (computation.proto().host_program_shape().parameters_size() != 0) { return InvalidArgument( "constant computation may not depend on any parameters."); } - ProgramShape program_shape(arg->computation().host_program_shape()); + ProgramShape program_shape(computation.proto().host_program_shape()); TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); - std::optional output_layout; - if (arg->has_output_layout()) { - output_layout = Layout::CreateFromProto(arg->output_layout()); + + if (output_layout) { TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( *output_layout, program_shape.result())); } @@ -1083,7 +1056,7 @@ absl::Status Service::ComputeConstantGraph( HloModuleConfig config(program_shape); TF_ASSIGN_OR_RETURN(std::unique_ptr module, - CreateModuleFromProto(arg->computation(), config)); + CreateModuleFromProto(computation.proto(), config)); DynamicPadder dynamic_padder; TF_RETURN_IF_ERROR(dynamic_padder.Run(module.get()).status()); @@ -1112,20 +1085,16 @@ absl::Status Service::ComputeConstantGraph( // relayout here. // // TODO(b/77824332): Make HloEvaluator take care of the re-layout. - if (output_layout.has_value()) { + if (output_layout) { result_literal = result_literal.Relayout(*output_layout); } - *result->mutable_literal() = result_literal.ToProto(); - - return absl::OkStatus(); + return result_literal; } -absl::Status Service::GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) { +absl::StatusOr Service::GetShape(const GlobalData& data) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, - allocation_tracker_.ResolveForReplica(arg->data(), 0)); - *result->mutable_shape() = buffer->on_device_shape().ToProto(); - return absl::OkStatus(); + allocation_tracker_.ResolveForReplica(data.handle(), 0)); + return buffer->on_device_shape(); } DeviceHandle Service::SingleComputationDeviceHandle() const { @@ -1152,4 +1121,46 @@ absl::StatusOr> Service::Replicas( return replicas; } +namespace { + +// Releases a set of global data handles owned by the parent service +// interface. +void ReleaseHandles(Service* parent, + const absl::Span handles) { + for (const GlobalDataHandle& handle : handles) { + VLOG(1) << "Requesting to unregister " << handle.ShortDebugString(); + absl::Status status = parent->Unregister(handle); + if (!status.ok()) { + LOG(WARNING) << "Failed to unregister handles: " << status + << "; continuing anyway..."; + } + } + VLOG(1) << "Done with request"; +} + +} // namespace + +GlobalData::GlobalData(Service* parent, GlobalDataHandle handle) + : handle_(std::move(handle)), parent_(parent) {} + +GlobalData::~GlobalData() { + if (parent_ != nullptr) { + ReleaseHandles(parent_, {handle_}); + } +} + +/* static */ void GlobalData::Release( + std::vector> instances) { + absl::flat_hash_map> + parent_handles_map; + for (auto& instance : instances) { + if (instance->parent_ != nullptr) { + parent_handles_map[instance->parent_].push_back(instance->Release()); + } + } + for (auto& parent_handles : parent_handles_map) { + ReleaseHandles(parent_handles.first, parent_handles.second); + } +} + } // namespace xla diff --git a/third_party/xla/xla/service/service.h b/third_party/xla/xla/service/service.h index ff54b36ae04435..3fd7f227c362e8 100644 --- a/third_party/xla/xla/service/service.h +++ b/third_party/xla/xla/service/service.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/client/xla_computation.h" #include "xla/debug_options_flags.h" #include "xla/executable_run_options.h" #include "xla/hlo/ir/hlo_module.h" @@ -44,6 +45,8 @@ limitations under the License. namespace xla { +class Service; + // Options to configure the service when it is created. class ServiceOptions { public: @@ -73,6 +76,59 @@ class ServiceOptions { std::optional> allowed_devices_; }; +// A GlobalData object represents a globally-accessible allocation of +// data in the associated XLA service. +class GlobalData { + public: + // Gives ownership of the global data handle to this object. + GlobalData(Service* parent, GlobalDataHandle handle); + + // Unregisters the wrapped handle, which causes the service to + // deallocate the associated data. + ~GlobalData(); + + const GlobalDataHandle& handle() const { return handle_; } + + // Releases a set of GlobalData handles. A single RPC will be issued + // per unique Service of the given GlobalData objects. + static void Release(std::vector> instances); + + private: + // Detaches the global data handle from the object, such that the destructor + // will not try to release it. + GlobalDataHandle Release() { + parent_ = nullptr; + return handle_; + } + + GlobalDataHandle handle_; // Handle being wrapped. + Service* parent_; // Service used to unregister handle_. + + GlobalData(const GlobalData&) = delete; + GlobalData& operator=(const GlobalData&) = delete; +}; + +// A struct to represent a computation instance to be executed. +// * If execution_options.device_handles is not empty, the computation is +// executed on the devices associated with the handles by partitioning the +// computation based on the attached sharding attributes. Otherwise, a +// device is chosen by the service. +struct XlaComputationInstance { + const XlaComputation& computation; + std::vector arguments; + ExecutionOptions execution_options; + ExecutionProfile* execution_profile; + + XlaComputationInstance(const XlaComputation& computation, + std::vector arguments, + ExecutionOptions execution_options, + ExecutionProfile* execution_profile) + : computation(computation), + arguments(std::move(arguments)), + execution_options(execution_options), + execution_profile(execution_profile) {} +}; + // The XLA service object, which is the same across all platforms. It maintains // the service state of computations and allocations, and delegates // target-specific requests to the target-specific infrastructure @@ -83,30 +139,32 @@ class Service { // // If the handle given is not currently allocated, a NOT_FOUND status is // returned. - virtual absl::Status Unregister(const UnregisterRequest* arg, - UnregisterResponse* result); + virtual absl::Status Unregister(const GlobalDataHandle& data); // Deconstructs a tuple. Returns a newly created GlobalDataHandle for each // element in the tuple. - virtual absl::Status DeconstructTuple(const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result); + virtual absl::StatusOr>> + DeconstructTuple(const GlobalData& data); // Compiles a computation into an executable. The request contains the whole // computation graph. Returns the handle to the executable. - virtual absl::Status Compile(const CompileRequest* arg, - CompileResponse* result); + virtual absl::StatusOr Compile( + const XlaComputation& computation, + absl::Span argument_shapes, + const ExecutionOptions& execution_options); // Executes an executable with the provided global data passes as immutable // arguments. The request contains the handle to the executable. Returns // global data output and execution timing. - virtual absl::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result); + virtual absl::StatusOr> Execute( + const ExecutionHandle& handle, absl::Span arguments, + ExecutionProfile* execution_profile); // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each // computation. - virtual absl::Status ExecuteGraphParallel( - const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result); + absl::StatusOr>> ExecuteGraphParallel( + absl::Span computations); // Requests one or more device handles from the target. // @@ -116,27 +174,28 @@ class Service { // the first set of replicas, and the next R devices to the second set of // replicas, etc. Each returned device handle represents the device with the // replica id 0. - virtual absl::Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result); + virtual absl::StatusOr> GetDeviceHandles( + int64_t device_count); // Requests that global data be transferred to the client in literal form. - virtual absl::Status TransferToClient(const TransferToClientRequest* arg, - TransferToClientResponse* result); + virtual absl::StatusOr TransferToClient( + const GlobalData& data, const Shape* shape_with_layout); // Transfers data from a literal provided by the client, into device memory. - virtual absl::Status TransferToServer(const TransferToServerRequest* arg, - TransferToServerResponse* result); + virtual absl::StatusOr> TransferToServer( + const LiteralSlice& literal_slice, const DeviceHandle* device_handle); // Transfers data from a literal provided by the client, into the Infeed // buffer of the device. - virtual absl::Status TransferToInfeed(const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result); + virtual absl::Status TransferToInfeed(const LiteralSlice& literal, + int64_t replica_id, + const DeviceHandle* device_handle); // Transfers data from the Outfeed othe device to the literal provided by the // client. - virtual absl::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result); + virtual absl::StatusOr TransferFromOutfeed( + const Shape* shape_with_layout, int64_t replica_id, + const DeviceHandle* device_handle); // Resets devices, clearing all existing state on all the devices associated // with this service (including memory allocated on the devices). @@ -147,22 +206,19 @@ class Service { // ResetDevice should be called before an Execution that expect the device to // be in the reset state. For example, if the prior Execution modifies device // state (e.g., architectural state) that the next Execution depends on. - virtual absl::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result); + virtual absl::Status ResetDevice(); - virtual absl::Status ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result); + virtual absl::StatusOr ComputeConstantGraph( + const XlaComputation& computation, const Layout* output_layout); // Returns the shape (with layout) of an array associated with a given data // handle. - virtual absl::Status GetShape(const GetShapeRequest* arg, - GetShapeResponse* result); + virtual absl::StatusOr GetShape(const GlobalData& data); // Creates a unique channel handle that can be used for Send/Recv // instructions. - virtual absl::Status CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result); + virtual absl::StatusOr CreateChannelHandle( + ChannelHandle::ChannelType type); // Returns the backend used to execute computations. const Backend& backend() const { return *execute_backend_; } @@ -201,7 +257,7 @@ class Service { // Prepare the arguments for executing parallel. absl::StatusOr>> GetArguments( const ExecutionOptions& execution_options, - absl::Span arguments) const; + absl::Span arguments) const; protected: friend class LocalExecutable; @@ -217,7 +273,7 @@ class Service { // the corresponding replica. absl::StatusOr>> ResolveAndValidateArguments( - absl::Span arguments, + absl::Span arguments, absl::Span stream_executors) const; public: diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 918992c6aeff99..69c591200c9025 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -1045,6 +1045,7 @@ xla_test( "//xla:status_macros", "//xla:test", "//xla/client:xla_builder", + "//xla/service", ], ) diff --git a/third_party/xla/xla/tests/client_test.cc b/third_party/xla/xla/tests/client_test.cc index 6d67d07eb969da..1adb92207748aa 100644 --- a/third_party/xla/xla/tests/client_test.cc +++ b/third_party/xla/xla/tests/client_test.cc @@ -128,14 +128,14 @@ XLA_TEST_F(ClientTest, // We can't really test parallel execution on CPU since all of the cores in a // CPU are presented as a single device. So for now we test "parallel" // execution on a single device. - std::vector computation_instances; + std::vector computation_instances; TF_ASSERT_OK_AND_ASSIGN(std::vector devices, client_->GetDeviceHandles(1)); ASSERT_EQ(devices.size(), 1); ExecutionOptions options = execution_options_; *options.add_device_handles() = devices[0]; - computation_instances.push_back(Client::XlaComputationInstance( + computation_instances.push_back(XlaComputationInstance( add_with_one_arg, {const_arg.get()}, options, nullptr)); TF_ASSERT_OK_AND_ASSIGN(auto results, diff --git a/third_party/xla/xla/tests/gather_operation_test.cc b/third_party/xla/xla/tests/gather_operation_test.cc index b80d1570a9b3b0..94d02abbea3b04 100644 --- a/third_party/xla/xla/tests/gather_operation_test.cc +++ b/third_party/xla/xla/tests/gather_operation_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/execution_options_util.h" #include "xla/literal_util.h" +#include "xla/service/service.h" #include "xla/status_macros.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" @@ -778,7 +779,7 @@ XLA_TEST_F(GatherClientLibraryTest, xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions(); *execution_options.add_device_handles() = devices[0]; TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, builder.Build()); - std::vector computation_instances = { + std::vector computation_instances = { {computation, {operand_arg.get(), indices_arg.get()}, execution_options, diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index d9e4594d2b4708..16b1da1999d7c3 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -1037,189 +1037,6 @@ message HloModuleProtoWithConfig { HloModuleConfigProto config = 2; } -message GetDeviceHandlesRequest { - int64 device_count = 1; -} - -message GetDeviceHandlesResponse { - repeated DeviceHandle device_handles = 1; -} - -message TransferToClientRequest { - GlobalDataHandle data = 1; - - // This optional field directs the service to return the literal in this - // layout. A shape is used to hold the layout to accommodate tuples. - ShapeProto shape_with_layout = 2; -} - -message TransferToClientResponse { - LiteralProto literal = 1; -} - -message TransferToServerRequest { - LiteralProto literal = 1; - DeviceHandle device_handle = 2; -} - -message TransferToServerResponse { - GlobalDataHandle data = 1; -} - -message TransferToInfeedRequest { - LiteralProto literal = 1; - int64 replica_id = 2; - DeviceHandle device_handle = 3; -} - -message TransferToInfeedResponse {} - -message TransferFromOutfeedRequest { - // This optional field directs the service to return the literal in this - // layout. A shape is used to hold the layout to accommodate tuples. - ShapeProto shape_with_layout = 1; - - int64 replica_id = 2; - DeviceHandle device_handle = 3; -} - -message TransferFromOutfeedResponse { - LiteralProto literal = 1; -} - -message ResetDeviceRequest { - DeviceHandle device_handle = 1; -} - -message ResetDeviceResponse {} - -message CreateChannelHandleRequest { - ChannelHandle.ChannelType channel_type = 1; -} - -message CreateChannelHandleResponse { - ChannelHandle channel = 1; -} - -message UnregisterRequest { - repeated GlobalDataHandle data = 1; -} - -message UnregisterResponse {} - -message CompileRequest { - // The graph to be compiled. - HloModuleProto computation = 1; - - // Options that affect how XLA compiles code to service this request. - ExecutionOptions execution_options = 2; - - // The layouts of the input arguments. If not set, the default layout will be - // used. Although the real arguments are not needed in compilation, the - // layouts of the arguments can affect the compilation. - repeated ShapeProto input_shape_with_layout = 3; -} - -message CompileResponse { - // The handle to the executable. - ExecutionHandle handle = 1; -} - -message ExecuteRequest { - ExecutionHandle handle = 1; - - // The shape and layout of the arguments must be the same as the those of the - // executable's parameters. - repeated GlobalDataHandle arguments = 2; -} - -// TODO(b/118493728): Remove this and ExecuteGraphParallelRequest and replace -// the uses with calls to Compile and Execute. -message ExecuteGraphRequest { - HloModuleProto computation = 1; - repeated GlobalDataHandle arguments = 2; - - // Options that affect how XLA compiles and runs code to service this request. - ExecutionOptions execution_options = 3; -} - -message ExecuteGraphParallelRequest { - repeated ExecuteGraphRequest requests = 1; -} - -message ExecuteResponse { - GlobalDataHandle output = 1; - ExecutionProfile profile = 2; -} - -message ExecuteParallelResponse { - repeated ExecuteResponse responses = 1; -} - -message ComputeConstantGraphRequest { - HloModuleProto computation = 1; - LayoutProto output_layout = 2; -} - -message ComputeConstantResponse { - // A LiteralProto is returned directly for this request. - LiteralProto literal = 1; -} - -message DeconstructTupleRequest { - GlobalDataHandle tuple_handle = 2; -} - -message DeconstructTupleResponse { - repeated GlobalDataHandle element_handles = 1; -} - -message LoadDataRequest { - // Describes the path of the ColumnIO tablet to load. - string columnio_tablet_path = 1; - - // Describes the field to load within the ColumnIO tablet. - string columnio_field = 2; - - // Individual element shape, excluding rows. - ShapeProto element_shape = 3; - - // Warning: ColumnIO does not support random-access, so use offset with - // caution in performance-critical scenarios. - int64 offset = 4; - - // Maximum number of elements (with shape element_shape) to load. - int64 limit = 5; - - // If more than one item is requested (via limit > 1), then this request - // attribute zips together the produced vectors. - bool zip = 6; -} - -message LoadDataResponse { - GlobalDataHandle data = 1; - ShapeProto data_shape = 2; - int64 available_rows = 3; - int64 rows_loaded = 4; - int64 nanoseconds = 5; -} - -message GetShapeRequest { - GlobalDataHandle data = 1; -} - -message GetShapeResponse { - ShapeProto shape = 1; -} - -message UnpackRequest { - GlobalDataHandle data = 1; -} - -message UnpackResponse { - repeated GlobalDataHandle tied_data = 1; -} - // A trace estimated by the Latency Hiding Scheduler. message ScheduleProto { message Instruction { From b1f94db136e67c5d31887f1f00ff4c903202d91c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 19 Jun 2024 04:04:32 -0700 Subject: [PATCH 5/6] Automated Code Change PiperOrigin-RevId: 644700338 --- .../xla/third_party/tsl/tsl/lib/io/block.cc | 4 +- .../tsl/tsl/lib/io/buffered_file.h | 20 +++---- .../tsl/tsl/lib/io/buffered_inputstream.cc | 55 +++++++++--------- .../tsl/tsl/lib/io/buffered_inputstream.h | 22 +++---- .../tsl/lib/io/buffered_inputstream_test.cc | 8 +-- .../xla/third_party/tsl/tsl/lib/io/format.cc | 17 +++--- .../xla/third_party/tsl/tsl/lib/io/format.h | 8 +-- .../third_party/tsl/tsl/lib/io/inputbuffer.cc | 57 ++++++++++--------- .../third_party/tsl/tsl/lib/io/inputbuffer.h | 33 +++++------ .../tsl/tsl/lib/io/inputstream_interface.cc | 4 +- .../tsl/tsl/lib/io/inputstream_interface.h | 8 +-- .../tsl/lib/io/inputstream_interface_test.cc | 8 +-- .../third_party/tsl/tsl/lib/io/iterator.cc | 10 ++-- .../xla/third_party/tsl/tsl/lib/io/iterator.h | 4 +- .../tsl/tsl/lib/io/random_inputstream.cc | 23 ++++---- .../tsl/tsl/lib/io/random_inputstream.h | 12 ++-- .../tsl/tsl/lib/io/record_reader.cc | 29 +++++----- .../tsl/tsl/lib/io/record_reader.h | 18 +++--- .../tsl/lib/io/record_reader_writer_test.cc | 4 +- .../tsl/tsl/lib/io/record_writer.cc | 30 +++++----- .../tsl/tsl/lib/io/record_writer.h | 8 +-- .../tsl/tsl/lib/io/recordio_test.cc | 30 +++++----- 22 files changed, 209 insertions(+), 203 deletions(-) diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/block.cc b/third_party/xla/third_party/tsl/tsl/lib/io/block.cc index 0bc9fa3664c97b..8eefa4b5a3609f 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/block.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/block.cc @@ -98,7 +98,7 @@ class Block::Iter : public Iterator { uint32 restart_index_; // Index of restart block in which current_ falls string key_; StringPiece value_; - Status status_; + absl::Status status_; inline int Compare(const StringPiece& a, const StringPiece& b) const { return a.compare(b); @@ -135,7 +135,7 @@ class Block::Iter : public Iterator { } bool Valid() const override { return current_ < restarts_; } - Status status() const override { return status_; } + absl::Status status() const override { return status_; } StringPiece key() const override { assert(Valid()); return key_; diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file.h b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file.h index 5627a7228fb782..69300956d9fe20 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file.h @@ -36,7 +36,7 @@ class BufferedWritableFile : public WritableFile { } ~BufferedWritableFile() override { Close().IgnoreError(); } - Status Append(StringPiece str_data) override { + absl::Status Append(StringPiece str_data) override { int64_t bytes_left = str_data.size(); const char* data = str_data.data(); @@ -58,22 +58,22 @@ class BufferedWritableFile : public WritableFile { bytes_left -= append_bytes; } - return OkStatus(); + return absl::OkStatus(); } - Status Append(const absl::Cord& data) override { + absl::Status Append(const absl::Cord& data) override { for (absl::string_view fragment : data.Chunks()) { TF_RETURN_IF_ERROR(Append(fragment)); } - return OkStatus(); + return absl::OkStatus(); } - Status Close() override { + absl::Status Close() override { TF_RETURN_IF_ERROR(Flush()); return file_->Close(); } - Status Flush() override { + absl::Status Flush() override { if (buffer_pos_ > 0) { TF_RETURN_IF_ERROR(file_->Append(StringPiece(&buffer_[0], buffer_pos_))); buffer_pos_ = 0; @@ -81,18 +81,18 @@ class BufferedWritableFile : public WritableFile { return file_->Flush(); } - tsl::Status Tell(int64_t* position) override { + absl::Status Tell(int64_t* position) override { int64_t bytes_written; - tsl::Status status = file_->Tell(&bytes_written); + absl::Status status = file_->Tell(&bytes_written); if (status.ok()) { *position = bytes_written + buffer_pos_; - return OkStatus(); + return absl::OkStatus(); } else { return status; } } - Status Sync() override { return file_->Sync(); } + absl::Status Sync() override { return file_->Sync(); } // For compatibilty with the TensorBundle writer, we expose CRC32 checksums. uint32_t crc32() const { return crc32_; } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.cc b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.cc index b3cfdbb20818ec..89ed20757cf093 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.cc @@ -41,13 +41,13 @@ BufferedInputStream::~BufferedInputStream() { } } -Status BufferedInputStream::FillBuffer() { +absl::Status BufferedInputStream::FillBuffer() { if (!file_status_.ok()) { pos_ = 0; limit_ = 0; return file_status_; } - Status s = input_stream_->ReadNBytes(size_, &buf_); + absl::Status s = input_stream_->ReadNBytes(size_, &buf_); pos_ = 0; limit_ = buf_.size(); if (!s.ok()) { @@ -57,10 +57,10 @@ Status BufferedInputStream::FillBuffer() { } template -Status BufferedInputStream::ReadLineHelper(StringType* result, - bool include_eol) { +absl::Status BufferedInputStream::ReadLineHelper(StringType* result, + bool include_eol) { result->clear(); - Status s; + absl::Status s; size_t start_pos = pos_; while (true) { if (pos_ == limit_) { @@ -79,7 +79,7 @@ Status BufferedInputStream::ReadLineHelper(StringType* result, result->append(1, c); } pos_++; - return OkStatus(); + return absl::OkStatus(); } // We don't append '\r' to *result if (c == '\r') { @@ -89,12 +89,13 @@ Status BufferedInputStream::ReadLineHelper(StringType* result, pos_++; } if (absl::IsOutOfRange(s) && !result->empty()) { - return OkStatus(); + return absl::OkStatus(); } return s; } -Status BufferedInputStream::ReadNBytes(int64_t bytes_to_read, tstring* result) { +absl::Status BufferedInputStream::ReadNBytes(int64_t bytes_to_read, + tstring* result) { if (bytes_to_read < 0) { return errors::InvalidArgument("Can't read a negative number of bytes: ", bytes_to_read); @@ -105,7 +106,7 @@ Status BufferedInputStream::ReadNBytes(int64_t bytes_to_read, tstring* result) { } result->reserve(bytes_to_read); - Status s; + absl::Status s; while (result->size() < static_cast(bytes_to_read)) { // Check whether the buffer is fully read or not. if (pos_ == limit_) { @@ -127,12 +128,12 @@ Status BufferedInputStream::ReadNBytes(int64_t bytes_to_read, tstring* result) { // obtained enough data to satisfy the function call. Returning OK then. if (absl::IsOutOfRange(s) && (result->size() == static_cast(bytes_to_read))) { - return OkStatus(); + return absl::OkStatus(); } return s; } -Status BufferedInputStream::SkipNBytes(int64_t bytes_to_skip) { +absl::Status BufferedInputStream::SkipNBytes(int64_t bytes_to_skip) { if (bytes_to_skip < 0) { return errors::InvalidArgument("Can only skip forward, not ", bytes_to_skip); @@ -144,7 +145,7 @@ Status BufferedInputStream::SkipNBytes(int64_t bytes_to_skip) { // Otherwise, we already have read limit_ - pos_, so skip the rest. At this // point we need to get fresh data into the buffer, so reset pos_ and // limit_. - Status s = input_stream_->SkipNBytes(bytes_to_skip - (limit_ - pos_)); + absl::Status s = input_stream_->SkipNBytes(bytes_to_skip - (limit_ - pos_)); pos_ = 0; limit_ = 0; if (absl::IsOutOfRange(s)) { @@ -152,14 +153,14 @@ Status BufferedInputStream::SkipNBytes(int64_t bytes_to_skip) { } return s; } - return OkStatus(); + return absl::OkStatus(); } int64_t BufferedInputStream::Tell() const { return input_stream_->Tell() - (limit_ - pos_); } -Status BufferedInputStream::Seek(int64_t position) { +absl::Status BufferedInputStream::Seek(int64_t position) { if (position < 0) { return errors::InvalidArgument("Seeking to a negative position: ", position); @@ -176,7 +177,7 @@ Status BufferedInputStream::Seek(int64_t position) { if (position < Tell()) { // Seek within buffer before 'pos_' pos_ -= Tell() - position; - return OkStatus(); + return absl::OkStatus(); } // Seek after 'pos_' @@ -184,9 +185,9 @@ Status BufferedInputStream::Seek(int64_t position) { } template -Status BufferedInputStream::ReadAll(T* result) { +absl::Status BufferedInputStream::ReadAll(T* result) { result->clear(); - Status status; + absl::Status status; while (status.ok()) { status = FillBuffer(); if (limit_ == 0) { @@ -198,7 +199,7 @@ Status BufferedInputStream::ReadAll(T* result) { if (absl::IsOutOfRange(status)) { file_status_ = status; - return OkStatus(); + return absl::OkStatus(); } return status; } @@ -206,19 +207,19 @@ Status BufferedInputStream::ReadAll(T* result) { template Status BufferedInputStream::ReadAll(std::string* result); template Status BufferedInputStream::ReadAll(tstring* result); -Status BufferedInputStream::Reset() { +absl::Status BufferedInputStream::Reset() { TF_RETURN_IF_ERROR(input_stream_->Reset()); pos_ = 0; limit_ = 0; - file_status_ = OkStatus(); - return OkStatus(); + file_status_ = absl::OkStatus(); + return absl::OkStatus(); } -Status BufferedInputStream::ReadLine(std::string* result) { +absl::Status BufferedInputStream::ReadLine(std::string* result) { return ReadLineHelper(result, false); } -Status BufferedInputStream::ReadLine(tstring* result) { +absl::Status BufferedInputStream::ReadLine(tstring* result) { return ReadLineHelper(result, false); } @@ -228,8 +229,8 @@ std::string BufferedInputStream::ReadLineAsString() { return result; } -Status BufferedInputStream::SkipLine() { - Status s; +absl::Status BufferedInputStream::SkipLine() { + absl::Status s; bool skipped = false; while (true) { if (pos_ == limit_) { @@ -242,11 +243,11 @@ Status BufferedInputStream::SkipLine() { char c = buf_[pos_++]; skipped = true; if (c == '\n') { - return OkStatus(); + return absl::OkStatus(); } } if (absl::IsOutOfRange(s) && skipped) { - return OkStatus(); + return absl::OkStatus(); } return s; } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.h b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.h index 5318434c63919c..6681f1bbfbed32 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.h @@ -43,9 +43,9 @@ class BufferedInputStream : public InputStreamInterface { ~BufferedInputStream() override; - Status ReadNBytes(int64_t bytes_to_read, tstring* result) override; + absl::Status ReadNBytes(int64_t bytes_to_read, tstring* result) override; - Status SkipNBytes(int64_t bytes_to_skip) override; + absl::Status SkipNBytes(int64_t bytes_to_skip) override; int64_t Tell() const override; @@ -58,7 +58,7 @@ class BufferedInputStream : public InputStreamInterface { // Note: When seeking backwards in a stream, this implementation uses // Reset() + SkipNBytes(), so its performance will be dependent // largely on the performance of SkipNBytes(). - Status Seek(int64_t position); + absl::Status Seek(int64_t position); // Read one text line of data into "*result" until end-of-file or a // \n is read. (The \n is not included in the result.) Overwrites @@ -67,8 +67,8 @@ class BufferedInputStream : public InputStreamInterface { // If successful, returns OK. If we are already at the end of the // file, we return an OUT_OF_RANGE error. Otherwise, we return // some other non-OK status. - Status ReadLine(std::string* result); - Status ReadLine(tstring* result); + absl::Status ReadLine(std::string* result); + absl::Status ReadLine(tstring* result); // Returns one text line of data until end-of-file or a '\n' is read. The '\n' // is included in the result. @@ -83,21 +83,21 @@ class BufferedInputStream : public InputStreamInterface { // If successful, returns OK. If we are already at the end of the // file, we return an OUT_OF_RANGE error. Otherwise, we return // some other non-OK status. - Status SkipLine(); + absl::Status SkipLine(); // Reads the entire contents of the file into *result. // // Note: the amount of memory used by this function call is unbounded, so only // use in ops that expect that behavior. template - Status ReadAll(T* result); + absl::Status ReadAll(T* result); - Status Reset() override; + absl::Status Reset() override; private: - Status FillBuffer(); + absl::Status FillBuffer(); template - Status ReadLineHelper(StringType* result, bool include_eol); + absl::Status ReadLineHelper(StringType* result, bool include_eol); InputStreamInterface* input_stream_; // not owned. size_t size_; // buffer size. @@ -108,7 +108,7 @@ class BufferedInputStream : public InputStreamInterface { bool owns_input_stream_ = false; // When EoF is reached, file_status_ contains the status to skip unnecessary // buffer allocations. - Status file_status_ = OkStatus(); + absl::Status file_status_ = absl::OkStatus(); BufferedInputStream(const BufferedInputStream&) = delete; void operator=(const BufferedInputStream&) = delete; diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc index 56dd88510377bd..ab1f58e0b14a83 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc @@ -36,7 +36,7 @@ class ReadOnceInputStream : public InputStreamInterface { public: ReadOnceInputStream() : start_(true) {} - virtual Status ReadNBytes(int64_t bytes_to_read, tstring* result) { + virtual absl::Status ReadNBytes(int64_t bytes_to_read, tstring* result) { if (bytes_to_read < 11) { return errors::InvalidArgument("Not reading all bytes: ", bytes_to_read); } @@ -52,9 +52,9 @@ class ReadOnceInputStream : public InputStreamInterface { int64_t Tell() const override { return start_ ? 0 : 10; } // Resets the stream to the beginning. - Status Reset() override { + absl::Status Reset() override { start_ = true; - return OkStatus(); + return absl::OkStatus(); } private: @@ -311,7 +311,7 @@ TEST(BufferedInputStream, OutOfRangeCache) { TF_ASSERT_OK((in.ReadNBytes(7, &read))); EXPECT_EQ(read, "3456789"); EXPECT_EQ(10, in.Tell()); - Status s = in.ReadNBytes(5, &read); + absl::Status s = in.ReadNBytes(5, &read); // Make sure the read is failing with OUT_OF_RANGE error. If it is failing // with other errors, it is not caching the OUT_OF_RANGE properly. EXPECT_EQ(error::OUT_OF_RANGE, s.code()) << s; diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/format.cc b/third_party/xla/third_party/tsl/tsl/lib/io/format.cc index bc12656f7fbec7..d0b20da64a385e 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/format.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/format.cc @@ -36,9 +36,9 @@ void BlockHandle::EncodeTo(string* dst) const { core::PutVarint64(dst, size_); } -Status BlockHandle::DecodeFrom(StringPiece* input) { +absl::Status BlockHandle::DecodeFrom(StringPiece* input) { if (core::GetVarint64(input, &offset_) && core::GetVarint64(input, &size_)) { - return OkStatus(); + return absl::OkStatus(); } else { return errors::DataLoss("bad block handle"); } @@ -56,7 +56,7 @@ void Footer::EncodeTo(string* dst) const { assert(dst->size() == original_size + kEncodedLength); } -Status Footer::DecodeFrom(StringPiece* input) { +absl::Status Footer::DecodeFrom(StringPiece* input) { const char* magic_ptr = input->data() + kEncodedLength - 8; const uint32 magic_lo = core::DecodeFixed32(magic_ptr); const uint32 magic_hi = core::DecodeFixed32(magic_ptr + 4); @@ -66,7 +66,7 @@ Status Footer::DecodeFrom(StringPiece* input) { return errors::DataLoss("not an sstable (bad magic number)"); } - Status result = metaindex_handle_.DecodeFrom(input); + absl::Status result = metaindex_handle_.DecodeFrom(input); if (result.ok()) { result = index_handle_.DecodeFrom(input); } @@ -78,8 +78,8 @@ Status Footer::DecodeFrom(StringPiece* input) { return result; } -Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, - BlockContents* result) { +absl::Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, + BlockContents* result) { result->data = StringPiece(); result->cacheable = false; result->heap_allocated = false; @@ -94,7 +94,8 @@ Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, char* buf = new char[n + kBlockTrailerSize]; StringPiece contents; - Status s = file->Read(handle.offset(), n + kBlockTrailerSize, &contents, buf); + absl::Status s = + file->Read(handle.offset(), n + kBlockTrailerSize, &contents, buf); if (!s.ok()) { delete[] buf; return s; @@ -159,7 +160,7 @@ Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, return errors::DataLoss("bad block type"); } - return OkStatus(); + return absl::OkStatus(); } } // namespace table diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/format.h b/third_party/xla/third_party/tsl/tsl/lib/io/format.h index cd8e863435f440..ae5bb26b8b8c86 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/format.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/format.h @@ -46,7 +46,7 @@ class BlockHandle { void set_size(uint64 size) { size_ = size; } void EncodeTo(string* dst) const; - Status DecodeFrom(StringPiece* input); + absl::Status DecodeFrom(StringPiece* input); // Maximum encoding length of a BlockHandle enum { kMaxEncodedLength = 10 + 10 }; @@ -71,7 +71,7 @@ class Footer { void set_index_handle(const BlockHandle& h) { index_handle_ = h; } void EncodeTo(string* dst) const; - Status DecodeFrom(StringPiece* input); + absl::Status DecodeFrom(StringPiece* input); // Encoded length of a Footer. Note that the serialization of a // Footer will always occupy exactly this many bytes. It consists @@ -99,8 +99,8 @@ struct BlockContents { // Read the block identified by "handle" from "file". On failure // return non-OK. On success fill *result and return OK. -extern Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, - BlockContents* result); +extern absl::Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, + BlockContents* result); // Implementation details follow. Clients should ignore, diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.cc b/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.cc index 1b9e2dc6fe2b21..3c183ee1ae1b3c 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.cc @@ -33,9 +33,9 @@ InputBuffer::InputBuffer(RandomAccessFile* file, size_t buffer_bytes) InputBuffer::~InputBuffer() { delete[] buf_; } -Status InputBuffer::FillBuffer() { +absl::Status InputBuffer::FillBuffer() { StringPiece data; - Status s = file_->Read(file_pos_, size_, &data, buf_); + absl::Status s = file_->Read(file_pos_, size_, &data, buf_); if (data.data() != buf_) { memmove(buf_, data.data(), data.size()); } @@ -46,9 +46,9 @@ Status InputBuffer::FillBuffer() { } template -Status InputBuffer::ReadLine(T* result) { +absl::Status InputBuffer::ReadLine(T* result) { result->clear(); - Status s; + absl::Status s; do { size_t buf_remain = limit_ - pos_; char* newline = static_cast(memchr(pos_, '\n', buf_remain)); @@ -59,7 +59,7 @@ Status InputBuffer::ReadLine(T* result) { if (!result->empty() && result->back() == '\r') { result->resize(result->size() - 1); } - return OkStatus(); + return absl::OkStatus(); } if (buf_remain > 0) result->append(pos_, buf_remain); // Get more data into buffer @@ -70,7 +70,7 @@ Status InputBuffer::ReadLine(T* result) { result->resize(result->size() - 1); } if (errors::IsOutOfRange(s) && !result->empty()) { - return OkStatus(); + return absl::OkStatus(); } return s; } @@ -78,7 +78,8 @@ Status InputBuffer::ReadLine(T* result) { template Status InputBuffer::ReadLine(std::string* result); template Status InputBuffer::ReadLine(tstring* result); -Status InputBuffer::ReadNBytes(int64_t bytes_to_read, std::string* result) { +absl::Status InputBuffer::ReadNBytes(int64_t bytes_to_read, + std::string* result) { result->clear(); if (bytes_to_read < 0) { return errors::InvalidArgument("Can't read a negative number of bytes: ", @@ -86,18 +87,18 @@ Status InputBuffer::ReadNBytes(int64_t bytes_to_read, std::string* result) { } result->resize(bytes_to_read); size_t bytes_read = 0; - Status status = ReadNBytes(bytes_to_read, &(*result)[0], &bytes_read); + absl::Status status = ReadNBytes(bytes_to_read, &(*result)[0], &bytes_read); if (bytes_read < bytes_to_read) result->resize(bytes_read); return status; } -Status InputBuffer::ReadNBytes(int64_t bytes_to_read, char* result, - size_t* bytes_read) { +absl::Status InputBuffer::ReadNBytes(int64_t bytes_to_read, char* result, + size_t* bytes_read) { if (bytes_to_read < 0) { return errors::InvalidArgument("Can't read a negative number of bytes: ", bytes_to_read); } - Status status; + absl::Status status; *bytes_read = 0; while (*bytes_read < static_cast(bytes_to_read)) { if (pos_ == limit_) { @@ -117,21 +118,21 @@ Status InputBuffer::ReadNBytes(int64_t bytes_to_read, char* result, } if (errors::IsOutOfRange(status) && (*bytes_read == static_cast(bytes_to_read))) { - return OkStatus(); + return absl::OkStatus(); } return status; } -Status InputBuffer::ReadVarint32Fallback(uint32* result) { - Status s = ReadVarintFallback(result, core::kMaxVarint32Bytes); +absl::Status InputBuffer::ReadVarint32Fallback(uint32* result) { + absl::Status s = ReadVarintFallback(result, core::kMaxVarint32Bytes); if (errors::IsDataLoss(s)) { return errors::DataLoss("Stored data is too large to be a varint32."); } return s; } -Status InputBuffer::ReadVarint64Fallback(uint64* result) { - Status s = ReadVarintFallback(result, core::kMaxVarint64Bytes); +absl::Status InputBuffer::ReadVarint64Fallback(uint64* result) { + absl::Status s = ReadVarintFallback(result, core::kMaxVarint64Bytes); if (errors::IsDataLoss(s)) { return errors::DataLoss("Stored data is too large to be a varint64."); } @@ -139,7 +140,7 @@ Status InputBuffer::ReadVarint64Fallback(uint64* result) { } template -Status InputBuffer::ReadVarintFallback(T* result, int max_bytes) { +absl::Status InputBuffer::ReadVarintFallback(T* result, int max_bytes) { uint8 scratch = 0; auto* p = reinterpret_cast(&scratch); size_t unused_bytes_read = 0; @@ -149,18 +150,18 @@ Status InputBuffer::ReadVarintFallback(T* result, int max_bytes) { int shift = 7 * index; TF_RETURN_IF_ERROR(ReadNBytes(1, p, &unused_bytes_read)); *result |= (static_cast(scratch) & 127) << shift; - if (!(scratch & 128)) return OkStatus(); + if (!(scratch & 128)) return absl::OkStatus(); } return errors::DataLoss("Stored data longer than ", max_bytes, " bytes."); } -Status InputBuffer::SkipNBytes(int64_t bytes_to_skip) { +absl::Status InputBuffer::SkipNBytes(int64_t bytes_to_skip) { if (bytes_to_skip < 0) { return errors::InvalidArgument("Can only skip forward, not ", bytes_to_skip); } int64_t bytes_skipped = 0; - Status s; + absl::Status s; while (bytes_skipped < bytes_to_skip) { if (pos_ == limit_) { // Get more data into buffer @@ -175,12 +176,12 @@ Status InputBuffer::SkipNBytes(int64_t bytes_to_skip) { pos_ += bytes_to_advance; } if (errors::IsOutOfRange(s) && bytes_skipped == bytes_to_skip) { - return OkStatus(); + return absl::OkStatus(); } return s; } -Status InputBuffer::Seek(int64_t position) { +absl::Status InputBuffer::Seek(int64_t position) { if (position < 0) { return errors::InvalidArgument("Seeking to a negative position: ", position); @@ -196,10 +197,10 @@ Status InputBuffer::Seek(int64_t position) { pos_ = limit_ = buf_; file_pos_ = position; } - return OkStatus(); + return absl::OkStatus(); } -Status InputBuffer::Hint(int64_t bytes_to_read) { +absl::Status InputBuffer::Hint(int64_t bytes_to_read) { if (bytes_to_read < 0) { return errors::InvalidArgument("Can't read a negative number of bytes: ", bytes_to_read); @@ -207,14 +208,14 @@ Status InputBuffer::Hint(int64_t bytes_to_read) { // The internal buffer is too small. Do nothing. if (bytes_to_read > size_) { - return OkStatus(); + return absl::OkStatus(); } const int64_t bytes_remain_in_buf = static_cast(limit_ - pos_); // There are enough data in the buffer. Do nothing. if (bytes_to_read <= bytes_remain_in_buf) { - return OkStatus(); + return absl::OkStatus(); } // Additional read from file is necessary. Make some room. @@ -225,7 +226,7 @@ Status InputBuffer::Hint(int64_t bytes_to_read) { // Read the remaining bytes from file. StringPiece data; - Status s = file_->Read(file_pos_, bytes_to_read, &data, limit_); + absl::Status s = file_->Read(file_pos_, bytes_to_read, &data, limit_); if (data.data() != limit_) { memmove(limit_, data.data(), data.size()); } @@ -233,7 +234,7 @@ Status InputBuffer::Hint(int64_t bytes_to_read) { file_pos_ += data.size(); if (errors::IsOutOfRange(s) && data.size() == bytes_to_read) { - return OkStatus(); + return absl::OkStatus(); } else { return s; } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.h b/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.h index e357efb5f75b53..57a4a983c11e75 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.h @@ -45,38 +45,39 @@ class InputBuffer { // file, we return an OUT_OF_RANGE error. Otherwise, we return // some other non-OK status. template - Status ReadLine(T* result); + absl::Status ReadLine(T* result); // Reads bytes_to_read bytes into *result, overwriting *result. // // If successful, returns OK. If we there are not enough bytes to // read before the end of the file, we return an OUT_OF_RANGE error. // Otherwise, we return some other non-OK status. - Status ReadNBytes(int64_t bytes_to_read, std::string* result); + absl::Status ReadNBytes(int64_t bytes_to_read, std::string* result); // An overload that writes to char*. Caller must ensure result[0, // bytes_to_read) is valid to be overwritten. Returns OK iff "*bytes_read == // bytes_to_read". - Status ReadNBytes(int64_t bytes_to_read, char* result, size_t* bytes_read); + absl::Status ReadNBytes(int64_t bytes_to_read, char* result, + size_t* bytes_read); // Reads a single varint32. - Status ReadVarint32(uint32* result); + absl::Status ReadVarint32(uint32* result); // Reads a single varint64. - Status ReadVarint64(uint64* result); + absl::Status ReadVarint64(uint64* result); // Like ReadNBytes() without returning the bytes read. - Status SkipNBytes(int64_t bytes_to_skip); + absl::Status SkipNBytes(int64_t bytes_to_skip); // Seek to this offset within the file. // // If we seek to somewhere within our pre-buffered data, we will re-use what // data we can. Otherwise, Seek() throws out the current buffer and the next // read will trigger a File::Read(). - Status Seek(int64_t position); + absl::Status Seek(int64_t position); // Provides a hint about future reads, which may improve their performance. - Status Hint(int64_t bytes_to_read); + absl::Status Hint(int64_t bytes_to_read); // Returns the position in the file. int64_t Tell() const { return file_pos_ - (limit_ - pos_); } @@ -85,19 +86,19 @@ class InputBuffer { RandomAccessFile* file() const { return file_; } private: - Status FillBuffer(); + absl::Status FillBuffer(); // Internal slow-path routine used by ReadVarint32(). - Status ReadVarint32Fallback(uint32* result); + absl::Status ReadVarint32Fallback(uint32* result); // Internal slow-path routine used by ReadVarint64(). - Status ReadVarint64Fallback(uint64* result); + absl::Status ReadVarint64Fallback(uint64* result); // Helper method for reading a varint which can span at max `max_bytes`. // If the varint is longer, a DataLoss error status is returned. // If end of file is reached while reading, OutOfRange error is returned. template - Status ReadVarintFallback(T* result, int max_bytes); + absl::Status ReadVarintFallback(T* result, int max_bytes); RandomAccessFile* file_; // Not owned int64_t file_pos_; // Next position to read from in "file_" @@ -118,28 +119,28 @@ extern template Status InputBuffer::ReadLine(std::string* result); extern template Status InputBuffer::ReadLine(tstring* result); // Inlined for performance. -inline Status InputBuffer::ReadVarint32(uint32* result) { +inline absl::Status InputBuffer::ReadVarint32(uint32* result) { if (pos_ + core::kMaxVarint32Bytes <= limit_) { // Fast path: directly parse from buffered data. // Reads strictly from the range [pos_, limit_). const char* offset = core::GetVarint32Ptr(pos_, limit_, result); if (offset == nullptr) return errors::OutOfRange("Parsed past limit."); pos_ = const_cast(offset); - return OkStatus(); + return absl::OkStatus(); } else { return ReadVarint32Fallback(result); } } // Inlined for performance. -inline Status InputBuffer::ReadVarint64(uint64* result) { +inline absl::Status InputBuffer::ReadVarint64(uint64* result) { if (pos_ + core::kMaxVarint64Bytes <= limit_) { // Fast path: directly parse from buffered data. // Reads strictly from the range [pos_, limit_). const char* offset = core::GetVarint64Ptr(pos_, limit_, result); if (offset == nullptr) return errors::OutOfRange("Parsed past limit."); pos_ = const_cast(offset); - return OkStatus(); + return absl::OkStatus(); } else { return ReadVarint64Fallback(result); } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.cc b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.cc index 1a2f11d4d2b2b2..6425ff0656b658 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.cc @@ -24,7 +24,7 @@ namespace io { // 8MB at a time. static constexpr int64_t kMaxSkipSize = 8 * 1024 * 1024; -Status InputStreamInterface::SkipNBytes(int64_t bytes_to_skip) { +absl::Status InputStreamInterface::SkipNBytes(int64_t bytes_to_skip) { if (bytes_to_skip < 0) { return errors::InvalidArgument("Can't skip a negative number of bytes"); } @@ -35,7 +35,7 @@ Status InputStreamInterface::SkipNBytes(int64_t bytes_to_skip) { TF_RETURN_IF_ERROR(ReadNBytes(bytes_to_read, &unused)); bytes_to_skip -= bytes_to_read; } - return OkStatus(); + return absl::OkStatus(); } } // namespace io diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.h b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.h index afe87a4b9cc37e..8eb7f2ad868965 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.h @@ -35,13 +35,13 @@ class InputStreamInterface { // Reads the next bytes_to_read from the file. Typical return codes: // * OK - in case of success. // * OUT_OF_RANGE - not enough bytes remaining before end of file. - virtual Status ReadNBytes(int64_t bytes_to_read, tstring* result) = 0; + virtual absl::Status ReadNBytes(int64_t bytes_to_read, tstring* result) = 0; #if defined(TF_CORD_SUPPORT) // Reads the next bytes_to_read from the file. Typical return codes: // * OK - in case of success. // * OUT_OF_RANGE - not enough bytes remaining before end of file. - virtual Status ReadNBytes(int64_t bytes_to_read, absl::Cord* cord) { + virtual absl::Status ReadNBytes(int64_t bytes_to_read, absl::Cord* cord) { return errors::Unimplemented( "ReadNBytes(int64, absl::Cord*) is not implemented."); } @@ -51,7 +51,7 @@ class InputStreamInterface { // Typical return codes: // * OK - in case of success. // * OUT_OF_RANGE - not enough bytes remaining before end of file. - virtual Status SkipNBytes(int64_t bytes_to_skip); + virtual absl::Status SkipNBytes(int64_t bytes_to_skip); // Return the offset of the current byte relative to the beginning of the // file. @@ -61,7 +61,7 @@ class InputStreamInterface { virtual int64_t Tell() const = 0; // Resets the stream to the beginning. - virtual Status Reset() = 0; + virtual absl::Status Reset() = 0; }; } // namespace io diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc index 2f7cda954fd13d..23d4fb0ddf50bc 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc @@ -27,21 +27,21 @@ class TestStringStream : public InputStreamInterface { public: explicit TestStringStream(const string& content) : content_(content) {} - Status ReadNBytes(int64_t bytes_to_read, tstring* result) override { + absl::Status ReadNBytes(int64_t bytes_to_read, tstring* result) override { result->clear(); if (pos_ + bytes_to_read > content_.size()) { return errors::OutOfRange("limit reached"); } *result = content_.substr(pos_, bytes_to_read); pos_ += bytes_to_read; - return OkStatus(); + return absl::OkStatus(); } int64_t Tell() const override { return pos_; } - Status Reset() override { + absl::Status Reset() override { pos_ = 0; - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/iterator.cc b/third_party/xla/third_party/tsl/tsl/lib/io/iterator.cc index 4dff9eb4f61761..a02a4254985087 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/iterator.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/iterator.cc @@ -53,7 +53,7 @@ void Iterator::RegisterCleanup(CleanupFunction func, void* arg1, void* arg2) { namespace { class EmptyIterator : public Iterator { public: - explicit EmptyIterator(const Status& s) : status_(s) {} + explicit EmptyIterator(const absl::Status& s) : status_(s) {} bool Valid() const override { return false; } void Seek(const StringPiece& target) override {} void SeekToFirst() override {} @@ -66,16 +66,16 @@ class EmptyIterator : public Iterator { assert(false); return StringPiece(); } - Status status() const override { return status_; } + absl::Status status() const override { return status_; } private: - Status status_; + absl::Status status_; }; } // namespace -Iterator* NewEmptyIterator() { return new EmptyIterator(OkStatus()); } +Iterator* NewEmptyIterator() { return new EmptyIterator(absl::OkStatus()); } -Iterator* NewErrorIterator(const Status& status) { +Iterator* NewErrorIterator(const absl::Status& status) { return new EmptyIterator(status); } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/iterator.h b/third_party/xla/third_party/tsl/tsl/lib/io/iterator.h index bb83f41ea47dd9..f0b16943c44b9c 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/iterator.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/iterator.h @@ -68,7 +68,7 @@ class Iterator { virtual StringPiece value() const = 0; // If an error has occurred, return it. Else return an ok status. - virtual Status status() const = 0; + virtual absl::Status status() const = 0; // Clients are allowed to register function/arg1/arg2 triples that // will be invoked when this iterator is destroyed. @@ -96,7 +96,7 @@ class Iterator { extern Iterator* NewEmptyIterator(); // Return an empty iterator with the specified status. -extern Iterator* NewErrorIterator(const Status& status); +extern Iterator* NewErrorIterator(const absl::Status& status); } // namespace table } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.cc b/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.cc index 1b5262057771b7..841e3d1bf26f6c 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.cc @@ -30,8 +30,8 @@ RandomAccessInputStream::~RandomAccessInputStream() { } } -Status RandomAccessInputStream::ReadNBytes(int64_t bytes_to_read, - tstring* result) { +absl::Status RandomAccessInputStream::ReadNBytes(int64_t bytes_to_read, + tstring* result) { if (bytes_to_read < 0) { return errors::InvalidArgument("Cannot read negative number of bytes"); } @@ -39,7 +39,7 @@ Status RandomAccessInputStream::ReadNBytes(int64_t bytes_to_read, result->resize_uninitialized(bytes_to_read); char* result_buffer = &(*result)[0]; StringPiece data; - Status s = file_->Read(pos_, bytes_to_read, &data, result_buffer); + absl::Status s = file_->Read(pos_, bytes_to_read, &data, result_buffer); if (data.data() != result_buffer) { memmove(result_buffer, data.data(), data.size()); } @@ -51,13 +51,13 @@ Status RandomAccessInputStream::ReadNBytes(int64_t bytes_to_read, } #if defined(TF_CORD_SUPPORT) -Status RandomAccessInputStream::ReadNBytes(int64_t bytes_to_read, - absl::Cord* result) { +absl::Status RandomAccessInputStream::ReadNBytes(int64_t bytes_to_read, + absl::Cord* result) { if (bytes_to_read < 0) { return errors::InvalidArgument("Cannot read negative number of bytes"); } int64_t current_size = result->size(); - Status s = file_->Read(pos_, bytes_to_read, result); + absl::Status s = file_->Read(pos_, bytes_to_read, result); if (s.ok() || errors::IsOutOfRange(s)) { pos_ += result->size() - current_size; } @@ -69,7 +69,7 @@ Status RandomAccessInputStream::ReadNBytes(int64_t bytes_to_read, // 8MB at a time. static constexpr int64_t kMaxSkipSize = 8 * 1024 * 1024; -Status RandomAccessInputStream::SkipNBytes(int64_t bytes_to_skip) { +absl::Status RandomAccessInputStream::SkipNBytes(int64_t bytes_to_skip) { if (bytes_to_skip < 0) { return errors::InvalidArgument("Can't skip a negative number of bytes"); } @@ -78,17 +78,18 @@ Status RandomAccessInputStream::SkipNBytes(int64_t bytes_to_skip) { // not reached yet and we could return. if (bytes_to_skip > 0) { StringPiece data; - Status s = file_->Read(pos_ + bytes_to_skip - 1, 1, &data, scratch.get()); + absl::Status s = + file_->Read(pos_ + bytes_to_skip - 1, 1, &data, scratch.get()); if ((s.ok() || errors::IsOutOfRange(s)) && data.size() == 1) { pos_ += bytes_to_skip; - return OkStatus(); + return absl::OkStatus(); } } // Read kDefaultSkipSize at a time till bytes_to_skip. while (bytes_to_skip > 0) { int64_t bytes_to_read = std::min(kMaxSkipSize, bytes_to_skip); StringPiece data; - Status s = file_->Read(pos_, bytes_to_read, &data, scratch.get()); + absl::Status s = file_->Read(pos_, bytes_to_read, &data, scratch.get()); if (s.ok() || errors::IsOutOfRange(s)) { pos_ += data.size(); } else { @@ -99,7 +100,7 @@ Status RandomAccessInputStream::SkipNBytes(int64_t bytes_to_skip) { } bytes_to_skip -= bytes_to_read; } - return OkStatus(); + return absl::OkStatus(); } int64_t RandomAccessInputStream::Tell() const { return pos_; } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.h b/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.h index e1608ce3ec2b9b..4d48db62c2b03f 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.h @@ -33,22 +33,22 @@ class RandomAccessInputStream : public InputStreamInterface { ~RandomAccessInputStream() override; - Status ReadNBytes(int64_t bytes_to_read, tstring* result) override; + absl::Status ReadNBytes(int64_t bytes_to_read, tstring* result) override; #if defined(TF_CORD_SUPPORT) - Status ReadNBytes(int64_t bytes_to_read, absl::Cord* result) override; + absl::Status ReadNBytes(int64_t bytes_to_read, absl::Cord* result) override; #endif - Status SkipNBytes(int64_t bytes_to_skip) override; + absl::Status SkipNBytes(int64_t bytes_to_skip) override; int64_t Tell() const override; - Status Seek(int64_t position) { + absl::Status Seek(int64_t position) { pos_ = position; - return OkStatus(); + return absl::OkStatus(); } - Status Reset() override { return Seek(0); } + absl::Status Reset() override { return Seek(0); } private: RandomAccessFile* file_; // Not owned. diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.cc b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.cc index e267b5cee84dab..8d17c610b09f71 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.cc @@ -101,7 +101,8 @@ inline const char* GetChecksumErrorSuffix(uint64 offset) { // and is used only in error messages. For failures at offset 0, // a reminder about the file format is added, because TFRecord files // contain no explicit format marker. -Status RecordReader::ReadChecksummed(uint64 offset, size_t n, tstring* result) { +absl::Status RecordReader::ReadChecksummed(uint64 offset, size_t n, + tstring* result) { if (n >= SIZE_MAX - sizeof(uint32)) { return errors::DataLoss("record size too large", GetChecksumErrorSuffix(offset)); @@ -125,10 +126,10 @@ Status RecordReader::ReadChecksummed(uint64 offset, size_t n, tstring* result) { GetChecksumErrorSuffix(offset)); } result->resize(n); - return OkStatus(); + return absl::OkStatus(); } -Status RecordReader::GetMetadata(Metadata* md) { +absl::Status RecordReader::GetMetadata(Metadata* md) { if (!md) { return errors::InvalidArgument( "Metadata object call to GetMetadata() was null"); @@ -148,7 +149,7 @@ Status RecordReader::GetMetadata(Metadata* md) { tstring record; while (true) { // Read header, containing size of data. - Status s = ReadChecksummed(offset, sizeof(uint64), &record); + absl::Status s = ReadChecksummed(offset, sizeof(uint64), &record); if (!s.ok()) { if (errors::IsOutOfRange(s)) { // We should reach out of range when the record file is complete. @@ -178,10 +179,10 @@ Status RecordReader::GetMetadata(Metadata* md) { } md->stats = cached_metadata_->stats; - return OkStatus(); + return absl::OkStatus(); } -Status RecordReader::PositionInputStream(uint64 offset) { +absl::Status RecordReader::PositionInputStream(uint64 offset) { int64_t curr_pos = input_stream_->Tell(); int64_t desired_pos = static_cast(offset); if (curr_pos > desired_pos || curr_pos < 0 /* EOF */ || @@ -193,14 +194,14 @@ Status RecordReader::PositionInputStream(uint64 offset) { TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(desired_pos - curr_pos)); } DCHECK_EQ(desired_pos, input_stream_->Tell()); - return OkStatus(); + return absl::OkStatus(); } -Status RecordReader::ReadRecord(uint64* offset, tstring* record) { +absl::Status RecordReader::ReadRecord(uint64* offset, tstring* record) { TF_RETURN_IF_ERROR(PositionInputStream(*offset)); // Read header data. - Status s = ReadChecksummed(*offset, sizeof(uint64), record); + absl::Status s = ReadChecksummed(*offset, sizeof(uint64), record); if (!s.ok()) { last_read_failed_ = true; return s; @@ -220,14 +221,14 @@ Status RecordReader::ReadRecord(uint64* offset, tstring* record) { *offset += kHeaderSize + length + kFooterSize; DCHECK_EQ(*offset, input_stream_->Tell()); - return OkStatus(); + return absl::OkStatus(); } -Status RecordReader::SkipRecords(uint64* offset, int num_to_skip, - int* num_skipped) { +absl::Status RecordReader::SkipRecords(uint64* offset, int num_to_skip, + int* num_skipped) { TF_RETURN_IF_ERROR(PositionInputStream(*offset)); - Status s; + absl::Status s; tstring record; *num_skipped = 0; for (int i = 0; i < num_to_skip; ++i) { @@ -252,7 +253,7 @@ Status RecordReader::SkipRecords(uint64* offset, int num_to_skip, DCHECK_EQ(*offset, input_stream_->Tell()); (*num_skipped)++; } - return OkStatus(); + return absl::OkStatus(); } SequentialRecordReader::SequentialRecordReader( diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.h b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.h index 282c0daff2a5a8..61540a657324c8 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.h @@ -94,14 +94,14 @@ class RecordReader { // Read the record at "*offset" into *record and update *offset to // point to the offset of the next record. Returns OK on success, // OUT_OF_RANGE for end of file, or something else for an error. - Status ReadRecord(uint64* offset, tstring* record); + absl::Status ReadRecord(uint64* offset, tstring* record); // Skip num_to_skip record starting at "*offset" and update *offset // to point to the offset of the next num_to_skip + 1 record. // Return OK on success, OUT_OF_RANGE for end of file, or something // else for an error. "*num_skipped" records the number of records that // are actually skipped. It should be equal to num_to_skip on success. - Status SkipRecords(uint64* offset, int num_to_skip, int* num_skipped); + absl::Status SkipRecords(uint64* offset, int num_to_skip, int* num_skipped); // Return the metadata of the Record file. // @@ -112,11 +112,11 @@ class RecordReader { // so that GetMetadata() could be a const method. // // 'metadata' must not be nullptr. - Status GetMetadata(Metadata* md); + absl::Status GetMetadata(Metadata* md); private: - Status ReadChecksummed(uint64 offset, size_t n, tstring* result); - Status PositionInputStream(uint64 offset); + absl::Status ReadChecksummed(uint64 offset, size_t n, tstring* result); + absl::Status PositionInputStream(uint64 offset); RecordReaderOptions options_; std::unique_ptr input_stream_; @@ -143,7 +143,7 @@ class SequentialRecordReader { // Read the next record in the file into *record. Returns OK on success, // OUT_OF_RANGE for end of file, or something else for an error. - Status ReadRecord(tstring* record) { + absl::Status ReadRecord(tstring* record) { return underlying_.ReadRecord(&offset_, record); } @@ -151,7 +151,7 @@ class SequentialRecordReader { // OUT_OF_RANGE for end of file, or something else for an error. // "*num_skipped" records the number of records that are actually skipped. // It should be equal to num_to_skip on success. - Status SkipRecords(int num_to_skip, int* num_skipped) { + absl::Status SkipRecords(int num_to_skip, int* num_skipped) { return underlying_.SkipRecords(&offset_, num_to_skip, num_skipped); } @@ -160,13 +160,13 @@ class SequentialRecordReader { // Seek to this offset within the file and set this offset as the current // offset. Trying to seek backward will throw error. - Status SeekOffset(uint64 offset) { + absl::Status SeekOffset(uint64 offset) { if (offset < offset_) return errors::InvalidArgument( "Trying to seek offset: ", offset, " which is less than the current offset: ", offset_); offset_ = offset; - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc index 2497db348a5729..67df783112f9ee 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc @@ -226,7 +226,7 @@ TEST(RecordReaderWriterTest, TestSkipOutOfRange) { uint64 offset = 0; int num_skipped; tstring record; - Status s = reader.SkipRecords(&offset, 3, &num_skipped); + absl::Status s = reader.SkipRecords(&offset, 3, &num_skipped); EXPECT_EQ(2, num_skipped); EXPECT_EQ(error::OUT_OF_RANGE, s.code()); } @@ -254,7 +254,7 @@ TEST(RecordReaderWriterTest, TestMalformedInput) { tstring record; // At offset 0, the error message reminds of the file type. uint64 offset = 0; - Status s = reader.ReadRecord(&offset, &record); + absl::Status s = reader.ReadRecord(&offset, &record); EXPECT_EQ(error::DATA_LOSS, s.code()); EXPECT_EQ("corrupted record at 0 (Is this even a TFRecord file?)", s.message()); diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.cc b/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.cc index aace9f10e14c6d..9a6a932dd77a26 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.cc @@ -69,7 +69,7 @@ RecordWriter::RecordWriter(WritableFile* dest, ZlibOutputBuffer* zlib_output_buffer = new ZlibOutputBuffer( dest, options.zlib_options.input_buffer_size, options.zlib_options.output_buffer_size, options.zlib_options); - Status s = zlib_output_buffer->Init(); + absl::Status s = zlib_output_buffer->Init(); if (!s.ok()) { LOG(FATAL) << "Failed to initialize Zlib inputbuffer. Error: " << s.ToString(); @@ -89,17 +89,17 @@ RecordWriter::RecordWriter(WritableFile* dest, RecordWriter::~RecordWriter() { if (dest_ != nullptr) { - Status s = Close(); + absl::Status s = Close(); if (!s.ok()) { LOG(ERROR) << "Could not finish writing file: " << s; } } } -Status RecordWriter::WriteRecord(StringPiece data) { +absl::Status RecordWriter::WriteRecord(StringPiece data) { if (dest_ == nullptr) { - return Status(absl::StatusCode::kFailedPrecondition, - "Writer not initialized or previously closed"); + return absl::Status(absl::StatusCode::kFailedPrecondition, + "Writer not initialized or previously closed"); } // Format of a single record: // uint64 length @@ -116,10 +116,10 @@ Status RecordWriter::WriteRecord(StringPiece data) { } #if defined(TF_CORD_SUPPORT) -Status RecordWriter::WriteRecord(const absl::Cord& data) { +absl::Status RecordWriter::WriteRecord(const absl::Cord& data) { if (dest_ == nullptr) { - return Status(absl::StatusCode::kFailedPrecondition, - "Writer not initialized or previously closed"); + return absl::Status(absl::StatusCode::kFailedPrecondition, + "Writer not initialized or previously closed"); } // Format of a single record: // uint64 length @@ -136,21 +136,21 @@ Status RecordWriter::WriteRecord(const absl::Cord& data) { } #endif -Status RecordWriter::Close() { - if (dest_ == nullptr) return OkStatus(); +absl::Status RecordWriter::Close() { + if (dest_ == nullptr) return absl::OkStatus(); if (IsZlibCompressed(options_) || IsSnappyCompressed(options_)) { - Status s = dest_->Close(); + absl::Status s = dest_->Close(); delete dest_; dest_ = nullptr; return s; } - return OkStatus(); + return absl::OkStatus(); } -Status RecordWriter::Flush() { +absl::Status RecordWriter::Flush() { if (dest_ == nullptr) { - return Status(absl::StatusCode::kFailedPrecondition, - "Writer not initialized or previously closed"); + return absl::Status(absl::StatusCode::kFailedPrecondition, + "Writer not initialized or previously closed"); } return dest_->Flush(); } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.h b/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.h index b585cb9b52f70c..06e9a5c847910c 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.h @@ -77,22 +77,22 @@ class RecordWriter { // implicit Close() call in the destructor. ~RecordWriter(); - Status WriteRecord(StringPiece data); + absl::Status WriteRecord(StringPiece data); #if defined(TF_CORD_SUPPORT) - Status WriteRecord(const absl::Cord& data); + absl::Status WriteRecord(const absl::Cord& data); #endif // Flushes any buffered data held by underlying containers of the // RecordWriter to the WritableFile. Does *not* flush the // WritableFile. - Status Flush(); + absl::Status Flush(); // Writes all output to the file. Does *not* close the WritableFile. // // After calling Close(), any further calls to `WriteRecord()` or `Flush()` // are invalid. - Status Close(); + absl::Status Close(); // Utility method to populate TFRecord headers. Populates record-header in // "header[0,kHeaderSize-1]". The record-header is based on data[0, n-1]. diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc index f9702f2ed13997..42adf76f7ef0d3 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc @@ -55,22 +55,22 @@ class StringDest : public WritableFile { public: explicit StringDest(string* contents) : contents_(contents) {} - Status Close() override { return OkStatus(); } - Status Flush() override { return OkStatus(); } - Status Sync() override { return OkStatus(); } - Status Append(StringPiece slice) override { + absl::Status Close() override { return absl::OkStatus(); } + absl::Status Flush() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } + absl::Status Append(StringPiece slice) override { contents_->append(slice.data(), slice.size()); - return OkStatus(); + return absl::OkStatus(); } #if defined(TF_CORD_SUPPORT) - Status Append(const absl::Cord& data) override { + absl::Status Append(const absl::Cord& data) override { contents_->append(std::string(data)); - return OkStatus(); + return absl::OkStatus(); } #endif - Status Tell(int64_t* pos) override { + absl::Status Tell(int64_t* pos) override { *pos = contents_->size(); - return OkStatus(); + return absl::OkStatus(); } private: @@ -82,8 +82,8 @@ class StringSource : public RandomAccessFile { explicit StringSource(string* contents) : contents_(contents), force_error_(false) {} - Status Read(uint64 offset, size_t n, StringPiece* result, - char* scratch) const override { + absl::Status Read(uint64 offset, size_t n, StringPiece* result, + char* scratch) const override { if (force_error_) { force_error_ = false; return errors::DataLoss("read error"); @@ -97,7 +97,7 @@ class StringSource : public RandomAccessFile { n = contents_->size() - offset; } *result = StringPiece(contents_->data() + offset, n); - return OkStatus(); + return absl::OkStatus(); } void force_error() { force_error_ = true; } @@ -150,7 +150,7 @@ class RecordioTest : public ::testing::Test { reading_ = true; } tstring record; - Status s = reader_->ReadRecord(&readpos_, &record); + absl::Status s = reader_->ReadRecord(&readpos_, &record); if (s.ok()) { return record; } else if (errors::IsOutOfRange(s)) { @@ -184,7 +184,7 @@ class RecordioTest : public ::testing::Test { reading_ = true; uint64 offset = WrittenBytes() + offset_past_end; tstring record; - Status s = reader_->ReadRecord(&offset, &record); + absl::Status s = reader_->ReadRecord(&offset, &record); ASSERT_TRUE(errors::IsOutOfRange(s)) << s; } }; @@ -317,7 +317,7 @@ void TestReadError(const RecordWriterOptions& writer_options, uint64 offset = 0; tstring read; file.force_error(); - Status status = reader.ReadRecord(&offset, &read); + absl::Status status = reader.ReadRecord(&offset, &read); ASSERT_TRUE(errors::IsDataLoss(status)); ASSERT_EQ(0, offset); From 89a47214bc7cd1582547ad81c8ed1af7e6e2cee1 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 19 Jun 2024 04:20:29 -0700 Subject: [PATCH 6/6] Refactor llvm_compiler_test. We can run the CpuCompiler and GPUCompiler related tests in separate test targets. PiperOrigin-RevId: 644703636 --- third_party/xla/xla/service/BUILD | 1 + third_party/xla/xla/service/llvm_compiler.cc | 7 +- third_party/xla/xla/tests/BUILD | 38 ++- .../xla/xla/tests/llvm_compiler_test.cc | 232 +++++------------- 4 files changed, 80 insertions(+), 198 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 3d1f6e42012fa7..e1a420ff607158 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -1654,6 +1654,7 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":compiler", + "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Core", "@local_tsl//tsl/platform:denormal", "@local_tsl//tsl/profiler/lib:scoped_annotation", diff --git a/third_party/xla/xla/service/llvm_compiler.cc b/third_party/xla/xla/service/llvm_compiler.cc index 4bbcca0b484cd2..02afa91c5d9ea5 100644 --- a/third_party/xla/xla/service/llvm_compiler.cc +++ b/third_party/xla/xla/service/llvm_compiler.cc @@ -15,6 +15,11 @@ limitations under the License. #include "xla/service/llvm_compiler.h" +#include +#include +#include + +#include "absl/status/statusor.h" #include "tsl/platform/denormal.h" #include "tsl/profiler/lib/scoped_annotation.h" @@ -55,6 +60,6 @@ absl::StatusOr>> LLVMCompiler::Compile( result.push_back(std::move(executable)); } - return {std::move(result)}; + return std::move(result); } } // namespace xla diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 69c591200c9025..b841b47a70624f 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -12,10 +12,6 @@ load( "if_cuda_is_configured", ) load("//xla:xla.bzl", "tests_build_defs_bzl_deps", "xla_cc_binary", "xla_cc_test") -load( - "//xla/stream_executor:build_defs.bzl", - "if_gpu_is_configured", -) load("//xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend_test_macros", "xla_test", "xla_test_library") load("//xla/tsl:tsl.bzl", "internal_visibility", "tsl_copts") load("//xla/tsl:tsl.default.bzl", "filegroup") @@ -2462,32 +2458,30 @@ xla_test( xla_test( name = "llvm_compiler_test", srcs = ["llvm_compiler_test.cc"], - backends = ["gpu"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM", - ]), - # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly - tags = ["gpu"], - deps = if_gpu_is_configured([ - ":verified_hlo_module", + backend_tags = { + # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly. + "gpu": ["gpu"], + }, + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":hlo_test_base", "//xla:literal_util", "//xla:test_helpers", "//xla/hlo/ir:hlo", - "@com_google_absl//absl/status", + "//xla/hlo/ir:hlo_module_group", "//xla/service:backend", - "//xla/service:cpu_plugin", "//xla/service:llvm_compiler", - "//xla/service:platform_util", - "//xla/service/cpu:cpu_compiler", - "//xla/service/gpu:gpu_compiler", "//xla/stream_executor", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Core", + "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:test_main", - ]) + if_cuda_is_configured([ - "//xla/stream_executor/cuda:cuda_platform_id", - ]) + if_rocm_is_configured([ - "//xla/stream_executor/rocm:rocm_platform_id", - ]), + ], ) xla_test( diff --git a/third_party/xla/xla/tests/llvm_compiler_test.cc b/third_party/xla/xla/tests/llvm_compiler_test.cc index dd15cbbeba4cc2..238482b650c3d6 100644 --- a/third_party/xla/xla/tests/llvm_compiler_test.cc +++ b/third_party/xla/xla/tests/llvm_compiler_test.cc @@ -21,199 +21,81 @@ limitations under the License. #include #include +#include #include "absl/status/status.h" +#include "llvm/IR/Module.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module_group.h" #include "xla/literal_util.h" #include "xla/service/backend.h" -#include "xla/service/cpu/cpu_compiler.h" -#include "xla/service/gpu/gpu_compiler.h" -#include "xla/service/platform_util.h" -#if GOOGLE_CUDA -#include "xla/stream_executor/cuda/cuda_platform_id.h" -#elif TENSORFLOW_USE_ROCM -#include "xla/stream_executor/rocm/rocm_platform_id.h" -#endif #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/stream_executor.h" #include "xla/test_helpers.h" -#include "xla/tests/verified_hlo_module.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/casts.h" #include "tsl/platform/threadpool.h" namespace xla { -namespace gpu { - -// Creating dummy data structure needed to initialize a GpuDummyCompiler -constexpr char kDummyTriple[] = "dummy-triple"; -constexpr char kDummyLayout[] = "e"; -const se::Platform::Id kGpuPlatformId = -#if GOOGLE_CUDA - se::cuda::kCudaPlatformId; -#elif TENSORFLOW_USE_ROCM - se::rocm::kROCmPlatformId; -#endif -// This class is a dummy implementation of GpuCompiler and is targeted for unit -// test only -class GpuDummyCompiler : public GpuCompiler { - public: - GpuDummyCompiler() - : GpuCompiler(kGpuPlatformId, kDummyTriple, kDummyLayout) {} - - int32_t GetToolkitVersion() const override { return 0; } - - absl::Status OptimizeHloConvolutionCanonicalization( - HloModule* hlo_module, se::GpuComputeCapability gpu_version, - se::dnn::VersionInfo dnn_version, - se::DeviceMemoryAllocator* device_allocator) { - return absl::OkStatus(); - } - - absl::Status OptimizeHloPostLayoutAssignment( - HloModule* hlo_module, se::StreamExecutor* stream_executor, - const CompileOptions& options, const TargetConfig& gpu_target_config, - tsl::thread::ThreadPool* thread_pool) override { - return absl::OkStatus(); - } +namespace { - absl::StatusOr CompileTargetBinary( - const HloModuleConfig& module_config, llvm::Module* llvm_module, - se::GpuComputeCapability gpu_version, bool relocatable, - const HloModule* debug_module, const CompileOptions& options) override { - return BackendCompileResult{}; - } -}; -} // namespace gpu +using LLVMCompilerTest = HloTestBase; -namespace { +const char* const kHloText = R"( +HloModule Constant -class LLVMCompilerTest : public ::testing::Test { - public: - void SetUp() override { - Platform* platform = FindPlatform(); - ASSERT_NE(platform, nullptr); - - BackendOptions backend_options; - backend_options.set_platform(platform); - absl::StatusOr> backend_or_status = - Backend::CreateBackend(backend_options); - ASSERT_IS_OK(backend_or_status.status()); - backend_ = std::move(backend_or_status).value(); - } - - ~LLVMCompilerTest() override {} - - protected: - using Platform = se::Platform; - - explicit LLVMCompilerTest(std::string platform_name) - : platform_name_(std::move(platform_name)) {} - - void TestCompilerHooks(LLVMCompiler* compiler) { - int pre_opt_hook_call_count = 0; - int post_opt_hook_call_count = 0; - - auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module&) { - ++pre_opt_hook_call_count; - return absl::OkStatus(); - }; - auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module&) { - ++post_opt_hook_call_count; - return absl::OkStatus(); - }; - - // Create HLO module, and run the compiler. - auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - - auto hlo_module = CreateNewVerifiedModule(); - hlo_module->AddEntryComputation(builder.Build()); - - compiler->SetPreOptimizationHook(pre_opt_hook); - compiler->SetPostOptimizationHook(post_opt_hook); - - ASSERT_TRUE(compiler - ->RunBackend(std::move(hlo_module), - backend_->default_stream_executor(), - /*device_allocator=*/nullptr) - .ok()); - - // Test that hooks were called. - EXPECT_EQ(1, pre_opt_hook_call_count); - EXPECT_EQ(1, post_opt_hook_call_count); - } - - void TestMultiModuleCompilation(LLVMCompiler* compiler) { - HloComputation::Builder builder(TestName()); - builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - - std::unique_ptr hlo_module = CreateNewVerifiedModule(); - hlo_module->AddEntryComputation(builder.Build()); - - auto module_group = std::make_unique("test_module_group"); - module_group->push_back(hlo_module->Clone()); - module_group->push_back(std::move(hlo_module)); - - std::vector> executors; - executors.push_back({backend_->default_stream_executor()}); - executors.push_back({backend_->default_stream_executor()}); - - EXPECT_IS_OK(compiler->Compile(std::move(module_group), - std::move(executors), - /*device_allocator=*/nullptr)); - } - - private: - Platform* FindPlatform() { - auto status_or_platform = PlatformUtil::GetPlatform(platform_name_); - return status_or_platform.ok() ? status_or_platform.value() : nullptr; - } - - std::string platform_name_; - std::unique_ptr backend_; - - static std::string TestName() { - return ::testing::UnitTest::GetInstance()->current_test_info()->name(); - } - - std::unique_ptr CreateNewVerifiedModule() { - HloModuleConfig config; - config.set_debug_options(GetDebugOptionsFromFlags()); - return std::make_unique( - TestName(), config, /*verifier_layout_sensitive=*/false, - /*allow_mixed_precision_in_hlo_verifier=*/true, - backend_->compiler()->ShapeSizeBytesFunction()); - } -}; - -class CpuCompilerTest : public LLVMCompilerTest { - public: - CpuCompilerTest() : LLVMCompilerTest("Host") {} -}; - -class GpuCompilerTest : public LLVMCompilerTest { - public: - GpuCompilerTest() : LLVMCompilerTest("GPU") {} -}; - -TEST_F(CpuCompilerTest, HooksTest) { - cpu::CpuCompiler compiler; - TestCompilerHooks(&compiler); +ENTRY main { + ROOT constant = f32[] constant(42.0) } +)"; -TEST_F(GpuCompilerTest, HooksTest) { - gpu::GpuDummyCompiler compiler; - TestCompilerHooks(&compiler); -} +TEST_F(LLVMCompilerTest, HooksTest) { + int pre_opt_hook_call_count = 0; + int post_opt_hook_call_count = 0; -TEST_F(CpuCompilerTest, CpuMultiModuleCompilation) { - cpu::CpuCompiler compiler; - TestMultiModuleCompilation(&compiler); + auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module&) { + ++pre_opt_hook_call_count; + return absl::OkStatus(); + }; + auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module&) { + ++post_opt_hook_call_count; + return absl::OkStatus(); + }; + + // Create HLO module, and run the compiler. + auto hlo_module = ParseAndReturnVerifiedModule(kHloText).value(); + LLVMCompiler* compiler = + tensorflow::down_cast(backend().compiler()); + compiler->SetPreOptimizationHook(pre_opt_hook); + compiler->SetPostOptimizationHook(post_opt_hook); + + ASSERT_TRUE(compiler + ->RunBackend(std::move(hlo_module), + backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ok()); + + // Test that hooks were called. + EXPECT_EQ(1, pre_opt_hook_call_count); + EXPECT_EQ(1, post_opt_hook_call_count); } -TEST_F(GpuCompilerTest, GpuMultModuleCompilation) { - gpu::GpuDummyCompiler compiler; - TestMultiModuleCompilation(&compiler); +TEST_F(LLVMCompilerTest, DISABLED_MultiModuleCompilation) { + auto hlo_module = ParseAndReturnVerifiedModule(kHloText).value(); + auto hlo_module2 = ParseAndReturnVerifiedModule(kHloText).value(); + std::vector> modules; + modules.push_back(std::move(hlo_module)); + modules.push_back(std::move(hlo_module2)); + auto module_group = + std::make_unique("test_module_group", std::move(modules)); + + std::vector> executors; + executors.push_back({backend().default_stream_executor()}); + executors.push_back({backend().default_stream_executor()}); + + EXPECT_IS_OK(backend().compiler()->Compile(std::move(module_group), + std::move(executors), + backend().memory_allocator())); } + } // namespace } // namespace xla