[go: nahoru, domu]

Skip to content

Commit

Permalink
Automated Code Change
Browse files Browse the repository at this point in the history
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#13781 from openxla:kernel_cache_grow f421890b6548a9a4dfcc5350e791bc8860615dbc
PiperOrigin-RevId: 643541625
  • Loading branch information
tensorflower-gardener committed Jun 18, 2024
1 parent 44cb866 commit 87dfff6
Show file tree
Hide file tree
Showing 18 changed files with 372 additions and 161 deletions.
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/scatter_nd_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ TEST_F(ScatterNdOpConstructionTest, Error_BadIndicesPolicyInvalid) {
.Input(FakeInput(DT_INT32))
.Attr("bad_indices_policy", "AN_UNRECOGNIZED_POLICY")
.Finalize(node_def()));
EXPECT_NE(InitOp(), OkStatus());
EXPECT_NE(InitOp(), absl::OkStatus());
}

class ScatterNdUpdateBM : public ScatterNdUpdateOpTest {
Expand Down
16 changes: 8 additions & 8 deletions third_party/xla/xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1760,14 +1760,14 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options->xla_gpu_shard_autotuning(),
"Shard autotuning between participating compiler processes (typically in "
"multi-host setups) and join the results when it's done."));
flag_list->push_back(tsl::Flag(
"xla_gpu_kernel_cache_file",
string_setter_for(&DebugOptions::set_xla_gpu_kernel_cache_file),
debug_options->xla_gpu_kernel_cache_file(),
"Path to a file to cache compiled kernels. If the file doesn't exist "
"write the compilation cache of the first compiled HLO module into it."
"Once the file exists, further compilations will read it to reuse "
"the kernels, but not write it. This behavior may change later."));
flag_list->push_back(
tsl::Flag("xla_gpu_kernel_cache_file",
string_setter_for(&DebugOptions::set_xla_gpu_kernel_cache_file),
debug_options->xla_gpu_kernel_cache_file(),
"Path to a file to cache compiled kernels. Cached kernels get "
"reused in further compilations; not yet cached kernels are "
"compiled as usual and get appended to the cache file whenever "
"possible."));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Expand Down
11 changes: 8 additions & 3 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3425,6 +3425,7 @@ cc_library(
":custom_kernel_fusion_rewriter",
":dot_dimension_sorter",
":dot_operand_converter",
":double_buffer_loop_unrolling",
":executable_proto_cc",
":fusion_merger",
":fusion_wrapper",
Expand All @@ -3451,7 +3452,7 @@ cc_library(
":instruction_fusion",
":ir_emission_utils",
":ir_emitter",
":double_buffer_loop_unrolling",
":kernel_reuse_cache",
":matmul_utils",
":metrics",
":move_copy_to_users",
Expand Down Expand Up @@ -3625,6 +3626,7 @@ cc_library(
":command_buffer_scheduling",
":execution_stream_assignment",
":fusion_pipeline",
":gpu_latency_hiding_scheduler",
":ir_emitter_context",
":ir_emitter_unnested",
":prepare_hlo_for_ir_emitting_pipeline",
Expand Down Expand Up @@ -3872,6 +3874,7 @@ xla_test(
deps = [
":gpu_constants",
":gpu_hlo_schedule",
":gpu_latency_hiding_scheduler",
":nvptx_compiler_impl",
"//xla:util",
"//xla:xla_proto_cc",
Expand Down Expand Up @@ -4231,11 +4234,9 @@ cc_library(
hdrs = ["gpu_hlo_schedule.h"],
deps = [
":backend_configs_cc",
":cublas_cudnn",
":gpu_latency_hiding_scheduler",
":gpu_schedule_postprocessing",
"//xla:shape_util",
"//xla:statusor",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_query",
Expand Down Expand Up @@ -5730,7 +5731,9 @@ xla_cc_test(
deps = [
":kernel_reuse_cache",
"//xla/tests:xla_internal_test_main",
"@com_google_googletest//:gtest",
"@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:test",
],
)
Expand Down Expand Up @@ -6257,6 +6260,8 @@ cc_library(
hdrs = ["gpu_latency_hiding_scheduler.h"],
deps = [
":backend_configs_cc",
":cublas_cudnn",
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_query",
"//xla/service:collective_ops_utils",
Expand Down
69 changes: 28 additions & 41 deletions third_party/xla/xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ limitations under the License.
#include "xla/service/gpu/gpu_executable.h"
#include "xla/service/gpu/gpu_float_support.h"
#include "xla/service/gpu/gpu_hlo_schedule.h"
#include "xla/service/gpu/gpu_latency_hiding_scheduler.h"
#include "xla/service/gpu/gpu_layout_assignment.h"
#include "xla/service/gpu/gpu_p2p_pipeliner.h"
#include "xla/service/gpu/gpu_reduce_scatter_creator.h"
Expand All @@ -147,6 +148,7 @@ limitations under the License.
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/ir_emitter_context.h"
#include "xla/service/gpu/ir_emitter_unnested.h"
#include "xla/service/gpu/kernel_reuse_cache.h"
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/gpu/metrics.h"
#include "xla/service/gpu/model/gpu_cost_model_stats_collection.h"
Expand Down Expand Up @@ -1922,12 +1924,7 @@ absl::StatusOr<GpuCompiler::BackendCompileResult> GpuCompiler::CompileAndLink(
std::string ptx_snippets;
std::vector<std::vector<uint8_t>> binaries_to_link;
binaries_to_link.reserve(compile_results.size());
struct NamedBinary {
// The string is the function name or empty just like for llvm_modules.
std::string name;
std::vector<uint8_t> binary;
};
std::vector<NamedBinary> binaries_to_cache;
std::vector<KernelReuseCache::NamedBinary> binaries_to_cache;
binaries_to_cache.reserve(single_function_module_count);
for (const auto& [name, maybe_result] : compile_results) {
TF_ASSIGN_OR_RETURN(auto result, maybe_result);
Expand All @@ -1948,51 +1945,40 @@ absl::StatusOr<GpuCompiler::BackendCompileResult> GpuCompiler::CompileAndLink(
return FailedPrecondition("File path can not be resolved: %s",
cache_path);
}
CompilationCacheProto& cache =
// current_cache contains new kernels from the current compilation and
// kernels to reuse from previous compilations if some were loaded from the
// cache file.
const CompilationCacheProto& current_cache =
compile_module_results.kernel_compilation_cache;
if (tsl::Env::Default()->FileExists(resolved_path).ok()) {
const bool cache_file_exists =
tsl::Env::Default()->FileExists(resolved_path).ok();
if (cache_file_exists) {
// Pick reused binaries from previous compilations needed to link the
// current executable.
int loaded_kernel_count = 0;
for (const auto& [name, entry] : cache.entries()) {
if (llvm_module->getFunction(name)) {
VLOG(5)
<< "Skipping cached " << name
<< " in favor of the just compiled kernel with the same name.";
CHECK(entry.binary().empty());
for (const auto& [name, entry] : current_cache.entries()) {
if (llvm_module->getFunction(name) != nullptr) {
VLOG(5) << "Using the just compiled kernel for " << name;
TF_RET_CHECK(entry.binary().empty())
<< name
<< " is a just compiled kernel and is not expected to have a "
"binary yet.";
continue;
}
const uint8_t* binary =
reinterpret_cast<const uint8_t*>(entry.binary().data());
binaries_to_link.push_back(
std::vector<uint8_t>(binary, binary + entry.binary().size()));
VLOG(5) << "Loaded " << name << ": " << entry.binary().size();
VLOG(5) << "Using " << name << " from cache: " << entry.binary().size();
++loaded_kernel_count;
}
VLOG(2) << "Loaded " << loaded_kernel_count << " / "
<< cache.entries_size() << " cached kernels.";
} else {
auto entries = cache.mutable_entries();
for (const auto& [name, binary] : binaries_to_cache) {
auto it = entries->find(name);
if (it == entries->end()) {
continue;
}
it->second.set_binary(reinterpret_cast<const char*>(binary.data()),
binary.size());
VLOG(5) << "Cached kernels: " << name << ": " << binary.size();
}
for (auto it = entries->begin(); it != entries->end();) {
if (it->second.binary().empty()) {
it = entries->erase(it);
} else {
++it;
}
}
if (cache.entries_size() > 0) {
TF_RETURN_IF_ERROR(tsl::WriteStringToFile(
tsl::Env::Default(), resolved_path, cache.SerializeAsString()));
VLOG(2) << "Stored " << cache.entries_size() << " / "
<< binaries_to_cache.size();
}
VLOG(2) << "Using " << loaded_kernel_count << " / "
<< current_cache.entries_size() << " cached kernels.";
}
if (!binaries_to_cache.empty()) {
TF_RETURN_IF_ERROR(
UpdateDiskKernelCache(resolved_path, /*do_append=*/cache_file_exists,
current_cache, binaries_to_cache));
}
}

Expand All @@ -2007,6 +1993,7 @@ absl::StatusOr<GpuCompiler::BackendCompileResult> GpuCompiler::CompileAndLink(
return maybe_backend_result.status();
}
VLOG(4) << "Binary size after linking [B]: " << maybe_backend_result->size();
compile_module_results.kernel_compilation_cache.Clear();
return BackendCompileResult{ptx_snippets, std::move(*maybe_backend_result)};
}

Expand Down
22 changes: 13 additions & 9 deletions third_party/xla/xla/service/gpu/gpu_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ CHECK-SAME: frontend_attributes={_xla_send_recv_pipeline="0"}
class KernelCacheTest : public HloTestBase {
public:
void SetUp() override {
CHECK(tsl::Env::Default()->LocalTempFilename(&cache_file_name_));
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(bool can_use_link_modules,
Expand All @@ -568,8 +569,8 @@ class KernelCacheTest : public HloTestBase {
GTEST_SKIP() << "Caching compiled kernels requires support of linking.";
}
}

DebugOptions GetDebugOptionsForTest() override {
CHECK(tsl::Env::Default()->LocalTempFilename(&cache_file_name_));
DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
debug_options.set_xla_gpu_kernel_cache_file(cache_file_name_);
debug_options.set_xla_gpu_enable_llvm_module_compilation_parallelism(true);
Expand All @@ -583,16 +584,16 @@ class KernelCacheTest : public HloTestBase {
return true;
}

bool NonEmptyCacheExists() {
int CacheEntryCount() {
if (!CacheFileExists()) {
return false;
return 0;
}
std::string serialized;
TF_EXPECT_OK(tsl::ReadFileToString(tsl::Env::Default(), cache_file_name_,
&serialized));
CompilationCacheProto proto;
EXPECT_TRUE(proto.ParseFromString(std::string(serialized)));
return proto.entries_size() > 0;
return proto.entries_size();
}

std::string cache_file_name_;
Expand All @@ -609,9 +610,10 @@ TEST_F(KernelCacheTest, CacheIsGenerated) {
EXPECT_FALSE(CacheFileExists());
EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false));
// First run generates a cache
EXPECT_TRUE(NonEmptyCacheExists());
EXPECT_EQ(CacheEntryCount(), 1);
// Second run - with cache file
EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false));
EXPECT_EQ(CacheEntryCount(), 1);
}

TEST_F(KernelCacheTest, NoCacheIsGeneratedWithoutCompiledKernels) {
Expand All @@ -626,17 +628,18 @@ TEST_F(KernelCacheTest, NoCacheIsGeneratedWithoutCompiledKernels) {
EXPECT_FALSE(CacheFileExists());
}

TEST_F(KernelCacheTest, UsingCacheFromAnotherModuleDoesNotFail) {
TEST_F(KernelCacheTest, CacheGrowsWithNewKernels) {
EXPECT_FALSE(CacheFileExists());
EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false));
EXPECT_TRUE(NonEmptyCacheExists());
EXPECT_EQ(CacheEntryCount(), 1);
// Second run - with cache file and another HLO
EXPECT_TRUE(Run(R"(
ENTRY e {
p = s8[] parameter(0)
ROOT _ = s8[] multiply(p, p)
})",
/*run_hlo_passes=*/false));
EXPECT_EQ(CacheEntryCount(), 2);
}

class KernelCacheTestSingleThreaded : public KernelCacheTest {
Expand All @@ -651,8 +654,9 @@ class KernelCacheTestSingleThreaded : public KernelCacheTest {
TEST_F(KernelCacheTestSingleThreaded, CacheIsGenerated) {
EXPECT_FALSE(CacheFileExists());
EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false));
EXPECT_TRUE(NonEmptyCacheExists());
EXPECT_EQ(CacheEntryCount(), 1);
EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false));
EXPECT_EQ(CacheEntryCount(), 1);
}

class NoKernelCacheTest : public KernelCacheTest {
Expand All @@ -666,7 +670,7 @@ class NoKernelCacheTest : public KernelCacheTest {

TEST_F(NoKernelCacheTest, NoCacheWithoutCompilationParallelism) {
EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false));
EXPECT_FALSE(NonEmptyCacheExists());
EXPECT_FALSE(CacheFileExists());
}

} // namespace
Expand Down
Loading

0 comments on commit 87dfff6

Please sign in to comment.