[go: nahoru, domu]

Skip to content

Commit

Permalink
Silence some pytype errors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 529115288
Change-Id: I793e047817eb559b675fbe69d64d05d31f3b1b9c
  • Loading branch information
rchen152 authored and Copybara-Service committed May 3, 2023
1 parent e315eac commit f7bfb6a
Show file tree
Hide file tree
Showing 19 changed files with 26 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _scalarize(self, transformed_multi_objectives: tf.Tensor) -> tf.Tensor:
{'weights': self._weights})
return tf.reduce_sum(transformed_multi_objectives * self._weights, axis=1)

def set_parameters(self, weights: tf.Tensor):
def set_parameters(self, weights: tf.Tensor): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
"""Set the scalarization parameter of the LinearScalarizer.
Args:
Expand Down Expand Up @@ -280,7 +280,7 @@ def _scalarize(self, transformed_multi_objectives: tf.Tensor) -> tf.Tensor:
(transformed_multi_objectives - self._reference_point) * self._weights,
axis=-1)

def set_parameters(self, weights: tf.Tensor, reference_point: tf.Tensor):
def set_parameters(self, weights: tf.Tensor, reference_point: tf.Tensor): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
"""Set the scalarization parameters for the ChebyshevScalarizer.
Args:
Expand Down Expand Up @@ -397,7 +397,7 @@ def _scalarize(self, transformed_multi_objectives: tf.Tensor) -> tf.Tensor:
transformed_multi_objectives.dtype.max),
axis=1)

def set_parameters(self, direction: tf.Tensor,
def set_parameters(self, direction: tf.Tensor, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
transform_params: Dict[str, tf.Tensor]):
"""Set the scalarization parameters for the HyperVolumeScalarizer.
Expand Down
2 changes: 1 addition & 1 deletion tf_agents/bandits/policies/neural_linucb_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def _distribution(self, time_step, policy_state):
raise NotImplementedError(
'This policy outputs an action and not a distribution.')

def _action(self, time_step, policy_state, seed):
def _action(self, time_step, policy_state, seed): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
observation = time_step.observation
if self.observation_and_action_constraint_splitter is not None:
observation, _ = self.observation_and_action_constraint_splitter(
Expand Down
2 changes: 1 addition & 1 deletion tf_agents/drivers/py_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
self._max_episodes = max_episodes or np.inf
self._end_episode_on_boundary = end_episode_on_boundary

def run(
def run( # pytype: disable=signature-mismatch # overriding-parameter-count-checks
self,
time_step: ts.TimeStep,
policy_state: types.NestedArray = ()
Expand Down
2 changes: 1 addition & 1 deletion tf_agents/drivers/tf_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
if not disable_tf_function:
self.run = common.function(self.run, autograph=True)

def run(
def run( # pytype: disable=signature-mismatch # overriding-parameter-count-checks
self, time_step: ts.TimeStep,
policy_state: types.NestedTensor = ()
) -> Tuple[ts.TimeStep, types.NestedTensor]:
Expand Down
2 changes: 1 addition & 1 deletion tf_agents/environments/atari_preprocessing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def step(self, action):
self.ale.screen_value -= 2
return (self.get_observation(), reward, is_terminal, unused)

def render(self, mode):
def render(self, mode): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
pass


Expand Down
2 changes: 1 addition & 1 deletion tf_agents/environments/tf_py_environment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def action_spec(self):
def observation_spec(self):
return specs.ArraySpec([], np.int64, name='observation')

def render(self, mode):
def render(self, mode): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
assert isinstance(mode, (str, Text)), 'Got: {}'.format(type(mode))
if mode == 'rgb_array':
return np.ones((4, 4, 3), dtype=np.uint8)
Expand Down
2 changes: 1 addition & 1 deletion tf_agents/policies/batched_py_policy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self,
action_spec=action_spec,
policy_state_spec=policy_state_spec)

def _action(self,
def _action(self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
time_step: ts.TimeStep,
policy_state: types.NestedArray) -> ps.PolicyStep:
random_action = array_spec.sample_spec_nest(
Expand Down
2 changes: 1 addition & 1 deletion tf_agents/policies/epsilon_greedy_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _get_epsilon(self):
else:
return self._epsilon

def _action(self, time_step, policy_state, seed):
def _action(self, time_step, policy_state, seed): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
seed_stream = tfp.util.SeedStream(seed=seed, salt='epsilon_greedy')
greedy_action = self._greedy_policy.action(time_step, policy_state)
random_action = self._random_policy.action(time_step, (), seed_stream())
Expand Down
2 changes: 1 addition & 1 deletion tf_agents/policies/fixed_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _get_policy_info_and_action(self, time_step):
self._action_value)
return policy_info, action

def _action(self, time_step, policy_state, seed):
def _action(self, time_step, policy_state, seed): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
del seed
policy_info, action = self._get_policy_info_and_action(time_step)
return policy_step.PolicyStep(action, policy_state, policy_info)
Expand Down
2 changes: 1 addition & 1 deletion tf_agents/policies/gaussian_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _create_normal_distribution(action_spec):
def _variables(self):
return self._wrapped_policy.variables()

def _action(self, time_step, policy_state, seed):
def _action(self, time_step, policy_state, seed): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
seed_stream = tfp.util.SeedStream(seed=seed, salt='gaussian_noise')

action_step = self._wrapped_policy.action(time_step, policy_state,
Expand Down
2 changes: 1 addition & 1 deletion tf_agents/policies/greedy_policy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, distribution, time_step_spec, action_spec, name=None):
super(DistributionPolicy, self).__init__(
time_step_spec, action_spec, name=name)

def _action(self, time_step, policy_state, seed):
def _action(self, time_step, policy_state, seed): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
raise NotImplementedError('Not implemented.')

def _distribution(self, time_step, policy_state):
Expand Down
2 changes: 1 addition & 1 deletion tf_agents/policies/ou_noise_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _validate_action_spec(action_spec):
def _variables(self):
return self._wrapped_policy.variables()

def _action(self, time_step, policy_state, seed):
def _action(self, time_step, policy_state, seed): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
seed_stream = tfp.util.SeedStream(seed=seed, salt='ou_noise')

def _create_ou_process(action_spec):
Expand Down
2 changes: 1 addition & 1 deletion tf_agents/policies/policy_info_updater_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _update_info(self, step):
current_info.update(self._updater_fn(step))
return policy_step.PolicyStep(step.action, step.state, current_info)

def _action(self, time_step, policy_state, seed):
def _action(self, time_step, policy_state, seed): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
action_step = self._wrapped_policy.action(time_step, policy_state, seed)
return self._update_info(action_step)

Expand Down
2 changes: 1 addition & 1 deletion tf_agents/policies/policy_info_updater_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self,
info_spec=info_spec,
name=name)

def _action(self, time_step, policy_state, seed):
def _action(self, time_step, policy_state, seed): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
return policy_step.PolicyStep(tf.constant(1., shape=(1,)), policy_state,
{'test_info': tf.constant(2, shape=(1,))})

Expand Down
4 changes: 2 additions & 2 deletions tf_agents/policies/policy_saver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def __init__(self):
step_type=(), reward=(), discount=(), observation=()),
action_spec=())

def _action(self, **kwargs):
def _action(self, **kwargs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
return policy_step.PolicyStep((), ())

def _distribution(self, **kwargs):
def _distribution(self, **kwargs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
raise NotImplementedError('_distribution has not been implemented.')


Expand Down
2 changes: 1 addition & 1 deletion tf_agents/policies/random_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(self, time_step_spec: ts.TimeStep,
def _variables(self):
return []

def _action(self, time_step, policy_state, seed):
def _action(self, time_step, policy_state, seed): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
observation_and_action_constraint_splitter = (
self.observation_and_action_constraint_splitter)

Expand Down
2 changes: 1 addition & 1 deletion tf_agents/policies/temporal_action_smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _distribution(self, time_step, policy_state):
raise NotImplementedError(
'`distribution` not implemented for TemporalActionSmoothingWrapper.')

def _action(self, time_step, policy_state, seed):
def _action(self, time_step, policy_state, seed): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
# Get action from the wrapped policy.
wrapped_policy_state, moving_average = policy_state
wrapped_policy_step = self._wrapped_policy.action(time_step,
Expand Down
4 changes: 2 additions & 2 deletions tf_agents/policies/temporal_action_smoothing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ def __init__(self, time_step_spec, action_spec):
policy_state_spec=action_spec,
)

def _action(self, time_step, policy_state, seed):
def _action(self, time_step, policy_state, seed): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
actions = tf.nest.map_structure(lambda t: t + 1, policy_state)
return policy_step.PolicyStep(actions, actions, ())

def _distribution(self):
def _distribution(self): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
return policy_step.PolicyStep(())


Expand Down
8 changes: 4 additions & 4 deletions tf_agents/policies/tf_policy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, init_var_value, var_scope, name=None):
def _variables(self):
return self._variables_list

def _action(self, time_step, policy_state, seed):
def _action(self, time_step, policy_state, seed): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
return policy_step.PolicyStep(())

def _distribution(self, time_step, policy_state):
Expand All @@ -74,7 +74,7 @@ def __init__(self):
action_spec = tensor_spec.BoundedTensorSpec([1], tf.int32, 0, 1)
super(TFPolicyMismatchedDtypes, self).__init__(time_step_spec, action_spec)

def _action(self, time_step, policy_state, seed):
def _action(self, time_step, policy_state, seed): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
# This action's dtype intentionally doesn't match action_spec's dtype.
return policy_step.PolicyStep(action=tf.constant([0], dtype=tf.int64))

Expand All @@ -95,7 +95,7 @@ def __init__(self):
super(TFPolicyMismatchedDtypesListAction,
self).__init__(time_step_spec, action_spec)

def _action(self, time_step, policy_state, seed):
def _action(self, time_step, policy_state, seed): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
# This time, the action is a list where only the second dtype doesn't match.
return policy_step.PolicyStep(action=[
tf.constant([0], dtype=tf.int64),
Expand All @@ -108,7 +108,7 @@ def _distribution(self, time_step, policy_state):

class TfPassThroughPolicy(tf_policy.TFPolicy):

def _action(self, time_step, policy_state, seed):
def _action(self, time_step, policy_state, seed): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
distributions = self._distribution(time_step, policy_state)
actions = tf.nest.map_structure(lambda d: d.sample(), distributions.action)
return policy_step.PolicyStep(actions, policy_state, ())
Expand Down

0 comments on commit f7bfb6a

Please sign in to comment.