From cd75b11191cf2d6bb0cb7558887afd839b28e7da Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Tue, 18 Jun 2024 02:54:57 -0700 Subject: [PATCH] [XLA:GPU] Add initial version of cost model for tiled hlo. Heuristics for compute and memory access time are shared with GpuPerformanceModel. The main difference is that we assume that each tile is read or computed only once and results are shared between threads via shared memory. PiperOrigin-RevId: 644313383 --- third_party/xla/xla/service/gpu/model/BUILD | 12 ++ .../model/gpu_indexing_performance_model.cc | 112 ++++++++++++++++++ .../model/gpu_indexing_performance_model.h | 26 ++++ .../gpu_indexing_performance_model_test.cc | 105 +++++++++++++++- .../gpu/model/gpu_performance_model_test.cc | 5 +- 5 files changed, 258 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 4ece9de1054dee..9da56e03fffedf 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -261,6 +261,7 @@ xla_cc_test( name = "gpu_performance_model_test", srcs = ["gpu_performance_model_test.cc"], deps = [ + ":fusion_analysis_cache", ":gpu_hlo_cost_analysis", ":gpu_indexing_performance_model", ":gpu_performance_model", @@ -341,25 +342,35 @@ cc_library( hdrs = ["gpu_indexing_performance_model.h"], deps = [ ":coalescing_analysis", + ":fusion_analysis_cache", ":gpu_hlo_cost_analysis", ":gpu_performance_model_base", ":hlo_op_profiles", ":indexing_analysis", + ":symbolic_tile_analysis", + ":tiled_hlo_computation", "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:hlo_cost_analysis", + "//xla/service:instruction_fusion", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions:triton", "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", ], ) @@ -367,6 +378,7 @@ xla_cc_test( name = "gpu_indexing_performance_model_test", srcs = ["gpu_indexing_performance_model_test.cc"], deps = [ + ":fusion_analysis_cache", ":gpu_hlo_cost_analysis", ":gpu_indexing_performance_model", ":gpu_performance_model_base", diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc index 3ab064f4391f41..6b8cc815fc52f0 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -17,15 +17,21 @@ limitations under the License. #include #include +#include +#include #include +#include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/fusions/triton.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/launch_dimensions.h" @@ -34,10 +40,14 @@ limitations under the License. #include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/symbolic_tile_analysis.h" +#include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/service/instruction_fusion.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" #include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -240,5 +250,107 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimes( return {time_unfused, time_fused}; } +absl::StatusOr +GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledFusion( + const HloFusionAdaptor& fusion_adaptor, + const LaunchDimensions& launch_dimensions, + const std::vector& tile_sizes) { + // TODO(b/332714755): Add caching for SymbolicTileAnalysis. + SymbolicTileAnalysisOrError analysis_or_error = + SymbolicTileAnalysis::AnalyzeFusion(fusion_adaptor, mlir_context_); + if (!std::holds_alternative(analysis_or_error)) { + return absl::FailedPreconditionError( + absl::StrCat("SymbolicTileAnalysis failed. ", + std::get(analysis_or_error).Explain())); + } + SymbolicTileAnalysis analysis = + std::get(std::move(analysis_or_error)); + + TF_ASSIGN_OR_RETURN(TiledHloComputation tiled_hlo_computation, + analysis.ComputeTiledHloInstructions(tile_sizes)); + + absl::flat_hash_map n_bytes_total_map; + + int64_t flops = 0; + int64_t bytes_read = 0; + + for (const auto& tiled_hlo : tiled_hlo_computation.instructions()) { + // Number of blocks that read or compute this tile. + int64_t num_blocks = tiled_hlo->block_id_to_tile_offsets_indexing() + .GetDimensionBound(0) + .GetLoopTripCount(); + + // Total number of elements that are read from memory or computed for this + // tile across all blocks. + int64_t num_elements = num_blocks * Product(tiled_hlo->tile_sizes()); + + const HloInstruction* hlo = tiled_hlo->hlo(); + + if (fusion_adaptor.ContainsInstruction(hlo)) { + // Tiles inside the computation contribute to the total FLOPs count. + flops += FlopsPerElement(hlo) * num_elements; + } else { + // Tiles of the operands of the fusion contribute to the total memory + // read time. + int64_t element_type_size = + ShapeUtil::ByteSizeOfPrimitiveType(hlo->shape().element_type()); + int64_t tile_bytes_read = element_type_size * num_elements; + + bytes_read += tile_bytes_read; + n_bytes_total_map[hlo] += tile_bytes_read; + } + } + + int64_t num_blocks = launch_dimensions.num_blocks(); + absl::Duration read_time = absl::ZeroDuration(); + for (const auto& [hlo, n_bytes_total] : n_bytes_total_map) { + int64_t operand_size = shape_size_(hlo->shape()); + int64_t n_bytes_net = std::min(operand_size, n_bytes_total); + + read_time += ReadTimeWithDRAMHeuristic( + *device_info_, num_blocks, n_bytes_net, n_bytes_total, + /*element_type=*/hlo->shape().element_type(), + /*coalesced=*/true); + } + + int64_t bytes_written = + GetShapeSizeRecursive(tiled_hlo_computation.GetRoot()->hlo()->shape()); + + absl::Duration compute_time = + ComputeTime(*device_info_, flops, launch_dimensions.num_blocks(), + launch_dimensions.num_threads_per_block()); + absl::Duration write_time = WriteTime(*device_info_, bytes_written); + absl::Duration memory_access_time = read_time + write_time; + absl::Duration exec_time = CombineComputeAndMemoryAccessTime( + compute_time, memory_access_time, + GpuPerformanceModelOptions::PriorityFusion()); + + return EstimateRunTimeData{/*flops=*/flops, + /*bytes_read=*/bytes_read, + /*bytes_written=*/bytes_written, + /*read_time=*/read_time, + /*write_time=*/write_time, + /*compute_time=*/compute_time, + /*exec_time=*/exec_time}; +} + +absl::StatusOr +GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTriton( + const HloInstruction* producer, const HloInstruction* consumer) { + const auto& fusion_analysis = + (consumer == nullptr) ? fusion_analysis_cache_->Get(*producer) + : fusion_analysis_cache_->Get(*producer, *consumer); + auto launch_config = TritonFusion(fusion_analysis).launch_config(); + + if (!launch_config.has_value()) { + return absl::InvalidArgumentError( + "Could not get launch config for Triton fusion."); + } + + return EstimateRunTimeForTiledFusion(fusion_analysis.fusion(), + launch_config->launch_dimensions, + launch_config->output_tile_sizes); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h index 14d7e520a820d3..a1f98a8660d663 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h @@ -16,12 +16,18 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_MODEL_GPU_INDEXING_PERFORMANCE_MODEL_H_ #define XLA_SERVICE_GPU_MODEL_GPU_INDEXING_PERFORMANCE_MODEL_H_ +#include #include +#include +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/service/gpu/model/hlo_op_profiles.h" #include "xla/service/hlo_cost_analysis.h" @@ -37,10 +43,12 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { public: explicit GpuPerformanceModelWithIndexingAnalysis( const se::DeviceDescription* device_info, + HloFusionAnalysisCache* fusion_analysis_cache, HloCostAnalysis::ShapeSizeFunction shape_size, mlir::MLIRContext* mlir_context) : hlo_op_profile_(&HloOpProfiles::Singleton().GetProfile(device_info)), device_info_(device_info), + fusion_analysis_cache_(fusion_analysis_cache), shape_size_(shape_size), mlir_context_(mlir_context) {} @@ -57,6 +65,23 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { const HloInstruction* producer, absl::Span fused_consumers = {}); + // Estimate the run time of the fusion with the given launch dimensions and + // output tile sizes. + // + // The model uses SymbolicTileAnalysis to build a TiledHloComputation with the + // given tile sizes. This way it can better estimate the amount of memory + // access and computation. + absl::StatusOr EstimateRunTimeForTiledFusion( + const HloFusionAdaptor& fusion_adaptor, + const LaunchDimensions& launch_dimensions, + const std::vector& output_tile_sizes); + + // Estimate the run time of producer and consumer fused together, assuming + // that they will be emitted with Triton. + // If consumer is nullptr, estimate run time of the producer alone. + absl::StatusOr EstimateRunTimeForTriton( + const HloInstruction* producer, const HloInstruction* consumer = nullptr); + private: // Returns an estimate how many FLOPs will be used to produce one element of // the output. @@ -66,6 +91,7 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { const HloOpProfiles::HloOpProfile* hlo_op_profile_; const se::DeviceDescription* device_info_; + HloFusionAnalysisCache* fusion_analysis_cache_; HloCostAnalysis::ShapeSizeFunction shape_size_; mlir::MLIRContext* mlir_context_; }; diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc index 5e52685762e524..57468cf258d521 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/shape.h" @@ -51,8 +52,10 @@ class GpuIndexingPerformanceModelTest : public HloTestBase { // The reference times in the test cases below are measured // on A6000 by profiling the execution of the HLOs. se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()}; + HloFusionAnalysisCache fusion_analysis_cache_{device_info_}; GpuPerformanceModelWithIndexingAnalysis indexing_cost_model_{ - &device_info_, ShapeSizeBytesFunction(), &mlir_context_}; + &device_info_, &fusion_analysis_cache_, ShapeSizeBytesFunction(), + &mlir_context_}; GpuIndexingPerformanceModelTest() : HloTestBase() {} }; @@ -167,6 +170,106 @@ ENTRY entry_computation { EXPECT_NEAR(absl::ToDoubleNanoseconds(runtime_data.exec_time), 58, 1); } +TEST_F(GpuIndexingPerformanceModelTest, + TritonSoftmaxFusionInstructionIsSupported) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +triton_softmax_computation { + param_0 = f32[512,911]{1,0} parameter(0) + param_1 = f32[911]{0} parameter(1) + broadcast_0 = f32[512,911]{1,0} broadcast(param_1), dimensions={1} + multiply_0 = f32[512,911]{1,0} multiply(param_0, broadcast_0) + constant_0 = f32[] constant(0) + reduce_0 = f32[512]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[512,911]{1,0} broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[512,911]{1,0} multiply(multiply_0, broadcast_4) +} + +ENTRY main { + param_0 = f32[512,911]{1,0} parameter(0) + param_1 = f32[911]{0} parameter(1) + ROOT triton_softmax = f32[512,911]{1,0} fusion(param_0, param_1), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} +} +)")); + TF_ASSERT_OK_AND_ASSIGN(auto runtime_data, + indexing_cost_model_.EstimateRunTimeForTriton( + module->entry_computation()->root_instruction())); + + constexpr int64_t kParam0SizeBytes = 512 * 911 * 4; + constexpr int64_t kParam1SizeBytes = 911 * 4; + constexpr int64_t kOutputSizeBytes = 512 * 911 * 4; + + // Each block reads 1 tile of shape [1, 911] from param_0 and full param_1. + // In total param_0 is read once and param_1 is read 512 times. + constexpr int64_t kExpectedBytesRead = + kParam0SizeBytes + 512 * kParam1SizeBytes; + + EXPECT_EQ(runtime_data.bytes_read, kExpectedBytesRead); + EXPECT_EQ(runtime_data.bytes_written, kOutputSizeBytes); + EXPECT_NEAR(absl::ToDoubleMicroseconds(runtime_data.exec_time), 5, 1); +} + +TEST_F(GpuIndexingPerformanceModelTest, + TritonSoftmaxProducerConsumerFusionIsSupported) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +fusion { + param_0 = f32[512,911] parameter(0) + param_1 = f32[911] parameter(1) + broadcast = f32[512,911] broadcast(param_1), dimensions={1} + ROOT multiply = f32[512,911] multiply(param_0, broadcast) +} + +triton_softmax_computation { + param_0 = f32[512,911] parameter(0) + constant_0 = f32[] constant(0) + reduce_0 = f32[512] reduce(param_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[512,911] broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[512,911] multiply(param_0, broadcast_4) +} + +ENTRY main { + param_0 = f32[512,911] parameter(0) + param_1 = f32[911] parameter(1) + fusion.1 = f32[512,911] fusion(param_0, param_1), kind=kLoop, calls=fusion + ROOT triton_softmax = f32[512,911] fusion(fusion.1), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} +} +)")); + auto consumer = module->entry_computation()->root_instruction(); + auto producer = consumer->operand(0); + + TF_ASSERT_OK_AND_ASSIGN( + auto runtime_data, + indexing_cost_model_.EstimateRunTimeForTriton(producer, consumer)); + + constexpr int64_t kParam0SizeBytes = 512 * 911 * 4; + constexpr int64_t kParam1SizeBytes = 911 * 4; + constexpr int64_t kOutputSizeBytes = 512 * 911 * 4; + + // Each block reads 1 tile of shape [1, 911] from param_0 and full param_1. + // In total param_0 is read once and param_1 is read 512 times. + constexpr int64_t kExpectedBytesRead = + kParam0SizeBytes + 512 * kParam1SizeBytes; + + EXPECT_EQ(runtime_data.bytes_read, kExpectedBytesRead); + EXPECT_EQ(runtime_data.bytes_written, kOutputSizeBytes); + EXPECT_NEAR(absl::ToDoubleMicroseconds(runtime_data.exec_time), 5, 1); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc index c8c42019d2efdb..e76f783f477e95 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_indexing_performance_model.h" #include "xla/service/gpu/model/gpu_performance_model_base.h" @@ -79,10 +80,12 @@ class GpuPerformanceModelTest : public HloTestBase { // The reference times in the test cases below are measured // on A6000 by profiling the execution of the HLOs. se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()}; + HloFusionAnalysisCache fusion_analysis_cache_{device_info_}; GpuHloCostAnalysis analysis_{options_, &device_info_}; GpuPerformanceModelWithIndexingAnalysis indexing_cost_model_{ - &device_info_, ShapeSizeBytesFunction(), &mlir_context_}; + &device_info_, &fusion_analysis_cache_, ShapeSizeBytesFunction(), + &mlir_context_}; GpuPerformanceModelTest() : HloTestBase() {} };