[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:GPU] Use priority fusion in TritonGemmAutotunerExtractor.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 644336930
  • Loading branch information
olegshyshkov authored and tensorflower-gardener committed Jun 18, 2024
1 parent 19b82d6 commit 4b12044
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 17 deletions.
3 changes: 3 additions & 0 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,9 @@ cc_library(
"@local_tsl//tsl/profiler/lib:scoped_annotation",
"//xla/tsl/util/proto:proto_utils",
"//xla/service/gpu:hlo_traversal",
":fusion_wrapper",
":priority_fusion",
"//xla/service/gpu/model:gpu_hlo_cost_analysis",
"//xla/stream_executor:stream_executor_memory_allocator",
"@local_tsl//tsl/platform:path",
]),
Expand Down
40 changes: 23 additions & 17 deletions third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,15 @@ limitations under the License.
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/buffer_comparator.h"
#include "xla/service/gpu/cudnn_fusion_compiler.h"
#include "xla/service/gpu/fusion_wrapper.h"
#include "xla/service/gpu/gemm_rewriter.h"
#include "xla/service/gpu/gpu_float_support.h"
#include "xla/service/gpu/gpu_fusible.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/instruction_fusion.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
#include "xla/service/gpu/priority_fusion.h"
#include "xla/service/gpu/split_k_gemm_rewriter.h"
#include "xla/service/gpu/stream_executor_util.h"
#include "xla/service/hlo_module_config.h"
Expand Down Expand Up @@ -355,22 +357,26 @@ absl::StatusOr<std::unique_ptr<HloModule>> TritonGemmAutotuneExtractor(
BF16);
FloatNormalization float_normalization(&bf16_support);
TF_RETURN_IF_ERROR(float_normalization.Run(new_module.get()).status());
GpuInstructionFusion instruction_fusion(/*may_duplicate=*/false,
gpu_device_info);
TF_RETURN_IF_ERROR(instruction_fusion.Run(new_module.get()).status());
HloInstruction* root = entry_computation->root_instruction();
// If the instruction fusion pass above skipped the reduction, turn it
// into a fusion for a universal set of arguments for execution.
if (root->opcode() == HloOpcode::kReduce) {
HloInstruction* fusion_instruction =
entry_computation->AddInstruction(HloInstruction::CreateFusion(
root->shape(), ChooseFusionKind(*root, *root), root));
HloInstruction* init_value = root->mutable_operand(1);
TF_CHECK_OK(
entry_computation->ReplaceInstruction(root, fusion_instruction));
fusion_instruction->FuseInstruction(init_value);
TF_CHECK_OK(entry_computation->RemoveInstruction(init_value));
}

auto shape_size_function = [&](const Shape& shape) {
// The real pointer size is set in GpuCompiler. In HloCostAnalysis, the
// pointer size is used only to determine the size of tuple types. We
// shouldn't have any tuples in the autotuned module, so it's safe to use
// a constant here, instead of piping the real value.
constexpr int64_t kPointerSize = 8;
return ShapeUtil::ByteSizeOf(shape, kPointerSize);
};
GpuPriorityFusion priority_fusion(
/*thread_pool=*/nullptr, gpu_device_info,
GpuHloCostAnalysis::Options{/*shape_size=*/shape_size_function,
/*per_second_rates=*/{},
/*count_multiple_input_accesses=*/true});
TF_RETURN_IF_ERROR(priority_fusion.Run(new_module.get()).status());

// If the priority fusion pass above skipped some instructions, turn them
// into fusions.
FusionWrapper fusion_wrapper;
TF_RETURN_IF_ERROR(fusion_wrapper.Run(new_module.get()).status());
}
return new_module;
}
Expand Down

0 comments on commit 4b12044

Please sign in to comment.