[go: nahoru, domu]

Skip to content

Commit

Permalink
Output counters in inference mode.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 480485704
  • Loading branch information
akolesnikov authored and Copybara-Service committed Oct 11, 2022
1 parent e0e6d28 commit 4662140
Showing 1 changed file with 24 additions and 5 deletions.
29 changes: 24 additions & 5 deletions deepconsensus/inference/quick_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
--output=predictions.fastq
"""

import collections
import concurrent.futures
import dataclasses
import enum
import itertools
import json
import multiprocessing
import os
import time
Expand Down Expand Up @@ -445,7 +447,7 @@ def initialize_model(

def preprocess(
one_zmw: Tuple[str, List[pre_lib.Read], pre_lib.DcConfig]
) -> List[Dict[str, Any]]:
) -> Tuple[List[Dict[str, Any]], Optional[Any]]:
"""Preprocess input data for one ZMW into windows of features.
This is often run from multiple processes in parallel, which creates some
Expand All @@ -457,18 +459,18 @@ def preprocess(
Returns:
A list of feature dictionaries, one for each window.
A list of time log dictionaries, one for each of the two stages.
collections.Counter with inference counters.
"""
zmw, subreads, dc_config = one_zmw

dc_whole_zmw = pre_lib.subreads_to_dc_example(
subreads=subreads, ccs_seqname=zmw, dc_config=dc_config)
if dc_whole_zmw is None or FLAGS.end_after_stage == DebugStage.DC_INPUT:
return []
return ([], None)

# One feature dictionary per window/example.
feature_dicts = [x.to_features_dict() for x in dc_whole_zmw.iter_examples()]
return feature_dicts
return feature_dicts, dc_whole_zmw.counter


def process_skipped_window(
Expand Down Expand Up @@ -496,6 +498,7 @@ def process_skipped_window(
return dc_output


# TODO: Combine outcome_counter and stats_counter.
def inference_on_n_zmws(
inputs: Sequence[Tuple[str, str, Sequence[Any]]],
model: tf.keras.Model,
Expand All @@ -504,6 +507,7 @@ def inference_on_n_zmws(
options: InferenceOptions,
batch_name: str,
outcome_counter: stitch_utils.OutcomeCounter,
stats_counter: Any,
pool: Optional[concurrent.futures.ProcessPoolExecutor] = None) -> None:
"""Runs the full inference process on a batch of ZMWs and writes to fastq.
Expand All @@ -516,6 +520,7 @@ def inference_on_n_zmws(
options: Some options that apply to various stages of the inference run.
batch_name: Name of batch used for runtime metrics.
outcome_counter: Counts outcomes for each ZMW.
stats_counter: Global counter to gather example statistics.
pool: Process pool to run the preprocessing on. If None or empty,
preprocessing will be done sequentially on the main process.
"""
Expand All @@ -529,8 +534,10 @@ def inference_on_n_zmws(
# Each call to preprocess gets one ZMW from inputs.
outputs = list(pool.map(preprocess, inputs))

feature_dicts_for_zmws = outputs
feature_dicts_for_zmws, counters = zip(*outputs)
num_zmws = len(feature_dicts_for_zmws)
for counter in counters:
stats_counter += counter

batch_total_examples = sum([len(zmw) for zmw in feature_dicts_for_zmws])
batch_total_subreads = sum([len(subreads) for _, subreads, _ in inputs])
Expand Down Expand Up @@ -661,6 +668,14 @@ def save_runtime(time_points, output_prefix):
df.to_csv(writer, index=False)


# TODO: Add annotation for Counter once we move to Python 3.9
def save_counters(counter: Any, output_prefix: str):
"""Output statistics into a file."""
json_stats = json.dumps(counter, indent=True)
with tf.io.gfile.GFile(f'{output_prefix}.json', 'w') as writer:
writer.write(json_stats)


def run() -> stitch_utils.OutcomeCounter:
"""Performs an inference run."""

Expand Down Expand Up @@ -708,6 +723,7 @@ def run() -> stitch_utils.OutcomeCounter:
dc_calibration_values=dc_calibration_values,
ccs_calibration_values=ccs_calibration_values)
outcome_counter = stitch_utils.OutcomeCounter()
stats_counter = collections.Counter()

pool = None
if options.cpus > 0:
Expand Down Expand Up @@ -768,6 +784,7 @@ def run() -> stitch_utils.OutcomeCounter:
options=options,
batch_name=str(batch_count),
outcome_counter=outcome_counter,
stats_counter=stats_counter,
pool=pool)
batch_count += 1
stored_n_zmws = []
Expand All @@ -783,6 +800,7 @@ def run() -> stitch_utils.OutcomeCounter:
options=options,
batch_name=str(batch_count),
outcome_counter=outcome_counter,
stats_counter=stats_counter,
pool=pool)

if pool:
Expand All @@ -794,6 +812,7 @@ def run() -> stitch_utils.OutcomeCounter:
time.time() - before_all_zmws)
logging.info('Outcome counts: %s', outcome_counter)
save_runtime(time_points=timing, output_prefix=f'{output_fname}.runtime')
save_counters(stats_counter, output_prefix=f'{output_fname}.inference')
return outcome_counter


Expand Down

0 comments on commit 4662140

Please sign in to comment.