From b8d7b00627ba39e2e84fdcbd1217d620a7836a3c Mon Sep 17 00:00:00 2001 From: Arian Arfaian Date: Tue, 23 May 2023 14:37:07 -0700 Subject: [PATCH] Do not partially quantize the inputs of UnidirectionalSequenceLSTMOp The TFLite runtime doesn't currently support hybrid/partial quantization of the inputs for the UnidirectionalSequenceLSTMOp. As a consequence, we must ensure quantization occurs atomically for all of the input constants. In cases where any of the input constants do not meet the minimum element count for quantization eligibility (as defined by the `minimum_elements_for_weights` quantization config parameter) then skip quantization for this op entirely. PiperOrigin-RevId: 534562296 --- .../tests/prepare-quantize-dynamic-range.mlir | 40 ++++++++++++ .../prepare_quantize_dynamic_range.cc | 65 ++++++++++++++++++- 2 files changed, 102 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir index b549b5645157d4..01ed79e5a63f30 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir @@ -3,6 +3,8 @@ // RUN: tf-opt %s -tfl-prepare-quantize-dynamic-range="enable-float16-quantization" | FileCheck --check-prefix=Float16 %s // RUN: tf-opt %s -tfl-prepare-quantize-dynamic-range="enable-custom-op-quantization=CustomTestOp=1-3,CustomTestOp3=3" | FileCheck --check-prefix=CustomOp %s // RUN: tf-opt %s -tfl-prepare-quantize-dynamic-range="min-elements-for-weights=4000 enable-custom-op-quantization=CustomTestOp=1-3,CustomTestOp3=3" | FileCheck --check-prefix=MinElement %s +// RUN: tf-opt %s -tfl-prepare-quantize-dynamic-range="min-elements-for-weights=19" | FileCheck --check-prefix=LSTMOpQuantized %s +// RUN: tf-opt %s -tfl-prepare-quantize-dynamic-range="min-elements-for-weights=21" | FileCheck --check-prefix=LSTMOpNotQuantized %s // CHECK-LABEL: QuantizeConv2D // PerTensor-LABEL: QuantizeConv2D @@ -409,3 +411,41 @@ func.func @LargeFloat16Constants(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112 // Float16-DAG: %[[w:.*]] = arith.constant dense<6.550400e+04> : tensor<64x3x3x3xf16> // Float16-DAG: %[[b:.*]] = arith.constant dense<-6.550400e+04> : tensor<64xf16> } + +// LSTMOpQuantized-LABEL: LSTMOpNotPartiallyQuantized +// LSTMOpNotQuantized-LABEL: LSTMOpNotPartiallyQuantized +func.func @LSTMOpNotPartiallyQuantized(%arg0: tensor<1x28x28xf32>) -> tensor<1x28x20xf32> { + %cst_2 = "tfl.no_value"() {value = unit} : () -> none + %cst_3 = arith.constant dense<1.0> : tensor<20x20xf32> + %cst_7 = arith.constant dense<1.0> : tensor<20xf32> + %recurrent_input = arith.constant dense<1.0> : tensor<1x20xf32> + %recurrent_stats = "quantfork.stats"(%recurrent_input) {layerStats = dense<[-2.0, 1.0]> : tensor<2xf32>} : (tensor<1x20xf32>) -> tensor<1x20xf32> + %cell_input = arith.constant dense<1.0> : tensor<1x20xf32> + %cell_stats = "quantfork.stats"(%cell_input) {layerStats = dense<[-2.73090601, 7.94872093]> : tensor<2xf32>} : (tensor<1x20xf32>) -> tensor<1x20xf32> + %0 = "tfl.unidirectional_sequence_lstm"(%arg0, + %cst_3, %cst_3, %cst_3, %cst_3, + %cst_3, %cst_3, %cst_3, %cst_3, + %cst_7, %cst_7, %cst_7, + %cst_7, %cst_7, %cst_7, %cst_7, + %cst_3, %cst_2, + %recurrent_stats, %cell_stats, + %cst_2, %cst_2, %cst_2, %cst_2) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} + : ( tensor<1x28x28xf32>, + tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, + tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, + tensor<20xf32>, tensor<20xf32>, tensor<20xf32>, + tensor<20xf32>, tensor<20xf32>, tensor<20xf32>, tensor<20xf32>, + tensor<20x20xf32>, none, + tensor<1x20xf32>, tensor<1x20xf32>, + none, none, none, none) -> tensor<1x28x20xf32> + %1 = "quantfork.stats"(%0) {layerStats = dense<[-1.0, 2.0]> : tensor<2xf32>} : (tensor<1x28x20xf32>) -> tensor<1x28x20xf32> + func.return %1 : tensor<1x28x20xf32> + +// LSTMOpQuantized-DAG: %[[dq1:.*]] = "tfl.dequantize"({{.*}}) : (tensor<20x20x!quant.uniform:f32, 0.0078740157480314959>>) -> tensor<20x20xf32> +// LSTMOpQuantized-DAG: %[[dq3:.*]] = "tfl.dequantize"({{.*}}) : (tensor<20x!quant.uniform:f32, 0.0078740157480314959>>) -> tensor<20xf32> +// LSTMOpQuantized: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %[[dq1]], %[[dq1]], %[[dq1]], %[[dq1]], %[[dq1]], %[[dq1]], %[[dq1]], %[[dq1]], %[[dq3]], %[[dq3]], %[[dq3]], %cst_0, %cst_0, %cst_0, %cst_0, %[[dq1]], %0, %cst_1, %cst_1, %0, %0, %0, %0) + +// LSTMOpNotQuantized-DAG: %[[cst_1:.*]] = arith.constant dense<1.000000e+00> : tensor<20x20xf32> +// LSTMOpNotQuantized-DAG: %[[cst_3:.*]] = arith.constant dense<1.000000e+00> : tensor<20xf32> +// LSTMOpNotQuantized: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %[[cst_1]], %[[cst_1]], %[[cst_1]], %[[cst_1]], %[[cst_1]], %[[cst_1]], %[[cst_1]], %[[cst_1]], %[[cst_3]], %[[cst_3]], %[[cst_3]], %cst_0, %cst_0, %cst_0, %cst_0, %[[cst_1]], %0, %cst_1, %cst_1, %0, %0, %0, %0) +} diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc index a19c29a666f06e..951748b31273f3 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc @@ -84,6 +84,10 @@ class PrepareDynamicRangeQuantizePass void runOnOperation() override; private: + // Keeps track of ops whose inputs cannot be quantized due to not meeting the + // minimum_elements_for_weights threshold. Prevents emitting duplicate + // warnings for the same op, once deemed ineligible for quantization. + llvm::SetVector visited_nonquantizable_ops_; quant::QuantizationSpecs quant_specs_; }; @@ -95,8 +99,10 @@ class PrepareDynamicRangeQuantizableOp : public OpRewritePattern { public: explicit PrepareDynamicRangeQuantizableOp( - MLIRContext* context, const quant::QuantizationSpecs& quant_specs) + MLIRContext* context, const quant::QuantizationSpecs& quant_specs, + llvm::SetVector* const visited_nonquantizable_ops) : OpRewritePattern(context), + visited_nonquantizable_ops_(visited_nonquantizable_ops), quant_specs_(quant_specs) {} LogicalResult matchAndRewrite(arith::ConstantOp op, @@ -129,6 +135,8 @@ class PrepareDynamicRangeQuantizableOp } private: + llvm::SetVector* const visited_nonquantizable_ops_; + // Check if the operand_index is included in the quantizable_indices. bool isQuantizableIndex(const int operand_index, const std::vector& quantizable_indices) const { @@ -142,6 +150,10 @@ class PrepareDynamicRangeQuantizableOp // specification for checking the support. For custom ops, it checks the // provided map. bool hasInt8QuantizableOperandAt(Operation* op, int operand_index) const { + if (visited_nonquantizable_ops_->contains(op)) { + return false; + } + if (auto custom_op = llvm::dyn_cast_or_null(op)) { std::string op_name = custom_op.getCustomCode().str(); auto custom_map_iter = quant_specs_.custom_map.find(op_name); @@ -152,7 +164,53 @@ class PrepareDynamicRangeQuantizableOp llvm::dyn_cast(op)) { const auto& quantizable_indices = quantizable_op.GetQuantizableOperandIndices(); - return isQuantizableIndex(operand_index, quantizable_indices); + + if (!isQuantizableIndex(operand_index, quantizable_indices)) { + return false; + } + + // Special case handling for UnidirectionalSequenceLSTMOp, which doesn't + // support partial quantization of its inputs. + // Below, we check all of the input constants for the + // UnidirectionalSequenceLSTMOp to see if any of them would not be + // quantized due to not meeting the minimum_elements_for_weights + // threshold. Should we find any, we don't quantize any of the ops. + if (!llvm::dyn_cast(op)) { + return true; + } + + for (int qi : quantizable_indices) { + auto const_op = llvm::dyn_cast_or_null( + op->getOperand(qi).getDefiningOp()); + if (!const_op) { + continue; + } + + DenseFPElementsAttr attr; + if (!matchPattern(const_op->getResult(0), m_Constant(&attr))) { + continue; + } + + if (attr.dyn_cast().size() >= + quant_specs_.minimum_elements_for_weights) { + continue; + } + + visited_nonquantizable_ops_->insert(op); + op->emitWarning( + "Skipped quantization for UnidirectionalSequenceLSTMOp. Partial " + "quantization of inputs for UnidirectionalSequenceLSTMOp is not " + "supported. The operand ") + << const_op->getName().getStringRef().str() << " at index " << qi + << " was not quantized because it has " + << attr.dyn_cast().size() + << " elements which is fewer than the " + "`minimum_elements_for_weights` threshold of " + << quant_specs_.minimum_elements_for_weights; + return false; + } + + return true; } return false; } @@ -427,7 +485,8 @@ void PrepareDynamicRangeQuantizePass::runOnOperation() { removeAllStatsOp(func); RewritePatternSet patterns(&getContext()); - patterns.add(ctx, quant_specs_); + patterns.add(ctx, quant_specs_, + &visited_nonquantizable_ops_); (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); ConvertMlirQuantOpsToTFLQuantOps(func);