[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TFLite 2.16.1 conversions fail with "AttributeError: 'Sequential' object has no attribute '_get_save_spec'" #63867

Open
schwefeljm opened this issue Mar 18, 2024 · 20 comments
Assignees
Labels
comp:lite TF Lite related issues stat:awaiting response Status - Awaiting response from author TF 2.16 TFLiteConverter For issues related to TFLite converter type:bug Bug

Comments

@schwefeljm
Copy link
schwefeljm commented Mar 18, 2024

1. System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
    Debian 13 - Trixie
  • TensorFlow installation (pip package or built from source):
    pip
  • TensorFlow library (version, if pip package or github SHA, if built from source):
    2.16.0-rc0 & 2.16.1

I am brand new to TensorFlow and welcome any suggestions.
I tested the exact same code on 2.15.0.post1 and 2.16.x. It runs on 15 and errors out on 16

2. Code

Provide code to help us reproduce your issues using one of the following options:

Option B: Paste your code here or provide a link to a custom end-to-end colab

def main():

    # Removed dependance on argsparser()
    # args = argsparer()

    # batch_size = args.batchSize # 32
    batch_size = 32
    # img_height = args.height    # 128
    img_height = 128
    # img_width = args.width      # 128
    img_width = 128
    activation1 = 'swish'
    activation2 = 'elu'
    optimizer = tf.keras.optimizers.AdamW(learning_rate=0.0005)
    # epochs = args.epochs        # 1
    epochs = 1

    train_data_dir = pathlib.Path("{}/{}/".format(args.inputDir, args.trainingSubDir)).with_suffix('')
    val_data_dir = pathlib.Path("birds/birds/valid/".format(args.inputDir, args.validationSubDir)).with_suffix('')

    # assume all files in training directory are readable image files.
    image_count = len(list(train_data_dir.glob('*/*.*')))

    AUTOTUNE = tf.data.AUTOTUNE

    train_ds = tf.keras.utils.image_dataset_from_directory(
        train_data_dir,
        seed=random.randint(1, 10000),
        image_size=(img_height, img_width),
        batch_size=batch_size)

    val_ds = tf.keras.utils.image_dataset_from_directory(
        val_data_dir,
        # validation_split=0.1,
        # subset="validation",
        seed=random.randint(1, 10000),
        image_size=(img_height, img_width),
        batch_size=batch_size)

    num_classes = len(train_ds.class_names)
    print("Training directory '{}' contains {} images in {} categories.".format(train_data_dir, image_count, num_classes))

    normalization_layer = layers.Rescaling(1. / 255)
    train_ds_norm = train_ds.map(lambda x, y: (normalization_layer(x), y),
                                num_parallel_calls=tf.data.AUTOTUNE,
                                deterministic=False)
    val_ds_norm = val_ds.map(lambda x, y: (normalization_layer(x), y),
                                num_parallel_calls=tf.data.AUTOTUNE,
                                deterministic=False)

    # train_ds = train_ds.shuffle(1000).prefetch(buffer_size=AUTOTUNE)
    # val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)

    data_augmentation = tf.keras.Sequential(
        [
            layers.RandomFlip("horizontal",
                              input_shape=(img_height,
                                           img_width,
                                           3)),
            layers.RandomRotation(0.1),
            layers.RandomZoom(0.1),
        ]
    )

    model = Sequential([
        data_augmentation,
        # layers.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
        # layers.Conv2D(16, 3, padding='same', activation=activation1),
        # layers.MaxPooling2D(),
        layers.Conv2D(32, 3, padding='same', activation=activation1),
        layers.MaxPooling2D(),
        layers.Conv2D(48, 3, padding='same', activation=activation2),
        layers.MaxPooling2D(),
        layers.Conv2D(64, 3, padding='same', activation=activation1),
        layers.MaxPooling2D(),
        layers.Dropout(0.15),
        layers.Flatten(),
        layers.Dense(128, activation=activation2),
        layers.Dense(num_classes)
    ])

    model.compile(optimizer=optimizer,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

    history = model.fit(
        train_ds_norm,
        validation_data=val_ds_norm,
        epochs=epochs)


    # Convert the model.
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.experimental_new_converter = True
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
        tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
    ]
    tflite_model = converter.convert()

    # Save the model.
    with open('model.tflite', 'wb') as f:
        f.write(tflite_model)

(You can paste links or attach files by dragging & dropping them below)

  • Include code to invoke the TFLite Converter Python API and the errors.
  • Provide links to your TensorFlow model and (optionally) TensorFlow Lite Model.

3. Failure after conversion

Conversion completely fails with "AttributeError: 'Sequential' object has no attribute '_get_save_spec'. Did you mean: '_set_save_spec'?" See below for complete log and traceback.

I have had to revert to 2.15.0.post1 to get model to convert and save as TFLite.

5. (optional) Any other info / logs

024-03-18 09:24:25.859843: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-18 09:24:25.884237: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-03-18 09:24:26.564689: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-03-18 09:24:26.569165: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-03-18 09:24:26.569272: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-03-18 09:24:26.570638: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-03-18 09:24:26.570740: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-03-18 09:24:26.570825: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-03-18 09:24:26.616558: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-03-18 09:24:26.616674: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-03-18 09:24:26.616763: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-03-18 09:24:26.616839: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 11308 MB memory:  -> device: 0, name: NVIDIA RTX A4000, pci bus id: 0000:01:00.0, compute capability: 8.6
Found 84635 files belonging to 525 classes.
Found 2625 files belonging to 525 classes.
Training directory 'birds/birds/train' contains 84635 images in 525 categories.
/home/jschwefel/repositories/aiml/TensorFlow/tfbase/lib/python3.11/site-packages/keras/src/layers/preprocessing/tf_data_layer.py:19: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
  super().__init__(**kwargs)
2024-03-18 09:24:29.337903: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape inStatefulPartitionedCall/sequential_1_1/dropout_1/stateless_dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
2024-03-18 09:24:29.937251: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8907
2645/2645 ━━━━━━━━━━━━━━━━━━━━ 46s 17ms/step - accuracy: 0.0066 - loss: 6.1631 - val_accuracy: 0.0579 - val_loss: 5.2888
Traceback (most recent call last):
  File "/home/jschwefel/repositories/aiml/TensorFlow/Test001/imageclassifier.py", line 129, in <module>
    main()
  File "/home/jschwefel/repositories/aiml/TensorFlow/Test001/imageclassifier.py", line 122, in main
    tflite_model = converter.convert()
                   ^^^^^^^^^^^^^^^^^^^
  File "/home/jschwefel/repositories/aiml/TensorFlow/tfbase/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1175, in wrapper
    return self._convert_and_export_metrics(convert_func, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jschwefel/repositories/aiml/TensorFlow/tfbase/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1129, in _convert_and_export_metrics
    result = convert_func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jschwefel/repositories/aiml/TensorFlow/tfbase/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1641, in convert
    self._freeze_keras_model()
  File "/home/jschwefel/repositories/aiml/TensorFlow/tfbase/lib/python3.11/site-packages/tensorflow/lite/python/convert_phase.py", line 215, in wrapper
    raise error from None  # Re-throws the exception.
    ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jschwefel/repositories/aiml/TensorFlow/tfbase/lib/python3.11/site-packages/tensorflow/lite/python/convert_phase.py", line 205, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jschwefel/repositories/aiml/TensorFlow/tfbase/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1582, in _freeze_keras_model
    input_signature = _model_input_signature(
                      ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jschwefel/repositories/aiml/TensorFlow/tfbase/lib/python3.11/site-packages/tensorflow/lite/python/tflite_keras_util.py", line 84, in model_input_signature
    input_specs = model._get_save_spec(  # pylint: disable=protected-access
                  ^^^^^^^^^^^^^^^^^^^^
AttributeError: 'Sequential' object has no attribute '_get_save_spec'. Did you mean: '_set_save_spec'?
@schwefeljm schwefeljm added the TFLiteConverter For issues related to TFLite converter label Mar 18, 2024
@tilakrayal tilakrayal added comp:lite TF Lite related issues TF 2.16 type:bug Bug labels Mar 19, 2024
@Aloqeely
Copy link
Contributor
    model = Sequential([
        data_augmentation,
        # layers.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
        # layers.Conv2D(16, 3, padding='same', activation=activation1),
        # layers.MaxPooling2D(),
        layers.Conv2D(32, 3, padding='same', activation=activation1),
        layers.MaxPooling2D(),
        layers.Conv2D(48, 3, padding='same', activation=activation2),
        layers.MaxPooling2D(),
        layers.Conv2D(64, 3, padding='same', activation=activation1),
        layers.MaxPooling2D(),
        layers.Dropout(0.15),
        layers.Flatten(),
        layers.Dense(128, activation=activation2),
        layers.Dense(num_classes)
    ])

Sequential here might be imported from something other than tf.keras. Please take a look at #39001

@LakshmiKalaKadali
Copy link
Contributor
LakshmiKalaKadali commented Mar 20, 2024

Hi @schwefeljm,

I am trying to reproduce the code meanwhile I encountered some other error. Is it `args = argparser(). Please provide complete reproducible code/link to debug the issue.

Thank You

@LakshmiKalaKadali LakshmiKalaKadali added the stat:awaiting response Status - Awaiting response from author label Mar 20, 2024
@schwefeljm
Copy link
Author

Hi @schwefeljm,

I am trying to reproduce the code meanwhile I encountered some other error. Is it `args = argparser(). Please provide complete reproducible code/link to debug the issue.

Thank You
Hi @LakshmiKalaKadali ,

I updated to code to remove the depency on 'argsparser()'

The dataset I used is from: https://www.kaggle.com/datasets/gpiosenka/100-bird-species Though, I expect it work on any dataset.

Jason

@google-ml-butler google-ml-butler bot removed the stat:awaiting response Status - Awaiting response from author label Mar 20, 2024
@schwefeljm
Copy link
Author
    model = Sequential([
        data_augmentation,
        # layers.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
        # layers.Conv2D(16, 3, padding='same', activation=activation1),
        # layers.MaxPooling2D(),
        layers.Conv2D(32, 3, padding='same', activation=activation1),
        layers.MaxPooling2D(),
        layers.Conv2D(48, 3, padding='same', activation=activation2),
        layers.MaxPooling2D(),
        layers.Conv2D(64, 3, padding='same', activation=activation1),
        layers.MaxPooling2D(),
        layers.Dropout(0.15),
        layers.Flatten(),
        layers.Dense(128, activation=activation2),
        layers.Dense(num_classes)
    ])

Sequential here might be imported from something other than tf.keras. Please take a look at #39001
@Aloqeely

I went through and forced 'tf.keras.models.Sequential' and it had no effect. Thank you for the suggestions, though.

Jason

@pkgoogle
Copy link

I was able to replicate on tf-nightly as well as 2.16.1. gist, I reduced the reproducible sample to what mattered i.e. the training process is actually irrelevant. This appears to happen when using Keras with the TFLite converter in 2.16.1 onward.

Hi @haozha111, can you please take a look?

@pkgoogle pkgoogle added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Mar 27, 2024
@haozha111 haozha111 assigned lu-wang-g and unassigned haozha111 Mar 27, 2024
@Takudzwamz
Copy link
Takudzwamz commented Apr 10, 2024

I'm experiencing the same error here

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tfmodel = converter.convert()
open ("mobilenetv2_finetuned.tflite" , "wb") .write(tfmodel)

Is it gonna help if I downgrade to an older version of Tensorflow?

@schwefeljm
Copy link
Author

I'm experiencing the same error here

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tfmodel = converter.convert()
open ("mobilenetv2_finetuned.tflite" , "wb") .write(tfmodel)

Is it gonna help if I downgrade to an older version of Tensorflow?

@Takudzwamz

For me it did.

@pkgoogle
Copy link

Hi @Takudzwamz, it seems like 2.15 does not exhibit this behavior currently.

@Takudzwamz
Copy link

@pkgoogle I downgraded to 2.15 and it worked, thanks.

@Lerxp
Copy link
Lerxp commented Apr 16, 2024

Referencing legacy Keras worked for me: https://blog.tensorflow.org/2024/03/whats-new-in-tensorflow-216.html

@pkgoogle
Copy link

As @Lerxp stated, using Legacy Keras is the preferred workaround for now:

import os;os.environ["TF_USE_LEGACY_KERAS"]="1"

@gusbernard
Copy link

I'm having the same issue, but Legacy Keras isn't working either.

@pkgoogle
Copy link

Hi @gusbernard, do you have a notebook or code snippet to share? Additionally have you tried using 2.15 for now? Thanks for any information you can provide.

@HCzou
Copy link
HCzou commented Apr 30, 2024

I met the same problem when I using tensorflow2.6.1 in python3.12.3, I tried to downgrade tensorflow to 2.15.0, while it not support python3.12.3. I am trying to downgrade python now...

@rudolflovrencic
Copy link

I have the same issue when trying to convert a keras model to TFLite exactly like in the docs.

Python     3.12.3
Keras      3.3.3
Tensorflow 2.16.1

@kariiho
Copy link
kariiho commented May 24, 2024

I used model.export() to deal with the problem.

# Export the keras model to a saved model format
model.export("saved_model")

# Convert the saved model to TensorFlow Lite
converter = tf.lite.TFLiteConverter.from_saved_model("saved_model")
tflite_model = converter.convert()

# Save the TensorFlow Lite model to a file
with open("model.tflite", "wb") as f:
    f.write(tflite_model)

@pkgoogle
Copy link

Hi @schwefeljm, have you heard of AI-Edge-Torch?, you can find more information here: googleblog. You will not run into this issue if you go this route.

I have created a script for converting your model here:

import torch
import torch.nn as nn
import ai_edge_torch


class CustomCVModel(nn.Module):
    def __init__(self, num_classes, img_height, img_width, activation1=nn.SiLU, activation2=nn.ELU):
        super().__init__()
        
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding="same"),
            activation1(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 48, kernel_size=3, padding="same"),
            activation2(),
            nn.MaxPool2d(2),
            nn.Conv2d(48, 64, kernel_size=3, padding="same"),
            activation1(),
            nn.MaxPool2d(2),
            nn.Dropout(0.15),
            nn.Flatten(),
        )

        self.classifier = nn.Sequential(
            nn.Linear(64 * (img_height // 8) * (img_width // 8), 128),  
            activation2(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.classifier(x)
        return x


img_height = 128
img_width = 128
model = CustomCVModel(1000, img_height, img_width, nn.SiLU, nn.ELU)
sample_inputs = (torch.randn(1, 3, img_height, img_width),)

edge_model = ai_edge_torch.convert(model.eval(), sample_inputs)
edge_model.export("custom_cv_model.tflite")

You will still need to modify your training code but I have tested this and the conversion does work w/o issue.

If you want to, you can actually try visualizing the result in model-explorer as well.

Please try them out and let us know if this resolves your issue. If you still need further help, feel free to open a new issue at the respective repo.

@pkgoogle pkgoogle added stat:awaiting response Status - Awaiting response from author and removed stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Jun 11, 2024
Copy link

This issue is stale because it has been open for 7 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Jun 19, 2024
@roboticvedant
Copy link

@pkgoogle I downgraded to 2.15 and it worked, thanks.
I am facing trouble in downgrading my TF, can you share how you did and what was your python version?

@google-ml-butler google-ml-butler bot removed stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author labels Jun 23, 2024
@pkgoogle pkgoogle added the stat:awaiting response Status - Awaiting response from author label Jun 24, 2024
@GuidoBartoli
Copy link

I have the same issue with a Keras 3 model saved into .keras format and reloaded with keras.models.load_model():

import keras as ks
import tensorflow as tf

model = ks.models.load_model("model.keras")
converter = tf.lite.TFLiteConverter.from_keras_model(model)

This is the error message:

Traceback (most recent call last):
  File "/home/bartoli/Projects/deeplab/deploy.py", line 29, in <module>
    deployment.deploy(args.model, args.output, args.quantize, quant_ds)
  File "/home/bartoli/Projects/deeplab/lib/deployment.py", line 67, in deploy
    converted = converter.convert()
                ^^^^^^^^^^^^^^^^^^^
  File "/home/bartoli/miniconda3/envs/dl/lib/python3.12/site-packages/tensorflow/lite/python/lite.py", line 1175, in wrapper
    return self._convert_and_export_metrics(convert_func, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bartoli/miniconda3/envs/dl/lib/python3.12/site-packages/tensorflow/lite/python/lite.py", line 1129, in _convert_and_export_metrics
    result = convert_func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bartoli/miniconda3/envs/dl/lib/python3.12/site-packages/tensorflow/lite/python/lite.py", line 1641, in convert
    self._freeze_keras_model()
  File "/home/bartoli/miniconda3/envs/dl/lib/python3.12/site-packages/tensorflow/lite/python/convert_phase.py", line 215, in wrapper
    raise error from None  # Re-throws the exception.
    ^^^^^^^^^^^^^^^^^^^^^
  File "/home/bartoli/miniconda3/envs/dl/lib/python3.12/site-packages/tensorflow/lite/python/convert_phase.py", line 205, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/bartoli/miniconda3/envs/dl/lib/python3.12/site-packages/tensorflow/lite/python/lite.py", line 1582, in _freeze_keras_model
    input_signature = _model_input_signature(
                      ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bartoli/miniconda3/envs/dl/lib/python3.12/site-packages/tensorflow/lite/python/tflite_keras_util.py", line 84, in model_input_signature
    input_specs = model._get_save_spec(  # pylint: disable=protected-access
                  ^^^^^^^^^^^^^^^^^^^^
AttributeError: 'Sequential' object has no attribute '_get_save_spec'. Did you mean: '_set_save_spec'?

@google-ml-butler google-ml-butler bot removed the stat:awaiting response Status - Awaiting response from author label Jun 28, 2024
@pkgoogle pkgoogle added the stat:awaiting response Status - Awaiting response from author label Jun 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:lite TF Lite related issues stat:awaiting response Status - Awaiting response from author TF 2.16 TFLiteConverter For issues related to TFLite converter type:bug Bug
Projects
None yet
Development

No branches or pull requests