[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:GPU][NFC] Consolidate error raising for invalid compute capabili…
Browse files Browse the repository at this point in the history
…ty with Triton.

PiperOrigin-RevId: 647664942
  • Loading branch information
bchetioui authored and tensorflower-gardener committed Jun 28, 2024
1 parent 4172d72 commit 0609624
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 38 deletions.
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1258,6 +1258,8 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@local_tsl//tsl/platform:tensor_float_32_utils",
],
)
Expand Down
13 changes: 1 addition & 12 deletions third_party/xla/xla/service/gpu/gemm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -798,18 +798,7 @@ bool ShouldTritonHandleGEMM(HloDotInstruction& dot,
absl::StatusOr<bool> GemmFusion::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
auto cuda_compute_capability =
std::get_if<se::CudaComputeCapability>(&gpu_version_);
if (!cuda_compute_capability) {
return absl::FailedPreconditionError(
"Triton support is only enabled for CUDA GPUs.");
} else if (!cuda_compute_capability->IsAtLeastAmpere()) {
return absl::FailedPreconditionError(
absl::StrCat("Triton support is only enabled for Ampere GPUs (compute ",
"capability 8.0) and up, but got compute capability ",
cuda_compute_capability->major, ".",
cuda_compute_capability->minor, "."));
}
TF_RETURN_IF_ERROR(EnsureTritonSupportsComputeCapability(gpu_version_));

bool changed = false;
for (HloComputation* computation :
Expand Down
8 changes: 2 additions & 6 deletions third_party/xla/xla/service/gpu/gemm_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,7 @@ ENTRY e {
"(compute capability 8.0) and up, but got")));
}

TEST_F(GemmFusionLevel2Test, GemmFusionBailsOutOnNonCudaGpu) {
TEST_F(GemmFusionLevel2Test, GemmFusionSucceedsOnNonCudaGpu) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
ENTRY e {
Expand All @@ -887,11 +887,7 @@ ENTRY e {
ROOT dot = f32[2,2] dot(p0e, p1c),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
})"));
EXPECT_THAT(
GemmFusion(se::RocmComputeCapability{}).Run(module.get()),
tsl::testing::StatusIs(
absl::StatusCode::kFailedPrecondition,
::testing::StrEq("Triton support is only enabled for CUDA GPUs.")));
EXPECT_TRUE(GemmFusion(se::RocmComputeCapability{}).Run(module.get()).ok());
}

TEST_F(GemmFusionLevel2Test, ParameterUsedElementwiseTwiceIsFused) {
Expand Down
14 changes: 2 additions & 12 deletions third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -721,18 +721,8 @@ absl::Status SoftmaxRewriterTriton::FuseDiamondChain(
absl::StatusOr<bool> SoftmaxRewriterTriton::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
auto cuda_compute_capability = std::get_if<se::CudaComputeCapability>(
&device_info_.gpu_compute_capability());
if (!cuda_compute_capability) {
return absl::FailedPreconditionError(
"Triton support is only enabled for CUDA GPUs.");
} else if (!cuda_compute_capability->IsAtLeastAmpere()) {
return absl::FailedPreconditionError(
absl::StrCat("Triton support is only enabled for Ampere GPUs (compute ",
"capability 8.0) and up, but got compute capability ",
cuda_compute_capability->major, ".",
cuda_compute_capability->minor, "."));
}
TF_RETURN_IF_ERROR(EnsureTritonSupportsComputeCapability(
device_info_.gpu_compute_capability()));

TF_ASSIGN_OR_RETURN(std::vector<DiamondChainDescriptor> diamond_chains,
FindAllFusibleDiamondChains(*module, execution_threads));
Expand Down
13 changes: 5 additions & 8 deletions third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1140,7 +1140,7 @@ ENTRY main {
"(compute capability 8.0) and up, but got")));
}

TEST_F(SoftmaxRewriterTritonTest, RewriterBailsOutOnNonCudaGpu) {
TEST_F(SoftmaxRewriterTritonTest, RewriterSucceedsOnNonCudaGpu) {
const std::string hlo_string = R"(
HloModule softmax
max_computation {
Expand All @@ -1159,13 +1159,10 @@ ENTRY main {

auto module = ParseAndReturnVerifiedModule(hlo_string).value();

EXPECT_THAT(
SoftmaxRewriterTriton(TestGpuDeviceInfo::AMDMI210DeviceInfo(),
ShapeSizeBytesFunction())
.Run(module.get()),
tsl::testing::StatusIs(
tsl::error::FAILED_PRECONDITION,
::testing::StrEq("Triton support is only enabled for CUDA GPUs.")));
EXPECT_TRUE(SoftmaxRewriterTriton(TestGpuDeviceInfo::AMDMI210DeviceInfo(),
ShapeSizeBytesFunction())
.Run(module.get())
.ok());
}

TEST_P(SoftmaxRewriterTritonTest, DoesNotFuseConvertWithC64DataType) {
Expand Down
17 changes: 17 additions & 0 deletions third_party/xla/xla/service/gpu/triton_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
Expand Down Expand Up @@ -572,6 +574,21 @@ bool IsTritonSupportedElementwise(HloOpcode opcode, PrimitiveType element_type,

} // namespace

absl::Status EnsureTritonSupportsComputeCapability(
const se::GpuComputeCapability& gpu_compute_capability) {
auto cuda_compute_capability =
std::get_if<se::CudaComputeCapability>(&gpu_compute_capability);
if (cuda_compute_capability && !cuda_compute_capability->IsAtLeastAmpere()) {
return absl::FailedPreconditionError(
absl::StrCat("Triton support is only enabled for Ampere GPUs (compute ",
"capability 8.0) and up, but got compute capability ",
cuda_compute_capability->major, ".",
cuda_compute_capability->minor, "."));
}

return absl::OkStatus();
}

CodegenDecision IsTritonSupportedInstruction(
const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) {
bool output_type_is_supported =
Expand Down
7 changes: 7 additions & 0 deletions third_party/xla/xla/service/gpu/triton_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.

#include <vector>

#include "absl/status/status.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
Expand Down Expand Up @@ -103,6 +104,12 @@ CodegenDecision IsTritonSupportedDynamicSlice(
const HloDynamicSliceInstruction& instr);
} // namespace legacy_triton

// Checks that Triton officially supports the provided compute capability.
//
// Currently does not perform any check for non-CUDA compute capabilities.
absl::Status EnsureTritonSupportsComputeCapability(
const se::GpuComputeCapability& gpu_compute_capability);

// Return `CodegenDecision`'s equivalent of `true` if the parameter instruction
// is supported by the Triton emitters for the given compute capability. Note
// that this function makes no assumption about what happens if
Expand Down

0 comments on commit 0609624

Please sign in to comment.