1+ import logging
2+ import json
3+ from datetime import datetime , timedelta
4+ from typing import Optional , Dict , Tuple
5+ from urllib .parse import urlparse , urlencode
6+ import jwt
7+ import requests
8+
9+ from databricks .sql .auth .authenticators import AuthProvider
10+ from databricks .sql .auth .common import AuthType
11+ from databricks .sql .common .http import HttpMethod
12+
13+ logger = logging .getLogger (__name__ )
14+
15+
16+ class TokenFederationProvider (AuthProvider ):
17+ """
18+ Implementation of Token Federation for Databricks SQL Python driver.
19+
20+ This provider exchanges third-party access tokens for Databricks in-house tokens
21+ when the token issuer is different from the Databricks host.
22+ """
23+
24+ TOKEN_EXCHANGE_ENDPOINT = "/oidc/v1/token"
25+ TOKEN_EXCHANGE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange"
26+ TOKEN_EXCHANGE_SUBJECT_TYPE = "urn:ietf:params:oauth:token-type:jwt"
27+
28+ def __init__ (
29+ self ,
30+ hostname : str ,
31+ external_provider : AuthProvider ,
32+ http_client = None ,
33+ identity_federation_client_id : Optional [str ] = None ,
34+ ):
35+ """
36+ Initialize the Token Federation Provider.
37+
38+ Args:
39+ hostname: The Databricks workspace hostname
40+ external_provider: The external authentication provider
41+ http_client: HTTP client for making requests
42+ identity_federation_client_id: Optional client ID for token federation
43+ """
44+ self .hostname = self ._normalize_hostname (hostname )
45+ self .external_provider = external_provider
46+ self .http_client = http_client or requests .Session ()
47+ self .identity_federation_client_id = identity_federation_client_id
48+
49+ self ._cached_token = None
50+ self ._cached_token_expiry = None
51+ self ._external_headers = {}
52+
53+ def add_headers (self , request_headers : Dict [str , str ]):
54+ """Add authentication headers to the request."""
55+ token_info = self ._get_token ()
56+ request_headers ["Authorization" ] = f"{ token_info ['token_type' ]} { token_info ['access_token' ]} "
57+
58+ def _get_token (self ) -> Dict [str , str ]:
59+ """Get or refresh the authentication token."""
60+ # Check if cached token is still valid
61+ if self ._is_token_valid ():
62+ return self ._cached_token
63+
64+ # Get the external token
65+ self ._external_headers = {}
66+ self .external_provider .add_headers (self ._external_headers )
67+
68+ # Extract token from Authorization header
69+ auth_header = self ._external_headers .get ("Authorization" , "" )
70+ token_type , access_token = self ._extract_token_from_header (auth_header )
71+
72+ # Check if token exchange is needed
73+ if self ._should_exchange_token (access_token ):
74+ try :
75+ exchanged_token = self ._exchange_token (access_token )
76+ self ._cache_token (exchanged_token )
77+ return exchanged_token
78+ except Exception as e :
79+ logger .warning (f"Token exchange failed, using external token: { e } " )
80+ # Fall back to using the external token
81+
82+ # Use external token directly
83+ token_info = {
84+ "access_token" : access_token ,
85+ "token_type" : token_type ,
86+ }
87+ self ._cache_token (token_info )
88+ return token_info
89+
90+ def _should_exchange_token (self , access_token : str ) -> bool :
91+ """Check if the token should be exchanged based on issuer."""
92+ try :
93+ # Decode JWT without verification to check issuer
94+ decoded = jwt .decode (access_token , options = {"verify_signature" : False })
95+ issuer = decoded .get ("iss" , "" )
96+
97+ # Check if issuer host is different from Databricks host
98+ return not self ._is_same_host (issuer , self .hostname )
99+ except Exception as e :
100+ logger .debug (f"Failed to decode JWT token: { e } " )
101+ return False
102+
103+ def _exchange_token (self , access_token : str ) -> Dict [str , str ]:
104+ """Exchange the external token for a Databricks token."""
105+ token_url = f"{ self .hostname .rstrip ('/' )} { self .TOKEN_EXCHANGE_ENDPOINT } "
106+
107+ # Prepare the token exchange request
108+ data = {
109+ "grant_type" : self .TOKEN_EXCHANGE_GRANT_TYPE ,
110+ "subject_token" : access_token ,
111+ "subject_token_type" : self .TOKEN_EXCHANGE_SUBJECT_TYPE ,
112+ "scope" : "sql" ,
113+ "return_original_token_if_authenticated" : "true" ,
114+ }
115+
116+ # Add client_id if provided
117+ if self .identity_federation_client_id :
118+ data ["client_id" ] = self .identity_federation_client_id
119+
120+ headers = {
121+ "Content-Type" : "application/x-www-form-urlencoded" ,
122+ "Accept" : "*/*" ,
123+ }
124+
125+ # Encode data as URL-encoded form
126+ body = urlencode (data )
127+
128+ # Make the token exchange request using UnifiedHttpClient API
129+ response = self .http_client .request (
130+ HttpMethod .POST , url = token_url , body = body , headers = headers
131+ )
132+
133+ # Parse the response
134+ token_response = json .loads (response .data .decode ())
135+
136+ return {
137+ "access_token" : token_response ["access_token" ],
138+ "token_type" : token_response .get ("token_type" , "Bearer" ),
139+ "expires_in" : token_response .get ("expires_in" ),
140+ }
141+
142+ def _extract_token_from_header (self , auth_header : str ) -> Tuple [str , str ]:
143+ """Extract token type and access token from Authorization header."""
144+ if not auth_header :
145+ raise ValueError ("Authorization header is missing" )
146+
147+ parts = auth_header .split (" " , 1 )
148+ if len (parts ) != 2 :
149+ raise ValueError ("Invalid Authorization header format" )
150+
151+ return parts [0 ], parts [1 ]
152+
153+ def _is_same_host (self , url1 : str , url2 : str ) -> bool :
154+ """Check if two URLs have the same host."""
155+ try :
156+ host1 = urlparse (url1 ).netloc
157+ host2 = urlparse (url2 ).netloc
158+ return host1 == host2
159+ except Exception as e :
160+ logger .debug (f"Failed to parse URLs: { e } " )
161+ return False
162+
163+ def _normalize_hostname (self , hostname : str ) -> str :
164+ """Normalize the hostname to include scheme and trailing slash."""
165+ if not hostname .startswith ("http://" ) and not hostname .startswith ("https://" ):
166+ hostname = f"https://{ hostname } "
167+ if not hostname .endswith ("/" ):
168+ hostname = f"{ hostname } /"
169+ return hostname
170+
171+ def _cache_token (self , token_info : Dict [str , str ]):
172+ """Cache the token with its expiry time."""
173+ self ._cached_token = token_info
174+
175+ # Calculate expiry time
176+ if "expires_in" in token_info :
177+ expires_in = int (token_info ["expires_in" ])
178+ # Set expiry with a 1-minute buffer
179+ self ._cached_token_expiry = datetime .now () + timedelta (seconds = expires_in - 60 )
180+ else :
181+ # Try to get expiry from JWT
182+ try :
183+ decoded = jwt .decode (
184+ token_info ["access_token" ],
185+ options = {"verify_signature" : False }
186+ )
187+ exp = decoded .get ("exp" )
188+ if exp :
189+ self ._cached_token_expiry = datetime .fromtimestamp (exp ) - timedelta (minutes = 1 )
190+ else :
191+ # Default to 1 hour if no expiry info
192+ self ._cached_token_expiry = datetime .now () + timedelta (hours = 1 )
193+ except :
194+ # Default to 1 hour if we can't decode
195+ self ._cached_token_expiry = datetime .now () + timedelta (hours = 1 )
196+
197+ def _is_token_valid (self ) -> bool :
198+ """Check if the cached token is still valid."""
199+ if not self ._cached_token or not self ._cached_token_expiry :
200+ return False
201+ return datetime .now () < self ._cached_token_expiry
202+
203+
204+ class ExternalTokenProvider (AuthProvider ):
205+ """
206+ A simple provider that wraps an external credentials provider for token federation.
207+ """
208+
209+ def __init__ (self , credentials_provider ):
210+ """
211+ Initialize with an external credentials provider.
212+
213+ Args:
214+ credentials_provider: A callable that returns authentication headers
215+ """
216+ self .credentials_provider = credentials_provider
217+ self ._header_factory = None
218+
219+ def add_headers (self , request_headers : Dict [str , str ]):
220+ """Add headers from the external provider."""
221+ if self ._header_factory is None :
222+ self ._header_factory = self .credentials_provider ()
223+
224+ headers = self ._header_factory ()
225+ for key , value in headers .items ():
226+ request_headers [key ] = value
0 commit comments