diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index c190665f71e4e3..20f290e27ff18a 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -1942,18 +1942,15 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, const HloInstruction* root = dot_instr->parent()->root_instruction(); TF_RET_CHECK(!root->shape().IsTuple()); - auto fusion_adaptor = HloFusionAdaptor::ForComputation(computation); - HloInstructionAdaptor instr_adaptor{*instr, fusion_adaptor.get()}; // TODO(b/320659359) Allow TF32 for 8-bit or less types with F32. - bool is_8_bit_or_less_dot_with_F32 = HloAnyOf( - instr_adaptor.GetOperands(), *fusion_adaptor, - [&](HloInstructionAdaptor node) { - if (node.opcode() != HloOpcode::kConvert) { + bool is_unsupported_bitwidth = + HloAnyOf({dot_instr}, [&](const HloInstruction* node) { + if (node->opcode() != HloOpcode::kConvert) { return false; } - Type in_type = - TritonType(builder, node.GetOperand(0).shape().element_type()); - Type out_type = TritonType(builder, node.shape().element_type()); + auto in_type = + TritonType(builder, node->operand(0)->shape().element_type()); + Type out_type = TritonType(builder, node->shape().element_type()); return in_type.getIntOrFloatBitWidth() <= 8 && out_type.isF32(); }); @@ -2168,7 +2165,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, // lower precision than the output type. The change was introduced here: // https://github.com/openai/triton/commit/31b0c521427109a8eda609b58d756c380b21599a auto input_precision = - IsTf32Allowed(dot_instr) && !is_8_bit_or_less_dot_with_F32 + IsTf32Allowed(dot_instr) && !is_unsupported_bitwidth ? mt::InputPrecision::TF32 : mt::InputPrecision::IEEE; accumulator_next =