Skip to content

Implements Token Federation for Python Driver #16

Implements Token Federation for Python Driver

Implements Token Federation for Python Driver #16

name: Token Federation Test
# This workflow tests token federation functionality with GitHub Actions OIDC tokens
# in the databricks-sql-python connector to ensure CI/CD functionality
on:
# Manual trigger with required inputs
workflow_dispatch:
inputs:
databricks_host:
description: 'Databricks host URL (e.g., example.cloud.databricks.com)'
required: true
databricks_http_path:
description: 'Databricks HTTP path (e.g., /sql/1.0/warehouses/abc123)'
required: true
identity_federation_client_id:
description: 'Identity federation client ID'
required: true
# Automatically run on PR that changes token federation files
pull_request:
branches:
- main
# Run on push to main that affects token federation
push:
paths:
- 'src/databricks/sql/auth/token_federation.py'
- 'src/databricks/sql/auth/auth.py'
- 'examples/token_federation_*.py'
branches:
- main
permissions:
# Required for GitHub OIDC token
id-token: write
contents: read
jobs:
test-token-federation:
runs-on:
group: databricks-protected-runner-group
labels: linux-ubuntu-latest
steps:
- name: Debug OIDC Claims
uses: github/actions-oidc-debugger@main
with:
audience: '${{ github.server_url }}/${{ github.repository_owner }}'
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python 3.9
uses: actions/setup-python@v5
with:
python-version: '3.9'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .
pip install pyarrow
- name: Create debugging patch script
run: |
cat > patch_for_debugging.py << 'EOF'
#!/usr/bin/env python3
def patch_code():
with open('src/databricks/sql/auth/token_federation.py', 'r') as f:
content = f.read()
# Add token debugging
modified = content.replace(
'def _exchange_token(self, token, force_refresh=False):',
'def _exchange_token(self, token, force_refresh=False):\n # Debug token info\n import jwt\n try:\n decoded = jwt.decode(token, options={"verify_signature": False})\n print(f"Token issuer: {decoded.get(\'iss\')}")\n print(f"Token subject: {decoded.get(\'sub\')}")\n print(f"Token audience: {decoded.get(\'aud\') if isinstance(decoded.get(\'aud\'), str) else decoded.get(\'aud\', [])[0] if decoded.get(\'aud\') else \'\'}")\n except Exception as e:\n print(f"Unable to decode token: {str(e)}")'
)
# Add verbose request debugging
modified = modified.replace(
'try:\n # Make the token exchange request',
'try:\n import urllib.parse\n # Debug full request\n print(f"Connecting to Databricks at {self.host}")\n print(f"Token endpoint: {self.token_endpoint}")\n print(f"Request parameters: {urllib.parse.urlencode(params)}")\n print(f"Request headers: {headers}")\n # Make the token exchange request'
)
# Add verbose response debugging
modified = modified.replace(
'response = requests.post(self.token_endpoint, data=params, headers=headers)',
'response = requests.post(self.token_endpoint, data=params, headers=headers)\n print(f"Response status: {response.status_code}")\n print(f"Response headers: {dict(response.headers)}")\n print(f"Response body: {response.text}")'
)
# Improve error handling
modified = modified.replace(
'except RequestException as e:',
'except RequestException as e:\n print(f"Failed to perform token exchange: {str(e)}")\n if hasattr(e, "response") and e.response:\n print(f"Error response status: {e.response.status_code}")\n print(f"Error response headers: {dict(e.response.headers)}")\n print(f"Error response text: {e.response.text}")'
)
with open('src/databricks/sql/auth/token_federation.py', 'w') as f:
f.write(modified)
if __name__ == "__main__":
patch_code()
EOF
chmod +x patch_for_debugging.py
- name: Install PyJWT for token debugging
run: pip install pyjwt
- name: Apply debugging patches to token_federation.py
run: python patch_for_debugging.py
- name: Create audience fix patch script
run: |
cat > patch_for_audience_fix.py << 'EOF'
#!/usr/bin/env python3
def patch_code():
with open('src/databricks/sql/auth/token_federation.py', 'r') as f:
content = f.read()
# Fix audience handling
modified = content.replace(
'def _exchange_token(self, token, force_refresh=False):',
'def _exchange_token(self, token, force_refresh=False):\\n # Additional handling for different audience formats\\n import jwt\\n try:\\n # Try both standard and alternative audience formats\\n audience_tried = False\\n \\n def try_with_audience(token, audience):\\n nonlocal audience_tried\\n if audience_tried:\\n return None\\n \\n audience_tried = True\\n decoded = jwt.decode(token, options={\"verify_signature\": False})\\n aud = decoded.get(\"aud\")\\n \\n # Check if aud is a list and convert to string if needed\\n if isinstance(aud, list) and len(aud) > 0:\\n aud = aud[0]\\n \\n # Print audience for debugging\\n print(f\"Original token audience: {aud}\")\\n \\n if aud != audience:\\n print(f\"WARNING: Token audience \'{aud}\' doesn\'t match expected audience \'{audience}\'\")\\n # We won\'t modify the token as that would invalidate the signature\\n \\n return None\\n \\n # We\'re just collecting debugging info, not modifying the token\\n try_with_audience(token, \"https://github.com/databricks\")\\n \\n except Exception as e:\\n print(f\"Audience debug error: {str(e)}\")'
)
with open('src/databricks/sql/auth/token_federation.py', 'w') as f:
f.write(modified)
if __name__ == "__main__":
patch_code()
EOF
chmod +x patch_for_audience_fix.py
- name: Apply audience fix patches
run: python patch_for_audience_fix.py
- name: Get GitHub OIDC token
id: get-id-token
uses: actions/github-script@v7
with:
script: |
const token = await core.getIDToken('https://github.com/databricks')
core.setSecret(token)
core.setOutput('token', token)
- name: Decode and display OIDC token claims
env:
OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }}
run: |
echo "Decoding GitHub OIDC token claims..."
python -c '
import sys, base64, json
token = """$OIDC_TOKEN"""
# Parse the token
try:
header, payload, signature = token.split(".")
# Add padding if needed
payload_padding = payload + "=" * (-len(payload) % 4)
# Decode the payload
decoded_payload = base64.b64decode(payload_padding).decode("utf-8")
claims = json.loads(decoded_payload)
# Print important claims
print("\n=== GITHUB OIDC TOKEN CLAIMS ===")
print(f"Issuer (iss): {claims.get('iss')}")
print(f"Subject (sub): {claims.get('sub')}")
print(f"Audience (aud): {claims.get('aud')}")
print(f"Repository: {claims.get('repository')}")
print(f"Repository owner: {claims.get('repository_owner')}")
print(f"Event name: {claims.get('event_name')}")
print(f"Ref: {claims.get('ref')}")
print(f"Workflow ref: {claims.get('workflow_ref')}")
print("\n=== FULL CLAIMS ===")
print(json.dumps(claims, indent=2))
print("===========================\n")
except Exception as e:
print(f"Failed to decode token: {str(e)}")
'
- name: Debug token exchange with curl
env:
DATABRICKS_HOST_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST_FOR_TF }}
IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID }}
OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }}
run: |
echo "Attempting direct token exchange with curl..."
echo "Host: $DATABRICKS_HOST_FOR_TF"
echo "Client ID: $IDENTITY_FEDERATION_CLIENT_ID"
# Debug token claims before making the request
echo "Token claims:"
python3 -c "
import base64, json, sys
token = \"$OIDC_TOKEN\"
parts = token.split(\".\")
if len(parts) >= 2:
padding = \"=\" * (4 - len(parts[1]) % 4)
decoded_bytes = base64.b64decode(parts[1] + padding)
decoded_str = decoded_bytes.decode(\"utf-8\")
claims = json.loads(decoded_str)
print(f\"Token issuer: {claims.get('iss', 'unknown')}\")
print(f\"Token subject: {claims.get('sub', 'unknown')}\")
print(f\"Token audience: {claims.get('aud', 'unknown')}\")
else:
print(\"Invalid token format\")
"
# Create a properly URL-encoded request
echo "Creating token exchange request..."
curl_data=$(cat << 'EOF'
client_id=$IDENTITY_FEDERATION_CLIENT_ID&\
subject_token=$OIDC_TOKEN&\
subject_token_type=urn:ietf:params:oauth:token-type:jwt&\
grant_type=urn:ietf:params:oauth:grant-type:token-exchange&\
scope=sql
EOF
)
# Substitute environment variables in the curl data
curl_data=$(eval echo "$curl_data")
# Print request details (except the token)
echo "Request URL: https://$DATABRICKS_HOST_FOR_TF/oidc/v1/token"
echo "Request data: $(echo "$curl_data" | sed 's/subject_token=.*&/subject_token=REDACTED&/')"
# Make the request with detailed info
echo "Sending request..."
response=$(curl -v -s -X POST "https://$DATABRICKS_HOST_FOR_TF/oidc/v1/token" \
--data-raw "$curl_data" \
-H "Content-Type: application/x-www-form-urlencoded" \
-H "Accept: application/json" \
2>&1)
# Extract and display results
echo "Response:"
echo "$response"
# Extract HTTP status if possible
status_code=$(echo "$response" | grep -o "< HTTP/[0-9.]* [0-9]*" | grep -o "[0-9]*$" || echo "unknown")
echo "HTTP Status Code: $status_code"
# Try to extract and pretty-print the JSON response body if present
response_body=$(echo "$response" | sed -n -e '/^{/,/^}/p' || echo "")
if [ ! -z "$response_body" ]; then
echo "Response body (formatted):"
echo "$response_body" | python3 -m json.tool || echo "$response_body"
fi
# Don't fail the workflow if curl fails
exit 0
- name: Create test script
run: |
cat > test_github_token_federation.py << 'EOF'
#!/usr/bin/env python3
"""
Test script for Databricks SQL token federation with GitHub Actions OIDC tokens.
This script demonstrates how to use the Databricks SQL connector with token federation
using a GitHub Actions OIDC token. It connects to a Databricks SQL warehouse,
runs a simple query, and shows the connected user.
"""
import os
import sys
import json
import base64
import requests
from databricks import sql
import time
def decode_jwt(token):
"""Decode and return the claims from a JWT token."""
try:
parts = token.split(".")
if len(parts) != 3:
raise ValueError("Invalid JWT format")
payload = parts[1]
# Add padding if needed
padding = '=' * (4 - len(payload) % 4)
payload += padding
decoded = base64.b64decode(payload)
return json.loads(decoded)
except Exception as e:
print(f"Failed to decode token: {str(e)}")
return None
def test_direct_token_exchange(host, token, client_id, audience=None):
"""Directly test token exchange with the Databricks API."""
try:
url = f"https://{host}/oidc/v1/token"
data = {
"client_id": client_id,
"subject_token": token,
"subject_token_type": "urn:ietf:params:oauth:token-type:jwt",
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
"scope": "sql",
"return_original_token_if_authenticated": "true"
}
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json"
}
print(f"Testing direct token exchange with {url}")
print(f"Request parameters: {data}")
# Add debugging info
claims = decode_jwt(token)
if claims:
print(f"Token issuer: {claims.get('iss', 'unknown')}")
print(f"Token subject: {claims.get('sub', 'unknown')}")
print(f"Token audience: {claims.get('aud', 'unknown')}")
# If audience was specified in policy but doesn't match token
if audience and audience != claims.get('aud'):
print("WARNING: Expected audience and token audience don't match")
print(f"Expected: {audience}")
print(f"Actual: {claims.get('aud')}")
# Enable more verbose HTTP debugging
import http.client as http_client
http_client.HTTPConnection.debuglevel = 1
# Log requests library debug info
import logging
logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)
requests_log = logging.getLogger("requests.packages.urllib3")
requests_log.setLevel(logging.DEBUG)
requests_log.propagate = True
response = requests.post(url, data=data, headers=headers)
print(f"Status code: {response.status_code}")
print(f"Response headers: {dict(response.headers)}")
print(f"Response content: {response.text}")
if response.status_code == 200:
try:
return json.loads(response.text).get("access_token")
except json.JSONDecodeError:
print("Failed to parse response JSON")
return None
return None
except Exception as e:
print(f"Direct token exchange failed: {str(e)}")
return None
def main():
# Get GitHub OIDC token
github_token = os.environ.get("OIDC_TOKEN")
if not github_token:
print("GitHub OIDC token not available")
sys.exit(1)
# Get Databricks connection parameters
host = os.environ.get("DATABRICKS_HOST_FOR_TF")
http_path = os.environ.get("DATABRICKS_HTTP_PATH_FOR_TF")
identity_federation_client_id = os.environ.get("IDENTITY_FEDERATION_CLIENT_ID")
if not host or not http_path:
print("Missing Databricks connection parameters")
sys.exit(1)
claims = decode_jwt(github_token)
if claims:
print("\n=== GitHub OIDC Token Claims ===")
print(f"Token issuer: {claims.get('iss')}")
print(f"Token subject: {claims.get('sub')}")
print(f"Token audience: {claims.get('aud')}")
print(f"Token expiration: {claims.get('exp', 'unknown')}")
print(f"Repository: {claims.get('repository', 'unknown')}")
print(f"Workflow ref: {claims.get('workflow_ref', 'unknown')}")
print(f"Event name: {claims.get('event_name', 'unknown')}")
print("===============================\n")
# Try token exchange with several possible audience values
audience_values = [
"https://github.com/databricks", # Standard audience for GitHub tokens
"https://github.com", # Alternative audience
None # No audience
]
# Direct token exchange test
access_token = None
for audience in audience_values:
print(f"\n=== Testing Direct Token Exchange (audience={audience}) ===")
result = test_direct_token_exchange(host, github_token, identity_federation_client_id, audience)
if result:
print("Direct token exchange successful!")
access_token = result
token_claims = decode_jwt(result)
if token_claims:
print(f"Databricks token subject: {token_claims.get('sub', 'unknown')}")
break
print(f"Token exchange failed with audience={audience}")
# Add a small delay between attempts
time.sleep(1)
if not access_token:
print("All token exchange attempts failed")
print("=====================================\n")
else:
print("=====================================\n")
try:
# Connect to Databricks using token federation
print(f"\n=== Testing Connection via Connector ===")
print(f"Connecting to Databricks at {host}{http_path}")
print(f"Using client ID: {identity_federation_client_id}")
connection_params = {
"server_hostname": host,
"http_path": http_path,
"access_token": github_token,
"auth_type": "token-federation",
"identity_federation_client_id": identity_federation_client_id,
}
print("Connection parameters:")
print(json.dumps({k: v if k != 'access_token' else '***' for k, v in connection_params.items()}, indent=2))
with sql.connect(**connection_params) as connection:
print("Connection established successfully")
# Execute a simple query
cursor = connection.cursor()
cursor.execute("SELECT 1 + 1 as result")
result = cursor.fetchall()
print(f"Query result: {result[0][0]}")
# Show current user
cursor.execute("SELECT current_user() as user")
result = cursor.fetchall()
print(f"Connected as user: {result[0][0]}")
print("Token federation test successful!")
return True
except Exception as e:
print(f"Error connecting to Databricks: {str(e)}")
print("===================================\n")
sys.exit(1)
if __name__ == "__main__":
main()
EOF
chmod +x test_github_token_federation.py
- name: Test token federation with GitHub OIDC token
env:
DATABRICKS_HOST_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST_FOR_TF }}
DATABRICKS_HTTP_PATH_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_http_path || secrets.DATABRICKS_HTTP_PATH_FOR_TF }}
IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID }}
OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }}
run: |
python test_github_token_federation.py