[go: nahoru, domu]

Skip to content

Commit

Permalink
Fix Sum op kernel bug for models with int8/uint8 input.
Browse files Browse the repository at this point in the history
Fixes two bugs:
* The non-quantized sum kernel was being called for models with quantized
  inputs, resulting in incorrect results due to integer overflow.
* The optimized version of QuantizedMeanOrSum was not clamping values after
  executing the computation, resulting in overflow in some cases.

For the first issue, there was a problematic predicate for determining whether
to use quantization-aware kernels for the sum op. Before this change, inputs
had to be of type int8/uint8 and inputs/outputs had to have different scales
and zero points in order to qualify. The latter criteria has been removed.

One additional change worth noting (not quite sure it can be categorized as a
bug): before this change, only the reference version of QuantizedMeanOrSum was
being called. This change enables either the reference or optimized version to
be called depending on the specified kernel_type.

PiperOrigin-RevId: 538835574
  • Loading branch information
arfaian authored and tensorflower-gardener committed Jun 8, 2023
1 parent 61c4900 commit ca55d1c
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 66 deletions.
9 changes: 6 additions & 3 deletions tensorflow/lite/kernels/internal/optimized/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.

#include <algorithm>
#include <limits>
#include <vector>

#include "ruy/profiler/instrumentation.h" // from @ruy
#include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
Expand Down Expand Up @@ -407,6 +408,8 @@ bool QuantizedMeanOrSum(const T* input_data, int32_t input_zero_point,
const int* axis, const int num_axis_dimensions,
bool keep_dims, int* normalized_dims,
int* resolved_axis, U* temp_sum, bool compute_sum) {
const int32_t kMinValue = std::numeric_limits<T>::min();
const int32_t kMaxValue = std::numeric_limits<T>::max();
ruy::profiler::ScopeLabel label(compute_sum ? "QuantizedSum"
: "QuantizedMean");
// Reset output data.
Expand Down Expand Up @@ -467,9 +470,9 @@ bool QuantizedMeanOrSum(const T* input_data, int32_t input_zero_point,
if (compute_sum) {
const float bias = -input_zero_point * scale * num_elements_in_axis;
for (size_t idx = 0; idx < num_outputs; ++idx) {
const U value =
static_cast<U>(TfLiteRound(temp_sum[idx] * scale + bias)) +
output_zero_point;
U value = static_cast<U>(TfLiteRound(temp_sum[idx] * scale + bias)) +
output_zero_point;
value = std::min(std::max(value, kMinValue), kMaxValue);
output_data[idx] = static_cast<T>(value);
}
} else {
Expand Down
90 changes: 37 additions & 53 deletions tensorflow/lite/kernels/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -450,27 +450,40 @@ TfLiteStatus Mean(TfLiteContext* context, const OpContext* op_context,

template <typename T>
TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context,
const OpContext* op_context, int* temp_index,
int* resolved_axis, int* temp_sum,
KernelType kernel_type, bool compute_sum) {
int num_axis = static_cast<int>(NumElements(op_context->axis));
auto args = std::tuple(
GetTensorData<T>(op_context->input), op_context->input->params.zero_point,
op_context->input->params.scale, &op_context->input->dims->data[0],
op_context->input->dims->size, GetTensorData<T>(op_context->output),
op_context->output->params.zero_point, op_context->output->params.scale,
&op_context->output->dims->data[0], op_context->output->dims->size,
GetTensorData<int>(op_context->axis), num_axis,
op_context->params->keep_dims, temp_index, resolved_axis, temp_sum,
compute_sum);
if (kernel_type == kReference) {
const OpContext& op_context,
const OpData* op_data, TfLiteTensor* temp_index,
TfLiteTensor* resolved_axis,
TfLiteTensor* temp_sum, KernelType kernel_type,
bool compute_sum) {
int num_axis = static_cast<int>(NumElements(op_context.axis));
if (kernel_type == kGenericOptimized) {
TF_LITE_ENSURE(
context,
std::apply(reference_ops::QuantizedMeanOrSum<T, int32_t>, args));
optimized_ops::QuantizedMeanOrSum(
GetTensorData<T>(op_context.input),
op_context.input->params.zero_point, op_context.input->params.scale,
op_context.input->dims->data, op_context.input->dims->size,
GetTensorData<T>(op_context.output),
op_context.output->params.zero_point,
op_context.output->params.scale, op_context.output->dims->data,
op_context.output->dims->size, GetTensorData<int>(op_context.axis),
num_axis, op_context.params->keep_dims,
GetTensorData<int>(temp_index), GetTensorData<int>(resolved_axis),
GetTensorData<int32_t>(temp_sum), compute_sum));
} else {
TF_LITE_ENSURE(
context,
std::apply(optimized_ops::QuantizedMeanOrSum<T, int32_t>, args));
reference_ops::QuantizedMeanOrSum(
GetTensorData<uint8_t>(op_context.input),
op_context.input->params.zero_point, op_context.input->dims->data,
op_context.input->dims->size,
GetTensorData<uint8_t>(op_context.output), op_data->multiplier,
op_data->shift, op_context.output->params.zero_point,
op_context.output->dims->data, op_context.output->dims->size,
GetTensorData<int>(op_context.axis), num_axis,
op_context.params->keep_dims, GetTensorData<int>(temp_index),
GetTensorData<int>(resolved_axis), GetTensorData<int32_t>(temp_sum),
compute_sum));
}
return kTfLiteOk;
}
Expand Down Expand Up @@ -899,19 +912,12 @@ TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type>
TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
OpContext op_context(context, node);
const OpData* op_data = reinterpret_cast<const OpData*>(node->user_data);
ruy::profiler::ScopeLabel label("Sum");
const auto& input = op_context.input;
const auto& output = op_context.output;
const bool same_scale =
(input->params.scale == output->params.scale &&
input->params.zero_point == output->params.zero_point);
const bool eight_bit_quantized =
input->type == kTfLiteUInt8 || input->type == kTfLiteInt8;
const bool need_rescale = (eight_bit_quantized && !same_scale);
if (need_rescale) {
// Rescaling 8bit reduce sum.
int num_axis = static_cast<int>(NumElements(op_context.axis));
if (eight_bit_quantized) {
const OpData* op_data = reinterpret_cast<const OpData*>(node->user_data);
TfLiteTensor* temp_index;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/0, &temp_index));
Expand All @@ -931,36 +937,14 @@ TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
}

if (input->type == kTfLiteUInt8) {
TF_LITE_ENSURE(
context,
reference_ops::QuantizedMeanOrSum(
GetTensorData<uint8_t>(op_context.input),
op_context.input->params.zero_point, op_context.input->dims->data,
op_context.input->dims->size,
GetTensorData<uint8_t>(op_context.output), op_data->multiplier,
op_data->shift, op_context.output->params.zero_point,
op_context.output->dims->data, op_context.output->dims->size,
GetTensorData<int>(op_context.axis), num_axis,
op_context.params->keep_dims, GetTensorData<int>(temp_index),
GetTensorData<int>(resolved_axis),
GetTensorData<int32_t>(temp_sum),
/*compute_sum=*/true));
return QuantizedMeanOrSum<uint8_t>(context, op_context, op_data,
temp_index, resolved_axis, temp_sum,
kernel_type, /*compute_sum=*/true);
}
if (input->type == kTfLiteInt8) {
TF_LITE_ENSURE(
context,
reference_ops::QuantizedMeanOrSum(
GetTensorData<int8_t>(op_context.input),
op_context.input->params.zero_point, op_context.input->dims->data,
op_context.input->dims->size,
GetTensorData<int8_t>(op_context.output), op_data->multiplier,
op_data->shift, op_context.output->params.zero_point,
op_context.output->dims->data, op_context.output->dims->size,
GetTensorData<int>(op_context.axis), num_axis,
op_context.params->keep_dims, GetTensorData<int>(temp_index),
GetTensorData<int>(resolved_axis),
GetTensorData<int32_t>(temp_sum),
/*compute_sum=*/true));
return QuantizedMeanOrSum<int8_t>(context, op_context, op_data,
temp_index, resolved_axis, temp_sum,
kernel_type, /*compute_sum=*/true);
}
} else {
return EvalGeneric<kernel_type, kSum>(context, node);
Expand Down
29 changes: 19 additions & 10 deletions tensorflow/lite/kernels/reduce_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -792,9 +792,11 @@ TEST(ConstUint8SumOpTest, NotKeepDims) {
m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(
ArrayFloatNear({-0.823529, -0.815686}, kQuantizedTolerance)));
EXPECT_THAT(
m.GetDequantizedOutput<uint8_t>(),
// 0.4 + 0.3 + 0.5 = 1.0 (clamped)
// 0.2 + 0.4 + 0.6 = 1.0 (clamped)
ElementsAreArray(ArrayFloatNear({1.0, 1.0}, kQuantizedTolerance)));
}

TEST(ConstUint8SumOpTest, NotKeepDimsRescaling) {
Expand Down Expand Up @@ -840,9 +842,12 @@ TEST(ConstUint8SumOpTest, KeepDims) {
m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1}));
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(ArrayFloatNear({-0.407843, -0.313726, 0.0941177},
kQuantizedTolerance)));
EXPECT_THAT(
m.GetDequantizedOutput<uint8_t>(),
// 0.4 + 0.2 = 0.6
// 0.3 + 0.4 = 0.7
// 0.5 + 0.6 = 1.0 (clamped)
ElementsAreArray(ArrayFloatNear({0.6, 0.7, 1.0}, kQuantizedTolerance)));
}

TEST(DynamicUint8SumOpTest, NotKeepDims) {
Expand All @@ -856,9 +861,11 @@ TEST(DynamicUint8SumOpTest, NotKeepDims) {
m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(
ArrayFloatNear({1.48235, 1.64706}, kQuantizedTolerance)));
EXPECT_THAT(
m.GetDequantizedOutput<uint8_t>(),
// 1.3 + -4.8 = -3.5
// -3.6 + 0.24 = -3.36
ElementsAreArray(ArrayFloatNear({-3.5, -3.36}, kQuantizedTolerance)));
}

TEST(DynamicUint8SumOpTest, KeepDims) {
Expand All @@ -874,7 +881,9 @@ TEST(DynamicUint8SumOpTest, KeepDims) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
EXPECT_THAT(
m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(ArrayFloatNear({6.47059, 10.698}, kQuantizedTolerance)));
// 11.14 + 7.423 = 12.0 (clamped)
// -0.14 + 0.879 = 0.739
ElementsAreArray(ArrayFloatNear({12.0, 0.739}, kQuantizedTolerance)));
}

TEST(ConstInt8SumOpTest, Rescale) {
Expand Down

0 comments on commit ca55d1c

Please sign in to comment.