Skip to content

Commit 76fdb98

Browse files
more fixes
Signed-off-by: Vikrant Puppala <vikrant.puppala@databricks.com>
1 parent 7c33fe4 commit 76fdb98

File tree

8 files changed

+30
-34
lines changed

8 files changed

+30
-34
lines changed

src/databricks/sql/auth/authenticators.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,7 @@ def __init__(
199199
self.azure_client_secret = azure_client_secret
200200
self.azure_workspace_resource_id = azure_workspace_resource_id
201201
self.azure_tenant_id = azure_tenant_id or get_azure_tenant_id_from_host(
202-
hostname,
203-
http_client
202+
hostname, http_client
204203
)
205204
self._http_client = http_client
206205

src/databricks/sql/auth/common.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,18 +115,14 @@ def get_azure_tenant_id_from_host(host: str, http_client) -> str:
115115
login_url = f"{host}/aad/auth"
116116
logger.debug("Loading tenant ID from %s", login_url)
117117

118-
with http_client.request_context(
119-
HttpMethod.GET, login_url
120-
) as resp:
121-
# if resp.status // 100 != 3:
122-
# raise ValueError(
123-
# f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status}"
124-
# )
118+
with http_client.request_context(HttpMethod.GET, login_url) as resp:
125119
entra_id_endpoint = resp.retries.history[-1].redirect_location
126120
if entra_id_endpoint is None:
127-
raise ValueError(f"No Location header in response from {login_url}: {entra_id_endpoint}")
121+
raise ValueError(
122+
f"No Location header in response from {login_url}: {entra_id_endpoint}"
123+
)
128124

129-
# The Location header has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
125+
# The final redirect URL has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
130126
# The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
131127
url = urlparse(entra_id_endpoint)
132128
path_segments = url.path.split("/")

src/databricks/sql/auth/oauth.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,15 +336,15 @@ def refresh(self) -> Token:
336336
**self.extra_params,
337337
}
338338
)
339-
340339

341340
response = self._http_client.request(
342341
method=HttpMethod.POST, url=self.token_url, headers=headers, body=data
343342
)
344343
try:
345344
if response.status == 200:
346-
import json
347-
oauth_response = OAuthResponse(**json.loads(response.data.decode('utf-8')))
345+
oauth_response = OAuthResponse(
346+
**json.loads(response.data.decode("utf-8"))
347+
)
348348
return Token(
349349
oauth_response.access_token,
350350
oauth_response.token_type,

src/databricks/sql/auth/retry.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,13 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]:
356356

357357
# Request succeeded. Don't retry.
358358
if status_code // 100 <= 3:
359-
return False, "2xx codes are not retried"
359+
return False, "2xx/3xx codes are not retried"
360+
361+
if status_code == 400:
362+
return (
363+
False,
364+
"Received 400 - BAD_REQUEST. Please check the request parameters.",
365+
)
360366

361367
if status_code == 401:
362368
return (

src/databricks/sql/common/unified_http_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,9 @@ def request_context(
154154
Yields:
155155
urllib3.HTTPResponse: The HTTP response object
156156
"""
157-
logger.debug("Making %s request to %s", method, urllib.parse.urlparse(url).netloc)
157+
logger.debug(
158+
"Making %s request to %s", method, urllib.parse.urlparse(url).netloc
159+
)
158160

159161
request_headers = self._prepare_headers(headers)
160162

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,6 @@ def __init__(
187187
self._executor = executor
188188

189189
# Create own HTTP client from client context
190-
from databricks.sql.common.unified_http_client import UnifiedHttpClient
191-
192190
self._http_client = UnifiedHttpClient(client_context)
193191

194192
def _export_event(self, event):

tests/e2e/common/retry_test_mixins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def test_retry_dangerous_codes(self, extra_params):
346346

347347
# These http codes are not retried by default
348348
# For some applications, idempotency is not important so we give users a way to force retries anyway
349-
DANGEROUS_CODES = [502, 504, 400]
349+
DANGEROUS_CODES = [502, 504]
350350

351351
additional_settings = {
352352
"_retry_dangerous_codes": DANGEROUS_CODES,

tests/unit/test_auth.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -263,15 +263,11 @@ def test_get_token_success(self, token_source, http_response):
263263
with patch.object(token_source, "_http_client", mock_http_client):
264264
# Create a mock response with the expected format
265265
mock_response = MagicMock()
266-
mock_response.status_code = 200
267-
mock_response.json.return_value = {
268-
"access_token": "abc123",
269-
"token_type": "Bearer",
270-
"refresh_token": None,
271-
}
272-
# Mock the context manager (execute returns context manager)
273-
mock_http_client.execute.return_value.__enter__.return_value = mock_response
274-
mock_http_client.execute.return_value.__exit__.return_value = None
266+
mock_response.status = 200
267+
mock_response.data.decode.return_value = '{"access_token": "abc123", "token_type": "Bearer", "refresh_token": null}'
268+
269+
# Mock the request method to return the response directly
270+
mock_http_client.request.return_value = mock_response
275271

276272
token = token_source.get_token()
277273

@@ -287,12 +283,11 @@ def test_get_token_failure(self, token_source, http_response):
287283
with patch.object(token_source, "_http_client", mock_http_client):
288284
# Create a mock response with error
289285
mock_response = MagicMock()
290-
mock_response.status_code = 400
291-
mock_response.text = "Bad Request"
292-
mock_response.json.return_value = {"error": "invalid_client"}
293-
# Mock the context manager (execute returns context manager)
294-
mock_http_client.execute.return_value.__enter__.return_value = mock_response
295-
mock_http_client.execute.return_value.__exit__.return_value = None
286+
mock_response.status = 400
287+
mock_response.data.decode.return_value = "Bad Request"
288+
289+
# Mock the request method to return the response directly
290+
mock_http_client.request.return_value = mock_response
296291

297292
with pytest.raises(Exception) as e:
298293
token_source.get_token()

0 commit comments

Comments
 (0)