[go: nahoru, domu]

Skip to content

Commit

Permalink
Add tfjs to image classification and text classification models
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 345168536
  • Loading branch information
lintian06 authored and Copybara-Service committed Dec 2, 2020
1 parent ced21f3 commit c7f6c3e
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ClassificationModel(custom_model.CustomModel):

DEFAULT_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.LABEL)
ALLOWED_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.LABEL,
ExportFormat.SAVED_MODEL)
ExportFormat.SAVED_MODEL, ExportFormat.TFJS)

def __init__(self, model_spec, index_to_label, shuffle, train_whole_model):
"""Initialize a instance with data, deploy mode and other related parameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def export(self,
label_filename='labels.txt',
vocab_filename='vocab',
saved_model_filename='saved_model',
tfjs_folder_name='tfjs',
export_format=None,
**kwargs):
"""Converts the retrained model based on `export_format`.
Expand All @@ -138,11 +139,13 @@ def export(self,
{export_dir}/{tflite_filename}.
label_filename: File name to save labels. The full export path is
{export_dir}/{label_filename}.
vocab_filename: File name to save vocabulary. The full export path is
vocab_filename: File name to save vocabulary. The full export path is
{export_dir}/{vocab_filename}.
saved_model_filename: Path to SavedModel or H5 file to save the model. The
full export path is
{export_dir}/{saved_model_filename}/{saved_model.pb|assets|variables}.
tfjs_folder_name: Folder name to save tfjs model. The full export path is
{export_dir}/{tfjs_folder_name}.
export_format: List of export format that could be saved_model, tflite,
label, vocab.
**kwargs: Other parameters like `quantized` for TFLITE model.
Expand All @@ -168,7 +171,8 @@ def export(self,
**export_saved_model_kwargs)

if ExportFormat.TFJS in export_format:
self._export_tfjs(export_dir)
tfjs_output_path = os.path.join(export_dir, tfjs_folder_name)
self._export_tfjs(tfjs_output_path)

if ExportFormat.VOCAB in export_format:
if with_metadata:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from tensorflow_examples.lite.model_maker.core import test_util
from tensorflow_examples.lite.model_maker.core.export_format import ExportFormat
from tensorflow_examples.lite.model_maker.core.task import custom_model
from tensorflow_examples.lite.model_maker.core.task import model_util


class MockCustomModel(custom_model.CustomModel):
Expand Down Expand Up @@ -102,10 +101,10 @@ def test_export(self):
self._check_nonempty_file(os.path.join(export_path, 'model.tflite'))
self._check_nonempty_dir(os.path.join(export_path, 'saved_model'))

if model_util.HAS_TFJS:
export_path = os.path.join(self.get_temp_dir(), 'export4/')
self.model.export(export_path, export_format=[ExportFormat.TFJS])
self._check_nonempty_file(os.path.join(export_path, 'model.json'))
export_path = os.path.join(self.get_temp_dir(), 'export4/')
self.model.export(export_path, export_format=[ExportFormat.TFJS])
expected_file = os.path.join(export_path, 'tfjs', 'model.json')
self._check_nonempty_file(expected_file)


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_mobilenetv2_model(self):
self._test_export_to_tflite_with_metadata(model)
self._test_export_to_saved_model(model)
self._test_export_labels(model)
self._test_export_to_tfjs(model)

@test_util.test_in_tf_1
def test_mobilenetv2_model_create_v1_incompatible(self):
Expand Down Expand Up @@ -110,6 +111,7 @@ def test_efficientnetlite0_model(self):
self._test_export_to_tflite_quantized(model, self.train_data)
self._test_export_to_tflite_with_metadata(
model, expected_json_file='efficientnet_lite0_metadata.json')
self._test_export_to_tfjs(model)

@test_util.test_in_tf_1and2
def test_efficientnetlite0_model_without_training(self):
Expand All @@ -130,6 +132,7 @@ def test_resnet_50_model(self):
self._test_export_to_tflite(model)
self._test_export_to_tflite_quantized(model, self.train_data)
self._test_export_to_tflite_with_metadata(model)
self._test_export_to_tfjs(model)

def _test_predict_top_k(self, model, threshold=0.0):
topk = model.predict_top_k(self.test_data, batch_size=4)
Expand Down Expand Up @@ -217,6 +220,13 @@ def _test_export_to_saved_model(self, model):
self.assertTrue(os.path.isdir(save_model_output_path))
self.assertNotEqual(len(os.listdir(save_model_output_path)), 0)

def _test_export_to_tfjs(self, model):
output_path = os.path.join(self.get_temp_dir(), 'tfjs')
model.export(self.get_temp_dir(), export_format=ExportFormat.TFJS)

self.assertTrue(os.path.isdir(output_path))
self.assertNotEqual(len(os.listdir(output_path)), 0)


if __name__ == '__main__':
# Load compressed models from tensorflow_hub
Expand Down
18 changes: 4 additions & 14 deletions tensorflow_examples/lite/model_maker/core/task/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,8 @@
import numpy as np
import tensorflow as tf
from tensorflow_examples.lite.model_maker.core import compat
from tensorflowjs.converters import converter as tfjs_converter

# TODO(tianlin): Conditional import only if tensorflowjs is installed, because
# tensorflowjs requires a stable `tensorflow` package rather than `tf-nightly`.
try:
from tensorflowjs.converters import converter as tfjs_converter # pylint: disable=g-import-not-at-top
HAS_TFJS = True
except ImportError as e:
HAS_TFJS = False

DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0

Expand Down Expand Up @@ -218,15 +212,13 @@ def export_tfjs(keras_or_saved_model, output_dir, **kwargs):
output_dir: Output TF.js model dir.
**kwargs: Other options.
"""
if not HAS_TFJS:
return

# For Keras model, creates a saved model first in a temp dir. Otherwise,
# convert directly.
is_keras = isinstance(keras_or_saved_model, tf.keras.Model)
with _create_temp_dir(is_keras) as temp_dir_name:
if is_keras:
keras_or_saved_model.save(temp_dir_name, save_format='tf')
keras_or_saved_model.save(
temp_dir_name, include_optimizer=False, save_format='tf')
path = temp_dir_name
else:
path = keras_or_saved_model
Expand All @@ -235,8 +227,6 @@ def export_tfjs(keras_or_saved_model, output_dir, **kwargs):


def load_tfjs_keras_model(model_path):
if not HAS_TFJS:
raise ImportError('tensorflowjs is required to load this model. Please run '
'`pip install tensorflowjs` to install.')
"""Loads tfjs keras model from path."""
return tfjs_converter.keras_tfjs_loader.load_keras_model(
model_path, load_weights=True)
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,9 @@ def test_export_tfjs(self):

output_dir = os.path.join(self.get_temp_dir(), 'tfjs')
model_util.export_tfjs(model, output_dir)
if model_util.HAS_TFJS:
self.assertTrue(os.path.exists(output_dir))
expected_model_json = os.path.join(output_dir, 'model.json')
self.assertTrue(os.path.exists(expected_model_json))
self.assertTrue(os.path.exists(output_dir))
expected_model_json = os.path.join(output_dir, 'model.json')
self.assertTrue(os.path.exists(expected_model_json))

@test_util.test_in_tf_1and2
def test_export_tfjs_saved_model(self):
Expand All @@ -116,12 +115,12 @@ def test_export_tfjs_saved_model(self):

saved_model_dir = os.path.join(self.get_temp_dir(), 'saved_model_for_js')
model.save(saved_model_dir)

output_dir = os.path.join(self.get_temp_dir(), 'tfjs')
model_util.export_tfjs(saved_model_dir, output_dir)
if model_util.HAS_TFJS:
self.assertTrue(os.path.exists(output_dir))
expected_model_json = os.path.join(output_dir, 'model.json')
self.assertTrue(os.path.exists(expected_model_json))
self.assertTrue(os.path.exists(output_dir))
expected_model_json = os.path.join(output_dir, 'model.json')
self.assertTrue(os.path.exists(expected_model_json))


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ class TextClassifier(classification_model.ClassificationModel):
DEFAULT_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.LABEL,
ExportFormat.VOCAB)
ALLOWED_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.LABEL,
ExportFormat.VOCAB, ExportFormat.SAVED_MODEL)
ExportFormat.VOCAB, ExportFormat.SAVED_MODEL,
ExportFormat.TFJS)

def __init__(self,
model_spec,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ def test_mobilebert_model(self):
self._test_export_to_tflite(model, threshold=0.0, atol=1e-2)
self._test_export_to_tflite_quant(model)

@test_util.test_in_tf_2
def test_mobilebert_model_without_training_for_tfjs(self):
model_spec = ms.mobilebert_classifier_spec(
seq_len=2, trainable=False, default_batch_size=1)
all_data = text_dataloader.TextClassifierDataLoader.from_folder(
self.text_dir, model_spec=model_spec)
self.train_data, self.test_data = all_data.split(0.5)
with self.assertRaises(Exception): # Raise an error when reloading model.
self._test_model_without_training(model_spec)

@test_util.test_in_tf_2
def test_average_wordvec_model(self):
model_spec = ms.AverageWordVecModelSpec(seq_len=2)
Expand Down Expand Up @@ -134,6 +144,7 @@ def _test_model_without_training(self, model_spec):
self.train_data, model_spec=model_spec, do_train=False)
self._test_accuracy(model, threshold=0.0)
self._test_export_to_tflite(model, threshold=0.0)
self._test_export_to_tfjs(model)

def _test_accuracy(self, model, threshold=1.0):
_, accuracy = model.evaluate(self.test_data)
Expand Down Expand Up @@ -228,6 +239,13 @@ def _test_export_to_saved_model(self, model):
self.assertTrue(os.path.isdir(save_model_output_path))
self.assertNotEmpty(os.listdir(save_model_output_path))

def _test_export_to_tfjs(self, model):
output_path = os.path.join(self.get_temp_dir(), 'tfjs')
model.export(self.get_temp_dir(), export_format=ExportFormat.TFJS)

self.assertTrue(os.path.isdir(output_path))
self.assertNotEmpty(os.listdir(output_path))

def _test_export_to_tflite_quant(self, model):
tflite_filename = 'model_quant.tflite'
tflite_output_file = os.path.join(self.get_temp_dir(), tflite_filename)
Expand Down

0 comments on commit c7f6c3e

Please sign in to comment.