[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:GPU] Better approximation for costly AR in LHS.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 642552801
  • Loading branch information
golechwierowicz authored and tensorflower-gardener committed Jun 12, 2024
1 parent b94a1d2 commit 2da33f0
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 4 deletions.
8 changes: 8 additions & 0 deletions third_party/xla/xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_enable_all_gather_combine_by_dim(true);
opts.set_xla_gpu_enable_reduce_scatter_combine_by_dim(true);
opts.set_xla_gpu_enable_all_reduce_splitter(true);
opts.set_xla_gpu_enable_approx_costly_collectives(false);

opts.set_xla_gpu_enable_reassociation_for_converted_ar(true);

Expand Down Expand Up @@ -1086,6 +1087,13 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options->xla_gpu_enable_all_reduce_splitter(),
"Splits cross-device all reduce into logical reduce scatter followed by "
"dynamic slice and all reduce."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_approx_costly_collectives",
bool_setter_for(
&DebugOptions::set_xla_gpu_enable_approx_costly_collectives),
debug_options->xla_gpu_enable_approx_costly_collectives(),
"Enables more accurate latency approximation of collectives. Used in "
"`ApproximateLatencyEstimator` scheduler."));
flag_list->push_back(tsl::Flag(
"xla_gpu_all_reduce_blueconnect_num_devices_per_host",
int32_setter_for(
Expand Down
9 changes: 8 additions & 1 deletion third_party/xla/xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,12 @@ absl::Status RunCollectiveOptimizationPasses(
}
if (debug_options.xla_gpu_enable_pipelined_collectives() ||
debug_options.xla_gpu_enable_pipelined_all_gather()) {
// TODO(b/346702380): This constraint relaxation breaks some near-optimal
// schedules for async LHS. This is just the mitigation, the proper fix is
// to add a heuristic to the LHS scheduler which would prefer paths with
// more costly collectives first.
bool acceptable_loop_invariant_op_in_chain =
!debug_options.xla_gpu_enable_approx_costly_collectives();
CollectivePipeliner::Config config{
/*level_to_operate_on=*/0,
/*max_pipelining_per_loop=*/INT64_MAX,
Expand All @@ -885,7 +891,8 @@ absl::Status RunCollectiveOptimizationPasses(
/*should_allow_control_dependencies=*/false,
/*postprocess_backward_peeled_op=*/std::nullopt,
/*postprocess_backward_rotated_op=*/std::nullopt,
/*acceptable_loop_invariant_op_in_chain=*/true};
acceptable_loop_invariant_op_in_chain,
};
collectives_pipeline.AddPass<CollectivePipeliner>(config);
}
if (debug_options.xla_gpu_enable_pipelined_collectives() ||
Expand Down
32 changes: 30 additions & 2 deletions third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ namespace gpu {

namespace {

// A threshold for which we consider AR to be costly perf-wise.
static constexpr int64_t kCostlyAllReduceThreshold = 30 * 1024 * 1024;

// Multiplier which we apply to expand the base cost for the costly AR.
static constexpr int64_t kCostlyAllReduceMultiplier = 4;

bool IsNopInstruction(const HloInstruction& hlo) {
HloOpcode op = hlo.opcode();
return op == HloOpcode::kGetTupleElement || op == HloOpcode::kBitcast ||
Expand Down Expand Up @@ -550,8 +556,9 @@ class GpuAsyncTracker : public GpuAsyncTrackerBase {
class GpuLatencyEstimator : public ApproximateLatencyEstimator {
public:
explicit GpuLatencyEstimator(
int64_t pointer_size,
GetCanonicalAsyncOpFunc func = GpuGetCanonicalAsyncOp)
: ApproximateLatencyEstimator(func) {}
: ApproximateLatencyEstimator(func), pointer_size_(pointer_size) {}
TimeCost NodeCost(const HloInstruction* instr) const override {
if (IsNopInstruction(*instr)) {
return 0.0;
Expand Down Expand Up @@ -582,12 +589,32 @@ class GpuLatencyEstimator : public ApproximateLatencyEstimator {
return ApproximateLatencyEstimator::kHighLatency * 10;
}

bool enable_approx_collectives =
from.GetInstr()
.GetModule()
->config()
.debug_options()
.xla_gpu_enable_approx_costly_collectives();
bool is_all_reduce =
from.GetInstr().opcode() == HloOpcode::kAllReduceStart;
bool collective_size_exceeds_threshold =
GetSizeOfShape(from.GetInstr().shape(), pointer_size_) >
kCostlyAllReduceThreshold;
if (enable_approx_collectives && is_all_reduce &&
collective_size_exceeds_threshold) {
return ApproximateLatencyEstimator::kHighLatency *
kCostlyAllReduceMultiplier;
}

return ApproximateLatencyEstimator::kHighLatency;
}
// Every other instruction we consider synchronous, which means the
// latency between each of them is always one unit.
return ApproximateLatencyEstimator::kLowLatency;
}

private:
int64_t pointer_size_;
};

tensorflow::profiler::ProfiledInstructionsProto GetProfileForFingerprint(
Expand Down Expand Up @@ -828,7 +855,8 @@ absl::StatusOr<ScheduleMetadata> ScheduleGpuModule(
}

SchedulerConfig config = GetSchedulerConfig(memory_limit);
auto gpu_latency_estimator = std::make_unique<GpuLatencyEstimator>();
auto gpu_latency_estimator =
std::make_unique<GpuLatencyEstimator>(pointer_size);

std::unique_ptr<LatencyEstimator> latency_estimator;
std::optional<tensorflow::profiler::ProfiledInstructionsProto> profile =
Expand Down
56 changes: 56 additions & 0 deletions third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,62 @@ TEST_F(GpuHloScheduleTest, LHSCostModel) {
EXPECT_TRUE(HasValidFingerprint(module.get()));
}

TEST_F(GpuHloScheduleTest, LHSCostModelCostlyAR) {
const char* hlo_text = R"(
HloModule AsyncAR
apply_op {
x = bf16[] parameter(0)
y = bf16[] parameter(1)
ROOT apply_op = bf16[] add(x, y)
}
ENTRY ar {
p0 = bf16[32505856] parameter(0)
p1 = f32[32, 32] parameter(1)
p2 = f32[32, 32] parameter(2)
dot0 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm"
dot1 = f32[32,32]{1,0} custom-call(dot0, p2), custom_call_target="__cublas$gemm"
dot2 = f32[32,32]{1,0} custom-call(dot1, p2), custom_call_target="__cublas$gemm"
dot3 = f32[32,32]{1,0} custom-call(dot2, p2), custom_call_target="__cublas$gemm"
dot4 = f32[32,32]{1,0} custom-call(dot3, p2), custom_call_target="__cublas$gemm"
dot5 = f32[32,32]{1,0} custom-call(dot4, p2), custom_call_target="__cublas$gemm"
dot6 = f32[32,32]{1,0} custom-call(dot5, p2), custom_call_target="__cublas$gemm"
ar-start = bf16[32505856] all-reduce-start(p0), to_apply=apply_op
ar-done = bf16[32505856] all-reduce-done(ar-start)
ROOT t = (bf16[32505856], f32[32,32]) tuple(ar-done, dot6)
})";

TF_ASSERT_OK_AND_ASSIGN(
auto module,
ParseAndReturnVerifiedModule(
hlo_text, GetModuleConfig(/*enable_latency_hiding_scheduler=*/true)));
SequentialHloOrdering order = BuildHloOrdering(module.get());

HloComputation* entry = module->entry_computation();
std::vector<int64_t> count_between_pairs;
bool in_between = false;
for (const HloInstruction* inst :
order.SequentialOrder(*entry)->instructions()) {
if (inst->opcode() == HloOpcode::kAllReduceStart) {
in_between = true;
count_between_pairs.push_back(0);
} else if (inst->opcode() == HloOpcode::kAllReduceDone) {
in_between = false;
} else if (in_between && inst->opcode() == HloOpcode::kCustomCall) {
count_between_pairs.back()++;
}
}

EXPECT_EQ(count_between_pairs.size(), 1);
// We pack in 7 medium cost operations into the costly AR.
// By default we pack in at most 5.
EXPECT_EQ(count_between_pairs[0], 7);
EXPECT_TRUE(HasValidFingerprint(module.get()));
}

TEST_F(GpuHloScheduleTest, ProfileGuidedCostModel) {
const char* hlo_text = R"(
HloModule AsyncAR
Expand Down
4 changes: 3 additions & 1 deletion third_party/xla/xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,9 @@ message DebugOptions {

bool xla_gpu_shard_autotuning = 304;

// Next id: 305
bool xla_gpu_enable_approx_costly_collectives = 305;

// Next id: 306

// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
Expand Down

0 comments on commit 2da33f0

Please sign in to comment.