diff --git a/third_party/xla/third_party/tsl/tsl/platform/env.cc b/third_party/xla/third_party/tsl/tsl/platform/env.cc index 77f48b3372e1eb..789725e8856b94 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/env.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/env.cc @@ -328,7 +328,8 @@ absl::Status Env::HasAtomicMove(const string& path, bool* has_atomic_move) { return fs->HasAtomicMove(path, has_atomic_move); } -Status Env::CanCreateTempFile(const string& fname, bool* can_create_temp_file) { +absl::Status Env::CanCreateTempFile(const string& fname, + bool* can_create_temp_file) { FileSystem* fs; TF_RETURN_IF_ERROR(GetFileSystemForFile(fname, &fs)); return fs->CanCreateTempFile(fname, can_create_temp_file); diff --git a/third_party/xla/third_party/tsl/tsl/platform/env.h b/third_party/xla/third_party/tsl/tsl/platform/env.h index 37abc1dee97d54..0952517f9b7f8c 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/env.h +++ b/third_party/xla/third_party/tsl/tsl/platform/env.h @@ -352,8 +352,8 @@ class Env { /// If this returns false, TensorFlow will write directly to output files /// instead of creating a temporary file and swapping it in. This may mean /// that incomplete writes are visible to consumers. - Status CanCreateTempFile(const std::string& fname, - bool* can_create_temp_file); + absl::Status CanCreateTempFile(const std::string& fname, + bool* can_create_temp_file); /// Stores the size of `fname` in `*file_size`. absl::Status GetFileSize(const std::string& fname, uint64* file_size); diff --git a/third_party/xla/third_party/tsl/tsl/platform/file_system.cc b/third_party/xla/third_party/tsl/tsl/platform/file_system.cc index 68d0fcf0499ca5..ee385af7354074 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/file_system.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/file_system.cc @@ -95,10 +95,10 @@ absl::Status FileSystem::HasAtomicMove(const string& path, return absl::OkStatus(); } -Status FileSystem::CanCreateTempFile(const std::string& fname, - bool* can_create_temp_file) { +absl::Status FileSystem::CanCreateTempFile(const std::string& fname, + bool* can_create_temp_file) { *can_create_temp_file = true; - return OkStatus(); + return absl::OkStatus(); } void FileSystem::FlushCaches(TransactionToken* token) {} diff --git a/third_party/xla/third_party/tsl/tsl/platform/file_system.h b/third_party/xla/third_party/tsl/tsl/platform/file_system.h index 4b728a42c4d507..67209ed491055f 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/file_system.h +++ b/third_party/xla/third_party/tsl/tsl/platform/file_system.h @@ -392,8 +392,8 @@ class FileSystem { /// to determine if there needs to be a temp location to safely write objects. /// If the file system cannot create a temp file, it's possibile that /// uncomplete result may appear in the given file. - virtual Status CanCreateTempFile(const std::string& fname, - bool* can_create_temp_file); + virtual absl::Status CanCreateTempFile(const std::string& fname, + bool* can_create_temp_file); /// \brief Flushes any cached filesystem objects from memory. virtual void FlushCaches() { FlushCaches(nullptr); } diff --git a/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system.h b/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system.h index 88da8787c3d618..a64ecc20e960ff 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system.h +++ b/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system.h @@ -151,8 +151,8 @@ class RetryingFileSystem : public FileSystem { return base_file_system_->HasAtomicMove(path, has_atomic_move); } - Status CanCreateTempFile(const std::string& fname, - bool* can_create_temp_file) override { + absl::Status CanCreateTempFile(const std::string& fname, + bool* can_create_temp_file) override { // this method does not need to be retried return base_file_system_->CanCreateTempFile(fname, can_create_temp_file); } diff --git a/third_party/xla/xla/client/BUILD b/third_party/xla/xla/client/BUILD index e0c53df856f3b1..400eb877cbbce1 100644 --- a/third_party/xla/xla/client/BUILD +++ b/third_party/xla/xla/client/BUILD @@ -29,7 +29,6 @@ filegroup( cc_library( name = "global_data", - srcs = ["global_data.cc"], hdrs = ["global_data.h"], deps = [ "//xla:types", diff --git a/third_party/xla/xla/client/client.cc b/third_party/xla/xla/client/client.cc index a58737a181fd84..6e89947f237ff7 100644 --- a/third_party/xla/xla/client/client.cc +++ b/third_party/xla/xla/client/client.cc @@ -41,129 +41,28 @@ Client::~Client() = default; absl::StatusOr Client::Transfer(const GlobalData& data, const Shape* shape_with_layout) { - TransferToClientRequest request; - *request.mutable_data() = data.handle(); - if (shape_with_layout != nullptr) { - *request.mutable_shape_with_layout() = shape_with_layout->ToProto(); - } - TransferToClientResponse response; - - VLOG(1) << "making transfer request"; - VLOG(3) << "TransferToClientRequest: {" << request.DebugString() << "}"; - absl::Status s = stub_->TransferToClient(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - VLOG(3) << "TransferToClientResponse: {" << response.DebugString() << "}"; - - if (!response.has_literal()) { - return FailedPrecondition( - "server provided response without a literal in " - "TransferToClient request"); - } - return Literal::CreateFromProto(*response.mutable_literal()); + return stub_->TransferToClient(data, shape_with_layout); } absl::StatusOr> Client::TransferToServer( const LiteralSlice& literal, const DeviceHandle* device_handle) { - TransferToServerRequest request; - *request.mutable_literal() = literal.ToProto(); - if (device_handle) { - *request.mutable_device_handle() = *device_handle; - } - TransferToServerResponse response; - - VLOG(1) << "making transfer to server request"; - VLOG(3) << "TransferToServerRequest: {" << request.DebugString() << "}"; - absl::Status s = stub_->TransferToServer(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - VLOG(3) << "TransferToServerResponse: {" << response.DebugString() << "}"; - - if (!response.has_data()) { - return FailedPrecondition( - "server provided response without a data handle in " - "TransferToServer request"); - } - - return std::make_unique(stub_, response.data()); + return stub_->TransferToServer(literal, device_handle); } absl::Status Client::TransferToInfeed(const LiteralSlice& literal, int64_t replica_id, const DeviceHandle* device_handle) { - TransferToInfeedRequest request; - *request.mutable_literal() = literal.ToProto(); - if (device_handle) { - *request.mutable_device_handle() = *device_handle; - } - request.set_replica_id(replica_id); - TransferToInfeedResponse response; - - VLOG(1) << "making transfer to infeed request"; - VLOG(3) << "TransferToInfeedRequest: {" << request.DebugString() << "}"; - absl::Status s = stub_->TransferToInfeed(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - VLOG(3) << "TransferToInfeedResponse: {" << response.DebugString() << "}"; - return absl::OkStatus(); + return stub_->TransferToInfeed(literal, replica_id, device_handle); } absl::StatusOr Client::TransferFromOutfeed( const Shape* shape_with_layout, int64_t replica_id, const DeviceHandle* device_handle) { - TransferFromOutfeedRequest request; - if (device_handle) { - *request.mutable_device_handle() = *device_handle; - } - request.set_replica_id(replica_id); - if (shape_with_layout != nullptr) { - *request.mutable_shape_with_layout() = shape_with_layout->ToProto(); - } - TransferFromOutfeedResponse response; - - VLOG(1) << "making transfer from outfeed request"; - VLOG(3) << "TransferFromOutfeedRequest: {" << request.DebugString() << "}"; - absl::Status s = stub_->TransferFromOutfeed(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - VLOG(3) << "TransferFromOutfeedResponse: {" << response.DebugString() << "}"; - - if (!response.has_literal()) { - return FailedPrecondition( - "server provided response without a literal in " - "TransferToClient request"); - } - - return Literal::CreateFromProto(response.literal()); + return stub_->TransferFromOutfeed(shape_with_layout, replica_id, + device_handle); } -absl::Status Client::ResetDevice() { - ResetDeviceRequest request; - ResetDeviceResponse response; - - VLOG(1) << "making reset device request"; - VLOG(3) << "ResetDeviceRequest: {" << request.DebugString() << "}"; - absl::Status s = stub_->ResetDevice(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - VLOG(3) << "ResetDeviceResponse: {" << response.DebugString() << "}"; - return absl::OkStatus(); -} +absl::Status Client::ResetDevice() { return stub_->ResetDevice(); } absl::StatusOr Client::ExecuteAndTransfer( const XlaComputation& computation, absl::Span arguments, @@ -185,30 +84,7 @@ absl::StatusOr Client::ExecuteAndTransfer( absl::StatusOr Client::ComputeConstant( const XlaComputation& computation, const Layout* output_layout) const { - ComputeConstantGraphRequest request; - *request.mutable_computation() = computation.proto(); - if (output_layout != nullptr) { - *request.mutable_output_layout() = output_layout->ToProto(); - } - - ComputeConstantResponse response; - - VLOG(2) << "making compute-constant-graph request"; - absl::Status s = stub_->ComputeConstantGraph(&request, &response); - VLOG(2) << "done with request"; - - if (!s.ok()) { - return s; - } - - VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}"; - - if (!response.has_literal()) { - return Internal( - "no computed literal in the provided response in ComputeConstantGraph " - "request"); - } - return Literal::CreateFromProto(response.literal()); + return stub_->ComputeConstantGraph(computation, output_layout); } absl::StatusOr Client::LoadSnapshot(const HloSnapshot& module) { @@ -219,61 +95,19 @@ absl::StatusOr Client::LoadSnapshot(const HloSnapshot& module) { absl::StatusOr Client::Compile( const XlaComputation& computation, absl::Span argument_shapes, const ExecutionOptions* execution_options) { - CompileRequest request; - *request.mutable_computation() = computation.proto(); - - if (execution_options == nullptr) { - *request.mutable_execution_options() = CreateDefaultExecutionOptions(); - } else { - *request.mutable_execution_options() = *execution_options; - } - if (request.execution_options().device_handles_size() > 1) { - return InvalidArgument( - "Compiling with multiple device handles is not supported. Use " - "'Execute' instead."); - } - - // The argument shapes affect how the computation is compiled. - for (const auto& arg_shape : argument_shapes) { - *request.add_input_shape_with_layout() = arg_shape.ToProto(); + std::optional opts; + if (!execution_options) { + opts = CreateDefaultExecutionOptions(); } - CompileResponse response; - VLOG(1) << "making compile request: " << request.ShortDebugString(); - absl::Status s = stub_->Compile(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - TF_RET_CHECK(response.has_handle()); - return response.handle(); + return stub_->Compile(computation, argument_shapes, + execution_options ? *execution_options : *opts); } absl::StatusOr> Client::Execute( const ExecutionHandle& handle, absl::Span arguments, ExecutionProfile* execution_profile) { - ExecuteRequest request; - *request.mutable_handle() = handle; - for (GlobalData* argument : arguments) { - CHECK(argument != nullptr) << "Argument pointers must not be null."; - *request.add_arguments() = argument->handle(); - } - - ExecuteResponse response; - VLOG(1) << "making execute request: " << request.ShortDebugString(); - absl::Status s = stub_->Execute(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - if (execution_profile != nullptr) { - *execution_profile = response.profile(); - } - - return std::make_unique(stub_, response.output()); + return stub_->Execute(handle, arguments, execution_profile); } absl::StatusOr> Client::Execute( @@ -329,39 +163,7 @@ absl::StatusOr> Client::Execute( absl::StatusOr>> Client::ExecuteParallel(absl::Span computations) { - ExecuteGraphParallelRequest request; - - for (const XlaComputationInstance& computation : computations) { - ExecuteGraphRequest single_request; - *single_request.mutable_computation() = computation.computation.proto(); - for (GlobalData* argument : computation.arguments) { - *single_request.add_arguments() = argument->handle(); - } - *single_request.mutable_execution_options() = computation.execution_options; - *request.add_requests() = single_request; - } - - ExecuteParallelResponse response; - VLOG(1) << "making execute-graph-parallel request: " - << request.ShortDebugString(); - absl::Status s = stub_->ExecuteGraphParallel(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - std::vector> outputs; - for (size_t i = 0, end = response.responses_size(); i < end; ++i) { - outputs.push_back( - std::make_unique(stub_, response.responses(i).output())); - if (i < computations.size() && - computations[i].execution_profile != nullptr) { - *computations[i].execution_profile = response.responses(i).profile(); - } - } - - return std::move(outputs); + return stub_->ExecuteGraphParallel(computations); } absl::StatusOr> Client::GetDeviceHandles( @@ -369,59 +171,17 @@ absl::StatusOr> Client::GetDeviceHandles( if (device_count < 1) { return InvalidArgument("device_count must be greater than 0"); } - GetDeviceHandlesRequest request; - request.set_device_count(device_count); - GetDeviceHandlesResponse response; - VLOG(1) << "making get device request: " << request.ShortDebugString(); - absl::Status s = stub_->GetDeviceHandles(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - std::vector device_handles; - const auto& response_device_handles = response.device_handles(); - device_handles.reserve(response_device_handles.size()); - for (const DeviceHandle& device_handle : response_device_handles) { - device_handles.push_back(device_handle); - } - - return device_handles; + return stub_->GetDeviceHandles(device_count); } absl::Status Client::Unregister(const GlobalData& data) { - UnregisterRequest request; - *request.add_data() = data.handle(); - UnregisterResponse response; - - VLOG(1) << "making unregister request"; - absl::Status s = stub_->Unregister(&request, &response); - VLOG(1) << "done with request"; - - return s; + return stub_->Unregister(data.handle()); } absl::StatusOr>> Client::DeconstructTuple(const GlobalData& data) { - DeconstructTupleRequest request; - *request.mutable_tuple_handle() = data.handle(); - DeconstructTupleResponse response; - - VLOG(1) << "making DestructTuple request"; - absl::Status s = stub_->DeconstructTuple(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - std::vector> handles; - for (auto& handle : response.element_handles()) { - handles.push_back(std::make_unique(stub_, handle)); - } - return std::move(handles); + return stub_->DeconstructTuple(data); } absl::StatusOr> Client::GetComputationShape( @@ -431,36 +191,12 @@ absl::StatusOr> Client::GetComputationShape( } absl::StatusOr Client::GetShape(const GlobalData& data) { - GetShapeRequest request; - *request.mutable_data() = data.handle(); - GetShapeResponse response; - - VLOG(1) << "making get shape request"; - absl::Status s = stub_->GetShape(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - return Shape(response.shape()); + return stub_->GetShape(data); } absl::StatusOr Client::CreateChannelHandleByType( ChannelHandle::ChannelType type) { - CreateChannelHandleRequest request; - request.set_channel_type(type); - CreateChannelHandleResponse response; - - VLOG(1) << "making create channel handle request"; - absl::Status s = stub_->CreateChannelHandle(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - return response.channel(); + return stub_->CreateChannelHandle(type); } absl::StatusOr Client::CreateChannelHandle() { diff --git a/third_party/xla/xla/client/client.h b/third_party/xla/xla/client/client.h index 66d82ec79f744e..1ecfcfe6f358eb 100644 --- a/third_party/xla/xla/client/client.h +++ b/third_party/xla/xla/client/client.h @@ -22,7 +22,6 @@ limitations under the License. #include #include "absl/types/span.h" -#include "xla/client/global_data.h" #include "xla/client/xla_computation.h" #include "xla/literal.h" #include "xla/service/hlo.pb.h" @@ -41,6 +40,8 @@ class Client { explicit Client(Service* stub); virtual ~Client(); + using XlaComputationInstance = XlaComputationInstance; + // Compile the computation with the given argument shapes and returns the // handle to the compiled executable. The compiled executable is cached on the // service, and the returned handle can be used for execution without @@ -70,7 +71,9 @@ class Client { // will be filled with profile data from the execution. absl::StatusOr> Execute( const ExecutionHandle& handle, absl::Span arguments, - ExecutionProfile* execution_profile = nullptr); + ExecutionProfile* execution_profile = nullptr + + ); // Executes the computation with the given arguments and returns the global // data that was produced from the execution. @@ -93,26 +96,6 @@ class Client { const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); - // A struct to represent a computation instance to be executed. - // * If execution_options.device_handles is not empty, the computation is - // executed on the devices associated with the handles by partitioning the - // computation based on the attached sharding attributes. Otherwise, a - // device is chosen by the service. - struct XlaComputationInstance { - const XlaComputation& computation; - std::vector arguments; - ExecutionOptions execution_options; - ExecutionProfile* execution_profile; - - XlaComputationInstance(const XlaComputation& computation, - std::vector arguments, - ExecutionOptions execution_options, - ExecutionProfile* execution_profile) - : computation(computation), - arguments(std::move(arguments)), - execution_options(execution_options), - execution_profile(execution_profile) {} - }; // Executes a list XlaComputationInstances and returns global data produced // from each computation. diff --git a/third_party/xla/xla/client/global_data.cc b/third_party/xla/xla/client/global_data.cc deleted file mode 100644 index 66a5c5fee61673..00000000000000 --- a/third_party/xla/xla/client/global_data.cc +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/client/global_data.h" - -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "xla/types.h" -#include "tsl/platform/logging.h" - -namespace xla { -namespace { - -// Releases a set of global data handles owned by the parent service -// interface. -void ReleaseHandles(Service* parent, - const absl::Span handles) { - UnregisterRequest request; - for (auto& handle : handles) { - VLOG(1) << "Requesting to unregister " << handle.ShortDebugString(); - *request.add_data() = handle; - } - UnregisterResponse response; - absl::Status status = parent->Unregister(&request, &response); - VLOG(1) << "Done with request"; - if (!status.ok()) { - LOG(WARNING) << "Failed to unregister handles: " << status - << "; continuing anyway..."; - } -} - -} // namespace - -GlobalData::GlobalData(Service* parent, GlobalDataHandle handle) - : handle_(std::move(handle)), parent_(parent) {} - -GlobalData::~GlobalData() { - if (parent_ != nullptr) { - ReleaseHandles(parent_, {handle_}); - } -} - -/* static */ void GlobalData::Release( - std::vector> instances) { - absl::flat_hash_map> - parent_handles_map; - for (auto& instance : instances) { - if (instance->parent_ != nullptr) { - parent_handles_map[instance->parent_].push_back(instance->Release()); - } - } - for (auto& parent_handles : parent_handles_map) { - ReleaseHandles(parent_handles.first, parent_handles.second); - } -} - -} // namespace xla diff --git a/third_party/xla/xla/client/global_data.h b/third_party/xla/xla/client/global_data.h index 790a92a97c26bc..d47209edee377b 100644 --- a/third_party/xla/xla/client/global_data.h +++ b/third_party/xla/xla/client/global_data.h @@ -26,37 +26,8 @@ limitations under the License. namespace xla { -// A GlobalData object represents a globally-accessible allocation of -// data in the associated XLA service. -class GlobalData { - public: - // Gives ownership of the global data handle to this object. - GlobalData(Service* parent, GlobalDataHandle handle); - - // Unregisters the wrapped handle, which causes the service to - // deallocate the associated data. - ~GlobalData(); - - const GlobalDataHandle& handle() const { return handle_; } - - // Releases a set of GlobalData handles. A single RPC will be issued - // per unique Service of the given GlobalData objects. - static void Release(std::vector> instances); - - private: - // Detaches the global data handle from the object, such that the destructor - // will not try to release it. - GlobalDataHandle Release() { - parent_ = nullptr; - return handle_; - } - - GlobalDataHandle handle_; // Handle being wrapped. - Service* parent_; // Service used to unregister handle_. - - GlobalData(const GlobalData&) = delete; - GlobalData& operator=(const GlobalData&) = delete; -}; +// TODO(cheshire): Remove. +// Deprecated target for backwards compatibility. } // namespace xla diff --git a/third_party/xla/xla/client/local_client.cc b/third_party/xla/xla/client/local_client.cc index b45cb62eaad3cb..e00f39143bb6ee 100644 --- a/third_party/xla/xla/client/local_client.cc +++ b/third_party/xla/xla/client/local_client.cc @@ -493,7 +493,7 @@ absl::StatusOr LocalClient::ReplicaNumberToDeviceOrdinal( return local_service_->ReplicaNumberToDeviceOrdinal(replica_number); } -absl::StatusOr LocalClient::TransferToLocalServer( +absl::StatusOr LocalClient::TransferToLocalServer( const ::xla::BorrowingLiteral& literal, int device_ordinal) { const ::xla::Shape& shape = literal.shape(); @@ -506,14 +506,13 @@ absl::StatusOr LocalClient::TransferToLocalServer( stream.get(), literal, shaped_buffer)); std::vector<::xla::ScopedShapedBuffer> replicated_buffer; replicated_buffer.emplace_back(std::move(shaped_buffer)); - ::xla::TransferToServerResponse result; - TF_ASSIGN_OR_RETURN(*result.mutable_data(), + TF_ASSIGN_OR_RETURN(GlobalDataHandle data, local_service_->RegisterReplicatedBuffers( std::move(replicated_buffer), absl::StrCat("TransferToServer literal of shape ", ::xla::ShapeUtil::HumanString(shape)))); - return result; + return data; } } // namespace xla diff --git a/third_party/xla/xla/client/local_client.h b/third_party/xla/xla/client/local_client.h index 0d7ee75b07beab..f26c67ced132c8 100644 --- a/third_party/xla/xla/client/local_client.h +++ b/third_party/xla/xla/client/local_client.h @@ -171,7 +171,7 @@ class LocalClient : public Client { se::DeviceMemoryAllocator* allocator = nullptr); // Transfer the BorrowingLiteral to the device with the given ordinal. - absl::StatusOr TransferToLocalServer( + absl::StatusOr TransferToLocalServer( const ::xla::BorrowingLiteral& literal, int device_ordinal); // Copy the data from the device contained in the given ShapedBuffer and diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 99bb153d1d32c1..e1a420ff607158 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -1238,6 +1238,7 @@ cc_library( "//xla:debug_options_flags", "//xla:executable_run_options", "//xla:execution_options_util", + "//xla:literal", "//xla:shape_layout", "//xla:shape_util", "//xla:status_macros", @@ -1245,11 +1246,13 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", + "//xla/client:xla_computation", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -1651,6 +1654,7 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":compiler", + "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Core", "@local_tsl//tsl/platform:denormal", "@local_tsl//tsl/profiler/lib:scoped_annotation", diff --git a/third_party/xla/xla/service/compile_only_service.h b/third_party/xla/xla/service/compile_only_service.h index 27f4b4c97626f1..0238a16f282946 100644 --- a/third_party/xla/xla/service/compile_only_service.h +++ b/third_party/xla/xla/service/compile_only_service.h @@ -53,30 +53,23 @@ class CompileOnlyService : public Service { const AotCompilationOptions& options, std::unique_ptr* metadata); - absl::Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) override { + absl::StatusOr> GetDeviceHandles( + int64_t device_count) override { return Unimplemented("CompileOnlyService does not support devices."); } - absl::Status TransferToServer(const TransferToServerRequest* arg, - TransferToServerResponse* result) override { - return Unimplemented( - "CompileOnlyService does not support device data transfers."); - } - absl::Status TransferToInfeed(const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) override { + + absl::StatusOr> TransferToServer( + const LiteralSlice& literal_slice, + const DeviceHandle* device_handle) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - absl::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) override { + + absl::Status TransferToInfeed(const LiteralSlice& literal, int64_t replica_id, + const DeviceHandle* device_handle) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - absl::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) override { - return Unimplemented("CompileOnlyService does not support devices."); - } private: explicit CompileOnlyService(const ServiceOptions& options, diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index bf3e191aec5e35..dba824f313e4d6 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -671,6 +671,7 @@ cc_library( deps = [ ":gpu_device_info_for_tests", ":gpu_float_support", + ":ir_emission_utils", ":ir_emitter_triton", ":matmul_utils", "//xla:shape_util", @@ -684,6 +685,7 @@ cc_library( "//xla/stream_executor:device_description", "//xla/tests:filecheck", "//xla/tests:verified_hlo_module", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -693,6 +695,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc index c58f31b5207a7a..a0dc62fbc7155d 100644 --- a/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc @@ -33,7 +33,7 @@ SequentialThunk::SequentialThunk(ThunkInfo thunk_info, ThunkSequence thunks) std::string SequentialThunk::ToStringExtra(int indent) const { std::string result = "\n"; - absl::StrAppend(&result, thunks().ToString(indent + 1, nullptr)); + absl::StrAppend(&result, thunks().ToString(indent + 1)); return result; } diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.cc b/third_party/xla/xla/service/gpu/runtime/thunk.cc index 0c95623db733be..877db05fe04e89 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/thunk.cc @@ -303,9 +303,7 @@ std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) { return os << Thunk::KindToString(kind); } -std::string ThunkSequence::ToString( - int indent, - std::function get_thunk_annotation) const { +std::string ThunkSequence::ToString(int indent) const { const std::string indent_str(indent * 2, ' '); if (empty()) return indent_str + "No thunks."; @@ -324,9 +322,6 @@ std::string ThunkSequence::ToString( absl::StrAppend(&result, indent_str, kind_str, std::string(max_thunk_kind_len - kind_str.length(), ' '), "\t"); - if (get_thunk_annotation) { - absl::StrAppend(&result, get_thunk_annotation(thunk.get())); - } absl::StrAppend(&result, thunk->ToStringExtra(indent)); absl::StrAppend(&result, "\n"); } diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.h b/third_party/xla/xla/service/gpu/runtime/thunk.h index 6a271ce8f623ba..78698859bf8553 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/thunk.h @@ -460,9 +460,7 @@ class Thunk { // A sequence of thunks. class ThunkSequence : public std::vector> { public: - std::string ToString(int indent = 0, - std::function - get_thunk_annotation = nullptr) const; + std::string ToString(int indent = 0) const; }; std::ostream& operator<<(std::ostream& os, Thunk::Kind kind); diff --git a/third_party/xla/xla/service/gpu/triton_support_test.cc b/third_party/xla/xla/service/gpu/triton_support_test.cc index 6c690bad4c47a9..3cf75a4cbad7e9 100644 --- a/third_party/xla/xla/service/gpu/triton_support_test.cc +++ b/third_party/xla/xla/service/gpu/triton_support_test.cc @@ -72,16 +72,9 @@ using BitcastOrReshapeTest = TritonSupportTestWithParam; TEST_P(BitcastOrReshapeTest, IsTritonSupportedBitcastOrReshape) { auto [data_type, opcode] = GetParam(); const std::string kHloTestTemplate = R"( -triton_computation { +ENTRY triton_computation { parameter_0 = $0[1,16,4]{2,1,0} parameter(0) ROOT bitcast_or_reshape = $0[64]{0} $1(parameter_0) -} - -ENTRY e { - parameter_0 = $0[1,16,4]{2,1,0} parameter(0) - ROOT root_op = $0[64]{0} fusion(parameter_0), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, @@ -120,17 +113,10 @@ TEST_P(UnaryElementwiseTest, IsTritonSupportedUnaryElementwise) { } const std::string kHloTestTemplate = R"( -triton_computation { +ENTRY triton_computation { parameter_0 = $0[33,68]{1,0} parameter(0) unary = $0[33,68]{1,0} $1(parameter_0) ROOT convert = f32[33,68]{1,0} convert(unary) -} - -ENTRY e { - parameter_0 = $0[33,68]{1,0} parameter(0) - ROOT root_op = f32[33,68]{1,0} fusion(parameter_0), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, @@ -184,18 +170,10 @@ TEST_P(BinaryElementwiseTest, IsTritonSupportedBinaryElementwise) { } const std::string kHloTestTemplate = R"( -triton_computation { +ENTRY triton_computation { parameter_0 = $0[11,63]{1,0} parameter(0) parameter_1 = $0[11,63]{1,0} parameter(1) ROOT binary = $0[11,63]{1,0} $1(parameter_0, parameter_1) -} - -ENTRY e { - parameter_0 = $0[11,63]{1,0} parameter(0) - parameter_1 = $0[11,63]{1,0} parameter(1) - ROOT triton_op = $0[11,63]{1,0} fusion(parameter_0, parameter_1), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, @@ -251,19 +229,11 @@ TEST_P(CompareTest, IsTritonSupportedCompare) { } const std::string kHloTestTemplate = R"( -triton_computation { +ENTRY triton_computation { parameter_0 = $0[11,63]{1,0} parameter(0) parameter_1 = $0[11,63]{1,0} parameter(1) compare = pred[11,63]{1,0} $1(parameter_0, parameter_1), direction=GE ROOT convert = f32[11,63]{1,0} convert(compare) -} - -ENTRY e { - parameter_0 = $0[11,63]{1,0} parameter(0) - parameter_1 = $0[11,63]{1,0} parameter(1) - ROOT triton_op = f32[11,63]{1,0} fusion(parameter_0, parameter_1), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, @@ -298,21 +268,12 @@ TEST_P(TernaryElementwiseTest, IsTritonSupportedTernaryElementwise) { } const std::string kHloTestTemplate = R"( -triton_computation { +ENTRY triton_computation { parameter_0 = $0[13,63]{1,0} parameter(0) parameter_1 = $0[13,63]{1,0} parameter(1) parameter_2 = pred[13,63]{1,0} parameter(2) ternary = $0[13,63]{1,0} $1(parameter_2, parameter_0, parameter_1) ROOT convert = f32[13,63]{1,0} convert(ternary) -} - -ENTRY e { - parameter_0 = $0[13,63]{1,0} parameter(0) - parameter_1 = $0[13,63]{1,0} parameter(1) - parameter_2 = pred[13,63]{1,0} parameter(2) - ROOT triton_op = f32[13,63]{1,0} fusion(parameter_0, parameter_1, parameter_2), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config":{"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, @@ -353,18 +314,10 @@ add { ROOT add = $0[] add(Arg_0, Arg_1) } -triton_computation { +ENTRY triton_computation { parameter_0 = $0[125,127]{1,0} parameter(0) constant_0 = $0[] constant(0) ROOT reduce = $0[125]{0} $1(parameter_0, constant_0), dimensions={1}, to_apply=add -} - -ENTRY main { - parameter_0 = $0[125,127]{1,0} parameter(0) - ROOT triton_op = $0[125]{0} fusion(parameter_0), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, @@ -406,19 +359,11 @@ add { ROOT add = f32[] add(Arg_0, Arg_1) } -triton_computation { +ENTRY triton_computation { parameter_0 = f32[125,127]{1,0} parameter(0) constant_0 = bf16[] constant(0) convert_0 = f32[] convert(constant_0) ROOT reduce = f32[125]{0} reduce(parameter_0, convert_0), dimensions={1}, to_apply=add -} - -ENTRY main { - parameter_0 = f32[125,127]{1,0} parameter(0) - ROOT triton_op = f32[125]{0} fusion(parameter_0), kind=kCustom, - calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( @@ -442,18 +387,10 @@ add { ROOT add = f32[] add(Arg_0, Arg_1) } -triton_computation { +ENTRY triton_computation { parameter_0 = f32[2,125,127]{2,1,0} parameter(0) constant_0 = f32[] constant(0) ROOT reduce = f32[2]{0} reduce(parameter_0, constant_0), dimensions={1,2}, to_apply=add -} - -ENTRY main { - parameter_0 = f32[2,125,127]{2,1,0} parameter(0) - ROOT triton_op = f32[2]{0} fusion(parameter_0), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( @@ -479,18 +416,10 @@ add { ROOT add = f32[] add(Arg_0, Arg_1) } -triton_computation { +ENTRY triton_computation { parameter_0 = f32[125,127]{1,0} parameter(0) constant_0 = f32[] constant(0) ROOT reduce = f32[127]{0} reduce(parameter_0, constant_0), dimensions={0}, to_apply=add -} - -ENTRY main { - parameter_0 = f32[125,127]{1,0} parameter(0) - ROOT triton_op = f32[127]{0} fusion(parameter_0), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( @@ -520,19 +449,11 @@ add { ROOT pair = (f32[], f32[]) tuple(add_0, add_1) } -triton_computation { +ENTRY triton_computation { parameter_0 = f32[125,127] parameter(0) constant_0 = f32[] constant(0) tuple_0 = (f32[125]{0}, f32[125]{0}) reduce(parameter_0, parameter_0, constant_0, constant_0), dimensions={1}, to_apply=add ROOT reduce = f32[125]{0} get-tuple-element(tuple_0), index=0 -} - -ENTRY main { - parameter_0 = f32[125,127]{1,0} parameter(0) - ROOT triton_op = f32[125]{0} fusion(parameter_0), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( @@ -557,19 +478,10 @@ add { ROOT add = f32[] add(Arg_0, Arg_1) } -triton_computation { +ENTRY triton_computation { parameter_0 = f32[125,127]{1,0} parameter(0) init = f32[] parameter(1) ROOT reduce = f32[125]{0} reduce(parameter_0, init), dimensions={1}, to_apply=add -} - -ENTRY main { - parameter_0 = f32[125,127]{1,0} parameter(0) - parameter_1 = f32[] parameter(1) - ROOT triton_op = f32[125]{0} fusion(parameter_0, parameter_1), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( @@ -599,18 +511,10 @@ custom_call { ROOT custom_call = f32[] custom-call(Arg_0, Arg_1), custom_call_target="foo" } -triton_computation { +ENTRY triton_computation { parameter_0 = f32[125,127]{1,0} parameter(0) constant_0 = f32[] constant(0) ROOT reduce = f32[125]{0} reduce(parameter_0, constant_0), dimensions={1}, to_apply=custom_call -} - -ENTRY main { - parameter_0 = f32[125,127]{1,0} parameter(0) - ROOT triton_op = f32[125]{0} fusion(parameter_0), - kind=kCustom, calls=triton_computation, - backend_config={"fusion_backend_config": - {"kind":"__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( diff --git a/third_party/xla/xla/service/gpu/triton_test_utils.cc b/third_party/xla/xla/service/gpu/triton_test_utils.cc index bb5bbe765f858a..97e5925d8a42f2 100644 --- a/third_party/xla/xla/service/gpu/triton_test_utils.cc +++ b/third_party/xla/xla/service/gpu/triton_test_utils.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include #include +#include #include +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -41,6 +43,7 @@ limitations under the License. #include "xla/service/float_normalization.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/gpu_float_support.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/ir_emitter_triton.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/hlo_pass_pipeline.h" @@ -48,6 +51,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/tests/filecheck.h" #include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace xla::gpu { @@ -140,6 +144,46 @@ std::string TritonSupportTestParamsToString( absl::StrReplaceAll(HloOpcodeString(opcode), {{"-", "_"}})); } +namespace { + +// This function does nothing if the input module already has an entry +// computation whose root is a fusion. Otherwise, creates a new entry +// computation whose root is a fusion instruction that calls the original entry +// computation. The new fusion instruction uses the generic Triton backend kind. +absl::Status ConvertEntryToTritonFusion(HloModule* module) { + if (module->entry_computation()->root_instruction()->opcode() == + HloOpcode::kFusion) { + return absl::OkStatus(); + } + auto builder = HloComputation::Builder("entry"); + std::vector params; + for (auto& param : module->entry_computation()->parameter_instructions()) { + TF_ASSIGN_OR_RETURN( + auto param_clone, + builder.AddParameter(HloInstruction::CreateParameter( + param->parameter_number(), param->shape(), + absl::StrCat("param_", param->parameter_number())))); + params.push_back(param_clone); + } + + auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( + module->entry_computation()->root_instruction()->shape(), + HloInstruction::FusionKind::kCustom, params, + module->entry_computation())); + + gpu::GpuBackendConfig gpu_config; + gpu_config.mutable_fusion_backend_config()->set_kind(kTritonFusionKind); + TF_RETURN_IF_ERROR(fusion->set_backend_config(gpu_config)); + + auto new_entry = + module->AddComputationAndUnifyNamesAndIds(builder.Build(), + /*is_entry=*/false); + module->ReplaceEntryComputation(new_entry); + return absl::OkStatus(); +} + +} // namespace + absl::StatusOr TritonSupportTest::ParseTemplateAndGetInstruction( absl::string_view hlo_template, xla::PrimitiveType data_type, @@ -149,8 +193,14 @@ TritonSupportTest::ParseTemplateAndGetInstruction( HloOpcodeString(opcode)); TF_ASSIGN_OR_RETURN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); + TF_RETURN_IF_ERROR(ConvertEntryToTritonFusion(module.get())); const HloComputation* computation = module->GetComputationWithName("triton_computation"); + if (computation == module->entry_computation()) { + return absl::InvalidArgumentError( + "The `triton_computation` and the module's entry computation cannot be " + "the same."); + } const HloFusionInstruction* fusion = DynCast( module->entry_computation()->root_instruction()); if (fusion == nullptr) { diff --git a/third_party/xla/xla/service/gpu/triton_test_utils.h b/third_party/xla/xla/service/gpu/triton_test_utils.h index de503bf8bd36f8..c7cb16abdf36ec 100644 --- a/third_party/xla/xla/service/gpu/triton_test_utils.h +++ b/third_party/xla/xla/service/gpu/triton_test_utils.h @@ -129,6 +129,13 @@ class TritonSupportTest : public TritonFilecheckTest { // The provided template must contain a computation called // `triton_computation`. If the template contains parameters $0 and $1, they // will be replaced with the data type and opcode respectively. + // If the template's entry computation does not have a root fusion + // instruction, a new entry computation will be created. The new computation + // will have a root fusion instruction that has the same parameters as the + // `triton_computation` and contains a fusion instruction that calls the + // `triton_computation` with the generic Triton emitter. Tests that need + // the `__triton_gemm` backend kind should provide their own ENTRY + // computation. absl::StatusOr ParseTemplateAndGetInstruction( absl::string_view hlo_template, xla::PrimitiveType data_type, xla::HloOpcode opcode); diff --git a/third_party/xla/xla/service/llvm_compiler.cc b/third_party/xla/xla/service/llvm_compiler.cc index 4bbcca0b484cd2..02afa91c5d9ea5 100644 --- a/third_party/xla/xla/service/llvm_compiler.cc +++ b/third_party/xla/xla/service/llvm_compiler.cc @@ -15,6 +15,11 @@ limitations under the License. #include "xla/service/llvm_compiler.h" +#include +#include +#include + +#include "absl/status/statusor.h" #include "tsl/platform/denormal.h" #include "tsl/profiler/lib/scoped_annotation.h" @@ -55,6 +60,6 @@ absl::StatusOr>> LLVMCompiler::Compile( result.push_back(std::move(executable)); } - return {std::move(result)}; + return std::move(result); } } // namespace xla diff --git a/third_party/xla/xla/service/service.cc b/third_party/xla/xla/service/service.cc index 207eae0ab440cc..4a1f323850d9e7 100644 --- a/third_party/xla/xla/service/service.cc +++ b/third_party/xla/xla/service/service.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "xla/debug_options_flags.h" @@ -33,6 +34,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" #include "xla/layout_util.h" +#include "xla/literal.h" #include "xla/service/backend.h" #include "xla/service/compiler.h" #include "xla/service/computation_layout.h" @@ -160,36 +162,26 @@ Service::Service(const ServiceOptions& options, } } -absl::Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) { - TF_ASSIGN_OR_RETURN(*result->mutable_channel(), - channel_tracker_.NewChannel(arg->channel_type())); - return absl::OkStatus(); +absl::StatusOr Service::CreateChannelHandle( + ChannelHandle::ChannelType type) { + return channel_tracker_.NewChannel(type); } -absl::Status Service::Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) { - absl::Status status; - for (auto& data : arg->data()) { - absl::Status unregister_status = allocation_tracker_.Unregister(data); - if (!unregister_status.ok() && status.ok()) { - status = unregister_status; - } - } - return status; +absl::Status Service::Unregister(const GlobalDataHandle& data) { + return allocation_tracker_.Unregister(data); } // Deconstructs a previously-allocated global handle. -absl::Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result) { - TF_ASSIGN_OR_RETURN( - std::vector elements, - allocation_tracker_.DeconstructTuple(arg->tuple_handle())); - - for (auto& element : elements) { - *result->add_element_handles() = element; +absl::StatusOr>> +Service::DeconstructTuple(const GlobalData& data) { + TF_ASSIGN_OR_RETURN(std::vector elements, + allocation_tracker_.DeconstructTuple(data.handle())); + std::vector> out; + out.reserve(elements.size()); + for (GlobalDataHandle& element : elements) { + out.push_back(std::make_unique(this, element)); } - return absl::OkStatus(); + return out; } absl::Status Service::ValidateResultShape(const Shape& client_shape, @@ -207,20 +199,14 @@ absl::Status Service::ValidateResultShape(const Shape& client_shape, absl::StatusOr>> Service::ResolveAndValidateArguments( - absl::Span arguments, + absl::Span arguments, absl::Span stream_executors) const { CHECK_EQ(options_.number_of_replicas(), stream_executors.size()); std::vector> replicated_arguments; replicated_arguments.resize(options_.number_of_replicas()); for (size_t i = 0; i < arguments.size(); ++i) { - auto buffer_status = allocation_tracker_.Resolve(*arguments[i]); - if (!buffer_status.ok()) { - return tsl::errors::CreateWithUpdatedMessage( - buffer_status.status(), - StrCat(buffer_status.status().message(), ", ", - "failed to resolve allocation for parameter ", i)); - } - auto replicated_buffers = buffer_status.value(); + TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, + allocation_tracker_.Resolve(arguments[i]->handle())); CHECK_EQ(options_.number_of_replicas(), replicated_buffers.size()); for (int replica = 0; replica < options_.number_of_replicas(); ++replica) { const ShapedBuffer* shaped_buffer = replicated_buffers[replica]; @@ -515,9 +501,8 @@ absl::StatusOr> Service::GetExecutors( } absl::StatusOr>> -Service::GetArguments( - const ExecutionOptions& execution_options, - absl::Span arguments) const { +Service::GetArguments(const ExecutionOptions& execution_options, + absl::Span arguments) const { // Resolve the allocations for the arguments of the computation, and create // a vector of device memory offsets for the arguments from the allocations. // In the case of partitioned computations, assume all arguments go on the @@ -531,8 +516,9 @@ Service::GetArguments( return replicated_arguments; } -absl::Status Service::ExecuteGraphParallel( - const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) { +absl::StatusOr>> +Service::ExecuteGraphParallel( + absl::Span computations) { VLOG(1) << "running execute-graph-parallel request"; std::vector>> all_arguments; @@ -543,10 +529,11 @@ absl::Status Service::ExecuteGraphParallel( std::vector device_handles; int num_requested_devices = - std::accumulate(arg->requests().begin(), arg->requests().end(), 0, - [](int a, const ExecuteGraphRequest& r) -> int { - return a + r.execution_options().device_handles_size(); + std::accumulate(computations.begin(), computations.end(), 0, + [](int a, const XlaComputationInstance& r) -> int { + return a + r.execution_options.device_handles_size(); }); + if (num_requested_devices * options_.number_of_replicas() > execute_backend_->device_count()) { return FailedPrecondition( @@ -554,23 +541,24 @@ absl::Status Service::ExecuteGraphParallel( num_requested_devices); } - for (int64_t i = 0; i < arg->requests_size(); ++i) { + for (int64_t i = 0; i < computations.size(); ++i) { + const XlaComputationInstance& computation = computations[i]; + // Get the stream executor for the i'th computation. This stream executor // is one of the executors to run the replicated computation. - const ExecutionOptions& execution_options = - arg->requests(i).execution_options(); - const ExecuteGraphRequest& request = arg->requests(i); - TF_RET_CHECK(request.has_computation()) << "computations may not be empty"; - TF_RET_CHECK(request.computation().has_host_program_shape()) + const ExecutionOptions& execution_options = computation.execution_options; + TF_RET_CHECK(computation.computation.proto().has_host_program_shape()) << "program shape may not be empty"; // Get the executors. - TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options, - arg->requests_size(), i)); + TF_ASSIGN_OR_RETURN( + std::vector executors, + GetExecutors(execution_options, computations.size(), i)); // Get the replicated arguments. - TF_ASSIGN_OR_RETURN(auto replicated_arguments, - GetArguments(execution_options, request.arguments())); + TF_ASSIGN_OR_RETURN( + std::vector> replicated_arguments, + GetArguments(execution_options, computation.arguments)); for (auto& args : replicated_arguments) { for (auto& arg : args) { @@ -596,8 +584,8 @@ absl::Status Service::ExecuteGraphParallel( TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig( - ProgramShape{request.computation().host_program_shape()}, - replicated_arguments.front(), request.execution_options())); + ProgramShape{computation.computation.proto().host_program_shape()}, + replicated_arguments.front(), computation.execution_options)); VLOG(3) << "ExecuteGraphParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -605,10 +593,10 @@ absl::Status Service::ExecuteGraphParallel( // Adds to the vectors to build and execute the computations after the loop. all_arguments.push_back(replicated_arguments); all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}}); - module_protos.push_back(&request.computation()); + module_protos.push_back(&computation.computation.proto()); module_configs.push_back(std::move(module_config)); computation_names.insert(computation_names.end(), executors.size(), - request.computation().name()); + computation.computation.name()); all_executors.push_back(executors); device_handles.insert(device_handles.end(), execution_options.device_handles().begin(), @@ -680,6 +668,12 @@ absl::Status Service::ExecuteGraphParallel( } } + for (int64_t i = 0; i < computations.size(); ++i) { + if (computations[i].execution_profile != nullptr) { + *computations[i].execution_profile = profile; + } + } + if (!execution_status.ok()) { // Execution failed so we don't have the results. Dump the HLO snapshot // with just the program arguments. @@ -690,11 +684,11 @@ absl::Status Service::ExecuteGraphParallel( TF_RETURN_IF_ERROR(execution_status); - for (const GlobalDataHandle& output : outputs) { - ExecuteResponse response; - *response.mutable_output() = output; - *response.mutable_profile() = profile; - *result->add_responses() = response; + std::vector> out; + + out.reserve(out.size()); + for (GlobalDataHandle& output : outputs) { + out.push_back(std::make_unique(this, output)); } for (int i = 0, end = executable_ptrs.size(); i < end; i++) { @@ -712,31 +706,32 @@ absl::Status Service::ExecuteGraphParallel( } VLOG(1) << "successfully completed 'execute-graph-parallel' request"; - return absl::OkStatus(); + return out; } -absl::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) { +absl::StatusOr> Service::GetDeviceHandles( + int64_t device_count) { const int64_t available_device_count = execute_backend_->device_count(); const int64_t replica_count = options_.number_of_replicas(); if (replica_count <= 0) { return FailedPrecondition("Replica count must be a positive integer"); } - if (available_device_count < arg->device_count() * replica_count) { + if (available_device_count < device_count * replica_count) { return ResourceExhausted( "Requested logical device count (%d) with replica count (%d) exceeds " "the number of available physical devices on the target (%d)", - arg->device_count(), replica_count, available_device_count); + device_count, replica_count, available_device_count); } - for (int64_t i = 0; i < arg->device_count(); ++i) { + std::vector out; + + for (int64_t i = 0; i < device_count; ++i) { DeviceHandle device_handle; device_handle.set_handle(i); - device_handle.set_device_count(arg->device_count()); - *result->add_device_handles() = device_handle; + device_handle.set_device_count(device_count); + out.push_back(device_handle); } - - return absl::OkStatus(); + return out; } absl::StatusOr> Service::BuildExecutable( @@ -795,71 +790,66 @@ absl::StatusOr> Service::BuildExecutable( return executable; } -absl::Status Service::Compile(const CompileRequest* arg, - CompileResponse* result) { +absl::StatusOr Service::Compile( + const XlaComputation& computation, absl::Span argument_shapes, + const ExecutionOptions& execution_options) { VLOG(1) << "running compile request"; - if (!arg->has_computation()) { - return InvalidArgument("computations may not be empty"); - } - if (!arg->computation().has_host_program_shape()) { + + if (!computation.proto().has_host_program_shape()) { return InvalidArgument("program shape may not be empty"); } - if (arg->execution_options().device_handles_size() > 1) { + if (execution_options.device_handles_size() > 1) { return InvalidArgument( "The compile request does not support multiple device handles."); } - std::vector argument_shapes; - argument_shapes.reserve(arg->input_shape_with_layout_size()); std::vector argument_shape_ptrs; - for (const ShapeProto& shape_proto : arg->input_shape_with_layout()) { - argument_shapes.push_back(Shape(shape_proto)); - argument_shape_ptrs.push_back(&argument_shapes.back()); + for (const Shape& shape : argument_shapes) { + argument_shape_ptrs.push_back(&shape); } + TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(ProgramShape{arg->computation().host_program_shape()}, - argument_shape_ptrs, &arg->execution_options())); + CreateModuleConfig(ProgramShape{computation.proto().host_program_shape()}, + argument_shape_ptrs, &execution_options)); VLOG(3) << "Compile created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); TF_ASSIGN_OR_RETURN( std::unique_ptr executable, - BuildExecutable(arg->computation(), std::move(module_config), + BuildExecutable(computation.proto(), std::move(module_config), execute_backend_.get(), execute_backend_->default_stream_executor(), {/*device_allocator=*/nullptr})); - *result->mutable_handle() = compilation_cache_.Insert(std::move(executable)); - VLOG(1) << "successfully completed 'compile' request"; - return absl::OkStatus(); + return compilation_cache_.Insert(std::move(executable)); } -absl::Status Service::Execute(const ExecuteRequest* arg, - ExecuteResponse* result) { +absl::StatusOr> Service::Execute( + const ExecutionHandle& handle, absl::Span arguments, + ExecutionProfile* execution_profile) { VLOG(1) << "running execute request"; - if (!arg->has_handle()) { - return InvalidArgument("execution handle should not be empty"); - } - TF_ASSIGN_OR_RETURN(auto executable, - compilation_cache_.LookUp(arg->handle())); - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, - SingleComputationDeviceHandle())); + TF_ASSIGN_OR_RETURN(std::shared_ptr executable, + compilation_cache_.LookUp(handle)); + + TF_ASSIGN_OR_RETURN( + std::vector replicas, + Replicas(*execute_backend_, SingleComputationDeviceHandle())); TF_ASSIGN_OR_RETURN( std::vector> replicated_arguments, - ResolveAndValidateArguments(arg->arguments(), replicas)); + ResolveAndValidateArguments(arguments, replicas)); // Check that the replicated_arguments has the same shape and layout as the // module config used when creating the executable. const int64_t num_module_args = executable->module_config().entry_computation_layout().parameter_count(); - if (num_module_args != arg->arguments_size()) { + if (num_module_args != arguments.size()) { return InvalidArgument( "The executable expects %lld arguments, but sees %lld.", - num_module_args, arg->arguments_size()); + num_module_args, arguments.size()); } for (int64_t i = 0; i < num_module_args; i++) { const Shape& shape_module = @@ -887,35 +877,32 @@ absl::Status Service::Execute(const ExecuteRequest* arg, } TF_ASSIGN_OR_RETURN( - *result->mutable_output(), - ExecuteAndRegisterResult(executable.get(), replicated_arguments, - execute_backend_.get(), - SingleComputationDeviceHandle(), - "result of " + executable->module().name(), - result->mutable_profile())); + GlobalDataHandle output, + ExecuteAndRegisterResult( + executable.get(), replicated_arguments, execute_backend_.get(), + SingleComputationDeviceHandle(), + "result of " + executable->module().name(), execution_profile)); if (executable->dumping_snapshot()) { - TF_ASSIGN_OR_RETURN( - const ShapedBuffer* result_buffer, - allocation_tracker_.ResolveForReplica(result->output(), 0)); + TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, + allocation_tracker_.ResolveForReplica(output, 0)); TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(), execute_backend_->transfer_manager(), &snapshot)); DumpHloSnapshotIfEnabled(executable->module(), snapshot); } - VLOG(1) << "successfully completed 'execute' request"; - return absl::OkStatus(); + return std::make_unique(this, output); } -absl::Status Service::TransferToClient(const TransferToClientRequest* arg, - TransferToClientResponse* result) { +absl::StatusOr Service::TransferToClient( + const GlobalData& data, const Shape* shape_with_layout) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer, - allocation_tracker_.ResolveForReplica(arg->data(), 0)); + allocation_tracker_.ResolveForReplica(data.handle(), 0)); Shape return_shape; - if (arg->has_shape_with_layout()) { - return_shape = Shape(arg->shape_with_layout()); + if (shape_with_layout) { + return_shape = Shape(*shape_with_layout); if (!LayoutUtil::HasLayout(return_shape)) { return InvalidArgument("shape_with_layout must have layout if present."); } @@ -943,24 +930,17 @@ absl::Status Service::TransferToClient(const TransferToClientRequest* arg, stream.get(), *shaped_buffer)); if (LayoutUtil::LayoutsInShapesEqual(return_shape, result_literal.shape())) { - *result->mutable_literal() = result_literal.ToProto(); - } else { - *result->mutable_literal() = - result_literal.Relayout(return_shape).ToProto(); + return result_literal; } - return absl::OkStatus(); + return result_literal.Relayout(return_shape); } -absl::Status Service::TransferToServer(const TransferToServerRequest* arg, - TransferToServerResponse* result) { - TF_ASSIGN_OR_RETURN(Literal literal, - Literal::CreateFromProto(arg->literal())); - const Shape& shape = literal.shape(); - +absl::StatusOr> Service::TransferToServer( + const LiteralSlice& literal_slice, const DeviceHandle* device_handle) { + const Shape& shape = literal_slice.shape(); std::vector replicas; - if (arg->has_device_handle()) { - TF_ASSIGN_OR_RETURN(replicas, - Replicas(*execute_backend_, arg->device_handle())); + if (device_handle) { + TF_ASSIGN_OR_RETURN(replicas, Replicas(*execute_backend_, *device_handle)); } else { TF_ASSIGN_OR_RETURN( replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); @@ -982,100 +962,93 @@ absl::Status Service::TransferToServer(const TransferToServerRequest* arg, TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor)); TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralToDevice( - stream.get(), literal, shaped_buffer)); + stream.get(), literal_slice, shaped_buffer)); replicated_buffers.emplace_back(std::move(shaped_buffer)); } - TF_ASSIGN_OR_RETURN(*result->mutable_data(), + + TF_ASSIGN_OR_RETURN(GlobalDataHandle out, allocation_tracker_.RegisterReplicatedBuffers( std::move(replicated_buffers), StrCat("TransferToServer literal of shape ", ShapeUtil::HumanString(shape)))); - return absl::OkStatus(); + return std::make_unique(this, out); } -absl::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) { +absl::Status Service::TransferToInfeed(const LiteralSlice& literal, + int64_t replica_id, + const DeviceHandle* device_handle) { const int64_t replica_count = options_.number_of_replicas(); - if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { + if (replica_id < 0 || replica_id >= replica_count) { return FailedPrecondition( "%s", - StrCat("The replica_id=", arg->replica_id(), + StrCat("The replica_id=", replica_id, " on TransferToInfeedRequest not in range [0, replica_count=", replica_count, ").")); } se::StreamExecutor* executor; - if (arg->has_device_handle()) { + if (device_handle) { TF_ASSIGN_OR_RETURN(auto replicas, - Replicas(*execute_backend_, arg->device_handle())); - executor = replicas[arg->replica_id()]; + Replicas(*execute_backend_, *device_handle)); + executor = replicas[replica_id]; } else { TF_ASSIGN_OR_RETURN( auto replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); - executor = replicas[arg->replica_id()]; + executor = replicas[replica_id]; } - TF_ASSIGN_OR_RETURN(Literal literal, - Literal::CreateFromProto(arg->literal())); return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor, literal); } -absl::Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) { +absl::StatusOr Service::TransferFromOutfeed( + const Shape* shape_with_layout, int64_t replica_id, + const DeviceHandle* device_handle) { const int64_t replica_count = options_.number_of_replicas(); - if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { + if (replica_id < 0 || replica_id >= replica_count) { return FailedPrecondition( "The replica_id=%d on TransferFromOutfeedRequest not in range [0, %d)", - arg->replica_id(), replica_count); + replica_id, replica_count); } se::StreamExecutor* executor; - if (arg->has_device_handle()) { + if (device_handle) { TF_ASSIGN_OR_RETURN(auto replicas, - Replicas(*execute_backend_, arg->device_handle())); - executor = replicas[arg->replica_id()]; + Replicas(*execute_backend_, *device_handle)); + executor = replicas[replica_id]; } else { TF_ASSIGN_OR_RETURN( auto replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); - executor = replicas[arg->replica_id()]; + executor = replicas[replica_id]; } - auto literal = Literal::CreateFromShape(Shape(arg->shape_with_layout())); + auto literal = Literal::CreateFromShape(*shape_with_layout); TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( executor, &literal)); - *result->mutable_literal() = literal.ToProto(); - return absl::OkStatus(); + return literal; } -absl::Status Service::ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) { - return execute_backend_->ResetDevices(); -} +absl::Status Service::ResetDevice() { return execute_backend_->ResetDevices(); } -absl::Status Service::ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) { - if (!arg->has_computation()) { - return InvalidArgument("computations may not be empty"); - } - if (!arg->computation().has_host_program_shape()) { +absl::StatusOr Service::ComputeConstantGraph( + const XlaComputation& computation, const Layout* output_layout) { + if (!computation.proto().has_host_program_shape()) { return InvalidArgument("program shape may not be empty"); } - if (arg->computation().host_program_shape().parameters_size() != 0) { + if (computation.proto().host_program_shape().parameters_size() != 0) { return InvalidArgument( "constant computation may not depend on any parameters."); } - ProgramShape program_shape(arg->computation().host_program_shape()); + ProgramShape program_shape(computation.proto().host_program_shape()); TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); - std::optional output_layout; - if (arg->has_output_layout()) { - output_layout = Layout::CreateFromProto(arg->output_layout()); + + if (output_layout) { TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( *output_layout, program_shape.result())); } @@ -1083,7 +1056,7 @@ absl::Status Service::ComputeConstantGraph( HloModuleConfig config(program_shape); TF_ASSIGN_OR_RETURN(std::unique_ptr module, - CreateModuleFromProto(arg->computation(), config)); + CreateModuleFromProto(computation.proto(), config)); DynamicPadder dynamic_padder; TF_RETURN_IF_ERROR(dynamic_padder.Run(module.get()).status()); @@ -1112,20 +1085,16 @@ absl::Status Service::ComputeConstantGraph( // relayout here. // // TODO(b/77824332): Make HloEvaluator take care of the re-layout. - if (output_layout.has_value()) { + if (output_layout) { result_literal = result_literal.Relayout(*output_layout); } - *result->mutable_literal() = result_literal.ToProto(); - - return absl::OkStatus(); + return result_literal; } -absl::Status Service::GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) { +absl::StatusOr Service::GetShape(const GlobalData& data) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, - allocation_tracker_.ResolveForReplica(arg->data(), 0)); - *result->mutable_shape() = buffer->on_device_shape().ToProto(); - return absl::OkStatus(); + allocation_tracker_.ResolveForReplica(data.handle(), 0)); + return buffer->on_device_shape(); } DeviceHandle Service::SingleComputationDeviceHandle() const { @@ -1152,4 +1121,46 @@ absl::StatusOr> Service::Replicas( return replicas; } +namespace { + +// Releases a set of global data handles owned by the parent service +// interface. +void ReleaseHandles(Service* parent, + const absl::Span handles) { + for (const GlobalDataHandle& handle : handles) { + VLOG(1) << "Requesting to unregister " << handle.ShortDebugString(); + absl::Status status = parent->Unregister(handle); + if (!status.ok()) { + LOG(WARNING) << "Failed to unregister handles: " << status + << "; continuing anyway..."; + } + } + VLOG(1) << "Done with request"; +} + +} // namespace + +GlobalData::GlobalData(Service* parent, GlobalDataHandle handle) + : handle_(std::move(handle)), parent_(parent) {} + +GlobalData::~GlobalData() { + if (parent_ != nullptr) { + ReleaseHandles(parent_, {handle_}); + } +} + +/* static */ void GlobalData::Release( + std::vector> instances) { + absl::flat_hash_map> + parent_handles_map; + for (auto& instance : instances) { + if (instance->parent_ != nullptr) { + parent_handles_map[instance->parent_].push_back(instance->Release()); + } + } + for (auto& parent_handles : parent_handles_map) { + ReleaseHandles(parent_handles.first, parent_handles.second); + } +} + } // namespace xla diff --git a/third_party/xla/xla/service/service.h b/third_party/xla/xla/service/service.h index ff54b36ae04435..3fd7f227c362e8 100644 --- a/third_party/xla/xla/service/service.h +++ b/third_party/xla/xla/service/service.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/client/xla_computation.h" #include "xla/debug_options_flags.h" #include "xla/executable_run_options.h" #include "xla/hlo/ir/hlo_module.h" @@ -44,6 +45,8 @@ limitations under the License. namespace xla { +class Service; + // Options to configure the service when it is created. class ServiceOptions { public: @@ -73,6 +76,59 @@ class ServiceOptions { std::optional> allowed_devices_; }; +// A GlobalData object represents a globally-accessible allocation of +// data in the associated XLA service. +class GlobalData { + public: + // Gives ownership of the global data handle to this object. + GlobalData(Service* parent, GlobalDataHandle handle); + + // Unregisters the wrapped handle, which causes the service to + // deallocate the associated data. + ~GlobalData(); + + const GlobalDataHandle& handle() const { return handle_; } + + // Releases a set of GlobalData handles. A single RPC will be issued + // per unique Service of the given GlobalData objects. + static void Release(std::vector> instances); + + private: + // Detaches the global data handle from the object, such that the destructor + // will not try to release it. + GlobalDataHandle Release() { + parent_ = nullptr; + return handle_; + } + + GlobalDataHandle handle_; // Handle being wrapped. + Service* parent_; // Service used to unregister handle_. + + GlobalData(const GlobalData&) = delete; + GlobalData& operator=(const GlobalData&) = delete; +}; + +// A struct to represent a computation instance to be executed. +// * If execution_options.device_handles is not empty, the computation is +// executed on the devices associated with the handles by partitioning the +// computation based on the attached sharding attributes. Otherwise, a +// device is chosen by the service. +struct XlaComputationInstance { + const XlaComputation& computation; + std::vector arguments; + ExecutionOptions execution_options; + ExecutionProfile* execution_profile; + + XlaComputationInstance(const XlaComputation& computation, + std::vector arguments, + ExecutionOptions execution_options, + ExecutionProfile* execution_profile) + : computation(computation), + arguments(std::move(arguments)), + execution_options(execution_options), + execution_profile(execution_profile) {} +}; + // The XLA service object, which is the same across all platforms. It maintains // the service state of computations and allocations, and delegates // target-specific requests to the target-specific infrastructure @@ -83,30 +139,32 @@ class Service { // // If the handle given is not currently allocated, a NOT_FOUND status is // returned. - virtual absl::Status Unregister(const UnregisterRequest* arg, - UnregisterResponse* result); + virtual absl::Status Unregister(const GlobalDataHandle& data); // Deconstructs a tuple. Returns a newly created GlobalDataHandle for each // element in the tuple. - virtual absl::Status DeconstructTuple(const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result); + virtual absl::StatusOr>> + DeconstructTuple(const GlobalData& data); // Compiles a computation into an executable. The request contains the whole // computation graph. Returns the handle to the executable. - virtual absl::Status Compile(const CompileRequest* arg, - CompileResponse* result); + virtual absl::StatusOr Compile( + const XlaComputation& computation, + absl::Span argument_shapes, + const ExecutionOptions& execution_options); // Executes an executable with the provided global data passes as immutable // arguments. The request contains the handle to the executable. Returns // global data output and execution timing. - virtual absl::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result); + virtual absl::StatusOr> Execute( + const ExecutionHandle& handle, absl::Span arguments, + ExecutionProfile* execution_profile); // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each // computation. - virtual absl::Status ExecuteGraphParallel( - const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result); + absl::StatusOr>> ExecuteGraphParallel( + absl::Span computations); // Requests one or more device handles from the target. // @@ -116,27 +174,28 @@ class Service { // the first set of replicas, and the next R devices to the second set of // replicas, etc. Each returned device handle represents the device with the // replica id 0. - virtual absl::Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result); + virtual absl::StatusOr> GetDeviceHandles( + int64_t device_count); // Requests that global data be transferred to the client in literal form. - virtual absl::Status TransferToClient(const TransferToClientRequest* arg, - TransferToClientResponse* result); + virtual absl::StatusOr TransferToClient( + const GlobalData& data, const Shape* shape_with_layout); // Transfers data from a literal provided by the client, into device memory. - virtual absl::Status TransferToServer(const TransferToServerRequest* arg, - TransferToServerResponse* result); + virtual absl::StatusOr> TransferToServer( + const LiteralSlice& literal_slice, const DeviceHandle* device_handle); // Transfers data from a literal provided by the client, into the Infeed // buffer of the device. - virtual absl::Status TransferToInfeed(const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result); + virtual absl::Status TransferToInfeed(const LiteralSlice& literal, + int64_t replica_id, + const DeviceHandle* device_handle); // Transfers data from the Outfeed othe device to the literal provided by the // client. - virtual absl::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result); + virtual absl::StatusOr TransferFromOutfeed( + const Shape* shape_with_layout, int64_t replica_id, + const DeviceHandle* device_handle); // Resets devices, clearing all existing state on all the devices associated // with this service (including memory allocated on the devices). @@ -147,22 +206,19 @@ class Service { // ResetDevice should be called before an Execution that expect the device to // be in the reset state. For example, if the prior Execution modifies device // state (e.g., architectural state) that the next Execution depends on. - virtual absl::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result); + virtual absl::Status ResetDevice(); - virtual absl::Status ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result); + virtual absl::StatusOr ComputeConstantGraph( + const XlaComputation& computation, const Layout* output_layout); // Returns the shape (with layout) of an array associated with a given data // handle. - virtual absl::Status GetShape(const GetShapeRequest* arg, - GetShapeResponse* result); + virtual absl::StatusOr GetShape(const GlobalData& data); // Creates a unique channel handle that can be used for Send/Recv // instructions. - virtual absl::Status CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result); + virtual absl::StatusOr CreateChannelHandle( + ChannelHandle::ChannelType type); // Returns the backend used to execute computations. const Backend& backend() const { return *execute_backend_; } @@ -201,7 +257,7 @@ class Service { // Prepare the arguments for executing parallel. absl::StatusOr>> GetArguments( const ExecutionOptions& execution_options, - absl::Span arguments) const; + absl::Span arguments) const; protected: friend class LocalExecutable; @@ -217,7 +273,7 @@ class Service { // the corresponding replica. absl::StatusOr>> ResolveAndValidateArguments( - absl::Span arguments, + absl::Span arguments, absl::Span stream_executors) const; public: diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 918992c6aeff99..b841b47a70624f 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -12,10 +12,6 @@ load( "if_cuda_is_configured", ) load("//xla:xla.bzl", "tests_build_defs_bzl_deps", "xla_cc_binary", "xla_cc_test") -load( - "//xla/stream_executor:build_defs.bzl", - "if_gpu_is_configured", -) load("//xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend_test_macros", "xla_test", "xla_test_library") load("//xla/tsl:tsl.bzl", "internal_visibility", "tsl_copts") load("//xla/tsl:tsl.default.bzl", "filegroup") @@ -1045,6 +1041,7 @@ xla_test( "//xla:status_macros", "//xla:test", "//xla/client:xla_builder", + "//xla/service", ], ) @@ -2461,32 +2458,30 @@ xla_test( xla_test( name = "llvm_compiler_test", srcs = ["llvm_compiler_test.cc"], - backends = ["gpu"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM", - ]), - # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly - tags = ["gpu"], - deps = if_gpu_is_configured([ - ":verified_hlo_module", + backend_tags = { + # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly. + "gpu": ["gpu"], + }, + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":hlo_test_base", "//xla:literal_util", "//xla:test_helpers", "//xla/hlo/ir:hlo", - "@com_google_absl//absl/status", + "//xla/hlo/ir:hlo_module_group", "//xla/service:backend", - "//xla/service:cpu_plugin", "//xla/service:llvm_compiler", - "//xla/service:platform_util", - "//xla/service/cpu:cpu_compiler", - "//xla/service/gpu:gpu_compiler", "//xla/stream_executor", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Core", + "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:test_main", - ]) + if_cuda_is_configured([ - "//xla/stream_executor/cuda:cuda_platform_id", - ]) + if_rocm_is_configured([ - "//xla/stream_executor/rocm:rocm_platform_id", - ]), + ], ) xla_test( diff --git a/third_party/xla/xla/tests/client_test.cc b/third_party/xla/xla/tests/client_test.cc index 6d67d07eb969da..1adb92207748aa 100644 --- a/third_party/xla/xla/tests/client_test.cc +++ b/third_party/xla/xla/tests/client_test.cc @@ -128,14 +128,14 @@ XLA_TEST_F(ClientTest, // We can't really test parallel execution on CPU since all of the cores in a // CPU are presented as a single device. So for now we test "parallel" // execution on a single device. - std::vector computation_instances; + std::vector computation_instances; TF_ASSERT_OK_AND_ASSIGN(std::vector devices, client_->GetDeviceHandles(1)); ASSERT_EQ(devices.size(), 1); ExecutionOptions options = execution_options_; *options.add_device_handles() = devices[0]; - computation_instances.push_back(Client::XlaComputationInstance( + computation_instances.push_back(XlaComputationInstance( add_with_one_arg, {const_arg.get()}, options, nullptr)); TF_ASSERT_OK_AND_ASSIGN(auto results, diff --git a/third_party/xla/xla/tests/gather_operation_test.cc b/third_party/xla/xla/tests/gather_operation_test.cc index b80d1570a9b3b0..94d02abbea3b04 100644 --- a/third_party/xla/xla/tests/gather_operation_test.cc +++ b/third_party/xla/xla/tests/gather_operation_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/execution_options_util.h" #include "xla/literal_util.h" +#include "xla/service/service.h" #include "xla/status_macros.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" @@ -778,7 +779,7 @@ XLA_TEST_F(GatherClientLibraryTest, xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions(); *execution_options.add_device_handles() = devices[0]; TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, builder.Build()); - std::vector computation_instances = { + std::vector computation_instances = { {computation, {operand_arg.get(), indices_arg.get()}, execution_options, diff --git a/third_party/xla/xla/tests/llvm_compiler_test.cc b/third_party/xla/xla/tests/llvm_compiler_test.cc index dd15cbbeba4cc2..238482b650c3d6 100644 --- a/third_party/xla/xla/tests/llvm_compiler_test.cc +++ b/third_party/xla/xla/tests/llvm_compiler_test.cc @@ -21,199 +21,81 @@ limitations under the License. #include #include +#include #include "absl/status/status.h" +#include "llvm/IR/Module.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module_group.h" #include "xla/literal_util.h" #include "xla/service/backend.h" -#include "xla/service/cpu/cpu_compiler.h" -#include "xla/service/gpu/gpu_compiler.h" -#include "xla/service/platform_util.h" -#if GOOGLE_CUDA -#include "xla/stream_executor/cuda/cuda_platform_id.h" -#elif TENSORFLOW_USE_ROCM -#include "xla/stream_executor/rocm/rocm_platform_id.h" -#endif #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/stream_executor.h" #include "xla/test_helpers.h" -#include "xla/tests/verified_hlo_module.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/casts.h" #include "tsl/platform/threadpool.h" namespace xla { -namespace gpu { - -// Creating dummy data structure needed to initialize a GpuDummyCompiler -constexpr char kDummyTriple[] = "dummy-triple"; -constexpr char kDummyLayout[] = "e"; -const se::Platform::Id kGpuPlatformId = -#if GOOGLE_CUDA - se::cuda::kCudaPlatformId; -#elif TENSORFLOW_USE_ROCM - se::rocm::kROCmPlatformId; -#endif -// This class is a dummy implementation of GpuCompiler and is targeted for unit -// test only -class GpuDummyCompiler : public GpuCompiler { - public: - GpuDummyCompiler() - : GpuCompiler(kGpuPlatformId, kDummyTriple, kDummyLayout) {} - - int32_t GetToolkitVersion() const override { return 0; } - - absl::Status OptimizeHloConvolutionCanonicalization( - HloModule* hlo_module, se::GpuComputeCapability gpu_version, - se::dnn::VersionInfo dnn_version, - se::DeviceMemoryAllocator* device_allocator) { - return absl::OkStatus(); - } - - absl::Status OptimizeHloPostLayoutAssignment( - HloModule* hlo_module, se::StreamExecutor* stream_executor, - const CompileOptions& options, const TargetConfig& gpu_target_config, - tsl::thread::ThreadPool* thread_pool) override { - return absl::OkStatus(); - } +namespace { - absl::StatusOr CompileTargetBinary( - const HloModuleConfig& module_config, llvm::Module* llvm_module, - se::GpuComputeCapability gpu_version, bool relocatable, - const HloModule* debug_module, const CompileOptions& options) override { - return BackendCompileResult{}; - } -}; -} // namespace gpu +using LLVMCompilerTest = HloTestBase; -namespace { +const char* const kHloText = R"( +HloModule Constant -class LLVMCompilerTest : public ::testing::Test { - public: - void SetUp() override { - Platform* platform = FindPlatform(); - ASSERT_NE(platform, nullptr); - - BackendOptions backend_options; - backend_options.set_platform(platform); - absl::StatusOr> backend_or_status = - Backend::CreateBackend(backend_options); - ASSERT_IS_OK(backend_or_status.status()); - backend_ = std::move(backend_or_status).value(); - } - - ~LLVMCompilerTest() override {} - - protected: - using Platform = se::Platform; - - explicit LLVMCompilerTest(std::string platform_name) - : platform_name_(std::move(platform_name)) {} - - void TestCompilerHooks(LLVMCompiler* compiler) { - int pre_opt_hook_call_count = 0; - int post_opt_hook_call_count = 0; - - auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module&) { - ++pre_opt_hook_call_count; - return absl::OkStatus(); - }; - auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module&) { - ++post_opt_hook_call_count; - return absl::OkStatus(); - }; - - // Create HLO module, and run the compiler. - auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - - auto hlo_module = CreateNewVerifiedModule(); - hlo_module->AddEntryComputation(builder.Build()); - - compiler->SetPreOptimizationHook(pre_opt_hook); - compiler->SetPostOptimizationHook(post_opt_hook); - - ASSERT_TRUE(compiler - ->RunBackend(std::move(hlo_module), - backend_->default_stream_executor(), - /*device_allocator=*/nullptr) - .ok()); - - // Test that hooks were called. - EXPECT_EQ(1, pre_opt_hook_call_count); - EXPECT_EQ(1, post_opt_hook_call_count); - } - - void TestMultiModuleCompilation(LLVMCompiler* compiler) { - HloComputation::Builder builder(TestName()); - builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - - std::unique_ptr hlo_module = CreateNewVerifiedModule(); - hlo_module->AddEntryComputation(builder.Build()); - - auto module_group = std::make_unique("test_module_group"); - module_group->push_back(hlo_module->Clone()); - module_group->push_back(std::move(hlo_module)); - - std::vector> executors; - executors.push_back({backend_->default_stream_executor()}); - executors.push_back({backend_->default_stream_executor()}); - - EXPECT_IS_OK(compiler->Compile(std::move(module_group), - std::move(executors), - /*device_allocator=*/nullptr)); - } - - private: - Platform* FindPlatform() { - auto status_or_platform = PlatformUtil::GetPlatform(platform_name_); - return status_or_platform.ok() ? status_or_platform.value() : nullptr; - } - - std::string platform_name_; - std::unique_ptr backend_; - - static std::string TestName() { - return ::testing::UnitTest::GetInstance()->current_test_info()->name(); - } - - std::unique_ptr CreateNewVerifiedModule() { - HloModuleConfig config; - config.set_debug_options(GetDebugOptionsFromFlags()); - return std::make_unique( - TestName(), config, /*verifier_layout_sensitive=*/false, - /*allow_mixed_precision_in_hlo_verifier=*/true, - backend_->compiler()->ShapeSizeBytesFunction()); - } -}; - -class CpuCompilerTest : public LLVMCompilerTest { - public: - CpuCompilerTest() : LLVMCompilerTest("Host") {} -}; - -class GpuCompilerTest : public LLVMCompilerTest { - public: - GpuCompilerTest() : LLVMCompilerTest("GPU") {} -}; - -TEST_F(CpuCompilerTest, HooksTest) { - cpu::CpuCompiler compiler; - TestCompilerHooks(&compiler); +ENTRY main { + ROOT constant = f32[] constant(42.0) } +)"; -TEST_F(GpuCompilerTest, HooksTest) { - gpu::GpuDummyCompiler compiler; - TestCompilerHooks(&compiler); -} +TEST_F(LLVMCompilerTest, HooksTest) { + int pre_opt_hook_call_count = 0; + int post_opt_hook_call_count = 0; -TEST_F(CpuCompilerTest, CpuMultiModuleCompilation) { - cpu::CpuCompiler compiler; - TestMultiModuleCompilation(&compiler); + auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module&) { + ++pre_opt_hook_call_count; + return absl::OkStatus(); + }; + auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module&) { + ++post_opt_hook_call_count; + return absl::OkStatus(); + }; + + // Create HLO module, and run the compiler. + auto hlo_module = ParseAndReturnVerifiedModule(kHloText).value(); + LLVMCompiler* compiler = + tensorflow::down_cast(backend().compiler()); + compiler->SetPreOptimizationHook(pre_opt_hook); + compiler->SetPostOptimizationHook(post_opt_hook); + + ASSERT_TRUE(compiler + ->RunBackend(std::move(hlo_module), + backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ok()); + + // Test that hooks were called. + EXPECT_EQ(1, pre_opt_hook_call_count); + EXPECT_EQ(1, post_opt_hook_call_count); } -TEST_F(GpuCompilerTest, GpuMultModuleCompilation) { - gpu::GpuDummyCompiler compiler; - TestMultiModuleCompilation(&compiler); +TEST_F(LLVMCompilerTest, DISABLED_MultiModuleCompilation) { + auto hlo_module = ParseAndReturnVerifiedModule(kHloText).value(); + auto hlo_module2 = ParseAndReturnVerifiedModule(kHloText).value(); + std::vector> modules; + modules.push_back(std::move(hlo_module)); + modules.push_back(std::move(hlo_module2)); + auto module_group = + std::make_unique("test_module_group", std::move(modules)); + + std::vector> executors; + executors.push_back({backend().default_stream_executor()}); + executors.push_back({backend().default_stream_executor()}); + + EXPECT_IS_OK(backend().compiler()->Compile(std::move(module_group), + std::move(executors), + backend().memory_allocator())); } + } // namespace } // namespace xla diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index d9e4594d2b4708..16b1da1999d7c3 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -1037,189 +1037,6 @@ message HloModuleProtoWithConfig { HloModuleConfigProto config = 2; } -message GetDeviceHandlesRequest { - int64 device_count = 1; -} - -message GetDeviceHandlesResponse { - repeated DeviceHandle device_handles = 1; -} - -message TransferToClientRequest { - GlobalDataHandle data = 1; - - // This optional field directs the service to return the literal in this - // layout. A shape is used to hold the layout to accommodate tuples. - ShapeProto shape_with_layout = 2; -} - -message TransferToClientResponse { - LiteralProto literal = 1; -} - -message TransferToServerRequest { - LiteralProto literal = 1; - DeviceHandle device_handle = 2; -} - -message TransferToServerResponse { - GlobalDataHandle data = 1; -} - -message TransferToInfeedRequest { - LiteralProto literal = 1; - int64 replica_id = 2; - DeviceHandle device_handle = 3; -} - -message TransferToInfeedResponse {} - -message TransferFromOutfeedRequest { - // This optional field directs the service to return the literal in this - // layout. A shape is used to hold the layout to accommodate tuples. - ShapeProto shape_with_layout = 1; - - int64 replica_id = 2; - DeviceHandle device_handle = 3; -} - -message TransferFromOutfeedResponse { - LiteralProto literal = 1; -} - -message ResetDeviceRequest { - DeviceHandle device_handle = 1; -} - -message ResetDeviceResponse {} - -message CreateChannelHandleRequest { - ChannelHandle.ChannelType channel_type = 1; -} - -message CreateChannelHandleResponse { - ChannelHandle channel = 1; -} - -message UnregisterRequest { - repeated GlobalDataHandle data = 1; -} - -message UnregisterResponse {} - -message CompileRequest { - // The graph to be compiled. - HloModuleProto computation = 1; - - // Options that affect how XLA compiles code to service this request. - ExecutionOptions execution_options = 2; - - // The layouts of the input arguments. If not set, the default layout will be - // used. Although the real arguments are not needed in compilation, the - // layouts of the arguments can affect the compilation. - repeated ShapeProto input_shape_with_layout = 3; -} - -message CompileResponse { - // The handle to the executable. - ExecutionHandle handle = 1; -} - -message ExecuteRequest { - ExecutionHandle handle = 1; - - // The shape and layout of the arguments must be the same as the those of the - // executable's parameters. - repeated GlobalDataHandle arguments = 2; -} - -// TODO(b/118493728): Remove this and ExecuteGraphParallelRequest and replace -// the uses with calls to Compile and Execute. -message ExecuteGraphRequest { - HloModuleProto computation = 1; - repeated GlobalDataHandle arguments = 2; - - // Options that affect how XLA compiles and runs code to service this request. - ExecutionOptions execution_options = 3; -} - -message ExecuteGraphParallelRequest { - repeated ExecuteGraphRequest requests = 1; -} - -message ExecuteResponse { - GlobalDataHandle output = 1; - ExecutionProfile profile = 2; -} - -message ExecuteParallelResponse { - repeated ExecuteResponse responses = 1; -} - -message ComputeConstantGraphRequest { - HloModuleProto computation = 1; - LayoutProto output_layout = 2; -} - -message ComputeConstantResponse { - // A LiteralProto is returned directly for this request. - LiteralProto literal = 1; -} - -message DeconstructTupleRequest { - GlobalDataHandle tuple_handle = 2; -} - -message DeconstructTupleResponse { - repeated GlobalDataHandle element_handles = 1; -} - -message LoadDataRequest { - // Describes the path of the ColumnIO tablet to load. - string columnio_tablet_path = 1; - - // Describes the field to load within the ColumnIO tablet. - string columnio_field = 2; - - // Individual element shape, excluding rows. - ShapeProto element_shape = 3; - - // Warning: ColumnIO does not support random-access, so use offset with - // caution in performance-critical scenarios. - int64 offset = 4; - - // Maximum number of elements (with shape element_shape) to load. - int64 limit = 5; - - // If more than one item is requested (via limit > 1), then this request - // attribute zips together the produced vectors. - bool zip = 6; -} - -message LoadDataResponse { - GlobalDataHandle data = 1; - ShapeProto data_shape = 2; - int64 available_rows = 3; - int64 rows_loaded = 4; - int64 nanoseconds = 5; -} - -message GetShapeRequest { - GlobalDataHandle data = 1; -} - -message GetShapeResponse { - ShapeProto shape = 1; -} - -message UnpackRequest { - GlobalDataHandle data = 1; -} - -message UnpackResponse { - repeated GlobalDataHandle tied_data = 1; -} - // A trace estimated by the Latency Hiding Scheduler. message ScheduleProto { message Instruction {