Skip to content

Commit d8f46cb

Browse files
Merge branch 'sea-migration' into sea-http-client
2 parents ef9283b + 8fbca9d commit d8f46cb

File tree

10 files changed

+299
-109
lines changed

10 files changed

+299
-109
lines changed

examples/experimental/tests/test_sea_async_query.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def test_sea_async_query_with_cloud_fetch():
4545
use_sea=True,
4646
user_agent_entry="SEA-Test-Client",
4747
use_cloud_fetch=True,
48+
enable_query_result_lz4_compression=False,
4849
)
4950

5051
logger.info(

examples/experimental/tests/test_sea_sync_query.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def test_sea_sync_query_with_cloud_fetch():
4343
use_sea=True,
4444
user_agent_entry="SEA-Test-Client",
4545
use_cloud_fetch=True,
46+
enable_query_result_lz4_compression=False,
4647
)
4748

4849
logger.info(

src/databricks/sql/backend/sea/backend.py

Lines changed: 70 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
if TYPE_CHECKING:
2020
from databricks.sql.client import Cursor
21-
from databricks.sql.backend.sea.result_set import SeaResultSet
21+
22+
from databricks.sql.backend.sea.result_set import SeaResultSet
2223

2324
from databricks.sql.backend.databricks_client import DatabricksClient
2425
from databricks.sql.backend.types import (
@@ -130,6 +131,8 @@ def __init__(
130131
"_use_arrow_native_complex_types", True
131132
)
132133

134+
self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", True)
135+
133136
# Extract warehouse ID from http_path
134137
self.warehouse_id = self._extract_warehouse_id(http_path)
135138

@@ -330,7 +333,7 @@ def _extract_description_from_manifest(
330333
return columns
331334

332335
def _results_message_to_execute_response(
333-
self, response: GetStatementResponse
336+
self, response: Union[ExecuteStatementResponse, GetStatementResponse]
334337
) -> ExecuteResponse:
335338
"""
336339
Convert a SEA response to an ExecuteResponse and extract result data.
@@ -364,6 +367,27 @@ def _results_message_to_execute_response(
364367

365368
return execute_response
366369

370+
def _response_to_result_set(
371+
self,
372+
response: Union[ExecuteStatementResponse, GetStatementResponse],
373+
cursor: Cursor,
374+
) -> SeaResultSet:
375+
"""
376+
Convert a SEA response to a SeaResultSet.
377+
"""
378+
379+
execute_response = self._results_message_to_execute_response(response)
380+
381+
return SeaResultSet(
382+
connection=cursor.connection,
383+
execute_response=execute_response,
384+
sea_client=self,
385+
result_data=response.result,
386+
manifest=response.manifest,
387+
buffer_size_bytes=cursor.buffer_size_bytes,
388+
arraysize=cursor.arraysize,
389+
)
390+
367391
def _check_command_not_in_failed_or_closed_state(
368392
self, state: CommandState, command_id: CommandId
369393
) -> None:
@@ -384,21 +408,24 @@ def _check_command_not_in_failed_or_closed_state(
384408

385409
def _wait_until_command_done(
386410
self, response: ExecuteStatementResponse
387-
) -> CommandState:
411+
) -> Union[ExecuteStatementResponse, GetStatementResponse]:
388412
"""
389413
Wait until a command is done.
390414
"""
391415

392-
state = response.status.state
393-
command_id = CommandId.from_sea_statement_id(response.statement_id)
416+
final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response
417+
418+
state = final_response.status.state
419+
command_id = CommandId.from_sea_statement_id(final_response.statement_id)
394420

395421
while state in [CommandState.PENDING, CommandState.RUNNING]:
396422
time.sleep(self.POLL_INTERVAL_SECONDS)
397-
state = self.get_query_state(command_id)
423+
final_response = self._poll_query(command_id)
424+
state = final_response.status.state
398425

399426
self._check_command_not_in_failed_or_closed_state(state, command_id)
400427

401-
return state
428+
return final_response
402429

403430
def execute_command(
404431
self,
@@ -456,7 +483,11 @@ def execute_command(
456483
ResultFormat.ARROW_STREAM if use_cloud_fetch else ResultFormat.JSON_ARRAY
457484
).value
458485
disposition = (
459-
ResultDisposition.EXTERNAL_LINKS
486+
(
487+
ResultDisposition.HYBRID
488+
if self.use_hybrid_disposition
489+
else ResultDisposition.EXTERNAL_LINKS
490+
)
460491
if use_cloud_fetch
461492
else ResultDisposition.INLINE
462493
).value
@@ -500,8 +531,11 @@ def execute_command(
500531
if async_op:
501532
return None
502533

503-
self._wait_until_command_done(response)
504-
return self.get_execution_result(command_id, cursor)
534+
final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response
535+
if response.status.state != CommandState.SUCCEEDED:
536+
final_response = self._wait_until_command_done(response)
537+
538+
return self._response_to_result_set(final_response, cursor)
505539

506540
def cancel_command(self, command_id: CommandId) -> None:
507541
"""
@@ -553,18 +587,9 @@ def close_command(self, command_id: CommandId) -> None:
553587
data=request.to_dict(),
554588
)
555589

556-
def get_query_state(self, command_id: CommandId) -> CommandState:
590+
def _poll_query(self, command_id: CommandId) -> GetStatementResponse:
557591
"""
558-
Get the state of a running query.
559-
560-
Args:
561-
command_id: Command identifier
562-
563-
Returns:
564-
CommandState: The current state of the command
565-
566-
Raises:
567-
ValueError: If the command ID is invalid
592+
Poll for the current command info.
568593
"""
569594

570595
if command_id.backend_type != BackendType.SEA:
@@ -580,9 +605,25 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
580605
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
581606
data=request.to_dict(),
582607
)
583-
584-
# Parse the response
585608
response = GetStatementResponse.from_dict(response_data)
609+
610+
return response
611+
612+
def get_query_state(self, command_id: CommandId) -> CommandState:
613+
"""
614+
Get the state of a running query.
615+
616+
Args:
617+
command_id: Command identifier
618+
619+
Returns:
620+
CommandState: The current state of the command
621+
622+
Raises:
623+
ProgrammingError: If the command ID is invalid
624+
"""
625+
626+
response = self._poll_query(command_id)
586627
return response.status.state
587628

588629
def get_execution_result(
@@ -604,40 +645,12 @@ def get_execution_result(
604645
ValueError: If the command ID is invalid
605646
"""
606647

607-
if command_id.backend_type != BackendType.SEA:
608-
raise ValueError("Not a valid SEA command ID")
609-
610-
sea_statement_id = command_id.to_sea_statement_id()
611-
if sea_statement_id is None:
612-
raise ValueError("Not a valid SEA command ID")
613-
614-
# Create the request model
615-
request = GetStatementRequest(statement_id=sea_statement_id)
616-
617-
# Get the statement result
618-
response_data = self._http_client._make_request(
619-
method="GET",
620-
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
621-
data=request.to_dict(),
622-
)
623-
response = GetStatementResponse.from_dict(response_data)
624-
625-
# Create and return a SeaResultSet
626-
from databricks.sql.backend.sea.result_set import SeaResultSet
648+
response = self._poll_query(command_id)
649+
return self._response_to_result_set(response, cursor)
627650

628-
execute_response = self._results_message_to_execute_response(response)
629-
630-
return SeaResultSet(
631-
connection=cursor.connection,
632-
execute_response=execute_response,
633-
sea_client=self,
634-
result_data=response.result,
635-
manifest=response.manifest,
636-
buffer_size_bytes=cursor.buffer_size_bytes,
637-
arraysize=cursor.arraysize,
638-
)
639-
640-
def get_chunk_link(self, statement_id: str, chunk_index: int) -> ExternalLink:
651+
def get_chunk_links(
652+
self, statement_id: str, chunk_index: int
653+
) -> List[ExternalLink]:
641654
"""
642655
Get links for chunks starting from the specified index.
643656
Args:
@@ -654,17 +667,7 @@ def get_chunk_link(self, statement_id: str, chunk_index: int) -> ExternalLink:
654667
response = GetChunksResponse.from_dict(response_data)
655668

656669
links = response.external_links or []
657-
link = next((l for l in links if l.chunk_index == chunk_index), None)
658-
if not link:
659-
raise ServerOperationError(
660-
f"No link found for chunk index {chunk_index}",
661-
{
662-
"operation-id": statement_id,
663-
"diagnostic-info": None,
664-
},
665-
)
666-
667-
return link
670+
return links
668671

669672
# == Metadata Operations ==
670673

src/databricks/sql/backend/sea/models/responses.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
These models define the structures used in SEA API responses.
55
"""
66

7+
import base64
78
from typing import Dict, Any, List, Optional
89
from dataclasses import dataclass
910

@@ -91,6 +92,11 @@ def _parse_result(data: Dict[str, Any]) -> ResultData:
9192
)
9293
)
9394

95+
# Handle attachment field - decode from base64 if present
96+
attachment = result_data.get("attachment")
97+
if attachment is not None:
98+
attachment = base64.b64decode(attachment)
99+
94100
return ResultData(
95101
data=result_data.get("data_array"),
96102
external_links=external_links,
@@ -100,7 +106,7 @@ def _parse_result(data: Dict[str, Any]) -> ResultData:
100106
next_chunk_internal_link=result_data.get("next_chunk_internal_link"),
101107
row_count=result_data.get("row_count"),
102108
row_offset=result_data.get("row_offset"),
103-
attachment=result_data.get("attachment"),
109+
attachment=attachment,
104110
)
105111

106112

src/databricks/sql/backend/sea/queue.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
77

8+
from databricks.sql.cloudfetch.downloader import ResultSetDownloadHandler
9+
810
try:
911
import pyarrow
1012
except ImportError:
@@ -23,7 +25,12 @@
2325
from databricks.sql.exc import ProgrammingError, ServerOperationError
2426
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
2527
from databricks.sql.types import SSLOptions
26-
from databricks.sql.utils import CloudFetchQueue, ResultSetQueue
28+
from databricks.sql.utils import (
29+
ArrowQueue,
30+
CloudFetchQueue,
31+
ResultSetQueue,
32+
create_arrow_table_from_arrow_file,
33+
)
2734

2835
import logging
2936

@@ -62,6 +69,18 @@ def build_queue(
6269
# INLINE disposition with JSON_ARRAY format
6370
return JsonQueue(result_data.data)
6471
elif manifest.format == ResultFormat.ARROW_STREAM.value:
72+
if result_data.attachment is not None:
73+
arrow_file = (
74+
ResultSetDownloadHandler._decompress_data(result_data.attachment)
75+
if lz4_compressed
76+
else result_data.attachment
77+
)
78+
arrow_table = create_arrow_table_from_arrow_file(
79+
arrow_file, description
80+
)
81+
logger.debug(f"Created arrow table with {arrow_table.num_rows} rows")
82+
return ArrowQueue(arrow_table, manifest.total_row_count)
83+
6584
# EXTERNAL_LINKS disposition
6685
return SeaCloudFetchQueue(
6786
result_data=result_data,
@@ -150,7 +169,11 @@ def __init__(
150169
)
151170

152171
initial_links = result_data.external_links or []
153-
first_link = next((l for l in initial_links if l.chunk_index == 0), None)
172+
self._chunk_index_to_link = {link.chunk_index: link for link in initial_links}
173+
174+
# Track the current chunk we're processing
175+
self._current_chunk_index = 0
176+
first_link = self._chunk_index_to_link.get(self._current_chunk_index, None)
154177
if not first_link:
155178
# possibly an empty response
156179
return None
@@ -173,21 +196,24 @@ def _convert_to_thrift_link(self, link: ExternalLink) -> TSparkArrowResultLink:
173196
httpHeaders=link.http_headers or {},
174197
)
175198

176-
def _get_chunk_link(self, chunk_index: int) -> Optional[ExternalLink]:
177-
"""Progress to the next chunk link."""
199+
def _get_chunk_link(self, chunk_index: int) -> Optional["ExternalLink"]:
178200
if chunk_index >= self._total_chunk_count:
179201
return None
180202

181-
try:
182-
return self._sea_client.get_chunk_link(self._statement_id, chunk_index)
183-
except Exception as e:
203+
if chunk_index not in self._chunk_index_to_link:
204+
links = self._sea_client.get_chunk_links(self._statement_id, chunk_index)
205+
self._chunk_index_to_link.update({l.chunk_index: l for l in links})
206+
207+
link = self._chunk_index_to_link.get(chunk_index, None)
208+
if not link:
184209
raise ServerOperationError(
185-
f"Error fetching link for chunk {chunk_index}: {e}",
210+
f"Error fetching link for chunk {chunk_index}",
186211
{
187212
"operation-id": self._statement_id,
188213
"diagnostic-info": None,
189214
},
190215
)
216+
return link
191217

192218
def _create_table_from_link(
193219
self, link: ExternalLink

src/databricks/sql/backend/sea/result_set.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import logging
66

7-
from databricks.sql.backend.sea.backend import SeaDatabricksClient
87
from databricks.sql.backend.sea.models.base import ResultData, ResultManifest
98
from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter
109

@@ -15,6 +14,7 @@
1514

1615
if TYPE_CHECKING:
1716
from databricks.sql.client import Connection
17+
from databricks.sql.backend.sea.backend import SeaDatabricksClient
1818
from databricks.sql.types import Row
1919
from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory
2020
from databricks.sql.backend.types import ExecuteResponse

src/databricks/sql/backend/sea/utils/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class ResultFormat(Enum):
2828
class ResultDisposition(Enum):
2929
"""Enum for result disposition values."""
3030

31-
# TODO: add support for hybrid disposition
31+
HYBRID = "INLINE_OR_EXTERNAL_LINKS"
3232
EXTERNAL_LINKS = "EXTERNAL_LINKS"
3333
INLINE = "INLINE"
3434

src/databricks/sql/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ def __init__(
9999
Connect to a Databricks SQL endpoint or a Databricks cluster.
100100
101101
Parameters:
102+
:param use_sea: `bool`, optional (default is False)
103+
Use the SEA backend instead of the Thrift backend.
104+
:param use_hybrid_disposition: `bool`, optional (default is False)
105+
Use the hybrid disposition instead of the inline disposition.
102106
:param server_hostname: Databricks instance host name.
103107
:param http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef)
104108
or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123)

0 commit comments

Comments
 (0)