Skip to content

Commit 02d16d2

Browse files
committed
token federation for python driver
1 parent bcab1df commit 02d16d2

File tree

4 files changed

+581
-6
lines changed

4 files changed

+581
-6
lines changed

src/databricks/sql/auth/auth.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,17 @@
88
AzureServicePrincipalCredentialProvider,
99
)
1010
from databricks.sql.auth.common import AuthType, ClientContext
11+
from databricks.sql.auth.token_federation import TokenFederationProvider, ExternalTokenProvider
1112

1213

1314
def get_auth_provider(cfg: ClientContext, http_client):
15+
# Determine the base auth provider
16+
base_provider = None
17+
1418
if cfg.credentials_provider:
15-
return ExternalAuthProvider(cfg.credentials_provider)
19+
base_provider = ExternalAuthProvider(cfg.credentials_provider)
1620
elif cfg.auth_type == AuthType.AZURE_SP_M2M.value:
17-
return ExternalAuthProvider(
21+
base_provider = ExternalAuthProvider(
1822
AzureServicePrincipalCredentialProvider(
1923
cfg.hostname,
2024
cfg.azure_client_id,
@@ -29,7 +33,7 @@ def get_auth_provider(cfg: ClientContext, http_client):
2933
assert cfg.oauth_client_id is not None
3034
assert cfg.oauth_scopes is not None
3135

32-
return DatabricksOAuthProvider(
36+
base_provider = DatabricksOAuthProvider(
3337
cfg.hostname,
3438
cfg.oauth_persistence,
3539
cfg.oauth_redirect_port_range,
@@ -39,17 +43,17 @@ def get_auth_provider(cfg: ClientContext, http_client):
3943
cfg.auth_type,
4044
)
4145
elif cfg.access_token is not None:
42-
return AccessTokenAuthProvider(cfg.access_token)
46+
base_provider = AccessTokenAuthProvider(cfg.access_token)
4347
elif cfg.use_cert_as_auth and cfg.tls_client_cert_file:
4448
# no op authenticator. authentication is performed using ssl certificate outside of headers
45-
return AuthProvider()
49+
base_provider = AuthProvider()
4650
else:
4751
if (
4852
cfg.oauth_redirect_port_range is not None
4953
and cfg.oauth_client_id is not None
5054
and cfg.oauth_scopes is not None
5155
):
52-
return DatabricksOAuthProvider(
56+
base_provider = DatabricksOAuthProvider(
5357
cfg.hostname,
5458
cfg.oauth_persistence,
5559
cfg.oauth_redirect_port_range,
@@ -60,6 +64,17 @@ def get_auth_provider(cfg: ClientContext, http_client):
6064
)
6165
else:
6266
raise RuntimeError("No valid authentication settings!")
67+
68+
# Wrap with token federation if enabled
69+
if cfg.enable_token_federation and base_provider:
70+
return TokenFederationProvider(
71+
hostname=cfg.hostname,
72+
external_provider=base_provider,
73+
http_client=http_client,
74+
identity_federation_client_id=cfg.identity_federation_client_id,
75+
)
76+
77+
return base_provider
6378

6479

6580
PYSQL_OAUTH_SCOPES = ["sql", "offline_access"]
@@ -114,5 +129,8 @@ def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs)
114129
else redirect_port_range,
115130
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
116131
credentials_provider=kwargs.get("credentials_provider"),
132+
# Token federation parameters
133+
enable_token_federation=kwargs.get("enable_token_federation", False),
134+
identity_federation_client_id=kwargs.get("identity_federation_client_id"),
117135
)
118136
return get_auth_provider(cfg, http_client)

src/databricks/sql/auth/common.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ def __init__(
3737
tls_client_cert_file: Optional[str] = None,
3838
oauth_persistence=None,
3939
credentials_provider=None,
40+
# Token federation parameters
41+
enable_token_federation: bool = False,
42+
identity_federation_client_id: Optional[str] = None,
4043
# HTTP client configuration parameters
4144
ssl_options=None, # SSLOptions type
4245
socket_timeout: Optional[float] = None,
@@ -65,6 +68,9 @@ def __init__(
6568
self.tls_client_cert_file = tls_client_cert_file
6669
self.oauth_persistence = oauth_persistence
6770
self.credentials_provider = credentials_provider
71+
# Token federation
72+
self.enable_token_federation = enable_token_federation
73+
self.identity_federation_client_id = identity_federation_client_id
6874

6975
# HTTP client configuration
7076
self.ssl_options = ssl_options
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
import logging
2+
import json
3+
from datetime import datetime, timedelta
4+
from typing import Optional, Dict, Tuple
5+
from urllib.parse import urlparse, urlencode
6+
import jwt
7+
import requests
8+
9+
from databricks.sql.auth.authenticators import AuthProvider
10+
from databricks.sql.auth.common import AuthType
11+
from databricks.sql.common.http import HttpMethod
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
class TokenFederationProvider(AuthProvider):
17+
"""
18+
Implementation of Token Federation for Databricks SQL Python driver.
19+
20+
This provider exchanges third-party access tokens for Databricks in-house tokens
21+
when the token issuer is different from the Databricks host.
22+
"""
23+
24+
TOKEN_EXCHANGE_ENDPOINT = "/oidc/v1/token"
25+
TOKEN_EXCHANGE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange"
26+
TOKEN_EXCHANGE_SUBJECT_TYPE = "urn:ietf:params:oauth:token-type:jwt"
27+
28+
def __init__(
29+
self,
30+
hostname: str,
31+
external_provider: AuthProvider,
32+
http_client=None,
33+
identity_federation_client_id: Optional[str] = None,
34+
):
35+
"""
36+
Initialize the Token Federation Provider.
37+
38+
Args:
39+
hostname: The Databricks workspace hostname
40+
external_provider: The external authentication provider
41+
http_client: HTTP client for making requests
42+
identity_federation_client_id: Optional client ID for token federation
43+
"""
44+
self.hostname = self._normalize_hostname(hostname)
45+
self.external_provider = external_provider
46+
self.http_client = http_client or requests.Session()
47+
self.identity_federation_client_id = identity_federation_client_id
48+
49+
self._cached_token = None
50+
self._cached_token_expiry = None
51+
self._external_headers = {}
52+
53+
def add_headers(self, request_headers: Dict[str, str]):
54+
"""Add authentication headers to the request."""
55+
token_info = self._get_token()
56+
request_headers["Authorization"] = f"{token_info['token_type']} {token_info['access_token']}"
57+
58+
def _get_token(self) -> Dict[str, str]:
59+
"""Get or refresh the authentication token."""
60+
# Check if cached token is still valid
61+
if self._is_token_valid():
62+
return self._cached_token
63+
64+
# Get the external token
65+
self._external_headers = {}
66+
self.external_provider.add_headers(self._external_headers)
67+
68+
# Extract token from Authorization header
69+
auth_header = self._external_headers.get("Authorization", "")
70+
token_type, access_token = self._extract_token_from_header(auth_header)
71+
72+
# Check if token exchange is needed
73+
if self._should_exchange_token(access_token):
74+
try:
75+
exchanged_token = self._exchange_token(access_token)
76+
self._cache_token(exchanged_token)
77+
return exchanged_token
78+
except Exception as e:
79+
logger.warning(f"Token exchange failed, using external token: {e}")
80+
# Fall back to using the external token
81+
82+
# Use external token directly
83+
token_info = {
84+
"access_token": access_token,
85+
"token_type": token_type,
86+
}
87+
self._cache_token(token_info)
88+
return token_info
89+
90+
def _should_exchange_token(self, access_token: str) -> bool:
91+
"""Check if the token should be exchanged based on issuer."""
92+
try:
93+
# Decode JWT without verification to check issuer
94+
decoded = jwt.decode(access_token, options={"verify_signature": False})
95+
issuer = decoded.get("iss", "")
96+
97+
# Check if issuer host is different from Databricks host
98+
return not self._is_same_host(issuer, self.hostname)
99+
except Exception as e:
100+
logger.debug(f"Failed to decode JWT token: {e}")
101+
return False
102+
103+
def _exchange_token(self, access_token: str) -> Dict[str, str]:
104+
"""Exchange the external token for a Databricks token."""
105+
token_url = f"{self.hostname.rstrip('/')}{self.TOKEN_EXCHANGE_ENDPOINT}"
106+
107+
# Prepare the token exchange request
108+
data = {
109+
"grant_type": self.TOKEN_EXCHANGE_GRANT_TYPE,
110+
"subject_token": access_token,
111+
"subject_token_type": self.TOKEN_EXCHANGE_SUBJECT_TYPE,
112+
"scope": "sql",
113+
"return_original_token_if_authenticated": "true",
114+
}
115+
116+
# Add client_id if provided
117+
if self.identity_federation_client_id:
118+
data["client_id"] = self.identity_federation_client_id
119+
120+
headers = {
121+
"Content-Type": "application/x-www-form-urlencoded",
122+
"Accept": "*/*",
123+
}
124+
125+
# Encode data as URL-encoded form
126+
body = urlencode(data)
127+
128+
# Make the token exchange request using UnifiedHttpClient API
129+
response = self.http_client.request(
130+
HttpMethod.POST, url=token_url, body=body, headers=headers
131+
)
132+
133+
# Parse the response
134+
token_response = json.loads(response.data.decode())
135+
136+
return {
137+
"access_token": token_response["access_token"],
138+
"token_type": token_response.get("token_type", "Bearer"),
139+
"expires_in": token_response.get("expires_in"),
140+
}
141+
142+
def _extract_token_from_header(self, auth_header: str) -> Tuple[str, str]:
143+
"""Extract token type and access token from Authorization header."""
144+
if not auth_header:
145+
raise ValueError("Authorization header is missing")
146+
147+
parts = auth_header.split(" ", 1)
148+
if len(parts) != 2:
149+
raise ValueError("Invalid Authorization header format")
150+
151+
return parts[0], parts[1]
152+
153+
def _is_same_host(self, url1: str, url2: str) -> bool:
154+
"""Check if two URLs have the same host."""
155+
try:
156+
host1 = urlparse(url1).netloc
157+
host2 = urlparse(url2).netloc
158+
return host1 == host2
159+
except Exception as e:
160+
logger.debug(f"Failed to parse URLs: {e}")
161+
return False
162+
163+
def _normalize_hostname(self, hostname: str) -> str:
164+
"""Normalize the hostname to include scheme and trailing slash."""
165+
if not hostname.startswith("http://") and not hostname.startswith("https://"):
166+
hostname = f"https://{hostname}"
167+
if not hostname.endswith("/"):
168+
hostname = f"{hostname}/"
169+
return hostname
170+
171+
def _cache_token(self, token_info: Dict[str, str]):
172+
"""Cache the token with its expiry time."""
173+
self._cached_token = token_info
174+
175+
# Calculate expiry time
176+
if "expires_in" in token_info:
177+
expires_in = int(token_info["expires_in"])
178+
# Set expiry with a 1-minute buffer
179+
self._cached_token_expiry = datetime.now() + timedelta(seconds=expires_in - 60)
180+
else:
181+
# Try to get expiry from JWT
182+
try:
183+
decoded = jwt.decode(
184+
token_info["access_token"],
185+
options={"verify_signature": False}
186+
)
187+
exp = decoded.get("exp")
188+
if exp:
189+
self._cached_token_expiry = datetime.fromtimestamp(exp) - timedelta(minutes=1)
190+
else:
191+
# Default to 1 hour if no expiry info
192+
self._cached_token_expiry = datetime.now() + timedelta(hours=1)
193+
except:
194+
# Default to 1 hour if we can't decode
195+
self._cached_token_expiry = datetime.now() + timedelta(hours=1)
196+
197+
def _is_token_valid(self) -> bool:
198+
"""Check if the cached token is still valid."""
199+
if not self._cached_token or not self._cached_token_expiry:
200+
return False
201+
return datetime.now() < self._cached_token_expiry
202+
203+
204+
class ExternalTokenProvider(AuthProvider):
205+
"""
206+
A simple provider that wraps an external credentials provider for token federation.
207+
"""
208+
209+
def __init__(self, credentials_provider):
210+
"""
211+
Initialize with an external credentials provider.
212+
213+
Args:
214+
credentials_provider: A callable that returns authentication headers
215+
"""
216+
self.credentials_provider = credentials_provider
217+
self._header_factory = None
218+
219+
def add_headers(self, request_headers: Dict[str, str]):
220+
"""Add headers from the external provider."""
221+
if self._header_factory is None:
222+
self._header_factory = self.credentials_provider()
223+
224+
headers = self._header_factory()
225+
for key, value in headers.items():
226+
request_headers[key] = value

0 commit comments

Comments
 (0)