[go: nahoru, domu]

Skip to content

Commit

Permalink
Propagate dtypes of analyzer outputs to be replaced in the output tra…
Browse files Browse the repository at this point in the history
…nsformation graph.

This fixes an issue where a ptransform analyzer returns a list of primitive ints for example, the resulting output tensor ends up being of type tf.int32 which is invalid.

PiperOrigin-RevId: 409936221
  • Loading branch information
zoyahav authored and tf-transform-team committed Nov 15, 2021
1 parent 7612590 commit 043a328
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 76 deletions.
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

* Raise a RuntimeError if order of analyzers in traced Tensorflow Graph is
non-deterministic in TF2.
* Fix issue where a `tft.experimental.ptransform_analyzer`'s output dtype
could be propagated incorrectly if it was a primitive as opposed to
`np.ndarray`.

## Breaking Changes

Expand Down
3 changes: 2 additions & 1 deletion tensorflow_transform/beam/analysis_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,8 @@ def preprocessing_fn(input)
nodes.apply_operation(
beam_nodes.CreateTensorBinding,
translated_value_node,
tensor=str(tensor.name),
tensor_name=str(tensor.name),
dtype_enum=tensor.dtype.as_datatype_enum,
is_asset_filepath=is_asset_filepath,
label=analyzer_nodes.sanitize_label(
'CreateTensorBinding[{}]'.format(name))))
Expand Down
36 changes: 18 additions & 18 deletions tensorflow_transform/beam/analysis_graph_builder_test.py

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions tensorflow_transform/beam/beam_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@


class CreateTensorBinding(
tfx_namedtuple.namedtuple('CreateTensorBinding',
['tensor', 'is_asset_filepath', 'label']),
tfx_namedtuple.namedtuple(
'CreateTensorBinding',
['tensor_name', 'dtype_enum', 'is_asset_filepath', 'label']),
nodes.OperationDef):
"""An operation that represents creating a tensor binding from a value.
Expand All @@ -60,8 +61,9 @@ class CreateTensorBinding(
to create a tensor binding.
Attributes:
tensor: The name of the tensor that the given value should replace as a
tensor_name: The name of the tensor that the given value should replace as a
constant tensor.
dtype_enum: The Dtype of the tensor as a TF `types_pb2.DataType`.
is_asset_filepath: If true, then the replaced value will be added to the
ASSET_FILEPATHS collection if exporting a TF1 Graph.
label: A unique label for this operation.
Expand Down
28 changes: 14 additions & 14 deletions tensorflow_transform/beam/cached_impl_test.py

Large diffs are not rendered by default.

60 changes: 30 additions & 30 deletions tensorflow_transform/beam/combiner_packing_util_test.py

Large diffs are not rendered by default.

14 changes: 9 additions & 5 deletions tensorflow_transform/beam/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,8 @@ def _convert_to_record_batch(


_TensorBinding = tfx_namedtuple.namedtuple(
'_TensorBinding', ['value', 'tensor_name', 'is_asset_filepath'])
'_TensorBinding',
['value', 'tensor_name', 'dtype_enum', 'is_asset_filepath'])


@beam_common.register_ptransform(beam_nodes.CreateTensorBinding)
Expand All @@ -519,13 +520,15 @@ class _CreateTensorBindingsImpl(beam.PTransform):

def __init__(self, operation, extra_args):
del extra_args
self._tensor = operation.tensor
self._dtype_enum = operation.dtype_enum
self._tensor_name = operation.tensor_name
self._is_asset_file = operation.is_asset_filepath

def expand(self, inputs):
pcoll, = inputs
return pcoll | 'ToTensorBinding' >> beam.Map(_TensorBinding, self._tensor,
self._is_asset_file)
return pcoll | 'ToTensorBinding' >> beam.Map(
_TensorBinding, self._tensor_name, self._dtype_enum,
self._is_asset_file)


def _get_tensor_replacement_map(graph, *tensor_bindings):
Expand All @@ -534,7 +537,8 @@ def _get_tensor_replacement_map(graph, *tensor_bindings):

for tensor_binding in tensor_bindings:
assert isinstance(tensor_binding, _TensorBinding), tensor_binding
replacement_tensor = tf.constant(tensor_binding.value)
replacement_tensor = tf.constant(
tensor_binding.value, tf.dtypes.as_dtype(tensor_binding.dtype_enum))
if graph is not None and tensor_binding.is_asset_filepath:
graph.add_to_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS,
replacement_tensor)
Expand Down
16 changes: 11 additions & 5 deletions tensorflow_transform/beam/impl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3237,13 +3237,15 @@ def analyzer_fn(inputs):

class _SumCombiner(tft_beam.experimental.PTransformAnalyzer):

def __init__(self):
def __init__(self, is_list_output):
super().__init__()
self.base_temp_dir_in_expand = None
self._is_list_output = is_list_output

# TODO(zoyahav): Add an test where the returned np.arrays are multi-dimensional.
@staticmethod
def _extract_outputs(sums):
def _extract_outputs(self, sums):
if self._is_list_output:
sums = sums.tolist()
return [beam.pvalue.TaggedOutput('0', sums[0]),
beam.pvalue.TaggedOutput('1', sums[1])]

Expand All @@ -3256,10 +3258,14 @@ def expand(self, pcoll: beam.PCollection[Tuple[np.ndarray, ...]]):
| beam.FlatMap(self._extract_outputs).with_outputs('0', '1'))
return tuple(outputs)

def testPTransformAnalyzer(self):
@tft_unit.named_parameters(
dict(testcase_name='ArrayOutput', is_list_output=False),
dict(testcase_name='ListOutput', is_list_output=True),
)
def testPTransformAnalyzer(self, is_list_output):
self._SkipIfOutputRecordBatches()

sum_combiner = self._SumCombiner()
sum_combiner = self._SumCombiner(is_list_output)

def analyzer_fn(inputs):
outputs = tft.experimental.ptransform_analyzer([inputs['x'], inputs['y']],
Expand Down

0 comments on commit 043a328

Please sign in to comment.