[go: nahoru, domu]

Skip to content

Commit

Permalink
Populates dummy chosen arm features in `RewardPredictionBasePolicy._d…
Browse files Browse the repository at this point in the history
…istribution`

Without this, the `distribution` method fails because the policy state spec expects the chosen arm features in the policy info, but the distribution step does not have it.

PiperOrigin-RevId: 474279805
Change-Id: I8caa7e7c178a45ab3739ac51d8b57c13137f7eac
  • Loading branch information
TF-Agents Team authored and Copybara-Service committed Sep 14, 2022
1 parent 966e564 commit 3cb0802
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
9 changes: 7 additions & 2 deletions tf_agents/bandits/policies/reward_prediction_base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,12 @@ def _distribution(self, time_step, policy_state):

if self._accepts_per_arm_features:
# The actual action sampling hasn't happened yet, so we leave
# `log_probability` and `chosen_arm_features` empty.
# `log_probability` empty and set `chosen_arm_features` to dummy values of
# all zeros. We need to save dummy chosen arm features to make the
# returned policy step have the same structure as the policy state spec.
dummy_chosen_arm_features = tf.nest.map_structure(
lambda obs: tf.zeros_like(obs[:, 0, ...]),
time_step.observation[bandit_spec_utils.PER_ARM_FEATURE_KEY])
policy_info = policy_utilities.PerArmPolicyInfo(
log_probability=(),
predicted_rewards_mean=(
Expand All @@ -267,7 +272,7 @@ def _distribution(self, time_step, policy_state):
bandit_policy_type=(bandit_policy_values
if policy_utilities.InfoFields.BANDIT_POLICY_TYPE
in self._emit_policy_info else ()),
chosen_arm_features=())
chosen_arm_features=dummy_chosen_arm_features)
else:
# The actual action sampling hasn't happened yet, so we leave
# `log_probability` empty.
Expand Down
31 changes: 31 additions & 0 deletions tf_agents/bandits/policies/reward_prediction_policies_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,37 @@ def testPerArmRewards(self, policy_class):
self.assertAllEqual(p_info.predicted_rewards_mean[:, 0],
padded_p_info.predicted_rewards_mean[:, 0])

@test_cases()
def testPerArmPolicyDistribution(self, policy_class):
tf.compat.v1.set_random_seed(3000)
obs_spec = bandit_spec_utils.create_per_arm_observation_spec(2, 3, 4)
time_step_spec = ts.time_step_spec(obs_spec)
action_spec = tensor_spec.BoundedTensorSpec((), tf.int32, 0, 3)
reward_network = (
global_and_arm_feature_network.create_feed_forward_common_tower_network(
obs_spec, (4, 3), (3, 4), (4, 2)))

policy = policy_class(
time_step_spec,
action_spec,
reward_network=reward_network,
accepts_per_arm_features=True)
action_feature = tf.cast(
tf.reshape(tf.random.shuffle(tf.range(24)), shape=[2, 4, 3]),
dtype=tf.float32)
observations = {
bandit_spec_utils.GLOBAL_FEATURE_KEY:
tf.constant([[1, 2], [3, 4]], dtype=tf.float32),
bandit_spec_utils.PER_ARM_FEATURE_KEY:
action_feature
}
time_step = ts.restart(observations, batch_size=2)
distribution_step = policy.distribution(time_step)
# Initialize all variables
self.evaluate(tf.compat.v1.global_variables_initializer())
info = self.evaluate(distribution_step.info)
self.assertAllEqual(info.chosen_arm_features.shape, [2, 3])

@test_cases()
def testPerArmRewardsVariableNumActions(self, policy_class):
tf.compat.v1.set_random_seed(3000)
Expand Down

0 comments on commit 3cb0802

Please sign in to comment.