Skip to content

Commit d54ba93

Browse files
committed
Enhance token federation refresh to get fresh external tokens
1 parent de48411 commit d54ba93

File tree

3 files changed

+196
-11
lines changed

3 files changed

+196
-11
lines changed

src/databricks/sql/auth/token_federation.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -283,33 +283,56 @@ def _is_same_host(self, url1: str, url2: str) -> bool:
283283

284284
def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]:
285285
"""
286-
Attempt to refresh an expired token.
286+
Attempt to refresh an expired token by first getting a fresh external token
287+
and then exchanging it for a new Databricks token.
287288
288-
For most OAuth implementations, refreshing involves a new token exchange
289-
with the latest external token.
289+
This implementation follows the JDBC driver approach by first requesting
290+
a fresh token from the underlying credentials provider before performing
291+
the token exchange.
290292
291293
Args:
292-
access_token: The original external access token
294+
access_token: The original external access token (will be replaced)
293295
token_type: The token type (Bearer, etc.)
294296
295297
Returns:
296298
The headers with the fresh token
297299
"""
298300
try:
299-
logger.info("Refreshing expired token via new token exchange")
300-
# For most federation implementations, refresh is just a new token exchange
301-
token_claims = self._parse_jwt_claims(access_token)
301+
logger.info("Refreshing expired token by getting a new external token")
302+
303+
# ENHANCEMENT: Get a fresh token from the underlying credentials provider
304+
# instead of reusing the same access_token
305+
fresh_headers = self.credentials_provider()()
306+
307+
# Extract the fresh token from the headers
308+
auth_header = fresh_headers.get("Authorization", "")
309+
if not auth_header:
310+
logger.error("No Authorization header in fresh headers")
311+
return self.external_provider_headers
312+
313+
parts = auth_header.split(" ", 1)
314+
if len(parts) != 2:
315+
logger.error(f"Invalid Authorization header format: {auth_header}")
316+
return self.external_provider_headers
317+
318+
fresh_token_type = parts[0]
319+
fresh_access_token = parts[1]
320+
321+
logger.debug("Got fresh external token")
322+
323+
# Now process the fresh token
324+
token_claims = self._parse_jwt_claims(fresh_access_token)
302325
idp_type = self._detect_idp_from_claims(token_claims)
303326

304-
# Perform a new token exchange
305-
refreshed_token = self._exchange_token(access_token, idp_type)
327+
# Perform a new token exchange with the fresh token
328+
refreshed_token = self._exchange_token(fresh_access_token, idp_type)
306329

307330
# Update the stored token
308331
self.last_exchanged_token = refreshed_token
309-
self.last_external_token = access_token
332+
self.last_external_token = fresh_access_token
310333

311334
# Create new headers with the refreshed token
312-
headers = dict(self.external_provider_headers)
335+
headers = dict(fresh_headers) # Use the fresh headers as base
313336
headers[
314337
"Authorization"
315338
] = f"{refreshed_token.token_type} {refreshed_token.access_token}"

tests/unit/test_token_federation.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,63 @@ def test_init_oidc_discovery(self, mock_get_endpoints, mock_requests_get):
114114
federation_provider.token_endpoint = None
115115
federation_provider._init_oidc_discovery()
116116
self.assertEqual(federation_provider.token_endpoint, "https://example.com/oidc/v1/token")
117+
118+
@patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims')
119+
@patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token')
120+
@patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host')
121+
def test_token_refresh(self, mock_is_same_host, mock_exchange_token, mock_parse_jwt):
122+
"""Test token refresh functionality for approaching expiry."""
123+
# Set up mocks
124+
mock_parse_jwt.return_value = {"iss": "https://login.microsoftonline.com/tenant"}
125+
mock_is_same_host.return_value = False
126+
127+
# Create a simple credentials provider that returns a fixed token
128+
external_token = "test_token"
129+
creds_provider = SimpleCredentialsProvider(external_token)
130+
131+
# Set up the token federation provider
132+
federation_provider = DatabricksTokenFederationProvider(
133+
creds_provider, "example.com", "client_id"
134+
)
135+
136+
# Mock the token exchange to return a known token
137+
future_time = datetime.now(tz=timezone.utc) + timedelta(hours=1)
138+
mock_exchange_token.return_value = Token(
139+
"exchanged_token_1", "Bearer", expiry=future_time
140+
)
141+
142+
# First call to get initial headers and token - this should trigger an exchange
143+
headers_factory = federation_provider()
144+
headers = headers_factory()
145+
146+
# Verify the exchange happened
147+
mock_exchange_token.assert_called_with(external_token, "azure")
148+
self.assertEqual(headers["Authorization"], "Bearer exchanged_token_1")
149+
150+
# Reset the mocks to track the next call
151+
mock_exchange_token.reset_mock()
152+
153+
# Now simulate an approaching expiry
154+
near_expiry = datetime.now(tz=timezone.utc) + timedelta(minutes=4)
155+
federation_provider.last_exchanged_token = Token(
156+
"exchanged_token_1", "Bearer", expiry=near_expiry
157+
)
158+
federation_provider.last_external_token = external_token
159+
160+
# Set up the mock to return a different token for the refresh
161+
mock_exchange_token.return_value = Token(
162+
"exchanged_token_2", "Bearer", expiry=future_time
163+
)
164+
165+
# Make a second call which should trigger refresh
166+
headers = headers_factory()
167+
168+
# Verify the token was exchanged with the SAME external token (current implementation)
169+
# This is different from the JDBC driver approach which gets a fresh token
170+
mock_exchange_token.assert_called_once_with(external_token, "azure")
171+
172+
# Verify the headers contain the new token
173+
self.assertEqual(headers["Authorization"], "Bearer exchanged_token_2")
117174

118175

119176
class TestTokenFederationFactory(unittest.TestCase):
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#!/usr/bin/env python3
2+
3+
"""
4+
Unit tests for the JDBC-style token refresh in Databricks SQL connector.
5+
6+
This test verifies that the token federation implementation follows the JDBC driver's approach
7+
of getting a fresh external token before exchanging it for a Databricks token during refresh.
8+
"""
9+
10+
import unittest
11+
from unittest.mock import patch, MagicMock
12+
from datetime import datetime, timezone, timedelta
13+
14+
from databricks.sql.auth.token_federation import (
15+
DatabricksTokenFederationProvider,
16+
Token
17+
)
18+
19+
20+
class RefreshingCredentialsProvider:
21+
"""
22+
A credentials provider that returns different tokens on each call.
23+
This simulates providers like Azure AD that can refresh their tokens.
24+
"""
25+
26+
def __init__(self):
27+
self.call_count = 0
28+
29+
def auth_type(self):
30+
return "bearer"
31+
32+
def __call__(self, *args, **kwargs):
33+
def get_headers():
34+
self.call_count += 1
35+
# Return a different token each time to simulate fresh tokens
36+
return {"Authorization": f"Bearer fresh_token_{self.call_count}"}
37+
return get_headers
38+
39+
40+
class TestJdbcStyleTokenRefresh(unittest.TestCase):
41+
"""Tests for the JDBC-style token refresh implementation."""
42+
43+
@patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims')
44+
@patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token')
45+
@patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host')
46+
def test_refresh_gets_fresh_token(self, mock_is_same_host, mock_exchange_token, mock_parse_jwt):
47+
"""Test that token refresh first gets a fresh external token."""
48+
# Set up mocks
49+
mock_parse_jwt.return_value = {"iss": "https://login.microsoftonline.com/tenant"}
50+
mock_is_same_host.return_value = False
51+
52+
# Create a credentials provider that returns different tokens on each call
53+
refreshing_provider = RefreshingCredentialsProvider()
54+
55+
# Set up the token federation provider
56+
federation_provider = DatabricksTokenFederationProvider(
57+
refreshing_provider, "example.com", "client_id"
58+
)
59+
60+
# Set up mock for token exchange
61+
future_time = datetime.now(tz=timezone.utc) + timedelta(hours=1)
62+
mock_exchange_token.return_value = Token(
63+
"exchanged_token_1", "Bearer", expiry=future_time
64+
)
65+
66+
# First call to get initial headers and token
67+
headers_factory = federation_provider()
68+
headers = headers_factory()
69+
70+
# Verify the first exchange happened
71+
mock_exchange_token.assert_called_with("fresh_token_1", "azure")
72+
self.assertEqual(headers["Authorization"], "Bearer exchanged_token_1")
73+
self.assertEqual(refreshing_provider.call_count, 1)
74+
75+
# Reset the mock to track the next call
76+
mock_exchange_token.reset_mock()
77+
78+
# Now simulate an approaching expiry
79+
near_expiry = datetime.now(tz=timezone.utc) + timedelta(minutes=4)
80+
federation_provider.last_exchanged_token = Token(
81+
"exchanged_token_1", "Bearer", expiry=near_expiry
82+
)
83+
federation_provider.last_external_token = "fresh_token_1"
84+
85+
# Set up the mock to return a different token for the refresh
86+
mock_exchange_token.return_value = Token(
87+
"exchanged_token_2", "Bearer", expiry=future_time
88+
)
89+
90+
# Make a second call which should trigger refresh
91+
headers = headers_factory()
92+
93+
# With JDBC-style implementation:
94+
# 1. Should call credentials provider again to get fresh token
95+
self.assertEqual(refreshing_provider.call_count, 2)
96+
97+
# 2. Should exchange the FRESH token (fresh_token_2), not the stored one (fresh_token_1)
98+
mock_exchange_token.assert_called_once_with("fresh_token_2", "azure")
99+
100+
# 3. Should return headers with the new Databricks token
101+
self.assertEqual(headers["Authorization"], "Bearer exchanged_token_2")
102+
103+
104+
if __name__ == "__main__":
105+
unittest.main()

0 commit comments

Comments
 (0)