18
18
from __future__ import division
19
19
from __future__ import print_function
20
20
21
+ import itertools
21
22
import logging
22
23
from typing import Callable , Iterable , List , Optional , Text , Tuple , Set
23
24
import apache_beam as beam
58
59
anomalies_pb2 .AnomalyInfo .DATASET_HIGH_NUM_EXAMPLES ,
59
60
])
60
61
62
+ _MULTIPLE_ERRORS = 'Multiple errors'
63
+
61
64
62
65
def infer_schema (
63
66
statistics : statistics_pb2 .DatasetFeatureStatisticsList ,
@@ -189,6 +192,47 @@ def update_schema(schema: schema_pb2.Schema,
189
192
return result
190
193
191
194
195
+ def _merge_descriptions (
196
+ anomaly_info : anomalies_pb2 .AnomalyInfo ,
197
+ other_anomaly_info : Optional [anomalies_pb2 .AnomalyInfo ]) -> str :
198
+ """Merges anomaly descriptions."""
199
+ descriptions = []
200
+ if other_anomaly_info is not None :
201
+ for reason in itertools .chain (anomaly_info .reason ,
202
+ other_anomaly_info .reason ):
203
+ descriptions .append (reason .description )
204
+ else :
205
+ descriptions = [reason .description for reason in anomaly_info .reason ]
206
+ return ' ' .join (descriptions )
207
+
208
+
209
+ def _merge_custom_anomalies (
210
+ anomalies : anomalies_pb2 .Anomalies ,
211
+ custom_anomalies : anomalies_pb2 .Anomalies ) -> anomalies_pb2 .Anomalies :
212
+ """Merges custom_anomalies with anomalies."""
213
+ for key , custom_anomaly_info in custom_anomalies .anomaly_info .items ():
214
+ if key in anomalies .anomaly_info :
215
+ # If the key is found in in both inputs, we know it has multiple errors.
216
+ anomalies .anomaly_info [key ].short_description = _MULTIPLE_ERRORS
217
+ anomalies .anomaly_info [key ].description = _merge_descriptions (
218
+ anomalies .anomaly_info [key ], custom_anomaly_info )
219
+ anomalies .anomaly_info [key ].severity = max (
220
+ anomalies .anomaly_info [key ].severity , custom_anomaly_info .severity )
221
+ anomalies .anomaly_info [key ].reason .extend (custom_anomaly_info .reason )
222
+ else :
223
+ anomalies .anomaly_info [key ].CopyFrom (custom_anomaly_info )
224
+ # Also populate top-level descriptions.
225
+ anomalies .anomaly_info [key ].description = _merge_descriptions (
226
+ custom_anomaly_info , None )
227
+ if len (anomalies .anomaly_info [key ].reason ) > 1 :
228
+ anomalies .anomaly_info [key ].short_description = _MULTIPLE_ERRORS
229
+ else :
230
+ anomalies .anomaly_info [
231
+ key ].short_description = custom_anomaly_info .reason [
232
+ 0 ].short_description
233
+ return anomalies
234
+
235
+
192
236
def validate_statistics (
193
237
statistics : statistics_pb2 .DatasetFeatureStatisticsList ,
194
238
schema : schema_pb2 .Schema ,
@@ -197,6 +241,8 @@ def validate_statistics(
197
241
statistics_pb2 .DatasetFeatureStatisticsList ] = None ,
198
242
serving_statistics : Optional [
199
243
statistics_pb2 .DatasetFeatureStatisticsList ] = None ,
244
+ custom_validation_config : Optional [
245
+ custom_validation_config_pb2 .CustomValidationConfig ] = None
200
246
) -> anomalies_pb2 .Anomalies :
201
247
"""Validates the input statistics against the provided input schema.
202
248
@@ -248,6 +294,14 @@ def validate_statistics(
248
294
distribution skew between current data and serving data. Configuration
249
295
for skew detection can be done by specifying a `skew_comparator` in the
250
296
schema.
297
+ custom_validation_config: An optional config that can be used to specify
298
+ custom validations to perform. If doing single-feature validations,
299
+ the test feature will come from `statistics` and will be mapped to
300
+ `feature` in the SQL query. If doing feature pair validations, the test
301
+ feature will come from `statistics` and will be mapped to `feature_test`
302
+ in the SQL query, and the base feature will come from
303
+ `previous_statistics` and will be mapped to `feature_base` in the SQL
304
+ query. Custom validations are not supported on Windows.
251
305
252
306
Returns:
253
307
An Anomalies protocol buffer.
@@ -270,7 +324,9 @@ def validate_statistics(
270
324
% type (previous_statistics ).__name__ )
271
325
272
326
return validate_statistics_internal (statistics , schema , environment ,
273
- previous_statistics , serving_statistics )
327
+ previous_statistics , serving_statistics ,
328
+ None , None , False ,
329
+ custom_validation_config )
274
330
275
331
276
332
def validate_statistics_internal (
@@ -284,7 +340,9 @@ def validate_statistics_internal(
284
340
previous_version_statistics : Optional [
285
341
statistics_pb2 .DatasetFeatureStatisticsList ] = None ,
286
342
validation_options : Optional [vo .ValidationOptions ] = None ,
287
- enable_diff_regions : bool = False
343
+ enable_diff_regions : bool = False ,
344
+ custom_validation_config : Optional [
345
+ custom_validation_config_pb2 .CustomValidationConfig ] = None
288
346
) -> anomalies_pb2 .Anomalies :
289
347
"""Validates the input statistics against the provided input schema.
290
348
@@ -341,6 +399,14 @@ def validate_statistics_internal(
341
399
enable_diff_regions: Specifies whether to include a comparison between the
342
400
existing schema and the fixed schema in the Anomalies protocol buffer
343
401
output.
402
+ custom_validation_config: An optional config that can be used to specify
403
+ custom validations to perform. If doing single-feature validations,
404
+ the test feature will come from `statistics` and will be mapped to
405
+ `feature` in the SQL query. If doing feature pair validations, the test
406
+ feature will come from `statistics` and will be mapped to `feature_test`
407
+ in the SQL query, and the base feature will come from
408
+ `previous_statistics` and will be mapped to `feature_base` in the SQL
409
+ query. Custom validations are not supported on Windows.
344
410
345
411
Returns:
346
412
An Anomalies protocol buffer.
@@ -449,10 +515,23 @@ def validate_statistics_internal(
449
515
# Parse the serialized Anomalies proto.
450
516
result = anomalies_pb2 .Anomalies ()
451
517
result .ParseFromString (anomalies_proto_string )
518
+
519
+ if custom_validation_config is not None :
520
+ serialized_previous_statistics = previous_span_statistics .SerializeToString (
521
+ ) if previous_span_statistics is not None else ''
522
+ custom_anomalies_string = (
523
+ pywrap_tensorflow_data_validation .CustomValidateStatistics (
524
+ tf .compat .as_bytes (statistics .SerializeToString ()),
525
+ tf .compat .as_bytes (serialized_previous_statistics ),
526
+ tf .compat .as_bytes (custom_validation_config .SerializeToString ()),
527
+ tf .compat .as_bytes (environment )))
528
+ custom_anomalies = anomalies_pb2 .Anomalies ()
529
+ custom_anomalies .ParseFromString (custom_anomalies_string )
530
+ result = _merge_custom_anomalies (result , custom_anomalies )
531
+
452
532
return result
453
533
454
534
455
- # TODO(b/239095455): Also integrate with validate_statistics.
456
535
def custom_validate_statistics (
457
536
statistics : statistics_pb2 .DatasetFeatureStatisticsList ,
458
537
validations : custom_validation_config_pb2 .CustomValidationConfig ,
0 commit comments