[go: nahoru, domu]

Skip to content

Commit

Permalink
Make StreamExecutorMemoryAllocator work in terms of StreamExecutorInt…
Browse files Browse the repository at this point in the history
…erface instead of StreamExecutor.

PiperOrigin-RevId: 631556982
  • Loading branch information
klucke authored and tensorflower-gardener committed May 7, 2024
1 parent 027d6b3 commit 783c383
Show file tree
Hide file tree
Showing 17 changed files with 90 additions and 56 deletions.
3 changes: 2 additions & 1 deletion third_party/xla/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,7 @@ cc_library(
"//xla:util",
"//xla/stream_executor",
"//xla/stream_executor:device_memory_allocator",
"//xla/stream_executor:stream_executor_interface",
"//xla/stream_executor/host:host_platform_id",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
Expand Down Expand Up @@ -1445,8 +1446,8 @@ xla_cc_test(
":shaped_buffer",
"//xla:shape_util",
"//xla:test",
"//xla/stream_executor",
"//xla/stream_executor:device_memory_allocator",
"//xla/stream_executor:stream_executor_interface",
"//xla/tests:xla_internal_test_main",
"@local_tsl//tsl/platform:test_benchmark",
],
Expand Down
4 changes: 3 additions & 1 deletion third_party/xla/xla/service/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "xla/statusor.h"
#include "xla/stream_executor/host/host_platform_id.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/stream_executor/stream_executor_interface.h"
#include "xla/util.h"
#include "tsl/platform/cpu_info.h"
#include "tsl/platform/env.h"
Expand Down Expand Up @@ -147,7 +148,8 @@ Backend::Backend(se::Platform* platform, Compiler* compiler,
stream_executors_(stream_executors.begin(), stream_executors.end()) {
// Create a memory allocator for the valid stream executors.
memory_allocator_ = std::make_shared<se::StreamExecutorMemoryAllocator>(
platform, stream_executors_);
platform, std::vector<se::StreamExecutorInterface*>{
stream_executors_.begin(), stream_executors_.end()});
CHECK(!stream_executors_.empty())
<< "Service found no devices for backend " << platform_->Name() << '.';

Expand Down
7 changes: 5 additions & 2 deletions third_party/xla/xla/service/shaped_buffer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ limitations under the License.

#include <memory>
#include <utility>
#include <vector>

#include "xla/service/platform_util.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/stream_executor/stream_executor_interface.h"
#include "xla/test.h"
#include "tsl/platform/test_benchmark.h"

Expand All @@ -33,7 +34,9 @@ TEST(ShapedBufferTest, ScopedShapeBufferAsShapedBufferB71629047) {
xla::PlatformUtil::GetDefaultPlatform());
TF_ASSERT_OK_AND_ASSIGN(auto executors,
xla::PlatformUtil::GetStreamExecutors(platform));
xla::se::StreamExecutorMemoryAllocator allocator(platform, executors);
xla::se::StreamExecutorMemoryAllocator allocator(
platform, std::vector<se::StreamExecutorInterface*>{executors.begin(),
executors.end()});
const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {});
const int kDeviceOrdinal = 0;
auto scoped_buffer = std::make_unique<xla::ScopedShapedBuffer>(
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/stream_executor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ cc_library(
":stream_executor_plugin_headers",
],
deps = STREAM_EXECUTOR_DEPENDENCIES + [
":stream_executor_interface",
":stream_executor_pimpl",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
Expand Down
27 changes: 13 additions & 14 deletions third_party/xla/xla/stream_executor/device_memory_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,29 @@ limitations under the License.
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor_pimpl.h"
#include "xla/stream_executor/stream_executor_interface.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/numbers.h"
#include "tsl/platform/statusor.h"

namespace stream_executor {

StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
StreamExecutor* executor)
StreamExecutorInterface* executor)
: DeviceMemoryAllocator(executor->GetPlatform()) {
stream_executors_ = {executor};
}

StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
const Platform* platform,
absl::Span<StreamExecutor* const> stream_executors)
absl::Span<StreamExecutorInterface* const> stream_executors)
: DeviceMemoryAllocator(platform),
stream_executors_(stream_executors.begin(), stream_executors.end()) {}

absl::StatusOr<OwningDeviceMemory> StreamExecutorMemoryAllocator::Allocate(
int device_ordinal, uint64_t size, bool retry_on_failure,
int64_t memory_space) {
TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
TF_ASSIGN_OR_RETURN(StreamExecutorInterface * executor,
GetStreamExecutor(device_ordinal));
DeviceMemoryBase result =
executor->AllocateArray<uint8_t>(size, memory_space);
Expand All @@ -68,7 +68,7 @@ absl::StatusOr<OwningDeviceMemory> StreamExecutorMemoryAllocator::Allocate(
absl::Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal,
DeviceMemoryBase mem) {
if (!mem.is_null()) {
TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
TF_ASSIGN_OR_RETURN(StreamExecutorInterface * executor,
GetStreamExecutor(device_ordinal));
VLOG(3) << absl::StreamFormat("Freeing %p on device ordinal %d",
mem.opaque(), device_ordinal);
Expand All @@ -77,13 +77,13 @@ absl::Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal,
return absl::OkStatus();
}

absl::StatusOr<StreamExecutor*>
absl::StatusOr<StreamExecutorInterface*>
StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) const {
if (device_ordinal < 0) {
return absl::InvalidArgumentError(absl::StrFormat(
"device ordinal value (%d) must be non-negative", device_ordinal));
}
for (StreamExecutor* se : stream_executors_) {
for (StreamExecutorInterface* se : stream_executors_) {
if (se->device_ordinal() == device_ordinal) {
return se;
}
Expand All @@ -101,17 +101,16 @@ absl::StatusOr<Stream*> StreamExecutorMemoryAllocator::GetStream(
int device_ordinal) {
CHECK(!AllowsAsynchronousDeallocation())
<< "The logic below only works for synchronous allocators";
TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
TF_ASSIGN_OR_RETURN(StreamExecutorInterface * executor,
GetStreamExecutor(device_ordinal));
absl::MutexLock lock(&mutex_);
if (!streams_.count(device_ordinal)) {
auto p = streams_.emplace(std::piecewise_construct,
std::forward_as_tuple(device_ordinal),
std::forward_as_tuple(executor));
TF_RETURN_IF_ERROR(p.first->second.Initialize());
return &p.first->second;
TF_ASSIGN_OR_RETURN(auto stream, executor->CreateStream());
auto stream_ptr = stream.get();
streams_.emplace(device_ordinal, std::move(stream));
return stream_ptr;
}
return &streams_.at(device_ordinal);
return streams_.at(device_ordinal).get();
}

} // namespace stream_executor
12 changes: 7 additions & 5 deletions third_party/xla/xla/stream_executor/device_memory_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream_executor_interface.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/status.h"

Expand Down Expand Up @@ -227,14 +228,14 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator {
public:
// Create an allocator supporting a single device, corresponding to the passed
// executor.
explicit StreamExecutorMemoryAllocator(StreamExecutor *executor);
explicit StreamExecutorMemoryAllocator(StreamExecutorInterface *executor);

// Create an allocator supporting multiple stream executors.
//
// Precondition: all stream_executors have different device ordinals.
StreamExecutorMemoryAllocator(
const Platform *platform,
absl::Span<StreamExecutor *const> stream_executors);
absl::Span<StreamExecutorInterface *const> stream_executors);

absl::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64_t size,
bool retry_on_failure,
Expand All @@ -252,17 +253,18 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator {
absl::StatusOr<Stream *> GetStream(int device_ordinal) override;

// Gets the stream executor for given device ordinal.
absl::StatusOr<StreamExecutor *> GetStreamExecutor(int device_ordinal) const;
absl::StatusOr<StreamExecutorInterface *> GetStreamExecutor(
int device_ordinal) const;

private:
// Available stream executors. Each stream executor has a different device
// ordinal.
std::vector<StreamExecutor *> stream_executors_;
std::vector<StreamExecutorInterface *> stream_executors_;

absl::Mutex mutex_;

// Cache of streams for GetStream.
std::map<int, Stream> streams_ ABSL_GUARDED_BY(mutex_);
std::map<int, std::unique_ptr<Stream>> streams_ ABSL_GUARDED_BY(mutex_);
};

template <typename ElemT>
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/stream_executor/mock_stream_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ class MockStreamExecutor : public StreamExecutorInterface {
MOCK_METHOD(const Platform*, GetPlatform, (), (const, override));
MOCK_METHOD(absl::StatusOr<std::unique_ptr<Stream>>, CreateStream,
((std::optional<std::variant<StreamPriority, int>>)), (override));
MOCK_METHOD(int64_t, GetMemoryLimitBytes, (), (const.override));
};

} // namespace stream_executor
Expand Down
25 changes: 25 additions & 0 deletions third_party/xla/xla/stream_executor/stream_executor_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ class StreamExecutorInterface {
std::optional<std::variant<StreamPriority, int>> priority =
std::nullopt) = 0;

// Synchronously allocates an array on the device of type T with element_count
// elements.
template <typename T>
DeviceMemory<T> AllocateArray(uint64_t element_count,
int64_t memory_space = 0);

// Retrieves (loads) a kernel, if one exists.
//
// Parameters:
Expand Down Expand Up @@ -369,11 +375,30 @@ class StreamExecutorInterface {
// Returns a stream allocated by this executor, or nullptr if not found.
virtual Stream* FindAllocatedStream(void* device_stream) { return nullptr; }

// Returns the memory limit in bytes supported by this executor.
virtual int64_t GetMemoryLimitBytes() const = 0;

private:
StreamExecutorInterface(const StreamExecutorInterface&) = delete;
void operator=(const StreamExecutorInterface&) = delete;
};

template <typename T>
inline DeviceMemory<T> StreamExecutorInterface::AllocateArray(
uint64_t element_count, int64_t memory_space) {
uint64_t bytes = sizeof(T) * element_count;
auto memory_limit_bytes = GetMemoryLimitBytes();
if (memory_limit_bytes > 0 &&
static_cast<int64_t>(bytes) > memory_limit_bytes) {
LOG(WARNING) << "Not enough memory to allocate " << bytes << " on device "
<< device_ordinal()
<< " within provided limit. limit=" << memory_limit_bytes
<< "]";
return DeviceMemory<T>();
}
return DeviceMemory<T>(Allocate(bytes, memory_space));
}

} // namespace stream_executor

#endif // XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERFACE_H_
5 changes: 3 additions & 2 deletions third_party/xla/xla/stream_executor/stream_executor_pimpl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,16 @@ namespace stream_executor {

// Get per-device memory limit in bytes. Returns 0 if
// TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
static int64_t GetMemoryLimitBytes() {
static int64_t GetMemoryLimitBytesFromEnvironmentVariable() {
int64_t value;
TF_CHECK_OK(
tsl::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB", 0, &value));
return value * (1ll << 20);
}

StreamExecutor::StreamExecutor(const Platform* platform)
: platform_(platform), memory_limit_bytes_(GetMemoryLimitBytes()) {}
: platform_(platform),
memory_limit_bytes_(GetMemoryLimitBytesFromEnvironmentVariable()) {}

const DeviceDescription& StreamExecutor::GetDeviceDescription() const {
absl::MutexLock lock(&mu_);
Expand Down
26 changes: 2 additions & 24 deletions third_party/xla/xla/stream_executor/stream_executor_pimpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,6 @@ class StreamExecutor : public StreamExecutorInterface {

const Platform* GetPlatform() const override { return platform_; }

// Synchronously allocates an array on the device of type T with element_count
// elements.
template <typename T>
DeviceMemory<T> AllocateArray(uint64_t element_count,
int64_t memory_space = 0);

// Convenience wrapper that allocates space for a single element of type T in
// device memory.
template <typename T>
Expand Down Expand Up @@ -113,6 +107,8 @@ class StreamExecutor : public StreamExecutorInterface {
std::optional<std::variant<StreamPriority, int>> priority =
std::nullopt) override;

int64_t GetMemoryLimitBytes() const override { return memory_limit_bytes_; }

private:
// Reader/writer lock for mutable data structures on this StreamExecutor.
//
Expand Down Expand Up @@ -140,24 +136,6 @@ class StreamExecutor : public StreamExecutorInterface {
void operator=(const StreamExecutor&) = delete;
};

////////////
// Inlines

template <typename T>
inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64_t element_count,
int64_t memory_space) {
uint64_t bytes = sizeof(T) * element_count;
if (memory_limit_bytes_ > 0 &&
static_cast<int64_t>(bytes) > memory_limit_bytes_) {
LOG(WARNING) << "Not enough memory to allocate " << bytes << " on device "
<< device_ordinal()
<< " within provided limit. limit=" << memory_limit_bytes_
<< "]";
return DeviceMemory<T>();
}
return DeviceMemory<T>(Allocate(bytes, memory_space));
}

} // namespace stream_executor

#endif // XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
4 changes: 3 additions & 1 deletion third_party/xla/xla/tests/buffer_donation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ class BufferDonationTest : public HloTestBase {

TF_ASSERT_OK_AND_ASSIGN(auto stream, executor_->CreateStream());

auto& executors = backend_->stream_executors();
se::StreamExecutorMemoryAllocator memory_allocator(
platform_, backend_->stream_executors());
platform_, std::vector<se::StreamExecutorInterface*>(executors.begin(),
executors.end()));
ExecutableRunOptions run_options;
run_options.set_stream(stream.get());
run_options.set_allocator(&memory_allocator);
Expand Down
4 changes: 3 additions & 1 deletion third_party/xla/xla/tests/cpu_gpu_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,9 @@ void BM_ParallelFusion(::testing::benchmark::State& state) {

se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
auto executors = PlatformUtil::GetStreamExecutors(platform).value();
se::StreamExecutorMemoryAllocator allocator(platform, executors);
se::StreamExecutorMemoryAllocator allocator(
platform, std::vector<se::StreamExecutorInterface*>(executors.begin(),
executors.end()));

const int64_t intra_op_parallelism_threads = 24;
xla::LocalClientOptions client_options;
Expand Down
4 changes: 3 additions & 1 deletion third_party/xla/xla/tests/dot_operation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2307,7 +2307,9 @@ ENTRY MatrixVectorComplex {
void DOT_ReorderContracting(::testing::benchmark::State& state) {
se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
auto executors = PlatformUtil::GetStreamExecutors(platform).value();
se::StreamExecutorMemoryAllocator allocator(platform, executors);
se::StreamExecutorMemoryAllocator allocator(
platform, std::vector<se::StreamExecutorInterface*>(executors.begin(),
executors.end()));

xla::LocalClientOptions client_options;
client_options.set_platform(platform);
Expand Down
4 changes: 3 additions & 1 deletion third_party/xla/xla/tests/dynamic_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,9 @@ ENTRY main {
void BM_DynamicSlice(::testing::benchmark::State& state) {
se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
auto executors = PlatformUtil::GetStreamExecutors(platform).value();
se::StreamExecutorMemoryAllocator allocator(platform, executors);
se::StreamExecutorMemoryAllocator allocator(
platform, std::vector<se::StreamExecutorInterface*>(executors.begin(),
executors.end()));
LocalClient* client = ClientLibrary::GetOrCreateLocalClient(platform).value();
auto* transfer_manager = TransferManager::GetForPlatform(platform).value();
int device_ordinal = client->default_device_ordinal();
Expand Down
4 changes: 3 additions & 1 deletion third_party/xla/xla/tests/local_client_execute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,9 @@ XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedOutfeedTest)) {
void BM_LocalClientOverhead(::testing::benchmark::State& state) {
se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
auto executors = PlatformUtil::GetStreamExecutors(platform).value();
se::StreamExecutorMemoryAllocator allocator(platform, executors);
se::StreamExecutorMemoryAllocator allocator(
platform, std::vector<se::StreamExecutorInterface*>(executors.begin(),
executors.end()));
LocalClient* client = ClientLibrary::GetOrCreateLocalClient(platform).value();
auto* transfer_manager = TransferManager::GetForPlatform(platform).value();
int device_ordinal = client->default_device_ordinal();
Expand Down
11 changes: 10 additions & 1 deletion third_party/xla/xla/tests/local_client_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ class TestAllocator : public se::StreamExecutorMemoryAllocator {
public:
explicit TestAllocator(se::Platform* platform)
: se::StreamExecutorMemoryAllocator(
platform, PlatformUtil::GetStreamExecutors(platform).value()) {}
platform, GetInterfaceVectorFromExecutors(
PlatformUtil::GetStreamExecutors(platform).value())) {
}

absl::StatusOr<se::OwningDeviceMemory> Allocate(
int device_ordinal, uint64_t size, bool retry_on_failure,
Expand All @@ -61,6 +63,13 @@ class TestAllocator : public se::StreamExecutorMemoryAllocator {
int64_t deallocation_count(int device_ordinal) const;

private:
// Helper function to turn a vector<StreamExecutor*> into a
// vector<StreamExecutorInterface*>.
std::vector<se::StreamExecutorInterface*> GetInterfaceVectorFromExecutors(
const std::vector<se::StreamExecutor*>& executors) {
return std::vector<se::StreamExecutorInterface*>(executors.begin(),
executors.end());
}
mutable absl::Mutex count_mutex_;

// Global counts of allocations and deallocations.
Expand Down
Loading

0 comments on commit 783c383

Please sign in to comment.