[go: nahoru, domu]

Skip to content

Commit

Permalink
Training improvements.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 473371952
  • Loading branch information
danielecook authored and Copybara-Service committed Sep 9, 2022
1 parent dcd2080 commit 8b9268c
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 15 deletions.
15 changes: 14 additions & 1 deletion deepconsensus/models/model_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def train_model(teacher_model: tf.keras.Model, out_dir: str,
model_utils.print_model_summary(model, input_shape)
logging.info('Done building model.')
# Initialize student model from teacher based on model params.
epoch_checkpoint = os.path.join(out_dir, 'epoch_checkpoint.txt')
model = init_student_from_teacher(model, teacher_model, params)
optimizer = tf.keras.optimizers.Adam(learning_rate=params.learning_rate)
train_loss = tf.keras.metrics.Mean(name='train/loss')
Expand Down Expand Up @@ -214,7 +215,7 @@ def compute_loss(labels: tf.Tensor, student_preds: tf.Tensor,

# model, optimizer, and checkpoint must be created under `strategy.scope`.
checkpoint, initial_epoch = model_utils.get_checkpoint_and_initial_epoch(
model, optimizer, out_dir, steps_per_epoch) # pytype: disable=wrong-arg-types # typed-keras
model, optimizer, epoch_checkpoint) # pytype: disable=wrong-arg-types # typed-keras

# Create summary writers
train_writer = tf.summary.create_file_writer(os.path.join(out_dir, 'train'))
Expand Down Expand Up @@ -333,6 +334,7 @@ def distributed_eval_step(iterator):
with train_writer.as_default():
model_utils.log_and_save_metrics(
epoch=epoch,
num_epochs=params['num_epochs'],
step=step_train,
total_steps=steps_per_epoch,
optimizer=optimizer,
Expand Down Expand Up @@ -367,6 +369,7 @@ def distributed_eval_step(iterator):
with eval_writer.as_default():
model_utils.log_and_save_metrics(
epoch=epoch,
num_epochs=params['num_epochs'],
step=step_eval,
total_steps=steps_per_eval,
optimizer=optimizer,
Expand All @@ -375,6 +378,16 @@ def distributed_eval_step(iterator):
steps_per_second=eval_steps_per_second)
# Reset timer
train_time_start = datetime.datetime.now()
# At the end of an epoch, create a savepoint checkpoint
# which will be used to resume training in the event of preemption or
# crashes. Intermediate checkpoints can still be used to
# select the best checkpoint.
epoch_checkpoint_name = model_utils.save_checkpoint(
checkpoint, out_dir, [eval_loss] + eval_metrics,
write_checkpoint_metrics)
with tf.io.gfile.GFile(epoch_checkpoint, 'w') as f:
logging.info('Epoch checkpoint: %s %s', epoch_checkpoint_name, epoch + 1)
f.write(f'{epoch_checkpoint_name}\t{epoch}')


def train(teacher_model_dir: str,
Expand Down
15 changes: 14 additions & 1 deletion deepconsensus/models/model_train_custom_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def train_model(out_dir: str, params: ml_collections.ConfigDict,
else:
model = model_utils.get_model(params)
logging.info('Done building model.')
epoch_checkpoint = os.path.join(out_dir, 'epoch_checkpoint.txt')
optimizer = tf.keras.optimizers.Adam(learning_rate=params.learning_rate)
train_loss = tf.keras.metrics.Mean(name='train/loss')
train_metrics = model_utils.get_deepconsensus_metrics(name_prefix='train/')
Expand All @@ -119,7 +120,7 @@ def compute_loss(labels, predictions):

# model, optimizer, and checkpoint must be created under `strategy.scope`.
checkpoint, initial_epoch = model_utils.get_checkpoint_and_initial_epoch(
model, optimizer, out_dir, steps_per_epoch) # pytype: disable=wrong-arg-types # typed-keras
model, optimizer, epoch_checkpoint) # pytype: disable=wrong-arg-types # typed-keras

# Create summary writers
train_writer = tf.summary.create_file_writer(os.path.join(out_dir, 'train'))
Expand Down Expand Up @@ -206,6 +207,7 @@ def distributed_eval_step(iterator):
with train_writer.as_default():
model_utils.log_and_save_metrics(
epoch=epoch,
num_epochs=params['num_epochs'],
step=step_train,
total_steps=steps_per_epoch,
optimizer=optimizer,
Expand Down Expand Up @@ -240,6 +242,7 @@ def distributed_eval_step(iterator):
with eval_writer.as_default():
model_utils.log_and_save_metrics(
epoch=epoch,
num_epochs=params['num_epochs'],
step=step_eval,
total_steps=steps_per_eval,
optimizer=optimizer,
Expand All @@ -248,6 +251,16 @@ def distributed_eval_step(iterator):
steps_per_second=eval_steps_per_second)
# Reset timer
train_time_start = datetime.datetime.now()
# At the end of an epoch, create a savepoint checkpoint
# which will be used to resume training in the event of preemption or
# crashes. Intermediate checkpoints can still be used to
# select the best checkpoint.
epoch_checkpoint_name = model_utils.save_checkpoint(
checkpoint, out_dir, [eval_loss] + eval_metrics,
write_checkpoint_metrics)
with tf.io.gfile.GFile(epoch_checkpoint, 'w') as f:
logging.info('Epoch checkpoint: %s %s', epoch_checkpoint_name, epoch + 1)
f.write(f'{epoch_checkpoint_name}\t{epoch}')


def train(out_dir: str,
Expand Down
17 changes: 12 additions & 5 deletions deepconsensus/models/model_train_custom_loop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,20 @@ def test_train_e2e(self, config_name):
eval_event_file = glob.glob(os.path.join(out_dir, 'eval/*event*'))
self.assertLen(train_event_file, 1)
self.assertLen(eval_event_file, 1)
checkpoint_files = glob.glob(os.path.join(out_dir, 'checkpoint*'))
# +2 here for checkpoint and checkpoint_metrics.tsv
self.assertLen(checkpoint_files, params.num_epochs * 2 + 2)
checkpoint_metrics = glob.glob(
os.path.join(out_dir, 'checkpoint_metrics.tsv'))
self.assertLen(checkpoint_metrics, 1)
checkpoint_files = glob.glob(os.path.join(out_dir, 'checkpoint*index'))
self.assertNotEmpty(checkpoint_files)
checkpoint_metrics = glob.glob(
os.path.join(out_dir, 'checkpoint_metrics.tsv'))
self.assertNotEmpty(checkpoint_metrics)
json_params = glob.glob(os.path.join(out_dir, 'params.json'))
self.assertLen(json_params, 1)
self.assertNotEmpty(json_params)
best_checkpoint = glob.glob(os.path.join(out_dir, 'best_checkpoint.txt'))
self.assertLen(best_checkpoint, 1)
self.assertNotEmpty(best_checkpoint)
epoch_checkpoint = glob.glob(os.path.join(out_dir, 'epoch_checkpoint.txt'))
self.assertNotEmpty(epoch_checkpoint)


if __name__ == '__main__':
Expand Down
30 changes: 22 additions & 8 deletions deepconsensus/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,15 +381,20 @@ def get_step_counts(params: ml_collections.ConfigDict,

def get_checkpoint_and_initial_epoch(
model: tf.keras.models.Model, optimizer: tf.keras.optimizers.Optimizer,
out_dir: str, steps_per_epoch: int) -> Tuple[tf.train.Checkpoint, int]:
epoch_checkpoint: str) -> Tuple[tf.train.Checkpoint, int]:
"""Loads a checkpoint if available and sets epoch to start training."""
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
latest_checkpoint = tf.train.latest_checkpoint(out_dir)
initial_epoch = 0
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
logging.info('Loaded checkpoint %s', latest_checkpoint)
initial_epoch = optimizer.iterations.numpy() // steps_per_epoch
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
if tf.io.gfile.exists(epoch_checkpoint):
with tf.io.gfile.GFile(epoch_checkpoint, 'r') as f:
epoch_checkpoint, initial_epoch = f.readline().split('\t')
initial_epoch = int(initial_epoch)
checkpoint.restore(epoch_checkpoint)
logging.info('Loading checkpoint %s for epoch %s', epoch_checkpoint,
initial_epoch)
else:
logging.info('No Epoch checkpoint. Starting from epoch %s', initial_epoch)
initial_epoch = 0
return checkpoint, initial_epoch


Expand All @@ -399,7 +404,8 @@ def reset_all_metrics(metrics: List[tf.keras.metrics.Metric]) -> None:
metric.reset_states()


def log_and_save_metrics(epoch: int, step: int, total_steps: int,
def log_and_save_metrics(epoch: int, num_epochs: int, step: int,
total_steps: int,
optimizer: tf.keras.optimizers.Optimizer,
metrics: List[tf.keras.metrics.Metric], training: bool,
steps_per_second: float) -> None:
Expand All @@ -408,8 +414,16 @@ def log_and_save_metrics(epoch: int, step: int, total_steps: int,
'epoch: %d step: %d of %d metrics: %s', epoch, step, total_steps,
' '.join(f'{metric.name}= {metric.result()}' for metric in metrics))

overall_progress = optimizer.iterations.numpy() / (total_steps * num_epochs)


if training:
tf.summary.scalar('learning_rate', optimizer.lr, step=optimizer.iterations)
tf.summary.scalar('progress/epoch', epoch, step=optimizer.iterations)
tf.summary.scalar(
'progress/overall_progress',
overall_progress,
step=optimizer.iterations)
for metric in metrics:
tf.summary.scalar(metric.name, metric.result(), step=optimizer.iterations)
metric.reset_states()
Expand Down

0 comments on commit 8b9268c

Please sign in to comment.