[go: nahoru, domu]

Skip to content

Commit

Permalink
1st change in tfrecord vocabulary support CL chain.
Browse files Browse the repository at this point in the history
This CL moves DatasetInitializer from core lookup_ops, to tf.data's experimental module lookup_ops. It also introduces table_from_dataset and index_table_from_dataset as helpful wrappers around DatasetInitializer, and updates TFT to avoid breakage.

PiperOrigin-RevId: 368609540
  • Loading branch information
zoyahav authored and tf-transform-team committed Apr 15, 2021
1 parent 0952922 commit 8b501b5
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 10 deletions.
3 changes: 1 addition & 2 deletions tensorflow_transform/analyzers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1943,8 +1943,7 @@ def _vocabulary_analyzer_nodes(
) -> common_types.TemporaryAnalyzerOutputType:
"""Internal helper for analyzing vocab. See `vocabulary` doc string."""
if (file_format == 'tfrecord_gzip' and
(not hasattr(tf.lookup.experimental, 'DatasetInitializer') or
tf.version.VERSION < '2.4')):
not tf_utils.is_vocabulary_tfrecord_supported()):
raise ValueError(
'Vocabulary file_format "tfrecord_gzip" requires TF version >= 2.4')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

# GOOGLE-INITIALIZATION

import tensorflow as tf
from tensorflow_transform import tf_utils
from tensorflow_transform.beam import tft_unit
from tensorflow_transform.beam import vocabulary_integration_test

Expand All @@ -29,8 +29,8 @@ class TFRecordVocabularyIntegrationTest(

def setUp(self):
if (tft_unit.is_external_environment() and
(not hasattr(tf.lookup.experimental, 'DatasetInitializer') or
tf.version.VERSION < '2.4') or tft_unit.is_tf_api_version_1()):
not tf_utils.is_vocabulary_tfrecord_supported() or
tft_unit.is_tf_api_version_1()):
raise unittest.SkipTest('Test requires async DatasetInitializer')
super(TFRecordVocabularyIntegrationTest, self).setUp()

Expand Down
3 changes: 1 addition & 2 deletions tensorflow_transform/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,8 +1091,7 @@ def apply_vocabulary(
assigned default_value.
"""
if (file_format == 'tfrecord_gzip' and
(not hasattr(tf.lookup.experimental, 'DatasetInitializer') or
tf.version.VERSION < '2.4')):
not tf_utils.is_vocabulary_tfrecord_supported()):
raise ValueError(
'Vocabulary file_format "tfrecord_gzip" requires TF version >= 2.4')
with tf.compat.v1.name_scope(name, 'apply_vocab'):
Expand Down
14 changes: 12 additions & 2 deletions tensorflow_transform/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,14 @@ def reorder_histogram(bucket_vocab, counts, boundary_size):
return tf.gather(counts, ordering)


# TODO(b/62379925): Remove this once all supported TF versions have
# tf.data.experimental.DatasetInitializer.
def is_vocabulary_tfrecord_supported():
return ((hasattr(tf.data.experimental, 'DatasetInitializer') or
hasattr(tf.lookup.experimental, 'DatasetInitializer')) and
tf.version.VERSION >= '2.4')


def apply_bucketize_op(
x: tf.Tensor,
boundaries: tf.Tensor,
Expand Down Expand Up @@ -486,9 +494,11 @@ def apply_bucketize_op(
return bucket_indices


# TODO(b/62379925): Remove this once TF 2.3 is no longer supported.
# TODO(b/62379925): Remove this once all supported TF versions have
# tf.data.experimental.DatasetInitializer.
class _DatasetInitializerCompat(
getattr(tf.lookup.experimental, 'DatasetInitializer', object)):
getattr(tf.data.experimental, 'DatasetInitializer',
getattr(tf.lookup.experimental, 'DatasetInitializer', object))):
"""Extends DatasetInitializer when possible and registers the init_op."""

def __init__(self, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_transform/tf_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,7 +1366,7 @@ def test_extend_reduced_batch_with_y_counts(self):
class VocabTFUtilsTest(test_case.TransformTestCase):

def setUp(self):
if (not hasattr(tf.lookup.experimental, 'DatasetInitializer') and
if (not tf_utils.is_vocabulary_tfrecord_supported() and
test_case.is_external_environment()):
raise unittest.SkipTest('Test requires DatasetInitializer')
super(VocabTFUtilsTest, self).setUp()
Expand Down

0 comments on commit 8b501b5

Please sign in to comment.