[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA] Remove proto-based communication for service/client
Browse files Browse the repository at this point in the history
Originally, service/client interface in XLA was envisioned as a compilation
service interface, with proto-based communication for RPCs.
That compilation service never quite worked in that shape, and in the meantime,
main compilation API moved to PjRT.

Having lots of protos around complicates the XLA compilation API, and
serialization/deserialization to protos also hurts performance.

Since all Service usages are realistically local, let's just not use protos at
the boundary and pass original datastructures.

PiperOrigin-RevId: 644695252
  • Loading branch information
cheshire authored and tensorflower-gardener committed Jun 19, 2024
1 parent 160e760 commit c72dbfc
Show file tree
Hide file tree
Showing 15 changed files with 325 additions and 828 deletions.
1 change: 0 additions & 1 deletion third_party/xla/xla/client/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ filegroup(

cc_library(
name = "global_data",
srcs = ["global_data.cc"],
hdrs = ["global_data.h"],
deps = [
"//xla:types",
Expand Down
302 changes: 19 additions & 283 deletions third_party/xla/xla/client/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,129 +41,28 @@ Client::~Client() = default;

absl::StatusOr<Literal> Client::Transfer(const GlobalData& data,
const Shape* shape_with_layout) {
TransferToClientRequest request;
*request.mutable_data() = data.handle();
if (shape_with_layout != nullptr) {
*request.mutable_shape_with_layout() = shape_with_layout->ToProto();
}
TransferToClientResponse response;

VLOG(1) << "making transfer request";
VLOG(3) << "TransferToClientRequest: {" << request.DebugString() << "}";
absl::Status s = stub_->TransferToClient(&request, &response);
VLOG(1) << "done with request";

if (!s.ok()) {
return s;
}
VLOG(3) << "TransferToClientResponse: {" << response.DebugString() << "}";

if (!response.has_literal()) {
return FailedPrecondition(
"server provided response without a literal in "
"TransferToClient request");
}
return Literal::CreateFromProto(*response.mutable_literal());
return stub_->TransferToClient(data, shape_with_layout);
}

absl::StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
const LiteralSlice& literal, const DeviceHandle* device_handle) {
TransferToServerRequest request;
*request.mutable_literal() = literal.ToProto();
if (device_handle) {
*request.mutable_device_handle() = *device_handle;
}
TransferToServerResponse response;

VLOG(1) << "making transfer to server request";
VLOG(3) << "TransferToServerRequest: {" << request.DebugString() << "}";
absl::Status s = stub_->TransferToServer(&request, &response);
VLOG(1) << "done with request";

if (!s.ok()) {
return s;
}
VLOG(3) << "TransferToServerResponse: {" << response.DebugString() << "}";

if (!response.has_data()) {
return FailedPrecondition(
"server provided response without a data handle in "
"TransferToServer request");
}

return std::make_unique<GlobalData>(stub_, response.data());
return stub_->TransferToServer(literal, device_handle);
}

absl::Status Client::TransferToInfeed(const LiteralSlice& literal,
int64_t replica_id,
const DeviceHandle* device_handle) {
TransferToInfeedRequest request;
*request.mutable_literal() = literal.ToProto();
if (device_handle) {
*request.mutable_device_handle() = *device_handle;
}
request.set_replica_id(replica_id);
TransferToInfeedResponse response;

VLOG(1) << "making transfer to infeed request";
VLOG(3) << "TransferToInfeedRequest: {" << request.DebugString() << "}";
absl::Status s = stub_->TransferToInfeed(&request, &response);
VLOG(1) << "done with request";

if (!s.ok()) {
return s;
}
VLOG(3) << "TransferToInfeedResponse: {" << response.DebugString() << "}";
return absl::OkStatus();
return stub_->TransferToInfeed(literal, replica_id, device_handle);
}

absl::StatusOr<Literal> Client::TransferFromOutfeed(
const Shape* shape_with_layout, int64_t replica_id,
const DeviceHandle* device_handle) {
TransferFromOutfeedRequest request;
if (device_handle) {
*request.mutable_device_handle() = *device_handle;
}
request.set_replica_id(replica_id);
if (shape_with_layout != nullptr) {
*request.mutable_shape_with_layout() = shape_with_layout->ToProto();
}
TransferFromOutfeedResponse response;

VLOG(1) << "making transfer from outfeed request";
VLOG(3) << "TransferFromOutfeedRequest: {" << request.DebugString() << "}";
absl::Status s = stub_->TransferFromOutfeed(&request, &response);
VLOG(1) << "done with request";

if (!s.ok()) {
return s;
}
VLOG(3) << "TransferFromOutfeedResponse: {" << response.DebugString() << "}";

if (!response.has_literal()) {
return FailedPrecondition(
"server provided response without a literal in "
"TransferToClient request");
}

return Literal::CreateFromProto(response.literal());
return stub_->TransferFromOutfeed(shape_with_layout, replica_id,
device_handle);
}

absl::Status Client::ResetDevice() {
ResetDeviceRequest request;
ResetDeviceResponse response;

VLOG(1) << "making reset device request";
VLOG(3) << "ResetDeviceRequest: {" << request.DebugString() << "}";
absl::Status s = stub_->ResetDevice(&request, &response);
VLOG(1) << "done with request";

if (!s.ok()) {
return s;
}
VLOG(3) << "ResetDeviceResponse: {" << response.DebugString() << "}";
return absl::OkStatus();
}
absl::Status Client::ResetDevice() { return stub_->ResetDevice(); }

absl::StatusOr<Literal> Client::ExecuteAndTransfer(
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
Expand All @@ -185,30 +84,7 @@ absl::StatusOr<Literal> Client::ExecuteAndTransfer(

absl::StatusOr<Literal> Client::ComputeConstant(
const XlaComputation& computation, const Layout* output_layout) const {
ComputeConstantGraphRequest request;
*request.mutable_computation() = computation.proto();
if (output_layout != nullptr) {
*request.mutable_output_layout() = output_layout->ToProto();
}

ComputeConstantResponse response;

VLOG(2) << "making compute-constant-graph request";
absl::Status s = stub_->ComputeConstantGraph(&request, &response);
VLOG(2) << "done with request";

if (!s.ok()) {
return s;
}

VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}";

if (!response.has_literal()) {
return Internal(
"no computed literal in the provided response in ComputeConstantGraph "
"request");
}
return Literal::CreateFromProto(response.literal());
return stub_->ComputeConstantGraph(computation, output_layout);
}

absl::StatusOr<XlaComputation> Client::LoadSnapshot(const HloSnapshot& module) {
Expand All @@ -219,61 +95,19 @@ absl::StatusOr<XlaComputation> Client::LoadSnapshot(const HloSnapshot& module) {
absl::StatusOr<ExecutionHandle> Client::Compile(
const XlaComputation& computation, absl::Span<const Shape> argument_shapes,
const ExecutionOptions* execution_options) {
CompileRequest request;
*request.mutable_computation() = computation.proto();

if (execution_options == nullptr) {
*request.mutable_execution_options() = CreateDefaultExecutionOptions();
} else {
*request.mutable_execution_options() = *execution_options;
}
if (request.execution_options().device_handles_size() > 1) {
return InvalidArgument(
"Compiling with multiple device handles is not supported. Use "
"'Execute' instead.");
}

// The argument shapes affect how the computation is compiled.
for (const auto& arg_shape : argument_shapes) {
*request.add_input_shape_with_layout() = arg_shape.ToProto();
std::optional<ExecutionOptions> opts;
if (!execution_options) {
opts = CreateDefaultExecutionOptions();
}

CompileResponse response;
VLOG(1) << "making compile request: " << request.ShortDebugString();
absl::Status s = stub_->Compile(&request, &response);
VLOG(1) << "done with request";

if (!s.ok()) {
return s;
}
TF_RET_CHECK(response.has_handle());
return response.handle();
return stub_->Compile(computation, argument_shapes,
execution_options ? *execution_options : *opts);
}

absl::StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
const ExecutionHandle& handle, absl::Span<GlobalData* const> arguments,
ExecutionProfile* execution_profile) {
ExecuteRequest request;
*request.mutable_handle() = handle;
for (GlobalData* argument : arguments) {
CHECK(argument != nullptr) << "Argument pointers must not be null.";
*request.add_arguments() = argument->handle();
}

ExecuteResponse response;
VLOG(1) << "making execute request: " << request.ShortDebugString();
absl::Status s = stub_->Execute(&request, &response);
VLOG(1) << "done with request";

if (!s.ok()) {
return s;
}

if (execution_profile != nullptr) {
*execution_profile = response.profile();
}

return std::make_unique<GlobalData>(stub_, response.output());
return stub_->Execute(handle, arguments, execution_profile);
}

absl::StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
Expand Down Expand Up @@ -329,99 +163,25 @@ absl::StatusOr<std::unique_ptr<GlobalData>> Client::Execute(

absl::StatusOr<std::vector<std::unique_ptr<GlobalData>>>
Client::ExecuteParallel(absl::Span<const XlaComputationInstance> computations) {
ExecuteGraphParallelRequest request;

for (const XlaComputationInstance& computation : computations) {
ExecuteGraphRequest single_request;
*single_request.mutable_computation() = computation.computation.proto();
for (GlobalData* argument : computation.arguments) {
*single_request.add_arguments() = argument->handle();
}
*single_request.mutable_execution_options() = computation.execution_options;
*request.add_requests() = single_request;
}

ExecuteParallelResponse response;
VLOG(1) << "making execute-graph-parallel request: "
<< request.ShortDebugString();
absl::Status s = stub_->ExecuteGraphParallel(&request, &response);
VLOG(1) << "done with request";

if (!s.ok()) {
return s;
}

std::vector<std::unique_ptr<GlobalData>> outputs;
for (size_t i = 0, end = response.responses_size(); i < end; ++i) {
outputs.push_back(
std::make_unique<GlobalData>(stub_, response.responses(i).output()));
if (i < computations.size() &&
computations[i].execution_profile != nullptr) {
*computations[i].execution_profile = response.responses(i).profile();
}
}

return std::move(outputs);
return stub_->ExecuteGraphParallel(computations);
}

absl::StatusOr<std::vector<DeviceHandle>> Client::GetDeviceHandles(
int64_t device_count) {
if (device_count < 1) {
return InvalidArgument("device_count must be greater than 0");
}
GetDeviceHandlesRequest request;
request.set_device_count(device_count);

GetDeviceHandlesResponse response;
VLOG(1) << "making get device request: " << request.ShortDebugString();
absl::Status s = stub_->GetDeviceHandles(&request, &response);
VLOG(1) << "done with request";

if (!s.ok()) {
return s;
}

std::vector<DeviceHandle> device_handles;
const auto& response_device_handles = response.device_handles();
device_handles.reserve(response_device_handles.size());
for (const DeviceHandle& device_handle : response_device_handles) {
device_handles.push_back(device_handle);
}

return device_handles;
return stub_->GetDeviceHandles(device_count);
}

absl::Status Client::Unregister(const GlobalData& data) {
UnregisterRequest request;
*request.add_data() = data.handle();
UnregisterResponse response;

VLOG(1) << "making unregister request";
absl::Status s = stub_->Unregister(&request, &response);
VLOG(1) << "done with request";

return s;
return stub_->Unregister(data.handle());
}

absl::StatusOr<std::vector<std::unique_ptr<GlobalData>>>
Client::DeconstructTuple(const GlobalData& data) {
DeconstructTupleRequest request;
*request.mutable_tuple_handle() = data.handle();
DeconstructTupleResponse response;

VLOG(1) << "making DestructTuple request";
absl::Status s = stub_->DeconstructTuple(&request, &response);
VLOG(1) << "done with request";

if (!s.ok()) {
return s;
}

std::vector<std::unique_ptr<GlobalData>> handles;
for (auto& handle : response.element_handles()) {
handles.push_back(std::make_unique<GlobalData>(stub_, handle));
}
return std::move(handles);
return stub_->DeconstructTuple(data);
}

absl::StatusOr<std::unique_ptr<ProgramShape>> Client::GetComputationShape(
Expand All @@ -431,36 +191,12 @@ absl::StatusOr<std::unique_ptr<ProgramShape>> Client::GetComputationShape(
}

absl::StatusOr<Shape> Client::GetShape(const GlobalData& data) {
GetShapeRequest request;
*request.mutable_data() = data.handle();
GetShapeResponse response;

VLOG(1) << "making get shape request";
absl::Status s = stub_->GetShape(&request, &response);
VLOG(1) << "done with request";

if (!s.ok()) {
return s;
}

return Shape(response.shape());
return stub_->GetShape(data);
}

absl::StatusOr<ChannelHandle> Client::CreateChannelHandleByType(
ChannelHandle::ChannelType type) {
CreateChannelHandleRequest request;
request.set_channel_type(type);
CreateChannelHandleResponse response;

VLOG(1) << "making create channel handle request";
absl::Status s = stub_->CreateChannelHandle(&request, &response);
VLOG(1) << "done with request";

if (!s.ok()) {
return s;
}

return response.channel();
return stub_->CreateChannelHandle(type);
}

absl::StatusOr<ChannelHandle> Client::CreateChannelHandle() {
Expand Down
Loading

0 comments on commit c72dbfc

Please sign in to comment.