-
Notifications
You must be signed in to change notification settings - Fork 721
/
heteroscedastic_q_network.py
180 lines (156 loc) · 6.98 KB
/
heteroscedastic_q_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# coding=utf-8
# Copyright 2020 The TF-Agents Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Network Outputting Expected Value and Variance of Rewards."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from typing import Any, Callable, Optional, Sequence, Text
import gin
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
from tf_agents.networks import encoding_network
from tf_agents.networks import network
from tf_agents.networks import q_network
from tf_agents.typing import types
class QBanditNetworkResult(
collections.namedtuple(
'QBanditNetworkResult', ('q_value_logits', 'log_variance')
)
):
pass
@gin.configurable
class HeteroscedasticQNetwork(network.Network):
"""Network Outputting Expected Value and Variance of Rewards."""
def __init__(
self,
input_tensor_spec: types.NestedTensorSpec,
action_spec: types.NestedTensorSpec,
preprocessing_layers: Optional[Callable[..., types.Tensor]] = None,
preprocessing_combiner: Optional[Callable[..., types.Tensor]] = None,
conv_layer_params: Optional[Sequence[Any]] = None,
fc_layer_params: Sequence[int] = (75, 40),
dropout_layer_params: Optional[Sequence[float]] = None,
activation_fn: Callable[
[types.Tensor], types.Tensor
] = tf.keras.activations.relu,
kernel_initializer: Optional[tf.keras.initializers.Initializer] = None,
batch_squash: bool = True,
min_variance: float = 0.1,
max_variance: float = 10000.0,
dtype: tf.DType = tf.float32,
name: Text = 'HeteroscedasticQNetwork',
):
"""Creates an instance of `HeteroscedasticQNetwork`.
Args:
input_tensor_spec: A nest of `tensor_spec.TensorSpec` representing the
input observations.
action_spec: A nest of `tensor_spec.BoundedTensorSpec` representing the
actions.
preprocessing_layers: (Optional.) A nest of `tf.keras.layers.Layer`
representing preprocessing for the different observations. All of these
layers must not be already built. For more details see the documentation
of `networks.EncodingNetwork`.
preprocessing_combiner: (Optional.) A keras layer that takes a flat list
of tensors and combines them. Good options include `tf.keras.layers.Add`
and `tf.keras.layers.Concatenate(axis=-1)`. This layer must not be
already built. For more details see the documentation of
`networks.EncodingNetwork`.
conv_layer_params: Optional list of convolution layers parameters, where
each item is a length-three tuple indicating (filters, kernel_size,
stride).
fc_layer_params: Optional list of fully_connected parameters, where each
item is the number of units in the layer.
dropout_layer_params: Optional list of dropout layer parameters, where
each item is the fraction of input units to drop. The dropout layers are
interleaved with the fully connected layers; there is a dropout layer
after each fully connected layer, except if the entry in the list is
None. This list must have the same length of fc_layer_params, or be
None.
activation_fn: Activation function, e.g. tf.keras.activations.relu.
kernel_initializer: Initializer to use for the kernels of the conv and
dense layers. If none is provided a default variance_scaling_initializer
batch_squash: If True the outer_ranks of the observation are squashed into
the batch dimension. This allow encoding networks to be used with
observations with shape [BxTx...].
min_variance: Float. The minimum allowed predicted variance. Predicted
variances less than this value will be clipped to this value.
max_variance: Float. The maximum allowed predicted variance. Predicted
variances greater than this value will be clipped to this value.
dtype: The dtype to use by the convolution and fully connected layers.
name: A string representing the name of the network.
Raises:
ValueError: If `input_tensor_spec` contains more than one observation. Or
if `action_spec` contains more than one action.
"""
q_network.validate_specs(action_spec, input_tensor_spec)
action_spec = tf.nest.flatten(action_spec)[0]
num_actions = action_spec.maximum - action_spec.minimum + 1
encoder_input_tensor_spec = input_tensor_spec
encoder = encoding_network.EncodingNetwork(
encoder_input_tensor_spec,
preprocessing_layers=preprocessing_layers,
preprocessing_combiner=preprocessing_combiner,
conv_layer_params=conv_layer_params,
fc_layer_params=fc_layer_params,
dropout_layer_params=dropout_layer_params,
activation_fn=activation_fn,
kernel_initializer=kernel_initializer,
batch_squash=batch_squash,
dtype=dtype,
)
q_value_layer = tf.keras.layers.Dense(
num_actions,
activation=None,
kernel_initializer=tf.random_uniform_initializer(
minval=-0.03, maxval=0.03
),
bias_initializer=tf.constant_initializer(-0.2),
)
super(HeteroscedasticQNetwork, self).__init__(
input_tensor_spec=input_tensor_spec, state_spec=(), name=name
)
self._encoder = encoder
self._q_value_layer = q_value_layer
self._log_variance_layer = tf.keras.layers.Dense(
num_actions,
activation=None,
kernel_initializer=tf.random_uniform_initializer(
minval=-0.03, maxval=0.03
),
dtype=dtype,
)
self._min_variance = min_variance
self._max_variance = max_variance
def call(self, observation, step_type=None, network_state=()):
"""Runs the given observation through the network.
Args:
observation: The observation to provide to the network.
step_type: The step type for the given observation. See `StepType` in
time_step.py.
network_state: A state tuple to pass to the network, mainly used by RNNs.
Returns:
An instance of `QBanditNetworkResult`.
"""
state, network_state = self._encoder(
observation, step_type=step_type, network_state=network_state
)
log_variance = tf.clip_by_value(
self._log_variance_layer(state),
tf.math.log(self._min_variance),
tf.math.log(self._max_variance),
)
q_value_logits = self._q_value_layer(state)
result = QBanditNetworkResult(q_value_logits, log_variance)
return (result, network_state)