[go: nahoru, domu]

Skip to content

Commit

Permalink
In PjRtStreamExecutorBuffer::Delete, fix a bug causing memory corrupt…
Browse files Browse the repository at this point in the history
…ion.

The bug is that we don't wait for events using the buffer on the Compute stream, and also the definition events. This can cause a race condition where the buffer is deleted before it is read/written, which can lead to memory corruptions.

For events on the compute stream, the fix is to schedule the deallocation on the compute stream.
For the definition events, the fix is to wait for them on a stream in a stream pool if they are not defined on the compute stream.

Added a fixed-size stream pool to avoid the overhead and deadlock that could be introduced by BorrowStreamFromPool, since we previously saw deadlocks between the cuStreamCreate and other cuda calls like cuMemHostAlloc.

This change probably wouldn't be needed if all writes are scheduled on the compute stream, but that is not the case today. We have a lot of usages of the h2d and d2d stream.

PiperOrigin-RevId: 634944963
  • Loading branch information
yifjiang authored and tensorflower-gardener committed May 18, 2024
1 parent b409682 commit d3f2cdb
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 16 deletions.
13 changes: 13 additions & 0 deletions third_party/xla/xla/pjrt/local_device_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor,
for (int i = 0; i < num_device_to_device_streams; ++i) {
device_to_device_streams_.emplace_back(create_stream());
}
fixed_size_pool_usage_streams_.reserve(kNumFixedSizePoolUsageStreams);
for (int i = 0; i < kNumFixedSizePoolUsageStreams; ++i) {
fixed_size_pool_usage_streams_.emplace_back(create_stream());
}
external_ready_event_streams_.reserve(kNumExternalReadyEventStreams);
for (int i = 0; i < kNumExternalReadyEventStreams; ++i) {
external_ready_event_streams_.emplace_back(create_stream());
Expand Down Expand Up @@ -167,6 +171,15 @@ se::Stream* LocalDeviceState::GetDeviceToDeviceStream() {
return device_to_device_streams_.at(i).get();
}

se::Stream* LocalDeviceState::GetFixedSizePoolUsageStream() {
absl::MutexLock lock(&mu_);
int i = next_fixed_size_pool_usage_stream_;
next_fixed_size_pool_usage_stream_ =
(next_fixed_size_pool_usage_stream_ + 1) %
fixed_size_pool_usage_streams_.size();
return fixed_size_pool_usage_streams_.at(i).get();
}

se::Stream* LocalDeviceState::GetExternalReadyEventStream() {
absl::MutexLock lock(&mu_);
int i = next_external_ready_event_stream_;
Expand Down
10 changes: 9 additions & 1 deletion third_party/xla/xla/pjrt/local_device_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ class LocalDeviceState {
// fashion amongst the available streams.
se::Stream* GetDeviceToDeviceStream();

// Returns a usage stream. Allocates streams in a round-robin fashion amongst
// the available streams. When the overhead from BorrowStreamFromPool is too
// large for a use case, consider using this API instead.
se::Stream* GetFixedSizePoolUsageStream();

// Return a stream that should be used to track when an externally-managed
// buffer is ready. This is intended to support dlpack on GPU. Allocates
// streams in a round-robin fashion amongst the available streams.
Expand All @@ -151,7 +156,7 @@ class LocalDeviceState {
// Returns a vector of device to device streams.
std::vector<se::Stream*> GetDeviceToDeviceStreams();

// Returns a stream from a pool. The stream is guaranteed not to have any
// Borrows a stream from a pool. The stream is guaranteed not to have any
// currently outstanding work at its tail.
std::unique_ptr<se::Stream> BorrowStreamFromPool();
// Returns a stream to the pool. The caller must ensure the stream does not
Expand Down Expand Up @@ -213,15 +218,18 @@ class LocalDeviceState {
std::unique_ptr<se::Stream> host_to_device_stream_;
std::vector<std::unique_ptr<se::Stream>> device_to_host_streams_;
std::vector<std::unique_ptr<se::Stream>> device_to_device_streams_;
std::vector<std::unique_ptr<se::Stream>> fixed_size_pool_usage_streams_;
std::vector<std::unique_ptr<se::Stream>> external_ready_event_streams_;

static constexpr int kNumDeviceToHostStreams = 4;
static constexpr int kNumDeviceToDeviceStreams = 4;
static constexpr int kNumFixedSizePoolUsageStreams = 4;
static constexpr int kNumExternalReadyEventStreams = 4;

absl::Mutex mu_;
int next_device_to_host_stream_ ABSL_GUARDED_BY(mu_) = 0;
int next_device_to_device_stream_ ABSL_GUARDED_BY(mu_) = 0;
int next_fixed_size_pool_usage_stream_ ABSL_GUARDED_BY(mu_) = 0;
int next_external_ready_event_stream_ ABSL_GUARDED_BY(mu_) = 0;

std::random_device prng_seed_device_ ABSL_GUARDED_BY(mu_);
Expand Down
80 changes: 71 additions & 9 deletions third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1486,32 +1486,82 @@ PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) {
} else {
if (local_device_state->allocation_model() ==
LocalDeviceState::kComputeSynchronized) {
std::unique_ptr<se::Stream> block_stream;
se::Stream* block_stream = nullptr;
for (const auto& stream_and_event : events) {
VLOG(2)
<< "Checking whether need to wait for stream_and_event: stream: "
<< stream_and_event.stream
<< "; event: " << stream_and_event.event.get()
<< "; reference_held: " << stream_and_event.reference_held
<< "; is_predetermined_error: "
<< stream_and_event.event->IsPredeterminedError();
// We only need to do something for events that didn't already acquire a
// reference to the buffer, and also which the compute stream didn't
// already wait for. Based on our heuristics this rare case should only
// occur when a buffer was copied to a device and then never used there.
// In that case we get a new stream and use it to hold onto a reference
// to the buffer until the events are complete.
//
// It is also important that we check IsPredeterminedError before
// checking DefinedOn(compute_stream) because otherwise DefinedOn would
// indefinitely wait since the event is never recorded when the buffer
// is predetermined error.
if (!stream_and_event.event->IsPredeterminedError() &&
!stream_and_event.reference_held &&
!stream_and_event.event->DefinedOn(
local_device_state->compute_stream()) &&
!stream_and_event.event->IsComplete()) {
if (block_stream == nullptr) {
block_stream = local_device_state->BorrowStreamFromPool();
block_stream = local_device_state->GetFixedSizePoolUsageStream();
}
stream_and_event.event->WaitForEventOnStream(block_stream.get());
VLOG(2) << "Waiting for stream_and_event: stream: "
<< stream_and_event.stream
<< "; event: " << stream_and_event.event.get()
<< "; reference_held: " << stream_and_event.reference_held
<< "; is_predetermined_error: "
<< stream_and_event.event->IsPredeterminedError();
stream_and_event.event->WaitForEventOnStream(block_stream);
}
}
for (const auto& definition_event : device_buffer->definition_events()) {
VLOG(2) << "Checking whether need to wait for definition_event: "
<< definition_event.get() << "; is_predetermined_error: "
<< definition_event->IsPredeterminedError();
// Here we wait for the definition events to complete on block_stream as
// well, if they are not on the compute stream and not also recorded as
// usage events.
//
// It is also important that we check IsPredeterminedError before
// checking DefinedOn(compute_stream) because otherwise DefinedOn would
// indefinitely wait since the event is never recorded when the buffer
// is predetermined error.
//
// Since it's possible that definition_event.SetSequencingEvent()
// is called on a different host thread than this host thread, when in
// future more conditions are added to this check, we should be careful
// about whether we put them before the DefinedOn check or after it.
// For example, we shouldn't add an IsDefined() check before the
// DefinedOn() check here because that could potentially cause a
// shortcut where we don't wait for
// definition_event.SetSequencingEvent() on the other thread and
// eventually cause memory corruption.
if (!definition_event->IsPredeterminedError() &&
!definition_event->DefinedOn(
local_device_state->compute_stream()) &&
!definition_event->IsComplete()) {
if (block_stream == nullptr) {
block_stream = local_device_state->GetFixedSizePoolUsageStream();
}
VLOG(2) << "Waiting for definition_event: " << definition_event.get()
<< "; is_predetermined_error: "
<< definition_event->IsPredeterminedError();
definition_event->WaitForEventOnStream(block_stream);
}
}
if (block_stream != nullptr) {
se::Stream* block_stream_ptr = block_stream.release();
TF_RETURN_IF_ERROR(local_device_state->ThenExecuteCallback(
block_stream_ptr,
[device_buffer, block_stream_ptr, local_device_state]() {
local_device_state->ReturnStreamToPool(
std::unique_ptr<se::Stream>(block_stream_ptr));
block_stream, [device_buffer]() {
// Drops device_buffer shared pointer.
}));
}
}
Expand All @@ -1521,8 +1571,20 @@ PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) {

void PjRtStreamExecutorBuffer::Delete() {
VLOG(1) << "PjRtStreamExecutorBuffer::Delete";

// When wait_for_reads_to_complete is false, Release should never fail.
TF_CHECK_OK(Release(/*wait_for_operations_to_complete=*/false).status());
absl::StatusOr<std::shared_ptr<TrackedDeviceBuffer>> tracked_device_buffer =
Release(/*wait_for_operations_to_complete=*/false);

// The only usage events that
// Release(/*wait_for_operations_to_complete=*/false) doesn't wait for are
// events defined on the compute stream. So we schedule the deallocation on
// the compute stream so that there wouldn't be use-after-free usages on the
// compute stream.
TF_CHECK_OK(tracked_device_buffer.status());
TF_CHECK_OK(device_->local_device_state()->ThenRelease(
device_->local_device_state()->compute_stream(),
std::move(tracked_device_buffer.value())));
}

bool PjRtStreamExecutorBuffer::IsDeleted() {
Expand Down
13 changes: 7 additions & 6 deletions third_party/xla/xla/pjrt/pjrt_stream_executor_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -798,12 +798,13 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer {
// Similar to Delete, drops the buffer's reference to its associated device
// memory, leaving the buffer in an invalid state, but returns the
// TrackedDeviceBuffer rather than freeing the device memory, so that another
// framework can take ownership of it. The buffer returned from Release may
// be safely dropped at any time even if it still has pending async
// operations. The client should call GetReadyFuture()->Await() before calling
// Release with wait_for_operations_to_complete=false, to ensure that the host
// has synchronized past any outstanding write operations to the buffer. If
// wait_for_operations_to_complete=true the host will block until any
// framework can take ownership of it.
//
// When called with wait_for_operations_to_complete=false, the buffer returned
// from Release should be dropped on the compute stream, since the only events
// that Release doesn't wait for are events defined on the compute stream.
//
// If wait_for_operations_to_complete=true, the host will block until any
// potentially outstanding asynchronous operations have completed before
// returning, in which case it is safe to read or mutate the returned buffer.
// If the buffer was shared via an external reference it is the client's
Expand Down

0 comments on commit d3f2cdb

Please sign in to comment.