[go: nahoru, domu]

Skip to content

Commit

Permalink
Add RG tags to BAM output
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 476519144
  • Loading branch information
danielecook authored and Copybara-Service committed Sep 24, 2022
1 parent 7e710a1 commit 5f75175
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 11 deletions.
21 changes: 15 additions & 6 deletions deepconsensus/inference/quick_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def run_model_on_examples(
ec_arr = data['ec']
np_num_passes_arr = data['np_num_passes']
rq_arr = data['rq']
rg_arr = data['rg']

y_preds = np.argmax(softmax_output, -1)
error_prob = 1 - np.max(softmax_output, axis=-1)
Expand All @@ -355,15 +356,16 @@ def run_model_on_examples(
quality_scores = np.minimum(quality_scores, dc_constants.MAX_QUAL)
quality_scores = np.round(quality_scores, decimals=0)
quality_scores = quality_scores.astype(dtype=np.int32)
for y_pred, qs, window_pos, molecule_name, ec, np_, rq in zip(
for y_pred, qs, window_pos, molecule_name, ec, np_, rq, rg in zip(
y_preds, quality_scores, window_pos_arr, molecule_name_arr, ec_arr,
np_num_passes_arr, rq_arr):
np_num_passes_arr, rq_arr, rg_arr):
dc_output = stitch_utils.DCModelOutput(
window_pos=window_pos,
molecule_name=molecule_name,
ec=ec,
np_num_passes=np_,
rq=rq)
rq=rq,
rg=rg)
y_pred_bases = ''.join(
np.vectorize(dc_constants.VOCAB.__getitem__)(y_pred))
quality_string = utils.quality_scores_to_string(qs)
Expand Down Expand Up @@ -528,7 +530,8 @@ def process_skipped_window(
quality_string=utils.quality_scores_to_string(ccs_quality_scores),
ec=feature_dict['ec'],
np_num_passes=feature_dict['np_num_passes'],
rq=feature_dict['rq'])
rq=feature_dict['rq'],
rg=feature_dict['rg'])
return dc_output


Expand Down Expand Up @@ -663,10 +666,15 @@ def percent_of_examples(numerator):
record.query_name = name
record.query_sequence = seq
record.query_qualities = pysam.qualitystring_to_array(qual)
record.flag = 4 # unmapped.
record.mapping_quality = 255
zmw = int(name.split('/')[1])
record.set_tags([
('ec', predictions_for_zmw[0].ec or -1, 'f'),
('np', predictions_for_zmw[0].np_num_passes, 'i'),
('rq', predictions_for_zmw[0].rq, 'f'),
('RG', predictions_for_zmw[0].rg, 'Z'),
('zm', zmw, 'i'),
])
output_writer.write(record)

Expand Down Expand Up @@ -786,8 +794,9 @@ def run() -> stitch_utils.OutcomeCounter:
if output_fname.endswith('.fq') or output_fname.endswith('.fastq'):
output_writer = gfile.Open(output_fname, 'wb')
else:
header = {'HD': {'VN': '1.0'}}
output_writer = pysam.AlignmentFile(output_fname, 'wb', header=header)
ccs_bam_header = pysam.AlignmentFile(FLAGS.ccs_bam, check_sq=False).header
output_writer = pysam.AlignmentFile(
output_fname, 'wb', header=ccs_bam_header)

input_file_generator = stream_bam(
subreads_to_ccs=FLAGS.subreads_to_ccs,
Expand Down
5 changes: 3 additions & 2 deletions deepconsensus/models/data_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def format_rows(
def process_feature_dict(
features: Dict[str, Union[np.ndarray, int, bytes]],
params: Union[config_dict.ConfigDict, config_dict.FrozenConfigDict]
) -> Dict[str, Union[np.ndarray, int, bytes]]:
) -> Dict[str, Union[np.ndarray, int, bytes, str]]:
"""Parses a serialized tf.Example to return an input, label, and metadata.
Args:
Expand Down Expand Up @@ -165,7 +165,8 @@ def process_feature_dict(
'ccs_base_quality_scores': features['ccs_base_quality_scores'],
'ec': features['ec'],
'np_num_passes': features['np_num_passes'],
'rq': features['rq']
'rq': features['rq'],
'rg': features['rg']
}
return features

Expand Down
1 change: 1 addition & 0 deletions deepconsensus/postprocess/stitch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class DCModelOutput:
ec: float
np_num_passes: int
rq: float
rg: str
sequence: Optional[str] = None
quality_string: Optional[str] = None

Expand Down
3 changes: 2 additions & 1 deletion deepconsensus/postprocess/stitch_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def fake_model_output(start: int, window_size: int):
quality_string='!' * window_size,
ec=2.5,
np_num_passes=2,
rq=0.98)
rq=0.98,
rg='test_rg')


def fake_model_outputs(window_size: int, num_windows: int):
Expand Down
10 changes: 9 additions & 1 deletion deepconsensus/preprocess/pre_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class Read(abc.Sequence):
ec: Optional[float] = None # effective coverage
np_num_passes: Optional[int] = None # number of passes
rq: Optional[float] = None # predicted concordance
rg: Optional[str] = None

# base_quality_scores are only used for the ccs read.
base_quality_scores: np.ndarray = np.empty(0, dtype=np.uint8)
Expand Down Expand Up @@ -287,6 +288,7 @@ def ccs_slice(self, start, end) -> 'Read':
ec=self.ec,
np_num_passes=self.np_num_passes,
rq=self.rq,
rg=self.rg,
ccs_idx=self.ccs_idx[ccs_slice],
truth_idx=self.truth_idx[ccs_slice],
truth_range=self.truth_range)
Expand All @@ -307,6 +309,7 @@ def pad(self, pad_width):
ec=self.ec,
np_num_passes=self.np_num_passes,
rq=self.rq,
rg=self.rg,
ccs_idx=right_pad(self.ccs_idx, pad_width, -1),
truth_idx=right_pad(self.truth_idx, pad_width, -1),
truth_range=self.truth_range)
Expand All @@ -333,6 +336,7 @@ def remove_gaps(self, pad_width: int) -> Union['Read', None]:
ec=self.ec,
np_num_passes=self.np_num_passes,
rq=self.rq,
rg=self.rg,
ccs_idx=self.ccs_idx[keep],
truth_idx=self.truth_idx[keep],
truth_range=self.truth_range).pad(pad_width)
Expand All @@ -357,6 +361,7 @@ def __getitem__(self, r_slice: Union[slice, int]) -> 'Read':
ec=self.ec,
np_num_passes=self.np_num_passes,
rq=self.rq,
rg=self.rg,
ccs_idx=self.ccs_idx[r_slice],
truth_idx=self.truth_idx[r_slice])

Expand Down Expand Up @@ -605,7 +610,8 @@ def to_features_dict(self):
'ccs_base_quality_scores': self.ccs.base_quality_scores,
'ec': self.ccs.ec,
'np_num_passes': self.ccs.np_num_passes,
'rq': self.ccs.rq
'rq': self.ccs.rq,
'rg': self.ccs.rg,
}
return features

Expand Down Expand Up @@ -787,6 +793,7 @@ def get_tag(read: pysam.AlignedSegment, tag_name: str) -> Any:
ec = get_tag(ccs_bam_read, 'ec')
np_num_passes = get_tag(ccs_bam_read, 'np')
rq = get_tag(ccs_bam_read, 'rq')
rg = get_tag(ccs_bam_read, 'RG')

return Read(
name=ccs_bam_read.qname,
Expand All @@ -798,6 +805,7 @@ def get_tag(read: pysam.AlignedSegment, tag_name: str) -> Any:
ec=ec,
np_num_passes=np_num_passes,
rq=rq,
rg=rg,
strand=dc_constants.Strand.UNKNOWN,
base_quality_scores=np.array(ccs_bam_read.query_qualities),
ccs_idx=np.arange(len(ccs_seq)))
Expand Down
2 changes: 1 addition & 1 deletion deepconsensus/utils/dc_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class Strand(int, enum.Enum):
# List of features in DC examples.
DC_FEATURES = [
'rows', 'label', 'num_passes', 'window_pos', 'name',
'ccs_base_quality_scores', 'ec', 'np_num_passes', 'rq'
'ccs_base_quality_scores', 'ec', 'np_num_passes', 'rq', 'rg'
]

MAX_QUAL = 40
Expand Down

0 comments on commit 5f75175

Please sign in to comment.