|
17 | 17 | from __future__ import division
|
18 | 18 | from __future__ import print_function
|
19 | 19 |
|
20 |
| -import json |
21 |
| - |
22 | 20 | from absl.testing import absltest
|
23 | 21 | from absl.testing import parameterized
|
24 | 22 | from tensorflow_data_validation import types
|
25 | 23 | from tensorflow_data_validation.statistics import stats_options
|
26 | 24 | from tensorflow_data_validation.statistics.generators import lift_stats_generator
|
27 | 25 | from tensorflow_data_validation.utils import slicing_util
|
28 | 26 |
|
| 27 | +from tensorflow.python.util.protobuf import compare # pylint: disable=g-direct-tensorflow-import |
| 28 | +from tensorflow_metadata.proto.v0 import schema_pb2 |
29 | 29 |
|
30 | 30 | INVALID_STATS_OPTIONS = [
|
31 | 31 | {
|
@@ -220,19 +220,81 @@ def test_stats_options(self, stats_options_kwargs, exception_type,
|
220 | 220 | with self.assertRaisesRegexp(exception_type, error_message):
|
221 | 221 | stats_options.StatsOptions(**stats_options_kwargs)
|
222 | 222 |
|
223 |
| - def test_stats_options_to_json(self): |
| 223 | + def test_stats_options_json_round_trip(self): |
| 224 | + generators = [ |
| 225 | + lift_stats_generator.LiftStatsGenerator( |
| 226 | + schema=None, |
| 227 | + y_path=types.FeaturePath(['label']), |
| 228 | + x_paths=[types.FeaturePath(['feature'])]) |
| 229 | + ] |
| 230 | + feature_whitelist = ['a'] |
| 231 | + schema = schema_pb2.Schema(feature=[schema_pb2.Feature(name='f')]) |
| 232 | + label_feature = 'label' |
| 233 | + weight_feature = 'weight' |
| 234 | + slice_functions = [slicing_util.get_feature_value_slicer({'b': None})] |
| 235 | + sample_rate = 0.01 |
| 236 | + num_top_values = 21 |
| 237 | + frequency_threshold = 2 |
| 238 | + weighted_frequency_threshold = 2.0 |
| 239 | + num_rank_histogram_buckets = 1001 |
| 240 | + num_values_histogram_buckets = 11 |
| 241 | + num_histogram_buckets = 11 |
| 242 | + num_quantiles_histogram_buckets = 11 |
| 243 | + epsilon = 0.02 |
| 244 | + infer_type_from_schema = True |
| 245 | + desired_batch_size = 100 |
| 246 | + enable_semantic_domain_stats = True |
| 247 | + semantic_domain_stats_sample_rate = 0.1 |
| 248 | + |
224 | 249 | options = stats_options.StatsOptions(
|
225 |
| - generators=[ |
226 |
| - lift_stats_generator.LiftStatsGenerator( |
227 |
| - schema=None, |
228 |
| - y_path=types.FeaturePath(['label']), |
229 |
| - x_paths=[types.FeaturePath(['feature'])]) |
230 |
| - ], |
231 |
| - slice_functions=[slicing_util.get_feature_value_slicer({'b': None})]) |
| 250 | + generators=generators, |
| 251 | + feature_whitelist=feature_whitelist, |
| 252 | + schema=schema, |
| 253 | + label_feature=label_feature, |
| 254 | + weight_feature=weight_feature, |
| 255 | + slice_functions=slice_functions, |
| 256 | + sample_rate=sample_rate, |
| 257 | + num_top_values=num_top_values, |
| 258 | + frequency_threshold=frequency_threshold, |
| 259 | + weighted_frequency_threshold=weighted_frequency_threshold, |
| 260 | + num_rank_histogram_buckets=num_rank_histogram_buckets, |
| 261 | + num_values_histogram_buckets=num_values_histogram_buckets, |
| 262 | + num_histogram_buckets=num_histogram_buckets, |
| 263 | + num_quantiles_histogram_buckets=num_quantiles_histogram_buckets, |
| 264 | + epsilon=epsilon, |
| 265 | + infer_type_from_schema=infer_type_from_schema, |
| 266 | + desired_batch_size=desired_batch_size, |
| 267 | + enable_semantic_domain_stats=enable_semantic_domain_stats, |
| 268 | + semantic_domain_stats_sample_rate=semantic_domain_stats_sample_rate) |
| 269 | + |
232 | 270 | options_json = options.to_json()
|
233 |
| - options_dict = json.loads(options_json) |
234 |
| - self.assertIsNone(options_dict['_generators']) |
235 |
| - self.assertIsNone(options_dict['_slice_functions']) |
| 271 | + options = stats_options.StatsOptions.from_json(options_json) |
| 272 | + |
| 273 | + self.assertIsNone(options.generators) |
| 274 | + self.assertEqual(feature_whitelist, options.feature_whitelist) |
| 275 | + compare.assertProtoEqual(self, schema, options.schema) |
| 276 | + self.assertEqual(label_feature, options.label_feature) |
| 277 | + self.assertEqual(weight_feature, options.weight_feature) |
| 278 | + self.assertIsNone(options.slice_functions) |
| 279 | + self.assertEqual(sample_rate, options.sample_rate) |
| 280 | + self.assertEqual(num_top_values, options.num_top_values) |
| 281 | + self.assertEqual(frequency_threshold, options.frequency_threshold) |
| 282 | + self.assertEqual(weighted_frequency_threshold, |
| 283 | + options.weighted_frequency_threshold) |
| 284 | + self.assertEqual(num_rank_histogram_buckets, |
| 285 | + options.num_rank_histogram_buckets) |
| 286 | + self.assertEqual(num_values_histogram_buckets, |
| 287 | + options.num_values_histogram_buckets) |
| 288 | + self.assertEqual(num_histogram_buckets, options.num_histogram_buckets) |
| 289 | + self.assertEqual(num_quantiles_histogram_buckets, |
| 290 | + options.num_quantiles_histogram_buckets) |
| 291 | + self.assertEqual(epsilon, options.epsilon) |
| 292 | + self.assertEqual(infer_type_from_schema, options.infer_type_from_schema) |
| 293 | + self.assertEqual(desired_batch_size, options.desired_batch_size) |
| 294 | + self.assertEqual(enable_semantic_domain_stats, |
| 295 | + options.enable_semantic_domain_stats) |
| 296 | + self.assertEqual(semantic_domain_stats_sample_rate, |
| 297 | + options.semantic_domain_stats_sample_rate) |
236 | 298 |
|
237 | 299 | def test_stats_options_from_json(self):
|
238 | 300 | options_json = """{
|
@@ -261,20 +323,6 @@ def test_stats_options_from_json(self):
|
261 | 323 | expected_options_dict = stats_options.StatsOptions().__dict__
|
262 | 324 | self.assertEqual(expected_options_dict, actual_options.__dict__)
|
263 | 325 |
|
264 |
| - def test_stats_options_json_round_trip(self): |
265 |
| - options = stats_options.StatsOptions( |
266 |
| - generators=[ |
267 |
| - lift_stats_generator.LiftStatsGenerator( |
268 |
| - schema=None, |
269 |
| - y_path=types.FeaturePath(['label']), |
270 |
| - x_paths=[types.FeaturePath(['feature'])]) |
271 |
| - ], |
272 |
| - slice_functions=[slicing_util.get_feature_value_slicer({'b': None})]) |
273 |
| - options_json = options.to_json() |
274 |
| - options_from_json = stats_options.StatsOptions.from_json(options_json) |
275 |
| - self.assertIsNone(options_from_json.generators) |
276 |
| - self.assertIsNone(options_from_json.slice_functions) |
277 |
| - |
278 | 326 |
|
279 | 327 | if __name__ == '__main__':
|
280 | 328 | absltest.main()
|
0 commit comments