@@ -2330,7 +2330,7 @@ def test_execute_command_sets_complex_type_fields_correctly(
23302330 [],
23312331 auth_provider = AuthProvider (),
23322332 ssl_options = SSLOptions (),
2333- http_client = MagicMock (),
2333+ http_client = MagicMock (),
23342334 ** complex_arg_types ,
23352335 )
23362336 thrift_backend .execute_command (
@@ -2356,148 +2356,85 @@ def test_execute_command_sets_complex_type_fields_correctly(
23562356 t_execute_statement_req .useArrowNativeTypes .intervalTypesAsArrow
23572357 )
23582358
2359- def test_col_to_description_with_variant_type (self ):
2360- # Test variant type detection from Arrow field metadata
2361- col = ttypes .TColumnDesc (
2362- columnName = "variant_col" ,
2363- typeDesc = self ._make_type_desc (ttypes .TTypeId .STRING_TYPE ),
2364- )
2365-
2366- # Create a field with variant type in metadata
2367- field = pyarrow .field (
2368- "variant_col" ,
2369- pyarrow .string (),
2370- metadata = {b'Spark:DataType:SqlName' : b'VARIANT' }
2371- )
2372-
2373- result = ThriftDatabricksClient ._col_to_description (col , field )
2374-
2375- # Verify the result has variant as the type
2376- self .assertEqual (result [0 ], "variant_col" ) # Column name
2377- self .assertEqual (result [1 ], "variant" ) # Type name (should be variant instead of string)
2378- self .assertIsNone (result [2 ]) # No display size
2379- self .assertIsNone (result [3 ]) # No internal size
2380- self .assertIsNone (result [4 ]) # No precision
2381- self .assertIsNone (result [5 ]) # No scale
2382- self .assertIsNone (result [6 ]) # No null ok
2383-
2384- def test_col_to_description_without_variant_type (self ):
2385- # Test normal column without variant type
2386- col = ttypes .TColumnDesc (
2387- columnName = "normal_col" ,
2388- typeDesc = self ._make_type_desc (ttypes .TTypeId .STRING_TYPE ),
2389- )
2390-
2391- # Create a normal field without variant metadata
2392- field = pyarrow .field (
2393- "normal_col" ,
2394- pyarrow .string (),
2395- metadata = {}
2396- )
2397-
2398- result = ThriftDatabricksClient ._col_to_description (col , field )
2399-
2400- # Verify the result has string as the type (unchanged)
2401- self .assertEqual (result [0 ], "normal_col" ) # Column name
2402- self .assertEqual (result [1 ], "string" ) # Type name (should be string)
2403- self .assertIsNone (result [2 ]) # No display size
2404- self .assertIsNone (result [3 ]) # No internal size
2405- self .assertIsNone (result [4 ]) # No precision
2406- self .assertIsNone (result [5 ]) # No scale
2407- self .assertIsNone (result [6 ]) # No null ok
2408-
2409- def test_col_to_description_with_null_field (self ):
2410- # Test handling of null field
2411- col = ttypes .TColumnDesc (
2412- columnName = "missing_field" ,
2413- typeDesc = self ._make_type_desc (ttypes .TTypeId .STRING_TYPE ),
2414- )
2415-
2416- # Pass None as the field
2417- result = ThriftDatabricksClient ._col_to_description (col , None )
2418-
2419- # Verify the result has string as the type (unchanged)
2420- self .assertEqual (result [0 ], "missing_field" ) # Column name
2421- self .assertEqual (result [1 ], "string" ) # Type name (should be string)
2422- self .assertIsNone (result [2 ]) # No display size
2423- self .assertIsNone (result [3 ]) # No internal size
2424- self .assertIsNone (result [4 ]) # No precision
2425- self .assertIsNone (result [5 ]) # No scale
2426- self .assertIsNone (result [6 ]) # No null ok
2427-
2428- def test_hive_schema_to_description_with_arrow_schema (self ):
2429- # Create a table schema with regular and variant columns
2430- columns = [
2431- ttypes .TColumnDesc (
2432- columnName = "regular_col" ,
2433- typeDesc = self ._make_type_desc (ttypes .TTypeId .STRING_TYPE ),
2434- ),
2435- ttypes .TColumnDesc (
2436- columnName = "variant_col" ,
2437- typeDesc = self ._make_type_desc (ttypes .TTypeId .STRING_TYPE ),
2438- ),
2359+ @unittest .skipIf (pyarrow is None , "Requires pyarrow" )
2360+ def test_col_to_description (self ):
2361+ test_cases = [
2362+ ("variant_col" , {b"Spark:DataType:SqlName" : b"VARIANT" }, "variant" ),
2363+ ("normal_col" , {}, "string" ),
2364+ ("weird_field" , {b"Spark:DataType:SqlName" : b"Some unexpected value" }, "string" ),
2365+ ("missing_field" , None , "string" ), # None field case
24392366 ]
2440- t_table_schema = ttypes .TTableSchema (columns = columns )
24412367
2442- # Create an Arrow schema with one variant column
2443- fields = [
2444- pyarrow .field ("regular_col" , pyarrow .string ()),
2445- pyarrow .field (
2446- "variant_col" ,
2447- pyarrow .string (),
2448- metadata = {b'Spark:DataType:SqlName' : b'VARIANT' }
2449- )
2450- ]
2451- arrow_schema = pyarrow .schema (fields )
2452- schema_bytes = arrow_schema .serialize ().to_pybytes ()
2453-
2454- # Get the description
2455- description = ThriftDatabricksClient ._hive_schema_to_description (t_table_schema , schema_bytes )
2456-
2457- # Verify regular column type
2458- self .assertEqual (description [0 ][0 ], "regular_col" )
2459- self .assertEqual (description [0 ][1 ], "string" )
2460-
2461- # Verify variant column type
2462- self .assertEqual (description [1 ][0 ], "variant_col" )
2463- self .assertEqual (description [1 ][1 ], "variant" )
2368+ for column_name , field_metadata , expected_type in test_cases :
2369+ with self .subTest (column_name = column_name , expected_type = expected_type ):
2370+ col = ttypes .TColumnDesc (
2371+ columnName = column_name ,
2372+ typeDesc = self ._make_type_desc (ttypes .TTypeId .STRING_TYPE ),
2373+ )
24642374
2465- def test_hive_schema_to_description_with_null_schema_bytes (self ):
2466- # Create a simple table schema
2467- columns = [
2468- ttypes .TColumnDesc (
2469- columnName = "regular_col" ,
2470- typeDesc = self ._make_type_desc (ttypes .TTypeId .STRING_TYPE ),
2375+ field = (
2376+ None
2377+ if field_metadata is None
2378+ else pyarrow .field (column_name , pyarrow .string (), metadata = field_metadata )
2379+ )
2380+
2381+ result = ThriftDatabricksClient ._col_to_description (col , field )
2382+
2383+ self .assertEqual (result [0 ], column_name )
2384+ self .assertEqual (result [1 ], expected_type )
2385+ self .assertIsNone (result [2 ])
2386+ self .assertIsNone (result [3 ])
2387+ self .assertIsNone (result [4 ])
2388+ self .assertIsNone (result [5 ])
2389+ self .assertIsNone (result [6 ])
2390+
2391+ @unittest .skipIf (pyarrow is None , "Requires pyarrow" )
2392+ def test_hive_schema_to_description (self ):
2393+ test_cases = [
2394+ (
2395+ [
2396+ ("regular_col" , ttypes .TTypeId .STRING_TYPE ),
2397+ ("variant_col" , ttypes .TTypeId .STRING_TYPE ),
2398+ ],
2399+ [
2400+ ("regular_col" , {}),
2401+ ("variant_col" , {b"Spark:DataType:SqlName" : b"VARIANT" }),
2402+ ],
2403+ [("regular_col" , "string" ), ("variant_col" , "variant" )],
2404+ ),
2405+ (
2406+ [("regular_col" , ttypes .TTypeId .STRING_TYPE )],
2407+ None , # No arrow schema
2408+ [("regular_col" , "string" )],
24712409 ),
24722410 ]
2473- t_table_schema = ttypes .TTableSchema (columns = columns )
2474-
2475- # Get the description with null schema_bytes
2476- description = ThriftDatabricksClient ._hive_schema_to_description (t_table_schema , None )
24772411
2478- # Verify column type remains unchanged
2479- self .assertEqual (description [0 ][0 ], "regular_col" )
2480- self .assertEqual (description [0 ][1 ], "string" )
2412+ for columns , arrow_fields , expected_types in test_cases :
2413+ with self .subTest (arrow_fields = arrow_fields is not None ):
2414+ t_table_schema = ttypes .TTableSchema (
2415+ columns = [
2416+ ttypes .TColumnDesc (
2417+ columnName = name , typeDesc = self ._make_type_desc (col_type )
2418+ )
2419+ for name , col_type in columns
2420+ ]
2421+ )
24812422
2482- def test_col_to_description_with_malformed_metadata (self ):
2483- # Test handling of malformed metadata
2484- col = ttypes .TColumnDesc (
2485- columnName = "weird_field" ,
2486- typeDesc = self ._make_type_desc (ttypes .TTypeId .STRING_TYPE ),
2487- )
2488-
2489- # Create a field with malformed metadata
2490- field = pyarrow .field (
2491- "weird_field" ,
2492- pyarrow .string (),
2493- metadata = {b'Spark:DataType:SqlName' : b'Some unexpected value' }
2494- )
2495-
2496- result = ThriftDatabricksClient ._col_to_description (col , field )
2497-
2498- # Verify the type remains unchanged
2499- self .assertEqual (result [0 ], "weird_field" ) # Column name
2500- self .assertEqual (result [1 ], "string" ) # Type name (should remain string)
2423+ schema_bytes = None
2424+ if arrow_fields :
2425+ fields = [
2426+ pyarrow .field (name , pyarrow .string (), metadata = metadata )
2427+ for name , metadata in arrow_fields
2428+ ]
2429+ schema_bytes = pyarrow .schema (fields ).serialize ().to_pybytes ()
2430+
2431+ description = ThriftDatabricksClient ._hive_schema_to_description (
2432+ t_table_schema , schema_bytes
2433+ )
2434+
2435+ for i , (expected_name , expected_type ) in enumerate (expected_types ):
2436+ self .assertEqual (description [i ][0 ], expected_name )
2437+ self .assertEqual (description [i ][1 ], expected_type )
25012438
25022439
25032440if __name__ == "__main__" :
0 commit comments