[go: nahoru, domu]

Skip to content

Commit

Permalink
Create new PjRt GPU client with remote devices after coordination ser…
Browse files Browse the repository at this point in the history
…vice agent is available

This does not modify the previous PjRt GPU client. The test environment case where multiple threads are used to simulate multiple workers is supported.

For the PjRt GPU MultiWorkerMirroredStrategy (MWMS) case where there are multiple workers each with a GPU or GPUs, the symptom of a client without this fix is a error message like the following where the number at the end is the ID of the first remote device, e.g. 8 when there are eight local GPUs with IDs 0..7. (Another example is 1 when there is one local GPU with ID 0.)
`INVALID_ARGUMENT: No matching device found for device_id 8`

Note that while the primary purpose of `BaseGPUDeviceFactory::CreateDevices` is to do one-time initialization, it is often called multiple times. In the typical production MWMS case, it is called both when creating a TF Context and when creating GRPC servers when enabling collectives. In the unit test added by this CL (which uses two GPUs), it is first called when a TF Context is created when starting up the test environment (with both GPUS) and then two worker TF processes are created which both call it twice during MWMS startup (each with the one GPU assigned to the process). Other test cases have different patterns.

PiperOrigin-RevId: 611621382
  • Loading branch information
SeeForTwo authored and tensorflower-gardener committed Mar 20, 2024
1 parent 590adf8 commit 4310720
Show file tree
Hide file tree
Showing 13 changed files with 737 additions and 54 deletions.
43 changes: 42 additions & 1 deletion tensorflow/core/common_runtime/eager/BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
load("@bazel_skylib//lib:selects.bzl", "selects")
load(
"//tensorflow:tensorflow.bzl",
"clean_dep",
"if_google",
"if_zendnn",
"tf_cc_test",
"tf_cc_test_mkl",
Expand Down Expand Up @@ -234,12 +237,50 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime/eager:cluster_function_library_runtime",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/eager:remote_mgr",
"//tensorflow/core/tfrt/common:global_state",
"//tensorflow/core/tfrt/common:pjrt_state",
"//tensorflow/core/tfrt/common:pjrt_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@local_tsl//tsl/distributed_runtime/coordination:coordination_service",
"@local_tsl//tsl/distributed_runtime/coordination:coordination_service_agent",
"@local_tsl//tsl/distributed_runtime/coordination:coordination_service_impl",
"@local_tsl//tsl/distributed_runtime/preemption:preemption_notifier",
"@local_tsl//tsl/platform:mutex",
"@local_tsl//tsl/platform:statusor",
"@local_xla//xla/pjrt:pjrt_stream_executor_client",
"@local_xla//xla/pjrt/gpu:se_gpu_pjrt_client",
],
}),
}) + if_google(
# TODO(b/282068262): PJRT pulls in TFRT components that are incompatible with ARM platform.
# Clean up so that PJRT can run on ARM (and remove "#if defined(PLATFORM_GOOGLE) ..." use
# from gpu_util.cc).
# Also it won't build with WeightWatcher which tracks OSS build binaries.
# TODO(b/290533709): Clean up this build rule.
selects.with_or({
clean_dep("//tensorflow:linux_x86_64_with_weightwatcher"): [],
(
clean_dep("//tensorflow:linux_x86_64"),
clean_dep("//tensorflow:haswell"),
): [
"//tensorflow/core",
"//tensorflow/core/framework:resource_base",
"@local_xla//xla/pjrt/distributed:key_value_store_interface",
"@local_xla//xla/pjrt:local_device_state",
"@local_xla//xla/pjrt:pjrt_client",
"@local_xla//xla/pjrt:pjrt_compiler",
"@local_xla//xla/service/gpu:gpu_executable_run_options",
"//tensorflow/core/common_runtime/gpu:gpu_runtime_impl",
],
"//conditions:default": [],
}),
),
)

tf_cc_test(
Expand Down
Loading

0 comments on commit 4310720

Please sign in to comment.