[go: nahoru, domu]

Skip to content

Commit

Permalink
Replacing deprecated beam.RemoveDuplicates with beam.Distinct, adding…
Browse files Browse the repository at this point in the history
… some pytype annotations to some methods, as well as using np.divide in order to avoid division by 0 issues.

PiperOrigin-RevId: 397700305
  • Loading branch information
zoyahav authored and tf-transform-team committed Sep 20, 2021
1 parent 9c36b78 commit 33ad15c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
28 changes: 20 additions & 8 deletions tensorflow_transform/beam/analyzer_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ def map_key_to_count_and_term(kv, key_fn):


# Experimental
def sum_labeled_weights(accs):
def sum_labeled_weights(
accs: List[Tuple[float, List[float]]]) -> Tuple[float, List[float]]:
"""Sums up a collection of labeled-weight tables.
Args:
Expand Down Expand Up @@ -413,7 +414,7 @@ def is_valid_string(kv):

result = ((result, coverage_counts)
| 'MergeStandardAndCoverageArms' >> beam.Flatten()
| 'RemoveDuplicates' >> beam.RemoveDuplicates())
| 'RemoveDuplicates' >> beam.Distinct())

return result

Expand Down Expand Up @@ -500,7 +501,8 @@ def fingerprint_sort_fn(kv):
return (wait_for_vocabulary_transform,)


def _flatten_value_to_list(batch_values):
def _flatten_value_to_list(
batch_values: Tuple[np.ndarray, ...]) -> Iterable[Any]:
"""Converts an N-D dense or sparse batch to a 1-D list."""
batch_value, = batch_values

Expand All @@ -510,7 +512,8 @@ def _flatten_value_to_list(batch_values):
return batch_value.tolist()


def _flatten_value_and_weights_to_list_of_tuples(batch_values):
def _flatten_value_and_weights_to_list_of_tuples(
batch_values: Tuple[np.ndarray, ...]) -> Iterable[Any]:
"""Converts a batch of vocabulary and weights to a list of KV tuples."""
batch_value, weights = batch_values

Expand All @@ -523,7 +526,8 @@ def _flatten_value_and_weights_to_list_of_tuples(batch_values):


# Experimental
def _flatten_value_and_labeled_weights_to_list_of_tuples(batch_values):
def _flatten_value_and_labeled_weights_to_list_of_tuples(
batch_values: Tuple[np.ndarray, ...]) -> Iterable[Any]:
"""Converts a batch of vocabulary and labeled weights to a list of KV tuples.
Args:
Expand Down Expand Up @@ -1032,7 +1036,8 @@ def _merge_outputs_by_key(keys_and_outputs, outputs_dtype):
dtype=dtype.as_numpy_dtype))


def _make_strictly_increasing_boundaries_rows(boundary_matrix):
def _make_strictly_increasing_boundaries_rows(
boundary_matrix: np.ndarray) -> np.ndarray:
"""Converts a 2-d array of increasing rows to strictly increasing rows.
Args:
Expand Down Expand Up @@ -1060,7 +1065,9 @@ def _make_strictly_increasing_boundaries_rows(boundary_matrix):
return np.insert(corrected_boundaries, 0, boundary_matrix[:, 0], axis=1)


def _join_boundary_rows(boundary_matrix):
def _join_boundary_rows(
boundary_matrix: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Joins boundaries per key, by scaling and shifting them.
This returns a new list of boundaries which is composed from the given 2-d
Expand Down Expand Up @@ -1090,7 +1097,12 @@ def _join_boundary_rows(boundary_matrix):
# Max boundary for each row.
max_boundary = np.max(boundary_matrix, axis=1)

scale = 1.0 / (max_boundary - min_boundary)
boundary_difference = max_boundary - min_boundary
scale = np.divide(
1.0,
boundary_difference,
out=np.ones_like(boundary_difference),
where=boundary_difference != 0)

# Shifts what would shift values so that when applied to min[key_id] we
# get: min[key_id] * scale[key_id] + shift[key_id] = key_id
Expand Down
7 changes: 7 additions & 0 deletions tensorflow_transform/beam/analyzer_impls_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ def testMakeStrictlyIncreasingBoundariesRows(self,
expected_scales=np.array([0.333333333, 333333.333, 0.166666667]),
expected_shifts=np.array([0, -999999, 1.66666667]),
expected_num_buckets=np.array(5)),
dict(
testcase_name='SingleBoundary',
input_boundaries=np.array([[1], [2]]),
expected_boundaries=np.array([0]),
expected_scales=np.array([1., 1.]),
expected_shifts=np.array([-1, -1]),
expected_num_buckets=np.array(2)),
)
def testJoinBoundarieRows(self, input_boundaries, expected_boundaries,
expected_scales, expected_shifts,
Expand Down

0 comments on commit 33ad15c

Please sign in to comment.