[go: nahoru, domu]

Skip to content

Commit

Permalink
Add type annotations to the Bandit environments.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 328897856
Change-Id: Ib48d2cf49835ae02212f9ce60633144aa619c0b8
  • Loading branch information
efiko authored and Copybara-Service committed Aug 28, 2020
1 parent 58c8f85 commit d68902e
Show file tree
Hide file tree
Showing 21 changed files with 221 additions and 139 deletions.
22 changes: 14 additions & 8 deletions tf_agents/bandits/environments/bandit_py_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@
"""Base class for Bandit Python environments."""
from __future__ import absolute_import
from __future__ import division
# Using Type Annotations.
from __future__ import print_function

import abc
from typing import Optional
import numpy as np

import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import

from tf_agents.environments import py_environment
from tf_agents.trajectories import time_step as ts
from tf_agents.typing import types


class BanditPyEnvironment(py_environment.PyEnvironment):
Expand All @@ -40,13 +43,16 @@ class BanditPyEnvironment(py_environment.PyEnvironment):
returned by step(action) will contain the reward and the next observation.
"""

def __init__(self, observation_spec, action_spec, reward_spec=None):
def __init__(self,
observation_spec: types.NestedArray,
action_spec: types.NestedArray,
reward_spec: Optional[types.NestedArray] = None):
self._observation_spec = observation_spec
self._action_spec = action_spec
self._reward_spec = reward_spec
super(BanditPyEnvironment, self).__init__()

def _reset(self):
def _reset(self) -> ts.TimeStep:
"""Returns a time step containing an observation.
It should not be overridden by Bandit environment implementations.
Expand All @@ -57,7 +63,7 @@ def _reset(self):
return ts.restart(self._observe(), batch_size=self.batch_size,
reward_spec=self.reward_spec())

def _step(self, action):
def _step(self, action: types.NestedArray) -> ts.TimeStep:
"""Returns a time step containing the reward for the action taken.
The returning time step also contains the next observation.
Expand All @@ -74,21 +80,21 @@ def _step(self, action):
reward = self._apply_action(action)
return ts.termination(self._observe(), reward)

def action_spec(self):
def action_spec(self) -> types.NestedArraySpec:
return self._action_spec

def observation_spec(self):
def observation_spec(self) -> types.NestedArraySpec:
return self._observation_spec

def reward_spec(self):
def reward_spec(self) -> types.NestedArraySpec:
return self._reward_spec

def _empty_observation(self):
return tf.nest.map_structure(lambda x: np.zeros(x.shape, x.dtype),
self.observation_spec())

@abc.abstractmethod
def _apply_action(self, action):
def _apply_action(self, action: types.NestedArray) -> types.Float:
"""Applies `action` to the Environment and returns the corresponding reward.
Args:
Expand All @@ -100,5 +106,5 @@ def _apply_action(self, action):
"""

@abc.abstractmethod
def _observe(self):
def _observe(self) -> types.NestedArray:
"""Returns an observation."""
18 changes: 12 additions & 6 deletions tf_agents/bandits/environments/bandit_tf_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@

from __future__ import absolute_import
from __future__ import division
# Using Type Annotations.
from __future__ import print_function

import abc
from typing import Optional
import six

import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import

from tf_agents.environments import tf_environment
from tf_agents.trajectories import time_step as ts
from tf_agents.typing import types
from tf_agents.utils import common
from tf_agents.utils import nest_utils

Expand Down Expand Up @@ -57,7 +60,10 @@ class BanditTFEnvironment(tf_environment.TFEnvironment):
```
"""

def __init__(self, time_step_spec=None, action_spec=None, batch_size=1):
def __init__(self,
time_step_spec: Optional[types.NestedArray] = None,
action_spec: Optional[types.NestedArray] = None,
batch_size: Optional[types.Int] = 1):
"""Initialize instances of `BanditTFEnvironment`.
Args:
Expand Down Expand Up @@ -94,7 +100,7 @@ def _update_time_step(self, time_step):
self._time_step_variables, time_step)

@common.function()
def _current_time_step(self):
def _current_time_step(self) -> ts.TimeStep:
def true_fn():
return tf.nest.map_structure(tf.identity, self._time_step_variables)
def false_fn():
Expand All @@ -104,7 +110,7 @@ def false_fn():
return tf.cond(self._reset_called, true_fn, false_fn)

@common.function
def _reset(self):
def _reset(self) -> ts.TimeStep:
current_time_step = ts.restart(
self._observe(), batch_size=self.batch_size,
reward_spec=self.time_step_spec().reward)
Expand All @@ -113,16 +119,16 @@ def _reset(self):
return current_time_step

@common.function
def _step(self, action):
def _step(self, action: types.NestedArray) -> ts.TimeStep:
reward = self._apply_action(action)
current_time_step = ts.termination(self._observe(), reward)
self._update_time_step(current_time_step)
return current_time_step

@abc.abstractmethod
def _apply_action(self, action):
def _apply_action(self, action: types.NestedArray) -> types.Float:
"""Returns a reward for the given action."""

@abc.abstractmethod
def _observe(self):
def _observe(self) -> types.NestedTensor:
"""Returns an observation."""
2 changes: 0 additions & 2 deletions tf_agents/bandits/environments/bandit_tf_environment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step as ts
from tf_agents.utils import common
from tensorflow.python.framework import test_util # pylint:disable=g-direct-tensorflow-import # TF internal


class ZerosEnvironment(bandit_tf_environment.BanditTFEnvironment):
Expand Down Expand Up @@ -73,7 +72,6 @@ def _observe(self):
return tf.zeros(observation_shape)


@test_util.run_all_in_graph_and_eager_modes
class BanditTFEnvironmentTest(tf.test.TestCase, parameterized.TestCase):

@parameterized.named_parameters(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,19 @@ def join_fn(context, mask):

from __future__ import absolute_import
from __future__ import division
# Using Type Annotations.
from __future__ import print_function

from typing import Callable

import gin
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
import tensorflow_probability as tfp

from tf_agents.bandits.environments import bandit_tf_environment
from tf_agents.bandits.policies import policy_utilities
from tf_agents.trajectories import time_step as ts
from tf_agents.typing import types
from tf_agents.utils import common

tfd = tfp.distributions
Expand All @@ -70,8 +74,11 @@ class BernoulliActionMaskTFEnvironment(bandit_tf_environment.BanditTFEnvironment
):
"""An environment wrapper that adds action masks to observations."""

def __init__(self, original_environment, action_constraint_join_fn,
action_probability):
def __init__(self,
original_environment: bandit_tf_environment.BanditTFEnvironment,
action_constraint_join_fn: Callable[
[types.TensorSpec, types.TensorSpec], types.TensorSpec],
action_probability: float):
"""Initializes a `BernoulliActionMaskTFEnvironment`.
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@
from tf_agents.bandits.environments import bernoulli_action_mask_tf_environment as masked_tf_env
from tf_agents.bandits.environments import random_bandit_environment
from tf_agents.specs import tensor_spec
from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import # TF internal

tfd = tfp.distributions


@test_util.run_all_in_graph_and_eager_modes
class BernoulliActionMaskTfEnvironmentTest(tf.test.TestCase,
parameterized.TestCase):

Expand Down
16 changes: 11 additions & 5 deletions tf_agents/bandits/environments/bernoulli_py_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@
# limitations under the License.

"""Class implementation of Python Bernoulli Bandit environment."""
# Using Type Annotations.
from typing import Optional, Sequence

import gin
import numpy as np

from tf_agents.bandits.environments import bandit_py_environment
from tf_agents.specs import array_spec
from tf_agents.typing import types


@gin.configurable
Expand All @@ -32,7 +36,9 @@ class BernoulliPyEnvironment(bandit_py_environment.BanditPyEnvironment):
Russo et al. (https://web.stanford.edu/~bvr/pubs/TS_Tutorial.pdf)
"""

def __init__(self, means, batch_size=1):
def __init__(self,
means: Sequence[types.Float],
batch_size: Optional[types.Int] = 1):
"""Initializes a Bernoulli Bandit environment.
Args:
Expand All @@ -56,18 +62,18 @@ def __init__(self, means, batch_size=1):
shape=(1,), dtype=np.int32, name='observation')
super(BernoulliPyEnvironment, self).__init__(observation_spec, action_spec)

def _observe(self):
def _observe(self) -> types.NestedArray:
return np.zeros(
shape=[self._batch_size] + list(self.observation_spec().shape),
dtype=self.observation_spec().dtype)

def _apply_action(self, action):
def _apply_action(self, action: types.NestedArray) -> types.Float:
return [np.floor(self._means[i] + np.random.random()) for i in action]

@property
def batched(self):
def batched(self) -> bool:
return True

@property
def batch_size(self):
def batch_size(self) -> types.Int:
return self._batch_size
25 changes: 18 additions & 7 deletions tf_agents/bandits/environments/classification_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,21 @@

from __future__ import absolute_import
from __future__ import division
# Using Type Annotations.
from __future__ import print_function

from typing import Optional

import gin
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
import tensorflow_probability as tfp
from tf_agents.bandits.environments import bandit_tf_environment as bte
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step
from tf_agents.typing import types
from tf_agents.utils import eager_utils


tfd = tfp.distributions


Expand Down Expand Up @@ -56,9 +61,15 @@ def _batched_table_lookup(tbl, row, col):
class ClassificationBanditEnvironment(bte.BanditTFEnvironment):
"""An environment based on an arbitrary classification problem."""

def __init__(self, dataset, reward_distribution, batch_size,
label_dtype_cast=None, shuffle_buffer_size=None,
repeat_dataset=True, prefetch_size=None, seed=None):
def __init__(self,
dataset: tf.data.Dataset,
reward_distribution: types.Distribution,
batch_size: types.Int,
label_dtype_cast: Optional[tf.DType] = None,
shuffle_buffer_size: Optional[types.Int] = None,
repeat_dataset: Optional[bool] = True,
prefetch_size: Optional[types.Int] = None,
seed: Optional[types.Int] = None):
"""Initialize `ClassificationBanditEnvironment`.
Args:
Expand Down Expand Up @@ -136,7 +147,7 @@ def __init__(self, dataset, reward_distribution, batch_size,
reward_means, axis=1, output_type=self._action_spec.dtype)
self._optimal_reward_table = tf.reduce_max(reward_means, axis=1)

def _observe(self):
def _observe(self) -> types.NestedTensor:
context, lbl = eager_utils.get_next(self._data_iterator)
self._previous_label.assign(self._current_label)
self._current_label.assign(tf.reshape(
Expand All @@ -145,16 +156,16 @@ def _observe(self):
context,
shape=[self._batch_size] + self._time_step_spec.observation.shape)

def _apply_action(self, action):
def _apply_action(self, action: types.NestedTensor) -> types.NestedTensor:
action = tf.reshape(
action, shape=[self._batch_size] + self._action_spec.shape)
reward_samples = self._reward_distribution.sample(tf.shape(action))
return _batched_table_lookup(reward_samples, self._current_label, action)

def compute_optimal_action(self):
def compute_optimal_action(self) -> types.NestedTensor:
return tf.gather(
params=self._optimal_action_table, indices=self._previous_label)

def compute_optimal_reward(self):
def compute_optimal_reward(self) -> types.NestedTensor:
return tf.gather(
params=self._optimal_reward_table, indices=self._previous_label)
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import tensorflow_probability as tfp

from tf_agents.bandits.environments import classification_environment as ce
from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import # TF internal

tfd = tfp.distributions

Expand All @@ -37,7 +36,6 @@ def deterministic_reward_distribution(reward_table):
reinterpreted_batch_ndims=2)


@test_util.run_all_in_graph_and_eager_modes
class ClassificationEnvironmentTest(tf.test.TestCase, parameterized.TestCase):

@parameterized.named_parameters(
Expand Down
Loading

0 comments on commit d68902e

Please sign in to comment.