diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 4ece9de1054dee..9da56e03fffedf 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -261,6 +261,7 @@ xla_cc_test( name = "gpu_performance_model_test", srcs = ["gpu_performance_model_test.cc"], deps = [ + ":fusion_analysis_cache", ":gpu_hlo_cost_analysis", ":gpu_indexing_performance_model", ":gpu_performance_model", @@ -341,25 +342,35 @@ cc_library( hdrs = ["gpu_indexing_performance_model.h"], deps = [ ":coalescing_analysis", + ":fusion_analysis_cache", ":gpu_hlo_cost_analysis", ":gpu_performance_model_base", ":hlo_op_profiles", ":indexing_analysis", + ":symbolic_tile_analysis", + ":tiled_hlo_computation", "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:hlo_cost_analysis", + "//xla/service:instruction_fusion", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions:triton", "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", ], ) @@ -367,6 +378,7 @@ xla_cc_test( name = "gpu_indexing_performance_model_test", srcs = ["gpu_indexing_performance_model_test.cc"], deps = [ + ":fusion_analysis_cache", ":gpu_hlo_cost_analysis", ":gpu_indexing_performance_model", ":gpu_performance_model_base", diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc index 3ab064f4391f41..6b8cc815fc52f0 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -17,15 +17,21 @@ limitations under the License. #include #include +#include +#include #include +#include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/fusions/triton.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/launch_dimensions.h" @@ -34,10 +40,14 @@ limitations under the License. #include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/symbolic_tile_analysis.h" +#include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/service/instruction_fusion.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" #include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -240,5 +250,107 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimes( return {time_unfused, time_fused}; } +absl::StatusOr +GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledFusion( + const HloFusionAdaptor& fusion_adaptor, + const LaunchDimensions& launch_dimensions, + const std::vector& tile_sizes) { + // TODO(b/332714755): Add caching for SymbolicTileAnalysis. + SymbolicTileAnalysisOrError analysis_or_error = + SymbolicTileAnalysis::AnalyzeFusion(fusion_adaptor, mlir_context_); + if (!std::holds_alternative(analysis_or_error)) { + return absl::FailedPreconditionError( + absl::StrCat("SymbolicTileAnalysis failed. ", + std::get(analysis_or_error).Explain())); + } + SymbolicTileAnalysis analysis = + std::get(std::move(analysis_or_error)); + + TF_ASSIGN_OR_RETURN(TiledHloComputation tiled_hlo_computation, + analysis.ComputeTiledHloInstructions(tile_sizes)); + + absl::flat_hash_map n_bytes_total_map; + + int64_t flops = 0; + int64_t bytes_read = 0; + + for (const auto& tiled_hlo : tiled_hlo_computation.instructions()) { + // Number of blocks that read or compute this tile. + int64_t num_blocks = tiled_hlo->block_id_to_tile_offsets_indexing() + .GetDimensionBound(0) + .GetLoopTripCount(); + + // Total number of elements that are read from memory or computed for this + // tile across all blocks. + int64_t num_elements = num_blocks * Product(tiled_hlo->tile_sizes()); + + const HloInstruction* hlo = tiled_hlo->hlo(); + + if (fusion_adaptor.ContainsInstruction(hlo)) { + // Tiles inside the computation contribute to the total FLOPs count. + flops += FlopsPerElement(hlo) * num_elements; + } else { + // Tiles of the operands of the fusion contribute to the total memory + // read time. + int64_t element_type_size = + ShapeUtil::ByteSizeOfPrimitiveType(hlo->shape().element_type()); + int64_t tile_bytes_read = element_type_size * num_elements; + + bytes_read += tile_bytes_read; + n_bytes_total_map[hlo] += tile_bytes_read; + } + } + + int64_t num_blocks = launch_dimensions.num_blocks(); + absl::Duration read_time = absl::ZeroDuration(); + for (const auto& [hlo, n_bytes_total] : n_bytes_total_map) { + int64_t operand_size = shape_size_(hlo->shape()); + int64_t n_bytes_net = std::min(operand_size, n_bytes_total); + + read_time += ReadTimeWithDRAMHeuristic( + *device_info_, num_blocks, n_bytes_net, n_bytes_total, + /*element_type=*/hlo->shape().element_type(), + /*coalesced=*/true); + } + + int64_t bytes_written = + GetShapeSizeRecursive(tiled_hlo_computation.GetRoot()->hlo()->shape()); + + absl::Duration compute_time = + ComputeTime(*device_info_, flops, launch_dimensions.num_blocks(), + launch_dimensions.num_threads_per_block()); + absl::Duration write_time = WriteTime(*device_info_, bytes_written); + absl::Duration memory_access_time = read_time + write_time; + absl::Duration exec_time = CombineComputeAndMemoryAccessTime( + compute_time, memory_access_time, + GpuPerformanceModelOptions::PriorityFusion()); + + return EstimateRunTimeData{/*flops=*/flops, + /*bytes_read=*/bytes_read, + /*bytes_written=*/bytes_written, + /*read_time=*/read_time, + /*write_time=*/write_time, + /*compute_time=*/compute_time, + /*exec_time=*/exec_time}; +} + +absl::StatusOr +GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTriton( + const HloInstruction* producer, const HloInstruction* consumer) { + const auto& fusion_analysis = + (consumer == nullptr) ? fusion_analysis_cache_->Get(*producer) + : fusion_analysis_cache_->Get(*producer, *consumer); + auto launch_config = TritonFusion(fusion_analysis).launch_config(); + + if (!launch_config.has_value()) { + return absl::InvalidArgumentError( + "Could not get launch config for Triton fusion."); + } + + return EstimateRunTimeForTiledFusion(fusion_analysis.fusion(), + launch_config->launch_dimensions, + launch_config->output_tile_sizes); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h index 14d7e520a820d3..a1f98a8660d663 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h @@ -16,12 +16,18 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_MODEL_GPU_INDEXING_PERFORMANCE_MODEL_H_ #define XLA_SERVICE_GPU_MODEL_GPU_INDEXING_PERFORMANCE_MODEL_H_ +#include #include +#include +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/service/gpu/model/hlo_op_profiles.h" #include "xla/service/hlo_cost_analysis.h" @@ -37,10 +43,12 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { public: explicit GpuPerformanceModelWithIndexingAnalysis( const se::DeviceDescription* device_info, + HloFusionAnalysisCache* fusion_analysis_cache, HloCostAnalysis::ShapeSizeFunction shape_size, mlir::MLIRContext* mlir_context) : hlo_op_profile_(&HloOpProfiles::Singleton().GetProfile(device_info)), device_info_(device_info), + fusion_analysis_cache_(fusion_analysis_cache), shape_size_(shape_size), mlir_context_(mlir_context) {} @@ -57,6 +65,23 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { const HloInstruction* producer, absl::Span fused_consumers = {}); + // Estimate the run time of the fusion with the given launch dimensions and + // output tile sizes. + // + // The model uses SymbolicTileAnalysis to build a TiledHloComputation with the + // given tile sizes. This way it can better estimate the amount of memory + // access and computation. + absl::StatusOr EstimateRunTimeForTiledFusion( + const HloFusionAdaptor& fusion_adaptor, + const LaunchDimensions& launch_dimensions, + const std::vector& output_tile_sizes); + + // Estimate the run time of producer and consumer fused together, assuming + // that they will be emitted with Triton. + // If consumer is nullptr, estimate run time of the producer alone. + absl::StatusOr EstimateRunTimeForTriton( + const HloInstruction* producer, const HloInstruction* consumer = nullptr); + private: // Returns an estimate how many FLOPs will be used to produce one element of // the output. @@ -66,6 +91,7 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { const HloOpProfiles::HloOpProfile* hlo_op_profile_; const se::DeviceDescription* device_info_; + HloFusionAnalysisCache* fusion_analysis_cache_; HloCostAnalysis::ShapeSizeFunction shape_size_; mlir::MLIRContext* mlir_context_; }; diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc index 5e52685762e524..57468cf258d521 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/shape.h" @@ -51,8 +52,10 @@ class GpuIndexingPerformanceModelTest : public HloTestBase { // The reference times in the test cases below are measured // on A6000 by profiling the execution of the HLOs. se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()}; + HloFusionAnalysisCache fusion_analysis_cache_{device_info_}; GpuPerformanceModelWithIndexingAnalysis indexing_cost_model_{ - &device_info_, ShapeSizeBytesFunction(), &mlir_context_}; + &device_info_, &fusion_analysis_cache_, ShapeSizeBytesFunction(), + &mlir_context_}; GpuIndexingPerformanceModelTest() : HloTestBase() {} }; @@ -167,6 +170,106 @@ ENTRY entry_computation { EXPECT_NEAR(absl::ToDoubleNanoseconds(runtime_data.exec_time), 58, 1); } +TEST_F(GpuIndexingPerformanceModelTest, + TritonSoftmaxFusionInstructionIsSupported) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +triton_softmax_computation { + param_0 = f32[512,911]{1,0} parameter(0) + param_1 = f32[911]{0} parameter(1) + broadcast_0 = f32[512,911]{1,0} broadcast(param_1), dimensions={1} + multiply_0 = f32[512,911]{1,0} multiply(param_0, broadcast_0) + constant_0 = f32[] constant(0) + reduce_0 = f32[512]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[512,911]{1,0} broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[512,911]{1,0} multiply(multiply_0, broadcast_4) +} + +ENTRY main { + param_0 = f32[512,911]{1,0} parameter(0) + param_1 = f32[911]{0} parameter(1) + ROOT triton_softmax = f32[512,911]{1,0} fusion(param_0, param_1), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} +} +)")); + TF_ASSERT_OK_AND_ASSIGN(auto runtime_data, + indexing_cost_model_.EstimateRunTimeForTriton( + module->entry_computation()->root_instruction())); + + constexpr int64_t kParam0SizeBytes = 512 * 911 * 4; + constexpr int64_t kParam1SizeBytes = 911 * 4; + constexpr int64_t kOutputSizeBytes = 512 * 911 * 4; + + // Each block reads 1 tile of shape [1, 911] from param_0 and full param_1. + // In total param_0 is read once and param_1 is read 512 times. + constexpr int64_t kExpectedBytesRead = + kParam0SizeBytes + 512 * kParam1SizeBytes; + + EXPECT_EQ(runtime_data.bytes_read, kExpectedBytesRead); + EXPECT_EQ(runtime_data.bytes_written, kOutputSizeBytes); + EXPECT_NEAR(absl::ToDoubleMicroseconds(runtime_data.exec_time), 5, 1); +} + +TEST_F(GpuIndexingPerformanceModelTest, + TritonSoftmaxProducerConsumerFusionIsSupported) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +fusion { + param_0 = f32[512,911] parameter(0) + param_1 = f32[911] parameter(1) + broadcast = f32[512,911] broadcast(param_1), dimensions={1} + ROOT multiply = f32[512,911] multiply(param_0, broadcast) +} + +triton_softmax_computation { + param_0 = f32[512,911] parameter(0) + constant_0 = f32[] constant(0) + reduce_0 = f32[512] reduce(param_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[512,911] broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[512,911] multiply(param_0, broadcast_4) +} + +ENTRY main { + param_0 = f32[512,911] parameter(0) + param_1 = f32[911] parameter(1) + fusion.1 = f32[512,911] fusion(param_0, param_1), kind=kLoop, calls=fusion + ROOT triton_softmax = f32[512,911] fusion(fusion.1), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} +} +)")); + auto consumer = module->entry_computation()->root_instruction(); + auto producer = consumer->operand(0); + + TF_ASSERT_OK_AND_ASSIGN( + auto runtime_data, + indexing_cost_model_.EstimateRunTimeForTriton(producer, consumer)); + + constexpr int64_t kParam0SizeBytes = 512 * 911 * 4; + constexpr int64_t kParam1SizeBytes = 911 * 4; + constexpr int64_t kOutputSizeBytes = 512 * 911 * 4; + + // Each block reads 1 tile of shape [1, 911] from param_0 and full param_1. + // In total param_0 is read once and param_1 is read 512 times. + constexpr int64_t kExpectedBytesRead = + kParam0SizeBytes + 512 * kParam1SizeBytes; + + EXPECT_EQ(runtime_data.bytes_read, kExpectedBytesRead); + EXPECT_EQ(runtime_data.bytes_written, kOutputSizeBytes); + EXPECT_NEAR(absl::ToDoubleMicroseconds(runtime_data.exec_time), 5, 1); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc index c8c42019d2efdb..e76f783f477e95 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_indexing_performance_model.h" #include "xla/service/gpu/model/gpu_performance_model_base.h" @@ -79,10 +80,12 @@ class GpuPerformanceModelTest : public HloTestBase { // The reference times in the test cases below are measured // on A6000 by profiling the execution of the HLOs. se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()}; + HloFusionAnalysisCache fusion_analysis_cache_{device_info_}; GpuHloCostAnalysis analysis_{options_, &device_info_}; GpuPerformanceModelWithIndexingAnalysis indexing_cost_model_{ - &device_info_, ShapeSizeBytesFunction(), &mlir_context_}; + &device_info_, &fusion_analysis_cache_, ShapeSizeBytesFunction(), + &mlir_context_}; GpuPerformanceModelTest() : HloTestBase() {} };