1818
1919if TYPE_CHECKING :
2020 from databricks .sql .client import Cursor
21- from databricks .sql .backend .sea .result_set import SeaResultSet
21+
22+ from databricks .sql .backend .sea .result_set import SeaResultSet
2223
2324from databricks .sql .backend .databricks_client import DatabricksClient
2425from databricks .sql .backend .types import (
@@ -130,6 +131,8 @@ def __init__(
130131 "_use_arrow_native_complex_types" , True
131132 )
132133
134+ self .use_hybrid_disposition = kwargs .get ("use_hybrid_disposition" , True )
135+
133136 # Extract warehouse ID from http_path
134137 self .warehouse_id = self ._extract_warehouse_id (http_path )
135138
@@ -330,7 +333,7 @@ def _extract_description_from_manifest(
330333 return columns
331334
332335 def _results_message_to_execute_response (
333- self , response : GetStatementResponse
336+ self , response : Union [ ExecuteStatementResponse , GetStatementResponse ]
334337 ) -> ExecuteResponse :
335338 """
336339 Convert a SEA response to an ExecuteResponse and extract result data.
@@ -364,6 +367,27 @@ def _results_message_to_execute_response(
364367
365368 return execute_response
366369
370+ def _response_to_result_set (
371+ self ,
372+ response : Union [ExecuteStatementResponse , GetStatementResponse ],
373+ cursor : Cursor ,
374+ ) -> SeaResultSet :
375+ """
376+ Convert a SEA response to a SeaResultSet.
377+ """
378+
379+ execute_response = self ._results_message_to_execute_response (response )
380+
381+ return SeaResultSet (
382+ connection = cursor .connection ,
383+ execute_response = execute_response ,
384+ sea_client = self ,
385+ result_data = response .result ,
386+ manifest = response .manifest ,
387+ buffer_size_bytes = cursor .buffer_size_bytes ,
388+ arraysize = cursor .arraysize ,
389+ )
390+
367391 def _check_command_not_in_failed_or_closed_state (
368392 self , state : CommandState , command_id : CommandId
369393 ) -> None :
@@ -384,21 +408,24 @@ def _check_command_not_in_failed_or_closed_state(
384408
385409 def _wait_until_command_done (
386410 self , response : ExecuteStatementResponse
387- ) -> CommandState :
411+ ) -> Union [ ExecuteStatementResponse , GetStatementResponse ] :
388412 """
389413 Wait until a command is done.
390414 """
391415
392- state = response .status .state
393- command_id = CommandId .from_sea_statement_id (response .statement_id )
416+ final_response : Union [ExecuteStatementResponse , GetStatementResponse ] = response
417+
418+ state = final_response .status .state
419+ command_id = CommandId .from_sea_statement_id (final_response .statement_id )
394420
395421 while state in [CommandState .PENDING , CommandState .RUNNING ]:
396422 time .sleep (self .POLL_INTERVAL_SECONDS )
397- state = self .get_query_state (command_id )
423+ final_response = self ._poll_query (command_id )
424+ state = final_response .status .state
398425
399426 self ._check_command_not_in_failed_or_closed_state (state , command_id )
400427
401- return state
428+ return final_response
402429
403430 def execute_command (
404431 self ,
@@ -456,7 +483,11 @@ def execute_command(
456483 ResultFormat .ARROW_STREAM if use_cloud_fetch else ResultFormat .JSON_ARRAY
457484 ).value
458485 disposition = (
459- ResultDisposition .EXTERNAL_LINKS
486+ (
487+ ResultDisposition .HYBRID
488+ if self .use_hybrid_disposition
489+ else ResultDisposition .EXTERNAL_LINKS
490+ )
460491 if use_cloud_fetch
461492 else ResultDisposition .INLINE
462493 ).value
@@ -500,8 +531,11 @@ def execute_command(
500531 if async_op :
501532 return None
502533
503- self ._wait_until_command_done (response )
504- return self .get_execution_result (command_id , cursor )
534+ final_response : Union [ExecuteStatementResponse , GetStatementResponse ] = response
535+ if response .status .state != CommandState .SUCCEEDED :
536+ final_response = self ._wait_until_command_done (response )
537+
538+ return self ._response_to_result_set (final_response , cursor )
505539
506540 def cancel_command (self , command_id : CommandId ) -> None :
507541 """
@@ -553,18 +587,9 @@ def close_command(self, command_id: CommandId) -> None:
553587 data = request .to_dict (),
554588 )
555589
556- def get_query_state (self , command_id : CommandId ) -> CommandState :
590+ def _poll_query (self , command_id : CommandId ) -> GetStatementResponse :
557591 """
558- Get the state of a running query.
559-
560- Args:
561- command_id: Command identifier
562-
563- Returns:
564- CommandState: The current state of the command
565-
566- Raises:
567- ValueError: If the command ID is invalid
592+ Poll for the current command info.
568593 """
569594
570595 if command_id .backend_type != BackendType .SEA :
@@ -580,9 +605,25 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
580605 path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
581606 data = request .to_dict (),
582607 )
583-
584- # Parse the response
585608 response = GetStatementResponse .from_dict (response_data )
609+
610+ return response
611+
612+ def get_query_state (self , command_id : CommandId ) -> CommandState :
613+ """
614+ Get the state of a running query.
615+
616+ Args:
617+ command_id: Command identifier
618+
619+ Returns:
620+ CommandState: The current state of the command
621+
622+ Raises:
623+ ProgrammingError: If the command ID is invalid
624+ """
625+
626+ response = self ._poll_query (command_id )
586627 return response .status .state
587628
588629 def get_execution_result (
@@ -604,40 +645,12 @@ def get_execution_result(
604645 ValueError: If the command ID is invalid
605646 """
606647
607- if command_id .backend_type != BackendType .SEA :
608- raise ValueError ("Not a valid SEA command ID" )
609-
610- sea_statement_id = command_id .to_sea_statement_id ()
611- if sea_statement_id is None :
612- raise ValueError ("Not a valid SEA command ID" )
613-
614- # Create the request model
615- request = GetStatementRequest (statement_id = sea_statement_id )
616-
617- # Get the statement result
618- response_data = self ._http_client ._make_request (
619- method = "GET" ,
620- path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
621- data = request .to_dict (),
622- )
623- response = GetStatementResponse .from_dict (response_data )
624-
625- # Create and return a SeaResultSet
626- from databricks .sql .backend .sea .result_set import SeaResultSet
648+ response = self ._poll_query (command_id )
649+ return self ._response_to_result_set (response , cursor )
627650
628- execute_response = self ._results_message_to_execute_response (response )
629-
630- return SeaResultSet (
631- connection = cursor .connection ,
632- execute_response = execute_response ,
633- sea_client = self ,
634- result_data = response .result ,
635- manifest = response .manifest ,
636- buffer_size_bytes = cursor .buffer_size_bytes ,
637- arraysize = cursor .arraysize ,
638- )
639-
640- def get_chunk_link (self , statement_id : str , chunk_index : int ) -> ExternalLink :
651+ def get_chunk_links (
652+ self , statement_id : str , chunk_index : int
653+ ) -> List [ExternalLink ]:
641654 """
642655 Get links for chunks starting from the specified index.
643656 Args:
@@ -654,17 +667,7 @@ def get_chunk_link(self, statement_id: str, chunk_index: int) -> ExternalLink:
654667 response = GetChunksResponse .from_dict (response_data )
655668
656669 links = response .external_links or []
657- link = next ((l for l in links if l .chunk_index == chunk_index ), None )
658- if not link :
659- raise ServerOperationError (
660- f"No link found for chunk index { chunk_index } " ,
661- {
662- "operation-id" : statement_id ,
663- "diagnostic-info" : None ,
664- },
665- )
666-
667- return link
670+ return links
668671
669672 # == Metadata Operations ==
670673
0 commit comments