Skip to content

Commit 4cc4826

Browse files
committed
Fix bug around handling WeightedFeatures in validation.
PiperOrigin-RevId: 290172331
1 parent 42d55da commit 4cc4826

File tree

3 files changed

+87
-3
lines changed

3 files changed

+87
-3
lines changed

tensorflow_data_validation/anomalies/schema.cc

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ limitations under the License.
4141
#include "tensorflow/core/platform/protobuf.h"
4242
#include "tensorflow/core/platform/types.h"
4343
#include "tensorflow_metadata/proto/v0/anomalies.pb.h"
44+
#include "tensorflow_metadata/proto/v0/schema.pb.h"
4445
#include "tensorflow_metadata/proto/v0/statistics.pb.h"
4546

4647
namespace tensorflow {
@@ -55,6 +56,7 @@ using ::tensorflow::metadata::v0::Feature;
5556
using ::tensorflow::metadata::v0::FeatureNameStatistics;
5657
using ::tensorflow::metadata::v0::SparseFeature;
5758
using ::tensorflow::metadata::v0::StringDomain;
59+
using ::tensorflow::metadata::v0::WeightedFeature;
5860
using PathProto = ::tensorflow::metadata::v0::Path;
5961

6062
constexpr char kTrainingServingSkew[] = "Training/Serving skew";
@@ -205,6 +207,13 @@ tensorflow::Status Schema::UpdateFeature(
205207
Feature* feature = GetExistingFeature(feature_stats_view.GetPath());
206208
SparseFeature* sparse_feature =
207209
GetExistingSparseFeature(feature_stats_view.GetPath());
210+
const WeightedFeature* weighted_feature =
211+
GetExistingWeightedFeature(feature_stats_view.GetPath());
212+
if (weighted_feature != nullptr) {
213+
// TODO(b/141961105): Add validation logic for weighted features.
214+
return Status::OK();
215+
}
216+
208217
if (sparse_feature != nullptr &&
209218
!::tensorflow::data_validation::SparseFeatureIsDeprecated(
210219
*sparse_feature)) {
@@ -554,7 +563,8 @@ tensorflow::metadata::v0::Schema Schema::GetSchema() const { return schema_; }
554563

555564
bool Schema::FeatureExists(const Path& path) {
556565
return GetExistingFeature(path) != nullptr ||
557-
GetExistingSparseFeature(path) != nullptr;
566+
GetExistingSparseFeature(path) != nullptr ||
567+
GetExistingWeightedFeature(path) != nullptr;
558568
}
559569

560570
Feature* Schema::GetExistingFeature(const Path& path) {
@@ -595,6 +605,23 @@ SparseFeature* Schema::GetExistingSparseFeature(const Path& path) {
595605
parent_feature->mutable_struct_domain()->mutable_sparse_feature());
596606
}
597607
}
608+
609+
const WeightedFeature* Schema::GetExistingWeightedFeature(
610+
const Path& path) const {
611+
CHECK(!path.empty());
612+
if (path.size() != 1) {
613+
// Weighted features are always top-level features with single-step paths.
614+
return nullptr;
615+
}
616+
auto name = path.last_step();
617+
for (const WeightedFeature& weighted_feature : schema_.weighted_feature()) {
618+
if (weighted_feature.name() == name) {
619+
return &weighted_feature;
620+
}
621+
}
622+
return nullptr;
623+
}
624+
598625
Feature* Schema::GetNewFeature(const Path& path) {
599626
CHECK(!path.empty());
600627
if (path.size() > 1) {

tensorflow_data_validation/anomalies/schema.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ class Schema {
168168
using Feature = tensorflow::metadata::v0::Feature;
169169
using SparseFeature = tensorflow::metadata::v0::SparseFeature;
170170
using StringDomain = tensorflow::metadata::v0::StringDomain;
171+
using WeightedFeature = tensorflow::metadata::v0::WeightedFeature;
171172
// Updates Schema given new data, but only on the columns specified.
172173
// If you have a new, previously unseen column on the list of columns to
173174
// consider, then config is used to create it.
@@ -226,12 +227,15 @@ class Schema {
226227
// values if necessary.
227228
StringDomain* GetStringDomain(const string& name);
228229

229-
// Gets an existing feature, and returns null if it doesn't exist.
230+
// Gets an existing feature, or returns null if it doesn't exist.
230231
Feature* GetExistingFeature(const Path& path);
231232

232-
// Gets an existing sparse feature, and returns null if it doesn't exist.
233+
// Gets an existing sparse feature, or returns null if it doesn't exist.
233234
SparseFeature* GetExistingSparseFeature(const Path& path);
234235

236+
// Gets an existing weighted feature, or returns null if it doesn't exist.
237+
const WeightedFeature* GetExistingWeightedFeature(const Path& path) const;
238+
235239
// Gets a new feature. Assumes that the feature does not already exist.
236240
Feature* GetNewFeature(const Path& path);
237241

tensorflow_data_validation/api/validation_api_test.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,59 @@ def test_validate_stats(self):
857857
anomalies = validation_api.validate_statistics(statistics, schema)
858858
self._assert_equal_anomalies(anomalies, expected_anomalies)
859859

860+
def test_validate_stats_weighted_feature(self):
861+
# This test is not intended to verify the anomaly detection logic, but just
862+
# ensure that the validation API will not fail on custom stats for weighted
863+
# features.
864+
schema = text_format.Parse(
865+
"""
866+
feature {
867+
name: "f"
868+
}
869+
feature {
870+
name: "w"
871+
}
872+
weighted_feature {
873+
name: "weighted_feature"
874+
feature {
875+
step: "f"
876+
}
877+
weight_feature {
878+
step: "w"
879+
}
880+
}
881+
""", schema_pb2.Schema())
882+
statistics = text_format.Parse(
883+
"""
884+
datasets {
885+
num_examples: 10
886+
features {
887+
path { step: 'weighted_feature' }
888+
custom_stats {
889+
name: 'missing_weight'
890+
num: 0.0
891+
}
892+
custom_stats {
893+
name: 'missing_value'
894+
num: 0.0
895+
}
896+
custom_stats {
897+
name: 'min_weight_length_diff'
898+
num: 0.0
899+
}
900+
custom_stats {
901+
name: 'max_weight_length_diff'
902+
num: 0.0
903+
}
904+
}
905+
}
906+
""", statistics_pb2.DatasetFeatureStatisticsList())
907+
expected_anomalies = {}
908+
909+
# Validate the stats.
910+
anomalies = validation_api.validate_statistics(statistics, schema)
911+
self._assert_equal_anomalies(anomalies, expected_anomalies)
912+
860913
# pylint: disable=line-too-long
861914
_annotated_enum_anomaly_info = """
862915
path {

0 commit comments

Comments
 (0)