From d7c03debf02c66ae4375080fb2c2400ce6284844 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Fri, 3 Mar 2023 13:15:46 -0800 Subject: [PATCH] [mlir][mhlo][sparse] use OSS general utility to test for "sparse" op PiperOrigin-RevId: 513905390 --- .../lower_general_dot/lower_general_dot.cc | 12 +----------- .../sparse_rewriting/sparse_rewriting.cc | 15 ++------------- 2 files changed, 3 insertions(+), 24 deletions(-) diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/lower_general_dot/lower_general_dot.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/lower_general_dot/lower_general_dot.cc index fcba9fd6b75c0e..db766421796e87 100644 --- a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/lower_general_dot/lower_general_dot.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/lower_general_dot/lower_general_dot.cc @@ -42,16 +42,6 @@ namespace mhlo { namespace { -bool hasAnySparseOperandOrResult(Operation *op) { - bool anySparseIn = llvm::any_of(op->getOperands().getTypes(), [](Type t) { - return sparse_tensor::getSparseTensorEncoding(t) != nullptr; - }); - bool anySparseOut = llvm::any_of(op->getResults().getTypes(), [](Type t) { - return sparse_tensor::getSparseTensorEncoding(t) != nullptr; - }); - return anySparseIn || anySparseOut; -} - Value transposeReshape(Value arg, Location loc, llvm::ArrayRef leftDims, llvm::ArrayRef rightDims, @@ -262,7 +252,7 @@ struct GeneralDotConvert : public OpRewritePattern { // For any sparse situation, don't use any of the following rules, since // transposing and reshaping is not without cost. Instead, rely on the // default linalg lowering that follows later in the pipeline. - if (hasAnySparseOperandOrResult(op)) return failure(); + if (sparse_tensor::hasAnySparseOperandOrResult(op)) return failure(); // Compute the, possibly, transposed-reshaped operands. lhs = llvm::cast>(processDotArg( diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/sparse_rewriting/sparse_rewriting.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/sparse_rewriting/sparse_rewriting.cc index 5a374600911959..ccf97ab4a684df 100644 --- a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/sparse_rewriting/sparse_rewriting.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/sparse_rewriting/sparse_rewriting.cc @@ -34,17 +34,6 @@ namespace mhlo { namespace { -// Whether the operation takes sparse input or produces sparse output. -bool hasAnySparseOperandOrResult(Operation *op) { - bool anySparseIn = llvm::any_of(op->getOperands().getTypes(), [](Type t) { - return sparse_tensor::getSparseTensorEncoding(t) != nullptr; - }); - bool anySparseOut = llvm::any_of(op->getResults().getTypes(), [](Type t) { - return sparse_tensor::getSparseTensorEncoding(t) != nullptr; - }); - return anySparseIn || anySparseOut; -} - /// Approves subsuming sparse types into operation. // TODO(b/231360416): replace this list with "supports sparsity" trait? bool canFuseWithSparseConvert(Operation *op) { @@ -98,7 +87,7 @@ struct SparseElementWiseConvertConverter LogicalResult matchAndRewrite(mhlo::ConvertOp op, PatternRewriter &rewriter) const override { - if (hasAnySparseOperandOrResult(op)) { + if (sparse_tensor::hasAnySparseOperandOrResult(op)) { // Uses sparse_tensor::ConvertOp to do element-wise value conversion. rewriter.replaceOpWithNewOp( op, op.getResult().getType(), op.getOperand()); @@ -118,7 +107,7 @@ struct SparseConcatenateConverter LogicalResult matchAndRewrite(mhlo::ConcatenateOp op, PatternRewriter &rewriter) const override { auto resultType = op.getResult().getType(); - if (hasAnySparseOperandOrResult(op)) { + if (sparse_tensor::hasAnySparseOperandOrResult(op)) { // If there is any sparse input, lower to sparse_tensor.concatenate // directly. rewriter.replaceOpWithNewOp(