[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automated Code Change #69953

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ cc_library(
"//xla/hlo/ir:hlo",
"//xla/service/llvm_ir:llvm_type_conversion_util",
"//xla/service/llvm_ir:llvm_util",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
Expand Down Expand Up @@ -1259,16 +1258,16 @@ xla_test(
deps = [
":gpu_device_info_for_tests",
":ir_emitter_triton",
":matmul_utils",
":triton_fusion_analysis",
":triton_support",
":triton_test_utils",
"//third_party/protobuf",
"//xla:xla_data_proto_cc",
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/lib/core:status_test_util",
Expand Down Expand Up @@ -1495,7 +1494,6 @@ cc_library(
"//xla:autotuning_proto_cc",
"//xla:literal_util",
"//xla:shape_util",
"//xla:statusor",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
Expand Down Expand Up @@ -1565,6 +1563,7 @@ cc_library(
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
Expand Down Expand Up @@ -2577,6 +2576,7 @@ xla_cc_test(
deps = [
":softmax_rewriter_triton",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/service:instruction_fusion",
"//xla/service:pattern_matcher",
Expand Down Expand Up @@ -4424,7 +4424,6 @@ cc_library(
deps = [
":cublas_cudnn",
":launch_dimensions",
":stream_executor_util_kernel",
"//xla:autotuning_proto_cc",
"//xla:shape_util",
"//xla:util",
Expand Down Expand Up @@ -4454,6 +4453,7 @@ cc_library(
"@local_tsl//tsl/platform:ml_dtypes",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/protobuf:dnn_proto_cc",
],
)

Expand Down Expand Up @@ -5524,6 +5524,7 @@ xla_cc_test(
deps = [
":scatter_slice_simplifier",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/service:pattern_matcher",
"//xla/service:pattern_matcher_gmock",
"//xla/tests:hlo_test_base",
Expand Down Expand Up @@ -6013,7 +6014,6 @@ cc_library(
deps = [
":backend_configs_cc",
":gpu_fusible",
"//xla:statusor",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
Expand All @@ -6023,6 +6023,7 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
Expand Down Expand Up @@ -6059,6 +6060,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:statusor",
],
Expand Down
119 changes: 74 additions & 45 deletions third_party/xla/xla/service/gpu/ir_emitter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ using mlir::ValueRange;
namespace {

// XLA -> Triton type conversions.
Type TritonType(mlir::OpBuilder b, PrimitiveType t) {
absl::StatusOr<Type> TritonType(mlir::OpBuilder b, PrimitiveType t) {
switch (t) {
case F64:
return b.getF64Type();
Expand Down Expand Up @@ -195,8 +195,9 @@ Type TritonType(mlir::OpBuilder b, PrimitiveType t) {
// Triton.
return b.getFloat8E4M3FNUZType();
default:
LOG(FATAL) << "This type is not supported yet: "
<< primitive_util::LowercasePrimitiveTypeName(t);
return absl::UnimplementedError(
absl::StrCat("This type is not supported yet: ",
primitive_util::LowercasePrimitiveTypeName(t)));
}
}

Expand Down Expand Up @@ -485,8 +486,11 @@ absl::StatusOr<Value> EmitElementwise(ImplicitLocOpBuilder& b,
case HloOpcode::kNegate:
// NegFOp is not supported by Triton.
return Subtract(b, {ZerosLike(b, inputs[0]), inputs[0]});
case HloOpcode::kConvert:
return Cast(b, inputs[0], TritonType(b, hlo.shape().element_type()));
case HloOpcode::kConvert: {
TF_ASSIGN_OR_RETURN(Type dst_ty,
TritonType(b, hlo.shape().element_type()));
return Cast(b, inputs[0], dst_ty);
}
case HloOpcode::kAdd:
if (is_integer) {
return b.create<ma::AddIOp>(inputs[0], inputs[1]);
Expand Down Expand Up @@ -577,8 +581,9 @@ Value EmitParameterLoad(ImplicitLocOpBuilder& b, Value pointer,
{});
}

Value EmitConstant(ImplicitLocOpBuilder& b, const HloInstruction& constant) {
Type ty = TritonType(b, constant.shape().element_type());
absl::StatusOr<Value> EmitConstant(ImplicitLocOpBuilder& b,
const HloInstruction& constant) {
TF_ASSIGN_OR_RETURN(Type ty, TritonType(b, constant.shape().element_type()));
if (constant.shape().IsInteger()) {
if (constant.shape().element_type() == U64) {
return CreateConst(b, ty, ScalarConstantValue<uint64_t>(constant, U64));
Expand Down Expand Up @@ -681,13 +686,14 @@ absl::StatusOr<Value> EmitReduce(ImplicitLocOpBuilder& b,
if (operand->opcode() == HloOpcode::kConvert) {
TF_RET_CHECK(operand->operand(0)->opcode() == HloOpcode::kConstant);
TF_RET_CHECK(operand->operand(0)->shape().element_type() == BF16);
PrimitiveType dest_ty = operand->shape().element_type();
TF_RET_CHECK(dest_ty == F32);
neutral = EmitConstant(b, *operand->operand(0));
neutral = Cast(b, neutral, TritonType(b, dest_ty));
TF_RET_CHECK(operand->shape().element_type() == F32);
TF_ASSIGN_OR_RETURN(Type dst_ty,
TritonType(b, operand->shape().element_type()));
TF_ASSIGN_OR_RETURN(neutral, EmitConstant(b, *operand->operand(0)));
neutral = Cast(b, neutral, dst_ty);
} else {
TF_RET_CHECK(operand->opcode() == HloOpcode::kConstant);
neutral = EmitConstant(b, *operand);
TF_ASSIGN_OR_RETURN(neutral, EmitConstant(b, *operand));
}

// Since every shape is padded to a power of 2 in Triton, the input tile may
Expand Down Expand Up @@ -756,7 +762,9 @@ absl::StatusOr<Value> EmitReduce(ImplicitLocOpBuilder& b,
result = Splat(b, result, {});
}

return Cast(b, result, TritonType(b, hlo_reduce.shape().element_type()));
TF_ASSIGN_OR_RETURN(Type result_ty,
TritonType(b, hlo_reduce.shape().element_type()));
return Cast(b, result, result_ty);
}

// Emit code corresponding to a fusion instruction somehow nested within the
Expand Down Expand Up @@ -873,8 +881,9 @@ absl::StatusOr<Value> EmitTiledHloInstruction(

if (hlo->opcode() == HloOpcode::kConstant &&
ShapeUtil::IsEffectiveScalar(hlo->shape())) {
TF_ASSIGN_OR_RETURN(Value constant, EmitConstant(b, *hlo));
// Splat makes it a tensor to avoid type mismatches.
return Splat(b, EmitConstant(b, *hlo), {});
return Splat(b, constant, {});
}

if (hlo->opcode() == HloOpcode::kBroadcast) {
Expand All @@ -896,16 +905,18 @@ absl::StatusOr<Value> EmitTiledHloInstruction(
return EmitElementwise(b, libdevice_path, device_info, *hlo, operands);
}

if (hlo->opcode() == HloOpcode::kTranspose ||
hlo->opcode() == HloOpcode::kSlice || hlo->opcode() == HloOpcode::kPad) {
// All these are currently supported only as operations on indices
// which are pushed to loads and stores. No operations on tiles are
// performed here.
// All these operations are currently supported only as operations on indices
// which are pushed to loads and stores. We don't generate any further code
// for these operations here.
std::vector<HloOpcode> passthrough_opcodes(
{HloOpcode::kBitcast, HloOpcode::kPad, HloOpcode::kReshape,
HloOpcode::kSlice, HloOpcode::kTranspose});
if (absl::c_linear_search(passthrough_opcodes, hlo->opcode())) {
return values[tiled_hlo.operand(0)];
}

return absl::UnimplementedError(
absl::StrCat("Unsupported opcode: ", hlo->opcode()));
absl::StrCat("Unsupported operation ", hlo->ToString()));
}

// Emit sequence of instructions using compatible tiling ordered producers
Expand Down Expand Up @@ -954,8 +965,9 @@ absl::StatusOr<Value> EmitScope(
TF_RET_CHECK(values.contains(hlo)) << hlo->ToString();
continue;
} else if (hlo->opcode() == HloOpcode::kConstant) {
TF_ASSIGN_OR_RETURN(Value constant, EmitConstant(b, *hlo));
// Splat makes it a tensor to avoid type mismatches.
result = Splat(b, EmitConstant(b, *hlo), {});
result = Splat(b, constant, {});
} else if (hlo->opcode() == HloOpcode::kBroadcast) {
TF_ASSIGN_OR_RETURN(
result, EmitBroadcast(b, analysis, scope, tiled_dimensions, *hlo,
Expand Down Expand Up @@ -1362,12 +1374,13 @@ class MatMulEmitterHelper {

// TODO(b/266862493): Accumulator can be integer too.
// Otherwise only f64 x f64 -> f64 uses f64 accumulator.
mlir::FloatType GetDotAccumulatorType() {
absl::StatusOr<mlir::FloatType> GetDotAccumulatorType() {
const PrecisionConfig::Algorithm algorithm =
dot_instr_->precision_config().algorithm();

if (algorithm == PrecisionConfig::ALG_UNSET) {
Type dot_output_ty = TritonType(b_, dot_instr_->shape().element_type());
TF_ASSIGN_OR_RETURN(Type dot_output_ty,
TritonType(b_, dot_instr_->shape().element_type()));
// The code below assumes that lhs and rhs have the same type. However
// it's not always the case with fp8 matmuls, e.g. e4m3×e5m2 is supported
// at the hardware level. NVidia GPU currently only supports f32
Expand All @@ -1377,14 +1390,14 @@ class MatMulEmitterHelper {
}

// Data type of dot() immediate inputs.
Type dot_input_ty = [&] {
const Type lhs_ty =
TritonType(b_, dot_instr_->operand(0)->shape().element_type());
const Type rhs_ty =
TritonType(b_, dot_instr_->operand(1)->shape().element_type());
CHECK(lhs_ty == rhs_ty);
return lhs_ty;
}();
TF_ASSIGN_OR_RETURN(
const Type lhs_ty,
TritonType(b_, dot_instr_->operand(0)->shape().element_type()));
TF_ASSIGN_OR_RETURN(
const Type rhs_ty,
TritonType(b_, dot_instr_->operand(1)->shape().element_type()));
TF_RET_CHECK(lhs_ty == rhs_ty);
Type dot_input_ty = lhs_ty;
// TODO(b/266862493): Accumulator can be integer too.
// Otherwise only f64 x f64 -> f64 uses f64 accumulator.
return (dot_output_ty.isF64() && dot_input_ty.isF64()) ? b_.getF64Type()
Expand All @@ -1395,7 +1408,8 @@ class MatMulEmitterHelper {
algorithm_util::GetDotAccumulatorType(algorithm);
CHECK(accum_type.ok()) << "Unexpected algorithm: "
<< PrecisionConfig::Algorithm_Name(algorithm);
Type mlir_accum_type = TritonType(b_, accum_type.value());
TF_ASSIGN_OR_RETURN(Type mlir_accum_type,
TritonType(b_, accum_type.value()));
if (auto float_accum_type =
mlir::dyn_cast<mlir::FloatType>(mlir_accum_type)) {
return float_accum_type;
Expand Down Expand Up @@ -2131,10 +2145,9 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
if (node->opcode() != HloOpcode::kConvert) {
return false;
}
Type in_type =
TritonType(builder, node->operand(0)->shape().element_type());
Type out_type = TritonType(builder, node->shape().element_type());
return in_type.getIntOrFloatBitWidth() <= 8 && out_type.isF32();
int in_width =
primitive_util::BitWidth(node->operand(0)->shape().element_type());
return in_width <= 8 && node->shape().element_type() == F32;
});

// We'll be creating a lot of instructions from a single dot, use an
Expand Down Expand Up @@ -2181,7 +2194,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
auto pid_n = b.create<ma::DivSIOp>(b.create<ma::RemSIOp>(pid_nc, c32(width)),
group_size);

mlir::FloatType acc_ty = emitter.GetDotAccumulatorType();
TF_ASSIGN_OR_RETURN(mlir::FloatType acc_ty, emitter.GetDotAccumulatorType());

ma::ConstantOp accumulator_init =
CreateConst(b, acc_ty, 0, {block_m, block_n});
Expand Down Expand Up @@ -2229,8 +2242,17 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
size_t lsize = ScopeInputs(analysis, TritonFusionAnalysis::Scope::LHS).size();
size_t rsize = ScopeInputs(analysis, TritonFusionAnalysis::Scope::RHS).size();

absl::flat_hash_map<const HloInstruction*, Type> triton_type_for_input;
for (const Side& side : {lhs, rhs}) {
for (const HloInstruction* input : ScopeInputs(analysis, side.scope)) {
TF_ASSIGN_OR_RETURN(Type input_ty,
TritonType(b, input->shape().element_type()));
triton_type_for_input.insert({input, input_ty});
}
}

auto body_builder = [&](mlir::OpBuilder&, mlir::Location, Value ki,
ValueRange iter_args) {
ValueRange iter_args) -> void {
SmallVector<Value> iter_args_next;
iter_args_next.reserve(iter_args.size());
std::array<absl::flat_hash_map<const HloInstruction*, Value>, 3> values;
Expand All @@ -2243,7 +2265,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
const HloInstruction* param_hlo = iter_args_to_inputs[i];
Type param_ty = index == kLhsMetaOperandIdx
? b.getI16Type()
: TritonType(b, param_hlo->shape().element_type());
: triton_type_for_input.at(param_hlo);
Type param_storage_ty = StorageType(b, param_ty);
Value param_value =
EmitParameterLoad(b, iter_args[i], iter_args_to_boundary_checks[i]);
Expand Down Expand Up @@ -2364,6 +2386,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
iter_args_next.push_back(accumulator_next);

b.create<mlir::scf::YieldOp>(iter_args_next);
return;
};

// Pointers to inputs of LHS scope, then RHS, then the accumulator
Expand Down Expand Up @@ -2393,8 +2416,9 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
/*iterArgs=*/iter_args, body_builder)
.getResult(iter_args.size() - 1);
absl::flat_hash_map<const HloInstruction*, Value> values_out;
values_out[dot_instr] =
Cast(b, acc_final, TritonType(b, dot_instr->shape().element_type()));
TF_ASSIGN_OR_RETURN(Type acc_final_ty,
TritonType(b, dot_instr->shape().element_type()));
values_out[dot_instr] = Cast(b, acc_final, acc_final_ty);

// Emit the output scope.
if (std::vector<const HloInstruction*> to_emit =
Expand Down Expand Up @@ -2774,16 +2798,21 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
SmallVector<Type> fn_arg_types;
for (HloInstruction* p : hlo_computation->parameter_instructions()) {
PrimitiveType type = p->shape().element_type();
Type ir_type = type != U16 ? TritonType(b, type) : b.getI16Type();
Type ir_type;
if (type == U16) {
ir_type = b.getI16Type();
} else {
TF_ASSIGN_OR_RETURN(ir_type, TritonType(b, type));
}
fn_arg_types.push_back(
mt::PointerType::get(StorageType(b, ir_type), mn::kGlobalMemorySpace));
}

for (const ShapeUtil::IndexedShape& s :
ShapeUtil::GetLeafShapes(fusion->shape())) {
fn_arg_types.push_back(mt::PointerType::get(
StorageType(b, TritonType(b, s.shape.element_type())),
mn::kGlobalMemorySpace));
TF_ASSIGN_OR_RETURN(Type triton_ty, TritonType(b, s.shape.element_type()));
fn_arg_types.push_back(mt::PointerType::get(StorageType(b, triton_ty),
mn::kGlobalMemorySpace));
}

auto fn = b.create<mt::FuncOp>(loc, fn_name,
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,7 @@ xla_cc_test(
"//xla/service/gpu:hlo_traversal",
"//xla/tests:hlo_test_base",
"//xla/tests:verified_hlo_module",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
Expand Down
Loading