[go: nahoru, domu]

Skip to content

Commit

Permalink
Add a bit of pytype to nodes.py where makes sense.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 364572690
  • Loading branch information
zoyahav authored and tf-transform-team committed Mar 23, 2021
1 parent 327ba1e commit b887ccc
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 27 deletions.
4 changes: 2 additions & 2 deletions tensorflow_transform/analyzers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,8 +1553,8 @@ def _get_vocab_filename(vocab_filename, store_frequency):
Args:
vocab_filename: The file name for the vocabulary file. If none, the
"uniques" scope name in the context of this graph will be used as the file
name.
"vocabulary" scope name in the context of this graph will be used as the
file name.
store_frequency: A bool that is true when the vocabulary for which this
generates a filename stores term frequency. False otherwise.
Expand Down
6 changes: 2 additions & 4 deletions tensorflow_transform/graph_tools_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import abc
import os
import tempfile

# GOOGLE-INITIALIZATION
from future.utils import with_metaclass

import six

Expand Down Expand Up @@ -296,9 +296,7 @@ def larger_than_100(x):
return {'x': x, 'y': larger_than_100(x)}


class _Matcher(object):

__metaclass__ = abc.ABCMeta
class _Matcher(with_metaclass(abc.ABCMeta, object)):

def _future_proof(self, value):
if isinstance(value, (six.text_type, str, bytes)):
Expand Down
7 changes: 3 additions & 4 deletions tensorflow_transform/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,8 +1078,8 @@ def apply_vocabulary(
Otherwise it is assigned the `default_value`.
lookup_fn: Optional lookup function, if specified it should take a tensor
and a deferred vocab filename as an input and return a lookup `op` along
with the table size, by default `apply_vocab` constructs a StaticHashTable
for the table lookup.
with the table size, by default `apply_vocabulary` constructs a
StaticHashTable for the table lookup.
file_format: (Optional) A str. The format of the given vocabulary.
Accepted formats are: 'tfrecord_gzip', 'text'.
The default value is 'text'.
Expand Down Expand Up @@ -1872,8 +1872,7 @@ def _apply_buckets_with_keys(
key_values = key.values if isinstance(key, tf.SparseTensor) else key

x_values = tf.cast(x_values, tf.float32)
# Convert `key_values` to indices in key_vocab. We must use apply_function
# since this uses a Table.
# Convert `key_values` to indices in key_vocab.
key_indices = tf_utils.lookup_key(key_values, key_vocab)

adjusted_key_indices = tf.where(
Expand Down
6 changes: 4 additions & 2 deletions tensorflow_transform/mappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,8 @@ def testToTFIDF(self):
[(3/5), (1/5), (1/5), (1/2), (1/2)],
[2, 4])
reduced_term_freq = tf.constant([[2, 1, 1, 1]])
output_tensor = mappers._to_tfidf(term_freq, reduced_term_freq, 2, True)
output_tensor = mappers._to_tfidf(term_freq, reduced_term_freq,
tf.constant(2), True)
log_3_over_2 = 1.4054651
self.assertSparseOutput(
expected_indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 3]],
Expand All @@ -539,7 +540,8 @@ def testToTFIDFNotSmooth(self):
[(3/5), (1/5), (1/5), (1/2), (1/2)],
[2, 4])
reduced_term_freq = tf.constant([[2, 1, 1, 1]])
output_tensor = mappers._to_tfidf(term_freq, reduced_term_freq, 2, False)
output_tensor = mappers._to_tfidf(term_freq, reduced_term_freq,
tf.constant(2), False)
log_2_over_1 = 1.6931471
self.assertSparseOutput(
expected_indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 3]],
Expand Down
32 changes: 17 additions & 15 deletions tensorflow_transform/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

import abc
import collections
from typing import Collection, Optional, Tuple

from future.utils import with_metaclass
import pydot
# TODO(https://issues.apache.org/jira/browse/SPARK-22674): Switch to
Expand All @@ -48,7 +50,7 @@ class ValueNode(
value_index: The index of this value in the outputs of `parent_operation`.
"""

def __init__(self, parent_operation, value_index):
def __init__(self, parent_operation, value_index: int):
if not isinstance(parent_operation, OperationNode):
raise TypeError(
'parent_operation must be a OperationNode, got {} of type {}'.format(
Expand Down Expand Up @@ -78,21 +80,21 @@ class OperationDef(with_metaclass(abc.ABCMeta, object)):
"""

@property
def num_outputs(self):
def num_outputs(self) -> int:
"""The number of outputs returned by this operation."""
return 1

@abc.abstractproperty
def label(self):
def label(self) -> str:
"""A unique label for this operation in the graph."""
pass

def get_field_str(self, field_name):
def get_field_str(self, field_name: str) -> str:
"""Returns a str representation of the requested field."""
return getattr(self, field_name)

@property
def is_partitionable(self):
def is_partitionable(self) -> bool:
"""If True, means that this operation can be applied on partitioned data.
Being able to be applied on partitioned data means that partitioning the
Expand All @@ -107,7 +109,7 @@ def is_partitionable(self):
return False

@property
def cache_coder(self):
def cache_coder(self) -> Optional[object]:
"""A CacheCoder object used to cache outputs returned by this operation.
If this doesn't return None, then:
Expand Down Expand Up @@ -226,7 +228,7 @@ def visit(self, operation_def, input_values):
class Traverser(object):
"""Class to traverse the DAG of nodes."""

def __init__(self, visitor):
def __init__(self, visitor: Visitor):
"""Init method for Traverser.
Args:
Expand All @@ -236,7 +238,7 @@ def __init__(self, visitor):
self._stack = []
self._visitor = visitor

def visit_value_node(self, value_node):
def visit_value_node(self, value_node: ValueNode):
"""Visit a value node, and return a corresponding value.
Args:
Expand All @@ -248,7 +250,7 @@ def visit_value_node(self, value_node):
"""
return self._maybe_visit_value_node(value_node)

def _maybe_visit_value_node(self, value_node):
def _maybe_visit_value_node(self, value_node: ValueNode):
"""Visit a value node if not cached, and return a corresponding value.
Args:
Expand All @@ -262,7 +264,7 @@ def _maybe_visit_value_node(self, value_node):
self._visit_operation(value_node.parent_operation)
return self._cached_value_nodes_values[value_node]

def _visit_operation(self, operation):
def _visit_operation(self, operation: OperationNode):
"""Visit an `OperationNode`."""
if operation in self._stack:
cycle = self._stack[self._stack.index(operation):] + [operation]
Expand Down Expand Up @@ -295,7 +297,7 @@ def _visit_operation(self, operation):
self._cached_value_nodes_values[output] = value


def _escape(line):
def _escape(line: str) -> str:
for char in '<>{}':
line = line.replace(char, '\\%s' % char)
return line
Expand All @@ -312,10 +314,10 @@ def __init__(self):
self._dot_graph.set_node_defaults(shape='Mrecord')
super(_PrintGraphVisitor, self).__init__()

def get_dot_graph(self):
def get_dot_graph(self) -> pydot.Dot:
return self._dot_graph

def visit(self, operation_def, input_nodes):
def visit(self, operation_def, input_nodes) -> Tuple[pydot.Node, ...]:
num_outputs = operation_def.num_outputs
node_name = operation_def.label

Expand Down Expand Up @@ -346,11 +348,11 @@ def visit(self, operation_def, input_nodes):
pydot.Node(obj_dict={'name': '"{}":{}'.format(node_name, idx)})
for idx in range(num_outputs))

def validate_value(self, value):
def validate_value(self, value: pydot.Node):
assert isinstance(value, pydot.Node)


def get_dot_graph(leaf_nodes):
def get_dot_graph(leaf_nodes: Collection[ValueNode]) -> pydot.Dot:
"""Utility to print a graph in a human readable manner.
The format resembles a sequence of calls to apply_operation or
Expand Down

0 comments on commit b887ccc

Please sign in to comment.