Skip to content

Commit b98fbad

Browse files
zwestricktfx-copybara
authored andcommitted
Experimental support for passing a schema to TFXIO for StatsGen.
PiperOrigin-RevId: 492221403
1 parent ad31543 commit b98fbad

File tree

6 files changed

+100
-11
lines changed

6 files changed

+100
-11
lines changed

tensorflow_data_validation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656

5757

5858
# Import schema utilities.
59+
from tensorflow_data_validation.utils.schema_util import generate_dummy_schema_with_paths
5960
from tensorflow_data_validation.utils.schema_util import get_domain
6061
from tensorflow_data_validation.utils.schema_util import get_feature
6162
from tensorflow_data_validation.utils.schema_util import load_schema_text

tensorflow_data_validation/statistics/stats_impl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import math
1818
import random
19-
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Text, Tuple
19+
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Text, Tuple, Union
2020

2121
import apache_beam as beam
2222
import pyarrow as pa
@@ -465,7 +465,8 @@ def _schema_has_natural_language_domains(schema: schema_pb2.Schema) -> bool:
465465

466466
def _filter_features(
467467
record_batch: pa.RecordBatch,
468-
feature_allowlist: List[types.FeatureName]) -> pa.RecordBatch:
468+
feature_allowlist: Union[List[types.FeatureName], List[types.FeaturePath]]
469+
) -> pa.RecordBatch:
469470
"""Removes features that are not on the allowlist.
470471
471472
Args:
@@ -478,6 +479,9 @@ def _filter_features(
478479
columns_to_select = []
479480
column_names_to_select = []
480481
for feature_name in feature_allowlist:
482+
if isinstance(feature_name, types.FeaturePath):
483+
# TODO(b/255895499): Support paths.
484+
raise NotImplementedError
481485
col = arrow_util.get_column(record_batch, feature_name, missing_ok=True)
482486
if col is None:
483487
continue

tensorflow_data_validation/statistics/stats_options.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import json
2323
import logging
2424
import types as python_types
25-
from typing import Dict, List, Optional, Text
25+
from typing import Dict, List, Optional, Text, Union
2626

2727
from tensorflow_data_validation import types
2828
from tensorflow_data_validation.statistics.generators import stats_generator
@@ -76,14 +76,17 @@ def __init__(
7676
types.FeatureName]] = None,
7777
vocab_paths: Optional[Dict[types.VocabName, types.VocabPath]] = None,
7878
add_default_generators: bool = True,
79-
feature_allowlist: Optional[List[types.FeatureName]] = None,
79+
# TODO(b/255895499): Support "from schema" for feature_allowlist.
80+
feature_allowlist: Optional[Union[List[types.FeatureName],
81+
List[types.FeaturePath]]] = None,
8082
experimental_use_sketch_based_topk_uniques: Optional[bool] = None,
8183
use_sketch_based_topk_uniques: Optional[bool] = None,
8284
experimental_slice_functions: Optional[List[types.SliceFunction]] = None,
8385
experimental_slice_sqls: Optional[List[Text]] = None,
8486
experimental_result_partitions: int = 1,
8587
experimental_num_feature_partitions: int = 1,
86-
slicing_config: Optional[slicing_spec_pb2.SlicingConfig] = None):
88+
slicing_config: Optional[slicing_spec_pb2.SlicingConfig] = None,
89+
experimental_filter_read_paths: bool = False):
8790
"""Initializes statistics options.
8891
8992
Args:
@@ -151,7 +154,7 @@ def __init__(
151154
(controlled by `enable_semantic_domain_stats`) and 4) schema-based
152155
generators that are enabled based on information provided in the schema.
153156
feature_allowlist: An optional list of names of the features to calculate
154-
statistics for.
157+
statistics for, or a list of paths.
155158
experimental_use_sketch_based_topk_uniques: Deprecated, prefer
156159
use_sketch_based_topk_uniques.
157160
use_sketch_based_topk_uniques: if True, use the sketch based
@@ -193,8 +196,11 @@ def __init__(
193196
number of features in a dataset, and never more than the available beam
194197
parallelism.
195198
slicing_config: an optional SlicingConfig. SlicingConfig includes
196-
slicing_specs specified with feature keys, feature values or slicing
197-
SQL queries.
199+
slicing_specs specified with feature keys, feature values or slicing
200+
SQL queries.
201+
experimental_filter_read_paths: If provided, tries to push down either
202+
paths passed via feature_allowlist or via the schema (in that priority)
203+
to the underlying read operation. Support depends on the file reader.
198204
"""
199205
self.generators = generators
200206
self.feature_allowlist = feature_allowlist
@@ -241,6 +247,7 @@ def __init__(
241247
self.experimental_num_feature_partitions = experimental_num_feature_partitions
242248
self.experimental_result_partitions = experimental_result_partitions
243249
self.slicing_config = slicing_config
250+
self.experimental_filter_read_paths = experimental_filter_read_paths
244251

245252
def to_json(self) -> Text:
246253
"""Convert from an object to JSON representation of the __dict__ attribute.
@@ -340,12 +347,16 @@ def generators(
340347
self._generators = generators
341348

342349
@property
343-
def feature_allowlist(self) -> Optional[List[types.FeatureName]]:
350+
def feature_allowlist(
351+
self
352+
) -> Optional[Union[List[types.FeatureName], List[types.FeaturePath]]]:
344353
return self._feature_allowlist
345354

346355
@feature_allowlist.setter
347356
def feature_allowlist(
348-
self, feature_allowlist: Optional[List[types.FeatureName]]) -> None:
357+
self, feature_allowlist: Optional[Union[List[types.FeatureName],
358+
List[types.FeaturePath]]]
359+
) -> None:
349360
if feature_allowlist is not None and not isinstance(feature_allowlist,
350361
list):
351362
raise TypeError('feature_allowlist is of type %s, should be a list.' %
@@ -554,6 +565,14 @@ def experimental_num_feature_partitions(self,
554565
raise ValueError('experimental_num_feature_partitions must be > 0.')
555566
self._experimental_num_feature_partitions = feature_partitions
556567

568+
@property
569+
def experimental_filter_read_paths(self) -> bool:
570+
return self._experimental_filter_read_paths
571+
572+
@experimental_filter_read_paths.setter
573+
def experimental_filter_read_paths(self, filter_read: bool) -> None:
574+
self._experimental_filter_read_paths = filter_read
575+
557576

558577
def _validate_sql(sql_query: Text, schema: schema_pb2.Schema):
559578
arrow_schema = example_coder.ExamplesToRecordBatchDecoder(

tensorflow_data_validation/statistics/stats_options_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,8 @@ def test_stats_options_from_json(self,
474474
"_slice_sqls": null,
475475
"_experimental_result_partitions": 1,
476476
"_experimental_num_feature_partitions": 1,
477-
"_slicing_config": null
477+
"_slicing_config": null,
478+
"_experimental_filter_read_paths": false
478479
"""
479480
options_json += type_name_line + '}'
480481
if want_exception:

tensorflow_data_validation/utils/schema_util.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,3 +374,37 @@ def _recursion_helper(
374374
result = []
375375
_recursion_helper(types.FeaturePath([]), schema.feature, result)
376376
return result
377+
378+
379+
def _paths_to_tree(paths: List[types.FeaturePath]):
380+
"""Convert paths to recursively nested dict."""
381+
nested_dict = lambda: collections.defaultdict(nested_dict)
382+
383+
result = nested_dict()
384+
385+
def _add(tree, path):
386+
if not path:
387+
return
388+
children = tree[path[0]]
389+
_add(children, path[1:])
390+
391+
for path in paths:
392+
_add(result, path.steps())
393+
return result
394+
395+
396+
def generate_dummy_schema_with_paths(
397+
paths: List[types.FeaturePath]) -> schema_pb2.Schema:
398+
"""Generate a schema with the requested paths and no other information."""
399+
schema = schema_pb2.Schema()
400+
tree = _paths_to_tree(paths)
401+
402+
def _add(container, name, children):
403+
container.feature.add(name=name)
404+
if children:
405+
for child_name, grandchildren in children.items():
406+
_add(container.feature[-1].struct_domain, child_name, grandchildren)
407+
408+
for name, children in tree.items():
409+
_add(schema, name, children)
410+
return schema

tensorflow_data_validation/utils/schema_util_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,5 +754,35 @@ def test_look_up_feature(self):
754754
schema_util.look_up_feature('feature2', container), feature_2)
755755
self.assertIsNone(schema_util.look_up_feature('feature3', container), None)
756756

757+
def test_generate_dummy_schema_with_paths(self):
758+
schema = text_format.Parse(
759+
"""
760+
feature {
761+
name: "foo"
762+
}
763+
feature {
764+
name: "bar"
765+
}
766+
feature {
767+
name: "baz"
768+
struct_domain: {
769+
feature {
770+
name: "zip"
771+
}
772+
feature {
773+
name: "zop"
774+
}
775+
}
776+
}
777+
""", schema_pb2.Schema())
778+
self.assertEqual(
779+
schema_util.generate_dummy_schema_with_paths([
780+
types.FeaturePath(['foo']),
781+
types.FeaturePath(['bar']),
782+
types.FeaturePath(['baz', 'zip']),
783+
types.FeaturePath(['baz', 'zop'])
784+
]), schema)
785+
786+
757787
if __name__ == '__main__':
758788
absltest.main()

0 commit comments

Comments
 (0)