[go: nahoru, domu]

Skip to content

Commit

Permalink
[mlir][mhlo][sparse] use OSS general utility to test for "sparse" op
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 513905390
  • Loading branch information
aartbik authored and tensorflower-gardener committed Mar 6, 2023
1 parent 02c7f3f commit d7c03de
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,6 @@ namespace mhlo {

namespace {

bool hasAnySparseOperandOrResult(Operation *op) {
bool anySparseIn = llvm::any_of(op->getOperands().getTypes(), [](Type t) {
return sparse_tensor::getSparseTensorEncoding(t) != nullptr;
});
bool anySparseOut = llvm::any_of(op->getResults().getTypes(), [](Type t) {
return sparse_tensor::getSparseTensorEncoding(t) != nullptr;
});
return anySparseIn || anySparseOut;
}

Value transposeReshape(Value arg, Location loc,
llvm::ArrayRef<int64_t> leftDims,
llvm::ArrayRef<int64_t> rightDims,
Expand Down Expand Up @@ -262,7 +252,7 @@ struct GeneralDotConvert : public OpRewritePattern<DotGeneralOp> {
// For any sparse situation, don't use any of the following rules, since
// transposing and reshaping is not without cost. Instead, rely on the
// default linalg lowering that follows later in the pipeline.
if (hasAnySparseOperandOrResult(op)) return failure();
if (sparse_tensor::hasAnySparseOperandOrResult(op)) return failure();

// Compute the, possibly, transposed-reshaped operands.
lhs = llvm::cast<mlir::TypedValue<mlir::TensorType>>(processDotArg(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,6 @@ namespace mhlo {

namespace {

// Whether the operation takes sparse input or produces sparse output.
bool hasAnySparseOperandOrResult(Operation *op) {
bool anySparseIn = llvm::any_of(op->getOperands().getTypes(), [](Type t) {
return sparse_tensor::getSparseTensorEncoding(t) != nullptr;
});
bool anySparseOut = llvm::any_of(op->getResults().getTypes(), [](Type t) {
return sparse_tensor::getSparseTensorEncoding(t) != nullptr;
});
return anySparseIn || anySparseOut;
}

/// Approves subsuming sparse types into operation.
// TODO(b/231360416): replace this list with "supports sparsity" trait?
bool canFuseWithSparseConvert(Operation *op) {
Expand Down Expand Up @@ -98,7 +87,7 @@ struct SparseElementWiseConvertConverter

LogicalResult matchAndRewrite(mhlo::ConvertOp op,
PatternRewriter &rewriter) const override {
if (hasAnySparseOperandOrResult(op)) {
if (sparse_tensor::hasAnySparseOperandOrResult(op)) {
// Uses sparse_tensor::ConvertOp to do element-wise value conversion.
rewriter.replaceOpWithNewOp<sparse_tensor::ConvertOp>(
op, op.getResult().getType(), op.getOperand());
Expand All @@ -118,7 +107,7 @@ struct SparseConcatenateConverter
LogicalResult matchAndRewrite(mhlo::ConcatenateOp op,
PatternRewriter &rewriter) const override {
auto resultType = op.getResult().getType();
if (hasAnySparseOperandOrResult(op)) {
if (sparse_tensor::hasAnySparseOperandOrResult(op)) {
// If there is any sparse input, lower to sparse_tensor.concatenate
// directly.
rewriter.replaceOpWithNewOp<sparse_tensor::ConcatenateOp>(
Expand Down

0 comments on commit d7c03de

Please sign in to comment.