[go: nahoru, domu]

Skip to content

Commit

Permalink
Don't record if tape.could_possibly_record is False
Browse files Browse the repository at this point in the history
  • Loading branch information
davisyoshida committed Dec 13, 2020
1 parent 0eca5eb commit 544d53f
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def inner(*args, _checkpoint=False, _watch_vars=None, _force_seed=False, **kwarg
else:
seed = random.randint(1, 1<<31)

if _checkpoint:
if _checkpoint and tape.could_possibly_record():
if _watch_vars is None:
_watch_vars = []

Expand All @@ -41,11 +41,14 @@ def inner(*args, _checkpoint=False, _watch_vars=None, _force_seed=False, **kwarg
tf.random.set_seed(seed)

result = f(*args, **kwargs)

flat_result = nest.flatten(result)
# No idea what the point of this is but they do it in tf.custom_gradient so I'm doing it too
flat_result = [tf.identity(x) for x in flat_result]
output = nest.pack_sequence_as(result, flat_result)
del flat_inputs
del result
del unique_inputs
del unique_vars

def grad(*output_grads):
with tf.GradientTape() as g:
Expand Down

0 comments on commit 544d53f

Please sign in to comment.