[go: nahoru, domu]

Skip to content

Commit

Permalink
Refactor llvm_compiler_test.
Browse files Browse the repository at this point in the history
We can run the CpuCompiler and GPUCompiler related tests in separate test
targets.

PiperOrigin-RevId: 644307904
  • Loading branch information
akuegel authored and tensorflower-gardener committed Jun 18, 2024
1 parent aa29a22 commit fe06dc7
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 170 deletions.
7 changes: 6 additions & 1 deletion third_party/xla/xla/service/gpu/nvptx_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization(
HloModule* hlo_module, se::GpuComputeCapability gpu_version,
se::dnn::VersionInfo dnn_version,
se::DeviceMemoryAllocator* device_allocator) {
LOG(ERROR) << "Running OptimizeHloConvolutionCanonicalization()";
auto cuda_compute_capability =
std::get<se::CudaComputeCapability>(gpu_version);
// Convert convolutions into CustomCalls to cudnn, then canonicalize them
Expand Down Expand Up @@ -254,6 +255,7 @@ absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization(
// by constant folding.
pipeline.AddPass<HloConstantFolding>();
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
LOG(ERROR) << "Finished running OptimizeHloConvolutionCanonicalization()";

return absl::OkStatus();
}
Expand All @@ -262,6 +264,7 @@ absl::Status NVPTXCompiler::OptimizeHloPostLayoutAssignment(
HloModule* hlo_module, se::StreamExecutor* stream_exec,
const CompileOptions& options, const TargetConfig& gpu_target_config,
tsl::thread::ThreadPool* thread_pool) {
LOG(ERROR) << "Running OptimizeHloPostLayoutAssignment()";
// This needs to run before GemmRewriter, which is part of
// OptimizeHloPostLayoutAssignment().
auto cuda_compute_capability = std::get<se::CudaComputeCapability>(
Expand Down Expand Up @@ -339,7 +342,7 @@ absl::Status NVPTXCompiler::OptimizeHloPostLayoutAssignment(
post_pipeline.AddPass<CuDnnWorkspaceRewriter>(*stream_exec);
}
TF_RETURN_IF_ERROR(post_pipeline.Run(hlo_module).status());

LOG(ERROR) << "Finished running OptimizeHloPostLayoutAssignment()";
return absl::OkStatus();
}

Expand Down Expand Up @@ -535,6 +538,7 @@ NVPTXCompiler::CompileTargetBinary(const HloModuleConfig& module_config,
bool relocatable,
const HloModule* debug_module,
const CompileOptions& options) {
LOG(ERROR) << "Running CompileTargetBinary()";
std::unique_ptr<llvm::Module> loaded_module =
MaybeLoadLLVMFromFile(debug_module, llvm_module);
llvm::Module* selected_module = nullptr;
Expand Down Expand Up @@ -574,6 +578,7 @@ NVPTXCompiler::CompileTargetBinary(const HloModuleConfig& module_config,
if (!maybe_cubin.ok()) {
return maybe_cubin.status();
}
LOG(ERROR) << "Finished unning CompileTargetBinary()";
return BackendCompileResult{std::move(ptx), std::move(maybe_cubin.value())};
}

Expand Down
38 changes: 17 additions & 21 deletions third_party/xla/xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@ load(
"if_cuda_is_configured",
)
load("//xla:xla.bzl", "tests_build_defs_bzl_deps", "xla_cc_binary", "xla_cc_test")
load(
"//xla/stream_executor:build_defs.bzl",
"if_gpu_is_configured",
)
load("//xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend_test_macros", "xla_test", "xla_test_library")
load("//xla/tsl:tsl.bzl", "internal_visibility", "tsl_copts")
load("//xla/tsl:tsl.default.bzl", "filegroup")
Expand Down Expand Up @@ -2457,32 +2453,32 @@ xla_test(
xla_test(
name = "llvm_compiler_test",
srcs = ["llvm_compiler_test.cc"],
backends = ["gpu"],
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
"TENSORFLOW_USE_ROCM",
]),
# TODO(b/317293391) Remove once Bazel test_suite handles tags correctly
tags = ["gpu"],
deps = if_gpu_is_configured([
backend_tags = {
# TODO(b/347860572): Remove tag when asan bug is fixed.
"cpu": ["noasan"],
# TODO(b/317293391) Remove once Bazel test_suite handles tags correctly.
"gpu": ["gpu"],
},
backends = [
"cpu",
"gpu",
],
deps = [
":verified_hlo_module",
"//xla:literal_util",
"//xla:test_helpers",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/status",
"//xla/hlo/ir:hlo_module_group",
"//xla/service:backend",
"//xla/service:cpu_plugin",
"//xla/service:llvm_compiler",
"//xla/service:platform_util",
"//xla/service/cpu:cpu_compiler",
"//xla/service/gpu:gpu_compiler",
"//xla/stream_executor",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest_main",
"@llvm-project//llvm:Core",
"@local_tsl//tsl/platform:casts",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:test_main",
]) + if_cuda_is_configured([
"//xla/stream_executor/cuda:cuda_platform_id",
]) + if_rocm_is_configured([
"//xla/stream_executor/rocm:rocm_platform_id",
]),
],
)

xla_test(
Expand Down
205 changes: 57 additions & 148 deletions third_party/xla/xla/tests/llvm_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,155 +21,35 @@ limitations under the License.
#include <utility>
#include <vector>

#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "llvm/IR/Module.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module_group.h"
#include "xla/literal_util.h"
#include "xla/service/backend.h"
#include "xla/service/cpu/cpu_compiler.h"
#include "xla/service/gpu/gpu_compiler.h"
#include "xla/service/platform_util.h"
#if GOOGLE_CUDA
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#elif TENSORFLOW_USE_ROCM
#include "xla/stream_executor/rocm/rocm_platform_id.h"
#endif
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/test_helpers.h"
#include "xla/tests/verified_hlo_module.h"
#include "tsl/platform/casts.h"
#include "tsl/platform/threadpool.h"

namespace xla {
namespace gpu {

// Creating dummy data structure needed to initialize a GpuDummyCompiler
constexpr char kDummyTriple[] = "dummy-triple";
constexpr char kDummyLayout[] = "e";
const se::Platform::Id kGpuPlatformId =
#if GOOGLE_CUDA
se::cuda::kCudaPlatformId;
#elif TENSORFLOW_USE_ROCM
se::rocm::kROCmPlatformId;
#endif
// This class is a dummy implementation of GpuCompiler and is targeted for unit
// test only
class GpuDummyCompiler : public GpuCompiler {
public:
GpuDummyCompiler()
: GpuCompiler(kGpuPlatformId, kDummyTriple, kDummyLayout) {}

int32_t GetToolkitVersion() const override { return 0; }

absl::Status OptimizeHloConvolutionCanonicalization(
HloModule* hlo_module, se::GpuComputeCapability gpu_version,
se::dnn::VersionInfo dnn_version,
se::DeviceMemoryAllocator* device_allocator) {
return absl::OkStatus();
}

absl::Status OptimizeHloPostLayoutAssignment(
HloModule* hlo_module, se::StreamExecutor* stream_executor,
const CompileOptions& options, const TargetConfig& gpu_target_config,
tsl::thread::ThreadPool* thread_pool) override {
return absl::OkStatus();
}

absl::StatusOr<GpuCompiler::BackendCompileResult> CompileTargetBinary(
const HloModuleConfig& module_config, llvm::Module* llvm_module,
se::GpuComputeCapability gpu_version, bool relocatable,
const HloModule* debug_module, const CompileOptions& options) override {
return BackendCompileResult{};
}
};
} // namespace gpu

namespace {

class LLVMCompilerTest : public ::testing::Test {
public:
void SetUp() override {
Platform* platform = FindPlatform();
ASSERT_NE(platform, nullptr);

BackendOptions backend_options;
backend_options.set_platform(platform);
absl::StatusOr<std::unique_ptr<Backend>> backend_or_status =
Backend::CreateBackend(backend_options);
Backend::CreateDefaultBackend();
ASSERT_IS_OK(backend_or_status.status());
backend_ = std::move(backend_or_status).value();
}

~LLVMCompilerTest() override {}

protected:
using Platform = se::Platform;

explicit LLVMCompilerTest(std::string platform_name)
: platform_name_(std::move(platform_name)) {}

void TestCompilerHooks(LLVMCompiler* compiler) {
int pre_opt_hook_call_count = 0;
int post_opt_hook_call_count = 0;

auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module&) {
++pre_opt_hook_call_count;
return absl::OkStatus();
};
auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module&) {
++post_opt_hook_call_count;
return absl::OkStatus();
};

// Create HLO module, and run the compiler.
auto builder = HloComputation::Builder(TestName());
builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));

auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build());

compiler->SetPreOptimizationHook(pre_opt_hook);
compiler->SetPostOptimizationHook(post_opt_hook);

ASSERT_TRUE(compiler
->RunBackend(std::move(hlo_module),
backend_->default_stream_executor(),
/*device_allocator=*/nullptr)
.ok());

// Test that hooks were called.
EXPECT_EQ(1, pre_opt_hook_call_count);
EXPECT_EQ(1, post_opt_hook_call_count);
}

void TestMultiModuleCompilation(LLVMCompiler* compiler) {
HloComputation::Builder builder(TestName());
builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));

std::unique_ptr<HloModule> hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build());

auto module_group = std::make_unique<HloModuleGroup>("test_module_group");
module_group->push_back(hlo_module->Clone());
module_group->push_back(std::move(hlo_module));

std::vector<std::vector<se::StreamExecutor*>> executors;
executors.push_back({backend_->default_stream_executor()});
executors.push_back({backend_->default_stream_executor()});

EXPECT_IS_OK(compiler->Compile(std::move(module_group),
std::move(executors),
/*device_allocator=*/nullptr));
}

private:
Platform* FindPlatform() {
auto status_or_platform = PlatformUtil::GetPlatform(platform_name_);
return status_or_platform.ok() ? status_or_platform.value() : nullptr;
}

std::string platform_name_;
std::unique_ptr<Backend> backend_;

static std::string TestName() {
Expand All @@ -186,34 +66,63 @@ class LLVMCompilerTest : public ::testing::Test {
}
};

class CpuCompilerTest : public LLVMCompilerTest {
public:
CpuCompilerTest() : LLVMCompilerTest("Host") {}
};
TEST_F(LLVMCompilerTest, HooksTest) {
int pre_opt_hook_call_count = 0;
int post_opt_hook_call_count = 0;

class GpuCompilerTest : public LLVMCompilerTest {
public:
GpuCompilerTest() : LLVMCompilerTest("GPU") {}
};

TEST_F(CpuCompilerTest, HooksTest) {
cpu::CpuCompiler compiler;
TestCompilerHooks(&compiler);
auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module&) {
++pre_opt_hook_call_count;
return absl::OkStatus();
};
auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module&) {
++post_opt_hook_call_count;
return absl::OkStatus();
};

// Create HLO module, and run the compiler.
auto builder = HloComputation::Builder(TestName());
builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));

auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build());

LLVMCompiler* compiler =
tensorflow::down_cast<xla::LLVMCompiler*>(backend_->compiler());
compiler->SetPreOptimizationHook(pre_opt_hook);
compiler->SetPostOptimizationHook(post_opt_hook);

ASSERT_TRUE(compiler
->RunBackend(std::move(hlo_module),
backend_->default_stream_executor(),
/*device_allocator=*/nullptr)
.ok());

// Test that hooks were called.
EXPECT_EQ(1, pre_opt_hook_call_count);
EXPECT_EQ(1, post_opt_hook_call_count);
}

TEST_F(GpuCompilerTest, HooksTest) {
gpu::GpuDummyCompiler compiler;
TestCompilerHooks(&compiler);
}
TEST_F(LLVMCompilerTest, MultiModuleCompilation) {
HloComputation::Builder builder(TestName());
builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));

TEST_F(CpuCompilerTest, CpuMultiModuleCompilation) {
cpu::CpuCompiler compiler;
TestMultiModuleCompilation(&compiler);
}
std::unique_ptr<HloModule> hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build());

auto module_group = std::make_unique<HloModuleGroup>("test_module_group");
module_group->push_back(hlo_module->Clone());
module_group->push_back(std::move(hlo_module));

TEST_F(GpuCompilerTest, GpuMultModuleCompilation) {
gpu::GpuDummyCompiler compiler;
TestMultiModuleCompilation(&compiler);
std::vector<std::vector<se::StreamExecutor*>> executors;
executors.push_back({backend_->default_stream_executor()});
executors.push_back({backend_->default_stream_executor()});

EXPECT_IS_OK(backend_->compiler()->Compile(std::move(module_group),
std::move(executors),
/*device_allocator=*/nullptr));
}

} // namespace
} // namespace xla

0 comments on commit fe06dc7

Please sign in to comment.