Skip to content

Commit 4dd9434

Browse files
Merge branch 'ext-links-sea' into sea-hybrid
2 parents 0e1abfa + 28c6bb1 commit 4dd9434

File tree

10 files changed

+96
-143
lines changed

10 files changed

+96
-143
lines changed

src/databricks/sql/backend/databricks_client.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from abc import ABC, abstractmethod
44
from typing import Dict, List, Optional, Any, Union, TYPE_CHECKING
55

6-
from databricks.sql.types import SSLOptions
7-
86
if TYPE_CHECKING:
97
from databricks.sql.client import Cursor
108
from databricks.sql.result_set import ResultSet
@@ -24,13 +22,6 @@ class DatabricksClient(ABC):
2422
- Fetching metadata about catalogs, schemas, tables, and columns
2523
"""
2624

27-
def __init__(self, ssl_options: SSLOptions, **kwargs):
28-
self._use_arrow_native_complex_types = kwargs.get(
29-
"_use_arrow_native_complex_types", True
30-
)
31-
self._max_download_threads = kwargs.get("max_download_threads", 10)
32-
self._ssl_options = ssl_options
33-
3425
# == Connection and Session Management ==
3526
@abstractmethod
3627
def open_session(
@@ -110,6 +101,7 @@ def execute_command(
110101
parameters: List of parameters to bind to the query
111102
async_op: Whether to execute the command asynchronously
112103
enforce_embedded_schema_correctness: Whether to enforce schema correctness
104+
row_limit: Maximum number of rows in the response.
113105
114106
Returns:
115107
If async_op is False, returns a ResultSet object containing the

src/databricks/sql/backend/sea/backend.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,19 @@ def __init__(
124124
http_path,
125125
)
126126

127-
super().__init__(ssl_options=ssl_options, **kwargs)
127+
self._max_download_threads = kwargs.get("max_download_threads", 10)
128+
self._ssl_options = ssl_options
129+
self._use_arrow_native_complex_types = kwargs.get(
130+
"_use_arrow_native_complex_types", True
131+
)
128132

129133
self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", True)
130134

131135
# Extract warehouse ID from http_path
132136
self.warehouse_id = self._extract_warehouse_id(http_path)
133137

134138
# Initialize HTTP client
135-
self.http_client = SeaHttpClient(
139+
self._http_client = SeaHttpClient(
136140
server_hostname=server_hostname,
137141
port=port,
138142
http_path=http_path,
@@ -224,7 +228,7 @@ def open_session(
224228
schema=schema,
225229
)
226230

227-
response = self.http_client._make_request(
231+
response = self._http_client._make_request(
228232
method="POST", path=self.SESSION_PATH, data=request_data.to_dict()
229233
)
230234

@@ -264,7 +268,7 @@ def close_session(self, session_id: SessionId) -> None:
264268
session_id=sea_session_id,
265269
)
266270

267-
self.http_client._make_request(
271+
self._http_client._make_request(
268272
method="DELETE",
269273
path=self.SESSION_PATH_WITH_ID.format(sea_session_id),
270274
data=request_data.to_dict(),
@@ -443,7 +447,9 @@ def execute_command(
443447
sea_parameters.append(
444448
StatementParameter(
445449
name=param.name,
446-
value=param.value,
450+
value=(
451+
param.value.stringValue if param.value is not None else None
452+
),
447453
type=param.type,
448454
)
449455
)
@@ -477,7 +483,7 @@ def execute_command(
477483
result_compression=result_compression,
478484
)
479485

480-
response_data = self.http_client._make_request(
486+
response_data = self._http_client._make_request(
481487
method="POST", path=self.STATEMENT_PATH, data=request.to_dict()
482488
)
483489
response = ExecuteStatementResponse.from_dict(response_data)
@@ -522,7 +528,7 @@ def cancel_command(self, command_id: CommandId) -> None:
522528
raise ValueError("Not a valid SEA command ID")
523529

524530
request = CancelStatementRequest(statement_id=sea_statement_id)
525-
self.http_client._make_request(
531+
self._http_client._make_request(
526532
method="POST",
527533
path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id),
528534
data=request.to_dict(),
@@ -547,7 +553,7 @@ def close_command(self, command_id: CommandId) -> None:
547553
raise ValueError("Not a valid SEA command ID")
548554

549555
request = CloseStatementRequest(statement_id=sea_statement_id)
550-
self.http_client._make_request(
556+
self._http_client._make_request(
551557
method="DELETE",
552558
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
553559
data=request.to_dict(),
@@ -575,7 +581,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
575581
raise ValueError("Not a valid SEA command ID")
576582

577583
request = GetStatementRequest(statement_id=sea_statement_id)
578-
response_data = self.http_client._make_request(
584+
response_data = self._http_client._make_request(
579585
method="GET",
580586
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
581587
data=request.to_dict(),
@@ -615,7 +621,7 @@ def get_execution_result(
615621
request = GetStatementRequest(statement_id=sea_statement_id)
616622

617623
# Get the statement result
618-
response_data = self.http_client._make_request(
624+
response_data = self._http_client._make_request(
619625
method="GET",
620626
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
621627
data=request.to_dict(),
@@ -649,7 +655,7 @@ def get_chunk_links(
649655
ExternalLink: External link for the chunk
650656
"""
651657

652-
response_data = self.http_client._make_request(
658+
response_data = self._http_client._make_request(
653659
method="GET",
654660
path=self.CHUNK_PATH_WITH_ID_AND_INDEX.format(statement_id, chunk_index),
655661
)

src/databricks/sql/backend/sea/models/responses.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -165,34 +165,33 @@ def from_dict(cls, data: Dict[str, Any]) -> "CreateSessionResponse":
165165

166166
@dataclass
167167
class GetChunksResponse:
168-
"""Response from getting chunks for a statement."""
169-
170-
statement_id: str
171-
external_links: List[ExternalLink]
168+
"""
169+
Response from getting chunks for a statement.
170+
171+
The response model can be found in the docs, here:
172+
https://docs.databricks.com/api/workspace/statementexecution/getstatementresultchunkn
173+
"""
174+
175+
data: Optional[List[List[Any]]] = None
176+
external_links: Optional[List[ExternalLink]] = None
177+
byte_count: Optional[int] = None
178+
chunk_index: Optional[int] = None
179+
next_chunk_index: Optional[int] = None
180+
next_chunk_internal_link: Optional[str] = None
181+
row_count: Optional[int] = None
182+
row_offset: Optional[int] = None
172183

173184
@classmethod
174185
def from_dict(cls, data: Dict[str, Any]) -> "GetChunksResponse":
175186
"""Create a GetChunksResponse from a dictionary."""
176-
external_links = []
177-
if "external_links" in data:
178-
for link_data in data["external_links"]:
179-
external_links.append(
180-
ExternalLink(
181-
external_link=link_data.get("external_link", ""),
182-
expiration=link_data.get("expiration", ""),
183-
chunk_index=link_data.get("chunk_index", 0),
184-
byte_count=link_data.get("byte_count", 0),
185-
row_count=link_data.get("row_count", 0),
186-
row_offset=link_data.get("row_offset", 0),
187-
next_chunk_index=link_data.get("next_chunk_index"),
188-
next_chunk_internal_link=link_data.get(
189-
"next_chunk_internal_link"
190-
),
191-
http_headers=link_data.get("http_headers"),
192-
)
193-
)
194-
187+
result = _parse_result({"result": data})
195188
return cls(
196-
statement_id=data.get("statement_id", ""),
197-
external_links=external_links,
189+
data=result.data,
190+
external_links=result.external_links,
191+
byte_count=result.byte_count,
192+
chunk_index=result.chunk_index,
193+
next_chunk_index=result.next_chunk_index,
194+
next_chunk_internal_link=result.next_chunk_internal_link,
195+
row_count=result.row_count,
196+
row_offset=result.row_offset,
198197
)

src/databricks/sql/backend/sea/queue.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
ResultManifest,
2222
)
2323
from databricks.sql.backend.sea.utils.constants import ResultFormat
24-
from databricks.sql.exc import ProgrammingError
24+
from databricks.sql.exc import ProgrammingError, ServerOperationError
2525
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
2626
from databricks.sql.types import SSLOptions
2727
from databricks.sql.utils import (
@@ -82,7 +82,7 @@ def build_queue(
8282

8383
# EXTERNAL_LINKS disposition
8484
return SeaCloudFetchQueue(
85-
initial_links=result_data.external_links or [],
85+
result_data=result_data,
8686
max_download_threads=max_download_threads,
8787
ssl_options=ssl_options,
8888
sea_client=sea_client,
@@ -116,13 +116,16 @@ def remaining_rows(self) -> List[List[str]]:
116116
self.cur_row_index += len(slice)
117117
return slice
118118

119+
def close(self):
120+
return
121+
119122

120123
class SeaCloudFetchQueue(CloudFetchQueue):
121124
"""Queue implementation for EXTERNAL_LINKS disposition with ARROW format for SEA backend."""
122125

123126
def __init__(
124127
self,
125-
initial_links: List["ExternalLink"],
128+
result_data: ResultData,
126129
max_download_threads: int,
127130
ssl_options: SSLOptions,
128131
sea_client: "SeaDatabricksClient",
@@ -163,10 +166,12 @@ def __init__(
163166
)
164167
)
165168

169+
initial_links = result_data.external_links
166170
self._chunk_index_to_link = {link.chunk_index: link for link in initial_links}
167171

168172
initial_link = self._chunk_index_to_link.get(0, None)
169173
if not initial_link:
174+
# possibly an empty response
170175
return
171176

172177
self.download_manager = ResultFileDownloadManager(
@@ -177,7 +182,7 @@ def __init__(
177182
)
178183

179184
# Track the current chunk we're processing
180-
self._current_chunk_link = initial_link
185+
self._current_chunk_link = first_link
181186

182187
# Initialize table and position
183188
self.table = self._create_table_from_link(self._current_chunk_link)
@@ -230,10 +235,6 @@ def _create_table_from_link(
230235
) -> Union["pyarrow.Table", None]:
231236
"""Create a table from a link."""
232237

233-
if not self.download_manager:
234-
logger.debug("SeaCloudFetchQueue: No download manager, returning")
235-
return None
236-
237238
thrift_link = self._convert_to_thrift_link(link)
238239
self.download_manager.add_link(thrift_link)
239240

src/databricks/sql/backend/thrift_backend.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,6 @@ def __init__(
149149
http_path,
150150
)
151151

152-
super().__init__(ssl_options, **kwargs)
153-
154152
port = port or 443
155153
if kwargs.get("_connection_uri"):
156154
uri = kwargs.get("_connection_uri")
@@ -164,13 +162,20 @@ def __init__(
164162
raise ValueError("No valid connection settings.")
165163

166164
self._initialize_retry_args(kwargs)
165+
self._use_arrow_native_complex_types = kwargs.get(
166+
"_use_arrow_native_complex_types", True
167+
)
167168

168169
self._use_arrow_native_decimals = kwargs.get("_use_arrow_native_decimals", True)
169170
self._use_arrow_native_timestamps = kwargs.get(
170171
"_use_arrow_native_timestamps", True
171172
)
172173

173174
# Cloud fetch
175+
self._max_download_threads = kwargs.get("max_download_threads", 10)
176+
177+
self._ssl_options = ssl_options
178+
174179
self._auth_provider = auth_provider
175180

176181
# Connector version 3 retry approach

src/databricks/sql/result_set.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def close(self) -> None:
169169
been closed on the server for some other reason, issue a request to the server to close it.
170170
"""
171171
try:
172+
self.results.close()
172173
if (
173174
self.status != CommandState.CLOSED
174175
and not self.has_been_closed_server_side

src/databricks/sql/utils.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def next_n_rows(self, num_rows: int):
4747
def remaining_rows(self):
4848
pass
4949

50+
@abstractmethod
51+
def close(self):
52+
pass
53+
5054

5155
class ThriftResultSetQueueFactory(ABC):
5256
@staticmethod
@@ -159,6 +163,9 @@ def remaining_rows(self):
159163
self.cur_row_index += slice.num_rows
160164
return slice
161165

166+
def close(self):
167+
return
168+
162169

163170
class ArrowQueue(ResultSetQueue):
164171
def __init__(
@@ -196,6 +203,9 @@ def remaining_rows(self) -> "pyarrow.Table":
196203
self.cur_row_index += slice.num_rows
197204
return slice
198205

206+
def close(self):
207+
return
208+
199209

200210
class CloudFetchQueue(ResultSetQueue, ABC):
201211
"""Base class for cloud fetch queues that handle EXTERNAL_LINKS disposition with ARROW format."""
@@ -230,7 +240,12 @@ def __init__(
230240
self.table_row_index = 0
231241

232242
# Initialize download manager
233-
self.download_manager: Optional["ResultFileDownloadManager"] = None
243+
self.download_manager = ResultFileDownloadManager(
244+
links=[],
245+
max_download_threads=max_download_threads,
246+
lz4_compressed=lz4_compressed,
247+
ssl_options=ssl_options,
248+
)
234249

235250
def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
236251
"""
@@ -287,11 +302,8 @@ def remaining_rows(self) -> "pyarrow.Table":
287302

288303
def _create_table_at_offset(self, offset: int) -> Union["pyarrow.Table", None]:
289304
"""Create next table at the given row offset"""
290-
# Create next table by retrieving the logical next downloaded file, or return None to signal end of queue
291-
if not self.download_manager:
292-
logger.debug("CloudFetchQueue: No download manager available")
293-
return None
294305

306+
# Create next table by retrieving the logical next downloaded file, or return None to signal end of queue
295307
downloaded_file = self.download_manager.get_next_downloaded_file(offset)
296308
if not downloaded_file:
297309
logger.debug(
@@ -324,6 +336,9 @@ def _create_empty_table(self) -> "pyarrow.Table":
324336
return pyarrow.Table.from_pydict({})
325337
return create_arrow_table_from_arrow_file(self.schema_bytes, self.description)
326338

339+
def close(self):
340+
self.download_manager._shutdown_manager()
341+
327342

328343
class ThriftCloudFetchQueue(CloudFetchQueue):
329344
"""Queue implementation for EXTERNAL_LINKS disposition with ARROW format for Thrift backend."""
@@ -373,14 +388,7 @@ def __init__(
373388
result_link.startRowOffset, result_link.rowCount
374389
)
375390
)
376-
377-
# Initialize download manager
378-
self.download_manager = ResultFileDownloadManager(
379-
links=self.result_links,
380-
max_download_threads=self.max_download_threads,
381-
lz4_compressed=self.lz4_compressed,
382-
ssl_options=self._ssl_options,
383-
)
391+
self.download_manager.add_link(result_link)
384392

385393
# Initialize table and position
386394
self.table = self._create_next_table()

0 commit comments

Comments
 (0)