[go: nahoru, domu]

Skip to content

Commit

Permalink
Temporary fix for TFT analyzers which do not reduce instance dims to …
Browse files Browse the repository at this point in the history
…work with numpy 1.24

PiperOrigin-RevId: 508126642
  • Loading branch information
zoyahav authored and tfx-copybara committed Feb 8, 2023
1 parent 004e3d0 commit d69e02f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
14 changes: 11 additions & 3 deletions tensorflow_transform/analyzer_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,17 @@ def encode_cache(self, accumulator):
return tf.compat.as_bytes(json.dumps(primitive_accumulator))

def decode_cache(self, encoded_accumulator):
return np.array(
json.loads(tf.compat.as_text(encoded_accumulator)), dtype=self._dtype
)
# TODO(b/268341036): Set dtype correctly for combiners for numpy 1.24.
try:
return np.array(
json.loads(tf.compat.as_text(encoded_accumulator)), dtype=self._dtype
)
except ValueError:
if self._dtype != object:
return np.array(
json.loads(tf.compat.as_text(encoded_accumulator)), dtype=object
)
raise


class AnalyzerDef(nodes.OperationDef, metaclass=abc.ABCMeta):
Expand Down
15 changes: 14 additions & 1 deletion tensorflow_transform/beam/analyzer_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_validate_dataset_keys(self):
coder=analyzer_nodes._VocabularyAccumulatorCoder(),
value=[b'\x8a', 29]),
dict(
testcase_name='_VocabularyAccumulatorCoderClassAccumulator',
testcase_name='_WeightedMeanAndVarAccumulatorPerKey',
coder=analyzer_nodes._VocabularyAccumulatorCoder(),
value=[
b'A',
Expand All @@ -92,6 +92,19 @@ def test_validate_dataset_keys(self):
weight=np.array(0.),
)
]),
dict(
testcase_name='_WeightedMeanAndVarAccumulatorKeepDims',
coder=analyzer_nodes.JsonNumpyCacheCoder(),
# TODO(b/268341036): Remove this complication once np 1.24 issue is
# fixed.
value=analyzer_nodes.JsonNumpyCacheCoder(object).decode_cache(
analyzer_nodes.JsonNumpyCacheCoder().encode_cache(
analyzers._WeightedMeanAndVarAccumulator(
count=np.array(0),
mean=np.array([], dtype=np.float64),
variance=np.array([], dtype=np.float64),
weight=np.array(0.0))))
),
dict(
testcase_name='_QuantilesAccumulatorCoderClassAccumulator',
coder=analyzers._QuantilesSketchCacheCoder(),
Expand Down

0 comments on commit d69e02f

Please sign in to comment.