diff --git a/tensorflow/core/data/service/client/data_service_client.cc b/tensorflow/core/data/service/client/data_service_client.cc index 15e0e519575354..dbd65ea37d7a61 100644 --- a/tensorflow/core/data/service/client/data_service_client.cc +++ b/tensorflow/core/data/service/client/data_service_client.cc @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/core/data/utils.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/framework/model.h" #include "tensorflow/core/platform/env.h" @@ -102,7 +103,10 @@ DataServiceClient::~DataServiceClient() { << iteration_client_id_; } -Status DataServiceClient::Initialize(Allocator* allocator) { +Status DataServiceClient::Initialize( + const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info, + Allocator* allocator) { + accelerator_device_info_ = accelerator_device_info; allocator_ = allocator; TF_RETURN_IF_ERROR(ValidateDataServiceParams(params_)); VLOG(3) << "Connecting to " << params_.address @@ -343,7 +347,7 @@ DataServiceClient::CreateWorkerClient(const std::string& protocol, TF_ASSIGN_OR_RETURN(DataTransferServerInfo transfer_server, GetTransferServer(protocol, task_info)); return CreateDataServiceWorkerClient(params_.protocol, transfer_server, - allocator_); + accelerator_device_info_, allocator_); } absl::StatusOr> @@ -356,7 +360,7 @@ DataServiceClient::CreateAlternativeWorkerClientWithGrpcFallback( const DataTransferServerInfo& transfer_server, const TaskInfo& task_info) { absl::StatusOr> worker = CreateDataServiceWorkerClient(params_.protocol, transfer_server, - allocator_); + accelerator_device_info_, allocator_); if (worker.ok()) { LOG(INFO) << "Successfully started client for data transfer protocol '" << transfer_server.protocol() << "' for worker '" @@ -383,7 +387,8 @@ DataServiceClient::CreateWorkerClient(const TaskInfo& task_info) { DataTransferServerInfo info; info.set_protocol(kLocalTransferProtocol); info.set_address(task_info.worker_address()); - return CreateDataServiceWorkerClient(params_.protocol, info, allocator_); + return CreateDataServiceWorkerClient(params_.protocol, info, + accelerator_device_info_, allocator_); } if (!params_.data_transfer_protocol.empty()) { TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/core/data/service/client/data_service_client.h b/tensorflow/core/data/service/client/data_service_client.h index 5c8e3d82fa92d1..0faa2a9ee19be3 100644 --- a/tensorflow/core/data/service/client/data_service_client.h +++ b/tensorflow/core/data/service/client/data_service_client.h @@ -80,7 +80,9 @@ class DataServiceClient { DataServiceClient& operator=(const DataServiceClient&) = delete; // Initializes the client. - Status Initialize(Allocator* allocator); + Status Initialize( + const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info, + Allocator* allocator); // Reads the next element from tf.data workers. Blocks if the next element is // not ready. @@ -246,6 +248,7 @@ class DataServiceClient { int64_t job_id_; int64_t iteration_client_id_; std::unique_ptr dispatcher_; + const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info_; Allocator* allocator_; int64_t get_next_index_ TF_GUARDED_BY(mu_) = 0; diff --git a/tensorflow/core/data/service/client/data_service_client_test.cc b/tensorflow/core/data/service/client/data_service_client_test.cc index 07e7ca0ad9fa40..8ec654b33eabde 100644 --- a/tensorflow/core/data/service/client/data_service_client_test.cc +++ b/tensorflow/core/data/service/client/data_service_client_test.cc @@ -134,7 +134,8 @@ TEST(DataServiceClientTest, NoSharding) { DataServiceParams params = GetDataServiceParams( dataset_id, test_cluster.DispatcherAddress(), ProcessingModeDef::OFF); DataServiceClient client(params); - TF_ASSERT_OK(client.Initialize(/*allocator=*/nullptr)); + TF_ASSERT_OK(client.Initialize(/*accelerator_device_info=*/nullptr, + /*allocator=*/nullptr)); EXPECT_THAT(GetResults(client), IsOkAndHolds(ElementsAreArray(Range(10)))); client.Cancel(); @@ -150,7 +151,8 @@ TEST(DataServiceClientTest, DynamicSharding) { DataServiceParams params = GetDataServiceParams( dataset_id, test_cluster.DispatcherAddress(), ProcessingModeDef::DYNAMIC); DataServiceClient client(params); - TF_ASSERT_OK(client.Initialize(/*allocator=*/nullptr)); + TF_ASSERT_OK(client.Initialize(/*accelerator_device_info=*/nullptr, + /*allocator=*/nullptr)); EXPECT_THAT(GetResults(client), IsOkAndHolds(UnorderedElementsAreArray(Range(10)))); client.Cancel(); @@ -167,7 +169,8 @@ TEST(DataServiceClientTest, StaticSharding) { GetDataServiceParams(dataset_id, test_cluster.DispatcherAddress(), ProcessingModeDef::FILE_OR_DATA); DataServiceClient client(params); - TF_ASSERT_OK(client.Initialize(/*allocator=*/nullptr)); + TF_ASSERT_OK(client.Initialize(/*accelerator_device_info=*/nullptr, + /*allocator=*/nullptr)); EXPECT_THAT(GetResults(client), IsOkAndHolds(UnorderedElementsAreArray(Range(10)))); client.Cancel(); @@ -183,7 +186,8 @@ TEST(DataServiceClientTest, RecordBufferEvents) { DataServiceParams params = GetDataServiceParams( dataset_id, test_cluster.DispatcherAddress(), ProcessingModeDef::OFF); DataServiceClient client(params); - TF_ASSERT_OK(client.Initialize(/*allocator=*/nullptr)); + TF_ASSERT_OK(client.Initialize(/*accelerator_device_info=*/nullptr, + /*allocator=*/nullptr)); auto mock_context = std::make_unique(); TestDataServiceContext* ctx = mock_context.get(); @@ -206,7 +210,8 @@ TEST(DataServiceClientTest, Cancel) { DataServiceParams params = GetDataServiceParams( dataset_id, test_cluster.DispatcherAddress(), ProcessingModeDef::OFF); DataServiceClient client(params); - TF_ASSERT_OK(client.Initialize(/*allocator=*/nullptr)); + TF_ASSERT_OK(client.Initialize(/*accelerator_device_info=*/nullptr, + /*allocator=*/nullptr)); client.Cancel(); EXPECT_THAT(client.GetNext(GetTestDataServiceContext), StatusIs(error::CANCELLED)); @@ -218,7 +223,8 @@ TEST(DataServiceClientTest, ValidationError) { params.target_workers = TARGET_WORKERS_LOCAL; DataServiceClient client(params); EXPECT_THAT( - client.Initialize(/*allocator=*/nullptr), + client.Initialize(/*accelerator_device_info=*/nullptr, + /*allocator=*/nullptr), StatusIs( error::INVALID_ARGUMENT, HasSubstr( diff --git a/tensorflow/core/data/service/data_transfer.h b/tensorflow/core/data/service/data_transfer.h index c769ae40f69a6d..ac3c6f68e95140 100644 --- a/tensorflow/core/data/service/data_transfer.h +++ b/tensorflow/core/data/service/data_transfer.h @@ -70,6 +70,7 @@ class DataTransferClient { struct Config { absl::string_view protocol; std::string address; + const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info; Allocator* allocator; }; using ClientFactoryT = diff --git a/tensorflow/core/data/service/test_cluster.h b/tensorflow/core/data/service/test_cluster.h index 4a3ee4b1f608e1..2e1fb24e16753b 100644 --- a/tensorflow/core/data/service/test_cluster.h +++ b/tensorflow/core/data/service/test_cluster.h @@ -168,7 +168,8 @@ DatasetClient::DatasetClient(const TestCluster& cluster) for (size_t i = 0; i < cluster.NumWorkers(); ++i) { worker_clients_[cluster_.WorkerAddress(i)] = std::make_unique( - cluster_.WorkerAddress(i), "grpc", "grpc", /*allocator=*/nullptr); + cluster_.WorkerAddress(i), "grpc", "grpc", + /*accelerator_device_info=*/nullptr, /*allocator=*/nullptr); } } diff --git a/tensorflow/core/data/service/worker_client.cc b/tensorflow/core/data/service/worker_client.cc index e6a5091a0d3801..5b7b3facb81c27 100644 --- a/tensorflow/core/data/service/worker_client.cc +++ b/tensorflow/core/data/service/worker_client.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/core/data/service/worker_impl.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/dataset.pb.h" +#include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -56,11 +57,13 @@ namespace tensorflow { namespace data { StatusOr> -CreateDataServiceWorkerClient(const std::string& dispatcher_protocol, - const DataTransferServerInfo& info, - Allocator* allocator) { +CreateDataServiceWorkerClient( + const std::string& dispatcher_protocol, const DataTransferServerInfo& info, + const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info, + Allocator* allocator) { auto client = std::make_unique( - info.address(), dispatcher_protocol, info.protocol(), allocator); + info.address(), dispatcher_protocol, info.protocol(), + accelerator_device_info, allocator); TF_RETURN_IF_ERROR(client->Initialize()); TF_RETURN_WITH_CONTEXT_IF_ERROR( client->CheckCompatibility(info.compatibility_info()), @@ -82,7 +85,8 @@ Status DataServiceWorkerClient::EnsureInitialized() { return absl::OkStatus(); } TF_RETURN_IF_ERROR(DataTransferClient::Build( - GetDataTransferProtocol(), {protocol_, address_, allocator_}, &client_)); + GetDataTransferProtocol(), + {protocol_, address_, accelerator_device_info_, allocator_}, &client_)); return absl::OkStatus(); } diff --git a/tensorflow/core/data/service/worker_client.h b/tensorflow/core/data/service/worker_client.h index f1bcb3e887c4da..b2de89dce25c88 100644 --- a/tensorflow/core/data/service/worker_client.h +++ b/tensorflow/core/data/service/worker_client.h @@ -37,12 +37,14 @@ constexpr const char kGrpcTransferProtocol[] = "grpc"; // Client for communicating with the tf.data service worker. class DataServiceWorkerClient : public DataServiceClientBase { public: - DataServiceWorkerClient(const std::string& address, - const std::string& protocol, - const std::string& transfer_protocol, - Allocator* allocator) + DataServiceWorkerClient( + const std::string& address, const std::string& protocol, + const std::string& transfer_protocol, + const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info, + Allocator* allocator) : DataServiceClientBase(address, protocol), transfer_protocol_(transfer_protocol), + accelerator_device_info_(accelerator_device_info), allocator_(allocator) {} // Fetches an element from the worker. @@ -66,6 +68,7 @@ class DataServiceWorkerClient : public DataServiceClientBase { private: std::string transfer_protocol_; + const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info_; Allocator* allocator_; mutex mu_; @@ -77,9 +80,10 @@ class DataServiceWorkerClient : public DataServiceClientBase { // Creates and initializes a new tf.data service worker client to read // from the data transfer server specified in `info`. StatusOr> -CreateDataServiceWorkerClient(const std::string& dispatcher_protocol, - const DataTransferServerInfo& info, - Allocator* allocator); +CreateDataServiceWorkerClient( + const std::string& dispatcher_protocol, const DataTransferServerInfo& info, + const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info, + Allocator* allocator); } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/worker_client_test.cc b/tensorflow/core/data/service/worker_client_test.cc index eab30984c321ac..3974cdb83aec38 100644 --- a/tensorflow/core/data/service/worker_client_test.cc +++ b/tensorflow/core/data/service/worker_client_test.cc @@ -107,6 +107,7 @@ class WorkerClientTest : public ::testing::Test { info.set_address(GetWorkerAddress()); info.set_protocol(data_transfer_protocol); return CreateDataServiceWorkerClient(kProtocol, info, + /*accelerator_device_info=*/nullptr, /*allocator=*/nullptr); } diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 3f1061290ec1e2..effa997eca60f9 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -667,7 +667,8 @@ class IteratorContext { public: struct Params { explicit Params(IteratorContext* ctx) - : allocator_getter(ctx->allocator_getter()), + : accelerator_device_info(ctx->accelerator_device_info()), + allocator_getter(ctx->allocator_getter()), cancellation_manager(ctx->cancellation_manager()), collective_executor(ctx->collective_executor()), env(ctx->env()), @@ -697,6 +698,7 @@ class IteratorContext { // NOTE: need reinterpret_cast because function.h forward-declares Device. DeviceBase* device = reinterpret_cast(ctx->function_library()->device()); + accelerator_device_info = device->tensorflow_accelerator_device_info(); allocator_getter = [device](AllocatorAttributes attrs) { return device->GetAllocator(attrs); }; @@ -719,6 +721,9 @@ class IteratorContext { *ctx->runner(), std::placeholders::_1); } + // If non-null, information about the GPU or TPU on which the op is placed. + const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info = nullptr; + // The Allocator to be used to allocate the output of an iterator. std::function allocator_getter = nullptr; @@ -825,6 +830,10 @@ class IteratorContext { return params_.id_registry; } + const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info() { + return params_.accelerator_device_info; + } + Allocator* allocator(AllocatorAttributes attrs) { return params_.allocator_getter(attrs); } diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc index f5086454228c12..951b53a3b8504b 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc @@ -340,7 +340,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(RegisterCancellationCallback( ctx->cancellation_manager(), [this]() { data_service_client_.Cancel(); }, &deregister_fn_)); - return data_service_client_.Initialize(ctx->allocator(/*attrs=*/{})); + return data_service_client_.Initialize(ctx->accelerator_device_info(), + ctx->allocator(/*attrs=*/{})); } Status GetNextInternal(IteratorContext* ctx,