[go: nahoru, domu]

Skip to content

Commit

Permalink
Update tf version to >= 2.9 and remove remaining dependencies on tens…
Browse files Browse the repository at this point in the history
…orflow_models.official.legacy.transformer

PiperOrigin-RevId: 458027974
  • Loading branch information
anastasiyabl authored and Copybara-Service committed Jun 29, 2022
1 parent 83007b0 commit f25f44e
Show file tree
Hide file tree
Showing 18 changed files with 168 additions and 612 deletions.
393 changes: 0 additions & 393 deletions deepconsensus/models/legacy_networks.py

This file was deleted.

118 changes: 0 additions & 118 deletions deepconsensus/models/legacy_networks_test.py

This file was deleted.

2 changes: 1 addition & 1 deletion deepconsensus/models/losses_and_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ def test_distillation_loss_fn(self, batch_size, window_length, temperature,
distill_loss = distill_loss + kl_ij
# Get the distillation loss over the whole window.
distill_loss = distill_loss / window_length
self.assertAlmostEqual(distill_loss, expected_loss[example_ind])
self.assertAlmostEqual(distill_loss, expected_loss[example_ind], places=6)


if __name__ == '__main__':
Expand Down
78 changes: 10 additions & 68 deletions deepconsensus/models/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,46 +65,11 @@ def _set_base_fc_hparams(params):
params.buffer_size = 1000


def _set_base_transformer_v2_hparams(params):
"""Updates given config with base values for the Transformer model."""
# Architecture
params.model_name = 'transformer_v2'
params.add_pos_encoding = True
# Num heads should be divisible by hidden size. This value should be tuned for
# the production setting. TODO: update this parameter after
# tuning.
params.num_heads = 2
params.layer_norm = False
params.dtype = dc_constants.TF_DATA_TYPE
params.condense_transformer_input = False
params.transformer_model_size = 'base'

params.num_channels = 1
params.use_bases = True
params.use_pw = True
params.use_ip = True
params.use_ccs = True
params.use_strand = True
params.use_sn = True
params.per_base_hidden_size = 1
params.pw_hidden_size = 1
params.ip_hidden_size = 1
params.sn_hidden_size = 1
params.strand_hidden_size = 1

# Training
params.batch_size = 256
params.num_epochs = 50
params.learning_rate = 1e-4
params.buffer_size = 1000


def _set_base_transformer_hparams(params):
"""Updates given config with base values for the Transformer model."""
# Architecture
params.model_name = 'transformer'
params.add_pos_encoding = True
params.use_relative_pos_enc = True
# Num heads should be divisible by hidden size. This value should be tuned for
# the production setting. TODO: update this parameter after
# tuning.
Expand Down Expand Up @@ -135,28 +100,11 @@ def _set_base_transformer_hparams(params):


def _set_transformer_learned_embeddings_hparams(params):
"""Updates given config with values for the learned embeddings transformer."""
_set_base_transformer_hparams(params)
params.model_name = 'transformer_learn_values'
params.PW_MAX = dc_constants.PW_MAX
params.IP_MAX = dc_constants.IP_MAX
params.STRAND_MAX = dc_constants.STRAND_MAX
params.SN_MAX = dc_constants.SN_MAX
params.per_base_hidden_size = 8
params.pw_hidden_size = 8
params.ip_hidden_size = 8
params.strand_hidden_size = 2
params.sn_hidden_size = 8
params.condense_transformer_input = True
params.transformer_input_size = 280


def _set_transformer_learned_embeddings_v2_hparams(params):
"""Updates given config with values for the learned embeddings transformer."""
# TODO: As we migrate off the legacy code, we might need to
# adjust the params below. For now just making a copy of the previous params.
_set_base_transformer_v2_hparams(params)
params.model_name = 'transformer_learn_values_v2'
_set_base_transformer_hparams(params)
params.model_name = 'transformer_learn_values'
params.PW_MAX = dc_constants.PW_MAX
params.IP_MAX = dc_constants.IP_MAX
params.STRAND_MAX = dc_constants.STRAND_MAX
Expand All @@ -170,10 +118,10 @@ def _set_transformer_learned_embeddings_v2_hparams(params):
params.transformer_input_size = 280


def _set_transformer_learned_embeddings_v2_distill_hparams(params):
def _set_transformer_learned_embeddings_distill_hparams(params):
"""Updates given config with values for the distilled transformer."""
_set_transformer_learned_embeddings_v2_hparams(params)
params.model_name = 'transformer_learn_values_v2_distill'
_set_transformer_learned_embeddings_hparams(params)
params.model_name = 'transformer_learn_values_distill'

# Student architecture parameters.
params.num_hidden_layers = 4
Expand Down Expand Up @@ -239,14 +187,12 @@ def get_config(config_name: str) -> ml_collections.ConfigDict:
Valid config names must consist of two parts: {model_name}+{dataset_name}. The
"+" must be present as a separator between the two parts. For example,
transformer_learn_bases+ccs is a valid name.
transformer_learn_values+ccs is a valid name.
Valid model names include:
* fc
* transformer (TODO: legacy codebase)
* transformer_learn_values (TODO: legacy codebase)
* transformer_v2
* transformer_learn_values_v2
* transformer
* transformer_learn_values
Valid dataset names include:
* ecoli
Expand Down Expand Up @@ -289,16 +235,12 @@ def get_config(config_name: str) -> ml_collections.ConfigDict:
params.limit = -1
if model_config_name == 'fc':
_set_base_fc_hparams(params)
elif model_config_name == 'transformer_v2':
_set_base_transformer_v2_hparams(params)
elif model_config_name == 'transformer':
_set_base_transformer_hparams(params)
elif model_config_name == 'transformer_learn_values_v2':
_set_transformer_learned_embeddings_v2_hparams(params)
elif model_config_name == 'transformer_learn_values':
_set_transformer_learned_embeddings_hparams(params)
elif model_config_name == 'transformer_learn_values_v2_distill':
_set_transformer_learned_embeddings_v2_distill_hparams(params)
elif model_config_name == 'transformer_learn_values_distill':
_set_transformer_learned_embeddings_distill_hparams(params)
else:
raise ValueError('Unknown model_config_name: %s' % model_config_name)

Expand Down
4 changes: 2 additions & 2 deletions deepconsensus/models/model_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@
Distillation attempts to train a smaller student model that mimics the larger
teacher model.
Currently only transformer_learn_values_v2_distill config is
Currently only transformer_learn_values_distill config is
supported for model training.
Example usage:
CONFIG="//learning/genomics/deepconsensus/models/model_configs.py:transformer_learn_values_v2_distill+ccs"
CONFIG="//learning/genomics/deepconsensus/models/model_configs.py:transformer_learn_values_distill+ccs"
TEACHER_MODEL_DIR=""
OUT_DIR=/tmp
Expand Down
2 changes: 1 addition & 1 deletion deepconsensus/models/model_distillation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

class ModelTrainTest(parameterized.TestCase):

@parameterized.parameters(['transformer_learn_values_v2_distill+test'])
@parameterized.parameters(['transformer_learn_values_distill+test'])
def test_train_e2e(self, config_name):
"""Tests that training completes and output files written."""

Expand Down
2 changes: 1 addition & 1 deletion deepconsensus/models/model_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ModelInferenceTest(absltest.TestCase):
def test_inference_e2e(self):
"""Tests that inference finishes running and an output file is created."""

config_name = 'transformer_learn_values_v2+test'
config_name = 'transformer_learn_values+test'
out_dir = self.create_tempdir().full_path
checkpoint_path = test_utils.deepconsensus_testdata('model/checkpoint-1')
params = model_configs.get_config(config_name)
Expand Down
2 changes: 1 addition & 1 deletion deepconsensus/models/model_train_custom_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
To use this binary for training a specific model, the corresponding config file
should be specified as input. Example usage:
CONFIG="//learning/genomics/deepconsensus/models/model_configs.py:transformer_learn_values_v2+ccs"
CONFIG="//learning/genomics/deepconsensus/models/model_configs.py:transformer_learn_values+ccs"
OUT_DIR=/tmp
time blaze run -c opt \
Expand Down
39 changes: 24 additions & 15 deletions deepconsensus/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,17 @@
import json
import logging
import os
from typing import List, Optional, Tuple, Any, Union, Dict
from typing import Any, Dict, List, Optional, Tuple, Union

import ml_collections
import numpy as np
import tensorflow as tf

from deepconsensus.models import data_providers
from deepconsensus.models import legacy_networks
from deepconsensus.models import losses_and_metrics
from deepconsensus.models import networks
from deepconsensus.models import transformer_basic_params
from deepconsensus.utils import dc_constants
from official.nlp.transformer import misc


def get_deepconsensus_loss(
Expand Down Expand Up @@ -119,16 +118,8 @@ def get_model(params: ml_collections.ConfigDict) -> tf.keras.Model:
if params.model_name == 'fc':
model = networks.FullyConnectedNet(params)
elif params.model_name == 'transformer':
model = legacy_networks.EncoderOnlyTransformer(params)
# I'm using "_v2" suffix for the new code migrated out of legacy. Feel free
# to suggest more informative names.
elif params.model_name == 'transformer_v2':
model = networks.EncoderOnlyTransformer(params)
elif params.model_name == 'transformer_learn_values':
model = legacy_networks.EncoderOnlyLearnedValuesTransformer(params)
# I'm using "_v2" suffix for the new code migrated out of legacy. Feel free
# to suggest more informative names.
elif 'transformer_learn_values_v2' in params.model_name:
elif 'transformer_learn_values' in params.model_name:
model = networks.EncoderOnlyLearnedValuesTransformer(params)
else:
raise ValueError('Unknown model name: %s' % params.model_name)
Expand Down Expand Up @@ -228,8 +219,7 @@ def modify_params(params: ml_collections.ConfigDict,
params.hidden_size += 1

# Set model-specific parameters
if (params.model_name == 'transformer' or
params.model_name == 'transformer_v2'):
if params.model_name == 'transformer':
# Transformer code uses default_batch_size, whereas my code uses
# batch_size, so make sure both are the same.
params.default_batch_size = params.batch_size
Expand All @@ -241,14 +231,33 @@ def modify_params(params: ml_collections.ConfigDict,
logging.info('Setting hidden size to transformer_input_size.')
params.hidden_size = params.transformer_input_size
if 'transformer' in params.model_name:
transformer_params = misc.get_model_params(
transformer_params = get_transformer_model_params(
params.transformer_model_size, num_gpus=num_gpus)
# Only add hyperparameters that don't already exist.
for param_name, param_value in transformer_params.items():
if param_name not in params:
params[param_name] = param_value


def get_transformer_model_params(param_set, num_gpus):
"""Gets predefined transformer model params."""
params_map = {
'tiny': transformer_basic_params.TINY_PARAMS,
'base': transformer_basic_params.BASE_PARAMS,
'big': transformer_basic_params.BIG_PARAMS,
}
if num_gpus > 1:
if param_set == 'big':
return transformer_basic_params.BIG_MULTI_GPU_PARAMS.copy()
elif param_set == 'base':
return transformer_basic_params.BASE_MULTI_GPU_PARAMS.copy()
else:
raise ValueError('Not valid params: param_set={} num_gpus={}'.format(
param_set, num_gpus))

return params_map[param_set].copy()


def run_inference_and_write_results(model: tf.keras.Model,
out_dir: str,
params: ml_collections.ConfigDict,
Expand Down
2 changes: 1 addition & 1 deletion deepconsensus/models/model_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_output_dir_created(self):

out_dir = f'/tmp/output_dir/{uuid.uuid1()}'
self.assertFalse(tf.io.gfile.isdir(out_dir))
params = model_configs.get_config('transformer_learn_values_v2+test')
params = model_configs.get_config('transformer_learn_values+test')
model_utils.modify_params(params)
model = model_utils.get_model(params)
checkpoint_path = test_utils.deepconsensus_testdata('model/checkpoint-1')
Expand Down
Loading

0 comments on commit f25f44e

Please sign in to comment.