Skip to content

Commit 85d0cd9

Browse files
committed
addresses comments
1 parent 9fc4c0c commit 85d0cd9

File tree

3 files changed

+60
-67
lines changed

3 files changed

+60
-67
lines changed

src/databricks/sql/auth/oidc_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import requests
33
from typing import Optional
4+
from urllib.parse import urlparse
45

56
from databricks.sql.auth.endpoint import (
67
get_oauth_endpoints,
@@ -56,3 +57,19 @@ def format_hostname(hostname: str) -> str:
5657
if not hostname.endswith("/"):
5758
hostname = f"{hostname}/"
5859
return hostname
60+
61+
62+
def is_same_host(url1: str, url2: str) -> bool:
63+
"""
64+
Check if two URLs have the same host.
65+
"""
66+
try:
67+
if not url1.startswith(("http://", "https://")):
68+
url1 = f"https://{url1}"
69+
if not url2.startswith(("http://", "https://")):
70+
url2 = f"https://{url2}"
71+
parsed1 = urlparse(url1)
72+
parsed2 = urlparse(url2)
73+
return parsed1.netloc.lower() == parsed2.netloc.lower()
74+
except Exception:
75+
return False

src/databricks/sql/auth/token_federation.py

Lines changed: 35 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from requests.exceptions import RequestException
1111

1212
from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory
13-
from databricks.sql.auth.oidc_utils import OIDCDiscoveryUtil
13+
from databricks.sql.auth.oidc_utils import OIDCDiscoveryUtil, is_same_host
1414
from databricks.sql.auth.token import Token
1515

1616
logger = logging.getLogger(__name__)
@@ -79,15 +79,6 @@ def __call__(self, *args, **kwargs) -> HeaderFactory:
7979
Configure and return a HeaderFactory that provides authentication headers.
8080
This is called by the ExternalAuthProvider to get headers for authentication.
8181
"""
82-
# First call the underlying credentials provider to get its headers
83-
header_factory = self.credentials_provider(*args, **kwargs)
84-
85-
# Get the standard token endpoint if not already set
86-
if self.token_endpoint is None:
87-
self.token_endpoint = OIDCDiscoveryUtil.discover_token_endpoint(
88-
self.hostname
89-
)
90-
9182
# Return a function that will get authentication headers
9283
return self.get_auth_headers
9384

@@ -156,34 +147,6 @@ def _get_expiry_from_jwt(self, token: str) -> Optional[datetime]:
156147

157148
return None
158149

159-
def _is_same_host(self, url1: str, url2: str) -> bool:
160-
"""
161-
Check if two URLs have the same host.
162-
163-
Args:
164-
url1: First URL
165-
url2: Second URL
166-
167-
Returns:
168-
bool: True if hosts are the same, False otherwise
169-
"""
170-
try:
171-
# Add protocol if missing to ensure proper parsing
172-
if not url1.startswith(("http://", "https://")):
173-
url1 = f"https://{url1}"
174-
if not url2.startswith(("http://", "https://")):
175-
url2 = f"https://{url2}"
176-
177-
# Parse the URLs
178-
parsed1 = urlparse(url1)
179-
parsed2 = urlparse(url2)
180-
181-
# Compare the hostnames
182-
return parsed1.netloc.lower() == parsed2.netloc.lower()
183-
except Exception as e:
184-
logger.warning(f"Error comparing hosts: {str(e)}")
185-
return False
186-
187150
def refresh_token(self) -> Token:
188151
"""
189152
Refresh the token and return the new Token object.
@@ -210,24 +173,34 @@ def refresh_token(self) -> Token:
210173
token_claims = self._parse_jwt_claims(access_token)
211174

212175
# Create new token based on whether it's from the same host or not
213-
if self._is_same_host(token_claims.get("iss", ""), self.hostname):
176+
if is_same_host(token_claims.get("iss", ""), self.hostname):
214177
# Token is from the same host, no need to exchange
215178
logger.debug("Token from same host, creating token without exchange")
216-
217179
expiry = self._get_expiry_from_jwt(access_token)
218180
if expiry is None:
219181
raise ValueError("Could not determine token expiry from JWT")
220-
221182
new_token = Token(access_token, token_type, "", expiry)
183+
self.current_token = new_token
184+
return new_token
222185
else:
223186
# Token is from a different host, need to exchange
224187
logger.debug("Token from different host, exchanging token")
225-
new_token = self._exchange_token(access_token)
226-
227-
# Store the token
228-
self.current_token = new_token
229-
230-
return new_token
188+
try:
189+
new_token = self._exchange_token(access_token)
190+
self.current_token = new_token
191+
return new_token
192+
except Exception as e:
193+
logger.error(
194+
f"Token exchange failed: {e}. Using external token as fallback."
195+
)
196+
expiry = self._get_expiry_from_jwt(access_token)
197+
if expiry is None:
198+
raise ValueError(
199+
"Could not determine token expiry from JWT (after exchange failure)"
200+
)
201+
fallback_token = Token(access_token, token_type, "", expiry)
202+
self.current_token = fallback_token
203+
return fallback_token
231204

232205
def get_current_token(self) -> Token:
233206
"""
@@ -254,24 +227,19 @@ def get_auth_headers(self) -> Dict[str, str]:
254227
"""
255228
Get authorization headers using the current token.
256229
257-
This method gets the current token and returns it formatted
258-
as authorization headers.
259-
260230
Returns:
261-
Dict[str, str]: Authorization headers
231+
Dict[str, str]: Authorization headers (may include extra headers from provider)
262232
"""
263233
try:
264234
token = self.get_current_token()
265-
return {"Authorization": f"{token.token_type} {token.access_token}"}
235+
# Always get the latest headers from the credentials provider
236+
header_factory = self.credentials_provider()
237+
headers = dict(header_factory()) if header_factory else {}
238+
headers["Authorization"] = f"{token.token_type} {token.access_token}"
239+
return headers
266240
except Exception as e:
267241
logger.error(f"Error getting auth headers: {str(e)}")
268-
269-
# Fall back to external headers if available
270-
if self.external_headers:
271-
return self.external_headers
272-
273-
# Return empty dict as a last resort
274-
return {}
242+
return dict(self.external_headers) if self.external_headers else {}
275243

276244
def _send_token_exchange_request(
277245
self, token_exchange_data: Dict[str, str]
@@ -286,7 +254,7 @@ def _send_token_exchange_request(
286254
Dict[str, Any]: Token exchange response
287255
288256
Raises:
289-
ValueError: If token exchange fails
257+
requests.HTTPError: If token exchange fails
290258
"""
291259
if not self.token_endpoint:
292260
raise ValueError("Token endpoint not initialized")
@@ -296,9 +264,9 @@ def _send_token_exchange_request(
296264
)
297265

298266
if response.status_code != 200:
299-
raise ValueError(
300-
f"Token exchange failed with status code {response.status_code}: "
301-
f"{response.text}"
267+
raise requests.HTTPError(
268+
f"Token exchange failed with status code {response.status_code}: {response.text}",
269+
response=response,
302270
)
303271

304272
return response.json()
@@ -316,6 +284,10 @@ def _exchange_token(self, access_token: str) -> Token:
316284
Raises:
317285
ValueError: If token exchange fails
318286
"""
287+
if self.token_endpoint is None:
288+
self.token_endpoint = OIDCDiscoveryUtil.discover_token_endpoint(
289+
self.hostname
290+
)
319291
# Prepare the request data
320292
token_exchange_data = dict(self.TOKEN_EXCHANGE_PARAMS)
321293
token_exchange_data["subject_token"] = access_token

tests/unit/test_token_federation.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def mock_dependencies(self):
145145
"databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token"
146146
) as mock_exchange:
147147
with patch(
148-
"databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host"
148+
"databricks.sql.auth.oidc_utils.is_same_host"
149149
) as mock_is_same_host:
150150
with patch(
151151
"databricks.sql.auth.token_federation.requests.post"
@@ -179,9 +179,11 @@ def test_provider_initialization(self, federation_provider):
179179
("databricks.com", "https://databricks.com", True),
180180
],
181181
)
182-
def test_is_same_host(self, federation_provider, url1, url2, expected):
182+
def test_is_same_host(self, url1, url2, expected):
183183
"""Test host comparison logic with various URL formats."""
184-
assert federation_provider._is_same_host(url1, url2) is expected
184+
from databricks.sql.auth.oidc_utils import is_same_host
185+
186+
assert is_same_host(url1, url2) is expected
185187

186188
@pytest.mark.parametrize(
187189
"headers,expected_result,should_raise",
@@ -389,7 +391,9 @@ def test_token_exchange_failure(self, federation_provider):
389391
mock_post.return_value = mock_response
390392

391393
# Call the method and expect an exception
394+
import requests
395+
392396
with pytest.raises(
393-
ValueError, match="Token exchange failed with status code 401"
397+
requests.HTTPError, match="Token exchange failed with status code 401"
394398
):
395399
federation_provider._exchange_token("original_token")

0 commit comments

Comments
 (0)