[go: nahoru, domu]

Skip to content

Commit

Permalink
PR #13569: [GPU] Add on-disk per-kernel compilation cache.
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#13569

Copybara import of the project:

--
6c366cf77a749f10c9e23a52e4ee7b16359baaf8 by Ilia Sergachev <isergachev@nvidia.com>:

[GPU] Let constants be emitted into a separate LLVM module.

--
a164cd1214fecd70e6538c8995d0a7318264a439 by Ilia Sergachev <isergachev@nvidia.com>:

[GPU] Add on-disk per-kernel compilation cache.

--
7808613eb73dda37f15b69a82ca0617613cd98ce by Ilia Sergachev <isergachev@nvidia.com>:

Skip kernel cache tests when linking is not available.

--
89694a031c441605f45b36fc606a1b6e6a351492 by Ilia Sergachev <isergachev@nvidia.com>:

Address feedback

--
cf8ddb3f41413d4fb8c7440cf3b10d2100a89bf4 by Ilia Sergachev <isergachev@nvidia.com>:

Remove static keyword

Merging this change closes #13569

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#13569 from openxla:disk_kernel_cache cf8ddb3f41413d4fb8c7440cf3b10d2100a89bf4
PiperOrigin-RevId: 642202829
  • Loading branch information
sergachev authored and tensorflower-gardener committed Jun 11, 2024
1 parent 4222b3f commit a81da4e
Show file tree
Hide file tree
Showing 18 changed files with 630 additions and 112 deletions.
5 changes: 3 additions & 2 deletions tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ namespace {
void WriteAllFuzzers(string root_location, std::vector<string> api_def_dirs,
std::vector<string> op_names) {
OpList ops;
StatusOr<ApiDefMap> api_def_map = LoadOpsAndApiDefs(ops, false, api_def_dirs);
absl::StatusOr<ApiDefMap> api_def_map =
LoadOpsAndApiDefs(ops, false, api_def_dirs);

TF_CHECK_OK(api_def_map.status());

Env* env = Env::Default();
tsl::Status status;
absl::Status status;
std::unique_ptr<WritableFile> fuzz_file = nullptr;
for (const OpDef& op_def : ops.op()) {
if (std::find(op_names.begin(), op_names.end(), op_def.name()) ==
Expand Down
6 changes: 3 additions & 3 deletions third_party/xla/third_party/tsl/workspace2.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,10 @@ def _tf_repositories():
tf_http_archive(
name = "zlib",
build_file = "//third_party:zlib.BUILD",
sha256 = "b3a24de97a8fdbc835b9833169501030b8977031bcb54b3b3ac13740f846ab30",
strip_prefix = "zlib-1.2.13",
sha256 = "9a93b2b7dfdac77ceba5a558a580e74667dd6fede4585b91eefb60f03b72df23",
strip_prefix = "zlib-1.3.1",
system_build_file = "//third_party/systemlibs:zlib.BUILD",
urls = tf_mirror_urls("https://zlib.net/fossils/zlib-1.2.13.tar.gz"),
urls = tf_mirror_urls("https://zlib.net/zlib-1.3.1.tar.gz"),
)

tf_http_archive(
Expand Down
6 changes: 3 additions & 3 deletions third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -615,9 +615,10 @@ class RocmTraceCollectorImpl : public profiler::RocmTraceCollector {
absl::flat_hash_map<uint32_t, RocmTracerEvent> auxiliary_api_events_map_
TF_GUARDED_BY(event_maps_mutex_);

const std::vector<RocmTracerEvent> ApiActivityInfoExchange();
const std::vector<RocmTracerEvent> ApiActivityInfoExchange()
TF_EXCLUSIVE_LOCKS_REQUIRED(event_maps_mutex_);

absl::node_hash_map<uint32_t, PerDeviceCollector> per_device_collector_;
absl::flat_hash_map<uint32_t, PerDeviceCollector> per_device_collector_;
};
//==========

Expand Down Expand Up @@ -732,7 +733,6 @@ RocmTraceCollectorImpl::ApiActivityInfoExchange() {
std::vector<RocmTracerEvent> aggregated_events;

// Copy info from activity events to API callback events
mutex_lock lock{event_maps_mutex_};
for (auto& api_iter : api_events_map_) {
RocmTracerEvent& api_event = api_iter.second;
auto activity_event =
Expand Down
8 changes: 8 additions & 0 deletions third_party/xla/xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1715,6 +1715,14 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"xla_use_shardonnay",
bool_setter_for(&DebugOptions::set_xla_use_shardonnay),
debug_options->xla_use_shardonnay(), "Whether to use Shardonnay."));
flag_list->push_back(tsl::Flag(
"xla_gpu_kernel_cache_file",
string_setter_for(&DebugOptions::set_xla_gpu_kernel_cache_file),
debug_options->xla_gpu_kernel_cache_file(),
"Path to a file to cache compiled kernels. If the file doesn't exist "
"write the compilation cache of the first compiled HLO module into it."
"Once the file exists, further compilations will read it to reuse "
"the kernels, but not write it. This behavior may change later."));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Expand Down
18 changes: 18 additions & 0 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3138,6 +3138,7 @@ cc_library(
":metrics",
":runtime_intrinsics",
"//xla:shape_util",
"//xla:status_macros",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
Expand All @@ -3159,6 +3160,7 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@llvm-project//llvm:AsmParser",
"@llvm-project//llvm:TransformUtils",
"@llvm-project//llvm:ir_headers",
Expand All @@ -3169,7 +3171,9 @@ cc_library(
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:path",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/profiler/lib:scoped_annotation",
],
)

Expand Down Expand Up @@ -5635,19 +5639,33 @@ cc_library(
srcs = ["kernel_reuse_cache.cc"],
hdrs = ["kernel_reuse_cache.h"],
deps = [
":executable_proto_cc",
":kernel_arguments",
":launch_dimensions",
"//xla:status_macros",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/stream_executor:launch_dim",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:logging",
],
)

xla_cc_test(
name = "kernel_reuse_cache_test",
srcs = ["kernel_reuse_cache_test.cc"],
deps = [
":kernel_reuse_cache",
"//xla/tests:xla_internal_test_main",
"@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:test",
],
)

cc_library(
name = "kernel_arguments",
srcs = ["kernel_arguments.cc"],
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/autotuner_compile_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ AutotunerCompileUtil::AutotunerCompileUtil(const AutotuneConfig& config,
// Avoid using GPU graphs as we don't want to measure graph construction time.
opts_.clear_xla_gpu_enable_command_buffer();
opts_.set_xla_embed_ir_in_executable(false);
opts_.set_xla_gpu_kernel_cache_file("");
}

absl::StatusOr<std::optional<AutotunerCompileUtil::ProfilingOutput>>
Expand Down
51 changes: 50 additions & 1 deletion third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ limitations under the License.
#include <stdlib.h>

#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <string>
Expand All @@ -28,6 +27,8 @@ limitations under the License.

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/DiagnosticPrinter.h"
Expand Down Expand Up @@ -61,6 +62,7 @@ limitations under the License.
#include "xla/service/hlo_ordering.h"
#include "xla/service/logical_buffer.h"
#include "xla/shape.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/rocm/rocm_platform_id.h"
Expand All @@ -70,7 +72,9 @@ limitations under the License.
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/path.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/scoped_annotation.h"

namespace xla::gpu {

Expand Down Expand Up @@ -101,6 +105,35 @@ void RemoveUnusedAndUninitializedGlobals(
}
}

static absl::Status LoadCache(IrEmitterContext& ir_emitter_context,
absl::string_view cache_file_path) {
std::string resolved_path;
if (!tsl::io::ResolveTestPrefixes(cache_file_path, resolved_path)) {
return FailedPrecondition("File path can not be resolved: %s",
cache_file_path);
}
if (tsl::Env::Default()->FileExists(resolved_path).ok()) {
std::string serialized;
TF_RETURN_IF_ERROR(
tsl::ReadFileToString(tsl::Env::Default(), resolved_path, &serialized));
CompilationCacheProto proto;
if (!proto.ParseFromString(serialized)) {
return Internal("Failed to parse serialized CompilationCacheProto.");
}
// Register all cached kernel names with the name uniquer to avoid
// naming conflicts.
for (const auto& [name, _] : proto.entries()) {
TF_RET_CHECK(ir_emitter_context.name_uniquer()->GetUniqueName(name) ==
name)
<< "Failed registering " << name << "in NameUniquer.";
}
TF_RETURN_IF_ERROR(ir_emitter_context.kernel_cache().Load(proto));
} else {
VLOG(1) << "Compilation cache file does not exist: " << resolved_path;
}
return absl::OkStatus();
}

} // namespace

absl::StatusOr<CompileModuleResults> CompileModuleToLlvmIr(
Expand All @@ -117,6 +150,14 @@ absl::StatusOr<CompileModuleResults> CompileModuleToLlvmIr(
results.llvm_module->setTargetTriple(target_triple);
results.llvm_module->setDataLayout(data_layout);

absl::string_view cache_file_path =
hlo_module->config().debug_options().xla_gpu_kernel_cache_file();
const bool use_cache =
!cache_file_path.empty() && split_constants_module &&
hlo_module->config()
.debug_options()
.xla_gpu_enable_llvm_module_compilation_parallelism();

if (split_constants_module) {
// Constants are emitted into a separate module to avoid caching them.
results.llvm_module_constants = std::make_unique<llvm::Module>(
Expand Down Expand Up @@ -191,6 +232,10 @@ absl::StatusOr<CompileModuleResults> CompileModuleToLlvmIr(
mlir_context.get(), results.llvm_module.get(),
results.llvm_module_constants.get(), /*emit_kernels=*/true);

if (use_cache) {
TF_RETURN_IF_ERROR(LoadCache(ir_emitter_context, cache_file_path));
}

std::vector<BufferAllocation*> allocations;
results.output_shape = hlo_module->result_shape();
TF_ASSIGN_OR_RETURN(results.output_info,
Expand Down Expand Up @@ -227,6 +272,10 @@ absl::StatusOr<CompileModuleResults> CompileModuleToLlvmIr(
}

results.executable = ir_emitter->ConsumeThunkSequence();
if (use_cache) {
results.kernel_compilation_cache =
ir_emitter_context.kernel_cache().Export();
}

return results;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ struct CompileModuleResults {
absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo> output_info;
Shape output_shape;
std::string module_name;
CompilationCacheProto kernel_compilation_cache;

// If true, the compiled module uses buffer allocations owned by
// buffer_assignment. Otherwise the compiled module uses buffer allocations
Expand Down
24 changes: 24 additions & 0 deletions third_party/xla/xla/service/gpu/executable.proto
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,27 @@ message CompilationResultProto {
bytes binary = 4;
map<string, string> dnn_compiled_graphs = 5;
}

message LaunchDimensionsProto {
uint64 num_blocks = 1;
uint64 num_threads_per_block = 2;
}

message ClusterDimProto {
uint64 x = 1;
uint64 y = 2;
uint64 z = 3;
}

message CompilationCacheEntryProto {
string fingerprint = 1;
LaunchDimensionsProto launch_dimensions = 2;
optional ClusterDimProto cluster_dim = 3;
int64 shmem_bytes = 4;
bytes binary = 5;
}

message CompilationCacheProto {
// Key is the kernel name.
map<string, CompilationCacheEntryProto> entries = 1;
}
Loading

0 comments on commit a81da4e

Please sign in to comment.