[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:GPU] Add flag for fully unrolling while loops
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 633196820
  • Loading branch information
golechwierowicz authored and tensorflower-gardener committed May 13, 2024
1 parent ea106b7 commit 676078f
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 23 deletions.
2 changes: 2 additions & 0 deletions third_party/xla/xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_cublas_fallback(true);
opts.set_xla_gpu_cudnn_gemm_fusion_level(0);
opts.set_xla_gpu_enable_while_loop_double_buffering(false);
opts.set_xla_gpu_enable_while_loop_unrolling(
DebugOptions::WHILE_LOOP_UNROLLING_NO_UNROLL);
opts.set_xla_gpu_ensure_minor_dot_contraction_dims(false);
opts.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning(true);
opts.set_xla_gpu_llvm_verification_level(0);
Expand Down
53 changes: 31 additions & 22 deletions third_party/xla/xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1066,36 +1066,27 @@ absl::Status RunPostFusionPasses(
HloModule* hlo_module,
std::function<absl::Status(HloPassPipeline*, const DebugOptions&)>
add_custom_kernel_replacement_passes) {
const DebugOptions& opts = hlo_module->config().debug_options();

HloPassPipeline pipeline("post-fusion optimization");
pipeline.AddPass<RenameFusions>();
pipeline.AddPass<AllGatherCombiner>(
hlo_module->config()
.debug_options()
.xla_gpu_all_gather_combine_threshold_bytes(),
opts.xla_gpu_all_gather_combine_threshold_bytes(),
/*combine_threshold_count=*/256,
hlo_module->config()
.debug_options()
.xla_gpu_enable_all_gather_combine_by_dim());
opts.xla_gpu_enable_all_gather_combine_by_dim());
pipeline.AddPass<AllReduceCombiner>(
hlo_module->config()
.debug_options()
.xla_gpu_all_reduce_combine_threshold_bytes(),
opts.xla_gpu_all_reduce_combine_threshold_bytes(),
/*combine_threshold_count=*/256);
pipeline.AddPass<ReduceScatterCombiner>(
hlo_module->config()
.debug_options()
.xla_gpu_reduce_scatter_combine_threshold_bytes(),
opts.xla_gpu_reduce_scatter_combine_threshold_bytes(),
/*combine_threshold_count=*/256,
hlo_module->config()
.debug_options()
.xla_gpu_enable_reduce_scatter_combine_by_dim());
opts.xla_gpu_enable_reduce_scatter_combine_by_dim());

if (hlo_module->config().debug_options().xla_gpu_all_reduce_contiguous()) {
if (opts.xla_gpu_all_reduce_contiguous()) {
pipeline.AddPass<AllReduceContiguous>();
}

TF_RETURN_IF_ERROR(add_custom_kernel_replacement_passes(
&pipeline, hlo_module->config().debug_options()));
TF_RETURN_IF_ERROR(add_custom_kernel_replacement_passes(&pipeline, opts));

int32_t blueconnect_num_devices_per_host =
hlo_module->config()
Expand All @@ -1105,10 +1096,28 @@ absl::Status RunPostFusionPasses(
pipeline.AddPass<AllReduceBlueConnect>(blueconnect_num_devices_per_host);
}

if (hlo_module->config()
.debug_options()
.xla_gpu_enable_while_loop_double_buffering()) {
pipeline.AddPass<LoopDoubleBufferTransformer>();
std::optional<LoopDoubleBufferTransformer::UnrollStrategy> unroll_strategy =
std::nullopt;
// Support old flag.
if (opts.xla_gpu_enable_while_loop_double_buffering()) {
unroll_strategy =
LoopDoubleBufferTransformer::UnrollStrategy::kDoubleBuffer;
}
// Support new flag setting style, override the old one.
if (opts.xla_gpu_enable_while_loop_unrolling() ==
DebugOptions::WHILE_LOOP_UNROLLING_DOUBLE_BUFFER) {
unroll_strategy =
LoopDoubleBufferTransformer::UnrollStrategy::kDoubleBuffer;
}
if (opts.xla_gpu_enable_while_loop_unrolling() ==
DebugOptions::WHILE_LOOP_UNROLLING_FULL_UNROLL) {
LOG_IF(WARNING, unroll_strategy != std::nullopt)
<< "Overriding double buffering set via "
"`xla_gpu_enable_while_loop_double_buffering` flag.";
unroll_strategy = LoopDoubleBufferTransformer::UnrollStrategy::kFullUnroll;
}
if (unroll_strategy != std::nullopt) {
pipeline.AddPass<LoopDoubleBufferTransformer>(*unroll_strategy);
pipeline.AddPass<TupleSimplifier>();
pipeline.AddPass<HloDCE>();
}
Expand Down
14 changes: 13 additions & 1 deletion third_party/xla/xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,18 @@ message DebugOptions {
// Enable double buffering for loops.
bool xla_gpu_enable_while_loop_double_buffering = 248;

enum WhileLoopUnrolling {
WHILE_LOOP_UNROLLING_NO_UNROLL = 0;
// Has the same effect as setting
// `xla_gpu_enable_while_loop_double_buffering`.
WHILE_LOOP_UNROLLING_DOUBLE_BUFFER = 1;
// Enables full loop unrolling using the same strategy as `DOUBLE_BUFFER`.
WHILE_LOOP_UNROLLING_FULL_UNROLL = 2;
}

// Determine the while loop unrolling scheme.
WhileLoopUnrolling xla_gpu_enable_while_loop_unrolling = 294;

// Change the layout of the second triton dot operand to be column major.
// Only works for (bf16 x bf16) -> bf16.
bool xla_gpu_ensure_minor_dot_contraction_dims = 249;
Expand Down Expand Up @@ -767,7 +779,7 @@ message DebugOptions {
// Base length to rewrite the reduce window to, no rewrite if set to 0.
int64 xla_reduce_window_rewrite_base_length = 293;

// Next id: 294
// Next id: 295

// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
Expand Down

0 comments on commit 676078f

Please sign in to comment.