Skip to content

Commit a618115

Browse files
Merge branch 'sea-hybrid' into sea-decouple-link-fetch
2 parents d038d84 + 2701e5d commit a618115

File tree

10 files changed

+253
-81
lines changed

10 files changed

+253
-81
lines changed

src/databricks/sql/backend/databricks_client.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from abc import ABC, abstractmethod
44
from typing import Dict, List, Optional, Any, Union, TYPE_CHECKING
55

6-
from databricks.sql.types import SSLOptions
7-
86
if TYPE_CHECKING:
97
from databricks.sql.client import Cursor
108
from databricks.sql.result_set import ResultSet
@@ -24,13 +22,6 @@ class DatabricksClient(ABC):
2422
- Fetching metadata about catalogs, schemas, tables, and columns
2523
"""
2624

27-
def __init__(self, ssl_options: SSLOptions, **kwargs):
28-
self._use_arrow_native_complex_types = kwargs.get(
29-
"_use_arrow_native_complex_types", True
30-
)
31-
self._max_download_threads = kwargs.get("max_download_threads", 10)
32-
self._ssl_options = ssl_options
33-
3425
# == Connection and Session Management ==
3526
@abstractmethod
3627
def open_session(
@@ -110,6 +101,7 @@ def execute_command(
110101
parameters: List of parameters to bind to the query
111102
async_op: Whether to execute the command asynchronously
112103
enforce_embedded_schema_correctness: Whether to enforce schema correctness
104+
row_limit: Maximum number of rows in the response.
113105
114106
Returns:
115107
If async_op is False, returns a ResultSet object containing the

src/databricks/sql/backend/sea/backend.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,19 @@ def __init__(
124124
http_path,
125125
)
126126

127-
super().__init__(ssl_options=ssl_options, **kwargs)
127+
self._max_download_threads = kwargs.get("max_download_threads", 10)
128+
self._ssl_options = ssl_options
129+
self._use_arrow_native_complex_types = kwargs.get(
130+
"_use_arrow_native_complex_types", True
131+
)
128132

129133
self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", True)
130134

131135
# Extract warehouse ID from http_path
132136
self.warehouse_id = self._extract_warehouse_id(http_path)
133137

134138
# Initialize HTTP client
135-
self.http_client = SeaHttpClient(
139+
self._http_client = SeaHttpClient(
136140
server_hostname=server_hostname,
137141
port=port,
138142
http_path=http_path,
@@ -224,7 +228,7 @@ def open_session(
224228
schema=schema,
225229
)
226230

227-
response = self.http_client._make_request(
231+
response = self._http_client._make_request(
228232
method="POST", path=self.SESSION_PATH, data=request_data.to_dict()
229233
)
230234

@@ -264,7 +268,7 @@ def close_session(self, session_id: SessionId) -> None:
264268
session_id=sea_session_id,
265269
)
266270

267-
self.http_client._make_request(
271+
self._http_client._make_request(
268272
method="DELETE",
269273
path=self.SESSION_PATH_WITH_ID.format(sea_session_id),
270274
data=request_data.to_dict(),
@@ -443,7 +447,9 @@ def execute_command(
443447
sea_parameters.append(
444448
StatementParameter(
445449
name=param.name,
446-
value=param.value,
450+
value=(
451+
param.value.stringValue if param.value is not None else None
452+
),
447453
type=param.type,
448454
)
449455
)
@@ -477,7 +483,7 @@ def execute_command(
477483
result_compression=result_compression,
478484
)
479485

480-
response_data = self.http_client._make_request(
486+
response_data = self._http_client._make_request(
481487
method="POST", path=self.STATEMENT_PATH, data=request.to_dict()
482488
)
483489
response = ExecuteStatementResponse.from_dict(response_data)
@@ -522,7 +528,7 @@ def cancel_command(self, command_id: CommandId) -> None:
522528
raise ValueError("Not a valid SEA command ID")
523529

524530
request = CancelStatementRequest(statement_id=sea_statement_id)
525-
self.http_client._make_request(
531+
self._http_client._make_request(
526532
method="POST",
527533
path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id),
528534
data=request.to_dict(),
@@ -547,7 +553,7 @@ def close_command(self, command_id: CommandId) -> None:
547553
raise ValueError("Not a valid SEA command ID")
548554

549555
request = CloseStatementRequest(statement_id=sea_statement_id)
550-
self.http_client._make_request(
556+
self._http_client._make_request(
551557
method="DELETE",
552558
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
553559
data=request.to_dict(),
@@ -575,7 +581,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
575581
raise ValueError("Not a valid SEA command ID")
576582

577583
request = GetStatementRequest(statement_id=sea_statement_id)
578-
response_data = self.http_client._make_request(
584+
response_data = self._http_client._make_request(
579585
method="GET",
580586
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
581587
data=request.to_dict(),
@@ -615,7 +621,7 @@ def get_execution_result(
615621
request = GetStatementRequest(statement_id=sea_statement_id)
616622

617623
# Get the statement result
618-
response_data = self.http_client._make_request(
624+
response_data = self._http_client._make_request(
619625
method="GET",
620626
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
621627
data=request.to_dict(),
@@ -649,13 +655,13 @@ def get_chunk_links(
649655
ExternalLink: External link for the chunk
650656
"""
651657

652-
response_data = self.http_client._make_request(
658+
response_data = self._http_client._make_request(
653659
method="GET",
654660
path=self.CHUNK_PATH_WITH_ID_AND_INDEX.format(statement_id, chunk_index),
655661
)
656662
response = GetChunksResponse.from_dict(response_data)
657663

658-
links = response.external_links
664+
links = response.external_links or []
659665
return links
660666

661667
# == Metadata Operations ==

src/databricks/sql/backend/sea/models/responses.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
import base64
8-
from typing import Dict, Any, List
8+
from typing import Dict, Any, List, Optional
99
from dataclasses import dataclass
1010

1111
from databricks.sql.backend.types import CommandState
@@ -165,34 +165,33 @@ def from_dict(cls, data: Dict[str, Any]) -> "CreateSessionResponse":
165165

166166
@dataclass
167167
class GetChunksResponse:
168-
"""Response from getting chunks for a statement."""
169-
170-
statement_id: str
171-
external_links: List[ExternalLink]
168+
"""
169+
Response from getting chunks for a statement.
170+
171+
The response model can be found in the docs, here:
172+
https://docs.databricks.com/api/workspace/statementexecution/getstatementresultchunkn
173+
"""
174+
175+
data: Optional[List[List[Any]]] = None
176+
external_links: Optional[List[ExternalLink]] = None
177+
byte_count: Optional[int] = None
178+
chunk_index: Optional[int] = None
179+
next_chunk_index: Optional[int] = None
180+
next_chunk_internal_link: Optional[str] = None
181+
row_count: Optional[int] = None
182+
row_offset: Optional[int] = None
172183

173184
@classmethod
174185
def from_dict(cls, data: Dict[str, Any]) -> "GetChunksResponse":
175186
"""Create a GetChunksResponse from a dictionary."""
176-
external_links = []
177-
if "external_links" in data:
178-
for link_data in data["external_links"]:
179-
external_links.append(
180-
ExternalLink(
181-
external_link=link_data.get("external_link", ""),
182-
expiration=link_data.get("expiration", ""),
183-
chunk_index=link_data.get("chunk_index", 0),
184-
byte_count=link_data.get("byte_count", 0),
185-
row_count=link_data.get("row_count", 0),
186-
row_offset=link_data.get("row_offset", 0),
187-
next_chunk_index=link_data.get("next_chunk_index"),
188-
next_chunk_internal_link=link_data.get(
189-
"next_chunk_internal_link"
190-
),
191-
http_headers=link_data.get("http_headers"),
192-
)
193-
)
194-
187+
result = _parse_result({"result": data})
195188
return cls(
196-
statement_id=data.get("statement_id", ""),
197-
external_links=external_links,
189+
data=result.data,
190+
external_links=result.external_links,
191+
byte_count=result.byte_count,
192+
chunk_index=result.chunk_index,
193+
next_chunk_index=result.next_chunk_index,
194+
next_chunk_internal_link=result.next_chunk_internal_link,
195+
row_count=result.row_count,
196+
row_offset=result.row_offset,
198197
)

src/databricks/sql/backend/sea/queue.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
ResultManifest,
2323
)
2424
from databricks.sql.backend.sea.utils.constants import ResultFormat
25-
from databricks.sql.exc import ProgrammingError
25+
from databricks.sql.exc import ProgrammingError, ServerOperationError
2626
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
2727
from databricks.sql.types import SSLOptions
2828
from databricks.sql.utils import (
@@ -83,7 +83,7 @@ def build_queue(
8383

8484
# EXTERNAL_LINKS disposition
8585
return SeaCloudFetchQueue(
86-
initial_links=result_data.external_links or [],
86+
result_data=result_data,
8787
max_download_threads=max_download_threads,
8888
ssl_options=ssl_options,
8989
sea_client=sea_client,
@@ -117,6 +117,9 @@ def remaining_rows(self) -> List[List[str]]:
117117
self.cur_row_index += len(slice)
118118
return slice
119119

120+
def close(self):
121+
return
122+
120123

121124
class LinkFetcher:
122125
def __init__(
@@ -218,7 +221,7 @@ class SeaCloudFetchQueue(CloudFetchQueue):
218221

219222
def __init__(
220223
self,
221-
initial_links: List["ExternalLink"],
224+
result_data: ResultData,
222225
max_download_threads: int,
223226
ssl_options: SSLOptions,
224227
sea_client: "SeaDatabricksClient",
@@ -252,14 +255,22 @@ def __init__(
252255

253256
self._sea_client = sea_client
254257
self._statement_id = statement_id
258+
self._total_chunk_count = total_chunk_count
255259

256260
logger.debug(
257261
"SeaCloudFetchQueue: Initialize CloudFetch loader for statement {}, total chunks: {}".format(
258262
statement_id, total_chunk_count
259263
)
260264
)
261265

262-
if total_chunk_count < 1:
266+
initial_links = result_data.external_links or []
267+
self._chunk_index_to_link = {link.chunk_index: link for link in initial_links}
268+
269+
# Track the current chunk we're processing
270+
self._current_chunk_index = 0
271+
first_link = self._chunk_index_to_link.get(self._current_chunk_index, None)
272+
if not first_link:
273+
# possibly an empty response
263274
return
264275

265276
self.current_chunk_index = 0

src/databricks/sql/backend/thrift_backend.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,6 @@ def __init__(
149149
http_path,
150150
)
151151

152-
super().__init__(ssl_options, **kwargs)
153-
154152
port = port or 443
155153
if kwargs.get("_connection_uri"):
156154
uri = kwargs.get("_connection_uri")
@@ -164,13 +162,20 @@ def __init__(
164162
raise ValueError("No valid connection settings.")
165163

166164
self._initialize_retry_args(kwargs)
165+
self._use_arrow_native_complex_types = kwargs.get(
166+
"_use_arrow_native_complex_types", True
167+
)
167168

168169
self._use_arrow_native_decimals = kwargs.get("_use_arrow_native_decimals", True)
169170
self._use_arrow_native_timestamps = kwargs.get(
170171
"_use_arrow_native_timestamps", True
171172
)
172173

173174
# Cloud fetch
175+
self._max_download_threads = kwargs.get("max_download_threads", 10)
176+
177+
self._ssl_options = ssl_options
178+
174179
self._auth_provider = auth_provider
175180

176181
# Connector version 3 retry approach

src/databricks/sql/result_set.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def close(self) -> None:
169169
been closed on the server for some other reason, issue a request to the server to close it.
170170
"""
171171
try:
172+
self.results.close()
172173
if (
173174
self.status != CommandState.CLOSED
174175
and not self.has_been_closed_server_side

src/databricks/sql/utils.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def next_n_rows(self, num_rows: int):
4747
def remaining_rows(self):
4848
pass
4949

50+
@abstractmethod
51+
def close(self):
52+
pass
53+
5054

5155
class ThriftResultSetQueueFactory(ABC):
5256
@staticmethod
@@ -159,6 +163,9 @@ def remaining_rows(self):
159163
self.cur_row_index += slice.num_rows
160164
return slice
161165

166+
def close(self):
167+
return
168+
162169

163170
class ArrowQueue(ResultSetQueue):
164171
def __init__(
@@ -196,6 +203,9 @@ def remaining_rows(self) -> "pyarrow.Table":
196203
self.cur_row_index += slice.num_rows
197204
return slice
198205

206+
def close(self):
207+
return
208+
199209

200210
class CloudFetchQueue(ResultSetQueue, ABC):
201211
"""Base class for cloud fetch queues that handle EXTERNAL_LINKS disposition with ARROW format."""
@@ -230,7 +240,12 @@ def __init__(
230240
self.table_row_index = 0
231241

232242
# Initialize download manager
233-
self.download_manager: Optional["ResultFileDownloadManager"] = None
243+
self.download_manager = ResultFileDownloadManager(
244+
links=[],
245+
max_download_threads=max_download_threads,
246+
lz4_compressed=lz4_compressed,
247+
ssl_options=ssl_options,
248+
)
234249

235250
def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
236251
"""
@@ -287,11 +302,8 @@ def remaining_rows(self) -> "pyarrow.Table":
287302

288303
def _create_table_at_offset(self, offset: int) -> Union["pyarrow.Table", None]:
289304
"""Create next table at the given row offset"""
290-
# Create next table by retrieving the logical next downloaded file, or return None to signal end of queue
291-
if not self.download_manager:
292-
logger.debug("CloudFetchQueue: No download manager available")
293-
return None
294305

306+
# Create next table by retrieving the logical next downloaded file, or return None to signal end of queue
295307
downloaded_file = self.download_manager.get_next_downloaded_file(offset)
296308
if not downloaded_file:
297309
logger.debug(
@@ -324,6 +336,9 @@ def _create_empty_table(self) -> "pyarrow.Table":
324336
return pyarrow.Table.from_pydict({})
325337
return create_arrow_table_from_arrow_file(self.schema_bytes, self.description)
326338

339+
def close(self):
340+
self.download_manager._shutdown_manager()
341+
327342

328343
class ThriftCloudFetchQueue(CloudFetchQueue):
329344
"""Queue implementation for EXTERNAL_LINKS disposition with ARROW format for Thrift backend."""
@@ -373,14 +388,7 @@ def __init__(
373388
result_link.startRowOffset, result_link.rowCount
374389
)
375390
)
376-
377-
# Initialize download manager
378-
self.download_manager = ResultFileDownloadManager(
379-
links=self.result_links,
380-
max_download_threads=self.max_download_threads,
381-
lz4_compressed=self.lz4_compressed,
382-
ssl_options=self._ssl_options,
383-
)
391+
self.download_manager.add_link(result_link)
384392

385393
# Initialize table and position
386394
self.table = self._create_next_table()

0 commit comments

Comments
 (0)