[go: nahoru, domu]

Skip to content

Commit

Permalink
Supports multiple batches per training loop in `tf_agents.bandits.exa…
Browse files Browse the repository at this point in the history
…mples.v2.trainer.train`

This change introduces a new configurable parameter `ASYNC_STEPS_PER_LOOP` to the bandits benchmark code. The parameter is an integer that configures the training loop behavior as follows: Within a single training loop iteration, the driver first runs this many times, with metrics exported after every run, and then the agent gets trained over this many batches sampled from the replay buffer, again with loss info exported after every batch. This enables more accurate simulation of offline training, where the policy is updated at intervals, but not continuously.

Note that the total number of train steps will be `TRAININIG_LOOP * ASYNC_STEPS_PER_LOOP`.

PiperOrigin-RevId: 481989490
Change-Id: Ib4fbb8db0257f775fa4a9d80a79f42a987f6ca08
  • Loading branch information
TF-Agents Team authored and Copybara-Service committed Oct 18, 2022
1 parent 4935903 commit 3a4c0f2
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 38 deletions.
91 changes: 63 additions & 28 deletions tf_agents/bandits/agents/examples/v2/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,26 @@
CHECKPOINT_FILE_PREFIX = 'ckpt'


def get_replay_buffer(data_spec, batch_size, steps_per_loop):
def export_metrics(step, metrics):
"""Exports metrics."""
metric_utils.log_metrics(metrics)
export_utils.export_metrics(step=step, metrics=metrics)
for metric in metrics:
metric.tf_summaries(train_step=step)


def export_loss_info(step, loss_info):
"""Exports loss info."""
export_utils.export_metrics(step=step, metrics=[], loss_info=loss_info)


def get_replay_buffer(data_spec, batch_size, steps_per_loop,
async_steps_per_loop):
"""Return a `TFUniformReplayBuffer` for the given `agent`."""
return bandit_replay_buffer.BanditReplayBuffer(
data_spec=data_spec, batch_size=batch_size, max_length=steps_per_loop)
data_spec=data_spec,
batch_size=batch_size,
max_length=steps_per_loop * async_steps_per_loop)


def set_expected_shape(experience, num_steps):
Expand All @@ -56,7 +72,8 @@ def set_time_dim(input_tensor, steps):
tf.nest.map_structure(lambda t: set_time_dim(t, num_steps), experience)


def get_training_loop(driver, replay_buffer, agent, steps):
def get_training_loop(driver, replay_buffer, agent, steps,
async_steps_per_loop):
"""Returns a `tf.function` that runs the driver and training loops.
Args:
Expand All @@ -65,22 +82,32 @@ def get_training_loop(driver, replay_buffer, agent, steps):
agent: an instance of `TFAgent`.
steps: an integer indicating how many driver steps should be
executed and presented to the trainer during each training loop.
async_steps_per_loop: an integer. In each training loop, the driver runs
this many times, and then the agent gets asynchronously trained over this
many batches sampled from the replay buffer.
"""

def training_loop(train_step):
"""Returns a `tf.function` that runs the training loop."""
del train_step # unused
driver.run()
def training_loop(train_step, metrics):
"""Returns a function that runs a single training loop and logs metrics."""
for batch_id in range(async_steps_per_loop):
driver.run()
export_metrics(
step=train_step * async_steps_per_loop + batch_id, metrics=metrics)
batch_size = driver.env.batch_size
dataset = replay_buffer.as_dataset(
sample_batch_size=batch_size,
num_steps=steps,
single_deterministic_pass=True)
experience, unused_buffer_info = tf.data.Dataset.get_single_element(dataset)
set_expected_shape(experience, steps)
loss_info = agent.train(experience)
dataset_it = iter(
replay_buffer.as_dataset(
sample_batch_size=batch_size,
num_steps=steps,
single_deterministic_pass=True))
for batch_id in range(async_steps_per_loop):
experience, unused_buffer_info = dataset_it.get_next()
set_expected_shape(experience, steps)
loss_info = agent.train(experience)
export_loss_info(
step=train_step * async_steps_per_loop + batch_id,
loss_info=loss_info)

replay_buffer.clear()
return loss_info

return training_loop

Expand Down Expand Up @@ -110,6 +137,7 @@ def train(root_dir,
environment,
training_loops,
steps_per_loop,
async_steps_per_loop=None,
additional_metrics=(),
get_replay_buffer_fn=None,
get_training_loop_fn=None,
Expand All @@ -134,18 +162,27 @@ def baseline_reward_fn(observation, per_action_reward_fns):
environment: an instance of `TFEnvironment`.
training_loops: an integer indicating how many training loops should be run.
steps_per_loop: an integer indicating how many driver steps should be
executed and presented to the trainer during each training loop.
executed in a single driver run.
async_steps_per_loop: an optional integer for simulating offline or
asynchronous training: In each training loop iteration, the driver runs
this many times, each executing `steps_per_loop` driver steps, and then
the agent gets asynchronously trained over this many batches sampled from
the replay buffer. When unset or set to 1, the function performs
synchronous training, where the agent gets trained on a single batch
immediately after the driver runs.
additional_metrics: Tuple of metric objects to log, in addition to default
metrics `NumberOfEpisodes`, `AverageReturnMetric`, and
`AverageEpisodeLengthMetric`.
get_replay_buffer_fn: An optional function that creates a replay buffer by
taking a data_spec, batch size, and the number of steps per loop. Note
that the returned replay buffer will be passed to `get_training_loop_fn`
below to generate a traininig loop function. If `None`, the
`get_replay_buffer` function defined in this module will be used.
taking a data_spec, batch size, the number of driver steps per loop, and
the number of asynchronous training steps per loop. Note that the returned
replay buffer will be passed to `get_training_loop_fn` below to generate a
traininig loop function. If `None`, the `get_replay_buffer` function
defined in this module will be used.
get_training_loop_fn: An optional function that constructs the traininig
loop function executing a single train step. This function takes a driver,
a replay buffer, an agent and the number of steps per loop. If `None`, the
a replay buffer, an agent, the number of driver steps per loop, and the
number of asynchronous training steps per loop. If `None`, the
`get_training_loop` function defined in this module will be used.
training_data_spec_transformation_fn: Optional function that transforms the
data items before they get to the replay buffer.
Expand All @@ -158,10 +195,12 @@ def baseline_reward_fn(observation, per_action_reward_fns):
else:
data_spec = training_data_spec_transformation_fn(
agent.policy.trajectory_spec)
if async_steps_per_loop is None:
async_steps_per_loop = 1
if get_replay_buffer_fn is None:
get_replay_buffer_fn = get_replay_buffer
replay_buffer = get_replay_buffer_fn(data_spec, environment.batch_size,
steps_per_loop)
steps_per_loop, async_steps_per_loop)

# `step_metric` records the number of individual rounds of bandit interaction;
# that is, (number of trajectories) * batch_size.
Expand Down Expand Up @@ -199,7 +238,7 @@ def baseline_reward_fn(observation, per_action_reward_fns):
if get_training_loop_fn is None:
get_training_loop_fn = get_training_loop
training_loop = get_training_loop_fn(driver, replay_buffer, agent,
steps_per_loop)
steps_per_loop, async_steps_per_loop)
checkpoint_manager = restore_and_get_checkpoint_manager(
root_dir, agent, metrics, step_metric)
train_step_counter = tf.compat.v1.train.get_or_create_global_step()
Expand All @@ -211,11 +250,7 @@ def baseline_reward_fn(observation, per_action_reward_fns):
summary_writer.set_as_default()

for i in range(training_loops):
loss_info = training_loop(train_step=i)
metric_utils.log_metrics(metrics)
export_utils.export_metrics(step=i, metrics=metrics, loss_info=loss_info)
for metric in metrics:
metric.tf_summaries(train_step=step_metric.result())
training_loop(train_step=i, metrics=metrics)
checkpoint_manager.save()
if save_policy & (i % 100 == 0):
saver.save(os.path.join(root_dir, 'policy_%d' % step_metric.result()))
37 changes: 30 additions & 7 deletions tf_agents/bandits/agents/examples/v2/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
import functools
import os
import tempfile
from unittest import mock

from absl import logging
from absl.testing import parameterized
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
import tensorflow_probability as tfp
Expand All @@ -35,6 +37,7 @@
from tf_agents.bandits.environments import wheel_py_environment
from tf_agents.bandits.metrics import tf_metrics as tf_bandit_metrics
from tf_agents.environments import tf_py_environment
from tf_agents.metrics import export_utils
from tf_agents.specs import tensor_spec

tfd = tfp.distributions
Expand Down Expand Up @@ -101,6 +104,20 @@ def get_environment_and_optimal_functions_by_name(environment_name, batch_size):
return (environment, optimal_reward_fn, optimal_action_fn)


class MockLog(mock.Mock):

def __init__(self, *args, **kwargs):
super(MockLog, self).__init__(*args, **kwargs)
self.lines = []

def info(self, message, *args):
self.lines.append(message % args)
logging.info(message, *args)

def as_string(self):
return '\n'.join(self.lines)


class TrainerTest(tf.test.TestCase, parameterized.TestCase):

@parameterized.named_parameters(
Expand Down Expand Up @@ -186,13 +203,19 @@ def testAgentAndEnvironmentRuns(self, environment_name, agent_name):
regret_metric = tf_bandit_metrics.RegretMetric(optimal_reward_fn)
suboptimal_arms_metric = tf_bandit_metrics.SuboptimalArmsMetric(
optimal_action_fn)
trainer.train(
root_dir=tempfile.mkdtemp(dir=os.getenv('TEST_TMPDIR')),
agent=agent,
environment=environment,
training_loops=training_loops,
steps_per_loop=steps_per_loop,
additional_metrics=[regret_metric, suboptimal_arms_metric])
with mock.patch.object(
export_utils, 'logging', new_callable=MockLog) as mock_logging:
trainer.train(
root_dir=tempfile.mkdtemp(dir=os.getenv('TEST_TMPDIR')),
agent=agent,
environment=environment,
training_loops=training_loops,
steps_per_loop=steps_per_loop,
additional_metrics=[regret_metric, suboptimal_arms_metric])
logged = mock_logging.as_string()
self.assertEqual(logged.count('RegretMetric'), training_loops)
self.assertEqual(logged.count('SuboptimalArmsMetric'), training_loops)
self.assertEqual(logged.count('loss'), training_loops)


if __name__ == '__main__':
Expand Down
7 changes: 4 additions & 3 deletions tf_agents/metrics/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@
from absl import logging


def export_metrics(step, metrics, loss_info):
def export_metrics(step, metrics, loss_info=None):
"""Exports the metrics and loss information to logging.info.
Args:
step: Integer denoting the round at which we log the metrics.
metrics: List of `TF metrics` to log.
loss_info: An instance of `LossInfo` whose value is logged.
loss_info: An optional instance of `LossInfo` whose value is logged.
"""
def logging_at_step_fn(name, value):
logging_msg = f'[step={step}] {name} = {value}.'
logging.info(logging_msg)

for metric in metrics:
logging_at_step_fn(metric.name, metric.result())
logging_at_step_fn('loss', loss_info.loss)
if loss_info is not None:
logging_at_step_fn('loss', loss_info.loss)

0 comments on commit 3a4c0f2

Please sign in to comment.