[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot load saved tflite model when model contains dropout. However quantization-aware model is fine. #51848

Closed
HongtaoYang opened this issue Sep 6, 2021 · 4 comments
Assignees
Labels
comp:lite TF Lite related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.5 Issues related to TF 2.5 type:bug Bug

Comments

@HongtaoYang
Copy link

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • TensorFlow installed from (source or binary): No
  • TensorFlow version (use command below): 2.5.0
  • Python version: 3.8.9
  • CUDA/cuDNN version: cuDNN-8.1.1.33. CUDA-11.2.0

Describe the current behavior
If a model contains dropout layers, if can be converted to tflite, but the tflite model cannot be loaded. The error message is:

ValueError: Did not get operators or tensors in subgraph 1.

Confusingly, if I make the model quantization-aware, the model can be successfully converted to tflite and loaded. Does QAT automatically remove dropout?

Describe the expected behavior
If a tflite model that contain dropout is not supported, then the desired behavior are:

  1. Raise an error during tflite conversion instead of at runtime. Or
  2. Remove the dropout operation during tflite conversion.

Standalone code to reproduce the issue

import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot


quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_apply = tfmot.quantization.keras.quantize_apply



def get_model(quantization_aware: bool = False):
    input_embeddings = tf.keras.Input(shape=(512,), dtype=tf.float32, name="input_embeddings")
    is_training = tf.keras.Input(shape=(), dtype=tf.bool, name="is_training")

    x = tf.keras.layers.Dropout(rate=0.5)(inputs=input_embeddings, training=is_training)

    if quantization_aware:  # apply quantization to dense layer and return quantization-aware model
        out = quantize_annotate_layer(
            tf.keras.layers.Dense(units=512,name="dense")
        )(x)

        model = quantize_apply(
            tf.keras.Model(
                inputs=[input_embeddings, is_training],
                outputs=[x, out],
                name="toy_model",
            )
        )
    else: # return vanilla model
        out = tf.keras.layers.Dense(units=512,name="dense")(x)
        model = tf.keras.Model(
            inputs=[input_embeddings, is_training],
            outputs=[x, out],
            name="toy_model",
        )

    return model

def convert_to_tflite(saved_model_path, output_model_path):
    converter = tf.lite.TFLiteConverter.from_saved_model(str(saved_model_path))
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS,
        tf.lite.OpsSet.SELECT_TF_OPS,
    ]
    tflite_quant_model = converter.convert()

    with open(output_model_path, "wb") as fh:
        fh.write(tflite_quant_model)



model = get_model(quantization_aware=False)  # set to False will raise the error. Set to True the code runs successfully.
input_data = np.random.rand(16, 512)
embeddings, out = model([input_data, True])

model.save("toy_model")

convert_to_tflite("toy_model", "toy_model.tflite")
interpreter = tf.lite.Interpreter("toy_model.tflite")

I'm wondering how should I deal with dropout layer in tflite? As dropout is essential for training, how can I export a trained model without dropout? This is a similar issue, and I'm not satisfied by the workaround proposed there.

@HongtaoYang
Copy link
Author
HongtaoYang commented Sep 6, 2021

Oh I figured out what I did wrong and also the correct way to use dropout.

Basically I shouldn't explicitly pass the training argument to dropout layer, it is handled during model invocation where we can tell the model to be in training mode by out = model(input_batch, training=True). This way tflite can handle dropout correctly.

However, I still don't understand why quantization-aware model doesn't have the dropout issue.

@tilakrayal tilakrayal added TF 2.5 Issues related to TF 2.5 comp:lite TF Lite related issues labels Sep 6, 2021
@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Oct 5, 2021
@pjpratik pjpratik self-assigned this Feb 20, 2023
@pjpratik
Copy link
Contributor
pjpratik commented Mar 9, 2023

@HongtaoYang I was checking if this is still an issue.

I have tried with TF Nightly 2.13.0-dev20230308 and TF MOT 0.7.3. I did not face any error while quantization_aware = False. Please find the gist here.

It is observed that dropout is indeed removed during the conversion when observed in netron.

Thanks.

@pjpratik pjpratik added stat:awaiting response Status - Awaiting response from author and removed stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Mar 9, 2023
@pjpratik pjpratik added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Mar 21, 2023
@github-actions
Copy link

This issue was closed because it has been inactive for 7 days since being marked as stale. Please reopen if you'd like to work on this further.

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:lite TF Lite related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.5 Issues related to TF 2.5 type:bug Bug
Projects
None yet
Development

No branches or pull requests

5 participants