[go: nahoru, domu]

Skip to content

Commit

Permalink
Add flag for generating tf examples with ccs base quality scores.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 505802316
  • Loading branch information
danielecook authored and Copybara-Service committed Jan 30, 2023
1 parent d482ca0 commit c0b43d5
Show file tree
Hide file tree
Showing 22 changed files with 199 additions and 60 deletions.
5 changes: 4 additions & 1 deletion deepconsensus/models/data_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,10 @@ def parse_example(proto_string: Dict[str, tf.Tensor],
proto_features = PROTO_FEATURES_INFERENCE
else:
proto_features = PROTO_FEATURES_TRAIN
if not proto_features['ccs_base_quality_scores'].shape:
# Set the correct dimensionality for ccs_base_quality scores.
if not proto_features['ccs_base_quality_scores'].shape or proto_features[
'ccs_base_quality_scores'].shape[0] != max_length:
proto_features['ccs_base_quality_scores'].shape.clear()
proto_features['ccs_base_quality_scores'].shape.append(max_length)
parsed_features = tf.io.parse_single_example(
serialized=proto_string, features=proto_features)
Expand Down
66 changes: 48 additions & 18 deletions deepconsensus/preprocess/pre_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,18 +411,18 @@ def __repr__(self):


def dc_config_from_shape(subreads_shape: Tuple[int, int, int],
use_bq: bool = False) -> 'DcConfig':
use_ccs_bq: bool = False) -> 'DcConfig':
"""Creates a DcConfig object based on subread shape and base quality usage.
Args:
subreads_shape: The shape of a subreads input.
use_bq: Boolean indicating whether to use base quality scores.
use_ccs_bq: Boolean indicating whether to use base quality scores.
Returns:
A DcConfig object.
"""
height, width, _ = subreads_shape
if use_bq:
if use_ccs_bq:
fixed_height = 6
else:
fixed_height = 5
Expand All @@ -431,7 +431,7 @@ def dc_config_from_shape(subreads_shape: Tuple[int, int, int],
len(DcConfig.n_subread_features))
if remainder != 0:
raise ValueError(f'Invalid subreads shape {subreads_shape!r}.')
return DcConfig(max_passes, width, use_bq)
return DcConfig(max_passes, width, use_ccs_bq)


class DcConfig:
Expand All @@ -441,8 +441,8 @@ class DcConfig:
max_passes: Number of passes to incorporate into examples.
max_length: Width of example.
feature_rows: A dictionary indicating features and corresponding height.
use_bq: Boolean indicating whether to incorporate base quality scores into
model input.
use_ccs_bq: Boolean indicating whether to incorporate base quality scores
into model input.
feature_indices: Calculated indices for each feature.
"""

Expand All @@ -451,7 +451,12 @@ class DcConfig:
# Features with n_rows = n_subreads.
n_subread_features = ['bases', 'pw', 'ip', 'strand']

def __init__(self, max_passes: int, max_length: int, use_bq: bool = False):
def __init__(
self,
max_passes: int,
max_length: int,
use_ccs_bq: bool = False,
):
self.max_passes = max_passes
self.max_length = max_length
self.feature_rows = {
Expand All @@ -460,12 +465,12 @@ def __init__(self, max_passes: int, max_length: int, use_bq: bool = False):
'ip': max_passes,
'strand': max_passes,
'ccs': 1,
'ccs_bq': 1 if use_bq else 0,
'ccs_bq': 1 if use_ccs_bq else 0,
'sn': 4
}
# Sets slices indicating rows for each feature type.
self.feature_indices = dict()
self.use_bq = use_bq
self.use_ccs_bq = use_ccs_bq
i_rows = 0
for k, v in self.feature_rows.items():
self.feature_indices[k] = slice(i_rows, i_rows + self.feature_rows[k])
Expand Down Expand Up @@ -715,7 +720,7 @@ def repeat(feature):

# ccs features.
data[ccs_idx] = self.ccs.bases_encoded
if self.config.use_bq:
if self.config.use_ccs_bq:
data[ccs_bq_idx] = self.ccs.base_quality_scores

# Format sn rows.
Expand Down Expand Up @@ -872,10 +877,26 @@ def set_feature(feature, shape):
return feature


def tf_example_to_features_dict(tf_example_proto_str, inference=False):
"""Convert tf.Example to features_dict."""
def tf_example_to_features_dict(tf_example_proto_str: Dict[str, Any],
inference: bool = False,
use_ccs_bq: bool = False,
max_length: int = 100) -> Dict[str, Any]:
"""Converts tf.Example to features_dict.
Args:
tf_example_proto_str: Input str-encoded tf.Example.
inference: Bool indicating whether to only load inference-relevant fields.
use_ccs_bq: Bool indicating whether subreads contain base quality scores.
max_length: The width of the tf.Example.
Returns:
A dictionary containing tf.Example Tensor elements.
"""
features = data_providers.parse_example(
tf_example_proto_str, inference=inference)
tf_example_proto_str,
inference=inference,
max_length=max_length,
)

for key, val in features.items():
if tf.executing_eagerly():
Expand All @@ -891,13 +912,22 @@ def tf_example_to_features_dict(tf_example_proto_str, inference=False):

features['subreads'] = set_feature(features['subreads/encoded'],
features['subreads/shape'])
dc_config = dc_config_from_shape(features['subreads/shape'])
dc_config = dc_config_from_shape(
features['subreads/shape'],
use_ccs_bq,
)
# Get a default config and overwrite with specified values
params = model_configs.get_config()
params.max_length = int(dc_config.max_length)
params.max_passes = int(dc_config.max_passes)
features['subreads'] = data_providers.format_rows(
features['subreads'], params=params)
with params.unlocked():
params.use_ccs_bq = use_ccs_bq
params.max_length = int(dc_config.max_length)
params.max_passes = int(dc_config.max_passes)
params.total_rows = data_providers.get_total_rows(
params.max_passes,
use_ccs_bq,
)
features['subreads'] = data_providers.format_rows(features['subreads'],
params)
del features['subreads/encoded']
if not inference:
features['label'] = set_feature(features['label/encoded'],
Expand Down
57 changes: 43 additions & 14 deletions deepconsensus/preprocess/pre_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,47 +926,47 @@ class TestDcConfigFromShape(parameterized.TestCase):

@parameterized.named_parameters(
dict(
testcase_name='standard shape use_bq=True',
testcase_name='standard shape use_ccs_bq=True',
shape=(86, 120, -1),
use_bq=True,
use_ccs_bq=True,
expected_max_passes=20,
expected_sn_rows=slice(82, 86),
),
dict(
testcase_name='expanded shape use_bq=True',
testcase_name='expanded shape use_ccs_bq=True',
shape=(106, 120, -1),
use_bq=True,
use_ccs_bq=True,
expected_max_passes=25,
expected_sn_rows=slice(102, 106),
),
dict(
testcase_name='standard shape use_bq=False',
testcase_name='standard shape use_ccs_bq=False',
shape=(85, 120, -1),
use_bq=False,
use_ccs_bq=False,
expected_max_passes=20,
expected_sn_rows=slice(81, 85),
),
dict(
testcase_name='expanded shape use_bq=False',
testcase_name='expanded shape use_ccs_bq=False',
shape=(105, 120, -1),
use_bq=False,
use_ccs_bq=False,
expected_max_passes=25,
expected_sn_rows=slice(101, 105),
))
def test_dc_config_from_shape(
self,
shape,
use_bq,
use_ccs_bq,
expected_max_passes,
expected_sn_rows,
):
dc_config = pre_lib.dc_config_from_shape(shape, use_bq)
dc_config = pre_lib.dc_config_from_shape(shape, use_ccs_bq)
self.assertEqual(dc_config.max_passes, expected_max_passes)
self.assertEqual(dc_config.indices('sn'), expected_sn_rows)

def test_incompatible_shape(self):
with self.assertRaisesRegex(ValueError, 'Invalid subreads shape'):
pre_lib.dc_config_from_shape((7, 7, -1), use_bq=True)
pre_lib.dc_config_from_shape((7, 7, -1), use_ccs_bq=True)


class TestDcExampleFunctionality(parameterized.TestCase):
Expand All @@ -976,7 +976,7 @@ def test_dc_example_functions(self):
dc_config = pre_lib.DcConfig(
max_passes=20,
max_length=max_length,
use_bq=True,
use_ccs_bq=True,
)
# First, generate a bunch of reads
read_set = []
Expand Down Expand Up @@ -1127,13 +1127,19 @@ def test_tf_example_train(self):
# Serialize tf example and reverse.
tf_example_str = tf_example.SerializePartialToString()
parsed_example = data_providers.parse_example(
tf_example_str, inference=False, max_length=dc_config.max_length)
tf_example_str,
inference=False,
max_length=dc_config.max_length,
)
self.assertSetEqual(
set(parsed_example.keys()),
set(data_providers.PROTO_FEATURES_TRAIN.keys()))

# Compare tf example converted back to DcExample
features = pre_lib.tf_example_to_features_dict(tf_example_str)
features = pre_lib.tf_example_to_features_dict(
tf_example_str,
max_length=dc_config.max_length,
)
window_2_rev = pre_lib.from_features_dict(features)

# Compare reversed values.
Expand Down Expand Up @@ -1259,5 +1265,28 @@ def test_ccs_smart_windows(self, segment_set, window_widths,
self.assertCountEqual(examples, expected_examples)


class TestTfExamplesToFeaturesDict(parameterized.TestCase):

def test_tf_examples_to_features_dict(self):
tf_examples = test_utils.deepconsensus_testdata(
'human_1m/tf_examples/@split/@split.tfrecord.gz')
examples = test_utils.load_dataset(tf_examples, 'train')
feature_dicts = pre_lib.tf_example_to_features_dict(
examples[0], inference=False, use_ccs_bq=False)
self.assertListEqual(
list(feature_dicts['subreads/shape']),
list(feature_dicts['subreads'].shape))

def test_tf_examples_bq_to_features_dict(self):
tf_examples = test_utils.deepconsensus_testdata(
'human_1m/tf_examples_bq/@split/@split.tfrecord.gz')
examples = test_utils.load_dataset(tf_examples, 'train')
feature_dicts = pre_lib.tf_example_to_features_dict(
examples[0], inference=False, use_ccs_bq=True)
self.assertListEqual(
list(feature_dicts['subreads/shape']),
list(feature_dicts['subreads'].shape))


if __name__ == '__main__':
absltest.main()
12 changes: 9 additions & 3 deletions deepconsensus/preprocess/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@
'use_ccs_smart_windows', False,
'If true, CCS smart window widths are used to partition '
'subreads into windows.')
_USE_CCS_BQ = flags.DEFINE_bool(
'use_ccs_bq', False,
'If true, incorporate CCS Base Quality scores into tf.examples.')
# The following just need to match the training parameters.
_MAX_PASSES = flags.DEFINE_integer('max_passes', 20,
'Maximum subreads in each input.')
Expand All @@ -118,9 +121,9 @@ def wrap(*args, **kwargs):
try:
result = f(*args, **kwargs)
return result
except: # pylint: disable=bare-except
except Exception as exc:
logging.exception('Error in function %s.', f.__name__)
raise Exception('Error in worker process')
raise exc

return wrap

Expand Down Expand Up @@ -240,7 +243,10 @@ def main(unused_argv) -> None:
queue = manager.Queue()

dc_config = pre_lib.DcConfig(
max_passes=_MAX_PASSES.value, max_length=_EXAMPLE_WIDTH.value)
max_passes=_MAX_PASSES.value,
max_length=_EXAMPLE_WIDTH.value,
use_ccs_bq=_USE_CCS_BQ.value,
)

proc_feeder, main_counter = pre_lib.create_proc_feeder(
subreads_to_ccs=FLAGS.subreads_to_ccs,
Expand Down
57 changes: 44 additions & 13 deletions deepconsensus/preprocess/preprocess_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@

from absl import flags
from absl.testing import absltest
from absl.testing import flagsaver
from absl.testing import parameterized
import tensorflow as tf

from deepconsensus.preprocess import pre_lib
from deepconsensus.preprocess import preprocess
from deepconsensus.utils import test_utils
from deepconsensus.utils.test_utils import deepconsensus_testdata
from absl import app

Expand All @@ -50,14 +52,6 @@ def load_summary(tmp_dir, path):
return json.load(open(summary_path, 'r'))


def load_dataset(output, dataset):
# Load inference, train, eval, or test tfrecord.gz files.
tf_record = output.replace('@split', dataset)
dataset = tf.data.TFRecordDataset(tf_record, compression_type='GZIP')
examples = list(dataset.as_numpy_iterator())
return examples


def get_unique_zmws(examples):
zmws = []
for example in examples:
Expand All @@ -80,7 +74,7 @@ def test_e2e_inference(self, n_cpus):
output = os.path.join(tmp_dir, 'tf-@split.tfrecord.gz')
FLAGS.output = output
preprocess.main([])
examples = load_dataset(output, 'inference')
examples = test_utils.load_dataset(output, 'inference')
features = pre_lib.tf_example_to_features_dict(examples[0], inference=True)

# Check that window_pos incr. monotonically for each ZMW.
Expand Down Expand Up @@ -117,9 +111,9 @@ def test_e2e_train(self, n_cpus):
output = os.path.join(tmp_dir, 'tf-@split.tfrecord.gz')
FLAGS.output = output
preprocess.main([])
train_examples = load_dataset(output, 'train')
eval_examples = load_dataset(output, 'eval')
test_examples = load_dataset(output, 'test')
train_examples = test_utils.load_dataset(output, 'train')
eval_examples = test_utils.load_dataset(output, 'eval')
test_examples = test_utils.load_dataset(output, 'test')
all_examples = train_examples + eval_examples + test_examples

# Check that window_pos incr. monotonically for each ZMW.
Expand Down Expand Up @@ -163,6 +157,43 @@ def test_e2e_train(self, n_cpus):
self.assertSameElements(features['subreads'].shape,
features['subreads/shape'])

def test_invalid_tf_examples(self):
"""Tests for proper error thrown when loading improprer tf example."""
output = os.path.join(self.create_tempdir(), 'tf-@split.tfrecord.gz')
with flagsaver.flagsaver(
subreads_to_ccs=testdata('human_1m/subreads_to_ccs.bam'),
ccs_bam=testdata('human_1m/ccs.bam'),
use_ccs_bq=True,
cpus=0,
limit=1,
output=output):
preprocess.main([])
examples = test_utils.load_dataset(output, 'train')
with self.assertRaisesRegex(ValueError, 'Invalid subreads shape'):
_ = pre_lib.tf_example_to_features_dict(
examples[0], inference=True, use_ccs_bq=False)

def test_bq_tf_examples(self):
"""Tests preprocessing inference with base quality score features."""
output = os.path.join(self.create_tempdir(), 'tf-@split.tfrecord.gz')
with flagsaver.flagsaver(
subreads_to_ccs=testdata('human_1m/subreads_to_ccs.bam'),
ccs_bam=testdata('human_1m/ccs.bam'),
use_ccs_bq=True,
cpus=0,
limit=1,
output=output):
preprocess.main([])
examples = test_utils.load_dataset(output, 'inference')

features = pre_lib.tf_example_to_features_dict(
examples[0],
inference=True,
use_ccs_bq=True,
)
self.assertEqual(list(features['subreads/shape']), [86, 100, 1])
self.assertEqual(list(features['subreads'].shape), [86, 100, 1])


if __name__ == '__main__':
absltest.main()
Binary file modified deepconsensus/testdata/human_1m/ccs.bam
Binary file not shown.
Binary file modified deepconsensus/testdata/human_1m/subreads_to_ccs.bam
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified deepconsensus/testdata/human_1m/tf_examples/test/test.tfrecord.gz
Binary file not shown.
Binary file modified deepconsensus/testdata/human_1m/tf_examples/train/train.tfrecord.gz
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit c0b43d5

Please sign in to comment.