[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow int16 input/output even when not using 16x8 quantization mode #56615

Closed
hguihot opened this issue Jun 29, 2022 · 6 comments
Closed

Allow int16 input/output even when not using 16x8 quantization mode #56615

hguihot opened this issue Jun 29, 2022 · 6 comments
Assignees
Labels
awaiting PR merge awaiting PR merge comp:lite TF Lite related issues TFLiteConverter For issues related to TFLite converter type:feature Feature requests

Comments

@hguihot
Copy link
hguihot commented Jun 29, 2022

Currently the valid types can be either int16 or int8/uint8, but not a combination of both. Some models could for example contain a custom op returning an int16 tensor as model output, and converting such model to TFLite is failing. It looks like always adding _dtypes.int16 to the list of supported types when quant_mode.is_integer_quantization() is true would be enough to make it work.

@hguihot hguihot added the TFLiteConverter For issues related to TFLite converter label Jun 29, 2022
@mohantym mohantym added comp:lite TF Lite related issues type:feature Feature requests labels Jun 29, 2022
@mohantym
Copy link
Contributor
mohantym commented Jun 29, 2022

Hi @hguihot ! Could you also share a minimal standalone code too which will help expedite the issue. Thank you!

@mohantym mohantym added the stat:awaiting response Status - Awaiting response from author label Jun 29, 2022
@hguihot
Copy link
Author
hguihot commented Jun 29, 2022

Here is an example with 2 convolutions, both with the same 8-bit input but one with 8-bit output and the other with 16-bit output.

import tensorflow as tf

def quant(x, num_bits=8):
    return tf.quantization.fake_quant_with_min_max_args(x, -1, 1, num_bits, False)

class QConv(tf.keras.layers.Conv2D):
    def __init__(self, filters, kernel_size, weight_quantizer, activation_quantizer):
        self.weight_quantizer = weight_quantizer
        self.activation_quantizer = activation_quantizer
        super().__init__(filters = filters, kernel_size = kernel_size)

    def call(self, bottom):
        return self.activation_quantizer(self.convolution_op(bottom, self.weight_quantizer(self.kernel)))

tf.keras.backend.set_image_data_format("channels_last")
input_tensor = tf.keras.Input(shape=(64, 64, 3), batch_size=1)
quantized_input_tensor = quant(input_tensor, num_bits=8)

# 8->8bit convolution
layer = QConv(filters=32, kernel_size=3, weight_quantizer=quant, activation_quantizer=lambda x: quant(x, num_bits=8))
output8 = layer(quantized_input_tensor)

# 8->16bit convolution
layer = QConv(filters=32, kernel_size=3, weight_quantizer=quant, activation_quantizer=lambda x: quant(x, num_bits=16))
output16 = layer(quantized_input_tensor)

model = tf.keras.Model(inputs=[input_tensor], outputs=[output8, output16])

train_save_path = "/tmp/debug_model"
convert_model_path = "/tmp/converted.tflite"

model.save(train_save_path)
converter = tf.lite.TFLiteConverter.from_saved_model(train_save_path)
converter.optimizations =[tf.lite.Optimize.DEFAULT]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int16
tflite_model = converter.convert()

with open(convert_model_path, "wb") as f:
    f.write(tflite_model)

Two changes were actually required to make the conversion succeed:

@google-ml-butler google-ml-butler bot removed the stat:awaiting response Status - Awaiting response from author label Jun 29, 2022
@mohantym
Copy link
Contributor

Hi @sachinprasadhs ! Could you look at this feature request. Attached gist for reference. Thank you!

@mohantym mohantym assigned sachinprasadhs and unassigned mohantym Jun 30, 2022
@sachinprasadhs sachinprasadhs added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jul 1, 2022
@hguihot
Copy link
Author
hguihot commented Aug 31, 2022

Any update?

pjpratik added a commit that referenced this issue Feb 2, 2023
As a feature request #56615, the _dtypes.int16 to be allowed when 16x8 quantization is not used so that the custom ops returning 16bit outputs can be benefitted.
@pjpratik
Copy link
Contributor
pjpratik commented Feb 3, 2023

We created PR #59526 to enable support dtypes.int16. The issue will be closed once this is merged. Thanks!

@pjpratik pjpratik added awaiting PR merge awaiting PR merge and removed stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Feb 3, 2023
@pjpratik pjpratik self-assigned this Feb 3, 2023
@sachinprasadhs sachinprasadhs removed their assignment Feb 16, 2023
@pjpratik
Copy link
Contributor

Hi @hguihot

The support for int16 has been added with the commit 33d76ac.

Closing this issue as resolved. Please reopen if you'd like to work on this further.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
awaiting PR merge awaiting PR merge comp:lite TF Lite related issues TFLiteConverter For issues related to TFLite converter type:feature Feature requests
Projects
None yet
Development

No branches or pull requests

7 participants