[go: nahoru, domu]

Skip to content

Commit

Permalink
Restore from latest checkpoint instead of the checkpoint at epoch beg…
Browse files Browse the repository at this point in the history
…inning.

PiperOrigin-RevId: 485715882
  • Loading branch information
anastasiyabl authored and Copybara-Service committed Nov 2, 2022
1 parent b8dfd6a commit 71f88dd
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 11 deletions.
2 changes: 1 addition & 1 deletion deepconsensus/models/model_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def compute_loss(labels: tf.Tensor, student_preds: tf.Tensor,

# model, optimizer, and checkpoint must be created under `strategy.scope`.
checkpoint, initial_epoch = model_utils.get_checkpoint_and_initial_epoch(
model, optimizer, epoch_checkpoint) # pytype: disable=wrong-arg-types # typed-keras
model, optimizer, False, out_dir, steps_per_epoch, epoch_checkpoint) # pytype: disable=wrong-arg-types # typed-keras

# Create summary writers
train_writer = tf.summary.create_file_writer(os.path.join(out_dir, 'train'))
Expand Down
2 changes: 1 addition & 1 deletion deepconsensus/models/model_train_custom_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def compute_loss(labels, predictions):

# model, optimizer, and checkpoint must be created under `strategy.scope`.
checkpoint, initial_epoch = model_utils.get_checkpoint_and_initial_epoch(
model, optimizer, epoch_checkpoint) # pytype: disable=wrong-arg-types # typed-keras
model, optimizer, False, out_dir, steps_per_epoch, epoch_checkpoint) # pytype: disable=wrong-arg-types # typed-keras

# Create summary writers
train_writer = tf.summary.create_file_writer(os.path.join(out_dir, 'train'))
Expand Down
29 changes: 20 additions & 9 deletions deepconsensus/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,20 +454,31 @@ def get_step_counts(params: ml_collections.ConfigDict,

def get_checkpoint_and_initial_epoch(
model: tf.keras.models.Model, optimizer: tf.keras.optimizers.Optimizer,
reload_from_epoch_start: bool, out_dir: str, steps_per_epoch: int,
epoch_checkpoint: str) -> Tuple[tf.train.Checkpoint, int]:
"""Loads a checkpoint if available and sets epoch to start training."""
initial_epoch = 0
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
if tf.io.gfile.exists(epoch_checkpoint):
with tf.io.gfile.GFile(epoch_checkpoint, 'r') as f:
epoch_checkpoint, initial_epoch = f.readline().split('\t')
initial_epoch = int(initial_epoch)
checkpoint.restore(epoch_checkpoint)
logging.info('Loading checkpoint %s for epoch %s', epoch_checkpoint,
initial_epoch)
if reload_from_epoch_start:
# Load the checkpoint that corresponds to the beginning of the epoch.
# TODO.
if tf.io.gfile.exists(epoch_checkpoint):
with tf.io.gfile.GFile(epoch_checkpoint, 'r') as f:
epoch_checkpoint, initial_epoch = f.readline().split('\t')
initial_epoch = int(initial_epoch)
checkpoint.restore(epoch_checkpoint)
logging.info('Loading checkpoint %s for epoch %s', epoch_checkpoint,
initial_epoch)
else:
logging.info('No Epoch checkpoint. Starting from epoch %s', initial_epoch)
initial_epoch = 0
else:
logging.info('No Epoch checkpoint. Starting from epoch %s', initial_epoch)
initial_epoch = 0
# Load from the latest checkpoint if it exists.
latest_checkpoint = tf.train.latest_checkpoint(out_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
logging.info('Loaded checkpoint %s', latest_checkpoint)
initial_epoch = optimizer.iterations.numpy() // steps_per_epoch
return checkpoint, initial_epoch


Expand Down

0 comments on commit 71f88dd

Please sign in to comment.