Skip to content

Commit 276ab4c

Browse files
cavenesstfx-copybara
authored andcommitted
Add a custom_validate_statistics function to the TFDV validation API.
PiperOrigin-RevId: 488706802
1 parent d04d93b commit 276ab4c

File tree

16 files changed

+439
-10
lines changed

16 files changed

+439
-10
lines changed

.bazelrc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Needed to work with ZetaSQL dependency.
2+
build --cxxopt="-std=c++17"
3+
4+
# Due to the invalid escape sequence in rules_foreign_cc
5+
# (e.g. "\(" in windows_commands.bzl) and the bazel 4.0.0 updates
6+
# (https://github.com/bazelbuild/bazel/commit/73402fa4aa5b9de46c9a4042b75e6fb332ad4a7f).
7+
build --incompatible_restrict_string_escapes=false
8+
9+
# icu@: In create_linking_context: in call to create_linking_context(),
10+
# parameter 'user_link_flags' is deprecated and will be removed soon.
11+
# It may be temporarily re-enabled by setting --incompatible_require_linker_input_cc_api=false
12+
build --incompatible_require_linker_input_cc_api=false
13+

RELEASE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99

1010
## Major Features and Improvements
1111

12+
* Add a `custom_validate_statistics` function to the validation API. Note that
13+
this function is not available on Windows.
14+
1215
## Bug Fixes and Other Changes
1316

1417
* Fix bug in implementation of `semantic_domain_stats_sample_rate`.

WORKSPACE

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,63 @@ load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies", "rules_
6464
rules_proto_dependencies()
6565
rules_proto_toolchains()
6666

67+
# TODO(b/239095455): Change to using a tfx-bsl workspace macro to load these
68+
# dependencies.
69+
# Needed by zetasql.
70+
PROTOBUF_COMMIT = "fde7cf7358ec7cd69e8db9be4f1fa6a5c431386a" # 3.13.0
71+
http_archive(
72+
name = "com_google_protobuf",
73+
sha256 = "e589e39ef46fb2b3b476b3ca355bd324e5984cbdfac19f0e1625f0042e99c276",
74+
strip_prefix = "protobuf-%s" % PROTOBUF_COMMIT,
75+
urls = [
76+
"https://storage.googleapis.com/grpc-bazel-mirror/github.com/google/protobuf/archive/%s.tar.gz" % PROTOBUF_COMMIT,
77+
"https://github.com/google/protobuf/archive/%s.tar.gz" % PROTOBUF_COMMIT,
78+
],
79+
)
80+
81+
# Needed by abseil-py by zetasql.
82+
http_archive(
83+
name = "six_archive",
84+
urls = [
85+
"http://mirror.bazel.build/pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz",
86+
"https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz",
87+
],
88+
sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a",
89+
strip_prefix = "six-1.10.0",
90+
build_file = "//third_party:six.BUILD"
91+
)
92+
93+
ABSL_COMMIT = "e1d388e7e74803050423d035e4374131b9b57919" # lts_20210324.1
94+
http_archive(
95+
name = "com_google_absl",
96+
urls = ["https://github.com/abseil/abseil-cpp/archive/%s.zip" % ABSL_COMMIT],
97+
sha256 = "baebd1536bec56ae7d7c060c20c01af89ecba2c0b1bc8992b652520655395f94",
98+
strip_prefix = "abseil-cpp-%s" % ABSL_COMMIT,
99+
)
100+
101+
ZETASQL_COMMIT = "5ccb05880e72ab9ff75dd6b05d7b0acce53f1ea2" # 04/22/2021
102+
http_archive(
103+
name = "com_google_zetasql",
104+
urls = ["https://github.com/google/zetasql/archive/%s.zip" % ZETASQL_COMMIT],
105+
strip_prefix = "zetasql-%s" % ZETASQL_COMMIT,
106+
sha256 = "4ca4e45f457926484822701ec15ca4d0172b01d7ce43c0b34c6f3ab98c95b241",
107+
)
108+
109+
load("@com_google_zetasql//bazel:zetasql_deps_step_1.bzl", "zetasql_deps_step_1")
110+
111+
zetasql_deps_step_1()
112+
113+
load("@com_google_zetasql//bazel:zetasql_deps_step_2.bzl", "zetasql_deps_step_2")
114+
115+
zetasql_deps_step_2(
116+
analyzer_deps = True,
117+
evaluator_deps = True,
118+
tools_deps = False,
119+
java_deps = False,
120+
testing_deps = False,
121+
)
122+
123+
67124
# Please add all new TensorFlow Data Validation dependencies in workspace.bzl.
68125
load("//tensorflow_data_validation:workspace.bzl", "tf_data_validation_workspace")
69126

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def finalize_options(self):
7373
'installation instruction.')
7474
self._additional_build_options = []
7575
if platform.system() == 'Darwin':
76-
self._additional_build_options = ['--macos_minimum_os=10.9']
76+
self._additional_build_options = ['--macos_minimum_os=10.14']
7777
elif platform.system() == 'Windows':
7878
self._additional_build_options = ['--copt=-DWIN32_LEAN_AND_MEAN']
7979

tensorflow_data_validation/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
licenses(["notice"]) # Apache 2.0
22

3+
load("@bazel_skylib//lib:selects.bzl", "selects")
4+
35
config_setting(
46
name = "windows",
57
constraint_values = [
@@ -13,12 +15,14 @@ sh_binary(
1315
data = select({
1416
":windows": [
1517
"//tensorflow_data_validation/skew/protos:feature_skew_results_pb2.py",
18+
"//tensorflow_data_validation/anomalies/proto:custom_validation_config_pb2.py",
1619
"//tensorflow_data_validation/anomalies/proto:validation_config_pb2.py",
1720
"//tensorflow_data_validation/anomalies/proto:validation_metadata_pb2.py",
1821
"//tensorflow_data_validation/pywrap:tensorflow_data_validation_extension.pyd",
1922
],
2023
"//conditions:default": [
2124
"//tensorflow_data_validation/skew/protos:feature_skew_results_pb2.py",
25+
"//tensorflow_data_validation/anomalies/proto:custom_validation_config_pb2.py",
2226
"//tensorflow_data_validation/anomalies/proto:validation_config_pb2.py",
2327
"//tensorflow_data_validation/anomalies/proto:validation_metadata_pb2.py",
2428
"//tensorflow_data_validation/pywrap:tensorflow_data_validation_extension.so",

tensorflow_data_validation/anomalies/BUILD

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,10 +439,11 @@ cc_library(
439439
deps = [
440440
":schema",
441441
":statistics_view",
442-
"//tensorflow_data_validation/anomalies/proto:custom_validation_config_cc_proto",
443-
"//third_party/py/tfx_bsl/cc/statistics:sql_util",
442+
"//tensorflow_data_validation/anomalies/proto:custom_validation_config_proto",
444443
"@com_github_tensorflow_metadata//tensorflow_metadata/proto/v0:metadata_v0_proto_cc_pb2",
444+
"@com_github_tfx_bsl//tfx_bsl/cc/statistics:sql_util",
445445
"@com_google_absl//absl/container:flat_hash_map",
446+
"@com_google_absl//absl/types:optional",
446447
"@org_tensorflow//tensorflow/core:lib",
447448
],
448449
)

tensorflow_data_validation/anomalies/custom_validation.cc

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
15-
1615
#include "tensorflow_data_validation/anomalies/custom_validation.h"
1716

1817
#include "absl/container/flat_hash_map.h"
1918
#include "tensorflow_data_validation/anomalies/schema_util.h"
2019
#include "tensorflow_data_validation/anomalies/statistics_view.h"
21-
#include "third_party/py/tfx_bsl/cc/statistics/sql_util.h"
20+
#include "tfx_bsl/cc/statistics/sql_util.h"
2221
#include "tensorflow/core/lib/core/errors.h"
23-
#include "tensorflow/core/platform/errors.h"
24-
#include "tensorflow/tsl/platform/errors.h"
2522
#include "tensorflow_metadata/proto/v0/anomalies.pb.h"
2623
#include "tensorflow_metadata/proto/v0/path.pb.h"
2724
#include "tensorflow_metadata/proto/v0/statistics.pb.h"
@@ -255,5 +252,48 @@ Status CustomValidateStatistics(
255252
return tensorflow::Status();
256253
}
257254

255+
Status CustomValidateStatisticsWithSerializedInputs(
256+
const std::string& serialized_test_statistics,
257+
const std::string& serialized_base_statistics,
258+
const std::string& serialized_validations,
259+
const std::string& serialized_environment,
260+
std::string* serialized_anomalies_proto) {
261+
metadata::v0::DatasetFeatureStatisticsList test_statistics;
262+
metadata::v0::DatasetFeatureStatisticsList base_statistics;
263+
metadata::v0::DatasetFeatureStatisticsList* base_statistics_ptr = nullptr;
264+
if (!test_statistics.ParseFromString(serialized_test_statistics)) {
265+
return tensorflow::errors::InvalidArgument(
266+
"Failed to parse DatasetFeatureStatistics proto.");
267+
}
268+
if (!serialized_base_statistics.empty()) {
269+
if (!base_statistics.ParseFromString(serialized_base_statistics)) {
270+
return tensorflow::errors::InvalidArgument(
271+
"Failed to parse DatasetFeatureStatistics proto.");
272+
}
273+
base_statistics_ptr = &base_statistics;
274+
}
275+
CustomValidationConfig validations;
276+
if (!validations.ParseFromString(serialized_validations)) {
277+
return tensorflow::errors::InvalidArgument(
278+
"Failed to parse CustomValidationConfig proto.");
279+
}
280+
absl::optional<std::string> environment = absl::nullopt;
281+
if (!serialized_environment.empty()) {
282+
environment = serialized_environment;
283+
}
284+
metadata::v0::Anomalies anomalies;
285+
const tensorflow::Status status =
286+
CustomValidateStatistics(test_statistics, base_statistics_ptr,
287+
validations, environment, &anomalies);
288+
if (!status.ok()) {
289+
return tensorflow::errors::Internal("Failed to run custom validations.");
290+
}
291+
if (!anomalies.SerializeToString(serialized_anomalies_proto)) {
292+
return tensorflow::errors::Internal(
293+
"Failed to serialize Anomalies output proto to string.");
294+
}
295+
return tensorflow::Status();
296+
}
297+
258298
} // namespace data_validation
259299
} // namespace tensorflow

tensorflow_data_validation/anomalies/custom_validation.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515
#ifndef THIRD_PARTY_PY_TENSORFLOW_DATA_VALIDATION_ANOMALIES_CUSTOM_VALIDATION_H_
1616
#define THIRD_PARTY_PY_TENSORFLOW_DATA_VALIDATION_ANOMALIES_CUSTOM_VALIDATION_H_
1717

18+
#include "absl/types/optional.h"
1819
#include "tensorflow_data_validation/anomalies/proto/custom_validation_config.pb.h"
1920
#include "tensorflow/core/lib/core/status.h"
2021
#include "tensorflow_metadata/proto/v0/anomalies.pb.h"
@@ -33,6 +34,15 @@ Status CustomValidateStatistics(
3334
const CustomValidationConfig& validations,
3435
const absl::optional<string> environment, metadata::v0::Anomalies* result);
3536

37+
// Like CustomValidateStatistics but with serialized inputs. Used for doing
38+
// custom validation in Python.
39+
Status CustomValidateStatisticsWithSerializedInputs(
40+
const std::string& serialized_test_statistics,
41+
const std::string& serialized_base_statistics,
42+
const std::string& serialized_validations,
43+
const std::string& serialized_environment,
44+
std::string* serialized_anomalies_proto);
45+
3646
} // namespace data_validation
3747
} // namespace tensorflow
3848

tensorflow_data_validation/api/validation_api.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import tensorflow as tf
2626
from tensorflow_data_validation import constants
2727
from tensorflow_data_validation import types
28+
from tensorflow_data_validation.anomalies.proto import custom_validation_config_pb2
2829
from tensorflow_data_validation.anomalies.proto import validation_config_pb2
2930
from tensorflow_data_validation.anomalies.proto import validation_metadata_pb2
3031
from tensorflow_data_validation.api import validation_options as vo
@@ -451,6 +452,58 @@ def validate_statistics_internal(
451452
return result
452453

453454

455+
# TODO(b/239095455): Also integrate with validate_statistics.
456+
def custom_validate_statistics(
457+
statistics: statistics_pb2.DatasetFeatureStatisticsList,
458+
validations: custom_validation_config_pb2.CustomValidationConfig,
459+
baseline_statistics: Optional[
460+
statistics_pb2.DatasetFeatureStatisticsList] = None,
461+
environment: Optional[str] = None) -> anomalies_pb2.Anomalies:
462+
"""Validates the input statistics with the user-supplied SQL queries.
463+
464+
If the SQL query from a user-supplied validation returns False, TFDV will
465+
return an anomaly for that validation. In single feature valdiations, the test
466+
feature will be mapped to `feature` in the SQL query. In two feature
467+
validations, the test feature will be mapped to `feature_test` in the SQL
468+
query, and the base feature will be mapped to `feature_base`.
469+
470+
If an optional `environment` is supplied, TFDV will run validations with
471+
that environment specified and validations with no environment specified.
472+
473+
NOTE: This function is not supported on Windows.
474+
475+
Args:
476+
statistics: A DatasetFeatureStatisticsList protocol buffer that holds the
477+
statistics to validate.
478+
validations: Configuration that specifies the dataset(s) and feature(s) to
479+
validate and the SQL query to use for the validation. The SQL query must
480+
return a boolean value.
481+
baseline_statistics: An optional DatasetFeatureStatisticsList protocol
482+
buffer that holds the baseline statistics used when validating feature
483+
pairs.
484+
environment: If supplied, TFDV will run validations with that
485+
environment specified and validations with no environment specified. If
486+
not supplied, TFDV will run all validations.
487+
Returns:
488+
An Anomalies protocol buffer.
489+
"""
490+
serialized_statistics = statistics.SerializeToString()
491+
serialized_baseline_statistics = (
492+
baseline_statistics.SerializeToString()
493+
if baseline_statistics is not None else '')
494+
serialized_validations = validations.SerializeToString()
495+
environment = '' if environment is None else environment
496+
serialized_anomalies = (
497+
pywrap_tensorflow_data_validation.CustomValidateStatistics(
498+
tf.compat.as_bytes(serialized_statistics),
499+
tf.compat.as_bytes(serialized_baseline_statistics),
500+
tf.compat.as_bytes(serialized_validations),
501+
tf.compat.as_bytes(environment)))
502+
result = anomalies_pb2.Anomalies()
503+
result.ParseFromString(serialized_anomalies)
504+
return result
505+
506+
454507
def _remove_features_missing_common_stats(
455508
stats: statistics_pb2.DatasetFeatureStatistics
456509
) -> statistics_pb2.DatasetFeatureStatistics:

0 commit comments

Comments
 (0)