[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:CollectivePipeliner] Refactor IsSupportedDynamicUpdateSlice into…
Browse files Browse the repository at this point in the history
… a separate function.

PiperOrigin-RevId: 646676579
  • Loading branch information
seherellis authored and tensorflower-gardener committed Jun 26, 2024
1 parent 22804b0 commit ee3916e
Showing 1 changed file with 145 additions and 95 deletions.
240 changes: 145 additions & 95 deletions third_party/xla/xla/service/collective_pipeliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,9 @@ bool IsLoopIterator(const HloInstruction* instr,
// Scavenge operands that are dependencies not included in the ops set and that
// aren't the source_op passed as input parameter and return them in a vector.
std::vector<HloInstruction*> CollectDependenciesToPipeline(
HloInstruction* source_op, absl::Span<HloInstruction* const> ops) {
absl::flat_hash_set<HloInstruction*> formatting_set(ops.begin(), ops.end());
const HloInstruction* source_op, absl::Span<HloInstruction* const> ops) {
absl::flat_hash_set<const HloInstruction*> formatting_set(ops.begin(),
ops.end());
formatting_set.insert(source_op);
std::vector<HloInstruction*> to_return;
absl::flat_hash_set<HloInstruction*> already_inserted;
Expand Down Expand Up @@ -709,6 +710,27 @@ class WhileLoopAnalysis {
int64_t GetMaxPipeliningPerLoop() const { return max_pipelining_per_loop_; }

bool ComputeLoopStatistics();
// Checks if the given dynamic-update-slice is supported for pipelining and
// returns its slice dimension and index in the while tuple if supported.
// Returns std::nullopt if it is not supported, which can happen for several
// reasons:
// - The slice dimension can not be found or is not 0 for forward-sinking.
// - The number of slices size does not match the loop iteration count.
// - There is an unexpected shape/size in the overall dependency chain.
// - The buffer to insert into is not a GTE from the loop parameter.
// - The parameter usage is not compatible with the expected pattern.
// - The update slicing is not compatible with the expected pattern.
// - The update index is not monotonic.
// - The output index for the insertion can not be found.
std::optional<std::pair<int64_t, int64_t>> IsSupportedDynamicUpdateSlice(
const HloDynamicUpdateSliceInstruction* dyn_update,
const HloInstruction* instr,
const std::vector<HloInstruction*>& formatting_ops,
CollectivePipeliner::PipeliningDirection direction,
int64_t level_to_operate_on,
const absl::flat_hash_map<int64_t, int64_t>& parameter_gtes_count,
const absl::flat_hash_map<const HloInstruction*, Range>& index_ranges)
const;
void CollectCollectivesToMove(
int64_t level_to_operate_on,
CollectivePipeliner::PipeliningDirection direction,
Expand Down Expand Up @@ -849,6 +871,119 @@ bool WhileLoopAnalysis::ComputeLoopStatistics() {
return true;
}

std::optional<std::pair<int64_t, int64_t>>
WhileLoopAnalysis::IsSupportedDynamicUpdateSlice(
const HloDynamicUpdateSliceInstruction* dyn_update,
const HloInstruction* instr,
const std::vector<HloInstruction*>& formatting_ops,
CollectivePipeliner::PipeliningDirection direction,
int64_t level_to_operate_on,
const absl::flat_hash_map<int64_t, int64_t>& parameter_gtes_count,
const absl::flat_hash_map<const HloInstruction*, Range>& index_ranges)
const {
HloComputation* while_body = while_->while_body();
const HloInstruction* loop_parameter =
while_body->parameter_instructions()[0];
std::optional<int64_t> sliced_dim = GetSlicedDimension(dyn_update);
if (!sliced_dim.has_value()) {
VLOG(5) << "Skipping " << instr->name()
<< " because couldn't find sliced dimension";
return std::nullopt;
}
if (direction == CollectivePipeliner::PipeliningDirection::kForwardSink &&
(*sliced_dim != 0 || dyn_update->shape().dimensions(0) !=
loop_iteration_count_->GetUnsignedValue())) {
VLOG(5) << "Skipping " << instr->name()
<< " because number of iteration of the loop doesn't match "
"slices being inserted or slice dim is not 0. slice_dim = "
<< *sliced_dim
<< " loop count = " << loop_iteration_count_->GetUnsignedValue();
}
if (!process_different_sized_options_) {
if (!formatting_ops.empty()) {
if (instr->operand(0)->shape() != formatting_ops.back()->shape()) {
VLOG(5) << "Skipping " << instr->name()
<< " because operand and last formatting op don't have the "
"same shape";
return std::nullopt;
}
auto dependencies_to_pipeline = CollectDependenciesToPipeline(
instr, absl::MakeConstSpan(formatting_ops));
bool skip_because_not_same_size = false;
// If any instruction in the dependency chain is not of the same size
// then we abort for this instruction.
for (auto* dependency : dependencies_to_pipeline) {
if (ShapeUtil::IsEffectiveScalar(dependency->shape())) {
skip_because_not_same_size = true;
break;
}
}
if (skip_because_not_same_size) {
VLOG(5)
<< "Skipping " << instr->name()
<< " because formatting ops do not have the expected shapes/sizes";
return std::nullopt;
}
} else if (instr->operand(0)->shape() != instr->shape()) {
VLOG(5) << "Skipping " << instr->name()
<< " because instr does not have the same shape as its operand";
return std::nullopt;
}
}
const HloInstruction* to_insert_into = dyn_update->operand(0);
if (level_to_operate_on == 0 &&
(to_insert_into->opcode() != HloOpcode::kGetTupleElement ||
to_insert_into->operand(0) != loop_parameter)) {
VLOG(5) << "Skipping " << instr->name()
<< " because slice to insert into is not a GTE from input "
"parameter "
<< to_insert_into->ToString();
return std::nullopt;
}
// If Level is > 0 then we already did our analysis in the previous
// iteration for safeness of this index to transform.
if (level_to_operate_on == 0) {
if (to_insert_into->opcode() == HloOpcode::kGetTupleElement) {
// GTE for this parameter is not CSEd. Abort because we don't analyze
// every single use from other GTEs.
if (parameter_gtes_count.at(to_insert_into->tuple_index()) != 1) {
VLOG(5) << "Skipping " << instr->name()
<< " because there are multiple parameter GTEs for this slice";
return std::nullopt;
}
}
const HloInstruction* dyn_update_idx = dyn_update->operand(
dyn_update->first_index_operand_number() + *sliced_dim);
if (level_to_operate_on == 0 &&
!CheckParameterUsageIsCompatible(to_insert_into, dyn_update,
dyn_update_idx, *sliced_dim)) {
VLOG(5) << "Skipping " << instr->name()
<< " because parameter usage doesn't follow the expected pattern";
return std::nullopt;
}
if (!AllIndicesConstantsExceptOne(
dyn_update,
dyn_update->first_index_operand_number() + *sliced_dim)) {
VLOG(5) << "Skipping " << instr->name()
<< " because update slicing doesn't match expectation";
return std::nullopt;
}
if (!CheckIndexIsMonotonic(dyn_update_idx, index_ranges)) {
VLOG(5) << "Skipping " << instr->name()
<< " because update index is not monotonic";
return std::nullopt;
}
}
std::optional<int64_t> output_idx = FindOutputIndexForDynamicUpdateSlice(
dyn_update, while_body->root_instruction());
if (!output_idx.has_value()) {
VLOG(5) << "Skipping " << instr->name()
<< " because couldn't find unique output index for insertion";
return std::nullopt;
}
return std::make_pair(*sliced_dim, *output_idx);
}

void WhileLoopAnalysis::CollectCollectivesToMove(
int64_t level_to_operate_on,
CollectivePipeliner::PipeliningDirection direction,
Expand Down Expand Up @@ -927,100 +1062,15 @@ void WhileLoopAnalysis::CollectCollectivesToMove(
"computation";
continue;
}
std::optional<int64_t> sliced_dim = GetSlicedDimension(dyn_update);
if (!sliced_dim.has_value()) {
VLOG(5) << "Skipping " << instr->name()
<< " because couldn't find sliced dimension";
continue;
}
if (direction == CollectivePipeliner::PipeliningDirection::kForwardSink &&
(*sliced_dim != 0 || dyn_update->shape().dimensions(0) !=
loop_iteration_count_->GetUnsignedValue())) {
VLOG(5) << "Skipping " << instr->name()
<< " because number of iteration of the loop doesn't match "
"slices being inserted or slice dim is not 0. slice_dim = "
<< *sliced_dim << " loop count = "
<< loop_iteration_count_->GetUnsignedValue();
}
if (!process_different_sized_options_) {
if (!formatting_ops.empty()) {
if (instr->operand(0)->shape() != formatting_ops.back()->shape()) {
continue;
}
auto dependencies_to_pipeline = CollectDependenciesToPipeline(
instr, absl::MakeConstSpan(formatting_ops));
bool skip_because_not_same_size = false;
// If any instruction in the dependency chain is not of the same size
// then we abort for this instruction.
for (auto* dependency : dependencies_to_pipeline) {
if (ShapeUtil::IsEffectiveScalar(dependency->shape())) {
skip_because_not_same_size = true;
break;
}
}
if (skip_because_not_same_size) {
continue;
}
} else if (instr->operand(0)->shape() != instr->shape()) {
continue;
}
}
const HloInstruction* to_insert_into = dyn_update->operand(0);
if (level_to_operate_on == 0 &&
(to_insert_into->opcode() != HloOpcode::kGetTupleElement ||
to_insert_into->operand(0) != loop_parameter)) {
VLOG(5) << "Skipping " << instr->name()
<< " because slice to insert into is not a GTE from input "
"parameter "
<< to_insert_into->ToString();
continue;
}
if (dyn_update->user_count() != 1) {
continue;
}
// If Level is > 0 then we already did our analysis in the previous
// iteration for safeness of this index to transform.
if (level_to_operate_on == 0) {
if (to_insert_into->opcode() == HloOpcode::kGetTupleElement) {
// GTE for this parameter is not CSEd. Abort because we don't analyze
// every single use from other GTEs.
if (parameter_gtes_count.at(to_insert_into->tuple_index()) != 1) {
VLOG(5)
<< "Skipping " << instr->name()
<< " because there are multiple parameter GTEs for this slice";
continue;
}
}
HloInstruction* dyn_update_idx = dyn_update->mutable_operand(
dyn_update->first_index_operand_number() + *sliced_dim);
if (level_to_operate_on == 0 &&
!CheckParameterUsageIsCompatible(to_insert_into, dyn_update,
dyn_update_idx, *sliced_dim)) {
VLOG(5)
<< "Skipping " << instr->name()
<< " because parameter usage doesn't follow the expected pattern";
continue;
}
if (!AllIndicesConstantsExceptOne(
dyn_update,
dyn_update->first_index_operand_number() + *sliced_dim)) {
VLOG(5) << "Skipping " << instr->name()
<< " because update slicing doesn't match expectation";
continue;
}
if (!CheckIndexIsMonotonic(dyn_update_idx, index_ranges)) {
VLOG(5) << "Skipping " << instr->name()
<< " because update index is not monotonic";
continue;
}
}
std::optional<int64_t> output_idx = FindOutputIndexForDynamicUpdateSlice(
dyn_update, while_body->root_instruction());
if (!output_idx.has_value()) {
VLOG(5) << "Skipping " << instr->name()
<< " because couldn't find unique output index for insertion";
std::optional<std::pair<int64_t, int64_t>> maybe_dus_info =
IsSupportedDynamicUpdateSlice(dyn_update, instr, formatting_ops,
direction, level_to_operate_on,
parameter_gtes_count, index_ranges);
if (!maybe_dus_info.has_value()) {
continue;
}
int64_t sliced_dim = maybe_dus_info->first;
int64_t output_idx = maybe_dus_info->second;
auto merge_as_formatting =
[this, &instruction_order](
absl::flat_hash_map<const HloInstruction*, int64_t>::iterator it,
Expand Down Expand Up @@ -1057,7 +1107,7 @@ void WhileLoopAnalysis::CollectCollectivesToMove(
}
index_per_dyn_update_slice[dyn_update] = move_infos_.size();
move_infos_.push_back({instr, dyn_update, std::move(formatting_ops),
*sliced_dim, *output_idx});
sliced_dim, output_idx});
} else {
CHECK_EQ(direction, CollectivePipeliner::PipeliningDirection::kBackward);
auto chain_collected = CollectChainsToPushBackwards(
Expand Down

0 comments on commit ee3916e

Please sign in to comment.