[go: nahoru, domu]

Skip to content

Commit

Permalink
Avoid erroneous warnings regarding table initialization inside tf.ini…
Browse files Browse the repository at this point in the history
…t_scope when calling InitializableGraphAnalyzer from private methods (such as get_dependent_inputs).

get_dependent_inputs is mostly called with a partial graph where assets are of type tf.Tensor even though it's running using the native TF2 environment.
This change also updates the expected type of the parameter `InitializableGraphAnalyzer` to `tft.apply_vocabulary`.

PiperOrigin-RevId: 387127346
  • Loading branch information
zoyahav authored and tf-transform-team committed Jul 27, 2021
1 parent ad64484 commit 2e7a25b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
28 changes: 19 additions & 9 deletions tensorflow_transform/graph_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import collections
import itertools
import uuid
from absl import logging

import tensorflow as tf
from tensorflow_transform import analyzer_nodes
Expand Down Expand Up @@ -569,13 +570,12 @@ def __init__(self,
continue

if isinstance(graph, tf_func_graph.FuncGraph):
tf.compat.v1.logging.warning('Tables initialized inside a tf.function '
'will be re-initialized on every '
'invocation of the function. This '
're-initialization can have significant '
'impact on performance. Consider lifting '
'them out of the graph context using '
'`tf.init_scope`.')
self._log_warning('Tables initialized inside a tf.function will be'
' re-initialized on every invocation of the function.'
' This re-initialization can have significant impact'
' on performance. Consider lifting them out of the'
' graph context using `tf.init_scope`.: {}'.format(
table_init_op_or_tensor.name))

table_init_op, table_input_ops = (
self._get_table_init_op_and_inputs(table_init_op_or_tensor))
Expand All @@ -593,6 +593,9 @@ def __init__(self,
self._graph_analyzer = _GraphAnalyzer(complete_source_info_dict,
translate_path_fn, graph)

def _log_warning(self, message: str):
logging.warning(message)

def _get_table_init_op_and_inputs(self, table_init_op_or_tensor):
"""Get a tuple of table init op and keys for its input ops."""
# If a TF2 exported SavedModel with a table is loaded inside the
Expand Down Expand Up @@ -759,6 +762,13 @@ def get_dependent_inputs(self, tensor_or_op):
return result


class _QuietInitializableGraphAnalyzer(InitializableGraphAnalyzer):
"""A `InitializableGraphAnalyzer` which doesn't log any warnings."""

def _log_warning(self, message: str):
pass


def get_dependent_inputs(graph, input_tensors, output_tensors):
"""Returns tensors in input_tensors that (transitively) produce output_tensors.
Expand Down Expand Up @@ -790,8 +800,8 @@ def get_dependent_inputs(graph, input_tensors, output_tensors):
# tensors doesn't affect the correctness of dependencies tracing.
tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS)
sink_tensors_ready = [(sink.tensor, False) for sink in tensor_sinks]
graph_analyzer = InitializableGraphAnalyzer(graph, input_tensors,
sink_tensors_ready)
graph_analyzer = _QuietInitializableGraphAnalyzer(graph, input_tensors,
sink_tensors_ready)
dependent_inputs = {}
for output_tensor in output_container:
dependent_inputs.update(graph_analyzer.get_dependent_inputs(output_tensor))
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_transform/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,7 @@ def compute_and_apply_vocabulary(
@common.log_api_use(common.MAPPER_COLLECTION)
def apply_vocabulary(
x: common_types.ConsistentTensorType,
deferred_vocab_filename_tensor: tf.Tensor,
deferred_vocab_filename_tensor: common_types.TemporaryAnalyzerOutputType,
default_value: Optional[Any] = -1,
num_oov_buckets: Optional[int] = 0,
lookup_fn: Optional[Callable[[common_types.TensorType, tf.Tensor],
Expand Down

0 comments on commit 2e7a25b

Please sign in to comment.