Skip to content

Commit e414d52

Browse files
Merge branch 'sea-migration' into sea-optimise-success
2 parents 70e22f8 + 640cc82 commit e414d52

22 files changed

+1510
-339
lines changed

examples/experimental/tests/test_sea_async_query.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def test_sea_async_query_with_cloud_fetch():
4545
use_sea=True,
4646
user_agent_entry="SEA-Test-Client",
4747
use_cloud_fetch=True,
48+
enable_query_result_lz4_compression=False,
4849
)
4950

5051
logger.info(

examples/experimental/tests/test_sea_sync_query.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def test_sea_sync_query_with_cloud_fetch():
4343
use_sea=True,
4444
user_agent_entry="SEA-Test-Client",
4545
use_cloud_fetch=True,
46+
enable_query_result_lz4_compression=False,
4647
)
4748

4849
logger.info(

src/databricks/sql/backend/sea/backend.py

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import re
66
from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set
77

8-
from databricks.sql.backend.sea.models.base import ResultManifest
8+
from databricks.sql.backend.sea.models.base import ExternalLink, ResultManifest
99
from databricks.sql.backend.sea.models.responses import GetStatementResponse
1010
from databricks.sql.backend.sea.utils.constants import (
1111
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP,
@@ -29,7 +29,7 @@
2929
BackendType,
3030
ExecuteResponse,
3131
)
32-
from databricks.sql.exc import DatabaseError, ProgrammingError, ServerOperationError
32+
from databricks.sql.exc import DatabaseError, ServerOperationError
3333
from databricks.sql.backend.sea.utils.http_client import SeaHttpClient
3434
from databricks.sql.types import SSLOptions
3535

@@ -44,22 +44,23 @@
4444
ExecuteStatementResponse,
4545
CreateSessionResponse,
4646
)
47+
from databricks.sql.backend.sea.models.responses import GetChunksResponse
4748

4849
logger = logging.getLogger(__name__)
4950

5051

5152
def _filter_session_configuration(
52-
session_configuration: Optional[Dict[str, str]]
53-
) -> Optional[Dict[str, str]]:
53+
session_configuration: Optional[Dict[str, Any]],
54+
) -> Dict[str, str]:
5455
if not session_configuration:
55-
return None
56+
return {}
5657

5758
filtered_session_configuration = {}
5859
ignored_configs: Set[str] = set()
5960

6061
for key, value in session_configuration.items():
6162
if key.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP:
62-
filtered_session_configuration[key.lower()] = value
63+
filtered_session_configuration[key.lower()] = str(value)
6364
else:
6465
ignored_configs.add(key)
6566

@@ -88,6 +89,7 @@ class SeaDatabricksClient(DatabricksClient):
8889
STATEMENT_PATH = BASE_PATH + "statements"
8990
STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}"
9091
CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
92+
CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}"
9193

9294
# SEA constants
9395
POLL_INTERVAL_SECONDS = 0.2
@@ -123,18 +125,24 @@ def __init__(
123125
)
124126

125127
self._max_download_threads = kwargs.get("max_download_threads", 10)
128+
self._ssl_options = ssl_options
129+
self._use_arrow_native_complex_types = kwargs.get(
130+
"_use_arrow_native_complex_types", True
131+
)
132+
133+
self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", True)
126134

127135
# Extract warehouse ID from http_path
128136
self.warehouse_id = self._extract_warehouse_id(http_path)
129137

130138
# Initialize HTTP client
131-
self.http_client = SeaHttpClient(
139+
self._http_client = SeaHttpClient(
132140
server_hostname=server_hostname,
133141
port=port,
134142
http_path=http_path,
135143
http_headers=http_headers,
136144
auth_provider=auth_provider,
137-
ssl_options=ssl_options,
145+
ssl_options=self._ssl_options,
138146
**kwargs,
139147
)
140148

@@ -173,7 +181,7 @@ def _extract_warehouse_id(self, http_path: str) -> str:
173181
f"Note: SEA only works for warehouses."
174182
)
175183
logger.error(error_message)
176-
raise ProgrammingError(error_message)
184+
raise ValueError(error_message)
177185

178186
@property
179187
def max_download_threads(self) -> int:
@@ -182,7 +190,7 @@ def max_download_threads(self) -> int:
182190

183191
def open_session(
184192
self,
185-
session_configuration: Optional[Dict[str, str]],
193+
session_configuration: Optional[Dict[str, Any]],
186194
catalog: Optional[str],
187195
schema: Optional[str],
188196
) -> SessionId:
@@ -220,7 +228,7 @@ def open_session(
220228
schema=schema,
221229
)
222230

223-
response = self.http_client._make_request(
231+
response = self._http_client._make_request(
224232
method="POST", path=self.SESSION_PATH, data=request_data.to_dict()
225233
)
226234

@@ -245,7 +253,7 @@ def close_session(self, session_id: SessionId) -> None:
245253
session_id: The session identifier returned by open_session()
246254
247255
Raises:
248-
ProgrammingError: If the session ID is invalid
256+
ValueError: If the session ID is invalid
249257
OperationalError: If there's an error closing the session
250258
"""
251259

@@ -260,7 +268,7 @@ def close_session(self, session_id: SessionId) -> None:
260268
session_id=sea_session_id,
261269
)
262270

263-
self.http_client._make_request(
271+
self._http_client._make_request(
264272
method="DELETE",
265273
path=self.SESSION_PATH_WITH_ID.format(sea_session_id),
266274
data=request_data.to_dict(),
@@ -342,7 +350,7 @@ def _results_message_to_execute_response(
342350

343351
# Check for compression
344352
lz4_compressed = (
345-
response.manifest.result_compression == ResultCompression.LZ4_FRAME
353+
response.manifest.result_compression == ResultCompression.LZ4_FRAME.value
346354
)
347355

348356
execute_response = ExecuteResponse(
@@ -451,7 +459,7 @@ def execute_command(
451459
enforce_embedded_schema_correctness: Whether to enforce schema correctness
452460
453461
Returns:
454-
ResultSet: A SeaResultSet instance for the executed command
462+
SeaResultSet: A SeaResultSet instance for the executed command
455463
"""
456464

457465
if session_id.backend_type != BackendType.SEA:
@@ -477,7 +485,11 @@ def execute_command(
477485
ResultFormat.ARROW_STREAM if use_cloud_fetch else ResultFormat.JSON_ARRAY
478486
).value
479487
disposition = (
480-
ResultDisposition.EXTERNAL_LINKS
488+
(
489+
ResultDisposition.HYBRID
490+
if self.use_hybrid_disposition
491+
else ResultDisposition.EXTERNAL_LINKS
492+
)
481493
if use_cloud_fetch
482494
else ResultDisposition.INLINE
483495
).value
@@ -498,7 +510,7 @@ def execute_command(
498510
result_compression=result_compression,
499511
)
500512

501-
response_data = self.http_client._make_request(
513+
response_data = self._http_client._make_request(
502514
method="POST", path=self.STATEMENT_PATH, data=request.to_dict()
503515
)
504516
response = ExecuteStatementResponse.from_dict(response_data)
@@ -535,7 +547,7 @@ def cancel_command(self, command_id: CommandId) -> None:
535547
command_id: Command identifier to cancel
536548
537549
Raises:
538-
ProgrammingError: If the command ID is invalid
550+
ValueError: If the command ID is invalid
539551
"""
540552

541553
if command_id.backend_type != BackendType.SEA:
@@ -546,7 +558,7 @@ def cancel_command(self, command_id: CommandId) -> None:
546558
raise ValueError("Not a valid SEA command ID")
547559

548560
request = CancelStatementRequest(statement_id=sea_statement_id)
549-
self.http_client._make_request(
561+
self._http_client._make_request(
550562
method="POST",
551563
path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id),
552564
data=request.to_dict(),
@@ -560,7 +572,7 @@ def close_command(self, command_id: CommandId) -> None:
560572
command_id: Command identifier to close
561573
562574
Raises:
563-
ProgrammingError: If the command ID is invalid
575+
ValueError: If the command ID is invalid
564576
"""
565577

566578
if command_id.backend_type != BackendType.SEA:
@@ -571,7 +583,7 @@ def close_command(self, command_id: CommandId) -> None:
571583
raise ValueError("Not a valid SEA command ID")
572584

573585
request = CloseStatementRequest(statement_id=sea_statement_id)
574-
self.http_client._make_request(
586+
self._http_client._make_request(
575587
method="DELETE",
576588
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
577589
data=request.to_dict(),
@@ -590,7 +602,7 @@ def _poll_query(self, command_id: CommandId) -> GetStatementResponse:
590602
raise ValueError("Not a valid SEA command ID")
591603

592604
request = GetStatementRequest(statement_id=sea_statement_id)
593-
response_data = self.http_client._make_request(
605+
response_data = self._http_client._make_request(
594606
method="GET",
595607
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
596608
data=request.to_dict(),
@@ -638,6 +650,27 @@ def get_execution_result(
638650
response = self._poll_query(command_id)
639651
return self._response_to_result_set(response, cursor)
640652

653+
def get_chunk_links(
654+
self, statement_id: str, chunk_index: int
655+
) -> List[ExternalLink]:
656+
"""
657+
Get links for chunks starting from the specified index.
658+
Args:
659+
statement_id: The statement ID
660+
chunk_index: The starting chunk index
661+
Returns:
662+
ExternalLink: External link for the chunk
663+
"""
664+
665+
response_data = self._http_client._make_request(
666+
method="GET",
667+
path=self.CHUNK_PATH_WITH_ID_AND_INDEX.format(statement_id, chunk_index),
668+
)
669+
response = GetChunksResponse.from_dict(response_data)
670+
671+
links = response.external_links or []
672+
return links
673+
641674
# == Metadata Operations ==
642675

643676
def get_catalogs(

src/databricks/sql/backend/sea/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ExecuteStatementResponse,
2828
GetStatementResponse,
2929
CreateSessionResponse,
30+
GetChunksResponse,
3031
)
3132

3233
__all__ = [
@@ -49,4 +50,5 @@
4950
"ExecuteStatementResponse",
5051
"GetStatementResponse",
5152
"CreateSessionResponse",
53+
"GetChunksResponse",
5254
]

src/databricks/sql/backend/sea/models/responses.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
These models define the structures used in SEA API responses.
55
"""
66

7-
from typing import Dict, Any
7+
import base64
8+
from typing import Dict, Any, List, Optional
89
from dataclasses import dataclass
910

1011
from databricks.sql.backend.types import CommandState
@@ -91,6 +92,11 @@ def _parse_result(data: Dict[str, Any]) -> ResultData:
9192
)
9293
)
9394

95+
# Handle attachment field - decode from base64 if present
96+
attachment = result_data.get("attachment")
97+
if attachment is not None:
98+
attachment = base64.b64decode(attachment)
99+
94100
return ResultData(
95101
data=result_data.get("data_array"),
96102
external_links=external_links,
@@ -100,7 +106,7 @@ def _parse_result(data: Dict[str, Any]) -> ResultData:
100106
next_chunk_internal_link=result_data.get("next_chunk_internal_link"),
101107
row_count=result_data.get("row_count"),
102108
row_offset=result_data.get("row_offset"),
103-
attachment=result_data.get("attachment"),
109+
attachment=attachment,
104110
)
105111

106112

@@ -154,3 +160,37 @@ class CreateSessionResponse:
154160
def from_dict(cls, data: Dict[str, Any]) -> "CreateSessionResponse":
155161
"""Create a CreateSessionResponse from a dictionary."""
156162
return cls(session_id=data.get("session_id", ""))
163+
164+
165+
@dataclass
166+
class GetChunksResponse:
167+
"""
168+
Response from getting chunks for a statement.
169+
170+
The response model can be found in the docs, here:
171+
https://docs.databricks.com/api/workspace/statementexecution/getstatementresultchunkn
172+
"""
173+
174+
data: Optional[List[List[Any]]] = None
175+
external_links: Optional[List[ExternalLink]] = None
176+
byte_count: Optional[int] = None
177+
chunk_index: Optional[int] = None
178+
next_chunk_index: Optional[int] = None
179+
next_chunk_internal_link: Optional[str] = None
180+
row_count: Optional[int] = None
181+
row_offset: Optional[int] = None
182+
183+
@classmethod
184+
def from_dict(cls, data: Dict[str, Any]) -> "GetChunksResponse":
185+
"""Create a GetChunksResponse from a dictionary."""
186+
result = _parse_result({"result": data})
187+
return cls(
188+
data=result.data,
189+
external_links=result.external_links,
190+
byte_count=result.byte_count,
191+
chunk_index=result.chunk_index,
192+
next_chunk_index=result.next_chunk_index,
193+
next_chunk_internal_link=result.next_chunk_internal_link,
194+
row_count=result.row_count,
195+
row_offset=result.row_offset,
196+
)

0 commit comments

Comments
 (0)