[go: nahoru, domu]

Skip to content

Commit

Permalink
Merge pull request #59880 from Tai78641:pr_fix_strided_slice3
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 535501847
  • Loading branch information
tensorflower-gardener committed May 26, 2023
2 parents 41787b2 + 41a8551 commit d032157
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 24 deletions.
4 changes: 2 additions & 2 deletions tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -617,8 +617,8 @@ func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> {

// CHECK-LABEL: test_strided_slice
// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array<i64: 9, 21, 2>, start = array<i64: 4, 0, 1>}
// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array<i64: 9, 1, 7, 3, 2, 1>}
// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) <{size = array<i64: 9, 1, 7, 1, 2, 1>, start = array<i64: 0, 0, 0, 0, 0, 0>}
// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array<i64: 9, 7, 3, 2>}
// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) <{size = array<i64: 9, 7, 1, 2>, start = array<i64: 0, 0, 0, 0>}
// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array<i64: 9, 7, 2>}
func.func @test_strided_slice(%arg0: tensor<13x21x3xf32>) -> tensor<9x7x2xf32> {
%2 = "tf.Const"() {value = dense<[4, 0, 1]> : tensor<3xi64>} : () -> tensor<3xi64>
Expand Down
16 changes: 8 additions & 8 deletions tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1064,8 +1064,8 @@ func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> {

// CHECK-LABEL: test_strided_slice_simple
// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array<i64: 9, 21, 2>, start = array<i64: 4, 0, 1>}>
// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array<i64: 9, 1, 7, 3, 2, 1>}>
// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) <{size = array<i64: 9, 1, 7, 1, 2, 1>, start = array<i64: 0, 0, 0, 0, 0, 0>}>
// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array<i64: 9, 7, 3, 2>}>
// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) <{size = array<i64: 9, 7, 1, 2>, start = array<i64: 0, 0, 0, 0>}>
// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array<i64: 9, 7, 2>}>
func.func @test_strided_slice_simple(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> {
%cst = arith.constant dense<[4, 0, 1]> : tensor<3xi32>
Expand All @@ -1079,8 +1079,8 @@ func.func @test_strided_slice_simple(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32

// CHECK-LABEL: test_strided_slice_simple_negative
// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array<i64: 9, 18, 2>, start = array<i64: 4, 0, 1>}>
// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array<i64: 9, 1, 6, 3, 2, 1>}>
// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) <{size = array<i64: 9, 1, 6, 1, 2, 1>, start = array<i64: 0, 0, 0, 0, 0, 0>}>
// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array<i64: 9, 6, 3, 2>}>
// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) <{size = array<i64: 9, 6, 1, 2>, start = array<i64: 0, 0, 0, 0>}>
// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array<i64: 9, 6, 2>}>
func.func @test_strided_slice_simple_negative(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> {
%cst = arith.constant dense<[4, 0, 1]> : tensor<3xi32>
Expand All @@ -1107,8 +1107,8 @@ func.func @test_strided_slice_strideless(%arg0: tensor<13x21x3xf32>) -> tensor<*

// CHECK-LABEL: test_strided_slice_shrink
// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array<i64: 1, 21, 1>, start = array<i64: 4, 0, 1>}>
// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array<i64: 1, 1, 7, 3, 1, 1>}>
// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) <{size = array<i64: 1, 1, 7, 1, 1, 1>, start = array<i64: 0, 0, 0, 0, 0, 0>}>
// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array<i64: 1, 7, 3, 1>}>
// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) <{size = array<i64: 1, 7, 1, 1>, start = array<i64: 0, 0, 0, 0>}>
// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array<i64: 7>}>
func.func @test_strided_slice_shrink(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> {
%cst = arith.constant dense<[4, 0, 1]> : tensor<3xi32>
Expand Down Expand Up @@ -1203,8 +1203,8 @@ func.func @test_strided_slice_dynamic_end(%arg0: tensor<10x?x?xf32>) -> tensor<*
%stride = arith.constant dense<[1, 2, -1]> : tensor<3xi32>

// CHECK: %[[SLICE1:.+]] = "tosa.slice"(%arg0) <{size = array<i64: 7, -1, 1>, start = array<i64: 0, 1, 2>}>
// CHECK: %[[RESHAPE1:.+]] = "tosa.reshape"(%[[SLICE1]]) <{new_shape = array<i64: 7, 1, -1, 2, 1, 1>}>
// CHECK: %[[SLICE2:.+]] = "tosa.slice"(%[[RESHAPE1]]) <{size = array<i64: 7, 1, -1, 1, 1, 1>, start = array<i64: 0, 0, 0, 0, 0, 0>}>
// CHECK: %[[RESHAPE1:.+]] = "tosa.reshape"(%[[SLICE1]]) <{new_shape = array<i64: 7, -1, 2, 1>}>
// CHECK: %[[SLICE2:.+]] = "tosa.slice"(%[[RESHAPE1]]) <{size = array<i64: 7, -1, 1, 1>, start = array<i64: 0, 0, 0, 0>}>
// CHECK: %[[RESHAPE2:.+]] = "tosa.reshape"(%[[SLICE2]]) <{new_shape = array<i64: 7, -1>}>
%0 = "tfl.strided_slice"(%arg0, %begin, %end, %stride) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 2 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 4 : i32} : (tensor<10x?x?xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32>
// CHECK: return %[[RESHAPE2]]
Expand Down
34 changes: 20 additions & 14 deletions tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2279,9 +2279,7 @@ std::optional<Value> convertStridedSliceOp(
// tensor
//
// 2. Reshape2: Reshape the tensor from (1) such that each dimension with
// stride is split into two dimensions of size_i/stride_i, stride_i. A naive
// implementation doubles the input tensor rank, but only dimensions being
// strided actually need to be doubled.
// abs(stride) != 1 is split into two dimensions of size_i/stride_i, stride_i.
//
// 3. Slice3: Slice the tensor from (2) such that we select index [0] from
// each of the stride_i dimensions in (2)
Expand Down Expand Up @@ -2324,7 +2322,6 @@ std::optional<Value> convertStridedSliceOp(
int32_t strides_size = strides.size();
for (auto stride : strides) all_strides_one &= abs(stride) == 1;


// If all of the masks are set we can just bypass the entire thing.
const int32_t all_masks_one = (1 << strides_size) - 1;

Expand Down Expand Up @@ -2456,10 +2453,14 @@ std::optional<Value> convertStridedSliceOp(
}

// Step 2: reshape the sliced array
SmallVector<int64_t> a2_shape(input_rank * 2);
SmallVector<int64_t> a2_shape;
for (int i = 0; i < input_rank; ++i) {
a2_shape[i * 2 + 0] = a1_size[i] == -1 ? -1 : a1_size[i] / abs(strides[i]);
a2_shape[i * 2 + 1] = abs(strides[i]);
int64_t abs_stride_i = abs(strides[i]);
a2_shape.push_back(a1_size[i] == -1 ? -1 : a1_size[i] / abs_stride_i);
if (abs_stride_i != 1) {
// only add a stride dimension if strides[i] != 1
a2_shape.push_back(abs_stride_i);
}
}

auto a2_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
Expand All @@ -2470,19 +2471,24 @@ std::optional<Value> convertStridedSliceOp(
tensorflow::ConvertMlirShapeToTF(a2_shape)));

// Step 3: take a slice along the strides
SmallVector<int64_t> a3_begin(input_rank * 2), a3_size(input_rank * 2);
SmallVector<int64_t> a3_begin, a3_size;
for (int i = 0; i < input_rank; ++i) {
a3_begin[i * 2 + 0] = 0;
a3_begin[i * 2 + 1] = 0;
int64_t abs_stride_i = abs(strides[i]);
a3_begin.push_back(0);

if (shrink_axis_mask & (1 << i)) {
a3_size[i * 2 + 0] = 1;
a3_size.push_back(1);
} else {
a3_size[i * 2 + 0] =
(a1_size[i] == -1) ? -1 : (a1_size[i] / abs(strides[i]));
a3_size.push_back((a1_size[i] == -1) ? -1 : (a1_size[i] / abs_stride_i));
}
if (abs_stride_i != 1) {
// previous reshape only adds a stride dimension if strides[i] != 1
a3_begin.push_back(0);
a3_size.push_back(1);
}
a3_size[i * 2 + 1] = 1;
}
assert(a2_shape.size() == a3_begin.size());
assert(a2_shape.size() == a3_size.size());

auto a3_slice_op = CreateOpAndInfer<tosa::SliceOp>(
rewriter, op->getLoc(),
Expand Down

0 comments on commit d032157

Please sign in to comment.