From f61ff93879927bcce1f56b76bc3d8d81d3d8d20f Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 15 Apr 2024 14:18:56 -0700 Subject: [PATCH] Use `ShapeUtil::HumanString` instead of calling `Shape::ToString` directly in `xla_builder.cc`. PiperOrigin-RevId: 625081373 --- third_party/triton/temporary/pipelining.patch | 472 ++++++++++++++++++ third_party/triton/temporary/series.bzl | 1 + .../triton/temporary/pipelining.patch | 472 ++++++++++++++++++ .../third_party/triton/temporary/series.bzl | 1 + third_party/xla/xla/client/xla_builder.cc | 78 ++- third_party/xla/xla/client/xla_builder.h | 26 +- .../xla/xla/client/xla_builder_test.cc | 86 +++- 7 files changed, 1100 insertions(+), 36 deletions(-) create mode 100644 third_party/triton/temporary/pipelining.patch create mode 100644 third_party/xla/third_party/triton/temporary/pipelining.patch diff --git a/third_party/triton/temporary/pipelining.patch b/third_party/triton/temporary/pipelining.patch new file mode 100644 index 00000000000000..9f5f36aeb5099d --- /dev/null +++ b/third_party/triton/temporary/pipelining.patch @@ -0,0 +1,472 @@ +This is patching changes upstream from different PRs that fix issues with +pipelining internally. Required changes are upto and including this commit +https://github.com/openai/triton/commit/70f0b7b6e333fe2155c79dfa8bec6ad388073670 +The patch can be removed with the integration that includes these changes. + +diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h +--- a/include/triton/Analysis/Utility.h ++++ b/include/triton/Analysis/Utility.h +@@ -8,6 +8,18 @@ + + namespace mlir { + ++inline bool isZeroConst(Value v) { ++ auto constantOp = v.getDefiningOp(); ++ if (!constantOp) ++ return false; ++ if (auto denseAttr = dyn_cast(constantOp.getValueAttr())) ++ return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); ++ if (auto denseAttr = ++ dyn_cast(constantOp.getValueAttr())) ++ return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); ++ return false; ++} ++ + class ReduceOpHelper { + public: + explicit ReduceOpHelper(triton::ReduceOp op) +diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +--- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td ++++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +@@ -45,6 +45,8 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait + + let arguments = (ins Variadic:$asyncToken, I32Attr:$num); + ++ let results = (outs TTG_AsyncToken:$retToken); ++ + let assemblyFormat = "$asyncToken attr-dict"; + + let extraClassDeclaration = [{ +@@ -229,10 +231,16 @@ def TTG_LocalLoadOp : TTG_Op<"local_load + let description = [{ + Load a tensor from the local memory descriptor into a distributed tensor. + }]; +- let arguments = (ins TT_MemDescType:$src); ++ let arguments = (ins TT_MemDescType:$src, Optional :$token); ++ ++ let builders = [ ++ OpBuilder<(ins "Type":$retType, "Value":$src), ++ [{ ++ build($_builder, $_state, retType, src, /*token=*/static_cast(nullptr)); ++ }]>]; + + // Use qualified() otherwise "!tt.memdesc" is printed as "". +- let assemblyFormat = [{$src attr-dict `:` qualified(type($src)) `->` type($result)}]; ++ let assemblyFormat = [{$src (`token` $token^)? attr-dict `:` qualified(type($src)) `->` type($result)}]; + + let results = (outs TT_Tensor:$result); + } +diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +--- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp ++++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +@@ -8,6 +8,7 @@ + #include "mlir/Interfaces/SideEffectInterfaces.h" + #include "mlir/Support/LLVM.h" + #include "triton/Analysis/AxisInfo.h" ++#include "triton/Analysis/Utility.h" + #include "triton/Dialect/Triton/IR/Types.h" + #include "triton/Dialect/Triton/IR/Utility.h" + #include "triton/Dialect/TritonGPU/IR/Attributes.h" +@@ -84,12 +85,13 @@ createAsyncCopy(scf::ForOp &forOp, tt::L + Location loc = loadOp.getLoc(); + Value src = loadOp.getPtr(); + Value mask = loadOp.getMask(); ++ Value other = loadOp.getOther(); + if (!isExpensiveLoadOrStore(loadOp) && opToInfo[loadOp].blockedEncoding) { + // For inexpensive loads that do not directly feed into dot ops + // we want to use optimal layout for the data. + ttg::BlockedEncodingAttr encoding = opToInfo[loadOp].blockedEncoding; + auto convertBlockLayout = [&](Value src) { +- auto ty = src.getType().cast(); ++ auto ty = cast(src.getType()); + auto newTy = + RankedTensorType::get(ty.getShape(), ty.getElementType(), encoding); + auto cvt = +@@ -99,9 +101,11 @@ createAsyncCopy(scf::ForOp &forOp, tt::L + src = convertBlockLayout(src); + if (mask) + mask = convertBlockLayout(mask); ++ if (other) ++ other = convertBlockLayout(other); + } + +- tt::MemDescType allocTy = alloc.getType().cast(); ++ tt::MemDescType allocTy = cast(alloc.getType()); + SmallVector copyOffsets(allocTy.getRank(), zero); + copyOffsets[0] = insertIdx; + tt::MemDescType subviewTy = tt::MemDescType::get( +@@ -110,11 +114,12 @@ createAsyncCopy(scf::ForOp &forOp, tt::L + auto view = + builder.create(loc, subviewTy, alloc, copyOffsets); + Operation *copy = builder.create( +- loc, src, view, mask, loadOp.getOther(), loadOp.getCache(), +- loadOp.getEvict(), loadOp.getIsVolatile()); ++ loc, src, view, mask, other, loadOp.getCache(), loadOp.getEvict(), ++ loadOp.getIsVolatile()); + Operation *commmit = + builder.create(loc, copy->getResult(0)); +- builder.create(loc, commmit->getResult(0), 0); ++ Operation *wait = ++ builder.create(loc, commmit->getResult(0), 0); + + int stage = opToInfo[loadOp].stage; + bool isMMV3Load = opToInfo[loadOp].loadIsMMAV3; +@@ -142,9 +147,21 @@ createAsyncCopy(scf::ForOp &forOp, tt::L + for (auto alloc : allocsToErase) { + alloc.erase(); + } +- auto sharedLoad = +- builder.create(loc, loadOp.getType(), viewLoad); +- loadOp->replaceAllUsesWith(sharedLoad->getResults()); ++ ++ auto sharedLoad = builder.create( ++ loc, loadOp.getType(), viewLoad, wait->getResult(0)); ++ auto result = sharedLoad->getResults(); ++ ++ // Create a select for non-zero other values as they are not handled by ++ // AsyncCopyGlobalToLocalOp for now. ++ Value other = loadOp.getOther(); ++ if (other && !isZeroConst(other)) { ++ auto select = builder.create( ++ loc, loadOp.getType(), mask, sharedLoad.getResult(), other); ++ result = select->getResults(); ++ } ++ ++ loadOp->replaceAllUsesWith(result); + } + loadOp.erase(); + } +@@ -160,7 +177,7 @@ getSharedEncIfAllUsersAreDotEnc(Value va + if (user->getNumResults() != 1) + return std::nullopt; + if (auto memDesc = +- user->getResult(0).getType().dyn_cast()) { ++ dyn_cast(user->getResult(0).getType())) { + // First time we find a shared encoding in the chain, save it and try to + // use it if it is compatible with the other users. + tempAttr = memDesc.getEncoding().cast(); +@@ -203,7 +220,7 @@ getSharedEncIfAllUsersAreDotEnc(Value va + static ttg::BlockedEncodingAttr + getBlockedEncoding(tt::LoadOp loadOp, tt::ModuleAxisInfoAnalysis &axisInfo) { + Value src = loadOp.getPtr(); +- auto ty = src.getType().cast(); ++ auto ty = cast(src.getType()); + auto mod = loadOp->getParentOfType(); + int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); + int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod); +@@ -221,7 +238,7 @@ getBlockedEncoding(tt::LoadOp loadOp, tt + + static std::optional + getSharedEncoding(tt::LoadOp loadOp, bool isMMAV3) { +- auto ty = loadOp.getType().cast(); ++ auto ty = cast(loadOp.getType()); + auto ctaLayout = ttg::getCTALayout(ty.getEncoding()); + auto blockedOrder = ttg::getOrder(ty.getEncoding()); + SmallVector order; +@@ -285,11 +302,10 @@ loadOpsToDistanceAndUse(scf::ForOp forOp + if (auto mask = loadOp.getMask()) + vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + +- auto tensorTy = ptr.getType().dyn_cast(); ++ auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return false; +- auto ty = +- tensorTy.getElementType().cast().getPointeeType(); ++ auto ty = cast(tensorTy.getElementType()).getPointeeType(); + unsigned width = vec * ty.getIntOrFloatBitWidth(); + + // We do not pipeline all loads for the following reasons: +@@ -353,7 +369,7 @@ static bool loadIsMMAv3(tt::LoadOp loadO + + // MMA V3 case. + auto newOrder = sharedEnc.getOrder(); +- auto ty = loadOp.getType().cast(); ++ auto ty = cast(loadOp.getType()); + auto oldOrder = ttg::getOrder(ty.getEncoding()); + + // The operand of MMAv3 is in SharedEncoding and its order should not +@@ -497,7 +513,7 @@ collectOpsToPipeline(scf::ForOp forOp, + static Value createAlloc(scf::ForOp &forOp, tt::LoadOp loadOp, + ttg::SharedEncodingAttr sharedEnc, unsigned distance) { + OpBuilder builder(forOp); +- auto ty = loadOp.getType().cast(); ++ auto ty = cast(loadOp.getType()); + SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); + bufferShape.insert(bufferShape.begin(), distance); + Type memdescType = mlir::triton::MemDescType::get( +@@ -669,12 +685,23 @@ createSchedule(scf::ForOp forOp, int num + } + }); + ++ auto getNestedOperands = [](Operation *op) -> SmallVector { ++ SmallVector operands; ++ op->walk([&](Operation *nestedOp) { ++ for (Value operand : nestedOp->getOperands()) { ++ if (operand.getParentBlock()->getParentOp()->isAncestor(nestedOp)) ++ operands.push_back(operand); ++ } ++ }); ++ return operands; ++ }; ++ + // Find dependencies with distance of 1. + SmallVector> distanceOneUsers(numStages); + for (int stage = 0; stage < numStages - 1; stage++) { + auto &group = insertAndDeps[stage]; + for (Operation *op : group) { +- for (Value operand : op->getOperands()) { ++ for (Value operand : getNestedOperands(op)) { + if (auto arg = operand.dyn_cast()) { + if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { + auto yieldOp = op->getBlock()->getTerminator(); +@@ -905,7 +932,7 @@ static int minNumInterleavedCommitOps(Op + // Look for consecutive wait ops and combine them into a single wait op. + static void + combineRedundantWaitOps(llvm::SmallSetVector &waitOps) { +- llvm::SmallSetVector toDelete; ++ llvm::MapVector toDelete; + for (auto waitOp : waitOps) { + if (toDelete.count(waitOp)) + continue; +@@ -927,10 +954,13 @@ combineRedundantWaitOps(llvm::SmallSetVe + OpBuilder builder(waitGroup.back()); + auto newWaitOp = builder.create(waitOp.getLoc(), + depTokens, minWaitNumber); +- toDelete.insert(waitGroup.begin(), waitGroup.end()); ++ for (auto waitOp : waitGroup) { ++ toDelete[waitOp] = newWaitOp; ++ } + } + for (auto waitOp : toDelete) { +- waitOp->erase(); ++ waitOp.first->replaceAllUsesWith(waitOp.second); ++ waitOp.first->erase(); + } + } + +@@ -1010,7 +1040,7 @@ static void threadValuesThroughWait(ttng + + for (ttng::DotAsyncOp dot : asyncDots) { + for (Value operand : dot.getOperands()) { +- if (operand.getType().isa()) { ++ if (isa(operand.getType())) { + newOperands.insert(operand); + } + } +@@ -1020,15 +1050,21 @@ static void threadValuesThroughWait(ttng + // values in the operation. + auto newWait = builder.create( + wait.getLoc(), llvm::to_vector(newOperands), wait.getPendings()); ++ ++ auto dominatedByNewWait = [&](OpOperand &operand) { ++ auto opInThisBlock = ++ newWait->getBlock()->findAncestorOpInBlock(*operand.getOwner()); ++ return opInThisBlock && newWait->isBeforeInBlock(opInThisBlock); ++ }; + for (int i = 0; i < origNumOperands; i++) { + Value operand = wait.getResult(i); +- if (!operand.getType().isa()) ++ if (!isa(operand.getType())) + operand.replaceAllUsesWith(newWait.getResult(i)); + } + for (int i = origNumOperands; i < newOperands.size(); i++) { + Value operand = newWait.getOperand(i); +- if (!operand.getType().isa()) +- operand.replaceAllUsesExcept(newWait.getResult(i), newWait); ++ if (!isa(operand.getType())) ++ operand.replaceUsesWithIf(newWait.getResult(i), dominatedByNewWait); + } + wait->erase(); + } +@@ -1047,8 +1083,8 @@ static void threadValuesThroughWait(ttng + // 1. All operands that touch shared memory are multi-buffered, i.e. can't read + // an incomplete value while it's being written asynchronously by a load. + // +-// 2. During iteration i, nothing other than the loop's `yield` reads the +-// result of the dot. ++// 2. If the dot is used by any op in the loop, it must be used under an `if`, ++// and will be synced with a `wait 0` at the beginning of the `if` block. + // + // 3. During iteration i, between the start of the loop up until the first + // `ttng.dot_wait {pendings=0}` op, the result of the dot from iteration i-1 +@@ -1079,7 +1115,7 @@ static std::optional dotCanBeProper + // Rule 1: All shmem operands are multi-buffered. + auto checkOperand = [&](Value operand) { + if (!isa( +- operand.getType().cast().getEncoding())) { ++ cast(operand.getType()).getEncoding())) { + return true; + } + +@@ -1103,17 +1139,41 @@ static std::optional dotCanBeProper + return std::nullopt; + } + +- // Rule 2: The dot should only be used by the for loop's `yield`. +- if (!dotOp->hasOneUse() || +- *dotOp->getUsers().begin() != forOp.getBody()->getTerminator()) { +- LDBG("Can't make dot async because it is not used only by the loop's " +- "`yield`."); +- return std::nullopt; ++ // Rule 2: The dot cannot be unconditionally used by any op in the loop. ++ // Uses under `if` are allowed, as can be explicitly synced with a `wait 0`. ++ int iterArgIdx = -1; ++ Value iterArg = nullptr; ++ SmallVector> queue; ++ for (auto &use : dotOp->getUses()) { ++ queue.push_back({use.getOwner(), use.getOperandNumber()}); + } +- +- // The result of the dot becomes this loop carry value. +- auto iterArgIdx = dotOp->getUses().begin()->getOperandNumber(); +- auto iterArg = forOp.getRegionIterArg(iterArgIdx); ++ while (!queue.empty()) { ++ auto [user, argIdx] = queue.pop_back_val(); ++ if (user->getParentOp() == forOp) { ++ if (isa(user)) { ++ if (iterArg) { ++ // The dot is used by the loop's yield, but we can't have any other ++ // uses. ++ return std::nullopt; ++ } ++ iterArgIdx = argIdx; ++ iterArg = forOp.getRegionIterArg(argIdx); ++ continue; ++ } ++ return std::nullopt; ++ } ++ if (auto ifOp = dyn_cast(user->getParentOp())) { ++ if (isa(user)) { ++ // The result is returned by the if, follow it further. ++ auto uses = ifOp.getResult(argIdx).getUses(); ++ for (auto &use : uses) { ++ queue.push_back({use.getOwner(), use.getOperandNumber()}); ++ } ++ } ++ } else { ++ return std::nullopt; ++ } ++ } + + // Rule 3a: Are the only users of the dot's result from iteration i-1 other + // MMAv3 dots? If so, we're done, this dot can be properly async. +@@ -1181,6 +1241,32 @@ static void insertAsyncDotWaitInLoop( + return; + } + ++ // Insert waits before the users of the properly async dots other than loop ++ // yield. ++ for (auto [asyncDot, iterArgIdx] : properlyAsyncDots) { ++ SmallVector uses; ++ for (auto &use : asyncDot->getUses()) { ++ if (auto yieldOp = dyn_cast(use.getOwner())) { ++ continue; ++ } ++ uses.push_back(&use); ++ } ++ ++ DenseMap> blockToUsers; ++ for (auto use : uses) { ++ auto block = use->getOwner()->getBlock(); ++ blockToUsers[block].push_back(use->get()); ++ } ++ ++ for (auto [block, users] : blockToUsers) { ++ OpBuilder builder(block, block->begin()); ++ auto newWait = builder.create(asyncDot->getLoc(), ++ ArrayRef{}, 0); ++ ++ threadValuesThroughWait(newWait, users); ++ } ++ } ++ + // Add the wait right after the last properly-async dot. This only needs to + // wait for all properly-async dots from the i-1'th iteration to complete, IOW + // we wait until there are most `asyncDots.size()` dots in flight. +diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir +--- a/test/TritonGPU/loop-pipeline.mlir ++++ b/test/TritonGPU/loop-pipeline.mlir +@@ -349,16 +349,21 @@ tt.func @indirect_bmm_scalar_dist_one(%7 + // CHECK: triton_gpu.async_copy_global_to_local + // CHECK: triton_gpu.async_copy_global_to_local + // CHECK: triton_gpu.async_commit_group ++// CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} ++// CHECK: scf.for ++// CHECK: tt.dot + // CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} + // CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]] +-// CHECK: %[[IND_BUFFER_0:.*]] = triton_gpu.memdesc_subview +-// CHECK: %[[IND_BUFFER_1:.*]] = triton_gpu.local_load %[[IND_BUFFER_0]] ++// CHECK-DAG: %[[IND_BUFFER_WAIT_TOKEN:.*]] = triton_gpu.async_wait {{.*}} {num = 1 : i32} ++// CHECK-DAG: %[[IND_BUFFER_0:.*]] = triton_gpu.memdesc_subview ++// CHECK: %[[IND_BUFFER_1:.*]] = triton_gpu.local_load %[[IND_BUFFER_0]] token %[[IND_BUFFER_WAIT_TOKEN]] + // CHECK: %[[IND_BUFFER_2:.*]] = tt.expand_dims %[[IND_BUFFER_1]] {axis = 1 : i32} + // CHECK: %[[IND_BUFFER_3:.*]] = tt.broadcast %[[IND_BUFFER_2]] + // CHECK: %[[IND_BUFFER_4:.*]] = arith.muli {{.*}}, %[[IND_BUFFER_3]] + // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[IND_BUFFER_4]] + // CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] + // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} ++// CHECK: scf.yield + tt.func @indirect_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, + %76: index, + %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, +diff --git a/test/TritonGPU/reorder-instructions.mlir b/test/TritonGPU/reorder-instructions.mlir +--- a/test/TritonGPU/reorder-instructions.mlir ++++ b/test/TritonGPU/reorder-instructions.mlir +@@ -28,7 +28,7 @@ module attributes {"triton_gpu.num-warps + // CHECK: triton_gpu.async_wait {num = 0 : i32} + // CHECK: triton_gpu.local_dealloc %0 : !tt.memdesc<4x128x64xf16, #shared> + // CHECK: triton_gpu.local_dealloc %1 : !tt.memdesc<4x128x64xf16, #shared> +-// CHECK: %2 = triton_gpu.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> ++// CHECK: %3 = triton_gpu.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> + #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> + #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> + #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp ++++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +@@ -333,17 +333,6 @@ static Value faddAccumulate(ConversionPa + return newStruct; + } + +-static bool isZero(Value v) { +- auto constantOp = v.getDefiningOp(); +- if (!constantOp) +- return false; +- if (auto denseAttr = dyn_cast(constantOp.getValueAttr())) +- return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); +- if (auto denseAttr = +- dyn_cast(constantOp.getValueAttr())) +- return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); +- return false; +-} + + static SmallVector emitWait(ConversionPatternRewriter &rewriter, + Location loc, SmallVector acc, +@@ -402,7 +391,7 @@ LogicalResult convertDot(const LLVMTypeC + int M = 4 * instrShape[0]; + int N = instrShape[1]; + int K = instrShape[2]; +- bool zeroAcc = isZero(c); ++ bool zeroAcc = isZeroConst(c); + auto shapePerCTATile = getShapePerCTATile(mmaEncoding); + int numRepM = ceil(dShapePerCTA[0], shapePerCTATile[0]); + int numRepN = ceil(dShapePerCTA[1], shapePerCTATile[1]); +diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp ++++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +@@ -924,8 +924,11 @@ struct AsyncWaitOpConversion + auto voidTy = void_ty(ctx); + ptxBuilder.launch(rewriter, loc, voidTy); + +- // Safe to remove the op since it doesn't have any return value. +- rewriter.eraseOp(op); ++ // Drop the result token. ++ Value zero = rewriter.create( ++ op.getLoc(), IntegerType::get(op.getContext(), 32), ++ rewriter.getI32IntegerAttr(0)); ++ rewriter.replaceOp(op, zero); + return success(); + } + }; diff --git a/third_party/triton/temporary/series.bzl b/third_party/triton/temporary/series.bzl index a929f2c4a017f0..2dfe0dd1bb695c 100644 --- a/third_party/triton/temporary/series.bzl +++ b/third_party/triton/temporary/series.bzl @@ -6,4 +6,5 @@ internal patch during the next triton integration process. """ temporary_patch_list = [ + "//third_party/triton/temporary:pipelining.patch", ] diff --git a/third_party/xla/third_party/triton/temporary/pipelining.patch b/third_party/xla/third_party/triton/temporary/pipelining.patch new file mode 100644 index 00000000000000..9f5f36aeb5099d --- /dev/null +++ b/third_party/xla/third_party/triton/temporary/pipelining.patch @@ -0,0 +1,472 @@ +This is patching changes upstream from different PRs that fix issues with +pipelining internally. Required changes are upto and including this commit +https://github.com/openai/triton/commit/70f0b7b6e333fe2155c79dfa8bec6ad388073670 +The patch can be removed with the integration that includes these changes. + +diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h +--- a/include/triton/Analysis/Utility.h ++++ b/include/triton/Analysis/Utility.h +@@ -8,6 +8,18 @@ + + namespace mlir { + ++inline bool isZeroConst(Value v) { ++ auto constantOp = v.getDefiningOp(); ++ if (!constantOp) ++ return false; ++ if (auto denseAttr = dyn_cast(constantOp.getValueAttr())) ++ return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); ++ if (auto denseAttr = ++ dyn_cast(constantOp.getValueAttr())) ++ return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); ++ return false; ++} ++ + class ReduceOpHelper { + public: + explicit ReduceOpHelper(triton::ReduceOp op) +diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +--- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td ++++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +@@ -45,6 +45,8 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait + + let arguments = (ins Variadic:$asyncToken, I32Attr:$num); + ++ let results = (outs TTG_AsyncToken:$retToken); ++ + let assemblyFormat = "$asyncToken attr-dict"; + + let extraClassDeclaration = [{ +@@ -229,10 +231,16 @@ def TTG_LocalLoadOp : TTG_Op<"local_load + let description = [{ + Load a tensor from the local memory descriptor into a distributed tensor. + }]; +- let arguments = (ins TT_MemDescType:$src); ++ let arguments = (ins TT_MemDescType:$src, Optional :$token); ++ ++ let builders = [ ++ OpBuilder<(ins "Type":$retType, "Value":$src), ++ [{ ++ build($_builder, $_state, retType, src, /*token=*/static_cast(nullptr)); ++ }]>]; + + // Use qualified() otherwise "!tt.memdesc" is printed as "". +- let assemblyFormat = [{$src attr-dict `:` qualified(type($src)) `->` type($result)}]; ++ let assemblyFormat = [{$src (`token` $token^)? attr-dict `:` qualified(type($src)) `->` type($result)}]; + + let results = (outs TT_Tensor:$result); + } +diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +--- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp ++++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +@@ -8,6 +8,7 @@ + #include "mlir/Interfaces/SideEffectInterfaces.h" + #include "mlir/Support/LLVM.h" + #include "triton/Analysis/AxisInfo.h" ++#include "triton/Analysis/Utility.h" + #include "triton/Dialect/Triton/IR/Types.h" + #include "triton/Dialect/Triton/IR/Utility.h" + #include "triton/Dialect/TritonGPU/IR/Attributes.h" +@@ -84,12 +85,13 @@ createAsyncCopy(scf::ForOp &forOp, tt::L + Location loc = loadOp.getLoc(); + Value src = loadOp.getPtr(); + Value mask = loadOp.getMask(); ++ Value other = loadOp.getOther(); + if (!isExpensiveLoadOrStore(loadOp) && opToInfo[loadOp].blockedEncoding) { + // For inexpensive loads that do not directly feed into dot ops + // we want to use optimal layout for the data. + ttg::BlockedEncodingAttr encoding = opToInfo[loadOp].blockedEncoding; + auto convertBlockLayout = [&](Value src) { +- auto ty = src.getType().cast(); ++ auto ty = cast(src.getType()); + auto newTy = + RankedTensorType::get(ty.getShape(), ty.getElementType(), encoding); + auto cvt = +@@ -99,9 +101,11 @@ createAsyncCopy(scf::ForOp &forOp, tt::L + src = convertBlockLayout(src); + if (mask) + mask = convertBlockLayout(mask); ++ if (other) ++ other = convertBlockLayout(other); + } + +- tt::MemDescType allocTy = alloc.getType().cast(); ++ tt::MemDescType allocTy = cast(alloc.getType()); + SmallVector copyOffsets(allocTy.getRank(), zero); + copyOffsets[0] = insertIdx; + tt::MemDescType subviewTy = tt::MemDescType::get( +@@ -110,11 +114,12 @@ createAsyncCopy(scf::ForOp &forOp, tt::L + auto view = + builder.create(loc, subviewTy, alloc, copyOffsets); + Operation *copy = builder.create( +- loc, src, view, mask, loadOp.getOther(), loadOp.getCache(), +- loadOp.getEvict(), loadOp.getIsVolatile()); ++ loc, src, view, mask, other, loadOp.getCache(), loadOp.getEvict(), ++ loadOp.getIsVolatile()); + Operation *commmit = + builder.create(loc, copy->getResult(0)); +- builder.create(loc, commmit->getResult(0), 0); ++ Operation *wait = ++ builder.create(loc, commmit->getResult(0), 0); + + int stage = opToInfo[loadOp].stage; + bool isMMV3Load = opToInfo[loadOp].loadIsMMAV3; +@@ -142,9 +147,21 @@ createAsyncCopy(scf::ForOp &forOp, tt::L + for (auto alloc : allocsToErase) { + alloc.erase(); + } +- auto sharedLoad = +- builder.create(loc, loadOp.getType(), viewLoad); +- loadOp->replaceAllUsesWith(sharedLoad->getResults()); ++ ++ auto sharedLoad = builder.create( ++ loc, loadOp.getType(), viewLoad, wait->getResult(0)); ++ auto result = sharedLoad->getResults(); ++ ++ // Create a select for non-zero other values as they are not handled by ++ // AsyncCopyGlobalToLocalOp for now. ++ Value other = loadOp.getOther(); ++ if (other && !isZeroConst(other)) { ++ auto select = builder.create( ++ loc, loadOp.getType(), mask, sharedLoad.getResult(), other); ++ result = select->getResults(); ++ } ++ ++ loadOp->replaceAllUsesWith(result); + } + loadOp.erase(); + } +@@ -160,7 +177,7 @@ getSharedEncIfAllUsersAreDotEnc(Value va + if (user->getNumResults() != 1) + return std::nullopt; + if (auto memDesc = +- user->getResult(0).getType().dyn_cast()) { ++ dyn_cast(user->getResult(0).getType())) { + // First time we find a shared encoding in the chain, save it and try to + // use it if it is compatible with the other users. + tempAttr = memDesc.getEncoding().cast(); +@@ -203,7 +220,7 @@ getSharedEncIfAllUsersAreDotEnc(Value va + static ttg::BlockedEncodingAttr + getBlockedEncoding(tt::LoadOp loadOp, tt::ModuleAxisInfoAnalysis &axisInfo) { + Value src = loadOp.getPtr(); +- auto ty = src.getType().cast(); ++ auto ty = cast(src.getType()); + auto mod = loadOp->getParentOfType(); + int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); + int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod); +@@ -221,7 +238,7 @@ getBlockedEncoding(tt::LoadOp loadOp, tt + + static std::optional + getSharedEncoding(tt::LoadOp loadOp, bool isMMAV3) { +- auto ty = loadOp.getType().cast(); ++ auto ty = cast(loadOp.getType()); + auto ctaLayout = ttg::getCTALayout(ty.getEncoding()); + auto blockedOrder = ttg::getOrder(ty.getEncoding()); + SmallVector order; +@@ -285,11 +302,10 @@ loadOpsToDistanceAndUse(scf::ForOp forOp + if (auto mask = loadOp.getMask()) + vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + +- auto tensorTy = ptr.getType().dyn_cast(); ++ auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return false; +- auto ty = +- tensorTy.getElementType().cast().getPointeeType(); ++ auto ty = cast(tensorTy.getElementType()).getPointeeType(); + unsigned width = vec * ty.getIntOrFloatBitWidth(); + + // We do not pipeline all loads for the following reasons: +@@ -353,7 +369,7 @@ static bool loadIsMMAv3(tt::LoadOp loadO + + // MMA V3 case. + auto newOrder = sharedEnc.getOrder(); +- auto ty = loadOp.getType().cast(); ++ auto ty = cast(loadOp.getType()); + auto oldOrder = ttg::getOrder(ty.getEncoding()); + + // The operand of MMAv3 is in SharedEncoding and its order should not +@@ -497,7 +513,7 @@ collectOpsToPipeline(scf::ForOp forOp, + static Value createAlloc(scf::ForOp &forOp, tt::LoadOp loadOp, + ttg::SharedEncodingAttr sharedEnc, unsigned distance) { + OpBuilder builder(forOp); +- auto ty = loadOp.getType().cast(); ++ auto ty = cast(loadOp.getType()); + SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); + bufferShape.insert(bufferShape.begin(), distance); + Type memdescType = mlir::triton::MemDescType::get( +@@ -669,12 +685,23 @@ createSchedule(scf::ForOp forOp, int num + } + }); + ++ auto getNestedOperands = [](Operation *op) -> SmallVector { ++ SmallVector operands; ++ op->walk([&](Operation *nestedOp) { ++ for (Value operand : nestedOp->getOperands()) { ++ if (operand.getParentBlock()->getParentOp()->isAncestor(nestedOp)) ++ operands.push_back(operand); ++ } ++ }); ++ return operands; ++ }; ++ + // Find dependencies with distance of 1. + SmallVector> distanceOneUsers(numStages); + for (int stage = 0; stage < numStages - 1; stage++) { + auto &group = insertAndDeps[stage]; + for (Operation *op : group) { +- for (Value operand : op->getOperands()) { ++ for (Value operand : getNestedOperands(op)) { + if (auto arg = operand.dyn_cast()) { + if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { + auto yieldOp = op->getBlock()->getTerminator(); +@@ -905,7 +932,7 @@ static int minNumInterleavedCommitOps(Op + // Look for consecutive wait ops and combine them into a single wait op. + static void + combineRedundantWaitOps(llvm::SmallSetVector &waitOps) { +- llvm::SmallSetVector toDelete; ++ llvm::MapVector toDelete; + for (auto waitOp : waitOps) { + if (toDelete.count(waitOp)) + continue; +@@ -927,10 +954,13 @@ combineRedundantWaitOps(llvm::SmallSetVe + OpBuilder builder(waitGroup.back()); + auto newWaitOp = builder.create(waitOp.getLoc(), + depTokens, minWaitNumber); +- toDelete.insert(waitGroup.begin(), waitGroup.end()); ++ for (auto waitOp : waitGroup) { ++ toDelete[waitOp] = newWaitOp; ++ } + } + for (auto waitOp : toDelete) { +- waitOp->erase(); ++ waitOp.first->replaceAllUsesWith(waitOp.second); ++ waitOp.first->erase(); + } + } + +@@ -1010,7 +1040,7 @@ static void threadValuesThroughWait(ttng + + for (ttng::DotAsyncOp dot : asyncDots) { + for (Value operand : dot.getOperands()) { +- if (operand.getType().isa()) { ++ if (isa(operand.getType())) { + newOperands.insert(operand); + } + } +@@ -1020,15 +1050,21 @@ static void threadValuesThroughWait(ttng + // values in the operation. + auto newWait = builder.create( + wait.getLoc(), llvm::to_vector(newOperands), wait.getPendings()); ++ ++ auto dominatedByNewWait = [&](OpOperand &operand) { ++ auto opInThisBlock = ++ newWait->getBlock()->findAncestorOpInBlock(*operand.getOwner()); ++ return opInThisBlock && newWait->isBeforeInBlock(opInThisBlock); ++ }; + for (int i = 0; i < origNumOperands; i++) { + Value operand = wait.getResult(i); +- if (!operand.getType().isa()) ++ if (!isa(operand.getType())) + operand.replaceAllUsesWith(newWait.getResult(i)); + } + for (int i = origNumOperands; i < newOperands.size(); i++) { + Value operand = newWait.getOperand(i); +- if (!operand.getType().isa()) +- operand.replaceAllUsesExcept(newWait.getResult(i), newWait); ++ if (!isa(operand.getType())) ++ operand.replaceUsesWithIf(newWait.getResult(i), dominatedByNewWait); + } + wait->erase(); + } +@@ -1047,8 +1083,8 @@ static void threadValuesThroughWait(ttng + // 1. All operands that touch shared memory are multi-buffered, i.e. can't read + // an incomplete value while it's being written asynchronously by a load. + // +-// 2. During iteration i, nothing other than the loop's `yield` reads the +-// result of the dot. ++// 2. If the dot is used by any op in the loop, it must be used under an `if`, ++// and will be synced with a `wait 0` at the beginning of the `if` block. + // + // 3. During iteration i, between the start of the loop up until the first + // `ttng.dot_wait {pendings=0}` op, the result of the dot from iteration i-1 +@@ -1079,7 +1115,7 @@ static std::optional dotCanBeProper + // Rule 1: All shmem operands are multi-buffered. + auto checkOperand = [&](Value operand) { + if (!isa( +- operand.getType().cast().getEncoding())) { ++ cast(operand.getType()).getEncoding())) { + return true; + } + +@@ -1103,17 +1139,41 @@ static std::optional dotCanBeProper + return std::nullopt; + } + +- // Rule 2: The dot should only be used by the for loop's `yield`. +- if (!dotOp->hasOneUse() || +- *dotOp->getUsers().begin() != forOp.getBody()->getTerminator()) { +- LDBG("Can't make dot async because it is not used only by the loop's " +- "`yield`."); +- return std::nullopt; ++ // Rule 2: The dot cannot be unconditionally used by any op in the loop. ++ // Uses under `if` are allowed, as can be explicitly synced with a `wait 0`. ++ int iterArgIdx = -1; ++ Value iterArg = nullptr; ++ SmallVector> queue; ++ for (auto &use : dotOp->getUses()) { ++ queue.push_back({use.getOwner(), use.getOperandNumber()}); + } +- +- // The result of the dot becomes this loop carry value. +- auto iterArgIdx = dotOp->getUses().begin()->getOperandNumber(); +- auto iterArg = forOp.getRegionIterArg(iterArgIdx); ++ while (!queue.empty()) { ++ auto [user, argIdx] = queue.pop_back_val(); ++ if (user->getParentOp() == forOp) { ++ if (isa(user)) { ++ if (iterArg) { ++ // The dot is used by the loop's yield, but we can't have any other ++ // uses. ++ return std::nullopt; ++ } ++ iterArgIdx = argIdx; ++ iterArg = forOp.getRegionIterArg(argIdx); ++ continue; ++ } ++ return std::nullopt; ++ } ++ if (auto ifOp = dyn_cast(user->getParentOp())) { ++ if (isa(user)) { ++ // The result is returned by the if, follow it further. ++ auto uses = ifOp.getResult(argIdx).getUses(); ++ for (auto &use : uses) { ++ queue.push_back({use.getOwner(), use.getOperandNumber()}); ++ } ++ } ++ } else { ++ return std::nullopt; ++ } ++ } + + // Rule 3a: Are the only users of the dot's result from iteration i-1 other + // MMAv3 dots? If so, we're done, this dot can be properly async. +@@ -1181,6 +1241,32 @@ static void insertAsyncDotWaitInLoop( + return; + } + ++ // Insert waits before the users of the properly async dots other than loop ++ // yield. ++ for (auto [asyncDot, iterArgIdx] : properlyAsyncDots) { ++ SmallVector uses; ++ for (auto &use : asyncDot->getUses()) { ++ if (auto yieldOp = dyn_cast(use.getOwner())) { ++ continue; ++ } ++ uses.push_back(&use); ++ } ++ ++ DenseMap> blockToUsers; ++ for (auto use : uses) { ++ auto block = use->getOwner()->getBlock(); ++ blockToUsers[block].push_back(use->get()); ++ } ++ ++ for (auto [block, users] : blockToUsers) { ++ OpBuilder builder(block, block->begin()); ++ auto newWait = builder.create(asyncDot->getLoc(), ++ ArrayRef{}, 0); ++ ++ threadValuesThroughWait(newWait, users); ++ } ++ } ++ + // Add the wait right after the last properly-async dot. This only needs to + // wait for all properly-async dots from the i-1'th iteration to complete, IOW + // we wait until there are most `asyncDots.size()` dots in flight. +diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir +--- a/test/TritonGPU/loop-pipeline.mlir ++++ b/test/TritonGPU/loop-pipeline.mlir +@@ -349,16 +349,21 @@ tt.func @indirect_bmm_scalar_dist_one(%7 + // CHECK: triton_gpu.async_copy_global_to_local + // CHECK: triton_gpu.async_copy_global_to_local + // CHECK: triton_gpu.async_commit_group ++// CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} ++// CHECK: scf.for ++// CHECK: tt.dot + // CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} + // CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]] +-// CHECK: %[[IND_BUFFER_0:.*]] = triton_gpu.memdesc_subview +-// CHECK: %[[IND_BUFFER_1:.*]] = triton_gpu.local_load %[[IND_BUFFER_0]] ++// CHECK-DAG: %[[IND_BUFFER_WAIT_TOKEN:.*]] = triton_gpu.async_wait {{.*}} {num = 1 : i32} ++// CHECK-DAG: %[[IND_BUFFER_0:.*]] = triton_gpu.memdesc_subview ++// CHECK: %[[IND_BUFFER_1:.*]] = triton_gpu.local_load %[[IND_BUFFER_0]] token %[[IND_BUFFER_WAIT_TOKEN]] + // CHECK: %[[IND_BUFFER_2:.*]] = tt.expand_dims %[[IND_BUFFER_1]] {axis = 1 : i32} + // CHECK: %[[IND_BUFFER_3:.*]] = tt.broadcast %[[IND_BUFFER_2]] + // CHECK: %[[IND_BUFFER_4:.*]] = arith.muli {{.*}}, %[[IND_BUFFER_3]] + // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[IND_BUFFER_4]] + // CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] + // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} ++// CHECK: scf.yield + tt.func @indirect_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, + %76: index, + %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, +diff --git a/test/TritonGPU/reorder-instructions.mlir b/test/TritonGPU/reorder-instructions.mlir +--- a/test/TritonGPU/reorder-instructions.mlir ++++ b/test/TritonGPU/reorder-instructions.mlir +@@ -28,7 +28,7 @@ module attributes {"triton_gpu.num-warps + // CHECK: triton_gpu.async_wait {num = 0 : i32} + // CHECK: triton_gpu.local_dealloc %0 : !tt.memdesc<4x128x64xf16, #shared> + // CHECK: triton_gpu.local_dealloc %1 : !tt.memdesc<4x128x64xf16, #shared> +-// CHECK: %2 = triton_gpu.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> ++// CHECK: %3 = triton_gpu.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> + #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> + #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> + #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp ++++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +@@ -333,17 +333,6 @@ static Value faddAccumulate(ConversionPa + return newStruct; + } + +-static bool isZero(Value v) { +- auto constantOp = v.getDefiningOp(); +- if (!constantOp) +- return false; +- if (auto denseAttr = dyn_cast(constantOp.getValueAttr())) +- return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); +- if (auto denseAttr = +- dyn_cast(constantOp.getValueAttr())) +- return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); +- return false; +-} + + static SmallVector emitWait(ConversionPatternRewriter &rewriter, + Location loc, SmallVector acc, +@@ -402,7 +391,7 @@ LogicalResult convertDot(const LLVMTypeC + int M = 4 * instrShape[0]; + int N = instrShape[1]; + int K = instrShape[2]; +- bool zeroAcc = isZero(c); ++ bool zeroAcc = isZeroConst(c); + auto shapePerCTATile = getShapePerCTATile(mmaEncoding); + int numRepM = ceil(dShapePerCTA[0], shapePerCTATile[0]); + int numRepN = ceil(dShapePerCTA[1], shapePerCTATile[1]); +diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp ++++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +@@ -924,8 +924,11 @@ struct AsyncWaitOpConversion + auto voidTy = void_ty(ctx); + ptxBuilder.launch(rewriter, loc, voidTy); + +- // Safe to remove the op since it doesn't have any return value. +- rewriter.eraseOp(op); ++ // Drop the result token. ++ Value zero = rewriter.create( ++ op.getLoc(), IntegerType::get(op.getContext(), 32), ++ rewriter.getI32IntegerAttr(0)); ++ rewriter.replaceOp(op, zero); + return success(); + } + }; diff --git a/third_party/xla/third_party/triton/temporary/series.bzl b/third_party/xla/third_party/triton/temporary/series.bzl index a929f2c4a017f0..2dfe0dd1bb695c 100644 --- a/third_party/xla/third_party/triton/temporary/series.bzl +++ b/third_party/xla/third_party/triton/temporary/series.bzl @@ -6,4 +6,5 @@ internal patch during the next triton integration process. """ temporary_patch_list = [ + "//third_party/triton/temporary:pipelining.patch", ] diff --git a/third_party/xla/xla/client/xla_builder.cc b/third_party/xla/xla/client/xla_builder.cc index b395bd81f4a347..8b6ee8630dfdf2 100644 --- a/third_party/xla/xla/client/xla_builder.cc +++ b/third_party/xla/xla/client/xla_builder.cc @@ -869,7 +869,43 @@ absl::StatusOr XlaBuilder::Build( return OkStatus(); } -XlaOp XlaBuilder::DynamicBroadcastInDim( +XlaOp XlaBuilder::MhloDynamicReshape(XlaOp operand, XlaOp output_shape, + const Shape& shape) { + return ReportErrorOrReturn([&]() -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + if (operand_shape->element_type() != shape.element_type()) { + return InvalidArgument( + "Element type of operand %s and output %s must match", + ShapeUtil::HumanString(*operand_shape), + ShapeUtil::HumanString(shape)); + } + if (operand_shape->is_static() && shape.is_static() && + ShapeUtil::ElementsIn(*operand_shape) != ShapeUtil::ElementsIn(shape)) { + return InvalidArgument( + "MhloDynamicReshape has mismatched element counts: from=%d (%s) " + "to=%d (%s)", + ShapeUtil::ElementsIn(*operand_shape), + ShapeUtil::HumanString(*operand_shape), ShapeUtil::ElementsIn(shape), + ShapeUtil::HumanString(shape)); + } + TF_ASSIGN_OR_RETURN(const Shape* output_shape_shape, + GetShapePtr(output_shape)); + if (output_shape_shape->dimensions(0) != shape.rank()) { + return InvalidArgument( + "output_shape dimension size=%d (%s) and rank of shape=%d (%s) must " + "match", + output_shape_shape->dimensions(0), + ShapeUtil::HumanString(*output_shape_shape), shape.rank(), + ShapeUtil::HumanString(shape)); + } + return xla::CustomCall(operand.builder(), "mhlo.dynamic_reshape", + /*operands=*/{operand, output_shape}, + /*shape=*/shape, + /*opaque=*/""); + }); +}; + +XlaOp XlaBuilder::MhloDynamicBroadcastInDim( const XlaOp operand, const XlaOp output_dimensions, absl::Span broadcast_dimensions, const Shape& output_shape) { return ReportErrorOrReturn([&]() -> absl::StatusOr { @@ -879,7 +915,7 @@ XlaOp XlaBuilder::DynamicBroadcastInDim( if (!output_dimensions_shape->IsInteger()) { return InvalidArgument("output_dimensions must be an integer type %s", - output_dimensions_shape->ToString()); + ShapeUtil::HumanString(*output_dimensions_shape)); } if (output_dimensions_shape->rank() != 1) { @@ -954,8 +990,8 @@ absl::StatusOr XlaBuilder::InDimBroadcast( TF_RET_CHECK(operand_shape->is_bounded_dynamic_dimension( it - broadcast_dimensions.begin()) == shape.is_bounded_dynamic_dimension(i)) - << " i: " << i << ", shape: " << shape.ToString() - << ", operand_shape: " << operand_shape->ToString(); + << " i: " << i << ", shape: " << ShapeUtil::HumanString(shape) + << ", operand_shape: " << ShapeUtil::HumanString(*operand_shape); } else { // Non-broadcast dimensions must be static. TF_RET_CHECK(shape.is_static_dimension(i)); @@ -1084,7 +1120,7 @@ absl::StatusOr> ExtractDimensionSizesAndPadOnesToLeft( // Broadcast `scalar` to `output_shape` with all shapes static at runtime. If a // dimension of `output_shape` is dynamic, get the dimension size of the dynamic // dimension from `output` and reshape them to `tensor<1xi32>`. This is used as -// one of the inputs to DynamicBroadcastInDim. +// one of the inputs to MhloDynamicBroadcastInDim. absl::StatusOr BroadcastScalarToOutputShapeWithUnbounded( XlaBuilder* builder, XlaOp scalar, XlaOp output, const Shape& output_shape) { @@ -1100,7 +1136,7 @@ absl::StatusOr BroadcastScalarToOutputShapeWithUnbounded( /*values=*/{static_cast(output_shape.dimensions(i))}) : Reshape(GetDimensionSize(output, i), {1}); } - return DynamicBroadcastInDim( + return MhloDynamicBroadcastInDim( scalar, /*output_dimensions=*/ConcatInDim(builder, output_sizes, 0), {}, output_shape); } @@ -1117,8 +1153,8 @@ absl::StatusOr DegenerateBroadcastWithUnbounded( std::iota(broadcast_dimensions.begin(), broadcast_dimensions.end(), output_shape.rank() - operand_shape->rank()); - return DynamicBroadcastInDim(operand, output_dimensions, broadcast_dimensions, - output_shape); + return MhloDynamicBroadcastInDim(operand, output_dimensions, + broadcast_dimensions, output_shape); } // Helper struct to store the result of `BroadcastToOutputShapeWithUnbounded`. @@ -1387,7 +1423,7 @@ XlaOp XlaBuilder::Iota(const Shape& shape, int64_t iota_dimension) { if (!shape.is_static()) { return InvalidArgument( "The output of iota must not have dynamic dimensions: %s", - shape.ToString()); + ShapeUtil::HumanString(shape)); } HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); @@ -1479,7 +1515,7 @@ XlaOp XlaBuilder::BroadcastInDim( operand_shape->element_type(), out_dim_size)); TF_RET_CHECK(!output_shape.is_unbounded_dynamic()) << "BroadcastInDim output must shape be static or bounded dynamic " - << output_shape.ToString(); + << ShapeUtil::HumanString(output_shape); int64_t broadcast_rank = broadcast_dimensions.size(); if (operand_shape->rank() != broadcast_rank) { return InvalidArgument( @@ -3164,13 +3200,14 @@ XlaOp XlaBuilder::AllReduceImpl(XlaOp operand, if (layout) { if (!LayoutUtil::HasLayout(*layout)) { return InvalidArgument("shape_with_layout must have the layout set: %s", - layout->ToString()); + ShapeUtil::HumanString(*layout)); } if (!ShapeUtil::Compatible(*layout, *operand_shape)) { return InvalidArgument( "Provided shape_with_layout must be compatible with the " "operand shape: %s vs %s", - layout->ToString(), operand_shape->ToString()); + ShapeUtil::HumanString(*layout), + ShapeUtil::HumanString(*operand_shape)); } instr.set_constrain_layout(true); if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) { @@ -3812,7 +3849,8 @@ XlaOp XlaBuilder::AllToAllTuple( return InvalidArgument( "Provided layout must be compatible with the operands' shape. " "The layout is %s, but operand %d has shape %s.", - layout->ToString(), i, shape.tuple_shapes(i).ToString()); + layout->ToString(), i, + ShapeUtil::HumanString(shape.tuple_shapes(i))); } *(shape.mutable_tuple_shapes(i)->mutable_layout()) = *layout; } @@ -4726,10 +4764,16 @@ XlaOp BroadcastInDim(const XlaOp operand, broadcast_dimensions); } -XlaOp DynamicBroadcastInDim(const XlaOp operand, const XlaOp output_dimensions, - absl::Span broadcast_dimensions, - const Shape& output_shape) { - return operand.builder()->DynamicBroadcastInDim( +XlaOp MhloDynamicReshape(const XlaOp operand, const XlaOp output_shape, + const Shape& shape) { + return operand.builder()->MhloDynamicReshape(operand, output_shape, shape); +} + +XlaOp MhloDynamicBroadcastInDim(const XlaOp operand, + const XlaOp output_dimensions, + absl::Span broadcast_dimensions, + const Shape& output_shape) { + return operand.builder()->MhloDynamicBroadcastInDim( operand, output_dimensions, broadcast_dimensions, output_shape); } diff --git a/third_party/xla/xla/client/xla_builder.h b/third_party/xla/xla/client/xla_builder.h index 585eafc91d72ca..24ca2db618664e 100644 --- a/third_party/xla/xla/client/xla_builder.h +++ b/third_party/xla/xla/client/xla_builder.h @@ -524,9 +524,10 @@ class XlaBuilder { // op from the XlaBuilder. This is only intended for export to MHLO or // StableHLO, and cannot be compiled. Only static output_dimensions are // allowed, and broadcast_dimensions is verified. - XlaOp DynamicBroadcastInDim(XlaOp operand, XlaOp output_dimensions, - absl::Span broadcast_dimensions, - const Shape& output_shape); + XlaOp MhloDynamicBroadcastInDim( + XlaOp operand, XlaOp output_dimensions, + absl::Span broadcast_dimensions, + const Shape& output_shape); XlaOp Pad(XlaOp operand, XlaOp padding_value, const PaddingConfig& padding_config); @@ -551,6 +552,9 @@ class XlaBuilder { absl::Span new_size_bounds, const std::vector& dims_are_dynamic); + XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, + const Shape& shape); + XlaOp Collapse(XlaOp operand, absl::Span dimensions); XlaOp Slice(XlaOp operand, absl::Span start_indices, @@ -1212,7 +1216,7 @@ class XlaBuilder { absl::Span out_dim_size, absl::Span broadcast_dimensions); - friend XlaOp DynamicBroadcastInDim( + friend XlaOp MhloDynamicBroadcastInDim( XlaOp operand, XlaOp output_dimensions, absl::Span broadcast_dimensions, const Shape& output_shape); @@ -1236,6 +1240,9 @@ class XlaBuilder { absl::Span new_size_bounds, const std::vector& dims_are_dynamic); + friend XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, + const Shape& shape); + friend XlaOp ReshapeWithInferredDimension(XlaOp operand, absl::Span new_sizes, int64_t inferred_dimension); @@ -1918,9 +1925,9 @@ XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, // StableHLO, and cannot be compiled. See // https://www.tensorflow.org/mlir/hlo_ops#mhlodynamic_broadcast_in_dim_mhlodynamicbroadcastindimop. // for the op semantics. -XlaOp DynamicBroadcastInDim(XlaOp operand, XlaOp output_dimensions, - absl::Span broadcast_dimensions, - const Shape& output_shape); +XlaOp MhloDynamicBroadcastInDim(XlaOp operand, XlaOp output_dimensions, + absl::Span broadcast_dimensions, + const Shape& output_shape); // Copies the input operand to the output. This operation is for internal // purpose and is only used by the compiler for optimization purposes or to @@ -1966,6 +1973,11 @@ XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, absl::Span new_size_bounds, const std::vector& dims_are_dynamic); +// This is an experimental API for creating the mhlo.dynamic_reshape op from the +// XlaBuilder. This is only intended for export to MHLO or StableHLO, and cannot +// be compiled. +XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, const Shape& shape); + // Enqueues an operation onto the computation that collapses the operand, // from first to last dimension (C order), then reshapes it to the given // dimension sizes. Conceptually, this is a limited form of "shape casting". diff --git a/third_party/xla/xla/client/xla_builder_test.cc b/third_party/xla/xla/client/xla_builder_test.cc index c97bd9f21b3674..c4d43bc1dc1c8e 100644 --- a/third_party/xla/xla/client/xla_builder_test.cc +++ b/third_party/xla/xla/client/xla_builder_test.cc @@ -1700,13 +1700,13 @@ TEST(XlaBuilderTest, TopKDimensions) { // Experimental Test //============================================================================// -TEST(XlaBuilderTest, DynamicBroadcastInDimExportSuccess) { +TEST(XlaBuilderTest, MhloDynamicBroadcastInDimExportSuccess) { XlaBuilder b(TestName()); TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[1, ?]")); TF_ASSERT_OK_AND_ASSIGN(const Shape output_dimensions, ParseShape("s32[3]")); TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, ParseShape("f32[1, 2, 3]")); TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[1, 2, 3]")); - DynamicBroadcastInDim( + MhloDynamicBroadcastInDim( Parameter(&b, 0, operand, "operand"), Parameter(&b, 1, output_dimensions, "output_dimensions"), /*broadcast_dimensions=*/{1, 2}, output_shape); @@ -1717,13 +1717,14 @@ TEST(XlaBuilderTest, DynamicBroadcastInDimExportSuccess) { GmockMatch(m::Op().WithShapeEqualTo(&expected))); } -TEST(XlaBuilderTest, DynamicBroadcastInDimNonBroadcastDimSizeGreaterThanOne) { +TEST(XlaBuilderTest, + MhloDynamicBroadcastInDimNonBroadcastDimSizeGreaterThanOne) { XlaBuilder b(TestName()); TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[2, ?]")); TF_ASSERT_OK_AND_ASSIGN(const Shape output_dimensions, ParseShape("s32[3]")); TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, ParseShape("f32[2, 2, 3]")); TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[2, 2, 3]")); - DynamicBroadcastInDim( + MhloDynamicBroadcastInDim( Parameter(&b, 0, operand, "operand"), Parameter(&b, 1, output_dimensions, "output_dimensions"), /*broadcast_dimensions=*/{1, 2}, output_shape); @@ -1734,13 +1735,13 @@ TEST(XlaBuilderTest, DynamicBroadcastInDimNonBroadcastDimSizeGreaterThanOne) { GmockMatch(m::Op().WithShapeEqualTo(&expected))); } -TEST(XlaBuilderTest, DynamicBroadcastInDimDynamicResultSize) { +TEST(XlaBuilderTest, MhloDynamicBroadcastInDimDynamicResultSize) { XlaBuilder b(TestName()); TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[1, ?]")); TF_ASSERT_OK_AND_ASSIGN(const Shape output_dimensions, ParseShape("s32[3]")); TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, ParseShape("f32[1, 2, ?]")); TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[1, 2, ?]")); - DynamicBroadcastInDim( + MhloDynamicBroadcastInDim( Parameter(&b, 0, operand, "operand"), Parameter(&b, 1, output_dimensions, "output_dimensions"), /*broadcast_dimensions=*/{1, 2}, output_shape); @@ -1751,12 +1752,13 @@ TEST(XlaBuilderTest, DynamicBroadcastInDimDynamicResultSize) { GmockMatch(m::Op().WithShapeEqualTo(&expected))); } -TEST(XlaBuilderTest, DynamicBroadcastInDimInvalidOutputDimensionsElementType) { +TEST(XlaBuilderTest, + MhloDynamicBroadcastInDimInvalidOutputDimensionsElementType) { XlaBuilder b(TestName()); TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[2, ?]")); TF_ASSERT_OK_AND_ASSIGN(const Shape output_dimensions, ParseShape("f32[3]")); TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, ParseShape("f32[2, 3, 3]")); - DynamicBroadcastInDim( + MhloDynamicBroadcastInDim( Parameter(&b, 0, operand, "operand"), Parameter(&b, 1, output_dimensions, "output_dimensions"), /*broadcast_dimensions=*/{1, 2}, output_shape); @@ -1766,13 +1768,13 @@ TEST(XlaBuilderTest, DynamicBroadcastInDimInvalidOutputDimensionsElementType) { HasSubstr("output_dimensions must be an integer type f32[3]"))); } -TEST(XlaBuilderTest, DynamicBroadcastInDimInvalidOutputDimensionsRank) { +TEST(XlaBuilderTest, MhloDynamicBroadcastInDimInvalidOutputDimensionsRank) { XlaBuilder b(TestName()); TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[2, ?]")); TF_ASSERT_OK_AND_ASSIGN(const Shape output_dimensions, ParseShape("s32[2, ?]")); TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, ParseShape("f32[2, 3, 3]")); - DynamicBroadcastInDim( + MhloDynamicBroadcastInDim( Parameter(&b, 0, operand, "operand"), Parameter(&b, 1, output_dimensions, "output_dimensions"), /*broadcast_dimensions=*/{1, 2}, output_shape); @@ -1782,12 +1784,12 @@ TEST(XlaBuilderTest, DynamicBroadcastInDimInvalidOutputDimensionsRank) { HasSubstr("output_dimensions must be rank 1 but got rank 2"))); } -TEST(XlaBuilderTest, DynamicBroadcastInDimIncompatibleBroadcastSize) { +TEST(XlaBuilderTest, MhloDynamicBroadcastInDimIncompatibleBroadcastSize) { XlaBuilder b(TestName()); TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[2, ?]")); TF_ASSERT_OK_AND_ASSIGN(const Shape output_dimensions, ParseShape("s32[3]")); TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, ParseShape("f32[2, 3, 3]")); - DynamicBroadcastInDim( + MhloDynamicBroadcastInDim( Parameter(&b, 0, operand, "operand"), Parameter(&b, 1, output_dimensions, "output_dimensions"), /*broadcast_dimensions=*/{1, 2}, output_shape); @@ -1797,6 +1799,66 @@ TEST(XlaBuilderTest, DynamicBroadcastInDimIncompatibleBroadcastSize) { "with size of result dimension 1 (3)"))); } +TEST(XlaBuilderTest, MhloDynamicReshapeExportSuccess) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 15]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, ParseShape("s32[2]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("f32[?, 15]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 15]")); + MhloDynamicReshape( + /*operand=*/Parameter(&b, 0, operand, "operand"), + /*output_shape=*/Parameter(&b, 1, output_shape, "output_shape"), + /*shape=*/shape); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(module->ToString(), HasSubstr("mhlo.dynamic_reshape")); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, MhloDynamicReshapeIncompatibleElementType) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 15]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, ParseShape("s32[2]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("s32[?, 15]")); + MhloDynamicReshape( + /*operand=*/Parameter(&b, 0, operand, "operand"), + /*output_shape=*/Parameter(&b, 1, output_shape, "output_shape"), + /*shape=*/shape); + EXPECT_THAT(BuildHloModule(b), + StatusIs(_, HasSubstr("Element type of operand f32[?,15] and " + "output s32[?,15] must match"))); +} + +TEST(XlaBuilderTest, MhloDynamicReshapeElementCountMismatch) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[3, 15]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, ParseShape("s32[2]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("f32[4, 15]")); + MhloDynamicReshape( + /*operand=*/Parameter(&b, 0, operand, "operand"), + /*output_shape=*/Parameter(&b, 1, output_shape, "output_shape"), + /*shape=*/shape); + EXPECT_THAT(BuildHloModule(b), + StatusIs(_, HasSubstr("MhloDynamicReshape has mismatched " + "element counts: from=45 (f32[3,15]) " + "to=60 (f32[4,15])"))); +} + +TEST(XlaBuilderTest, MhloDynamicReshapeRankMismatch) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 15]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, ParseShape("s32[3]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("f32[?, 15]")); + MhloDynamicReshape( + /*operand=*/Parameter(&b, 0, operand, "operand"), + /*output_shape=*/Parameter(&b, 1, output_shape, "output_shape"), + /*shape=*/shape); + EXPECT_THAT( + BuildHloModule(b), + StatusIs(_, HasSubstr("output_shape dimension size=3 (s32[3]) and rank " + "of shape=2 (f32[?,15]) must match"))); +} + //============================================================================// // Unbounded Dynamism Test //============================================================================//