Skip to content

Commit 41b2815

Browse files
committed
added retry handling based on idempotency
1 parent 048fae1 commit 41b2815

File tree

10 files changed

+165
-39
lines changed

10 files changed

+165
-39
lines changed

src/databricks/sql/auth/oauth.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from oauthlib.oauth2.rfc6749.errors import OAuth2Error
1313
from databricks.sql.common.http import HttpMethod, HttpHeader
1414
from databricks.sql.common.http import OAuthResponse
15+
from databricks.sql.auth.retry import CommandType
1516
from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler
1617
from databricks.sql.auth.endpoint import OAuthEndpointCollection
1718
from abc import abstractmethod, ABC
@@ -87,6 +88,8 @@ def __fetch_well_known_config(self, hostname: str):
8788
known_config_url = self.idp_endpoint.get_openid_config_url(hostname)
8889

8990
try:
91+
# Set command type for OAuth configuration request
92+
self.http_client.setRequestType(CommandType.AUTH)
9093
response = self.http_client.request(HttpMethod.GET, url=known_config_url)
9194
# Convert urllib3 response to requests-like response for compatibility
9295
response.status_code = response.status
@@ -195,6 +198,8 @@ def __send_token_request(self, token_request_url, data):
195198
"Accept": "application/json",
196199
"Content-Type": "application/x-www-form-urlencoded",
197200
}
201+
# Set command type for OAuth token request
202+
self.http_client.setRequestType(CommandType.AUTH)
198203
# Use unified HTTP client
199204
response = self.http_client.request(
200205
HttpMethod.POST, url=token_request_url, body=data, headers=headers
@@ -337,6 +342,8 @@ def refresh(self) -> Token:
337342
}
338343
)
339344

345+
# Set command type for OAuth client credentials request
346+
self._http_client.setRequestType(CommandType.AUTH)
340347
response = self._http_client.request(
341348
method=HttpMethod.POST, url=self.token_url, headers=headers, body=data
342349
)

src/databricks/sql/auth/retry.py

Lines changed: 82 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,20 @@
3232

3333

3434
class CommandType(Enum):
35-
EXECUTE_STATEMENT = "ExecuteStatement"
35+
NOT_SET = "NotSet"
36+
OPEN_SESSION = "OpenSession"
3637
CLOSE_SESSION = "CloseSession"
38+
METADATA = "Metadata"
3739
CLOSE_OPERATION = "CloseOperation"
38-
GET_OPERATION_STATUS = "GetOperationStatus"
40+
CANCEL_OPERATION = "CancelOperation"
41+
EXECUTE_STATEMENT = "ExecuteStatement"
42+
FETCH_RESULTS = "FetchResults"
43+
CLOUD_FETCH = "CloudFetch"
44+
AUTH = "Auth"
45+
TELEMETRY_PUSH = "TelemetryPush"
46+
VOLUME_GET = "VolumeGet"
47+
VOLUME_PUT = "VolumePut"
48+
VOLUME_DELETE = "VolumeDelete"
3949
OTHER = "Other"
4050

4151
@classmethod
@@ -45,9 +55,66 @@ def get(cls, value: str):
4555
if valid_command:
4656
return getattr(cls, str(valid_command))
4757
else:
58+
# Map Thrift metadata operations to METADATA type
59+
metadata_operations = {
60+
"GetOperationStatus", "GetResultSetMetadata", "GetTables",
61+
"GetColumns", "GetSchemas", "GetCatalogs", "GetFunctions",
62+
"GetPrimaryKeys", "GetTypeInfo", "GetCrossReference",
63+
"GetImportedKeys", "GetExportedKeys", "GetTableTypes"
64+
}
65+
if value in metadata_operations:
66+
return cls.METADATA
4867
return cls.OTHER
4968

5069

70+
class CommandIdempotency(Enum):
71+
IDEMPOTENT = "idempotent"
72+
NON_IDEMPOTENT = "non_idempotent"
73+
74+
75+
# Mapping of CommandType to CommandIdempotency
76+
# Based on the official idempotency classification
77+
COMMAND_IDEMPOTENCY_MAP = {
78+
# NON-IDEMPOTENT operations (safety first - unknown types are not retried)
79+
CommandType.NOT_SET: CommandIdempotency.NON_IDEMPOTENT,
80+
CommandType.EXECUTE_STATEMENT: CommandIdempotency.NON_IDEMPOTENT,
81+
CommandType.FETCH_RESULTS: CommandIdempotency.NON_IDEMPOTENT,
82+
CommandType.VOLUME_PUT: CommandIdempotency.NON_IDEMPOTENT, # PUT can overwrite files
83+
84+
# IDEMPOTENT operations
85+
CommandType.OPEN_SESSION: CommandIdempotency.IDEMPOTENT,
86+
CommandType.CLOSE_SESSION: CommandIdempotency.IDEMPOTENT,
87+
CommandType.METADATA: CommandIdempotency.IDEMPOTENT,
88+
CommandType.CLOSE_OPERATION: CommandIdempotency.IDEMPOTENT,
89+
CommandType.CANCEL_OPERATION: CommandIdempotency.IDEMPOTENT,
90+
CommandType.CLOUD_FETCH: CommandIdempotency.IDEMPOTENT,
91+
CommandType.AUTH: CommandIdempotency.IDEMPOTENT,
92+
CommandType.TELEMETRY_PUSH: CommandIdempotency.IDEMPOTENT,
93+
CommandType.VOLUME_GET: CommandIdempotency.IDEMPOTENT,
94+
CommandType.VOLUME_DELETE: CommandIdempotency.IDEMPOTENT,
95+
CommandType.OTHER: CommandIdempotency.IDEMPOTENT,
96+
}
97+
98+
# HTTP status codes that should never be retried, even for idempotent requests
99+
# These are client error codes that indicate permanent issues
100+
NON_RETRYABLE_STATUS_CODES = {
101+
400, # Bad Request
102+
401, # Unauthorized
103+
403, # Forbidden
104+
404, # Not Found
105+
405, # Method Not Allowed
106+
409, # Conflict
107+
410, # Gone
108+
411, # Length Required
109+
412, # Precondition Failed
110+
413, # Payload Too Large
111+
414, # URI Too Long
112+
415, # Unsupported Media Type
113+
416, # Range Not Satisfiable
114+
501, # Not Implemented
115+
}
116+
117+
51118
class DatabricksRetryPolicy(Retry):
52119
"""
53120
Implements our v3 retry policy by extending urllib3's robust default retry behaviour.
@@ -354,38 +421,25 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]:
354421

355422
logger.info(f"Received status code {status_code} for {method} request")
356423

424+
# Get command idempotency for use in multiple conditions below
425+
command_idempotency = COMMAND_IDEMPOTENCY_MAP.get(
426+
self.command_type, CommandIdempotency.NON_IDEMPOTENT
427+
)
428+
357429
# Request succeeded. Don't retry.
358430
if status_code // 100 <= 3:
359431
return False, "2xx/3xx codes are not retried"
360432

361-
if status_code == 400:
362-
return (
363-
False,
364-
"Received 400 - BAD_REQUEST. Please check the request parameters.",
365-
)
366-
367-
if status_code == 401:
368-
return (
369-
False,
370-
"Received 401 - UNAUTHORIZED. Confirm your authentication credentials.",
371-
)
372-
373-
if status_code == 403:
374-
return False, "403 codes are not retried"
375-
376-
# Request failed and server said NotImplemented. This isn't recoverable. Don't retry.
377-
if status_code == 501:
378-
return False, "Received code 501 from server."
379433

380434
# Request failed and this method is not retryable. We only retry POST requests.
381435
if not self._is_method_retryable(method):
382436
return False, "Only POST requests are retried"
383437

384438
# Request failed with 404 and was a GetOperationStatus. This is not recoverable. Don't retry.
385-
if status_code == 404 and self.command_type == CommandType.GET_OPERATION_STATUS:
439+
if status_code == 404 and self.command_type == CommandType.METADATA:
386440
return (
387441
False,
388-
"GetOperationStatus received 404 code from Databricks. Operation was canceled.",
442+
"Metadata request received 404 code from Databricks. Operation was canceled.",
389443
)
390444

391445
# Request failed with 404 because CloseSession returns 404 if you repeat the request.
@@ -408,23 +462,26 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]:
408462
"CloseOperation received 404 code from Databricks. Cursor is already closed."
409463
)
410464

465+
if status_code in NON_RETRYABLE_STATUS_CODES:
466+
return False, f"Received {status_code} code from Databricks. Operation was canceled."
467+
411468
# Request failed, was an ExecuteStatement and the command may have reached the server
412469
if (
413-
self.command_type == CommandType.EXECUTE_STATEMENT
470+
command_idempotency == CommandIdempotency.NON_IDEMPOTENT
414471
and status_code not in self.status_forcelist
415472
and status_code not in self.force_dangerous_codes
416473
):
417474
return (
418475
False,
419-
"ExecuteStatement command can only be retried for codes 429 and 503",
476+
"Non Idempotent requests can only be retried for codes 429 and 503",
420477
)
421478

422479
# Request failed with a dangerous code, was an ExecuteStatement, but user forced retries for this
423480
# dangerous code. Note that these lines _are not required_ to make these requests retry. They would
424481
# retry automatically. This code is included only so that we can log the exact reason for the retry.
425482
# This gives users signal that their _retry_dangerous_codes setting actually did something.
426483
if (
427-
self.command_type == CommandType.EXECUTE_STATEMENT
484+
command_idempotency == CommandIdempotency.NON_IDEMPOTENT
428485
and status_code in self.force_dangerous_codes
429486
):
430487
return (

src/databricks/sql/backend/sea/utils/http_client.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,13 +285,19 @@ def _get_command_type_from_path(self, path: str, method: str) -> CommandType:
285285
if method == "POST" and path.endswith("/statements"):
286286
return CommandType.EXECUTE_STATEMENT
287287
elif "/cancel" in path:
288-
return CommandType.OTHER # Cancel operation
288+
return CommandType.CANCEL_OPERATION
289289
elif method == "DELETE":
290290
return CommandType.CLOSE_OPERATION
291291
elif method == "GET":
292-
return CommandType.GET_OPERATION_STATUS
292+
# For GET requests on statements, determine if it's fetching results or status
293+
if "/result/chunks/" in path:
294+
return CommandType.FETCH_RESULTS
295+
else:
296+
return CommandType.METADATA # Statement status queries
293297
elif "/sessions" in path:
294-
if method == "DELETE":
298+
if method == "POST" and path.endswith("/sessions"):
299+
return CommandType.OPEN_SESSION
300+
elif method == "DELETE":
295301
return CommandType.CLOSE_SESSION
296302

297-
return CommandType.OTHER
303+
return CommandType.NOT_SET

src/databricks/sql/client.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from databricks.sql.thrift_api.TCLIService import ttypes
2626
from databricks.sql.backend.thrift_backend import ThriftDatabricksClient
2727
from databricks.sql.backend.databricks_client import DatabricksClient
28+
from databricks.sql.auth.retry import CommandType
2829
from databricks.sql.utils import (
2930
ParamEscaper,
3031
inject_parameters,
@@ -774,6 +775,9 @@ def _handle_staging_put(
774775
session_id_hex=self.connection.get_session_id_hex(),
775776
)
776777

778+
# Set command type for volume PUT operation
779+
self.connection.http_client.setRequestType(CommandType.VOLUME_PUT)
780+
777781
with open(local_file, "rb") as fh:
778782
r = self.connection.http_client.request(
779783
HttpMethod.PUT, presigned_url, body=fh.read(), headers=headers
@@ -830,6 +834,9 @@ def _handle_staging_put_stream(
830834
session_id_hex=self.connection.get_session_id_hex(),
831835
)
832836

837+
# Set command type for volume PUT stream operation
838+
self.connection.http_client.setRequestType(CommandType.VOLUME_PUT)
839+
833840
r = self.connection.http_client.request(
834841
HttpMethod.PUT, presigned_url, body=stream.read(), headers=headers
835842
)
@@ -851,6 +858,9 @@ def _handle_staging_get(
851858
session_id_hex=self.connection.get_session_id_hex(),
852859
)
853860

861+
# Set command type for volume GET operation
862+
self.connection.http_client.setRequestType(CommandType.VOLUME_GET)
863+
854864
r = self.connection.http_client.request(
855865
HttpMethod.GET, presigned_url, headers=headers
856866
)
@@ -874,6 +884,9 @@ def _handle_staging_remove(
874884
):
875885
"""Make an HTTP DELETE request to the presigned_url"""
876886

887+
# Set command type for volume DELETE operation
888+
self.connection.http_client.setRequestType(CommandType.VOLUME_DELETE)
889+
877890
r = self.connection.http_client.request(
878891
HttpMethod.DELETE, presigned_url, headers=headers
879892
)

src/databricks/sql/common/unified_http_client.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,11 +218,37 @@ def _prepare_headers(
218218
def _prepare_retry_policy(self):
219219
"""Set up the retry policy for the current request."""
220220
if isinstance(self._retry_policy, DatabricksRetryPolicy):
221-
# Set command type for HTTP requests to OTHER (not database commands)
222-
self._retry_policy.command_type = CommandType.OTHER
221+
# Only set command type to NOT_SET if it hasn't been explicitly set via setRequestType()
222+
if self._retry_policy.command_type is None:
223+
self._retry_policy.command_type = CommandType.NOT_SET
223224
# Start the retry timer for duration-based retry limits
224225
self._retry_policy.start_retry_timer()
225226

227+
def setRequestType(self, request_type: CommandType):
228+
"""
229+
Set the specific request type for the next HTTP request.
230+
231+
This allows clients to specify what type of operation they're performing
232+
so the retry policy can make appropriate idempotency decisions.
233+
234+
Args:
235+
request_type: The CommandType enum value for this operation
236+
237+
Example:
238+
# For authentication requests (OAuth, etc.)
239+
http_client.setRequestType(CommandType.AUTH)
240+
response = http_client.request(HttpMethod.POST, url, body=data)
241+
242+
# For cloud fetch operations
243+
http_client.setRequestType(CommandType.CLOUD_FETCH)
244+
response = http_client.request(HttpMethod.GET, cloud_url)
245+
"""
246+
if isinstance(self._retry_policy, DatabricksRetryPolicy):
247+
self._retry_policy.command_type = request_type
248+
logger.debug(f"Set request type to: {request_type.value}")
249+
else:
250+
logger.warning(f"Cannot set request type {request_type.value}: retry policy is not DatabricksRetryPolicy")
251+
226252
@contextmanager
227253
def request_context(
228254
self,
@@ -269,6 +295,11 @@ def request_context(
269295
logger.error("HTTP request error: %s", e)
270296
raise RequestError(f"HTTP request error: {e}")
271297
finally:
298+
# Reset command type after request completion to prevent it from affecting subsequent requests
299+
if isinstance(self._retry_policy, DatabricksRetryPolicy):
300+
self._retry_policy.command_type = None
301+
logger.debug("Reset command type after request completion")
302+
272303
if response:
273304
response.close()
274305

tests/e2e/common/large_queries_mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def fetch_rows(self, cursor, row_count, fetchmany_size):
3333
rows = self.get_some_rows(cursor, fetchmany_size)
3434
if not rows:
3535
# Read all the rows, row_count should match
36-
self.assertEqual(n, row_count)
36+
assert n == row_count
3737

3838
num_fetches = max(math.ceil(n / 10000), 1)
3939
latency_ms = int((time.time() - start_time) * 1000 / num_fetches), 1

tests/e2e/common/retry_test_mixins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def test_retry_max_count_not_exceeded(self, mock_send_telemetry, extra_params):
278278
THEN the connector issues six request (original plus five retries)
279279
before raising an exception
280280
"""
281-
with mocked_server_response(status=404) as mock_obj:
281+
with mocked_server_response(status=429) as mock_obj:
282282
with pytest.raises(MaxRetryError) as cm:
283283
extra_params = {**extra_params, **self._retry_policy}
284284
with self.connection(extra_params=extra_params) as conn:

tests/e2e/common/staging_ingestion_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_staging_ingestion_life_cycle(self, ingestion_user):
8181
# GET after REMOVE should fail
8282

8383
with pytest.raises(
84-
Error, match="too many 404 error responses"
84+
Error, match="Staging operation over HTTP was unsuccessful: 404"
8585
):
8686
cursor = conn.cursor()
8787
query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv' TO '{new_temp_path}'"

tests/e2e/common/uc_volume_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_uc_volume_life_cycle(self, catalog, schema):
8181
# GET after REMOVE should fail
8282

8383
with pytest.raises(
84-
Error, match="too many 404 error responses"
84+
Error, match="Staging operation over HTTP was unsuccessful: 404"
8585
):
8686
cursor = conn.cursor()
8787
query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'"

0 commit comments

Comments
 (0)