Status | Accepted |
---|---|
Author(s) | Kathy Wu (kathywu@google.com) |
Sponsor | Karmel (karmel@google.com) |
Updated | 2019-05-30 |
Add ability to (1) save a TensorFlow SavedModel from a Keras model and (2) load a Keras model from a SavedModel. This is a standalone serialization format, which allows models with custom objects to be saved/loaded without having the original code (currently required by the HDF5 and JSON formats).
Goals:
- Add SavedModel format for saving Keras models.
- Support comprehensive model serialization/deserialization
- Define model serialization coverage (what parts of the model are saved to the SavedModel)
- Address serialization pattern that is reusable across different frameworks (TensorFlow Hub, tf.module, etc.)
- Define SavedModel deserialization (properties of the reconstructed model)
Non-goals:
- Serializing custom training loops. Only the standard
model.compile
tomodel.fit
use case is covered. Custom training loops may still be saved by wrapping it in a tf.function, and passing it to thesignatures
argument or setting it as an attribute of the model (e.g. model.train_loop = ...). - Changing existing serialization formats
Keras is a high level API for defining models while using a TensorFlow graph in the backend. The current model serialization format stores details for recreating the model object and its internal layers. The Python objects define the operations to add to the graph. This is great for saving and loading models that solely use pre-defined layers and follows the Keras Sequential or Functional API. If the model contains custom objects, their implementations must be passed to the tf.keras.model.load_models
function.
The SavedModel format saves the TensorFlow graph, so it is capable of serializing custom objects, and deserializing without needing the original implementation. This allows Keras models to move freely between different systems and languages (Python, JS, and C++ to an extent).
Familiarity with the terminology used in RFC: SavedModel saving/loading in 2.X is recommended.
Related works:
Symbol changes
-
tf.keras.Model.save
ortf.keras.models.save_model
: Add additional argumentssave_format
andsignatures
.model.save(path, include_optimizer=None, save_format=None, signatures=None)
- include_optimizer: (for SavedModel format) Whether or not the SavedModel should contain the optimizer. Note that even if this argument is
False
, any compiled losses and metrics are saved, since they are separate from the optimizer. - save_format: Either
h5
ortf
. Specifies the format for saving the model. If left as None, the save format will default to tf unless the path ends with.h5
,.hdf5
, or.keras
. - signatures: Signatures to attach to the exported SavedModel. If
None
(default), the model's call() will be traced and used as the default signature. Thesignatures
argument intf.saved_model.save
has a more comprehensive description.
- include_optimizer: (for SavedModel format) Whether or not the SavedModel should contain the optimizer. Note that even if this argument is
Behavior changes
tf.keras.models.load_model
: Currently only loads h5 files. This will be modified to automatically detect SavedModels. In both cases, a Keras model object is returned. The root object in the SavedModel must at least contain shared endpoints, otherwise, an error is raised.tf.saved_model.save
andtf.keras.models.save_model
andmodel.save
: These functions have consistent saving results. Additional dependencies/functions are serialized to the SavedModel. See Serialization coverage
Note
tf.saved_model.load
(tf.saved_model.load_v2 in TF 1.13+) will not change. Instead of returning a Keras model liketf.keras.models.load_model
, this will return a generic object with similar attributes (Deserialization details section briefly describes the generic object).
model.export
: As an alternative tomodel.save
, which contains export-exclusive arguments such as optimization (precision, model pruning, frozen graphs), or target framework (TF JS, TF Lite, TF Hub, etc.). Perhaps this method will be added in the future, but for saving/loading using SavedModel, modifyingmodel.save
is more intuitive.- Directly return a keras model instead of a generic object from
tf.saved_model.load
. This makes it difficult for other frameworks built on top of TensorFlow to work with SavedModels saved from Keras. Handling Keras object comes with subtle issues when not using the Keras API (for example, Keras objects often utilize a backend graph/session, even in eager mode).
The signatures
argument in model.save
and tf.saved_model.save
allows models saved to SavedModel to define signatures (see RFC: SavedModel saving/loading in 2.X). Signatures are primarily used for serving, but can be used to save and load unattached tf.functions.
The following example saves a model with signatures to predict from raw inputs, or from tf.Examples:
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(16, input_shape=(32,)))
model.add(tf.keras.layers.Dense(5, activation=tf.nn.softmax))
@tf.function(input_signature=[tf.TensorSpec([None, 32], dtype=tf.float32)])
def model_predict(input_batch):
return {'outputs': model(input_batch, training=False)}
@tf.function(input_signature=[tf.TensorSpec([None], dtype=tf.string)])
def classify(serialized):
feature_spec = {'input_batch': tf.io.FixedLenFeature([32], tf.float32)}
deserialized = tf.io.parse_example(serialized, feature_spec)
scores = model_predict(deserialized['input_batch'])
return {'classes': tf.argmax(scores, 1), 'scores': scores}
model.save(
'/tmp/keras_model', signatures={'classify': classify, 'predict': model_predict})
And later, loading the SavedModel:
model = tf.keras.load_model('/tmp/keras_model')
model.signatures # {'classify' → tf.function, 'predict' → tf.function}
If the signatures
argument is left empty, then a default signature is created with the traced model call function. In the future, we may consider always exporting the default signature (even if the signatures
argument is set).
The goal is simple - save a Keras layer or model object to SavedModel, and load it back in the same state. Same state is defined as all of the below are equivalent-
- variables (list of all variables, trainable variables, and non-trainable variables)
- forward pass (call function)
- losses
- metrics
- child layers (and their states)
- python attributes
- name
- trainable
- config
- dtype
- batch_input_shape
- input_spec
- (if compiled model) compiled arguments
- optimizer (hyperparameters, iterations)
- loss functions
- metrics
All of the above attributes are serialized in the SavedModel as either metadata, checkpointable objects, or functions.
Checkpointable objects and functions are as defined in RFC: SavedModel Save/Load in 2.x (see Concrete function graph and Checkpointable in the Terminology section).
Metadata will be a newly added property of checkpointable objects in the SavedModel. Checkpointable objects may save arbitrary strings to this field. Python attributes such as layer names are saved to the metadata.
See Detailed Design for more info on how each attribute is saved.
Reconstructing a Keras model is a matter of remapping the saved attributes (listed above) to the original attributes. A few more details are in the Detailed design.
Keras models contain many Python-defined components. For example, the trainable_variable
property lists the model's trainable variables by recursively retrieving the trainable variables from each of the child layers. Another example is model.call
, a python function that calls child layers and adds ops to the backend graph.
Only Tensorflow checkpointable objects and functions can be serialized to SavedModel. Serializing a Keras model as-is results in a checkpointable object that does not resemble a Keras model at all. Thus, extra checkpointable objects and functions must be created during serialization.
The following checkpointable objects and functions are attached to the saved Keras model:
variables
: List of all variables in this layer and sublayers.trainable_variables
: List of all trainable variables in this layer and sublayers.non_trainable_variables
: List of all non-trainable variables in this layer and sublayers.regularization_losses
: List of unconditional loss functions in this layer and sublayers. Each function takes no arguments, and returns a scalar tensor.layers
: Flat list of all sublayers (does not include metrics, even though Metric subclasses Layer).metrics
: List of all metric layers attached to this layer and sublayers._variables
: List of all variables owned by this object (and not sublayers)__call__
: Returns the outputs of the call function.call_and_return_conditional_losses
: Returns the outputs of the call function, as well as a list input-dependent losses (does not include the activity regularizer loss).call_and_return_all_conditional_losses
: A function that calls the model and returns returns outputs and returns all input-dependent losses. Unlikecall_and_return_conditional_losses
, the losses returned in this function includes the activity regularizer and any compiled losses.activity_regularizer_fn
: Activity regularization functioncompile_losses
: List of loss functions added duringmodel.compile
.compile_metrics
: List of metric objects added duringmodel.compile
.
The optimizer is a checkpointable object, so it is automatically saved to the SavedModel.
Public vs private variables (variables vs _variables)
The public attributes are exported so that the all variables/trainable variables/etc. may be accessed without the Keras python logic to recursively traverse all the sublayers. The private variable attribute is exported so that when the model is deserialized:
- it is clear which objects own variables
- variables are guaranteed to be in the same order. This is important for Keras models, which uses layer and weight order for certain operations (e.g. saving/loading to HDF5).
Call functions
Two versions of the call function are exported. Exporting __call__
enables the model(inputs)
function to be retained in the generic object created by tf.saved_model.load
. call_and_return_conditional_losses
is exported for model deserialization.
The generic loader, tf.saved_model.load
, creates a generic object with attributes as saved in the SavedModel. Loading the saved Keras model (with added checkpointable objects and functions as listed above) will produce:
GenericObject obj → .variables, .trainable_variables, etc.
GenericObject
is similar to a Keras model object, but lacks the .fit
, .test
and .predict
methods.
When reconstructing a Keras model, the saved attributes are remapped to the original names. The exception is the call function, which uses call_and_return_conditional_losses
instead of __call__
.
Having shared endpoints allows models to be used interchangeably between different frameworks. The following properties were proposed by arnoegw@google.com and andresp@google.com, from the TF Hub team, as common endpoints shared by all modules:
__call__
: A function that takes inputs to the model and returns outputsvariables
: List of all variables in the modeltrainable_variables
: List of all trainable variables in the modelregularization_losses
: List of callables that return a scalar tensor.call_and_return_all_conditional_losses
: A function that calls the model and returns returns outputs and returns all input-dependent losses. Unlikecall_and_return_conditional_losses
, the losses returned in this function includes the activity regularizer and any compiled losses.
All losses contained in the model are split between regularization_losses
and call_and_return_all_conditional_losses
.
Any SavedModel with these endpoints defined may be loaded as a Keras model using tf.keras.models.load_model
.
Keras SavedModels are backward and forward compatible across minor TensorFlow versions (similar to GraphDef). Therefore, checkpointable object and tf.function attributes will not be removed from the SavedModel. New attributes can be added if additional serialization is requested.
- Do we all agree on the API changes?
- Regarding the
include_optimizer
argument: If this is set to False, then the optimizer will not be included. If the model has been compiled, then the eval graphs. - Should we always save the default model signature? No (at least not initially). This can be added later on, but the change should also be added to
tf.saved_model.save
.
- Regarding the
- Are there other aspects of the model that should be serialized/deserialized?
- Add an additional method that includes all input-dependent losses (
call_and_return_all_conditional_losses
). This combines the losses generated fromcall_and_return_conditional_losses
,compile_losses
, andactivity_regularizer_fn
.
- Add an additional method that includes all input-dependent losses (
- Should
tf.saved_model.load
return a generic object or Keras model? Generic object - Syncing common endpoints with tf.module
tf.module is extremely open-ended, so just the
.variables
attribute should be synced between tf.module SavedModel and Keras SavedModel. Note about.trainable_variables
-- Keras layers and tf.modules have different definitions of trainable variables.