@@ -130,14 +130,40 @@ def __init__(
130130 # Extract warehouse ID from http_path
131131 self .warehouse_id = self ._extract_warehouse_id (http_path )
132132
133+ # Extract retry policy parameters
134+ retry_policy = kwargs .get ("_retry_policy" , None )
135+ retry_stop_after_attempts_count = kwargs .get (
136+ "_retry_stop_after_attempts_count" , 30
137+ )
138+ retry_stop_after_attempts_duration = kwargs .get (
139+ "_retry_stop_after_attempts_duration" , 600
140+ )
141+ retry_delay_min = kwargs .get ("_retry_delay_min" , 1 )
142+ retry_delay_max = kwargs .get ("_retry_delay_max" , 60 )
143+ retry_delay_default = kwargs .get ("_retry_delay_default" , 5 )
144+ retry_dangerous_codes = kwargs .get ("_retry_dangerous_codes" , [])
145+
146+ # Create retry policy if not provided
147+ if not retry_policy :
148+ from databricks .sql .auth .retry import DatabricksRetryPolicy
149+
150+ retry_policy = DatabricksRetryPolicy (
151+ delay_min = retry_delay_min ,
152+ delay_max = retry_delay_max ,
153+ stop_after_attempts_count = retry_stop_after_attempts_count ,
154+ stop_after_attempts_duration = retry_stop_after_attempts_duration ,
155+ delay_default = retry_delay_default ,
156+ force_dangerous_codes = retry_dangerous_codes ,
157+ )
158+
133159 # Initialize ThriftHttpClient
134160 thrift_client = THttpClient (
135161 auth_provider = auth_provider ,
136162 uri_or_host = f"https://{ server_hostname } :{ port } " ,
137163 path = http_path ,
138164 ssl_options = ssl_options ,
139165 max_connections = kwargs .get ("max_connections" , 1 ),
140- retry_policy = kwargs . get ( "_retry_stop_after_attempts_count" , 30 ),
166+ retry_policy = retry_policy , # Use the configured retry policy
141167 )
142168
143169 # Set custom headers
@@ -229,22 +255,31 @@ def open_session(
229255 schema = schema ,
230256 )
231257
232- response = self .http_client .post (
233- path = self .SESSION_PATH , data = request_data .to_dict ()
234- )
235-
236- session_response = CreateSessionResponse .from_dict (response )
237- session_id = session_response .session_id
238- if not session_id :
239- raise ServerOperationError (
240- "Failed to create session: No session ID returned" ,
241- {
242- "operation-id" : None ,
243- "diagnostic-info" : None ,
244- },
258+ try :
259+ response = self .http_client .post (
260+ path = self .SESSION_PATH , data = request_data .to_dict ()
245261 )
246262
247- return SessionId .from_sea_session_id (session_id )
263+ session_response = CreateSessionResponse .from_dict (response )
264+ session_id = session_response .session_id
265+ if not session_id :
266+ raise ServerOperationError (
267+ "Failed to create session: No session ID returned" ,
268+ {
269+ "operation-id" : None ,
270+ "diagnostic-info" : None ,
271+ },
272+ )
273+
274+ return SessionId .from_sea_session_id (session_id )
275+ except Exception as e :
276+ # Map exceptions to match Thrift behavior
277+ from databricks .sql .exc import RequestError , OperationalError
278+
279+ if isinstance (e , (RequestError , ServerOperationError )):
280+ raise
281+ else :
282+ raise OperationalError (f"Error opening session: { str (e )} " )
248283
249284 def close_session (self , session_id : SessionId ) -> None :
250285 """
@@ -269,10 +304,25 @@ def close_session(self, session_id: SessionId) -> None:
269304 session_id = sea_session_id ,
270305 )
271306
272- self .http_client .delete (
273- path = self .SESSION_PATH_WITH_ID .format (sea_session_id ),
274- data = request_data .to_dict (),
275- )
307+ try :
308+ self .http_client .delete (
309+ path = self .SESSION_PATH_WITH_ID .format (sea_session_id ),
310+ data = request_data .to_dict (),
311+ )
312+ except Exception as e :
313+ # Map exceptions to match Thrift behavior
314+ from databricks .sql .exc import (
315+ RequestError ,
316+ OperationalError ,
317+ SessionAlreadyClosedError ,
318+ )
319+
320+ if isinstance (e , RequestError ) and "404" in str (e ):
321+ raise SessionAlreadyClosedError ("Session is already closed" )
322+ elif isinstance (e , (RequestError , ServerOperationError )):
323+ raise
324+ else :
325+ raise OperationalError (f"Error closing session: { str (e )} " )
276326
277327 @staticmethod
278328 def get_default_session_configuration_value (name : str ) -> Optional [str ]:
@@ -475,48 +525,57 @@ def execute_command(
475525 result_compression = result_compression ,
476526 )
477527
478- response_data = self .http_client .post (
479- path = self .STATEMENT_PATH , data = request .to_dict ()
480- )
481- response = ExecuteStatementResponse .from_dict (response_data )
482- statement_id = response .statement_id
483- if not statement_id :
484- raise ServerOperationError (
485- "Failed to execute command: No statement ID returned" ,
486- {
487- "operation-id" : None ,
488- "diagnostic-info" : None ,
489- },
528+ try :
529+ response_data = self .http_client .post (
530+ path = self .STATEMENT_PATH , data = request .to_dict ()
490531 )
532+ response = ExecuteStatementResponse .from_dict (response_data )
533+ statement_id = response .statement_id
534+ if not statement_id :
535+ raise ServerOperationError (
536+ "Failed to execute command: No statement ID returned" ,
537+ {
538+ "operation-id" : None ,
539+ "diagnostic-info" : None ,
540+ },
541+ )
491542
492- command_id = CommandId .from_sea_statement_id (statement_id )
543+ command_id = CommandId .from_sea_statement_id (statement_id )
493544
494- # Store the command ID in the cursor
495- cursor .active_command_id = command_id
545+ # Store the command ID in the cursor
546+ cursor .active_command_id = command_id
496547
497- # If async operation, return and let the client poll for results
498- if async_op :
499- return None
548+ # If async operation, return and let the client poll for results
549+ if async_op :
550+ return None
500551
501- # For synchronous operation, wait for the statement to complete
502- status = response .status
503- state = status .state
552+ # For synchronous operation, wait for the statement to complete
553+ status = response .status
554+ state = status .state
504555
505- # Keep polling until we reach a terminal state
506- while state in [CommandState .PENDING , CommandState .RUNNING ]:
507- time .sleep (0.5 ) # add a small delay to avoid excessive API calls
508- state = self .get_query_state (command_id )
556+ # Keep polling until we reach a terminal state
557+ while state in [CommandState .PENDING , CommandState .RUNNING ]:
558+ time .sleep (0.5 ) # add a small delay to avoid excessive API calls
559+ state = self .get_query_state (command_id )
509560
510- if state != CommandState .SUCCEEDED :
511- raise ServerOperationError (
512- f"Statement execution did not succeed: { status .error .message if status .error else 'Unknown error' } " ,
513- {
514- "operation-id" : command_id .to_sea_statement_id (),
515- "diagnostic-info" : None ,
516- },
517- )
561+ if state != CommandState .SUCCEEDED :
562+ raise ServerOperationError (
563+ f"Statement execution did not succeed: { status .error .message if status .error else 'Unknown error' } " ,
564+ {
565+ "operation-id" : command_id .to_sea_statement_id (),
566+ "diagnostic-info" : None ,
567+ },
568+ )
518569
519- return self .get_execution_result (command_id , cursor )
570+ return self .get_execution_result (command_id , cursor )
571+ except Exception as e :
572+ # Map exceptions to match Thrift behavior
573+ from databricks .sql .exc import RequestError , OperationalError
574+
575+ if isinstance (e , (RequestError , ServerOperationError )):
576+ raise
577+ else :
578+ raise OperationalError (f"Error executing command: { str (e )} " )
520579
521580 def cancel_command (self , command_id : CommandId ) -> None :
522581 """
@@ -535,10 +594,25 @@ def cancel_command(self, command_id: CommandId) -> None:
535594 sea_statement_id = command_id .to_sea_statement_id ()
536595
537596 request = CancelStatementRequest (statement_id = sea_statement_id )
538- self .http_client .post (
539- path = self .CANCEL_STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
540- data = request .to_dict (),
541- )
597+ try :
598+ self .http_client .post (
599+ path = self .CANCEL_STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
600+ data = request .to_dict (),
601+ )
602+ except Exception as e :
603+ # Map exceptions to match Thrift behavior
604+ from databricks .sql .exc import RequestError , OperationalError
605+
606+ if isinstance (e , RequestError ) and "404" in str (e ):
607+ # Operation was already closed, so we can ignore this
608+ logger .warning (
609+ f"Attempted to cancel a command that was already closed: { sea_statement_id } "
610+ )
611+ return
612+ elif isinstance (e , (RequestError , ServerOperationError )):
613+ raise
614+ else :
615+ raise OperationalError (f"Error canceling command: { str (e )} " )
542616
543617 def close_command (self , command_id : CommandId ) -> None :
544618 """
@@ -557,10 +631,25 @@ def close_command(self, command_id: CommandId) -> None:
557631 sea_statement_id = command_id .to_sea_statement_id ()
558632
559633 request = CloseStatementRequest (statement_id = sea_statement_id )
560- self .http_client .delete (
561- path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
562- data = request .to_dict (),
563- )
634+ try :
635+ self .http_client .delete (
636+ path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
637+ data = request .to_dict (),
638+ )
639+ except Exception as e :
640+ # Map exceptions to match Thrift behavior
641+ from databricks .sql .exc import (
642+ RequestError ,
643+ OperationalError ,
644+ CursorAlreadyClosedError ,
645+ )
646+
647+ if isinstance (e , RequestError ) and "404" in str (e ):
648+ raise CursorAlreadyClosedError ("Cursor is already closed" )
649+ elif isinstance (e , (RequestError , ServerOperationError )):
650+ raise
651+ else :
652+ raise OperationalError (f"Error closing command: { str (e )} " )
564653
565654 def get_query_state (self , command_id : CommandId ) -> CommandState :
566655 """
@@ -582,13 +671,28 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
582671 sea_statement_id = command_id .to_sea_statement_id ()
583672
584673 request = GetStatementRequest (statement_id = sea_statement_id )
585- response_data = self .http_client .get (
586- path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
587- )
674+ try :
675+ response_data = self .http_client .get (
676+ path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
677+ )
588678
589- # Parse the response
590- response = GetStatementResponse .from_dict (response_data )
591- return response .status .state
679+ # Parse the response
680+ response = GetStatementResponse .from_dict (response_data )
681+ return response .status .state
682+ except Exception as e :
683+ # Map exceptions to match Thrift behavior
684+ from databricks .sql .exc import RequestError , OperationalError
685+
686+ if isinstance (e , RequestError ) and "404" in str (e ):
687+ # If the operation is not found, it was likely already closed
688+ logger .warning (
689+ f"Operation not found when checking state: { sea_statement_id } "
690+ )
691+ return CommandState .CANCELLED
692+ elif isinstance (e , (RequestError , ServerOperationError )):
693+ raise
694+ else :
695+ raise OperationalError (f"Error getting query state: { str (e )} " )
592696
593697 def get_execution_result (
594698 self ,
@@ -617,30 +721,39 @@ def get_execution_result(
617721 # Create the request model
618722 request = GetStatementRequest (statement_id = sea_statement_id )
619723
620- # Get the statement result
621- response_data = self .http_client .get (
622- path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
623- )
724+ try :
725+ # Get the statement result
726+ response_data = self .http_client .get (
727+ path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
728+ )
624729
625- # Create and return a SeaResultSet
626- from databricks .sql .result_set import SeaResultSet
627-
628- # Convert the response to an ExecuteResponse and extract result data
629- (
630- execute_response ,
631- result_data ,
632- manifest ,
633- ) = self ._results_message_to_execute_response (response_data , command_id )
634-
635- return SeaResultSet (
636- connection = cursor .connection ,
637- execute_response = execute_response ,
638- sea_client = self ,
639- buffer_size_bytes = cursor .buffer_size_bytes ,
640- arraysize = cursor .arraysize ,
641- result_data = result_data ,
642- manifest = manifest ,
643- )
730+ # Create and return a SeaResultSet
731+ from databricks .sql .result_set import SeaResultSet
732+
733+ # Convert the response to an ExecuteResponse and extract result data
734+ (
735+ execute_response ,
736+ result_data ,
737+ manifest ,
738+ ) = self ._results_message_to_execute_response (response_data , command_id )
739+
740+ return SeaResultSet (
741+ connection = cursor .connection ,
742+ execute_response = execute_response ,
743+ sea_client = self ,
744+ buffer_size_bytes = cursor .buffer_size_bytes ,
745+ arraysize = cursor .arraysize ,
746+ result_data = result_data ,
747+ manifest = manifest ,
748+ )
749+ except Exception as e :
750+ # Map exceptions to match Thrift behavior
751+ from databricks .sql .exc import RequestError , OperationalError
752+
753+ if isinstance (e , (RequestError , ServerOperationError )):
754+ raise
755+ else :
756+ raise OperationalError (f"Error getting execution result: { str (e )} " )
644757
645758 # == Metadata Operations ==
646759
0 commit comments