1818
1919if TYPE_CHECKING :
2020 from databricks .sql .client import Cursor
21- from databricks .sql .backend .sea .result_set import SeaResultSet
21+
22+ from databricks .sql .backend .sea .result_set import SeaResultSet
2223
2324from databricks .sql .backend .databricks_client import DatabricksClient
2425from databricks .sql .backend .types import (
5051
5152
5253def _filter_session_configuration (
53- session_configuration : Optional [Dict [str , str ]]
54- ) -> Optional [ Dict [str , str ] ]:
54+ session_configuration : Optional [Dict [str , Any ]],
55+ ) -> Dict [str , str ]:
5556 if not session_configuration :
56- return None
57+ return {}
5758
5859 filtered_session_configuration = {}
5960 ignored_configs : Set [str ] = set ()
6061
6162 for key , value in session_configuration .items ():
6263 if key .upper () in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP :
63- filtered_session_configuration [key .lower ()] = value
64+ filtered_session_configuration [key .lower ()] = str ( value )
6465 else :
6566 ignored_configs .add (key )
6667
@@ -124,15 +125,19 @@ def __init__(
124125 http_path ,
125126 )
126127
127- super ().__init__ (ssl_options = ssl_options , ** kwargs )
128+ self ._max_download_threads = kwargs .get ("max_download_threads" , 10 )
129+ self ._ssl_options = ssl_options
130+ self ._use_arrow_native_complex_types = kwargs .get (
131+ "_use_arrow_native_complex_types" , True
132+ )
128133
129134 self .use_hybrid_disposition = kwargs .get ("use_hybrid_disposition" , True )
130135
131136 # Extract warehouse ID from http_path
132137 self .warehouse_id = self ._extract_warehouse_id (http_path )
133138
134139 # Initialize HTTP client
135- self .http_client = SeaHttpClient (
140+ self ._http_client = SeaHttpClient (
136141 server_hostname = server_hostname ,
137142 port = port ,
138143 http_path = http_path ,
@@ -186,7 +191,7 @@ def max_download_threads(self) -> int:
186191
187192 def open_session (
188193 self ,
189- session_configuration : Optional [Dict [str , str ]],
194+ session_configuration : Optional [Dict [str , Any ]],
190195 catalog : Optional [str ],
191196 schema : Optional [str ],
192197 ) -> SessionId :
@@ -224,7 +229,7 @@ def open_session(
224229 schema = schema ,
225230 )
226231
227- response = self .http_client ._make_request (
232+ response = self ._http_client ._make_request (
228233 method = "POST" , path = self .SESSION_PATH , data = request_data .to_dict ()
229234 )
230235
@@ -264,7 +269,7 @@ def close_session(self, session_id: SessionId) -> None:
264269 session_id = sea_session_id ,
265270 )
266271
267- self .http_client ._make_request (
272+ self ._http_client ._make_request (
268273 method = "DELETE" ,
269274 path = self .SESSION_PATH_WITH_ID .format (sea_session_id ),
270275 data = request_data .to_dict (),
@@ -328,7 +333,7 @@ def _extract_description_from_manifest(
328333 return columns
329334
330335 def _results_message_to_execute_response (
331- self , response : GetStatementResponse
336+ self , response : Union [ ExecuteStatementResponse , GetStatementResponse ]
332337 ) -> ExecuteResponse :
333338 """
334339 Convert a SEA response to an ExecuteResponse and extract result data.
@@ -362,6 +367,27 @@ def _results_message_to_execute_response(
362367
363368 return execute_response
364369
370+ def _response_to_result_set (
371+ self ,
372+ response : Union [ExecuteStatementResponse , GetStatementResponse ],
373+ cursor : Cursor ,
374+ ) -> SeaResultSet :
375+ """
376+ Convert a SEA response to a SeaResultSet.
377+ """
378+
379+ execute_response = self ._results_message_to_execute_response (response )
380+
381+ return SeaResultSet (
382+ connection = cursor .connection ,
383+ execute_response = execute_response ,
384+ sea_client = self ,
385+ result_data = response .result ,
386+ manifest = response .manifest ,
387+ buffer_size_bytes = cursor .buffer_size_bytes ,
388+ arraysize = cursor .arraysize ,
389+ )
390+
365391 def _check_command_not_in_failed_or_closed_state (
366392 self , state : CommandState , command_id : CommandId
367393 ) -> None :
@@ -382,21 +408,24 @@ def _check_command_not_in_failed_or_closed_state(
382408
383409 def _wait_until_command_done (
384410 self , response : ExecuteStatementResponse
385- ) -> CommandState :
411+ ) -> Union [ ExecuteStatementResponse , GetStatementResponse ] :
386412 """
387413 Wait until a command is done.
388414 """
389415
390- state = response .status .state
391- command_id = CommandId .from_sea_statement_id (response .statement_id )
416+ final_response : Union [ExecuteStatementResponse , GetStatementResponse ] = response
417+
418+ state = final_response .status .state
419+ command_id = CommandId .from_sea_statement_id (final_response .statement_id )
392420
393421 while state in [CommandState .PENDING , CommandState .RUNNING ]:
394422 time .sleep (self .POLL_INTERVAL_SECONDS )
395- state = self .get_query_state (command_id )
423+ final_response = self ._poll_query (command_id )
424+ state = final_response .status .state
396425
397426 self ._check_command_not_in_failed_or_closed_state (state , command_id )
398427
399- return state
428+ return final_response
400429
401430 def execute_command (
402431 self ,
@@ -443,7 +472,9 @@ def execute_command(
443472 sea_parameters .append (
444473 StatementParameter (
445474 name = param .name ,
446- value = param .value ,
475+ value = (
476+ param .value .stringValue if param .value is not None else None
477+ ),
447478 type = param .type ,
448479 )
449480 )
@@ -477,7 +508,7 @@ def execute_command(
477508 result_compression = result_compression ,
478509 )
479510
480- response_data = self .http_client ._make_request (
511+ response_data = self ._http_client ._make_request (
481512 method = "POST" , path = self .STATEMENT_PATH , data = request .to_dict ()
482513 )
483514 response = ExecuteStatementResponse .from_dict (response_data )
@@ -500,8 +531,11 @@ def execute_command(
500531 if async_op :
501532 return None
502533
503- self ._wait_until_command_done (response )
504- return self .get_execution_result (command_id , cursor )
534+ final_response : Union [ExecuteStatementResponse , GetStatementResponse ] = response
535+ if response .status .state != CommandState .SUCCEEDED :
536+ final_response = self ._wait_until_command_done (response )
537+
538+ return self ._response_to_result_set (final_response , cursor )
505539
506540 def cancel_command (self , command_id : CommandId ) -> None :
507541 """
@@ -522,7 +556,7 @@ def cancel_command(self, command_id: CommandId) -> None:
522556 raise ValueError ("Not a valid SEA command ID" )
523557
524558 request = CancelStatementRequest (statement_id = sea_statement_id )
525- self .http_client ._make_request (
559+ self ._http_client ._make_request (
526560 method = "POST" ,
527561 path = self .CANCEL_STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
528562 data = request .to_dict (),
@@ -547,24 +581,15 @@ def close_command(self, command_id: CommandId) -> None:
547581 raise ValueError ("Not a valid SEA command ID" )
548582
549583 request = CloseStatementRequest (statement_id = sea_statement_id )
550- self .http_client ._make_request (
584+ self ._http_client ._make_request (
551585 method = "DELETE" ,
552586 path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
553587 data = request .to_dict (),
554588 )
555589
556- def get_query_state (self , command_id : CommandId ) -> CommandState :
590+ def _poll_query (self , command_id : CommandId ) -> GetStatementResponse :
557591 """
558- Get the state of a running query.
559-
560- Args:
561- command_id: Command identifier
562-
563- Returns:
564- CommandState: The current state of the command
565-
566- Raises:
567- ValueError: If the command ID is invalid
592+ Poll for the current command info.
568593 """
569594
570595 if command_id .backend_type != BackendType .SEA :
@@ -575,14 +600,30 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
575600 raise ValueError ("Not a valid SEA command ID" )
576601
577602 request = GetStatementRequest (statement_id = sea_statement_id )
578- response_data = self .http_client ._make_request (
603+ response_data = self ._http_client ._make_request (
579604 method = "GET" ,
580605 path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
581606 data = request .to_dict (),
582607 )
583-
584- # Parse the response
585608 response = GetStatementResponse .from_dict (response_data )
609+
610+ return response
611+
612+ def get_query_state (self , command_id : CommandId ) -> CommandState :
613+ """
614+ Get the state of a running query.
615+
616+ Args:
617+ command_id: Command identifier
618+
619+ Returns:
620+ CommandState: The current state of the command
621+
622+ Raises:
623+ ProgrammingError: If the command ID is invalid
624+ """
625+
626+ response = self ._poll_query (command_id )
586627 return response .status .state
587628
588629 def get_execution_result (
@@ -604,38 +645,8 @@ def get_execution_result(
604645 ValueError: If the command ID is invalid
605646 """
606647
607- if command_id .backend_type != BackendType .SEA :
608- raise ValueError ("Not a valid SEA command ID" )
609-
610- sea_statement_id = command_id .to_sea_statement_id ()
611- if sea_statement_id is None :
612- raise ValueError ("Not a valid SEA command ID" )
613-
614- # Create the request model
615- request = GetStatementRequest (statement_id = sea_statement_id )
616-
617- # Get the statement result
618- response_data = self .http_client ._make_request (
619- method = "GET" ,
620- path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
621- data = request .to_dict (),
622- )
623- response = GetStatementResponse .from_dict (response_data )
624-
625- # Create and return a SeaResultSet
626- from databricks .sql .backend .sea .result_set import SeaResultSet
627-
628- execute_response = self ._results_message_to_execute_response (response )
629-
630- return SeaResultSet (
631- connection = cursor .connection ,
632- execute_response = execute_response ,
633- sea_client = self ,
634- result_data = response .result ,
635- manifest = response .manifest ,
636- buffer_size_bytes = cursor .buffer_size_bytes ,
637- arraysize = cursor .arraysize ,
638- )
648+ response = self ._poll_query (command_id )
649+ return self ._response_to_result_set (response , cursor )
639650
640651 def get_chunk_links (
641652 self , statement_id : str , chunk_index : int
@@ -649,13 +660,13 @@ def get_chunk_links(
649660 ExternalLink: External link for the chunk
650661 """
651662
652- response_data = self .http_client ._make_request (
663+ response_data = self ._http_client ._make_request (
653664 method = "GET" ,
654665 path = self .CHUNK_PATH_WITH_ID_AND_INDEX .format (statement_id , chunk_index ),
655666 )
656667 response = GetChunksResponse .from_dict (response_data )
657668
658- links = response .external_links
669+ links = response .external_links or []
659670 return links
660671
661672 # == Metadata Operations ==
0 commit comments