[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:GPU] Add a method to Cost Model estimate the best tiling for a f…
Browse files Browse the repository at this point in the history
…usion.

The idea is to iterate over all the tiling options from `SymbolicTileAnalysis::GetGoodTilings()`. Run Cost Model on each tiling and choose the one with the best execution time.

This method will be used in TritonSoftMaxRewriter and PriorityFusion to assign block level parameters config to fusions.

PiperOrigin-RevId: 644763688
  • Loading branch information
olegshyshkov authored and tensorflower-gardener committed Jun 19, 2024
1 parent 8e5c47c commit d1766d9
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 18 deletions.
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ cc_library(
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:launch_dimensions",
"//xla/service/gpu/fusions:triton",
"//xla/stream_executor:device_description",
Expand Down Expand Up @@ -386,6 +387,7 @@ xla_cc_test(
"//xla/hlo/ir:hlo",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:gpu_device_info_for_tests",
"//xla/service/gpu:hlo_traversal",
"//xla/stream_executor:device_description",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
Expand Down
121 changes: 103 additions & 18 deletions third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include <algorithm>
#include <cstdint>
#include <optional>
#include <utility>
#include <variant>
#include <vector>
Expand All @@ -34,6 +35,7 @@ limitations under the License.
#include "xla/service/gpu/fusions/triton.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/gpu/model/coalescing_analysis.h"
#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
Expand Down Expand Up @@ -250,25 +252,11 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimes(
return {time_unfused, time_fused};
}

absl::StatusOr<EstimateRunTimeData>
GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledFusion(
EstimateRunTimeData
GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledHloComputation(
const HloFusionAdaptor& fusion_adaptor,
const LaunchDimensions& launch_dimensions,
absl::Span<const int64_t> tile_sizes) {
// TODO(b/332714755): Add caching for SymbolicTileAnalysis.
SymbolicTileAnalysisOrError analysis_or_error =
SymbolicTileAnalysis::AnalyzeFusion(fusion_adaptor, mlir_context_);
if (!std::holds_alternative<SymbolicTileAnalysis>(analysis_or_error)) {
return absl::FailedPreconditionError(
absl::StrCat("SymbolicTileAnalysis failed. ",
std::get<FusionDecision>(analysis_or_error).Explain()));
}
SymbolicTileAnalysis analysis =
std::get<SymbolicTileAnalysis>(std::move(analysis_or_error));

TF_ASSIGN_OR_RETURN(TiledHloComputation tiled_hlo_computation,
analysis.ComputeTiledHloInstructions(tile_sizes));

const TiledHloComputation& tiled_hlo_computation,
const LaunchDimensions& launch_dimensions) {
absl::flat_hash_map<const HloInstruction*, int64_t> n_bytes_total_map;

int64_t flops = 0;
Expand Down Expand Up @@ -334,6 +322,29 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledFusion(
/*exec_time=*/exec_time};
}

absl::StatusOr<EstimateRunTimeData>
GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledFusion(
const HloFusionAdaptor& fusion_adaptor,
const LaunchDimensions& launch_dimensions,
absl::Span<const int64_t> tile_sizes) {
// TODO(b/332714755): Add caching for SymbolicTileAnalysis.
SymbolicTileAnalysisOrError analysis_or_error =
SymbolicTileAnalysis::AnalyzeFusion(fusion_adaptor, mlir_context_);
if (const auto* fusion_decision =
std::get_if<FusionDecision>(&analysis_or_error)) {
return absl::FailedPreconditionError(absl::StrCat(
"SymbolicTileAnalysis failed. ", fusion_decision->Explain()));
}
SymbolicTileAnalysis analysis =
std::get<SymbolicTileAnalysis>(std::move(analysis_or_error));

TF_ASSIGN_OR_RETURN(TiledHloComputation tiled_hlo_computation,
analysis.ComputeTiledHloInstructions(tile_sizes));

return EstimateRunTimeForTiledHloComputation(
fusion_adaptor, tiled_hlo_computation, launch_dimensions);
}

absl::StatusOr<EstimateRunTimeData>
GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTriton(
const HloInstruction* producer, const HloInstruction* consumer) {
Expand All @@ -352,5 +363,79 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTriton(
launch_config->output_tile_sizes);
}

// Returns the number of warps to use based on the tile size. The numbers were
// originally selected from Triton SoftMax reduction row length.
// TODO(b/332714755): Make it smarter.
int64_t GetNumWarps(int64_t tile_size) {
if (tile_size <= 512) return 1;
if (tile_size <= 1024) return 2;
if (tile_size <= 16384) return 4;
if (tile_size <= 32768) return 8;
if (tile_size <= 65536) return 16;
return 32;
}

LaunchDimensions GetLaunchDimensionsForTiledFusion(
const TiledHloComputation& tiled_hlo_computation) {
const auto* tiled_root = tiled_hlo_computation.GetRoot();
int64_t num_blocks = tiled_root->block_id_to_tile_offsets_indexing()
.GetDimensionBound(0)
.GetLoopTripCount();

int64_t num_warps = GetNumWarps(Product(tiled_root->tile_sizes()));

return {static_cast<uint64_t>(num_blocks),
static_cast<uint64_t>(num_warps * WarpSize())};
}

absl::StatusOr<std::variant<TiledRunTimeData, FusionDecision>>
GpuPerformanceModelWithIndexingAnalysis::TryFindBestTilingForFusion(
const HloFusionAdaptor& fusion_adaptor) {
SymbolicTileAnalysisOrError analysis_or_error =
SymbolicTileAnalysis::AnalyzeFusion(fusion_adaptor, mlir_context_);

if (const auto* fusion_decision =
std::get_if<FusionDecision>(&analysis_or_error)) {
return *fusion_decision;
}

SymbolicTileAnalysis analysis =
std::get<SymbolicTileAnalysis>(std::move(analysis_or_error));

TF_ASSIGN_OR_RETURN(auto tilings, analysis.GetGoodTilings());

std::optional<TiledRunTimeData> best_tiled_run_time_data;

for (const auto& tiling : tilings) {
TF_ASSIGN_OR_RETURN(TiledHloComputation tiled_hlo_computation,
analysis.ComputeTiledHloInstructions(tiling));

LaunchDimensions launch_dimensions =
GetLaunchDimensionsForTiledFusion(tiled_hlo_computation);

EstimateRunTimeData estimate_run_time_data =
EstimateRunTimeForTiledHloComputation(
fusion_adaptor, tiled_hlo_computation, launch_dimensions);

if (!best_tiled_run_time_data.has_value() ||
estimate_run_time_data.exec_time <
best_tiled_run_time_data->runtime_data.exec_time) {
BlockLevelParameters block_level_parameters;
block_level_parameters.output_tile_sizes =
std::vector<int64_t>(tiling.begin(), tiling.end());
block_level_parameters.num_warps =
launch_dimensions.num_threads_per_block() / WarpSize();

best_tiled_run_time_data =
TiledRunTimeData{estimate_run_time_data, block_level_parameters};
}
}

if (!best_tiled_run_time_data.has_value()) {
return FusionDecision("No valid tilings found.");
}
return *best_tiled_run_time_data;
}

} // namespace gpu
} // namespace xla
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.

#include <cstddef>
#include <cstdint>
#include <variant>
#include <vector>

#include "absl/status/statusor.h"
Expand All @@ -30,13 +31,22 @@ limitations under the License.
#include "xla/service/gpu/model/fusion_analysis_cache.h"
#include "xla/service/gpu/model/gpu_performance_model_base.h"
#include "xla/service/gpu/model/hlo_op_profiles.h"
#include "xla/service/gpu/model/symbolic_tile_analysis.h"
#include "xla/service/gpu/model/tiled_hlo_computation.h"
#include "xla/service/hlo_cost_analysis.h"
#include "xla/service/instruction_fusion.h"
#include "xla/shape.h"
#include "xla/stream_executor/device_description.h"

namespace xla {
namespace gpu {

// Contains informations about block level parameters and run time of a fusion.
struct TiledRunTimeData {
EstimateRunTimeData runtime_data;
BlockLevelParameters block_level_parameters;
};

// Implementation of Cost Model that uses indexing analysis to estimate amount
// of compute and memory access time.
class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase {
Expand Down Expand Up @@ -65,6 +75,11 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase {
const HloInstruction* producer,
absl::Span<const HloInstruction* const> fused_consumers = {});

EstimateRunTimeData EstimateRunTimeForTiledHloComputation(
const HloFusionAdaptor& fusion_adaptor,
const TiledHloComputation& tiled_hlo_computation,
const LaunchDimensions& launch_dimensions);

// Estimate the run time of the fusion with the given launch dimensions and
// output tile sizes.
//
Expand All @@ -82,6 +97,17 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase {
absl::StatusOr<EstimateRunTimeData> EstimateRunTimeForTriton(
const HloInstruction* producer, const HloInstruction* consumer = nullptr);

// Estimates the best tile sizes for the given fusion. Iterates over all the
// good tile sizes provided by SymbolicTileAnalysis, estimates the run time
// for each of them.
//
// Returns status if there is an error that we can't recover from.
// Returns FusionDecision if the fusion can't be tiled or there are no valid
// block level parameters.
// Otherwise returns block level parameters that give the best execution time.
absl::StatusOr<std::variant<TiledRunTimeData, FusionDecision>>
TryFindBestTilingForFusion(const HloFusionAdaptor& fusion_adaptor);

private:
// Returns an estimate how many FLOPs will be used to produce one element of
// the output.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ limitations under the License.

#include <cstdint>
#include <memory>
#include <variant>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
Expand All @@ -26,6 +28,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/gpu_device_info_for_tests.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/model/fusion_analysis_cache.h"
#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
#include "xla/service/gpu/model/gpu_performance_model_base.h"
Expand All @@ -39,6 +42,8 @@ namespace xla {
namespace gpu {
namespace {

using ::testing::ElementsAre;

class GpuIndexingPerformanceModelTest : public HloTestBase {
GpuHloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const {
return [&](const Shape& shape) {
Expand Down Expand Up @@ -270,6 +275,66 @@ ENTRY main {
EXPECT_NEAR(absl::ToDoubleMicroseconds(runtime_data.exec_time), 5, 1);
}

TEST_F(GpuIndexingPerformanceModelTest,
EstimateBestTiling_TritonSoftmax_IsSupported) {
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
HloModule m
add {
Arg_0 = f32[] parameter(0)
Arg_1 = f32[] parameter(1)
ROOT add = f32[] add(Arg_0, Arg_1)
}
triton_softmax_computation {
param_0 = f32[512,911]{1,0} parameter(0)
param_1 = f32[911]{0} parameter(1)
broadcast_0 = f32[512,911]{1,0} broadcast(param_1), dimensions={1}
multiply_0 = f32[512,911]{1,0} multiply(param_0, broadcast_0)
constant_0 = f32[] constant(0)
reduce_0 = f32[512]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add
broadcast_4 = f32[512,911]{1,0} broadcast(reduce_0), dimensions={0}
ROOT multiply = f32[512,911]{1,0} multiply(multiply_0, broadcast_4)
}
ENTRY main {
param_0 = f32[512,911]{1,0} parameter(0)
param_1 = f32[911]{0} parameter(1)
ROOT triton_softmax = f32[512,911]{1,0} fusion(param_0, param_1), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}}
}
)"));
auto fusion_adaptor = HloFusionAdaptor::ForInstruction(
module->entry_computation()->root_instruction());

TF_ASSERT_OK_AND_ASSIGN(
auto tiling_result,
indexing_cost_model_.TryFindBestTilingForFusion(*fusion_adaptor));

ASSERT_TRUE(std::holds_alternative<TiledRunTimeData>(tiling_result));

auto tiled_runtime_data = std::get<TiledRunTimeData>(tiling_result);

constexpr int64_t kParam0SizeBytes = 512 * 911 * 4;
constexpr int64_t kParam1SizeBytes = 911 * 4;
constexpr int64_t kOutputSizeBytes = 512 * 911 * 4;

// Launch grid consists of 128 blocks. Each block reads 1 tile of shape [4,
// 911] from param_0 and full param_1. In total param_0 is read once and
// param_1 is read 128 times.
constexpr int64_t kExpectedBytesRead =
kParam0SizeBytes + 128 * kParam1SizeBytes;

EXPECT_THAT(tiled_runtime_data.block_level_parameters.output_tile_sizes,
ElementsAre(4, 911));
EXPECT_EQ(tiled_runtime_data.block_level_parameters.num_warps, 4);

EXPECT_EQ(tiled_runtime_data.runtime_data.bytes_read, kExpectedBytesRead);
EXPECT_EQ(tiled_runtime_data.runtime_data.bytes_written, kOutputSizeBytes);
EXPECT_NEAR(
absl::ToDoubleMicroseconds(tiled_runtime_data.runtime_data.exec_time), 5,
1);
}

} // namespace
} // namespace gpu
} // namespace xla

0 comments on commit d1766d9

Please sign in to comment.