[go: nahoru, domu]

Skip to content

Commit

Permalink
Revert pred type conversion rewrite.
Browse files Browse the repository at this point in the history
This breaks int4_test.

Reverts 4d538ab

PiperOrigin-RevId: 646795348
  • Loading branch information
jreiffers authored and tensorflower-gardener committed Jun 26, 2024
1 parent b1709b4 commit 9a84833
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 152 deletions.
101 changes: 0 additions & 101 deletions third_party/xla/xla/service/gpu/fusions/mlir/simplify_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@ limitations under the License.

#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project
#include "mlir/Dialect/Utils/StaticValueUtils.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
Expand Down Expand Up @@ -127,101 +125,6 @@ struct RewriteMinSi : mlir::OpRewritePattern<mlir::arith::MinSIOp> {
}
};

// Finds the defining `trunc` op if the value is the result of a trunc or a
// {trunc, ext} chain.
mlir::arith::TruncIOp FindDefiningTrunc(mlir::Value value) {
if (auto ext = value.getDefiningOp<mlir::arith::ExtUIOp>()) {
return FindDefiningTrunc(ext.getOperand());
}

auto defining_op = value.getDefiningOp<mlir::arith::TruncIOp>();
if (defining_op) {
auto first_trunc = FindDefiningTrunc(defining_op.getOperand());
if (first_trunc && first_trunc.getType().getIntOrFloatBitWidth() <=
defining_op.getType().getIntOrFloatBitWidth()) {
return first_trunc;
}
}
return defining_op;
}

// Rewrites trunc-bitwise to bitwise-trunc.
//
// For pred reductions, we generate code like this:
//
// %1 = arith.trunci %0 : i32 to i1
// %2 = arith.ori %1, %x
// %3 = arith.extui %2 : i1 to i32
// %4 = gpu.shuffle %3
//
// By swapping the trunc with the or, we get a trunc-ext-shuffle sequence, which
// can be rewritten to shuffle-trunc-ext. If there is another copy of the
// pattern afterwards, we can push the truncs/exts further down.
template <typename Op>
struct RewriteTruncBitExt : mlir::OpRewritePattern<Op> {
using mlir::OpRewritePattern<Op>::OpRewritePattern;

mlir::LogicalResult matchAndRewrite(
Op op, mlir::PatternRewriter& rewriter) const override {
mlir::Value lhs = op.getLhs();
mlir::Value rhs = op.getRhs();

mlir::arith::TruncIOp trunci_lhs = FindDefiningTrunc(lhs);
mlir::arith::TruncIOp trunci_rhs = FindDefiningTrunc(rhs);
if (!trunci_lhs && !trunci_rhs) {
return rewriter.notifyMatchFailure(op, "no truncation");
}

auto narrow_type = trunci_lhs ? trunci_lhs.getType() : trunci_rhs.getType();
auto wide_type =
(trunci_lhs ? trunci_lhs : trunci_rhs).getOperand().getType();
if (trunci_rhs && (trunci_rhs.getOperand().getType() != wide_type ||
trunci_rhs.getType() != narrow_type)) {
return rewriter.notifyMatchFailure(op, "mismatched truncation types");
}

mlir::Value new_lhs = trunci_lhs ? trunci_lhs.getOperand()
: rewriter.create<mlir::arith::ExtUIOp>(
op.getLoc(), wide_type, lhs);
mlir::Value new_rhs = trunci_rhs ? trunci_rhs.getOperand()
: rewriter.create<mlir::arith::ExtUIOp>(
op.getLoc(), wide_type, rhs);
mlir::Value new_op = rewriter.create<Op>(op.getLoc(), new_lhs, new_rhs);
rewriter.replaceOpWithNewOp<mlir::arith::TruncIOp>(op, op.getType(),
new_op);

return mlir::success();
}
};

// Rewrites trunc-ext-shuffle to shuffle-trunc-ext. This pattern is designed to
// work together with RewriteTruncBitExt to optimize pred reductions.
struct RewriteTruncExtShuffle
: public mlir::OpRewritePattern<mlir::gpu::ShuffleOp> {
using OpRewritePattern::OpRewritePattern;

mlir::LogicalResult matchAndRewrite(
mlir::gpu::ShuffleOp op, mlir::PatternRewriter& rewriter) const override {
auto ext = op.getOperand(0).getDefiningOp<mlir::arith::ExtUIOp>();
if (!ext) {
return rewriter.notifyMatchFailure(op, "no ext");
}
auto trunc = ext.getOperand().getDefiningOp<mlir::arith::TruncIOp>();
if (!trunc || trunc.getOperand().getType() != ext.getType()) {
return rewriter.notifyMatchFailure(op, "no trunc or type mismatch");
}
rewriter.setInsertionPointAfter(op);
auto new_trunc = rewriter.create<mlir::arith::TruncIOp>(
op.getLoc(), trunc.getType(), op.getResult(0));
auto new_ext = rewriter.create<mlir::arith::ExtUIOp>(
op.getLoc(), ext.getType(), new_trunc.getResult());
rewriter.modifyOpInPlace(op,
[&]() { op->setOperand(0, trunc.getOperand()); });
rewriter.replaceAllUsesExcept(op.getResult(0), new_ext, new_trunc);
return mlir::success();
}
};

void AnnotateRanges(mlir::func::FuncOp func) {
func->walk([](mlir::Operation* op) {
if (op->getNumResults() != 1) {
Expand Down Expand Up @@ -273,10 +176,6 @@ class SimplifyArithPass
mlir::RewritePatternSet patterns(&getContext());
AnnotateRanges(getOperation());
patterns.add<RewriteCmpI, RewriteMaxSi, RewriteMinSi>(&getContext());
patterns
.add<RewriteTruncBitExt<mlir::arith::OrIOp>,
RewriteTruncBitExt<mlir::arith::AndIOp>, RewriteTruncExtShuffle>(
&getContext());
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-simplify-arith -cse -canonicalize | FileCheck %s
// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-simplify-arith -canonicalize | FileCheck %s

module {
func.func @unknown(%arg0: index {xla.range = [0 : index, 42 : index]}) -> i1 {
Expand All @@ -13,6 +13,7 @@ module {

// -----


module {
func.func @true(%arg0: index {xla.range = [12 : index, 42 : index]}) -> i1 {
%c5 = arith.constant 5 : index
Expand Down Expand Up @@ -153,53 +154,3 @@ module {
// CHECK-LABEL: @minsi_add
// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
// CHECK-NEXT: return %[[ARG1]]

// -----

module {
func.func @pred_reduce(%in: i1) -> i1 {
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%c4_i32 = arith.constant 4 : i32
%c8_i32 = arith.constant 8 : i32
%c16_i32 = arith.constant 16 : i32
%c32_i32 = arith.constant 32 : i32
%0 = arith.extui %in : i1 to i32
%shuffleResult, %valid = gpu.shuffle down %0, %c16_i32, %c32_i32 : i32
%1 = arith.trunci %shuffleResult : i32 to i1
%2 = arith.ori %in, %1 : i1
%3 = arith.extui %2 : i1 to i32
%shuffleResult_0, %valid_1 = gpu.shuffle down %3, %c8_i32, %c32_i32 : i32
%4 = arith.trunci %shuffleResult_0 : i32 to i1
%5 = arith.ori %2, %4 : i1
%6 = arith.extui %5 : i1 to i32
%shuffleResult_2, %valid_3 = gpu.shuffle down %6, %c4_i32, %c32_i32 : i32
%7 = arith.trunci %shuffleResult_2 : i32 to i1
%8 = arith.ori %5, %7 : i1
%9 = arith.extui %8 : i1 to i32
%shuffleResult_4, %valid_5 = gpu.shuffle down %9, %c2_i32, %c32_i32 : i32
%10 = arith.trunci %shuffleResult_4 : i32 to i1
%11 = arith.ori %8, %10 : i1
%12 = arith.extui %11 : i1 to i32
%shuffleResult_6, %valid_7 = gpu.shuffle down %12, %c1_i32, %c32_i32 : i32
%13 = arith.trunci %shuffleResult_6 : i32 to i1
%14 = arith.ori %11, %13 : i1
return %14 : i1
}
}

// CHECK-LABEL: @pred_reduce
// CHECK-SAME: (%[[IN:.*]]: i1)
// CHECK: %[[IN_EXT:.*]] = arith.extui %[[IN]]
// CHECK-NEXT: %[[SHUFFLE0:.*]], {{.*}} = gpu.shuffle down %[[IN_EXT]]
// CHECK-NEXT: %[[OR0:.*]] = arith.ori %[[IN_EXT]], %[[SHUFFLE0]]
// CHECK-NEXT: %[[SHUFFLE1:.*]], {{.*}} = gpu.shuffle down %[[OR0]]
// CHECK-NEXT: %[[OR1:.*]] = arith.ori %[[OR0]], %[[SHUFFLE1]]
// CHECK-NEXT: %[[SHUFFLE2:.*]], {{.*}} = gpu.shuffle down %[[OR1]]
// CHECK-NEXT: %[[OR2:.*]] = arith.ori %[[OR1]], %[[SHUFFLE2]]
// CHECK-NEXT: %[[SHUFFLE3:.*]], {{.*}} = gpu.shuffle down %[[OR2]]
// CHECK-NEXT: %[[OR3:.*]] = arith.ori %[[OR2]], %[[SHUFFLE3]]
// CHECK-NEXT: %[[SHUFFLE4:.*]], {{.*}} = gpu.shuffle down %[[OR3]]
// CHECK-NEXT: %[[OR4:.*]] = arith.ori %[[OR3]], %[[SHUFFLE4]]
// CHECK-NEXT: %[[RET:.*]] = arith.trunci %[[OR4]]
// CHECK-NEXT: return %[[RET]]

0 comments on commit 9a84833

Please sign in to comment.