From b59bae9f746714b24a3d2fe9586c01de3e2e26eb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 3 May 2024 06:25:32 -0700 Subject: [PATCH] This change is to canonicalize boundary values in the IR replacing -Inf/Inf with MIN/MAX float value. PiperOrigin-RevId: 630372176 --- tensorflow/compiler/mlir/lite/BUILD | 27 +++++ .../mlir/lite/common/tfl_pass_config.h | 4 + .../python/saved_model_to_tfl_flatbuffer.cc | 2 + tensorflow/compiler/mlir/lite/stablehlo/BUILD | 2 + .../mlir/lite/stablehlo/odml_to_stablehlo.cc | 1 + .../lite/stablehlo/transforms/legalize_hlo.cc | 18 +-- .../lite/stablehlo/transforms/transforms.cc | 1 + .../tests/canonicalize_boundary_value.mlir | 79 +++++++++++++ .../mlir/lite/tf_to_tfl_flatbuffer.cc | 4 + .../canonicalize_boundary_value_pass.cc | 104 ++++++++++++++++++ .../compiler/mlir/lite/transforms/passes.h | 3 + .../compiler/mlir/lite/transforms/passes.td | 5 + tensorflow/compiler/mlir/lite/utils/utils.h | 12 ++ tensorflow/lite/python/convert.py | 6 + tensorflow/lite/python/lite.py | 4 + tensorflow/lite/python/lite_v2_test.py | 38 +++++++ tensorflow/lite/python/lite_v2_test_util.py | 6 + tensorflow/lite/toco/toco_flags.proto | 6 +- 18 files changed, 309 insertions(+), 13 deletions(-) create mode 100644 tensorflow/compiler/mlir/lite/tests/canonicalize_boundary_value.mlir create mode 100644 tensorflow/compiler/mlir/lite/transforms/canonicalize_boundary_value_pass.cc diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 41344262ef4f0a..b23cf780d62850 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -644,6 +644,32 @@ tf_cc_test( ], ) +cc_library( + name = "canonicalize_boundary_value", + srcs = [ + "transforms/canonicalize_boundary_value_pass.cc", + ], + hdrs = [ + "transforms/passes.h", + ], + deps = [ + ":tensorflow_lite", + ":tensorflow_lite_passes_inc_gen", + "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", + "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/container:flat_hash_set", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:DialectUtils", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@stablehlo//:stablehlo_ops", + ], +) + cc_library( name = "tensorflow_lite_legalize_tf", srcs = [ @@ -1365,6 +1391,7 @@ cc_library( "tf_tfl_passes.h", ], deps = [ + ":canonicalize_boundary_value", ":common", ":fake_quant_utils", ":tensorflow_lite_d2s", diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h index 1149d7841b38fd..e82b9240b0bb10 100644 --- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -102,6 +102,10 @@ struct PassConfig { // Enables the attempt to directly lower composites into tflite ops. bool enable_composite_direct_lowering = true; + + // When set to true, convert +Inf/-Inf to MIN/MAX float value and output of + // convert only contains finite values. + bool canonicalizing_inf_as_min_max_float = false; }; inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 085478db128a71..79b8ec47877e4a 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -212,6 +212,8 @@ Status ConvertSavedModelToTFLiteFlatBuffer( pass_config.enable_stablehlo_quantizer = toco_flags.has_quantization_config(); pass_config.enable_composite_direct_lowering = toco_flags.enable_composite_direct_lowering(); + pass_config.canonicalizing_inf_as_min_max_float = + toco_flags.canonicalizing_inf_as_min_max_float(); if (toco_flags.qdq_conversion_mode() == "STATIC") { pass_config.quant_specs.qdq_conversion_mode = diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index e4001d4c08b695..48277aa6705b83 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -203,6 +203,7 @@ cc_library( ":tf_stablehlo", ":unfold_splat_constant_pass", ":unfuse_batch_norm_pass", + "//tensorflow/compiler/mlir/lite:canonicalize_boundary_value", "//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", @@ -814,6 +815,7 @@ tf_cc_binary( "//tensorflow/cc/saved_model:loader", "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir:passes", + "//tensorflow/compiler/mlir/lite:canonicalize_boundary_value", "//tensorflow/compiler/mlir/lite:flatbuffer_export", "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc index 28afcd43a03218..d593843f9d76db 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc @@ -55,6 +55,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc index 96081a2b2b1bd8..6dc68fa3b7ef54 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc @@ -69,6 +69,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/utils/utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" @@ -1998,7 +1999,7 @@ class ConvertReduceOpToTfMax if (mlir::isa(type)) { APFloat const_value(.0); if (failed(GetConstantSplatValue(init_value, const_value)) || - !const_value.isInfinity() || !const_value.isNegative()) + !TFL::IsMinFloatValue(const_value)) return failure(); } else if (mlir::isa(type) && type.isSignlessInteger()) { APInt const_value; @@ -2023,7 +2024,7 @@ class ConvertReduceOpToTfMin if (mlir::isa(type)) { APFloat const_value(.0); if (failed(GetConstantSplatValue(init_value, const_value)) || - !const_value.isInfinity() || const_value.isNegative()) + !TFL::IsMaxFloatValue(const_value)) return failure(); } else if (mlir::isa(type) && type.isSignlessInteger()) { APInt const_value; @@ -2081,7 +2082,7 @@ class ConvertReduceOpToTfArgmax return false; if (mlir::isa(element_type)) { auto value = *attr.value_begin(); - return value.isNegative() && value.isInfinity(); + return TFL::IsMinFloatValue(value); } else if (element_type.isInteger(1)) { auto value = *attr.value_begin(); return value.isZero(); @@ -2105,7 +2106,7 @@ class ConvertReduceOpToTfArgmin return false; if (mlir::isa(element_type)) { auto value = *attr.value_begin(); - return !value.isNegative() && value.isInfinity(); + return TFL::IsMaxFloatValue(value); } else if (element_type.isInteger(1)) { auto value = *attr.value_begin(); return value.isZero(); @@ -2596,14 +2597,7 @@ class ConvertMaxPoolOp : public OpConversionPattern { } APFloat element = float_value.getValues()[0]; - if (!element.isInfinity()) { - return false; - } - if (!element.isNegative()) { - return false; - } - - return true; + return TFL::IsMinFloatValue(element); } LogicalResult replaceWithMaxPool(mhlo::ReduceWindowOp op, Value input, diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc index fdbf12538f230e..4ec630a5850ed6 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" diff --git a/tensorflow/compiler/mlir/lite/tests/canonicalize_boundary_value.mlir b/tensorflow/compiler/mlir/lite/tests/canonicalize_boundary_value.mlir new file mode 100644 index 00000000000000..a94588454687c8 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/canonicalize_boundary_value.mlir @@ -0,0 +1,79 @@ +// RUN: tf-opt %s --canonicalize-boundary-value --split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @clamp_neg_inf_f32() -> tensor { +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor +// CHECK: return %[[CONST]] : tensor + +func.func @clamp_neg_inf_f32() -> tensor { + %ret = stablehlo.constant dense<0xFF800000> : tensor + return %ret : tensor +} + +// ----- + +// CHECK-LABEL: func.func @clamp_pos_inf_f32() -> tensor { +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<3.40282347E+38> : tensor +// CHECK: return %[[CONST]] : tensor +func.func @clamp_pos_inf_f32() -> tensor { + %ret = stablehlo.constant dense<0x7F800000> : tensor + return %ret : tensor +} + +// ----- + +// CHECK-LABEL: func.func @clamp_pos_inf_f32_tensor() -> tensor<1x4xf32> { +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<{{\[\[}}3.40282347E+38, 1.000000e+01, 2.000000e+01, -3.40282347E+38]]> : tensor<1x4xf32> +// CHECK: return %[[CONST]] : tensor<1x4xf32> +func.func @clamp_pos_inf_f32_tensor() -> tensor<1x4xf32> { + %ret = stablehlo.constant dense<[[0x7F800000, 10.0, 20.0, 0xFF800000]]> : tensor<1x4xf32> + return %ret : tensor<1x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @clamp_neg_inf_f16() -> tensor { +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<-6.550400e+04> : tensor +// CHECK: return %[[CONST]] : tensor +func.func @clamp_neg_inf_f16() -> tensor { + %ret = stablehlo.constant dense<0xFC00> : tensor + return %ret : tensor +} +// ----- + +// CHECK-LABEL: func.func @clamp_neg_inf_bf16() -> tensor { +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<-1.038460e+34> : tensor +// CHECK: return %[[CONST]] : tensor +func.func @clamp_neg_inf_bf16() -> tensor { + %ret = stablehlo.constant dense<0xF800> : tensor + return %ret : tensor +} + +// ----- + +// CHECK-LABEL: func.func @clamp_pos_inf_f16() -> tensor { +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<6.550400e+04> : tensor +// CHECK: return %[[CONST]] : tensor +func.func @clamp_pos_inf_f16() -> tensor { + %ret = stablehlo.constant dense<0x7C00> : tensor + return %ret : tensor +} + +// ----- + +// CHECK-LABEL: func.func @clamp_pos_inf_f16_tensor() -> tensor<1x4xf16> { +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<{{\[\[}}6.550400e+04, 1.000000e+01, 2.000000e+01, -6.550400e+04]]> : tensor<1x4xf16> +// CHECK: return %[[CONST]] : tensor<1x4xf16> +func.func @clamp_pos_inf_f16_tensor() -> tensor<1x4xf16> { + %ret = stablehlo.constant dense<[[0x7C00, 10.0, 20.0, 0xFC00]]> : tensor<1x4xf16> + return %ret : tensor<1x4xf16> +} + +// ----- + +// CHECK-LABEL: func.func @clamp_pos_inf_f16_tensor_tf_const() -> tensor<3xf16> { +// CHECK: %[[CONST:.*]] = "tf.Const"() <{value = dense<6.550400e+04> : tensor<3xf16>}> : () -> tensor<3xf16> +// CHECK: return %[[CONST]] : tensor<3xf16> +func.func @clamp_pos_inf_f16_tensor_tf_const() -> tensor<3xf16> { + %ret = "tf.Const"() <{value = dense<0x7C00> : tensor<3xf16>}> : () -> tensor<3xf16> + return %ret : tensor<3xf16> +} diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index dd8b345862e3c8..cba05ebded6840 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -360,6 +360,8 @@ absl::Status ConvertTFExecutorToStablehloFlatbuffer( return status_handler.ConsumeStatus(); } pass_manager.clear(); + if (pass_config.canonicalizing_inf_as_min_max_float) + pass_manager.addPass(mlir::TFL::CreateCanonicalizeBoundaryValuePass()); pass_manager.addPass(mlir::odml::createLegalizeStablehloToVhloPass()); if (failed(pass_manager.run(module))) { return status_handler.Combine( @@ -498,6 +500,8 @@ absl::Status ConvertTFExecutorToTFLOrFlatbuffer( MetadataForReducedPrecisionSupport(quant_specs.support_mask)); } pass_manager.clear(); + if (pass_config.canonicalizing_inf_as_min_max_float) + pass_manager.addPass(mlir::TFL::CreateCanonicalizeBoundaryValuePass()); pass_manager.addPass(mlir::odml::createLegalizeStablehloToVhloPass()); if (failed(pass_manager.run(module))) { return status_handler.Combine( diff --git a/tensorflow/compiler/mlir/lite/transforms/canonicalize_boundary_value_pass.cc b/tensorflow/compiler/mlir/lite/transforms/canonicalize_boundary_value_pass.cc new file mode 100644 index 00000000000000..e444b69f74b30f --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/canonicalize_boundary_value_pass.cc @@ -0,0 +1,104 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TFL { +namespace { + +#define DEBUG_TYPE "canonicalize-boundary-value" + +#define GEN_PASS_DEF_CANONICALIZEBOUNDARYVALUEPASS +#include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc" + +class CanonicalizeBoundaryValuePass + : public impl::CanonicalizeBoundaryValuePassBase< + CanonicalizeBoundaryValuePass> { + void runOnOperation() override; +}; + +// Clamp constant -Inf/Inf to MIN/MAX float value. +template +struct ClampInfToMinMaxFloat : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy const_op, + PatternRewriter& rewriter) const override { + ElementsAttr tensor_attr = + const_op.getValueAttr().template cast(); + ShapedType tensor_type = tensor_attr.getShapedType(); + auto float_type = dyn_cast(tensor_type.getElementType()); + if (!float_type) return failure(); + + auto vals_orig = tensor_attr.getValues(); + // If all values are finite, no need to rewrite. + if (llvm::all_of(vals_orig, [&](APFloat val) { return !val.isInfinity(); })) + return failure(); + + SmallVector vals_new(llvm::map_range(vals_orig, [&](APFloat val) { + return val.isInfinity() + ? APFloat::getLargest(float_type.getFloatSemantics(), + val.isNegative()) + : val; + })); + rewriter.replaceOpWithNewOp( + const_op, DenseElementsAttr::get(tensor_type, vals_new)); + return success(); + } +}; + +void CanonicalizeBoundaryValuePass::runOnOperation() { + auto* ctx = &getContext(); + + RewritePatternSet patterns(ctx); + patterns.add, + ClampInfToMinMaxFloat, + ClampInfToMinMaxFloat>(ctx); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { + return signalPassFailure(); + } +} + +} // end namespace + +std::unique_ptr> CreateCanonicalizeBoundaryValuePass() { + return std::make_unique(); +} + +} // end namespace TFL +} // end namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index 9e79d40db85fbd..a75808d68000a5 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -237,6 +237,9 @@ std::unique_ptr> CreateReduceTypePrecisionPass(); // so redudant ones may be grouped and removed. std::unique_ptr> CreatePushTransposeThroughEwisePass(); +// Create a pass that canonicalize the boundary values. +std::unique_ptr> CreateCanonicalizeBoundaryValuePass(); + // Creates a pass that brings operations into the same order as graph_info.cc. std::unique_ptr> CreatePartitionedTopologicalSortPass(); diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.td b/tensorflow/compiler/mlir/lite/transforms/passes.td index b2ab947b3895b3..9714a2f78de486 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.td +++ b/tensorflow/compiler/mlir/lite/transforms/passes.td @@ -476,3 +476,8 @@ def PushTransposeThroughEwisePass : Pass<"push-transpose-through-ewise", "mlir:: let dependentDialects = ["TFL::TensorFlowLiteDialect"]; } +def CanonicalizeBoundaryValuePass : Pass<"canonicalize-boundary-value", "mlir::ModuleOp"> { + let summary = "Canonicalize the IR representations of boundary values"; + let dependentDialects = ["mlir::stablehlo::StablehloDialect", "TF::TensorFlowDialect", "mlir::arith::ArithDialect"]; + let constructor = "CreateCanonicalizeBoundaryValuePass()"; +} diff --git a/tensorflow/compiler/mlir/lite/utils/utils.h b/tensorflow/compiler/mlir/lite/utils/utils.h index d73bf37ebd748a..0a3248e7d7c62e 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.h +++ b/tensorflow/compiler/mlir/lite/utils/utils.h @@ -41,6 +41,18 @@ using mlir::Operation; using mlir::ShapedType; using mlir::Value; +// Returns true if the value is the min float value. +inline bool IsMinFloatValue(APFloat value) { + if (!value.isNegative()) return false; + return value.isLargest() || value.isInfinity(); +} + +// Returns true if the value is the max float value. +inline bool IsMaxFloatValue(APFloat value) { + if (value.isNegative()) return false; + return value.isLargest() || value.isInfinity(); +} + // Returns true if all tensor value in `values` has static shape and same shape. inline bool OpHasSameStaticShapes(Operation* op) { auto values = op->getOperands(); diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 32fc17a1ce5ae3..c12b8f3682377c 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -596,6 +596,7 @@ def build_conversion_flags( qdq_conversion_mode=None, disable_per_channel_quantization_for_dense_layers=False, enable_composite_direct_lowering=False, + canonicalizing_inf_as_min_max_float=False, **_, ): """Builds protocol buffer describing a conversion of a model. @@ -727,6 +728,8 @@ def build_conversion_flags( layers. The flag works only for integer quantized model. enable_composite_direct_lowering: If set, attempts to lower composite ops directly to tflite ops. + canonicalizing_inf_as_min_max_float: When set to true, convert +Inf/-Inf to + MIN/MAX float value and output of converter only contains finite values. Returns: conversion_flags: protocol buffer describing the conversion process. @@ -850,6 +853,9 @@ def build_conversion_flags( conversion_flags.enable_composite_direct_lowering = ( enable_composite_direct_lowering ) + conversion_flags.canonicalizing_inf_as_min_max_float = ( + canonicalizing_inf_as_min_max_float + ) return conversion_flags diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index b6cfbaa60f632e..c5503a1307582f 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -672,6 +672,7 @@ def __init__(self): self._experimental_qdq_conversion_mode = None self._experimental_disable_per_channel_quantization_for_dense_layers = False self._experimental_enable_composite_direct_lowering = False + self._experimental_canonicalizing_inf_as_min_max_float = False # Debug parameters self.ir_dump_dir = None @@ -832,6 +833,9 @@ def _get_base_converter_args(self): "enable_composite_direct_lowering": ( self._experimental_enable_composite_direct_lowering ), + "canonicalizing_inf_as_min_max_float": ( + self._experimental_canonicalizing_inf_as_min_max_float + ), } if self.saved_model_dir: diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index afd4a46ec76a00..d2e21365c95734 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -5519,5 +5519,43 @@ def testSavedModelSignatureDefs(self): self.assertEqual(list(signature_defs['mul_add']['outputs']), ['output_0']) +class BoundaryValueTest(lite_v2_test_util.ModelTest): + + @parameterized.named_parameters( + ('EnableCanonicalizeInfAsMaxMinFloatFromSavedModel', True, True), + ('DisableCanonicalizeInfAsMaxMinFloatFromSavedModel', False, True), + ('EnableCanonicalizeInfAsMaxMinFloatFromConcreteFunc', True, False), + ('DisableCanonicalizeInfAsMaxMinFloatFromConcreteFunc', False, False), + ) + @test_util.run_v2_only + def testFloatBoundaryValue(self, is_canonicalized, is_from_saved_model): + root = self._getInfFloatModel() + input_data = None + concrete_func = root.f.get_concrete_function(input_data) + + mdl = tf.Module() + mdl.f = concrete_func + + def _get_converter() -> lite.TFLiteConverterV2: + if is_from_saved_model: + save_dir = os.path.join(self.get_temp_dir(), 'saved_model') + tf.saved_model.save(mdl, save_dir) + return lite.TFLiteConverterV2.from_saved_model(save_dir) + return lite.TFLiteConverterV2.from_concrete_functions( + [concrete_func], root + ) + + converter = _get_converter() + converter._experimental_canonicalizing_inf_as_min_max_float = ( + is_canonicalized + ) + tflite_model = converter.convert() + + # Check output value from converted model. + expected_value = [np.finfo(np.float32).max if is_canonicalized else np.inf] + actual_value = self._evaluateTFLiteModel(tflite_model, [input_data]) + self.assertEqual(expected_value, actual_value) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/lite/python/lite_v2_test_util.py b/tensorflow/lite/python/lite_v2_test_util.py index 2a4a781990e708..e831325364610c 100644 --- a/tensorflow/lite/python/lite_v2_test_util.py +++ b/tensorflow/lite/python/lite_v2_test_util.py @@ -287,3 +287,9 @@ def calibration_gen(): outputs = ReadAssign(number_of_states)(inputs) model = tf.keras.Model(inputs, outputs) return model, calibration_gen + + def _getInfFloatModel(self): + root = autotrackable.AutoTrackable() + root.v = constant_op.constant([np.inf], shape=(), dtype=dtypes.float32) + root.f = def_function.function(lambda x: root.v) + return root diff --git a/tensorflow/lite/toco/toco_flags.proto b/tensorflow/lite/toco/toco_flags.proto index 1760841a333f6a..9430c43a51fd0a 100644 --- a/tensorflow/lite/toco/toco_flags.proto +++ b/tensorflow/lite/toco/toco_flags.proto @@ -41,7 +41,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 64. +// Next ID to use: 65. message TocoFlags { // Input file format optional FileFormat input_format = 1; @@ -360,4 +360,8 @@ message TocoFlags { // Enables the attempt to directly lower composites into tflite ops. // WARNING: Experimental interface, subject to change. optional bool enable_composite_direct_lowering = 63 [default = false]; + + // When set to true, convert +Inf/-Inf to MIN/MAX float value and output of + // convert only contains finite values. + optional bool canonicalizing_inf_as_min_max_float = 64 [default = false]; }