Skip to content

Commit a93dd4b

Browse files
committed
clean up
1 parent 34413f3 commit a93dd4b

File tree

3 files changed

+143
-66
lines changed

3 files changed

+143
-66
lines changed

.github/workflows/token-federation-test.yml

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
name: Token Federation Test
22

3-
# This workflow tests token federation functionality with GitHub Actions OIDC tokens
4-
# in the databricks-sql-python connector to ensure CI/CD functionality
5-
3+
# Tests token federation functionality with GitHub Actions OIDC tokens
64
on:
75
# Manual trigger with required inputs
86
workflow_dispatch:
@@ -17,31 +15,34 @@ on:
1715
description: 'Identity federation client ID'
1816
required: true
1917

20-
# Automatically run on PR that changes token federation files
18+
# Run on PRs that might affect token federation
2119
pull_request:
22-
branches:
23-
- main
20+
branches: [main]
21+
paths:
22+
- 'src/databricks/sql/auth/**'
23+
- 'examples/token_federation_*.py'
24+
- 'tests/token_federation/**'
25+
- '.github/workflows/token-federation-test.yml'
2426

2527
# Run on push to main that affects token federation
2628
push:
29+
branches: [main]
2730
paths:
28-
- 'src/databricks/sql/auth/token_federation.py'
29-
- 'src/databricks/sql/auth/auth.py'
31+
- 'src/databricks/sql/auth/**'
3032
- 'examples/token_federation_*.py'
31-
- 'tests/token_federation/github_oidc_test.py'
32-
branches:
33-
- main
33+
- 'tests/token_federation/**'
34+
- '.github/workflows/token-federation-test.yml'
3435

3536
permissions:
36-
# Required for GitHub OIDC token
37-
id-token: write
37+
id-token: write # Required for GitHub OIDC token
3838
contents: read
3939

4040
jobs:
4141
test-token-federation:
42+
name: Test Token Federation
4243
runs-on:
43-
group: databricks-protected-runner-group
44-
labels: linux-ubuntu-latest
44+
group: databricks-protected-runner-group
45+
labels: linux-ubuntu-latest
4546

4647
steps:
4748
- name: Checkout code
@@ -51,6 +52,7 @@ jobs:
5152
uses: actions/setup-python@v5
5253
with:
5354
python-version: '3.9'
55+
cache: 'pip'
5456

5557
- name: Install dependencies
5658
run: |
@@ -73,5 +75,4 @@ jobs:
7375
DATABRICKS_HTTP_PATH_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_http_path || secrets.DATABRICKS_HTTP_PATH_FOR_TF }}
7476
IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID }}
7577
OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }}
76-
run: |
77-
python tests/token_federation/github_oidc_test.py
78+
run: python tests/token_federation/github_oidc_test.py

tests/token_federation/github_oidc_test.py

Lines changed: 103 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,27 @@
1212
import sys
1313
import json
1414
import base64
15+
import logging
1516
from databricks import sql
1617

1718

19+
logging.basicConfig(
20+
level=logging.INFO,
21+
format="%(asctime)s - %(levelname)s - %(message)s"
22+
)
23+
logger = logging.getLogger(__name__)
24+
25+
1826
def decode_jwt(token):
19-
"""Decode and return the claims from a JWT token."""
27+
"""
28+
Decode and return the claims from a JWT token.
29+
30+
Args:
31+
token: The JWT token string
32+
33+
Returns:
34+
dict: The decoded token claims or None if decoding fails
35+
"""
2036
try:
2137
parts = token.split(".")
2238
if len(parts) != 3:
@@ -30,72 +46,121 @@ def decode_jwt(token):
3046
decoded = base64.b64decode(payload)
3147
return json.loads(decoded)
3248
except Exception as e:
33-
print(f"Failed to decode token: {str(e)}")
49+
logger.error(f"Failed to decode token: {str(e)}")
3450
return None
3551

3652

37-
def main():
38-
# Get GitHub OIDC token
53+
def get_environment_variables():
54+
"""
55+
Get required environment variables for the test.
56+
57+
Returns:
58+
tuple: (github_token, host, http_path, identity_federation_client_id)
59+
60+
Raises:
61+
SystemExit: If any required environment variable is missing
62+
"""
3963
github_token = os.environ.get("OIDC_TOKEN")
4064
if not github_token:
41-
print("GitHub OIDC token not available")
65+
logger.error("GitHub OIDC token not available")
4266
sys.exit(1)
4367

44-
# Get Databricks connection parameters
4568
host = os.environ.get("DATABRICKS_HOST_FOR_TF")
4669
http_path = os.environ.get("DATABRICKS_HTTP_PATH_FOR_TF")
4770
identity_federation_client_id = os.environ.get("IDENTITY_FEDERATION_CLIENT_ID")
4871

4972
if not host or not http_path:
50-
print("Missing Databricks connection parameters")
73+
logger.error("Missing Databricks connection parameters")
5174
sys.exit(1)
5275

53-
# Display token claims for debugging
54-
claims = decode_jwt(github_token)
55-
if claims:
56-
print("\n=== GitHub OIDC Token Claims ===")
57-
print(f"Token issuer: {claims.get('iss')}")
58-
print(f"Token subject: {claims.get('sub')}")
59-
print(f"Token audience: {claims.get('aud')}")
60-
print(f"Token expiration: {claims.get('exp', 'unknown')}")
61-
print(f"Repository: {claims.get('repository', 'unknown')}")
62-
print(f"Workflow ref: {claims.get('workflow_ref', 'unknown')}")
63-
print(f"Event name: {claims.get('event_name', 'unknown')}")
64-
print("===============================\n")
65-
66-
try:
67-
# Connect to Databricks using token federation
68-
print(f"=== Testing Connection via Connector ===")
69-
print(f"Connecting to Databricks at {host}{http_path}")
70-
print(f"Using client ID: {identity_federation_client_id}")
76+
return github_token, host, http_path, identity_federation_client_id
77+
78+
79+
def display_token_info(claims):
80+
"""Display token claims for debugging."""
81+
if not claims:
82+
logger.warning("No token claims available to display")
83+
return
7184

72-
connection_params = {
73-
"server_hostname": host,
74-
"http_path": http_path,
75-
"access_token": github_token,
76-
"auth_type": "token-federation",
77-
"identity_federation_client_id": identity_federation_client_id,
78-
}
85+
logger.info("=== GitHub OIDC Token Claims ===")
86+
logger.info(f"Token issuer: {claims.get('iss')}")
87+
logger.info(f"Token subject: {claims.get('sub')}")
88+
logger.info(f"Token audience: {claims.get('aud')}")
89+
logger.info(f"Token expiration: {claims.get('exp', 'unknown')}")
90+
logger.info(f"Repository: {claims.get('repository', 'unknown')}")
91+
logger.info(f"Workflow ref: {claims.get('workflow_ref', 'unknown')}")
92+
logger.info(f"Event name: {claims.get('event_name', 'unknown')}")
93+
logger.info("===============================")
94+
95+
96+
def test_databricks_connection(host, http_path, github_token, identity_federation_client_id):
97+
"""
98+
Test connection to Databricks using token federation.
99+
100+
Args:
101+
host: Databricks host
102+
http_path: Databricks HTTP path
103+
github_token: GitHub OIDC token
104+
identity_federation_client_id: Identity federation client ID
79105
106+
Returns:
107+
bool: True if the test is successful, False otherwise
108+
"""
109+
logger.info("=== Testing Connection via Connector ===")
110+
logger.info(f"Connecting to Databricks at {host}{http_path}")
111+
logger.info(f"Using client ID: {identity_federation_client_id}")
112+
113+
connection_params = {
114+
"server_hostname": host,
115+
"http_path": http_path,
116+
"access_token": github_token,
117+
"auth_type": "token-federation",
118+
"identity_federation_client_id": identity_federation_client_id,
119+
}
120+
121+
try:
80122
with sql.connect(**connection_params) as connection:
81-
print("Connection established successfully")
123+
logger.info("Connection established successfully")
82124

83125
# Execute a simple query
84126
cursor = connection.cursor()
85127
cursor.execute("SELECT 1 + 1 as result")
86128
result = cursor.fetchall()
87-
print(f"Query result: {result[0][0]}")
129+
logger.info(f"Query result: {result[0][0]}")
88130

89131
# Show current user
90132
cursor.execute("SELECT current_user() as user")
91133
result = cursor.fetchall()
92-
print(f"Connected as user: {result[0][0]}")
134+
logger.info(f"Connected as user: {result[0][0]}")
93135

94-
print("Token federation test successful!")
136+
logger.info("Token federation test successful!")
95137
return True
96138
except Exception as e:
97-
print(f"Error connecting to Databricks: {str(e)}")
98-
print("===================================\n")
139+
logger.error(f"Error connecting to Databricks: {str(e)}")
140+
return False
141+
142+
143+
def main():
144+
"""Main entry point for the test script."""
145+
try:
146+
# Get environment variables
147+
github_token, host, http_path, identity_federation_client_id = get_environment_variables()
148+
149+
# Display token claims
150+
claims = decode_jwt(github_token)
151+
display_token_info(claims)
152+
153+
# Test Databricks connection
154+
success = test_databricks_connection(
155+
host, http_path, github_token, identity_federation_client_id
156+
)
157+
158+
if not success:
159+
logger.error("Token federation test failed")
160+
sys.exit(1)
161+
162+
except Exception as e:
163+
logger.error(f"Unexpected error: {str(e)}")
99164
sys.exit(1)
100165

101166

tests/unit/test_token_federation.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
Token,
1414
DatabricksTokenFederationProvider,
1515
SimpleCredentialsProvider,
16-
create_token_federation_provider
16+
create_token_federation_provider,
17+
TOKEN_REFRESH_BUFFER_SECONDS
1718
)
1819

1920

@@ -47,12 +48,12 @@ def test_token_needs_refresh(self):
4748
self.assertTrue(token.needs_refresh())
4849

4950
# Token with expiry in the near future (within refresh buffer)
50-
near_future = datetime.now(tz=timezone.utc) + timedelta(minutes=4)
51+
near_future = datetime.now(tz=timezone.utc) + timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS - 60)
5152
token = Token("access_token", "Bearer", expiry=near_future)
5253
self.assertTrue(token.needs_refresh())
5354

5455
# Token with expiry far in the future
55-
far_future = datetime.now(tz=timezone.utc) + timedelta(hours=1)
56+
far_future = datetime.now(tz=timezone.utc) + timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS + 60)
5657
token = Token("access_token", "Bearer", expiry=far_future)
5758
self.assertFalse(token.needs_refresh())
5859

@@ -118,22 +119,30 @@ def test_init_oidc_discovery(self, mock_get_endpoints, mock_requests_get):
118119
@patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims')
119120
@patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token')
120121
@patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host')
121-
def test_token_refresh(self, mock_is_same_host, mock_exchange_token, mock_parse_jwt):
122+
@patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._detect_idp_from_claims')
123+
def test_token_refresh(self, mock_detect_idp, mock_is_same_host, mock_exchange_token, mock_parse_jwt):
122124
"""Test token refresh functionality for approaching expiry."""
123125
# Set up mocks
124126
mock_parse_jwt.return_value = {"iss": "https://login.microsoftonline.com/tenant"}
125127
mock_is_same_host.return_value = False
128+
mock_detect_idp.return_value = "azure"
126129

127-
# Create a mock credentials provider that can return different tokens
130+
# Create mock credentials provider that can return different tokens for different calls
128131
mock_creds_provider = MagicMock()
129-
# Initial token factory
132+
133+
# First call returns initial_token, second call returns fresh_token
134+
initial_headers = {"Authorization": "Bearer initial_token"}
135+
fresh_headers = {"Authorization": "Bearer fresh_token"}
136+
137+
# Set up initial header factory
130138
initial_header_factory = MagicMock()
131-
initial_header_factory.return_value = {"Authorization": "Bearer initial_token"}
132-
# Fresh token factory for refresh
139+
initial_header_factory.return_value = initial_headers
140+
141+
# Set up fresh header factory for second call
133142
fresh_header_factory = MagicMock()
134-
fresh_header_factory.return_value = {"Authorization": "Bearer fresh_token"}
143+
fresh_header_factory.return_value = fresh_headers
135144

136-
# Configure the mock to return different header factories on consecutive calls
145+
# Configure the mock to return factories
137146
mock_creds_provider.side_effect = [initial_header_factory, fresh_header_factory]
138147

139148
# Set up the token federation provider
@@ -157,9 +166,11 @@ def test_token_refresh(self, mock_is_same_host, mock_exchange_token, mock_parse_
157166

158167
# Reset the mocks to track the next call
159168
mock_exchange_token.reset_mock()
169+
mock_creds_provider.reset_mock()
170+
mock_creds_provider.return_value = fresh_header_factory
160171

161172
# Now simulate an approaching expiry
162-
near_expiry = datetime.now(tz=timezone.utc) + timedelta(minutes=4)
173+
near_expiry = datetime.now(tz=timezone.utc) + timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS - 60)
163174
federation_provider.last_exchanged_token = Token(
164175
"exchanged_token_1", "Bearer", expiry=near_expiry
165176
)

0 commit comments

Comments
 (0)