[go: nahoru, domu]

Skip to content

Commit

Permalink
Do not partially quantize the inputs of UnidirectionalSequenceLSTMOp
Browse files Browse the repository at this point in the history
The TFLite runtime doesn't currently support hybrid/partial quantization of the
inputs for the UnidirectionalSequenceLSTMOp. As a consequence, we must ensure
quantization occurs atomically for all of the input constants. In cases where
any of the input constants do not meet the minimum element count for
quantization eligibility (as defined by the `minimum_elements_for_weights`
quantization config parameter) then skip quantization for this op entirely.

PiperOrigin-RevId: 534562296
  • Loading branch information
arfaian authored and tensorflower-gardener committed May 23, 2023
1 parent 24260f2 commit b8d7b00
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// RUN: tf-opt %s -tfl-prepare-quantize-dynamic-range="enable-float16-quantization" | FileCheck --check-prefix=Float16 %s
// RUN: tf-opt %s -tfl-prepare-quantize-dynamic-range="enable-custom-op-quantization=CustomTestOp=1-3,CustomTestOp3=3" | FileCheck --check-prefix=CustomOp %s
// RUN: tf-opt %s -tfl-prepare-quantize-dynamic-range="min-elements-for-weights=4000 enable-custom-op-quantization=CustomTestOp=1-3,CustomTestOp3=3" | FileCheck --check-prefix=MinElement %s
// RUN: tf-opt %s -tfl-prepare-quantize-dynamic-range="min-elements-for-weights=19" | FileCheck --check-prefix=LSTMOpQuantized %s
// RUN: tf-opt %s -tfl-prepare-quantize-dynamic-range="min-elements-for-weights=21" | FileCheck --check-prefix=LSTMOpNotQuantized %s

// CHECK-LABEL: QuantizeConv2D
// PerTensor-LABEL: QuantizeConv2D
Expand Down Expand Up @@ -409,3 +411,41 @@ func.func @LargeFloat16Constants(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112
// Float16-DAG: %[[w:.*]] = arith.constant dense<6.550400e+04> : tensor<64x3x3x3xf16>
// Float16-DAG: %[[b:.*]] = arith.constant dense<-6.550400e+04> : tensor<64xf16>
}

// LSTMOpQuantized-LABEL: LSTMOpNotPartiallyQuantized
// LSTMOpNotQuantized-LABEL: LSTMOpNotPartiallyQuantized
func.func @LSTMOpNotPartiallyQuantized(%arg0: tensor<1x28x28xf32>) -> tensor<1x28x20xf32> {
%cst_2 = "tfl.no_value"() {value = unit} : () -> none
%cst_3 = arith.constant dense<1.0> : tensor<20x20xf32>
%cst_7 = arith.constant dense<1.0> : tensor<20xf32>
%recurrent_input = arith.constant dense<1.0> : tensor<1x20xf32>
%recurrent_stats = "quantfork.stats"(%recurrent_input) {layerStats = dense<[-2.0, 1.0]> : tensor<2xf32>} : (tensor<1x20xf32>) -> tensor<1x20xf32>
%cell_input = arith.constant dense<1.0> : tensor<1x20xf32>
%cell_stats = "quantfork.stats"(%cell_input) {layerStats = dense<[-2.73090601, 7.94872093]> : tensor<2xf32>} : (tensor<1x20xf32>) -> tensor<1x20xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0,
%cst_3, %cst_3, %cst_3, %cst_3,
%cst_3, %cst_3, %cst_3, %cst_3,
%cst_7, %cst_7, %cst_7,
%cst_7, %cst_7, %cst_7, %cst_7,
%cst_3, %cst_2,
%recurrent_stats, %cell_stats,
%cst_2, %cst_2, %cst_2, %cst_2) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}
: ( tensor<1x28x28xf32>,
tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>,
tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>,
tensor<20xf32>, tensor<20xf32>, tensor<20xf32>,
tensor<20xf32>, tensor<20xf32>, tensor<20xf32>, tensor<20xf32>,
tensor<20x20xf32>, none,
tensor<1x20xf32>, tensor<1x20xf32>,
none, none, none, none) -> tensor<1x28x20xf32>
%1 = "quantfork.stats"(%0) {layerStats = dense<[-1.0, 2.0]> : tensor<2xf32>} : (tensor<1x28x20xf32>) -> tensor<1x28x20xf32>
func.return %1 : tensor<1x28x20xf32>

// LSTMOpQuantized-DAG: %[[dq1:.*]] = "tfl.dequantize"({{.*}}) : (tensor<20x20x!quant.uniform<i8<-127:127>:f32, 0.0078740157480314959>>) -> tensor<20x20xf32>
// LSTMOpQuantized-DAG: %[[dq3:.*]] = "tfl.dequantize"({{.*}}) : (tensor<20x!quant.uniform<i8<-127:127>:f32, 0.0078740157480314959>>) -> tensor<20xf32>
// LSTMOpQuantized: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %[[dq1]], %[[dq1]], %[[dq1]], %[[dq1]], %[[dq1]], %[[dq1]], %[[dq1]], %[[dq1]], %[[dq3]], %[[dq3]], %[[dq3]], %cst_0, %cst_0, %cst_0, %cst_0, %[[dq1]], %0, %cst_1, %cst_1, %0, %0, %0, %0)

// LSTMOpNotQuantized-DAG: %[[cst_1:.*]] = arith.constant dense<1.000000e+00> : tensor<20x20xf32>
// LSTMOpNotQuantized-DAG: %[[cst_3:.*]] = arith.constant dense<1.000000e+00> : tensor<20xf32>
// LSTMOpNotQuantized: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %[[cst_1]], %[[cst_1]], %[[cst_1]], %[[cst_1]], %[[cst_1]], %[[cst_1]], %[[cst_1]], %[[cst_1]], %[[cst_3]], %[[cst_3]], %[[cst_3]], %cst_0, %cst_0, %cst_0, %cst_0, %[[cst_1]], %0, %cst_1, %cst_1, %0, %0, %0, %0)
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ class PrepareDynamicRangeQuantizePass
void runOnOperation() override;

private:
// Keeps track of ops whose inputs cannot be quantized due to not meeting the
// minimum_elements_for_weights threshold. Prevents emitting duplicate
// warnings for the same op, once deemed ineligible for quantization.
llvm::SetVector<Operation*> visited_nonquantizable_ops_;
quant::QuantizationSpecs quant_specs_;
};

Expand All @@ -95,8 +99,10 @@ class PrepareDynamicRangeQuantizableOp
: public OpRewritePattern<arith::ConstantOp> {
public:
explicit PrepareDynamicRangeQuantizableOp(
MLIRContext* context, const quant::QuantizationSpecs& quant_specs)
MLIRContext* context, const quant::QuantizationSpecs& quant_specs,
llvm::SetVector<Operation*>* const visited_nonquantizable_ops)
: OpRewritePattern<arith::ConstantOp>(context),
visited_nonquantizable_ops_(visited_nonquantizable_ops),
quant_specs_(quant_specs) {}

LogicalResult matchAndRewrite(arith::ConstantOp op,
Expand Down Expand Up @@ -129,6 +135,8 @@ class PrepareDynamicRangeQuantizableOp
}

private:
llvm::SetVector<Operation*>* const visited_nonquantizable_ops_;

// Check if the operand_index is included in the quantizable_indices.
bool isQuantizableIndex(const int operand_index,
const std::vector<int>& quantizable_indices) const {
Expand All @@ -142,6 +150,10 @@ class PrepareDynamicRangeQuantizableOp
// specification for checking the support. For custom ops, it checks the
// provided map.
bool hasInt8QuantizableOperandAt(Operation* op, int operand_index) const {
if (visited_nonquantizable_ops_->contains(op)) {
return false;
}

if (auto custom_op = llvm::dyn_cast_or_null<CustomOp>(op)) {
std::string op_name = custom_op.getCustomCode().str();
auto custom_map_iter = quant_specs_.custom_map.find(op_name);
Expand All @@ -152,7 +164,53 @@ class PrepareDynamicRangeQuantizableOp
llvm::dyn_cast<DynamicRangeQuantizedOpInterface>(op)) {
const auto& quantizable_indices =
quantizable_op.GetQuantizableOperandIndices();
return isQuantizableIndex(operand_index, quantizable_indices);

if (!isQuantizableIndex(operand_index, quantizable_indices)) {
return false;
}

// Special case handling for UnidirectionalSequenceLSTMOp, which doesn't
// support partial quantization of its inputs.
// Below, we check all of the input constants for the
// UnidirectionalSequenceLSTMOp to see if any of them would not be
// quantized due to not meeting the minimum_elements_for_weights
// threshold. Should we find any, we don't quantize any of the ops.
if (!llvm::dyn_cast<UnidirectionalSequenceLSTMOp>(op)) {
return true;
}

for (int qi : quantizable_indices) {
auto const_op = llvm::dyn_cast_or_null<arith::ConstantOp>(
op->getOperand(qi).getDefiningOp());
if (!const_op) {
continue;
}

DenseFPElementsAttr attr;
if (!matchPattern(const_op->getResult(0), m_Constant(&attr))) {
continue;
}

if (attr.dyn_cast<DenseFPElementsAttr>().size() >=
quant_specs_.minimum_elements_for_weights) {
continue;
}

visited_nonquantizable_ops_->insert(op);
op->emitWarning(
"Skipped quantization for UnidirectionalSequenceLSTMOp. Partial "
"quantization of inputs for UnidirectionalSequenceLSTMOp is not "
"supported. The operand ")
<< const_op->getName().getStringRef().str() << " at index " << qi
<< " was not quantized because it has "
<< attr.dyn_cast<DenseFPElementsAttr>().size()
<< " elements which is fewer than the "
"`minimum_elements_for_weights` threshold of "
<< quant_specs_.minimum_elements_for_weights;
return false;
}

return true;
}
return false;
}
Expand Down Expand Up @@ -427,7 +485,8 @@ void PrepareDynamicRangeQuantizePass::runOnOperation() {
removeAllStatsOp(func);

RewritePatternSet patterns(&getContext());
patterns.add<PrepareDynamicRangeQuantizableOp>(ctx, quant_specs_);
patterns.add<PrepareDynamicRangeQuantizableOp>(ctx, quant_specs_,
&visited_nonquantizable_ops_);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));

ConvertMlirQuantOpsToTFLQuantOps(func);
Expand Down

0 comments on commit b8d7b00

Please sign in to comment.