diff --git a/tensorflow_transform/analyzers.py b/tensorflow_transform/analyzers.py index 0ef6bf2e..1608c915 100644 --- a/tensorflow_transform/analyzers.py +++ b/tensorflow_transform/analyzers.py @@ -200,7 +200,7 @@ def _apply_cacheable_combiner_per_key( def _apply_cacheable_combiner_per_key_large( combiner: analyzer_nodes.Combiner, key_vocabulary_filename: str, *tensor_inputs: common_types.TensorType -) -> Union[tf.Tensor, common_types.Asset]: +) -> Union[tf.Tensor, tf.saved_model.Asset]: """Similar to above but saves the combined result to a file.""" input_values_node = analyzer_nodes.get_input_tensors_value_nodes( tensor_inputs) @@ -1072,7 +1072,7 @@ def _mean_and_var_per_key( output_dtype: Optional[tf.DType] = None, key_vocabulary_filename: Optional[str] = None ) -> Union[Tuple[tf.Tensor, tf.Tensor, tf.Tensor], tf.Tensor, - common_types.Asset]: + tf.saved_model.Asset]: """`mean_and_var` by group, specified by key. Args: diff --git a/tensorflow_transform/beam/impl_test.py b/tensorflow_transform/beam/impl_test.py index e0f0d9ae..cfda233f 100644 --- a/tensorflow_transform/beam/impl_test.py +++ b/tensorflow_transform/beam/impl_test.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import itertools import math import os @@ -562,9 +561,6 @@ def preprocessing_fn(inputs): expected_metadata) def testPyFuncs(self): - if not tft_unit.is_tf_api_version_1(): - raise unittest.SkipTest('Test disabled when TF 2.x behavior enabled.') - def my_multiply(x, y): return x*y @@ -628,14 +624,11 @@ def preprocessing_fn(inputs): }) self.assertAnalyzeAndTransformResults( input_data, input_metadata, preprocessing_fn, expected_data, - expected_metadata) + expected_metadata, force_tf_compat_v1=True) def testAssertsNoReturnPyFunc(self): # Asserts that apply_pyfunc raises an exception if the passed function does # not return anything. - if not tft_unit.is_tf_api_version_1(): - raise unittest.SkipTest('Test disabled when TF 2.x behavior enabled.') - self._SkipIfOutputRecordBatches() def bad_func(): @@ -684,7 +677,8 @@ def preprocessing_fn(inputs): preprocessing_fn, expected_data, expected_metadata, - desired_batch_size=batch_size) + desired_batch_size=batch_size, + force_tf_compat_v1=True) def testWithUnicode(self): def preprocessing_fn(inputs): @@ -4714,12 +4708,6 @@ def testEmptySchema(self): preprocessing_fn=lambda inputs: inputs) # pyformat: disable def testLoadKerasModelInPreprocessingFn(self): - - if tft_unit.is_tf_api_version_1(): - raise unittest.SkipTest( - '`tft.make_and_track_object` is only supported when TF2 behavior is ' - 'enabled.') - def _create_model(features, target): inputs = [ tf.keras.Input(shape=(1,), name=f, dtype=tf.float32) for f in features @@ -4797,11 +4785,8 @@ def preprocessing_fn(inputs): 'f3': 1 }] - with contextlib.ExitStack() as stack: - if not tft_unit.is_tf_api_version_1(): - stack.enter_context( - self.assertRaisesRegex( - RuntimeError, 'analyzers.*appears to be non-deterministic')) + with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises + RuntimeError, 'analyzers.*appears to be non-deterministic'): self.assertAnalyzeAndTransformResults(input_data, input_metadata, preprocessing_fn, expected_outputs) diff --git a/tensorflow_transform/beam/tft_unit.py b/tensorflow_transform/beam/tft_unit.py index f9c80edd..fec1cdc3 100644 --- a/tensorflow_transform/beam/tft_unit.py +++ b/tensorflow_transform/beam/tft_unit.py @@ -37,7 +37,6 @@ cross_parameters = test_case.cross_parameters named_parameters = test_case.named_parameters cross_named_parameters = test_case.cross_named_parameters -is_tf_api_version_1 = test_case.is_tf_api_version_1 is_external_environment = test_case.is_external_environment skip_if_not_tf2 = test_case.skip_if_not_tf2 SkipTest = test_case.SkipTest diff --git a/tensorflow_transform/beam/vocabulary_tfrecord_gzip_integration_test.py b/tensorflow_transform/beam/vocabulary_tfrecord_gzip_integration_test.py index bfce0e9a..6f44e3ca 100644 --- a/tensorflow_transform/beam/vocabulary_tfrecord_gzip_integration_test.py +++ b/tensorflow_transform/beam/vocabulary_tfrecord_gzip_integration_test.py @@ -39,8 +39,7 @@ def setUp(self): mock_is_vocabulary_tfrecord_supported.side_effect = lambda: True if (tft_unit.is_external_environment() and - not tf_utils.is_vocabulary_tfrecord_supported() or - tft_unit.is_tf_api_version_1()): + not tf_utils.is_vocabulary_tfrecord_supported()): raise unittest.SkipTest('Test requires async DatasetInitializer') super().setUp() diff --git a/tensorflow_transform/common_types.py b/tensorflow_transform/common_types.py index 6800b85b..06aed9b7 100644 --- a/tensorflow_transform/common_types.py +++ b/tensorflow_transform/common_types.py @@ -21,14 +21,6 @@ from tensorflow_metadata.proto.v0 import schema_pb2 -# TODO(b/160294509): Stop using tracking.TrackableAsset when TF1.15 support is -# dropped. -if hasattr(tf.saved_model, 'Asset'): - Asset = tf.saved_model.Asset # pylint: disable=invalid-name -else: - from tensorflow.python.training.tracking import tracking # pylint: disable=g-direct-tensorflow-import, g-import-not-at-top - Asset = tracking.TrackableAsset # pylint: disable=invalid-name - # TODO(b/185719271): Define BucketBoundariesType at module level of mappers.py. BucketBoundariesType = Union[tf.Tensor, Iterable[Union[int, float]]] @@ -53,7 +45,7 @@ tf.compat.v1.ragged.RaggedTensorValue] TensorValueType = Union[tf.Tensor, np.ndarray, SparseTensorValueType, RaggedTensorValueType] -TemporaryAnalyzerOutputType = Union[tf.Tensor, Asset] +TemporaryAnalyzerOutputType = Union[tf.Tensor, tf.saved_model.Asset] VocabularyFileFormatType = Literal['text', 'tfrecord_gzip'] diff --git a/tensorflow_transform/graph_tools_test.py b/tensorflow_transform/graph_tools_test.py index aa3ca774..1eab906c 100644 --- a/tensorflow_transform/graph_tools_test.py +++ b/tensorflow_transform/graph_tools_test.py @@ -32,12 +32,6 @@ mock = tf.compat.v1.test.mock -def _skip_if_external_environment_or_v1_api(reason): - test_case.skip_if_external_environment(reason) - if test_case.is_tf_api_version_1(): - raise test_case.SkipTest(reason) - - def _create_lookup_table_from_file(filename): initializer = tf.lookup.TextFileInitializer( filename, @@ -1098,7 +1092,7 @@ class GraphToolsTestUniquePath(test_case.TransformTestCase): }), dict( testcase_name='_y_function_of_x_with_tf_while', - skip_test_check_fn=_skip_if_external_environment_or_v1_api, + skip_test_check_fn=test_case.skip_if_external_environment, create_graph_fn=_create_graph_with_tf_function_while, feeds=['x'], replaced_tensors_ready={'x': False}, diff --git a/tensorflow_transform/output_wrapper.py b/tensorflow_transform/output_wrapper.py index 3391d5fe..bd23908c 100644 --- a/tensorflow_transform/output_wrapper.py +++ b/tensorflow_transform/output_wrapper.py @@ -30,7 +30,6 @@ from tensorflow_transform.tf_metadata import schema_utils # pylint: disable=g-direct-tensorflow-import -from tensorflow.python import tf2 from tensorflow.python.framework import ops from tensorflow.tools.docs import doc_controls # pylint: enable=g-direct-tensorflow-import @@ -427,40 +426,8 @@ def post_transform_statistics_path(self) -> str: self._transform_output_dir, self.POST_TRANSFORM_FEATURE_STATS_PATH) -# TODO(zoyahav): Use register_keras_serializable directly once we no longer support -# TF<2.1. -def _maybe_register_keras_serializable(package): - if hasattr(tf.keras.utils, 'register_keras_serializable'): - return tf.keras.utils.register_keras_serializable(package=package) - else: - return lambda cls: cls - - -def _check_tensorflow_version(): - """Check that we're using a compatible TF version. - - Raises a warning if either Tensorflow version is less that 2.0 or TF 2.x is - not enabled. - - If TF 2.x is enabled, but version is < TF 2.3, raises a warning to indicate - that resources may not be initialized. - """ - major, minor, _ = tf.version.VERSION.split('.') - if not (int(major) >= 2 and tf2.enabled()): - tf.compat.v1.logging.warning( - 'Tensorflow version (%s) found. TransformFeaturesLayer is supported ' - 'only for TF 2.x with TF 2.x behaviors enabled and may not work as ' - 'intended.', tf.version.VERSION) - elif int(major) == 2 and int(minor) < 3: - # TODO(varshaan): Log a more specific warning. - tf.compat.v1.logging.warning( - 'Tensorflow version (%s) found. TransformFeaturesLayer may not work ' - 'as intended if the SavedModel contains an initialization op.', - tf.version.VERSION) - - # TODO(b/162055065): Possibly switch back to inherit from Layer when possible. -@_maybe_register_keras_serializable(package='TensorFlowTransform') +@tf.keras.utils.register_keras_serializable(package='TensorFlowTransform') class TransformFeaturesLayer(tf.keras.Model): """A Keras layer for applying a tf.Transform output to input layers.""" @@ -478,7 +445,6 @@ def __init__(self, self._loaded_saved_model_graph = None # TODO(b/160294509): Use tf.compat.v1 when we stop supporting TF 1.15. if ops.executing_eagerly_outside_functions(): - _check_tensorflow_version() # The model must be tracked by assigning to an attribute of the Keras # layer. Hence, we track the attributes of _saved_model_loader here as # well. diff --git a/tensorflow_transform/saved/saved_transform_io_v2.py b/tensorflow_transform/saved/saved_transform_io_v2.py index ab1cb75b..d236abc6 100644 --- a/tensorflow_transform/saved/saved_transform_io_v2.py +++ b/tensorflow_transform/saved/saved_transform_io_v2.py @@ -522,7 +522,7 @@ def _variable_creator(next_creator, **kwargs): initializers.append(resource._initializer) # pylint: disable=protected-access module.initializers = initializers module.assets = [ - common_types.Asset(asset_filepath) for asset_filepath in + tf.saved_model.Asset(asset_filepath) for asset_filepath in concrete_fn.graph.get_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS) ] return concrete_fn diff --git a/tensorflow_transform/test_case.py b/tensorflow_transform/test_case.py index 53212081..f67cd0c3 100644 --- a/tensorflow_transform/test_case.py +++ b/tensorflow_transform/test_case.py @@ -35,10 +35,6 @@ SkipTest = unittest.SkipTest -def is_tf_api_version_1(): - return hasattr(tf, 'Session') - - def cross_named_parameters(*args): """Cross a list of lists of dicts suitable for @named_parameters. @@ -239,8 +235,7 @@ def skip_if_external_environment(reason): def skip_if_not_tf2(reason): - major, _, _ = tf.version.VERSION.split('.') - if not (int(major) >= 2 and tf2.enabled()) or is_tf_api_version_1(): + if not tf2.enabled(): raise unittest.SkipTest(reason) diff --git a/tensorflow_transform/tf_utils_test.py b/tensorflow_transform/tf_utils_test.py index fa4ad019..9b26aad1 100644 --- a/tensorflow_transform/tf_utils_test.py +++ b/tensorflow_transform/tf_utils_test.py @@ -2292,9 +2292,6 @@ def test_split_vocabulary_entries(self): self.assertAllEqual(self.evaluate(keys), np.array(expected_keys)) self.assertAllEqual(self.evaluate(values), np.array(expected_values)) - @unittest.skipIf( - test_case.is_tf_api_version_1(), - 'TFRecord vocabulary dataset tests require TF API version>1') def test_read_tfrecord_vocabulary_dataset(self): vocab_file = os.path.join(self.get_temp_dir(), 'vocab.tfrecord.gz') contents = [b'a', b'b', b'c'] @@ -2346,9 +2343,6 @@ def test_read_tfrecord_vocabulary_dataset(self): return_indicator_as_value=True, has_indicator=True), ]) - @unittest.skipIf( - test_case.is_tf_api_version_1(), - 'TFRecord vocabulary dataset tests require TF API version>1') def test_make_tfrecord_vocabulary_dataset(self, contents, expected, key_dtype, value_dtype, return_indicator_as_value,