@@ -39,7 +39,15 @@ def __init__(
3939 self .access_token = access_token
4040 self .token_type = token_type
4141 self .refresh_token = refresh_token
42- self .expiry = expiry or datetime .now (tz = timezone .utc )
42+
43+ # Ensure expiry is timezone-aware
44+ if expiry is None :
45+ self .expiry = datetime .now (tz = timezone .utc )
46+ elif expiry .tzinfo is None :
47+ # Convert naive datetime to aware datetime
48+ self .expiry = expiry .replace (tzinfo = timezone .utc )
49+ else :
50+ self .expiry = expiry
4351
4452 def is_expired (self ) -> bool :
4553 """Check if the token is expired."""
@@ -129,7 +137,9 @@ def get_headers() -> Dict[str, str]:
129137 and self .last_exchanged_token .needs_refresh ()
130138 ):
131139 # The token is approaching expiry, try to refresh
132- logger .debug ("Exchanged token approaching expiry, refreshing..." )
140+ logger .info (
141+ "Exchanged token approaching expiry, refreshing with fresh external token..."
142+ )
133143 return self ._refresh_token (access_token , token_type )
134144
135145 # Parse the JWT to get claims
@@ -138,14 +148,16 @@ def get_headers() -> Dict[str, str]:
138148 # Check if token needs to be exchanged
139149 if self ._is_same_host (token_claims .get ("iss" , "" ), self .hostname ):
140150 # Token is from the same host, no need to exchange
151+ logger .debug ("Token from same host, no exchange needed" )
141152 return self .external_provider_headers
142153 else :
143154 # Token is from a different host, need to exchange
155+ logger .debug ("Token from different host, attempting exchange" )
144156 return self ._try_token_exchange_or_fallback (
145157 access_token , token_type
146158 )
147159 except Exception as e :
148- logger .error (f"Failed to process token: { str (e )} " )
160+ logger .error (f"Error processing token: { str (e )} " )
149161 # Fall back to original headers in case of error
150162 return self .external_provider_headers
151163
@@ -238,25 +250,6 @@ def _parse_jwt_claims(self, token: str) -> Dict[str, Any]:
238250 logger .error (f"Failed to parse JWT: { str (e )} " )
239251 raise
240252
241- def _detect_idp_from_claims (self , token_claims : Dict [str , Any ]) -> str :
242- """
243- Detect the identity provider type from token claims.
244-
245- This can be used to adjust token exchange parameters based on the IdP.
246- """
247- issuer = token_claims .get ("iss" , "" )
248-
249- if "login.microsoftonline.com" in issuer or "sts.windows.net" in issuer :
250- return "azure"
251- elif "token.actions.githubusercontent.com" in issuer :
252- return "github"
253- elif "accounts.google.com" in issuer :
254- return "google"
255- elif "cognito-idp" in issuer and "amazonaws.com" in issuer :
256- return "aws"
257- else :
258- return "unknown"
259-
260253 def _is_same_host (self , url1 : str , url2 : str ) -> bool :
261254 """Check if two URLs have the same host."""
262255 try :
@@ -283,7 +276,9 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]:
283276 The headers with the fresh token
284277 """
285278 try :
286- logger .info ("Refreshing expired token by getting a new external token" )
279+ logger .info (
280+ "Refreshing token using proactive approach (getting fresh external token first)"
281+ )
287282
288283 # Get a fresh token from the underlying credentials provider
289284 # instead of reusing the same access_token
@@ -303,14 +298,14 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]:
303298 fresh_token_type = parts [0 ]
304299 fresh_access_token = parts [1 ]
305300
306- logger . debug ( "Got fresh external token" )
307-
308- # Now process the fresh token
309- token_claims = self . _parse_jwt_claims ( fresh_access_token )
310- idp_type = self . _detect_idp_from_claims ( token_claims )
301+ # Check if we got the same token back
302+ if fresh_access_token == access_token :
303+ logger . warning (
304+ "Credentials provider returned the same token during refresh"
305+ )
311306
312307 # Perform a new token exchange with the fresh token
313- refreshed_token = self ._exchange_token (fresh_access_token , idp_type )
308+ refreshed_token = self ._exchange_token (fresh_access_token )
314309
315310 # Update the stored token
316311 self .last_exchanged_token = refreshed_token
@@ -321,6 +316,10 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]:
321316 headers [
322317 "Authorization"
323318 ] = f"{ refreshed_token .token_type } { refreshed_token .access_token } "
319+
320+ logger .info (
321+ f"Successfully refreshed token, new expiry: { refreshed_token .expiry } "
322+ )
324323 return headers
325324 except Exception as e :
326325 logger .error (
@@ -334,12 +333,8 @@ def _try_token_exchange_or_fallback(
334333 ) -> Dict [str , str ]:
335334 """Try to exchange the token or fall back to the original token."""
336335 try :
337- # Parse the token to get claims for IdP-specific adjustments
338- token_claims = self ._parse_jwt_claims (access_token )
339- idp_type = self ._detect_idp_from_claims (token_claims )
340-
341336 # Exchange the token
342- exchanged_token = self ._exchange_token (access_token , idp_type )
337+ exchanged_token = self ._exchange_token (access_token )
343338
344339 # Store the exchanged token for potential refresh later
345340 self .last_exchanged_token = exchanged_token
@@ -358,13 +353,12 @@ def _try_token_exchange_or_fallback(
358353 # Fall back to original headers
359354 return self .external_provider_headers
360355
361- def _exchange_token (self , access_token : str , idp_type : str = "unknown" ) -> Token :
356+ def _exchange_token (self , access_token : str ) -> Token :
362357 """
363358 Exchange an external token for a Databricks token.
364359
365360 Args:
366361 access_token: The external token to exchange
367- idp_type: The detected identity provider type (azure, github, etc.)
368362
369363 Returns:
370364 A Token object containing the exchanged token
@@ -384,14 +378,6 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token
384378 if self .identity_federation_client_id :
385379 params ["client_id" ] = self .identity_federation_client_id
386380
387- # Make IdP-specific adjustments
388- if idp_type == "azure" :
389- # For Azure AD, add special handling if needed
390- pass
391- elif idp_type == "github" :
392- # For GitHub Actions, add special handling if needed
393- pass
394-
395381 # Set up headers
396382 headers = {"Accept" : "*/*" , "Content-Type" : "application/x-www-form-urlencoded" }
397383
@@ -441,7 +427,7 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token
441427 return token
442428 except RequestException as e :
443429 logger .error (f"Failed to perform token exchange: { str (e )} " )
444- raise
430+ raise ValueError ( f"Request error during token exchange: { str ( e ) } " )
445431
446432
447433class SimpleCredentialsProvider (CredentialsProvider ):
0 commit comments