[go: nahoru, domu]

Skip to content

Commit

Permalink
#tf-data-service Pass accelerator device info to alternative data tra…
Browse files Browse the repository at this point in the history
…nsfer clients.

PiperOrigin-RevId: 632229720
  • Loading branch information
mpcallanan authored and tensorflower-gardener committed May 9, 2024
1 parent 76b909c commit 1187fe2
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 26 deletions.
13 changes: 9 additions & 4 deletions tensorflow/core/data/service/client/data_service_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ limitations under the License.
#include "tensorflow/core/data/utils.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/metrics.h"
#include "tensorflow/core/framework/model.h"
#include "tensorflow/core/platform/env.h"
Expand Down Expand Up @@ -102,7 +103,10 @@ DataServiceClient::~DataServiceClient() {
<< iteration_client_id_;
}

Status DataServiceClient::Initialize(Allocator* allocator) {
Status DataServiceClient::Initialize(
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info,
Allocator* allocator) {
accelerator_device_info_ = accelerator_device_info;
allocator_ = allocator;
TF_RETURN_IF_ERROR(ValidateDataServiceParams(params_));
VLOG(3) << "Connecting to " << params_.address
Expand Down Expand Up @@ -343,7 +347,7 @@ DataServiceClient::CreateWorkerClient(const std::string& protocol,
TF_ASSIGN_OR_RETURN(DataTransferServerInfo transfer_server,
GetTransferServer(protocol, task_info));
return CreateDataServiceWorkerClient(params_.protocol, transfer_server,
allocator_);
accelerator_device_info_, allocator_);
}

absl::StatusOr<std::unique_ptr<DataServiceWorkerClient>>
Expand All @@ -356,7 +360,7 @@ DataServiceClient::CreateAlternativeWorkerClientWithGrpcFallback(
const DataTransferServerInfo& transfer_server, const TaskInfo& task_info) {
absl::StatusOr<std::unique_ptr<DataServiceWorkerClient>> worker =
CreateDataServiceWorkerClient(params_.protocol, transfer_server,
allocator_);
accelerator_device_info_, allocator_);
if (worker.ok()) {
LOG(INFO) << "Successfully started client for data transfer protocol '"
<< transfer_server.protocol() << "' for worker '"
Expand All @@ -383,7 +387,8 @@ DataServiceClient::CreateWorkerClient(const TaskInfo& task_info) {
DataTransferServerInfo info;
info.set_protocol(kLocalTransferProtocol);
info.set_address(task_info.worker_address());
return CreateDataServiceWorkerClient(params_.protocol, info, allocator_);
return CreateDataServiceWorkerClient(params_.protocol, info,
accelerator_device_info_, allocator_);
}
if (!params_.data_transfer_protocol.empty()) {
TF_ASSIGN_OR_RETURN(
Expand Down
5 changes: 4 additions & 1 deletion tensorflow/core/data/service/client/data_service_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ class DataServiceClient {
DataServiceClient& operator=(const DataServiceClient&) = delete;

// Initializes the client.
Status Initialize(Allocator* allocator);
Status Initialize(
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info,
Allocator* allocator);

// Reads the next element from tf.data workers. Blocks if the next element is
// not ready.
Expand Down Expand Up @@ -246,6 +248,7 @@ class DataServiceClient {
int64_t job_id_;
int64_t iteration_client_id_;
std::unique_ptr<DataServiceDispatcherClient> dispatcher_;
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info_;
Allocator* allocator_;

int64_t get_next_index_ TF_GUARDED_BY(mu_) = 0;
Expand Down
18 changes: 12 additions & 6 deletions tensorflow/core/data/service/client/data_service_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ TEST(DataServiceClientTest, NoSharding) {
DataServiceParams params = GetDataServiceParams(
dataset_id, test_cluster.DispatcherAddress(), ProcessingModeDef::OFF);
DataServiceClient client(params);
TF_ASSERT_OK(client.Initialize(/*allocator=*/nullptr));
TF_ASSERT_OK(client.Initialize(/*accelerator_device_info=*/nullptr,
/*allocator=*/nullptr));
EXPECT_THAT(GetResults<int64_t>(client),
IsOkAndHolds(ElementsAreArray(Range(10))));
client.Cancel();
Expand All @@ -150,7 +151,8 @@ TEST(DataServiceClientTest, DynamicSharding) {
DataServiceParams params = GetDataServiceParams(
dataset_id, test_cluster.DispatcherAddress(), ProcessingModeDef::DYNAMIC);
DataServiceClient client(params);
TF_ASSERT_OK(client.Initialize(/*allocator=*/nullptr));
TF_ASSERT_OK(client.Initialize(/*accelerator_device_info=*/nullptr,
/*allocator=*/nullptr));
EXPECT_THAT(GetResults<int64_t>(client),
IsOkAndHolds(UnorderedElementsAreArray(Range(10))));
client.Cancel();
Expand All @@ -167,7 +169,8 @@ TEST(DataServiceClientTest, StaticSharding) {
GetDataServiceParams(dataset_id, test_cluster.DispatcherAddress(),
ProcessingModeDef::FILE_OR_DATA);
DataServiceClient client(params);
TF_ASSERT_OK(client.Initialize(/*allocator=*/nullptr));
TF_ASSERT_OK(client.Initialize(/*accelerator_device_info=*/nullptr,
/*allocator=*/nullptr));
EXPECT_THAT(GetResults<int64_t>(client),
IsOkAndHolds(UnorderedElementsAreArray(Range(10))));
client.Cancel();
Expand All @@ -183,7 +186,8 @@ TEST(DataServiceClientTest, RecordBufferEvents) {
DataServiceParams params = GetDataServiceParams(
dataset_id, test_cluster.DispatcherAddress(), ProcessingModeDef::OFF);
DataServiceClient client(params);
TF_ASSERT_OK(client.Initialize(/*allocator=*/nullptr));
TF_ASSERT_OK(client.Initialize(/*accelerator_device_info=*/nullptr,
/*allocator=*/nullptr));

auto mock_context = std::make_unique<TestDataServiceContext>();
TestDataServiceContext* ctx = mock_context.get();
Expand All @@ -206,7 +210,8 @@ TEST(DataServiceClientTest, Cancel) {
DataServiceParams params = GetDataServiceParams(
dataset_id, test_cluster.DispatcherAddress(), ProcessingModeDef::OFF);
DataServiceClient client(params);
TF_ASSERT_OK(client.Initialize(/*allocator=*/nullptr));
TF_ASSERT_OK(client.Initialize(/*accelerator_device_info=*/nullptr,
/*allocator=*/nullptr));
client.Cancel();
EXPECT_THAT(client.GetNext(GetTestDataServiceContext),
StatusIs(error::CANCELLED));
Expand All @@ -218,7 +223,8 @@ TEST(DataServiceClientTest, ValidationError) {
params.target_workers = TARGET_WORKERS_LOCAL;
DataServiceClient client(params);
EXPECT_THAT(
client.Initialize(/*allocator=*/nullptr),
client.Initialize(/*accelerator_device_info=*/nullptr,
/*allocator=*/nullptr),
StatusIs(
error::INVALID_ARGUMENT,
HasSubstr(
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/data/service/data_transfer.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class DataTransferClient {
struct Config {
absl::string_view protocol;
std::string address;
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info;
Allocator* allocator;
};
using ClientFactoryT =
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/data/service/test_cluster.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ DatasetClient<T>::DatasetClient(const TestCluster& cluster)
for (size_t i = 0; i < cluster.NumWorkers(); ++i) {
worker_clients_[cluster_.WorkerAddress(i)] =
std::make_unique<DataServiceWorkerClient>(
cluster_.WorkerAddress(i), "grpc", "grpc", /*allocator=*/nullptr);
cluster_.WorkerAddress(i), "grpc", "grpc",
/*accelerator_device_info=*/nullptr, /*allocator=*/nullptr);
}
}

Expand Down
14 changes: 9 additions & 5 deletions tensorflow/core/data/service/worker_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/data/service/worker_impl.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/dataset.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/metrics.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
Expand All @@ -56,11 +57,13 @@ namespace tensorflow {
namespace data {

StatusOr<std::unique_ptr<DataServiceWorkerClient>>
CreateDataServiceWorkerClient(const std::string& dispatcher_protocol,
const DataTransferServerInfo& info,
Allocator* allocator) {
CreateDataServiceWorkerClient(
const std::string& dispatcher_protocol, const DataTransferServerInfo& info,
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info,
Allocator* allocator) {
auto client = std::make_unique<DataServiceWorkerClient>(
info.address(), dispatcher_protocol, info.protocol(), allocator);
info.address(), dispatcher_protocol, info.protocol(),
accelerator_device_info, allocator);
TF_RETURN_IF_ERROR(client->Initialize());
TF_RETURN_WITH_CONTEXT_IF_ERROR(
client->CheckCompatibility(info.compatibility_info()),
Expand All @@ -82,7 +85,8 @@ Status DataServiceWorkerClient::EnsureInitialized() {
return absl::OkStatus();
}
TF_RETURN_IF_ERROR(DataTransferClient::Build(
GetDataTransferProtocol(), {protocol_, address_, allocator_}, &client_));
GetDataTransferProtocol(),
{protocol_, address_, accelerator_device_info_, allocator_}, &client_));
return absl::OkStatus();
}

Expand Down
18 changes: 11 additions & 7 deletions tensorflow/core/data/service/worker_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ constexpr const char kGrpcTransferProtocol[] = "grpc";
// Client for communicating with the tf.data service worker.
class DataServiceWorkerClient : public DataServiceClientBase {
public:
DataServiceWorkerClient(const std::string& address,
const std::string& protocol,
const std::string& transfer_protocol,
Allocator* allocator)
DataServiceWorkerClient(
const std::string& address, const std::string& protocol,
const std::string& transfer_protocol,
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info,
Allocator* allocator)
: DataServiceClientBase(address, protocol),
transfer_protocol_(transfer_protocol),
accelerator_device_info_(accelerator_device_info),
allocator_(allocator) {}

// Fetches an element from the worker.
Expand All @@ -66,6 +68,7 @@ class DataServiceWorkerClient : public DataServiceClientBase {

private:
std::string transfer_protocol_;
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info_;
Allocator* allocator_;

mutex mu_;
Expand All @@ -77,9 +80,10 @@ class DataServiceWorkerClient : public DataServiceClientBase {
// Creates and initializes a new tf.data service worker client to read
// from the data transfer server specified in `info`.
StatusOr<std::unique_ptr<DataServiceWorkerClient>>
CreateDataServiceWorkerClient(const std::string& dispatcher_protocol,
const DataTransferServerInfo& info,
Allocator* allocator);
CreateDataServiceWorkerClient(
const std::string& dispatcher_protocol, const DataTransferServerInfo& info,
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info,
Allocator* allocator);

} // namespace data
} // namespace tensorflow
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/data/service/worker_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class WorkerClientTest : public ::testing::Test {
info.set_address(GetWorkerAddress());
info.set_protocol(data_transfer_protocol);
return CreateDataServiceWorkerClient(kProtocol, info,
/*accelerator_device_info=*/nullptr,
/*allocator=*/nullptr);
}

Expand Down
11 changes: 10 additions & 1 deletion tensorflow/core/framework/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,8 @@ class IteratorContext {
public:
struct Params {
explicit Params(IteratorContext* ctx)
: allocator_getter(ctx->allocator_getter()),
: accelerator_device_info(ctx->accelerator_device_info()),
allocator_getter(ctx->allocator_getter()),
cancellation_manager(ctx->cancellation_manager()),
collective_executor(ctx->collective_executor()),
env(ctx->env()),
Expand Down Expand Up @@ -697,6 +698,7 @@ class IteratorContext {
// NOTE: need reinterpret_cast because function.h forward-declares Device.
DeviceBase* device =
reinterpret_cast<DeviceBase*>(ctx->function_library()->device());
accelerator_device_info = device->tensorflow_accelerator_device_info();
allocator_getter = [device](AllocatorAttributes attrs) {
return device->GetAllocator(attrs);
};
Expand All @@ -719,6 +721,9 @@ class IteratorContext {
*ctx->runner(), std::placeholders::_1);
}

// If non-null, information about the GPU or TPU on which the op is placed.
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info = nullptr;

// The Allocator to be used to allocate the output of an iterator.
std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;

Expand Down Expand Up @@ -825,6 +830,10 @@ class IteratorContext {
return params_.id_registry;
}

const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info() {
return params_.accelerator_device_info;
}

Allocator* allocator(AllocatorAttributes attrs) {
return params_.allocator_getter(attrs);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(),
[this]() { data_service_client_.Cancel(); }, &deregister_fn_));
return data_service_client_.Initialize(ctx->allocator(/*attrs=*/{}));
return data_service_client_.Initialize(ctx->accelerator_device_info(),
ctx->allocator(/*attrs=*/{}));
}

Status GetNextInternal(IteratorContext* ctx,
Expand Down

0 comments on commit 1187fe2

Please sign in to comment.