[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:GPU] Add constraints to SymbolicTileAnalysis.
Browse files Browse the repository at this point in the history
Constraints are constructed from merging the constraints of all the
`SymbolicTile`s encountered while constructing the resulting symbolic tiled HLO
computation. If any of the `SymbolicTile`s is unsatisfiable, the construction
of the `SymbolicTileAnalysis` object does not succeed. Likewise, construction
fails if some constraints can not be merged with others.

Constraints are now checked to be satisfied by the provided tile parameters
when attempting to extract a `TiledHloComputation` out of
`SymbolicTileAnalysis`. In order to avoid checking constraints too many times,
we allow pinky-promising that the provided tile parameters satisfy the
constraints, to voluntarily bypass the checks.

PiperOrigin-RevId: 644301234
  • Loading branch information
bchetioui authored and tensorflower-gardener committed Jun 18, 2024
1 parent cd75b11 commit 7530918
Show file tree
Hide file tree
Showing 9 changed files with 395 additions and 105 deletions.
4 changes: 2 additions & 2 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1259,16 +1259,16 @@ xla_test(
deps = [
":gpu_device_info_for_tests",
":ir_emitter_triton",
":matmul_utils",
":triton_fusion_analysis",
":triton_support",
":triton_test_utils",
"//third_party/protobuf",
"//xla:xla_data_proto_cc",
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/lib/core:status_test_util",
Expand Down
119 changes: 74 additions & 45 deletions third_party/xla/xla/service/gpu/ir_emitter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ using mlir::ValueRange;
namespace {

// XLA -> Triton type conversions.
Type TritonType(mlir::OpBuilder b, PrimitiveType t) {
absl::StatusOr<Type> TritonType(mlir::OpBuilder b, PrimitiveType t) {
switch (t) {
case F64:
return b.getF64Type();
Expand Down Expand Up @@ -195,8 +195,9 @@ Type TritonType(mlir::OpBuilder b, PrimitiveType t) {
// Triton.
return b.getFloat8E4M3FNUZType();
default:
LOG(FATAL) << "This type is not supported yet: "
<< primitive_util::LowercasePrimitiveTypeName(t);
return absl::UnimplementedError(
absl::StrCat("This type is not supported yet: ",
primitive_util::LowercasePrimitiveTypeName(t)));
}
}

Expand Down Expand Up @@ -485,8 +486,11 @@ absl::StatusOr<Value> EmitElementwise(ImplicitLocOpBuilder& b,
case HloOpcode::kNegate:
// NegFOp is not supported by Triton.
return Subtract(b, {ZerosLike(b, inputs[0]), inputs[0]});
case HloOpcode::kConvert:
return Cast(b, inputs[0], TritonType(b, hlo.shape().element_type()));
case HloOpcode::kConvert: {
TF_ASSIGN_OR_RETURN(Type dst_ty,
TritonType(b, hlo.shape().element_type()));
return Cast(b, inputs[0], dst_ty);
}
case HloOpcode::kAdd:
if (is_integer) {
return b.create<ma::AddIOp>(inputs[0], inputs[1]);
Expand Down Expand Up @@ -577,8 +581,9 @@ Value EmitParameterLoad(ImplicitLocOpBuilder& b, Value pointer,
{});
}

Value EmitConstant(ImplicitLocOpBuilder& b, const HloInstruction& constant) {
Type ty = TritonType(b, constant.shape().element_type());
absl::StatusOr<Value> EmitConstant(ImplicitLocOpBuilder& b,
const HloInstruction& constant) {
TF_ASSIGN_OR_RETURN(Type ty, TritonType(b, constant.shape().element_type()));
if (constant.shape().IsInteger()) {
if (constant.shape().element_type() == U64) {
return CreateConst(b, ty, ScalarConstantValue<uint64_t>(constant, U64));
Expand Down Expand Up @@ -681,13 +686,14 @@ absl::StatusOr<Value> EmitReduce(ImplicitLocOpBuilder& b,
if (operand->opcode() == HloOpcode::kConvert) {
TF_RET_CHECK(operand->operand(0)->opcode() == HloOpcode::kConstant);
TF_RET_CHECK(operand->operand(0)->shape().element_type() == BF16);
PrimitiveType dest_ty = operand->shape().element_type();
TF_RET_CHECK(dest_ty == F32);
neutral = EmitConstant(b, *operand->operand(0));
neutral = Cast(b, neutral, TritonType(b, dest_ty));
TF_RET_CHECK(operand->shape().element_type() == F32);
TF_ASSIGN_OR_RETURN(Type dst_ty,
TritonType(b, operand->shape().element_type()));
TF_ASSIGN_OR_RETURN(neutral, EmitConstant(b, *operand->operand(0)));
neutral = Cast(b, neutral, dst_ty);
} else {
TF_RET_CHECK(operand->opcode() == HloOpcode::kConstant);
neutral = EmitConstant(b, *operand);
TF_ASSIGN_OR_RETURN(neutral, EmitConstant(b, *operand));
}

// Since every shape is padded to a power of 2 in Triton, the input tile may
Expand Down Expand Up @@ -756,7 +762,9 @@ absl::StatusOr<Value> EmitReduce(ImplicitLocOpBuilder& b,
result = Splat(b, result, {});
}

return Cast(b, result, TritonType(b, hlo_reduce.shape().element_type()));
TF_ASSIGN_OR_RETURN(Type result_ty,
TritonType(b, hlo_reduce.shape().element_type()));
return Cast(b, result, result_ty);
}

// Emit code corresponding to a fusion instruction somehow nested within the
Expand Down Expand Up @@ -873,8 +881,9 @@ absl::StatusOr<Value> EmitTiledHloInstruction(

if (hlo->opcode() == HloOpcode::kConstant &&
ShapeUtil::IsEffectiveScalar(hlo->shape())) {
TF_ASSIGN_OR_RETURN(Value constant, EmitConstant(b, *hlo));
// Splat makes it a tensor to avoid type mismatches.
return Splat(b, EmitConstant(b, *hlo), {});
return Splat(b, constant, {});
}

if (hlo->opcode() == HloOpcode::kBroadcast) {
Expand All @@ -896,16 +905,18 @@ absl::StatusOr<Value> EmitTiledHloInstruction(
return EmitElementwise(b, libdevice_path, device_info, *hlo, operands);
}

if (hlo->opcode() == HloOpcode::kTranspose ||
hlo->opcode() == HloOpcode::kSlice || hlo->opcode() == HloOpcode::kPad) {
// All these are currently supported only as operations on indices
// which are pushed to loads and stores. No operations on tiles are
// performed here.
// All these operations are currently supported only as operations on indices
// which are pushed to loads and stores. We don't generate any further code
// for these operations here.
std::vector<HloOpcode> passthrough_opcodes(
{HloOpcode::kBitcast, HloOpcode::kPad, HloOpcode::kReshape,
HloOpcode::kSlice, HloOpcode::kTranspose});
if (absl::c_linear_search(passthrough_opcodes, hlo->opcode())) {
return values[tiled_hlo.operand(0)];
}

return absl::UnimplementedError(
absl::StrCat("Unsupported opcode: ", hlo->opcode()));
absl::StrCat("Unsupported operation ", hlo->ToString()));
}

// Emit sequence of instructions using compatible tiling ordered producers
Expand Down Expand Up @@ -954,8 +965,9 @@ absl::StatusOr<Value> EmitScope(
TF_RET_CHECK(values.contains(hlo)) << hlo->ToString();
continue;
} else if (hlo->opcode() == HloOpcode::kConstant) {
TF_ASSIGN_OR_RETURN(Value constant, EmitConstant(b, *hlo));
// Splat makes it a tensor to avoid type mismatches.
result = Splat(b, EmitConstant(b, *hlo), {});
result = Splat(b, constant, {});
} else if (hlo->opcode() == HloOpcode::kBroadcast) {
TF_ASSIGN_OR_RETURN(
result, EmitBroadcast(b, analysis, scope, tiled_dimensions, *hlo,
Expand Down Expand Up @@ -1362,12 +1374,13 @@ class MatMulEmitterHelper {

// TODO(b/266862493): Accumulator can be integer too.
// Otherwise only f64 x f64 -> f64 uses f64 accumulator.
mlir::FloatType GetDotAccumulatorType() {
absl::StatusOr<mlir::FloatType> GetDotAccumulatorType() {
const PrecisionConfig::Algorithm algorithm =
dot_instr_->precision_config().algorithm();

if (algorithm == PrecisionConfig::ALG_UNSET) {
Type dot_output_ty = TritonType(b_, dot_instr_->shape().element_type());
TF_ASSIGN_OR_RETURN(Type dot_output_ty,
TritonType(b_, dot_instr_->shape().element_type()));
// The code below assumes that lhs and rhs have the same type. However
// it's not always the case with fp8 matmuls, e.g. e4m3×e5m2 is supported
// at the hardware level. NVidia GPU currently only supports f32
Expand All @@ -1377,14 +1390,14 @@ class MatMulEmitterHelper {
}

// Data type of dot() immediate inputs.
Type dot_input_ty = [&] {
const Type lhs_ty =
TritonType(b_, dot_instr_->operand(0)->shape().element_type());
const Type rhs_ty =
TritonType(b_, dot_instr_->operand(1)->shape().element_type());
CHECK(lhs_ty == rhs_ty);
return lhs_ty;
}();
TF_ASSIGN_OR_RETURN(
const Type lhs_ty,
TritonType(b_, dot_instr_->operand(0)->shape().element_type()));
TF_ASSIGN_OR_RETURN(
const Type rhs_ty,
TritonType(b_, dot_instr_->operand(1)->shape().element_type()));
TF_RET_CHECK(lhs_ty == rhs_ty);
Type dot_input_ty = lhs_ty;
// TODO(b/266862493): Accumulator can be integer too.
// Otherwise only f64 x f64 -> f64 uses f64 accumulator.
return (dot_output_ty.isF64() && dot_input_ty.isF64()) ? b_.getF64Type()
Expand All @@ -1395,7 +1408,8 @@ class MatMulEmitterHelper {
algorithm_util::GetDotAccumulatorType(algorithm);
CHECK(accum_type.ok()) << "Unexpected algorithm: "
<< PrecisionConfig::Algorithm_Name(algorithm);
Type mlir_accum_type = TritonType(b_, accum_type.value());
TF_ASSIGN_OR_RETURN(Type mlir_accum_type,
TritonType(b_, accum_type.value()));
if (auto float_accum_type =
mlir::dyn_cast<mlir::FloatType>(mlir_accum_type)) {
return float_accum_type;
Expand Down Expand Up @@ -2131,10 +2145,9 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
if (node->opcode() != HloOpcode::kConvert) {
return false;
}
Type in_type =
TritonType(builder, node->operand(0)->shape().element_type());
Type out_type = TritonType(builder, node->shape().element_type());
return in_type.getIntOrFloatBitWidth() <= 8 && out_type.isF32();
int in_width =
primitive_util::BitWidth(node->operand(0)->shape().element_type());
return in_width <= 8 && node->shape().element_type() == F32;
});

// We'll be creating a lot of instructions from a single dot, use an
Expand Down Expand Up @@ -2181,7 +2194,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
auto pid_n = b.create<ma::DivSIOp>(b.create<ma::RemSIOp>(pid_nc, c32(width)),
group_size);

mlir::FloatType acc_ty = emitter.GetDotAccumulatorType();
TF_ASSIGN_OR_RETURN(mlir::FloatType acc_ty, emitter.GetDotAccumulatorType());

ma::ConstantOp accumulator_init =
CreateConst(b, acc_ty, 0, {block_m, block_n});
Expand Down Expand Up @@ -2229,8 +2242,17 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
size_t lsize = ScopeInputs(analysis, TritonFusionAnalysis::Scope::LHS).size();
size_t rsize = ScopeInputs(analysis, TritonFusionAnalysis::Scope::RHS).size();

absl::flat_hash_map<const HloInstruction*, Type> triton_type_for_input;
for (const Side& side : {lhs, rhs}) {
for (const HloInstruction* input : ScopeInputs(analysis, side.scope)) {
TF_ASSIGN_OR_RETURN(Type input_ty,
TritonType(b, input->shape().element_type()));
triton_type_for_input.insert({input, input_ty});
}
}

auto body_builder = [&](mlir::OpBuilder&, mlir::Location, Value ki,
ValueRange iter_args) {
ValueRange iter_args) -> void {
SmallVector<Value> iter_args_next;
iter_args_next.reserve(iter_args.size());
std::array<absl::flat_hash_map<const HloInstruction*, Value>, 3> values;
Expand All @@ -2243,7 +2265,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
const HloInstruction* param_hlo = iter_args_to_inputs[i];
Type param_ty = index == kLhsMetaOperandIdx
? b.getI16Type()
: TritonType(b, param_hlo->shape().element_type());
: triton_type_for_input.at(param_hlo);
Type param_storage_ty = StorageType(b, param_ty);
Value param_value =
EmitParameterLoad(b, iter_args[i], iter_args_to_boundary_checks[i]);
Expand Down Expand Up @@ -2364,6 +2386,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
iter_args_next.push_back(accumulator_next);

b.create<mlir::scf::YieldOp>(iter_args_next);
return;
};

// Pointers to inputs of LHS scope, then RHS, then the accumulator
Expand Down Expand Up @@ -2393,8 +2416,9 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
/*iterArgs=*/iter_args, body_builder)
.getResult(iter_args.size() - 1);
absl::flat_hash_map<const HloInstruction*, Value> values_out;
values_out[dot_instr] =
Cast(b, acc_final, TritonType(b, dot_instr->shape().element_type()));
TF_ASSIGN_OR_RETURN(Type acc_final_ty,
TritonType(b, dot_instr->shape().element_type()));
values_out[dot_instr] = Cast(b, acc_final, acc_final_ty);

// Emit the output scope.
if (std::vector<const HloInstruction*> to_emit =
Expand Down Expand Up @@ -2774,16 +2798,21 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
SmallVector<Type> fn_arg_types;
for (HloInstruction* p : hlo_computation->parameter_instructions()) {
PrimitiveType type = p->shape().element_type();
Type ir_type = type != U16 ? TritonType(b, type) : b.getI16Type();
Type ir_type;
if (type == U16) {
ir_type = b.getI16Type();
} else {
TF_ASSIGN_OR_RETURN(ir_type, TritonType(b, type));
}
fn_arg_types.push_back(
mt::PointerType::get(StorageType(b, ir_type), mn::kGlobalMemorySpace));
}

for (const ShapeUtil::IndexedShape& s :
ShapeUtil::GetLeafShapes(fusion->shape())) {
fn_arg_types.push_back(mt::PointerType::get(
StorageType(b, TritonType(b, s.shape.element_type())),
mn::kGlobalMemorySpace));
TF_ASSIGN_OR_RETURN(Type triton_ty, TritonType(b, s.shape.element_type()));
fn_arg_types.push_back(mt::PointerType::get(StorageType(b, triton_ty),
mn::kGlobalMemorySpace));
}

auto fn = b.create<mt::FuncOp>(loc, fn_name,
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,7 @@ xla_cc_test(
"//xla/service/gpu:hlo_traversal",
"//xla/tests:hlo_test_base",
"//xla/tests:verified_hlo_module",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
Expand Down
60 changes: 23 additions & 37 deletions third_party/xla/xla/service/gpu/model/symbolic_tile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,43 +106,6 @@ AffineMap SubstituteAllIndicesAndRangeVarSymbolsWithSameValue(
return simplifyAffineMap(affine_map.replace(indices, num_dims, num_symbols));
}

// Merges `maybe_first_map` and `second_map` if
// (1) `maybe_first_map` is present, and
// (2) `second_map` and `*maybe_first_map` have distinct sets of keys.
// Otherwise, returns `std::nullopt`.
//
//
// The behaviour of this function is in spirit equivalent to using C++23's
// `std::optional<T>::and_then` to merge a collection of `ConstraintMap`s.
//
// We pass `maybe_first_map` by value here in order to exploit move semantics
// to avoid copies when possible.
//
// TODO(bchetioui): allow merging constraints in more edge cases, e.g. if one
// of the intervals is contained within the other.
std::optional<ConstraintMap> MergeConstraintMapIfPresentAndCompatible(
std::optional<ConstraintMap> maybe_first_map,
const ConstraintMap& second_map) {
if (!maybe_first_map.has_value()) {
return std::nullopt;
}

ConstraintMap& first_map = *maybe_first_map;

for (const auto& [expr, interval] : second_map) {
if (first_map.contains(expr)) {
AffineMapPrinter printer;
VLOG(1) << "Got two different constraints for expression "
<< printer.ToString(expr);
return std::nullopt;
}

first_map.insert({expr, interval});
}

return first_map;
}

struct SizeAndStrideExpression {
AffineExpr size;
AffineExpr stride;
Expand Down Expand Up @@ -620,6 +583,29 @@ AffineExpr SimplifyAffineExpr(const AffineExpr& expr,

} // anonymous namespace

std::optional<ConstraintMap> MergeConstraintMapIfPresentAndCompatible(
std::optional<ConstraintMap> maybe_first_map,
const ConstraintMap& second_map) {
if (!maybe_first_map.has_value()) {
return std::nullopt;
}

ConstraintMap& first_map = *maybe_first_map;

for (const auto& [expr, interval] : second_map) {
if (first_map.contains(expr)) {
AffineMapPrinter printer;
VLOG(1) << "Got two different constraints for expression "
<< printer.ToString(expr);
return std::nullopt;
}

first_map.insert({expr, interval});
}

return first_map;
}

/*static*/ std::optional<SymbolicTile> SymbolicTile::FromIndexingMap(
const IndexingMap& indexing_map) {
VLOG(1) << "SymbolicTile::FromIndexingMap: " << indexing_map.ToString();
Expand Down
19 changes: 19 additions & 0 deletions third_party/xla/xla/service/gpu/model/symbolic_tile.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,25 @@ class SymbolicTile {
is_satisfiable_(is_satisfiable) {}
};

// Merges `maybe_first_map` and `second_map` if
// (1) `maybe_first_map` is present, and
// (2) `second_map` and `*maybe_first_map` have distinct sets of keys.
// Otherwise, returns `std::nullopt`.
//
//
// The behaviour of this function is in spirit equivalent to using C++23's
// `std::optional<T>::and_then` to merge a collection of `ConstraintMap`s.
//
// We pass `maybe_first_map` by value here in order to exploit move semantics
// to avoid copies when possible.
//
// TODO(bchetioui): allow merging constraints in more edge cases, e.g. if one
// of the intervals is contained within the other.
std::optional<SymbolicTile::ConstraintMap>
MergeConstraintMapIfPresentAndCompatible(
std::optional<SymbolicTile::ConstraintMap> maybe_first_map,
const SymbolicTile::ConstraintMap& second_map);

} // namespace gpu
} // namespace xla

Expand Down
Loading

0 comments on commit 7530918

Please sign in to comment.