[go: nahoru, domu]

Skip to content

Commit

Permalink
Adds an option to tf_agents/bandits/agents/examples/v2/trainer.py f…
Browse files Browse the repository at this point in the history
…or resuming training loops

PiperOrigin-RevId: 484249329
Change-Id: I12e2b3d8cf67a1ece2192bca4e8a17a21366b194
  • Loading branch information
TF-Agents Team authored and Copybara-Service committed Oct 27, 2022
1 parent 6d0ff50 commit 58ffe1e
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 2 deletions.
24 changes: 22 additions & 2 deletions tf_agents/bandits/agents/examples/v2/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def train(root_dir,
get_replay_buffer_fn=None,
get_training_loop_fn=None,
training_data_spec_transformation_fn=None,
save_policy=True):
save_policy=True,
resume_training_loops=False):
"""Perform `training_loops` iterations of training.
Checkpoint results.
Expand Down Expand Up @@ -182,6 +183,9 @@ def baseline_reward_fn(observation, per_action_reward_fns):
training_data_spec_transformation_fn: Optional function that transforms the
data items before they get to the replay buffer.
save_policy: (bool) whether to save the policy or not.
resume_training_loops: A boolean flag indicating whether
`training_loops` should be enforced relatively to the initial (True) or
the last (False) checkpoint.
"""

# TODO(b/127641485): create evaluation loop with configurable metrics.
Expand Down Expand Up @@ -244,7 +248,23 @@ def baseline_reward_fn(observation, per_action_reward_fns):
summary_writer = tf.summary.create_file_writer(root_dir)
summary_writer.set_as_default()

for i in range(training_loops):
if resume_training_loops:
train_step_count_per_loop = (
steps_per_loop * environment.batch_size * async_steps_per_loop)
last_checkpointed_step = step_metric.result().numpy()
if last_checkpointed_step % train_step_count_per_loop != 0:
raise ValueError(
f'Last checkpointed step is expected to be a multiple of '
'steps_per_loop * batch_size * async_steps_per_loop, but found '
f'otherwise: last checkpointed step: {last_checkpointed_step}, '
f'steps_per_loop: {steps_per_loop}, batch_size: '
f'{environment.batch_size}, async_steps_per_loop: '
f'{async_steps_per_loop}')
starting_loop = last_checkpointed_step // train_step_count_per_loop
else:
starting_loop = 0

for i in range(starting_loop, training_loops):
training_loop(train_step=i, metrics=metrics)
checkpoint_manager.save()
if save_policy & (i % 100 == 0):
Expand Down
65 changes: 65 additions & 0 deletions tf_agents/bandits/agents/examples/v2/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,71 @@ def testAgentAndEnvironmentRuns(self, environment_name, agent_name):
self.assertEqual(logged.count('SuboptimalArmsMetric'), training_loops)
self.assertEqual(logged.count('loss'), training_loops)

def testResumeTrainLoops(self):
batch_size = 8
training_loops = 3
steps_per_loop = 2
environment_name = 'stationary_stochastic'
agent_name = 'epsGreedy'
environment, _, _ = (
trainer_test_utils.get_environment_and_optimal_functions_by_name(
environment_name, batch_size))
agent = trainer_test_utils.get_agent_by_name(agent_name,
environment.time_step_spec(),
environment.action_spec())
root_dir = tempfile.mkdtemp(dir=os.getenv('TEST_TMPDIR'))

def train(training_loops, resume_training_loops):
trainer.train(
root_dir=root_dir,
agent=agent,
environment=environment,
training_loops=training_loops,
steps_per_loop=steps_per_loop,
resume_training_loops=resume_training_loops)

with mock.patch.object(
export_utils, 'logging', new_callable=MockLog) as mock_logging:
train(training_loops=training_loops, resume_training_loops=True)
logged = mock_logging.as_string()
self.assertEqual(logged.count('loss'), training_loops)
self.assertEqual(logged.count('AverageReturn'), training_loops)

# With `resume_training_loops` set to True, the same `training_loops`
# would not result in more training.
with mock.patch.object(
export_utils, 'logging', new_callable=MockLog) as mock_logging:
train(training_loops=training_loops, resume_training_loops=True)
logged = mock_logging.as_string()
self.assertEqual(logged.count('loss'), 0)
self.assertEqual(logged.count('AverageReturn'), 0)

# With `resume_training_loops` set to True, increasing
# `training_loops` will result in more training.
with mock.patch.object(
export_utils, 'logging', new_callable=MockLog) as mock_logging:
train(training_loops=training_loops + 1, resume_training_loops=True)
logged = mock_logging.as_string()
self.assertEqual(logged.count('loss'), 1)
self.assertEqual(logged.count('AverageReturn'), 1)
expected_num_episodes = (training_loops + 1) * steps_per_loop * batch_size
self.assertEqual(
logged.count(f'NumberOfEpisodes = {expected_num_episodes}'), 1)

# With `resume_training_loops` set to False, `training_loops` of 1
# will result in more training.
with mock.patch.object(
export_utils, 'logging', new_callable=MockLog) as mock_logging:
train(training_loops=1, resume_training_loops=False)
logged = mock_logging.as_string()
self.assertEqual(logged.count('loss'), 1)
self.assertEqual(logged.count('AverageReturn'), 1)
# The number of episodes is expected to accumulate over all trainings using
# the same `root_dir`.
expected_num_episodes = (training_loops + 2) * steps_per_loop * batch_size
self.assertEqual(
logged.count(f'NumberOfEpisodes = {expected_num_episodes}'), 1)


if __name__ == '__main__':
tf.test.main()

0 comments on commit 58ffe1e

Please sign in to comment.