{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "W7rEsKyWcxmu" }, "source": [ "##### Copyright 2023 The TF-Agents Authors.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2023-12-22T12:17:31.959282Z", "iopub.status.busy": "2023-12-22T12:17:31.958806Z", "iopub.status.idle": "2023-12-22T12:17:31.962844Z", "shell.execute_reply": "2023-12-22T12:17:31.962276Z" }, "id": "nQnmcm0oI1Q-" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "G6aOV15Wc4HP" }, "source": [ "# Checkpointer and PolicySaver\n", "\n", "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " View on TensorFlow.org\n", " \n", " \n", " \n", " Run in Google Colab\n", " \n", " \n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "M3HE5S3wsMEh" }, "source": [ "## Introduction\n", "\n", "`tf_agents.utils.common.Checkpointer` is a utility to save/load the training state, policy state, and replay_buffer state to/from a local storage.\n", "\n", "`tf_agents.policies.policy_saver.PolicySaver` is a tool to save/load only the policy, and is lighter than `Checkpointer`. You can use `PolicySaver` to deploy the model as well without any knowledge of the code that created the policy.\n", "\n", "In this tutorial, we will use DQN to train a model, then use `Checkpointer` and `PolicySaver` to show how we can store and load the states and model in an interactive way. Note that we will use TF2.0's new saved_model tooling and format for `PolicySaver`.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "vbTrDrX4dkP_" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "Opk_cVDYdgct" }, "source": [ " If you haven't installed the following dependencies, run:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:17:31.966631Z", "iopub.status.busy": "2023-12-22T12:17:31.966054Z", "iopub.status.idle": "2023-12-22T12:17:44.962798Z", "shell.execute_reply": "2023-12-22T12:17:44.961942Z" }, "id": "Jv668dKvZmka" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]\r", " \r", "Hit:1 http://us-central1.gce.archive.ubuntu.com/ubuntu focal InRelease\r\n", "\r", "0% [Connecting to security.ubuntu.com (185.125.190.36)] [Connecting to develope\r", " \r", "Hit:2 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates InRelease\r\n", "\r", " \r", "Hit:3 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-backports InRelease\r\n", "\r", "0% [Connecting to security.ubuntu.com (185.125.190.36)] [Connected to developer" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " \r", "Get:4 https://nvidia.github.io/libnvidia-container/stable/ubuntu18.04/amd64 InRelease [1484 B]\r\n", "\r", "0% [Connecting to security.ubuntu.com (185.125.190.36)] [Connecting to apt.llvm\r", "0% [Connecting to security.ubuntu.com (185.125.190.36)] [Connecting to apt.llvm\r", " \r", "Hit:5 https://download.docker.com/linux/ubuntu focal InRelease\r\n", "\r", "0% [Connecting to security.ubuntu.com (185.125.190.36)] [Connecting to apt.llvm\r", " \r", "Hit:6 https://nvidia.github.io/nvidia-container-runtime/stable/ubuntu18.04/amd64 InRelease\r\n", "\r", "0% [Connecting to security.ubuntu.com (185.125.190.36)] [Connecting to apt.llvm\r", " \r", "Hit:7 https://nvidia.github.io/nvidia-docker/ubuntu18.04/amd64 InRelease\r\n", "\r", "0% [Connecting to security.ubuntu.com (185.125.190.36)] [Waiting for headers] [" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Connecting to security.ubuntu.com (185.125.190.36)] [Waiting for headers] [\r", " \r", "Hit:9 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64 InRelease\r\n", "\r", "0% [Connecting to security.ubuntu.com (185.125.190.36)] [Connecting to ppa.laun" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " \r", "0% [Waiting for headers] [Waiting for headers]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " \r", "Hit:10 http://security.ubuntu.com/ubuntu focal-security InRelease\r\n", "\r", " \r", "0% [Waiting for headers]\r", " \r", "Hit:11 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal InRelease\r\n", "\r", " \r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Waiting for headers]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " \r", "Hit:8 https://apt.llvm.org/focal llvm-toolchain-focal-17 InRelease\r\n", "\r", "0% [Waiting for headers]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Waiting for headers]\r", " \r", "Hit:12 http://ppa.launchpad.net/longsleep/golang-backports/ubuntu focal InRelease\r\n", "\r", "0% [Waiting for headers]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Waiting for headers]\r", " \r", "Hit:13 http://ppa.launchpad.net/openjdk-r/ppa/ubuntu focal InRelease\r\n", "\r", " \r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "100% [Working]\r", " \r", "Fetched 1484 B in 1s (1067 B/s)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 0%\r", "\r", "Reading package lists... 0%\r", "\r", "Reading package lists... 0%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 2%\r", "\r", "Reading package lists... 2%\r", "\r", "Reading package lists... 4%\r", "\r", "Reading package lists... 4%\r", "\r", "Reading package lists... 4%\r", "\r", "Reading package lists... 4%\r", "\r", "Reading package lists... 4%\r", "\r", "Reading package lists... 4%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 28%\r", "\r", "Reading package lists... 28%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 41%\r", "\r", "Reading package lists... 41%\r", "\r", "Reading package lists... 41%\r", "\r", "Reading package lists... 41%\r", "\r", "Reading package lists... 41%\r", "\r", "Reading package lists... 41%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 49%\r", "\r", "Reading package lists... 49%\r", "\r", "Reading package lists... 55%\r", "\r", "Reading package lists... 55%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 58%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 62%\r", "\r", "Reading package lists... 62%\r", "\r", "Reading package lists... 65%\r", "\r", "Reading package lists... 65%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 68%\r", "\r", "Reading package lists... 68%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 70%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 77%\r", "\r", "Reading package lists... 77%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 82%\r", "\r", "Reading package lists... 82%\r", "\r", "Reading package lists... 89%\r", "\r", "Reading package lists... 89%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 92%\r", "\r", "Reading package lists... 92%\r", "\r", "Reading package lists... 94%\r", "\r", "Reading package lists... 94%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... Done\r", "\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 0%\r", "\r", "Reading package lists... 100%\r", "\r", "Reading package lists... Done\r", "\r\n", "\r", "Building dependency tree... 0%\r", "\r", "Building dependency tree... 0%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Building dependency tree... 50%\r", "\r", "Building dependency tree... 50%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Building dependency tree \r", "\r\n", "\r", "Reading state information... 0%\r", "\r", "Reading state information... 0%\r", "\r", "Reading state information... Done\r", "\r\n", "python-opengl is already the newest version (3.1.0+dfsg-2build1).\r\n", "ffmpeg is already the newest version (7:4.2.7-0ubuntu0.1).\r\n", "xvfb is already the newest version (2:1.20.13-1ubuntu1~20.04.12).\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "The following packages were automatically installed and are no longer required:\r\n", " libatasmart4 libblockdev-fs2 libblockdev-loop2 libblockdev-part-err2\r\n", " libblockdev-part2 libblockdev-swap2 libblockdev-utils2 libblockdev2\r\n", " libparted-fs-resize0 libxmlb2\r\n", "Use 'sudo apt autoremove' to remove them.\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "0 upgraded, 0 newly installed, 0 to remove and 115 not upgraded.\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pyglet in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (2.0.10)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: imageio==2.4.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (2.4.0)\r\n", "Requirement already satisfied: numpy in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from imageio==2.4.0) (1.26.2)\r\n", "Requirement already satisfied: pillow in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from imageio==2.4.0) (10.1.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: xvfbwrapper==0.2.9 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (0.2.9)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: tf-agents[reverb] in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (0.19.0)\r\n", "Requirement already satisfied: absl-py>=0.6.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (1.4.0)\r\n", "Requirement already satisfied: cloudpickle>=1.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (3.0.0)\r\n", "Requirement already satisfied: gin-config>=0.4.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (0.5.0)\r\n", "Requirement already satisfied: gym<=0.23.0,>=0.17.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (0.23.0)\r\n", "Requirement already satisfied: numpy>=1.19.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (1.26.2)\r\n", "Requirement already satisfied: pillow in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (10.1.0)\r\n", "Requirement already satisfied: six>=1.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (1.16.0)\r\n", "Requirement already satisfied: protobuf>=3.11.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (3.20.3)\r\n", "Requirement already satisfied: wrapt>=1.11.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (1.14.1)\r\n", "Requirement already satisfied: typing-extensions==4.5.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (4.5.0)\r\n", "Requirement already satisfied: pygame==2.1.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (2.1.3)\r\n", "Requirement already satisfied: tensorflow-probability~=0.23.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (0.23.0)\r\n", "Requirement already satisfied: rlds in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (0.1.8)\r\n", "Requirement already satisfied: dm-reverb~=0.14.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (0.14.0)\r\n", "Requirement already satisfied: tensorflow~=2.15.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (2.15.0.post1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: dm-tree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from dm-reverb~=0.14.0->tf-agents[reverb]) (0.1.8)\r\n", "Requirement already satisfied: portpicker in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from dm-reverb~=0.14.0->tf-agents[reverb]) (1.6.0)\r\n", "Requirement already satisfied: gym-notices>=0.0.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from gym<=0.23.0,>=0.17.0->tf-agents[reverb]) (0.0.8)\r\n", "Requirement already satisfied: importlib-metadata>=4.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from gym<=0.23.0,>=0.17.0->tf-agents[reverb]) (7.0.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (1.6.3)\r\n", "Requirement already satisfied: flatbuffers>=23.5.26 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (23.5.26)\r\n", "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (0.5.4)\r\n", "Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (0.2.0)\r\n", "Requirement already satisfied: h5py>=2.9.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (3.10.0)\r\n", "Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (16.0.6)\r\n", "Requirement already satisfied: ml-dtypes~=0.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (0.2.0)\r\n", "Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (3.3.0)\r\n", "Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (23.2)\r\n", "Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (69.0.2)\r\n", "Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (2.4.0)\r\n", "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (0.35.0)\r\n", "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (1.60.0)\r\n", "Requirement already satisfied: tensorboard<2.16,>=2.15 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (2.15.1)\r\n", "Requirement already satisfied: tensorflow-estimator<2.16,>=2.15.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (2.15.0)\r\n", "Requirement already satisfied: keras<2.16,>=2.15.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (2.15.0)\r\n", "Requirement already satisfied: decorator in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-probability~=0.23.0->tf-agents[reverb]) (5.1.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: wheel<1.0,>=0.23.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from astunparse>=1.6.0->tensorflow~=2.15.0->tf-agents[reverb]) (0.41.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.10.0->gym<=0.23.0,>=0.17.0->tf-agents[reverb]) (3.17.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: google-auth<3,>=1.6.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (2.25.2)\r\n", "Requirement already satisfied: google-auth-oauthlib<2,>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (1.2.0)\r\n", "Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (3.5.1)\r\n", "Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (2.31.0)\r\n", "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (0.7.2)\r\n", "Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (3.0.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: psutil in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from portpicker->dm-reverb~=0.14.0->tf-agents[reverb]) (5.9.7)\r\n", "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (5.3.2)\r\n", "Requirement already satisfied: pyasn1-modules>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (0.3.0)\r\n", "Requirement already satisfied: rsa<5,>=3.1.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (4.9)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: requests-oauthlib>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth-oauthlib<2,>=0.5->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (1.3.1)\r\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (3.3.2)\r\n", "Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (3.6)\r\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (2.1.0)\r\n", "Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (2023.11.17)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (2.1.3)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (0.5.1)\r\n", "Requirement already satisfied: oauthlib>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<2,>=0.5->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (3.2.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: tf-keras in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (2.15.0)\r\n" ] } ], "source": [ "#@test {\"skip\": true}\n", "!sudo apt-get update\n", "!sudo apt-get install -y xvfb ffmpeg python-opengl\n", "!pip install pyglet\n", "!pip install 'imageio==2.4.0'\n", "!pip install 'xvfbwrapper==0.2.9'\n", "!pip install tf-agents[reverb]\n", "!pip install tf-keras" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:17:44.967040Z", "iopub.status.busy": "2023-12-22T12:17:44.966776Z", "iopub.status.idle": "2023-12-22T12:17:44.970661Z", "shell.execute_reply": "2023-12-22T12:17:44.970022Z" }, "id": "WPuD0bMEY9Iz" }, "outputs": [], "source": [ "import os\n", "# Keep using keras-2 (tf-keras) rather than keras-3 (keras).\n", "os.environ['TF_USE_LEGACY_KERAS'] = '1'" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:17:44.973884Z", "iopub.status.busy": "2023-12-22T12:17:44.973671Z", "iopub.status.idle": "2023-12-22T12:17:48.184882Z", "shell.execute_reply": "2023-12-22T12:17:48.184131Z" }, "id": "bQMULMo1dCEn" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-12-22 12:17:45.769390: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2023-12-22 12:17:45.769434: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2023-12-22 12:17:45.771084: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "from __future__ import absolute_import\n", "from __future__ import division\n", "from __future__ import print_function\n", "\n", "import base64\n", "import imageio\n", "import io\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import os\n", "import shutil\n", "import tempfile\n", "import tensorflow as tf\n", "import zipfile\n", "import IPython\n", "\n", "try:\n", " from google.colab import files\n", "except ImportError:\n", " files = None\n", "from tf_agents.agents.dqn import dqn_agent\n", "from tf_agents.drivers import dynamic_step_driver\n", "from tf_agents.environments import suite_gym\n", "from tf_agents.environments import tf_py_environment\n", "from tf_agents.eval import metric_utils\n", "from tf_agents.metrics import tf_metrics\n", "from tf_agents.networks import q_network\n", "from tf_agents.policies import policy_saver\n", "from tf_agents.policies import py_tf_eager_policy\n", "from tf_agents.policies import random_tf_policy\n", "from tf_agents.replay_buffers import tf_uniform_replay_buffer\n", "from tf_agents.trajectories import trajectory\n", "from tf_agents.utils import common\n", "\n", "tempdir = os.getenv(\"TEST_TMPDIR\", tempfile.gettempdir())" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:17:48.189401Z", "iopub.status.busy": "2023-12-22T12:17:48.188725Z", "iopub.status.idle": "2023-12-22T12:17:48.312017Z", "shell.execute_reply": "2023-12-22T12:17:48.310712Z" }, "id": "AwIqiLdDCX9Q" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "# Set up a virtual display for rendering OpenAI gym environments.\n", "import xvfbwrapper\n", "xvfbwrapper.Xvfb(1400, 900, 24).start()" ] }, { "cell_type": "markdown", "metadata": { "id": "AOv_kofIvWnW" }, "source": [ "## DQN agent\n", "We are going to set up DQN agent, just like in the previous colab. The details are hidden by default as they are not core part of this colab, but you can click on 'SHOW CODE' to see the details." ] }, { "cell_type": "markdown", "metadata": { "id": "cStmaxredFSW" }, "source": [ "### Hyperparameters" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "cellView": "both", "execution": { "iopub.execute_input": "2023-12-22T12:17:48.316590Z", "iopub.status.busy": "2023-12-22T12:17:48.316292Z", "iopub.status.idle": "2023-12-22T12:17:48.321242Z", "shell.execute_reply": "2023-12-22T12:17:48.320318Z" }, "id": "yxFs6QU0dGI_" }, "outputs": [], "source": [ "env_name = \"CartPole-v1\"\n", "\n", "collect_steps_per_iteration = 100\n", "replay_buffer_capacity = 100000\n", "\n", "fc_layer_params = (100,)\n", "\n", "batch_size = 64\n", "learning_rate = 1e-3\n", "log_interval = 5\n", "\n", "num_eval_episodes = 10\n", "eval_interval = 1000" ] }, { "cell_type": "markdown", "metadata": { "id": "w4GR7RDndIOR" }, "source": [ "### Environment" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:17:48.324757Z", "iopub.status.busy": "2023-12-22T12:17:48.324325Z", "iopub.status.idle": "2023-12-22T12:17:48.364840Z", "shell.execute_reply": "2023-12-22T12:17:48.364015Z" }, "id": "fZwK4d-bdI7Z" }, "outputs": [], "source": [ "train_py_env = suite_gym.load(env_name)\n", "eval_py_env = suite_gym.load(env_name)\n", "\n", "train_env = tf_py_environment.TFPyEnvironment(train_py_env)\n", "eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)" ] }, { "cell_type": "markdown", "metadata": { "id": "0AvYRwfkeMvo" }, "source": [ "### Agent" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:17:48.368742Z", "iopub.status.busy": "2023-12-22T12:17:48.368050Z", "iopub.status.idle": "2023-12-22T12:17:50.878726Z", "shell.execute_reply": "2023-12-22T12:17:50.877981Z" }, "id": "cUrFl83ieOvV" }, "outputs": [], "source": [ "#@title\n", "q_net = q_network.QNetwork(\n", " train_env.observation_spec(),\n", " train_env.action_spec(),\n", " fc_layer_params=fc_layer_params)\n", "\n", "optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)\n", "\n", "global_step = tf.compat.v1.train.get_or_create_global_step()\n", "\n", "agent = dqn_agent.DqnAgent(\n", " train_env.time_step_spec(),\n", " train_env.action_spec(),\n", " q_network=q_net,\n", " optimizer=optimizer,\n", " td_errors_loss_fn=common.element_wise_squared_loss,\n", " train_step_counter=global_step)\n", "agent.initialize()" ] }, { "cell_type": "markdown", "metadata": { "id": "p8ganoJhdsbn" }, "source": [ "### Data Collection" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:17:50.883217Z", "iopub.status.busy": "2023-12-22T12:17:50.882708Z", "iopub.status.idle": "2023-12-22T12:17:54.156524Z", "shell.execute_reply": "2023-12-22T12:17:54.155740Z" }, "id": "XiT1p78HdtSe" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py:377: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use `as_dataset(..., single_deterministic_pass=False) instead.\n" ] } ], "source": [ "#@title\n", "replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(\n", " data_spec=agent.collect_data_spec,\n", " batch_size=train_env.batch_size,\n", " max_length=replay_buffer_capacity)\n", "\n", "collect_driver = dynamic_step_driver.DynamicStepDriver(\n", " train_env,\n", " agent.collect_policy,\n", " observers=[replay_buffer.add_batch],\n", " num_steps=collect_steps_per_iteration)\n", "\n", "# Initial data collection\n", "collect_driver.run()\n", "\n", "# Dataset generates trajectories with shape [BxTx...] where\n", "# T = n_step_update + 1.\n", "dataset = replay_buffer.as_dataset(\n", " num_parallel_calls=3, sample_batch_size=batch_size,\n", " num_steps=2).prefetch(3)\n", "\n", "iterator = iter(dataset)" ] }, { "cell_type": "markdown", "metadata": { "id": "8V8bojrKdupW" }, "source": [ "### Train the agent" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:17:54.160722Z", "iopub.status.busy": "2023-12-22T12:17:54.160440Z", "iopub.status.idle": "2023-12-22T12:17:54.165498Z", "shell.execute_reply": "2023-12-22T12:17:54.164865Z" }, "id": "-rDC3leXdvm_" }, "outputs": [], "source": [ "#@title\n", "# (Optional) Optimize by wrapping some of the code in a graph using TF function.\n", "agent.train = common.function(agent.train)\n", "\n", "def train_one_iteration():\n", "\n", " # Collect a few steps using collect_policy and save to the replay buffer.\n", " collect_driver.run()\n", "\n", " # Sample a batch of data from the buffer and update the agent's network.\n", " experience, unused_info = next(iterator)\n", " train_loss = agent.train(experience)\n", "\n", " iteration = agent.train_step_counter.numpy()\n", " print ('iteration: {0} loss: {1}'.format(iteration, train_loss.loss))" ] }, { "cell_type": "markdown", "metadata": { "id": "vgqVaPnUeDAn" }, "source": [ "### Video Generation" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:17:54.168997Z", "iopub.status.busy": "2023-12-22T12:17:54.168485Z", "iopub.status.idle": "2023-12-22T12:17:54.174149Z", "shell.execute_reply": "2023-12-22T12:17:54.173501Z" }, "id": "ZY6w-fcieFDW" }, "outputs": [], "source": [ "#@title\n", "def embed_gif(gif_buffer):\n", " \"\"\"Embeds a gif file in the notebook.\"\"\"\n", " tag = ''.format(base64.b64encode(gif_buffer).decode())\n", " return IPython.display.HTML(tag)\n", "\n", "def run_episodes_and_create_video(policy, eval_tf_env, eval_py_env):\n", " num_episodes = 3\n", " frames = []\n", " for _ in range(num_episodes):\n", " time_step = eval_tf_env.reset()\n", " frames.append(eval_py_env.render())\n", " while not time_step.is_last():\n", " action_step = policy.action(time_step)\n", " time_step = eval_tf_env.step(action_step.action)\n", " frames.append(eval_py_env.render())\n", " gif_file = io.BytesIO()\n", " imageio.mimsave(gif_file, frames, format='gif', fps=60)\n", " IPython.display.display(embed_gif(gif_file.getvalue()))" ] }, { "cell_type": "markdown", "metadata": { "id": "y-oA8VYJdFdj" }, "source": [ "### Generate a video\n", "Check the performance of the policy by generating a video." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:17:54.177371Z", "iopub.status.busy": "2023-12-22T12:17:54.176983Z", "iopub.status.idle": "2023-12-22T12:17:55.558737Z", "shell.execute_reply": "2023-12-22T12:17:55.558036Z" }, "id": "FpmPLXWbdG70" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "global_step:\n", "\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print ('global_step:')\n", "print (global_step)\n", "run_episodes_and_create_video(agent.policy, eval_env, eval_py_env)" ] }, { "cell_type": "markdown", "metadata": { "id": "7RPLExsxwnOm" }, "source": [ "## Setup Checkpointer and PolicySaver\n", "\n", "Now we are ready to use Checkpointer and PolicySaver." ] }, { "cell_type": "markdown", "metadata": { "id": "g-iyQJacfQqO" }, "source": [ "### Checkpointer\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:17:55.562820Z", "iopub.status.busy": "2023-12-22T12:17:55.562174Z", "iopub.status.idle": "2023-12-22T12:17:55.569110Z", "shell.execute_reply": "2023-12-22T12:17:55.568454Z" }, "id": "2DzCJZ-6YYbX" }, "outputs": [], "source": [ "checkpoint_dir = os.path.join(tempdir, 'checkpoint')\n", "train_checkpointer = common.Checkpointer(\n", " ckpt_dir=checkpoint_dir,\n", " max_to_keep=1,\n", " agent=agent,\n", " policy=agent.policy,\n", " replay_buffer=replay_buffer,\n", " global_step=global_step\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "MKpWNZM4WE8d" }, "source": [ "### Policy Saver" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:17:55.572662Z", "iopub.status.busy": "2023-12-22T12:17:55.572195Z", "iopub.status.idle": "2023-12-22T12:17:55.651927Z", "shell.execute_reply": "2023-12-22T12:17:55.651205Z" }, "id": "8mDZ_YMUWEY9" }, "outputs": [], "source": [ "policy_dir = os.path.join(tempdir, 'policy')\n", "tf_policy_saver = policy_saver.PolicySaver(agent.policy)" ] }, { "cell_type": "markdown", "metadata": { "id": "1OnANb1Idx8-" }, "source": [ "### Train one iteration" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:17:55.655334Z", "iopub.status.busy": "2023-12-22T12:17:55.655104Z", "iopub.status.idle": "2023-12-22T12:17:58.571309Z", "shell.execute_reply": "2023-12-22T12:17:58.570592Z" }, "id": "ql_D1iq8dl0X" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training one iteration....\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py:1260: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "back_prop=False is deprecated. Consider using tf.stop_gradient instead.\n", "Instead of:\n", "results = tf.foldr(fn, elems, back_prop=False)\n", "Use:\n", "results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py:1260: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "back_prop=False is deprecated. Consider using tf.stop_gradient instead.\n", "Instead of:\n", "results = tf.foldr(fn, elems, back_prop=False)\n", "Use:\n", "results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "iteration: 1 loss: 0.936349093914032\n" ] } ], "source": [ "#@test {\"skip\": true}\n", "print('Training one iteration....')\n", "train_one_iteration()" ] }, { "cell_type": "markdown", "metadata": { "id": "eSChNSQPlySb" }, "source": [ "### Save to checkpoint" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:17:58.574712Z", "iopub.status.busy": "2023-12-22T12:17:58.574461Z", "iopub.status.idle": "2023-12-22T12:17:58.651469Z", "shell.execute_reply": "2023-12-22T12:17:58.650645Z" }, "id": "usDm_Wpsl0bu" }, "outputs": [], "source": [ "train_checkpointer.save(global_step)" ] }, { "cell_type": "markdown", "metadata": { "id": "gTQUrKgihuic" }, "source": [ "### Restore checkpoint\n", "\n", "For this to work, the whole set of objects should be recreated the same way as when the checkpoint was created." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:17:58.655147Z", "iopub.status.busy": "2023-12-22T12:17:58.654894Z", "iopub.status.idle": "2023-12-22T12:17:58.658354Z", "shell.execute_reply": "2023-12-22T12:17:58.657712Z" }, "id": "l6l3EB-Yhwmz" }, "outputs": [], "source": [ "train_checkpointer.initialize_or_restore()\n", "global_step = tf.compat.v1.train.get_global_step()" ] }, { "cell_type": "markdown", "metadata": { "id": "Nb8_MSE2XjRp" }, "source": [ "Also save policy and export to a location" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:17:58.662009Z", "iopub.status.busy": "2023-12-22T12:17:58.661310Z", "iopub.status.idle": "2023-12-22T12:17:59.022986Z", "shell.execute_reply": "2023-12-22T12:17:59.022352Z" }, "id": "3xHz09WCWjwA" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:`0/step_type` is not a valid tf.function parameter name. Sanitizing to `arg_0_step_type`.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:`0/reward` is not a valid tf.function parameter name. Sanitizing to `arg_0_reward`.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:`0/discount` is not a valid tf.function parameter name. Sanitizing to `arg_0_discount`.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:`0/observation` is not a valid tf.function parameter name. Sanitizing to `arg_0_observation`.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:`0/step_type` is not a valid tf.function parameter name. Sanitizing to `arg_0_step_type`.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmpfs/tmp/policy/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:458: UserWarning: Encoding a StructuredValue with type tfp.distributions.Deterministic_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered.\n", " warnings.warn(\"Encoding a StructuredValue with type %s; loading this \"\n", "INFO:tensorflow:Assets written to: /tmpfs/tmp/policy/assets\n" ] } ], "source": [ "tf_policy_saver.save(policy_dir)" ] }, { "cell_type": "markdown", "metadata": { "id": "Mz-xScbuh4Vo" }, "source": [ "The policy can be loaded without having any knowledge of what agent or network was used to create it. This makes deployment of the policy much easier.\n", "\n", "Load the saved policy and check how it performs" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:17:59.026264Z", "iopub.status.busy": "2023-12-22T12:17:59.026032Z", "iopub.status.idle": "2023-12-22T12:18:00.086727Z", "shell.execute_reply": "2023-12-22T12:18:00.086041Z" }, "id": "J6T5KLTMh9ZB" }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "saved_policy = tf.saved_model.load(policy_dir)\n", "run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)" ] }, { "cell_type": "markdown", "metadata": { "id": "MpE0KKfqjc0c" }, "source": [ "## Export and import\n", "The rest of the colab will help you export / import checkpointer and policy directories such that you can continue training at a later point and deploy the model without having to train again.\n", "\n", "Now you can go back to 'Train one iteration' and train a few more times such that you can understand the difference later on. Once you start to see slightly better results, continue below." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:18:00.090574Z", "iopub.status.busy": "2023-12-22T12:18:00.090314Z", "iopub.status.idle": "2023-12-22T12:18:00.095718Z", "shell.execute_reply": "2023-12-22T12:18:00.095059Z" }, "id": "fd5Cj7DVjfH4" }, "outputs": [], "source": [ "#@title Create zip file and upload zip file (double-click to see the code)\n", "def create_zip_file(dirname, base_filename):\n", " return shutil.make_archive(base_filename, 'zip', dirname)\n", "\n", "def upload_and_unzip_file_to(dirname):\n", " if files is None:\n", " return\n", " uploaded = files.upload()\n", " for fn in uploaded.keys():\n", " print('User uploaded file \"{name}\" with length {length} bytes'.format(\n", " name=fn, length=len(uploaded[fn])))\n", " shutil.rmtree(dirname)\n", " zip_files = zipfile.ZipFile(io.BytesIO(uploaded[fn]), 'r')\n", " zip_files.extractall(dirname)\n", " zip_files.close()" ] }, { "cell_type": "markdown", "metadata": { "id": "hgyy29doHCmL" }, "source": [ "Create a zipped file from the checkpoint directory." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:18:00.099119Z", "iopub.status.busy": "2023-12-22T12:18:00.098721Z", "iopub.status.idle": "2023-12-22T12:18:00.177200Z", "shell.execute_reply": "2023-12-22T12:18:00.176598Z" }, "id": "nhR8NeWzF4fe" }, "outputs": [], "source": [ "train_checkpointer.save(global_step)\n", "checkpoint_zip_filename = create_zip_file(checkpoint_dir, os.path.join(tempdir, 'exported_cp'))" ] }, { "cell_type": "markdown", "metadata": { "id": "VGEpntTocd2u" }, "source": [ "Download the zip file." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:18:00.180644Z", "iopub.status.busy": "2023-12-22T12:18:00.180392Z", "iopub.status.idle": "2023-12-22T12:18:00.183846Z", "shell.execute_reply": "2023-12-22T12:18:00.183167Z" }, "id": "upFxb5k8b4MC" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "if files is not None:\n", " files.download(checkpoint_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469" ] }, { "cell_type": "markdown", "metadata": { "id": "VRaZMrn5jLmE" }, "source": [ "After training for some time (10-15 times), download the checkpoint zip file,\n", "and go to \"Runtime > Restart and run all\" to reset the training,\n", "and come back to this cell. Now you can upload the downloaded zip file,\n", "and continue the training." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:18:00.186822Z", "iopub.status.busy": "2023-12-22T12:18:00.186591Z", "iopub.status.idle": "2023-12-22T12:18:00.190037Z", "shell.execute_reply": "2023-12-22T12:18:00.189465Z" }, "id": "kg-bKgMsF-H_" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "upload_and_unzip_file_to(checkpoint_dir)\n", "train_checkpointer.initialize_or_restore()\n", "global_step = tf.compat.v1.train.get_global_step()" ] }, { "cell_type": "markdown", "metadata": { "id": "uXrNax5Zk3vF" }, "source": [ "Once you have uploaded checkpoint directory, go back to 'Train one iteration' to continue training or go back to 'Generate a video' to check the performance of the loaded policy." ] }, { "cell_type": "markdown", "metadata": { "id": "OAkvVZ-NeN2j" }, "source": [ "Alternatively, you can save the policy (model) and restore it.\n", "Unlike checkpointer, you cannot continue with the training, but you can still deploy the model. Note that the downloaded file is much smaller than that of the checkpointer." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:18:00.193096Z", "iopub.status.busy": "2023-12-22T12:18:00.192846Z", "iopub.status.idle": "2023-12-22T12:18:00.380151Z", "shell.execute_reply": "2023-12-22T12:18:00.379543Z" }, "id": "s7qMn6D8eiIA" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmpfs/tmp/policy/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:458: UserWarning: Encoding a StructuredValue with type tfp.distributions.Deterministic_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered.\n", " warnings.warn(\"Encoding a StructuredValue with type %s; loading this \"\n", "INFO:tensorflow:Assets written to: /tmpfs/tmp/policy/assets\n" ] } ], "source": [ "tf_policy_saver.save(policy_dir)\n", "policy_zip_filename = create_zip_file(policy_dir, os.path.join(tempdir, 'exported_policy'))" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:18:00.383610Z", "iopub.status.busy": "2023-12-22T12:18:00.383342Z", "iopub.status.idle": "2023-12-22T12:18:00.386738Z", "shell.execute_reply": "2023-12-22T12:18:00.386133Z" }, "id": "rrGvCEXwerJj" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "if files is not None:\n", " files.download(policy_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469" ] }, { "cell_type": "markdown", "metadata": { "id": "DyC_O_gsgSi5" }, "source": [ "Upload the downloaded policy directory (exported_policy.zip) and check how the saved policy performs." ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:18:00.390351Z", "iopub.status.busy": "2023-12-22T12:18:00.389734Z", "iopub.status.idle": "2023-12-22T12:18:01.434242Z", "shell.execute_reply": "2023-12-22T12:18:01.433529Z" }, "id": "bgWLimRlXy5z" }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#@test {\"skip\": true}\n", "upload_and_unzip_file_to(policy_dir)\n", "saved_policy = tf.saved_model.load(policy_dir)\n", "run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "HSehXThTm4af" }, "source": [ "## SavedModelPyTFEagerPolicy\n", "\n", "If you don't want to use TF policy, then you can also use the saved_model directly with the Python env through the use of `py_tf_eager_policy.SavedModelPyTFEagerPolicy`.\n", "\n", "Note that this only works when eager mode is enabled." ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:18:01.437803Z", "iopub.status.busy": "2023-12-22T12:18:01.437550Z", "iopub.status.idle": "2023-12-22T12:18:02.477866Z", "shell.execute_reply": "2023-12-22T12:18:02.477183Z" }, "id": "iUC5XuLf1jF7" }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(\n", " policy_dir, eval_py_env.time_step_spec(), eval_py_env.action_spec())\n", "\n", "# Note that we're passing eval_py_env not eval_env.\n", "run_episodes_and_create_video(eager_py_policy, eval_py_env, eval_py_env)" ] }, { "cell_type": "markdown", "metadata": { "id": "7fvWqfJg00ww" }, "source": [ "## Convert policy to TFLite\n", "\n", "See [TensorFlow Lite converter](https://www.tensorflow.org/lite/convert) for more details." ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:18:02.481809Z", "iopub.status.busy": "2023-12-22T12:18:02.481493Z", "iopub.status.idle": "2023-12-22T12:18:02.721684Z", "shell.execute_reply": "2023-12-22T12:18:02.720887Z" }, "id": "z9zonVBJ0z46" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-12-22 12:18:02.611792: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.\n", "2023-12-22 12:18:02.611828: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.\n", "Summary on the non-converted ops:\n", "---------------------------------\n", " * Accepted dialects: tfl, builtin, func\n", " * Non-Converted Ops: 11, Total Ops 28, % non-converted = 39.29 %\n", " * 11 ARITH ops\n", "\n", "- arith.constant: 11 occurrences (i64: 2, f32: 4, i32: 5)\n", "\n", "\n", "\n", " (i64: 1)\n", " (i32: 1)\n", " (i64: 1)\n", " (i32: 1)\n", " (f32: 2)\n", " (i64: 1)\n", " (i64: 1)\n", " (i64: 1, f32: 1)\n", " (i32: 2)\n", " (i32: 2)\n" ] } ], "source": [ "converter = tf.lite.TFLiteConverter.from_saved_model(policy_dir, signature_keys=[\"action\"])\n", "tflite_policy = converter.convert()\n", "with open(os.path.join(tempdir, 'policy.tflite'), 'wb') as f:\n", " f.write(tflite_policy)" ] }, { "cell_type": "markdown", "metadata": { "id": "rsi3V9QdxJUu" }, "source": [ "### Run inference on TFLite model\n", "\n", "See [TensorFlow Lite Inference](https://tensorflow.org/lite/guide/inference) for more details." ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:18:02.725083Z", "iopub.status.busy": "2023-12-22T12:18:02.724821Z", "iopub.status.idle": "2023-12-22T12:18:02.730269Z", "shell.execute_reply": "2023-12-22T12:18:02.729589Z" }, "id": "4GeUSWyZxMlN" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'0/discount': 1, '0/observation': 2, '0/reward': 3, '0/step_type': 0}\n" ] } ], "source": [ "import numpy as np\n", "interpreter = tf.lite.Interpreter(os.path.join(tempdir, 'policy.tflite'))\n", "\n", "policy_runner = interpreter.get_signature_runner()\n", "print(policy_runner._inputs)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T12:18:02.733131Z", "iopub.status.busy": "2023-12-22T12:18:02.732895Z", "iopub.status.idle": "2023-12-22T12:18:02.738876Z", "shell.execute_reply": "2023-12-22T12:18:02.738263Z" }, "id": "eVVrdTbRxnOC" }, "outputs": [ { "data": { "text/plain": [ "{'action': array([1])}" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "policy_runner(**{\n", " '0/discount':tf.constant(0.0),\n", " '0/observation':tf.zeros([1,4]),\n", " '0/reward':tf.constant(0.0),\n", " '0/step_type':tf.constant(0)})" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "TF-Agent Checkpointer / PolicySaver Colab", "private_outputs": true, "provenance": [ { "file_id": "1soe3ixbJxESeOTxhVcGW1o9ZsU-uRVK7", "timestamp": 1641308536614 }, { "file_id": "https://github.com/tensorflow/agents/blob/master/docs/tutorials/10_checkpointer_policysaver_tutorial.ipynb", "timestamp": 1641307902610 }, { "file_id": "12InF1JXmpmA_qCgRScMO736YjqRoHWxT", "timestamp": 1627303299731 }, { "file_id": "https://github.com/tensorflow/agents/blob/master/docs/tutorials/10_checkpointer_policysaver_tutorial.ipynb", "timestamp": 1627302328422 } ], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }