[go: nahoru, domu]

Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 377944468
  • Loading branch information
saberkun authored and tensorflower-gardener committed Jun 7, 2021
1 parent 1fa648a commit a1fd33c
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions official/nlp/modeling/ops/sampling_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,17 +431,17 @@ def _process_finished_state(

def _continue_search(self, state) -> tf.Tensor:
i = state[decoding_module.StateKeys.CUR_INDEX]
return tf.less(i, self.max_decode_length)
# Have we reached max decoding length?
not_at_end = tf.less(i, self.max_decode_length)
# Have all sampled sequences reached an EOS?
all_has_eos = tf.reduce_all(
state[decoding_module.StateKeys.FINISHED_FLAGS],
axis=None,
name="search_finish_cond")
return tf.logical_and(not_at_end, tf.logical_not(all_has_eos))

def _finished_flags(self, topk_ids, state) -> tf.Tensor:
new_finished_flags = tf.equal(topk_ids, self.eos_id)
new_finished_flags = tf.logical_or(
new_finished_flags, state[decoding_module.StateKeys.FINISHED_FLAGS])
return new_finished_flags







0 comments on commit a1fd33c

Please sign in to comment.