[go: nahoru, domu]

Skip to content

Commit

Permalink
Makes some helper functions in `bandits.agents.examples.v2.trainer.py…
Browse files Browse the repository at this point in the history
…` private

They are not to be used outside of this module.

PiperOrigin-RevId: 483413990
Change-Id: I10c5765e4d7f9daf4ef215565afed44c1e394d96
  • Loading branch information
TF-Agents Team authored and Copybara-Service committed Oct 24, 2022
1 parent 868f276 commit 6d0ff50
Showing 1 changed file with 16 additions and 21 deletions.
37 changes: 16 additions & 21 deletions tf_agents/bandits/agents/examples/v2/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,8 @@
CHECKPOINT_FILE_PREFIX = 'ckpt'


def export_metrics(step, metrics):
"""Exports metrics."""
metric_utils.log_metrics(metrics)
export_utils.export_metrics(step=step, metrics=metrics)
for metric in metrics:
metric.tf_summaries(train_step=step)


def export_loss_info(step, loss_info):
"""Exports loss info."""
export_utils.export_metrics(step=step, metrics=[], loss_info=loss_info)


def get_replay_buffer(data_spec, batch_size, steps_per_loop,
async_steps_per_loop):
def _get_replay_buffer(data_spec, batch_size, steps_per_loop,
async_steps_per_loop):
"""Return a `TFUniformReplayBuffer` for the given `agent`."""
return bandit_replay_buffer.BanditReplayBuffer(
data_spec=data_spec,
Expand All @@ -72,8 +59,8 @@ def set_time_dim(input_tensor, steps):
tf.nest.map_structure(lambda t: set_time_dim(t, num_steps), experience)


def get_training_loop(driver, replay_buffer, agent, steps,
async_steps_per_loop):
def _get_training_loop(driver, replay_buffer, agent, steps,
async_steps_per_loop):
"""Returns a `tf.function` that runs the driver and training loops.
Args:
Expand All @@ -87,11 +74,18 @@ def get_training_loop(driver, replay_buffer, agent, steps,
many batches sampled from the replay buffer.
"""

def _export_metrics_and_summaries(step, metrics):
"""Exports metrics and tf summaries."""
metric_utils.log_metrics(metrics)
export_utils.export_metrics(step=step, metrics=metrics)
for metric in metrics:
metric.tf_summaries(train_step=step)

def training_loop(train_step, metrics):
"""Returns a function that runs a single training loop and logs metrics."""
for batch_id in range(async_steps_per_loop):
driver.run()
export_metrics(
_export_metrics_and_summaries(
step=train_step * async_steps_per_loop + batch_id, metrics=metrics)
batch_size = driver.env.batch_size
dataset_it = iter(
Expand All @@ -103,8 +97,9 @@ def training_loop(train_step, metrics):
experience, unused_buffer_info = dataset_it.get_next()
set_expected_shape(experience, steps)
loss_info = agent.train(experience)
export_loss_info(
export_utils.export_metrics(
step=train_step * async_steps_per_loop + batch_id,
metrics=[],
loss_info=loss_info)

replay_buffer.clear()
Expand Down Expand Up @@ -198,7 +193,7 @@ def baseline_reward_fn(observation, per_action_reward_fns):
if async_steps_per_loop is None:
async_steps_per_loop = 1
if get_replay_buffer_fn is None:
get_replay_buffer_fn = get_replay_buffer
get_replay_buffer_fn = _get_replay_buffer
replay_buffer = get_replay_buffer_fn(data_spec, environment.batch_size,
steps_per_loop, async_steps_per_loop)

Expand Down Expand Up @@ -236,7 +231,7 @@ def baseline_reward_fn(observation, per_action_reward_fns):
observers=observers)

if get_training_loop_fn is None:
get_training_loop_fn = get_training_loop
get_training_loop_fn = _get_training_loop
training_loop = get_training_loop_fn(driver, replay_buffer, agent,
steps_per_loop, async_steps_per_loop)
checkpoint_manager = restore_and_get_checkpoint_manager(
Expand Down

0 comments on commit 6d0ff50

Please sign in to comment.