[go: nahoru, domu]

Skip to content

Commit

Permalink
Refactor llvm_compiler_test.
Browse files Browse the repository at this point in the history
We can run the CpuCompiler and GPUCompiler related tests in separate test
targets.

PiperOrigin-RevId: 644307904
  • Loading branch information
akuegel authored and tensorflower-gardener committed Jun 18, 2024
1 parent dc76516 commit 36fc1f4
Show file tree
Hide file tree
Showing 10 changed files with 475 additions and 219 deletions.
4 changes: 2 additions & 2 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1259,16 +1259,16 @@ xla_test(
deps = [
":gpu_device_info_for_tests",
":ir_emitter_triton",
":matmul_utils",
":triton_fusion_analysis",
":triton_support",
":triton_test_utils",
"//third_party/protobuf",
"//xla:xla_data_proto_cc",
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/lib/core:status_test_util",
Expand Down
119 changes: 74 additions & 45 deletions third_party/xla/xla/service/gpu/ir_emitter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ using mlir::ValueRange;
namespace {

// XLA -> Triton type conversions.
Type TritonType(mlir::OpBuilder b, PrimitiveType t) {
absl::StatusOr<Type> TritonType(mlir::OpBuilder b, PrimitiveType t) {
switch (t) {
case F64:
return b.getF64Type();
Expand Down Expand Up @@ -195,8 +195,9 @@ Type TritonType(mlir::OpBuilder b, PrimitiveType t) {
// Triton.
return b.getFloat8E4M3FNUZType();
default:
LOG(FATAL) << "This type is not supported yet: "
<< primitive_util::LowercasePrimitiveTypeName(t);
return absl::UnimplementedError(
absl::StrCat("This type is not supported yet: ",
primitive_util::LowercasePrimitiveTypeName(t)));
}
}

Expand Down Expand Up @@ -485,8 +486,11 @@ absl::StatusOr<Value> EmitElementwise(ImplicitLocOpBuilder& b,
case HloOpcode::kNegate:
// NegFOp is not supported by Triton.
return Subtract(b, {ZerosLike(b, inputs[0]), inputs[0]});
case HloOpcode::kConvert:
return Cast(b, inputs[0], TritonType(b, hlo.shape().element_type()));
case HloOpcode::kConvert: {
TF_ASSIGN_OR_RETURN(Type dst_ty,
TritonType(b, hlo.shape().element_type()));
return Cast(b, inputs[0], dst_ty);
}
case HloOpcode::kAdd:
if (is_integer) {
return b.create<ma::AddIOp>(inputs[0], inputs[1]);
Expand Down Expand Up @@ -577,8 +581,9 @@ Value EmitParameterLoad(ImplicitLocOpBuilder& b, Value pointer,
{});
}

Value EmitConstant(ImplicitLocOpBuilder& b, const HloInstruction& constant) {
Type ty = TritonType(b, constant.shape().element_type());
absl::StatusOr<Value> EmitConstant(ImplicitLocOpBuilder& b,
const HloInstruction& constant) {
TF_ASSIGN_OR_RETURN(Type ty, TritonType(b, constant.shape().element_type()));
if (constant.shape().IsInteger()) {
if (constant.shape().element_type() == U64) {
return CreateConst(b, ty, ScalarConstantValue<uint64_t>(constant, U64));
Expand Down Expand Up @@ -681,13 +686,14 @@ absl::StatusOr<Value> EmitReduce(ImplicitLocOpBuilder& b,
if (operand->opcode() == HloOpcode::kConvert) {
TF_RET_CHECK(operand->operand(0)->opcode() == HloOpcode::kConstant);
TF_RET_CHECK(operand->operand(0)->shape().element_type() == BF16);
PrimitiveType dest_ty = operand->shape().element_type();
TF_RET_CHECK(dest_ty == F32);
neutral = EmitConstant(b, *operand->operand(0));
neutral = Cast(b, neutral, TritonType(b, dest_ty));
TF_RET_CHECK(operand->shape().element_type() == F32);
TF_ASSIGN_OR_RETURN(Type dst_ty,
TritonType(b, operand->shape().element_type()));
TF_ASSIGN_OR_RETURN(neutral, EmitConstant(b, *operand->operand(0)));
neutral = Cast(b, neutral, dst_ty);
} else {
TF_RET_CHECK(operand->opcode() == HloOpcode::kConstant);
neutral = EmitConstant(b, *operand);
TF_ASSIGN_OR_RETURN(neutral, EmitConstant(b, *operand));
}

// Since every shape is padded to a power of 2 in Triton, the input tile may
Expand Down Expand Up @@ -756,7 +762,9 @@ absl::StatusOr<Value> EmitReduce(ImplicitLocOpBuilder& b,
result = Splat(b, result, {});
}

return Cast(b, result, TritonType(b, hlo_reduce.shape().element_type()));
TF_ASSIGN_OR_RETURN(Type result_ty,
TritonType(b, hlo_reduce.shape().element_type()));
return Cast(b, result, result_ty);
}

// Emit code corresponding to a fusion instruction somehow nested within the
Expand Down Expand Up @@ -873,8 +881,9 @@ absl::StatusOr<Value> EmitTiledHloInstruction(

if (hlo->opcode() == HloOpcode::kConstant &&
ShapeUtil::IsEffectiveScalar(hlo->shape())) {
TF_ASSIGN_OR_RETURN(Value constant, EmitConstant(b, *hlo));
// Splat makes it a tensor to avoid type mismatches.
return Splat(b, EmitConstant(b, *hlo), {});
return Splat(b, constant, {});
}

if (hlo->opcode() == HloOpcode::kBroadcast) {
Expand All @@ -896,16 +905,18 @@ absl::StatusOr<Value> EmitTiledHloInstruction(
return EmitElementwise(b, libdevice_path, device_info, *hlo, operands);
}

if (hlo->opcode() == HloOpcode::kTranspose ||
hlo->opcode() == HloOpcode::kSlice || hlo->opcode() == HloOpcode::kPad) {
// All these are currently supported only as operations on indices
// which are pushed to loads and stores. No operations on tiles are
// performed here.
// All these operations are currently supported only as operations on indices
// which are pushed to loads and stores. We don't generate any further code
// for these operations here.
std::vector<HloOpcode> passthrough_opcodes(
{HloOpcode::kBitcast, HloOpcode::kPad, HloOpcode::kReshape,
HloOpcode::kSlice, HloOpcode::kTranspose});
if (absl::c_linear_search(passthrough_opcodes, hlo->opcode())) {
return values[tiled_hlo.operand(0)];
}

return absl::UnimplementedError(
absl::StrCat("Unsupported opcode: ", hlo->opcode()));
absl::StrCat("Unsupported operation ", hlo->ToString()));
}

// Emit sequence of instructions using compatible tiling ordered producers
Expand Down Expand Up @@ -954,8 +965,9 @@ absl::StatusOr<Value> EmitScope(
TF_RET_CHECK(values.contains(hlo)) << hlo->ToString();
continue;
} else if (hlo->opcode() == HloOpcode::kConstant) {
TF_ASSIGN_OR_RETURN(Value constant, EmitConstant(b, *hlo));
// Splat makes it a tensor to avoid type mismatches.
result = Splat(b, EmitConstant(b, *hlo), {});
result = Splat(b, constant, {});
} else if (hlo->opcode() == HloOpcode::kBroadcast) {
TF_ASSIGN_OR_RETURN(
result, EmitBroadcast(b, analysis, scope, tiled_dimensions, *hlo,
Expand Down Expand Up @@ -1362,12 +1374,13 @@ class MatMulEmitterHelper {

// TODO(b/266862493): Accumulator can be integer too.
// Otherwise only f64 x f64 -> f64 uses f64 accumulator.
mlir::FloatType GetDotAccumulatorType() {
absl::StatusOr<mlir::FloatType> GetDotAccumulatorType() {
const PrecisionConfig::Algorithm algorithm =
dot_instr_->precision_config().algorithm();

if (algorithm == PrecisionConfig::ALG_UNSET) {
Type dot_output_ty = TritonType(b_, dot_instr_->shape().element_type());
TF_ASSIGN_OR_RETURN(Type dot_output_ty,
TritonType(b_, dot_instr_->shape().element_type()));
// The code below assumes that lhs and rhs have the same type. However
// it's not always the case with fp8 matmuls, e.g. e4m3×e5m2 is supported
// at the hardware level. NVidia GPU currently only supports f32
Expand All @@ -1377,14 +1390,14 @@ class MatMulEmitterHelper {
}

// Data type of dot() immediate inputs.
Type dot_input_ty = [&] {
const Type lhs_ty =
TritonType(b_, dot_instr_->operand(0)->shape().element_type());
const Type rhs_ty =
TritonType(b_, dot_instr_->operand(1)->shape().element_type());
CHECK(lhs_ty == rhs_ty);
return lhs_ty;
}();
TF_ASSIGN_OR_RETURN(
const Type lhs_ty,
TritonType(b_, dot_instr_->operand(0)->shape().element_type()));
TF_ASSIGN_OR_RETURN(
const Type rhs_ty,
TritonType(b_, dot_instr_->operand(1)->shape().element_type()));
TF_RET_CHECK(lhs_ty == rhs_ty);
Type dot_input_ty = lhs_ty;
// TODO(b/266862493): Accumulator can be integer too.
// Otherwise only f64 x f64 -> f64 uses f64 accumulator.
return (dot_output_ty.isF64() && dot_input_ty.isF64()) ? b_.getF64Type()
Expand All @@ -1395,7 +1408,8 @@ class MatMulEmitterHelper {
algorithm_util::GetDotAccumulatorType(algorithm);
CHECK(accum_type.ok()) << "Unexpected algorithm: "
<< PrecisionConfig::Algorithm_Name(algorithm);
Type mlir_accum_type = TritonType(b_, accum_type.value());
TF_ASSIGN_OR_RETURN(Type mlir_accum_type,
TritonType(b_, accum_type.value()));
if (auto float_accum_type =
mlir::dyn_cast<mlir::FloatType>(mlir_accum_type)) {
return float_accum_type;
Expand Down Expand Up @@ -2131,10 +2145,9 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
if (node->opcode() != HloOpcode::kConvert) {
return false;
}
Type in_type =
TritonType(builder, node->operand(0)->shape().element_type());
Type out_type = TritonType(builder, node->shape().element_type());
return in_type.getIntOrFloatBitWidth() <= 8 && out_type.isF32();
int in_width =
primitive_util::BitWidth(node->operand(0)->shape().element_type());
return in_width <= 8 && node->shape().element_type() == F32;
});

// We'll be creating a lot of instructions from a single dot, use an
Expand Down Expand Up @@ -2181,7 +2194,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
auto pid_n = b.create<ma::DivSIOp>(b.create<ma::RemSIOp>(pid_nc, c32(width)),
group_size);

mlir::FloatType acc_ty = emitter.GetDotAccumulatorType();
TF_ASSIGN_OR_RETURN(mlir::FloatType acc_ty, emitter.GetDotAccumulatorType());

ma::ConstantOp accumulator_init =
CreateConst(b, acc_ty, 0, {block_m, block_n});
Expand Down Expand Up @@ -2229,8 +2242,17 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
size_t lsize = ScopeInputs(analysis, TritonFusionAnalysis::Scope::LHS).size();
size_t rsize = ScopeInputs(analysis, TritonFusionAnalysis::Scope::RHS).size();

absl::flat_hash_map<const HloInstruction*, Type> triton_type_for_input;
for (const Side& side : {lhs, rhs}) {
for (const HloInstruction* input : ScopeInputs(analysis, side.scope)) {
TF_ASSIGN_OR_RETURN(Type input_ty,
TritonType(b, input->shape().element_type()));
triton_type_for_input.insert({input, input_ty});
}
}

auto body_builder = [&](mlir::OpBuilder&, mlir::Location, Value ki,
ValueRange iter_args) {
ValueRange iter_args) -> void {
SmallVector<Value> iter_args_next;
iter_args_next.reserve(iter_args.size());
std::array<absl::flat_hash_map<const HloInstruction*, Value>, 3> values;
Expand All @@ -2243,7 +2265,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
const HloInstruction* param_hlo = iter_args_to_inputs[i];
Type param_ty = index == kLhsMetaOperandIdx
? b.getI16Type()
: TritonType(b, param_hlo->shape().element_type());
: triton_type_for_input.at(param_hlo);
Type param_storage_ty = StorageType(b, param_ty);
Value param_value =
EmitParameterLoad(b, iter_args[i], iter_args_to_boundary_checks[i]);
Expand Down Expand Up @@ -2364,6 +2386,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
iter_args_next.push_back(accumulator_next);

b.create<mlir::scf::YieldOp>(iter_args_next);
return;
};

// Pointers to inputs of LHS scope, then RHS, then the accumulator
Expand Down Expand Up @@ -2393,8 +2416,9 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
/*iterArgs=*/iter_args, body_builder)
.getResult(iter_args.size() - 1);
absl::flat_hash_map<const HloInstruction*, Value> values_out;
values_out[dot_instr] =
Cast(b, acc_final, TritonType(b, dot_instr->shape().element_type()));
TF_ASSIGN_OR_RETURN(Type acc_final_ty,
TritonType(b, dot_instr->shape().element_type()));
values_out[dot_instr] = Cast(b, acc_final, acc_final_ty);

// Emit the output scope.
if (std::vector<const HloInstruction*> to_emit =
Expand Down Expand Up @@ -2774,16 +2798,21 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
SmallVector<Type> fn_arg_types;
for (HloInstruction* p : hlo_computation->parameter_instructions()) {
PrimitiveType type = p->shape().element_type();
Type ir_type = type != U16 ? TritonType(b, type) : b.getI16Type();
Type ir_type;
if (type == U16) {
ir_type = b.getI16Type();
} else {
TF_ASSIGN_OR_RETURN(ir_type, TritonType(b, type));
}
fn_arg_types.push_back(
mt::PointerType::get(StorageType(b, ir_type), mn::kGlobalMemorySpace));
}

for (const ShapeUtil::IndexedShape& s :
ShapeUtil::GetLeafShapes(fusion->shape())) {
fn_arg_types.push_back(mt::PointerType::get(
StorageType(b, TritonType(b, s.shape.element_type())),
mn::kGlobalMemorySpace));
TF_ASSIGN_OR_RETURN(Type triton_ty, TritonType(b, s.shape.element_type()));
fn_arg_types.push_back(mt::PointerType::get(StorageType(b, triton_ty),
mn::kGlobalMemorySpace));
}

auto fn = b.create<mt::FuncOp>(loc, fn_name,
Expand Down
12 changes: 12 additions & 0 deletions third_party/xla/xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ xla_cc_test(
name = "gpu_performance_model_test",
srcs = ["gpu_performance_model_test.cc"],
deps = [
":fusion_analysis_cache",
":gpu_hlo_cost_analysis",
":gpu_indexing_performance_model",
":gpu_performance_model",
Expand Down Expand Up @@ -341,32 +342,43 @@ cc_library(
hdrs = ["gpu_indexing_performance_model.h"],
deps = [
":coalescing_analysis",
":fusion_analysis_cache",
":gpu_hlo_cost_analysis",
":gpu_performance_model_base",
":hlo_op_profiles",
":indexing_analysis",
":symbolic_tile_analysis",
":tiled_hlo_computation",
"//xla:shape_util",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/service:hlo_cost_analysis",
"//xla/service:instruction_fusion",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:launch_dimensions",
"//xla/service/gpu/fusions:triton",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@llvm-project//mlir:IR",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:statusor",
],
)

xla_cc_test(
name = "gpu_indexing_performance_model_test",
srcs = ["gpu_indexing_performance_model_test.cc"],
deps = [
":fusion_analysis_cache",
":gpu_hlo_cost_analysis",
":gpu_indexing_performance_model",
":gpu_performance_model_base",
Expand Down
Loading

0 comments on commit 36fc1f4

Please sign in to comment.