[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xla:cpu] Optimize KernelThunk by passing SE_HOST_KernelArg directly to the kernel #70362

Merged
merged 1 commit into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[xla:cpu] Optimize KernelThunk by passing SE_HOST_KernelArg directly …
…to the kernel

PiperOrigin-RevId: 646664927
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Jun 26, 2024
commit 633e9ccc0e93efeadffd572f8b969a54540df00b
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ static void BM_SelectAndScatterF32(benchmark::State& state) {

BENCHMARK(BM_SelectAndScatterF32)
->MeasureProcessCPUTime()
->Arg(64)
->Arg(128)
->Arg(256)
->Arg(512)
Expand Down
30 changes: 16 additions & 14 deletions third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,38 +87,40 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> KernelThunk::Execute(
kernel_name_, arguments_buffers_.size(), results_buffers_.size(),
thread_dim_.ToString());

absl::InlinedVector<se::DeviceMemoryBase, 8> buffers_data;
buffers_data.reserve(arguments_buffers_.size() + results_buffers_.size());
absl::InlinedVector<SE_HOST_KernelArg, 8> kernel_args;
kernel_args.reserve(arguments_buffers_.size() + results_buffers_.size());

int64_t arg_num = 0;
for (BufferAllocation::Slice& buffer : arguments_buffers_) {
TF_ASSIGN_OR_RETURN(buffers_data.emplace_back(),
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase arg_data,
params.buffer_allocations->GetDeviceAddress(buffer));
kernel_args.push_back(
SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()});
VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", arg_num++,
buffer.ToString(),
buffers_data.back().opaque());
buffer.ToString(), kernel_args.back().data);
}

int64_t res_num = 0;
for (BufferAllocation::Slice& buffer : results_buffers_) {
TF_ASSIGN_OR_RETURN(buffers_data.emplace_back(),
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase result_data,
params.buffer_allocations->GetDeviceAddress(buffer));
kernel_args.push_back(
SE_HOST_KernelArg{result_data.opaque(), result_data.size()});
VLOG(3) << absl::StreamFormat(" res #%d: %s (%p)", res_num++,
buffer.ToString(),
buffers_data.back().opaque());
buffer.ToString(), kernel_args.back().data);
}

// Check that all buffers are aligned to the minimum alignment. We codegen
// with the assumption that all buffers are aligned, and if they are not, we
// will crash with a segmentation fault, or worse, produce incorrect results.
if (min_alignment_.has_value()) {
for (int64_t i = 0; i < buffers_data.size(); ++i) {
auto ptr = reinterpret_cast<uintptr_t>(buffers_data[i].opaque());
for (int64_t i = 0; i < kernel_args.size(); ++i) {
auto ptr = reinterpret_cast<uintptr_t>(kernel_args[i].data);
if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) {
return Internal(
"Host kernel %s buffer argument #%d (%p) is not aligned to a "
"required minimum alignment of %d bytes",
info().op_name, i, buffers_data[i].opaque(), *min_alignment_);
info().op_name, i, kernel_args[i].data, *min_alignment_);
}
}
}
Expand All @@ -134,22 +136,22 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> KernelThunk::Execute(
params.host_kernels->Find(kernel_name_));

absl::MutexLock lock(&mutex_);
kernel_.emplace(buffers_data.size(), kernel_fn, nullptr);
kernel_.emplace(kernel_args.size(), kernel_fn, nullptr);
kernel_ptr_.store(kernel = &kernel_.value());
}

// If intra-op thread pool is not nullptr, we launch HostKernel in async mode
// by scheduling tasks into it. HostKernel launch completion will
// automatically signal KernelThunk execute completion.
if (ABSL_PREDICT_FALSE(params.intra_op_threadpool && use_task_runner_)) {
return kernel->Launch(thread_dim_, buffers_data,
return kernel->Launch(thread_dim_, kernel_args,
[&params](se::host::HostKernel::Task task) {
params.intra_op_threadpool->getPool()->Schedule(
ToCopyableTask(std::move(task)));
});
}

TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, buffers_data));
TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, kernel_args));
return OkExecuteEvent();
}

Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/service/cpu/runtime/kernel_thunk_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class AddF32HostKernels : public Thunk::HostKernels {
public:
absl::StatusOr<SE_HOST_Kernel*> Find(std::string_view name) override {
return +[](const SE_HOST_KernelCallFrame* call_frame) {
SE_HOST_KernelArg& in = call_frame->args[0];
SE_HOST_KernelArg& out = call_frame->args[1];
const SE_HOST_KernelArg& in = call_frame->args[0];
const SE_HOST_KernelArg& out = call_frame->args[1];

float* in_ptr = reinterpret_cast<float*>(in.data);
float* out_ptr = reinterpret_cast<float*>(out.data);
Expand Down
33 changes: 24 additions & 9 deletions third_party/xla/xla/stream_executor/host/host_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class HostKernelExecuteState
HostKernelExecuteState(HostKernel::TaskRunner task_runner,
HostKernel::KernelFunction* function,
ThreadDim thread_dims,
absl::Span<const DeviceMemoryBase> buffers);
absl::Span<const SE_HOST_KernelArg> args);

// Notify of a completion of a host kernel task.
void Notify(absl::Status status);
Expand Down Expand Up @@ -118,11 +118,19 @@ HostKernel::HostKernel(unsigned arity, SE_HOST_Kernel* kernel,
absl::Status HostKernel::Launch(
const ThreadDim& thread_dims,
absl::Span<const DeviceMemoryBase> buffers) const {
SE_HOST_KernelThreadDim kernel_thread_dims = {thread_dims.x, thread_dims.y,
thread_dims.z};
return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers));
}

absl::Status HostKernel::Launch(
const ThreadDim& thread_dims,
absl::Span<const SE_HOST_KernelArg> args) const {
SE_HOST_KernelThreadDim kernel_thread_dims = {
thread_dims.x,
thread_dims.y,
thread_dims.z,
};

SE_HOST_Kernel* kernel = function_->kernel();
auto args = ConvertBuffersToKernelArgs(buffers);

for (uint64_t z = 0; z < thread_dims.z; ++z) {
for (uint64_t y = 0; y < thread_dims.y; ++y) {
Expand All @@ -134,7 +142,7 @@ absl::Status HostKernel::Launch(

SE_HOST_KernelError* error = (*kernel)(&call_frame);

if (error != nullptr) {
if (ABSL_PREDICT_FALSE(error != nullptr)) {
return absl::InternalError("Failed to call host kernel");
}
}
Expand All @@ -147,20 +155,27 @@ absl::Status HostKernel::Launch(
tsl::AsyncValueRef<LaunchEvent> HostKernel::Launch(
const ThreadDim& thread_dims, absl::Span<const DeviceMemoryBase> buffers,
TaskRunner task_runner) const {
return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers),
std::move(task_runner));
}

tsl::AsyncValueRef<LaunchEvent> HostKernel::Launch(
const ThreadDim& thread_dims, absl::Span<const SE_HOST_KernelArg> args,
TaskRunner task_runner) const {
size_t num_tasks = thread_dims.x * thread_dims.y * thread_dims.z;
CHECK_GT(num_tasks, 0) << "Number of tasks must be positive"; // Crash Ok

// Short-circuit launch with a single task and run it in the caller thread.
if (ABSL_PREDICT_TRUE(num_tasks == 1)) {
absl::Status launched = Launch(thread_dims, buffers);
absl::Status launched = Launch(thread_dims, args);
return ABSL_PREDICT_TRUE(launched.ok())
? OkLaunchEvent()
: tsl::MakeErrorAsyncValueRef(std::move(launched));
}

// Allocate a control structure that will orchestrate kernel execution.
auto state = tsl::MakeRef<HostKernelExecuteState>(
std::move(task_runner), function_.get(), thread_dims, buffers);
std::move(task_runner), function_.get(), thread_dims, args);

state->CallAsync(/*start_index=*/0, /*end_index=*/num_tasks);

Expand All @@ -169,12 +184,12 @@ tsl::AsyncValueRef<LaunchEvent> HostKernel::Launch(

HostKernelExecuteState::HostKernelExecuteState(
HostKernel::TaskRunner task_runner, HostKernel::KernelFunction* function,
ThreadDim thread_dims, absl::Span<const DeviceMemoryBase> buffers)
ThreadDim thread_dims, absl::Span<const SE_HOST_KernelArg> args)
: task_runner_(std::move(task_runner)),
num_tasks_(thread_dims.x * thread_dims.y * thread_dims.z),
kernel_(function->kernel()),
thread_dims_({thread_dims.x, thread_dims.y, thread_dims.z}),
args_(ConvertBuffersToKernelArgs(buffers)),
args_(args.begin(), args.end()),
abort_(false),
counter_(num_tasks_),
event_(tsl::MakeConstructedAsyncValueRef<LaunchEvent>()) {}
Expand Down
5 changes: 5 additions & 0 deletions third_party/xla/xla/stream_executor/host/host_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class HostKernel : public Kernel {
// `thread_dims` and calling the kernel function.
absl::Status Launch(const ThreadDim& thread_dims,
absl::Span<const DeviceMemoryBase> buffers) const;
absl::Status Launch(const ThreadDim& thread_dims,
absl::Span<const SE_HOST_KernelArg> args) const;

// Launches the kernel by iterating over all threads in `thread_dims` and
// calling `task_runner` to run individual task (implementation might decide
Expand All @@ -93,6 +95,9 @@ class HostKernel : public Kernel {
tsl::AsyncValueRef<LaunchEvent> Launch(
const ThreadDim& thread_dims, absl::Span<const DeviceMemoryBase> buffers,
TaskRunner task_runner) const;
tsl::AsyncValueRef<LaunchEvent> Launch(
const ThreadDim& thread_dims, absl::Span<const SE_HOST_KernelArg> args,
TaskRunner task_runner) const;

// For host platform, we assume that a core is a thread, and we can run at
// most one instance of a kernel on a given thread.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ typedef struct SE_HOST_KernelCallFrame {
SE_HOST_KernelThread* thread;

size_t num_args;
SE_HOST_KernelArg* args;
const SE_HOST_KernelArg* args;
} SE_HOST_KernelCallFrame;

// Error reporting for host kernels. NULL means success.
Expand Down
25 changes: 16 additions & 9 deletions third_party/xla/xla/stream_executor/host/host_kernel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ static auto ToCopyableTask(HostKernel::Task task) {
}

static SE_HOST_KernelError* AddI32(const SE_HOST_KernelCallFrame* call_frame) {
SE_HOST_KernelArg& lhs = call_frame->args[0];
SE_HOST_KernelArg& rhs = call_frame->args[1];
SE_HOST_KernelArg& out = call_frame->args[2];
const SE_HOST_KernelArg& lhs = call_frame->args[0];
const SE_HOST_KernelArg& rhs = call_frame->args[1];
const SE_HOST_KernelArg& out = call_frame->args[2];

int32_t* lhs_ptr = reinterpret_cast<int32_t*>(lhs.data);
int32_t* rhs_ptr = reinterpret_cast<int32_t*>(rhs.data);
Expand Down Expand Up @@ -217,7 +217,9 @@ TEST(HostKernelTest, LaunchAsync) {
};

HostKernel host_kernel(/*arity=*/0, no_op);
auto event = host_kernel.Launch(ThreadDim(4, 4, 4), {}, std::move(runner));
auto event = host_kernel.Launch(ThreadDim(4, 4, 4),
absl::Span<const SE_HOST_KernelArg>(),
std::move(runner));

tsl::BlockUntilReady(event);
EXPECT_TRUE(event.IsConcrete());
Expand Down Expand Up @@ -245,7 +247,9 @@ TEST(HostKernelTest, LaunchAsyncError) {
};

HostKernel host_kernel(/*arity=*/0, maybe_error);
auto event = host_kernel.Launch(ThreadDim(4, 4, 4), {}, std::move(runner));
auto event = host_kernel.Launch(ThreadDim(4, 4, 4),
absl::Span<const SE_HOST_KernelArg>(),
std::move(runner));

tsl::BlockUntilReady(event);
ASSERT_TRUE(event.IsError());
Expand All @@ -269,7 +273,8 @@ static void BM_HostKernelSyncLaunch(benchmark::State& state) {

HostKernel kernel(/*arity=*/0, NoOp);
for (auto _ : state) {
benchmark::DoNotOptimize(kernel.Launch(ThreadDim(tdim_x), /*buffers=*/{}));
benchmark::DoNotOptimize(kernel.Launch(
ThreadDim(tdim_x), absl::Span<const SE_HOST_KernelArg>()));
}
}

Expand All @@ -281,9 +286,11 @@ static void BM_HostKernelAsyncLaunch(benchmark::State& state) {

HostKernel kernel(/*arity=*/0, NoOp);
for (auto _ : state) {
auto event = kernel.Launch(ThreadDim(tdim_x), {}, [&](auto task) {
thread_pool->Schedule(ToCopyableTask(std::move(task)));
});
auto event =
kernel.Launch(ThreadDim(tdim_x), absl::Span<const SE_HOST_KernelArg>(),
[&](auto task) {
thread_pool->Schedule(ToCopyableTask(std::move(task)));
});
tsl::BlockUntilReady(event);
}
}
Expand Down