diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b5afb99f..a899dafaa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -385,6 +385,28 @@ * Added unit test coverage for local_eval_sets_manager.py ([174afb3](https://github.com/google/adk-python/commit/174afb3975bdc7e5f10c26f3eebb17d2efa0dd59)) * Extract common options for `adk web` and `adk api_server` ([01965bd](https://github.com/google/adk-python/commit/01965bdd74a9dbdb0ce91a924db8dee5961478b8)) +## [Unreleased] + +### Added +- **OAuth2 Client Credentials Flow Support**: Added comprehensive support for OAuth2 client credentials flow across ADK authentication infrastructure + - Enhanced `OAuth2CredentialExchanger` to detect and handle client credentials flow automatically + - Updated `OAuth2CredentialRefresher` to properly refresh client credentials tokens (by re-exchange) + - Improved `create_oauth2_session` utility to support client credentials session creation + - Enhanced `OAuthGrantType.from_flow()` method with better flow detection and documentation + - MCPToolset now supports OAuth2 client credentials authentication out-of-the-box + - Added comprehensive unit tests for client credentials functionality + - Added example usage in MCP stdio server agent sample + +### Changed +- `OAuth2CredentialExchanger.exchange()` now supports both authorization code and client credentials flows +- `OAuth2CredentialRefresher.refresh()` automatically detects grant type and uses appropriate refresh strategy +- `OAuthGrantType.from_flow()` return type changed to `Optional[OAuthGrantType]` for better type safety + +### Technical Details +- Client credentials flow prioritized over authorization code when both are configured +- Automatic token exchange without user interaction for machine-to-machine authentication +- Proper error handling and fallback for unsupported or misconfigured flows + ## 1.1.1 ### Features diff --git a/contributing/samples/mcp_oauth2_client_credentials_agent/README.md b/contributing/samples/mcp_oauth2_client_credentials_agent/README.md new file mode 100644 index 000000000..13e9c76e9 --- /dev/null +++ b/contributing/samples/mcp_oauth2_client_credentials_agent/README.md @@ -0,0 +1,465 @@ +# OAuth2 Client Credentials Flow with Automatic Discovery Sample + +## Overview + +This sample demonstrates the **OAuth2 client credentials authentication flow** with **automatic OAuth discovery** for MCP (Model Context Protocol) servers. It showcases enterprise-grade authentication that "just works" out of the box while providing flexibility for custom configurations. + +## Key Features + +- šŸš€ **Automatic OAuth Discovery**: Zero-configuration OAuth2 setup for HTTP-based MCP connections +- šŸ”§ **RFC 8414 Compliance**: Two-stage OAuth discovery following industry standards +- šŸ” **Complete Client Credentials Flow**: Full token exchange using the authlib library +- šŸ“ **Production-Ready**: Appropriate logging and error handling +- šŸŽÆ **Multiple Scenarios**: From simple automatic setup to advanced custom configurations + +## How It Works + +The ADK MCPToolset automatically: + +1. **Extracts base URL** from MCP connection parameters (e.g., `http://localhost:9204/mcp/` → `http://localhost:9204`) +2. **Enables OAuth discovery** for HTTP-based connections (StreamableHTTP, SSE) +3. **Discovers OAuth endpoints** via RFC 8414 two-stage process: + - Query `.well-known/oauth-protected-resource` to find authorization server + - Query authorization server's `.well-known/oauth-authorization-server` for token endpoint +4. **Exchanges client credentials** for access tokens using the discovered token endpoint +5. **Authenticates all MCP requests** with the obtained access tokens + +## Sample Scenarios + +### 1. Automatic OAuth Discovery (Simplest Case) + +The most common usage - everything happens automatically: + +```python +MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url='http://localhost:9204/mcp/', + ), + auth_credential=create_oauth2_credential( + client_id='your_client_id', + client_secret='your_client_secret' + ), + # Optional: Define scopes (will be used during discovery) + auth_scheme=create_oauth2_scheme( + token_url="", # Empty - will be automatically discovered + scopes={"api:read": "Read access", "api:write": "Write access"} + ), + # No auth_discovery parameter needed - automatic discovery enabled! +) +``` + +### 2. Custom OAuth Discovery Configuration + +Override default discovery behavior: + +```python +MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url='http://localhost:9204/mcp/', + ), + auth_credential=create_oauth2_credential( + client_id='your_client_id', + client_secret='your_client_secret' + ), + # Custom OAuth discovery configuration + auth_discovery=MCPAuthDiscovery( + base_url='http://auth-server.example.com:9205', # Different auth server + timeout=15.0, # Custom timeout + enabled=True + ), +) +``` + +### 3. No Auth Scheme (Minimal Configuration) + +Let discovery create the complete OAuth scheme: + +```python +MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url='http://localhost:9204/mcp/', + ), + auth_credential=create_oauth2_credential( + client_id='your_client_id', + client_secret='your_client_secret' + ), + # No auth_scheme - discovery will create one + # No auth_discovery - automatic discovery enabled +) +``` + +### 4. Disabled OAuth Discovery (Manual Configuration) + +Traditional OAuth2 setup without automatic discovery: + +```python +MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url='http://localhost:9204/mcp/', + ), + auth_credential=create_oauth2_credential( + client_id='your_client_id', + client_secret='your_client_secret' + ), + # Complete OAuth2 scheme with known token endpoint + auth_scheme=create_oauth2_scheme( + token_url="http://localhost:9204/token", # Known endpoint + scopes={"api:read": "Read access"} + ), + # Explicitly disable OAuth discovery + auth_discovery=MCPAuthDiscovery( + base_url="http://localhost:9204", + enabled=False + ), +) +``` + +### 5. Multiple MCP Servers + +Access multiple secured MCP servers with different configurations: + +```python +tools=[ + # Server 1: Automatic discovery + MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url='http://server1.example.com:9204/mcp/', + ), + auth_credential=create_oauth2_credential( + client_id='server1_client', + client_secret='server1_secret' + ), + ), + + # Server 2: Custom discovery server + MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url='http://server2.example.com:8080/api/mcp/', + ), + auth_credential=create_oauth2_credential( + client_id='server2_client', + client_secret='server2_secret' + ), + auth_discovery=MCPAuthDiscovery( + base_url='http://auth.server2.example.com:9000', + ), + ), +] +``` + +### 6. Self-Signed Certificates (SSL Verification Disabled) + +For development environments using self-signed SSL certificates: + +```python +MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url='https://localhost:9204/mcp/', # HTTPS with self-signed cert + ), + auth_credential=create_oauth2_credential( + client_id='your_client_id', + client_secret='your_client_secret' + ), + # Override just SSL verification - base_url auto-extracted from connection_params + auth_discovery=MCPAuthDiscovery( + verify_ssl=False, # Only override SSL verification + # base_url auto-extracted as "https://localhost:9204" from connection_params + ), +) +``` + +āš ļø **Security Warning**: Only disable SSL verification in development environments with self-signed certificates. Never disable SSL verification in production! + +### 7. Custom Settings Without Base URL Override + +Override multiple discovery settings while letting MCPToolset auto-extract the base_url: + +```python +MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url='http://localhost:9204/mcp/', + ), + auth_credential=create_oauth2_credential( + client_id='your_client_id', + client_secret='your_client_secret' + ), + # Override multiple settings - base_url auto-extracted + auth_discovery=MCPAuthDiscovery( + timeout=15.0, # Custom timeout + verify_ssl=True, # Explicit SSL verification (default) + enabled=True # Explicit enabled (default) + # base_url auto-extracted as "http://localhost:9204" from connection_params + ), +) +``` + +## Key Benefits + +1. **Zero Configuration**: OAuth discovery works out-of-the-box for HTTP connections +2. **Smart Auto-Extraction**: Base URL automatically extracted from MCP connection parameters +3. **Override Only What You Need**: No need to duplicate base_url when overriding other settings +4. **Standards Compliant**: Follows RFC 8414 OAuth2 Authorization Server Metadata +5. **Production Ready**: Appropriate logging levels and comprehensive error handling +6. **Backwards Compatible**: Existing OAuth2 configurations continue to work +7. **Flexible**: From automatic discovery to complete manual control +8. **Secure**: Uses industry-standard OAuth2 client credentials flow + +## Prerequisites + +### OAuth2 Server Setup + +You need an OAuth2 server that supports: + +1. **RFC 8414 OAuth discovery endpoints**: + - `/.well-known/oauth-protected-resource` + - `/.well-known/oauth-authorization-server` + +2. **Client credentials grant type** (`client_credentials`) + +3. **client_secret_post** authentication method (credentials in form body) + +### Self-Signed SSL Certificates + +If you're using self-signed SSL certificates for development: + +1. **Generate self-signed certificate**: + ```bash + # Generate private key + openssl genrsa -out server.key 2048 + + # Generate certificate + openssl req -new -x509 -key server.key -out server.crt -days 365 \ + -subj "/C=US/ST=CA/L=SF/O=Dev/CN=localhost" + ``` + +2. **Configure MCPAuthDiscovery**: + ```python + auth_discovery=MCPAuthDiscovery( + base_url='https://localhost:9204', + verify_ssl=False, # Disable for self-signed certs + ) + ``` + +3. **Test manually**: + ```bash + # Test with curl (skip SSL verification) + curl -k https://localhost:9204/.well-known/oauth-protected-resource + ``` + +### Environment Variables + +Set your OAuth2 credentials: + +```bash +export OAUTH2_CLIENT_ID="your_client_id" +export OAUTH2_CLIENT_SECRET="your_client_secret" + +# For multi-server setup: +export SERVER1_CLIENT_ID="server1_client_id" +export SERVER1_CLIENT_SECRET="server1_client_secret" +export SERVER2_CLIENT_ID="server2_client_id" +export SERVER2_CLIENT_SECRET="server2_client_secret" +``` + +## Running the Sample + +1. **Start the Mock OAuth2 Server** (included in this sample): + + **For HTTP (simple testing):** + ```bash + python mock_oauth_server.py + ``` + + **For HTTPS with self-signed certificates:** + ```bash + # First generate self-signed certificates + openssl genrsa -out server.key 2048 + openssl req -new -x509 -key server.key -out server.crt -days 365 \ + -subj "/C=US/ST=CA/L=SF/O=Dev/CN=localhost" + + # Then start server with SSL + python mock_oauth_server.py --ssl-keyfile server.key --ssl-certfile server.crt + ``` + + This starts a test OAuth2 server with: + - RFC 8414 discovery endpoints + - Client credentials grant support + - Demo client credentials (see server output) + - Optional HTTPS support for testing SSL configurations + +2. **Set environment variables** with OAuth2 credentials: + ```bash + export OAUTH2_CLIENT_ID="demo_client_id" + export OAUTH2_CLIENT_SECRET="demo_client_secret" + ``` + +3. **Run the agent**: + ```bash + python -m google.adk.cli.chatbot --agent-module contributing.samples.mcp_oauth2_client_credentials_agent + ``` + +4. **Try different agents** by modifying the `root_agent` variable in `agent.py`: + - `automatic_discovery_agent` (default) + - `custom_discovery_agent` + - `discovery_only_agent` + - `manual_config_agent` + - `multi_server_agent` + - `self_signed_ssl_agent` + - `custom_settings_agent` + +## Testing with the Mock OAuth2 Server + +The included `mock_oauth_server.py` provides a complete OAuth2 server for testing: + +### Server Features + +- **RFC 8414 Discovery Endpoints**: + - `/.well-known/oauth-protected-resource` + - `/.well-known/oauth-authorization-server` +- **OAuth2 Token Endpoint**: `/token` +- **Token Validation**: `/validate` (for debugging) +- **Health Check**: `/health` + +### Demo Clients + +The mock server includes these demo clients: + +| Client ID | Client Secret | Scopes | +|-----------|---------------|--------| +| `demo_client_id` | `demo_client_secret` | `api:read`, `api:write` | +| `server1_client` | `server1_secret` | `api:read` | +| `server2_client` | `server2_secret` | `api:read`, `api:write` | + +### Testing the OAuth Flow Manually + +You can test the OAuth flow manually using curl: + +1. **Discover OAuth Configuration**: + ```bash + curl http://localhost:9204/.well-known/oauth-protected-resource + curl http://localhost:9204/.well-known/oauth-authorization-server + ``` + +2. **Request Access Token**: + ```bash + curl -X POST http://localhost:9204/token \ + -d "grant_type=client_credentials" \ + -d "client_id=demo_client_id" \ + -d "client_secret=demo_client_secret" \ + -d "scope=api:read api:write" + ``` + +3. **Validate Token**: + ```bash + curl -H "Authorization: Bearer YOUR_TOKEN" \ + http://localhost:9204/validate + ``` + +## Debugging OAuth Discovery + +Enable debug logging to trace the OAuth discovery process: + +```python +import logging +logging.getLogger("google_adk").setLevel(logging.DEBUG) +``` + +You'll see detailed logs like: +``` +DEBUG:google_adk: šŸš€ Starting OAuth discovery process +DEBUG:google_adk: šŸ” Attempting OAuth discovery at server root: http://localhost:9204 +DEBUG:google_adk: āœ… OAuth discovery successful - updating tokenUrl in existing scheme +DEBUG:google_adk: šŸ” Performing OAuth token exchange for session authentication +DEBUG:google_adk: āœ… Successfully obtained access token for session +``` + +## Troubleshooting + +### SSL Certificate Issues + +If you encounter SSL certificate verification errors with self-signed certificates: + +``` +ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: self-signed certificate +``` + +**Solution**: Set `verify_ssl=False` in your MCPAuthDiscovery configuration: + +```python +auth_discovery=MCPAuthDiscovery( + verify_ssl=False # Disables SSL verification for self-signed certificates +) +``` + +### OAuth2Session Constructor Errors + +If you encounter OAuth2Session constructor conflicts: + +``` +authlib.oauth2.client.OAuth2Client.__init__() got multiple values for keyword argument 'session' +``` + +This issue has been resolved in the latest version. The ADK now properly sets SSL verification on the OAuth2Session object without constructor conflicts. + +### Authentication Failures + +If OAuth token exchange fails: + +1. **Check client credentials**: Ensure your `client_id` and `client_secret` are correct +2. **Verify discovery endpoints**: Confirm your server supports RFC 8414 discovery +3. **Test manually**: Use curl to test OAuth discovery and token exchange +4. **Enable debug logging**: Check detailed OAuth flow logs for specific errors + +### Testing Your Setup + +You can test the complete OAuth flow manually: + +1. **Test Discovery Endpoints**: + ```bash + # Test discovery (use -k for self-signed certificates) + curl -k https://localhost:9204/.well-known/oauth-protected-resource + curl -k https://localhost:9204/.well-known/oauth-authorization-server + ``` + +2. **Test Token Exchange**: + ```bash + # Test client credentials flow (use -k for self-signed certificates) + curl -k -X POST https://localhost:9204/token \ + -d "grant_type=client_credentials" \ + -d "client_id=demo_client_id" \ + -d "client_secret=demo_client_secret" \ + -d "scope=api:read api:write" + ``` + +3. **Verify Token**: + ```bash + # Test token validation (replace YOUR_TOKEN with actual token) + curl -k -H "Authorization: Bearer YOUR_TOKEN" \ + https://localhost:9204/validate + ``` + +## Benefits + +1. **Zero Configuration**: OAuth discovery works out-of-the-box for HTTP connections +2. **Standards Compliant**: Follows RFC 8414 OAuth2 Authorization Server Metadata +3. **Production Ready**: Appropriate logging levels and comprehensive error handling +4. **Backwards Compatible**: Existing OAuth2 configurations continue to work +5. **Flexible**: From automatic discovery to complete manual control +6. **Secure**: Uses industry-standard OAuth2 client credentials flow + +## Common Use Cases + +- **Enterprise API Integration**: Secure access to internal MCP services +- **Multi-tenant Applications**: Different OAuth configurations per tenant +- **Development vs Production**: Automatic discovery in dev, manual config in prod +- **Microservices Architecture**: Multiple MCP servers with centralized auth +- **Third-party Integrations**: Secure access to external MCP providers + +## Next Steps + +- Explore the different agent configurations in `agent.py` +- Try connecting to your own OAuth2-enabled MCP server +- Experiment with custom discovery configurations +- Implement your own MCP tools with OAuth2 authentication + +This sample demonstrates the power of **convention over configuration** while maintaining full flexibility for advanced use cases. The OAuth2 client credentials flow with automatic discovery makes enterprise authentication simple and reliable! šŸš€ \ No newline at end of file diff --git a/contributing/samples/mcp_oauth2_client_credentials_agent/__init__.py b/contributing/samples/mcp_oauth2_client_credentials_agent/__init__.py new file mode 100644 index 000000000..b6b4adcdb --- /dev/null +++ b/contributing/samples/mcp_oauth2_client_credentials_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent \ No newline at end of file diff --git a/contributing/samples/mcp_oauth2_client_credentials_agent/agent.py b/contributing/samples/mcp_oauth2_client_credentials_agent/agent.py new file mode 100644 index 000000000..f67e18661 --- /dev/null +++ b/contributing/samples/mcp_oauth2_client_credentials_agent/agent.py @@ -0,0 +1,329 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +OAuth2 Client Credentials Flow with Automatic Discovery Sample + +This sample demonstrates the OAuth2 client credentials authentication flow +with automatic OAuth discovery for MCP servers. It shows multiple scenarios +from basic automatic discovery to custom configurations. +""" + +import os +from fastapi.openapi.models import OAuth2, OAuthFlows, OAuthFlowClientCredentials + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes, OAuth2Auth +from google.adk.tools.mcp_tool.mcp_auth_discovery import MCPAuthDiscovery +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset + + +def create_oauth2_credential(client_id: str, client_secret: str) -> AuthCredential: + """Helper function to create OAuth2 credentials.""" + return AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id=client_id, + client_secret=client_secret + ) + ) + + +def create_oauth2_scheme(token_url: str, scopes: dict[str, str]) -> OAuth2: + """Helper function to create OAuth2 auth scheme.""" + return OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl=token_url, + scopes=scopes + ) + ) + ) + + +# ============================================================================= +# Scenario 1: Automatic OAuth Discovery (Simplest Case) +# ============================================================================= + +# This is the simplest way to use OAuth2 client credentials with MCP. +# MCPToolset automatically: +# 1. Extracts base URL from the MCP connection (http://localhost:9204) +# 2. Enables OAuth discovery for HTTP-based connections +# 3. Discovers OAuth endpoints via RFC 8414 two-stage process +# 4. Exchanges client credentials for access tokens +# 5. Authenticates all MCP requests + +automatic_discovery_agent = LlmAgent( + model='gemini-2.0-flash', + name='oauth2_automatic_agent', + instruction=""" +You are an assistant that can access secured MCP tools using OAuth2 authentication. +The OAuth2 discovery and authentication happens automatically behind the scenes. +Help users by calling the available MCP tools as needed. + """, + tools=[ + MCPToolset( + # MCP server connection + connection_params=StreamableHTTPConnectionParams( + url='http://localhost:9204/mcp/', + ), + # OAuth2 credentials (client_id and client_secret) + auth_credential=create_oauth2_credential( + client_id=os.getenv('OAUTH2_CLIENT_ID', 'demo_client_id'), + client_secret=os.getenv('OAUTH2_CLIENT_SECRET', 'demo_client_secret') + ), + # Optional: Define scopes in auth_scheme (will be used during discovery) + auth_scheme=create_oauth2_scheme( + token_url="", # Empty - will be automatically discovered + scopes={ + "api:read": "Read access to API", + "api:write": "Write access to API" + } + ), + # No auth_discovery parameter needed - automatic discovery enabled! + ) + ], +) + + +# ============================================================================= +# Scenario 2: Custom OAuth Discovery Configuration +# ============================================================================= + +# Sometimes you need to customize the OAuth discovery process: +# - Use a different OAuth server than the MCP server +# - Adjust discovery timeout +# - Specify exact discovery endpoints + +custom_discovery_agent = LlmAgent( + model='gemini-2.0-flash', + name='oauth2_custom_discovery_agent', + instruction=""" +You are an assistant using custom OAuth2 discovery configuration. +This demonstrates how to override default discovery behavior. + """, + tools=[ + MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url='http://localhost:9204/mcp/', + ), + auth_credential=create_oauth2_credential( + client_id=os.getenv('OAUTH2_CLIENT_ID', 'demo_client_id'), + client_secret=os.getenv('OAUTH2_CLIENT_SECRET', 'demo_client_secret') + ), + auth_scheme=create_oauth2_scheme( + token_url="", # Empty - will be discovered + scopes={"api:read": "Read access"} + ), + # Custom OAuth discovery configuration + auth_discovery=MCPAuthDiscovery( + base_url='http://auth-server.example.com:9205', # Custom auth server + timeout=15.0, # Custom timeout + enabled=True + ), + ) + ], +) + + +# ============================================================================= +# Scenario 3: No Auth Scheme (Discovery Creates Complete Scheme) +# ============================================================================= + +# When no auth_scheme is provided, OAuth discovery will create a complete +# OAuth2 scheme with discovered endpoints and default scopes + +discovery_only_agent = LlmAgent( + model='gemini-2.0-flash', + name='oauth2_discovery_only_agent', + instruction=""" +You are an assistant where OAuth2 discovery creates the complete auth scheme. +This shows the most minimal configuration possible. + """, + tools=[ + MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url='http://localhost:9204/mcp/', + ), + auth_credential=create_oauth2_credential( + client_id=os.getenv('OAUTH2_CLIENT_ID', 'demo_client_id'), + client_secret=os.getenv('OAUTH2_CLIENT_SECRET', 'demo_client_secret') + ), + # No auth_scheme - discovery will create one + # No auth_discovery - automatic discovery enabled + ) + ], +) + + +# ============================================================================= +# Scenario 4: Disabled OAuth Discovery (Manual Configuration) +# ============================================================================= + +# Sometimes you want to disable automatic discovery and provide +# the complete OAuth configuration manually + +manual_config_agent = LlmAgent( + model='gemini-2.0-flash', + name='oauth2_manual_config_agent', + instruction=""" +You are an assistant using manual OAuth2 configuration without discovery. +This demonstrates traditional OAuth2 setup without automatic discovery. + """, + tools=[ + MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url='http://localhost:9204/mcp/', + ), + auth_credential=create_oauth2_credential( + client_id=os.getenv('OAUTH2_CLIENT_ID', 'demo_client_id'), + client_secret=os.getenv('OAUTH2_CLIENT_SECRET', 'demo_client_secret') + ), + # Complete OAuth2 scheme with known token endpoint + auth_scheme=create_oauth2_scheme( + token_url="http://localhost:9204/token", # Known token endpoint + scopes={"api:read": "Read access", "api:write": "Write access"} + ), + # Explicitly disable OAuth discovery + auth_discovery=MCPAuthDiscovery( + base_url="http://localhost:9204", + enabled=False # Discovery disabled + ), + ) + ], +) + + +# ============================================================================= +# Scenario 5: Multiple MCP Servers with Different Auth Configurations +# ============================================================================= + +# This demonstrates using multiple MCP servers with different OAuth setups + +multi_server_agent = LlmAgent( + model='gemini-2.0-flash', + name='oauth2_multi_server_agent', + instruction=""" +You are an assistant that can access multiple secured MCP servers, +each with their own OAuth2 configuration. Choose the appropriate +toolset based on the user's needs. + """, + tools=[ + # Server 1: Automatic discovery + MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url='http://server1.example.com:9204/mcp/', + ), + auth_credential=create_oauth2_credential( + client_id=os.getenv('SERVER1_CLIENT_ID', 'server1_client'), + client_secret=os.getenv('SERVER1_CLIENT_SECRET', 'server1_secret') + ), + tool_filter=['list_tools', 'get_info'], # Simple string list filter + ), + + # Server 2: Custom discovery server + MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url='http://server2.example.com:8080/api/mcp/', + ), + auth_credential=create_oauth2_credential( + client_id=os.getenv('SERVER2_CLIENT_ID', 'server2_client'), + client_secret=os.getenv('SERVER2_CLIENT_SECRET', 'server2_secret') + ), + auth_discovery=MCPAuthDiscovery( + base_url='http://auth.server2.example.com:9000', + timeout=20.0 + ), + tool_filter=['query_data', 'update_records'], # Simple string list filter + ), + ], +) + + +# ============================================================================= +# Scenario 6: Self-Signed Certificates (SSL Verification Disabled) +# ============================================================================= + +# This scenario demonstrates using OAuth2 discovery with self-signed certificates +# by disabling SSL certificate verification - useful for development environments +# Note: base_url is auto-extracted from connection_params, only verify_ssl is overridden + +self_signed_ssl_agent = LlmAgent( + model='gemini-2.0-flash', + name='oauth2_self_signed_ssl_agent', + instruction=""" +You are an assistant that can access MCP servers using self-signed SSL certificates. +SSL certificate verification is disabled for development environments. + """, + tools=[ + MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url='https://localhost:9204/mcp/', # HTTPS with self-signed cert + ), + auth_credential=create_oauth2_credential( + client_id=os.getenv('OAUTH2_CLIENT_ID', 'demo_client_id'), + client_secret=os.getenv('OAUTH2_CLIENT_SECRET', 'demo_client_secret') + ), + # Override just SSL verification - base_url auto-extracted from connection_params + auth_discovery=MCPAuthDiscovery( + verify_ssl=False, # Only override SSL verification + # base_url auto-extracted as "https://localhost:9204" from connection_params + ), + ) + ], +) + + +# ============================================================================= +# Scenario 7: Custom Settings Without Base URL Override +# ============================================================================= + +# This scenario shows overriding multiple discovery settings while letting +# MCPToolset auto-extract the base_url from connection parameters + +custom_settings_agent = LlmAgent( + model='gemini-2.0-flash', + name='oauth2_custom_settings_agent', + instruction=""" +You are an assistant with custom OAuth discovery settings. +The base URL is automatically extracted from the MCP connection. + """, + tools=[ + MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url='http://localhost:9204/mcp/', + ), + auth_credential=create_oauth2_credential( + client_id=os.getenv('OAUTH2_CLIENT_ID', 'demo_client_id'), + client_secret=os.getenv('OAUTH2_CLIENT_SECRET', 'demo_client_secret') + ), + # Override multiple settings - base_url auto-extracted + auth_discovery=MCPAuthDiscovery( + timeout=15.0, # Custom timeout + verify_ssl=True, # Explicit SSL verification (default) + enabled=True # Explicit enabled (default) + # base_url auto-extracted as "http://localhost:9204" from connection_params + ), + ) + ], +) + + +# ============================================================================= +# Default agent for the sample (most commonly used scenario) +# ============================================================================= + +# The default agent demonstrates the most common usage: automatic OAuth discovery +root_agent = automatic_discovery_agent \ No newline at end of file diff --git a/contributing/samples/mcp_oauth2_client_credentials_agent/mock_oauth_server.py b/contributing/samples/mcp_oauth2_client_credentials_agent/mock_oauth_server.py new file mode 100644 index 000000000..052c036a3 --- /dev/null +++ b/contributing/samples/mcp_oauth2_client_credentials_agent/mock_oauth_server.py @@ -0,0 +1,296 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Mock OAuth2 Server for Testing OAuth2 Client Credentials Flow + +This is a minimal OAuth2 server implementation that supports: +- RFC 8414 OAuth discovery endpoints +- Client credentials grant type +- Testing the ADK OAuth2 functionality + +DO NOT use this in production - it's for demonstration purposes only! +""" + +import asyncio +import json +import logging +import time +from typing import Dict, Any + +from fastapi import FastAPI, Form, HTTPException, Request +from fastapi.responses import JSONResponse +import uvicorn + + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +app = FastAPI(title="Mock OAuth2 Server for ADK Testing") + +# Simple in-memory client store (DO NOT use in production!) +CLIENTS = { + "demo_client_id": { + "client_secret": "demo_client_secret", + "scopes": ["api:read", "api:write"] + }, + "server1_client": { + "client_secret": "server1_secret", + "scopes": ["api:read"] + }, + "server2_client": { + "client_secret": "server2_secret", + "scopes": ["api:read", "api:write"] + } +} + +# Simple token store (DO NOT use in production!) +TOKENS: Dict[str, Dict[str, Any]] = {} + + +@app.get("/.well-known/oauth-protected-resource") +async def oauth_protected_resource(request: Request): + """ + RFC 8414 OAuth Protected Resource Discovery + + This endpoint tells clients where to find the authorization server. + """ + base_url = f"{request.url.scheme}://{request.url.netloc}" + + logger.info(f"šŸ” OAuth protected resource discovery requested from {request.client.host if request.client else 'unknown'}") + + return JSONResponse({ + "authorization_servers": [base_url] + }) + + +@app.get("/.well-known/oauth-authorization-server") +async def oauth_authorization_server(request: Request): + """ + RFC 8414 OAuth Authorization Server Metadata + + This endpoint provides metadata about the authorization server capabilities. + """ + base_url = f"{request.url.scheme}://{request.url.netloc}" + + logger.info(f"šŸ” OAuth authorization server metadata requested from {request.client.host if request.client else 'unknown'}") + + return JSONResponse({ + "issuer": base_url, + "token_endpoint": f"{base_url}/token", + "grant_types_supported": ["client_credentials"], + "token_endpoint_auth_methods_supported": ["client_secret_post"], + "scopes_supported": ["api:read", "api:write"], + "response_types_supported": ["token"] + }) + + +@app.post("/token") +async def token_endpoint( + request: Request, + grant_type: str = Form(...), + client_id: str = Form(...), + client_secret: str = Form(...), + scope: str = Form(None) +): + """ + OAuth2 Token Endpoint - Client Credentials Grant + + This endpoint exchanges client credentials for access tokens. + """ + logger.info(f"šŸ” Token request from {request.client.host if request.client else 'unknown'}") + logger.info(f" Grant Type: {grant_type}") + logger.info(f" Client ID: {client_id}") + logger.info(f" Scopes: {scope or 'default'}") + + # Validate grant type + if grant_type != "client_credentials": + logger.warning(f"āŒ Unsupported grant type: {grant_type}") + raise HTTPException( + status_code=400, + detail={ + "error": "unsupported_grant_type", + "error_description": "Only client_credentials grant type is supported" + } + ) + + # Validate client credentials + if client_id not in CLIENTS: + logger.warning(f"āŒ Unknown client ID: {client_id}") + raise HTTPException( + status_code=401, + detail={ + "error": "invalid_client", + "error_description": "Unknown client ID" + } + ) + + client = CLIENTS[client_id] + if client["client_secret"] != client_secret: + logger.warning(f"āŒ Invalid client secret for client: {client_id}") + raise HTTPException( + status_code=401, + detail={ + "error": "invalid_client", + "error_description": "Invalid client secret" + } + ) + + # Validate scopes (if provided) + requested_scopes = scope.split() if scope else client["scopes"] + for requested_scope in requested_scopes: + if requested_scope not in client["scopes"]: + logger.warning(f"āŒ Unauthorized scope '{requested_scope}' for client: {client_id}") + raise HTTPException( + status_code=400, + detail={ + "error": "invalid_scope", + "error_description": f"Scope '{requested_scope}' not authorized for client" + } + ) + + # Generate access token (simple approach for demo) + access_token = f"demo_token_{client_id}_{int(time.time())}" + expires_in = 3600 # 1 hour + + # Store token (in production, use proper token storage) + TOKENS[access_token] = { + "client_id": client_id, + "scopes": requested_scopes, + "expires_at": time.time() + expires_in + } + + logger.info(f"āœ… Successfully issued token for client: {client_id}") + logger.info(f" Token: {access_token[:20]}...") + logger.info(f" Scopes: {' '.join(requested_scopes)}") + + return JSONResponse({ + "access_token": access_token, + "token_type": "Bearer", + "expires_in": expires_in, + "scope": " ".join(requested_scopes) + }) + + +@app.get("/validate") +async def validate_token(request: Request): + """ + Token validation endpoint (for debugging purposes) + + This endpoint allows you to validate tokens issued by this server. + """ + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Missing or invalid Authorization header") + + token = auth_header[7:] # Remove "Bearer " prefix + + if token not in TOKENS: + raise HTTPException(status_code=401, detail="Invalid token") + + token_info = TOKENS[token] + + # Check if token is expired + if time.time() > token_info["expires_at"]: + del TOKENS[token] # Clean up expired token + raise HTTPException(status_code=401, detail="Token expired") + + logger.info(f"āœ… Valid token for client: {token_info['client_id']}") + + return JSONResponse({ + "valid": True, + "client_id": token_info["client_id"], + "scopes": token_info["scopes"], + "expires_at": token_info["expires_at"] + }) + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return JSONResponse({"status": "healthy", "server": "Mock OAuth2 Server"}) + + +@app.get("/") +async def root(): + """Root endpoint with server information.""" + return JSONResponse({ + "server": "Mock OAuth2 Server for ADK Testing", + "endpoints": { + "discovery": "/.well-known/oauth-protected-resource", + "metadata": "/.well-known/oauth-authorization-server", + "token": "/token", + "validate": "/validate", + "health": "/health" + }, + "demo_clients": { + "demo_client_id": "demo_client_secret", + "server1_client": "server1_secret", + "server2_client": "server2_secret" + }, + "note": "This is a demo server - DO NOT use in production!" + }) + + +async def main(): + """Run the mock OAuth2 server.""" + import argparse + + parser = argparse.ArgumentParser(description="Mock OAuth2 Server for ADK Testing") + parser.add_argument("--ssl-keyfile", help="SSL private key file for HTTPS") + parser.add_argument("--ssl-certfile", help="SSL certificate file for HTTPS") + parser.add_argument("--host", default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)") + parser.add_argument("--port", type=int, default=9204, help="Port to bind to (default: 9204)") + args = parser.parse_args() + + protocol = "https" if args.ssl_keyfile and args.ssl_certfile else "http" + + print("šŸš€ Starting Mock OAuth2 Server for ADK Testing") + print(f"šŸ“ Server will be available at: {protocol}://{args.host}:{args.port}") + print(f"šŸ” Discovery endpoint: {protocol}://{args.host}:{args.port}/.well-known/oauth-protected-resource") + print(f"šŸ” Token endpoint: {protocol}://{args.host}:{args.port}/token") + + if protocol == "https": + print("šŸ”’ HTTPS mode enabled with SSL certificates") + print(f" SSL Key: {args.ssl_keyfile}") + print(f" SSL Cert: {args.ssl_certfile}") + print("āš ļø If using self-signed certificates, set verify_ssl=False in MCPAuthDiscovery") + else: + print("šŸ”“ HTTP mode (no SSL)") + + print("āš ļø This is for testing only - DO NOT use in production!") + print() + print("Demo clients:") + for client_id, client_info in CLIENTS.items(): + print(f" • {client_id} / {client_info['client_secret']} (scopes: {', '.join(client_info['scopes'])})") + print() + + config = uvicorn.Config( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile + ) + server = uvicorn.Server(config) + await server.serve() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nšŸ›‘ Mock OAuth2 Server stopped") \ No newline at end of file diff --git a/contributing/samples/mcp_stdio_server_agent/README.md b/contributing/samples/mcp_stdio_server_agent/README.md new file mode 100644 index 000000000..bc662647d --- /dev/null +++ b/contributing/samples/mcp_stdio_server_agent/README.md @@ -0,0 +1,243 @@ +# MCP Server Agent Sample + +This sample demonstrates how to use MCPToolset in ADK, including both basic file system access and OAuth2 client credentials authentication for external APIs. + +## Features + +### 1. Basic MCP Integration (Stdio) +- File system access using the filesystem MCP server +- Read-only operations (configured with tool filters) +- Local directory access + +### 2. OAuth2 Client Credentials Authentication (New!) +- Machine-to-machine authentication for external APIs +- Automatic token exchange and refresh +- Secure access to protected MCP servers + +## Setup + +### Basic File System Access + +The basic example works out of the box: + +```bash +# Install dependencies +npm install -g @modelcontextprotocol/server-filesystem + +# Run the agent +python agent.py +``` + +### OAuth2 Client Credentials Setup + +To use the OAuth2 client credentials functionality: + +1. **Get OAuth2 Credentials**: Obtain client ID and secret from your OAuth provider +2. **Configure Environment Variables**: + ```bash + export OAUTH_CLIENT_ID="your_client_id_here" + export OAUTH_CLIENT_SECRET="your_client_secret_here" + ``` + +3. **Update Configuration**: Uncomment and configure the OAuth2 section in `agent.py`: + +```python +# Configuration - replace with your actual values +CLIENT_ID = os.getenv("OAUTH_CLIENT_ID", "your_client_id_here") +CLIENT_SECRET = os.getenv("OAUTH_CLIENT_SECRET", "your_client_secret_here") +TOKEN_URL = "https://your-oauth-provider.com/token" +MCP_SERVER_URL = "https://your-mcp-server.com" +SCOPES = ["read", "write"] + +# Define OAuth2 client credentials scheme +auth_scheme = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl=TOKEN_URL, + scopes={ + "read": "Read access to resources", + "write": "Write access to resources" + } + ) + ) +) + +# Provide client credentials (ADK will automatically handle token exchange) +auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET + ) +) + +# Create MCPToolset with OAuth2 client credentials authentication +return MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url=MCP_SERVER_URL, + timeout=30 + ), + auth_scheme=auth_scheme, + auth_credential=auth_credential +) +``` + +## OAuth2 Client Credentials Flow + +### How It Works + +1. **No User Interaction**: Client credentials flow is designed for machine-to-machine authentication +2. **Automatic Token Exchange**: ADK automatically exchanges your client ID/secret for access tokens +3. **Token Refresh**: ADK handles token expiration and refresh automatically +4. **Secure Headers**: Access tokens are automatically added to MCP server requests + +### Flow Diagram + +``` +Client Application + ↓ + ADK Framework + ↓ (client_id, client_secret) + OAuth2 Provider + ↓ (access_token) + MCP Server + ↓ (API response) + Client Application +``` + +### Key Advantages + +- **Secure**: No user credentials stored, only OAuth2 tokens +- **Automatic**: No manual token management required +- **Standards Compliant**: Uses standard OAuth2 client credentials flow +- **Flexible**: Works with any OAuth2-compliant provider + +## Usage Examples + +### Basic File Operations + +```python +# The agent can access local files +agent.run("List the files in the current directory") +agent.run("Read the content of README.md") +``` + +### OAuth2-Protected API Access + +```python +# Once OAuth2 is configured, the agent can access protected APIs +agent.run("Get my user profile from the API") +agent.run("List my resources") +agent.run("Create a new item") +``` + +## Supported OAuth2 Providers + +This implementation works with any OAuth2 provider that supports client credentials flow: + +- **Google Cloud Platform**: For accessing Google APIs +- **Microsoft Azure**: For accessing Microsoft Graph and other APIs +- **Auth0**: For custom applications +- **Okta**: For enterprise applications +- **Custom OAuth2 Servers**: Any RFC 6749 compliant server + +## Security Best Practices + +1. **Environment Variables**: Store client secrets in environment variables, not code +2. **Least Privilege**: Request only the scopes you actually need +3. **Token Rotation**: Let ADK handle automatic token refresh +4. **HTTPS Only**: Always use HTTPS for token endpoints and MCP servers +5. **Secret Management**: Use proper secret management systems in production + +## Troubleshooting + +### Common Issues + +1. **Invalid Client Credentials** + - Verify your client ID and secret + - Check that the client is authorized for the specified scopes + +2. **Token Exchange Failures** + - Verify the token URL is correct + - Check network connectivity to the OAuth provider + +3. **MCP Server Authentication Errors** + - Ensure the MCP server expects Bearer token authentication + - Verify the server accepts your access token format + +### Debug Logging + +Enable debug logging to see OAuth2 flow details: + +```python +import logging +logging.getLogger("google_adk").setLevel(logging.DEBUG) +``` + +## Advanced Configuration + +### Custom Scopes + +```python +auth_scheme = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl=TOKEN_URL, + scopes={ + "read:users": "Read user information", + "write:data": "Write application data", + "admin:settings": "Manage application settings" + } + ) + ) +) +``` + +### Multiple MCP Servers + +```python +# Create multiple toolsets with different authentication +public_mcp_toolset = MCPToolset(connection_params=public_params) +authenticated_mcp_toolset = MCPToolset( + connection_params=auth_params, + auth_scheme=auth_scheme, + auth_credential=auth_credential +) + +# Use both in the same agent +agent = LlmAgent( + tools=[public_mcp_toolset, authenticated_mcp_toolset] +) +``` + +## Migration from Manual Token Management + +If you were previously managing OAuth2 tokens manually: + +### Before (Manual) +```python +# Manual token management (error-prone) +access_token = get_access_token_manually(client_id, client_secret) +headers = {"Authorization": f"Bearer {access_token}"} +# Need to handle refresh manually +``` + +### After (ADK Automatic) +```python +# ADK handles everything automatically +mcp_toolset = MCPToolset( + connection_params=connection_params, + auth_scheme=oauth2_scheme, + auth_credential=oauth2_credential +) +# Tokens are managed automatically +``` + +## Contributing + +Found an issue or want to improve OAuth2 support? Please: + +1. Check existing issues in the ADK repository +2. Create a detailed bug report or feature request +3. Include your OAuth2 provider details (if relevant) +4. Provide minimal reproduction examples \ No newline at end of file diff --git a/contributing/samples/mcp_stdio_server_agent/agent.py b/contributing/samples/mcp_stdio_server_agent/agent.py index fe8b75c21..3958161c4 100755 --- a/contributing/samples/mcp_stdio_server_agent/agent.py +++ b/contributing/samples/mcp_stdio_server_agent/agent.py @@ -16,51 +16,158 @@ import os from google.adk.agents.llm_agent import LlmAgent -from google.adk.tools.mcp_tool import StdioConnectionParams +from google.adk.tools.mcp_tool import StdioConnectionParams, StreamableHTTPConnectionParams from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset from mcp import StdioServerParameters +# Example 1: Basic MCP Toolset with Stdio (existing example) _allowed_path = os.path.dirname(os.path.abspath(__file__)) -root_agent = LlmAgent( - model='gemini-2.0-flash', - name='enterprise_assistant', - instruction=f"""\ -Help user accessing their file systems. - -Allowed directory: {_allowed_path} - """, - tools=[ - MCPToolset( - connection_params=StdioConnectionParams( - server_params=StdioServerParameters( - command='npx', - args=[ - '-y', # Arguments for the command - '@modelcontextprotocol/server-filesystem', - _allowed_path, - ], - ), - timeout=5, - ), - # don't want agent to do write operation - # you can also do below - # tool_filter=lambda tool, ctx=None: tool.name - # not in [ - # 'write_file', - # 'edit_file', - # 'create_directory', - # 'move_file', - # ], - tool_filter=[ - 'read_file', - 'read_multiple_files', - 'list_directory', - 'directory_tree', - 'search_files', - 'get_file_info', - 'list_allowed_directories', +basic_mcp_toolset = MCPToolset( + connection_params=StdioConnectionParams( + server_params=StdioServerParameters( + command='npx', + args=[ + '-y', # Arguments for the command + '@modelcontextprotocol/server-filesystem', + _allowed_path, ], - ) + ), + timeout=5, + ), + # don't want agent to do write operation + # you can also do below + # tool_filter=lambda tool, ctx=None: tool.name + # not in [ + # 'write_file', + # 'edit_file', + # 'create_directory', + # 'move_file', + # ], + tool_filter=[ + 'read_file', + 'read_multiple_files', + 'list_directory', + 'directory_tree', + 'search_files', + 'get_file_info', + 'list_allowed_directories', ], ) + +# Example 2: MCP Toolset with OAuth2 Client Credentials + Auto Discovery (new functionality!) +def create_oauth2_auto_discovery_mcp_toolset(): + """ + Example of creating an MCPToolset with automatic OAuth discovery. + + This demonstrates the powerful new OAuth discovery feature added to ADK. + No manual token URL configuration needed! + """ + from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes, OAuth2Auth + + # Configuration - replace with your actual values + CLIENT_ID = os.getenv("OAUTH_CLIENT_ID", "your_client_id_here") + CLIENT_SECRET = os.getenv("OAUTH_CLIENT_SECRET", "your_client_secret_here") + SERVER_URL = "https://your-mcp-server.com" # Your OAuth-protected MCP server + + # Create auth credential (no auth scheme needed for auto-discovery!) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + ) + ) + + # Create MCPToolset with automatic OAuth discovery + # ADK will automatically: + # 1. Query .well-known/oauth-protected-resource + # 2. Query .well-known/oauth-authorization-server + # 3. Extract token endpoints and create OAuth2 scheme + # 4. Handle token exchange and refresh seamlessly + toolset = MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url=SERVER_URL + ), + auth_credential=auth_credential, + auto_discover_oauth=True, # Enable automatic discovery! + discovery_timeout=10.0, # Discovery timeout (optional) + discovery_scopes=["read", "write"], # Scopes to request (optional) + tool_filter=["list_tools", "call_tool"] # Optional tool filtering + ) + + return toolset + +# Example 3: MCP Toolset with Manual OAuth2 Client Credentials (traditional approach) +def create_oauth2_manual_mcp_toolset(): + """ + Example of creating an MCPToolset with manual OAuth configuration. + + Use this approach when you need explicit control over OAuth endpoints + or when the server doesn't support discovery. + """ + from fastapi.openapi.models import OAuth2, OAuthFlows, OAuthFlowClientCredentials + from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes, OAuth2Auth + + # Configuration - replace with your actual values + CLIENT_ID = os.getenv("OAUTH_CLIENT_ID", "your_client_id_here") + CLIENT_SECRET = os.getenv("OAUTH_CLIENT_SECRET", "your_client_secret_here") + TOKEN_URL = "https://auth.example.com/token" + SERVER_URL = "https://your-mcp-server.com" + + # Manually create OAuth2 auth scheme + auth_scheme = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl=TOKEN_URL, + scopes={ + "read": "Read access", + "write": "Write access" + } + ) + ) + ) + + # Create auth credential + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + ) + ) + + # Create MCPToolset with explicit OAuth configuration + toolset = MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url=SERVER_URL + ), + auth_scheme=auth_scheme, # Explicit auth scheme + auth_credential=auth_credential, + auto_discover_oauth=False, # Disable discovery when using explicit scheme + ) + + return toolset + +# Create the authenticated toolset (will be None if not configured) +authenticated_toolset = create_oauth2_auto_discovery_mcp_toolset() + +# Create agent with both file system and (optionally) authenticated MCP toolset +if authenticated_toolset: + # Both file system and authenticated toolsets + toolsets = [basic_mcp_toolset, authenticated_toolset] +else: + # Just file system toolset + toolsets = [basic_mcp_toolset] + +agent = LlmAgent( + model='gemini-2.0-flash-exp', + name='file_mcp_agent', + instruction=""" +You are a helpful assistant with access to file system operations via MCP tools. +You can read files, list directories, and perform other file operations as requested. + +If OAuth-protected MCP tools are available, you can also access authenticated services. + """, + tools=toolsets, +) diff --git a/src/google/adk/auth/__init__.py b/src/google/adk/auth/__init__.py index 49fba3768..daa520a3b 100644 --- a/src/google/adk/auth/__init__.py +++ b/src/google/adk/auth/__init__.py @@ -12,11 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Auth configurations for Google ADK.""" + from .auth_credential import AuthCredential from .auth_credential import AuthCredentialTypes from .auth_credential import OAuth2Auth +from .auth_credential import ServiceAccount from .auth_handler import AuthHandler from .auth_schemes import AuthScheme -from .auth_schemes import AuthSchemeType +from .auth_schemes import OAuthGrantType from .auth_schemes import OpenIdConnectWithConfig -from .auth_tool import AuthConfig +from .credential_manager import CredentialManager + +# OAuth discovery utilities - imported conditionally to avoid circular imports +try: + from .oauth2_discovery_util import create_oauth_scheme_from_discovery + _discovery_available = True +except ImportError: + _discovery_available = False + +__all__ = [ + "AuthCredential", + "AuthCredentialTypes", + "AuthHandler", + "AuthScheme", + "CredentialManager", + "OAuthGrantType", + "OAuth2Auth", + "OpenIdConnectWithConfig", + "ServiceAccount", +] + +# Add discovery utilities to __all__ if available +if _discovery_available: + __all__.extend([ + "create_oauth_scheme_from_discovery", + ]) diff --git a/src/google/adk/auth/auth_schemes.py b/src/google/adk/auth/auth_schemes.py index baccf648d..f238dbb91 100644 --- a/src/google/adk/auth/auth_schemes.py +++ b/src/google/adk/auth/auth_schemes.py @@ -50,16 +50,32 @@ class OAuthGrantType(str, Enum): PASSWORD = "password" @staticmethod - def from_flow(flow: OAuthFlows) -> "OAuthGrantType": - """Converts an OAuthFlows object to a OAuthGrantType.""" + def from_flow(flow: OAuthFlows) -> Optional["OAuthGrantType"]: + """Converts an OAuthFlows object to a OAuthGrantType. + + Determines the OAuth2 grant type based on which flow is configured + in the OAuthFlows object. Prioritizes client credentials as it's the + most specific flow for machine-to-machine authentication. + + Args: + flow: The OAuthFlows object containing flow configurations. + + Returns: + The corresponding OAuthGrantType, or None if no recognized flow is found. + """ + # Prioritize client credentials for machine-to-machine authentication if flow.clientCredentials: return OAuthGrantType.CLIENT_CREDENTIALS + # Authorization code flow for interactive user authentication if flow.authorizationCode: return OAuthGrantType.AUTHORIZATION_CODE + # Implicit flow (less secure, deprecated in OAuth 2.1) if flow.implicit: return OAuthGrantType.IMPLICIT + # Password flow (not recommended for security reasons) if flow.password: return OAuthGrantType.PASSWORD + # No recognized flow found return None diff --git a/src/google/adk/auth/credential_manager.py b/src/google/adk/auth/credential_manager.py index c5dae9f51..b6f1c97ab 100644 --- a/src/google/adk/auth/credential_manager.py +++ b/src/google/adk/auth/credential_manager.py @@ -14,6 +14,7 @@ from __future__ import annotations +import logging from typing import Optional from ..agents.callback_context import CallbackContext @@ -27,6 +28,8 @@ from .refresher.base_credential_refresher import BaseCredentialRefresher from .refresher.credential_refresher_registry import CredentialRefresherRegistry +logger = logging.getLogger("google_adk." + __name__) + @experimental class CredentialManager: @@ -76,9 +79,14 @@ def __init__( self._refresher_registry = CredentialRefresherRegistry() # Register default exchangers and refreshers - # TODO: support service account credential exchanger + # Register OAuth2 exchanger for client credentials and authorization code flows + from .exchanger.oauth2_credential_exchanger import OAuth2CredentialExchanger from .refresher.oauth2_credential_refresher import OAuth2CredentialRefresher + oauth2_exchanger = OAuth2CredentialExchanger() + self._exchanger_registry.register(AuthCredentialTypes.OAUTH2, oauth2_exchanger) + self._exchanger_registry.register(AuthCredentialTypes.OPEN_ID_CONNECT, oauth2_exchanger) + oauth2_refresher = OAuth2CredentialRefresher() self._refresher_registry.register( AuthCredentialTypes.OAUTH2, oauth2_refresher @@ -87,6 +95,8 @@ def __init__( AuthCredentialTypes.OPEN_ID_CONNECT, oauth2_refresher ) + # TODO: support service account credential exchanger + def register_credential_exchanger( self, credential_type: AuthCredentialTypes, @@ -104,59 +114,104 @@ async def request_credential(self, callback_context: CallbackContext) -> None: callback_context.request_credential(self._auth_config) async def get_auth_credential( - self, callback_context: CallbackContext + self, callback_context: CallbackContext, verify_ssl: bool = True ) -> Optional[AuthCredential]: - """Load and prepare authentication credential through a structured workflow.""" + """Load and prepare authentication credential through a structured workflow. + + Args: + callback_context: The callback context for credential operations. + verify_ssl: Whether to verify SSL certificates during OAuth operations (default: True). + Set to False for self-signed certificates. + + Returns: + The prepared authentication credential, or None if unavailable. + """ + + logger.debug("šŸ”„ CredentialManager.get_auth_credential() called") # Step 1: Validate credential configuration + logger.debug("šŸ” Step 1: Validating credential configuration") await self._validate_credential() + logger.debug("āœ… Step 1: Credential validation passed") # Step 2: Check if credential is already ready (no processing needed) + logger.debug("šŸ” Step 2: Checking if credential is already ready") if self._is_credential_ready(): + logger.debug("āœ… Step 2: Credential is ready, returning raw credential") return self._auth_config.raw_auth_credential + logger.debug("āœ… Step 2: Credential needs processing") # Step 3: Try to load existing processed credential + logger.debug("šŸ” Step 3: Loading existing processed credential") credential = await self._load_existing_credential(callback_context) + if credential: + logger.debug("āœ… Step 3: Found existing credential") + if credential.oauth2 and credential.oauth2.access_token: + logger.debug("āœ… Existing credential has access_token, skipping exchange") + else: + logger.debug("āš ļø Existing credential has no access_token") + else: + logger.debug("āœ… Step 3: No existing credential found") # Step 4: If no existing credential, load from auth response # TODO instead of load from auth response, we can store auth response in # credential service. was_from_auth_response = False if not credential: + logger.debug("šŸ” Step 4: Loading from auth response") credential = await self._load_from_auth_response(callback_context) - was_from_auth_response = True + if credential: + logger.debug("āœ… Step 4: Found credential from auth response") + was_from_auth_response = True + else: + logger.debug("āœ… Step 4: No credential from auth response") # Step 5: If still no credential available, return None if not credential: - return None + # For OAuth2 client credentials, fallback to raw credential + if (self._auth_config.raw_auth_credential and + self._auth_config.raw_auth_credential.auth_type == AuthCredentialTypes.OAUTH2): + logger.debug("āœ… Step 5: Using raw OAuth2 credential for client credentials flow") + credential = self._auth_config.raw_auth_credential + else: + logger.debug("āŒ Step 5: No credential available, returning None") + return None + logger.debug("āœ… Step 5: Credential available, proceeding to exchange") # Step 6: Exchange credential if needed (e.g., service account to access token) - credential, was_exchanged = await self._exchange_credential(credential) + logger.debug("šŸ” Step 6: Starting credential exchange") + credential, was_exchanged = await self._exchange_credential(credential, verify_ssl) + logger.debug(f"āœ… Step 6: Exchange completed, was_exchanged={was_exchanged}") # Step 7: Refresh credential if expired + logger.debug("šŸ” Step 7: Checking if refresh needed") was_refreshed = False if not was_exchanged: credential, was_refreshed = await self._refresh_credential(credential) + logger.debug(f"āœ… Step 7: Refresh completed, was_refreshed={was_refreshed}") + else: + logger.debug("āœ… Step 7: Skipping refresh since credential was exchanged") # Step 8: Save credential if it was modified if was_from_auth_response or was_exchanged or was_refreshed: + logger.debug("šŸ” Step 8: Saving modified credential") await self._save_credential(callback_context, credential) + logger.debug("āœ… Step 8: Credential saved") return credential async def _load_existing_credential( self, callback_context: CallbackContext ) -> Optional[AuthCredential]: - """Load existing credential from credential service or cached exchanged credential.""" - - # Try loading from credential service first + """Load existing credential.""" + # First try to load from credential service credential = await self._load_from_credential_service(callback_context) if credential: return credential - # Check if we have a cached exchanged credential - if self._auth_config.exchanged_auth_credential: - return self._auth_config.exchanged_auth_credential + # Then try to load from context + if hasattr(callback_context, "_auth_credential") and callback_context._auth_credential: + return callback_context._auth_credential return None @@ -178,16 +233,40 @@ async def _load_from_auth_response( return callback_context.get_auth_response(self._auth_config) async def _exchange_credential( - self, credential: AuthCredential + self, credential: AuthCredential, verify_ssl: bool = True ) -> tuple[AuthCredential, bool]: - """Exchange credential if needed and return the credential and whether it was exchanged.""" + """Exchange credential if needed and return the credential and whether it was exchanged. + + Args: + credential: The credential to exchange. + verify_ssl: Whether to verify SSL certificates during OAuth operations (default: True). + + Returns: + Tuple of (exchanged_credential, was_exchanged). + """ + logger.debug(f"šŸ”„ _exchange_credential called for credential type: {credential.auth_type}") + exchanger = self._exchanger_registry.get_exchanger(credential.auth_type) if not exchanger: + logger.debug(f"āŒ No exchanger found for credential type: {credential.auth_type}") return credential, False - exchanged_credential = await exchanger.exchange( - credential, self._auth_config.auth_scheme - ) + logger.debug(f"āœ… Found exchanger: {type(exchanger).__name__}") + logger.debug("šŸš€ Calling exchanger.exchange()") + + # Check if exchanger supports verify_ssl parameter (OAuth2CredentialExchanger does) + from .exchanger.oauth2_credential_exchanger import OAuth2CredentialExchanger + if isinstance(exchanger, OAuth2CredentialExchanger): + exchanged_credential = await exchanger.exchange( + credential, self._auth_config.auth_scheme, verify_ssl + ) + else: + # Fallback for other exchangers that don't support verify_ssl + exchanged_credential = await exchanger.exchange( + credential, self._auth_config.auth_scheme + ) + + logger.debug("āœ… Exchanger.exchange() completed") return exchanged_credential, True async def _refresh_credential( diff --git a/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py b/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py index 4231a7c1e..b29c4c930 100644 --- a/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py +++ b/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py @@ -19,6 +19,7 @@ import logging from typing import Optional +from fastapi.openapi.models import OAuth2 from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_schemes import AuthScheme from google.adk.auth.auth_schemes import OAuthGrantType @@ -32,6 +33,7 @@ try: from authlib.integrations.requests_client import OAuth2Session + import requests AUTHLIB_AVAILABLE = True except ImportError: @@ -42,20 +44,23 @@ @experimental class OAuth2CredentialExchanger(BaseCredentialExchanger): - """Exchanges OAuth2 credentials from authorization responses.""" + """Exchanges OAuth2 credentials from authorization responses or client credentials.""" @override async def exchange( self, auth_credential: AuthCredential, auth_scheme: Optional[AuthScheme] = None, + verify_ssl: bool = True, ) -> AuthCredential: - """Exchange OAuth2 credential from authorization response. + """Exchange OAuth2 credential based on the flow type. if credential exchange failed, the original credential will be returned. Args: auth_credential: The OAuth2 credential to exchange. auth_scheme: The OAuth2 authentication scheme. + verify_ssl: Whether to verify SSL certificates (default: True). + Set to False for self-signed certificates. Returns: The exchanged credential with access token. @@ -63,7 +68,10 @@ async def exchange( Raises: CredentialExchangError: If auth_scheme is missing. """ + logger.debug("šŸ”„ OAuth2CredentialExchanger.exchange() called") + if not auth_scheme: + logger.error("āŒ auth_scheme is missing") raise CredentialExchangError( "auth_scheme is required for OAuth2 credential exchange" ) @@ -74,13 +82,115 @@ async def exchange( # The client using this tool can decide to exchange the credential # themselves using other lib. logger.warning( - "authlib is not available, skipping OAuth2 credential exchange." + "āŒ authlib is not available, skipping OAuth2 credential exchange." ) return auth_credential + logger.debug("āœ… authlib is available") + if auth_credential.oauth2 and auth_credential.oauth2.access_token: + logger.debug("āœ… credential already has access_token, no exchange needed") + return auth_credential + + logger.debug("šŸ” credential needs token exchange") + + # Determine the OAuth2 flow type + grant_type = self._get_grant_type(auth_scheme) + logger.debug(f"šŸŽÆ detected grant type: {grant_type}") + + if grant_type == OAuthGrantType.CLIENT_CREDENTIALS: + logger.debug("šŸš€ starting client credentials exchange") + return await self._exchange_client_credentials(auth_credential, auth_scheme, verify_ssl) + elif grant_type == OAuthGrantType.AUTHORIZATION_CODE: + logger.debug("šŸš€ starting authorization code exchange") + return await self._exchange_authorization_code(auth_credential, auth_scheme) + else: + logger.warning(f"āŒ Unsupported OAuth2 grant type: {grant_type}") return auth_credential + def _get_grant_type(self, auth_scheme: AuthScheme) -> Optional[OAuthGrantType]: + """Determine the OAuth2 grant type from the auth scheme.""" + if isinstance(auth_scheme, OAuth2) and auth_scheme.flows: + return OAuthGrantType.from_flow(auth_scheme.flows) + return None + + async def _exchange_client_credentials( + self, + auth_credential: AuthCredential, + auth_scheme: AuthScheme, + verify_ssl: bool + ) -> AuthCredential: + """Handle OAuth2 client credentials flow.""" + + logger.debug("šŸ” _exchange_client_credentials() called") + + if not isinstance(auth_scheme, OAuth2) or not auth_scheme.flows.clientCredentials: + logger.warning("āŒ No client credentials flow configuration found") + return auth_credential + + flow = auth_scheme.flows.clientCredentials + token_url = flow.tokenUrl + scopes = list(flow.scopes.keys()) if flow.scopes else [] + + logger.debug(f"šŸŽÆ token_url: {token_url}") + logger.debug(f"šŸŽÆ scopes: {scopes}") + + if not auth_credential.oauth2 or not auth_credential.oauth2.client_id or not auth_credential.oauth2.client_secret: + logger.error("āŒ Client ID and secret required for client credentials flow") + return auth_credential + + logger.debug(f"āœ… client_id: {auth_credential.oauth2.client_id}") + logger.debug("āœ… client_secret: [REDACTED]") + + try: + logger.debug("šŸš€ Creating OAuth2Session for client credentials") + + # Create OAuth2 session for client credentials + # Use client_secret_post to send credentials in form body, not HTTP Basic Auth + client = OAuth2Session( + auth_credential.oauth2.client_id, + auth_credential.oauth2.client_secret, + scope=" ".join(scopes), + token_endpoint_auth_method='client_secret_post' + ) + + # Set SSL verification on the OAuth2Session (which inherits from requests.Session) + client.verify = verify_ssl + + if not verify_ssl: + logger.debug("āš ļø SSL certificate verification disabled") + # Suppress SSL warnings when verification is disabled + import urllib3 + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + logger.debug(f"šŸ“” Making POST request to token endpoint: {token_url} (verify_ssl={verify_ssl})") + # Fetch token using client credentials grant + tokens = client.fetch_token( + token_url, + grant_type=OAuthGrantType.CLIENT_CREDENTIALS, + ) + + logger.debug("āœ… Successfully received tokens from server") + logger.debug(f"šŸ”‘ received tokens: {list(tokens.keys())}") + + # Update credential with tokens + update_credential_with_tokens(auth_credential, tokens) + logger.debug("āœ… Successfully exchanged OAuth2 client credentials") + + except Exception as e: + logger.error(f"āŒ Failed to exchange OAuth2 client credentials: {e}") + logger.exception("Exception details:") + return auth_credential + + return auth_credential + + async def _exchange_authorization_code( + self, + auth_credential: AuthCredential, + auth_scheme: AuthScheme + ) -> AuthCredential: + """Handle OAuth2 authorization code flow (existing logic).""" + client, token_endpoint = create_oauth2_session(auth_scheme, auth_credential) if not client: logger.warning("Could not create OAuth2 session for token exchange") @@ -89,15 +199,15 @@ async def exchange( try: tokens = client.fetch_token( token_endpoint, - authorization_response=auth_credential.oauth2.auth_response_uri, - code=auth_credential.oauth2.auth_code, + authorization_response=auth_credential.oauth2.auth_response_uri if auth_credential.oauth2 else None, + code=auth_credential.oauth2.auth_code if auth_credential.oauth2 else None, grant_type=OAuthGrantType.AUTHORIZATION_CODE, ) update_credential_with_tokens(auth_credential, tokens) - logger.debug("Successfully exchanged OAuth2 tokens") + logger.debug("Successfully exchanged OAuth2 authorization code") except Exception as e: # TODO reconsider whether we should raise errors in this case - logger.error("Failed to exchange OAuth2 tokens: %s", e) + logger.error("Failed to exchange OAuth2 authorization code: %s", e) # Return original credential on failure return auth_credential diff --git a/src/google/adk/auth/oauth2_credential_util.py b/src/google/adk/auth/oauth2_credential_util.py index cc315bd29..f8751c7eb 100644 --- a/src/google/adk/auth/oauth2_credential_util.py +++ b/src/google/adk/auth/oauth2_credential_util.py @@ -55,15 +55,18 @@ def create_oauth2_session( if not hasattr(auth_scheme, "token_endpoint"): return None, None token_endpoint = auth_scheme.token_endpoint - scopes = auth_scheme.scopes + scopes = auth_scheme.scopes or [] elif isinstance(auth_scheme, OAuth2): - if ( - not auth_scheme.flows.authorizationCode - or not auth_scheme.flows.authorizationCode.tokenUrl - ): + # Handle client credentials flow + if auth_scheme.flows.clientCredentials: + token_endpoint = auth_scheme.flows.clientCredentials.tokenUrl + scopes = list(auth_scheme.flows.clientCredentials.scopes.keys()) if auth_scheme.flows.clientCredentials.scopes else [] + # Handle authorization code flow + elif auth_scheme.flows.authorizationCode: + token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl + scopes = list(auth_scheme.flows.authorizationCode.scopes.keys()) if auth_scheme.flows.authorizationCode.scopes else [] + else: return None, None - token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl - scopes = list(auth_scheme.flows.authorizationCode.scopes.keys()) else: return None, None @@ -75,16 +78,28 @@ def create_oauth2_session( ): return None, None - return ( - OAuth2Session( - auth_credential.oauth2.client_id, - auth_credential.oauth2.client_secret, - scope=" ".join(scopes), - redirect_uri=auth_credential.oauth2.redirect_uri, - state=auth_credential.oauth2.state, - ), - token_endpoint, - ) + # For client credentials flow, we don't need redirect_uri or state + if isinstance(auth_scheme, OAuth2) and auth_scheme.flows.clientCredentials: + return ( + OAuth2Session( + auth_credential.oauth2.client_id, + auth_credential.oauth2.client_secret, + scope=" ".join(scopes), + ), + token_endpoint, + ) + else: + # For authorization code flow, include redirect_uri and state + return ( + OAuth2Session( + auth_credential.oauth2.client_id, + auth_credential.oauth2.client_secret, + scope=" ".join(scopes), + redirect_uri=auth_credential.oauth2.redirect_uri, + state=auth_credential.oauth2.state, + ), + token_endpoint, + ) @experimental @@ -97,11 +112,24 @@ def update_credential_with_tokens( auth_credential: The authentication credential to update. tokens: The OAuth2Token object containing new token information. """ - auth_credential.oauth2.access_token = tokens.get("access_token") - auth_credential.oauth2.refresh_token = tokens.get("refresh_token") - auth_credential.oauth2.expires_at = ( - int(tokens.get("expires_at")) if tokens.get("expires_at") else None - ) - auth_credential.oauth2.expires_in = ( - int(tokens.get("expires_in")) if tokens.get("expires_in") else None - ) + if not auth_credential.oauth2: + return + + # Cast token values to appropriate types + access_token = tokens.get("access_token") + auth_credential.oauth2.access_token = str(access_token) if access_token else None + + refresh_token = tokens.get("refresh_token") + auth_credential.oauth2.refresh_token = str(refresh_token) if refresh_token else None + + expires_at = tokens.get("expires_at") + try: + auth_credential.oauth2.expires_at = int(expires_at) if expires_at is not None else None + except (ValueError, TypeError): + auth_credential.oauth2.expires_at = None + + expires_in = tokens.get("expires_in") + try: + auth_credential.oauth2.expires_in = int(expires_in) if expires_in is not None else None + except (ValueError, TypeError): + auth_credential.oauth2.expires_in = None diff --git a/src/google/adk/auth/oauth2_discovery_util.py b/src/google/adk/auth/oauth2_discovery_util.py new file mode 100644 index 000000000..8e0068155 --- /dev/null +++ b/src/google/adk/auth/oauth2_discovery_util.py @@ -0,0 +1,186 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth2 discovery utilities for automatic configuration discovery.""" + +from __future__ import annotations + +import json +import logging +from typing import Any, Dict, Optional + +import httpx +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowClientCredentials +from fastapi.openapi.models import OAuthFlows +from google.adk.utils.feature_decorator import experimental + +logger = logging.getLogger(__name__) + +# OAuth Discovery Constants +OAUTH_PROTECTED_RESOURCE_DISCOVERY = ".well-known/oauth-protected-resource" +OAUTH_AUTHORIZATION_SERVER_DISCOVERY = ".well-known/oauth-authorization-server" + + +@experimental +def _validate_oauth_discovery_response(config: Dict[str, Any]) -> bool: + """ + Validate OAuth discovery response contains required fields. + + Args: + config: The discovered OAuth configuration + + Returns: + True if configuration is valid, False otherwise + """ + # For oauth-protected-resource discovery + if "authorization_servers" in config: + return isinstance(config["authorization_servers"], list) and bool(config["authorization_servers"]) + + # For oauth-authorization-server discovery + if "token_endpoint" in config: + return isinstance(config["token_endpoint"], str) and bool(config["token_endpoint"]) + + return False + + +@experimental +async def create_oauth_scheme_from_discovery( + base_url: str, + scopes: Optional[list[str]] = None, + timeout: float = 10.0, + verify_ssl: bool = True +) -> Optional[OAuth2]: + """ + Create an OAuth2 auth scheme by automatically discovering OAuth configuration. + + Implements RFC 8414 two-stage discovery: + 1. Query .well-known/oauth-protected-resource to find authorization server + 2. Query authorization server's .well-known/oauth-authorization-server for token endpoint + + Args: + base_url: The base URL to discover OAuth configuration for + scopes: List of OAuth scopes to request + timeout: Discovery request timeout in seconds + verify_ssl: Whether to verify SSL certificates (default: True). + Set to False for self-signed certificates. + + Returns: + OAuth2 auth scheme with discovered configuration, or None if discovery fails + """ + # Stage 1: Try to find authorization server from protected resource endpoint + protected_resource_config = await _query_oauth_endpoint( + base_url, OAUTH_PROTECTED_RESOURCE_DISCOVERY, timeout, verify_ssl + ) + + token_endpoint = None + + if protected_resource_config and "authorization_servers" in protected_resource_config: + # Stage 2: Query the authorization server's oauth-authorization-server endpoint + auth_servers = protected_resource_config["authorization_servers"] + if auth_servers: + auth_server_url = auth_servers[0] + logger.debug(f"Found authorization server: {auth_server_url}") + + # Specifically query the authorization server's oauth-authorization-server endpoint + auth_server_config = await _query_oauth_endpoint( + auth_server_url, OAUTH_AUTHORIZATION_SERVER_DISCOVERY, timeout, verify_ssl + ) + + if auth_server_config and "token_endpoint" in auth_server_config: + token_endpoint = auth_server_config["token_endpoint"] + logger.debug(f"Discovered token endpoint: {token_endpoint}") + else: + logger.warning(f"Authorization server {auth_server_url} did not provide token_endpoint") + # Fallback: assume standard /token endpoint + token_endpoint = f"{auth_server_url.rstrip('/')}/token" + logger.debug(f"Using fallback token endpoint: {token_endpoint}") + else: + # Fallback: Try direct authorization server discovery at base URL + logger.debug(f"No oauth-protected-resource found, trying direct authorization server discovery at {base_url}") + auth_server_config = await _query_oauth_endpoint( + base_url, OAUTH_AUTHORIZATION_SERVER_DISCOVERY, timeout, verify_ssl + ) + + if auth_server_config and "token_endpoint" in auth_server_config: + token_endpoint = auth_server_config["token_endpoint"] + logger.debug(f"Discovered token endpoint via direct discovery: {token_endpoint}") + + if not token_endpoint: + logger.warning("Could not determine token endpoint from OAuth discovery") + return None + + # Create scopes dictionary + scopes = scopes or ["read", "write"] + scopes_dict = {scope: f"Access to {scope}" for scope in scopes} + + # Create OAuth2 scheme with client credentials flow + return OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl=token_endpoint, + scopes=scopes_dict + ) + ) + ) + + +@experimental +async def _query_oauth_endpoint( + base_url: str, + endpoint_path: str, + timeout: float, + verify_ssl: bool = True +) -> Optional[Dict[str, Any]]: + """ + Query a specific OAuth discovery endpoint. + + Args: + base_url: The base URL of the server + endpoint_path: The discovery endpoint path (e.g., ".well-known/oauth-protected-resource") + timeout: Request timeout in seconds + verify_ssl: Whether to verify SSL certificates (default: True). + Set to False for self-signed certificates. + + Returns: + Dictionary containing the discovery response, or None if failed + """ + discovery_url = f"{base_url.rstrip('/')}/{endpoint_path}" + + async with httpx.AsyncClient(timeout=timeout, verify=verify_ssl) as client: + try: + logger.debug(f"Querying OAuth endpoint: {discovery_url} (verify_ssl={verify_ssl})") + + response = await client.get(discovery_url) + response.raise_for_status() + + config = response.json() + logger.debug(f"Successfully got response from {discovery_url}") + + # Validate response has expected structure + if _validate_oauth_discovery_response(config): + return config + else: + logger.warning(f"Invalid OAuth discovery response from {discovery_url}") + return None + + except httpx.HTTPStatusError as e: + logger.debug(f"OAuth endpoint {discovery_url} returned HTTP {e.response.status_code}") + return None + except (httpx.RequestError, json.JSONDecodeError) as e: + logger.debug(f"Failed to query OAuth endpoint {discovery_url}: {e}") + return None + except Exception as e: + logger.warning(f"Unexpected error querying OAuth endpoint {discovery_url}: {e}") + return None \ No newline at end of file diff --git a/src/google/adk/auth/refresher/oauth2_credential_refresher.py b/src/google/adk/auth/refresher/oauth2_credential_refresher.py index 02d8ebfb7..a92435ef3 100644 --- a/src/google/adk/auth/refresher/oauth2_credential_refresher.py +++ b/src/google/adk/auth/refresher/oauth2_credential_refresher.py @@ -20,8 +20,10 @@ import logging from typing import Optional +from fastapi.openapi.models import OAuth2 from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.auth_schemes import OAuthGrantType from google.adk.auth.oauth2_credential_util import create_oauth2_session from google.adk.auth.oauth2_credential_util import update_credential_with_tokens from google.adk.utils.feature_decorator import experimental @@ -99,28 +101,77 @@ async def refresh( if not auth_credential.oauth2: return auth_credential + # Check if token is expired if OAuth2Token({ "expires_at": auth_credential.oauth2.expires_at, "expires_in": auth_credential.oauth2.expires_in, }).is_expired(): - client, token_endpoint = create_oauth2_session( - auth_scheme, auth_credential - ) - if not client: - logger.warning("Could not create OAuth2 session for token refresh") + + # Determine the OAuth2 flow type + grant_type = self._get_grant_type(auth_scheme) + + if grant_type == OAuthGrantType.CLIENT_CREDENTIALS: + return await self._refresh_client_credentials(auth_credential, auth_scheme) + elif grant_type == OAuthGrantType.AUTHORIZATION_CODE: + return await self._refresh_authorization_code(auth_credential, auth_scheme) + else: + logger.warning(f"Unsupported OAuth2 grant type for refresh: {grant_type}") return auth_credential - try: - tokens = client.refresh_token( - url=token_endpoint, - refresh_token=auth_credential.oauth2.refresh_token, - ) - update_credential_with_tokens(auth_credential, tokens) - logger.debug("Successfully refreshed OAuth2 tokens") - except Exception as e: - # TODO reconsider whether we should raise error when refresh failed. - logger.error("Failed to refresh OAuth2 tokens: %s", e) - # Return original credential on failure - return auth_credential + return auth_credential + + def _get_grant_type(self, auth_scheme: AuthScheme) -> Optional[OAuthGrantType]: + """Determine the OAuth2 grant type from the auth scheme.""" + if isinstance(auth_scheme, OAuth2) and auth_scheme.flows: + return OAuthGrantType.from_flow(auth_scheme.flows) + return None + + async def _refresh_client_credentials( + self, + auth_credential: AuthCredential, + auth_scheme: AuthScheme + ) -> AuthCredential: + """Refresh client credentials by getting a new token (no refresh tokens in client credentials flow).""" + + # For client credentials, "refresh" means getting a new token + # Import here to avoid circular imports + from ..exchanger.oauth2_credential_exchanger import OAuth2CredentialExchanger + + try: + exchanger = OAuth2CredentialExchanger() + # Clear the access token to force re-exchange + if auth_credential.oauth2: + auth_credential.oauth2.access_token = None + return await exchanger.exchange(auth_credential, auth_scheme) + except Exception as e: + logger.error("Failed to refresh OAuth2 client credentials: %s", e) + return auth_credential + + async def _refresh_authorization_code( + self, + auth_credential: AuthCredential, + auth_scheme: AuthScheme + ) -> AuthCredential: + """Refresh authorization code credentials using refresh token.""" + + client, token_endpoint = create_oauth2_session( + auth_scheme, auth_credential + ) + if not client: + logger.warning("Could not create OAuth2 session for token refresh") + return auth_credential + + try: + tokens = client.refresh_token( + url=token_endpoint, + refresh_token=auth_credential.oauth2.refresh_token if auth_credential.oauth2 else None, + ) + update_credential_with_tokens(auth_credential, tokens) + logger.debug("Successfully refreshed OAuth2 authorization code tokens") + except Exception as e: + # TODO reconsider whether we should raise error when refresh failed. + logger.error("Failed to refresh OAuth2 authorization code tokens: %s", e) + # Return original credential on failure + return auth_credential return auth_credential diff --git a/src/google/adk/tools/mcp_tool/mcp_auth_discovery.py b/src/google/adk/tools/mcp_tool/mcp_auth_discovery.py new file mode 100644 index 000000000..df1acf46b --- /dev/null +++ b/src/google/adk/tools/mcp_tool/mcp_auth_discovery.py @@ -0,0 +1,115 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth discovery configuration for MCP tools.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Optional + + +@dataclass +class MCPAuthDiscovery: + """Configuration for OAuth2 discovery in MCP tools. + + This class encapsulates parameters needed for automatic OAuth2 discovery, + providing a clean API for configuring how MCPToolset should discover + OAuth2 token endpoints. + + Attributes: + base_url: The base server URL for OAuth discovery endpoints (e.g., "http://server:9204"). + If None, MCPToolset will automatically extract the base URL from connection parameters. + OAuth .well-known endpoints will be queried at this URL root. + timeout: Timeout in seconds for OAuth discovery requests (default: 10.0). + enabled: Whether OAuth discovery is enabled (default: True). + verify_ssl: Whether to verify SSL certificates during discovery (default: True). + Set to False for self-signed certificates or development environments. + + Note: + OAuth scopes should be specified in the auth_scheme parameter of MCPToolset, + not in this discovery configuration. Discovery only finds token endpoints. + + Examples: + >>> # Override just SSL verification (base_url auto-extracted) + >>> discovery = MCPAuthDiscovery(verify_ssl=False) + >>> toolset = MCPToolset( + ... connection_params=StreamableHTTPConnectionParams(url="https://localhost:9204/mcp/"), + ... auth_credential=credential, + ... auth_discovery=discovery # base_url will be auto-extracted as "https://localhost:9204" + ... ) + + >>> # Override multiple settings + >>> discovery = MCPAuthDiscovery( + ... verify_ssl=False, + ... timeout=15.0 + ... ) + + >>> # Explicit base_url (override auto-extraction) + >>> discovery = MCPAuthDiscovery( + ... base_url="https://auth-server.example.com", + ... verify_ssl=False + ... ) + + >>> # For production with valid SSL certificates + >>> discovery = MCPAuthDiscovery( + ... base_url="https://api.example.com" + ... ) + + >>> # For development with self-signed certificates + >>> discovery = MCPAuthDiscovery( + ... base_url="https://localhost:9204", + ... verify_ssl=False # Disable SSL verification + ... ) + + >>> # Scopes go in auth scheme, not discovery config + >>> auth_scheme = OAuth2( + ... flows=OAuthFlows( + ... clientCredentials=OAuthFlowClientCredentials( + ... tokenUrl="", # Will be discovered + ... scopes={"read": "Read access", "write": "Write access"} + ... ) + ... ) + ... ) + >>> toolset = MCPToolset( + ... connection_params=StreamableHTTPConnectionParams(url="https://localhost:9204/mcp/"), + ... auth_scheme=auth_scheme, + ... auth_credential=credential, + ... auth_discovery=discovery + ... ) + """ + + base_url: Optional[str] = None + timeout: float = 10.0 + enabled: bool = True + verify_ssl: bool = True + + def __post_init__(self): + """Validate and normalize configuration after initialization.""" + # Normalize base URL - remove trailing slash (if provided) + if self.base_url: + self.base_url = self.base_url.rstrip('/') + + # Validate timeout + if self.timeout <= 0: + raise ValueError("Discovery timeout must be positive") + + @property + def is_enabled(self) -> bool: + """Check if OAuth discovery is enabled and properly configured. + + Returns True if discovery is enabled. Note that base_url can be None + since MCPToolset will auto-extract it from connection parameters. + """ + return self.enabled diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index af4616cae..95369dd72 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -84,14 +84,23 @@ def __init__( Raises: ValueError: If mcp_tool or mcp_session_manager is None. """ + logger.debug(f"šŸ”§ MCPTool.__init__() called for '{mcp_tool.name}'") + logger.debug(f"šŸ” auth_scheme provided: {auth_scheme is not None}") + logger.debug(f"šŸ” auth_credential provided: {auth_credential is not None}") + + auth_config = None + if auth_scheme: + auth_config = AuthConfig( + auth_scheme=auth_scheme, raw_auth_credential=auth_credential + ) + logger.debug("āœ… Created AuthConfig - CredentialManager will be initialized") + else: + logger.warning("āŒ No auth_scheme provided - no CredentialManager will be created") + super().__init__( name=mcp_tool.name, description=mcp_tool.description if mcp_tool.description else "", - auth_config=AuthConfig( - auth_scheme=auth_scheme, raw_auth_credential=auth_credential - ) - if auth_scheme - else None, + auth_config=auth_config, ) self._mcp_tool = mcp_tool self._mcp_session_manager = mcp_session_manager diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index 2fc9d640a..18c8351a7 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -21,12 +21,15 @@ from typing import TextIO from typing import Union +from fastapi.openapi.models import OAuth2 from ...agents.readonly_context import ReadonlyContext from ...auth.auth_credential import AuthCredential from ...auth.auth_schemes import AuthScheme +from ...auth.oauth2_discovery_util import create_oauth_scheme_from_discovery from ..base_tool import BaseTool from ..base_toolset import BaseToolset from ..base_toolset import ToolPredicate +from .mcp_auth_discovery import MCPAuthDiscovery from .mcp_session_manager import MCPSessionManager from .mcp_session_manager import retry_on_closed_resource from .mcp_session_manager import SseConnectionParams @@ -61,16 +64,43 @@ class MCPToolset(BaseToolset): that can be used by an agent. It properly implements the BaseToolset interface for easy integration with the agent framework. + **OAuth Discovery by Default**: MCPToolset automatically attempts OAuth2 discovery + for HTTP-based connections (StreamableHTTP, SSE) unless explicitly disabled. This + provides seamless authentication setup without manual configuration. + Usage:: + # Basic usage with automatic OAuth discovery (default behavior) toolset = MCPToolset( - connection_params=StdioServerParameters( - command='npx', - args=["-y", "@modelcontextprotocol/server-filesystem"], + connection_params=StreamableHTTPConnectionParams( + url='http://localhost:9204/mcp/', ), + auth_credential=oauth2_credential, # OAuth discovery will find token endpoint tool_filter=['read_file', 'list_directory'] # Optional: filter specific tools ) + # Explicit OAuth discovery configuration (overrides default) + toolset = MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url='http://localhost:9204/mcp/', + ), + auth_credential=oauth2_credential, + auth_discovery=MCPAuthDiscovery( + base_url='http://custom-auth-server:9205', # Different auth server + timeout=15.0 + ), + tool_filter=['read_file', 'list_directory'] + ) + + # Disable OAuth discovery completely + toolset = MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url='http://localhost:9204/mcp/', + ), + auth_credential=oauth2_credential, + auth_discovery=MCPAuthDiscovery(enabled=False), # Explicitly disabled + ) + # Use in an agent agent = LlmAgent( model='gemini-2.0-flash', @@ -97,6 +127,7 @@ def __init__( errlog: TextIO = sys.stderr, auth_scheme: Optional[AuthScheme] = None, auth_credential: Optional[AuthCredential] = None, + auth_discovery: Optional[MCPAuthDiscovery] = None, ): """Initializes the MCPToolset. @@ -113,8 +144,12 @@ def __init__( list of tool names to include - A ToolPredicate function for custom filtering logic errlog: TextIO stream for error logging. - auth_scheme: The auth scheme of the tool for tool calling + auth_scheme: The auth scheme of the tool for tool calling. If not provided + and OAuth discovery succeeds, a discovered scheme will be used. auth_credential: The auth credential of the tool for tool calling + auth_discovery: Optional OAuth discovery configuration. If not provided, + automatic OAuth discovery will be enabled for HTTP-based connections using + the server's base URL. Set to MCPAuthDiscovery(enabled=False) to disable. """ super().__init__(tool_filter=tool_filter) @@ -131,6 +166,162 @@ def __init__( ) self._auth_scheme = auth_scheme self._auth_credential = auth_credential + + # Default OAuth discovery behavior: Auto-enable for HTTP connections + if auth_discovery is None: + auth_discovery = self._create_default_auth_discovery() + + self._auth_discovery = auth_discovery + self._oauth_discovery_attempted = False + + def _create_default_auth_discovery(self) -> MCPAuthDiscovery: + """Create default OAuth discovery configuration from connection parameters. + + Returns: + MCPAuthDiscovery instance configured for the connection, or disabled + if the connection type doesn't support OAuth discovery. + """ + # Extract base URL from HTTP-based connection parameters + base_url = None + + if isinstance(self._connection_params, StreamableHTTPConnectionParams): + # Extract server root from MCP URL + full_url = self._connection_params.url + from urllib.parse import urlparse + parsed = urlparse(full_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + logger.debug(f"Auto-detected OAuth discovery base URL: {base_url} from MCP URL: {full_url}") + + elif isinstance(self._connection_params, SseConnectionParams): + # Extract server root from SSE URL + full_url = self._connection_params.url + from urllib.parse import urlparse + parsed = urlparse(full_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + logger.debug(f"Auto-detected OAuth discovery base URL: {base_url} from SSE URL: {full_url}") + + if base_url: + logger.debug(f"āœ… Enabling default OAuth discovery for HTTP connection at: {base_url}") + return MCPAuthDiscovery( + base_url=base_url, + timeout=10.0, + enabled=True + ) + else: + # For Stdio connections, OAuth discovery is not applicable + logger.debug("āŒ Disabling OAuth discovery for non-HTTP connection (Stdio)") + return MCPAuthDiscovery( + enabled=False + ) + + async def _perform_oauth_discovery(self) -> None: + """Perform OAuth discovery if enabled and not already attempted.""" + logger.debug("šŸ” _perform_oauth_discovery() called") + logger.debug(f"šŸ” auth_discovery: {self._auth_discovery}") + logger.debug(f"šŸ” current auth_scheme: {self._auth_scheme}") + + if ( + not self._auth_discovery or not self._auth_discovery.is_enabled + or self._oauth_discovery_attempted + ): + logger.debug("āŒ OAuth discovery skipped (not enabled or already attempted)") + return + + # Check if we need discovery even when auth_scheme is provided + needs_discovery = False + + if self._auth_scheme is None: + # No auth scheme provided - definitely need discovery + needs_discovery = True + logger.debug("šŸŽÆ OAuth discovery needed: no auth scheme provided") + elif isinstance(self._auth_scheme, OAuth2): + # Check if OAuth2 scheme has client credentials flow with empty/invalid tokenUrl + if (self._auth_scheme.flows and + self._auth_scheme.flows.clientCredentials and + (not self._auth_scheme.flows.clientCredentials.tokenUrl or + self._auth_scheme.flows.clientCredentials.tokenUrl.strip() == "")): + needs_discovery = True + logger.debug("šŸŽÆ OAuth discovery needed: empty tokenUrl in existing scheme") + + if not needs_discovery: + logger.debug("āŒ OAuth discovery not needed") + return + + self._oauth_discovery_attempted = True + logger.debug("šŸš€ Starting OAuth discovery process") + + # Determine the discovery base URL + if self._auth_discovery.base_url: + # Use explicitly configured base URL + base_url = self._auth_discovery.base_url + logger.debug(f"Using explicitly configured discovery base URL: {base_url}") + else: + # Auto-extract base URL from connection parameters (same logic as _create_default_auth_discovery) + base_url = None + if isinstance(self._connection_params, StreamableHTTPConnectionParams): + # Extract server root from HTTP URL + full_url = self._connection_params.url + from urllib.parse import urlparse + parsed = urlparse(full_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + logger.debug(f"Auto-extracted OAuth discovery base URL: {base_url} from MCP URL: {full_url}") + + elif isinstance(self._connection_params, SseConnectionParams): + # Extract server root from SSE URL + full_url = self._connection_params.url + from urllib.parse import urlparse + parsed = urlparse(full_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + logger.debug(f"Auto-extracted OAuth discovery base URL: {base_url} from SSE URL: {full_url}") + + if not base_url: + logger.debug("āŒ Cannot auto-extract base URL for OAuth discovery (non-HTTP connection)") + return + + try: + logger.debug(f"šŸ” Attempting OAuth discovery at server root: {base_url}") + + # Extract scopes from existing auth scheme if available + discovery_scopes = None + if (isinstance(self._auth_scheme, OAuth2) and + self._auth_scheme.flows and + self._auth_scheme.flows.clientCredentials and + self._auth_scheme.flows.clientCredentials.scopes): + # Use scopes from existing auth scheme + discovery_scopes = list(self._auth_scheme.flows.clientCredentials.scopes.keys()) + logger.debug(f"Using scopes from auth scheme: {discovery_scopes}") + + discovered_scheme = await create_oauth_scheme_from_discovery( + base_url=base_url, + scopes=discovery_scopes, + timeout=self._auth_discovery.timeout, + verify_ssl=self._auth_discovery.verify_ssl + ) + + if discovered_scheme: + if self._auth_scheme is None: + # No existing scheme - use discovered scheme entirely + logger.debug("āœ… OAuth discovery successful - using discovered configuration") + self._auth_scheme = discovered_scheme + else: + # Existing scheme with empty tokenUrl - merge discovered tokenUrl + logger.debug("āœ… OAuth discovery successful - updating tokenUrl in existing scheme") + if (isinstance(self._auth_scheme, OAuth2) and + self._auth_scheme.flows and + self._auth_scheme.flows.clientCredentials and + isinstance(discovered_scheme, OAuth2) and + discovered_scheme.flows and + discovered_scheme.flows.clientCredentials): + # Update the tokenUrl with discovered value + self._auth_scheme.flows.clientCredentials.tokenUrl = discovered_scheme.flows.clientCredentials.tokenUrl + logger.debug(f"Updated tokenUrl to: {discovered_scheme.flows.clientCredentials.tokenUrl}") + else: + logger.debug("āŒ OAuth discovery failed - no valid configuration found") + + except Exception as e: + logger.warning(f"āŒ OAuth discovery failed with error: {e}") + + logger.debug(f"āœ… OAuth discovery completed. Final auth_scheme: {self._auth_scheme}") @retry_on_closed_resource async def get_tools( @@ -146,14 +337,80 @@ async def get_tools( Returns: List[BaseTool]: A list of tools available under the specified context. """ - # Get session from session manager - session = await self._mcp_session_manager.create_session() + # Perform OAuth discovery if needed + await self._perform_oauth_discovery() + + # Perform OAuth token exchange before session creation if we have auth + session_headers = None + if self._auth_scheme and self._auth_credential: + logger.debug("šŸ” Performing OAuth token exchange for session authentication") + + # Get verify_ssl setting from auth_discovery configuration + verify_ssl = True + if self._auth_discovery and hasattr(self._auth_discovery, 'verify_ssl'): + verify_ssl = self._auth_discovery.verify_ssl + + # Create a temporary CredentialManager to exchange tokens + from ...auth.auth_tool import AuthConfig + from ...auth.credential_manager import CredentialManager + + auth_config = AuthConfig( + auth_scheme=self._auth_scheme, + raw_auth_credential=self._auth_credential + ) + + credential_manager = CredentialManager(auth_config) + + # Create a dummy callback context for token exchange + # This is a simplified approach - in a full implementation this would come from the agent + class DummyCallbackContext: + def __init__(self): + from ...agents.readonly_context import ReadonlyContext + self._invocation_context = type('obj', (object,), { + 'credential_service': None, + 'app_name': 'mcp_toolset', + 'user_id': 'system' + })() + + def get_auth_response(self, auth_config): + """Return None since we're using raw credentials for client credentials flow.""" + return None + + async def load_credential(self, auth_config): + """Return None since no stored credentials.""" + return None + + async def save_credential(self, auth_config): + """No-op for dummy context.""" + pass + + dummy_context = DummyCallbackContext() + + try: + # Exchange credentials to get access token with SSL verification setting + exchanged_credential = await credential_manager.get_auth_credential(dummy_context, verify_ssl) + + if exchanged_credential and exchanged_credential.oauth2 and exchanged_credential.oauth2.access_token: + logger.debug(f"āœ… Successfully obtained access token for session") + session_headers = {"Authorization": f"Bearer {exchanged_credential.oauth2.access_token}"} + else: + logger.debug("āŒ Failed to obtain access token for session") + except Exception as e: + logger.debug(f"āŒ OAuth token exchange failed: {e}") + + # Get session from session manager with OAuth headers + session = await self._mcp_session_manager.create_session(headers=session_headers) # Fetch available tools from the MCP server + logger.debug("šŸ” Calling session.list_tools()") tools_response: ListToolsResult = await session.list_tools() + logger.debug(f"āœ… Retrieved {len(tools_response.tools)} tools from MCP server") # Apply filtering based on context and tool_filter tools = [] + logger.debug(f"šŸ” Creating MCPTools with auth_scheme: {self._auth_scheme}") + logger.debug(f"šŸ” Auth credential: {self._auth_credential}") + for tool in tools_response.tools: mcp_tool = MCPTool( mcp_tool=tool, @@ -161,8 +418,15 @@ async def get_tools( auth_scheme=self._auth_scheme, auth_credential=self._auth_credential, ) + + logger.debug(f"āœ… Created MCPTool '{tool.name}' with auth_config: {mcp_tool._credentials_manager is not None}") - if self._is_tool_selected(mcp_tool, readonly_context): + # Handle None readonly_context for _is_tool_selected method + if readonly_context is None: + # When no context provided, include tool if tool_filter allows it + if not self.tool_filter or (isinstance(self.tool_filter, list) and mcp_tool.name in self.tool_filter): + tools.append(mcp_tool) + elif self._is_tool_selected(mcp_tool, readonly_context): tools.append(mcp_tool) return tools diff --git a/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py b/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py index 31288511e..8f82104d3 100644 --- a/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py +++ b/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py @@ -17,9 +17,14 @@ from unittest.mock import patch from authlib.oauth2.rfc6749 import OAuth2Token +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowClientCredentials +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import AuthCredentialTypes from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_schemes import OAuthGrantType from google.adk.auth.auth_schemes import OpenIdConnectWithConfig from google.adk.auth.exchanger.base_credential_exchanger import CredentialExchangError from google.adk.auth.exchanger.oauth2_credential_exchanger import OAuth2CredentialExchanger @@ -218,3 +223,349 @@ async def test_exchange_authlib_not_available(self): # Should return original credential when authlib is not available assert result == credential assert result.oauth2.access_token is None + + +class TestOAuth2CredentialExchangerSSLVerification: + """Test suite for OAuth2CredentialExchanger SSL verification functionality.""" + + @patch("google.adk.auth.exchanger.oauth2_credential_exchanger.OAuth2Session") + @pytest.mark.asyncio + async def test_exchange_client_credentials_ssl_verification_enabled(self, mock_oauth2_session): + """Test client credentials exchange with SSL verification enabled (default).""" + from fastapi.openapi.models import OAuth2, OAuthFlows, OAuthFlowClientCredentials + + # Setup mock + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = OAuth2Token({ + "access_token": "test_access_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + mock_client.fetch_token.return_value = mock_tokens + + # Create OAuth2 scheme with client credentials flow + scheme = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl="https://example.com/token", + scopes={"read": "Read access"} + ) + ) + ) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + ), + ) + + exchanger = OAuth2CredentialExchanger() + result = await exchanger.exchange(credential, scheme, verify_ssl=True) + + # Verify SSL verification is enabled by default + assert hasattr(mock_client, 'verify') + assert mock_client.verify is True + + # Verify token exchange was successful + assert result.oauth2.access_token == "test_access_token" + + @patch("google.adk.auth.exchanger.oauth2_credential_exchanger.urllib3") + @patch("google.adk.auth.exchanger.oauth2_credential_exchanger.OAuth2Session") + @pytest.mark.asyncio + async def test_exchange_client_credentials_ssl_verification_disabled(self, mock_oauth2_session, mock_urllib3): + """Test client credentials exchange with SSL verification disabled.""" + from fastapi.openapi.models import OAuth2, OAuthFlows, OAuthFlowClientCredentials + + # Setup mock + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = OAuth2Token({ + "access_token": "test_access_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + mock_client.fetch_token.return_value = mock_tokens + + # Create OAuth2 scheme with client credentials flow + scheme = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl="https://localhost:9204/token", # Self-signed SSL scenario + scopes={"read": "Read access"} + ) + ) + ) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + ), + ) + + exchanger = OAuth2CredentialExchanger() + result = await exchanger.exchange(credential, scheme, verify_ssl=False) + + # Verify SSL verification is disabled + assert hasattr(mock_client, 'verify') + assert mock_client.verify is False + + # Verify SSL warnings are suppressed + mock_urllib3.disable_warnings.assert_called_once_with(mock_urllib3.exceptions.InsecureRequestWarning) + + # Verify token exchange was successful + assert result.oauth2.access_token == "test_access_token" + + @patch("google.adk.auth.exchanger.oauth2_credential_exchanger.OAuth2Session") + @pytest.mark.asyncio + async def test_exchange_client_credentials_ssl_verification_default_true(self, mock_oauth2_session): + """Test that SSL verification defaults to True when not specified.""" + from fastapi.openapi.models import OAuth2, OAuthFlows, OAuthFlowClientCredentials + + # Setup mock + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = OAuth2Token({ + "access_token": "test_access_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + mock_client.fetch_token.return_value = mock_tokens + + # Create OAuth2 scheme with client credentials flow + scheme = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl="https://example.com/token", + scopes={"read": "Read access"} + ) + ) + ) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + ), + ) + + exchanger = OAuth2CredentialExchanger() + # Call without verify_ssl parameter - should default to True + result = await exchanger.exchange(credential, scheme) + + # Verify SSL verification defaults to True + assert hasattr(mock_client, 'verify') + assert mock_client.verify is True + + # Verify token exchange was successful + assert result.oauth2.access_token == "test_access_token" + + +class TestOAuth2CredentialExchangerClientCredentials: + """Test suite for OAuth2CredentialExchanger client credentials flow.""" + + @patch("google.adk.auth.exchanger.oauth2_credential_exchanger.OAuth2Session") + @pytest.mark.asyncio + async def test_exchange_client_credentials_success(self, mock_oauth2_session): + """Test successful client credentials token exchange.""" + from fastapi.openapi.models import OAuth2, OAuthFlows, OAuthFlowClientCredentials + + # Setup mock + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = OAuth2Token({ + "access_token": "client_creds_access_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + mock_client.fetch_token.return_value = mock_tokens + + # Create OAuth2 scheme with client credentials flow + scheme = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl="https://example.com/token", + scopes={"read": "Read access", "write": "Write access"} + ) + ) + ) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + ), + ) + + exchanger = OAuth2CredentialExchanger() + result = await exchanger.exchange(credential, scheme) + + # Verify token exchange was successful + assert result.oauth2.access_token == "client_creds_access_token" + # Client credentials flow doesn't provide refresh tokens + assert result.oauth2.refresh_token is None or result.oauth2.refresh_token == "None" + + # Verify the correct grant type was used + mock_client.fetch_token.assert_called_once() + call_args = mock_client.fetch_token.call_args + assert call_args[1]["grant_type"] == "client_credentials" + + @pytest.mark.asyncio + async def test_exchange_client_credentials_missing_client_secret(self): + """Test client credentials exchange with missing client secret.""" + from fastapi.openapi.models import OAuth2, OAuthFlows, OAuthFlowClientCredentials + + scheme = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl="https://example.com/token", + scopes={"read": "Read access"} + ) + ) + ) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + # Missing client_secret + ), + ) + + exchanger = OAuth2CredentialExchanger() + result = await exchanger.exchange(credential, scheme) + + # Should return original credential when client secret is missing + assert result == credential + assert result.oauth2.access_token is None + + @patch("google.adk.auth.exchanger.oauth2_credential_exchanger.OAuth2Session") + @pytest.mark.asyncio + async def test_exchange_client_credentials_token_fetch_failure(self, mock_oauth2_session): + """Test client credentials exchange when token fetch fails.""" + from fastapi.openapi.models import OAuth2, OAuthFlows, OAuthFlowClientCredentials + + # Setup mock to raise exception during fetch_token + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_client.fetch_token.side_effect = Exception("Token fetch failed") + + scheme = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl="https://example.com/token", + scopes={"read": "Read access"} + ) + ) + ) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + ), + ) + + exchanger = OAuth2CredentialExchanger() + result = await exchanger.exchange(credential, scheme) + + # Should return original credential when fetch_token fails + assert result == credential + assert result.oauth2.access_token is None + mock_client.fetch_token.assert_called_once() + + def test_get_grant_type_client_credentials(self): + """Test grant type detection for client credentials flow.""" + from fastapi.openapi.models import OAuth2, OAuthFlows, OAuthFlowClientCredentials + + scheme = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl="https://example.com/token", + scopes={"read": "Read access"} + ) + ) + ) + + exchanger = OAuth2CredentialExchanger() + grant_type = exchanger._get_grant_type(scheme) + + assert grant_type == OAuthGrantType.CLIENT_CREDENTIALS + + def test_get_grant_type_authorization_code(self): + """Test grant type detection for authorization code flow.""" + from fastapi.openapi.models import OAuth2, OAuthFlows, OAuthFlowAuthorizationCode + + scheme = OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/auth", + tokenUrl="https://example.com/token", + scopes={"read": "Read access"} + ) + ) + ) + + exchanger = OAuth2CredentialExchanger() + grant_type = exchanger._get_grant_type(scheme) + + assert grant_type == OAuthGrantType.AUTHORIZATION_CODE + + def test_get_grant_type_mixed_flows_prioritizes_client_credentials(self): + """Test that client credentials is prioritized when multiple flows are present.""" + from fastapi.openapi.models import OAuth2, OAuthFlows, OAuthFlowClientCredentials, OAuthFlowAuthorizationCode + + # Create scheme with both client credentials and authorization code flows + scheme = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl="https://example.com/token", + scopes={"read": "Read access"} + ), + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/auth", + tokenUrl="https://example.com/token", + scopes={"read": "Read access"} + ) + ) + ) + + exchanger = OAuth2CredentialExchanger() + grant_type = exchanger._get_grant_type(scheme) + + # Should prioritize client credentials + assert grant_type == OAuthGrantType.CLIENT_CREDENTIALS + + def test_get_grant_type_no_flows(self): + """Test grant type detection when no flows are configured.""" + from fastapi.openapi.models import OAuth2, OAuthFlows + + scheme = OAuth2(flows=OAuthFlows()) + + exchanger = OAuth2CredentialExchanger() + grant_type = exchanger._get_grant_type(scheme) + + assert grant_type is None + + def test_get_grant_type_non_oauth2_scheme(self): + """Test grant type detection for non-OAuth2 schemes.""" + from google.adk.auth.auth_schemes import OpenIdConnectWithConfig + + scheme = OpenIdConnectWithConfig( + openId_connect_url="https://example.com/.well-known/openid_configuration", + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + + exchanger = OAuth2CredentialExchanger() + grant_type = exchanger._get_grant_type(scheme) + + assert grant_type is None diff --git a/tests/unittests/auth/test_credential_manager_oauth2_integration.py b/tests/unittests/auth/test_credential_manager_oauth2_integration.py new file mode 100644 index 000000000..8a14253d9 --- /dev/null +++ b/tests/unittests/auth/test_credential_manager_oauth2_integration.py @@ -0,0 +1,331 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for OAuth2CredentialExchanger integration in CredentialManager.""" + +from unittest.mock import patch +import pytest + +from fastapi.openapi.models import OAuth2, OAuthFlows, OAuthFlowClientCredentials +from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes, OAuth2Auth +from google.adk.auth.auth_tool import AuthConfig +from google.adk.auth.credential_manager import CredentialManager +from google.adk.auth.exchanger.oauth2_credential_exchanger import OAuth2CredentialExchanger + + +class TestCredentialManagerOAuth2Integration: + """Test OAuth2CredentialExchanger integration with CredentialManager.""" + + def test_oauth2_credential_exchanger_registration(self): + """Test that OAuth2CredentialExchanger is registered in CredentialManager.""" + auth_scheme = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl="http://localhost:9204/token", + scopes={"api:read": "Read access"} + ) + ) + ) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret" + ) + ) + + auth_config = AuthConfig( + auth_scheme=auth_scheme, + raw_auth_credential=auth_credential + ) + + credential_manager = CredentialManager(auth_config) + + # Verify OAuth2CredentialExchanger is registered + exchanger = credential_manager._exchanger_registry.get_exchanger(AuthCredentialTypes.OAUTH2) + assert exchanger is not None + assert isinstance(exchanger, OAuth2CredentialExchanger) + + @patch('google.adk.auth.exchanger.oauth2_credential_exchanger.OAuth2Session') + @pytest.mark.asyncio + async def test_oauth2_credential_exchange_flow(self, mock_oauth2_session): + """Test complete OAuth2 credential exchange flow through CredentialManager.""" + # Setup mock OAuth2Session + mock_session_instance = mock_oauth2_session.return_value + mock_session_instance.fetch_token.return_value = { + 'access_token': 'test_access_token', + 'token_type': 'Bearer', + 'expires_in': 3600 + } + + auth_scheme = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl="http://localhost:9204/token", + scopes={"api:read": "Read access"} + ) + ) + ) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret" + ) + ) + + auth_config = AuthConfig( + auth_scheme=auth_scheme, + raw_auth_credential=auth_credential + ) + + credential_manager = CredentialManager(auth_config) + + # Mock callback context + class MockCallbackContext: + def get_auth_response(self, auth_config): + return None + async def load_credential(self, auth_config): + return None + async def save_credential(self, auth_config): + pass + + callback_context = MockCallbackContext() + + # Perform credential exchange + result_credential = await credential_manager.get_auth_credential(callback_context) # type: ignore + + # Verify OAuth2Session was called with correct parameters + mock_oauth2_session.assert_called_once_with( + "test_client_id", + "test_client_secret", + scope="api:read", + token_endpoint_auth_method='client_secret_post' + ) + + # Verify fetch_token was called with correct parameters + mock_session_instance.fetch_token.assert_called_once_with( + "http://localhost:9204/token", + grant_type="client_credentials" + ) + + # Verify result credential has access token + assert result_credential is not None + assert result_credential.oauth2 is not None + assert result_credential.oauth2.access_token == "test_access_token" + + def test_oauth2_credential_fallback_when_no_exchanger_found(self): + """Test that CredentialManager falls back to raw OAuth2 credential when needed.""" + auth_scheme = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl="http://localhost:9204/token", + scopes={"api:read": "Read access"} + ) + ) + ) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret" + ) + ) + + auth_config = AuthConfig( + auth_scheme=auth_scheme, + raw_auth_credential=auth_credential + ) + + credential_manager = CredentialManager(auth_config) + + # Verify that the raw OAuth2 credential fallback logic exists + # This is tested indirectly by verifying the _is_credential_ready method + assert not credential_manager._is_credential_ready() + + # Verify that OAuth2 credentials are handled appropriately + assert auth_credential.auth_type == AuthCredentialTypes.OAUTH2 + assert auth_credential.oauth2 is not None + assert auth_credential.oauth2.client_id == "test_client_id" + + async def test_oauth2_exchanger_processing_with_token(self): + """Test that OAuth2CredentialExchanger properly processes credentials with existing tokens.""" + + from google.adk.auth.auth_tool import AuthConfig + from google.adk.auth.credential_manager import CredentialManager + from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes, OAuth2Auth + from fastapi.openapi.models import OAuth2, OAuthFlows, OAuthFlowClientCredentials + + # Create OAuth2 credential with existing token + oauth2_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client", + client_secret="test_secret", + access_token="existing_token" + ) + ) + + # Create OAuth2 auth scheme + auth_scheme = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl="https://example.com/token", + scopes={"read": "Read access"} + ) + ) + ) + + auth_config = AuthConfig( + raw_auth_credential=oauth2_credential, + auth_scheme=auth_scheme + ) + + manager = CredentialManager(auth_config) + + # Verify OAuth2CredentialExchanger is registered + exchanger = manager._exchanger_registry.get_exchanger(AuthCredentialTypes.OAUTH2) + assert exchanger is not None + assert isinstance(exchanger, OAuth2CredentialExchanger) + + + @patch('google.adk.auth.exchanger.oauth2_credential_exchanger.OAuth2Session') + async def test_credential_manager_ssl_verification_enabled(self, mock_oauth2_session): + """Test CredentialManager passes verify_ssl=True to OAuth2CredentialExchanger.""" + from google.adk.auth.auth_tool import AuthConfig + from google.adk.auth.credential_manager import CredentialManager + from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes, OAuth2Auth + from google.adk.agents.readonly_context import ReadonlyContext + from fastapi.openapi.models import OAuth2, OAuthFlows, OAuthFlowClientCredentials + + # Setup mock + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = { + "access_token": "test_access_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + } + mock_client.fetch_token.return_value = mock_tokens + + # Create OAuth2 credential + oauth2_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client", + client_secret="test_secret" + ) + ) + + # Create OAuth2 auth scheme + auth_scheme = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl="https://example.com/token", + scopes={"read": "Read access"} + ) + ) + ) + + auth_config = AuthConfig( + raw_auth_credential=oauth2_credential, + auth_scheme=auth_scheme + ) + + manager = CredentialManager(auth_config) + + # Create mock callback context + mock_context = Mock() + mock_context._invocation_context = Mock() + mock_context._invocation_context.credential_service = None + mock_context.get_auth_response = Mock(return_value=None) + mock_context.load_credential = AsyncMock(return_value=None) + mock_context.save_credential = AsyncMock() + + # Test with SSL verification enabled (default) + result = await manager.get_auth_credential(mock_context, verify_ssl=True) + + # Verify SSL verification is enabled + assert hasattr(mock_client, 'verify') + assert mock_client.verify is True + assert result is not None + + + @patch('google.adk.auth.exchanger.oauth2_credential_exchanger.urllib3') + @patch('google.adk.auth.exchanger.oauth2_credential_exchanger.OAuth2Session') + async def test_credential_manager_ssl_verification_disabled(self, mock_oauth2_session, mock_urllib3): + """Test CredentialManager passes verify_ssl=False to OAuth2CredentialExchanger.""" + from google.adk.auth.auth_tool import AuthConfig + from google.adk.auth.credential_manager import CredentialManager + from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes, OAuth2Auth + from fastapi.openapi.models import OAuth2, OAuthFlows, OAuthFlowClientCredentials + + # Setup mock + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = { + "access_token": "test_access_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + } + mock_client.fetch_token.return_value = mock_tokens + + # Create OAuth2 credential + oauth2_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client", + client_secret="test_secret" + ) + ) + + # Create OAuth2 auth scheme + auth_scheme = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl="https://localhost:9204/token", # Self-signed SSL scenario + scopes={"read": "Read access"} + ) + ) + ) + + auth_config = AuthConfig( + raw_auth_credential=oauth2_credential, + auth_scheme=auth_scheme + ) + + manager = CredentialManager(auth_config) + + # Create mock callback context + mock_context = Mock() + mock_context._invocation_context = Mock() + mock_context._invocation_context.credential_service = None + mock_context.get_auth_response = Mock(return_value=None) + mock_context.load_credential = AsyncMock(return_value=None) + mock_context.save_credential = AsyncMock() + + # Test with SSL verification disabled (for self-signed certificates) + result = await manager.get_auth_credential(mock_context, verify_ssl=False) + + # Verify SSL verification is disabled + assert hasattr(mock_client, 'verify') + assert mock_client.verify is False + + # Verify SSL warnings are suppressed + mock_urllib3.disable_warnings.assert_called_once_with(mock_urllib3.exceptions.InsecureRequestWarning) + + assert result is not None \ No newline at end of file diff --git a/tests/unittests/auth/test_oauth2_discovery_util.py b/tests/unittests/auth/test_oauth2_discovery_util.py new file mode 100644 index 000000000..5806a38dc --- /dev/null +++ b/tests/unittests/auth/test_oauth2_discovery_util.py @@ -0,0 +1,439 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for OAuth2 discovery utilities.""" + +import json +from unittest.mock import AsyncMock, Mock, patch + +import httpx +import pytest +from fastapi.openapi.models import OAuth2, OAuthFlows, OAuthFlowClientCredentials + +# Import from the correct path +import sys +sys.path.insert(0, '../../../src') + +from google.adk.auth.oauth2_discovery_util import ( + discover_oauth_configuration, + create_oauth_scheme_from_discovery, + _validate_oauth_discovery_response, + OAUTH_PROTECTED_RESOURCE_DISCOVERY, + OAUTH_AUTHORIZATION_SERVER_DISCOVERY, +) + + +class TestDiscoverOAuthConfiguration: + """Test suite for discover_oauth_configuration function.""" + + @patch("google.adk.auth.oauth2_discovery_util.httpx.AsyncClient") + @pytest.mark.asyncio + async def test_discovery_oauth_protected_resource_success(self, mock_async_client): + """Test successful OAuth discovery using oauth-protected-resource endpoint.""" + # Setup mock response + mock_response = Mock() + mock_response.json.return_value = { + "authorization_servers": ["https://auth.example.com"] + } + mock_response.raise_for_status = Mock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + mock_async_client.return_value.__aenter__.return_value = mock_client + + # Test discovery + result = await discover_oauth_configuration("https://api.example.com") + + # Verify result + assert result == {"authorization_servers": ["https://auth.example.com"]} + + # Verify first endpoint was called + mock_client.get.assert_called_with( + f"https://api.example.com/{OAUTH_PROTECTED_RESOURCE_DISCOVERY}" + ) + + @patch("google.adk.auth.oauth2_discovery_util.httpx.AsyncClient") + @pytest.mark.asyncio + async def test_discovery_oauth_authorization_server_success(self, mock_async_client): + """Test successful OAuth discovery using oauth-authorization-server endpoint.""" + # Setup mock to fail on first endpoint, succeed on second + mock_response1 = Mock() + mock_response1.raise_for_status.side_effect = httpx.HTTPStatusError( + "Not found", request=Mock(), response=Mock(status_code=404) + ) + + mock_response2 = Mock() + mock_response2.json.return_value = { + "token_endpoint": "https://auth.example.com/token" + } + mock_response2.raise_for_status = Mock() + + mock_client = AsyncMock() + mock_client.get.side_effect = [mock_response1, mock_response2] + mock_async_client.return_value.__aenter__.return_value = mock_client + + # Test discovery + result = await discover_oauth_configuration("https://api.example.com") + + # Verify result + assert result == {"token_endpoint": "https://auth.example.com/token"} + + # Verify both endpoints were called + assert mock_client.get.call_count == 2 + + @patch("google.adk.auth.oauth2_discovery_util.httpx.AsyncClient") + @pytest.mark.asyncio + async def test_discovery_failure_all_endpoints(self, mock_async_client): + """Test OAuth discovery failure when all endpoints fail.""" + mock_response = Mock() + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Not found", request=Mock(), response=Mock(status_code=404) + ) + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + mock_async_client.return_value.__aenter__.return_value = mock_client + + # Test discovery + result = await discover_oauth_configuration("https://api.example.com") + + # Verify result + assert result is None + + # Verify both endpoints were tried + assert mock_client.get.call_count == 2 + + @patch("google.adk.auth.oauth2_discovery_util.httpx.AsyncClient") + @pytest.mark.asyncio + async def test_discovery_invalid_response(self, mock_async_client): + """Test OAuth discovery with invalid response format.""" + mock_response = Mock() + mock_response.json.return_value = {"invalid": "response"} + mock_response.raise_for_status = Mock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + mock_async_client.return_value.__aenter__.return_value = mock_client + + # Test discovery + result = await discover_oauth_configuration("https://api.example.com") + + # Verify result is None due to validation failure + assert result is None + + @patch("google.adk.auth.oauth2_discovery_util.httpx.AsyncClient") + @pytest.mark.asyncio + async def test_discovery_json_decode_error(self, mock_async_client): + """Test OAuth discovery with JSON decode error.""" + mock_response = Mock() + mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0) + mock_response.raise_for_status = Mock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + mock_async_client.return_value.__aenter__.return_value = mock_client + + # Test discovery + result = await discover_oauth_configuration("https://api.example.com") + + # Verify result is None due to JSON error + assert result is None + + @pytest.mark.asyncio + async def test_discovery_with_custom_timeout(self): + """Test OAuth discovery with custom timeout.""" + with patch("google.adk.auth.oauth2_discovery_util.httpx.AsyncClient") as mock_async_client: + mock_client = AsyncMock() + mock_async_client.return_value.__aenter__.return_value = mock_client + + await discover_oauth_configuration("https://api.example.com", timeout=5.0) + + # Verify timeout was passed to AsyncClient + mock_async_client.assert_called_with(timeout=5.0) + + +class TestValidateOAuthDiscoveryResponse: + """Test suite for _validate_oauth_discovery_response function.""" + + def test_validate_oauth_protected_resource_valid(self): + """Test validation of valid oauth-protected-resource response.""" + config = {"authorization_servers": ["https://auth.example.com"]} + assert _validate_oauth_discovery_response(config) is True + + def test_validate_oauth_protected_resource_empty_list(self): + """Test validation of oauth-protected-resource response with empty server list.""" + config = {"authorization_servers": []} + assert _validate_oauth_discovery_response(config) is False + + def test_validate_oauth_authorization_server_valid(self): + """Test validation of valid oauth-authorization-server response.""" + config = {"token_endpoint": "https://auth.example.com/token"} + assert _validate_oauth_discovery_response(config) is True + + def test_validate_oauth_authorization_server_empty_endpoint(self): + """Test validation of oauth-authorization-server response with empty endpoint.""" + config = {"token_endpoint": ""} + assert _validate_oauth_discovery_response(config) is False + + def test_validate_invalid_response(self): + """Test validation of invalid response format.""" + config = {"invalid": "response"} + assert _validate_oauth_discovery_response(config) is False + + def test_validate_mixed_valid_response(self): + """Test validation prioritizes oauth-protected-resource over authorization-server.""" + config = { + "authorization_servers": ["https://auth.example.com"], + "token_endpoint": "https://auth.example.com/token" + } + # Should return True because authorization_servers is checked first + assert _validate_oauth_discovery_response(config) is True + + +class TestCreateOAuthSchemeFromDiscovery: + """Test suite for create_oauth_scheme_from_discovery function.""" + + @patch("google.adk.auth.oauth2_discovery_util.discover_oauth_configuration") + @pytest.mark.asyncio + async def test_create_scheme_oauth_protected_resource(self, mock_discover): + """Test OAuth scheme creation from oauth-protected-resource discovery.""" + # Setup mock discovery response + mock_discover.return_value = { + "authorization_servers": ["https://auth.example.com"] + } + + # Mock nested discovery for auth server + with patch("google.adk.auth.oauth2_discovery_util.discover_oauth_configuration") as mock_nested_discover: + mock_nested_discover.side_effect = [ + {"authorization_servers": ["https://auth.example.com"]}, # First call + {"token_endpoint": "https://auth.example.com/token"} # Nested call for auth server + ] + + result = await create_oauth_scheme_from_discovery("https://api.example.com") + + # Verify result + assert isinstance(result, OAuth2) + assert result.flows.clientCredentials.tokenUrl == "https://auth.example.com/token" + assert "read" in result.flows.clientCredentials.scopes + assert "write" in result.flows.clientCredentials.scopes + + @patch("google.adk.auth.oauth2_discovery_util.discover_oauth_configuration") + @pytest.mark.asyncio + async def test_create_scheme_oauth_authorization_server(self, mock_discover): + """Test OAuth scheme creation from oauth-authorization-server discovery.""" + mock_discover.return_value = { + "token_endpoint": "https://auth.example.com/token" + } + + result = await create_oauth_scheme_from_discovery("https://api.example.com") + + # Verify result + assert isinstance(result, OAuth2) + assert result.flows.clientCredentials.tokenUrl == "https://auth.example.com/token" + assert "read" in result.flows.clientCredentials.scopes + assert "write" in result.flows.clientCredentials.scopes + + @patch("google.adk.auth.oauth2_discovery_util.discover_oauth_configuration") + @pytest.mark.asyncio + async def test_create_scheme_custom_scopes(self, mock_discover): + """Test OAuth scheme creation with custom scopes.""" + mock_discover.return_value = { + "token_endpoint": "https://auth.example.com/token" + } + + custom_scopes = ["admin", "user:read"] + result = await create_oauth_scheme_from_discovery( + "https://api.example.com", + scopes=custom_scopes + ) + + # Verify custom scopes + assert isinstance(result, OAuth2) + assert "admin" in result.flows.clientCredentials.scopes + assert "user:read" in result.flows.clientCredentials.scopes + assert "read" not in result.flows.clientCredentials.scopes + + @patch("google.adk.auth.oauth2_discovery_util.discover_oauth_configuration") + @pytest.mark.asyncio + async def test_create_scheme_discovery_failure(self, mock_discover): + """Test OAuth scheme creation when discovery fails.""" + mock_discover.return_value = None + + result = await create_oauth_scheme_from_discovery("https://api.example.com") + + # Verify result is None + assert result is None + + @patch("google.adk.auth.oauth2_discovery_util.discover_oauth_configuration") + @pytest.mark.asyncio + async def test_create_scheme_fallback_token_endpoint(self, mock_discover): + """Test OAuth scheme creation with fallback token endpoint.""" + # Setup mock to fail on nested discovery + mock_discover.side_effect = [ + {"authorization_servers": ["https://auth.example.com"]}, # First call + None # Nested call fails + ] + + result = await create_oauth_scheme_from_discovery("https://api.example.com") + + # Verify fallback endpoint is used + assert isinstance(result, OAuth2) + assert result.flows.clientCredentials.tokenUrl == "https://auth.example.com/token" + + @patch("google.adk.auth.oauth2_discovery_util.discover_oauth_configuration") + @pytest.mark.asyncio + async def test_create_scheme_no_token_endpoint(self, mock_discover): + """Test OAuth scheme creation when no token endpoint can be determined.""" + mock_discover.return_value = {"invalid": "response"} + + result = await create_oauth_scheme_from_discovery("https://api.example.com") + + # Verify result is None + assert result is None + + +class TestOAuth2DiscoverySSLVerification: + """Test suite for SSL verification functionality in OAuth2 discovery.""" + + @patch("google.adk.auth.oauth2_discovery_util.httpx.AsyncClient") + @pytest.mark.asyncio + async def test_discovery_ssl_verification_enabled(self, mock_async_client): + """Test OAuth discovery with SSL verification enabled (default).""" + # Setup mock client + mock_client = AsyncMock() + mock_async_client.return_value.__aenter__.return_value = mock_client + + # Setup mock response for oauth-protected-resource + mock_response = Mock() + mock_response.json.return_value = { + "authorization_servers": ["https://auth.example.com"] + } + mock_response.raise_for_status = Mock() + mock_client.get.return_value = mock_response + + result = await create_oauth_scheme_from_discovery( + "https://api.example.com", + scopes=["read"], + timeout=10.0, + verify_ssl=True + ) + + # Verify AsyncClient was created with SSL verification enabled + mock_async_client.assert_called_with(timeout=10.0, verify=True) + + # Verify discovery was attempted + mock_client.get.assert_called() + + @patch("google.adk.auth.oauth2_discovery_util.httpx.AsyncClient") + @pytest.mark.asyncio + async def test_discovery_ssl_verification_disabled(self, mock_async_client): + """Test OAuth discovery with SSL verification disabled for self-signed certificates.""" + # Setup mock client + mock_client = AsyncMock() + mock_async_client.return_value.__aenter__.return_value = mock_client + + # Setup mock response for oauth-protected-resource + mock_response = Mock() + mock_response.json.return_value = { + "authorization_servers": ["https://localhost:9204"] + } + mock_response.raise_for_status = Mock() + mock_client.get.return_value = mock_response + + result = await create_oauth_scheme_from_discovery( + "https://localhost:9204", # Self-signed SSL scenario + scopes=["read"], + timeout=10.0, + verify_ssl=False # Disable SSL verification + ) + + # Verify AsyncClient was created with SSL verification disabled + mock_async_client.assert_called_with(timeout=10.0, verify=False) + + # Verify discovery was attempted + mock_client.get.assert_called() + + @patch("google.adk.auth.oauth2_discovery_util.httpx.AsyncClient") + @pytest.mark.asyncio + async def test_discovery_ssl_verification_default_true(self, mock_async_client): + """Test that SSL verification defaults to True when not specified.""" + # Setup mock client + mock_client = AsyncMock() + mock_async_client.return_value.__aenter__.return_value = mock_client + + # Setup mock response for oauth-protected-resource + mock_response = Mock() + mock_response.json.return_value = { + "authorization_servers": ["https://auth.example.com"] + } + mock_response.raise_for_status = Mock() + mock_client.get.return_value = mock_response + + # Call without verify_ssl parameter - should default to True + result = await create_oauth_scheme_from_discovery( + "https://api.example.com", + scopes=["read"], + timeout=10.0 + ) + + # Verify AsyncClient was created with SSL verification enabled by default + mock_async_client.assert_called_with(timeout=10.0, verify=True) + + # Verify discovery was attempted + mock_client.get.assert_called() + + @patch("google.adk.auth.oauth2_discovery_util.httpx.AsyncClient") + @pytest.mark.asyncio + async def test_discovery_two_stage_ssl_verification(self, mock_async_client): + """Test that SSL verification is applied to both stages of OAuth discovery.""" + # Setup mock client + mock_client = AsyncMock() + mock_async_client.return_value.__aenter__.return_value = mock_client + + # Setup mock responses for two-stage discovery + protected_resource_response = Mock() + protected_resource_response.json.return_value = { + "authorization_servers": ["https://localhost:9204"] + } + protected_resource_response.raise_for_status = Mock() + + auth_server_response = Mock() + auth_server_response.json.return_value = { + "token_endpoint": "https://localhost:9204/token" + } + auth_server_response.raise_for_status = Mock() + + # Mock both calls in sequence + mock_client.get.side_effect = [protected_resource_response, auth_server_response] + + result = await create_oauth_scheme_from_discovery( + "https://localhost:9204", + scopes=["read"], + timeout=10.0, + verify_ssl=False # Disable SSL verification for self-signed certs + ) + + # Verify AsyncClient was created with SSL verification disabled for both stages + # Should be called twice - once for each stage of discovery + assert mock_async_client.call_count >= 2 + for call in mock_async_client.call_args_list: + args, kwargs = call + assert kwargs.get('verify') is False + + # Verify both discovery calls were made + assert mock_client.get.call_count == 2 + + # Verify result + assert isinstance(result, OAuth2) + assert result.flows.clientCredentials.tokenUrl == "https://localhost:9204/token" \ No newline at end of file diff --git a/tests/unittests/tools/mcp_tool/test_mcp_auth_discovery.py b/tests/unittests/tools/mcp_tool/test_mcp_auth_discovery.py new file mode 100644 index 000000000..eada47e80 --- /dev/null +++ b/tests/unittests/tools/mcp_tool/test_mcp_auth_discovery.py @@ -0,0 +1,247 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for MCP Auth Discovery configuration.""" + +import pytest +import sys + +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" +) + +try: + from google.adk.tools.mcp_tool.mcp_auth_discovery import MCPAuthDiscovery +except ImportError: + if sys.version_info >= (3, 10): + raise + # Create dummy class for older Python versions + class MCPAuthDiscovery: + pass + + +class TestMCPAuthDiscovery: + """Test suite for MCPAuthDiscovery configuration class.""" + + def test_basic_initialization(self): + """Test basic MCPAuthDiscovery initialization with required parameters.""" + discovery = MCPAuthDiscovery( + base_url="http://localhost:9204", + timeout=10.0, + enabled=True + ) + + assert discovery.base_url == "http://localhost:9204" + assert discovery.timeout == 10.0 + assert discovery.enabled is True + assert discovery.is_enabled is True + + def test_mcp_auth_discovery_defaults(self): + """Test MCPAuthDiscovery with default values.""" + discovery = MCPAuthDiscovery(base_url="http://localhost:9204") + + assert discovery.base_url == "http://localhost:9204" + assert discovery.timeout == 10.0 + assert discovery.enabled is True + assert discovery.verify_ssl is True + assert discovery.is_enabled is True + + + def test_mcp_auth_discovery_custom_values(self): + """Test MCPAuthDiscovery with custom values.""" + discovery = MCPAuthDiscovery( + base_url="https://custom-server:8080/", + timeout=15.0, + enabled=False, + verify_ssl=False + ) + + assert discovery.base_url == "https://custom-server:8080" # Trailing slash removed + assert discovery.timeout == 15.0 + assert discovery.enabled is False + assert discovery.verify_ssl is False + assert discovery.is_enabled is False # Disabled + + + def test_mcp_auth_discovery_self_signed_ssl(self): + """Test MCPAuthDiscovery configured for self-signed SSL certificates.""" + discovery = MCPAuthDiscovery( + base_url="https://localhost:9204", + verify_ssl=False, # Disable SSL verification for self-signed certs + timeout=5.0 + ) + + assert discovery.base_url == "https://localhost:9204" + assert discovery.verify_ssl is False + assert discovery.timeout == 5.0 + assert discovery.enabled is True + assert discovery.is_enabled is True + + def test_url_normalization(self): + """Test that base URLs are properly normalized (trailing slash removed).""" + discovery = MCPAuthDiscovery(base_url="http://localhost:9204/") + + assert discovery.base_url == "http://localhost:9204" # Trailing slash removed + + def test_complex_url_normalization(self): + """Test URL normalization with complex paths.""" + discovery = MCPAuthDiscovery(base_url="http://server.example.com:8080/api/v1/") + + assert discovery.base_url == "http://server.example.com:8080/api/v1" + + def test_disabled_discovery(self): + """Test MCPAuthDiscovery when explicitly disabled.""" + discovery = MCPAuthDiscovery( + base_url="http://localhost:9204", + enabled=False + ) + + assert discovery.enabled is False + assert discovery.is_enabled is False + + def test_empty_base_url_validation(self): + """Test that empty base URL raises ValueError.""" + with pytest.raises(ValueError, match="Discovery base_url is required and cannot be empty"): + MCPAuthDiscovery(base_url="") + + def test_whitespace_base_url_validation(self): + """Test that whitespace-only base URL raises ValueError.""" + with pytest.raises(ValueError, match="Discovery base_url is required and cannot be empty"): + MCPAuthDiscovery(base_url=" ") + + def test_none_base_url_validation(self): + """Test that None base URL raises ValueError.""" + with pytest.raises(ValueError, match="Discovery base_url is required and cannot be empty"): + MCPAuthDiscovery(base_url=None) # type: ignore + + def test_negative_timeout_validation(self): + """Test that negative timeout raises ValueError.""" + with pytest.raises(ValueError, match="Discovery timeout must be positive"): + MCPAuthDiscovery(base_url="http://localhost:9204", timeout=-1.0) + + def test_zero_timeout_validation(self): + """Test that zero timeout raises ValueError.""" + with pytest.raises(ValueError, match="Discovery timeout must be positive"): + MCPAuthDiscovery(base_url="http://localhost:9204", timeout=0.0) + + def test_is_enabled_property_with_empty_url(self): + """Test is_enabled property returns False when base_url is effectively empty.""" + # This test verifies the property works correctly even if validation is bypassed + discovery = MCPAuthDiscovery.__new__(MCPAuthDiscovery) + discovery.base_url = "" + discovery.enabled = True + + assert discovery.is_enabled is False + + def test_valid_timeout_values(self): + """Test various valid timeout values.""" + test_timeouts = [0.1, 1.0, 5.0, 30.0, 120.0] + + for timeout in test_timeouts: + discovery = MCPAuthDiscovery( + base_url="http://localhost:9204", + timeout=timeout + ) + assert discovery.timeout == timeout + + def test_https_url(self): + """Test MCPAuthDiscovery with HTTPS URL.""" + discovery = MCPAuthDiscovery(base_url="https://secure.example.com") + + assert discovery.base_url == "https://secure.example.com" + assert discovery.is_enabled is True + + def test_url_with_port(self): + """Test MCPAuthDiscovery with URL including port.""" + discovery = MCPAuthDiscovery(base_url="http://localhost:8080") + + assert discovery.base_url == "http://localhost:8080" + assert discovery.is_enabled is True + + def test_dataclass_equality(self): + """Test that MCPAuthDiscovery instances with same values are equal.""" + discovery1 = MCPAuthDiscovery( + base_url="http://localhost:9204", + timeout=10.0, + enabled=True + ) + discovery2 = MCPAuthDiscovery( + base_url="http://localhost:9204", + timeout=10.0, + enabled=True + ) + + assert discovery1 == discovery2 + + def test_dataclass_inequality(self): + """Test that MCPAuthDiscovery instances with different values are not equal.""" + discovery1 = MCPAuthDiscovery(base_url="http://localhost:9204") + discovery2 = MCPAuthDiscovery(base_url="http://localhost:9205") + + assert discovery1 != discovery2 + + def test_string_representation(self): + """Test string representation of MCPAuthDiscovery.""" + discovery = MCPAuthDiscovery( + base_url="http://localhost:9204", + timeout=15.0, + enabled=True + ) + + repr_str = repr(discovery) + assert "MCPAuthDiscovery" in repr_str + assert "http://localhost:9204" in repr_str + assert "15.0" in repr_str + assert "True" in repr_str + + +def test_mcp_auth_discovery_optional_base_url(): + """Test MCPAuthDiscovery with optional base_url (None).""" + discovery = MCPAuthDiscovery( + verify_ssl=False, + timeout=15.0 + ) + + assert discovery.base_url is None + assert discovery.verify_ssl is False + assert discovery.timeout == 15.0 + assert discovery.enabled is True + assert discovery.is_enabled is True # Should be enabled even without base_url + + +def test_mcp_auth_discovery_override_ssl_only(): + """Test MCPAuthDiscovery overriding only SSL verification.""" + discovery = MCPAuthDiscovery(verify_ssl=False) + + assert discovery.base_url is None + assert discovery.verify_ssl is False + assert discovery.timeout == 10.0 # Default + assert discovery.enabled is True # Default + assert discovery.is_enabled is True + + +def test_mcp_auth_discovery_multiple_overrides(): + """Test MCPAuthDiscovery overriding multiple settings without base_url.""" + discovery = MCPAuthDiscovery( + timeout=20.0, + verify_ssl=False, + enabled=True + ) + + assert discovery.base_url is None + assert discovery.timeout == 20.0 + assert discovery.verify_ssl is False + assert discovery.enabled is True + assert discovery.is_enabled is True \ No newline at end of file diff --git a/tests/unittests/tools/mcp_tool/test_mcp_toolset_oauth_discovery.py b/tests/unittests/tools/mcp_tool/test_mcp_toolset_oauth_discovery.py new file mode 100644 index 000000000..f7a17c7da --- /dev/null +++ b/tests/unittests/tools/mcp_tool/test_mcp_toolset_oauth_discovery.py @@ -0,0 +1,384 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for MCPToolset OAuth discovery functionality.""" + +import pytest +import sys +from unittest.mock import AsyncMock, Mock, patch + +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" +) + +try: + from fastapi.openapi.models import OAuth2, OAuthFlows, OAuthFlowClientCredentials + from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes, OAuth2Auth + from google.adk.tools.mcp_tool.mcp_auth_discovery import MCPAuthDiscovery + from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset + from google.adk.tools.mcp_tool.mcp_session_manager import ( + StreamableHTTPConnectionParams, + SseConnectionParams, + StdioConnectionParams + ) + from mcp import StdioServerParameters +except ImportError: + if sys.version_info >= (3, 10): + raise + # Create dummy classes for older Python versions + class OAuth2: + pass + class OAuthFlows: + pass + class OAuthFlowClientCredentials: + pass + class AuthCredential: + pass + class AuthCredentialTypes: + OAUTH2 = "oauth2" + class OAuth2Auth: + pass + class MCPAuthDiscovery: + pass + class MCPToolset: + pass + class StreamableHTTPConnectionParams: + pass + class SseConnectionParams: + pass + class StdioConnectionParams: + pass + class StdioServerParameters: + pass + + +class TestMCPToolsetOAuthDiscovery: + """Test suite for MCPToolset OAuth discovery functionality.""" + + def test_default_auth_discovery_streamable_http(self): + """Test that MCPToolset creates default MCPAuthDiscovery for StreamableHTTP connections.""" + connection_params = StreamableHTTPConnectionParams( + url="http://localhost:9204/mcp/" + ) + + toolset = MCPToolset(connection_params=connection_params) + + # Verify default auth discovery was created + assert toolset._auth_discovery is not None + assert toolset._auth_discovery.base_url == "http://localhost:9204" + assert toolset._auth_discovery.enabled is True + assert toolset._auth_discovery.timeout == 10.0 + + def test_default_auth_discovery_sse(self): + """Test that MCPToolset creates default MCPAuthDiscovery for SSE connections.""" + connection_params = SseConnectionParams( + url="http://server.example.com:8080/sse/" + ) + + toolset = MCPToolset(connection_params=connection_params) + + # Verify default auth discovery was created + assert toolset._auth_discovery is not None + assert toolset._auth_discovery.base_url == "http://server.example.com:8080" + assert toolset._auth_discovery.enabled is True + + def test_default_auth_discovery_stdio_disabled(self): + """Test that MCPToolset disables OAuth discovery for Stdio connections.""" + connection_params = StdioConnectionParams( + server_params=StdioServerParameters(command="test", args=[]) + ) + + toolset = MCPToolset(connection_params=connection_params) + + # Verify auth discovery is disabled for stdio + assert toolset._auth_discovery is not None + assert toolset._auth_discovery.enabled is False + + def test_explicit_auth_discovery_overrides_default(self): + """Test that explicit auth_discovery parameter overrides default behavior.""" + connection_params = StreamableHTTPConnectionParams( + url="http://localhost:9204/mcp/" + ) + + custom_discovery = MCPAuthDiscovery( + base_url="http://custom-auth.example.com", + timeout=15.0 + ) + + toolset = MCPToolset( + connection_params=connection_params, + auth_discovery=custom_discovery + ) + + # Verify custom discovery is used + assert toolset._auth_discovery is custom_discovery + assert toolset._auth_discovery.base_url == "http://custom-auth.example.com" + assert toolset._auth_discovery.timeout == 15.0 + + def test_disabled_auth_discovery_override(self): + """Test that explicitly disabled auth_discovery overrides default enabling.""" + connection_params = StreamableHTTPConnectionParams( + url="http://localhost:9204/mcp/" + ) + + disabled_discovery = MCPAuthDiscovery( + base_url="http://localhost:9204", + enabled=False + ) + + toolset = MCPToolset( + connection_params=connection_params, + auth_discovery=disabled_discovery + ) + + # Verify discovery is disabled despite HTTP connection + assert toolset._auth_discovery.enabled is False + assert toolset._auth_discovery.is_enabled is False + + @patch("google.adk.tools.mcp_tool.mcp_toolset.create_oauth_scheme_from_discovery") + @pytest.mark.asyncio + async def test_oauth_discovery_not_attempted_when_disabled(self, mock_discovery): + """Test that OAuth discovery is not attempted when disabled.""" + connection_params = StreamableHTTPConnectionParams( + url="http://localhost:9204/mcp/" + ) + + toolset = MCPToolset( + connection_params=connection_params, + auth_discovery=MCPAuthDiscovery( + base_url="http://localhost:9204", + enabled=False + ) + ) + + # Call the discovery method + await toolset._perform_oauth_discovery() + + # Verify discovery was not attempted + mock_discovery.assert_not_called() + assert toolset._oauth_discovery_attempted is False + + @patch("google.adk.tools.mcp_tool.mcp_toolset.create_oauth_scheme_from_discovery") + @pytest.mark.asyncio + async def test_oauth_discovery_with_empty_token_url(self, mock_discovery): + """Test OAuth discovery when auth_scheme has empty tokenUrl.""" + mock_discovery.return_value = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl="http://localhost:9204/token", + scopes={"api:read": "Read access"} + ) + ) + ) + + connection_params = StreamableHTTPConnectionParams( + url="http://localhost:9204/mcp/" + ) + + auth_scheme = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl="", # Empty tokenUrl triggers discovery + scopes={"api:read": "Read access"} + ) + ) + ) + + toolset = MCPToolset( + connection_params=connection_params, + auth_scheme=auth_scheme + ) + + # Call the discovery method + await toolset._perform_oauth_discovery() + + # Verify discovery was attempted with correct parameters + mock_discovery.assert_called_once_with( + base_url="http://localhost:9204", + scopes=["api:read"], + timeout=10.0 + ) + + # Verify tokenUrl was updated + assert isinstance(toolset._auth_scheme, OAuth2) + assert toolset._auth_scheme.flows.clientCredentials.tokenUrl == "http://localhost:9204/token" + assert toolset._oauth_discovery_attempted is True + + @patch("google.adk.tools.mcp_tool.mcp_toolset.create_oauth_scheme_from_discovery") + @pytest.mark.asyncio + async def test_oauth_discovery_with_no_auth_scheme(self, mock_discovery): + """Test OAuth discovery when no auth_scheme is provided.""" + mock_discovery.return_value = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl="http://localhost:9204/token", + scopes={} + ) + ) + ) + + connection_params = StreamableHTTPConnectionParams( + url="http://localhost:9204/mcp/" + ) + + toolset = MCPToolset(connection_params=connection_params) + + # Call the discovery method + await toolset._perform_oauth_discovery() + + # Verify discovery was attempted + mock_discovery.assert_called_once_with( + base_url="http://localhost:9204", + scopes=None, + timeout=10.0 + ) + + # Verify auth_scheme was set from discovery + assert toolset._auth_scheme is not None + assert isinstance(toolset._auth_scheme, OAuth2) + assert toolset._auth_scheme.flows.clientCredentials.tokenUrl == "http://localhost:9204/token" + + @patch("google.adk.tools.mcp_tool.mcp_toolset.create_oauth_scheme_from_discovery") + @pytest.mark.asyncio + async def test_oauth_discovery_failure_handling(self, mock_discovery): + """Test handling of OAuth discovery failures.""" + mock_discovery.return_value = None # Discovery fails + + connection_params = StreamableHTTPConnectionParams( + url="http://localhost:9204/mcp/" + ) + + toolset = MCPToolset(connection_params=connection_params) + + # Call the discovery method + await toolset._perform_oauth_discovery() + + # Verify discovery was attempted but auth_scheme remains None + mock_discovery.assert_called_once() + assert toolset._auth_scheme is None + assert toolset._oauth_discovery_attempted is True + + @patch("google.adk.tools.mcp_tool.mcp_toolset.create_oauth_scheme_from_discovery") + @pytest.mark.asyncio + async def test_oauth_discovery_exception_handling(self, mock_discovery): + """Test handling of exceptions during OAuth discovery.""" + mock_discovery.side_effect = Exception("Discovery failed") + + connection_params = StreamableHTTPConnectionParams( + url="http://localhost:9204/mcp/" + ) + + toolset = MCPToolset(connection_params=connection_params) + + # Call the discovery method - should not raise + await toolset._perform_oauth_discovery() + + # Verify discovery was attempted and exception was handled + mock_discovery.assert_called_once() + assert toolset._oauth_discovery_attempted is True + + @pytest.mark.asyncio + async def test_oauth_discovery_only_attempted_once(self): + """Test that OAuth discovery is only attempted once even with multiple calls.""" + connection_params = StreamableHTTPConnectionParams( + url="http://localhost:9204/mcp/" + ) + + toolset = MCPToolset(connection_params=connection_params) + + with patch("google.adk.tools.mcp_tool.mcp_toolset.create_oauth_scheme_from_discovery") as mock_discovery: + mock_discovery.return_value = None + + # Call discovery multiple times + await toolset._perform_oauth_discovery() + await toolset._perform_oauth_discovery() + await toolset._perform_oauth_discovery() + + # Verify discovery was only called once + assert mock_discovery.call_count == 1 + assert toolset._oauth_discovery_attempted is True + + def test_url_parsing_with_complex_paths(self): + """Test URL parsing with complex paths extracts correct base URL.""" + connection_params = StreamableHTTPConnectionParams( + url="https://api.example.com:8443/services/v1/mcp/stream" + ) + + toolset = MCPToolset(connection_params=connection_params) + + # Verify correct base URL extraction + assert toolset._auth_discovery.base_url == "https://api.example.com:8443" + + def test_url_parsing_with_query_parameters(self): + """Test URL parsing ignores query parameters and fragments.""" + connection_params = StreamableHTTPConnectionParams( + url="http://localhost:9204/mcp/?version=1.0&debug=true#section" + ) + + toolset = MCPToolset(connection_params=connection_params) + + # Verify query parameters and fragments are ignored + assert toolset._auth_discovery.base_url == "http://localhost:9204" + + @patch("google.adk.auth.credential_manager.CredentialManager.get_auth_credential") + @patch("google.adk.tools.mcp_tool.mcp_session_manager.MCPSessionManager.create_session") + @pytest.mark.asyncio + async def test_token_exchange_before_session_creation(self, mock_create_session, mock_get_credential): + """Test that token exchange happens before session creation.""" + # Setup mocks + mock_credential = Mock() + mock_credential.oauth2.access_token = "test_access_token" + mock_get_credential.return_value = mock_credential + + mock_session = AsyncMock() + mock_session.list_tools.return_value = Mock(tools=[]) + mock_create_session.return_value = mock_session + + connection_params = StreamableHTTPConnectionParams( + url="http://localhost:9204/mcp/" + ) + + auth_scheme = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl="http://localhost:9204/token", + scopes={"api:read": "Read access"} + ) + ) + ) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth(client_id="test_id", client_secret="test_secret") + ) + + toolset = MCPToolset( + connection_params=connection_params, + auth_scheme=auth_scheme, + auth_credential=auth_credential + ) + + # Call get_tools to trigger token exchange and session creation + with patch.object(toolset, '_perform_oauth_discovery'): + await toolset.get_tools(readonly_context=None) + + # Verify token exchange was called + mock_get_credential.assert_called_once() + + # Verify session was created with Authorization header + mock_create_session.assert_called_once() + call_args = mock_create_session.call_args + headers = call_args.kwargs.get('headers', {}) + assert headers.get('Authorization') == 'Bearer test_access_token' \ No newline at end of file