[go: nahoru, domu]

Skip to content

Commit

Permalink
Re-land: [XLA:GPU] Store fusion_roots and fusion_heroes as HloInstruc…
Browse files Browse the repository at this point in the history
…tionAdaptor in HloFusionAnalysis.

Reverts 23ac61b

PiperOrigin-RevId: 635737157
  • Loading branch information
olegshyshkov authored and tensorflower-gardener committed May 21, 2024
1 parent f30bdc8 commit 104615d
Show file tree
Hide file tree
Showing 20 changed files with 155 additions and 128 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4436,6 +4436,7 @@ cc_library(
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
],
)
Expand Down
6 changes: 6 additions & 0 deletions third_party/xla/xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,7 @@ cc_library(
"//xla/hlo/ir:hlo",
"//xla/service:buffer_assignment",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:ir_emitter",
"//xla/service/gpu:ir_emitter_context",
Expand Down Expand Up @@ -812,6 +813,7 @@ xla_cc_test(
"//xla/service/gpu:gpu_device_info_for_tests",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:ir_emitter_context",
"//xla/service/gpu/model:indexing_analysis",
"//xla/service/gpu/model:indexing_test_utils",
"//xla/stream_executor:device_description",
"//xla/tests:hlo_test_base",
Expand All @@ -833,6 +835,7 @@ cc_library(
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:reduction_utils",
"//xla/service/gpu/fusions/mlir:computation_partitioner",
Expand All @@ -845,6 +848,7 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
Expand Down Expand Up @@ -969,6 +973,7 @@ cc_library(
":fusion_emitter",
":tiling_util",
"//xla:permutation_util",
"//xla:shape_util",
"//xla:status",
"//xla:util",
"//xla/hlo/ir:hlo",
Expand Down Expand Up @@ -1058,6 +1063,7 @@ cc_library(
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:launch_dimensions",
"//xla/service/gpu/fusions/mlir:computation_partitioner",
"//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
Expand Down
5 changes: 3 additions & 2 deletions third_party/xla/xla/service/gpu/fusions/concatenate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ absl::Status ConcatenateFusion::EmitKernel(
for (const auto& [output, root] :
llvm::zip_equal(outputs, analysis_.fusion_roots())) {
llvm_ir::IrArray::Index root_index = result_index.SourceIndexOfBitcast(
concat.shape(), root->shape(), builder);
TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(*root));
concat.shape(), root.shape(), builder);
TF_ASSIGN_OR_RETURN(auto generator,
fused_emitter.GetGenerator(root.instruction()));
TF_ASSIGN_OR_RETURN(llvm::Value * value, generator(root_index));
output.EmitWriteArrayElement(root_index, value, builder);
}
Expand Down
7 changes: 4 additions & 3 deletions third_party/xla/xla/service/gpu/fusions/fusions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ bool IsParameterOrGteOfParameter(const HloInstruction* instr) {

bool IsDynamicUpdateSliceFusion(const HloFusionAnalysis& analysis) {
return absl::c_all_of(
analysis.fusion_root_adaptors(), [](const HloInstructionAdaptor& root) {
analysis.fusion_roots(), [](const HloInstructionAdaptor& root) {
return root.opcode() == HloOpcode::kDynamicUpdateSlice ||
(root.opcode() == HloOpcode::kBitcast &&
root.GetOperand(0).opcode() == HloOpcode::kDynamicUpdateSlice);
Expand All @@ -89,7 +89,8 @@ bool IsDynamicUpdateSliceFusion(const HloFusionAnalysis& analysis) {
std::optional<absl::StatusOr<std::unique_ptr<FusionInterface>>>
HloFusionInfo::GetCopyFusion() const {
std::vector<BufferAllocation::Slice> src_buffers;
for (auto* root : analysis().fusion_roots()) {
for (const HloInstructionAdaptor& root_adaptor : analysis().fusion_roots()) {
const HloInstruction* root = &root_adaptor.instruction();
if (root->opcode() != HloOpcode::kCopy ||
root->operand(0)->opcode() != HloOpcode::kParameter ||
!LayoutUtil::Equal(root->operand(0)->shape().layout(),
Expand Down Expand Up @@ -127,7 +128,7 @@ HloFusionInfo::GetCopyFusion() const {

bool HloFusionInfo::CanEmitDynamicUpdateSliceInPlace() const {
auto ret = CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
instr_, buffer_assignment_, analysis().fusion_root_adaptors());
instr_, buffer_assignment_, analysis().fusion_roots());
return ret.ok() && *ret;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ class InPlaceDynamicUpdateSliceFusion : public KernelFusionEmitterBase {
public:
explicit InPlaceDynamicUpdateSliceFusion(const HloFusionAnalysis& analysis)
: analysis_(analysis),
dus_ops_(GetOutputDefiningDynamicUpdateSlices(
analysis.fusion_root_adaptors())) {}
dus_ops_(
GetOutputDefiningDynamicUpdateSlices(analysis.fusion_roots())) {}
LaunchDimensions launch_dimensions() const override;

std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ absl::Status MlirInPlaceDynamicUpdateSliceFusion::EmitEntryFunction(
ProvideParameter(root_computation, dus_instr, kDUSUpdateIndex,
input_indices, call_targets, entry_function, b);
// Handle bitcasts under the DUS.
if (dus_instr->shape() != root->shape()) {
if (dus_instr->shape() != root.shape()) {
update_indices = ApplyAffineMap(
GetBitcastMap(dus_instr->shape(), root->shape(), b.getContext())
GetBitcastMap(dus_instr->shape(), root.shape(), b.getContext())
.GetAffineMap(),
update_indices, {}, b);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class MlirInPlaceDynamicUpdateSliceFusion : public MlirFusionEmitterBase {
explicit MlirInPlaceDynamicUpdateSliceFusion(
const HloFusionAnalysis& analysis)
: analysis_(analysis),
dus_ops_(GetOutputDefiningDynamicUpdateSlices(
analysis.fusion_root_adaptors())) {}
dus_ops_(
GetOutputDefiningDynamicUpdateSlices(analysis.fusion_roots())) {}

LaunchDimensions launch_dimensions() const override;

Expand Down
22 changes: 15 additions & 7 deletions third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ limitations under the License.
#include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/gpu/model/indexing_analysis.h"
#include "xla/service/gpu/model/indexing_map.h"
Expand Down Expand Up @@ -69,11 +70,16 @@ MlirInputSlicesFusion::ComputeThreadIdToOutputIndexing(
std::vector<mlir_converter::EpilogueSpecification>
MlirInputSlicesFusion::GetEpilogues(const HloFusionInstruction& fusion,
mlir::MLIRContext* mlir_context) const {
std::vector<const HloInstruction*> roots;
roots.reserve(analysis_.fusion_root_count());
for (const auto& root : analysis_.fusion_roots()) {
roots.push_back(&root.instruction());
}

// We don't actually use epilogues here, but this is how we tell the base
// class not to emit code for the slices.
return {mlir_converter::EpilogueSpecification::FromOutputIndexing(
analysis_, analysis_.fusion_roots(), analysis_.fusion_roots(), *this,
mlir_context)};
analysis_, roots, roots, *this, mlir_context)};
}

LaunchDimensions MlirInputSlicesFusion::launch_dimensions() const {
Expand All @@ -94,7 +100,8 @@ absl::Status MlirInputSlicesFusion::EmitEntryFunction(
builder.setInsertionPointToStart(entry_function.addEntryBlock());

auto launch_dims = launch_dimensions();
const auto& shape = analysis_.fusion_roots().front()->operand(0)->shape();
const auto& shape =
analysis_.fusion_root(0).instruction().operand(0)->shape();
auto input_indexing = GetDefaultThreadIdIndexingMap(
launch_dims, unroll_factor_, shape, builder.getContext());

Expand All @@ -115,8 +122,8 @@ absl::Status MlirInputSlicesFusion::EmitEntryFunction(
result_tensors.reserve(output_tensor_args.size());

absl::flat_hash_map<const HloInstruction*, mlir::Value> input_values;
for (const auto* root : analysis_.fusion_roots()) {
const auto* arg = root->operand(0);
for (const HloInstructionAdaptor& root : analysis_.fusion_roots()) {
const auto* arg = root.instruction().operand(0);
if (auto& value = input_values[arg]; !value) {
value =
builder.create<PureCallOp>(call_targets(arg), input_operands)
Expand All @@ -137,8 +144,9 @@ absl::Status MlirInputSlicesFusion::EmitEntryFunction(
auto output_indices = mlir_converter::ApplyAffineMap(
output_indexing->GetAffineMap(), dim_values, symbol_values,
then_builder);
const auto* arg =
analysis_.fusion_roots()[output_index]->operand(0);
const auto* arg = analysis_.fusion_root(output_index)
.instruction()
.operand(0);
auto inserted = then_builder.create<mlir::tensor::InsertOp>(
input_values[arg], output, output_indices);
then_builder.create<mlir::scf::YieldOp>(inserted.getResult());
Expand Down
14 changes: 7 additions & 7 deletions third_party/xla/xla/service/gpu/fusions/loop_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ std::optional<IndexingMap> MlirLoopFusion::ComputeThreadIdToOutputIndexing(
auto launch_dims = launch_dimensions();
return GetDefaultThreadIdIndexingMap(
launch_dims, config_.unroll_factor,
GetIndexShape(analysis_.fusion_roots()[root_index]->shape()), ctx);
GetIndexShape(analysis_.fusion_root(root_index).shape()), ctx);
}

std::optional<IndexingMap> MlirLoopFusion::ComputeThreadIdToInputIndexing(
Expand Down Expand Up @@ -93,8 +93,8 @@ std::optional<IndexingMap> MlirLoopFusion::ComputeThreadIdToInputIndexing(

LaunchDimensions MlirLoopFusion::launch_dimensions() const {
return CalculateLaunchDimensions(
GetIndexShape(analysis_.fusion_roots()[0]->shape()),
analysis_.device_info(), config_);
GetIndexShape(analysis_.fusion_root(0).shape()), analysis_.device_info(),
config_);
}

absl::Status MlirLoopFusion::EmitEntryFunction(
Expand All @@ -113,13 +113,13 @@ absl::Status MlirLoopFusion::EmitEntryFunction(
auto output_tensor_args =
entry_function.getArguments().drop_front(num_inputs);
llvm::SmallVector<const Shape*> result_shapes;
for (const auto* root : analysis_.fusion_roots()) {
if (root->shape().IsTuple()) {
for (const auto& shape : root->shape().tuple_shapes()) {
for (const HloInstructionAdaptor& root : analysis_.fusion_roots()) {
if (root.shape().IsTuple()) {
for (const auto& shape : root.shape().tuple_shapes()) {
result_shapes.push_back(&shape);
}
} else {
result_shapes.push_back(&root->shape());
result_shapes.push_back(&root.shape());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,11 @@ EpilogueSpecification EpilogueSpecification::FromOutputIndexing(
root_to_hero;
for (auto [root, hero] :
llvm::zip(analysis.fusion_roots(), analysis.fusion_heroes())) {
root_to_hero[root] = hero;
root_to_hero[&root.instruction()] = &hero.instruction();
}
absl::flat_hash_map<const HloInstruction*, int> root_to_index;
for (auto [index, root] : llvm::enumerate(analysis.fusion_roots())) {
root_to_index[root] = root_to_index.size();
root_to_index[&root.instruction()] = root_to_index.size();
}

result.root_indexing.reserve(roots.size());
Expand Down
17 changes: 9 additions & 8 deletions third_party/xla/xla/service/gpu/fusions/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ limitations under the License.
#include "xla/service/gpu/fusions/thunk_util.h"
#include "xla/service/gpu/fusions/tiling_util.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/ir_emitter_context.h"
#include "xla/service/gpu/ir_emitter_nested.h"
Expand Down Expand Up @@ -147,10 +148,10 @@ class ReductionEmitter {
index_ty_(GetIndexType(fusion, reduction_codegen_info.GetTiling(),
elemental_emitter_.builder())) {
for (auto hero : analysis.fusion_heroes()) {
if (hero->opcode() == HloOpcode::kReduce) {
for (int i = 0; i < hero->operand_count() / 2; ++i) {
if (hero.opcode() == HloOpcode::kReduce) {
for (int i = 0; i < hero.instruction().operand_count() / 2; ++i) {
CHECK(LayoutUtil::IsMonotonicWithDim0Major(
hero->operand(i)->shape().layout()))
hero.instruction().operand(i)->shape().layout()))
<< "reduction-layout-normalizer must run before code generation";
}
}
Expand Down Expand Up @@ -1012,10 +1013,10 @@ absl::StatusOr<FusionEmissionResult> ReductionEmitter::EmitInitializers() {
return absl::OkStatus();
}));

absl::Span<const HloInstruction* const> fusion_roots =
absl::Span<HloInstructionAdaptor const> fusion_roots =
analysis_.fusion_roots();
for (int i = 0; i < fusion_roots.size(); ++i) {
const HloInstruction* fusion_root = fusion_roots[i];
const HloInstruction* fusion_root = &fusion_roots[i].instruction();

if (IsReductionFromOrToContiguousDimensions(*fusion_root)) {
TF_ASSIGN_OR_RETURN(
Expand Down Expand Up @@ -1046,9 +1047,9 @@ absl::Status ReductionEmitter::EmitKernel(
ReductionOutputMap result_ir_arrays;

int ir_arrays_idx = 0;
for (const HloInstruction* root : analysis_.fusion_roots()) {
int get_num_results = GetNumOutputs(root->shape());
result_ir_arrays[root] =
for (const HloInstructionAdaptor& root : analysis_.fusion_roots()) {
int get_num_results = GetNumOutputs(root.shape());
result_ir_arrays[&root.instruction()] =
absl::MakeSpan(outputs).subspan(ir_arrays_idx, get_num_results);
ir_arrays_idx += get_num_results;
}
Expand Down
8 changes: 4 additions & 4 deletions third_party/xla/xla/service/gpu/fusions/reduction_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ ReductionGroups GroupDisjointReductions(const HloFusionAnalysis& analysis,
disjoint_sets[root].Get() = root;
reachable_outputs[root].insert(root);
result.is_reduction_root.push_back(
IsRealReductionHero(root.instruction(), *hero));
IsRealReductionHero(root.instruction(), hero.instruction()));
if (result.is_reduction_root.back()) {
roots_with_reduction.insert(root);
} else if (first_non_reduction_root) {
Expand Down Expand Up @@ -337,7 +337,7 @@ std::optional<IndexingMap> ReductionInfo::ComputeThreadIdToOutputIndexing(
auto map = ComposeIndexingMaps(
GetIndexingMapForTiling(tiling_, ctx),
GetBitcastMap(tiling_.GetXlaShape(),
analysis_.fusion_roots()[root_index]->shape(), ctx));
analysis_.fusion_root(root_index).shape(), ctx));
AddGroupIdConstraint(map, root_index, ctx);
return map;
}
Expand Down Expand Up @@ -431,8 +431,8 @@ std::optional<IndexingMap> ReductionInfo::ComputeThreadIdToInputIndexing(
if (!groups_.is_reduction_root[root_index]) {
return ComposeIndexingMaps(
*ComputeThreadIdToOutputIndexing(root_index, ctx),
*ComputeOutputToInputIndexing(analysis_.fusion_roots()[root_index], 0,
ctx)
*ComputeOutputToInputIndexing(
&analysis_.fusion_root(root_index).instruction(), 0, ctx)
.indexing_maps[hero_operand_index]
.begin());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "xla/service/gpu/gpu_device_info_for_tests.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/ir_emitter_context.h"
#include "xla/service/gpu/model/indexing_analysis.h"
#include "xla/service/gpu/model/indexing_test_utils.h"
#include "xla/stream_executor/device_description.h"
#include "xla/tests/hlo_test_base.h"
Expand Down Expand Up @@ -492,8 +493,8 @@ TEST_F(ReductionTest, TwoGroups) {
FakeReductionFusion fusion(analysis);

EXPECT_THAT(fusion.reduction_info().GetGroups().grouped_roots,
ElementsAre(ElementsAre(analysis.fusion_roots()[0]),
ElementsAre(analysis.fusion_roots()[1])));
ElementsAre(ElementsAre(&analysis.fusion_root(0).instruction()),
ElementsAre(&analysis.fusion_root(1).instruction())));
}

TEST_F(ReductionTest, OneGroup) {
Expand Down
12 changes: 8 additions & 4 deletions third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ limitations under the License.
#include "xla/service/gpu/fusions/mlir/type_util.h"
#include "xla/service/gpu/fusions/reduction_base.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/model/indexing_analysis.h"
#include "xla/service/gpu/model/indexing_map.h"
#include "xla/service/gpu/reduction_utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"

namespace xla {
Expand Down Expand Up @@ -80,9 +82,9 @@ struct MlirReductionFusion::EmitterState {
computation(computations.FindPartitionedComputation(
fusion.fused_instructions_computation())) {
int index = 0;
for (auto root : owner.analysis().fusion_roots()) {
fusion_result_index_starts[root] = index;
index += root->shape().IsTuple() ? root->shape().tuple_shapes_size() : 1;
for (const auto& root : owner.analysis().fusion_roots()) {
fusion_result_index_starts[&root.instruction()] = index;
index += root.shape().IsTuple() ? root.shape().tuple_shapes_size() : 1;
}
}

Expand Down Expand Up @@ -132,9 +134,11 @@ MlirReductionFusion::MlirReductionFusion(const HloFusionAnalysis& analysis)
reduction_roots_.resize(num_groups);

absl::flat_hash_set<const HloInstruction*> seen_heroes;
for (auto [root, hero, is_reduction, group_id] :
for (auto [root_adaptor, hero_adaptor, is_reduction, group_id] :
llvm::zip(analysis.fusion_roots(), analysis.fusion_heroes(),
groups.is_reduction_root, groups.group_id_per_root)) {
const HloInstruction* root = &root_adaptor.instruction();
const HloInstruction* hero = &hero_adaptor.instruction();
if (is_reduction) {
if (seen_heroes.insert(hero).second) {
reduction_heroes_[group_id].push_back(hero);
Expand Down
Loading

0 comments on commit 104615d

Please sign in to comment.