[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:GPU] Extract useful utils from ir_emitter_triton_test into their…
Browse files Browse the repository at this point in the history
… own file.

PiperOrigin-RevId: 642300620
  • Loading branch information
dimitar-asenov authored and tensorflower-gardener committed Jun 11, 2024
1 parent 8a69fee commit c7515ca
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 99 deletions.
34 changes: 28 additions & 6 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -637,22 +637,18 @@ xla_test(
deps = [
":backend_configs_cc",
":gpu_device_info_for_tests",
":hlo_fusion_analysis",
":ir_emission_utils",
":ir_emitter_triton",
":matmul_utils",
":triton_fusion_analysis",
":triton_test_utils",
"//xla:autotuning_proto_cc",
"//xla:error_spec",
"//xla:literal",
"//xla:literal_util",
"//xla:status_macros",
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/service:pattern_matcher",
"//xla/service:pattern_matcher_gmock",
"//xla/service/gpu/fusions:triton",
"//xla/service/gpu/tests:gpu_codegen_test",
"//xla/stream_executor:device_description",
"//xla/stream_executor/cuda:cublas_plugin",
"//xla/tests:filecheck",
Expand All @@ -663,7 +659,6 @@ xla_test(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:ir_headers",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
Expand All @@ -678,6 +673,33 @@ xla_test(
],
)

cc_library(
name = "triton_test_utils",
testonly = True,
srcs = ["triton_test_utils.cc"],
hdrs = ["triton_test_utils.h"],
deps = [
":gpu_device_info_for_tests",
":hlo_fusion_analysis",
":ir_emission_utils",
":ir_emitter_triton",
":matmul_utils",
":triton_fusion_analysis",
"//xla:status_macros",
"//xla/hlo/ir:hlo",
"//xla/service/gpu/fusions:triton",
"//xla/service/gpu/tests:gpu_codegen_test",
"//xla/stream_executor:device_description",
"//xla/tests:filecheck",
"//xla/tests:verified_hlo_module",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:string_view",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@local_tsl//tsl/platform:statusor",
],
)

cc_test(
name = "ir_emitter_triton_mem_utils_test",
srcs = if_cuda_is_configured(["ir_emitter_triton_mem_utils_test.cc"]),
Expand Down
93 changes: 1 addition & 92 deletions third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ limitations under the License.
#include "absl/strings/substitute.h"
#include "absl/types/span.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "xla/autotuning.pb.h"
Expand All @@ -46,16 +45,12 @@ limitations under the License.
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/fusions/triton.h"
#include "xla/service/gpu/gpu_device_info_for_tests.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/gpu/tests/gpu_codegen_test.h"
#include "xla/service/gpu/triton_fusion_analysis.h"
#include "xla/service/gpu/triton_test_utils.h"
#include "xla/service/pattern_matcher.h"
#include "xla/service/pattern_matcher_gmock.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_description.h"
#include "xla/tests/filecheck.h"
#include "xla/tests/verified_hlo_module.h"
Expand All @@ -78,42 +73,6 @@ namespace {

namespace m = ::xla::match;

class TritonTest : public GpuCodegenTest {
protected:
const auto& device_desc() {
return backend().default_stream_executor()->GetDeviceDescription();
}

public:
se::CudaComputeCapability GetCudaComputeCapability() {
return backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability();
}

const se::GpuComputeCapability& GpuComputeComp() {
return device_desc().gpu_compute_capability();
}

bool SkipBF16Tests() {
if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {
auto rcc = device_desc().rocm_compute_capability();
return !rcc.has_bf16_dtype_support();
}
return false;
}

se::GpuComputeCapability CudaAmpereOrRocm() {
if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {
return se::GpuComputeCapability{device_desc().rocm_compute_capability()};
} else {
return se::GpuComputeCapability{
se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0}};
}
}
};

class TritonGemmTest : public TritonTest {
public:
DebugOptions GetDebugOptionsForTest() override {
Expand Down Expand Up @@ -153,56 +112,6 @@ class TritonGemmTestWithoutTritonGemmAny : public TritonGemmTest {
}
};

class TritonFilecheckTest : public TritonTest {
public:
absl::Status CreateTritonIrAndFileCheck(
absl::string_view hlo_text, const TritonGemmConfig& config,
std::vector<int64_t> output_tile_sizes, TritonIrEmitter emitter,
absl::string_view triton_fusion_name,
absl::string_view filecheck_pattern);
};

absl::Status TritonFilecheckTest::CreateTritonIrAndFileCheck(
absl::string_view hlo_text, const TritonGemmConfig& config,
std::vector<int64_t> output_tile_sizes, TritonIrEmitter emitter,
absl::string_view triton_fusion_name, absl::string_view filecheck_pattern) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> verified_module,
ParseAndReturnVerifiedModule(hlo_text));

auto* computation =
verified_module->GetComputationWithName(triton_fusion_name);
auto* fusion = Cast<HloFusionInstruction>(computation->FusionInstruction());
TF_RET_CHECK(computation != nullptr);
TF_ASSIGN_OR_RETURN(auto analysis,
TritonFusionAnalysis::Execute(*computation));

auto fusion_analysis = HloFusionAnalysis::Create(fusion, &device_desc());

if (fusion_analysis.fusion_backend_config().kind() ==
kTritonSoftmaxFusionKind) {
TritonFusion triton_fusion(fusion_analysis);
if (auto launch_config = triton_fusion.launch_config()) {
output_tile_sizes = launch_config->output_tile_sizes;
}
}

mlir::MLIRContext context;
TF_ASSIGN_OR_RETURN(
auto module,
CreateTritonModule(analysis, "triton_fn", fusion,
TestGpuDeviceInfo::RTXA6000DeviceInfo(), config,
output_tile_sizes, emitter, context));

std::string out;
llvm::raw_string_ostream os(out);
module->print(os);
TF_ASSIGN_OR_RETURN(bool succeeded, RunFileCheck(out, filecheck_pattern));
if (!succeeded) {
return absl::InternalError("FileCheck failed.");
}
return absl::OkStatus();
}

TEST_F(TritonFilecheckTest, TestGemm) {
const std::string kHloText = R"(
HloModule t, is_scheduled=true
Expand Down
106 changes: 106 additions & 0 deletions third_party/xla/xla/service/gpu/triton_test_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/gpu/triton_test_utils.h"

#include <cstdint>
#include <memory>
#include <string>
#include <variant>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/service/gpu/fusions/triton.h"
#include "xla/service/gpu/gpu_device_info_for_tests.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/ir_emitter_triton.h"
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/gpu/triton_fusion_analysis.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_description.h"
#include "xla/tests/filecheck.h"
#include "xla/tests/verified_hlo_module.h"
#include "tsl/platform/statusor.h"

namespace xla::gpu {
bool TritonTest::SkipBF16Tests() {
if (std::holds_alternative<stream_executor::RocmComputeCapability>(
GpuComputeComp())) {
auto rcc = device_desc().rocm_compute_capability();
return !rcc.has_bf16_dtype_support();
}
return false;
}

stream_executor::GpuComputeCapability TritonTest::CudaAmpereOrRocm() {
if (std::holds_alternative<stream_executor::RocmComputeCapability>(
GpuComputeComp())) {
return stream_executor::GpuComputeCapability{
device_desc().rocm_compute_capability()};
} else {
return stream_executor::GpuComputeCapability{
stream_executor::CudaComputeCapability{
stream_executor::CudaComputeCapability::AMPERE, 0}};
}
}

absl::Status TritonFilecheckTest::CreateTritonIrAndFileCheck(
absl::string_view hlo_text, const TritonGemmConfig& config,
std::vector<int64_t> output_tile_sizes, TritonIrEmitter emitter,
absl::string_view triton_fusion_name, absl::string_view filecheck_pattern) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> verified_module,
ParseAndReturnVerifiedModule(hlo_text));

auto* computation =
verified_module->GetComputationWithName(triton_fusion_name);
auto* fusion = Cast<HloFusionInstruction>(computation->FusionInstruction());
TF_RET_CHECK(computation != nullptr);
TF_ASSIGN_OR_RETURN(auto analysis,
TritonFusionAnalysis::Execute(*computation));

auto fusion_analysis = HloFusionAnalysis::Create(fusion, &device_desc());

if (fusion_analysis.fusion_backend_config().kind() ==
kTritonSoftmaxFusionKind) {
TritonFusion triton_fusion(fusion_analysis);
if (auto launch_config = triton_fusion.launch_config()) {
output_tile_sizes = launch_config->output_tile_sizes;
}
}

mlir::MLIRContext context;
TF_ASSIGN_OR_RETURN(
auto module,
CreateTritonModule(analysis, "triton_fn", fusion,
TestGpuDeviceInfo::RTXA6000DeviceInfo(), config,
output_tile_sizes, emitter, context));

std::string out;
llvm::raw_string_ostream os(out);
module->print(os);
TF_ASSIGN_OR_RETURN(bool succeeded, RunFileCheck(out, filecheck_pattern));
if (!succeeded) {
return absl::InternalError("FileCheck failed.");
}
return absl::OkStatus();
}

} // namespace xla::gpu
64 changes: 64 additions & 0 deletions third_party/xla/xla/service/gpu/triton_test_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_SERVICE_GPU_TRITON_TEST_UTILS_H_
#define XLA_SERVICE_GPU_TRITON_TEST_UTILS_H_

#include <cstdint>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "xla/service/gpu/ir_emitter_triton.h"
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/gpu/tests/gpu_codegen_test.h"
#include "xla/stream_executor/device_description.h"

namespace xla::gpu {

class TritonTest : public GpuCodegenTest {
public:
stream_executor::CudaComputeCapability GetCudaComputeCapability() {
return backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability();
}

const stream_executor::GpuComputeCapability& GpuComputeComp() {
return device_desc().gpu_compute_capability();
}

bool SkipBF16Tests();
stream_executor::GpuComputeCapability CudaAmpereOrRocm();

protected:
const stream_executor::DeviceDescription& device_desc() {
return backend().default_stream_executor()->GetDeviceDescription();
}
};

class TritonFilecheckTest : public TritonTest {
public:
absl::Status CreateTritonIrAndFileCheck(
absl::string_view hlo_text, const TritonGemmConfig& config,
std::vector<int64_t> output_tile_sizes, TritonIrEmitter emitter,
absl::string_view triton_fusion_name,
absl::string_view filecheck_pattern);
};

} // namespace xla::gpu

#endif // XLA_SERVICE_GPU_TRITON_TEST_UTILS_H_
1 change: 0 additions & 1 deletion third_party/xla/xla/stream_executor/device_description.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ class RocmComputeCapability {
: gcn_arch_name_(proto.gcn_arch_name()) {}

RocmComputeCapability() = default;
~RocmComputeCapability() = default;

std::string gcn_arch_name() const { return gcn_arch_name_; }

Expand Down

0 comments on commit c7515ca

Please sign in to comment.