[go: nahoru, domu]

Skip to content

Commit

Permalink
Fix inconsistent training steps.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 496754412
  • Loading branch information
Genomics team in Google Health authored and Copybara-Service committed Dec 20, 2022
1 parent 0023941 commit 09d3660
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 56 deletions.
27 changes: 11 additions & 16 deletions deepconsensus/models/model_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ def train_model(teacher_model: tf.keras.Model, out_dir: str,
params = ml_collections.FrozenConfigDict(params)
model_utils.save_params_as_json(out_dir, params)
train_dataset, eval_dataset = model_utils.get_datasets(params, strategy)
train_iterator = iter(train_dataset)
eval_iterator = iter(eval_dataset)
steps_per_epoch, steps_per_eval = model_utils.get_step_counts(
params, _EVAL_AND_LOG_EVERY_STEP.value)
# Number of steps this model will train for.
Expand All @@ -172,7 +174,7 @@ def train_model(teacher_model: tf.keras.Model, out_dir: str,
model_utils.print_model_summary(model, input_shape)
logging.info('Done building model.')
# Initialize student model from teacher based on model params.
epoch_checkpoint = os.path.join(out_dir, 'epoch_checkpoint.txt')
eval_checkpoint = os.path.join(out_dir, 'eval_checkpoint.txt')
model = init_student_from_teacher(model, teacher_model, params)

# Calculate the number of steps to decay the learning rate over.
Expand Down Expand Up @@ -225,8 +227,8 @@ def compute_loss(labels: tf.Tensor, student_preds: tf.Tensor,
return losses_dict

# model, optimizer, and checkpoint must be created under `strategy.scope`.
checkpoint, initial_epoch = model_utils.get_checkpoint_and_initial_epoch(
model, optimizer, False, out_dir, steps_per_epoch, epoch_checkpoint) # pytype: disable=wrong-arg-types # typed-keras
checkpoint, initial_epoch, initial_step_train = model_utils.get_checkpoint_and_initial_epoch(
model, optimizer, out_dir, eval_checkpoint) # pytype: disable=wrong-arg-types # typed-keras

# Create summary writers
train_writer = tf.summary.create_file_writer(os.path.join(out_dir, 'train'))
Expand Down Expand Up @@ -317,8 +319,6 @@ def distributed_eval_step(iterator):
log_eval_steps = 3000
if _EVAL_AND_LOG_EVERY_STEP.value:
log_train_steps = 1
train_iterator = iter(train_dataset)
eval_iterator = iter(eval_dataset)

# Decide the best checkpoiht using main eval metric.
max_main_eval_metric = 0.0
Expand All @@ -332,7 +332,7 @@ def distributed_eval_step(iterator):
for epoch in range(initial_epoch, params['num_epochs']):
logging.info('Starting to run epoch: %s', epoch)
train_time_start = datetime.datetime.now()
for step_train in range(steps_per_epoch):
for step_train in range(initial_step_train, steps_per_epoch):
distributed_train_step(train_iterator)
# Log and reset train metrics.
if optimizer.iterations % log_train_steps == 0:
Expand Down Expand Up @@ -365,6 +365,9 @@ def distributed_eval_step(iterator):
checkpoint_name = model_utils.save_checkpoint(
checkpoint, out_dir, [eval_loss] + eval_metrics,
write_checkpoint_metrics)
with tf.io.gfile.GFile(eval_checkpoint, 'w') as f:
f.write(f'{checkpoint_name}\t{epoch}\t{step_train}')

# Record the best checkpoint based on the main eval metric.
main_eval_metric_val = float(main_eval_metric.result())
if main_eval_metric_val >= max_main_eval_metric:
Expand All @@ -386,16 +389,8 @@ def distributed_eval_step(iterator):
steps_per_second=eval_steps_per_second)
# Reset timer
train_time_start = datetime.datetime.now()
# At the end of an epoch, create a savepoint checkpoint
# which will be used to resume training in the event of preemption or
# crashes. Intermediate checkpoints can still be used to
# select the best checkpoint.
epoch_checkpoint_name = model_utils.save_checkpoint(
checkpoint, out_dir, [eval_loss] + eval_metrics,
write_checkpoint_metrics)
with tf.io.gfile.GFile(epoch_checkpoint, 'w') as f:
logging.info('Epoch checkpoint: %s %s', epoch_checkpoint_name, epoch + 1)
f.write(f'{epoch_checkpoint_name}\t{epoch}')

initial_step_train = 0


def train(teacher_model_dir: str,
Expand Down
27 changes: 11 additions & 16 deletions deepconsensus/models/model_train_custom_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def train_model(out_dir: str, params: ml_collections.ConfigDict,
params = ml_collections.FrozenConfigDict(params)
model_utils.save_params_as_json(out_dir, params)
train_dataset, eval_dataset = model_utils.get_datasets(params, strategy)
train_iterator = iter(train_dataset)
eval_iterator = iter(eval_dataset)
steps_per_epoch, steps_per_eval = model_utils.get_step_counts(
params, _EVAL_AND_LOG_EVERY_STEP.value)
# Number of steps this model will train for.
Expand All @@ -102,7 +104,7 @@ def train_model(out_dir: str, params: ml_collections.ConfigDict,
else:
model = model_utils.get_model(params)
logging.info('Done building model.')
epoch_checkpoint = os.path.join(out_dir, 'epoch_checkpoint.txt')
eval_checkpoint = os.path.join(out_dir, 'eval_checkpoint.txt')

# Calculate the number of steps to decay the learning rate over.
# Usually this number is the total training steps. However, since we train
Expand All @@ -128,8 +130,8 @@ def compute_loss(labels, predictions):
per_example_loss, global_batch_size=params.batch_size)

# model, optimizer, and checkpoint must be created under `strategy.scope`.
checkpoint, initial_epoch = model_utils.get_checkpoint_and_initial_epoch(
model, optimizer, False, out_dir, steps_per_epoch, epoch_checkpoint) # pytype: disable=wrong-arg-types # typed-keras
checkpoint, initial_epoch, initial_step_train = model_utils.get_checkpoint_and_initial_epoch(
model, optimizer, out_dir, eval_checkpoint) # pytype: disable=wrong-arg-types # typed-keras

# Create summary writers
train_writer = tf.summary.create_file_writer(os.path.join(out_dir, 'train'))
Expand Down Expand Up @@ -188,8 +190,6 @@ def distributed_eval_step(iterator):
log_eval_steps = 3000
if _EVAL_AND_LOG_EVERY_STEP.value:
log_train_steps = 1
train_iterator = iter(train_dataset)
eval_iterator = iter(eval_dataset)

# Decide the best checkpoiht using main eval metric.
max_main_eval_metric = 0.0
Expand All @@ -203,7 +203,7 @@ def distributed_eval_step(iterator):
for epoch in range(initial_epoch, params['num_epochs']):
logging.info('Starting to run epoch: %s', epoch)
train_time_start = datetime.datetime.now()
for step_train in range(1, steps_per_epoch + 1):
for step_train in range(initial_step_train, steps_per_epoch):
distributed_train_step(train_iterator)
# Log and reset train metrics.
if optimizer.iterations % log_train_steps == 0:
Expand Down Expand Up @@ -236,6 +236,9 @@ def distributed_eval_step(iterator):
checkpoint_name = model_utils.save_checkpoint(
checkpoint, out_dir, [eval_loss] + eval_metrics,
write_checkpoint_metrics)
with tf.io.gfile.GFile(eval_checkpoint, 'w') as f:
f.write(f'{checkpoint_name}\t{epoch}\t{step_train}')

# Record the best checkpoint based on the main eval metric.
main_eval_metric_val = float(main_eval_metric.result())
if main_eval_metric_val >= max_main_eval_metric:
Expand All @@ -257,16 +260,8 @@ def distributed_eval_step(iterator):
steps_per_second=eval_steps_per_second)
# Reset timer
train_time_start = datetime.datetime.now()
# At the end of an epoch, create a savepoint checkpoint
# which will be used to resume training in the event of preemption or
# crashes. Intermediate checkpoints can still be used to
# select the best checkpoint.
epoch_checkpoint_name = model_utils.save_checkpoint(
checkpoint, out_dir, [eval_loss] + eval_metrics,
write_checkpoint_metrics)
with tf.io.gfile.GFile(epoch_checkpoint, 'w') as f:
logging.info('Epoch checkpoint: %s %s', epoch_checkpoint_name, epoch + 1)
f.write(f'{epoch_checkpoint_name}\t{epoch}')

initial_step_train = 0


def train(out_dir: str,
Expand Down
2 changes: 0 additions & 2 deletions deepconsensus/models/model_train_custom_loop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ def test_train_e2e(self, config_name):
self.assertNotEmpty(json_params)
best_checkpoint = glob.glob(os.path.join(out_dir, 'best_checkpoint.txt'))
self.assertNotEmpty(best_checkpoint)
epoch_checkpoint = glob.glob(os.path.join(out_dir, 'epoch_checkpoint.txt'))
self.assertNotEmpty(epoch_checkpoint)


if __name__ == '__main__':
Expand Down
38 changes: 16 additions & 22 deletions deepconsensus/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,32 +454,26 @@ def get_step_counts(params: ml_collections.ConfigDict,

def get_checkpoint_and_initial_epoch(
model: tf.keras.models.Model, optimizer: tf.keras.optimizers.Optimizer,
reload_from_epoch_start: bool, out_dir: str, steps_per_epoch: int,
epoch_checkpoint: str) -> Tuple[tf.train.Checkpoint, int]:
out_dir: str, eval_checkpoint: str) -> Tuple[tf.train.Checkpoint, int, int]:
"""Loads a checkpoint if available and sets epoch to start training."""
initial_epoch = 0
initial_step_train = 0
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
if reload_from_epoch_start:
# Load the checkpoint that corresponds to the beginning of the epoch.
# TODO.
if tf.io.gfile.exists(epoch_checkpoint):
with tf.io.gfile.GFile(epoch_checkpoint, 'r') as f:
epoch_checkpoint, initial_epoch = f.readline().split('\t')
# Load from the latest checkpoint if it exists.
latest_checkpoint = tf.train.latest_checkpoint(out_dir)
if latest_checkpoint:
if tf.io.gfile.exists(eval_checkpoint):
with tf.io.gfile.GFile(eval_checkpoint, 'r') as f:
checkpoint.restore(latest_checkpoint)
eval_checkpoint, initial_epoch, initial_step_train = f.readline().split(
'\t')
initial_epoch = int(initial_epoch)
checkpoint.restore(epoch_checkpoint)
logging.info('Loading checkpoint %s for epoch %s', epoch_checkpoint,
initial_epoch)
else:
logging.info('No Epoch checkpoint. Starting from epoch %s', initial_epoch)
initial_epoch = 0
else:
# Load from the latest checkpoint if it exists.
latest_checkpoint = tf.train.latest_checkpoint(out_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
logging.info('Loaded checkpoint %s', latest_checkpoint)
initial_epoch = optimizer.iterations.numpy() // steps_per_epoch
return checkpoint, initial_epoch
initial_step_train = int(initial_step_train)
logging.info('Loaded checkpoint %s (%s) for epoch %s step %s',
latest_checkpoint, eval_checkpoint, initial_epoch,
initial_step_train)
initial_step_train += 1
return checkpoint, initial_epoch, initial_step_train


def reset_all_metrics(metrics: List[tf.keras.metrics.Metric]) -> None:
Expand Down

0 comments on commit 09d3660

Please sign in to comment.