[go: nahoru, domu]

Skip to content

Commit

Permalink
Refactors tf_agents.bandits.examples.v2.trainer.train
Browse files Browse the repository at this point in the history
Refactors the replay buffer and training loop function creation logic out of the train function. This allows users to inject a customized training loop function and/or replay buffer.

PiperOrigin-RevId: 479607009
Change-Id: I596115570ab008e245a9a96000a51a74602c24ef
  • Loading branch information
TF-Agents Team authored and Copybara-Service committed Oct 7, 2022
1 parent 78aedf6 commit 4935903
Showing 1 changed file with 35 additions and 17 deletions.
52 changes: 35 additions & 17 deletions tf_agents/bandits/agents/examples/v2/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Generic TF-Agents training function for bandits."""
"""Generic TF-Agents training function for bandits."""

from __future__ import absolute_import
from __future__ import division
Expand All @@ -37,26 +37,26 @@
CHECKPOINT_FILE_PREFIX = 'ckpt'


def get_replay_buffer(data_spec,
batch_size,
steps_per_loop):
def get_replay_buffer(data_spec, batch_size, steps_per_loop):
"""Return a `TFUniformReplayBuffer` for the given `agent`."""
buf = bandit_replay_buffer.BanditReplayBuffer(
data_spec=data_spec,
batch_size=batch_size,
max_length=steps_per_loop)
return buf
return bandit_replay_buffer.BanditReplayBuffer(
data_spec=data_spec, batch_size=batch_size, max_length=steps_per_loop)


def set_expected_shape(experience, num_steps):
"""Sets expected shape."""
def set_time_dim(input_tensor, steps):
tensor_shape = input_tensor.shape.as_list()
if len(tensor_shape) < 2:
raise ValueError(
'input_tensor is expected to be of rank-2, but found otherwise: '
f'input_tensor={input_tensor}, tensor_shape={tensor_shape}')
tensor_shape[1] = steps
input_tensor.set_shape(tensor_shape)
tf.nest.map_structure(lambda t: set_time_dim(t, num_steps), experience)


def get_training_loop_fn(driver, replay_buffer, agent, steps):
def get_training_loop(driver, replay_buffer, agent, steps):
"""Returns a `tf.function` that runs the driver and training loops.
Args:
Expand All @@ -66,19 +66,22 @@ def get_training_loop_fn(driver, replay_buffer, agent, steps):
steps: an integer indicating how many driver steps should be
executed and presented to the trainer during each training loop.
"""
def training_loop():

def training_loop(train_step):
"""Returns a `tf.function` that runs the training loop."""
del train_step # unused
driver.run()
batch_size = driver.env.batch_size
dataset = replay_buffer.as_dataset(
sample_batch_size=batch_size,
num_steps=steps,
single_deterministic_pass=True)
experience, unused_info = tf.data.experimental.get_single_element(dataset)
experience, unused_buffer_info = tf.data.Dataset.get_single_element(dataset)
set_expected_shape(experience, steps)
loss_info = agent.train(experience)
replay_buffer.clear()
return loss_info

return training_loop


Expand Down Expand Up @@ -108,6 +111,8 @@ def train(root_dir,
training_loops,
steps_per_loop,
additional_metrics=(),
get_replay_buffer_fn=None,
get_training_loop_fn=None,
training_data_spec_transformation_fn=None,
save_policy=True):
"""Perform `training_loops` iterations of training.
Expand All @@ -133,6 +138,15 @@ def baseline_reward_fn(observation, per_action_reward_fns):
additional_metrics: Tuple of metric objects to log, in addition to default
metrics `NumberOfEpisodes`, `AverageReturnMetric`, and
`AverageEpisodeLengthMetric`.
get_replay_buffer_fn: An optional function that creates a replay buffer by
taking a data_spec, batch size, and the number of steps per loop. Note
that the returned replay buffer will be passed to `get_training_loop_fn`
below to generate a traininig loop function. If `None`, the
`get_replay_buffer` function defined in this module will be used.
get_training_loop_fn: An optional function that constructs the traininig
loop function executing a single train step. This function takes a driver,
a replay buffer, an agent and the number of steps per loop. If `None`, the
`get_training_loop` function defined in this module will be used.
training_data_spec_transformation_fn: Optional function that transforms the
data items before they get to the replay buffer.
save_policy: (bool) whether to save the policy or not.
Expand All @@ -144,8 +158,10 @@ def baseline_reward_fn(observation, per_action_reward_fns):
else:
data_spec = training_data_spec_transformation_fn(
agent.policy.trajectory_spec)
replay_buffer = get_replay_buffer(data_spec, environment.batch_size,
steps_per_loop)
if get_replay_buffer_fn is None:
get_replay_buffer_fn = get_replay_buffer
replay_buffer = get_replay_buffer_fn(data_spec, environment.batch_size,
steps_per_loop)

# `step_metric` records the number of individual rounds of bandit interaction;
# that is, (number of trajectories) * batch_size.
Expand Down Expand Up @@ -180,8 +196,10 @@ def baseline_reward_fn(observation, per_action_reward_fns):
num_steps=steps_per_loop * environment.batch_size,
observers=observers)

training_loop = get_training_loop_fn(
driver, replay_buffer, agent, steps_per_loop)
if get_training_loop_fn is None:
get_training_loop_fn = get_training_loop
training_loop = get_training_loop_fn(driver, replay_buffer, agent,
steps_per_loop)
checkpoint_manager = restore_and_get_checkpoint_manager(
root_dir, agent, metrics, step_metric)
train_step_counter = tf.compat.v1.train.get_or_create_global_step()
Expand All @@ -193,7 +211,7 @@ def baseline_reward_fn(observation, per_action_reward_fns):
summary_writer.set_as_default()

for i in range(training_loops):
loss_info = training_loop()
loss_info = training_loop(train_step=i)
metric_utils.log_metrics(metrics)
export_utils.export_metrics(step=i, metrics=metrics, loss_info=loss_info)
for metric in metrics:
Expand Down

0 comments on commit 4935903

Please sign in to comment.