@@ -235,8 +235,8 @@ def compare_cursor_description(
235235
236236 def _safe_compare (self , val1 , val2 ):
237237 """
238- Safely compare two values, handling lists, dicts, and complex types .
239-
238+ Safely compare two values, handling Row objects and PyArrow tables .
239+
240240 Returns True if values are equal, False otherwise.
241241 """
242242 try :
@@ -245,28 +245,40 @@ def _safe_compare(self, val1, val2):
245245 return True
246246 if val1 is None or val2 is None :
247247 return False
248-
249- # For lists, tuples, and other sequences (but not strings)
248+
249+ # For Row objects, convert to dictionaries
250+ if hasattr (val1 , "asDict" ) and hasattr (val2 , "asDict" ):
251+ return self ._safe_compare (
252+ val1 .asDict (recursive = True ), val2 .asDict (recursive = True )
253+ )
254+
255+ # For PyArrow arrays/tables
256+ if hasattr (val1 , "to_pylist" ) and hasattr (val2 , "to_pylist" ):
257+ return val1 .to_pylist () == val2 .to_pylist ()
258+
259+ # For lists and tuples
250260 if isinstance (val1 , (list , tuple )) and isinstance (val2 , (list , tuple )):
251261 if len (val1 ) != len (val2 ):
252262 return False
253263 return all (self ._safe_compare (v1 , v2 ) for v1 , v2 in zip (val1 , val2 ))
254-
264+
255265 # For dictionaries
256266 if isinstance (val1 , dict ) and isinstance (val2 , dict ):
257267 if set (val1 .keys ()) != set (val2 .keys ()):
258268 return False
259269 return all (self ._safe_compare (val1 [k ], val2 [k ]) for k in val1 .keys ())
260-
261- # For Row objects (which are tuples with special properties)
262- if hasattr (val1 , 'asDict' ) and hasattr (val2 , 'asDict' ):
263- return self ._safe_compare (val1 .asDict (recursive = True ), val2 .asDict (recursive = True ))
264-
265- # Default comparison
266- return val1 == val2
267- except (ValueError , TypeError ) as e :
268- # If comparison fails (e.g., numpy arrays), convert to string
269- return str (val1 ) == str (val2 )
270+
271+ # Default comparison - ensure we always return a boolean
272+ result = val1 == val2
273+ # If result is not a simple boolean, use bool() to convert it
274+ return bool (result )
275+
276+ except (ValueError , TypeError ):
277+ # Fallback to string comparison for problematic types
278+ try :
279+ return str (val1 ) == str (val2 )
280+ except :
281+ return False
270282
271283 def compare_rows (
272284 self , thrift_rows : List [Row ], sea_rows : List [Row ], result : ComparisonResult
@@ -302,15 +314,17 @@ def compare_rows(
302314 # Check if dictionaries are different by comparing all fields
303315 all_fields = set (thrift_dict .keys ()) | set (sea_dict .keys ())
304316 dicts_differ = False
305-
317+
306318 for field in all_fields :
307319 if field not in thrift_dict or field not in sea_dict :
308320 dicts_differ = True
309321 break
310- elif not self ._safe_compare (thrift_dict .get (field ), sea_dict .get (field )):
322+ elif not self ._safe_compare (
323+ thrift_dict .get (field ), sea_dict .get (field )
324+ ):
311325 dicts_differ = True
312326 break
313-
327+
314328 if dicts_differ :
315329
316330 for field in all_fields :
@@ -353,9 +367,9 @@ def compare_rows(
353367 thrift_values = [m [1 ] for m in mismatches ]
354368 sea_values = [m [2 ] for m in mismatches ]
355369
356- if all (self . _safe_compare ( v , thrift_values [ 0 ]) for v in thrift_values ) and all (
357- self ._safe_compare (v , sea_values [0 ]) for v in sea_values
358- ):
370+ if all (
371+ self ._safe_compare (v , thrift_values [0 ]) for v in thrift_values
372+ ) and all ( self . _safe_compare ( v , sea_values [ 0 ]) for v in sea_values ) :
359373 result .add_difference (
360374 f"Field '{ field } ' value mismatch in all rows" ,
361375 thrift_values [0 ],
0 commit comments