[go: nahoru, domu]

Skip to content

Commit

Permalink
Add virtual StreamExecutorInterface::GetPlatform method.
Browse files Browse the repository at this point in the history
This will eventually replace StreamExecutor::platform which collides with TpuExecutorInterface::platform method if StreamExecutor::platform was made virtual.

PiperOrigin-RevId: 630094133
  • Loading branch information
klucke authored and tensorflower-gardener committed May 2, 2024
1 parent 1075195 commit 62baf6a
Show file tree
Hide file tree
Showing 18 changed files with 33 additions and 21 deletions.
2 changes: 1 addition & 1 deletion tensorflow/compiler/jit/xla_launch_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
if (output.on_host_shape().is_dynamic()) {
const se::Platform* platform = nullptr;
if (stream != nullptr) {
platform = stream->parent()->platform();
platform = stream->parent()->GetPlatform();
} else {
// Stream is not set for the host platform.
TF_ASSIGN_OR_RETURN(platform,
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/jit/xla_platform_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
auto device = static_cast<Device*>(device_base);
platform_id = device->tensorflow_accelerator_device_info()
->stream->parent()
->platform()
->GetPlatform()
->id();
} else if (XlaDevice::GetMetadataFromDevice(device_base, &xla_device_metadata)
.ok()) {
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/backends/interpreter/executable_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ absl::StatusOr<ExecutionOutput> InterpreterExecutableBase::ExecuteAsyncOnStream(
HloExecutionProfile* hlo_execution_profile) {
se::Stream* stream = run_options->stream();
se::StreamExecutor* executor = stream->parent();
const se::Platform* platform = executor->platform();
const se::Platform* platform = executor->GetPlatform();

// Convert the ShapeTree to a ShapedBuffer. We do this so we can call
// TransferManager methods below.
Expand Down Expand Up @@ -175,7 +175,7 @@ InterpreterExecutableBase::AllocateOutputMemoryWithInputReuse(
}));

se::StreamExecutor* executor = stream->parent();
const se::Platform* platform = executor->platform();
const se::Platform* platform = executor->GetPlatform();
TF_ASSIGN_OR_RETURN(TransferManager * transfer_manager,
TransferManager::GetForPlatform(platform));

Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/client/local_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Status LocalExecutable::ValidateExecutionOptions(

// Check stream matches service platform.
const se::Platform* stream_platform =
run_options.stream()->parent()->platform();
run_options.stream()->parent()->GetPlatform();
if (stream_platform != backend_->platform()) {
return InvalidArgument(
"stream is for platform %s, but service targets platform %s",
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ Status BuildDistributedDevices(
local_topology.set_boot_id(boot_id_str);
for (const auto& ordinal_and_device : local_device_states) {
const se::Platform* platform =
ordinal_and_device.second->executor()->platform();
ordinal_and_device.second->executor()->GetPlatform();
TF_ASSIGN_OR_RETURN(
std::unique_ptr<xla::se::DeviceDescription> desc,
platform->DescriptionForDevice(ordinal_and_device.first));
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace xla {

Compiler::TargetConfig::TargetConfig(se::StreamExecutor* s)
: device_description(s->GetDeviceDescription().ToGpuProto()),
platform_name(s->platform()->Name()),
platform_name(s->GetPlatform()->Name()),
device_description_str(s->GetDeviceDescription().name()) {
se::dnn::DnnSupport* dnn = s->AsDnn();
if (dnn != nullptr) {
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/gpu/autotuner_compile_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ AutotunerCompileUtil::Create(const AutotuneConfig& config,
se::DeviceMemoryAllocator* allocator = config.GetAllocator();
TF_ASSIGN_OR_RETURN(se::Stream* const stream, config.GetStream());
TF_ASSIGN_OR_RETURN(Compiler * compiler,
Compiler::GetForPlatform(stream_exec->platform()));
Compiler::GetForPlatform(stream_exec->GetPlatform()));
return AutotunerCompileUtil(config, compiler, *stream_exec, *stream,
*allocator, opts);
}
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/gpu/conv_algorithm_picker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::PickBestAlgorithmNoCache(
// Check StreamExecutor on which platform it is. ROCm and Cuda implementation
// have diverged. Specifically, we need to make sure redzone allocator related
// utilities are not used in ROCm routine
se::Platform::Id platform_id = stream_exec->platform()->id();
se::Platform::Id platform_id = stream_exec->GetPlatform()->id();
if (platform_id == se::rocm::kROCmPlatformId) {
result_or = PickBestAlgorithmNoCacheRocm(instr);
} else if (platform_id == se::cuda::kCudaPlatformId) {
Expand Down
5 changes: 3 additions & 2 deletions third_party/xla/xla/service/gpu/gpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ absl::Status GpuExecutable::CheckCompatibilityWithServiceExecutableRunOptions(
se::Stream* main_stream = run_options->stream();

stream_executor::Platform::Id platform_id =
main_stream->parent()->platform()->id();
main_stream->parent()->GetPlatform()->id();
if (platform_id == stream_executor::rocm::kROCmPlatformId) {
auto cc = main_stream->GetRocmComputeCapability();
std::string stream_arch = cc.gcn_arch_name();
Expand Down Expand Up @@ -647,7 +647,8 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) {
// The CUDA driver isn't able to load a PTX and a binary which are both empty.
// It's okay if we skip loading in this case; if the module isn't loaded, all
// symbol lookups will fail, just as they should for an empty module.
if (!(executor->platform()->id() == stream_executor::cuda::kCudaPlatformId &&
if (!(executor->GetPlatform()->id() ==
stream_executor::cuda::kCudaPlatformId &&
binary().empty() && text().empty())) {
TF_RETURN_IF_ERROR(executor->LoadModule(module_spec, &module_handle));
}
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/service/gpu/gpu_transfer_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ absl::Status GpuTransferManager::ReadDynamicShapes(
DCHECK(device_shape->is_dynamic());
Shape original_device_shape = *device_shape;

TF_ASSIGN_OR_RETURN(auto compiler,
Compiler::GetForPlatform(stream->parent()->platform()));
TF_ASSIGN_OR_RETURN(
auto compiler, Compiler::GetForPlatform(stream->parent()->GetPlatform()));
auto shape_size_fn = compiler->ShapeSizeBytesFunction();

// First, figure out which parts of `device_shape` are dynamic and where the
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/gpu/stream_executor_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ absl::Mutex& GetGpuMutex(const se::StreamExecutor* stream_exec) {
absl::MutexLock global_lock(&mu);
auto it = mutexes
->emplace(std::piecewise_construct,
std::make_tuple(stream_exec->platform(),
std::make_tuple(stream_exec->GetPlatform(),
stream_exec->device_ordinal()),
std::make_tuple())
.first;
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/service/platform_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ PlatformUtil::GetSupportedPlatforms() {
// by XLA.
static bool IsDeviceSupported(se::StreamExecutor* executor) {
const auto& description = executor->GetDeviceDescription();
if (executor->platform()->id() == se::cuda::kCudaPlatformId) {
if (executor->GetPlatform()->id() == se::cuda::kCudaPlatformId) {
// CUDA devices must have a minimum compute capability.
se::CudaComputeCapability cc = description.cuda_compute_capability();
if (!cc.IsAtLeast(kMinCudaComputeCapabilityMajor,
Expand All @@ -148,7 +148,7 @@ static bool IsDeviceSupported(se::StreamExecutor* executor) {
<< "device is " << cc.ToString();
return false;
}
} else if (executor->platform()->id() == se::rocm::kROCmPlatformId) {
} else if (executor->GetPlatform()->id() == se::rocm::kROCmPlatformId) {
auto rocm_compute_capability = description.rocm_compute_capability();
if (!rocm_compute_capability.is_supported_gfx_version()) {
LOG(INFO) << "StreamExecutor ROCM device (" << executor->device_ordinal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ limitations under the License.
namespace stream_executor {

TfAllocatorAdapter::TfAllocatorAdapter(tsl::Allocator *wrapped, Stream *stream)
: DeviceMemoryAllocator(stream->parent()->platform()),
: DeviceMemoryAllocator(stream->parent()->GetPlatform()),
wrapped_(wrapped),
stream_(stream) {}

Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/stream_executor/mock_stream_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License.
#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/memory_allocation.h"
#include "xla/stream_executor/module_spec.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream_executor_interface.h"
#include "xla/stream_executor/stream_interface.h"
#include "xla/test.h"
Expand Down Expand Up @@ -172,6 +173,7 @@ class MockStreamExecutor : public StreamExecutorInterface {
MOCK_METHOD(bool, ClearAllocatorStats, (), (override));
MOCK_METHOD(absl::Status, FlushCompilationCache, (), (override));
MOCK_METHOD(Stream*, FindAllocatedStream, (void* device_stream), (override));
MOCK_METHOD(const Platform*, GetPlatform, (), (const, override));
};

} // namespace stream_executor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ limitations under the License.
#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/memory_allocation.h"
#include "xla/stream_executor/module_spec.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream_interface.h"

namespace stream_executor {
Expand All @@ -48,6 +49,9 @@ class StreamExecutorInterface {
StreamExecutorInterface() = default;
virtual ~StreamExecutorInterface() = default;

// Returns a reference to the platform that created this executor.
virtual const Platform* GetPlatform() const = 0;

// Initializes the device for use.
virtual absl::Status Init() = 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ absl::StatusOr<std::unique_ptr<Stream>> StreamExecutor::CreateStream(

StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
StreamExecutor* executor)
: DeviceMemoryAllocator(executor->platform()) {
: DeviceMemoryAllocator(executor->GetPlatform()) {
stream_executors_ = {executor};
}

Expand Down
4 changes: 4 additions & 0 deletions third_party/xla/xla/stream_executor/stream_executor_pimpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ class StreamExecutor : public StreamExecutorInterface {

~StreamExecutor() = default;

const Platform* GetPlatform() const override { return platform_; }

// Returns a reference to the platform that created this executor.
// TODO(b/301020144) Delete this once all callers are migrated to GetPlatform.
ABSL_DEPRECATED("Use GetPlatform instead.")
const Platform* platform() const { return platform_; }

// Synchronously allocates an array on the device of type T with element_count
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ namespace {
static Status PopulateResultTupleBuffers(const ShapedBuffer& result,
se::Stream* stream,
se::Stream* transfer_stream) {
TF_ASSIGN_OR_RETURN(auto transfer_manager, TransferManager::GetForPlatform(
stream->parent()->platform()));
TF_ASSIGN_OR_RETURN(
auto transfer_manager,
TransferManager::GetForPlatform(stream->parent()->GetPlatform()));
if (transfer_manager->CanShapedBufferBeAccessedNow(stream->parent(),
result)) {
TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync(
Expand All @@ -77,7 +78,7 @@ TpuExecutableInterface::AllocateOutputMemoryWithInputReuse(
std::vector<ExecutionInput>* arguments, se::Stream* stream,
se::Stream* transfer_stream) {
auto stream_exec = stream->parent();
auto platform = stream_exec->platform();
auto platform = stream_exec->GetPlatform();
TF_ASSIGN_OR_RETURN(auto transfer_manager,
TransferManager::GetForPlatform(platform));
TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));
Expand Down

0 comments on commit 62baf6a

Please sign in to comment.