[go: nahoru, domu]

Skip to content

Commit

Permalink
Use common check function for literal comparisons and suppress warnings.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 538235389
Change-Id: I81255653866d133a8930ef6e657f8b853598debe
  • Loading branch information
pwohlhart authored and Copybara-Service committed Jun 6, 2023
1 parent 006fca7 commit a3b1942
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 31 deletions.
9 changes: 5 additions & 4 deletions tf_agents/agents/ppo/ppo_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from tf_agents.trajectories import policy_step
from tf_agents.trajectories import time_step as ts
from tf_agents.typing import types
from tf_agents.utils import common
from tf_agents.utils import nest_utils
from tf_agents.utils import tensor_normalizer

Expand Down Expand Up @@ -271,12 +272,12 @@ def _distribution(self, time_step, policy_state, training=False):
policy_info = ()

# Disable lint for TF arrays.
if (new_policy_state['actor_network_state'] is () and # pylint: disable=literal-comparison
new_policy_state['value_network_state'] is ()): # pylint: disable=literal-comparison
if (not common.safe_has_state(new_policy_state['actor_network_state']) and
not common.safe_has_state(new_policy_state['value_network_state'])):
new_policy_state = ()
elif new_policy_state['value_network_state'] is (): # pylint: disable=literal-comparison
elif not common.safe_has_state(new_policy_state['value_network_state']):
del new_policy_state['value_network_state']
elif new_policy_state['actor_network_state'] is (): # pylint: disable=literal-comparison
elif not common.safe_has_state(new_policy_state['actor_network_state']):
del new_policy_state['actor_network_state']

return policy_step.PolicyStep(distributions, new_policy_state, policy_info)
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from tf_agents.trajectories import policy_step
from tf_agents.trajectories import time_step as ts
from tf_agents.typing import types
from tf_agents.utils import common


@tf.function
Expand Down Expand Up @@ -235,10 +236,7 @@ def _predict(
ValueError: If the output size of any objective network does not match the
expected number of actions.
"""
# TODO(b/158804957): Use literal comparison because in some strange cases
# (tf.function? autograph?) the expression "x in (None, (), [])" gets
# converted to a tensor.
if policy_state is None or policy_state is () or policy_state is []: # pylint: disable=literal-comparison
if not common.safe_has_state(policy_state):
policy_state = [()] * self._num_objectives
predicted_objective_values = []
updated_policy_state = []
Expand Down
7 changes: 3 additions & 4 deletions tf_agents/keras_layers/rnn_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from __future__ import print_function

import tensorflow as tf
from tf_agents.utils import common


__all__ = ['RNNWrapper']

Expand Down Expand Up @@ -167,11 +169,8 @@ def call(self, inputs, initial_state=None, mask=None, training=False):
inputs_flat = [tf.expand_dims(t, axis=1) for t in inputs_flat]
inputs = tf.nest.pack_sequence_as(inputs, inputs_flat)

# TODO(b/158804957): tf.function changes "if tensor:" to tensor bool expr.
# pylint: disable=literal-comparison
if initial_state is None or initial_state is () or initial_state is []:
if not common.safe_has_state(initial_state):
initial_state = self._layer.get_initial_state(inputs)
# pylint: enable=literal-comparison

outputs = self._layer(
inputs, initial_state=initial_state, mask=mask, training=training)
Expand Down
8 changes: 5 additions & 3 deletions tf_agents/networks/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tf_agents.keras_layers import rnn_wrapper
from tf_agents.networks import network
from tf_agents.typing import types
from tf_agents.utils import common


def _infer_state_specs(
Expand Down Expand Up @@ -203,12 +204,13 @@ def call(self, inputs, network_state=(), **kwargs):

input_state = maybe_network_state

# pylint: disable=literal-comparison
if maybe_network_state is None:
input_state = layer.get_initial_state(inputs)
elif input_state is not () and self._layer_state_is_list[i]:
elif (
common.safe_has_state(input_state)
and self._layer_state_is_list[i]
):
input_state = list(input_state)
# pylint: enable=literal-comparison

outputs = layer(inputs, input_state, **layer_kwargs)
inputs, next_state = outputs
Expand Down
46 changes: 34 additions & 12 deletions tf_agents/policies/qtopt_cem_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from tf_agents.trajectories import policy_step
from tf_agents.trajectories import time_step as ts
from tf_agents.typing import types
from tf_agents.utils import common
from tf_agents.utils import nest_utils

try:
Expand Down Expand Up @@ -320,23 +321,44 @@ def init_best_actions(action_spec):
best_scores = tf.zeros([batch_size, self._num_elites], dtype=tf.float32)

# Run the while loop for CEM in-graph.
mean_shape = tf.nest.map_structure(
lambda m: [None] + m.get_shape()[1:], mean
)
var_shape = tf.nest.map_structure(lambda v: [None] + v.get_shape()[1:], var)
best_action_shape = tf.nest.map_structure(
lambda a: [None] + a.get_shape()[1:], best_actions
)
elites_shape = tf.TensorShape([None, self._num_elites])
policy_state_shape = ()
if common.safe_has_state(policy_state):
policy_state_shape = tf.nest.map_structure(
lambda state: state.get_shape(), policy_state
)

_, _, _, _, best_actions, best_scores, best_next_policy_state = (
tf.while_loop(
cond=cond,
body=body,
loop_vars=[mean, var, 0, iters, best_actions, best_scores,
policy_state],
loop_vars=[
mean,
var,
0,
iters,
best_actions,
best_scores,
policy_state,
],
shape_invariants=[
tf.nest.map_structure(
lambda m: [None] + m.get_shape()[1:], mean),
tf.nest.map_structure(
lambda v: [None] + v.get_shape()[1:], var),
tf.TensorShape(()), iters.get_shape(),
tf.nest.map_structure(
lambda a: [None] + a.get_shape()[1:], best_actions),
tf.TensorShape([None, self._num_elites]),
() if policy_state is () else tf.nest.map_structure( # pylint: disable=literal-comparison
lambda state: state.get_shape(), policy_state)]))
mean_shape,
var_shape,
tf.TensorShape(()),
iters.get_shape(),
best_action_shape,
elites_shape,
policy_state_shape,
],
)
)

if outer_rank == 2:
best_actions = tf.nest.map_structure(
Expand Down
5 changes: 1 addition & 4 deletions tf_agents/policies/tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,7 @@ def action(self,
time_step,
self._time_step_spec,
message='time_step and time_step_spec structures do not match')
# TODO(b/158804957): Use literal comparison because in some strange cases
# (tf.function? autograph?) the expression "x not in (None, (), [])" gets
# converted to a tensor.
if not (policy_state is None or policy_state is () or policy_state is []): # pylint: disable=literal-comparison
if common.safe_has_state(policy_state):
nest_utils.assert_same_structure(
policy_state,
self._policy_state_spec,
Expand Down
16 changes: 16 additions & 0 deletions tf_agents/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import importlib
import os
from typing import Dict, Optional, Text
import warnings

from absl import logging

Expand Down Expand Up @@ -1436,6 +1437,21 @@ def deduped_network_variables(network, *args):
return [v for v in network.variables if v not in other_vars]


# Filter warnings about comparing literals with "is" (b/229309809).
warnings.filterwarnings(
'ignore',
'"is" with a literal. Did you mean "=="?',
category=SyntaxWarning,
module=__name__,
)
warnings.filterwarnings(
'ignore',
'"is not" with a literal. Did you mean "!="?',
category=SyntaxWarning,
module=__name__,
)


def safe_has_state(state):
"""Safely checks `state not in (None, (), [])`."""
# TODO(b/158804957): tf.function changes "s in ((),)" to a tensor bool expr.
Expand Down

0 comments on commit a3b1942

Please sign in to comment.