[go: nahoru, domu]

Skip to content

Commit

Permalink
[MHLO] Add support for batching dims when legalizing mhlo.gather to…
Browse files Browse the repository at this point in the history
… linalg

PiperOrigin-RevId: 639017063
  • Loading branch information
tensorflower-gardener committed May 31, 2024
1 parent 69c4700 commit 2feb38b
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include <utility>

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mhlo/IR/hlo_ops.h"
#include "mhlo/transforms/map_mhlo_to_scalar_op.h"
Expand Down Expand Up @@ -4017,11 +4018,6 @@ struct GatherConversion : public OpConversionPattern<mhlo::GatherOp> {
LogicalResult matchAndRewrite(
mhlo::GatherOp gatherOp, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
// TODO: b/342172264 - Implement handling of batching dims.
if (!gatherOp.getDimensionNumbers().getOperandBatchingDims().empty() ||
!gatherOp.getDimensionNumbers().getStartIndicesBatchingDims().empty())
return failure();

Location loc = gatherOp.getLoc();

Value startIndices = adaptor.getStartIndices();
Expand All @@ -4048,6 +4044,10 @@ struct GatherConversion : public OpConversionPattern<mhlo::GatherOp> {
gatherOp.getDimensionNumbers().getOffsetDims();
ArrayRef<int64_t> collapsedSliceDims =
gatherOp.getDimensionNumbers().getCollapsedSliceDims();
ArrayRef<int64_t> operandBatchingDims =
gatherOp.getDimensionNumbers().getOperandBatchingDims();
ArrayRef<int64_t> startIndicesBatchingDims =
gatherOp.getDimensionNumbers().getStartIndicesBatchingDims();
ArrayRef<int64_t> startIndexMap =
gatherOp.getDimensionNumbers().getStartIndexMap();

Expand Down Expand Up @@ -4128,12 +4128,25 @@ struct GatherConversion : public OpConversionPattern<mhlo::GatherOp> {
for (const auto& it : llvm::enumerate(startIndexMap))
remappedIndexFromIndices[it.value()] = indexFromStartIndices[it.index()];

// Now we construct the index based on the operand/start_indices batching
// dimensions.
SmallVector<Value> indexFromBatching(operandRank, constants[0]);
for (auto [operandDim, indicesDim] :
llvm::zip_equal(operandBatchingDims, startIndicesBatchingDims)) {
indexFromBatching[operandDim] =
gatherIndex[indicesDim + (indicesDim < indexVectorDim ? 0 : 1)];
}

auto isCollapsedOrBatching = [&](int64_t dim) {
return llvm::is_contained(collapsedSliceDims, dim) ||
llvm::is_contained(operandBatchingDims, dim);
};

// Now we construct the index based on the offset. First we need to remap
// the offset dimensions by dropping the collapsed indices.
SmallVector<unsigned> remappedOffsetDims;
for (int i = 0; i < operandRank; ++i)
if (!llvm::is_contained(collapsedSliceDims, i))
remappedOffsetDims.push_back(i);
if (!isCollapsedOrBatching(i)) remappedOffsetDims.push_back(i);

assert(remappedOffsetDims.size() == offsetDims.size());

Expand All @@ -4142,7 +4155,7 @@ struct GatherConversion : public OpConversionPattern<mhlo::GatherOp> {
// Compute the size of the output shape dimension corresponding to this
// index dimension. If it's collapsed set it to 1.
Value outputDimSize = constants[1];
if (!llvm::is_contained(collapsedSliceDims, i)) {
if (!isCollapsedOrBatching(i)) {
outputDimSize = rewriter.createOrFold<tensor::DimOp>(
loc, emptyOp, offsetDims[operandIndexDim++]);
}
Expand Down Expand Up @@ -4171,12 +4184,15 @@ struct GatherConversion : public OpConversionPattern<mhlo::GatherOp> {
for (unsigned k = 0; k < offsetDims.size(); ++k)
indexFromOffset[remappedOffsetDims[k]] = linalgIndices[offsetDims[k]];

// Now we add together our two indices to get the final index into the
// Now we add together our three indices to get the final index into the
// operand.
SmallVector<Value> combinedIndex;
for (int i = 0; i < operandRank; ++i)
combinedIndex.push_back(rewriter.createOrFold<arith::AddIOp>(
loc, rewriter.getIndexType(), remappedIndexFromIndices[i],
loc, rewriter.getIndexType(),
rewriter.createOrFold<arith::AddIOp>(loc, rewriter.getIndexType(),
remappedIndexFromIndices[i],
indexFromBatching[i]),
indexFromOffset[i]));

Value extractOperand;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4719,6 +4719,50 @@ func.func @gather(%operand : tensor<1x4x8xi32>, %start_indices : tensor<1x8x2xi3

// -----

func.func @gather_batching_dims(%operand : tensor<5x4x8xi32>, %start_indices : tensor<8x5x1xi32>) -> tensor<8x5x8xi32> {
%res = "mhlo.gather"(%operand, %start_indices) {
dimension_numbers = #mhlo.gather<
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
index_vector_dim = 2,
offset_dims = [2],
start_index_map = [1]
>,
indices_are_sorted = false,
slice_sizes = dense<[1, 1, 8]> : tensor<3xi64>,
someattr
} : (tensor<5x4x8xi32>, tensor<8x5x1xi32>) -> tensor<8x5x8xi32>
func.return %res : tensor<8x5x8xi32>
}

// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @gather_batching_dims(
// CHECK-SAME: %[[OPERAND:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[START_INDICES:[a-zA-Z0-9_]+]]
// CHECK-SAME: )
// CHECK-DAG: %[[C0:.+]] = arith.constant 0
// CHECK-DAG: %[[C3:.+]] = arith.constant 3
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<8x5x8xi32>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
// CHECK-SAME: outs(%[[INIT]] : tensor<8x5x8xi32>)
// CHECK-SAME: {someattr}
// CHECK: ^bb0
// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0
// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1
// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2
// CHECK-DAG: %[[S0_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX0]], %[[IDX1]], %[[C0]]] : tensor<8x5x1xi32>
// CHECK-DAG: %[[S0:.+]] = arith.index_cast %[[S0_INT]] : i32 to index
// CHECK-DAG: %[[CLAMP0:.+]] = arith.maxsi %[[S0]], %[[C0]] : index
// CHECK-DAG: %[[IN0:.+]] = arith.minsi %[[CLAMP0]], %[[C3]]
// CHECK: %[[Y:.+]] = tensor.extract %[[OPERAND]][%[[IDX1]], %[[IN0]], %[[IDX2]]] : tensor<5x4x8xi32>
// CHECK: linalg.yield %[[Y]] : i32
// CHECK-DAG: return %[[RES]]

// -----

func.func @gather_unsigned_index(
%operand : tensor<1x4x8xi32>, %start_indices : tensor<1x8x2xui32>)
-> tensor<1x8x8xi32> {
Expand Down

0 comments on commit 2feb38b

Please sign in to comment.