[go: nahoru, domu]

Skip to content

Commit

Permalink
Validate that all nodes in the TFT graph have unique labels. There sh…
Browse files Browse the repository at this point in the history
…ouldn't be a duplicate label unless there's a bug in TFT code, or caused by custom PTransform analyzers.

PiperOrigin-RevId: 397702599
  • Loading branch information
zoyahav authored and tf-transform-team committed Sep 20, 2021
1 parent 33ad15c commit f9aa0cc
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tensorflow_transform/analyzers.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def _min_and_max_per_key(
TypeError: If the type of `x` is not supported.
"""
if key is None:
raise ValueError('A key is required for _mean_and_var_per_key')
raise ValueError('A key is required for _min_and_max_per_key')

if not reduce_instance_dims:
raise NotImplementedError('Per-key elementwise reduction not supported')
Expand Down
9 changes: 9 additions & 0 deletions tensorflow_transform/beam/analysis_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,17 @@ class _ReadyVisitor(nodes.Visitor):

def __init__(self, graph_analyzer):
self._graph_analyzer = graph_analyzer
self._visited_operation_def_labels = set()

def _validate_operation_label_uniqueness(self, operation_def):
assert operation_def.label not in self._visited_operation_def_labels, (
f'An operation with label {operation_def.label} '
'already exists in the operations graph.')
self._visited_operation_def_labels.add(operation_def.label)

def visit(self, operation_def, input_values):
self._validate_operation_label_uniqueness(operation_def)

if isinstance(operation_def, analyzer_nodes.TensorSource):
is_ready = all(self._graph_analyzer.ready_to_run(tensor)
for tensor in operation_def.tensors)
Expand Down
37 changes: 37 additions & 0 deletions tensorflow_transform/beam/analysis_graph_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tensorflow_transform import analyzer_nodes
from tensorflow_transform import impl_helper
from tensorflow_transform import nodes
from tensorflow_transform import tf2_utils
from tensorflow_transform.beam import analysis_graph_builder
from tensorflow_transform import test_case
# TODO(https://issues.apache.org/jira/browse/SPARK-22674): Switch to
Expand Down Expand Up @@ -524,6 +525,42 @@ def mocked_make_cache_entry_key(_):
self.WriteRenderedDotFile(dot_string)
self.assertCountEqual(cache_entry_keys, [mocked_cache_entry_key])

def test_duplicate_label_error(self):

def _preprocessing_fn(inputs):

class _Analyzer(
tfx_namedtuple.namedtuple('_Analyzer', ['label']),
nodes.OperationDef):
pass

input_values_node = nodes.apply_operation(
analyzer_nodes.TensorSource, tensors=[inputs['x']])
intermediate_value_node = nodes.apply_operation(
_Analyzer, input_values_node, label='SameLabel')
output_value_node = nodes.apply_operation(
_Analyzer, intermediate_value_node, label='SameLabel')
x_chained = analyzer_nodes.bind_future_as_tensor(
output_value_node,
analyzer_nodes.TensorInfo(tf.float32, (17, 27), None))
return {'x_chained': x_chained}

feature_spec = {'x': tf.io.FixedLenFeature([], tf.float32)}
use_tf_compat_v1 = tf2_utils.use_tf_compat_v1(False)
specs = (
feature_spec if use_tf_compat_v1 else
impl_helper.get_type_specs_from_feature_specs(feature_spec))
graph, structured_inputs, structured_outputs = (
impl_helper.trace_preprocessing_function(
_preprocessing_fn,
specs,
use_tf_compat_v1=use_tf_compat_v1,
base_temp_dir=os.path.join(self.get_temp_dir(),
self._testMethodName)))
with self.assertRaisesRegex(AssertionError, 'SameLabel'):
_ = analysis_graph_builder.build(graph, structured_inputs,
structured_outputs)


if __name__ == '__main__':
test_case.main()

0 comments on commit f9aa0cc

Please sign in to comment.