[go: nahoru, domu]

Skip to content

Commit

Permalink
Combine a few beam metrics to reduce the number of counters in a pipe…
Browse files Browse the repository at this point in the history
…line.

PiperOrigin-RevId: 469712221
  • Loading branch information
zoyahav authored and tfx-copybara committed Aug 24, 2022
1 parent fe2d383 commit a206811
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 51 deletions.
10 changes: 7 additions & 3 deletions tensorflow_transform/analyzer_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,12 +1025,16 @@ def is_partitionable(self):


class InstrumentDatasetCache(
tfx_namedtuple.namedtuple('InstrumentDatasetCache',
['dataset_key', 'label']), nodes.OperationDef):
tfx_namedtuple.namedtuple('InstrumentDatasetCache', [
'input_cache_dataset_keys', 'num_encode_cache', 'num_decode_cache',
'label'
]), nodes.OperationDef):
"""OperationDef instrumenting cached datasets.
Fields:
dataset_key: A dataset key.
input_cache_dataset_keys: A dataset keys for which there's input cache.
num_encode_cache: Number of cache entries encoded.
num_decode_cache: Number of cache entries decoded.
label: A unique label for this operation.
"""
__slots__ = ()
Expand Down
66 changes: 38 additions & 28 deletions tensorflow_transform/beam/analysis_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,29 +173,33 @@ def __init__(self, dataset_keys, cache_dict, tensor_keys_to_paths,
self._cache_dict = cache_dict
self._tensor_keys_to_paths = tensor_keys_to_paths
self._dataset_has_cache_misses = collections.defaultdict(bool)
self._num_encode_cache_nodes = 0
self._num_decode_cache_nodes = 0
self.cache_output_nodes = cache_output_nodes
self._num_phases = num_phases

def _validate_operation_def(self, operation_def):
if operation_def.cache_coder is not None:
if not operation_def.is_partitionable:
raise ValueError(
'Non partitionable OperationDefs cannot be cacheable: {}'.format(
operation_def.label))
'Non partitionable OperationDefs cannot be cacheable: '
f'{operation_def.label}'
)
if operation_def.is_partitionable or operation_def.cache_coder is not None:
if operation_def.num_outputs != 1:
raise ValueError(
'Cacheable OperationDefs must have exactly 1 output: {}'.format(
operation_def.label))
'Cacheable OperationDefs must have exactly 1 output: '
f'{operation_def.label}'
)

def get_detached_sideeffect_leafs(self):
"""Returns a list of sideeffect leaf nodes after the visit is done."""
# If this is a multi-phase analysis, then all datasets have to be read
# anyway, and so we'll not instrument full cache coverage for this case.
if self._num_phases > 1:
return []
result = []
for (dataset_idx, dataset_key) in enumerate(self._sorted_dataset_keys):
dataset_keys_with_decoded_cache = []
for dataset_key in self._sorted_dataset_keys:
# Default to True here, if the dataset_key is not in the cache misses map
# then treat it like it does have cache misses because it has not been
# visited in the optimization traversal.
Expand All @@ -207,12 +211,18 @@ def get_detached_sideeffect_leafs(self):
cache_dict = self._cache_dict or {}
dataset_cache_entries = cache_dict.get(dataset_key, None)
if dataset_cache_entries is not None and dataset_cache_entries.metadata:
node = nodes.apply_operation(
analyzer_nodes.InstrumentDatasetCache,
dataset_key=dataset_key,
label=f'InstrumentDatasetCache[AnalysisIndex{dataset_idx}]')
result.append(node)
return result
dataset_keys_with_decoded_cache.append(dataset_key)
if (dataset_keys_with_decoded_cache or self._num_encode_cache_nodes or
self._num_decode_cache_nodes):
return [
nodes.apply_operation(
analyzer_nodes.InstrumentDatasetCache,
input_cache_dataset_keys=dataset_keys_with_decoded_cache,
num_encode_cache=self._num_encode_cache_nodes,
num_decode_cache=self._num_decode_cache_nodes,
label='InstrumentDatasetCache')
]
return []

def _make_next_hashed_path(self, parent_hashed_paths, operation_def):
# Making a copy of parent_hashed_paths.
Expand Down Expand Up @@ -269,7 +279,7 @@ def visit(self, operation_def, input_values):
next_inputs = nodes.apply_multi_output_operation(
beam_nodes.Flatten,
*disaggregated_input_values,
label='FlattenCache[{}]'.format(operation_def.label))
label=f'FlattenCache[{operation_def.label}]')
else:
# Parent operation output is not cacheable, therefore we can just use
# a flattened view.
Expand Down Expand Up @@ -341,30 +351,31 @@ def _apply_operation_on_fine_grained_view(self, operation_def,

for (dataset_idx, dataset_key) in enumerate(self._sorted_dataset_keys):
# We use an index for the label in order to make beam labels more stable.
infix = 'AnalysisIndex{}'.format(dataset_idx)
infix = f'AnalysisIndex{dataset_idx}'
if (operation_def.cache_coder and self._cache_dict.get(
dataset_key, {}).get(cache_entry_key) is not None):
self._dataset_has_cache_misses[dataset_key] |= False
decode_cache = analyzer_nodes.DecodeCache(
dataset_key,
cache_entry_key,
coder=operation_def.cache_coder,
label='DecodeCache[{}][{}]'.format(operation_def.label, infix))
label=f'DecodeCache[{operation_def.label}][{infix}]')
(op_output,) = nodes.OperationNode(decode_cache, tuple()).outputs
self._num_decode_cache_nodes += 1
else:
value_nodes = tuple(v[dataset_key] for v in fine_grained_views)
(op_output,) = nodes.OperationNode(
operation_def._replace(
label='{}[{}]'.format(operation_def.label, infix)),
operation_def._replace(label=f'{operation_def.label}[{infix}]'),
value_nodes).outputs
if operation_def.cache_coder:
self._dataset_has_cache_misses[dataset_key] = True
encode_cache = nodes.apply_operation(
analyzer_nodes.EncodeCache,
op_output,
coder=operation_def.cache_coder,
label='EncodeCache[{}][{}]'.format(operation_def.label, infix))
label=f'EncodeCache[{operation_def.label}][{infix}]')
self.cache_output_nodes[(dataset_key, cache_entry_key)] = encode_cache
self._num_encode_cache_nodes += 1
result_fine_grained_view[dataset_key] = op_output

return result_fine_grained_view
Expand All @@ -377,16 +388,15 @@ def _visit_apply_savedmodel_operation(self, operation_def, upstream_views):

fine_grained_view = collections.OrderedDict()
for (dataset_idx, dataset_key) in enumerate(self._sorted_dataset_keys):
infix = 'AnalysisIndex{}'.format(dataset_idx)
infix = f'AnalysisIndex{dataset_idx}'
input_node = nodes.apply_operation(
beam_nodes.ExtractInputForSavedModel,
dataset_key=dataset_key,
label='ExtractInputForSavedModel[{}]'.format(infix))
label=f'ExtractInputForSavedModel[{infix}]')
# We use an index for the label in order to make beam labels more stable.
(fine_grained_view[dataset_key],) = (
nodes.OperationNode(
operation_def._replace(
label='{}[{}]'.format(operation_def.label, infix)),
operation_def._replace(label=f'{operation_def.label}[{infix}]'),
(saved_model_path_upstream_view.flattened_view,
input_node)).outputs)

Expand All @@ -404,8 +414,8 @@ def validate_value(self, value):
assert isinstance(value, _OptimizationView), value
if value.fine_grained_view:
assert set(value.fine_grained_view.keys()) == set(
self._sorted_dataset_keys), ('{} != {}'.format(
value.fine_grained_view.keys(), self._sorted_dataset_keys))
self._sorted_dataset_keys
), (f'{value.fine_grained_view.keys()} != {self._sorted_dataset_keys}')


def _perform_cache_optimization(saved_model_future, dataset_keys,
Expand Down Expand Up @@ -593,7 +603,7 @@ def preprocessing_fn(input)
label='ExtractInputForSavedModel[FlattenedDataset]')

while not all(sink_tensors_ready.values()):
infix = 'Phase{}'.format(phase)
infix = f'Phase{phase}'
# Determine which table init ops are ready to run in this phase
# Determine which keys of pending_tensor_replacements are ready to run
# in this phase, based in whether their dependencies are ready.
Expand All @@ -610,14 +620,14 @@ def preprocessing_fn(input)
*tensor_bindings,
table_initializers=tuple(graph_analyzer.ready_table_initializers),
output_signature=intermediate_output_signature,
label='CreateSavedModelForAnalyzerInputs[{}]'.format(infix))
label=f'CreateSavedModelForAnalyzerInputs[{infix}]')

extracted_values_dict = nodes.apply_operation(
beam_nodes.ApplySavedModel,
saved_model_future,
extracted_input_node,
phase=phase,
label='ApplySavedModel[{}]'.format(infix))
label=f'ApplySavedModel[{infix}]')

translate_visitor.phase = phase
translate_visitor.intermediate_output_signature = (
Expand All @@ -643,7 +653,7 @@ def preprocessing_fn(input)
dtype_enum=tensor.dtype.as_datatype_enum,
is_asset_filepath=is_asset_filepath,
label=analyzer_nodes.sanitize_label(
'CreateTensorBinding[{}]'.format(name))))
f'CreateTensorBinding[{name}]')))
sink_tensors_ready[hashable_tensor] = True

analyzers_input_signature.update(intermediate_output_signature)
Expand Down
51 changes: 35 additions & 16 deletions tensorflow_transform/beam/analyzer_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,30 +1325,51 @@ def __init__(self, operation, extra_args):
def expand(self, inputs):
pcoll, = inputs

return (pcoll
| 'Encode' >> beam.Map(self._coder.encode_cache)
| 'Count' >> common.IncrementCounter('cache_entries_encoded'))
return pcoll | 'Encode' >> beam.Map(self._coder.encode_cache)


@common.register_ptransform(analyzer_nodes.InstrumentDatasetCache)
@beam.typehints.with_input_types(beam.pvalue.PBegin)
@beam.typehints.with_output_types(None)
class _InstrumentDatasetCacheImpl(beam.PTransform):
"""Instruments datasets not read due to cache hit."""
"""Instruments pipeline analysis cache usage."""

def __init__(self, operation, extra_args):
self._metadata_pcoll = (
extra_args.cache_pcoll_dict[operation.dataset_key].metadata)
self.pipeline = extra_args.pipeline
self._metadata_pcolls = tuple(extra_args.cache_pcoll_dict[k].metadata
for k in operation.input_cache_dataset_keys)
self._num_encode_cache = operation.num_encode_cache
self._num_decode_cache = operation.num_decode_cache

def _make_and_increment_counter(self, metadata):
if metadata:
beam.metrics.Metrics.counter(common.METRICS_NAMESPACE,
'analysis_input_bytes_from_cache').inc(
metadata.dataset_size)
def _make_and_increment_counter(self, value, name):
beam.metrics.Metrics.counter(common.METRICS_NAMESPACE, name).inc(value)

def expand(self, pbegin):
return (self._metadata_pcoll | 'InstrumentCachedInputBytes' >> beam.Map(
self._make_and_increment_counter))
if self._num_encode_cache > 0:
_ = (
pbegin
| 'CreateSoleCacheEncodeInstrument' >> beam.Create(
[self._num_encode_cache])
| 'InstrumentCacheEncode' >> beam.Map(
self._make_and_increment_counter, 'cache_entries_encoded'))
if self._num_decode_cache > 0:
_ = (
self.pipeline
| 'CreateSoleCacheDecodeInstrument' >> beam.Create(
[self._num_decode_cache])
| 'InstrumentCacheDecode' >> beam.Map(
self._make_and_increment_counter, 'cache_entries_decoded'))
if self._metadata_pcolls:
# Instruments datasets not read due to cache hit.
_ = (
self._metadata_pcolls | beam.Flatten(pipeline=self.pipeline)
| 'ExtractCachedInputBytes' >>
beam.Map(lambda m: m.dataset_size if m else 0)
| 'SumCachedInputBytes' >> beam.CombineGlobally(sum)
| 'InstrumentCachedInputBytes' >> beam.Map(
self._make_and_increment_counter,
'analysis_input_bytes_from_cache'))
return pbegin | 'CreateSoleEmptyOutput' >> beam.Create([])


@common.register_ptransform(analyzer_nodes.DecodeCache)
Expand All @@ -1366,9 +1387,7 @@ def __init__(self, operation, extra_args):
def expand(self, pbegin):
del pbegin # unused

return (self._cache_pcoll
| 'Decode' >> beam.Map(self._coder.decode_cache)
| 'Count' >> common.IncrementCounter('cache_entries_decoded'))
return self._cache_pcoll | 'Decode' >> beam.Map(self._coder.decode_cache)


@common.register_ptransform(analyzer_nodes.AddKey)
Expand Down
1 change: 1 addition & 0 deletions tensorflow_transform/beam/cached_impl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def is_partitionable(self):
"FakeChainableCacheable[x/cacheable2][AnalysisIndex0]" -> "EncodeCache[FakeChainableCacheable[x/cacheable2]][AnalysisIndex0]";
"EncodeCache[FakeChainableCacheable[x/cacheable2]][AnalysisIndex1]" [label="{EncodeCache|coder: Not-a-coder-but-thats-ok!|label: EncodeCache[FakeChainableCacheable[x/cacheable2]][AnalysisIndex1]|partitionable: True}"];
"FakeChainableCacheable[x/cacheable2][AnalysisIndex1]" -> "EncodeCache[FakeChainableCacheable[x/cacheable2]][AnalysisIndex1]";
InstrumentDatasetCache [label="{InstrumentDatasetCache|input_cache_dataset_keys: []|num_encode_cache: 4|num_decode_cache: 0|label: InstrumentDatasetCache|partitionable: True}"];
}
""")

Expand Down
15 changes: 11 additions & 4 deletions tensorflow_transform/beam/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,16 +1097,23 @@ def expand(self, dataset):
telemetry.TrackRecordBatchBytes(beam_common.METRICS_NAMESPACE,
'analysis_input_bytes'))
else:
bytes_per_dataset = []
for idx, key in enumerate(sorted(input_values_pcoll_dict.keys())):
infix = f'AnalysisIndex{idx}'
if input_values_pcoll_dict[key] is not None:
bytes_per_dataset.append(input_values_pcoll_dict[key]
| f'ExtractInputBytes[{infix}]' >>
telemetry.ExtractRecordBatchBytes())
dataset_metrics[key] = (
input_values_pcoll_dict[key]
| f'InstrumentInputBytes[AnalysisPCollDict][{infix}]' >>
telemetry.TrackRecordBatchBytes(beam_common.METRICS_NAMESPACE,
'analysis_input_bytes')
bytes_per_dataset[-1]
| f'ConstructMetadata[{infix}]' >> beam.Map(
analyzer_cache.DatasetCacheMetadata))
_ = (
bytes_per_dataset
| 'FlattenAnalysisBytes' >> beam.Flatten(pipeline=pipeline)
| 'InstrumentInputBytes[AnalysisPCollDict]' >>
telemetry.IncrementCounter(beam_common.METRICS_NAMESPACE,
'analysis_input_bytes'))

# Gather telemetry on types of input features.
_ = (
Expand Down

0 comments on commit a206811

Please sign in to comment.