[go: nahoru, domu]

Skip to content

Commit

Permalink
Modify get_indices to handle ccs_bq scores.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 505020547
  • Loading branch information
danielecook authored and Copybara-Service committed Jan 27, 2023
1 parent 3884c05 commit 5c09997
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 16 deletions.
67 changes: 59 additions & 8 deletions deepconsensus/models/data_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,39 @@ def get_total_rows(max_passes: int, use_ccs_bq: bool) -> int:
return (max_passes * 4) + fixed_length


def get_indices(max_passes: int) -> Iterable[Tuple[int, int]]:
"""Return row indices for bases/PW/IP/SN in tf.Example subreads array."""
def get_indices(max_passes: int, use_ccs_bq: bool) -> Iterable[Tuple[int, int]]:
"""Returns row indices for bases/PW/IP/SN in tf.Example subreads array.
This function returns tuples of the start/end rows for each feature in an
input example.
Arguments:
max_passes: The number of passes used to construct input example.
use_ccs_bq: Whether to use CCS Base Quality scores.
Returns:
A list of tuples with the (start, end) of each feature.
"""
base_indices = (0, max_passes)
pw_indices = (max_passes, max_passes * 2)
ip_indices = (max_passes * 2, max_passes * 3)
strand_indices = (max_passes * 3, max_passes * 4)
ccs_indices = (max_passes * 4, max_passes * 4 + 1)
sn_indices = (max_passes * 4 + 1, max_passes * 4 + 5)
return base_indices, pw_indices, ip_indices, strand_indices, ccs_indices, sn_indices
if use_ccs_bq:
ccs_bq_indices = (max_passes * 4 + 1, max_passes * 4 + 2)
sn_indices = (max_passes * 4 + 2, max_passes * 4 + 6)
else:
ccs_bq_indices = (0, 0)
sn_indices = (max_passes * 4 + 1, max_passes * 4 + 5)
return (
base_indices,
pw_indices,
ip_indices,
strand_indices,
ccs_indices,
ccs_bq_indices,
sn_indices,
)


@tf.function
Expand All @@ -103,13 +127,22 @@ def format_rows(
config_dict.FrozenConfigDict]
) -> tf.Tensor:
"""Returns model input matrix formatted based on input args."""
base_indices, pw_indices, ip_indices, strand_indices, ccs_indices, sn_indices = get_indices(
params.max_passes)
(
base_indices,
pw_indices,
ip_indices,
strand_indices,
ccs_indices,
ccs_bq_indices,
sn_indices,
) = get_indices(params.max_passes, params.use_ccs_bq)

base_rows = subreads[slice(*base_indices)]
pw_rows = subreads[slice(*pw_indices)]
ip_rows = subreads[slice(*ip_indices)]
strand_rows = subreads[slice(*strand_indices)]
ccs_rows = subreads[slice(*ccs_indices)]
ccs_bq_rows = subreads[slice(*ccs_bq_indices)]
sn_rows = subreads[slice(*sn_indices)]

if params.PW_MAX:
Expand All @@ -121,8 +154,26 @@ def format_rows(
if params.SN_MAX:
sn_rows = tf.clip_by_value(
sn_rows, clip_value_min=0, clip_value_max=params.SN_MAX)
rows = tf.concat(
[base_rows, pw_rows, ip_rows, strand_rows, ccs_rows, sn_rows], axis=0)
if params.use_ccs_bq:
features = [
base_rows,
pw_rows,
ip_rows,
strand_rows,
ccs_rows,
ccs_bq_rows,
sn_rows,
]
else:
features = [
base_rows,
pw_rows,
ip_rows,
strand_rows,
ccs_rows,
sn_rows,
]
rows = tf.concat(features, axis=0)
rows.set_shape((params.total_rows, params.max_length, 1))
return rows

Expand Down
40 changes: 38 additions & 2 deletions deepconsensus/models/data_providers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,15 @@ def test_get_dataset_with_pw_ip(self, num_epochs, batch_size, inference):
for data in dataset.as_numpy_iterator():
rows = data['rows']
check_not_empty = True
base_indices, pw_indices, ip_indices, strand_indices, ccs_indices, sn_indices = data_providers.get_indices(
params.max_passes)
(
base_indices,
pw_indices,
ip_indices,
strand_indices,
ccs_indices,
_,
sn_indices,
) = data_providers.get_indices(params.max_passes, params.use_ccs_bq)
base_rows = rows[:, slice(*base_indices), :, :]
pw_rows = rows[:, slice(*pw_indices), :, :]
ip_rows = rows[:, slice(*ip_indices), :, :]
Expand Down Expand Up @@ -316,5 +323,34 @@ def test_get_total_rows(self, max_passes, use_ccs_bq, expected_total_rows):
self.assertEqual(total_rows, expected_total_rows)


class GetIndicesTest(parameterized.TestCase):

@parameterized.named_parameters(
dict(
testcase_name='use_ccs_bq',
use_ccs_bq=True,
expected_ccs_bq_rows=(81, 82),
expected_sn_rows=(82, 86)),
dict(
testcase_name='no_use_ccs_bq',
use_ccs_bq=False,
expected_ccs_bq_rows=(0, 0),
expected_sn_rows=(81, 85)),
)
def test_get_indices(self, use_ccs_bq, expected_ccs_bq_rows,
expected_sn_rows):
(
_,
_,
_,
_,
_,
ccs_bq_rows,
sn_rows,
) = data_providers.get_indices(20, use_ccs_bq)
self.assertEqual(ccs_bq_rows, expected_ccs_bq_rows)
self.assertEqual(sn_rows, expected_sn_rows)


if __name__ == '__main__':
absltest.main()
5 changes: 4 additions & 1 deletion deepconsensus/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ def extract_example_height(dataset_sharded_path: str) -> int:
def get_ccs_from_example(features: tf.Tensor,
params: ml_collections.ConfigDict) -> tf.Tensor:
"""Gets CCS sequence from model input features."""
_, _, _, _, ccs_index, _ = data_providers.get_indices(params['max_passes'])
_, _, _, _, ccs_index, _, _ = data_providers.get_indices(
params.max_passes,
params.use_ccs_bq,
)
# CCS tensor with shape [batch_size, 1, max_length, 1].
ccs = tf.gather(features, tf.range(*ccs_index), axis=1)
# Return CCS tensor with shape [batch_size, max_length].
Expand Down
16 changes: 12 additions & 4 deletions deepconsensus/models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,18 @@ def encode(self, inputs: tf.Tensor, attention_bias: tf.Tensor,
# [batch_size, length, embedding_size]. Embed each row of the input
# separately and then concatenate.
embedded_inputs = []
base_indices, pw_indices, ip_indices, strand_indices, ccs_indices, sn_indices = data_providers.get_indices(
self.params['max_passes'])
(
base_indices,
pw_indices,
ip_indices,
strand_indices,
ccs_indices,
_,
sn_indices,
) = data_providers.get_indices(
self.params.max_passes,
self.params.use_ccs_bq,
)
if self.params.use_bases:
for i in range(*base_indices):
# Shape: [batch_size, length, per_base_hidden_size]
Expand Down Expand Up @@ -431,8 +441,6 @@ def encode(self, inputs: tf.Tensor, attention_bias: tf.Tensor,
tf.cast(inputs[:, :, i], tf.int32))
embedded_inputs.append(embedded)

# TODO: experiment with computing a weighted average using snr as
# weights to aggregate subread-level embeddings (instead of concatenating).
if self.params.use_sn:
# The last four elements in the last dimension in the inputs tensor
# correspond to the four signal-to-noise ratio scores for A, G, C, T.
Expand Down
5 changes: 4 additions & 1 deletion deepconsensus/utils/colab_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def convert_to_bases(rows: tf.Tensor, label: tf.Tensor,
rows = tf.squeeze(rows)
label = tf.squeeze(label)
deepconsensus_pred = tf.squeeze(deepconsensus_pred)
base_indices, _, _, _, _, _ = data_providers.get_indices(max_passes)
base_indices, _, _, _, _, _, _ = data_providers.get_indices(
max_passes,
use_ccs_bq=False,
)
subread_rows_range = range(*base_indices)
subread_rows = [rows[i, :].numpy() for i in subread_rows_range]
subread_rows = [row for row in subread_rows if np.sum(row) != 0]
Expand Down

0 comments on commit 5c09997

Please sign in to comment.