[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:GPU] Refactor inputPrecision condition for Triton to use HloAnyO…
Browse files Browse the repository at this point in the history
…f without requiring adaptors

PiperOrigin-RevId: 633484897
  • Loading branch information
Moerafaat authored and tensorflower-gardener committed May 14, 2024
1 parent 1624431 commit 934c0de
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions third_party/xla/xla/service/gpu/ir_emitter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1942,18 +1942,15 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
const HloInstruction* root = dot_instr->parent()->root_instruction();
TF_RET_CHECK(!root->shape().IsTuple());

auto fusion_adaptor = HloFusionAdaptor::ForComputation(computation);
HloInstructionAdaptor instr_adaptor{*instr, fusion_adaptor.get()};
// TODO(b/320659359) Allow TF32 for 8-bit or less types with F32.
bool is_8_bit_or_less_dot_with_F32 = HloAnyOf(
instr_adaptor.GetOperands(), *fusion_adaptor,
[&](HloInstructionAdaptor node) {
if (node.opcode() != HloOpcode::kConvert) {
bool is_unsupported_bitwidth =
HloAnyOf({dot_instr}, [&](const HloInstruction* node) {
if (node->opcode() != HloOpcode::kConvert) {
return false;
}
Type in_type =
TritonType(builder, node.GetOperand(0).shape().element_type());
Type out_type = TritonType(builder, node.shape().element_type());
auto in_type =
TritonType(builder, node->operand(0)->shape().element_type());
Type out_type = TritonType(builder, node->shape().element_type());
return in_type.getIntOrFloatBitWidth() <= 8 && out_type.isF32();
});

Expand Down Expand Up @@ -2168,7 +2165,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
// lower precision than the output type. The change was introduced here:
// https://github.com/openai/triton/commit/31b0c521427109a8eda609b58d756c380b21599a
auto input_precision =
IsTf32Allowed(dot_instr) && !is_8_bit_or_less_dot_with_F32
IsTf32Allowed(dot_instr) && !is_unsupported_bitwidth
? mt::InputPrecision::TF32
: mt::InputPrecision::IEEE;
accumulator_next =
Expand Down

0 comments on commit 934c0de

Please sign in to comment.