Skip to content

Commit 2e12935

Browse files
committed
general improvements
1 parent 29f95f2 commit 2e12935

File tree

5 files changed

+751
-531
lines changed

5 files changed

+751
-531
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import logging
2+
import requests
3+
from typing import Optional
4+
5+
from databricks.sql.auth.endpoint import (
6+
get_oauth_endpoints,
7+
infer_cloud_from_host,
8+
)
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class OIDCDiscoveryUtil:
14+
"""
15+
Utility class for OIDC discovery operations.
16+
17+
This class handles discovery of OIDC endpoints through standard
18+
discovery mechanisms, with fallback to default endpoints if needed.
19+
"""
20+
21+
# Standard token endpoint path for Databricks workspaces
22+
DEFAULT_TOKEN_PATH = "oidc/v1/token"
23+
24+
@staticmethod
25+
def discover_token_endpoint(hostname: str) -> str:
26+
"""
27+
Get the token endpoint for the given Databricks hostname.
28+
29+
For Databricks workspaces, the token endpoint is always at host/oidc/v1/token.
30+
31+
Args:
32+
hostname: The hostname to get token endpoint for
33+
34+
Returns:
35+
str: The token endpoint URL
36+
"""
37+
# Format the hostname and return the standard endpoint
38+
hostname = OIDCDiscoveryUtil.format_hostname(hostname)
39+
token_endpoint = f"{hostname}{OIDCDiscoveryUtil.DEFAULT_TOKEN_PATH}"
40+
logger.info(f"Using token endpoint: {token_endpoint}")
41+
return token_endpoint
42+
43+
@staticmethod
44+
def format_hostname(hostname: str) -> str:
45+
"""
46+
Format hostname to ensure it has proper https:// prefix and trailing slash.
47+
48+
Args:
49+
hostname: The hostname to format
50+
51+
Returns:
52+
str: The formatted hostname
53+
"""
54+
if not hostname.startswith("https://"):
55+
hostname = f"https://{hostname}"
56+
if not hostname.endswith("/"):
57+
hostname = f"{hostname}/"
58+
return hostname

src/databricks/sql/auth/token.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""
2+
Token class for authentication tokens with expiry handling.
3+
"""
4+
5+
from datetime import datetime, timezone, timedelta
6+
from typing import Optional
7+
8+
9+
class Token:
10+
"""
11+
Represents an OAuth token with expiry information.
12+
13+
This class handles token state including expiry calculation.
14+
"""
15+
16+
# Minimum time buffer before expiry to consider a token still valid (in seconds)
17+
MIN_VALIDITY_BUFFER = 10
18+
19+
def __init__(
20+
self,
21+
access_token: str,
22+
token_type: str,
23+
refresh_token: str = "",
24+
expiry: Optional[datetime] = None,
25+
):
26+
"""
27+
Initialize a Token object.
28+
29+
Args:
30+
access_token: The access token string
31+
token_type: The token type (usually "Bearer")
32+
refresh_token: Optional refresh token
33+
expiry: Token expiry datetime, must be provided
34+
35+
Raises:
36+
ValueError: If no expiry is provided
37+
"""
38+
self.access_token = access_token
39+
self.token_type = token_type
40+
self.refresh_token = refresh_token
41+
42+
# Ensure we have an expiry time
43+
if expiry is None:
44+
raise ValueError("Token expiry must be provided")
45+
46+
# Ensure expiry is timezone-aware
47+
if expiry.tzinfo is None:
48+
# Convert naive datetime to aware datetime
49+
self.expiry = expiry.replace(tzinfo=timezone.utc)
50+
else:
51+
self.expiry = expiry
52+
53+
def is_valid(self) -> bool:
54+
"""
55+
Check if the token is valid (has at least MIN_VALIDITY_BUFFER seconds before expiry).
56+
57+
Returns:
58+
bool: True if the token is valid, False otherwise
59+
"""
60+
buffer = timedelta(seconds=self.MIN_VALIDITY_BUFFER)
61+
return datetime.now(tz=timezone.utc) + buffer < self.expiry
62+
63+
def __str__(self) -> str:
64+
"""Return the token as a string in the format used for Authorization headers."""
65+
return f"{self.token_type} {self.access_token}"

0 commit comments

Comments
 (0)