[go: nahoru, domu]

Skip to content

Commit

Permalink
Modify total_rows calculation to account for base quality scores.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 504955748
  • Loading branch information
danielecook authored and Copybara-Service committed Jan 26, 2023
1 parent da89777 commit 3884c05
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 70 deletions.
12 changes: 2 additions & 10 deletions deepconsensus/inference/quick_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
from absl import flags
from absl import logging
from ml_collections.config_dict import config_dict
from ml_collections.config_flags import config_flags
import numpy as np
import pandas as pd
import pysam
Expand Down Expand Up @@ -95,9 +94,6 @@ class DebugStage(enum.Enum):
'(e.g. "/path/to/model_directory/checkpoint-50"), '
'or to a saved model directory, (e.g. "/path/to/model_directory") '
'which is the directory that contains a saved_model.pb')
config_flags.DEFINE_config_file(
'params', None, 'params.json configuration file. By default, '
'/path/to/model_directory/params.json is used.')

# TODO Find out if this flag is needed here. Currently it is added to
# avoid pipeline to fail. max_length should be correctly read from checkpoint.
Expand Down Expand Up @@ -439,8 +435,7 @@ def initialize_model(
# If you don't do this, then assert_existing_objects_matched will not
# raise an error even if the wrong checkpoint is used.
# Some context here: b/148023980.
row_size = data_providers.get_total_rows(params.max_passes)
input_shape = (1, row_size, params.max_length, params.num_channels)
input_shape = (1, params.total_rows, params.max_length, params.num_channels)
model_utils.print_model_summary(model, input_shape)
checkpoint.restore(
checkpoint_path).expect_partial().assert_existing_objects_matched()
Expand Down Expand Up @@ -698,10 +693,7 @@ def run() -> stitch_utils.OutcomeCounter:
tf.io.gfile.exists(f'{FLAGS.checkpoint}/saved_model.pb'))

# Load model parameters
if not FLAGS.params:
params = model_utils.read_params_from_json(checkpoint_path=FLAGS.checkpoint)
else:
params = FLAGS.params
params = model_utils.read_params_from_json(checkpoint_path=FLAGS.checkpoint)

dc_config = pre_lib.DcConfig(params.max_passes, params.max_length)

Expand Down
4 changes: 1 addition & 3 deletions deepconsensus/models/convert_to_saved_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from absl import logging
import tensorflow as tf

from deepconsensus.models import data_providers
from deepconsensus.models import model_utils
from tensorflow.python.platform import gfile

Expand Down Expand Up @@ -79,8 +78,7 @@ def initialize_model(checkpoint_path: str) -> Optional[tf.keras.Model]:
# work as expected. If you don't do this, then assert_existing_objects_matched
# will not raise an error even if the wrong checkpoint is used.
# Some context here: b/148023980.
row_size = data_providers.get_total_rows(params.max_passes)
input_shape = (1, row_size, params.max_length, params.num_channels)
input_shape = (1, params.total_rows, params.max_length, params.num_channels)
model_utils.print_model_summary(model, input_shape)
checkpoint.restore(
checkpoint_path).expect_partial().assert_existing_objects_matched()
Expand Down
39 changes: 16 additions & 23 deletions deepconsensus/models/data_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@

from deepconsensus.utils import dc_constants


# Define base fields for TFRecords.
PROTO_FEATURES_INFERENCE = {
'name': tf.io.FixedLenFeature(shape=[1], dtype=tf.string),
Expand All @@ -57,27 +56,24 @@
}, **PROTO_FEATURES_INFERENCE)


def get_total_rows(max_passes: int) -> int:
"""Returns total rows in input tf.Examples.
Update if other signals added.
def get_total_rows(max_passes: int, use_ccs_bq: bool) -> int:
"""Calculates the number of rows in input examples.
For each of `max_subreads`, we have four pieces of information: bases, PW, IP,
and strand. We also have one row for CCS, and four rows for SN (in that
order).
The information is structured as follows:
Bases: rows 0 to (params.max_passes - 1)
PW: rows (params.max_passes) to (params.max_passes * 2 - 1)
IP: rows (params.max_passes * 2) to (params.max_passes * 3 - 1)
Strand: rows (params.max_passes * 3) to (params.max_passes * 4 - 1)
CCS+SN: rows (params.max_passes * 4) to (params.max_passes * 4 + 5)
The number of rows is based on max_passes which scales dynamic features
(Bases, PW, IP, Strand, etc) + rows for a number of fixed size features. CCS
Base Qualities are optionally included as a feature, which can modify the
number of fixed length rows.
Args:
max_passes: Maximum number of subreads to show. Space is made for them all
even though few examples will have enough subreads to fill these rows.
Returns: Total number of rows in the full example.
use_ccs_bq: Bool indicating whether CCS Base Quality Scores are being used.
Returns:
Total number of rows in the full example.
"""
return (max_passes * 4) + 5
fixed_length = 6 if use_ccs_bq else 5
return (max_passes * 4) + fixed_length


def get_indices(max_passes: int) -> Iterable[Tuple[int, int]]:
Expand Down Expand Up @@ -127,8 +123,7 @@ def format_rows(
sn_rows, clip_value_min=0, clip_value_max=params.SN_MAX)
rows = tf.concat(
[base_rows, pw_rows, ip_rows, strand_rows, ccs_rows, sn_rows], axis=0)
num_rows = get_total_rows(params.max_passes)
rows.set_shape((num_rows, params.max_length, 1))
rows.set_shape((params.total_rows, params.max_length, 1))
return rows


Expand Down Expand Up @@ -257,8 +252,8 @@ def get_dataset(file_pattern: str,
inference: Whether to parse tf.Examples for inference or training.
limit: Max number of examples to get. Set to -1 for no limit.
drop_remainder: Passed to TFRecordDataset.batch
example_label_tuple: If True, output simplified format for training/eval
as (rows, label)
example_label_tuple: If True, output simplified format for training/eval as
(rows, label)
Returns:
A dataset for which each batch has the following elements:
Expand All @@ -272,9 +267,7 @@ def get_dataset(file_pattern: str,

def _process_input_helper(proto_string: tf.Tensor) -> Dict[str, tf.Tensor]:
return process_input(
proto_string=proto_string,
params=params,
inference=inference)
proto_string=proto_string, params=params, inference=inference)

file_patterns = create_glob_list(file_pattern)
ds = tf.data.TFRecordDataset(file_patterns, compression_type='GZIP')
Expand Down
13 changes: 13 additions & 0 deletions deepconsensus/models/data_providers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,5 +303,18 @@ def test_remove_internal_gaps_and_shift(self):
self.assertEqual(result, expected)


class GetTotalRowsTest(parameterized.TestCase):

@parameterized.parameters(
[20, False, 85],
[20, True, 86],
[25, False, 105],
[25, True, 106],
)
def test_get_total_rows(self, max_passes, use_ccs_bq, expected_total_rows):
total_rows = data_providers.get_total_rows(max_passes, use_ccs_bq)
self.assertEqual(total_rows, expected_total_rows)


if __name__ == '__main__':
absltest.main()
42 changes: 28 additions & 14 deletions deepconsensus/models/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from typing import Optional
import ml_collections
from ml_collections import config_dict

# Do not add any additional imports to the config.
# It can lead to circular dependencies easily and should not be necessary
Expand All @@ -46,18 +47,13 @@ def _set_base_fc_hparams(params):
params.fc_size = [256, 512, 256, 128]
params.fc_dropout = 0.0

params.use_bases = True
params.use_pw = True
params.use_ip = True
params.use_strand = True
params.use_ccs = True
params.use_sn = True
params.num_channels = 1

params.per_base_hidden_size = 1
params.pw_hidden_size = 1
params.ip_hidden_size = 1
params.strand_hidden_size = 1
params.ccs_bq_hidden_size = 1
params.sn_hidden_size = 1

# Training
Expand Down Expand Up @@ -97,16 +93,11 @@ def _set_base_transformer_hparams(params):
params.attn_win_size = 12

params.num_channels = 1
params.use_bases = True
params.use_pw = True
params.use_ip = True
params.use_ccs = True
params.use_strand = True
params.use_sn = True
params.per_base_hidden_size = 1
params.pw_hidden_size = 1
params.ip_hidden_size = 1
params.sn_hidden_size = 1
params.ccs_bq_hidden_size = 1
params.strand_hidden_size = 1

# Dropout values (only used when training).
Expand Down Expand Up @@ -135,15 +126,15 @@ def _set_base_transformer_hparams(params):

def _set_transformer_learned_embeddings_hparams(params):
"""Updates given config with values for the learned embeddings transformer."""
# TODO: As we migrate off the legacy code, we might need to
# adjust the params below. For now just making a copy of the previous params.
_set_base_transformer_hparams(params)
params.model_name = 'transformer_learn_values'
params.per_base_hidden_size = 8
params.pw_hidden_size = 8
params.ip_hidden_size = 8
params.strand_hidden_size = 2
params.sn_hidden_size = 8
params.ccs_bq_hidden_size = 8

params.condense_transformer_input = True
params.transformer_input_size = 280

Expand Down Expand Up @@ -260,12 +251,35 @@ def get_config(config_name: Optional[str] = None) -> ml_collections.ConfigDict:
# Used for generating replicates.
params.trial = 1

# Defaults for backward compatibilitiy
# Older models initiate with the default config, so initialize those values
# here to set for older models.
params.rezero = False

# Base config
params.PW_MAX = 255
params.IP_MAX = 255
params.SN_MAX = 500
params.CCS_BQ_MAX = 95
params.STRAND_MAX = 2

# Features
params.use_bases = True
params.use_pw = True
params.use_ip = True
params.use_strand = True
params.use_sn = True
params.use_ccs = True
params.use_ccs_bq = False
params.per_base_hidden_size = 1
params.pw_hidden_size = 1
params.ip_hidden_size = 1
params.sn_hidden_size = 1
params.strand_hidden_size = 1
params.ccs_bq_hidden_size = 1

params.total_rows = config_dict.placeholder(int)

# Specify common configs here.
params.vocab_size = 5
params.tensorboard_update_freq = 'batch'
Expand Down
9 changes: 3 additions & 6 deletions deepconsensus/models/model_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ def init_student_from_teacher(
student_model: tf.keras.Model, teacher_model: tf.keras.Model,
params: ml_collections.ConfigDict) -> tf.keras.Model:
"""Initialize student model using teacher model weights based on params."""
row_size = data_providers.get_total_rows(params.max_passes)
input_shape = (1, row_size, params.max_length, params.num_channels)
input_shape = (1, params.total_rows, params.max_length, params.num_channels)
model_utils.print_model_summary(teacher_model, input_shape)
if params.init_encoder_stack:
teacher2student_encoder_map = dict(
Expand Down Expand Up @@ -137,8 +136,7 @@ def get_teacher_model(checkpoint_path: str,
# If you don't do this, then assert_existing_objects_matched will not
# raise an error even if the wrong checkpoint is used.
# Some context here: b/148023980.
row_size = data_providers.get_total_rows(params.max_passes)
input_shape = (1, row_size, params.max_length, params.num_channels)
input_shape = (1, params.total_rows, params.max_length, params.num_channels)
model_utils.print_model_summary(model, input_shape)
checkpoint.restore(
checkpoint_path).expect_partial().assert_existing_objects_matched()
Expand Down Expand Up @@ -171,8 +169,7 @@ def train_model(teacher_model: tf.keras.Model, out_dir: str,
model = model_utils.get_model(params)
# Note that the `print_model_summary` is necessary because we need to run a
# forward pass with the model to be able to initialize student from teacher.
row_size = data_providers.get_total_rows(params.max_passes)
input_shape = (1, row_size, params.max_length, params.num_channels)
input_shape = (1, params.total_rows, params.max_length, params.num_channels)
model_utils.print_model_summary(model, input_shape)
logging.info('Done building model.')
# Initialize student model from teacher based on model params.
Expand Down
4 changes: 1 addition & 3 deletions deepconsensus/models/model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
from ml_collections.config_flags import config_flags
import tensorflow as tf

from deepconsensus.models import data_providers
from deepconsensus.models import losses_and_metrics
from deepconsensus.models import model_utils

Expand Down Expand Up @@ -84,8 +83,7 @@ def run_inference(out_dir: str, params: ml_collections.ConfigDict,

with strategy.scope():
model = model_utils.get_model(params)
row_size = data_providers.get_total_rows(params.max_passes)
input_shape = (1, row_size, params.max_length, params.num_channels)
input_shape = (1, params.total_rows, params.max_length, params.num_channels)
model_utils.print_model_summary(model, input_shape)
checkpoint = tf.train.Checkpoint(model=model)
# Need to run a forward pass with the model in order for
Expand Down
35 changes: 25 additions & 10 deletions deepconsensus/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def load_dataset_summary(dataset_path: str) -> Tuple[str, Dict[str, Any]]:
return dataset_summary_path, dataset_summary


def del_param(params, name):
def _del_param(params, name):
if name in params:
del params[name]

Expand Down Expand Up @@ -247,11 +247,11 @@ def modify_params(params: ml_collections.ConfigDict,
with params.unlocked():
if not is_training:
# Only allow dataset specification in params when in training mode.
del_param(params, 'tf_dataset')
del_param(params, 'train_path')
del_param(params, 'eval_path')
del_param(params, 'test_path')
del_param(params, 'inference_path')
_del_param(params, 'tf_dataset')
_del_param(params, 'train_path')
_del_param(params, 'eval_path')
_del_param(params, 'test_path')
_del_param(params, 'inference_path')
# Set dataset if tf_dataset is set.
if 'tf_dataset' in params and params.tf_dataset:
set_dataset(params)
Expand Down Expand Up @@ -287,16 +287,24 @@ def modify_params(params: ml_collections.ConfigDict,
if not hasattr(params, 'max_length'):
raise ValueError('No params.max_length provided.')

# Calculate the total number of rows.
params.total_rows = data_providers.get_total_rows(
params.max_passes,
params.use_ccs_bq,
)

if 'transformer_learn_values' in params.model_name:
dim = ((params.use_bases * params.per_base_hidden_size) +
(params.use_pw * params.pw_hidden_size) +
(params.use_ip * params.ip_hidden_size) +
(params.use_strand * params.strand_hidden_size))
(params.use_strand * params.strand_hidden_size) +
(params.use_ccs_bq * params.ccs_bq_hidden_size))
params.hidden_size = ((params.max_passes * dim) +
(params.use_sn * params.sn_hidden_size * 4) +
(params.use_ccs * params.per_base_hidden_size))
(params.use_ccs * params.per_base_hidden_size) +
(params.use_ccs_bq * params.ccs_bq_hidden_size) +
(params.use_sn * params.sn_hidden_size * 4))
else:
params.hidden_size = data_providers.get_total_rows(params.max_passes)
params.hidden_size = params.total_rows

if 'transformer' in params.model_name and params.hidden_size % 2 != 0:
params.hidden_size += 1
Expand Down Expand Up @@ -410,6 +418,13 @@ def read_params_from_json(checkpoint_path: str) -> ml_collections.ConfigDict:
'that is not present in params.json'), b_param,
param_set[b_param])
param_set.update(json_params)

# Calculate the total number of rows for backward compatibility.
param_set.total_rows = data_providers.get_total_rows(
param_set.max_passes,
param_set.use_ccs_bq,
)

return param_set


Expand Down
Loading

0 comments on commit 3884c05

Please sign in to comment.