[go: nahoru, domu]

Skip to content

Commit

Permalink
Internal logging.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 474265383
Change-Id: Ia1f03c503f12cf7a22eba68c6cff699979be5f9a
  • Loading branch information
rodolphejenatton authored and Copybara-Service committed Sep 14, 2022
1 parent a3b58c4 commit a3a1440
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 1 deletion.
5 changes: 4 additions & 1 deletion tf_agents/bandits/agents/examples/v2/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tf_agents.bandits.replay_buffers import bandit_replay_buffer
from tf_agents.drivers import dynamic_step_driver
from tf_agents.eval import metric_utils
from tf_agents.metrics import export_utils
from tf_agents.metrics import tf_metrics
from tf_agents.policies import policy_saver

Expand Down Expand Up @@ -190,9 +191,11 @@ def baseline_reward_fn(observation, per_action_reward_fns):

summary_writer = tf.summary.create_file_writer(root_dir)
summary_writer.set_as_default()

for i in range(training_loops):
training_loop()
loss_info = training_loop()
metric_utils.log_metrics(metrics)
export_utils.export_metrics(step=i, metrics=metrics, loss_info=loss_info)
for metric in metrics:
metric.tf_summaries(train_step=step_metric.result())
checkpoint_manager.save()
Expand Down
1 change: 1 addition & 0 deletions tf_agents/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Metrics module."""

from tf_agents.metrics import batched_py_metric
from tf_agents.metrics import export_utils
from tf_agents.metrics import py_metric
from tf_agents.metrics import py_metrics
from tf_agents.metrics import tf_metric
Expand Down
35 changes: 35 additions & 0 deletions tf_agents/metrics/export_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# coding=utf-8
# Copyright 2020 The TF-Agents Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utils to export metrics."""

from absl import logging


def export_metrics(step, metrics, loss_info):
"""Exports the metrics and loss information to logging.info.
Args:
step: Integer denoting the round at which we log the metrics.
metrics: List of `TF metrics` to log.
loss_info: An instance of `LossInfo` whose value is logged.
"""
def logging_at_step_fn(name, value):
logging_msg = f'[step={step}] {name} = {value}.'
logging.info(logging_msg)

for metric in metrics:
logging_at_step_fn(metric.name, metric.result())
logging_at_step_fn('loss', loss_info.loss)

0 comments on commit a3a1440

Please sign in to comment.