diff --git a/deepconsensus/models/model_distillation.py b/deepconsensus/models/model_distillation.py index 6dd895f..ffc827f 100644 --- a/deepconsensus/models/model_distillation.py +++ b/deepconsensus/models/model_distillation.py @@ -156,6 +156,8 @@ def train_model(teacher_model: tf.keras.Model, out_dir: str, params = ml_collections.FrozenConfigDict(params) model_utils.save_params_as_json(out_dir, params) train_dataset, eval_dataset = model_utils.get_datasets(params, strategy) + train_iterator = iter(train_dataset) + eval_iterator = iter(eval_dataset) steps_per_epoch, steps_per_eval = model_utils.get_step_counts( params, _EVAL_AND_LOG_EVERY_STEP.value) # Number of steps this model will train for. @@ -172,7 +174,7 @@ def train_model(teacher_model: tf.keras.Model, out_dir: str, model_utils.print_model_summary(model, input_shape) logging.info('Done building model.') # Initialize student model from teacher based on model params. - epoch_checkpoint = os.path.join(out_dir, 'epoch_checkpoint.txt') + eval_checkpoint = os.path.join(out_dir, 'eval_checkpoint.txt') model = init_student_from_teacher(model, teacher_model, params) # Calculate the number of steps to decay the learning rate over. @@ -225,8 +227,8 @@ def compute_loss(labels: tf.Tensor, student_preds: tf.Tensor, return losses_dict # model, optimizer, and checkpoint must be created under `strategy.scope`. - checkpoint, initial_epoch = model_utils.get_checkpoint_and_initial_epoch( - model, optimizer, False, out_dir, steps_per_epoch, epoch_checkpoint) # pytype: disable=wrong-arg-types # typed-keras + checkpoint, initial_epoch, initial_step_train = model_utils.get_checkpoint_and_initial_epoch( + model, optimizer, out_dir, eval_checkpoint) # pytype: disable=wrong-arg-types # typed-keras # Create summary writers train_writer = tf.summary.create_file_writer(os.path.join(out_dir, 'train')) @@ -317,8 +319,6 @@ def distributed_eval_step(iterator): log_eval_steps = 3000 if _EVAL_AND_LOG_EVERY_STEP.value: log_train_steps = 1 - train_iterator = iter(train_dataset) - eval_iterator = iter(eval_dataset) # Decide the best checkpoiht using main eval metric. max_main_eval_metric = 0.0 @@ -332,7 +332,7 @@ def distributed_eval_step(iterator): for epoch in range(initial_epoch, params['num_epochs']): logging.info('Starting to run epoch: %s', epoch) train_time_start = datetime.datetime.now() - for step_train in range(steps_per_epoch): + for step_train in range(initial_step_train, steps_per_epoch): distributed_train_step(train_iterator) # Log and reset train metrics. if optimizer.iterations % log_train_steps == 0: @@ -365,6 +365,9 @@ def distributed_eval_step(iterator): checkpoint_name = model_utils.save_checkpoint( checkpoint, out_dir, [eval_loss] + eval_metrics, write_checkpoint_metrics) + with tf.io.gfile.GFile(eval_checkpoint, 'w') as f: + f.write(f'{checkpoint_name}\t{epoch}\t{step_train}') + # Record the best checkpoint based on the main eval metric. main_eval_metric_val = float(main_eval_metric.result()) if main_eval_metric_val >= max_main_eval_metric: @@ -386,16 +389,8 @@ def distributed_eval_step(iterator): steps_per_second=eval_steps_per_second) # Reset timer train_time_start = datetime.datetime.now() - # At the end of an epoch, create a savepoint checkpoint - # which will be used to resume training in the event of preemption or - # crashes. Intermediate checkpoints can still be used to - # select the best checkpoint. - epoch_checkpoint_name = model_utils.save_checkpoint( - checkpoint, out_dir, [eval_loss] + eval_metrics, - write_checkpoint_metrics) - with tf.io.gfile.GFile(epoch_checkpoint, 'w') as f: - logging.info('Epoch checkpoint: %s %s', epoch_checkpoint_name, epoch + 1) - f.write(f'{epoch_checkpoint_name}\t{epoch}') + + initial_step_train = 0 def train(teacher_model_dir: str, diff --git a/deepconsensus/models/model_train_custom_loop.py b/deepconsensus/models/model_train_custom_loop.py index d2cfd5c..f3452d0 100644 --- a/deepconsensus/models/model_train_custom_loop.py +++ b/deepconsensus/models/model_train_custom_loop.py @@ -86,6 +86,8 @@ def train_model(out_dir: str, params: ml_collections.ConfigDict, params = ml_collections.FrozenConfigDict(params) model_utils.save_params_as_json(out_dir, params) train_dataset, eval_dataset = model_utils.get_datasets(params, strategy) + train_iterator = iter(train_dataset) + eval_iterator = iter(eval_dataset) steps_per_epoch, steps_per_eval = model_utils.get_step_counts( params, _EVAL_AND_LOG_EVERY_STEP.value) # Number of steps this model will train for. @@ -102,7 +104,7 @@ def train_model(out_dir: str, params: ml_collections.ConfigDict, else: model = model_utils.get_model(params) logging.info('Done building model.') - epoch_checkpoint = os.path.join(out_dir, 'epoch_checkpoint.txt') + eval_checkpoint = os.path.join(out_dir, 'eval_checkpoint.txt') # Calculate the number of steps to decay the learning rate over. # Usually this number is the total training steps. However, since we train @@ -128,8 +130,8 @@ def compute_loss(labels, predictions): per_example_loss, global_batch_size=params.batch_size) # model, optimizer, and checkpoint must be created under `strategy.scope`. - checkpoint, initial_epoch = model_utils.get_checkpoint_and_initial_epoch( - model, optimizer, False, out_dir, steps_per_epoch, epoch_checkpoint) # pytype: disable=wrong-arg-types # typed-keras + checkpoint, initial_epoch, initial_step_train = model_utils.get_checkpoint_and_initial_epoch( + model, optimizer, out_dir, eval_checkpoint) # pytype: disable=wrong-arg-types # typed-keras # Create summary writers train_writer = tf.summary.create_file_writer(os.path.join(out_dir, 'train')) @@ -188,8 +190,6 @@ def distributed_eval_step(iterator): log_eval_steps = 3000 if _EVAL_AND_LOG_EVERY_STEP.value: log_train_steps = 1 - train_iterator = iter(train_dataset) - eval_iterator = iter(eval_dataset) # Decide the best checkpoiht using main eval metric. max_main_eval_metric = 0.0 @@ -203,7 +203,7 @@ def distributed_eval_step(iterator): for epoch in range(initial_epoch, params['num_epochs']): logging.info('Starting to run epoch: %s', epoch) train_time_start = datetime.datetime.now() - for step_train in range(1, steps_per_epoch + 1): + for step_train in range(initial_step_train, steps_per_epoch): distributed_train_step(train_iterator) # Log and reset train metrics. if optimizer.iterations % log_train_steps == 0: @@ -236,6 +236,9 @@ def distributed_eval_step(iterator): checkpoint_name = model_utils.save_checkpoint( checkpoint, out_dir, [eval_loss] + eval_metrics, write_checkpoint_metrics) + with tf.io.gfile.GFile(eval_checkpoint, 'w') as f: + f.write(f'{checkpoint_name}\t{epoch}\t{step_train}') + # Record the best checkpoint based on the main eval metric. main_eval_metric_val = float(main_eval_metric.result()) if main_eval_metric_val >= max_main_eval_metric: @@ -257,16 +260,8 @@ def distributed_eval_step(iterator): steps_per_second=eval_steps_per_second) # Reset timer train_time_start = datetime.datetime.now() - # At the end of an epoch, create a savepoint checkpoint - # which will be used to resume training in the event of preemption or - # crashes. Intermediate checkpoints can still be used to - # select the best checkpoint. - epoch_checkpoint_name = model_utils.save_checkpoint( - checkpoint, out_dir, [eval_loss] + eval_metrics, - write_checkpoint_metrics) - with tf.io.gfile.GFile(epoch_checkpoint, 'w') as f: - logging.info('Epoch checkpoint: %s %s', epoch_checkpoint_name, epoch + 1) - f.write(f'{epoch_checkpoint_name}\t{epoch}') + + initial_step_train = 0 def train(out_dir: str, diff --git a/deepconsensus/models/model_train_custom_loop_test.py b/deepconsensus/models/model_train_custom_loop_test.py index 6b98118..c98023b 100644 --- a/deepconsensus/models/model_train_custom_loop_test.py +++ b/deepconsensus/models/model_train_custom_loop_test.py @@ -79,8 +79,6 @@ def test_train_e2e(self, config_name): self.assertNotEmpty(json_params) best_checkpoint = glob.glob(os.path.join(out_dir, 'best_checkpoint.txt')) self.assertNotEmpty(best_checkpoint) - epoch_checkpoint = glob.glob(os.path.join(out_dir, 'epoch_checkpoint.txt')) - self.assertNotEmpty(epoch_checkpoint) if __name__ == '__main__': diff --git a/deepconsensus/models/model_utils.py b/deepconsensus/models/model_utils.py index 2d9fd53..875eba3 100644 --- a/deepconsensus/models/model_utils.py +++ b/deepconsensus/models/model_utils.py @@ -454,32 +454,26 @@ def get_step_counts(params: ml_collections.ConfigDict, def get_checkpoint_and_initial_epoch( model: tf.keras.models.Model, optimizer: tf.keras.optimizers.Optimizer, - reload_from_epoch_start: bool, out_dir: str, steps_per_epoch: int, - epoch_checkpoint: str) -> Tuple[tf.train.Checkpoint, int]: + out_dir: str, eval_checkpoint: str) -> Tuple[tf.train.Checkpoint, int, int]: """Loads a checkpoint if available and sets epoch to start training.""" initial_epoch = 0 + initial_step_train = 0 checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) - if reload_from_epoch_start: - # Load the checkpoint that corresponds to the beginning of the epoch. - # TODO. - if tf.io.gfile.exists(epoch_checkpoint): - with tf.io.gfile.GFile(epoch_checkpoint, 'r') as f: - epoch_checkpoint, initial_epoch = f.readline().split('\t') + # Load from the latest checkpoint if it exists. + latest_checkpoint = tf.train.latest_checkpoint(out_dir) + if latest_checkpoint: + if tf.io.gfile.exists(eval_checkpoint): + with tf.io.gfile.GFile(eval_checkpoint, 'r') as f: + checkpoint.restore(latest_checkpoint) + eval_checkpoint, initial_epoch, initial_step_train = f.readline().split( + '\t') initial_epoch = int(initial_epoch) - checkpoint.restore(epoch_checkpoint) - logging.info('Loading checkpoint %s for epoch %s', epoch_checkpoint, - initial_epoch) - else: - logging.info('No Epoch checkpoint. Starting from epoch %s', initial_epoch) - initial_epoch = 0 - else: - # Load from the latest checkpoint if it exists. - latest_checkpoint = tf.train.latest_checkpoint(out_dir) - if latest_checkpoint: - checkpoint.restore(latest_checkpoint) - logging.info('Loaded checkpoint %s', latest_checkpoint) - initial_epoch = optimizer.iterations.numpy() // steps_per_epoch - return checkpoint, initial_epoch + initial_step_train = int(initial_step_train) + logging.info('Loaded checkpoint %s (%s) for epoch %s step %s', + latest_checkpoint, eval_checkpoint, initial_epoch, + initial_step_train) + initial_step_train += 1 + return checkpoint, initial_epoch, initial_step_train def reset_all_metrics(metrics: List[tf.keras.metrics.Metric]) -> None: