[go: nahoru, domu]

Skip to content

Commit

Permalink
Adds doctest and documentation for using TFT cache when using TFT dir…
Browse files Browse the repository at this point in the history
…ectly.

PiperOrigin-RevId: 372308410
  • Loading branch information
zoyahav authored and tf-transform-team committed May 6, 2021
1 parent 9d2f7c3 commit df21282
Showing 1 changed file with 45 additions and 11 deletions.
56 changes: 45 additions & 11 deletions tensorflow_transform/beam/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,7 +1137,7 @@ def expand(self, dataset):


class AnalyzeDatasetWithCache(_AnalyzeDatasetCommon):
"""Takes a preprocessing_fn and computes the relevant statistics.
r"""Takes a preprocessing_fn and computes the relevant statistics.
WARNING: This is experimental.
Expand All @@ -1146,16 +1146,50 @@ class AnalyzeDatasetWithCache(_AnalyzeDatasetCommon):
will write out cache for statistics that it does compute whenever possible.
Example use:
```
pcoll_cache_dict = (pipeline
| tft.analyzer_cache.ReadAnalysisCacheFromFS(cache_dir, dataset_keys))
transform_fn, cache_output = (
(input_data_pcoll_dict, pcoll_cache_dict, input_metadata)
| tft_beam.AnalyzeDatasetWithCache(preprocessing_fn))
_ = (
cache_output
| tft.analyzer_cache.WriteAnalysisCacheToFS(pipeline, cache_dir))
```
>>> span_0_key = tft_beam.analyzer_cache.DatasetKey('span-0')
>>> cache_dir = tempfile.mkdtemp()
>>> output_path = os.path.join(tempfile.mkdtemp(), 'result')
>>> def preprocessing_fn(inputs):
... x = inputs['x']
... return {'x_mean': tft.mean(x, name='x') + tf.zeros_like(x)}
>>> feature_spec = {'x': tf.io.FixedLenFeature([], tf.float32)}
>>> input_metadata = dataset_metadata.DatasetMetadata(
... schema_utils.schema_from_feature_spec(feature_spec))
>>> input_data_dict_0 = {span_0_key: [{'x': x} for x in range(6)]}
>>> input_data_dict_1 = {span_0_key: [{'x': x} for x in range(6, 11)]}
>>> empty_input_cache = {}
>>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
... with beam.Pipeline() as p:
... # Iteration #0:
... transform_fn, output_cache = (
... (input_data_dict_0, empty_input_cache, input_metadata)
... | tft_beam.AnalyzeDatasetWithCache(preprocessing_fn))
... output_cache | tft_beam.analyzer_cache.WriteAnalysisCacheToFS(
... p, cache_dir)
...
... # Iteration #1:
... input_cache = p | tft_beam.analyzer_cache.ReadAnalysisCacheFromFS(
... cache_dir, [span_0_key])
... transform_fn, output_cache = (
... (input_data_dict_1, input_cache, input_metadata)
... | tft_beam.AnalyzeDatasetWithCache(preprocessing_fn))
... output_cache | tft_beam.analyzer_cache.WriteAnalysisCacheToFS(
... p, cache_dir)
...
... # Applying the accumulated transformation:
... transform_data = p | beam.Create(input_data_dict_0[span_0_key])
... transformed_dataset = (
... ((transform_data, input_metadata), transform_fn)
... | tft_beam.TransformDataset())
... transformed_data, transformed_metadata = transformed_dataset
... (transformed_data
... | beam.combiners.Sample.FixedSizeGlobally(1)
... | beam.io.WriteToText(output_path, shard_name_template=''))
>>> with open(output_path) as f:
... f.read()
"[{'x_mean': 5.0}]\n"
"""

def _make_parent_dataset(self, dataset):
Expand Down

0 comments on commit df21282

Please sign in to comment.