Skip to content

Commit 6880834

Browse files
committed
address comments
1 parent 0d122e8 commit 6880834

File tree

2 files changed

+99
-151
lines changed

2 files changed

+99
-151
lines changed

tests/e2e/test_variant_types.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
from datetime import datetime
33
import json
4+
45
try:
56
import pyarrow
67
except ImportError:
@@ -9,6 +10,8 @@
910
from tests.e2e.test_driver import PySQLPytestTestCase
1011
from tests.e2e.common.predicates import pysql_supports_arrow
1112

13+
14+
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Requires arrow support")
1215
class TestVariantTypes(PySQLPytestTestCase):
1316
"""Tests for the proper detection and handling of VARIANT type columns"""
1417

@@ -17,7 +20,7 @@ def variant_table(self, connection_details):
1720
"""A pytest fixture that creates a test table and cleans up after tests"""
1821
self.arguments = connection_details.copy()
1922
table_name = "pysql_test_variant_types_table"
20-
23+
2124
with self.cursor() as cursor:
2225
try:
2326
# Create the table with variant columns
@@ -30,7 +33,7 @@ def variant_table(self, connection_details):
3033
)
3134
"""
3235
)
33-
36+
3437
# Insert test records with different variant values
3538
cursor.execute(
3639
"""
@@ -44,37 +47,45 @@ def variant_table(self, connection_details):
4447
finally:
4548
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
4649

47-
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Requires arrow support")
4850
def test_variant_type_detection(self, variant_table):
4951
"""Test that VARIANT type columns are properly detected in schema"""
5052
with self.cursor() as cursor:
5153
cursor.execute(f"SELECT * FROM {variant_table} LIMIT 0")
52-
54+
5355
# Verify column types in description
54-
assert cursor.description[0][1] == 'int', "Integer column type not correctly identified"
55-
assert cursor.description[1][1] == 'variant', "VARIANT column type not correctly identified"
56-
assert cursor.description[2][1] == 'string', "String column type not correctly identified"
56+
assert (
57+
cursor.description[0][1] == "int"
58+
), "Integer column type not correctly identified"
59+
assert (
60+
cursor.description[1][1] == "variant"
61+
), "VARIANT column type not correctly identified"
62+
assert (
63+
cursor.description[2][1] == "string"
64+
), "String column type not correctly identified"
5765

58-
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Requires arrow support")
5966
def test_variant_data_retrieval(self, variant_table):
6067
"""Test that VARIANT data is properly retrieved and can be accessed as JSON"""
6168
with self.cursor() as cursor:
6269
cursor.execute(f"SELECT * FROM {variant_table} ORDER BY id")
6370
rows = cursor.fetchall()
64-
71+
6572
# First row should have a JSON object
6673
json_obj = rows[0][1]
67-
assert isinstance(json_obj, str), "VARIANT column should be returned as string"
74+
assert isinstance(
75+
json_obj, str
76+
), "VARIANT column should be returned as string"
6877

6978
parsed = json.loads(json_obj)
70-
assert parsed.get('name') == 'John'
71-
assert parsed.get('age') == 30
79+
assert parsed.get("name") == "John"
80+
assert parsed.get("age") == 30
7281

7382
# Second row should have a JSON array
7483
json_array = rows[1][1]
75-
assert isinstance(json_array, str), "VARIANT array should be returned as string"
76-
84+
assert isinstance(
85+
json_array, str
86+
), "VARIANT array should be returned as string"
87+
7788
# Parsing to verify it's valid JSON array
7889
parsed_array = json.loads(json_array)
7990
assert isinstance(parsed_array, list)
80-
assert parsed_array == [1, 2, 3, 4]
91+
assert parsed_array == [1, 2, 3, 4]

tests/unit/test_thrift_backend.py

Lines changed: 73 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -2330,7 +2330,7 @@ def test_execute_command_sets_complex_type_fields_correctly(
23302330
[],
23312331
auth_provider=AuthProvider(),
23322332
ssl_options=SSLOptions(),
2333-
http_client=MagicMock(),
2333+
http_client=MagicMock(),
23342334
**complex_arg_types,
23352335
)
23362336
thrift_backend.execute_command(
@@ -2356,148 +2356,85 @@ def test_execute_command_sets_complex_type_fields_correctly(
23562356
t_execute_statement_req.useArrowNativeTypes.intervalTypesAsArrow
23572357
)
23582358

2359-
def test_col_to_description_with_variant_type(self):
2360-
# Test variant type detection from Arrow field metadata
2361-
col = ttypes.TColumnDesc(
2362-
columnName="variant_col",
2363-
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
2364-
)
2365-
2366-
# Create a field with variant type in metadata
2367-
field = pyarrow.field(
2368-
"variant_col",
2369-
pyarrow.string(),
2370-
metadata={b'Spark:DataType:SqlName': b'VARIANT'}
2371-
)
2372-
2373-
result = ThriftDatabricksClient._col_to_description(col, field)
2374-
2375-
# Verify the result has variant as the type
2376-
self.assertEqual(result[0], "variant_col") # Column name
2377-
self.assertEqual(result[1], "variant") # Type name (should be variant instead of string)
2378-
self.assertIsNone(result[2]) # No display size
2379-
self.assertIsNone(result[3]) # No internal size
2380-
self.assertIsNone(result[4]) # No precision
2381-
self.assertIsNone(result[5]) # No scale
2382-
self.assertIsNone(result[6]) # No null ok
2383-
2384-
def test_col_to_description_without_variant_type(self):
2385-
# Test normal column without variant type
2386-
col = ttypes.TColumnDesc(
2387-
columnName="normal_col",
2388-
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
2389-
)
2390-
2391-
# Create a normal field without variant metadata
2392-
field = pyarrow.field(
2393-
"normal_col",
2394-
pyarrow.string(),
2395-
metadata={}
2396-
)
2397-
2398-
result = ThriftDatabricksClient._col_to_description(col, field)
2399-
2400-
# Verify the result has string as the type (unchanged)
2401-
self.assertEqual(result[0], "normal_col") # Column name
2402-
self.assertEqual(result[1], "string") # Type name (should be string)
2403-
self.assertIsNone(result[2]) # No display size
2404-
self.assertIsNone(result[3]) # No internal size
2405-
self.assertIsNone(result[4]) # No precision
2406-
self.assertIsNone(result[5]) # No scale
2407-
self.assertIsNone(result[6]) # No null ok
2408-
2409-
def test_col_to_description_with_null_field(self):
2410-
# Test handling of null field
2411-
col = ttypes.TColumnDesc(
2412-
columnName="missing_field",
2413-
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
2414-
)
2415-
2416-
# Pass None as the field
2417-
result = ThriftDatabricksClient._col_to_description(col, None)
2418-
2419-
# Verify the result has string as the type (unchanged)
2420-
self.assertEqual(result[0], "missing_field") # Column name
2421-
self.assertEqual(result[1], "string") # Type name (should be string)
2422-
self.assertIsNone(result[2]) # No display size
2423-
self.assertIsNone(result[3]) # No internal size
2424-
self.assertIsNone(result[4]) # No precision
2425-
self.assertIsNone(result[5]) # No scale
2426-
self.assertIsNone(result[6]) # No null ok
2427-
2428-
def test_hive_schema_to_description_with_arrow_schema(self):
2429-
# Create a table schema with regular and variant columns
2430-
columns = [
2431-
ttypes.TColumnDesc(
2432-
columnName="regular_col",
2433-
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
2434-
),
2435-
ttypes.TColumnDesc(
2436-
columnName="variant_col",
2437-
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
2438-
),
2359+
@unittest.skipIf(pyarrow is None, "Requires pyarrow")
2360+
def test_col_to_description(self):
2361+
test_cases = [
2362+
("variant_col", {b"Spark:DataType:SqlName": b"VARIANT"}, "variant"),
2363+
("normal_col", {}, "string"),
2364+
("weird_field", {b"Spark:DataType:SqlName": b"Some unexpected value"}, "string"),
2365+
("missing_field", None, "string"), # None field case
24392366
]
2440-
t_table_schema = ttypes.TTableSchema(columns=columns)
24412367

2442-
# Create an Arrow schema with one variant column
2443-
fields = [
2444-
pyarrow.field("regular_col", pyarrow.string()),
2445-
pyarrow.field(
2446-
"variant_col",
2447-
pyarrow.string(),
2448-
metadata={b'Spark:DataType:SqlName': b'VARIANT'}
2449-
)
2450-
]
2451-
arrow_schema = pyarrow.schema(fields)
2452-
schema_bytes = arrow_schema.serialize().to_pybytes()
2453-
2454-
# Get the description
2455-
description = ThriftDatabricksClient._hive_schema_to_description(t_table_schema, schema_bytes)
2456-
2457-
# Verify regular column type
2458-
self.assertEqual(description[0][0], "regular_col")
2459-
self.assertEqual(description[0][1], "string")
2460-
2461-
# Verify variant column type
2462-
self.assertEqual(description[1][0], "variant_col")
2463-
self.assertEqual(description[1][1], "variant")
2368+
for column_name, field_metadata, expected_type in test_cases:
2369+
with self.subTest(column_name=column_name, expected_type=expected_type):
2370+
col = ttypes.TColumnDesc(
2371+
columnName=column_name,
2372+
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
2373+
)
24642374

2465-
def test_hive_schema_to_description_with_null_schema_bytes(self):
2466-
# Create a simple table schema
2467-
columns = [
2468-
ttypes.TColumnDesc(
2469-
columnName="regular_col",
2470-
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
2375+
field = (
2376+
None
2377+
if field_metadata is None
2378+
else pyarrow.field(column_name, pyarrow.string(), metadata=field_metadata)
2379+
)
2380+
2381+
result = ThriftDatabricksClient._col_to_description(col, field)
2382+
2383+
self.assertEqual(result[0], column_name)
2384+
self.assertEqual(result[1], expected_type)
2385+
self.assertIsNone(result[2])
2386+
self.assertIsNone(result[3])
2387+
self.assertIsNone(result[4])
2388+
self.assertIsNone(result[5])
2389+
self.assertIsNone(result[6])
2390+
2391+
@unittest.skipIf(pyarrow is None, "Requires pyarrow")
2392+
def test_hive_schema_to_description(self):
2393+
test_cases = [
2394+
(
2395+
[
2396+
("regular_col", ttypes.TTypeId.STRING_TYPE),
2397+
("variant_col", ttypes.TTypeId.STRING_TYPE),
2398+
],
2399+
[
2400+
("regular_col", {}),
2401+
("variant_col", {b"Spark:DataType:SqlName": b"VARIANT"}),
2402+
],
2403+
[("regular_col", "string"), ("variant_col", "variant")],
2404+
),
2405+
(
2406+
[("regular_col", ttypes.TTypeId.STRING_TYPE)],
2407+
None, # No arrow schema
2408+
[("regular_col", "string")],
24712409
),
24722410
]
2473-
t_table_schema = ttypes.TTableSchema(columns=columns)
2474-
2475-
# Get the description with null schema_bytes
2476-
description = ThriftDatabricksClient._hive_schema_to_description(t_table_schema, None)
24772411

2478-
# Verify column type remains unchanged
2479-
self.assertEqual(description[0][0], "regular_col")
2480-
self.assertEqual(description[0][1], "string")
2412+
for columns, arrow_fields, expected_types in test_cases:
2413+
with self.subTest(arrow_fields=arrow_fields is not None):
2414+
t_table_schema = ttypes.TTableSchema(
2415+
columns=[
2416+
ttypes.TColumnDesc(
2417+
columnName=name, typeDesc=self._make_type_desc(col_type)
2418+
)
2419+
for name, col_type in columns
2420+
]
2421+
)
24812422

2482-
def test_col_to_description_with_malformed_metadata(self):
2483-
# Test handling of malformed metadata
2484-
col = ttypes.TColumnDesc(
2485-
columnName="weird_field",
2486-
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
2487-
)
2488-
2489-
# Create a field with malformed metadata
2490-
field = pyarrow.field(
2491-
"weird_field",
2492-
pyarrow.string(),
2493-
metadata={b'Spark:DataType:SqlName': b'Some unexpected value'}
2494-
)
2495-
2496-
result = ThriftDatabricksClient._col_to_description(col, field)
2497-
2498-
# Verify the type remains unchanged
2499-
self.assertEqual(result[0], "weird_field") # Column name
2500-
self.assertEqual(result[1], "string") # Type name (should remain string)
2423+
schema_bytes = None
2424+
if arrow_fields:
2425+
fields = [
2426+
pyarrow.field(name, pyarrow.string(), metadata=metadata)
2427+
for name, metadata in arrow_fields
2428+
]
2429+
schema_bytes = pyarrow.schema(fields).serialize().to_pybytes()
2430+
2431+
description = ThriftDatabricksClient._hive_schema_to_description(
2432+
t_table_schema, schema_bytes
2433+
)
2434+
2435+
for i, (expected_name, expected_type) in enumerate(expected_types):
2436+
self.assertEqual(description[i][0], expected_name)
2437+
self.assertEqual(description[i][1], expected_type)
25012438

25022439

25032440
if __name__ == "__main__":

0 commit comments

Comments
 (0)