[go: nahoru, domu]

Skip to content

Commit

Permalink
Some improvements to ptransform_analyzer:
Browse files Browse the repository at this point in the history
 - Made the input PCollection be dynamically typed, allowing input type hints to specify how many np.ndarrays are expected in the PCollection items.
 - Allowed custom analyzers to return DoOutputsTuple directly instead of wrapping it with a tuple
 - Modified how the example accesses the value for a specific key to generalize it
 - Expanded documentation to explicitly mention that the outputs of the analyzer must match given output_* parameters
 - Added a test for multi-dimensional analyzer output

PiperOrigin-RevId: 410028986
  • Loading branch information
zoyahav authored and tf-transform-team committed Nov 15, 2021
1 parent 043a328 commit 34b1c79
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 48 deletions.
1 change: 0 additions & 1 deletion tensorflow_transform/beam/analyzer_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,6 @@ def extract_output(self, accumulator):
return self._combiner.extract_output(accumulator)


@beam.typehints.with_input_types(Tuple[np.ndarray, ...])
class _CombinerWrapper(beam.CombineFn):
"""Class to wrap a analyzer_nodes.Combiner as a beam.CombineFn."""

Expand Down
18 changes: 15 additions & 3 deletions tensorflow_transform/beam/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
# resolved.
from tfx_bsl.types import tfx_namedtuple

from tensorflow.python.framework import ops # pylint: disable=g-direct-tensorflow-import
from tensorflow_metadata.proto.v0 import schema_pb2

# TODO(b/123325923): Fix the key type here to agree with the actual keys.
Expand Down Expand Up @@ -535,10 +536,17 @@ def _get_tensor_replacement_map(graph, *tensor_bindings):
"""Get Tensor replacement map."""
tensor_replacement_map = {}

is_graph_mode = not ops.executing_eagerly_outside_functions()
for tensor_binding in tensor_bindings:
assert isinstance(tensor_binding, _TensorBinding), tensor_binding
value = tensor_binding.value
# TODO(b/160294509): tf.constant doesn't accept List[np.ndarray] in TF 1.15
# graph mode. Remove this condition.
if (is_graph_mode and isinstance(value, list) and
any(isinstance(x, np.ndarray) for x in value)):
value = np.asarray(tensor_binding.value)
replacement_tensor = tf.constant(
tensor_binding.value, tf.dtypes.as_dtype(tensor_binding.dtype_enum))
value, tf.dtypes.as_dtype(tensor_binding.dtype_enum))
if graph is not None and tensor_binding.is_asset_filepath:
graph.add_to_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS,
replacement_tensor)
Expand Down Expand Up @@ -792,7 +800,6 @@ def _convert_to_numpy(input_dict):
@beam.typehints.with_input_types(Dict[str,
Union[np.ndarray,
tf.compat.v1.SparseTensorValue]])
@beam.typehints.with_output_types(Tuple[np.ndarray, ...])
class _ExtractFromDictImpl(beam.PTransform):
"""Implements ExtractFromDict by extracting the configured keys."""

Expand All @@ -807,7 +814,12 @@ def extract_keys(input_dict, keys):
return (tuple(input_dict[k] for k in keys)
if isinstance(keys, tuple) else input_dict[keys])

return pcoll | 'ExtractKeys' >> beam.Map(extract_keys, keys=self._keys)
if isinstance(self._keys, tuple):
output_type = Tuple[(np.ndarray,) * len(self._keys)]
else:
output_type = np.ndarray
return pcoll | 'ExtractKeys' >> beam.Map(
extract_keys, keys=self._keys).with_output_types(output_type)


@beam_common.register_ptransform(beam_nodes.Flatten)
Expand Down
74 changes: 48 additions & 26 deletions tensorflow_transform/beam/impl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3237,35 +3237,26 @@ def analyzer_fn(inputs):

class _SumCombiner(tft_beam.experimental.PTransformAnalyzer):

def __init__(self, is_list_output):
def __init__(self):
super().__init__()
self.base_temp_dir_in_expand = None
self._is_list_output = is_list_output

# TODO(zoyahav): Add an test where the returned np.arrays are multi-dimensional.
def _extract_outputs(self, sums):
if self._is_list_output:
sums = sums.tolist()
return [beam.pvalue.TaggedOutput('0', sums[0]),
beam.pvalue.TaggedOutput('1', sums[1])]

def expand(self, pcoll: beam.PCollection[Tuple[np.ndarray, ...]]):
def expand(self, pcoll: beam.PCollection[Tuple[np.ndarray, np.ndarray]]):
self.base_temp_dir_in_expand = self.base_temp_dir
outputs = (
pcoll
| beam.FlatMap(lambda arrays: list(zip(*arrays)))
| beam.CombineGlobally(lambda values: np.sum(list(values), axis=0))
| beam.FlatMap(self._extract_outputs).with_outputs('0', '1'))
return tuple(outputs)
return (pcoll
| beam.FlatMap(lambda baches: list(zip(*baches)))
|
beam.CombineGlobally(lambda values: np.sum(list(values), axis=0))
| beam.FlatMap(self._extract_outputs).with_outputs('0', '1'))

@tft_unit.named_parameters(
dict(testcase_name='ArrayOutput', is_list_output=False),
dict(testcase_name='ListOutput', is_list_output=True),
)
def testPTransformAnalyzer(self, is_list_output):
def testPTransformAnalyzer(self):
self._SkipIfOutputRecordBatches()

sum_combiner = self._SumCombiner(is_list_output)
sum_combiner = self._SumCombiner()

def analyzer_fn(inputs):
outputs = tft.experimental.ptransform_analyzer([inputs['x'], inputs['y']],
Expand All @@ -3274,8 +3265,6 @@ def analyzer_fn(inputs):
[[], []])
return {'x_sum': outputs[0], 'y_sum': outputs[1]}

# NOTE: We force 10 batches: data has 100 elements and we request a batch
# size of 10.
input_data = [{'x': 1, 'y': i} for i in range(100)]
input_metadata = tft_unit.metadata_from_feature_spec({
'x': tf.io.FixedLenFeature([], tf.int64),
Expand All @@ -3286,16 +3275,49 @@ def analyzer_fn(inputs):
'y_sum': np.array(4950, np.int64)
}
self.assertIsNone(sum_combiner.base_temp_dir_in_expand)
self.assertAnalyzerOutputs(
input_data,
input_metadata,
analyzer_fn,
expected_outputs,
desired_batch_size=10)
self.assertAnalyzerOutputs(input_data, input_metadata, analyzer_fn,
expected_outputs)
self.assertIsNotNone(sum_combiner.base_temp_dir_in_expand)
self.assertStartsWith(sum_combiner.base_temp_dir_in_expand,
self.get_temp_dir())

@tft_unit.named_parameters(
dict(
testcase_name='ArrayOutput',
output_fn=lambda x: np.array(x, np.int64)),
dict(testcase_name='ListOutput', output_fn=list),
)
def testPTransformAnalyzerMultiDimOutput(self, output_fn):
self._SkipIfOutputRecordBatches()

class _SimpleSumCombiner(tft_beam.experimental.PTransformAnalyzer):

def expand(self, pcoll: beam.PCollection[Tuple[np.ndarray, np.ndarray]]):
return (
pcoll
| beam.FlatMap(lambda baches: list(zip(*baches)))
| beam.CombineGlobally(lambda values: np.sum(list(values), axis=0))
| beam.combiners.ToList()
| beam.Map(output_fn))

sum_combiner = _SimpleSumCombiner()

def analyzer_fn(inputs):
outputs, = tft.experimental.ptransform_analyzer(
[inputs['x'], inputs['y']], sum_combiner, [tf.int64], [[1, 2]])
return {'x_y_sums': outputs}

input_data = [{'x': 1, 'y': i} for i in range(100)]
input_metadata = tft_unit.metadata_from_feature_spec({
'x': tf.io.FixedLenFeature([], tf.int64),
'y': tf.io.FixedLenFeature([], tf.int64)
})
expected_outputs = {
'x_y_sums': np.array([[100, 4950]], np.int64),
}
self.assertAnalyzerOutputs(input_data, input_metadata, analyzer_fn,
expected_outputs)

@unittest.skipIf(not common.IS_ANNOTATIONS_PB_AVAILABLE,
'Schema annotations are not available')
def testSavedModelWithAnnotations(self):
Expand Down
35 changes: 21 additions & 14 deletions tensorflow_transform/experimental/analyzers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,26 @@ def ptransform_analyzer(
Example:
>>> class MeanPerKey(beam.PTransform):
... def expand(self, pcoll: beam.PCollection[Tuple[np.ndarray, ...]]):
... # Returning a single PCollection since this analyzer has 1 output.
... def expand(self, pcoll: beam.PCollection[Tuple[np.ndarray, np.ndarray]]):
... def extract_output(key_value_pairs):
... keys, values = zip(*key_value_pairs)
... return [beam.TaggedOutput('keys', keys),
... beam.TaggedOutput('values', values)]
... return (pcoll
... | 'ZipAndFlatten' >> beam.FlatMap(lambda arrays: list(zip(*arrays)))
... | 'ZipAndFlatten' >> beam.FlatMap(lambda batches: list(zip(*batches)))
... | 'MeanPerKey' >> beam.CombinePerKey(beam.combiners.MeanCombineFn())
... | 'ToList' >> beam.combiners.ToList()
... | 'SortedMeansByKey' >>
... beam.Map(lambda kv_list: [v for _, v in sorted(kv_list)]))
... | 'Extract' >> beam.FlatMap(extract_output).with_outputs(
... 'keys', 'values'))
>>> def preprocessing_fn(inputs):
... outputs = tft.experimental.ptransform_analyzer(
... inputs=[inputs['s'], inputs['x']],
... ptransform=MeanPerKey(),
... output_dtypes=[tf.float32],
... output_shapes=[[2]])
... (mean_per_key,) = outputs
... return { 'x/mean_a': inputs['x'] / mean_per_key[0] }
... output_dtypes=[tf.string, tf.float32],
... output_shapes=[[2], [2]])
... (keys, means) = outputs
... mean_a = tf.reshape(tf.gather(means, tf.where(keys == 'a')), [])
... return { 'x/mean_a': inputs['x'] / mean_a }
>>> raw_data = [dict(x=1, s='a'), dict(x=8, s='b'), dict(x=3, s='a')]
>>> feature_spec = dict(
... x=tf.io.FixedLenFeature([], tf.float32),
Expand All @@ -111,11 +115,14 @@ def ptransform_analyzer(
inputs: An ordered collection of input `Tensor`s.
ptransform: A Beam PTransform that accepts a Beam PCollection where each
element is a tuple of `ndarray`s. Each element in the tuple contains a
batch of values for the corresponding input tensor of the analyzer. It
returns a `PCollection`, or a tuple of `PCollections`, each containing a
single element which is an `ndarray` or a list. It may inherit from
`tft_beam.experimental.PTransformAnalyzer` if access to a temp base
directory is needed.
batch of values for the corresponding input tensor of the analyzer and
maintain their shapes and dtypes.
It returns a `PCollection`, or a tuple of `PCollections`, each containing
a single element which is an `ndarray` or a list of primitive types. The
contents of these output `PCollection`s must be consistent with the given
values of `output_dtypes` and `output_shapes`.
It may inherit from `tft_beam.experimental.PTransformAnalyzer` if access
to a temp base directory is needed.
output_dtypes: An ordered collection of TensorFlow dtypes of the output of
the analyzer.
output_shapes: An ordered collection of shapes of the output of the
Expand Down
12 changes: 9 additions & 3 deletions tensorflow_transform/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,19 @@ def _visit_operation(self, operation: OperationNode):
# Expect a tuple of outputs. Since ValueNode and OperationDef are both
# subclasses of tuple, we also explicitly disallow them, since returning
# a single ValueNode or OperationDef is almost certainly an error.
if (not isinstance(output_values, tuple) or
isinstance(output_values, (ValueNode, OperationDef))):
try:
_ = iter(output_values)
output_iterable = not isinstance(output_values, str)
except TypeError:
output_iterable = False
if (not output_iterable or isinstance(output_values,
(ValueNode, OperationDef))):
raise ValueError(
'When running operation {} expected visitor to return a tuple, got '
'{} of type {}'.format(operation.operation_def.label, output_values,
type(output_values)))
if len(output_values) != len(outputs):
# DoOutputsTuple doesn't work with len().
if hasattr(output_values, '__len__') and len(output_values) != len(outputs):
raise ValueError(
'Operation {} has {} outputs but visitor returned {} values: '
'{}'.format(operation.operation_def, len(outputs),
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_transform/nodes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def testTraverserOutputsNotATuple(self):
a = nodes.apply_operation(_Constant, value='a', label='Constant[a]')

mock_visitor = mock.MagicMock()
mock_visitor.visit.side_effect = ['not a tuple']
mock_visitor.visit.side_effect = [42]

with self.assertRaisesRegexp(
ValueError, r'expected visitor to return a tuple, got'):
Expand Down

0 comments on commit 34b1c79

Please sign in to comment.