From 015a5d623899b64afad0747404c0ab9590239f4e Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Thu, 13 Jun 2024 06:04:10 -0700 Subject: [PATCH] Move BlockedSparseToMMA pattern from Triton to XLA. FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/13781 from openxla:kernel_cache_grow f421890b6548a9a4dfcc5350e791bc8860615dbc PiperOrigin-RevId: 642959490 --- .../triton/xla_extensions/sparse_dot.patch | 191 +++++------------- .../triton/xla_extensions/sparse_dot.patch | 191 +++++------------- third_party/xla/xla/debug_options_flags.cc | 16 +- third_party/xla/xla/service/gpu/BUILD | 12 +- .../xla/xla/service/gpu/gpu_compiler.cc | 69 +++---- .../xla/xla/service/gpu/gpu_compiler_test.cc | 22 +- .../xla/xla/service/gpu/gpu_hlo_schedule.cc | 91 --------- .../xla/xla/service/gpu/gpu_hlo_schedule.h | 2 - .../gpu/gpu_latency_hiding_scheduler.cc | 92 +++++++++ .../gpu/gpu_latency_hiding_scheduler.h | 23 +++ .../xla/service/gpu/ir_emitter_triton_cuda.cc | 1 + .../xla/service/gpu/ir_emitter_triton_rocm.cc | 1 + .../xla/xla/service/gpu/kernel_reuse_cache.cc | 36 ++++ .../xla/xla/service/gpu/kernel_reuse_cache.h | 12 ++ .../service/gpu/kernel_reuse_cache_test.cc | 38 +++- third_party/xla/xla/service/gpu/model/BUILD | 3 + .../xla/service/gpu/model/symbolic_tile.cc | 10 +- .../xla/xla/service/gpu/model/symbolic_tile.h | 3 +- .../gpu/model/symbolic_tile_analysis_test.cc | 77 +++++++ .../service/gpu/model/symbolic_tile_test.cc | 25 ++- .../xla/service/gpu/nvptx_compiler_test.cc | 1 + .../tests/sparse_ttg_accelerate_matmul.mlir | 2 +- .../service/gpu/triton_sparse_extensions.cc | 142 ++++++++++++- .../service/gpu/triton_sparse_extensions.h | 1 + 24 files changed, 607 insertions(+), 454 deletions(-) diff --git a/third_party/triton/xla_extensions/sparse_dot.patch b/third_party/triton/xla_extensions/sparse_dot.patch index 307e4cdbb947b4..ae8c5c61522860 100644 --- a/third_party/triton/xla_extensions/sparse_dot.patch +++ b/third_party/triton/xla_extensions/sparse_dot.patch @@ -181,46 +181,48 @@ diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect index c47558fa6..35f0cca95 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp -@@ -37,7 +37,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) { +@@ -37,8 +37,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) { return 0; } -SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, -+template -+SmallVector warpsPerTileV2(DotType dotOp, const ArrayRef shape, - int numWarps) { +- int numWarps) { ++SmallVector ++warpsPerTileV2(Operation *dotOp, const ArrayRef shape, int numWarps) { auto rank = shape.size(); // Early exit for batched matmul -@@ -51,8 +52,8 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, + if (rank == 3) +@@ -51,9 +51,8 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, auto slices = multiRootGetSlice(dotOp, {filter}, {filter}); bool hasChainedDot = false; for (Operation *op : slices) { - if (isa(op) && (op != dotOp)) { - auto chainedDot = cast(op); -+ if (isa(op) && (op != dotOp)) { -+ auto chainedDot = cast(op); - auto resTy = chainedDot.getResult().getType(); +- auto resTy = chainedDot.getResult().getType(); ++ if (dotOp->getName() == op->getName() && op != dotOp) { ++ auto resTy = cast(op->getResult(0).getType()); if (resTy.getRank() != rank) { continue; -@@ -96,12 +97,13 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, - return ret; + } +@@ -97,12 +96,13 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, } --SmallVector + SmallVector -warpsPerTileV3(DotOp dotOp, const ArrayRef shape, int numWarps, -- const SmallVector &instrShape) { -+template -+SmallVector warpsPerTileV3( -+ DotType dotOp, const ArrayRef shape, int numWarps, -+ const SmallVector &instrShape) { ++warpsPerTileV3(Operation *dotOp, const ArrayRef shape, int numWarps, + const SmallVector &instrShape) { SetVector slices; - mlir::getForwardSlice(dotOp.getResult(), &slices); +- mlir::getForwardSlice(dotOp.getResult(), &slices); - if (llvm::find_if(slices, [](Operation *op) { return isa(op); }) != -+ if (llvm::find_if(slices, [](Operation *op) { return isa(op); }) != - slices.end()) +- slices.end()) ++ mlir::getForwardSlice(dotOp->getResult(0), &slices); ++ if (llvm::find_if(slices, [&](Operation *op) { ++ return dotOp->getName() == op->getName(); ++ }) != slices.end()) return {(unsigned)numWarps, 1}; -@@ -191,6 +193,7 @@ class BlockedToMMA : public mlir::OpRewritePattern { + // For MMAv3, the smallest indivisible unit of warp shape is (4, 1). +@@ -167,6 +167,7 @@ class BlockedToMMA : public mlir::OpRewritePattern { mlir::TypeID::get()); } @@ -228,140 +230,37 @@ index c47558fa6..35f0cca95 100644 // Finds the first different bitwidth in the chain of shape-preserving // unary ops that x depends on. // There are two primary scenarios: -@@ -224,14 +227,14 @@ class BlockedToMMA : public mlir::OpRewritePattern { - return origBitWidth; +@@ -206,7 +207,7 @@ public: } --public: - BlockedToMMA(mlir::MLIRContext *context, int computeCapability) - : OpRewritePattern(context), computeCapability(computeCapability) { - } - -- static SmallVector + static SmallVector - getWarpsPerTile(DotOp dotOp, const ArrayRef shape, int version, -- int numWarps, const SmallVector &instrShape) { -+ template -+ static SmallVector getWarpsPerTile( -+ DotType dotOp, const ArrayRef shape, int version, int numWarps, -+ const SmallVector &instrShape) { ++ getWarpsPerTile(Operation *dotOp, const ArrayRef shape, int version, + int numWarps, const SmallVector &instrShape) { switch (version) { case 2: - return warpsPerTileV2(dotOp, shape, numWarps); -@@ -359,6 +362,106 @@ public: - return success(); +@@ -405,6 +406,21 @@ public: } }; -+ -+class SparseBlockedToMMA : public mlir::RewritePattern { -+ public: -+ using SparseDotOp = mlir::triton::gpu::SparseDotOp; -+ using SparseDotMetaEncodingAttr = -+ mlir::triton::gpu::SparseDotMetaEncodingAttr; -+ -+ SparseBlockedToMMA(mlir::MLIRContext *context, int computeCapability) -+ : mlir::RewritePattern(SparseDotOp::getOperationName(), 2, context), -+ computeCapability(computeCapability) {} -+ -+ mlir::LogicalResult matchAndRewrite( -+ mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { -+ auto dotOp = cast(op); -+ auto ctx = op->getContext(); -+ Value a = dotOp.getA(); -+ Value b = dotOp.getB(); -+ -+ // Check data-types and SM compatibility -+ RankedTensorType oldRetType = dotOp.getType(); -+ if (!oldRetType.getEncoding() || -+ isa(oldRetType.getEncoding())) -+ return failure(); -+ -+ assert(computeCapability >= 80 && -+ "SparseDot is supported on Ampere and higher"); -+ bool allowV3 = !triton::tools::getBoolEnv("DISABLE_MMA_V3"); -+ int versionMajor = computeCapability >= 90 && allowV3 ? 3 : 2; -+ -+ // get MMA encoding for the given number of warps -+ auto retShapePerCTA = getShapePerCTA(oldRetType); -+ auto mod = op->getParentOfType(); -+ int numWarps = TritonGPUDialect::getNumWarps(mod); -+ auto CTALayout = getCTALayout(oldRetType.getEncoding()); -+ -+ auto instrShape = -+ mmaVersionToInstrShape(versionMajor, retShapePerCTA, -+ cast(a.getType()), numWarps); -+ auto warpsPerTile = BlockedToMMA::getWarpsPerTile( -+ dotOp, retShapePerCTA, versionMajor, numWarps, instrShape); -+ NvidiaMmaEncodingAttr mmaEnc = -+ NvidiaMmaEncodingAttr::get(ctx, versionMajor, /*versionMinor=*/0, -+ warpsPerTile, CTALayout, instrShape); -+ auto newRetType = RankedTensorType::get( -+ oldRetType.getShape(), oldRetType.getElementType(), mmaEnc); -+ -+ // convert accumulator -+ auto oldAcc = dotOp.getOperand(2); -+ auto newAcc = rewriter.create(oldAcc.getLoc(), -+ newRetType, oldAcc); -+ -+ if (versionMajor == 2) { -+ int minBitwidth = std::min(BlockedToMMA::computeOrigBitWidth(a), -+ BlockedToMMA::computeOrigBitWidth(b)); -+ int kWidth = 32 / minBitwidth; -+ -+ // convert A operand -+ auto oldAType = cast(a.getType()); -+ auto newAEncoding = -+ DotOperandEncodingAttr::get(ctx, 0, mmaEnc, kWidth); -+ auto newAType = RankedTensorType::get( -+ oldAType.getShape(), oldAType.getElementType(), newAEncoding); -+ a = rewriter.create(a.getLoc(), newAType, a); -+ -+ // convert B operand -+ auto oldBType = cast(b.getType()); -+ auto newBEncoding = -+ DotOperandEncodingAttr::get(ctx, 1, mmaEnc, kWidth); -+ auto newBType = RankedTensorType::get( -+ oldBType.getShape(), oldBType.getElementType(), newBEncoding); -+ b = rewriter.create(b.getLoc(), newBType, b); -+ } else { -+ auto eltType = dotOp.getA().getType().getElementType(); -+ // In MMAV3 tranpose is only supported for f16 and bf16. -+ bool allowTranspose = eltType.isF16() || eltType.isBF16(); -+ a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose); -+ b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose); -+ } -+ -+ // convert metadata -+ Value meta = dotOp.getAMeta(); -+ auto oldMetaType = cast(meta.getType()); -+ auto newMetaType = RankedTensorType::get( -+ oldMetaType.getShape(), oldMetaType.getElementType(), -+ SparseDotMetaEncodingAttr::get(ctx, mmaEnc)); -+ meta = -+ rewriter.create(meta.getLoc(), newMetaType, meta); -+ -+ // convert dot instruction -+ auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, -+ newAcc, meta); -+ -+ rewriter.replaceOpWithNewOp(op, oldRetType, -+ newDot.getResult()); -+ return success(); -+ } -+ -+ private: -+ int computeCapability; -+}; - } // namespace - - static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, -@@ -420,6 +523,7 @@ public: - mlir::RewritePatternSet patterns(context); - patterns.add(context, computeCapability); -+ patterns.add(context, computeCapability); - if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { - signalPassFailure(); - } ++// Expose helper functions from BlockedToMMA to be reused for sparse matmul. ++SmallVector ++getWarpsPerTile(Operation *dotOp, ArrayRef shape, int version, ++ int numWarps, const SmallVector &instrShape) { ++ return BlockedToMMA::getWarpsPerTile(dotOp, shape, version, numWarps, ++ instrShape); ++} ++int computeOrigBitWidth(Value x) { ++ return BlockedToMMA::computeOrigBitWidth(x); ++} ++Value getSharedMemMMAOperand(Value v, mlir::PatternRewriter &rewriter, ++ int opIdx, bool allowTranspose) { ++ return getSharedMemoryMMAOperand(v, rewriter, opIdx, allowTranspose); ++} ++ + } // namespace gpu + } // namespace triton + } // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 5cb992714..cdafdffce 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp diff --git a/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch b/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch index 307e4cdbb947b4..ae8c5c61522860 100644 --- a/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch +++ b/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch @@ -181,46 +181,48 @@ diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect index c47558fa6..35f0cca95 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp -@@ -37,7 +37,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) { +@@ -37,8 +37,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) { return 0; } -SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, -+template -+SmallVector warpsPerTileV2(DotType dotOp, const ArrayRef shape, - int numWarps) { +- int numWarps) { ++SmallVector ++warpsPerTileV2(Operation *dotOp, const ArrayRef shape, int numWarps) { auto rank = shape.size(); // Early exit for batched matmul -@@ -51,8 +52,8 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, + if (rank == 3) +@@ -51,9 +51,8 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, auto slices = multiRootGetSlice(dotOp, {filter}, {filter}); bool hasChainedDot = false; for (Operation *op : slices) { - if (isa(op) && (op != dotOp)) { - auto chainedDot = cast(op); -+ if (isa(op) && (op != dotOp)) { -+ auto chainedDot = cast(op); - auto resTy = chainedDot.getResult().getType(); +- auto resTy = chainedDot.getResult().getType(); ++ if (dotOp->getName() == op->getName() && op != dotOp) { ++ auto resTy = cast(op->getResult(0).getType()); if (resTy.getRank() != rank) { continue; -@@ -96,12 +97,13 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, - return ret; + } +@@ -97,12 +96,13 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, } --SmallVector + SmallVector -warpsPerTileV3(DotOp dotOp, const ArrayRef shape, int numWarps, -- const SmallVector &instrShape) { -+template -+SmallVector warpsPerTileV3( -+ DotType dotOp, const ArrayRef shape, int numWarps, -+ const SmallVector &instrShape) { ++warpsPerTileV3(Operation *dotOp, const ArrayRef shape, int numWarps, + const SmallVector &instrShape) { SetVector slices; - mlir::getForwardSlice(dotOp.getResult(), &slices); +- mlir::getForwardSlice(dotOp.getResult(), &slices); - if (llvm::find_if(slices, [](Operation *op) { return isa(op); }) != -+ if (llvm::find_if(slices, [](Operation *op) { return isa(op); }) != - slices.end()) +- slices.end()) ++ mlir::getForwardSlice(dotOp->getResult(0), &slices); ++ if (llvm::find_if(slices, [&](Operation *op) { ++ return dotOp->getName() == op->getName(); ++ }) != slices.end()) return {(unsigned)numWarps, 1}; -@@ -191,6 +193,7 @@ class BlockedToMMA : public mlir::OpRewritePattern { + // For MMAv3, the smallest indivisible unit of warp shape is (4, 1). +@@ -167,6 +167,7 @@ class BlockedToMMA : public mlir::OpRewritePattern { mlir::TypeID::get()); } @@ -228,140 +230,37 @@ index c47558fa6..35f0cca95 100644 // Finds the first different bitwidth in the chain of shape-preserving // unary ops that x depends on. // There are two primary scenarios: -@@ -224,14 +227,14 @@ class BlockedToMMA : public mlir::OpRewritePattern { - return origBitWidth; +@@ -206,7 +207,7 @@ public: } --public: - BlockedToMMA(mlir::MLIRContext *context, int computeCapability) - : OpRewritePattern(context), computeCapability(computeCapability) { - } - -- static SmallVector + static SmallVector - getWarpsPerTile(DotOp dotOp, const ArrayRef shape, int version, -- int numWarps, const SmallVector &instrShape) { -+ template -+ static SmallVector getWarpsPerTile( -+ DotType dotOp, const ArrayRef shape, int version, int numWarps, -+ const SmallVector &instrShape) { ++ getWarpsPerTile(Operation *dotOp, const ArrayRef shape, int version, + int numWarps, const SmallVector &instrShape) { switch (version) { case 2: - return warpsPerTileV2(dotOp, shape, numWarps); -@@ -359,6 +362,106 @@ public: - return success(); +@@ -405,6 +406,21 @@ public: } }; -+ -+class SparseBlockedToMMA : public mlir::RewritePattern { -+ public: -+ using SparseDotOp = mlir::triton::gpu::SparseDotOp; -+ using SparseDotMetaEncodingAttr = -+ mlir::triton::gpu::SparseDotMetaEncodingAttr; -+ -+ SparseBlockedToMMA(mlir::MLIRContext *context, int computeCapability) -+ : mlir::RewritePattern(SparseDotOp::getOperationName(), 2, context), -+ computeCapability(computeCapability) {} -+ -+ mlir::LogicalResult matchAndRewrite( -+ mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { -+ auto dotOp = cast(op); -+ auto ctx = op->getContext(); -+ Value a = dotOp.getA(); -+ Value b = dotOp.getB(); -+ -+ // Check data-types and SM compatibility -+ RankedTensorType oldRetType = dotOp.getType(); -+ if (!oldRetType.getEncoding() || -+ isa(oldRetType.getEncoding())) -+ return failure(); -+ -+ assert(computeCapability >= 80 && -+ "SparseDot is supported on Ampere and higher"); -+ bool allowV3 = !triton::tools::getBoolEnv("DISABLE_MMA_V3"); -+ int versionMajor = computeCapability >= 90 && allowV3 ? 3 : 2; -+ -+ // get MMA encoding for the given number of warps -+ auto retShapePerCTA = getShapePerCTA(oldRetType); -+ auto mod = op->getParentOfType(); -+ int numWarps = TritonGPUDialect::getNumWarps(mod); -+ auto CTALayout = getCTALayout(oldRetType.getEncoding()); -+ -+ auto instrShape = -+ mmaVersionToInstrShape(versionMajor, retShapePerCTA, -+ cast(a.getType()), numWarps); -+ auto warpsPerTile = BlockedToMMA::getWarpsPerTile( -+ dotOp, retShapePerCTA, versionMajor, numWarps, instrShape); -+ NvidiaMmaEncodingAttr mmaEnc = -+ NvidiaMmaEncodingAttr::get(ctx, versionMajor, /*versionMinor=*/0, -+ warpsPerTile, CTALayout, instrShape); -+ auto newRetType = RankedTensorType::get( -+ oldRetType.getShape(), oldRetType.getElementType(), mmaEnc); -+ -+ // convert accumulator -+ auto oldAcc = dotOp.getOperand(2); -+ auto newAcc = rewriter.create(oldAcc.getLoc(), -+ newRetType, oldAcc); -+ -+ if (versionMajor == 2) { -+ int minBitwidth = std::min(BlockedToMMA::computeOrigBitWidth(a), -+ BlockedToMMA::computeOrigBitWidth(b)); -+ int kWidth = 32 / minBitwidth; -+ -+ // convert A operand -+ auto oldAType = cast(a.getType()); -+ auto newAEncoding = -+ DotOperandEncodingAttr::get(ctx, 0, mmaEnc, kWidth); -+ auto newAType = RankedTensorType::get( -+ oldAType.getShape(), oldAType.getElementType(), newAEncoding); -+ a = rewriter.create(a.getLoc(), newAType, a); -+ -+ // convert B operand -+ auto oldBType = cast(b.getType()); -+ auto newBEncoding = -+ DotOperandEncodingAttr::get(ctx, 1, mmaEnc, kWidth); -+ auto newBType = RankedTensorType::get( -+ oldBType.getShape(), oldBType.getElementType(), newBEncoding); -+ b = rewriter.create(b.getLoc(), newBType, b); -+ } else { -+ auto eltType = dotOp.getA().getType().getElementType(); -+ // In MMAV3 tranpose is only supported for f16 and bf16. -+ bool allowTranspose = eltType.isF16() || eltType.isBF16(); -+ a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose); -+ b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose); -+ } -+ -+ // convert metadata -+ Value meta = dotOp.getAMeta(); -+ auto oldMetaType = cast(meta.getType()); -+ auto newMetaType = RankedTensorType::get( -+ oldMetaType.getShape(), oldMetaType.getElementType(), -+ SparseDotMetaEncodingAttr::get(ctx, mmaEnc)); -+ meta = -+ rewriter.create(meta.getLoc(), newMetaType, meta); -+ -+ // convert dot instruction -+ auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, -+ newAcc, meta); -+ -+ rewriter.replaceOpWithNewOp(op, oldRetType, -+ newDot.getResult()); -+ return success(); -+ } -+ -+ private: -+ int computeCapability; -+}; - } // namespace - - static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, -@@ -420,6 +523,7 @@ public: - mlir::RewritePatternSet patterns(context); - patterns.add(context, computeCapability); -+ patterns.add(context, computeCapability); - if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { - signalPassFailure(); - } ++// Expose helper functions from BlockedToMMA to be reused for sparse matmul. ++SmallVector ++getWarpsPerTile(Operation *dotOp, ArrayRef shape, int version, ++ int numWarps, const SmallVector &instrShape) { ++ return BlockedToMMA::getWarpsPerTile(dotOp, shape, version, numWarps, ++ instrShape); ++} ++int computeOrigBitWidth(Value x) { ++ return BlockedToMMA::computeOrigBitWidth(x); ++} ++Value getSharedMemMMAOperand(Value v, mlir::PatternRewriter &rewriter, ++ int opIdx, bool allowTranspose) { ++ return getSharedMemoryMMAOperand(v, rewriter, opIdx, allowTranspose); ++} ++ + } // namespace gpu + } // namespace triton + } // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 5cb992714..cdafdffce 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 84594a623d9bcb..7bb8e33a62336b 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -1760,14 +1760,14 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_shard_autotuning(), "Shard autotuning between participating compiler processes (typically in " "multi-host setups) and join the results when it's done.")); - flag_list->push_back(tsl::Flag( - "xla_gpu_kernel_cache_file", - string_setter_for(&DebugOptions::set_xla_gpu_kernel_cache_file), - debug_options->xla_gpu_kernel_cache_file(), - "Path to a file to cache compiled kernels. If the file doesn't exist " - "write the compilation cache of the first compiled HLO module into it." - "Once the file exists, further compilations will read it to reuse " - "the kernels, but not write it. This behavior may change later.")); + flag_list->push_back( + tsl::Flag("xla_gpu_kernel_cache_file", + string_setter_for(&DebugOptions::set_xla_gpu_kernel_cache_file), + debug_options->xla_gpu_kernel_cache_file(), + "Path to a file to cache compiled kernels. Cached kernels get " + "reused in further compilations; not yet cached kernels are " + "compiled as usual and get appended to the cache file whenever " + "possible.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 638f9dc37b7007..b61d9785727ec1 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -606,6 +606,7 @@ cc_library( "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@triton//:TritonDialects", + "@triton//:TritonGPUToLLVM", "@triton//:TritonGPUTransforms", ], ) @@ -3425,6 +3426,7 @@ cc_library( ":custom_kernel_fusion_rewriter", ":dot_dimension_sorter", ":dot_operand_converter", + ":double_buffer_loop_unrolling", ":executable_proto_cc", ":fusion_merger", ":fusion_wrapper", @@ -3451,7 +3453,7 @@ cc_library( ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", - ":double_buffer_loop_unrolling", + ":kernel_reuse_cache", ":matmul_utils", ":metrics", ":move_copy_to_users", @@ -3625,6 +3627,7 @@ cc_library( ":command_buffer_scheduling", ":execution_stream_assignment", ":fusion_pipeline", + ":gpu_latency_hiding_scheduler", ":ir_emitter_context", ":ir_emitter_unnested", ":prepare_hlo_for_ir_emitting_pipeline", @@ -3872,6 +3875,7 @@ xla_test( deps = [ ":gpu_constants", ":gpu_hlo_schedule", + ":gpu_latency_hiding_scheduler", ":nvptx_compiler_impl", "//xla:util", "//xla:xla_proto_cc", @@ -4231,11 +4235,9 @@ cc_library( hdrs = ["gpu_hlo_schedule.h"], deps = [ ":backend_configs_cc", - ":cublas_cudnn", ":gpu_latency_hiding_scheduler", ":gpu_schedule_postprocessing", "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", @@ -5730,7 +5732,9 @@ xla_cc_test( deps = [ ":kernel_reuse_cache", "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:test", ], ) @@ -6257,6 +6261,8 @@ cc_library( hdrs = ["gpu_latency_hiding_scheduler.h"], deps = [ ":backend_configs_cc", + ":cublas_cudnn", + "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "//xla/service:collective_ops_utils", diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 601619ee6ee973..440e5ad600cd7c 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -136,6 +136,7 @@ limitations under the License. #include "xla/service/gpu/gpu_executable.h" #include "xla/service/gpu/gpu_float_support.h" #include "xla/service/gpu/gpu_hlo_schedule.h" +#include "xla/service/gpu/gpu_latency_hiding_scheduler.h" #include "xla/service/gpu/gpu_layout_assignment.h" #include "xla/service/gpu/gpu_p2p_pipeliner.h" #include "xla/service/gpu/gpu_reduce_scatter_creator.h" @@ -147,6 +148,7 @@ limitations under the License. #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/ir_emitter_unnested.h" +#include "xla/service/gpu/kernel_reuse_cache.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/metrics.h" #include "xla/service/gpu/model/gpu_cost_model_stats_collection.h" @@ -1922,12 +1924,7 @@ absl::StatusOr GpuCompiler::CompileAndLink( std::string ptx_snippets; std::vector> binaries_to_link; binaries_to_link.reserve(compile_results.size()); - struct NamedBinary { - // The string is the function name or empty just like for llvm_modules. - std::string name; - std::vector binary; - }; - std::vector binaries_to_cache; + std::vector binaries_to_cache; binaries_to_cache.reserve(single_function_module_count); for (const auto& [name, maybe_result] : compile_results) { TF_ASSIGN_OR_RETURN(auto result, maybe_result); @@ -1948,51 +1945,40 @@ absl::StatusOr GpuCompiler::CompileAndLink( return FailedPrecondition("File path can not be resolved: %s", cache_path); } - CompilationCacheProto& cache = + // current_cache contains new kernels from the current compilation and + // kernels to reuse from previous compilations if some were loaded from the + // cache file. + const CompilationCacheProto& current_cache = compile_module_results.kernel_compilation_cache; - if (tsl::Env::Default()->FileExists(resolved_path).ok()) { + const bool cache_file_exists = + tsl::Env::Default()->FileExists(resolved_path).ok(); + if (cache_file_exists) { + // Pick reused binaries from previous compilations needed to link the + // current executable. int loaded_kernel_count = 0; - for (const auto& [name, entry] : cache.entries()) { - if (llvm_module->getFunction(name)) { - VLOG(5) - << "Skipping cached " << name - << " in favor of the just compiled kernel with the same name."; - CHECK(entry.binary().empty()); + for (const auto& [name, entry] : current_cache.entries()) { + if (llvm_module->getFunction(name) != nullptr) { + VLOG(5) << "Using the just compiled kernel for " << name; + TF_RET_CHECK(entry.binary().empty()) + << name + << " is a just compiled kernel and is not expected to have a " + "binary yet."; continue; } const uint8_t* binary = reinterpret_cast(entry.binary().data()); binaries_to_link.push_back( std::vector(binary, binary + entry.binary().size())); - VLOG(5) << "Loaded " << name << ": " << entry.binary().size(); + VLOG(5) << "Using " << name << " from cache: " << entry.binary().size(); ++loaded_kernel_count; } - VLOG(2) << "Loaded " << loaded_kernel_count << " / " - << cache.entries_size() << " cached kernels."; - } else { - auto entries = cache.mutable_entries(); - for (const auto& [name, binary] : binaries_to_cache) { - auto it = entries->find(name); - if (it == entries->end()) { - continue; - } - it->second.set_binary(reinterpret_cast(binary.data()), - binary.size()); - VLOG(5) << "Cached kernels: " << name << ": " << binary.size(); - } - for (auto it = entries->begin(); it != entries->end();) { - if (it->second.binary().empty()) { - it = entries->erase(it); - } else { - ++it; - } - } - if (cache.entries_size() > 0) { - TF_RETURN_IF_ERROR(tsl::WriteStringToFile( - tsl::Env::Default(), resolved_path, cache.SerializeAsString())); - VLOG(2) << "Stored " << cache.entries_size() << " / " - << binaries_to_cache.size(); - } + VLOG(2) << "Using " << loaded_kernel_count << " / " + << current_cache.entries_size() << " cached kernels."; + } + if (!binaries_to_cache.empty()) { + TF_RETURN_IF_ERROR( + UpdateDiskKernelCache(resolved_path, /*do_append=*/cache_file_exists, + current_cache, binaries_to_cache)); } } @@ -2007,6 +1993,7 @@ absl::StatusOr GpuCompiler::CompileAndLink( return maybe_backend_result.status(); } VLOG(4) << "Binary size after linking [B]: " << maybe_backend_result->size(); + compile_module_results.kernel_compilation_cache.Clear(); return BackendCompileResult{ptx_snippets, std::move(*maybe_backend_result)}; } diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc index 643a9ff7e906b0..8c11571a731e62 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -559,6 +559,7 @@ CHECK-SAME: frontend_attributes={_xla_send_recv_pipeline="0"} class KernelCacheTest : public HloTestBase { public: void SetUp() override { + CHECK(tsl::Env::Default()->LocalTempFilename(&cache_file_name_)); HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(bool can_use_link_modules, @@ -568,8 +569,8 @@ class KernelCacheTest : public HloTestBase { GTEST_SKIP() << "Caching compiled kernels requires support of linking."; } } + DebugOptions GetDebugOptionsForTest() override { - CHECK(tsl::Env::Default()->LocalTempFilename(&cache_file_name_)); DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_kernel_cache_file(cache_file_name_); debug_options.set_xla_gpu_enable_llvm_module_compilation_parallelism(true); @@ -583,16 +584,16 @@ class KernelCacheTest : public HloTestBase { return true; } - bool NonEmptyCacheExists() { + int CacheEntryCount() { if (!CacheFileExists()) { - return false; + return 0; } std::string serialized; TF_EXPECT_OK(tsl::ReadFileToString(tsl::Env::Default(), cache_file_name_, &serialized)); CompilationCacheProto proto; EXPECT_TRUE(proto.ParseFromString(std::string(serialized))); - return proto.entries_size() > 0; + return proto.entries_size(); } std::string cache_file_name_; @@ -609,9 +610,10 @@ TEST_F(KernelCacheTest, CacheIsGenerated) { EXPECT_FALSE(CacheFileExists()); EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); // First run generates a cache - EXPECT_TRUE(NonEmptyCacheExists()); + EXPECT_EQ(CacheEntryCount(), 1); // Second run - with cache file EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); + EXPECT_EQ(CacheEntryCount(), 1); } TEST_F(KernelCacheTest, NoCacheIsGeneratedWithoutCompiledKernels) { @@ -626,10 +628,10 @@ TEST_F(KernelCacheTest, NoCacheIsGeneratedWithoutCompiledKernels) { EXPECT_FALSE(CacheFileExists()); } -TEST_F(KernelCacheTest, UsingCacheFromAnotherModuleDoesNotFail) { +TEST_F(KernelCacheTest, CacheGrowsWithNewKernels) { EXPECT_FALSE(CacheFileExists()); EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); - EXPECT_TRUE(NonEmptyCacheExists()); + EXPECT_EQ(CacheEntryCount(), 1); // Second run - with cache file and another HLO EXPECT_TRUE(Run(R"( ENTRY e { @@ -637,6 +639,7 @@ TEST_F(KernelCacheTest, UsingCacheFromAnotherModuleDoesNotFail) { ROOT _ = s8[] multiply(p, p) })", /*run_hlo_passes=*/false)); + EXPECT_EQ(CacheEntryCount(), 2); } class KernelCacheTestSingleThreaded : public KernelCacheTest { @@ -651,8 +654,9 @@ class KernelCacheTestSingleThreaded : public KernelCacheTest { TEST_F(KernelCacheTestSingleThreaded, CacheIsGenerated) { EXPECT_FALSE(CacheFileExists()); EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); - EXPECT_TRUE(NonEmptyCacheExists()); + EXPECT_EQ(CacheEntryCount(), 1); EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); + EXPECT_EQ(CacheEntryCount(), 1); } class NoKernelCacheTest : public KernelCacheTest { @@ -666,7 +670,7 @@ class NoKernelCacheTest : public KernelCacheTest { TEST_F(NoKernelCacheTest, NoCacheWithoutCompilationParallelism) { EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); - EXPECT_FALSE(NonEmptyCacheExists()); + EXPECT_FALSE(CacheFileExists()); } } // namespace diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc index 0b1c9629c5c7f7..07449034b1705f 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -37,7 +36,6 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -48,7 +46,6 @@ limitations under the License. #include "xla/service/buffer_value.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/gpu_latency_hiding_scheduler.h" #include "xla/service/gpu/gpu_schedule_postprocessing.h" #include "xla/service/gpu/model/analytical_latency_estimator.h" @@ -59,7 +56,6 @@ limitations under the License. #include "xla/service/profile_guided_latency_estimator.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" #include "xla/util.h" #include "tsl/platform/env.h" @@ -73,19 +69,6 @@ namespace gpu { namespace { -// A threshold for which we consider AR to be costly perf-wise. -static constexpr int64_t kCostlyAllReduceThreshold = 30 * 1024 * 1024; - -// Multiplier which we apply to expand the base cost for the costly AR. -static constexpr int64_t kCostlyAllReduceMultiplier = 4; - -bool IsNopInstruction(const HloInstruction& hlo) { - HloOpcode op = hlo.opcode(); - return op == HloOpcode::kGetTupleElement || op == HloOpcode::kBitcast || - op == HloOpcode::kConstant || op == HloOpcode::kParameter || - hlo.IsEffectiveBitcast(); -} - bool ShouldScheduleAsEarlyAsPossible(const HloInstruction& instr) { switch (instr.opcode()) { case HloOpcode::kAllReduceStart: @@ -280,70 +263,6 @@ SchedulerConfig GetSchedulerConfig(int64_t memory_limit) { return config; } -class GpuLatencyEstimator : public ApproximateLatencyEstimator { - public: - explicit GpuLatencyEstimator( - int64_t pointer_size, - GetCanonicalAsyncOpFunc func = GpuGetCanonicalAsyncOp) - : ApproximateLatencyEstimator(func), pointer_size_(pointer_size) {} - TimeCost NodeCost(const HloInstruction* instr) const override { - if (IsNopInstruction(*instr)) { - return 0.0; - } - // Consider cublas/cuddn/softmax custom calls as medium cost. Since the - // latency between async-start and async-done is 5000 and cost of each - // custom call is 1000, the LHS will try to schedule approximately 5 of - // these in between each start/end pair. - if (instr->opcode() == HloOpcode::kCustomCall) { - if (IsCublasGemm(*instr) || IsCustomCallToDnnConvolution(*instr)) { - return ApproximateLatencyEstimator::kMediumCost; - } - // consider other custom calls as medium cost for now. Keeping the case - // explicitly separate for further tuning. - return ApproximateLatencyEstimator::kMediumCost; - } - return ApproximateLatencyEstimator::NodeCost(instr); - } - - LatencyEstimator::TimeCost GetLatencyBetween( - const HloGraphNode& from, const HloGraphNode& target) const override { - if (IsAsyncPair(from, target)) { - if (from.GetInstr().opcode() == HloOpcode::kRecv) { - // Recv -> RecvDone has a low latency. - return ApproximateLatencyEstimator::kLowLatency; - } else if (from.GetInstr().opcode() == HloOpcode::kSend) { - // Send -> SendDone has a very high latency. - return ApproximateLatencyEstimator::kHighLatency * 10; - } - - bool enable_approx_collectives = - from.GetInstr() - .GetModule() - ->config() - .debug_options() - .xla_gpu_enable_approx_costly_collectives(); - bool is_all_reduce = - from.GetInstr().opcode() == HloOpcode::kAllReduceStart; - bool collective_size_exceeds_threshold = - GetSizeOfShape(from.GetInstr().shape(), pointer_size_) > - kCostlyAllReduceThreshold; - if (enable_approx_collectives && is_all_reduce && - collective_size_exceeds_threshold) { - return ApproximateLatencyEstimator::kHighLatency * - kCostlyAllReduceMultiplier; - } - - return ApproximateLatencyEstimator::kHighLatency; - } - // Every other instruction we consider synchronous, which means the - // latency between each of them is always one unit. - return ApproximateLatencyEstimator::kLowLatency; - } - - private: - int64_t pointer_size_; -}; - tensorflow::profiler::ProfiledInstructionsProto GetProfileForFingerprint( tensorflow::profiler::ProfiledInstructionsProto& profile, const std::string& fingerprint) { @@ -533,16 +452,6 @@ absl::Status IsProfileApplicable( return absl::OkStatus(); } -int64_t GetSizeOfShape(const Shape& shape, int pointer_size) { - int64_t size = ShapeUtil::ByteSizeOf(shape, pointer_size); - if (shape.IsTuple() || shape.is_static()) { - return size; - } - // Each dynamic dimension size is represented as a S32. - int64_t metadata_size = sizeof(int32_t) * shape.dimensions_size(); - return size + metadata_size; -} - static int64_t GetSchedulerMemoryLimit( const HloModule* module, const se::DeviceDescription& gpu_device_info, int pointer_size); diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h index d20a056494666d..7263eff68eaa13 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h @@ -37,8 +37,6 @@ absl::Status IsProfileApplicable( const HloModule* module, const tensorflow::profiler::ProfiledInstructionsProto& profile); -int64_t GetSizeOfShape(const Shape& shape, int pointer_size); - struct ScheduleMetadata { int64_t scheduler_mem_limit; }; diff --git a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc index 33315bfacfbac0..1d5a3b99d49e43 100644 --- a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc +++ b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc @@ -27,13 +27,30 @@ limitations under the License. #include "xla/hlo/utils/hlo_query.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/latency_hiding_scheduler.h" +#include "xla/shape.h" +#include "xla/shape_util.h" namespace xla { namespace gpu { namespace { +// A threshold for which we consider AR to be costly perf-wise. +static constexpr int64_t kCostlyAllReduceThreshold = 30 * 1024 * 1024; + +// Multiplier which we apply to expand the base cost for the costly AR. +static constexpr int64_t kCostlyAllReduceMultiplier = 4; + +// Classifies `hlo` instruction as noop or not. +bool IsNopInstruction(const HloInstruction& hlo) { + HloOpcode op = hlo.opcode(); + return op == HloOpcode::kGetTupleElement || op == HloOpcode::kBitcast || + op == HloOpcode::kConstant || op == HloOpcode::kParameter || + hlo.IsEffectiveBitcast(); +} + bool IsAsyncComputeOp(const HloInstruction& hlo) { return (hlo.opcode() == HloOpcode::kAsyncStart || hlo.opcode() == HloOpcode::kAsyncDone) && @@ -74,6 +91,16 @@ std::pair GetP2PResourceAndUsage( } // namespace +int64_t GetSizeOfShape(const Shape& shape, int pointer_size) { + int64_t size = ShapeUtil::ByteSizeOf(shape, pointer_size); + if (shape.IsTuple() || shape.is_static()) { + return size; + } + // Each dynamic dimension size is represented as a S32. + int64_t metadata_size = sizeof(int32_t) * shape.dimensions_size(); + return size + metadata_size; +} + CanonicalAsyncOp GpuGetCanonicalAsyncOp(const HloInstruction& hlo) { switch (hlo.opcode()) { case HloOpcode::kSend: @@ -89,6 +116,7 @@ CanonicalAsyncOp GpuGetCanonicalAsyncOp(const HloInstruction& hlo) { } } +// GpuAsyncTrackerBase implementations begin GpuAsyncTrackerBase::GpuAsyncTrackerBase(const SchedulerConfig& config, GetCanonicalAsyncOpFunc func) : AsyncTracker(config, func) {} @@ -132,7 +160,9 @@ void GpuAsyncTrackerBase::PostProcessScheduleGraph( } } } +// GpuAsyncTrackerBase implementations end +// GpuAsyncTracker implementations begin GpuAsyncTracker::GpuAsyncTracker(const SchedulerConfig& config) : GpuAsyncTrackerBase(config) {} @@ -278,5 +308,67 @@ int64_t GpuAsyncTracker::GetNumResourcesPerInstruction( return num_resources - (found ? 1 : 0); } +// GpuAsyncTracker implementations end + +// GpuLatencyEstimator implementations begin +GpuLatencyEstimator::GpuLatencyEstimator(int64_t pointer_size, + GetCanonicalAsyncOpFunc func) + : ApproximateLatencyEstimator(func), pointer_size_(pointer_size) {} + +ApproximateLatencyEstimator::TimeCost GpuLatencyEstimator::NodeCost( + const HloInstruction* instr) const { + if (IsNopInstruction(*instr)) { + return 0.0; + } + // Consider cublas/cuddn/softmax custom calls as medium cost. Since the + // latency between async-start and async-done is 5000 and cost of each + // custom call is 1000, the LHS will try to schedule approximately 5 of + // these in between each start/end pair. + if (instr->opcode() == HloOpcode::kCustomCall) { + if (IsCublasGemm(*instr) || IsCustomCallToDnnConvolution(*instr)) { + return ApproximateLatencyEstimator::kMediumCost; + } + // consider other custom calls as medium cost for now. Keeping the case + // explicitly separate for further tuning. + return ApproximateLatencyEstimator::kMediumCost; + } + return ApproximateLatencyEstimator::NodeCost(instr); +} + +ApproximateLatencyEstimator::TimeCost GpuLatencyEstimator::GetLatencyBetween( + const HloGraphNode& from, const HloGraphNode& to) const { + if (IsAsyncPair(from, to)) { + if (from.GetInstr().opcode() == HloOpcode::kRecv) { + // Recv -> RecvDone has a low latency. + return ApproximateLatencyEstimator::kLowLatency; + } else if (from.GetInstr().opcode() == HloOpcode::kSend) { + // Send -> SendDone has a very high latency. + return ApproximateLatencyEstimator::kHighLatency * 10; + } + + bool enable_approx_collectives = + from.GetInstr() + .GetModule() + ->config() + .debug_options() + .xla_gpu_enable_approx_costly_collectives(); + bool is_all_reduce = from.GetInstr().opcode() == HloOpcode::kAllReduceStart; + bool collective_size_exceeds_threshold = + GetSizeOfShape(from.GetInstr().shape(), pointer_size_) > + kCostlyAllReduceThreshold; + if (enable_approx_collectives && is_all_reduce && + collective_size_exceeds_threshold) { + return ApproximateLatencyEstimator::kHighLatency * + kCostlyAllReduceMultiplier; + } + + return ApproximateLatencyEstimator::kHighLatency; + } + // Every other instruction we consider synchronous, which means the + // latency between each of them is always one unit. + return ApproximateLatencyEstimator::kLowLatency; +} +// GpuLatencyEstimator implementations end + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.h b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.h index a7b0e0f7546b99..fae9debc8fc291 100644 --- a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.h +++ b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/latency_hiding_scheduler.h" +#include "xla/shape.h" namespace xla { namespace gpu { @@ -29,6 +30,9 @@ namespace gpu { // E.g. AllReduceStart is broken down into Reduce + AsyncStart. CanonicalAsyncOp GpuGetCanonicalAsyncOp(const HloInstruction& hlo); +// Returns size of the `shape` given the `pointer_size`. +int64_t GetSizeOfShape(const Shape& shape, int pointer_size); + // GPU specific resources for latency hiding scheduler. // // We use two different set of resources to model the scheduling of asynchronous @@ -95,6 +99,25 @@ class GpuAsyncTracker : public GpuAsyncTrackerBase { int64_t resource_type, const HloInstruction& instr) const override; }; +// GPU approximate latency estimator. It is a set of hardcoded heuristics +// for every instruction and async instruction pairs. +class GpuLatencyEstimator : public ApproximateLatencyEstimator { + public: + explicit GpuLatencyEstimator( + int64_t pointer_size, + GetCanonicalAsyncOpFunc func = GpuGetCanonicalAsyncOp); + + // Uses the approximate node for an instruction `instr`. + TimeCost NodeCost(const HloInstruction* instr) const override; + + // Returns a latency estimation between nodes `from` and `to`. + TimeCost GetLatencyBetween(const HloGraphNode& from, + const HloGraphNode& to) const override; + + private: + int64_t pointer_size_; +}; + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_cuda.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_cuda.cc index 217ad889416a4a..c8ffbfacbac30b 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_cuda.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_cuda.cc @@ -74,6 +74,7 @@ absl::Status CreateTritonPipeline( pm.addPass(mlir::createTritonNvidiaGPUPlanCTAPass(&out_cluster_info)); pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); pm.addPass(mt::gpu::createTritonGPUOptimizeThreadLocality()); + pm.addPass(createSparseBlockedToMMAPass()); pm.addPass(mt::gpu::createTritonGPUAccelerateMatmul()); pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); pm.addPass( diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_rocm.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_rocm.cc index 79ef0ce4a5d05c..2a6e6c3c805cd8 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_rocm.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_rocm.cc @@ -75,6 +75,7 @@ absl::Status CreateTritonPipeline( pm.addPass(mt::gpu::createTritonGPUCoalesce()); pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); pm.addPass(mt::gpu::createTritonGPUOptimizeThreadLocality()); + pm.addPass(createSparseBlockedToMMAPass()); pm.addPass(mt::gpu::createTritonGPUAccelerateMatmul()); pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); // TODO ROCm Check if we want to compare MI100 and greater diff --git a/third_party/xla/xla/service/gpu/kernel_reuse_cache.cc b/third_party/xla/xla/service/gpu/kernel_reuse_cache.cc index 0b29670e6cb43f..0c1d97af215688 100644 --- a/third_party/xla/xla/service/gpu/kernel_reuse_cache.cc +++ b/third_party/xla/xla/service/gpu/kernel_reuse_cache.cc @@ -139,6 +139,42 @@ CompilationCacheProto KernelReuseCache::Export() const { return proto; } +absl::Status UpdateDiskKernelCache( + absl::string_view path, const bool do_append, + const CompilationCacheProto& current_cache, + absl::Span binaries_to_cache) { + CompilationCacheProto disk_cache; + if (do_append) { + std::string serialized; + TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), + std::string(path), &serialized)); + if (!disk_cache.ParseFromString(std::string(serialized))) { + return Internal("Failed to parse serialized CompilationCacheProto."); + } + } + auto entries = disk_cache.mutable_entries(); + int stored_kernel_count = 0; + for (const auto& [name, binary] : binaries_to_cache) { + auto it_current = current_cache.entries().find(name); + TF_RET_CHECK(it_current != current_cache.entries().end()); + auto [it_disk, inserted] = entries->insert({name, it_current->second}); + TF_RET_CHECK(inserted); + TF_RET_CHECK(!binary.empty()); + it_disk->second.set_binary(reinterpret_cast(binary.data()), + binary.size()); + VLOG(5) << "Cached kernel: " << name << ": " << binary.size(); + ++stored_kernel_count; + } + if (stored_kernel_count > 0) { + TF_RETURN_IF_ERROR(tsl::WriteStringToFile(tsl::Env::Default(), + std::string(path), + disk_cache.SerializeAsString())); + VLOG(2) << "Stored " << stored_kernel_count << " / " + << binaries_to_cache.size() << " kernels in the cache file."; + } + return absl::OkStatus(); +} + std::pair, bool> KernelReuseCache::GetWithStatus( const HloComputation* fused_computation, diff --git a/third_party/xla/xla/service/gpu/kernel_reuse_cache.h b/third_party/xla/xla/service/gpu/kernel_reuse_cache.h index a66a5fac70dd50..bf165b5a7033f0 100644 --- a/third_party/xla/xla/service/gpu/kernel_reuse_cache.h +++ b/third_party/xla/xla/service/gpu/kernel_reuse_cache.h @@ -45,6 +45,10 @@ class KernelReuseCache { int64_t shmem_bytes = 0; std::string binary; }; + struct NamedBinary { + std::string name; + std::vector binary; + }; absl::Status Load(const CompilationCacheProto& proto); // Exporting skips kernels that were loaded but not used during emission. @@ -88,6 +92,14 @@ class KernelReuseCache { absl::flat_hash_set hits_; }; +// Add kernels to the cache file. Binaries are taken from binaries_to_cache, +// all other kernel properties are taken from current_cache. +// do_append makes an existing file be loaded first. +absl::Status UpdateDiskKernelCache( + absl::string_view path, bool do_append, + const CompilationCacheProto& current_cache, + absl::Span binaries_to_cache); + // Calculates the fingerprint of a (fused_computation, kernel_arguments, // discriminator) tuple. // diff --git a/third_party/xla/xla/service/gpu/kernel_reuse_cache_test.cc b/third_party/xla/xla/service/gpu/kernel_reuse_cache_test.cc index 19b0c0d0d3f3c7..75c9b009f7d93d 100644 --- a/third_party/xla/xla/service/gpu/kernel_reuse_cache_test.cc +++ b/third_party/xla/xla/service/gpu/kernel_reuse_cache_test.cc @@ -14,8 +14,9 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/kernel_reuse_cache.h" +#include #include "tsl/lib/core/status_test_util.h" -#include "tsl/platform/test.h" +#include "tsl/platform/env.h" namespace xla { namespace gpu { @@ -39,6 +40,41 @@ TEST_F(KernelReuseTest, ExportAndLoadWork) { EXPECT_FALSE(cache.IsEmpty()); } +TEST_F(KernelReuseTest, UpdatingDiskKernelCacheWorks) { + std::string cache_file_path; + CHECK(tsl::Env::Default()->LocalTempFilename(&cache_file_path)); + { + const CompilationCacheProto proto = [](std::string kernel_name) { + KernelReuseCache cache; + auto [result, was_cached] = cache.GetWithStatus("fingerprint", [&]() { + return KernelReuseCache::Entry{.kernel_name = kernel_name}; + }); + return cache.Export(); + }("k1"); + TF_EXPECT_OK(UpdateDiskKernelCache(cache_file_path, /*do_append=*/false, + proto, + {{.name = "k1", .binary = {5, 6}}})); + } + { + const CompilationCacheProto proto = [](std::string kernel_name) { + KernelReuseCache cache; + auto [result, was_cached] = cache.GetWithStatus("fingerprint", [&]() { + return KernelReuseCache::Entry{.kernel_name = kernel_name}; + }); + return cache.Export(); + }("k2"); + TF_EXPECT_OK(UpdateDiskKernelCache(cache_file_path, /*do_append=*/true, + proto, + {{.name = "k2", .binary = {7, 8}}})); + } + std::string serialized; + TF_EXPECT_OK( + tsl::ReadFileToString(tsl::Env::Default(), cache_file_path, &serialized)); + CompilationCacheProto proto; + EXPECT_TRUE(proto.ParseFromString(std::string(serialized))); + EXPECT_EQ(proto.entries_size(), 2); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 03dcb308b2367f..be7c6d0d196629 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -691,11 +691,14 @@ xla_cc_test( ":symbolic_tile_analysis", ":tiled_hlo_computation", ":tiled_hlo_instruction", + "//xla:util", "//xla/hlo/ir:hlo", + "//xla/service:instruction_fusion", "//xla/service/gpu:hlo_traversal", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc index 93aec8922eec80..f4c4bff3c7fca6 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc @@ -607,13 +607,19 @@ std::optional MergeConstraintMapIfPresentAndCompatible( } /*static*/ std::optional SymbolicTile::FromIndexingMap( - const IndexingMap& indexing_map) { + IndexingMap indexing_map) { VLOG(1) << "SymbolicTile::FromIndexingMap: " << indexing_map.ToString(); // We do not handle indexing maps with pre-existing constraints for now. + // Let's try to simplify the indexing map, because the constraints my be + // redundant. + // TODO(bchetioui): Consider doing the simplification in the caller, not here. + bool did_simplify = indexing_map.Simplify(); + VLOG(1) << "did_simplify: " << did_simplify; if (indexing_map.GetConstraintsCount() != 0) { VLOG(1) << "Deriving symbolic tile from indexing map with pre-existing " - << "constraints might produce spurious constraints. Bailing out."; + << "constraints might produce spurious constraints. Bailing out. " + << indexing_map.ToString(); return std::nullopt; } diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.h b/third_party/xla/xla/service/gpu/model/symbolic_tile.h index ddce4de4699a28..f5c3f680eae5c3 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile.h +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.h @@ -157,8 +157,7 @@ namespace gpu { // simplified later. class SymbolicTile { public: - static std::optional FromIndexingMap( - const IndexingMap& indexing_map); + static std::optional FromIndexingMap(IndexingMap indexing_map); using ConstraintMap = llvm::DenseMap; diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc index 6350459c0d750b..a0d12abdd1adfd 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -25,7 +25,9 @@ limitations under the License. #include #include #include "absl/algorithm/container.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -34,8 +36,10 @@ limitations under the License. #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/gpu/model/tiled_hlo_instruction.h" +#include "xla/service/instruction_fusion.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/util.h" #include "tsl/platform/statusor.h" namespace xla { @@ -45,7 +49,9 @@ namespace { using detail::GetGoodTilings; using ::testing::ElementsAreArray; using ::testing::ExplainMatchResult; +using ::testing::IsEmpty; using ::testing::Matcher; +using ::testing::Not; using ::testing::SizeIs; using ::testing::status::IsOkAndHolds; using ::testing::status::StatusIs; @@ -83,6 +89,8 @@ class SymbolicTileAnalysisTest : public HloTestBase { if (std::holds_alternative(analysis_or_error)) { return std::get(std::move(analysis_or_error)); } + VLOG(1) << "Cannot analyze module: " + << std::get(analysis_or_error).Explain(); return std::nullopt; } @@ -609,6 +617,75 @@ ENTRY main { {{48, 1}, {48, 2}, {48, 4}})); } +// Logs the tilings if VLOG level 1 is enabled. +// +// Use these arguments to see the log: +// --test_output=all +// --test_arg=--logtostderr +// --test_arg=--vmodule=symbolic_tile_analysis_test=1 +void LogTilingsIfVlog1(absl::Span tilings) { + if (VLOG_IS_ON(1)) { + LOG(INFO) << "Tilings: {"; + for (const SymbolicTileAnalysis::Tiling& tiling : tilings) { + LOG(INFO) << "{" << absl::StrJoin(tiling, ",") << "},"; + } + LOG(INFO) << "}"; + } +} + +TEST_F(SymbolicTileAnalysisTest, GetGoodTilingsWorksForSoftmaxExample) { + // The example is from + // https://github.com/google/paxml/blob/91893818862645f5e9f23b84f530e611551745f6/paxml/contrib/gpu/scripts_gpu/configs.py#L107-L120. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule m + +region { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(param_0, param_1) +} + +region.1 { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT add = f32[] add(param_0, param_1) +} + +fused_computation { + param_0 = f32[8192,50304] parameter(0) + bitcast = f32[4,2048,50304] bitcast(param_0) + constant = f32[] constant(-inf) + reduce = f32[8192] reduce(param_0, constant), dimensions={1}, to_apply=region + bitcast.1 = f32[4,2048] bitcast(reduce) + broadcast = f32[4,2048,50304] broadcast(bitcast.1), dimensions={0,1} + subtract = f32[4,2048,50304] subtract(bitcast, broadcast) + exponential = f32[4,2048,50304] exponential(subtract) + constant.1 = f32[] constant(0) + reduce.1 = f32[4,2048] reduce(exponential, constant.1), dimensions={2}, to_apply=region.1 + log = f32[4,2048] log(reduce.1) + broadcast.1 = f32[4,2048,50304] broadcast(log), dimensions={0,1} + ROOT subtract.1 = f32[4,2048,50304] subtract(subtract, broadcast.1) +} + +ENTRY entry_computation { + param_0 = f32[8192,50304] parameter(0) + ROOT fusion = f32[4,2048,50304] fusion(param_0), kind=kCustom, calls=fused_computation, backend_config={"fusion_backend_config":{"kind":"__triton_softmax"}} +} +)")); + + std::optional opt_analysis = + TryAnalyzeModule(module.get()); + ASSERT_TRUE(opt_analysis.has_value()); + const SymbolicTileAnalysis& analysis = opt_analysis.value(); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector good_tilings, + analysis.GetGoodTilings()); + EXPECT_THAT(good_tilings, Not(IsEmpty())); + LogTilingsIfVlog1(good_tilings); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc index 83471b357e5bd0..8ce8c53184854b 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include @@ -729,7 +730,7 @@ TEST_F(SymbolicTileTest, ParseAffineMap("(d0) -> (d0 mod 6, d0 mod 6)", &mlir_context_), /*dimensions=*/{DimVar{0, 10}}, /*range_vars=*/{}, /*rt_vars=*/{}); - EXPECT_THAT(SymbolicTile::FromIndexingMap(indexing_map), + EXPECT_THAT(SymbolicTile::FromIndexingMap(std::move(indexing_map)), Optional(MatchSymbolicTileString(R"( Symbolic tile with offset_map: ()[s0] -> (0, 0) @@ -740,6 +741,28 @@ TEST_F(SymbolicTileTest, )"))); } +TEST_F(SymbolicTileTest, + CanPropagateTileWhenPreexistingConstraintsCanBeSimplifiedAway) { + // The example is from + // https://github.com/google/paxml/blob/91893818862645f5e9f23b84f530e611551745f6/paxml/contrib/gpu/scripts_gpu/configs.py#L107-L120. + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0, d1, d2)[s0] -> (d0 * 2048 + d1, s0)", + &mlir_context_), + {4, 2048, 50304}, {50304}); + // This constraint is redundant, because it can be derived from the domains of + // the dimension variables. + indexing_map.AddConstraint(ParseAffineExpr("d0 * 2048 + d1", &mlir_context_), + Interval{0, 8191}); + + EXPECT_THAT(SymbolicTile::FromIndexingMap(indexing_map), + Optional(MatchSymbolicTileString(R"( + Symbolic tile with + offset_map: ()[s0, s1, s2] -> (0, 0) + size_map: ()[s0, s1, s2] -> (s0 * s1, 50304) + stride_map: ()[s0, s1, s2] -> (((-s1 + 2049) floordiv 2048) * ((-((-s0 + 5) floordiv 4) + 1) * 2048) + -((-s1 + 2049) floordiv 2048) + 1, 1) + )"))); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc b/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc index 991e209df057a1..642a0cc9eca438 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/service/buffer_value.h" #include "xla/service/gpu/gpu_constants.h" #include "xla/service/gpu/gpu_hlo_schedule.h" +#include "xla/service/gpu/gpu_latency_hiding_scheduler.h" #include "xla/service/hlo_ordering.h" #include "xla/service/logical_buffer.h" #include "xla/stream_executor/device_description.h" diff --git a/third_party/xla/xla/service/gpu/tests/sparse_ttg_accelerate_matmul.mlir b/third_party/xla/xla/service/gpu/tests/sparse_ttg_accelerate_matmul.mlir index f1b0f88932a1f9..65bcbd87ddf130 100644 --- a/third_party/xla/xla/service/gpu/tests/sparse_ttg_accelerate_matmul.mlir +++ b/third_party/xla/xla/service/gpu/tests/sparse_ttg_accelerate_matmul.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-accelerate-matmul | FileCheck %s +// RUN: sparse-opt %s -split-input-file -sparse-blocked-to-mma | FileCheck %s #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> // CHECK: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> diff --git a/third_party/xla/xla/service/gpu/triton_sparse_extensions.cc b/third_party/xla/xla/service/gpu/triton_sparse_extensions.cc index d7495d07a793a4..8b05928c17e677 100644 --- a/third_party/xla/xla/service/gpu/triton_sparse_extensions.cc +++ b/third_party/xla/xla/service/gpu/triton_sparse_extensions.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/service/gpu/triton_sparse_extensions.h" +#include +#include #include #include #include @@ -32,14 +34,27 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "mlir/Support/TypeID.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/CommandLine.h" -#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" using namespace mlir; // NOLINT(build/namespaces) +// The functions below are defined in AccelerateMatmul.cpp. +namespace mlir::triton::gpu { +SmallVector getWarpsPerTile( + Operation *dotOp, ArrayRef shape, int version, int numWarps, + const SmallVector &instrShape); +int computeOrigBitWidth(Value x); +Value getSharedMemMMAOperand(Value v, mlir::PatternRewriter &rewriter, + int opIdx, bool allowTranspose); +} // namespace mlir::triton::gpu + namespace { struct TritonSparseDotPattern @@ -175,6 +190,126 @@ class AddSparseDotEncodingPass llvm::cl::init(1)}; }; +class SparseBlockedToMMA : public RewritePattern { + using ConvertLayoutOp = triton::gpu::ConvertLayoutOp; + using SparseDotOp = triton::gpu::SparseDotOp; + using SparseDotMetaEncodingAttr = triton::gpu::SparseDotMetaEncodingAttr; + using NvidiaMmaEncodingAttr = triton::gpu::NvidiaMmaEncodingAttr; + + public: + SparseBlockedToMMA(MLIRContext *context, int compute_capability) + : RewritePattern(SparseDotOp::getOperationName(), 2, context), + compute_capability_(compute_capability) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto dotOp = cast(op); + auto ctx = op->getContext(); + Value a = dotOp.getA(); + Value b = dotOp.getB(); + + // Check data-types and SM compatibility + RankedTensorType oldRetType = dotOp.getType(); + if (!oldRetType.getEncoding() || + isa(oldRetType.getEncoding())) + return failure(); + + assert(compute_capability_ >= 80 && + "SparseDot is supported on Ampere and higher"); + bool allowV3 = !triton::tools::getBoolEnv("DISABLE_MMA_V3"); + int versionMajor = compute_capability_ >= 90 && allowV3 ? 3 : 2; + + // get MMA encoding for the given number of warps + auto retShapePerCTA = triton::gpu::getShapePerCTA(oldRetType); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + auto CTALayout = triton::gpu::getCTALayout(oldRetType.getEncoding()); + + auto instrShape = + mmaVersionToInstrShape(versionMajor, retShapePerCTA, + cast(a.getType()), numWarps); + auto warpsPerTile = getWarpsPerTile(dotOp, retShapePerCTA, versionMajor, + numWarps, instrShape); + NvidiaMmaEncodingAttr mmaEnc = + NvidiaMmaEncodingAttr::get(ctx, versionMajor, /*versionMinor=*/0, + warpsPerTile, CTALayout, instrShape); + auto newRetType = RankedTensorType::get( + oldRetType.getShape(), oldRetType.getElementType(), mmaEnc); + + // convert accumulator + auto oldAcc = dotOp.getOperand(2); + auto newAcc = rewriter.create( + oldAcc.getLoc(), newRetType, oldAcc); + + if (versionMajor == 2) { + int minBitwidth = std::min(triton::gpu::computeOrigBitWidth(a), + triton::gpu::computeOrigBitWidth(b)); + int kWidth = 32 / minBitwidth; + + // convert A operand + auto oldAType = cast(a.getType()); + auto newAEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaEnc, kWidth); + auto newAType = RankedTensorType::get( + oldAType.getShape(), oldAType.getElementType(), newAEncoding); + a = rewriter.create(a.getLoc(), newAType, a); + + // convert B operand + auto oldBType = cast(b.getType()); + auto newBEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaEnc, kWidth); + auto newBType = RankedTensorType::get( + oldBType.getShape(), oldBType.getElementType(), newBEncoding); + b = rewriter.create(b.getLoc(), newBType, b); + } else { + auto eltType = dotOp.getA().getType().getElementType(); + // In MMAV3 transpose is only supported for f16 and bf16. + bool allowTranspose = eltType.isF16() || eltType.isBF16(); + a = triton::gpu::getSharedMemMMAOperand(a, rewriter, 0, allowTranspose); + b = triton::gpu::getSharedMemMMAOperand(b, rewriter, 1, allowTranspose); + } + + // convert metadata + Value meta = dotOp.getAMeta(); + auto oldMetaType = cast(meta.getType()); + auto newMetaType = RankedTensorType::get( + oldMetaType.getShape(), oldMetaType.getElementType(), + SparseDotMetaEncodingAttr::get(ctx, mmaEnc)); + meta = rewriter.create(meta.getLoc(), newMetaType, meta); + + // convert dot instruction + auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, + newAcc, meta); + + rewriter.replaceOpWithNewOp(op, oldRetType, + newDot.getResult()); + return success(); + } + + private: + int compute_capability_; +}; + +class SparseBlockedToMMAPass + : public PassWrapper> { + public: + SparseBlockedToMMAPass() = default; + + StringRef getArgument() const override { return "sparse-blocked-to-mma"; } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + auto compute_capability = getNVIDIAComputeCapability(module); + auto pattern = + std::make_unique(context, compute_capability); + RewritePatternSet patterns(context, std::move(pattern)); + if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) { + signalPassFailure(); + } + } + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SparseBlockedToMMAPass) +}; + } // namespace std::unique_ptr xla::gpu::createAddSparseDotEncodingPass( @@ -183,6 +318,11 @@ std::unique_ptr xla::gpu::createAddSparseDotEncodingPass( num_ctas); } +std::unique_ptr xla::gpu::createSparseBlockedToMMAPass() { + return std::make_unique(); +} + void xla::gpu::registerSparsePasses() { registerPass([] { return std::make_unique(); }); + registerPass([] { return std::make_unique(); }); } diff --git a/third_party/xla/xla/service/gpu/triton_sparse_extensions.h b/third_party/xla/xla/service/gpu/triton_sparse_extensions.h index 6826a6b291d361..d3d6989e006c83 100644 --- a/third_party/xla/xla/service/gpu/triton_sparse_extensions.h +++ b/third_party/xla/xla/service/gpu/triton_sparse_extensions.h @@ -27,6 +27,7 @@ namespace xla::gpu { std::unique_ptr createAddSparseDotEncodingPass( int32_t num_warps, int32_t threads_per_warp, int32_t num_ctas); +std::unique_ptr createSparseBlockedToMMAPass(); void registerSparsePasses();