[go: nahoru, domu]

Skip to content

Commit

Permalink
Switch to using tf.int64 for holding the number of samples per action.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 407108824
Change-Id: I15648efdd3f2aaac1be69694a0ccca64e7fc1693
  • Loading branch information
efiko authored and Copybara-Service committed Nov 2, 2021
1 parent f2fbd29 commit c060bf0
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tf_agents/bandits/agents/greedy_reward_prediction_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def _train(self, experience, weights):
# Compute the number of samples for each action in the current batch.
actions_flattened = tf.reshape(experience.action, [-1])
num_samples_per_action_current = [
tf.reduce_sum(tf.cast(tf.equal(actions_flattened, k), tf.int32))
tf.reduce_sum(tf.cast(tf.equal(actions_flattened, k), tf.int64))
for k in range(self._num_actions)]
# Update the number of samples for each action.
for a, b in zip(self._num_samples_list, num_samples_per_action_current):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def testNumSamplesList(self):
for k in range(3):
num_samples_list.append(
tf.compat.v2.Variable(
tf.zeros([], dtype=tf.int32), name='num_samples_{}'.format(k)))
tf.zeros([], dtype=tf.int64), name='num_samples_{}'.format(k)))
agent = greedy_agent.GreedyRewardPredictionAgent(
self._time_step_spec,
self._action_spec,
Expand Down

0 comments on commit c060bf0

Please sign in to comment.