[go: nahoru, domu]

Skip to content

Commit

Permalink
Refactor batch ptransform in PredictionsExtractor.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 489042610
  • Loading branch information
tf-model-analysis-team authored and tfx-copybara committed Nov 16, 2022
1 parent a21dc44 commit 8a2800c
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions tensorflow_model_analysis/extractors/predictions_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,12 @@ def _RunInference(extracts: beam.pvalue.PCollection,
# Beam batch will group single Extracts into a batch. Then
# merge_extracts will flatten the batch into a single "batched"
# extract.
batch_extracts_stage_name = 'BatchSingleExampleExtracts'
if batch_size is not None:
extracts |= batch_extracts_stage_name >> beam.BatchElements(
min_batch_size=batch_size, max_batch_size=batch_size)
batch_kwargs = {'min_batch_size': batch_size, 'max_batch_size': batch_size}
else:
extracts |= batch_extracts_stage_name >> beam.BatchElements()
return extracts | 'MergeExtracts' >> beam.Map(
util.merge_extracts, squeeze_two_dim_vector=False)
# Default batch parameters.
batch_kwargs = {}
return (extracts
| 'BatchSingleExampleExtracts' >> beam.BatchElements(**batch_kwargs)
| 'MergeExtracts' >> beam.Map(
util.merge_extracts, squeeze_two_dim_vector=False))

0 comments on commit 8a2800c

Please sign in to comment.