Skip to content

Commit d00e3c8

Browse files
fix e2e
Signed-off-by: Vikrant Puppala <vikrant.puppala@databricks.com>
1 parent 3155211 commit d00e3c8

File tree

9 files changed

+61
-35
lines changed

9 files changed

+61
-35
lines changed

src/databricks/sql/auth/common.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,10 @@ def __init__(
6969
# HTTP client configuration
7070
self.ssl_options = ssl_options
7171
self.socket_timeout = socket_timeout
72-
self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 30
72+
self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 5
7373
self.retry_delay_min = retry_delay_min or 1.0
7474
self.retry_delay_max = retry_delay_max or 60.0
75-
self.retry_stop_after_attempts_duration = (
76-
retry_stop_after_attempts_duration or 900.0
77-
)
75+
self.retry_stop_after_attempts_duration = retry_stop_after_attempts_duration or 900.0
7876
self.retry_delay_default = retry_delay_default or 5.0
7977
self.retry_dangerous_codes = retry_dangerous_codes or []
8078
self.http_proxy = http_proxy

src/databricks/sql/auth/retry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def sleep_for_retry(self, response: BaseHTTPResponse) -> bool:
294294
else:
295295
proposed_wait = self.get_backoff_time()
296296

297-
proposed_wait = max(proposed_wait, self.delay_max)
297+
proposed_wait = min(proposed_wait, self.delay_max)
298298
self.check_proposed_wait(proposed_wait)
299299
logger.debug(f"Retrying after {proposed_wait} seconds")
300300
time.sleep(proposed_wait)
@@ -355,8 +355,8 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]:
355355
logger.info(f"Received status code {status_code} for {method} request")
356356

357357
# Request succeeded. Don't retry.
358-
if status_code == 200:
359-
return False, "200 codes are not retried"
358+
if status_code // 100 == 2:
359+
return False, "2xx codes are not retried"
360360

361361
if status_code == 401:
362362
return (

src/databricks/sql/backend/thrift_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def __init__(
194194

195195
if _max_redirects:
196196
if _max_redirects > self._retry_stop_after_attempts_count:
197-
logger.warn(
197+
logger.warning(
198198
"_retry_max_redirects > _retry_stop_after_attempts_count so it will have no affect!"
199199
)
200200
urllib3_kwargs = {"redirect": _max_redirects}

src/databricks/sql/client.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -376,21 +376,17 @@ def _build_client_context(self, server_hostname: str, **kwargs):
376376
hostname=server_hostname,
377377
ssl_options=ssl_options,
378378
socket_timeout=kwargs.get("_socket_timeout"),
379-
retry_stop_after_attempts_count=kwargs.get(
380-
"_retry_stop_after_attempts_count", 30
381-
),
382-
retry_delay_min=kwargs.get("_retry_delay_min", 1.0),
383-
retry_delay_max=kwargs.get("_retry_delay_max", 60.0),
384-
retry_stop_after_attempts_duration=kwargs.get(
385-
"_retry_stop_after_attempts_duration", 900.0
386-
),
387-
retry_delay_default=kwargs.get("_retry_delay_default", 1.0),
388-
retry_dangerous_codes=kwargs.get("_retry_dangerous_codes", []),
379+
retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count"),
380+
retry_delay_min=kwargs.get("_retry_delay_min"),
381+
retry_delay_max=kwargs.get("_retry_delay_max"),
382+
retry_stop_after_attempts_duration=kwargs.get("_retry_stop_after_attempts_duration"),
383+
retry_delay_default=kwargs.get("_retry_delay_default"),
384+
retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"),
389385
http_proxy=kwargs.get("_http_proxy"),
390386
proxy_username=kwargs.get("_proxy_username"),
391387
proxy_password=kwargs.get("_proxy_password"),
392-
pool_connections=kwargs.get("_pool_connections", 1),
393-
pool_maxsize=kwargs.get("_pool_maxsize", 1),
388+
pool_connections=kwargs.get("_pool_connections"),
389+
pool_maxsize=kwargs.get("_pool_maxsize"),
394390
user_agent=user_agent,
395391
)
396392

src/databricks/sql/common/unified_http_client.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from urllib3.util import make_headers
1010
from urllib3.exceptions import MaxRetryError
1111

12-
from databricks.sql.auth.retry import DatabricksRetryPolicy
12+
from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType
1313
from databricks.sql.exc import RequestError
1414

1515
logger = logging.getLogger(__name__)
@@ -33,6 +33,7 @@ def __init__(self, client_context):
3333
"""
3434
self.config = client_context
3535
self._pool_manager = None
36+
self._retry_policy = None
3637
self._setup_pool_manager()
3738

3839
def _setup_pool_manager(self):
@@ -69,20 +70,25 @@ def _setup_pool_manager(self):
6970
)
7071

7172
# Create retry policy
72-
retry_policy = DatabricksRetryPolicy(
73+
self._retry_policy = DatabricksRetryPolicy(
7374
delay_min=self.config.retry_delay_min,
7475
delay_max=self.config.retry_delay_max,
7576
stop_after_attempts_count=self.config.retry_stop_after_attempts_count,
7677
stop_after_attempts_duration=self.config.retry_stop_after_attempts_duration,
7778
delay_default=self.config.retry_delay_default,
7879
force_dangerous_codes=self.config.retry_dangerous_codes,
7980
)
81+
82+
# Initialize the required attributes that DatabricksRetryPolicy expects
83+
# but doesn't initialize in its constructor
84+
self._retry_policy._command_type = None
85+
self._retry_policy._retry_start_time = None
8086

8187
# Common pool manager kwargs
8288
pool_kwargs = {
8389
"num_pools": self.config.pool_connections,
8490
"maxsize": self.config.pool_maxsize,
85-
"retries": retry_policy,
91+
"retries": self._retry_policy,
8692
"timeout": urllib3.Timeout(
8793
connect=self.config.socket_timeout, read=self.config.socket_timeout
8894
)
@@ -119,6 +125,14 @@ def _prepare_headers(
119125

120126
return request_headers
121127

128+
def _prepare_retry_policy(self):
129+
"""Set up the retry policy for the current request."""
130+
if isinstance(self._retry_policy, DatabricksRetryPolicy):
131+
# Set command type for HTTP requests to OTHER (not database commands)
132+
self._retry_policy.command_type = CommandType.OTHER
133+
# Start the retry timer for duration-based retry limits
134+
self._retry_policy.start_retry_timer()
135+
122136
@contextmanager
123137
def request_context(
124138
self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs
@@ -138,6 +152,10 @@ def request_context(
138152
logger.debug("Making %s request to %s", method, url)
139153

140154
request_headers = self._prepare_headers(headers)
155+
156+
# Prepare retry policy for this request
157+
self._prepare_retry_policy()
158+
141159
response = None
142160

143161
try:

tests/e2e/common/retry_test_mixins.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ def test_retry_exponential_backoff(self, mock_send_telemetry, extra_params):
247247
"""
248248
retry_policy = self._retry_policy.copy()
249249
retry_policy["_retry_delay_min"] = 1
250+
retry_policy["_retry_delay_max"] = 10
250251

251252
time_start = time.time()
252253
with mocked_server_response(
@@ -282,9 +283,11 @@ def test_retry_max_duration_not_exceeded(self, extra_params):
282283
WHEN the server sends a Retry-After header of 60 seconds
283284
THEN the connector raises a MaxRetryDurationError
284285
"""
286+
retry_policy = self._retry_policy.copy()
287+
retry_policy["_retry_delay_max"] = 60
285288
with mocked_server_response(status=429, headers={"Retry-After": "60"}):
286289
with pytest.raises(RequestError) as cm:
287-
extra_params = {**extra_params, **self._retry_policy}
290+
extra_params = {**extra_params, **retry_policy}
288291
with self.connection(extra_params=extra_params) as conn:
289292
pass
290293
assert isinstance(cm.value.args[1], MaxRetryDurationError)

tests/e2e/common/staging_ingestion_tests.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,19 @@ def test_staging_ingestion_life_cycle(self, ingestion_user):
6868
# REMOVE should succeed
6969

7070
remove_query = f"REMOVE 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv'"
71-
72-
with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn:
71+
# Use minimal retry settings to fail fast for staging operations
72+
extra_params = {
73+
"staging_allowed_local_path": "/",
74+
"_retry_stop_after_attempts_count": 1,
75+
}
76+
with self.connection(extra_params=extra_params) as conn:
7377
cursor = conn.cursor()
7478
cursor.execute(remove_query)
7579

7680
# GET after REMOVE should fail
7781

7882
with pytest.raises(
79-
Error, match="Staging operation over HTTP was unsuccessful: 404"
83+
Error, match="too many 404 error responses"
8084
):
8185
cursor = conn.cursor()
8286
query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv' TO '{new_temp_path}'"

tests/e2e/common/uc_volume_tests.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,19 @@ def test_uc_volume_life_cycle(self, catalog, schema):
6868

6969
remove_query = f"REMOVE '/Volumes/{catalog}/{schema}/e2etests/file1.csv'"
7070

71-
with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn:
71+
# Use minimal retry settings to fail fast
72+
extra_params = {
73+
"staging_allowed_local_path": "/",
74+
"_retry_stop_after_attempts_count": 1,
75+
}
76+
with self.connection(extra_params=extra_params) as conn:
7277
cursor = conn.cursor()
7378
cursor.execute(remove_query)
7479

7580
# GET after REMOVE should fail
7681

7782
with pytest.raises(
78-
Error, match="Staging operation over HTTP was unsuccessful: 404"
83+
Error, match="too many 404 error responses"
7984
):
8085
cursor = conn.cursor()
8186
query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'"

tests/e2e/test_driver.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,14 @@
6060
unsafe_logger.addHandler(logging.FileHandler("./tests-unsafe.log"))
6161

6262
# manually decorate DecimalTestsMixin to need arrow support
63-
for name in loader.getTestCaseNames(DecimalTestsMixin, "test_"):
64-
fn = getattr(DecimalTestsMixin, name)
65-
decorated = skipUnless(pysql_supports_arrow(), "Decimal tests need arrow support")(
66-
fn
67-
)
68-
setattr(DecimalTestsMixin, name, decorated)
63+
test_loader = loader.TestLoader()
64+
for name in test_loader.getTestCaseNames(DecimalTestsMixin):
65+
if name.startswith("test_"):
66+
fn = getattr(DecimalTestsMixin, name)
67+
decorated = skipUnless(pysql_supports_arrow(), "Decimal tests need arrow support")(
68+
fn
69+
)
70+
setattr(DecimalTestsMixin, name, decorated)
6971

7072

7173
class PySQLPytestTestCase:

0 commit comments

Comments
 (0)