[go: nahoru, domu]

Skip to content

Commit

Permalink
Add base quality calibration in quick_inference.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 455444708
  • Loading branch information
kishwarshafin authored and Copybara-Service committed Jun 16, 2022
1 parent 2af1046 commit 0bb1adb
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 10 deletions.
117 changes: 108 additions & 9 deletions deepconsensus/inference/quick_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,25 @@ class DebugStage(enum.Enum):
'on the 3rd GPU, you can set this to 2. By default, if you have GPUs, the '
'lowest index one would be used.')

# The following parameters are for quality score calibration

flags.DEFINE_string(
'dc_calibration', '', 'Comma separated values of '
'linear transformation model\'s calibration values for deepconsensus base '
'qualities. The values are set as \"threshold,w,b\" where threshold is '
'minimum base quality threshold after which the linear transformation '
'will be applied, w is the co-efficient value and b is the bias term for '
'linear transformation. Default: 0,1.197654,-0.99781. Set to "" '
'(empty string) to perform no quality calibration.')
flags.DEFINE_string(
'ccs_calibration', '', 'Comma separated values of '
'linear transformation model\'s calibration values for deepconsensus base '
'qualities. The values are set as \"threshold,w,b\" where threshold is '
'minimum base quality threshold after which the linear transformation '
'will be applied, w is the co-efficient value and b is the bias term for '
'linear transformation. Default: "". Set to "" '
'(empty string) to perform no quality calibration.')


def register_required_flags():
flags.mark_flags_as_required([
Expand All @@ -155,6 +174,22 @@ def register_required_flags():
])


@dataclasses.dataclass
class QualityCalibrationValues:
"""A structure that defines variables required for base quality calibration.
Attributes:
enabled: If set then calibration is enabled.
threshold: A threshold value above which the qualities will be calibrated
w: Co-efficient for linear transformation.
b: Bias term for linear transformation.
"""
enabled: bool
threshold: float
w: float
b: float


@dataclasses.dataclass
class InferenceOptions:
"""A central place to define options used across various stages of inference.
Expand All @@ -170,12 +205,16 @@ class InferenceOptions:
min_quality: Quality threshold to filter final reads.
min_length: Length threshold to filter final reads.
batch_size: Number of examples passed through model at once.
cpus: Number of processes to use for multiprocessing. Must be
positive (for multiprocessing) or 0 (for serial execution).
cpus: Number of processes to use for multiprocessing. Must be positive (for
multiprocessing) or 0 (for serial execution).
skip_windows_above: Run the model only when the avg(ccs_base_qual) of the
window is below this value.
use_saved_model: True if the given checkpoint is a saved model, false if it
is a regular checkpoint.
dc_calibration_values: QualityCalibrationValues defining values to be used
for deepconsensus quality calibration.
ccs_calibration_values: QualityCalibrationValues defining values to be used
for ccs quality calibration.
"""
example_width: int
example_height: int
Expand All @@ -188,6 +227,8 @@ class InferenceOptions:
cpus: int
skip_windows_above: int
use_saved_model: bool
dc_calibration_values: QualityCalibrationValues
ccs_calibration_values: QualityCalibrationValues


timing = []
Expand All @@ -212,6 +253,30 @@ def timelog(stage: str,
timing.append(datum)


def calibrate_quality_scores(
quality_scores: np.ndarray,
calibration_values: QualityCalibrationValues) -> np.ndarray:
"""Calibrate the quality score using linear transformation.
Args:
quality_scores: A list containing the predicted quality score.
calibration_values: Co-efficient values of the linear transformation.
Returns:
A list of calibrated quality scores.
"""
if calibration_values.threshold == 0:
# Skip O(n) operations of np.where when we need to calibrate the entire list
return quality_scores * calibration_values.w + calibration_values.b

w_values = np.where(quality_scores > calibration_values.threshold,
calibration_values.w, 1.0)
b_values = np.where(quality_scores > calibration_values.threshold,
calibration_values.b, 0.0)
calibrated_score = quality_scores * w_values + b_values
return calibrated_score


# TODO Add unit test for this function. We need to create unit test
# infrastructure that allows to easily create input data for unit tests.
def batch_examples(feature_dicts: List[Tuple[str, Union[np.ndarray, int,
Expand Down Expand Up @@ -285,9 +350,11 @@ def run_model_on_examples(
y_preds = np.argmax(softmax_output, -1)
error_prob = 1 - np.max(softmax_output, axis=-1)
quality_scores = -10 * np.log10(error_prob)
# Round to the nearest integer and cap at max allowed value.
quality_scores = np.round(quality_scores, decimals=0)
if options.dc_calibration_values.enabled:
quality_scores = calibrate_quality_scores(quality_scores,
options.dc_calibration_values)
quality_scores = np.minimum(quality_scores, dc_constants.MAX_QUAL)
quality_scores = np.round(quality_scores, decimals=0)
quality_scores = quality_scores.astype(dtype=np.int32)
for y_pred, qs, window_pos, molecule_name in zip(y_preds, quality_scores,
window_pos_arr,
Expand Down Expand Up @@ -446,17 +513,23 @@ def preprocess(


def process_skipped_window(
feature_dict: Dict[str, Any]) -> stitch_utils.DCModelOutput:
feature_dict: Dict[str, Any],
options: InferenceOptions) -> stitch_utils.DCModelOutput:
"""Process a window by simply adopting the CCS sequence and base qualities."""
rows = feature_dict['subreads']
ccs = rows[-5, :, 0]
ccs_seq = utils.encoded_sequence_to_string(ccs)
ccs_quality_scores = feature_dict['ccs_base_quality_scores']
if options.ccs_calibration_values.enabled:
ccs_quality_scores = calibrate_quality_scores(
ccs_quality_scores, options.ccs_calibration_values)
ccs_quality_scores = np.minimum(ccs_quality_scores, dc_constants.MAX_QUAL)
ccs_quality_scores = ccs_quality_scores.astype(dtype=np.int32)
dc_output = stitch_utils.DCModelOutput(
window_pos=feature_dict['window_pos'],
molecule_name=feature_dict['name'],
sequence=ccs_seq,
quality_string=utils.quality_scores_to_string(
feature_dict['ccs_base_quality_scores']))
quality_string=utils.quality_scores_to_string(ccs_quality_scores))
return dc_output


Expand Down Expand Up @@ -530,7 +603,7 @@ def inference_on_n_zmws(inputs: Sequence[Tuple[str, str, Sequence[Any]]],
if avg_ccs_base_quality <= options.skip_windows_above:
feature_dicts_for_model.append(window)
else:
dc_output_for_window = process_skipped_window(window)
dc_output_for_window = process_skipped_window(window, options)
predictions_for_skipped_windows.append(dc_output_for_window)
else:
# Go straight to model without skipping any windows.
Expand Down Expand Up @@ -606,6 +679,26 @@ def save_runtime(time_points, output_prefix):
df.to_csv(writer, index=False)


def parse_calibration_string(
calibration_string: str) -> QualityCalibrationValues:
"""Parse calibration string and return threshold, w and b values."""
# calibration string is empty. So no calibration will be performed.
if not calibration_string:
return QualityCalibrationValues(enabled=False, threshold=0.0, w=1.0, b=0.0)

parsed_list = calibration_string.split(',')
if len(parsed_list) != 3:
raise ValueError('Malformed calibration string. Expected 3 values.',
calibration_string)

calibration_values = QualityCalibrationValues(
enabled=True,
threshold=float(parsed_list[0]),
w=float(parsed_list[1]),
b=float(parsed_list[2]))
return calibration_values


def run() -> stitch_utils.OutcomeCounter:
"""Called by main."""
dc_config = preprocess_utils.DcConfig(FLAGS.max_passes, FLAGS.example_width,
Expand All @@ -614,6 +707,10 @@ def run() -> stitch_utils.OutcomeCounter:
use_saved_model = (
tf.io.gfile.exists(FLAGS.checkpoint) and
tf.io.gfile.exists(f'{FLAGS.checkpoint}/saved_model.pb'))

dc_calibration_values = parse_calibration_string(FLAGS.dc_calibration)
ccs_calibration_values = parse_calibration_string(FLAGS.ccs_calibration)

options = InferenceOptions(
example_width=FLAGS.example_width,
example_height=dc_config.tensor_height,
Expand All @@ -625,7 +722,9 @@ def run() -> stitch_utils.OutcomeCounter:
batch_size=FLAGS.batch_size,
cpus=FLAGS.cpus,
skip_windows_above=FLAGS.skip_windows_above,
use_saved_model=use_saved_model)
use_saved_model=use_saved_model,
dc_calibration_values=dc_calibration_values,
ccs_calibration_values=ccs_calibration_values)
outcome_counter = stitch_utils.OutcomeCounter()

# Set up model.
Expand Down
52 changes: 52 additions & 0 deletions deepconsensus/inference/quick_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,58 @@ def test_end_to_end_multiprocessing(self, cpus, batch_zmws):
self.assertEqual(count, 2)
self.assertEqual(outcomes.success, 2)

@parameterized.parameters(
dict(
calibration_str='',
expected=quick_inference.QualityCalibrationValues(
enabled=False, threshold=0.0, w=1.0, b=0.0),
message='Test 1: Valid empty calibration string.'),
dict(
calibration_str='10,1.0,0.2222',
expected=quick_inference.QualityCalibrationValues(
enabled=True, threshold=10.0, w=1.0, b=0.2222),
message='Test 2: Valid calibration string with positive values.'),
dict(
calibration_str='-10,1.0,0.2222',
expected=quick_inference.QualityCalibrationValues(
enabled=True, threshold=-10.0, w=1.0, b=0.2222),
message='Test 3: Valid calibration string with negative threshold.'),
dict(
calibration_str='-10,-1.0,-0.2222',
expected=quick_inference.QualityCalibrationValues(
enabled=True, threshold=-10.0, w=-1.0, b=-0.2222),
message='Test 4: Valid calibration string with all negative values.'))
@flagsaver.flagsaver
def test_parse_calibration_string(self, calibration_str, expected, message):
"""Tests for parse_calibration_string method."""
returned = quick_inference.parse_calibration_string(calibration_str)
self.assertEqual(returned.enabled, expected.enabled, msg=message)
self.assertEqual(returned.threshold, expected.threshold, msg=message)
self.assertEqual(returned.w, expected.w, msg=message)
self.assertEqual(returned.b, expected.b, msg=message)

@parameterized.parameters(
dict(
calibration_str='ABCD',
message='Test 1: Invalid calibration string ABCD.'),
dict(
calibration_str='A,BC,D',
message='Test 2: Invalid calibration string A,BC,D.'),
dict(
calibration_str='10,1.0',
message='Test 2: Invalid calibration string 10,1.0.'),
dict(
calibration_str='10,AB,1.0',
message='Test 2: Invalid calibration string 10,AB,1.0.'),
dict(
calibration_str='10,0.1.1,1.0',
message='Test 2: Invalid calibration string 10,0.1.1,1.0.'),
)
@flagsaver.flagsaver
def test_parse_calibration_string_exceptions(self, calibration_str, message):
with self.assertRaises(Exception, msg=message):
quick_inference.parse_calibration_string(calibration_str)


if __name__ == '__main__':
absltest.main()
2 changes: 1 addition & 1 deletion deepconsensus/utils/dc_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,5 +107,5 @@ class Strand(int, enum.Enum):
'ccs_base_quality_scores'
]

MAX_QUAL = 60
MAX_QUAL = 40
EMPTY_QUAL = 0

0 comments on commit 0bb1adb

Please sign in to comment.