1010from requests .exceptions import RequestException
1111
1212from databricks .sql .auth .authenticators import CredentialsProvider , HeaderFactory
13- from databricks .sql .auth .oidc_utils import OIDCDiscoveryUtil
13+ from databricks .sql .auth .oidc_utils import OIDCDiscoveryUtil , is_same_host
1414from databricks .sql .auth .token import Token
1515
1616logger = logging .getLogger (__name__ )
@@ -79,15 +79,6 @@ def __call__(self, *args, **kwargs) -> HeaderFactory:
7979 Configure and return a HeaderFactory that provides authentication headers.
8080 This is called by the ExternalAuthProvider to get headers for authentication.
8181 """
82- # First call the underlying credentials provider to get its headers
83- header_factory = self .credentials_provider (* args , ** kwargs )
84-
85- # Get the standard token endpoint if not already set
86- if self .token_endpoint is None :
87- self .token_endpoint = OIDCDiscoveryUtil .discover_token_endpoint (
88- self .hostname
89- )
90-
9182 # Return a function that will get authentication headers
9283 return self .get_auth_headers
9384
@@ -156,34 +147,6 @@ def _get_expiry_from_jwt(self, token: str) -> Optional[datetime]:
156147
157148 return None
158149
159- def _is_same_host (self , url1 : str , url2 : str ) -> bool :
160- """
161- Check if two URLs have the same host.
162-
163- Args:
164- url1: First URL
165- url2: Second URL
166-
167- Returns:
168- bool: True if hosts are the same, False otherwise
169- """
170- try :
171- # Add protocol if missing to ensure proper parsing
172- if not url1 .startswith (("http://" , "https://" )):
173- url1 = f"https://{ url1 } "
174- if not url2 .startswith (("http://" , "https://" )):
175- url2 = f"https://{ url2 } "
176-
177- # Parse the URLs
178- parsed1 = urlparse (url1 )
179- parsed2 = urlparse (url2 )
180-
181- # Compare the hostnames
182- return parsed1 .netloc .lower () == parsed2 .netloc .lower ()
183- except Exception as e :
184- logger .warning (f"Error comparing hosts: { str (e )} " )
185- return False
186-
187150 def refresh_token (self ) -> Token :
188151 """
189152 Refresh the token and return the new Token object.
@@ -210,24 +173,34 @@ def refresh_token(self) -> Token:
210173 token_claims = self ._parse_jwt_claims (access_token )
211174
212175 # Create new token based on whether it's from the same host or not
213- if self . _is_same_host (token_claims .get ("iss" , "" ), self .hostname ):
176+ if is_same_host (token_claims .get ("iss" , "" ), self .hostname ):
214177 # Token is from the same host, no need to exchange
215178 logger .debug ("Token from same host, creating token without exchange" )
216-
217179 expiry = self ._get_expiry_from_jwt (access_token )
218180 if expiry is None :
219181 raise ValueError ("Could not determine token expiry from JWT" )
220-
221182 new_token = Token (access_token , token_type , "" , expiry )
183+ self .current_token = new_token
184+ return new_token
222185 else :
223186 # Token is from a different host, need to exchange
224187 logger .debug ("Token from different host, exchanging token" )
225- new_token = self ._exchange_token (access_token )
226-
227- # Store the token
228- self .current_token = new_token
229-
230- return new_token
188+ try :
189+ new_token = self ._exchange_token (access_token )
190+ self .current_token = new_token
191+ return new_token
192+ except Exception as e :
193+ logger .error (
194+ f"Token exchange failed: { e } . Using external token as fallback."
195+ )
196+ expiry = self ._get_expiry_from_jwt (access_token )
197+ if expiry is None :
198+ raise ValueError (
199+ "Could not determine token expiry from JWT (after exchange failure)"
200+ )
201+ fallback_token = Token (access_token , token_type , "" , expiry )
202+ self .current_token = fallback_token
203+ return fallback_token
231204
232205 def get_current_token (self ) -> Token :
233206 """
@@ -254,24 +227,19 @@ def get_auth_headers(self) -> Dict[str, str]:
254227 """
255228 Get authorization headers using the current token.
256229
257- This method gets the current token and returns it formatted
258- as authorization headers.
259-
260230 Returns:
261- Dict[str, str]: Authorization headers
231+ Dict[str, str]: Authorization headers (may include extra headers from provider)
262232 """
263233 try :
264234 token = self .get_current_token ()
265- return {"Authorization" : f"{ token .token_type } { token .access_token } " }
235+ # Always get the latest headers from the credentials provider
236+ header_factory = self .credentials_provider ()
237+ headers = dict (header_factory ()) if header_factory else {}
238+ headers ["Authorization" ] = f"{ token .token_type } { token .access_token } "
239+ return headers
266240 except Exception as e :
267241 logger .error (f"Error getting auth headers: { str (e )} " )
268-
269- # Fall back to external headers if available
270- if self .external_headers :
271- return self .external_headers
272-
273- # Return empty dict as a last resort
274- return {}
242+ return dict (self .external_headers ) if self .external_headers else {}
275243
276244 def _send_token_exchange_request (
277245 self , token_exchange_data : Dict [str , str ]
@@ -286,7 +254,7 @@ def _send_token_exchange_request(
286254 Dict[str, Any]: Token exchange response
287255
288256 Raises:
289- ValueError : If token exchange fails
257+ requests.HTTPError : If token exchange fails
290258 """
291259 if not self .token_endpoint :
292260 raise ValueError ("Token endpoint not initialized" )
@@ -296,9 +264,9 @@ def _send_token_exchange_request(
296264 )
297265
298266 if response .status_code != 200 :
299- raise ValueError (
300- f"Token exchange failed with status code { response .status_code } : "
301- f" { response . text } "
267+ raise requests . HTTPError (
268+ f"Token exchange failed with status code { response .status_code } : { response . text } " ,
269+ response = response ,
302270 )
303271
304272 return response .json ()
@@ -316,6 +284,10 @@ def _exchange_token(self, access_token: str) -> Token:
316284 Raises:
317285 ValueError: If token exchange fails
318286 """
287+ if self .token_endpoint is None :
288+ self .token_endpoint = OIDCDiscoveryUtil .discover_token_endpoint (
289+ self .hostname
290+ )
319291 # Prepare the request data
320292 token_exchange_data = dict (self .TOKEN_EXCHANGE_PARAMS )
321293 token_exchange_data ["subject_token" ] = access_token
0 commit comments