Skip to content

Commit a636bdc

Browse files
type normalisation for SEA
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 83e45ae commit a636bdc

File tree

3 files changed

+114
-0
lines changed

3 files changed

+114
-0
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
WaitTimeout,
2020
MetadataCommands,
2121
)
22+
from databricks.sql.backend.sea.utils.normalize import normalize_sea_type_to_thrift
2223
from databricks.sql.thrift_api.TCLIService import ttypes
2324

2425
if TYPE_CHECKING:
@@ -322,6 +323,11 @@ def _extract_description_from_manifest(
322323
# Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
323324
name = col_data.get("name", "")
324325
type_name = col_data.get("type_name", "")
326+
327+
# Normalize SEA type to Thrift conventions before any processing
328+
type_name = normalize_sea_type_to_thrift(type_name, col_data)
329+
330+
# Now strip _TYPE suffix and convert to lowercase
325331
type_name = (
326332
type_name[:-5] if type_name.endswith("_TYPE") else type_name
327333
).lower()
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""
2+
Type normalization utilities for SEA backend.
3+
4+
This module provides functionality to normalize SEA type names to match
5+
Thrift type naming conventions.
6+
"""
7+
8+
from typing import Dict, Any
9+
10+
# Mapping of SEA types that need to be translated to Thrift conventions
11+
SEA_TO_THRIFT_TYPE_MAP = {
12+
"BYTE": "TINYINT",
13+
"SHORT": "SMALLINT",
14+
"LONG": "BIGINT",
15+
"INTERVAL": "INTERVAL", # Default mapping, will be overridden if type_interval_type is present
16+
}
17+
18+
19+
def normalize_sea_type_to_thrift(type_name: str, col_data: Dict[str, Any]) -> str:
20+
"""
21+
Normalize SEA type names to match Thrift type naming conventions.
22+
23+
Args:
24+
type_name: The type name from SEA (e.g., "BYTE", "LONG", "INTERVAL")
25+
col_data: The full column data dictionary from manifest (for accessing type_interval_type)
26+
27+
Returns:
28+
Normalized type name matching Thrift conventions
29+
"""
30+
# Early return if type doesn't need mapping
31+
if type_name not in SEA_TO_THRIFT_TYPE_MAP:
32+
return type_name
33+
34+
normalized_type = SEA_TO_THRIFT_TYPE_MAP[type_name]
35+
36+
# Special handling for interval types
37+
if type_name == "INTERVAL":
38+
type_interval_type = col_data.get("type_interval_type")
39+
if type_interval_type:
40+
if any(t in type_interval_type.upper() for t in ["YEAR", "MONTH"]):
41+
return "INTERVAL_YEAR_MONTH"
42+
elif any(
43+
t in type_interval_type.upper()
44+
for t in ["DAY", "HOUR", "MINUTE", "SECOND"]
45+
):
46+
return "INTERVAL_DAY_TIME"
47+
48+
return normalized_type

tests/unit/test_sea_backend.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,66 @@ def test_extract_description_from_manifest(self, sea_client):
550550
assert description[1][1] == "int" # type_code
551551
assert description[1][6] is None # null_ok
552552

553+
def test_extract_description_from_manifest_with_type_normalization(
554+
self, sea_client
555+
):
556+
"""Test _extract_description_from_manifest with SEA to Thrift type normalization."""
557+
manifest_obj = MagicMock()
558+
manifest_obj.schema = {
559+
"columns": [
560+
{
561+
"name": "byte_col",
562+
"type_name": "BYTE",
563+
},
564+
{
565+
"name": "short_col",
566+
"type_name": "SHORT",
567+
},
568+
{
569+
"name": "long_col",
570+
"type_name": "LONG",
571+
},
572+
{
573+
"name": "interval_ym_col",
574+
"type_name": "INTERVAL",
575+
"type_interval_type": "YEAR TO MONTH",
576+
},
577+
{
578+
"name": "interval_dt_col",
579+
"type_name": "INTERVAL",
580+
"type_interval_type": "DAY TO SECOND",
581+
},
582+
{
583+
"name": "interval_default_col",
584+
"type_name": "INTERVAL",
585+
# No type_interval_type field
586+
},
587+
]
588+
}
589+
590+
description = sea_client._extract_description_from_manifest(manifest_obj)
591+
assert description is not None
592+
assert len(description) == 6
593+
594+
# Check normalized types
595+
assert description[0][0] == "byte_col"
596+
assert description[0][1] == "tinyint" # BYTE -> tinyint
597+
598+
assert description[1][0] == "short_col"
599+
assert description[1][1] == "smallint" # SHORT -> smallint
600+
601+
assert description[2][0] == "long_col"
602+
assert description[2][1] == "bigint" # LONG -> bigint
603+
604+
assert description[3][0] == "interval_ym_col"
605+
assert description[3][1] == "interval_year_month" # INTERVAL with YEAR/MONTH
606+
607+
assert description[4][0] == "interval_dt_col"
608+
assert description[4][1] == "interval_day_time" # INTERVAL with DAY/TIME
609+
610+
assert description[5][0] == "interval_default_col"
611+
assert description[5][1] == "interval" # INTERVAL without subtype
612+
553613
def test_filter_session_configuration(self):
554614
"""Test that _filter_session_configuration converts all values to strings."""
555615
session_config = {

0 commit comments

Comments
 (0)