-
Notifications
You must be signed in to change notification settings - Fork 720
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PiperOrigin-RevId: 268746875 Change-Id: I7dfc14979c381727021a39531e0fb23b9ccab58d
- Loading branch information
TF-Agents Team
authored and
Copybara-Service
committed
Sep 12, 2019
1 parent
2e2ed1f
commit 56e57e0
Showing
78 changed files
with
11,554 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The TF-Agents Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The TF-Agents Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Module importing all agents.""" | ||
|
||
from tf_agents.bandits.agents import dropout_thompson_sampling_agent | ||
from tf_agents.bandits.agents import exp3_agent | ||
from tf_agents.bandits.agents import greedy_reward_prediction_agent | ||
from tf_agents.bandits.agents import lin_ucb_agent | ||
from tf_agents.bandits.agents import linear_thompson_sampling_agent | ||
from tf_agents.bandits.agents import neural_epsilon_greedy_agent | ||
from tf_agents.bandits.agents import utils |
105 changes: 105 additions & 0 deletions
105
tf_agents/bandits/agents/dropout_thompson_sampling_agent.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The TF-Agents Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""A neural network based agent that implements Thompson sampling via dropout. | ||
Implements an agent based on a neural network that predicts arm rewards. | ||
The neural network internally uses dropout to approximate Thompson sampling. | ||
""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import gin | ||
import tensorflow as tf | ||
|
||
from tf_agents.bandits.agents import greedy_reward_prediction_agent | ||
from tf_agents.networks import q_network | ||
|
||
|
||
@gin.configurable | ||
class DropoutThompsonSamplingAgent( | ||
greedy_reward_prediction_agent.GreedyRewardPredictionAgent): | ||
"""A neural network based Thompson sampling agent. | ||
This agent receives parameters for a neural network and trains it to predict | ||
rewards. The action is chosen greedily with respect to the prediction. | ||
The neural network implements dropout for exploration. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
time_step_spec, | ||
action_spec, | ||
optimizer, | ||
# Network params. | ||
dropout_rate, | ||
network_layers, | ||
dropout_only_top_layer=True, | ||
# Params for training. | ||
error_loss_fn=tf.compat.v1.losses.mean_squared_error, | ||
gradient_clipping=None, | ||
# Params for debugging. | ||
debug_summaries=False, | ||
summarize_grads_and_vars=False, | ||
train_step_counter=None, | ||
name=None): | ||
"""Creates a Dropout Thompson Sampling Agent. | ||
Args: | ||
time_step_spec: A `TimeStep` spec of the expected time_steps. | ||
action_spec: A nest of `BoundedTensorSpec` representing the actions. | ||
optimizer: The optimizer to use for training. | ||
dropout_rate: Float in `(0, 1)`, the dropout rate. | ||
network_layers: Tuple of ints determining the sizes of the network layers. | ||
dropout_only_top_layer: Boolean parameter determining if dropout should be | ||
done only in the top layer. True by default. | ||
error_loss_fn: A function for computing the error loss, taking parameters | ||
labels, predictions, and weights (any function from tf.losses would | ||
work). The default is `tf.losses.mean_squared_error`. | ||
gradient_clipping: A float representing the norm length to clip gradients | ||
(or None for no clipping.) | ||
debug_summaries: A Python bool, default False. When True, debug summaries | ||
are gathered. | ||
summarize_grads_and_vars: A Python bool, default False. When True, | ||
gradients and network variable summaries are written during training. | ||
train_step_counter: An optional `tf.Variable` to increment every time the | ||
train op is run. Defaults to the `global_step`. | ||
name: Python str name of this agent. All variables in this module will | ||
fall under that name. Defaults to the class name. | ||
Raises: | ||
ValueError: If the action spec contains more than one action or or it is | ||
not a bounded scalar int32 spec with minimum 0. | ||
""" | ||
fc_layer_params = network_layers | ||
dropout_param = {'rate': dropout_rate, 'permanent': True} | ||
if dropout_only_top_layer: | ||
dropout_layer_params = [None] * (len(fc_layer_params) - 1) | ||
dropout_layer_params.append(dropout_param) | ||
else: | ||
dropout_layer_params = [dropout_param] * len(fc_layer_params) | ||
|
||
reward_network = q_network.QNetwork( | ||
input_tensor_spec=time_step_spec.observation, | ||
action_spec=action_spec, | ||
fc_layer_params=fc_layer_params, | ||
dropout_layer_params=dropout_layer_params) | ||
|
||
super(DropoutThompsonSamplingAgent, | ||
self).__init__(time_step_spec, action_spec, reward_network, optimizer, | ||
error_loss_fn, gradient_clipping, debug_summaries, | ||
summarize_grads_and_vars, train_step_counter, name) |
111 changes: 111 additions & 0 deletions
111
tf_agents/bandits/agents/dropout_thompson_sampling_agent_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The TF-Agents Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Tests for dropout_thompson_sampling_agent.py.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from tf_agents.bandits.agents import dropout_thompson_sampling_agent | ||
from tf_agents.bandits.drivers import driver_utils | ||
from tf_agents.specs import tensor_spec | ||
from tf_agents.trajectories import policy_step | ||
from tf_agents.trajectories import time_step as ts | ||
|
||
from tensorflow.python.framework import test_util # pylint:disable=g-direct-tensorflow-import # TF internal | ||
|
||
|
||
def _get_initial_and_final_steps(observations, rewards): | ||
batch_size = observations.shape[0] | ||
initial_step = ts.TimeStep( | ||
tf.constant( | ||
ts.StepType.FIRST, dtype=tf.int32, shape=[batch_size], | ||
name='step_type'), | ||
tf.constant(0.0, dtype=tf.float32, shape=[batch_size], name='reward'), | ||
tf.constant(1.0, dtype=tf.float32, shape=[batch_size], name='discount'), | ||
tf.constant(observations, dtype=tf.float32, name='observation')) | ||
final_step = ts.TimeStep( | ||
tf.constant( | ||
ts.StepType.LAST, dtype=tf.int32, shape=[batch_size], | ||
name='step_type'), | ||
tf.constant(rewards, dtype=tf.float32, name='reward'), | ||
tf.constant(1.0, dtype=tf.float32, shape=[batch_size], name='discount'), | ||
tf.constant(observations + 100.0, dtype=tf.float32, name='observation')) | ||
return initial_step, final_step | ||
|
||
|
||
def _get_action_step(action): | ||
return policy_step.PolicyStep( | ||
action=tf.convert_to_tensor(action)) | ||
|
||
|
||
def _get_experience(initial_step, action_step, final_step): | ||
single_experience = driver_utils.trajectory_for_bandit( | ||
initial_step, action_step, final_step) | ||
# Adds a 'time' dimension. | ||
return tf.nest.map_structure( | ||
lambda x: tf.expand_dims(tf.convert_to_tensor(x), 1), | ||
single_experience) | ||
|
||
|
||
@test_util.run_all_in_graph_and_eager_modes | ||
class AgentTest(tf.test.TestCase): | ||
|
||
def setUp(self): | ||
super(AgentTest, self).setUp() | ||
tf.compat.v1.enable_resource_variables() | ||
self._obs_spec = tensor_spec.TensorSpec([2], tf.float32) | ||
self._time_step_spec = ts.time_step_spec(self._obs_spec) | ||
self._action_spec = tensor_spec.BoundedTensorSpec( | ||
dtype=tf.int32, shape=(), minimum=0, maximum=2) | ||
|
||
def testCreateAgent(self): | ||
agent = dropout_thompson_sampling_agent.DropoutThompsonSamplingAgent( | ||
self._time_step_spec, | ||
self._action_spec, | ||
optimizer=None, | ||
dropout_rate=0.1, | ||
network_layers=(20, 20, 20)) | ||
self.assertIsNotNone(agent.policy) | ||
|
||
def testTrainAgent(self): | ||
optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.1) | ||
agent = dropout_thompson_sampling_agent.DropoutThompsonSamplingAgent( | ||
self._time_step_spec, | ||
self._action_spec, | ||
optimizer=optimizer, | ||
dropout_rate=0.1, | ||
network_layers=(20, 20, 20), | ||
dropout_only_top_layer=False) | ||
observations = np.array([[1, 2], [3, 4]], dtype=np.float32) | ||
actions = np.array([0, 1], dtype=np.float32) | ||
rewards = np.array([0.5, 3.0], dtype=np.float32) | ||
initial_step, final_step = _get_initial_and_final_steps( | ||
observations, rewards) | ||
action_step = _get_action_step(actions) | ||
experience = _get_experience(initial_step, action_step, final_step) | ||
loss_before, _ = agent.train(experience, None) | ||
loss_after, _ = agent.train(experience, None) | ||
self.evaluate(tf.compat.v1.global_variables_initializer()) | ||
self.assertAllGreater(self.evaluate(loss_before), 0) | ||
self.assertAllGreater(self.evaluate(loss_after), 0) | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.test.main() |
112 changes: 112 additions & 0 deletions
112
tf_agents/bandits/agents/examples/v1/train_eval_drifting_linear.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The TF-Agents Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""End-to-end test for bandits against a drifting linear environment. | ||
""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import os | ||
from absl import app | ||
from absl import flags | ||
|
||
import tensorflow as tf | ||
import tensorflow_probability as tfp | ||
from tf_agents.bandits.agents import lin_ucb_agent | ||
from tf_agents.bandits.agents import linear_thompson_sampling_agent as lin_ts_agent | ||
from tf_agents.bandits.agents.examples.v1 import trainer | ||
from tf_agents.bandits.environments import drifting_linear_environment as dle | ||
from tf_agents.bandits.environments import non_stationary_stochastic_environment as nse | ||
from tf_agents.bandits.metrics import tf_metrics as tf_bandit_metrics | ||
|
||
|
||
flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'), | ||
'Root directory for writing logs/summaries/checkpoints.') | ||
flags.DEFINE_enum( | ||
'agent', 'LinUCB', ['LinUCB', 'LinTS'], | ||
'Which agent to use. Possible values are `LinUCB` and `LinTS`.') | ||
|
||
FLAGS = flags.FLAGS | ||
tfd = tfp.distributions | ||
|
||
|
||
CONTEXT_DIM = 15 | ||
NUM_ACTIONS = 5 | ||
REWARD_NOISE_VARIANCE = 0.01 | ||
DRIFT_VARIANCE = 0.01 | ||
DRIFT_MEAN = 0.01 | ||
BATCH_SIZE = 8 | ||
TRAINING_LOOPS = 200 | ||
STEPS_PER_LOOP = 2 | ||
AGENT_ALPHA = 10.0 | ||
|
||
|
||
def main(unused_argv): | ||
tf.enable_resource_variables() | ||
|
||
with tf.device('/CPU:0'): # due to b/128333994 | ||
observation_shape = [CONTEXT_DIM] | ||
overall_shape = [BATCH_SIZE] + observation_shape | ||
observation_distribution = tfd.Normal( | ||
loc=tf.zeros(overall_shape), scale=tf.ones(overall_shape)) | ||
action_shape = [NUM_ACTIONS] | ||
observation_to_reward_shape = observation_shape + action_shape | ||
observation_to_reward_distribution = tfd.Normal( | ||
loc=tf.zeros(observation_to_reward_shape), | ||
scale=tf.ones(observation_to_reward_shape)) | ||
drift_distribution = tfd.Normal(loc=DRIFT_MEAN, scale=DRIFT_VARIANCE) | ||
additive_reward_distribution = tfd.Normal( | ||
loc=tf.zeros(action_shape), | ||
scale=(REWARD_NOISE_VARIANCE * tf.ones(action_shape))) | ||
environment_dynamics = dle.DriftingLinearDynamics( | ||
observation_distribution, | ||
observation_to_reward_distribution, | ||
drift_distribution, | ||
additive_reward_distribution) | ||
environment = nse.NonStationaryStochasticEnvironment(environment_dynamics) | ||
|
||
if FLAGS.agent == 'LinUCB': | ||
agent = lin_ucb_agent.LinearUCBAgent( | ||
time_step_spec=environment.time_step_spec(), | ||
action_spec=environment.action_spec(), | ||
alpha=AGENT_ALPHA, | ||
gamma=0.95, | ||
emit_log_probability=False, | ||
dtype=tf.float32) | ||
elif FLAGS.agent == 'LinTS': | ||
agent = lin_ts_agent.LinearThompsonSamplingAgent( | ||
time_step_spec=environment.time_step_spec(), | ||
action_spec=environment.action_spec(), | ||
gamma=0.95, | ||
dtype=tf.float32) | ||
|
||
regret_metric = tf_bandit_metrics.RegretMetric( | ||
environment.environment_dynamics.compute_optimal_reward) | ||
suboptimal_arms_metric = tf_bandit_metrics.SuboptimalArmsMetric( | ||
environment.environment_dynamics.compute_optimal_action) | ||
|
||
trainer.train( | ||
root_dir=FLAGS.root_dir, | ||
agent=agent, | ||
environment=environment, | ||
training_loops=TRAINING_LOOPS, | ||
steps_per_loop=STEPS_PER_LOOP, | ||
additional_metrics=[regret_metric, suboptimal_arms_metric]) | ||
|
||
|
||
if __name__ == '__main__': | ||
app.run(main) |
Oops, something went wrong.