Skip to content

Commit 82e9c4f

Browse files
remove hardcoding in SqlType
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 7d3174f commit 82e9c4f

File tree

3 files changed

+46
-39
lines changed

3 files changed

+46
-39
lines changed

src/databricks/sql/backend/sea/result_set.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -313,17 +313,17 @@ def _prepare_column_mapping(self) -> None:
313313
for new_idx, result_column in enumerate(self._metadata_columns or []):
314314
# Find the corresponding SEA column
315315
if (
316-
result_column.result_set_column_name
317-
and result_column.result_set_column_name in sea_column_indices
316+
result_column.sea_col_name
317+
and result_column.sea_col_name in sea_column_indices
318318
):
319-
old_idx = sea_column_indices[result_column.result_set_column_name]
319+
old_idx = sea_column_indices[result_column.sea_col_name]
320320
self._column_index_mapping[new_idx] = old_idx
321321
# Use the original column metadata but with JDBC name
322322
old_col = self.description[old_idx]
323323
new_description.append(
324324
(
325-
result_column.column_name, # JDBC name
326-
result_column.column_type, # Expected type
325+
result_column.thrift_col_name, # JDBC name
326+
result_column.thrift_col_type, # Expected type
327327
old_col[2], # display_size
328328
old_col[3], # internal_size
329329
old_col[4], # precision
@@ -335,8 +335,8 @@ def _prepare_column_mapping(self) -> None:
335335
# Column doesn't exist in SEA - add with None values
336336
new_description.append(
337337
(
338-
result_column.column_name,
339-
result_column.column_type,
338+
result_column.thrift_col_name,
339+
result_column.thrift_col_type,
340340
None,
341341
None,
342342
None,
@@ -380,7 +380,7 @@ def _normalise_arrow_metadata_cols(self, table: "pyarrow.Table") -> "pyarrow.Tab
380380
null_array = pyarrow.nulls(table.num_rows)
381381
new_columns.append(null_array)
382382

383-
column_names.append(result_column.column_name)
383+
column_names.append(result_column.thrift_col_name)
384384

385385
return pyarrow.Table.from_arrays(new_columns, names=column_names)
386386

src/databricks/sql/backend/sea/utils/conversion.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from dateutil import parser
1212
from typing import Callable, Dict, Optional
1313

14+
from databricks.sql.thrift_api.TCLIService import ttypes
15+
1416
logger = logging.getLogger(__name__)
1517

1618

@@ -56,43 +58,49 @@ class SqlType:
5658
after normalize_sea_type_to_thrift processing (lowercase, without _TYPE suffix).
5759
"""
5860

61+
@staticmethod
62+
def _get_type_name(thrift_type_id: int) -> str:
63+
type_name = ttypes.TTypeId._VALUES_TO_NAMES[thrift_type_id]
64+
type_name = type_name.lower()
65+
if type_name.endswith("_type"):
66+
type_name = type_name[:-5]
67+
return type_name
68+
5969
# Numeric types
60-
TINYINT = "tinyint" # Maps to TTypeId.TINYINT_TYPE
61-
SMALLINT = "smallint" # Maps to TTypeId.SMALLINT_TYPE
62-
INT = "int" # Maps to TTypeId.INT_TYPE
63-
BIGINT = "bigint" # Maps to TTypeId.BIGINT_TYPE
64-
FLOAT = "float" # Maps to TTypeId.FLOAT_TYPE
65-
DOUBLE = "double" # Maps to TTypeId.DOUBLE_TYPE
66-
DECIMAL = "decimal" # Maps to TTypeId.DECIMAL_TYPE
70+
TINYINT = _get_type_name(ttypes.TTypeId.TINYINT_TYPE)
71+
SMALLINT = _get_type_name(ttypes.TTypeId.SMALLINT_TYPE)
72+
INT = _get_type_name(ttypes.TTypeId.INT_TYPE)
73+
BIGINT = _get_type_name(ttypes.TTypeId.BIGINT_TYPE)
74+
FLOAT = _get_type_name(ttypes.TTypeId.FLOAT_TYPE)
75+
DOUBLE = _get_type_name(ttypes.TTypeId.DOUBLE_TYPE)
76+
DECIMAL = _get_type_name(ttypes.TTypeId.DECIMAL_TYPE)
6777

6878
# Boolean type
69-
BOOLEAN = "boolean" # Maps to TTypeId.BOOLEAN_TYPE
79+
BOOLEAN = _get_type_name(ttypes.TTypeId.BOOLEAN_TYPE)
7080

7181
# Date/Time types
72-
DATE = "date" # Maps to TTypeId.DATE_TYPE
73-
TIMESTAMP = "timestamp" # Maps to TTypeId.TIMESTAMP_TYPE
74-
INTERVAL_YEAR_MONTH = (
75-
"interval_year_month" # Maps to TTypeId.INTERVAL_YEAR_MONTH_TYPE
76-
)
77-
INTERVAL_DAY_TIME = "interval_day_time" # Maps to TTypeId.INTERVAL_DAY_TIME_TYPE
82+
DATE = _get_type_name(ttypes.TTypeId.DATE_TYPE)
83+
TIMESTAMP = _get_type_name(ttypes.TTypeId.TIMESTAMP_TYPE)
84+
INTERVAL_YEAR_MONTH = _get_type_name(ttypes.TTypeId.INTERVAL_YEAR_MONTH_TYPE)
85+
INTERVAL_DAY_TIME = _get_type_name(ttypes.TTypeId.INTERVAL_DAY_TIME_TYPE)
7886

7987
# String types
80-
CHAR = "char" # Maps to TTypeId.CHAR_TYPE
81-
VARCHAR = "varchar" # Maps to TTypeId.VARCHAR_TYPE
82-
STRING = "string" # Maps to TTypeId.STRING_TYPE
88+
CHAR = _get_type_name(ttypes.TTypeId.CHAR_TYPE)
89+
VARCHAR = _get_type_name(ttypes.TTypeId.VARCHAR_TYPE)
90+
STRING = _get_type_name(ttypes.TTypeId.STRING_TYPE)
8391

8492
# Binary type
85-
BINARY = "binary" # Maps to TTypeId.BINARY_TYPE
93+
BINARY = _get_type_name(ttypes.TTypeId.BINARY_TYPE)
8694

8795
# Complex types
88-
ARRAY = "array" # Maps to TTypeId.ARRAY_TYPE
89-
MAP = "map" # Maps to TTypeId.MAP_TYPE
90-
STRUCT = "struct" # Maps to TTypeId.STRUCT_TYPE
96+
ARRAY = _get_type_name(ttypes.TTypeId.ARRAY_TYPE)
97+
MAP = _get_type_name(ttypes.TTypeId.MAP_TYPE)
98+
STRUCT = _get_type_name(ttypes.TTypeId.STRUCT_TYPE)
9199

92100
# Other types
93-
NULL = "null" # Maps to TTypeId.NULL_TYPE
94-
UNION = "union" # Maps to TTypeId.UNION_TYPE
95-
USER_DEFINED = "user_defined" # Maps to TTypeId.USER_DEFINED_TYPE
101+
NULL = _get_type_name(ttypes.TTypeId.NULL_TYPE)
102+
UNION = _get_type_name(ttypes.TTypeId.UNION_TYPE)
103+
USER_DEFINED = _get_type_name(ttypes.TTypeId.USER_DEFINED_TYPE)
96104

97105

98106
class SqlTypeConverter:

src/databricks/sql/backend/sea/utils/result_column.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@ class ResultColumn:
88
Represents a mapping between JDBC specification column names and actual result set column names.
99
1010
Attributes:
11-
column_name: JDBC specification column name (e.g., "TABLE_CAT")
12-
result_set_column_name: Server result column name from SEA (e.g., "catalog")
13-
column_type: SQL type code from databricks.sql.types
14-
transform_value: Optional function to transform values for this column
11+
thrift_col_name: Column name as returned by Thrift (e.g., "TABLE_CAT")
12+
sea_col_name: Server result column name from SEA (e.g., "catalog")
13+
thrift_col_type: SQL type name
1514
"""
1615

17-
column_name: str
18-
result_set_column_name: Optional[str] # None if SEA doesn't return this column
19-
column_type: str
16+
thrift_col_name: str
17+
sea_col_name: Optional[str] # None if SEA doesn't return this column
18+
thrift_col_type: str

0 commit comments

Comments
 (0)