[go: nahoru, domu]

Skip to content

Commit

Permalink
Simplify interleaved scaling testcases in impl_test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 380161525
  • Loading branch information
zoyahav authored and tf-transform-team committed Jun 18, 2021
1 parent b878c66 commit dbc2ecf
Showing 1 changed file with 24 additions and 54 deletions.
78 changes: 24 additions & 54 deletions tensorflow_transform/beam/impl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,15 +727,9 @@ def testScaleUnitInterval(self, elementwise):

def preprocessing_fn(inputs):
outputs = {}
cols = ('x', 'y')
for col, scaled_t in zip(
cols,
tf.unstack(
tft.scale_to_0_1(
tf.stack([inputs[col] for col in cols], axis=1),
elementwise=elementwise),
axis=1)):
outputs[col + '_scaled'] = scaled_t
stacked_input = tf.stack([inputs['x'], inputs['y']], axis=1)
result = tft.scale_to_0_1(stacked_input, elementwise=elementwise)
outputs['x_scaled'], outputs['y_scaled'] = tf.unstack(result, axis=1)
return outputs

input_data = [{
Expand Down Expand Up @@ -795,16 +789,10 @@ def testScaleUnitIntervalPerKey(self):

def preprocessing_fn(inputs):
outputs = {}
cols = ('x', 'y')
for col, scaled_t in zip(
cols,
tf.unstack(
tft.scale_to_0_1_per_key(
tf.stack([inputs[col] for col in cols], axis=1),
inputs['key'],
elementwise=False),
axis=1)):
outputs[col + '_scaled'] = scaled_t
stacked_input = tf.stack([inputs['x'], inputs['y']], axis=1)
result = tft.scale_to_0_1_per_key(
stacked_input, inputs['key'], elementwise=False)
outputs['x_scaled'], outputs['y_scaled'] = tf.unstack(result, axis=1)
return outputs

input_data = [{
Expand Down Expand Up @@ -868,17 +856,10 @@ def preprocessing_fn(inputs):
def testScaleMinMax(self, elementwise):
def preprocessing_fn(inputs):
outputs = {}
cols = ('x', 'y')
for col, scaled_t in zip(
cols,
tf.unstack(
tft.scale_by_min_max(
tf.stack([inputs[col] for col in cols], axis=1),
output_min=-1,
output_max=1,
elementwise=elementwise),
axis=1)):
outputs[col + '_scaled'] = scaled_t
stacked_input = tf.stack([inputs['x'], inputs['y']], axis=1)
result = tft.scale_by_min_max(
stacked_input, output_min=-1, output_max=1, elementwise=elementwise)
outputs['x_scaled'], outputs['y_scaled'] = tf.unstack(result, axis=1)
return outputs

input_data = [{
Expand Down Expand Up @@ -945,19 +926,15 @@ def preprocessing_fn(inputs):
def testScaleMinMaxPerKey(self, key_vocabulary_filename):
def preprocessing_fn(inputs):
outputs = {}
cols = ('x', 'y')
for col, scaled_t in zip(
cols,
tf.unstack(
tft.scale_by_min_max_per_key(
tf.stack([inputs[col] for col in cols], axis=1),
inputs['key'],
output_min=-1,
output_max=1,
elementwise=False,
key_vocabulary_filename=key_vocabulary_filename),
axis=1)):
outputs[col + '_scaled'] = scaled_t
stacked_input = tf.stack([inputs['x'], inputs['y']], axis=1)
result = tft.scale_by_min_max_per_key(
stacked_input,
inputs['key'],
output_min=-1,
output_max=1,
elementwise=False,
key_vocabulary_filename=key_vocabulary_filename)
outputs['x_scaled'], outputs['y_scaled'] = tf.unstack(result, axis=1)
return outputs

input_data = [{
Expand Down Expand Up @@ -1226,17 +1203,10 @@ def testScaleMinMaxConstantElementwise(self):

def preprocessing_fn(inputs):
outputs = {}
cols = ('x', 'y')
for col, scaled_t in zip(
cols,
tf.unstack(
tft.scale_by_min_max(
tf.stack([inputs[col] for col in cols], axis=1),
output_min=0,
output_max=10,
elementwise=True),
axis=1)):
outputs[col + '_scaled'] = scaled_t
stacked_input = tf.stack([inputs['x'], inputs['y']], axis=1)
result = tft.scale_by_min_max(
stacked_input, output_min=0, output_max=10, elementwise=True)
outputs['x_scaled'], outputs['y_scaled'] = tf.unstack(result, axis=1)
return outputs

input_data = [{
Expand Down

0 comments on commit dbc2ecf

Please sign in to comment.