@@ -233,6 +233,41 @@ def compare_cursor_description(
233233 sea_val ,
234234 )
235235
236+ def _safe_compare (self , val1 , val2 ):
237+ """
238+ Safely compare two values, handling lists, dicts, and complex types.
239+
240+ Returns True if values are equal, False otherwise.
241+ """
242+ try :
243+ # Handle None values
244+ if val1 is None and val2 is None :
245+ return True
246+ if val1 is None or val2 is None :
247+ return False
248+
249+ # For lists, tuples, and other sequences (but not strings)
250+ if isinstance (val1 , (list , tuple )) and isinstance (val2 , (list , tuple )):
251+ if len (val1 ) != len (val2 ):
252+ return False
253+ return all (self ._safe_compare (v1 , v2 ) for v1 , v2 in zip (val1 , val2 ))
254+
255+ # For dictionaries
256+ if isinstance (val1 , dict ) and isinstance (val2 , dict ):
257+ if set (val1 .keys ()) != set (val2 .keys ()):
258+ return False
259+ return all (self ._safe_compare (val1 [k ], val2 [k ]) for k in val1 .keys ())
260+
261+ # For Row objects (which are tuples with special properties)
262+ if hasattr (val1 , 'asDict' ) and hasattr (val2 , 'asDict' ):
263+ return self ._safe_compare (val1 .asDict (recursive = True ), val2 .asDict (recursive = True ))
264+
265+ # Default comparison
266+ return val1 == val2
267+ except (ValueError , TypeError ) as e :
268+ # If comparison fails (e.g., numpy arrays), convert to string
269+ return str (val1 ) == str (val2 )
270+
236271 def compare_rows (
237272 self , thrift_rows : List [Row ], sea_rows : List [Row ], result : ComparisonResult
238273 ):
@@ -264,9 +299,19 @@ def compare_rows(
264299 thrift_dict = thrift_row .asDict (recursive = True )
265300 sea_dict = sea_row .asDict (recursive = True )
266301
267- if thrift_dict != sea_dict :
268- # Find which fields differ
269- all_fields = set (thrift_dict .keys ()) | set (sea_dict .keys ())
302+ # Check if dictionaries are different by comparing all fields
303+ all_fields = set (thrift_dict .keys ()) | set (sea_dict .keys ())
304+ dicts_differ = False
305+
306+ for field in all_fields :
307+ if field not in thrift_dict or field not in sea_dict :
308+ dicts_differ = True
309+ break
310+ elif not self ._safe_compare (thrift_dict .get (field ), sea_dict .get (field )):
311+ dicts_differ = True
312+ break
313+
314+ if dicts_differ :
270315
271316 for field in all_fields :
272317 thrift_value = thrift_dict .get (field )
@@ -276,7 +321,7 @@ def compare_rows(
276321 fields_missing_in_thrift .add (field )
277322 elif field not in sea_dict :
278323 fields_missing_in_sea .add (field )
279- elif thrift_value != sea_value :
324+ elif not self . _safe_compare ( thrift_value , sea_value ) :
280325 if field not in field_value_mismatches :
281326 field_value_mismatches [field ] = []
282327 field_value_mismatches [field ].append (
@@ -308,8 +353,8 @@ def compare_rows(
308353 thrift_values = [m [1 ] for m in mismatches ]
309354 sea_values = [m [2 ] for m in mismatches ]
310355
311- if all (v == thrift_values [0 ] for v in thrift_values ) and all (
312- v == sea_values [0 ] for v in sea_values
356+ if all (self . _safe_compare ( v , thrift_values [0 ]) for v in thrift_values ) and all (
357+ self . _safe_compare ( v , sea_values [0 ]) for v in sea_values
313358 ):
314359 result .add_difference (
315360 f"Field '{ field } ' value mismatch in all rows" ,
0 commit comments