Skip to content

Commit ea8ae9f

Browse files
Merge branch 'main' into robust-metadata-sea
2 parents 11f2c54 + 59d28b0 commit ea8ae9f

File tree

11 files changed

+260
-53
lines changed

11 files changed

+260
-53
lines changed

src/databricks/sql/backend/sea/utils/filters.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -70,43 +70,19 @@ def _create_filtered_manifest(result_set: SeaResultSet, new_row_count: int):
7070
result_set: Original result set to copy manifest from
7171
new_row_count: New total row count for filtered data
7272
73-
Returns:
74-
Updated manifest copy
75-
"""
76-
filtered_manifest = deepcopy(result_set.manifest)
77-
filtered_manifest.total_row_count = new_row_count
78-
return filtered_manifest
79-
80-
@staticmethod
81-
def _create_filtered_result_set(
82-
result_set: SeaResultSet,
83-
result_data: ResultData,
84-
row_count: int,
85-
) -> "SeaResultSet":
86-
"""
87-
Create a new filtered SeaResultSet with the provided data.
88-
89-
Args:
90-
result_set: Original result set to copy parameters from
91-
result_data: New result data for the filtered set
92-
row_count: Number of rows in the filtered data
93-
94-
Returns:
95-
New filtered SeaResultSet
96-
"""
73+
from databricks.sql.backend.sea.backend import SeaDatabricksClient
9774
from databricks.sql.backend.sea.result_set import SeaResultSet
9875
99-
execute_response = ResultSetFilter._create_execute_response(result_set)
100-
filtered_manifest = ResultSetFilter._create_filtered_manifest(
101-
result_set, row_count
102-
)
76+
# Create a new SeaResultSet with the filtered data
77+
manifest = result_set.manifest
78+
manifest.total_row_count = len(filtered_rows)
10379
104-
return SeaResultSet(
80+
filtered_result_set = SeaResultSet(
10581
connection=result_set.connection,
10682
execute_response=execute_response,
10783
sea_client=cast(SeaDatabricksClient, result_set.backend),
10884
result_data=result_data,
109-
manifest=filtered_manifest,
85+
manifest=manifest,
11086
buffer_size_bytes=result_set.buffer_size_bytes,
11187
arraysize=result_set.arraysize,
11288
)

src/databricks/sql/backend/thrift_backend.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,7 +1040,6 @@ def execute_command(
10401040
max_download_threads=self.max_download_threads,
10411041
ssl_options=self._ssl_options,
10421042
has_more_rows=has_more_rows,
1043-
session_id_hex=self._session_id_hex,
10441043
)
10451044

10461045
def get_catalogs(
@@ -1079,7 +1078,6 @@ def get_catalogs(
10791078
max_download_threads=self.max_download_threads,
10801079
ssl_options=self._ssl_options,
10811080
has_more_rows=has_more_rows,
1082-
session_id_hex=self._session_id_hex,
10831081
)
10841082

10851083
def get_schemas(
@@ -1124,7 +1122,6 @@ def get_schemas(
11241122
max_download_threads=self.max_download_threads,
11251123
ssl_options=self._ssl_options,
11261124
has_more_rows=has_more_rows,
1127-
session_id_hex=self._session_id_hex,
11281125
)
11291126

11301127
def get_tables(
@@ -1173,7 +1170,6 @@ def get_tables(
11731170
max_download_threads=self.max_download_threads,
11741171
ssl_options=self._ssl_options,
11751172
has_more_rows=has_more_rows,
1176-
session_id_hex=self._session_id_hex,
11771173
)
11781174

11791175
def get_columns(
@@ -1222,7 +1218,6 @@ def get_columns(
12221218
max_download_threads=self.max_download_threads,
12231219
ssl_options=self._ssl_options,
12241220
has_more_rows=has_more_rows,
1225-
session_id_hex=self._session_id_hex,
12261221
)
12271222

12281223
def _handle_execute_response(self, resp, cursor):

src/databricks/sql/cloudfetch/downloader.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,14 @@ class DownloadableResultSettings:
5454
link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs.
5555
download_timeout (int): Timeout for download requests. Default 60 secs.
5656
max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down.
57+
min_cloudfetch_download_speed (float): Threshold in MB/s below which to log warning. Default 0.1 MB/s.
5758
"""
5859

5960
is_lz4_compressed: bool
6061
link_expiry_buffer_secs: int = 0
6162
download_timeout: int = 60
6263
max_consecutive_file_download_retries: int = 0
64+
min_cloudfetch_download_speed: float = 0.1
6365

6466

6567
class ResultSetDownloadHandler:
@@ -100,6 +102,8 @@ def run(self) -> DownloadedFile:
100102
self.link, self.settings.link_expiry_buffer_secs
101103
)
102104

105+
start_time = time.time()
106+
103107
with self._http_client.execute(
104108
method=HttpMethod.GET,
105109
url=self.link.fileLink,
@@ -112,6 +116,13 @@ def run(self) -> DownloadedFile:
112116

113117
# Save (and decompress if needed) the downloaded file
114118
compressed_data = response.content
119+
120+
# Log download metrics
121+
download_duration = time.time() - start_time
122+
self._log_download_metrics(
123+
self.link.fileLink, len(compressed_data), download_duration
124+
)
125+
115126
decompressed_data = (
116127
ResultSetDownloadHandler._decompress_data(compressed_data)
117128
if self.settings.is_lz4_compressed
@@ -138,6 +149,32 @@ def run(self) -> DownloadedFile:
138149
self.link.rowCount,
139150
)
140151

152+
def _log_download_metrics(
153+
self, url: str, bytes_downloaded: int, duration_seconds: float
154+
):
155+
"""Log download speed metrics at INFO/WARN levels."""
156+
# Calculate speed in MB/s (ensure float division for precision)
157+
speed_mbps = (float(bytes_downloaded) / (1024 * 1024)) / duration_seconds
158+
159+
urlEndpoint = url.split("?")[0]
160+
# INFO level logging
161+
logger.info(
162+
"CloudFetch download completed: %.4f MB/s, %d bytes in %.3fs from %s",
163+
speed_mbps,
164+
bytes_downloaded,
165+
duration_seconds,
166+
urlEndpoint,
167+
)
168+
169+
# WARN level logging if below threshold
170+
if speed_mbps < self.settings.min_cloudfetch_download_speed:
171+
logger.warning(
172+
"CloudFetch download slower than threshold: %.4f MB/s (threshold: %.1f MB/s) from %s",
173+
speed_mbps,
174+
self.settings.min_cloudfetch_download_speed,
175+
url,
176+
)
177+
141178
@staticmethod
142179
def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int):
143180
"""

src/databricks/sql/common/http.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
import threading
66
from dataclasses import dataclass
77
from contextlib import contextmanager
8-
from typing import Generator
8+
from typing import Generator, Optional
99
import logging
10+
from requests.adapters import HTTPAdapter
11+
from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType
1012

1113
logger = logging.getLogger(__name__)
1214

@@ -81,3 +83,70 @@ def execute(
8183

8284
def close(self):
8385
self.session.close()
86+
87+
88+
class TelemetryHTTPAdapter(HTTPAdapter):
89+
"""
90+
Custom HTTP adapter to prepare our DatabricksRetryPolicy before each request.
91+
This ensures the retry timer is started and the command type is set correctly,
92+
allowing the policy to manage its state for the duration of the request retries.
93+
"""
94+
95+
def send(self, request, **kwargs):
96+
self.max_retries.command_type = CommandType.OTHER
97+
self.max_retries.start_retry_timer()
98+
return super().send(request, **kwargs)
99+
100+
101+
class TelemetryHttpClient: # TODO: Unify all the http clients in the PySQL Connector
102+
"""Singleton HTTP client for sending telemetry data."""
103+
104+
_instance: Optional["TelemetryHttpClient"] = None
105+
_lock = threading.Lock()
106+
107+
TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT = 3
108+
TELEMETRY_RETRY_DELAY_MIN = 1.0
109+
TELEMETRY_RETRY_DELAY_MAX = 10.0
110+
TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION = 30.0
111+
112+
def __init__(self):
113+
"""Initializes the session and mounts the custom retry adapter."""
114+
retry_policy = DatabricksRetryPolicy(
115+
delay_min=self.TELEMETRY_RETRY_DELAY_MIN,
116+
delay_max=self.TELEMETRY_RETRY_DELAY_MAX,
117+
stop_after_attempts_count=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT,
118+
stop_after_attempts_duration=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION,
119+
delay_default=1.0,
120+
force_dangerous_codes=[],
121+
)
122+
adapter = TelemetryHTTPAdapter(max_retries=retry_policy)
123+
self.session = requests.Session()
124+
self.session.mount("https://", adapter)
125+
self.session.mount("http://", adapter)
126+
127+
@classmethod
128+
def get_instance(cls) -> "TelemetryHttpClient":
129+
"""Get the singleton instance of the TelemetryHttpClient."""
130+
if cls._instance is None:
131+
with cls._lock:
132+
if cls._instance is None:
133+
logger.debug("Initializing singleton TelemetryHttpClient")
134+
cls._instance = TelemetryHttpClient()
135+
return cls._instance
136+
137+
def post(self, url: str, **kwargs) -> requests.Response:
138+
"""
139+
Executes a POST request using the configured session.
140+
141+
This is a blocking call intended to be run in a background thread.
142+
"""
143+
logger.debug("Executing telemetry POST request to: %s", url)
144+
return self.session.post(url, **kwargs)
145+
146+
def close(self):
147+
"""Closes the underlying requests.Session."""
148+
logger.debug("Closing TelemetryHttpClient session.")
149+
self.session.close()
150+
# Clear the instance to allow for re-initialization if needed
151+
with TelemetryHttpClient._lock:
152+
TelemetryHttpClient._instance = None

src/databricks/sql/exc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
import logging
33

44
logger = logging.getLogger(__name__)
5-
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
6-
75

86
### PEP-249 Mandated ###
97
# https://peps.python.org/pep-0249/#exceptions
@@ -22,6 +20,8 @@ def __init__(
2220

2321
error_name = self.__class__.__name__
2422
if session_id_hex:
23+
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
24+
2525
telemetry_client = TelemetryClientFactory.get_telemetry_client(
2626
session_id_hex
2727
)

src/databricks/sql/result_set.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,11 @@ def close(self) -> None:
170170
been closed on the server for some other reason, issue a request to the server to close it.
171171
"""
172172
try:
173-
self.results.close()
173+
if self.results is not None:
174+
self.results.close()
175+
else:
176+
logger.warning("result set close: queue not initialized")
177+
174178
if (
175179
self.status != CommandState.CLOSED
176180
and not self.has_been_closed_server_side
@@ -193,7 +197,6 @@ def __init__(
193197
connection: Connection,
194198
execute_response: ExecuteResponse,
195199
thrift_client: ThriftDatabricksClient,
196-
session_id_hex: Optional[str],
197200
buffer_size_bytes: int = 104857600,
198201
arraysize: int = 10000,
199202
use_cloud_fetch: bool = True,
@@ -217,7 +220,7 @@ def __init__(
217220
:param ssl_options: SSL options for cloud fetch
218221
:param has_more_rows: Whether there are more rows to fetch
219222
"""
220-
self.num_downloaded_chunks = 0
223+
self.num_chunks = 0
221224

222225
# Initialize ThriftResultSet-specific attributes
223226
self._use_cloud_fetch = use_cloud_fetch
@@ -237,12 +240,12 @@ def __init__(
237240
lz4_compressed=execute_response.lz4_compressed,
238241
description=execute_response.description,
239242
ssl_options=ssl_options,
240-
session_id_hex=session_id_hex,
243+
session_id_hex=connection.get_session_id_hex(),
241244
statement_id=execute_response.command_id.to_hex_guid(),
242-
chunk_id=self.num_downloaded_chunks,
245+
chunk_id=self.num_chunks,
243246
)
244247
if t_row_set.resultLinks:
245-
self.num_downloaded_chunks += len(t_row_set.resultLinks)
248+
self.num_chunks += len(t_row_set.resultLinks)
246249

247250
# Call parent constructor with common attributes
248251
super().__init__(
@@ -275,11 +278,11 @@ def _fill_results_buffer(self):
275278
arrow_schema_bytes=self._arrow_schema_bytes,
276279
description=self.description,
277280
use_cloud_fetch=self._use_cloud_fetch,
278-
chunk_id=self.num_downloaded_chunks,
281+
chunk_id=self.num_chunks,
279282
)
280283
self.results = results
281284
self.has_more_rows = has_more_rows
282-
self.num_downloaded_chunks += result_links_count
285+
self.num_chunks += result_links_count
283286

284287
def _convert_columnar_table(self, table):
285288
column_names = [c[0] for c in self.description]

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import threading
22
import time
3-
import requests
43
import logging
54
from concurrent.futures import ThreadPoolExecutor
65
from typing import Dict, Optional
6+
from databricks.sql.common.http import TelemetryHttpClient
77
from databricks.sql.telemetry.models.event import (
88
TelemetryEvent,
99
DriverSystemConfiguration,
@@ -159,6 +159,7 @@ def __init__(
159159
self._driver_connection_params = None
160160
self._host_url = host_url
161161
self._executor = executor
162+
self._http_client = TelemetryHttpClient.get_instance()
162163

163164
def _export_event(self, event):
164165
"""Add an event to the batch queue and flush if batch is full"""
@@ -207,7 +208,7 @@ def _send_telemetry(self, events):
207208
try:
208209
logger.debug("Submitting telemetry request to thread pool")
209210
future = self._executor.submit(
210-
requests.post,
211+
self._http_client.post,
211212
url,
212213
data=request.to_json(),
213214
headers=headers,
@@ -433,6 +434,7 @@ def close(session_id_hex):
433434
)
434435
try:
435436
TelemetryClientFactory._executor.shutdown(wait=True)
437+
TelemetryHttpClient.close()
436438
except Exception as e:
437439
logger.debug("Failed to shutdown thread pool executor: %s", e)
438440
TelemetryClientFactory._executor = None

tests/unit/test_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def test_negative_fetch_throws_exception(self):
266266
mock_backend.fetch_results.return_value = (Mock(), False, 0)
267267

268268
result_set = ThriftResultSet(
269-
Mock(), Mock(), mock_backend, session_id_hex=Mock()
269+
Mock(), Mock(), mock_backend
270270
)
271271

272272
with self.assertRaises(ValueError) as e:

0 commit comments

Comments
 (0)