Skip to content

Commit 26a198b

Browse files
committed
Handle schema proto serialization.
PiperOrigin-RevId: 299223398
1 parent 646e1dd commit 26a198b

File tree

2 files changed

+86
-29
lines changed

2 files changed

+86
-29
lines changed

tensorflow_data_validation/statistics/stats_options.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from tensorflow_data_validation.statistics.generators import stats_generator
2828
from typing import List, Optional, Text
2929

30+
from google.protobuf import json_format
3031
from tensorflow_metadata.proto.v0 import schema_pb2
3132

3233

@@ -145,14 +146,17 @@ def to_json(self) -> Text:
145146
Custom generators and slice_functions are skipped, meaning that they will
146147
not be used when running TFDV in a setting where the stats options have been
147148
json-serialized, first. This will happen in the case where TFDV is run as a
148-
TFX component.
149+
TFX component. The schema proto will be json_encoded.
149150
150151
Returns:
151152
A JSON representation of a filtered version of __dict__.
152153
"""
153154
options_dict = copy.copy(self.__dict__)
154155
options_dict['_slice_functions'] = None
155156
options_dict['_generators'] = None
157+
if self.schema:
158+
del options_dict['_schema']
159+
options_dict['schema_json'] = json_format.MessageToJson(self.schema)
156160
return json.dumps(options_dict)
157161

158162
@classmethod
@@ -167,8 +171,13 @@ def from_json(cls, options_json: Text) -> 'StatsOptions':
167171
A StatsOptions instance constructed by setting the __dict__ attribute to
168172
the deserialized value of options_json.
169173
"""
174+
options_dict = json.loads(options_json)
175+
if 'schema_json' in options_dict:
176+
options_dict['_schema'] = json_format.Parse(options_dict['schema_json'],
177+
schema_pb2.Schema())
178+
del options_dict['schema_json']
170179
options = cls()
171-
options.__dict__ = json.loads(options_json)
180+
options.__dict__ = options_dict
172181
return options
173182

174183
@property

tensorflow_data_validation/statistics/stats_options_test.py

Lines changed: 75 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20-
import json
21-
2220
from absl.testing import absltest
2321
from absl.testing import parameterized
2422
from tensorflow_data_validation import types
2523
from tensorflow_data_validation.statistics import stats_options
2624
from tensorflow_data_validation.statistics.generators import lift_stats_generator
2725
from tensorflow_data_validation.utils import slicing_util
2826

27+
from tensorflow.python.util.protobuf import compare # pylint: disable=g-direct-tensorflow-import
28+
from tensorflow_metadata.proto.v0 import schema_pb2
2929

3030
INVALID_STATS_OPTIONS = [
3131
{
@@ -220,19 +220,81 @@ def test_stats_options(self, stats_options_kwargs, exception_type,
220220
with self.assertRaisesRegexp(exception_type, error_message):
221221
stats_options.StatsOptions(**stats_options_kwargs)
222222

223-
def test_stats_options_to_json(self):
223+
def test_stats_options_json_round_trip(self):
224+
generators = [
225+
lift_stats_generator.LiftStatsGenerator(
226+
schema=None,
227+
y_path=types.FeaturePath(['label']),
228+
x_paths=[types.FeaturePath(['feature'])])
229+
]
230+
feature_whitelist = ['a']
231+
schema = schema_pb2.Schema(feature=[schema_pb2.Feature(name='f')])
232+
label_feature = 'label'
233+
weight_feature = 'weight'
234+
slice_functions = [slicing_util.get_feature_value_slicer({'b': None})]
235+
sample_rate = 0.01
236+
num_top_values = 21
237+
frequency_threshold = 2
238+
weighted_frequency_threshold = 2.0
239+
num_rank_histogram_buckets = 1001
240+
num_values_histogram_buckets = 11
241+
num_histogram_buckets = 11
242+
num_quantiles_histogram_buckets = 11
243+
epsilon = 0.02
244+
infer_type_from_schema = True
245+
desired_batch_size = 100
246+
enable_semantic_domain_stats = True
247+
semantic_domain_stats_sample_rate = 0.1
248+
224249
options = stats_options.StatsOptions(
225-
generators=[
226-
lift_stats_generator.LiftStatsGenerator(
227-
schema=None,
228-
y_path=types.FeaturePath(['label']),
229-
x_paths=[types.FeaturePath(['feature'])])
230-
],
231-
slice_functions=[slicing_util.get_feature_value_slicer({'b': None})])
250+
generators=generators,
251+
feature_whitelist=feature_whitelist,
252+
schema=schema,
253+
label_feature=label_feature,
254+
weight_feature=weight_feature,
255+
slice_functions=slice_functions,
256+
sample_rate=sample_rate,
257+
num_top_values=num_top_values,
258+
frequency_threshold=frequency_threshold,
259+
weighted_frequency_threshold=weighted_frequency_threshold,
260+
num_rank_histogram_buckets=num_rank_histogram_buckets,
261+
num_values_histogram_buckets=num_values_histogram_buckets,
262+
num_histogram_buckets=num_histogram_buckets,
263+
num_quantiles_histogram_buckets=num_quantiles_histogram_buckets,
264+
epsilon=epsilon,
265+
infer_type_from_schema=infer_type_from_schema,
266+
desired_batch_size=desired_batch_size,
267+
enable_semantic_domain_stats=enable_semantic_domain_stats,
268+
semantic_domain_stats_sample_rate=semantic_domain_stats_sample_rate)
269+
232270
options_json = options.to_json()
233-
options_dict = json.loads(options_json)
234-
self.assertIsNone(options_dict['_generators'])
235-
self.assertIsNone(options_dict['_slice_functions'])
271+
options = stats_options.StatsOptions.from_json(options_json)
272+
273+
self.assertIsNone(options.generators)
274+
self.assertEqual(feature_whitelist, options.feature_whitelist)
275+
compare.assertProtoEqual(self, schema, options.schema)
276+
self.assertEqual(label_feature, options.label_feature)
277+
self.assertEqual(weight_feature, options.weight_feature)
278+
self.assertIsNone(options.slice_functions)
279+
self.assertEqual(sample_rate, options.sample_rate)
280+
self.assertEqual(num_top_values, options.num_top_values)
281+
self.assertEqual(frequency_threshold, options.frequency_threshold)
282+
self.assertEqual(weighted_frequency_threshold,
283+
options.weighted_frequency_threshold)
284+
self.assertEqual(num_rank_histogram_buckets,
285+
options.num_rank_histogram_buckets)
286+
self.assertEqual(num_values_histogram_buckets,
287+
options.num_values_histogram_buckets)
288+
self.assertEqual(num_histogram_buckets, options.num_histogram_buckets)
289+
self.assertEqual(num_quantiles_histogram_buckets,
290+
options.num_quantiles_histogram_buckets)
291+
self.assertEqual(epsilon, options.epsilon)
292+
self.assertEqual(infer_type_from_schema, options.infer_type_from_schema)
293+
self.assertEqual(desired_batch_size, options.desired_batch_size)
294+
self.assertEqual(enable_semantic_domain_stats,
295+
options.enable_semantic_domain_stats)
296+
self.assertEqual(semantic_domain_stats_sample_rate,
297+
options.semantic_domain_stats_sample_rate)
236298

237299
def test_stats_options_from_json(self):
238300
options_json = """{
@@ -261,20 +323,6 @@ def test_stats_options_from_json(self):
261323
expected_options_dict = stats_options.StatsOptions().__dict__
262324
self.assertEqual(expected_options_dict, actual_options.__dict__)
263325

264-
def test_stats_options_json_round_trip(self):
265-
options = stats_options.StatsOptions(
266-
generators=[
267-
lift_stats_generator.LiftStatsGenerator(
268-
schema=None,
269-
y_path=types.FeaturePath(['label']),
270-
x_paths=[types.FeaturePath(['feature'])])
271-
],
272-
slice_functions=[slicing_util.get_feature_value_slicer({'b': None})])
273-
options_json = options.to_json()
274-
options_from_json = stats_options.StatsOptions.from_json(options_json)
275-
self.assertIsNone(options_from_json.generators)
276-
self.assertIsNone(options_from_json.slice_functions)
277-
278326

279327
if __name__ == '__main__':
280328
absltest.main()

0 commit comments

Comments
 (0)