Skip to content

Commit ab8752c

Browse files
committed
Improve beam proto test assertions.
PiperOrigin-RevId: 299384176
1 parent ab10366 commit ab8752c

File tree

1 file changed

+37
-21
lines changed

1 file changed

+37
-21
lines changed

tensorflow_data_validation/utils/test_util.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
from __future__ import print_function
2020

21+
import traceback
22+
2123
from absl.testing import absltest
2224
import apache_beam as beam
2325
from apache_beam.testing import util
@@ -47,19 +49,24 @@ def _matcher(actual):
4749
"""Matcher function for comparing the example dicts."""
4850
try:
4951
# Check number of examples.
50-
test.assertEqual(len(actual), len(expected))
51-
52+
test.assertLen(actual, len(expected))
5253
for i in range(len(actual)):
5354
for key in actual[i]:
5455
# Check each feature value.
5556
if isinstance(expected[i][key], np.ndarray):
56-
test.assertEqual(actual[i][key].dtype, expected[i][key].dtype)
57+
test.assertEqual(
58+
expected[i][key].dtype, actual[i][key].dtype,
59+
'Expected dtype {}, found {} in actual[{}][{}]: {}'.format(
60+
expected[i][key].dtype, actual[i][key].dtype, i, key,
61+
actual[i][key]))
5762
np.testing.assert_equal(actual[i][key], expected[i][key])
5863
else:
59-
test.assertEqual(actual[i][key], expected[i][key])
64+
test.assertEqual(
65+
expected[i][key], actual[i][key],
66+
'Unexpected value of actual[{}][{}]'.format(i, key))
6067

61-
except AssertionError as e:
62-
raise util.BeamAssertException('Failed assert: ' + str(e))
68+
except AssertionError:
69+
raise util.BeamAssertException(traceback.format_exc())
6370

6471
return _matcher
6572

@@ -81,8 +88,9 @@ def make_dataset_feature_stats_list_proto_equal_fn(
8188
def _matcher(actual: List[statistics_pb2.DatasetFeatureStatisticsList]):
8289
"""Matcher function for comparing DatasetFeatureStatisticsList proto."""
8390
try:
84-
test.assertEqual(len(actual), 1)
85-
test.assertEqual(len(actual[0].datasets), len(expected_result.datasets))
91+
test.assertLen(actual, 1,
92+
'Expected exactly one DatasetFeatureStatisticsList')
93+
test.assertLen(actual[0].datasets, len(expected_result.datasets))
8694

8795
sorted_actual_datasets = sorted(actual[0].datasets, key=lambda d: d.name)
8896
sorted_expected_datasets = sorted(expected_result.datasets,
@@ -92,8 +100,8 @@ def _matcher(actual: List[statistics_pb2.DatasetFeatureStatisticsList]):
92100
assert_dataset_feature_stats_proto_equal(test,
93101
sorted_actual_datasets[i],
94102
sorted_expected_datasets[i])
95-
except AssertionError as e:
96-
raise util.BeamAssertException('Failed assert: ' + str(e))
103+
except AssertionError:
104+
raise util.BeamAssertException(traceback.format_exc())
97105

98106
return _matcher
99107

@@ -109,21 +117,21 @@ def assert_feature_proto_equal(
109117
expected: The expected feature proto.
110118
"""
111119

112-
test.assertEqual(len(actual.custom_stats), len(expected.custom_stats))
120+
test.assertLen(actual.custom_stats, len(expected.custom_stats))
113121
expected_custom_stats = {}
114122
for expected_custom_stat in expected.custom_stats:
115123
expected_custom_stats[expected_custom_stat.name] = expected_custom_stat
116124

117125
for actual_custom_stat in actual.custom_stats:
118-
test.assertTrue(actual_custom_stat.name in expected_custom_stats)
126+
test.assertIn(actual_custom_stat.name, expected_custom_stats)
119127
expected_custom_stat = expected_custom_stats[actual_custom_stat.name]
120128
compare.assertProtoEqual(
121-
test, actual_custom_stat, expected_custom_stat, normalize_numbers=True)
129+
test, expected_custom_stat, actual_custom_stat, normalize_numbers=True)
122130
del actual.custom_stats[:]
123131
del expected.custom_stats[:]
124132

125133
# Compare the rest of the proto without numeric custom stats
126-
compare.assertProtoEqual(test, actual, expected, normalize_numbers=True)
134+
compare.assertProtoEqual(test, expected, actual, normalize_numbers=True)
127135

128136

129137
def assert_dataset_feature_stats_proto_equal(
@@ -139,9 +147,14 @@ def assert_dataset_feature_stats_proto_equal(
139147
actual: The actual DatasetFeatureStatistics proto.
140148
expected: The expected DatasetFeatureStatistics proto.
141149
"""
142-
test.assertEqual(actual.name, expected.name)
143-
test.assertEqual(actual.num_examples, expected.num_examples)
144-
test.assertEqual(len(actual.features), len(expected.features))
150+
test.assertEqual(
151+
expected.name, actual.name, 'Expected name to be {}, found {} in '
152+
'DatasetFeatureStatistics {}'.format(expected.name, actual.name, actual))
153+
test.assertEqual(
154+
expected.num_examples, actual.num_examples,
155+
'Expected num_examples to be {}, found {} in DatasetFeatureStatsitics {}'
156+
.format(expected.num_examples, actual.num_examples, actual))
157+
test.assertLen(actual.features, len(expected.features))
145158

146159
expected_features = {}
147160
for feature in expected.features:
@@ -382,13 +395,16 @@ def _matcher(actual_tables):
382395
"""Arrow tables matcher fn."""
383396
test.assertLen(actual_tables, len(expected_tables))
384397
for i in range(len(expected_tables)):
385-
test.assertEqual(actual_tables[i].num_columns,
386-
expected_tables[i].num_columns)
398+
test.assertEqual(
399+
expected_tables[i].num_columns, actual_tables[i].num_columns,
400+
'Expected {} columns, found {} in table {}'.format(
401+
expected_tables[i].num_columns, actual_tables[i].num_columns,
402+
actual_tables[i]))
387403
for column_name, expected_column in zip(
388404
expected_tables[i].schema.names, expected_tables[i].columns):
389405
actual_column = actual_tables[i].column(column_name)
390-
test.assertEqual(len(actual_column.data.chunks),
391-
len(expected_column.data.chunks))
406+
test.assertLen(actual_column.data.chunks,
407+
len(expected_column.data.chunks))
392408
for j in range(len(expected_column.data.chunks)):
393409
actual_chunk = actual_column.data.chunk(j)
394410
expected_chunk = expected_column.data.chunk(j)

0 commit comments

Comments
 (0)