[go: nahoru, domu]

Skip to content

Commit

Permalink
Try replacing synchronous copies with asynchronous ones (Phase I). In…
Browse files Browse the repository at this point in the history
… this phase, we only try replacing copies with less complicated use patterns, i.e., the ones without conditional, loop, or nested copy uses.

PiperOrigin-RevId: 630452571
  • Loading branch information
mehrdadkhani authored and tensorflower-gardener committed Jun 5, 2024
1 parent fb6b954 commit 2aaa4f5
Show file tree
Hide file tree
Showing 10 changed files with 3,426 additions and 324 deletions.
10 changes: 10 additions & 0 deletions third_party/xla/xla/service/cpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,19 @@ cc_library(
":thunk",
"//xla/runtime:buffer_use",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/profiler/lib:traceme",
],
)

Expand All @@ -73,6 +82,7 @@ xla_cc_test(
"//xla/runtime:buffer_use",
"//xla/service:buffer_assignment",
"@com_google_absl//absl/status",
"@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
"@local_tsl//tsl/platform:test_main",
Expand Down
78 changes: 78 additions & 0 deletions third_party/xla/xla/service/cpu/runtime/thunk_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,24 @@ limitations under the License.

#include "xla/service/cpu/runtime/thunk_executor.h"

#include <atomic>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "xla/runtime/buffer_use.h"
#include "xla/service/cpu/runtime/thunk.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {

Expand Down Expand Up @@ -75,6 +83,76 @@ absl::StatusOr<ThunkExecutor> ThunkExecutor::Create(
return ThunkExecutor(std::move(thunk_sequence), std::move(defs));
}

ThunkExecutor::ExecuteState::ExecuteState(ThunkExecutor* executor,
TaskRunner runner)
: executor(executor),
runner(std::move(runner)),
counters(executor->nodes_defs().size()),
nodes(executor->nodes_defs().size()),
done(executor->sink().size()) {
for (NodeId id = 0; id < nodes.size(); ++id) {
const NodeDef& node_def = executor->node_def(id);
counters[id].store(node_def.in_edges.size(), std::memory_order_relaxed);
nodes[id] = Node{id, &counters[id], &node_def.out_edges};
}
}

absl::Status ThunkExecutor::Execute(const Thunk::ExecuteParams& params,
TaskRunner runner) {
auto state = std::make_unique<ExecuteState>(this, std::move(runner));

ReadyQueue ready_queue(source_.begin(), source_.end());
TF_RETURN_IF_ERROR(Execute(state.get(), params, std::move(ready_queue)));

tsl::profiler::TraceMe trace("ThunkExecutor::Execute (wait for done)");
state->done.Wait();

return absl::OkStatus();
}

absl::Status ThunkExecutor::Execute(ExecuteState* state,
const Thunk::ExecuteParams& params,
ReadyQueue ready_queue) {
tsl::profiler::TraceMe trace("ThunkExecutor::Execute");
CHECK(!ready_queue.empty()) << "Ready queue must not be empty";

for (int64_t i = 0; i < ready_queue.size(); ++i) {
NodeId id = ready_queue[i];
Node& node = state->nodes[id];

// Push the tail of the ready queue to the task runner.
if (i < ready_queue.size() - 1) {
ReadyQueue tail(ready_queue.begin() + i + 1, ready_queue.end());
ready_queue.erase(ready_queue.begin() + i + 1, ready_queue.end());
state->runner([&params, state, tail = std::move(tail)]() mutable {
// TODO(ezhulenev): Add proper error handling.
CHECK_OK(state->executor->Execute(state, params, std::move(tail)));
});
}

// Execute thunk for the given node id.
Thunk& thunk = *state->executor->thunk_sequence_[id];
// TODO(ezhulenev): Add proper error handling.
CHECK_OK(thunk.Execute(params));

// Append ready nodes to the back of the queue.
for (NodeId out_edge : *node.out_edges) {
Node& out_node = state->nodes[out_edge];

int64_t cnt = out_node.counter->fetch_sub(1, std::memory_order_relaxed);
DCHECK_GE(cnt, 1) << "Node counter can't drop below 0";
if (cnt == 1) ready_queue.push_back(out_edge);
}

// Drop done counter if the node has no out-edges.
if (node.out_edges->empty()) {
state->done.DecrementCount();
}
}

return absl::OkStatus();
}

std::string ThunkExecutor::ToString() const {
std::string str = absl::StrFormat(
"ThunkExecutor: #thunks=%d #source_nodes=%d #sink_nodes=%d",
Expand Down
45 changes: 44 additions & 1 deletion third_party/xla/xla/service/cpu/runtime/thunk_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@ limitations under the License.
#ifndef XLA_SERVICE_CPU_RUNTIME_THUNK_EXECUTOR_H_
#define XLA_SERVICE_CPU_RUNTIME_THUNK_EXECUTOR_H_

#include <atomic>
#include <cstdint>
#include <limits>
#include <string>
#include <vector>

#include "absl/container/fixed_array.h"
#include "absl/container/inlined_vector.h"
#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/types/span.h"
#include "xla/service/cpu/runtime/thunk.h"

Expand All @@ -32,6 +38,13 @@ namespace xla::cpu {
// thunks concurrently in a given thread pool.
class ThunkExecutor {
public:
// It's up to the caller to provide the task runner that will execute tasks
// produced by the executor. It can be a simple inline executor that runs
// tasks on the same thread, or a runner backed by a thread pool.
using Task = absl::AnyInvocable<void()>;
using TaskRunner = absl::AnyInvocable<void(Task)>;

// Nodes identified by their index in the captured ThunkSequence.
using NodeId = int64_t;

static constexpr NodeId kInvalidNodeId = std::numeric_limits<NodeId>::min();
Expand All @@ -44,19 +57,49 @@ class ThunkExecutor {
// NodeDef defines an execution order for all thunks in a sequence.
struct NodeDef {
NodeId id = kInvalidNodeId;
std::vector<NodeId> out_edges;
std::vector<NodeId> in_edges;
std::vector<NodeId> out_edges;
};

// Executes the thunk sequence using the prepared dataflow graph.
absl::Status Execute(const Thunk::ExecuteParams& params, TaskRunner runner);

absl::Span<const NodeDef> nodes_defs() const { return nodes_defs_; }
const NodeDef& node_def(NodeId id) const { return nodes_defs_[id]; }

absl::Span<const NodeId> source() const { return source_; }
absl::Span<const NodeId> sink() const { return sink_; }

std::string ToString() const;

private:
using ReadyQueue = absl::InlinedVector<NodeId, 8>;

ThunkExecutor(ThunkSequence thunk_sequence, std::vector<NodeDef> nodes_defs);

// At run time NodeDef instantiated as a Node with an atomic counter that
// drops to zero when all in_edges are ready.
struct Node {
NodeId id = kInvalidNodeId;
std::atomic<int64_t>* counter = nullptr;
const std::vector<NodeId>* out_edges = nullptr;
};

// A struct to keep the state of a running executor.
struct ExecuteState {
ExecuteState(ThunkExecutor* executor, TaskRunner runner);

ThunkExecutor* executor;
TaskRunner runner;

absl::FixedArray<std::atomic<int64_t>> counters;
absl::InlinedVector<Node, 32> nodes;
absl::BlockingCounter done;
};

absl::Status Execute(ExecuteState* state, const Thunk::ExecuteParams& params,
ReadyQueue ready_queue);

ThunkSequence thunk_sequence_;
std::vector<NodeDef> nodes_defs_;

Expand Down
51 changes: 44 additions & 7 deletions third_party/xla/xla/service/cpu/runtime/thunk_executor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ limitations under the License.
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/status/status.h"
#include "xla/runtime/buffer_use.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/cpu/runtime/thunk.h"
#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

Expand All @@ -36,25 +38,32 @@ class BufferUseThunk : public Thunk {
public:
using BufferUses = Thunk::BufferUses;

BufferUseThunk(std::string name, BufferUses buffer_uses)
BufferUseThunk(std::string name, BufferUses buffer_uses,
std::vector<std::string>* trace = nullptr)
: Thunk(Kind::kKernel, Info{name}),
buffer_uses_(std::move(buffer_uses)) {}
buffer_uses_(std::move(buffer_uses)),
trace_(trace) {}

static std::unique_ptr<Thunk> Create(std::string name,
BufferUses buffer_uses) {
static std::unique_ptr<Thunk> Create(
std::string name, BufferUses buffer_uses,
std::vector<std::string>* trace = nullptr) {
return std::make_unique<BufferUseThunk>(std::move(name),
std::move(buffer_uses));
std::move(buffer_uses), trace);
}

BufferUses buffer_uses() const override { return buffer_uses_; }

absl::Status Execute(const ExecuteParams&) final { return absl::OkStatus(); }
absl::Status Execute(const ExecuteParams&) final {
if (trace_) trace_->push_back(info().op_name);
return absl::OkStatus();
}

private:
BufferUses buffer_uses_;
std::vector<std::string>* trace_;
};

TEST(ThunkExecutorTest, Basics) {
TEST(ThunkExecutorTest, Ordering) {
BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0);

BufferAllocation::Slice slice0(&alloc, /*offset=*/0, /*size=*/10);
Expand All @@ -73,5 +82,33 @@ TEST(ThunkExecutorTest, Basics) {
EXPECT_THAT(executor.sink(), ElementsAre(0, 2));
}

TEST(ThunkExecutorTest, Execute) {
BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0);

BufferAllocation::Slice slice0(&alloc, /*offset=*/0, /*size=*/10);
BufferAllocation::Slice slice1(&alloc, /*offset=*/5, /*size=*/10);
BufferAllocation::Slice slice2(&alloc, /*offset=*/10, /*size=*/10);

std::vector<std::string> trace;

ThunkSequence sequence;
sequence.push_back(
BufferUseThunk::Create("a", {BufferUse::Read(slice0)}, &trace));
sequence.push_back(
BufferUseThunk::Create("b", {BufferUse::Read(slice1)}, &trace));
sequence.push_back(
BufferUseThunk::Create("c", {BufferUse::Write(slice2)}, &trace));

TF_ASSERT_OK_AND_ASSIGN(auto executor,
ThunkExecutor::Create(std::move(sequence)));
Thunk::ExecuteParams params;
TF_ASSERT_OK(executor.Execute(params, [&](ThunkExecutor::Task task) {
trace.push_back("<TaskRunner>");
task();
}));

EXPECT_THAT(trace, ElementsAre("<TaskRunner>", "b", "c", "a"));
}

} // namespace
} // namespace xla::cpu
Loading

0 comments on commit 2aaa4f5

Please sign in to comment.