[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automated Code Change #70910

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 0 additions & 35 deletions third_party/xla/xla/service/gpu/triton_fusion_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,41 +213,6 @@ absl::StatusOr<TritonFusionAnalysis> TritonFusionAnalysis::Execute(
return analysis;
}

absl::Status TritonFusionAnalysis::ExecuteForProducerConsumer(
const HloInstruction& producer, const HloInstruction& consumer,
int split_k) {
// TODO(shyshkov): Use HloFusionAdaptor to avoid the need to materialize the
// hlo fusion.
std::unique_ptr<HloModule> new_module =
ExtractProducerConsumerIntoNewModule(producer, consumer);

auto* new_producer =
new_module->entry_computation()->GetInstructionWithName(producer.name());
auto* new_consumer =
new_module->entry_computation()->GetInstructionWithName(consumer.name());

std::unique_ptr<HloInstruction> fusion_instruction_holder;
HloInstruction* fusion_instruction;
if (new_consumer->opcode() == HloOpcode::kFusion) {
fusion_instruction = new_consumer;
} else {
fusion_instruction_holder = HloInstruction::CreateFusion(
new_consumer->shape(), new_producer->fusion_kind(), new_consumer);
fusion_instruction = fusion_instruction_holder.get();
}

// Try to merge the producer into candidate fusion.
if (new_producer->opcode() == HloOpcode::kFusion) {
fusion_instruction->MergeFusionInstruction(new_producer);
} else {
fusion_instruction->FuseInstruction(new_producer);
}

auto* fused_computation =
fusion_instruction->fused_instructions_computation();
return Execute(*fused_computation, split_k).status();
}

absl::Status TritonFusionAnalysis::ExecuteForDotFusion(
const HloInstruction& dot, const int split_k) {
DotRequirements lhs_requirements(kNoSplitRequirement);
Expand Down
8 changes: 0 additions & 8 deletions third_party/xla/xla/service/gpu/triton_fusion_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,6 @@ class TritonFusionAnalysis {
static absl::StatusOr<TritonFusionAnalysis> Execute(
const HloComputation& computation, int split_k = 1);

// Execute the analysis of a produce-consumer fusion. Returns absl::OkStatus,
// if the analysis can find a valid tiling for the producer-consumer fusion.
// `split_k` indicates whether this operation was converted to the split-K
// form and tells the analysis how to interpret the batch dimensions.
static absl::Status ExecuteForProducerConsumer(const HloInstruction& producer,
const HloInstruction& consumer,
int split_k = 1);

// A scope is an HLO graph that can be tiled efficiently using same or
// compatible tile shapes on all operations. GEMM fusion has 3 or 4 scopes
// defined by left operand, right operand, optional meta (third operand) and
Expand Down
47 changes: 0 additions & 47 deletions third_party/xla/xla/tools/hlo_decomposer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,53 +175,6 @@ std::unique_ptr<HloModule> ExtractInstructionIntoNewModule(
return new_hlo_module;
}

std::unique_ptr<HloModule> ExtractProducerConsumerIntoNewModule(
const HloInstruction& producer, const HloInstruction& consumer) {
auto new_hlo_module =
std::make_unique<HloModule>("extracted", HloModuleConfig{},
std::make_unique<CompilationEnvironments>(
consumer.GetModule()->comp_envs()));
int parameter_number = 0;
HloComputation::Builder builder("entry_computation");
HloCloneContext clone_context(new_hlo_module.get());
absl::InlinedVector<HloInstruction*, 8> producer_operands;
for (const HloInstruction* operand : producer.operands()) {
HloInstruction* new_parameter =
builder.AddInstruction(HloInstruction::CreateParameter(
parameter_number, operand->shape(), operand->name()));
++parameter_number;

producer_operands.push_back(new_parameter);
}

HloInstruction* new_producer =
builder.AddInstruction(producer.CloneWithNewOperands(
producer.shape(), producer_operands, &clone_context));

absl::flat_hash_map<const HloInstruction*, HloInstruction*> operand_map;
operand_map.emplace(&producer, new_producer);

absl::InlinedVector<HloInstruction*, 8> consumer_operands;
for (const HloInstruction* operand : consumer.operands()) {
auto it = operand_map.find(operand);
if (it != operand_map.end()) {
consumer_operands.push_back(it->second);
} else {
HloInstruction* new_parameter =
builder.AddInstruction(HloInstruction::CreateParameter(
parameter_number, operand->shape(), operand->name()));
++parameter_number;

consumer_operands.push_back(new_parameter);
}
}
builder.AddInstruction(consumer.CloneWithNewOperands(
consumer.shape(), consumer_operands, &clone_context));

new_hlo_module->AddEntryComputationWithLayouts(builder.Build());
return new_hlo_module;
}

std::unique_ptr<HloModule> ExtractComputationIntoNewModule(
const HloComputation& computation) {
auto new_hlo_module =
Expand Down
5 changes: 0 additions & 5 deletions third_party/xla/xla/tools/hlo_decomposer.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,6 @@ std::unique_ptr<HloModule> ExtractInstructionIntoNewModule(
std::unique_ptr<HloModule> ExtractInstructionIntoNewModule(
const std::vector<HloInstruction*>& instructions);

// Extracts producer and consumer HLO instruction into a new HLO module
// replacing its operands with parameter instructions.
std::unique_ptr<HloModule> ExtractProducerConsumerIntoNewModule(
const HloInstruction& producer, const HloInstruction& consumer);

// Extracts an HLO computation into a new HLO module, using its clone as the
// root computation.
std::unique_ptr<HloModule> ExtractComputationIntoNewModule(
Expand Down