Skip to content

Commit 76df22e

Browse files
committed
update and add todo for future work
1 parent a93dd4b commit 76df22e

File tree

4 files changed

+52
-16
lines changed

4 files changed

+52
-16
lines changed

poetry.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ PyJWT = ">=2.0.0"
3030
[tool.poetry.extras]
3131
pyarrow = ["pyarrow"]
3232

33-
[tool.poetry.dev-dependencies]
33+
[tool.poetry.group.dev.dependencies]
3434
pytest = "^7.1.2"
3535
mypy = "^1.10.1"
3636
pylint = ">=2.12.0"

src/databricks/sql/auth/auth.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
class AuthType(Enum):
1414
DATABRICKS_OAUTH = "databricks-oauth"
1515
AZURE_OAUTH = "azure-oauth"
16+
# TODO: Token federation should be a feature that works with different auth types,
17+
# not an auth type itself. This will be refactored in a future release.
1618
TOKEN_FEDERATION = "token-federation"
1719
# other supported types (access_token) can be inferred
1820
# we can add more types as needed later
@@ -47,6 +49,10 @@ def __init__(
4749

4850

4951
def get_auth_provider(cfg: ClientContext):
52+
# TODO: In a future refactoring, token federation should be a feature that wraps
53+
# any auth provider, not a separate auth type. The code below treats it as an auth type
54+
# for backward compatibility, but this approach will be revised.
55+
5056
if cfg.credentials_provider:
5157
# If token federation is enabled and credentials provider is provided,
5258
# wrap the credentials provider with DatabricksTokenFederationProvider
@@ -153,6 +159,10 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
153159
"Please use OAuth or access token instead."
154160
)
155161

162+
# TODO: Future refactoring needed:
163+
# - Add a use_token_federation flag that can be combined with any auth type
164+
# - Remove TOKEN_FEDERATION as an auth_type and properly handle the underlying auth type
165+
# - Maintain backward compatibility during transition
156166
cfg = ClientContext(
157167
hostname=normalize_host_name(hostname),
158168
auth_type=auth_type,

src/databricks/sql/auth/token_federation.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ def get_headers() -> Dict[str, str]:
116116
self.external_provider_headers = header_factory()
117117

118118
# Extract the token from the headers
119-
token_info = self._extract_token_info_from_header(self.external_provider_headers)
119+
token_info = self._extract_token_info_from_header(
120+
self.external_provider_headers
121+
)
120122
token_type, access_token = token_info
121123

122124
try:
@@ -139,7 +141,9 @@ def get_headers() -> Dict[str, str]:
139141
return self.external_provider_headers
140142
else:
141143
# Token is from a different host, need to exchange
142-
return self._try_token_exchange_or_fallback(access_token, token_type)
144+
return self._try_token_exchange_or_fallback(
145+
access_token, token_type
146+
)
143147
except Exception as e:
144148
logger.error(f"Failed to process token: {str(e)}")
145149
# Fall back to original headers in case of error
@@ -159,8 +163,10 @@ def _init_oidc_discovery(self):
159163

160164
if self.idp_endpoints:
161165
# Get the OpenID configuration URL
162-
openid_config_url = self.idp_endpoints.get_openid_config_url(self.hostname)
163-
166+
openid_config_url = self.idp_endpoints.get_openid_config_url(
167+
self.hostname
168+
)
169+
164170
# Fetch the OpenID configuration
165171
response = requests.get(openid_config_url)
166172
if response.status_code == 200:
@@ -184,7 +190,9 @@ def _init_oidc_discovery(self):
184190
)
185191
hostname = self._format_hostname(self.hostname)
186192
self.token_endpoint = f"{hostname}oidc/v1/token"
187-
logger.info(f"Using default token endpoint after error: {self.token_endpoint}")
193+
logger.info(
194+
f"Using default token endpoint after error: {self.token_endpoint}"
195+
)
188196

189197
def _format_hostname(self, hostname: str) -> str:
190198
"""Format hostname to ensure it has proper https:// prefix and trailing slash."""
@@ -194,7 +202,9 @@ def _format_hostname(self, hostname: str) -> str:
194202
hostname = f"{hostname}/"
195203
return hostname
196204

197-
def _extract_token_info_from_header(self, headers: Dict[str, str]) -> Tuple[str, str]:
205+
def _extract_token_info_from_header(
206+
self, headers: Dict[str, str]
207+
) -> Tuple[str, str]:
198208
"""Extract token type and token value from authorization header."""
199209
auth_header = headers.get("Authorization")
200210
if not auth_header:
@@ -308,14 +318,20 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]:
308318

309319
# Create new headers with the refreshed token
310320
headers = dict(fresh_headers) # Use the fresh headers as base
311-
headers["Authorization"] = f"{refreshed_token.token_type} {refreshed_token.access_token}"
321+
headers[
322+
"Authorization"
323+
] = f"{refreshed_token.token_type} {refreshed_token.access_token}"
312324
return headers
313325
except Exception as e:
314-
logger.error(f"Token refresh failed, falling back to original token: {str(e)}")
326+
logger.error(
327+
f"Token refresh failed, falling back to original token: {str(e)}"
328+
)
315329
# If refresh fails, fall back to the original headers
316330
return self.external_provider_headers
317331

318-
def _try_token_exchange_or_fallback(self, access_token: str, token_type: str) -> Dict[str, str]:
332+
def _try_token_exchange_or_fallback(
333+
self, access_token: str, token_type: str
334+
) -> Dict[str, str]:
319335
"""Try to exchange the token or fall back to the original token."""
320336
try:
321337
# Parse the token to get claims for IdP-specific adjustments
@@ -331,10 +347,14 @@ def _try_token_exchange_or_fallback(self, access_token: str, token_type: str) ->
331347

332348
# Create new headers with the exchanged token
333349
headers = dict(self.external_provider_headers)
334-
headers["Authorization"] = f"{exchanged_token.token_type} {exchanged_token.access_token}"
350+
headers[
351+
"Authorization"
352+
] = f"{exchanged_token.token_type} {exchanged_token.access_token}"
335353
return headers
336354
except Exception as e:
337-
logger.error(f"Token exchange failed, falling back to using external token: {str(e)}")
355+
logger.error(
356+
f"Token exchange failed, falling back to using external token: {str(e)}"
357+
)
338358
# Fall back to original headers
339359
return self.external_provider_headers
340360

@@ -396,10 +416,14 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token
396416
try:
397417
# Calculate expiry by adding expires_in seconds to current time
398418
expires_in_seconds = int(resp_data["expires_in"])
399-
token.expiry = datetime.now(tz=timezone.utc) + timedelta(seconds=expires_in_seconds)
419+
token.expiry = datetime.now(tz=timezone.utc) + timedelta(
420+
seconds=expires_in_seconds
421+
)
400422
logger.debug(f"Token expiry set from expires_in: {token.expiry}")
401423
except (ValueError, TypeError) as e:
402-
logger.warning(f"Could not parse expires_in from response: {str(e)}")
424+
logger.warning(
425+
f"Could not parse expires_in from response: {str(e)}"
426+
)
403427

404428
# If expires_in wasn't available, try to parse expiry from the token JWT
405429
if token.expiry == datetime.now(tz=timezone.utc):
@@ -408,7 +432,9 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token
408432
exp_time = token_claims.get("exp")
409433
if exp_time:
410434
token.expiry = datetime.fromtimestamp(exp_time, tz=timezone.utc)
411-
logger.debug(f"Token expiry set from JWT exp claim: {token.expiry}")
435+
logger.debug(
436+
f"Token expiry set from JWT exp claim: {token.expiry}"
437+
)
412438
except Exception as e:
413439
logger.warning(f"Could not parse expiry from token: {str(e)}")
414440

0 commit comments

Comments
 (0)