[go: nahoru, domu]

Skip to content

Commit

Permalink
This change is to canonicalize boundary values in the IR replacing -I…
Browse files Browse the repository at this point in the history
…nf/Inf with MIN/MAX float value.

PiperOrigin-RevId: 630372176
  • Loading branch information
tensorflower-gardener committed May 9, 2024
1 parent 6c94115 commit e687ecb
Show file tree
Hide file tree
Showing 12 changed files with 268 additions and 43 deletions.
26 changes: 26 additions & 0 deletions tensorflow/compiler/mlir/lite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,31 @@ tf_cc_test(
],
)

cc_library(
name = "canonicalize_boundary_value",
srcs = [
"transforms/canonicalize_boundary_value_pass.cc",
],
hdrs = [
"transforms/passes.h",
],
deps = [
":tensorflow_lite",
":tensorflow_lite_passes_inc_gen",
"//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config",
"//tensorflow/compiler/mlir/tensorflow",
"@com_google_absl//absl/container:flat_hash_set",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@stablehlo//:stablehlo_ops",
],
)

cc_library(
name = "tensorflow_lite_legalize_tf",
srcs = [
Expand Down Expand Up @@ -1380,6 +1405,7 @@ cc_library(
"tf_tfl_passes.h",
],
deps = [
":canonicalize_boundary_value",
":common",
":fake_quant_utils",
":tensorflow_lite_d2s",
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/lite/stablehlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ cc_library(
":tf_stablehlo",
":unfold_splat_constant_pass",
":unfuse_batch_norm_pass",
"//tensorflow/compiler/mlir/lite:canonicalize_boundary_value",
"//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes",
"//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes",
Expand Down
48 changes: 24 additions & 24 deletions tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2974,8 +2974,8 @@ func.func @convert_int_reduce_to_sum(%arg0: tensor<1x256xi32>) -> tensor<1xi32>
// CHECK: return %[[VAL_3]] : tensor<1xf32>
// CHECK: }
func.func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
// "0xFF800000" represents -INF for f32.
%0 = mhlo.constant dense<0xFF800000> : tensor<f32>
// "0xFF7FFFFF" represents MIN for f32.
%0 = mhlo.constant dense<0xFF7FFFFF> : tensor<f32>
%1 = "mhlo.reduce"(%arg0, %0) ({
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%2 = mhlo.maximum %arg1, %arg2 : tensor<f32>
Expand Down Expand Up @@ -3007,8 +3007,8 @@ func.func @convert_reduce_to_max_int(%arg0: tensor<1x4xi32>) -> tensor<1xi32> {
// CHECK: return %[[VAL_3]] : tensor<1xf32>
// CHECK: }
func.func @convert_reduce_to_min(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
// "0x7F800000" represents INF for f32.
%0 = mhlo.constant dense<0x7F800000> : tensor<f32>
// "0x7F7FFFFF" represents the MAX for f32.
%0 = mhlo.constant dense<0x7F7FFFFF> : tensor<f32>
%1 = "mhlo.reduce"(%arg0, %0) ({
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%2 = mhlo.minimum %arg1, %arg2 : tensor<f32>
Expand Down Expand Up @@ -3299,8 +3299,8 @@ func.func @convert_avgpool_reshape_broadcast(%arg0: tensor<4x16x16x8xf32>) -> te
// CHECK: return %[[VAL_1]] : tensor<4x7x7x8xf32>
// CHECK: }
func.func @convert_maxpool_valid(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
// "0xFF800000" represents -INF for f32.
%0 = mhlo.constant dense<0xFF800000> : tensor<f32>
// "0xFF7FFFFF" represents the MIN for f32.
%0 = mhlo.constant dense<0xFF7FFFFF> : tensor<f32>
%1 = "mhlo.reduce_window"(%arg0, %0) ({
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%5 = mhlo.maximum %arg1, %arg2 : tensor<f32>
Expand All @@ -3320,8 +3320,8 @@ func.func @convert_maxpool_valid(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8
// CHECK: return %[[VAL_1]] : tensor<4x3x7x7xf32>
// CHECK: }
func.func @convert_maxpool_valid_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32> {
// "0xFF800000" represents -INF for f32.
%0 = mhlo.constant dense<0xFF800000> : tensor<f32>
// "0xFF7FFFFF" represents the MIN for f32.
%0 = mhlo.constant dense<0xFF7FFFFF> : tensor<f32>
%1 = "mhlo.reduce_window"(%arg0, %0) ({
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%2 = mhlo.maximum %arg1, %arg2 : tensor<f32>
Expand All @@ -3341,8 +3341,8 @@ func.func @convert_maxpool_valid_channel_first(%arg0: tensor<4x3x16x16xf32>) ->
// CHECK: return %[[VAL_1]] : tensor<4x7x7x7x8xf32>
// CHECK: }
func.func @convert_maxpool_valid_3d(%arg0: tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> {
// "0xFF800000" represents -INF for f32.
%0 = mhlo.constant dense<0xFF800000> : tensor<f32>
// "0xFF7FFFFF" represents the MIN for f32.
%0 = mhlo.constant dense<0xFF7FFFFF> : tensor<f32>
%1 = "mhlo.reduce_window"(%arg0, %0) ({
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%5 = mhlo.maximum %arg1, %arg2 : tensor<f32>
Expand All @@ -3362,8 +3362,8 @@ func.func @convert_maxpool_valid_3d(%arg0: tensor<4x16x16x16x8xf32>) -> tensor<4
// CHECK: return %[[VAL_1]] : tensor<4x3x7x7x7xf32>
// CHECK: }
func.func @convert_maxpool_valid_3d_channel_first(%arg0: tensor<4x3x16x16x16xf32>) -> tensor<4x3x7x7x7xf32> {
// "0xFF800000" represents -INF for f32.
%0 = mhlo.constant dense<0xFF800000> : tensor<f32>
// "0xFF7FFFFF" represents the MIN for f32.
%0 = mhlo.constant dense<0xFF7FFFFF> : tensor<f32>
%1 = "mhlo.reduce_window"(%arg0, %0) ({
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%2 = mhlo.maximum %arg1, %arg2 : tensor<f32>
Expand All @@ -3383,8 +3383,8 @@ func.func @convert_maxpool_valid_3d_channel_first(%arg0: tensor<4x3x16x16x16xf32
// CHECK: return %[[VAL_1]] : tensor<4x8x8x8xf32>
// CHECK: }
func.func @convert_maxpool_same(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
// "0xFF800000" represents -INF for f32.
%0 = mhlo.constant dense<0xFF800000> : tensor<f32>
// "0xFF7FFFFF" represents the MIN for f32.
%0 = mhlo.constant dense<0xFF7FFFFF> : tensor<f32>
%1 = "mhlo.reduce_window"(%arg0, %0) ({
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%6 = mhlo.maximum %arg1, %arg2 : tensor<f32>
Expand Down Expand Up @@ -4198,7 +4198,7 @@ func.func @convert_pytorch_argmax(%arg0: tensor<1x9xi32>) -> tensor<1xi32> {
// CHECK: return %[[VAL_10]], %[[VAL_11]]
// CHECK: }
func.func @convert_argmax(%arg0: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, tensor<4x32xi32>) {
%0 = mhlo.constant dense<0xFF800000> : tensor<f32>
%0 = mhlo.constant dense<0xFF7FFFFF> : tensor<f32>
%1 = mhlo.constant dense<0> : tensor<i32>
%2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<256xi32>
%3 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xi32>) -> tensor<4x32x256xi32>
Expand All @@ -4220,7 +4220,7 @@ func.func @convert_argmax(%arg0: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, ten

// CHECK-LABEL: func @convert_argmax_constant(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x2x4xf32>) -> (tensor<2x2xf32>, tensor<2x2xi32>) {
// CHECK-DAG: %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0xFF800000> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG: %[[VAL_1:.*]] = "tf.Const"() <{value = dense<-3.40282347E+38> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG: %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG: %[[VAL_3:.*]] = "tf.Const"() <{value = dense<{{\[\[}}[0, 1, 2, 3], [0, 1, 2, 3]], {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x2x4xi32>}> : () -> tensor<2x2x4xi32>
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<2> : tensor<1xi32>
Expand All @@ -4229,7 +4229,7 @@ func.func @convert_argmax(%arg0: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, ten
// CHECK: return %[[VAL_5]], %[[VAL_6]] : tensor<2x2xf32>, tensor<2x2xi32>
// CHECK: }
func.func @convert_argmax_constant(%arg0: tensor<2x2x4xf32>) -> (tensor<2x2xf32>, tensor<2x2xi32>) {
%0 = mhlo.constant dense<0xFF800000> : tensor<f32>
%0 = mhlo.constant dense<0xFF7FFFFF> : tensor<f32>
%1 = mhlo.constant dense<0> : tensor<i32>
%3 = mhlo.constant dense<[[[0, 1, 2, 3], [0, 1, 2, 3]], [[0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x2x4xi32>
%4:2 = "mhlo.reduce"(%arg0, %3, %0, %1) ({
Expand All @@ -4250,7 +4250,7 @@ func.func @convert_argmax_constant(%arg0: tensor<2x2x4xf32>) -> (tensor<2x2xf32>

// CHECK-LABEL: func @convert_argmax_constant_non_z_axis(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4xf32>) -> (tensor<4xf32>, tensor<4xi32>) {
// CHECK-DAG: %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0xFF800000> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG: %[[VAL_1:.*]] = "tf.Const"() <{value = dense<-3.40282347E+38> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG: %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG: %[[VAL_3:.*]] = "tf.Const"() <{value = dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]> : tensor<4x4xi32>}> : () -> tensor<4x4xi32>
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<0> : tensor<1xi32>
Expand All @@ -4259,7 +4259,7 @@ func.func @convert_argmax_constant(%arg0: tensor<2x2x4xf32>) -> (tensor<2x2xf32>
// CHECK: return %[[VAL_5]], %[[VAL_6]] : tensor<4xf32>, tensor<4xi32>
// CHECK: }
func.func @convert_argmax_constant_non_z_axis(%arg0: tensor<4x4xf32>) -> (tensor<4xf32>, tensor<4xi32>) {
%0 = mhlo.constant dense<0xFF800000> : tensor<f32>
%0 = mhlo.constant dense<0xFF7FFFFF> : tensor<f32>
%1 = mhlo.constant dense<0> : tensor<i32>
%3 = mhlo.constant dense<[[0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]> : tensor<4x4xi32>
%4:2 = "mhlo.reduce"(%arg0, %3, %0, %1) ({
Expand Down Expand Up @@ -4318,7 +4318,7 @@ func.func @convert_argmax_bool(%arg0: tensor<2xi1>) -> tensor<i32> {
// CHECK: return %[[VAL_10]], %[[VAL_11]]
// CHECK: }
func.func @convert_argmin(%arg0: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, tensor<4x32xi32>) {
%0 = mhlo.constant dense<0x7F800000> : tensor<f32>
%0 = mhlo.constant dense<0x7F7FFFFF> : tensor<f32>
%1 = mhlo.constant dense<0> : tensor<i32>
%2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<256xi32>
%3 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xi32>) -> tensor<4x32x256xi32>
Expand Down Expand Up @@ -4368,7 +4368,7 @@ func.func @convert_argmin_i16(%arg0: tensor<2xi16>) -> (tensor<i16>, tensor<i32>

// CHECK-LABEL: func @convert_argmin_constant(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x2x4xf32>) -> (tensor<2x2xf32>, tensor<2x2xi32>) {
// CHECK-DAG: %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0x7F800000> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG: %[[VAL_1:.*]] = "tf.Const"() <{value = dense<3.40282347E+38> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG: %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG: %[[VAL_3:.*]] = "tf.Const"() <{value = dense<{{\[\[}}[0, 1, 2, 3], [0, 1, 2, 3]], {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x2x4xi32>}> : () -> tensor<2x2x4xi32>
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<2> : tensor<1xi32>
Expand All @@ -4377,7 +4377,7 @@ func.func @convert_argmin_i16(%arg0: tensor<2xi16>) -> (tensor<i16>, tensor<i32>
// CHECK: return %[[VAL_5]], %[[VAL_6]] : tensor<2x2xf32>, tensor<2x2xi32>
// CHECK: }
func.func @convert_argmin_constant(%arg0: tensor<2x2x4xf32>) -> (tensor<2x2xf32>, tensor<2x2xi32>) {
%0 = mhlo.constant dense<0x7F800000> : tensor<f32>
%0 = mhlo.constant dense<0x7F7FFFFF> : tensor<f32>
%1 = mhlo.constant dense<0> : tensor<i32>
%3 = mhlo.constant dense<[[[0, 1, 2, 3], [0, 1, 2, 3]], [[0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x2x4xi32>
%4:2 = "mhlo.reduce"(%arg0, %3, %0, %1) ({
Expand Down Expand Up @@ -4429,7 +4429,7 @@ func.func @convert_argmin_bool(%arg0: tensor<2xi1>) -> tensor<i32> {

// CHECK-LABEL: func @convert_argmax_with_reshaped_iota(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x32x1xf32>) -> (tensor<1x1xf32>, tensor<1x1xi32>) {
// CHECK-DAG: %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0xFF800000> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG: %[[VAL_1:.*]] = "tf.Const"() <{value = dense<-3.40282347E+38> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG: %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG: %[[VAL_3:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG: %[[VAL_4:.*]] = "tf.Const"() <{value = dense<32> : tensor<i32>}> : () -> tensor<i32>
Expand All @@ -4443,7 +4443,7 @@ func.func @convert_argmin_bool(%arg0: tensor<2xi1>) -> tensor<i32> {
// CHECK: return %[[VAL_10]], %[[VAL_11]] : tensor<1x1xf32>, tensor<1x1xi32>
// CHECK: }
func.func @convert_argmax_with_reshaped_iota(%arg0: tensor<1x32x1xf32>) -> (tensor<1x1xf32>, tensor<1x1xi32>) {
%0 = mhlo.constant dense<0xFF800000> : tensor<f32>
%0 = mhlo.constant dense<0xFF7FFFFF> : tensor<f32>
%1 = mhlo.constant dense<0> : tensor<i32>
%2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<32xi32>
%3 = "mhlo.reshape"(%2) : (tensor<32xi32>) -> tensor<1x32x1xi32>
Expand Down
18 changes: 6 additions & 12 deletions tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/utils.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
Expand Down Expand Up @@ -1998,7 +1999,7 @@ class ConvertReduceOpToTfMax
if (mlir::isa<FloatType>(type)) {
APFloat const_value(.0);
if (failed(GetConstantSplatValue(init_value, const_value)) ||
!const_value.isInfinity() || !const_value.isNegative())
!TFL::IsMinFloatValue(const_value))
return failure();
} else if (mlir::isa<IntegerType>(type) && type.isSignlessInteger()) {
APInt const_value;
Expand All @@ -2023,7 +2024,7 @@ class ConvertReduceOpToTfMin
if (mlir::isa<FloatType>(type)) {
APFloat const_value(.0);
if (failed(GetConstantSplatValue(init_value, const_value)) ||
!const_value.isInfinity() || const_value.isNegative())
!TFL::IsMaxFloatValue(const_value))
return failure();
} else if (mlir::isa<IntegerType>(type) && type.isSignlessInteger()) {
APInt const_value;
Expand Down Expand Up @@ -2081,7 +2082,7 @@ class ConvertReduceOpToTfArgmax
return false;
if (mlir::isa<FloatType>(element_type)) {
auto value = *attr.value_begin<APFloat>();
return value.isNegative() && value.isInfinity();
return TFL::IsMinFloatValue(value);
} else if (element_type.isInteger(1)) {
auto value = *attr.value_begin<APInt>();
return value.isZero();
Expand All @@ -2105,7 +2106,7 @@ class ConvertReduceOpToTfArgmin
return false;
if (mlir::isa<FloatType>(element_type)) {
auto value = *attr.value_begin<APFloat>();
return !value.isNegative() && value.isInfinity();
return TFL::IsMaxFloatValue(value);
} else if (element_type.isInteger(1)) {
auto value = *attr.value_begin<APInt>();
return value.isZero();
Expand Down Expand Up @@ -2596,14 +2597,7 @@ class ConvertMaxPoolOp : public OpConversionPattern<mhlo::ReduceWindowOp> {
}

APFloat element = float_value.getValues<APFloat>()[0];
if (!element.isInfinity()) {
return false;
}
if (!element.isNegative()) {
return false;
}

return true;
return TFL::IsMinFloatValue(element);
}

LogicalResult replaceWithMaxPool(mhlo::ReduceWindowOp op, Value input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
Expand All @@ -38,6 +39,7 @@ void AddTFToStablehloPasses(OpPassManager& pm, bool skip_resize,

// if the input is a call_xla_module, then unwrap the content
pm.addPass(mlir::odml::CreateLegalizeTFXlaCallModuleToStablehloPass());
pm.addPass(mlir::TFL::CreateCanonicalizeBoundaryValuePass());
// TODO: b/230572023 - Consider improving shape inference for While op instead
// of dropping the attribute. This need not be correct for models not trained
// on TPU.
Expand All @@ -64,6 +66,7 @@ void AddTFToStablehloPasses(OpPassManager& pm, bool skip_resize,
pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
pm.addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass(
/*allow_mutable_tensors=*/true));
pm.addPass(mlir::TFL::CreateCanonicalizeBoundaryValuePass());

// Generic MLIR optimization passes.
pm.addPass(mlir::createCanonicalizerPass());
Expand All @@ -84,6 +87,7 @@ void AddTFToStablehloPasses(OpPassManager& pm, bool skip_resize,
pm.addNestedPass<func::FuncOp>(CreateSmuggleDisallowedOpsPass());
pm.addPass(mlir::createCanonicalizerPass());
}
pm.addPass(mlir::TFL::CreateCanonicalizeBoundaryValuePass());
}

void AddMhloOptimizationPasses(OpPassManager& pm,
Expand Down
Loading

0 comments on commit e687ecb

Please sign in to comment.