[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:GPU] Add initial version of cost model for tiled hlo.
Browse files Browse the repository at this point in the history
Heuristics for compute and memory access time are shared with GpuPerformanceModel. The main difference is that we assume that each tile is read or computed only once and results are shared between threads via shared memory.

PiperOrigin-RevId: 644313383
  • Loading branch information
olegshyshkov authored and tensorflower-gardener committed Jun 18, 2024
1 parent dc76516 commit cd75b11
Show file tree
Hide file tree
Showing 5 changed files with 258 additions and 2 deletions.
12 changes: 12 additions & 0 deletions third_party/xla/xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ xla_cc_test(
name = "gpu_performance_model_test",
srcs = ["gpu_performance_model_test.cc"],
deps = [
":fusion_analysis_cache",
":gpu_hlo_cost_analysis",
":gpu_indexing_performance_model",
":gpu_performance_model",
Expand Down Expand Up @@ -341,32 +342,43 @@ cc_library(
hdrs = ["gpu_indexing_performance_model.h"],
deps = [
":coalescing_analysis",
":fusion_analysis_cache",
":gpu_hlo_cost_analysis",
":gpu_performance_model_base",
":hlo_op_profiles",
":indexing_analysis",
":symbolic_tile_analysis",
":tiled_hlo_computation",
"//xla:shape_util",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/service:hlo_cost_analysis",
"//xla/service:instruction_fusion",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:launch_dimensions",
"//xla/service/gpu/fusions:triton",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@llvm-project//mlir:IR",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:statusor",
],
)

xla_cc_test(
name = "gpu_indexing_performance_model_test",
srcs = ["gpu_indexing_performance_model_test.cc"],
deps = [
":fusion_analysis_cache",
":gpu_hlo_cost_analysis",
":gpu_indexing_performance_model",
":gpu_performance_model_base",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,21 @@ limitations under the License.

#include <algorithm>
#include <cstdint>
#include <utility>
#include <variant>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/fusions/triton.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/launch_dimensions.h"
Expand All @@ -34,10 +40,14 @@ limitations under the License.
#include "xla/service/gpu/model/gpu_performance_model_base.h"
#include "xla/service/gpu/model/indexing_analysis.h"
#include "xla/service/gpu/model/indexing_map.h"
#include "xla/service/gpu/model/symbolic_tile_analysis.h"
#include "xla/service/gpu/model/tiled_hlo_computation.h"
#include "xla/service/instruction_fusion.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/util.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace gpu {
Expand Down Expand Up @@ -240,5 +250,107 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimes(
return {time_unfused, time_fused};
}

absl::StatusOr<EstimateRunTimeData>
GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledFusion(
const HloFusionAdaptor& fusion_adaptor,
const LaunchDimensions& launch_dimensions,
const std::vector<int64_t>& tile_sizes) {
// TODO(b/332714755): Add caching for SymbolicTileAnalysis.
SymbolicTileAnalysisOrError analysis_or_error =
SymbolicTileAnalysis::AnalyzeFusion(fusion_adaptor, mlir_context_);
if (!std::holds_alternative<SymbolicTileAnalysis>(analysis_or_error)) {
return absl::FailedPreconditionError(
absl::StrCat("SymbolicTileAnalysis failed. ",
std::get<FusionDecision>(analysis_or_error).Explain()));
}
SymbolicTileAnalysis analysis =
std::get<SymbolicTileAnalysis>(std::move(analysis_or_error));

TF_ASSIGN_OR_RETURN(TiledHloComputation tiled_hlo_computation,
analysis.ComputeTiledHloInstructions(tile_sizes));

absl::flat_hash_map<const HloInstruction*, int64_t> n_bytes_total_map;

int64_t flops = 0;
int64_t bytes_read = 0;

for (const auto& tiled_hlo : tiled_hlo_computation.instructions()) {
// Number of blocks that read or compute this tile.
int64_t num_blocks = tiled_hlo->block_id_to_tile_offsets_indexing()
.GetDimensionBound(0)
.GetLoopTripCount();

// Total number of elements that are read from memory or computed for this
// tile across all blocks.
int64_t num_elements = num_blocks * Product(tiled_hlo->tile_sizes());

const HloInstruction* hlo = tiled_hlo->hlo();

if (fusion_adaptor.ContainsInstruction(hlo)) {
// Tiles inside the computation contribute to the total FLOPs count.
flops += FlopsPerElement(hlo) * num_elements;
} else {
// Tiles of the operands of the fusion contribute to the total memory
// read time.
int64_t element_type_size =
ShapeUtil::ByteSizeOfPrimitiveType(hlo->shape().element_type());
int64_t tile_bytes_read = element_type_size * num_elements;

bytes_read += tile_bytes_read;
n_bytes_total_map[hlo] += tile_bytes_read;
}
}

int64_t num_blocks = launch_dimensions.num_blocks();
absl::Duration read_time = absl::ZeroDuration();
for (const auto& [hlo, n_bytes_total] : n_bytes_total_map) {
int64_t operand_size = shape_size_(hlo->shape());
int64_t n_bytes_net = std::min(operand_size, n_bytes_total);

read_time += ReadTimeWithDRAMHeuristic(
*device_info_, num_blocks, n_bytes_net, n_bytes_total,
/*element_type=*/hlo->shape().element_type(),
/*coalesced=*/true);
}

int64_t bytes_written =
GetShapeSizeRecursive(tiled_hlo_computation.GetRoot()->hlo()->shape());

absl::Duration compute_time =
ComputeTime(*device_info_, flops, launch_dimensions.num_blocks(),
launch_dimensions.num_threads_per_block());
absl::Duration write_time = WriteTime(*device_info_, bytes_written);
absl::Duration memory_access_time = read_time + write_time;
absl::Duration exec_time = CombineComputeAndMemoryAccessTime(
compute_time, memory_access_time,
GpuPerformanceModelOptions::PriorityFusion());

return EstimateRunTimeData{/*flops=*/flops,
/*bytes_read=*/bytes_read,
/*bytes_written=*/bytes_written,
/*read_time=*/read_time,
/*write_time=*/write_time,
/*compute_time=*/compute_time,
/*exec_time=*/exec_time};
}

absl::StatusOr<EstimateRunTimeData>
GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTriton(
const HloInstruction* producer, const HloInstruction* consumer) {
const auto& fusion_analysis =
(consumer == nullptr) ? fusion_analysis_cache_->Get(*producer)
: fusion_analysis_cache_->Get(*producer, *consumer);
auto launch_config = TritonFusion(fusion_analysis).launch_config();

if (!launch_config.has_value()) {
return absl::InvalidArgumentError(
"Could not get launch config for Triton fusion.");
}

return EstimateRunTimeForTiledFusion(fusion_analysis.fusion(),
launch_config->launch_dimensions,
launch_config->output_tile_sizes);
}

} // namespace gpu
} // namespace xla
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@ limitations under the License.
#ifndef XLA_SERVICE_GPU_MODEL_GPU_INDEXING_PERFORMANCE_MODEL_H_
#define XLA_SERVICE_GPU_MODEL_GPU_INDEXING_PERFORMANCE_MODEL_H_

#include <cstddef>
#include <cstdint>
#include <vector>

#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/gpu/model/fusion_analysis_cache.h"
#include "xla/service/gpu/model/gpu_performance_model_base.h"
#include "xla/service/gpu/model/hlo_op_profiles.h"
#include "xla/service/hlo_cost_analysis.h"
Expand All @@ -37,10 +43,12 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase {
public:
explicit GpuPerformanceModelWithIndexingAnalysis(
const se::DeviceDescription* device_info,
HloFusionAnalysisCache* fusion_analysis_cache,
HloCostAnalysis::ShapeSizeFunction shape_size,
mlir::MLIRContext* mlir_context)
: hlo_op_profile_(&HloOpProfiles::Singleton().GetProfile(device_info)),
device_info_(device_info),
fusion_analysis_cache_(fusion_analysis_cache),
shape_size_(shape_size),
mlir_context_(mlir_context) {}

Expand All @@ -57,6 +65,23 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase {
const HloInstruction* producer,
absl::Span<const HloInstruction* const> fused_consumers = {});

// Estimate the run time of the fusion with the given launch dimensions and
// output tile sizes.
//
// The model uses SymbolicTileAnalysis to build a TiledHloComputation with the
// given tile sizes. This way it can better estimate the amount of memory
// access and computation.
absl::StatusOr<EstimateRunTimeData> EstimateRunTimeForTiledFusion(
const HloFusionAdaptor& fusion_adaptor,
const LaunchDimensions& launch_dimensions,
const std::vector<int64_t>& output_tile_sizes);

// Estimate the run time of producer and consumer fused together, assuming
// that they will be emitted with Triton.
// If consumer is nullptr, estimate run time of the producer alone.
absl::StatusOr<EstimateRunTimeData> EstimateRunTimeForTriton(
const HloInstruction* producer, const HloInstruction* consumer = nullptr);

private:
// Returns an estimate how many FLOPs will be used to produce one element of
// the output.
Expand All @@ -66,6 +91,7 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase {

const HloOpProfiles::HloOpProfile* hlo_op_profile_;
const se::DeviceDescription* device_info_;
HloFusionAnalysisCache* fusion_analysis_cache_;
HloCostAnalysis::ShapeSizeFunction shape_size_;
mlir::MLIRContext* mlir_context_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/gpu_device_info_for_tests.h"
#include "xla/service/gpu/model/fusion_analysis_cache.h"
#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
#include "xla/service/gpu/model/gpu_performance_model_base.h"
#include "xla/shape.h"
Expand All @@ -51,8 +52,10 @@ class GpuIndexingPerformanceModelTest : public HloTestBase {
// The reference times in the test cases below are measured
// on A6000 by profiling the execution of the HLOs.
se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()};
HloFusionAnalysisCache fusion_analysis_cache_{device_info_};
GpuPerformanceModelWithIndexingAnalysis indexing_cost_model_{
&device_info_, ShapeSizeBytesFunction(), &mlir_context_};
&device_info_, &fusion_analysis_cache_, ShapeSizeBytesFunction(),
&mlir_context_};

GpuIndexingPerformanceModelTest() : HloTestBase() {}
};
Expand Down Expand Up @@ -167,6 +170,106 @@ ENTRY entry_computation {
EXPECT_NEAR(absl::ToDoubleNanoseconds(runtime_data.exec_time), 58, 1);
}

TEST_F(GpuIndexingPerformanceModelTest,
TritonSoftmaxFusionInstructionIsSupported) {
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
HloModule m
add {
Arg_0 = f32[] parameter(0)
Arg_1 = f32[] parameter(1)
ROOT add = f32[] add(Arg_0, Arg_1)
}
triton_softmax_computation {
param_0 = f32[512,911]{1,0} parameter(0)
param_1 = f32[911]{0} parameter(1)
broadcast_0 = f32[512,911]{1,0} broadcast(param_1), dimensions={1}
multiply_0 = f32[512,911]{1,0} multiply(param_0, broadcast_0)
constant_0 = f32[] constant(0)
reduce_0 = f32[512]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add
broadcast_4 = f32[512,911]{1,0} broadcast(reduce_0), dimensions={0}
ROOT multiply = f32[512,911]{1,0} multiply(multiply_0, broadcast_4)
}
ENTRY main {
param_0 = f32[512,911]{1,0} parameter(0)
param_1 = f32[911]{0} parameter(1)
ROOT triton_softmax = f32[512,911]{1,0} fusion(param_0, param_1), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}}
}
)"));
TF_ASSERT_OK_AND_ASSIGN(auto runtime_data,
indexing_cost_model_.EstimateRunTimeForTriton(
module->entry_computation()->root_instruction()));

constexpr int64_t kParam0SizeBytes = 512 * 911 * 4;
constexpr int64_t kParam1SizeBytes = 911 * 4;
constexpr int64_t kOutputSizeBytes = 512 * 911 * 4;

// Each block reads 1 tile of shape [1, 911] from param_0 and full param_1.
// In total param_0 is read once and param_1 is read 512 times.
constexpr int64_t kExpectedBytesRead =
kParam0SizeBytes + 512 * kParam1SizeBytes;

EXPECT_EQ(runtime_data.bytes_read, kExpectedBytesRead);
EXPECT_EQ(runtime_data.bytes_written, kOutputSizeBytes);
EXPECT_NEAR(absl::ToDoubleMicroseconds(runtime_data.exec_time), 5, 1);
}

TEST_F(GpuIndexingPerformanceModelTest,
TritonSoftmaxProducerConsumerFusionIsSupported) {
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
HloModule m
add {
Arg_0 = f32[] parameter(0)
Arg_1 = f32[] parameter(1)
ROOT add = f32[] add(Arg_0, Arg_1)
}
fusion {
param_0 = f32[512,911] parameter(0)
param_1 = f32[911] parameter(1)
broadcast = f32[512,911] broadcast(param_1), dimensions={1}
ROOT multiply = f32[512,911] multiply(param_0, broadcast)
}
triton_softmax_computation {
param_0 = f32[512,911] parameter(0)
constant_0 = f32[] constant(0)
reduce_0 = f32[512] reduce(param_0, constant_0), dimensions={1}, to_apply=add
broadcast_4 = f32[512,911] broadcast(reduce_0), dimensions={0}
ROOT multiply = f32[512,911] multiply(param_0, broadcast_4)
}
ENTRY main {
param_0 = f32[512,911] parameter(0)
param_1 = f32[911] parameter(1)
fusion.1 = f32[512,911] fusion(param_0, param_1), kind=kLoop, calls=fusion
ROOT triton_softmax = f32[512,911] fusion(fusion.1), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}}
}
)"));
auto consumer = module->entry_computation()->root_instruction();
auto producer = consumer->operand(0);

TF_ASSERT_OK_AND_ASSIGN(
auto runtime_data,
indexing_cost_model_.EstimateRunTimeForTriton(producer, consumer));

constexpr int64_t kParam0SizeBytes = 512 * 911 * 4;
constexpr int64_t kParam1SizeBytes = 911 * 4;
constexpr int64_t kOutputSizeBytes = 512 * 911 * 4;

// Each block reads 1 tile of shape [1, 911] from param_0 and full param_1.
// In total param_0 is read once and param_1 is read 512 times.
constexpr int64_t kExpectedBytesRead =
kParam0SizeBytes + 512 * kParam1SizeBytes;

EXPECT_EQ(runtime_data.bytes_read, kExpectedBytesRead);
EXPECT_EQ(runtime_data.bytes_written, kOutputSizeBytes);
EXPECT_NEAR(absl::ToDoubleMicroseconds(runtime_data.exec_time), 5, 1);
}

} // namespace
} // namespace gpu
} // namespace xla
Loading

0 comments on commit cd75b11

Please sign in to comment.