[go: nahoru, domu]

Skip to content

Commit

Permalink
Continue to remove TF1 related logic from TFT
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 493101025
  • Loading branch information
zoyahav authored and tfx-copybara committed Dec 5, 2022
1 parent 7ca243a commit 1c82cf3
Show file tree
Hide file tree
Showing 10 changed files with 13 additions and 89 deletions.
4 changes: 2 additions & 2 deletions tensorflow_transform/analyzers.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _apply_cacheable_combiner_per_key(
def _apply_cacheable_combiner_per_key_large(
combiner: analyzer_nodes.Combiner, key_vocabulary_filename: str,
*tensor_inputs: common_types.TensorType
) -> Union[tf.Tensor, common_types.Asset]:
) -> Union[tf.Tensor, tf.saved_model.Asset]:
"""Similar to above but saves the combined result to a file."""
input_values_node = analyzer_nodes.get_input_tensors_value_nodes(
tensor_inputs)
Expand Down Expand Up @@ -1072,7 +1072,7 @@ def _mean_and_var_per_key(
output_dtype: Optional[tf.DType] = None,
key_vocabulary_filename: Optional[str] = None
) -> Union[Tuple[tf.Tensor, tf.Tensor, tf.Tensor], tf.Tensor,
common_types.Asset]:
tf.saved_model.Asset]:
"""`mean_and_var` by group, specified by key.
Args:
Expand Down
25 changes: 5 additions & 20 deletions tensorflow_transform/beam/impl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import itertools
import math
import os
Expand Down Expand Up @@ -562,9 +561,6 @@ def preprocessing_fn(inputs):
expected_metadata)

def testPyFuncs(self):
if not tft_unit.is_tf_api_version_1():
raise unittest.SkipTest('Test disabled when TF 2.x behavior enabled.')

def my_multiply(x, y):
return x*y

Expand Down Expand Up @@ -628,14 +624,11 @@ def preprocessing_fn(inputs):
})
self.assertAnalyzeAndTransformResults(
input_data, input_metadata, preprocessing_fn, expected_data,
expected_metadata)
expected_metadata, force_tf_compat_v1=True)

def testAssertsNoReturnPyFunc(self):
# Asserts that apply_pyfunc raises an exception if the passed function does
# not return anything.
if not tft_unit.is_tf_api_version_1():
raise unittest.SkipTest('Test disabled when TF 2.x behavior enabled.')

self._SkipIfOutputRecordBatches()

def bad_func():
Expand Down Expand Up @@ -684,7 +677,8 @@ def preprocessing_fn(inputs):
preprocessing_fn,
expected_data,
expected_metadata,
desired_batch_size=batch_size)
desired_batch_size=batch_size,
force_tf_compat_v1=True)

def testWithUnicode(self):
def preprocessing_fn(inputs):
Expand Down Expand Up @@ -4714,12 +4708,6 @@ def testEmptySchema(self):
preprocessing_fn=lambda inputs: inputs) # pyformat: disable

def testLoadKerasModelInPreprocessingFn(self):

if tft_unit.is_tf_api_version_1():
raise unittest.SkipTest(
'`tft.make_and_track_object` is only supported when TF2 behavior is '
'enabled.')

def _create_model(features, target):
inputs = [
tf.keras.Input(shape=(1,), name=f, dtype=tf.float32) for f in features
Expand Down Expand Up @@ -4797,11 +4785,8 @@ def preprocessing_fn(inputs):
'f3': 1
}]

with contextlib.ExitStack() as stack:
if not tft_unit.is_tf_api_version_1():
stack.enter_context(
self.assertRaisesRegex(
RuntimeError, 'analyzers.*appears to be non-deterministic'))
with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises
RuntimeError, 'analyzers.*appears to be non-deterministic'):
self.assertAnalyzeAndTransformResults(input_data, input_metadata,
preprocessing_fn, expected_outputs)

Expand Down
1 change: 0 additions & 1 deletion tensorflow_transform/beam/tft_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
cross_parameters = test_case.cross_parameters
named_parameters = test_case.named_parameters
cross_named_parameters = test_case.cross_named_parameters
is_tf_api_version_1 = test_case.is_tf_api_version_1
is_external_environment = test_case.is_external_environment
skip_if_not_tf2 = test_case.skip_if_not_tf2
SkipTest = test_case.SkipTest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def setUp(self):
mock_is_vocabulary_tfrecord_supported.side_effect = lambda: True

if (tft_unit.is_external_environment() and
not tf_utils.is_vocabulary_tfrecord_supported() or
tft_unit.is_tf_api_version_1()):
not tf_utils.is_vocabulary_tfrecord_supported()):
raise unittest.SkipTest('Test requires async DatasetInitializer')
super().setUp()

Expand Down
10 changes: 1 addition & 9 deletions tensorflow_transform/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,6 @@

from tensorflow_metadata.proto.v0 import schema_pb2

# TODO(b/160294509): Stop using tracking.TrackableAsset when TF1.15 support is
# dropped.
if hasattr(tf.saved_model, 'Asset'):
Asset = tf.saved_model.Asset # pylint: disable=invalid-name
else:
from tensorflow.python.training.tracking import tracking # pylint: disable=g-direct-tensorflow-import, g-import-not-at-top
Asset = tracking.TrackableAsset # pylint: disable=invalid-name

# TODO(b/185719271): Define BucketBoundariesType at module level of mappers.py.
BucketBoundariesType = Union[tf.Tensor, Iterable[Union[int, float]]]

Expand All @@ -53,7 +45,7 @@
tf.compat.v1.ragged.RaggedTensorValue]
TensorValueType = Union[tf.Tensor, np.ndarray, SparseTensorValueType,
RaggedTensorValueType]
TemporaryAnalyzerOutputType = Union[tf.Tensor, Asset]
TemporaryAnalyzerOutputType = Union[tf.Tensor, tf.saved_model.Asset]
VocabularyFileFormatType = Literal['text', 'tfrecord_gzip']


Expand Down
8 changes: 1 addition & 7 deletions tensorflow_transform/graph_tools_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@
mock = tf.compat.v1.test.mock


def _skip_if_external_environment_or_v1_api(reason):
test_case.skip_if_external_environment(reason)
if test_case.is_tf_api_version_1():
raise test_case.SkipTest(reason)


def _create_lookup_table_from_file(filename):
initializer = tf.lookup.TextFileInitializer(
filename,
Expand Down Expand Up @@ -1098,7 +1092,7 @@ class GraphToolsTestUniquePath(test_case.TransformTestCase):
}),
dict(
testcase_name='_y_function_of_x_with_tf_while',
skip_test_check_fn=_skip_if_external_environment_or_v1_api,
skip_test_check_fn=test_case.skip_if_external_environment,
create_graph_fn=_create_graph_with_tf_function_while,
feeds=['x'],
replaced_tensors_ready={'x': False},
Expand Down
36 changes: 1 addition & 35 deletions tensorflow_transform/output_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from tensorflow_transform.tf_metadata import schema_utils

# pylint: disable=g-direct-tensorflow-import
from tensorflow.python import tf2
from tensorflow.python.framework import ops
from tensorflow.tools.docs import doc_controls
# pylint: enable=g-direct-tensorflow-import
Expand Down Expand Up @@ -427,40 +426,8 @@ def post_transform_statistics_path(self) -> str:
self._transform_output_dir, self.POST_TRANSFORM_FEATURE_STATS_PATH)


# TODO(zoyahav): Use register_keras_serializable directly once we no longer support
# TF<2.1.
def _maybe_register_keras_serializable(package):
if hasattr(tf.keras.utils, 'register_keras_serializable'):
return tf.keras.utils.register_keras_serializable(package=package)
else:
return lambda cls: cls


def _check_tensorflow_version():
"""Check that we're using a compatible TF version.
Raises a warning if either Tensorflow version is less that 2.0 or TF 2.x is
not enabled.
If TF 2.x is enabled, but version is < TF 2.3, raises a warning to indicate
that resources may not be initialized.
"""
major, minor, _ = tf.version.VERSION.split('.')
if not (int(major) >= 2 and tf2.enabled()):
tf.compat.v1.logging.warning(
'Tensorflow version (%s) found. TransformFeaturesLayer is supported '
'only for TF 2.x with TF 2.x behaviors enabled and may not work as '
'intended.', tf.version.VERSION)
elif int(major) == 2 and int(minor) < 3:
# TODO(varshaan): Log a more specific warning.
tf.compat.v1.logging.warning(
'Tensorflow version (%s) found. TransformFeaturesLayer may not work '
'as intended if the SavedModel contains an initialization op.',
tf.version.VERSION)


# TODO(b/162055065): Possibly switch back to inherit from Layer when possible.
@_maybe_register_keras_serializable(package='TensorFlowTransform')
@tf.keras.utils.register_keras_serializable(package='TensorFlowTransform')
class TransformFeaturesLayer(tf.keras.Model):
"""A Keras layer for applying a tf.Transform output to input layers."""

Expand All @@ -478,7 +445,6 @@ def __init__(self,
self._loaded_saved_model_graph = None
# TODO(b/160294509): Use tf.compat.v1 when we stop supporting TF 1.15.
if ops.executing_eagerly_outside_functions():
_check_tensorflow_version()
# The model must be tracked by assigning to an attribute of the Keras
# layer. Hence, we track the attributes of _saved_model_loader here as
# well.
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_transform/saved/saved_transform_io_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def _variable_creator(next_creator, **kwargs):
initializers.append(resource._initializer) # pylint: disable=protected-access
module.initializers = initializers
module.assets = [
common_types.Asset(asset_filepath) for asset_filepath in
tf.saved_model.Asset(asset_filepath) for asset_filepath in
concrete_fn.graph.get_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS)
]
return concrete_fn
Expand Down
7 changes: 1 addition & 6 deletions tensorflow_transform/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@
SkipTest = unittest.SkipTest


def is_tf_api_version_1():
return hasattr(tf, 'Session')


def cross_named_parameters(*args):
"""Cross a list of lists of dicts suitable for @named_parameters.
Expand Down Expand Up @@ -239,8 +235,7 @@ def skip_if_external_environment(reason):


def skip_if_not_tf2(reason):
major, _, _ = tf.version.VERSION.split('.')
if not (int(major) >= 2 and tf2.enabled()) or is_tf_api_version_1():
if not tf2.enabled():
raise unittest.SkipTest(reason)


Expand Down
6 changes: 0 additions & 6 deletions tensorflow_transform/tf_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2292,9 +2292,6 @@ def test_split_vocabulary_entries(self):
self.assertAllEqual(self.evaluate(keys), np.array(expected_keys))
self.assertAllEqual(self.evaluate(values), np.array(expected_values))

@unittest.skipIf(
test_case.is_tf_api_version_1(),
'TFRecord vocabulary dataset tests require TF API version>1')
def test_read_tfrecord_vocabulary_dataset(self):
vocab_file = os.path.join(self.get_temp_dir(), 'vocab.tfrecord.gz')
contents = [b'a', b'b', b'c']
Expand Down Expand Up @@ -2346,9 +2343,6 @@ def test_read_tfrecord_vocabulary_dataset(self):
return_indicator_as_value=True,
has_indicator=True),
])
@unittest.skipIf(
test_case.is_tf_api_version_1(),
'TFRecord vocabulary dataset tests require TF API version>1')
def test_make_tfrecord_vocabulary_dataset(self, contents, expected, key_dtype,
value_dtype,
return_indicator_as_value,
Expand Down

0 comments on commit 1c82cf3

Please sign in to comment.