Skip to content

Commit 0d122e8

Browse files
committed
variant type detection
1 parent 177c197 commit 0d122e8

File tree

2 files changed

+25
-20
lines changed

2 files changed

+25
-20
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ def convert_col(t_column_desc):
735735
return pyarrow.schema([convert_col(col) for col in t_table_schema.columns])
736736

737737
@staticmethod
738-
def _col_to_description(col, field, session_id_hex=None):
738+
def _col_to_description(col, field=None, session_id_hex=None):
739739
type_entry = col.typeDesc.types[0]
740740

741741
if type_entry.primitiveEntry:
@@ -778,7 +778,9 @@ def _col_to_description(col, field, session_id_hex=None):
778778
return col.columnName, cleaned_type, None, None, precision, scale, None
779779

780780
@staticmethod
781-
def _hive_schema_to_description(t_table_schema, schema_bytes=None, session_id_hex=None):
781+
def _hive_schema_to_description(
782+
t_table_schema, schema_bytes=None, session_id_hex=None
783+
):
782784
field_dict = {}
783785
if pyarrow and schema_bytes:
784786
try:
@@ -788,8 +790,13 @@ def _hive_schema_to_description(t_table_schema, schema_bytes=None, session_id_he
788790
field_dict[field.name] = field
789791
except Exception as e:
790792
logger.debug(f"Could not parse arrow schema: {e}")
793+
791794
return [
792-
ThriftDatabricksClient._col_to_description(col, field_dict.get(col.columnName), session_id_hex)
795+
ThriftDatabricksClient._col_to_description(
796+
col,
797+
field_dict.get(col.columnName) if field_dict else None,
798+
session_id_hex,
799+
)
793800
for col in t_table_schema.columns
794801
]
795802

@@ -822,11 +829,6 @@ def _results_message_to_execute_response(self, resp, operation_state):
822829
or direct_results.resultSet.hasMoreRows
823830
)
824831

825-
description = self._hive_schema_to_description(
826-
t_result_set_metadata_resp.schema,
827-
self._session_id_hex,
828-
)
829-
830832
if pyarrow:
831833
schema_bytes = (
832834
t_result_set_metadata_resp.arrowSchema
@@ -840,7 +842,9 @@ def _results_message_to_execute_response(self, resp, operation_state):
840842
schema_bytes = None
841843

842844
description = self._hive_schema_to_description(
843-
t_result_set_metadata_resp.schema, schema_bytes
845+
t_result_set_metadata_resp.schema,
846+
schema_bytes,
847+
self._session_id_hex,
844848
)
845849

846850
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
@@ -887,11 +891,6 @@ def get_execution_result(
887891

888892
t_result_set_metadata_resp = resp.resultSetMetadata
889893

890-
description = self._hive_schema_to_description(
891-
t_result_set_metadata_resp.schema,
892-
self._session_id_hex,
893-
)
894-
895894
if pyarrow:
896895
schema_bytes = (
897896
t_result_set_metadata_resp.arrowSchema
@@ -904,6 +903,12 @@ def get_execution_result(
904903
else:
905904
schema_bytes = None
906905

906+
description = self._hive_schema_to_description(
907+
t_result_set_metadata_resp.schema,
908+
schema_bytes,
909+
self._session_id_hex,
910+
)
911+
907912
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
908913
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
909914
has_more_rows = resp.hasMoreRows

tests/unit/test_thrift_backend.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2370,7 +2370,7 @@ def test_col_to_description_with_variant_type(self):
23702370
metadata={b'Spark:DataType:SqlName': b'VARIANT'}
23712371
)
23722372

2373-
result = ThriftBackend._col_to_description(col, field)
2373+
result = ThriftDatabricksClient._col_to_description(col, field)
23742374

23752375
# Verify the result has variant as the type
23762376
self.assertEqual(result[0], "variant_col") # Column name
@@ -2395,7 +2395,7 @@ def test_col_to_description_without_variant_type(self):
23952395
metadata={}
23962396
)
23972397

2398-
result = ThriftBackend._col_to_description(col, field)
2398+
result = ThriftDatabricksClient._col_to_description(col, field)
23992399

24002400
# Verify the result has string as the type (unchanged)
24012401
self.assertEqual(result[0], "normal_col") # Column name
@@ -2414,7 +2414,7 @@ def test_col_to_description_with_null_field(self):
24142414
)
24152415

24162416
# Pass None as the field
2417-
result = ThriftBackend._col_to_description(col, None)
2417+
result = ThriftDatabricksClient._col_to_description(col, None)
24182418

24192419
# Verify the result has string as the type (unchanged)
24202420
self.assertEqual(result[0], "missing_field") # Column name
@@ -2452,7 +2452,7 @@ def test_hive_schema_to_description_with_arrow_schema(self):
24522452
schema_bytes = arrow_schema.serialize().to_pybytes()
24532453

24542454
# Get the description
2455-
description = ThriftBackend._hive_schema_to_description(t_table_schema, schema_bytes)
2455+
description = ThriftDatabricksClient._hive_schema_to_description(t_table_schema, schema_bytes)
24562456

24572457
# Verify regular column type
24582458
self.assertEqual(description[0][0], "regular_col")
@@ -2473,7 +2473,7 @@ def test_hive_schema_to_description_with_null_schema_bytes(self):
24732473
t_table_schema = ttypes.TTableSchema(columns=columns)
24742474

24752475
# Get the description with null schema_bytes
2476-
description = ThriftBackend._hive_schema_to_description(t_table_schema, None)
2476+
description = ThriftDatabricksClient._hive_schema_to_description(t_table_schema, None)
24772477

24782478
# Verify column type remains unchanged
24792479
self.assertEqual(description[0][0], "regular_col")
@@ -2493,7 +2493,7 @@ def test_col_to_description_with_malformed_metadata(self):
24932493
metadata={b'Spark:DataType:SqlName': b'Some unexpected value'}
24942494
)
24952495

2496-
result = ThriftBackend._col_to_description(col, field)
2496+
result = ThriftDatabricksClient._col_to_description(col, field)
24972497

24982498
# Verify the type remains unchanged
24992499
self.assertEqual(result[0], "weird_field") # Column name

0 commit comments

Comments
 (0)