[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:CollectivePipeliner] Refactor IsSupportedDynamicUpdateSlice into a separate function. #70405

Merged
merged 1 commit into from
Jun 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[XLA:CollectivePipeliner] Refactor IsSupportedDynamicUpdateSlice into…
… a separate function.

PiperOrigin-RevId: 647377136
  • Loading branch information
seherellis authored and tensorflower-gardener committed Jun 27, 2024
commit db9b0a9ac60908801d1a8ecf576b3fbad389be03
240 changes: 145 additions & 95 deletions third_party/xla/xla/service/collective_pipeliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,9 @@ bool IsLoopIterator(const HloInstruction* instr,
// Scavenge operands that are dependencies not included in the ops set and that
// aren't the source_op passed as input parameter and return them in a vector.
std::vector<HloInstruction*> CollectDependenciesToPipeline(
HloInstruction* source_op, absl::Span<HloInstruction* const> ops) {
absl::flat_hash_set<HloInstruction*> formatting_set(ops.begin(), ops.end());
const HloInstruction* source_op, absl::Span<HloInstruction* const> ops) {
absl::flat_hash_set<const HloInstruction*> formatting_set(ops.begin(),
ops.end());
formatting_set.insert(source_op);
std::vector<HloInstruction*> to_return;
absl::flat_hash_set<HloInstruction*> already_inserted;
Expand Down Expand Up @@ -709,6 +710,27 @@ class WhileLoopAnalysis {
int64_t GetMaxPipeliningPerLoop() const { return max_pipelining_per_loop_; }

bool ComputeLoopStatistics();
// Checks if the given dynamic-update-slice is supported for pipelining and
// returns its slice dimension and index in the while tuple if supported.
// Returns std::nullopt if it is not supported, which can happen for several
// reasons:
// - The slice dimension can not be found or is not 0 for forward-sinking.
// - The number of slices size does not match the loop iteration count.
// - There is an unexpected shape/size in the overall dependency chain.
// - The buffer to insert into is not a GTE from the loop parameter.
// - The parameter usage is not compatible with the expected pattern.
// - The update slicing is not compatible with the expected pattern.
// - The update index is not monotonic.
// - The output index for the insertion can not be found.
std::optional<std::pair<int64_t, int64_t>> IsSupportedDynamicUpdateSlice(
const HloDynamicUpdateSliceInstruction* dyn_update,
const HloInstruction* instr,
const std::vector<HloInstruction*>& formatting_ops,
CollectivePipeliner::PipeliningDirection direction,
int64_t level_to_operate_on,
const absl::flat_hash_map<int64_t, int64_t>& parameter_gtes_count,
const absl::flat_hash_map<const HloInstruction*, Range>& index_ranges)
const;
void CollectCollectivesToMove(
int64_t level_to_operate_on,
CollectivePipeliner::PipeliningDirection direction,
Expand Down Expand Up @@ -849,6 +871,119 @@ bool WhileLoopAnalysis::ComputeLoopStatistics() {
return true;
}

std::optional<std::pair<int64_t, int64_t>>
WhileLoopAnalysis::IsSupportedDynamicUpdateSlice(
const HloDynamicUpdateSliceInstruction* dyn_update,
const HloInstruction* instr,
const std::vector<HloInstruction*>& formatting_ops,
CollectivePipeliner::PipeliningDirection direction,
int64_t level_to_operate_on,
const absl::flat_hash_map<int64_t, int64_t>& parameter_gtes_count,
const absl::flat_hash_map<const HloInstruction*, Range>& index_ranges)
const {
HloComputation* while_body = while_->while_body();
const HloInstruction* loop_parameter =
while_body->parameter_instructions()[0];
std::optional<int64_t> sliced_dim = GetSlicedDimension(dyn_update);
if (!sliced_dim.has_value()) {
VLOG(5) << "Skipping " << instr->name()
<< " because couldn't find sliced dimension";
return std::nullopt;
}
if (direction == CollectivePipeliner::PipeliningDirection::kForwardSink &&
(*sliced_dim != 0 || dyn_update->shape().dimensions(0) !=
loop_iteration_count_->GetUnsignedValue())) {
VLOG(5) << "Skipping " << instr->name()
<< " because number of iteration of the loop doesn't match "
"slices being inserted or slice dim is not 0. slice_dim = "
<< *sliced_dim
<< " loop count = " << loop_iteration_count_->GetUnsignedValue();
}
if (!process_different_sized_options_) {
if (!formatting_ops.empty()) {
if (instr->operand(0)->shape() != formatting_ops.back()->shape()) {
VLOG(5) << "Skipping " << instr->name()
<< " because operand and last formatting op don't have the "
"same shape";
return std::nullopt;
}
auto dependencies_to_pipeline = CollectDependenciesToPipeline(
instr, absl::MakeConstSpan(formatting_ops));
bool skip_because_not_same_size = false;
// If any instruction in the dependency chain is not of the same size
// then we abort for this instruction.
for (auto* dependency : dependencies_to_pipeline) {
if (ShapeUtil::IsEffectiveScalar(dependency->shape())) {
skip_because_not_same_size = true;
break;
}
}
if (skip_because_not_same_size) {
VLOG(5)
<< "Skipping " << instr->name()
<< " because formatting ops do not have the expected shapes/sizes";
return std::nullopt;
}
} else if (instr->operand(0)->shape() != instr->shape()) {
VLOG(5) << "Skipping " << instr->name()
<< " because instr does not have the same shape as its operand";
return std::nullopt;
}
}
const HloInstruction* to_insert_into = dyn_update->operand(0);
if (level_to_operate_on == 0 &&
(to_insert_into->opcode() != HloOpcode::kGetTupleElement ||
to_insert_into->operand(0) != loop_parameter)) {
VLOG(5) << "Skipping " << instr->name()
<< " because slice to insert into is not a GTE from input "
"parameter "
<< to_insert_into->ToString();
return std::nullopt;
}
// If Level is > 0 then we already did our analysis in the previous
// iteration for safeness of this index to transform.
if (level_to_operate_on == 0) {
if (to_insert_into->opcode() == HloOpcode::kGetTupleElement) {
// GTE for this parameter is not CSEd. Abort because we don't analyze
// every single use from other GTEs.
if (parameter_gtes_count.at(to_insert_into->tuple_index()) != 1) {
VLOG(5) << "Skipping " << instr->name()
<< " because there are multiple parameter GTEs for this slice";
return std::nullopt;
}
}
const HloInstruction* dyn_update_idx = dyn_update->operand(
dyn_update->first_index_operand_number() + *sliced_dim);
if (level_to_operate_on == 0 &&
!CheckParameterUsageIsCompatible(to_insert_into, dyn_update,
dyn_update_idx, *sliced_dim)) {
VLOG(5) << "Skipping " << instr->name()
<< " because parameter usage doesn't follow the expected pattern";
return std::nullopt;
}
if (!AllIndicesConstantsExceptOne(
dyn_update,
dyn_update->first_index_operand_number() + *sliced_dim)) {
VLOG(5) << "Skipping " << instr->name()
<< " because update slicing doesn't match expectation";
return std::nullopt;
}
if (!CheckIndexIsMonotonic(dyn_update_idx, index_ranges)) {
VLOG(5) << "Skipping " << instr->name()
<< " because update index is not monotonic";
return std::nullopt;
}
}
std::optional<int64_t> output_idx = FindOutputIndexForDynamicUpdateSlice(
dyn_update, while_body->root_instruction());
if (!output_idx.has_value()) {
VLOG(5) << "Skipping " << instr->name()
<< " because couldn't find unique output index for insertion";
return std::nullopt;
}
return std::make_pair(*sliced_dim, *output_idx);
}

void WhileLoopAnalysis::CollectCollectivesToMove(
int64_t level_to_operate_on,
CollectivePipeliner::PipeliningDirection direction,
Expand Down Expand Up @@ -927,100 +1062,15 @@ void WhileLoopAnalysis::CollectCollectivesToMove(
"computation";
continue;
}
std::optional<int64_t> sliced_dim = GetSlicedDimension(dyn_update);
if (!sliced_dim.has_value()) {
VLOG(5) << "Skipping " << instr->name()
<< " because couldn't find sliced dimension";
continue;
}
if (direction == CollectivePipeliner::PipeliningDirection::kForwardSink &&
(*sliced_dim != 0 || dyn_update->shape().dimensions(0) !=
loop_iteration_count_->GetUnsignedValue())) {
VLOG(5) << "Skipping " << instr->name()
<< " because number of iteration of the loop doesn't match "
"slices being inserted or slice dim is not 0. slice_dim = "
<< *sliced_dim << " loop count = "
<< loop_iteration_count_->GetUnsignedValue();
}
if (!process_different_sized_options_) {
if (!formatting_ops.empty()) {
if (instr->operand(0)->shape() != formatting_ops.back()->shape()) {
continue;
}
auto dependencies_to_pipeline = CollectDependenciesToPipeline(
instr, absl::MakeConstSpan(formatting_ops));
bool skip_because_not_same_size = false;
// If any instruction in the dependency chain is not of the same size
// then we abort for this instruction.
for (auto* dependency : dependencies_to_pipeline) {
if (ShapeUtil::IsEffectiveScalar(dependency->shape())) {
skip_because_not_same_size = true;
break;
}
}
if (skip_because_not_same_size) {
continue;
}
} else if (instr->operand(0)->shape() != instr->shape()) {
continue;
}
}
const HloInstruction* to_insert_into = dyn_update->operand(0);
if (level_to_operate_on == 0 &&
(to_insert_into->opcode() != HloOpcode::kGetTupleElement ||
to_insert_into->operand(0) != loop_parameter)) {
VLOG(5) << "Skipping " << instr->name()
<< " because slice to insert into is not a GTE from input "
"parameter "
<< to_insert_into->ToString();
continue;
}
if (dyn_update->user_count() != 1) {
continue;
}
// If Level is > 0 then we already did our analysis in the previous
// iteration for safeness of this index to transform.
if (level_to_operate_on == 0) {
if (to_insert_into->opcode() == HloOpcode::kGetTupleElement) {
// GTE for this parameter is not CSEd. Abort because we don't analyze
// every single use from other GTEs.
if (parameter_gtes_count.at(to_insert_into->tuple_index()) != 1) {
VLOG(5)
<< "Skipping " << instr->name()
<< " because there are multiple parameter GTEs for this slice";
continue;
}
}
HloInstruction* dyn_update_idx = dyn_update->mutable_operand(
dyn_update->first_index_operand_number() + *sliced_dim);
if (level_to_operate_on == 0 &&
!CheckParameterUsageIsCompatible(to_insert_into, dyn_update,
dyn_update_idx, *sliced_dim)) {
VLOG(5)
<< "Skipping " << instr->name()
<< " because parameter usage doesn't follow the expected pattern";
continue;
}
if (!AllIndicesConstantsExceptOne(
dyn_update,
dyn_update->first_index_operand_number() + *sliced_dim)) {
VLOG(5) << "Skipping " << instr->name()
<< " because update slicing doesn't match expectation";
continue;
}
if (!CheckIndexIsMonotonic(dyn_update_idx, index_ranges)) {
VLOG(5) << "Skipping " << instr->name()
<< " because update index is not monotonic";
continue;
}
}
std::optional<int64_t> output_idx = FindOutputIndexForDynamicUpdateSlice(
dyn_update, while_body->root_instruction());
if (!output_idx.has_value()) {
VLOG(5) << "Skipping " << instr->name()
<< " because couldn't find unique output index for insertion";
std::optional<std::pair<int64_t, int64_t>> maybe_dus_info =
IsSupportedDynamicUpdateSlice(dyn_update, instr, formatting_ops,
direction, level_to_operate_on,
parameter_gtes_count, index_ranges);
if (!maybe_dus_info.has_value()) {
continue;
}
int64_t sliced_dim = maybe_dus_info->first;
int64_t output_idx = maybe_dus_info->second;
auto merge_as_formatting =
[this, &instruction_order](
absl::flat_hash_map<const HloInstruction*, int64_t>::iterator it,
Expand Down Expand Up @@ -1057,7 +1107,7 @@ void WhileLoopAnalysis::CollectCollectivesToMove(
}
index_per_dyn_update_slice[dyn_update] = move_infos_.size();
move_infos_.push_back({instr, dyn_update, std::move(formatting_ops),
*sliced_dim, *output_idx});
sliced_dim, output_idx});
} else {
CHECK_EQ(direction, CollectivePipeliner::PipeliningDirection::kBackward);
auto chain_collected = CollectChainsToPushBackwards(
Expand Down