[go: nahoru, domu]

Skip to content

Commit

Permalink
[NFC] Inline ExecuteXlaRuntimeOrThunks
Browse files Browse the repository at this point in the history
The method does not make sense since we currently always go through thunks

PiperOrigin-RevId: 640506854
  • Loading branch information
cheshire authored and tensorflower-gardener committed Jun 5, 2024
1 parent a1ae1a6 commit 46627bc
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 54 deletions.
73 changes: 28 additions & 45 deletions third_party/xla/xla/service/gpu/gpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -354,15 +354,23 @@ absl::Status RendezvousAfterInitialization(
const ServiceExecutableRunOptions* run_options);

absl::Status ExecuteThunks(
const std::string& module_name, ModuleIdentifier module_id,
const ThunkSequence& thunk_sequence,
const DebugOptions* debug_options, const std::string& module_name,
ModuleIdentifier module_id, const ThunkSequence& thunk_sequence,
Thunk::ExecutableSource executable_source,
const ServiceExecutableRunOptions* run_options,
const BufferAllocations& buffer_allocations, bool block_host_until_done,
bool use_highest_priority_for_async_stream,
const absl::flat_hash_set<ExecutionStreamId>& execution_stream_ids,
int64_t collective_max_nchannels, int64_t p2p_max_nchannels,
const ModuleAnnotations& module_annotations) {
int64_t collective_max_nchannels =
debug_options ? debug_options->xla_gpu_nccl_collective_max_nchannels()
: 0;
int64_t p2p_max_nchannels =
debug_options ? debug_options->xla_gpu_nccl_p2p_max_nchannels() : 0;
bool use_highest_priority_for_async_stream =
debug_options
? debug_options->xla_gpu_enable_highest_priority_async_stream()
: false;

se::Stream* main_stream = run_options->stream();
se::StreamExecutor* executor = main_stream->parent();
stream_executor::StreamPriority stream_priority =
Expand Down Expand Up @@ -1000,8 +1008,22 @@ absl::StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStreamImpl(
buffers_in_result.insert(result_buffer);
}

TF_RETURN_IF_ERROR(ExecuteThunksOrXlaRuntime(run_options, buffer_allocations,
block_host_until_done));
{
TF_RETURN_IF_ERROR(
CheckCompatibilityWithServiceExecutableRunOptions(run_options));

ScopedAnnotation annotation([&] { return module_annotations_.top_level; });
ScopedModuleAnnotations module_annotations(&module_annotations_);

ModuleIdentifier unique_id = has_module() ? module().unique_id() : -1;
Thunk::ExecutableSource executable_source = {text_, binary_,
dnn_compiled_graphs_};

TF_RETURN_IF_ERROR(ExecuteThunks(
has_module() ? &module_config().debug_options() : nullptr, module_name_,
unique_id, *thunks_, executable_source, run_options, buffer_allocations,
block_host_until_done, execution_stream_ids_, module_annotations_));
}

TF_RETURN_IF_ERROR(
buffer_allocations.TearDown(buffers_in_result, GetAllocations()));
Expand All @@ -1013,45 +1035,6 @@ absl::StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStreamImpl(
return std::move(result);
}

absl::Status GpuExecutable::ExecuteThunksOrXlaRuntime(
const ServiceExecutableRunOptions* run_options,
const BufferAllocations& buffer_allocations, bool block_host_until_done) {
TF_RETURN_IF_ERROR(
CheckCompatibilityWithServiceExecutableRunOptions(run_options));

ScopedAnnotation annotation([&] { return module_annotations_.top_level; });
ScopedModuleAnnotations module_annotations(&module_annotations_);

ModuleIdentifier unique_id = has_module() ? module().unique_id() : -1;

if (thunks_) {
Thunk::ExecutableSource executable_source = {text_, binary_,
dnn_compiled_graphs_};
int64_t collective_max_nchannels =
has_module() ? module_config()
.debug_options()
.xla_gpu_nccl_collective_max_nchannels()
: 0;
int64_t p2p_max_nchannels =
has_module()
? module_config().debug_options().xla_gpu_nccl_p2p_max_nchannels()
: 0;

return ExecuteThunks(
module_name_, unique_id, *thunks_, executable_source, run_options,
buffer_allocations, block_host_until_done,
/*use_highest_priority_for_async_stream*/
has_module() ? module_config()
.debug_options()
.xla_gpu_enable_highest_priority_async_stream()
: false,
execution_stream_ids_, collective_max_nchannels, p2p_max_nchannels,
module_annotations_);
}

return FailedPrecondition("Expected XLA gpu executable is not supplied.");
}

int64_t GpuExecutable::SizeOfGeneratedCodeInBytes() const {
// Non-empty PTX but empty cubin: compilation must have failed, return
// "unknown".
Expand Down
9 changes: 0 additions & 9 deletions third_party/xla/xla/service/gpu/gpu_executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,6 @@ class GpuExecutable : public Executable {
// Use GpuExecutable::Create() to create an instance.
explicit GpuExecutable(Params params);

// If `block_host_until_done` is false, execution will not block the host
// until the kernels have completed. This is used as an optimization for
// clients, such as Tensorflow, that use a single stream of execution for
// computations, and allow host-side deallocation from the allocator before
// GPU execution completes.
absl::Status ExecuteThunksOrXlaRuntime(
const ServiceExecutableRunOptions* run_options,
const BufferAllocations& buffer_allocations, bool block_host_until_done);

using BufferAllocToDeviceMemoryMap =
absl::flat_hash_map<BufferAllocation::Index, se::DeviceMemoryBase>;

Expand Down

0 comments on commit 46627bc

Please sign in to comment.