[go: nahoru, domu]

Skip to content

Commit

Permalink
This change is to provide users with the option `canonicalizing_inf_a…
Browse files Browse the repository at this point in the history
…s_min_max_float` to canonicalize boundary values in the IR replacing -Inf/Inf with MIN/MAX float value. An option is provided to the user to enable this option.

PiperOrigin-RevId: 630372176
  • Loading branch information
tensorflower-gardener committed Jun 22, 2024
1 parent eea8412 commit fe6647a
Show file tree
Hide file tree
Showing 18 changed files with 343 additions and 13 deletions.
27 changes: 27 additions & 0 deletions tensorflow/compiler/mlir/lite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,32 @@ tf_cc_test(
],
)

cc_library(
name = "canonicalize_boundary_value",
srcs = [
"transforms/canonicalize_boundary_value_pass.cc",
],
hdrs = [
"transforms/passes.h",
],
deps = [
":tensorflow_lite",
":tensorflow_lite_passes_inc_gen",
"//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config",
"//tensorflow/compiler/mlir/tensorflow",
"@com_google_absl//absl/container:flat_hash_set",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@stablehlo//:stablehlo_ops",
],
)

cc_library(
name = "tensorflow_lite_legalize_tf",
srcs = [
Expand Down Expand Up @@ -1368,6 +1394,7 @@ cc_library(
"tf_tfl_passes.h",
],
deps = [
":canonicalize_boundary_value",
":common",
":fake_quant_utils",
":tensorflow_lite_d2s",
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/compiler/mlir/lite/common/tfl_pass_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ struct PassConfig {

// Enables the attempt to directly lower composites into tflite ops.
bool enable_composite_direct_lowering = true;

// When set to true, convert +Inf/-Inf to MIN/MAX float value and output of
// convert only contains finite values.
bool canonicalizing_inf_as_min_max_float = true;
};

inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ absl::Status ConvertGraphDefToTFLiteFlatBuffer(
pass_config.guarantee_all_funcs_one_use =
toco_flags.guarantee_all_funcs_one_use();
pass_config.enable_stablehlo_conversion = toco_flags.convert_to_stablehlo();
pass_config.canonicalizing_inf_as_min_max_float =
toco_flags.canonicalizing_inf_as_min_max_float();

// StableHLO Quantizer is not supported for GraphDef inputs, so
// quantization_py_function_lib is set to nullptr.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
pass_config.enable_stablehlo_quantizer = toco_flags.has_quantization_config();
pass_config.enable_composite_direct_lowering =
toco_flags.enable_composite_direct_lowering();
pass_config.canonicalizing_inf_as_min_max_float =
toco_flags.canonicalizing_inf_as_min_max_float();

if (toco_flags.qdq_conversion_mode() == "STATIC") {
pass_config.quant_specs.qdq_conversion_mode =
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/lite/stablehlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ cc_library(
":tf_stablehlo",
":unfold_splat_constant_pass",
":unfuse_batch_norm_pass",
"//tensorflow/compiler/mlir/lite:canonicalize_boundary_value",
"//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes",
"//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes",
Expand Down Expand Up @@ -813,6 +814,7 @@ tf_cc_binary(
"//tensorflow/cc/saved_model:loader",
"//tensorflow/compiler/mlir:init_mlir",
"//tensorflow/compiler/mlir:passes",
"//tensorflow/compiler/mlir/lite:canonicalize_boundary_value",
"//tensorflow/compiler/mlir/lite:flatbuffer_export",
"//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer",
"//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess",
Expand Down
18 changes: 6 additions & 12 deletions tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/utils.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
Expand Down Expand Up @@ -1998,7 +1999,7 @@ class ConvertReduceOpToTfMax
if (mlir::isa<FloatType>(type)) {
APFloat const_value(.0);
if (failed(GetConstantSplatValue(init_value, const_value)) ||
!const_value.isInfinity() || !const_value.isNegative())
!TFL::IsNegInfiniteValue(const_value))
return failure();
} else if (mlir::isa<IntegerType>(type) && type.isSignlessInteger()) {
APInt const_value;
Expand All @@ -2023,7 +2024,7 @@ class ConvertReduceOpToTfMin
if (mlir::isa<FloatType>(type)) {
APFloat const_value(.0);
if (failed(GetConstantSplatValue(init_value, const_value)) ||
!const_value.isInfinity() || const_value.isNegative())
!TFL::IsPosInfiniteValue(const_value))
return failure();
} else if (mlir::isa<IntegerType>(type) && type.isSignlessInteger()) {
APInt const_value;
Expand Down Expand Up @@ -2081,7 +2082,7 @@ class ConvertReduceOpToTfArgmax
return false;
if (mlir::isa<FloatType>(element_type)) {
auto value = *attr.value_begin<APFloat>();
return value.isNegative() && value.isInfinity();
return TFL::IsNegInfiniteValue(value);
} else if (element_type.isInteger(1)) {
auto value = *attr.value_begin<APInt>();
return value.isZero();
Expand All @@ -2105,7 +2106,7 @@ class ConvertReduceOpToTfArgmin
return false;
if (mlir::isa<FloatType>(element_type)) {
auto value = *attr.value_begin<APFloat>();
return !value.isNegative() && value.isInfinity();
return TFL::IsPosInfiniteValue(value);
} else if (element_type.isInteger(1)) {
auto value = *attr.value_begin<APInt>();
return value.isZero();
Expand Down Expand Up @@ -2596,14 +2597,7 @@ class ConvertMaxPoolOp : public OpConversionPattern<mhlo::ReduceWindowOp> {
}

APFloat element = float_value.getValues<APFloat>()[0];
if (!element.isInfinity()) {
return false;
}
if (!element.isNegative()) {
return false;
}

return true;
return TFL::IsNegInfiniteValue(element);
}

LogicalResult replaceWithMaxPool(mhlo::ReduceWindowOp op, Value input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// RUN: tf-opt %s --canonicalize-boundary-value --split-input-file | FileCheck %s

// CHECK-LABEL: func.func @clamp_neg_inf_f32() -> tensor<f32> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor<f32>
// CHECK: return %[[CONST]] : tensor<f32>

func.func @clamp_neg_inf_f32() -> tensor<f32> {
%ret = stablehlo.constant dense<0xFF800000> : tensor<f32>
return %ret : tensor<f32>
}

// -----

// CHECK-LABEL: func.func @clamp_pos_inf_f32() -> tensor<f32> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<3.40282347E+38> : tensor<f32>
// CHECK: return %[[CONST]] : tensor<f32>
func.func @clamp_pos_inf_f32() -> tensor<f32> {
%ret = stablehlo.constant dense<0x7F800000> : tensor<f32>
return %ret : tensor<f32>
}

// -----

// CHECK-LABEL: func.func @clamp_pos_inf_f32_tensor() -> tensor<1x4xf32> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<{{\[\[}}3.40282347E+38, 1.000000e+01, 2.000000e+01, -3.40282347E+38]]> : tensor<1x4xf32>
// CHECK: return %[[CONST]] : tensor<1x4xf32>
func.func @clamp_pos_inf_f32_tensor() -> tensor<1x4xf32> {
%ret = stablehlo.constant dense<[[0x7F800000, 10.0, 20.0, 0xFF800000]]> : tensor<1x4xf32>
return %ret : tensor<1x4xf32>
}

// -----

// CHECK-LABEL: func.func @clamp_neg_inf_f16() -> tensor<f16> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<-6.550400e+04> : tensor<f16>
// CHECK: return %[[CONST]] : tensor<f16>
func.func @clamp_neg_inf_f16() -> tensor<f16> {
%ret = stablehlo.constant dense<0xFC00> : tensor<f16>
return %ret : tensor<f16>
}
// -----

// CHECK-LABEL: func.func @clamp_neg_inf_bf16() -> tensor<bf16> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<-1.038460e+34> : tensor<bf16>
// CHECK: return %[[CONST]] : tensor<bf16>
func.func @clamp_neg_inf_bf16() -> tensor<bf16> {
%ret = stablehlo.constant dense<0xF800> : tensor<bf16>
return %ret : tensor<bf16>
}

// -----

// CHECK-LABEL: func.func @clamp_pos_inf_f16() -> tensor<f16> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<6.550400e+04> : tensor<f16>
// CHECK: return %[[CONST]] : tensor<f16>
func.func @clamp_pos_inf_f16() -> tensor<f16> {
%ret = stablehlo.constant dense<0x7C00> : tensor<f16>
return %ret : tensor<f16>
}

// -----

// CHECK-LABEL: func.func @clamp_pos_inf_f16_tensor() -> tensor<1x4xf16> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<{{\[\[}}6.550400e+04, 1.000000e+01, 2.000000e+01, -6.550400e+04]]> : tensor<1x4xf16>
// CHECK: return %[[CONST]] : tensor<1x4xf16>
func.func @clamp_pos_inf_f16_tensor() -> tensor<1x4xf16> {
%ret = stablehlo.constant dense<[[0x7C00, 10.0, 20.0, 0xFC00]]> : tensor<1x4xf16>
return %ret : tensor<1x4xf16>
}

// -----

// CHECK-LABEL: func.func @clamp_pos_inf_f16_tensor_tf_const() -> tensor<3xf16> {
// CHECK: %[[CONST:.*]] = "tf.Const"() <{value = dense<6.550400e+04> : tensor<3xf16>}> : () -> tensor<3xf16>
// CHECK: return %[[CONST]] : tensor<3xf16>
func.func @clamp_pos_inf_f16_tensor_tf_const() -> tensor<3xf16> {
%ret = "tf.Const"() <{value = dense<0x7C00> : tensor<3xf16>}> : () -> tensor<3xf16>
return %ret : tensor<3xf16>
}


// -----

// CHECK-LABEL: func.func @clamp_pos_inf_f16_arith_op() -> f16 {
// CHECK: %[[CONST:.*]] = arith.constant 6.550400e+04 : f16
// CHECK: return %[[CONST]] : f16
func.func @clamp_pos_inf_f16_arith_op() -> f16 {
%ret = arith.constant 0x7C00 : f16
return %ret : f16
}

10 changes: 10 additions & 0 deletions tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,11 @@ void AddPostVariableFreezingTFToTFLConversionPasses(
pass_manager->addPass(mlir::TFL::CreateReduceTypePrecisionPass());
}

// This pass should alway run before the end of the model conversion but
// not after the CreateSplitMergedOperandsPass below.
if (pass_config.canonicalizing_inf_as_min_max_float)
pass_manager->addPass(mlir::TFL::CreateCanonicalizeBoundaryValuePass());

// This pass should be always at the end of the model
// conversion (even after quantization). Some TFL ops like unidirectional
// sequence lstm will have stateful operands and some optimization passes
Expand All @@ -514,7 +519,12 @@ void AddPostVariableFreezingTFToTFLConversionPasses(
// model dialect.
pass_manager->addPass(
mlir::TFL::CreateInsertCallOnceOpFromSessionInitializerPass());
} else {
// This pass should alway run before the end of the model conversion.
if (pass_config.canonicalizing_inf_as_min_max_float)
pass_manager->addPass(mlir::TFL::CreateCanonicalizeBoundaryValuePass());
}

if (pass_config.unfold_large_splat_constant) {
pass_manager->addPass(mlir::TFL::CreateUnfoldLargeSplatConstantPass());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <memory>
#include <utility>

#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"

namespace mlir {
namespace TFL {
namespace {

#define DEBUG_TYPE "canonicalize-boundary-value"

#define GEN_PASS_DEF_CANONICALIZEBOUNDARYVALUEPASS
#include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"

class CanonicalizeBoundaryValuePass
: public impl::CanonicalizeBoundaryValuePassBase<
CanonicalizeBoundaryValuePass> {
void runOnOperation() override;
};

// Clamp constant -Inf/Inf to MIN/MAX float value.
template <typename OpTy>
struct ClampInfToMinMaxFloat : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;

LogicalResult matchAndRewrite(OpTy const_op,
PatternRewriter& rewriter) const override {
Attribute attr = const_op.getValueAttr();
if (auto float_attr = llvm::dyn_cast<FloatAttr>(attr)) {
if (float_attr.getValue().isInfinity()) {
FloatType float_type = llvm::dyn_cast<FloatType>(const_op.getType());
if (!float_type) return failure();
rewriter.replaceOpWithNewOp<OpTy>(
const_op, rewriter.getFloatAttr(
float_type, APFloat::getLargest(
float_type.getFloatSemantics(),
float_attr.getValue().isNegative())));
return success();
}
}

ElementsAttr tensor_attr = llvm::dyn_cast<ElementsAttr>(attr);
if (!tensor_attr) return failure();

Type type = tensor_attr.getType();
ShapedType tensor_type = llvm::cast<ShapedType>(type);
auto float_type = dyn_cast<FloatType>(tensor_type.getElementType());
if (!float_type) return failure();

auto vals_orig = tensor_attr.getValues<APFloat>();
// If all values are finite, no need to rewrite.
if (llvm::all_of(vals_orig, [&](APFloat val) { return !val.isInfinity(); }))
return failure();

SmallVector<APFloat> vals_new(llvm::map_range(vals_orig, [&](APFloat val) {
return val.isInfinity()
? APFloat::getLargest(float_type.getFloatSemantics(),
val.isNegative())
: val;
}));
rewriter.replaceOpWithNewOp<OpTy>(
const_op, DenseElementsAttr::get(tensor_type, vals_new));
return success();
}
};

void CanonicalizeBoundaryValuePass::runOnOperation() {
auto* ctx = &getContext();

RewritePatternSet patterns(ctx);
patterns.add<ClampInfToMinMaxFloat<stablehlo::ConstantOp>,
ClampInfToMinMaxFloat<TF::ConstOp>,
ClampInfToMinMaxFloat<arith::ConstantOp>>(ctx);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}

} // end namespace

std::unique_ptr<OperationPass<ModuleOp>> CreateCanonicalizeBoundaryValuePass() {
return std::make_unique<CanonicalizeBoundaryValuePass>();
}

} // end namespace TFL
} // end namespace mlir
Loading

0 comments on commit fe6647a

Please sign in to comment.