Skip to content

Commit 328c3bd

Browse files
Merge branch 'main' into sea-http-client
2 parents 6c5b37f + aee6863 commit 328c3bd

File tree

16 files changed

+806
-143
lines changed

16 files changed

+806
-143
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
WaitTimeout,
2020
MetadataCommands,
2121
)
22+
from databricks.sql.backend.sea.utils.normalize import normalize_sea_type_to_thrift
2223
from databricks.sql.thrift_api.TCLIService import ttypes
2324

2425
if TYPE_CHECKING:
@@ -322,6 +323,11 @@ def _extract_description_from_manifest(
322323
# Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
323324
name = col_data.get("name", "")
324325
type_name = col_data.get("type_name", "")
326+
327+
# Normalize SEA type to Thrift conventions before any processing
328+
type_name = normalize_sea_type_to_thrift(type_name, col_data)
329+
330+
# Now strip _TYPE suffix and convert to lowercase
325331
type_name = (
326332
type_name[:-5] if type_name.endswith("_TYPE") else type_name
327333
).lower()

src/databricks/sql/backend/sea/result_set.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,20 +92,19 @@ def _convert_json_types(self, row: List[str]) -> List[Any]:
9292
converted_row = []
9393

9494
for i, value in enumerate(row):
95+
column_name = self.description[i][0]
9596
column_type = self.description[i][1]
9697
precision = self.description[i][4]
9798
scale = self.description[i][5]
9899

99-
try:
100-
converted_value = SqlTypeConverter.convert_value(
101-
value, column_type, precision=precision, scale=scale
102-
)
103-
converted_row.append(converted_value)
104-
except Exception as e:
105-
logger.warning(
106-
f"Error converting value '{value}' to {column_type}: {e}"
107-
)
108-
converted_row.append(value)
100+
converted_value = SqlTypeConverter.convert_value(
101+
value,
102+
column_type,
103+
column_name=column_name,
104+
precision=precision,
105+
scale=scale,
106+
)
107+
converted_row.append(converted_value)
109108

110109
return converted_row
111110

src/databricks/sql/backend/sea/utils/conversion.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -50,60 +50,65 @@ def _convert_decimal(
5050

5151
class SqlType:
5252
"""
53-
SQL type constants
53+
SQL type constants based on Thrift TTypeId values.
5454
55-
The list of types can be found in the SEA REST API Reference:
56-
https://docs.databricks.com/api/workspace/statementexecution/executestatement
55+
These correspond to the normalized type names that come from the SEA backend
56+
after normalize_sea_type_to_thrift processing (lowercase, without _TYPE suffix).
5757
"""
5858

5959
# Numeric types
60-
BYTE = "byte"
61-
SHORT = "short"
62-
INT = "int"
63-
LONG = "long"
64-
FLOAT = "float"
65-
DOUBLE = "double"
66-
DECIMAL = "decimal"
60+
TINYINT = "tinyint" # Maps to TTypeId.TINYINT_TYPE
61+
SMALLINT = "smallint" # Maps to TTypeId.SMALLINT_TYPE
62+
INT = "int" # Maps to TTypeId.INT_TYPE
63+
BIGINT = "bigint" # Maps to TTypeId.BIGINT_TYPE
64+
FLOAT = "float" # Maps to TTypeId.FLOAT_TYPE
65+
DOUBLE = "double" # Maps to TTypeId.DOUBLE_TYPE
66+
DECIMAL = "decimal" # Maps to TTypeId.DECIMAL_TYPE
6767

6868
# Boolean type
69-
BOOLEAN = "boolean"
69+
BOOLEAN = "boolean" # Maps to TTypeId.BOOLEAN_TYPE
7070

7171
# Date/Time types
72-
DATE = "date"
73-
TIMESTAMP = "timestamp"
74-
INTERVAL = "interval"
72+
DATE = "date" # Maps to TTypeId.DATE_TYPE
73+
TIMESTAMP = "timestamp" # Maps to TTypeId.TIMESTAMP_TYPE
74+
INTERVAL_YEAR_MONTH = (
75+
"interval_year_month" # Maps to TTypeId.INTERVAL_YEAR_MONTH_TYPE
76+
)
77+
INTERVAL_DAY_TIME = "interval_day_time" # Maps to TTypeId.INTERVAL_DAY_TIME_TYPE
7578

7679
# String types
77-
CHAR = "char"
78-
STRING = "string"
80+
CHAR = "char" # Maps to TTypeId.CHAR_TYPE
81+
VARCHAR = "varchar" # Maps to TTypeId.VARCHAR_TYPE
82+
STRING = "string" # Maps to TTypeId.STRING_TYPE
7983

8084
# Binary type
81-
BINARY = "binary"
85+
BINARY = "binary" # Maps to TTypeId.BINARY_TYPE
8286

8387
# Complex types
84-
ARRAY = "array"
85-
MAP = "map"
86-
STRUCT = "struct"
88+
ARRAY = "array" # Maps to TTypeId.ARRAY_TYPE
89+
MAP = "map" # Maps to TTypeId.MAP_TYPE
90+
STRUCT = "struct" # Maps to TTypeId.STRUCT_TYPE
8791

8892
# Other types
89-
NULL = "null"
90-
USER_DEFINED_TYPE = "user_defined_type"
93+
NULL = "null" # Maps to TTypeId.NULL_TYPE
94+
UNION = "union" # Maps to TTypeId.UNION_TYPE
95+
USER_DEFINED = "user_defined" # Maps to TTypeId.USER_DEFINED_TYPE
9196

9297

9398
class SqlTypeConverter:
9499
"""
95100
Utility class for converting SQL types to Python types.
96-
Based on the types supported by the Databricks SDK.
101+
Based on the Thrift TTypeId types after normalization.
97102
"""
98103

99104
# SQL type to conversion function mapping
100105
# TODO: complex types
101106
TYPE_MAPPING: Dict[str, Callable] = {
102107
# Numeric types
103-
SqlType.BYTE: lambda v: int(v),
104-
SqlType.SHORT: lambda v: int(v),
108+
SqlType.TINYINT: lambda v: int(v),
109+
SqlType.SMALLINT: lambda v: int(v),
105110
SqlType.INT: lambda v: int(v),
106-
SqlType.LONG: lambda v: int(v),
111+
SqlType.BIGINT: lambda v: int(v),
107112
SqlType.FLOAT: lambda v: float(v),
108113
SqlType.DOUBLE: lambda v: float(v),
109114
SqlType.DECIMAL: _convert_decimal,
@@ -112,30 +117,34 @@ class SqlTypeConverter:
112117
# Date/Time types
113118
SqlType.DATE: lambda v: datetime.date.fromisoformat(v),
114119
SqlType.TIMESTAMP: lambda v: parser.parse(v),
115-
SqlType.INTERVAL: lambda v: v, # Keep as string for now
120+
SqlType.INTERVAL_YEAR_MONTH: lambda v: v, # Keep as string for now
121+
SqlType.INTERVAL_DAY_TIME: lambda v: v, # Keep as string for now
116122
# String types - no conversion needed
117123
SqlType.CHAR: lambda v: v,
124+
SqlType.VARCHAR: lambda v: v,
118125
SqlType.STRING: lambda v: v,
119126
# Binary type
120127
SqlType.BINARY: lambda v: bytes.fromhex(v),
121128
# Other types
122129
SqlType.NULL: lambda v: None,
123130
# Complex types and user-defined types return as-is
124-
SqlType.USER_DEFINED_TYPE: lambda v: v,
131+
SqlType.USER_DEFINED: lambda v: v,
125132
}
126133

127134
@staticmethod
128135
def convert_value(
129136
value: str,
130137
sql_type: str,
138+
column_name: Optional[str],
131139
**kwargs,
132140
) -> object:
133141
"""
134142
Convert a string value to the appropriate Python type based on SQL type.
135143
136144
Args:
137145
value: The string value to convert
138-
sql_type: The SQL type (e.g., 'int', 'decimal')
146+
sql_type: The SQL type (e.g., 'tinyint', 'decimal')
147+
column_name: The name of the column being converted
139148
**kwargs: Additional keyword arguments for the conversion function
140149
141150
Returns:
@@ -155,6 +164,10 @@ def convert_value(
155164
return converter_func(value, precision, scale)
156165
else:
157166
return converter_func(value)
158-
except (ValueError, TypeError, decimal.InvalidOperation) as e:
159-
logger.warning(f"Error converting value '{value}' to {sql_type}: {e}")
167+
except Exception as e:
168+
warning_message = f"Error converting value '{value}' to {sql_type}"
169+
if column_name:
170+
warning_message += f" in column {column_name}"
171+
warning_message += f": {e}"
172+
logger.warning(warning_message)
160173
return value
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""
2+
Type normalization utilities for SEA backend.
3+
4+
This module provides functionality to normalize SEA type names to match
5+
Thrift type naming conventions.
6+
"""
7+
8+
from typing import Dict, Any
9+
10+
# SEA types that need to be translated to Thrift types
11+
# The list of all SEA types is available in the REST reference at:
12+
# https://docs.databricks.com/api/workspace/statementexecution/executestatement
13+
# The list of all Thrift types can be found in the ttypes.TTypeId definition
14+
# The SEA types that do not align with Thrift are explicitly mapped below
15+
SEA_TO_THRIFT_TYPE_MAP = {
16+
"BYTE": "TINYINT",
17+
"SHORT": "SMALLINT",
18+
"LONG": "BIGINT",
19+
"INTERVAL": "INTERVAL", # Default mapping, will be overridden if type_interval_type is present
20+
}
21+
22+
23+
def normalize_sea_type_to_thrift(type_name: str, col_data: Dict[str, Any]) -> str:
24+
"""
25+
Normalize SEA type names to match Thrift type naming conventions.
26+
27+
Args:
28+
type_name: The type name from SEA (e.g., "BYTE", "LONG", "INTERVAL")
29+
col_data: The full column data dictionary from manifest (for accessing type_interval_type)
30+
31+
Returns:
32+
Normalized type name matching Thrift conventions
33+
"""
34+
# Early return if type doesn't need mapping
35+
if type_name not in SEA_TO_THRIFT_TYPE_MAP:
36+
return type_name
37+
38+
normalized_type = SEA_TO_THRIFT_TYPE_MAP[type_name]
39+
40+
# Special handling for interval types
41+
if type_name == "INTERVAL":
42+
type_interval_type = col_data.get("type_interval_type")
43+
if type_interval_type:
44+
return (
45+
"INTERVAL_YEAR_MONTH"
46+
if any(t in type_interval_type.upper() for t in ["YEAR", "MONTH"])
47+
else "INTERVAL_DAY_TIME"
48+
)
49+
50+
return normalized_type

src/databricks/sql/client.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -248,11 +248,8 @@ def read(self) -> Optional[OAuthToken]:
248248
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)
249249
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
250250
self._cursors = [] # type: List[Cursor]
251-
252-
self.server_telemetry_enabled = True
253-
self.client_telemetry_enabled = kwargs.get("enable_telemetry", False)
254-
self.telemetry_enabled = (
255-
self.client_telemetry_enabled and self.server_telemetry_enabled
251+
self.telemetry_batch_size = kwargs.get(
252+
"telemetry_batch_size", TelemetryClientFactory.DEFAULT_BATCH_SIZE
256253
)
257254

258255
try:
@@ -285,11 +282,16 @@ def read(self) -> Optional[OAuthToken]:
285282
)
286283
self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None)
287284

285+
self.force_enable_telemetry = kwargs.get("force_enable_telemetry", False)
286+
self.enable_telemetry = kwargs.get("enable_telemetry", False)
287+
self.telemetry_enabled = TelemetryHelper.is_telemetry_enabled(self)
288+
288289
TelemetryClientFactory.initialize_telemetry_client(
289290
telemetry_enabled=self.telemetry_enabled,
290291
session_id_hex=self.get_session_id_hex(),
291292
auth_provider=self.session.auth_provider,
292293
host_url=self.session.host,
294+
batch_size=self.telemetry_batch_size,
293295
)
294296

295297
self._telemetry_client = TelemetryClientFactory.get_telemetry_client(

src/databricks/sql/cloudfetch/downloader.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,14 @@ class DownloadableResultSettings:
5454
link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs.
5555
download_timeout (int): Timeout for download requests. Default 60 secs.
5656
max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down.
57+
min_cloudfetch_download_speed (float): Threshold in MB/s below which to log warning. Default 0.1 MB/s.
5758
"""
5859

5960
is_lz4_compressed: bool
6061
link_expiry_buffer_secs: int = 0
6162
download_timeout: int = 60
6263
max_consecutive_file_download_retries: int = 0
64+
min_cloudfetch_download_speed: float = 0.1
6365

6466

6567
class ResultSetDownloadHandler:
@@ -100,6 +102,8 @@ def run(self) -> DownloadedFile:
100102
self.link, self.settings.link_expiry_buffer_secs
101103
)
102104

105+
start_time = time.time()
106+
103107
with self._http_client.execute(
104108
method=HttpMethod.GET,
105109
url=self.link.fileLink,
@@ -112,6 +116,13 @@ def run(self) -> DownloadedFile:
112116

113117
# Save (and decompress if needed) the downloaded file
114118
compressed_data = response.content
119+
120+
# Log download metrics
121+
download_duration = time.time() - start_time
122+
self._log_download_metrics(
123+
self.link.fileLink, len(compressed_data), download_duration
124+
)
125+
115126
decompressed_data = (
116127
ResultSetDownloadHandler._decompress_data(compressed_data)
117128
if self.settings.is_lz4_compressed
@@ -138,6 +149,32 @@ def run(self) -> DownloadedFile:
138149
self.link.rowCount,
139150
)
140151

152+
def _log_download_metrics(
153+
self, url: str, bytes_downloaded: int, duration_seconds: float
154+
):
155+
"""Log download speed metrics at INFO/WARN levels."""
156+
# Calculate speed in MB/s (ensure float division for precision)
157+
speed_mbps = (float(bytes_downloaded) / (1024 * 1024)) / duration_seconds
158+
159+
urlEndpoint = url.split("?")[0]
160+
# INFO level logging
161+
logger.info(
162+
"CloudFetch download completed: %.4f MB/s, %d bytes in %.3fs from %s",
163+
speed_mbps,
164+
bytes_downloaded,
165+
duration_seconds,
166+
urlEndpoint,
167+
)
168+
169+
# WARN level logging if below threshold
170+
if speed_mbps < self.settings.min_cloudfetch_download_speed:
171+
logger.warning(
172+
"CloudFetch download slower than threshold: %.4f MB/s (threshold: %.1f MB/s) from %s",
173+
speed_mbps,
174+
self.settings.min_cloudfetch_download_speed,
175+
url,
176+
)
177+
141178
@staticmethod
142179
def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int):
143180
"""

0 commit comments

Comments
 (0)