[go: nahoru, domu]

Skip to content

Commit

Permalink
Return proper value for dynamic size when replacing Einsum with Batch…
Browse files Browse the repository at this point in the history
…MatMul.

Prior to this change, the added test case would result in a BatchMatMul op
with an incorrect result shape of `2x?x0x1xf32`.

The value representing dynamic size has changed from `-1` to
`std::numeric_limits<int64_t>::min()`.
The fixed conditional, which returns the value representing dynamic size if
detected in the provided shape, would (previously) continue to multiply
int64 min by subsequent shape values. This would result in a return value
of `0` if int64 min was multiplied by other even values (or int64 min if
multiplied by odd values) due to integer overflow.

PiperOrigin-RevId: 495961505
  • Loading branch information
arfaian authored and tensorflower-gardener committed Dec 16, 2022
1 parent f1447ef commit 2065b6f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
10 changes: 10 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ func.func @einsum_matmul(%arg0: tensor<7x9xf32>, %arg1: tensor<9x5xf32>) -> tens
// CHECK: return %[[v0]] : tensor<7x5xf32>
}

func.func @einsum_matmul_dynamic_size(%arg0: tensor<2x?x?x?xf32>, %arg1: tensor<2x?xf32>) -> tensor<2x?x?x1xf32> {
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bxyc,bx->bxyc"} : (tensor<2x?x?x?xf32>, tensor<2x?xf32>) -> tensor<2x?x?x1xf32>
func.return %0 : tensor<2x?x?x1xf32>
// CHECK-LABEL: einsum_matmul_dynamic_size
// CHECK-DAG: %[[cst:.*]] = arith.constant dense<[2, -1, 1, 1]> : tensor<4xi64>
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg1, %cst) : (tensor<2x?xf32>, tensor<4xi64>) -> tensor<2x?x1x1xf32>
// CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%arg0, %0) {adj_x = false, adj_y = false} : (tensor<2x?x?x?xf32>, tensor<2x?x1x1xf32>) -> tensor<2x?x?x1xf32>
// CHECK: return %[[v1]] : tensor<2x?x?x1xf32>
}

func.func @einsum_broadcast(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<3x4x6xf32> {
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,km->ijm"}: (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32>
func.return %0 : tensor<3x4x6xf32>
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include <string>
#include <tuple>
#include <utility>
#include <vector>

#include "absl/memory/memory.h"
#include "llvm/ADT/ArrayRef.h"
Expand Down Expand Up @@ -489,7 +490,7 @@ inline int64_t ProdShapeWithIndexInTuple(
int64_t prod_shape = 1;
for (auto index_tuple : index_tuples) {
const int64_t shape_i = shape[std::get<I>(index_tuple)];
if (shape_i == -1) return -1;
if (ShapedType::isDynamic(shape_i)) return ShapedType::kDynamic;
prod_shape *= shape_i;
}
return prod_shape;
Expand Down

0 comments on commit 2065b6f

Please sign in to comment.