Skip to content

Commit a9b9006

Browse files
safer comparision
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent a1a41dc commit a9b9006

File tree

1 file changed

+51
-6
lines changed

1 file changed

+51
-6
lines changed

examples/experimental/comparator.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,41 @@ def compare_cursor_description(
233233
sea_val,
234234
)
235235

236+
def _safe_compare(self, val1, val2):
237+
"""
238+
Safely compare two values, handling lists, dicts, and complex types.
239+
240+
Returns True if values are equal, False otherwise.
241+
"""
242+
try:
243+
# Handle None values
244+
if val1 is None and val2 is None:
245+
return True
246+
if val1 is None or val2 is None:
247+
return False
248+
249+
# For lists, tuples, and other sequences (but not strings)
250+
if isinstance(val1, (list, tuple)) and isinstance(val2, (list, tuple)):
251+
if len(val1) != len(val2):
252+
return False
253+
return all(self._safe_compare(v1, v2) for v1, v2 in zip(val1, val2))
254+
255+
# For dictionaries
256+
if isinstance(val1, dict) and isinstance(val2, dict):
257+
if set(val1.keys()) != set(val2.keys()):
258+
return False
259+
return all(self._safe_compare(val1[k], val2[k]) for k in val1.keys())
260+
261+
# For Row objects (which are tuples with special properties)
262+
if hasattr(val1, 'asDict') and hasattr(val2, 'asDict'):
263+
return self._safe_compare(val1.asDict(recursive=True), val2.asDict(recursive=True))
264+
265+
# Default comparison
266+
return val1 == val2
267+
except (ValueError, TypeError) as e:
268+
# If comparison fails (e.g., numpy arrays), convert to string
269+
return str(val1) == str(val2)
270+
236271
def compare_rows(
237272
self, thrift_rows: List[Row], sea_rows: List[Row], result: ComparisonResult
238273
):
@@ -264,9 +299,19 @@ def compare_rows(
264299
thrift_dict = thrift_row.asDict(recursive=True)
265300
sea_dict = sea_row.asDict(recursive=True)
266301

267-
if thrift_dict != sea_dict:
268-
# Find which fields differ
269-
all_fields = set(thrift_dict.keys()) | set(sea_dict.keys())
302+
# Check if dictionaries are different by comparing all fields
303+
all_fields = set(thrift_dict.keys()) | set(sea_dict.keys())
304+
dicts_differ = False
305+
306+
for field in all_fields:
307+
if field not in thrift_dict or field not in sea_dict:
308+
dicts_differ = True
309+
break
310+
elif not self._safe_compare(thrift_dict.get(field), sea_dict.get(field)):
311+
dicts_differ = True
312+
break
313+
314+
if dicts_differ:
270315

271316
for field in all_fields:
272317
thrift_value = thrift_dict.get(field)
@@ -276,7 +321,7 @@ def compare_rows(
276321
fields_missing_in_thrift.add(field)
277322
elif field not in sea_dict:
278323
fields_missing_in_sea.add(field)
279-
elif thrift_value != sea_value:
324+
elif not self._safe_compare(thrift_value, sea_value):
280325
if field not in field_value_mismatches:
281326
field_value_mismatches[field] = []
282327
field_value_mismatches[field].append(
@@ -308,8 +353,8 @@ def compare_rows(
308353
thrift_values = [m[1] for m in mismatches]
309354
sea_values = [m[2] for m in mismatches]
310355

311-
if all(v == thrift_values[0] for v in thrift_values) and all(
312-
v == sea_values[0] for v in sea_values
356+
if all(self._safe_compare(v, thrift_values[0]) for v in thrift_values) and all(
357+
self._safe_compare(v, sea_values[0]) for v in sea_values
313358
):
314359
result.add_difference(
315360
f"Field '{field}' value mismatch in all rows",

0 commit comments

Comments
 (0)