[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try replacing synchronous copies with asynchronous ones (Phase I). In this phase, we only try replacing copies with less complicated use patterns, i.e., the ones without conditional, loop, or nested copy uses. #66948

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Try replacing synchronous copies with asynchronous ones (Phase I). In…
… this phase, we only try replacing copies with less complicated use patterns, i.e., the ones without conditional, loop, or nested copy uses.

PiperOrigin-RevId: 630452571
  • Loading branch information
mehrdadkhani authored and tensorflower-gardener committed Jun 5, 2024
commit 2aaa4f5ad214eba81121021d431a3b03e1d0f370
10 changes: 10 additions & 0 deletions third_party/xla/xla/service/cpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,19 @@ cc_library(
":thunk",
"//xla/runtime:buffer_use",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/profiler/lib:traceme",
],
)

Expand All @@ -73,6 +82,7 @@ xla_cc_test(
"//xla/runtime:buffer_use",
"//xla/service:buffer_assignment",
"@com_google_absl//absl/status",
"@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
"@local_tsl//tsl/platform:test_main",
Expand Down
78 changes: 78 additions & 0 deletions third_party/xla/xla/service/cpu/runtime/thunk_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,24 @@ limitations under the License.

#include "xla/service/cpu/runtime/thunk_executor.h"

#include <atomic>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "xla/runtime/buffer_use.h"
#include "xla/service/cpu/runtime/thunk.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {

Expand Down Expand Up @@ -75,6 +83,76 @@ absl::StatusOr<ThunkExecutor> ThunkExecutor::Create(
return ThunkExecutor(std::move(thunk_sequence), std::move(defs));
}

ThunkExecutor::ExecuteState::ExecuteState(ThunkExecutor* executor,
TaskRunner runner)
: executor(executor),
runner(std::move(runner)),
counters(executor->nodes_defs().size()),
nodes(executor->nodes_defs().size()),
done(executor->sink().size()) {
for (NodeId id = 0; id < nodes.size(); ++id) {
const NodeDef& node_def = executor->node_def(id);
counters[id].store(node_def.in_edges.size(), std::memory_order_relaxed);
nodes[id] = Node{id, &counters[id], &node_def.out_edges};
}
}

absl::Status ThunkExecutor::Execute(const Thunk::ExecuteParams& params,
TaskRunner runner) {
auto state = std::make_unique<ExecuteState>(this, std::move(runner));

ReadyQueue ready_queue(source_.begin(), source_.end());
TF_RETURN_IF_ERROR(Execute(state.get(), params, std::move(ready_queue)));

tsl::profiler::TraceMe trace("ThunkExecutor::Execute (wait for done)");
state->done.Wait();

return absl::OkStatus();
}

absl::Status ThunkExecutor::Execute(ExecuteState* state,
const Thunk::ExecuteParams& params,
ReadyQueue ready_queue) {
tsl::profiler::TraceMe trace("ThunkExecutor::Execute");
CHECK(!ready_queue.empty()) << "Ready queue must not be empty";

for (int64_t i = 0; i < ready_queue.size(); ++i) {
NodeId id = ready_queue[i];
Node& node = state->nodes[id];

// Push the tail of the ready queue to the task runner.
if (i < ready_queue.size() - 1) {
ReadyQueue tail(ready_queue.begin() + i + 1, ready_queue.end());
ready_queue.erase(ready_queue.begin() + i + 1, ready_queue.end());
state->runner([&params, state, tail = std::move(tail)]() mutable {
// TODO(ezhulenev): Add proper error handling.
CHECK_OK(state->executor->Execute(state, params, std::move(tail)));
});
}

// Execute thunk for the given node id.
Thunk& thunk = *state->executor->thunk_sequence_[id];
// TODO(ezhulenev): Add proper error handling.
CHECK_OK(thunk.Execute(params));

// Append ready nodes to the back of the queue.
for (NodeId out_edge : *node.out_edges) {
Node& out_node = state->nodes[out_edge];

int64_t cnt = out_node.counter->fetch_sub(1, std::memory_order_relaxed);
DCHECK_GE(cnt, 1) << "Node counter can't drop below 0";
if (cnt == 1) ready_queue.push_back(out_edge);
}

// Drop done counter if the node has no out-edges.
if (node.out_edges->empty()) {
state->done.DecrementCount();
}
}

return absl::OkStatus();
}

std::string ThunkExecutor::ToString() const {
std::string str = absl::StrFormat(
"ThunkExecutor: #thunks=%d #source_nodes=%d #sink_nodes=%d",
Expand Down
45 changes: 44 additions & 1 deletion third_party/xla/xla/service/cpu/runtime/thunk_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@ limitations under the License.
#ifndef XLA_SERVICE_CPU_RUNTIME_THUNK_EXECUTOR_H_
#define XLA_SERVICE_CPU_RUNTIME_THUNK_EXECUTOR_H_

#include <atomic>
#include <cstdint>
#include <limits>
#include <string>
#include <vector>

#include "absl/container/fixed_array.h"
#include "absl/container/inlined_vector.h"
#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/types/span.h"
#include "xla/service/cpu/runtime/thunk.h"

Expand All @@ -32,6 +38,13 @@ namespace xla::cpu {
// thunks concurrently in a given thread pool.
class ThunkExecutor {
public:
// It's up to the caller to provide the task runner that will execute tasks
// produced by the executor. It can be a simple inline executor that runs
// tasks on the same thread, or a runner backed by a thread pool.
using Task = absl::AnyInvocable<void()>;
using TaskRunner = absl::AnyInvocable<void(Task)>;

// Nodes identified by their index in the captured ThunkSequence.
using NodeId = int64_t;

static constexpr NodeId kInvalidNodeId = std::numeric_limits<NodeId>::min();
Expand All @@ -44,19 +57,49 @@ class ThunkExecutor {
// NodeDef defines an execution order for all thunks in a sequence.
struct NodeDef {
NodeId id = kInvalidNodeId;
std::vector<NodeId> out_edges;
std::vector<NodeId> in_edges;
std::vector<NodeId> out_edges;
};

// Executes the thunk sequence using the prepared dataflow graph.
absl::Status Execute(const Thunk::ExecuteParams& params, TaskRunner runner);

absl::Span<const NodeDef> nodes_defs() const { return nodes_defs_; }
const NodeDef& node_def(NodeId id) const { return nodes_defs_[id]; }

absl::Span<const NodeId> source() const { return source_; }
absl::Span<const NodeId> sink() const { return sink_; }

std::string ToString() const;

private:
using ReadyQueue = absl::InlinedVector<NodeId, 8>;

ThunkExecutor(ThunkSequence thunk_sequence, std::vector<NodeDef> nodes_defs);

// At run time NodeDef instantiated as a Node with an atomic counter that
// drops to zero when all in_edges are ready.
struct Node {
NodeId id = kInvalidNodeId;
std::atomic<int64_t>* counter = nullptr;
const std::vector<NodeId>* out_edges = nullptr;
};

// A struct to keep the state of a running executor.
struct ExecuteState {
ExecuteState(ThunkExecutor* executor, TaskRunner runner);

ThunkExecutor* executor;
TaskRunner runner;

absl::FixedArray<std::atomic<int64_t>> counters;
absl::InlinedVector<Node, 32> nodes;
absl::BlockingCounter done;
};

absl::Status Execute(ExecuteState* state, const Thunk::ExecuteParams& params,
ReadyQueue ready_queue);

ThunkSequence thunk_sequence_;
std::vector<NodeDef> nodes_defs_;

Expand Down
51 changes: 44 additions & 7 deletions third_party/xla/xla/service/cpu/runtime/thunk_executor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ limitations under the License.
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/status/status.h"
#include "xla/runtime/buffer_use.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/cpu/runtime/thunk.h"
#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

Expand All @@ -36,25 +38,32 @@ class BufferUseThunk : public Thunk {
public:
using BufferUses = Thunk::BufferUses;

BufferUseThunk(std::string name, BufferUses buffer_uses)
BufferUseThunk(std::string name, BufferUses buffer_uses,
std::vector<std::string>* trace = nullptr)
: Thunk(Kind::kKernel, Info{name}),
buffer_uses_(std::move(buffer_uses)) {}
buffer_uses_(std::move(buffer_uses)),
trace_(trace) {}

static std::unique_ptr<Thunk> Create(std::string name,
BufferUses buffer_uses) {
static std::unique_ptr<Thunk> Create(
std::string name, BufferUses buffer_uses,
std::vector<std::string>* trace = nullptr) {
return std::make_unique<BufferUseThunk>(std::move(name),
std::move(buffer_uses));
std::move(buffer_uses), trace);
}

BufferUses buffer_uses() const override { return buffer_uses_; }

absl::Status Execute(const ExecuteParams&) final { return absl::OkStatus(); }
absl::Status Execute(const ExecuteParams&) final {
if (trace_) trace_->push_back(info().op_name);
return absl::OkStatus();
}

private:
BufferUses buffer_uses_;
std::vector<std::string>* trace_;
};

TEST(ThunkExecutorTest, Basics) {
TEST(ThunkExecutorTest, Ordering) {
BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0);

BufferAllocation::Slice slice0(&alloc, /*offset=*/0, /*size=*/10);
Expand All @@ -73,5 +82,33 @@ TEST(ThunkExecutorTest, Basics) {
EXPECT_THAT(executor.sink(), ElementsAre(0, 2));
}

TEST(ThunkExecutorTest, Execute) {
BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0);

BufferAllocation::Slice slice0(&alloc, /*offset=*/0, /*size=*/10);
BufferAllocation::Slice slice1(&alloc, /*offset=*/5, /*size=*/10);
BufferAllocation::Slice slice2(&alloc, /*offset=*/10, /*size=*/10);

std::vector<std::string> trace;

ThunkSequence sequence;
sequence.push_back(
BufferUseThunk::Create("a", {BufferUse::Read(slice0)}, &trace));
sequence.push_back(
BufferUseThunk::Create("b", {BufferUse::Read(slice1)}, &trace));
sequence.push_back(
BufferUseThunk::Create("c", {BufferUse::Write(slice2)}, &trace));

TF_ASSERT_OK_AND_ASSIGN(auto executor,
ThunkExecutor::Create(std::move(sequence)));
Thunk::ExecuteParams params;
TF_ASSERT_OK(executor.Execute(params, [&](ThunkExecutor::Task task) {
trace.push_back("<TaskRunner>");
task();
}));

EXPECT_THAT(trace, ElementsAre("<TaskRunner>", "b", "c", "a"));
}

} // namespace
} // namespace xla::cpu
Loading
Loading