Skip to content

Commit e87b52d

Browse files
committed
readability
1 parent 3613cb0 commit e87b52d

File tree

1 file changed

+3
-281
lines changed

1 file changed

+3
-281
lines changed

.github/workflows/token-federation-test.yml

Lines changed: 3 additions & 281 deletions
Original file line numberDiff line numberDiff line change
@@ -62,81 +62,6 @@ jobs:
6262
pip install -e .
6363
pip install pyarrow
6464
65-
- name: Create debugging patch script
66-
run: |
67-
cat > patch_for_debugging.py << 'EOF'
68-
#!/usr/bin/env python3
69-
70-
def patch_code():
71-
with open('src/databricks/sql/auth/token_federation.py', 'r') as f:
72-
content = f.read()
73-
74-
# Add token debugging
75-
modified = content.replace(
76-
'def _exchange_token(self, token, force_refresh=False):',
77-
'def _exchange_token(self, token, force_refresh=False):\n # Debug token info\n import jwt\n try:\n decoded = jwt.decode(token, options={"verify_signature": False})\n print(f"Token issuer: {decoded.get(\'iss\')}")\n print(f"Token subject: {decoded.get(\'sub\')}")\n print(f"Token audience: {decoded.get(\'aud\') if isinstance(decoded.get(\'aud\'), str) else decoded.get(\'aud\', [])[0] if decoded.get(\'aud\') else \'\'}")\n except Exception as e:\n print(f"Unable to decode token: {str(e)}")'
78-
)
79-
80-
# Add verbose request debugging
81-
modified = modified.replace(
82-
'try:\n # Make the token exchange request',
83-
'try:\n import urllib.parse\n # Debug full request\n print(f"Connecting to Databricks at {self.host}")\n print(f"Token endpoint: {self.token_endpoint}")\n print(f"Request parameters: {urllib.parse.urlencode(params)}")\n print(f"Request headers: {headers}")\n # Make the token exchange request'
84-
)
85-
86-
# Add verbose response debugging
87-
modified = modified.replace(
88-
'response = requests.post(self.token_endpoint, data=params, headers=headers)',
89-
'response = requests.post(self.token_endpoint, data=params, headers=headers)\n print(f"Response status: {response.status_code}")\n print(f"Response headers: {dict(response.headers)}")\n print(f"Response body: {response.text}")'
90-
)
91-
92-
# Improve error handling
93-
modified = modified.replace(
94-
'except RequestException as e:',
95-
'except RequestException as e:\n print(f"Failed to perform token exchange: {str(e)}")\n if hasattr(e, "response") and e.response:\n print(f"Error response status: {e.response.status_code}")\n print(f"Error response headers: {dict(e.response.headers)}")\n print(f"Error response text: {e.response.text}")'
96-
)
97-
98-
with open('src/databricks/sql/auth/token_federation.py', 'w') as f:
99-
f.write(modified)
100-
101-
if __name__ == "__main__":
102-
patch_code()
103-
EOF
104-
105-
chmod +x patch_for_debugging.py
106-
107-
- name: Install PyJWT for token debugging
108-
run: pip install pyjwt
109-
110-
- name: Apply debugging patches to token_federation.py
111-
run: python patch_for_debugging.py
112-
113-
- name: Create audience fix patch script
114-
run: |
115-
cat > patch_for_audience_fix.py << 'EOF'
116-
#!/usr/bin/env python3
117-
118-
def patch_code():
119-
with open('src/databricks/sql/auth/token_federation.py', 'r') as f:
120-
content = f.read()
121-
122-
# Fix audience handling
123-
modified = content.replace(
124-
'def _exchange_token(self, token, force_refresh=False):',
125-
'def _exchange_token(self, token, force_refresh=False):\\n # Additional handling for different audience formats\\n import jwt\\n try:\\n # Try both standard and alternative audience formats\\n audience_tried = False\\n \\n def try_with_audience(token, audience):\\n nonlocal audience_tried\\n if audience_tried:\\n return None\\n \\n audience_tried = True\\n decoded = jwt.decode(token, options={\"verify_signature\": False})\\n aud = decoded.get(\"aud\")\\n \\n # Check if aud is a list and convert to string if needed\\n if isinstance(aud, list) and len(aud) > 0:\\n aud = aud[0]\\n \\n # Print audience for debugging\\n print(f\"Original token audience: {aud}\")\\n \\n if aud != audience:\\n print(f\"WARNING: Token audience \'{aud}\' doesn\'t match expected audience \'{audience}\'\")\\n # We won\'t modify the token as that would invalidate the signature\\n \\n return None\\n \\n # We\'re just collecting debugging info, not modifying the token\\n try_with_audience(token, \"https://github.com/databricks\")\\n \\n except Exception as e:\\n print(f\"Audience debug error: {str(e)}\")'
126-
)
127-
128-
with open('src/databricks/sql/auth/token_federation.py', 'w') as f:
129-
f.write(modified)
130-
131-
if __name__ == "__main__":
132-
patch_code()
133-
EOF
134-
135-
chmod +x patch_for_audience_fix.py
136-
137-
- name: Apply audience fix patches
138-
run: python patch_for_audience_fix.py
139-
14065
- name: Get GitHub OIDC token
14166
id: get-id-token
14267
uses: actions/github-script@v7
@@ -146,116 +71,6 @@ jobs:
14671
core.setSecret(token)
14772
core.setOutput('token', token)
14873
149-
- name: Decode and display OIDC token claims
150-
env:
151-
OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }}
152-
run: |
153-
echo "Decoding GitHub OIDC token claims..."
154-
python -c '
155-
import sys, base64, json
156-
157-
token = """$OIDC_TOKEN"""
158-
159-
# Parse the token
160-
try:
161-
header, payload, signature = token.split(".")
162-
163-
# Add padding if needed
164-
payload_padding = payload + "=" * (-len(payload) % 4)
165-
166-
# Decode the payload
167-
decoded_payload = base64.b64decode(payload_padding).decode("utf-8")
168-
claims = json.loads(decoded_payload)
169-
170-
# Print important claims
171-
print("\n=== GITHUB OIDC TOKEN CLAIMS ===")
172-
print(f"Issuer (iss): {claims.get('iss')}")
173-
print(f"Subject (sub): {claims.get('sub')}")
174-
print(f"Audience (aud): {claims.get('aud')}")
175-
print(f"Repository: {claims.get('repository')}")
176-
print(f"Repository owner: {claims.get('repository_owner')}")
177-
print(f"Event name: {claims.get('event_name')}")
178-
print(f"Ref: {claims.get('ref')}")
179-
print(f"Workflow ref: {claims.get('workflow_ref')}")
180-
print("\n=== FULL CLAIMS ===")
181-
print(json.dumps(claims, indent=2))
182-
print("===========================\n")
183-
except Exception as e:
184-
print(f"Failed to decode token: {str(e)}")
185-
'
186-
187-
- name: Debug token exchange with curl
188-
env:
189-
DATABRICKS_HOST_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST_FOR_TF }}
190-
IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID }}
191-
OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }}
192-
run: |
193-
echo "Attempting direct token exchange with curl..."
194-
echo "Host: $DATABRICKS_HOST_FOR_TF"
195-
echo "Client ID: $IDENTITY_FEDERATION_CLIENT_ID"
196-
197-
# Debug token claims before making the request
198-
echo "Token claims:"
199-
python3 -c "
200-
import base64, json, sys
201-
token = \"$OIDC_TOKEN\"
202-
parts = token.split(\".\")
203-
if len(parts) >= 2:
204-
padding = \"=\" * (4 - len(parts[1]) % 4)
205-
decoded_bytes = base64.b64decode(parts[1] + padding)
206-
decoded_str = decoded_bytes.decode(\"utf-8\")
207-
claims = json.loads(decoded_str)
208-
print(f\"Token issuer: {claims.get('iss', 'unknown')}\")
209-
print(f\"Token subject: {claims.get('sub', 'unknown')}\")
210-
print(f\"Token audience: {claims.get('aud', 'unknown')}\")
211-
else:
212-
print(\"Invalid token format\")
213-
"
214-
215-
# Create a properly URL-encoded request
216-
echo "Creating token exchange request..."
217-
curl_data=$(cat << 'EOF'
218-
client_id=$IDENTITY_FEDERATION_CLIENT_ID&\
219-
subject_token=$OIDC_TOKEN&\
220-
subject_token_type=urn:ietf:params:oauth:token-type:jwt&\
221-
grant_type=urn:ietf:params:oauth:grant-type:token-exchange&\
222-
scope=sql
223-
EOF
224-
)
225-
226-
# Substitute environment variables in the curl data
227-
curl_data=$(eval echo "$curl_data")
228-
229-
# Print request details (except the token)
230-
echo "Request URL: https://$DATABRICKS_HOST_FOR_TF/oidc/v1/token"
231-
echo "Request data: $(echo "$curl_data" | sed 's/subject_token=.*&/subject_token=REDACTED&/')"
232-
233-
# Make the request with detailed info
234-
echo "Sending request..."
235-
response=$(curl -v -s -X POST "https://$DATABRICKS_HOST_FOR_TF/oidc/v1/token" \
236-
--data-raw "$curl_data" \
237-
-H "Content-Type: application/x-www-form-urlencoded" \
238-
-H "Accept: application/json" \
239-
2>&1)
240-
241-
# Extract and display results
242-
echo "Response:"
243-
echo "$response"
244-
245-
# Extract HTTP status if possible
246-
status_code=$(echo "$response" | grep -o "< HTTP/[0-9.]* [0-9]*" | grep -o "[0-9]*$" || echo "unknown")
247-
echo "HTTP Status Code: $status_code"
248-
249-
# Try to extract and pretty-print the JSON response body if present
250-
response_body=$(echo "$response" | sed -n -e '/^{/,/^}/p' || echo "")
251-
if [ ! -z "$response_body" ]; then
252-
echo "Response body (formatted):"
253-
echo "$response_body" | python3 -m json.tool || echo "$response_body"
254-
fi
255-
256-
# Don't fail the workflow if curl fails
257-
exit 0
258-
25974
- name: Create test script
26075
run: |
26176
cat > test_github_token_federation.py << 'EOF'
@@ -264,7 +79,7 @@ jobs:
26479
"""
26580
Test script for Databricks SQL token federation with GitHub Actions OIDC tokens.
26681
267-
This script demonstrates how to use the Databricks SQL connector with token federation
82+
This script tests the Databricks SQL connector with token federation
26883
using a GitHub Actions OIDC token. It connects to a Databricks SQL warehouse,
26984
runs a simple query, and shows the connected user.
27085
"""
@@ -273,9 +88,7 @@ jobs:
27388
import sys
27489
import json
27590
import base64
276-
import requests
27791
from databricks import sql
278-
import time
27992
28093
def decode_jwt(token):
28194
"""Decode and return the claims from a JWT token."""
@@ -295,69 +108,6 @@ jobs:
295108
print(f"Failed to decode token: {str(e)}")
296109
return None
297110
298-
def test_direct_token_exchange(host, token, client_id, audience=None):
299-
"""Directly test token exchange with the Databricks API."""
300-
try:
301-
url = f"https://{host}/oidc/v1/token"
302-
data = {
303-
"client_id": client_id,
304-
"subject_token": token,
305-
"subject_token_type": "urn:ietf:params:oauth:token-type:jwt",
306-
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
307-
"scope": "sql",
308-
"return_original_token_if_authenticated": "true"
309-
}
310-
311-
headers = {
312-
"Content-Type": "application/x-www-form-urlencoded",
313-
"Accept": "application/json"
314-
}
315-
316-
print(f"Testing direct token exchange with {url}")
317-
print(f"Request parameters: {data}")
318-
319-
# Add debugging info
320-
claims = decode_jwt(token)
321-
if claims:
322-
print(f"Token issuer: {claims.get('iss', 'unknown')}")
323-
print(f"Token subject: {claims.get('sub', 'unknown')}")
324-
print(f"Token audience: {claims.get('aud', 'unknown')}")
325-
326-
# If audience was specified in policy but doesn't match token
327-
if audience and audience != claims.get('aud'):
328-
print("WARNING: Expected audience and token audience don't match")
329-
print(f"Expected: {audience}")
330-
print(f"Actual: {claims.get('aud')}")
331-
332-
# Enable more verbose HTTP debugging
333-
import http.client as http_client
334-
http_client.HTTPConnection.debuglevel = 1
335-
336-
# Log requests library debug info
337-
import logging
338-
logging.basicConfig()
339-
logging.getLogger().setLevel(logging.DEBUG)
340-
requests_log = logging.getLogger("requests.packages.urllib3")
341-
requests_log.setLevel(logging.DEBUG)
342-
requests_log.propagate = True
343-
344-
response = requests.post(url, data=data, headers=headers)
345-
346-
print(f"Status code: {response.status_code}")
347-
print(f"Response headers: {dict(response.headers)}")
348-
print(f"Response content: {response.text}")
349-
350-
if response.status_code == 200:
351-
try:
352-
return json.loads(response.text).get("access_token")
353-
except json.JSONDecodeError:
354-
print("Failed to parse response JSON")
355-
return None
356-
return None
357-
except Exception as e:
358-
print(f"Direct token exchange failed: {str(e)}")
359-
return None
360-
361111
def main():
362112
# Get GitHub OIDC token
363113
github_token = os.environ.get("OIDC_TOKEN")
@@ -374,6 +124,7 @@ jobs:
374124
print("Missing Databricks connection parameters")
375125
sys.exit(1)
376126
127+
# Display token claims for debugging
377128
claims = decode_jwt(github_token)
378129
if claims:
379130
print("\n=== GitHub OIDC Token Claims ===")
@@ -386,38 +137,9 @@ jobs:
386137
print(f"Event name: {claims.get('event_name', 'unknown')}")
387138
print("===============================\n")
388139
389-
# Try token exchange with several possible audience values
390-
audience_values = [
391-
"https://github.com/databricks", # Standard audience for GitHub tokens
392-
"https://github.com", # Alternative audience
393-
None # No audience
394-
]
395-
396-
# Direct token exchange test
397-
access_token = None
398-
for audience in audience_values:
399-
print(f"\n=== Testing Direct Token Exchange (audience={audience}) ===")
400-
result = test_direct_token_exchange(host, github_token, identity_federation_client_id, audience)
401-
if result:
402-
print("Direct token exchange successful!")
403-
access_token = result
404-
token_claims = decode_jwt(result)
405-
if token_claims:
406-
print(f"Databricks token subject: {token_claims.get('sub', 'unknown')}")
407-
break
408-
print(f"Token exchange failed with audience={audience}")
409-
# Add a small delay between attempts
410-
time.sleep(1)
411-
412-
if not access_token:
413-
print("All token exchange attempts failed")
414-
print("=====================================\n")
415-
else:
416-
print("=====================================\n")
417-
418140
try:
419141
# Connect to Databricks using token federation
420-
print(f"\n=== Testing Connection via Connector ===")
142+
print(f"=== Testing Connection via Connector ===")
421143
print(f"Connecting to Databricks at {host}{http_path}")
422144
print(f"Using client ID: {identity_federation_client_id}")
423145

0 commit comments

Comments
 (0)