[go: nahoru, domu]

Skip to content

Commit

Permalink
Showcase tft_unit in simple example, and validate the expected transf…
Browse files Browse the repository at this point in the history
…ormed data.

PiperOrigin-RevId: 422775631
  • Loading branch information
zoyahav authored and tfx-copybara committed Jan 19, 2022
1 parent cd42be7 commit ff04e67
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 30 deletions.
68 changes: 39 additions & 29 deletions examples/simple_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,41 +22,51 @@
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import schema_utils

_RAW_DATA_METADATA = dataset_metadata.DatasetMetadata(
schema_utils.schema_from_feature_spec({
's': tf.io.FixedLenFeature([], tf.string),
'y': tf.io.FixedLenFeature([], tf.float32),
'x': tf.io.FixedLenFeature([], tf.float32),
}))

_RAW_DATA = [{
'x': 1,
'y': 1,
's': 'hello'
}, {
'x': 2,
'y': 2,
's': 'world'
}, {
'x': 3,
'y': 3,
's': 'hello'
}]

def main():
def preprocessing_fn(inputs):
"""Preprocess input columns into transformed columns."""
x = inputs['x']
y = inputs['y']
s = inputs['s']
x_centered = x - tft.mean(x)
y_normalized = tft.scale_to_0_1(y)
s_integerized = tft.compute_and_apply_vocabulary(s)
x_centered_times_y_normalized = (x_centered * y_normalized)
return {
'x_centered': x_centered,
'y_normalized': y_normalized,
'x_centered_times_y_normalized': x_centered_times_y_normalized,
's_integerized': s_integerized
}

raw_data = [
{'x': 1, 'y': 1, 's': 'hello'},
{'x': 2, 'y': 2, 's': 'world'},
{'x': 3, 'y': 3, 's': 'hello'}
]
def _preprocessing_fn(inputs):
"""Preprocess input columns into transformed columns."""
x = inputs['x']
y = inputs['y']
s = inputs['s']
x_centered = x - tft.mean(x)
y_normalized = tft.scale_to_0_1(y)
s_integerized = tft.compute_and_apply_vocabulary(s)
x_centered_times_y_normalized = (x_centered * y_normalized)
return {
'x_centered': x_centered,
'y_normalized': y_normalized,
'x_centered_times_y_normalized': x_centered_times_y_normalized,
's_integerized': s_integerized
}

raw_data_metadata = dataset_metadata.DatasetMetadata(
schema_utils.schema_from_feature_spec({
's': tf.io.FixedLenFeature([], tf.string),
'y': tf.io.FixedLenFeature([], tf.float32),
'x': tf.io.FixedLenFeature([], tf.float32),
}))

def main():

with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
transformed_dataset, transform_fn = ( # pylint: disable=unused-variable
(raw_data, raw_data_metadata) | tft_beam.AnalyzeAndTransformDataset(
preprocessing_fn))
(_RAW_DATA, _RAW_DATA_METADATA)
| tft_beam.AnalyzeAndTransformDataset(_preprocessing_fn))

transformed_data, transformed_metadata = transformed_dataset # pylint: disable=unused-variable

Expand Down
34 changes: 33 additions & 1 deletion examples/simple_example_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,42 @@
"""Tests for simple_example."""

import tensorflow as tf
from tensorflow_transform.beam import tft_unit
import simple_example


class SimpleExampleTest(tf.test.TestCase):
_EXPECTED_TRANSFORMED_OUTPUT = [
{
'x_centered': 1.0,
'y_normalized': 1.0,
'x_centered_times_y_normalized': 1.0,
's_integerized': 0,
},
{
'x_centered': 0.0,
'y_normalized': 0.5,
'x_centered_times_y_normalized': 0.0,
's_integerized': 1,
},
{
'x_centered': -1.0,
'y_normalized': 0.0,
'x_centered_times_y_normalized': -0.0,
's_integerized': 0,
},
]


class SimpleExampleTest(tft_unit.TransformTestCase):

def test_preprocessing_fn(self):
self.assertAnalyzeAndTransformResults(simple_example._RAW_DATA,
simple_example._RAW_DATA_METADATA,
simple_example._preprocessing_fn,
_EXPECTED_TRANSFORMED_OUTPUT)


class SimpleMainTest(tf.test.TestCase):

def testMainDoesNotCrash(self):
simple_example.main()
Expand Down

0 comments on commit ff04e67

Please sign in to comment.