{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2023-10-17T11:52:43.029472Z", "iopub.status.busy": "2023-10-17T11:52:43.028962Z", "iopub.status.idle": "2023-10-17T11:52:43.032766Z", "shell.execute_reply": "2023-10-17T11:52:43.032083Z" }, "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "qFdPvlXBOdUN" }, "source": [ "# Image classification with Model Garden" ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "Ta_nFXaVAqLD" }, "source": [ "This tutorial fine-tunes a Residual Network (ResNet) from the TensorFlow [Model Garden](https://github.com/tensorflow/models) package (`tensorflow-models`) to classify images in the [CIFAR](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.\n", "\n", "Model Garden contains a collection of state-of-the-art vision models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n", "\n", "This tutorial uses a [ResNet](https://arxiv.org/pdf/1512.03385.pdf) model, a state-of-the-art image classifier. This tutorial uses the ResNet-18 model, a convolutional neural network with 18 layers.\n", "\n", "This tutorial demonstrates how to:\n", "1. Use models from the TensorFlow Models package.\n", "2. Fine-tune a pre-built ResNet for image classification.\n", "3. Export the tuned ResNet model." ] }, { "cell_type": "markdown", "metadata": { "id": "G2FlaQcEPOER" }, "source": [ "## Setup\n", "\n", "Install and import the necessary modules." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:52:43.036406Z", "iopub.status.busy": "2023-10-17T11:52:43.036004Z", "iopub.status.idle": "2023-10-17T11:52:53.249823Z", "shell.execute_reply": "2023-10-17T11:52:53.248930Z" }, "id": "XvWfdCrvrV5W" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: There was an error checking the latest version of pip.\u001b[0m\u001b[33m\r\n", "\u001b[0m" ] } ], "source": [ "!pip install -U -q \"tf-models-official\"" ] }, { "cell_type": "markdown", "metadata": { "id": "CKYMTPjOE400" }, "source": [ "Import TensorFlow, TensorFlow Datasets, and a few helper libraries." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:52:53.254966Z", "iopub.status.busy": "2023-10-17T11:52:53.254312Z", "iopub.status.idle": "2023-10-17T11:52:56.468885Z", "shell.execute_reply": "2023-10-17T11:52:56.468153Z" }, "id": "Wlon1uoIowmZ" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-10-17 11:52:54.005237: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2023-10-17 11:52:54.005294: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2023-10-17 11:52:54.005338: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "import pprint\n", "import tempfile\n", "\n", "from IPython import display\n", "import matplotlib.pyplot as plt\n", "\n", "import tensorflow as tf\n", "import tensorflow_datasets as tfds" ] }, { "cell_type": "markdown", "metadata": { "id": "AVTs0jDd1b24" }, "source": [ "The `tensorflow_models` package contains the ResNet vision model, and the `official.vision.serving` model contains the function to save and export the tuned model." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:52:56.473112Z", "iopub.status.busy": "2023-10-17T11:52:56.472740Z", "iopub.status.idle": "2023-10-17T11:52:57.684827Z", "shell.execute_reply": "2023-10-17T11:52:57.683792Z" }, "id": "NHT1iiIiBzlC" }, "outputs": [], "source": [ "import tensorflow_models as tfm\n", "\n", "# These are not in the tfm public API for v2.9. They will be available in v2.10\n", "from official.vision.serving import export_saved_model_lib\n", "import official.core.train_lib" ] }, { "cell_type": "markdown", "metadata": { "id": "aKv3wdqkQ8FU" }, "source": [ "## Configure the ResNet-18 model for the Cifar-10 dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "5iN8mHEJjKYE" }, "source": [ "The CIFAR10 dataset contains 60,000 color images in mutually exclusive 10 classes, with 6,000 images in each class.\n", "\n", "In Model Garden, the collections of parameters that define a model are called *configs*. Model Garden can create a config based on a known set of parameters via a [factory](https://en.wikipedia.org/wiki/Factory_method_pattern).\n", "\n", "Use the `resnet_imagenet` factory configuration, as defined by `tfm.vision.configs.image_classification.image_classification_imagenet`. The configuration is set up to train ResNet to converge on [ImageNet](https://www.image-net.org/)." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:52:57.690032Z", "iopub.status.busy": "2023-10-17T11:52:57.689085Z", "iopub.status.idle": "2023-10-17T11:52:59.390071Z", "shell.execute_reply": "2023-10-17T11:52:59.389431Z" }, "id": "1M77f88Dj2Td" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-10-17 11:52:59.285390: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.\n", "Skipping registering GPU devices...\n" ] }, { "data": { "text/plain": [ "tfds.core.DatasetInfo(\n", " name='cifar10',\n", " full_name='cifar10/3.0.2',\n", " description=\"\"\"\n", " The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.\n", " \"\"\",\n", " homepage='https://www.cs.toronto.edu/~kriz/cifar.html',\n", " data_dir='gs://tensorflow-datasets/datasets/cifar10/3.0.2',\n", " file_format=tfrecord,\n", " download_size=162.17 MiB,\n", " dataset_size=132.40 MiB,\n", " features=FeaturesDict({\n", " 'id': Text(shape=(), dtype=string),\n", " 'image': Image(shape=(32, 32, 3), dtype=uint8),\n", " 'label': ClassLabel(shape=(), dtype=int64, num_classes=10),\n", " }),\n", " supervised_keys=('image', 'label'),\n", " disable_shuffling=False,\n", " splits={\n", " 'test': ,\n", " 'train': ,\n", " },\n", " citation=\"\"\"@TECHREPORT{Krizhevsky09learningmultiple,\n", " author = {Alex Krizhevsky},\n", " title = {Learning multiple layers of features from tiny images},\n", " institution = {},\n", " year = {2009}\n", " }\"\"\",\n", ")" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')\n", "tfds_name = 'cifar10'\n", "ds,ds_info = tfds.load(\n", "tfds_name,\n", "with_info=True)\n", "ds_info" ] }, { "cell_type": "markdown", "metadata": { "id": "U6PVwXA-j3E7" }, "source": [ "Adjust the model and dataset configurations so that it works with Cifar-10 (`cifar10`)." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:52:59.393899Z", "iopub.status.busy": "2023-10-17T11:52:59.393642Z", "iopub.status.idle": "2023-10-17T11:52:59.399820Z", "shell.execute_reply": "2023-10-17T11:52:59.399255Z" }, "id": "YWI7faVStQaV" }, "outputs": [], "source": [ "# Configure model\n", "exp_config.task.model.num_classes = 10\n", "exp_config.task.model.input_size = list(ds_info.features[\"image\"].shape)\n", "exp_config.task.model.backbone.resnet.model_id = 18\n", "\n", "# Configure training and testing data\n", "batch_size = 128\n", "\n", "exp_config.task.train_data.input_path = ''\n", "exp_config.task.train_data.tfds_name = tfds_name\n", "exp_config.task.train_data.tfds_split = 'train'\n", "exp_config.task.train_data.global_batch_size = batch_size\n", "\n", "exp_config.task.validation_data.input_path = ''\n", "exp_config.task.validation_data.tfds_name = tfds_name\n", "exp_config.task.validation_data.tfds_split = 'test'\n", "exp_config.task.validation_data.global_batch_size = batch_size\n" ] }, { "cell_type": "markdown", "metadata": { "id": "DE3ggKzzTD56" }, "source": [ "Adjust the trainer configuration." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:52:59.403420Z", "iopub.status.busy": "2023-10-17T11:52:59.402985Z", "iopub.status.idle": "2023-10-17T11:52:59.409947Z", "shell.execute_reply": "2023-10-17T11:52:59.409272Z" }, "id": "inE_-4UGkLud" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on CPU is slow, so only train for a few steps.\n" ] } ], "source": [ "logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n", "\n", "if 'GPU' in ''.join(logical_device_names):\n", " print('This may be broken in Colab.')\n", " device = 'GPU'\n", "elif 'TPU' in ''.join(logical_device_names):\n", " print('This may be broken in Colab.')\n", " device = 'TPU'\n", "else:\n", " print('Running on CPU is slow, so only train for a few steps.')\n", " device = 'CPU'\n", "\n", "if device=='CPU':\n", " train_steps = 20\n", " exp_config.trainer.steps_per_loop = 5\n", "else:\n", " train_steps=5000\n", " exp_config.trainer.steps_per_loop = 100\n", "\n", "exp_config.trainer.summary_interval = 100\n", "exp_config.trainer.checkpoint_interval = train_steps\n", "exp_config.trainer.validation_interval = 1000\n", "exp_config.trainer.validation_steps = ds_info.splits['test'].num_examples // batch_size\n", "exp_config.trainer.train_steps = train_steps\n", "exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'\n", "exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps\n", "exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1\n", "exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100" ] }, { "cell_type": "markdown", "metadata": { "id": "5mTcDnBiTOYD" }, "source": [ "Print the modified configuration." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:52:59.413309Z", "iopub.status.busy": "2023-10-17T11:52:59.412932Z", "iopub.status.idle": "2023-10-17T11:52:59.422408Z", "shell.execute_reply": "2023-10-17T11:52:59.421780Z" }, "id": "tuVfxSBCTK-y" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'runtime': {'all_reduce_alg': None,\n", " 'batchnorm_spatial_persistent': False,\n", " 'dataset_num_private_threads': None,\n", " 'default_shard_dim': -1,\n", " 'distribution_strategy': 'mirrored',\n", " 'enable_xla': True,\n", " 'gpu_thread_mode': None,\n", " 'loss_scale': None,\n", " 'mixed_precision_dtype': None,\n", " 'num_cores_per_replica': 1,\n", " 'num_gpus': 0,\n", " 'num_packs': 1,\n", " 'per_gpu_thread_count': 0,\n", " 'run_eagerly': False,\n", " 'task_index': -1,\n", " 'tpu': None,\n", " 'tpu_enable_xla_dynamic_padder': None,\n", " 'use_tpu_mp_strategy': False,\n", " 'worker_hosts': None},\n", " 'task': {'allow_image_summary': False,\n", " 'differential_privacy_config': None,\n", " 'eval_input_partition_dims': [],\n", " 'evaluation': {'precision_and_recall_thresholds': None,\n", " 'report_per_class_precision_and_recall': False,\n", " 'top_k': 5},\n", " 'freeze_backbone': False,\n", " 'init_checkpoint': None,\n", " 'init_checkpoint_modules': 'all',\n", " 'losses': {'l2_weight_decay': 0.0001,\n", " 'label_smoothing': 0.0,\n", " 'loss_weight': 1.0,\n", " 'one_hot': True,\n", " 'soft_labels': False,\n", " 'use_binary_cross_entropy': False},\n", " 'model': {'add_head_batch_norm': False,\n", " 'backbone': {'resnet': {'bn_trainable': True,\n", " 'depth_multiplier': 1.0,\n", " 'model_id': 18,\n", " 'replace_stem_max_pool': False,\n", " 'resnetd_shortcut': False,\n", " 'scale_stem': True,\n", " 'se_ratio': 0.0,\n", " 'stem_type': 'v0',\n", " 'stochastic_depth_drop_rate': 0.0},\n", " 'type': 'resnet'},\n", " 'dropout_rate': 0.0,\n", " 'input_size': [32, 32, 3],\n", " 'kernel_initializer': 'random_uniform',\n", " 'norm_activation': {'activation': 'relu',\n", " 'norm_epsilon': 1e-05,\n", " 'norm_momentum': 0.9,\n", " 'use_sync_bn': False},\n", " 'num_classes': 10,\n", " 'output_softmax': False},\n", " 'model_output_keys': [],\n", " 'name': None,\n", " 'train_data': {'apply_tf_data_service_before_batching': False,\n", " 'aug_crop': True,\n", " 'aug_policy': None,\n", " 'aug_rand_hflip': True,\n", " 'aug_type': None,\n", " 'autotune_algorithm': None,\n", " 'block_length': 1,\n", " 'cache': False,\n", " 'center_crop_fraction': 0.875,\n", " 'color_jitter': 0.0,\n", " 'crop_area_range': (0.08, 1.0),\n", " 'cycle_length': 10,\n", " 'decode_jpeg_only': True,\n", " 'decoder': {'simple_decoder': {'attribute_names': [],\n", " 'mask_binarize_threshold': None,\n", " 'regenerate_source_id': False},\n", " 'type': 'simple_decoder'},\n", " 'deterministic': None,\n", " 'drop_remainder': True,\n", " 'dtype': 'float32',\n", " 'enable_shared_tf_data_service_between_parallel_trainers': False,\n", " 'enable_tf_data_service': False,\n", " 'file_type': 'tfrecord',\n", " 'global_batch_size': 128,\n", " 'image_field_key': 'image/encoded',\n", " 'input_path': '',\n", " 'is_multilabel': False,\n", " 'is_training': True,\n", " 'label_field_key': 'image/class/label',\n", " 'mixup_and_cutmix': None,\n", " 'prefetch_buffer_size': None,\n", " 'randaug_magnitude': 10,\n", " 'random_erasing': None,\n", " 'repeated_augment': None,\n", " 'seed': None,\n", " 'sharding': True,\n", " 'shuffle_buffer_size': 10000,\n", " 'tf_data_service_address': None,\n", " 'tf_data_service_job_name': None,\n", " 'tf_resize_method': 'bilinear',\n", " 'tfds_as_supervised': False,\n", " 'tfds_data_dir': '',\n", " 'tfds_name': 'cifar10',\n", " 'tfds_skip_decoding_feature': '',\n", " 'tfds_split': 'train',\n", " 'three_augment': False,\n", " 'trainer_id': None,\n", " 'weights': None},\n", " 'train_input_partition_dims': [],\n", " 'validation_data': {'apply_tf_data_service_before_batching': False,\n", " 'aug_crop': True,\n", " 'aug_policy': None,\n", " 'aug_rand_hflip': True,\n", " 'aug_type': None,\n", " 'autotune_algorithm': None,\n", " 'block_length': 1,\n", " 'cache': False,\n", " 'center_crop_fraction': 0.875,\n", " 'color_jitter': 0.0,\n", " 'crop_area_range': (0.08, 1.0),\n", " 'cycle_length': 10,\n", " 'decode_jpeg_only': True,\n", " 'decoder': {'simple_decoder': {'attribute_names': [],\n", " 'mask_binarize_threshold': None,\n", " 'regenerate_source_id': False},\n", " 'type': 'simple_decoder'},\n", " 'deterministic': None,\n", " 'drop_remainder': True,\n", " 'dtype': 'float32',\n", " 'enable_shared_tf_data_service_between_parallel_trainers': False,\n", " 'enable_tf_data_service': False,\n", " 'file_type': 'tfrecord',\n", " 'global_batch_size': 128,\n", " 'image_field_key': 'image/encoded',\n", " 'input_path': '',\n", " 'is_multilabel': False,\n", " 'is_training': False,\n", " 'label_field_key': 'image/class/label',\n", " 'mixup_and_cutmix': None,\n", " 'prefetch_buffer_size': None,\n", " 'randaug_magnitude': 10,\n", " 'random_erasing': None,\n", " 'repeated_augment': None,\n", " 'seed': None,\n", " 'sharding': True,\n", " 'shuffle_buffer_size': 10000,\n", " 'tf_data_service_address': None,\n", " 'tf_data_service_job_name': None,\n", " 'tf_resize_method': 'bilinear',\n", " 'tfds_as_supervised': False,\n", " 'tfds_data_dir': '',\n", " 'tfds_name': 'cifar10',\n", " 'tfds_skip_decoding_feature': '',\n", " 'tfds_split': 'test',\n", " 'three_augment': False,\n", " 'trainer_id': None,\n", " 'weights': None}},\n", " 'trainer': {'allow_tpu_summary': False,\n", " 'best_checkpoint_eval_metric': '',\n", " 'best_checkpoint_export_subdir': '',\n", " 'best_checkpoint_metric_comp': 'higher',\n", " 'checkpoint_interval': 20,\n", " 'continuous_eval_timeout': 3600,\n", " 'eval_tf_function': True,\n", " 'eval_tf_while_loop': False,\n", " 'loss_upper_bound': 1000000.0,\n", " 'max_to_keep': 5,\n", " 'optimizer_config': {'ema': None,\n", " 'learning_rate': {'cosine': {'alpha': 0.0,\n", " 'decay_steps': 20,\n", " 'initial_learning_rate': 0.1,\n", " 'name': 'CosineDecay',\n", " 'offset': 0},\n", " 'type': 'cosine'},\n", " 'optimizer': {'sgd': {'clipnorm': None,\n", " 'clipvalue': None,\n", " 'decay': 0.0,\n", " 'global_clipnorm': None,\n", " 'momentum': 0.9,\n", " 'name': 'SGD',\n", " 'nesterov': False},\n", " 'type': 'sgd'},\n", " 'warmup': {'linear': {'name': 'linear',\n", " 'warmup_learning_rate': 0,\n", " 'warmup_steps': 100},\n", " 'type': 'linear'}},\n", " 'preemption_on_demand_checkpoint': True,\n", " 'recovery_begin_steps': 0,\n", " 'recovery_max_trials': 0,\n", " 'steps_per_loop': 5,\n", " 'summary_interval': 100,\n", " 'train_steps': 20,\n", " 'train_tf_function': True,\n", " 'train_tf_while_loop': True,\n", " 'validation_interval': 1000,\n", " 'validation_steps': 78,\n", " 'validation_summary_subdir': 'validation'}}\n" ] }, { "data": { "application/javascript": [ "google.colab.output.setIframeHeight('300px');" ], "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pprint.pprint(exp_config.as_dict())\n", "\n", "display.Javascript(\"google.colab.output.setIframeHeight('300px');\")" ] }, { "cell_type": "markdown", "metadata": { "id": "w7_X0UHaRF2m" }, "source": [ "Set up the distribution strategy." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:52:59.425876Z", "iopub.status.busy": "2023-10-17T11:52:59.425427Z", "iopub.status.idle": "2023-10-17T11:52:59.430927Z", "shell.execute_reply": "2023-10-17T11:52:59.430334Z" }, "id": "ykL14FIbTaSt" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Warning: this will be really slow.\n" ] } ], "source": [ "logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n", "\n", "if exp_config.runtime.mixed_precision_dtype == tf.float16:\n", " tf.keras.mixed_precision.set_global_policy('mixed_float16')\n", "\n", "if 'GPU' in ''.join(logical_device_names):\n", " distribution_strategy = tf.distribute.MirroredStrategy()\n", "elif 'TPU' in ''.join(logical_device_names):\n", " tf.tpu.experimental.initialize_tpu_system()\n", " tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')\n", " distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n", "else:\n", " print('Warning: this will be really slow.')\n", " distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])" ] }, { "cell_type": "markdown", "metadata": { "id": "W4k5YH5pTjaK" }, "source": [ "Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n", "\n", "The `Task` object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:52:59.434412Z", "iopub.status.busy": "2023-10-17T11:52:59.433877Z", "iopub.status.idle": "2023-10-17T11:52:59.437908Z", "shell.execute_reply": "2023-10-17T11:52:59.437318Z" }, "id": "6MgYSH0PtUaW" }, "outputs": [], "source": [ "with distribution_strategy.scope():\n", " model_dir = tempfile.mkdtemp()\n", " task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)\n", "\n", "# tf.keras.utils.plot_model(task.build_model(), show_shapes=True)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:52:59.441333Z", "iopub.status.busy": "2023-10-17T11:52:59.440885Z", "iopub.status.idle": "2023-10-17T11:53:02.392113Z", "shell.execute_reply": "2023-10-17T11:53:02.391341Z" }, "id": "IFXEZYdzBKoX" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "images.shape: (128, 32, 32, 3) images.dtype: tf.float32\n", "labels.shape: (128,) labels.dtype: tf.int32\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-10-17 11:53:02.248801: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n" ] } ], "source": [ "for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n", " print()\n", " print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')\n", " print(f'labels.shape: {str(labels.shape):16} labels.dtype: {labels.dtype!r}')" ] }, { "cell_type": "markdown", "metadata": { "id": "yrwxnGDaRU0U" }, "source": [ "## Visualize the training data" ] }, { "cell_type": "markdown", "metadata": { "id": "683c255c6c52" }, "source": [ "The dataloader applies a z-score normalization using \n", "`preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`, so the images returned by the dataset can't be directly displayed by standard tools. The visualization code needs to rescale the data into the [0,1] range." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:53:02.396806Z", "iopub.status.busy": "2023-10-17T11:53:02.396199Z", "iopub.status.idle": "2023-10-17T11:53:02.426521Z", "shell.execute_reply": "2023-10-17T11:53:02.425896Z" }, "id": "PdmOz2EC0Nx2" }, "outputs": [], "source": [ "plt.hist(images.numpy().flatten());" ] }, { "cell_type": "markdown", "metadata": { "id": "7a8582ebde7b" }, "source": [ "Use `ds_info` (which is an instance of `tfds.core.DatasetInfo`) to lookup the text descriptions of each class ID." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:53:02.429830Z", "iopub.status.busy": "2023-10-17T11:53:02.429607Z", "iopub.status.idle": "2023-10-17T11:53:02.433999Z", "shell.execute_reply": "2023-10-17T11:53:02.433370Z" }, "id": "Wq4Wq_CuDG3Q" }, "outputs": [ { "data": { "text/plain": [ "'automobile'" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "label_info = ds_info.features['label']\n", "label_info.int2str(1)" ] }, { "cell_type": "markdown", "metadata": { "id": "8c652a6fdbcf" }, "source": [ "Visualize a batch of the data." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:53:02.436913Z", "iopub.status.busy": "2023-10-17T11:53:02.436700Z", "iopub.status.idle": "2023-10-17T11:53:02.442302Z", "shell.execute_reply": "2023-10-17T11:53:02.441685Z" }, "id": "ZKfTxytf1l0d" }, "outputs": [], "source": [ "def show_batch(images, labels, predictions=None):\n", " plt.figure(figsize=(10, 10))\n", " min = images.numpy().min()\n", " max = images.numpy().max()\n", " delta = max - min\n", "\n", " for i in range(12):\n", " plt.subplot(6, 6, i + 1)\n", " plt.imshow((images[i]-min) / delta)\n", " if predictions is None:\n", " plt.title(label_info.int2str(labels[i]))\n", " else:\n", " if labels[i] == predictions[i]:\n", " color = 'g'\n", " else:\n", " color = 'r'\n", " plt.title(label_info.int2str(predictions[i]), color=color)\n", " plt.axis(\"off\")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:53:02.444915Z", "iopub.status.busy": "2023-10-17T11:53:02.444689Z", "iopub.status.idle": "2023-10-17T11:53:04.806488Z", "shell.execute_reply": "2023-10-17T11:53:04.805735Z" }, "id": "xkA5h_RBtYYU" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-10-17 11:53:04.198417: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n" ] } ], "source": [ "plt.figure(figsize=(10, 10))\n", "for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n", " show_batch(images, labels)" ] }, { "cell_type": "markdown", "metadata": { "id": "v_A9VnL2RbXP" }, "source": [ "## Visualize the testing data" ] }, { "cell_type": "markdown", "metadata": { "id": "AXovuumW_I2z" }, "source": [ "Visualize a batch of images from the validation dataset." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:53:04.811306Z", "iopub.status.busy": "2023-10-17T11:53:04.810658Z", "iopub.status.idle": "2023-10-17T11:53:07.146253Z", "shell.execute_reply": "2023-10-17T11:53:07.145503Z" }, "id": "Ma-_Eb-nte9A" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-10-17 11:53:07.007846: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n" ] } ], "source": [ "plt.figure(figsize=(10, 10));\n", "for images, labels in task.build_inputs(exp_config.task.validation_data).take(1):\n", " show_batch(images, labels)" ] }, { "cell_type": "markdown", "metadata": { "id": "ihKJt2FHRi2N" }, "source": [ "## Train and evaluate" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:53:07.150096Z", "iopub.status.busy": "2023-10-17T11:53:07.149838Z", "iopub.status.idle": "2023-10-17T11:53:47.507619Z", "shell.execute_reply": "2023-10-17T11:53:47.506891Z" }, "id": "0AFMNvYxtjXx" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "restoring or initializing model...\n", "INFO:tensorflow:Customized initialization is done through the passed `init_fn`.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Customized initialization is done through the passed `init_fn`.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "train | step: 0 | training until step 20...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-10-17 11:53:09.849007: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "train | step: 5 | steps/sec: 0.5 | output: \n", " {'accuracy': 0.103125,\n", " 'learning_rate': 0.0,\n", " 'top_5_accuracy': 0.4828125,\n", " 'training_loss': 2.7998607}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "saved checkpoint to /tmpfs/tmp/tmpu0ate1h5/ckpt-5.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "train | step: 10 | steps/sec: 0.8 | output: \n", " {'accuracy': 0.0828125,\n", " 'learning_rate': 0.0,\n", " 'top_5_accuracy': 0.4984375,\n", " 'training_loss': 2.8205295}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "train | step: 15 | steps/sec: 0.8 | output: \n", " {'accuracy': 0.0921875,\n", " 'learning_rate': 0.0,\n", " 'top_5_accuracy': 0.503125,\n", " 'training_loss': 2.8169343}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "train | step: 20 | steps/sec: 0.8 | output: \n", " {'accuracy': 0.1015625,\n", " 'learning_rate': 0.0,\n", " 'top_5_accuracy': 0.45,\n", " 'training_loss': 2.8760865}\n", " eval | step: 20 | running 78 steps of evaluation...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " eval | step: 20 | steps/sec: 24.4 | eval time: 3.2 sec | output: \n", " {'accuracy': 0.09485176,\n", " 'steps_per_second': 24.40085348913806,\n", " 'top_5_accuracy': 0.49589342,\n", " 'validation_loss': 2.5864375}\n", "saved checkpoint to /tmpfs/tmp/tmpu0ate1h5/ckpt-20.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-10-17 11:53:43.844533: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.\n", "Skipping registering GPU devices...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/nn_ops.py:5253: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/nn_ops.py:5253: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " eval | step: 20 | running 78 steps of evaluation...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-10-17 11:53:45.627213: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " eval | step: 20 | steps/sec: 40.1 | eval time: 1.9 sec | output: \n", " {'accuracy': 0.09485176,\n", " 'steps_per_second': 40.14298727815298,\n", " 'top_5_accuracy': 0.49589342,\n", " 'validation_loss': 2.5864375}\n" ] } ], "source": [ "model, eval_logs = tfm.core.train_lib.run_experiment(\n", " distribution_strategy=distribution_strategy,\n", " task=task,\n", " mode='train_and_eval',\n", " params=exp_config,\n", " model_dir=model_dir,\n", " run_post_eval=True)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:53:47.511575Z", "iopub.status.busy": "2023-10-17T11:53:47.510951Z", "iopub.status.idle": "2023-10-17T11:53:47.514403Z", "shell.execute_reply": "2023-10-17T11:53:47.513723Z" }, "id": "gCcHMQYhozmA" }, "outputs": [], "source": [ "# tf.keras.utils.plot_model(model, show_shapes=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "L7nVfxlBA8Gb" }, "source": [ "Print the `accuracy`, `top_5_accuracy`, and `validation_loss` evaluation metrics." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:53:47.518093Z", "iopub.status.busy": "2023-10-17T11:53:47.517450Z", "iopub.status.idle": "2023-10-17T11:53:47.521653Z", "shell.execute_reply": "2023-10-17T11:53:47.520949Z" }, "id": "0124f938a1b9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy : 0.095\n", "top_5_accuracy : 0.496\n", "validation_loss : 2.586\n", "steps_per_second : 40.143\n" ] } ], "source": [ "for key, value in eval_logs.items():\n", " if isinstance(value, tf.Tensor):\n", " value = value.numpy()\n", " print(f'{key:20}: {value:.3f}')" ] }, { "cell_type": "markdown", "metadata": { "id": "TDys5bZ1zsml" }, "source": [ "Run a batch of the processed training data through the model, and view the results" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:53:47.525508Z", "iopub.status.busy": "2023-10-17T11:53:47.525037Z", "iopub.status.idle": "2023-10-17T11:53:51.012412Z", "shell.execute_reply": "2023-10-17T11:53:51.011736Z" }, "id": "GhI7zR-Uz1JT" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-10-17 11:53:49.840600: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/4 [======>.......................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "4/4 [==============================] - 1s 13ms/step\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-10-17 11:53:50.778301: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n" ] } ], "source": [ "for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n", " predictions = model.predict(images)\n", " predictions = tf.argmax(predictions, axis=-1)\n", "\n", "show_batch(images, labels, tf.cast(predictions, tf.int32))\n", "\n", "if device=='CPU':\n", " plt.suptitle('The model was only trained for a few steps, it is not expected to do well.')" ] }, { "cell_type": "markdown", "metadata": { "id": "fkE9locGTBgt" }, "source": [ "## Export a SavedModel" ] }, { "cell_type": "markdown", "metadata": { "id": "9669d08c91af" }, "source": [ "The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details, so you can pass `tf.uint8` images and get the correct results.\n" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:53:51.016823Z", "iopub.status.busy": "2023-10-17T11:53:51.016207Z", "iopub.status.idle": "2023-10-17T11:53:58.157183Z", "shell.execute_reply": "2023-10-17T11:53:58.156381Z" }, "id": "AQCFa7BvtmDg" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: ./export/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: ./export/assets\n" ] } ], "source": [ "# Saving and exporting the trained model\n", "export_saved_model_lib.export_inference_graph(\n", " input_type='image_tensor',\n", " batch_size=1,\n", " input_image_size=[32, 32],\n", " params=exp_config,\n", " checkpoint_path=tf.train.latest_checkpoint(model_dir),\n", " export_dir='./export/')" ] }, { "cell_type": "markdown", "metadata": { "id": "vVr6DxNqTyLZ" }, "source": [ "Test the exported model." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:53:58.162073Z", "iopub.status.busy": "2023-10-17T11:53:58.161829Z", "iopub.status.idle": "2023-10-17T11:53:59.920578Z", "shell.execute_reply": "2023-10-17T11:53:59.919739Z" }, "id": "gP7nOvrftsB0" }, "outputs": [], "source": [ "# Importing SavedModel\n", "imported = tf.saved_model.load('./export/')\n", "model_fn = imported.signatures['serving_default']" ] }, { "cell_type": "markdown", "metadata": { "id": "GiOp2WVIUNUZ" }, "source": [ "Visualize the predictions." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T11:53:59.924731Z", "iopub.status.busy": "2023-10-17T11:53:59.924272Z", "iopub.status.idle": "2023-10-17T11:54:02.215391Z", "shell.execute_reply": "2023-10-17T11:54:02.214653Z" }, "id": "BTRMrZQAN4mk" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-10-17 11:54:01.438509: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n" ] } ], "source": [ "plt.figure(figsize=(10, 10))\n", "for data in tfds.load('cifar10', split='test').batch(12).take(1):\n", " predictions = []\n", " for image in data['image']:\n", " index = tf.argmax(model_fn(image[tf.newaxis, ...])['logits'], axis=1)[0]\n", " predictions.append(index)\n", " show_batch(data['image'], data['label'], predictions)\n", "\n", " if device=='CPU':\n", " plt.suptitle('The model was only trained for a few steps, it is not expected to do better than random.')" ] } ], "metadata": { "colab": { "name": "classification_with_model_garden.ipynb", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }