Implements Token Federation for Python Driver #16
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| name: Token Federation Test | |
| # This workflow tests token federation functionality with GitHub Actions OIDC tokens | |
| # in the databricks-sql-python connector to ensure CI/CD functionality | |
| on: | |
| # Manual trigger with required inputs | |
| workflow_dispatch: | |
| inputs: | |
| databricks_host: | |
| description: 'Databricks host URL (e.g., example.cloud.databricks.com)' | |
| required: true | |
| databricks_http_path: | |
| description: 'Databricks HTTP path (e.g., /sql/1.0/warehouses/abc123)' | |
| required: true | |
| identity_federation_client_id: | |
| description: 'Identity federation client ID' | |
| required: true | |
| # Automatically run on PR that changes token federation files | |
| pull_request: | |
| branches: | |
| - main | |
| # Run on push to main that affects token federation | |
| push: | |
| paths: | |
| - 'src/databricks/sql/auth/token_federation.py' | |
| - 'src/databricks/sql/auth/auth.py' | |
| - 'examples/token_federation_*.py' | |
| branches: | |
| - main | |
| permissions: | |
| # Required for GitHub OIDC token | |
| id-token: write | |
| contents: read | |
| jobs: | |
| test-token-federation: | |
| runs-on: | |
| group: databricks-protected-runner-group | |
| labels: linux-ubuntu-latest | |
| steps: | |
| - name: Debug OIDC Claims | |
| uses: github/actions-oidc-debugger@main | |
| with: | |
| audience: '${{ github.server_url }}/${{ github.repository_owner }}' | |
| - name: Checkout code | |
| uses: actions/checkout@v4 | |
| - name: Set up Python 3.9 | |
| uses: actions/setup-python@v5 | |
| with: | |
| python-version: '3.9' | |
| - name: Install dependencies | |
| run: | | |
| python -m pip install --upgrade pip | |
| pip install -e . | |
| pip install pyarrow | |
| - name: Create debugging patch script | |
| run: | | |
| cat > patch_for_debugging.py << 'EOF' | |
| #!/usr/bin/env python3 | |
| def patch_code(): | |
| with open('src/databricks/sql/auth/token_federation.py', 'r') as f: | |
| content = f.read() | |
| # Add token debugging | |
| modified = content.replace( | |
| 'def _exchange_token(self, token, force_refresh=False):', | |
| 'def _exchange_token(self, token, force_refresh=False):\n # Debug token info\n import jwt\n try:\n decoded = jwt.decode(token, options={"verify_signature": False})\n print(f"Token issuer: {decoded.get(\'iss\')}")\n print(f"Token subject: {decoded.get(\'sub\')}")\n print(f"Token audience: {decoded.get(\'aud\') if isinstance(decoded.get(\'aud\'), str) else decoded.get(\'aud\', [])[0] if decoded.get(\'aud\') else \'\'}")\n except Exception as e:\n print(f"Unable to decode token: {str(e)}")' | |
| ) | |
| # Add verbose request debugging | |
| modified = modified.replace( | |
| 'try:\n # Make the token exchange request', | |
| 'try:\n import urllib.parse\n # Debug full request\n print(f"Connecting to Databricks at {self.host}")\n print(f"Token endpoint: {self.token_endpoint}")\n print(f"Request parameters: {urllib.parse.urlencode(params)}")\n print(f"Request headers: {headers}")\n # Make the token exchange request' | |
| ) | |
| # Add verbose response debugging | |
| modified = modified.replace( | |
| 'response = requests.post(self.token_endpoint, data=params, headers=headers)', | |
| 'response = requests.post(self.token_endpoint, data=params, headers=headers)\n print(f"Response status: {response.status_code}")\n print(f"Response headers: {dict(response.headers)}")\n print(f"Response body: {response.text}")' | |
| ) | |
| # Improve error handling | |
| modified = modified.replace( | |
| 'except RequestException as e:', | |
| 'except RequestException as e:\n print(f"Failed to perform token exchange: {str(e)}")\n if hasattr(e, "response") and e.response:\n print(f"Error response status: {e.response.status_code}")\n print(f"Error response headers: {dict(e.response.headers)}")\n print(f"Error response text: {e.response.text}")' | |
| ) | |
| with open('src/databricks/sql/auth/token_federation.py', 'w') as f: | |
| f.write(modified) | |
| if __name__ == "__main__": | |
| patch_code() | |
| EOF | |
| chmod +x patch_for_debugging.py | |
| - name: Install PyJWT for token debugging | |
| run: pip install pyjwt | |
| - name: Apply debugging patches to token_federation.py | |
| run: python patch_for_debugging.py | |
| - name: Create audience fix patch script | |
| run: | | |
| cat > patch_for_audience_fix.py << 'EOF' | |
| #!/usr/bin/env python3 | |
| def patch_code(): | |
| with open('src/databricks/sql/auth/token_federation.py', 'r') as f: | |
| content = f.read() | |
| # Fix audience handling | |
| modified = content.replace( | |
| 'def _exchange_token(self, token, force_refresh=False):', | |
| 'def _exchange_token(self, token, force_refresh=False):\\n # Additional handling for different audience formats\\n import jwt\\n try:\\n # Try both standard and alternative audience formats\\n audience_tried = False\\n \\n def try_with_audience(token, audience):\\n nonlocal audience_tried\\n if audience_tried:\\n return None\\n \\n audience_tried = True\\n decoded = jwt.decode(token, options={\"verify_signature\": False})\\n aud = decoded.get(\"aud\")\\n \\n # Check if aud is a list and convert to string if needed\\n if isinstance(aud, list) and len(aud) > 0:\\n aud = aud[0]\\n \\n # Print audience for debugging\\n print(f\"Original token audience: {aud}\")\\n \\n if aud != audience:\\n print(f\"WARNING: Token audience \'{aud}\' doesn\'t match expected audience \'{audience}\'\")\\n # We won\'t modify the token as that would invalidate the signature\\n \\n return None\\n \\n # We\'re just collecting debugging info, not modifying the token\\n try_with_audience(token, \"https://github.com/databricks\")\\n \\n except Exception as e:\\n print(f\"Audience debug error: {str(e)}\")' | |
| ) | |
| with open('src/databricks/sql/auth/token_federation.py', 'w') as f: | |
| f.write(modified) | |
| if __name__ == "__main__": | |
| patch_code() | |
| EOF | |
| chmod +x patch_for_audience_fix.py | |
| - name: Apply audience fix patches | |
| run: python patch_for_audience_fix.py | |
| - name: Get GitHub OIDC token | |
| id: get-id-token | |
| uses: actions/github-script@v7 | |
| with: | |
| script: | | |
| const token = await core.getIDToken('https://github.com/databricks') | |
| core.setSecret(token) | |
| core.setOutput('token', token) | |
| - name: Decode and display OIDC token claims | |
| env: | |
| OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }} | |
| run: | | |
| echo "Decoding GitHub OIDC token claims..." | |
| python -c ' | |
| import sys, base64, json | |
| token = """$OIDC_TOKEN""" | |
| # Parse the token | |
| try: | |
| header, payload, signature = token.split(".") | |
| # Add padding if needed | |
| payload_padding = payload + "=" * (-len(payload) % 4) | |
| # Decode the payload | |
| decoded_payload = base64.b64decode(payload_padding).decode("utf-8") | |
| claims = json.loads(decoded_payload) | |
| # Print important claims | |
| print("\n=== GITHUB OIDC TOKEN CLAIMS ===") | |
| print(f"Issuer (iss): {claims.get('iss')}") | |
| print(f"Subject (sub): {claims.get('sub')}") | |
| print(f"Audience (aud): {claims.get('aud')}") | |
| print(f"Repository: {claims.get('repository')}") | |
| print(f"Repository owner: {claims.get('repository_owner')}") | |
| print(f"Event name: {claims.get('event_name')}") | |
| print(f"Ref: {claims.get('ref')}") | |
| print(f"Workflow ref: {claims.get('workflow_ref')}") | |
| print("\n=== FULL CLAIMS ===") | |
| print(json.dumps(claims, indent=2)) | |
| print("===========================\n") | |
| except Exception as e: | |
| print(f"Failed to decode token: {str(e)}") | |
| ' | |
| - name: Debug token exchange with curl | |
| env: | |
| DATABRICKS_HOST_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST_FOR_TF }} | |
| IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID }} | |
| OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }} | |
| run: | | |
| echo "Attempting direct token exchange with curl..." | |
| echo "Host: $DATABRICKS_HOST_FOR_TF" | |
| echo "Client ID: $IDENTITY_FEDERATION_CLIENT_ID" | |
| # Debug token claims before making the request | |
| echo "Token claims:" | |
| python3 -c " | |
| import base64, json, sys | |
| token = \"$OIDC_TOKEN\" | |
| parts = token.split(\".\") | |
| if len(parts) >= 2: | |
| padding = \"=\" * (4 - len(parts[1]) % 4) | |
| decoded_bytes = base64.b64decode(parts[1] + padding) | |
| decoded_str = decoded_bytes.decode(\"utf-8\") | |
| claims = json.loads(decoded_str) | |
| print(f\"Token issuer: {claims.get('iss', 'unknown')}\") | |
| print(f\"Token subject: {claims.get('sub', 'unknown')}\") | |
| print(f\"Token audience: {claims.get('aud', 'unknown')}\") | |
| else: | |
| print(\"Invalid token format\") | |
| " | |
| # Create a properly URL-encoded request | |
| echo "Creating token exchange request..." | |
| curl_data=$(cat << 'EOF' | |
| client_id=$IDENTITY_FEDERATION_CLIENT_ID&\ | |
| subject_token=$OIDC_TOKEN&\ | |
| subject_token_type=urn:ietf:params:oauth:token-type:jwt&\ | |
| grant_type=urn:ietf:params:oauth:grant-type:token-exchange&\ | |
| scope=sql | |
| EOF | |
| ) | |
| # Substitute environment variables in the curl data | |
| curl_data=$(eval echo "$curl_data") | |
| # Print request details (except the token) | |
| echo "Request URL: https://$DATABRICKS_HOST_FOR_TF/oidc/v1/token" | |
| echo "Request data: $(echo "$curl_data" | sed 's/subject_token=.*&/subject_token=REDACTED&/')" | |
| # Make the request with detailed info | |
| echo "Sending request..." | |
| response=$(curl -v -s -X POST "https://$DATABRICKS_HOST_FOR_TF/oidc/v1/token" \ | |
| --data-raw "$curl_data" \ | |
| -H "Content-Type: application/x-www-form-urlencoded" \ | |
| -H "Accept: application/json" \ | |
| 2>&1) | |
| # Extract and display results | |
| echo "Response:" | |
| echo "$response" | |
| # Extract HTTP status if possible | |
| status_code=$(echo "$response" | grep -o "< HTTP/[0-9.]* [0-9]*" | grep -o "[0-9]*$" || echo "unknown") | |
| echo "HTTP Status Code: $status_code" | |
| # Try to extract and pretty-print the JSON response body if present | |
| response_body=$(echo "$response" | sed -n -e '/^{/,/^}/p' || echo "") | |
| if [ ! -z "$response_body" ]; then | |
| echo "Response body (formatted):" | |
| echo "$response_body" | python3 -m json.tool || echo "$response_body" | |
| fi | |
| # Don't fail the workflow if curl fails | |
| exit 0 | |
| - name: Create test script | |
| run: | | |
| cat > test_github_token_federation.py << 'EOF' | |
| #!/usr/bin/env python3 | |
| """ | |
| Test script for Databricks SQL token federation with GitHub Actions OIDC tokens. | |
| This script demonstrates how to use the Databricks SQL connector with token federation | |
| using a GitHub Actions OIDC token. It connects to a Databricks SQL warehouse, | |
| runs a simple query, and shows the connected user. | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import base64 | |
| import requests | |
| from databricks import sql | |
| import time | |
| def decode_jwt(token): | |
| """Decode and return the claims from a JWT token.""" | |
| try: | |
| parts = token.split(".") | |
| if len(parts) != 3: | |
| raise ValueError("Invalid JWT format") | |
| payload = parts[1] | |
| # Add padding if needed | |
| padding = '=' * (4 - len(payload) % 4) | |
| payload += padding | |
| decoded = base64.b64decode(payload) | |
| return json.loads(decoded) | |
| except Exception as e: | |
| print(f"Failed to decode token: {str(e)}") | |
| return None | |
| def test_direct_token_exchange(host, token, client_id, audience=None): | |
| """Directly test token exchange with the Databricks API.""" | |
| try: | |
| url = f"https://{host}/oidc/v1/token" | |
| data = { | |
| "client_id": client_id, | |
| "subject_token": token, | |
| "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", | |
| "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", | |
| "scope": "sql", | |
| "return_original_token_if_authenticated": "true" | |
| } | |
| headers = { | |
| "Content-Type": "application/x-www-form-urlencoded", | |
| "Accept": "application/json" | |
| } | |
| print(f"Testing direct token exchange with {url}") | |
| print(f"Request parameters: {data}") | |
| # Add debugging info | |
| claims = decode_jwt(token) | |
| if claims: | |
| print(f"Token issuer: {claims.get('iss', 'unknown')}") | |
| print(f"Token subject: {claims.get('sub', 'unknown')}") | |
| print(f"Token audience: {claims.get('aud', 'unknown')}") | |
| # If audience was specified in policy but doesn't match token | |
| if audience and audience != claims.get('aud'): | |
| print("WARNING: Expected audience and token audience don't match") | |
| print(f"Expected: {audience}") | |
| print(f"Actual: {claims.get('aud')}") | |
| # Enable more verbose HTTP debugging | |
| import http.client as http_client | |
| http_client.HTTPConnection.debuglevel = 1 | |
| # Log requests library debug info | |
| import logging | |
| logging.basicConfig() | |
| logging.getLogger().setLevel(logging.DEBUG) | |
| requests_log = logging.getLogger("requests.packages.urllib3") | |
| requests_log.setLevel(logging.DEBUG) | |
| requests_log.propagate = True | |
| response = requests.post(url, data=data, headers=headers) | |
| print(f"Status code: {response.status_code}") | |
| print(f"Response headers: {dict(response.headers)}") | |
| print(f"Response content: {response.text}") | |
| if response.status_code == 200: | |
| try: | |
| return json.loads(response.text).get("access_token") | |
| except json.JSONDecodeError: | |
| print("Failed to parse response JSON") | |
| return None | |
| return None | |
| except Exception as e: | |
| print(f"Direct token exchange failed: {str(e)}") | |
| return None | |
| def main(): | |
| # Get GitHub OIDC token | |
| github_token = os.environ.get("OIDC_TOKEN") | |
| if not github_token: | |
| print("GitHub OIDC token not available") | |
| sys.exit(1) | |
| # Get Databricks connection parameters | |
| host = os.environ.get("DATABRICKS_HOST_FOR_TF") | |
| http_path = os.environ.get("DATABRICKS_HTTP_PATH_FOR_TF") | |
| identity_federation_client_id = os.environ.get("IDENTITY_FEDERATION_CLIENT_ID") | |
| if not host or not http_path: | |
| print("Missing Databricks connection parameters") | |
| sys.exit(1) | |
| claims = decode_jwt(github_token) | |
| if claims: | |
| print("\n=== GitHub OIDC Token Claims ===") | |
| print(f"Token issuer: {claims.get('iss')}") | |
| print(f"Token subject: {claims.get('sub')}") | |
| print(f"Token audience: {claims.get('aud')}") | |
| print(f"Token expiration: {claims.get('exp', 'unknown')}") | |
| print(f"Repository: {claims.get('repository', 'unknown')}") | |
| print(f"Workflow ref: {claims.get('workflow_ref', 'unknown')}") | |
| print(f"Event name: {claims.get('event_name', 'unknown')}") | |
| print("===============================\n") | |
| # Try token exchange with several possible audience values | |
| audience_values = [ | |
| "https://github.com/databricks", # Standard audience for GitHub tokens | |
| "https://github.com", # Alternative audience | |
| None # No audience | |
| ] | |
| # Direct token exchange test | |
| access_token = None | |
| for audience in audience_values: | |
| print(f"\n=== Testing Direct Token Exchange (audience={audience}) ===") | |
| result = test_direct_token_exchange(host, github_token, identity_federation_client_id, audience) | |
| if result: | |
| print("Direct token exchange successful!") | |
| access_token = result | |
| token_claims = decode_jwt(result) | |
| if token_claims: | |
| print(f"Databricks token subject: {token_claims.get('sub', 'unknown')}") | |
| break | |
| print(f"Token exchange failed with audience={audience}") | |
| # Add a small delay between attempts | |
| time.sleep(1) | |
| if not access_token: | |
| print("All token exchange attempts failed") | |
| print("=====================================\n") | |
| else: | |
| print("=====================================\n") | |
| try: | |
| # Connect to Databricks using token federation | |
| print(f"\n=== Testing Connection via Connector ===") | |
| print(f"Connecting to Databricks at {host}{http_path}") | |
| print(f"Using client ID: {identity_federation_client_id}") | |
| connection_params = { | |
| "server_hostname": host, | |
| "http_path": http_path, | |
| "access_token": github_token, | |
| "auth_type": "token-federation", | |
| "identity_federation_client_id": identity_federation_client_id, | |
| } | |
| print("Connection parameters:") | |
| print(json.dumps({k: v if k != 'access_token' else '***' for k, v in connection_params.items()}, indent=2)) | |
| with sql.connect(**connection_params) as connection: | |
| print("Connection established successfully") | |
| # Execute a simple query | |
| cursor = connection.cursor() | |
| cursor.execute("SELECT 1 + 1 as result") | |
| result = cursor.fetchall() | |
| print(f"Query result: {result[0][0]}") | |
| # Show current user | |
| cursor.execute("SELECT current_user() as user") | |
| result = cursor.fetchall() | |
| print(f"Connected as user: {result[0][0]}") | |
| print("Token federation test successful!") | |
| return True | |
| except Exception as e: | |
| print(f"Error connecting to Databricks: {str(e)}") | |
| print("===================================\n") | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() | |
| EOF | |
| chmod +x test_github_token_federation.py | |
| - name: Test token federation with GitHub OIDC token | |
| env: | |
| DATABRICKS_HOST_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST_FOR_TF }} | |
| DATABRICKS_HTTP_PATH_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_http_path || secrets.DATABRICKS_HTTP_PATH_FOR_TF }} | |
| IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID }} | |
| OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }} | |
| run: | | |
| python test_github_token_federation.py |