[go: nahoru, domu]

Skip to content

Commit

Permalink
Add saved_model export format.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 308775711
  • Loading branch information
ziyeqinghan authored and Copybara-Service committed Apr 28, 2020
1 parent f1dc9b6 commit 689abda
Show file tree
Hide file tree
Showing 16 changed files with 296 additions and 200 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model export format such as saved_model / tflite."""
"""Export format such as saved_model / tflite."""

from __future__ import absolute_import
from __future__ import division
Expand All @@ -22,6 +22,8 @@


@unique
class ModelExportFormat(Enum):
TFLITE = 0
SAVEDMODEL = 1
class ExportFormat(Enum):
TFLITE = "TFLITE"
SAVED_MODEL = "SAVED_MODEL"
LABEL = "LABEL"
VOCAB = "VOCAB"
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,17 @@

import numpy as np
import tensorflow.compat.v2 as tf
from tensorflow_examples.lite.model_maker.core import model_export_format as mef
from tensorflow_examples.lite.model_maker.core.task import custom_model


class ClassificationModel(custom_model.CustomModel):
""""The abstract base class that represents a Tensorflow classification model."""

def __init__(self, model_export_format, model_spec, index_to_label,
num_classes, shuffle, train_whole_model):
def __init__(self, model_spec, index_to_label, num_classes, shuffle,
train_whole_model):
"""Initialize a instance with data, deploy mode and other related parameters.
Args:
model_export_format: Model export format such as saved_model / tflite.
model_spec: Specification for the model.
index_to_label: A list that map from index to label class name.
num_classes: Number of label classes.
Expand All @@ -41,12 +39,7 @@ def __init__(self, model_export_format, model_spec, index_to_label,
classification layer on top. Otherwise, only train the top
classification layer.
"""
if model_export_format != mef.ModelExportFormat.TFLITE:
raise ValueError('Model export format %s is not supported currently.' %
str(model_export_format))

super(ClassificationModel, self).__init__(model_export_format, model_spec,
shuffle)
super(ClassificationModel, self).__init__(model_spec, shuffle)
self.index_to_label = index_to_label
self.num_classes = num_classes
self.train_whole_model = train_whole_model
Expand Down Expand Up @@ -90,28 +83,10 @@ def predict_top_k(self, data, k=1, batch_size=32):

return label_prob

def _export_tflite(self,
tflite_filename,
label_filename,
quantized=False,
quantization_steps=None,
representative_data=None):
"""Converts the retrained model to tflite format and saves it.
def _export_labels(self, label_filepath):
if label_filepath is None:
raise ValueError("Label filepath couldn't be None when exporting labels.")

Args:
tflite_filename: File name to save tflite model.
label_filename: File name to save labels.
quantized: boolean, if True, save quantized model.
quantization_steps: Number of post-training quantization calibration steps
to run. Used only if `quantized` is True.
representative_data: Representative data used for post-training
quantization. Used only if `quantized` is True.
"""
super(ClassificationModel,
self)._export_tflite(tflite_filename, quantized, quantization_steps,
representative_data)

with tf.io.gfile.GFile(label_filename, 'w') as f:
tf.compat.v1.logging.info('Saving labels in %s.', label_filepath)
with tf.io.gfile.GFile(label_filepath, 'w') as f:
f.write('\n'.join(self.index_to_label))

tf.compat.v1.logging.info('Saved labels in %s.', label_filename)
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from __future__ import division
from __future__ import print_function

import os

import tensorflow.compat.v2 as tf
from tensorflow_examples.lite.model_maker.core import model_export_format as mef
from tensorflow_examples.lite.model_maker.core import test_util
from tensorflow_examples.lite.model_maker.core.task import classification_model

Expand All @@ -36,27 +37,29 @@ def evaluate(self, data, **kwargs):

class ClassificationModelTest(tf.test.TestCase):

def test_predict_top_k(self):
input_shape = [24, 24, 3]
num_classes = 2
model = MockClassificationModel(
model_export_format=mef.ModelExportFormat.TFLITE,
def setUp(self):
super(ClassificationModelTest, self).setUp()
self.num_classes = 2
self.model = MockClassificationModel(
model_spec=None,
index_to_label=['pos', 'neg'],
num_classes=2,
train_whole_model=False,
shuffle=False)
model.model = test_util.build_model(input_shape, num_classes)
data = test_util.get_dataloader(2, input_shape, num_classes)

topk_results = model.predict_top_k(data, k=2, batch_size=1)
def test_predict_top_k(self):
input_shape = [24, 24, 3]
self.model.model = test_util.build_model(input_shape, self.num_classes)
data = test_util.get_dataloader(2, input_shape, self.num_classes)

topk_results = self.model.predict_top_k(data, k=2, batch_size=1)
for topk_result in topk_results:
top1_result, top2_result = topk_result[0], topk_result[1]
top1_label, top1_prob = top1_result[0], top1_result[1]
top2_label, top2_prob = top2_result[0], top2_result[1]

self.assertIn(top1_label, model.index_to_label)
self.assertIn(top2_label, model.index_to_label)
self.assertIn(top1_label, self.model.index_to_label)
self.assertIn(top2_label, self.model.index_to_label)
self.assertNotEqual(top1_label, top2_label)

self.assertLessEqual(top1_prob, 1)
Expand All @@ -65,6 +68,13 @@ def test_predict_top_k(self):

self.assertEqual(top1_prob + top2_prob, 1.0)

def test_export_labels(self):
labels_output_file = os.path.join(self.get_temp_dir(), 'label')
self.model._export_labels(labels_output_file)
with tf.io.gfile.GFile(labels_output_file, 'r') as f:
labels = [label.strip() for label in f]
self.assertEqual(labels, ['pos', 'neg'])


if __name__ == '__main__':
tf.test.main()
52 changes: 39 additions & 13 deletions tensorflow_examples/lite/model_maker/core/task/custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import tensorflow.compat.v2 as tf
from tensorflow_examples.lite.model_maker.core import compat
from tensorflow_examples.lite.model_maker.core import model_export_format as mef

DEFAULT_QUANTIZATION_STEPS = 2000

Expand All @@ -41,19 +40,13 @@ def representative_dataset_gen():
class CustomModel(abc.ABC):
""""The abstract base class that represents a Tensorflow classification model."""

def __init__(self, model_export_format, model_spec, shuffle):
def __init__(self, model_spec, shuffle):
"""Initialize a instance with data, deploy mode and other related parameters.
Args:
model_export_format: Model export format such as saved_model / tflite.
model_spec: Specification for the model.
shuffle: Whether the data should be shuffled.
"""
if model_export_format != mef.ModelExportFormat.TFLITE:
raise ValueError('Model export format %s is not supported currently.' %
str(model_export_format))

self.model_export_format = model_export_format
self.model_spec = model_spec
self.shuffle = shuffle
self.model = None
Expand Down Expand Up @@ -104,21 +97,56 @@ def _gen_dataset(self,
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
return ds

def _export_saved_model(self,
filepath,
overwrite=True,
include_optimizer=True,
save_format=None,
signatures=None,
options=None):
"""Saves the model to Tensorflow SavedModel or a single HDF5 file.
Args:
filepath: String, path to SavedModel or H5 file to save the model.
overwrite: Whether to silently overwrite any existing file at the target
location, or provide the user with a manual prompt.
include_optimizer: If True, save optimizer's state together.
save_format: Either 'tf' or 'h5', indicating whether to save the model to
Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X, and 'h5' in
TF 1.X.
signatures: Signatures to save with the SavedModel. Applicable to the 'tf'
format only. Please see the `signatures` argument in
`tf.saved_model.save` for details.
options: Optional `tf.saved_model.SaveOptions` object that specifies
options for saving to SavedModel.
"""
if filepath is None:
raise ValueError(
"SavedModel filepath couldn't be None when exporting to SavedModel.")
self.model.save(filepath, overwrite, include_optimizer, save_format,
signatures, options)

def _export_tflite(self,
tflite_filename,
tflite_filepath,
quantized=False,
quantization_steps=None,
representative_data=None):
"""Converts the retrained model to tflite format and saves it.
Args:
tflite_filename: File name to save tflite model.
tflite_filepath: File path to save tflite model.
quantized: boolean, if True, save quantized model.
quantization_steps: Number of post-training quantization calibration steps
to run. Used only if `quantized` is True.
representative_data: Representative data used for post-training
quantization. Used only if `quantized` is True.
"""
if tflite_filepath is None:
raise ValueError(
"TFLite filepath couldn't be None when exporting to tflite.")

tf.compat.v1.logging.info('Exporting to tflite model in %s.',
tflite_filepath)
temp_dir = None
if compat.get_tf_behavior() == 1:
temp_dir = tempfile.TemporaryDirectory()
Expand Down Expand Up @@ -149,7 +177,5 @@ def _export_tflite(self,
if temp_dir:
temp_dir.cleanup()

with tf.io.gfile.GFile(tflite_filename, 'wb') as f:
with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
f.write(tflite_model)

tf.compat.v1.logging.info('Export to tflite model in %s.', tflite_filename)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import numpy as np
import tensorflow.compat.v2 as tf
from tensorflow_examples.lite.model_maker.core import model_export_format as mef
from tensorflow_examples.lite.model_maker.core import test_util
from tensorflow_examples.lite.model_maker.core.task import custom_model

Expand All @@ -42,7 +41,6 @@ class CustomModelTest(tf.test.TestCase):
def setUp(self):
super(CustomModelTest, self).setUp()
self.model = MockCustomModel(
model_export_format=mef.ModelExportFormat.TFLITE,
model_spec=None,
shuffle=False)

Expand All @@ -67,6 +65,13 @@ def test_export_tflite(self):
self.model._export_tflite(tflite_file)
self._test_tflite(self.model.model, tflite_file, input_dim)

def test_export_saved_model(self):
self.model.model = test_util.build_model(input_shape=[4], num_classes=2)
saved_model_filepath = os.path.join(self.get_temp_dir(), 'saved_model/')
self.model._export_saved_model(saved_model_filepath)
self.assertTrue(os.path.isdir(saved_model_filepath))
self.assertNotEqual(len(os.listdir(saved_model_filepath)), 0)

def test_export_tflite_quantized(self):
input_dim = 4
num_classes = 2
Expand Down
Loading

0 comments on commit 689abda

Please sign in to comment.