diff --git a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir index 5cbcc21080e417..ec2d36d126775b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir @@ -448,6 +448,17 @@ func.func @batchMatMulV2MatrixAdjXY(%arg0: tensor<5x4xf32>, %arg1: tensor<6x5xf3 // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32> } +// ----- + +func.func @batchMatMulV2DynamicSize(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor, tensor) -> tensor + func.return %0 : tensor + + // CHECK-LABEL: batchMatMulV2DynamicSize + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor, tensor) -> tensor + // CHECK: return %[[MATMUL_1]] : tensor +} + // ----- // ==== V3 tests ==== diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc index a15101584a1fa6..abdd1a83d516eb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc @@ -198,8 +198,8 @@ LogicalResult ConvertTFBatchMatMulOp::matchAndRewrite( std::swap(rhs_shape[rhs_dims - 1], rhs_shape[rhs_dims - 2]); } - const int rows = lhs_shape[lhs_dims - 2]; - const int cols = rhs_shape[rhs_dims - 1]; + const int64_t rows = lhs_shape[lhs_dims - 2]; + const int64_t cols = rhs_shape[rhs_dims - 1]; if (lhs_shape[lhs_dims - 1] != rhs_shape[rhs_dims - 2]) { // Input dimensions must be compatible for multiplication.