@@ -130,14 +130,35 @@ def __init__(
130130 # Extract warehouse ID from http_path
131131 self .warehouse_id = self ._extract_warehouse_id (http_path )
132132
133- # Initialize ThriftHttpClient
133+ # Extract retry policy parameters
134+ retry_policy = kwargs .get ("_retry_policy" , None )
135+ retry_stop_after_attempts_count = kwargs .get ("_retry_stop_after_attempts_count" , 30 )
136+ retry_stop_after_attempts_duration = kwargs .get ("_retry_stop_after_attempts_duration" , 600 )
137+ retry_delay_min = kwargs .get ("_retry_delay_min" , 1 )
138+ retry_delay_max = kwargs .get ("_retry_delay_max" , 60 )
139+ retry_delay_default = kwargs .get ("_retry_delay_default" , 5 )
140+ retry_dangerous_codes = kwargs .get ("_retry_dangerous_codes" , [])
141+
142+ # Create retry policy if not provided
143+ if not retry_policy :
144+ from databricks .sql .auth .retry import DatabricksRetryPolicy
145+ retry_policy = DatabricksRetryPolicy (
146+ delay_min = retry_delay_min ,
147+ delay_max = retry_delay_max ,
148+ stop_after_attempts_count = retry_stop_after_attempts_count ,
149+ stop_after_attempts_duration = retry_stop_after_attempts_duration ,
150+ delay_default = retry_delay_default ,
151+ force_dangerous_codes = retry_dangerous_codes ,
152+ )
153+
154+ # Initialize ThriftHttpClient with retry policy
134155 thrift_client = THttpClient (
135156 auth_provider = auth_provider ,
136157 uri_or_host = f"https://{ server_hostname } :{ port } " ,
137158 path = http_path ,
138159 ssl_options = ssl_options ,
139160 max_connections = kwargs .get ("max_connections" , 1 ),
140- retry_policy = kwargs . get ( "_retry_stop_after_attempts_count" , 30 ) ,
161+ retry_policy = retry_policy ,
141162 )
142163
143164 # Set custom headers
@@ -394,7 +415,7 @@ def _results_message_to_execute_response(self, sea_response, command_id):
394415 description = description ,
395416 has_been_closed_server_side = False ,
396417 lz4_compressed = lz4_compressed ,
397- is_staging_operation = False ,
418+ is_staging_operation = manifest_obj . is_volume_operation ,
398419 arrow_schema_bytes = None ,
399420 result_format = manifest_obj .format ,
400421 )
@@ -475,48 +496,56 @@ def execute_command(
475496 result_compression = result_compression ,
476497 )
477498
478- response_data = self .http_client .post (
479- path = self .STATEMENT_PATH , data = request .to_dict ()
480- )
481- response = ExecuteStatementResponse .from_dict (response_data )
482- statement_id = response .statement_id
483- if not statement_id :
484- raise ServerOperationError (
485- "Failed to execute command: No statement ID returned" ,
486- {
487- "operation-id" : None ,
488- "diagnostic-info" : None ,
489- },
499+ try :
500+ response_data = self .http_client .post (
501+ path = self .STATEMENT_PATH , data = request .to_dict ()
490502 )
503+ response = ExecuteStatementResponse .from_dict (response_data )
504+ statement_id = response .statement_id
505+ if not statement_id :
506+ raise ServerOperationError (
507+ "Failed to execute command: No statement ID returned" ,
508+ {
509+ "operation-id" : None ,
510+ "diagnostic-info" : None ,
511+ },
512+ )
491513
492- command_id = CommandId .from_sea_statement_id (statement_id )
514+ command_id = CommandId .from_sea_statement_id (statement_id )
493515
494- # Store the command ID in the cursor
495- cursor .active_command_id = command_id
516+ # Store the command ID in the cursor
517+ cursor .active_command_id = command_id
496518
497- # If async operation, return and let the client poll for results
498- if async_op :
499- return None
519+ # If async operation, return and let the client poll for results
520+ if async_op :
521+ return None
500522
501- # For synchronous operation, wait for the statement to complete
502- status = response .status
503- state = status .state
523+ # For synchronous operation, wait for the statement to complete
524+ status = response .status
525+ state = status .state
504526
505- # Keep polling until we reach a terminal state
506- while state in [CommandState .PENDING , CommandState .RUNNING ]:
507- time .sleep (0.5 ) # add a small delay to avoid excessive API calls
508- state = self .get_query_state (command_id )
527+ # Keep polling until we reach a terminal state
528+ while state in [CommandState .PENDING , CommandState .RUNNING ]:
529+ time .sleep (0.5 ) # add a small delay to avoid excessive API calls
530+ state = self .get_query_state (command_id )
509531
510- if state != CommandState .SUCCEEDED :
511- raise ServerOperationError (
512- f"Statement execution did not succeed: { status .error .message if status .error else 'Unknown error' } " ,
513- {
514- "operation-id" : command_id .to_sea_statement_id (),
515- "diagnostic-info" : None ,
516- },
517- )
532+ if state != CommandState .SUCCEEDED :
533+ raise ServerOperationError (
534+ f"Statement execution did not succeed: { status .error .message if status .error else 'Unknown error' } " ,
535+ {
536+ "operation-id" : command_id .to_sea_statement_id (),
537+ "diagnostic-info" : None ,
538+ },
539+ )
518540
519- return self .get_execution_result (command_id , cursor )
541+ return self .get_execution_result (command_id , cursor )
542+ except Exception as e :
543+ # Map exceptions to match Thrift behavior
544+ from databricks .sql .exc import RequestError , OperationalError
545+ if isinstance (e , (ServerOperationError , RequestError )):
546+ raise
547+ else :
548+ raise OperationalError (f"Error executing command: { str (e )} " )
520549
521550 def cancel_command (self , command_id : CommandId ) -> None :
522551 """
0 commit comments