diff --git a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc index 8b2570fdc953a7..585baa08df864a 100644 --- a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc +++ b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc @@ -40,12 +40,13 @@ namespace { void WriteAllFuzzers(string root_location, std::vector api_def_dirs, std::vector op_names) { OpList ops; - StatusOr api_def_map = LoadOpsAndApiDefs(ops, false, api_def_dirs); + absl::StatusOr api_def_map = + LoadOpsAndApiDefs(ops, false, api_def_dirs); TF_CHECK_OK(api_def_map.status()); Env* env = Env::Default(); - tsl::Status status; + absl::Status status; std::unique_ptr fuzz_file = nullptr; for (const OpDef& op_def : ops.op()) { if (std::find(op_names.begin(), op_names.end(), op_def.name()) == diff --git a/third_party/xla/third_party/tsl/workspace2.bzl b/third_party/xla/third_party/tsl/workspace2.bzl index 67011d3d2bb3ea..4bd2fd2e1c158c 100644 --- a/third_party/xla/third_party/tsl/workspace2.bzl +++ b/third_party/xla/third_party/tsl/workspace2.bzl @@ -378,10 +378,10 @@ def _tf_repositories(): tf_http_archive( name = "zlib", build_file = "//third_party:zlib.BUILD", - sha256 = "b3a24de97a8fdbc835b9833169501030b8977031bcb54b3b3ac13740f846ab30", - strip_prefix = "zlib-1.2.13", + sha256 = "9a93b2b7dfdac77ceba5a558a580e74667dd6fede4585b91eefb60f03b72df23", + strip_prefix = "zlib-1.3.1", system_build_file = "//third_party/systemlibs:zlib.BUILD", - urls = tf_mirror_urls("https://zlib.net/fossils/zlib-1.2.13.tar.gz"), + urls = tf_mirror_urls("https://zlib.net/zlib-1.3.1.tar.gz"), ) tf_http_archive( diff --git a/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc b/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc index eb6928404d5e92..e53b017195d717 100644 --- a/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc +++ b/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc @@ -615,9 +615,10 @@ class RocmTraceCollectorImpl : public profiler::RocmTraceCollector { absl::flat_hash_map auxiliary_api_events_map_ TF_GUARDED_BY(event_maps_mutex_); - const std::vector ApiActivityInfoExchange(); + const std::vector ApiActivityInfoExchange() + TF_EXCLUSIVE_LOCKS_REQUIRED(event_maps_mutex_); - absl::node_hash_map per_device_collector_; + absl::flat_hash_map per_device_collector_; }; //========== @@ -732,7 +733,6 @@ RocmTraceCollectorImpl::ApiActivityInfoExchange() { std::vector aggregated_events; // Copy info from activity events to API callback events - mutex_lock lock{event_maps_mutex_}; for (auto& api_iter : api_events_map_) { RocmTracerEvent& api_event = api_iter.second; auto activity_event = diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 88e3778d87ecdf..56c2e97ad0b5da 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -1715,6 +1715,14 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "xla_use_shardonnay", bool_setter_for(&DebugOptions::set_xla_use_shardonnay), debug_options->xla_use_shardonnay(), "Whether to use Shardonnay.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_kernel_cache_file", + string_setter_for(&DebugOptions::set_xla_gpu_kernel_cache_file), + debug_options->xla_gpu_kernel_cache_file(), + "Path to a file to cache compiled kernels. If the file doesn't exist " + "write the compilation cache of the first compiled HLO module into it." + "Once the file exists, further compilations will read it to reuse " + "the kernels, but not write it. This behavior may change later.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index b446758b6e2acc..afd3dbd818f254 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -3138,6 +3138,7 @@ cc_library( ":metrics", ":runtime_intrinsics", "//xla:shape_util", + "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -3159,6 +3160,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:AsmParser", "@llvm-project//llvm:TransformUtils", "@llvm-project//llvm:ir_headers", @@ -3169,7 +3171,9 @@ cc_library( "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:scoped_annotation", ], ) @@ -5635,12 +5639,15 @@ cc_library( srcs = ["kernel_reuse_cache.cc"], hdrs = ["kernel_reuse_cache.h"], deps = [ + ":executable_proto_cc", ":kernel_arguments", ":launch_dimensions", + "//xla:status_macros", "//xla:util", "//xla/hlo/ir:hlo", "//xla/stream_executor:launch_dim", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -5648,6 +5655,17 @@ cc_library( ], ) +xla_cc_test( + name = "kernel_reuse_cache_test", + srcs = ["kernel_reuse_cache_test.cc"], + deps = [ + ":kernel_reuse_cache", + "//xla/tests:xla_internal_test_main", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:test", + ], +) + cc_library( name = "kernel_arguments", srcs = ["kernel_arguments.cc"], diff --git a/third_party/xla/xla/service/gpu/autotuner_compile_util.cc b/third_party/xla/xla/service/gpu/autotuner_compile_util.cc index 89c8640fdc7391..b3e880b08ac01f 100644 --- a/third_party/xla/xla/service/gpu/autotuner_compile_util.cc +++ b/third_party/xla/xla/service/gpu/autotuner_compile_util.cc @@ -95,6 +95,7 @@ AutotunerCompileUtil::AutotunerCompileUtil(const AutotuneConfig& config, // Avoid using GPU graphs as we don't want to measure graph construction time. opts_.clear_xla_gpu_enable_command_buffer(); opts_.set_xla_embed_ir_in_executable(false); + opts_.set_xla_gpu_kernel_cache_file(""); } absl::StatusOr> diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc index 510ad9ed95c3bb..9e63254a5c2776 100644 --- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -28,6 +27,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" @@ -61,6 +62,7 @@ limitations under the License. #include "xla/service/hlo_ordering.h" #include "xla/service/logical_buffer.h" #include "xla/shape.h" +#include "xla/status_macros.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" @@ -70,7 +72,9 @@ limitations under the License. #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/path.h" #include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/scoped_annotation.h" namespace xla::gpu { @@ -101,6 +105,35 @@ void RemoveUnusedAndUninitializedGlobals( } } +static absl::Status LoadCache(IrEmitterContext& ir_emitter_context, + absl::string_view cache_file_path) { + std::string resolved_path; + if (!tsl::io::ResolveTestPrefixes(cache_file_path, resolved_path)) { + return FailedPrecondition("File path can not be resolved: %s", + cache_file_path); + } + if (tsl::Env::Default()->FileExists(resolved_path).ok()) { + std::string serialized; + TF_RETURN_IF_ERROR( + tsl::ReadFileToString(tsl::Env::Default(), resolved_path, &serialized)); + CompilationCacheProto proto; + if (!proto.ParseFromString(serialized)) { + return Internal("Failed to parse serialized CompilationCacheProto."); + } + // Register all cached kernel names with the name uniquer to avoid + // naming conflicts. + for (const auto& [name, _] : proto.entries()) { + TF_RET_CHECK(ir_emitter_context.name_uniquer()->GetUniqueName(name) == + name) + << "Failed registering " << name << "in NameUniquer."; + } + TF_RETURN_IF_ERROR(ir_emitter_context.kernel_cache().Load(proto)); + } else { + VLOG(1) << "Compilation cache file does not exist: " << resolved_path; + } + return absl::OkStatus(); +} + } // namespace absl::StatusOr CompileModuleToLlvmIr( @@ -117,6 +150,14 @@ absl::StatusOr CompileModuleToLlvmIr( results.llvm_module->setTargetTriple(target_triple); results.llvm_module->setDataLayout(data_layout); + absl::string_view cache_file_path = + hlo_module->config().debug_options().xla_gpu_kernel_cache_file(); + const bool use_cache = + !cache_file_path.empty() && split_constants_module && + hlo_module->config() + .debug_options() + .xla_gpu_enable_llvm_module_compilation_parallelism(); + if (split_constants_module) { // Constants are emitted into a separate module to avoid caching them. results.llvm_module_constants = std::make_unique( @@ -191,6 +232,10 @@ absl::StatusOr CompileModuleToLlvmIr( mlir_context.get(), results.llvm_module.get(), results.llvm_module_constants.get(), /*emit_kernels=*/true); + if (use_cache) { + TF_RETURN_IF_ERROR(LoadCache(ir_emitter_context, cache_file_path)); + } + std::vector allocations; results.output_shape = hlo_module->result_shape(); TF_ASSIGN_OR_RETURN(results.output_info, @@ -227,6 +272,10 @@ absl::StatusOr CompileModuleToLlvmIr( } results.executable = ir_emitter->ConsumeThunkSequence(); + if (use_cache) { + results.kernel_compilation_cache = + ir_emitter_context.kernel_cache().Export(); + } return results; } diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h index c76d301c212485..0fdef39f261df2 100644 --- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h +++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h @@ -54,6 +54,7 @@ struct CompileModuleResults { absl::flat_hash_map output_info; Shape output_shape; std::string module_name; + CompilationCacheProto kernel_compilation_cache; // If true, the compiled module uses buffer allocations owned by // buffer_assignment. Otherwise the compiled module uses buffer allocations diff --git a/third_party/xla/xla/service/gpu/executable.proto b/third_party/xla/xla/service/gpu/executable.proto index f4c2a479bcc6ea..0c384beca953fc 100644 --- a/third_party/xla/xla/service/gpu/executable.proto +++ b/third_party/xla/xla/service/gpu/executable.proto @@ -27,3 +27,27 @@ message CompilationResultProto { bytes binary = 4; map dnn_compiled_graphs = 5; } + +message LaunchDimensionsProto { + uint64 num_blocks = 1; + uint64 num_threads_per_block = 2; +} + +message ClusterDimProto { + uint64 x = 1; + uint64 y = 2; + uint64 z = 3; +} + +message CompilationCacheEntryProto { + string fingerprint = 1; + LaunchDimensionsProto launch_dimensions = 2; + optional ClusterDimProto cluster_dim = 3; + int64 shmem_bytes = 4; + bytes binary = 5; +} + +message CompilationCacheProto { + // Key is the kernel name. + map entries = 1; +} diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 1dbf0bfb027264..e28445c46c1968 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -244,6 +244,7 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/numbers.h" +#include "tsl/platform/path.h" #include "tsl/platform/protobuf.h" // IWYU pragma: keep #include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" @@ -1769,37 +1770,44 @@ GpuCompiler::CompileSingleModule(const HloModuleConfig& module_config, return result; } -absl::StatusOr -GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config, - CompileModuleResults& compile_module_results, - se::GpuComputeCapability gpu_version, - se::StreamExecutor* stream_exec, - const CompileOptions& options, - const HloModule* debug_module) { - MaybeOwningThreadPool thread_pool = - module_config.debug_options() - .xla_gpu_enable_llvm_module_compilation_parallelism() - ? CreateMaybeOwningThreadPool( - /*parallelism=*/module_config.debug_options() - .xla_gpu_force_compilation_parallelism(), - /*default_thread_pool=*/options.thread_pool, - /*default_parallelism=*/1) - : MaybeOwningThreadPool(nullptr); - llvm::Module* llvm_module = &*compile_module_results.llvm_module; - - // Test whether LinkModules is supported. - TF_ASSIGN_OR_RETURN(bool can_use_link_modules, - CanUseLinkModules(module_config)); - - // Disable multi-threading during deviceless AOT compilation. - // TODO(anlunx): Enable multi-threading once deviceless AOT compilation is - // enabled. - if (!can_use_link_modules || !thread_pool.get() || !stream_exec) { - return CompileSingleModule(module_config, gpu_version, debug_module, +namespace { +int CountFunctions(const llvm::Module& module) { + int num_functions = 0; + for (const llvm::Function& func : module.functions()) { + if (!func.isDeclaration() && + func.getLinkage() == llvm::GlobalValue::LinkageTypes::ExternalLinkage) { + ++num_functions; + } + } + return num_functions; +} - llvm_module, /*relocatable=*/false, options, - /*shard_number=*/std::nullopt); +// Returns the name of the single function in the module or empty string if it's +// not a single-function module. +std::string SingleFunctionName(const llvm::Module& module) { + std::string name; + for (const llvm::Function& func : module.functions()) { + if (!func.isDeclaration() && + func.getLinkage() == llvm::GlobalValue::LinkageTypes::ExternalLinkage) { + if (name.empty()) { + // First function in a module: name the module with it. + name = func.getName().str(); + } else { + // Not the first function - the module is not cacheable. + return ""; + } + } } + return name; +} +} // namespace + +absl::StatusOr GpuCompiler::CompileAndLink( + const HloModuleConfig& module_config, + CompileModuleResults& compile_module_results, + se::GpuComputeCapability gpu_version, se::StreamExecutor* stream_exec, + const CompileOptions& options, const HloModule* debug_module) { + llvm::Module* llvm_module = &*compile_module_results.llvm_module; bool force_module_split = module_config.debug_options().xla_llvm_force_inline_before_split(); @@ -1826,15 +1834,6 @@ GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config, } } - std::vector> llvm_modules; - int num_functions = 0; - for (llvm::Function& func : llvm_module->functions()) { - if (!func.isDeclaration() && - func.getLinkage() == llvm::GlobalValue::LinkageTypes::ExternalLinkage) { - num_functions++; - } - } - // Record the name of some constant global variables and their initializers. // We'll change the linkage type of these variables from external to internal // to ensure constant-folding works properly after calling llvm::SplitModule. @@ -1861,14 +1860,42 @@ GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config, } } + llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module, + /*optimized=*/false, "inlined"); + + absl::string_view cache_path = + module_config.debug_options().xla_gpu_kernel_cache_file(); + const bool use_cache = !cache_path.empty(); + + struct NamedModule { + // The string is the function name for single-function modules (used to + // cache them), empty for all other modules. + std::string name; + std::unique_ptr module; + }; + std::vector llvm_modules; + MaybeOwningThreadPool thread_pool = CreateMaybeOwningThreadPool( + /*parallelism=*/module_config.debug_options() + .xla_gpu_force_compilation_parallelism(), + /*default_thread_pool=*/options.thread_pool, + /*default_parallelism=*/1); + // Only single-function module are cacheable -> for caching try to get 1 + // function per module. If caching is not used limit the number of modules to + // the number of threads. + int num_modules = CountFunctions(*llvm_module); + if (thread_pool.get() != nullptr && !use_cache) { + num_modules = std::max(1, std::min(thread_pool->NumThreads(), num_modules)); + } if (compile_module_results.llvm_module_constants != nullptr) { + llvm_modules.reserve(num_modules + 1); llvm_modules.push_back( - std::move(compile_module_results.llvm_module_constants)); + {"", std::move(compile_module_results.llvm_module_constants)}); + } else { + llvm_modules.reserve(num_modules); } + int single_function_module_count = 0; llvm::SplitModule( - *llvm_module, - std::max( - 1, std::min(thread_pool->NumThreads(), num_functions)), + *llvm_module, num_modules, [&](std::unique_ptr module) { // Change the linkage type of some global constant variables to internal for (llvm::GlobalVariable& gv : module->globals()) { @@ -1878,45 +1905,132 @@ GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config, gv.setLinkage(llvm::GlobalValue::InternalLinkage); } } - llvm_modules.push_back(std::move(module)); + const std::string name = SingleFunctionName(*module); + if (!name.empty()) { + ++single_function_module_count; + } + llvm_modules.push_back({name, std::move(module)}); }, /*PreserveLocals=*/true); + VLOG(2) << "Single-function cacheable modules: " + << single_function_module_count << " / " << llvm_modules.size(); - std::vector> compile_results( - llvm_modules.size()); - tsl::BlockingCounter counter(llvm_modules.size()); - for (int i = 0; i < llvm_modules.size(); i++) { - thread_pool.get_mutable()->Schedule( - [&compile_results, i, &llvm_modules, &counter, this, &module_config, - &gpu_version, &debug_module, &options] { - // Each thread has its own context to avoid race conditions. - llvm::LLVMContext new_context; - std::unique_ptr new_module = - CopyToContext(*llvm_modules.at(i), new_context); - compile_results.at(i) = CompileSingleModule( - module_config, gpu_version, debug_module, new_module.get(), - /*relocatable=*/true, options, - /*shard_number=*/i); - counter.DecrementCount(); - }); - } - counter.Wait(); + struct NamedCompileResult { + // Single function name or empty just like for llvm_modules. + std::string name; + absl::StatusOr result; + }; + std::vector compile_results(llvm_modules.size()); + if (thread_pool.get() != nullptr) { + tsl::BlockingCounter counter(llvm_modules.size()); + for (int i = 0; i < llvm_modules.size(); ++i) { + thread_pool.get_mutable()->Schedule( + [&compile_results, i, &llvm_modules, &counter, this, &module_config, + &gpu_version, &debug_module, &options] { + // Each thread has its own context to avoid race conditions. + llvm::LLVMContext new_context; + std::unique_ptr new_module = + CopyToContext(*llvm_modules.at(i).module, new_context); + compile_results.at(i) = { + llvm_modules.at(i).name, + CompileSingleModule(module_config, gpu_version, debug_module, + new_module.get(), + /*relocatable=*/true, options, + /*shard_number=*/i)}; + counter.DecrementCount(); + }); + } + counter.Wait(); + } else { + for (int i = 0; i < llvm_modules.size(); ++i) { + compile_results.at(i) = { + llvm_modules.at(i).name, + CompileSingleModule(module_config, gpu_version, debug_module, + &*llvm_modules.at(i).module, + /*relocatable=*/true, options, + /*shard_number=*/i)}; + } + } std::string ptx_snippets; - std::vector> submodule_compile_results; - for (auto& maybe_result : compile_results) { + std::vector> binaries_to_link; + binaries_to_link.reserve(compile_results.size()); + struct NamedBinary { + // The string is the function name or empty just like for llvm_modules. + std::string name; + std::vector binary; + }; + std::vector binaries_to_cache; + binaries_to_cache.reserve(single_function_module_count); + for (const auto& [name, maybe_result] : compile_results) { TF_ASSIGN_OR_RETURN(auto result, maybe_result); if (result.binary.empty()) { continue; } ptx_snippets += result.asm_text; ptx_snippets += "\n"; - submodule_compile_results.push_back(result.binary); + binaries_to_link.push_back(result.binary); + if (!name.empty()) { + binaries_to_cache.push_back({name, result.binary}); + } + } + + if (use_cache) { + std::string resolved_path; + if (!tsl::io::ResolveTestPrefixes(cache_path, resolved_path)) { + return FailedPrecondition("File path can not be resolved: %s", + cache_path); + } + CompilationCacheProto& cache = + compile_module_results.kernel_compilation_cache; + if (tsl::Env::Default()->FileExists(resolved_path).ok()) { + int loaded_kernel_count = 0; + for (const auto& [name, entry] : cache.entries()) { + if (llvm_module->getFunction(name)) { + VLOG(5) + << "Skipping cached " << name + << " in favor of the just compiled kernel with the same name."; + CHECK(entry.binary().empty()); + continue; + } + const uint8_t* binary = + reinterpret_cast(entry.binary().data()); + binaries_to_link.push_back( + std::vector(binary, binary + entry.binary().size())); + VLOG(5) << "Loaded " << name << ": " << entry.binary().size(); + ++loaded_kernel_count; + } + VLOG(2) << "Loaded " << loaded_kernel_count << " / " + << cache.entries_size() << " cached kernels."; + } else { + auto entries = cache.mutable_entries(); + for (const auto& [name, binary] : binaries_to_cache) { + auto it = entries->find(name); + if (it == entries->end()) { + continue; + } + it->second.set_binary(reinterpret_cast(binary.data()), + binary.size()); + VLOG(5) << "Cached kernels: " << name << ": " << binary.size(); + } + for (auto it = entries->begin(); it != entries->end();) { + if (it->second.binary().empty()) { + it = entries->erase(it); + } else { + ++it; + } + } + if (cache.entries_size() > 0) { + TF_RETURN_IF_ERROR(tsl::WriteStringToFile( + tsl::Env::Default(), resolved_path, cache.SerializeAsString())); + VLOG(2) << "Stored " << cache.entries_size() << " / " + << binaries_to_cache.size(); + } + } } - auto maybe_backend_result = - this->LinkModules(stream_exec, std::move(submodule_compile_results), - module_config.debug_options()); + auto maybe_backend_result = LinkModules( + stream_exec, std::move(binaries_to_link), module_config.debug_options()); if (!maybe_backend_result.ok()) { LOG(ERROR) << "The CUDA linking API did not work. Please use XLA_FLAGS=" "--xla_gpu_enable_llvm_module_compilation_parallelism=false " @@ -1925,6 +2039,7 @@ GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config, << maybe_backend_result.status(); return maybe_backend_result.status(); } + VLOG(4) << "Binary size after linking [B]: " << maybe_backend_result->size(); return BackendCompileResult{ptx_snippets, std::move(*maybe_backend_result)}; } @@ -1942,12 +2057,28 @@ GpuCompiler::CompileToBackendResult( TF_ASSIGN_OR_RETURN(se::Platform * platform, se::PlatformManager::PlatformWithId(PlatformId())); + // Test whether LinkModules is supported. + bool can_use_link_modules = (executor != nullptr); + if (can_use_link_modules) { + TF_ASSIGN_OR_RETURN(can_use_link_modules, + CanUseLinkModules(module->config())); + } + const bool split_modules = + can_use_link_modules && + module->config() + .debug_options() + .xla_gpu_enable_llvm_module_compilation_parallelism(); + const bool use_cache = + split_modules && + !module->config().debug_options().xla_gpu_kernel_cache_file().empty(); + // Compile the module TF_ASSIGN_OR_RETURN( CompileModuleResults compile_module_results, CompileModuleToLlvmIr(module, llvm_context, target_triple_, data_layout_, platform->Name(), platform->id(), gpu_device_info, - GetCanShareBuffer(), BufferSizeBytesFunction())); + GetCanShareBuffer(), BufferSizeBytesFunction(), + /*split_constants_module=*/use_cache)); if (user_pre_optimization_hook_) { user_pre_optimization_hook_(*compile_module_results.llvm_module); @@ -1959,18 +2090,31 @@ GpuCompiler::CompileToBackendResult( llvm_ir::DumpIrIfEnabled(*module, *compile_module_results.llvm_module, /*optimized=*/false); - if (compile_module_results.llvm_module_constants != nullptr) { llvm_ir::DumpIrIfEnabled(*module, *compile_module_results.llvm_module_constants, /*optimized=*/false, "constants"); } - TF_ASSIGN_OR_RETURN( - BackendCompileResult backend_result, - CompileToTargetBinary(module->config(), compile_module_results, - gpu_device_info.gpu_compute_capability(), executor, - options, module)); + BackendCompileResult backend_result; + // Disable multi-threading during deviceless AOT compilation. + // TODO(anlunx): Enable multi-threading once deviceless AOT compilation is + // enabled. + if (split_modules) { + TF_ASSIGN_OR_RETURN(backend_result, + CompileAndLink(module->config(), compile_module_results, + gpu_device_info.gpu_compute_capability(), + executor, options, module)); + } else { + CHECK(compile_module_results.llvm_module_constants == nullptr); + TF_ASSIGN_OR_RETURN( + backend_result, + CompileSingleModule(module->config(), + gpu_device_info.gpu_compute_capability(), module, + &*compile_module_results.llvm_module, + /*relocatable=*/false, options, + /*shard_number=*/std::nullopt)); + } RecordXlaDeviceBinarySize(backend_result.binary.size()); if (DumpingEnabledForHloModule(*module)) { DumpToFileInDirOrStdout(*module, "", "thunk_sequence.txt", @@ -2086,7 +2230,8 @@ absl::StatusOr> GpuCompiler::RunBackend( : std::move(res.compile_module_results.allocations)), /*buffer_assignment=*/ std::move(res.compile_module_results.buffer_assignment), - /*debug_buffer_assignment_show_max=*/debug_buffer_assignment_show_max, + /*debug_buffer_assignment_show_max=*/ + debug_buffer_assignment_show_max, /*debug_module=*/options.is_autotuning_compilation ? std::unique_ptr() : std::move(module), @@ -2282,13 +2427,13 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines( HloPassPipeline pipeline("fusion-wrapper"); pipeline.AddPass(); // Wrap remaining unfused ops that have no LHLO equivalent in single-op - // fusions. This needs to happen after rematerialization, because that will - // insert additional copies. + // fusions. This needs to happen after rematerialization, because that + // will insert additional copies. TF_RETURN_IF_ERROR(pipeline.Run(module).status()); } - // After we have a scheduled module and all operations wrapped into fusions we - // can decide how to wrap them into command buffers. + // After we have a scheduled module and all operations wrapped into fusions + // we can decide how to wrap them into command buffers. { HloPassPipeline pipeline("command-buffer-scheduling"); auto driver_version = se::gpu::GpuDriver::GetDriverVersion(); @@ -2325,8 +2470,8 @@ absl::Status GpuCompiler::SerializeAutotuneResultsToFile( if (absl::string_view file_path = debug_options.xla_gpu_dump_autotune_results_to(); !file_path.empty()) { - // Warning: This writes the autotune results at every compilation, possibly - // multiple times per process. + // Warning: This writes the autotune results at every compilation, + // possibly multiple times per process. TF_RETURN_IF_ERROR( AutotunerUtil::SerializeAutotuneResultsToFile(file_path)); } diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.h b/third_party/xla/xla/service/gpu/gpu_compiler.h index 37671ff9011bac..48d65c4602f9f1 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.h +++ b/third_party/xla/xla/service/gpu/gpu_compiler.h @@ -117,6 +117,11 @@ class GpuCompiler : public LLVMCompiler { virtual int32_t GetToolkitVersion() const = 0; + virtual absl::StatusOr CanUseLinkModules( + const HloModuleConfig& config) { + return false; + } + protected: struct BackendCompileResult { std::string asm_text; @@ -186,7 +191,7 @@ class GpuCompiler : public LLVMCompiler { se::StreamExecutor* executor, const CompileOptions& options, const se::DeviceDescription& gpu_device_info); - absl::StatusOr CompileToTargetBinary( + absl::StatusOr CompileAndLink( const HloModuleConfig& module_config, CompileModuleResults& compile_module_results, se::GpuComputeCapability gpu_version, se::StreamExecutor* stream_exec, @@ -227,11 +232,6 @@ class GpuCompiler : public LLVMCompiler { absl::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module); - virtual absl::StatusOr CanUseLinkModules( - const HloModuleConfig& config) { - return false; - } - virtual absl::StatusOr> LinkModules( se::StreamExecutor* stream_exec, std::vector> modules, diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc index d725e1b1e38636..1e4b66a7d6dad6 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -533,6 +533,119 @@ CHECK-SAME: frontend_attributes={_xla_send_recv_pipeline="0"} EXPECT_TRUE(filecheck_matched); } +class KernelCacheTest : public HloTestBase { + public: + void SetUp() override { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(bool can_use_link_modules, + dynamic_cast(backend().compiler()) + ->CanUseLinkModules(config)); + if (!can_use_link_modules) { + GTEST_SKIP() << "Caching compiled kernels requires support of linking."; + } + } + DebugOptions GetDebugOptionsForTest() override { + CHECK(tsl::Env::Default()->LocalTempFilename(&cache_file_name_)); + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_kernel_cache_file(cache_file_name_); + debug_options.set_xla_gpu_enable_llvm_module_compilation_parallelism(true); + return debug_options; + } + + bool CacheFileExists() { + if (!tsl::Env::Default()->FileExists(cache_file_name_).ok()) { + return false; + } + return true; + } + + bool NonEmptyCacheExists() { + if (!CacheFileExists()) { + return false; + } + std::string serialized; + TF_EXPECT_OK(tsl::ReadFileToString(tsl::Env::Default(), cache_file_name_, + &serialized)); + CompilationCacheProto proto; + EXPECT_TRUE(proto.ParseFromString(serialized)); + return proto.entries_size() > 0; + } + + std::string cache_file_name_; + static constexpr absl::string_view kHloText = R"( + ENTRY e { + p = s8[] parameter(0) + c = s8[] constant(8) + ROOT _ = s8[] add(p, c) + })"; +}; + +TEST_F(KernelCacheTest, CacheIsGenerated) { + // First run - no cache file + EXPECT_FALSE(CacheFileExists()); + EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); + // First run generates a cache + EXPECT_TRUE(NonEmptyCacheExists()); + // Second run - with cache file + EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); +} + +TEST_F(KernelCacheTest, NoCacheIsGeneratedWithoutCompiledKernels) { + EXPECT_FALSE(CacheFileExists()); + EXPECT_TRUE(Run(R"( + ENTRY e { + a = f32[5,5] parameter(0) + ROOT _ = f32[5,5] custom-call(a, a), custom_call_target="__cublas$gemm", + backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}}" + })", + /*run_hlo_passes=*/false)); + EXPECT_FALSE(CacheFileExists()); +} + +TEST_F(KernelCacheTest, UsingCacheFromAnotherModuleDoesNotFail) { + EXPECT_FALSE(CacheFileExists()); + EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); + EXPECT_TRUE(NonEmptyCacheExists()); + // Second run - with cache file and another HLO + EXPECT_TRUE(Run(R"( + ENTRY e { + p = s8[] parameter(0) + ROOT _ = s8[] multiply(p, p) + })", + /*run_hlo_passes=*/false)); +} + +class KernelCacheTestSingleThreaded : public KernelCacheTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = KernelCacheTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_force_compilation_parallelism(1); + return debug_options; + } +}; + +TEST_F(KernelCacheTestSingleThreaded, CacheIsGenerated) { + EXPECT_FALSE(CacheFileExists()); + EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); + EXPECT_TRUE(NonEmptyCacheExists()); + EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); +} + +class NoKernelCacheTest : public KernelCacheTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = KernelCacheTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_llvm_module_compilation_parallelism(false); + return debug_options; + } +}; + +TEST_F(NoKernelCacheTest, NoCacheWithoutCompilationParallelism) { + EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); + EXPECT_FALSE(NonEmptyCacheExists()); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/kernel_reuse_cache.cc b/third_party/xla/xla/service/gpu/kernel_reuse_cache.cc index 2d4a9d3356b672..3c2a53ac84d1a8 100644 --- a/third_party/xla/xla/service/gpu/kernel_reuse_cache.cc +++ b/third_party/xla/xla/service/gpu/kernel_reuse_cache.cc @@ -15,10 +15,12 @@ limitations under the License. #include "xla/service/gpu/kernel_reuse_cache.h" #include +#include #include #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -27,6 +29,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/kernel_arguments.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/util.h" #include "tsl/platform/logging.h" @@ -84,6 +89,60 @@ std::string GetComputationFingerprint( fused_computation->ToString(print_options)); } +absl::Status KernelReuseCache::Load(const CompilationCacheProto& proto) { + for (const auto& [name, entry] : proto.entries()) { + std::optional cluster_dim; + if (entry.has_cluster_dim()) { + cluster_dim = + se::ClusterDim{entry.cluster_dim().x(), entry.cluster_dim().y(), + entry.cluster_dim().z()}; + } + TF_RET_CHECK( + cache_ + .insert( + {entry.fingerprint(), + Entry{name, + LaunchDimensions{ + entry.launch_dimensions().num_blocks(), + entry.launch_dimensions().num_threads_per_block()}, + cluster_dim, entry.shmem_bytes(), entry.binary()}}) + .second); + } + + return absl::OkStatus(); +} + +CompilationCacheProto KernelReuseCache::Export() const { + CompilationCacheProto proto; + for (const auto& [fingerprint, cache_entry] : cache_) { + if (!hits_.contains(fingerprint)) { + VLOG(5) << "Not exporting unused " << cache_entry.kernel_name; + continue; + } + auto [it, inserted] = proto.mutable_entries()->emplace( + cache_entry.kernel_name, CompilationCacheEntryProto{}); + CHECK(inserted) << cache_entry.kernel_name; + CompilationCacheEntryProto& proto_entry = it->second; + proto_entry.set_fingerprint(fingerprint); + LaunchDimensionsProto launch_dimensions_proto; + launch_dimensions_proto.set_num_blocks( + cache_entry.launch_dimensions.num_blocks()); + launch_dimensions_proto.set_num_threads_per_block( + cache_entry.launch_dimensions.num_threads_per_block()); + *proto_entry.mutable_launch_dimensions() = launch_dimensions_proto; + if (cache_entry.cluster_dim.has_value()) { + ClusterDimProto cluster_dim_proto; + cluster_dim_proto.set_x(cache_entry.cluster_dim->x); + cluster_dim_proto.set_y(cache_entry.cluster_dim->y); + cluster_dim_proto.set_z(cache_entry.cluster_dim->z); + *proto_entry.mutable_cluster_dim() = cluster_dim_proto; + } + proto_entry.set_shmem_bytes(cache_entry.shmem_bytes); + proto_entry.set_binary(cache_entry.binary); + } + return proto; +} + std::pair, bool> KernelReuseCache::GetWithStatus( const HloComputation* fused_computation, @@ -101,6 +160,7 @@ std::pair, bool> KernelReuseCache::GetWithStatus( std::string fingerprint, const std::function()>& generator) { + hits_.insert(fingerprint); auto it = cache_.find(fingerprint); if (it != cache_.end()) { return {&it->second, /*was_cached=*/true}; diff --git a/third_party/xla/xla/service/gpu/kernel_reuse_cache.h b/third_party/xla/xla/service/gpu/kernel_reuse_cache.h index ea55d97dc43989..a66a5fac70dd50 100644 --- a/third_party/xla/xla/service/gpu/kernel_reuse_cache.h +++ b/third_party/xla/xla/service/gpu/kernel_reuse_cache.h @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/service/gpu/executable.pb.h" #include "xla/service/gpu/kernel_arguments.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/stream_executor/launch_dim.h" @@ -42,8 +43,19 @@ class KernelReuseCache { LaunchDimensions launch_dimensions; std::optional cluster_dim; int64_t shmem_bytes = 0; + std::string binary; }; + absl::Status Load(const CompilationCacheProto& proto); + // Exporting skips kernels that were loaded but not used during emission. + // See comment for hits_ below. + CompilationCacheProto Export() const; + bool IsEmpty() const { return cache_.empty(); } + void Clear() { + cache_.clear(); + hits_.clear(); + } + // Retrieves the cache entry for the given computation, or generates it using // the given generator function and stores it in the cache. // @@ -70,6 +82,10 @@ class KernelReuseCache { private: absl::flat_hash_map cache_; + // Track which fingerprints are in use. Unused ones can appear from loading a + // partially compatible cache file. These should not be exported to avoid + // linking the corresponding kernels later. + absl::flat_hash_set hits_; }; // Calculates the fingerprint of a (fused_computation, kernel_arguments, diff --git a/third_party/xla/xla/service/gpu/kernel_reuse_cache_test.cc b/third_party/xla/xla/service/gpu/kernel_reuse_cache_test.cc new file mode 100644 index 00000000000000..19b0c0d0d3f3c7 --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernel_reuse_cache_test.cc @@ -0,0 +1,44 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/kernel_reuse_cache.h" + +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +using KernelReuseTest = ::testing::Test; + +TEST_F(KernelReuseTest, ExportAndLoadWork) { + KernelReuseCache cache; + EXPECT_TRUE(cache.IsEmpty()); + auto [result, was_cached] = cache.GetWithStatus( + "fingerprint", []() { return KernelReuseCache::Entry{}; }); + TF_EXPECT_OK(result); + EXPECT_NE(result.value(), nullptr); + EXPECT_FALSE(was_cached); + EXPECT_FALSE(cache.IsEmpty()); + const CompilationCacheProto proto = cache.Export(); + cache.Clear(); + EXPECT_TRUE(cache.IsEmpty()); + TF_EXPECT_OK(cache.Load(proto)); + EXPECT_FALSE(cache.IsEmpty()); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.h b/third_party/xla/xla/service/gpu/nvptx_compiler.h index b400d041b89a61..f90846e157647b 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.h +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.h @@ -97,10 +97,10 @@ class NVPTXCompiler : public GpuCompiler { kDriver, }; - private: absl::StatusOr CanUseLinkModules( const HloModuleConfig& module_config) override; + private: absl::StatusOr> LinkModules( se::StreamExecutor* stream_exec, std::vector> modules, diff --git a/third_party/xla/xla/tests/collective_ops_test.cc b/third_party/xla/xla/tests/collective_ops_test.cc index e366264a1d4c4a..7809bc0f63b4ce 100644 --- a/third_party/xla/xla/tests/collective_ops_test.cc +++ b/third_party/xla/xla/tests/collective_ops_test.cc @@ -1985,18 +1985,30 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecv_TwoConcurrentChains)) { channel_id=0, frontend_attributes={ _xla_send_recv_source_target_pairs="{{0,1}}" } - recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=0 + + recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=0, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } recv-data.0 = u32[2] get-tuple-element(recv-done.0), index=0 - recv-done.1 = (u32[2], token[]) recv-done(recv.1), channel_id=0 + recv-done.1 = (u32[2], token[]) recv-done(recv.1), channel_id=0, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } recv-data.1 = u32[2] get-tuple-element(recv-done.1), index=0 compare0 = pred[] compare(replica, c0), direction=EQ compare = pred[2] broadcast(compare0), dimensions={} recv-data = u32[2] select(compare, recv-data.0, recv-data.1) - send-done.0 = token[] send-done(send.0), channel_id=0 - send-done.1 = token[] send-done(send.1), channel_id=0 - + send-done.0 = token[] send-done(send.0), channel_id=0, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } + send-done.1 = token[] send-done(send.1), channel_id=0, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } c1b = u32[2] broadcast(c1), dimensions={} ROOT result = u32[2] add(c1b, recv-data) })"; @@ -2042,9 +2054,15 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecv_ValidationAttr1)) { _xla_send_recv_source_target_pairs="{{1,0}}", _xla_send_recv_validation="invalid" } - recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=0 + recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=0, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } recv-data.0 = u32[2] get-tuple-element(recv-done.0), index=0 - send-done.0 = token[] send-done(send.0), channel_id=0 + send-done.0 = token[] send-done(send.0), channel_id=0, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } after-all.1 = token[] after-all() recv.1 = (u32[2], u32[], token[]) recv(after-all.1), channel_id=0, @@ -2055,13 +2073,19 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecv_ValidationAttr1)) { channel_id=0, frontend_attributes={ _xla_send_recv_source_target_pairs="{{0,1}}" } - recv-done.1 = (u32[2], token[]) recv-done(recv.1), channel_id=0 + recv-done.1 = (u32[2], token[]) recv-done(recv.1), channel_id=0, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } recv-data.1 = u32[2] get-tuple-element(recv-done.1), index=0 compare0 = pred[] compare(replica, c0), direction=EQ compare = pred[2] broadcast(compare0), dimensions={} recv-data = u32[2] select(compare, recv-data.0, recv-data.1) - send-done.1 = token[] send-done(send.1), channel_id=0 + send-done.1 = token[] send-done(send.1), channel_id=0, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } c1b = u32[2] broadcast(c1), dimensions={} ROOT result = u32[2] add(c1b, recv-data) @@ -2113,9 +2137,15 @@ body { _xla_send_recv_source_target_pairs="{{1,0}}", _xla_send_recv_validation="{{0,1}}" } - recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=0 + recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=0, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } recv-data.0 = u32[2] get-tuple-element(recv-done.0), index=0 - send-done.0 = token[] send-done(send.0), channel_id=0 + send-done.0 = token[] send-done(send.0), channel_id=0, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } after-all.1 = token[] after-all() recv.1 = (u32[2], u32[], token[]) recv(after-all.1), channel_id=0, @@ -2127,7 +2157,10 @@ body { frontend_attributes={ _xla_send_recv_source_target_pairs="{{0,1}}" } - recv-done.1 = (u32[2], token[]) recv-done(recv.1), channel_id=0 + recv-done.1 = (u32[2], token[]) recv-done(recv.1), channel_id=0, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } recv-data.1 = u32[2] get-tuple-element(recv-done.1), index=0 replica = u32[] replica-id() @@ -2142,7 +2175,10 @@ body { r = u32[2] broadcast(c1), dimensions={} s = u32[2] add(r, recv-data) - send-done.1 = token[] send-done(send.1), channel_id=0 + send-done.1 = token[] send-done(send.1), channel_id=0, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } ROOT result = (u32[], u32[2]) tuple(new_count, s) } diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 5afed9774d695b..b681ae11c1dd80 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -803,7 +803,9 @@ message DebugOptions { // details. bool xla_use_shardonnay = 302; - // Next id: 304 + string xla_gpu_kernel_cache_file = 304; + + // Next id: 305 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.