@@ -130,14 +130,40 @@ def __init__(
130130 # Extract warehouse ID from http_path
131131 self .warehouse_id = self ._extract_warehouse_id (http_path )
132132
133- # Initialize ThriftHttpClient
133+ # Extract retry policy parameters
134+ retry_policy = kwargs .get ("_retry_policy" , None )
135+ retry_stop_after_attempts_count = kwargs .get (
136+ "_retry_stop_after_attempts_count" , 30
137+ )
138+ retry_stop_after_attempts_duration = kwargs .get (
139+ "_retry_stop_after_attempts_duration" , 600
140+ )
141+ retry_delay_min = kwargs .get ("_retry_delay_min" , 1 )
142+ retry_delay_max = kwargs .get ("_retry_delay_max" , 60 )
143+ retry_delay_default = kwargs .get ("_retry_delay_default" , 5 )
144+ retry_dangerous_codes = kwargs .get ("_retry_dangerous_codes" , [])
145+
146+ # Create retry policy if not provided
147+ if not retry_policy :
148+ from databricks .sql .auth .retry import DatabricksRetryPolicy
149+
150+ retry_policy = DatabricksRetryPolicy (
151+ delay_min = retry_delay_min ,
152+ delay_max = retry_delay_max ,
153+ stop_after_attempts_count = retry_stop_after_attempts_count ,
154+ stop_after_attempts_duration = retry_stop_after_attempts_duration ,
155+ delay_default = retry_delay_default ,
156+ force_dangerous_codes = retry_dangerous_codes ,
157+ )
158+
159+ # Initialize ThriftHttpClient with retry policy
134160 thrift_client = THttpClient (
135161 auth_provider = auth_provider ,
136162 uri_or_host = f"https://{ server_hostname } :{ port } " ,
137163 path = http_path ,
138164 ssl_options = ssl_options ,
139165 max_connections = kwargs .get ("max_connections" , 1 ),
140- retry_policy = kwargs . get ( "_retry_stop_after_attempts_count" , 30 ) ,
166+ retry_policy = retry_policy ,
141167 )
142168
143169 # Set custom headers
@@ -475,48 +501,99 @@ def execute_command(
475501 result_compression = result_compression ,
476502 )
477503
478- response_data = self .http_client .post (
479- path = self .STATEMENT_PATH , data = request .to_dict ()
480- )
481- response = ExecuteStatementResponse .from_dict (response_data )
482- statement_id = response .statement_id
483- if not statement_id :
484- raise ServerOperationError (
485- "Failed to execute command: No statement ID returned" ,
486- {
487- "operation-id" : None ,
488- "diagnostic-info" : None ,
489- },
504+ try :
505+ response_data = self .http_client .post (
506+ path = self .STATEMENT_PATH , data = request .to_dict ()
490507 )
508+ response = ExecuteStatementResponse .from_dict (response_data )
509+ statement_id = response .statement_id
510+
511+ if not statement_id :
512+ raise ServerOperationError (
513+ "Failed to execute command: No statement ID returned" ,
514+ {
515+ "operation-id" : None ,
516+ "diagnostic-info" : None ,
517+ },
518+ )
491519
492- command_id = CommandId .from_sea_statement_id (statement_id )
520+ command_id = CommandId .from_sea_statement_id (statement_id )
493521
494- # Store the command ID in the cursor
495- cursor .active_command_id = command_id
522+ # Store the command ID in the cursor
523+ cursor .active_command_id = command_id
496524
497- # If async operation, return and let the client poll for results
498- if async_op :
499- return None
525+ # If async operation, return and let the client poll for results
526+ if async_op :
527+ return None
500528
501- # For synchronous operation, wait for the statement to complete
502- status = response .status
503- state = status .state
529+ # For synchronous operation, wait for the statement to complete
530+ status = response .status
531+ state = status .state
504532
505- # Keep polling until we reach a terminal state
506- while state in [CommandState .PENDING , CommandState .RUNNING ]:
507- time .sleep (0.5 ) # add a small delay to avoid excessive API calls
508- state = self .get_query_state (command_id )
533+ # Keep polling until we reach a terminal state
534+ while state in [CommandState .PENDING , CommandState .RUNNING ]:
535+ time .sleep (0.5 ) # add a small delay to avoid excessive API calls
536+ state = self .get_query_state (command_id )
509537
510- if state != CommandState .SUCCEEDED :
511- raise ServerOperationError (
512- f"Statement execution did not succeed: { status .error .message if status .error else 'Unknown error' } " ,
513- {
514- "operation-id" : command_id .to_sea_statement_id (),
515- "diagnostic-info" : None ,
516- },
517- )
538+ if state != CommandState .SUCCEEDED :
539+ error_message = (
540+ status .error .message if status .error else "Unknown error"
541+ )
542+ error_code = status .error .error_code if status .error else None
543+
544+ # Map error codes to appropriate exceptions to match Thrift behavior
545+ from databricks .sql .exc import (
546+ DatabaseError ,
547+ ProgrammingError ,
548+ OperationalError ,
549+ )
550+
551+ if (
552+ error_code == "SYNTAX_ERROR"
553+ or "syntax error" in error_message .lower ()
554+ ):
555+ raise DatabaseError (
556+ f"Syntax error in SQL statement: { error_message } "
557+ )
558+ elif error_code == "TEMPORARILY_UNAVAILABLE" :
559+ raise OperationalError (
560+ f"Service temporarily unavailable: { error_message } "
561+ )
562+ elif error_code == "PERMISSION_DENIED" :
563+ raise OperationalError (f"Permission denied: { error_message } " )
564+ else :
565+ raise ServerOperationError (
566+ f"Statement execution failed: { error_message } " ,
567+ {
568+ "operation-id" : command_id .to_sea_statement_id (),
569+ "diagnostic-info" : None ,
570+ },
571+ )
518572
519- return self .get_execution_result (command_id , cursor )
573+ return self .get_execution_result (command_id , cursor )
574+
575+ except Exception as e :
576+ # Map exceptions to match Thrift behavior
577+ from databricks .sql .exc import DatabaseError , OperationalError , RequestError
578+
579+ if isinstance (e , (DatabaseError , OperationalError , RequestError )):
580+ # Pass through these exceptions as they're already properly typed
581+ raise
582+ elif "syntax error" in str (e ).lower ():
583+ # Syntax errors
584+ raise DatabaseError (f"Syntax error in SQL statement: { str (e )} " )
585+ elif "permission denied" in str (e ).lower ():
586+ # Permission errors
587+ raise OperationalError (f"Permission denied: { str (e )} " )
588+ elif "database" in str (e ).lower () and "not found" in str (e ).lower ():
589+ # Database not found errors
590+ raise DatabaseError (f"Database not found: { str (e )} " )
591+ elif "table" in str (e ).lower () and "not found" in str (e ).lower ():
592+ # Table not found errors
593+ raise DatabaseError (f"Table not found: { str (e )} " )
594+ else :
595+ # Generic operational errors
596+ raise OperationalError (f"Error executing command: { str (e )} " )
520597
521598 def cancel_command (self , command_id : CommandId ) -> None :
522599 """
0 commit comments