Skip to content

Commit 47bb758

Browse files
Merge branch 'sea-decouple-link-fetch' into sea-link-expiry
2 parents a18be78 + 00db613 commit 47bb758

File tree

11 files changed

+824
-181
lines changed

11 files changed

+824
-181
lines changed

src/databricks/sql/backend/databricks_client.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from abc import ABC, abstractmethod
44
from typing import Dict, List, Optional, Any, Union, TYPE_CHECKING
55

6-
from databricks.sql.types import SSLOptions
7-
86
if TYPE_CHECKING:
97
from databricks.sql.client import Cursor
108
from databricks.sql.result_set import ResultSet
@@ -24,13 +22,6 @@ class DatabricksClient(ABC):
2422
- Fetching metadata about catalogs, schemas, tables, and columns
2523
"""
2624

27-
def __init__(self, ssl_options: SSLOptions, **kwargs):
28-
self._use_arrow_native_complex_types = kwargs.get(
29-
"_use_arrow_native_complex_types", True
30-
)
31-
self._max_download_threads = kwargs.get("max_download_threads", 10)
32-
self._ssl_options = ssl_options
33-
3425
# == Connection and Session Management ==
3526
@abstractmethod
3627
def open_session(
@@ -110,6 +101,7 @@ def execute_command(
110101
parameters: List of parameters to bind to the query
111102
async_op: Whether to execute the command asynchronously
112103
enforce_embedded_schema_correctness: Whether to enforce schema correctness
104+
row_limit: Maximum number of rows in the response.
113105
114106
Returns:
115107
If async_op is False, returns a ResultSet object containing the

src/databricks/sql/backend/sea/backend.py

Lines changed: 81 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
if TYPE_CHECKING:
2020
from databricks.sql.client import Cursor
21-
from databricks.sql.backend.sea.result_set import SeaResultSet
21+
22+
from databricks.sql.backend.sea.result_set import SeaResultSet
2223

2324
from databricks.sql.backend.databricks_client import DatabricksClient
2425
from databricks.sql.backend.types import (
@@ -50,17 +51,17 @@
5051

5152

5253
def _filter_session_configuration(
53-
session_configuration: Optional[Dict[str, str]]
54-
) -> Optional[Dict[str, str]]:
54+
session_configuration: Optional[Dict[str, Any]],
55+
) -> Dict[str, str]:
5556
if not session_configuration:
56-
return None
57+
return {}
5758

5859
filtered_session_configuration = {}
5960
ignored_configs: Set[str] = set()
6061

6162
for key, value in session_configuration.items():
6263
if key.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP:
63-
filtered_session_configuration[key.lower()] = value
64+
filtered_session_configuration[key.lower()] = str(value)
6465
else:
6566
ignored_configs.add(key)
6667

@@ -124,15 +125,19 @@ def __init__(
124125
http_path,
125126
)
126127

127-
super().__init__(ssl_options=ssl_options, **kwargs)
128+
self._max_download_threads = kwargs.get("max_download_threads", 10)
129+
self._ssl_options = ssl_options
130+
self._use_arrow_native_complex_types = kwargs.get(
131+
"_use_arrow_native_complex_types", True
132+
)
128133

129134
self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", True)
130135

131136
# Extract warehouse ID from http_path
132137
self.warehouse_id = self._extract_warehouse_id(http_path)
133138

134139
# Initialize HTTP client
135-
self.http_client = SeaHttpClient(
140+
self._http_client = SeaHttpClient(
136141
server_hostname=server_hostname,
137142
port=port,
138143
http_path=http_path,
@@ -186,7 +191,7 @@ def max_download_threads(self) -> int:
186191

187192
def open_session(
188193
self,
189-
session_configuration: Optional[Dict[str, str]],
194+
session_configuration: Optional[Dict[str, Any]],
190195
catalog: Optional[str],
191196
schema: Optional[str],
192197
) -> SessionId:
@@ -224,7 +229,7 @@ def open_session(
224229
schema=schema,
225230
)
226231

227-
response = self.http_client._make_request(
232+
response = self._http_client._make_request(
228233
method="POST", path=self.SESSION_PATH, data=request_data.to_dict()
229234
)
230235

@@ -264,7 +269,7 @@ def close_session(self, session_id: SessionId) -> None:
264269
session_id=sea_session_id,
265270
)
266271

267-
self.http_client._make_request(
272+
self._http_client._make_request(
268273
method="DELETE",
269274
path=self.SESSION_PATH_WITH_ID.format(sea_session_id),
270275
data=request_data.to_dict(),
@@ -328,7 +333,7 @@ def _extract_description_from_manifest(
328333
return columns
329334

330335
def _results_message_to_execute_response(
331-
self, response: GetStatementResponse
336+
self, response: Union[ExecuteStatementResponse, GetStatementResponse]
332337
) -> ExecuteResponse:
333338
"""
334339
Convert a SEA response to an ExecuteResponse and extract result data.
@@ -362,6 +367,27 @@ def _results_message_to_execute_response(
362367

363368
return execute_response
364369

370+
def _response_to_result_set(
371+
self,
372+
response: Union[ExecuteStatementResponse, GetStatementResponse],
373+
cursor: Cursor,
374+
) -> SeaResultSet:
375+
"""
376+
Convert a SEA response to a SeaResultSet.
377+
"""
378+
379+
execute_response = self._results_message_to_execute_response(response)
380+
381+
return SeaResultSet(
382+
connection=cursor.connection,
383+
execute_response=execute_response,
384+
sea_client=self,
385+
result_data=response.result,
386+
manifest=response.manifest,
387+
buffer_size_bytes=cursor.buffer_size_bytes,
388+
arraysize=cursor.arraysize,
389+
)
390+
365391
def _check_command_not_in_failed_or_closed_state(
366392
self, state: CommandState, command_id: CommandId
367393
) -> None:
@@ -382,21 +408,24 @@ def _check_command_not_in_failed_or_closed_state(
382408

383409
def _wait_until_command_done(
384410
self, response: ExecuteStatementResponse
385-
) -> CommandState:
411+
) -> Union[ExecuteStatementResponse, GetStatementResponse]:
386412
"""
387413
Wait until a command is done.
388414
"""
389415

390-
state = response.status.state
391-
command_id = CommandId.from_sea_statement_id(response.statement_id)
416+
final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response
417+
418+
state = final_response.status.state
419+
command_id = CommandId.from_sea_statement_id(final_response.statement_id)
392420

393421
while state in [CommandState.PENDING, CommandState.RUNNING]:
394422
time.sleep(self.POLL_INTERVAL_SECONDS)
395-
state = self.get_query_state(command_id)
423+
final_response = self._poll_query(command_id)
424+
state = final_response.status.state
396425

397426
self._check_command_not_in_failed_or_closed_state(state, command_id)
398427

399-
return state
428+
return final_response
400429

401430
def execute_command(
402431
self,
@@ -443,7 +472,9 @@ def execute_command(
443472
sea_parameters.append(
444473
StatementParameter(
445474
name=param.name,
446-
value=param.value,
475+
value=(
476+
param.value.stringValue if param.value is not None else None
477+
),
447478
type=param.type,
448479
)
449480
)
@@ -477,7 +508,7 @@ def execute_command(
477508
result_compression=result_compression,
478509
)
479510

480-
response_data = self.http_client._make_request(
511+
response_data = self._http_client._make_request(
481512
method="POST", path=self.STATEMENT_PATH, data=request.to_dict()
482513
)
483514
response = ExecuteStatementResponse.from_dict(response_data)
@@ -500,8 +531,11 @@ def execute_command(
500531
if async_op:
501532
return None
502533

503-
self._wait_until_command_done(response)
504-
return self.get_execution_result(command_id, cursor)
534+
final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response
535+
if response.status.state != CommandState.SUCCEEDED:
536+
final_response = self._wait_until_command_done(response)
537+
538+
return self._response_to_result_set(final_response, cursor)
505539

506540
def cancel_command(self, command_id: CommandId) -> None:
507541
"""
@@ -522,7 +556,7 @@ def cancel_command(self, command_id: CommandId) -> None:
522556
raise ValueError("Not a valid SEA command ID")
523557

524558
request = CancelStatementRequest(statement_id=sea_statement_id)
525-
self.http_client._make_request(
559+
self._http_client._make_request(
526560
method="POST",
527561
path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id),
528562
data=request.to_dict(),
@@ -547,24 +581,15 @@ def close_command(self, command_id: CommandId) -> None:
547581
raise ValueError("Not a valid SEA command ID")
548582

549583
request = CloseStatementRequest(statement_id=sea_statement_id)
550-
self.http_client._make_request(
584+
self._http_client._make_request(
551585
method="DELETE",
552586
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
553587
data=request.to_dict(),
554588
)
555589

556-
def get_query_state(self, command_id: CommandId) -> CommandState:
590+
def _poll_query(self, command_id: CommandId) -> GetStatementResponse:
557591
"""
558-
Get the state of a running query.
559-
560-
Args:
561-
command_id: Command identifier
562-
563-
Returns:
564-
CommandState: The current state of the command
565-
566-
Raises:
567-
ValueError: If the command ID is invalid
592+
Poll for the current command info.
568593
"""
569594

570595
if command_id.backend_type != BackendType.SEA:
@@ -575,14 +600,30 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
575600
raise ValueError("Not a valid SEA command ID")
576601

577602
request = GetStatementRequest(statement_id=sea_statement_id)
578-
response_data = self.http_client._make_request(
603+
response_data = self._http_client._make_request(
579604
method="GET",
580605
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
581606
data=request.to_dict(),
582607
)
583-
584-
# Parse the response
585608
response = GetStatementResponse.from_dict(response_data)
609+
610+
return response
611+
612+
def get_query_state(self, command_id: CommandId) -> CommandState:
613+
"""
614+
Get the state of a running query.
615+
616+
Args:
617+
command_id: Command identifier
618+
619+
Returns:
620+
CommandState: The current state of the command
621+
622+
Raises:
623+
ProgrammingError: If the command ID is invalid
624+
"""
625+
626+
response = self._poll_query(command_id)
586627
return response.status.state
587628

588629
def get_execution_result(
@@ -604,38 +645,8 @@ def get_execution_result(
604645
ValueError: If the command ID is invalid
605646
"""
606647

607-
if command_id.backend_type != BackendType.SEA:
608-
raise ValueError("Not a valid SEA command ID")
609-
610-
sea_statement_id = command_id.to_sea_statement_id()
611-
if sea_statement_id is None:
612-
raise ValueError("Not a valid SEA command ID")
613-
614-
# Create the request model
615-
request = GetStatementRequest(statement_id=sea_statement_id)
616-
617-
# Get the statement result
618-
response_data = self.http_client._make_request(
619-
method="GET",
620-
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
621-
data=request.to_dict(),
622-
)
623-
response = GetStatementResponse.from_dict(response_data)
624-
625-
# Create and return a SeaResultSet
626-
from databricks.sql.backend.sea.result_set import SeaResultSet
627-
628-
execute_response = self._results_message_to_execute_response(response)
629-
630-
return SeaResultSet(
631-
connection=cursor.connection,
632-
execute_response=execute_response,
633-
sea_client=self,
634-
result_data=response.result,
635-
manifest=response.manifest,
636-
buffer_size_bytes=cursor.buffer_size_bytes,
637-
arraysize=cursor.arraysize,
638-
)
648+
response = self._poll_query(command_id)
649+
return self._response_to_result_set(response, cursor)
639650

640651
def get_chunk_links(
641652
self, statement_id: str, chunk_index: int
@@ -649,13 +660,13 @@ def get_chunk_links(
649660
ExternalLink: External link for the chunk
650661
"""
651662

652-
response_data = self.http_client._make_request(
663+
response_data = self._http_client._make_request(
653664
method="GET",
654665
path=self.CHUNK_PATH_WITH_ID_AND_INDEX.format(statement_id, chunk_index),
655666
)
656667
response = GetChunksResponse.from_dict(response_data)
657668

658-
links = response.external_links
669+
links = response.external_links or []
659670
return links
660671

661672
# == Metadata Operations ==

0 commit comments

Comments
 (0)