From 55ef05ec381606365fd79eb760c6ff00cdfbbcf8 Mon Sep 17 00:00:00 2001 From: danielecook Date: Thu, 12 Jan 2023 15:41:04 -0800 Subject: [PATCH] Internal Updates PiperOrigin-RevId: 501685566 --- deepconsensus/models/model_distillation.py | 9 +++++---- deepconsensus/models/model_train_custom_loop.py | 7 ++++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/deepconsensus/models/model_distillation.py b/deepconsensus/models/model_distillation.py index ffc827f..8961486 100644 --- a/deepconsensus/models/model_distillation.py +++ b/deepconsensus/models/model_distillation.py @@ -60,14 +60,12 @@ from ml_collections.config_flags import config_flags import tensorflow as tf + from deepconsensus.models import data_providers from deepconsensus.models import losses_and_metrics from deepconsensus.models import model_utils from deepconsensus.utils import dc_constants - -# pylint: disable=unused-import - FLAGS = flags.FLAGS config_flags.DEFINE_config_file('params', None, 'Training configuration.') _TEACHER_MODEL_DIR = flags.DEFINE_string( @@ -154,6 +152,10 @@ def train_model(teacher_model: tf.keras.Model, out_dir: str, """Trains the model under the given strategy and params.""" # Freeze config dict here to ensure it is hashable. params = ml_collections.FrozenConfigDict(params) + + if out_dir is None: + raise ValueError('--out_dir must be defined.') + model_utils.save_params_as_json(out_dir, params) train_dataset, eval_dataset = model_utils.get_datasets(params, strategy) train_iterator = iter(train_dataset) @@ -433,6 +435,5 @@ def main(unused_args=None): flags.mark_flags_as_required([ 'teacher_model_dir', 'params', - 'out_dir', ]) app.run(main) diff --git a/deepconsensus/models/model_train_custom_loop.py b/deepconsensus/models/model_train_custom_loop.py index f3452d0..afafe9e 100644 --- a/deepconsensus/models/model_train_custom_loop.py +++ b/deepconsensus/models/model_train_custom_loop.py @@ -36,7 +36,6 @@ time blaze run -c opt \ //learning/genomics/deepconsensus/models:model_train_custom_loop -- \ --params ${CONFIG} \ - --out_dir ${OUT_DIR} \ --xm_runlocal \ --alsologtostderr """ @@ -58,7 +57,6 @@ from deepconsensus.models import model_utils from deepconsensus.utils import dc_constants -# pylint: disable=unused-import g-import-not-at-top FLAGS = flags.FLAGS config_flags.DEFINE_config_file('params', None, 'Training configuration.') @@ -84,6 +82,10 @@ def train_model(out_dir: str, params: ml_collections.ConfigDict, """Trains the model under the given strategy and params.""" # Freeze config dict here to ensure it is hashable. params = ml_collections.FrozenConfigDict(params) + + if out_dir is None: + raise ValueError('--out_dir must be defined.') + model_utils.save_params_as_json(out_dir, params) train_dataset, eval_dataset = model_utils.get_datasets(params, strategy) train_iterator = iter(train_dataset) @@ -300,6 +302,5 @@ def main(unused_args=None): if __name__ == '__main__': flags.mark_flags_as_required([ 'params', - 'out_dir', ]) app.run(main)