Skip to content

Commit 9fc4c0c

Browse files
committed
Refactor token expiry handling in DatabricksTokenFederationProvider and enhance unit tests for accurate expiry verification
1 parent 7ab4068 commit 9fc4c0c

File tree

2 files changed

+16
-41
lines changed

2 files changed

+16
-41
lines changed

src/databricks/sql/auth/token_federation.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -335,45 +335,13 @@ def _exchange_token(self, access_token: str) -> Token:
335335
token_type = resp_data.get("token_type", "Bearer")
336336
refresh_token = resp_data.get("refresh_token", "")
337337

338-
# Determine token expiry - first try from JWT claims
338+
# Extract expiry from JWT claims
339339
expiry = self._get_expiry_from_jwt(new_access_token)
340-
341-
# If JWT expiry not available, use expires_in from response
342340
if expiry is None:
343-
expiry = self._get_expiry_from_response(resp_data)
344-
345-
# If we still don't have an expiry, we can't proceed
346-
if expiry is None:
347-
raise ValueError(
348-
"Unable to determine token expiry from response or JWT claims"
349-
)
341+
raise ValueError("Unable to determine token expiry from JWT claims")
350342

351343
return Token(new_access_token, token_type, refresh_token, expiry)
352344

353-
def _get_expiry_from_response(
354-
self, resp_data: Dict[str, Any]
355-
) -> Optional[datetime]:
356-
"""
357-
Extract expiry datetime from response data.
358-
359-
Args:
360-
resp_data: Response data from token exchange
361-
362-
Returns:
363-
Optional[datetime]: Expiry datetime if found in response, None otherwise
364-
"""
365-
if "expires_in" not in resp_data or not resp_data["expires_in"]:
366-
return None
367-
368-
try:
369-
expires_in = int(resp_data["expires_in"])
370-
expiry = datetime.now(tz=timezone.utc) + timedelta(seconds=expires_in)
371-
logger.debug(f"Using expiry from expires_in: {expiry}")
372-
return expiry
373-
except (ValueError, TypeError) as e:
374-
logger.warning(f"Invalid expires_in value: {str(e)}")
375-
return None
376-
377345

378346
class SimpleCredentialsProvider(CredentialsProvider):
379347
"""A simple credentials provider that returns a fixed token."""

tests/unit/test_token_federation.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -344,20 +344,26 @@ def test_token_exchange_success(self, federation_provider):
344344
"""Test successful token exchange."""
345345
# Mock successful response
346346
with patch("databricks.sql.auth.token_federation.requests.post") as mock_post:
347+
# Create a token with a valid expiry
348+
expiry_timestamp = int(
349+
(datetime.now(tz=timezone.utc) + timedelta(hours=1)).timestamp()
350+
)
351+
347352
# Configure mock response
348353
mock_response = MagicMock()
349354
mock_response.status_code = 200
350355
mock_response.json.return_value = {
351356
"access_token": "new_token",
352357
"token_type": "Bearer",
353358
"refresh_token": "refresh_value",
354-
"expires_in": 3600,
355359
}
356360
mock_post.return_value = mock_response
357361

358-
# Patch the _get_expiry_from_jwt method to return None (forcing use of expires_in)
362+
# Mock JWT expiry extraction to return a valid expiry
359363
with patch.object(
360-
federation_provider, "_get_expiry_from_jwt", return_value=None
364+
federation_provider,
365+
"_get_expiry_from_jwt",
366+
return_value=datetime.fromtimestamp(expiry_timestamp, tz=timezone.utc),
361367
):
362368
# Call the exchange method
363369
token = federation_provider._exchange_token("original_token")
@@ -367,10 +373,11 @@ def test_token_exchange_success(self, federation_provider):
367373
assert token.token_type == "Bearer"
368374
assert token.refresh_token == "refresh_value"
369375

370-
# Verify expiry time (should be ~1 hour in future)
371-
now = datetime.now(tz=timezone.utc)
372-
assert token.expiry > now
373-
assert token.expiry < now + timedelta(seconds=3601)
376+
# Verify expiry time is correctly set
377+
expiry_datetime = datetime.fromtimestamp(
378+
expiry_timestamp, tz=timezone.utc
379+
)
380+
assert token.expiry == expiry_datetime
374381

375382
def test_token_exchange_failure(self, federation_provider):
376383
"""Test token exchange failure handling."""

0 commit comments

Comments
 (0)