diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 696b96468c183c..4649162b44c195 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -410,7 +410,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( if (output.on_host_shape().is_dynamic()) { const se::Platform* platform = nullptr; if (stream != nullptr) { - platform = stream->parent()->platform(); + platform = stream->parent()->GetPlatform(); } else { // Stream is not set for the host platform. TF_ASSIGN_OR_RETURN(platform, diff --git a/tensorflow/compiler/jit/xla_platform_info.cc b/tensorflow/compiler/jit/xla_platform_info.cc index 1486340e95da3b..a6b066f7460168 100644 --- a/tensorflow/compiler/jit/xla_platform_info.cc +++ b/tensorflow/compiler/jit/xla_platform_info.cc @@ -377,7 +377,7 @@ XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) { auto device = static_cast(device_base); platform_id = device->tensorflow_accelerator_device_info() ->stream->parent() - ->platform() + ->GetPlatform() ->id(); } else if (XlaDevice::GetMetadataFromDevice(device_base, &xla_device_metadata) .ok()) { diff --git a/third_party/xla/xla/backends/interpreter/executable_base.cc b/third_party/xla/xla/backends/interpreter/executable_base.cc index 78443e635ad737..0bb1639daa58a0 100644 --- a/third_party/xla/xla/backends/interpreter/executable_base.cc +++ b/third_party/xla/xla/backends/interpreter/executable_base.cc @@ -44,7 +44,7 @@ absl::StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( HloExecutionProfile* hlo_execution_profile) { se::Stream* stream = run_options->stream(); se::StreamExecutor* executor = stream->parent(); - const se::Platform* platform = executor->platform(); + const se::Platform* platform = executor->GetPlatform(); // Convert the ShapeTree to a ShapedBuffer. We do this so we can call // TransferManager methods below. @@ -175,7 +175,7 @@ InterpreterExecutableBase::AllocateOutputMemoryWithInputReuse( })); se::StreamExecutor* executor = stream->parent(); - const se::Platform* platform = executor->platform(); + const se::Platform* platform = executor->GetPlatform(); TF_ASSIGN_OR_RETURN(TransferManager * transfer_manager, TransferManager::GetForPlatform(platform)); diff --git a/third_party/xla/xla/client/local_client.cc b/third_party/xla/xla/client/local_client.cc index 033361e897b676..59d4ecf6a6e82d 100644 --- a/third_party/xla/xla/client/local_client.cc +++ b/third_party/xla/xla/client/local_client.cc @@ -60,7 +60,7 @@ Status LocalExecutable::ValidateExecutionOptions( // Check stream matches service platform. const se::Platform* stream_platform = - run_options.stream()->parent()->platform(); + run_options.stream()->parent()->GetPlatform(); if (stream_platform != backend_->platform()) { return InvalidArgument( "stream is for platform %s, but service targets platform %s", diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 817f49a87e4e34..081d57949126fd 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -916,7 +916,7 @@ Status BuildDistributedDevices( local_topology.set_boot_id(boot_id_str); for (const auto& ordinal_and_device : local_device_states) { const se::Platform* platform = - ordinal_and_device.second->executor()->platform(); + ordinal_and_device.second->executor()->GetPlatform(); TF_ASSIGN_OR_RETURN( std::unique_ptr desc, platform->DescriptionForDevice(ordinal_and_device.first)); diff --git a/third_party/xla/xla/service/compiler.cc b/third_party/xla/xla/service/compiler.cc index b0feb9a62ae811..ee78a3b661818f 100644 --- a/third_party/xla/xla/service/compiler.cc +++ b/third_party/xla/xla/service/compiler.cc @@ -31,7 +31,7 @@ namespace xla { Compiler::TargetConfig::TargetConfig(se::StreamExecutor* s) : device_description(s->GetDeviceDescription().ToGpuProto()), - platform_name(s->platform()->Name()), + platform_name(s->GetPlatform()->Name()), device_description_str(s->GetDeviceDescription().name()) { se::dnn::DnnSupport* dnn = s->AsDnn(); if (dnn != nullptr) { diff --git a/third_party/xla/xla/service/gpu/autotuner_compile_util.cc b/third_party/xla/xla/service/gpu/autotuner_compile_util.cc index ea18ca8abf4a23..19d556cf6a0f2f 100644 --- a/third_party/xla/xla/service/gpu/autotuner_compile_util.cc +++ b/third_party/xla/xla/service/gpu/autotuner_compile_util.cc @@ -174,7 +174,7 @@ AutotunerCompileUtil::Create(const AutotuneConfig& config, se::DeviceMemoryAllocator* allocator = config.GetAllocator(); TF_ASSIGN_OR_RETURN(se::Stream* const stream, config.GetStream()); TF_ASSIGN_OR_RETURN(Compiler * compiler, - Compiler::GetForPlatform(stream_exec->platform())); + Compiler::GetForPlatform(stream_exec->GetPlatform())); return AutotunerCompileUtil(config, compiler, *stream_exec, *stream, *allocator, opts); } diff --git a/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc b/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc index 5de0e32ffb6251..5a5706f2a56449 100644 --- a/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc @@ -426,7 +426,7 @@ absl::StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCache( // Check StreamExecutor on which platform it is. ROCm and Cuda implementation // have diverged. Specifically, we need to make sure redzone allocator related // utilities are not used in ROCm routine - se::Platform::Id platform_id = stream_exec->platform()->id(); + se::Platform::Id platform_id = stream_exec->GetPlatform()->id(); if (platform_id == se::rocm::kROCmPlatformId) { result_or = PickBestAlgorithmNoCacheRocm(instr); } else if (platform_id == se::cuda::kCudaPlatformId) { diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index 169fa9cddc7ceb..1f53f5678bbf39 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -207,7 +207,7 @@ absl::Status GpuExecutable::CheckCompatibilityWithServiceExecutableRunOptions( se::Stream* main_stream = run_options->stream(); stream_executor::Platform::Id platform_id = - main_stream->parent()->platform()->id(); + main_stream->parent()->GetPlatform()->id(); if (platform_id == stream_executor::rocm::kROCmPlatformId) { auto cc = main_stream->GetRocmComputeCapability(); std::string stream_arch = cc.gcn_arch_name(); @@ -647,7 +647,8 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) { // The CUDA driver isn't able to load a PTX and a binary which are both empty. // It's okay if we skip loading in this case; if the module isn't loaded, all // symbol lookups will fail, just as they should for an empty module. - if (!(executor->platform()->id() == stream_executor::cuda::kCudaPlatformId && + if (!(executor->GetPlatform()->id() == + stream_executor::cuda::kCudaPlatformId && binary().empty() && text().empty())) { TF_RETURN_IF_ERROR(executor->LoadModule(module_spec, &module_handle)); } diff --git a/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc b/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc index 8bd97486d34358..89dd6343810bb2 100644 --- a/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc +++ b/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc @@ -102,8 +102,8 @@ absl::Status GpuTransferManager::ReadDynamicShapes( DCHECK(device_shape->is_dynamic()); Shape original_device_shape = *device_shape; - TF_ASSIGN_OR_RETURN(auto compiler, - Compiler::GetForPlatform(stream->parent()->platform())); + TF_ASSIGN_OR_RETURN( + auto compiler, Compiler::GetForPlatform(stream->parent()->GetPlatform())); auto shape_size_fn = compiler->ShapeSizeBytesFunction(); // First, figure out which parts of `device_shape` are dynamic and where the diff --git a/third_party/xla/xla/service/gpu/stream_executor_util.cc b/third_party/xla/xla/service/gpu/stream_executor_util.cc index 4d3489f600d37e..69faf8b17d6ca6 100644 --- a/third_party/xla/xla/service/gpu/stream_executor_util.cc +++ b/third_party/xla/xla/service/gpu/stream_executor_util.cc @@ -348,7 +348,7 @@ absl::Mutex& GetGpuMutex(const se::StreamExecutor* stream_exec) { absl::MutexLock global_lock(&mu); auto it = mutexes ->emplace(std::piecewise_construct, - std::make_tuple(stream_exec->platform(), + std::make_tuple(stream_exec->GetPlatform(), stream_exec->device_ordinal()), std::make_tuple()) .first; diff --git a/third_party/xla/xla/service/platform_util.cc b/third_party/xla/xla/service/platform_util.cc index 1d7cb0c2088d24..dc4d13b9e750ff 100644 --- a/third_party/xla/xla/service/platform_util.cc +++ b/third_party/xla/xla/service/platform_util.cc @@ -135,7 +135,7 @@ PlatformUtil::GetSupportedPlatforms() { // by XLA. static bool IsDeviceSupported(se::StreamExecutor* executor) { const auto& description = executor->GetDeviceDescription(); - if (executor->platform()->id() == se::cuda::kCudaPlatformId) { + if (executor->GetPlatform()->id() == se::cuda::kCudaPlatformId) { // CUDA devices must have a minimum compute capability. se::CudaComputeCapability cc = description.cuda_compute_capability(); if (!cc.IsAtLeast(kMinCudaComputeCapabilityMajor, @@ -148,7 +148,7 @@ static bool IsDeviceSupported(se::StreamExecutor* executor) { << "device is " << cc.ToString(); return false; } - } else if (executor->platform()->id() == se::rocm::kROCmPlatformId) { + } else if (executor->GetPlatform()->id() == se::rocm::kROCmPlatformId) { auto rocm_compute_capability = description.rocm_compute_capability(); if (!rocm_compute_capability.is_supported_gfx_version()) { LOG(INFO) << "StreamExecutor ROCM device (" << executor->device_ordinal() diff --git a/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter.cc b/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter.cc index 97a6c7c7b09e3e..c1ea8add16272d 100644 --- a/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter.cc +++ b/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter.cc @@ -26,7 +26,7 @@ limitations under the License. namespace stream_executor { TfAllocatorAdapter::TfAllocatorAdapter(tsl::Allocator *wrapped, Stream *stream) - : DeviceMemoryAllocator(stream->parent()->platform()), + : DeviceMemoryAllocator(stream->parent()->GetPlatform()), wrapped_(wrapped), stream_(stream) {} diff --git a/third_party/xla/xla/stream_executor/mock_stream_executor.h b/third_party/xla/xla/stream_executor/mock_stream_executor.h index 1c7d959b90dc3d..c0325e5d96d2e5 100644 --- a/third_party/xla/xla/stream_executor/mock_stream_executor.h +++ b/third_party/xla/xla/stream_executor/mock_stream_executor.h @@ -37,6 +37,7 @@ limitations under the License. #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/module_spec.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor_interface.h" #include "xla/stream_executor/stream_interface.h" #include "xla/test.h" @@ -172,6 +173,7 @@ class MockStreamExecutor : public StreamExecutorInterface { MOCK_METHOD(bool, ClearAllocatorStats, (), (override)); MOCK_METHOD(absl::Status, FlushCompilationCache, (), (override)); MOCK_METHOD(Stream*, FindAllocatedStream, (void* device_stream), (override)); + MOCK_METHOD(const Platform*, GetPlatform, (), (const, override)); }; } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/stream_executor_interface.h b/third_party/xla/xla/stream_executor/stream_executor_interface.h index 6f2b433bfc04ed..f51c579d15cd50 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_interface.h +++ b/third_party/xla/xla/stream_executor/stream_executor_interface.h @@ -35,6 +35,7 @@ limitations under the License. #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/module_spec.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_interface.h" namespace stream_executor { @@ -48,6 +49,9 @@ class StreamExecutorInterface { StreamExecutorInterface() = default; virtual ~StreamExecutorInterface() = default; + // Returns a reference to the platform that created this executor. + virtual const Platform* GetPlatform() const = 0; + // Initializes the device for use. virtual absl::Status Init() = 0; diff --git a/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc b/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc index 92ed9b9c868a2f..0041d0b2c53e8a 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc +++ b/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc @@ -121,7 +121,7 @@ absl::StatusOr> StreamExecutor::CreateStream( StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator( StreamExecutor* executor) - : DeviceMemoryAllocator(executor->platform()) { + : DeviceMemoryAllocator(executor->GetPlatform()) { stream_executors_ = {executor}; } diff --git a/third_party/xla/xla/stream_executor/stream_executor_pimpl.h b/third_party/xla/xla/stream_executor/stream_executor_pimpl.h index e27a7a9297b5c1..702cc70508d31e 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_pimpl.h +++ b/third_party/xla/xla/stream_executor/stream_executor_pimpl.h @@ -71,7 +71,11 @@ class StreamExecutor : public StreamExecutorInterface { ~StreamExecutor() = default; + const Platform* GetPlatform() const override { return platform_; } + // Returns a reference to the platform that created this executor. + // TODO(b/301020144) Delete this once all callers are migrated to GetPlatform. + ABSL_DEPRECATED("Use GetPlatform instead.") const Platform* platform() const { return platform_; } // Synchronously allocates an array on the device of type T with element_count diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.cc b/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.cc index 69b9c4a0fc8067..e07b904f65f6c0 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.cc @@ -53,8 +53,9 @@ namespace { static Status PopulateResultTupleBuffers(const ShapedBuffer& result, se::Stream* stream, se::Stream* transfer_stream) { - TF_ASSIGN_OR_RETURN(auto transfer_manager, TransferManager::GetForPlatform( - stream->parent()->platform())); + TF_ASSIGN_OR_RETURN( + auto transfer_manager, + TransferManager::GetForPlatform(stream->parent()->GetPlatform())); if (transfer_manager->CanShapedBufferBeAccessedNow(stream->parent(), result)) { TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync( @@ -77,7 +78,7 @@ TpuExecutableInterface::AllocateOutputMemoryWithInputReuse( std::vector* arguments, se::Stream* stream, se::Stream* transfer_stream) { auto stream_exec = stream->parent(); - auto platform = stream_exec->platform(); + auto platform = stream_exec->GetPlatform(); TF_ASSIGN_OR_RETURN(auto transfer_manager, TransferManager::GetForPlatform(platform)); TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));