[go: nahoru, domu]

Skip to content

Commit

Permalink
This change is to canonicalize boundary values in the IR replacing -I…
Browse files Browse the repository at this point in the history
…nf/Inf with MIN/MAX float value.

PiperOrigin-RevId: 630372176
  • Loading branch information
tensorflower-gardener committed May 22, 2024
1 parent c662838 commit b59bae9
Show file tree
Hide file tree
Showing 18 changed files with 309 additions and 13 deletions.
27 changes: 27 additions & 0 deletions tensorflow/compiler/mlir/lite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,32 @@ tf_cc_test(
],
)

cc_library(
name = "canonicalize_boundary_value",
srcs = [
"transforms/canonicalize_boundary_value_pass.cc",
],
hdrs = [
"transforms/passes.h",
],
deps = [
":tensorflow_lite",
":tensorflow_lite_passes_inc_gen",
"//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config",
"//tensorflow/compiler/mlir/tensorflow",
"@com_google_absl//absl/container:flat_hash_set",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@stablehlo//:stablehlo_ops",
],
)

cc_library(
name = "tensorflow_lite_legalize_tf",
srcs = [
Expand Down Expand Up @@ -1365,6 +1391,7 @@ cc_library(
"tf_tfl_passes.h",
],
deps = [
":canonicalize_boundary_value",
":common",
":fake_quant_utils",
":tensorflow_lite_d2s",
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/compiler/mlir/lite/common/tfl_pass_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ struct PassConfig {

// Enables the attempt to directly lower composites into tflite ops.
bool enable_composite_direct_lowering = true;

// When set to true, convert +Inf/-Inf to MIN/MAX float value and output of
// convert only contains finite values.
bool canonicalizing_inf_as_min_max_float = false;
};

inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
pass_config.enable_stablehlo_quantizer = toco_flags.has_quantization_config();
pass_config.enable_composite_direct_lowering =
toco_flags.enable_composite_direct_lowering();
pass_config.canonicalizing_inf_as_min_max_float =
toco_flags.canonicalizing_inf_as_min_max_float();

if (toco_flags.qdq_conversion_mode() == "STATIC") {
pass_config.quant_specs.qdq_conversion_mode =
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/lite/stablehlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ cc_library(
":tf_stablehlo",
":unfold_splat_constant_pass",
":unfuse_batch_norm_pass",
"//tensorflow/compiler/mlir/lite:canonicalize_boundary_value",
"//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes",
"//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes",
Expand Down Expand Up @@ -814,6 +815,7 @@ tf_cc_binary(
"//tensorflow/cc/saved_model:loader",
"//tensorflow/compiler/mlir:init_mlir",
"//tensorflow/compiler/mlir:passes",
"//tensorflow/compiler/mlir/lite:canonicalize_boundary_value",
"//tensorflow/compiler/mlir/lite:flatbuffer_export",
"//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer",
"//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h"
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h"
#include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h"
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
Expand Down
18 changes: 6 additions & 12 deletions tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/utils.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
Expand Down Expand Up @@ -1998,7 +1999,7 @@ class ConvertReduceOpToTfMax
if (mlir::isa<FloatType>(type)) {
APFloat const_value(.0);
if (failed(GetConstantSplatValue(init_value, const_value)) ||
!const_value.isInfinity() || !const_value.isNegative())
!TFL::IsMinFloatValue(const_value))
return failure();
} else if (mlir::isa<IntegerType>(type) && type.isSignlessInteger()) {
APInt const_value;
Expand All @@ -2023,7 +2024,7 @@ class ConvertReduceOpToTfMin
if (mlir::isa<FloatType>(type)) {
APFloat const_value(.0);
if (failed(GetConstantSplatValue(init_value, const_value)) ||
!const_value.isInfinity() || const_value.isNegative())
!TFL::IsMaxFloatValue(const_value))
return failure();
} else if (mlir::isa<IntegerType>(type) && type.isSignlessInteger()) {
APInt const_value;
Expand Down Expand Up @@ -2081,7 +2082,7 @@ class ConvertReduceOpToTfArgmax
return false;
if (mlir::isa<FloatType>(element_type)) {
auto value = *attr.value_begin<APFloat>();
return value.isNegative() && value.isInfinity();
return TFL::IsMinFloatValue(value);
} else if (element_type.isInteger(1)) {
auto value = *attr.value_begin<APInt>();
return value.isZero();
Expand All @@ -2105,7 +2106,7 @@ class ConvertReduceOpToTfArgmin
return false;
if (mlir::isa<FloatType>(element_type)) {
auto value = *attr.value_begin<APFloat>();
return !value.isNegative() && value.isInfinity();
return TFL::IsMaxFloatValue(value);
} else if (element_type.isInteger(1)) {
auto value = *attr.value_begin<APInt>();
return value.isZero();
Expand Down Expand Up @@ -2596,14 +2597,7 @@ class ConvertMaxPoolOp : public OpConversionPattern<mhlo::ReduceWindowOp> {
}

APFloat element = float_value.getValues<APFloat>()[0];
if (!element.isInfinity()) {
return false;
}
if (!element.isNegative()) {
return false;
}

return true;
return TFL::IsMinFloatValue(element);
}

LogicalResult replaceWithMaxPool(mhlo::ReduceWindowOp op, Value input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// RUN: tf-opt %s --canonicalize-boundary-value --split-input-file | FileCheck %s

// CHECK-LABEL: func.func @clamp_neg_inf_f32() -> tensor<f32> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor<f32>
// CHECK: return %[[CONST]] : tensor<f32>

func.func @clamp_neg_inf_f32() -> tensor<f32> {
%ret = stablehlo.constant dense<0xFF800000> : tensor<f32>
return %ret : tensor<f32>
}

// -----

// CHECK-LABEL: func.func @clamp_pos_inf_f32() -> tensor<f32> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<3.40282347E+38> : tensor<f32>
// CHECK: return %[[CONST]] : tensor<f32>
func.func @clamp_pos_inf_f32() -> tensor<f32> {
%ret = stablehlo.constant dense<0x7F800000> : tensor<f32>
return %ret : tensor<f32>
}

// -----

// CHECK-LABEL: func.func @clamp_pos_inf_f32_tensor() -> tensor<1x4xf32> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<{{\[\[}}3.40282347E+38, 1.000000e+01, 2.000000e+01, -3.40282347E+38]]> : tensor<1x4xf32>
// CHECK: return %[[CONST]] : tensor<1x4xf32>
func.func @clamp_pos_inf_f32_tensor() -> tensor<1x4xf32> {
%ret = stablehlo.constant dense<[[0x7F800000, 10.0, 20.0, 0xFF800000]]> : tensor<1x4xf32>
return %ret : tensor<1x4xf32>
}

// -----

// CHECK-LABEL: func.func @clamp_neg_inf_f16() -> tensor<f16> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<-6.550400e+04> : tensor<f16>
// CHECK: return %[[CONST]] : tensor<f16>
func.func @clamp_neg_inf_f16() -> tensor<f16> {
%ret = stablehlo.constant dense<0xFC00> : tensor<f16>
return %ret : tensor<f16>
}
// -----

// CHECK-LABEL: func.func @clamp_neg_inf_bf16() -> tensor<bf16> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<-1.038460e+34> : tensor<bf16>
// CHECK: return %[[CONST]] : tensor<bf16>
func.func @clamp_neg_inf_bf16() -> tensor<bf16> {
%ret = stablehlo.constant dense<0xF800> : tensor<bf16>
return %ret : tensor<bf16>
}

// -----

// CHECK-LABEL: func.func @clamp_pos_inf_f16() -> tensor<f16> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<6.550400e+04> : tensor<f16>
// CHECK: return %[[CONST]] : tensor<f16>
func.func @clamp_pos_inf_f16() -> tensor<f16> {
%ret = stablehlo.constant dense<0x7C00> : tensor<f16>
return %ret : tensor<f16>
}

// -----

// CHECK-LABEL: func.func @clamp_pos_inf_f16_tensor() -> tensor<1x4xf16> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<{{\[\[}}6.550400e+04, 1.000000e+01, 2.000000e+01, -6.550400e+04]]> : tensor<1x4xf16>
// CHECK: return %[[CONST]] : tensor<1x4xf16>
func.func @clamp_pos_inf_f16_tensor() -> tensor<1x4xf16> {
%ret = stablehlo.constant dense<[[0x7C00, 10.0, 20.0, 0xFC00]]> : tensor<1x4xf16>
return %ret : tensor<1x4xf16>
}

// -----

// CHECK-LABEL: func.func @clamp_pos_inf_f16_tensor_tf_const() -> tensor<3xf16> {
// CHECK: %[[CONST:.*]] = "tf.Const"() <{value = dense<6.550400e+04> : tensor<3xf16>}> : () -> tensor<3xf16>
// CHECK: return %[[CONST]] : tensor<3xf16>
func.func @clamp_pos_inf_f16_tensor_tf_const() -> tensor<3xf16> {
%ret = "tf.Const"() <{value = dense<0x7C00> : tensor<3xf16>}> : () -> tensor<3xf16>
return %ret : tensor<3xf16>
}
4 changes: 4 additions & 0 deletions tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,8 @@ absl::Status ConvertTFExecutorToStablehloFlatbuffer(
return status_handler.ConsumeStatus();
}
pass_manager.clear();
if (pass_config.canonicalizing_inf_as_min_max_float)
pass_manager.addPass(mlir::TFL::CreateCanonicalizeBoundaryValuePass());
pass_manager.addPass(mlir::odml::createLegalizeStablehloToVhloPass());
if (failed(pass_manager.run(module))) {
return status_handler.Combine(
Expand Down Expand Up @@ -498,6 +500,8 @@ absl::Status ConvertTFExecutorToTFLOrFlatbuffer(
MetadataForReducedPrecisionSupport(quant_specs.support_mask));
}
pass_manager.clear();
if (pass_config.canonicalizing_inf_as_min_max_float)
pass_manager.addPass(mlir::TFL::CreateCanonicalizeBoundaryValuePass());
pass_manager.addPass(mlir::odml::createLegalizeStablehloToVhloPass());
if (failed(pass_manager.run(module))) {
return status_handler.Combine(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <memory>
#include <utility>

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"

namespace mlir {
namespace TFL {
namespace {

#define DEBUG_TYPE "canonicalize-boundary-value"

#define GEN_PASS_DEF_CANONICALIZEBOUNDARYVALUEPASS
#include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"

class CanonicalizeBoundaryValuePass
: public impl::CanonicalizeBoundaryValuePassBase<
CanonicalizeBoundaryValuePass> {
void runOnOperation() override;
};

// Clamp constant -Inf/Inf to MIN/MAX float value.
template <typename OpTy>
struct ClampInfToMinMaxFloat : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;

LogicalResult matchAndRewrite(OpTy const_op,
PatternRewriter& rewriter) const override {
ElementsAttr tensor_attr =
const_op.getValueAttr().template cast<ElementsAttr>();
ShapedType tensor_type = tensor_attr.getShapedType();
auto float_type = dyn_cast<mlir::FloatType>(tensor_type.getElementType());
if (!float_type) return failure();

auto vals_orig = tensor_attr.getValues<APFloat>();
// If all values are finite, no need to rewrite.
if (llvm::all_of(vals_orig, [&](APFloat val) { return !val.isInfinity(); }))
return failure();

SmallVector<APFloat> vals_new(llvm::map_range(vals_orig, [&](APFloat val) {
return val.isInfinity()
? APFloat::getLargest(float_type.getFloatSemantics(),
val.isNegative())
: val;
}));
rewriter.replaceOpWithNewOp<OpTy>(
const_op, DenseElementsAttr::get(tensor_type, vals_new));
return success();
}
};

void CanonicalizeBoundaryValuePass::runOnOperation() {
auto* ctx = &getContext();

RewritePatternSet patterns(ctx);
patterns.add<ClampInfToMinMaxFloat<stablehlo::ConstantOp>,
ClampInfToMinMaxFloat<TF::ConstOp>,
ClampInfToMinMaxFloat<arith::ConstantOp>>(ctx);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}

} // end namespace

std::unique_ptr<OperationPass<ModuleOp>> CreateCanonicalizeBoundaryValuePass() {
return std::make_unique<CanonicalizeBoundaryValuePass>();
}

} // end namespace TFL
} // end namespace mlir
3 changes: 3 additions & 0 deletions tensorflow/compiler/mlir/lite/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateReduceTypePrecisionPass();
// so redudant ones may be grouped and removed.
std::unique_ptr<OperationPass<ModuleOp>> CreatePushTransposeThroughEwisePass();

// Create a pass that canonicalize the boundary values.
std::unique_ptr<OperationPass<ModuleOp>> CreateCanonicalizeBoundaryValuePass();

// Creates a pass that brings operations into the same order as graph_info.cc.
std::unique_ptr<OperationPass<func::FuncOp>>
CreatePartitionedTopologicalSortPass();
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/compiler/mlir/lite/transforms/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -476,3 +476,8 @@ def PushTransposeThroughEwisePass : Pass<"push-transpose-through-ewise", "mlir::
let dependentDialects = ["TFL::TensorFlowLiteDialect"];
}

def CanonicalizeBoundaryValuePass : Pass<"canonicalize-boundary-value", "mlir::ModuleOp"> {
let summary = "Canonicalize the IR representations of boundary values";
let dependentDialects = ["mlir::stablehlo::StablehloDialect", "TF::TensorFlowDialect", "mlir::arith::ArithDialect"];
let constructor = "CreateCanonicalizeBoundaryValuePass()";
}
12 changes: 12 additions & 0 deletions tensorflow/compiler/mlir/lite/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ using mlir::Operation;
using mlir::ShapedType;
using mlir::Value;

// Returns true if the value is the min float value.
inline bool IsMinFloatValue(APFloat value) {
if (!value.isNegative()) return false;
return value.isLargest() || value.isInfinity();
}

// Returns true if the value is the max float value.
inline bool IsMaxFloatValue(APFloat value) {
if (value.isNegative()) return false;
return value.isLargest() || value.isInfinity();
}

// Returns true if all tensor value in `values` has static shape and same shape.
inline bool OpHasSameStaticShapes(Operation* op) {
auto values = op->getOperands();
Expand Down
Loading

0 comments on commit b59bae9

Please sign in to comment.