55import re
66from typing import Any , Dict , Tuple , List , Optional , Union , TYPE_CHECKING , Set
77
8- from databricks .sql .backend .sea .models .base import ResultManifest
8+ from databricks .sql .backend .sea .models .base import ExternalLink , ResultManifest
99from databricks .sql .backend .sea .models .responses import GetStatementResponse
1010from databricks .sql .backend .sea .utils .constants import (
1111 ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP ,
2929 BackendType ,
3030 ExecuteResponse ,
3131)
32- from databricks .sql .exc import DatabaseError , ProgrammingError , ServerOperationError
32+ from databricks .sql .exc import DatabaseError , ServerOperationError
3333from databricks .sql .backend .sea .utils .http_client import SeaHttpClient
3434from databricks .sql .types import SSLOptions
3535
4444 ExecuteStatementResponse ,
4545 CreateSessionResponse ,
4646)
47+ from databricks .sql .backend .sea .models .responses import GetChunksResponse
4748
4849logger = logging .getLogger (__name__ )
4950
5051
5152def _filter_session_configuration (
52- session_configuration : Optional [Dict [str , str ]]
53- ) -> Optional [ Dict [str , str ] ]:
53+ session_configuration : Optional [Dict [str , Any ]],
54+ ) -> Dict [str , str ]:
5455 if not session_configuration :
55- return None
56+ return {}
5657
5758 filtered_session_configuration = {}
5859 ignored_configs : Set [str ] = set ()
5960
6061 for key , value in session_configuration .items ():
6162 if key .upper () in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP :
62- filtered_session_configuration [key .lower ()] = value
63+ filtered_session_configuration [key .lower ()] = str ( value )
6364 else :
6465 ignored_configs .add (key )
6566
@@ -88,6 +89,7 @@ class SeaDatabricksClient(DatabricksClient):
8889 STATEMENT_PATH = BASE_PATH + "statements"
8990 STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}"
9091 CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
92+ CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}"
9193
9294 # SEA constants
9395 POLL_INTERVAL_SECONDS = 0.2
@@ -123,18 +125,24 @@ def __init__(
123125 )
124126
125127 self ._max_download_threads = kwargs .get ("max_download_threads" , 10 )
128+ self ._ssl_options = ssl_options
129+ self ._use_arrow_native_complex_types = kwargs .get (
130+ "_use_arrow_native_complex_types" , True
131+ )
132+
133+ self .use_hybrid_disposition = kwargs .get ("use_hybrid_disposition" , True )
126134
127135 # Extract warehouse ID from http_path
128136 self .warehouse_id = self ._extract_warehouse_id (http_path )
129137
130138 # Initialize HTTP client
131- self .http_client = SeaHttpClient (
139+ self ._http_client = SeaHttpClient (
132140 server_hostname = server_hostname ,
133141 port = port ,
134142 http_path = http_path ,
135143 http_headers = http_headers ,
136144 auth_provider = auth_provider ,
137- ssl_options = ssl_options ,
145+ ssl_options = self . _ssl_options ,
138146 ** kwargs ,
139147 )
140148
@@ -173,7 +181,7 @@ def _extract_warehouse_id(self, http_path: str) -> str:
173181 f"Note: SEA only works for warehouses."
174182 )
175183 logger .error (error_message )
176- raise ProgrammingError (error_message )
184+ raise ValueError (error_message )
177185
178186 @property
179187 def max_download_threads (self ) -> int :
@@ -182,7 +190,7 @@ def max_download_threads(self) -> int:
182190
183191 def open_session (
184192 self ,
185- session_configuration : Optional [Dict [str , str ]],
193+ session_configuration : Optional [Dict [str , Any ]],
186194 catalog : Optional [str ],
187195 schema : Optional [str ],
188196 ) -> SessionId :
@@ -220,7 +228,7 @@ def open_session(
220228 schema = schema ,
221229 )
222230
223- response = self .http_client ._make_request (
231+ response = self ._http_client ._make_request (
224232 method = "POST" , path = self .SESSION_PATH , data = request_data .to_dict ()
225233 )
226234
@@ -245,7 +253,7 @@ def close_session(self, session_id: SessionId) -> None:
245253 session_id: The session identifier returned by open_session()
246254
247255 Raises:
248- ProgrammingError : If the session ID is invalid
256+ ValueError : If the session ID is invalid
249257 OperationalError: If there's an error closing the session
250258 """
251259
@@ -260,7 +268,7 @@ def close_session(self, session_id: SessionId) -> None:
260268 session_id = sea_session_id ,
261269 )
262270
263- self .http_client ._make_request (
271+ self ._http_client ._make_request (
264272 method = "DELETE" ,
265273 path = self .SESSION_PATH_WITH_ID .format (sea_session_id ),
266274 data = request_data .to_dict (),
@@ -342,7 +350,7 @@ def _results_message_to_execute_response(
342350
343351 # Check for compression
344352 lz4_compressed = (
345- response .manifest .result_compression == ResultCompression .LZ4_FRAME
353+ response .manifest .result_compression == ResultCompression .LZ4_FRAME . value
346354 )
347355
348356 execute_response = ExecuteResponse (
@@ -451,7 +459,7 @@ def execute_command(
451459 enforce_embedded_schema_correctness: Whether to enforce schema correctness
452460
453461 Returns:
454- ResultSet : A SeaResultSet instance for the executed command
462+ SeaResultSet : A SeaResultSet instance for the executed command
455463 """
456464
457465 if session_id .backend_type != BackendType .SEA :
@@ -477,7 +485,11 @@ def execute_command(
477485 ResultFormat .ARROW_STREAM if use_cloud_fetch else ResultFormat .JSON_ARRAY
478486 ).value
479487 disposition = (
480- ResultDisposition .EXTERNAL_LINKS
488+ (
489+ ResultDisposition .HYBRID
490+ if self .use_hybrid_disposition
491+ else ResultDisposition .EXTERNAL_LINKS
492+ )
481493 if use_cloud_fetch
482494 else ResultDisposition .INLINE
483495 ).value
@@ -498,7 +510,7 @@ def execute_command(
498510 result_compression = result_compression ,
499511 )
500512
501- response_data = self .http_client ._make_request (
513+ response_data = self ._http_client ._make_request (
502514 method = "POST" , path = self .STATEMENT_PATH , data = request .to_dict ()
503515 )
504516 response = ExecuteStatementResponse .from_dict (response_data )
@@ -535,7 +547,7 @@ def cancel_command(self, command_id: CommandId) -> None:
535547 command_id: Command identifier to cancel
536548
537549 Raises:
538- ProgrammingError : If the command ID is invalid
550+ ValueError : If the command ID is invalid
539551 """
540552
541553 if command_id .backend_type != BackendType .SEA :
@@ -546,7 +558,7 @@ def cancel_command(self, command_id: CommandId) -> None:
546558 raise ValueError ("Not a valid SEA command ID" )
547559
548560 request = CancelStatementRequest (statement_id = sea_statement_id )
549- self .http_client ._make_request (
561+ self ._http_client ._make_request (
550562 method = "POST" ,
551563 path = self .CANCEL_STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
552564 data = request .to_dict (),
@@ -560,7 +572,7 @@ def close_command(self, command_id: CommandId) -> None:
560572 command_id: Command identifier to close
561573
562574 Raises:
563- ProgrammingError : If the command ID is invalid
575+ ValueError : If the command ID is invalid
564576 """
565577
566578 if command_id .backend_type != BackendType .SEA :
@@ -571,7 +583,7 @@ def close_command(self, command_id: CommandId) -> None:
571583 raise ValueError ("Not a valid SEA command ID" )
572584
573585 request = CloseStatementRequest (statement_id = sea_statement_id )
574- self .http_client ._make_request (
586+ self ._http_client ._make_request (
575587 method = "DELETE" ,
576588 path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
577589 data = request .to_dict (),
@@ -590,7 +602,7 @@ def _poll_query(self, command_id: CommandId) -> GetStatementResponse:
590602 raise ValueError ("Not a valid SEA command ID" )
591603
592604 request = GetStatementRequest (statement_id = sea_statement_id )
593- response_data = self .http_client ._make_request (
605+ response_data = self ._http_client ._make_request (
594606 method = "GET" ,
595607 path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
596608 data = request .to_dict (),
@@ -638,6 +650,27 @@ def get_execution_result(
638650 response = self ._poll_query (command_id )
639651 return self ._response_to_result_set (response , cursor )
640652
653+ def get_chunk_links (
654+ self , statement_id : str , chunk_index : int
655+ ) -> List [ExternalLink ]:
656+ """
657+ Get links for chunks starting from the specified index.
658+ Args:
659+ statement_id: The statement ID
660+ chunk_index: The starting chunk index
661+ Returns:
662+ ExternalLink: External link for the chunk
663+ """
664+
665+ response_data = self ._http_client ._make_request (
666+ method = "GET" ,
667+ path = self .CHUNK_PATH_WITH_ID_AND_INDEX .format (statement_id , chunk_index ),
668+ )
669+ response = GetChunksResponse .from_dict (response_data )
670+
671+ links = response .external_links or []
672+ return links
673+
641674 # == Metadata Operations ==
642675
643676 def get_catalogs (
0 commit comments