Skip to content

Commit 4d0b51d

Browse files
cavenesstfx-copybara
authored andcommitted
Add support for running custom validations in validate_statistics().
PiperOrigin-RevId: 488721853
1 parent 276ab4c commit 4d0b51d

File tree

3 files changed

+161
-5
lines changed

3 files changed

+161
-5
lines changed

RELEASE.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010
## Major Features and Improvements
1111

12-
* Add a `custom_validate_statistics` function to the validation API. Note that
13-
this function is not available on Windows.
12+
* Add a `custom_validate_statistics` function to the validation API, and
13+
support passing custom validations to `validate_statistics`. Note that
14+
custom validation is not supported on Windows.
1415

1516
## Bug Fixes and Other Changes
1617

tensorflow_data_validation/api/validation_api.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import itertools
2122
import logging
2223
from typing import Callable, Iterable, List, Optional, Text, Tuple, Set
2324
import apache_beam as beam
@@ -58,6 +59,8 @@
5859
anomalies_pb2.AnomalyInfo.DATASET_HIGH_NUM_EXAMPLES,
5960
])
6061

62+
_MULTIPLE_ERRORS = 'Multiple errors'
63+
6164

6265
def infer_schema(
6366
statistics: statistics_pb2.DatasetFeatureStatisticsList,
@@ -189,6 +192,47 @@ def update_schema(schema: schema_pb2.Schema,
189192
return result
190193

191194

195+
def _merge_descriptions(
196+
anomaly_info: anomalies_pb2.AnomalyInfo,
197+
other_anomaly_info: Optional[anomalies_pb2.AnomalyInfo]) -> str:
198+
"""Merges anomaly descriptions."""
199+
descriptions = []
200+
if other_anomaly_info is not None:
201+
for reason in itertools.chain(anomaly_info.reason,
202+
other_anomaly_info.reason):
203+
descriptions.append(reason.description)
204+
else:
205+
descriptions = [reason.description for reason in anomaly_info.reason]
206+
return ' '.join(descriptions)
207+
208+
209+
def _merge_custom_anomalies(
210+
anomalies: anomalies_pb2.Anomalies,
211+
custom_anomalies: anomalies_pb2.Anomalies) -> anomalies_pb2.Anomalies:
212+
"""Merges custom_anomalies with anomalies."""
213+
for key, custom_anomaly_info in custom_anomalies.anomaly_info.items():
214+
if key in anomalies.anomaly_info:
215+
# If the key is found in in both inputs, we know it has multiple errors.
216+
anomalies.anomaly_info[key].short_description = _MULTIPLE_ERRORS
217+
anomalies.anomaly_info[key].description = _merge_descriptions(
218+
anomalies.anomaly_info[key], custom_anomaly_info)
219+
anomalies.anomaly_info[key].severity = max(
220+
anomalies.anomaly_info[key].severity, custom_anomaly_info.severity)
221+
anomalies.anomaly_info[key].reason.extend(custom_anomaly_info.reason)
222+
else:
223+
anomalies.anomaly_info[key].CopyFrom(custom_anomaly_info)
224+
# Also populate top-level descriptions.
225+
anomalies.anomaly_info[key].description = _merge_descriptions(
226+
custom_anomaly_info, None)
227+
if len(anomalies.anomaly_info[key].reason) > 1:
228+
anomalies.anomaly_info[key].short_description = _MULTIPLE_ERRORS
229+
else:
230+
anomalies.anomaly_info[
231+
key].short_description = custom_anomaly_info.reason[
232+
0].short_description
233+
return anomalies
234+
235+
192236
def validate_statistics(
193237
statistics: statistics_pb2.DatasetFeatureStatisticsList,
194238
schema: schema_pb2.Schema,
@@ -197,6 +241,8 @@ def validate_statistics(
197241
statistics_pb2.DatasetFeatureStatisticsList] = None,
198242
serving_statistics: Optional[
199243
statistics_pb2.DatasetFeatureStatisticsList] = None,
244+
custom_validation_config: Optional[
245+
custom_validation_config_pb2.CustomValidationConfig] = None
200246
) -> anomalies_pb2.Anomalies:
201247
"""Validates the input statistics against the provided input schema.
202248
@@ -248,6 +294,14 @@ def validate_statistics(
248294
distribution skew between current data and serving data. Configuration
249295
for skew detection can be done by specifying a `skew_comparator` in the
250296
schema.
297+
custom_validation_config: An optional config that can be used to specify
298+
custom validations to perform. If doing single-feature validations,
299+
the test feature will come from `statistics` and will be mapped to
300+
`feature` in the SQL query. If doing feature pair validations, the test
301+
feature will come from `statistics` and will be mapped to `feature_test`
302+
in the SQL query, and the base feature will come from
303+
`previous_statistics` and will be mapped to `feature_base` in the SQL
304+
query. Custom validations are not supported on Windows.
251305
252306
Returns:
253307
An Anomalies protocol buffer.
@@ -270,7 +324,9 @@ def validate_statistics(
270324
% type(previous_statistics).__name__)
271325

272326
return validate_statistics_internal(statistics, schema, environment,
273-
previous_statistics, serving_statistics)
327+
previous_statistics, serving_statistics,
328+
None, None, False,
329+
custom_validation_config)
274330

275331

276332
def validate_statistics_internal(
@@ -284,7 +340,9 @@ def validate_statistics_internal(
284340
previous_version_statistics: Optional[
285341
statistics_pb2.DatasetFeatureStatisticsList] = None,
286342
validation_options: Optional[vo.ValidationOptions] = None,
287-
enable_diff_regions: bool = False
343+
enable_diff_regions: bool = False,
344+
custom_validation_config: Optional[
345+
custom_validation_config_pb2.CustomValidationConfig] = None
288346
) -> anomalies_pb2.Anomalies:
289347
"""Validates the input statistics against the provided input schema.
290348
@@ -341,6 +399,14 @@ def validate_statistics_internal(
341399
enable_diff_regions: Specifies whether to include a comparison between the
342400
existing schema and the fixed schema in the Anomalies protocol buffer
343401
output.
402+
custom_validation_config: An optional config that can be used to specify
403+
custom validations to perform. If doing single-feature validations,
404+
the test feature will come from `statistics` and will be mapped to
405+
`feature` in the SQL query. If doing feature pair validations, the test
406+
feature will come from `statistics` and will be mapped to `feature_test`
407+
in the SQL query, and the base feature will come from
408+
`previous_statistics` and will be mapped to `feature_base` in the SQL
409+
query. Custom validations are not supported on Windows.
344410
345411
Returns:
346412
An Anomalies protocol buffer.
@@ -449,10 +515,23 @@ def validate_statistics_internal(
449515
# Parse the serialized Anomalies proto.
450516
result = anomalies_pb2.Anomalies()
451517
result.ParseFromString(anomalies_proto_string)
518+
519+
if custom_validation_config is not None:
520+
serialized_previous_statistics = previous_span_statistics.SerializeToString(
521+
) if previous_span_statistics is not None else ''
522+
custom_anomalies_string = (
523+
pywrap_tensorflow_data_validation.CustomValidateStatistics(
524+
tf.compat.as_bytes(statistics.SerializeToString()),
525+
tf.compat.as_bytes(serialized_previous_statistics),
526+
tf.compat.as_bytes(custom_validation_config.SerializeToString()),
527+
tf.compat.as_bytes(environment)))
528+
custom_anomalies = anomalies_pb2.Anomalies()
529+
custom_anomalies.ParseFromString(custom_anomalies_string)
530+
result = _merge_custom_anomalies(result, custom_anomalies)
531+
452532
return result
453533

454534

455-
# TODO(b/239095455): Also integrate with validate_statistics.
456535
def custom_validate_statistics(
457536
statistics: statistics_pb2.DatasetFeatureStatisticsList,
458537
validations: custom_validation_config_pb2.CustomValidationConfig,

tensorflow_data_validation/api/validation_api_test.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2087,6 +2087,82 @@ def test_validate_stats_invalid_previous_version_stats_multiple_datasets(
20872087
schema,
20882088
previous_version_statistics=previous_version_stats)
20892089

2090+
# Custom validation uses ZetaSQL, which cannot be compiled on Windows.
2091+
@unittest.skipIf(
2092+
sys.platform.startswith('win'),
2093+
'Custom validation is not supported on Windows.')
2094+
def test_validate_stats_with_custom_validations(self):
2095+
statistics = text_format.Parse(
2096+
"""
2097+
datasets{
2098+
num_examples: 10
2099+
features {
2100+
path { step: 'annotated_enum' }
2101+
type: STRING
2102+
string_stats {
2103+
common_stats {
2104+
num_missing: 3
2105+
num_non_missing: 7
2106+
min_num_values: 1
2107+
max_num_values: 1
2108+
}
2109+
unique: 3
2110+
rank_histogram {
2111+
buckets {
2112+
label: "D"
2113+
sample_count: 1
2114+
}
2115+
}
2116+
}
2117+
}
2118+
}
2119+
""", statistics_pb2.DatasetFeatureStatisticsList())
2120+
schema = text_format.Parse(
2121+
"""
2122+
feature {
2123+
name: 'annotated_enum'
2124+
type: BYTES
2125+
unique_constraints {
2126+
min: 4
2127+
max: 4
2128+
}
2129+
}
2130+
""", schema_pb2.Schema())
2131+
validation_config = text_format.Parse("""
2132+
feature_validations {
2133+
feature_path { step: 'annotated_enum' }
2134+
validations {
2135+
sql_expression: 'feature.string_stats.common_stats.num_missing < 3'
2136+
severity: WARNING
2137+
description: 'Feature has too many missing.'
2138+
}
2139+
}
2140+
""", custom_validation_config_pb2.CustomValidationConfig())
2141+
expected_anomalies = {
2142+
'annotated_enum':
2143+
text_format.Parse(
2144+
"""
2145+
path { step: 'annotated_enum' }
2146+
short_description: 'Multiple errors'
2147+
description: 'Expected at least 4 unique values but found only 3. Custom validation triggered anomaly. Query: feature.string_stats.common_stats.num_missing < 3 Test dataset: default slice'
2148+
severity: ERROR
2149+
reason {
2150+
type: FEATURE_TYPE_LOW_UNIQUE
2151+
short_description: 'Low number of unique values'
2152+
description: 'Expected at least 4 unique values but found only 3.'
2153+
}
2154+
reason {
2155+
type: CUSTOM_VALIDATION
2156+
short_description: 'Feature has too many missing.'
2157+
description: 'Custom validation triggered anomaly. Query: feature.string_stats.common_stats.num_missing < 3 Test dataset: default slice'
2158+
}
2159+
""", anomalies_pb2.AnomalyInfo())
2160+
}
2161+
anomalies = validation_api.validate_statistics(statistics, schema, None,
2162+
None, None,
2163+
validation_config)
2164+
self._assert_equal_anomalies(anomalies, expected_anomalies)
2165+
20902166
def test_validate_stats_internal_with_previous_version_stats(self):
20912167
statistics = text_format.Parse(
20922168
"""

0 commit comments

Comments
 (0)