[go: nahoru, domu]

Skip to content

Commit

Permalink
Move BlockedSparseToMMA pattern from Triton to XLA.
Browse files Browse the repository at this point in the history
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#13781 from openxla:kernel_cache_grow f421890b6548a9a4dfcc5350e791bc8860615dbc
PiperOrigin-RevId: 642959490
  • Loading branch information
chsigg authored and tensorflower-gardener committed Jun 18, 2024
1 parent 44cb866 commit 015a5d6
Show file tree
Hide file tree
Showing 24 changed files with 607 additions and 454 deletions.
191 changes: 45 additions & 146 deletions third_party/triton/xla_extensions/sparse_dot.patch
Original file line number Diff line number Diff line change
Expand Up @@ -181,187 +181,86 @@ diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect
index c47558fa6..35f0cca95 100644
--- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
@@ -37,7 +37,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
@@ -37,8 +37,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
return 0;
}

-SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
+template <typename DotType>
+SmallVector<unsigned> warpsPerTileV2(DotType dotOp, const ArrayRef<int64_t> shape,
int numWarps) {
- int numWarps) {
+SmallVector<unsigned>
+warpsPerTileV2(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps) {
auto rank = shape.size();
// Early exit for batched matmul
@@ -51,8 +52,8 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
if (rank == 3)
@@ -51,9 +51,8 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
auto slices = multiRootGetSlice(dotOp, {filter}, {filter});
bool hasChainedDot = false;
for (Operation *op : slices) {
- if (isa<DotOp>(op) && (op != dotOp)) {
- auto chainedDot = cast<DotOp>(op);
+ if (isa<DotType>(op) && (op != dotOp)) {
+ auto chainedDot = cast<DotType>(op);
auto resTy = chainedDot.getResult().getType();
- auto resTy = chainedDot.getResult().getType();
+ if (dotOp->getName() == op->getName() && op != dotOp) {
+ auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
if (resTy.getRank() != rank) {
continue;
@@ -96,12 +97,13 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
return ret;
}
@@ -97,12 +96,13 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
}

-SmallVector<unsigned, 2>
SmallVector<unsigned, 2>
-warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
- const SmallVector<unsigned, 3> &instrShape) {
+template <typename DotType>
+SmallVector<unsigned, 2> warpsPerTileV3(
+ DotType dotOp, const ArrayRef<int64_t> shape, int numWarps,
+ const SmallVector<unsigned, 3> &instrShape) {
+warpsPerTileV3(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps,
const SmallVector<unsigned, 3> &instrShape) {
SetVector<Operation *> slices;
mlir::getForwardSlice(dotOp.getResult(), &slices);
- mlir::getForwardSlice(dotOp.getResult(), &slices);
- if (llvm::find_if(slices, [](Operation *op) { return isa<DotOp>(op); }) !=
+ if (llvm::find_if(slices, [](Operation *op) { return isa<DotType>(op); }) !=
slices.end())
- slices.end())
+ mlir::getForwardSlice(dotOp->getResult(0), &slices);
+ if (llvm::find_if(slices, [&](Operation *op) {
+ return dotOp->getName() == op->getName();
+ }) != slices.end())
return {(unsigned)numWarps, 1};

@@ -191,6 +193,7 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
// For MMAv3, the smallest indivisible unit of warp shape is (4, 1).
@@ -167,6 +167,7 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
mlir::TypeID::get<arith::ArithDialect>());
}

+public:
// Finds the first different bitwidth in the chain of shape-preserving
// unary ops that x depends on.
// There are two primary scenarios:
@@ -224,14 +227,14 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
return origBitWidth;
@@ -206,7 +207,7 @@ public:
}

-public:
BlockedToMMA(mlir::MLIRContext *context, int computeCapability)
: OpRewritePattern<DotOp>(context), computeCapability(computeCapability) {
}

- static SmallVector<unsigned, 3>
static SmallVector<unsigned, 3>
- getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
- int numWarps, const SmallVector<unsigned, 3> &instrShape) {
+ template <typename DotType>
+ static SmallVector<unsigned, 3> getWarpsPerTile(
+ DotType dotOp, const ArrayRef<int64_t> shape, int version, int numWarps,
+ const SmallVector<unsigned, 3> &instrShape) {
+ getWarpsPerTile(Operation *dotOp, const ArrayRef<int64_t> shape, int version,
int numWarps, const SmallVector<unsigned, 3> &instrShape) {
switch (version) {
case 2:
return warpsPerTileV2(dotOp, shape, numWarps);
@@ -359,6 +362,106 @@ public:
return success();
@@ -405,6 +406,21 @@ public:
}
};
+
+class SparseBlockedToMMA : public mlir::RewritePattern {
+ public:
+ using SparseDotOp = mlir::triton::gpu::SparseDotOp;
+ using SparseDotMetaEncodingAttr =
+ mlir::triton::gpu::SparseDotMetaEncodingAttr;
+
+ SparseBlockedToMMA(mlir::MLIRContext *context, int computeCapability)
+ : mlir::RewritePattern(SparseDotOp::getOperationName(), 2, context),
+ computeCapability(computeCapability) {}
+
+ mlir::LogicalResult matchAndRewrite(
+ mlir::Operation *op, mlir::PatternRewriter &rewriter) const override {
+ auto dotOp = cast<SparseDotOp>(op);
+ auto ctx = op->getContext();
+ Value a = dotOp.getA();
+ Value b = dotOp.getB();
+
+ // Check data-types and SM compatibility
+ RankedTensorType oldRetType = dotOp.getType();
+ if (!oldRetType.getEncoding() ||
+ isa<NvidiaMmaEncodingAttr>(oldRetType.getEncoding()))
+ return failure();
+
+ assert(computeCapability >= 80 &&
+ "SparseDot is supported on Ampere and higher");
+ bool allowV3 = !triton::tools::getBoolEnv("DISABLE_MMA_V3");
+ int versionMajor = computeCapability >= 90 && allowV3 ? 3 : 2;
+
+ // get MMA encoding for the given number of warps
+ auto retShapePerCTA = getShapePerCTA(oldRetType);
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ int numWarps = TritonGPUDialect::getNumWarps(mod);
+ auto CTALayout = getCTALayout(oldRetType.getEncoding());
+
+ auto instrShape =
+ mmaVersionToInstrShape(versionMajor, retShapePerCTA,
+ cast<RankedTensorType>(a.getType()), numWarps);
+ auto warpsPerTile = BlockedToMMA::getWarpsPerTile(
+ dotOp, retShapePerCTA, versionMajor, numWarps, instrShape);
+ NvidiaMmaEncodingAttr mmaEnc =
+ NvidiaMmaEncodingAttr::get(ctx, versionMajor, /*versionMinor=*/0,
+ warpsPerTile, CTALayout, instrShape);
+ auto newRetType = RankedTensorType::get(
+ oldRetType.getShape(), oldRetType.getElementType(), mmaEnc);
+
+ // convert accumulator
+ auto oldAcc = dotOp.getOperand(2);
+ auto newAcc = rewriter.create<ConvertLayoutOp>(oldAcc.getLoc(),
+ newRetType, oldAcc);
+
+ if (versionMajor == 2) {
+ int minBitwidth = std::min(BlockedToMMA::computeOrigBitWidth(a),
+ BlockedToMMA::computeOrigBitWidth(b));
+ int kWidth = 32 / minBitwidth;
+
+ // convert A operand
+ auto oldAType = cast<RankedTensorType>(a.getType());
+ auto newAEncoding =
+ DotOperandEncodingAttr::get(ctx, 0, mmaEnc, kWidth);
+ auto newAType = RankedTensorType::get(
+ oldAType.getShape(), oldAType.getElementType(), newAEncoding);
+ a = rewriter.create<ConvertLayoutOp>(a.getLoc(), newAType, a);
+
+ // convert B operand
+ auto oldBType = cast<RankedTensorType>(b.getType());
+ auto newBEncoding =
+ DotOperandEncodingAttr::get(ctx, 1, mmaEnc, kWidth);
+ auto newBType = RankedTensorType::get(
+ oldBType.getShape(), oldBType.getElementType(), newBEncoding);
+ b = rewriter.create<ConvertLayoutOp>(b.getLoc(), newBType, b);
+ } else {
+ auto eltType = dotOp.getA().getType().getElementType();
+ // In MMAV3 tranpose is only supported for f16 and bf16.
+ bool allowTranspose = eltType.isF16() || eltType.isBF16();
+ a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose);
+ b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose);
+ }
+
+ // convert metadata
+ Value meta = dotOp.getAMeta();
+ auto oldMetaType = cast<RankedTensorType>(meta.getType());
+ auto newMetaType = RankedTensorType::get(
+ oldMetaType.getShape(), oldMetaType.getElementType(),
+ SparseDotMetaEncodingAttr::get(ctx, mmaEnc));
+ meta =
+ rewriter.create<ConvertLayoutOp>(meta.getLoc(), newMetaType, meta);
+
+ // convert dot instruction
+ auto newDot = rewriter.create<SparseDotOp>(dotOp.getLoc(), newRetType, a, b,
+ newAcc, meta);
+
+ rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, oldRetType,
+ newDot.getResult());
+ return success();
+ }
+
+ private:
+ int computeCapability;
+};
} // namespace

static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
@@ -420,6 +523,7 @@ public:

mlir::RewritePatternSet patterns(context);
patterns.add<BlockedToMMA>(context, computeCapability);
+ patterns.add<SparseBlockedToMMA>(context, computeCapability);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();
}
+// Expose helper functions from BlockedToMMA to be reused for sparse matmul.
+SmallVector<unsigned, 3>
+getWarpsPerTile(Operation *dotOp, ArrayRef<int64_t> shape, int version,
+ int numWarps, const SmallVector<unsigned, 3> &instrShape) {
+ return BlockedToMMA::getWarpsPerTile(dotOp, shape, version, numWarps,
+ instrShape);
+}
+int computeOrigBitWidth(Value x) {
+ return BlockedToMMA::computeOrigBitWidth(x);
+}
+Value getSharedMemMMAOperand(Value v, mlir::PatternRewriter &rewriter,
+ int opIdx, bool allowTranspose) {
+ return getSharedMemoryMMAOperand(v, rewriter, opIdx, allowTranspose);
+}
+
} // namespace gpu
} // namespace triton
} // namespace mlir
diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp
index 5cb992714..cdafdffce 100644
--- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp
Expand Down
Loading

0 comments on commit 015a5d6

Please sign in to comment.