[go: nahoru, domu]

Skip to content

Commit

Permalink
PR #13258: Give GPU compiler class access to PJRT key-value store.
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#13258

Copybara import of the project:

--
735291df0028b9d3a1a73a7db8fe5104910b4ca5 by Ilia Sergachev <isergachev@nvidia.com>:

Give GPU compiler class access to PJRT key-value store.

Merging this change closes #13258

PiperOrigin-RevId: 640871152
  • Loading branch information
sergachev authored and tensorflower-gardener committed Jun 6, 2024
1 parent 258b6d8 commit 12ccaf3
Show file tree
Hide file tree
Showing 13 changed files with 90 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ absl::Status CreateClientOnce(
/*allocator=*/std::move(info->allocator),
/*host_memory_allocator=*/std::move(info->host_memory_allocator),
/*should_stage_host_to_device_transfers=*/true,
/*gpu_run_options=*/std::move(gpu_run_options));
/*gpu_run_options=*/std::move(gpu_run_options), kv_store);
VLOG(2) << "PJRT GPU client with remote devices created.";
status = SetPjRtClientInTFGlobalResourceManager(DeviceType(DEVICE_GPU),
std::move(pjrt_client));
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/common_runtime/gpu/gpu_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1930,7 +1930,8 @@ Status BaseGPUDeviceFactory::CreateDevices(
/*allocator=*/std::move(allocator_adapter),
/*host_memory_allocator=*/std::move(pjrt_gpu_host_allocator),
/*should_stage_host_to_device_transfers=*/true,
/*gpu_run_options=*/std::move(gpu_run_options));
/*gpu_run_options=*/std::move(gpu_run_options),
/*kv_store=*/nullptr);

return SetPjRtClientInTFGlobalResourceManager(DeviceType(DEVICE_GPU),
std::move(pjrt_client));
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/client/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ cc_library(
"//xla:util",
"//xla:xla_proto_cc",
"//xla/pjrt:compile_options_proto_cc",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/service:compilation_environments",
"//xla/service:computation_placer",
"@com_google_absl//absl/algorithm:container",
Expand Down
20 changes: 20 additions & 0 deletions third_party/xla/xla/client/executable_build_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "xla/pjrt/compile_options.pb.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/service/compilation_environments.h"
#include "xla/service/computation_placer.h"
#include "xla/shape.h"
Expand Down Expand Up @@ -236,6 +237,22 @@ class ExecutableBuildOptions {

absl::StatusOr<ExecutableBuildOptionsProto> ToProto() const;

int process_index() const { return process_index_; }
void set_process_index(const int process_index) {
process_index_ = process_index;
}
int process_count() const { return process_count_; }
void set_process_count(const int process_count) {
process_count_ = process_count;
}

std::shared_ptr<KeyValueStoreInterface> key_value_store() const {
return key_value_store_;
}
void set_key_value_store(std::shared_ptr<KeyValueStoreInterface> kv_store) {
key_value_store_ = kv_store;
}

private:
int device_ordinal_ = -1;
Shape result_layout_;
Expand All @@ -262,6 +279,9 @@ class ExecutableBuildOptions {
LayoutCanonicalizationCallback layout_canonicalization_callback_;
std::string fdo_profile_;
int64_t device_memory_size_ = 0;
int process_index_ = 0;
int process_count_ = 1;
std::shared_ptr<KeyValueStoreInterface> key_value_store_;
};

absl::StatusOr<ExecutableBuildOptions> ExecutableBuildOptionsFromProto(
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/pjrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ cc_library(
"//xla:literal",
"//xla:shape_tree",
"//xla:shape_util",
"//xla:status_macros",
"//xla:statusor",
"//xla:util",
"//xla:xla_data_proto_cc",
Expand Down
11 changes: 7 additions & 4 deletions third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,14 +484,16 @@ StreamExecutorGpuClient::StreamExecutorGpuClient(
int process_index, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tsl::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options)
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options,
std::shared_ptr<KeyValueStoreInterface> kv_store)
: xla::PjRtStreamExecutorClient(
platform_name, client, std::move(devices), process_index,
std::move(allocator), std::move(host_memory_allocator),
should_stage_host_to_device_transfers, std::move(gpu_run_options)),
topology_(xla::StreamExecutorGpuTopologyDescription::Create(
tsl::Fingerprint64(platform_name), platform_name,
devices_.back()->device_kind(), devices_)) {
devices_.back()->device_kind(), devices_)),
kv_store_(std::move(kv_store)) {
for (auto* device : addressable_devices()) {
// Use the device id to construct a globally unique memory space id. We do
// not promise that memory space ids and device ids are the same.
Expand Down Expand Up @@ -718,6 +720,7 @@ PjRtFuture<> StreamExecutorGpuClient::CopyRawSubBufferToHost(
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
StreamExecutorGpuClient::Compile(const XlaComputation& computation,
CompileOptions options) {
options.executable_build_options.set_key_value_store(kv_store_);
auto executable = PjRtStreamExecutorClient::Compile(computation, options);

#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
Expand Down Expand Up @@ -1166,8 +1169,8 @@ absl::StatusOr<std::unique_ptr<PjRtClient>> GetStreamExecutorGpuClient(
return std::unique_ptr<PjRtClient>(std::make_unique<StreamExecutorGpuClient>(
pjrt_platform_name, xla_client, std::move(devices), options.node_id,
std::move(allocator), std::move(host_memory_allocator),
options.should_stage_host_to_device_transfers,
std::move(gpu_run_options)));
options.should_stage_host_to_device_transfers, std::move(gpu_run_options),
std::move(kv_store)));
}

absl::StatusOr<std::string> StreamExecutorGpuTopologyDescription::Serialize()
Expand Down
4 changes: 3 additions & 1 deletion third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient {
int process_index, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tsl::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options,
std::shared_ptr<KeyValueStoreInterface> kv_store);

absl::StatusOr<xla::DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override;
Expand Down Expand Up @@ -256,6 +257,7 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient {

private:
xla::StreamExecutorGpuTopologyDescription topology_;
std::shared_ptr<KeyValueStoreInterface> kv_store_;
};

std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> BuildLocalDevices(
Expand Down
6 changes: 6 additions & 0 deletions third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/shape_tree.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/statusor.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/device_memory_allocator.h"
Expand Down Expand Up @@ -3347,6 +3348,11 @@ PjRtStreamExecutorClient::Compile(const XlaComputation& computation,
CompileOptions options) {
tsl::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
VLOG(1) << "PjRtStreamExecutorClient::Compile";
options.executable_build_options.set_process_index(process_index());
TF_RET_CHECK(device_count() % addressable_device_count() == 0)
<< "Each process is expected to have the same number of devices";
options.executable_build_options.set_process_count(
device_count() / addressable_device_count());
auto input_options = options;

TF_RETURN_IF_ERROR(options.ApplyAllOptionOverrides());
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1565,6 +1565,7 @@ cc_library(
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:hlo_module_group",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/stream_executor",
"//xla/stream_executor:dnn",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down
5 changes: 5 additions & 0 deletions third_party/xla/xla/service/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_module_group.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/buffer_value.h"
#include "xla/service/computation_placer.h"
Expand Down Expand Up @@ -160,6 +161,10 @@ class Compiler {
// Registry of MLIR dialects and plugins to be loaded during optimization.
// If non-null, it will be used to construct relevant MLIR contexts.
mlir::DialectRegistry* registry = nullptr;

int process_index = 0;
int process_count = 1;
std::shared_ptr<KeyValueStoreInterface> key_value_store;
};

virtual ~Compiler() = default;
Expand Down
21 changes: 12 additions & 9 deletions third_party/xla/xla/service/local_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,21 @@ LocalService::CompileExecutables(
// TODO(cjfj): Investigate why there are a couple of test failures when the
// single partition computations are built using `BuildExecutables`, fix it,
// and remove this special case (provided the performance if similar).
const Compiler::CompileOptions compile_options{
build_options.device_allocator(),
build_options.compile_thread_pool(),
build_options.layout_canonicalization_callback(),
false,
{},
nullptr,
build_options.process_index(),
build_options.process_count(),
build_options.key_value_store()};
if (build_options.num_partitions() == 1) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
BuildExecutable(computation.proto(), std::move(module_config),
execute_backend_.get(), executor,
{build_options.device_allocator(),
build_options.compile_thread_pool(),
build_options.layout_canonicalization_callback()},
execute_backend_.get(), executor, compile_options,
build_options.run_backend_only()));
std::vector<std::unique_ptr<Executable>> executables;
executables.push_back(std::move(executable));
Expand All @@ -109,11 +116,7 @@ LocalService::CompileExecutables(

return BuildExecutables(
/*module_protos=*/{&computation.proto()}, std::move(module_configs),
execute_backend_.get(), {executors},
Compiler::CompileOptions{
build_options.device_allocator(),
build_options.compile_thread_pool(),
build_options.layout_canonicalization_callback()},
execute_backend_.get(), {executors}, compile_options,
build_options.run_backend_only());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ absl::StatusOr<std::unique_ptr<xla::PjRtClient>> GetPjRtClient(
TF_QCHECK_OK(distributed_client->Connect());
kv_store = GetDistributedKeyValueStore(distributed_client,
/*key_prefix=*/"gpu:");
return xla::FunctionalHloRunner::CreateGpuClient(distributed_client,
node_id, num_nodes);
return xla::FunctionalHloRunner::CreateGpuClient(kv_store, node_id,
num_nodes);
}
}
}
Expand Down Expand Up @@ -344,20 +344,19 @@ FunctionalHloRunner::CreateMockGpuClient(int num_nodes) {

absl::StatusOr<std::unique_ptr<PjRtClient>>
FunctionalHloRunner::CreateGpuClient(
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client,
int node_id, int num_nodes) {
std::shared_ptr<xla::KeyValueStoreInterface> kv_store, int node_id,
int num_nodes) {
if (node_id < 0 || node_id >= num_nodes) {
return absl::InvalidArgumentError(
"Node id is expected to be in range [0, num_nodes)");
}

TF_RET_CHECK(distributed_client != nullptr);
TF_RET_CHECK(kv_store != nullptr);

GpuClientOptions options;
options.node_id = node_id;
options.num_nodes = num_nodes;
options.kv_store =
GetDistributedKeyValueStore(distributed_client, /*key_prefix=*/"gpu:");
options.kv_store = kv_store;
return GetStreamExecutorGpuClient(options);
}

Expand All @@ -371,7 +370,8 @@ absl::StatusOr<ExecutionOptions> FunctionalHloRunner::LoadExecutionOptions(

absl::StatusOr<CompileOptions> FunctionalHloRunner::CreateCompileOptions(
const PjRtClient& client,
const FunctionalHloRunner::RawCompileOptions& raw_options, int task_id) {
const FunctionalHloRunner::RawCompileOptions& raw_options, int task_id,
int num_nodes, std::shared_ptr<xla::KeyValueStoreInterface> kv_store) {
CompileOptions compile_options;
if (raw_options.execution_options.has_value()) {
compile_options.executable_build_options =
Expand All @@ -388,6 +388,9 @@ absl::StatusOr<CompileOptions> FunctionalHloRunner::CreateCompileOptions(
raw_options.num_slices.value_or(1));
build_options.set_num_replicas(replicas_and_partitions.replicas);
build_options.set_num_partitions(replicas_and_partitions.partitions);
build_options.set_process_index(task_id);
build_options.set_process_count(num_nodes);
build_options.set_key_value_store(kv_store);
if (raw_options.spmd_mode == SpmdMode::kUseSpmdPartitioning) {
build_options.set_use_spmd_partitioning(true);
}
Expand Down Expand Up @@ -570,10 +573,12 @@ absl::Status FunctionalHloRunner::LoadAndRunAndDump(
const xla::FunctionalHloRunner::RawCompileOptions& raw_compile_options,
const xla::FunctionalHloRunner::RunningOptions& running_options,
absl::string_view hlo_text, InputFormat input_format,
std::string dump_output_to, int task_id) {
TF_ASSIGN_OR_RETURN(CompileOptions compile_options,
FunctionalHloRunner::CreateCompileOptions(
client, raw_compile_options, task_id));
std::string dump_output_to, int task_id, int num_nodes,
std::shared_ptr<xla::KeyValueStoreInterface> kv_store) {
TF_ASSIGN_OR_RETURN(
CompileOptions compile_options,
FunctionalHloRunner::CreateCompileOptions(client, raw_compile_options,
task_id, num_nodes, kv_store));
TF_ASSIGN_OR_RETURN(
FunctionalHloRunner::PerDeviceLiteralVecType output,
FunctionalHloRunner::LoadAndRun(client, debug_options, preproc_options,
Expand Down Expand Up @@ -620,10 +625,12 @@ absl::Status FunctionalHloRunner::LoadAndCompile(
PjRtClient& client, const DebugOptions& debug_options,
const PreprocessingOptions& preproc_options,
const RawCompileOptions& raw_compile_options, std::string_view hlo_file,
InputFormat input_format, int task_id) {
TF_ASSIGN_OR_RETURN(CompileOptions compile_options,
FunctionalHloRunner::CreateCompileOptions(
client, raw_compile_options, task_id));
InputFormat input_format, int task_id, int num_nodes,
std::shared_ptr<xla::KeyValueStoreInterface> kv_store) {
TF_ASSIGN_OR_RETURN(
CompileOptions compile_options,
FunctionalHloRunner::CreateCompileOptions(client, raw_compile_options,
task_id, num_nodes, kv_store));

int num_replicas = compile_options.executable_build_options.num_replicas();
int num_partitions =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ class FunctionalHloRunner {
// The distributed client pointer passed as a parameter is expected to be
// non-null, and 0 <= node_id < num_nodes must hold.
static absl::StatusOr<std::unique_ptr<PjRtClient>> CreateGpuClient(
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client,
int node_id, int num_nodes);
std::shared_ptr<xla::KeyValueStoreInterface> kv_store, int node_id,
int num_nodes);

// Loads an ExecutionOptions proto (which can be used in RawCompileOptions).
static absl::StatusOr<ExecutionOptions> LoadExecutionOptions(
Expand All @@ -248,7 +248,8 @@ class FunctionalHloRunner {
static absl::StatusOr<CompileOptions> CreateCompileOptions(
const PjRtClient& client,
const FunctionalHloRunner::RawCompileOptions& raw_options,
int task_id = 0);
int task_id = 0, int num_nodes = 1,
std::shared_ptr<xla::KeyValueStoreInterface> kv_store = nullptr);

// Runs on HLO module and dumps the output if needed.
//
Expand All @@ -259,7 +260,8 @@ class FunctionalHloRunner {
const xla::FunctionalHloRunner::RawCompileOptions& raw_compile_options,
const xla::FunctionalHloRunner::RunningOptions& running_options,
absl::string_view hlo_text, InputFormat input_format,
std::string dump_output_to = "", int task_id = 0);
std::string dump_output_to = "", int task_id = 0, int num_nodes = 1,
std::shared_ptr<xla::KeyValueStoreInterface> kv_store = nullptr);

// Loads an HLO module from hlo_file according to input_format and run it.
// The HLO module is run with the provided arguments if the arguments map is
Expand All @@ -281,7 +283,8 @@ class FunctionalHloRunner {
PjRtClient& client, const DebugOptions& debug_options,
const PreprocessingOptions& preproc_options,
const RawCompileOptions& raw_compile_options, std::string_view hlo_file,
InputFormat input_format, int task_id = 0);
InputFormat input_format, int task_id = 0, int num_nodes = 1,
std::shared_ptr<xla::KeyValueStoreInterface> kv_store = nullptr);

// Compiles and runs the given HLO module with the given arguments for each
// device. The given arguments is a map from device ID to a list of arguments.
Expand Down

0 comments on commit 12ccaf3

Please sign in to comment.