From c395b48cdbbad799e93415fd2bf5c9d3726967b5 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 18 Jun 2024 02:55:15 -0700 Subject: [PATCH 01/59] [XLA:GPU] Add `bitcast` and `reshape` to the list of "passthrough" opcodes in Triton emitter. In order to test this exhaustively, change the `TritonType` derivation logic to propagate a `Status` instead of crashing whenever a mapping between the provided HLO type and Triton types has not been defined. Landing this as a single change makes sense because exhaustively testing `bitcast`s and `reshape`s is a rather canonical way of checking that this new error propagation logic works as intended. PiperOrigin-RevId: 644313457 --- third_party/xla/xla/service/gpu/BUILD | 4 +- .../xla/xla/service/gpu/ir_emitter_triton.cc | 119 +++++++++++------- .../xla/service/gpu/triton_support_test.cc | 68 +++++++++- 3 files changed, 143 insertions(+), 48 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index c6daa3e8d6745a..ea8feebb6aa687 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1259,16 +1259,16 @@ xla_test( deps = [ ":gpu_device_info_for_tests", ":ir_emitter_triton", - ":matmul_utils", ":triton_fusion_analysis", ":triton_support", ":triton_test_utils", + "//third_party/protobuf", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/lib/core:status_test_util", diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index fd03050d701fe3..d3e02d7e09237d 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -167,7 +167,7 @@ using mlir::ValueRange; namespace { // XLA -> Triton type conversions. -Type TritonType(mlir::OpBuilder b, PrimitiveType t) { +absl::StatusOr TritonType(mlir::OpBuilder b, PrimitiveType t) { switch (t) { case F64: return b.getF64Type(); @@ -195,8 +195,9 @@ Type TritonType(mlir::OpBuilder b, PrimitiveType t) { // Triton. return b.getFloat8E4M3FNUZType(); default: - LOG(FATAL) << "This type is not supported yet: " - << primitive_util::LowercasePrimitiveTypeName(t); + return absl::UnimplementedError( + absl::StrCat("This type is not supported yet: ", + primitive_util::LowercasePrimitiveTypeName(t))); } } @@ -485,8 +486,11 @@ absl::StatusOr EmitElementwise(ImplicitLocOpBuilder& b, case HloOpcode::kNegate: // NegFOp is not supported by Triton. return Subtract(b, {ZerosLike(b, inputs[0]), inputs[0]}); - case HloOpcode::kConvert: - return Cast(b, inputs[0], TritonType(b, hlo.shape().element_type())); + case HloOpcode::kConvert: { + TF_ASSIGN_OR_RETURN(Type dst_ty, + TritonType(b, hlo.shape().element_type())); + return Cast(b, inputs[0], dst_ty); + } case HloOpcode::kAdd: if (is_integer) { return b.create(inputs[0], inputs[1]); @@ -577,8 +581,9 @@ Value EmitParameterLoad(ImplicitLocOpBuilder& b, Value pointer, {}); } -Value EmitConstant(ImplicitLocOpBuilder& b, const HloInstruction& constant) { - Type ty = TritonType(b, constant.shape().element_type()); +absl::StatusOr EmitConstant(ImplicitLocOpBuilder& b, + const HloInstruction& constant) { + TF_ASSIGN_OR_RETURN(Type ty, TritonType(b, constant.shape().element_type())); if (constant.shape().IsInteger()) { if (constant.shape().element_type() == U64) { return CreateConst(b, ty, ScalarConstantValue(constant, U64)); @@ -681,13 +686,14 @@ absl::StatusOr EmitReduce(ImplicitLocOpBuilder& b, if (operand->opcode() == HloOpcode::kConvert) { TF_RET_CHECK(operand->operand(0)->opcode() == HloOpcode::kConstant); TF_RET_CHECK(operand->operand(0)->shape().element_type() == BF16); - PrimitiveType dest_ty = operand->shape().element_type(); - TF_RET_CHECK(dest_ty == F32); - neutral = EmitConstant(b, *operand->operand(0)); - neutral = Cast(b, neutral, TritonType(b, dest_ty)); + TF_RET_CHECK(operand->shape().element_type() == F32); + TF_ASSIGN_OR_RETURN(Type dst_ty, + TritonType(b, operand->shape().element_type())); + TF_ASSIGN_OR_RETURN(neutral, EmitConstant(b, *operand->operand(0))); + neutral = Cast(b, neutral, dst_ty); } else { TF_RET_CHECK(operand->opcode() == HloOpcode::kConstant); - neutral = EmitConstant(b, *operand); + TF_ASSIGN_OR_RETURN(neutral, EmitConstant(b, *operand)); } // Since every shape is padded to a power of 2 in Triton, the input tile may @@ -756,7 +762,9 @@ absl::StatusOr EmitReduce(ImplicitLocOpBuilder& b, result = Splat(b, result, {}); } - return Cast(b, result, TritonType(b, hlo_reduce.shape().element_type())); + TF_ASSIGN_OR_RETURN(Type result_ty, + TritonType(b, hlo_reduce.shape().element_type())); + return Cast(b, result, result_ty); } // Emit code corresponding to a fusion instruction somehow nested within the @@ -873,8 +881,9 @@ absl::StatusOr EmitTiledHloInstruction( if (hlo->opcode() == HloOpcode::kConstant && ShapeUtil::IsEffectiveScalar(hlo->shape())) { + TF_ASSIGN_OR_RETURN(Value constant, EmitConstant(b, *hlo)); // Splat makes it a tensor to avoid type mismatches. - return Splat(b, EmitConstant(b, *hlo), {}); + return Splat(b, constant, {}); } if (hlo->opcode() == HloOpcode::kBroadcast) { @@ -896,16 +905,18 @@ absl::StatusOr EmitTiledHloInstruction( return EmitElementwise(b, libdevice_path, device_info, *hlo, operands); } - if (hlo->opcode() == HloOpcode::kTranspose || - hlo->opcode() == HloOpcode::kSlice || hlo->opcode() == HloOpcode::kPad) { - // All these are currently supported only as operations on indices - // which are pushed to loads and stores. No operations on tiles are - // performed here. + // All these operations are currently supported only as operations on indices + // which are pushed to loads and stores. We don't generate any further code + // for these operations here. + std::vector passthrough_opcodes( + {HloOpcode::kBitcast, HloOpcode::kPad, HloOpcode::kReshape, + HloOpcode::kSlice, HloOpcode::kTranspose}); + if (absl::c_linear_search(passthrough_opcodes, hlo->opcode())) { return values[tiled_hlo.operand(0)]; } return absl::UnimplementedError( - absl::StrCat("Unsupported opcode: ", hlo->opcode())); + absl::StrCat("Unsupported operation ", hlo->ToString())); } // Emit sequence of instructions using compatible tiling ordered producers @@ -954,8 +965,9 @@ absl::StatusOr EmitScope( TF_RET_CHECK(values.contains(hlo)) << hlo->ToString(); continue; } else if (hlo->opcode() == HloOpcode::kConstant) { + TF_ASSIGN_OR_RETURN(Value constant, EmitConstant(b, *hlo)); // Splat makes it a tensor to avoid type mismatches. - result = Splat(b, EmitConstant(b, *hlo), {}); + result = Splat(b, constant, {}); } else if (hlo->opcode() == HloOpcode::kBroadcast) { TF_ASSIGN_OR_RETURN( result, EmitBroadcast(b, analysis, scope, tiled_dimensions, *hlo, @@ -1362,12 +1374,13 @@ class MatMulEmitterHelper { // TODO(b/266862493): Accumulator can be integer too. // Otherwise only f64 x f64 -> f64 uses f64 accumulator. - mlir::FloatType GetDotAccumulatorType() { + absl::StatusOr GetDotAccumulatorType() { const PrecisionConfig::Algorithm algorithm = dot_instr_->precision_config().algorithm(); if (algorithm == PrecisionConfig::ALG_UNSET) { - Type dot_output_ty = TritonType(b_, dot_instr_->shape().element_type()); + TF_ASSIGN_OR_RETURN(Type dot_output_ty, + TritonType(b_, dot_instr_->shape().element_type())); // The code below assumes that lhs and rhs have the same type. However // it's not always the case with fp8 matmuls, e.g. e4m3×e5m2 is supported // at the hardware level. NVidia GPU currently only supports f32 @@ -1377,14 +1390,14 @@ class MatMulEmitterHelper { } // Data type of dot() immediate inputs. - Type dot_input_ty = [&] { - const Type lhs_ty = - TritonType(b_, dot_instr_->operand(0)->shape().element_type()); - const Type rhs_ty = - TritonType(b_, dot_instr_->operand(1)->shape().element_type()); - CHECK(lhs_ty == rhs_ty); - return lhs_ty; - }(); + TF_ASSIGN_OR_RETURN( + const Type lhs_ty, + TritonType(b_, dot_instr_->operand(0)->shape().element_type())); + TF_ASSIGN_OR_RETURN( + const Type rhs_ty, + TritonType(b_, dot_instr_->operand(1)->shape().element_type())); + TF_RET_CHECK(lhs_ty == rhs_ty); + Type dot_input_ty = lhs_ty; // TODO(b/266862493): Accumulator can be integer too. // Otherwise only f64 x f64 -> f64 uses f64 accumulator. return (dot_output_ty.isF64() && dot_input_ty.isF64()) ? b_.getF64Type() @@ -1395,7 +1408,8 @@ class MatMulEmitterHelper { algorithm_util::GetDotAccumulatorType(algorithm); CHECK(accum_type.ok()) << "Unexpected algorithm: " << PrecisionConfig::Algorithm_Name(algorithm); - Type mlir_accum_type = TritonType(b_, accum_type.value()); + TF_ASSIGN_OR_RETURN(Type mlir_accum_type, + TritonType(b_, accum_type.value())); if (auto float_accum_type = mlir::dyn_cast(mlir_accum_type)) { return float_accum_type; @@ -2131,10 +2145,9 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, if (node->opcode() != HloOpcode::kConvert) { return false; } - Type in_type = - TritonType(builder, node->operand(0)->shape().element_type()); - Type out_type = TritonType(builder, node->shape().element_type()); - return in_type.getIntOrFloatBitWidth() <= 8 && out_type.isF32(); + int in_width = + primitive_util::BitWidth(node->operand(0)->shape().element_type()); + return in_width <= 8 && node->shape().element_type() == F32; }); // We'll be creating a lot of instructions from a single dot, use an @@ -2181,7 +2194,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, auto pid_n = b.create(b.create(pid_nc, c32(width)), group_size); - mlir::FloatType acc_ty = emitter.GetDotAccumulatorType(); + TF_ASSIGN_OR_RETURN(mlir::FloatType acc_ty, emitter.GetDotAccumulatorType()); ma::ConstantOp accumulator_init = CreateConst(b, acc_ty, 0, {block_m, block_n}); @@ -2229,8 +2242,17 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, size_t lsize = ScopeInputs(analysis, TritonFusionAnalysis::Scope::LHS).size(); size_t rsize = ScopeInputs(analysis, TritonFusionAnalysis::Scope::RHS).size(); + absl::flat_hash_map triton_type_for_input; + for (const Side& side : {lhs, rhs}) { + for (const HloInstruction* input : ScopeInputs(analysis, side.scope)) { + TF_ASSIGN_OR_RETURN(Type input_ty, + TritonType(b, input->shape().element_type())); + triton_type_for_input.insert({input, input_ty}); + } + } + auto body_builder = [&](mlir::OpBuilder&, mlir::Location, Value ki, - ValueRange iter_args) { + ValueRange iter_args) -> void { SmallVector iter_args_next; iter_args_next.reserve(iter_args.size()); std::array, 3> values; @@ -2243,7 +2265,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, const HloInstruction* param_hlo = iter_args_to_inputs[i]; Type param_ty = index == kLhsMetaOperandIdx ? b.getI16Type() - : TritonType(b, param_hlo->shape().element_type()); + : triton_type_for_input.at(param_hlo); Type param_storage_ty = StorageType(b, param_ty); Value param_value = EmitParameterLoad(b, iter_args[i], iter_args_to_boundary_checks[i]); @@ -2364,6 +2386,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, iter_args_next.push_back(accumulator_next); b.create(iter_args_next); + return; }; // Pointers to inputs of LHS scope, then RHS, then the accumulator @@ -2393,8 +2416,9 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, /*iterArgs=*/iter_args, body_builder) .getResult(iter_args.size() - 1); absl::flat_hash_map values_out; - values_out[dot_instr] = - Cast(b, acc_final, TritonType(b, dot_instr->shape().element_type())); + TF_ASSIGN_OR_RETURN(Type acc_final_ty, + TritonType(b, dot_instr->shape().element_type())); + values_out[dot_instr] = Cast(b, acc_final, acc_final_ty); // Emit the output scope. if (std::vector to_emit = @@ -2774,16 +2798,21 @@ absl::StatusOr> CreateTritonModule( SmallVector fn_arg_types; for (HloInstruction* p : hlo_computation->parameter_instructions()) { PrimitiveType type = p->shape().element_type(); - Type ir_type = type != U16 ? TritonType(b, type) : b.getI16Type(); + Type ir_type; + if (type == U16) { + ir_type = b.getI16Type(); + } else { + TF_ASSIGN_OR_RETURN(ir_type, TritonType(b, type)); + } fn_arg_types.push_back( mt::PointerType::get(StorageType(b, ir_type), mn::kGlobalMemorySpace)); } for (const ShapeUtil::IndexedShape& s : ShapeUtil::GetLeafShapes(fusion->shape())) { - fn_arg_types.push_back(mt::PointerType::get( - StorageType(b, TritonType(b, s.shape.element_type())), - mn::kGlobalMemorySpace)); + TF_ASSIGN_OR_RETURN(Type triton_ty, TritonType(b, s.shape.element_type())); + fn_arg_types.push_back(mt::PointerType::get(StorageType(b, triton_ty), + mn::kGlobalMemorySpace)); } auto fn = b.create(loc, fn_name, diff --git a/third_party/xla/xla/service/gpu/triton_support_test.cc b/third_party/xla/xla/service/gpu/triton_support_test.cc index 99fb5b7538515f..0d6f2998953362 100644 --- a/third_party/xla/xla/service/gpu/triton_support_test.cc +++ b/third_party/xla/xla/service/gpu/triton_support_test.cc @@ -20,11 +20,14 @@ limitations under the License. #include #include #include +#include #include #include +#include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "third_party/protobuf/descriptor.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/ir_emitter_triton.h" @@ -41,9 +44,72 @@ namespace xla { namespace gpu { namespace { -using UnaryElementwiseTest = TritonSupportTestWithParam; +using ::testing::Not; +using ::testing::status::IsOk; + +auto AllXlaDataTypes() { + std::vector xla_data_types; + std::vector to_filter_out = {PRIMITIVE_TYPE_INVALID, + TUPLE, OPAQUE_TYPE, TOKEN}; + const proto2::EnumDescriptor* xla_type_descriptor = + proto2::GetEnumDescriptor(); + for (int enum_ix = 0; enum_ix < xla_type_descriptor->value_count(); + ++enum_ix) { + xla::PrimitiveType xla_type = static_cast( + xla_type_descriptor->value(enum_ix)->number()); + if (!absl::c_linear_search(to_filter_out, xla_type)) { + xla_data_types.push_back(xla_type); + } + } + return ::testing::ValuesIn(xla_data_types); +} // TODO(b/343158720): remove references to TritonFusionAnalysis in this file. +// TODO(b/343158720): factor out implication tests into a util in order to +// simplify the test structure. +using BitcastOrReshapeTest = TritonSupportTestWithParam; + +TEST_P(BitcastOrReshapeTest, IsTritonSupportedBitcastOrReshape) { + auto [data_type, opcode] = GetParam(); + const std::string kHloTestTemplate = R"( +triton_computation { + parameter_0 = $0[1,16,4]{2,1,0} parameter(0) + ROOT bitcast_or_reshape = $0[64]{0} $1(parameter_0) +} + +ENTRY e { + parameter_0 = $0[1,16,4]{2,1,0} parameter(0) + ROOT root_op = $0[64]{0} fusion(parameter_0), + kind=kCustom, calls=triton_computation, + backend_config={"fusion_backend_config":{"kind":"__triton"}} +})"; + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); + if (IsTritonSupportedInstruction(ti.Instruction(), + GetCudaComputeCapability())) { + TF_EXPECT_OK(CreateTritonIrAndFileCheck(ti.TritonComputation(), + FromOutputTileSizes({16}), + "CHECK: tt.func @triton_fn")); + } else { + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + EXPECT_THAT( + TritonWrapper("test_fn", &ti.TritonFusion(), GetCudaComputeCapability(), + dev_info, FromOutputTileSizes({1}), &llvm_module_, + mlir_context_), + Not(IsOk())); + } +} + +INSTANTIATE_TEST_SUITE_P( + BitcastOrReshapeTestSuite, BitcastOrReshapeTest, + ::testing::Combine(AllXlaDataTypes(), + ::testing::Values(HloOpcode::kBitcast, + HloOpcode::kReshape)), + TritonSupportTestParamsToString); + +using UnaryElementwiseTest = TritonSupportTestWithParam; // TODO(b/331636835): updates elementwise op tests to directly emit single op // instead of relying on triton gemm kernel. From aa29a22fddfee7510e9d9239f56aecfa0260ef59 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 18 Jun 2024 03:29:28 -0700 Subject: [PATCH 02/59] [XLA:GPU] Add constraints to `SymbolicTileAnalysis`. Constraints are constructed from merging the constraints of all the `SymbolicTile`s encountered while constructing the resulting symbolic tiled HLO computation. If any of the `SymbolicTile`s is unsatisfiable, the construction of the `SymbolicTileAnalysis` object does not succeed. Likewise, construction fails if some constraints can not be merged with others. Constraints are now checked to be satisfied by the provided tile parameters when attempting to extract a `TiledHloComputation` out of `SymbolicTileAnalysis`. In order to avoid checking constraints too many times, we allow pinky-promising that the provided tile parameters satisfy the constraints, to voluntarily bypass the checks. PiperOrigin-RevId: 644321131 --- third_party/xla/xla/service/gpu/model/BUILD | 1 + .../xla/service/gpu/model/symbolic_tile.cc | 60 +++----- .../xla/xla/service/gpu/model/symbolic_tile.h | 19 +++ .../gpu/model/symbolic_tile_analysis.cc | 68 ++++++++- .../gpu/model/symbolic_tile_analysis.h | 30 +++- .../gpu/model/symbolic_tile_analysis_test.cc | 131 ++++++++++++++++-- 6 files changed, 252 insertions(+), 57 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 9da56e03fffedf..32150a534fbce2 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -691,6 +691,7 @@ xla_cc_test( "//xla/service/gpu:hlo_traversal", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc index 585e15ac9b5d9a..93aec8922eec80 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc @@ -106,43 +106,6 @@ AffineMap SubstituteAllIndicesAndRangeVarSymbolsWithSameValue( return simplifyAffineMap(affine_map.replace(indices, num_dims, num_symbols)); } -// Merges `maybe_first_map` and `second_map` if -// (1) `maybe_first_map` is present, and -// (2) `second_map` and `*maybe_first_map` have distinct sets of keys. -// Otherwise, returns `std::nullopt`. -// -// -// The behaviour of this function is in spirit equivalent to using C++23's -// `std::optional::and_then` to merge a collection of `ConstraintMap`s. -// -// We pass `maybe_first_map` by value here in order to exploit move semantics -// to avoid copies when possible. -// -// TODO(bchetioui): allow merging constraints in more edge cases, e.g. if one -// of the intervals is contained within the other. -std::optional MergeConstraintMapIfPresentAndCompatible( - std::optional maybe_first_map, - const ConstraintMap& second_map) { - if (!maybe_first_map.has_value()) { - return std::nullopt; - } - - ConstraintMap& first_map = *maybe_first_map; - - for (const auto& [expr, interval] : second_map) { - if (first_map.contains(expr)) { - AffineMapPrinter printer; - VLOG(1) << "Got two different constraints for expression " - << printer.ToString(expr); - return std::nullopt; - } - - first_map.insert({expr, interval}); - } - - return first_map; -} - struct SizeAndStrideExpression { AffineExpr size; AffineExpr stride; @@ -620,6 +583,29 @@ AffineExpr SimplifyAffineExpr(const AffineExpr& expr, } // anonymous namespace +std::optional MergeConstraintMapIfPresentAndCompatible( + std::optional maybe_first_map, + const ConstraintMap& second_map) { + if (!maybe_first_map.has_value()) { + return std::nullopt; + } + + ConstraintMap& first_map = *maybe_first_map; + + for (const auto& [expr, interval] : second_map) { + if (first_map.contains(expr)) { + AffineMapPrinter printer; + VLOG(1) << "Got two different constraints for expression " + << printer.ToString(expr); + return std::nullopt; + } + + first_map.insert({expr, interval}); + } + + return first_map; +} + /*static*/ std::optional SymbolicTile::FromIndexingMap( const IndexingMap& indexing_map) { VLOG(1) << "SymbolicTile::FromIndexingMap: " << indexing_map.ToString(); diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.h b/third_party/xla/xla/service/gpu/model/symbolic_tile.h index 654f42f0fcfc2b..ddce4de4699a28 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile.h +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.h @@ -225,6 +225,25 @@ class SymbolicTile { is_satisfiable_(is_satisfiable) {} }; +// Merges `maybe_first_map` and `second_map` if +// (1) `maybe_first_map` is present, and +// (2) `second_map` and `*maybe_first_map` have distinct sets of keys. +// Otherwise, returns `std::nullopt`. +// +// +// The behaviour of this function is in spirit equivalent to using C++23's +// `std::optional::and_then` to merge a collection of `ConstraintMap`s. +// +// We pass `maybe_first_map` by value here in order to exploit move semantics +// to avoid copies when possible. +// +// TODO(bchetioui): allow merging constraints in more edge cases, e.g. if one +// of the intervals is contained within the other. +std::optional +MergeConstraintMapIfPresentAndCompatible( + std::optional maybe_first_map, + const SymbolicTile::ConstraintMap& second_map); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc index a0c162ff883e9d..c2f56266782bdd 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -35,6 +35,8 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -61,6 +63,7 @@ namespace { using ::mlir::AffineExpr; using ::mlir::MLIRContext; +using ConstraintMap = SymbolicTile::ConstraintMap; // Computes indexing map from program id into the tile offset for the given // shape and tile sizes. @@ -162,6 +165,8 @@ absl::StatusOr ComputeBlockIdToTileOffsetIndexing( const HloInstructionAdaptor&, IndexingMap)> get_tiled_hlo_instruction; + std::optional constraints = ConstraintMap(); + // Create a new tiled hlo instruction or return existing instruction from // cache for the given hlo and indexing map. get_tiled_hlo_instruction = @@ -180,8 +185,6 @@ absl::StatusOr ComputeBlockIdToTileOffsetIndexing( // line. This is not an inherent limitation of the approach, but simply // issues to be resolved in the current implementation. if (hlo->opcode() == HloOpcode::kDot || - hlo->opcode() == HloOpcode::kReshape || - hlo->opcode() == HloOpcode::kBitcast || hlo->opcode() == HloOpcode::kConcatenate) { return FusionDecision{} << "Bailing out on " << hlo->ToString(); } @@ -199,6 +202,22 @@ absl::StatusOr ComputeBlockIdToTileOffsetIndexing( << hlo->ToString(); } + if (!symbolic_tile->is_satisfiable()) { + return FusionDecision{} << "Symbolic tile " << symbolic_tile->ToString() + << " is not satisfiable for " + << indexing_map.ToString() << " for HLO " + << hlo->ToString(); + } + + constraints = MergeConstraintMapIfPresentAndCompatible( + std::move(constraints), symbolic_tile->constraints()); + + if (!constraints.has_value()) { + return FusionDecision{} << "Failed to merge constraints of " + << hlo->ToString() << " in pre-existing " + << "constraint map"; + } + tiled_hlo_instructions.push_back( std::make_unique( hlo, std::move(indexing_map), std::move(*symbolic_tile))); @@ -258,12 +277,53 @@ absl::StatusOr ComputeBlockIdToTileOffsetIndexing( return topological_order.at(i1.get()) < topological_order.at(i2.get()); }); - return SymbolicTileAnalysis(std::move(tiled_hlo_instructions), ctx); + return SymbolicTileAnalysis(std::move(tiled_hlo_instructions), + std::move(*constraints), ctx); +} + +absl::StatusOr SymbolicTileAnalysis::ParametersSatisfyConstraints( + const std::vector& tile_parameters) const { + // Populate parameter map. + llvm::SmallVector parameters = llvm::to_vector( + llvm::map_range(tile_parameters, [this](const int64_t v) -> AffineExpr { + return mlir::getAffineConstantExpr(v, context_); + })); + + for (auto [constrained_expr, interval] : constraints_) { + AffineExpr constrained_expr_value = + constrained_expr.replaceSymbols(parameters); + if (constrained_expr_value.getKind() != mlir::AffineExprKind::Constant) { + return absl::InvalidArgumentError(absl::StrCat( + "Failed to reduce ", AffineMapPrinter().ToString(constrained_expr), + " to a constant with tile parameters ", + absl::StrJoin(tile_parameters, ", "))); + } + + int64_t constrained_value = + llvm::cast(constrained_expr_value).getValue(); + + if (constrained_value < interval.lower || + constrained_value > interval.upper) { + return false; + } + } + return true; } absl::StatusOr SymbolicTileAnalysis::ComputeTiledHloInstructions( - const std::vector& tile_parameters) const { + const std::vector& tile_parameters, + bool constraints_are_known_satisfied) const { + if (!constraints_are_known_satisfied) { + TF_ASSIGN_OR_RETURN(bool constraints_are_satisfied, + ParametersSatisfyConstraints(tile_parameters)); + if (!constraints_are_satisfied) { + return absl::InvalidArgumentError(absl::StrCat( + "Tile parameters ", absl::StrJoin(tile_parameters, ", "), + " do not satisfy the SymbolicTileAnalysis's constraints.")); + } + } + IndexingMap block_id_to_root_tile_offset = ComputeBlockIdToOutputTileIndexing( GetRoot()->hlo()->shape().dimensions(), tile_parameters, context_); diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h index e01c7335c88181..7a1d7ddaeb98c7 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h @@ -29,6 +29,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/symbolic_tile.h" #include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/instruction_fusion.h" @@ -56,8 +57,13 @@ class SymbolicTileAnalysis { const HloFusionAdaptor& fusion, mlir::MLIRContext* ctx); // Returns a graph of HLO instructions tiled with the given tile parameters. + // The provided tile parameters must satisfy the analysis's constraints. + // By default, `ComputetiledHloInstructions` performs a check that the + // constraints are satisfied by the chosen tiled parameters. Setting + // `constraints_are_known_satisfied` to true bypasses this check. absl::StatusOr ComputeTiledHloInstructions( - const std::vector& tile_parameters) const; + const std::vector& tile_parameters, + bool constraints_are_known_satisfied = false) const; // Returns the tiled root instruction. const SymbolicTiledHloInstruction* GetRoot() const { @@ -70,6 +76,23 @@ class SymbolicTileAnalysis { return symbolic_tiled_hlo_instructions_; } + // Returns the constraints for the parameters of the symbolic tiled HLO + // computation. This is the union of the constraints of all the symbolic tiles + // encountered throughout the computation. + const SymbolicTile::ConstraintMap& GetConstraints() const { + return constraints_; + } + + // Returns true if a list of tile parameters satisfies the symbolic tile + // analysis's constraints. + // + // Returns false if the constraints are not satisfied but can be evaluated + // correctly. Returns an error if the constraints cannot be evaluated + // correctly. This is typically the case if too few tile parameters are + // provided to fully reduce the constraint expressions to constants. + absl::StatusOr ParametersSatisfyConstraints( + const std::vector& tile_parameters) const; + // Return the underlying MLIRContext. mlir::MLIRContext* GetMLIRContext() const { return context_; }; @@ -81,15 +104,20 @@ class SymbolicTileAnalysis { private: SymbolicTileAnalysis(std::vector> symbolic_tiled_hlo_instructions, + SymbolicTile::ConstraintMap constraints, mlir::MLIRContext* context) : symbolic_tiled_hlo_instructions_( std::move(symbolic_tiled_hlo_instructions)), + constraints_(std::move(constraints)), context_(context) {} // The tiled HLO instructions in def-before-use order. std::vector> symbolic_tiled_hlo_instructions_; + // See the documentation of GetConstraints(). + SymbolicTile::ConstraintMap constraints_; + mlir::MLIRContext* context_; }; diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc index a3b0f13063fa1b..9e7258c44bd291 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -20,9 +20,11 @@ limitations under the License. #include #include #include +#include #include #include +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -42,6 +44,9 @@ namespace { using ::testing::ElementsAreArray; using ::testing::ExplainMatchResult; using ::testing::Matcher; +using ::testing::SizeIs; +using ::testing::status::IsOkAndHolds; +using ::testing::status::StatusIs; MATCHER_P3(MatchTiledHloInstructionImpl, tile_sizes, tile_strides, block_id_to_tile_offsets_indexing, "") { @@ -336,37 +341,41 @@ ENTRY main { p1 = f32[2,3]{1,0} parameter(1) ROOT fusion = f32[1,3]{1,0} fusion(p0, p1), kind=kLoop, calls=fusion })")); - EXPECT_EQ(TryAnalyzeModule(module.get()), std::nullopt); + EXPECT_FALSE(TryAnalyzeModule(module.get()).has_value()); } -TEST_F(SymbolicTileAnalysisTest, BailOutOnUnsupportedReshape) { +TEST_F(SymbolicTileAnalysisTest, DoesNotBailOutOnConstrainedReshape) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( fusion { - p0 = f32[1,2]{1,0} parameter(0) - ROOT reshape = f32[2] reshape(p0) + p0 = f32[4,2]{1,0} parameter(0) + ROOT reshape = f32[8] reshape(p0) } ENTRY main { - p0 = f32[1,2]{1,0} parameter(0) - ROOT fusion = f32[2] fusion(p0), kind=kLoop, calls=fusion + p0 = f32[4,2]{1,0} parameter(0) + ROOT fusion = f32[8] fusion(p0), kind=kLoop, calls=fusion })")); - EXPECT_EQ(TryAnalyzeModule(module.get()), std::nullopt); + std::optional analysis = TryAnalyzeModule(module.get()); + ASSERT_TRUE(analysis.has_value()); + EXPECT_THAT(analysis->GetConstraints(), SizeIs(1)); } -TEST_F(SymbolicTileAnalysisTest, BailOutOnUnsupportedBitcast) { +TEST_F(SymbolicTileAnalysisTest, DoesNotBailOutOnConstrainedBitcast) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( fusion { - p0 = f32[1,2]{1,0} parameter(0) - ROOT bitcast = f32[2] bitcast(p0) + p0 = f32[4,2]{1,0} parameter(0) + ROOT bitcast = f32[8] bitcast(p0) } ENTRY main { - p0 = f32[1,2]{1,0} parameter(0) - ROOT fusion = f32[2] fusion(p0), kind=kLoop, calls=fusion + p0 = f32[4,2]{1,0} parameter(0) + ROOT fusion = f32[8] fusion(p0), kind=kLoop, calls=fusion })")); - EXPECT_EQ(TryAnalyzeModule(module.get()), std::nullopt); + std::optional analysis = TryAnalyzeModule(module.get()); + ASSERT_TRUE(analysis.has_value()); + EXPECT_THAT(analysis->GetConstraints(), SizeIs(1)); } TEST_F(SymbolicTileAnalysisTest, BailOutOnUnsupportedConcatenate) { @@ -383,7 +392,7 @@ ENTRY main { p1 = f32[1,3]{1,0} parameter(1) ROOT fusion = f32[2,3] fusion(p0, p1), kind=kLoop, calls=fusion })")); - EXPECT_EQ(TryAnalyzeModule(module.get()), std::nullopt); + EXPECT_FALSE(TryAnalyzeModule(module.get()).has_value()); } TEST_F(SymbolicTileAnalysisTest, MultiOutputFusionIsNotSupported) { @@ -402,7 +411,99 @@ ENTRY main { p1 = f32[32] parameter(1) ROOT fusion = (f32[32], f32[32]) fusion(p0, p1), kind=kLoop, calls=fusion })")); - EXPECT_EQ(TryAnalyzeModule(module.get()), std::nullopt); + EXPECT_FALSE(TryAnalyzeModule(module.get()).has_value()); +} + +TEST_F(SymbolicTileAnalysisTest, ConstraintSatisfactionIsEvaluatedCorrectly) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +fusion { + p0 = f32[1,8,6,4,8]{4,3,2,1,0} parameter(0) + ROOT bitcast = f32[48,32]{1,0} bitcast(p0) +} + +ENTRY main { + p0 = f32[1,8,6,4,8]{4,3,2,1,0} parameter(0) + ROOT fusion = f32[48,32]{1,0} fusion(p0), kind=kLoop, calls=fusion +})")); + std::optional analysis = TryAnalyzeModule(module.get()); + ASSERT_TRUE(analysis.has_value()); + EXPECT_THAT(analysis->GetConstraints(), SizeIs(2)); + + // We expect the constraints here to be + // s0 mod 6 in [0, 0] + // s1 mod 8 in [0, 0] + // We expect tile sizes {6, 8} to satisfy these constraints. + std::vector possible_tile_parameters({6, 8}); + EXPECT_THAT(analysis->ParametersSatisfyConstraints(possible_tile_parameters), + IsOkAndHolds(true)); + + // However, we do not expect tile sizes {6, 7} to satisfy these constraints. + std::vector impossible_tile_parameters({6, 7}); + EXPECT_THAT( + analysis->ParametersSatisfyConstraints(impossible_tile_parameters), + IsOkAndHolds(false)); + + // Passing too few tile parameters results in an error since constraints can + // not be properly evaluated. + EXPECT_THAT(analysis->ParametersSatisfyConstraints(/*tile_parameters==*/{6}), + StatusIs(absl::StatusCode::kInvalidArgument)); + + // Passing tile parameters that satisfy the constraints should let us compute + // a TiledHloComputation. + EXPECT_OK(analysis->ParametersSatisfyConstraints(possible_tile_parameters)); + + // Passing tile parameters that do not satisfy the constraints should result + // in an error... + EXPECT_THAT(analysis->ComputeTiledHloInstructions(impossible_tile_parameters), + StatusIs(absl::StatusCode::kInvalidArgument)); + + // ... unless we pinky-promise (lie) that they satisfy the constraints ;) + EXPECT_OK(analysis->ComputeTiledHloInstructions( + impossible_tile_parameters, /*constraints_are_known_satisfied=*/true)); +} + +TEST_F(SymbolicTileAnalysisTest, ConstraintsAreAggregatedCorrectly) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +fusion { + p0 = f32[1,48,4,8]{3,2,1,0} parameter(0) + p1 = f32[1,8,6,32]{3,2,1,0} parameter(1) + bitcast_p0 = f32[48,32]{1,0} bitcast(p0) + bitcast_p1 = f32[48,32]{1,0} bitcast(p1) + ROOT add = f32[48,32]{1,0} add(bitcast_p0, bitcast_p1) +} + +ENTRY main { + p0 = f32[1,48,4,8]{3,2,1,0} parameter(0) + p1 = f32[1,8,6,32]{3,2,1,0} parameter(1) + ROOT fusion = f32[48,32]{1,0} fusion(p0, p1), kind=kLoop, calls=fusion +})")); + std::optional analysis = TryAnalyzeModule(module.get()); + ASSERT_TRUE(analysis.has_value()); + // Each bitcast in the above module introduces one constraint. Once they are + // aggregated, we have two! + EXPECT_THAT(analysis->GetConstraints(), SizeIs(2)); +} + +TEST_F(SymbolicTileAnalysisTest, BailsOutWhenConstraintsCanNotBeMerged) { + // TODO(bchetioui): allow merging a constraint with itself. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +fusion { + p0 = f32[1,48,4,8]{3,2,1,0} parameter(0) + p1 = f32[1,48,4,8]{3,2,1,0} parameter(1) + bitcast_p0 = f32[48,32]{1,0} bitcast(p0) + bitcast_p1 = f32[48,32]{1,0} bitcast(p1) + ROOT add = f32[48,32]{1,0} add(bitcast_p0, bitcast_p1) +} + +ENTRY main { + p0 = f32[1,48,4,8]{3,2,1,0} parameter(0) + p1 = f32[1,48,4,8]{3,2,1,0} parameter(1) + ROOT fusion = f32[48,32]{1,0} fusion(p0, p1), kind=kLoop, calls=fusion +})")); + EXPECT_FALSE(TryAnalyzeModule(module.get()).has_value()); } } // namespace From 19b82d6a7a9e8f1608a38f19c712313f2413ecd8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 18 Jun 2024 04:26:13 -0700 Subject: [PATCH 03/59] Integrate LLVM at llvm/llvm-project@93ffe1792fd9 Updates LLVM usage to match [93ffe1792fd9](https://github.com/llvm/llvm-project/commit/93ffe1792fd9) PiperOrigin-RevId: 644334496 --- .../transforms/functional_to_region/pass.cc | 3 +- .../transforms/region_to_functional/pass.cc | 3 +- third_party/llvm/generated.patch | 516 ++++++++++++++---- third_party/llvm/workspace.bzl | 4 +- third_party/stablehlo/temporary.patch | 43 +- .../triton/llvm_integration/cl643947742.patch | 13 + .../triton/llvm_integration/series.bzl | 1 + .../xla/third_party/stablehlo/temporary.patch | 43 +- .../triton/llvm_integration/cl643947742.patch | 13 + .../triton/llvm_integration/series.bzl | 1 + 10 files changed, 509 insertions(+), 131 deletions(-) create mode 100644 third_party/triton/llvm_integration/cl643947742.patch create mode 100644 third_party/xla/third_party/triton/llvm_integration/cl643947742.patch diff --git a/tensorflow/core/transforms/functional_to_region/pass.cc b/tensorflow/core/transforms/functional_to_region/pass.cc index d98f7191b8792d..6d21eb179bc6b4 100644 --- a/tensorflow/core/transforms/functional_to_region/pass.cc +++ b/tensorflow/core/transforms/functional_to_region/pass.cc @@ -44,7 +44,8 @@ struct FunctionalToRegionPass // Use top-down traversal for more efficient conversion. Disable region // simplification as all regions are single block. config.useTopDownTraversal = true; - config.enableRegionSimplification = false; + config.enableRegionSimplification = + mlir::GreedySimplifyRegionLevel::Disabled; // If there are deeply nested conditionals, instantiating them too deep will // cause the verifiers, which are implemented recursively, to stack // overflow. Set a relatively low iteration limit. diff --git a/tensorflow/core/transforms/region_to_functional/pass.cc b/tensorflow/core/transforms/region_to_functional/pass.cc index 867fa8f4c4393d..62d7d5061a68af 100644 --- a/tensorflow/core/transforms/region_to_functional/pass.cc +++ b/tensorflow/core/transforms/region_to_functional/pass.cc @@ -48,7 +48,8 @@ struct RegionToFunctionalPass // Use top-down traversal for more efficient conversion. Disable region // simplification as all regions are single block. config.useTopDownTraversal = true; - config.enableRegionSimplification = false; + config.enableRegionSimplification = + mlir::GreedySimplifyRegionLevel::Disabled; // Iterate until all regions have been outlined. This is guaranteed to // terminate because the IR can only hold a finite depth of regions. config.maxIterations = GreedyRewriteConfig::kNoLimit; diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index f7ef792b9350c0..1c001c5121cf1a 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,122 +1,396 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/libcxx/docs/Status/Cxx2cIssues.csv b/libcxx/docs/Status/Cxx2cIssues.csv ---- a/libcxx/docs/Status/Cxx2cIssues.csv -+++ b/libcxx/docs/Status/Cxx2cIssues.csv -@@ -65,5 +65,4 @@ - "`3343 `__","Ordering of calls to ``unlock()`` and ``notify_all()`` in Effects element of ``notify_all_at_thread_exit()`` should be reversed","Not Yet Adopted","|Complete|","16.0","" - "XXXX","","The sys_info range should be affected by save","Not Yet Adopted","|Complete|","19.0" - "`4071 `__","","``reference_wrapper`` comparisons are not SFINAE-friendly","Not Yet Adopted","|Complete|","19.0" --"`4110 `__","","``shared_ptr(nullptr_t, Deleter)`` is overconstrained, breaking some sensible deleters","Not Yet Adopted","|Complete|","19.0" - "","","","","","" -diff -ruN --strip-trailing-cr a/libcxx/include/__memory/shared_ptr.h b/libcxx/include/__memory/shared_ptr.h ---- a/libcxx/include/__memory/shared_ptr.h -+++ b/libcxx/include/__memory/shared_ptr.h -@@ -404,7 +404,7 @@ - }; - - template --using __shared_ptr_nullptr_deleter_ctor_reqs = _And, __well_formed_deleter<_Dp, _Tp*> >; -+using __shared_ptr_nullptr_deleter_ctor_reqs = _And, __well_formed_deleter<_Dp, nullptr_t> >; - - #if defined(_LIBCPP_ABI_ENABLE_SHARED_PTR_TRIVIAL_ABI) - # define _LIBCPP_SHARED_PTR_TRIVIAL_ABI __attribute__((__trivial_abi__)) -diff -ruN --strip-trailing-cr a/libcxx/test/std/utilities/memory/util.smartptr/util.smartptr.shared/util.smartptr.shared.const/nullptr_t_deleter_allocator.pass.cpp b/libcxx/test/std/utilities/memory/util.smartptr/util.smartptr.shared/util.smartptr.shared.const/nullptr_t_deleter_allocator.pass.cpp ---- a/libcxx/test/std/utilities/memory/util.smartptr/util.smartptr.shared/util.smartptr.shared.const/nullptr_t_deleter_allocator.pass.cpp -+++ b/libcxx/test/std/utilities/memory/util.smartptr/util.smartptr.shared/util.smartptr.shared.const/nullptr_t_deleter_allocator.pass.cpp -@@ -33,21 +33,17 @@ - // LWG 3233. Broken requirements for shared_ptr converting constructors - // https://cplusplus.github.io/LWG/issue3233 - static_assert( std::is_constructible, std::nullptr_t, test_deleter, test_allocator >::value, ""); --static_assert(!std::is_constructible, std::nullptr_t, bad_deleter, test_allocator >::value, -- ""); -+static_assert(!std::is_constructible, std::nullptr_t, bad_deleter, test_allocator >::value, ""); -+static_assert(!std::is_constructible, std::nullptr_t, no_nullptr_deleter, test_allocator >::value, ""); - static_assert(!std::is_constructible, std::nullptr_t, no_move_deleter, test_allocator >::value, ""); - - #if TEST_STD_VER >= 17 --static_assert( -- std::is_constructible, std::nullptr_t, test_deleter, test_allocator >::value, -- ""); -+static_assert( std::is_constructible, std::nullptr_t, test_deleter, test_allocator >::value, ""); - static_assert(!std::is_constructible, std::nullptr_t, bad_deleter, test_allocator >::value, ""); - static_assert(!std::is_constructible, std::nullptr_t, no_nullptr_deleter, test_allocator >::value, ""); - static_assert(!std::is_constructible, std::nullptr_t, no_move_deleter, test_allocator >::value, ""); - --static_assert( -- std::is_constructible, std::nullptr_t, test_deleter, test_allocator >::value, -- ""); -+static_assert( std::is_constructible, std::nullptr_t, test_deleter, test_allocator >::value, ""); - static_assert(!std::is_constructible, std::nullptr_t, bad_deleter, test_allocator >::value, ""); - static_assert(!std::is_constructible, std::nullptr_t, no_nullptr_deleter, test_allocator >::value, ""); - static_assert(!std::is_constructible, std::nullptr_t, no_move_deleter, test_allocator >::value, ""); -diff -ruN --strip-trailing-cr a/libcxx/test/std/utilities/memory/util.smartptr/util.smartptr.shared/util.smartptr.shared.const/nullptr_t_deleter.pass.cpp b/libcxx/test/std/utilities/memory/util.smartptr/util.smartptr.shared/util.smartptr.shared.const/nullptr_t_deleter.pass.cpp ---- a/libcxx/test/std/utilities/memory/util.smartptr/util.smartptr.shared/util.smartptr.shared.const/nullptr_t_deleter.pass.cpp -+++ b/libcxx/test/std/utilities/memory/util.smartptr/util.smartptr.shared/util.smartptr.shared.const/nullptr_t_deleter.pass.cpp -@@ -32,16 +32,17 @@ - // LWG 3233. Broken requirements for shared_ptr converting constructors - // https://cplusplus.github.io/LWG/issue3233 - static_assert( std::is_constructible, std::nullptr_t, test_deleter >::value, ""); --static_assert(!std::is_constructible, std::nullptr_t, bad_deleter>::value, ""); -+static_assert(!std::is_constructible, std::nullptr_t, bad_deleter>::value, ""); -+static_assert(!std::is_constructible, std::nullptr_t, no_nullptr_deleter>::value, ""); - static_assert(!std::is_constructible, std::nullptr_t, no_move_deleter>::value, ""); - - #if TEST_STD_VER >= 17 --static_assert(std::is_constructible, std::nullptr_t, test_deleter >::value, ""); -+static_assert( std::is_constructible, std::nullptr_t, test_deleter >::value, ""); - static_assert(!std::is_constructible, std::nullptr_t, bad_deleter>::value, ""); - static_assert(!std::is_constructible, std::nullptr_t, no_nullptr_deleter>::value, ""); - static_assert(!std::is_constructible, std::nullptr_t, no_move_deleter>::value, ""); - --static_assert(std::is_constructible, std::nullptr_t, test_deleter >::value, ""); -+static_assert( std::is_constructible, std::nullptr_t, test_deleter >::value, ""); - static_assert(!std::is_constructible, std::nullptr_t, bad_deleter>::value, ""); - static_assert(!std::is_constructible, std::nullptr_t, no_nullptr_deleter>::value, ""); - static_assert(!std::is_constructible, std::nullptr_t, no_move_deleter>::value, ""); -diff -ruN --strip-trailing-cr a/libcxx/test/std/utilities/memory/util.smartptr/util.smartptr.shared/util.smartptr.shared.const/pointer_deleter_allocator.pass.cpp b/libcxx/test/std/utilities/memory/util.smartptr/util.smartptr.shared/util.smartptr.shared.const/pointer_deleter_allocator.pass.cpp ---- a/libcxx/test/std/utilities/memory/util.smartptr/util.smartptr.shared/util.smartptr.shared.const/pointer_deleter_allocator.pass.cpp -+++ b/libcxx/test/std/utilities/memory/util.smartptr/util.smartptr.shared/util.smartptr.shared.const/pointer_deleter_allocator.pass.cpp -@@ -165,13 +165,5 @@ - test_allocator >::value, ""); - } - --#if TEST_STD_VER >= 14 -- { -- // LWG 4110 -- auto deleter = [](auto pointer) { delete pointer; }; -- std::shared_ptr p(new int, deleter, std::allocator()); -- } --#endif -- - return 0; - } -diff -ruN --strip-trailing-cr a/libcxx/test/std/utilities/memory/util.smartptr/util.smartptr.shared/util.smartptr.shared.const/pointer_deleter.pass.cpp b/libcxx/test/std/utilities/memory/util.smartptr/util.smartptr.shared/util.smartptr.shared.const/pointer_deleter.pass.cpp ---- a/libcxx/test/std/utilities/memory/util.smartptr/util.smartptr.shared/util.smartptr.shared.const/pointer_deleter.pass.cpp -+++ b/libcxx/test/std/utilities/memory/util.smartptr/util.smartptr.shared/util.smartptr.shared.const/pointer_deleter.pass.cpp -@@ -115,14 +115,6 @@ - } - #endif // TEST_STD_VER >= 11 - --#if TEST_STD_VER >= 14 -- { -- // LWG 4110 -- auto deleter = [](auto pointer) { delete pointer; }; -- std::shared_ptr p(new int, deleter); -- } --#endif -- - test_function_type(); - return 0; - } -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel -@@ -1051,7 +1051,7 @@ - ":CAPIDebugObjects", - ":CAPIIRObjects", - ":CAPIInterfacesObjects", -- ":CAPITransformObjects", -+ ":CAPITransformsObjects", - ], - ) - +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/abs_i16.ll b/llvm/test/CodeGen/AMDGPU/abs_i16.ll +--- a/llvm/test/CodeGen/AMDGPU/abs_i16.ll ++++ b/llvm/test/CodeGen/AMDGPU/abs_i16.ll +@@ -98,10 +98,10 @@ + ; GFX8-LABEL: v_abs_v2i16: + ; GFX8: ; %bb.0: + ; GFX8-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX8-NEXT: v_lshrrev_b32_e32 v1, 16, v0 +-; GFX8-NEXT: v_sub_u16_e32 v2, 0, v1 +-; GFX8-NEXT: v_max_i16_sdwa v1, v1, v2 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD ++; GFX8-NEXT: v_mov_b32_e32 v1, 0 ++; GFX8-NEXT: v_sub_u16_sdwa v1, v1, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 + ; GFX8-NEXT: v_sub_u16_e32 v2, 0, v0 ++; GFX8-NEXT: v_max_i16_sdwa v1, v0, v1 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD + ; GFX8-NEXT: v_max_i16_e32 v0, v0, v2 + ; GFX8-NEXT: v_or_b32_e32 v0, v0, v1 + ; GFX8-NEXT: s_setpc_b64 s[30:31] +@@ -181,12 +181,12 @@ + ; GFX8-LABEL: v_abs_v3i16: + ; GFX8: ; %bb.0: + ; GFX8-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX8-NEXT: v_lshrrev_b32_e32 v2, 16, v0 +-; GFX8-NEXT: v_sub_u16_e32 v3, 0, v2 +-; GFX8-NEXT: v_max_i16_sdwa v2, v2, v3 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD ++; GFX8-NEXT: v_mov_b32_e32 v2, 0 + ; GFX8-NEXT: v_sub_u16_e32 v3, 0, v1 ++; GFX8-NEXT: v_sub_u16_sdwa v2, v2, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 + ; GFX8-NEXT: v_max_i16_e32 v1, v1, v3 + ; GFX8-NEXT: v_sub_u16_e32 v3, 0, v0 ++; GFX8-NEXT: v_max_i16_sdwa v2, v0, v2 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD + ; GFX8-NEXT: v_max_i16_e32 v0, v0, v3 + ; GFX8-NEXT: v_or_b32_e32 v0, v0, v2 + ; GFX8-NEXT: s_setpc_b64 s[30:31] +@@ -286,18 +286,17 @@ + ; GFX8-LABEL: v_abs_v4i16: + ; GFX8: ; %bb.0: + ; GFX8-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX8-NEXT: v_lshrrev_b32_e32 v2, 16, v1 +-; GFX8-NEXT: v_sub_u16_e32 v3, 0, v2 +-; GFX8-NEXT: v_max_i16_sdwa v2, v2, v3 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v3, 16, v0 +-; GFX8-NEXT: v_sub_u16_e32 v4, 0, v3 +-; GFX8-NEXT: v_max_i16_sdwa v3, v3, v4 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD ++; GFX8-NEXT: v_mov_b32_e32 v2, 0 ++; GFX8-NEXT: v_sub_u16_sdwa v3, v2, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_sub_u16_sdwa v2, v2, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 + ; GFX8-NEXT: v_sub_u16_e32 v4, 0, v1 + ; GFX8-NEXT: v_sub_u16_e32 v5, 0, v0 ++; GFX8-NEXT: v_max_i16_sdwa v3, v1, v3 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_sdwa v2, v0, v2 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD + ; GFX8-NEXT: v_max_i16_e32 v0, v0, v5 + ; GFX8-NEXT: v_max_i16_e32 v1, v1, v4 +-; GFX8-NEXT: v_or_b32_e32 v0, v0, v3 +-; GFX8-NEXT: v_or_b32_e32 v1, v1, v2 ++; GFX8-NEXT: v_or_b32_e32 v0, v0, v2 ++; GFX8-NEXT: v_or_b32_e32 v1, v1, v3 + ; GFX8-NEXT: s_setpc_b64 s[30:31] + ; + ; GFX9-LABEL: v_abs_v4i16: +@@ -411,24 +410,22 @@ + ; GFX8-LABEL: v_abs_v6i16: + ; GFX8: ; %bb.0: + ; GFX8-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX8-NEXT: v_lshrrev_b32_e32 v3, 16, v2 +-; GFX8-NEXT: v_sub_u16_e32 v4, 0, v3 +-; GFX8-NEXT: v_max_i16_sdwa v3, v3, v4 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v4, 16, v1 +-; GFX8-NEXT: v_sub_u16_e32 v5, 0, v4 +-; GFX8-NEXT: v_max_i16_sdwa v4, v4, v5 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v5, 16, v0 +-; GFX8-NEXT: v_sub_u16_e32 v6, 0, v5 +-; GFX8-NEXT: v_max_i16_sdwa v5, v5, v6 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD ++; GFX8-NEXT: v_mov_b32_e32 v3, 0 ++; GFX8-NEXT: v_sub_u16_sdwa v4, v3, v2 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_sub_u16_sdwa v5, v3, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_sub_u16_sdwa v3, v3, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 + ; GFX8-NEXT: v_sub_u16_e32 v6, 0, v2 + ; GFX8-NEXT: v_sub_u16_e32 v7, 0, v1 + ; GFX8-NEXT: v_sub_u16_e32 v8, 0, v0 ++; GFX8-NEXT: v_max_i16_sdwa v4, v2, v4 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_sdwa v5, v1, v5 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_sdwa v3, v0, v3 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD + ; GFX8-NEXT: v_max_i16_e32 v0, v0, v8 + ; GFX8-NEXT: v_max_i16_e32 v1, v1, v7 + ; GFX8-NEXT: v_max_i16_e32 v2, v2, v6 +-; GFX8-NEXT: v_or_b32_e32 v0, v0, v5 +-; GFX8-NEXT: v_or_b32_e32 v1, v1, v4 +-; GFX8-NEXT: v_or_b32_e32 v2, v2, v3 ++; GFX8-NEXT: v_or_b32_e32 v0, v0, v3 ++; GFX8-NEXT: v_or_b32_e32 v1, v1, v5 ++; GFX8-NEXT: v_or_b32_e32 v2, v2, v4 + ; GFX8-NEXT: s_setpc_b64 s[30:31] + ; + ; GFX9-LABEL: v_abs_v6i16: +@@ -572,30 +569,27 @@ + ; GFX8-LABEL: v_abs_v8i16: + ; GFX8: ; %bb.0: + ; GFX8-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX8-NEXT: v_lshrrev_b32_e32 v4, 16, v3 +-; GFX8-NEXT: v_sub_u16_e32 v5, 0, v4 +-; GFX8-NEXT: v_max_i16_sdwa v4, v4, v5 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v5, 16, v2 +-; GFX8-NEXT: v_sub_u16_e32 v6, 0, v5 +-; GFX8-NEXT: v_max_i16_sdwa v5, v5, v6 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v6, 16, v1 +-; GFX8-NEXT: v_sub_u16_e32 v7, 0, v6 +-; GFX8-NEXT: v_max_i16_sdwa v6, v6, v7 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v7, 16, v0 +-; GFX8-NEXT: v_sub_u16_e32 v8, 0, v7 +-; GFX8-NEXT: v_max_i16_sdwa v7, v7, v8 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD ++; GFX8-NEXT: v_mov_b32_e32 v4, 0 ++; GFX8-NEXT: v_sub_u16_sdwa v5, v4, v3 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_sub_u16_sdwa v6, v4, v2 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_sub_u16_sdwa v7, v4, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_sub_u16_sdwa v4, v4, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 + ; GFX8-NEXT: v_sub_u16_e32 v8, 0, v3 + ; GFX8-NEXT: v_sub_u16_e32 v9, 0, v2 + ; GFX8-NEXT: v_sub_u16_e32 v10, 0, v1 + ; GFX8-NEXT: v_sub_u16_e32 v11, 0, v0 ++; GFX8-NEXT: v_max_i16_sdwa v5, v3, v5 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_sdwa v6, v2, v6 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_sdwa v7, v1, v7 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_sdwa v4, v0, v4 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD + ; GFX8-NEXT: v_max_i16_e32 v0, v0, v11 + ; GFX8-NEXT: v_max_i16_e32 v1, v1, v10 + ; GFX8-NEXT: v_max_i16_e32 v2, v2, v9 + ; GFX8-NEXT: v_max_i16_e32 v3, v3, v8 +-; GFX8-NEXT: v_or_b32_e32 v0, v0, v7 +-; GFX8-NEXT: v_or_b32_e32 v1, v1, v6 +-; GFX8-NEXT: v_or_b32_e32 v2, v2, v5 +-; GFX8-NEXT: v_or_b32_e32 v3, v3, v4 ++; GFX8-NEXT: v_or_b32_e32 v0, v0, v4 ++; GFX8-NEXT: v_or_b32_e32 v1, v1, v7 ++; GFX8-NEXT: v_or_b32_e32 v2, v2, v6 ++; GFX8-NEXT: v_or_b32_e32 v3, v3, v5 + ; GFX8-NEXT: s_setpc_b64 s[30:31] + ; + ; GFX9-LABEL: v_abs_v8i16: +@@ -820,30 +814,15 @@ + ; GFX8-LABEL: v_abs_v16i16: + ; GFX8: ; %bb.0: + ; GFX8-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX8-NEXT: v_lshrrev_b32_e32 v8, 16, v7 +-; GFX8-NEXT: v_sub_u16_e32 v9, 0, v8 +-; GFX8-NEXT: v_max_i16_sdwa v8, v8, v9 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v9, 16, v6 +-; GFX8-NEXT: v_sub_u16_e32 v10, 0, v9 +-; GFX8-NEXT: v_max_i16_sdwa v9, v9, v10 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v10, 16, v5 +-; GFX8-NEXT: v_sub_u16_e32 v11, 0, v10 +-; GFX8-NEXT: v_max_i16_sdwa v10, v10, v11 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v11, 16, v4 +-; GFX8-NEXT: v_sub_u16_e32 v12, 0, v11 +-; GFX8-NEXT: v_max_i16_sdwa v11, v11, v12 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v12, 16, v3 +-; GFX8-NEXT: v_sub_u16_e32 v13, 0, v12 +-; GFX8-NEXT: v_max_i16_sdwa v12, v12, v13 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v13, 16, v2 +-; GFX8-NEXT: v_sub_u16_e32 v14, 0, v13 +-; GFX8-NEXT: v_max_i16_sdwa v13, v13, v14 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v14, 16, v1 +-; GFX8-NEXT: v_sub_u16_e32 v15, 0, v14 +-; GFX8-NEXT: v_max_i16_sdwa v14, v14, v15 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v15, 16, v0 +-; GFX8-NEXT: v_sub_u16_e32 v16, 0, v15 +-; GFX8-NEXT: v_max_i16_sdwa v15, v15, v16 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD ++; GFX8-NEXT: v_mov_b32_e32 v8, 0 ++; GFX8-NEXT: v_sub_u16_sdwa v9, v8, v7 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_sub_u16_sdwa v10, v8, v6 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_sub_u16_sdwa v11, v8, v5 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_sub_u16_sdwa v12, v8, v4 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_sub_u16_sdwa v13, v8, v3 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_sub_u16_sdwa v14, v8, v2 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_sub_u16_sdwa v15, v8, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_sub_u16_sdwa v8, v8, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 + ; GFX8-NEXT: v_sub_u16_e32 v16, 0, v7 + ; GFX8-NEXT: v_sub_u16_e32 v17, 0, v6 + ; GFX8-NEXT: v_sub_u16_e32 v18, 0, v5 +@@ -852,6 +831,14 @@ + ; GFX8-NEXT: v_sub_u16_e32 v21, 0, v2 + ; GFX8-NEXT: v_sub_u16_e32 v22, 0, v1 + ; GFX8-NEXT: v_sub_u16_e32 v23, 0, v0 ++; GFX8-NEXT: v_max_i16_sdwa v9, v7, v9 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_sdwa v10, v6, v10 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_sdwa v11, v5, v11 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_sdwa v12, v4, v12 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_sdwa v13, v3, v13 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_sdwa v14, v2, v14 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_sdwa v15, v1, v15 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_sdwa v8, v0, v8 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD + ; GFX8-NEXT: v_max_i16_e32 v0, v0, v23 + ; GFX8-NEXT: v_max_i16_e32 v1, v1, v22 + ; GFX8-NEXT: v_max_i16_e32 v2, v2, v21 +@@ -860,14 +847,14 @@ + ; GFX8-NEXT: v_max_i16_e32 v5, v5, v18 + ; GFX8-NEXT: v_max_i16_e32 v6, v6, v17 + ; GFX8-NEXT: v_max_i16_e32 v7, v7, v16 +-; GFX8-NEXT: v_or_b32_e32 v0, v0, v15 +-; GFX8-NEXT: v_or_b32_e32 v1, v1, v14 +-; GFX8-NEXT: v_or_b32_e32 v2, v2, v13 +-; GFX8-NEXT: v_or_b32_e32 v3, v3, v12 +-; GFX8-NEXT: v_or_b32_e32 v4, v4, v11 +-; GFX8-NEXT: v_or_b32_e32 v5, v5, v10 +-; GFX8-NEXT: v_or_b32_e32 v6, v6, v9 +-; GFX8-NEXT: v_or_b32_e32 v7, v7, v8 ++; GFX8-NEXT: v_or_b32_e32 v0, v0, v8 ++; GFX8-NEXT: v_or_b32_e32 v1, v1, v15 ++; GFX8-NEXT: v_or_b32_e32 v2, v2, v14 ++; GFX8-NEXT: v_or_b32_e32 v3, v3, v13 ++; GFX8-NEXT: v_or_b32_e32 v4, v4, v12 ++; GFX8-NEXT: v_or_b32_e32 v5, v5, v11 ++; GFX8-NEXT: v_or_b32_e32 v6, v6, v10 ++; GFX8-NEXT: v_or_b32_e32 v7, v7, v9 + ; GFX8-NEXT: s_setpc_b64 s[30:31] + ; + ; GFX9-LABEL: v_abs_v16i16: +@@ -1267,102 +1254,87 @@ + ; GFX8-LABEL: v_abs_v32i16: + ; GFX8: ; %bb.0: + ; GFX8-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +-; GFX8-NEXT: v_lshrrev_b32_e32 v16, 16, v15 +-; GFX8-NEXT: v_sub_u16_e32 v17, 0, v16 +-; GFX8-NEXT: v_max_i16_sdwa v16, v16, v17 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v17, 16, v14 +-; GFX8-NEXT: v_sub_u16_e32 v18, 0, v17 +-; GFX8-NEXT: v_max_i16_sdwa v17, v17, v18 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v18, 16, v13 +-; GFX8-NEXT: v_sub_u16_e32 v19, 0, v18 +-; GFX8-NEXT: v_max_i16_sdwa v18, v18, v19 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v19, 16, v12 +-; GFX8-NEXT: v_sub_u16_e32 v20, 0, v19 +-; GFX8-NEXT: v_max_i16_sdwa v19, v19, v20 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v20, 16, v11 +-; GFX8-NEXT: v_sub_u16_e32 v21, 0, v20 +-; GFX8-NEXT: v_max_i16_sdwa v20, v20, v21 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v21, 16, v10 +-; GFX8-NEXT: v_sub_u16_e32 v22, 0, v21 +-; GFX8-NEXT: v_max_i16_sdwa v21, v21, v22 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v22, 16, v9 +-; GFX8-NEXT: v_sub_u16_e32 v23, 0, v22 +-; GFX8-NEXT: v_max_i16_sdwa v22, v22, v23 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v23, 16, v8 +-; GFX8-NEXT: v_sub_u16_e32 v24, 0, v23 +-; GFX8-NEXT: v_max_i16_sdwa v23, v23, v24 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v24, 16, v7 +-; GFX8-NEXT: v_sub_u16_e32 v25, 0, v24 +-; GFX8-NEXT: v_max_i16_sdwa v24, v24, v25 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v25, 16, v6 +-; GFX8-NEXT: v_sub_u16_e32 v26, 0, v25 +-; GFX8-NEXT: v_max_i16_sdwa v25, v25, v26 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v26, 16, v5 +-; GFX8-NEXT: v_sub_u16_e32 v27, 0, v26 +-; GFX8-NEXT: v_max_i16_sdwa v26, v26, v27 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v27, 16, v4 +-; GFX8-NEXT: v_sub_u16_e32 v28, 0, v27 +-; GFX8-NEXT: v_max_i16_sdwa v27, v27, v28 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v28, 16, v3 +-; GFX8-NEXT: v_sub_u16_e32 v29, 0, v28 +-; GFX8-NEXT: v_max_i16_sdwa v28, v28, v29 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v29, 16, v2 +-; GFX8-NEXT: v_sub_u16_e32 v30, 0, v29 +-; GFX8-NEXT: v_max_i16_sdwa v29, v29, v30 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v30, 16, v1 +-; GFX8-NEXT: v_sub_u16_e32 v31, 0, v30 +-; GFX8-NEXT: v_max_i16_sdwa v30, v30, v31 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_lshrrev_b32_e32 v31, 16, v0 +-; GFX8-NEXT: v_sub_u16_e32 v32, 0, v31 +-; GFX8-NEXT: v_max_i16_sdwa v31, v31, v32 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD +-; GFX8-NEXT: v_sub_u16_e32 v32, 0, v0 +-; GFX8-NEXT: v_max_i16_e32 v0, v0, v32 +-; GFX8-NEXT: v_or_b32_e32 v0, v0, v31 +-; GFX8-NEXT: v_sub_u16_e32 v31, 0, v1 +-; GFX8-NEXT: v_max_i16_e32 v1, v1, v31 +-; GFX8-NEXT: v_or_b32_e32 v1, v1, v30 +-; GFX8-NEXT: v_sub_u16_e32 v30, 0, v2 +-; GFX8-NEXT: v_max_i16_e32 v2, v2, v30 +-; GFX8-NEXT: v_or_b32_e32 v2, v2, v29 +-; GFX8-NEXT: v_sub_u16_e32 v29, 0, v3 +-; GFX8-NEXT: v_max_i16_e32 v3, v3, v29 +-; GFX8-NEXT: v_or_b32_e32 v3, v3, v28 +-; GFX8-NEXT: v_sub_u16_e32 v28, 0, v4 +-; GFX8-NEXT: v_max_i16_e32 v4, v4, v28 +-; GFX8-NEXT: v_or_b32_e32 v4, v4, v27 +-; GFX8-NEXT: v_sub_u16_e32 v27, 0, v5 +-; GFX8-NEXT: v_max_i16_e32 v5, v5, v27 +-; GFX8-NEXT: v_or_b32_e32 v5, v5, v26 +-; GFX8-NEXT: v_sub_u16_e32 v26, 0, v6 +-; GFX8-NEXT: v_max_i16_e32 v6, v6, v26 +-; GFX8-NEXT: v_or_b32_e32 v6, v6, v25 +-; GFX8-NEXT: v_sub_u16_e32 v25, 0, v7 +-; GFX8-NEXT: v_max_i16_e32 v7, v7, v25 +-; GFX8-NEXT: v_or_b32_e32 v7, v7, v24 +-; GFX8-NEXT: v_sub_u16_e32 v24, 0, v8 +-; GFX8-NEXT: v_max_i16_e32 v8, v8, v24 +-; GFX8-NEXT: v_or_b32_e32 v8, v8, v23 +-; GFX8-NEXT: v_sub_u16_e32 v23, 0, v9 +-; GFX8-NEXT: v_max_i16_e32 v9, v9, v23 +-; GFX8-NEXT: v_or_b32_e32 v9, v9, v22 +-; GFX8-NEXT: v_sub_u16_e32 v22, 0, v10 +-; GFX8-NEXT: v_max_i16_e32 v10, v10, v22 +-; GFX8-NEXT: v_or_b32_e32 v10, v10, v21 +-; GFX8-NEXT: v_sub_u16_e32 v21, 0, v11 +-; GFX8-NEXT: v_max_i16_e32 v11, v11, v21 ++; GFX8-NEXT: v_mov_b32_e32 v16, 0 ++; GFX8-NEXT: v_sub_u16_sdwa v19, v16, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_sub_u16_e32 v20, 0, v0 ++; GFX8-NEXT: v_max_i16_sdwa v19, v0, v19 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_e32 v0, v0, v20 ++; GFX8-NEXT: v_sub_u16_sdwa v20, v16, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_or_b32_e32 v0, v0, v19 ++; GFX8-NEXT: v_sub_u16_e32 v19, 0, v1 ++; GFX8-NEXT: v_max_i16_sdwa v20, v1, v20 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_e32 v1, v1, v19 ++; GFX8-NEXT: v_sub_u16_sdwa v19, v16, v2 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_or_b32_e32 v1, v1, v20 ++; GFX8-NEXT: v_sub_u16_e32 v20, 0, v2 ++; GFX8-NEXT: v_max_i16_sdwa v19, v2, v19 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_e32 v2, v2, v20 ++; GFX8-NEXT: v_sub_u16_sdwa v20, v16, v3 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_or_b32_e32 v2, v2, v19 ++; GFX8-NEXT: v_sub_u16_e32 v19, 0, v3 ++; GFX8-NEXT: v_max_i16_sdwa v20, v3, v20 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_e32 v3, v3, v19 ++; GFX8-NEXT: v_sub_u16_sdwa v19, v16, v4 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_or_b32_e32 v3, v3, v20 ++; GFX8-NEXT: v_sub_u16_e32 v20, 0, v4 ++; GFX8-NEXT: v_max_i16_sdwa v19, v4, v19 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_e32 v4, v4, v20 ++; GFX8-NEXT: v_sub_u16_sdwa v20, v16, v5 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_or_b32_e32 v4, v4, v19 ++; GFX8-NEXT: v_sub_u16_e32 v19, 0, v5 ++; GFX8-NEXT: v_max_i16_sdwa v20, v5, v20 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_e32 v5, v5, v19 ++; GFX8-NEXT: v_sub_u16_sdwa v19, v16, v6 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_or_b32_e32 v5, v5, v20 ++; GFX8-NEXT: v_sub_u16_e32 v20, 0, v6 ++; GFX8-NEXT: v_max_i16_sdwa v19, v6, v19 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_e32 v6, v6, v20 ++; GFX8-NEXT: v_sub_u16_sdwa v20, v16, v7 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_or_b32_e32 v6, v6, v19 ++; GFX8-NEXT: v_sub_u16_e32 v19, 0, v7 ++; GFX8-NEXT: v_max_i16_sdwa v20, v7, v20 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_e32 v7, v7, v19 ++; GFX8-NEXT: v_sub_u16_sdwa v19, v16, v8 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_or_b32_e32 v7, v7, v20 ++; GFX8-NEXT: v_sub_u16_e32 v20, 0, v8 ++; GFX8-NEXT: v_max_i16_sdwa v19, v8, v19 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_e32 v8, v8, v20 ++; GFX8-NEXT: v_sub_u16_sdwa v20, v16, v9 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_or_b32_e32 v8, v8, v19 ++; GFX8-NEXT: v_sub_u16_e32 v19, 0, v9 ++; GFX8-NEXT: v_max_i16_sdwa v20, v9, v20 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_e32 v9, v9, v19 ++; GFX8-NEXT: v_sub_u16_sdwa v19, v16, v10 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_or_b32_e32 v9, v9, v20 ++; GFX8-NEXT: v_sub_u16_e32 v20, 0, v10 ++; GFX8-NEXT: v_max_i16_sdwa v19, v10, v19 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_e32 v10, v10, v20 ++; GFX8-NEXT: v_sub_u16_sdwa v20, v16, v11 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_or_b32_e32 v10, v10, v19 ++; GFX8-NEXT: v_sub_u16_e32 v19, 0, v11 ++; GFX8-NEXT: v_max_i16_sdwa v20, v11, v20 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_e32 v11, v11, v19 ++; GFX8-NEXT: v_sub_u16_sdwa v17, v16, v15 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_sub_u16_sdwa v18, v16, v14 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_sub_u16_sdwa v19, v16, v13 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 ++; GFX8-NEXT: v_sub_u16_sdwa v16, v16, v12 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 + ; GFX8-NEXT: v_or_b32_e32 v11, v11, v20 + ; GFX8-NEXT: v_sub_u16_e32 v20, 0, v12 ++; GFX8-NEXT: v_max_i16_sdwa v16, v12, v16 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD + ; GFX8-NEXT: v_max_i16_e32 v12, v12, v20 +-; GFX8-NEXT: v_or_b32_e32 v12, v12, v19 +-; GFX8-NEXT: v_sub_u16_e32 v19, 0, v13 ++; GFX8-NEXT: v_or_b32_e32 v12, v12, v16 ++; GFX8-NEXT: v_sub_u16_e32 v16, 0, v13 ++; GFX8-NEXT: v_max_i16_sdwa v19, v13, v19 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD + ; GFX8-NEXT: v_sub_u16_e32 v20, 0, v15 +-; GFX8-NEXT: v_max_i16_e32 v13, v13, v19 +-; GFX8-NEXT: v_sub_u16_e32 v19, 0, v14 +-; GFX8-NEXT: v_max_i16_e32 v14, v14, v19 ++; GFX8-NEXT: v_max_i16_e32 v13, v13, v16 ++; GFX8-NEXT: v_sub_u16_e32 v16, 0, v14 ++; GFX8-NEXT: v_max_i16_sdwa v17, v15, v17 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_sdwa v18, v14, v18 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD ++; GFX8-NEXT: v_max_i16_e32 v14, v14, v16 + ; GFX8-NEXT: v_max_i16_e32 v15, v15, v20 +-; GFX8-NEXT: v_or_b32_e32 v13, v13, v18 +-; GFX8-NEXT: v_or_b32_e32 v14, v14, v17 +-; GFX8-NEXT: v_or_b32_e32 v15, v15, v16 ++; GFX8-NEXT: v_or_b32_e32 v13, v13, v19 ++; GFX8-NEXT: v_or_b32_e32 v14, v14, v18 ++; GFX8-NEXT: v_or_b32_e32 v15, v15, v17 + ; GFX8-NEXT: s_setpc_b64 s[30:31] + ; + ; GFX9-LABEL: v_abs_v32i16: diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index b37e08b1de85ff..fc1cf70ed11f47 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "e83adfe59632d2e2f8ff26db33087ba7fb754485" - LLVM_SHA256 = "5730ac18a80109d6189a34e212f61f2b0832b91b8cc175c3ebd0aa063db85f59" + LLVM_COMMIT = "93ffe1792fd9a985b96fee1105b399b5196a15bc" + LLVM_SHA256 = "f92f48acaf9a6543097dc50acb761584ca3da3f2a137d7ad32aa92650b9ea932" tf_http_archive( name = name, diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index dd093a981a47dd..1db9bf674f5b38 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -2442,7 +2442,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp b/ + // Do a single traversal to recompose CustomCallOp to CHLO ops. + GreedyRewriteConfig config; + config.useTopDownTraversal = true; -+ config.enableRegionSimplification = true; ++ config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; + config.maxIterations = 1; + config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; + config.strictMode = GreedyRewriteStrictness::ExistingOps; @@ -2754,7 +2754,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy + void runOnOperation() override { + GreedyRewriteConfig config; + config.useTopDownTraversal = true; -+ config.enableRegionSimplification = true; ++ config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; + config.maxIterations = 2; + config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; + config.strictMode = GreedyRewriteStrictness::AnyOp; @@ -2930,7 +2930,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c + // upstream, and that might be the reason. + GreedyRewriteConfig config; + config.useTopDownTraversal = true; -+ config.enableRegionSimplification = true; ++ config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; + config.maxIterations = 3; + config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; + config.strictMode = GreedyRewriteStrictness::AnyOp; @@ -2955,4 +2955,41 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c +} // namespace experimental +} // namespace stablehlo +} // namespace mlir +diff --ruN a/stablehlo/stablehlo/tests/TestUtils.cpp b/stablehlo/stablehlo/tests/TestUtils.cpp +--- stablehlo/stablehlo/tests/TestUtils.cpp ++++ stablehlo/stablehlo/tests/TestUtils.cpp +@@ -176,7 +176,8 @@ + GreedyRewriteConfig config; + config.maxIterations = 1; + config.useTopDownTraversal = true; +- config.enableRegionSimplification = false; ++ config.enableRegionSimplification = ++ mlir::GreedySimplifyRegionLevel::Disabled; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } + +diff --ruN a/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp +--- stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp ++++ stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp +@@ -300,7 +300,7 @@ + + LogicalResult initialize(MLIRContext* context) override { + config.useTopDownTraversal = true; +- config.enableRegionSimplification = true; ++ config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; + config.maxIterations = 2; + config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; + config.strictMode = GreedyRewriteStrictness::AnyOp; +diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ++++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +@@ -728,7 +728,7 @@ + // There have been recent refactors to applyPatternsAndFoldGreedily + // upstream, and that might be the reason. + config.useTopDownTraversal = true; +- config.enableRegionSimplification = true; ++ config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; + config.maxIterations = 2; + config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; + config.strictMode = GreedyRewriteStrictness::AnyOp; diff --git a/third_party/triton/llvm_integration/cl643947742.patch b/third_party/triton/llvm_integration/cl643947742.patch new file mode 100644 index 00000000000000..4715517a5272f3 --- /dev/null +++ b/third_party/triton/llvm_integration/cl643947742.patch @@ -0,0 +1,13 @@ +==== triton/test/Triton/reproducer.mlir#3 - triton/test/Triton/reproducer.mlir ==== +# action=edit type=text +--- triton/test/Triton/reproducer.mlir 2024-05-14 06:33:36.000000000 -0700 ++++ triton/test/Triton/reproducer.mlir 2024-06-17 08:37:24.000000000 -0700 +@@ -9,7 +9,7 @@ + {-# + external_resources: { + mlir_reproducer: { +- pipeline: "builtin.module(any(convert-scf-to-cf,convert-index-to-llvm{index-bitwidth=0},convert-triton-gpu-to-llvm{compute-capability=90},convert-nv-gpu-to-llvm,convert-arith-to-llvm{index-bitwidth=0},canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},cse,symbol-dce,enable-line-info))", ++ pipeline: "builtin.module(any(convert-scf-to-cf,convert-index-to-llvm{index-bitwidth=0},convert-triton-gpu-to-llvm{compute-capability=90},convert-nv-gpu-to-llvm,convert-arith-to-llvm{index-bitwidth=0},canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=aggressive test-convergence=false top-down=true},cse,symbol-dce,enable-line-info))", + disable_threading: false, + verify_each: false + } diff --git a/third_party/triton/llvm_integration/series.bzl b/third_party/triton/llvm_integration/series.bzl index 4425b4e71caf27..4610d5289abdd3 100644 --- a/third_party/triton/llvm_integration/series.bzl +++ b/third_party/triton/llvm_integration/series.bzl @@ -6,4 +6,5 @@ These should be upstreamed to openai/triton as part of the next triton integrati llvm_patch_list = [ "//third_party/triton/llvm_integration:cl642434908.patch", + "//third_party/triton/llvm_integration:cl643947742.patch", ] diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index dd093a981a47dd..1db9bf674f5b38 100755 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -2442,7 +2442,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp b/ + // Do a single traversal to recompose CustomCallOp to CHLO ops. + GreedyRewriteConfig config; + config.useTopDownTraversal = true; -+ config.enableRegionSimplification = true; ++ config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; + config.maxIterations = 1; + config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; + config.strictMode = GreedyRewriteStrictness::ExistingOps; @@ -2754,7 +2754,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy + void runOnOperation() override { + GreedyRewriteConfig config; + config.useTopDownTraversal = true; -+ config.enableRegionSimplification = true; ++ config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; + config.maxIterations = 2; + config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; + config.strictMode = GreedyRewriteStrictness::AnyOp; @@ -2930,7 +2930,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c + // upstream, and that might be the reason. + GreedyRewriteConfig config; + config.useTopDownTraversal = true; -+ config.enableRegionSimplification = true; ++ config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; + config.maxIterations = 3; + config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; + config.strictMode = GreedyRewriteStrictness::AnyOp; @@ -2955,4 +2955,41 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c +} // namespace experimental +} // namespace stablehlo +} // namespace mlir +diff --ruN a/stablehlo/stablehlo/tests/TestUtils.cpp b/stablehlo/stablehlo/tests/TestUtils.cpp +--- stablehlo/stablehlo/tests/TestUtils.cpp ++++ stablehlo/stablehlo/tests/TestUtils.cpp +@@ -176,7 +176,8 @@ + GreedyRewriteConfig config; + config.maxIterations = 1; + config.useTopDownTraversal = true; +- config.enableRegionSimplification = false; ++ config.enableRegionSimplification = ++ mlir::GreedySimplifyRegionLevel::Disabled; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } + +diff --ruN a/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp +--- stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp ++++ stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp +@@ -300,7 +300,7 @@ + + LogicalResult initialize(MLIRContext* context) override { + config.useTopDownTraversal = true; +- config.enableRegionSimplification = true; ++ config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; + config.maxIterations = 2; + config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; + config.strictMode = GreedyRewriteStrictness::AnyOp; +diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ++++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +@@ -728,7 +728,7 @@ + // There have been recent refactors to applyPatternsAndFoldGreedily + // upstream, and that might be the reason. + config.useTopDownTraversal = true; +- config.enableRegionSimplification = true; ++ config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; + config.maxIterations = 2; + config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; + config.strictMode = GreedyRewriteStrictness::AnyOp; diff --git a/third_party/xla/third_party/triton/llvm_integration/cl643947742.patch b/third_party/xla/third_party/triton/llvm_integration/cl643947742.patch new file mode 100644 index 00000000000000..4715517a5272f3 --- /dev/null +++ b/third_party/xla/third_party/triton/llvm_integration/cl643947742.patch @@ -0,0 +1,13 @@ +==== triton/test/Triton/reproducer.mlir#3 - triton/test/Triton/reproducer.mlir ==== +# action=edit type=text +--- triton/test/Triton/reproducer.mlir 2024-05-14 06:33:36.000000000 -0700 ++++ triton/test/Triton/reproducer.mlir 2024-06-17 08:37:24.000000000 -0700 +@@ -9,7 +9,7 @@ + {-# + external_resources: { + mlir_reproducer: { +- pipeline: "builtin.module(any(convert-scf-to-cf,convert-index-to-llvm{index-bitwidth=0},convert-triton-gpu-to-llvm{compute-capability=90},convert-nv-gpu-to-llvm,convert-arith-to-llvm{index-bitwidth=0},canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},cse,symbol-dce,enable-line-info))", ++ pipeline: "builtin.module(any(convert-scf-to-cf,convert-index-to-llvm{index-bitwidth=0},convert-triton-gpu-to-llvm{compute-capability=90},convert-nv-gpu-to-llvm,convert-arith-to-llvm{index-bitwidth=0},canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=aggressive test-convergence=false top-down=true},cse,symbol-dce,enable-line-info))", + disable_threading: false, + verify_each: false + } diff --git a/third_party/xla/third_party/triton/llvm_integration/series.bzl b/third_party/xla/third_party/triton/llvm_integration/series.bzl index 4425b4e71caf27..4610d5289abdd3 100644 --- a/third_party/xla/third_party/triton/llvm_integration/series.bzl +++ b/third_party/xla/third_party/triton/llvm_integration/series.bzl @@ -6,4 +6,5 @@ These should be upstreamed to openai/triton as part of the next triton integrati llvm_patch_list = [ "//third_party/triton/llvm_integration:cl642434908.patch", + "//third_party/triton/llvm_integration:cl643947742.patch", ] From 4b12044106fd2d4e20bc4809e198432225028861 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Tue, 18 Jun 2024 04:37:04 -0700 Subject: [PATCH 04/59] [XLA:GPU] Use priority fusion in TritonGemmAutotunerExtractor. PiperOrigin-RevId: 644336930 --- third_party/xla/xla/service/gpu/BUILD | 3 ++ .../xla/service/gpu/gemm_fusion_autotuner.cc | 40 +++++++++++-------- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index ea8feebb6aa687..638f9dc37b7007 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -836,6 +836,9 @@ cc_library( "@local_tsl//tsl/profiler/lib:scoped_annotation", "//xla/tsl/util/proto:proto_utils", "//xla/service/gpu:hlo_traversal", + ":fusion_wrapper", + ":priority_fusion", + "//xla/service/gpu/model:gpu_hlo_cost_analysis", "//xla/stream_executor:stream_executor_memory_allocator", "@local_tsl//tsl/platform:path", ]), diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc index 1082b0a4155c46..f872efc3865101 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc @@ -62,13 +62,15 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/buffer_comparator.h" #include "xla/service/gpu/cudnn_fusion_compiler.h" +#include "xla/service/gpu/fusion_wrapper.h" #include "xla/service/gpu/gemm_rewriter.h" #include "xla/service/gpu/gpu_float_support.h" -#include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/instruction_fusion.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/priority_fusion.h" #include "xla/service/gpu/split_k_gemm_rewriter.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/hlo_module_config.h" @@ -355,22 +357,26 @@ absl::StatusOr> TritonGemmAutotuneExtractor( BF16); FloatNormalization float_normalization(&bf16_support); TF_RETURN_IF_ERROR(float_normalization.Run(new_module.get()).status()); - GpuInstructionFusion instruction_fusion(/*may_duplicate=*/false, - gpu_device_info); - TF_RETURN_IF_ERROR(instruction_fusion.Run(new_module.get()).status()); - HloInstruction* root = entry_computation->root_instruction(); - // If the instruction fusion pass above skipped the reduction, turn it - // into a fusion for a universal set of arguments for execution. - if (root->opcode() == HloOpcode::kReduce) { - HloInstruction* fusion_instruction = - entry_computation->AddInstruction(HloInstruction::CreateFusion( - root->shape(), ChooseFusionKind(*root, *root), root)); - HloInstruction* init_value = root->mutable_operand(1); - TF_CHECK_OK( - entry_computation->ReplaceInstruction(root, fusion_instruction)); - fusion_instruction->FuseInstruction(init_value); - TF_CHECK_OK(entry_computation->RemoveInstruction(init_value)); - } + + auto shape_size_function = [&](const Shape& shape) { + // The real pointer size is set in GpuCompiler. In HloCostAnalysis, the + // pointer size is used only to determine the size of tuple types. We + // shouldn't have any tuples in the autotuned module, so it's safe to use + // a constant here, instead of piping the real value. + constexpr int64_t kPointerSize = 8; + return ShapeUtil::ByteSizeOf(shape, kPointerSize); + }; + GpuPriorityFusion priority_fusion( + /*thread_pool=*/nullptr, gpu_device_info, + GpuHloCostAnalysis::Options{/*shape_size=*/shape_size_function, + /*per_second_rates=*/{}, + /*count_multiple_input_accesses=*/true}); + TF_RETURN_IF_ERROR(priority_fusion.Run(new_module.get()).status()); + + // If the priority fusion pass above skipped some instructions, turn them + // into fusions. + FusionWrapper fusion_wrapper; + TF_RETURN_IF_ERROR(fusion_wrapper.Run(new_module.get()).status()); } return new_module; } From f1a240d54e2340b63063ab5835f2c8c978ff12f8 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Tue, 18 Jun 2024 04:55:23 -0700 Subject: [PATCH 05/59] Stop using xla/statusor.h now that it just contains an alias for absl::Status. In some situations, this meant also changing unrelated files to directly include tsl/platform/statusor.h to get the definitions for TF_ASSIGN_OR_RETURN, etc., where they were getting transitively included for free. PiperOrigin-RevId: 644340828 --- third_party/xla/xla/python/pjrt_ifrt/BUILD | 1 - third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/third_party/xla/xla/python/pjrt_ifrt/BUILD b/third_party/xla/xla/python/pjrt_ifrt/BUILD index b9d0a0100faf31..86427ec27a8214 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/BUILD +++ b/third_party/xla/xla/python/pjrt_ifrt/BUILD @@ -215,7 +215,6 @@ cc_library( "//xla:literal", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc index c4fc0325d93375..a973cff8a21fb2 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" @@ -55,7 +56,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/util.h" From 44cb866c3d76e8eeb6da601eca4417903dede4c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= Date: Tue, 18 Jun 2024 05:02:22 -0700 Subject: [PATCH 06/59] [XLA:GPU] Add initial SymbolicTileAnalysis::GetGoodTilings implementation PiperOrigin-RevId: 644342386 --- third_party/xla/xla/service/gpu/model/BUILD | 5 + .../gpu/model/symbolic_tile_analysis.cc | 92 +++++++++++++++- .../gpu/model/symbolic_tile_analysis.h | 27 ++++- .../gpu/model/symbolic_tile_analysis_test.cc | 103 ++++++++++++++++++ 4 files changed, 223 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 32150a534fbce2..03dcb308b2367f 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -659,6 +659,8 @@ cc_library( ":symbolic_tiled_hlo_instruction", ":tiled_hlo_computation", ":tiled_hlo_instruction", + "//xla:shape_util", + "//xla:status_macros", "//xla/hlo/ir:hlo", "//xla/service:instruction_fusion", "//xla/service:name_uniquer", @@ -668,9 +670,11 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -691,6 +695,7 @@ xla_cc_test( "//xla/service/gpu:hlo_traversal", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc index c2f56266782bdd..59fafc442f80d2 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/model/symbolic_tile_analysis.h" +#include #include #include #include @@ -30,8 +31,11 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" +#include "absl/numeric/bits.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" @@ -53,6 +57,8 @@ limitations under the License. #include "xla/service/gpu/model/tiled_hlo_instruction.h" #include "xla/service/instruction_fusion.h" #include "xla/service/name_uniquer.h" +#include "xla/shape.h" +#include "xla/status_macros.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" @@ -282,7 +288,7 @@ absl::StatusOr ComputeBlockIdToTileOffsetIndexing( } absl::StatusOr SymbolicTileAnalysis::ParametersSatisfyConstraints( - const std::vector& tile_parameters) const { + absl::Span tile_parameters) const { // Populate parameter map. llvm::SmallVector parameters = llvm::to_vector( llvm::map_range(tile_parameters, [this](const int64_t v) -> AffineExpr { @@ -312,7 +318,7 @@ absl::StatusOr SymbolicTileAnalysis::ParametersSatisfyConstraints( absl::StatusOr SymbolicTileAnalysis::ComputeTiledHloInstructions( - const std::vector& tile_parameters, + absl::Span tile_parameters, bool constraints_are_known_satisfied) const { if (!constraints_are_known_satisfied) { TF_ASSIGN_OR_RETURN(bool constraints_are_satisfied, @@ -420,5 +426,87 @@ std::string SymbolicTileAnalysis::ToString( return ss.str(); } +namespace { + +// The possible tiles sizes for one dimension. +std::vector PossibleTileSizesForOneDimension(int64_t dim_size) { + CHECK_GE(dim_size, 1); + + std::vector result; + result.reserve(absl::bit_width(static_cast(dim_size))); + for (int64_t tile_size = 1; tile_size < dim_size; tile_size *= 2) { + result.push_back(tile_size); + } + result.push_back(dim_size); + return result; +} + +} // namespace + +namespace detail { +std::vector GetGoodTilings( + absl::Span dim_sizes, + std::function)> is_valid) { + CHECK(is_valid != nullptr); + + std::vector tilings; + tilings.push_back({}); + for (int dim_size : dim_sizes) { + std::vector possible_tile_sizes = + PossibleTileSizesForOneDimension(dim_size); + std::vector extended_tilings; + extended_tilings.reserve(tilings.size() * possible_tile_sizes.size()); + for (const SymbolicTileAnalysis::Tiling& tiling : tilings) { + for (int64_t tile_size : possible_tile_sizes) { + SymbolicTileAnalysis::Tiling extended_tiling = tiling; + extended_tiling.push_back(tile_size); + extended_tilings.push_back(extended_tiling); + } + } + tilings = std::move(extended_tilings); + } + + tilings.erase( + std::remove_if(tilings.begin(), tilings.end(), std::not_fn(is_valid)), + tilings.end()); + + return tilings; +} +} // namespace detail + +absl::StatusOr> +SymbolicTileAnalysis::GetGoodTilings() const { + TF_RET_CHECK(!symbolic_tiled_hlo_instructions_.empty()); + TF_RET_CHECK(symbolic_tiled_hlo_instructions_.back() != nullptr); + + const SymbolicTiledHloInstruction& instr = + *symbolic_tiled_hlo_instructions_.back(); + TF_RET_CHECK(instr.hlo() != nullptr); + const Shape& shape = instr.hlo()->shape(); + if (!absl::c_all_of(shape.dimensions(), + [](int64_t dim_size) { return dim_size >= 1; })) { + return absl::InvalidArgumentError(absl::StrFormat( + "Shape %s has zero or negative dimensions.", shape.ToString())); + } + + absl::Status status = absl::OkStatus(); + std::vector result = detail::GetGoodTilings( + shape.dimensions(), [&](absl::Span tile_sizes) { + absl::StatusOr is_valid = + ParametersSatisfyConstraints(tile_sizes); + if (!is_valid.ok()) { + status = is_valid.status(); + return false; + } + return is_valid.value(); + }); + + if (status.ok()) { + return result; + } + + return status; +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h index 7a1d7ddaeb98c7..ac938f5c17589e 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h @@ -17,13 +17,16 @@ limitations under the License. #define XLA_SERVICE_GPU_MODEL_SYMBOLIC_TILE_ANALYSIS_H_ #include +#include #include #include #include #include #include +#include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -49,6 +52,11 @@ using SymbolicTileAnalysisOrError = // instruction of the computation to the relevant instruction. class SymbolicTileAnalysis { public: + // A tile size for each dimension. + // + // This is an inlined vector to avoid too many heap allocations. + using Tiling = absl::InlinedVector; + // Tries to construct a symbolic tile analysis from a computation. Returns // a diagnostic if the construction fails for any reason. static SymbolicTileAnalysisOrError AnalyzeComputation( @@ -62,7 +70,7 @@ class SymbolicTileAnalysis { // constraints are satisfied by the chosen tiled parameters. Setting // `constraints_are_known_satisfied` to true bypasses this check. absl::StatusOr ComputeTiledHloInstructions( - const std::vector& tile_parameters, + absl::Span tile_parameters, bool constraints_are_known_satisfied = false) const; // Returns the tiled root instruction. @@ -91,7 +99,7 @@ class SymbolicTileAnalysis { // correctly. This is typically the case if too few tile parameters are // provided to fully reduce the constraint expressions to constants. absl::StatusOr ParametersSatisfyConstraints( - const std::vector& tile_parameters) const; + absl::Span tile_parameters) const; // Return the underlying MLIRContext. mlir::MLIRContext* GetMLIRContext() const { return context_; }; @@ -101,6 +109,14 @@ class SymbolicTileAnalysis { std::string ToString( const AffineMapPrinter& printer = AffineMapPrinter()) const; + // Returns a list of tilings for the symbolic tiled HLO computation of the + // analysis that are expected to perform well. + // + // Note: This is an initial implementation where the results may not perform + // that well, and now we're filtering the tilings with Triton in mind + // (allowing only powers of 2 or the full dimension size). + absl::StatusOr> GetGoodTilings() const; + private: SymbolicTileAnalysis(std::vector> symbolic_tiled_hlo_instructions, @@ -121,6 +137,13 @@ class SymbolicTileAnalysis { mlir::MLIRContext* context_; }; +namespace detail { +// Only exposed for testing, please use SymbolicTileAnalysis::GetGoodTilings() +// instead. +std::vector GetGoodTilings( + absl::Span dim_sizes, + std::function)> is_valid); +} // namespace detail } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc index 9e7258c44bd291..6350459c0d750b 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -41,12 +42,14 @@ namespace xla { namespace gpu { namespace { +using detail::GetGoodTilings; using ::testing::ElementsAreArray; using ::testing::ExplainMatchResult; using ::testing::Matcher; using ::testing::SizeIs; using ::testing::status::IsOkAndHolds; using ::testing::status::StatusIs; +using TilingVector = std::vector; MATCHER_P3(MatchTiledHloInstructionImpl, tile_sizes, tile_strides, block_id_to_tile_offsets_indexing, "") { @@ -506,6 +509,106 @@ ENTRY main { EXPECT_FALSE(TryAnalyzeModule(module.get()).has_value()); } +bool AlwaysValid(absl::Span) { return true; } + +TEST(GetGoodTilingsTest, ReturnsOneTilingWhenRankIsZero) { + EXPECT_EQ(GetGoodTilings({}, AlwaysValid), + TilingVector{SymbolicTileAnalysis::Tiling{}}); +} + +TEST(GetGoodTilingsTest, ReturnsPowersOfTwoAndTheDimSizeForRankOne) { + EXPECT_EQ(GetGoodTilings({1}, AlwaysValid), TilingVector{{1}}); + EXPECT_EQ(GetGoodTilings({2}, AlwaysValid), TilingVector({{1}, {2}})); + EXPECT_EQ(GetGoodTilings({3}, AlwaysValid), TilingVector({{1}, {2}, {3}})); + EXPECT_EQ(GetGoodTilings({4}, AlwaysValid), TilingVector({{1}, {2}, {4}})); + EXPECT_EQ(GetGoodTilings({5}, AlwaysValid), + TilingVector({{1}, {2}, {4}, {5}})); + EXPECT_EQ(GetGoodTilings({11}, AlwaysValid), + TilingVector({{1}, {2}, {4}, {8}, {11}})); +} + +TEST(GetGoodTilingsTest, CreatesCartesianProductForRankTwo) { + EXPECT_EQ(GetGoodTilings({3, 4}, AlwaysValid), TilingVector({{1, 1}, + {1, 2}, + {1, 4}, + {2, 1}, + {2, 2}, + {2, 4}, + {3, 1}, + {3, 2}, + {3, 4}})); +} + +TEST(GetGoodTilingsTest, CreatesCartesianProductForRankThree) { + EXPECT_EQ(GetGoodTilings({3, 4, 2}, AlwaysValid), TilingVector({{1, 1, 1}, + {1, 1, 2}, + {1, 2, 1}, + {1, 2, 2}, + {1, 4, 1}, + {1, 4, 2}, + {2, 1, 1}, + {2, 1, 2}, + {2, 2, 1}, + {2, 2, 2}, + {2, 4, 1}, + {2, 4, 2}, + {3, 1, 1}, + {3, 1, 2}, + {3, 2, 1}, + {3, 2, 2}, + {3, 4, 1}, + {3, 4, 2}})); +} + +TEST(GetGoodTilingsTest, FiltersTheTilingsUsingThePredicate) { + auto all_even = [](absl::Span tile_sizes) { + return absl::c_all_of(tile_sizes, + [](int64_t tile_size) { return tile_size % 2 == 0; }); + }; + + EXPECT_EQ(GetGoodTilings({3, 4}, all_even), TilingVector({{2, 2}, {2, 4}})); + + auto all_equal = [](absl::Span tile_sizes) { + return absl::c_all_of(tile_sizes, [&](int64_t tile_size) { + return tile_size == tile_sizes.at(0); + }); + }; + + EXPECT_EQ(GetGoodTilings({3, 3, 3}, all_equal), + TilingVector({{1, 1, 1}, {2, 2, 2}, {3, 3, 3}})); +} + +TEST_F(SymbolicTileAnalysisTest, + GetGoodTilingsWorksTakingConstraintsIntoAccount) { + // The module was chosen (from SymbolicTileTest) because it has a constraint + // on the tile sizes. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +fusion { + p0 = f32[1,8,6,4]{3,2,1,0} parameter(0) + ROOT bitcast = f32[48,4]{1,0} bitcast(p0) +} + +ENTRY main { + p0 = f32[1,8,6,4]{3,2,1,0} parameter(0) + ROOT fusion = f32[48,4]{1,0} fusion(p0), kind=kLoop, calls=fusion +})")); + + std::optional opt_analysis = + TryAnalyzeModule(module.get()); + ASSERT_TRUE(opt_analysis.has_value()); + + const SymbolicTileAnalysis& analysis = opt_analysis.value(); + TF_ASSERT_OK_AND_ASSIGN( + std::vector good_tilings, + analysis.GetGoodTilings()); + // The constraint on the 1st dimension is "s0 mod 6 in [0, 0]", and only 48 + // fulfills that from the set of possible tile sizes (1, 2, 4, 8, 16, 32, 48). + // There is no constraint on the 2nd dimension. + EXPECT_EQ(good_tilings, std::vector( + {{48, 1}, {48, 2}, {48, 4}})); +} + } // namespace } // namespace gpu } // namespace xla From 5cfe3aca5e1b006528d1c71e00e641c69cff7928 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Tue, 18 Jun 2024 05:13:40 -0700 Subject: [PATCH 07/59] PR #13781: [GPU] Let the on-disk kernel compilation cache grow. Imported from GitHub PR https://github.com/openxla/xla/pull/13781 Copybara import of the project: -- 060f02a0c356edffa1037da07150eb19ef387231 by Ilia Sergachev : [GPU] Let the on-disk kernel compilation cache grow. -- 3572421e08fef72e8a6c49105c6a5ec5c9b47a5d by Ilia Sergachev : Add new flag description -- ecb79ca37583599b8bed8cd8c457139981c589be by Ilia Sergachev : Move code -- 3695df02c2f9ec73f4c70d6cacb23d983e3e7a21 by Ilia Sergachev : Improve checks -- f421890b6548a9a4dfcc5350e791bc8860615dbc by Ilia Sergachev : Add another test Merging this change closes #13781 PiperOrigin-RevId: 644345246 --- third_party/xla/xla/debug_options_flags.cc | 16 ++--- third_party/xla/xla/service/gpu/BUILD | 5 +- .../xla/xla/service/gpu/gpu_compiler.cc | 68 ++++++++----------- .../xla/xla/service/gpu/gpu_compiler_test.cc | 22 +++--- .../xla/xla/service/gpu/kernel_reuse_cache.cc | 36 ++++++++++ .../xla/xla/service/gpu/kernel_reuse_cache.h | 12 ++++ .../service/gpu/kernel_reuse_cache_test.cc | 38 ++++++++++- 7 files changed, 137 insertions(+), 60 deletions(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 84594a623d9bcb..7bb8e33a62336b 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -1760,14 +1760,14 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_shard_autotuning(), "Shard autotuning between participating compiler processes (typically in " "multi-host setups) and join the results when it's done.")); - flag_list->push_back(tsl::Flag( - "xla_gpu_kernel_cache_file", - string_setter_for(&DebugOptions::set_xla_gpu_kernel_cache_file), - debug_options->xla_gpu_kernel_cache_file(), - "Path to a file to cache compiled kernels. If the file doesn't exist " - "write the compilation cache of the first compiled HLO module into it." - "Once the file exists, further compilations will read it to reuse " - "the kernels, but not write it. This behavior may change later.")); + flag_list->push_back( + tsl::Flag("xla_gpu_kernel_cache_file", + string_setter_for(&DebugOptions::set_xla_gpu_kernel_cache_file), + debug_options->xla_gpu_kernel_cache_file(), + "Path to a file to cache compiled kernels. Cached kernels get " + "reused in further compilations; not yet cached kernels are " + "compiled as usual and get appended to the cache file whenever " + "possible.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 638f9dc37b7007..b18bc40aab77e9 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -3425,6 +3425,7 @@ cc_library( ":custom_kernel_fusion_rewriter", ":dot_dimension_sorter", ":dot_operand_converter", + ":double_buffer_loop_unrolling", ":executable_proto_cc", ":fusion_merger", ":fusion_wrapper", @@ -3451,7 +3452,7 @@ cc_library( ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", - ":double_buffer_loop_unrolling", + ":kernel_reuse_cache", ":matmul_utils", ":metrics", ":move_copy_to_users", @@ -5730,7 +5731,9 @@ xla_cc_test( deps = [ ":kernel_reuse_cache", "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 601619ee6ee973..00f11ec73d8dbf 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -147,6 +147,7 @@ limitations under the License. #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/ir_emitter_unnested.h" +#include "xla/service/gpu/kernel_reuse_cache.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/metrics.h" #include "xla/service/gpu/model/gpu_cost_model_stats_collection.h" @@ -1922,12 +1923,7 @@ absl::StatusOr GpuCompiler::CompileAndLink( std::string ptx_snippets; std::vector> binaries_to_link; binaries_to_link.reserve(compile_results.size()); - struct NamedBinary { - // The string is the function name or empty just like for llvm_modules. - std::string name; - std::vector binary; - }; - std::vector binaries_to_cache; + std::vector binaries_to_cache; binaries_to_cache.reserve(single_function_module_count); for (const auto& [name, maybe_result] : compile_results) { TF_ASSIGN_OR_RETURN(auto result, maybe_result); @@ -1948,51 +1944,40 @@ absl::StatusOr GpuCompiler::CompileAndLink( return FailedPrecondition("File path can not be resolved: %s", cache_path); } - CompilationCacheProto& cache = + // current_cache contains new kernels from the current compilation and + // kernels to reuse from previous compilations if some were loaded from the + // cache file. + const CompilationCacheProto& current_cache = compile_module_results.kernel_compilation_cache; - if (tsl::Env::Default()->FileExists(resolved_path).ok()) { + const bool cache_file_exists = + tsl::Env::Default()->FileExists(resolved_path).ok(); + if (cache_file_exists) { + // Pick reused binaries from previous compilations needed to link the + // current executable. int loaded_kernel_count = 0; - for (const auto& [name, entry] : cache.entries()) { - if (llvm_module->getFunction(name)) { - VLOG(5) - << "Skipping cached " << name - << " in favor of the just compiled kernel with the same name."; - CHECK(entry.binary().empty()); + for (const auto& [name, entry] : current_cache.entries()) { + if (llvm_module->getFunction(name) != nullptr) { + VLOG(5) << "Using the just compiled kernel for " << name; + TF_RET_CHECK(entry.binary().empty()) + << name + << " is a just compiled kernel and is not expected to have a " + "binary yet."; continue; } const uint8_t* binary = reinterpret_cast(entry.binary().data()); binaries_to_link.push_back( std::vector(binary, binary + entry.binary().size())); - VLOG(5) << "Loaded " << name << ": " << entry.binary().size(); + VLOG(5) << "Using " << name << " from cache: " << entry.binary().size(); ++loaded_kernel_count; } - VLOG(2) << "Loaded " << loaded_kernel_count << " / " - << cache.entries_size() << " cached kernels."; - } else { - auto entries = cache.mutable_entries(); - for (const auto& [name, binary] : binaries_to_cache) { - auto it = entries->find(name); - if (it == entries->end()) { - continue; - } - it->second.set_binary(reinterpret_cast(binary.data()), - binary.size()); - VLOG(5) << "Cached kernels: " << name << ": " << binary.size(); - } - for (auto it = entries->begin(); it != entries->end();) { - if (it->second.binary().empty()) { - it = entries->erase(it); - } else { - ++it; - } - } - if (cache.entries_size() > 0) { - TF_RETURN_IF_ERROR(tsl::WriteStringToFile( - tsl::Env::Default(), resolved_path, cache.SerializeAsString())); - VLOG(2) << "Stored " << cache.entries_size() << " / " - << binaries_to_cache.size(); - } + VLOG(2) << "Using " << loaded_kernel_count << " / " + << current_cache.entries_size() << " cached kernels."; + } + if (!binaries_to_cache.empty()) { + TF_RETURN_IF_ERROR( + UpdateDiskKernelCache(resolved_path, /*do_append=*/cache_file_exists, + current_cache, binaries_to_cache)); } } @@ -2007,6 +1992,7 @@ absl::StatusOr GpuCompiler::CompileAndLink( return maybe_backend_result.status(); } VLOG(4) << "Binary size after linking [B]: " << maybe_backend_result->size(); + compile_module_results.kernel_compilation_cache.Clear(); return BackendCompileResult{ptx_snippets, std::move(*maybe_backend_result)}; } diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc index 643a9ff7e906b0..8c11571a731e62 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -559,6 +559,7 @@ CHECK-SAME: frontend_attributes={_xla_send_recv_pipeline="0"} class KernelCacheTest : public HloTestBase { public: void SetUp() override { + CHECK(tsl::Env::Default()->LocalTempFilename(&cache_file_name_)); HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(bool can_use_link_modules, @@ -568,8 +569,8 @@ class KernelCacheTest : public HloTestBase { GTEST_SKIP() << "Caching compiled kernels requires support of linking."; } } + DebugOptions GetDebugOptionsForTest() override { - CHECK(tsl::Env::Default()->LocalTempFilename(&cache_file_name_)); DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_kernel_cache_file(cache_file_name_); debug_options.set_xla_gpu_enable_llvm_module_compilation_parallelism(true); @@ -583,16 +584,16 @@ class KernelCacheTest : public HloTestBase { return true; } - bool NonEmptyCacheExists() { + int CacheEntryCount() { if (!CacheFileExists()) { - return false; + return 0; } std::string serialized; TF_EXPECT_OK(tsl::ReadFileToString(tsl::Env::Default(), cache_file_name_, &serialized)); CompilationCacheProto proto; EXPECT_TRUE(proto.ParseFromString(std::string(serialized))); - return proto.entries_size() > 0; + return proto.entries_size(); } std::string cache_file_name_; @@ -609,9 +610,10 @@ TEST_F(KernelCacheTest, CacheIsGenerated) { EXPECT_FALSE(CacheFileExists()); EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); // First run generates a cache - EXPECT_TRUE(NonEmptyCacheExists()); + EXPECT_EQ(CacheEntryCount(), 1); // Second run - with cache file EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); + EXPECT_EQ(CacheEntryCount(), 1); } TEST_F(KernelCacheTest, NoCacheIsGeneratedWithoutCompiledKernels) { @@ -626,10 +628,10 @@ TEST_F(KernelCacheTest, NoCacheIsGeneratedWithoutCompiledKernels) { EXPECT_FALSE(CacheFileExists()); } -TEST_F(KernelCacheTest, UsingCacheFromAnotherModuleDoesNotFail) { +TEST_F(KernelCacheTest, CacheGrowsWithNewKernels) { EXPECT_FALSE(CacheFileExists()); EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); - EXPECT_TRUE(NonEmptyCacheExists()); + EXPECT_EQ(CacheEntryCount(), 1); // Second run - with cache file and another HLO EXPECT_TRUE(Run(R"( ENTRY e { @@ -637,6 +639,7 @@ TEST_F(KernelCacheTest, UsingCacheFromAnotherModuleDoesNotFail) { ROOT _ = s8[] multiply(p, p) })", /*run_hlo_passes=*/false)); + EXPECT_EQ(CacheEntryCount(), 2); } class KernelCacheTestSingleThreaded : public KernelCacheTest { @@ -651,8 +654,9 @@ class KernelCacheTestSingleThreaded : public KernelCacheTest { TEST_F(KernelCacheTestSingleThreaded, CacheIsGenerated) { EXPECT_FALSE(CacheFileExists()); EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); - EXPECT_TRUE(NonEmptyCacheExists()); + EXPECT_EQ(CacheEntryCount(), 1); EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); + EXPECT_EQ(CacheEntryCount(), 1); } class NoKernelCacheTest : public KernelCacheTest { @@ -666,7 +670,7 @@ class NoKernelCacheTest : public KernelCacheTest { TEST_F(NoKernelCacheTest, NoCacheWithoutCompilationParallelism) { EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); - EXPECT_FALSE(NonEmptyCacheExists()); + EXPECT_FALSE(CacheFileExists()); } } // namespace diff --git a/third_party/xla/xla/service/gpu/kernel_reuse_cache.cc b/third_party/xla/xla/service/gpu/kernel_reuse_cache.cc index 0b29670e6cb43f..0c1d97af215688 100644 --- a/third_party/xla/xla/service/gpu/kernel_reuse_cache.cc +++ b/third_party/xla/xla/service/gpu/kernel_reuse_cache.cc @@ -139,6 +139,42 @@ CompilationCacheProto KernelReuseCache::Export() const { return proto; } +absl::Status UpdateDiskKernelCache( + absl::string_view path, const bool do_append, + const CompilationCacheProto& current_cache, + absl::Span binaries_to_cache) { + CompilationCacheProto disk_cache; + if (do_append) { + std::string serialized; + TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), + std::string(path), &serialized)); + if (!disk_cache.ParseFromString(std::string(serialized))) { + return Internal("Failed to parse serialized CompilationCacheProto."); + } + } + auto entries = disk_cache.mutable_entries(); + int stored_kernel_count = 0; + for (const auto& [name, binary] : binaries_to_cache) { + auto it_current = current_cache.entries().find(name); + TF_RET_CHECK(it_current != current_cache.entries().end()); + auto [it_disk, inserted] = entries->insert({name, it_current->second}); + TF_RET_CHECK(inserted); + TF_RET_CHECK(!binary.empty()); + it_disk->second.set_binary(reinterpret_cast(binary.data()), + binary.size()); + VLOG(5) << "Cached kernel: " << name << ": " << binary.size(); + ++stored_kernel_count; + } + if (stored_kernel_count > 0) { + TF_RETURN_IF_ERROR(tsl::WriteStringToFile(tsl::Env::Default(), + std::string(path), + disk_cache.SerializeAsString())); + VLOG(2) << "Stored " << stored_kernel_count << " / " + << binaries_to_cache.size() << " kernels in the cache file."; + } + return absl::OkStatus(); +} + std::pair, bool> KernelReuseCache::GetWithStatus( const HloComputation* fused_computation, diff --git a/third_party/xla/xla/service/gpu/kernel_reuse_cache.h b/third_party/xla/xla/service/gpu/kernel_reuse_cache.h index a66a5fac70dd50..bf165b5a7033f0 100644 --- a/third_party/xla/xla/service/gpu/kernel_reuse_cache.h +++ b/third_party/xla/xla/service/gpu/kernel_reuse_cache.h @@ -45,6 +45,10 @@ class KernelReuseCache { int64_t shmem_bytes = 0; std::string binary; }; + struct NamedBinary { + std::string name; + std::vector binary; + }; absl::Status Load(const CompilationCacheProto& proto); // Exporting skips kernels that were loaded but not used during emission. @@ -88,6 +92,14 @@ class KernelReuseCache { absl::flat_hash_set hits_; }; +// Add kernels to the cache file. Binaries are taken from binaries_to_cache, +// all other kernel properties are taken from current_cache. +// do_append makes an existing file be loaded first. +absl::Status UpdateDiskKernelCache( + absl::string_view path, bool do_append, + const CompilationCacheProto& current_cache, + absl::Span binaries_to_cache); + // Calculates the fingerprint of a (fused_computation, kernel_arguments, // discriminator) tuple. // diff --git a/third_party/xla/xla/service/gpu/kernel_reuse_cache_test.cc b/third_party/xla/xla/service/gpu/kernel_reuse_cache_test.cc index 19b0c0d0d3f3c7..75c9b009f7d93d 100644 --- a/third_party/xla/xla/service/gpu/kernel_reuse_cache_test.cc +++ b/third_party/xla/xla/service/gpu/kernel_reuse_cache_test.cc @@ -14,8 +14,9 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/kernel_reuse_cache.h" +#include #include "tsl/lib/core/status_test_util.h" -#include "tsl/platform/test.h" +#include "tsl/platform/env.h" namespace xla { namespace gpu { @@ -39,6 +40,41 @@ TEST_F(KernelReuseTest, ExportAndLoadWork) { EXPECT_FALSE(cache.IsEmpty()); } +TEST_F(KernelReuseTest, UpdatingDiskKernelCacheWorks) { + std::string cache_file_path; + CHECK(tsl::Env::Default()->LocalTempFilename(&cache_file_path)); + { + const CompilationCacheProto proto = [](std::string kernel_name) { + KernelReuseCache cache; + auto [result, was_cached] = cache.GetWithStatus("fingerprint", [&]() { + return KernelReuseCache::Entry{.kernel_name = kernel_name}; + }); + return cache.Export(); + }("k1"); + TF_EXPECT_OK(UpdateDiskKernelCache(cache_file_path, /*do_append=*/false, + proto, + {{.name = "k1", .binary = {5, 6}}})); + } + { + const CompilationCacheProto proto = [](std::string kernel_name) { + KernelReuseCache cache; + auto [result, was_cached] = cache.GetWithStatus("fingerprint", [&]() { + return KernelReuseCache::Entry{.kernel_name = kernel_name}; + }); + return cache.Export(); + }("k2"); + TF_EXPECT_OK(UpdateDiskKernelCache(cache_file_path, /*do_append=*/true, + proto, + {{.name = "k2", .binary = {7, 8}}})); + } + std::string serialized; + TF_EXPECT_OK( + tsl::ReadFileToString(tsl::Env::Default(), cache_file_path, &serialized)); + CompilationCacheProto proto; + EXPECT_TRUE(proto.ParseFromString(std::string(serialized))); + EXPECT_EQ(proto.entries_size(), 2); +} + } // namespace } // namespace gpu } // namespace xla From 85f91e81dd9f24267a99a069e7b74d1ffcde3b36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= Date: Tue, 18 Jun 2024 05:30:43 -0700 Subject: [PATCH 08/59] [XLA:GPU] Support tiling Softmax example PiperOrigin-RevId: 644348830 --- third_party/xla/xla/service/gpu/model/BUILD | 3 + .../xla/service/gpu/model/symbolic_tile.cc | 10 ++- .../xla/xla/service/gpu/model/symbolic_tile.h | 3 +- .../gpu/model/symbolic_tile_analysis_test.cc | 77 +++++++++++++++++++ .../service/gpu/model/symbolic_tile_test.cc | 25 +++++- 5 files changed, 113 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 03dcb308b2367f..be7c6d0d196629 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -691,11 +691,14 @@ xla_cc_test( ":symbolic_tile_analysis", ":tiled_hlo_computation", ":tiled_hlo_instruction", + "//xla:util", "//xla/hlo/ir:hlo", + "//xla/service:instruction_fusion", "//xla/service/gpu:hlo_traversal", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc index 93aec8922eec80..f4c4bff3c7fca6 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc @@ -607,13 +607,19 @@ std::optional MergeConstraintMapIfPresentAndCompatible( } /*static*/ std::optional SymbolicTile::FromIndexingMap( - const IndexingMap& indexing_map) { + IndexingMap indexing_map) { VLOG(1) << "SymbolicTile::FromIndexingMap: " << indexing_map.ToString(); // We do not handle indexing maps with pre-existing constraints for now. + // Let's try to simplify the indexing map, because the constraints my be + // redundant. + // TODO(bchetioui): Consider doing the simplification in the caller, not here. + bool did_simplify = indexing_map.Simplify(); + VLOG(1) << "did_simplify: " << did_simplify; if (indexing_map.GetConstraintsCount() != 0) { VLOG(1) << "Deriving symbolic tile from indexing map with pre-existing " - << "constraints might produce spurious constraints. Bailing out."; + << "constraints might produce spurious constraints. Bailing out. " + << indexing_map.ToString(); return std::nullopt; } diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.h b/third_party/xla/xla/service/gpu/model/symbolic_tile.h index ddce4de4699a28..f5c3f680eae5c3 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile.h +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.h @@ -157,8 +157,7 @@ namespace gpu { // simplified later. class SymbolicTile { public: - static std::optional FromIndexingMap( - const IndexingMap& indexing_map); + static std::optional FromIndexingMap(IndexingMap indexing_map); using ConstraintMap = llvm::DenseMap; diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc index 6350459c0d750b..a0d12abdd1adfd 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -25,7 +25,9 @@ limitations under the License. #include #include #include "absl/algorithm/container.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -34,8 +36,10 @@ limitations under the License. #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/gpu/model/tiled_hlo_instruction.h" +#include "xla/service/instruction_fusion.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/util.h" #include "tsl/platform/statusor.h" namespace xla { @@ -45,7 +49,9 @@ namespace { using detail::GetGoodTilings; using ::testing::ElementsAreArray; using ::testing::ExplainMatchResult; +using ::testing::IsEmpty; using ::testing::Matcher; +using ::testing::Not; using ::testing::SizeIs; using ::testing::status::IsOkAndHolds; using ::testing::status::StatusIs; @@ -83,6 +89,8 @@ class SymbolicTileAnalysisTest : public HloTestBase { if (std::holds_alternative(analysis_or_error)) { return std::get(std::move(analysis_or_error)); } + VLOG(1) << "Cannot analyze module: " + << std::get(analysis_or_error).Explain(); return std::nullopt; } @@ -609,6 +617,75 @@ ENTRY main { {{48, 1}, {48, 2}, {48, 4}})); } +// Logs the tilings if VLOG level 1 is enabled. +// +// Use these arguments to see the log: +// --test_output=all +// --test_arg=--logtostderr +// --test_arg=--vmodule=symbolic_tile_analysis_test=1 +void LogTilingsIfVlog1(absl::Span tilings) { + if (VLOG_IS_ON(1)) { + LOG(INFO) << "Tilings: {"; + for (const SymbolicTileAnalysis::Tiling& tiling : tilings) { + LOG(INFO) << "{" << absl::StrJoin(tiling, ",") << "},"; + } + LOG(INFO) << "}"; + } +} + +TEST_F(SymbolicTileAnalysisTest, GetGoodTilingsWorksForSoftmaxExample) { + // The example is from + // https://github.com/google/paxml/blob/91893818862645f5e9f23b84f530e611551745f6/paxml/contrib/gpu/scripts_gpu/configs.py#L107-L120. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule m + +region { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(param_0, param_1) +} + +region.1 { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT add = f32[] add(param_0, param_1) +} + +fused_computation { + param_0 = f32[8192,50304] parameter(0) + bitcast = f32[4,2048,50304] bitcast(param_0) + constant = f32[] constant(-inf) + reduce = f32[8192] reduce(param_0, constant), dimensions={1}, to_apply=region + bitcast.1 = f32[4,2048] bitcast(reduce) + broadcast = f32[4,2048,50304] broadcast(bitcast.1), dimensions={0,1} + subtract = f32[4,2048,50304] subtract(bitcast, broadcast) + exponential = f32[4,2048,50304] exponential(subtract) + constant.1 = f32[] constant(0) + reduce.1 = f32[4,2048] reduce(exponential, constant.1), dimensions={2}, to_apply=region.1 + log = f32[4,2048] log(reduce.1) + broadcast.1 = f32[4,2048,50304] broadcast(log), dimensions={0,1} + ROOT subtract.1 = f32[4,2048,50304] subtract(subtract, broadcast.1) +} + +ENTRY entry_computation { + param_0 = f32[8192,50304] parameter(0) + ROOT fusion = f32[4,2048,50304] fusion(param_0), kind=kCustom, calls=fused_computation, backend_config={"fusion_backend_config":{"kind":"__triton_softmax"}} +} +)")); + + std::optional opt_analysis = + TryAnalyzeModule(module.get()); + ASSERT_TRUE(opt_analysis.has_value()); + const SymbolicTileAnalysis& analysis = opt_analysis.value(); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector good_tilings, + analysis.GetGoodTilings()); + EXPECT_THAT(good_tilings, Not(IsEmpty())); + LogTilingsIfVlog1(good_tilings); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc index 83471b357e5bd0..8ce8c53184854b 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include @@ -729,7 +730,7 @@ TEST_F(SymbolicTileTest, ParseAffineMap("(d0) -> (d0 mod 6, d0 mod 6)", &mlir_context_), /*dimensions=*/{DimVar{0, 10}}, /*range_vars=*/{}, /*rt_vars=*/{}); - EXPECT_THAT(SymbolicTile::FromIndexingMap(indexing_map), + EXPECT_THAT(SymbolicTile::FromIndexingMap(std::move(indexing_map)), Optional(MatchSymbolicTileString(R"( Symbolic tile with offset_map: ()[s0] -> (0, 0) @@ -740,6 +741,28 @@ TEST_F(SymbolicTileTest, )"))); } +TEST_F(SymbolicTileTest, + CanPropagateTileWhenPreexistingConstraintsCanBeSimplifiedAway) { + // The example is from + // https://github.com/google/paxml/blob/91893818862645f5e9f23b84f530e611551745f6/paxml/contrib/gpu/scripts_gpu/configs.py#L107-L120. + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0, d1, d2)[s0] -> (d0 * 2048 + d1, s0)", + &mlir_context_), + {4, 2048, 50304}, {50304}); + // This constraint is redundant, because it can be derived from the domains of + // the dimension variables. + indexing_map.AddConstraint(ParseAffineExpr("d0 * 2048 + d1", &mlir_context_), + Interval{0, 8191}); + + EXPECT_THAT(SymbolicTile::FromIndexingMap(indexing_map), + Optional(MatchSymbolicTileString(R"( + Symbolic tile with + offset_map: ()[s0, s1, s2] -> (0, 0) + size_map: ()[s0, s1, s2] -> (s0 * s1, 50304) + stride_map: ()[s0, s1, s2] -> (((-s1 + 2049) floordiv 2048) * ((-((-s0 + 5) floordiv 4) + 1) * 2048) + -((-s1 + 2049) floordiv 2048) + 1, 1) + )"))); +} + } // namespace } // namespace gpu } // namespace xla From db5c56991c2499c50f521e5943e3f33f61440d52 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Tue, 18 Jun 2024 05:54:54 -0700 Subject: [PATCH 09/59] [XLA:GPU][NFC] Move GPU specific latency estimator to a separate file. PiperOrigin-RevId: 644353852 --- third_party/xla/xla/service/gpu/BUILD | 6 +- .../xla/xla/service/gpu/gpu_compiler.cc | 1 + .../xla/xla/service/gpu/gpu_hlo_schedule.cc | 91 ------------------ .../xla/xla/service/gpu/gpu_hlo_schedule.h | 2 - .../gpu/gpu_latency_hiding_scheduler.cc | 92 +++++++++++++++++++ .../gpu/gpu_latency_hiding_scheduler.h | 23 +++++ .../xla/service/gpu/nvptx_compiler_test.cc | 1 + 7 files changed, 121 insertions(+), 95 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index b18bc40aab77e9..90d8940b681515 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -3626,6 +3626,7 @@ cc_library( ":command_buffer_scheduling", ":execution_stream_assignment", ":fusion_pipeline", + ":gpu_latency_hiding_scheduler", ":ir_emitter_context", ":ir_emitter_unnested", ":prepare_hlo_for_ir_emitting_pipeline", @@ -3873,6 +3874,7 @@ xla_test( deps = [ ":gpu_constants", ":gpu_hlo_schedule", + ":gpu_latency_hiding_scheduler", ":nvptx_compiler_impl", "//xla:util", "//xla:xla_proto_cc", @@ -4232,11 +4234,9 @@ cc_library( hdrs = ["gpu_hlo_schedule.h"], deps = [ ":backend_configs_cc", - ":cublas_cudnn", ":gpu_latency_hiding_scheduler", ":gpu_schedule_postprocessing", "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", @@ -6260,6 +6260,8 @@ cc_library( hdrs = ["gpu_latency_hiding_scheduler.h"], deps = [ ":backend_configs_cc", + ":cublas_cudnn", + "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "//xla/service:collective_ops_utils", diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 00f11ec73d8dbf..440e5ad600cd7c 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -136,6 +136,7 @@ limitations under the License. #include "xla/service/gpu/gpu_executable.h" #include "xla/service/gpu/gpu_float_support.h" #include "xla/service/gpu/gpu_hlo_schedule.h" +#include "xla/service/gpu/gpu_latency_hiding_scheduler.h" #include "xla/service/gpu/gpu_layout_assignment.h" #include "xla/service/gpu/gpu_p2p_pipeliner.h" #include "xla/service/gpu/gpu_reduce_scatter_creator.h" diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc index 0b1c9629c5c7f7..07449034b1705f 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -37,7 +36,6 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -48,7 +46,6 @@ limitations under the License. #include "xla/service/buffer_value.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/gpu_latency_hiding_scheduler.h" #include "xla/service/gpu/gpu_schedule_postprocessing.h" #include "xla/service/gpu/model/analytical_latency_estimator.h" @@ -59,7 +56,6 @@ limitations under the License. #include "xla/service/profile_guided_latency_estimator.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" #include "xla/util.h" #include "tsl/platform/env.h" @@ -73,19 +69,6 @@ namespace gpu { namespace { -// A threshold for which we consider AR to be costly perf-wise. -static constexpr int64_t kCostlyAllReduceThreshold = 30 * 1024 * 1024; - -// Multiplier which we apply to expand the base cost for the costly AR. -static constexpr int64_t kCostlyAllReduceMultiplier = 4; - -bool IsNopInstruction(const HloInstruction& hlo) { - HloOpcode op = hlo.opcode(); - return op == HloOpcode::kGetTupleElement || op == HloOpcode::kBitcast || - op == HloOpcode::kConstant || op == HloOpcode::kParameter || - hlo.IsEffectiveBitcast(); -} - bool ShouldScheduleAsEarlyAsPossible(const HloInstruction& instr) { switch (instr.opcode()) { case HloOpcode::kAllReduceStart: @@ -280,70 +263,6 @@ SchedulerConfig GetSchedulerConfig(int64_t memory_limit) { return config; } -class GpuLatencyEstimator : public ApproximateLatencyEstimator { - public: - explicit GpuLatencyEstimator( - int64_t pointer_size, - GetCanonicalAsyncOpFunc func = GpuGetCanonicalAsyncOp) - : ApproximateLatencyEstimator(func), pointer_size_(pointer_size) {} - TimeCost NodeCost(const HloInstruction* instr) const override { - if (IsNopInstruction(*instr)) { - return 0.0; - } - // Consider cublas/cuddn/softmax custom calls as medium cost. Since the - // latency between async-start and async-done is 5000 and cost of each - // custom call is 1000, the LHS will try to schedule approximately 5 of - // these in between each start/end pair. - if (instr->opcode() == HloOpcode::kCustomCall) { - if (IsCublasGemm(*instr) || IsCustomCallToDnnConvolution(*instr)) { - return ApproximateLatencyEstimator::kMediumCost; - } - // consider other custom calls as medium cost for now. Keeping the case - // explicitly separate for further tuning. - return ApproximateLatencyEstimator::kMediumCost; - } - return ApproximateLatencyEstimator::NodeCost(instr); - } - - LatencyEstimator::TimeCost GetLatencyBetween( - const HloGraphNode& from, const HloGraphNode& target) const override { - if (IsAsyncPair(from, target)) { - if (from.GetInstr().opcode() == HloOpcode::kRecv) { - // Recv -> RecvDone has a low latency. - return ApproximateLatencyEstimator::kLowLatency; - } else if (from.GetInstr().opcode() == HloOpcode::kSend) { - // Send -> SendDone has a very high latency. - return ApproximateLatencyEstimator::kHighLatency * 10; - } - - bool enable_approx_collectives = - from.GetInstr() - .GetModule() - ->config() - .debug_options() - .xla_gpu_enable_approx_costly_collectives(); - bool is_all_reduce = - from.GetInstr().opcode() == HloOpcode::kAllReduceStart; - bool collective_size_exceeds_threshold = - GetSizeOfShape(from.GetInstr().shape(), pointer_size_) > - kCostlyAllReduceThreshold; - if (enable_approx_collectives && is_all_reduce && - collective_size_exceeds_threshold) { - return ApproximateLatencyEstimator::kHighLatency * - kCostlyAllReduceMultiplier; - } - - return ApproximateLatencyEstimator::kHighLatency; - } - // Every other instruction we consider synchronous, which means the - // latency between each of them is always one unit. - return ApproximateLatencyEstimator::kLowLatency; - } - - private: - int64_t pointer_size_; -}; - tensorflow::profiler::ProfiledInstructionsProto GetProfileForFingerprint( tensorflow::profiler::ProfiledInstructionsProto& profile, const std::string& fingerprint) { @@ -533,16 +452,6 @@ absl::Status IsProfileApplicable( return absl::OkStatus(); } -int64_t GetSizeOfShape(const Shape& shape, int pointer_size) { - int64_t size = ShapeUtil::ByteSizeOf(shape, pointer_size); - if (shape.IsTuple() || shape.is_static()) { - return size; - } - // Each dynamic dimension size is represented as a S32. - int64_t metadata_size = sizeof(int32_t) * shape.dimensions_size(); - return size + metadata_size; -} - static int64_t GetSchedulerMemoryLimit( const HloModule* module, const se::DeviceDescription& gpu_device_info, int pointer_size); diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h index d20a056494666d..7263eff68eaa13 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h @@ -37,8 +37,6 @@ absl::Status IsProfileApplicable( const HloModule* module, const tensorflow::profiler::ProfiledInstructionsProto& profile); -int64_t GetSizeOfShape(const Shape& shape, int pointer_size); - struct ScheduleMetadata { int64_t scheduler_mem_limit; }; diff --git a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc index 33315bfacfbac0..1d5a3b99d49e43 100644 --- a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc +++ b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc @@ -27,13 +27,30 @@ limitations under the License. #include "xla/hlo/utils/hlo_query.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/latency_hiding_scheduler.h" +#include "xla/shape.h" +#include "xla/shape_util.h" namespace xla { namespace gpu { namespace { +// A threshold for which we consider AR to be costly perf-wise. +static constexpr int64_t kCostlyAllReduceThreshold = 30 * 1024 * 1024; + +// Multiplier which we apply to expand the base cost for the costly AR. +static constexpr int64_t kCostlyAllReduceMultiplier = 4; + +// Classifies `hlo` instruction as noop or not. +bool IsNopInstruction(const HloInstruction& hlo) { + HloOpcode op = hlo.opcode(); + return op == HloOpcode::kGetTupleElement || op == HloOpcode::kBitcast || + op == HloOpcode::kConstant || op == HloOpcode::kParameter || + hlo.IsEffectiveBitcast(); +} + bool IsAsyncComputeOp(const HloInstruction& hlo) { return (hlo.opcode() == HloOpcode::kAsyncStart || hlo.opcode() == HloOpcode::kAsyncDone) && @@ -74,6 +91,16 @@ std::pair GetP2PResourceAndUsage( } // namespace +int64_t GetSizeOfShape(const Shape& shape, int pointer_size) { + int64_t size = ShapeUtil::ByteSizeOf(shape, pointer_size); + if (shape.IsTuple() || shape.is_static()) { + return size; + } + // Each dynamic dimension size is represented as a S32. + int64_t metadata_size = sizeof(int32_t) * shape.dimensions_size(); + return size + metadata_size; +} + CanonicalAsyncOp GpuGetCanonicalAsyncOp(const HloInstruction& hlo) { switch (hlo.opcode()) { case HloOpcode::kSend: @@ -89,6 +116,7 @@ CanonicalAsyncOp GpuGetCanonicalAsyncOp(const HloInstruction& hlo) { } } +// GpuAsyncTrackerBase implementations begin GpuAsyncTrackerBase::GpuAsyncTrackerBase(const SchedulerConfig& config, GetCanonicalAsyncOpFunc func) : AsyncTracker(config, func) {} @@ -132,7 +160,9 @@ void GpuAsyncTrackerBase::PostProcessScheduleGraph( } } } +// GpuAsyncTrackerBase implementations end +// GpuAsyncTracker implementations begin GpuAsyncTracker::GpuAsyncTracker(const SchedulerConfig& config) : GpuAsyncTrackerBase(config) {} @@ -278,5 +308,67 @@ int64_t GpuAsyncTracker::GetNumResourcesPerInstruction( return num_resources - (found ? 1 : 0); } +// GpuAsyncTracker implementations end + +// GpuLatencyEstimator implementations begin +GpuLatencyEstimator::GpuLatencyEstimator(int64_t pointer_size, + GetCanonicalAsyncOpFunc func) + : ApproximateLatencyEstimator(func), pointer_size_(pointer_size) {} + +ApproximateLatencyEstimator::TimeCost GpuLatencyEstimator::NodeCost( + const HloInstruction* instr) const { + if (IsNopInstruction(*instr)) { + return 0.0; + } + // Consider cublas/cuddn/softmax custom calls as medium cost. Since the + // latency between async-start and async-done is 5000 and cost of each + // custom call is 1000, the LHS will try to schedule approximately 5 of + // these in between each start/end pair. + if (instr->opcode() == HloOpcode::kCustomCall) { + if (IsCublasGemm(*instr) || IsCustomCallToDnnConvolution(*instr)) { + return ApproximateLatencyEstimator::kMediumCost; + } + // consider other custom calls as medium cost for now. Keeping the case + // explicitly separate for further tuning. + return ApproximateLatencyEstimator::kMediumCost; + } + return ApproximateLatencyEstimator::NodeCost(instr); +} + +ApproximateLatencyEstimator::TimeCost GpuLatencyEstimator::GetLatencyBetween( + const HloGraphNode& from, const HloGraphNode& to) const { + if (IsAsyncPair(from, to)) { + if (from.GetInstr().opcode() == HloOpcode::kRecv) { + // Recv -> RecvDone has a low latency. + return ApproximateLatencyEstimator::kLowLatency; + } else if (from.GetInstr().opcode() == HloOpcode::kSend) { + // Send -> SendDone has a very high latency. + return ApproximateLatencyEstimator::kHighLatency * 10; + } + + bool enable_approx_collectives = + from.GetInstr() + .GetModule() + ->config() + .debug_options() + .xla_gpu_enable_approx_costly_collectives(); + bool is_all_reduce = from.GetInstr().opcode() == HloOpcode::kAllReduceStart; + bool collective_size_exceeds_threshold = + GetSizeOfShape(from.GetInstr().shape(), pointer_size_) > + kCostlyAllReduceThreshold; + if (enable_approx_collectives && is_all_reduce && + collective_size_exceeds_threshold) { + return ApproximateLatencyEstimator::kHighLatency * + kCostlyAllReduceMultiplier; + } + + return ApproximateLatencyEstimator::kHighLatency; + } + // Every other instruction we consider synchronous, which means the + // latency between each of them is always one unit. + return ApproximateLatencyEstimator::kLowLatency; +} +// GpuLatencyEstimator implementations end + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.h b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.h index a7b0e0f7546b99..fae9debc8fc291 100644 --- a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.h +++ b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/latency_hiding_scheduler.h" +#include "xla/shape.h" namespace xla { namespace gpu { @@ -29,6 +30,9 @@ namespace gpu { // E.g. AllReduceStart is broken down into Reduce + AsyncStart. CanonicalAsyncOp GpuGetCanonicalAsyncOp(const HloInstruction& hlo); +// Returns size of the `shape` given the `pointer_size`. +int64_t GetSizeOfShape(const Shape& shape, int pointer_size); + // GPU specific resources for latency hiding scheduler. // // We use two different set of resources to model the scheduling of asynchronous @@ -95,6 +99,25 @@ class GpuAsyncTracker : public GpuAsyncTrackerBase { int64_t resource_type, const HloInstruction& instr) const override; }; +// GPU approximate latency estimator. It is a set of hardcoded heuristics +// for every instruction and async instruction pairs. +class GpuLatencyEstimator : public ApproximateLatencyEstimator { + public: + explicit GpuLatencyEstimator( + int64_t pointer_size, + GetCanonicalAsyncOpFunc func = GpuGetCanonicalAsyncOp); + + // Uses the approximate node for an instruction `instr`. + TimeCost NodeCost(const HloInstruction* instr) const override; + + // Returns a latency estimation between nodes `from` and `to`. + TimeCost GetLatencyBetween(const HloGraphNode& from, + const HloGraphNode& to) const override; + + private: + int64_t pointer_size_; +}; + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc b/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc index 991e209df057a1..642a0cc9eca438 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/service/buffer_value.h" #include "xla/service/gpu/gpu_constants.h" #include "xla/service/gpu/gpu_hlo_schedule.h" +#include "xla/service/gpu/gpu_latency_hiding_scheduler.h" #include "xla/service/hlo_ordering.h" #include "xla/service/logical_buffer.h" #include "xla/stream_executor/device_description.h" From c4a89ad0e9f9002c7637d3901f962b937b73a322 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Tue, 18 Jun 2024 06:25:51 -0700 Subject: [PATCH 10/59] [XLA:GPU] Use absl::Span instead of std::vector to pass tile sizes. Tile sizes are usually small, so it's better to use InlinedVector or SmallVector to store them. PiperOrigin-RevId: 644362570 --- .../xla/xla/service/gpu/model/gpu_indexing_performance_model.cc | 2 +- .../xla/xla/service/gpu/model/gpu_indexing_performance_model.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc index 6b8cc815fc52f0..3e583c84f2d461 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -254,7 +254,7 @@ absl::StatusOr GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledFusion( const HloFusionAdaptor& fusion_adaptor, const LaunchDimensions& launch_dimensions, - const std::vector& tile_sizes) { + absl::Span tile_sizes) { // TODO(b/332714755): Add caching for SymbolicTileAnalysis. SymbolicTileAnalysisOrError analysis_or_error = SymbolicTileAnalysis::AnalyzeFusion(fusion_adaptor, mlir_context_); diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h index a1f98a8660d663..b70c29e6b722c3 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h @@ -74,7 +74,7 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { absl::StatusOr EstimateRunTimeForTiledFusion( const HloFusionAdaptor& fusion_adaptor, const LaunchDimensions& launch_dimensions, - const std::vector& output_tile_sizes); + absl::Span output_tile_sizes); // Estimate the run time of producer and consumer fused together, assuming // that they will be emitted with Triton. From 85b905212e917741f247146b8703ecf6b3553bbc Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Tue, 18 Jun 2024 06:33:50 -0700 Subject: [PATCH 11/59] Move BlockedSparseToMMA pattern from Triton to XLA. PiperOrigin-RevId: 644364189 --- .../triton/xla_extensions/sparse_dot.patch | 191 +++++------------- .../triton/xla_extensions/sparse_dot.patch | 191 +++++------------- third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/ir_emitter_triton_cuda.cc | 1 + .../xla/service/gpu/ir_emitter_triton_rocm.cc | 1 + .../tests/sparse_ttg_accelerate_matmul.mlir | 2 +- .../service/gpu/triton_sparse_extensions.cc | 142 ++++++++++++- .../service/gpu/triton_sparse_extensions.h | 1 + 8 files changed, 236 insertions(+), 294 deletions(-) diff --git a/third_party/triton/xla_extensions/sparse_dot.patch b/third_party/triton/xla_extensions/sparse_dot.patch index 307e4cdbb947b4..ae8c5c61522860 100644 --- a/third_party/triton/xla_extensions/sparse_dot.patch +++ b/third_party/triton/xla_extensions/sparse_dot.patch @@ -181,46 +181,48 @@ diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect index c47558fa6..35f0cca95 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp -@@ -37,7 +37,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) { +@@ -37,8 +37,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) { return 0; } -SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, -+template -+SmallVector warpsPerTileV2(DotType dotOp, const ArrayRef shape, - int numWarps) { +- int numWarps) { ++SmallVector ++warpsPerTileV2(Operation *dotOp, const ArrayRef shape, int numWarps) { auto rank = shape.size(); // Early exit for batched matmul -@@ -51,8 +52,8 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, + if (rank == 3) +@@ -51,9 +51,8 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, auto slices = multiRootGetSlice(dotOp, {filter}, {filter}); bool hasChainedDot = false; for (Operation *op : slices) { - if (isa(op) && (op != dotOp)) { - auto chainedDot = cast(op); -+ if (isa(op) && (op != dotOp)) { -+ auto chainedDot = cast(op); - auto resTy = chainedDot.getResult().getType(); +- auto resTy = chainedDot.getResult().getType(); ++ if (dotOp->getName() == op->getName() && op != dotOp) { ++ auto resTy = cast(op->getResult(0).getType()); if (resTy.getRank() != rank) { continue; -@@ -96,12 +97,13 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, - return ret; + } +@@ -97,12 +96,13 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, } --SmallVector + SmallVector -warpsPerTileV3(DotOp dotOp, const ArrayRef shape, int numWarps, -- const SmallVector &instrShape) { -+template -+SmallVector warpsPerTileV3( -+ DotType dotOp, const ArrayRef shape, int numWarps, -+ const SmallVector &instrShape) { ++warpsPerTileV3(Operation *dotOp, const ArrayRef shape, int numWarps, + const SmallVector &instrShape) { SetVector slices; - mlir::getForwardSlice(dotOp.getResult(), &slices); +- mlir::getForwardSlice(dotOp.getResult(), &slices); - if (llvm::find_if(slices, [](Operation *op) { return isa(op); }) != -+ if (llvm::find_if(slices, [](Operation *op) { return isa(op); }) != - slices.end()) +- slices.end()) ++ mlir::getForwardSlice(dotOp->getResult(0), &slices); ++ if (llvm::find_if(slices, [&](Operation *op) { ++ return dotOp->getName() == op->getName(); ++ }) != slices.end()) return {(unsigned)numWarps, 1}; -@@ -191,6 +193,7 @@ class BlockedToMMA : public mlir::OpRewritePattern { + // For MMAv3, the smallest indivisible unit of warp shape is (4, 1). +@@ -167,6 +167,7 @@ class BlockedToMMA : public mlir::OpRewritePattern { mlir::TypeID::get()); } @@ -228,140 +230,37 @@ index c47558fa6..35f0cca95 100644 // Finds the first different bitwidth in the chain of shape-preserving // unary ops that x depends on. // There are two primary scenarios: -@@ -224,14 +227,14 @@ class BlockedToMMA : public mlir::OpRewritePattern { - return origBitWidth; +@@ -206,7 +207,7 @@ public: } --public: - BlockedToMMA(mlir::MLIRContext *context, int computeCapability) - : OpRewritePattern(context), computeCapability(computeCapability) { - } - -- static SmallVector + static SmallVector - getWarpsPerTile(DotOp dotOp, const ArrayRef shape, int version, -- int numWarps, const SmallVector &instrShape) { -+ template -+ static SmallVector getWarpsPerTile( -+ DotType dotOp, const ArrayRef shape, int version, int numWarps, -+ const SmallVector &instrShape) { ++ getWarpsPerTile(Operation *dotOp, const ArrayRef shape, int version, + int numWarps, const SmallVector &instrShape) { switch (version) { case 2: - return warpsPerTileV2(dotOp, shape, numWarps); -@@ -359,6 +362,106 @@ public: - return success(); +@@ -405,6 +406,21 @@ public: } }; -+ -+class SparseBlockedToMMA : public mlir::RewritePattern { -+ public: -+ using SparseDotOp = mlir::triton::gpu::SparseDotOp; -+ using SparseDotMetaEncodingAttr = -+ mlir::triton::gpu::SparseDotMetaEncodingAttr; -+ -+ SparseBlockedToMMA(mlir::MLIRContext *context, int computeCapability) -+ : mlir::RewritePattern(SparseDotOp::getOperationName(), 2, context), -+ computeCapability(computeCapability) {} -+ -+ mlir::LogicalResult matchAndRewrite( -+ mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { -+ auto dotOp = cast(op); -+ auto ctx = op->getContext(); -+ Value a = dotOp.getA(); -+ Value b = dotOp.getB(); -+ -+ // Check data-types and SM compatibility -+ RankedTensorType oldRetType = dotOp.getType(); -+ if (!oldRetType.getEncoding() || -+ isa(oldRetType.getEncoding())) -+ return failure(); -+ -+ assert(computeCapability >= 80 && -+ "SparseDot is supported on Ampere and higher"); -+ bool allowV3 = !triton::tools::getBoolEnv("DISABLE_MMA_V3"); -+ int versionMajor = computeCapability >= 90 && allowV3 ? 3 : 2; -+ -+ // get MMA encoding for the given number of warps -+ auto retShapePerCTA = getShapePerCTA(oldRetType); -+ auto mod = op->getParentOfType(); -+ int numWarps = TritonGPUDialect::getNumWarps(mod); -+ auto CTALayout = getCTALayout(oldRetType.getEncoding()); -+ -+ auto instrShape = -+ mmaVersionToInstrShape(versionMajor, retShapePerCTA, -+ cast(a.getType()), numWarps); -+ auto warpsPerTile = BlockedToMMA::getWarpsPerTile( -+ dotOp, retShapePerCTA, versionMajor, numWarps, instrShape); -+ NvidiaMmaEncodingAttr mmaEnc = -+ NvidiaMmaEncodingAttr::get(ctx, versionMajor, /*versionMinor=*/0, -+ warpsPerTile, CTALayout, instrShape); -+ auto newRetType = RankedTensorType::get( -+ oldRetType.getShape(), oldRetType.getElementType(), mmaEnc); -+ -+ // convert accumulator -+ auto oldAcc = dotOp.getOperand(2); -+ auto newAcc = rewriter.create(oldAcc.getLoc(), -+ newRetType, oldAcc); -+ -+ if (versionMajor == 2) { -+ int minBitwidth = std::min(BlockedToMMA::computeOrigBitWidth(a), -+ BlockedToMMA::computeOrigBitWidth(b)); -+ int kWidth = 32 / minBitwidth; -+ -+ // convert A operand -+ auto oldAType = cast(a.getType()); -+ auto newAEncoding = -+ DotOperandEncodingAttr::get(ctx, 0, mmaEnc, kWidth); -+ auto newAType = RankedTensorType::get( -+ oldAType.getShape(), oldAType.getElementType(), newAEncoding); -+ a = rewriter.create(a.getLoc(), newAType, a); -+ -+ // convert B operand -+ auto oldBType = cast(b.getType()); -+ auto newBEncoding = -+ DotOperandEncodingAttr::get(ctx, 1, mmaEnc, kWidth); -+ auto newBType = RankedTensorType::get( -+ oldBType.getShape(), oldBType.getElementType(), newBEncoding); -+ b = rewriter.create(b.getLoc(), newBType, b); -+ } else { -+ auto eltType = dotOp.getA().getType().getElementType(); -+ // In MMAV3 tranpose is only supported for f16 and bf16. -+ bool allowTranspose = eltType.isF16() || eltType.isBF16(); -+ a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose); -+ b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose); -+ } -+ -+ // convert metadata -+ Value meta = dotOp.getAMeta(); -+ auto oldMetaType = cast(meta.getType()); -+ auto newMetaType = RankedTensorType::get( -+ oldMetaType.getShape(), oldMetaType.getElementType(), -+ SparseDotMetaEncodingAttr::get(ctx, mmaEnc)); -+ meta = -+ rewriter.create(meta.getLoc(), newMetaType, meta); -+ -+ // convert dot instruction -+ auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, -+ newAcc, meta); -+ -+ rewriter.replaceOpWithNewOp(op, oldRetType, -+ newDot.getResult()); -+ return success(); -+ } -+ -+ private: -+ int computeCapability; -+}; - } // namespace - - static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, -@@ -420,6 +523,7 @@ public: - mlir::RewritePatternSet patterns(context); - patterns.add(context, computeCapability); -+ patterns.add(context, computeCapability); - if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { - signalPassFailure(); - } ++// Expose helper functions from BlockedToMMA to be reused for sparse matmul. ++SmallVector ++getWarpsPerTile(Operation *dotOp, ArrayRef shape, int version, ++ int numWarps, const SmallVector &instrShape) { ++ return BlockedToMMA::getWarpsPerTile(dotOp, shape, version, numWarps, ++ instrShape); ++} ++int computeOrigBitWidth(Value x) { ++ return BlockedToMMA::computeOrigBitWidth(x); ++} ++Value getSharedMemMMAOperand(Value v, mlir::PatternRewriter &rewriter, ++ int opIdx, bool allowTranspose) { ++ return getSharedMemoryMMAOperand(v, rewriter, opIdx, allowTranspose); ++} ++ + } // namespace gpu + } // namespace triton + } // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 5cb992714..cdafdffce 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp diff --git a/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch b/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch index 307e4cdbb947b4..ae8c5c61522860 100644 --- a/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch +++ b/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch @@ -181,46 +181,48 @@ diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect index c47558fa6..35f0cca95 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp -@@ -37,7 +37,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) { +@@ -37,8 +37,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) { return 0; } -SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, -+template -+SmallVector warpsPerTileV2(DotType dotOp, const ArrayRef shape, - int numWarps) { +- int numWarps) { ++SmallVector ++warpsPerTileV2(Operation *dotOp, const ArrayRef shape, int numWarps) { auto rank = shape.size(); // Early exit for batched matmul -@@ -51,8 +52,8 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, + if (rank == 3) +@@ -51,9 +51,8 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, auto slices = multiRootGetSlice(dotOp, {filter}, {filter}); bool hasChainedDot = false; for (Operation *op : slices) { - if (isa(op) && (op != dotOp)) { - auto chainedDot = cast(op); -+ if (isa(op) && (op != dotOp)) { -+ auto chainedDot = cast(op); - auto resTy = chainedDot.getResult().getType(); +- auto resTy = chainedDot.getResult().getType(); ++ if (dotOp->getName() == op->getName() && op != dotOp) { ++ auto resTy = cast(op->getResult(0).getType()); if (resTy.getRank() != rank) { continue; -@@ -96,12 +97,13 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, - return ret; + } +@@ -97,12 +96,13 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, } --SmallVector + SmallVector -warpsPerTileV3(DotOp dotOp, const ArrayRef shape, int numWarps, -- const SmallVector &instrShape) { -+template -+SmallVector warpsPerTileV3( -+ DotType dotOp, const ArrayRef shape, int numWarps, -+ const SmallVector &instrShape) { ++warpsPerTileV3(Operation *dotOp, const ArrayRef shape, int numWarps, + const SmallVector &instrShape) { SetVector slices; - mlir::getForwardSlice(dotOp.getResult(), &slices); +- mlir::getForwardSlice(dotOp.getResult(), &slices); - if (llvm::find_if(slices, [](Operation *op) { return isa(op); }) != -+ if (llvm::find_if(slices, [](Operation *op) { return isa(op); }) != - slices.end()) +- slices.end()) ++ mlir::getForwardSlice(dotOp->getResult(0), &slices); ++ if (llvm::find_if(slices, [&](Operation *op) { ++ return dotOp->getName() == op->getName(); ++ }) != slices.end()) return {(unsigned)numWarps, 1}; -@@ -191,6 +193,7 @@ class BlockedToMMA : public mlir::OpRewritePattern { + // For MMAv3, the smallest indivisible unit of warp shape is (4, 1). +@@ -167,6 +167,7 @@ class BlockedToMMA : public mlir::OpRewritePattern { mlir::TypeID::get()); } @@ -228,140 +230,37 @@ index c47558fa6..35f0cca95 100644 // Finds the first different bitwidth in the chain of shape-preserving // unary ops that x depends on. // There are two primary scenarios: -@@ -224,14 +227,14 @@ class BlockedToMMA : public mlir::OpRewritePattern { - return origBitWidth; +@@ -206,7 +207,7 @@ public: } --public: - BlockedToMMA(mlir::MLIRContext *context, int computeCapability) - : OpRewritePattern(context), computeCapability(computeCapability) { - } - -- static SmallVector + static SmallVector - getWarpsPerTile(DotOp dotOp, const ArrayRef shape, int version, -- int numWarps, const SmallVector &instrShape) { -+ template -+ static SmallVector getWarpsPerTile( -+ DotType dotOp, const ArrayRef shape, int version, int numWarps, -+ const SmallVector &instrShape) { ++ getWarpsPerTile(Operation *dotOp, const ArrayRef shape, int version, + int numWarps, const SmallVector &instrShape) { switch (version) { case 2: - return warpsPerTileV2(dotOp, shape, numWarps); -@@ -359,6 +362,106 @@ public: - return success(); +@@ -405,6 +406,21 @@ public: } }; -+ -+class SparseBlockedToMMA : public mlir::RewritePattern { -+ public: -+ using SparseDotOp = mlir::triton::gpu::SparseDotOp; -+ using SparseDotMetaEncodingAttr = -+ mlir::triton::gpu::SparseDotMetaEncodingAttr; -+ -+ SparseBlockedToMMA(mlir::MLIRContext *context, int computeCapability) -+ : mlir::RewritePattern(SparseDotOp::getOperationName(), 2, context), -+ computeCapability(computeCapability) {} -+ -+ mlir::LogicalResult matchAndRewrite( -+ mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { -+ auto dotOp = cast(op); -+ auto ctx = op->getContext(); -+ Value a = dotOp.getA(); -+ Value b = dotOp.getB(); -+ -+ // Check data-types and SM compatibility -+ RankedTensorType oldRetType = dotOp.getType(); -+ if (!oldRetType.getEncoding() || -+ isa(oldRetType.getEncoding())) -+ return failure(); -+ -+ assert(computeCapability >= 80 && -+ "SparseDot is supported on Ampere and higher"); -+ bool allowV3 = !triton::tools::getBoolEnv("DISABLE_MMA_V3"); -+ int versionMajor = computeCapability >= 90 && allowV3 ? 3 : 2; -+ -+ // get MMA encoding for the given number of warps -+ auto retShapePerCTA = getShapePerCTA(oldRetType); -+ auto mod = op->getParentOfType(); -+ int numWarps = TritonGPUDialect::getNumWarps(mod); -+ auto CTALayout = getCTALayout(oldRetType.getEncoding()); -+ -+ auto instrShape = -+ mmaVersionToInstrShape(versionMajor, retShapePerCTA, -+ cast(a.getType()), numWarps); -+ auto warpsPerTile = BlockedToMMA::getWarpsPerTile( -+ dotOp, retShapePerCTA, versionMajor, numWarps, instrShape); -+ NvidiaMmaEncodingAttr mmaEnc = -+ NvidiaMmaEncodingAttr::get(ctx, versionMajor, /*versionMinor=*/0, -+ warpsPerTile, CTALayout, instrShape); -+ auto newRetType = RankedTensorType::get( -+ oldRetType.getShape(), oldRetType.getElementType(), mmaEnc); -+ -+ // convert accumulator -+ auto oldAcc = dotOp.getOperand(2); -+ auto newAcc = rewriter.create(oldAcc.getLoc(), -+ newRetType, oldAcc); -+ -+ if (versionMajor == 2) { -+ int minBitwidth = std::min(BlockedToMMA::computeOrigBitWidth(a), -+ BlockedToMMA::computeOrigBitWidth(b)); -+ int kWidth = 32 / minBitwidth; -+ -+ // convert A operand -+ auto oldAType = cast(a.getType()); -+ auto newAEncoding = -+ DotOperandEncodingAttr::get(ctx, 0, mmaEnc, kWidth); -+ auto newAType = RankedTensorType::get( -+ oldAType.getShape(), oldAType.getElementType(), newAEncoding); -+ a = rewriter.create(a.getLoc(), newAType, a); -+ -+ // convert B operand -+ auto oldBType = cast(b.getType()); -+ auto newBEncoding = -+ DotOperandEncodingAttr::get(ctx, 1, mmaEnc, kWidth); -+ auto newBType = RankedTensorType::get( -+ oldBType.getShape(), oldBType.getElementType(), newBEncoding); -+ b = rewriter.create(b.getLoc(), newBType, b); -+ } else { -+ auto eltType = dotOp.getA().getType().getElementType(); -+ // In MMAV3 tranpose is only supported for f16 and bf16. -+ bool allowTranspose = eltType.isF16() || eltType.isBF16(); -+ a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose); -+ b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose); -+ } -+ -+ // convert metadata -+ Value meta = dotOp.getAMeta(); -+ auto oldMetaType = cast(meta.getType()); -+ auto newMetaType = RankedTensorType::get( -+ oldMetaType.getShape(), oldMetaType.getElementType(), -+ SparseDotMetaEncodingAttr::get(ctx, mmaEnc)); -+ meta = -+ rewriter.create(meta.getLoc(), newMetaType, meta); -+ -+ // convert dot instruction -+ auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, -+ newAcc, meta); -+ -+ rewriter.replaceOpWithNewOp(op, oldRetType, -+ newDot.getResult()); -+ return success(); -+ } -+ -+ private: -+ int computeCapability; -+}; - } // namespace - - static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, -@@ -420,6 +523,7 @@ public: - mlir::RewritePatternSet patterns(context); - patterns.add(context, computeCapability); -+ patterns.add(context, computeCapability); - if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { - signalPassFailure(); - } ++// Expose helper functions from BlockedToMMA to be reused for sparse matmul. ++SmallVector ++getWarpsPerTile(Operation *dotOp, ArrayRef shape, int version, ++ int numWarps, const SmallVector &instrShape) { ++ return BlockedToMMA::getWarpsPerTile(dotOp, shape, version, numWarps, ++ instrShape); ++} ++int computeOrigBitWidth(Value x) { ++ return BlockedToMMA::computeOrigBitWidth(x); ++} ++Value getSharedMemMMAOperand(Value v, mlir::PatternRewriter &rewriter, ++ int opIdx, bool allowTranspose) { ++ return getSharedMemoryMMAOperand(v, rewriter, opIdx, allowTranspose); ++} ++ + } // namespace gpu + } // namespace triton + } // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 5cb992714..cdafdffce 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 90d8940b681515..b61d9785727ec1 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -606,6 +606,7 @@ cc_library( "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@triton//:TritonDialects", + "@triton//:TritonGPUToLLVM", "@triton//:TritonGPUTransforms", ], ) diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_cuda.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_cuda.cc index 217ad889416a4a..c8ffbfacbac30b 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_cuda.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_cuda.cc @@ -74,6 +74,7 @@ absl::Status CreateTritonPipeline( pm.addPass(mlir::createTritonNvidiaGPUPlanCTAPass(&out_cluster_info)); pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); pm.addPass(mt::gpu::createTritonGPUOptimizeThreadLocality()); + pm.addPass(createSparseBlockedToMMAPass()); pm.addPass(mt::gpu::createTritonGPUAccelerateMatmul()); pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); pm.addPass( diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_rocm.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_rocm.cc index 79ef0ce4a5d05c..2a6e6c3c805cd8 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_rocm.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_rocm.cc @@ -75,6 +75,7 @@ absl::Status CreateTritonPipeline( pm.addPass(mt::gpu::createTritonGPUCoalesce()); pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); pm.addPass(mt::gpu::createTritonGPUOptimizeThreadLocality()); + pm.addPass(createSparseBlockedToMMAPass()); pm.addPass(mt::gpu::createTritonGPUAccelerateMatmul()); pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); // TODO ROCm Check if we want to compare MI100 and greater diff --git a/third_party/xla/xla/service/gpu/tests/sparse_ttg_accelerate_matmul.mlir b/third_party/xla/xla/service/gpu/tests/sparse_ttg_accelerate_matmul.mlir index f1b0f88932a1f9..65bcbd87ddf130 100644 --- a/third_party/xla/xla/service/gpu/tests/sparse_ttg_accelerate_matmul.mlir +++ b/third_party/xla/xla/service/gpu/tests/sparse_ttg_accelerate_matmul.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-accelerate-matmul | FileCheck %s +// RUN: sparse-opt %s -split-input-file -sparse-blocked-to-mma | FileCheck %s #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> // CHECK: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> diff --git a/third_party/xla/xla/service/gpu/triton_sparse_extensions.cc b/third_party/xla/xla/service/gpu/triton_sparse_extensions.cc index d7495d07a793a4..8b05928c17e677 100644 --- a/third_party/xla/xla/service/gpu/triton_sparse_extensions.cc +++ b/third_party/xla/xla/service/gpu/triton_sparse_extensions.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/service/gpu/triton_sparse_extensions.h" +#include +#include #include #include #include @@ -32,14 +34,27 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "mlir/Support/TypeID.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/CommandLine.h" -#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" using namespace mlir; // NOLINT(build/namespaces) +// The functions below are defined in AccelerateMatmul.cpp. +namespace mlir::triton::gpu { +SmallVector getWarpsPerTile( + Operation *dotOp, ArrayRef shape, int version, int numWarps, + const SmallVector &instrShape); +int computeOrigBitWidth(Value x); +Value getSharedMemMMAOperand(Value v, mlir::PatternRewriter &rewriter, + int opIdx, bool allowTranspose); +} // namespace mlir::triton::gpu + namespace { struct TritonSparseDotPattern @@ -175,6 +190,126 @@ class AddSparseDotEncodingPass llvm::cl::init(1)}; }; +class SparseBlockedToMMA : public RewritePattern { + using ConvertLayoutOp = triton::gpu::ConvertLayoutOp; + using SparseDotOp = triton::gpu::SparseDotOp; + using SparseDotMetaEncodingAttr = triton::gpu::SparseDotMetaEncodingAttr; + using NvidiaMmaEncodingAttr = triton::gpu::NvidiaMmaEncodingAttr; + + public: + SparseBlockedToMMA(MLIRContext *context, int compute_capability) + : RewritePattern(SparseDotOp::getOperationName(), 2, context), + compute_capability_(compute_capability) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto dotOp = cast(op); + auto ctx = op->getContext(); + Value a = dotOp.getA(); + Value b = dotOp.getB(); + + // Check data-types and SM compatibility + RankedTensorType oldRetType = dotOp.getType(); + if (!oldRetType.getEncoding() || + isa(oldRetType.getEncoding())) + return failure(); + + assert(compute_capability_ >= 80 && + "SparseDot is supported on Ampere and higher"); + bool allowV3 = !triton::tools::getBoolEnv("DISABLE_MMA_V3"); + int versionMajor = compute_capability_ >= 90 && allowV3 ? 3 : 2; + + // get MMA encoding for the given number of warps + auto retShapePerCTA = triton::gpu::getShapePerCTA(oldRetType); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + auto CTALayout = triton::gpu::getCTALayout(oldRetType.getEncoding()); + + auto instrShape = + mmaVersionToInstrShape(versionMajor, retShapePerCTA, + cast(a.getType()), numWarps); + auto warpsPerTile = getWarpsPerTile(dotOp, retShapePerCTA, versionMajor, + numWarps, instrShape); + NvidiaMmaEncodingAttr mmaEnc = + NvidiaMmaEncodingAttr::get(ctx, versionMajor, /*versionMinor=*/0, + warpsPerTile, CTALayout, instrShape); + auto newRetType = RankedTensorType::get( + oldRetType.getShape(), oldRetType.getElementType(), mmaEnc); + + // convert accumulator + auto oldAcc = dotOp.getOperand(2); + auto newAcc = rewriter.create( + oldAcc.getLoc(), newRetType, oldAcc); + + if (versionMajor == 2) { + int minBitwidth = std::min(triton::gpu::computeOrigBitWidth(a), + triton::gpu::computeOrigBitWidth(b)); + int kWidth = 32 / minBitwidth; + + // convert A operand + auto oldAType = cast(a.getType()); + auto newAEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaEnc, kWidth); + auto newAType = RankedTensorType::get( + oldAType.getShape(), oldAType.getElementType(), newAEncoding); + a = rewriter.create(a.getLoc(), newAType, a); + + // convert B operand + auto oldBType = cast(b.getType()); + auto newBEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaEnc, kWidth); + auto newBType = RankedTensorType::get( + oldBType.getShape(), oldBType.getElementType(), newBEncoding); + b = rewriter.create(b.getLoc(), newBType, b); + } else { + auto eltType = dotOp.getA().getType().getElementType(); + // In MMAV3 transpose is only supported for f16 and bf16. + bool allowTranspose = eltType.isF16() || eltType.isBF16(); + a = triton::gpu::getSharedMemMMAOperand(a, rewriter, 0, allowTranspose); + b = triton::gpu::getSharedMemMMAOperand(b, rewriter, 1, allowTranspose); + } + + // convert metadata + Value meta = dotOp.getAMeta(); + auto oldMetaType = cast(meta.getType()); + auto newMetaType = RankedTensorType::get( + oldMetaType.getShape(), oldMetaType.getElementType(), + SparseDotMetaEncodingAttr::get(ctx, mmaEnc)); + meta = rewriter.create(meta.getLoc(), newMetaType, meta); + + // convert dot instruction + auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, + newAcc, meta); + + rewriter.replaceOpWithNewOp(op, oldRetType, + newDot.getResult()); + return success(); + } + + private: + int compute_capability_; +}; + +class SparseBlockedToMMAPass + : public PassWrapper> { + public: + SparseBlockedToMMAPass() = default; + + StringRef getArgument() const override { return "sparse-blocked-to-mma"; } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + auto compute_capability = getNVIDIAComputeCapability(module); + auto pattern = + std::make_unique(context, compute_capability); + RewritePatternSet patterns(context, std::move(pattern)); + if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) { + signalPassFailure(); + } + } + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SparseBlockedToMMAPass) +}; + } // namespace std::unique_ptr xla::gpu::createAddSparseDotEncodingPass( @@ -183,6 +318,11 @@ std::unique_ptr xla::gpu::createAddSparseDotEncodingPass( num_ctas); } +std::unique_ptr xla::gpu::createSparseBlockedToMMAPass() { + return std::make_unique(); +} + void xla::gpu::registerSparsePasses() { registerPass([] { return std::make_unique(); }); + registerPass([] { return std::make_unique(); }); } diff --git a/third_party/xla/xla/service/gpu/triton_sparse_extensions.h b/third_party/xla/xla/service/gpu/triton_sparse_extensions.h index 6826a6b291d361..d3d6989e006c83 100644 --- a/third_party/xla/xla/service/gpu/triton_sparse_extensions.h +++ b/third_party/xla/xla/service/gpu/triton_sparse_extensions.h @@ -27,6 +27,7 @@ namespace xla::gpu { std::unique_ptr createAddSparseDotEncodingPass( int32_t num_warps, int32_t threads_per_warp, int32_t num_ctas); +std::unique_ptr createSparseBlockedToMMAPass(); void registerSparsePasses(); From 1107c80eaeff8e30df4fe29d3d5a8ba2e9641d6a Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 18 Jun 2024 06:41:53 -0700 Subject: [PATCH 12/59] [XLA:GPU][NFC] Replace `bitcast`s with `reshape`s in `symbolic_tile_test`. `bitcast`s are not meaningful pre-optimizations because intermediate HLO ops do not have a layout at that point. For that reason, incorrect `bitcast`s evade verifier checks in `ParseAndReturnVerifiedModule`. This was hiding a data type mismatch in our tests. Since all the `bitcast`s in `symbolic_tile_test` have `reshape` semantics, we simply replace them with `reshape`s, which is handled well by `ParseAndReturnVerifiedModule`. PiperOrigin-RevId: 644366242 --- .../service/gpu/model/symbolic_tile_test.cc | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc index 8ce8c53184854b..34cb38c265303d 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc @@ -124,7 +124,7 @@ TEST_F(SymbolicTileTest, HloModule m ENTRY e { p0 = f32[1,8,6,4]{3,2,1,0} parameter(0) - ROOT bitcast = f32[48,4]{1,0} bitcast(p0) + ROOT reshape = f32[48,4]{1,0} reshape(p0) } )")); @@ -152,7 +152,7 @@ TEST_F(SymbolicTileTest, HloModule m ENTRY e { p0 = f32[192,4]{1,0} parameter(0) - ROOT bitcast = s8[4,8,6,4]{3,2,1,0} bitcast(p0) + ROOT reshape = f32[4,8,6,4]{3,2,1,0} reshape(p0) } )")); @@ -542,7 +542,7 @@ TEST_F(SymbolicTileTest, CanPropagateTileThroughSplitReshapeOfReverse) { computation { p0 = f32[1,8,6,4]{3,2,1,0} parameter(0) reverse = f32[1,8,6,4]{3,2,1,0} reverse(p0), dimensions={1,2} - ROOT bitcast = f32[48,4]{1,0} bitcast(reverse) + ROOT reshape = f32[48,4]{1,0} reshape(reverse) } ENTRY e { @@ -572,8 +572,8 @@ TEST_F(SymbolicTileTest, HloModule m computation { p0 = f32[1,8,6,4]{3,2,1,0} parameter(0) - bitcast = f32[48,4]{1,0} bitcast(p0) - ROOT slice = f32[5,2]{1,0} slice(bitcast), slice={[18:43:5], [0:4:2]} + reshape = f32[48,4]{1,0} reshape(p0) + ROOT slice = f32[5,2]{1,0} slice(reshape), slice={[18:43:5], [0:4:2]} } ENTRY e { @@ -597,8 +597,8 @@ TEST_F(SymbolicTileTest, HloModule m computation { p0 = f32[1,8,6,4]{3,2,1,0} parameter(0) - bitcast = f32[48,4]{1,0} bitcast(p0) - ROOT slice = f32[5,2]{1,0} slice(bitcast), slice={[20:45:5], [0:4:2]} + reshape = f32[48,4]{1,0} reshape(p0) + ROOT slice = f32[5,2]{1,0} slice(reshape), slice={[20:45:5], [0:4:2]} } ENTRY e { @@ -621,8 +621,8 @@ TEST_F(SymbolicTileTest, computation { p0 = f32[1,6,8,4]{3,2,1,0} parameter(0) transpose = f32[1,8,6,4]{3,2,1,0} transpose(p0), dimensions={0,2,1,3} - bitcast = f32[48,4]{1,0} bitcast(transpose) - ROOT slice = f32[5,2]{1,0} slice(bitcast), slice={[18:43:5], [0:4:2]} + reshape = f32[48,4]{1,0} reshape(transpose) + ROOT slice = f32[5,2]{1,0} slice(reshape), slice={[18:43:5], [0:4:2]} } ENTRY e { @@ -646,8 +646,8 @@ TEST_F(SymbolicTileTest, computation { p0 = f32[1,8,6,4]{3,2,1,0} parameter(0) reverse = f32[1,8,6,4]{3,2,1,0} reverse(p0), dimensions={1,2} - bitcast = f32[48,4]{1,0} bitcast(reverse) - ROOT slice = f32[5,2]{1,0} slice(bitcast), slice={[18:43:5], [0:4:2]} + reshape = f32[48,4]{1,0} reshape(reverse) + ROOT slice = f32[5,2]{1,0} slice(reshape), slice={[18:43:5], [0:4:2]} } ENTRY e { @@ -699,7 +699,7 @@ TEST_F(SymbolicTileTest, CanCombineCompatibleConstraints) { HloModule m ENTRY e { p0 = f32[1,8,6,4,8]{4,3,2,1,0} parameter(0) - ROOT bitcast = f32[48,32]{1,0} bitcast(p0) + ROOT reshape = f32[48,32]{1,0} reshape(p0) } )")); From 528cff79c889346cbdef38b86f6631be2cfc48c3 Mon Sep 17 00:00:00 2001 From: Alexander Lyashuk Date: Tue, 18 Jun 2024 07:13:02 -0700 Subject: [PATCH 13/59] [XLA:GPU] Fall back to cuBLASlt call in autotuner when it makes sense When autotuning them, calling cuBLASlt is considered as a fallback. Before this change, only cuBLAS path was considered in the autotuner. This is a temporary change; the GemmRewriter "fp8" parameter will be removed, so only one call will be needed. PiperOrigin-RevId: 644374262 --- third_party/xla/xla/service/gpu/BUILD | 2 + .../xla/service/gpu/gemm_fusion_autotuner.cc | 13 ++-- .../service/gpu/gemm_fusion_autotuner_test.cc | 73 +++++++++++++++++++ 3 files changed, 83 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index b61d9785727ec1..723d6946d22165 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -862,6 +862,7 @@ xla_test( ":backend_configs_cc", ":gemm_fusion", ":gemm_fusion_autotuner", + ":gemm_rewriter", ":ir_emission_utils", ":matmul_utils", "//xla:autotuning_proto_cc", @@ -870,6 +871,7 @@ xla_test( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/pjrt/distributed:key_value_store_interface", + "//xla/service:call_inliner", "//xla/service:executable", "//xla/service:hlo_module_config", "//xla/service:hlo_pass_pipeline", diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc index f872efc3865101..5a09b09c2f875a 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc @@ -403,11 +403,14 @@ absl::StatusOr> CublasGemmAutotuneExtractor( PrecisionConfig::ALG_DOT_F32_F32_F32); } - GemmRewriter rewriter(config.GetGpuComputeCapability(), toolkit_version); - GpuInstructionFusion fusion_pass( - /*may_duplicate=*/false, config.GetExecutor()->GetDeviceDescription()); - TF_RETURN_IF_ERROR(rewriter.Run(new_module.get()).status()); - TF_RETURN_IF_ERROR(fusion_pass.Run(new_module.get()).status()); + for (bool fp8 : {true, false}) { + GemmRewriter rewriter(config.GetGpuComputeCapability(), toolkit_version, + fp8); + GpuInstructionFusion fusion_pass( + /*may_duplicate=*/false, config.GetExecutor()->GetDeviceDescription()); + TF_RETURN_IF_ERROR(rewriter.Run(new_module.get()).status()); + TF_RETURN_IF_ERROR(fusion_pass.Run(new_module.get()).status()); + } // TODO(tdanyluk): Consider running GemmAlgorithmPicker here for better cuBLAS // performance. It is probably not needed on Ampere and later because cuBLAS // ignores the algorithm parameter for those targets. If we run diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc index 10c1a641fc12ed..c5195a1506abad 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -36,10 +37,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/pjrt/distributed/key_value_store_interface.h" +#include "xla/service/call_inliner.h" #include "xla/service/executable.h" #include "xla/service/gpu/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gemm_fusion.h" +#include "xla/service/gpu/gemm_rewriter.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/hlo_module_config.h" @@ -566,6 +569,76 @@ class GemmFusionAutotunerDumpTest : public GemmFusionAutotunerTest { } }; +TEST_F(GemmFusionAutotunerDumpTest, Fp8CublasltFallbackSupport) { + const std::string kHloText = R"( +HloModule o + +gemm_fusion { + p0 = f8e4m3fn[64,6144]{1,0} parameter(0) + p1 = f8e4m3fn[64,6144]{1,0} parameter(1) + ROOT %dot.0 = f32[64,64]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY main { + p0 = f8e4m3fn[64,6144]{1,0} parameter(0) + p1 = f8e4m3fn[64,6144]{1,0} parameter(1) + ROOT %dot.0 = f32[64,64]{1,0} fusion(p0, p1), kind=kCustom, calls=gemm_fusion, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); + + DebugOptions opts; + AutotuneConfig autotune_config{ + DeviceConfig{backend().default_stream_executor(), + backend().memory_allocator()}, + opts}; + AutotuneCacheKey cache_key(autotune_config.GetModelStr(), + *module->entry_computation()->root_instruction()); + + TF_ASSERT_OK_AND_ASSIGN(AutotuneResults autotune_results_override, + ParseTextProto(R"pb( + version: 3 + results { + device: "..." + hlo: "..." + result { + gemm { algorithm: -1 } + run_time { nanos: 14 } + } + })pb")); + autotune_results_override.mutable_results(0)->set_device( + cache_key.GetModelStr()); + autotune_results_override.mutable_results(0)->set_hlo(cache_key.GetHlo()); + CHECK_OK(AutotunerUtil::LoadAutotuneResults(autotune_results_override)); + + HloPassPipeline pipeline("gemm_autotune"); + tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "", + tsl::port::MaxParallelism()); + MultiProcessKeyValueStore key_value_store; + pipeline.AddPass(autotune_config, GetToolkitVersion(), + &thread_pool, key_value_store); + pipeline.AddPass(); + for (bool fp8_rewrite : {true, false}) { + pipeline.AddPass(autotune_config.GetGpuComputeCapability(), + GetToolkitVersion(), fp8_rewrite); + } + + EXPECT_OK(HloTestBase::RunHloPass(&pipeline, module.get())); + const bool is_at_least_hopper = + std::holds_alternative( + autotune_config.GetGpuComputeCapability()) && + std::get( + autotune_config.GetGpuComputeCapability()) + .IsAtLeastHopper(); + TF_ASSERT_OK_AND_ASSIGN( + bool filecheck_matches, + RunFileCheck(module->ToString(), is_at_least_hopper + ? "// CHECK: __cublas$lt" + : "// CHECK: __cublas$gemm")); + EXPECT_TRUE(filecheck_matches); +} + TEST_F(GemmFusionAutotunerDumpTest, DumpingFusionsWorksWithFallback) { // Computation is chosen such that relatively heavy math operations before the // GEMM are not worth fusing because they would get duplicated many times and From e23a71943504c73e397f3f6ac0806520644daaab Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 18 Jun 2024 07:42:17 -0700 Subject: [PATCH 14/59] Fix bug in array type conversion util PiperOrigin-RevId: 644382031 --- third_party/xla/xla/reference_util.cc | 2 +- third_party/xla/xla/reference_util_test.cc | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/reference_util.cc b/third_party/xla/xla/reference_util.cc index fb0348c2f278cd..3078a2d6e6d64f 100644 --- a/third_party/xla/xla/reference_util.cc +++ b/third_party/xla/xla/reference_util.cc @@ -42,7 +42,7 @@ namespace xla { auto result = std::make_unique>(input.height(), input.width()); for (int64_t rowno = 0; rowno < input.height(); ++rowno) { - for (int64_t colno = 0; colno < input.height(); ++colno) { + for (int64_t colno = 0; colno < input.width(); ++colno) { (*result)(rowno, colno) = input(rowno, colno); } } diff --git a/third_party/xla/xla/reference_util_test.cc b/third_party/xla/xla/reference_util_test.cc index 1ade118862480f..320d1cac5e63f7 100644 --- a/third_party/xla/xla/reference_util_test.cc +++ b/third_party/xla/xla/reference_util_test.cc @@ -85,6 +85,18 @@ TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { ErrorSpec(0.0001)); } +TEST_F(ReferenceUtilTest, Array2DF32ToF64Test) { + auto result = ReferenceUtil::Array2DF32ToF64(*matrix_); + ASSERT_EQ(result->height(), matrix_->height()); + ASSERT_EQ(result->width(), matrix_->width()); + for (int64_t rowno = 0; rowno < matrix_->height(); ++rowno) { + for (int64_t colno = 0; colno < matrix_->width(); ++colno) { + EXPECT_EQ(static_cast((*matrix_)(rowno, colno)), + (*result)(rowno, colno)); + } + } +} + TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) { auto result = LiteralUtil::CreateR1(ReferenceUtil::Reduce4DTo1D( Array4D(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2}, From 65550eb40810b2e1521fef68fe695b8ce8c87063 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Tue, 18 Jun 2024 07:53:52 -0700 Subject: [PATCH 15/59] [XLA:GPU][MLIR-based emitters] Kill thread tiling for MlirColumnReduce. PiperOrigin-RevId: 644384966 --- .../xla/xla/service/gpu/fusions/reduction.cc | 2 +- .../xla/service/gpu/fusions/reduction_base.cc | 73 +++---- .../xla/service/gpu/fusions/reduction_base.h | 6 +- .../xla/service/gpu/fusions/reduction_mlir.cc | 176 +++++++--------- .../xla/service/gpu/fusions/reduction_mlir.h | 11 +- .../gpu/fusions/reduction_mlir_test.cc | 190 ++++++++++++------ 6 files changed, 260 insertions(+), 198 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/reduction.cc b/third_party/xla/xla/service/gpu/fusions/reduction.cc index 881b17a52c2b40..e4186401e9a61a 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction.cc @@ -1172,7 +1172,7 @@ ReductionInfo ReductionInfo::Create(const HloFusionAnalysis& analysis) { } int vector_size = GetVectorSize(analysis, reduction_dimensions, num_threads_x, - reduction_tiling, /*for_mlir=*/false); + reduction_tiling); absl::InlinedVector num_threads{1, num_threads_y, num_threads_x}; absl::InlinedVector tiled_shape{shape[0], shape[1], diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_base.cc b/third_party/xla/xla/service/gpu/fusions/reduction_base.cc index 66b095a61762b7..2db96f7cd49ee0 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_base.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_base.cc @@ -66,7 +66,7 @@ int RowReductionGetRowsPerWarp(int reduced_dimension_size) { int GetVectorSize(const HloFusionAnalysis& analysis, const ReductionDimensions& reduction_dimensions, - int num_threads, Vector3 reduction_tiling, bool for_mlir) { + int num_threads, Vector3 reduction_tiling) { // If the minor dimension is not divisible by 2, we can't currently vectorize. int64_t minor_dim = reduction_dimensions.dimensions.back(); if (minor_dim % 2 != 0) { @@ -77,42 +77,9 @@ int GetVectorSize(const HloFusionAnalysis& analysis, if (num_threads * 2 > minor_dim) { return 1; } - - if (for_mlir) { - // MLIR vectorizes loads/stores explicitly and therefore doesn't need - // unrolling and the associated heuristics. - - // MLIR's vectorization doesn't work with complex types. However, complex - // load/stores are effectively always vectorized and have a size - // of at least 8 bytes, which is sufficient. - for (HloInstructionAdaptor hero : analysis.fusion_heroes()) { - for (HloInstructionAdaptor operand : hero.GetOperands()) { - if (primitive_util::IsComplexType(operand.shape().element_type())) { - return 1; - } - } - } - - // 16 byte vector loads are often slower than 8 byte loads. - if (analysis.input_output_info().smallest_input_dtype_bits >= 32) { - return 2; - } - if (analysis.input_output_info().smallest_input_dtype_bits >= 64) { - return 1; - } - - // Like above, if the size of the minor dimension is not sufficiently large, - // the vectorization is not helpful. - if (num_threads * 4 > minor_dim) { - return 2; - } - return minor_dim % 4 == 0 ? 4 : 2; - } - if (MayPreventVectorization(analysis.fusion())) { return 1; } - if (reduction_dimensions.is_row_reduction) { constexpr int kRowMinorReduced = ReductionDimensions::kRowMinorReducedDimension; @@ -132,10 +99,46 @@ int GetVectorSize(const HloFusionAnalysis& analysis, } return 1; } - return 1; } +int GetVectorSizeForMlir(const HloFusionAnalysis& analysis, + const ReductionDimensions& reduction_dimensions, + int num_threads) { + // If the minor dimension is not divisible by 2, we can't currently vectorize. + int64_t minor_dim = reduction_dimensions.dimensions.back(); + if (minor_dim % 2 != 0) { + return 1; + } + // Only enable vectorization if all threads will still have work. + if (num_threads * 2 > minor_dim) { + return 1; + } + // MLIR's vectorization doesn't work with complex types. However, complex + // load/stores are effectively always vectorized and have a size + // of at least 8 bytes, which is sufficient. + for (HloInstructionAdaptor hero : analysis.fusion_heroes()) { + for (HloInstructionAdaptor operand : hero.GetOperands()) { + if (primitive_util::IsComplexType(operand.shape().element_type())) { + return 1; + } + } + } + // 16 byte vector loads are often slower than 8 byte loads. + if (analysis.input_output_info().smallest_input_dtype_bits >= 32) { + return 2; + } + if (analysis.input_output_info().smallest_input_dtype_bits >= 64) { + return 1; + } + // Like above, if the size of the minor dimension is not sufficiently large, + // the vectorization is not helpful. + if (num_threads * 4 > minor_dim) { + return 2; + } + return minor_dim % 4 == 0 ? 4 : 2; +} + ReductionGroups GroupDisjointReductions(const HloFusionAnalysis& analysis, bool for_mlir) { const int num_fusion_outputs = analysis.fusion_root_count(); diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_base.h b/third_party/xla/xla/service/gpu/fusions/reduction_base.h index 7bf4437deadd74..6c0c2c787bd242 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_base.h +++ b/third_party/xla/xla/service/gpu/fusions/reduction_base.h @@ -45,7 +45,11 @@ int RowReductionGetRowsPerWarp(int reduced_dimension_size); int GetVectorSize(const HloFusionAnalysis& analysis, const ReductionDimensions& reduction_dimensions, - int num_threads, Vector3 reduction_tiling, bool for_mlir); + int num_threads, Vector3 reduction_tiling); + +int GetVectorSizeForMlir(const HloFusionAnalysis& analysis, + const ReductionDimensions& reduction_dimensions, + int num_threads); void AddGroupIdConstraint(IndexingMap& map, int64_t root_index, const ReductionGroups& groups); diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc index b3ec1f8e00ee2e..46e1bddf5e5dd8 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc @@ -229,7 +229,7 @@ absl::Status MlirReductionFusion::EmitEntryFunction( return absl::OkStatus(); } -IndexingMap MlirReductionFusion::ComputeThreadIdToReductionInputIndexing( +IndexingMap MlirRowReductionFusion::ComputeThreadIdToReductionInputIndexing( mlir::MLIRContext* ctx) const { auto rank = input_shape_.size(); @@ -266,7 +266,7 @@ HloValueMap MlirReductionFusion::EmitterState::EmitPerThreadReducedElements( KernelFusionInterface::kIndexingMapBlockIdxDims[1]) .upper = owner.reduction_heroes_.size(); tile_indexing.Simplify(); - bool vectorize = owner.tile_sizes_per_thread_.back() > 1; + bool vectorize = owner.vector_size_ > 1; SmallVector iter_arg_inits; const auto& side_outputs = owner.side_output_roots_[group_id]; @@ -459,8 +459,7 @@ MlirRowReductionFusion::MlirRowReductionFusion( } int vector_size = - GetVectorSize(analysis, reduction_dimensions_, num_threads_x, - reduction_tiling, /*for_mlir=*/true); + GetVectorSizeForMlir(analysis, reduction_dimensions_, num_threads_x); num_threads_ = absl::InlinedVector{1, num_threads_y, num_threads_x}; @@ -519,6 +518,7 @@ MlirRowReductionFusion::MlirRowReductionFusion( total_num_blocks_ = Product(num_blocks_); total_num_threads_per_block_ = Product(num_threads_); + vector_size_ = tile_sizes_per_thread_.back(); } std::optional @@ -728,49 +728,28 @@ MlirColumnReductionFusion::MlirColumnReductionFusion( const HloFusionAnalysis& analysis) : MlirReductionFusion(analysis) { CHECK(!reduction_dimensions_.is_row_reduction); - auto shape = reduction_dimensions_.dimensions; - Vector3 reduction_tiling = {1, 128, 1}; - int vector_size = GetVectorSize(analysis, reduction_dimensions_, WarpSize(), - reduction_tiling, /*for_mlir=*/true); - // The vector dimension is a loop, i.e. we use a symbol for it. - num_threads_ = absl::InlinedVector{1, WarpSize(), WarpSize(), 1}; - input_shape_ = {shape[0], shape[1], shape[2] / vector_size, vector_size}; - tile_sizes_per_thread_ = {reduction_tiling[0], reduction_tiling[1], - reduction_tiling[2], vector_size}; - // The indexing map simplifier does not currently handle this correctly, - // leading to loop bounds that are too large. - // TODO(jreiffers): Implement tightening of ranges based on constraints - // instead. For example, based on: - // - // s1 in [0, 127] - // d0 floordiv 32 + s1 * 32 in [0, 63] - // - // Tighten the bound of s1 to [0, 1]. - for (int i = 0; i < num_threads_.size() - 1; ++i) { - tile_sizes_per_thread_[i] = - std::min(tile_sizes_per_thread_[i], - CeilOfRatio(input_shape_[i], num_threads_[i])); - } - tile_sizes_per_block_.resize(input_shape_.size()); - num_blocks_.resize(input_shape_.size()); - for (int64_t i = 0; i < input_shape_.size(); ++i) { - tile_sizes_per_block_[i] = tile_sizes_per_thread_[i] * num_threads_[i]; - CHECK_NE(tile_sizes_per_block_[i], 0); - num_blocks_[i] = CeilOfRatio(input_shape_[i], tile_sizes_per_block_[i]); - CHECK_NE(num_blocks_[i], 0); - } - - total_num_blocks_ = Product(num_blocks_); - total_num_threads_per_block_ = Product(num_threads_); + input_shape_ = {reduction_dimensions_.dimensions[0], + reduction_dimensions_.dimensions[1], + reduction_dimensions_.dimensions[2]}; + vector_size_ = + GetVectorSizeForMlir(analysis, reduction_dimensions_, WarpSize()); + num_warps_per_column_ = WarpSize(); + total_num_threads_per_block_ = num_warps_per_column_ * WarpSize(); + + int64_t major_kept_dim = + reduction_dimensions_ + .dimensions[ReductionDimensions::kColMajorKeptDimension]; + int64_t minor_kept_dim = + reduction_dimensions_ + .dimensions[ReductionDimensions::kColMinorKeptDimension]; + num_blocks_per_row_ = CeilOfRatio(minor_kept_dim, WarpSize() * vector_size_); + total_num_blocks_ = major_kept_dim * num_blocks_per_row_; } std::optional MlirColumnReductionFusion::ComputeThreadIdToOutputIndexing( int64_t root_index, MLIRContext* ctx) const { - auto block_offsets = GetBlockOffsetsForTiling( - num_blocks_, tile_sizes_per_block_, input_shape_.size(), ctx); - if (!groups_.is_reduction_root[root_index]) { auto map = ComposeIndexingMaps( ComputeThreadIdToReductionInputIndexing(ctx), @@ -779,59 +758,64 @@ MlirColumnReductionFusion::ComputeThreadIdToOutputIndexing( AddGroupIdConstraint(map, root_index, groups_); return map; } + AffineExpr th_x = getAffineDimExpr(0, ctx); + AffineExpr bl_x = getAffineDimExpr(3, ctx); + AffineExpr s_v = getAffineSymbolExpr(0, ctx); + + auto reduced_shape = ShapeUtil::DeleteDimension( + ReductionDimensions::kColReducedDimension, + ShapeUtil::MakeShape(PrimitiveType::F32, input_shape_)); + SmallVector results{ + bl_x.floorDiv(num_blocks_per_row_), + ((bl_x % num_blocks_per_row_) * WarpSize() + th_x.floorDiv(WarpSize())) * + vector_size_ + + s_v}; + IndexingMap map{AffineMap::get(6, 1, results, ctx), + DimVarsFromTensorSizes( + {total_num_threads_per_block_, 1, 1, total_num_blocks_, + static_cast(groups_.grouped_roots.size()), 1}), + RangeVarsFromTensorSizes({vector_size_}), + /*rt_vars=*/{}}; + for (auto [result, dim_size] : + llvm::zip(results, reduced_shape.dimensions())) { + map.AddConstraint(result, {0, dim_size - 1}); + } + map.AddConstraint(th_x % WarpSize(), {0, 0}); + AddGroupIdConstraint(map, root_index, groups_); const auto& hero = analysis_.fusion_hero(root_index).instruction(); - - auto thread_ids = - DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_); - auto physical_shape = ShapeUtil::DeleteDimensions(hero.dimensions(), hero.operand(0)->shape()); - std::vector dimension_ranges{ - {{0, total_num_threads_per_block_ - 1}}, - {}, - {}, - {{0, total_num_blocks_ - 1}}, - {{0, static_cast(groups_.grouped_roots.size() - 1)}}, - {}, - }; - - constexpr int kColMajorKept = ReductionDimensions::kColMajorKeptDimension; - constexpr int kColMinorKept = ReductionDimensions::kColMinorKeptDimension; - constexpr int kColReduced = ReductionDimensions::kColReducedDimension; - - auto map = [&]() { - mlir::SmallVector projected_dims{ - block_offsets.getResult(kColMajorKept), - block_offsets.getResult(kColMinorKept) + thread_ids[kColReduced]}; - std::vector range_vars; - if (thread_ids.size() == 4) { - int vector_size = tile_sizes_per_thread_.back(); - range_vars.push_back({0, vector_size - 1}); - projected_dims.push_back(mlir::getAffineSymbolExpr(0, ctx)); - } - IndexingMap projected_index( - mlir::AffineMap::get(6, range_vars.size(), projected_dims, ctx), - dimension_ranges, range_vars, /*rt_vars=*/{}); - - projected_index.AddConstraint( - mlir::getAffineDimExpr( - KernelFusionInterface::kIndexingMapThreadIdxDims[0], ctx) % - WarpSize(), - {0, 0}); - projected_index.AddConstraint( - projected_index.GetAffineMap().getResult(1), - {0, input_shape_[ReductionDimensions::kColMinorKeptDimension] - 1}); - - return ComposeIndexingMaps( - projected_index, - GetBitcastMap( - ShapeUtil::DeleteDimension( - ReductionDimensions::kColReducedDimension, - ShapeUtil::MakeShape(PrimitiveType::F32, input_shape_)), - physical_shape, ctx)); - }(); + return map * GetBitcastMap(reduced_shape, physical_shape, ctx); +} - AddGroupIdConstraint(map, root_index, groups_); +IndexingMap MlirColumnReductionFusion::ComputeThreadIdToReductionInputIndexing( + mlir::MLIRContext* ctx) const { + AffineExpr th_x = getAffineDimExpr(0, ctx); + AffineExpr bl_x = getAffineDimExpr(3, ctx); + AffineExpr s_e = getAffineSymbolExpr(0, ctx); + AffineExpr s_v = getAffineSymbolExpr(1, ctx); + + int64_t num_col_elements_per_thread = + CeilOfRatio(reduction_dimensions_ + .dimensions[ReductionDimensions::kColReducedDimension], + num_warps_per_column_); + SmallVector results{ + bl_x.floorDiv(num_blocks_per_row_), + th_x.floorDiv(WarpSize()) + s_e * num_warps_per_column_, + ((bl_x % num_blocks_per_row_) * WarpSize() + th_x % WarpSize()) * + vector_size_ + + s_v}; + IndexingMap map{ + AffineMap::get(6, 2, results, ctx), + DimVarsFromTensorSizes( + {total_num_threads_per_block_, 1, 1, total_num_blocks_, + static_cast(groups_.grouped_roots.size()), 1}), + RangeVarsFromTensorSizes({num_col_elements_per_thread, vector_size_}), + /*rt_vars=*/{}}; + for (auto [result, dim_size] : + llvm::zip(results, reduction_dimensions_.dimensions)) { + map.AddConstraint(result, {0, dim_size - 1}); + } return map; } @@ -845,23 +829,15 @@ llvm::SmallVector MlirColumnReductionFusion::EmitReduction( Value cst_true = b.create(b.getOneAttr(b.getI1Type())); Value thread_id = state.thread_and_block_ids[0]; - auto thread_indexing = - GetBitcastMap({total_num_threads_per_block_}, - ShapeUtil::MakeShapeWithDescendingLayout(U8, num_threads_), - b.getContext()); - auto thread_ids = - mlir_converter::ApplyIndexing(thread_indexing, {thread_id}, {}, b); - Value lane_id = b.create(); Value warp_id = b.create( thread_id, b.create(WarpSize())); // The number of results per thread. - int64_t vector_size = tile_sizes_per_thread_.back(); - Value vector_size_cst = b.create(vector_size); + Value vector_size_cst = b.create(vector_size_); std::vector shared_tile_size{WarpSize(), - WarpSize() * vector_size + 1}; + WarpSize() * vector_size_ + 1}; Value lane_id_times_v = b.create(lane_id, vector_size_cst); Value warp_id_times_v = b.create(warp_id, vector_size_cst); diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h index 49a1fb4fdd486f..39ff6682c801d9 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h @@ -86,8 +86,8 @@ class MlirReductionFusion : public MlirFusionEmitterBase { return first_reduce_->operand(0)->shape(); } - IndexingMap ComputeThreadIdToReductionInputIndexing( - mlir::MLIRContext* ctx) const; + virtual IndexingMap ComputeThreadIdToReductionInputIndexing( + mlir::MLIRContext* ctx) const = 0; // The reduction heroes for each reduction group. std::vector> reduction_heroes_; @@ -108,6 +108,7 @@ class MlirReductionFusion : public MlirFusionEmitterBase { absl::InlinedVector num_blocks_; int64_t total_num_blocks_; int64_t total_num_threads_per_block_; + int64_t vector_size_ = -1; ReductionDimensions reduction_dimensions_; ReductionGroups groups_; @@ -125,6 +126,8 @@ class MlirRowReductionFusion : public MlirReductionFusion { int GetRowsPerWarp() const; llvm::SmallVector EmitReduction( int group_id, EmitterState& state) const override; + IndexingMap ComputeThreadIdToReductionInputIndexing( + mlir::MLIRContext* ctx) const override; }; class MlirColumnReductionFusion : public MlirReductionFusion { @@ -137,6 +140,10 @@ class MlirColumnReductionFusion : public MlirReductionFusion { protected: llvm::SmallVector EmitReduction( int group_id, EmitterState& state) const override; + IndexingMap ComputeThreadIdToReductionInputIndexing( + mlir::MLIRContext* ctx) const override; + int64_t num_warps_per_column_; + int64_t num_blocks_per_row_; }; std::unique_ptr CreateMlirReductionFusion( diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc index 66612a856fcc2a..321fadd33b5902 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc @@ -301,30 +301,6 @@ TEST_F(MlirRowReductionTest, NonPowerOfTwoRowReduction) { EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } -TEST_F(MlirRowReductionTest, MixedIndexing) { - constexpr auto kHloString = R"( - HloModule module - add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) - } - fusion { - %param_0 = f32[64,128] parameter(0) - %constant_0 = f32[] constant(0) - %reduce.1 = f32[128] reduce(f32[64,128] %param_0, f32[] %constant_0), dimensions={0}, to_apply=%add - %neg = f32[64,128] negate(f32[64,128] %param_0) - %bitcast = f32[8,8,128]{2,1,0} bitcast(f32[64,128] %neg) - %reduce.2 = f32[128] reduce(f32[8,8,128]{2,1,0} %bitcast, f32[] %constant_0), dimensions={0,1}, to_apply=%add - ROOT %tuple.12 = (f32[128], f32[128]) tuple(f32[128] %reduce.1, f32[128] %reduce.2) - } - ENTRY entry { - %param_0 = f32[64,128] parameter(0) - ROOT %fusion = (f32[128], f32[128]) fusion(%param_0), kind=kInput, calls=fusion - })"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - TEST_F(MlirRowReductionTest, NonTrivialEpilogue) { constexpr auto kHloString = R"( HloModule module @@ -638,6 +614,48 @@ TEST_F(MlirColumnReductionTest, ColumnReduction) { c = f32[] constant(0) ROOT fusion = f32[13,321] fusion(a, c), kind=kInput, calls=fused_computation })"; + + auto module = ParseAndReturnVerifiedModule(kHloString).value(); + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + MlirColumnReductionFusion fusion(analysis); + EXPECT_THAT( + fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1] -> ( + d3 floordiv 11, + d0 floordiv 32 + s0 * 32, + (d3 mod 11) * 32 + d0 mod 32 + s1 + ) + domain: + d0 in [0, 1023] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 142] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 32] + s1 in [0, 0] + (d3 mod 11) * 32 + d0 mod 32 + s1 in [0, 320] + d0 floordiv 32 + s0 * 32 in [0, 1050] + )")); + EXPECT_THAT( + fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0] -> ( + d3 floordiv 11, (d3 mod 11) * 32 + d0 floordiv 32 + s0 + ) + domain: + d0 in [0, 1023] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 142] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 0] + (d3 mod 11) * 32 + d0 floordiv 32 + s0 in [0, 320] + d0 mod 32 in [0, 0] + )")); TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( // CHECK: xla_gpu.pure_call @Add_add // CHECK: allocate_shared @@ -672,6 +690,30 @@ TEST_F(MlirColumnReductionTest, SmallColumnReduction) { EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } +TEST_F(MlirColumnReductionTest, MixedIndexing) { + constexpr auto kHloString = R"( + HloModule module + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + fusion { + %param_0 = f32[64,128] parameter(0) + %constant_0 = f32[] constant(0) + %reduce.1 = f32[128] reduce(f32[64,128] %param_0, f32[] %constant_0), dimensions={0}, to_apply=%add + %neg = f32[64,128] negate(f32[64,128] %param_0) + %bitcast = f32[8,8,128]{2,1,0} bitcast(f32[64,128] %neg) + %reduce.2 = f32[128] reduce(f32[8,8,128]{2,1,0} %bitcast, f32[] %constant_0), dimensions={0,1}, to_apply=%add + ROOT %tuple.12 = (f32[128], f32[128]) tuple(f32[128] %reduce.1, f32[128] %reduce.2) + } + ENTRY entry { + %param_0 = f32[64,128] parameter(0) + ROOT %fusion = (f32[128], f32[128]) fusion(%param_0), kind=kInput, calls=fusion + })"; + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + TEST_F(MlirColumnReductionTest, ColumnReductionVectorization) { constexpr auto kHloString = R"( HloModule Test, is_scheduled=true @@ -690,6 +732,44 @@ TEST_F(MlirColumnReductionTest, ColumnReductionVectorization) { c = f32[] constant(0) ROOT fusion = f32[16384] fusion(a, c), kind=kInput, calls=fused_computation })"; + auto module = ParseAndReturnVerifiedModule(kHloString).value(); + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + MlirColumnReductionFusion fusion(analysis); + EXPECT_THAT( + fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1] -> ( + (d3 floordiv 256) * 2048 + d0 floordiv 32 + s0 * 32, + ((d3 mod 256) * 32 + d0 mod 32) * 2 + s1) + domain: + d0 in [0, 1023] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 255] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 63] + s1 in [0, 1] + ((d3 mod 256) * 32 + d0 mod 32) * 2 + s1 in [0, 16383] + d0 floordiv 32 + s0 * 32 in [0, 2047] + )")); + EXPECT_THAT( + fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0] -> + ((d3 floordiv 256) * 16384 + ((d3 mod 256) * 32 + d0 floordiv 32) * 2 + s0) + domain: + d0 in [0, 1023] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 255] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 1] + ((d3 mod 256) * 32 + d0 floordiv 32) * 2 + s0 in [0, 16383] + d0 mod 32 in [0, 0] + )")); TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( // CHECK: vector<2xf32> )")); @@ -722,9 +802,8 @@ TEST_F(MlirColumnReductionTest, ColumnReductionVectorization_v4) { } TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_v2) { - auto module = ParseAndReturnVerifiedModule( - absl::Substitute(kColumnVectorizationTemplate, "f32")) - .value(); + const auto kHloString = absl::Substitute(kColumnVectorizationTemplate, "f32"); + auto module = ParseAndReturnVerifiedModule(kHloString).value(); auto* root = module->entry_computation()->root_instruction(); auto analysis = AnalyzeFusion(*root, device_info_); @@ -733,10 +812,10 @@ TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_v2) { EXPECT_THAT( fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( + (d0, d1, d2, d3, d4, d5)[s0, s1] -> ( d3 floordiv 24, - d0 floordiv 32 + s1 * 32, - ((d3 mod 24) * 32 + d0 mod 32) * 2 + s3 + d0 floordiv 32 + s0 * 32, + ((d3 mod 24) * 32 + d0 mod 32) * 2 + s1 ) domain: d0 in [0, 1023] @@ -745,12 +824,10 @@ TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_v2) { d3 in [0, 4607] d4 in [0, 0] d5 in [0, 0] - s0 in [0, 0] + s0 in [0, 1] s1 in [0, 1] - s2 in [0, 0] - s3 in [0, 1] - (d3 mod 24) * 32 + d0 mod 32 in [0, 767] - d0 floordiv 32 + s1 * 32 in [0, 63] + ((d3 mod 24) * 32 + d0 mod 32) * 2 + s1 in [0, 1535] + d0 floordiv 32 + s0 * 32 in [0, 63] )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -766,15 +843,14 @@ TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_v2) { d4 in [0, 0] d5 in [0, 0] s0 in [0, 1] - (d3 mod 24) * 32 + d0 floordiv 32 in [0, 767] + ((d3 mod 24) * 32 + d0 floordiv 32) * 2 + s0 in [0, 1535] d0 mod 32 in [0, 0] )")); } TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_v4) { - auto module = ParseAndReturnVerifiedModule( - absl::Substitute(kColumnVectorizationTemplate, "f16")) - .value(); + const auto kHloString = absl::Substitute(kColumnVectorizationTemplate, "f16"); + auto module = ParseAndReturnVerifiedModule(kHloString).value(); auto* root = module->entry_computation()->root_instruction(); auto analysis = AnalyzeFusion(*root, device_info_); @@ -784,10 +860,10 @@ TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_v4) { EXPECT_THAT( fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( + (d0, d1, d2, d3, d4, d5)[s0, s1] -> ( d3 floordiv 12, - d0 floordiv 32 + s1 * 32, - ((d3 mod 12) * 32 + d0 mod 32) * 4 + s3 + d0 floordiv 32 + s0 * 32, + ((d3 mod 12) * 32 + d0 mod 32) * 4 + s1 ) domain: d0 in [0, 1023] @@ -796,12 +872,10 @@ TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_v4) { d3 in [0, 2303] d4 in [0, 0] d5 in [0, 0] - s0 in [0, 0] - s1 in [0, 1] - s2 in [0, 0] - s3 in [0, 3] - (d3 mod 12) * 32 + d0 mod 32 in [0, 383] - d0 floordiv 32 + s1 * 32 in [0, 63] + s0 in [0, 1] + s1 in [0, 3] + ((d3 mod 12) * 32 + d0 mod 32) * 4 + s1 in [0, 1535] + d0 floordiv 32 + s0 * 32 in [0, 63] )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -817,7 +891,7 @@ TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_v4) { d4 in [0, 0] d5 in [0, 0] s0 in [0, 3] - (d3 mod 12) * 32 + d0 floordiv 32 in [0, 383] + ((d3 mod 12) * 32 + d0 floordiv 32) * 4 + s0 in [0, 1535] d0 mod 32 in [0, 0] )")); } @@ -836,10 +910,10 @@ TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_Complex) { EXPECT_THAT( fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( + (d0, d1, d2, d3, d4, d5)[s0, s1] -> ( d3 floordiv 48, - d0 floordiv 32 + s1 * 32, - (d3 mod 48) * 32 + d0 mod 32 + d0 floordiv 32 + s0 * 32, + (d3 mod 48) * 32 + d0 mod 32 + s1 ) domain: d0 in [0, 1023] @@ -848,12 +922,10 @@ TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_Complex) { d3 in [0, 9215] d4 in [0, 0] d5 in [0, 0] - s0 in [0, 0] - s1 in [0, 1] - s2 in [0, 0] - s3 in [0, 0] - (d3 mod 48) * 32 + d0 mod 32 in [0, 1535] - d0 floordiv 32 + s1 * 32 in [0, 63] + s0 in [0, 1] + s1 in [0, 0] + (d3 mod 48) * 32 + d0 mod 32 + s1 in [0, 1535] + d0 floordiv 32 + s0 * 32 in [0, 63] )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -869,7 +941,7 @@ TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_Complex) { d4 in [0, 0] d5 in [0, 0] s0 in [0, 0] - (d3 mod 48) * 32 + d0 floordiv 32 in [0, 1535] + (d3 mod 48) * 32 + d0 floordiv 32 + s0 in [0, 1535] d0 mod 32 in [0, 0] )")); } From 686e352b520b91b381a7bf810d304cb8bb9f084b Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Tue, 18 Jun 2024 07:54:32 -0700 Subject: [PATCH 16/59] Temporarily disable cudnn algorithm 14 for all shapes This algorithm is responsible for numerical problems in 4+ models from different customers. It's likely that other customers also have issues that they didn't report yet. Let's disable algo id 14 for all shapes for now until the cuDNN team has a chance to look at the issue. PiperOrigin-RevId: 644385128 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/conv_algorithm_picker.cc | 57 +++++++++++++++++++ .../service/gpu/conv_algorithm_picker_test.cc | 16 ++++++ 3 files changed, 74 insertions(+) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 723d6946d22165..accbe67948df09 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2034,6 +2034,7 @@ xla_test( ":autotuner_util", ":conv_algorithm_picker", ":gpu_conv_rewriter", + ":stream_executor_util", "//xla:debug_options_flags", "//xla/hlo/ir:hlo", "//xla/service:pattern_matcher", diff --git a/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc b/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc index 80a731bee65761..5fc5746b32564c 100644 --- a/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -36,6 +37,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/autotuning.pb.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -459,6 +461,48 @@ GpuConvAlgorithmPicker::AutotuneRuntimeArguments::FromInstruction( return runtime_arguments; } +struct CudnnVersionRange { + using TupleVersion = std::tuple; + TupleVersion begin; + TupleVersion end; + + bool IsInRange(const CudnnVersion& other) const { + TupleVersion other_version{other.major(), other.minor(), other.patch()}; + return begin <= other_version && other_version < end; + } + + CudnnVersionRange(const CudnnVersion& begin, const CudnnVersion& end) + : begin(begin.major(), begin.minor(), begin.patch()), + end(end.major(), end.minor(), end.patch()) {} + + CudnnVersionRange(const TupleVersion& begin, const TupleVersion& end) + : begin(begin), end(end) {} +}; + +struct ComputeCapabilityRange { + using TupleComputeCapability = std::tuple; + TupleComputeCapability begin; + TupleComputeCapability end; + + bool IsInRange(const ComputeCapability& other) const { + TupleComputeCapability other_cc{other.major(), other.minor()}; + return begin <= other_cc && other_cc < end; + } +}; + +struct DisabledAlgorithm { + CudnnVersionRange cudnn_version_range; + ComputeCapabilityRange compute_capability_range; + int algo_id; +}; + +// TODO(b/343101418): Remove this once the bug is fixed in upstream cudnn and +// once we updated to that cudnn version. +static const DisabledAlgorithm kDisabledAlgorithms[] = { + {/*.cudnn_version_range=*/{/*.begin=*/{9, 0, 0}, /*.end=*/{10, 0, 0}}, + /*.compute_capability_range=*/{/*.begin=*/{6, 0}, /*.end=*/{8, 0}}, + /*.algo_id=*/14}}; + // There are three tiers of errors possible here: returning a failed // absl::StatusOr means autotuning fails immediately; returning an // AutotuneResult with a failure code other than DISQUALIFIED means autotuning @@ -494,6 +538,19 @@ absl::StatusOr GpuConvAlgorithmPicker::AutotuneOneConvRunner( ? std::string(instruction_info->GetHlo()) : ""; + for (const auto& disabled_algo : kDisabledAlgorithms) { + if (disabled_algo.cudnn_version_range.IsInRange( + GetCudnnVersion(stream_exec)) && + disabled_algo.compute_capability_range.IsInRange( + GetComputeCapability(stream_exec)) && + disabled_algo.algo_id == alg.algo_id()) { + LOG(INFO) << "Omitted potentially buggy algorithm " << alg.ToString() + << " for conv " << instr_str; + return make_failure(AutotuneResult::DISQUALIFIED, + "Disqualified for being known-buggy."); + } + } + if (absl::c_linear_search(disabled_algos, alg_key)) { LOG(INFO) << "Omitted potentially buggy algorithm " << alg.ToString() << " for conv " << instr_str; diff --git a/third_party/xla/xla/service/gpu/conv_algorithm_picker_test.cc b/third_party/xla/xla/service/gpu/conv_algorithm_picker_test.cc index 0ec6e7013fe7ab..d9a3a691da0565 100644 --- a/third_party/xla/xla/service/gpu/conv_algorithm_picker_test.cc +++ b/third_party/xla/xla/service/gpu/conv_algorithm_picker_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/gpu/conv_algorithm_picker.h" #include +#include #include #include "absl/strings/string_view.h" @@ -23,6 +24,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/autotuner_util.h" #include "xla/service/gpu/gpu_conv_rewriter.h" +#include "xla/service/gpu/stream_executor_util.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/service/platform_util.h" @@ -107,6 +109,20 @@ ENTRY main { conv->shape(), GmockMatch(m::Shape().WithSubshape( {1}, m::Shape().WithElementType(U8).WithDims({new_scratch_bytes})))); + + // Algorithm 14 is disabled for cuDNN 9 on V100 + TF_ASSERT_OK_AND_ASSIGN(auto dnn_version, GetDnnVersionInfo(stream_exec)); + if (dnn_version.major_version() >= 9 && dnn_version.major_version() < 10 && + std::holds_alternative(cc) && + std::get(cc).major == 7 && + std::get(cc).minor == 0) { + EXPECT_TRUE(conv->backend_config() + ->has_cudnn_conv_backend_config() && + conv->backend_config() + ->cudnn_conv_backend_config() + .algorithm() + .algo_id() != 14); + } } } // namespace From 5555ec6252a89c5830e1124267dbb565191d6279 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 18 Jun 2024 08:17:33 -0700 Subject: [PATCH 17/59] Internal change only. PiperOrigin-RevId: 644392118 --- tensorflow/core/BUILD | 1 + tensorflow/core/framework/BUILD | 3 +++ 2 files changed, 4 insertions(+) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index c28a4618154302..e8258c39d0a12a 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -134,6 +134,7 @@ package_group( "//perftools/accelerators/xprof/...", "//quality/webanswers/brain/tokenization/custom_tf_ops/kernels/...", "//smartass/brain/server/...", + "//waymo/ml/deploy/benchmark/...", ], ) diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index dc7960cff2459e..257ce3f9dd2bb3 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -37,6 +37,7 @@ default_visibility = [ # copybara:uncomment "//learning/brain/experimental/tfrt/native_lowering/graph_executor:__subpackages__", "//tensorflow/cc/saved_model:__subpackages__", #internal starburst clustering service, + "//waymo/ml/deploy/benchmark:__subpackages__", ] package( @@ -686,6 +687,7 @@ cc_library( "//tensorflow/core:__pkg__", "//tensorflow/core/runtime_fallback:__subpackages__", "//tensorflow/core/tfrt/utils:__subpackages__", + "//waymo/ml/deploy/benchmark:__subpackages__", ], deps = [ ":bounds_check", @@ -1700,6 +1702,7 @@ tf_proto_library( "//tensorflow/core:__subpackages__", "//tensorflow/python:__pkg__", "//tensorflow/security/fuzzing:__subpackages__", + "//waymo/ml/deploy/benchmark:__subpackages__", ], ) From 40856334bf02d9cacc466e8b9fb17f275c254103 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Tue, 18 Jun 2024 08:21:46 -0700 Subject: [PATCH 18/59] Move `InferDotOperandSharding` from `sharding_propagation.cc` to `hlo_sharding_util`. We may also use it in SPMD partitioner. This cl only change the location of a util function without behavior change. PiperOrigin-RevId: 644393231 --- third_party/xla/xla/hlo/utils/BUILD | 6 ++ .../xla/xla/hlo/utils/hlo_sharding_util.cc | 88 +++++++++++++++++++ .../xla/xla/hlo/utils/hlo_sharding_util.h | 14 +++ .../xla/hlo/utils/hlo_sharding_util_test.cc | 63 +++++++++++-- .../xla/xla/service/sharding_propagation.cc | 87 ++---------------- 5 files changed, 171 insertions(+), 87 deletions(-) diff --git a/third_party/xla/xla/hlo/utils/BUILD b/third_party/xla/xla/hlo/utils/BUILD index 27fee52a3af361..7c6c4719a5bd7e 100644 --- a/third_party/xla/xla/hlo/utils/BUILD +++ b/third_party/xla/xla/hlo/utils/BUILD @@ -106,6 +106,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/ir:tile_assignment", "//xla/service:call_graph", + "//xla/service:dot_as_convolution_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", @@ -135,9 +136,14 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:tile_assignment", + "//xla/service:dot_as_convolution_util", + "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc index a0aae451fc4e27..d51fa004cc4c13 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc @@ -52,6 +52,7 @@ limitations under the License. #include "xla/map_util.h" #include "xla/protobuf_util.h" #include "xla/service/call_graph.h" +#include "xla/service/dot_as_convolution_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" @@ -3351,5 +3352,92 @@ std::optional ReturnImprovedShardingImpl( return std::nullopt; } +HloSharding InferDotOperandSharding( + const HloInstruction* dot, int64_t operand_index, + const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, + bool consider_other_operand, bool may_combine_partial_sharding) { + CHECK(dot->opcode() == HloOpcode::kDot || + dot->opcode() == HloOpcode::kConvolution); + CHECK(operand_index == 0 || operand_index == 1); + CHECK(dnums.conv_spatial_dims.empty()); + + auto operand = dot->operand(operand_index); + auto other = dot->operand(1 - operand_index); + std::vector output_dims_to_replicate; + std::vector other_operand_dims_to_replicate; + for (const auto& dim : operand_index == 0 ? dnums.rhs_non_contracting_dims + : dnums.lhs_non_contracting_dims) { + output_dims_to_replicate.push_back(dim.output); + other_operand_dims_to_replicate.push_back(operand_index == 0 ? dim.rhs + : dim.lhs); + } + // If this dot is interpreted from a conv, then contracting dims may have + // corresponding spatial dimensions in the output, and this operand's + // non-contracting dims may have corresponding spatial dims in the other + // operand. + for (const auto& dim : dnums.contracting_dims) { + if (dim.output >= 0) { + output_dims_to_replicate.push_back(dim.output); + } + } + for (const auto& dim : operand_index == 0 ? dnums.lhs_non_contracting_dims + : dnums.rhs_non_contracting_dims) { + int64_t other_dim = operand_index == 0 ? dim.rhs : dim.lhs; + if (other_dim >= 0) { + other_operand_dims_to_replicate.push_back(other_dim); + } + } + HloSharding output_other_dims_replicated = + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + dot->sharding(), output_dims_to_replicate); + + std::vector output_to_operand_dims(dot->shape().rank(), -1); + std::vector operand_to_output_dims(operand->shape().rank(), -1); + for (const auto& dim : dnums.batch_dims) { + output_to_operand_dims[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs; + operand_to_output_dims[operand_index == 0 ? dim.lhs : dim.rhs] = dim.output; + } + for (const auto& dim : operand_index == 0 ? dnums.lhs_non_contracting_dims + : dnums.rhs_non_contracting_dims) { + output_to_operand_dims[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs; + operand_to_output_dims[operand_index == 0 ? dim.lhs : dim.rhs] = dim.output; + } + auto sharding = *hlo_sharding_util::TransposeShardingWithCollapsedDims( + output_other_dims_replicated, output_to_operand_dims, + operand_to_output_dims); + + if (consider_other_operand && + hlo_sharding_util::IsSpatiallyPartitioned(other)) { + auto other_operand_dims_replicated = + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + other->sharding(), other_operand_dims_to_replicate); + + std::vector other_to_operand_dims(other->shape().rank(), -1); + std::vector operand_to_other_dims(operand->shape().rank(), -1); + for (const auto& dim : dnums.batch_dims) { + other_to_operand_dims[operand_index == 0 ? dim.rhs : dim.lhs] = + operand_index == 0 ? dim.lhs : dim.rhs; + operand_to_other_dims[operand_index == 0 ? dim.lhs : dim.rhs] = + operand_index == 0 ? dim.rhs : dim.lhs; + } + for (const auto& dim : dnums.contracting_dims) { + other_to_operand_dims[operand_index == 0 ? dim.rhs : dim.lhs] = + operand_index == 0 ? dim.lhs : dim.rhs; + operand_to_other_dims[operand_index == 0 ? dim.lhs : dim.rhs] = + operand_index == 0 ? dim.rhs : dim.lhs; + } + HloSharding sharding_from_other = + *hlo_sharding_util::TransposeShardingWithCollapsedDims( + other_operand_dims_replicated, other_to_operand_dims, + operand_to_other_dims); + if (hlo_sharding_util::MergeSharding(sharding, &sharding_from_other, + may_combine_partial_sharding)) { + sharding = std::move(sharding_from_other); + } + } + + return sharding; +} + } // namespace hlo_sharding_util } // namespace xla diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h index 311b2c9449844d..2b0970dbcecbe4 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h @@ -34,6 +34,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/service/call_graph.h" +#include "xla/service/dot_as_convolution_util.h" #include "xla/shape.h" #include "xla/util.h" @@ -522,6 +523,19 @@ std::optional ReturnImprovedShardingImpl( const Shape& to_improved_shape, bool may_combine_partial_sharding, bool allow_aggressive_resharding = false); +// Infers the sharding of the operand of a dot operation. +// +// If `operand_index` is 0, the sharding of the LHS is inferred. If it is 1, +// the sharding of the RHS is inferred. +// +// If `consider_other_operand` is true, the sharding of the other operand is +// considered. `may_combine_partial_sharding` is used when considering other +// operand. +HloSharding InferDotOperandSharding( + const HloInstruction* dot, int64_t operand_index, + const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, + bool consider_other_operand, bool may_combine_partial_sharding); + } // namespace hlo_sharding_util } // namespace xla diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc index c9902e6a6ad14c..87ee25f3bdf737 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc @@ -21,17 +21,22 @@ limitations under the License. #include #include +#include #include "absl/log/log.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/array.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/tile_assignment.h" +#include "xla/service/dot_as_convolution_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" +#include "xla/tests/hlo_test_base.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { namespace hlo_sharding_util { @@ -554,9 +559,9 @@ TEST(HloShardingUtilTest, GetManualSubgroupSharding_ManualOnly) { // Expect the device groups are: {0, 2} and {1, 3} EXPECT_THAT(group_sharding.device_groups[0], - testing::ElementsAreArray({0, 2})); + ::testing::ElementsAreArray({0, 2})); EXPECT_THAT(group_sharding.device_groups[1], - testing::ElementsAreArray({1, 3})); + ::testing::ElementsAreArray({1, 3})); } TEST(HloShardingUtilTest, GetManualSubgroupSharding_ManualAndReplicted) { @@ -574,9 +579,9 @@ TEST(HloShardingUtilTest, GetManualSubgroupSharding_ManualAndReplicted) { // Expect the device groups are: {0, 2, 4, 6} and {1, 3, 5, 7} EXPECT_THAT(group_sharding.device_groups[0], - testing::ElementsAreArray({0, 2, 4, 6})); + ::testing::ElementsAreArray({0, 2, 4, 6})); EXPECT_THAT(group_sharding.device_groups[1], - testing::ElementsAreArray({1, 3, 5, 7})); + ::testing::ElementsAreArray({1, 3, 5, 7})); } TEST(HloShardingUtilTest, GetManualSubgroupSharding_ReplicatedAndManual) { @@ -594,9 +599,9 @@ TEST(HloShardingUtilTest, GetManualSubgroupSharding_ReplicatedAndManual) { // Expect the device groups are: {0, 1, 4, 5} and {2, 3, 6, 7} EXPECT_THAT(group_sharding.device_groups[0], - testing::ElementsAreArray({0, 1, 4, 5})); + ::testing::ElementsAreArray({0, 1, 4, 5})); EXPECT_THAT(group_sharding.device_groups[1], - testing::ElementsAreArray({2, 3, 6, 7})); + ::testing::ElementsAreArray({2, 3, 6, 7})); } TEST(HloShardingUtilTest, UngroupSharding_ManualOnly) { @@ -1021,6 +1026,52 @@ TEST(HloShardingUtilTest, UntileShape) { ShapeUtil::MakeTupleShape({expected_shape_0, expected_shape_1})); } +using HloShardingUtilTestWithHlo = HloTestBase; + +TEST_F(HloShardingUtilTestWithHlo, InferDotOperandShardingTest) { + absl::string_view hlo_string = R"( + HloModule module + + ENTRY %main.7 { + %p0 = bf16[32,64,128,512] parameter(0), sharding={devices=[8,1,1,4]<=[32]} + %p1 = bf16[32,64,256,512] parameter(1), sharding={devices=[1,1,1,2,16]<=[8,2,2]T(1,0,2) last_tile_dim_replicate} + ROOT %dot.3 = bf16[32,64,128,256] dot(%p0, %p1), lhs_batch_dims={0,1}, rhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_contracting_dims={3}, sharding={devices=[2,2,2,2,2]<=[32] last_tile_dim_replicate} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + const HloInstruction* dot = module->entry_computation()->root_instruction(); + auto dnums = dot_as_convolution_util::ParseDotGeneralFromDot(dot); + + bool consider_other_operand = true; + bool may_combine_partial_sharding = false; + EXPECT_EQ(InferDotOperandSharding(dot, 0, dnums, consider_other_operand, + may_combine_partial_sharding), + HloSharding::PartialTile(TileAssignment({2, 2, 2, 1, 4}))); + EXPECT_EQ(InferDotOperandSharding(dot, 1, dnums, consider_other_operand, + may_combine_partial_sharding), + HloSharding::IotaTile({8, 1, 1, 4})); + + consider_other_operand = true; + may_combine_partial_sharding = true; + EXPECT_EQ(InferDotOperandSharding(dot, 0, dnums, consider_other_operand, + may_combine_partial_sharding), + HloSharding::PartialTile(TileAssignment({2, 2, 2, 2, 2}))); + EXPECT_EQ(InferDotOperandSharding(dot, 1, dnums, consider_other_operand, + may_combine_partial_sharding), + HloSharding::IotaTile({8, 1, 1, 4})); + + consider_other_operand = false; + for (bool may_combine_partial_sharding : {false, true}) { + EXPECT_EQ(InferDotOperandSharding(dot, 0, dnums, consider_other_operand, + may_combine_partial_sharding), + HloSharding::PartialTile(TileAssignment({2, 2, 2, 1, 4}))); + EXPECT_EQ(InferDotOperandSharding(dot, 1, dnums, consider_other_operand, + may_combine_partial_sharding), + HloSharding::PartialTile(TileAssignment( + {2, 2, 2, 1, 4}, {2, 2, 2, 2, 2}, {0, 1, 3, 2, 4}))); + } +} + } // namespace } // namespace hlo_sharding_util } // namespace xla diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index 5f547c99481120..ad088e6f15e8bb 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -587,83 +587,6 @@ bool CanPropagateThroughAtAggressiveLevel(const HloInstruction& inst, return true; } -HloSharding InferDotOperandSharding( - const HloInstruction* instruction, - const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, - int64_t operand_index, bool may_combine_partial_sharding) { - auto operand = instruction->operand(operand_index); - auto other = instruction->operand(1 - operand_index); - std::vector output_dims_to_replicate; - std::vector other_operand_dims_to_replicate; - for (const auto& dim : operand_index == 0 ? dnums.rhs_non_contracting_dims - : dnums.lhs_non_contracting_dims) { - output_dims_to_replicate.push_back(dim.output); - other_operand_dims_to_replicate.push_back(operand_index == 0 ? dim.rhs - : dim.lhs); - } - // If this dot is interpreted from a conv, then contracting dims may have - // corresponding spatial dimensions in the output, and this operand's - // non-contracting dims may have corresponding spatial dims in the other - // operand. - for (const auto& dim : dnums.contracting_dims) { - if (dim.output >= 0) { - output_dims_to_replicate.push_back(dim.output); - } - } - for (const auto& dim : operand_index == 0 ? dnums.lhs_non_contracting_dims - : dnums.rhs_non_contracting_dims) { - int64_t other_dim = operand_index == 0 ? dim.rhs : dim.lhs; - if (other_dim >= 0) { - other_operand_dims_to_replicate.push_back(other_dim); - } - } - auto output_other_dims_replicated = - hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( - instruction->sharding(), output_dims_to_replicate); - std::vector output_to_operand_dims(instruction->shape().rank(), -1); - std::vector operand_to_output_dims(operand->shape().rank(), -1); - for (const auto& dim : dnums.batch_dims) { - output_to_operand_dims[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs; - operand_to_output_dims[operand_index == 0 ? dim.lhs : dim.rhs] = dim.output; - } - for (const auto& dim : operand_index == 0 ? dnums.lhs_non_contracting_dims - : dnums.rhs_non_contracting_dims) { - output_to_operand_dims[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs; - operand_to_output_dims[operand_index == 0 ? dim.lhs : dim.rhs] = dim.output; - } - auto sharding = *hlo_sharding_util::TransposeShardingWithCollapsedDims( - output_other_dims_replicated, output_to_operand_dims, - operand_to_output_dims); - if (hlo_sharding_util::IsSpatiallyPartitioned(other)) { - auto other_operand_dims_replicated = - hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( - other->sharding(), other_operand_dims_to_replicate); - std::vector other_to_operand_dims(other->shape().rank(), -1); - std::vector operand_to_other_dims(operand->shape().rank(), -1); - for (const auto& dim : dnums.batch_dims) { - other_to_operand_dims[operand_index == 0 ? dim.rhs : dim.lhs] = - operand_index == 0 ? dim.lhs : dim.rhs; - operand_to_other_dims[operand_index == 0 ? dim.lhs : dim.rhs] = - operand_index == 0 ? dim.rhs : dim.lhs; - } - for (const auto& dim : dnums.contracting_dims) { - other_to_operand_dims[operand_index == 0 ? dim.rhs : dim.lhs] = - operand_index == 0 ? dim.lhs : dim.rhs; - operand_to_other_dims[operand_index == 0 ? dim.lhs : dim.rhs] = - operand_index == 0 ? dim.rhs : dim.lhs; - } - HloSharding sharding_from_other = - *hlo_sharding_util::TransposeShardingWithCollapsedDims( - other_operand_dims_replicated, other_to_operand_dims, - operand_to_other_dims); - if (hlo_sharding_util::MergeSharding(sharding, &sharding_from_other, - may_combine_partial_sharding)) { - sharding = std::move(sharding_from_other); - } - } - return sharding; -} - // Checks if two HloShardings have the same metadata attached. bool SameShardingMetadata(const HloSharding& a, const HloSharding& b) { DCHECK_EQ(a, b); @@ -1730,8 +1653,9 @@ std::optional ShardingPropagation::GetShardingFromUser( auto dot_dims = dot_as_convolution_util::ParseConvolutionDimsInfo(&user); if (dot_dims.conv_spatial_dims.empty()) { int64_t op_idx = user.operand_index(&instruction); - return InferDotOperandSharding(&user, dot_dims, op_idx, - may_combine_partial_sharding); + return hlo_sharding_util::InferDotOperandSharding( + &user, op_idx, dot_dims, /*consider_other_operand=*/true, + may_combine_partial_sharding); } return std::nullopt; } @@ -1867,8 +1791,9 @@ std::optional ShardingPropagation::GetShardingFromUser( case HloOpcode::kDot: { int64_t op_idx = user.operand_index(&instruction); auto dnums = dot_as_convolution_util::ParseDotGeneralFromDot(&user); - return InferDotOperandSharding(&user, dnums, op_idx, - may_combine_partial_sharding); + return hlo_sharding_util::InferDotOperandSharding( + &user, op_idx, dnums, /*consider_other_operand=*/true, + may_combine_partial_sharding); } case HloOpcode::kReduce: { if (instruction.shape().rank() == 0) { From 6e6641a618e737b06adaaad74b0b0bc00651f32f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 18 Jun 2024 08:35:16 -0700 Subject: [PATCH 19/59] Integrate LLVM at llvm/llvm-project@52d87de7a42d Updates LLVM usage to match [52d87de7a42d](https://github.com/llvm/llvm-project/commit/52d87de7a42d) PiperOrigin-RevId: 644397020 --- third_party/llvm/generated.patch | 501 +++++++------------------------ third_party/llvm/workspace.bzl | 4 +- 2 files changed, 113 insertions(+), 392 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 1c001c5121cf1a..765078cd0493f5 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,396 +1,117 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/abs_i16.ll b/llvm/test/CodeGen/AMDGPU/abs_i16.ll ---- a/llvm/test/CodeGen/AMDGPU/abs_i16.ll -+++ b/llvm/test/CodeGen/AMDGPU/abs_i16.ll -@@ -98,10 +98,10 @@ - ; GFX8-LABEL: v_abs_v2i16: - ; GFX8: ; %bb.0: - ; GFX8-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX8-NEXT: v_lshrrev_b32_e32 v1, 16, v0 --; GFX8-NEXT: v_sub_u16_e32 v2, 0, v1 --; GFX8-NEXT: v_max_i16_sdwa v1, v1, v2 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD -+; GFX8-NEXT: v_mov_b32_e32 v1, 0 -+; GFX8-NEXT: v_sub_u16_sdwa v1, v1, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 - ; GFX8-NEXT: v_sub_u16_e32 v2, 0, v0 -+; GFX8-NEXT: v_max_i16_sdwa v1, v0, v1 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD - ; GFX8-NEXT: v_max_i16_e32 v0, v0, v2 - ; GFX8-NEXT: v_or_b32_e32 v0, v0, v1 - ; GFX8-NEXT: s_setpc_b64 s[30:31] -@@ -181,12 +181,12 @@ - ; GFX8-LABEL: v_abs_v3i16: - ; GFX8: ; %bb.0: - ; GFX8-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX8-NEXT: v_lshrrev_b32_e32 v2, 16, v0 --; GFX8-NEXT: v_sub_u16_e32 v3, 0, v2 --; GFX8-NEXT: v_max_i16_sdwa v2, v2, v3 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD -+; GFX8-NEXT: v_mov_b32_e32 v2, 0 - ; GFX8-NEXT: v_sub_u16_e32 v3, 0, v1 -+; GFX8-NEXT: v_sub_u16_sdwa v2, v2, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 - ; GFX8-NEXT: v_max_i16_e32 v1, v1, v3 - ; GFX8-NEXT: v_sub_u16_e32 v3, 0, v0 -+; GFX8-NEXT: v_max_i16_sdwa v2, v0, v2 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD - ; GFX8-NEXT: v_max_i16_e32 v0, v0, v3 - ; GFX8-NEXT: v_or_b32_e32 v0, v0, v2 - ; GFX8-NEXT: s_setpc_b64 s[30:31] -@@ -286,18 +286,17 @@ - ; GFX8-LABEL: v_abs_v4i16: - ; GFX8: ; %bb.0: - ; GFX8-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX8-NEXT: v_lshrrev_b32_e32 v2, 16, v1 --; GFX8-NEXT: v_sub_u16_e32 v3, 0, v2 --; GFX8-NEXT: v_max_i16_sdwa v2, v2, v3 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v3, 16, v0 --; GFX8-NEXT: v_sub_u16_e32 v4, 0, v3 --; GFX8-NEXT: v_max_i16_sdwa v3, v3, v4 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD -+; GFX8-NEXT: v_mov_b32_e32 v2, 0 -+; GFX8-NEXT: v_sub_u16_sdwa v3, v2, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_sub_u16_sdwa v2, v2, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 - ; GFX8-NEXT: v_sub_u16_e32 v4, 0, v1 - ; GFX8-NEXT: v_sub_u16_e32 v5, 0, v0 -+; GFX8-NEXT: v_max_i16_sdwa v3, v1, v3 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_sdwa v2, v0, v2 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD - ; GFX8-NEXT: v_max_i16_e32 v0, v0, v5 - ; GFX8-NEXT: v_max_i16_e32 v1, v1, v4 --; GFX8-NEXT: v_or_b32_e32 v0, v0, v3 --; GFX8-NEXT: v_or_b32_e32 v1, v1, v2 -+; GFX8-NEXT: v_or_b32_e32 v0, v0, v2 -+; GFX8-NEXT: v_or_b32_e32 v1, v1, v3 - ; GFX8-NEXT: s_setpc_b64 s[30:31] +diff -ruN --strip-trailing-cr a/clang/unittests/Lex/HeaderSearchTest.cpp b/clang/unittests/Lex/HeaderSearchTest.cpp +--- a/clang/unittests/Lex/HeaderSearchTest.cpp ++++ b/clang/unittests/Lex/HeaderSearchTest.cpp +@@ -19,6 +19,8 @@ + #include "clang/Serialization/InMemoryModuleCache.h" + #include "llvm/Support/MemoryBuffer.h" + #include "gtest/gtest.h" ++#include ++#include + + namespace clang { + namespace { +@@ -350,8 +352,8 @@ + std::string TextualPath = "/textual.h"; + }; + +- auto ExternalSource = new MockExternalHeaderFileInfoSource(); +- Search.SetExternalSource(ExternalSource); ++ auto ExternalSource = std::make_unique(); ++ Search.SetExternalSource(ExternalSource.get()); + + // Everything should start out external. + auto ModularFE = AddHeader(ExternalSource->ModularPath); +diff -ruN --strip-trailing-cr a/llvm/test/DebugInfo/Generic/sroa-extract-bits.ll b/llvm/test/DebugInfo/Generic/sroa-extract-bits.ll +--- a/llvm/test/DebugInfo/Generic/sroa-extract-bits.ll ++++ b/llvm/test/DebugInfo/Generic/sroa-extract-bits.ll +@@ -10,11 +10,11 @@ + ; CHECK-SAME: i32 [[ARG:%.*]]) { + ; CHECK-NEXT: entry: + ; CHECK-NEXT: [[PTR_SROA_0_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[ARG]] to i8 +-; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i8 [[PTR_SROA_0_0_EXTRACT_TRUNC]], metadata [[META2:![0-9]+]], metadata !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 8)), !dbg [[DBG7:![0-9]+]] ++; CHECK-NEXT: #dbg_value(i8 [[PTR_SROA_0_0_EXTRACT_TRUNC]], [[META2:![0-9]+]], !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 8), [[META7:![0-9]+]]) + ; CHECK-NEXT: [[PTR_SROA_2_0_EXTRACT_SHIFT:%.*]] = lshr i32 [[ARG]], 8 + ; CHECK-NEXT: [[PTR_SROA_2_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[PTR_SROA_2_0_EXTRACT_SHIFT]] to i24 +-; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i24 [[PTR_SROA_2_0_EXTRACT_TRUNC]], metadata [[META8:![0-9]+]], metadata !DIExpression(DW_OP_LLVM_extract_bits_zext, 8, 16)), !dbg [[DBG7]] +-; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i24 [[PTR_SROA_2_0_EXTRACT_TRUNC]], metadata [[META9:![0-9]+]], metadata !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 8)), !dbg [[DBG7]] ++; CHECK-NEXT: #dbg_value(i24 [[PTR_SROA_2_0_EXTRACT_TRUNC]], [[META8:![0-9]+]], !DIExpression(DW_OP_LLVM_extract_bits_zext, 8, 16), [[META7]]) ++; CHECK-NEXT: #dbg_value(i24 [[PTR_SROA_2_0_EXTRACT_TRUNC]], [[META9:![0-9]+]], !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 8), [[META7]]) + ; CHECK-NEXT: ret i8 [[PTR_SROA_0_0_EXTRACT_TRUNC]] ; - ; GFX9-LABEL: v_abs_v4i16: -@@ -411,24 +410,22 @@ - ; GFX8-LABEL: v_abs_v6i16: - ; GFX8: ; %bb.0: - ; GFX8-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX8-NEXT: v_lshrrev_b32_e32 v3, 16, v2 --; GFX8-NEXT: v_sub_u16_e32 v4, 0, v3 --; GFX8-NEXT: v_max_i16_sdwa v3, v3, v4 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v4, 16, v1 --; GFX8-NEXT: v_sub_u16_e32 v5, 0, v4 --; GFX8-NEXT: v_max_i16_sdwa v4, v4, v5 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v5, 16, v0 --; GFX8-NEXT: v_sub_u16_e32 v6, 0, v5 --; GFX8-NEXT: v_max_i16_sdwa v5, v5, v6 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD -+; GFX8-NEXT: v_mov_b32_e32 v3, 0 -+; GFX8-NEXT: v_sub_u16_sdwa v4, v3, v2 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_sub_u16_sdwa v5, v3, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_sub_u16_sdwa v3, v3, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 - ; GFX8-NEXT: v_sub_u16_e32 v6, 0, v2 - ; GFX8-NEXT: v_sub_u16_e32 v7, 0, v1 - ; GFX8-NEXT: v_sub_u16_e32 v8, 0, v0 -+; GFX8-NEXT: v_max_i16_sdwa v4, v2, v4 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_sdwa v5, v1, v5 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_sdwa v3, v0, v3 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD - ; GFX8-NEXT: v_max_i16_e32 v0, v0, v8 - ; GFX8-NEXT: v_max_i16_e32 v1, v1, v7 - ; GFX8-NEXT: v_max_i16_e32 v2, v2, v6 --; GFX8-NEXT: v_or_b32_e32 v0, v0, v5 --; GFX8-NEXT: v_or_b32_e32 v1, v1, v4 --; GFX8-NEXT: v_or_b32_e32 v2, v2, v3 -+; GFX8-NEXT: v_or_b32_e32 v0, v0, v3 -+; GFX8-NEXT: v_or_b32_e32 v1, v1, v5 -+; GFX8-NEXT: v_or_b32_e32 v2, v2, v4 - ; GFX8-NEXT: s_setpc_b64 s[30:31] + entry: +@@ -33,14 +33,14 @@ + ; CHECK-SAME: i32 [[ARG1:%.*]], i8 [[ARG2:%.*]]) { + ; CHECK-NEXT: entry: + ; CHECK-NEXT: [[PTR_SROA_0_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[ARG1]] to i8 +-; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i8 [[PTR_SROA_0_0_EXTRACT_TRUNC]], metadata [[META2]], metadata !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 8)), !dbg [[DBG7]] ++; CHECK-NEXT: #dbg_value(i8 [[PTR_SROA_0_0_EXTRACT_TRUNC]], [[META2]], !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 8), [[META7]]) + ; CHECK-NEXT: [[PTR_SROA_2_0_EXTRACT_SHIFT:%.*]] = lshr i32 [[ARG1]], 8 + ; CHECK-NEXT: [[PTR_SROA_2_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[PTR_SROA_2_0_EXTRACT_SHIFT]] to i16 +-; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i16 [[PTR_SROA_2_0_EXTRACT_TRUNC]], metadata [[META9]], metadata !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 16)), !dbg [[DBG7]] ++; CHECK-NEXT: #dbg_value(i16 [[PTR_SROA_2_0_EXTRACT_TRUNC]], [[META9]], !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 16), [[META7]]) + ; CHECK-NEXT: [[PTR_SROA_21_0_EXTRACT_SHIFT:%.*]] = lshr i32 [[ARG1]], 24 + ; CHECK-NEXT: [[PTR_SROA_21_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[PTR_SROA_21_0_EXTRACT_SHIFT]] to i8 +-; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i8 [[PTR_SROA_21_0_EXTRACT_TRUNC]], metadata [[META8]], metadata !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 8)), !dbg [[DBG7]] +-; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i8 [[ARG2]], metadata [[META8]], metadata !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 8)), !dbg [[DBG7]] ++; CHECK-NEXT: #dbg_value(i8 [[PTR_SROA_21_0_EXTRACT_TRUNC]], [[META8]], !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 8), [[META7]]) ++; CHECK-NEXT: #dbg_value(i8 [[ARG2]], [[META8]], !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 8), [[META7]]) + ; CHECK-NEXT: ret i8 [[PTR_SROA_0_0_EXTRACT_TRUNC]] ; - ; GFX9-LABEL: v_abs_v6i16: -@@ -572,30 +569,27 @@ - ; GFX8-LABEL: v_abs_v8i16: - ; GFX8: ; %bb.0: - ; GFX8-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX8-NEXT: v_lshrrev_b32_e32 v4, 16, v3 --; GFX8-NEXT: v_sub_u16_e32 v5, 0, v4 --; GFX8-NEXT: v_max_i16_sdwa v4, v4, v5 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v5, 16, v2 --; GFX8-NEXT: v_sub_u16_e32 v6, 0, v5 --; GFX8-NEXT: v_max_i16_sdwa v5, v5, v6 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v6, 16, v1 --; GFX8-NEXT: v_sub_u16_e32 v7, 0, v6 --; GFX8-NEXT: v_max_i16_sdwa v6, v6, v7 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v7, 16, v0 --; GFX8-NEXT: v_sub_u16_e32 v8, 0, v7 --; GFX8-NEXT: v_max_i16_sdwa v7, v7, v8 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD -+; GFX8-NEXT: v_mov_b32_e32 v4, 0 -+; GFX8-NEXT: v_sub_u16_sdwa v5, v4, v3 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_sub_u16_sdwa v6, v4, v2 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_sub_u16_sdwa v7, v4, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_sub_u16_sdwa v4, v4, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 - ; GFX8-NEXT: v_sub_u16_e32 v8, 0, v3 - ; GFX8-NEXT: v_sub_u16_e32 v9, 0, v2 - ; GFX8-NEXT: v_sub_u16_e32 v10, 0, v1 - ; GFX8-NEXT: v_sub_u16_e32 v11, 0, v0 -+; GFX8-NEXT: v_max_i16_sdwa v5, v3, v5 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_sdwa v6, v2, v6 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_sdwa v7, v1, v7 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_sdwa v4, v0, v4 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD - ; GFX8-NEXT: v_max_i16_e32 v0, v0, v11 - ; GFX8-NEXT: v_max_i16_e32 v1, v1, v10 - ; GFX8-NEXT: v_max_i16_e32 v2, v2, v9 - ; GFX8-NEXT: v_max_i16_e32 v3, v3, v8 --; GFX8-NEXT: v_or_b32_e32 v0, v0, v7 --; GFX8-NEXT: v_or_b32_e32 v1, v1, v6 --; GFX8-NEXT: v_or_b32_e32 v2, v2, v5 --; GFX8-NEXT: v_or_b32_e32 v3, v3, v4 -+; GFX8-NEXT: v_or_b32_e32 v0, v0, v4 -+; GFX8-NEXT: v_or_b32_e32 v1, v1, v7 -+; GFX8-NEXT: v_or_b32_e32 v2, v2, v6 -+; GFX8-NEXT: v_or_b32_e32 v3, v3, v5 - ; GFX8-NEXT: s_setpc_b64 s[30:31] + entry: +@@ -81,10 +81,10 @@ + ; CHECK-SAME: i32 [[ARG:%.*]]) { + ; CHECK-NEXT: entry: + ; CHECK-NEXT: [[PTR_SROA_0_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[ARG]] to i16 +-; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i16 [[PTR_SROA_0_0_EXTRACT_TRUNC]], metadata [[META2]], metadata !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 8)), !dbg [[DBG7]] ++; CHECK-NEXT: #dbg_value(i16 [[PTR_SROA_0_0_EXTRACT_TRUNC]], [[META2]], !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 8), [[META7]]) + ; CHECK-NEXT: [[PTR_SROA_2_0_EXTRACT_SHIFT:%.*]] = lshr i32 [[ARG]], 16 + ; CHECK-NEXT: [[PTR_SROA_2_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[PTR_SROA_2_0_EXTRACT_SHIFT]] to i16 +-; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i16 [[PTR_SROA_2_0_EXTRACT_TRUNC]], metadata [[META8]], metadata !DIExpression(DW_OP_LLVM_extract_bits_zext, 8, 8)), !dbg [[DBG7]] ++; CHECK-NEXT: #dbg_value(i16 [[PTR_SROA_2_0_EXTRACT_TRUNC]], [[META8]], !DIExpression(DW_OP_LLVM_extract_bits_zext, 8, 8), [[META7]]) + ; CHECK-NEXT: ret i16 [[PTR_SROA_0_0_EXTRACT_TRUNC]] ; - ; GFX9-LABEL: v_abs_v8i16: -@@ -820,30 +814,15 @@ - ; GFX8-LABEL: v_abs_v16i16: - ; GFX8: ; %bb.0: - ; GFX8-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX8-NEXT: v_lshrrev_b32_e32 v8, 16, v7 --; GFX8-NEXT: v_sub_u16_e32 v9, 0, v8 --; GFX8-NEXT: v_max_i16_sdwa v8, v8, v9 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v9, 16, v6 --; GFX8-NEXT: v_sub_u16_e32 v10, 0, v9 --; GFX8-NEXT: v_max_i16_sdwa v9, v9, v10 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v10, 16, v5 --; GFX8-NEXT: v_sub_u16_e32 v11, 0, v10 --; GFX8-NEXT: v_max_i16_sdwa v10, v10, v11 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v11, 16, v4 --; GFX8-NEXT: v_sub_u16_e32 v12, 0, v11 --; GFX8-NEXT: v_max_i16_sdwa v11, v11, v12 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v12, 16, v3 --; GFX8-NEXT: v_sub_u16_e32 v13, 0, v12 --; GFX8-NEXT: v_max_i16_sdwa v12, v12, v13 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v13, 16, v2 --; GFX8-NEXT: v_sub_u16_e32 v14, 0, v13 --; GFX8-NEXT: v_max_i16_sdwa v13, v13, v14 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v14, 16, v1 --; GFX8-NEXT: v_sub_u16_e32 v15, 0, v14 --; GFX8-NEXT: v_max_i16_sdwa v14, v14, v15 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v15, 16, v0 --; GFX8-NEXT: v_sub_u16_e32 v16, 0, v15 --; GFX8-NEXT: v_max_i16_sdwa v15, v15, v16 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD -+; GFX8-NEXT: v_mov_b32_e32 v8, 0 -+; GFX8-NEXT: v_sub_u16_sdwa v9, v8, v7 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_sub_u16_sdwa v10, v8, v6 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_sub_u16_sdwa v11, v8, v5 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_sub_u16_sdwa v12, v8, v4 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_sub_u16_sdwa v13, v8, v3 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_sub_u16_sdwa v14, v8, v2 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_sub_u16_sdwa v15, v8, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_sub_u16_sdwa v8, v8, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 - ; GFX8-NEXT: v_sub_u16_e32 v16, 0, v7 - ; GFX8-NEXT: v_sub_u16_e32 v17, 0, v6 - ; GFX8-NEXT: v_sub_u16_e32 v18, 0, v5 -@@ -852,6 +831,14 @@ - ; GFX8-NEXT: v_sub_u16_e32 v21, 0, v2 - ; GFX8-NEXT: v_sub_u16_e32 v22, 0, v1 - ; GFX8-NEXT: v_sub_u16_e32 v23, 0, v0 -+; GFX8-NEXT: v_max_i16_sdwa v9, v7, v9 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_sdwa v10, v6, v10 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_sdwa v11, v5, v11 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_sdwa v12, v4, v12 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_sdwa v13, v3, v13 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_sdwa v14, v2, v14 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_sdwa v15, v1, v15 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_sdwa v8, v0, v8 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD - ; GFX8-NEXT: v_max_i16_e32 v0, v0, v23 - ; GFX8-NEXT: v_max_i16_e32 v1, v1, v22 - ; GFX8-NEXT: v_max_i16_e32 v2, v2, v21 -@@ -860,14 +847,14 @@ - ; GFX8-NEXT: v_max_i16_e32 v5, v5, v18 - ; GFX8-NEXT: v_max_i16_e32 v6, v6, v17 - ; GFX8-NEXT: v_max_i16_e32 v7, v7, v16 --; GFX8-NEXT: v_or_b32_e32 v0, v0, v15 --; GFX8-NEXT: v_or_b32_e32 v1, v1, v14 --; GFX8-NEXT: v_or_b32_e32 v2, v2, v13 --; GFX8-NEXT: v_or_b32_e32 v3, v3, v12 --; GFX8-NEXT: v_or_b32_e32 v4, v4, v11 --; GFX8-NEXT: v_or_b32_e32 v5, v5, v10 --; GFX8-NEXT: v_or_b32_e32 v6, v6, v9 --; GFX8-NEXT: v_or_b32_e32 v7, v7, v8 -+; GFX8-NEXT: v_or_b32_e32 v0, v0, v8 -+; GFX8-NEXT: v_or_b32_e32 v1, v1, v15 -+; GFX8-NEXT: v_or_b32_e32 v2, v2, v14 -+; GFX8-NEXT: v_or_b32_e32 v3, v3, v13 -+; GFX8-NEXT: v_or_b32_e32 v4, v4, v12 -+; GFX8-NEXT: v_or_b32_e32 v5, v5, v11 -+; GFX8-NEXT: v_or_b32_e32 v6, v6, v10 -+; GFX8-NEXT: v_or_b32_e32 v7, v7, v9 - ; GFX8-NEXT: s_setpc_b64 s[30:31] + entry: +@@ -104,11 +104,11 @@ + ; CHECK-SAME: i32 [[ARG:%.*]]) { + ; CHECK-NEXT: entry: + ; CHECK-NEXT: [[PTR_SROA_0_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[ARG]] to i8 +-; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i8 [[PTR_SROA_0_0_EXTRACT_TRUNC]], metadata [[META11:![0-9]+]], metadata !DIExpression()), !dbg [[DBG7]] ++; CHECK-NEXT: #dbg_value(i8 [[PTR_SROA_0_0_EXTRACT_TRUNC]], [[META11:![0-9]+]], !DIExpression(), [[META7]]) + ; CHECK-NEXT: [[PTR_SROA_2_0_EXTRACT_SHIFT:%.*]] = lshr i32 [[ARG]], 8 + ; CHECK-NEXT: [[PTR_SROA_2_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[PTR_SROA_2_0_EXTRACT_SHIFT]] to i24 +-; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i24 [[PTR_SROA_2_0_EXTRACT_TRUNC]], metadata [[META8]], metadata !DIExpression(DW_OP_LLVM_extract_bits_zext, 8, 8)), !dbg [[DBG7]] +-; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i24 [[PTR_SROA_2_0_EXTRACT_TRUNC]], metadata [[META9]], metadata !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 8)), !dbg [[DBG7]] ++; CHECK-NEXT: #dbg_value(i24 [[PTR_SROA_2_0_EXTRACT_TRUNC]], [[META8]], !DIExpression(DW_OP_LLVM_extract_bits_zext, 8, 8), [[META7]]) ++; CHECK-NEXT: #dbg_value(i24 [[PTR_SROA_2_0_EXTRACT_TRUNC]], [[META9]], !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 8), [[META7]]) + ; CHECK-NEXT: ret i8 [[PTR_SROA_0_0_EXTRACT_TRUNC]] ; - ; GFX9-LABEL: v_abs_v16i16: -@@ -1267,102 +1254,87 @@ - ; GFX8-LABEL: v_abs_v32i16: - ; GFX8: ; %bb.0: - ; GFX8-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX8-NEXT: v_lshrrev_b32_e32 v16, 16, v15 --; GFX8-NEXT: v_sub_u16_e32 v17, 0, v16 --; GFX8-NEXT: v_max_i16_sdwa v16, v16, v17 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v17, 16, v14 --; GFX8-NEXT: v_sub_u16_e32 v18, 0, v17 --; GFX8-NEXT: v_max_i16_sdwa v17, v17, v18 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v18, 16, v13 --; GFX8-NEXT: v_sub_u16_e32 v19, 0, v18 --; GFX8-NEXT: v_max_i16_sdwa v18, v18, v19 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v19, 16, v12 --; GFX8-NEXT: v_sub_u16_e32 v20, 0, v19 --; GFX8-NEXT: v_max_i16_sdwa v19, v19, v20 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v20, 16, v11 --; GFX8-NEXT: v_sub_u16_e32 v21, 0, v20 --; GFX8-NEXT: v_max_i16_sdwa v20, v20, v21 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v21, 16, v10 --; GFX8-NEXT: v_sub_u16_e32 v22, 0, v21 --; GFX8-NEXT: v_max_i16_sdwa v21, v21, v22 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v22, 16, v9 --; GFX8-NEXT: v_sub_u16_e32 v23, 0, v22 --; GFX8-NEXT: v_max_i16_sdwa v22, v22, v23 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v23, 16, v8 --; GFX8-NEXT: v_sub_u16_e32 v24, 0, v23 --; GFX8-NEXT: v_max_i16_sdwa v23, v23, v24 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v24, 16, v7 --; GFX8-NEXT: v_sub_u16_e32 v25, 0, v24 --; GFX8-NEXT: v_max_i16_sdwa v24, v24, v25 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v25, 16, v6 --; GFX8-NEXT: v_sub_u16_e32 v26, 0, v25 --; GFX8-NEXT: v_max_i16_sdwa v25, v25, v26 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v26, 16, v5 --; GFX8-NEXT: v_sub_u16_e32 v27, 0, v26 --; GFX8-NEXT: v_max_i16_sdwa v26, v26, v27 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v27, 16, v4 --; GFX8-NEXT: v_sub_u16_e32 v28, 0, v27 --; GFX8-NEXT: v_max_i16_sdwa v27, v27, v28 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v28, 16, v3 --; GFX8-NEXT: v_sub_u16_e32 v29, 0, v28 --; GFX8-NEXT: v_max_i16_sdwa v28, v28, v29 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v29, 16, v2 --; GFX8-NEXT: v_sub_u16_e32 v30, 0, v29 --; GFX8-NEXT: v_max_i16_sdwa v29, v29, v30 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v30, 16, v1 --; GFX8-NEXT: v_sub_u16_e32 v31, 0, v30 --; GFX8-NEXT: v_max_i16_sdwa v30, v30, v31 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_lshrrev_b32_e32 v31, 16, v0 --; GFX8-NEXT: v_sub_u16_e32 v32, 0, v31 --; GFX8-NEXT: v_max_i16_sdwa v31, v31, v32 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD --; GFX8-NEXT: v_sub_u16_e32 v32, 0, v0 --; GFX8-NEXT: v_max_i16_e32 v0, v0, v32 --; GFX8-NEXT: v_or_b32_e32 v0, v0, v31 --; GFX8-NEXT: v_sub_u16_e32 v31, 0, v1 --; GFX8-NEXT: v_max_i16_e32 v1, v1, v31 --; GFX8-NEXT: v_or_b32_e32 v1, v1, v30 --; GFX8-NEXT: v_sub_u16_e32 v30, 0, v2 --; GFX8-NEXT: v_max_i16_e32 v2, v2, v30 --; GFX8-NEXT: v_or_b32_e32 v2, v2, v29 --; GFX8-NEXT: v_sub_u16_e32 v29, 0, v3 --; GFX8-NEXT: v_max_i16_e32 v3, v3, v29 --; GFX8-NEXT: v_or_b32_e32 v3, v3, v28 --; GFX8-NEXT: v_sub_u16_e32 v28, 0, v4 --; GFX8-NEXT: v_max_i16_e32 v4, v4, v28 --; GFX8-NEXT: v_or_b32_e32 v4, v4, v27 --; GFX8-NEXT: v_sub_u16_e32 v27, 0, v5 --; GFX8-NEXT: v_max_i16_e32 v5, v5, v27 --; GFX8-NEXT: v_or_b32_e32 v5, v5, v26 --; GFX8-NEXT: v_sub_u16_e32 v26, 0, v6 --; GFX8-NEXT: v_max_i16_e32 v6, v6, v26 --; GFX8-NEXT: v_or_b32_e32 v6, v6, v25 --; GFX8-NEXT: v_sub_u16_e32 v25, 0, v7 --; GFX8-NEXT: v_max_i16_e32 v7, v7, v25 --; GFX8-NEXT: v_or_b32_e32 v7, v7, v24 --; GFX8-NEXT: v_sub_u16_e32 v24, 0, v8 --; GFX8-NEXT: v_max_i16_e32 v8, v8, v24 --; GFX8-NEXT: v_or_b32_e32 v8, v8, v23 --; GFX8-NEXT: v_sub_u16_e32 v23, 0, v9 --; GFX8-NEXT: v_max_i16_e32 v9, v9, v23 --; GFX8-NEXT: v_or_b32_e32 v9, v9, v22 --; GFX8-NEXT: v_sub_u16_e32 v22, 0, v10 --; GFX8-NEXT: v_max_i16_e32 v10, v10, v22 --; GFX8-NEXT: v_or_b32_e32 v10, v10, v21 --; GFX8-NEXT: v_sub_u16_e32 v21, 0, v11 --; GFX8-NEXT: v_max_i16_e32 v11, v11, v21 -+; GFX8-NEXT: v_mov_b32_e32 v16, 0 -+; GFX8-NEXT: v_sub_u16_sdwa v19, v16, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_sub_u16_e32 v20, 0, v0 -+; GFX8-NEXT: v_max_i16_sdwa v19, v0, v19 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_e32 v0, v0, v20 -+; GFX8-NEXT: v_sub_u16_sdwa v20, v16, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_or_b32_e32 v0, v0, v19 -+; GFX8-NEXT: v_sub_u16_e32 v19, 0, v1 -+; GFX8-NEXT: v_max_i16_sdwa v20, v1, v20 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_e32 v1, v1, v19 -+; GFX8-NEXT: v_sub_u16_sdwa v19, v16, v2 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_or_b32_e32 v1, v1, v20 -+; GFX8-NEXT: v_sub_u16_e32 v20, 0, v2 -+; GFX8-NEXT: v_max_i16_sdwa v19, v2, v19 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_e32 v2, v2, v20 -+; GFX8-NEXT: v_sub_u16_sdwa v20, v16, v3 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_or_b32_e32 v2, v2, v19 -+; GFX8-NEXT: v_sub_u16_e32 v19, 0, v3 -+; GFX8-NEXT: v_max_i16_sdwa v20, v3, v20 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_e32 v3, v3, v19 -+; GFX8-NEXT: v_sub_u16_sdwa v19, v16, v4 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_or_b32_e32 v3, v3, v20 -+; GFX8-NEXT: v_sub_u16_e32 v20, 0, v4 -+; GFX8-NEXT: v_max_i16_sdwa v19, v4, v19 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_e32 v4, v4, v20 -+; GFX8-NEXT: v_sub_u16_sdwa v20, v16, v5 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_or_b32_e32 v4, v4, v19 -+; GFX8-NEXT: v_sub_u16_e32 v19, 0, v5 -+; GFX8-NEXT: v_max_i16_sdwa v20, v5, v20 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_e32 v5, v5, v19 -+; GFX8-NEXT: v_sub_u16_sdwa v19, v16, v6 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_or_b32_e32 v5, v5, v20 -+; GFX8-NEXT: v_sub_u16_e32 v20, 0, v6 -+; GFX8-NEXT: v_max_i16_sdwa v19, v6, v19 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_e32 v6, v6, v20 -+; GFX8-NEXT: v_sub_u16_sdwa v20, v16, v7 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_or_b32_e32 v6, v6, v19 -+; GFX8-NEXT: v_sub_u16_e32 v19, 0, v7 -+; GFX8-NEXT: v_max_i16_sdwa v20, v7, v20 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_e32 v7, v7, v19 -+; GFX8-NEXT: v_sub_u16_sdwa v19, v16, v8 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_or_b32_e32 v7, v7, v20 -+; GFX8-NEXT: v_sub_u16_e32 v20, 0, v8 -+; GFX8-NEXT: v_max_i16_sdwa v19, v8, v19 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_e32 v8, v8, v20 -+; GFX8-NEXT: v_sub_u16_sdwa v20, v16, v9 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_or_b32_e32 v8, v8, v19 -+; GFX8-NEXT: v_sub_u16_e32 v19, 0, v9 -+; GFX8-NEXT: v_max_i16_sdwa v20, v9, v20 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_e32 v9, v9, v19 -+; GFX8-NEXT: v_sub_u16_sdwa v19, v16, v10 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_or_b32_e32 v9, v9, v20 -+; GFX8-NEXT: v_sub_u16_e32 v20, 0, v10 -+; GFX8-NEXT: v_max_i16_sdwa v19, v10, v19 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_e32 v10, v10, v20 -+; GFX8-NEXT: v_sub_u16_sdwa v20, v16, v11 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_or_b32_e32 v10, v10, v19 -+; GFX8-NEXT: v_sub_u16_e32 v19, 0, v11 -+; GFX8-NEXT: v_max_i16_sdwa v20, v11, v20 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_e32 v11, v11, v19 -+; GFX8-NEXT: v_sub_u16_sdwa v17, v16, v15 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_sub_u16_sdwa v18, v16, v14 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_sub_u16_sdwa v19, v16, v13 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 -+; GFX8-NEXT: v_sub_u16_sdwa v16, v16, v12 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 - ; GFX8-NEXT: v_or_b32_e32 v11, v11, v20 - ; GFX8-NEXT: v_sub_u16_e32 v20, 0, v12 -+; GFX8-NEXT: v_max_i16_sdwa v16, v12, v16 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD - ; GFX8-NEXT: v_max_i16_e32 v12, v12, v20 --; GFX8-NEXT: v_or_b32_e32 v12, v12, v19 --; GFX8-NEXT: v_sub_u16_e32 v19, 0, v13 -+; GFX8-NEXT: v_or_b32_e32 v12, v12, v16 -+; GFX8-NEXT: v_sub_u16_e32 v16, 0, v13 -+; GFX8-NEXT: v_max_i16_sdwa v19, v13, v19 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD - ; GFX8-NEXT: v_sub_u16_e32 v20, 0, v15 --; GFX8-NEXT: v_max_i16_e32 v13, v13, v19 --; GFX8-NEXT: v_sub_u16_e32 v19, 0, v14 --; GFX8-NEXT: v_max_i16_e32 v14, v14, v19 -+; GFX8-NEXT: v_max_i16_e32 v13, v13, v16 -+; GFX8-NEXT: v_sub_u16_e32 v16, 0, v14 -+; GFX8-NEXT: v_max_i16_sdwa v17, v15, v17 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_sdwa v18, v14, v18 dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD -+; GFX8-NEXT: v_max_i16_e32 v14, v14, v16 - ; GFX8-NEXT: v_max_i16_e32 v15, v15, v20 --; GFX8-NEXT: v_or_b32_e32 v13, v13, v18 --; GFX8-NEXT: v_or_b32_e32 v14, v14, v17 --; GFX8-NEXT: v_or_b32_e32 v15, v15, v16 -+; GFX8-NEXT: v_or_b32_e32 v13, v13, v19 -+; GFX8-NEXT: v_or_b32_e32 v14, v14, v18 -+; GFX8-NEXT: v_or_b32_e32 v15, v15, v17 - ; GFX8-NEXT: s_setpc_b64 s[30:31] + entry: +@@ -127,14 +127,14 @@ + ; CHECK-SAME: i32 [[ARG1:%.*]], i8 [[ARG2:%.*]]) { + ; CHECK-NEXT: entry: + ; CHECK-NEXT: [[PTR_SROA_0_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[ARG1]] to i8 +-; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i8 undef, metadata [[META2]], metadata !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 8)), !dbg [[DBG7]] ++; CHECK-NEXT: #dbg_value(i8 undef, [[META2]], !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 8), [[META7]]) + ; CHECK-NEXT: [[PTR_SROA_2_0_EXTRACT_SHIFT:%.*]] = lshr i32 [[ARG1]], 8 + ; CHECK-NEXT: [[PTR_SROA_2_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[PTR_SROA_2_0_EXTRACT_SHIFT]] to i16 +-; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i16 undef, metadata [[META9]], metadata !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 16)), !dbg [[DBG7]] ++; CHECK-NEXT: #dbg_value(i16 undef, [[META9]], !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 16), [[META7]]) + ; CHECK-NEXT: [[PTR_SROA_21_0_EXTRACT_SHIFT:%.*]] = lshr i32 [[ARG1]], 24 + ; CHECK-NEXT: [[PTR_SROA_21_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[PTR_SROA_21_0_EXTRACT_SHIFT]] to i8 +-; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i8 undef, metadata [[META8]], metadata !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 8)), !dbg [[DBG7]] +-; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i8 undef, metadata [[META8]], metadata !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 8)), !dbg [[DBG7]] ++; CHECK-NEXT: #dbg_value(i8 undef, [[META8]], !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 8), [[META7]]) ++; CHECK-NEXT: #dbg_value(i8 undef, [[META8]], !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 8), [[META7]]) + ; CHECK-NEXT: ret i8 [[PTR_SROA_0_0_EXTRACT_TRUNC]] ; - ; GFX9-LABEL: v_abs_v32i16: + entry: +@@ -196,7 +196,7 @@ + ; CHECK: [[META4]] = distinct !DICompileUnit(language: DW_LANG_C_plus_plus_14, file: [[META5:![0-9]+]], isOptimized: false, runtimeVersion: 0, emissionKind: FullDebug) + ; CHECK: [[META5]] = !DIFile(filename: "dbg-bit-piece.cpp", directory: "") + ; CHECK: [[META6]] = !DIBasicType(name: "unsigned int", size: 32, encoding: DW_ATE_unsigned) +-; CHECK: [[DBG7]] = !DILocation(line: 0, scope: [[META3]]) ++; CHECK: [[META7]] = !DILocation(line: 0, scope: [[META3]]) + ; CHECK: [[META8]] = !DILocalVariable(name: "z", scope: [[META3]], type: [[META6]]) + ; CHECK: [[META9]] = !DILocalVariable(name: "y", scope: [[META3]], type: [[META10:![0-9]+]]) + ; CHECK: [[META10]] = !DIBasicType(name: "signed int", size: 32, encoding: DW_ATE_signed) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index fc1cf70ed11f47..0ef1683dacd8a5 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "93ffe1792fd9a985b96fee1105b399b5196a15bc" - LLVM_SHA256 = "f92f48acaf9a6543097dc50acb761584ca3da3f2a137d7ad32aa92650b9ea932" + LLVM_COMMIT = "52d87de7a42d608ac1da33795ca0a892f2b53f36" + LLVM_SHA256 = "6c78e61b7b1cef7ca91c9454bc18fb63426040dd06f9639ae9a89758926eec57" tf_http_archive( name = name, From 221220f141dd47a393d8b8dcafd9a3289137fc36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= Date: Tue, 18 Jun 2024 09:08:15 -0700 Subject: [PATCH 20/59] [XLA:GPU] Experimental: Add --xla_gpu_per_fusion_autotune_cache_dir option If the option is set, we will maintain (read/write) a per-fusion autotune cache in the given directory. The directory must exist. Cache invalidation has to be handled by the user (e.g. please use an empty directory if you want to start with an empty cache). XLA version checks must be done by the user (e.g. if you want to cache fusions created with different versions of XLA, please use different directories). (If the using library already has a version handling mechanism, like JAX, then it shouldn't be difficult for them to create separate directories based on that version (and all the parameters which matter to them).) Default: no file based cache. There is minimal support for multiple processes using the same cache - the rename trick is used to avoid writing the same file by multiple processes at the same time or reading incomplete files. We use SHA256 hashes in the filenames and assume that no collisions occur. This is a simple implementation to allow people to test it and find good use-cases. If needed we can refine it later. Considered use case: People running [multiple] [similar] models [through JAX]. For example there are 2 similar HLOs that we want to run with JAX (using the same "XLA binary") and it would be nice to reuse the autotune results from the first, if some kernels appear in both. Similarly: Consider the use case of a researcher sitting at a Colab session and making small changes to their model. They should mostly get cache hits! Limitations: It is not recommended to change the cache directory during the run of a process, because then the in-memory and the file based cache can become inconsistent. At least clear the in-memory cache if you change it. When loading results with LoadAutotuneResults[FromFile], they are not written into the cache directory. PiperOrigin-RevId: 644406688 --- third_party/xla/xla/debug_options_flags.cc | 15 ++ third_party/xla/xla/service/gpu/BUILD | 44 ++-- .../xla/xla/service/gpu/autotuner_util.cc | 213 +++++++++++++++--- .../xla/xla/service/gpu/autotuner_util.h | 40 +++- .../xla/service/gpu/autotuner_util_test.cc | 181 ++++++++++++++- .../xla/service/gpu/gemm_fusion_autotuner.cc | 12 +- third_party/xla/xla/xla.proto | 4 +- 7 files changed, 455 insertions(+), 54 deletions(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 7bb8e33a62336b..fa4cb08a82a3f7 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -270,6 +270,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_shard_autotuning(false); + opts.set_xla_gpu_per_fusion_autotune_cache_dir(""); + return opts; } @@ -1768,6 +1770,19 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "reused in further compilations; not yet cached kernels are " "compiled as usual and get appended to the cache file whenever " "possible.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_per_fusion_autotune_cache_dir", + string_setter_for( + &DebugOptions::set_xla_gpu_per_fusion_autotune_cache_dir), + debug_options->xla_gpu_per_fusion_autotune_cache_dir(), + "Experimental: Maintain a per-fusion autotune cache in the given " + "directory. XLA will try to read existing results when they are needed " + "and write new results when they are determined. The directory must " + "exist. Cache invalidation has to be handled by the user (e.g. please " + "use an empty directory if you want to start with an empty cache). XLA " + "version checks must be done by the user (e.g. if you want to use " + "separate caches for different versions of XLA, please use different " + "directories). Default: no cache.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index accbe67948df09..7b5b1a36d716d8 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1630,18 +1630,19 @@ cc_library( deps = if_gpu_is_configured([ ":gpu_asm_opts_util", ":stream_executor_util", - "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/hash", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", - "//xla/hlo/ir:hlo", - "//xla/service:compilation_environments", - "//xla/stream_executor", - "//xla/stream_executor/gpu:redzone_allocator", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", "//xla:autotune_results_proto_cc", "//xla:autotuning_proto_cc", "//xla:shape_util", @@ -1650,16 +1651,19 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:compilation_environments", + "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/stream_executor", + "//xla/stream_executor/gpu:redzone_allocator", + "@local_tsl//tsl/platform:base64", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", - ]) + [ - "//xla/stream_executor:stream_executor_memory_allocator", - "@com_google_absl//absl/status", - ], + ]), ) # We need a separate target, as runtime executable cannot depend on compilation @@ -5842,23 +5846,31 @@ xla_cc_test( "@com_google_googletest//:gtest", "@com_google_absl//absl/base:log_severity", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/log:scoped_mock_log", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "//xla:autotune_results_proto_cc", + "//xla:autotuning_proto_cc", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", - "//xla/stream_executor:platform", + "//xla/hlo/utils:hlo_query", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:platform", "//xla/stream_executor/host:host_platform", "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", + "//xla/tests:xla_internal_test_main", # Keep outside GPU guard "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:protobuf", - ]) + [ - "//xla/tests:xla_internal_test_main", # Keep outside GPU guard - "@com_google_absl//absl/status", - ], + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ]), ) cc_library( diff --git a/third_party/xla/xla/service/gpu/autotuner_util.cc b/third_party/xla/xla/service/gpu/autotuner_util.cc index 056cced3f53aaa..56fc9b3b33030d 100644 --- a/third_party/xla/xla/service/gpu/autotuner_util.cc +++ b/third_party/xla/xla/service/gpu/autotuner_util.cc @@ -16,19 +16,26 @@ limitations under the License. #include "xla/service/gpu/autotuner_util.h" #include +#include #include #include +#include #include #include #include "absl/base/const_init.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/SHA256.h" #include "xla/autotune_results.pb.h" #include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_clone_context.h" @@ -43,6 +50,7 @@ limitations under the License. #include "xla/stream_executor/gpu/redzone_allocator.h" #include "xla/stream_executor/stream.h" #include "xla/util.h" +#include "tsl/platform/base64.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" @@ -67,8 +75,130 @@ static absl::Mutex autotune_cache_mu(absl::kConstInit); static auto& autotune_cache ABSL_GUARDED_BY(autotune_cache_mu) = *new AutotuneCacheMap(); +absl::StatusOr GetBase64EncodedSha256Hash(absl::string_view s) { + llvm::SHA256 sha256; + sha256.update(llvm::StringRef(s)); + std::array hash = sha256.final(); + // C++ strict aliasing rules allow reinterpret casting to (const) char*. + absl::string_view hash_view(reinterpret_cast(hash.data()), + hash.size()); + std::string base64_encoded_hash; + TF_RETURN_IF_ERROR(tsl::Base64Encode(hash_view, &base64_encoded_hash)); + return base64_encoded_hash; +} + +namespace { + +// Get the path corresponding to the given key. +absl::StatusOr GetCacheFilePath(absl::string_view cache_dir, + const AutotuneCacheKey& key) { + if (cache_dir.empty()) { + return absl::InvalidArgumentError("autotune_cache_dir should not be empty"); + } + + TF_ASSIGN_OR_RETURN(std::string key_hash, + GetBase64EncodedSha256Hash(key.ToString())); + return tsl::io::JoinPath(cache_dir, absl::StrCat(key_hash, ".textproto")); +} + +struct ResultAndInserted { + // The result that ended up in the cache. This is the existing result if + // inserted is false, and the new result if inserted is true. + // + // We return a value, not a pointer, for thread safety reasons. + AutotuneResult result; + // Did we insert the given result into the cache? + bool inserted; +}; + +ResultAndInserted AddResultToInMemoryCache(const AutotuneCacheKey& key, + AutotuneResult result) + ABSL_LOCKS_EXCLUDED(autotune_cache_mu) { + absl::MutexLock lock(&autotune_cache_mu); + auto [it, inserted] = autotune_cache.emplace(key, std::move(result)); + return {it->second, inserted}; +} + +absl::Status AddResultToFileBasedCacheIfEnabled(const AutotuneCacheKey& key, + AutotuneResult result, + std::string_view cache_dir) + ABSL_LOCKS_EXCLUDED(autotune_cache_mu) { + if (cache_dir.empty()) { + return absl::OkStatus(); + } + + TF_ASSIGN_OR_RETURN(const std::string file_path, + GetCacheFilePath(cache_dir, key)); + + VLOG(1) << "Writing autotune result to file: " << file_path; + + std::string result_str; + if (!tsl::protobuf::TextFormat::PrintToString(result, &result_str)) { + return absl::InternalError("Failed to serialize autotune result."); + } + + // Rename trick: Write to a temporary file, then rename it to the final file + // to avoid mingled files when multiple processes are writing to the same + // file. Also avoids reading incomplete files. (This may not work on all file + // systems.) + std::string temp_file_path = tsl::io::GetTempFilename(".textproto"); + tsl::Env* default_env = tsl::Env::Default(); + TF_RETURN_IF_ERROR( + tsl::WriteStringToFile(default_env, temp_file_path, result_str)); + return default_env->RenameFile(temp_file_path, file_path); +} + +absl::StatusOr AddResultToCaches(const AutotuneCacheKey& key, + AutotuneResult result, + std::string_view cache_dir) + ABSL_LOCKS_EXCLUDED(autotune_cache_mu) { + ResultAndInserted result_and_inserted = AddResultToInMemoryCache(key, result); + if (result_and_inserted.inserted) { + TF_RETURN_IF_ERROR(AddResultToFileBasedCacheIfEnabled( + key, result_and_inserted.result, cache_dir)); + } + return result_and_inserted; +} + +std::optional TryToFindInInMemoryCache( + const AutotuneCacheKey& key) ABSL_LOCKS_EXCLUDED(autotune_cache_mu) { + absl::MutexLock lock(&autotune_cache_mu); + auto it = autotune_cache.find(key); + if (it == autotune_cache.end()) { + return std::nullopt; + } + return it->second; +} + +absl::StatusOr> +TryToFindInFileBasedCacheIfEnabled(const AutotuneCacheKey& key, + absl::string_view cache_dir) + ABSL_LOCKS_EXCLUDED(autotune_cache_mu) { + if (cache_dir.empty()) { + return std::nullopt; + } + + TF_ASSIGN_OR_RETURN(const std::string file_path, + GetCacheFilePath(cache_dir, key)); + if (!tsl::Env::Default()->FileExists(file_path).ok()) { + VLOG(1) << "Autotune result file not found: " << file_path; + return std::nullopt; + } + + VLOG(1) << "Autotune result file found: " << file_path; + std::string autotune_result_str; + TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), file_path, + &autotune_result_str)); + AutotuneResult result; + if (!tsl::protobuf::TextFormat::ParseFromString(autotune_result_str, + &result)) { + return absl::InvalidArgumentError("Failed to parse autotune result."); + } + return result; +} + // Sort the results so that they're deterministic. -static void SortAutotuneResults(AutotuneResults* results) { +void SortAutotuneResults(AutotuneResults* results) { std::sort(results->mutable_results()->pointer_begin(), results->mutable_results()->pointer_end(), [](const auto* a, const auto* b) { @@ -79,6 +209,8 @@ static void SortAutotuneResults(AutotuneResults* results) { }); } +} // namespace + // Serialize `results` to string as a proto. absl::StatusOr AutotuneResultsToString( const AutotuneResults& results, bool as_textproto) { @@ -93,15 +225,16 @@ absl::StatusOr AutotuneResultsToString( return results.SerializeAsString(); } +namespace { // Serialize a single entry to `results`. -static void SerializeAutotuneEntry(AutotuneResults* results, - const AutotuneCacheKey& k, - const AutotuneResult* res) { +void SerializeAutotuneEntry(AutotuneResults* results, const AutotuneCacheKey& k, + const AutotuneResult* res) { auto& entry = *results->add_results(); entry.set_device(std::string(k.GetModelStr())); entry.set_hlo(std::string(k.GetHlo())); *entry.mutable_result() = *res; } +} // namespace /*static*/ absl::Status AutotunerUtil::SerializeAutotuneResults( AutotuneResults* results) { @@ -152,7 +285,8 @@ static void SerializeAutotuneEntry(AutotuneResults* results, return buffer; } -static std::string ToCanonicalString(const HloInstruction* instr) { +namespace { +std::string ToCanonicalString(const HloInstruction* instr) { auto options = HloPrintOptions::Canonical(); if (instr->opcode() != HloOpcode::kFusion) { options.set_print_backend_config(true); @@ -171,21 +305,37 @@ static std::string ToCanonicalString(const HloInstruction* instr) { return instr->called_computations()[0]->ToString(options); } +} // namespace + AutotuneCacheKey::AutotuneCacheKey(absl::string_view model_str, const HloInstruction& instr) : AutotuneCacheKey(model_str, ToCanonicalString(&instr)) {} -static AutotuneResult* TryFindInCache(const AutotuneCacheKey& key) { - absl::MutexLock lock(&autotune_cache_mu); - auto it = autotune_cache.find(key); - if (it != autotune_cache.end()) { - // Cache hit. +namespace { +absl::StatusOr> TryFindInCache( + const AutotuneCacheKey& key, absl::string_view cache_dir) + ABSL_LOCKS_EXCLUDED(autotune_cache_mu) { + std::optional opt_result = TryToFindInInMemoryCache(key); + if (opt_result.has_value()) { + if (VLOG_IS_ON(1)) { + LOG(INFO) << "In-memory autotune cache hit"; + } else if (VLOG_IS_ON(2)) { + LOG(INFO) << "In-memory autotune cache hit: key = " << key.ToString(); + } + return opt_result; + } + + TF_ASSIGN_OR_RETURN(opt_result, + TryToFindInFileBasedCacheIfEnabled(key, cache_dir)); + if (opt_result.has_value()) { + AddResultToInMemoryCache(key, opt_result.value()); + if (VLOG_IS_ON(1)) { - LOG(INFO) << "Autotune cache hit"; + LOG(INFO) << "File-based autotune cache hit"; } else if (VLOG_IS_ON(2)) { - LOG(INFO) << "Autotune cache hit: key = " << key.ToString(); + LOG(INFO) << "File-based autotune cache hit: key = " << key.ToString(); } - return &it->second; + return opt_result; } if (VLOG_IS_ON(1)) { @@ -193,31 +343,39 @@ static AutotuneResult* TryFindInCache(const AutotuneCacheKey& key) { } else if (VLOG_IS_ON(2)) { LOG(INFO) << "Autotune cache miss: key = " << key.ToString(); } - return nullptr; + return std::nullopt; } +} // namespace /*static*/ AutotuneCacheKey AutotunerUtil::GetKey( const HloInstruction* instr, const AutotuneConfig& config) { return AutotuneCacheKey(config.GetModelStr(), *instr); } -/*static*/ bool AutotunerUtil::IsInCache(const AutotuneCacheKey& key) { - return TryFindInCache(key) != nullptr; +/*static*/ absl::StatusOr AutotunerUtil::IsInCache( + const AutotuneCacheKey& key, const AutotuneConfig& config) { + TF_ASSIGN_OR_RETURN(std::optional opt_res, + TryFindInCache(key, config.autotune_cache_dir())); + return opt_res.has_value(); } -/*static*/ bool AutotunerUtil::AddResult(const AutotuneCacheKey& key, - AutotuneResult result) { - absl::MutexLock lock(&autotune_cache_mu); - auto [_, inserted] = autotune_cache.emplace(key, std::move(result)); - return inserted; +/*static*/ absl::StatusOr AutotunerUtil::AddResult( + const AutotuneCacheKey& key, AutotuneResult result, + const AutotuneConfig& config) { + TF_ASSIGN_OR_RETURN( + ResultAndInserted result_and_inserted, + AddResultToCaches(key, std::move(result), config.autotune_cache_dir())); + return result_and_inserted.inserted; } /*static*/ absl::StatusOr AutotunerUtil::Autotune( const HloInstruction* instr, const AutotuneConfig& config, const AutotuneNoCacheFn& autotune_fn) { const AutotuneCacheKey key = GetKey(instr, config); - if (AutotuneResult* res = TryFindInCache(key)) { - return *res; + TF_ASSIGN_OR_RETURN(std::optional opt_res, + TryFindInCache(key, config.autotune_cache_dir())); + if (opt_res.has_value()) { + return opt_res.value(); } // Cache miss. @@ -230,9 +388,10 @@ static AutotuneResult* TryFindInCache(const AutotuneCacheKey& key) { TF_ASSIGN_OR_RETURN(AutotuneResult autotune_result, autotune_fn()); - absl::MutexLock lock(&autotune_cache_mu); - auto [it, inserted] = autotune_cache.emplace(key, autotune_result); - return it->second; + TF_ASSIGN_OR_RETURN(ResultAndInserted result_and_inserted, + AddResultToCaches(key, std::move(autotune_result), + config.autotune_cache_dir())); + return result_and_inserted.result; } namespace { @@ -275,7 +434,7 @@ bool IsTextProtoPath(absl::string_view file_path) { return AutotuneResultsToString(results, as_textproto); } -/* static */ absl::Status AutotunerUtil::SerializeAutotuneResultsToFile( +/*static*/ absl::Status AutotunerUtil::SerializeAutotuneResultsToFile( const AutotuneResults& results, absl::string_view file_path) { TF_RET_CHECK(!file_path.empty()); TF_RET_CHECK(results.version() > 0) diff --git a/third_party/xla/xla/service/gpu/autotuner_util.h b/third_party/xla/xla/service/gpu/autotuner_util.h index c0604b4f58404d..bba517d16566f0 100644 --- a/third_party/xla/xla/service/gpu/autotuner_util.h +++ b/third_party/xla/xla/service/gpu/autotuner_util.h @@ -107,6 +107,8 @@ class AutotuneConfig { bool should_require_complete_aot_autotune_results() const { return require_complete_aot_autotune_results_; } + // Empty string means no cache is used. + const std::string& autotune_cache_dir() const { return autotune_cache_dir_; } AutotuneConfig(const AutotuneConfig& right) : config_(right.config_), @@ -114,7 +116,8 @@ class AutotuneConfig { should_crash_on_check_failure_(right.should_crash_on_check_failure_), exhaustive_tiling_search_(right.exhaustive_tiling_search_), require_complete_aot_autotune_results_( - right.require_complete_aot_autotune_results_) {} + right.require_complete_aot_autotune_results_), + autotune_cache_dir_(right.autotune_cache_dir_) {} AutotuneConfig(const std::variant& config, const DebugOptions& debug_options) @@ -125,7 +128,9 @@ class AutotuneConfig { exhaustive_tiling_search_( debug_options.xla_gpu_exhaustive_tiling_search()), require_complete_aot_autotune_results_( - debug_options.xla_gpu_require_complete_aot_autotune_results()) {} + debug_options.xla_gpu_require_complete_aot_autotune_results()), + autotune_cache_dir_( + debug_options.xla_gpu_per_fusion_autotune_cache_dir()) {} absl::string_view GetModelStr() const { if (auto deviceless_config = std::get_if(&config_)) { @@ -179,6 +184,7 @@ class AutotuneConfig { bool exhaustive_tiling_search_; bool require_complete_aot_autotune_results_; mutable std::unique_ptr allocator_; + std::string autotune_cache_dir_; }; using AutotuneNoCacheFn = std::function()>; @@ -203,14 +209,17 @@ struct AutotunerUtil { // Checks if the key is in the autotune cache. // // Normally, we don't have to use this low level method. - static bool IsInCache(const AutotuneCacheKey& key); + static absl::StatusOr IsInCache(const AutotuneCacheKey& key, + const AutotuneConfig& config); // Adds the result to the autotune cache. // // Returns true if the entry is inserted. // // Normally, we don't have to use this low level method. - static bool AddResult(const AutotuneCacheKey& key, AutotuneResult result); + static absl::StatusOr AddResult(const AutotuneCacheKey& key, + AutotuneResult result, + const AutotuneConfig& config); // Creates a RedzoneAllocator from a given config. static absl::StatusOr CreateRedzoneAllocator( @@ -258,10 +267,18 @@ struct AutotunerUtil { static absl::StatusOr SerializeAutotuneResults( bool as_textproto = false); + // Serializes autotune results into the given proto. static absl::Status SerializeAutotuneResults(AutotuneResults* results); + + // Loads autotune results from the given string of bytes. + // + // Warning: The results are only loaded to the in-memory cache. static absl::Status LoadAutotuneResults(absl::string_view data, bool as_textproto = false); + // Loads autotune results from the given proto. + // + // Warning: The results are only loaded to the in-memory cache. static absl::Status LoadAutotuneResults(const AutotuneResults& results); // Serializes autotune results into a file. @@ -281,16 +298,31 @@ struct AutotunerUtil { // If `file_path` ends with ".txt" or ".textproto", then the file is // considered to be in the textproto format, otherwise the binary protobuf // format. + // + // Warning: The results are only loaded to the in-memory cache. static absl::Status LoadAutotuneResultsFromFile(absl::string_view file_path); + // Warning: This only clears the in-memory cache. If you use a file based + // cache you're responsible for clearing the cache directory when you want to. static void ClearAutotuneResults(); + // Warning: This only checks the in-memory cache. If you use a file based + // cache, you're responsible for checking whether the cache directory is + // empty. static bool ResultCacheIsEmpty(); }; absl::StatusOr AutotuneResultsToString( const AutotuneResults& results, bool as_textproto); +// Exposed only for testing. Returns the SHA-256 hash of the input string, +// encoded in base64. +// +// SHA-256 was chosen to follow industry best practices and avoid collisions. +// Git is also transitioning to SHA-256. This is probably better than +// tsl::Fingerprint128. +absl::StatusOr GetBase64EncodedSha256Hash(absl::string_view s); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/autotuner_util_test.cc b/third_party/xla/xla/service/gpu/autotuner_util_test.cc index 79b382d8455568..d6a00f95e08dcc 100644 --- a/third_party/xla/xla/service/gpu/autotuner_util_test.cc +++ b/third_party/xla/xla/service/gpu/autotuner_util_test.cc @@ -22,25 +22,36 @@ limitations under the License. #include #include #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/autotune_results.pb.h" +#include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla.pb.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" -#include "tsl/platform/logging.h" // IWYU pragma: keep +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" // IWYU pragma: keep +#include "tsl/platform/path.h" #include "tsl/platform/protobuf.h" // IWYU pragma: keep +#include "tsl/platform/status.h" #include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { namespace { +using ::testing::ElementsAre; using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::Not; @@ -98,7 +109,7 @@ results { return str; } - std::unique_ptr NewStreamExecutor() { + static std::unique_ptr NewStreamExecutor() { stream_executor::Platform* platform = stream_executor::PlatformManager::PlatformWithName("Host").value(); stream_executor::StreamExecutorConfig config(/*ordinal=*/0); @@ -245,6 +256,172 @@ TEST_F(AutotunerUtilTest, OkIfJitAutotuningDisabledButAlreadyLoadedAOT) { }).status()); } +class FileBasedCacheTest : public AutotunerUtilTest { + public: + static std::string ToString(const proto2::Message& message) { + std::string textproto; + CHECK(tsl::protobuf::TextFormat::PrintToString(message, &textproto)); + return textproto; + } + + static std::vector GetFilesInDir( + const absl::string_view cache_dir) { + std::vector files_in_cache; + TF_CHECK_OK(tsl::Env::Default()->GetChildren(std::string(cache_dir), + &files_in_cache)); + return files_in_cache; + } + + static std::string Read(const absl::string_view filepath) { + std::string file_content; + TF_CHECK_OK(tsl::ReadFileToString(tsl::Env::Default(), + std::string(filepath), &file_content)); + return file_content; + } + + static void Write(const absl::string_view filepath, + const absl::string_view content) { + TF_CHECK_OK(tsl::WriteStringToFile(tsl::Env::Default(), + std::string(filepath), content)); + } + + std::unique_ptr executor_ = + NewStreamExecutor(); + std::unique_ptr module_ = + ParseAndReturnVerifiedModule(kHloText).value(); + const HloInstruction* dot_ = hlo_query::GetFirstInstructionWithOpcode( + *module_->entry_computation(), HloOpcode::kDot); + std::string cache_dir_ = [] { + tsl::Env* default_env = tsl::Env::Default(); + std::string cache_dir; + CHECK(default_env->LocalTempFilename(&cache_dir)); + CHECK_OK(default_env->CreateDir(cache_dir)); + return cache_dir; + }(); + AutotuneConfig config_ = AutotuneConfig(DeviceConfig{executor_.get()}, [&] { + DebugOptions options; + options.set_xla_gpu_per_fusion_autotune_cache_dir(cache_dir_); + return options; + }()); + AutotuneCacheKey cache_key_ = AutotunerUtil::GetKey(dot_, config_); + std::string cache_filename_ = [&] { + absl::StatusOr key_hash = + GetBase64EncodedSha256Hash(cache_key_.ToString()); + CHECK_OK(key_hash.status()); + return absl::StrCat(key_hash.value(), ".textproto"); + }(); + std::string cache_file_path_ = tsl::io::JoinPath(cache_dir_, cache_filename_); + const AutotuneResult result1_ = [] { + AutotuneResult result; + result.set_scratch_bytes(1); + return result; + }(); + const AutotuneResult result2_ = [] { + AutotuneResult result; + result.set_scratch_bytes(2); + return result; + }(); +}; + +TEST_F(FileBasedCacheTest, AutotuneWritesResultToTheCacheDir) { + TF_ASSERT_OK_AND_ASSIGN( + AutotuneResult result, + AutotunerUtil::Autotune(dot_, config_, [&] { return result1_; })); + EXPECT_EQ(ToString(result), ToString(result1_)); + + ASSERT_THAT(GetFilesInDir(cache_dir_), ElementsAre(cache_filename_)); + EXPECT_EQ(Read(cache_file_path_), ToString(result1_)); +} + +TEST_F(FileBasedCacheTest, AutotuneReadsResultFromTheCacheDir) { + Write(cache_file_path_, ToString(result1_)); + + bool cache_hit = true; + TF_ASSERT_OK_AND_ASSIGN(AutotuneResult result, + AutotunerUtil::Autotune(dot_, config_, [&] { + cache_hit = false; + return result2_; + })); + + EXPECT_TRUE(cache_hit); + EXPECT_EQ(ToString(result), ToString(result1_)); +} + +TEST_F(FileBasedCacheTest, + RepeatedAutotuneCallsDontReadOrWriteTheCacheFileAgain) { + auto check_autotune_cache_hit = [](const HloInstruction* instr, + const AutotuneConfig& config, + const AutotuneResult& expected_result) { + bool cache_hit = true; + TF_ASSERT_OK_AND_ASSIGN(AutotuneResult result, + AutotunerUtil::Autotune(instr, config, [&] { + cache_hit = false; + AutotuneResult new_result; + new_result.set_scratch_bytes(2); + return new_result; + })); + EXPECT_TRUE(cache_hit); + EXPECT_EQ(ToString(result), ToString(expected_result)); + }; + + Write(cache_file_path_, ToString(result1_)); + check_autotune_cache_hit(dot_, config_, /*expected_result=*/result1_); + + constexpr absl::string_view kPlaceholderContent = "placeholder content"; + Write(cache_file_path_, kPlaceholderContent); + // File was not read again: + check_autotune_cache_hit(dot_, config_, /*expected_result=*/result1_); + // File was not written again: + EXPECT_EQ(Read(cache_file_path_), kPlaceholderContent); +} + +TEST_F(FileBasedCacheTest, + IsInCacheReturnsTrueIfTheResultIsInTheFileBasedCache) { + Write(cache_file_path_, ToString(result1_)); + + TF_ASSERT_OK_AND_ASSIGN(bool is_in_cache, + AutotunerUtil::IsInCache(cache_key_, config_)); + + EXPECT_TRUE(is_in_cache); +} + +TEST_F(FileBasedCacheTest, IsInCacheReturnsFalseIfTheResultIsNotInEitherCache) { + TF_ASSERT_OK_AND_ASSIGN(bool is_in_cache, + AutotunerUtil::IsInCache(cache_key_, config_)); + + EXPECT_FALSE(is_in_cache); +} + +TEST_F(FileBasedCacheTest, AddResultAddsTheResultToTheFileBasedCache) { + TF_ASSERT_OK_AND_ASSIGN( + bool added, AutotunerUtil::AddResult(cache_key_, result1_, config_)); + EXPECT_TRUE(added); + + ASSERT_THAT(GetFilesInDir(cache_dir_), ElementsAre(cache_filename_)); + EXPECT_EQ(Read(cache_file_path_), ToString(result1_)); +} + +TEST_F(FileBasedCacheTest, RepeatedAddResultDoesNotWriteTheFileAgain) { + { + TF_ASSERT_OK_AND_ASSIGN( + bool added, AutotunerUtil::AddResult(cache_key_, result1_, config_)); + EXPECT_TRUE(added); + } + ASSERT_THAT(GetFilesInDir(cache_dir_), ElementsAre(cache_filename_)); + EXPECT_EQ(Read(cache_file_path_), ToString(result1_)); + constexpr absl::string_view kPlaceholderContent = "placeholder content"; + Write(cache_file_path_, kPlaceholderContent); + + { + TF_ASSERT_OK_AND_ASSIGN( + bool added, AutotunerUtil::AddResult(cache_key_, result1_, config_)); + EXPECT_FALSE(added); + } + + // File was not written again: + EXPECT_EQ(Read(cache_file_path_), kPlaceholderContent); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc index 5a09b09c2f875a..b11d4e4f2ff5a1 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc @@ -247,7 +247,9 @@ class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault { ++(iterator->second); } - if (AutotunerUtil::IsInCache(key) || handled_fusions_.contains(key)) { + TF_ASSIGN_OR_RETURN(bool is_in_cache, + AutotunerUtil::IsInCache(key, impl_->GetConfig())); + if (is_in_cache || handled_fusions_.contains(key)) { return absl::OkStatus(); } @@ -1097,7 +1099,9 @@ absl::Status GemmFusionAutotunerImpl::Autotune( } const AutotuneCacheKey key = AutotunerUtil::GetKey(fusion, config_); - if (!AutotunerUtil::AddResult(key, std::move(best))) { + TF_ASSIGN_OR_RETURN( + bool added, AutotunerUtil::AddResult(key, std::move(best), config_)); + if (!added) { // In the context of model server, concurrent autotuning is expected and // insertion of identical autotuning keys is accepted. LOG(WARNING) << "AutotunerUtil::AddResult already existed: " @@ -1192,7 +1196,7 @@ absl::StatusOr GemmFusionAutotuner::Run( AutotuneResult res = FromConfig(tilings[0]); *res.mutable_run_time() = tsl::proto_utils::ToDurationProto(absl::ZeroDuration()); - AutotunerUtil::AddResult(key, res); + TF_RETURN_IF_ERROR(AutotunerUtil::AddResult(key, res, config_).status()); } } else if (!debug_options.xla_gpu_override_gemm_autotuner().empty()) { // TODO(gflegar): support overriding with non-Triton configs (cuBLAS, cuDNN) @@ -1207,7 +1211,7 @@ absl::StatusOr GemmFusionAutotuner::Run( *res.mutable_triton() = gemm_key; *res.mutable_run_time() = tsl::proto_utils::ToDurationProto(absl::ZeroDuration()); - AutotunerUtil::AddResult(key, res); + TF_RETURN_IF_ERROR(AutotunerUtil::AddResult(key, res, config_).status()); } } else if (!config_.IsDeviceless()) { TF_ASSIGN_OR_RETURN(std::optional opt_compile_util, diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 4b150bf7164072..d9e4594d2b4708 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -824,7 +824,9 @@ message DebugOptions { // unsafe flag. bool xla_gpu_unsafe_pipelined_loop_annotator = 309; - // Next id: 310 + string xla_gpu_per_fusion_autotune_cache_dir = 310; + + // Next id: 311 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From 8f9cee48f69e1ecc8798adb0dd89a4aa1c9aa5e2 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Tue, 18 Jun 2024 09:13:17 -0700 Subject: [PATCH 21/59] Adopt ConvertMlirHloToHloModule instead of passing in proto in PJRT PiperOrigin-RevId: 644408173 --- third_party/xla/xla/pjrt/BUILD | 1 + third_party/xla/xla/pjrt/mlir_to_hlo.cc | 15 +++++++++++---- third_party/xla/xla/pjrt/pjrt_c_api_client.cc | 15 +++++---------- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 0bb55a7fd0cc31..545bec940f4f95 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -593,6 +593,7 @@ cc_library( "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", + "@local_tsl//tsl/platform:statusor", "@stablehlo//:chlo_ops", "@stablehlo//:register", "@stablehlo//:stablehlo_ops", diff --git a/third_party/xla/xla/pjrt/mlir_to_hlo.cc b/third_party/xla/xla/pjrt/mlir_to_hlo.cc index bc8fd12186e594..afada99b7727aa 100644 --- a/third_party/xla/xla/pjrt/mlir_to_hlo.cc +++ b/third_party/xla/xla/pjrt/mlir_to_hlo.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/pjrt/mlir_to_hlo.h" #include +#include #include #include #include @@ -61,6 +62,7 @@ limitations under the License. #include "xla/service/spmd/shardonnay/utils.h" #include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -228,7 +230,6 @@ absl::Status MlirToXlaComputation(mlir::ModuleOp module, } } - HloProto proto; // TODO(b/345414638): Delete when we move Shardonnay as the first pass in the // XLA pipeline. if (use_tuple_args && GetDebugOptionsFromFlags().xla_use_shardonnay()) { @@ -238,10 +239,16 @@ absl::Status MlirToXlaComputation(mlir::ModuleOp module, mlir::StringAttr::get(context, "t")); use_tuple_args = false; } - TF_RETURN_IF_ERROR( - ConvertMlirHloToHlo(module, &proto, use_tuple_args, return_tuple)); - xla_computation = XlaComputation(std::move(*proto.mutable_hlo_module())); + // create config options use use_tuple_args, return_tuple set: + mlir::MlirToHloConversionOptions options; + options.use_tuple_args = use_tuple_args; + options.return_tuple = return_tuple; + + TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, + mlir::ConvertMlirHloToHloModule(module, options)); + + xla_computation = XlaComputation(hlo_module->ToProto()); return absl::OkStatus(); } diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc index 1160ca0b1ca6ce..5bee5fb9ce7915 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc @@ -1173,17 +1173,12 @@ PjRtCApiExecutable::GetHloModules() const { return xla::Internal("failed to convert to MHLO"); // TODO(jieying): Tuple args should really come from GetCompileOptions (or // equivalent) once implemented. - TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module.get(), &hlo_proto, - /*use_tuple_args=*/false, - /*return_tuple=*/false)); - xla::DebugOptions debug_options; - TF_ASSIGN_OR_RETURN(xla::HloModuleConfig module_config, - xla::HloModule::CreateModuleConfigFromProto( - hlo_proto.hlo_module(), debug_options)); + mlir::MlirToHloConversionOptions options; + options.return_tuple = false; + TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, + mlir::ConvertMlirHloToHloModule(module.get(), options)); + std::vector> out; - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_module, - HloModule::CreateFromProto(hlo_proto.hlo_module(), module_config)); out.push_back(std::move(hlo_module)); return out; } From 8c8fb8fdbf3f5107f6646d544e221bae714963ea Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Tue, 18 Jun 2024 09:30:47 -0700 Subject: [PATCH 22/59] [XLA:GPU][MLIR-based emitters] Add more tests for row reduction indexing. PiperOrigin-RevId: 644413897 --- .../gpu/fusions/reduction_mlir_test.cc | 266 ++++++++++++++++++ 1 file changed, 266 insertions(+) diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc index 321fadd33b5902..1bc6d48da5363b 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc @@ -230,6 +230,46 @@ TEST_F(MlirRowReductionTest, F64RowReduction) { c = f64[] constant(0) ROOT fusion = f64[100] fusion(a, c), kind=kInput, calls=fused_computation })"; + auto module = ParseAndReturnVerifiedModule(kHloString).value(); + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + MlirRowReductionFusion fusion(analysis); + + EXPECT_THAT( + fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4] -> ( + d3 * 8 + d0 floordiv 32, + (d0 mod 32 + s2 * 32) * 2 + s3) + domain: + d0 in [0, 255] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 12] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 0] + s1 in [0, 0] + s2 in [0, 1] + s3 in [0, 1] + s4 in [0, 0] + d0 mod 32 + s2 * 32 in [0, 63] + d3 * 8 + d0 floordiv 32 in [0, 99] + )")); + EXPECT_THAT( + fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5) -> (d3 * 8 + d0 floordiv 32) + domain: + d0 in [0, 255] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 12] + d4 in [0, 0] + d5 in [0, 0] + d0 mod 32 in [0, 0] + d3 * 8 + d0 floordiv 32 in [0, 99] + )")); // This reduction is small enough not to require shared memory. TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( // CHECK-NOT: allocate_shared @@ -256,6 +296,44 @@ TEST_F(MlirRowReductionTest, MultiRowReduction) { c = f32[] constant(0) ROOT fusion = f32[1024] fusion(a, c), kind=kInput, calls=fused_computation })"; + auto module = ParseAndReturnVerifiedModule(kHloString).value(); + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + MlirRowReductionFusion fusion(analysis); + + EXPECT_THAT( + fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( + d3 * 64 + d0 floordiv 4, d0 mod 4) + domain: + d0 in [0, 255] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 15] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 0] + s1 in [0, 0] + s2 in [0, 0] + s3 in [0, 0] + d0 mod 4 in [0, 3] + d3 * 64 + d0 floordiv 4 in [0, 1023] + )")); + EXPECT_THAT( + fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5) -> (d3 * 64 + d0 floordiv 4) + domain: + d0 in [0, 255] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 15] + d4 in [0, 0] + d5 in [0, 0] + d0 mod 4 in [0, 0] + d3 * 64 + d0 floordiv 4 in [0, 1023] + )")); // Multi-row reductions don't use shared memory. TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( // CHECK: shuffle_reduce {{.*}} to 2 @@ -330,6 +408,42 @@ TEST_F(MlirRowReductionTest, NonTrivialEpilogue) { ROOT %fusion = (f64[], f64[], f64[]) fusion(%p0, %p1), kind=kInput, calls=fusion })"; + auto module = ParseAndReturnVerifiedModule(kHloString).value(); + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + MlirRowReductionFusion fusion(analysis); + + EXPECT_THAT( + fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( + (d0 floordiv 4) * 4 + d0 mod 4) + domain: + d0 in [0, 3] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 0] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 0] + s1 in [0, 0] + s2 in [0, 0] + s3 in [0, 0] + d0 mod 4 in [0, 3] + )")); + EXPECT_THAT( + fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5) -> () + domain: + d0 in [0, 3] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 0] + d4 in [0, 0] + d5 in [0, 0] + d0 mod 4 in [0, 0] + )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } @@ -355,6 +469,45 @@ TEST_F(MlirRowReductionTest, SideOutput) { ROOT fusion = (f32[8], f32[8,2048]) fusion(a, c), kind=kInput, calls=fused_computation })"; + auto module = ParseAndReturnVerifiedModule(kHloString).value(); + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + MlirRowReductionFusion fusion(analysis); + + EXPECT_THAT( + fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4] -> ( + d3 * 2 + d0 floordiv 128, (d0 mod 128 + s2 * 128) * 2 + s3) + domain: + d0 in [0, 255] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 3] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 0] + s1 in [0, 0] + s2 in [0, 7] + s3 in [0, 1] + s4 in [0, 0] + d0 mod 128 + s2 * 128 in [0, 1023] + d3 * 2 + d0 floordiv 128 in [0, 7] + )")); + EXPECT_THAT( + fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5) -> (d3 * 2 + d0 floordiv 128) + domain: + d0 in [0, 255] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 3] + d4 in [0, 0] + d5 in [0, 0] + d0 mod 128 in [0, 0] + d3 * 2 + d0 floordiv 128 in [0, 7] + )")); TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( // CHECK: @fused_computation // CHECK: scf.for @@ -387,6 +540,45 @@ TEST_F(MlirRowReductionTest, UnsignedSideOutput) { ROOT fusion = (u32[8], u32[8,2048]) fusion(a, c), kind=kInput, calls=fused_computation })"; + auto module = ParseAndReturnVerifiedModule(kHloString).value(); + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + MlirRowReductionFusion fusion(analysis); + + EXPECT_THAT( + fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4] -> ( + d3 * 2 + d0 floordiv 128, (d0 mod 128 + s2 * 128) * 2 + s3) + domain: + d0 in [0, 255] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 3] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 0] + s1 in [0, 0] + s2 in [0, 7] + s3 in [0, 1] + s4 in [0, 0] + d0 mod 128 + s2 * 128 in [0, 1023] + d3 * 2 + d0 floordiv 128 in [0, 7] + )")); + EXPECT_THAT( + fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5) -> (d3 * 2 + d0 floordiv 128) + domain: + d0 in [0, 255] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 3] + d4 in [0, 0] + d5 in [0, 0] + d0 mod 128 in [0, 0] + d3 * 2 + d0 floordiv 128 in [0, 7] + )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } @@ -408,7 +600,47 @@ TEST_F(MlirRowReductionTest, BroadcastSideOutput) { %p0 = f32[6,6] parameter(0) ROOT %fusion = (f32[6,6], f32[]) fusion(%p0), kind=kInput, calls=%fusion })"; + auto module = ParseAndReturnVerifiedModule(kHloString).value(); + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + MlirRowReductionFusion fusion(analysis); + EXPECT_THAT( + fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> () + domain: + d0 in [0, 31] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 0] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 0] + s1 in [0, 0] + s2 in [0, 1] + s3 in [0, 0] + (d0 + s2 * 32) mod 6 in [0, 5] + d0 + s2 * 32 in [0, 35] + )")); + EXPECT_THAT( + fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( + (d0 + s2 * 32) floordiv 6, (d0 + s2 * 32) mod 6) + domain: + d0 in [0, 31] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 0] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 0] + s1 in [0, 0] + s2 in [0, 1] + s3 in [0, 0] + d0 + s2 * 32 in [0, 35] + )")); TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( // CHECK: @fused_computation )")); @@ -443,7 +675,41 @@ TEST_F(MlirRowReductionTest, VariadicMOF) { %p0 = f32[6,6] parameter(0) ROOT %fusion = (f32[], (f32[], f32[]), f32[6,6]) fusion(%p0), kind=kInput, calls=%fusion })"; + auto module = ParseAndReturnVerifiedModule(kHloString).value(); + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + MlirRowReductionFusion fusion(analysis); + EXPECT_THAT( + fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( + (d0 + s2 * 32) floordiv 6, (d0 + s2 * 32) mod 6) + domain: + d0 in [0, 31] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 0] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 0] + s1 in [0, 0] + s2 in [0, 1] + s3 in [0, 0] + d0 + s2 * 32 in [0, 35] + )")); + EXPECT_THAT( + fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5) -> () + domain: + d0 in [0, 0] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 0] + d4 in [0, 0] + d5 in [0, 0] + )")); TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( // CHECK: @fused_computation )")); From 1ae6cf22040fe8b709d2e796d61fd4722fc235a9 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Tue, 18 Jun 2024 09:57:01 -0700 Subject: [PATCH 23/59] Re-enable tensorflow/compiler/tests:async_comp_test_gpu test PiperOrigin-RevId: 644422132 --- tensorflow/compiler/tests/async_comp_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/compiler/tests/async_comp_test.py b/tensorflow/compiler/tests/async_comp_test.py index 0dd3eec2e198d3..7d811cf44b6041 100644 --- a/tensorflow/compiler/tests/async_comp_test.py +++ b/tensorflow/compiler/tests/async_comp_test.py @@ -15,7 +15,6 @@ """Tests for asynchronous compilation on the CPU and GPU devices.""" import os -import unittest from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session as session_lib @@ -52,7 +51,6 @@ class AsyncCompilationTest(test.TestCase): # Asynchrobnous compilation uses the existing fallback path and existing # compiler. This test only tests that asynchronous compilation is performed. - @unittest.skip("b/263146341 - flaky Kokoro build.") def testAsyncCompilationJit(self): @function.Defun(compiled=True) From bcbab5246a2d4bc9d4145d752e756f1226c32a0d Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Tue, 18 Jun 2024 11:13:43 -0700 Subject: [PATCH 24/59] [XLA:GPU] Replace TritonFusionAnalysis with SymbolicTileAnalysis in PriorityFusion. TritonFusionAnalysis is deprecated and we should be using SymbolicTileAnalysis. This is preparation to start doing Triton fusions with the new tiling infrastructure. This is a safe change, because the feature is still behind a flag. PiperOrigin-RevId: 644451491 --- third_party/xla/xla/service/gpu/BUILD | 2 ++ .../xla/xla/service/gpu/priority_fusion.cc | 28 +++++++++++++------ .../xla/xla/service/gpu/priority_fusion.h | 2 ++ 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 7b5b1a36d716d8..c097abd3c30c53 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2456,6 +2456,7 @@ cc_library( "//xla/service/gpu/model:gpu_hlo_cost_analysis", "//xla/service/gpu/model:gpu_performance_model", "//xla/service/gpu/model:gpu_performance_model_base", + "//xla/service/gpu/model:symbolic_tile_analysis", "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -2467,6 +2468,7 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:blocking_counter", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc index 845614ed605503..b079f3a2753754 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/container/flat_hash_map.h" @@ -36,6 +37,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "llvm/ADT/STLExtras.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -51,7 +53,7 @@ limitations under the License. #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_performance_model.h" #include "xla/service/gpu/model/gpu_performance_model_base.h" -#include "xla/service/gpu/triton_fusion_analysis.h" +#include "xla/service/gpu/model/symbolic_tile_analysis.h" #include "xla/service/hlo_graph_dumper.h" #include "xla/service/instruction_fusion.h" #include "xla/shape.h" @@ -130,13 +132,14 @@ class GpuPriorityFusionQueue { const GpuHloCostAnalysis::Options& cost_analysis_options, const se::DeviceDescription* device_info, FusionProcessDumpProto* fusion_process_dump, - tsl::thread::ThreadPool* thread_pool, + tsl::thread::ThreadPool* thread_pool, mlir::MLIRContext* mlir_context, HloFusionAnalysisCache& fusion_analysis_cache, bool triton_softmax_priority_fusion_enabled) : computation_(computation), cost_analysis_(cost_analysis_options, device_info), fusion_process_dump_(fusion_process_dump), thread_pool_(thread_pool), + mlir_context_(mlir_context), fusion_analysis_cache_(fusion_analysis_cache), triton_softmax_priority_fusion_enabled_( triton_softmax_priority_fusion_enabled) { @@ -428,11 +431,16 @@ class GpuPriorityFusionQueue { } } - // TODO(b/316143118): Replace TritonFusionAnalysis with SymbolicTileAnalysis - // once symbolic analysis is ready. - if (!TritonFusionAnalysis::ExecuteForProducerConsumer(*producer, *consumer) - .ok()) { - return "triton codegen can't handle the fusion"; + auto fusion = HloFusionAdaptor::ForProducerConsumer(producer, consumer); + + SymbolicTileAnalysisOrError symbolic_tile_analysis_or = + SymbolicTileAnalysis::AnalyzeFusion(*fusion, mlir_context_); + + if (const auto* fusion_decision = + std::get_if(&symbolic_tile_analysis_or)) { + return { + absl::StrCat("Fusion can not be tiled with SymbolicTileAnalysis: ", + fusion_decision->Explain())}; } return {}; @@ -615,6 +623,8 @@ class GpuPriorityFusionQueue { tsl::thread::ThreadPool* thread_pool_; + mlir::MLIRContext* mlir_context_; + HloFusionAnalysisCache& fusion_analysis_cache_; // Caches result of can_fuse for a (producer, consumer) pair. A cache entry is @@ -724,8 +734,8 @@ absl::StatusOr GpuPriorityFusion::Run( auto fusion_queue = std::make_unique( computation, cost_analysis_options_, &device_info_, - fusion_process_dump_.get(), thread_pool_, fusion_analysis_cache_, - triton_softmax_priority_fusion_enabled); + fusion_process_dump_.get(), thread_pool_, &mlir_context_, + fusion_analysis_cache_, triton_softmax_priority_fusion_enabled); while (fusion_queue->DequeueNextProducer()) { auto producer = fusion_queue->current_producer(); diff --git a/third_party/xla/xla/service/gpu/priority_fusion.h b/third_party/xla/xla/service/gpu/priority_fusion.h index c8e45da90cb313..79e2eda54d82fd 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.h +++ b/third_party/xla/xla/service/gpu/priority_fusion.h @@ -89,6 +89,8 @@ class GpuPriorityFusion : public InstructionFusion { std::unique_ptr fusion_process_dump_; HloFusionAnalysisCache fusion_analysis_cache_; + + mlir::MLIRContext mlir_context_; }; } // namespace gpu From 42860aaccf1e5e08ec1de45a1f1e0e5da90d2b9a Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Tue, 18 Jun 2024 11:14:25 -0700 Subject: [PATCH 25/59] Change default permissions for github actions workflows to satisfy security scorecard PiperOrigin-RevId: 644451735 --- third_party/xla/.github/workflows/buildifier.yml | 1 + third_party/xla/.github/workflows/check_contents.yml | 1 + third_party/xla/.github/workflows/clang_format.yml | 1 + .../xla/.github/workflows/rollback_notification.yml | 4 ++-- .../xla/.github/workflows/trusted-partners.yml | 11 ++++++----- 5 files changed, 11 insertions(+), 7 deletions(-) diff --git a/third_party/xla/.github/workflows/buildifier.yml b/third_party/xla/.github/workflows/buildifier.yml index 00413b1f505d4e..55140675aa28c3 100644 --- a/third_party/xla/.github/workflows/buildifier.yml +++ b/third_party/xla/.github/workflows/buildifier.yml @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ name: Buildifier +permissions: read-all on: pull_request: diff --git a/third_party/xla/.github/workflows/check_contents.yml b/third_party/xla/.github/workflows/check_contents.yml index abbf80a18db631..fd38adfd0adda3 100644 --- a/third_party/xla/.github/workflows/check_contents.yml +++ b/third_party/xla/.github/workflows/check_contents.yml @@ -19,6 +19,7 @@ # files once XLA moves out of Tensorflow internally. # TODO(ddunleavy): Update this after METADATA files are consolidated. name: Check Contents +permissions: read-all on: pull_request: diff --git a/third_party/xla/.github/workflows/clang_format.yml b/third_party/xla/.github/workflows/clang_format.yml index c262edb2385ad4..0bb143c3c9fc1f 100644 --- a/third_party/xla/.github/workflows/clang_format.yml +++ b/third_party/xla/.github/workflows/clang_format.yml @@ -14,6 +14,7 @@ # ============================================================================ name: Clang Format +permissions: read-all on: pull_request: diff --git a/third_party/xla/.github/workflows/rollback_notification.yml b/third_party/xla/.github/workflows/rollback_notification.yml index 0ef1d434e8f17c..3d3658efdf5bf7 100644 --- a/third_party/xla/.github/workflows/rollback_notification.yml +++ b/third_party/xla/.github/workflows/rollback_notification.yml @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================ name: Rollback Notification +permissions: + pull-requests: write on: push: branches: @@ -21,8 +23,6 @@ on: jobs: rollback-notification: if: contains(github.event.head_commit.message, 'revert') - permissions: - pull-requests: write runs-on: ubuntu-22.04 defaults: run: diff --git a/third_party/xla/.github/workflows/trusted-partners.yml b/third_party/xla/.github/workflows/trusted-partners.yml index ef3f752727880e..4adef4e7478c3f 100644 --- a/third_party/xla/.github/workflows/trusted-partners.yml +++ b/third_party/xla/.github/workflows/trusted-partners.yml @@ -14,17 +14,18 @@ # ============================================================================== name: Trusted Partner PR + +permissions: + # Needed to attach tags into the PR + issues: write + pull-requests: write + on: pull_request_target: jobs: assign-partner-tags: runs-on: ubuntu-latest - permissions: - # Needed to attach tags into the PR - issues: write - contents: write - pull-requests: write if: | github.event.pull_request.draft == false && github.event.sender.type == 'User' From 042e9b857d26d0873e548a162fae7bf2d44fc180 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 18 Jun 2024 11:14:56 -0700 Subject: [PATCH 26/59] Clean up TF deps tf_to_xla_attribute_utils PiperOrigin-RevId: 644451921 --- tensorflow/compiler/mlir/lite/core/c/BUILD | 25 ++++++++ .../mlir/lite/core/c/builtin_op_data.h | 27 +++++++++ tensorflow/compiler/mlir/lite/kernels/BUILD | 25 ++++++++ .../compiler/mlir/lite/kernels/padding.h | 59 +++++++++++++++++++ .../mlir/quantization/tensorflow/utils/BUILD | 3 +- .../utils/tf_to_xla_attribute_utils.cc | 13 ++-- 6 files changed, 147 insertions(+), 5 deletions(-) create mode 100644 tensorflow/compiler/mlir/lite/core/c/BUILD create mode 100644 tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h create mode 100644 tensorflow/compiler/mlir/lite/kernels/BUILD create mode 100644 tensorflow/compiler/mlir/lite/kernels/padding.h diff --git a/tensorflow/compiler/mlir/lite/core/c/BUILD b/tensorflow/compiler/mlir/lite/core/c/BUILD new file mode 100644 index 00000000000000..02e10ba9de270b --- /dev/null +++ b/tensorflow/compiler/mlir/lite/core/c/BUILD @@ -0,0 +1,25 @@ +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") +load("//tensorflow/compiler/mlir/lite:build_def.bzl", "tflite_copts") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], +) + +# LINT.IfChange(cc_library_common) +cc_library( + name = "common", + srcs = [], + hdrs = ["builtin_op_data.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts(), + visibility = [ + "//tensorflow/compiler/mlir/lite/kernels:__pkg__", + "//tensorflow/compiler/mlir/quantization/tensorflow/utils:__pkg__", + ], + alwayslink = 1, # Why?? TODO(b/161243354): eliminate this. +) +# LINT.ThenChange(//tensorflow/lite/core/c/BUILD) diff --git a/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h b/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h new file mode 100644 index 00000000000000..8185215ba0cf8c --- /dev/null +++ b/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h @@ -0,0 +1,27 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_BUILTIN_OP_DATA_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_BUILTIN_OP_DATA_H_ + +// LINT.IfChange +typedef enum { + kTfLitePaddingUnknown = 0, + kTfLitePaddingSame, + kTfLitePaddingValid, +} TfLitePadding; + +// LINT.ThenChange(//tensorflow/lite/core/c/builtin_op_data.h) + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/compiler/mlir/lite/kernels/BUILD b/tensorflow/compiler/mlir/lite/kernels/BUILD new file mode 100644 index 00000000000000..ccc0433bc66c76 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/kernels/BUILD @@ -0,0 +1,25 @@ +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], +) + +# LINT.IfChange(cc_library_padding) +cc_library( + name = "padding", + srcs = [], + hdrs = ["padding.h"], + compatible_with = get_compatible_with_portable(), + visibility = [ + "//tensorflow/compiler/mlir/lite/kernels:__pkg__", + "//tensorflow/compiler/mlir/quantization/tensorflow/utils:__pkg__", + ], + deps = [ + "//tensorflow/compiler/mlir/lite/core/c:common", + ], +) +# LINT.ThenChange(//tensorflow/lite/kernels/BUILD) diff --git a/tensorflow/compiler/mlir/lite/kernels/padding.h b/tensorflow/compiler/mlir/lite/kernels/padding.h new file mode 100644 index 00000000000000..b0dd6daf44dafc --- /dev/null +++ b/tensorflow/compiler/mlir/lite/kernels/padding.h @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_PADDING_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_PADDING_H_ + +// LINT.IfChange +#include "tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h" + +namespace tflite_migration { + +// Matching GetWindowedOutputSize in TensorFlow. +inline int ComputeOutSize(TfLitePadding padding, int image_size, + int filter_size, int stride, int dilation_rate = 1) { + int effective_filter_size = (filter_size - 1) * dilation_rate + 1; + + // TODO(b/186448822): This uses 0 since the function has no other way to + // report error case + if (stride == 0) return 0; + + switch (padding) { + case kTfLitePaddingSame: + return (image_size + stride - 1) / stride; + case kTfLitePaddingValid: + return (image_size + stride - effective_filter_size) / stride; + default: + return 0; + } +} + +// It's not guaranteed that padding is symmetric. It's important to keep +// offset for algorithms need all paddings. +inline int ComputePaddingWithOffset(int stride, int dilation_rate, int in_size, + int filter_size, int out_size, + int* offset) { + int effective_filter_size = (filter_size - 1) * dilation_rate + 1; + int total_padding = + ((out_size - 1) * stride + effective_filter_size - in_size); + total_padding = total_padding > 0 ? total_padding : 0; + *offset = total_padding % 2; + return total_padding / 2; +} + +} // namespace tflite_migration + +// LINT.ThenChange(//tensorflow/lite/kernels/padding.h) + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_PADDING_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD index 4397b4fc5a3f2d..ddfc905acd2365 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD @@ -79,9 +79,10 @@ cc_library( hdrs = ["tf_to_xla_attribute_utils.h"], compatible_with = get_compatible_with_portable(), deps = [ + "//tensorflow/compiler/mlir/lite/core/c:common", + "//tensorflow/compiler/mlir/lite/kernels:padding", "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:constant_fold", - "//tensorflow/lite/kernels:padding", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:Support", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.cc b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.cc index b22726de30aeaa..1ed9a143b7437b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.cc @@ -21,10 +21,11 @@ limitations under the License. #include "absl/strings/str_format.h" #include "llvm/ADT/ArrayRef.h" #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h" +#include "tensorflow/compiler/mlir/lite/kernels/padding.h" #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.h" #include "xla/xla_data.pb.h" -#include "tensorflow/lite/kernels/padding.h" namespace mlir::quant { namespace { @@ -188,12 +189,16 @@ Value CalculatePaddingAndPadIfNeeded(OpBuilder &builder, Location loc, int filter_size = filter_shape.getDimSize(i - 1); int stride_i = mlir::cast(strides[i]).getInt(); int dilation_i = mlir::cast(dilations[i]).getInt(); - int out_size = tflite::ComputeOutSize(kTfLitePaddingSame, input_size, - filter_size, stride_i, dilation_i); + + // LINT.IfChange + int out_size = tflite_migration::ComputeOutSize( + kTfLitePaddingSame, input_size, filter_size, stride_i, dilation_i); int offset = 0; - int padding_before = tflite::ComputePaddingWithOffset( + int padding_before = tflite_migration::ComputePaddingWithOffset( stride_i, dilation_i, input_size, filter_size, out_size, &offset); + // LINT.ThenChange(//tensorflow/lite/kernels/padding.h) + int padding_after = padding_before + offset; padding_values[2 * i] = padding_before; padding_values[2 * i + 1] = padding_after; From cc182302758794a658791d56905c90adba26e668 Mon Sep 17 00:00:00 2001 From: Quentin Khan Date: Tue, 18 Jun 2024 11:33:59 -0700 Subject: [PATCH 27/59] Propagate SDPA as a StableHLO composite op instead of converting it to a custom op. Update the XNNPack delegate to handle the SDPA StableHLO composite op. PiperOrigin-RevId: 644459016 --- .../legalize-stablehlo-tfl-composite.mlir | 21 +- ...alize_stablehlo_composite_to_tfl_custom.cc | 4 +- tensorflow/lite/delegates/xnnpack/BUILD | 12 +- .../lite/delegates/xnnpack/odml_sdpa_test.cc | 118 ++++++----- .../delegates/xnnpack/odml_sdpa_tester.cc | 198 +++++++++--------- .../lite/delegates/xnnpack/odml_sdpa_tester.h | 22 +- .../delegates/xnnpack/xnnpack_delegate.cc | 73 +++++-- 7 files changed, 247 insertions(+), 201 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-composite.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-composite.mlir index d64b50b72d533f..dc15507fc31257 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-composite.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-composite.mlir @@ -1,25 +1,6 @@ // RUN: odml-to-stablehlo-opt %s -stablehlo-composite-legalize-tfl-custom | FileCheck %s -// RUN: tf_tfl_translate --enable-hlo-to-tf-conversion --input-mlir %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s --check-prefix=CHECK-ROUNDTRIP module { - func.func public @main(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<1x100x32x4xf32>, - %arg3: tensor<1x500x4x4xf32>, %arg4: tensor<1x500x4x4xf32>, %arg5: tensor<1x1x100x500xf32>, %arg6: tensor) - -> tensor<1x100x32x4xf32> { - // CHECK-ROUNDTRIP: %0 = "tfl.custom"(%arg2, %arg3, %arg4, %arg5, %arg6) <{custom_code = "odml.scaled_dot_product_attention", custom_option = #tfl}> : (tensor<1x100x32x4xf32>, tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<1x1x100x500xf32>, tensor) -> tensor<1x100x32x4xf32> - %0 = func.call @test_sdpa(%arg2, %arg3, %arg4, %arg5, %arg6) : (tensor<1x100x32x4xf32>, tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<1x1x100x500xf32>, tensor) -> tensor<1x100x32x4xf32> - return %0: tensor<1x100x32x4xf32> - } - - // CHECK-LABEL: func.func private @test_sdpa - func.func private @test_sdpa(%arg0: tensor<1x100x32x4xf32>, %arg1: tensor<1x500x4x4xf32>, %arg2: tensor<1x500x4x4xf32>, %arg3: tensor<1x1x100x500xf32>, %arg4: tensor) -> tensor<1x100x32x4xf32> { - // CHECK: %0 = "tfl.custom"(%arg0, %arg1, %arg2, %arg3, %arg4) <{custom_code = "odml.scaled_dot_product_attention", custom_option = #tfl}> : (tensor<1x100x32x4xf32>, tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<1x1x100x500xf32>, tensor) -> tensor<1x100x32x4xf32> - %0 = stablehlo.composite "odml.scaled_dot_product_attention" %arg0, %arg1, %arg2, %arg3, %arg4 {decomposition = @odml.scaled_dot_product_attention.impl} : (tensor<1x100x32x4xf32>, tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<1x1x100x500xf32>, tensor) -> tensor<1x100x32x4xf32> - return %0 : tensor<1x100x32x4xf32> - } - func.func private @odml.scaled_dot_product_attention.impl(%arg0: tensor<1x100x32x4xf32>, %arg1: tensor<1x500x4x4xf32>, %arg2: tensor<1x500x4x4xf32>, %arg3: tensor<1x1x100x500xf32>, %arg4: tensor) -> tensor<1x100x32x4xf32> { - // No decomposition provided for test case. - return %arg0 : tensor<1x100x32x4xf32> - } // CHECK-LABEL: func.func private @test_multiple_kv_caches func.func private @test_multiple_kv_caches(%arg0: tensor<1x500x4x4xf32>, %arg1: tensor<1x500x4x4xf32>, %arg2: tensor<100xi64>, %arg3: tensor<1x100x4x4xf32>, %arg4: tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) { @@ -47,4 +28,4 @@ module { return %6, %7 : tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32> } -} \ No newline at end of file +} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc index 50521a02c7b907..fdad31b31b4b86 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc @@ -43,9 +43,7 @@ namespace odml { namespace { bool IsSupportedComposite(::mlir::stablehlo::CompositeOp op) { // List of supported composites to represent using CustomOp. - return llvm::is_contained( - {"odml.update_kv_cache", "odml.scaled_dot_product_attention"}, - op.getName()); + return llvm::is_contained({"odml.update_kv_cache"}, op.getName()); } bool IsKVCacheCompositeOp(::mlir::stablehlo::CompositeOp op) { diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD index d83fd8161b744b..eaf105cb86e4b8 100644 --- a/tensorflow/lite/delegates/xnnpack/BUILD +++ b/tensorflow/lite/delegates/xnnpack/BUILD @@ -596,14 +596,22 @@ cc_library( testonly = 1, srcs = ["odml_sdpa_tester.cc"], hdrs = ["odml_sdpa_tester.h"], + data = [ + "odml_sdpa_composite_gqa.tflite", + "odml_sdpa_composite_mha.tflite", + "odml_sdpa_composite_mqa.tflite", + ], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/lite:framework", - "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite:version", + "//tensorflow/lite/c:common", "//tensorflow/lite/core:framework", "//tensorflow/lite/core/c:common", "//tensorflow/lite/core/kernels:builtin_ops", "//tensorflow/lite/experimental/genai:genai_ops", - "//tensorflow/lite/schema:schema_conversion_utils", + "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", "@flatbuffers", diff --git a/tensorflow/lite/delegates/xnnpack/odml_sdpa_test.cc b/tensorflow/lite/delegates/xnnpack/odml_sdpa_test.cc index 843de362400634..bf54f45cf04233 100644 --- a/tensorflow/lite/delegates/xnnpack/odml_sdpa_test.cc +++ b/tensorflow/lite/delegates/xnnpack/odml_sdpa_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include "tensorflow/lite/c/c_api_types.h" @@ -23,65 +24,78 @@ limitations under the License. namespace tflite { namespace xnnpack { -TEST(ODMLSDPA, MQA) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - const auto batch = 1; - const auto input_seq_len = 1; - const auto max_seq_len = 64; - const auto q_heads = 32; - const auto kv_heads = 1; - const auto head_dim = 4; // embedding_dim//q_heads - - ODMLSDPATester() - .QueryShape({batch, input_seq_len, q_heads, head_dim}) // q - .KeyShape({batch, max_seq_len, kv_heads, head_dim}) // k - .ValueShape({batch, max_seq_len, kv_heads, head_dim}) // v - .MaskShape({batch, 1, input_seq_len, max_seq_len}) // mask - .Test(xnnpack_delegate.get()); +struct SDPATestParams { + std::string model_name; + std::string custom_test_name; + int batch; + int input_seq_len; + int max_seq_len; + int q_heads; + int kv_heads; + int head_dim; // embedding_dim//q_heads +}; + +std::string TestName(const testing::TestParamInfo& info) { + if (info.param.model_name != kOdmlSdpaCustom) { + return info.param.model_name; + } + return "CustomOp" + info.param.custom_test_name; } -TEST(ODMLSDPA, MHA) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - const auto batch = 1; - const auto input_seq_len = 1; - const auto max_seq_len = 64; - const auto q_heads = 32; - const auto kv_heads = 32; - const auto head_dim = 4; // embedding_dim//q_heads - - ODMLSDPATester() - .QueryShape({batch, input_seq_len, q_heads, head_dim}) // q - .KeyShape({batch, max_seq_len, kv_heads, head_dim}) // k - .ValueShape({batch, max_seq_len, kv_heads, head_dim}) // v - .MaskShape({batch, 1, input_seq_len, max_seq_len}) // mask - .Test(xnnpack_delegate.get()); -} +class SDPATest : public testing::TestWithParam {}; -TEST(ODMLSDPA, GQA) { +TEST_P(SDPATest, CompareWithTFLiteReference) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); - - const auto batch = 1; - const auto input_seq_len = 1; - const auto max_seq_len = 64; - const auto q_heads = 32; - const auto kv_heads = 4; - const auto head_dim = 4; // embedding_dim//q_heads - - ODMLSDPATester() - .QueryShape({batch, input_seq_len, q_heads, head_dim}) // q - .KeyShape({batch, max_seq_len, kv_heads, head_dim}) // k - .ValueShape({batch, max_seq_len, kv_heads, head_dim}) // v - .MaskShape({batch, 1, input_seq_len, max_seq_len}) // mask - .Test(xnnpack_delegate.get()); + const SDPATestParams& p = GetParam(); + + ODMLSDPATester tester(p.model_name); + if (p.model_name == kOdmlSdpaCustom) { + tester + .QueryShape({p.batch, p.input_seq_len, p.q_heads, p.head_dim}) // q + .KeyShape({p.batch, p.max_seq_len, p.kv_heads, p.head_dim}) // k + .ValueShape({p.batch, p.max_seq_len, p.kv_heads, p.head_dim}) // v + .MaskShape({p.batch, 1, p.input_seq_len, p.max_seq_len}); // mask + } + tester.Test(xnnpack_delegate.get()); } +INSTANTIATE_TEST_SUITE_P(SDPA, SDPATest, + testing::Values(SDPATestParams{kOdmlSdpaCompositeMqa}, + SDPATestParams{kOdmlSdpaCompositeMha}, + SDPATestParams{kOdmlSdpaCompositeGqa}, + SDPATestParams{ + kOdmlSdpaCustom, + /*.custom_test_name=*/"MQA", + /*.batch=*/1, + /*.input_seq_len=*/1, + /*.max_seq_len=*/64, + /*.q_heads=*/32, + /*.kv_heads=*/1, + /*.head_dim=*/4, + }, + SDPATestParams{ + kOdmlSdpaCustom, + /*.custom_test_name=*/"MHA", + /*.batch=*/1, + /*.input_seq_len=*/1, + /*.max_seq_len=*/64, + /*.q_heads=*/32, + /*.kv_heads=*/32, + /*.head_dim=*/4, + }, + SDPATestParams{ + kOdmlSdpaCustom, + /*.custom_test_name=*/"GQA", + /*.batch=*/1, + /*.input_seq_len=*/1, + /*.max_seq_len=*/64, + /*.q_heads=*/32, + /*.kv_heads=*/4, + /*.head_dim=*/4, + }), + TestName); + } // namespace xnnpack } // namespace tflite diff --git a/tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.cc b/tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.cc index 704390bc2aa6aa..b840c889f27181 100644 --- a/tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.cc @@ -15,10 +15,9 @@ limitations under the License. #include "tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.h" -#include - #include #include +#include #include #include #include @@ -27,6 +26,7 @@ limitations under the License. #include #include #include +#include #include #include "flatbuffers/flexbuffers.h" @@ -34,10 +34,14 @@ limitations under the License. #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/string.h" // from @flatbuffers +#include "flatbuffers/util.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" #include "tensorflow/lite/experimental/genai/genai_ops.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" @@ -84,31 +88,16 @@ void ODMLSDPATester::Test(TfLiteDelegate* delegate) const { ASSERT_EQ(delegate_interpreter->ModifyGraphWithDelegate(delegate), kTfLiteOk); - float* delegate_input1_data = - delegate_interpreter->typed_input_tensor(0); - float* delegate_input2_data = - delegate_interpreter->typed_input_tensor(1); - float* delegate_input3_data = - delegate_interpreter->typed_input_tensor(2); - float* delegate_input4_data = - delegate_interpreter->typed_input_tensor(3); - std::generate_n(delegate_input1_data, QuerySize(), std::ref(input_rng)); - std::generate_n(delegate_input2_data, KeySize(), std::ref(input_rng)); - std::generate_n(delegate_input3_data, ValueSize(), std::ref(input_rng)); - std::generate_n(delegate_input4_data, MaskSize(), std::ref(input_rng)); - - float* default_input1_data = - default_interpreter->typed_input_tensor(0); - float* default_input2_data = - default_interpreter->typed_input_tensor(1); - float* default_input3_data = - default_interpreter->typed_input_tensor(2); - float* default_input4_data = - default_interpreter->typed_input_tensor(3); - std::copy_n(delegate_input1_data, QuerySize(), default_input1_data); - std::copy_n(delegate_input2_data, KeySize(), default_input2_data); - std::copy_n(delegate_input3_data, ValueSize(), default_input3_data); - std::copy_n(delegate_input4_data, MaskSize(), default_input4_data); + for (size_t i = 0; i < delegate_interpreter->inputs().size(); ++i) { + const TfLiteTensor* delegate_input_tensor = delegate_interpreter->tensor(i); + const size_t num_elts = NumElements(delegate_input_tensor); + float* const delegate_input_data = + delegate_interpreter->typed_input_tensor(i); + float* const default_input_data = + default_interpreter->typed_input_tensor(i); + std::generate_n(delegate_input_data, num_elts, std::ref(input_rng)); + std::copy_n(delegate_input_data, num_elts, default_input_data); + } ASSERT_EQ(delegate_interpreter->Invoke(), kTfLiteOk); ASSERT_EQ(default_interpreter->Invoke(), kTfLiteOk); @@ -117,7 +106,8 @@ void ODMLSDPATester::Test(TfLiteDelegate* delegate) const { delegate_interpreter->typed_output_tensor(0); float* default_output_data = default_interpreter->typed_output_tensor(0); - const int32_t output_size = ComputeSize(OutputShape()); + const int32_t output_size = + NumElements(delegate_interpreter->output_tensor(0)); for (size_t i = 0; i < output_size; i++) { ASSERT_NEAR(default_output_data[i], delegate_output_data[i], @@ -127,76 +117,86 @@ void ODMLSDPATester::Test(TfLiteDelegate* delegate) const { } std::vector ODMLSDPATester::CreateTfLiteModel() const { - flatbuffers::FlatBufferBuilder builder; - flatbuffers::Offset operator_code = CreateOperatorCode( - builder, BuiltinOperator_CUSTOM, - builder.CreateString("odml.scaled_dot_product_attention")); - - const std::array, 1> buffers{{ - CreateBuffer(builder, builder.CreateVector({})), - }}; - - const std::array, 5> tensors{{ - CreateTensor(builder, - builder.CreateVector(QueryShape().data(), - QueryShape().size()), - TensorType_FLOAT32), - CreateTensor( - builder, - builder.CreateVector(KeyShape().data(), KeyShape().size()), - TensorType_FLOAT32), - CreateTensor(builder, - builder.CreateVector(ValueShape().data(), - ValueShape().size()), - TensorType_FLOAT32), - CreateTensor( - builder, - builder.CreateVector(MaskShape().data(), MaskShape().size()), - TensorType_FLOAT32), - CreateTensor(builder, - builder.CreateVector(OutputShape().data(), - OutputShape().size()), - TensorType_FLOAT32), - }}; - - auto fbb = std::make_unique(); - float scale = 1 / sqrt(QueryShape().data()[QueryShape().size() - 1]); - fbb->Map([&]() { fbb->Float("scale", scale); }); - fbb->Finish(); - - const std::array op_inputs{{0, 1, 2, 3}}; - const std::array op_outputs{{4}}; - flatbuffers::Offset op = CreateOperator( - builder, /*opcode_index=*/0, - builder.CreateVector(op_inputs.data(), op_inputs.size()), - builder.CreateVector(op_outputs.data(), op_outputs.size()), - tflite::BuiltinOptions_NONE, 0, - builder.CreateVector( - reinterpret_cast(fbb->GetBuffer().data()), - fbb->GetSize())); - - const std::array subgraph_inputs{{0, 1, 2, 3}}; - const std::array subgraph_outputs{{4}}; - flatbuffers::Offset subgraph = CreateSubGraph( - builder, builder.CreateVector(tensors.data(), tensors.size()), - builder.CreateVector(subgraph_inputs.data(), - subgraph_inputs.size()), - builder.CreateVector(subgraph_outputs.data(), - subgraph_outputs.size()), - builder.CreateVector(&op, 1)); - - flatbuffers::Offset description = - builder.CreateString("ODML SDPA model"); - - flatbuffers::Offset model_buffer = CreateModel( - builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1), - builder.CreateVector(&subgraph, 1), description, - builder.CreateVector(buffers.data(), buffers.size())); - - builder.Finish(model_buffer); - - return std::vector(builder.GetBufferPointer(), - builder.GetBufferPointer() + builder.GetSize()); + if (!model_name_.empty() && model_name_ != kOdmlSdpaCustom) { + const char kTestModelFolder[] = + "/tensorflow/lite/delegates/xnnpack/"; + const std::string test_model = + testing::SrcDir() + kTestModelFolder + model_name_ + ".tflite"; + std::string model_data; + flatbuffers::LoadFile(test_model.c_str(), /*binary=*/true, &model_data); + return std::vector(model_data.begin(), model_data.end()); + } else { + flatbuffers::FlatBufferBuilder builder; + flatbuffers::Offset operator_code = CreateOperatorCode( + builder, BuiltinOperator_CUSTOM, + builder.CreateString("odml.scaled_dot_product_attention")); + + const std::array, 1> buffers{{ + CreateBuffer(builder, builder.CreateVector({})), + }}; + + const std::array, 5> tensors{{ + CreateTensor(builder, + builder.CreateVector(QueryShape().data(), + QueryShape().size()), + TensorType_FLOAT32), + CreateTensor( + builder, + builder.CreateVector(KeyShape().data(), KeyShape().size()), + TensorType_FLOAT32), + CreateTensor(builder, + builder.CreateVector(ValueShape().data(), + ValueShape().size()), + TensorType_FLOAT32), + CreateTensor(builder, + builder.CreateVector(MaskShape().data(), + MaskShape().size()), + TensorType_FLOAT32), + CreateTensor(builder, + builder.CreateVector(OutputShape().data(), + OutputShape().size()), + TensorType_FLOAT32), + }}; + + auto fbb = std::make_unique(); + float scale = 1 / sqrt(QueryShape().data()[QueryShape().size() - 1]); + fbb->Map([&]() { fbb->Float("scale", scale); }); + fbb->Finish(); + + const std::array op_inputs{{0, 1, 2, 3}}; + const std::array op_outputs{{4}}; + flatbuffers::Offset op = CreateOperator( + builder, /*opcode_index=*/0, + builder.CreateVector(op_inputs.data(), op_inputs.size()), + builder.CreateVector(op_outputs.data(), op_outputs.size()), + tflite::BuiltinOptions_NONE, 0, + builder.CreateVector( + reinterpret_cast(fbb->GetBuffer().data()), + fbb->GetSize())); + + const std::array subgraph_inputs{{0, 1, 2, 3}}; + const std::array subgraph_outputs{{4}}; + flatbuffers::Offset subgraph = CreateSubGraph( + builder, builder.CreateVector(tensors.data(), tensors.size()), + builder.CreateVector(subgraph_inputs.data(), + subgraph_inputs.size()), + builder.CreateVector(subgraph_outputs.data(), + subgraph_outputs.size()), + builder.CreateVector(&op, 1)); + + flatbuffers::Offset description = + builder.CreateString("ODML SDPA model"); + + flatbuffers::Offset model_buffer = CreateModel( + builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1), + builder.CreateVector(&subgraph, 1), description, + builder.CreateVector(buffers.data(), buffers.size())); + + builder.Finish(model_buffer); + + return std::vector(builder.GetBufferPointer(), + builder.GetBufferPointer() + builder.GetSize()); + } } int32_t ODMLSDPATester::ComputeSize(const std::vector& shape) { diff --git a/tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.h b/tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.h index da88b8c890729f..ba13b851bf38b0 100644 --- a/tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.h +++ b/tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.h @@ -17,23 +17,31 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_XNNPACK_ODML_SDPA_TESTER_H_ #include -#include -#include +#include +#include #include #include #include -#include "tensorflow/lite/core/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" namespace tflite { namespace xnnpack { +constexpr const char kOdmlSdpaCompositeMqa[] = "odml_sdpa_composite_mqa"; +constexpr const char kOdmlSdpaCompositeMha[] = "odml_sdpa_composite_mha"; +constexpr const char kOdmlSdpaCompositeGqa[] = "odml_sdpa_composite_gqa"; +constexpr const char kOdmlSdpaCustom[] = "odml_sdpa_custom"; + class ODMLSDPATester { public: ODMLSDPATester() = default; ODMLSDPATester(const ODMLSDPATester&) = delete; ODMLSDPATester& operator=(const ODMLSDPATester&) = delete; + explicit ODMLSDPATester(const std::string& model_name) + : model_name_(model_name) {}; + inline ODMLSDPATester& QueryShape(std::initializer_list shape) { EXPECT_THAT(shape, testing::Each(testing::Gt(0))); query_shape_ = std::vector(shape.begin(), shape.end()); @@ -68,6 +76,13 @@ class ODMLSDPATester { return *this; } + int32_t Batch() const { return query_shape_[0]; }; + int32_t InputSeqLen() const { return query_shape_[1]; }; + int32_t QHeads() const { return query_shape_[2]; }; + int32_t HeadDim() const { return query_shape_[3]; }; + int32_t MaxSeqLen() const { return key_shape_[1]; }; + int32_t KVHeads() const { return key_shape_[2]; }; + inline const std::vector& MaskShape() const { return mask_shape_; } inline int32_t QuerySize() const { return query_size_; } @@ -95,6 +110,7 @@ class ODMLSDPATester { int32_t key_size_ = 1; int32_t value_size_ = 1; int32_t mask_size_ = 1; + std::string model_name_; }; } // namespace xnnpack diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index dccf053c266497..e40294c18d08ff 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -58,6 +58,8 @@ namespace tflite { namespace xnnpack { namespace { +constexpr char kOdmlSDPA[] = "odml.scaled_dot_product_attention"; + template void SafeCopyCustomData(const TfLiteNode& node, T* target) { const size_t safe_size = @@ -2914,6 +2916,25 @@ class Subgraph { case kTfLiteBuiltinVarHandle: return VisitVarHandleNode(subgraph, delegate, logging_context, node_index, node); + case kTfLiteBuiltinStablehloComposite: { + const TfLiteStablehloCompositeParams* composite_params = + static_cast( + node->builtin_data); + if (strcmp(composite_params->name, kOdmlSDPA) == 0) { + return VisitScaledDotAttentionCompositeNode( + subgraph, delegate, context, node_index, node, context->tensors, + composite_params->attributes, composite_params->attributes_size, + input_output_tensors); + } else { +#ifdef XNNPACK_DELEGATE_ENABLE_LOGGING + TF_LITE_KERNEL_LOG(context, + "unsupported stablehlo.composite operator type " + "\"%s\" in node #%d", + composite_params->name, node_index); +#endif // XNNPACK_DELEGATE_ENABLE_LOGGING + } + return kTfLiteError; + } case kTfLiteBuiltinCustom: { if (strcmp(registration->custom_name, "Convolution2DTransposeBias") == 0) { @@ -2938,28 +2959,11 @@ class Subgraph { return VisitMediaPipeUnpoolingNode( subgraph, delegate, context, node_index, node, context->tensors, &pool_params, input_output_tensors); - } else if (strcmp(registration->custom_name, - "odml.scaled_dot_product_attention") == 0) { - const float* scale_val = nullptr; - // ensure 28 bytes as we expect - // TODO(b/339106680): this reading method may not work for every case. - if (node->custom_initial_data_size == 28 && sizeof(float) == 4) { - // Custom data here is a flexbuffer map. - // byte_width is 4 for our map. - // First 5 values are "scale", then is the float value, and last is - // flexbuffer metadata. - const uint8_t* buffer = - reinterpret_cast(node->custom_initial_data); - if (strcmp("scale", reinterpret_cast(buffer)) == 0) { - constexpr size_t kScaleValOffset = 20; - scale_val = - reinterpret_cast(buffer + kScaleValOffset); - } - } - - return VisitDotAttentionNode(subgraph, delegate, context, node_index, - node, context->tensors, scale_val, - input_output_tensors); + } else if (strcmp(registration->custom_name, kOdmlSDPA) == 0) { + return VisitScaledDotAttentionCompositeNode( + subgraph, delegate, context, node_index, node, context->tensors, + reinterpret_cast(node->custom_initial_data), + node->custom_initial_data_size, input_output_tensors); } else { #ifdef XNNPACK_DELEGATE_ENABLE_LOGGING TF_LITE_KERNEL_LOG( @@ -6591,6 +6595,31 @@ class Subgraph { return kTfLiteOk; } + static TfLiteStatus VisitScaledDotAttentionCompositeNode( + xnn_subgraph_t subgraph, const Delegate& delegate, + TfLiteContext* logging_context, int node_index, TfLiteNode* node, + const TfLiteTensor* tensors, const uint8_t* buffer, + const size_t buffer_size, + const std::unordered_map& input_output_tensors) { + const float* scale_val = nullptr; + // ensure 28 bytes as we expect + // TODO(b/339106680): this reading method may not work for every case. + if (buffer_size == 28 && sizeof(float) == 4) { + // Custom data here is a flexbuffer map. + // byte_width is 4 for our map. + // First 5 values are "scale", then is the float value, and last is + // flexbuffer metadata. + if (strcmp("scale", reinterpret_cast(buffer)) == 0) { + constexpr size_t kScaleValOffset = 20; + scale_val = reinterpret_cast(buffer + kScaleValOffset); + } + } + + return VisitDotAttentionNode(subgraph, delegate, logging_context, + node_index, node, tensors, scale_val, + input_output_tensors); + } + static TfLiteStatus VisitDotAttentionNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, From 692144385b931799f918670a32e37a0eab339f23 Mon Sep 17 00:00:00 2001 From: Weiyi Wang Date: Tue, 18 Jun 2024 11:43:43 -0700 Subject: [PATCH 28/59] Support array attribute in vhlo seraialization. PiperOrigin-RevId: 644462079 --- .../compiler/mlir/lite/flatbuffer_export.cc | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index eb12754b217c8b..b73a9902377af9 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -1517,7 +1517,7 @@ void CreateFlexbufferVector( const std::unique_ptr& flex_builder, std::string& name, const mlir::Attribute& attr) { auto start = flex_builder->StartVector(name.c_str()); - auto array = attr.cast().getValue(); + auto array = attr.cast().getValue(); for (int i = 0; i < array.size(); i++) { if (llvm::isa(array[i])) { @@ -1671,6 +1671,11 @@ Translator::BuildVhloCompositeV1Op(mlir::vhlo::CompositeOpV1 composite_op, flex_builder->Float( name.c_str(), attr.cast().getValue().convertToFloat()); + else if (llvm::isa(attr)) + CreateFlexbufferVector(flex_builder, name, attr); + else + // Unhandled attribute type. + return std::nullopt; } flex_builder->EndMap(map_start); @@ -2161,8 +2166,12 @@ std::optional> Translator::BuildOperator( return BuildVhloPadV1Op(vhlo_op, operands, results, vhlo_type_converter); } if (auto vhlo_op = llvm::dyn_cast(inst)) { - return BuildVhloCompositeV1Op(vhlo_op, operands, results, - inst->getName().getStringRef().str()); + auto op = BuildVhloCompositeV1Op(vhlo_op, operands, results, + inst->getName().getStringRef().str()); + if (!op) + return inst->emitOpError("Failed to build VhloCompositeOpV1."), + std::nullopt; + return op; } // for ops don't have kernels, only serialize when conversion is set to // true From 6156204d3d3b11755916ddd108085d967623aa43 Mon Sep 17 00:00:00 2001 From: Changhui Lin Date: Tue, 18 Jun 2024 11:44:05 -0700 Subject: [PATCH 29/59] Remove the deprecated PjRtClient::LookupAddressableDevice() that takes a raw int. PiperOrigin-RevId: 644462225 --- tensorflow/compiler/jit/BUILD | 3 +++ tensorflow/compiler/jit/pjrt_device_context.cc | 8 ++++++-- tensorflow/compiler/jit/xla_launch_util.cc | 4 +++- tensorflow/compiler/jit/xla_launch_util_test.cc | 12 +++++++----- tensorflow/core/tfrt/gpu/kernel/BUILD | 1 + tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc | 6 ++++-- third_party/xla/xla/pjrt/BUILD | 2 +- .../xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc | 6 +++--- third_party/xla/xla/pjrt/cpu/cpu_client.cc | 5 ----- third_party/xla/xla/pjrt/cpu/cpu_client.h | 2 -- third_party/xla/xla/pjrt/pjrt_c_api_client.cc | 5 ----- third_party/xla/xla/pjrt/pjrt_c_api_client.h | 2 -- third_party/xla/xla/pjrt/pjrt_client.h | 7 +------ .../xla/xla/pjrt/pjrt_stream_executor_client.cc | 5 ----- .../xla/xla/pjrt/pjrt_stream_executor_client.h | 5 ++--- third_party/xla/xla/pjrt/tf_pjrt_client.h | 4 ---- third_party/xla/xla/python/BUILD | 1 + third_party/xla/xla/python/dlpack.cc | 10 +++++++--- third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc | 3 ++- 19 files changed, 41 insertions(+), 50 deletions(-) diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index c11e0300e27b8e..a3934ae99155ca 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -595,6 +595,7 @@ cc_library( "@local_xla//xla:status_macros", "@local_xla//xla/client:local_client", "@local_xla//xla/pjrt:pjrt_client", + "@local_xla//xla/pjrt:pjrt_common", "@local_xla//xla/pjrt:pjrt_future", "@local_xla//xla/pjrt:pjrt_stream_executor_client", "@local_xla//xla/pjrt:tracked_device_buffer", @@ -633,6 +634,7 @@ tf_cc_test( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_xla//xla/pjrt:pjrt_client", + "@local_xla//xla/pjrt:pjrt_common", "@local_xla//xla/pjrt:tfrt_cpu_pjrt_client", "@local_xla//xla/tests:literal_test_util", "@local_xla//xla/tsl/framework:device_id_utils", @@ -1680,6 +1682,7 @@ cc_library( "//tensorflow/core/tfrt/common:create_pjrt_client_util", "@com_google_absl//absl/status", "@local_xla//xla/pjrt:pjrt_client", + "@local_xla//xla/pjrt:pjrt_common", "@local_xla//xla/tsl/c:tsl_status_internal", "@local_xla//xla/tsl/framework:device_id_utils", ], diff --git a/tensorflow/compiler/jit/pjrt_device_context.cc b/tensorflow/compiler/jit/pjrt_device_context.cc index 0a716b0dfb773a..51b6e5770ce592 100644 --- a/tensorflow/compiler/jit/pjrt_device_context.cc +++ b/tensorflow/compiler/jit/pjrt_device_context.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/jit/pjrt_tensor_buffer_util.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" #include "xla/tsl/c/tsl_status_internal.h" #include "xla/tsl/framework/device_id_utils.h" #include "tensorflow/core/common_runtime/dma_helper.h" @@ -58,7 +59,8 @@ absl::StatusOr> HostTensorToPjRtBuffer( tsl::GetDeviceIdFromDeviceParsedName(device->parsed_name(), DeviceType(device->device_type()))); TF_ASSIGN_OR_RETURN(xla::PjRtDevice * pjrt_device, - pjrt_client->LookupAddressableDevice(pjrt_device_id)); + pjrt_client->LookupAddressableDevice( + xla::PjRtLocalDeviceId(pjrt_device_id))); auto first_try_buffer = pjrt_client->BufferFromHostBuffer( cpu_tensor->data(), shape.element_type(), shape.dimensions(), /*byte_strides=*/std::nullopt, @@ -265,7 +267,9 @@ void PjRtDeviceToDeviceCopy(DeviceContext* send_dev_context, DeviceType(dst->device_type())) .value(); xla::PjRtDevice* pjrt_dst_device = - (*pjrt_dst_client)->LookupAddressableDevice(pjrt_dst_device_id).value(); + (*pjrt_dst_client) + ->LookupAddressableDevice(xla::PjRtLocalDeviceId(pjrt_dst_device_id)) + .value(); absl::StatusOr> buffer_or = src_device_buffer->CopyToDevice(pjrt_dst_device); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 3920b6f05d1342..d146c8878450ee 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "xla/client/local_client.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/pjrt_stream_executor_client.h" #include "xla/pjrt/tracked_device_buffer.h" @@ -861,7 +862,8 @@ Status RunPjRtExecutable( tsl::GetDeviceIdFromDeviceParsedName( ctx->device()->parsed_name(), device_type)); TF_ASSIGN_OR_RETURN(xla::PjRtDevice * device, - pjrt_client->LookupAddressableDevice(pjrt_device_id)); + pjrt_client->LookupAddressableDevice( + xla::PjRtLocalDeviceId(pjrt_device_id))); gpu::GpuServingDeviceSelectorResource* device_selector_resource = nullptr; if (device_type == DEVICE_GPU && gpu::kUseGpuServingDeviceSelector) { diff --git a/tensorflow/compiler/jit/xla_launch_util_test.cc b/tensorflow/compiler/jit/xla_launch_util_test.cc index ebc1f9fb7809aa..d19e4fc2548bb1 100644 --- a/tensorflow/compiler/jit/xla_launch_util_test.cc +++ b/tensorflow/compiler/jit/xla_launch_util_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/jit/variable_info_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/tfrt_cpu_pjrt_client.h" #include "xla/tests/literal_test_util.h" #include "xla/tsl/framework/allocator.h" @@ -191,8 +192,9 @@ class PjRtExecutionUtilTest : public OpsTestBase { const std::vector& variables, const XlaCompiler::CompilationResult* result, xla::PjRtLoadedExecutable* executable) { - TF_ASSIGN_OR_RETURN(auto pjrt_device, pjrt_client_->LookupAddressableDevice( - device_->parsed_name().id)); + TF_ASSIGN_OR_RETURN(auto pjrt_device, + pjrt_client_->LookupAddressableDevice( + xla::PjRtLocalDeviceId(device_->parsed_name().id))); std::vector executable_args; executable_args.reserve(result->input_mapping.size()); @@ -675,9 +677,9 @@ TEST_F(PjRtExecutionUtilTest, RunPjRtExecutableWithoutCtx) { TF_ASSERT_OK_AND_ASSIGN(const int pjrt_device_id, tsl::GetDeviceIdFromDeviceParsedName( context_->device()->parsed_name(), device_type)); - TF_ASSERT_OK_AND_ASSIGN( - xla::PjRtDevice * pjrt_device, - pjrt_client_->LookupAddressableDevice(pjrt_device_id)); + TF_ASSERT_OK_AND_ASSIGN(xla::PjRtDevice * pjrt_device, + pjrt_client_->LookupAddressableDevice( + xla::PjRtLocalDeviceId(pjrt_device_id))); absl::flat_hash_map variable_snapshots; for (int i = 0; i < variables.size(); i++) { diff --git a/tensorflow/core/tfrt/gpu/kernel/BUILD b/tensorflow/core/tfrt/gpu/kernel/BUILD index 1d778df8861db5..bd4f86131e3117 100644 --- a/tensorflow/core/tfrt/gpu/kernel/BUILD +++ b/tensorflow/core/tfrt/gpu/kernel/BUILD @@ -67,6 +67,7 @@ cc_library( "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", "@local_xla//xla/pjrt:pjrt_client", + "@local_xla//xla/pjrt:pjrt_common", "@local_xla//xla/tsl/framework:device_id", "@local_xla//xla/tsl/framework:serving_device_selector", "@tf_runtime//:hostcontext", diff --git a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc index c2a59ad64dc40e..3143b8bd7821ae 100644 --- a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc +++ b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" #include "xla/tsl/framework/device_id.h" #include "xla/tsl/framework/device_id_manager.h" #include "xla/tsl/framework/serving_device_selector.h" @@ -397,8 +398,9 @@ GpuRunner::Run(const GpuRunInputs& run_inputs) { "Execution with collectives is not supported yet."); } - TF_ASSIGN_OR_RETURN(xla::PjRtDevice * pjrt_device, - pjrt_client->LookupAddressableDevice(device_idx)); + TF_ASSIGN_OR_RETURN( + xla::PjRtDevice * pjrt_device, + pjrt_client->LookupAddressableDevice(xla::PjRtLocalDeviceId(device_idx))); TF_ASSIGN_OR_RETURN( std::vector> executable_outputs, RunPjRtExecutable(/*num_missing_prefix_ctx_inputs=*/0, inputs, diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 545bec940f4f95..3ec761e5502451 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -337,7 +337,7 @@ xla_cc_test( cc_library( name = "pjrt_common", hdrs = ["pjrt_common.h"], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), deps = [ "@local_tsl//tsl/lib/gtl:int_type", ], diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index a8b8070f22dc22..8b4b3f99901ef3 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -453,9 +453,9 @@ PJRT_Error* PJRT_Client_LookupAddressableDevice( PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( "PJRT_Client_LookupAddressableDevice_Args", PJRT_Client_LookupAddressableDevice_Args_STRUCT_SIZE, args->struct_size)); - PJRT_ASSIGN_OR_RETURN( - xla::PjRtDevice * addressable_device, - args->client->client->LookupAddressableDevice(args->local_hardware_id)); + PJRT_ASSIGN_OR_RETURN(xla::PjRtDevice * addressable_device, + args->client->client->LookupAddressableDevice( + xla::PjRtLocalDeviceId(args->local_hardware_id))); args->addressable_device = GetCDevice(args->client, addressable_device); return nullptr; } diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.cc b/third_party/xla/xla/pjrt/cpu/cpu_client.cc index e19dcdd69641c4..3dc7f4208859b4 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.cc @@ -479,11 +479,6 @@ absl::StatusOr TfrtCpuClient::LookupDevice( global_device_id.value()); } -absl::StatusOr TfrtCpuClient::LookupAddressableDevice( - int local_hardware_id) const { - return LookupAddressableDevice(PjRtLocalDeviceId(local_hardware_id)); -} - absl::StatusOr TfrtCpuClient::LookupAddressableDevice( PjRtLocalDeviceId local_device_id) const { for (auto* device : addressable_devices_) { diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.h b/third_party/xla/xla/pjrt/cpu/cpu_client.h index bb4559aa685d20..41bda465a394cf 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.h +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.h @@ -282,8 +282,6 @@ class TfrtCpuClient final : public PjRtClient { absl::StatusOr LookupDevice( PjRtGlobalDeviceId global_device_id) const override; - absl::StatusOr LookupAddressableDevice( - int local_hardware_id) const override; absl::StatusOr LookupAddressableDevice( PjRtLocalDeviceId local_device_id) const override; diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc index 5bee5fb9ce7915..3ff6a75b5955f1 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc @@ -332,11 +332,6 @@ absl::StatusOr PjRtCApiClient::LookupDevice( return GetCppDevice(args.device); } -absl::StatusOr PjRtCApiClient::LookupAddressableDevice( - int local_hardware_id) const { - return LookupAddressableDevice(PjRtLocalDeviceId(local_hardware_id)); -} - absl::StatusOr PjRtCApiClient::LookupAddressableDevice( PjRtLocalDeviceId local_device_id) const { PJRT_Client_LookupAddressableDevice_Args args; diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.h b/third_party/xla/xla/pjrt/pjrt_c_api_client.h index 83d86dd45d131c..a8d2236930843a 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.h +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.h @@ -270,8 +270,6 @@ class PjRtCApiClient : public PjRtClient { absl::StatusOr LookupDevice( PjRtGlobalDeviceId global_device_id) const override; - absl::StatusOr LookupAddressableDevice( - int local_hardware_id) const override; absl::StatusOr LookupAddressableDevice( PjRtLocalDeviceId local_device_id) const override; diff --git a/third_party/xla/xla/pjrt/pjrt_client.h b/third_party/xla/xla/pjrt/pjrt_client.h index 0417b2be648975..d9eea97b71cbdb 100644 --- a/third_party/xla/xla/pjrt/pjrt_client.h +++ b/third_party/xla/xla/pjrt/pjrt_client.h @@ -531,12 +531,7 @@ class PjRtClient { PjRtGlobalDeviceId global_device_id) const = 0; // Return an addressable PjRtDevice for a given - // PjRtDevice::local_hardware_id(). - ABSL_DEPRECATED("Use LookupAddressableDevice(PjRtLocalDeviceId) instead") - virtual absl::StatusOr LookupAddressableDevice( - int local_hardware_id) const { - return LookupAddressableDevice(PjRtLocalDeviceId(local_hardware_id)); - } + // PjRtDevice::local_device_id(). virtual absl::StatusOr LookupAddressableDevice( PjRtLocalDeviceId local_device_id) const = 0; diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index 2a16fabcc3f367..4c036f2ec74013 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -1357,11 +1357,6 @@ PjRtStreamExecutorDevice::GetStreamForExternalReadyEvents() const { return absl::bit_cast(raw_stream); } -absl::StatusOr PjRtStreamExecutorClient::LookupAddressableDevice( - int local_hardware_id) const { - return LookupAddressableDevice(PjRtLocalDeviceId(local_hardware_id)); -} - absl::StatusOr PjRtStreamExecutorClient::LookupAddressableDevice( xla::PjRtLocalDeviceId local_device_id) const { for (auto* device : addressable_devices_) { diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index 015266588a48b6..7aac763b2d8c1a 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -286,8 +286,6 @@ class PjRtStreamExecutorClient : public PjRtClient { global_device_id.value()); } - absl::StatusOr LookupAddressableDevice( - int local_hardware_id) const override; absl::StatusOr LookupAddressableDevice( PjRtLocalDeviceId local_device_id) const override; @@ -416,7 +414,8 @@ class PjRtStreamExecutorClient : public PjRtClient { LocalDeviceState& device_state(int device_ordinal) const { return *tensorflow::down_cast( - LookupAddressableDevice(device_ordinal).value()) + LookupAddressableDevice(xla::PjRtLocalDeviceId(device_ordinal)) + .value()) ->local_device_state(); } LocalClient* client() const { return client_; } diff --git a/third_party/xla/xla/pjrt/tf_pjrt_client.h b/third_party/xla/xla/pjrt/tf_pjrt_client.h index d1dfd3eace20e7..363c0526f0ba56 100644 --- a/third_party/xla/xla/pjrt/tf_pjrt_client.h +++ b/third_party/xla/xla/pjrt/tf_pjrt_client.h @@ -208,10 +208,6 @@ class TfPjRtClient : public PjRtClient { PjRtGlobalDeviceId global_device_id) const override { return wrapped_->LookupDevice(global_device_id); } - absl::StatusOr LookupAddressableDevice( - int local_hardware_id) const override { - return LookupAddressableDevice(PjRtLocalDeviceId(local_hardware_id)); - } absl::StatusOr LookupAddressableDevice( PjRtLocalDeviceId local_device_id) const override { if (wrapped_ == nullptr) { diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index 772cb444b32bc8..33597dd696044e 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -517,6 +517,7 @@ cc_library( "//xla:util", "//xla/pjrt:exceptions", "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_common", "//xla/pjrt:pjrt_compiler", "//xla/pjrt:pjrt_layout", "//xla/python/ifrt", diff --git a/third_party/xla/xla/python/dlpack.cc b/third_party/xla/xla/python/dlpack.cc index 1714eead8aac77..5a22723fedad1d 100644 --- a/third_party/xla/xla/python/dlpack.cc +++ b/third_party/xla/xla/python/dlpack.cc @@ -36,6 +36,7 @@ limitations under the License. #include "xla/layout.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" @@ -270,21 +271,24 @@ absl::StatusOr DeviceForDLDevice(const PjRtClient* cpu_client, "DLPack tensor is on CPU, but no CPU backend was provided."); } TF_RET_CHECK(cpu_client->platform_id() == CpuId()); - return cpu_client->LookupAddressableDevice(context.device_id); + return cpu_client->LookupAddressableDevice( + xla::PjRtLocalDeviceId(context.device_id)); case kDLCUDA: if (gpu_client == nullptr) { return InvalidArgument( "DLPack tensor is on GPU, but no GPU backend was provided."); } TF_RET_CHECK(gpu_client->platform_id() == CudaId()); - return gpu_client->LookupAddressableDevice(context.device_id); + return gpu_client->LookupAddressableDevice( + xla::PjRtLocalDeviceId(context.device_id)); case kDLROCM: if (gpu_client == nullptr) { return InvalidArgument( "DLPack tensor is on GPU, but no GPU backend was provided."); } TF_RET_CHECK(gpu_client->platform_id() == RocmId()); - return gpu_client->LookupAddressableDevice(context.device_id); + return gpu_client->LookupAddressableDevice( + xla::PjRtLocalDeviceId(context.device_id)); default: return InvalidArgument("Unknown/unsupported DLPack device type %d", context.device_type); diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc index 3b8c8b537c8a5c..e412d9cb87450b 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc @@ -493,7 +493,8 @@ absl::StatusOr PjRtClient::LookupAddressableDevice( int local_hardware_id) const { DCHECK(this); TF_ASSIGN_OR_RETURN(xla::PjRtDevice * pjrt_device, - pjrt_client_->LookupAddressableDevice(local_hardware_id)); + pjrt_client_->LookupAddressableDevice( + xla::PjRtLocalDeviceId(local_hardware_id))); return LookupPjRtDevice(pjrt_device); } From 849108484779568084a54abdcce3932ff0ede115 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Tue, 18 Jun 2024 11:45:28 -0700 Subject: [PATCH 30/59] Move StreamExecutor::MemZero processing completely into Stream and its derived classes. PiperOrigin-RevId: 644462658 --- .../stream_executor/stream_executor.cc | 20 ---------------- .../stream_executor_internal.h | 24 +++++++++++++++---- .../xla/xla/backends/interpreter/executor.h | 4 ---- .../xla/stream_executor/cuda/cuda_executor.cc | 10 -------- third_party/xla/xla/stream_executor/gpu/BUILD | 2 ++ .../xla/stream_executor/gpu/gpu_executor.h | 2 -- .../xla/xla/stream_executor/gpu/gpu_stream.cc | 11 +++++++++ .../xla/xla/stream_executor/gpu/gpu_stream.h | 2 ++ .../xla/xla/stream_executor/host/BUILD | 1 + .../xla/stream_executor/host/host_executor.cc | 10 -------- .../xla/stream_executor/host/host_executor.h | 2 -- .../xla/stream_executor/host/host_stream.cc | 8 +++++++ .../xla/stream_executor/host/host_stream.h | 2 ++ .../stream_executor/mock_stream_executor.h | 3 --- .../xla/stream_executor/rocm/rocm_executor.cc | 10 -------- third_party/xla/xla/stream_executor/stream.h | 10 ++++---- .../xla/xla/stream_executor/stream_common.cc | 4 ---- .../xla/xla/stream_executor/stream_common.h | 1 - .../xla/xla/stream_executor/stream_executor.h | 6 ----- .../xla/stream_executor/tpu/tpu_executor.h | 4 ---- 20 files changed, 52 insertions(+), 84 deletions(-) diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index da7b7105a427d9..ef54b703a68db6 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -155,17 +155,6 @@ absl::Status ValidateSEPlatformRegistrationParams( } #undef TF_VALIDATE_NOT_NULL -// Converts DeviceMemoryBase to a C struct. -SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) { - SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE}; - // `opaque` field inside SP_DeviceMemoryBase is not const. - // Therefore, we need to cast away the constness before setting it. - device_memory_base.opaque = const_cast(mem->opaque()); - device_memory_base.size = mem->size(); - device_memory_base.payload = mem->payload(); - return device_memory_base; -} - DeviceMemoryBase DeviceMemoryBaseFromC(const SP_DeviceMemoryBase& mem) { DeviceMemoryBase base(mem.opaque, mem.size); base.SetPayload(mem.payload); @@ -314,15 +303,6 @@ class CStreamExecutor : public StreamExecutorCommon { size, c_status.get()); return StatusFromTF_Status(c_status.get()); } - absl::Status MemZero(Stream* stream, DeviceMemoryBase* location, - uint64 size) override { - OwnedTFStatus c_status(TF_NewStatus()); - SP_Stream stream_handle = static_cast(stream)->Handle(); - SP_DeviceMemoryBase device_mem = DeviceMemoryBaseToC(location); - stream_executor_->mem_zero(&device_, stream_handle, &device_mem, size, - c_status.get()); - return StatusFromTF_Status(c_status.get()); - } absl::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8 pattern, uint64 size) override { OwnedTFStatus c_status(TF_NewStatus()); diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h index 8f2b5970f8ab04..21447ba6aa16f5 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h +++ b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h @@ -46,6 +46,17 @@ absl::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn, std::string* device_type, std::string* platform_name); +// Converts DeviceMemoryBase to a C struct. +inline SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) { + SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE}; + // `opaque` field inside SP_DeviceMemoryBase is not const. + // Therefore, we need to cast away the constness before setting it. + device_memory_base.opaque = const_cast(mem->opaque()); + device_memory_base.size = mem->size(); + device_memory_base.payload = mem->payload(); + return device_memory_base; +} + // This file implements core stream executor base classes in terms of // the C API defined in stream_executor.h. A class "CSomething" represents a // "Something" that can be manipulated via calls in the C interface. @@ -166,8 +177,7 @@ class CStream : public StreamCommon { absl::Status Create() { tensorflow::TF_StatusPtr c_status(TF_NewStatus()); stream_executor_->create_stream(device_, &stream_handle_, c_status.get()); - absl::Status s = tensorflow::StatusFromTF_Status(c_status.get()); - return s; + return tensorflow::StatusFromTF_Status(c_status.get()); } void Destroy() { @@ -201,8 +211,14 @@ class CStream : public StreamCommon { tensorflow::TF_StatusPtr c_status(TF_NewStatus()); stream_executor_->wait_for_event(device_, stream_handle_, event_handle, c_status.get()); - absl::Status s = tensorflow::StatusFromTF_Status(c_status.get()); - return s; + return tensorflow::StatusFromTF_Status(c_status.get()); + } + absl::Status MemZero(DeviceMemoryBase* location, uint64_t size) override { + tensorflow::TF_StatusPtr c_status(TF_NewStatus()); + SP_DeviceMemoryBase device_mem = DeviceMemoryBaseToC(location); + stream_executor_->mem_zero(device_, stream_handle_, &device_mem, size, + c_status.get()); + return tensorflow::StatusFromTF_Status(c_status.get()); } SP_Stream Handle() { return stream_handle_; } diff --git a/third_party/xla/xla/backends/interpreter/executor.h b/third_party/xla/xla/backends/interpreter/executor.h index 1ba7dabbed9993..d103d94b99ac47 100644 --- a/third_party/xla/xla/backends/interpreter/executor.h +++ b/third_party/xla/xla/backends/interpreter/executor.h @@ -100,10 +100,6 @@ class XlaInterpreterExecutor : public StreamExecutorCommon { return false; } - absl::Status MemZero(Stream *stream, DeviceMemoryBase *location, - uint64_t size) override { - return absl::InternalError("Interpreter can not memzero"); - } absl::Status Memset(Stream *stream, DeviceMemoryBase *location, uint8_t pattern, uint64_t size) override { return absl::InternalError("Interpreter can not memset"); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index b78d54c116451c..68f79ea1f2be9c 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -671,16 +671,6 @@ absl::Status GpuExecutor::SynchronousMemcpy(void* host_dst, AsCudaDevicePtr(gpu_src), size); } -absl::Status GpuExecutor::MemZero(Stream* stream, DeviceMemoryBase* location, - uint64_t size) { - if (reinterpret_cast(location->opaque()) % 4 == 0 && - size % 4 == 0) { - return Memset32(stream, location, 0x0, size); - } else { - return Memset(stream, location, 0x0, size); - } -} - absl::Status GpuExecutor::Memset(Stream* stream, DeviceMemoryBase* location, uint8_t pattern, uint64_t size) { VLOG(2) << "enqueueing memset8 operation onto stream " << stream diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 18bab237f87dcf..f7ee2747381c59 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -305,6 +305,7 @@ gpu_only_cc_library( deps = [ ":gpu_executor_header", ":gpu_types_header", + "//xla/stream_executor:device_memory", "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_common", @@ -321,6 +322,7 @@ gpu_only_cc_library( ":gpu_event_header", ":gpu_executor_header", ":gpu_types_header", + "//xla/stream_executor:device_memory", "//xla/stream_executor:event", "//xla/stream_executor:platform", "//xla/stream_executor:stream", diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h index 31398f60cef67c..775ef19365a99f 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h @@ -208,8 +208,6 @@ class GpuExecutor : public StreamExecutorCommon { const DeviceMemoryBase& gpu_src, uint64_t size) override; - absl::Status MemZero(Stream* stream, DeviceMemoryBase* location, - uint64_t size) override; absl::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8_t pattern, uint64_t size) override; absl::Status Memset32(Stream* stream, DeviceMemoryBase* location, diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc index ab5531546b4839..536bb1adefc9c4 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc @@ -15,12 +15,14 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_stream.h" +#include #include #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_event.h" @@ -56,6 +58,15 @@ Stream::PlatformSpecificHandle GpuStream::platform_specific_handle() const { return handle; } +absl::Status GpuStream::MemZero(DeviceMemoryBase* location, uint64_t size) { + if (reinterpret_cast(location->opaque()) % 4 == 0 && + size % 4 == 0) { + return Memset32(location, 0x0, size); + } else { + return parent_->Memset(this, location, 0x0, size); + } +} + absl::Status GpuStream::WaitFor(Stream* other) { GpuStream* other_gpu = AsGpuStream(other); GpuEventHandle other_completed_event = *(other_gpu->completed_event()); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h index c001bec48ce72a..d3ac630f214ce1 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/log/check.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/platform.h" @@ -97,6 +98,7 @@ class GpuStream : public StreamCommon { absl::Status WaitFor(Stream* other) override; absl::Status WaitFor(Event* event) override; absl::Status RecordEvent(Event* event) override; + absl::Status MemZero(DeviceMemoryBase* location, uint64_t size) override; private: GpuExecutor* parent_; // Executor that spawned this stream. diff --git a/third_party/xla/xla/stream_executor/host/BUILD b/third_party/xla/xla/stream_executor/host/BUILD index 169d3cbf5152b0..b1b2f89f71072a 100644 --- a/third_party/xla/xla/stream_executor/host/BUILD +++ b/third_party/xla/xla/stream_executor/host/BUILD @@ -81,6 +81,7 @@ cc_library( ], deps = [ ":host_event", + "//xla/stream_executor:device_memory", "//xla/stream_executor:event", "//xla/stream_executor:stream", "//xla/stream_executor:stream_common", diff --git a/third_party/xla/xla/stream_executor/host/host_executor.cc b/third_party/xla/xla/stream_executor/host/host_executor.cc index defa623447f029..0d75e7e19962d6 100644 --- a/third_party/xla/xla/stream_executor/host/host_executor.cc +++ b/third_party/xla/xla/stream_executor/host/host_executor.cc @@ -178,16 +178,6 @@ bool HostExecutor::MemcpyDeviceToDevice(Stream* stream, return true; } -absl::Status HostExecutor::MemZero(Stream* stream, DeviceMemoryBase* location, - uint64_t size) { - void* gpu_mem = location->opaque(); - // Enqueue the [asynchronous] memzero on the stream (HostStream) associated - // with the HostExecutor. - AsHostStream(stream)->EnqueueTask( - [gpu_mem, size]() { memset(gpu_mem, 0, size); }); - return absl::OkStatus(); -} - absl::Status HostExecutor::Memset(Stream* stream, DeviceMemoryBase* location, uint8 pattern, uint64_t size) { void* gpu_mem = location->opaque(); diff --git a/third_party/xla/xla/stream_executor/host/host_executor.h b/third_party/xla/xla/stream_executor/host/host_executor.h index c1d08a923cef2e..5339315985d1a8 100644 --- a/third_party/xla/xla/stream_executor/host/host_executor.h +++ b/third_party/xla/xla/stream_executor/host/host_executor.h @@ -98,8 +98,6 @@ class HostExecutor : public StreamExecutorCommon { const DeviceMemoryBase& gpu_src, uint64_t size) override; - absl::Status MemZero(Stream* stream, DeviceMemoryBase* location, - uint64_t size) override; absl::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8_t pattern, uint64_t size) override; absl::Status Memset32(Stream* stream, DeviceMemoryBase* location, diff --git a/third_party/xla/xla/stream_executor/host/host_stream.cc b/third_party/xla/xla/stream_executor/host/host_stream.cc index 38125bcd8980e9..3a08013ee1225b 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream.cc +++ b/third_party/xla/xla/stream_executor/host/host_stream.cc @@ -53,6 +53,14 @@ HostStream::~HostStream() { parent()->DeallocateStream(this); } +absl::Status HostStream::MemZero(DeviceMemoryBase* location, uint64_t size) { + void* gpu_mem = location->opaque(); + // Enqueue the [asynchronous] memzero on the stream (HostStream) associated + // with the HostExecutor. + EnqueueTask([gpu_mem, size]() { memset(gpu_mem, 0, size); }); + return absl::OkStatus(); +} + absl::Status HostStream::WaitFor(Stream* other) { auto event = std::make_shared(); static_cast(other)->EnqueueTask([event]() { event->Notify(); }); diff --git a/third_party/xla/xla/stream_executor/host/host_stream.h b/third_party/xla/xla/stream_executor/host/host_stream.h index f579df98136a6c..105b4704f8609b 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream.h +++ b/third_party/xla/xla/stream_executor/host/host_stream.h @@ -26,6 +26,7 @@ limitations under the License. #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/synchronization/mutex.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream_common.h" #include "tsl/platform/env.h" #include "tsl/platform/thread_annotations.h" @@ -53,6 +54,7 @@ class HostStream : public StreamCommon { absl::Status WaitFor(Stream* other) override; absl::Status WaitFor(Event* event) override; absl::Status RecordEvent(Event* event) override; + absl::Status MemZero(DeviceMemoryBase* location, uint64_t size) override; private: bool WorkAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); diff --git a/third_party/xla/xla/stream_executor/mock_stream_executor.h b/third_party/xla/xla/stream_executor/mock_stream_executor.h index e5494f9ea45712..f5b1dc722d98ae 100644 --- a/third_party/xla/xla/stream_executor/mock_stream_executor.h +++ b/third_party/xla/xla/stream_executor/mock_stream_executor.h @@ -105,9 +105,6 @@ class MockStreamExecutor : public StreamExecutor { (void* host_dst, const DeviceMemoryBase& device_src, uint64_t size), (override)); - MOCK_METHOD(absl::Status, MemZero, - (Stream * stream, DeviceMemoryBase* location, uint64_t size), - (override)); MOCK_METHOD(absl::Status, Memset, (Stream * stream, DeviceMemoryBase* location, uint8_t pattern, uint64_t size), diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc index dc9c30630d0415..32b9fc4267d694 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc @@ -572,16 +572,6 @@ absl::Status GpuExecutor::SynchronousMemcpy(void* host_dst, AsROCmDevicePtr(gpu_src), size); } -absl::Status GpuExecutor::MemZero(Stream* stream, DeviceMemoryBase* location, - uint64_t size) { - if (reinterpret_cast(location->opaque()) % 4 == 0 && - size % 4 == 0) { - return Memset32(stream, location, 0x0, size); - } else { - return Memset(stream, location, 0x0, size); - } -} - absl::Status GpuExecutor::Memset(Stream* stream, DeviceMemoryBase* location, uint8 pattern, uint64_t size) { VLOG(2) << "enqueueing memset8 operation onto stream " << stream diff --git a/third_party/xla/xla/stream_executor/stream.h b/third_party/xla/xla/stream_executor/stream.h index c508e8894ffee0..4fd904befad159 100644 --- a/third_party/xla/xla/stream_executor/stream.h +++ b/third_party/xla/xla/stream_executor/stream.h @@ -199,11 +199,13 @@ class Stream { return Memcpy(gpu_dst, gpu_src, size); } - // Entrain onto the stream: a memset of zero at a GPU location of size bytes. - // The location must not be null. - virtual absl::Status MemZero(DeviceMemoryBase *location, uint64_t size) = 0; + // Entrain onto the stream: a memset of zero at a device location of size + // bytes. The location must not be null. + virtual absl::Status MemZero(DeviceMemoryBase *location, uint64_t size) { + return absl::UnimplementedError("MemZero is not supported on this stream."); + } - // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location of + // Entrain onto the stream: a memset of a 32-bit pattern at device location of // size bytes, where bytes must be evenly 32-bit sized (i.e. evenly divisible // by 4). The location must not be null. virtual absl::Status Memset32(DeviceMemoryBase *location, uint32_t pattern, diff --git a/third_party/xla/xla/stream_executor/stream_common.cc b/third_party/xla/xla/stream_executor/stream_common.cc index efadd05176fb71..cf4c3ee7efd75e 100644 --- a/third_party/xla/xla/stream_executor/stream_common.cc +++ b/third_party/xla/xla/stream_executor/stream_common.cc @@ -163,10 +163,6 @@ absl::Status StreamCommon::Memcpy(DeviceMemoryBase *gpu_dst, return absl::InternalError("failed to memcpy"); } -absl::Status StreamCommon::MemZero(DeviceMemoryBase *location, uint64_t size) { - return parent_->MemZero(this, location, size); -} - absl::Status StreamCommon::Memset32(DeviceMemoryBase *location, uint32_t pattern, uint64_t size) { return parent_->Memset32(this, location, pattern, size); diff --git a/third_party/xla/xla/stream_executor/stream_common.h b/third_party/xla/xla/stream_executor/stream_common.h index db1ef443ede355..39e3e025fb2245 100644 --- a/third_party/xla/xla/stream_executor/stream_common.h +++ b/third_party/xla/xla/stream_executor/stream_common.h @@ -77,7 +77,6 @@ class StreamCommon : public Stream { uint64_t size) override; absl::Status Memcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, uint64_t size) override; - absl::Status MemZero(DeviceMemoryBase *location, uint64_t size) override; absl::Status Memset32(DeviceMemoryBase *location, uint32_t pattern, uint64_t size) override; absl::Status BlockHostUntilDone() override TF_LOCKS_EXCLUDED(mu_); diff --git a/third_party/xla/xla/stream_executor/stream_executor.h b/third_party/xla/xla/stream_executor/stream_executor.h index 4f14f718f43472..636dda8581203a 100644 --- a/third_party/xla/xla/stream_executor/stream_executor.h +++ b/third_party/xla/xla/stream_executor/stream_executor.h @@ -223,12 +223,6 @@ class StreamExecutor { return SynchronousMemcpy(host_dst, device_src, size); } - // Enqueues an operation onto stream to zero out size bytes at the given - // device memory location. Neither stream nor location may be null. Returns - // whether the operation was successfully enqueued onto the stream. - virtual absl::Status MemZero(Stream* stream, DeviceMemoryBase* location, - uint64_t size) = 0; - // Enqueues an operation onto stream to set 8-bit patterns starting at // location, for byte count given by size. Returns whether the operation was // successfully enqueued onto the stream. diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h index 2f43c0ba4859af..73785f36cf1e37 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h @@ -140,10 +140,6 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { // -- Unimplemented (stubbed out) methods. - absl::Status MemZero(Stream* stream, DeviceMemoryBase* location, - uint64_t size) override { - LOG(FATAL) << "not yet implemented"; - } absl::Status Memset32(Stream* stream, DeviceMemoryBase* location, uint32_t pattern, uint64_t size) override { LOG(FATAL) << "not yet implemented"; From 4fe124f534024e6b7c8d107b21cf3a7c36ec8813 Mon Sep 17 00:00:00 2001 From: pizzud Date: Tue, 18 Jun 2024 11:51:21 -0700 Subject: [PATCH 31/59] [NFC]xla_compile: Report the Status in the CompilationResult even when OK. Doing so allows us to distinguish compile-succeeded-but-crashing-in-exit cases. PiperOrigin-RevId: 644464487 --- third_party/xla/xla/service/BUILD | 7 +++++++ third_party/xla/xla/service/dump.cc | 20 +++++++++++++++++++ .../xla/xla/service/xla_compile_result.proto | 2 +- third_party/xla/xla/tools/BUILD | 1 + third_party/xla/xla/tools/xla_compile_lib.cc | 2 +- .../xla/xla/tools/xla_compile_lib_test.cc | 19 ++++++++++++++---- 6 files changed, 45 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 64a2394e165cf2..520fb05947ef27 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -612,10 +612,16 @@ cc_library( "//xla:util", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -624,6 +630,7 @@ cc_library( "@local_tsl//tsl/lib/io:zlib_outputbuffer", "@local_tsl//tsl/lib/strings:proto_serialization", "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:regexp", "@local_tsl//tsl/platform:status", diff --git a/third_party/xla/xla/service/dump.cc b/third_party/xla/xla/service/dump.cc index b23042a9ace747..bd3013677ffdc2 100644 --- a/third_party/xla/xla/service/dump.cc +++ b/third_party/xla/xla/service/dump.cc @@ -15,24 +15,39 @@ limitations under the License. #include "xla/service/dump.h" +#include #include +#include #include +#include #include +#include #include +#include +#include "absl/algorithm/container.h" +#include "absl/base/const_init.h" +#include "absl/container/flat_hash_set.h" #include "absl/functional/any_invocable.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "llvm/ADT/SmallString.h" #include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project #include "mlir/Transforms/LocationSnapshot.h" // from @llvm-project +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_graph_dumper.h" #include "xla/service/hlo_proto_util.h" #include "xla/util.h" @@ -40,6 +55,8 @@ limitations under the License. #include "tsl/lib/io/zlib_outputbuffer.h" #include "tsl/lib/strings/proto_serialization.h" #include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/file_system.h" #include "tsl/platform/path.h" #include "tsl/platform/regexp.h" #include "tsl/platform/status.h" @@ -654,6 +671,9 @@ void DumpProtobufToFile(const tsl::protobuf::Message& proto, CanonicalDebugOptions opts(debug_options); tsl::Env* env = tsl::Env::Default(); const std::string& dir = opts.dump_to; + if (dir.empty()) { + return; + } if (!env->IsDirectory(dir).ok()) { auto status = env->RecursivelyCreateDir(dir); if (!status.ok()) { diff --git a/third_party/xla/xla/service/xla_compile_result.proto b/third_party/xla/xla/service/xla_compile_result.proto index ed5982e270f9b2..5846b8b11bacc2 100644 --- a/third_party/xla/xla/service/xla_compile_result.proto +++ b/third_party/xla/xla/service/xla_compile_result.proto @@ -43,7 +43,7 @@ message CompilationResult { // Always set when compilation succeeds. May or may not be set when // compilation fails. optional CompilerPerfStats perf_stats = 2; - // Always set when compilation fails; never set when compilation succeeds. + // Always set even when compilation succeeds. optional tensorflow.StatusProto status = 3; // Collects counters collected during compilation. Not every producer may // include counter support at all or any particular counter. diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index cb563413b70c7f..d7109b0b69f5e2 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -839,6 +839,7 @@ xla_test( "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@local_tsl//tsl/protobuf:status_proto_cc", ], ) diff --git a/third_party/xla/xla/tools/xla_compile_lib.cc b/third_party/xla/xla/tools/xla_compile_lib.cc index 86db193d31ed1d..85ea437beaef25 100644 --- a/third_party/xla/xla/tools/xla_compile_lib.cc +++ b/third_party/xla/xla/tools/xla_compile_lib.cc @@ -352,8 +352,8 @@ absl::Status XlaCompileMain(const XlaCompileOptions& options) { } auto result = CompileExecutable(std::move(hlo_module), backend, std::move(cfg), compilation_result); + *compilation_result.mutable_status() = tsl::StatusToProto(result.status()); if (!result.ok()) { - *compilation_result.mutable_status() = tsl::StatusToProto(result.status()); return result.status(); } diff --git a/third_party/xla/xla/tools/xla_compile_lib_test.cc b/third_party/xla/xla/tools/xla_compile_lib_test.cc index 18ccf33263cb12..6bf9051f221c83 100644 --- a/third_party/xla/xla/tools/xla_compile_lib_test.cc +++ b/third_party/xla/xla/tools/xla_compile_lib_test.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" #include "tsl/protobuf/error_codes.pb.h" +#include "tsl/protobuf/status.pb.h" namespace xla { namespace { @@ -173,9 +174,9 @@ TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(MainForCpu)) { module_->ToString())); const std::string output_path = - tsl::io::JoinPath(tsl::testing::TmpDir(), "output"); + tsl::io::JoinPath(tsl::testing::TmpDir(), "cpu_output"); const std::string result_file = - tsl::io::JoinPath(tsl::testing::TmpDir(), "result.pb"); + tsl::io::JoinPath(tsl::testing::TmpDir(), "cpu_result.pb"); XlaCompileOptions options; options.module_path = module_file; @@ -183,6 +184,11 @@ TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(MainForCpu)) { options.platform = "cpu"; options.result_output_file = result_file; TF_EXPECT_OK(XlaCompileMain(options)); + + CompilationResult result; + TF_ASSERT_OK(tsl::ReadBinaryProto(tsl::Env::Default(), result_file, &result)); + EXPECT_TRUE(result.has_status()); + EXPECT_EQ(result.status().code(), tensorflow::error::OK); } TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(MainForGpu)) { @@ -192,9 +198,9 @@ TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(MainForGpu)) { module_->ToString())); const std::string output_path = - tsl::io::JoinPath(tsl::testing::TmpDir(), "output"); + tsl::io::JoinPath(tsl::testing::TmpDir(), "gpu_output"); const std::string result_file = - tsl::io::JoinPath(tsl::testing::TmpDir(), "result.pb"); + tsl::io::JoinPath(tsl::testing::TmpDir(), "gpu_result.pb"); XlaCompileOptions options; options.module_path = module_file; @@ -203,6 +209,11 @@ TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(MainForGpu)) { options.result_output_file = result_file; options.gpu_options.use_attached_device = true; TF_EXPECT_OK(XlaCompileMain(options)); + + CompilationResult result; + TF_ASSERT_OK(tsl::ReadBinaryProto(tsl::Env::Default(), result_file, &result)); + EXPECT_TRUE(result.has_status()); + EXPECT_EQ(result.status().code(), tensorflow::error::OK); } } // namespace From 5fb87afc6e536b5a2527d0576544aaf4c09761cd Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 18 Jun 2024 12:22:41 -0700 Subject: [PATCH 32/59] [xla:cpu] Add support for single replica all-reduce + prepare for enabling "real" all-reduce tests PiperOrigin-RevId: 644474476 --- .../xla/xla/service/collective_ops_utils.h | 13 ++ third_party/xla/xla/service/cpu/BUILD | 1 + .../xla/xla/service/cpu/cpu_executable.cc | 6 +- third_party/xla/xla/service/cpu/runtime/BUILD | 28 ++++ .../service/cpu/runtime/all_reduce_thunk.cc | 120 ++++++++++++++---- .../service/cpu/runtime/all_reduce_thunk.h | 30 +++-- .../service/cpu/runtime/collective_thunk.cc | 82 ++++++++++++ .../service/cpu/runtime/collective_thunk.h | 61 +++++++++ .../xla/xla/service/cpu/runtime/thunk.cc | 30 ++++- .../xla/xla/service/cpu/runtime/thunk.h | 29 ++++- .../xla/xla/service/cpu/thunk_emitter.cc | 42 +++++- .../xla/xla/service/gpu/runtime/nccl_api.cc | 17 +-- third_party/xla/xla/tests/BUILD | 6 +- third_party/xla/xla/tests/all_reduce_test.cc | 2 +- 14 files changed, 405 insertions(+), 62 deletions(-) create mode 100644 third_party/xla/xla/service/cpu/runtime/collective_thunk.cc create mode 100644 third_party/xla/xla/service/cpu/runtime/collective_thunk.h diff --git a/third_party/xla/xla/service/collective_ops_utils.h b/third_party/xla/xla/service/collective_ops_utils.h index 702110e0f40f58..5956fdd755eb35 100644 --- a/third_party/xla/xla/service/collective_ops_utils.h +++ b/third_party/xla/xla/service/collective_ops_utils.h @@ -40,6 +40,19 @@ namespace xla { enum class ReductionKind { SUM, PRODUCT, MIN, MAX }; +constexpr std::string_view ReductionKindToString(ReductionKind reduction_kind) { + switch (reduction_kind) { + case ReductionKind::SUM: + return "sum"; + case ReductionKind::PRODUCT: + return "prod"; + case ReductionKind::MIN: + return "min"; + case ReductionKind::MAX: + return "max"; + } +} + // Attempts to match instruction to one of the possible cases for ReductionKind. std::optional MatchReductionInstruction( const HloInstruction* hlo); diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index ca57de781cfc3e..6753c4595dba50 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -809,6 +809,7 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", + "//xla/service:collective_ops_utils", "//xla/service:hlo_module_config", "//xla/service/cpu:dot_op_emitter", "//xla/service/cpu/runtime:all_reduce_thunk", diff --git a/third_party/xla/xla/service/cpu/cpu_executable.cc b/third_party/xla/xla/service/cpu/cpu_executable.cc index c83305da34d7fb..4c15f9fa17a869 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.cc +++ b/third_party/xla/xla/service/cpu/cpu_executable.cc @@ -354,10 +354,14 @@ absl::Status CpuExecutable::ExecuteThunks( profile_counters_size); VLOG(3) << absl::StrFormat(" Profile counters: %p", profile_counters); + // Prepare for executing XLA program collectively. + TF_ASSIGN_OR_RETURN(Thunk::CollectiveExecuteParams collective_execute_params, + Thunk::CollectiveExecuteParams::Create(run_options)); + Thunk::ExecuteParams execute_params = { &*host_kernels_, &allocations, runtime::GetXfeedManager(run_options->device_ordinal()), - run_options->intra_op_thread_pool()}; + run_options->intra_op_thread_pool(), &collective_execute_params}; auto executed_event = thunks_->Execute(execute_params); tsl::BlockUntilReady(executed_event); diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD index 72e3d496856abe..1e635a1fb75f4c 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/service/cpu/runtime/BUILD @@ -43,8 +43,12 @@ cc_library( hdrs = ["thunk.h"], deps = [ ":buffer_allocations", + "//xla:executable_run_options", + "//xla:util", "//xla/runtime:buffer_use", + "//xla/service:global_device_id", "//xla/service/cpu:cpu_runtime", + "//xla/stream_executor", "//xla/stream_executor/host:host_kernel_c_api", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/base:core_headers", @@ -168,11 +172,14 @@ cc_library( srcs = ["all_reduce_thunk.cc"], hdrs = ["all_reduce_thunk.h"], deps = [ + ":collective_thunk", ":thunk", "//xla:shape_util", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", + "//xla/service:collective_ops_utils", "//xla/stream_executor", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/algorithm:container", @@ -190,6 +197,27 @@ cc_library( ], ) +cc_library( + name = "collective_thunk", + srcs = ["collective_thunk.cc"], + hdrs = ["collective_thunk.h"], + deps = [ + ":thunk", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/service:collective_ops_utils", + "//xla/service:computation_placer", + "//xla/service:global_device_id", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "copy_thunk", srcs = ["copy_thunk.cc"], diff --git a/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.cc b/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.cc index baee736b245b8c..0f59a38fd6e645 100644 --- a/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/cpu/runtime/all_reduce_thunk.h" #include +#include #include #include #include @@ -26,44 +27,94 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/primitive_util.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/cpu/runtime/collective_thunk.h" #include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/util.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { +namespace { + +static bool IsDataTypeSupportedByCollectiveReduce(PrimitiveType datatype) { + switch (datatype) { + case PRED: + case S8: + case U8: + case S16: + case U16: + case S32: + case U32: + case S64: + case U64: + case F16: + case F32: + case F64: + case C64: + case C128: + return true; + default: + return false; + } +} + +} // namespace absl::StatusOr> AllReduceThunk::Create( - Info info, absl::Span source_buffers, + Info info, ReductionKind reduction_kind, OpParams op_params, + absl::Span source_buffers, absl::Span source_shapes, - BufferAllocation::Slice destination_buffer, - const Shape& destination_shape) { - return absl::WrapUnique(new AllReduceThunk(std::move(info), source_buffers, - source_shapes, destination_buffer, - destination_shape)); + absl::Span destination_buffers, + absl::Span destination_shapes, bool single_replica) { + auto datatype = source_shapes[0].element_type(); + + // Check that the data types are supported. + if (!IsDataTypeSupportedByCollectiveReduce(datatype)) { + return Unimplemented("AllReduce for datatype '%s' is not supported", + primitive_util::LowercasePrimitiveTypeName(datatype)); + } + + return absl::WrapUnique(new AllReduceThunk( + std::move(info), reduction_kind, op_params, source_buffers, source_shapes, + destination_buffers, destination_shapes, single_replica)); } AllReduceThunk::AllReduceThunk( - Info info, absl::Span source_buffers, + Info info, ReductionKind reduction_kind, OpParams op_params, + absl::Span source_buffers, absl::Span source_shapes, - BufferAllocation::Slice destination_buffer, const Shape& destination_shape) - : Thunk(Kind::kAllReduce, info), + absl::Span destination_buffers, + absl::Span destination_shapes, bool single_replica) + : CollectiveThunk(Kind::kAllReduce, info, op_params), + reduction_kind_(reduction_kind), source_buffers_(source_buffers.begin(), source_buffers.end()), source_shapes_(source_shapes.begin(), source_shapes.end()), - destination_buffer_(destination_buffer), - destination_shape_(destination_shape) {} + destination_buffers_(destination_buffers.begin(), + destination_buffers.end()), + destination_shapes_(destination_shapes.begin(), destination_shapes.end()), + single_replica_(single_replica) {} tsl::AsyncValueRef AllReduceThunk::Execute( const ExecuteParams& params) { tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); size_t num_srcs = source_buffers_.size(); - VLOG(3) << absl::StreamFormat("AllReduce: #source_buffers=%d", num_srcs); + size_t num_dsts = destination_buffers_.size(); + DCHECK_EQ(num_srcs, num_dsts) << "Number of src and dst buffers must match"; + + VLOG(3) << absl::StreamFormat( + "AllReduce: #source_buffers=%d, #destination_buffers=%d, " + "reduction_kind=%s, single_replica=%v", + num_srcs, num_dsts, ReductionKindToString(reduction_kind_), + single_replica_); absl::InlinedVector source_data(num_srcs); for (int i = 0; i < num_srcs; ++i) { @@ -75,32 +126,53 @@ tsl::AsyncValueRef AllReduceThunk::Execute( source_buffers_[i].ToString(), source_data[i].opaque()); } - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase destination_data, - params.buffer_allocations->GetDeviceAddress(destination_buffer_)); - - VLOG(3) << absl::StreamFormat( - " dst: %s in slice %s (%p)", destination_shape_.ToString(true), - destination_buffer_.ToString(), destination_data.opaque()); + absl::InlinedVector destination_data(num_dsts); + for (int i = 0; i < num_dsts; ++i) { + TF_ASSIGN_OR_RETURN( + destination_data[i], + params.buffer_allocations->GetDeviceAddress(destination_buffers_[i])); + VLOG(3) << absl::StreamFormat( + " dst: %s in slice %s (%p)", destination_shapes_[i].ToString(true), + destination_buffers_[i].ToString(), destination_data[i].opaque()); + } // Handle single-replica case by copying the source to the destination. - if (num_srcs == 1) { + if (single_replica_) { DCHECK_EQ(source_data.size(), destination_data.size()); - std::memcpy(destination_data.opaque(), source_data[0].opaque(), - destination_data.size()); + for (int i = 0; i < num_srcs; ++i) { + std::memcpy(destination_data[i].opaque(), source_data[i].opaque(), + destination_data[i].size()); + } return OkExecuteEvent(); } + // For multi-replica case, we need collective parameters to be able to + // perform the all-reduce operation collectively with other replicas. + CollectiveExecuteParams* collective_params = params.collective_params; + if (collective_params == nullptr) { + return Internal( + "Collective parameters are not set for all-reduce operation"); + } + + TF_ASSIGN_OR_RETURN(RendezvousKey key, GetRendezvousKey(*collective_params)); + TF_ASSIGN_OR_RETURN( + int32_t rank, + RankInGlobalDevices(key, collective_params->global_device_id)); + + VLOG(3) << absl::StreamFormat(" rank=%d, key=%s", rank, key.ToString()); + return absl::UnimplementedError("AllReduceThunk::Execute not implemented"); } Thunk::BufferUses AllReduceThunk::buffer_uses() const { BufferUses uses; - uses.reserve(source_buffers_.size() + 1); + uses.reserve(source_buffers_.size() + destination_buffers_.size()); for (auto& source_buffer : source_buffers_) { uses.push_back(BufferUse::Read(source_buffer)); } - uses.push_back(BufferUse::Write(destination_buffer_)); + for (auto& destination_buffer : destination_buffers_) { + uses.push_back(BufferUse::Write(destination_buffer)); + } return uses; } diff --git a/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.h b/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.h index f5a8866377ffc0..4fae44c68bfd12 100644 --- a/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.h +++ b/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.h @@ -22,36 +22,46 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/cpu/runtime/collective_thunk.h" #include "xla/shape.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/xla_data.pb.h" namespace xla::cpu { -class AllReduceThunk final : public Thunk { +class AllReduceThunk final : public CollectiveThunk { public: + using CollectiveThunk::OpParams; + static absl::StatusOr> Create( - Info info, absl::Span source_buffers, + Info info, ReductionKind reduction_kind, OpParams op_params, + absl::Span source_buffers, absl::Span source_shapes, - BufferAllocation::Slice destination_buffer, - const Shape& destination_shape); + absl::Span destination_buffers, + absl::Span destination_shapes, bool single_replica); tsl::AsyncValueRef Execute(const ExecuteParams& params) final; BufferUses buffer_uses() const final; private: - AllReduceThunk(Info info, + AllReduceThunk(Info info, ReductionKind reduction_kind, OpParams op_params, absl::Span source_buffers, absl::Span source_shapes, - BufferAllocation::Slice destination_buffer, - const Shape& destination_shape); + absl::Span destination_buffers, + absl::Span destination_shapes, + bool single_replica); + + ReductionKind reduction_kind_; std::vector source_buffers_; std::vector source_shapes_; - BufferAllocation::Slice destination_buffer_; - Shape destination_shape_; + std::vector destination_buffers_; + std::vector destination_shapes_; + + bool single_replica_; }; } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc b/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc new file mode 100644 index 00000000000000..21676087e21f38 --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc @@ -0,0 +1,82 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/runtime/collective_thunk.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/computation_placer.h" +#include "xla/service/cpu/runtime/thunk.h" +#include "xla/service/global_device_id.h" +#include "xla/status_macros.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" + +namespace xla::cpu { + +CollectiveThunk::CollectiveThunk(Kind kind, Thunk::Info info, + OpParams op_params) + : Thunk(kind, info), op_params_(std::move(op_params)) {} + +absl::StatusOr CollectiveThunk::GetRendezvousKey( + const Thunk::CollectiveExecuteParams& params) { + TF_RET_CHECK(params.device_assignment) << "Device assignment is null"; + + const DeviceAssignment& device_assignment = *params.device_assignment; + RendezvousKey::CollectiveOpKind op_kind = op_params_.has_channel_id + ? RendezvousKey::kCrossModule + : RendezvousKey::kCrossReplica; + + TF_ASSIGN_OR_RETURN( + CollectiveOpGroupMode group_mode, + GetCollectiveOpGroupMode(op_params_.has_channel_id, + op_params_.use_global_device_ids)); + + TF_ASSIGN_OR_RETURN( + std::vector participating_devices, + GetParticipatingDevices(params.global_device_id, device_assignment, + op_params_.group, group_mode)); + + int num_local_participants = participating_devices.size(); + return RendezvousKey{params.run_id, std::move(participating_devices), + num_local_participants, op_kind, op_params_.op_id}; +} + +absl::StatusOr CollectiveThunk::RankInGlobalDevices( + const RendezvousKey& key, GlobalDeviceId device) { + auto it = absl::c_find(key.global_devices, device); + if (it == key.global_devices.end()) { + return InvalidArgument( + "Device %d not present in global devices %s.", device.value(), + absl::StrJoin(key.global_devices, ", ", + [](std::string* out, GlobalDeviceId id) { + absl::StrAppend(out, id.value()); + })); + } + return std::distance(key.global_devices.begin(), it); +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/collective_thunk.h b/third_party/xla/xla/service/cpu/runtime/collective_thunk.h new file mode 100644 index 00000000000000..022b65346a26f6 --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/collective_thunk.h @@ -0,0 +1,61 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_RUNTIME_COLLECTIVE_THUNK_H_ +#define XLA_SERVICE_CPU_RUNTIME_COLLECTIVE_THUNK_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/cpu/runtime/thunk.h" +#include "xla/service/global_device_id.h" + +namespace xla::cpu { + +class CollectiveThunk : public Thunk { + using Thunk::Thunk; + + public: + // Parameters of the collective operation behind the collective thunk. We rely + // on them to construct the rendezvous key and to find a thunk "location" in + // the collective operation "clique" (group of communicating devices). + struct OpParams { + int64_t op_id; + bool has_channel_id; + std::optional use_global_device_ids; + std::vector group; + }; + + CollectiveThunk(Kind kind, Thunk::Info info, OpParams op_params); + + const OpParams& op_params() const { return op_params_; } + + protected: + absl::StatusOr GetRendezvousKey( + const Thunk::CollectiveExecuteParams& params); + + absl::StatusOr RankInGlobalDevices(const RendezvousKey& key, + GlobalDeviceId device); + + private: + OpParams op_params_; +}; + +} // namespace xla::cpu + +#endif // XLA_SERVICE_CPU_RUNTIME_COLLECTIVE_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.cc b/third_party/xla/xla/service/cpu/runtime/thunk.cc index eaf30398705c6d..6da98ddb4f5b8a 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/thunk.cc @@ -15,15 +15,18 @@ limitations under the License. #include "xla/service/cpu/runtime/thunk.h" +#include #include #include #include #include #include -#include "absl/status/status.h" +#include "xla/executable_run_options.h" +#include "xla/service/global_device_id.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/lib/traceme_encode.h" @@ -57,6 +60,29 @@ std::string_view Thunk::KindToString(Kind kind) { } } +absl::StatusOr +Thunk::CollectiveExecuteParams::Create( + const ExecutableRunOptions* run_options) { + // Device ordinal must be set by caller and passed in run options, if not, + // we use the device ordinal from the parent StreamExecutor. + int32_t device_ordinal = + run_options->device_ordinal() >= 0 + ? run_options->device_ordinal() + : run_options->stream()->parent()->device_ordinal(); + + return CollectiveExecuteParams{run_options->run_id(), device_ordinal, + GlobalDeviceId(run_options->device_ordinal()), + run_options->device_assignment()}; +} + +Thunk::CollectiveExecuteParams::CollectiveExecuteParams( + RunId run_id, int64_t local_device_ordinal, GlobalDeviceId global_device_id, + const DeviceAssignment* device_assignment) + : run_id(run_id), + local_device_ordinal(local_device_ordinal), + global_device_id(global_device_id), + device_assignment(device_assignment) {} + tsl::AsyncValueRef Thunk::OkExecuteEvent() { static tsl::AsyncValueOwningRef* event = [] { auto* storage = new tsl::internal::AsyncValueStorage(); diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.h b/third_party/xla/xla/service/cpu/runtime/thunk.h index d27846d848afdc..8074412330f873 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.h +++ b/third_party/xla/xla/service/cpu/runtime/thunk.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef XLA_SERVICE_CPU_RUNTIME_THUNK_H_ #define XLA_SERVICE_CPU_RUNTIME_THUNK_H_ -#include #include #include #include @@ -26,13 +25,13 @@ limitations under the License. #include #include -#include "absl/base/optimization.h" #include "absl/container/inlined_vector.h" -#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/executable_run_options.h" #include "xla/runtime/buffer_use.h" #include "xla/service/cpu/runtime/buffer_allocations.h" #include "xla/service/cpu/xfeed_manager.h" +#include "xla/service/global_device_id.h" #include "xla/stream_executor/host/host_kernel_c_api.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/chain.h" @@ -111,6 +110,29 @@ class Thunk { virtual absl::StatusOr Find(std::string_view name) = 0; }; + //===--------------------------------------------------------------------===// + // CollectiveExecuteParams + //===--------------------------------------------------------------------===// + + // Parameters capturing all the details required for collective execution of + // XLA executables (multiple partitions and replicas). + struct CollectiveExecuteParams { + static absl::StatusOr Create( + const ExecutableRunOptions* run_options); + + RunId run_id; + + int64_t local_device_ordinal; + GlobalDeviceId global_device_id; + + const DeviceAssignment* device_assignment = nullptr; + + private: + CollectiveExecuteParams(RunId run_id, int64_t local_device_ordinal, + GlobalDeviceId global_device_id, + const DeviceAssignment* device_assignment); + }; + //===--------------------------------------------------------------------===// // ExecuteParams //===--------------------------------------------------------------------===// @@ -122,6 +144,7 @@ class Thunk { const BufferAllocations* buffer_allocations = nullptr; runtime::XfeedManager* xfeed = nullptr; const Eigen::ThreadPoolDevice* intra_op_threadpool = nullptr; + CollectiveExecuteParams* collective_params = nullptr; }; // An execute event that becomes ready when all tasks are completed. diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.cc b/third_party/xla/xla/service/cpu/thunk_emitter.cc index aab25ac8a7cc55..fe01f464d28b89 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.cc +++ b/third_party/xla/xla/service/cpu/thunk_emitter.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/cpu/thunk_emitter.h" +#include #include #include #include @@ -32,6 +33,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/dot_op_emitter.h" #include "xla/service/cpu/ir_emitter2.h" #include "xla/service/cpu/runtime/all_reduce_thunk.h" @@ -248,21 +250,51 @@ absl::StatusOr ThunkEmitter::EmitHloInstruction( absl::StatusOr ThunkEmitter::EmitAllReduceThunk( const HloInstruction* instruction) { + auto* all_reduce = Cast(instruction); + + // Check that we recognize the reduction computation attached to a collective. + auto reduction_kind = MatchReductionComputation(all_reduce->to_apply()); + if (!reduction_kind.has_value()) { + return Unimplemented("AllReduce for computation '%s' is not supported", + all_reduce->to_apply()->ToString()); + } + + // Collect buffer slices for all operands. std::vector source_buffers; std::vector source_shapes; - for (const HloInstruction* operand : instruction->operands()) { + for (const HloInstruction* operand : all_reduce->operands()) { TF_ASSIGN_OR_RETURN(source_buffers.emplace_back(), GetAllocationSlice(operand)); source_shapes.push_back(operand->shape()); } - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice destination_buffer, - GetAllocationSlice(instruction)); + // Collect buffer slices for all results. + std::vector destination_buffers; + std::vector destination_shapes; + + for (auto& indexed : ShapeUtil::GetLeafShapes(instruction->shape())) { + TF_ASSIGN_OR_RETURN(destination_buffers.emplace_back(), + GetAllocationSlice(instruction, indexed.index)); + destination_shapes.push_back(indexed.shape); + } + + AllReduceThunk::OpParams op_params = { + /*op_id=*/all_reduce->channel_id().has_value() + ? all_reduce->channel_id().value() + : all_reduce->GetModule()->unique_id(), + /*has_channel_id=*/all_reduce->channel_id().has_value(), + /*use_global_device_ids=*/all_reduce->use_global_device_ids(), + /*replica_groups=*/all_reduce->replica_groups(), + }; + + bool single_replica = hlo_module_config_.replica_count() == 1 && + hlo_module_config_.num_partitions() == 1; return ThunkSequence::Of( - ThunkInfo(instruction), source_buffers, source_shapes, destination_buffer, - instruction->shape()); + ThunkInfo(all_reduce), *reduction_kind, std::move(op_params), + source_buffers, source_shapes, destination_buffers, destination_shapes, + single_replica); } absl::StatusOr ThunkEmitter::EmitCallThunk( diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_api.cc b/third_party/xla/xla/service/gpu/runtime/nccl_api.cc index 49abfa4636b44e..79f89590c9d430 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_api.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_api.cc @@ -165,19 +165,6 @@ static ncclRedOp_t ToNcclReduction(ReductionKind kind) { } } -static std::string_view ToString(ReductionKind reduction_kind) { - switch (reduction_kind) { - case ReductionKind::SUM: - return "sum"; - case ReductionKind::PRODUCT: - return "prod"; - case ReductionKind::MIN: - return "min"; - case ReductionKind::MAX: - return "max"; - } -} - //==-----------------------------------------------------------------------===// // Casting between opaque API structs and NCCL types. //==-----------------------------------------------------------------------===// @@ -531,7 +518,7 @@ absl::Status DefaultNcclApi::AllReduce(se::DeviceMemoryBase send_buffer, "stream=%p", stream->parent()->device_ordinal(), send_buffer.opaque(), recv_buffer.opaque(), primitive_util::LowercasePrimitiveTypeName(dtype), - count, ToString(reduction_kind), comm, stream); + count, ReductionKindToString(reduction_kind), comm, stream); TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); @@ -573,7 +560,7 @@ absl::Status DefaultNcclApi::ReduceScatter(se::DeviceMemoryBase send_buffer, "stream=%p", stream->parent()->device_ordinal(), send_buffer.opaque(), recv_buffer.opaque(), primitive_util::LowercasePrimitiveTypeName(dtype), - count, ToString(reduction_kind), comm, stream); + count, ReductionKindToString(reduction_kind), comm, stream); TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 1e69e2061f4e5e..918992c6aeff99 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -2131,12 +2131,16 @@ xla_test( # All reduce is not supported on the interpreter backend. "interpreter", ], - tags = ["test_hlo_pjrt_runner"], + tags = [ + "test_hlo_pjrt_runner", + "test_xla_cpu_thunks", + ], deps = [ ":hlo_test_base", ":test_macros_header", ":xla_internal_test_main", "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:test", "//xla:test_helpers", diff --git a/third_party/xla/xla/tests/all_reduce_test.cc b/third_party/xla/xla/tests/all_reduce_test.cc index f9dd1f3e8777ee..2d449f31c61ac2 100644 --- a/third_party/xla/xla/tests/all_reduce_test.cc +++ b/third_party/xla/xla/tests/all_reduce_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/shape_util.h" #include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" From 5f3eacb081e51a15123173de32b1dd1c055613b3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 18 Jun 2024 12:29:28 -0700 Subject: [PATCH 33/59] Reverts fee3bfc812780f9c01a4fd936f69562a6884582a PiperOrigin-RevId: 644476286 --- tensorflow/core/grappler/BUILD | 1 - tensorflow/core/grappler/op_types.cc | 105 +------------ tensorflow/core/grappler/op_types.h | 9 -- tensorflow/core/grappler/optimizers/BUILD | 9 -- .../grappler/optimizers/constant_folding.cc | 64 +++++++- .../core/grappler/optimizers/remapper.cc | 127 +++------------ .../core/grappler/optimizers/remapper_test.cc | 145 +----------------- tensorflow/core/kernels/BUILD | 1 - .../core/kernels/fused_eigen_output_kernels.h | 12 +- tensorflow/core/kernels/mkl/BUILD | 1 - .../python/grappler/cost_analyzer_test.py | 11 +- 11 files changed, 102 insertions(+), 383 deletions(-) diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD index 99c53397ee47c8..ae85e69d064649 100644 --- a/tensorflow/core/grappler/BUILD +++ b/tensorflow/core/grappler/BUILD @@ -29,7 +29,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 36a1918256ea22..e0981fe90c8ae9 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -15,16 +15,9 @@ limitations under the License. #include "tensorflow/core/grappler/op_types.h" -#include -#include - -#include "absl/strings/match.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -34,36 +27,8 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace { -template -bool AllValuesAre(const TensorProto& proto, const T& value) { - Tensor tensor; - if (!tensor.FromProto(proto)) { - return false; - } - auto values = tensor.flat(); - for (int i = 0; i < tensor.NumElements(); ++i) { - if (values(i) != value) { - return false; - } - } - return true; -} - -#define IS_VALUE_CASE(DTYPE, VALUE) \ - case DTYPE: \ - return AllValuesAre::Type>( \ - tensor, EnumToDataType::Type(VALUE)) - -#define IS_ONES_CASE(TYPE) IS_VALUE_CASE(TYPE, 1) -#define IS_ZEROS_CASE(TYPE) IS_VALUE_CASE(TYPE, 0) - -} // namespace - -bool IsAddV2(const NodeDef& node) { return node.op() == "AddV2"; } - bool IsAdd(const NodeDef& node) { - if (IsAddV2(node)) { + if (node.op() == "AddV2") { return true; } if (node.op() == "Add") { @@ -1033,73 +998,5 @@ bool NeverForwardsInputs(const NodeDef& node) { bool IsXlaLaunch(const NodeDef& node) { return node.op() == "XlaLaunch"; } -bool IsZeroTensor(const TensorProto& tensor, const DataType& dtype) { - switch (dtype) { - IS_ZEROS_CASE(DT_BOOL); - IS_ZEROS_CASE(DT_HALF); - IS_ZEROS_CASE(DT_BFLOAT16); - IS_ZEROS_CASE(DT_FLOAT); - IS_ZEROS_CASE(DT_DOUBLE); - IS_ZEROS_CASE(DT_COMPLEX64); - IS_ZEROS_CASE(DT_COMPLEX128); - IS_ZEROS_CASE(DT_UINT8); - IS_ZEROS_CASE(DT_INT8); - IS_ZEROS_CASE(DT_UINT16); - IS_ZEROS_CASE(DT_INT16); - IS_ZEROS_CASE(DT_INT32); - IS_ZEROS_CASE(DT_INT64); - IS_ZEROS_CASE(DT_QINT32); - IS_ZEROS_CASE(DT_QINT16); - IS_ZEROS_CASE(DT_QUINT16); - IS_ZEROS_CASE(DT_QINT8); - IS_ZEROS_CASE(DT_QUINT8); - default: - VLOG(1) << "Unsupported type " << DataTypeString(dtype); - return false; - } - return false; -} - -bool IsZerosNode(const NodeDef& node) { - if (!IsConstant(node)) return false; - if (node.attr().count("dtype") == 0) return false; - return IsZeroTensor(node.attr().at("value").tensor(), - node.attr().at("dtype").type()); -} - -bool IsOneTensor(const TensorProto& tensor, const DataType& dtype) { - switch (dtype) { - IS_ONES_CASE(DT_BOOL); - IS_ONES_CASE(DT_HALF); - IS_ONES_CASE(DT_BFLOAT16); - IS_ONES_CASE(DT_FLOAT); - IS_ONES_CASE(DT_DOUBLE); - IS_ONES_CASE(DT_COMPLEX64); - IS_ONES_CASE(DT_COMPLEX128); - IS_ONES_CASE(DT_UINT8); - IS_ONES_CASE(DT_INT8); - IS_ONES_CASE(DT_UINT16); - IS_ONES_CASE(DT_INT16); - IS_ONES_CASE(DT_INT32); - IS_ONES_CASE(DT_INT64); - IS_ONES_CASE(DT_QINT32); - IS_ONES_CASE(DT_QINT16); - IS_ONES_CASE(DT_QUINT16); - IS_ONES_CASE(DT_QINT8); - IS_ONES_CASE(DT_QUINT8); - default: - VLOG(1) << "Unsupported type " << DataTypeString(dtype); - return false; - } - return false; -} - -bool IsOnesNode(const NodeDef& node) { - if (!IsConstant(node)) return false; - if (node.attr().count("dtype") == 0) return false; - const auto dtype = node.attr().at("dtype").type(); - return IsOneTensor(node.attr().at("value").tensor(), dtype); -} - } // namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index d034b4b204c706..c233b6e9c6b61a 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -23,7 +23,6 @@ limitations under the License. namespace tensorflow { namespace grappler { bool IsAdd(const NodeDef& node); -bool IsAddV2(const NodeDef& node); bool IsAddN(const NodeDef& node); bool IsAll(const NodeDef& node); bool IsAngle(const NodeDef& node); @@ -279,14 +278,6 @@ bool IsCastLike(const NodeDef& node); // allocates buffers for its inputs. bool NeverForwardsInputs(const NodeDef& node); -bool IsZeroTensor(const TensorProto& tensor, const DataType& dtype); - -bool IsOnesTensor(const TensorProto& tensor, const DataType& dtype); - -bool IsZerosNode(const NodeDef& node); - -bool IsOnesNode(const NodeDef& node); - } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 55f161beb46aa8..e967c46836756d 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -890,7 +890,6 @@ tf_kernel_library( deps = [ ":constant_folding", ":graph_optimizer", - "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -899,20 +898,12 @@ tf_kernel_library( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", - "//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core/grappler/utils:graph_view", "//tensorflow/core/grappler/utils:pattern_utils", "//tensorflow/core/grappler/utils:symbolic_shapes", "//tensorflow/core/grappler/utils:topological_sort", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@eigen_archive//:eigen3", ] + if_mkl(["//tensorflow/core/graph:mkl_graph_util"]), ) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 4d24b396a2e836..90eb63a22bcaa7 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1788,6 +1788,14 @@ Status ConstantFolding::IsSimplifiableReshape( return absl::OkStatus(); } +#define IS_VALUE_CASE(DTYPE, VALUE) \ + case DTYPE: \ + return AllValuesAre::Type>( \ + node.attr().at("value").tensor(), EnumToDataType::Type(VALUE)) + +#define IS_ONES_CASE(TYPE) IS_VALUE_CASE(TYPE, 1) +#define IS_ZEROS_CASE(TYPE) IS_VALUE_CASE(TYPE, 0) + bool ConstantFolding::IsOnes(const NodeDef& node) const { if (feed_nodes_.find(node.name()) != feed_nodes_.end()) { return false; @@ -1798,7 +1806,33 @@ bool ConstantFolding::IsOnes(const NodeDef& node) const { NodeDef* values = node_map_->GetNode(NodeName(node.input(1))); return values != nullptr && IsOnes(*values); } - return IsOnesNode(node); + if (node.op() != "Const") return false; + if (node.attr().count("dtype") == 0) return false; + const auto dtype = node.attr().at("dtype").type(); + switch (dtype) { + IS_ONES_CASE(DT_BOOL); + IS_ONES_CASE(DT_HALF); + IS_ONES_CASE(DT_BFLOAT16); + IS_ONES_CASE(DT_FLOAT); + IS_ONES_CASE(DT_DOUBLE); + IS_ONES_CASE(DT_COMPLEX64); + IS_ONES_CASE(DT_COMPLEX128); + IS_ONES_CASE(DT_UINT8); + IS_ONES_CASE(DT_INT8); + IS_ONES_CASE(DT_UINT16); + IS_ONES_CASE(DT_INT16); + IS_ONES_CASE(DT_INT32); + IS_ONES_CASE(DT_INT64); + IS_ONES_CASE(DT_QINT32); + IS_ONES_CASE(DT_QINT16); + IS_ONES_CASE(DT_QUINT16); + IS_ONES_CASE(DT_QINT8); + IS_ONES_CASE(DT_QUINT8); + default: + VLOG(1) << "Unsupported type " << DataTypeString(dtype); + return false; + } + return false; } bool ConstantFolding::IsZeros(const NodeDef& node) const { @@ -1811,7 +1845,33 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const { NodeDef* values = node_map_->GetNode(NodeName(node.input(1))); return values != nullptr && IsZeros(*values); } - return IsZerosNode(node); + if (!IsConstant(node)) return false; + if (node.attr().count("dtype") == 0) return false; + const auto dtype = node.attr().at("dtype").type(); + switch (dtype) { + IS_ZEROS_CASE(DT_BOOL); + IS_ZEROS_CASE(DT_HALF); + IS_ZEROS_CASE(DT_BFLOAT16); + IS_ZEROS_CASE(DT_FLOAT); + IS_ZEROS_CASE(DT_DOUBLE); + IS_ZEROS_CASE(DT_COMPLEX64); + IS_ZEROS_CASE(DT_COMPLEX128); + IS_ZEROS_CASE(DT_UINT8); + IS_ZEROS_CASE(DT_INT8); + IS_ZEROS_CASE(DT_UINT16); + IS_ZEROS_CASE(DT_INT16); + IS_ZEROS_CASE(DT_INT32); + IS_ZEROS_CASE(DT_INT64); + IS_ZEROS_CASE(DT_QINT32); + IS_ZEROS_CASE(DT_QINT16); + IS_ZEROS_CASE(DT_QUINT16); + IS_ZEROS_CASE(DT_QINT8); + IS_ZEROS_CASE(DT_QUINT8); + default: + VLOG(1) << "Unsupported type " << DataTypeString(dtype); + return false; + } + return false; } bool ConstantFolding::ReplaceOperationWithBroadcastTo( diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index bde4822ae625f8..da60dacd9aa5e1 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -24,25 +24,11 @@ limitations under the License. #include #include -#include "absl/algorithm/container.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "Eigen/Core" // from @eigen_archive -#include "tensorflow/core/framework/attr_value_util.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor.pb.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/versions.pb.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" @@ -50,14 +36,13 @@ limitations under the License. #include "tensorflow/core/grappler/utils/graph_view.h" #include "tensorflow/core/grappler/utils/pattern_utils.h" #include "tensorflow/core/grappler/utils/symbolic_shapes.h" -#include "tensorflow/core/platform/bfloat16.h" +#include "tensorflow/core/grappler/utils/topological_sort.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/status.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" #include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/use_cudnn.h" #include "tsl/platform/errors.h" -#include "tsl/platform/status.h" #ifdef INTEL_MKL #include "tensorflow/core/util/mkl_heuristics.h" #endif // INTEL_MKL @@ -525,24 +510,6 @@ bool RuntimeFusionEnabled(const Cluster* cluster) { return is_enabled; } -bool IsReluSemanticMaximum(const RemapperContext& ctx, - const utils::MutableNodeView& node_view) { - if (!IsMaximum(*node_view.node())) return false; - - const std::vector& input_props = - ctx.graph_properties.GetInputProperties(node_view.node()->name()); - if (input_props.size() < 2) return false; - - for (const auto& input_prop : input_props) { - if (input_prop.shape().dim_size() == 0 && - IsZeroTensor(input_prop.value(), input_prop.dtype())) { - return true; - break; - } - } - return false; -} - bool IsSupportedActivation(const NodeDef& node, const Cluster* cluster) { bool is_default_supported = IsRelu(node) || IsRelu6(node) || IsElu(node) || IsLeakyRelu(node); @@ -700,6 +667,8 @@ bool IsConvOrMatMul(const NodeDef& node) { bool IsBiasSemanticAdd(const RemapperContext& ctx, const utils::MutableNodeView& node_view, int& bias_port) { + if (!IsMKLEnabled()) return false; + const auto* node_def = node_view.node(); if (!NodeIsOnCpu(node_def)) return false; if (!IsAdd(*node_def) || node_view.NumRegularFanins() != 2) return false; @@ -713,6 +682,7 @@ bool IsBiasSemanticAdd(const RemapperContext& ctx, const auto& regular_fanin_1 = node_view.GetRegularFanin(1); const auto* node_view_1 = regular_fanin_1.node_view(); const auto* node_def_1 = node_view_1->node(); + if (!IsConvOrMatMul(*node_def_0) && !IsConvOrMatMul(*node_def_1)) return false; @@ -731,11 +701,13 @@ bool IsBiasSemanticAdd(const RemapperContext& ctx, const TensorShapeProto& prot0_shape = props[0].shape(); const TensorShapeProto& prot1_shape = props[1].shape(); + if (prot0_shape.unknown_rank() || prot1_shape.unknown_rank() || prot0_shape.dim_size() < 1 || prot1_shape.dim_size() < 1 || !IsKnown(prot0_shape.dim(prot0_shape.dim_size() - 1)) || !IsKnown(prot1_shape.dim(prot1_shape.dim_size() - 1))) return false; + // Helper function to check Add/AddV2 could be replaced with BiasAdd. const auto is_supported_shape = [&](const TensorShapeProto& shape, @@ -762,6 +734,7 @@ bool IsBiasSemanticAdd(const RemapperContext& ctx, if (ShapesSymbolicallyEqual(prot0_shape, prot1_shape) || !ShapesBroadcastable(prot0_shape, prot1_shape)) return false; + if (IsConvOrMatMul(*node_def_0)) { bias_port = 1; return (is_supported_shape(prot0_shape, prot1_shape)); @@ -883,7 +856,7 @@ bool FindFusedConvWithFusedActivation(const RemapperContext& ctx, return true; } -bool FindContractionWithBiasAddAndActivation( +bool FindContractionWithBiasAndActivation( const RemapperContext& ctx, Cluster* cluster, int node_index, ContractionWithBiasAddAndActivation* matched) { const auto* node_view = ctx.graph_view.GetNode(node_index); @@ -892,9 +865,8 @@ bool FindContractionWithBiasAddAndActivation( if (HasControlFaninOrFanout(*node_view)) return false; const auto* node_def = node_view->node(); - if (!IsSupportedActivation(*node_def, cluster) && - !(IsReluSemanticMaximum(ctx, *node_view))) - return false; + if (!IsSupportedActivation(*node_def, cluster)) return false; + // And input to the activation node must match ContractionWithBiasAdd pattern. if (node_view->NumRegularFanins() < 1) return false; const auto& regular_fanin_0 = node_view->GetRegularFanin(0); @@ -908,6 +880,7 @@ bool FindContractionWithBiasAddAndActivation( !HaveSameDataType(node_def, bias_add_node_def) || IsInPreserveSet(ctx, bias_add_node_def)) return false; + // Get the contraction node const auto* contraction_node_view = bias_add_node_view->GetRegularFanin(1 - base.bias_port).node_view(); @@ -3428,12 +3401,7 @@ Status AddFusedContractionNode( CopyConv3DAttributes(contraction, &fused_op, &activation); } - SetFusedOpAttributes( - &fused_op, - {"BiasAdd", - IsReluSemanticMaximum(*ctx, *ctx->graph_view.GetNode(matched.activation)) - ? "Relu" - : activation.op()}); + SetFusedOpAttributes(&fused_op, {"BiasAdd", activation.op()}); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); Status status; @@ -4704,7 +4672,6 @@ Status ReplaceSoftplusTanhAndMulWithMish( // (3) Fusing Conv2D biasadd and relu on GPU // (4) INTEL_MKL specific: Conv2D -> Add or Conv2D -> BiasAdd -> Add. // (5) Fusing side output and/or activation into FusedBatchNormGrad. -// (6) Fusing MatMul + AddV2 + Relu (Maximum(x, 0)) bool RequiresInferredShapes(const RemapperContext& ctx, int node_index, const Cluster* cluster) { // Candidate for a FusedBatchNorm splitting. @@ -4842,66 +4809,17 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index, return true; }; - const auto is_add_matmul_candidate = [&]() -> bool { - if (!IsAdd(*node_def) && !IsBiasAdd(*node_def)) return false; - if (node_view->NumRegularFanins() < 2) return false; - const auto& add_fanin_0 = node_view->GetRegularFanin(0); - const auto* add_fanin_0_node_view = add_fanin_0.node_view(); - const auto* add_fanin_0_node_def = add_fanin_0_node_view->node(); - const auto& add_fanin_1 = node_view->GetRegularFanin(1); - const auto* add_fanin_1_node_view = add_fanin_1.node_view(); - const auto* add_fanin_1_node_def = add_fanin_1_node_view->node(); - if (!IsMatMul(*add_fanin_0_node_def) && !IsMatMul(*add_fanin_1_node_def)) - return false; - return true; - }; - - const auto is_maximum_add_matmul_candidate = [&]() -> bool { - if (!IsMaximum(*node_def)) return false; - if (node_view->NumRegularFanins() < 2) return false; - const auto& max_fanin_0 = node_view->GetRegularFanin(0); - const auto* max_fanin_0_node_view = max_fanin_0.node_view(); - const auto* max_fanin_0_node_def = max_fanin_0_node_view->node(); - const auto& max_fanin_1 = node_view->GetRegularFanin(1); - const auto* max_fanin_1_node_view = max_fanin_1.node_view(); - const auto* max_fanin_1_node_def = max_fanin_1_node_view->node(); - const NodeDef* add_node_def = nullptr; - const utils::MutableNodeView* add_node_view = nullptr; - if (IsAdd(*max_fanin_0_node_def) || IsBiasAdd(*max_fanin_0_node_def)) { - add_node_def = max_fanin_0_node_def; - add_node_view = max_fanin_0_node_view; - } else if (IsAdd(*max_fanin_1_node_def) || - IsBiasAdd(*max_fanin_1_node_def)) { - add_node_def = max_fanin_1_node_def; - add_node_view = max_fanin_1_node_view; - } else { - return false; - } - if (add_node_view->NumRegularFanins() < 2) return false; - const auto& add_fanin_0 = add_node_view->GetRegularFanin(0); - const auto* add_fanin_0_node_view = add_fanin_0.node_view(); - const auto* add_fanin_0_node_def = add_fanin_0_node_view->node(); - const auto& add_fanin_1 = add_node_view->GetRegularFanin(1); - const auto* add_fanin_1_node_view = add_fanin_1.node_view(); - const auto* add_fanin_1_node_def = add_fanin_1_node_view->node(); - if (!IsMatMul(*add_fanin_0_node_def) && !IsMatMul(*add_fanin_1_node_def)) - return false; - return true; - }; - if (IsMKLEnabled()) return is_batch_norm_candidate() || is_batch_norm_fusion_candidate() || IsContractionWithAdd(ctx, node_index) || is_act_biasadd_conv_candidate() || IsBiasAdd(*node_def) || - IsTranspose(*node_def) || is_maximum_add_matmul_candidate() || - is_add_matmul_candidate(); + IsTranspose(*node_def); return is_act_biasadd_conv_candidate() || is_batch_norm_candidate() || is_batch_norm_fusion_candidate() || is_batch_norm_grad_fusion_candidate() || is_matmul_gelu_exact_fusion_candidate() || - is_act_biasadd_matmul_candidate() || - is_maximum_add_matmul_candidate() || is_add_matmul_candidate(); + is_act_biasadd_matmul_candidate(); } inline bool IsXlaCpuGlobalJitOn() { @@ -4947,6 +4865,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, if (invalidated_nodes[i] || nodes_to_delete[i]) { continue; } + // Infer properties lazily in case they are not needed. if (!ctx.inferred_graph_properties && RequiresInferredShapes(ctx, i, cluster)) { @@ -5017,6 +4936,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, continue; } #endif + PadWithConv3D pad_with_conv3d; // Remap Pad+{Conv3D,_FusedConv3D} into the _FusedConv3D. if (FindPadWithConv3D(ctx, i, &pad_with_conv3d)) { @@ -5186,13 +5106,12 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, // Remap {Conv2D,DepthwiseConv2D,MatMul,Conv3D}+BiasAdd+Activation into the // _Fused{Conv2D,DepthwiseConv2dNative,MatMul,Conv3D}. - - ContractionWithBiasAddAndActivation contract_with_bias_add_and_activation; + ContractionWithBiasAddAndActivation contract_with_bias_and_activation; if (allow_non_differentiable_rewrites && - FindContractionWithBiasAddAndActivation( - ctx, cluster, i, &contract_with_bias_add_and_activation)) { + FindContractionWithBiasAndActivation( + ctx, cluster, i, &contract_with_bias_and_activation)) { TF_RETURN_IF_ERROR( - AddFusedContractionNode(&ctx, contract_with_bias_add_and_activation, + AddFusedContractionNode(&ctx, contract_with_bias_and_activation, &invalidated_nodes, &nodes_to_delete)); continue; } diff --git a/tensorflow/core/grappler/optimizers/remapper_test.cc b/tensorflow/core/grappler/optimizers/remapper_test.cc index 5eb9603dbf0dab..7a02b8283e752e 100644 --- a/tensorflow/core/grappler/optimizers/remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/remapper_test.cc @@ -15,20 +15,8 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/remapper.h" -#include -#include -#include - -#include "absl/log/log.h" -#include "tensorflow/cc/framework/ops.h" -#include "tensorflow/cc/framework/scope.h" -#include "tensorflow/cc/ops/array_ops.h" -#include "tensorflow/cc/ops/const_op.h" -#include "tensorflow/cc/ops/math_ops.h" -#include "tensorflow/cc/ops/nn_ops.h" #include "tensorflow/cc/ops/nn_ops_internal.h" -#include "tensorflow/cc/ops/string_ops.h" -#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" @@ -36,13 +24,9 @@ limitations under the License. #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/utils/graph_view.h" #include "tensorflow/core/grappler/utils/grappler_test.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/util.h" -#include "tsl/lib/core/status_test_util.h" -#include "tsl/platform/status.h" #if GOOGLE_CUDA #include "third_party/gpus/cudnn/cudnn.h" @@ -2103,131 +2087,6 @@ TEST_F(RemapperFuseMatMulWithBiasAndActivationTest, Bf16) { RunTest(); // NOLINT } -class RemapperFuseMatMulWithBiasSemanticAddAndMaximumTest - : public RemapperTest { - public: - template - void RunTest() { - using ::tensorflow::ops::Placeholder; - for (const string& add_op : {"BiasAdd", "AddV2", "Add"}) { - LOG(ERROR) << "siqiaowu@add_op: " << add_op; - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - - auto input_shape = ops::Placeholder::Shape({4, 32}); - auto filter_shape_1 = ops::Placeholder::Shape({32, 8}); - auto bias_shape_1 = ops::Placeholder::Shape({8}); - auto filter_shape_2 = ops::Placeholder::Shape({8, 2}); - auto bias_shape_2 = ops::Placeholder::Shape({2}); - - auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); - auto filter_1 = - Placeholder(s.WithOpName("filter_1"), DT_FLOAT, filter_shape_1); - auto bias_1 = Placeholder(s.WithOpName("bias_1"), DT_FLOAT, bias_shape_1); - auto filter_2 = - Placeholder(s.WithOpName("filter_2"), DT_FLOAT, filter_shape_2); - auto bias_2 = Placeholder(s.WithOpName("bias_2"), DT_FLOAT, bias_shape_2); - - auto matmul_1 = ops::MatMul(s.WithOpName("matmul_1"), input, filter_1); - Output bias_add_1; - if (add_op == "BiasAdd") - bias_add_1 = ops::BiasAdd(s.WithOpName("bias_add_1"), matmul_1, bias_1); - else if (add_op == "AddV2") - bias_add_1 = ops::AddV2(s.WithOpName("bias_add_1"), matmul_1, bias_1); - else if (add_op == "Add") - bias_add_1 = ops::Add(s.WithOpName("bias_add_1"), bias_1, matmul_1); - - auto matmul_2 = - ops::MatMul(s.WithOpName("matmul_2"), bias_add_1, filter_2); - Output bias_add_2; - if (add_op == "BiasAdd") - bias_add_2 = ops::BiasAdd(s.WithOpName("bias_add_2"), matmul_2, bias_2); - else if (add_op == "AddV2") - bias_add_2 = ops::AddV2(s.WithOpName("bias_add_2"), matmul_2, bias_2); - else if (add_op == "Add") - bias_add_2 = ops::Add(s.WithOpName("bias_add_2"), bias_2, matmul_2); - - typedef typename EnumToDataType::Type CType; - auto zeros = ops::Const(s.WithOpName("zeros"), 0.0f, {}); - Output maximum_output = - ops::Maximum(s.WithOpName("maximum"), bias_add_2, zeros); - - auto fetch = s.WithOpName("fetch"); - - ops::Identity(fetch, maximum_output); - - auto input_tensor = GenerateRandomTensor( - TensorShape(input_shape.shape_.dim_sizes())); - auto filter_1_tensor = GenerateRandomTensor( - TensorShape(filter_shape_1.shape_.dim_sizes())); - auto bias_1_tensor = GenerateRandomTensor( - TensorShape(bias_shape_1.shape_.dim_sizes())); - auto filter_2_tensor = GenerateRandomTensor( - TensorShape(filter_shape_2.shape_.dim_sizes())); - auto bias_2_tensor = GenerateRandomTensor( - TensorShape(bias_shape_2.shape_.dim_sizes())); - - GrapplerItem item; - item.fetch = {"fetch"}; - item.feed = {{"input", input_tensor}, - {"filter_1", filter_1_tensor}, - {"bias_1", bias_1_tensor}, - {"filter_2", filter_2_tensor}, - {"bias_2", bias_2_tensor}}; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); - - // Place all nodes on CPU. - for (int i = 0; i < item.graph.node_size(); ++i) { - item.graph.mutable_node(i)->set_device("/device:CPU:0"); - } - - Remapper optimizer(RewriterConfig::AGGRESSIVE); - GraphDef output; - TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output)); - - int found = 0; - for (const NodeDef& node : output.node()) { - if (node.name() == "bias_add_1") { - EXPECT_EQ("_FusedMatMul", node.op()); - EXPECT_EQ("input", node.input(0)); - EXPECT_EQ("filter_1", node.input(1)); - EXPECT_EQ(1, node.attr().at("num_args").i()); - EXPECT_EQ("bias_1", node.input(2)); - - const auto fused_ops = node.attr().at("fused_ops").list().s(); - EXPECT_EQ(1, fused_ops.size()); - EXPECT_EQ("BiasAdd", fused_ops[0]); - found++; - } - - if (node.name() == "maximum") { - EXPECT_EQ("_FusedMatMul", node.op()); - EXPECT_EQ("bias_add_1", node.input(0)); - EXPECT_EQ("filter_2", node.input(1)); - EXPECT_EQ(1, node.attr().at("num_args").i()); - EXPECT_EQ("bias_2", node.input(2)); - - const auto fused_ops = node.attr().at("fused_ops").list().s(); - EXPECT_EQ(2, fused_ops.size()); - EXPECT_EQ("BiasAdd", fused_ops[0]); - EXPECT_EQ("Relu", fused_ops[1]); - found++; - } - } - EXPECT_EQ(2, found); - - auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); - auto tensors = EvaluateNodes(output, item.fetch, item.feed); - EXPECT_EQ(1, tensors_expected.size()); - EXPECT_EQ(1, tensors.size()); - test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6); - } - } -}; - -TEST_F(RemapperFuseMatMulWithBiasSemanticAddAndMaximumTest, F32) { - RunTest(); -} - TEST_F(RemapperTest, FuseConv2DWithBatchNorm) { #ifdef DNNL_AARCH64_USE_ACL GTEST_SKIP() << "Skipping test due to different behaviour on AARCH64"; diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index e30825ba05072c..1dc269212965d5 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -824,7 +824,6 @@ cc_library( hdrs = ["fused_eigen_output_kernels.h"], deps = [ "//tensorflow/core:framework", - "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@eigen_archive//:eigen3", ], diff --git a/tensorflow/core/kernels/fused_eigen_output_kernels.h b/tensorflow/core/kernels/fused_eigen_output_kernels.h index 204d97d8d5c8e5..21bcf17df3e9d6 100644 --- a/tensorflow/core/kernels/fused_eigen_output_kernels.h +++ b/tensorflow/core/kernels/fused_eigen_output_kernels.h @@ -28,7 +28,6 @@ limitations under the License. #include -#include "absl/status/status.h" #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -414,13 +413,10 @@ Status InitBiasAddArgs(OpKernelContext* context, BiasAddArgs* args, const float* leakyrelu_alpha = nullptr) { // Bias of the following dimensions: [ output_depth ] const Tensor& bias = context->input(2); - for (int i = 0; i < bias.dims() - 1; ++i) { - if (bias.dim_size(i) != 1) { - return errors::InvalidArgument( - "all dimension sizes of bias must be 1 except the last dimension: ", - bias.shape().DebugString()); - } - } + + if (bias.dims() != 1) + return errors::InvalidArgument("bias must be 1-dimensional", + bias.shape().DebugString()); const auto data_ptr = [](const Tensor& tensor) -> const T* { return reinterpret_cast(tensor.tensor_data().data()); diff --git a/tensorflow/core/kernels/mkl/BUILD b/tensorflow/core/kernels/mkl/BUILD index 3dc886b53b1032..fd1ff2abcdea16 100644 --- a/tensorflow/core/kernels/mkl/BUILD +++ b/tensorflow/core/kernels/mkl/BUILD @@ -99,7 +99,6 @@ tf_mkl_kernel_library( ], deps = [ "//tensorflow/core:graph", - "@com_google_absl//absl/container:inlined_vector", ] + MKL_DEPS, ) diff --git a/tensorflow/python/grappler/cost_analyzer_test.py b/tensorflow/python/grappler/cost_analyzer_test.py index 676e7d1a24e380..20da91ea325a1c 100644 --- a/tensorflow/python/grappler/cost_analyzer_test.py +++ b/tensorflow/python/grappler/cost_analyzer_test.py @@ -110,8 +110,17 @@ def testSmallNetworkCost(self): self.assertTrue(b"Conv2DBackpropFilter" in report) self.assertTrue(b"Softmax" in report) + # When mkl is enabled, Conv2D and MatMul op followed by + # 1-dimension Add in this graph will be fused, but not + # in the mkl disabled case. + expected_matmul_count = 2 op_types = [b"MatMul", b"Conv2DBackpropFilter"] + if not test_util.IsMklEnabled(): + self.assertTrue(b"Conv2D" in report) + expected_matmul_count = 3 + op_types.append(b"Conv2D") + for op_type in op_types: matcher = re.compile( br"\s+" + op_type + br",\s*(\d+),\s*(\d+),\s*([\d\.eE+-]+)%,\s*" + @@ -122,7 +131,7 @@ def testSmallNetworkCost(self): # upper = int(m.group(5)) lower = int(m.group(6)) if op_type == b"MatMul": - self.assertEqual(2, op_count) + self.assertEqual(expected_matmul_count, op_count) else: self.assertEqual(1, op_count) self.assertTrue(0 <= lower) From 40ee6456441258dacd39b9eab13a424e5e255ed2 Mon Sep 17 00:00:00 2001 From: Pauline Sho Date: Tue, 18 Jun 2024 12:43:22 -0700 Subject: [PATCH 34/59] Upstream flatbuffer utils to read big models PiperOrigin-RevId: 644480154 --- tensorflow/lite/tools/flatbuffer_utils.py | 55 ++++++++++++++++------- 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/tensorflow/lite/tools/flatbuffer_utils.py b/tensorflow/lite/tools/flatbuffer_utils.py index ce6c7d80f8892d..ad90f04b01d892 100644 --- a/tensorflow/lite/tools/flatbuffer_utils.py +++ b/tensorflow/lite/tools/flatbuffer_utils.py @@ -58,9 +58,38 @@ def read_model(input_tflite_file): raise RuntimeError('Input file not found at %r\n' % input_tflite_file) with gfile.GFile(input_tflite_file, 'rb') as input_file_handle: model_bytearray = bytearray(input_file_handle.read()) + return read_model_from_bytearray(model_bytearray) + + +def read_model_from_bytearray(model_bytearray): + """Reads a tflite model as a python object. + + Args: + model_bytearray: TFLite model in bytearray format. + + Returns: + A python object corresponding to the input tflite file. + """ model = convert_bytearray_to_object(model_bytearray) if sys.byteorder == 'big': byte_swap_tflite_model_obj(model, 'little', 'big') + + # Offset handling for models > 2GB + for buffer in model.buffers: + if buffer.offset: + buffer.data = model_bytearray[buffer.offset : buffer.offset + buffer.size] + buffer.offset = 0 + buffer.size = 0 + for subgraph in model.subgraphs: + for op in subgraph.operators: + if op.largeCustomOptionsOffset: + op.customOptions = model_bytearray[ + op.largeCustomOptionsOffset : op.largeCustomOptionsOffset + + op.largeCustomOptionsSize + ] + op.largeCustomOptionsOffset = 0 + op.largeCustomOptionsSize = 0 + return model @@ -294,14 +323,10 @@ def byte_swap_buffer_content(buffer, chunksize, from_endiness, to_endiness): buffer.data[i : i + chunksize] for i in range(0, len(buffer.data), chunksize) ] - buffer.data = b''.join( - [ - int.from_bytes(byteswap, from_endiness).to_bytes( - chunksize, to_endiness - ) - for byteswap in to_swap - ] - ) + buffer.data = b''.join([ + int.from_bytes(byteswap, from_endiness).to_bytes(chunksize, to_endiness) + for byteswap in to_swap + ]) def byte_swap_string_content(buffer, from_endiness, to_endiness): @@ -314,14 +339,12 @@ def byte_swap_string_content(buffer, from_endiness, to_endiness): """ num_of_strings = int.from_bytes(buffer.data[0:4], from_endiness) string_content = bytearray(buffer.data[4 * (num_of_strings + 2) :]) - prefix_data = b''.join( - [ - int.from_bytes(buffer.data[i : i + 4], from_endiness).to_bytes( - 4, to_endiness - ) - for i in range(0, (num_of_strings + 1) * 4 + 1, 4) - ] - ) + prefix_data = b''.join([ + int.from_bytes(buffer.data[i : i + 4], from_endiness).to_bytes( + 4, to_endiness + ) + for i in range(0, (num_of_strings + 1) * 4 + 1, 4) + ]) buffer.data = prefix_data + string_content From 487d60b06fda32a01bddb4fec6dcc36476a5ab9e Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 18 Jun 2024 12:49:31 -0700 Subject: [PATCH 35/59] [XLA:GPU] Redirect all Triton normalization fusions to the new generic Triton emitter. This change contains the necessary changes to make this happen. Concretely: 1. We now tag all Triton fusions with kind `__triton` in `softmax_rewriter_triton.cc`; 2. We prevent `SoftmaxRewriterTriton` from creating fusions that we won't know how to tile by calling into `SymbolicTileAnalysis` before creating fusions for a given `DiamondChainDescriptor`. This notably forces us to stop supporting some cases that we used to support where there were *a lot* of `bitcast`s, but they were contrived and unlikely to show up in actual models; 3. We add missing simplifications to `SymbolicTileAnalysis` that are necessary to make all our tests pass; 4. To ensure that all our tests now go down the right path, we completely delete the dedicated SoftMax emitter. The legacy infrastructure (specifically `TritonFusionAnalysis` and its tests) still hold references to `kTritonSoftmaxFusionKind`. Since that code is all but dead for normalization fusions, we don't need to change this now---and the clean ups of that part of the code will be done in upcoming CLs. PiperOrigin-RevId: 644481882 --- third_party/xla/xla/service/gpu/BUILD | 4 + .../xla/xla/service/gpu/ir_emitter_triton.cc | 155 ------------------ .../gpu/ir_emitter_triton_large_test.cc | 2 +- .../ir_emitter_triton_parametrized_test.cc | 118 +++---------- .../gpu/model/symbolic_tile_analysis.cc | 3 + .../service/gpu/softmax_rewriter_triton.cc | 89 ++++++++-- .../xla/service/gpu/softmax_rewriter_triton.h | 4 +- .../gpu/softmax_rewriter_triton_test.cc | 26 +-- 8 files changed, 114 insertions(+), 287 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index c097abd3c30c53..2bdbeb1fb17d21 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1564,6 +1564,7 @@ cc_library( "//xla/hlo/utils:hlo_query", "//xla/service:hlo_pass", "//xla/service:instruction_fusion", + "//xla/service/gpu/model:symbolic_tile_analysis", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -1571,7 +1572,9 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", @@ -2605,6 +2608,7 @@ xla_cc_test( "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index d3e02d7e09237d..0dcb453d7655b0 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -2586,155 +2585,6 @@ absl::Status EmitGeneric(mlir::OpBuilder builder, return absl::OkStatus(); } -absl::Status EmitSoftMax(mlir::OpBuilder builder, - absl::string_view libdevice_path, - const se::DeviceDescription& device_info, - const HloFusionInstruction* fusion, - mlir::triton::FuncOp fn, const BlockLevelParameters&) { - const HloComputation* computation = fusion->fused_instructions_computation(); - // TODO(b/332649307): Remove the fallback on the legacy triton analysis once - // the symbolic tile analysis can handle all cases. - TF_ASSIGN_OR_RETURN(TritonFusionAnalysis analysis, - TritonFusionAnalysis::Execute(*computation)); - - const HloInstruction* root = computation->root_instruction(); - auto loc = mlir::NameLoc::get(builder.getStringAttr(root->name())); - ImplicitLocOpBuilder b(loc, builder); - - // Assumptions we make about the matcher: - // * matches Softmax "diamonds" on the last axis, along with any number of - // elementwise operations/bitcasts on any edge - // * within a given fusion, every argument to a Softmax diamond has the same - // shape - // * every reduction is on the last axis - // * the last axis of every reduction parameter has the same length - // * reductions only reduce a single operand - // * all the shapes have canonical layout (logical layout = physical layout) - // * the computation has a single output - // * we tile along a single dimension - - // TODO(bchetioui): allow doing several rows per block (e.g. for when rows - // are smaller than the minimum transaction size) - - const HloInstruction* reduce = hlo_query::GetFirstInstructionWithOpcode( - *computation, HloOpcode::kReduce); - - TF_RET_CHECK(reduce != nullptr); - - Shape reduce_input_shape = reduce->operand(0)->shape(); - - TF_RET_CHECK(reduce->opcode() == HloOpcode::kReduce); - TF_RET_CHECK(reduce->dimensions().size() == 1); - TF_RET_CHECK(reduce->dimensions()[0] == reduce_input_shape.rank() - 1); - - int row_len = reduce_input_shape.dimensions_minor(0); - - Value pid = b.create( - b.getI64Type(), b.create(mt::ProgramIDDim::X)); - Value row_stride = CreateConst(b, b.getI32Type(), row_len); - - Value row_offset = b.create( - pid, b.create(b.getI64Type(), row_stride)); - Value zero_offset = CreateConst(b, b.getI64Type(), 0); - - absl::flat_hash_map values_out; - std::vector boundary_checks; - - // block_size must be a power of two. - int result_block_size = pow(2, ceil(log(row_len) / log(2))); - - if (result_block_size != row_len) { - boundary_checks.push_back(0); - } - - // Emits load instructions - for (int param_idx = 0; param_idx < computation->num_parameters(); - ++param_idx) { - HloInstruction* param = computation->parameter_instruction(param_idx); - // Current tiling derivation assigns index 0 to the reduction dimension and - // index 1 to the batch dimension. - auto reduce_iterspec = analysis.IterSpec( - TritonFusionAnalysis::Scope::OUTPUT, param, /*dimension=*/0); - auto batch_iterspec = analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, - param, /*dimension=*/1); - - // Make sure only batch and reduce dims are present in tiling - TF_RET_CHECK(analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, param, - /*dimension=*/2) == nullptr); - - if (!reduce_iterspec) { - // This parameter's broadcast is along the reduce dimension, and so - // each pid uses and broadcasts its own index. - - // If batchDimIterSpec is also not present, then this parameter is a - // scalar, in which case we reuse this for each pid with offset. - Value batch_offset = batch_iterspec ? pid : zero_offset; - - values_out[param] = EmitParameterLoad( - b, AddPtr(b, fn.getArgument(param_idx), batch_offset), - boundary_checks); - continue; - } - - TF_RET_CHECK(reduce_iterspec != nullptr); - TF_RET_CHECK(reduce_iterspec->size() == 1); - - // TODO(b/310721908): The below assumes that we tile along a single dim. - int reduce_dim_len = reduce_iterspec->front().count; - int reduce_dim_stride = reduce_iterspec->front().stride; - int slice_offset = reduce_iterspec->front().slice_start; - - // If the batch dimension is present in this parameter's tile, we must make - // sure each batch idx is offset by the correct number of rows. If it is not - // present, then the reduce dim data is reused without any offset. - Value base_offset = batch_iterspec ? row_offset : zero_offset; - - // We assume that the reduced axis of this parameter has length row_len. - // TODO(b/316637896): Relax assumption that param reduce_dim_len == row_len. - TF_RET_CHECK(reduce_dim_len == row_len); - - // block_size must be a power of two. - int block_size = pow(2, ceil(log(reduce_dim_len) / log(2))); - - // Verify that this param contains a single contiguous fragment. - TF_RET_CHECK(reduce_iterspec->front().subfragments.size() == 1); - - Value emitted_tensor = b.create( - /*base=*/AddPtr(b, fn.getArgument(param_idx), base_offset), - /*shape=*/ValueRange{CreateConst(b, b.getI64Type(), reduce_dim_len)}, - /*strides=*/ - ValueRange{CreateConst(b, b.getI64Type(), reduce_dim_stride)}, - /*offsets=*/ValueRange{CreateConst(b, b.getI32Type(), slice_offset)}, - /*tensorShape=*/std::vector{block_size}, - /*order=*/std::vector{0}); - - values_out[param] = EmitParameterLoad(b, emitted_tensor, boundary_checks); - } - - // Dimension 0 is the reduced one by construction and it's the only one - // present in the tile shapes. - std::vector tiled_dims = {DimProperties( - /*index=*/0, pid, result_block_size, /*split_value=*/1)}; - TF_ASSIGN_OR_RETURN( - Value result, - EmitScope(b, libdevice_path, device_info, &analysis, - TritonFusionAnalysis::Scope::OUTPUT, tiled_dims, - computation->MakeInstructionPostOrder(), values_out)); - - Value store_tensor = b.create( - /*base=*/AddPtr(b, fn.getArgument(computation->num_parameters()), - row_offset), - /*shape=*/ValueRange{CreateConst(b, b.getI64Type(), row_len)}, - /*strides=*/ValueRange{CreateConst(b, b.getI64Type(), 1)}, - /*offsets=*/ValueRange{CreateConst(b, b.getI32Type(), 0)}, - /*tensorShape=*/std::vector{result_block_size}, - /*order=*/std::vector{0}); - - b.create(store_tensor, result, std::vector{0}, - mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL); - return absl::OkStatus(); -} - void LoadMlirDialectsForTriton(mlir::MLIRContext& mlir_context) { mlir_context.loadDialect< mt::TritonDialect, mt::gpu::TritonGPUDialect, mlir::arith::ArithDialect, @@ -2833,11 +2683,6 @@ absl::StatusOr> CreateTritonModule( if (fusion_kind == kTritonGemmFusionKind) { TF_RETURN_IF_ERROR(EmitMatMul(b, libdevice_path, device_info, fusion, fn, block_level_parameters)); - } else if (fusion_kind == kTritonSoftmaxFusionKind) { - TF_ASSIGN_OR_RETURN(TritonFusionAnalysis analysis, - TritonFusionAnalysis::Execute(*hlo_computation)); - TF_RETURN_IF_ERROR(EmitSoftMax(b, libdevice_path, device_info, fusion, fn, - block_level_parameters)); } else if (fusion_kind == kTritonFusionKind) { TF_RETURN_IF_ERROR(EmitGeneric(b, libdevice_path, device_info, fusion, fn, block_level_parameters)); diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_large_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_large_test.cc index 05a0b78a6d876a..039b0d2d1863c4 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_large_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_large_test.cc @@ -154,7 +154,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"); // Checking that this does not crash should be enough. diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc index e03def6246566c..146b53547a1852 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc @@ -838,7 +838,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[param_0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; const std::string hlo_ref = absl::Substitute( @@ -892,7 +892,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; const std::string hlo_ref = absl::Substitute( @@ -943,7 +943,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[param_0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; const std::string hlo_ref = absl::Substitute( @@ -978,7 +978,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0_ENTRY]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; MatchOptimizedHlo(hlo_text, hlo_ref); @@ -988,78 +988,6 @@ ENTRY main { ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); } -TEST_P( - TritonSoftmaxTest, - CanFuseAndEmitSoftmaxWithBatchDimMergingAndSplittingBitcastsOnEveryEdge) { - PrimitiveType data_type = GetParam(); - - if (data_type == F16) { - GTEST_SKIP() << "Exponential op does not support F16."; - } - - const std::string hlo_text_template = R"( -HloModule softmax -max_computation { - arg_0 = $0[] parameter(0) - arg_1 = $0[] parameter(1) - ROOT maximum = $0[] maximum(arg_0, arg_1) -} -add_computation { - arg_0.1 = $0[] parameter(0) - arg_1.1 = $0[] parameter(1) - ROOT add = $0[] add(arg_0.1, arg_1.1) -} - -ENTRY main { - param_0 = $0[2,65,125] parameter(0) - bitcasted_param_0 = $0[65,2,125] reshape(param_0) - constant_neg_inf = $0[] constant(-inf) - reduce = $0[65,2]{1,0} reduce(bitcasted_param_0, constant_neg_inf), dimensions={2}, to_apply=max_computation - bitcasted_reduce = $0[130] reshape(reduce) - broadcast = $0[130,125]{1,0} broadcast(bitcasted_reduce), dimensions={0} - bitcasted_broadcast = $0[65,2,125] reshape(broadcast) - subtract = $0[65,2,125]{2,1,0} subtract(bitcasted_param_0, bitcasted_broadcast) - bitcasted_subtract = $0[130,125] reshape(subtract) - exponential = $0[130,125]{1,0} exponential(bitcasted_subtract) - constant_zero = $0[] constant(0) - bitcasted_exponential = $0[2,65,125] reshape(exponential) - second_reduce = $0[2,65]{1,0} reduce(bitcasted_exponential, constant_zero), dimensions={2}, to_apply=add_computation - second_bitcasted_reduce = $0[130] reshape(second_reduce) - second_broadcast = $0[130,125]{1,0} broadcast(second_bitcasted_reduce), dimensions={0} - second_bitcasted_broadcast = $0[2,65,125] reshape(second_broadcast) - ROOT divide = $0[2,65,125]{2,1,0} divide(bitcasted_exponential, second_bitcasted_broadcast) -})"; - const std::string hlo_text = absl::Substitute( - hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); - const std::string hlo_ref_template = R"( -; CHECK: ENTRY -; CHECK: %[[P0:.*]] = $0[2,65,125]{2,1,0} parameter(0) -; CHECK: ROOT -; CHECK-SAME: fusion(%[[P0]]) -; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax -)"; - - const std::string hlo_ref = absl::Substitute( - hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); - - MatchOptimizedHlo(hlo_text, hlo_ref); - - float tolerance; - switch (data_type) { - case F32: - tolerance = 1e-6; - break; - case BF16: - tolerance = 2e-4; - break; - default: - ABSL_UNREACHABLE(); - } - EXPECT_TRUE(RunAndCompare(hlo_text, - ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); -} - TEST_P(TritonSoftmaxTest, CanFuseAndEmitDiamondWithMultipleBroadcastDimensions) { PrimitiveType data_type = GetParam(); @@ -1088,7 +1016,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; const std::string hlo_ref = absl::Substitute( @@ -1156,7 +1084,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; const std::string hlo_ref = absl::Substitute( @@ -1217,7 +1145,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; const std::string hlo_ref = absl::Substitute( @@ -1272,7 +1200,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; const std::string hlo_ref = absl::Substitute( @@ -1326,7 +1254,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; const std::string hlo_ref = absl::Substitute( @@ -1383,7 +1311,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; const std::string hlo_ref = absl::Substitute( @@ -1438,7 +1366,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; MatchOptimizedHlo(hlo_text, hlo_ref); @@ -1491,7 +1419,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; const std::string hlo_ref = absl::Substitute( @@ -1547,7 +1475,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; const std::string hlo_ref = absl::Substitute( @@ -1606,7 +1534,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; const std::string hlo_ref = absl::Substitute( @@ -1674,7 +1602,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; const std::string hlo_ref = absl::Substitute( @@ -1741,7 +1669,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; const std::string hlo_ref = absl::Substitute( @@ -1841,7 +1769,7 @@ ENTRY main.30 { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; const std::string hlo_ref = absl::Substitute( @@ -1904,7 +1832,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; const std::string hlo_ref = absl::Substitute( @@ -1967,7 +1895,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; const std::string hlo_ref = absl::Substitute( @@ -2026,7 +1954,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; const std::string hlo_ref = absl::Substitute( @@ -2085,7 +2013,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; const std::string hlo_ref = absl::Substitute( hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); @@ -2144,7 +2072,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; const std::string hlo_ref = absl::Substitute( hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); @@ -2203,7 +2131,7 @@ ENTRY main { ; CHECK: ROOT ; CHECK-SAME: fusion(%[[P0]]) ; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax +; CHECK-SAME: __triton )"; const std::string hlo_ref = absl::Substitute( hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc index 59fafc442f80d2..b2d0847850114b 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -247,6 +247,9 @@ absl::StatusOr ComputeBlockIdToTileOffsetIndexing( IndexingMap operand_indexing_map = ComposeIndexingMaps(tiled_hlo_instruction->indexing_map(), *operand_indexing_map_set.begin()); + operand_indexing_map.Simplify(); + operand_indexing_map.RescaleSymbols(); + operand_indexing_map.RemoveUnusedSymbols(); auto tiled_operand_or = get_tiled_hlo_instruction(operand, std::move(operand_indexing_map)); diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc index fc3c08ca2aae8f..d270ad1fbe1a66 100644 --- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -35,6 +36,7 @@ limitations under the License. #include "xla/layout_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/model/symbolic_tile_analysis.h" #include "xla/service/gpu/triton_support.h" #include "xla/service/instruction_fusion.h" #include "xla/shape.h" @@ -327,7 +329,15 @@ HloInstruction* FindFirstNonFusibleDiamondProducer( return diamond_producer; } -absl::Status FuseDiamondChainImpl(const DiamondChainDescriptor& diamond_chain) { +// Creates a fusion corresponding to the input diamond chain. The resulting +// fusion instruction is added to the module, but is not yet inserted into the +// graph as a replacement of the original instructions. +// +// TODO(b/347956491): this awkward abstraction is needed to work around +// limitations of HloFusionAdaptor, which underpins the implementation of +// SymbolicTileAnalysis. We need to come up with a better solution. +absl::StatusOr MakeFusionForDiamondChain( + const DiamondChainDescriptor& diamond_chain) { auto [root, producer] = diamond_chain; std::string suggested_name = "triton_softmax"; @@ -366,6 +376,7 @@ absl::Status FuseDiamondChainImpl(const DiamondChainDescriptor& diamond_chain) { } }; create_computation(root); + HloComputation* computation = root->GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(), /*is_entry=*/false); @@ -376,14 +387,20 @@ absl::Status FuseDiamondChainImpl(const DiamondChainDescriptor& diamond_chain) { computation)); softmax_fusion->GetModule()->SetAndUniquifyInstrName(softmax_fusion, - suggested_name); - + "triton_softmax"); TF_ASSIGN_OR_RETURN(auto gpu_config, softmax_fusion->backend_config()); FusionBackendConfig& backend_config = *gpu_config.mutable_fusion_backend_config(); - backend_config.set_kind(std::string(kTritonSoftmaxFusionKind)); + backend_config.set_kind(std::string(kTritonFusionKind)); TF_RETURN_IF_ERROR(softmax_fusion->set_backend_config(gpu_config)); + return xla::Cast(softmax_fusion); +} + +absl::Status FuseDiamondChainImpl(const DiamondChainDescriptor& diamond_chain) { + TF_ASSIGN_OR_RETURN(HloFusionInstruction * softmax_fusion, + MakeFusionForDiamondChain(diamond_chain)); + HloInstruction* root = diamond_chain.root; if (root->IsRoot()) { root->parent()->set_root_instruction(softmax_fusion); @@ -398,7 +415,27 @@ absl::Status FuseDiamondChainImpl(const DiamondChainDescriptor& diamond_chain) { return absl::OkStatus(); } -using DiamondDescriptor = DiamondChainDescriptor; +// Returns `true` if the diamond chain passed as a parameter can be tiled +// correctly using `SymbolicTileAnalysis`. +absl::StatusOr CanSymbolicTileAnalysisTileDiamondChain( + const DiamondChainDescriptor& diamond_chain) { + TF_ASSIGN_OR_RETURN(HloFusionInstruction * softmax_fusion, + MakeFusionForDiamondChain(diamond_chain)); + mlir::MLIRContext context; + SymbolicTileAnalysisOrError symbolic_tile_analysis_or_error = + SymbolicTileAnalysis::AnalyzeComputation( + *softmax_fusion->called_computation(), &context); + + bool can_tile = std::holds_alternative( + symbolic_tile_analysis_or_error); + + TF_RETURN_IF_ERROR(diamond_chain.root->GetModule()->RemoveEmbeddedComputation( + softmax_fusion->called_computation())); + TF_RETURN_IF_ERROR( + diamond_chain.root->parent()->RemoveInstruction(softmax_fusion)); + + return can_tile; +} } // anonymous namespace @@ -477,11 +514,11 @@ SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond( return producer; } -std::vector +absl::StatusOr> SoftmaxRewriterTriton::FindAllFusibleDiamondChains( HloModule& module, const absl::flat_hash_set& execution_threads) const { - std::vector matched_diamonds; + std::vector matched_diamonds; for (HloComputation* comp : module.MakeNonfusionComputations(execution_threads)) { @@ -499,10 +536,22 @@ SoftmaxRewriterTriton::FindAllFusibleDiamondChains( auto producer = MatchesTritonCompatibleClosedReductionDiamond(instr); if (std::holds_alternative(producer)) { - matched_diamonds.push_back(DiamondDescriptor{ - instr, - std::get(producer), - }); + DiamondChainDescriptor diamond_chain{ + /*root=*/instr, /*producer=*/std::get(producer)}; + // We filter out the diamond chains that cannot be tiled correctly using + // `SymbolicTileAnalysis`. + TF_ASSIGN_OR_RETURN( + bool can_tile_diamond_chain, + CanSymbolicTileAnalysisTileDiamondChain(diamond_chain)); + if (can_tile_diamond_chain) { + matched_diamonds.push_back(diamond_chain); + } else { + VLOG(5) << "Cannot tile the diamond pattern described by " + << "instructions " << instr->ToString() << " and " + << std::get(producer)->ToString() << "."; + continue; + } + } else { VLOG(5) << "Cannot match the diamond pattern for instruction " << instr->ToString() @@ -512,7 +561,7 @@ SoftmaxRewriterTriton::FindAllFusibleDiamondChains( } if (matched_diamonds.empty()) { - return {}; + return std::vector(); } auto reduction_dimension_size_from_diamond_root = @@ -614,7 +663,17 @@ SoftmaxRewriterTriton::FindAllFusibleDiamondChains( last_trivially_fusible_user(matched_diamonds.back().root), current_fusion_producer}); - return diamond_chains; + // We filter out the diamond chains that cannot be tiled correctly using + // `SymbolicTileAnalysis`. + std::vector filtered_diamond_chains; + for (const DiamondChainDescriptor& diamond_chain : diamond_chains) { + TF_ASSIGN_OR_RETURN(bool can_tile_diamond_chain, + CanSymbolicTileAnalysisTileDiamondChain(diamond_chain)); + if (can_tile_diamond_chain) { + filtered_diamond_chains.push_back(diamond_chain); + } + } + return filtered_diamond_chains; } absl::Status SoftmaxRewriterTriton::FuseDiamondChain( @@ -638,8 +697,8 @@ absl::StatusOr SoftmaxRewriterTriton::Run( cuda_compute_capability->minor, ".")); } - std::vector diamond_chains = - FindAllFusibleDiamondChains(*module, execution_threads); + TF_ASSIGN_OR_RETURN(std::vector diamond_chains, + FindAllFusibleDiamondChains(*module, execution_threads)); if (diamond_chains.empty()) { return false; diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h index 44d32c5717b709..b70f9f56a03334 100644 --- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h +++ b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -54,7 +55,8 @@ class SoftmaxRewriterTriton : public HloModulePass { // Finds and returns all the fusible diamond chains in the module. The // resulting vector is sorted according to a post-order matching (i.e. within // the same computation, producer diamonds appear before consumer diamonds). - std::vector FindAllFusibleDiamondChains( + absl::StatusOr> + FindAllFusibleDiamondChains( HloModule& module, const absl::flat_hash_set& execution_threads) const; diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc index b6730212efb583..70f445de4aac42 100644 --- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -52,9 +53,9 @@ absl::StatusOr SoftmaxRewriterTritonMatchAndRewrite( se::GpuComputeCapability gpu_version, HloModule* module) { CHECK_NE(module, nullptr); SoftmaxRewriterTriton softmax_rewriter_triton(gpu_version); - std::vector diamond_chains = - softmax_rewriter_triton.FindAllFusibleDiamondChains( - *module, /*execution_threads=*/{}); + TF_ASSIGN_OR_RETURN(std::vector diamond_chains, + softmax_rewriter_triton.FindAllFusibleDiamondChains( + *module, /*execution_threads=*/{})); for (auto diamond_chain = diamond_chains.rbegin(); diamond_chain != diamond_chains.rend(); ++diamond_chain) { @@ -229,7 +230,7 @@ ENTRY main { } TEST_P(SoftmaxRewriterTritonTest, - CanFuseSoftmaxWithBatchDimMergingAndSplittingBitcastsOnEveryEdge) { + CanNotFuseSoftmaxWhenResultingComputationCanNotBeTiledCorrectly) { PrimitiveType data_type = GetParam(); const std::string hlo_string_template = R"( HloModule softmax @@ -270,23 +271,8 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - EXPECT_TRUE( + EXPECT_FALSE( SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); - EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - - switch (data_type) { - case F32: - case BF16: - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); - break; - case F16: - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Bitcast(m::Divide()))); - break; - default: - ABSL_UNREACHABLE(); - } } TEST_P(SoftmaxRewriterTritonTest, CanNotFuseSoftmaxDiamondWithWrongLayout) { From 8f025afe8ee3ba20a9965e135baf8494460f1ba4 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 18 Jun 2024 13:11:40 -0700 Subject: [PATCH 36/59] [XLA:GPU] Remove trailing references to deprecated `kTritonSoftmaxFusionKind`. PiperOrigin-RevId: 644488604 --- .../xla/xla/service/gpu/fusions/triton.cc | 6 +----- .../xla/xla/service/gpu/fusions/triton_test.cc | 2 +- third_party/xla/xla/service/gpu/gpu_fusible.cc | 4 ++-- third_party/xla/xla/service/gpu/gpu_fusible.h | 4 ++-- .../xla/xla/service/gpu/hlo_fusion_analysis.cc | 4 +--- .../xla/xla/service/gpu/ir_emission_utils.h | 5 ----- .../gpu/model/gpu_performance_model_base_test.cc | 2 +- .../gpu/model/symbolic_tile_analysis_test.cc | 2 +- .../xla/xla/service/gpu/priority_fusion.cc | 6 +++--- .../xla/xla/service/gpu/priority_fusion_test.cc | 4 ++-- .../service/gpu/triton_fusion_analysis_test.cc | 16 ++++++++-------- .../gpu/triton_fusion_numerics_verifier.cc | 3 +-- 12 files changed, 23 insertions(+), 35 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/triton.cc b/third_party/xla/xla/service/gpu/fusions/triton.cc index e9f14aa86546c7..d983ab056ba213 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton.cc @@ -167,8 +167,7 @@ absl::StatusOr TritonFusion::Emit( TritonWrapperResult triton_wrapper_result; LaunchDimensions launch_dimensions; - if (fusion_kind == kTritonFusionKind || - fusion_kind == kTritonSoftmaxFusionKind) { + if (fusion_kind == kTritonFusionKind) { auto launch_config = *this->launch_config(); launch_dimensions = launch_config.launch_dimensions; @@ -290,9 +289,6 @@ std::optional TritonFusion::launch_config() const { // - 1 grid dimension corresponds to all batch dimensions in the HLO. // - 1-2 grid dimension corresponds to block-able dimensions from the HLO. return CalculateSoftMaxLaunchConfig(analysis_.fusion()); - } else if (analysis_.fusion_backend_config().kind() == - kTritonSoftmaxFusionKind) { - return CalculateSoftMaxLaunchConfig(analysis_.fusion()); } // MatMul is not yet supported. diff --git a/third_party/xla/xla/service/gpu/fusions/triton_test.cc b/third_party/xla/xla/service/gpu/fusions/triton_test.cc index abe8355dadd9f5..f714d2df9ebf6e 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton_test.cc @@ -64,7 +64,7 @@ TEST_F(TritonFusionTest, TritonSoftmaxFusion) { ENTRY main { param_0 = f32[125]{0} parameter(0) auxiliary_fusion = f32[125,127]{1,0} fusion(param_0), kind=kLoop, calls=auxiliary_computation - ROOT triton_softmax = f32[125,127]{1,0} fusion(auxiliary_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config":{"kind":"__triton_softmax"}} + ROOT triton_softmax = f32[125,127]{1,0} fusion(auxiliary_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config":{"kind":"__triton"}} })") .value(); diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.cc b/third_party/xla/xla/service/gpu/gpu_fusible.cc index 848f9680498e16..b445f35feb11e8 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible.cc +++ b/third_party/xla/xla/service/gpu/gpu_fusible.cc @@ -938,14 +938,14 @@ std::vector GetFusionRoots( return out; } -bool IsTritonSoftmaxFusion(const HloInstruction& instr) { +bool IsGenericTritonFusion(const HloInstruction& instr) { // TODO(b/332649307): Eventually turn this into a generic fusion. return instr.opcode() == HloOpcode::kFusion && instr.fusion_kind() == HloInstruction::FusionKind::kCustom && instr.backend_config().ok() && instr.backend_config() ->fusion_backend_config() - .kind() == kTritonSoftmaxFusionKind; + .kind() == kTritonFusionKind; } bool MayPreventVectorization(const HloFusionAdaptor& fusion) { diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.h b/third_party/xla/xla/service/gpu/gpu_fusible.h index 7971e69e05056f..7bc120b7574529 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible.h +++ b/third_party/xla/xla/service/gpu/gpu_fusible.h @@ -215,8 +215,8 @@ size_t GetOutputSizeOfFusible(const HloInstruction& instr); std::vector GetFusionRoots( const HloComputation& computation); -// Whether the instruction is a Triton Softmax fusion. -bool IsTritonSoftmaxFusion(const HloInstruction& instr); +// Whether the instruction is a generic Triton fusion. +bool IsGenericTritonFusion(const HloInstruction& instr); // Whether the fusion will likely behave poorly with vectorization due to the // instructions it contains. diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc index c872f8e5878f14..a027e85ab5cd3a 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc @@ -39,7 +39,6 @@ limitations under the License. #include "xla/service/gpu/reduction_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" namespace xla { @@ -209,8 +208,7 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() } if (fusion_backend_config_.kind() == kTritonFusionKind || - fusion_backend_config_.kind() == kTritonGemmFusionKind || - fusion_backend_config_.kind() == kTritonSoftmaxFusionKind) { + fusion_backend_config_.kind() == kTritonGemmFusionKind) { return EmitterFusionKind::kTriton; } diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.h b/third_party/xla/xla/service/gpu/ir_emission_utils.h index 99109b54fb4169..73e640eda511d4 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.h +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.h @@ -71,11 +71,6 @@ inline constexpr absl::string_view kTritonFusionKind = "__triton"; // Fusions that use Triton have FusionBackendConfig.kind equal to this string. inline constexpr absl::string_view kTritonGemmFusionKind = "__triton_gemm"; -// SoftmaxRewriterTriton sets backend_config of Triton Softmax custom fusions to -// this string. -inline constexpr absl::string_view kTritonSoftmaxFusionKind = - "__triton_softmax"; - inline constexpr absl::string_view kCuDnnFusionKind = "__cudnn$fusion"; inline constexpr absl::string_view kUncompilableFusion = diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc index 9acd86ef99a0c1..8888e084c5863f 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc @@ -241,7 +241,7 @@ ENTRY e { p0 = f32[16,970]{1,0} parameter(0) ROOT r = f32[16,970]{1,0} fusion(p0), kind=kCustom, calls=triton_softmax_computation, - backend_config={"fusion_backend_config": {kind: "__triton_softmax"}} + backend_config={"fusion_backend_config": {kind: "__triton"}} })"; TF_ASSERT_OK_AND_ASSIGN(auto module, diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc index a0d12abdd1adfd..660c686d5d4fec 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -670,7 +670,7 @@ fused_computation { ENTRY entry_computation { param_0 = f32[8192,50304] parameter(0) - ROOT fusion = f32[4,2048,50304] fusion(param_0), kind=kCustom, calls=fused_computation, backend_config={"fusion_backend_config":{"kind":"__triton_softmax"}} + ROOT fusion = f32[4,2048,50304] fusion(param_0), kind=kCustom, calls=fused_computation, backend_config={"fusion_backend_config":{"kind":"__triton"}} } )")); diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc index b079f3a2753754..6abf156d1e83e4 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion.cc @@ -421,7 +421,7 @@ class GpuPriorityFusionQueue { return "triton softmax fusion is not enabled"; } - if (IsTritonSoftmaxFusion(*producer)) { + if (IsGenericTritonFusion(*producer)) { if (!IsFusible(*consumer)) { return "the consumer is not fusible"; } @@ -447,7 +447,7 @@ class GpuPriorityFusionQueue { } FusionDecision CanFuse(HloInstruction* producer, HloInstruction* consumer) { - if (IsTritonSoftmaxFusion(*producer) || IsTritonSoftmaxFusion(*consumer)) { + if (IsGenericTritonFusion(*producer) || IsGenericTritonFusion(*consumer)) { return CanFuseTriton(producer, consumer); } @@ -842,7 +842,7 @@ HloInstruction* GpuPriorityFusion::FuseInstruction( HloInstruction* fusion_instruction, HloInstruction* producer) { HloInstruction* result = fusion_instruction; if (producer->opcode() == HloOpcode::kFusion) { - if (IsTritonSoftmaxFusion(*producer)) { + if (IsGenericTritonFusion(*producer)) { TF_CHECK_OK(fusion_instruction->set_backend_config( *producer->backend_config())); } diff --git a/third_party/xla/xla/service/gpu/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/priority_fusion_test.cc index 5da67641d62855..4f71a51b869b4f 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion_test.cc @@ -896,7 +896,7 @@ ENTRY main { param_0 = f32[125]{0} parameter(0) param_1 = f32[125,127]{1,0} parameter(1) producer_fusion = f32[125,127]{1,0} fusion(param_0), kind=kLoop, calls=producer_computation - triton_softmax = f32[125,127]{1,0} fusion(producer_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton_softmax"}} + triton_softmax = f32[125,127]{1,0} fusion(producer_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} ROOT consumer_fusion = f32[125,127]{1,0} fusion(param_1, triton_softmax), kind=kLoop, calls=consumer_computation })"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); @@ -910,7 +910,7 @@ ENTRY main { HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kCustom); - EXPECT_TRUE(IsTritonSoftmaxFusion(*root)); + EXPECT_TRUE(IsGenericTritonFusion(*root)); } TEST_F(PriorityFusionTest, DoNotFuseInsideReducer) { diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc index 0ff9e93f6f9bb6..ad61b6771e2c83 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc @@ -818,7 +818,7 @@ ENTRY e { p0 = f32[1,97]{1,0} parameter(0) ROOT r = f32[1,97]{1,0} fusion(p0), kind=kCustom, calls=triton_softmax_computation, - backend_config={"kind":"__triton_softmax"} + backend_config={"kind":"__triton"} })")); const HloComputation* computation = module->entry_computation()->root_instruction()->called_computations()[0]; @@ -892,7 +892,7 @@ ENTRY main { param_0 = f32[8,4,127]{2,1,0} parameter(0) ROOT fusion = f32[4,127]{1,0} fusion(param_0), kind=kCustom, calls=triton_softmax_computation, - backend_config={"kind":"__triton_softmax"} + backend_config={"kind":"__triton"} })")); const HloComputation* computation = @@ -924,7 +924,7 @@ ENTRY main { param_0 = f32[4,127]{1,0} parameter(0) ROOT fusion = f32[8,127]{1,0} fusion(param_0), kind=kCustom, calls=triton_softmax_computation, - backend_config={"kind":"__triton_softmax"} + backend_config={"kind":"__triton"} })")); const HloComputation* computation = @@ -955,7 +955,7 @@ ENTRY main { param_1 = f32[8,16129]{1,0} parameter(0) ROOT fusion = f32[8,127,127]{2,1,0} fusion(param_1), kind=kCustom, calls=triton_softmax_computation, - backend_config={"kind":"__triton_softmax"} + backend_config={"kind":"__triton"} })")); const HloComputation* computation = @@ -989,7 +989,7 @@ ENTRY main { param_1 = f32[1,8,127,128]{3,2,1,0} parameter(0) ROOT fusion = f32[8,127,128]{2,1,0} fusion(param_1), kind=kCustom, calls=triton_softmax_computation, - backend_config={"kind":"__triton_softmax"} + backend_config={"kind":"__triton"} })")); const HloComputation* computation = @@ -1020,7 +1020,7 @@ ENTRY main { param_1 = f32[1,2,4,127,128]{4,3,2,1,0} parameter(0) ROOT fusion = f32[8,127,128]{2,1,0} fusion(param_1), kind=kCustom, calls=triton_softmax_computation, - backend_config={"kind":"__triton_softmax"} + backend_config={"kind":"__triton"} })")); const HloComputation* computation = @@ -1054,7 +1054,7 @@ ENTRY main { param_0 = f32[27,260]{1,0} parameter(0) ROOT fusion = f32[4,127]{1,0} fusion(param_0), kind=kCustom, calls=triton_softmax_computation, - backend_config={"kind":"__triton_softmax"} + backend_config={"kind":"__triton"} })")); const HloComputation* computation = @@ -1092,7 +1092,7 @@ ENTRY main { producer_fusion = f32[125,127] fusion(param_0), kind=kLoop, calls=producer_computation ROOT triton_softmax = f32[125,127] fusion(producer_fusion), kind=kCustom, calls=triton_softmax_computation, - backend_config={"fusion_backend_config": {"kind":"__triton_softmax"}} + backend_config={"fusion_backend_config": {"kind":"__triton"}} })")); auto consumer = module->entry_computation()->root_instruction(); diff --git a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.cc b/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.cc index 1d4620acb7c138..11aa1e8a966013 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.cc +++ b/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.cc @@ -63,8 +63,7 @@ absl::StatusOr AsTritonFusion( fusion->backend_config()); const FusionBackendConfig& backend_config = gpu_config.fusion_backend_config(); - if (backend_config.kind() == kTritonFusionKind || - backend_config.kind() == kTritonSoftmaxFusionKind) { + if (backend_config.kind() == kTritonFusionKind) { return fusion; } return nullptr; From 58bf3f5d1f9001b20b63f946b190c379a6f5c7db Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Tue, 18 Jun 2024 13:24:58 -0700 Subject: [PATCH 37/59] Reshard LHS and RHS to match output sharding by default to handle dot operations in SPMD partitioner. Before this cl, the default solution is replicating LHS, RHS and conducting the full-shape matmul. With this cl, we reshard LHS and RHS to match the output sharding. The new solution is better than or the same as the old one. PiperOrigin-RevId: 644492679 --- .../xla/xla/service/spmd/dot_handler.cc | 44 +++++++++++++++++-- .../xla/service/spmd/spmd_partitioner_test.cc | 27 ++++++++++++ 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/spmd/dot_handler.cc b/third_party/xla/xla/service/spmd/dot_handler.cc index 74d141dcc3e4ec..ae63b1636fb82e 100644 --- a/third_party/xla/xla/service/spmd/dot_handler.cc +++ b/third_party/xla/xla/service/spmd/dot_handler.cc @@ -1778,7 +1778,6 @@ absl::StatusOr EmitWindowedDotGeneral( // one at a time. The base shapes and shardings can be changed during the // recursion as we group devices together. So refer to the passed in shapes and // shardings for inputs and output, and do not use shape inference. - absl::StatusOr PartitionBaseCase( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, const HloSharding& output_sharding, @@ -2143,7 +2142,8 @@ absl::StatusOr PartitionDot( const SpmdPartitionerOptions& options, SpmdBuilder* b, std::vector* windowed_dot_general_loops, - SpmdPartitioningVisitor* visitor); + SpmdPartitioningVisitor* visitor, + bool reshard_lhs_rhs_to_match_output_sharding = false); absl::StatusOr PartitionDotGroupOnBatchImpl( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, @@ -4232,6 +4232,33 @@ absl::StatusOr PartitionDot( return nullptr; } +// Reshard the LHS and RHS to match the output sharding. +absl::StatusOr ReshardLHSRHSToMatchOutputSharding( + const PartitionedHlo& lhs, const PartitionedHlo& rhs, + const DotConvolutionDimsInfo& dims_mapping, + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> + create_sharded_dot, + const Window& conv_window, HloInstruction* original_hlo, SpmdBuilder* b) { + const bool consider_other_operand = false; + const bool may_combine_partial_sharding = false; + const HloSharding infered_lhs_sharding = + hlo_sharding_util::InferDotOperandSharding(original_hlo, 0, dims_mapping, + consider_other_operand, + may_combine_partial_sharding); + const HloSharding infered_rhs_sharding = + hlo_sharding_util::InferDotOperandSharding(original_hlo, 1, dims_mapping, + consider_other_operand, + may_combine_partial_sharding); + + TF_ASSIGN_OR_RETURN( + auto dot, create_sharded_dot(lhs.Reshard(infered_lhs_sharding).hlo(), + rhs.Reshard(infered_rhs_sharding).hlo(), b, + conv_window)); + return dot; +} + absl::StatusOr PartitionDot( const PartitionedHlo& lhs, const PartitionedHlo& rhs, const Shape& output_base_shape, const HloSharding& output_sharding, @@ -4244,7 +4271,8 @@ absl::StatusOr PartitionDot( const SpmdPartitionerOptions& options, SpmdBuilder* b, std::vector* windowed_dot_general_loops, - SpmdPartitioningVisitor* visitor) { + SpmdPartitioningVisitor* visitor, + bool reshard_lhs_rhs_to_match_output_sharding) { // First try partitioning without resharding the groups, then try allow // resharding the groups. for (bool require_matching_devices_to_group : {true, false}) { @@ -4259,6 +4287,12 @@ absl::StatusOr PartitionDot( } } + if (reshard_lhs_rhs_to_match_output_sharding) { + return ReshardLHSRHSToMatchOutputSharding(lhs, rhs, dims_mapping, + create_sharded_dot, conv_window, + original_hlo, b); + } + // Default action. TF_ASSIGN_OR_RETURN( auto dot, create_sharded_dot(lhs.Replicate().hlo(), rhs.Replicate().hlo(), @@ -4291,7 +4325,9 @@ absl::Status SpmdPartitioningVisitor::HandleDotHelper( auto partitioned_dot, PartitionDot(lhs, rhs, hlo->shape(), hlo->sharding(), dims_mapping, num_partitions_, create_sharded_dot, conv_window, module_, - hlo, options_, &b_, &windowed_dot_general_loops_, this)); + hlo, options_, &b_, &windowed_dot_general_loops_, this, + /*reshard_lhs_rhs_to_match_output_sharding=*/ + dims_mapping.conv_spatial_dims.empty())); SetPartitionedHlo(hlo, [&] { return partitioned_dot; }); return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc index 83cda050d15837..002935739d188a 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc @@ -9079,6 +9079,33 @@ ENTRY entry { EXPECT_THAT(root, dot) << module->ToString(); } +TEST_P(SpmdPartitioningTest, ReshardLHSRHSToMatchDotSharding) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY %main.7 { + %p0 = bf16[32,97] parameter(0), sharding={devices=[32,1]<=[8,4]T(1,0)} + %p1 = bf16[48,64,97] parameter(1), sharding={devices=[8,4,1]<=[32]} + %dot.0 = bf16[32,48,64] dot(%p0, %p1), lhs_contracting_dims={1}, rhs_contracting_dims={2}, sharding={devices=[4,8,1]<=[8,4]T(1,0)} + %dot.1 = bf16[32,48,64] dot(%p0, %p1), lhs_contracting_dims={1}, rhs_contracting_dims={2}, sharding={devices=[4,4,1,2]<=[8,4]T(1,0) last_tile_dim_replicate} + ROOT %tuple = tuple(%dot.0, %dot.1), sharding={{devices=[4,8,1]<=[8,4]T(1,0)}, {devices=[4,4,1,2]<=[8,4]T(1,0) last_tile_dim_replicate}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/32)); + VLOG(1) << module->ToString(); + + const auto lhs = AllOf(op::Shape("bf16[8,97]")); + const auto rhs0 = AllOf(op::Shape("bf16[6,64,97]")); + const auto rhs1 = AllOf(op::Shape("bf16[12,64,97]")); + auto dot0 = AllOf(op::Shape("bf16[8,6,64]"), op::Dot(lhs, rhs0)); + auto dot1 = AllOf(op::Shape("bf16[8,12,64]"), op::Dot(lhs, rhs1)); + auto tuple = + AllOf(op::Shape("(bf16[8,6,64], bf16[8,12,64])"), op::Tuple(dot0, dot1)); + const auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, tuple); +} + TEST_P(SpmdPartitioningTest, ElementwiseTest_SubgroupSharding_TileToReplicate) { absl::string_view hlo_string = R"( HloModule module From 3ea9253b134238c50725b057d644a7da705d3c70 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 18 Jun 2024 13:41:47 -0700 Subject: [PATCH 38/59] Update tf_type shape to print 0 sized dimensions as 00 to avoid any parsing later confusing 0x as the start of a hex value. PiperOrigin-RevId: 644498184 --- .../tensorflow/tests/freeze_variables.mlir | 2 +- tensorflow/core/ir/types/BUILD | 1 + tensorflow/core/ir/types/dialect.cc | 13 +++- tensorflow/core/ir/types/dialect_test.cc | 62 +++++++++++++++++++ 4 files changed, 74 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir b/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir index a458a20d49e510..4e820220598398 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir @@ -424,7 +424,7 @@ module { } // CHECK: func private @f_callee(%[[ARG0:.*]]: tensor<0xf32>) -> tensor<0xf32> - // CHECK-SAME: tf._input_shapes = [#tf_type.shape<0>] + // CHECK-SAME: tf._input_shapes = [#tf_type.shape<00>] func.func private @f_callee(%arg0: tensor<0xf32>, %arg1: tensor<*x!tf_type.resource>) -> tensor<0xf32> attributes {tf._input_shapes = [#tf_type.shape<0>, #tf_type.shape<>]} { %0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource>) -> tensor<0xf32> %1 = "tf.AddV2"(%arg0, %0) : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32> diff --git a/tensorflow/core/ir/types/BUILD b/tensorflow/core/ir/types/BUILD index 18bb1541d76662..4577967ebc7d19 100644 --- a/tensorflow/core/ir/types/BUILD +++ b/tensorflow/core/ir/types/BUILD @@ -121,6 +121,7 @@ tf_cc_test( ":Dialect", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Support", diff --git a/tensorflow/core/ir/types/dialect.cc b/tensorflow/core/ir/types/dialect.cc index bd693656cf6fbe..481c9f0f055204 100644 --- a/tensorflow/core/ir/types/dialect.cc +++ b/tensorflow/core/ir/types/dialect.cc @@ -366,10 +366,17 @@ void ShapeAttr::print(AsmPrinter& os) const { os << "<"; if (hasRank()) { auto print_dim = [&](int64_t dim) { - if (dim != ShapedType::kDynamic) - os << dim; - else + if (dim != ShapedType::kDynamic) { + if (dim == 0) { + // In order to avoid the parseInteger below from confusing a dimension + // list with '0x' as hex integer, we use 00 for a 0 sized dimension. + os << "00"; + } else { + os << dim; + } + } else { os << "?"; + } }; llvm::interleave(getShape(), os, print_dim, "x"); } else { diff --git a/tensorflow/core/ir/types/dialect_test.cc b/tensorflow/core/ir/types/dialect_test.cc index 1fb6537b4684f7..84a301a93ad9d5 100644 --- a/tensorflow/core/ir/types/dialect_test.cc +++ b/tensorflow/core/ir/types/dialect_test.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/core/ir/types/dialect.h" +#include +#include + +#include #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -55,6 +59,64 @@ TEST(TFTypesDialect, TestFuncAttrSubElement) { EXPECT_TRUE(bar_ref == sym_ref); } +TEST(TFTypesDialect, ParsesDimensionListWithZero) { + // Test that a dimension list with zero can be parsed. + const char *const code = R"mlir( + "test.op"() {shape = #tf_type.shape<00x128>} : () -> () +)mlir"; + + MLIRContext context; + context.allowUnregisteredDialects(); + context.getOrLoadDialect(); + OwningOpRef module = + mlir::parseSourceString(code, &context); + Operation &test_op = module->front(); + + auto shape_attr = + mlir::dyn_cast(test_op.getAttr("shape")); + ASSERT_TRUE(shape_attr); + EXPECT_THAT(shape_attr.getShape(), testing::ElementsAre(0, 128)); +} + +TEST(TFTypesDialect, ParsesDimensionListWithQuestionMark) { + // Test that a dimension list with zero can be parsed. + const char *const code = R"mlir( + "test.op"() {shape = #tf_type.shape<0x?x2>} : () -> () +)mlir"; + + MLIRContext context; + context.allowUnregisteredDialects(); + context.getOrLoadDialect(); + OwningOpRef module = + mlir::parseSourceString(code, &context); + Operation &test_op = module->front(); + + auto shape_attr = + mlir::dyn_cast(test_op.getAttr("shape")); + ASSERT_TRUE(shape_attr); + EXPECT_THAT(shape_attr.getShape(), + testing::ElementsAre(0, std::numeric_limits::min(), 2)); +} + +TEST(TFTypesDialect, ParsesDimensionListWithNegativeOne) { + // Test that a dimension list with zero can be parsed. + const char *const code = R"mlir( + "test.op"() {shape = #tf_type.shape<0x-1x2>} : () -> () +)mlir"; + + MLIRContext context; + context.allowUnregisteredDialects(); + context.getOrLoadDialect(); + OwningOpRef module = + mlir::parseSourceString(code, &context); + Operation &test_op = module->front(); + + auto shape_attr = + mlir::dyn_cast(test_op.getAttr("shape")); + ASSERT_TRUE(shape_attr); + EXPECT_THAT(shape_attr.getShape(), testing::ElementsAre(0, -1, 2)); +} + } // namespace } // namespace tfg } // namespace mlir From a35ca23f4ff4dcc90c3fee7576219c7f211d04fe Mon Sep 17 00:00:00 2001 From: Vamsi Manchala Date: Tue, 18 Jun 2024 13:48:16 -0700 Subject: [PATCH 39/59] Explicitly reset tf::SavedModelBundle after done with using it. This releases the memory allocated/consumed as a result of saved_model creation. This is needed to lower the HWM memory space and reduces the HWM space by 1x of the model size. PiperOrigin-RevId: 644500154 --- .../python/saved_model_to_tfl_flatbuffer.cc | 2 +- .../lite/python/tf_tfl_flatbuffer_helpers.cc | 8 +++++--- .../lite/python/tf_tfl_flatbuffer_helpers.h | 2 +- .../compiler/mlir/lite/tf_tfl_translate.cc | 5 ++--- .../compiler/mlir/lite/tf_to_tfl_flatbuffer.cc | 17 +++++++++++++---- .../compiler/mlir/lite/tf_to_tfl_flatbuffer.h | 5 +++-- 6 files changed, 25 insertions(+), 14 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 085478db128a71..5a69362092503e 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -241,7 +241,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer( // TODO(b/153507667): Pass the session object when importing logic is removed. auto status = internal::ConvertMLIRToTFLiteFlatBuffer( model_flags, toco_flags, std::move(module), pass_config, tags, result, - bundle.get(), quantization_py_function_lib); + std::move(bundle), quantization_py_function_lib); return status; } diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index a4512d226939f7..369d5e15483714 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h" +#include #include #include #include @@ -349,7 +350,7 @@ absl::Status ConvertMLIRToTFLiteFlatBuffer( mlir::OwningOpRef module, const mlir::TFL::PassConfig& pass_config, const std::unordered_set& saved_model_tags, - std::string* result, SavedModelBundle* saved_model_bundle, + std::string* result, std::unique_ptr saved_model_bundle, const PyFunctionLibrary* quantization_py_function_lib) { if (toco_flags.has_dump_graphviz_dir()) { TF_RETURN_IF_ERROR(DumpOpGraphToFile( @@ -374,8 +375,9 @@ absl::Status ConvertMLIRToTFLiteFlatBuffer( auto status = ConvertTFExecutorToTFLOrFlatbuffer( module.get(), /*export_to_mlir=*/false, toco_flags, pass_config_copy, - saved_model_tags, model_flags.saved_model_dir(), saved_model_bundle, - result, /*serialize_stablehlo_ops=*/false, quantization_py_function_lib); + saved_model_tags, model_flags.saved_model_dir(), + std::move(saved_model_bundle), result, /*serialize_stablehlo_ops=*/false, + quantization_py_function_lib); if (toco_flags.has_dump_graphviz_dir()) { TF_RETURN_IF_ERROR(DumpOpGraphToFile( // rename once we enable the new converter feature flag. diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h index a57b3585abadb1..fb6d19b26fbeab 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -61,7 +61,7 @@ Status ConvertMLIRToTFLiteFlatBuffer( mlir::OwningOpRef module, const mlir::TFL::PassConfig& pass_config, const std::unordered_set& saved_model_tags, string* result, - SavedModelBundle* saved_model_bundle, + std::unique_ptr saved_model_bundle, const quantization::PyFunctionLibrary* quantization_py_function_lib); // Give a warning for any unused flags that have been specified. diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 8d124af7cb246a..edae191747d04b 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -317,11 +317,10 @@ int main(int argc, char **argv) { }); std::string result; - std::optional session = std::nullopt; - if (bundle) session = bundle->GetSession(); auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer( module.value().get(), output_mlir, toco_flags, pass_config, tags, - /*saved_model_dir=*/"", bundle.get(), &result, serialize_stablehlo_ops); + /*saved_model_dir=*/"", std::move(bundle), &result, + serialize_stablehlo_ops); if (!status.ok()) { llvm::errs() << status.message() << '\n'; return kTrFailure; diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index dd8b345862e3c8..227de88bb8510b 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -386,8 +386,9 @@ absl::Status ConvertTFExecutorToTFLOrFlatbuffer( mlir::ModuleOp module, bool export_to_mlir, toco::TocoFlags& toco_flags, const mlir::TFL::PassConfig& pass_config, const std::unordered_set& saved_model_tags, - llvm::StringRef saved_model_dir, SavedModelBundle* saved_model_bundle, - std::string* result, bool serialize_stablehlo_ops, + llvm::StringRef saved_model_dir, + std::unique_ptr saved_model_bundle, std::string* result, + bool serialize_stablehlo_ops, const PyFunctionLibrary* quantization_py_function_lib) { // Explicitly disable dumping Op details on failures. module.getContext()->printOpOnDiagnostic(false); @@ -433,8 +434,8 @@ absl::Status ConvertTFExecutorToTFLOrFlatbuffer( if (failed(RunHloToTfConversion( pass_config, saved_model_dir, saved_model_tags, toco_flags.mutable_quantization_config(), - quantization_py_function_lib, saved_model_bundle, pass_manager, - status_handler, module))) { + quantization_py_function_lib, saved_model_bundle.get(), + pass_manager, status_handler, module))) { return status_handler.ConsumeStatus(); } } @@ -454,6 +455,14 @@ absl::Status ConvertTFExecutorToTFLOrFlatbuffer( "converter.experimental_enable_resource_variables = True")); } + // Its safe to reset the saved_model_bundle after variable freezing, as this + // function owns the saved_model_bundle via std::move into a unique_ptr. + saved_model_bundle.reset(); + + // set session to nullptr to avoid invalid access as the session would be + // deleted along with the saved_model_bundle. + session = nullptr; + pass_manager.clear(); AddPostVariableFreezingTFToTFLConversionPasses(saved_model_dir, toco_flags, diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h index f77912938d8709..4c60c3b8f0cf26 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h @@ -86,8 +86,9 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( mlir::ModuleOp module, bool export_to_mlir, toco::TocoFlags& toco_flags, const mlir::TFL::PassConfig& pass_config, const std::unordered_set& saved_model_tags, - llvm::StringRef saved_model_dir, SavedModelBundle* saved_model_bundle, - std::string* result, bool serialize_stablehlo_ops = false, + llvm::StringRef saved_model_dir, + std::unique_ptr saved_model_bundle, std::string* result, + bool serialize_stablehlo_ops = false, const quantization::PyFunctionLibrary* quantization_py_function_lib = nullptr); } // namespace tensorflow From 7a3fbe7b27304f2532b223ff6e35e23a54276aac Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 18 Jun 2024 14:04:09 -0700 Subject: [PATCH 40/59] [XLA:GPU] Fix OSS dependency on protobuf descriptor. PiperOrigin-RevId: 644505065 --- third_party/xla/xla/service/gpu/BUILD | 2 +- third_party/xla/xla/service/gpu/triton_support_test.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 2bdbeb1fb17d21..dcb4a1d9b9bc10 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1268,7 +1268,6 @@ xla_test( ":triton_fusion_analysis", ":triton_support", ":triton_test_utils", - "//third_party/protobuf", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", @@ -1278,6 +1277,7 @@ xla_test( "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", ], diff --git a/third_party/xla/xla/service/gpu/triton_support_test.cc b/third_party/xla/xla/service/gpu/triton_support_test.cc index 0d6f2998953362..fb818d6efab617 100644 --- a/third_party/xla/xla/service/gpu/triton_support_test.cc +++ b/third_party/xla/xla/service/gpu/triton_support_test.cc @@ -27,7 +27,6 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" -#include "third_party/protobuf/descriptor.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/ir_emitter_triton.h" @@ -37,6 +36,7 @@ limitations under the License. #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/protobuf.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" From e4d0a29eccfb07579b8b14417f5aa19916c95204 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 18 Jun 2024 14:13:39 -0700 Subject: [PATCH 41/59] Remove an unused parameter PiperOrigin-RevId: 644508028 --- .../xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc | 5 ++--- .../xla/xla/hlo/experimental/auto_sharding/auto_sharding.h | 2 +- .../experimental/auto_sharding/auto_sharding_strategy.cc | 6 +++--- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 6cd3bce9c31518..1282bb36ade687 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -366,9 +366,8 @@ void FollowArrayOrTokenStrategyGroup( std::unique_ptr HandlePartialReduce( const HloInstruction* ins, const size_t instruction_id, - const bool have_memory_cost, StrategyGroups& strategy_groups, - const ClusterEnvironment& cluster_env, StrategyMap& strategy_map, - const CallGraph& call_graph) { + StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, + StrategyMap& strategy_map, const CallGraph& call_graph) { absl::StatusOr reduction_dim = GetPartialReduceReductionDim(ins); CHECK_OK(reduction_dim); const Shape& shape = ins->shape(); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h index f5f91a676d596d..bf5327318161aa 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -295,7 +295,7 @@ std::unique_ptr HandleManuallyShardedInstruction( StrategyGroups& strategy_groups, StrategyMap& strategy_map); std::unique_ptr HandlePartialReduce( - const HloInstruction* ins, size_t instruction_id, bool have_memory_cost, + const HloInstruction* ins, size_t instruction_id, StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, StrategyMap& strategy_map, const CallGraph& call_graph); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 54eed08bff059e..78e51f866e217b 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -792,9 +792,9 @@ BuildStrategyAndCost( } else if (IsTopKCustomCall(ins)) { generate_non_following_strategies(false, {0}); } else if (IsPartialReduceCustomCall(ins)) { - strategy_group = HandlePartialReduce( - ins, instruction_id, /* have_memory_cost */ true, strategy_groups, - cluster_env, strategy_map, call_graph); + strategy_group = + HandlePartialReduce(ins, instruction_id, strategy_groups, + cluster_env, strategy_map, call_graph); } else if (OutputInputSameShapes(ins)) { auto* partitioner = GetCustomCallPartitioner(ins->custom_call_target()); From 964fae6c0793c6c967d7935b4d84a57baf268f83 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 18 Jun 2024 14:18:54 -0700 Subject: [PATCH 42/59] [XLA] Add shardings for implicit operands and return values of CaseOp and IfOp. We only add arg shardings if there are result shardings, otherwise it means sharding propagation hasn't been done yet. PiperOrigin-RevId: 644509565 --- .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 100 ++++++++++--- .../translate/mhlo_to_hlo/tests/sharding.mlir | 140 ++++++++++++++++++ 2 files changed, 217 insertions(+), 23 deletions(-) diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 1095eb746bc7b9..4a6ad69d70428e 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -677,6 +678,20 @@ std::optional CreateTupleSharding( return sharding; } +// If `ops` has a single element, returns that element. Otherwise, returns +// a tuple instruction with `ops` and attaches a tuple sharding from +// `shardings`. +xla::XlaOp CreateTupleIfMultipleOps( + xla::XlaBuilder* builder, llvm::ArrayRef ops, + llvm::ArrayRef> shardings) { + if (ops.size() == 1) { + return ops[0]; + } + xla::XlaScopedShardingAssignment scoped_sharding( + builder, CreateTupleSharding(shardings)); + return Tuple(builder, ops); +} + // Returns the flattened result shardings of the given `op_sharding`, i.e., // either: // - an empty vector if `op_sharding` is `std::nullopt`. @@ -700,6 +715,20 @@ llvm::SmallVector> GetResultShardings( return res_shardings; } +// Returns the OpSharding of each op in `xla_ops`, or std::nullopt if the op +// doesn't have a sharding. +llvm::SmallVector> GetXlaOpShardings( + llvm::ArrayRef xla_ops) { + llvm::SmallVector> shardings; + shardings.reserve(xla_ops.size()); + for (const xla::XlaOp& xla_op : xla_ops) { + auto sharding = xla_op.builder()->GetOpSharding(xla_op); + assert(sharding.ok() && "can't find XlaOp for argument"); + shardings.push_back(*sharding); + } + return shardings; +} + namespace mlir { namespace { class ConvertToHloModule { @@ -1602,17 +1631,37 @@ LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) { llvm::SmallVector implicit_false_operands( implicit_false_operand_set.begin(), implicit_false_operand_set.end()); + llvm::SmallVector> ret_shardings = + GetResultShardings(ctx.builder->sharding(), op->getNumResults()); + + llvm::SmallVector true_args; + if (failed(GetXlaOps(op, implicit_true_operands, ctx, true_args))) + return failure(); + + llvm::SmallVector false_args; + if (failed(GetXlaOps(op, implicit_false_operands, ctx, false_args))) + return failure(); + + llvm::SmallVector> true_arg_shardings, + false_arg_shardings; + if (!ret_shardings.empty()) { + // We only add arg shardings if there are result shardings, otherwise it + // means sharding propagation hasn't been done yet. + true_arg_shardings = GetXlaOpShardings(true_args); + false_arg_shardings = GetXlaOpShardings(false_args); + } + // Create xla parameters for functions corresponding to ifOp regions using the // implicit captures operands. Also export the instructions within those // regions. if (failed(ctx.converter->LowerRegionAsComputation( &op.getTrueBranch(), &true_branch, llvm::ArrayRef(implicit_true_operands), - /*ensure_single_arg*/ true)) || + /*ensure_single_arg*/ true, true_arg_shardings, ret_shardings)) || failed(ctx.converter->LowerRegionAsComputation( &op.getFalseBranch(), &false_branch, llvm::ArrayRef(implicit_false_operands), - /*ensure_single_arg*/ true))) { + /*ensure_single_arg*/ true, false_arg_shardings, ret_shardings))) { return failure(); } @@ -1621,18 +1670,12 @@ LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) { if (failed(GetXlaOp(op.getPred(), value_map, &pred, op))) return failure(); // Create the true branch Xla argument. - llvm::SmallVector true_args; - if (failed(GetXlaOps(op, implicit_true_operands, ctx, true_args))) - return failure(); xla::XlaOp true_arg = - true_args.size() == 1 ? true_args[0] : Tuple(ctx.builder, true_args); + CreateTupleIfMultipleOps(ctx.builder, true_args, true_arg_shardings); // Create the false branch Xla argument. - llvm::SmallVector false_args; - if (failed(GetXlaOps(op, implicit_false_operands, ctx, false_args))) - return failure(); xla::XlaOp false_arg = - false_args.size() == 1 ? false_args[0] : Tuple(ctx.builder, false_args); + CreateTupleIfMultipleOps(ctx.builder, false_args, false_arg_shardings); // Create XLA Conditional op. auto ifop = @@ -1673,10 +1716,22 @@ LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) { llvm::SmallVector implicit_operands( implicit_operand_set.begin(), implicit_operand_set.end()); + llvm::SmallVector> ret_shardings = + GetResultShardings(ctx.builder->sharding(), op->getNumResults()); + // Create the branches[i]'s Xla argument. llvm::SmallVector args; if (failed(GetXlaOps(op, implicit_operands, ctx, args))) return failure(); - branch_operands[i] = args.size() == 1 ? args[0] : Tuple(ctx.builder, args); + + llvm::SmallVector> arg_shardings; + if (!ret_shardings.empty()) { + // We only add arg shardings if there are result shardings, otherwise it + // means sharding propagation hasn't been done yet. + arg_shardings = GetXlaOpShardings(args); + } + + branch_operands[i] = + CreateTupleIfMultipleOps(ctx.builder, args, arg_shardings); // Create xla parameters for functions corresponding to region branches[i] // using the implicit captures operands. Also export the instructions within @@ -1684,7 +1739,7 @@ LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) { computations_p[i] = &computations[i]; if (failed(ctx.converter->LowerRegionAsComputation( &branches[i], computations_p[i], llvm::ArrayRef(implicit_operands), - /*ensure_single_arg*/ true))) + /*ensure_single_arg*/ true, arg_shardings, ret_shardings))) return failure(); } @@ -3482,10 +3537,6 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( // Applicable for mhlo.IfOp or mhlo.CaseOp or mhlo.WhileOp. llvm::SmallVector arg_shapes; - // The arguments of `block` are ignored if `implicit_operands` is set, - // therefore `arg_shardings` should be empty in that case. - assert(arg_shardings.empty() || !implicit_operands); - auto args_size = block->getNumArguments(); if (implicit_operands) args_size = implicit_operands->size(); @@ -3512,10 +3563,13 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( "arg_tuple"); if (implicit_operands) { - int arg_index = 0; - for (auto implicit_operand : *implicit_operands) - lowering[implicit_operand] = - xla::GetTupleElement(tuple, arg_index++); + for (auto [arg_index, implicit_operand] : + llvm::enumerate(*implicit_operands)) { + xla::XlaScopedShardingAssignment scoped_sharding( + builder, arg_shardings.empty() ? std::nullopt + : arg_shardings[arg_index]); + lowering[implicit_operand] = xla::GetTupleElement(tuple, arg_index); + } } else { for (BlockArgument& arg : block->getArguments()) { auto num = arg.getArgNumber(); @@ -3528,6 +3582,9 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( } else if (args_size == 1) { // Save the location information as a name. For example JAX will set the // name of the function argument. Want to preserve these for debugging. + xla::XlaScopedShardingAssignment scoped_sharding( + builder, + arg_shardings.empty() ? std::nullopt : arg_shardings.front()); if (implicit_operands) { mlir::Value arg = (*implicit_operands)[0]; xla::XlaScopedOpMetadataAssignment op_metadata( @@ -3537,9 +3594,6 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( mlir::BlockArgument arg = block->getArgument(0); xla::XlaScopedOpMetadataAssignment op_metadata( builder, GetOpNameMetadataFromLocation(arg)); - xla::XlaScopedShardingAssignment scoped_sharding( - builder, - arg_shardings.empty() ? std::nullopt : arg_shardings.front()); lowering[arg] = xla::Parameter(builder, 0, arg_shapes[0], "Arg_"); } } else { diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/sharding.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/sharding.mlir index a93bdee50abb8e..295c95075c9fc7 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/sharding.mlir +++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/sharding.mlir @@ -220,3 +220,143 @@ func.func @main(%arg0: tensor, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) } func.return %0#1, %0#2 : tensor<4xf32>, tensor<4xf32> } + +// ----- + +// CHECK-LABEL: HloModule main + +// CHECK: %region_0.8 (arg_tuple.9: (f32[4], f32[4])) -> (f32[4], f32[4]) { +// CHECK-NEXT: %arg_tuple.9 = (f32[4], f32[4]) parameter(0), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} +// CHECK-NEXT: %get-tuple-element.10 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.9), index=0, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %get-tuple-element.11 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.9), index=1 +// CHECK-NEXT: ROOT %tuple.12 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.10, f32[4] %get-tuple-element.11), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} + +// CHECK: %region_1.13 (arg_tuple.14: (f32[4], f32[4])) -> (f32[4], f32[4]) { +// CHECK-NEXT: %arg_tuple.14 = (f32[4], f32[4]) parameter(0), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %get-tuple-element.15 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.14), index=0, sharding={replicated} +// CHECK-NEXT: %get-tuple-element.16 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.14), index=1, sharding={devices=[4]<=[4]} +// CHECK-NEXT: ROOT %tuple.17 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.15, f32[4] %get-tuple-element.16), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} + +// CHECK: ENTRY %main.22 (Arg_0.1: s32[], Arg_1.2: f32[4], Arg_2.3: f32[4], Arg_3.4: f32[4], Arg_4.5: f32[4]) -> (f32[4], f32[4]) { +// CHECK-NEXT: %Arg_0.1 = s32[] parameter(0) +// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %Arg_2.3 = f32[4] parameter(2) +// CHECK-NEXT: %tuple.6 = (f32[4], f32[4]) tuple(f32[4] %Arg_1.2, f32[4] %Arg_2.3), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} +// CHECK-NEXT: %Arg_3.4 = f32[4] parameter(3), sharding={replicated} +// CHECK-NEXT: %Arg_4.5 = f32[4] parameter(4), sharding={devices=[4]<=[4]} +// CHECK-NEXT: %tuple.7 = (f32[4], f32[4]) tuple(f32[4] %Arg_3.4, f32[4] %Arg_4.5), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %conditional.18 = (f32[4], f32[4]) conditional(s32[] %Arg_0.1, (f32[4], f32[4]) %tuple.6, (f32[4], f32[4]) %tuple.7), branch_computations={%region_0.8, %region_1.13}, +// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %get-tuple-element.19 = f32[4] get-tuple-element((f32[4], f32[4]) %conditional.18), index=0, sharding={replicated} +// CHECK-NEXT: %get-tuple-element.20 = f32[4] get-tuple-element((f32[4], f32[4]) %conditional.18), index=1, sharding={devices=[4]<=[4]} +// CHECK-NEXT: ROOT %tuple.21 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.19, f32[4] %get-tuple-element.20) + +func.func @main(%arg0: tensor, + %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}, + %arg2: tensor<4xf32>, + %arg3: tensor<4xf32> {mhlo.sharding = "{replicated}"}, + %arg4: tensor<4xf32> {mhlo.sharding = "{devices=[4]<=[4]}"}) -> (tensor<4xf32>, tensor<4xf32>) { + %0:2 = "mhlo.case"(%arg0) ( { + mhlo.return %arg1, %arg2 : tensor<4xf32>, tensor<4xf32> + }, { + mhlo.return %arg3, %arg4 : tensor<4xf32>, tensor<4xf32> + }) {mhlo.sharding = "{{replicated},{devices=[4]<=[4]}}"} : (tensor) -> (tensor<4xf32>, tensor<4xf32>) + func.return %0#0, %0#1 : tensor<4xf32>, tensor<4xf32> +} + + +// ----- + +// CHECK-LABEL: HloModule main + +// CHECK: %region_0.4 (Arg_.5: f32[4]) -> f32[4] { +// CHECK-NEXT: ROOT %Arg_.5 = f32[4] parameter(0) + +// CHECK: %region_1.6 (Arg_.7: f32[4]) -> f32[4] { +// CHECK-NEXT: ROOT %Arg_.7 = f32[4] parameter(0) + +// CHECK: ENTRY %main.9 (Arg_0.1: s32[], Arg_1.2: f32[4], Arg_2.3: f32[4]) -> f32[4] { +// CHECK-NEXT: %Arg_0.1 = s32[] parameter(0) +// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %Arg_2.3 = f32[4] parameter(2) +// CHECK-NEXT: ROOT %conditional.8 = f32[4] conditional(s32[] %Arg_0.1, f32[4] %Arg_1.2, f32[4] %Arg_2.3), branch_computations={%region_0.4, %region_1.6} +func.func @main(%arg0: tensor, + %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}, + %arg2: tensor<4xf32>) -> tensor<4xf32> { + %0 = "mhlo.case"(%arg0) ( { + mhlo.return %arg1 : tensor<4xf32> + }, { + mhlo.return %arg2 : tensor<4xf32> + }) : (tensor) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: HloModule main + +// CHECK: %region_0.8 (arg_tuple.9: (f32[4], f32[4])) -> (f32[4], f32[4]) { +// CHECK-NEXT: %arg_tuple.9 = (f32[4], f32[4]) parameter(0), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} +// CHECK-NEXT: %get-tuple-element.10 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.9), index=0, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %get-tuple-element.11 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.9), index=1 +// CHECK-NEXT: ROOT %tuple.12 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.10, f32[4] %get-tuple-element.11), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} + +// CHECK: %region_1.13 (arg_tuple.14: (f32[4], f32[4])) -> (f32[4], f32[4]) { +// CHECK-NEXT: %arg_tuple.14 = (f32[4], f32[4]) parameter(0), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %get-tuple-element.15 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.14), index=0, sharding={replicated} +// CHECK-NEXT: %get-tuple-element.16 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.14), index=1, sharding={devices=[4]<=[4]} +// CHECK-NEXT: ROOT %tuple.17 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.15, f32[4] %get-tuple-element.16), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} + +// CHECK: ENTRY %main.22 (Arg_0.1: pred[], Arg_1.2: f32[4], Arg_2.3: f32[4], Arg_3.4: f32[4], Arg_4.5: f32[4]) -> (f32[4], f32[4]) { +// CHECK-NEXT: %Arg_0.1 = pred[] parameter(0) +// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %Arg_2.3 = f32[4] parameter(2) +// CHECK-NEXT: %tuple.6 = (f32[4], f32[4]) tuple(f32[4] %Arg_1.2, f32[4] %Arg_2.3), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} +// CHECK-NEXT: %Arg_3.4 = f32[4] parameter(3), sharding={replicated} +// CHECK-NEXT: %Arg_4.5 = f32[4] parameter(4), sharding={devices=[4]<=[4]} +// CHECK-NEXT: %tuple.7 = (f32[4], f32[4]) tuple(f32[4] %Arg_3.4, f32[4] %Arg_4.5), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %conditional.18 = (f32[4], f32[4]) conditional(pred[] %Arg_0.1, (f32[4], f32[4]) %tuple.6, (f32[4], f32[4]) %tuple.7), true_computation=%region_0.8, false_computation=%region_1.13, +// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %get-tuple-element.19 = f32[4] get-tuple-element((f32[4], f32[4]) %conditional.18), index=0, sharding={replicated} +// CHECK-NEXT: %get-tuple-element.20 = f32[4] get-tuple-element((f32[4], f32[4]) %conditional.18), index=1, sharding={devices=[4]<=[4]} +// CHECK-NEXT: ROOT %tuple.21 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.19, f32[4] %get-tuple-element.20) + +func.func @main(%arg0: tensor, + %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}, + %arg2: tensor<4xf32>, + %arg3: tensor<4xf32> {mhlo.sharding = "{replicated}"}, + %arg4: tensor<4xf32> {mhlo.sharding = "{devices=[4]<=[4]}"}) -> (tensor<4xf32>, tensor<4xf32>) { + %0:2 = "mhlo.if"(%arg0) ( { + mhlo.return %arg1, %arg2 : tensor<4xf32>, tensor<4xf32> + }, { + mhlo.return %arg3, %arg4 : tensor<4xf32>, tensor<4xf32> + }) {mhlo.sharding = "{{replicated},{devices=[4]<=[4]}}"} : (tensor) -> (tensor<4xf32>, tensor<4xf32>) + func.return %0#0, %0#1 : tensor<4xf32>, tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: HloModule main + +// CHECK: %region_0.4 (Arg_.5: f32[4]) -> f32[4] { +// CHECK-NEXT: ROOT %Arg_.5 = f32[4] parameter(0) + +// CHECK: %region_1.6 (Arg_.7: f32[4]) -> f32[4] { +// CHECK-NEXT: ROOT %Arg_.7 = f32[4] parameter(0) + +// CHECK: ENTRY %main.9 (Arg_0.1: pred[], Arg_1.2: f32[4], Arg_2.3: f32[4]) -> f32[4] { +// CHECK-NEXT: %Arg_0.1 = pred[] parameter(0) +// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %Arg_2.3 = f32[4] parameter(2) +// CHECK-NEXT: ROOT %conditional.8 = f32[4] conditional(pred[] %Arg_0.1, f32[4] %Arg_1.2, f32[4] %Arg_2.3), true_computation=%region_0.4, false_computation=%region_1.6 + +func.func @main(%arg0: tensor, + %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}, + %arg2: tensor<4xf32>) -> tensor<4xf32> { + %0 = "mhlo.if"(%arg0) ( { + mhlo.return %arg1 : tensor<4xf32> + }, { + mhlo.return %arg2 : tensor<4xf32> + }) : (tensor) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} From 25099d18679c3a654df319824c3866c865bd9baa Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Tue, 18 Jun 2024 14:35:28 -0700 Subject: [PATCH 43/59] Use newer version of scorecards-analysis for XLA and TensorFlow PiperOrigin-RevId: 644514423 --- .github/workflows/scorecards-analysis.yml | 28 +++++++++--------- .../.github/workflows/scorecards-analysis.yml | 29 +++++++++++-------- 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/.github/workflows/scorecards-analysis.yml b/.github/workflows/scorecards-analysis.yml index 1c520aa86fd5ac..8ff0613e726863 100644 --- a/.github/workflows/scorecards-analysis.yml +++ b/.github/workflows/scorecards-analysis.yml @@ -15,12 +15,15 @@ name: Scorecards supply-chain security on: - # Only the default branch is supported. + # For Branch-Protection check. Only the default branch is supported. See + # https://github.com/ossf/scorecard/blob/main/docs/checks.md#branch-protection branch_protection_rule: + # To guarantee Maintained check is occasionally updated. See + # https://github.com/ossf/scorecard/blob/main/docs/checks.md#maintained schedule: - - cron: '44 15 * * 5' + - cron: '26 3 * * 2' push: - branches: [ master ] + branches: [ "master" ] # Declare default permissions as read only. permissions: read-all @@ -33,35 +36,34 @@ jobs: permissions: # Needed to upload the results to code-scanning dashboard. security-events: write + # Needed to publish results and get a badge (see publish_results below). id-token: write steps: - name: "Checkout code" - uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: persist-credentials: false - name: "Run analysis" - uses: ossf/scorecard-action@15c10fcf1cf912bd22260bfec67569a359ab87da # v2.1.1 + uses: ossf/scorecard-action@0864cf19026789058feabb7e87baa5f140aac736 # v2.3.1 with: results_file: results.sarif results_format: sarif - # Publish the results to enable scorecard badges. For more details, see - # https://github.com/ossf/scorecard-action#publishing-results. - # For private repositories, `publish_results` will automatically be set to `false`, - # regardless of the value entered here. publish_results: true - # Upload the results as artifacts (optional). + # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF + # format to the repository Actions tab. - name: "Upload artifact" - uses: actions/upload-artifact@83fd05a356d7e2593de66fc9913b3002723633cb # v3.1.1 + uses: actions/upload-artifact@97a0fba1372883ab732affbe8f94b823f91727db # v3.pre.node20 with: name: SARIF file path: results.sarif retention-days: 5 - # Upload the results to GitHub's code scanning dashboard. + # Upload the results to GitHub's code scanning dashboard (optional). + # Commenting out will disable upload of results to your repo's Code Scanning dashboard - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@896079047b4bb059ba6f150a5d87d47dde99e6e5 # v2.11.6 + uses: github/codeql-action/upload-sarif@1b1aada464948af03b950897e5eb522f92603cc2 # v3.24.9 with: sarif_file: results.sarif diff --git a/third_party/xla/.github/workflows/scorecards-analysis.yml b/third_party/xla/.github/workflows/scorecards-analysis.yml index 8fd13f8d91a736..51745c3c99825f 100644 --- a/third_party/xla/.github/workflows/scorecards-analysis.yml +++ b/third_party/xla/.github/workflows/scorecards-analysis.yml @@ -12,13 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +# This workflow uses actions that are not certified by GitHub. They are provided +# by a third-party and are governed by separate terms of service, privacy +# policy, and support documentation. name: Scorecards supply-chain security on: - # Only the default branch is supported. + # For Branch-Protection check. Only the default branch is supported. See + # https://github.com/ossf/scorecard/blob/main/docs/checks.md#branch-protection branch_protection_rule: + # To guarantee Maintained check is occasionally updated. See + # https://github.com/ossf/scorecard/blob/main/docs/checks.md#maintained schedule: - - cron: '55 14 * * 5' + - cron: '26 3 * * 2' push: branches: [ "main" ] @@ -33,35 +39,34 @@ jobs: permissions: # Needed to upload the results to code-scanning dashboard. security-events: write + # Needed to publish results and get a badge (see publish_results below). id-token: write steps: - name: "Checkout code" - uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: persist-credentials: false - name: "Run analysis" - uses: ossf/scorecard-action@15c10fcf1cf912bd22260bfec67569a359ab87da # v2.1.1 + uses: ossf/scorecard-action@0864cf19026789058feabb7e87baa5f140aac736 # v2.3.1 with: results_file: results.sarif results_format: sarif - # Publish the results to enable scorecard badges. For more details, see - # https://github.com/ossf/scorecard-action#publishing-results. - # For private repositories, `publish_results` will automatically be set to `false`, - # regardless of the value entered here. publish_results: true - # Upload the results as artifacts (optional). + # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF + # format to the repository Actions tab. - name: "Upload artifact" - uses: actions/upload-artifact@83fd05a356d7e2593de66fc9913b3002723633cb # v3.1.1 + uses: actions/upload-artifact@97a0fba1372883ab732affbe8f94b823f91727db # v3.pre.node20 with: name: SARIF file path: results.sarif retention-days: 5 - # Upload the results to GitHub's code scanning dashboard. + # Upload the results to GitHub's code scanning dashboard (optional). + # Commenting out will disable upload of results to your repo's Code Scanning dashboard - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@896079047b4bb059ba6f150a5d87d47dde99e6e5 # v2.11.6 + uses: github/codeql-action/upload-sarif@1b1aada464948af03b950897e5eb522f92603cc2 # v3.24.9 with: sarif_file: results.sarif From 3eb2492bc6593beb5cd14f616bb4fdca7c7cd474 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Tue, 18 Jun 2024 14:44:41 -0700 Subject: [PATCH 44/59] Remove use of deprecated op UnaryEinsum. PiperOrigin-RevId: 644517312 --- .../compiler/mlir/tensorflow/ir/tf_ops_a_m.cc | 6 ++-- .../mlir/tensorflow/tests/tf-ops.mlir | 10 +++++- .../mlir/tf2xla/tests/legalize-tf.mlir | 3 +- .../mlir/tf2xla/transforms/legalize_tf.cc | 34 +++++++++++-------- 4 files changed, 34 insertions(+), 19 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index 36fb36a3d451c6..5ad1642d2f064f 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -2450,13 +2450,13 @@ LogicalResult DynamicStitchOp::verify() { //===----------------------------------------------------------------------===// // Verifies that, -// * Arity of the op is at most two. +// * Arity of the op is one or two. // // TODO(hinsu): Verify einsum equation attribute. LogicalResult EinsumOp::verify() { EinsumOp op = *this; - if (op.getN() > 2) { - return op.emitOpError("supports at most two operands"); + if (op.getN() != 1 && op.getN() != 2) { + return op.emitOpError("must have 1 or 2 operands"); } return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 7605f0360625fe..7793634a8429c3 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -3168,8 +3168,16 @@ func.func @testSqueezeOutOfBounds(%arg0: tensor) -> tensor // ----- +func.func @testNullaryEinsum(%arg0: tensor<2x3xf32>){ + // expected-error @+1 {{op must have 1 or 2 operands}} + "tf.Einsum"() {equation = "->"} : () -> (tensor) + func.return +} + +// ----- + func.func @testTernaryEinsum(%arg0: tensor<2x3xf32>){ - // expected-error @+1 {{supports at most two operands}} + // expected-error @+1 {{op must have 1 or 2 operands}} %0 = "tf.Einsum"(%arg0, %arg0, %arg0) {equation = "ab,cd,ef->"} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> (tensor<*xf32>) func.return } diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir index 91008b91056d40..fce42a21be3a62 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir @@ -774,7 +774,8 @@ func.func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4x // CHECK-LABEL: func @unary_einsum func.func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { - // CHECK: mhlo.unary_einsum + // CHECK: mhlo.constant{{.*}}1.000000e+00 + // CHECK: mhlo.einsum{{.*}}",ab->aa" %0 = "tf.Einsum"(%arg0) {equation = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32> func.return %0: tensor<2x2xf32> } diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc index c7dbb1a9e7313e..dca10693d74b01 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc @@ -1994,27 +1994,33 @@ class ConvertMatrixDiagPartV3Op } }; -// Converts TensorFlow EinsumOp to either HLO EinsumOp or UnaryEinsumOp -// depending on arity of the op. +// Converts TensorFlow EinsumOp to HLO EinsumOp class ConvertEinsumOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TF::EinsumOp op, PatternRewriter &rewriter) const override { - StringAttr equation = op->getAttrOfType("equation"); + // Prepend `,` to equation if unary einsum. + std::string equation_str = op.getEquation().str(); + llvm::SmallVector inputs; + + // Unary einsum prepends `,` to equation and + // creates a scalar constant 1.0 for first operand. if (op.getN() == 1) { - rewriter.replaceOpWithNewOp( - op, op.getType(), *op.getInputs().begin(), equation); - } else if (op.getN() == 2) { - ValueRange inputs = op.getInputs(); - rewriter.replaceOpWithNewOp(op, op.getType(), inputs[0], - inputs[1], equation); - } else { - // TensorFlow EinsumOp verifies that the number of operands are at most - // two. - return failure(); - } + equation_str = "," + equation_str; + inputs.push_back(rewriter.create( + op.getLoc(), hlo::getScalarOfType( + mlir::getElementTypeOrSelf(op.getOperand(0)), 1))); + } + // Insert remaining operands into inputs, TF op verifier requires there be + // 0 or 1 operands. + auto operands = op.getInputs(); + inputs.insert(inputs.end(), operands.begin(), operands.end()); + assert(inputs.size() == 2); + + rewriter.replaceOpWithNewOp(op, op.getType(), inputs[0], + inputs[1], equation_str); return success(); } }; From cb736a66a2c047c96593bd5447d357a242335b88 Mon Sep 17 00:00:00 2001 From: Jing Pu Date: Tue, 18 Jun 2024 14:55:38 -0700 Subject: [PATCH 45/59] Internal BUILD change PiperOrigin-RevId: 644520276 --- tensorflow/compiler/mlir/lite/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 597ce884f352c7..7e4afbbe358849 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1379,6 +1379,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_quantization_passes", "//tensorflow/compiler/mlir/lite/stablehlo:compose_uniform_quantized_type_pass", "//tensorflow/compiler/mlir/lite/stablehlo:composite_lowering", + "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_composite_to_tfl_custom", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_tf_xla_call_module_to_stablehlo_pass", "//tensorflow/compiler/mlir/lite/stablehlo:rename_entrypoint_to_main", "//tensorflow/compiler/mlir/lite/stablehlo:tf_legalize_hlo", # buildcleaner: keep From 172289414ebd675a1609245e1bdc207147855afb Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Tue, 18 Jun 2024 15:10:57 -0700 Subject: [PATCH 46/59] Integrate StableHLO at openxla/stablehlo@f1f49945 PiperOrigin-RevId: 644524932 --- third_party/stablehlo/temporary.patch | 949 +++++++++++++++++- third_party/stablehlo/workspace.bzl | 4 +- .../xla/third_party/stablehlo/temporary.patch | 949 +++++++++++++++++- .../xla/third_party/stablehlo/workspace.bzl | 4 +- .../Dialect/chlo/chlo_legalize_to_mhlo.mlir | 360 ++++++- 5 files changed, 2123 insertions(+), 143 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 1db9bf674f5b38..0eb67a7bda2ee2 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -164,6 +164,893 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt #------------------------------------------------------------------------------- # Directory setup +diff --ruN a/stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py b/stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py +--- stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py ++++ stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py +@@ -22,54 +22,64 @@ + + + def main(): +- try: +- import functional_algorithms as fa +- except ImportError as msg: +- print(f"Skipping: {msg}") +- return ++ try: ++ import functional_algorithms as fa ++ except ImportError as msg: ++ print(f"Skipping: {msg}") ++ return + +- fa_version = tuple(map(int, fa.__version__.split(".", 4)[:3])) +- if fa_version < (0, 4, 0): +- warnings.warn("functional_algorithm version 0.4.0 or newer is required," +- f" got {fa.__version__}") +- return ++ fa_version = tuple(map(int, fa.__version__.split(".", 4)[:3])) ++ if fa_version < (0, 4, 0): ++ warnings.warn( ++ "functional_algorithm version 0.4.0 or newer is required," ++ f" got {fa.__version__}" ++ ) ++ return + +- output_file = os.path.relpath( +- os.path.normpath( +- os.path.join(os.path.dirname(__file__), "..", "..", "stablehlo", +- "transforms", "ChloDecompositionPatternsMath.td")), +- os.getcwd()) ++ output_file = os.path.relpath( ++ os.path.normpath( ++ os.path.join( ++ os.path.dirname(__file__), ++ "..", ++ "..", ++ "stablehlo", ++ "transforms", ++ "ChloDecompositionPatternsMath.td", ++ ) ++ ), ++ os.getcwd(), ++ ) + +- sources = [] +- target = fa.targets.stablehlo +- for chloname, fname, args in [ +- ("CHLO_AsinOp", "complex_asin", ("z:complex",)), +- ("CHLO_AsinOp", "real_asin", ("x:float",)), +- ]: +- func = getattr(fa.algorithms, fname, None) +- if func is None: +- warnings.warn( +- "{fa.algorithms.__name__} does not define {fname}. Skipping.") +- continue +- ctx = fa.Context(paths=[fa.algorithms]) +- graph = ctx.trace(func, *args).implement_missing(target).simplify() +- graph.props.update(name=chloname) +- src = graph.tostring(target) +- sources.append(target.make_comment( +- func.__doc__)) if func.__doc__ else None +- sources[-1] += src +- source = "\n\n".join(sources) + "\n" ++ sources = [] ++ target = fa.targets.stablehlo ++ for chloname, fname, args in [ ++ ("CHLO_AsinOp", "complex_asin", ("z:complex",)), ++ ("CHLO_AsinOp", "real_asin", ("x:float",)), ++ ]: ++ func = getattr(fa.algorithms, fname, None) ++ if func is None: ++ warnings.warn( ++ "{fa.algorithms.__name__} does not define {fname}. Skipping." ++ ) ++ continue ++ ctx = fa.Context(paths=[fa.algorithms]) ++ graph = ctx.trace(func, *args).implement_missing(target).simplify() ++ graph.props.update(name=chloname) ++ src = graph.tostring(target) ++ sources.append(target.make_comment(func.__doc__)) if func.__doc__ else None ++ sources[-1] += src ++ source = "\n\n".join(sources) + "\n" + +- if os.path.isfile(output_file): +- f = open(output_file, "r") +- content = f.read() +- f.close() +- if content.endswith(source): +- print(f"{output_file} is up-to-date.") +- return ++ if os.path.isfile(output_file): ++ f = open(output_file, "r") ++ content = f.read() ++ f.close() ++ if content.endswith(source): ++ print(f"{output_file} is up-to-date.") ++ return + +- f = open(output_file, "w") +- f.write("""\ ++ f = open(output_file, "w") ++ f.write("""\ + /* Copyright 2024 The StableHLO Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); +@@ -86,15 +96,14 @@ + ==============================================================================*/ + + """) +- f.write( +- target.make_comment(f"""\ ++ f.write(target.make_comment(f"""\ + + This file is generated using functional_algorithms tool ({fa.__version__}). + See build_tools/math/README.md for more information.""") + "\n") +- f.write(source) +- f.close() +- print(f"Created {output_file}") ++ f.write(source) ++ f.close() ++ print(f"Created {output_file}") + + + if __name__ == "__main__": +- main() ++ main() +diff --ruN a/stablehlo/build_tools/math/generate_tests.py b/stablehlo/build_tools/math/generate_tests.py +--- stablehlo/build_tools/math/generate_tests.py ++++ stablehlo/build_tools/math/generate_tests.py +@@ -55,374 +55,394 @@ + + + def main(): +- try: +- import functional_algorithms as fa +- except ImportError as msg: +- print(f"Skipping: {msg}") +- return +- +- fa_version = tuple(map(int, fa.__version__.split(".", 4)[:3])) +- if fa_version < (0, 4, 0): +- warnings.warn("functional_algorithm version 0.4.0 or newer is required," +- f" got {fa.__version__}") +- return +- +- target_dir = os.path.relpath( +- os.path.normpath( +- os.path.join(os.path.dirname(__file__), "..", "..", "stablehlo", +- "tests", "math")), os.getcwd()) +- +- flush_subnormals = False +- for op in operations: +- opname = op["name"] +- mpmath_opname = op.get("mpmath_name", opname) +- size_re = size_im = op.get("size", default_size) +- extra_prec_multiplier = op.get("extra_prec_multiplier", +- default_extra_prec_multiplier) +- max_ulp_difference = op.get("max_ulp_difference", +- default_max_ulp_difference) +- +- nmp = fa.utils.numpy_with_mpmath( +- extra_prec_multiplier=extra_prec_multiplier, +- flush_subnormals=flush_subnormals) +- for dtype in [np.complex64, np.complex128, np.float32, np.float64]: +- fi = np.finfo(dtype) +- +- float_dtype = to_float_dtype[dtype] +- finfo = np.finfo(float_dtype) +- +- if dtype in [np.complex64, np.complex128]: +- samples = fa.utils.complex_samples( +- size=(size_re, size_im), +- dtype=dtype, +- include_subnormal=not flush_subnormals).flatten() +- else: +- samples = fa.utils.real_samples( +- size=size_re * size_im, +- dtype=dtype, +- include_subnormal=not flush_subnormals).flatten() +- +- expected = getattr(nmp, mpmath_opname)(samples) +- +- module_name = f"{opname}_{dtype.__name__}" +- m = SSA.make_module(module_name) +- +- samples_func = m.make_function("samples", "", mlir_type(samples)) +- samples_func.assign(samples) +- samples_func.return_last() +- +- expected_func = m.make_function("expected", "", mlir_type(expected)) +- expected_func.assign(expected) +- expected_func.return_last() +- +- main_func = m.make_function("main", "", "", "public") +- +- ref_samples = main_func.call("samples") +- actual = main_func.composite(f"chlo.{opname}", ref_samples) +- expected = main_func.call("expected") +- +- main_func.void_call( +- "check.expect_close", +- actual, +- expected, +- f"max_ulp_difference = {max_ulp_difference}", +- atypes=", ".join(map(main_func.get_ref_type, +- [actual, expected])), +- ) +- main_func.void_call("func.return") +- source = str(m).rstrip() + "\n" +- fname = os.path.join(target_dir, f"{module_name}.mlir") +- if os.path.isfile(fname): +- f = open(fname, "r") +- content = f.read() +- f.close() +- if content.endswith(source): +- print(f"{fname} is up-to-date.") +- continue +- +- f = open(fname, "w") +- f.write( +- "// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s | stablehlo-translate --interpret\n" +- ) +- f.write( +- "// This file is generated, see build_tools/math/README.md for more information.\n" +- ) +- f.write(source) +- f.close() +- print(f"Created {fname}") +- +- # Testing ULP difference +- for dtype in [np.float32, np.float64]: +- fi = np.finfo(dtype) +- +- max_ulp_difference = 0 +- min_ulp_difference = 0 +- +- finfo = np.finfo(dtype) +- module_name = f"ulp_difference_{dtype.__name__}" +- m = SSA.make_module(module_name) +- +- main_func = m.make_function("main", "", "", "public") +- +- def samples_generator(): +- data = [ +- -finfo.max, -1e9 - 1.2, -finfo.smallest_normal, +- -finfo.smallest_subnormal, 0, finfo.smallest_subnormal, +- finfo.smallest_normal, 1.2, 1e9 +- ] +- for expected_ulp_difference in [0, 1, 5, 50]: +- if expected_ulp_difference == 0: +- actual = np.array(data + [np.inf, -np.inf, np.nan], +- dtype=dtype) +- else: +- actual = np.array(data, dtype=dtype) +- shifted = actual +- for i in range(expected_ulp_difference): +- shifted = np.nextafter(shifted, np.inf, dtype=dtype) +- label = str(expected_ulp_difference) +- yield actual, shifted, expected_ulp_difference, label +- +- actual = np.array([np.inf] * 5, dtype=dtype) +- shifted = np.array([-np.inf, np.nan, 0, 1.2, finfo.max], +- dtype=dtype) +- yield actual, shifted, 2**64 - 1, "nonfinite" +- +- for actual, shifted, expected_ulp_difference, label in samples_generator( +- ): +- +- actual_func = m.make_function(f"actual_{label}", "", +- mlir_type(actual)) +- actual_func.comment(f'{list(actual)}') +- actual_func.assign(actual) +- actual_func.return_last() +- +- shifted_func = m.make_function(f"shifted_{label}", "", +- mlir_type(shifted)) +- shifted_func.comment(f'{list(shifted)}') +- shifted_func.assign(shifted) +- shifted_func.return_last() +- +- actual_values = main_func.call(f"actual_{label}") +- shifted_values = main_func.call(f"shifted_{label}") +- +- main_func.void_call( +- "check.expect_close", +- actual_values, +- shifted_values, +- f"max_ulp_difference = {expected_ulp_difference}", +- f"min_ulp_difference = {expected_ulp_difference}", +- atypes=", ".join( +- map(main_func.get_ref_type, +- [actual_values, shifted_values])), +- ) +- +- main_func.void_call("func.return") +- source = str(m).rstrip() + "\n" +- fname = os.path.join(target_dir, f"{module_name}.mlir") +- if os.path.isfile(fname): +- f = open(fname, "r") +- content = f.read() +- f.close() +- if content.endswith(source): +- print(f"{fname} is up-to-date.") +- continue +- +- f = open(fname, "w") +- f.write( +- "// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s | stablehlo-translate --interpret\n" +- ) +- f.write( +- "// This file is generated, see build_tools/math/README.md for more information.\n" +- ) +- f.write(source) ++ try: ++ import functional_algorithms as fa ++ except ImportError as msg: ++ print(f"Skipping: {msg}") ++ return ++ ++ fa_version = tuple(map(int, fa.__version__.split(".", 4)[:3])) ++ if fa_version < (0, 4, 0): ++ warnings.warn( ++ "functional_algorithm version 0.4.0 or newer is required," ++ f" got {fa.__version__}" ++ ) ++ return ++ ++ target_dir = os.path.relpath( ++ os.path.normpath( ++ os.path.join( ++ os.path.dirname(__file__), ++ "..", ++ "..", ++ "stablehlo", ++ "tests", ++ "math", ++ ) ++ ), ++ os.getcwd(), ++ ) ++ ++ flush_subnormals = False ++ for op in operations: ++ opname = op["name"] ++ mpmath_opname = op.get("mpmath_name", opname) ++ size_re = size_im = op.get("size", default_size) ++ extra_prec_multiplier = op.get( ++ "extra_prec_multiplier", default_extra_prec_multiplier ++ ) ++ max_ulp_difference = op.get( ++ "max_ulp_difference", default_max_ulp_difference ++ ) ++ ++ nmp = fa.utils.numpy_with_mpmath( ++ extra_prec_multiplier=extra_prec_multiplier, ++ flush_subnormals=flush_subnormals, ++ ) ++ for dtype in [np.complex64, np.complex128, np.float32, np.float64]: ++ fi = np.finfo(dtype) ++ ++ float_dtype = to_float_dtype[dtype] ++ finfo = np.finfo(float_dtype) ++ ++ if dtype in [np.complex64, np.complex128]: ++ samples = fa.utils.complex_samples( ++ size=(size_re, size_im), ++ dtype=dtype, ++ include_subnormal=not flush_subnormals, ++ ).flatten() ++ else: ++ samples = fa.utils.real_samples( ++ size=size_re * size_im, ++ dtype=dtype, ++ include_subnormal=not flush_subnormals, ++ ).flatten() ++ ++ expected = getattr(nmp, mpmath_opname)(samples) ++ ++ module_name = f"{opname}_{dtype.__name__}" ++ m = SSA.make_module(module_name) ++ ++ samples_func = m.make_function("samples", "", mlir_type(samples)) ++ samples_func.assign(samples) ++ samples_func.return_last() ++ ++ expected_func = m.make_function("expected", "", mlir_type(expected)) ++ expected_func.assign(expected) ++ expected_func.return_last() ++ ++ main_func = m.make_function("main", "", "", "public") ++ ++ ref_samples = main_func.call("samples") ++ actual = main_func.composite(f"chlo.{opname}", ref_samples) ++ expected = main_func.call("expected") ++ ++ main_func.void_call( ++ "check.expect_close", ++ actual, ++ expected, ++ f"max_ulp_difference = {max_ulp_difference}", ++ atypes=", ".join(map(main_func.get_ref_type, [actual, expected])), ++ ) ++ main_func.void_call("func.return") ++ source = str(m).rstrip() + "\n" ++ fname = os.path.join(target_dir, f"{module_name}.mlir") ++ if os.path.isfile(fname): ++ f = open(fname, "r") ++ content = f.read() + f.close() +- print(f"Created {fname}") ++ if content.endswith(source): ++ print(f"{fname} is up-to-date.") ++ continue ++ ++ f = open(fname, "w") ++ f.write( ++ "// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s |" ++ " stablehlo-translate --interpret\n" ++ ) ++ f.write( ++ "// This file is generated, see build_tools/math/README.md for more" ++ " information.\n" ++ ) ++ f.write(source) ++ f.close() ++ print(f"Created {fname}") ++ ++ # Testing ULP difference ++ for dtype in [np.float32, np.float64]: ++ fi = np.finfo(dtype) ++ ++ max_ulp_difference = 0 ++ min_ulp_difference = 0 ++ ++ finfo = np.finfo(dtype) ++ module_name = f"ulp_difference_{dtype.__name__}" ++ m = SSA.make_module(module_name) ++ ++ main_func = m.make_function("main", "", "", "public") ++ ++ def samples_generator(): ++ data = [ ++ -finfo.max, ++ -1e9 - 1.2, ++ -finfo.smallest_normal, ++ -finfo.smallest_subnormal, ++ 0, ++ finfo.smallest_subnormal, ++ finfo.smallest_normal, ++ 1.2, ++ 1e9, ++ ] ++ for expected_ulp_difference in [0, 1, 5, 50]: ++ if expected_ulp_difference == 0: ++ actual = np.array(data + [np.inf, -np.inf, np.nan], dtype=dtype) ++ else: ++ actual = np.array(data, dtype=dtype) ++ shifted = actual ++ for i in range(expected_ulp_difference): ++ shifted = np.nextafter(shifted, np.inf, dtype=dtype) ++ label = str(expected_ulp_difference) ++ yield actual, shifted, expected_ulp_difference, label ++ ++ actual = np.array([np.inf] * 5, dtype=dtype) ++ shifted = np.array([-np.inf, np.nan, 0, 1.2, finfo.max], dtype=dtype) ++ yield actual, shifted, 2**64 - 1, "nonfinite" ++ ++ for actual, shifted, expected_ulp_difference, label in samples_generator(): ++ ++ actual_func = m.make_function(f"actual_{label}", "", mlir_type(actual)) ++ actual_func.comment(f"{list(actual)}") ++ actual_func.assign(actual) ++ actual_func.return_last() ++ ++ shifted_func = m.make_function(f"shifted_{label}", "", mlir_type(shifted)) ++ shifted_func.comment(f"{list(shifted)}") ++ shifted_func.assign(shifted) ++ shifted_func.return_last() ++ ++ actual_values = main_func.call(f"actual_{label}") ++ shifted_values = main_func.call(f"shifted_{label}") ++ ++ main_func.void_call( ++ "check.expect_close", ++ actual_values, ++ shifted_values, ++ f"max_ulp_difference = {expected_ulp_difference}", ++ f"min_ulp_difference = {expected_ulp_difference}", ++ atypes=", ".join( ++ map(main_func.get_ref_type, [actual_values, shifted_values]) ++ ), ++ ) ++ ++ main_func.void_call("func.return") ++ source = str(m).rstrip() + "\n" ++ fname = os.path.join(target_dir, f"{module_name}.mlir") ++ if os.path.isfile(fname): ++ f = open(fname, "r") ++ content = f.read() ++ f.close() ++ if content.endswith(source): ++ print(f"{fname} is up-to-date.") ++ continue ++ ++ f = open(fname, "w") ++ f.write( ++ "// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s |" ++ " stablehlo-translate --interpret\n" ++ ) ++ f.write( ++ "// This file is generated, see build_tools/math/README.md for more" ++ " information.\n" ++ ) ++ f.write(source) ++ f.close() ++ print(f"Created {fname}") + + + class Block: +- """A data structure used in SSA""" +- +- def __init__(self, parent, prefix, suffix, start_counter=0): +- self.parent = parent +- self.prefix = prefix +- self.suffix = suffix +- self.counter = start_counter +- self.statements = {} +- +- def tostr(self, tab=""): +- lines = [] +- lines.append(tab + self.prefix) +- for i in sorted(self.statements): +- op, expr, typ = self.statements[i] +- if op == "//": +- lines.append(f"{tab} {op} {expr}") +- elif typ: +- lines.append(f"{tab} {op} {expr} : {typ}") +- else: +- assert not expr, (op, expr, typ) +- lines.append(f"{tab} {op}") +- lines.append(tab + self.suffix) +- return "\n".join(lines) +- +- def comment(self, message): +- # add comment to code +- self.statements[self.counter] = ("//", message, None) +- self.counter += 1 +- +- def assign(self, expr, typ=None): +- if isinstance(expr, np.ndarray): +- assert typ is None, typ +- typ = mlir_type(expr) +- expr = shlo_constant(expr) +- elif isinstance(expr, str) and typ is not None: +- pass +- elif isinstance(expr, bool) and typ is not None: +- expr = shlo_constant(expr) +- else: +- raise NotImplementedError((expr, typ)) +- target = f"%{self.counter}" +- self.statements[self.counter] = (f"{target} =", expr, typ) +- self.counter += 1 +- return target +- +- def call(self, name, *args): +- # call function created with make_function +- sargs = ", ".join(args) +- return self.assign(f"call @{name}({sargs})", +- typ=self.get_function_type(name)) +- +- def composite(self, name, *args, **options): +- sargs = ", ".join(args) +- atypes = tuple(map(self.get_ref_type, args)) +- rtype = options.get("rtype") +- if rtype is None: +- # assuming the first op argument defines the op type +- rtype = atypes[0] +- sargs = ", ".join(args) +- typ = f'({", ".join(atypes)}) -> {rtype}' +- return self.assign(f'"{name}"({sargs})', typ=typ) +- +- def void_call(self, name, *args, **options): +- # call function that has void return +- if args: +- sargs = ", ".join(args) +- atypes = options.get("atypes") +- if atypes is None: +- atypes = ", ".join(map(self.get_ref_type, args)) +- self.statements[self.counter] = (name, f"{sargs}", f"{atypes}") +- else: +- self.statements[self.counter] = (name, "", "") +- self.counter += 1 +- +- def apply(self, op, *args, **options): +- sargs = ", ".join(args) +- atypes = tuple(map(self.get_ref_type, args)) +- rtype = options.get("rtype") +- if rtype is None: +- # assuming the first op argument defines the op type +- rtype = atypes[0] +- typ = f'({", ".join(atypes)}) -> {rtype}' +- return self.assign(f"{op} {sargs}", typ=typ) +- +- def return_last(self): +- ref = f"%{self.counter - 1}" +- self.statements[self.counter] = ("return", ref, self.get_ref_type(ref)) +- self.counter += 1 +- +- @property +- def is_function(self): +- return self.prefix.startwith("func.func") +- +- @property +- def function_name(self): +- if self.prefix.startswith("func.func"): +- i = self.prefix.find("@") +- j = self.prefix.find("(", i) +- assert -1 not in {i, j}, self.prefix +- return self.prefix[i + 1:j] +- +- @property +- def function_type(self): +- if self.prefix.startswith("func.func"): +- i = self.prefix.find("(", self.prefix.find("@")) +- j = self.prefix.find("{", i) +- assert -1 not in {i, j}, self.prefix +- return self.prefix[i:j].strip() +- +- def get_function_type(self, name): +- for block in self.parent.blocks: +- if block.function_name == name: +- return block.function_type +- +- def get_ref_type(self, ref): +- assert ref.startswith("%"), ref +- counter = int(ref[1:]) +- typ = self.statements[counter][-1] +- return typ.rsplit("->", 1)[-1].strip() ++ """A data structure used in SSA""" ++ ++ def __init__(self, parent, prefix, suffix, start_counter=0): ++ self.parent = parent ++ self.prefix = prefix ++ self.suffix = suffix ++ self.counter = start_counter ++ self.statements = {} ++ ++ def tostr(self, tab=""): ++ lines = [] ++ lines.append(tab + self.prefix) ++ for i in sorted(self.statements): ++ op, expr, typ = self.statements[i] ++ if op == "//": ++ lines.append(f"{tab} {op} {expr}") ++ elif typ: ++ lines.append(f"{tab} {op} {expr} : {typ}") ++ else: ++ assert not expr, (op, expr, typ) ++ lines.append(f"{tab} {op}") ++ lines.append(tab + self.suffix) ++ return "\n".join(lines) ++ ++ def comment(self, message): ++ # add comment to code ++ self.statements[self.counter] = ("//", message, None) ++ self.counter += 1 ++ ++ def assign(self, expr, typ=None): ++ if isinstance(expr, np.ndarray): ++ assert typ is None, typ ++ typ = mlir_type(expr) ++ expr = shlo_constant(expr) ++ elif isinstance(expr, str) and typ is not None: ++ pass ++ elif isinstance(expr, bool) and typ is not None: ++ expr = shlo_constant(expr) ++ else: ++ raise NotImplementedError((expr, typ)) ++ target = f"%{self.counter}" ++ self.statements[self.counter] = (f"{target} =", expr, typ) ++ self.counter += 1 ++ return target ++ ++ def call(self, name, *args): ++ # call function created with make_function ++ sargs = ", ".join(args) ++ return self.assign( ++ f"call @{name}({sargs})", typ=self.get_function_type(name) ++ ) ++ ++ def composite(self, name, *args, **options): ++ sargs = ", ".join(args) ++ atypes = tuple(map(self.get_ref_type, args)) ++ rtype = options.get("rtype") ++ if rtype is None: ++ # assuming the first op argument defines the op type ++ rtype = atypes[0] ++ sargs = ", ".join(args) ++ typ = f'({", ".join(atypes)}) -> {rtype}' ++ return self.assign(f'"{name}"({sargs})', typ=typ) ++ ++ def void_call(self, name, *args, **options): ++ # call function that has void return ++ if args: ++ sargs = ", ".join(args) ++ atypes = options.get("atypes") ++ if atypes is None: ++ atypes = ", ".join(map(self.get_ref_type, args)) ++ self.statements[self.counter] = (name, f"{sargs}", f"{atypes}") ++ else: ++ self.statements[self.counter] = (name, "", "") ++ self.counter += 1 ++ ++ def apply(self, op, *args, **options): ++ sargs = ", ".join(args) ++ atypes = tuple(map(self.get_ref_type, args)) ++ rtype = options.get("rtype") ++ if rtype is None: ++ # assuming the first op argument defines the op type ++ rtype = atypes[0] ++ typ = f'({", ".join(atypes)}) -> {rtype}' ++ return self.assign(f"{op} {sargs}", typ=typ) ++ ++ def return_last(self): ++ ref = f"%{self.counter - 1}" ++ self.statements[self.counter] = ("return", ref, self.get_ref_type(ref)) ++ self.counter += 1 ++ ++ @property ++ def is_function(self): ++ return self.prefix.startwith("func.func") ++ ++ @property ++ def function_name(self): ++ if self.prefix.startswith("func.func"): ++ i = self.prefix.find("@") ++ j = self.prefix.find("(", i) ++ assert -1 not in {i, j}, self.prefix ++ return self.prefix[i + 1 : j] ++ ++ @property ++ def function_type(self): ++ if self.prefix.startswith("func.func"): ++ i = self.prefix.find("(", self.prefix.find("@")) ++ j = self.prefix.find("{", i) ++ assert -1 not in {i, j}, self.prefix ++ return self.prefix[i:j].strip() ++ ++ def get_function_type(self, name): ++ for block in self.parent.blocks: ++ if block.function_name == name: ++ return block.function_type ++ ++ def get_ref_type(self, ref): ++ assert ref.startswith("%"), ref ++ counter = int(ref[1:]) ++ typ = self.statements[counter][-1] ++ return typ.rsplit("->", 1)[-1].strip() + + + class SSA: +- """A light-weight SSA form factory.""" +- +- def __init__(self, prefix, suffix): +- self.prefix = prefix +- self.suffix = suffix +- self.blocks = [] +- +- @classmethod +- def make_module(cls, name): +- return SSA(f"module @{name} {{", "}") +- +- def make_function(self, name, args, rtype, attrs="private"): +- if rtype: +- b = Block(self, f"func.func {attrs} @{name}({args}) -> {rtype} {{", +- "}") +- else: +- b = Block(self, f"func.func {attrs} @{name}({args}) {{", "}") +- self.blocks.append(b) +- return b +- +- def tostr(self, tab=""): +- lines = [] +- lines.append(tab + self.prefix) +- for b in self.blocks: +- lines.extend(b.tostr(tab=tab + " ").split("\n")) +- lines.append(tab + self.suffix) +- return "\n".join(lines) +- +- def __str__(self): +- return self.tostr() ++ """A light-weight SSA form factory.""" ++ ++ def __init__(self, prefix, suffix): ++ self.prefix = prefix ++ self.suffix = suffix ++ self.blocks = [] ++ ++ @classmethod ++ def make_module(cls, name): ++ return SSA(f"module @{name} {{", "}") ++ ++ def make_function(self, name, args, rtype, attrs="private"): ++ if rtype: ++ b = Block(self, f"func.func {attrs} @{name}({args}) -> {rtype} {{", "}") ++ else: ++ b = Block(self, f"func.func {attrs} @{name}({args}) {{", "}") ++ self.blocks.append(b) ++ return b ++ ++ def tostr(self, tab=""): ++ lines = [] ++ lines.append(tab + self.prefix) ++ for b in self.blocks: ++ lines.extend(b.tostr(tab=tab + " ").split("\n")) ++ lines.append(tab + self.suffix) ++ return "\n".join(lines) ++ ++ def __str__(self): ++ return self.tostr() + + + def mlir_type(obj): +- if isinstance(obj, np.ndarray): +- s = "x".join(map(str, obj.shape)) +- t = { +- np.bool_: "i1", +- np.float16: "f16", +- np.float32: "f32", +- np.float64: "f64", +- np.complex64: "complex", +- np.complex128: "complex", +- }[obj.dtype.type] +- return f"tensor<{s}x{t}>" ++ if isinstance(obj, np.ndarray): ++ s = "x".join(map(str, obj.shape)) ++ t = { ++ np.bool_: "i1", ++ np.float16: "f16", ++ np.float32: "f32", ++ np.float64: "f64", ++ np.complex64: "complex", ++ np.complex128: "complex", ++ }[obj.dtype.type] ++ return f"tensor<{s}x{t}>" ++ else: ++ raise NotImplementedError(type(obj)) ++ ++ ++def shlo_constant(obj): ++ if isinstance(obj, bool): ++ v = str(obj).lower() ++ return f"stablehlo.constant dense<{v}>" ++ if isinstance(obj, np.ndarray): ++ if obj.dtype == np.bool_: ++ h = "".join(map(lambda n: "%01x" % n, obj.view(np.uint8))).upper() + else: +- raise NotImplementedError(type(obj)) +- +- +-def shlo_constant(obj): +- if isinstance(obj, bool): +- v = str(obj).lower() +- return f"stablehlo.constant dense<{v}>" +- if isinstance(obj, np.ndarray): +- if obj.dtype == np.bool_: +- h = "".join(map(lambda n: "%01x" % n, obj.view(np.uint8))).upper() +- else: +- h = "".join(map(lambda n: "%02x" % n, obj.view(np.uint8))).upper() +- return f'stablehlo.constant dense<"0x{h}">' +- else: +- raise NotImplementedError(type(obj)) ++ h = "".join(map(lambda n: "%02x" % n, obj.view(np.uint8))).upper() ++ return f'stablehlo.constant dense<"0x{h}">' ++ else: ++ raise NotImplementedError(type(obj)) + + + if __name__ == "__main__": +- main() ++ main() diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists.txt --- stablehlo/stablehlo/CMakeLists.txt +++ stablehlo/stablehlo/CMakeLists.txt @@ -175,18 +1062,6 @@ diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists add_subdirectory(integrations) add_subdirectory(reference) add_subdirectory(tests) -diff --ruN a/stablehlo/stablehlo/api/PortableApi.h b/stablehlo/stablehlo/api/PortableApi.h ---- stablehlo/stablehlo/api/PortableApi.h -+++ stablehlo/stablehlo/api/PortableApi.h -@@ -27,7 +27,7 @@ - - /// Return the current version for portable API. - /// Increments on all meaningful changes to this file. --inline int64_t getApiVersion() { return 6; } -+inline int64_t getApiVersion() { return 7; } - - // Get the current StableHLO version. - // diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel --- stablehlo/stablehlo/experimental/BUILD.bazel +++ stablehlo/stablehlo/experimental/BUILD.bazel @@ -2955,41 +3830,21 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c +} // namespace experimental +} // namespace stablehlo +} // namespace mlir -diff --ruN a/stablehlo/stablehlo/tests/TestUtils.cpp b/stablehlo/stablehlo/tests/TestUtils.cpp ---- stablehlo/stablehlo/tests/TestUtils.cpp -+++ stablehlo/stablehlo/tests/TestUtils.cpp -@@ -176,7 +176,8 @@ - GreedyRewriteConfig config; - config.maxIterations = 1; - config.useTopDownTraversal = true; -- config.enableRegionSimplification = false; -+ config.enableRegionSimplification = -+ mlir::GreedySimplifyRegionLevel::Disabled; - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - } +diff --ruN a/stablehlo/stablehlo/integrations/python/tests/stablehlo.py b/stablehlo/stablehlo/integrations/python/tests/stablehlo.py +--- stablehlo/stablehlo/integrations/python/tests/stablehlo.py ++++ stablehlo/stablehlo/integrations/python/tests/stablehlo.py +@@ -283,11 +283,13 @@ + expected = arg + arg + assert (actual == expected).all() -diff --ruN a/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp ---- stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp -+++ stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp -@@ -300,7 +300,7 @@ ++ + @run + def test_get_smaller_version(): + curr_version = stablehlo.get_current_version() + min_version = stablehlo.get_minimum_version() + assert stablehlo.get_smaller_version(curr_version, min_version) == min_version ++ - LogicalResult initialize(MLIRContext* context) override { - config.useTopDownTraversal = true; -- config.enableRegionSimplification = true; -+ config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; - config.maxIterations = 2; - config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; - config.strictMode = GreedyRewriteStrictness::AnyOp; -diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ---- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -@@ -728,7 +728,7 @@ - // There have been recent refactors to applyPatternsAndFoldGreedily - // upstream, and that might be the reason. - config.useTopDownTraversal = true; -- config.enableRegionSimplification = true; -+ config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; - config.maxIterations = 2; - config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; - config.strictMode = GreedyRewriteStrictness::AnyOp; + @run + def test_serialization_apis(): diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 3b14d447210c95..15648501cf7f1c 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "dd48ec58d3bb8d674adf56715d4394102538fa84" - STABLEHLO_SHA256 = "beb4ddc7246326d92cc1a864fb387d6c2d103a5feda096f5b296c24a766ea080" + STABLEHLO_COMMIT = "f1f49945f3862a46ecd6c7fc111e9d7843c2b7da" + STABLEHLO_SHA256 = "5e79eaf23075e627c8cbcb4e8572f91fd2db70a4a96722f02490482f26177b4b" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index 1db9bf674f5b38..0eb67a7bda2ee2 100755 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -164,6 +164,893 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt #------------------------------------------------------------------------------- # Directory setup +diff --ruN a/stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py b/stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py +--- stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py ++++ stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py +@@ -22,54 +22,64 @@ + + + def main(): +- try: +- import functional_algorithms as fa +- except ImportError as msg: +- print(f"Skipping: {msg}") +- return ++ try: ++ import functional_algorithms as fa ++ except ImportError as msg: ++ print(f"Skipping: {msg}") ++ return + +- fa_version = tuple(map(int, fa.__version__.split(".", 4)[:3])) +- if fa_version < (0, 4, 0): +- warnings.warn("functional_algorithm version 0.4.0 or newer is required," +- f" got {fa.__version__}") +- return ++ fa_version = tuple(map(int, fa.__version__.split(".", 4)[:3])) ++ if fa_version < (0, 4, 0): ++ warnings.warn( ++ "functional_algorithm version 0.4.0 or newer is required," ++ f" got {fa.__version__}" ++ ) ++ return + +- output_file = os.path.relpath( +- os.path.normpath( +- os.path.join(os.path.dirname(__file__), "..", "..", "stablehlo", +- "transforms", "ChloDecompositionPatternsMath.td")), +- os.getcwd()) ++ output_file = os.path.relpath( ++ os.path.normpath( ++ os.path.join( ++ os.path.dirname(__file__), ++ "..", ++ "..", ++ "stablehlo", ++ "transforms", ++ "ChloDecompositionPatternsMath.td", ++ ) ++ ), ++ os.getcwd(), ++ ) + +- sources = [] +- target = fa.targets.stablehlo +- for chloname, fname, args in [ +- ("CHLO_AsinOp", "complex_asin", ("z:complex",)), +- ("CHLO_AsinOp", "real_asin", ("x:float",)), +- ]: +- func = getattr(fa.algorithms, fname, None) +- if func is None: +- warnings.warn( +- "{fa.algorithms.__name__} does not define {fname}. Skipping.") +- continue +- ctx = fa.Context(paths=[fa.algorithms]) +- graph = ctx.trace(func, *args).implement_missing(target).simplify() +- graph.props.update(name=chloname) +- src = graph.tostring(target) +- sources.append(target.make_comment( +- func.__doc__)) if func.__doc__ else None +- sources[-1] += src +- source = "\n\n".join(sources) + "\n" ++ sources = [] ++ target = fa.targets.stablehlo ++ for chloname, fname, args in [ ++ ("CHLO_AsinOp", "complex_asin", ("z:complex",)), ++ ("CHLO_AsinOp", "real_asin", ("x:float",)), ++ ]: ++ func = getattr(fa.algorithms, fname, None) ++ if func is None: ++ warnings.warn( ++ "{fa.algorithms.__name__} does not define {fname}. Skipping." ++ ) ++ continue ++ ctx = fa.Context(paths=[fa.algorithms]) ++ graph = ctx.trace(func, *args).implement_missing(target).simplify() ++ graph.props.update(name=chloname) ++ src = graph.tostring(target) ++ sources.append(target.make_comment(func.__doc__)) if func.__doc__ else None ++ sources[-1] += src ++ source = "\n\n".join(sources) + "\n" + +- if os.path.isfile(output_file): +- f = open(output_file, "r") +- content = f.read() +- f.close() +- if content.endswith(source): +- print(f"{output_file} is up-to-date.") +- return ++ if os.path.isfile(output_file): ++ f = open(output_file, "r") ++ content = f.read() ++ f.close() ++ if content.endswith(source): ++ print(f"{output_file} is up-to-date.") ++ return + +- f = open(output_file, "w") +- f.write("""\ ++ f = open(output_file, "w") ++ f.write("""\ + /* Copyright 2024 The StableHLO Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); +@@ -86,15 +96,14 @@ + ==============================================================================*/ + + """) +- f.write( +- target.make_comment(f"""\ ++ f.write(target.make_comment(f"""\ + + This file is generated using functional_algorithms tool ({fa.__version__}). + See build_tools/math/README.md for more information.""") + "\n") +- f.write(source) +- f.close() +- print(f"Created {output_file}") ++ f.write(source) ++ f.close() ++ print(f"Created {output_file}") + + + if __name__ == "__main__": +- main() ++ main() +diff --ruN a/stablehlo/build_tools/math/generate_tests.py b/stablehlo/build_tools/math/generate_tests.py +--- stablehlo/build_tools/math/generate_tests.py ++++ stablehlo/build_tools/math/generate_tests.py +@@ -55,374 +55,394 @@ + + + def main(): +- try: +- import functional_algorithms as fa +- except ImportError as msg: +- print(f"Skipping: {msg}") +- return +- +- fa_version = tuple(map(int, fa.__version__.split(".", 4)[:3])) +- if fa_version < (0, 4, 0): +- warnings.warn("functional_algorithm version 0.4.0 or newer is required," +- f" got {fa.__version__}") +- return +- +- target_dir = os.path.relpath( +- os.path.normpath( +- os.path.join(os.path.dirname(__file__), "..", "..", "stablehlo", +- "tests", "math")), os.getcwd()) +- +- flush_subnormals = False +- for op in operations: +- opname = op["name"] +- mpmath_opname = op.get("mpmath_name", opname) +- size_re = size_im = op.get("size", default_size) +- extra_prec_multiplier = op.get("extra_prec_multiplier", +- default_extra_prec_multiplier) +- max_ulp_difference = op.get("max_ulp_difference", +- default_max_ulp_difference) +- +- nmp = fa.utils.numpy_with_mpmath( +- extra_prec_multiplier=extra_prec_multiplier, +- flush_subnormals=flush_subnormals) +- for dtype in [np.complex64, np.complex128, np.float32, np.float64]: +- fi = np.finfo(dtype) +- +- float_dtype = to_float_dtype[dtype] +- finfo = np.finfo(float_dtype) +- +- if dtype in [np.complex64, np.complex128]: +- samples = fa.utils.complex_samples( +- size=(size_re, size_im), +- dtype=dtype, +- include_subnormal=not flush_subnormals).flatten() +- else: +- samples = fa.utils.real_samples( +- size=size_re * size_im, +- dtype=dtype, +- include_subnormal=not flush_subnormals).flatten() +- +- expected = getattr(nmp, mpmath_opname)(samples) +- +- module_name = f"{opname}_{dtype.__name__}" +- m = SSA.make_module(module_name) +- +- samples_func = m.make_function("samples", "", mlir_type(samples)) +- samples_func.assign(samples) +- samples_func.return_last() +- +- expected_func = m.make_function("expected", "", mlir_type(expected)) +- expected_func.assign(expected) +- expected_func.return_last() +- +- main_func = m.make_function("main", "", "", "public") +- +- ref_samples = main_func.call("samples") +- actual = main_func.composite(f"chlo.{opname}", ref_samples) +- expected = main_func.call("expected") +- +- main_func.void_call( +- "check.expect_close", +- actual, +- expected, +- f"max_ulp_difference = {max_ulp_difference}", +- atypes=", ".join(map(main_func.get_ref_type, +- [actual, expected])), +- ) +- main_func.void_call("func.return") +- source = str(m).rstrip() + "\n" +- fname = os.path.join(target_dir, f"{module_name}.mlir") +- if os.path.isfile(fname): +- f = open(fname, "r") +- content = f.read() +- f.close() +- if content.endswith(source): +- print(f"{fname} is up-to-date.") +- continue +- +- f = open(fname, "w") +- f.write( +- "// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s | stablehlo-translate --interpret\n" +- ) +- f.write( +- "// This file is generated, see build_tools/math/README.md for more information.\n" +- ) +- f.write(source) +- f.close() +- print(f"Created {fname}") +- +- # Testing ULP difference +- for dtype in [np.float32, np.float64]: +- fi = np.finfo(dtype) +- +- max_ulp_difference = 0 +- min_ulp_difference = 0 +- +- finfo = np.finfo(dtype) +- module_name = f"ulp_difference_{dtype.__name__}" +- m = SSA.make_module(module_name) +- +- main_func = m.make_function("main", "", "", "public") +- +- def samples_generator(): +- data = [ +- -finfo.max, -1e9 - 1.2, -finfo.smallest_normal, +- -finfo.smallest_subnormal, 0, finfo.smallest_subnormal, +- finfo.smallest_normal, 1.2, 1e9 +- ] +- for expected_ulp_difference in [0, 1, 5, 50]: +- if expected_ulp_difference == 0: +- actual = np.array(data + [np.inf, -np.inf, np.nan], +- dtype=dtype) +- else: +- actual = np.array(data, dtype=dtype) +- shifted = actual +- for i in range(expected_ulp_difference): +- shifted = np.nextafter(shifted, np.inf, dtype=dtype) +- label = str(expected_ulp_difference) +- yield actual, shifted, expected_ulp_difference, label +- +- actual = np.array([np.inf] * 5, dtype=dtype) +- shifted = np.array([-np.inf, np.nan, 0, 1.2, finfo.max], +- dtype=dtype) +- yield actual, shifted, 2**64 - 1, "nonfinite" +- +- for actual, shifted, expected_ulp_difference, label in samples_generator( +- ): +- +- actual_func = m.make_function(f"actual_{label}", "", +- mlir_type(actual)) +- actual_func.comment(f'{list(actual)}') +- actual_func.assign(actual) +- actual_func.return_last() +- +- shifted_func = m.make_function(f"shifted_{label}", "", +- mlir_type(shifted)) +- shifted_func.comment(f'{list(shifted)}') +- shifted_func.assign(shifted) +- shifted_func.return_last() +- +- actual_values = main_func.call(f"actual_{label}") +- shifted_values = main_func.call(f"shifted_{label}") +- +- main_func.void_call( +- "check.expect_close", +- actual_values, +- shifted_values, +- f"max_ulp_difference = {expected_ulp_difference}", +- f"min_ulp_difference = {expected_ulp_difference}", +- atypes=", ".join( +- map(main_func.get_ref_type, +- [actual_values, shifted_values])), +- ) +- +- main_func.void_call("func.return") +- source = str(m).rstrip() + "\n" +- fname = os.path.join(target_dir, f"{module_name}.mlir") +- if os.path.isfile(fname): +- f = open(fname, "r") +- content = f.read() +- f.close() +- if content.endswith(source): +- print(f"{fname} is up-to-date.") +- continue +- +- f = open(fname, "w") +- f.write( +- "// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s | stablehlo-translate --interpret\n" +- ) +- f.write( +- "// This file is generated, see build_tools/math/README.md for more information.\n" +- ) +- f.write(source) ++ try: ++ import functional_algorithms as fa ++ except ImportError as msg: ++ print(f"Skipping: {msg}") ++ return ++ ++ fa_version = tuple(map(int, fa.__version__.split(".", 4)[:3])) ++ if fa_version < (0, 4, 0): ++ warnings.warn( ++ "functional_algorithm version 0.4.0 or newer is required," ++ f" got {fa.__version__}" ++ ) ++ return ++ ++ target_dir = os.path.relpath( ++ os.path.normpath( ++ os.path.join( ++ os.path.dirname(__file__), ++ "..", ++ "..", ++ "stablehlo", ++ "tests", ++ "math", ++ ) ++ ), ++ os.getcwd(), ++ ) ++ ++ flush_subnormals = False ++ for op in operations: ++ opname = op["name"] ++ mpmath_opname = op.get("mpmath_name", opname) ++ size_re = size_im = op.get("size", default_size) ++ extra_prec_multiplier = op.get( ++ "extra_prec_multiplier", default_extra_prec_multiplier ++ ) ++ max_ulp_difference = op.get( ++ "max_ulp_difference", default_max_ulp_difference ++ ) ++ ++ nmp = fa.utils.numpy_with_mpmath( ++ extra_prec_multiplier=extra_prec_multiplier, ++ flush_subnormals=flush_subnormals, ++ ) ++ for dtype in [np.complex64, np.complex128, np.float32, np.float64]: ++ fi = np.finfo(dtype) ++ ++ float_dtype = to_float_dtype[dtype] ++ finfo = np.finfo(float_dtype) ++ ++ if dtype in [np.complex64, np.complex128]: ++ samples = fa.utils.complex_samples( ++ size=(size_re, size_im), ++ dtype=dtype, ++ include_subnormal=not flush_subnormals, ++ ).flatten() ++ else: ++ samples = fa.utils.real_samples( ++ size=size_re * size_im, ++ dtype=dtype, ++ include_subnormal=not flush_subnormals, ++ ).flatten() ++ ++ expected = getattr(nmp, mpmath_opname)(samples) ++ ++ module_name = f"{opname}_{dtype.__name__}" ++ m = SSA.make_module(module_name) ++ ++ samples_func = m.make_function("samples", "", mlir_type(samples)) ++ samples_func.assign(samples) ++ samples_func.return_last() ++ ++ expected_func = m.make_function("expected", "", mlir_type(expected)) ++ expected_func.assign(expected) ++ expected_func.return_last() ++ ++ main_func = m.make_function("main", "", "", "public") ++ ++ ref_samples = main_func.call("samples") ++ actual = main_func.composite(f"chlo.{opname}", ref_samples) ++ expected = main_func.call("expected") ++ ++ main_func.void_call( ++ "check.expect_close", ++ actual, ++ expected, ++ f"max_ulp_difference = {max_ulp_difference}", ++ atypes=", ".join(map(main_func.get_ref_type, [actual, expected])), ++ ) ++ main_func.void_call("func.return") ++ source = str(m).rstrip() + "\n" ++ fname = os.path.join(target_dir, f"{module_name}.mlir") ++ if os.path.isfile(fname): ++ f = open(fname, "r") ++ content = f.read() + f.close() +- print(f"Created {fname}") ++ if content.endswith(source): ++ print(f"{fname} is up-to-date.") ++ continue ++ ++ f = open(fname, "w") ++ f.write( ++ "// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s |" ++ " stablehlo-translate --interpret\n" ++ ) ++ f.write( ++ "// This file is generated, see build_tools/math/README.md for more" ++ " information.\n" ++ ) ++ f.write(source) ++ f.close() ++ print(f"Created {fname}") ++ ++ # Testing ULP difference ++ for dtype in [np.float32, np.float64]: ++ fi = np.finfo(dtype) ++ ++ max_ulp_difference = 0 ++ min_ulp_difference = 0 ++ ++ finfo = np.finfo(dtype) ++ module_name = f"ulp_difference_{dtype.__name__}" ++ m = SSA.make_module(module_name) ++ ++ main_func = m.make_function("main", "", "", "public") ++ ++ def samples_generator(): ++ data = [ ++ -finfo.max, ++ -1e9 - 1.2, ++ -finfo.smallest_normal, ++ -finfo.smallest_subnormal, ++ 0, ++ finfo.smallest_subnormal, ++ finfo.smallest_normal, ++ 1.2, ++ 1e9, ++ ] ++ for expected_ulp_difference in [0, 1, 5, 50]: ++ if expected_ulp_difference == 0: ++ actual = np.array(data + [np.inf, -np.inf, np.nan], dtype=dtype) ++ else: ++ actual = np.array(data, dtype=dtype) ++ shifted = actual ++ for i in range(expected_ulp_difference): ++ shifted = np.nextafter(shifted, np.inf, dtype=dtype) ++ label = str(expected_ulp_difference) ++ yield actual, shifted, expected_ulp_difference, label ++ ++ actual = np.array([np.inf] * 5, dtype=dtype) ++ shifted = np.array([-np.inf, np.nan, 0, 1.2, finfo.max], dtype=dtype) ++ yield actual, shifted, 2**64 - 1, "nonfinite" ++ ++ for actual, shifted, expected_ulp_difference, label in samples_generator(): ++ ++ actual_func = m.make_function(f"actual_{label}", "", mlir_type(actual)) ++ actual_func.comment(f"{list(actual)}") ++ actual_func.assign(actual) ++ actual_func.return_last() ++ ++ shifted_func = m.make_function(f"shifted_{label}", "", mlir_type(shifted)) ++ shifted_func.comment(f"{list(shifted)}") ++ shifted_func.assign(shifted) ++ shifted_func.return_last() ++ ++ actual_values = main_func.call(f"actual_{label}") ++ shifted_values = main_func.call(f"shifted_{label}") ++ ++ main_func.void_call( ++ "check.expect_close", ++ actual_values, ++ shifted_values, ++ f"max_ulp_difference = {expected_ulp_difference}", ++ f"min_ulp_difference = {expected_ulp_difference}", ++ atypes=", ".join( ++ map(main_func.get_ref_type, [actual_values, shifted_values]) ++ ), ++ ) ++ ++ main_func.void_call("func.return") ++ source = str(m).rstrip() + "\n" ++ fname = os.path.join(target_dir, f"{module_name}.mlir") ++ if os.path.isfile(fname): ++ f = open(fname, "r") ++ content = f.read() ++ f.close() ++ if content.endswith(source): ++ print(f"{fname} is up-to-date.") ++ continue ++ ++ f = open(fname, "w") ++ f.write( ++ "// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s |" ++ " stablehlo-translate --interpret\n" ++ ) ++ f.write( ++ "// This file is generated, see build_tools/math/README.md for more" ++ " information.\n" ++ ) ++ f.write(source) ++ f.close() ++ print(f"Created {fname}") + + + class Block: +- """A data structure used in SSA""" +- +- def __init__(self, parent, prefix, suffix, start_counter=0): +- self.parent = parent +- self.prefix = prefix +- self.suffix = suffix +- self.counter = start_counter +- self.statements = {} +- +- def tostr(self, tab=""): +- lines = [] +- lines.append(tab + self.prefix) +- for i in sorted(self.statements): +- op, expr, typ = self.statements[i] +- if op == "//": +- lines.append(f"{tab} {op} {expr}") +- elif typ: +- lines.append(f"{tab} {op} {expr} : {typ}") +- else: +- assert not expr, (op, expr, typ) +- lines.append(f"{tab} {op}") +- lines.append(tab + self.suffix) +- return "\n".join(lines) +- +- def comment(self, message): +- # add comment to code +- self.statements[self.counter] = ("//", message, None) +- self.counter += 1 +- +- def assign(self, expr, typ=None): +- if isinstance(expr, np.ndarray): +- assert typ is None, typ +- typ = mlir_type(expr) +- expr = shlo_constant(expr) +- elif isinstance(expr, str) and typ is not None: +- pass +- elif isinstance(expr, bool) and typ is not None: +- expr = shlo_constant(expr) +- else: +- raise NotImplementedError((expr, typ)) +- target = f"%{self.counter}" +- self.statements[self.counter] = (f"{target} =", expr, typ) +- self.counter += 1 +- return target +- +- def call(self, name, *args): +- # call function created with make_function +- sargs = ", ".join(args) +- return self.assign(f"call @{name}({sargs})", +- typ=self.get_function_type(name)) +- +- def composite(self, name, *args, **options): +- sargs = ", ".join(args) +- atypes = tuple(map(self.get_ref_type, args)) +- rtype = options.get("rtype") +- if rtype is None: +- # assuming the first op argument defines the op type +- rtype = atypes[0] +- sargs = ", ".join(args) +- typ = f'({", ".join(atypes)}) -> {rtype}' +- return self.assign(f'"{name}"({sargs})', typ=typ) +- +- def void_call(self, name, *args, **options): +- # call function that has void return +- if args: +- sargs = ", ".join(args) +- atypes = options.get("atypes") +- if atypes is None: +- atypes = ", ".join(map(self.get_ref_type, args)) +- self.statements[self.counter] = (name, f"{sargs}", f"{atypes}") +- else: +- self.statements[self.counter] = (name, "", "") +- self.counter += 1 +- +- def apply(self, op, *args, **options): +- sargs = ", ".join(args) +- atypes = tuple(map(self.get_ref_type, args)) +- rtype = options.get("rtype") +- if rtype is None: +- # assuming the first op argument defines the op type +- rtype = atypes[0] +- typ = f'({", ".join(atypes)}) -> {rtype}' +- return self.assign(f"{op} {sargs}", typ=typ) +- +- def return_last(self): +- ref = f"%{self.counter - 1}" +- self.statements[self.counter] = ("return", ref, self.get_ref_type(ref)) +- self.counter += 1 +- +- @property +- def is_function(self): +- return self.prefix.startwith("func.func") +- +- @property +- def function_name(self): +- if self.prefix.startswith("func.func"): +- i = self.prefix.find("@") +- j = self.prefix.find("(", i) +- assert -1 not in {i, j}, self.prefix +- return self.prefix[i + 1:j] +- +- @property +- def function_type(self): +- if self.prefix.startswith("func.func"): +- i = self.prefix.find("(", self.prefix.find("@")) +- j = self.prefix.find("{", i) +- assert -1 not in {i, j}, self.prefix +- return self.prefix[i:j].strip() +- +- def get_function_type(self, name): +- for block in self.parent.blocks: +- if block.function_name == name: +- return block.function_type +- +- def get_ref_type(self, ref): +- assert ref.startswith("%"), ref +- counter = int(ref[1:]) +- typ = self.statements[counter][-1] +- return typ.rsplit("->", 1)[-1].strip() ++ """A data structure used in SSA""" ++ ++ def __init__(self, parent, prefix, suffix, start_counter=0): ++ self.parent = parent ++ self.prefix = prefix ++ self.suffix = suffix ++ self.counter = start_counter ++ self.statements = {} ++ ++ def tostr(self, tab=""): ++ lines = [] ++ lines.append(tab + self.prefix) ++ for i in sorted(self.statements): ++ op, expr, typ = self.statements[i] ++ if op == "//": ++ lines.append(f"{tab} {op} {expr}") ++ elif typ: ++ lines.append(f"{tab} {op} {expr} : {typ}") ++ else: ++ assert not expr, (op, expr, typ) ++ lines.append(f"{tab} {op}") ++ lines.append(tab + self.suffix) ++ return "\n".join(lines) ++ ++ def comment(self, message): ++ # add comment to code ++ self.statements[self.counter] = ("//", message, None) ++ self.counter += 1 ++ ++ def assign(self, expr, typ=None): ++ if isinstance(expr, np.ndarray): ++ assert typ is None, typ ++ typ = mlir_type(expr) ++ expr = shlo_constant(expr) ++ elif isinstance(expr, str) and typ is not None: ++ pass ++ elif isinstance(expr, bool) and typ is not None: ++ expr = shlo_constant(expr) ++ else: ++ raise NotImplementedError((expr, typ)) ++ target = f"%{self.counter}" ++ self.statements[self.counter] = (f"{target} =", expr, typ) ++ self.counter += 1 ++ return target ++ ++ def call(self, name, *args): ++ # call function created with make_function ++ sargs = ", ".join(args) ++ return self.assign( ++ f"call @{name}({sargs})", typ=self.get_function_type(name) ++ ) ++ ++ def composite(self, name, *args, **options): ++ sargs = ", ".join(args) ++ atypes = tuple(map(self.get_ref_type, args)) ++ rtype = options.get("rtype") ++ if rtype is None: ++ # assuming the first op argument defines the op type ++ rtype = atypes[0] ++ sargs = ", ".join(args) ++ typ = f'({", ".join(atypes)}) -> {rtype}' ++ return self.assign(f'"{name}"({sargs})', typ=typ) ++ ++ def void_call(self, name, *args, **options): ++ # call function that has void return ++ if args: ++ sargs = ", ".join(args) ++ atypes = options.get("atypes") ++ if atypes is None: ++ atypes = ", ".join(map(self.get_ref_type, args)) ++ self.statements[self.counter] = (name, f"{sargs}", f"{atypes}") ++ else: ++ self.statements[self.counter] = (name, "", "") ++ self.counter += 1 ++ ++ def apply(self, op, *args, **options): ++ sargs = ", ".join(args) ++ atypes = tuple(map(self.get_ref_type, args)) ++ rtype = options.get("rtype") ++ if rtype is None: ++ # assuming the first op argument defines the op type ++ rtype = atypes[0] ++ typ = f'({", ".join(atypes)}) -> {rtype}' ++ return self.assign(f"{op} {sargs}", typ=typ) ++ ++ def return_last(self): ++ ref = f"%{self.counter - 1}" ++ self.statements[self.counter] = ("return", ref, self.get_ref_type(ref)) ++ self.counter += 1 ++ ++ @property ++ def is_function(self): ++ return self.prefix.startwith("func.func") ++ ++ @property ++ def function_name(self): ++ if self.prefix.startswith("func.func"): ++ i = self.prefix.find("@") ++ j = self.prefix.find("(", i) ++ assert -1 not in {i, j}, self.prefix ++ return self.prefix[i + 1 : j] ++ ++ @property ++ def function_type(self): ++ if self.prefix.startswith("func.func"): ++ i = self.prefix.find("(", self.prefix.find("@")) ++ j = self.prefix.find("{", i) ++ assert -1 not in {i, j}, self.prefix ++ return self.prefix[i:j].strip() ++ ++ def get_function_type(self, name): ++ for block in self.parent.blocks: ++ if block.function_name == name: ++ return block.function_type ++ ++ def get_ref_type(self, ref): ++ assert ref.startswith("%"), ref ++ counter = int(ref[1:]) ++ typ = self.statements[counter][-1] ++ return typ.rsplit("->", 1)[-1].strip() + + + class SSA: +- """A light-weight SSA form factory.""" +- +- def __init__(self, prefix, suffix): +- self.prefix = prefix +- self.suffix = suffix +- self.blocks = [] +- +- @classmethod +- def make_module(cls, name): +- return SSA(f"module @{name} {{", "}") +- +- def make_function(self, name, args, rtype, attrs="private"): +- if rtype: +- b = Block(self, f"func.func {attrs} @{name}({args}) -> {rtype} {{", +- "}") +- else: +- b = Block(self, f"func.func {attrs} @{name}({args}) {{", "}") +- self.blocks.append(b) +- return b +- +- def tostr(self, tab=""): +- lines = [] +- lines.append(tab + self.prefix) +- for b in self.blocks: +- lines.extend(b.tostr(tab=tab + " ").split("\n")) +- lines.append(tab + self.suffix) +- return "\n".join(lines) +- +- def __str__(self): +- return self.tostr() ++ """A light-weight SSA form factory.""" ++ ++ def __init__(self, prefix, suffix): ++ self.prefix = prefix ++ self.suffix = suffix ++ self.blocks = [] ++ ++ @classmethod ++ def make_module(cls, name): ++ return SSA(f"module @{name} {{", "}") ++ ++ def make_function(self, name, args, rtype, attrs="private"): ++ if rtype: ++ b = Block(self, f"func.func {attrs} @{name}({args}) -> {rtype} {{", "}") ++ else: ++ b = Block(self, f"func.func {attrs} @{name}({args}) {{", "}") ++ self.blocks.append(b) ++ return b ++ ++ def tostr(self, tab=""): ++ lines = [] ++ lines.append(tab + self.prefix) ++ for b in self.blocks: ++ lines.extend(b.tostr(tab=tab + " ").split("\n")) ++ lines.append(tab + self.suffix) ++ return "\n".join(lines) ++ ++ def __str__(self): ++ return self.tostr() + + + def mlir_type(obj): +- if isinstance(obj, np.ndarray): +- s = "x".join(map(str, obj.shape)) +- t = { +- np.bool_: "i1", +- np.float16: "f16", +- np.float32: "f32", +- np.float64: "f64", +- np.complex64: "complex", +- np.complex128: "complex", +- }[obj.dtype.type] +- return f"tensor<{s}x{t}>" ++ if isinstance(obj, np.ndarray): ++ s = "x".join(map(str, obj.shape)) ++ t = { ++ np.bool_: "i1", ++ np.float16: "f16", ++ np.float32: "f32", ++ np.float64: "f64", ++ np.complex64: "complex", ++ np.complex128: "complex", ++ }[obj.dtype.type] ++ return f"tensor<{s}x{t}>" ++ else: ++ raise NotImplementedError(type(obj)) ++ ++ ++def shlo_constant(obj): ++ if isinstance(obj, bool): ++ v = str(obj).lower() ++ return f"stablehlo.constant dense<{v}>" ++ if isinstance(obj, np.ndarray): ++ if obj.dtype == np.bool_: ++ h = "".join(map(lambda n: "%01x" % n, obj.view(np.uint8))).upper() + else: +- raise NotImplementedError(type(obj)) +- +- +-def shlo_constant(obj): +- if isinstance(obj, bool): +- v = str(obj).lower() +- return f"stablehlo.constant dense<{v}>" +- if isinstance(obj, np.ndarray): +- if obj.dtype == np.bool_: +- h = "".join(map(lambda n: "%01x" % n, obj.view(np.uint8))).upper() +- else: +- h = "".join(map(lambda n: "%02x" % n, obj.view(np.uint8))).upper() +- return f'stablehlo.constant dense<"0x{h}">' +- else: +- raise NotImplementedError(type(obj)) ++ h = "".join(map(lambda n: "%02x" % n, obj.view(np.uint8))).upper() ++ return f'stablehlo.constant dense<"0x{h}">' ++ else: ++ raise NotImplementedError(type(obj)) + + + if __name__ == "__main__": +- main() ++ main() diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists.txt --- stablehlo/stablehlo/CMakeLists.txt +++ stablehlo/stablehlo/CMakeLists.txt @@ -175,18 +1062,6 @@ diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists add_subdirectory(integrations) add_subdirectory(reference) add_subdirectory(tests) -diff --ruN a/stablehlo/stablehlo/api/PortableApi.h b/stablehlo/stablehlo/api/PortableApi.h ---- stablehlo/stablehlo/api/PortableApi.h -+++ stablehlo/stablehlo/api/PortableApi.h -@@ -27,7 +27,7 @@ - - /// Return the current version for portable API. - /// Increments on all meaningful changes to this file. --inline int64_t getApiVersion() { return 6; } -+inline int64_t getApiVersion() { return 7; } - - // Get the current StableHLO version. - // diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel --- stablehlo/stablehlo/experimental/BUILD.bazel +++ stablehlo/stablehlo/experimental/BUILD.bazel @@ -2955,41 +3830,21 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c +} // namespace experimental +} // namespace stablehlo +} // namespace mlir -diff --ruN a/stablehlo/stablehlo/tests/TestUtils.cpp b/stablehlo/stablehlo/tests/TestUtils.cpp ---- stablehlo/stablehlo/tests/TestUtils.cpp -+++ stablehlo/stablehlo/tests/TestUtils.cpp -@@ -176,7 +176,8 @@ - GreedyRewriteConfig config; - config.maxIterations = 1; - config.useTopDownTraversal = true; -- config.enableRegionSimplification = false; -+ config.enableRegionSimplification = -+ mlir::GreedySimplifyRegionLevel::Disabled; - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - } +diff --ruN a/stablehlo/stablehlo/integrations/python/tests/stablehlo.py b/stablehlo/stablehlo/integrations/python/tests/stablehlo.py +--- stablehlo/stablehlo/integrations/python/tests/stablehlo.py ++++ stablehlo/stablehlo/integrations/python/tests/stablehlo.py +@@ -283,11 +283,13 @@ + expected = arg + arg + assert (actual == expected).all() -diff --ruN a/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp ---- stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp -+++ stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp -@@ -300,7 +300,7 @@ ++ + @run + def test_get_smaller_version(): + curr_version = stablehlo.get_current_version() + min_version = stablehlo.get_minimum_version() + assert stablehlo.get_smaller_version(curr_version, min_version) == min_version ++ - LogicalResult initialize(MLIRContext* context) override { - config.useTopDownTraversal = true; -- config.enableRegionSimplification = true; -+ config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; - config.maxIterations = 2; - config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; - config.strictMode = GreedyRewriteStrictness::AnyOp; -diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ---- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -@@ -728,7 +728,7 @@ - // There have been recent refactors to applyPatternsAndFoldGreedily - // upstream, and that might be the reason. - config.useTopDownTraversal = true; -- config.enableRegionSimplification = true; -+ config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; - config.maxIterations = 2; - config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; - config.strictMode = GreedyRewriteStrictness::AnyOp; + @run + def test_serialization_apis(): diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl index 3b14d447210c95..15648501cf7f1c 100644 --- a/third_party/xla/third_party/stablehlo/workspace.bzl +++ b/third_party/xla/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "dd48ec58d3bb8d674adf56715d4394102538fa84" - STABLEHLO_SHA256 = "beb4ddc7246326d92cc1a864fb387d6c2d103a5feda096f5b296c24a766ea080" + STABLEHLO_COMMIT = "f1f49945f3862a46ecd6c7fc111e9d7843c2b7da" + STABLEHLO_SHA256 = "5e79eaf23075e627c8cbcb4e8572f91fd2db70a4a96722f02490482f26177b4b" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir index fdc11fedbd392f..82452cc8a98985 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir @@ -3,16 +3,16 @@ // CHECK-LABEL: func.func @asin_bf16( // CHECK-SAME: %[[TMP_arg0:.*]]: tensor -// CHECK-NEXT: %[[TMP_0:.*]] = mhlo.constant dense<2.000000e+00> : tensor -// CHECK-NEXT: %[[TMP_1:.*]] = mhlo.constant dense<1.000000e+00> : tensor -// CHECK-NEXT: %[[TMP_2:.*]] = mhlo.constant dense<1.000000e+00> : tensor -// CHECK-NEXT: %[[TMP_3:.*]] = mhlo.multiply %[[TMP_arg0]], %[[TMP_arg0]] : tensor -// CHECK-NEXT: %[[TMP_4:.*]] = mhlo.subtract %[[TMP_2]], %[[TMP_3]] : tensor -// CHECK-NEXT: %[[TMP_5:.*]] = mhlo.sqrt %[[TMP_4]] : tensor -// CHECK-NEXT: %[[TMP_6:.*]] = mhlo.add %[[TMP_1]], %[[TMP_5]] : tensor -// CHECK-NEXT: %[[TMP_7:.*]] = mhlo.atan2 %[[TMP_arg0]], %[[TMP_6]] : tensor -// CHECK-NEXT: %[[TMP_8:.*]] = mhlo.multiply %[[TMP_0]], %[[TMP_7]] : tensor -// CHECK-NEXT: return %[[TMP_8]] : tensor +// CHECK-NEXT: %[[TMP_0:.*]] = mhlo.constant dense<2.000000e+00> : tensor +// CHECK-NEXT: %[[TMP_1:.*]] = mhlo.constant dense<1.000000e+00> : tensor +// CHECK-NEXT: %[[TMP_2:.*]] = mhlo.subtract %[[TMP_1]], %[[TMP_arg0]] : tensor +// CHECK-NEXT: %[[TMP_3:.*]] = mhlo.add %[[TMP_1]], %[[TMP_arg0]] : tensor +// CHECK-NEXT: %[[TMP_4:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_3]] : tensor +// CHECK-NEXT: %[[TMP_5:.*]] = mhlo.sqrt %[[TMP_4]] : tensor +// CHECK-NEXT: %[[TMP_6:.*]] = mhlo.add %[[TMP_1]], %[[TMP_5]] : tensor +// CHECK-NEXT: %[[TMP_7:.*]] = mhlo.atan2 %[[TMP_arg0]], %[[TMP_6]] : tensor +// CHECK-NEXT: %[[TMP_8:.*]] = mhlo.multiply %[[TMP_0]], %[[TMP_7]] : tensor +// CHECK-NEXT: return %[[TMP_8]] : tensor func.func @asin_bf16(%arg : tensor) -> tensor { %result = "chlo.asin"(%arg) : (tensor) -> tensor func.return %result : tensor @@ -24,9 +24,9 @@ func.func @asin_bf16(%arg : tensor) -> tensor { // CHECK-SAME: %[[TMP_arg0:.*]]: tensor // CHECK-NEXT: %[[TMP_0:.*]] = mhlo.constant dense<2.000000e+00> : tensor // CHECK-NEXT: %[[TMP_1:.*]] = mhlo.constant dense<1.000000e+00> : tensor -// CHECK-NEXT: %[[TMP_2:.*]] = mhlo.constant dense<1.000000e+00> : tensor -// CHECK-NEXT: %[[TMP_3:.*]] = mhlo.multiply %[[TMP_arg0]], %[[TMP_arg0]] : tensor -// CHECK-NEXT: %[[TMP_4:.*]] = mhlo.subtract %[[TMP_2]], %[[TMP_3]] : tensor +// CHECK-NEXT: %[[TMP_2:.*]] = mhlo.subtract %[[TMP_1]], %[[TMP_arg0]] : tensor +// CHECK-NEXT: %[[TMP_3:.*]] = mhlo.add %[[TMP_1]], %[[TMP_arg0]] : tensor +// CHECK-NEXT: %[[TMP_4:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_3]] : tensor // CHECK-NEXT: %[[TMP_5:.*]] = mhlo.sqrt %[[TMP_4]] : tensor // CHECK-NEXT: %[[TMP_6:.*]] = mhlo.add %[[TMP_1]], %[[TMP_5]] : tensor // CHECK-NEXT: %[[TMP_7:.*]] = mhlo.atan2 %[[TMP_arg0]], %[[TMP_6]] : tensor @@ -43,9 +43,9 @@ func.func @asin_f16(%arg : tensor) -> tensor { // CHECK-SAME: %[[TMP_arg0:.*]]: tensor) -> tensor // CHECK-NEXT: %[[TMP_0:.*]] = mhlo.constant dense<2.000000e+00> : tensor // CHECK-NEXT: %[[TMP_1:.*]] = mhlo.constant dense<1.000000e+00> : tensor -// CHECK-NEXT: %[[TMP_2:.*]] = mhlo.constant dense<1.000000e+00> : tensor -// CHECK-NEXT: %[[TMP_3:.*]] = mhlo.multiply %[[TMP_arg0]], %[[TMP_arg0]] : tensor -// CHECK-NEXT: %[[TMP_4:.*]] = mhlo.subtract %[[TMP_2]], %[[TMP_3]] : tensor +// CHECK-NEXT: %[[TMP_2:.*]] = mhlo.subtract %[[TMP_1]], %[[TMP_arg0]] : tensor +// CHECK-NEXT: %[[TMP_3:.*]] = mhlo.add %[[TMP_1]], %[[TMP_arg0]] : tensor +// CHECK-NEXT: %[[TMP_4:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_3]] : tensor // CHECK-NEXT: %[[TMP_5:.*]] = mhlo.sqrt %[[TMP_4]] : tensor // CHECK-NEXT: %[[TMP_6:.*]] = mhlo.add %[[TMP_1]], %[[TMP_5]] : tensor // CHECK-NEXT: %[[TMP_7:.*]] = mhlo.atan2 %[[TMP_arg0]], %[[TMP_6]] : tensor @@ -62,9 +62,9 @@ func.func @asin_f32(%arg : tensor) -> tensor { // CHECK-SAME: %[[TMP_arg0:.*]]: tensor) -> tensor // CHECK-NEXT: %[[TMP_0:.*]] = mhlo.constant dense<2.000000e+00> : tensor // CHECK-NEXT: %[[TMP_1:.*]] = mhlo.constant dense<1.000000e+00> : tensor -// CHECK-NEXT: %[[TMP_2:.*]] = mhlo.constant dense<1.000000e+00> : tensor -// CHECK-NEXT: %[[TMP_3:.*]] = mhlo.multiply %[[TMP_arg0]], %[[TMP_arg0]] : tensor -// CHECK-NEXT: %[[TMP_4:.*]] = mhlo.subtract %[[TMP_2]], %[[TMP_3]] : tensor +// CHECK-NEXT: %[[TMP_2:.*]] = mhlo.subtract %[[TMP_1]], %[[TMP_arg0]] : tensor +// CHECK-NEXT: %[[TMP_3:.*]] = mhlo.add %[[TMP_1]], %[[TMP_arg0]] : tensor +// CHECK-NEXT: %[[TMP_4:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_3]] : tensor // CHECK-NEXT: %[[TMP_5:.*]] = mhlo.sqrt %[[TMP_4]] : tensor // CHECK-NEXT: %[[TMP_6:.*]] = mhlo.add %[[TMP_1]], %[[TMP_5]] : tensor // CHECK-NEXT: %[[TMP_7:.*]] = mhlo.atan2 %[[TMP_arg0]], %[[TMP_6]] : tensor @@ -79,16 +79,141 @@ func.func @asin_f64(%arg : tensor) -> tensor { // CHECK-LABEL: func.func @asin_complex_f32( // CHECK-SAME: %[[TMP_arg0:.*]]: tensor>) -> tensor> -// CHECK-NEXT: %[[TMP_0:.*]] = mhlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> -// CHECK-NEXT: %[[TMP_1:.*]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor> -// CHECK-NEXT: %[[TMP_2:.*]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor> -// CHECK-NEXT: %[[TMP_3:.*]] = mhlo.multiply %[[TMP_arg0]], %[[TMP_arg0]] : tensor> -// CHECK-NEXT: %[[TMP_4:.*]] = mhlo.subtract %[[TMP_2]], %[[TMP_3]] : tensor> -// CHECK-NEXT: %[[TMP_5:.*]] = mhlo.sqrt %[[TMP_4]] : tensor> -// CHECK-NEXT: %[[TMP_6:.*]] = mhlo.add %[[TMP_1]], %[[TMP_5]] : tensor> -// CHECK-NEXT: %[[TMP_7:.*]] = mhlo.atan2 %[[TMP_arg0]], %[[TMP_6]] : tensor> -// CHECK-NEXT: %[[TMP_8:.*]] = mhlo.multiply %[[TMP_0]], %[[TMP_7]] : tensor> -// CHECK-NEXT: return %[[TMP_8]] : tensor> +// CHECK-NEXT: %[[TMP_0:.*]] = mhlo.real %[[TMP_arg0]] : (tensor>) -> tensor +// CHECK-NEXT: %[[TMP_1:.*]] = mhlo.abs %[[TMP_0]] : tensor +// CHECK-NEXT: %[[TMP_2:.*]] = mhlo.imag %[[TMP_arg0]] : (tensor>) -> tensor +// CHECK-NEXT: %[[TMP_3:.*]] = mhlo.abs %[[TMP_2]] : tensor +// CHECK-NEXT: %[[TMP_4:.*]] = mhlo.maximum %[[TMP_1]], %[[TMP_3]] : tensor +// CHECK-NEXT: %[[TMP_5:.*]] = mhlo.constant dense<3.40282347E+38> : tensor +// CHECK-NEXT: %[[TMP_6:.*]] = mhlo.sqrt %[[TMP_5]] : tensor +// CHECK-NEXT: %[[TMP_7:.*]] = mhlo.constant dense<8.000000e+00> : tensor +// CHECK-NEXT: %[[TMP_8:.*]] = mhlo.divide %[[TMP_6]], %[[TMP_7]] : tensor +// CHECK-NEXT: %[[TMP_9:.*]] = mhlo.compare GE, %[[TMP_4]], %[[TMP_8]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_10:.*]] = mhlo.constant dense<1.000000e+00> : tensor +// CHECK-NEXT: %[[TMP_11:.*]] = mhlo.compare LE, %[[TMP_1]], %[[TMP_10]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_12:.*]] = mhlo.constant dense<5.000000e-01> : tensor +// CHECK-NEXT: %[[TMP_13:.*]] = mhlo.add %[[TMP_1]], %[[TMP_10]] : tensor +// CHECK-NEXT: %[[TMP_14:.*]] = mhlo.abs %[[TMP_13]] : tensor +// CHECK-NEXT: %[[TMP_15:.*]] = mhlo.maximum %[[TMP_14]], %[[TMP_3]] : tensor +// CHECK-NEXT: %[[TMP_16:.*]] = mhlo.minimum %[[TMP_14]], %[[TMP_3]] : tensor +// CHECK-NEXT: %[[TMP_17:.*]] = mhlo.compare EQ, %[[TMP_15]], %[[TMP_16]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_18:.*]] = mhlo.constant dense<2.000000e+00> : tensor +// CHECK-NEXT: %[[TMP_19:.*]] = mhlo.sqrt %[[TMP_18]] : tensor +// CHECK-NEXT: %[[TMP_20:.*]] = mhlo.multiply %[[TMP_19]], %[[TMP_15]] : tensor +// CHECK-NEXT: %[[TMP_21:.*]] = mhlo.divide %[[TMP_16]], %[[TMP_15]] : tensor +// CHECK-NEXT: %[[TMP_22:.*]] = mhlo.multiply %[[TMP_21]], %[[TMP_21]] : tensor +// CHECK-NEXT: %[[TMP_23:.*]] = mhlo.add %[[TMP_10]], %[[TMP_22]] : tensor +// CHECK-NEXT: %[[TMP_24:.*]] = mhlo.sqrt %[[TMP_23]] : tensor +// CHECK-NEXT: %[[TMP_25:.*]] = mhlo.compare EQ, %[[TMP_24]], %[[TMP_10]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_26:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK-NEXT: %[[TMP_27:.*]] = mhlo.compare GT, %[[TMP_22]], %[[TMP_26]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_28:.*]] = mhlo.and %[[TMP_25]], %[[TMP_27]] : tensor +// CHECK-NEXT: %[[TMP_29:.*]] = mhlo.multiply %[[TMP_15]], %[[TMP_22]] : tensor +// CHECK-NEXT: %[[TMP_30:.*]] = mhlo.divide %[[TMP_29]], %[[TMP_18]] : tensor +// CHECK-NEXT: %[[TMP_31:.*]] = mhlo.add %[[TMP_15]], %[[TMP_30]] : tensor +// CHECK-NEXT: %[[TMP_32:.*]] = mhlo.multiply %[[TMP_15]], %[[TMP_24]] : tensor +// CHECK-NEXT: %[[TMP_33:.*]] = mhlo.select %[[TMP_28]], %[[TMP_31]], %[[TMP_32]] : tensor, tensor +// CHECK-NEXT: %[[TMP_34:.*]] = mhlo.select %[[TMP_17]], %[[TMP_20]], %[[TMP_33]] : tensor, tensor +// CHECK-NEXT: %[[TMP_35:.*]] = mhlo.subtract %[[TMP_1]], %[[TMP_10]] : tensor +// CHECK-NEXT: %[[TMP_36:.*]] = mhlo.abs %[[TMP_35]] : tensor +// CHECK-NEXT: %[[TMP_37:.*]] = mhlo.maximum %[[TMP_36]], %[[TMP_3]] : tensor +// CHECK-NEXT: %[[TMP_38:.*]] = mhlo.minimum %[[TMP_36]], %[[TMP_3]] : tensor +// CHECK-NEXT: %[[TMP_39:.*]] = mhlo.compare EQ, %[[TMP_37]], %[[TMP_38]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_40:.*]] = mhlo.multiply %[[TMP_19]], %[[TMP_37]] : tensor +// CHECK-NEXT: %[[TMP_41:.*]] = mhlo.divide %[[TMP_38]], %[[TMP_37]] : tensor +// CHECK-NEXT: %[[TMP_42:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_41]] : tensor +// CHECK-NEXT: %[[TMP_43:.*]] = mhlo.add %[[TMP_10]], %[[TMP_42]] : tensor +// CHECK-NEXT: %[[TMP_44:.*]] = mhlo.sqrt %[[TMP_43]] : tensor +// CHECK-NEXT: %[[TMP_45:.*]] = mhlo.compare EQ, %[[TMP_44]], %[[TMP_10]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_46:.*]] = mhlo.compare GT, %[[TMP_42]], %[[TMP_26]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_47:.*]] = mhlo.and %[[TMP_45]], %[[TMP_46]] : tensor +// CHECK-NEXT: %[[TMP_48:.*]] = mhlo.multiply %[[TMP_37]], %[[TMP_42]] : tensor +// CHECK-NEXT: %[[TMP_49:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_18]] : tensor +// CHECK-NEXT: %[[TMP_50:.*]] = mhlo.add %[[TMP_37]], %[[TMP_49]] : tensor +// CHECK-NEXT: %[[TMP_51:.*]] = mhlo.multiply %[[TMP_37]], %[[TMP_44]] : tensor +// CHECK-NEXT: %[[TMP_52:.*]] = mhlo.select %[[TMP_47]], %[[TMP_50]], %[[TMP_51]] : tensor, tensor +// CHECK-NEXT: %[[TMP_53:.*]] = mhlo.select %[[TMP_39]], %[[TMP_40]], %[[TMP_52]] : tensor, tensor +// CHECK-NEXT: %[[TMP_54:.*]] = mhlo.add %[[TMP_34]], %[[TMP_53]] : tensor +// CHECK-NEXT: %[[TMP_55:.*]] = mhlo.multiply %[[TMP_12]], %[[TMP_54]] : tensor +// CHECK-NEXT: %[[TMP_56:.*]] = mhlo.add %[[TMP_55]], %[[TMP_1]] : tensor +// CHECK-NEXT: %[[TMP_57:.*]] = mhlo.multiply %[[TMP_12]], %[[TMP_56]] : tensor +// CHECK-NEXT: %[[TMP_58:.*]] = mhlo.multiply %[[TMP_3]], %[[TMP_3]] : tensor +// CHECK-NEXT: %[[TMP_59:.*]] = mhlo.add %[[TMP_34]], %[[TMP_13]] : tensor +// CHECK-NEXT: %[[TMP_60:.*]] = mhlo.divide %[[TMP_58]], %[[TMP_59]] : tensor +// CHECK-NEXT: %[[TMP_61:.*]] = mhlo.subtract %[[TMP_53]], %[[TMP_35]] : tensor +// CHECK-NEXT: %[[TMP_62:.*]] = mhlo.add %[[TMP_60]], %[[TMP_61]] : tensor +// CHECK-NEXT: %[[TMP_63:.*]] = mhlo.multiply %[[TMP_57]], %[[TMP_62]] : tensor +// CHECK-NEXT: %[[TMP_64:.*]] = mhlo.sqrt %[[TMP_63]] : tensor +// CHECK-NEXT: %[[TMP_65:.*]] = mhlo.divide %[[TMP_57]], %[[TMP_59]] : tensor +// CHECK-NEXT: %[[TMP_66:.*]] = mhlo.add %[[TMP_53]], %[[TMP_35]] : tensor +// CHECK-NEXT: %[[TMP_67:.*]] = mhlo.divide %[[TMP_57]], %[[TMP_66]] : tensor +// CHECK-NEXT: %[[TMP_68:.*]] = mhlo.add %[[TMP_65]], %[[TMP_67]] : tensor +// CHECK-NEXT: %[[TMP_69:.*]] = mhlo.sqrt %[[TMP_68]] : tensor +// CHECK-NEXT: %[[TMP_70:.*]] = mhlo.multiply %[[TMP_3]], %[[TMP_69]] : tensor +// CHECK-NEXT: %[[TMP_71:.*]] = mhlo.select %[[TMP_11]], %[[TMP_64]], %[[TMP_70]] : tensor, tensor +// CHECK-NEXT: %[[TMP_72:.*]] = mhlo.select %[[TMP_9]], %[[TMP_3]], %[[TMP_71]] : tensor, tensor +// CHECK-NEXT: %[[TMP_73:.*]] = mhlo.atan2 %[[TMP_0]], %[[TMP_72]] : tensor +// CHECK-NEXT: %[[TMP_74:.*]] = mhlo.compare LT, %[[TMP_2]], %[[TMP_26]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_75:.*]] = mhlo.constant dense<9.99999995E+11> : tensor +// CHECK-NEXT: %[[TMP_76:.*]] = mhlo.multiply %[[TMP_8]], %[[TMP_75]] : tensor +// CHECK-NEXT: %[[TMP_77:.*]] = mhlo.compare LT, %[[TMP_1]], %[[TMP_76]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_78:.*]] = mhlo.constant dense<9.99999997E-7> : tensor +// CHECK-NEXT: %[[TMP_79:.*]] = mhlo.multiply %[[TMP_8]], %[[TMP_78]] : tensor +// CHECK-NEXT: %[[TMP_80:.*]] = mhlo.constant dense<1.000000e+02> : tensor +// CHECK-NEXT: %[[TMP_81:.*]] = mhlo.multiply %[[TMP_8]], %[[TMP_80]] : tensor +// CHECK-NEXT: %[[TMP_82:.*]] = mhlo.select %[[TMP_77]], %[[TMP_79]], %[[TMP_81]] : tensor, tensor +// CHECK-NEXT: %[[TMP_83:.*]] = mhlo.compare GE, %[[TMP_3]], %[[TMP_82]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_84:.*]] = mhlo.select %[[TMP_83]], %[[TMP_3]], %[[TMP_1]] : tensor, tensor +// CHECK-NEXT: %[[TMP_85:.*]] = mhlo.select %[[TMP_83]], %[[TMP_82]], %[[TMP_8]] : tensor, tensor +// CHECK-NEXT: %[[TMP_86:.*]] = mhlo.compare GE, %[[TMP_84]], %[[TMP_85]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_87:.*]] = mhlo.log %[[TMP_18]] : tensor +// CHECK-NEXT: %[[TMP_88:.*]] = mhlo.log %[[TMP_84]] : tensor +// CHECK-NEXT: %[[TMP_89:.*]] = mhlo.add %[[TMP_87]], %[[TMP_88]] : tensor +// CHECK-NEXT: %[[TMP_90:.*]] = mhlo.constant dense<0x7F800000> : tensor +// CHECK-NEXT: %[[TMP_91:.*]] = mhlo.compare EQ, %[[TMP_3]], %[[TMP_90]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_92:.*]] = mhlo.not %[[TMP_91]] : tensor +// CHECK-NEXT: %[[TMP_93:.*]] = mhlo.and %[[TMP_83]], %[[TMP_92]] : tensor +// CHECK-NEXT: %[[TMP_94:.*]] = mhlo.divide %[[TMP_1]], %[[TMP_3]] : tensor +// CHECK-NEXT: %[[TMP_95:.*]] = mhlo.select %[[TMP_93]], %[[TMP_94]], %[[TMP_26]] : tensor, tensor +// CHECK-NEXT: %[[TMP_96:.*]] = mhlo.multiply %[[TMP_95]], %[[TMP_95]] : tensor +// CHECK-NEXT: %[[TMP_97:.*]] = mhlo.log_plus_one %[[TMP_96]] : tensor +// CHECK-NEXT: %[[TMP_98:.*]] = mhlo.multiply %[[TMP_12]], %[[TMP_97]] : tensor +// CHECK-NEXT: %[[TMP_99:.*]] = mhlo.add %[[TMP_89]], %[[TMP_98]] : tensor +// CHECK-NEXT: %[[TMP_100:.*]] = mhlo.constant dense<1.17549435E-38> : tensor +// CHECK-NEXT: %[[TMP_101:.*]] = mhlo.sqrt %[[TMP_100]] : tensor +// CHECK-NEXT: %[[TMP_102:.*]] = mhlo.constant dense<4.000000e+00> : tensor +// CHECK-NEXT: %[[TMP_103:.*]] = mhlo.multiply %[[TMP_101]], %[[TMP_102]] : tensor +// CHECK-NEXT: %[[TMP_104:.*]] = mhlo.compare LT, %[[TMP_3]], %[[TMP_103]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_105:.*]] = mhlo.compare LT, %[[TMP_1]], %[[TMP_10]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_106:.*]] = mhlo.and %[[TMP_104]], %[[TMP_105]] : tensor +// CHECK-NEXT: %[[TMP_107:.*]] = mhlo.multiply %[[TMP_13]], %[[TMP_35]] : tensor +// CHECK-NEXT: %[[TMP_108:.*]] = mhlo.add %[[TMP_55]], %[[TMP_10]] : tensor +// CHECK-NEXT: %[[TMP_109:.*]] = mhlo.divide %[[TMP_107]], %[[TMP_108]] : tensor +// CHECK-NEXT: %[[TMP_110:.*]] = mhlo.negate %[[TMP_109]] : tensor +// CHECK-NEXT: %[[TMP_111:.*]] = mhlo.compare GE, %[[TMP_1]], %[[TMP_10]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_112:.*]] = mhlo.multiply %[[TMP_12]], %[[TMP_58]] : tensor +// CHECK-NEXT: %[[TMP_113:.*]] = mhlo.divide %[[TMP_112]], %[[TMP_59]] : tensor +// CHECK-NEXT: %[[TMP_114:.*]] = mhlo.multiply %[[TMP_12]], %[[TMP_66]] : tensor +// CHECK-NEXT: %[[TMP_115:.*]] = mhlo.add %[[TMP_113]], %[[TMP_114]] : tensor +// CHECK-NEXT: %[[TMP_116:.*]] = mhlo.constant dense<1.500000e+00> : tensor +// CHECK-NEXT: %[[TMP_117:.*]] = mhlo.compare LE, %[[TMP_55]], %[[TMP_116]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_118:.*]] = mhlo.divide %[[TMP_112]], %[[TMP_61]] : tensor +// CHECK-NEXT: %[[TMP_119:.*]] = mhlo.add %[[TMP_113]], %[[TMP_118]] : tensor +// CHECK-NEXT: %[[TMP_120:.*]] = mhlo.subtract %[[TMP_55]], %[[TMP_10]] : tensor +// CHECK-NEXT: %[[TMP_121:.*]] = mhlo.select %[[TMP_117]], %[[TMP_119]], %[[TMP_120]] : tensor, tensor +// CHECK-NEXT: %[[TMP_122:.*]] = mhlo.select %[[TMP_111]], %[[TMP_115]], %[[TMP_121]] : tensor, tensor +// CHECK-NEXT: %[[TMP_123:.*]] = mhlo.select %[[TMP_106]], %[[TMP_110]], %[[TMP_122]] : tensor, tensor +// CHECK-NEXT: %[[TMP_124:.*]] = mhlo.multiply %[[TMP_123]], %[[TMP_108]] : tensor +// CHECK-NEXT: %[[TMP_125:.*]] = mhlo.sqrt %[[TMP_124]] : tensor +// CHECK-NEXT: %[[TMP_126:.*]] = mhlo.divide %[[TMP_3]], %[[TMP_125]] : tensor +// CHECK-NEXT: %[[TMP_127:.*]] = mhlo.add %[[TMP_123]], %[[TMP_125]] : tensor +// CHECK-NEXT: %[[TMP_128:.*]] = mhlo.log_plus_one %[[TMP_127]] : tensor +// CHECK-NEXT: %[[TMP_129:.*]] = mhlo.select %[[TMP_106]], %[[TMP_126]], %[[TMP_128]] : tensor, tensor +// CHECK-NEXT: %[[TMP_130:.*]] = mhlo.select %[[TMP_86]], %[[TMP_99]], %[[TMP_129]] : tensor, tensor +// CHECK-NEXT: %[[TMP_131:.*]] = mhlo.negate %[[TMP_130]] : tensor +// CHECK-NEXT: %[[TMP_132:.*]] = mhlo.select %[[TMP_74]], %[[TMP_131]], %[[TMP_130]] : tensor, tensor +// CHECK-NEXT: %[[TMP_133:.*]] = mhlo.complex %[[TMP_73]], %[[TMP_132]] : tensor> +// CHECK-NEXT: return %[[TMP_133]] : tensor> func.func @asin_complex_f32(%arg : tensor>) -> tensor> { %result = "chlo.asin"(%arg) : (tensor>) -> tensor> func.return %result : tensor> @@ -98,22 +223,167 @@ func.func @asin_complex_f32(%arg : tensor>) -> tensor> // CHECK-LABEL: func.func @asin_complex_f64_dynamic( // CHECK-SAME: %[[ARG0:.*]]: tensor>) -> tensor> -// CHECK: %[[TWO:.*]] = mhlo.constant dense<(2.000000e+00,0.000000e+00)> -// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] -// CHECK: %[[TWO_BROADCASTED:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TWO]], %[[SHAPE]]) -// CHECK: %[[ONE:.*]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> -// CHECK: %[[SHAPE2:.*]] = shape.shape_of %[[ARG0]] -// CHECK: %[[ONE_BROADCASTED:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ONE]], %[[SHAPE2]]) -// CHECK: %[[ONE2:.*]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> -// CHECK: %[[SHAPE3:.*]] = shape.shape_of %[[ARG0]] -// CHECK: %[[ONE_BROADCASTED2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ONE2]], %[[SHAPE3]]) -// CHECK: %[[SQUARE:.*]] = mhlo.multiply %[[ARG0]], %[[ARG0]] -// CHECK: %[[SUB:.*]] = mhlo.subtract %[[ONE_BROADCASTED2]], %[[SQUARE]] -// CHECK: %[[SQRT:.*]] = mhlo.sqrt %[[SUB]] -// CHECK: %[[ADD:.*]] = mhlo.add %[[ONE_BROADCASTED]], %[[SQRT]] -// CHECK: %[[ATAN2:.*]] = mhlo.atan2 %[[ARG0]], %[[ADD]] -// CHECK: %[[MUL:.*]] = mhlo.multiply %[[TWO_BROADCASTED]], %[[ATAN2]] -// CHECK: return %[[MUL]] +// CHECK-NEXT: %[[TMP_0:.*]] = mhlo.real %[[ARG0]] : (tensor>) -> tensor +// CHECK-NEXT: %[[TMP_1:.*]] = mhlo.abs %[[TMP_0]] : tensor +// CHECK-NEXT: %[[TMP_2:.*]] = mhlo.imag %[[ARG0]] : (tensor>) -> tensor +// CHECK-NEXT: %[[TMP_3:.*]] = mhlo.abs %[[TMP_2]] : tensor +// CHECK-NEXT: %[[TMP_4:.*]] = mhlo.maximum %[[TMP_1]], %[[TMP_3]] : tensor +// CHECK-NEXT: %[[TMP_5:.*]] = mhlo.constant dense<1.7976931348623157E+308> : tensor +// CHECK-NEXT: %[[TMP_6:.*]] = shape.shape_of %[[TMP_0]] : tensor -> tensor<1xindex> +// CHECK-NEXT: %[[TMP_7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TMP_5]], %[[TMP_6]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[TMP_8:.*]] = mhlo.sqrt %[[TMP_7]] : tensor +// CHECK-NEXT: %[[TMP_9:.*]] = mhlo.constant dense<8.000000e+00> : tensor +// CHECK-NEXT: %[[TMP_10:.*]] = shape.shape_of %[[TMP_0]] : tensor -> tensor<1xindex> +// CHECK-NEXT: %[[TMP_11:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TMP_9]], %[[TMP_10]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[TMP_12:.*]] = mhlo.divide %[[TMP_8]], %[[TMP_11]] : tensor +// CHECK-NEXT: %[[TMP_13:.*]] = mhlo.compare GE, %[[TMP_4]], %[[TMP_12]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_14:.*]] = mhlo.constant dense<1.000000e+00> : tensor +// CHECK-NEXT: %[[TMP_15:.*]] = shape.shape_of %[[TMP_0]] : tensor -> tensor<1xindex> +// CHECK-NEXT: %[[TMP_16:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TMP_14]], %[[TMP_15]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[TMP_17:.*]] = mhlo.compare LE, %[[TMP_1]], %[[TMP_16]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_18:.*]] = mhlo.constant dense<5.000000e-01> : tensor +// CHECK-NEXT: %[[TMP_19:.*]] = shape.shape_of %[[TMP_0]] : tensor -> tensor<1xindex> +// CHECK-NEXT: %[[TMP_20:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TMP_18]], %[[TMP_19]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[TMP_21:.*]] = mhlo.add %[[TMP_1]], %[[TMP_16]] : tensor +// CHECK-NEXT: %[[TMP_22:.*]] = mhlo.abs %[[TMP_21]] : tensor +// CHECK-NEXT: %[[TMP_23:.*]] = mhlo.maximum %[[TMP_22]], %[[TMP_3]] : tensor +// CHECK-NEXT: %[[TMP_24:.*]] = mhlo.minimum %[[TMP_22]], %[[TMP_3]] : tensor +// CHECK-NEXT: %[[TMP_25:.*]] = mhlo.compare EQ, %[[TMP_23]], %[[TMP_24]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_26:.*]] = mhlo.constant dense<2.000000e+00> : tensor +// CHECK-NEXT: %[[TMP_27:.*]] = shape.shape_of %[[TMP_0]] : tensor -> tensor<1xindex> +// CHECK-NEXT: %[[TMP_28:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TMP_26]], %[[TMP_27]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[TMP_29:.*]] = mhlo.sqrt %[[TMP_28]] : tensor +// CHECK-NEXT: %[[TMP_30:.*]] = mhlo.multiply %[[TMP_29]], %[[TMP_23]] : tensor +// CHECK-NEXT: %[[TMP_31:.*]] = mhlo.divide %[[TMP_24]], %[[TMP_23]] : tensor +// CHECK-NEXT: %[[TMP_32:.*]] = mhlo.multiply %[[TMP_31]], %[[TMP_31]] : tensor +// CHECK-NEXT: %[[TMP_33:.*]] = mhlo.add %[[TMP_16]], %[[TMP_32]] : tensor +// CHECK-NEXT: %[[TMP_34:.*]] = mhlo.sqrt %[[TMP_33]] : tensor +// CHECK-NEXT: %[[TMP_35:.*]] = mhlo.compare EQ, %[[TMP_34]], %[[TMP_16]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_36:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK-NEXT: %[[TMP_37:.*]] = shape.shape_of %[[TMP_0]] : tensor -> tensor<1xindex> +// CHECK-NEXT: %[[TMP_38:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TMP_36]], %[[TMP_37]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[TMP_39:.*]] = mhlo.compare GT, %[[TMP_32]], %[[TMP_38]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_40:.*]] = mhlo.and %[[TMP_35]], %[[TMP_39]] : tensor +// CHECK-NEXT: %[[TMP_41:.*]] = mhlo.multiply %[[TMP_23]], %[[TMP_32]] : tensor +// CHECK-NEXT: %[[TMP_42:.*]] = mhlo.divide %[[TMP_41]], %[[TMP_28]] : tensor +// CHECK-NEXT: %[[TMP_43:.*]] = mhlo.add %[[TMP_23]], %[[TMP_42]] : tensor +// CHECK-NEXT: %[[TMP_44:.*]] = mhlo.multiply %[[TMP_23]], %[[TMP_34]] : tensor +// CHECK-NEXT: %[[TMP_45:.*]] = mhlo.select %[[TMP_40]], %[[TMP_43]], %[[TMP_44]] : tensor, tensor +// CHECK-NEXT: %[[TMP_46:.*]] = mhlo.select %[[TMP_25]], %[[TMP_30]], %[[TMP_45]] : tensor, tensor +// CHECK-NEXT: %[[TMP_47:.*]] = mhlo.subtract %[[TMP_1]], %[[TMP_16]] : tensor +// CHECK-NEXT: %[[TMP_48:.*]] = mhlo.abs %[[TMP_47]] : tensor +// CHECK-NEXT: %[[TMP_49:.*]] = mhlo.maximum %[[TMP_48]], %[[TMP_3]] : tensor +// CHECK-NEXT: %[[TMP_50:.*]] = mhlo.minimum %[[TMP_48]], %[[TMP_3]] : tensor +// CHECK-NEXT: %[[TMP_51:.*]] = mhlo.compare EQ, %[[TMP_49]], %[[TMP_50]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_52:.*]] = mhlo.multiply %[[TMP_29]], %[[TMP_49]] : tensor +// CHECK-NEXT: %[[TMP_53:.*]] = mhlo.divide %[[TMP_50]], %[[TMP_49]] : tensor +// CHECK-NEXT: %[[TMP_54:.*]] = mhlo.multiply %[[TMP_53]], %[[TMP_53]] : tensor +// CHECK-NEXT: %[[TMP_55:.*]] = mhlo.add %[[TMP_16]], %[[TMP_54]] : tensor +// CHECK-NEXT: %[[TMP_56:.*]] = mhlo.sqrt %[[TMP_55]] : tensor +// CHECK-NEXT: %[[TMP_57:.*]] = mhlo.compare EQ, %[[TMP_56]], %[[TMP_16]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_58:.*]] = mhlo.compare GT, %[[TMP_54]], %[[TMP_38]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_59:.*]] = mhlo.and %[[TMP_57]], %[[TMP_58]] : tensor +// CHECK-NEXT: %[[TMP_60:.*]] = mhlo.multiply %[[TMP_49]], %[[TMP_54]] : tensor +// CHECK-NEXT: %[[TMP_61:.*]] = mhlo.divide %[[TMP_60]], %[[TMP_28]] : tensor +// CHECK-NEXT: %[[TMP_62:.*]] = mhlo.add %[[TMP_49]], %[[TMP_61]] : tensor +// CHECK-NEXT: %[[TMP_63:.*]] = mhlo.multiply %[[TMP_49]], %[[TMP_56]] : tensor +// CHECK-NEXT: %[[TMP_64:.*]] = mhlo.select %[[TMP_59]], %[[TMP_62]], %[[TMP_63]] : tensor, tensor +// CHECK-NEXT: %[[TMP_65:.*]] = mhlo.select %[[TMP_51]], %[[TMP_52]], %[[TMP_64]] : tensor, tensor +// CHECK-NEXT: %[[TMP_66:.*]] = mhlo.add %[[TMP_46]], %[[TMP_65]] : tensor +// CHECK-NEXT: %[[TMP_67:.*]] = mhlo.multiply %[[TMP_20]], %[[TMP_66]] : tensor +// CHECK-NEXT: %[[TMP_68:.*]] = mhlo.add %[[TMP_67]], %[[TMP_1]] : tensor +// CHECK-NEXT: %[[TMP_69:.*]] = mhlo.multiply %[[TMP_20]], %[[TMP_68]] : tensor +// CHECK-NEXT: %[[TMP_70:.*]] = mhlo.multiply %[[TMP_3]], %[[TMP_3]] : tensor +// CHECK-NEXT: %[[TMP_71:.*]] = mhlo.add %[[TMP_46]], %[[TMP_21]] : tensor +// CHECK-NEXT: %[[TMP_72:.*]] = mhlo.divide %[[TMP_70]], %[[TMP_71]] : tensor +// CHECK-NEXT: %[[TMP_73:.*]] = mhlo.subtract %[[TMP_65]], %[[TMP_47]] : tensor +// CHECK-NEXT: %[[TMP_74:.*]] = mhlo.add %[[TMP_72]], %[[TMP_73]] : tensor +// CHECK-NEXT: %[[TMP_75:.*]] = mhlo.multiply %[[TMP_69]], %[[TMP_74]] : tensor +// CHECK-NEXT: %[[TMP_76:.*]] = mhlo.sqrt %[[TMP_75]] : tensor +// CHECK-NEXT: %[[TMP_77:.*]] = mhlo.divide %[[TMP_69]], %[[TMP_71]] : tensor +// CHECK-NEXT: %[[TMP_78:.*]] = mhlo.add %[[TMP_65]], %[[TMP_47]] : tensor +// CHECK-NEXT: %[[TMP_79:.*]] = mhlo.divide %[[TMP_69]], %[[TMP_78]] : tensor +// CHECK-NEXT: %[[TMP_80:.*]] = mhlo.add %[[TMP_77]], %[[TMP_79]] : tensor +// CHECK-NEXT: %[[TMP_81:.*]] = mhlo.sqrt %[[TMP_80]] : tensor +// CHECK-NEXT: %[[TMP_82:.*]] = mhlo.multiply %[[TMP_3]], %[[TMP_81]] : tensor +// CHECK-NEXT: %[[TMP_83:.*]] = mhlo.select %[[TMP_17]], %[[TMP_76]], %[[TMP_82]] : tensor, tensor +// CHECK-NEXT: %[[TMP_84:.*]] = mhlo.select %[[TMP_13]], %[[TMP_3]], %[[TMP_83]] : tensor, tensor +// CHECK-NEXT: %[[TMP_85:.*]] = mhlo.atan2 %[[TMP_0]], %[[TMP_84]] : tensor +// CHECK-NEXT: %[[TMP_86:.*]] = mhlo.compare LT, %[[TMP_2]], %[[TMP_38]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_87:.*]] = mhlo.constant dense<1.000000e+12> : tensor +// CHECK-NEXT: %[[TMP_88:.*]] = shape.shape_of %[[TMP_0]] : tensor -> tensor<1xindex> +// CHECK-NEXT: %[[TMP_89:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TMP_87]], %[[TMP_88]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[TMP_90:.*]] = mhlo.multiply %[[TMP_12]], %[[TMP_89]] : tensor +// CHECK-NEXT: %[[TMP_91:.*]] = mhlo.compare LT, %[[TMP_1]], %[[TMP_90]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_92:.*]] = mhlo.constant dense<9.9999999999999995E-7> : tensor +// CHECK-NEXT: %[[TMP_93:.*]] = shape.shape_of %[[TMP_0]] : tensor -> tensor<1xindex> +// CHECK-NEXT: %[[TMP_94:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TMP_92]], %[[TMP_93]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[TMP_95:.*]] = mhlo.multiply %[[TMP_12]], %[[TMP_94]] : tensor +// CHECK-NEXT: %[[TMP_96:.*]] = mhlo.constant dense<1.000000e+02> : tensor +// CHECK-NEXT: %[[TMP_97:.*]] = shape.shape_of %[[TMP_0]] : tensor -> tensor<1xindex> +// CHECK-NEXT: %[[TMP_98:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TMP_96]], %[[TMP_97]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[TMP_99:.*]] = mhlo.multiply %[[TMP_12]], %[[TMP_98]] : tensor +// CHECK-NEXT: %[[TMP_100:.*]] = mhlo.select %[[TMP_91]], %[[TMP_95]], %[[TMP_99]] : tensor, tensor +// CHECK-NEXT: %[[TMP_101:.*]] = mhlo.compare GE, %[[TMP_3]], %[[TMP_100]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_102:.*]] = mhlo.select %[[TMP_101]], %[[TMP_3]], %[[TMP_1]] : tensor, tensor +// CHECK-NEXT: %[[TMP_103:.*]] = mhlo.select %[[TMP_101]], %[[TMP_100]], %[[TMP_12]] : tensor, tensor +// CHECK-NEXT: %[[TMP_104:.*]] = mhlo.compare GE, %[[TMP_102]], %[[TMP_103]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_105:.*]] = mhlo.log %[[TMP_28]] : tensor +// CHECK-NEXT: %[[TMP_106:.*]] = mhlo.log %[[TMP_102]] : tensor +// CHECK-NEXT: %[[TMP_107:.*]] = mhlo.add %[[TMP_105]], %[[TMP_106]] : tensor +// CHECK-NEXT: %[[TMP_108:.*]] = mhlo.constant dense<0x7FF0000000000000> : tensor +// CHECK-NEXT: %[[TMP_109:.*]] = shape.shape_of %[[TMP_2]] : tensor -> tensor<1xindex> +// CHECK-NEXT: %[[TMP_110:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TMP_108]], %[[TMP_109]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[TMP_111:.*]] = mhlo.compare EQ, %[[TMP_3]], %[[TMP_110]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_112:.*]] = mhlo.not %[[TMP_111]] : tensor +// CHECK-NEXT: %[[TMP_113:.*]] = mhlo.and %[[TMP_101]], %[[TMP_112]] : tensor +// CHECK-NEXT: %[[TMP_114:.*]] = mhlo.divide %[[TMP_1]], %[[TMP_3]] : tensor +// CHECK-NEXT: %[[TMP_115:.*]] = mhlo.select %[[TMP_113]], %[[TMP_114]], %[[TMP_38]] : tensor, tensor +// CHECK-NEXT: %[[TMP_116:.*]] = mhlo.multiply %[[TMP_115]], %[[TMP_115]] : tensor +// CHECK-NEXT: %[[TMP_117:.*]] = mhlo.log_plus_one %[[TMP_116]] : tensor +// CHECK-NEXT: %[[TMP_118:.*]] = mhlo.multiply %[[TMP_20]], %[[TMP_117]] : tensor +// CHECK-NEXT: %[[TMP_119:.*]] = mhlo.add %[[TMP_107]], %[[TMP_118]] : tensor +// CHECK-NEXT: %[[TMP_120:.*]] = mhlo.constant dense<2.2250738585072014E-308> : tensor +// CHECK-NEXT: %[[TMP_121:.*]] = shape.shape_of %[[TMP_0]] : tensor -> tensor<1xindex> +// CHECK-NEXT: %[[TMP_122:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TMP_120]], %[[TMP_121]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[TMP_123:.*]] = mhlo.sqrt %[[TMP_122]] : tensor +// CHECK-NEXT: %[[TMP_124:.*]] = mhlo.constant dense<4.000000e+00> : tensor +// CHECK-NEXT: %[[TMP_125:.*]] = shape.shape_of %[[TMP_0]] : tensor -> tensor<1xindex> +// CHECK-NEXT: %[[TMP_126:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TMP_124]], %[[TMP_125]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[TMP_127:.*]] = mhlo.multiply %[[TMP_123]], %[[TMP_126]] : tensor +// CHECK-NEXT: %[[TMP_128:.*]] = mhlo.compare LT, %[[TMP_3]], %[[TMP_127]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_129:.*]] = mhlo.compare LT, %[[TMP_1]], %[[TMP_16]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_130:.*]] = mhlo.and %[[TMP_128]], %[[TMP_129]] : tensor +// CHECK-NEXT: %[[TMP_131:.*]] = mhlo.multiply %[[TMP_21]], %[[TMP_47]] : tensor +// CHECK-NEXT: %[[TMP_132:.*]] = mhlo.add %[[TMP_67]], %[[TMP_16]] : tensor +// CHECK-NEXT: %[[TMP_133:.*]] = mhlo.divide %[[TMP_131]], %[[TMP_132]] : tensor +// CHECK-NEXT: %[[TMP_134:.*]] = mhlo.negate %[[TMP_133]] : tensor +// CHECK-NEXT: %[[TMP_135:.*]] = mhlo.compare GE, %[[TMP_1]], %[[TMP_16]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_136:.*]] = mhlo.multiply %[[TMP_20]], %[[TMP_70]] : tensor +// CHECK-NEXT: %[[TMP_137:.*]] = mhlo.divide %[[TMP_136]], %[[TMP_71]] : tensor +// CHECK-NEXT: %[[TMP_138:.*]] = mhlo.multiply %[[TMP_20]], %[[TMP_78]] : tensor +// CHECK-NEXT: %[[TMP_139:.*]] = mhlo.add %[[TMP_137]], %[[TMP_138]] : tensor +// CHECK-NEXT: %[[TMP_140:.*]] = mhlo.constant dense<1.500000e+00> : tensor +// CHECK-NEXT: %[[TMP_141:.*]] = shape.shape_of %[[TMP_0]] : tensor -> tensor<1xindex> +// CHECK-NEXT: %[[TMP_142:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TMP_140]], %[[TMP_141]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[TMP_143:.*]] = mhlo.compare LE, %[[TMP_67]], %[[TMP_142]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[TMP_144:.*]] = mhlo.divide %[[TMP_136]], %[[TMP_73]] : tensor +// CHECK-NEXT: %[[TMP_145:.*]] = mhlo.add %[[TMP_137]], %[[TMP_144]] : tensor +// CHECK-NEXT: %[[TMP_146:.*]] = mhlo.subtract %[[TMP_67]], %[[TMP_16]] : tensor +// CHECK-NEXT: %[[TMP_147:.*]] = mhlo.select %[[TMP_143]], %[[TMP_145]], %[[TMP_146]] : tensor, tensor +// CHECK-NEXT: %[[TMP_148:.*]] = mhlo.select %[[TMP_135]], %[[TMP_139]], %[[TMP_147]] : tensor, tensor +// CHECK-NEXT: %[[TMP_149:.*]] = mhlo.select %[[TMP_130]], %[[TMP_134]], %[[TMP_148]] : tensor, tensor +// CHECK-NEXT: %[[TMP_150:.*]] = mhlo.multiply %[[TMP_149]], %[[TMP_132]] : tensor +// CHECK-NEXT: %[[TMP_151:.*]] = mhlo.sqrt %[[TMP_150]] : tensor +// CHECK-NEXT: %[[TMP_152:.*]] = mhlo.divide %[[TMP_3]], %[[TMP_151]] : tensor +// CHECK-NEXT: %[[TMP_153:.*]] = mhlo.add %[[TMP_149]], %[[TMP_151]] : tensor +// CHECK-NEXT: %[[TMP_154:.*]] = mhlo.log_plus_one %[[TMP_153]] : tensor +// CHECK-NEXT: %[[TMP_155:.*]] = mhlo.select %[[TMP_130]], %[[TMP_152]], %[[TMP_154]] : tensor, tensor +// CHECK-NEXT: %[[TMP_156:.*]] = mhlo.select %[[TMP_104]], %[[TMP_119]], %[[TMP_155]] : tensor, tensor +// CHECK-NEXT: %[[TMP_157:.*]] = mhlo.negate %[[TMP_156]] : tensor +// CHECK-NEXT: %[[TMP_158:.*]] = mhlo.select %[[TMP_86]], %[[TMP_157]], %[[TMP_156]] : tensor, tensor +// CHECK-NEXT: %[[TMP_159:.*]] = mhlo.complex %[[TMP_85]], %[[TMP_158]] : tensor> +// CHECK-NEXT: return %[[TMP_159]] : tensor> func.func @asin_complex_f64_dynamic(%arg : tensor>) -> tensor> { %result = "chlo.asin"(%arg) : (tensor>) -> tensor> func.return %result : tensor> From df114d8dcdbe878d5998a29ac5c04864e0bac155 Mon Sep 17 00:00:00 2001 From: Juhyun Lee Date: Tue, 18 Jun 2024 15:13:39 -0700 Subject: [PATCH 47/59] Fix variable name. PiperOrigin-RevId: 644525699 --- tensorflow/lite/kernels/embedding_lookup.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow/lite/kernels/embedding_lookup.cc b/tensorflow/lite/kernels/embedding_lookup.cc index 285bee61156e8a..3ab0f09fcd9f73 100644 --- a/tensorflow/lite/kernels/embedding_lookup.cc +++ b/tensorflow/lite/kernels/embedding_lookup.cc @@ -83,14 +83,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output; TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); - TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value)); + TfLiteIntArray* output_size = TfLiteIntArrayCreate(NumDimensions(value)); - outputSize->data[0] = SizeOfDimension(lookup, 0); - outputSize->data[1] = SizeOfDimension(value, 1); + output_size->data[0] = SizeOfDimension(lookup, 0); + output_size->data[1] = SizeOfDimension(value, 1); for (int i = 2; i < NumDimensions(value); i++) { - outputSize->data[i] = SizeOfDimension(value, i); + output_size->data[i] = SizeOfDimension(value, i); } - return context->ResizeTensor(context, output, outputSize); + return context->ResizeTensor(context, output, output_size); } TfLiteStatus EvalSimple(TfLiteContext* context, TfLiteNode* node, From ffede0562ce7ba7be32ffb0d2fa29a47aed66966 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 18 Jun 2024 15:33:07 -0700 Subject: [PATCH 48/59] Add an option to enable shardings where a tensor dim is sharded across more devices than the size of the dimension. PiperOrigin-RevId: 644531507 --- .../xla/hlo/experimental/auto_sharding/auto_sharding.cc | 8 ++++---- .../xla/hlo/experimental/auto_sharding/auto_sharding.h | 6 +++--- .../experimental/auto_sharding/auto_sharding_option.cc | 4 ++++ .../hlo/experimental/auto_sharding/auto_sharding_option.h | 4 ++++ .../experimental/auto_sharding/auto_sharding_strategy.cc | 8 +++++--- 5 files changed, 20 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 1282bb36ade687..030ae8c296159b 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -1641,14 +1641,14 @@ void CheckMemoryCosts(StrategyGroup* strategy_group, const Shape& shape) { } } -void RemoveInvalidShardingsWithShapes( +void RemoveShardingsWhereSmallDimsShardedAcrossManyDevices( const Shape& shape, StrategyGroup* strategy_group, const bool instruction_has_user_sharding) { if (strategy_group->is_tuple) { for (size_t i = 0; i < strategy_group->childs.size(); i++) { - RemoveInvalidShardingsWithShapes(shape.tuple_shapes().at(i), - strategy_group->childs[i].get(), - instruction_has_user_sharding); + RemoveShardingsWhereSmallDimsShardedAcrossManyDevices( + shape.tuple_shapes().at(i), strategy_group->childs[i].get(), + instruction_has_user_sharding); } } else { if (instruction_has_user_sharding && diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h index bf5327318161aa..578f70d8820341 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -365,9 +365,9 @@ std::unique_ptr MaybeFollowInsStrategyGroup( const StableHashMap>& pretrimmed_strategy_map); -void RemoveInvalidShardingsWithShapes(const Shape& shape, - StrategyGroup* strategy_group, - bool instruction_has_user_sharding); +void RemoveShardingsWhereSmallDimsShardedAcrossManyDevices( + const Shape& shape, StrategyGroup* strategy_group, + bool instruction_has_user_sharding); void ScaleCostsWithExecutionCounts(StrategyGroup* strategy_group, int64_t execution_count); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc index 857110d78483d9..3fce6b4cfc6cfd 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc @@ -139,6 +139,10 @@ std::string AutoShardingOption::ToString() const { lines.push_back(absl::StrCat("generate_windowed_einsum_strategies: ", generate_windowed_einsum_strategies)); + lines.push_back( + absl::StrCat("allow_shardings_small_dims_across_many_devices: ", + allow_shardings_small_dims_across_many_devices)); + return absl::StrJoin(lines, "\n"); } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h index 36247e624d6bcc..2165b9e78d880b 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h @@ -202,6 +202,10 @@ struct AutoShardingOption { // once it is fully implemented. bool generate_windowed_einsum_strategies = false; + // Whether or not to allow shardings where a tensor dim is shared across a + // number of devices larger than the size of the tensor dimension + bool allow_shardings_small_dims_across_many_devices = false; + // Prints a debug string. std::string ToString() const; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 78e51f866e217b..2fc47e3c206dec 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -921,9 +921,11 @@ BuildStrategyAndCost( } } } - RemoveInvalidShardingsWithShapes( - ins->shape(), strategy_group.get(), - /* instruction_has_user_sharding */ ins->has_sharding()); + if (option.allow_shardings_small_dims_across_many_devices) { + RemoveShardingsWhereSmallDimsShardedAcrossManyDevices( + ins->shape(), strategy_group.get(), + /* instruction_has_user_sharding */ ins->has_sharding()); + } if (instruction_execution_counts.contains(ins)) { ScaleCostsWithExecutionCounts(strategy_group.get(), From c50b0dea3637a54533ad86a5cb72cbe54c4eab9d Mon Sep 17 00:00:00 2001 From: Tim Peut Date: Tue, 18 Jun 2024 16:01:54 -0700 Subject: [PATCH 49/59] Set a valid minSdkVersion in dummy manifest. PiperOrigin-RevId: 644539231 --- tensorflow/lite/java/aar_with_jni.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/java/aar_with_jni.bzl b/tensorflow/lite/java/aar_with_jni.bzl index 0e66b5f9bd7e47..808183ad93b16b 100644 --- a/tensorflow/lite/java/aar_with_jni.bzl +++ b/tensorflow/lite/java/aar_with_jni.bzl @@ -30,7 +30,7 @@ cat > $(OUTS) < - + EOF """, From c0e79dad82e082a2530e82cae7db22a5164effc8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 18 Jun 2024 16:25:56 -0700 Subject: [PATCH 50/59] Replace cpu and os based selects with platform based selects PiperOrigin-RevId: 644546054 --- tensorflow/BUILD | 406 ++++++++++++++++++++------------------ tensorflow/tensorflow.bzl | 14 ++ 2 files changed, 228 insertions(+), 192 deletions(-) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 730f952827f8dc..51c8f1362386a7 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -10,6 +10,7 @@ load( "VERSION", "VERSION_MAJOR", "check_deps", + "config_setting_for_bazel7", "if_google", "if_oss", "if_xla_available", @@ -250,97 +251,85 @@ config_setting( ) # Config setting for determining if we are building for Android. -config_setting( +config_setting_for_bazel7( name = "android", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], - [], - ), - values = if_oss( - {"crosstool_top": "//external:android/crosstool"}, - {}, - ), + constraint_values = ["//third_party/bazel_platforms/os:android"], + legacy_values = {"crosstool_top": "//external:android/crosstool"}, visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "android_x86", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], - [], - ), - values = dict( - if_oss( - {"crosstool_top": "//external:android/crosstool"}, - ), - cpu = "x86", - ), + constraint_values = [ + "//third_party/bazel_platforms/cpu:x86_32", + "//third_party/bazel_platforms/os:android", + ], + legacy_values = { + "crosstool_top": "//external:android/crosstool", + "cpu": "x86", + }, visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "android_x86_64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], - [], - ), - values = dict( - if_oss( - {"crosstool_top": "//external:android/crosstool"}, - ), - cpu = "x86_64", - ), + constraint_values = [ + "//third_party/bazel_platforms/cpu:x86_64", + "//third_party/bazel_platforms/os:android", + ], + legacy_values = { + "crosstool_top": "//external:android/crosstool", + "cpu": "x86_64", + }, visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "android_armeabi", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], - [], - ), - values = dict( - if_oss( - {"crosstool_top": "//external:android/crosstool"}, - ), - cpu = "armeabi", - ), + constraint_values = [ + "//third_party/bazel_platforms/cpu:armv7", + "//third_party/bazel_platforms/os:android", + ], + legacy_values = { + "crosstool_top": "//external:android/crosstool", + "cpu": "armeabi", + }, visibility = ["//visibility:public"], ) # copybara:uncomment_begin(google-only) # config_setting( # name = "chromiumos_x86_64", -# constraint_values = ["//third_party/bazel_platforms/os:chromiumos"], -# values = {"cpu": "k8"}, +# constraint_values = [ +# "//third_party/bazel_platforms/cpu:x86_64", +# "//third_party/bazel_platforms/os:chromiumos", +# ], # visibility = ["//visibility:public"], # ) # # config_setting( # name = "chromiumos_arm64", -# constraint_values = ["//third_party/bazel_platforms/os:chromiumos"], -# values = {"cpu": "arm"}, +# constraint_values = [ +# "//third_party/bazel_platforms/cpu:arm64", +# "//third_party/bazel_platforms/os:chromiumos", +# ], # visibility = ["//visibility:public"], # ) # # config_setting( # name = "chromiumos_armv7", -# constraint_values = ["//third_party/bazel_platforms/os:chromiumos"], -# values = {"cpu": "armeabi-v7a"}, +# constraint_values = [ +# "//third_party/bazel_platforms/cpu:armv7", +# "//third_party/bazel_platforms/os:chromiumos", +# ], # visibility = ["//visibility:public"], # ) # copybara:uncomment_end -config_setting( +config_setting_for_bazel7( name = "emscripten", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:emscripten"], - [], - ), - values = if_oss( - {"crosstool_top": "//external:android/emscripten"}, - {}, - ), + constraint_values = ["//third_party/bazel_platforms/os:emscripten"], + legacy_values = {"crosstool_top": "//external:android/emscripten"}, visibility = ["//visibility:public"], ) @@ -353,48 +342,52 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "android_arm", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], - [], - ), - values = dict( - if_oss( - {"crosstool_top": "//external:android/crosstool"}, - ), - cpu = "armeabi-v7a", - ), + constraint_values = [ + "//third_party/bazel_platforms/cpu:armv7", + "//third_party/bazel_platforms/os:android", + ], + legacy_values = { + "crosstool_top": "//external:android/crosstool", + "cpu": "armeabi-v7a", + }, visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "android_arm64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], - [], - ), - values = dict( - if_oss( - {"crosstool_top": "//external:android/crosstool"}, - ), - cpu = "arm64-v8a", - ), + constraint_values = [ + "//third_party/bazel_platforms/cpu:arm64", + "//third_party/bazel_platforms/os:android", + ], + legacy_values = { + "crosstool_top": "//external:android/crosstool", + "cpu": "arm64-v8a", + }, visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "android_mips", - values = { + constraint_values = [ + "//third_party/bazel_platforms/cpu:mips64", + "//third_party/bazel_platforms/os:android", + ], + legacy_values = { "crosstool_top": "//external:android/crosstool", "cpu": "mips", }, visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "android_mips64", - values = { + constraint_values = [ + "//third_party/bazel_platforms/cpu:mips64", + "//third_party/bazel_platforms/os:android", + ], + legacy_values = { "crosstool_top": "//external:android/crosstool", "cpu": "mips64", }, @@ -402,18 +395,12 @@ config_setting( ) # TODO(jakeharmon8): Remove in favor of TSL version -config_setting( +config_setting_for_bazel7( name = "windows", # Internal builds query the target OS. - constraint_values = if_google( - ["//third_party/bazel_platforms/os:windows"], - [], - ), + constraint_values = ["//third_party/bazel_platforms/os:windows"], # OSS builds query the CPU type. - values = if_oss( - {"cpu": "x64_windows"}, - {}, - ), + legacy_values = {"cpu": "x64_windows"}, visibility = ["//visibility:public"], ) @@ -427,25 +414,26 @@ config_setting( # "darwin_x86_64". The former shows up when building on a Mac x86_64 host for a Mac x86_64 target. # The latter shows up when cross-compiling for Mac x86_64 from a Mac ARM machine and in internal # Google builds. -config_setting( +config_setting_for_bazel7( name = "macos_x86_64_default", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:macos"], - [], - ), - values = { + constraint_values = [ + "//third_party/bazel_platforms/os:macos", + "//third_party/bazel_platforms/cpu:x86_64", + ], + legacy_values = { "apple_platform_type": "macos", "cpu": "darwin", }, ) -config_setting( +config_setting_for_bazel7( name = "macos_x86_64_crosscompile", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:macos"], - [], - ), - values = { + constraint_values = [ + "//third_party/bazel_platforms/os:macos", + "//third_party/bazel_platforms/cpu:x86_64", + # TODO: introduce cross_compilable_cpu constraint in Bazel + ], + legacy_values = { "apple_platform_type": "macos", "cpu": "darwin_x86_64", }, @@ -460,13 +448,13 @@ selects.config_setting_group( visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "macos_arm64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:macos"], - [], - ), - values = { + constraint_values = [ + "//third_party/bazel_platforms/os:macos", + "//third_party/bazel_platforms/cpu:arm64", + ], + legacy_values = { "apple_platform_type": "macos", "cpu": "darwin_arm64", }, @@ -483,138 +471,151 @@ selects.config_setting_group( visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "ios", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:ios"], - [], - ), - values = if_oss( - {"apple_platform_type": "ios"}, - {}, - ), + constraint_values = ["//third_party/bazel_platforms/os:ios"], + legacy_values = {"apple_platform_type": "ios"}, visibility = ["//visibility:public"], ) # TODO(jakeharmon8): Remove in favor of TSL version -config_setting( +config_setting_for_bazel7( name = "fuchsia", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:fuchsia"], - [], - ), - values = if_oss( - # TODO(b/149248802) When we have a Fuchsia Bazel SDK update to use the values it sets. - {"cpu": "fuchsia"}, - {}, - ), + constraint_values = ["//third_party/bazel_platforms/os:fuchsia"], + # TODO(b/149248802) When we have a Fuchsia Bazel SDK update to use the values it sets. + legacy_values = {"cpu": "fuchsia"}, visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "fuchsia_x86_64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:fuchsia"], - [], - ), - values = { + constraint_values = [ + "//third_party/bazel_platforms/os:fuchsia", + "//third_party/bazel_platforms/cpu:x86_64", + ], + legacy_values = { "cpu": "x86_64", }, visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "ios_x86_64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:ios"], - [], - ), - values = dict( - if_oss( - {"crosstool_top": "//tools/osx/crosstool:crosstool"}, - ), - cpu = "ios_x86_64", - ), + constraint_values = [ + "//third_party/bazel_platforms/os:ios", + "//third_party/bazel_platforms/cpu:x86_64", + ], + legacy_values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "ios_x86_64", + }, visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "chromiumos", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:chromiumos"], - [], - ), - values = if_oss( - {"crosstool_top": "//external:android/chromiumos"}, - {}, - ), + constraint_values = ["//third_party/bazel_platforms/os:chromiumos"], + legacy_values = {"crosstool_top": "//external:android/chromiumos"}, visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "linux_aarch64", - values = {"cpu": "aarch64"}, + constraint_values = ["//third_party/bazel_platforms/cpu:aarch64"], + legacy_values = {"cpu": "aarch64"}, visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "linux_armhf", - values = {"cpu": "armhf"}, + constraint_values = ["//third_party/bazel_platforms/cpu:armv7e-mf"], + legacy_values = {"cpu": "armhf"}, visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "linux_x86_64", - values = {"cpu": "k8"}, + constraint_values = [ + "//third_party/bazel_platforms/cpu:x86_64", + "//third_party/bazel_platforms/os:linux", + ], + legacy_values = {"cpu": "k8"}, visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "haswell", - values = {"cpu": "haswell"}, + constraint_values = [ + # TODO: introduce haswell constraint in Bazel + "//third_party/bazel_platforms/cpu:x86_64", + "//third_party/bazel_platforms/os:linux", + ], + legacy_values = {"cpu": "haswell"}, visibility = ["//visibility:public"], ) # This condition takes precedence over :linux_x86_64 -config_setting( +config_setting_for_bazel7( name = "linux_x86_64_no_sse", - values = { - "cpu": "k8", - "copt": "-mno-sse4.2", - }, + constraint_values = [ + "//third_party/bazel_platforms/cpu:x86_64", + "//third_party/bazel_platforms/os:linux", + ], + legacy_values = {"cpu": "k8"}, + values = {"copt": "-mno-sse4.2"}, visibility = ["//visibility:public"], ) # This condition takes precedence over :linux_x86_64 # TODO(b/290533709): Remove this with PJRT build rule cleanup. -config_setting( +config_setting_for_bazel7( name = "linux_x86_64_with_weightwatcher", + constraint_values = [ + "//third_party/bazel_platforms/cpu:x86_64", + "//third_party/bazel_platforms/os:linux", + ], define_values = {"tensorflow_weightwatcher": "true"}, - values = {"cpu": "k8"}, + legacy_values = {"cpu": "k8"}, visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "linux_ppc64le", - values = {"cpu": "ppc"}, + constraint_values = [ + "//third_party/bazel_platforms/cpu:ppc64le", + "//third_party/bazel_platforms/os:linux", + ], + legacy_values = {"cpu": "ppc"}, visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "linux_s390x", - values = {"cpu": "s390x"}, + constraint_values = [ + "//third_party/bazel_platforms/cpu:s390x", + "//third_party/bazel_platforms/os:linux", + ], + legacy_values = {"cpu": "s390x"}, visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "linux_mips64", - values = {"cpu": "mips64"}, + constraint_values = [ + "//third_party/bazel_platforms/cpu:mips64", + "//third_party/bazel_platforms/os:linux", + ], + legacy_values = {"cpu": "mips64"}, visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "linux_riscv64", - values = {"cpu": "riscv64"}, + constraint_values = [ + "//third_party/bazel_platforms/cpu:riscv64", + "//third_party/bazel_platforms/os:linux", + ], + legacy_values = {"cpu": "riscv64"}, visibility = ["//visibility:public"], ) @@ -634,27 +635,39 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "arm", - values = {"cpu": "arm"}, + constraint_values = [ + "//third_party/bazel_platforms/cpu:arm", + ], + legacy_values = {"cpu": "arm"}, visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "armeabi", - values = {"cpu": "armeabi"}, + constraint_values = [ + "//third_party/bazel_platforms/cpu:armv7", + ], + legacy_values = {"cpu": "armeabi"}, visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "armeabi-v7a", - values = {"cpu": "armeabi-v7a"}, + constraint_values = [ + "//third_party/bazel_platforms/cpu:armv7", + ], + legacy_values = {"cpu": "armeabi-v7a"}, visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "arm64-v8a", - values = {"cpu": "arm64-v8a"}, + constraint_values = [ + "//third_party/bazel_platforms/cpu:arm64", + ], + legacy_values = {"cpu": "arm64-v8a"}, visibility = ["//visibility:public"], ) @@ -670,9 +683,10 @@ selects.config_setting_group( ], ) -config_setting( +config_setting_for_bazel7( name = "freebsd", - values = {"cpu": "freebsd"}, + constraint_values = ["//third_party/bazel_platforms/os:freebsd"], + legacy_values = {"cpu": "freebsd"}, visibility = ["//visibility:public"], ) @@ -723,24 +737,32 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "macos_x86_64_with_framework_shared_object", + constraint_values = [ + "//third_party/bazel_platforms/os:macos", + "//third_party/bazel_platforms/cpu:x86_64", + ], define_values = { "framework_shared_object": "true", }, - values = { + legacy_values = { "apple_platform_type": "macos", "cpu": "darwin", }, visibility = ["//visibility:public"], ) -config_setting( +config_setting_for_bazel7( name = "macos_arm64_with_framework_shared_object", + constraint_values = [ + "//third_party/bazel_platforms/os:macos", + "//third_party/bazel_platforms/cpu:arm64", + ], define_values = { "framework_shared_object": "true", }, - values = { + legacy_values = { "apple_platform_type": "macos", "cpu": "darwin_arm64", }, diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 297c3e094bb903..f27f4cee33360e 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -129,6 +129,20 @@ def if_google(google_value, oss_value = []): """ return oss_value # copybara:comment_replace return google_value +# Used for toolchain resolution, which work either at Google or on Bazel 7 +def config_setting_for_bazel7(*, constraint_values = [], legacy_values = {}, values = {}, **kwargs): + """A config_setting that uses either constraint values or legacy config flag values. + + At the moment it uses constraint values at Google. When Tensorflow upgrades to Bazel 7 + (or flips --incompatible_enable_cc_toolchain_resolution on Bazel <=6) the constraint value + configuration will be ready. + """ + native.config_setting( + constraint_values = if_google(constraint_values, {}), + values = if_google(values, values | legacy_values), + **kwargs + ) + def if_v2(a): return select({ clean_dep("//tensorflow:api_version_2"): a, From a1a5b8eb52ccecc54b2d45195d683c7a1f874015 Mon Sep 17 00:00:00 2001 From: Kuy Mainwaring Date: Tue, 18 Jun 2024 16:52:44 -0700 Subject: [PATCH 51/59] [XLA:GPU] Clang-tidy cleanup for xla/service/ar_crs_combiner.{cc,h} PiperOrigin-RevId: 644553397 --- third_party/xla/xla/service/BUILD | 7 ++++++- third_party/xla/xla/service/ar_crs_combiner.cc | 16 ++++++++++++++-- third_party/xla/xla/service/ar_crs_combiner.h | 10 ++++++++++ 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 520fb05947ef27..4bae8005059bbc 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -6818,13 +6818,18 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", - "//xla:types", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/ar_crs_combiner.cc b/third_party/xla/xla/service/ar_crs_combiner.cc index 91f4deea605d92..a75acbc2b38498 100644 --- a/third_party/xla/xla/service/ar_crs_combiner.cc +++ b/third_party/xla/xla/service/ar_crs_combiner.cc @@ -15,11 +15,20 @@ limitations under the License. #include "xla/service/ar_crs_combiner.h" -#include +#include +#include +#include #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -30,9 +39,12 @@ limitations under the License. #include "xla/service/call_graph.h" #include "xla/service/hlo_replication_analysis.h" #include "xla/service/pattern_matcher.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/types.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/ar_crs_combiner.h b/third_party/xla/xla/service/ar_crs_combiner.h index 60f7841a9c5fa5..3b7dfe809928b5 100644 --- a/third_party/xla/xla/service/ar_crs_combiner.h +++ b/third_party/xla/xla/service/ar_crs_combiner.h @@ -16,9 +16,19 @@ limitations under the License. #ifndef XLA_SERVICE_AR_CRS_COMBINER_H_ #define XLA_SERVICE_AR_CRS_COMBINER_H_ +#include +#include +#include +#include +#include + #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/call_graph.h" #include "xla/service/hlo_pass_interface.h" From ad29fd64edb059207fc50bfd9743d628516ea654 Mon Sep 17 00:00:00 2001 From: Kuy Mainwaring Date: Tue, 18 Jun 2024 16:54:16 -0700 Subject: [PATCH 52/59] [XLA:GPU] Clang-tidy cleanup for xla/service/allocation_tracker.cc PiperOrigin-RevId: 644553723 --- third_party/xla/xla/service/BUILD | 7 +++++-- third_party/xla/xla/service/allocation_tracker.cc | 14 ++++++++++---- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 4bae8005059bbc..99bb153d1d32c1 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -1693,18 +1693,21 @@ cc_library( hdrs = ["allocation_tracker.h"], deps = [ ":backend", - ":transfer_manager", + ":shaped_buffer", "//xla:shape_util", "//xla:status_macros", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/allocation_tracker.cc b/third_party/xla/xla/service/allocation_tracker.cc index e6b5eea72ff562..95168eba9c6c61 100644 --- a/third_party/xla/xla/service/allocation_tracker.cc +++ b/third_party/xla/xla/service/allocation_tracker.cc @@ -15,19 +15,25 @@ limitations under the License. #include "xla/service/allocation_tracker.h" +#include #include +#include +#include #include +#include -#include "absl/strings/str_cat.h" -#include "xla/map_util.h" -#include "xla/service/transfer_manager.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "xla/service/shaped_buffer.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" -#include "xla/types.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla { From ed006d099424a9da0f93f24fbe2a1a6fcffe4530 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 18 Jun 2024 18:18:00 -0700 Subject: [PATCH 53/59] Use jax AOT APIs instead of deprecated jax.xla_computation. PiperOrigin-RevId: 644573713 --- tensorflow/lite/python/lite.py | 19 +++++++++++-------- tensorflow/lite/python/util.py | 4 ++-- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 826fd6e6e0ec73..cf1162e2499772 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -57,7 +57,7 @@ from tensorflow.lite.python.op_hint import is_ophint_converted as _is_ophint_converted from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import from tensorflow.lite.python.optimize import calibrator as _calibrator -from tensorflow.lite.python.util import _xla_computation +from tensorflow.lite.python.util import _jit from tensorflow.lite.python.util import build_debug_info_func as _build_debug_info_func from tensorflow.lite.python.util import convert_debug_info_func as _convert_debug_info_func from tensorflow.lite.python.util import freeze_graph as _freeze_graph @@ -1981,15 +1981,15 @@ def convert(self): Raises: ImportError: - If cannot import the xla_computation from jax. + If cannot import the jit from jax. ValueError: No serving function is specified. Input tensors are not specified. The truth value of an array with more than one element is ambiguous. Failed to convert the given Jax function to hlo. """ - if not _xla_computation: - raise ImportError("Cannot import xla_computation from jax.") + if not _jit: + raise ImportError("Cannot import jit from jax.") if not self._serving_funcs: raise ValueError("No serving func is specified.") @@ -2024,10 +2024,13 @@ def convert(self): ordered_inputs.append(tensor) try: - xla_compuation = _xla_computation(self._serving_funcs[0], backend="cpu") - hlo_proto = xla_compuation( - *ordered_inputs - ).as_serialized_hlo_module_proto() + hlo_proto = ( + _jit(self._serving_funcs[0]) + .trace(*ordered_inputs) + .lower(lowering_platforms=("cpu",)) + .compiler_ir("hlo") + .as_serialized_hlo_module_proto() + ) except Exception: # pylint: disable=broad-except raise ValueError("Failed to convert the given Jax function to hlo.") diff --git a/tensorflow/lite/python/util.py b/tensorflow/lite/python/util.py index 7b3589dc082dc9..c0692655c3f127 100644 --- a/tensorflow/lite/python/util.py +++ b/tensorflow/lite/python/util.py @@ -50,9 +50,9 @@ # pylint: disable=g-import-not-at-top # pylint: disable=unused-import try: - from jax import xla_computation as _xla_computation + from jax import jit as _jit except ImportError: - _xla_computation = None + _jit = None # pylint: enable=g-import-not-at-top # pylint: enable=unused-import From e98b73df9b2a6e79d3dee92deb9bbabd303281f2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 18 Jun 2024 19:01:29 -0700 Subject: [PATCH 54/59] Integrate LLVM at llvm/llvm-project@b99d0b344001 Updates LLVM usage to match [b99d0b344001](https://github.com/llvm/llvm-project/commit/b99d0b344001) PiperOrigin-RevId: 644581469 --- third_party/llvm/externc.patch | 24 ---- third_party/llvm/generated.patch | 116 ------------------ third_party/llvm/workspace.bzl | 5 +- .../xla/service/gpu/fusions/reduction_mlir.cc | 4 +- 4 files changed, 4 insertions(+), 145 deletions(-) delete mode 100644 third_party/llvm/externc.patch diff --git a/third_party/llvm/externc.patch b/third_party/llvm/externc.patch deleted file mode 100644 index 96116fad9d8704..00000000000000 --- a/third_party/llvm/externc.patch +++ /dev/null @@ -1,24 +0,0 @@ -diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h -index 45218a1cd4eb..bed93045f4b5 100644 ---- a/mlir/include/mlir-c/Rewrite.h -+++ b/mlir/include/mlir-c/Rewrite.h -@@ -19,6 +19,10 @@ - #include "mlir-c/Support.h" - #include "mlir/Config/mlir-config.h" - -+#ifdef __cplusplus -+extern "C" { -+#endif -+ - //===----------------------------------------------------------------------===// - /// Opaque type declarations (see mlir-c/IR.h for more details). - //===----------------------------------------------------------------------===// -@@ -57,4 +61,8 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op); - - #undef DEFINE_C_API_STRUCT - -+#ifdef __cplusplus -+} -+#endif -+ - #endif // MLIR_C_REWRITE_H diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 765078cd0493f5..509398da979e83 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,117 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/clang/unittests/Lex/HeaderSearchTest.cpp b/clang/unittests/Lex/HeaderSearchTest.cpp ---- a/clang/unittests/Lex/HeaderSearchTest.cpp -+++ b/clang/unittests/Lex/HeaderSearchTest.cpp -@@ -19,6 +19,8 @@ - #include "clang/Serialization/InMemoryModuleCache.h" - #include "llvm/Support/MemoryBuffer.h" - #include "gtest/gtest.h" -+#include -+#include - - namespace clang { - namespace { -@@ -350,8 +352,8 @@ - std::string TextualPath = "/textual.h"; - }; - -- auto ExternalSource = new MockExternalHeaderFileInfoSource(); -- Search.SetExternalSource(ExternalSource); -+ auto ExternalSource = std::make_unique(); -+ Search.SetExternalSource(ExternalSource.get()); - - // Everything should start out external. - auto ModularFE = AddHeader(ExternalSource->ModularPath); -diff -ruN --strip-trailing-cr a/llvm/test/DebugInfo/Generic/sroa-extract-bits.ll b/llvm/test/DebugInfo/Generic/sroa-extract-bits.ll ---- a/llvm/test/DebugInfo/Generic/sroa-extract-bits.ll -+++ b/llvm/test/DebugInfo/Generic/sroa-extract-bits.ll -@@ -10,11 +10,11 @@ - ; CHECK-SAME: i32 [[ARG:%.*]]) { - ; CHECK-NEXT: entry: - ; CHECK-NEXT: [[PTR_SROA_0_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[ARG]] to i8 --; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i8 [[PTR_SROA_0_0_EXTRACT_TRUNC]], metadata [[META2:![0-9]+]], metadata !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 8)), !dbg [[DBG7:![0-9]+]] -+; CHECK-NEXT: #dbg_value(i8 [[PTR_SROA_0_0_EXTRACT_TRUNC]], [[META2:![0-9]+]], !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 8), [[META7:![0-9]+]]) - ; CHECK-NEXT: [[PTR_SROA_2_0_EXTRACT_SHIFT:%.*]] = lshr i32 [[ARG]], 8 - ; CHECK-NEXT: [[PTR_SROA_2_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[PTR_SROA_2_0_EXTRACT_SHIFT]] to i24 --; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i24 [[PTR_SROA_2_0_EXTRACT_TRUNC]], metadata [[META8:![0-9]+]], metadata !DIExpression(DW_OP_LLVM_extract_bits_zext, 8, 16)), !dbg [[DBG7]] --; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i24 [[PTR_SROA_2_0_EXTRACT_TRUNC]], metadata [[META9:![0-9]+]], metadata !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 8)), !dbg [[DBG7]] -+; CHECK-NEXT: #dbg_value(i24 [[PTR_SROA_2_0_EXTRACT_TRUNC]], [[META8:![0-9]+]], !DIExpression(DW_OP_LLVM_extract_bits_zext, 8, 16), [[META7]]) -+; CHECK-NEXT: #dbg_value(i24 [[PTR_SROA_2_0_EXTRACT_TRUNC]], [[META9:![0-9]+]], !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 8), [[META7]]) - ; CHECK-NEXT: ret i8 [[PTR_SROA_0_0_EXTRACT_TRUNC]] - ; - entry: -@@ -33,14 +33,14 @@ - ; CHECK-SAME: i32 [[ARG1:%.*]], i8 [[ARG2:%.*]]) { - ; CHECK-NEXT: entry: - ; CHECK-NEXT: [[PTR_SROA_0_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[ARG1]] to i8 --; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i8 [[PTR_SROA_0_0_EXTRACT_TRUNC]], metadata [[META2]], metadata !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 8)), !dbg [[DBG7]] -+; CHECK-NEXT: #dbg_value(i8 [[PTR_SROA_0_0_EXTRACT_TRUNC]], [[META2]], !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 8), [[META7]]) - ; CHECK-NEXT: [[PTR_SROA_2_0_EXTRACT_SHIFT:%.*]] = lshr i32 [[ARG1]], 8 - ; CHECK-NEXT: [[PTR_SROA_2_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[PTR_SROA_2_0_EXTRACT_SHIFT]] to i16 --; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i16 [[PTR_SROA_2_0_EXTRACT_TRUNC]], metadata [[META9]], metadata !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 16)), !dbg [[DBG7]] -+; CHECK-NEXT: #dbg_value(i16 [[PTR_SROA_2_0_EXTRACT_TRUNC]], [[META9]], !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 16), [[META7]]) - ; CHECK-NEXT: [[PTR_SROA_21_0_EXTRACT_SHIFT:%.*]] = lshr i32 [[ARG1]], 24 - ; CHECK-NEXT: [[PTR_SROA_21_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[PTR_SROA_21_0_EXTRACT_SHIFT]] to i8 --; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i8 [[PTR_SROA_21_0_EXTRACT_TRUNC]], metadata [[META8]], metadata !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 8)), !dbg [[DBG7]] --; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i8 [[ARG2]], metadata [[META8]], metadata !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 8)), !dbg [[DBG7]] -+; CHECK-NEXT: #dbg_value(i8 [[PTR_SROA_21_0_EXTRACT_TRUNC]], [[META8]], !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 8), [[META7]]) -+; CHECK-NEXT: #dbg_value(i8 [[ARG2]], [[META8]], !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 8), [[META7]]) - ; CHECK-NEXT: ret i8 [[PTR_SROA_0_0_EXTRACT_TRUNC]] - ; - entry: -@@ -81,10 +81,10 @@ - ; CHECK-SAME: i32 [[ARG:%.*]]) { - ; CHECK-NEXT: entry: - ; CHECK-NEXT: [[PTR_SROA_0_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[ARG]] to i16 --; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i16 [[PTR_SROA_0_0_EXTRACT_TRUNC]], metadata [[META2]], metadata !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 8)), !dbg [[DBG7]] -+; CHECK-NEXT: #dbg_value(i16 [[PTR_SROA_0_0_EXTRACT_TRUNC]], [[META2]], !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 8), [[META7]]) - ; CHECK-NEXT: [[PTR_SROA_2_0_EXTRACT_SHIFT:%.*]] = lshr i32 [[ARG]], 16 - ; CHECK-NEXT: [[PTR_SROA_2_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[PTR_SROA_2_0_EXTRACT_SHIFT]] to i16 --; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i16 [[PTR_SROA_2_0_EXTRACT_TRUNC]], metadata [[META8]], metadata !DIExpression(DW_OP_LLVM_extract_bits_zext, 8, 8)), !dbg [[DBG7]] -+; CHECK-NEXT: #dbg_value(i16 [[PTR_SROA_2_0_EXTRACT_TRUNC]], [[META8]], !DIExpression(DW_OP_LLVM_extract_bits_zext, 8, 8), [[META7]]) - ; CHECK-NEXT: ret i16 [[PTR_SROA_0_0_EXTRACT_TRUNC]] - ; - entry: -@@ -104,11 +104,11 @@ - ; CHECK-SAME: i32 [[ARG:%.*]]) { - ; CHECK-NEXT: entry: - ; CHECK-NEXT: [[PTR_SROA_0_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[ARG]] to i8 --; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i8 [[PTR_SROA_0_0_EXTRACT_TRUNC]], metadata [[META11:![0-9]+]], metadata !DIExpression()), !dbg [[DBG7]] -+; CHECK-NEXT: #dbg_value(i8 [[PTR_SROA_0_0_EXTRACT_TRUNC]], [[META11:![0-9]+]], !DIExpression(), [[META7]]) - ; CHECK-NEXT: [[PTR_SROA_2_0_EXTRACT_SHIFT:%.*]] = lshr i32 [[ARG]], 8 - ; CHECK-NEXT: [[PTR_SROA_2_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[PTR_SROA_2_0_EXTRACT_SHIFT]] to i24 --; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i24 [[PTR_SROA_2_0_EXTRACT_TRUNC]], metadata [[META8]], metadata !DIExpression(DW_OP_LLVM_extract_bits_zext, 8, 8)), !dbg [[DBG7]] --; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i24 [[PTR_SROA_2_0_EXTRACT_TRUNC]], metadata [[META9]], metadata !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 8)), !dbg [[DBG7]] -+; CHECK-NEXT: #dbg_value(i24 [[PTR_SROA_2_0_EXTRACT_TRUNC]], [[META8]], !DIExpression(DW_OP_LLVM_extract_bits_zext, 8, 8), [[META7]]) -+; CHECK-NEXT: #dbg_value(i24 [[PTR_SROA_2_0_EXTRACT_TRUNC]], [[META9]], !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 8), [[META7]]) - ; CHECK-NEXT: ret i8 [[PTR_SROA_0_0_EXTRACT_TRUNC]] - ; - entry: -@@ -127,14 +127,14 @@ - ; CHECK-SAME: i32 [[ARG1:%.*]], i8 [[ARG2:%.*]]) { - ; CHECK-NEXT: entry: - ; CHECK-NEXT: [[PTR_SROA_0_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[ARG1]] to i8 --; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i8 undef, metadata [[META2]], metadata !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 8)), !dbg [[DBG7]] -+; CHECK-NEXT: #dbg_value(i8 undef, [[META2]], !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 8), [[META7]]) - ; CHECK-NEXT: [[PTR_SROA_2_0_EXTRACT_SHIFT:%.*]] = lshr i32 [[ARG1]], 8 - ; CHECK-NEXT: [[PTR_SROA_2_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[PTR_SROA_2_0_EXTRACT_SHIFT]] to i16 --; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i16 undef, metadata [[META9]], metadata !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 16)), !dbg [[DBG7]] -+; CHECK-NEXT: #dbg_value(i16 undef, [[META9]], !DIExpression(DW_OP_LLVM_extract_bits_zext, 0, 16), [[META7]]) - ; CHECK-NEXT: [[PTR_SROA_21_0_EXTRACT_SHIFT:%.*]] = lshr i32 [[ARG1]], 24 - ; CHECK-NEXT: [[PTR_SROA_21_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[PTR_SROA_21_0_EXTRACT_SHIFT]] to i8 --; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i8 undef, metadata [[META8]], metadata !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 8)), !dbg [[DBG7]] --; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i8 undef, metadata [[META8]], metadata !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 8)), !dbg [[DBG7]] -+; CHECK-NEXT: #dbg_value(i8 undef, [[META8]], !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 8), [[META7]]) -+; CHECK-NEXT: #dbg_value(i8 undef, [[META8]], !DIExpression(DW_OP_LLVM_extract_bits_sext, 0, 8), [[META7]]) - ; CHECK-NEXT: ret i8 [[PTR_SROA_0_0_EXTRACT_TRUNC]] - ; - entry: -@@ -196,7 +196,7 @@ - ; CHECK: [[META4]] = distinct !DICompileUnit(language: DW_LANG_C_plus_plus_14, file: [[META5:![0-9]+]], isOptimized: false, runtimeVersion: 0, emissionKind: FullDebug) - ; CHECK: [[META5]] = !DIFile(filename: "dbg-bit-piece.cpp", directory: "") - ; CHECK: [[META6]] = !DIBasicType(name: "unsigned int", size: 32, encoding: DW_ATE_unsigned) --; CHECK: [[DBG7]] = !DILocation(line: 0, scope: [[META3]]) -+; CHECK: [[META7]] = !DILocation(line: 0, scope: [[META3]]) - ; CHECK: [[META8]] = !DILocalVariable(name: "z", scope: [[META3]], type: [[META6]]) - ; CHECK: [[META9]] = !DILocalVariable(name: "y", scope: [[META3]], type: [[META10:![0-9]+]]) - ; CHECK: [[META10]] = !DIBasicType(name: "signed int", size: 32, encoding: DW_ATE_signed) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 0ef1683dacd8a5..be194471a81f13 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "52d87de7a42d608ac1da33795ca0a892f2b53f36" - LLVM_SHA256 = "6c78e61b7b1cef7ca91c9454bc18fb63426040dd06f9639ae9a89758926eec57" + LLVM_COMMIT = "b99d0b34400176cb9183113b96b245400caaf8d8" + LLVM_SHA256 = "9ca100ba202a8048ad478149a31cb3f06e164ec4dc49c17fe043807bea608c69" tf_http_archive( name = name, @@ -19,7 +19,6 @@ def repo(name): patch_file = [ "//third_party/llvm:generated.patch", # Autogenerated, don't remove. "//third_party/llvm:build.patch", - "//third_party/llvm:externc.patch", "//third_party/llvm:mathextras.patch", "//third_party/llvm:toolchains.patch", "//third_party/llvm:zstd.patch", diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc index 46e1bddf5e5dd8..c9163364e35d4b 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc @@ -600,7 +600,7 @@ llvm::SmallVector MlirRowReductionFusion::EmitReduction( auto thread_ids = mlir_converter::ApplyIndexing(thread_indexing, {thread_id}, {}, b); - Value lane_id = b.create(); + Value lane_id = b.create(/*upper_bound=*/nullptr); Value warp_id = b.create( thread_ids[ReductionDimensions::kRowMinorReducedDimension], b.create(WarpSize())); @@ -829,7 +829,7 @@ llvm::SmallVector MlirColumnReductionFusion::EmitReduction( Value cst_true = b.create(b.getOneAttr(b.getI1Type())); Value thread_id = state.thread_and_block_ids[0]; - Value lane_id = b.create(); + Value lane_id = b.create(/*upper_bound=*/nullptr); Value warp_id = b.create( thread_id, b.create(WarpSize())); From e13d7ea9645b9505b06be1a5c2e9abf45f2a1ce0 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Tue, 18 Jun 2024 19:07:16 -0700 Subject: [PATCH 55/59] Allow using ifrt_client() as a shared_ptr from PyClient. PiperOrigin-RevId: 644582468 --- third_party/xla/xla/python/py_client.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/third_party/xla/xla/python/py_client.h b/third_party/xla/xla/python/py_client.h index 292447ab34899f..e8eb4aa3c35263 100644 --- a/third_party/xla/xla/python/py_client.h +++ b/third_party/xla/xla/python/py_client.h @@ -63,6 +63,9 @@ class PyClient { virtual ~PyClient(); ifrt::Client* ifrt_client() const { return ifrt_client_.get(); } + const std::shared_ptr& shared_ptr_ifrt_client() const { + return ifrt_client_; + } // Short-term escape hatch to get PjRtClient from PyClient. // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. From c625df7f3b56b1ab2b37a2aa252b4868b8742d44 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 18 Jun 2024 23:45:24 -0700 Subject: [PATCH 56/59] [XLA:GPU[Rocm] Fix missing import in `ir_emitter_triton_rocm.cc`. PiperOrigin-RevId: 644637700 --- third_party/xla/xla/service/gpu/ir_emitter_triton_rocm.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_rocm.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_rocm.cc index 2a6e6c3c805cd8..b0d5dc5187c33a 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_rocm.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_rocm.cc @@ -24,6 +24,7 @@ limitations under the License. #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/service/gpu/triton_sparse_extensions.h" #include "xla/service/hlo_module_config.h" #include "tsl/platform/rocm_rocdl_path.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" From e7e7273436119bd301a07102b5bdc1cc37c4abf6 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 18 Jun 2024 23:46:59 -0700 Subject: [PATCH 57/59] [XLA:GPU] Fetch `EnumDescriptor` utils from `tsl::protobuf` in `triton_support_test.cc`. Otherwise, the OSS test doesn't build. PiperOrigin-RevId: 644638056 --- third_party/xla/xla/service/gpu/triton_support_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/triton_support_test.cc b/third_party/xla/xla/service/gpu/triton_support_test.cc index fb818d6efab617..6c690bad4c47a9 100644 --- a/third_party/xla/xla/service/gpu/triton_support_test.cc +++ b/third_party/xla/xla/service/gpu/triton_support_test.cc @@ -51,8 +51,8 @@ auto AllXlaDataTypes() { std::vector xla_data_types; std::vector to_filter_out = {PRIMITIVE_TYPE_INVALID, TUPLE, OPAQUE_TYPE, TOKEN}; - const proto2::EnumDescriptor* xla_type_descriptor = - proto2::GetEnumDescriptor(); + const tsl::protobuf::EnumDescriptor* xla_type_descriptor = + tsl::protobuf::GetEnumDescriptor(); for (int enum_ix = 0; enum_ix < xla_type_descriptor->value_count(); ++enum_ix) { xla::PrimitiveType xla_type = static_cast( From 8c71440e230d307d13052091912c4aa39e689d25 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 19 Jun 2024 00:38:02 -0700 Subject: [PATCH 58/59] Reverts c0e79dad82e082a2530e82cae7db22a5164effc8 PiperOrigin-RevId: 644651233 --- tensorflow/BUILD | 406 ++++++++++++++++++-------------------- tensorflow/tensorflow.bzl | 14 -- 2 files changed, 192 insertions(+), 228 deletions(-) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 51c8f1362386a7..730f952827f8dc 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -10,7 +10,6 @@ load( "VERSION", "VERSION_MAJOR", "check_deps", - "config_setting_for_bazel7", "if_google", "if_oss", "if_xla_available", @@ -251,85 +250,97 @@ config_setting( ) # Config setting for determining if we are building for Android. -config_setting_for_bazel7( +config_setting( name = "android", - constraint_values = ["//third_party/bazel_platforms/os:android"], - legacy_values = {"crosstool_top": "//external:android/crosstool"}, + constraint_values = if_google( + ["//third_party/bazel_platforms/os:android"], + [], + ), + values = if_oss( + {"crosstool_top": "//external:android/crosstool"}, + {}, + ), visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "android_x86", - constraint_values = [ - "//third_party/bazel_platforms/cpu:x86_32", - "//third_party/bazel_platforms/os:android", - ], - legacy_values = { - "crosstool_top": "//external:android/crosstool", - "cpu": "x86", - }, + constraint_values = if_google( + ["//third_party/bazel_platforms/os:android"], + [], + ), + values = dict( + if_oss( + {"crosstool_top": "//external:android/crosstool"}, + ), + cpu = "x86", + ), visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "android_x86_64", - constraint_values = [ - "//third_party/bazel_platforms/cpu:x86_64", - "//third_party/bazel_platforms/os:android", - ], - legacy_values = { - "crosstool_top": "//external:android/crosstool", - "cpu": "x86_64", - }, + constraint_values = if_google( + ["//third_party/bazel_platforms/os:android"], + [], + ), + values = dict( + if_oss( + {"crosstool_top": "//external:android/crosstool"}, + ), + cpu = "x86_64", + ), visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "android_armeabi", - constraint_values = [ - "//third_party/bazel_platforms/cpu:armv7", - "//third_party/bazel_platforms/os:android", - ], - legacy_values = { - "crosstool_top": "//external:android/crosstool", - "cpu": "armeabi", - }, + constraint_values = if_google( + ["//third_party/bazel_platforms/os:android"], + [], + ), + values = dict( + if_oss( + {"crosstool_top": "//external:android/crosstool"}, + ), + cpu = "armeabi", + ), visibility = ["//visibility:public"], ) # copybara:uncomment_begin(google-only) # config_setting( # name = "chromiumos_x86_64", -# constraint_values = [ -# "//third_party/bazel_platforms/cpu:x86_64", -# "//third_party/bazel_platforms/os:chromiumos", -# ], +# constraint_values = ["//third_party/bazel_platforms/os:chromiumos"], +# values = {"cpu": "k8"}, # visibility = ["//visibility:public"], # ) # # config_setting( # name = "chromiumos_arm64", -# constraint_values = [ -# "//third_party/bazel_platforms/cpu:arm64", -# "//third_party/bazel_platforms/os:chromiumos", -# ], +# constraint_values = ["//third_party/bazel_platforms/os:chromiumos"], +# values = {"cpu": "arm"}, # visibility = ["//visibility:public"], # ) # # config_setting( # name = "chromiumos_armv7", -# constraint_values = [ -# "//third_party/bazel_platforms/cpu:armv7", -# "//third_party/bazel_platforms/os:chromiumos", -# ], +# constraint_values = ["//third_party/bazel_platforms/os:chromiumos"], +# values = {"cpu": "armeabi-v7a"}, # visibility = ["//visibility:public"], # ) # copybara:uncomment_end -config_setting_for_bazel7( +config_setting( name = "emscripten", - constraint_values = ["//third_party/bazel_platforms/os:emscripten"], - legacy_values = {"crosstool_top": "//external:android/emscripten"}, + constraint_values = if_google( + ["//third_party/bazel_platforms/os:emscripten"], + [], + ), + values = if_oss( + {"crosstool_top": "//external:android/emscripten"}, + {}, + ), visibility = ["//visibility:public"], ) @@ -342,52 +353,48 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "android_arm", - constraint_values = [ - "//third_party/bazel_platforms/cpu:armv7", - "//third_party/bazel_platforms/os:android", - ], - legacy_values = { - "crosstool_top": "//external:android/crosstool", - "cpu": "armeabi-v7a", - }, + constraint_values = if_google( + ["//third_party/bazel_platforms/os:android"], + [], + ), + values = dict( + if_oss( + {"crosstool_top": "//external:android/crosstool"}, + ), + cpu = "armeabi-v7a", + ), visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "android_arm64", - constraint_values = [ - "//third_party/bazel_platforms/cpu:arm64", - "//third_party/bazel_platforms/os:android", - ], - legacy_values = { - "crosstool_top": "//external:android/crosstool", - "cpu": "arm64-v8a", - }, + constraint_values = if_google( + ["//third_party/bazel_platforms/os:android"], + [], + ), + values = dict( + if_oss( + {"crosstool_top": "//external:android/crosstool"}, + ), + cpu = "arm64-v8a", + ), visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "android_mips", - constraint_values = [ - "//third_party/bazel_platforms/cpu:mips64", - "//third_party/bazel_platforms/os:android", - ], - legacy_values = { + values = { "crosstool_top": "//external:android/crosstool", "cpu": "mips", }, visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "android_mips64", - constraint_values = [ - "//third_party/bazel_platforms/cpu:mips64", - "//third_party/bazel_platforms/os:android", - ], - legacy_values = { + values = { "crosstool_top": "//external:android/crosstool", "cpu": "mips64", }, @@ -395,12 +402,18 @@ config_setting_for_bazel7( ) # TODO(jakeharmon8): Remove in favor of TSL version -config_setting_for_bazel7( +config_setting( name = "windows", # Internal builds query the target OS. - constraint_values = ["//third_party/bazel_platforms/os:windows"], + constraint_values = if_google( + ["//third_party/bazel_platforms/os:windows"], + [], + ), # OSS builds query the CPU type. - legacy_values = {"cpu": "x64_windows"}, + values = if_oss( + {"cpu": "x64_windows"}, + {}, + ), visibility = ["//visibility:public"], ) @@ -414,26 +427,25 @@ config_setting( # "darwin_x86_64". The former shows up when building on a Mac x86_64 host for a Mac x86_64 target. # The latter shows up when cross-compiling for Mac x86_64 from a Mac ARM machine and in internal # Google builds. -config_setting_for_bazel7( +config_setting( name = "macos_x86_64_default", - constraint_values = [ - "//third_party/bazel_platforms/os:macos", - "//third_party/bazel_platforms/cpu:x86_64", - ], - legacy_values = { + constraint_values = if_google( + ["//third_party/bazel_platforms/os:macos"], + [], + ), + values = { "apple_platform_type": "macos", "cpu": "darwin", }, ) -config_setting_for_bazel7( +config_setting( name = "macos_x86_64_crosscompile", - constraint_values = [ - "//third_party/bazel_platforms/os:macos", - "//third_party/bazel_platforms/cpu:x86_64", - # TODO: introduce cross_compilable_cpu constraint in Bazel - ], - legacy_values = { + constraint_values = if_google( + ["//third_party/bazel_platforms/os:macos"], + [], + ), + values = { "apple_platform_type": "macos", "cpu": "darwin_x86_64", }, @@ -448,13 +460,13 @@ selects.config_setting_group( visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "macos_arm64", - constraint_values = [ - "//third_party/bazel_platforms/os:macos", - "//third_party/bazel_platforms/cpu:arm64", - ], - legacy_values = { + constraint_values = if_google( + ["//third_party/bazel_platforms/os:macos"], + [], + ), + values = { "apple_platform_type": "macos", "cpu": "darwin_arm64", }, @@ -471,151 +483,138 @@ selects.config_setting_group( visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "ios", - constraint_values = ["//third_party/bazel_platforms/os:ios"], - legacy_values = {"apple_platform_type": "ios"}, + constraint_values = if_google( + ["//third_party/bazel_platforms/os:ios"], + [], + ), + values = if_oss( + {"apple_platform_type": "ios"}, + {}, + ), visibility = ["//visibility:public"], ) # TODO(jakeharmon8): Remove in favor of TSL version -config_setting_for_bazel7( +config_setting( name = "fuchsia", - constraint_values = ["//third_party/bazel_platforms/os:fuchsia"], - # TODO(b/149248802) When we have a Fuchsia Bazel SDK update to use the values it sets. - legacy_values = {"cpu": "fuchsia"}, + constraint_values = if_google( + ["//third_party/bazel_platforms/os:fuchsia"], + [], + ), + values = if_oss( + # TODO(b/149248802) When we have a Fuchsia Bazel SDK update to use the values it sets. + {"cpu": "fuchsia"}, + {}, + ), visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "fuchsia_x86_64", - constraint_values = [ - "//third_party/bazel_platforms/os:fuchsia", - "//third_party/bazel_platforms/cpu:x86_64", - ], - legacy_values = { + constraint_values = if_google( + ["//third_party/bazel_platforms/os:fuchsia"], + [], + ), + values = { "cpu": "x86_64", }, visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "ios_x86_64", - constraint_values = [ - "//third_party/bazel_platforms/os:ios", - "//third_party/bazel_platforms/cpu:x86_64", - ], - legacy_values = { - "crosstool_top": "//tools/osx/crosstool:crosstool", - "cpu": "ios_x86_64", - }, + constraint_values = if_google( + ["//third_party/bazel_platforms/os:ios"], + [], + ), + values = dict( + if_oss( + {"crosstool_top": "//tools/osx/crosstool:crosstool"}, + ), + cpu = "ios_x86_64", + ), visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "chromiumos", - constraint_values = ["//third_party/bazel_platforms/os:chromiumos"], - legacy_values = {"crosstool_top": "//external:android/chromiumos"}, + constraint_values = if_google( + ["//third_party/bazel_platforms/os:chromiumos"], + [], + ), + values = if_oss( + {"crosstool_top": "//external:android/chromiumos"}, + {}, + ), visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "linux_aarch64", - constraint_values = ["//third_party/bazel_platforms/cpu:aarch64"], - legacy_values = {"cpu": "aarch64"}, + values = {"cpu": "aarch64"}, visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "linux_armhf", - constraint_values = ["//third_party/bazel_platforms/cpu:armv7e-mf"], - legacy_values = {"cpu": "armhf"}, + values = {"cpu": "armhf"}, visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "linux_x86_64", - constraint_values = [ - "//third_party/bazel_platforms/cpu:x86_64", - "//third_party/bazel_platforms/os:linux", - ], - legacy_values = {"cpu": "k8"}, + values = {"cpu": "k8"}, visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "haswell", - constraint_values = [ - # TODO: introduce haswell constraint in Bazel - "//third_party/bazel_platforms/cpu:x86_64", - "//third_party/bazel_platforms/os:linux", - ], - legacy_values = {"cpu": "haswell"}, + values = {"cpu": "haswell"}, visibility = ["//visibility:public"], ) # This condition takes precedence over :linux_x86_64 -config_setting_for_bazel7( +config_setting( name = "linux_x86_64_no_sse", - constraint_values = [ - "//third_party/bazel_platforms/cpu:x86_64", - "//third_party/bazel_platforms/os:linux", - ], - legacy_values = {"cpu": "k8"}, - values = {"copt": "-mno-sse4.2"}, + values = { + "cpu": "k8", + "copt": "-mno-sse4.2", + }, visibility = ["//visibility:public"], ) # This condition takes precedence over :linux_x86_64 # TODO(b/290533709): Remove this with PJRT build rule cleanup. -config_setting_for_bazel7( +config_setting( name = "linux_x86_64_with_weightwatcher", - constraint_values = [ - "//third_party/bazel_platforms/cpu:x86_64", - "//third_party/bazel_platforms/os:linux", - ], define_values = {"tensorflow_weightwatcher": "true"}, - legacy_values = {"cpu": "k8"}, + values = {"cpu": "k8"}, visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "linux_ppc64le", - constraint_values = [ - "//third_party/bazel_platforms/cpu:ppc64le", - "//third_party/bazel_platforms/os:linux", - ], - legacy_values = {"cpu": "ppc"}, + values = {"cpu": "ppc"}, visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "linux_s390x", - constraint_values = [ - "//third_party/bazel_platforms/cpu:s390x", - "//third_party/bazel_platforms/os:linux", - ], - legacy_values = {"cpu": "s390x"}, + values = {"cpu": "s390x"}, visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "linux_mips64", - constraint_values = [ - "//third_party/bazel_platforms/cpu:mips64", - "//third_party/bazel_platforms/os:linux", - ], - legacy_values = {"cpu": "mips64"}, + values = {"cpu": "mips64"}, visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "linux_riscv64", - constraint_values = [ - "//third_party/bazel_platforms/cpu:riscv64", - "//third_party/bazel_platforms/os:linux", - ], - legacy_values = {"cpu": "riscv64"}, + values = {"cpu": "riscv64"}, visibility = ["//visibility:public"], ) @@ -635,39 +634,27 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "arm", - constraint_values = [ - "//third_party/bazel_platforms/cpu:arm", - ], - legacy_values = {"cpu": "arm"}, + values = {"cpu": "arm"}, visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "armeabi", - constraint_values = [ - "//third_party/bazel_platforms/cpu:armv7", - ], - legacy_values = {"cpu": "armeabi"}, + values = {"cpu": "armeabi"}, visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "armeabi-v7a", - constraint_values = [ - "//third_party/bazel_platforms/cpu:armv7", - ], - legacy_values = {"cpu": "armeabi-v7a"}, + values = {"cpu": "armeabi-v7a"}, visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "arm64-v8a", - constraint_values = [ - "//third_party/bazel_platforms/cpu:arm64", - ], - legacy_values = {"cpu": "arm64-v8a"}, + values = {"cpu": "arm64-v8a"}, visibility = ["//visibility:public"], ) @@ -683,10 +670,9 @@ selects.config_setting_group( ], ) -config_setting_for_bazel7( +config_setting( name = "freebsd", - constraint_values = ["//third_party/bazel_platforms/os:freebsd"], - legacy_values = {"cpu": "freebsd"}, + values = {"cpu": "freebsd"}, visibility = ["//visibility:public"], ) @@ -737,32 +723,24 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "macos_x86_64_with_framework_shared_object", - constraint_values = [ - "//third_party/bazel_platforms/os:macos", - "//third_party/bazel_platforms/cpu:x86_64", - ], define_values = { "framework_shared_object": "true", }, - legacy_values = { + values = { "apple_platform_type": "macos", "cpu": "darwin", }, visibility = ["//visibility:public"], ) -config_setting_for_bazel7( +config_setting( name = "macos_arm64_with_framework_shared_object", - constraint_values = [ - "//third_party/bazel_platforms/os:macos", - "//third_party/bazel_platforms/cpu:arm64", - ], define_values = { "framework_shared_object": "true", }, - legacy_values = { + values = { "apple_platform_type": "macos", "cpu": "darwin_arm64", }, diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index f27f4cee33360e..297c3e094bb903 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -129,20 +129,6 @@ def if_google(google_value, oss_value = []): """ return oss_value # copybara:comment_replace return google_value -# Used for toolchain resolution, which work either at Google or on Bazel 7 -def config_setting_for_bazel7(*, constraint_values = [], legacy_values = {}, values = {}, **kwargs): - """A config_setting that uses either constraint values or legacy config flag values. - - At the moment it uses constraint values at Google. When Tensorflow upgrades to Bazel 7 - (or flips --incompatible_enable_cc_toolchain_resolution on Bazel <=6) the constraint value - configuration will be ready. - """ - native.config_setting( - constraint_values = if_google(constraint_values, {}), - values = if_google(values, values | legacy_values), - **kwargs - ) - def if_v2(a): return select({ clean_dep("//tensorflow:api_version_2"): a, From d724a03045a30591c9460790f5fedd9b930ee0ff Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Tue, 18 Jun 2024 02:07:17 -0700 Subject: [PATCH 59/59] Stop using xla/statusor.h now that it just contains an alias for absl::Status. In some situations, this meant also changing unrelated files to directly include tsl/platform/statusor.h to get the definitions for TF_ASSIGN_OR_RETURN, etc., where they were getting transitively included for free. PiperOrigin-RevId: 644301234 --- tensorflow/core/kernels/scatter_nd_op_test.cc | 2 +- third_party/xla/xla/service/gpu/BUILD | 33 ++++------------- .../service/gpu/cudnn_fused_mha_rewriter.cc | 2 +- .../gpu/cudnn_vectorize_convolutions.cc | 2 +- .../gpu/cudnn_vectorize_convolutions_test.cc | 1 - .../xla/service/gpu/fusions/loop_mlir_test.cc | 36 +++++++++++++++++++ .../xla/service/gpu/gemm_algorithm_picker.cc | 2 +- .../xla/xla/service/gpu/gemm_rewriter.cc | 2 +- .../xla/xla/service/gpu/gpu_hlo_schedule.cc | 1 + .../xla/service/gpu/gpu_p2p_pipeliner_test.cc | 2 +- .../xla/service/gpu/hlo_fusion_analysis.cc | 1 + .../xla/service/gpu/split_k_gemm_rewriter.cc | 2 +- .../service/gpu/stream_attribute_annotator.cc | 2 +- .../xla/xla/service/gpu/topk_splitter.cc | 2 +- .../gpu/triton_fusion_analysis_test.cc | 1 - 15 files changed, 54 insertions(+), 37 deletions(-) diff --git a/tensorflow/core/kernels/scatter_nd_op_test.cc b/tensorflow/core/kernels/scatter_nd_op_test.cc index eb5f42d012ad8c..fd679c35347798 100644 --- a/tensorflow/core/kernels/scatter_nd_op_test.cc +++ b/tensorflow/core/kernels/scatter_nd_op_test.cc @@ -363,7 +363,7 @@ TEST_F(ScatterNdOpConstructionTest, Error_BadIndicesPolicyInvalid) { .Input(FakeInput(DT_INT32)) .Attr("bad_indices_policy", "AN_UNRECOGNIZED_POLICY") .Finalize(node_def())); - EXPECT_NE(InitOp(), OkStatus()); + EXPECT_NE(InitOp(), absl::OkStatus()); } class ScatterNdUpdateBM : public ScatterNdUpdateOpTest { diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index dcb4a1d9b9bc10..afb2944a8e0b59 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -327,7 +327,6 @@ cc_library( "//xla:literal", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/ffi:attribute_map", @@ -393,6 +392,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -507,7 +507,6 @@ cc_library( "//xla:literal", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -809,7 +808,6 @@ cc_library( "//xla:shape_util", "//xla:status_macros", "//xla/tools:hlo_decomposer_lib", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", @@ -1209,7 +1207,6 @@ cc_library( "//xla:shape_util", "//xla:status", "//xla:status_macros", - "//xla:statusor", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", @@ -1382,7 +1379,6 @@ xla_cc_test( deps = [ ":gemm_fusion", ":triton_fusion_analysis", - "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", @@ -1501,7 +1497,6 @@ cc_library( "//xla:autotuning_proto_cc", "//xla:literal_util", "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -1511,6 +1506,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/types:span", @@ -1618,7 +1614,6 @@ cc_library( "//xla:util", "//xla:autotuning_proto_cc", "//xla:shape_util", - "//xla:statusor", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", @@ -1650,7 +1645,6 @@ cc_library( "//xla:autotuning_proto_cc", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:types", "//xla:util", "//xla:xla_proto_cc", @@ -1696,7 +1690,6 @@ cc_library( "//xla/stream_executor/gpu:redzone_allocator", "//xla:executable_run_options", "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla:xla_proto_cc", "@local_tsl//tsl/platform:errors", @@ -1769,7 +1762,6 @@ cc_library( "//xla:autotuning_proto_cc", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", @@ -1999,7 +1991,6 @@ cc_library( "//xla:debug_options_flags", "//xla:literal_util", "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -2096,7 +2087,6 @@ cc_library( ":stream_executor_util", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", @@ -2105,6 +2095,7 @@ cc_library( "//xla/stream_executor:dnn", "//xla/stream_executor:lazy_op_runner", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:statusor", ] + if_cuda_is_configured([ @@ -2279,7 +2270,6 @@ cc_library( ]), deps = [ "//xla:comparison_util", - "//xla:statusor", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", @@ -2317,7 +2307,6 @@ cc_library( "//xla:literal", "//xla:literal_util", "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -2830,7 +2819,6 @@ cc_library( ":cudnn_support_utils", ":stream_executor_util", "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla/client:xla_builder", "//xla/client:xla_computation", @@ -2859,7 +2847,6 @@ xla_cc_test( ":backend_configs_cc", ":cublas_cudnn", ":cudnn_vectorize_convolutions", - "//xla:statusor", "//xla:util", "//xla/service:call_inliner", "//xla/service:hlo_parser", @@ -3503,7 +3490,7 @@ cc_library( "@llvm-project//mlir:Support", "//xla:autotune_results_proto_cc", "//xla:status_macros", - "//xla:statusor", + "@com_google_absl//absl/status:statusor", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", @@ -3820,7 +3807,6 @@ cc_library( "@llvm-project//llvm:Support", "//xla:autotune_results_proto_cc", "//xla:status_macros", - "//xla:statusor", "//xla:types", "//xla:util", "//xla:xla_proto_cc", @@ -4336,7 +4322,6 @@ xla_cc_test( ], deps = [ ":gpu_p2p_pipeliner", - "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", @@ -4346,6 +4331,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest", ], ) @@ -4512,13 +4498,13 @@ cc_library( ":ir_emission_utils", ":reduction_utils", "//xla:shape_util", - "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", @@ -4559,7 +4545,6 @@ cc_library( "@eigen_archive//:eigen3", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:util", "//xla/service:hlo_module_config", "//xla/stream_executor", @@ -4779,7 +4764,6 @@ cc_library( ":backend_configs_cc", ":cublas_cudnn", "//xla:shape_util", - "//xla:statusor", "//xla:types", "//xla:util", "//xla:window_util", @@ -4840,7 +4824,6 @@ cc_library( "//xla:permutation_util", "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", @@ -5402,7 +5385,6 @@ cc_library( srcs = if_gpu_is_configured(["make_batch_pointers.cc"]), hdrs = if_gpu_is_configured(["make_batch_pointers.h"]), deps = [ - "//xla:statusor", "//xla:types", "//xla:util", "//xla/stream_executor", @@ -5410,6 +5392,7 @@ cc_library( "//xla/stream_executor:typed_kernel_factory", "//xla/stream_executor/gpu:gpu_stream_header", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ] + if_cuda_is_configured([ @@ -5596,7 +5579,6 @@ cc_library( hdrs = ["topk_splitter.h"], deps = [ "//xla:shape_util", - "//xla:statusor", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:hlo_creation_utils", @@ -6041,7 +6023,6 @@ cc_library( deps = [ ":backend_configs_cc", ":gpu_fusible", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc index f03fe4f0fac1ab..23b238dad651a3 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc +++ b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" @@ -46,7 +47,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/dnn.h" #include "xla/types.h" diff --git a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.cc b/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.cc index 4c29e69aa70028..3846f01136a81f 100644 --- a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.cc +++ b/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/client/xla_builder.h" @@ -42,7 +43,6 @@ limitations under the License. #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/dnn.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions_test.cc b/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions_test.cc index d448621fbc6bc4..aa15fc73093fc6 100644 --- a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions_test.cc +++ b/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions_test.cc @@ -29,7 +29,6 @@ limitations under the License. #include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/dnn.h" #include "xla/tests/hlo_test_base.h" diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc index 008b01630b363c..6af3fb0a9e173d 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc @@ -181,6 +181,42 @@ TEST_F(MlirLoopFusionTest, ThreadId_Broadcast) { )")); } +TEST_F(MlirLoopFusionTest, Constant_Broadcast) { + auto kHloString = R"( + HloModule module + + bcast { + zero = bf16[] constant(0) + ROOT broadcast = bf16[2,16,48]{2,1,0} broadcast(zero), dimensions={} + } + + ENTRY entry { + ROOT %fusion = bf16[2,16,48]{2,1,0} fusion(), kind=kLoop, calls=bcast + } + )"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1 * 1024 + d0)> + // CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (((d1 * 4 + d0 floordiv 256) floordiv 3) mod 2)> + // CHECK: #[[MAP2:.*]] = affine_map<(d0, d1) -> (((d1 * 64 + d0 floordiv 16) floordiv 3) mod 16)> + // CHECK: #[[MAP3:.*]] = affine_map<(d0, d1) -> ((d1 * 1024 + d0) mod 48)> + // CHECK: func.func @fused_computation(%[[ARG0:.*]]: tensor<2x16x48xbf16> + // CHECK: %[[UPPER_BOUND:.*]] = arith.constant 1535 : index + // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id + // CHECK: %[[BLOCK_ID:.*]] = gpu.block_id + // CHECK: %[[LINEAR:.*]] = xla_gpu.apply_indexing #[[MAP0]] + // CHECL: %[[IN_BOUNDS:.*]] = arith.cmpi sle, %[[LINEAR]], %[[UPPER_BOUND]] : index + // scf.if %[[IN_BOUNDS]] + // CHECK: %[[I0:.*]] = xla_gpu.apply_indexing #[[MAP1]] + // CHECK: %[[I1:.*]] = xla_gpu.apply_indexing #[[MAP2]] + // CHECK: %[[I2:.*]] = xla_gpu.apply_indexing #[[MAP3]] + // CHECK: %[[BCAST:.*]] = xla_gpu.pure_call @bcast_broadcast + // CHECK: %[[INSERTED:.*]] = tensor.insert %[[BCAST]] into %[[ARG0]][%[[I0]], %[[I1]], %[[I2]]] + // CHECK: func.func private @bcast_broadcast + // CHECK: arith.constant 0.000000e+00 + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0})); +} + TEST_F(MlirLoopFusionTest, NoCodeDuplication) { // This test HLO is copied from // xla/service/fusion_node_indexing_evaluation_test.cc. diff --git a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc index 2207d9eac9999d..dd10917d812a32 100644 --- a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" @@ -43,7 +44,6 @@ limitations under the License. #include "xla/service/gpu/variant_visitor.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/gemm_rewriter.cc index e109c6b3fa6d5d..660ae0412c8e2d 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter.cc @@ -35,6 +35,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/evaluator/hlo_evaluator.h" @@ -57,7 +58,6 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc index 07449034b1705f..0463875b4f98a4 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_format.h" diff --git a/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc b/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc index e0ee476ea914b1..1dcab0c47b98c9 100644 --- a/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include "absl/log/check.h" +#include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -32,7 +33,6 @@ limitations under the License. #include "xla/service/hlo_parser.h" #include "xla/service/hlo_pass_pipeline.h" #include "xla/service/hlo_verifier.h" -#include "xla/statusor.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc index a027e85ab5cd3a..345fd8c709ee49 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc index 4ca4c62a886637..6b0d9ea2fb4854 100644 --- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -47,7 +48,6 @@ limitations under the License. #include "xla/service/hlo_creation_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/service/gpu/stream_attribute_annotator.cc b/third_party/xla/xla/service/gpu/stream_attribute_annotator.cc index 0b8d00984df7a3..7e54ea5aded6a0 100644 --- a/third_party/xla/xla/service/gpu/stream_attribute_annotator.cc +++ b/third_party/xla/xla/service/gpu/stream_attribute_annotator.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -29,7 +30,6 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/runtime/thunk.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/service/gpu/topk_splitter.cc b/third_party/xla/xla/service/gpu/topk_splitter.cc index 3b8ef207f124d9..d20116dd22dd7c 100644 --- a/third_party/xla/xla/service/gpu/topk_splitter.cc +++ b/third_party/xla/xla/service/gpu/topk_splitter.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/numeric/bits.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" @@ -37,7 +38,6 @@ limitations under the License. #include "xla/service/hlo_creation_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc index ad61b6771e2c83..051ae78fa9644b 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc @@ -24,7 +24,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gemm_fusion.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h"