Skip to content

Commit 929191b

Browse files
committed
separate py script
1 parent e87b52d commit 929191b

File tree

2 files changed

+105
-115
lines changed

2 files changed

+105
-115
lines changed

.github/workflows/token-federation-test.yml

Lines changed: 2 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ on:
2828
- 'src/databricks/sql/auth/token_federation.py'
2929
- 'src/databricks/sql/auth/auth.py'
3030
- 'examples/token_federation_*.py'
31+
- 'tests/token_federation/github_oidc_test.py'
3132
branches:
3233
- main
3334

@@ -43,11 +44,6 @@ jobs:
4344
labels: linux-ubuntu-latest
4445

4546
steps:
46-
- name: Debug OIDC Claims
47-
uses: github/actions-oidc-debugger@main
48-
with:
49-
audience: '${{ github.server_url }}/${{ github.repository_owner }}'
50-
5147
- name: Checkout code
5248
uses: actions/checkout@v4
5349

@@ -71,120 +67,11 @@ jobs:
7167
core.setSecret(token)
7268
core.setOutput('token', token)
7369
74-
- name: Create test script
75-
run: |
76-
cat > test_github_token_federation.py << 'EOF'
77-
#!/usr/bin/env python3
78-
79-
"""
80-
Test script for Databricks SQL token federation with GitHub Actions OIDC tokens.
81-
82-
This script tests the Databricks SQL connector with token federation
83-
using a GitHub Actions OIDC token. It connects to a Databricks SQL warehouse,
84-
runs a simple query, and shows the connected user.
85-
"""
86-
87-
import os
88-
import sys
89-
import json
90-
import base64
91-
from databricks import sql
92-
93-
def decode_jwt(token):
94-
"""Decode and return the claims from a JWT token."""
95-
try:
96-
parts = token.split(".")
97-
if len(parts) != 3:
98-
raise ValueError("Invalid JWT format")
99-
100-
payload = parts[1]
101-
# Add padding if needed
102-
padding = '=' * (4 - len(payload) % 4)
103-
payload += padding
104-
105-
decoded = base64.b64decode(payload)
106-
return json.loads(decoded)
107-
except Exception as e:
108-
print(f"Failed to decode token: {str(e)}")
109-
return None
110-
111-
def main():
112-
# Get GitHub OIDC token
113-
github_token = os.environ.get("OIDC_TOKEN")
114-
if not github_token:
115-
print("GitHub OIDC token not available")
116-
sys.exit(1)
117-
118-
# Get Databricks connection parameters
119-
host = os.environ.get("DATABRICKS_HOST_FOR_TF")
120-
http_path = os.environ.get("DATABRICKS_HTTP_PATH_FOR_TF")
121-
identity_federation_client_id = os.environ.get("IDENTITY_FEDERATION_CLIENT_ID")
122-
123-
if not host or not http_path:
124-
print("Missing Databricks connection parameters")
125-
sys.exit(1)
126-
127-
# Display token claims for debugging
128-
claims = decode_jwt(github_token)
129-
if claims:
130-
print("\n=== GitHub OIDC Token Claims ===")
131-
print(f"Token issuer: {claims.get('iss')}")
132-
print(f"Token subject: {claims.get('sub')}")
133-
print(f"Token audience: {claims.get('aud')}")
134-
print(f"Token expiration: {claims.get('exp', 'unknown')}")
135-
print(f"Repository: {claims.get('repository', 'unknown')}")
136-
print(f"Workflow ref: {claims.get('workflow_ref', 'unknown')}")
137-
print(f"Event name: {claims.get('event_name', 'unknown')}")
138-
print("===============================\n")
139-
140-
try:
141-
# Connect to Databricks using token federation
142-
print(f"=== Testing Connection via Connector ===")
143-
print(f"Connecting to Databricks at {host}{http_path}")
144-
print(f"Using client ID: {identity_federation_client_id}")
145-
146-
connection_params = {
147-
"server_hostname": host,
148-
"http_path": http_path,
149-
"access_token": github_token,
150-
"auth_type": "token-federation",
151-
"identity_federation_client_id": identity_federation_client_id,
152-
}
153-
154-
print("Connection parameters:")
155-
print(json.dumps({k: v if k != 'access_token' else '***' for k, v in connection_params.items()}, indent=2))
156-
157-
with sql.connect(**connection_params) as connection:
158-
print("Connection established successfully")
159-
160-
# Execute a simple query
161-
cursor = connection.cursor()
162-
cursor.execute("SELECT 1 + 1 as result")
163-
result = cursor.fetchall()
164-
print(f"Query result: {result[0][0]}")
165-
166-
# Show current user
167-
cursor.execute("SELECT current_user() as user")
168-
result = cursor.fetchall()
169-
print(f"Connected as user: {result[0][0]}")
170-
171-
print("Token federation test successful!")
172-
return True
173-
except Exception as e:
174-
print(f"Error connecting to Databricks: {str(e)}")
175-
print("===================================\n")
176-
sys.exit(1)
177-
178-
if __name__ == "__main__":
179-
main()
180-
EOF
181-
chmod +x test_github_token_federation.py
182-
18370
- name: Test token federation with GitHub OIDC token
18471
env:
18572
DATABRICKS_HOST_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST_FOR_TF }}
18673
DATABRICKS_HTTP_PATH_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_http_path || secrets.DATABRICKS_HTTP_PATH_FOR_TF }}
18774
IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID }}
18875
OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }}
18976
run: |
190-
python test_github_token_federation.py
77+
python tests/token_federation/github_oidc_test.py
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#!/usr/bin/env python3
2+
3+
"""
4+
Test script for Databricks SQL token federation with GitHub Actions OIDC tokens.
5+
6+
This script tests the Databricks SQL connector with token federation
7+
using a GitHub Actions OIDC token. It connects to a Databricks SQL warehouse,
8+
runs a simple query, and shows the connected user.
9+
"""
10+
11+
import os
12+
import sys
13+
import json
14+
import base64
15+
from databricks import sql
16+
17+
18+
def decode_jwt(token):
19+
"""Decode and return the claims from a JWT token."""
20+
try:
21+
parts = token.split(".")
22+
if len(parts) != 3:
23+
raise ValueError("Invalid JWT format")
24+
25+
payload = parts[1]
26+
# Add padding if needed
27+
padding = '=' * (4 - len(payload) % 4)
28+
payload += padding
29+
30+
decoded = base64.b64decode(payload)
31+
return json.loads(decoded)
32+
except Exception as e:
33+
print(f"Failed to decode token: {str(e)}")
34+
return None
35+
36+
37+
def main():
38+
# Get GitHub OIDC token
39+
github_token = os.environ.get("OIDC_TOKEN")
40+
if not github_token:
41+
print("GitHub OIDC token not available")
42+
sys.exit(1)
43+
44+
# Get Databricks connection parameters
45+
host = os.environ.get("DATABRICKS_HOST_FOR_TF")
46+
http_path = os.environ.get("DATABRICKS_HTTP_PATH_FOR_TF")
47+
identity_federation_client_id = os.environ.get("IDENTITY_FEDERATION_CLIENT_ID")
48+
49+
if not host or not http_path:
50+
print("Missing Databricks connection parameters")
51+
sys.exit(1)
52+
53+
# Display token claims for debugging
54+
claims = decode_jwt(github_token)
55+
if claims:
56+
print("\n=== GitHub OIDC Token Claims ===")
57+
print(f"Token issuer: {claims.get('iss')}")
58+
print(f"Token subject: {claims.get('sub')}")
59+
print(f"Token audience: {claims.get('aud')}")
60+
print(f"Token expiration: {claims.get('exp', 'unknown')}")
61+
print(f"Repository: {claims.get('repository', 'unknown')}")
62+
print(f"Workflow ref: {claims.get('workflow_ref', 'unknown')}")
63+
print(f"Event name: {claims.get('event_name', 'unknown')}")
64+
print("===============================\n")
65+
66+
try:
67+
# Connect to Databricks using token federation
68+
print(f"=== Testing Connection via Connector ===")
69+
print(f"Connecting to Databricks at {host}{http_path}")
70+
print(f"Using client ID: {identity_federation_client_id}")
71+
72+
connection_params = {
73+
"server_hostname": host,
74+
"http_path": http_path,
75+
"access_token": github_token,
76+
"auth_type": "token-federation",
77+
"identity_federation_client_id": identity_federation_client_id,
78+
}
79+
80+
with sql.connect(**connection_params) as connection:
81+
print("Connection established successfully")
82+
83+
# Execute a simple query
84+
cursor = connection.cursor()
85+
cursor.execute("SELECT 1 + 1 as result")
86+
result = cursor.fetchall()
87+
print(f"Query result: {result[0][0]}")
88+
89+
# Show current user
90+
cursor.execute("SELECT current_user() as user")
91+
result = cursor.fetchall()
92+
print(f"Connected as user: {result[0][0]}")
93+
94+
print("Token federation test successful!")
95+
return True
96+
except Exception as e:
97+
print(f"Error connecting to Databricks: {str(e)}")
98+
print("===================================\n")
99+
sys.exit(1)
100+
101+
102+
if __name__ == "__main__":
103+
main()

0 commit comments

Comments
 (0)