From 78385a14e5df07618cb5fe2368bd40ce283f097f Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Wed, 27 Mar 2024 22:14:12 -0700 Subject: [PATCH] Make usages of Eigen::array compatible with std::array. Eigen::array is no longer necessary, so will be deprecated/removed and replaced with `std::array`. The main difference is the constructor - currently Eigen::array allows `array(a, b, c, ...)` construction, whereas `std::array` requires an initializer list. We also need to remove any direct access to the `Eigen::array::values` internal parameter, in favor of regular index access. PiperOrigin-RevId: 619787691 --- tensorflow/core/distributed_runtime/BUILD | 1 + .../core/distributed_runtime/master_test.cc | 6 +- .../core/kernels/gather_nd_op_gpu.cu.cc | 11 +- .../core/kernels/image/adjust_contrast_op.cc | 40 ++-- ...arameterized_truncated_normal_op_gpu.cu.cc | 2 - tensorflow/core/kernels/random_binomial_op.cc | 2 - .../kernels/sparse_tensor_dense_matmul_op.cc | 3 +- .../eigen_spatial_convolutions_test.cc | 4 - third_party/xla/xla/pjrt/gpu/BUILD | 1 + .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 183 +++++++++++------- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.h | 2 +- .../xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 11 +- .../xla/pjrt/pjrt_stream_executor_client.cc | 23 +-- .../xla/pjrt/pjrt_stream_executor_client.h | 3 +- .../xla/service/gpu/elemental_ir_emitter.cc | 4 +- 15 files changed, 166 insertions(+), 130 deletions(-) diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 77c8c9207f8b37..87789060878b05 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -779,6 +779,7 @@ tf_cuda_cc_test( "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:variable_ops", "//tensorflow/core/protobuf:master_proto_cc", + "@eigen_archive//:eigen3", ] + tf_grpc_cc_dependencies(), ) diff --git a/tensorflow/core/distributed_runtime/master_test.cc b/tensorflow/core/distributed_runtime/master_test.cc index 5c2f17e31f819d..1e9e5545183191 100644 --- a/tensorflow/core/distributed_runtime/master_test.cc +++ b/tensorflow/core/distributed_runtime/master_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "grpcpp/grpcpp.h" - +#include "Eigen/Core" // from @eigen_archive #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h" @@ -389,8 +389,8 @@ TEST_F(MasterTest, EigenProblem) { TF_CHECK_OK(CreateSession(def, &handle, &initial_version)); // Temps supporting the computation of the convergence condition. - const Eigen::array sum_along_dim(0); - const Eigen::array matrix_transpose({1, 0}); + const Eigen::array sum_along_dim{0}; + const Eigen::array matrix_transpose{1, 0}; Tensor x(DT_FLOAT, TensorShape({2, 1})); Tensor y(DT_FLOAT, TensorShape({2, 1})); Eigen::Tensor y_square_sum; diff --git a/tensorflow/core/kernels/gather_nd_op_gpu.cu.cc b/tensorflow/core/kernels/gather_nd_op_gpu.cu.cc index 227cd311c244ee..c26f5bdf492e39 100644 --- a/tensorflow/core/kernels/gather_nd_op_gpu.cu.cc +++ b/tensorflow/core/kernels/gather_nd_op_gpu.cu.cc @@ -39,11 +39,14 @@ __global__ void GatherSliceOpKernel( const auto indices_i = indices + IXDIM * loc; bool out_of_bounds = false; Index offset = 0; + // Avoid empty std::array access, which fails to compile on GPU. + if constexpr (IXDIM > 0) { #pragma unroll - for (int j = 0; j < IXDIM; ++j) { - const Index index_j = ldg(indices_i + j); - out_of_bounds |= !FastBoundsCheck(index_j, batch_indices[j]); - offset += batch_strides[j] * index_j; + for (int j = 0; j < IXDIM; ++j) { + const Index index_j = ldg(indices_i + j); + out_of_bounds |= !FastBoundsCheck(index_j, batch_indices[j]); + offset += batch_strides[j] * index_j; + } } // TODO(ebrevdo): // This is the only part that depends on the offset. The part diff --git a/tensorflow/core/kernels/image/adjust_contrast_op.cc b/tensorflow/core/kernels/image/adjust_contrast_op.cc index 7cef95b9479022..df8650ebfed515 100644 --- a/tensorflow/core/kernels/image/adjust_contrast_op.cc +++ b/tensorflow/core/kernels/image/adjust_contrast_op.cc @@ -248,6 +248,7 @@ class AdjustContrastOpv2 : public AdjustContrastOpV2Base { TTypes::Tensor mean_flat(&mean(0, 0), mean.size()); TTypes::Tensor summation_scratch(&scratch(0, 0, 0), scratch.size()); + using Eigen::DenseIndex; typedef Eigen::array Index; const int64_t plane_size = image_size * channels; // Since the number of channels in the early layers is often small, a @@ -255,10 +256,10 @@ class AdjustContrastOpv2 : public AdjustContrastOpV2Base { // This algorithm repeatedly folds each image plane by half, until // only one set of channels remains. for (int64_t i = 0; i < batch; i++) { - auto input_plane = - input_flat.slice(Index(i * plane_size), Index(plane_size)); - auto summation_plane = - summation_scratch.slice(Index(i * plane_size), Index(plane_size)); + auto input_plane = input_flat.slice(Index{DenseIndex(i * plane_size)}, + Index{DenseIndex(plane_size)}); + auto summation_plane = summation_scratch.slice( + Index{DenseIndex(i * plane_size)}, Index{DenseIndex(plane_size)}); int64_t remaining_size = image_size; int round = 0; // Sum the input(i, :, k) into mean(i, k). Repeatedly splits the input @@ -289,26 +290,29 @@ class AdjustContrastOpv2 : public AdjustContrastOpV2Base { if (round == 0) { // In the first round, sum the left side and right side of the input // array into the summation area. - summation_plane.slice(Index(0), Index(right_size * channels)) = - input_plane.slice(Index(left_size * channels), - Index(right_size * channels)) + - input_plane.slice(Index(0), Index(right_size * channels)); + summation_plane.slice(Index{0}, + Index{DenseIndex(right_size * channels)}) = + input_plane.slice(Index{DenseIndex(left_size * channels)}, + Index{DenseIndex(right_size * channels)}) + + input_plane.slice(Index{0}, + Index{DenseIndex(right_size * channels)}); if (left_size > right_size) { DCHECK_EQ(left_size - right_size, 1); // Copy over the remaining column if the remaining_size is odd. // This also handles the case where image_size == 1. - summation_plane.slice(Index(right_size * channels), - Index(channels)) = - input_plane.slice(Index(right_size * channels), - Index(channels)); + summation_plane.slice(Index{DenseIndex(right_size * channels)}, + Index{DenseIndex(channels)}) = + input_plane.slice(Index{DenseIndex(right_size * channels)}, + Index{DenseIndex(channels)}); } } else { // For all the remaining rounds, add the second half of the inputs // into the first half of the inputs. With the flat structure and // large size, this utilizes vectorization between components. - summation_plane.slice(Index(0), Index(right_size * channels)) += - summation_plane.slice(Index(left_size * channels), - Index(right_size * channels)); + summation_plane.slice(Index{0}, + Index{DenseIndex(right_size * channels)}) += + summation_plane.slice(Index{DenseIndex(left_size * channels)}, + Index{DenseIndex(right_size * channels)}); } remaining_size = left_size; round++; @@ -316,9 +320,11 @@ class AdjustContrastOpv2 : public AdjustContrastOpV2Base { const float mean_scaling = 1.0f / image_size; // The first channels elements in summation_plane now holds the summation. // Scale it with image_size and copy over to the means. - auto mean_plane = mean_flat.slice(Index(i * channels), Index(channels)); + auto mean_plane = mean_flat.slice(Index{DenseIndex(i * channels)}, + Index{DenseIndex(channels)}); mean_plane = - summation_plane.slice(Index(0), Index(channels)) * mean_scaling; + summation_plane.slice(Index{0}, Index{DenseIndex(channels)}) * + mean_scaling; } } diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc index b826564437c0a1..e7b76653dc329e 100644 --- a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc +++ b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc @@ -132,8 +132,6 @@ __global__ void __launch_bounds__(1024) (normMax >= T(0.))) || ((normMax > kStdDevsInsideBoundsToUseRandnSampler) && (normMin <= T(0.)))) { - Eigen::array n; - int numIterations = 0; while (numIterations < kMaxIterations) { const auto randn = normal_dist(&gen); diff --git a/tensorflow/core/kernels/random_binomial_op.cc b/tensorflow/core/kernels/random_binomial_op.cc index 8fceaf70c0dbbb..98118b78eb5b58 100644 --- a/tensorflow/core/kernels/random_binomial_op.cc +++ b/tensorflow/core/kernels/random_binomial_op.cc @@ -187,8 +187,6 @@ struct RandomBinomialFunctor { &gen, &output](int64_t start_output, int64_t limit_output) { // Vectorized intermediate calculations for uniform rejection sampling. // We always generate at most 4 samples. - Eigen::array z; - Eigen::array g; const bool should_bcast = bcast.IsBroadcastingRequired(); const auto& counts_batch_indices = bcast.x_batch_indices(); const auto& probs_batch_indices = bcast.y_batch_indices(); diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc index 04aff711362552..cb80fa34230a20 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h" +#include "Eigen/Core" // from @eigen_archive #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -310,7 +311,7 @@ Status SparseTensorDenseMatMulImpl( if (ADJ_B) { // Perform transpose and conjugation on B once, since we chip out B's // columns in the nnz loop. - Eigen::array shuffle(1, 0); // preserve dimension order + Eigen::array shuffle{1, 0}; // preserve dimension order Eigen::Tensor col_major_conj_b = b.swap_layout().shuffle(shuffle).conjugate(); LOOP_NNZ(col_major_conj_b); diff --git a/third_party/xla/third_party/tsl/tsl/framework/convolution/eigen_spatial_convolutions_test.cc b/third_party/xla/third_party/tsl/tsl/framework/convolution/eigen_spatial_convolutions_test.cc index 48e0379bce1466..9f589c549901e9 100644 --- a/third_party/xla/third_party/tsl/tsl/framework/convolution/eigen_spatial_convolutions_test.cc +++ b/third_party/xla/third_party/tsl/tsl/framework/convolution/eigen_spatial_convolutions_test.cc @@ -1036,10 +1036,6 @@ static void PackLhsHelper(::testing::benchmark::State& state, reshape_dims[0] = filter_count; reshape_dims[1] = input_depth * filter_rows * filter_cols; - // We are going to contract along the 'in_depth * filter_rows * filter_cols`. - nocontract_t nocontract_dim = {0}; - contract_t contract_dim = {1}; - // These values computed using the algorithm in TensorContraction.h, with // 'nocontract_dim' and 'contract_dim' values specified above. nocontract_t nocontract_strides = {1}; diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index d440f0ace7b424..9a192f20c65aec 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -91,6 +91,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 2b167fcb94b857..6f77844ef19cba 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -33,6 +33,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" +#include "absl/functional/bind_front.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/memory/memory.h" @@ -504,97 +505,141 @@ StreamExecutorGpuClient::GetDefaultDeviceAssignment(int num_replicas, } PjRtFuture<> StreamExecutorGpuClient::CopyRawSubBufferToHost( - PjRtBuffer* pjrt_buffer, void* dst, int64_t offset, int64_t transfer_size) { + PjRtBuffer* pjrt_buffer, PjRtFuture dst, int64_t offset, + int64_t transfer_size) { auto* buffer = tensorflow::down_cast(pjrt_buffer); DCHECK(buffer); PjRtStreamExecutorDevice* device = buffer->device(); LocalDeviceState* local_device = device->local_device_state(); se::Stream* stream = local_device->GetDeviceToHostStream(); + // Acquire the usage hold inline so that the buffer is kept alive even if + // `dst` is not immediately available. PjRtStreamExecutorBuffer::ScopedHold hold(buffer->GetBufferWithUsageHold()); if (!hold.ok()) { return PjRtFuture<>(hold.status()); } + auto device_buffer = hold.buffer(); if (device_buffer->device_memory().size() != 1) { return PjRtFuture<>(InvalidArgument("Copy raw buffer called on tuple")); } - auto& device_memory = device_buffer->device_memory()[0]; - if (offset < 0 || offset > device_memory.size() || - device_memory.size() - offset < transfer_size) { - return PjRtFuture<>( - InvalidArgument("Copy raw buffer called on buffer size %lld with " - "invalid offset %lld, transfer size %lld", - device_memory.size(), offset, transfer_size)); - } - WaitForBufferDefinitionEventsOnStream(*device_buffer, stream); - absl::StatusOr event_or = - local_device->event_pool().AllocateEvent(stream->parent()); - if (!event_or.ok()) { - return PjRtFuture<>(event_or.status()); - } - std::unique_ptr sub_buffer; - if (transfer_size < device_memory.size()) { - sub_buffer = std::make_unique( - device_memory.GetByteSlice(offset, transfer_size)); - } else { - sub_buffer = std::make_unique(device_memory); - } + auto promise = PjRtFuture<>::CreatePromise(); + auto usage_event = + std::make_shared(this->thread_pool()); - if (transfer_size != 0) { - if (should_stage_host_to_device_transfers()) { - if (host_memory_allocator() == nullptr) { - return PjRtFuture<>(InvalidArgument( - "host_memory_allocator should be initialized for staging buffer " - "transfer.")); - } - void* ptr = host_memory_allocator()->AllocateRaw( - tsl::Allocator::kAllocatorAlignment, transfer_size); + // When using the ComputeSynchronized allocation model, retain a reference to + // the device_buffer until the copy completes, to ensure that the buffer isn't + // deleted or donated while it is still in use. The choice of retaining a + // reference at the host is a heuristic; the alternative is to ensure, before + // freeing the buffer, that the compute stream is synchronized past the + // transfer, but it seems better to hold onto the buffer too long than to + // stall the compute stream. + hold.ConvertUsageHold(stream, usage_event, /*reference_held=*/true); + + auto async_copy = [this, promise, offset, transfer_size, stream, local_device, + device_buffer, usage_event = std::move(usage_event)]( + absl::StatusOr dst) mutable { + absl::StatusOr event = + local_device->event_pool().AllocateEvent(stream->parent()); + if (!event.ok()) { + promise.Set(event.status()); + return; + } - std::shared_ptr staging_buffer = std::shared_ptr( - ptr, [host_memory_allocator = host_memory_allocator()](void* ptr) { - host_memory_allocator->DeallocateRaw(ptr); - }); - if (auto status = - stream->Memcpy(staging_buffer.get(), *sub_buffer, transfer_size); - !status.ok()) { - return PjRtFuture<>(status); - } - auto copy_to_staging_buffer = [dst, transfer_size, - staging_buffer]() mutable { - std::memcpy(dst, staging_buffer.get(), transfer_size); - }; - if (auto status = stream->DoHostCallback(copy_to_staging_buffer); - !status.ok()) { - return PjRtFuture<>(status); - } + absl::Status defined_status = + device_buffer->definition_events()[0]->GetDefinedStatus(); + if (!defined_status.ok()) { + promise.Set(defined_status); + return; + } + + auto& device_memory = device_buffer->device_memory()[0]; + if (offset < 0 || offset > device_memory.size() || + device_memory.size() - offset < transfer_size) { + promise.Set( + InvalidArgument("Copy raw buffer called on buffer size %lld with " + "invalid offset %lld, transfer size %lld", + device_memory.size(), offset, transfer_size)); + return; + } + + std::unique_ptr sub_buffer; + if (transfer_size < device_memory.size()) { + sub_buffer = std::make_unique( + device_memory.GetByteSlice(offset, transfer_size)); } else { - // D2H request holds a non-owned pointer into sub_buffer base address - // that needs to outlive the transfer until the stream callback is - // invoked. - auto status = stream->Memcpy(dst, *sub_buffer, transfer_size); - if (!status.ok()) { - return PjRtFuture<>(status); + sub_buffer = std::make_unique(device_memory); + } + + WaitForBufferDefinitionEventsOnStream(*device_buffer, stream); + + if (transfer_size != 0) { + if (should_stage_host_to_device_transfers()) { + if (host_memory_allocator() == nullptr) { + promise.Set(InvalidArgument( + "host_memory_allocator should be initialized for staging buffer " + "transfer.")); + return; + } + void* ptr = host_memory_allocator()->AllocateRaw( + tsl::Allocator::kAllocatorAlignment, transfer_size); + + std::shared_ptr staging_buffer = std::shared_ptr( + ptr, [host_memory_allocator = host_memory_allocator()](void* ptr) { + host_memory_allocator->DeallocateRaw(ptr); + }); + if (auto status = stream->Memcpy(staging_buffer.get(), *sub_buffer, + transfer_size); + !status.ok()) { + promise.Set(std::move(status)); + return; + } + auto copy_to_staging_buffer = [dst, transfer_size, + staging_buffer]() mutable { + std::memcpy(*dst, staging_buffer.get(), transfer_size); + }; + if (auto status = stream->DoHostCallback(copy_to_staging_buffer); + !status.ok()) { + promise.Set(std::move(status)); + return; + } + } else { + // D2H request holds a non-owned pointer into sub_buffer base address + // that needs to outlive the transfer until the stream callback is + // invoked. + auto status = stream->Memcpy(*dst, *sub_buffer, transfer_size); + if (!status.ok()) { + promise.Set(std::move(status)); + return; + } } } - } - auto usage_event = - std::make_shared(this->thread_pool()); - local_device->event_pool().ThenRecordEvent(stream, event_or.value()); - usage_event->SetSequencingEvent(std::move(event_or).value(), stream); - // This usage hold will prevent device_buffer from being deleted before - // the transfer is complete. - hold.ConvertUsageHold(stream, std::move(usage_event), - /*reference_held=*/false); + local_device->event_pool().ThenRecordEvent(stream, event.value()); + usage_event->SetSequencingEvent(std::move(event).value(), stream); - auto promise = PjRtFuture<>::CreatePromise(); - auto callback_status = local_device->ThenExecuteCallback( - stream, [promise]() mutable { promise.Set(); }); - if (!callback_status.ok()) { - return PjRtFuture<>(callback_status); - } + auto callback_status = local_device->ThenExecuteCallback( + stream, [promise, device_buffer = std::move(device_buffer)]() mutable { + promise.Set(); + }); + if (!callback_status.ok()) { + promise.Set(std::move(callback_status)); + return; + } + }; + + device_buffer->definition_events()[0]->ExecuteOrAddToFutureTasks( + absl::StrFormat("async_copy_raw_sub_buffer_to_host_%p", &async_copy), + [this, dst, async_copy = std::move(async_copy)]() mutable { + dst.OnReady([this, async_copy = std::move(async_copy)]( + absl::StatusOr dst) { + // Trampoline through a thread pool since GPUs do not allow calling + // D2H inside the callback's context. + thread_pool()->Schedule(absl::bind_front(async_copy, std::move(dst))); + }); + }); return PjRtFuture<>( std::move(promise), diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h index 8436dccd1d4011..52daf66f96dfc7 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -196,7 +196,7 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { CreateBuffersForAsyncHostToDevice(absl::Span shapes, PjRtDevice* device) override; - PjRtFuture<> CopyRawSubBufferToHost(PjRtBuffer* buffer, void* dst, + PjRtFuture<> CopyRawSubBufferToHost(PjRtBuffer* buffer, PjRtFuture dst, int64_t offset, int64_t transfer_size) override; diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index 8ff02b10928335..2f2d11f2333990 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -440,12 +440,17 @@ TEST(StreamExecutorGpuClientTest, CopyRawToHostFuture) { xla::PjRtFuture dst_future(dst_promise); TF_ASSERT_OK_AND_ASSIGN(int64_t size, buffer->GetOnDeviceSizeInBytes()); - buffer->GetReadyFuture().OnReady([dst_promise = std::move(dst_promise), - size](absl::Status status) mutable { + auto ready = buffer->GetReadyFuture(); + auto result = buffer->CopyRawToHostFuture(dst_future, 0, size); + + // Drop the buffer before fulfilling `dst`. The transfer should still keep the + // buffer alive. + buffer.reset(); + ready.OnReady([dst_promise = std::move(dst_promise), + size](absl::Status status) mutable { dst_promise.Set(aligned_alloc(size, 0)); }); - auto result = buffer->CopyRawToHostFuture(dst_future, 0, size); TF_EXPECT_OK(result.Await()); TF_ASSERT_OK_AND_ASSIGN(auto* dst, dst_future.Await()); EXPECT_EQ(*(static_cast(dst)), 41.0f); diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index efa0d9fe2c25aa..6175f93f2e2e74 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -1634,30 +1634,13 @@ StatusOr PjRtStreamExecutorBuffer::GetOnDeviceSizeInBytes() const { PjRtFuture<> PjRtStreamExecutorBuffer::CopyRawToHost(void* dst, int64_t offset, int64_t transfer_size) { - return client_->CopyRawSubBufferToHost(this, dst, offset, transfer_size); + return client_->CopyRawSubBufferToHost(this, PjRtFuture(dst), offset, + transfer_size); } PjRtFuture<> PjRtStreamExecutorBuffer::CopyRawToHostFuture( PjRtFuture dst, int64_t offset, int64_t transfer_size) { - auto promise = PjRtFuture<>::CreatePromise(); - dst.OnReady([this, promise, offset, - transfer_size](absl::StatusOr dst) mutable { - if (dst.ok()) { - // Trampoline through a thread pool since some device types (e.g., GPUs) - // do not allow calling D2H inside the callback's context. - client_->thread_pool()->Schedule( - [this, dst = *dst, offset, transfer_size, - promise = std::move(promise)]() mutable { - CopyRawToHost(dst, offset, transfer_size) - .OnReady([promise = std::move(promise)](Status status) mutable { - promise.Set(std::move(status)); - }); - }); - } else { - promise.Set(dst.status()); - } - }); - return PjRtFuture<>(std::move(promise)); + return client_->CopyRawSubBufferToHost(this, dst, offset, transfer_size); } StatusOr PjRtStreamExecutorBuffer::AsShapedBuffer() const { diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index 545d03e7c2c78d..58a2282d37e9a8 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -399,7 +399,8 @@ class PjRtStreamExecutorClient : public PjRtClient { } } - virtual PjRtFuture<> CopyRawSubBufferToHost(PjRtBuffer* buffer, void* dst, + virtual PjRtFuture<> CopyRawSubBufferToHost(PjRtBuffer* buffer, + PjRtFuture dst, int64_t offset, int64_t transfer_size) { return PjRtFuture<>(Unimplemented("Raw copies to host not implemented.")); diff --git a/third_party/xla/xla/service/gpu/elemental_ir_emitter.cc b/third_party/xla/xla/service/gpu/elemental_ir_emitter.cc index 6d6675f21e46ee..0ff1cc6d3bf654 100644 --- a/third_party/xla/xla/service/gpu/elemental_ir_emitter.cc +++ b/third_party/xla/xla/service/gpu/elemental_ir_emitter.cc @@ -55,8 +55,6 @@ limitations under the License. namespace xla { namespace gpu { -using absl::StrAppend; - GpuElementalIrEmitter::GpuElementalIrEmitter( IrEmitterContext& ir_emitter_context, llvm::IRBuilder<>* b) : ElementalIrEmitter(ir_emitter_context.llvm_module(), b), @@ -133,7 +131,7 @@ llvm_ir::IrArray::Index GpuElementalIrEmitter::GetSourceIndexOfBitcast( // Decode the layout of the shape from the Protobufs attached to // backend_config_. auto gpu_config = hlo->backend_config(); - CHECK(gpu_config.ok()); + CHECK_OK(gpu_config); const BitcastBackendConfig& bitcast_config = gpu_config.value().bitcast_backend_config();