@@ -130,35 +130,14 @@ def __init__(
130130 # Extract warehouse ID from http_path
131131 self .warehouse_id = self ._extract_warehouse_id (http_path )
132132
133- # Extract retry policy parameters
134- retry_policy = kwargs .get ("_retry_policy" , None )
135- retry_stop_after_attempts_count = kwargs .get ("_retry_stop_after_attempts_count" , 30 )
136- retry_stop_after_attempts_duration = kwargs .get ("_retry_stop_after_attempts_duration" , 600 )
137- retry_delay_min = kwargs .get ("_retry_delay_min" , 1 )
138- retry_delay_max = kwargs .get ("_retry_delay_max" , 60 )
139- retry_delay_default = kwargs .get ("_retry_delay_default" , 5 )
140- retry_dangerous_codes = kwargs .get ("_retry_dangerous_codes" , [])
141-
142- # Create retry policy if not provided
143- if not retry_policy :
144- from databricks .sql .auth .retry import DatabricksRetryPolicy
145- retry_policy = DatabricksRetryPolicy (
146- delay_min = retry_delay_min ,
147- delay_max = retry_delay_max ,
148- stop_after_attempts_count = retry_stop_after_attempts_count ,
149- stop_after_attempts_duration = retry_stop_after_attempts_duration ,
150- delay_default = retry_delay_default ,
151- force_dangerous_codes = retry_dangerous_codes ,
152- )
153-
154- # Initialize ThriftHttpClient with retry policy
133+ # Initialize ThriftHttpClient
155134 thrift_client = THttpClient (
156135 auth_provider = auth_provider ,
157136 uri_or_host = f"https://{ server_hostname } :{ port } " ,
158137 path = http_path ,
159138 ssl_options = ssl_options ,
160139 max_connections = kwargs .get ("max_connections" , 1 ),
161- retry_policy = retry_policy ,
140+ retry_policy = kwargs . get ( "_retry_stop_after_attempts_count" , 30 ) ,
162141 )
163142
164143 # Set custom headers
@@ -415,7 +394,7 @@ def _results_message_to_execute_response(self, sea_response, command_id):
415394 description = description ,
416395 has_been_closed_server_side = False ,
417396 lz4_compressed = lz4_compressed ,
418- is_staging_operation = manifest_obj . is_volume_operation ,
397+ is_staging_operation = False ,
419398 arrow_schema_bytes = None ,
420399 result_format = manifest_obj .format ,
421400 )
@@ -496,56 +475,48 @@ def execute_command(
496475 result_compression = result_compression ,
497476 )
498477
499- try :
500- response_data = self .http_client .post (
501- path = self .STATEMENT_PATH , data = request .to_dict ()
478+ response_data = self .http_client .post (
479+ path = self .STATEMENT_PATH , data = request .to_dict ()
480+ )
481+ response = ExecuteStatementResponse .from_dict (response_data )
482+ statement_id = response .statement_id
483+ if not statement_id :
484+ raise ServerOperationError (
485+ "Failed to execute command: No statement ID returned" ,
486+ {
487+ "operation-id" : None ,
488+ "diagnostic-info" : None ,
489+ },
502490 )
503- response = ExecuteStatementResponse .from_dict (response_data )
504- statement_id = response .statement_id
505- if not statement_id :
506- raise ServerOperationError (
507- "Failed to execute command: No statement ID returned" ,
508- {
509- "operation-id" : None ,
510- "diagnostic-info" : None ,
511- },
512- )
513491
514- command_id = CommandId .from_sea_statement_id (statement_id )
492+ command_id = CommandId .from_sea_statement_id (statement_id )
515493
516- # Store the command ID in the cursor
517- cursor .active_command_id = command_id
494+ # Store the command ID in the cursor
495+ cursor .active_command_id = command_id
518496
519- # If async operation, return and let the client poll for results
520- if async_op :
521- return None
497+ # If async operation, return and let the client poll for results
498+ if async_op :
499+ return None
522500
523- # For synchronous operation, wait for the statement to complete
524- status = response .status
525- state = status .state
501+ # For synchronous operation, wait for the statement to complete
502+ status = response .status
503+ state = status .state
526504
527- # Keep polling until we reach a terminal state
528- while state in [CommandState .PENDING , CommandState .RUNNING ]:
529- time .sleep (0.5 ) # add a small delay to avoid excessive API calls
530- state = self .get_query_state (command_id )
505+ # Keep polling until we reach a terminal state
506+ while state in [CommandState .PENDING , CommandState .RUNNING ]:
507+ time .sleep (0.5 ) # add a small delay to avoid excessive API calls
508+ state = self .get_query_state (command_id )
531509
532- if state != CommandState .SUCCEEDED :
533- raise ServerOperationError (
534- f"Statement execution did not succeed: { status .error .message if status .error else 'Unknown error' } " ,
535- {
536- "operation-id" : command_id .to_sea_statement_id (),
537- "diagnostic-info" : None ,
538- },
539- )
510+ if state != CommandState .SUCCEEDED :
511+ raise ServerOperationError (
512+ f"Statement execution did not succeed: { status .error .message if status .error else 'Unknown error' } " ,
513+ {
514+ "operation-id" : command_id .to_sea_statement_id (),
515+ "diagnostic-info" : None ,
516+ },
517+ )
540518
541- return self .get_execution_result (command_id , cursor )
542- except Exception as e :
543- # Map exceptions to match Thrift behavior
544- from databricks .sql .exc import RequestError , OperationalError
545- if isinstance (e , (ServerOperationError , RequestError )):
546- raise
547- else :
548- raise OperationalError (f"Error executing command: { str (e )} " )
519+ return self .get_execution_result (command_id , cursor )
549520
550521 def cancel_command (self , command_id : CommandId ) -> None :
551522 """
0 commit comments