[go: nahoru, domu]

Skip to content

Commit

Permalink
[xla:cpu] Optimize KernelThunk by passing SE_HOST_KernelArg directly …
Browse files Browse the repository at this point in the history
…to the kernel

PiperOrigin-RevId: 646311089
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Jun 25, 2024
1 parent 01aeb51 commit e4ab12a
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 98 deletions.
2 changes: 2 additions & 0 deletions third_party/xla/opensource_only.files
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
compiler/xla/mlir_hlo/WORKSPACE:
compiler/xla/package_groups.bzl:
compiler/xla/stream_executor/build_defs.bzl:
compiler/xla/tsl/cuda/stub.bzl:
compiler/xla/tsl/mkl/BUILD:
compiler/xla/tsl/mkl/LICENSE:
compiler/xla/tsl/mkl/MKL_LICENSE:
compiler/xla/tsl/mkl/build_defs.bzl:
compiler/xla/tsl/package_groups.bzl:
third_party/BUILD:
third_party/__init__:.py
third_party/compute_library/BUILD:
Expand Down
39 changes: 3 additions & 36 deletions third_party/xla/xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library")
load("//third_party/compute_library:build_defs.bzl", "if_enable_acl")

# Placeholder: load py_proto_library
load("//xla:package_groups.bzl", "xla_package_groups")
load("//xla:xla.bzl", "xla_cc_test", "xla_py_proto_library")
load("//xla/tsl:tsl.bzl", "if_google", "internal_visibility")
load("//xla/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable")
Expand All @@ -17,46 +18,12 @@ package(
licenses = ["notice"],
)

package_group(
name = "friends",
includes = ["//xla:internal"],
packages = [
# copybara:uncomment "//learning/...",
"//third_party/australis/...",
"//third_party/iree/...",
"//third_party/libxc/...",
"//third_party/mira/...",
"//third_party/mlcompass/...",
"//third_party/mlir_edge/model_curriculum/...",
"//third_party/openxla/shardonnay/...",
"//third_party/py/enzyme_ad/...",
"//third_party/py/jax/...",
"//third_party/py/t5x/...",
"//third_party/py/tpu_graphs/...",
"//tensorflow/compiler/...",
"//tensorflow/python/tpu/...",
],
)

package_group(
name = "internal",
packages = [
"//xla/...",
],
)

package_group(
name = "runtime",
packages = [
"//xla/runtime/...",
"//xla/service/gpu/runtime/...",
],
)

exports_files([
"lit.cfg.py",
])

xla_package_groups()

# Filegroup used to collect source files for dependency checking.
filegroup(
name = "c_srcs",
Expand Down
23 changes: 23 additions & 0 deletions third_party/xla/xla/package_groups.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""XLA package_group definitions"""

def xla_package_groups(name = "xla_package_groups"):
native.package_group(
name = "friends",
packages = ["//..."],
)

native.package_group(
name = "internal",
packages = ["//..."],
)

native.package_group(
name = "runtime",
packages = ["//..."],
)

def xla_tests_package_groups(name = "xla_tests_package_groups"):
native.package_group(
name = "friends",
packages = ["//..."],
)
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ static void BM_SelectAndScatterF32(benchmark::State& state) {

BENCHMARK(BM_SelectAndScatterF32)
->MeasureProcessCPUTime()
->Arg(64)
->Arg(128)
->Arg(256)
->Arg(512)
Expand Down
30 changes: 16 additions & 14 deletions third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,38 +87,40 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> KernelThunk::Execute(
kernel_name_, arguments_buffers_.size(), results_buffers_.size(),
thread_dim_.ToString());

absl::InlinedVector<se::DeviceMemoryBase, 8> buffers_data;
buffers_data.reserve(arguments_buffers_.size() + results_buffers_.size());
absl::InlinedVector<SE_HOST_KernelArg, 8> kernel_args;
kernel_args.reserve(arguments_buffers_.size() + results_buffers_.size());

int64_t arg_num = 0;
for (BufferAllocation::Slice& buffer : arguments_buffers_) {
TF_ASSIGN_OR_RETURN(buffers_data.emplace_back(),
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase arg_data,
params.buffer_allocations->GetDeviceAddress(buffer));
kernel_args.push_back(
SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()});
VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", arg_num++,
buffer.ToString(),
buffers_data.back().opaque());
buffer.ToString(), kernel_args.back().data);
}

int64_t res_num = 0;
for (BufferAllocation::Slice& buffer : results_buffers_) {
TF_ASSIGN_OR_RETURN(buffers_data.emplace_back(),
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase result_data,
params.buffer_allocations->GetDeviceAddress(buffer));
kernel_args.push_back(
SE_HOST_KernelArg{result_data.opaque(), result_data.size()});
VLOG(3) << absl::StreamFormat(" res #%d: %s (%p)", res_num++,
buffer.ToString(),
buffers_data.back().opaque());
buffer.ToString(), kernel_args.back().data);
}

// Check that all buffers are aligned to the minimum alignment. We codegen
// with the assumption that all buffers are aligned, and if they are not, we
// will crash with a segmentation fault, or worse, produce incorrect results.
if (min_alignment_.has_value()) {
for (int64_t i = 0; i < buffers_data.size(); ++i) {
auto ptr = reinterpret_cast<uintptr_t>(buffers_data[i].opaque());
for (int64_t i = 0; i < kernel_args.size(); ++i) {
auto ptr = reinterpret_cast<uintptr_t>(kernel_args[i].data);
if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) {
return Internal(
"Host kernel %s buffer argument #%d (%p) is not aligned to a "
"required minimum alignment of %d bytes",
info().op_name, i, buffers_data[i].opaque(), *min_alignment_);
info().op_name, i, kernel_args[i].data, *min_alignment_);
}
}
}
Expand All @@ -134,22 +136,22 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> KernelThunk::Execute(
params.host_kernels->Find(kernel_name_));

absl::MutexLock lock(&mutex_);
kernel_.emplace(buffers_data.size(), kernel_fn, nullptr);
kernel_.emplace(kernel_args.size(), kernel_fn, nullptr);
kernel_ptr_.store(kernel = &kernel_.value());
}

// If intra-op thread pool is not nullptr, we launch HostKernel in async mode
// by scheduling tasks into it. HostKernel launch completion will
// automatically signal KernelThunk execute completion.
if (ABSL_PREDICT_FALSE(params.intra_op_threadpool && use_task_runner_)) {
return kernel->Launch(thread_dim_, buffers_data,
return kernel->Launch(thread_dim_, kernel_args,
[&params](se::host::HostKernel::Task task) {
params.intra_op_threadpool->getPool()->Schedule(
ToCopyableTask(std::move(task)));
});
}

TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, buffers_data));
TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, kernel_args));
return OkExecuteEvent();
}

Expand Down
33 changes: 24 additions & 9 deletions third_party/xla/xla/stream_executor/host/host_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class HostKernelExecuteState
HostKernelExecuteState(HostKernel::TaskRunner task_runner,
HostKernel::KernelFunction* function,
ThreadDim thread_dims,
absl::Span<const DeviceMemoryBase> buffers);
absl::Span<const SE_HOST_KernelArg> args);

// Notify of a completion of a host kernel task.
void Notify(absl::Status status);
Expand Down Expand Up @@ -118,11 +118,19 @@ HostKernel::HostKernel(unsigned arity, SE_HOST_Kernel* kernel,
absl::Status HostKernel::Launch(
const ThreadDim& thread_dims,
absl::Span<const DeviceMemoryBase> buffers) const {
SE_HOST_KernelThreadDim kernel_thread_dims = {thread_dims.x, thread_dims.y,
thread_dims.z};
return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers));
}

absl::Status HostKernel::Launch(
const ThreadDim& thread_dims,
absl::Span<const SE_HOST_KernelArg> args) const {
SE_HOST_KernelThreadDim kernel_thread_dims = {
thread_dims.x,
thread_dims.y,
thread_dims.z,
};

SE_HOST_Kernel* kernel = function_->kernel();
auto args = ConvertBuffersToKernelArgs(buffers);

for (uint64_t z = 0; z < thread_dims.z; ++z) {
for (uint64_t y = 0; y < thread_dims.y; ++y) {
Expand All @@ -134,7 +142,7 @@ absl::Status HostKernel::Launch(

SE_HOST_KernelError* error = (*kernel)(&call_frame);

if (error != nullptr) {
if (ABSL_PREDICT_FALSE(error != nullptr)) {
return absl::InternalError("Failed to call host kernel");
}
}
Expand All @@ -147,20 +155,27 @@ absl::Status HostKernel::Launch(
tsl::AsyncValueRef<LaunchEvent> HostKernel::Launch(
const ThreadDim& thread_dims, absl::Span<const DeviceMemoryBase> buffers,
TaskRunner task_runner) const {
return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers),
std::move(task_runner));
}

tsl::AsyncValueRef<LaunchEvent> HostKernel::Launch(
const ThreadDim& thread_dims, absl::Span<const SE_HOST_KernelArg> args,
TaskRunner task_runner) const {
size_t num_tasks = thread_dims.x * thread_dims.y * thread_dims.z;
CHECK_GT(num_tasks, 0) << "Number of tasks must be positive"; // Crash Ok

// Short-circuit launch with a single task and run it in the caller thread.
if (ABSL_PREDICT_TRUE(num_tasks == 1)) {
absl::Status launched = Launch(thread_dims, buffers);
absl::Status launched = Launch(thread_dims, args);
return ABSL_PREDICT_TRUE(launched.ok())
? OkLaunchEvent()
: tsl::MakeErrorAsyncValueRef(std::move(launched));
}

// Allocate a control structure that will orchestrate kernel execution.
auto state = tsl::MakeRef<HostKernelExecuteState>(
std::move(task_runner), function_.get(), thread_dims, buffers);
std::move(task_runner), function_.get(), thread_dims, args);

state->CallAsync(/*start_index=*/0, /*end_index=*/num_tasks);

Expand All @@ -169,12 +184,12 @@ tsl::AsyncValueRef<LaunchEvent> HostKernel::Launch(

HostKernelExecuteState::HostKernelExecuteState(
HostKernel::TaskRunner task_runner, HostKernel::KernelFunction* function,
ThreadDim thread_dims, absl::Span<const DeviceMemoryBase> buffers)
ThreadDim thread_dims, absl::Span<const SE_HOST_KernelArg> args)
: task_runner_(std::move(task_runner)),
num_tasks_(thread_dims.x * thread_dims.y * thread_dims.z),
kernel_(function->kernel()),
thread_dims_({thread_dims.x, thread_dims.y, thread_dims.z}),
args_(ConvertBuffersToKernelArgs(buffers)),
args_(args.begin(), args.end()),
abort_(false),
counter_(num_tasks_),
event_(tsl::MakeConstructedAsyncValueRef<LaunchEvent>()) {}
Expand Down
5 changes: 5 additions & 0 deletions third_party/xla/xla/stream_executor/host/host_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class HostKernel : public Kernel {
// `thread_dims` and calling the kernel function.
absl::Status Launch(const ThreadDim& thread_dims,
absl::Span<const DeviceMemoryBase> buffers) const;
absl::Status Launch(const ThreadDim& thread_dims,
absl::Span<const SE_HOST_KernelArg> args) const;

// Launches the kernel by iterating over all threads in `thread_dims` and
// calling `task_runner` to run individual task (implementation might decide
Expand All @@ -93,6 +95,9 @@ class HostKernel : public Kernel {
tsl::AsyncValueRef<LaunchEvent> Launch(
const ThreadDim& thread_dims, absl::Span<const DeviceMemoryBase> buffers,
TaskRunner task_runner) const;
tsl::AsyncValueRef<LaunchEvent> Launch(
const ThreadDim& thread_dims, absl::Span<const SE_HOST_KernelArg> args,
TaskRunner task_runner) const;

// For host platform, we assume that a core is a thread, and we can run at
// most one instance of a kernel on a given thread.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ typedef struct SE_HOST_KernelCallFrame {
SE_HOST_KernelThread* thread;

size_t num_args;
SE_HOST_KernelArg* args;
const SE_HOST_KernelArg* args;
} SE_HOST_KernelCallFrame;

// Error reporting for host kernels. NULL means success.
Expand Down
11 changes: 3 additions & 8 deletions third_party/xla/xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ load(
"@local_tsl//tsl/platform/default:cuda_build_defs.bzl",
"if_cuda_is_configured",
)
load("//xla:package_groups.bzl", "xla_tests_package_groups")
load("//xla:xla.bzl", "tests_build_defs_bzl_deps", "xla_cc_binary", "xla_cc_test")
load("//xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend_test_macros", "xla_test", "xla_test_library")
load("//xla/tsl:tsl.bzl", "internal_visibility", "tsl_copts")
Expand All @@ -22,14 +23,6 @@ package(
licenses = ["notice"],
)

package_group(
name = "friends",
includes = [
"//xla:friends",
],
packages = ["//platforms/testing/tests/..."],
)

# Filegroup used to collect source files for dependency checking.
filegroup(
name = "c_srcs",
Expand All @@ -39,6 +32,8 @@ filegroup(
]),
)

xla_tests_package_groups()

# Generate test_suites for all backends, named "${backend}_tests".
generate_backend_suites()

Expand Down
33 changes: 3 additions & 30 deletions third_party/xla/xla/tsl/BUILD
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
load("@bazel_skylib//lib:selects.bzl", "selects")
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "bool_setting")
load("package_groups.bzl", "tsl_package_groups")
load("tsl.bzl", "if_google", "if_oss")
load("tsl.default.bzl", "tsl_google_bzl_deps")

# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])

tsl_package_groups()

# Config setting to use in select()s to distinguish open source build from
# google internal build on configurable attributes.
#
Expand Down Expand Up @@ -497,36 +500,6 @@ config_setting(
visibility = ["//visibility:public"],
)

# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
# Instead, please use public APIs or public build rules TF provides.
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
# TODO(b/173549186): Move Google-internal TF code out of learning/brain
# TODO(jakeharmon): Prune this for use in TSL
package_group(
name = "internal",
packages = [
"//devtools/python/indexer/...",
"//learning/brain/keras/...",
"//learning/brain/mlir/...",
"//learning/brain/tfrt/...",
"//learning/lib/ami/simple_ml/...",
"//learning/pathways/...",
"//smartass/brain/configure/...",
"//tensorflow/...",
"//tensorflow_decision_forests/...",
"//tensorflow_federated/...",
"//third_party/cloud_tpu/convergence_tools/sdc_monitoring/...",
"//third_party/cloud_tpu/inference_converter/...",
"//third_party/py/cloud_ml_autoflow/...",
"//third_party/py/envlogger/...",
"//third_party/yggdrasil_decision_forests/...",
] + if_google([
# Needed in OSS, where bazel won't allow a package group to refer to an
# external repo.
"@local_tsl//tsl/...",
]),
)

bzl_library(
name = "tsl_bzl",
srcs = ["tsl.bzl"],
Expand Down
7 changes: 7 additions & 0 deletions third_party/xla/xla/tsl/package_groups.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""TSL package_group definitions"""

def tsl_package_groups(name = "tsl_package_groups"):
native.package_group(
name = "internal",
packages = ["//..."],
)

0 comments on commit e4ab12a

Please sign in to comment.