[go: nahoru, domu]

Skip to content

Commit

Permalink
Port keras lookup layers to new adapt, use a StaticHashTable during call
Browse files Browse the repository at this point in the history
This is a significant refactor of the internals of the layer, which will break
SavedModel compatibility with previous versions. The usage of the layer will
remain the same, so a compatible layer should be generatable from the same
training script.

This refactor has the following advantages:
 - Static tables can be distributed to end workers in a multi-worker setting
   allowing more efficient distributed training.
 - File based vocabularies will only be scanned once.
 - Static vocabularies passed on init will be consistently clonable with the
   layer config, rather than clonable only in the file based case.

We now consistently enforce that a vocabulary must be set when calling the
layer on anything besides a keras.Input.

PiperOrigin-RevId: 380645230
Change-Id: I82f5e8dc77b48df409044ce88aca3582ac3658d1
  • Loading branch information
mattdangerw authored and Copybara-Service committed Jun 21, 2021
1 parent f29b375 commit f14f0a5
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions tf_agents/networks/encoding_network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,31 +301,31 @@ def testNumericKerasInput(self):
self.assertEqual(input_shape, output.shape)

def testKerasIntegerLookup(self):
self.skipTest('Re-enable this test after cl/362119497 on tf-nightly.')
if not tf.executing_eagerly():
self.skipTest('This test is TF2 only.')

key = 'feature_key'
vocab_list = [2, 3, 4]

keras_input = tf.keras.Input(shape=(1,), name=key, dtype=tf.dtypes.int32)
id_input = keras_preprocessing.IntegerLookup(vocabulary=vocab_list)
encoded_input = keras_preprocessing.CategoryEncoding(
max_tokens=len(vocab_list))
id_input = keras_preprocessing.IntegerLookup(
vocabulary=vocab_list, num_oov_indices=0, output_mode='multi_hot')

state_input = [3, 2, 2, 4, 3]
state = {key: tf.expand_dims(state_input, -1)}
input_spec = {key: tensor_spec.TensorSpec([1], tf.int32)}

network = encoding_network.EncodingNetwork(
input_spec,
preprocessing_combiner=tf.keras.Sequential(
[keras_input, id_input, encoded_input]))
preprocessing_combiner=tf.keras.Sequential([keras_input, id_input]))

output, _ = network(state)
expected_shape = (len(state_input), len(vocab_list))
self.assertEqual(expected_shape, output.shape)

def testCombinedKerasPreprocessingLayers(self):
self.skipTest('Re-enable this test after cl/362119497 on tf-nightly.')
if not tf.executing_eagerly():
self.skipTest('This test is TF2 only.')

Expand All @@ -339,10 +339,9 @@ def testCombinedKerasPreprocessingLayers(self):
vocab_list = [2, 3, 4]
inputs[indicator_key] = tf.keras.Input(
shape=(1,), dtype=tf.dtypes.int32, name=indicator_key)
id_input = keras_preprocessing.IntegerLookup(
vocabulary=vocab_list)(inputs[indicator_key])
features[indicator_key] = keras_preprocessing.CategoryEncoding(
max_tokens=len(vocab_list))(id_input)
features[indicator_key] = keras_preprocessing.IntegerLookup(
vocabulary=vocab_list, num_oov_indices=0, output_mode='multi_hot')(
inputs[indicator_key])
state_input = [3, 2, 2, 4, 3]
tensors[indicator_key] = tf.expand_dims(state_input, -1)
specs[indicator_key] = tensor_spec.TensorSpec([1], tf.int32)
Expand All @@ -354,7 +353,8 @@ def testCombinedKerasPreprocessingLayers(self):
inputs[embedding_key] = tf.keras.Input(
shape=(1,), dtype=tf.dtypes.int32, name=embedding_key)
id_input = keras_preprocessing.IntegerLookup(
vocabulary=vocab_list)(inputs[embedding_key])
vocabulary=vocab_list, num_oov_indices=0)(
inputs[embedding_key])
embedding_input = tf.keras.layers.Embedding(
input_dim=len(vocab_list),
output_dim=embedding_dim)(id_input)
Expand Down

0 comments on commit f14f0a5

Please sign in to comment.