{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2019 The TensorFlow Authors.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "# Custom training loop with Keras and MultiWorkerMirroredStrategy\n", "\n", "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "xHxb-dlhMIzW" }, "source": [ "## Overview\n", "\n", "This tutorial demonstrates how to perform multi-worker distributed training with a Keras model and with [custom training loops](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch) using the `tf.distribute.Strategy` API. The training loop is distributed via `tf.distribute.MultiWorkerMirroredStrategy`, such that a `tf.keras` model—designed to run on [single-worker](custom_training.ipynb)—can seamlessly work on multiple workers with minimal code changes. Custom training loops provide flexibility and a greater control on training, while also making it easier to debug the model. Learn more about [writing a basic training loop](../../guide/basic_training_loops.ipynb), [writing a training loop from scratch](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch) and [custom training](../customization/custom_training_walkthrough.ipynb).\n", "\n", "If you are looking for how to use `MultiWorkerMirroredStrategy` with `tf.keras.Model.fit`, refer to this [tutorial](multi_worker_with_keras.ipynb) instead.\n", "\n", "[Distributed Training in TensorFlow](../../guide/distributed_training.ipynb) guide is available for an overview of the distribution strategies TensorFlow supports for those interested in a deeper understanding of `tf.distribute.Strategy` APIs." ] }, { "cell_type": "markdown", "metadata": { "id": "MUXex9ctTuDB" }, "source": [ "## Setup\n", "\n", "First, some necessary imports." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bnYxvfLD-LW-" }, "outputs": [], "source": [ "import json\n", "import os\n", "import sys" ] }, { "cell_type": "markdown", "metadata": { "id": "Zz0EY91y3mxy" }, "source": [ "Before importing TensorFlow, make a few changes to the environment:\n", "* Disable all GPUs. This prevents errors caused by all workers trying to use the same GPU. In a real-world application, each worker would be on a different machine." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "685pbYEY3jGC" }, "outputs": [], "source": [ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\"" ] }, { "cell_type": "markdown", "metadata": { "id": "7X1MS6385BWi" }, "source": [ "* Reset the `'TF_CONFIG'` environment variable (you'll see more about this later)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WEJLYa2_7OZF" }, "outputs": [], "source": [ "os.environ.pop('TF_CONFIG', None)" ] }, { "cell_type": "markdown", "metadata": { "id": "Rd4L9Ii77SS8" }, "source": [ "* Make sure that the current directory is on Python's path. This allows the notebook to import the files written by `%%writefile` later.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hPBuZUNSZmrQ" }, "outputs": [], "source": [ "if '.' not in sys.path:\n", " sys.path.insert(0, '.')" ] }, { "cell_type": "markdown", "metadata": { "id": "pDhHuMjb7bfU" }, "source": [ "Now import TensorFlow." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vHNvttzV43sA" }, "outputs": [], "source": [ "import tensorflow as tf" ] }, { "cell_type": "markdown", "metadata": { "id": "0S2jpf6Sx50i" }, "source": [ "### Dataset and model definition" ] }, { "cell_type": "markdown", "metadata": { "id": "fLW6D2TzvC-4" }, "source": [ "Next, create an `mnist.py` file with a simple model and dataset setup. This Python file will be used by the worker-processes in this tutorial:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dma_wUAxZqo2" }, "outputs": [], "source": [ "%%writefile mnist.py\n", "\n", "import os\n", "import tensorflow as tf\n", "import numpy as np\n", "\n", "def mnist_dataset(batch_size):\n", " (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()\n", " # The `x` arrays are in uint8 and have values in the range [0, 255].\n", " # You need to convert them to float32 with values in the range [0, 1]\n", " x_train = x_train / np.float32(255)\n", " y_train = y_train.astype(np.int64)\n", " train_dataset = tf.data.Dataset.from_tensor_slices(\n", " (x_train, y_train)).shuffle(60000)\n", " return train_dataset\n", "\n", "def dataset_fn(global_batch_size, input_context):\n", " batch_size = input_context.get_per_replica_batch_size(global_batch_size)\n", " dataset = mnist_dataset(batch_size)\n", " dataset = dataset.shard(input_context.num_input_pipelines,\n", " input_context.input_pipeline_id)\n", " dataset = dataset.batch(batch_size)\n", " return dataset\n", "\n", "def build_cnn_model():\n", " regularizer = tf.keras.regularizers.L2(1e-5)\n", " return tf.keras.Sequential([\n", " tf.keras.Input(shape=(28, 28)),\n", " tf.keras.layers.Reshape(target_shape=(28, 28, 1)),\n", " tf.keras.layers.Conv2D(32, 3,\n", " activation='relu',\n", " kernel_regularizer=regularizer),\n", " tf.keras.layers.Flatten(),\n", " tf.keras.layers.Dense(128,\n", " activation='relu',\n", " kernel_regularizer=regularizer),\n", " tf.keras.layers.Dense(10, kernel_regularizer=regularizer)\n", " ])" ] }, { "cell_type": "markdown", "metadata": { "id": "JmgZwwymxqt5" }, "source": [ "## Multi-worker configuration\n", "\n", "Now let's enter the world of multi-worker training. In TensorFlow, the `'TF_CONFIG'` environment variable is required for training on multiple machines. Each machine may have a different role. The `'TF_CONFIG'` variable used below is a JSON string that specifies the cluster configuration on each worker that is part of the cluster. This is the default method for specifying a cluster, using `cluster_resolver.TFConfigClusterResolver`, but there are other options available in the `distribute.cluster_resolver` module. Learn more about setting up the `'TF_CONFIG'` variable in the [Distributed training guide](../../guide/distributed_training.ipynb)." ] }, { "cell_type": "markdown", "metadata": { "id": "SS8WhvRhe_Ya" }, "source": [ "### Describe your cluster\n", "Here is an example configuration:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XK1eTYvSZiX7" }, "outputs": [], "source": [ "tf_config = {\n", " 'cluster': {\n", " 'worker': ['localhost:12345', 'localhost:23456']\n", " },\n", " 'task': {'type': 'worker', 'index': 0}\n", "}" ] }, { "cell_type": "markdown", "metadata": { "id": "JjgwJbPKZkJL" }, "source": [ "Note that `tf_config` is just a local variable in Python. To use it for training configuration, serialize it as a JSON and place it in a `'TF_CONFIG'` environment variable. Here is the same `'TF_CONFIG'` serialized as a JSON string:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yY-T0YDQZjbu" }, "outputs": [], "source": [ "json.dumps(tf_config)" ] }, { "cell_type": "markdown", "metadata": { "id": "AUBmYRZqxthH" }, "source": [ "There are two components of `'TF_CONFIG'`: `'cluster'` and `'task'`.\n", "\n", "* `'cluster'` is the same for all workers and provides information about the training cluster, which is a dict consisting of different types of jobs such as `'worker'`. In multi-worker training with `MultiWorkerMirroredStrategy`, there is usually one `'worker'` that takes on a little more responsibility like saving checkpoints and writing summary files for TensorBoard in addition to what a regular `'worker'` does. Such a worker is referred to as the `'chief'` worker, and it is customary that the `'worker'` with `'index'` 0 is appointed as the chief `worker`.\n", "\n", "* `'task'` provides information of the current task and is different on each worker. It specifies the `'type'` and `'index'` of that worker." ] }, { "cell_type": "markdown", "metadata": { "id": "8YFpxrcsZ2xG" }, "source": [ "In this example, you set the task `'type'` to `'worker'` and the task `'index'` to `0`. This machine is the first worker and will be appointed as the chief worker and do more work than the others. Note that other machines will need to have the `'TF_CONFIG'` environment variable set as well, and it should have the same `'cluster'` dict, but different task `'type'` or task `'index'` depending on what the roles of those machines are.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "aogb74kHxynz" }, "source": [ "For illustration purposes, this tutorial shows how one may set a `'TF_CONFIG'` with two workers on `'localhost'`. In practice, users would create multiple workers on external IP addresses/ports, and set `'TF_CONFIG'` on each worker appropriately.\n", "\n", "This example uses two workers. The first worker's `'TF_CONFIG'` is shown above. For the second worker, set `tf_config['task']['index']=1`." ] }, { "cell_type": "markdown", "metadata": { "id": "cIlkfWmjz1PG" }, "source": [ "### Environment variables and subprocesses in notebooks" ] }, { "cell_type": "markdown", "metadata": { "id": "FcjAbuGY1ACJ" }, "source": [ "Subprocesses inherit environment variables from their parent. So if you set an environment variable in this Jupyter Notebook process:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PH2gHn2_0_U8" }, "outputs": [], "source": [ "os.environ['GREETINGS'] = 'Hello TensorFlow!'" ] }, { "cell_type": "markdown", "metadata": { "id": "gQkIX-cg18md" }, "source": [ "you can then access the environment variable from a subprocess:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pquKO6IA18G5" }, "outputs": [], "source": [ "%%bash\n", "echo ${GREETINGS}" ] }, { "cell_type": "markdown", "metadata": { "id": "af6BCA-Y2fpz" }, "source": [ "In the next section, you'll use this to pass the `'TF_CONFIG'` to the worker subprocesses. You would never really launch your jobs this way, but it's sufficient for the purposes of this tutorial: To demonstrate a minimal multi-worker example." ] }, { "cell_type": "markdown", "metadata": { "id": "UhNtHfuxCGVy" }, "source": [ "## MultiWorkerMirroredStrategy\n", "\n", "Before training the model, first create an instance of `tf.distribute.MultiWorkerMirroredStrategy`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1uFSHCJXMrQ-" }, "outputs": [], "source": [ "strategy = tf.distribute.MultiWorkerMirroredStrategy()" ] }, { "cell_type": "markdown", "metadata": { "id": "N0iv7SyyAohc" }, "source": [ "Note: `'TF_CONFIG'` is parsed and TensorFlow's GRPC servers are started at the time you call `tf.distribute.MultiWorkerMirroredStrategy.` Therefore, you must set the `'TF_CONFIG'` environment variable before you instantiate a `tf.distribute.Strategy`. To save time in this illustrative example, this is not demonstrated in this tutorial, so that servers do not need to start. You can find a full example in the last section of this tutorial." ] }, { "cell_type": "markdown", "metadata": { "id": "TS4S-faBHHam" }, "source": [ "Use `tf.distribute.Strategy.scope` to specify that a strategy should be used when building your model. This allows the strategy to control things like variable placement—it will create copies of all variables in the model's layers on each device across all workers." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nXV49tG1_opc" }, "outputs": [], "source": [ "import mnist\n", "with strategy.scope():\n", " # Model building needs to be within `strategy.scope()`.\n", " multi_worker_model = mnist.build_cnn_model()" ] }, { "cell_type": "markdown", "metadata": { "id": "DSYkM-on6r3Y" }, "source": [ "## Auto-shard your data across workers\n", "\n", "In multi-worker training, _dataset sharding_ is needed to ensure convergence and reproducibility. Sharding means handing each worker a subset of the entire dataset—it helps create the experience similar to training on a single worker. In the example below, you're relying on the default autosharding policy of `tf.distribute`. You can also customize it by setting the `tf.data.experimental.AutoShardPolicy` of the `tf.data.experimental.DistributeOptions`. To learn more, refer to the _Sharding_ section of the [Distributed input tutorial](input.ipynb)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "65-p36pt6rUF" }, "outputs": [], "source": [ "per_worker_batch_size = 64\n", "num_workers = len(tf_config['cluster']['worker'])\n", "global_batch_size = per_worker_batch_size * num_workers\n", "\n", "with strategy.scope():\n", " multi_worker_dataset = strategy.distribute_datasets_from_function(\n", " lambda input_context: mnist.dataset_fn(global_batch_size, input_context))" ] }, { "cell_type": "markdown", "metadata": { "id": "rkNzSR3g60iP" }, "source": [ "## Define a custom training loop and train the model\n", "Specify an optimizer:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NoMr4_zTeKSn" }, "outputs": [], "source": [ "with strategy.scope():\n", " # The creation of optimizer and train_accuracy needs to be in\n", " # `strategy.scope()` as well, since they create variables.\n", " optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)\n", " train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n", " name='train_accuracy')" ] }, { "cell_type": "markdown", "metadata": { "id": "RmrDcAii4B5O" }, "source": [ "Define a training step with `tf.function`:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "znXWN5S3eUDB" }, "outputs": [], "source": [ "@tf.function\n", "def train_step(iterator):\n", " \"\"\"Training step function.\"\"\"\n", "\n", " def step_fn(inputs):\n", " \"\"\"Per-Replica step function.\"\"\"\n", " x, y = inputs\n", " with tf.GradientTape() as tape:\n", " predictions = multi_worker_model(x, training=True)\n", " per_example_loss = tf.keras.losses.SparseCategoricalCrossentropy(\n", " from_logits=True,\n", " reduction=tf.keras.losses.Reduction.NONE)(y, predictions)\n", " loss = tf.nn.compute_average_loss(per_example_loss)\n", " model_losses = multi_worker_model.losses\n", " if model_losses:\n", " loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses))\n", "\n", " grads = tape.gradient(loss, multi_worker_model.trainable_variables)\n", " optimizer.apply_gradients(\n", " zip(grads, multi_worker_model.trainable_variables))\n", " train_accuracy.update_state(y, predictions)\n", " return loss\n", "\n", " per_replica_losses = strategy.run(step_fn, args=(next(iterator),))\n", " return strategy.reduce(\n", " tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)" ] }, { "cell_type": "markdown", "metadata": { "id": "eFXHsUVBy0Rx" }, "source": [ "### Checkpoint saving and restoring\n", "\n", "As you write a custom training loop, you need to handle [checkpoint saving](../../guide/checkpoint.ipynb) manually instead of relying on a Keras callback. Note that for `MultiWorkerMirroredStrategy`, saving a checkpoint or a complete model requires the participation of all workers, because attempting to save only on the chief worker could lead to a deadlock. Workers also need to write to different paths to avoid overwriting each other. Here's an example of how to configure the directories:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LcFO6x1KyjhI" }, "outputs": [], "source": [ "from multiprocessing import util\n", "checkpoint_dir = os.path.join(util.get_temp_dir(), 'ckpt')\n", "\n", "def _is_chief(task_type, task_id, cluster_spec):\n", " return (task_type is None\n", " or task_type == 'chief'\n", " or (task_type == 'worker'\n", " and task_id == 0\n", " and \"chief\" not in cluster_spec.as_dict()))\n", "\n", "def _get_temp_dir(dirpath, task_id):\n", " base_dirpath = 'workertemp_' + str(task_id)\n", " temp_dir = os.path.join(dirpath, base_dirpath)\n", " tf.io.gfile.makedirs(temp_dir)\n", " return temp_dir\n", "\n", "def write_filepath(filepath, task_type, task_id, cluster_spec):\n", " dirpath = os.path.dirname(filepath)\n", " base = os.path.basename(filepath)\n", " if not _is_chief(task_type, task_id, cluster_spec):\n", " dirpath = _get_temp_dir(dirpath, task_id)\n", " return os.path.join(dirpath, base)" ] }, { "cell_type": "markdown", "metadata": { "id": "nrcdPHtG4ObO" }, "source": [ "Create one `tf.train.Checkpoint` that tracks the model, which is managed by a `tf.train.CheckpointManager`, so that only the latest checkpoints are preserved:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4rURT2pI4aqV" }, "outputs": [], "source": [ "epoch = tf.Variable(\n", " initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')\n", "step_in_epoch = tf.Variable(\n", " initial_value=tf.constant(0, dtype=tf.dtypes.int64),\n", " name='step_in_epoch')\n", "task_type, task_id = (strategy.cluster_resolver.task_type,\n", " strategy.cluster_resolver.task_id)\n", "# Normally, you don't need to manually instantiate a `ClusterSpec`, but in this\n", "# illustrative example you did not set `'TF_CONFIG'` before initializing the\n", "# strategy. Check out the next section for \"real-world\" usage.\n", "cluster_spec = tf.train.ClusterSpec(tf_config['cluster'])\n", "\n", "checkpoint = tf.train.Checkpoint(\n", " model=multi_worker_model, epoch=epoch, step_in_epoch=step_in_epoch)\n", "\n", "write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id,\n", " cluster_spec)\n", "checkpoint_manager = tf.train.CheckpointManager(\n", " checkpoint, directory=write_checkpoint_dir, max_to_keep=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "RO7cbN40XD5v" }, "source": [ "Now, when you need to restore a checkpoint, you can find the latest checkpoint saved using the convenient `tf.train.latest_checkpoint` function (or by calling `tf.train.CheckpointManager.restore_or_initialize`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gniynaQj6HMV" }, "outputs": [], "source": [ "latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)\n", "if latest_checkpoint:\n", " checkpoint.restore(latest_checkpoint)" ] }, { "cell_type": "markdown", "metadata": { "id": "1j9JuI-h6ObW" }, "source": [ "After restoring the checkpoint, you can continue with training your custom training loop." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kZzXZCh45FY6" }, "outputs": [], "source": [ "num_epochs = 3\n", "num_steps_per_epoch = 70\n", "\n", "while epoch.numpy() < num_epochs:\n", " iterator = iter(multi_worker_dataset)\n", " total_loss = 0.0\n", " num_batches = 0\n", "\n", " while step_in_epoch.numpy() < num_steps_per_epoch:\n", " total_loss += train_step(iterator)\n", " num_batches += 1\n", " step_in_epoch.assign_add(1)\n", "\n", " train_loss = total_loss / num_batches\n", " print('Epoch: %d, accuracy: %f, train_loss: %f.'\n", " %(epoch.numpy(), train_accuracy.result(), train_loss))\n", "\n", " train_accuracy.reset_states()\n", "\n", " # Once the `CheckpointManager` is set up, you're now ready to save, and remove\n", " # the checkpoints non-chief workers saved.\n", " checkpoint_manager.save()\n", " if not _is_chief(task_type, task_id, cluster_spec):\n", " tf.io.gfile.rmtree(write_checkpoint_dir)\n", "\n", " epoch.assign_add(1)\n", " step_in_epoch.assign(0)" ] }, { "cell_type": "markdown", "metadata": { "id": "0W1Osks466DE" }, "source": [ "## Complete code at a glance" ] }, { "cell_type": "markdown", "metadata": { "id": "jfYpmIxO6Jck" }, "source": [ "To sum up all the procedures discussed so far:\n", "\n", "1. You create worker processes.\n", "2. Pass `'TF_CONFIG'`s to the worker processes.\n", "3. Let each work process run the script below that contains the training code." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MIDCESkVzN6M" }, "outputs": [], "source": [ "%%writefile main.py\n", "#@title File: `main.py`\n", "import os\n", "import json\n", "import tensorflow as tf\n", "import mnist\n", "from multiprocessing import util\n", "\n", "per_worker_batch_size = 64\n", "tf_config = json.loads(os.environ['TF_CONFIG'])\n", "num_workers = len(tf_config['cluster']['worker'])\n", "global_batch_size = per_worker_batch_size * num_workers\n", "\n", "num_epochs = 3\n", "num_steps_per_epoch=70\n", "\n", "# Checkpoint saving and restoring\n", "def _is_chief(task_type, task_id, cluster_spec):\n", " return (task_type is None\n", " or task_type == 'chief'\n", " or (task_type == 'worker'\n", " and task_id == 0\n", " and 'chief' not in cluster_spec.as_dict()))\n", "\n", "def _get_temp_dir(dirpath, task_id):\n", " base_dirpath = 'workertemp_' + str(task_id)\n", " temp_dir = os.path.join(dirpath, base_dirpath)\n", " tf.io.gfile.makedirs(temp_dir)\n", " return temp_dir\n", "\n", "def write_filepath(filepath, task_type, task_id, cluster_spec):\n", " dirpath = os.path.dirname(filepath)\n", " base = os.path.basename(filepath)\n", " if not _is_chief(task_type, task_id, cluster_spec):\n", " dirpath = _get_temp_dir(dirpath, task_id)\n", " return os.path.join(dirpath, base)\n", "\n", "checkpoint_dir = os.path.join(util.get_temp_dir(), 'ckpt')\n", "\n", "# Define Strategy\n", "strategy = tf.distribute.MultiWorkerMirroredStrategy()\n", "\n", "with strategy.scope():\n", " # Model building/compiling need to be within `tf.distribute.Strategy.scope`.\n", " multi_worker_model = mnist.build_cnn_model()\n", "\n", " multi_worker_dataset = strategy.distribute_datasets_from_function(\n", " lambda input_context: mnist.dataset_fn(global_batch_size, input_context))\n", " optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)\n", " train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n", " name='train_accuracy')\n", "\n", "@tf.function\n", "def train_step(iterator):\n", " \"\"\"Training step function.\"\"\"\n", "\n", " def step_fn(inputs):\n", " \"\"\"Per-Replica step function.\"\"\"\n", " x, y = inputs\n", " with tf.GradientTape() as tape:\n", " predictions = multi_worker_model(x, training=True)\n", " per_example_loss = tf.keras.losses.SparseCategoricalCrossentropy(\n", " from_logits=True,\n", " reduction=tf.keras.losses.Reduction.NONE)(y, predictions)\n", " loss = tf.nn.compute_average_loss(per_example_loss)\n", " model_losses = multi_worker_model.losses\n", " if model_losses:\n", " loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses))\n", "\n", " grads = tape.gradient(loss, multi_worker_model.trainable_variables)\n", " optimizer.apply_gradients(\n", " zip(grads, multi_worker_model.trainable_variables))\n", " train_accuracy.update_state(y, predictions)\n", "\n", " return loss\n", "\n", " per_replica_losses = strategy.run(step_fn, args=(next(iterator),))\n", " return strategy.reduce(\n", " tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)\n", "\n", "epoch = tf.Variable(\n", " initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')\n", "step_in_epoch = tf.Variable(\n", " initial_value=tf.constant(0, dtype=tf.dtypes.int64),\n", " name='step_in_epoch')\n", "\n", "task_type, task_id, cluster_spec = (strategy.cluster_resolver.task_type,\n", " strategy.cluster_resolver.task_id,\n", " strategy.cluster_resolver.cluster_spec())\n", "\n", "checkpoint = tf.train.Checkpoint(\n", " model=multi_worker_model, epoch=epoch, step_in_epoch=step_in_epoch)\n", "\n", "write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id,\n", " cluster_spec)\n", "checkpoint_manager = tf.train.CheckpointManager(\n", " checkpoint, directory=write_checkpoint_dir, max_to_keep=1)\n", "\n", "# Restoring the checkpoint\n", "latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)\n", "if latest_checkpoint:\n", " checkpoint.restore(latest_checkpoint)\n", "\n", "# Resume our CTL training\n", "while epoch.numpy() < num_epochs:\n", " iterator = iter(multi_worker_dataset)\n", " total_loss = 0.0\n", " num_batches = 0\n", "\n", " while step_in_epoch.numpy() < num_steps_per_epoch:\n", " total_loss += train_step(iterator)\n", " num_batches += 1\n", " step_in_epoch.assign_add(1)\n", "\n", " train_loss = total_loss / num_batches\n", " print('Epoch: %d, accuracy: %f, train_loss: %f.'\n", " %(epoch.numpy(), train_accuracy.result(), train_loss))\n", "\n", " train_accuracy.reset_states()\n", "\n", " checkpoint_manager.save()\n", " if not _is_chief(task_type, task_id, cluster_spec):\n", " tf.io.gfile.rmtree(write_checkpoint_dir)\n", "\n", " epoch.assign_add(1)\n", " step_in_epoch.assign(0)" ] }, { "cell_type": "markdown", "metadata": { "id": "ItVOvPN1qnZ6" }, "source": [ "The current directory now contains both Python files:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bi6x05Sr60O9" }, "outputs": [], "source": [ "%%bash\n", "ls *.py" ] }, { "cell_type": "markdown", "metadata": { "id": "qmEEStPS6vR_" }, "source": [ "So JSON-serialize the `'TF_CONFIG'` and add it to the environment variables:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9uu3g7vV7Bbt" }, "outputs": [], "source": [ "os.environ['TF_CONFIG'] = json.dumps(tf_config)" ] }, { "cell_type": "markdown", "metadata": { "id": "MsY3dQLK7jdf" }, "source": [ "Now, you can launch a worker process that will run the `main.py` and use the `'TF_CONFIG'`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "txMXaq8d8N_S" }, "outputs": [], "source": [ "# first kill any previous runs\n", "%killbgscripts" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qnSma_Ck7r-r" }, "outputs": [], "source": [ "%%bash --bg\n", "python main.py &> job_0.log" ] }, { "cell_type": "markdown", "metadata": { "id": "ZChyazqS7v0P" }, "source": [ "There are a few things to note about the above command:\n", "\n", "1. It uses the `%%bash` which is a [notebook \"magic\"](https://ipython.readthedocs.io/en/stable/interactive/magics.html) to run some bash commands.\n", "2. It uses the `--bg` flag to run the `bash` process in the background, because this worker will not terminate. It waits for all the workers before it starts.\n", "\n", "The backgrounded worker process won't print the output to this notebook. The `&>` redirects its output to a file, so that you can inspect what happened.\n", "\n", "Wait a few seconds for the process to start up:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Hm2yrULE9281" }, "outputs": [], "source": [ "import time\n", "time.sleep(20)" ] }, { "cell_type": "markdown", "metadata": { "id": "ZFPoNxg_9_Mx" }, "source": [ "Now, check the output to the worker's log file so far:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vZEOuVgQ9-hn" }, "outputs": [], "source": [ "%%bash\n", "cat job_0.log" ] }, { "cell_type": "markdown", "metadata": { "id": "RqZhVF7L_KOy" }, "source": [ "The last line of the log file should say: `Started server with target: grpc://localhost:12345`. The first worker is now ready, and is waiting for all the other worker(s) to be ready to proceed." ] }, { "cell_type": "markdown", "metadata": { "id": "Pi8vPNNA_l4a" }, "source": [ "Update the `tf_config` for the second worker's process to pick up:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lAiYkkPu_Jqd" }, "outputs": [], "source": [ "tf_config['task']['index'] = 1\n", "os.environ['TF_CONFIG'] = json.dumps(tf_config)" ] }, { "cell_type": "markdown", "metadata": { "id": "0AshGVO0_x0w" }, "source": [ "Now launch the second worker. This will start the training since all the workers are active (so there's no need to background this process):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_ESVtyQ9_xjx" }, "outputs": [], "source": [ "%%bash\n", "python main.py > /dev/null 2>&1" ] }, { "cell_type": "markdown", "metadata": { "id": "hX4FA2O2AuAn" }, "source": [ "If you recheck the logs written by the first worker, notice that it participated in training that model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rc6hw3yTBKXX" }, "outputs": [], "source": [ "%%bash\n", "cat job_0.log" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sG5_1UgrgniF" }, "outputs": [], "source": [ "# Delete the `'TF_CONFIG'`, and kill any background tasks so they don't affect the next section.\n", "os.environ.pop('TF_CONFIG', None)\n", "%killbgscripts" ] }, { "cell_type": "markdown", "metadata": { "id": "bhxMXa0AaZkK" }, "source": [ "## Multi-worker training in depth\n", "\n", "This tutorial has demonstrated a custom training loop workflow of the multi-worker setup. Detailed descriptions of other topics is available in the [Multi-worker training with Keras (`tf.keras.Model.fit`)](multi_worker_with_keras.ipynb) tutorial applicable to custom training loops." ] }, { "cell_type": "markdown", "metadata": { "id": "ega2hdOQEmy_" }, "source": [ "## Learn more\n", "\n", "1. The [Distributed training in TensorFlow](../../guide/distributed_training.ipynb) guide provides an overview of the available distribution strategies.\n", "2. [Official models](https://github.com/tensorflow/models/tree/master/official), many of which can be configured to run multiple distribution strategies.\n", "3. The [Performance section](../../guide/function.ipynb) in the `tf.function` guide provides information about other strategies and [tools](../../guide/profiler.md) you can use to optimize the performance of your TensorFlow models.\n" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "multi_worker_with_ctl.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }