[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:GPU] Use Cost Model to choose tile sizes in SoftmaxRewriterTriton.
Browse files Browse the repository at this point in the history
After this change, all SoftMax fusions create in the rewriter should have BlockLevelFusionConfig with tile parameters to be used in the emitter.

PiperOrigin-RevId: 645033335
  • Loading branch information
olegshyshkov authored and tensorflower-gardener committed Jun 20, 2024
1 parent 93fcc5e commit f05e433
Show file tree
Hide file tree
Showing 7 changed files with 284 additions and 148 deletions.
9 changes: 8 additions & 1 deletion third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1555,6 +1555,7 @@ cc_library(
hdrs = ["softmax_rewriter_triton.h"],
deps = [
":backend_configs_cc",
":hlo_traversal",
":ir_emission_utils",
":triton_support",
"//xla:shape_util",
Expand All @@ -1563,9 +1564,13 @@ cc_library(
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_query",
"//xla/service:hlo_cost_analysis",
"//xla/service:hlo_pass",
"//xla/service:instruction_fusion",
"//xla/service/gpu/model:fusion_analysis_cache",
"//xla/service/gpu/model:gpu_indexing_performance_model",
"//xla/service/gpu/model:symbolic_tile_analysis",
"//xla/service/gpu/model:tiled_hlo_computation",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down Expand Up @@ -2585,12 +2590,14 @@ xla_cc_test(
name = "softmax_rewriter_triton_test",
srcs = ["softmax_rewriter_triton_test.cc"],
deps = [
":backend_configs_cc",
":gpu_device_info_for_tests",
":softmax_rewriter_triton",
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/service:instruction_fusion",
"//xla/service:pattern_matcher",
"//xla/service:pattern_matcher_gmock",
"//xla/service/gpu/model:gpu_hlo_cost_analysis",
"//xla/stream_executor:device_description",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main", # build_cleaner: keep
Expand Down
3 changes: 2 additions & 1 deletion third_party/xla/xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1380,7 +1380,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) {
pipeline.AddPass<HloPassFix<GpuAlgebraicSimplifier>>(simplifier_options,
gpu_version);
pipeline.AddPass<SoftmaxRewriterTriton>(gpu_version);
pipeline.AddPass<SoftmaxRewriterTriton>(
gpu_target_config.device_description, ShapeSizeBytesFunction());
}

pipeline.AddPass<ReductionDimensionGrouper>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ LaunchDimensions GetLaunchDimensionsForTiledFusion(
static_cast<uint64_t>(num_warps * WarpSize())};
}

absl::StatusOr<std::variant<TiledRunTimeData, FusionDecision>>
absl::StatusOr<TiledRunTimeDataOrError>
GpuPerformanceModelWithIndexingAnalysis::TryFindBestTilingForFusion(
const HloFusionAdaptor& fusion_adaptor) {
SymbolicTileAnalysisOrError analysis_or_error =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ struct TiledRunTimeData {
BlockLevelParameters block_level_parameters;
};

using TiledRunTimeDataOrError = std::variant<TiledRunTimeData, FusionDecision>;

// Implementation of Cost Model that uses indexing analysis to estimate amount
// of compute and memory access time.
class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase {
Expand Down Expand Up @@ -105,8 +107,8 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase {
// Returns FusionDecision if the fusion can't be tiled or there are no valid
// block level parameters.
// Otherwise returns block level parameters that give the best execution time.
absl::StatusOr<std::variant<TiledRunTimeData, FusionDecision>>
TryFindBestTilingForFusion(const HloFusionAdaptor& fusion_adaptor);
absl::StatusOr<TiledRunTimeDataOrError> TryFindBestTilingForFusion(
const HloFusionAdaptor& fusion_adaptor);

private:
// Returns an estimate how many FLOPs will be used to produce one element of
Expand Down
65 changes: 51 additions & 14 deletions third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.

#include <functional>
#include <string>
#include <utility>
#include <variant>
#include <vector>

Expand All @@ -35,8 +36,12 @@ limitations under the License.
#include "xla/hlo/utils/hlo_query.h"
#include "xla/layout_util.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/model/fusion_analysis_cache.h"
#include "xla/service/gpu/model/gpu_indexing_performance_model.h"
#include "xla/service/gpu/model/symbolic_tile_analysis.h"
#include "xla/service/gpu/model/tiled_hlo_computation.h"
#include "xla/service/gpu/triton_support.h"
#include "xla/service/instruction_fusion.h"
#include "xla/shape.h"
Expand Down Expand Up @@ -397,11 +402,35 @@ absl::StatusOr<HloFusionInstruction*> MakeFusionForDiamondChain(
return xla::Cast<HloFusionInstruction>(softmax_fusion);
}

absl::Status FuseDiamondChainImpl(const DiamondChainDescriptor& diamond_chain) {
absl::Status FuseDiamondChainImpl(
const DiamondChainDescriptor& diamond_chain,
GpuPerformanceModelWithIndexingAnalysis& indexing_performance_model) {
TF_ASSIGN_OR_RETURN(HloFusionInstruction * softmax_fusion,
MakeFusionForDiamondChain(diamond_chain));
HloInstruction* root = diamond_chain.root;

auto fusion_adaptor = HloFusionAdaptor::ForInstruction(softmax_fusion);

TF_ASSIGN_OR_RETURN(
TiledRunTimeDataOrError tiled_runtime_data_or,
indexing_performance_model.TryFindBestTilingForFusion(*fusion_adaptor));

if (const auto* fusion_decision =
std::get_if<FusionDecision>(&tiled_runtime_data_or)) {
return absl::FailedPreconditionError(absl::StrCat(
"SymbolicTileAnalysis failed. ", fusion_decision->Explain()));
}

TiledRunTimeData tiled_runtime_data =
std::get<TiledRunTimeData>(std::move(tiled_runtime_data_or));

TF_ASSIGN_OR_RETURN(auto backend_config,
softmax_fusion->backend_config<GpuBackendConfig>());
*backend_config.mutable_fusion_backend_config()
->mutable_block_level_fusion_config() =
tiled_runtime_data.block_level_parameters.ToBlockLevelFusionConfig();
TF_RETURN_IF_ERROR(softmax_fusion->set_backend_config(backend_config));

if (root->IsRoot()) {
root->parent()->set_root_instruction(softmax_fusion);
TF_RETURN_IF_ERROR(
Expand Down Expand Up @@ -446,7 +475,8 @@ SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond(
return "Root is not elementwise binary.";
}

if (!legacy_triton::IsTritonSupportedInstruction(*instr, gpu_version_)) {
if (!legacy_triton::IsTritonSupportedInstruction(
*instr, device_info_.gpu_compute_capability())) {
return "Root is not supported for Triton instruction.";
}

Expand All @@ -455,12 +485,12 @@ SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond(
HloInstruction* reduce;

if (!TrivialEdge(&broadcast, instr->mutable_operand(1), HloOpcode::kBroadcast,
gpu_version_)) {
device_info_.gpu_compute_capability())) {
return "Could not find a trivial connection from root to a broadcast.";
}

if (!TrivialEdge(&reduce, broadcast->mutable_operand(0), HloOpcode::kReduce,
gpu_version_)) {
device_info_.gpu_compute_capability())) {
return "Could not find a trivial connection from matched broadcast to a "
"reduction.";
}
Expand All @@ -471,7 +501,8 @@ SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond(
}

if (CodegenDecision is_supported =
legacy_triton::IsTritonSupportedInstruction(*reduce, gpu_version_);
legacy_triton::IsTritonSupportedInstruction(
*reduce, device_info_.gpu_compute_capability());
!is_supported) {
VLOG(3) << is_supported.Explain();
return is_supported;
Expand All @@ -488,7 +519,7 @@ SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond(
return "Broadcast is not along the reduction dimension.";
}

while (IsTriviallyFusible(producer, gpu_version_)) {
while (IsTriviallyFusible(producer, device_info_.gpu_compute_capability())) {
producer = ChooseOperandForFusionProcessing(producer);
}

Expand All @@ -497,7 +528,7 @@ SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond(
}

if (!IsTriviallyConnectedProducerOf(producer, instr->mutable_operand(0),
gpu_version_)) {
device_info_.gpu_compute_capability())) {
return "Producer is not trivially connected.";
}

Expand Down Expand Up @@ -579,7 +610,8 @@ SoftmaxRewriterTriton::FindAllFusibleDiamondChains(

auto last_trivially_fusible_user = [&](HloInstruction* instr) {
while (HasOneUse(instr) && !instr->IsRoot() &&
IsTriviallyFusible(instr->users().front(), gpu_version_)) {
IsTriviallyFusible(instr->users().front(),
device_info_.gpu_compute_capability())) {
instr = instr->users().front();
}

Expand All @@ -588,7 +620,7 @@ SoftmaxRewriterTriton::FindAllFusibleDiamondChains(
// restriction.
if (HasOneUse(instr) && !instr->IsRoot() &&
IsTriviallyFusible(
instr->users().front(), gpu_version_,
instr->users().front(), device_info_.gpu_compute_capability(),
/*num_allowed_users=*/instr->users().front()->user_count())) {
instr = instr->users().front();
}
Expand All @@ -615,7 +647,7 @@ SoftmaxRewriterTriton::FindAllFusibleDiamondChains(
diamond_chains.reserve(matched_diamonds.size());

HloInstruction* current_fusion_producer = FindFirstNonFusibleDiamondProducer(
matched_diamonds.front().producer, gpu_version_);
matched_diamonds.front().producer, device_info_.gpu_compute_capability());
int current_reduce_dimension_size =
reduction_dimension_size_from_diamond_root(matched_diamonds.front().root);

Expand All @@ -626,7 +658,8 @@ SoftmaxRewriterTriton::FindAllFusibleDiamondChains(
matched_diamonds[diamond_idx - 1].root;

HloInstruction* first_non_fusible_diamond_producer =
FindFirstNonFusibleDiamondProducer(diamond_producer, gpu_version_);
FindFirstNonFusibleDiamondProducer(
diamond_producer, device_info_.gpu_compute_capability());

int diamond_reduce_dimension_size =
reduction_dimension_size_from_diamond_root(diamond_root);
Expand Down Expand Up @@ -678,14 +711,18 @@ SoftmaxRewriterTriton::FindAllFusibleDiamondChains(

absl::Status SoftmaxRewriterTriton::FuseDiamondChain(
const DiamondChainDescriptor& diamond_chain) {
return FuseDiamondChainImpl(diamond_chain);
HloFusionAnalysisCache fusion_analysis_cache(device_info_);
GpuPerformanceModelWithIndexingAnalysis indexing_performance_model(
&device_info_, &fusion_analysis_cache, shape_size_, &mlir_context_);

return FuseDiamondChainImpl(diamond_chain, indexing_performance_model);
}

absl::StatusOr<bool> SoftmaxRewriterTriton::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
auto cuda_compute_capability =
std::get_if<se::CudaComputeCapability>(&gpu_version_);
auto cuda_compute_capability = std::get_if<se::CudaComputeCapability>(
&device_info_.gpu_compute_capability());
if (!cuda_compute_capability) {
return absl::FailedPreconditionError(
"Triton support is only enabled for CUDA GPUs.");
Expand Down
12 changes: 9 additions & 3 deletions third_party/xla/xla/service/gpu/softmax_rewriter_triton.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/hlo_cost_analysis.h"
#include "xla/service/hlo_pass_interface.h"
#include "xla/service/instruction_fusion.h"
#include "xla/stream_executor/device_description.h"
Expand All @@ -43,8 +45,10 @@ using DiamondMatchingDecision = std::variant<FusionDecision, HloInstruction*>;
// with the Triton-based Softmax emitter.
class SoftmaxRewriterTriton : public HloModulePass {
public:
explicit SoftmaxRewriterTriton(se::GpuComputeCapability gpu_version)
: gpu_version_(gpu_version) {}
explicit SoftmaxRewriterTriton(const se::DeviceDescription& device_info,
HloCostAnalysis::ShapeSizeFunction shape_size)
: device_info_(device_info), shape_size_(shape_size) {}

absl::string_view name() const override { return "triton-softmax-rewriter"; }

using HloPassInterface::Run;
Expand Down Expand Up @@ -86,7 +90,9 @@ class SoftmaxRewriterTriton : public HloModulePass {
HloInstruction* instr) const;

private:
se::GpuComputeCapability gpu_version_;
const se::DeviceDescription& device_info_;
const HloCostAnalysis::ShapeSizeFunction shape_size_;
mlir::MLIRContext mlir_context_;
};

} // namespace gpu
Expand Down
Loading

0 comments on commit f05e433

Please sign in to comment.