diff --git a/auth_spec.png b/auth_spec.png new file mode 100644 index 000000000..9bd276c46 Binary files /dev/null and b/auth_spec.png differ diff --git a/docs/configuration.mdx b/docs/configuration.mdx index 250841824..b3eacf702 100644 --- a/docs/configuration.mdx +++ b/docs/configuration.mdx @@ -123,6 +123,49 @@ mcp-agent uses two configuration files: +## OAuth Configuration + +MCP Agent exposes two complementary OAuth configuration blocks: + +- `authorization` describes how the MCP Agent server validates inbound bearer tokens and publishes protected resource metadata. +- `oauth` configures delegated authorization when the agent connects to downstream MCP servers. + +```yaml +authorization: + enabled: true + issuer_url: https://auth.example.com + resource_server_url: https://agent.example.com/mcp + required_scopes: ["mcp.read", "mcp.write"] + introspection_endpoint: https://auth.example.com/oauth/introspect + introspection_client_id: ${INTROSPECTION_CLIENT_ID} + introspection_client_secret: ${INTROSPECTION_CLIENT_SECRET} + +oauth: + callback_base_url: https://agent.example.com + flow_timeout_seconds: 180 + token_store: + backend: memory # use "redis" for multi-instance deployments + refresh_leeway_seconds: 90 + +mcp: + servers: + github: + transport: streamable_http + url: https://github.mcp.example.com/mcp + auth: + oauth: + enabled: true + scopes: ["repo", "user:email"] + client_id: ${GITHUB_MCP_CLIENT_ID} + client_secret: ${GITHUB_MCP_CLIENT_SECRET} + redirect_uri_options: + - https://agent.example.com/internal/oauth/callback +``` + +- When `authorization.enabled` is true the MCP server advertises `/.well-known/oauth-protected-resource` and enforces bearer tokens using the provided introspection or JWKS configuration. +- `oauth` enables delegated authorization flows; the default in-memory token store is ideal for local development while Redis is recommended for production clusters. +- Downstream servers opt into OAuth via `mcp.servers..auth.oauth`. Supplying a `client_id`/`client_secret` allows immediate usage; support for dynamic client registration is planned as a follow-up. + ## Configuration Reference ### Execution Engine diff --git a/docs/oauth_support_design.md b/docs/oauth_support_design.md new file mode 100644 index 000000000..5aea9b1f2 --- /dev/null +++ b/docs/oauth_support_design.md @@ -0,0 +1,108 @@ +# MCP Agent OAuth Support + +## Goals +- Protect MCP Agent Cloud servers using OAuth 2.1 so MCP clients obtain tokens via standard flows. +- Enable MCP Agent runtimes to authenticate to downstream MCP servers that require OAuth access tokens. +- Provide pluggable token storage for both local development (in-memory) and multi-instance deployments (Redis planned). +- Maintain compatibility with MCP Authorization spec (RFC 8414, RFC 9728, OAuth 2.1 + PKCE, Resource Indicators) and the proposed delegated authorization SEP. + +## Architecture Overview + +### Components +1. **Auth Server Integration** – Configure the FastMCP instance with `AuthSettings` and a custom `TokenVerifier` that calls MCP Agent Cloud auth services. +2. **Protected Resource Metadata** – Serve `/.well-known/oauth-protected-resource` using FastMCP hooks so clients can discover the auth server. +3. **Access Token Validation** – Enforce bearer tokens on every inbound MCP request via `RequireAuthMiddleware`, populating the request context with the authenticated user. +4. **OAuth Token Service** – New `mcp_agent.oauth` package with: + - `TokenStore`/`TokenRecord` abstractions + - `InMemoryTokenStore` and Redis-backed implementation (second pass) + - `TokenManager` orchestration (acquire, refresh, revoke) + - `OAuthHttpxAuth` for attaching tokens to downstream HTTP transports + - `AuthorizationFlowCoordinator` that interacts with the user via MCP `auth/request` +5. **Delegated Authorization UI Flow** – Extend the gateway/session relay so servers can send `auth/request` messages to MCP clients, capturing authorization codes via either: + - Client-returned callback URL (preferred, works with SEP-capable clients) + - MCP Agent hosted callback endpoint (`/internal/oauth/callback/{flow_id}`) as a fallback / native-app style loopback. +6. **Configuration Surface** – Extend `Settings` and per-server `MCPServerAuthSettings` to describe OAuth behaviour (scopes, preferred auth server, redirect URIs, etc.) and global token-store configuration. + +### Key Data Flow +1. **Inbound Requests** + - Client presents bearer token ⇒ `BearerAuthBackend` + `MCPAgentTokenVerifier` introspect token. + - Verified token populates context with `OAuthUserIdentity` (provider + subject + email). + - Context is propagated into workflows/sessions so downstream OAuth flows know the acting user. + +2. **Outbound HTTP (downstream MCP server)** + - `ServerRegistry` detects `auth.oauth` configuration. + - Wraps HTTP transport with `OAuthHttpxAuth` which requests an access token from `TokenManager`. + - `TokenManager` checks store; if missing/expired ⇒ `AuthorizationFlowCoordinator` performs RFC 9728 discovery, PKCE, delegated browser flow through MCP client, exchanges code for tokens, caches result. + - Requests automatically retry after token refresh when a response returns 401/invalid token. + +3. **Token Storage** + - Tokens stored per `(user_identity, resource, authorization_server)` tuple with metadata (scopes, expiry, refresh token, provider claims). + - Store implements optimistic locking to avoid concurrent refresh storms. + - Pluggable backend (`InMemoryTokenStore` initial, Redis follow-up). + +## Module Plan + +``` +src/mcp_agent/oauth/ + __init__.py + identity.py # OAuthUserIdentity, helpers to extract from auth context + records.py # TokenRecord dataclass/pydantic model + store/base.py # TokenStore protocol + store/in_memory.py # Default store + manager.py # TokenManager (get/refresh/invalidate) + flow.py # AuthorizationFlowCoordinator + http/auth.py # OAuthHttpxAuth (httpx.Auth implementation) + metadata.py # RFC 8414 + RFC 9728 discovery helpers + pkce.py # PKCE + state utilities + errors.py # Custom exception hierarchy +``` + +Integration touchpoints: +- `mcp_agent/config.py` – add OAuth settings models. +- `mcp_agent/core/context.py` – add `current_user`, `token_manager`, `token_store`, `oauth_config` fields. +- `mcp_agent/app.py` – initialize token store/manager based on settings. +- `mcp_agent/server/app_server.py` – configure FastMCP auth settings, register callback route, surface user identity, extend relay to handle `auth/request`. +- `mcp_agent/mcp/mcp_server_registry.py` & `mcp_agent/mcp/mcp_connection_manager.py` – wire `OAuthHttpxAuth` into HTTP transports and expose helper for manual token teardown. +- `mcp_agent/mcp/client_proxy.py` – add proxy helpers for `auth/request`. +- `SessionProxy` – add direct request helper for `auth/request` and ensure Temporal flow support. +- `examples/mcp_agent_server/*` – demonstrate configuration changes. +- Tests – new suite exercising token store, metadata discovery, flow orchestration (with mocked HTTP + client responses). + +## OAuth Flow Details +1. **Discovery** + - If downstream server responds 401 with `WWW-Authenticate`, parse for `resource_metadata` ⇒ GET metadata ⇒ determine auth server URL(s). + - Fetch authorization server metadata (RFC 8414). + - Perform optional dynamic client registration when configured and supported. + +2. **Authorization Request** + - Generate PKCE challenge/verifier, secure `state`, choose `redirect_uri`. + - Build authorization URL including `resource` parameter (RFC 8707) + requested scopes. + - Invoke `auth/request` via SessionProxy → MCP client opens browser. + +3. **Callback Handling** + - Preferred: MCP client returns callback URL payload via request result. + - Fallback: Authorization server redirects to `/internal/oauth/callback/{flow_id}`. + - Coordinator validates `state`, extracts `code` (and errors). + +4. **Token Exchange / Storage** + - POST token endpoint with code + PKCE verifier + resource. + - Store access token, refresh token, expiry, scope, provider metadata. + - Associate tokens with user identity for reuse. + +5. **Refresh / Revocation** + - Manager refreshes when expiry within configurable grace window. + - Invalidate token on refresh failure or when server responses indicate revocation. + - Provide method to revoke tokens via authorization server when supported. + +## Open Questions / Follow-ups +- Redis-backed `TokenStore` (requires deployment secrets) – planned second phase. +- How LastMile auth server exposes token introspection + JWKS; need concrete endpoint specs to finalize `MCPAgentTokenVerifier`. +- MCP client adoption of `auth/request` SEP – need capability detection; until widely supported we rely on hosted callback fallback & manual instructions. +- Access control DSL (include/exclude by email/domain) – to be evaluated once token identity payload finalized. + +## Testing Strategy +- Unit tests for token store concurrency + expiry handling. +- Metadata discovery + PKCE generation (pure python tests). +- Integration-style test for delegated flow using mocked HTTP server + fake MCP client (ensures `auth/request` plumbing works end-to-end). +- Tests around server 401 enforcement + WWW-Authenticate header. + diff --git a/examples/oauth/README.md b/examples/oauth/README.md new file mode 100644 index 000000000..dd966ba1f --- /dev/null +++ b/examples/oauth/README.md @@ -0,0 +1,19 @@ +# OAUTH scenarios + +## preconfigured + +In this case, a token is hard-coded into the configuration. +This is useful for testing or when the token is static. + +## workflow_pre_auth + +In this case, the client can call a `workflows_pre_auth` tool before calling a workflow to seed the tokens. +This is useful when the client can do the auth step, but the workflow cannot (e.g. because it runs async). +There is a slight hack employed here: since we don't have oauth for the mcp app, we do not have a user. +Since we need a user to store the token against, we create a synthetic user and use that. + +## dynamic_auth + +In this case, no tokens are provided, and the calls comes back to the client to do the auth step. +Currently implemented as an elicitation request (to align with the future elicit URL scheme). +I have not achieved full end-to-end flow here. \ No newline at end of file diff --git a/examples/oauth/dynamic_auth/client.py b/examples/oauth/dynamic_auth/client.py new file mode 100644 index 000000000..8280e933a --- /dev/null +++ b/examples/oauth/dynamic_auth/client.py @@ -0,0 +1,166 @@ +import asyncio +import time + +from datetime import timedelta +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from mcp import ClientSession +from mcp.types import CallToolResult, LoggingMessageNotificationParams +from mcp_agent.app import MCPApp +from mcp_agent.config import MCPServerSettings +from mcp_agent.core.context import Context +from mcp_agent.mcp.gen_client import gen_client +from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession +from mcp_agent.human_input.console_handler import console_input_callback +from mcp_agent.elicitation.handler import console_elicitation_callback + +from rich import print + +try: + from exceptiongroup import ExceptionGroup as _ExceptionGroup # Python 3.10 backport +except Exception: # pragma: no cover + _ExceptionGroup = None # type: ignore +try: + from anyio import BrokenResourceError as _BrokenResourceError +except Exception: # pragma: no cover + _BrokenResourceError = None # type: ignore + + +async def main(): + # Create MCPApp to get the server registry + app = MCPApp( + name="workflow_mcp_client", + human_input_callback=console_input_callback, + elicitation_callback=console_elicitation_callback, + ) + async with app.run() as client_app: + logger = client_app.logger + context = client_app.context + + # Connect to the workflow server + logger.info("Connecting to workflow server...") + + # Override the server configuration to point to our local script + context.server_registry.registry["basic_agent_server"] = MCPServerSettings( + name="basic_agent_server", + description="Local workflow server running the basic agent example", + transport="sse", + url="http://127.0.0.1:8000/sse", + ) + + # Define a logging callback to receive server-side log notifications + async def on_server_log(params: LoggingMessageNotificationParams) -> None: + level = params.level.upper() + name = params.logger or "server" + print(f"[SERVER LOG] [{level}] [{name}] {params.data}") + + # Provide a client session factory that installs our logging callback + # and prints non-logging notifications to the console + class ConsolePrintingClientSession(MCPAgentClientSession): + async def _received_notification(self, notification): # type: ignore[override] + try: + method = getattr(notification.root, "method", None) + except Exception: + method = None + + # Avoid duplicating server log prints (handled by logging_callback) + if method and method != "notifications/message": + try: + data = notification.model_dump() + except Exception: + data = str(notification) + print(f"[SERVER NOTIFY] {method}: {data}") + + return await super()._received_notification(notification) + + def make_session( + read_stream: MemoryObjectReceiveStream, + write_stream: MemoryObjectSendStream, + read_timeout_seconds: timedelta | None, + context: Context | None = None, + ) -> ClientSession: + return ConsolePrintingClientSession( + read_stream=read_stream, + write_stream=write_stream, + read_timeout_seconds=read_timeout_seconds, + logging_callback=on_server_log, + context=context, + ) + + try: + async with gen_client( + "basic_agent_server", + context.server_registry, + client_session_factory=make_session, + ) as server: + try: + await server.set_logging_level("info") + except Exception: + # Older servers may not support logging capability + print("[client] Server does not support logging/setLevel") + + # List available tools + tools_result = await server.list_tools() + logger.info( + "Available tools:", + data={"tools": [tool.name for tool in tools_result.tools]}, + ) + + print( + await server.call_tool("github_org_search", {"query": "lastmileai"}) + ) + except Exception as e: + # Tolerate benign shutdown races from stdio client (BrokenResourceError within ExceptionGroup) + if _ExceptionGroup is not None and isinstance(e, _ExceptionGroup): + subs = getattr(e, "exceptions", []) or [] + if ( + _BrokenResourceError is not None + and subs + and all(isinstance(se, _BrokenResourceError) for se in subs) + ): + logger.debug("Ignored BrokenResourceError from stdio shutdown") + else: + raise + elif _BrokenResourceError is not None and isinstance( + e, _BrokenResourceError + ): + logger.debug("Ignored BrokenResourceError from stdio shutdown") + elif "BrokenResourceError" in str(e): + logger.debug( + "Ignored BrokenResourceError from stdio shutdown (string match)" + ) + else: + raise + # Nudge cleanup of subprocess transports before the loop closes to avoid + # 'Event loop is closed' from BaseSubprocessTransport.__del__ on GC. + try: + await asyncio.sleep(0) + except Exception: + pass + try: + import gc + + gc.collect() + except Exception: + pass + + +def _tool_result_to_json(tool_result: CallToolResult): + if tool_result.content and len(tool_result.content) > 0: + text = tool_result.content[0].text + try: + # Try to parse the response as JSON if it's a string + import json + + return json.loads(text) + except (json.JSONDecodeError, TypeError): + # If it's not valid JSON, just use the text + return None + + +if __name__ == "__main__": + start = time.time() + asyncio.run(main()) + end = time.time() + t = end - start + + print(f"Total run time: {t:.2f}s") diff --git a/examples/oauth/dynamic_auth/main.py b/examples/oauth/dynamic_auth/main.py new file mode 100644 index 000000000..5772c595d --- /dev/null +++ b/examples/oauth/dynamic_auth/main.py @@ -0,0 +1,173 @@ +""" +Workflow MCP Server Example + +This example demonstrates three approaches to creating agents and workflows: +1. Traditional workflow-based approach with manual agent creation +2. Programmatic agent configuration using AgentConfig +3. Declarative agent configuration using FastMCPApp decorators +""" + +import asyncio +import json +import os +from pydantic import AnyHttpUrl + +from mcp.server.fastmcp import FastMCP + +from mcp_agent.app import MCPApp +from mcp_agent.server.app_server import create_mcp_server_for_app + +from mcp_agent.config import ( + MCPServerSettings, + Settings, + LoggerSettings, + MCPSettings, + MCPServerAuthSettings, + MCPOAuthClientSettings, + OAuthSettings, + OAuthTokenStoreSettings, + TemporalSettings, +) + +# Note: This is purely optional: +# if not provided, a default FastMCP server will be created by MCPApp using create_mcp_server_for_app() +mcp = FastMCP(name="basic_agent_server", instructions="My basic agent server example.") + + +# Get client id and secret from environment variables +client_id = os.getenv("GITHUB_CLIENT_ID") +client_secret = os.getenv("GITHUB_CLIENT_SECRET") + +if not client_id or not client_secret: + print( + "\nGitHub client id and/or secret not found in GITHUB_CLIENT_Id and GITHUB_CLIENT_SECRET " + "environment variables." + ) + print("\nTo create these:") + print("\n1. Open your profile on github.com and navigate to 'Developer Settings'") + print("\n2. Create a new OAuth app and create a client secret for it.") + print("\n3. Create environment variables:") + print("\nexport GITHUB_CLIENT_ID='your_client_id_here'") + print("\nexport GITHUB_CLIENT_SECRET='your_client_secret_here'") + + +settings = Settings( + execution_engine="temporal", + temporal=TemporalSettings( + host="localhost:7233", + namespace="default", + task_queue="mcp-agent", + max_concurrent_activities=10, + ), + logger=LoggerSettings(level="info"), + oauth=OAuthSettings( + callback_base_url=AnyHttpUrl("http://localhost:8000"), + flow_timeout_seconds=300, + token_store=OAuthTokenStoreSettings(refresh_leeway_seconds=60), + ), + mcp=MCPSettings( + servers={ + "github": MCPServerSettings( + name="github", + transport="streamable_http", + url="https://api.githubcopilot.com/mcp/", + auth=MCPServerAuthSettings( + oauth=MCPOAuthClientSettings( + client_id=client_id, + client_secret=client_secret, + use_internal_callback=True, + enabled=True, + scopes=[ + "read:org", # Required for search_orgs tool + "public_repo", # Access to public repositories + "user:email", # User information access + ], + authorization_server=AnyHttpUrl( + "https://github.com/login/oauth" + ), + resource=AnyHttpUrl("https://api.githubcopilot.com/mcp"), + ) + ), + ) + } + ), +) + +# Define the MCPApp instance. The server created for this app will advertise the +# MCP logging capability and forward structured logs upstream to connected clients. +app = MCPApp( + name="basic_agent_server", + description="Basic agent server example", + mcp=mcp, + settings=settings, +) + + +@app.workflow_task(name="github_org_search_activity") +async def github_org_search_activity(query: str) -> str: + from mcp_agent.mcp.gen_client import gen_client + + print("running activity)") + try: + async with gen_client( + "github", server_registry=app.context.server_registry, context=app.context + ) as github_client: + print("got client") + result = await github_client.call_tool( + "search_orgs", + {"query": query, "perPage": 10, "sort": "best-match", "order": "desc"}, + ) + + organizations = [] + if result.content: + for content_item in result.content: + if hasattr(content_item, "text"): + try: + data = json.loads(content_item.text) + if isinstance(data, dict) and "items" in data: + organizations.extend(data["items"]) + elif isinstance(data, list): + organizations.extend(data) + except json.JSONDecodeError: + pass + + print(f"Organizations: {organizations}") + return str(organizations) + except Exception as e: + import traceback + + traceback.print_exc() + return f"Error: {e}" + + +@app.tool(name="github_org_search") +async def github_org_search(query: str) -> str: + if app._logger and hasattr(app._logger, "_bound_context"): + app._logger._bound_context = app.context + + result = await app.executor.execute(github_org_search_activity, query) + print(f"Result: {result}") + + return result + + +async def main(): + async with app.run() as agent_app: + # Log registered workflows and agent configurations + agent_app.logger.info(f"Creating MCP server for {agent_app.name}") + + agent_app.logger.info("Registered workflows:") + for workflow_id in agent_app.workflows: + agent_app.logger.info(f" - {workflow_id}") + + # Create the MCP server that exposes both workflows and agent configurations, + # optionally using custom FastMCP settings + mcp_server = create_mcp_server_for_app(agent_app) + agent_app.logger.info(f"MCP Server settings: {mcp_server.settings}") + + # Run the server + await mcp_server.run_sse_async() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/oauth/dynamic_auth/worker.py b/examples/oauth/dynamic_auth/worker.py new file mode 100644 index 000000000..39b2a3c67 --- /dev/null +++ b/examples/oauth/dynamic_auth/worker.py @@ -0,0 +1,31 @@ +""" +Worker script for the Temporal workflow example. +This script starts a Temporal worker that can execute workflows and activities. +Run this script in a separate terminal window before running the main.py script. + +This leverages the TemporalExecutor's start_worker method to handle the worker setup. +""" + +import asyncio +import logging + + +from mcp_agent.executor.temporal import create_temporal_worker_for_app + +from main import app + +# Initialize logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def main(): + """ + Start a Temporal worker for the example workflows using the app's executor. + """ + async with create_temporal_worker_for_app(app) as worker: + await worker.run() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/oauth/preconfigured/README.md b/examples/oauth/preconfigured/README.md new file mode 100644 index 000000000..846eded46 --- /dev/null +++ b/examples/oauth/preconfigured/README.md @@ -0,0 +1,402 @@ +# OAuth MCP Agent Example + +This example demonstrates how to build MCP agents that use OAuth authentication to access OAuth-protected MCP servers, specifically showing integration with the GitHub MCP server. + +## 📋 Overview + +This example includes: + +- **Basic OAuth Integration** - Connect to OAuth-protected MCP servers +- **GitHub Organization Search** - Use the `search_orgs` tool from GitHub MCP server +- **Workflow Pre-Authorization** - Demonstrate the new `workflow_pre_auth` endpoint +- **Interactive OAuth Flow** - Complete OAuth setup and token management +- **Production-Ready Configuration** - Comprehensive config with security best practices + +## 🚀 Quick Start + +### 1. Prerequisites + +```bash +# Install Python dependencies +pip install -r requirements.txt + +# Install GitHub MCP server +uvx install github-mcp-server + +# Optional: Install additional MCP servers +npm install -g @modelcontextprotocol/server-filesystem +uvx install mcp-server-fetch +``` + +### 2. Set Up GitHub OAuth App + +1. Go to [GitHub Settings → Developer settings → OAuth Apps](https://github.com/settings/applications/new) +2. Click **"New OAuth App"** +3. Fill in the details: + - **Application name**: `MCP Agent OAuth Example` + - **Homepage URL**: `https://github.com/yourusername/your-repo` + - **Authorization callback URL**: `http://localhost:8080/oauth/callback` +4. Click **"Register application"** +5. Copy the **Client ID** and generate a **Client Secret** + +### 3. Configure Secrets + +```bash +# Copy the secrets template +cp mcp_agent.secrets.yaml.example mcp_agent.secrets.yaml + +# Edit with your credentials +nano mcp_agent.secrets.yaml +``` + +Add your GitHub OAuth app credentials: + +```yaml +mcp: + servers: + github: + auth: + oauth: + client_id: "your_github_oauth_app_client_id_here" + client_secret: "your_github_oauth_app_client_secret_here" + access_token: "your_github_access_token_here" # Optional: skip OAuth flow +``` + +### 4. Run the Examples + +#### Basic OAuth Example +```bash +python main.py +``` + +#### Interactive OAuth Setup +```bash +python oauth_demo.py +``` + +#### Workflow with Pre-Authorization +```bash +python workflow_example.py +``` + +## 📁 File Structure + +``` +examples/oauth/ +├── README.md # This file +├── main.py # Basic OAuth MCP agent +├── workflow_example.py # Workflow with pre-auth demo +├── oauth_demo.py # Interactive OAuth flow +├── mcp_agent.config.yaml # Agent configuration +├── mcp_agent.secrets.yaml.example # Secrets template +└── requirements.txt # Python dependencies +``` + +## 🔐 Authentication Methods + +### Method 1: OAuth Flow (Recommended) + +Full OAuth 2.0 flow with refresh tokens: + +1. **Run Interactive Setup**: + ```bash + GITHUB_CLIENT_ID=your_id GITHUB_CLIENT_SECRET=your_secret python oauth_demo.py + ``` + +2. **Follow Browser Flow**: Authorize the application in your browser + +3. **Token Storage**: Tokens are automatically stored for reuse + +### Method 2: Personal Access Token (Development) + +For development and testing, you can use a GitHub Personal Access Token: + +1. Go to [GitHub Settings → Developer settings → Personal access tokens](https://github.com/settings/tokens) +2. Generate a token with scopes: `read:org`, `public_repo`, `user:email` +3. Configure in `mcp_agent.secrets.yaml`: + +```yaml +mcp: + servers: + github: + auth: + api_key: "ghp_your_personal_access_token_here" +``` + +## 🔄 Workflow Pre-Authorization + +The `workflow_pre_auth` endpoint allows you to pre-store OAuth tokens for workflows: + +### 1. Start MCP Agent Server + +```bash +mcp-agent server --config mcp_agent.config.yaml +``` + +### 2. Pre-Authorize Tokens + +```bash +curl -X POST http://localhost:8080/tools/workflows-pre-auth \ + -H "Content-Type: application/json" \ + -d '{ + "workflow_name": "github_analysis_workflow", + "tokens": [ + { + "access_token": "your_github_access_token", + "refresh_token": "your_refresh_token", + "server_name": "github", + "scopes": ["read:org", "public_repo"], + "authorization_server": "https://github.com/login/oauth/authorize" + } + ] + }' +``` + +### 3. Run Workflow + +```bash +curl -X POST http://localhost:8080/tools/workflows-run \ + -H "Content-Type: application/json" \ + -d '{ + "workflow_name": "analyze_github_ecosystem", + "run_parameters": { + "focus_areas": ["AI/ML", "cloud", "security"], + "include_details": true + } + }' +``` + +## 🛠 Examples Explained + +### main.py - Basic OAuth Integration + +Demonstrates: +- Connecting to GitHub MCP server with OAuth +- Using the `search_orgs` tool +- Error handling and token management +- Both single-use and persistent connections + +Key features: +```python +async with gen_client("github", server_registry=context.server_registry) as github_client: + result = await github_client.call_tool("search_orgs", { + "query": "microsoft", + "perPage": 10, + "sort": "best-match" + }) +``` + +### workflow_example.py - Advanced Workflow + +Demonstrates: +- Custom agent (`GitHubOrganizationAnalyzer`) +- Workflow with `@app.async_tool` decorator +- Pre-authorization token usage +- Comprehensive GitHub ecosystem analysis + +Key features: +```python +@app.async_tool +async def analyze_github_ecosystem( + app_ctx: Context, + focus_areas: List[str], + include_details: bool = True +) -> Dict[str, Any]: + # Uses pre-authorized tokens automatically + analyzer = GitHubOrganizationAnalyzer(context=app_ctx) + return await analyzer.analyze_organizations(queries, include_details) +``` + +### oauth_demo.py - Interactive OAuth Setup + +Demonstrates: +- Complete OAuth 2.0 flow +- Local callback server +- Token testing and validation +- Token persistence + +Key features: +- Automatic browser opening +- CSRF protection with state parameter +- Token testing with GitHub API +- Export tokens for MCP agent use + +## ⚙ Configuration Details + +### OAuth Configuration + +```yaml +oauth: + token_store: + backend: memory # or 'redis' for production + refresh_leeway_seconds: 60 + flow_timeout_seconds: 300 + callback_base_url: http://localhost:8080 + +mcp: + servers: + github: + auth: + oauth: + enabled: true + scopes: ["read:org", "public_repo", "user:email"] + authorization_server: "https://github.com/login/oauth/authorize" + resource: "https://api.github.com" +``` + +### GitHub Scopes Required + +| Scope | Purpose | Required | +|-------|---------|----------| +| `read:org` | Search organizations | ✅ Yes | +| `public_repo` | Access public repositories | ✅ Yes | +| `user:email` | User information | ⚠ Recommended | +| `repo` | Private repositories | ❌ Optional | + +## 🔧 Production Deployment + +### Redis Token Storage + +For production with multiple processes: + +```yaml +oauth: + token_store: + backend: redis + redis_url: "redis://localhost:6379" + redis_prefix: "mcp_agent:oauth_tokens" +``` + +### Environment Variables + +```bash +export GITHUB_CLIENT_ID="your_client_id" +export GITHUB_CLIENT_SECRET="your_client_secret" +export REDIS_URL="redis://localhost:6379" +export OPENAI_API_KEY="your_openai_key" +``` + +### Security Best Practices + +1. **Never commit secrets** - Use `.gitignore` for `mcp_agent.secrets.yaml` +2. **Rotate tokens regularly** - Set up token refresh workflows +3. **Minimal scopes** - Only request necessary permissions +4. **Secure storage** - Use Redis or encrypted storage in production +5. **HTTPS callbacks** - Use HTTPS URLs for production OAuth callbacks + +## 🐛 Troubleshooting + +### Common Issues + +#### OAuth Flow Fails +``` +Error: OAuth error: access_denied +``` +**Solution**: Check callback URL matches OAuth app configuration + +#### Token Test Fails +``` +Error: Token test failed: 401 +``` +**Solution**: Verify token scopes and GitHub app permissions + +#### MCP Server Connection Fails +``` +Error: GitHub MCP server not found +``` +**Solution**: Install GitHub MCP server with `uvx install github-mcp-server` + +#### Import Errors +``` +ImportError: No module named 'mcp_agent' +``` +**Solution**: Install with `pip install mcp-agent[oauth]` + +### Debug Mode + +Enable detailed logging: + +```yaml +logger: + level: debug + debug_oauth: true +``` + +### Testing Tokens + +Test your GitHub token manually: + +```bash +curl -H "Authorization: Bearer YOUR_TOKEN" https://api.github.com/user +``` + +## 📚 Advanced Usage + +### Multiple OAuth Providers + +```yaml +mcp: + servers: + github: + auth: + oauth: + client_id: "github_client_id" + # ... GitHub config + + slack: + auth: + oauth: + client_id: "slack_client_id" + # ... Slack config +``` + +### Custom Token Refresh + +```python +async def refresh_github_token(old_token: str) -> str: + # Custom token refresh logic + async with aiohttp.ClientSession() as session: + # ... refresh implementation + return new_token +``` + +### Workflow Chaining + +```python +@app.async_tool +async def multi_step_analysis(app_ctx: Context, orgs: List[str]): + # Step 1: Search organizations + github_results = await search_organizations(orgs) + + # Step 2: Analyze with different service + analysis = await analyze_with_ai(github_results) + + # Step 3: Store results + await store_results(analysis) + + return analysis +``` + +## 🤝 Contributing + +1. Fork the repository +2. Create your feature branch +3. Add tests for new functionality +4. Ensure all examples work +5. Submit a pull request + +## 📄 License + +This example is part of the MCP Agent project and follows the same license terms. + +## 🔗 Related Resources + +- [MCP Agent Documentation](../../README.md) +- [GitHub MCP Server](https://github.com/github/github-mcp-server) +- [OAuth 2.0 Specification](https://tools.ietf.org/html/rfc6749) +- [GitHub OAuth Apps](https://docs.github.com/en/developers/apps/building-oauth-apps) +- [Model Context Protocol](https://modelcontextprotocol.io) + +--- + +For questions or issues with this example, please check the [main repository issues](../../issues) or create a new issue with the `oauth-example` label. \ No newline at end of file diff --git a/examples/oauth/preconfigured/main.py b/examples/oauth/preconfigured/main.py new file mode 100644 index 000000000..12fd60356 --- /dev/null +++ b/examples/oauth/preconfigured/main.py @@ -0,0 +1,184 @@ +""" +OAuth MCP Agent Example + +This example demonstrates how to use an MCP agent with OAuth authentication +to access the GitHub MCP server and search for organizations. + +Features demonstrated: +- OAuth flow setup and configuration +- Connecting to GitHub MCP server with OAuth +- Using the search_orgs tool +- Error handling and token refresh +""" + +import asyncio +import json +import logging +from typing import Any, Dict, List + +from mcp_agent.app import MCPApp +from mcp_agent.mcp.gen_client import gen_client + +# Create the MCP app with OAuth configuration +app = MCPApp(name="oauth_github_example") + + +async def search_github_orgs(query: str, limit: int = 5) -> List[Dict[str, Any]]: + """ + Search for GitHub organizations using the GitHub MCP server with OAuth. + + Args: + query: Search query (e.g., 'microsoft', 'location:california') + limit: Maximum number of results to return + + Returns: + List of organization data from GitHub + """ + async with app.run() as github_app: + context = github_app.context + logger = github_app.logger + + logger.info(f"Searching GitHub organizations for: '{query}'") + + try: + # Connect to the GitHub MCP server with OAuth + async with gen_client( + "github", server_registry=context.server_registry, context=context + ) as github_client: + logger.info("Connected to GitHub MCP server with OAuth") + + # List available tools to verify connection + tools_result = await github_client.list_tools() + logger.info(f"Available tools: {len(tools_result.tools)} tools found") + + # Find the search_orgs tool + search_orgs_tool = None + for tool in tools_result.tools: + if tool.name == "search_orgs": + search_orgs_tool = tool + break + + if not search_orgs_tool: + logger.error("search_orgs tool not found") + return [] + + logger.info(f"Found search_orgs tool: {search_orgs_tool.description}") + + # Call the search_orgs tool + result = await github_client.call_tool( + "search_orgs", + { + "query": query, + "perPage": min(limit, 100), # GitHub API max is 100 + "sort": "best-match", + "order": "desc", + }, + ) + + logger.info("Search completed, processing results...") + + # Parse and return the results + if result.content: + # The result content should contain the organization data + organizations = [] + for content_item in result.content: + if hasattr(content_item, "text"): + try: + # Try to parse as JSON if it's structured data + data = json.loads(content_item.text) + if isinstance(data, dict) and "items" in data: + organizations.extend(data["items"][:limit]) + elif isinstance(data, list): + organizations.extend(data[:limit]) + else: + organizations.append(data) + except json.JSONDecodeError: + # If not JSON, treat as text description + organizations.append({"description": content_item.text}) + + logger.info(f"Found {len(organizations)} organizations") + return organizations[:limit] + + return [] + + except Exception as e: + logger.error(f"Error searching GitHub organizations: {e}") + # Check if it's an OAuth-related error + if "authentication" in str(e).lower() or "oauth" in str(e).lower(): + logger.error( + "Authentication failed. Please check your OAuth configuration " + "and ensure your GitHub token is valid." + ) + raise + + +async def demonstrate_org_search(): + """ + Demonstrate searching for different types of organizations. + """ + search_queries = [ + "microsoft", + # "location:california", + # "created:>=2020-01-01", + # "language:python", + # "repositories:>100" + ] + + for query in search_queries: + print(f"\n{'=' * 50}") + print(f"Searching for: {query}") + print("=" * 50) + + try: + orgs = await search_github_orgs(query, limit=3) + + if orgs: + for i, org in enumerate(orgs, 1): + print(f"\n{i}. {org.get('login', 'Unknown')}") + if "description" in org and org["description"]: + print(f" Description: {org['description']}") + if "html_url" in org: + print(f" URL: {org['html_url']}") + if "public_repos" in org: + print(f" Public repos: {org['public_repos']}") + if "location" in org and org["location"]: + print(f" Location: {org['location']}") + else: + print("No organizations found.") + + except Exception as e: + print(f"Error: {e}") + continue + + +async def main(): + """ + Main function demonstrating various OAuth MCP agent usage patterns. + """ + print("OAuth MCP Agent Example - GitHub Organization Search") + print("=" * 60) + + try: + # Demonstrate basic organization search + await demonstrate_org_search() + + except Exception as e: + print(f"\nExample failed with error: {e}") + print("\nPlease ensure:") + print("1. You have configured your GitHub OAuth app correctly") + print("2. Your mcp_agent.secrets.yaml file contains valid OAuth credentials") + print("3. The GitHub MCP server is properly installed and accessible") + print("4. Your OAuth token has the required scopes (read:org)") + + # Log the full error for debugging + logging.exception("Full error details:") + + +if __name__ == "__main__": + # Set up logging to show detailed information + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + asyncio.run(main()) diff --git a/examples/oauth/preconfigured/mcp_agent.config.yaml b/examples/oauth/preconfigured/mcp_agent.config.yaml new file mode 100644 index 000000000..61fa80456 --- /dev/null +++ b/examples/oauth/preconfigured/mcp_agent.config.yaml @@ -0,0 +1,115 @@ +$schema: ../../schema/mcp-agent.config.schema.json + +# OAuth MCP Agent Configuration +# This configuration demonstrates how to set up OAuth authentication +# for accessing OAuth-protected MCP servers like the GitHub MCP server + +execution_engine: asyncio + +# Logging configuration +logger: + type: console + level: debug + path: "./mcp-agent.log" + # Enable detailed OAuth and authentication logging + debug_oauth: true + +# OpenTelemetry tracing (optional) +tracing: + enabled: true + service_name: "OAuth-MCP-Agent-Example" + exporters: ["console"] + +# OAuth-specific configuration +oauth: + # Token store configuration + token_store: + backend: memory # Use 'redis' for production with multiple processes + refresh_leeway_seconds: 60 # Refresh tokens 60 seconds before expiry + + # OAuth flow timeout (5 minutes) + flow_timeout_seconds: 300 + + # Callback base URL for OAuth flows (adjust if running behind proxy) + callback_base_url: http://localhost:8080 + +# MCP server configurations +mcp: + servers: + # GitHub MCP Server with OAuth authentication + github: + # Use the remote GitHub MCP server + transport: "streamable_http" + url: "https://api.githubcopilot.com/mcp/" + + # OAuth authentication configuration + auth: + oauth: + enabled: true + # Required scopes for GitHub MCP server functionality + scopes: + - "read:org" # Required for search_orgs tool + - "public_repo" # Access to public repositories + - "user:email" # User information access + + # GitHub OAuth endpoints + authorization_server: "https://github.com/login/oauth" + resource: "https://api.githubcopilot.com/mcp" + + # Use internal callback handling (recommended) + use_internal_callback: true + + # Example: Additional MCP server (filesystem for local operations) + filesystem: + command: "npx" + args: ["-y", "@modelcontextprotocol/server-filesystem", "."] + # No auth required for local filesystem server + + # Example: Fetch server (no auth required) + fetch: + command: "uvx" + args: ["mcp-server-fetch"] + +# Authorization server configuration (if exposing this agent as an MCP server) +authorization: + enabled: false # Set to true if you want to expose this agent with OAuth protection + # issuer_url: "https://your-oauth-provider.com" + # resource_server_url: "https://your-agent-server.com" + # required_scopes: ["agent:read", "agent:execute"] + +# Model configuration (for LLM-based agents) +openai: + # API keys are stored in mcp_agent.secrets.yaml + model: "gpt-4" + temperature: 0.1 + max_tokens: 2000 + +anthropic: + # API keys are stored in mcp_agent.secrets.yaml + model: "claude-3-sonnet-20240229" + max_tokens: 2000 + +# Agent-specific settings +agents: + # Configuration for the GitHubOrganizationAnalyzer agent + default_timeout: 30 + max_retries: 3 + + # GitHub-specific settings + github: + max_orgs_per_query: 10 + default_search_sort: "best-match" + include_detailed_analysis: true + +# Workflow configuration +workflows: + # Default workflow timeout (30 minutes) + default_timeout_minutes: 30 + + # Enable workflow pre-authorization + enable_pre_auth: true + + # Workflow-specific settings + github_analysis: + max_concurrent_searches: 5 + analysis_depth: "detailed" # "basic" or "detailed" diff --git a/examples/oauth/preconfigured/mcp_agent.secrets.yaml.example b/examples/oauth/preconfigured/mcp_agent.secrets.yaml.example new file mode 100644 index 000000000..1feea2c6f --- /dev/null +++ b/examples/oauth/preconfigured/mcp_agent.secrets.yaml.example @@ -0,0 +1,41 @@ +$schema: ../../schema/mcp-agent.config.schema.json + +# OAuth MCP Agent Secrets Template +# +# This file contains sensitive configuration like API keys and OAuth credentials. +# Copy this file to mcp_agent.secrets.yaml and fill in your actual values. +# + +# OpenAI API configuration +openai: + api_key: "sk-your-openai-api-key-here" + +# Anthropic API configuration +anthropic: + api_key: "sk-ant-api03-your-anthropic-api-key-here" + +# GitHub OAuth App Configuration +# To set up a GitHub OAuth app: +# 1. Go to GitHub Settings > Developer settings > OAuth Apps +# 2. Click "New OAuth App" +# 3. Fill in the application details: +# - Application name: "MCP Agent OAuth Example" +# - Homepage URL: "https://github.com/your-username/your-repo" +# - Authorization callback URL: "http://localhost:8080/oauth/callback" +# 4. Click "Register application" +# 5. Copy the Client ID and generate a Client Secret +# 6. Fill in the values below + +mcp: + servers: + github: + auth: + oauth: + # GitHub OAuth App credentials + client_id: "your-github-client-id-here" + client_secret: "your-github-client-secret-here" + + # Optional: Pre-configured access token (from oauth_demo.py or manual setup) + # If you have a valid access token, you can specify it here to skip the OAuth flow + access_token: "your_github_access_token_here" + # refresh_token: "your_refresh_token_here" # If available diff --git a/examples/oauth/preconfigured/oauth_demo.py b/examples/oauth/preconfigured/oauth_demo.py new file mode 100644 index 000000000..37fbd6846 --- /dev/null +++ b/examples/oauth/preconfigured/oauth_demo.py @@ -0,0 +1,400 @@ +""" +Standalone OAuth Flow Demonstration + +This script demonstrates the OAuth flow for authenticating with GitHub +and storing tokens for use with MCP agents. This is useful for: +- Understanding the OAuth process +- Testing token acquisition and storage +- Debugging authentication issues +- Setting up tokens for the first time + +Run this script interactively to authenticate with GitHub and store tokens. +""" + +import asyncio +import json +import os +import secrets +import time +import urllib.parse +import webbrowser +from typing import Any, Dict, Optional + +import aiohttp +from aiohttp import web + + +class GitHubOAuthDemo: + """ + Demonstration of GitHub OAuth flow for MCP agent authentication. + """ + + def __init__(self, client_id: str, client_secret: str, redirect_uri: str = None): + """ + Initialize OAuth demo. + + Args: + client_id: GitHub OAuth app client ID + client_secret: GitHub OAuth app client secret + redirect_uri: OAuth redirect URI (defaults to localhost) + """ + self.client_id = client_id + self.client_secret = client_secret + self.redirect_uri = ( + redirect_uri or "http://localhost:8080/internal/oauth/callback" + ) + self.state = secrets.token_urlsafe(32) # CSRF protection + self.access_token: Optional[str] = None + self.refresh_token: Optional[str] = None + self.token_expires_at: Optional[float] = None + + def get_authorization_url(self, scopes: list = None) -> str: + """ + Generate the GitHub authorization URL. + + Args: + scopes: List of OAuth scopes to request + + Returns: + Authorization URL for the user to visit + """ + if scopes is None: + scopes = ["read:org", "public_repo"] # Default scopes for GitHub MCP server + + params = { + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "scope": " ".join(scopes), + "state": self.state, + "response_type": "code", + } + + base_url = "https://github.com/login/oauth/authorize" + return f"{base_url}?{urllib.parse.urlencode(params)}" + + async def exchange_code_for_token(self, code: str, state: str) -> Dict[str, Any]: + """ + Exchange authorization code for access token. + + Args: + code: Authorization code from GitHub + state: State parameter for CSRF protection + + Returns: + Token response data + + Raises: + ValueError: If state doesn't match or token exchange fails + """ + if state != self.state: + raise ValueError("State parameter mismatch - possible CSRF attack") + + token_url = "https://github.com/login/oauth/access_token" + data = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "redirect_uri": self.redirect_uri, + } + + headers = { + "Accept": "application/json", + "User-Agent": "MCP-Agent-OAuth-Demo/1.0", + } + + async with aiohttp.ClientSession() as session: + async with session.post(token_url, data=data, headers=headers) as response: + if response.status != 200: + raise ValueError(f"Token exchange failed: {response.status}") + + token_data = await response.json() + + if "error" in token_data: + raise ValueError(f"Token error: {token_data['error_description']}") + + self.access_token = token_data.get("access_token") + self.refresh_token = token_data.get("refresh_token") + + # Calculate expiration time + expires_in = token_data.get("expires_in") + if expires_in: + self.token_expires_at = time.time() + expires_in + else: + # GitHub tokens typically don't expire, but set a far future date + self.token_expires_at = time.time() + (365 * 24 * 3600) # 1 year + + return token_data + + async def test_token(self) -> Dict[str, Any]: + """ + Test the access token by making a simple GitHub API call. + + Returns: + User information from GitHub API + """ + if not self.access_token: + raise ValueError("No access token available") + + headers = { + "Authorization": f"Bearer {self.access_token}", + "Accept": "application/vnd.github.v3+json", + "User-Agent": "MCP-Agent-OAuth-Demo/1.0", + } + + async with aiohttp.ClientSession() as session: + # Test with a simple user info call + async with session.get( + "https://api.github.com/user", headers=headers + ) as response: + if response.status != 200: + raise ValueError(f"Token test failed: {response.status}") + + user_data = await response.json() + return user_data + + def get_token_for_mcp_agent(self) -> Dict[str, Any]: + """ + Get token data in the format expected by MCP agent workflow_pre_auth. + + Returns: + Token data dictionary for MCP agent + """ + if not self.access_token: + raise ValueError("No access token available") + + return { + "access_token": self.access_token, + "refresh_token": self.refresh_token, + "server_name": "github", + "scopes": ["read:org", "public_repo"], + "expires_at": self.token_expires_at, + "authorization_server": "https://github.com/login/oauth/authorize", + "token_type": "Bearer", + } + + async def save_token_to_file(self, filename: str = "github_oauth_token.json"): + """ + Save the token to a JSON file for later use. + + Args: + filename: File to save token data + """ + token_data = self.get_token_for_mcp_agent() + + with open(filename, "w") as f: + json.dump(token_data, f, indent=2) + + print(f"Token saved to {filename}") + print( + "You can now use this token with the MCP agent workflow_pre_auth endpoint." + ) + + async def run_oauth_flow(self, scopes: list = None) -> Dict[str, Any]: + """ + Run the complete OAuth flow interactively. + + Args: + scopes: OAuth scopes to request + + Returns: + Complete token data + """ + print("Starting GitHub OAuth flow...") + print("=" * 50) + + # Step 1: Generate authorization URL + auth_url = self.get_authorization_url(scopes) + print("\n1. Please visit this URL to authorize the application:") + print(f" {auth_url}") + print("\n2. After authorization, you'll be redirected to a callback URL.") + + # Step 2: Start local server to handle callback + callback_received = asyncio.Event() + callback_data = {} + + async def handle_callback(request): + nonlocal callback_data + try: + code = request.query.get("code") + state = request.query.get("state") + error = request.query.get("error") + + if error: + callback_data["error"] = error + callback_data["error_description"] = request.query.get( + "error_description", "" + ) + else: + callback_data["code"] = code + callback_data["state"] = state + + callback_received.set() + + # Return a simple success page + html = """ + + +

OAuth Authorization Complete

+

You can close this window and return to the terminal.

+ + + + """ + return web.Response(text=html, content_type="text/html") + + except Exception as e: + print(f"Error in callback handler: {e}") + return web.Response(text=f"Error: {e}", status=500) + + # Start local server + app = web.Application() + app.router.add_get("/internal/oauth/callback", handle_callback) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "localhost", 8080) + await site.start() + + print("\n3. Local callback server started on http://localhost:8080") + print(" Opening browser to authorization URL...") + + # Open browser automatically + try: + webbrowser.open(auth_url) + except Exception: + print(" (Could not open browser automatically)") + + print("\n4. Waiting for authorization callback...") + + # Wait for callback with timeout + try: + await asyncio.wait_for( + callback_received.wait(), timeout=300 + ) # 5 minute timeout + except asyncio.TimeoutError: + print(" Timeout waiting for authorization callback") + await runner.cleanup() + raise ValueError("OAuth flow timeout") + + await runner.cleanup() + + # Step 3: Handle callback result + if "error" in callback_data: + raise ValueError( + f"OAuth error: {callback_data['error']} - {callback_data.get('error_description', '')}" + ) + + code = callback_data.get("code") + state = callback_data.get("state") + + if not code: + raise ValueError("No authorization code received") + + print("5. Authorization code received, exchanging for access token...") + + # Step 4: Exchange code for token + await self.exchange_code_for_token(code, state) + print(" ✓ Access token obtained successfully") + + # Step 5: Test the token + print("6. Testing access token...") + try: + user_info = await self.test_token() + username = user_info.get("login", "unknown") + print(f" ✓ Token test successful - authenticated as: {username}") + except Exception as e: + print(f" ⚠ Token test failed: {e}") + + return self.get_token_for_mcp_agent() + + +async def main(): + """ + Main interactive OAuth demonstration. + """ + print("GitHub OAuth Demo for MCP Agent") + print("=" * 40) + + # Check for environment variables + client_id = os.getenv("GITHUB_CLIENT_ID") + client_secret = os.getenv("GITHUB_CLIENT_SECRET") + + if not client_id or not client_secret: + print("\nTo use this demo, you need to set up a GitHub OAuth App:") + print("1. Go to GitHub Settings > Developer settings > OAuth Apps") + print("2. Click 'New OAuth App'") + print( + "3. Set Authorization callback URL to: http://localhost:8080/internal/oauth/callback" + ) + print("4. Set environment variables:") + print(" export GITHUB_CLIENT_ID='your_client_id'") + print(" export GITHUB_CLIENT_SECRET='your_client_secret'") + print("\nAlternatively, you can pass them as command line arguments.") + return + + try: + # Create OAuth demo instance + oauth_demo = GitHubOAuthDemo(client_id, client_secret) + + # Run the OAuth flow + scopes = ["read:org", "public_repo", "user:email"] + token_data = await oauth_demo.run_oauth_flow(scopes) + + print(token_data) + + print("\n" + "=" * 50) + print("OAuth Flow Completed Successfully!") + print("=" * 50) + + print("\nToken Details:") + print(f" Access Token: {token_data['access_token'][:20]}...") + print(f" Expires At: {time.ctime(token_data['expires_at'])}") + print(f" Scopes: {', '.join(token_data['scopes'])}") + + # Save token to file + save_choice = input("\nSave token to file? (y/n): ").lower().strip() + if save_choice in ["y", "yes"]: + filename = input( + "Enter filename (default: github_oauth_token.json): " + ).strip() + if not filename: + filename = "github_oauth_token.json" + await oauth_demo.save_token_to_file(filename) + + # Show usage instructions + print("\n" + "=" * 50) + print("Next Steps:") + print("=" * 50) + print("1. Use this token with the MCP agent workflow_pre_auth endpoint:") + + example_usage = { + "workflow_name": "github_analysis_workflow", + "tokens": [token_data], + } + print(f" {json.dumps(example_usage, indent=2)}") + + print("\n2. Or configure it in your mcp_agent.secrets.yaml:") + secrets_example = { + "mcp": { + "servers": { + "github": { + "auth": { + "oauth": { + "access_token": token_data["access_token"], + "refresh_token": token_data.get("refresh_token"), + "scopes": token_data["scopes"], + } + } + } + } + } + } + print(f" {json.dumps(secrets_example, indent=2)}") + + except Exception as e: + print(f"\nOAuth flow failed: {e}") + print("Please check your GitHub OAuth app configuration and try again.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/oauth/preconfigured/requirements.txt b/examples/oauth/preconfigured/requirements.txt new file mode 100644 index 000000000..12966f2cf --- /dev/null +++ b/examples/oauth/preconfigured/requirements.txt @@ -0,0 +1,37 @@ +# OAuth MCP Agent Example Requirements +# +# This file specifies the Python dependencies needed to run the OAuth examples. +# Install with: pip install -r requirements.txt + +# Core MCP Agent with OAuth support +mcp-agent[oauth,cli]>=0.1.27 + +# Additional dependencies for OAuth flow and HTTP handling +aiohttp>=3.11.13 +# For OAuth flow demonstration and token management +urllib3>=2.0.0 + +# GitHub MCP Server (install separately with uvx or npm) +# Note: The GitHub MCP server is installed via uvx, not pip +# Run: uvx install github-mcp-server + +# Optional: Enhanced JSON handling +ujson>=5.0.0 + +# Optional: Enhanced async support +trio>=0.30.0 + +# Optional: Redis for production token storage +redis>=4.0.0 + +# Development and testing dependencies (optional) +pytest>=7.4.0 +pytest-asyncio>=0.21.1 + +# For the interactive OAuth demo +webbrowser # Standard library - no explicit install needed + +# Note: Additional MCP servers can be installed as needed: +# - Filesystem server: npm install -g @modelcontextprotocol/server-filesystem +# - Fetch server: uvx install mcp-server-fetch +# - Other servers: See https://github.com/modelcontextprotocol/servers \ No newline at end of file diff --git a/examples/oauth/preconfigured/workflow_example.py b/examples/oauth/preconfigured/workflow_example.py new file mode 100644 index 000000000..e161e41a3 --- /dev/null +++ b/examples/oauth/preconfigured/workflow_example.py @@ -0,0 +1,437 @@ +""" +OAuth Workflow Pre-Authorization Example + +This example demonstrates how to use the workflow_pre_auth endpoint to +pre-store OAuth tokens for a workflow, then execute the workflow to +access GitHub MCP server tools. + +Features demonstrated: +- Using the workflow_pre_auth endpoint to store OAuth tokens +- Creating a workflow that uses pre-authorized tokens +- Accessing multiple MCP servers with different tokens +- Error handling for token expiration and OAuth issues +""" + +import asyncio +import json +import time +from typing import Any, Dict, List + +from mcp_agent.app import MCPApp +from mcp_agent.agents.agent import Agent +from mcp_agent.core.context import Context + + +class GitHubOrganizationAnalyzer(Agent): + """ + An agent that analyzes GitHub organizations using pre-authorized OAuth tokens. + """ + + def __init__(self, context: Context): + super().__init__(context=context) + self.name = "github_org_analyzer" + + async def analyze_organizations( + self, queries: List[str], detailed_analysis: bool = True + ) -> Dict[str, Any]: + """ + Analyze multiple organizations based on search queries. + + Args: + queries: List of search queries for organizations + detailed_analysis: Whether to fetch detailed information + + Returns: + Analysis results for all organizations + """ + logger = self.context.logger + results = {"organizations": [], "summary": {}, "errors": []} + + try: + # The OAuth tokens should be pre-authorized for this workflow + # and available through the context + logger.info(f"Starting analysis of {len(queries)} organization queries") + + for query in queries: + try: + orgs = await self._search_organizations(query) + + for org in orgs: + org_analysis = { + "query": query, + "organization": org.get("login", "unknown"), + "description": org.get("description", ""), + "url": org.get("html_url", ""), + "public_repos": org.get("public_repos", 0), + "followers": org.get("followers", 0), + "location": org.get("location", ""), + "created_at": org.get("created_at", ""), + } + + if detailed_analysis: + # Add more detailed analysis + org_analysis.update( + await self._analyze_organization_details(org) + ) + + results["organizations"].append(org_analysis) + + except Exception as e: + error_msg = f"Error processing query '{query}': {e}" + logger.error(error_msg) + results["errors"].append(error_msg) + + # Generate summary + results["summary"] = self._generate_summary(results["organizations"]) + + logger.info( + f"Analysis completed: {len(results['organizations'])} organizations analyzed" + ) + return results + + except Exception as e: + logger.error(f"Organization analysis failed: {e}") + raise + + async def _search_organizations(self, query: str) -> List[Dict[str, Any]]: + """Search for organizations using the GitHub MCP server.""" + from mcp_agent.mcp.gen_client import gen_client + + async with gen_client( + "github", server_registry=self.context.server_registry, context=self.context + ) as github_client: + result = await github_client.call_tool( + "search_orgs", + {"query": query, "perPage": 10, "sort": "best-match", "order": "desc"}, + ) + + organizations = [] + if result.content: + for content_item in result.content: + if hasattr(content_item, "text"): + try: + data = json.loads(content_item.text) + if isinstance(data, dict) and "items" in data: + organizations.extend(data["items"]) + elif isinstance(data, list): + organizations.extend(data) + except json.JSONDecodeError: + pass + + return organizations + + async def _analyze_organization_details( + self, org: Dict[str, Any] + ) -> Dict[str, Any]: + """Analyze detailed information about an organization.""" + details = { + "activity_score": self._calculate_activity_score(org), + "size_category": self._categorize_size(org.get("public_repos", 0)), + "age_years": self._calculate_age(org.get("created_at", "")), + } + + return details + + def _calculate_activity_score(self, org: Dict[str, Any]) -> float: + """Calculate a simple activity score based on available metrics.""" + repos = org.get("public_repos", 0) + followers = org.get("followers", 0) + + # Simple scoring algorithm + score = (repos * 0.1) + (followers * 0.01) + return min(score, 100.0) # Cap at 100 + + def _categorize_size(self, repo_count: int) -> str: + """Categorize organization size based on repository count.""" + if repo_count < 10: + return "small" + elif repo_count < 50: + return "medium" + elif repo_count < 200: + return "large" + else: + return "enterprise" + + def _calculate_age(self, created_at: str) -> float: + """Calculate organization age in years.""" + if not created_at: + return 0.0 + + try: + from datetime import datetime + + created = datetime.fromisoformat(created_at.replace("Z", "+00:00")) + now = datetime.now(created.tzinfo) + return (now - created).days / 365.25 + except Exception: + return 0.0 + + def _generate_summary(self, organizations: List[Dict[str, Any]]) -> Dict[str, Any]: + """Generate summary statistics from organization analysis.""" + if not organizations: + return {"total": 0, "message": "No organizations analyzed"} + + total_repos = sum(org.get("public_repos", 0) for org in organizations) + total_followers = sum(org.get("followers", 0) for org in organizations) + + size_categories = {} + for org in organizations: + category = org.get("size_category", "unknown") + size_categories[category] = size_categories.get(category, 0) + 1 + + return { + "total_organizations": len(organizations), + "total_public_repos": total_repos, + "total_followers": total_followers, + "average_repos_per_org": total_repos / len(organizations), + "size_distribution": size_categories, + "top_organizations": sorted( + organizations, key=lambda x: x.get("activity_score", 0), reverse=True + )[:5], + } + + +# Create workflow using the @app.async_tool decorator +app = MCPApp(name="oauth_workflow_example") + + +@app.async_tool +async def analyze_github_ecosystem( + app_ctx: Context, focus_areas: List[str], include_details: bool = True +) -> Dict[str, Any]: + """ + Analyze the GitHub ecosystem based on focus areas. + + This workflow demonstrates using pre-authorized OAuth tokens + to analyze organizations across different domains. + + Args: + focus_areas: Areas to focus on (e.g., ["AI/ML", "cloud", "security"]) + include_details: Whether to include detailed analysis + + Returns: + Comprehensive analysis of the GitHub ecosystem + """ + logger = app_ctx.logger + logger.info(f"Starting GitHub ecosystem analysis for: {focus_areas}") + + # Create the analyzer agent + analyzer = GitHubOrganizationAnalyzer(context=app_ctx) + + # Map focus areas to search queries + query_mapping = { + "AI/ML": [ + "machine-learning", + "artificial-intelligence", + "deep-learning", + "tensorflow", + "pytorch", + ], + "cloud": ["cloud-computing", "aws", "azure", "kubernetes", "docker"], + "security": ["cybersecurity", "security", "encryption", "vulnerability"], + "web": ["web-development", "javascript", "react", "vue", "angular"], + "mobile": ["mobile-development", "android", "ios", "react-native", "flutter"], + "data": ["data-science", "analytics", "big-data", "database", "sql"], + "devtools": ["developer-tools", "ci-cd", "testing", "monitoring", "automation"], + } + + all_queries = [] + for area in focus_areas: + queries = query_mapping.get(area.lower(), [area.lower()]) + all_queries.extend(queries) + + # Remove duplicates while preserving order + unique_queries = list(dict.fromkeys(all_queries)) + + logger.info(f"Executing {len(unique_queries)} organization searches") + + try: + # Perform the analysis + analysis_results = await analyzer.analyze_organizations( + queries=unique_queries, detailed_analysis=include_details + ) + + # Add ecosystem-level insights + ecosystem_analysis = { + "focus_areas": focus_areas, + "timestamp": time.time(), + "queries_executed": unique_queries, + "results": analysis_results, + "insights": _generate_ecosystem_insights(analysis_results), + } + + logger.info("GitHub ecosystem analysis completed successfully") + return ecosystem_analysis + + except Exception as e: + logger.error(f"Ecosystem analysis failed: {e}") + raise + + +def _generate_ecosystem_insights(results: Dict[str, Any]) -> Dict[str, Any]: + """Generate high-level insights from the ecosystem analysis.""" + organizations = results.get("organizations", []) + + if not organizations: + return {"message": "No data available for insights"} + + # Find trends and patterns + insights = { + "dominant_languages": _analyze_language_trends(organizations), + "geographic_distribution": _analyze_geographic_distribution(organizations), + "maturity_analysis": _analyze_organization_maturity(organizations), + "activity_patterns": _analyze_activity_patterns(organizations), + } + + return insights + + +def _analyze_language_trends(organizations: List[Dict[str, Any]]) -> Dict[str, Any]: + """Analyze programming language trends from organization data.""" + # This is a simplified example - in a real implementation, + # you might use additional GitHub API calls to get language data + return { + "message": "Language trend analysis would require additional API calls", + "suggestion": "Use repository listing and language detection APIs", + } + + +def _analyze_geographic_distribution( + organizations: List[Dict[str, Any]], +) -> Dict[str, Any]: + """Analyze geographic distribution of organizations.""" + locations = {} + for org in organizations: + location = org.get("location", "").strip() + if location: + locations[location] = locations.get(location, 0) + 1 + + return { + "total_with_location": len( + [org for org in organizations if org.get("location")] + ), + "top_locations": dict( + sorted(locations.items(), key=lambda x: x[1], reverse=True)[:10] + ), + } + + +def _analyze_organization_maturity( + organizations: List[Dict[str, Any]], +) -> Dict[str, Any]: + """Analyze the maturity of organizations.""" + mature_count = sum(1 for org in organizations if org.get("age_years", 0) > 5) + established_count = sum( + 1 for org in organizations if 2 <= org.get("age_years", 0) <= 5 + ) + new_count = sum(1 for org in organizations if org.get("age_years", 0) < 2) + + return { + "mature_orgs": mature_count, # > 5 years + "established_orgs": established_count, # 2-5 years + "new_orgs": new_count, # < 2 years + "maturity_ratio": mature_count / len(organizations) if organizations else 0, + } + + +def _analyze_activity_patterns(organizations: List[Dict[str, Any]]) -> Dict[str, Any]: + """Analyze activity patterns across organizations.""" + if not organizations: + return {} + + activity_scores = [org.get("activity_score", 0) for org in organizations] + + return { + "average_activity": sum(activity_scores) / len(activity_scores), + "high_activity_count": sum(1 for score in activity_scores if score > 75), + "low_activity_count": sum(1 for score in activity_scores if score < 25), + "activity_distribution": { + "high": sum(1 for score in activity_scores if score > 75), + "medium": sum(1 for score in activity_scores if 25 <= score <= 75), + "low": sum(1 for score in activity_scores if score < 25), + }, + } + + +async def demonstrate_pre_auth_workflow(): + """ + Demonstrate the workflow with pre-authorization. + """ + print("OAuth Workflow Pre-Authorization Example") + print("=" * 50) + + # Note: In a real scenario, you would use the MCP agent server + # to call the workflow_pre_auth endpoint before running the workflow + print("\n1. Pre-authorization step:") + print(" Before running this workflow, you should pre-authorize OAuth tokens:") + print(" Use the workflow_pre_auth endpoint with the following structure:") + + example_tokens = [ + { + "access_token": "github_oauth_access_token_here", + "refresh_token": "github_oauth_refresh_token_here", + "server_name": "github", + "scopes": ["read:org", "public_repo"], + "authorization_server": "https://github.com/login/oauth/authorize", + } + ] + + print(f" Token structure: {json.dumps(example_tokens, indent=2)}") + + print("\n2. Running workflow with pre-authorized tokens:") + + try: + async with app.run() as workflow_app: + # Simulate workflow execution + # In practice, this would be called through the MCP agent server + context = workflow_app.context + + result = await analyze_github_ecosystem( + app_ctx=context, + focus_areas=["AI/ML", "cloud", "security"], + include_details=True, + ) + + print("\n3. Workflow Results:") + print(f" - Focus areas analyzed: {result['focus_areas']}") + print(f" - Queries executed: {len(result['queries_executed'])}") + print( + f" - Organizations found: {result['results']['summary'].get('total_organizations', 0)}" + ) + + if result["results"]["errors"]: + print(f" - Errors encountered: {len(result['results']['errors'])}") + + print("\n4. Ecosystem Insights:") + insights = result["insights"] + if "geographic_distribution" in insights: + top_locations = insights["geographic_distribution"].get( + "top_locations", {} + ) + if top_locations: + print(f" - Top locations: {list(top_locations.keys())[:3]}") + + if "maturity_analysis" in insights: + maturity = insights["maturity_analysis"] + print(f" - Mature organizations: {maturity.get('mature_orgs', 0)}") + print(f" - Maturity ratio: {maturity.get('maturity_ratio', 0):.2%}") + + except Exception as e: + print(f" Workflow failed: {e}") + print("\n This is expected if OAuth tokens are not properly configured.") + print(" To run this example successfully:") + print(" 1. Set up a GitHub OAuth app") + print(" 2. Configure mcp_agent.config.yaml with OAuth settings") + print(" 3. Use workflow_pre_auth to store valid tokens") + print(" 4. Run the workflow through the MCP agent server") + + +async def main(): + """ + Main function demonstrating the workflow pre-authorization pattern. + """ + await demonstrate_pre_auth_workflow() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/oauth/workflow_pre_auth/client.py b/examples/oauth/workflow_pre_auth/client.py new file mode 100644 index 000000000..4d9ae1a7c --- /dev/null +++ b/examples/oauth/workflow_pre_auth/client.py @@ -0,0 +1,204 @@ +import asyncio +import time +import os +import sys + +from datetime import timedelta +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from mcp import ClientSession +from mcp.types import CallToolResult, LoggingMessageNotificationParams +from mcp_agent.app import MCPApp +from mcp_agent.config import MCPServerSettings +from mcp_agent.core.context import Context +from mcp_agent.mcp.gen_client import gen_client +from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession +from mcp_agent.human_input.console_handler import console_input_callback +from mcp_agent.elicitation.handler import console_elicitation_callback + +from rich import print + +try: + from exceptiongroup import ExceptionGroup as _ExceptionGroup # Python 3.10 backport +except Exception: # pragma: no cover + _ExceptionGroup = None # type: ignore +try: + from anyio import BrokenResourceError as _BrokenResourceError +except Exception: # pragma: no cover + _BrokenResourceError = None # type: ignore + +# Get GitHub access token from environment or ask user +access_token = os.getenv("GITHUB_ACCESS_TOKEN") + +if not access_token: + print("\nGitHub access token not found in environment variable GITHUB_ACCESS_TOKEN") + print("\nTo get a GitHub access token:") + print("1. Run the oauth_demo.py script from examples/oauth/ to get a fresh token") + print("2. Or go to GitHub Settings > Developer settings > Personal access tokens") + print("3. Create a token with 'read:org' and 'public_repo' scopes") + print("\nThen set the token:") + print("export GITHUB_ACCESS_TOKEN='your_token_here'") + +# Verify token format +if not access_token.startswith(("gho_", "ghp_", "github_pat_")): + print( + f"Warning: Token doesn't look like a GitHub token (got: {access_token[:10]}...)" + ) + print("GitHub tokens usually start with 'gho_', 'ghp_', or 'github_pat_'") + + +async def main(): + # Create MCPApp to get the server registry + app = MCPApp( + name="workflow_mcp_client", + human_input_callback=console_input_callback, + elicitation_callback=console_elicitation_callback, + ) + async with app.run() as client_app: + logger = client_app.logger + context = client_app.context + + # Connect to the workflow server + logger.info("Connecting to workflow server...") + + # Override the server configuration to point to our local script + context.server_registry.registry["basic_agent_server"] = MCPServerSettings( + name="basic_agent_server", + description="Local workflow server running the basic agent example", + transport="sse", + url="http://127.0.0.1:8000/sse", + # command="uv", + # args=["run", "main.py"], + ) + + # Define a logging callback to receive server-side log notifications + async def on_server_log(params: LoggingMessageNotificationParams) -> None: + level = params.level.upper() + name = params.logger or "server" + print(f"[SERVER LOG] [{level}] [{name}] {params.data}") + + # Provide a client session factory that installs our logging callback + # and prints non-logging notifications to the console + class ConsolePrintingClientSession(MCPAgentClientSession): + async def _received_notification(self, notification): # type: ignore[override] + try: + method = getattr(notification.root, "method", None) + except Exception: + method = None + + # Avoid duplicating server log prints (handled by logging_callback) + if method and method != "notifications/message": + try: + data = notification.model_dump() + except Exception: + data = str(notification) + print(f"[SERVER NOTIFY] {method}: {data}") + + return await super()._received_notification(notification) + + def make_session( + read_stream: MemoryObjectReceiveStream, + write_stream: MemoryObjectSendStream, + read_timeout_seconds: timedelta | None, + context: Context | None = None, + ) -> ClientSession: + return ConsolePrintingClientSession( + read_stream=read_stream, + write_stream=write_stream, + read_timeout_seconds=read_timeout_seconds, + logging_callback=on_server_log, + context=context, + ) + + try: + async with gen_client( + "basic_agent_server", + context.server_registry, + client_session_factory=make_session, + ) as server: + try: + await server.set_logging_level("info") + except Exception: + # Older servers may not support logging capability + print("[client] Server does not support logging/setLevel") + + # List available tools + tools_result = await server.list_tools() + logger.info( + "Available tools:", + data={"tools": [tool.name for tool in tools_result.tools]}, + ) + + if len(sys.argv) < 2 or sys.argv[1] != "--skip-pre-auth": + print("Performing pre-auth") + await server.call_tool( + "workflows-pre-auth", + arguments={ + "workflow_name": "github_org_search", + "tokens": [ + { + "access_token": access_token, + "server_name": "github", + } + ], + }, + ) + + print( + await server.call_tool("github_org_search", {"query": "lastmileai"}) + ) + except Exception as e: + # Tolerate benign shutdown races from stdio client (BrokenResourceError within ExceptionGroup) + if _ExceptionGroup is not None and isinstance(e, _ExceptionGroup): + subs = getattr(e, "exceptions", []) or [] + if ( + _BrokenResourceError is not None + and subs + and all(isinstance(se, _BrokenResourceError) for se in subs) + ): + logger.debug("Ignored BrokenResourceError from stdio shutdown") + else: + raise + elif _BrokenResourceError is not None and isinstance( + e, _BrokenResourceError + ): + logger.debug("Ignored BrokenResourceError from stdio shutdown") + elif "BrokenResourceError" in str(e): + logger.debug( + "Ignored BrokenResourceError from stdio shutdown (string match)" + ) + else: + raise + # Nudge cleanup of subprocess transports before the loop closes to avoid + # 'Event loop is closed' from BaseSubprocessTransport.__del__ on GC. + try: + await asyncio.sleep(0) + except Exception: + pass + try: + import gc + + gc.collect() + except Exception: + pass + + +def _tool_result_to_json(tool_result: CallToolResult): + if tool_result.content and len(tool_result.content) > 0: + text = tool_result.content[0].text + try: + # Try to parse the response as JSON if it's a string + import json + + return json.loads(text) + except (json.JSONDecodeError, TypeError): + # If it's not valid JSON, just use the text + return None + + +if __name__ == "__main__": + start = time.time() + asyncio.run(main()) + end = time.time() + t = end - start + + print(f"Total run time: {t:.2f}s") diff --git a/examples/oauth/workflow_pre_auth/main.py b/examples/oauth/workflow_pre_auth/main.py new file mode 100644 index 000000000..8966a43dc --- /dev/null +++ b/examples/oauth/workflow_pre_auth/main.py @@ -0,0 +1,134 @@ +""" +Workflow MCP Server Example + +This example demonstrates three approaches to creating agents and workflows: +1. Traditional workflow-based approach with manual agent creation +2. Programmatic agent configuration using AgentConfig +3. Declarative agent configuration using FastMCPApp decorators +""" + +import asyncio +import json +from typing import Optional +from pydantic import AnyHttpUrl + +from mcp.server.fastmcp import FastMCP +from mcp_agent.core.context import Context as AppContext + +from mcp_agent.app import MCPApp +from mcp_agent.server.app_server import create_mcp_server_for_app +from mcp_agent.mcp.gen_client import gen_client +from mcp_agent.config import ( + MCPServerSettings, + Settings, + LoggerSettings, + MCPSettings, + MCPServerAuthSettings, + MCPOAuthClientSettings, +) + +# Note: This is purely optional: +# if not provided, a default FastMCP server will be created by MCPApp using create_mcp_server_for_app() +mcp = FastMCP(name="basic_agent_server", instructions="My basic agent server example.") + + +class MCPServerOAuthSettings: + pass + + +settings = Settings( + execution_engine="asyncio", + logger=LoggerSettings(level="info"), + mcp=MCPSettings( + servers={ + "github": MCPServerSettings( + name="github", + transport="streamable_http", + url="https://api.githubcopilot.com/mcp/", + auth=MCPServerAuthSettings( + oauth=MCPOAuthClientSettings( + enabled=True, + scopes=[ + "read:org", # Required for search_orgs tool + "public_repo", # Access to public repositories + "user:email", # User information access + ], + authorization_server=AnyHttpUrl( + "https://github.com/login/oauth" + ), + resource=AnyHttpUrl("https://api.githubcopilot.com/mcp"), + ) + ), + ) + } + ), +) + +# Define the MCPApp instance. The server created for this app will advertise the +# MCP logging capability and forward structured logs upstream to connected clients. +app = MCPApp( + name="basic_agent_server", + description="Basic agent server example", + mcp=mcp, + settings=settings, +) + + +@app.tool(name="github_org_search") +async def github_org_search(query: str, app_ctx: Optional[AppContext] = None) -> str: + # Use the context's app if available for proper logging with upstream_session + _app = app_ctx.app if app_ctx else app + # Ensure the app's logger is bound to the current context with upstream_session + if _app._logger and hasattr(_app._logger, "_bound_context"): + _app._logger._bound_context = app_ctx + + try: + async with gen_client( + "github", server_registry=app_ctx.server_registry, context=app_ctx + ) as github_client: + result = await github_client.call_tool( + "search_orgs", + {"query": query, "perPage": 10, "sort": "best-match", "order": "desc"}, + ) + + organizations = [] + if result.content: + for content_item in result.content: + if hasattr(content_item, "text"): + try: + data = json.loads(content_item.text) + if isinstance(data, dict) and "items" in data: + organizations.extend(data["items"]) + elif isinstance(data, list): + organizations.extend(data) + except json.JSONDecodeError: + pass + + return str(organizations) + except Exception: + import traceback + + return f"Error: {traceback.format_exc()}" + + +async def main(): + async with app.run() as agent_app: + # Log registered workflows and agent configurations + agent_app.logger.info(f"Creating MCP server for {agent_app.name}") + + agent_app.logger.info("Registered workflows:") + for workflow_id in agent_app.workflows: + agent_app.logger.info(f" - {workflow_id}") + + # Create the MCP server that exposes both workflows and agent configurations, + # optionally using custom FastMCP settings + mcp_server = create_mcp_server_for_app(agent_app) + agent_app.logger.info(f"MCP Server settings: {mcp_server.settings}") + + # Run the server + # await mcp_server.run_stdio_async() + await mcp_server.run_sse_async() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/oauth/workflow_pre_auth/worker.py b/examples/oauth/workflow_pre_auth/worker.py new file mode 100644 index 000000000..39b2a3c67 --- /dev/null +++ b/examples/oauth/workflow_pre_auth/worker.py @@ -0,0 +1,31 @@ +""" +Worker script for the Temporal workflow example. +This script starts a Temporal worker that can execute workflows and activities. +Run this script in a separate terminal window before running the main.py script. + +This leverages the TemporalExecutor's start_worker method to handle the worker setup. +""" + +import asyncio +import logging + + +from mcp_agent.executor.temporal import create_temporal_worker_for_app + +from main import app + +# Initialize logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def main(): + """ + Start a Temporal worker for the example workflows using the app's executor. + """ + async with create_temporal_worker_for_app(app) as worker: + await worker.run() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/mcp_agent/app.py b/src/mcp_agent/app.py index 2ada0483b..f6fe7946a 100644 --- a/src/mcp_agent/app.py +++ b/src/mcp_agent/app.py @@ -3,7 +3,7 @@ import sys import functools -from types import MethodType +from types import MethodType, FunctionType from typing import ( Any, Dict, @@ -40,6 +40,8 @@ from mcp_agent.tracing.telemetry import get_tracer from mcp_agent.utils.common import unwrap from mcp_agent.workflows.llm.llm_selector import ModelSelector +from mcp_agent.oauth.manager import TokenManager +from mcp_agent.oauth.store import InMemoryTokenStore from mcp_agent.workflows.factory import load_agent_specs_from_dir @@ -253,6 +255,29 @@ async def initialize(self): # Store a reference to this app instance in the context for easier access self._context.app = self + # Initialize OAuth token management helpers if configured + oauth_settings = None + try: + if self._context.config: + oauth_settings = self._context.config.oauth + except Exception: + oauth_settings = None + + if oauth_settings: + self.logger.debug("Initializing OAuth token management") + token_store = InMemoryTokenStore() + token_manager = TokenManager( + token_store=token_store, + settings=oauth_settings, + ) + self._context.token_store = token_store + self._context.token_manager = token_manager + + # Check for pre-configured tokens and store them with synthetic users + await self._initialize_preconfigured_tokens(token_manager) + else: + self.logger.debug("No OAuth settings found, skipping OAuth initialization") + # Provide a safe default bound context for loggers created after init without explicit context try: set_default_bound_context(self._context) @@ -330,6 +355,60 @@ async def initialize(self): }, ) + async def _initialize_preconfigured_tokens(self, token_manager): + """Check for pre-configured OAuth tokens and store them with a single synthetic user.""" + + mcp_config = getattr(self._context.config, "mcp", None) + if not mcp_config or not getattr(mcp_config, "servers", None): + self.logger.debug( + "No MCP servers found in config, skipping token initialization" + ) + return + + servers = mcp_config.servers + self.logger.debug(f"Found MCP servers in config: {list(servers.keys())}") + + servers_with_tokens = [] + + # First pass: check which servers have pre-configured tokens + for server_name, server_config in servers.items(): + if not hasattr(server_config, "auth") or not server_config.auth: + self.logger.debug( + f"Server '{server_name}' has no auth config, skipping" + ) + continue + + oauth_config = getattr(server_config.auth, "oauth", None) + + if ( + not oauth_config + or not oauth_config.enabled + or not oauth_config.access_token + ): + continue + + self.logger.debug(f"Server '{server_name}' has pre-configured OAuth token") + servers_with_tokens.append((server_name, server_config)) + + # If we have any servers with pre-configured tokens, create a single synthetic user + if servers_with_tokens: + synthetic_user = ( + token_manager.create_default_user_for_preconfigured_tokens() + ) + self._context.current_user = synthetic_user + self.logger.info( + f"Created synthetic user for pre-configured OAuth tokens: {synthetic_user.cache_key}" + ) + + # Second pass: store all tokens using the same synthetic user + for server_name, server_config in servers_with_tokens: + self.logger.info( + f"Storing pre-configured OAuth token for server: {server_name}" + ) + await token_manager.store_preconfigured_token( + server_name, server_config, synthetic_user + ) + async def get_token_node(self): """Return the root app token node, if available.""" if not self._context or not getattr(self._context, "token_counter", None): @@ -561,6 +640,25 @@ async def wrapper(*args, **kwargs): # Fall back to the original function return await fn(*args, **kwargs) + # Ensure the wrapper shares the original function's globals so that + # string annotations (from __future__ import annotations) continue to + # resolve against the workflow module rather than mcp_agent.app. + original_globals = getattr(fn, "__globals__", None) + if original_globals is not None and wrapper.__globals__ is not original_globals: + rebuilt_wrapper = FunctionType( + wrapper.__code__, + original_globals, + name=wrapper.__name__, + argdefs=wrapper.__defaults__, + closure=wrapper.__closure__, + ) + rebuilt_wrapper.__kwdefaults__ = wrapper.__kwdefaults__ + rebuilt_wrapper.__annotations__ = wrapper.__annotations__ + rebuilt_wrapper.__dict__.update(wrapper.__dict__) + rebuilt_wrapper = functools.update_wrapper(rebuilt_wrapper, fn) + rebuilt_wrapper.__wrapped__ = fn + wrapper = rebuilt_wrapper + return wrapper def _create_workflow_from_function( diff --git a/src/mcp_agent/config.py b/src/mcp_agent/config.py index 3d296bb42..72ef61118 100644 --- a/src/mcp_agent/config.py +++ b/src/mcp_agent/config.py @@ -13,6 +13,7 @@ from pydantic import ( AliasChoices, + AnyHttpUrl, BaseModel, ConfigDict, Field, @@ -29,6 +30,95 @@ class MCPServerAuthSettings(BaseModel): """Represents authentication configuration for a server.""" api_key: str | None = None + oauth: Optional["MCPOAuthClientSettings"] = None + + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) + + +class MCPAuthorizationServerSettings(BaseModel): + """Configuration for exposing the MCP Agent server as an OAuth protected resource.""" + + enabled: bool = False + issuer_url: AnyHttpUrl | None = None + resource_server_url: AnyHttpUrl | None = None + service_documentation_url: AnyHttpUrl | None = None + required_scopes: List[str] = Field(default_factory=list) + jwks_uri: AnyHttpUrl | None = None + introspection_endpoint: AnyHttpUrl | None = None + introspection_client_id: str | None = None + introspection_client_secret: str | None = None + token_cache_ttl_seconds: int = Field(300, ge=0) + + # RFC 9068 audience validation settings + # TODO: this should really depend on the app_id, or config_id so that we can enforce unique values. + # To be removed and replaced with a fixed value once we have app_id/config_id support + expected_audiences: List[str] = Field(default_factory=list) + """List of audience values this resource server accepts. + MUST be configured to comply with RFC 9068 audience validation. + Audience validation is always enforced when authorization is enabled.""" + + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) + + @model_validator(mode="after") + def _validate_required_urls(self) -> "MCPAuthorizationServerSettings": + if self.enabled: + missing = [] + if self.issuer_url is None: + missing.append("issuer_url") + if self.resource_server_url is None: + missing.append("resource_server_url") + # Validate audience configuration for RFC 9068 compliance + if not self.expected_audiences: + missing.append("expected_audiences (required for RFC 9068 compliance)") + if missing: + raise ValueError( + " | ".join(missing) + " must be set when authorization is enabled" + ) + return self + + +class MCPOAuthClientSettings(BaseModel): + """Configuration for authenticating to downstream OAuth-protected MCP servers.""" + + enabled: bool = False + scopes: List[str] = Field(default_factory=list) + resource: AnyHttpUrl | None = None + authorization_server: AnyHttpUrl | None = None + client_id: str | None = None + client_secret: str | None = None + # Support for pre-configured access tokens (bypasses OAuth flow) + access_token: str | None = None + refresh_token: str | None = None + expires_at: float | None = None + token_type: str = "Bearer" + redirect_uri_options: List[str] = Field(default_factory=list) + extra_authorize_params: Dict[str, str] = Field(default_factory=dict) + extra_token_params: Dict[str, str] = Field(default_factory=dict) + require_pkce: bool = True + use_internal_callback: bool = True + + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) + + +class OAuthTokenStoreSettings(BaseModel): + """Settings for OAuth token persistence.""" + + backend: Literal["memory", "redis"] = "memory" + redis_url: str | None = None + redis_prefix: str = "mcp_agent:oauth_tokens" + refresh_leeway_seconds: int = Field(60, ge=0) + + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) + + +class OAuthSettings(BaseModel): + """Global OAuth-related settings for MCP Agent.""" + + token_store: OAuthTokenStoreSettings = Field( + default_factory=OAuthTokenStoreSettings + ) + flow_timeout_seconds: int = Field(300, ge=30) + callback_base_url: AnyHttpUrl | None = None model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) @@ -860,6 +950,12 @@ class Settings(BaseSettings): agents: SubagentSettings | None = SubagentSettings() """Settings for defining and loading subagents for the MCP Agent application""" + authorization: MCPAuthorizationServerSettings | None = None + """Settings for exposing this MCP application as an OAuth protected resource""" + + oauth: OAuthSettings | None = Field(default_factory=OAuthSettings) + """Global OAuth client configuration (token store, delegated auth defaults)""" + def __eq__(self, other): # type: ignore[override] if not isinstance(other, Settings): return NotImplemented diff --git a/src/mcp_agent/core/context.py b/src/mcp_agent/core/context.py index d449c938a..cc7bee78e 100644 --- a/src/mcp_agent/core/context.py +++ b/src/mcp_agent/core/context.py @@ -33,6 +33,7 @@ from mcp_agent.workflows.llm.llm_selector import ModelSelector from mcp_agent.logging.logger import get_logger from mcp_agent.tracing.token_counter import TokenCounter +from mcp_agent.oauth.identity import OAuthUserIdentity if TYPE_CHECKING: @@ -42,6 +43,8 @@ from mcp_agent.executor.workflow_signal import SignalWaitCallback from mcp_agent.executor.workflow_registry import WorkflowRegistry from mcp_agent.app import MCPApp + from mcp_agent.oauth.manager import TokenManager + from mcp_agent.oauth.store import TokenStore else: # Runtime placeholders for the types AgentSpec = Any @@ -50,6 +53,8 @@ SignalWaitCallback = Any WorkflowRegistry = Any MCPApp = Any + TokenManager = Any + TokenStore = Any logger = get_logger(__name__) @@ -93,6 +98,13 @@ class Context(BaseModel): gateway_url: str | None = None gateway_token: str | None = None + # Current authenticated user (set when acting as an MCP server) + current_user: Optional[OAuthUserIdentity] = None + + # OAuth helpers for downstream servers + token_store: Optional[TokenStore] = None + token_manager: Optional[TokenManager] = None + model_config = ConfigDict( extra="allow", arbitrary_types_allowed=True, # Tell Pydantic to defer type evaluation @@ -256,6 +268,14 @@ async def cleanup_context(shutdown_logger: bool = False): shutdown_logger: If True, completely shutdown OTEL infrastructure. If False, just cleanup app-specific resources. """ + global _global_context + + if _global_context and getattr(_global_context, "token_manager", None): + try: + await _global_context.token_manager.aclose() # type: ignore[call-arg] + except Exception: + pass + if shutdown_logger: # Shutdown logging and telemetry completely await LoggingConfig.shutdown() diff --git a/src/mcp_agent/mcp/mcp_server_registry.py b/src/mcp_agent/mcp/mcp_server_registry.py index a28735373..04a8cad08 100644 --- a/src/mcp_agent/mcp/mcp_server_registry.py +++ b/src/mcp_agent/mcp/mcp_server_registry.py @@ -31,6 +31,7 @@ from mcp_agent.logging.logger import get_logger from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession +from mcp_agent.oauth.http import OAuthHttpxAuth from mcp_agent.mcp.mcp_connection_manager import MCPConnectionManager if TYPE_CHECKING: @@ -215,6 +216,24 @@ async def start_server( kwargs["sse_read_timeout"] = sse_read_timeout # For Streamable HTTP, we get an additional callback for session ID + auth_handler = None + oauth_cfg = config.auth.oauth if config.auth else None + if oauth_cfg and oauth_cfg.enabled: + if context is None or getattr(context, "token_manager", None) is None: + logger.warning( + f"{server_name}: OAuth configured but token manager not available; skipping auth" + ) + else: + auth_handler = OAuthHttpxAuth( + token_manager=context.token_manager, + context=context, + server_name=server_name, + server_config=config, + scopes=oauth_cfg.scopes, + ) + if auth_handler: + kwargs["auth"] = auth_handler + async with streamablehttp_client( **kwargs, ) as (read_stream, write_stream, session_id_callback): diff --git a/src/mcp_agent/oauth/__init__.py b/src/mcp_agent/oauth/__init__.py new file mode 100644 index 000000000..7612bd080 --- /dev/null +++ b/src/mcp_agent/oauth/__init__.py @@ -0,0 +1,19 @@ +"""OAuth support utilities for MCP Agent. + +Modules export their own public APIs; this package file avoids importing them +eagerly to sidestep circular dependencies during initialization. +""" + +__all__ = [ + "access_token", + "callbacks", + "errors", + "flow", + "http", + "identity", + "manager", + "metadata", + "pkce", + "records", + "store", +] diff --git a/src/mcp_agent/oauth/access_token.py b/src/mcp_agent/oauth/access_token.py new file mode 100644 index 000000000..0ec75a190 --- /dev/null +++ b/src/mcp_agent/oauth/access_token.py @@ -0,0 +1,113 @@ +"""Extended access token model for MCP Agent authorization flows.""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any, Dict, Iterable, List + +from mcp.server.auth.provider import AccessToken + + +class MCPAccessToken(AccessToken): + """Access token enriched with identity and claim metadata.""" + + subject: str | None = None + email: str | None = None + issuer: str | None = None + resource_indicator: str | None = None + claims: Dict[str, Any] | None = None + audiences: List[str] | None = None + + @classmethod + def from_introspection( + cls, + token: str, + payload: Dict[str, Any], + *, + resource_hint: str | None = None, + ) -> "MCPAccessToken": + """Build an access token instance from an OAuth 2.0 introspection response.""" + client_id = _first_non_empty( + payload.get("client_id"), + payload.get("clientId"), + payload.get("cid"), + ) + scope_value = payload.get("scope") or payload.get("scp") + if isinstance(scope_value, str): + scopes: List[str] = [s for s in scope_value.split() if s] + elif isinstance(scope_value, Iterable): + scopes = [str(item) for item in scope_value] + else: + scopes = [] + + # Enhanced audience extraction for RFC 9068 compliance + audiences = _extract_all_audiences(payload) + audience_value = audiences[0] if audiences else None + resource = resource_hint or audience_value + + expires_at = payload.get("exp") + + return cls( + token=token, + client_id=str(client_id) if client_id is not None else "", + scopes=scopes, + expires_at=expires_at, + resource=resource, + subject=_first_non_empty(payload.get("sub"), payload.get("subject")), + email=_first_non_empty( + payload.get("email"), payload.get("preferred_username") + ), + issuer=payload.get("iss"), + resource_indicator=resource, + audiences=audiences, + claims=payload, + ) + + def is_expired(self, *, leeway_seconds: int = 0) -> bool: + """Return True if token is expired considering optional leeway.""" + if self.expires_at is None: + return False + now = datetime.now(tz=timezone.utc).timestamp() + return now >= (self.expires_at - leeway_seconds) + + def validate_audience(self, expected_audiences: List[str]) -> bool: + """Validate this token's audience claims against expected values per RFC 9068.""" + if not self.audiences: + return False + if not expected_audiences: + return False + + return bool(set(expected_audiences).intersection(set(self.audiences))) + + +def _extract_all_audiences(payload: Dict[str, Any]) -> List[str]: + """Extract all audience values from token payload per RFC 9068.""" + audiences = [] + + # Extract from 'aud' claim + aud_claim = payload.get("aud") + if aud_claim: + if isinstance(aud_claim, str): + audiences.append(aud_claim) + elif isinstance(aud_claim, (list, tuple)): + audiences.extend([str(aud) for aud in aud_claim if aud]) + + # Extract from 'resource' claim (OAuth 2.0 resource indicators) + resource_claim = payload.get("resource") + if resource_claim: + if isinstance(resource_claim, str): + audiences.append(resource_claim) + elif isinstance(resource_claim, (list, tuple)): + audiences.extend([str(res) for res in resource_claim if res]) + + return list(set(audiences)) # Remove duplicates + + +def _first_non_empty(*values: Any) -> Any | None: + for value in values: + if value is None: + continue + if isinstance(value, str) and not value: + continue + return value + return None diff --git a/src/mcp_agent/oauth/callbacks.py b/src/mcp_agent/oauth/callbacks.py new file mode 100644 index 000000000..8bfa99957 --- /dev/null +++ b/src/mcp_agent/oauth/callbacks.py @@ -0,0 +1,54 @@ +"""Callback coordination for delegated OAuth flows.""" + +from __future__ import annotations + +import asyncio +from typing import Any, Dict + + +class OAuthCallbackRegistry: + """Manage asynchronous delivery of OAuth authorization callbacks.""" + + def __init__(self) -> None: + self._pending: Dict[str, asyncio.Future[Dict[str, Any]]] = {} + self._lock = asyncio.Lock() + + async def create_handle(self, flow_id: str) -> asyncio.Future[Dict[str, Any]]: + """Create (or reuse) a future associated with a flow identifier.""" + async with self._lock: + future = self._pending.get(flow_id) + if future is None or future.done(): + loop = asyncio.get_running_loop() + future = loop.create_future() + self._pending[flow_id] = future + return future + + async def deliver(self, flow_id: str, payload: Dict[str, Any]) -> bool: + """Set the result for a pending flow, returning False when no listener exists.""" + async with self._lock: + future = self._pending.get(flow_id) + if future is None: + # print all entries in _pending for debugging + return False + if not future.done(): + future.set_result(payload) + return True + + async def fail(self, flow_id: str, exc: Exception) -> bool: + async with self._lock: + future = self._pending.get(flow_id) + if future is None: + return False + if not future.done(): + future.set_exception(exc) + return True + + async def discard(self, flow_id: str) -> None: + async with self._lock: + future = self._pending.pop(flow_id, None) + if future and not future.done(): + future.cancel() + + +# Global registry used by server + flow coordinator +callback_registry = OAuthCallbackRegistry() diff --git a/src/mcp_agent/oauth/errors.py b/src/mcp_agent/oauth/errors.py new file mode 100644 index 000000000..3b1b5ce7a --- /dev/null +++ b/src/mcp_agent/oauth/errors.py @@ -0,0 +1,21 @@ +"""Custom exception types for OAuth workflows.""" + + +class OAuthFlowError(Exception): + """Base class for OAuth-related failures.""" + + +class AuthorizationDeclined(OAuthFlowError): + """Raised when the user declines an authorization request.""" + + +class CallbackTimeoutError(OAuthFlowError): + """Raised when the delegated authorization callback is not received in time.""" + + +class TokenRefreshError(OAuthFlowError): + """Raised when refreshing an access token fails irrecoverably.""" + + +class MissingUserIdentityError(OAuthFlowError): + """Raised when an OAuth flow is attempted without a known user identity.""" diff --git a/src/mcp_agent/oauth/flow.py b/src/mcp_agent/oauth/flow.py new file mode 100644 index 000000000..021b9c6a7 --- /dev/null +++ b/src/mcp_agent/oauth/flow.py @@ -0,0 +1,238 @@ +"""Delegated OAuth authorization flow coordinator.""" + +from __future__ import annotations + +import asyncio +import time +import uuid +from json import JSONDecodeError +from typing import Any, Dict, Sequence +from urllib.parse import parse_qs, urlparse + +import httpx +from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata +from mcp.server.session import ServerSession + +from mcp_agent.config import MCPOAuthClientSettings, OAuthSettings +from mcp_agent.core.context import Context +from mcp_agent.logging.logger import get_logger +from mcp_agent.oauth.callbacks import callback_registry +from mcp_agent.oauth.errors import ( + AuthorizationDeclined, + CallbackTimeoutError, + MissingUserIdentityError, + OAuthFlowError, +) +from mcp_agent.oauth.identity import OAuthUserIdentity +from mcp_agent.oauth.pkce import ( + generate_code_challenge, + generate_code_verifier, + generate_state, +) +from mcp_agent.oauth.records import TokenRecord + +logger = get_logger(__name__) + + +class AuthorizationFlowCoordinator: + """Handles the interactive OAuth Authorization Code flow via MCP clients.""" + + def __init__(self, *, http_client: httpx.AsyncClient, settings: OAuthSettings): + self._http_client = http_client + self._settings = settings + + async def authorize( + self, + *, + context: Context, + user: OAuthUserIdentity, + server_name: str, + oauth_config: MCPOAuthClientSettings, + resource: str, + authorization_server_url: str, + resource_metadata: ProtectedResourceMetadata, + auth_metadata: OAuthMetadata, + scopes: Sequence[str], + ) -> TokenRecord: + if not user: + raise MissingUserIdentityError( + "Cannot begin OAuth flow without authenticated MCP user" + ) + + client_id = oauth_config.client_id + if not client_id: + raise OAuthFlowError( + f"No OAuth client_id configured for server '{server_name}'." + ) + + redirect_options = list(oauth_config.redirect_uri_options or []) + flow_id = uuid.uuid4().hex + internal_redirect = None + if oauth_config.use_internal_callback and self._settings.callback_base_url: + internal_redirect = f"{str(self._settings.callback_base_url).rstrip('/')}/internal/oauth/callback/{flow_id}" + redirect_options.insert(0, internal_redirect) + + if not redirect_options: + raise OAuthFlowError( + "No redirect URI options configured for OAuth authorization flow" + ) + + redirect_uri = redirect_options[0] + + code_verifier = generate_code_verifier() + code_challenge = generate_code_challenge(code_verifier) + state = generate_state() + scope_param = " ".join(scopes) + + params = { + "response_type": "code", + "client_id": client_id, + "redirect_uri": redirect_uri, + "scope": scope_param, + "state": state, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "resource": resource, + } + + # add extra params if any + if oauth_config.extra_authorize_params: + params.update(oauth_config.extra_authorize_params) + + import urllib.parse + + authorize_url = httpx.URL( + str(auth_metadata.authorization_endpoint) + + "?" + + urllib.parse.urlencode(params) + ) + + callback_future = None + if internal_redirect is not None: + callback_future = await callback_registry.create_handle(flow_id) + + request_payload = { + "url": str(authorize_url), + "message": f"Authorization required for {server_name}", + "redirect_uri_options": redirect_options, + "flow_id": flow_id, + } + + result = await _send_auth_request(context, request_payload) + + try: + if result and result.get("url"): + callback_data = _parse_callback_params(result["url"]) + if callback_future is not None: + await callback_registry.discard(flow_id) + elif result and result.get("code"): + callback_data = result + if callback_future is not None: + await callback_registry.discard(flow_id) + elif callback_future is not None: + timeout = self._settings.flow_timeout_seconds or 300 + try: + callback_data = await asyncio.wait_for( + callback_future, timeout=timeout + ) + except asyncio.TimeoutError as exc: + raise CallbackTimeoutError( + f"Timed out waiting for OAuth callback after {timeout} seconds" + ) from exc + else: + raise AuthorizationDeclined( + "Authorization request was declined by the user" + ) + finally: + if callback_future is not None: + await callback_registry.discard(flow_id) + + error = callback_data.get("error") + if error: + description = callback_data.get("error_description") or error + raise OAuthFlowError(f"Authorization server returned error: {description}") + + returned_state = callback_data.get("state") + if returned_state != state: + raise OAuthFlowError("State mismatch detected in OAuth callback") + + authorization_code = callback_data.get("code") + if not authorization_code: + raise OAuthFlowError("Authorization callback did not include code") + + token_endpoint = str(auth_metadata.token_endpoint) + data: Dict[str, Any] = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": redirect_uri, + "client_id": client_id, + "code_verifier": code_verifier, + "resource": resource, + } + if scope_param: + data["scope"] = scope_param + if oauth_config.extra_token_params: + data.update(oauth_config.extra_token_params) + + auth = None + if oauth_config.client_secret: + data["client_secret"] = oauth_config.client_secret + + token_response = await self._http_client.post( + token_endpoint, data=data, auth=auth, headers={"Accept": "application/json"} + ) + token_response.raise_for_status() + + try: + callback_data = token_response.json() + except JSONDecodeError: + callback_data = _parse_callback_params("?" + token_response.text) + + access_token = callback_data.get("access_token") + if not access_token: + raise OAuthFlowError("Token endpoint response missing access_token") + refresh_token = callback_data.get("refresh_token") + expires_in = callback_data.get("expires_in") + expires_at = None + if isinstance(expires_in, (int, float)): + expires_at = time.time() + float(expires_in) + + scope_from_payload = callback_data.get("scope") + if isinstance(scope_from_payload, str) and scope_from_payload.strip(): + effective_scopes = tuple(scope_from_payload.split()) + else: + effective_scopes = tuple(scopes) + + return TokenRecord( + access_token=access_token, + refresh_token=refresh_token, + expires_at=expires_at, + scopes=effective_scopes, + token_type=str(callback_data.get("token_type", "Bearer")), + resource=resource, + authorization_server=authorization_server_url, + metadata={"raw": token_response.text}, + ) + + +def _parse_callback_params(url: str) -> Dict[str, str]: + parsed = urlparse(url) + params = {} + params.update({k: v[-1] for k, v in parse_qs(parsed.query).items()}) + if parsed.fragment: + params.update({k: v[-1] for k, v in parse_qs(parsed.fragment).items()}) + return params + + +async def _send_auth_request( + context: Context, payload: Dict[str, Any] +) -> Dict[str, Any]: + session = getattr(context, "upstream_session", None) + + if session and isinstance(session, ServerSession): + rpc = getattr(session, "rpc", None) + if rpc and hasattr(rpc, "request"): + return await rpc.request("auth/request", payload) + raise AuthorizationDeclined( + "No upstream MCP session available to prompt user for authorization" + ) diff --git a/src/mcp_agent/oauth/http/__init__.py b/src/mcp_agent/oauth/http/__init__.py new file mode 100644 index 000000000..60479ab9e --- /dev/null +++ b/src/mcp_agent/oauth/http/__init__.py @@ -0,0 +1,5 @@ +"""HTTP client helpers for OAuth flows.""" + +from .auth import OAuthHttpxAuth + +__all__ = ["OAuthHttpxAuth"] diff --git a/src/mcp_agent/oauth/http/auth.py b/src/mcp_agent/oauth/http/auth.py new file mode 100644 index 000000000..a061e02ff --- /dev/null +++ b/src/mcp_agent/oauth/http/auth.py @@ -0,0 +1,78 @@ +"""httpx.Auth adapter that acquires tokens via TokenManager.""" + +from __future__ import annotations + +import httpx + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from mcp_agent.oauth.manager import TokenManager + from mcp_agent.core.context import Context + + +class OAuthHttpxAuth(httpx.Auth): + requires_request_body = True + + def __init__( + self, + *, + token_manager: "TokenManager", + context: "Context", + server_name: str, + server_config, + scopes=None, + ) -> None: + self._token_manager = token_manager + self._context = context + self._server_name = server_name + self._server_config = server_config + self._scopes = list(scopes) if scopes is not None else None + + async def async_auth_flow(self, request: httpx.Request): + try: + token_record = await self._token_manager.ensure_access_token( + context=self._context, + server_name=self._server_name, + server_config=self._server_config, + scopes=self._scopes, + ) + except Exception: + raise + request.headers["Authorization"] = ( + f"{token_record.token_type} {token_record.access_token}" + ) + response = yield request + + if response.status_code != 401: + return + + user = self._context.current_user + if user is None: + return + + await self._token_manager.invalidate( + user=user, + resource=token_record.resource or "", + authorization_server=token_record.authorization_server, + scopes=token_record.scopes, + ) + + refreshed_record = await self._token_manager.ensure_access_token( + context=self._context, + server_name=self._server_name, + server_config=self._server_config, + scopes=self._scopes, + ) + + # Create a new request with the refreshed token + retry_request = httpx.Request( + method=request.method, + url=request.url, + headers=request.headers.copy(), + content=request.content, + ) + retry_request.headers["Authorization"] = ( + f"{refreshed_record.token_type} {refreshed_record.access_token}" + ) + yield retry_request diff --git a/src/mcp_agent/oauth/identity.py b/src/mcp_agent/oauth/identity.py new file mode 100644 index 000000000..54c6d6fed --- /dev/null +++ b/src/mcp_agent/oauth/identity.py @@ -0,0 +1,46 @@ +"""Utilities for representing authenticated MCP users.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict + +from .access_token import MCPAccessToken + + +@dataclass(frozen=True) +class OAuthUserIdentity: + """Canonical identifier for an authenticated user within MCP Agent.""" + + provider: str + subject: str + email: str | None = None + claims: Dict[str, Any] | None = None + + @property + def cache_key(self) -> str: + """Return a deterministic cache key for token storage.""" + return f"{self.provider}:{self.subject}" + + @classmethod + def from_access_token( + cls, token: MCPAccessToken | None + ) -> "OAuthUserIdentity" | None: + """Build an identity from an enriched access token.""" + if token is None: + return None + subject = token.subject or _claim(token, "sub") + if not subject: + return None + provider = token.issuer or _claim(token, "iss") or "unknown" + email = ( + token.email or _claim(token, "email") or _claim(token, "preferred_username") + ) + claims = token.claims or {} + return cls(provider=provider, subject=subject, email=email, claims=claims) + + +def _claim(token: MCPAccessToken, key: str) -> Any | None: + if not token.claims: + return None + return token.claims.get(key) diff --git a/src/mcp_agent/oauth/manager.py b/src/mcp_agent/oauth/manager.py new file mode 100644 index 000000000..5a5a14396 --- /dev/null +++ b/src/mcp_agent/oauth/manager.py @@ -0,0 +1,385 @@ +"""Token management for downstream OAuth-protected MCP servers.""" + +from __future__ import annotations + +import asyncio +import time +from collections import defaultdict +from typing import Dict, Iterable, Sequence, TYPE_CHECKING + +import httpx +from httpx import URL + +from mcp_agent.config import MCPOAuthClientSettings, OAuthSettings +from mcp_agent.logging.logger import get_logger +from mcp_agent.oauth.errors import ( + OAuthFlowError, + TokenRefreshError, +) +from mcp_agent.oauth.flow import AuthorizationFlowCoordinator +from mcp_agent.oauth.identity import OAuthUserIdentity +from mcp_agent.oauth.metadata import ( + fetch_authorization_server_metadata, + fetch_resource_metadata, + normalize_resource, + select_authorization_server, +) +from mcp_agent.oauth.records import TokenRecord +from mcp_agent.oauth.store import ( + InMemoryTokenStore, + TokenStore, + TokenStoreKey, + scope_fingerprint, +) + +if TYPE_CHECKING: + from mcp_agent.core.context import Context + +logger = get_logger(__name__) + + +def create_default_user_for_preconfigured_tokens( + session_id: str | None = None, +) -> "OAuthUserIdentity": + """Create a synthetic user identity for pre-configured tokens.""" + from mcp_agent.oauth.identity import OAuthUserIdentity + + return OAuthUserIdentity( + provider="mcp-agent", + subject=f"preconfigured-tokens-{session_id}" + if session_id + else "preconfigured-tokens", + claims={ + "token_source": "preconfigured", + "description": "Synthetic user for pre-configured OAuth tokens", + }, + ) + + +class TokenManager: + """High-level orchestrator for acquiring and refreshing OAuth tokens.""" + + def __init__( + self, + *, + http_client: httpx.AsyncClient | None = None, + token_store: TokenStore | None = None, + settings: OAuthSettings | None = None, + ) -> None: + self._settings = settings or OAuthSettings() + self._token_store = token_store or InMemoryTokenStore() + self._http_client = http_client or httpx.AsyncClient(timeout=30.0) + self._own_http_client = http_client is None + self._flow = AuthorizationFlowCoordinator( + http_client=self._http_client, settings=self._settings + ) + self._locks: Dict[TokenStoreKey, asyncio.Lock] = defaultdict(asyncio.Lock) + self._resource_metadata_cache: Dict[str, tuple[float, object]] = {} + self._auth_metadata_cache: Dict[str, tuple[float, object]] = {} + + async def store_preconfigured_token( + self, server_name: str, server_config, synthetic_user: "OAuthUserIdentity" + ) -> None: + """Store a pre-configured token in the token store.""" + oauth_config = server_config.auth.oauth + + # Create token record + resource_str = ( + str(oauth_config.resource) + if oauth_config.resource + else getattr(server_config, "url", None) + ) + auth_server_str = ( + str(oauth_config.authorization_server) + if oauth_config.authorization_server + else None + ) + + from datetime import datetime, timezone + + record = TokenRecord( + access_token=oauth_config.access_token, + refresh_token=oauth_config.refresh_token, + scopes=tuple(oauth_config.scopes or []), + expires_at=oauth_config.expires_at, + token_type=oauth_config.token_type, + resource=resource_str, + authorization_server=auth_server_str, + obtained_at=datetime.now(tz=timezone.utc).timestamp(), + metadata={"server_name": server_name, "pre_configured": True}, + ) + + # Create storage key + key = TokenStoreKey( + user_key=synthetic_user.cache_key, + resource=resource_str or "", + authorization_server=auth_server_str, + scope_fingerprint=scope_fingerprint(oauth_config.scopes or []), + ) + + # Store the token + logger.debug( + f"Storing token with key: user_key={key.user_key}, resource={key.resource}, auth_server={key.authorization_server}, scope_fingerprint={key.scope_fingerprint}" + ) + await self._token_store.set(key, record) + + async def ensure_access_token( + self, + *, + context: "Context", + server_name: str, + server_config, + scopes: Iterable[str] | None = None, + ) -> TokenRecord: + oauth_config: MCPOAuthClientSettings | None = None + if server_config and server_config.auth: + oauth_config = getattr(server_config.auth, "oauth", None) + if not oauth_config or not oauth_config.enabled: + raise OAuthFlowError( + f"Server '{server_name}' is not configured for OAuth authentication" + ) + + user = context.current_user + + # Use the same key construction logic as store_preconfigured_token to ensure consistency + resource_str = ( + str(oauth_config.resource) + if oauth_config.resource + else getattr(server_config, "url", None) + ) + auth_server_str = ( + str(oauth_config.authorization_server) + if oauth_config.authorization_server + else None + ) + scope_list = ( + list(scopes) if scopes is not None else list(oauth_config.scopes or []) + ) + + # check for a globally configure token + key = TokenStoreKey( + user_key=create_default_user_for_preconfigured_tokens().cache_key, + resource=resource_str, + authorization_server=auth_server_str, + scope_fingerprint=scope_fingerprint(scope_list), + ) + + lock = self._locks[key] + + async with lock: + record = await self._token_store.get(key) + if record: + return record + + # there is no global token, look for a user specific one + key = TokenStoreKey( + user_key=user.cache_key, + resource=resource_str, + authorization_server=auth_server_str, + scope_fingerprint=scope_fingerprint(scope_list), + ) + + lock = self._locks[key] + async with lock: + record = await self._token_store.get(key) + leeway = ( + self._settings.token_store.refresh_leeway_seconds + if self._settings and self._settings.token_store + else 60 + ) + if record and not record.is_expired(leeway_seconds=leeway): + return record + + # If token exists but expired, try to refresh it + if record and record.refresh_token: + # For refresh, we need OAuth metadata + resource_hint = ( + str(oauth_config.resource) + if oauth_config.resource + else getattr(server_config, "url", None) + ) + server_url = getattr(server_config, "url", None) + resource = normalize_resource(resource_hint, server_url) + + # Get OAuth metadata for token refresh + parsed_resource = URL(resource) + metadata_url = str( + parsed_resource.copy_with( + path="/.well-known/oauth-protected-resource" + + parsed_resource.path + ) + ) + resource_metadata = await self._get_resource_metadata(metadata_url) + auth_server_url = select_authorization_server( + resource_metadata, str(oauth_config.authorization_server) + ) + auth_metadata = await self._get_authorization_metadata(auth_server_url) + + try: + refreshed = await self._refresh_token( + record, + oauth_config=oauth_config, + auth_metadata=auth_metadata, + resource=resource, + scopes=scope_list, + ) + except TokenRefreshError: + await self._token_store.delete(key) + else: + if refreshed: + await self._token_store.set(key, refreshed) + return refreshed + await self._token_store.delete(key) + + # Need to run full authorization flow - only if no token found or refresh failed + if not record: + resource_hint = ( + str(oauth_config.resource) + if oauth_config.resource + else getattr(server_config, "url", None) + ) + server_url = getattr(server_config, "url", None) + resource = normalize_resource(resource_hint, server_url) + + # Get OAuth metadata for full authorization flow + parsed_resource = URL(resource) + metadata_url = str( + parsed_resource.copy_with( + path="/.well-known/oauth-protected-resource" + + parsed_resource.path + ) + ) + resource_metadata = await self._get_resource_metadata(metadata_url) + auth_server_url = select_authorization_server( + resource_metadata, str(oauth_config.authorization_server) + ) + auth_metadata = await self._get_authorization_metadata(auth_server_url) + + record = await self._flow.authorize( + context=context, + user=user, + server_name=server_name, + oauth_config=oauth_config, + resource=resource, + authorization_server_url=auth_server_url, + resource_metadata=resource_metadata, + auth_metadata=auth_metadata, + scopes=scope_list, + ) + await self._token_store.set(key, record) + return record + + # If we reach here, we have an expired token with no refresh token + # Return it anyway - the caller will handle 401s + return record + + async def invalidate( + self, + *, + user: OAuthUserIdentity, + resource: str, + authorization_server: str | None, + scopes: Iterable[str], + ) -> None: + key = TokenStoreKey( + user_key=user.cache_key, + resource=resource, + authorization_server=authorization_server, + scope_fingerprint=scope_fingerprint(scopes), + ) + await self._token_store.delete(key) + + async def _refresh_token( + self, + record: TokenRecord, + *, + oauth_config: MCPOAuthClientSettings, + auth_metadata, + resource: str, + scopes: Sequence[str], + ) -> TokenRecord | None: + if not record.refresh_token: + return None + + token_endpoint = str(auth_metadata.token_endpoint) + data = { + "grant_type": "refresh_token", + "refresh_token": record.refresh_token, + "client_id": oauth_config.client_id, + "resource": resource, + } + if scopes: + data["scope"] = " ".join(scopes) + if oauth_config.client_secret: + data["client_secret"] = oauth_config.client_secret + if oauth_config.extra_token_params: + data.update(oauth_config.extra_token_params) + + try: + response = await self._http_client.post(token_endpoint, data=data) + except httpx.HTTPError as exc: + logger.warning("Refresh token request failed", exc_info=True) + raise TokenRefreshError(str(exc)) from exc + + if response.status_code != 200: + logger.warning( + "Refresh token request returned non-success status", + data={"status_code": response.status_code}, + ) + return None + + payload = response.json() + new_access = payload.get("access_token") + if not new_access: + return None + new_refresh = payload.get("refresh_token", record.refresh_token) + expires_in = payload.get("expires_in") + new_expires = record.expires_at + if isinstance(expires_in, (int, float)): + new_expires = time.time() + float(expires_in) + + scope_from_payload = payload.get("scope") + if isinstance(scope_from_payload, str) and scope_from_payload.strip(): + scopes_tuple = tuple(scope_from_payload.split()) + else: + scopes_tuple = record.scopes + + return TokenRecord( + access_token=new_access, + refresh_token=new_refresh, + expires_at=new_expires, + scopes=scopes_tuple, + token_type=str(payload.get("token_type", record.token_type)), + resource=record.resource, + authorization_server=record.authorization_server, + metadata={"raw": payload}, + ) + + async def _get_resource_metadata(self, url: str): + cached = self._resource_metadata_cache.get(url) + if cached and time.time() - cached[0] < 300: + return cached[1] + metadata = await fetch_resource_metadata(self._http_client, url) + self._resource_metadata_cache[url] = (time.time(), metadata) + return metadata + + async def _get_authorization_metadata(self, url: str): + cached = self._auth_metadata_cache.get(url) + if cached and time.time() - cached[0] < 300: + return cached[1] + # Construct OAuth authorization server metadata URL + parsed_url = URL(url) + metadata_url = str( + parsed_url.copy_with( + path="/.well-known/oauth-authorization-server" + parsed_url.path + ) + ) + metadata = await fetch_authorization_server_metadata( + self._http_client, metadata_url + ) + self._auth_metadata_cache[url] = (time.time(), metadata) + return metadata + + async def aclose(self) -> None: + if self._own_http_client: + await self._http_client.aclose() diff --git a/src/mcp_agent/oauth/metadata.py b/src/mcp_agent/oauth/metadata.py new file mode 100644 index 000000000..c4986b011 --- /dev/null +++ b/src/mcp_agent/oauth/metadata.py @@ -0,0 +1,60 @@ +"""Helpers for OAuth metadata discovery.""" + +from __future__ import annotations + +from typing import List + +import httpx +from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata + +from mcp_agent.logging.logger import get_logger + +logger = get_logger(__name__) + + +async def fetch_resource_metadata( + client: httpx.AsyncClient, + resource_metadata_url: str, +) -> ProtectedResourceMetadata: + response = await client.get(resource_metadata_url) + response.raise_for_status() + data = response.json() + return ProtectedResourceMetadata.model_validate(data) + + +async def fetch_authorization_server_metadata( + client: httpx.AsyncClient, + metadata_url: str, +) -> OAuthMetadata: + response = await client.get(metadata_url) + response.raise_for_status() + return OAuthMetadata.model_validate(response.json()) + + +def select_authorization_server( + metadata: ProtectedResourceMetadata, + preferred: str | None = None, +) -> str: + candidates: List[str] = [str(url) for url in (metadata.authorization_servers or [])] + if not candidates: + raise ValueError( + "Protected resource metadata did not include authorization servers" + ) + + if preferred and preferred in candidates: + return preferred + + if preferred: + logger.warning( + "Preferred authorization server not listed; falling back to first entry", + data={"preferred": preferred, "candidates": candidates}, + ) + return candidates[0] + + +def normalize_resource(resource: str | None, fallback: str | None) -> str: + if resource: + return resource.rstrip("/") + if fallback: + return fallback.rstrip("/") + raise ValueError("Unable to determine resource identifier for OAuth flow") diff --git a/src/mcp_agent/oauth/pkce.py b/src/mcp_agent/oauth/pkce.py new file mode 100644 index 000000000..1709dc332 --- /dev/null +++ b/src/mcp_agent/oauth/pkce.py @@ -0,0 +1,27 @@ +"""PKCE utilities.""" + +from __future__ import annotations + +import base64 +import hashlib +import secrets + + +def generate_code_verifier(length: int = 64) -> str: + if length < 43 or length > 128: + raise ValueError("PKCE code verifier length must be between 43 and 128") + # token_urlsafe returns ~1.3 chars per byte; adjust to reach desired length + needed_bytes = int(length * 0.8) + 1 + verifier = secrets.token_urlsafe(needed_bytes) + if len(verifier) < length: + verifier = (verifier + secrets.token_urlsafe(needed_bytes))[:length] + return verifier[:length] + + +def generate_code_challenge(verifier: str) -> str: + digest = hashlib.sha256(verifier.encode()).digest() + return base64.urlsafe_b64encode(digest).rstrip(b"=").decode() + + +def generate_state(length: int = 32) -> str: + return secrets.token_urlsafe(length) diff --git a/src/mcp_agent/oauth/records.py b/src/mcp_agent/oauth/records.py new file mode 100644 index 000000000..2a7c47be0 --- /dev/null +++ b/src/mcp_agent/oauth/records.py @@ -0,0 +1,46 @@ +"""Shared record types for OAuth token management.""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any, Dict, Tuple + +from pydantic import BaseModel, Field + + +class TokenRecord(BaseModel): + """Persisted token bundle for a user/resource/authorization server combination.""" + + access_token: str + refresh_token: str | None = None + scopes: Tuple[str, ...] = () + expires_at: float | None = None + token_type: str = "Bearer" + resource: str | None = None + authorization_server: str | None = None + obtained_at: float = Field( + default_factory=lambda: datetime.now(tz=timezone.utc).timestamp() + ) + metadata: Dict[str, Any] = Field(default_factory=dict) + + def is_expired(self, *, leeway_seconds: int = 0) -> bool: + if self.expires_at is None: + return False + now = datetime.now(tz=timezone.utc).timestamp() + return now >= (self.expires_at - leeway_seconds) + + def with_tokens( + self, + *, + access_token: str, + refresh_token: str | None, + expires_at: float | None, + ) -> "TokenRecord": + return self.model_copy( + update={ + "access_token": access_token, + "refresh_token": refresh_token, + "expires_at": expires_at, + "obtained_at": datetime.now(tz=timezone.utc).timestamp(), + } + ) diff --git a/src/mcp_agent/oauth/store/__init__.py b/src/mcp_agent/oauth/store/__init__.py new file mode 100644 index 000000000..226953f43 --- /dev/null +++ b/src/mcp_agent/oauth/store/__init__.py @@ -0,0 +1,11 @@ +"""Token store implementations.""" + +from .base import TokenStore, TokenStoreKey, scope_fingerprint +from .in_memory import InMemoryTokenStore + +__all__ = [ + "TokenStore", + "TokenStoreKey", + "scope_fingerprint", + "InMemoryTokenStore", +] diff --git a/src/mcp_agent/oauth/store/base.py b/src/mcp_agent/oauth/store/base.py new file mode 100644 index 000000000..a03f7da83 --- /dev/null +++ b/src/mcp_agent/oauth/store/base.py @@ -0,0 +1,33 @@ +"""Abstract token store definition.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterable, Protocol + +from ..records import TokenRecord + + +@dataclass(frozen=True) +class TokenStoreKey: + """Uniquely identifies a cached token.""" + + user_key: str + resource: str + authorization_server: str | None + scope_fingerprint: str + + +def scope_fingerprint(scopes: Iterable[str]) -> str: + """Return a deterministic fingerprint for a scope list.""" + return " ".join(sorted({scope.strip() for scope in scopes if scope})) + + +class TokenStore(Protocol): + """Persistence interface for OAuth tokens.""" + + async def get(self, key: TokenStoreKey) -> TokenRecord | None: ... + + async def set(self, key: TokenStoreKey, record: TokenRecord) -> None: ... + + async def delete(self, key: TokenStoreKey) -> None: ... diff --git a/src/mcp_agent/oauth/store/in_memory.py b/src/mcp_agent/oauth/store/in_memory.py new file mode 100644 index 000000000..d1aa40460 --- /dev/null +++ b/src/mcp_agent/oauth/store/in_memory.py @@ -0,0 +1,30 @@ +"""In-memory token store for local development and testing.""" + +from __future__ import annotations + +import asyncio +from typing import Dict + +from .base import TokenStore, TokenStoreKey +from ..records import TokenRecord + + +class InMemoryTokenStore(TokenStore): + def __init__(self) -> None: + self._records: Dict[TokenStoreKey, TokenRecord] = {} + self._lock = asyncio.Lock() + + async def get(self, key: TokenStoreKey) -> TokenRecord | None: + async with self._lock: + record = self._records.get(key) + if record is None: + return None + return record + + async def set(self, key: TokenStoreKey, record: TokenRecord) -> None: + async with self._lock: + self._records[key] = record + + async def delete(self, key: TokenStoreKey) -> None: + async with self._lock: + self._records.pop(key, None) diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index ec56574c1..3bbf6e59a 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -10,15 +10,22 @@ import os import secrets import asyncio +from pydantic import BaseModel, Field from mcp.server.fastmcp import Context as MCPContext, FastMCP +from mcp.server.fastmcp.server import AuthSettings +from mcp.server.auth.middleware.auth_context import ( + AuthenticatedUser, + auth_context_var, +) from starlette.requests import Request -from starlette.responses import JSONResponse +from starlette.responses import HTMLResponse, JSONResponse from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.tools import Tool as FastTool from mcp_agent.app import MCPApp from mcp_agent.agents.agent import Agent +from mcp_agent.config import MCPOAuthClientSettings from mcp_agent.core.context_dependent import ContextDependent from mcp_agent.executor.workflow import Workflow from mcp_agent.executor.workflow_registry import ( @@ -30,6 +37,13 @@ from mcp_agent.logging.logger import get_logger from mcp_agent.logging.logger import LoggingConfig from mcp_agent.mcp.mcp_server_registry import ServerRegistry +from mcp_agent.oauth.identity import OAuthUserIdentity +from mcp_agent.oauth.callbacks import callback_registry +from mcp_agent.oauth.errors import ( + CallbackTimeoutError, +) +from mcp_agent.oauth.manager import create_default_user_for_preconfigured_tokens +from mcp_agent.server.token_verifier import MCPAgentTokenVerifier if TYPE_CHECKING: from mcp_agent.core.context import Context @@ -178,6 +192,40 @@ def _set_upstream_from_request_ctx_if_available(ctx: MCPContext) -> None: # ctx.session property might raise ValueError if context not available pass + # Capture authenticated user information if available + identity: OAuthUserIdentity | None = None + try: + auth_user = auth_context_var.get() + except LookupError: + auth_user = None + + if isinstance(auth_user, AuthenticatedUser): + access_token = getattr(auth_user, "access_token", None) + if access_token is not None: + # Prefer enriched token instances but fall back to raw data if necessary + try: + from mcp_agent.oauth.access_token import MCPAccessToken + + if isinstance(access_token, MCPAccessToken): + identity = OAuthUserIdentity.from_access_token(access_token) + else: + token_dict = getattr(access_token, "model_dump", None) + if callable(token_dict): + maybe_token = MCPAccessToken.model_validate( + access_token.model_dump() + ) + identity = OAuthUserIdentity.from_access_token(maybe_token) + except Exception: + identity = None + + if not identity: + # Try create identity from session id + try: + session_id = ctx.request_context.request.query_params.get("session_id") + identity = create_default_user_for_preconfigured_tokens(session_id) + except Exception: + identity = None + if session is not None: app: MCPApp | None = _get_attached_app(ctx.fastmcp) if app is not None and getattr(app, "context", None) is not None: @@ -185,9 +233,15 @@ def _set_upstream_from_request_ctx_if_available(ctx: MCPContext) -> None: # Previously captured; no need to keep old value # Use direct assignment for Pydantic model app.context.upstream_session = session + app.context.current_user = identity return else: return + else: + # Update identity even if we failed to resolve a session + app: MCPApp | None = _get_attached_app(ctx.fastmcp) + if app is not None and getattr(app, "context", None) is not None: + app.context.current_user = identity def _resolve_workflows_and_context( @@ -322,6 +376,34 @@ def create_mcp_server_for_app(app: MCPApp, **kwargs: Any) -> FastMCP: A configured FastMCP server instance """ + auth_settings_config = None + try: + if app.context and app.context.config: + auth_settings_config = app.context.config.authorization + except Exception: + auth_settings_config = None + + effective_auth_settings: AuthSettings | None = None + token_verifier: MCPAgentTokenVerifier | None = None + owns_token_verifier = False + if auth_settings_config and auth_settings_config.enabled: + try: + effective_auth_settings = AuthSettings( + issuer_url=auth_settings_config.issuer_url, # type: ignore[arg-type] + resource_server_url=auth_settings_config.resource_server_url, # type: ignore[arg-type] + service_documentation_url=auth_settings_config.service_documentation_url, # type: ignore[arg-type] + required_scopes=auth_settings_config.required_scopes or None, + ) + token_verifier = MCPAgentTokenVerifier(auth_settings_config) + except Exception as exc: + logger.error( + "Failed to configure authorization server integration", + exc_info=True, + data={"error": str(exc)}, + ) + effective_auth_settings = None + token_verifier = None + # Create a lifespan function specific to this app @asynccontextmanager async def app_specific_lifespan(mcp: FastMCP) -> AsyncIterator[ServerContext]: @@ -341,7 +423,11 @@ async def app_specific_lifespan(mcp: FastMCP) -> AsyncIterator[ServerContext]: yield server_context finally: # Don't clean up the MCPApp here - let the caller handle that - pass + if owns_token_verifier and token_verifier is not None: + try: + await token_verifier.aclose() + except Exception: + pass # Helper: install internal HTTP routes (not MCP tools) def _install_internal_routes(mcp_server: FastMCP) -> None: @@ -358,6 +444,41 @@ def _get_fallback_upstream_session() -> Any | None: return None return None + @mcp_server.custom_route( + "/internal/oauth/callback/{flow_id}", + methods=["GET", "POST"], + include_in_schema=False, + ) + async def _oauth_callback(request: Request): + flow_id = request.path_params.get("flow_id") + if not flow_id: + return JSONResponse({"error": "missing_flow_id"}, status_code=400) + + payload: Dict[str, Any] = {} + try: + payload.update({k: v for k, v in request.query_params.multi_items()}) + except Exception: + payload.update(dict(request.query_params)) + + if request.method.upper() == "POST": + content_type = request.headers.get("content-type", "") + try: + if "application/json" in content_type: + body_data = await request.json() + else: + form = await request.form() + body_data = {k: v for k, v in form.multi_items()} + except Exception: + body_data = {} + payload.update(body_data) + + delivered = await callback_registry.deliver(flow_id, payload) + if not delivered: + return JSONResponse({"error": "unknown_flow"}, status_code=404) + + html = """

Authorization complete.

You may close this window and return to MCP Agent.

""" + return HTMLResponse(html) + @mcp_server.custom_route( "/internal/session/by-run/{execution_id}/notify", methods=["POST"], @@ -609,6 +730,39 @@ async def _handle_specific_request( request=req, result_type=EmptyResult ) # type: ignore[attr-defined] return result.model_dump(by_alias=True, mode="json", exclude_none=True) + elif method == "auth/request": + # TODO: special handling of auth request, should be replaced by future URL elicitation + class AuthToken(BaseModel): + confirmation: str = Field( + description="Please press enter to confirm this message has been received" + ) + + flow_id = params["flow_id"] + callback_future = await callback_registry.create_handle(flow_id) + + req = ElicitRequest( + method="elicitation/create", + params=ElicitRequestParams( + message=params["message"] + "\n\n" + params["url"], + requestedSchema=AuthToken.model_json_schema(), + ), + ) + + result = await session.send_request( + request=req, result_type=ElicitResult + ) # type: ignore[attr-defined] + + timeout = 300 + try: + callback_data = await asyncio.wait_for( + callback_future, timeout=timeout + ) + except asyncio.TimeoutError as exc: + raise CallbackTimeoutError( + f"Timed out waiting for OAuth callback after {timeout} seconds" + ) from exc + + return callback_data else: raise ValueError(f"unsupported method: {method}") @@ -1039,6 +1193,11 @@ async def _internal_human_prompts(request: Request): except Exception: pass else: + if "auth" not in kwargs and effective_auth_settings is not None: + kwargs["auth"] = effective_auth_settings + if "token_verifier" not in kwargs and token_verifier is not None: + kwargs["token_verifier"] = token_verifier + owns_token_verifier = True mcp = FastMCP( name=app.name or "mcp_agent_server", # TODO: saqadri (MAC) - create a much more detailed description @@ -1381,6 +1540,172 @@ async def cancel_workflow( return result + @mcp.tool(name="workflows-pre-auth") + async def workflow_pre_auth( + ctx: MCPContext, workflow_name: str, tokens: List[Dict[str, Any]] + ) -> Dict[str, Any]: + """ + Pre-authorize OAuth tokens for a workflow to use with MCP servers. + + Stores OAuth tokens that the workflow can use when connecting to various MCP servers. + This allows workflows to authenticate with external services without requiring + interactive OAuth flows during execution. + + Args: + workflow_name: The name of the workflow that will use these tokens. + tokens: List of OAuth token objects, each containing: + - access_token (str): The OAuth access token + - refresh_token (str, optional): The OAuth refresh token + - server_name (str): Name/identifier of the MCP server + - scopes (List[str], optional): List of OAuth scopes + - expires_at (float, optional): Token expiration timestamp + - authorization_server (str, optional): Authorization server URL + + Returns: + Dictionary with success status and count of stored tokens. + """ + # Ensure upstream session is available for any logs + try: + _set_upstream_from_request_ctx_if_available(ctx) + except Exception: + pass + + workflows_dict, app_context = _resolve_workflows_and_context(ctx) + if not workflows_dict or not app_context: + raise ToolError("Server context not available for MCPApp Server.") + + if workflow_name not in workflows_dict: + raise ToolError(f"Workflow '{workflow_name}' not found.") + + if not app_context.token_store: + raise ToolError("Token storage not available.") + + if not tokens: + raise ToolError("At least one token must be provided.") + + stored_count = 0 + errors = [] + + try: + for i, token_data in enumerate(tokens): + try: + # Validate required fields + if not isinstance(token_data, dict): + errors.append(f"Token {i}: must be a dictionary") + continue + + access_token = token_data.get("access_token") + server_name = token_data.get("server_name") + + if not access_token: + errors.append( + f"Token {i}: missing required 'access_token' field" + ) + continue + + if not server_name: + errors.append( + f"Token {i}: missing required 'server_name' field" + ) + continue + + server_config = app_context.server_registry.get_server_config( + server_name + ) + if not server_config: + errors.append( + f"Token {i}: server '{server_name}' not recognized" + ) + continue + + oauth_config: MCPOAuthClientSettings | None = None + if server_config and server_config.auth: + oauth_config = getattr(server_config.auth, "oauth", None) + if not oauth_config or not oauth_config.enabled: + errors.append( + f"Token {i}: Server '{server_name}' is not configured for OAuth authentication" + ) + continue + + # Create TokenRecord + from mcp_agent.oauth.records import TokenRecord + from mcp_agent.oauth.store.base import ( + TokenStoreKey, + scope_fingerprint, + ) + + resource_str = ( + str(oauth_config.resource) + if oauth_config.resource + else getattr(server_config, "url", None) + ) + auth_server_str = ( + str(oauth_config.authorization_server) + if oauth_config.authorization_server + else None + ) + scope_list = list(oauth_config.scopes or []) + + token_record = TokenRecord( + access_token=access_token, + refresh_token=token_data.get("refresh_token"), + scopes=tuple(scope_list), + expires_at=token_data.get("expires_at"), + token_type=token_data.get("token_type", "Bearer"), + resource=server_name, + authorization_server=auth_server_str, + metadata={"workflow_name": workflow_name}, + ) + + str(oauth_config.resource) if oauth_config.resource else getattr( + server_config, "url", None + ) + # Create storage key using current user + store_key = TokenStoreKey( + user_key=app_context.current_user.cache_key, + resource=resource_str, + authorization_server=auth_server_str, + scope_fingerprint=scope_fingerprint(scope_list), + ) + + # Store the token + await app_context.token_store.set(store_key, token_record) + stored_count += 1 + except Exception as e: + errors.append(f"Token {i}: {str(e)}") + logger.error( + f"Error storing token {i} for workflow '{workflow_name}': {e}" + ) + + if errors and stored_count == 0: + raise ToolError( + f"Failed to store any tokens. Errors: {'; '.join(errors)}" + ) + + result = { + "success": True, + "workflow_name": workflow_name, + "stored_tokens": stored_count, + "total_tokens": len(tokens), + } + + if errors: + result["errors"] = errors + result["partial_success"] = True + + logger.info( + f"Pre-authorization completed for workflow '{workflow_name}': " + f"{stored_count}/{len(tokens)} tokens stored" + ) + + return result + + except Exception as e: + logger.error( + f"Error in workflow pre-authorization for '{workflow_name}': {e}" + ) + raise ToolError(f"Failed to store tokens: {str(e)}") + # endregion return mcp diff --git a/src/mcp_agent/server/token_verifier.py b/src/mcp_agent/server/token_verifier.py new file mode 100644 index 000000000..799341585 --- /dev/null +++ b/src/mcp_agent/server/token_verifier.py @@ -0,0 +1,200 @@ +"""Token verification for MCP Agent Cloud authorization server.""" + +from __future__ import annotations + +import asyncio +from datetime import datetime, timezone +from typing import Any, Dict, List + +import httpx + +from mcp.server.auth.provider import AccessToken +from mcp.server.auth.provider import TokenVerifier + +from mcp_agent.config import MCPAuthorizationServerSettings +from mcp_agent.logging.logger import get_logger +from mcp_agent.oauth.access_token import MCPAccessToken + +logger = get_logger(__name__) + + +class MCPAgentTokenVerifier(TokenVerifier): + """Verify bearer tokens issued by the MCP Agent Cloud authorization server.""" + + def __init__(self, settings: MCPAuthorizationServerSettings): + if not settings.introspection_endpoint: + raise ValueError( + "introspection_endpoint must be configured to verify tokens" + ) + self._settings = settings + timeout = httpx.Timeout(10.0) + self._client = httpx.AsyncClient(timeout=timeout) + self._cache: Dict[str, MCPAccessToken] = {} + self._lock = asyncio.Lock() + + async def verify_token(self, token: str) -> AccessToken | None: # type: ignore[override] + cached = self._cache.get(token) + if cached and not cached.is_expired(leeway_seconds=30): + return cached + + async with self._lock: + # Double-check cache after acquiring lock to avoid duplicate refresh + cached = self._cache.get(token) + if cached and not cached.is_expired(leeway_seconds=30): + return cached + + verified = await self._introspect(token) + if verified: + self._cache[token] = verified + else: + self._cache.pop(token, None) + return verified + + async def _introspect(self, token: str) -> MCPAccessToken | None: + data = {"token": token} + auth = None + if ( + self._settings.introspection_client_id + and self._settings.introspection_client_secret + ): + auth = httpx.BasicAuth( + self._settings.introspection_client_id, + self._settings.introspection_client_secret, + ) + + try: + response = await self._client.post( + str(self._settings.introspection_endpoint), + data=data, + auth=auth, + ) + except httpx.HTTPError as exc: + logger.warning(f"Token introspection request failed: {exc}") + return None + + if response.status_code != 200: + logger.warning( + "Token introspection returned non-success status", + data={"status_code": response.status_code}, + ) + return None + + try: + payload: Dict[str, Any] = response.json() + except ValueError: + logger.warning("Token introspection response was not valid JSON") + return None + + if not payload.get("active"): + return None + + if self._settings.issuer_url and payload.get("iss"): + if str(payload.get("iss")) != str(self._settings.issuer_url): + logger.warning( + "Token issuer mismatch", + data={ + "expected": str(self._settings.issuer_url), + "actual": payload.get("iss"), + }, + ) + return None + + # RFC 9068 Audience Validation (always enforced) + token_audiences = self._extract_audiences(payload) + if not self._validate_audiences(token_audiences): + logger.warning( + "Token audience validation failed", + data={ + "token_audiences": token_audiences, + "expected_audiences": self._settings.expected_audiences, + }, + ) + return None + + token_model = MCPAccessToken.from_introspection( + token, + payload, + resource_hint=str(self._settings.resource_server_url) + if self._settings.resource_server_url + else None, + ) + + # Respect cache TTL limit if configured + ttl_seconds = max(0, self._settings.token_cache_ttl_seconds or 0) + if ttl_seconds and token_model.expires_at is not None: + now_ts = datetime.now(tz=timezone.utc).timestamp() + cache_limit = now_ts + ttl_seconds + token_model.expires_at = min(token_model.expires_at, cache_limit) + + # Optionally enforce required scopes + required_scopes = self._settings.required_scopes or [] + missing = [ + scope for scope in required_scopes if scope not in token_model.scopes + ] + if missing: + logger.warning( + "Token missing required scopes", + data={"missing_scopes": missing}, + ) + return None + + return token_model + + def _extract_audiences(self, payload: Dict[str, Any]) -> List[str]: + """Extract audience values from token payload according to RFC 9068.""" + audiences = [] + + # Check both 'aud' and 'resource' claims (OAuth 2.0 resource indicators) + aud_claim = payload.get("aud") + resource_claim = payload.get("resource") + + # Handle 'aud' claim (can be string or array) + if aud_claim: + if isinstance(aud_claim, str): + audiences.append(aud_claim) + elif isinstance(aud_claim, (list, tuple)): + audiences.extend([str(aud) for aud in aud_claim if aud]) + + # Handle 'resource' claim (OAuth 2.0 resource indicator) + if resource_claim: + if isinstance(resource_claim, str): + audiences.append(resource_claim) + elif isinstance(resource_claim, (list, tuple)): + audiences.extend([str(res) for res in resource_claim if res]) + + return list(set(audiences)) # Remove duplicates + + def _validate_audiences(self, token_audiences: List[str]) -> bool: + """Validate token audiences against expected values per RFC 9068.""" + if not token_audiences: + logger.warning("Token contains no audience claims") + return False + + if not self._settings.expected_audiences: + logger.warning("No expected audiences configured for validation") + return False + + # RFC 9068: Token MUST contain at least one expected audience + valid_audiences = set(self._settings.expected_audiences) + token_audience_set = set(token_audiences) + + if not valid_audiences.intersection(token_audience_set): + logger.warning( + "Token audience validation failed - no matching audiences", + data={ + "token_audiences": list(token_audience_set), + "valid_audiences": list(valid_audiences), + }, + ) + return False + + return True + + async def aclose(self) -> None: + await self._client.aclose() + + async def __aenter__(self) -> "MCPAgentTokenVerifier": + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.aclose() diff --git a/tests/test_audience_validation.py b/tests/test_audience_validation.py new file mode 100644 index 000000000..392cf4e18 --- /dev/null +++ b/tests/test_audience_validation.py @@ -0,0 +1,250 @@ +"""Test audience validation functionality for RFC 9068 compliance.""" + +import pytest +from unittest.mock import Mock, AsyncMock +import httpx +from mcp_agent.config import MCPAuthorizationServerSettings +from mcp_agent.server.token_verifier import MCPAgentTokenVerifier +from mcp_agent.oauth.access_token import MCPAccessToken, _extract_all_audiences + + +@pytest.mark.asyncio +async def test_audience_validation_success(): + """Test successful audience validation with matching audiences.""" + settings = MCPAuthorizationServerSettings( + enabled=True, + issuer_url="https://auth.example.com", + resource_server_url="https://api.example.com", + introspection_endpoint="https://auth.example.com/introspect", + expected_audiences=["https://api.example.com", "api.example.com"], + ) + + # Mock successful introspection response with valid audience + payload = { + "active": True, + "aud": ["https://api.example.com", "other.example.com"], + "sub": "user123", + "exp": 1234567890, + "iss": "https://auth.example.com/", + } + + token = MCPAccessToken.from_introspection("test_token", payload) + assert token.validate_audience(settings.expected_audiences) is True + + +@pytest.mark.asyncio +async def test_audience_validation_failure(): + """Test audience validation failure with non-matching audiences.""" + settings = MCPAuthorizationServerSettings( + enabled=True, + issuer_url="https://auth.example.com", + resource_server_url="https://api.example.com", + introspection_endpoint="https://auth.example.com/introspect", + expected_audiences=["https://api.example.com"], + ) + + payload = { + "active": True, + "aud": ["https://malicious.example.com"], # Wrong audience + "sub": "user123", + "exp": 1234567890, + "iss": "https://auth.example.com/", + } + + token = MCPAccessToken.from_introspection("test_token", payload) + assert token.validate_audience(settings.expected_audiences) is False + + +@pytest.mark.asyncio +async def test_resource_claim_audience_validation(): + """Test audience validation using OAuth 2.0 resource indicators.""" + settings = MCPAuthorizationServerSettings( + enabled=True, + issuer_url="https://auth.example.com", + resource_server_url="https://api.example.com", + introspection_endpoint="https://auth.example.com/introspect", + expected_audiences=["https://api.example.com"], + ) + + # Token with resource claim instead of aud claim + payload = { + "active": True, + "resource": "https://api.example.com", # OAuth 2.0 resource indicator + "sub": "user123", + "exp": 1234567890, + "iss": "https://auth.example.com/", + } + + token = MCPAccessToken.from_introspection("test_token", payload) + assert token.validate_audience(settings.expected_audiences) is True + + +@pytest.mark.asyncio +async def test_multiple_audiences_extraction(): + """Test extraction of multiple audiences from both aud and resource claims.""" + payload = { + "aud": ["https://api1.example.com", "https://api2.example.com"], + "resource": "https://api3.example.com", + } + + audiences = _extract_all_audiences(payload) + expected = { + "https://api1.example.com", + "https://api2.example.com", + "https://api3.example.com", + } + assert set(audiences) == expected + + +@pytest.mark.asyncio +async def test_audience_extraction_string_values(): + """Test extraction when aud and resource are strings rather than arrays.""" + payload = { + "aud": "https://api1.example.com", + "resource": "https://api2.example.com", + } + + audiences = _extract_all_audiences(payload) + expected = {"https://api1.example.com", "https://api2.example.com"} + assert set(audiences) == expected + + +@pytest.mark.asyncio +async def test_empty_audience_validation(): + """Test validation fails when no audiences are present.""" + payload = { + "active": True, + "sub": "user123", + "exp": 1234567890, + "iss": "https://auth.example.com/", + # No aud or resource claims + } + + token = MCPAccessToken.from_introspection("test_token", payload) + assert token.validate_audience(["https://api.example.com"]) is False + + +def test_configuration_validation(): + """Test that configuration validation always enforces audience settings.""" + # Should raise error when no audiences configured (always enforced now) + with pytest.raises(ValueError, match="expected_audiences.*required for RFC 9068"): + MCPAuthorizationServerSettings( + enabled=True, + issuer_url="https://auth.example.com", + resource_server_url="https://api.example.com", + expected_audiences=[], # Empty list should always fail + ) + + # Should succeed with proper configuration + settings = MCPAuthorizationServerSettings( + enabled=True, + issuer_url="https://auth.example.com", + resource_server_url="https://api.example.com", + expected_audiences=["https://api.example.com"], + ) + assert "https://api.example.com" in settings.expected_audiences + + +@pytest.mark.asyncio +async def test_token_verifier_audience_validation_integration(): + """Test full integration of audience validation in token verifier.""" + settings = MCPAuthorizationServerSettings( + enabled=True, + issuer_url="https://auth.example.com", + resource_server_url="https://api.example.com", + introspection_endpoint="https://auth.example.com/introspect", + introspection_client_id="test-client", + introspection_client_secret="test-secret", + expected_audiences=["https://api.example.com"], + ) + + verifier = MCPAgentTokenVerifier(settings) + + # Mock HTTP client + mock_client = Mock(spec=httpx.AsyncClient) + + # Mock successful response with valid audience + valid_response = Mock() + valid_response.status_code = 200 + valid_response.json.return_value = { + "active": True, + "aud": "https://api.example.com", + "sub": "user123", + "exp": 1234567890, + "iss": "https://auth.example.com/", + } + mock_client.post = AsyncMock(return_value=valid_response) + verifier._client = mock_client + + # Should succeed with valid audience + token = await verifier._introspect("valid_token") + assert token is not None + assert "https://api.example.com" in token.audiences + + # Mock response with invalid audience + invalid_response = Mock() + invalid_response.status_code = 200 + invalid_response.json.return_value = { + "active": True, + "aud": "https://malicious.example.com", # Wrong audience + "sub": "user123", + "exp": 1234567890, + "iss": "https://auth.example.com/", + } + mock_client.post = AsyncMock(return_value=invalid_response) + + # Should fail with invalid audience + token = await verifier._introspect("invalid_token") + assert token is None + + +def test_audience_extraction_edge_cases(): + """Test audience extraction handles edge cases properly.""" + # Empty payload + assert _extract_all_audiences({}) == [] + + # None values + assert _extract_all_audiences({"aud": None, "resource": None}) == [] + + # Mixed empty and valid values + payload = { + "aud": ["", "https://valid.com", None], + "resource": ["https://another.com", ""], + } + audiences = _extract_all_audiences(payload) + expected = {"https://valid.com", "https://another.com"} + assert set(audiences) == expected + + # Duplicate values should be removed + payload = { + "aud": ["https://api.com", "https://api.com"], + "resource": "https://api.com", + } + audiences = _extract_all_audiences(payload) + assert audiences == ["https://api.com"] + + +@pytest.mark.asyncio +async def test_partial_audience_match(): + """Test that partial audience matches are sufficient for validation.""" + settings = MCPAuthorizationServerSettings( + enabled=True, + issuer_url="https://auth.example.com", + resource_server_url="https://api.example.com", + introspection_endpoint="https://auth.example.com/introspect", + expected_audiences=["https://api.example.com", "https://other-api.com"], + strict_audience_validation=True, + ) + + # Token has one matching and one non-matching audience + payload = { + "active": True, + "aud": ["https://api.example.com", "https://unrelated.com"], + "sub": "user123", + "exp": 1234567890, + "iss": "https://auth.example.com/", + } + + token = MCPAccessToken.from_introspection("test_token", payload) + # Should succeed because at least one audience matches + assert token.validate_audience(settings.expected_audiences) is True diff --git a/tests/test_oauth_utils.py b/tests/test_oauth_utils.py new file mode 100644 index 000000000..cd3de139f --- /dev/null +++ b/tests/test_oauth_utils.py @@ -0,0 +1,86 @@ +import time +import pathlib +import sys + +import pytest + +PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1] +SRC_ROOT = PROJECT_ROOT / "src" +if str(SRC_ROOT) not in sys.path: + sys.path.insert(0, str(SRC_ROOT)) + +try: + from mcp_agent.oauth.metadata import normalize_resource, select_authorization_server + from mcp_agent.oauth.records import TokenRecord + from mcp_agent.oauth.store import ( + InMemoryTokenStore, + TokenStoreKey, + scope_fingerprint, + ) + from mcp.shared.auth import ProtectedResourceMetadata +except ModuleNotFoundError: # pragma: no cover - optional dependency + pytest.skip("MCP SDK not installed", allow_module_level=True) + + +def test_scope_fingerprint_ordering(): + scopes = ["email", "profile", "email"] + fingerprint = scope_fingerprint(scopes) + assert fingerprint == "email profile" + + +def test_token_record_expiry(): + record = TokenRecord( + access_token="tok", + expires_at=time.time() + 5, + ) + assert not record.is_expired(leeway_seconds=0) + assert record.is_expired(leeway_seconds=10) + + +@pytest.mark.asyncio +async def test_in_memory_token_store_round_trip(): + store = InMemoryTokenStore() + key = TokenStoreKey( + user_key="provider:subject", + resource="https://example.com", + authorization_server="https://auth.example.com", + scope_fingerprint="scope", + ) + record = TokenRecord(access_token="abc123") + + await store.set(key, record) + fetched = await store.get(key) + assert fetched.access_token == record.access_token + await store.delete(key) + assert await store.get(key) is None + + +def test_select_authorization_server_prefers_explicit(): + metadata = ProtectedResourceMetadata( + resource="https://example.com", + authorization_servers=[ + "https://auth1.example.com", + "https://auth2.example.com", + ], + ) + # URLs get normalized with trailing slashes by pydantic + assert ( + select_authorization_server(metadata, "https://auth2.example.com/") + == "https://auth2.example.com/" + ) + assert ( + select_authorization_server(metadata, "https://unknown.example.com") + == "https://auth1.example.com/" # Falls back to first, which gets normalized + ) + + +def test_normalize_resource_with_fallback(): + assert ( + normalize_resource("https://example.com/api", None) == "https://example.com/api" + ) + assert ( + normalize_resource(None, "https://fallback.example.com") + == "https://fallback.example.com" + ) + with pytest.raises(ValueError): + normalize_resource(None, None)