[go: nahoru, domu]

Skip to content

Commit

Permalink
Internal Updates
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 501685566
  • Loading branch information
danielecook authored and Copybara-Service committed Jan 12, 2023
1 parent d9d709c commit 55ef05e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
9 changes: 5 additions & 4 deletions deepconsensus/models/model_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,12 @@
from ml_collections.config_flags import config_flags
import tensorflow as tf


from deepconsensus.models import data_providers
from deepconsensus.models import losses_and_metrics
from deepconsensus.models import model_utils
from deepconsensus.utils import dc_constants


# pylint: disable=unused-import

FLAGS = flags.FLAGS
config_flags.DEFINE_config_file('params', None, 'Training configuration.')
_TEACHER_MODEL_DIR = flags.DEFINE_string(
Expand Down Expand Up @@ -154,6 +152,10 @@ def train_model(teacher_model: tf.keras.Model, out_dir: str,
"""Trains the model under the given strategy and params."""
# Freeze config dict here to ensure it is hashable.
params = ml_collections.FrozenConfigDict(params)

if out_dir is None:
raise ValueError('--out_dir must be defined.')

model_utils.save_params_as_json(out_dir, params)
train_dataset, eval_dataset = model_utils.get_datasets(params, strategy)
train_iterator = iter(train_dataset)
Expand Down Expand Up @@ -433,6 +435,5 @@ def main(unused_args=None):
flags.mark_flags_as_required([
'teacher_model_dir',
'params',
'out_dir',
])
app.run(main)
7 changes: 4 additions & 3 deletions deepconsensus/models/model_train_custom_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
time blaze run -c opt \
//learning/genomics/deepconsensus/models:model_train_custom_loop -- \
--params ${CONFIG} \
--out_dir ${OUT_DIR} \
--xm_runlocal \
--alsologtostderr
"""
Expand All @@ -58,7 +57,6 @@
from deepconsensus.models import model_utils
from deepconsensus.utils import dc_constants

# pylint: disable=unused-import g-import-not-at-top

FLAGS = flags.FLAGS
config_flags.DEFINE_config_file('params', None, 'Training configuration.')
Expand All @@ -84,6 +82,10 @@ def train_model(out_dir: str, params: ml_collections.ConfigDict,
"""Trains the model under the given strategy and params."""
# Freeze config dict here to ensure it is hashable.
params = ml_collections.FrozenConfigDict(params)

if out_dir is None:
raise ValueError('--out_dir must be defined.')

model_utils.save_params_as_json(out_dir, params)
train_dataset, eval_dataset = model_utils.get_datasets(params, strategy)
train_iterator = iter(train_dataset)
Expand Down Expand Up @@ -300,6 +302,5 @@ def main(unused_args=None):
if __name__ == '__main__':
flags.mark_flags_as_required([
'params',
'out_dir',
])
app.run(main)

0 comments on commit 55ef05e

Please sign in to comment.