[go: nahoru, domu]

Skip to content

Commit

Permalink
Simplified experimental sharding logic for TFDV. Now output shards ar…
Browse files Browse the repository at this point in the history
…e guaranteed to contain all statistics for a subset of features.

PiperOrigin-RevId: 425510691
  • Loading branch information
tfx-copybara committed Feb 1, 2022
1 parent 6185c13 commit 0d5d1ce
Show file tree
Hide file tree
Showing 8 changed files with 394 additions and 46 deletions.
1 change: 1 addition & 0 deletions tensorflow_data_validation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
# Import stats utilities.
from tensorflow_data_validation.utils.stats_util import get_feature_stats
from tensorflow_data_validation.utils.stats_util import get_slice_stats
from tensorflow_data_validation.utils.stats_util import load_sharded_statistics
from tensorflow_data_validation.utils.stats_util import load_statistics
from tensorflow_data_validation.utils.stats_util import load_stats_binary
from tensorflow_data_validation.utils.stats_util import load_stats_text
Expand Down
60 changes: 35 additions & 25 deletions tensorflow_data_validation/statistics/stats_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from tensorflow_data_validation.statistics.generators import top_k_uniques_sketch_stats_generator
from tensorflow_data_validation.statistics.generators import top_k_uniques_stats_generator
from tensorflow_data_validation.statistics.generators import weighted_feature_stats_generator
from tensorflow_data_validation.utils import feature_partition_util
from tensorflow_data_validation.utils import slicing_util
from tfx_bsl.arrow import table_util
from tfx_bsl.statistics import merge_util
Expand Down Expand Up @@ -121,6 +122,24 @@ def _TrackDistinctSliceKeys( # pylint: disable=invalid-name
lambda x: _increment_counter('num_distinct_slice_keys', x)))


class _YieldPlaceholderFn(beam.DoFn):
"""Yields a single empty proto if input (count) is zero."""

def process(self, count: int):
if count == 0:
yield (None, statistics_pb2.DatasetFeatureStatistics())


@beam.ptransform_fn
def _AddPlaceholderStatistics( # pylint: disable=invalid-name
statistics: beam.PCollection[Tuple[
types.SliceKey, statistics_pb2.DatasetFeatureStatistics]]):
"""Adds a placeholder empty dataset for empty input, otherwise noop."""
count = statistics | beam.combiners.Count.Globally()
maybe_placeholder = count | beam.ParDo(_YieldPlaceholderFn())
return (statistics, maybe_placeholder) | beam.Flatten()


# This transform will be used by the example validation API to compute
# statistics over anomalous examples. Specifically, it is used to compute
# statistics over examples found for each anomaly (i.e., the anomaly type
Expand Down Expand Up @@ -178,31 +197,22 @@ def expand(
_CombinerStatsGeneratorsCombineFn(
combiner_stats_generators, self._options.desired_batch_size)
).with_hot_key_fanout(fanout))

# result_protos is a list of PCollections of (slice key,
# DatasetFeatureStatistics proto) pairs.
if (self._options.experimental_output_type ==
stats_options.OUTPUT_TYPE_BINARY_PB):
# We now flatten the list into a single PCollection, combine the
# DatasetFeatureStatistics protos by key, and then merge the
# DatasetFeatureStatistics protos in the PCollection into a single
# DatasetFeatureStatisticsList proto.
return (result_protos
| 'FlattenFeatureStatistics' >> beam.Flatten()
| 'AddSliceKeyToStatsProto' >> beam.Map(_add_slice_key,
self._is_slicing_enabled)
| 'ToList' >> beam.combiners.ToList()
| 'MergeDatasetFeatureStatisticsProtos' >> beam.Map(
merge_util.merge_dataset_feature_statistics))
else:
# If we're writing sharded data, we can flatten to a single PCollection,
# and wrap each shard into a singleton list.
return (result_protos
| 'FlattenFeatureStatistics' >> beam.Flatten()
| 'AddSliceKeyToStatsProto' >> beam.Map(_add_slice_key,
self._is_slicing_enabled)
| 'MakeDatasetFeatureStatisticsListProto' >>
beam.Map(_make_singleton_dataset_feature_statistics_list_proto))
result_protos = result_protos | 'FlattenFeatureStatistics' >> beam.Flatten()
result_protos = (
result_protos
| 'AddPlaceholderStatistics' >> _AddPlaceholderStatistics()) # pylint: disable=no-value-for-parameter
# Combine result_protos into a configured number of partitions.
return (result_protos
| 'AddSliceKeyToStatsProto' >> beam.Map(_add_slice_key,
self._is_slicing_enabled)
| 'MakeDatasetFeatureStatisticsListProto' >>
beam.Map(_make_singleton_dataset_feature_statistics_list_proto)
| 'SplitIntoFeaturePartitions' >> beam.ParDo(
feature_partition_util.KeyAndSplitByFeatureFn(
self._options.experimental_result_partitions))
| 'MergeStatsProtos' >> beam.CombinePerKey(
merge_util.merge_dataset_feature_statistics_list)
| 'Values' >> beam.Values())


def get_generators(options: stats_options.StatsOptions,
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_data_validation/statistics/stats_impl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2889,7 +2889,7 @@ def extract_output(self, accumulator):
num_histogram_buckets=2,
num_quantiles_histogram_buckets=2,
enable_semantic_domain_stats=True,
experimental_output_type=stats_options.OUTPUT_TYPE_TFRECORDS),
experimental_result_partitions=999), # 999 >> #features.
'expected_result_proto_text':
_SLICED_STATS_TEST_RESULT,
'expected_shards':
Expand Down
30 changes: 14 additions & 16 deletions tensorflow_data_validation/statistics/stats_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
_SCHEMA_JSON_KEY = 'schema_json'
_PER_FEATURE_WEIGHT_OVERRIDE_JSON_KEY = 'per_feature_weight_override_json'

OUTPUT_TYPE_BINARY_PB = 'binary_pb'
OUTPUT_TYPE_TFRECORDS = 'tfrecords'

# TODO(b/181559345): Currently we use a single epsilon (error tolerance)
# parameter for all histograms. Set this parameter specific to each
Expand Down Expand Up @@ -77,7 +75,7 @@ def __init__(
experimental_use_sketch_based_topk_uniques: bool = False,
experimental_slice_functions: Optional[List[types.SliceFunction]] = None,
experimental_slice_sqls: Optional[List[Text]] = None,
experimental_output_type: str = OUTPUT_TYPE_BINARY_PB):
experimental_result_partitions: int = 1):
"""Initializes statistics options.
Args:
Expand Down Expand Up @@ -175,10 +173,10 @@ def __init__(
example.country WHERE country = 'USA'" Only one of
experimental_slice_functions or experimental_slice_sqls must be
specified. Note that this option is not supported on Windows.
experimental_output_type: One of 'binary_pb' (default), or 'tfrecords'. If
this is 'binary_pb' the output will be a single binary proto consisting
of all output merged across features, generators, and slices. If this is
'tfrecords', the output will be written in sharded form.
experimental_result_partitions: The number of feature partitions to
combine output DatasetFeatureStatisticsLists into. If set to 1 (default)
output is globally combined. If set to value greater than one, up to
that many shards are returned, each containing a subset of features.
"""
self.generators = generators
self.feature_allowlist = feature_allowlist
Expand Down Expand Up @@ -212,7 +210,7 @@ def __init__(
self.experimental_use_sketch_based_topk_uniques = (
experimental_use_sketch_based_topk_uniques)
self.experimental_slice_sqls = experimental_slice_sqls
self.experimental_output_type = experimental_output_type
self.experimental_result_partitions = experimental_result_partitions

def to_json(self) -> Text:
"""Convert from an object to JSON representation of the __dict__ attribute.
Expand Down Expand Up @@ -464,17 +462,17 @@ def experimental_use_sketch_based_topk_uniques(
self._use_sketch_based_topk_uniques = use_sketch_based_topk_uniques

@property
def experimental_output_type(self) -> str:
return self._experimental_output_type
def experimental_result_partitions(self) -> int:
return self._experimental_result_partitions

@experimental_output_type.setter
def experimental_output_type(self, output_type: str) -> None:
if output_type in (OUTPUT_TYPE_BINARY_PB, OUTPUT_TYPE_TFRECORDS):
self._experimental_output_type = output_type
@experimental_result_partitions.setter
def experimental_result_partitions(self, num_partitions: int) -> None:
if num_partitions > 0:
self._experimental_result_partitions = num_partitions
else:
raise ValueError(
'Unsupported output type %s. Must be one of binary_pb, tfrecords.' %
output_type)
'Unsupported experimental_result_partitions <= 0: %d' %
num_partitions)


def _validate_sql(sql_query: Text, schema: schema_pb2.Schema):
Expand Down
8 changes: 5 additions & 3 deletions tensorflow_data_validation/statistics/stats_options_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def test_stats_options_json_round_trip(self):
per_feature_weight_override = {types.FeaturePath(['a']): 'w'}
add_default_generators = True
use_sketch_based_topk_uniques = True
experimental_output_type = 'tfrecords'
experimental_result_partitions = 3

options = stats_options.StatsOptions(
generators=generators,
Expand All @@ -310,7 +310,7 @@ def test_stats_options_json_round_trip(self):
per_feature_weight_override=per_feature_weight_override,
add_default_generators=add_default_generators,
experimental_use_sketch_based_topk_uniques=use_sketch_based_topk_uniques,
experimental_output_type=experimental_output_type)
experimental_result_partitions=experimental_result_partitions)

options_json = options.to_json()
options = stats_options.StatsOptions.from_json(options_json)
Expand Down Expand Up @@ -346,6 +346,8 @@ def test_stats_options_json_round_trip(self):
self.assertEqual(add_default_generators, options.add_default_generators)
self.assertEqual(use_sketch_based_topk_uniques,
options.experimental_use_sketch_based_topk_uniques)
self.assertEqual(experimental_result_partitions,
options.experimental_result_partitions)

def test_stats_options_from_json(self):
options_json = """{
Expand Down Expand Up @@ -373,7 +375,7 @@ def test_stats_options_from_json(self):
"_add_default_generators": true,
"_use_sketch_based_topk_uniques": false,
"_slice_sqls": null,
"_experimental_output_type": "binary_pb"
"_experimental_result_partitions": 1
}"""
actual_options = stats_options.StatsOptions.from_json(options_json)
expected_options_dict = stats_options.StatsOptions().__dict__
Expand Down
80 changes: 80 additions & 0 deletions tensorflow_data_validation/utils/feature_partition_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
import collections
import hashlib
from typing import Any, Iterable, Tuple, Union, FrozenSet, Mapping

import apache_beam as beam
import pyarrow as pa

from tensorflow_data_validation import types
from tensorflow_metadata.proto.v0 import statistics_pb2


class ColumnHasher(object):
Expand All @@ -44,6 +48,14 @@ def assign(self, feature_name: Union[bytes, str]) -> int:
self._cache[feature_name] = partition
return partition

def assign_sequence(self, *parts: Union[bytes, str]) -> int:
"""Assigns a feature partition based on a sequence of bytes or strings."""
partition = 0
for part in parts:
partition += self.assign(part)
partition = partition % self.num_partitions
return partition

def __eq__(self, o):
return self.num_partitions == o.num_partitions

Expand Down Expand Up @@ -95,3 +107,71 @@ def generate_feature_partitions(
key = (slice_key, partition)
column_names, columns = features
yield (key, pa.RecordBatch.from_arrays(columns, column_names))


def _copy_with_no_features(
statistics: statistics_pb2.DatasetFeatureStatistics
) -> statistics_pb2.DatasetFeatureStatistics:
"""Return a copy of 'statistics' with no features or cross-features."""
return statistics_pb2.DatasetFeatureStatistics(
name=statistics.name,
num_examples=statistics.num_examples,
weighted_num_examples=statistics.weighted_num_examples)


@beam.typehints.with_input_types(statistics_pb2.DatasetFeatureStatisticsList)
@beam.typehints.with_output_types(
beam.typehints.KV[int, statistics_pb2.DatasetFeatureStatisticsList])
class KeyAndSplitByFeatureFn(beam.DoFn):
"""Breaks a DatasetFeatureStatisticsList into shards keyed by partition index.
Each partition index contains a random (but deterministic across workers)
subset of features and cross features.
"""

def __init__(self, num_partitions: int):
"""Initializes KeyAndSplitByFeatureFn.
Args:
num_partitions: The number of partitions to divide features/cross-features
into. Must be >= 1.
"""
if num_partitions < 1:
raise ValueError('num_partitions must be >= 1.')
if num_partitions != 1:
self._hasher = ColumnHasher(num_partitions)
else:
self._hasher = None

def process(self, statistics: statistics_pb2.DatasetFeatureStatisticsList):
# If the number of partitions is one, or there are no datasets, yield the
# full statistics proto with a placeholder key.
if self._hasher is None or not statistics.datasets:
yield (0, statistics)
return
for dataset in statistics.datasets:
for feature in dataset.features:
if feature.name:
partition = self._hasher.assign_sequence(dataset.name, feature.name)
else:
partition = self._hasher.assign_sequence(dataset.name,
*feature.path.step)
dataset_copy = _copy_with_no_features(dataset)
dataset_copy.features.append(feature)
yield (partition,
statistics_pb2.DatasetFeatureStatisticsList(
datasets=[dataset_copy]))
for cross_feature in dataset.cross_features:
partition = self._hasher.assign_sequence(dataset.name,
*cross_feature.path_x.step,
*cross_feature.path_y.step)
dataset_copy = _copy_with_no_features(dataset)
dataset_copy.cross_features.append(cross_feature)
yield (partition,
statistics_pb2.DatasetFeatureStatisticsList(
datasets=[dataset_copy]))
# If there were no features or cross-features, yield the dataset itself
# into shard 0 to ensure it's not dropped entirely.
if not dataset.features and not dataset.cross_features:
yield (0,
statistics_pb2.DatasetFeatureStatisticsList(datasets=[dataset]))
Loading

0 comments on commit 0d5d1ce

Please sign in to comment.