[go: nahoru, domu]

Skip to content

Commit

Permalink
Adding a test case validating the expected TFT graph for a 2-phase an…
Browse files Browse the repository at this point in the history
…alysis which is fully covered by cache.

PiperOrigin-RevId: 405878964
  • Loading branch information
zoyahav authored and tf-transform-team committed Oct 27, 2021
1 parent b06e87b commit 2ddd0df
Showing 1 changed file with 99 additions and 11 deletions.
110 changes: 99 additions & 11 deletions tensorflow_transform/beam/cached_impl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
import functools
import os
import struct
from typing import Callable, Mapping, List
import apache_beam as beam
from apache_beam.testing import util as beam_test_util
import numpy as np

import tensorflow as tf
import tensorflow_transform as tft
from tensorflow_transform import analyzer_nodes
from tensorflow_transform import common_types
from tensorflow_transform import impl_helper
from tensorflow_transform import nodes
import tensorflow_transform.beam as tft_beam
Expand Down Expand Up @@ -73,10 +75,10 @@ def _preprocessing_fn_for_common_optimize_traversal(inputs):
's': tf.io.FixedLenFeature([], tf.string)
},
preprocessing_fn=_preprocessing_fn_for_common_optimize_traversal,
dataset_input_cache_dict={
dataset_input_cache_dicts=[{
_make_cache_key(b'CacheableCombineAccumulate[x#mean_and_var]'):
'cache hit',
},
}],
expected_dot_graph_str=r"""digraph G {
directed=True;
node [shape=Mrecord];
Expand Down Expand Up @@ -165,6 +167,83 @@ def _preprocessing_fn_for_common_optimize_traversal(inputs):
}
""")

_OPTIMIZE_TRAVERSAL_MULTI_PHASE_FULL_CACHE_HIT_CASE = dict(
testcase_name='multi_phase_full_cache_coverage',
feature_spec={
'x': tf.io.FixedLenFeature([], tf.float32),
's': tf.io.FixedLenFeature([], tf.string)
},
preprocessing_fn=_preprocessing_fn_for_common_optimize_traversal,
dataset_input_cache_dicts=[{
_make_cache_key(b'CacheableCombineAccumulate[x#mean_and_var]'):
'cache hit',
_make_cache_key(b'VocabularyAccumulate[vocabulary]'):
'cache hit',
}] * 2,
expected_dot_graph_str=r"""digraph G {
directed=True;
node [shape=Mrecord];
"DecodeCache[VocabularyAccumulate[vocabulary]][AnalysisIndex0]" [label="{DecodeCache|dataset_key: DatasetKey(key='span-0')|cache_key: \<bytes\>|coder: \<_VocabularyAccumulatorCoder\>|label: DecodeCache[VocabularyAccumulate[vocabulary]][AnalysisIndex0]|partitionable: True}"];
"DecodeCache[VocabularyAccumulate[vocabulary]][AnalysisIndex1]" [label="{DecodeCache|dataset_key: DatasetKey(key='span-1')|cache_key: \<bytes\>|coder: \<_VocabularyAccumulatorCoder\>|label: DecodeCache[VocabularyAccumulate[vocabulary]][AnalysisIndex1]|partitionable: True}"];
"FlattenCache[VocabularyMerge[vocabulary]]" [label="{Flatten|label: FlattenCache[VocabularyMerge[vocabulary]]|partitionable: True}"];
"DecodeCache[VocabularyAccumulate[vocabulary]][AnalysisIndex0]" -> "FlattenCache[VocabularyMerge[vocabulary]]";
"DecodeCache[VocabularyAccumulate[vocabulary]][AnalysisIndex1]" -> "FlattenCache[VocabularyMerge[vocabulary]]";
"VocabularyMerge[vocabulary]" [label="{VocabularyMerge|vocab_ordering_type: 1|use_adjusted_mutual_info: False|min_diff_from_avg: None|label: VocabularyMerge[vocabulary]}"];
"FlattenCache[VocabularyMerge[vocabulary]]" -> "VocabularyMerge[vocabulary]";
"VocabularyCount[vocabulary]" [label="{VocabularyCount|label: VocabularyCount[vocabulary]}"];
"VocabularyMerge[vocabulary]" -> "VocabularyCount[vocabulary]";
"CreateTensorBinding[vocabulary#vocab_vocabulary_unpruned_vocab_size]" [label="{CreateTensorBinding|tensor: vocabulary/vocab_vocabulary_unpruned_vocab_size:0|is_asset_filepath: False|label: CreateTensorBinding[vocabulary#vocab_vocabulary_unpruned_vocab_size]}"];
"VocabularyCount[vocabulary]" -> "CreateTensorBinding[vocabulary#vocab_vocabulary_unpruned_vocab_size]";
"VocabularyPrune[vocabulary]" [label="{VocabularyPrune|top_k: None|frequency_threshold: 0|informativeness_threshold: -inf|coverage_top_k: None|coverage_frequency_threshold: 0|coverage_informativeness_threshold: -inf|key_fn: None|filter_newline_characters: True|input_dtype: string|label: VocabularyPrune[vocabulary]}"];
"VocabularyMerge[vocabulary]" -> "VocabularyPrune[vocabulary]";
"VocabularyOrderAndWrite[vocabulary]" [label="{VocabularyOrderAndWrite|vocab_filename: vocab_vocabulary|store_frequency: False|input_dtype: string|label: VocabularyOrderAndWrite[vocabulary]|fingerprint_shuffle: False|file_format: text|input_is_sorted: False}"];
"VocabularyPrune[vocabulary]" -> "VocabularyOrderAndWrite[vocabulary]";
"CreateTensorBinding[vocabulary#Placeholder]" [label="{CreateTensorBinding|tensor: vocabulary/Placeholder:0|is_asset_filepath: True|label: CreateTensorBinding[vocabulary#Placeholder]}"];
"VocabularyOrderAndWrite[vocabulary]" -> "CreateTensorBinding[vocabulary#Placeholder]";
"DecodeCache[CacheableCombineAccumulate[x#mean_and_var]][AnalysisIndex0]" [label="{DecodeCache|dataset_key: DatasetKey(key='span-0')|cache_key: \<bytes\>|coder: \<JsonNumpyCacheCoder\>|label: DecodeCache[CacheableCombineAccumulate[x#mean_and_var]][AnalysisIndex0]|partitionable: True}"];
"DecodeCache[CacheableCombineAccumulate[x#mean_and_var]][AnalysisIndex1]" [label="{DecodeCache|dataset_key: DatasetKey(key='span-1')|cache_key: \<bytes\>|coder: \<JsonNumpyCacheCoder\>|label: DecodeCache[CacheableCombineAccumulate[x#mean_and_var]][AnalysisIndex1]|partitionable: True}"];
"FlattenCache[CacheableCombineMerge[x#mean_and_var]]" [label="{Flatten|label: FlattenCache[CacheableCombineMerge[x#mean_and_var]]|partitionable: True}"];
"DecodeCache[CacheableCombineAccumulate[x#mean_and_var]][AnalysisIndex0]" -> "FlattenCache[CacheableCombineMerge[x#mean_and_var]]";
"DecodeCache[CacheableCombineAccumulate[x#mean_and_var]][AnalysisIndex1]" -> "FlattenCache[CacheableCombineMerge[x#mean_and_var]]";
"CacheableCombineMerge[x#mean_and_var]" [label="{CacheableCombineMerge|combiner: \<WeightedMeanAndVarCombiner\>|label: CacheableCombineMerge[x#mean_and_var]}"];
"FlattenCache[CacheableCombineMerge[x#mean_and_var]]" -> "CacheableCombineMerge[x#mean_and_var]";
"ExtractCombineMergeOutputs[x#mean_and_var]" [label="{ExtractCombineMergeOutputs|output_tensor_info_list: [TensorInfo(dtype=tf.float32, shape=(), temporary_asset_value=None), TensorInfo(dtype=tf.float32, shape=(), temporary_asset_value=None)]|label: ExtractCombineMergeOutputs[x#mean_and_var]|{<0>0|<1>1}}"];
"CacheableCombineMerge[x#mean_and_var]" -> "ExtractCombineMergeOutputs[x#mean_and_var]";
"CreateTensorBinding[x#mean_and_var#Placeholder]" [label="{CreateTensorBinding|tensor: x/mean_and_var/Placeholder:0|is_asset_filepath: False|label: CreateTensorBinding[x#mean_and_var#Placeholder]}"];
"ExtractCombineMergeOutputs[x#mean_and_var]":0 -> "CreateTensorBinding[x#mean_and_var#Placeholder]";
"CreateTensorBinding[x#mean_and_var#Placeholder_1]" [label="{CreateTensorBinding|tensor: x/mean_and_var/Placeholder_1:0|is_asset_filepath: False|label: CreateTensorBinding[x#mean_and_var#Placeholder_1]}"];
"ExtractCombineMergeOutputs[x#mean_and_var]":1 -> "CreateTensorBinding[x#mean_and_var#Placeholder_1]";
"CreateSavedModelForAnalyzerInputs[Phase1]" [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('x_square_deviations/mean_and_var/Cast_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x_square_deviations/mean_and_var/div_no_nan', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x_square_deviations/mean_and_var/div_no_nan_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x_square_deviations/mean_and_var/zeros', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\")])|label: CreateSavedModelForAnalyzerInputs[Phase1]}"];
"CreateTensorBinding[vocabulary#vocab_vocabulary_unpruned_vocab_size]" -> "CreateSavedModelForAnalyzerInputs[Phase1]";
"CreateTensorBinding[vocabulary#Placeholder]" -> "CreateSavedModelForAnalyzerInputs[Phase1]";
"CreateTensorBinding[x#mean_and_var#Placeholder]" -> "CreateSavedModelForAnalyzerInputs[Phase1]";
"CreateTensorBinding[x#mean_and_var#Placeholder_1]" -> "CreateSavedModelForAnalyzerInputs[Phase1]";
"ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset')|label: ExtractInputForSavedModel[FlattenedDataset]}"];
"ApplySavedModel[Phase1]" [label="{ApplySavedModel|phase: 1|label: ApplySavedModel[Phase1]|partitionable: True}"];
"CreateSavedModelForAnalyzerInputs[Phase1]" -> "ApplySavedModel[Phase1]";
"ExtractInputForSavedModel[FlattenedDataset]" -> "ApplySavedModel[Phase1]";
"TensorSource[x_square_deviations#mean_and_var]" [label="{ExtractFromDict|keys: ('x_square_deviations/mean_and_var/Cast_1', 'x_square_deviations/mean_and_var/div_no_nan', 'x_square_deviations/mean_and_var/div_no_nan_1', 'x_square_deviations/mean_and_var/zeros')|label: TensorSource[x_square_deviations#mean_and_var]|partitionable: True}"];
"ApplySavedModel[Phase1]" -> "TensorSource[x_square_deviations#mean_and_var]";
"CacheableCombineAccumulate[x_square_deviations#mean_and_var]" [label="{CacheableCombineAccumulate|combiner: \<WeightedMeanAndVarCombiner\>|label: CacheableCombineAccumulate[x_square_deviations#mean_and_var]|partitionable: True}"];
"TensorSource[x_square_deviations#mean_and_var]" -> "CacheableCombineAccumulate[x_square_deviations#mean_and_var]";
"CacheableCombineMerge[x_square_deviations#mean_and_var]" [label="{CacheableCombineMerge|combiner: \<WeightedMeanAndVarCombiner\>|label: CacheableCombineMerge[x_square_deviations#mean_and_var]}"];
"CacheableCombineAccumulate[x_square_deviations#mean_and_var]" -> "CacheableCombineMerge[x_square_deviations#mean_and_var]";
"ExtractCombineMergeOutputs[x_square_deviations#mean_and_var]" [label="{ExtractCombineMergeOutputs|output_tensor_info_list: [TensorInfo(dtype=tf.float32, shape=(), temporary_asset_value=None), TensorInfo(dtype=tf.float32, shape=(), temporary_asset_value=None)]|label: ExtractCombineMergeOutputs[x_square_deviations#mean_and_var]|{<0>0|<1>1}}"];
"CacheableCombineMerge[x_square_deviations#mean_and_var]" -> "ExtractCombineMergeOutputs[x_square_deviations#mean_and_var]";
"CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder]" [label="{CreateTensorBinding|tensor: x_square_deviations/mean_and_var/Placeholder:0|is_asset_filepath: False|label: CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder]}"];
"ExtractCombineMergeOutputs[x_square_deviations#mean_and_var]":0 -> "CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder]";
"CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder_1]" [label="{CreateTensorBinding|tensor: x_square_deviations/mean_and_var/Placeholder_1:0|is_asset_filepath: False|label: CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder_1]}"];
"ExtractCombineMergeOutputs[x_square_deviations#mean_and_var]":1 -> "CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder_1]";
CreateSavedModel [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('x_normalized', \"Tensor\<shape: [None], \<dtype: 'float32'\>\>\")])|label: CreateSavedModel}"];
"CreateTensorBinding[vocabulary#vocab_vocabulary_unpruned_vocab_size]" -> CreateSavedModel;
"CreateTensorBinding[vocabulary#Placeholder]" -> CreateSavedModel;
"CreateTensorBinding[x#mean_and_var#Placeholder]" -> CreateSavedModel;
"CreateTensorBinding[x#mean_and_var#Placeholder_1]" -> CreateSavedModel;
"CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder]" -> CreateSavedModel;
"CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder_1]" -> CreateSavedModel;
}
""")

_TF_VERSION_NAMED_PARAMETERS = [
dict(testcase_name='CompatV1', use_tf_compat_v1=True),
dict(testcase_name='V2', use_tf_compat_v1=False),
Expand Down Expand Up @@ -264,7 +343,7 @@ def is_partitionable(self):
testcase_name='generalized_chained_ptransforms',
feature_spec={'x': tf.io.FixedLenFeature([], tf.float32)},
preprocessing_fn=_preprocessing_fn_for_generalized_chained_ptransforms,
dataset_input_cache_dict=None,
dataset_input_cache_dicts=None,
expected_dot_graph_str=r"""digraph G {
directed=True;
node [shape=Mrecord];
Expand Down Expand Up @@ -334,6 +413,7 @@ def is_partitionable(self):

_OPTIMIZE_TRAVERSAL_TEST_CASES = [
_OPTIMIZE_TRAVERSAL_COMMON_CASE,
_OPTIMIZE_TRAVERSAL_MULTI_PHASE_FULL_CACHE_HIT_CASE,
_OPTIMIZE_TRAVERSAL_GENERALIZED_CHAINED_PTRANSFORMS_CASE,
]

Expand Down Expand Up @@ -1006,18 +1086,26 @@ def preprocessing_fn(inputs):

@tft_unit.named_parameters(*_OPTIMIZE_TRAVERSAL_TEST_CASES)
@mock_out_cache_hash
def test_optimize_traversal(self, feature_spec, preprocessing_fn,
dataset_input_cache_dict, expected_dot_graph_str):
span_0_key, span_1_key = analyzer_cache.DatasetKey(
'span-0'), analyzer_cache.DatasetKey('span-1')
if dataset_input_cache_dict is not None:
cache = {span_0_key: dataset_input_cache_dict}
def test_optimize_traversal(
self, feature_spec: Mapping[str, common_types.FeatureSpecType],
preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
Mapping[str, common_types.TensorType]],
dataset_input_cache_dicts: List[Mapping[str, str]],
expected_dot_graph_str: str):
dataset_keys = [
analyzer_cache.DatasetKey('span-0'),
analyzer_cache.DatasetKey('span-1')
]
if dataset_input_cache_dicts is not None:
cache = {
key: cache_dict
for key, cache_dict in zip(dataset_keys, dataset_input_cache_dicts)
}
else:
cache = {}
dot_string = self._publish_rendered_dot_graph_file(preprocessing_fn,
feature_spec,
{span_0_key, span_1_key},
cache)
set(dataset_keys), cache)

self.assertSameElements(
expected_dot_graph_str.split('\n'),
Expand Down

0 comments on commit 2ddd0df

Please sign in to comment.