[go: nahoru, domu]

Skip to content

anderspkd/tf_train_quantized

Repository files navigation

Quantize aware training and conversion to tflite files

This folder contains a couple of scripts that can be used to train models and convert the result into a tflite files.

Currently, only MNIST models are supported, but it should be easy to use train_mnist.py as a starting point for other problem areas (e.g., CIFAR10).

quickstart

Run

$ virtualenv --python $(which python3)
$ source venv/bin/activate
$ pip install tensorflow

And then, for example,

$ ./train_and_quantize mnist_simple1

Consult Models.py to see which models can be trained.

Step-by-step

Training

The file train_mnist.py specifies a script that performs quantization aware training given the name of one of the models defined in models.py (see bit about models further down).

Example usage:

$ ./train_mnist.py
usage: train_mnist.py [-h] [-m name] [-l] [--epochs epochs]
			  [--checkpoint-dir dir] [--freeze model name]

trains a model with quantization aware training

optional arguments:
  -h, --help            show this help message and exit
  -m name, --model-name name
			    name of model. Must be defined in models.py
  -l, --list-models     lists available models
  --epochs epochs       number of epochs for training. Default is 1
  --checkpoint-dir dir  directory to save checkpoint information. Default is
			    "./chkpt/checkpoints"
  --freeze model name   freezes the model
$ ./train_mnist.py -m simple
[snip 🦀]

60000/60000 [==============================] - 2s 29us/sample - loss: 0.4387 - acc: 0.8832
10000/10000 [==============================] - 1s 54us/sample - loss: 0.2266 - acc: 0.9360

Evaluation results:
test loss 0.22657053579986094
test accuracy 0.936

This generates some checkpoints in the folder ./chkpt/ that will be needed later

Converting checkpoints to a frozen graph def

The script checkpoint2pb.py takes the checkpoints from the previous step and creates a frozen graph def file:

$ ./checkpoint2pb.py
usage: checkpoint2pb.py [-h] model_name checkpoints
checkpoint2pb.py: error: the following arguments are required: model_name, checkpoints
$ ./checkpoint2pb.py mnist_simple chkpt/checkpoints
[snip 🦀]
-----------------------
writing mnist_simple.pb
input_arrays: flatten_input
output_arrays: dense_1/Softmax

Notice that the model name needs to prefixed with the name of the group it belongs to (here mnist).

Converting a frozen graph to a tflite file

Finally, the .pb file can be converted into a fully optimized and quantized model file.

$ ./pb2tflite.sh
usage: ./pb2tflite.sh [frozen_model.pb] [input_arrays] [output_arrays]
$ ./pb2tflite.sh mnist_simple.pb flatten_input dense_1/Softmax
converting "mnist_simple.pb" to "mnist_simple.tflite"
[snip 🦀]
done!

Models

All models should be defined in models.py. See existing models for how to go about defining a new model.

Mnist

Two models are currently defined for the MNIST dataset:

Simple

Two fully connected layers with relu activation.

Simple2

One convolution and two fully connected layers with relu6 activation.