[go: nahoru, domu]

Skip to content

Commit

Permalink
Make the Boltzmann-Gumbel agent XLA-compatible.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 407054947
Change-Id: I3b5d05045bf77a6307f2453a4809da5efeda4dce
  • Loading branch information
efiko authored and Copybara-Service committed Nov 2, 2021
1 parent db66f82 commit f2fbd29
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions tf_agents/bandits/agents/greedy_reward_prediction_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,9 @@ def _train(self, experience, weights):
if not self._accepts_per_arm_features and self._num_samples_list:
# Compute the number of samples for each action in the current batch.
actions_flattened = tf.reshape(experience.action, [-1])
num_samples_per_action_current = tf.unstack(
tf.reshape(
tf.math.bincount(
actions_flattened,
minlength=self._num_actions, maxlength=self._num_actions,
dtype=tf.int32),
[self._num_actions]), axis=-1)
num_samples_per_action_current = [
tf.reduce_sum(tf.cast(tf.equal(actions_flattened, k), tf.int32))
for k in range(self._num_actions)]
# Update the number of samples for each action.
for a, b in zip(self._num_samples_list, num_samples_per_action_current):
tf.compat.v1.assign_add(a, b)
Expand Down

0 comments on commit f2fbd29

Please sign in to comment.