[go: nahoru, domu]

Skip to content

Commit

Permalink
Address and remove remaining TF 1.15 support related TODOs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 493487255
  • Loading branch information
zoyahav authored and tfx-copybara committed Dec 7, 2022
1 parent 3147adf commit 1fa2ed8
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 47 deletions.
9 changes: 2 additions & 7 deletions tensorflow_transform/experimental/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,14 +325,9 @@ def _to_term_document_one_hot(
# frequency only cares the existence of a term in a document, not the
# occurrence frequency within that document.
# Hashing (<batch_index>, <vocab_index>) pairs for dedup.
# TODO(b/160294509): Switch to tf.raw_ops.UniqueV2 to avoid hashing for tf1
# tf.raw_ops.UniqueV2 always results in rank-1 tensor placeholder in tf1, even
# when the input is 2D. This causes issues when applying UniqueV2 to
# (<batch_index>, <vocab_index>) 2D tensor along axis=0 and using the uniqued
# 2D tensor as indices for sparse tensors in tf1. See b/160294509#comment8.
multiplier = vocab_size + 1
unique_flatten_indices, _ = tf.unique(batch_indices * multiplier +
vocab_indices)
unique_flatten_indices, _ = tf.raw_ops.UniqueV2(
x=batch_indices * multiplier + vocab_indices, axis=[0])
unique_batch_indices = tf.cast(
tf.math.divide(unique_flatten_indices, multiplier), dtype=tf.int64)
unique_vocab_indices = tf.math.mod(unique_flatten_indices, multiplier)
Expand Down
10 changes: 1 addition & 9 deletions tensorflow_transform/graph_tools_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,6 @@ class GraphToolsTestUniquePath(test_case.TransformTestCase):
}),
dict(
testcase_name='_y_function_of_x_with_raw_ops_while',
skip_test_check_fn=test_case.skip_if_external_environment,
create_graph_fn=_create_graph_with_y_function_of_x_with_tf_while,
feeds=['x'],
replaced_tensors_ready={'x': False},
Expand Down Expand Up @@ -1092,7 +1091,6 @@ class GraphToolsTestUniquePath(test_case.TransformTestCase):
}),
dict(
testcase_name='_y_function_of_x_with_tf_while',
skip_test_check_fn=test_case.skip_if_external_environment,
create_graph_fn=_create_graph_with_tf_function_while,
feeds=['x'],
replaced_tensors_ready={'x': False},
Expand Down Expand Up @@ -1226,13 +1224,7 @@ def testGetUniquePath(self,
create_graph_fn,
feeds,
replaced_tensors_ready,
expected_calls_dict,
skip_test_check_fn=None):

# TODO(b/160294509): Remove this condition when TFT no longer supports TF<2.
if skip_test_check_fn:
skip_test_check_fn('This test is not currently supported.')

expected_calls_dict):
with tf.compat.v1.Graph().as_default() as graph:
tensors = create_graph_fn()
replaced_tensors_ready = [(tensors[name], ready)
Expand Down
6 changes: 2 additions & 4 deletions tensorflow_transform/output_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,7 @@ def __init__(self,
self._exported_as_v1 = exported_as_v1
self._saved_model_loader_value = None
self._loaded_saved_model_graph = None
# TODO(b/160294509): Use tf.compat.v1 when we stop supporting TF 1.15.
if ops.executing_eagerly_outside_functions():
if tf.compat.v1.executing_eagerly_outside_functions():
# The model must be tracked by assigning to an attribute of the Keras
# layer. Hence, we track the attributes of _saved_model_loader here as
# well.
Expand All @@ -470,8 +469,7 @@ def _saved_model_loader(self) -> saved_transform_io_v2.SavedModelLoader:
self._tft_output.transform_savedmodel_dir)
self._loaded_saved_model_graph = ops.get_default_graph()

# TODO(b/160294509): Use tf.compat.v1 when we stop supporting TF 1.15.
if ops.executing_eagerly_outside_functions():
if tf.compat.v1.executing_eagerly_outside_functions():
return self._saved_model_loader_value
else:
assert not self._exported_as_v1
Expand Down
4 changes: 1 addition & 3 deletions tensorflow_transform/saved/saved_transform_io_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,7 @@ def __init__(self, saved_model_dir: str):
defined in `../constants.py` ('transform' and 'transform_signature',
respectively).
"""
# TODO(b/160294509): Stop using tf.compat.v2 when TF1.15 support is
# dropped.
imported = tf.compat.v2.saved_model.load(saved_model_dir)
imported = tf.saved_model.load(saved_model_dir)
load_v2_in_compat = constants.TRANSFORM_SIGNATURE in imported.signatures
if load_v2_in_compat:
restored_function = imported.signatures[constants.TRANSFORM_SIGNATURE]
Expand Down
10 changes: 4 additions & 6 deletions tensorflow_transform/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,8 @@ def _make_placeholder(tensor_spec):
return tf.compat.v1.sparse_placeholder(
shape=tensor_spec.shape, dtype=tensor_spec.dtype)
if isinstance(tensor_spec, tf.RaggedTensorSpec):
# TODO(b/160294509): Switch to public APIs once TF 1 support is dropped.
return tf.compat.v1.ragged.placeholder(
tensor_spec._dtype, tensor_spec._ragged_rank, value_shape=()) # pylint: disable=protected-access
tensor_spec.dtype, tensor_spec.ragged_rank, value_shape=())
else:
return tf.compat.v1.placeholder(
shape=tensor_spec.shape, dtype=tensor_spec.dtype)
Expand Down Expand Up @@ -164,8 +163,7 @@ def _wrap_as_constant(value, tensor_spec):
values=tf.constant(value.values, dtype=tensor_spec.dtype),
dense_shape=tf.constant(value.dense_shape, dtype=tf.int64))
elif isinstance(tensor_spec, tf.RaggedTensorSpec):
# TODO(b/160294509): Switch to public APIs once TF 1 support is dropped.
result = _ragged_value_as_constant(value, tensor_spec._dtype) # pylint: disable=protected-access
result = _ragged_value_as_constant(value, tensor_spec.dtype)
else:
result = tf.constant(value, dtype=tensor_spec.dtype)
result.shape.assert_is_compatible_with(tensor_spec.shape)
Expand Down Expand Up @@ -299,10 +297,10 @@ def _assertValuesCloseOrEqual(self, a_value, b_value, msg=None):
if (isinstance(a_value, (bytes, str)) or isinstance(a_value, list) and
a_value and isinstance(a_value[0], (bytes, str)) or
isinstance(a_value, np.ndarray) and a_value.dtype == object):
self.assertAllEqual(a_value, b_value)
self.assertAllEqual(a_value, b_value, msg=msg)
else:
# TODO(varshaan): Change atol only for tests for which 1e-6 is too strict.
self.assertAllClose(a_value, b_value, atol=1e-5)
self.assertAllClose(a_value, b_value, atol=1e-5, msg=msg)

def AssertVocabularyContents(self, vocab_file_path, file_contents):
if vocab_file_path.endswith('.tfrecord.gz'):
Expand Down
4 changes: 1 addition & 3 deletions tensorflow_transform/tf2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,18 @@
from tensorflow_transform import common_types
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python import tf2
from tensorflow.python.framework import ops
from tensorflow.python.framework.func_graph import FuncGraph
# pylint: enable=g-direct-tensorflow-import


def use_tf_compat_v1(force_tf_compat_v1: bool) -> bool:
"""Evaluate from environment variables if TF should be used in compat.v1 mode."""
major, _, _ = tf.version.VERSION.split('.')
# TODO(b/160294509): Use tf.compat.v1 when we stop supporting TF 1.15.
# If tf.enable_v2_behavior has been called, but eager execution has been
# disabled, force compat v1 behavior. Hence, check
# `executing_eagerly_outside_functions` as well.
return (force_tf_compat_v1 or int(major) < 2 or not tf2.enabled() or
not ops.executing_eagerly_outside_functions())
not tf.compat.v1.executing_eagerly_outside_functions())


def strip_and_get_tensors_and_control_dependencies(
Expand Down
15 changes: 3 additions & 12 deletions tensorflow_transform/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,13 +323,11 @@ def hashable_tensor_or_op(tensor_or_op):
if isinstance(tensor_or_op, tf.Tensor):
return tensor_or_op.experimental_ref()
if isinstance(tensor_or_op, composite_tensor.CompositeTensor):
# TODO(b/160294509): Use tf.type_spec_from_value - only available in TF 2.
return _CompositeTensorRef(
type_spec=tensor_or_op._type_spec, # pylint: disable=protected-access
type_spec=tf.type_spec_from_value(tensor_or_op),
list_of_refs=tuple(
hashable_tensor_or_op(component) for component in tf.nest.flatten(
tensor_or_op, expand_composites=True)
))
tensor_or_op, expand_composites=True)))
return tensor_or_op


Expand Down Expand Up @@ -432,10 +430,7 @@ def reduce_batch_count(x: common_types.TensorType,
dense_shape=x.dense_shape)
# TODO(b/178189903): Remove this once we no longer lose static shape
# information.
# TODO(b/160294509): Remove the hasattr contition once TFT no longer
# supports TF<2.
if hasattr(x, '_dense_shape_default'):
ones_like._dense_shape_default = x._dense_shape_default # pylint: disable=protected-access
ones_like._dense_shape_default = x._dense_shape_default # pylint: disable=protected-access
return _sparse_reduce_batch_keep_shape(tf.sparse.reduce_sum, ones_like)
elif isinstance(x, tf.RaggedTensor):
if reduce_instance_dims:
Expand Down Expand Up @@ -697,10 +692,6 @@ def _split_vocabulary_entries(batched_vocab_lines):
if isinstance(split, tf.RaggedTensor):
split_tensor = split.to_tensor()
return split_tensor[:, 1], split_tensor[:, 0]
# TODO(b/160294509): Remove this condition when TFT no longer supports TF<2.
elif isinstance(split, tf.SparseTensor):
split_tensor = tf.sparse.to_dense(split)
return split_tensor[:, 1], split_tensor[:, 0]
else:
return split[1], split[0]

Expand Down
3 changes: 0 additions & 3 deletions tensorflow_transform/tf_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2404,7 +2404,4 @@ def foo(input_tensor):


if __name__ == '__main__':
# TODO(b/160294509): Remove this once this is enabled by default in all
# supported TF versions.
tf.compat.v1.enable_v2_tensorshape()
test_case.main()

0 comments on commit 1fa2ed8

Please sign in to comment.