@@ -56,6 +56,20 @@ def __init__(
5656
5757 def add_headers (self , request_headers : Dict [str , str ]):
5858 """Add authentication headers to the request."""
59+
60+ if self ._cached_token and not self ._cached_token .is_expired ():
61+ request_headers ["Authorization" ] = f"{ self ._cached_token .token_type } { self ._cached_token .access_token } "
62+ return
63+
64+ # Get the external headers first to check if we need token federation
65+ self ._external_headers = {}
66+ self .external_provider .add_headers (self ._external_headers )
67+
68+ # If no Authorization header from external provider, pass through all headers
69+ if "Authorization" not in self ._external_headers :
70+ request_headers .update (self ._external_headers )
71+ return
72+
5973 token = self ._get_token ()
6074 request_headers ["Authorization" ] = f"{ token .token_type } { token .access_token } "
6175
@@ -65,11 +79,7 @@ def _get_token(self) -> Token:
6579 if self ._cached_token and not self ._cached_token .is_expired ():
6680 return self ._cached_token
6781
68- # Get the external token
69- self ._external_headers = {}
70- self .external_provider .add_headers (self ._external_headers )
71-
72- # Extract token from Authorization header
82+ # Extract token from already-fetched headers
7383 auth_header = self ._external_headers .get ("Authorization" , "" )
7484 token_type , access_token = self ._extract_token_from_header (auth_header )
7585
0 commit comments