diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 1167b176a..25cd12862 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -38,4 +38,6 @@ jobs: run: pnpm run build:all - name: Publish preview packages - run: pnpm dlx pkg-pr-new publish --packageManager=npm --pnpm './packages/server' './packages/client' + run: + pnpm dlx pkg-pr-new publish --packageManager=npm --pnpm './packages/server' './packages/client' + './packages/server-express' './packages/server-hono' diff --git a/CLAUDE.md b/CLAUDE.md index 2a0b253a6..3caca17b6 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -224,7 +224,7 @@ mcpServer.tool('tool-name', { param: z.string() }, async ({ param }, extra) => { ```typescript // Server -const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID() }); +const transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID() }); await server.connect(transport); // Client diff --git a/README.md b/README.md index dc0116c96..4d5270287 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,6 @@ # MCP TypeScript SDK -> [!IMPORTANT] -> **This is the `main` branch which contains v2 of the SDK (currently in development, pre-alpha).** +> [!IMPORTANT] **This is the `main` branch which contains v2 of the SDK (currently in development, pre-alpha).** > > We anticipate a stable v2 release in Q1 2026. Until then, **v1.x remains the recommended version** for production use. v1.x will continue to receive bug fixes and security updates for at least 6 months after v2 ships to give people time to upgrade. > diff --git a/common/eslint-config/eslint.config.mjs b/common/eslint-config/eslint.config.mjs index 321f3f6fc..6ac057c69 100644 --- a/common/eslint-config/eslint.config.mjs +++ b/common/eslint-config/eslint.config.mjs @@ -47,6 +47,7 @@ export default defineConfig( '@typescript-eslint/consistent-type-imports': ['error', { disallowTypeAnnotations: false }], 'simple-import-sort/imports': 'warn', 'simple-import-sort/exports': 'warn', + 'import/consistent-type-specifier-style': ['error', 'prefer-top-level'], 'import/no-extraneous-dependencies': [ 'error', { diff --git a/docs/server.md b/docs/server.md index 4d5138e84..800d336db 100644 --- a/docs/server.md +++ b/docs/server.md @@ -70,7 +70,7 @@ For more detailed patterns (stateless vs stateful, JSON response mode, CORS, DNS MCP servers running on localhost are vulnerable to DNS rebinding attacks. Use `createMcpExpressApp()` to create an Express app with DNS rebinding protection enabled by default: ```typescript -import { createMcpExpressApp } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; // Protection auto-enabled (default host is 127.0.0.1) const app = createMcpExpressApp(); @@ -85,7 +85,7 @@ const app = createMcpExpressApp({ host: '0.0.0.0' }); When binding to `0.0.0.0` / `::`, provide an allow-list of hosts: ```typescript -import { createMcpExpressApp } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; const app = createMcpExpressApp({ host: '0.0.0.0', diff --git a/examples/server/README.md b/examples/server/README.md index 310113e45..1e7322b1a 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -1,6 +1,9 @@ # MCP TypeScript SDK Examples (Server) -This directory contains runnable MCP **server** examples built with `@modelcontextprotocol/server`. +This directory contains runnable MCP **server** examples built with `@modelcontextprotocol/server` plus framework adapters: + +- `@modelcontextprotocol/server-express` +- `@modelcontextprotocol/server-hono` For client examples, see [`../client/README.md`](../client/README.md). For guided docs, see [`../../docs/server.md`](../../docs/server.md). @@ -68,7 +71,7 @@ When deploying MCP servers in a horizontally scaled environment (multiple server ### Stateless mode -To enable stateless mode, configure the `StreamableHTTPServerTransport` with: +To enable stateless mode, configure the `NodeStreamableHTTPServerTransport` with: ```typescript sessionIdGenerator: undefined; diff --git a/examples/server/package.json b/examples/server/package.json index a3a3d14c7..cb37d9f40 100644 --- a/examples/server/package.json +++ b/examples/server/package.json @@ -38,6 +38,8 @@ "hono": "catalog:runtimeServerOnly", "@modelcontextprotocol/examples-shared": "workspace:^", "@modelcontextprotocol/server": "workspace:^", + "@modelcontextprotocol/server-express": "workspace:^", + "@modelcontextprotocol/server-hono": "workspace:^", "cors": "catalog:runtimeServerOnly", "express": "catalog:runtimeServerOnly", "zod": "catalog:runtimeShared" diff --git a/examples/server/src/elicitationFormExample.ts b/examples/server/src/elicitationFormExample.ts index f8863c17b..eaeb73c32 100644 --- a/examples/server/src/elicitationFormExample.ts +++ b/examples/server/src/elicitationFormExample.ts @@ -9,8 +9,9 @@ import { randomUUID } from 'node:crypto'; -import { createMcpExpressApp, isInitializeRequest, McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; -import { type Request, type Response } from 'express'; +import { isInitializeRequest, McpServer, NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; +import type { Request, Response } from 'express'; // Create MCP server - it will automatically use AjvJsonSchemaValidator with sensible defaults // The validator supports format validation (email, date, etc.) if ajv-formats is installed @@ -321,7 +322,7 @@ async function main() { const app = createMcpExpressApp(); // Map to store transports by session ID - const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {}; + const transports: { [sessionId: string]: NodeStreamableHTTPServerTransport } = {}; // MCP POST endpoint const mcpPostHandler = async (req: Request, res: Response) => { @@ -331,13 +332,13 @@ async function main() { } try { - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; if (sessionId && transports[sessionId]) { // Reuse existing transport for this session transport = transports[sessionId]; } else if (!sessionId && isInitializeRequest(req.body)) { // New initialization request - create new transport - transport = new StreamableHTTPServerTransport({ + transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), onsessioninitialized: sessionId => { // Store the transport by session ID when session is initialized diff --git a/examples/server/src/elicitationUrlExample.ts b/examples/server/src/elicitationUrlExample.ts index 99f85d079..51e1344b8 100644 --- a/examples/server/src/elicitationUrlExample.ts +++ b/examples/server/src/elicitationUrlExample.ts @@ -13,15 +13,13 @@ import { setupAuthServer } from '@modelcontextprotocol/examples-shared'; import type { CallToolResult, ElicitRequestURLParams, ElicitResult, OAuthMetadata } from '@modelcontextprotocol/server'; import { checkResourceAllowed, - createMcpExpressApp, getOAuthProtectedResourceMetadataUrl, isInitializeRequest, - mcpAuthMetadataRouter, McpServer, - requireBearerAuth, - StreamableHTTPServerTransport, + NodeStreamableHTTPServerTransport, UrlElicitationRequiredError } from '@modelcontextprotocol/server'; +import { createMcpExpressApp, mcpAuthMetadataRouter, requireBearerAuth } from '@modelcontextprotocol/server-express'; import cors from 'cors'; import type { Request, Response } from 'express'; import express from 'express'; @@ -594,7 +592,7 @@ app.post('/confirm-payment', express.urlencoded(), (req: Request, res: Response) }); // Map to store transports by session ID -const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {}; +const transports: { [sessionId: string]: NodeStreamableHTTPServerTransport } = {}; // Interface for a function that can send an elicitation request type ElicitationSender = (params: ElicitRequestURLParams) => Promise; @@ -613,7 +611,7 @@ const mcpPostHandler = async (req: Request, res: Response) => { console.debug(`Received MCP POST for session: ${sessionId || 'unknown'}`); try { - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; if (sessionId && transports[sessionId]) { // Reuse existing transport transport = transports[sessionId]; @@ -621,7 +619,7 @@ const mcpPostHandler = async (req: Request, res: Response) => { const server = getServer(); // New initialization request const eventStore = new InMemoryEventStore(); - transport = new StreamableHTTPServerTransport({ + transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), eventStore, // Enable resumability onsessioninitialized: sessionId => { diff --git a/examples/server/src/honoWebStandardStreamableHttp.ts b/examples/server/src/honoWebStandardStreamableHttp.ts index aef1e99e2..f5c59cffe 100644 --- a/examples/server/src/honoWebStandardStreamableHttp.ts +++ b/examples/server/src/honoWebStandardStreamableHttp.ts @@ -10,6 +10,7 @@ import { serve } from '@hono/node-server'; import type { CallToolResult } from '@modelcontextprotocol/server'; import { McpServer, WebStandardStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { mcpStreamableHttpHandler } from '@modelcontextprotocol/server-hono'; import { Hono } from 'hono'; import { cors } from 'hono/cors'; import * as z from 'zod/v4'; @@ -56,7 +57,7 @@ app.use( app.get('/health', c => c.json({ status: 'ok' })); // MCP endpoint -app.all('/mcp', c => transport.handleRequest(c.req.raw)); +app.all('/mcp', mcpStreamableHttpHandler(transport)); // Start the server const PORT = process.env.MCP_PORT ? parseInt(process.env.MCP_PORT, 10) : 3000; diff --git a/examples/server/src/jsonResponseStreamableHttp.ts b/examples/server/src/jsonResponseStreamableHttp.ts index 2199ebfbe..5935ad2c2 100644 --- a/examples/server/src/jsonResponseStreamableHttp.ts +++ b/examples/server/src/jsonResponseStreamableHttp.ts @@ -1,7 +1,8 @@ import { randomUUID } from 'node:crypto'; import type { CallToolResult } from '@modelcontextprotocol/server'; -import { createMcpExpressApp, isInitializeRequest, McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { isInitializeRequest, McpServer, NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; import * as z from 'zod/v4'; @@ -96,21 +97,21 @@ const getServer = () => { const app = createMcpExpressApp(); // Map to store transports by session ID -const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {}; +const transports: { [sessionId: string]: NodeStreamableHTTPServerTransport } = {}; app.post('/mcp', async (req: Request, res: Response) => { console.log('Received MCP request:', req.body); try { // Check for existing session ID const sessionId = req.headers['mcp-session-id'] as string | undefined; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; if (sessionId && transports[sessionId]) { // Reuse existing transport transport = transports[sessionId]; } else if (!sessionId && isInitializeRequest(req.body)) { // New initialization request - use JSON response mode - transport = new StreamableHTTPServerTransport({ + transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), enableJsonResponse: true, // Enable JSON response mode onsessioninitialized: sessionId => { diff --git a/examples/server/src/simpleSseServer.ts b/examples/server/src/simpleSseServer.ts index 90561c62f..35b48b69d 100644 --- a/examples/server/src/simpleSseServer.ts +++ b/examples/server/src/simpleSseServer.ts @@ -1,5 +1,6 @@ import type { CallToolResult } from '@modelcontextprotocol/server'; -import { createMcpExpressApp, McpServer, SSEServerTransport } from '@modelcontextprotocol/server'; +import { McpServer, SSEServerTransport } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; import * as z from 'zod/v4'; diff --git a/examples/server/src/simpleStatelessStreamableHttp.ts b/examples/server/src/simpleStatelessStreamableHttp.ts index 3aee2c212..70389275c 100644 --- a/examples/server/src/simpleStatelessStreamableHttp.ts +++ b/examples/server/src/simpleStatelessStreamableHttp.ts @@ -1,5 +1,6 @@ import type { CallToolResult, GetPromptResult, ReadResourceResult } from '@modelcontextprotocol/server'; -import { createMcpExpressApp, McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { McpServer, NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; import * as z from 'zod/v4'; @@ -103,7 +104,7 @@ const app = createMcpExpressApp(); app.post('/mcp', async (req: Request, res: Response) => { const server = getServer(); try { - const transport: StreamableHTTPServerTransport = new StreamableHTTPServerTransport({ + const transport: NodeStreamableHTTPServerTransport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: undefined }); await server.connect(transport); diff --git a/examples/server/src/simpleStreamableHttp.ts b/examples/server/src/simpleStreamableHttp.ts index 7613e3786..c1656a544 100644 --- a/examples/server/src/simpleStreamableHttp.ts +++ b/examples/server/src/simpleStreamableHttp.ts @@ -11,17 +11,15 @@ import type { } from '@modelcontextprotocol/server'; import { checkResourceAllowed, - createMcpExpressApp, ElicitResultSchema, getOAuthProtectedResourceMetadataUrl, InMemoryTaskMessageQueue, InMemoryTaskStore, isInitializeRequest, - mcpAuthMetadataRouter, McpServer, - requireBearerAuth, - StreamableHTTPServerTransport + NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { createMcpExpressApp, mcpAuthMetadataRouter, requireBearerAuth } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; import * as z from 'zod/v4'; @@ -590,7 +588,7 @@ if (useOAuth) { } // Map to store transports by session ID -const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {}; +const transports: { [sessionId: string]: NodeStreamableHTTPServerTransport } = {}; // MCP POST endpoint with optional auth const mcpPostHandler = async (req: Request, res: Response) => { @@ -605,14 +603,14 @@ const mcpPostHandler = async (req: Request, res: Response) => { console.log('Authenticated user:', req.auth); } try { - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; if (sessionId && transports[sessionId]) { // Reuse existing transport transport = transports[sessionId]; } else if (!sessionId && isInitializeRequest(req.body)) { // New initialization request const eventStore = new InMemoryEventStore(); - transport = new StreamableHTTPServerTransport({ + transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), eventStore, // Enable resumability onsessioninitialized: sessionId => { diff --git a/examples/server/src/simpleTaskInteractive.ts b/examples/server/src/simpleTaskInteractive.ts index 956c33f8e..469ecf0c2 100644 --- a/examples/server/src/simpleTaskInteractive.ts +++ b/examples/server/src/simpleTaskInteractive.ts @@ -35,16 +35,16 @@ import type { } from '@modelcontextprotocol/server'; import { CallToolRequestSchema, - createMcpExpressApp, GetTaskPayloadRequestSchema, GetTaskRequestSchema, InMemoryTaskStore, isTerminal, ListToolsRequestSchema, + NodeStreamableHTTPServerTransport, RELATED_TASK_META_KEY, - Server, - StreamableHTTPServerTransport + Server } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; // ============================================================================ @@ -642,7 +642,7 @@ const createServer = (): Server => { const app = createMcpExpressApp(); // Map to store transports by session ID -const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {}; +const transports: { [sessionId: string]: NodeStreamableHTTPServerTransport } = {}; // Helper to check if request is initialize const isInitializeRequest = (body: unknown): boolean => { @@ -654,12 +654,12 @@ app.post('/mcp', async (req: Request, res: Response) => { const sessionId = req.headers['mcp-session-id'] as string | undefined; try { - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; if (sessionId && transports[sessionId]) { transport = transports[sessionId]; } else if (!sessionId && isInitializeRequest(req.body)) { - transport = new StreamableHTTPServerTransport({ + transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), onsessioninitialized: sid => { console.log(`Session initialized: ${sid}`); diff --git a/examples/server/src/sseAndStreamableHttpCompatibleServer.ts b/examples/server/src/sseAndStreamableHttpCompatibleServer.ts index 335802d0a..bb2636ea3 100644 --- a/examples/server/src/sseAndStreamableHttpCompatibleServer.ts +++ b/examples/server/src/sseAndStreamableHttpCompatibleServer.ts @@ -1,13 +1,8 @@ import { randomUUID } from 'node:crypto'; import type { CallToolResult } from '@modelcontextprotocol/server'; -import { - createMcpExpressApp, - isInitializeRequest, - McpServer, - SSEServerTransport, - StreamableHTTPServerTransport -} from '@modelcontextprotocol/server'; +import { isInitializeRequest, McpServer, NodeStreamableHTTPServerTransport, SSEServerTransport } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; import * as z from 'zod/v4'; @@ -81,7 +76,7 @@ const getServer = () => { const app = createMcpExpressApp(); // Store transports by session ID -const transports: Record = {}; +const transports: Record = {}; //============================================================================= // STREAMABLE HTTP TRANSPORT (PROTOCOL VERSION 2025-11-25) @@ -94,16 +89,16 @@ app.all('/mcp', async (req: Request, res: Response) => { try { // Check for existing session ID const sessionId = req.headers['mcp-session-id'] as string | undefined; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; if (sessionId && transports[sessionId]) { // Check if the transport is of the correct type const existingTransport = transports[sessionId]; - if (existingTransport instanceof StreamableHTTPServerTransport) { + if (existingTransport instanceof NodeStreamableHTTPServerTransport) { // Reuse existing transport transport = existingTransport; } else { - // Transport exists but is not a StreamableHTTPServerTransport (could be SSEServerTransport) + // Transport exists but is not a NodeStreamableHTTPServerTransport (could be SSEServerTransport) res.status(400).json({ jsonrpc: '2.0', error: { @@ -116,7 +111,7 @@ app.all('/mcp', async (req: Request, res: Response) => { } } else if (!sessionId && req.method === 'POST' && isInitializeRequest(req.body)) { const eventStore = new InMemoryEventStore(); - transport = new StreamableHTTPServerTransport({ + transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), eventStore, // Enable resumability onsessioninitialized: sessionId => { @@ -191,7 +186,7 @@ app.post('/messages', async (req: Request, res: Response) => { // Reuse existing transport transport = existingTransport; } else { - // Transport exists but is not a SSEServerTransport (could be StreamableHTTPServerTransport) + // Transport exists but is not a SSEServerTransport (could be NodeStreamableHTTPServerTransport) res.status(400).json({ jsonrpc: '2.0', error: { diff --git a/examples/server/src/ssePollingExample.ts b/examples/server/src/ssePollingExample.ts index 4e3d36328..4d0841dee 100644 --- a/examples/server/src/ssePollingExample.ts +++ b/examples/server/src/ssePollingExample.ts @@ -15,7 +15,8 @@ import { randomUUID } from 'node:crypto'; import type { CallToolResult } from '@modelcontextprotocol/server'; -import { createMcpExpressApp, McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { McpServer, NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import cors from 'cors'; import type { Request, Response } from 'express'; @@ -111,7 +112,7 @@ app.use(cors()); const eventStore = new InMemoryEventStore(); // Track transports by session ID for session reuse -const transports = new Map(); +const transports = new Map(); // Handle all MCP requests app.all('/mcp', async (req: Request, res: Response) => { @@ -121,7 +122,7 @@ app.all('/mcp', async (req: Request, res: Response) => { let transport = sessionId ? transports.get(sessionId) : undefined; if (!transport) { - transport = new StreamableHTTPServerTransport({ + transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), eventStore, retryInterval: 2000, // Default retry interval for priming events diff --git a/examples/server/src/standaloneSseWithGetStreamableHttp.ts b/examples/server/src/standaloneSseWithGetStreamableHttp.ts index cceb24299..869d7e859 100644 --- a/examples/server/src/standaloneSseWithGetStreamableHttp.ts +++ b/examples/server/src/standaloneSseWithGetStreamableHttp.ts @@ -1,7 +1,8 @@ import { randomUUID } from 'node:crypto'; import type { ReadResourceResult } from '@modelcontextprotocol/server'; -import { createMcpExpressApp, isInitializeRequest, McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { isInitializeRequest, McpServer, NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; // Create an MCP server with implementation details @@ -11,7 +12,7 @@ const server = new McpServer({ }); // Store transports by session ID to send notifications -const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {}; +const transports: { [sessionId: string]: NodeStreamableHTTPServerTransport } = {}; const addResource = (name: string, content: string) => { const uri = `https://mcp-example.com/dynamic/${encodeURIComponent(name)}`; @@ -41,14 +42,14 @@ app.post('/mcp', async (req: Request, res: Response) => { try { // Check for existing session ID const sessionId = req.headers['mcp-session-id'] as string | undefined; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; if (sessionId && transports[sessionId]) { // Reuse existing transport transport = transports[sessionId]; } else if (!sessionId && isInitializeRequest(req.body)) { // New initialization request - transport = new StreamableHTTPServerTransport({ + transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), onsessioninitialized: sessionId => { // Store the transport by session ID when session is initialized diff --git a/examples/server/tsconfig.json b/examples/server/tsconfig.json index 98d3a5b3f..1f72b0199 100644 --- a/examples/server/tsconfig.json +++ b/examples/server/tsconfig.json @@ -6,6 +6,8 @@ "paths": { "*": ["./*"], "@modelcontextprotocol/server": ["./node_modules/@modelcontextprotocol/server/src/index.ts"], + "@modelcontextprotocol/server-express": ["./node_modules/@modelcontextprotocol/server-express/src/index.ts"], + "@modelcontextprotocol/server-hono": ["./node_modules/@modelcontextprotocol/server-hono/src/index.ts"], "@modelcontextprotocol/core": [ "./node_modules/@modelcontextprotocol/server/node_modules/@modelcontextprotocol/core/src/index.ts" ], diff --git a/examples/shared/package.json b/examples/shared/package.json index 8287ca552..2d0f6ebe9 100644 --- a/examples/shared/package.json +++ b/examples/shared/package.json @@ -35,6 +35,7 @@ }, "dependencies": { "@modelcontextprotocol/server": "workspace:^", + "@modelcontextprotocol/server-express": "workspace:^", "express": "catalog:runtimeServerOnly" }, "devDependencies": { diff --git a/examples/shared/src/demoInMemoryOAuthProvider.ts b/examples/shared/src/demoInMemoryOAuthProvider.ts index bcf11dd0c..23b168224 100644 --- a/examples/shared/src/demoInMemoryOAuthProvider.ts +++ b/examples/shared/src/demoInMemoryOAuthProvider.ts @@ -9,8 +9,9 @@ import type { OAuthServerProvider, OAuthTokens } from '@modelcontextprotocol/server'; -import { createOAuthMetadata, InvalidRequestError, mcpAuthRouter, resourceUrlFromServerUrl } from '@modelcontextprotocol/server'; -import type { Request, Response } from 'express'; +import { createOAuthMetadata, InvalidRequestError, resourceUrlFromServerUrl } from '@modelcontextprotocol/server'; +import { mcpAuthRouter } from '@modelcontextprotocol/server-express'; +import type { Request, Response as ExpressResponse } from 'express'; import express from 'express'; export class DemoInMemoryClientsStore implements OAuthRegisteredClientsStore { @@ -47,7 +48,7 @@ export class DemoInMemoryAuthProvider implements OAuthServerProvider { constructor(private validateResource?: (resource?: URL) => boolean) {} - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + async authorize(client: OAuthClientInformationFull, params: AuthorizationParams): Promise { const code = randomUUID(); const searchParams = new URLSearchParams({ @@ -64,27 +65,24 @@ export class DemoInMemoryAuthProvider implements OAuthServerProvider { // Simulate a user login // Set a secure HTTP-only session cookie with authorization info - if (res.cookie) { - const authCookieData = { - userId: 'demo_user', - name: 'Demo User', - timestamp: Date.now() - }; - res.cookie('demo_session', JSON.stringify(authCookieData), { - httpOnly: true, - secure: false, // In production, this should be true - sameSite: 'lax', - maxAge: 24 * 60 * 60 * 1000, // 24 hours - for demo purposes - path: '/' // Available to all routes - }); - } + const authCookieData = { + userId: 'demo_user', + name: 'Demo User', + timestamp: Date.now() + }; + const cookieValue = encodeURIComponent(JSON.stringify(authCookieData)); + const maxAgeSeconds = 24 * 60 * 60; // 24 hours - demo only + const setCookie = `demo_session=${cookieValue}; HttpOnly; SameSite=Lax; Max-Age=${maxAgeSeconds}; Path=/`; if (!client.redirect_uris.includes(params.redirectUri)) { throw new InvalidRequestError('Unregistered redirect_uri'); } const targetUrl = new URL(params.redirectUri); targetUrl.search = searchParams.toString(); - res.redirect(targetUrl.toString()); + const redirectResponse = Response.redirect(targetUrl.toString(), 302); + const headers = new Headers(redirectResponse.headers); + headers.append('Set-Cookie', setCookie); + return new Response(null, { status: redirectResponse.status, headers }); } async challengeForAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise { @@ -204,7 +202,7 @@ export const setupAuthServer = ({ }) ); - authApp.post('/introspect', async (req: Request, res: Response) => { + authApp.post('/introspect', async (req: Request, res: ExpressResponse) => { try { const { token } = req.body; if (!token) { diff --git a/examples/shared/test/demoInMemoryOAuthProvider.test.ts b/examples/shared/test/demoInMemoryOAuthProvider.test.ts index 4018dddbe..0c1c887aa 100644 --- a/examples/shared/test/demoInMemoryOAuthProvider.test.ts +++ b/examples/shared/test/demoInMemoryOAuthProvider.test.ts @@ -1,21 +1,15 @@ import type { OAuthClientInformationFull } from '@modelcontextprotocol/core'; import type { AuthorizationParams } from '@modelcontextprotocol/server'; import { InvalidRequestError } from '@modelcontextprotocol/server'; -import { createExpressResponseMock } from '@modelcontextprotocol/test-helpers'; -import type { Response } from 'express'; import { beforeEach, describe, expect, it } from 'vitest'; -import { DemoInMemoryAuthProvider, DemoInMemoryClientsStore } from '../src/demoInMemoryOAuthProvider.js'; +import { DemoInMemoryAuthProvider } from '../src/demoInMemoryOAuthProvider.js'; describe('DemoInMemoryAuthProvider', () => { let provider: DemoInMemoryAuthProvider; - let mockResponse: Response & { getRedirectUrl: () => string }; beforeEach(() => { provider = new DemoInMemoryAuthProvider(); - mockResponse = createExpressResponseMock({ trackRedirectUrl: true }) as Response & { - getRedirectUrl: () => string; - }; }); describe('authorize', () => { @@ -26,7 +20,7 @@ describe('DemoInMemoryAuthProvider', () => { scope: 'test-scope' }; - it('should redirect to the requested redirect_uri when valid', async () => { + it('redirects to redirect_uri when valid', async () => { const params: AuthorizationParams = { redirectUri: 'https://example.com/callback', state: 'test-state', @@ -34,18 +28,18 @@ describe('DemoInMemoryAuthProvider', () => { scopes: ['test-scope'] }; - await provider.authorize(validClient, params, mockResponse); - - expect(mockResponse.redirect).toHaveBeenCalled(); - expect(mockResponse.getRedirectUrl()).toBeDefined(); - - const url = new URL(mockResponse.getRedirectUrl()); + const res = await provider.authorize(validClient, params); + expect(res.status).toBe(302); + const location = res.headers.get('location'); + expect(location).toBeTruthy(); + const url = new URL(location!); expect(url.origin + url.pathname).toBe('https://example.com/callback'); expect(url.searchParams.get('state')).toBe('test-state'); - expect(url.searchParams.has('code')).toBe(true); + expect(url.searchParams.get('code')).toBeTruthy(); + expect(res.headers.get('set-cookie')).toContain('demo_session='); }); - it('should throw InvalidRequestError for unregistered redirect_uri', async () => { + it('throws InvalidRequestError for unregistered redirect_uri', async () => { const params: AuthorizationParams = { redirectUri: 'https://evil.com/callback', state: 'test-state', @@ -53,212 +47,8 @@ describe('DemoInMemoryAuthProvider', () => { scopes: ['test-scope'] }; - await expect(provider.authorize(validClient, params, mockResponse)).rejects.toThrow(InvalidRequestError); - - await expect(provider.authorize(validClient, params, mockResponse)).rejects.toThrow('Unregistered redirect_uri'); - - expect(mockResponse.redirect).not.toHaveBeenCalled(); - }); - - it('should generate unique authorization codes for multiple requests', async () => { - const params1: AuthorizationParams = { - redirectUri: 'https://example.com/callback', - state: 'state-1', - codeChallenge: 'challenge-1', - scopes: ['test-scope'] - }; - - const params2: AuthorizationParams = { - redirectUri: 'https://example.com/callback', - state: 'state-2', - codeChallenge: 'challenge-2', - scopes: ['test-scope'] - }; - - await provider.authorize(validClient, params1, mockResponse); - const firstRedirectUrl = mockResponse.getRedirectUrl(); - const firstCode = new URL(firstRedirectUrl).searchParams.get('code'); - - // Reset the mock for the second call - mockResponse = createExpressResponseMock({ trackRedirectUrl: true }) as Response & { - getRedirectUrl: () => string; - }; - await provider.authorize(validClient, params2, mockResponse); - const secondRedirectUrl = mockResponse.getRedirectUrl(); - const secondCode = new URL(secondRedirectUrl).searchParams.get('code'); - - expect(firstCode).toBeDefined(); - expect(secondCode).toBeDefined(); - expect(firstCode).not.toBe(secondCode); - }); - - it('should handle params without state', async () => { - const params: AuthorizationParams = { - redirectUri: 'https://example.com/callback', - codeChallenge: 'test-challenge', - scopes: ['test-scope'] - }; - - await provider.authorize(validClient, params, mockResponse); - - expect(mockResponse.redirect).toHaveBeenCalled(); - expect(mockResponse.getRedirectUrl()).toBeDefined(); - - const url = new URL(mockResponse.getRedirectUrl()); - expect(url.searchParams.has('state')).toBe(false); - expect(url.searchParams.has('code')).toBe(true); - }); - }); - - describe('challengeForAuthorizationCode', () => { - const validClient: OAuthClientInformationFull = { - client_id: 'test-client', - client_secret: 'test-secret', - redirect_uris: ['https://example.com/callback'], - scope: 'test-scope' - }; - - it('should return the code challenge for a valid authorization code', async () => { - const params: AuthorizationParams = { - redirectUri: 'https://example.com/callback', - state: 'test-state', - codeChallenge: 'test-challenge-value', - scopes: ['test-scope'] - }; - - await provider.authorize(validClient, params, mockResponse); - const code = new URL(mockResponse.getRedirectUrl()).searchParams.get('code')!; - - const challenge = await provider.challengeForAuthorizationCode(validClient, code); - expect(challenge).toBe('test-challenge-value'); - }); - - it('should throw error for invalid authorization code', async () => { - await expect(provider.challengeForAuthorizationCode(validClient, 'invalid-code')).rejects.toThrow('Invalid authorization code'); - }); - }); - - describe('exchangeAuthorizationCode', () => { - const validClient: OAuthClientInformationFull = { - client_id: 'test-client', - client_secret: 'test-secret', - redirect_uris: ['https://example.com/callback'], - scope: 'test-scope' - }; - - it('should exchange valid authorization code for tokens', async () => { - const params: AuthorizationParams = { - redirectUri: 'https://example.com/callback', - state: 'test-state', - codeChallenge: 'test-challenge', - scopes: ['test-scope', 'other-scope'] - }; - - await provider.authorize(validClient, params, mockResponse); - const code = new URL(mockResponse.getRedirectUrl()).searchParams.get('code')!; - - const tokens = await provider.exchangeAuthorizationCode(validClient, code); - - expect(tokens).toEqual({ - access_token: expect.any(String), - token_type: 'bearer', - expires_in: 3600, - scope: 'test-scope other-scope' - }); - }); - - it('should throw error for invalid authorization code', async () => { - await expect(provider.exchangeAuthorizationCode(validClient, 'invalid-code')).rejects.toThrow('Invalid authorization code'); - }); - - it('should throw error when client_id does not match', async () => { - const params: AuthorizationParams = { - redirectUri: 'https://example.com/callback', - state: 'test-state', - codeChallenge: 'test-challenge', - scopes: ['test-scope'] - }; - - await provider.authorize(validClient, params, mockResponse); - const code = new URL(mockResponse.getRedirectUrl()).searchParams.get('code')!; - - const differentClient: OAuthClientInformationFull = { - client_id: 'different-client', - client_secret: 'different-secret', - redirect_uris: ['https://example.com/callback'], - scope: 'test-scope' - }; - - await expect(provider.exchangeAuthorizationCode(differentClient, code)).rejects.toThrow( - 'Authorization code was not issued to this client' - ); - }); - - it('should delete authorization code after successful exchange', async () => { - const params: AuthorizationParams = { - redirectUri: 'https://example.com/callback', - state: 'test-state', - codeChallenge: 'test-challenge', - scopes: ['test-scope'] - }; - - await provider.authorize(validClient, params, mockResponse); - const code = new URL(mockResponse.getRedirectUrl()).searchParams.get('code')!; - - // First exchange should succeed - await provider.exchangeAuthorizationCode(validClient, code); - - // Second exchange should fail - await expect(provider.exchangeAuthorizationCode(validClient, code)).rejects.toThrow('Invalid authorization code'); - }); - - it('should validate resource when validateResource is provided', async () => { - const validateResource = vi.fn().mockReturnValue(false); - const strictProvider = new DemoInMemoryAuthProvider(validateResource); - - const params: AuthorizationParams = { - redirectUri: 'https://example.com/callback', - state: 'test-state', - codeChallenge: 'test-challenge', - scopes: ['test-scope'], - resource: new URL('https://invalid-resource.com') - }; - - await strictProvider.authorize(validClient, params, mockResponse); - const code = new URL(mockResponse.getRedirectUrl()).searchParams.get('code')!; - - await expect(strictProvider.exchangeAuthorizationCode(validClient, code)).rejects.toThrow( - 'Invalid resource: https://invalid-resource.com/' - ); - - expect(validateResource).toHaveBeenCalledWith(params.resource); - }); - }); - - describe('DemoInMemoryClientsStore', () => { - let store: DemoInMemoryClientsStore; - - beforeEach(() => { - store = new DemoInMemoryClientsStore(); - }); - - it('should register and retrieve client', async () => { - const client: OAuthClientInformationFull = { - client_id: 'test-client', - client_secret: 'test-secret', - redirect_uris: ['https://example.com/callback'], - scope: 'test-scope' - }; - - await store.registerClient(client); - const retrieved = await store.getClient('test-client'); - - expect(retrieved).toEqual(client); - }); - - it('should return undefined for non-existent client', async () => { - const retrieved = await store.getClient('non-existent'); - expect(retrieved).toBeUndefined(); + await expect(provider.authorize(validClient, params)).rejects.toThrow(InvalidRequestError); + await expect(provider.authorize(validClient, params)).rejects.toThrow('Unregistered redirect_uri'); }); }); }); diff --git a/examples/shared/tsconfig.json b/examples/shared/tsconfig.json index aa994f939..91e368e7a 100644 --- a/examples/shared/tsconfig.json +++ b/examples/shared/tsconfig.json @@ -6,6 +6,7 @@ "paths": { "*": ["./*"], "@modelcontextprotocol/server": ["./node_modules/@modelcontextprotocol/server/src/index.ts"], + "@modelcontextprotocol/server-express": ["./node_modules/@modelcontextprotocol/server-express/src/index.ts"], "@modelcontextprotocol/core": [ "./node_modules/@modelcontextprotocol/server/node_modules/@modelcontextprotocol/core/src/index.ts" ], diff --git a/package.json b/package.json index 2633d5ef2..0b25f9f4e 100644 --- a/package.json +++ b/package.json @@ -15,7 +15,7 @@ "node": ">=20", "pnpm": ">=10.24.0" }, - "packageManager": "pnpm@10.24.0", + "packageManager": "pnpm@10.26.1", "keywords": [ "modelcontextprotocol", "mcp" diff --git a/packages/client/src/client/sse.ts b/packages/client/src/client/sse.ts index bff74986e..4cfa77e17 100644 --- a/packages/client/src/client/sse.ts +++ b/packages/client/src/client/sse.ts @@ -1,6 +1,7 @@ import type { FetchLike, JSONRPCMessage, Transport } from '@modelcontextprotocol/core'; import { createFetchWithInit, JSONRPCMessageSchema, normalizeHeaders } from '@modelcontextprotocol/core'; -import { type ErrorEvent, EventSource, type EventSourceInit } from 'eventsource'; +import type { ErrorEvent, EventSourceInit } from 'eventsource'; +import { EventSource } from 'eventsource'; import type { AuthResult, OAuthClientProvider } from './auth.js'; import { auth, extractWWWAuthenticateParams, UnauthorizedError } from './auth.js'; diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index 9c65015d1..c9242e96d 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -292,14 +292,14 @@ export type RequestHandlerExtra void; /** * Closes the standalone GET SSE stream, triggering client reconnection. - * Only available when using StreamableHTTPServerTransport with eventStore configured. + * Only available when using aStreamableHTTPServerTransport with eventStore configured. * Use this to implement polling behavior for server-initiated notifications. */ closeStandaloneSSEStream?: () => void; diff --git a/packages/core/src/types/types.ts b/packages/core/src/types/types.ts index 35b04745d..f3e1b92a8 100644 --- a/packages/core/src/types/types.ts +++ b/packages/core/src/types/types.ts @@ -2415,13 +2415,13 @@ export interface MessageExtraInfo { /** * Callback to close the SSE stream for this request, triggering client reconnection. - * Only available when using StreamableHTTPServerTransport with eventStore configured. + * Only available when using NodeStreamableHTTPServerTransport with eventStore configured. */ closeSSEStream?: () => void; /** * Callback to close the standalone GET SSE stream, triggering client reconnection. - * Only available when using StreamableHTTPServerTransport with eventStore configured. + * Only available when using NodeStreamableHTTPServerTransport with eventStore configured. */ closeStandaloneSSEStream?: () => void; } diff --git a/packages/server-express/README.md b/packages/server-express/README.md new file mode 100644 index 000000000..27fb348d7 --- /dev/null +++ b/packages/server-express/README.md @@ -0,0 +1,83 @@ +# `@modelcontextprotocol/server-express` + +Express adapters for the MCP TypeScript server SDK. + +This package is the Express-specific companion to [`@modelcontextprotocol/server`](../server/), which is framework-agnostic and uses Web Standard `Request`/`Response` interfaces. + +## Install + +```bash +npm install @modelcontextprotocol/server @modelcontextprotocol/server-express zod +``` + +## Exports + +- `createMcpExpressApp(options?)` +- `hostHeaderValidation(allowedHosts)` +- `localhostHostValidation()` +- `mcpAuthRouter(options)` +- `mcpAuthMetadataRouter(options)` +- `requireBearerAuth(options)` + +## Usage + +### Create an Express app with localhost DNS rebinding protection + +```ts +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; + +const app = createMcpExpressApp(); // default host is 127.0.0.1; protection enabled +``` + +### Streamable HTTP endpoint (Express) + +```ts +import { McpServer, NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; + +const app = createMcpExpressApp(); + +app.post('/mcp', async (req, res) => { + const transport = new NodeStreamableHTTPServerTransport(); + await transport.handleRequest(req, res, req.body); +}); +``` + +### OAuth routes (Express) + +`@modelcontextprotocol/server` provides Web-standard auth handlers; this package wraps them as Express routers. + +```ts +import { mcpAuthRouter } from '@modelcontextprotocol/server-express'; +import type { OAuthServerProvider } from '@modelcontextprotocol/server'; +import express from 'express'; + +const provider: OAuthServerProvider = /* ... */; +const app = express(); +app.use(express.json()); + +// MUST be mounted at the app root +app.use( + mcpAuthRouter({ + provider, + issuerUrl: new URL('https://auth.example.com'), + // Optional rate limiting (implemented via express-rate-limit) + rateLimit: { windowMs: 60_000, max: 60 } + }) +); +``` + +### Bearer auth middleware (Express) + +`requireBearerAuth` validates the `Authorization: Bearer ...` header and sets `req.auth` on success. + +```ts +import { requireBearerAuth } from '@modelcontextprotocol/server-express'; +import type { OAuthTokenVerifier } from '@modelcontextprotocol/server'; + +const verifier: OAuthTokenVerifier = /* ... */; + +app.post('/protected', requireBearerAuth({ verifier }), (req, res) => { + res.json({ clientId: req.auth?.clientId }); +}); +``` diff --git a/packages/server-express/eslint.config.mjs b/packages/server-express/eslint.config.mjs new file mode 100644 index 000000000..03d533134 --- /dev/null +++ b/packages/server-express/eslint.config.mjs @@ -0,0 +1,12 @@ +// @ts-check + +import baseConfig from '@modelcontextprotocol/eslint-config'; + +export default [ + ...baseConfig, + { + settings: { + 'import/internal-regex': '^@modelcontextprotocol/(server|core)' + } + } +]; diff --git a/packages/server-express/package.json b/packages/server-express/package.json new file mode 100644 index 000000000..bca9ac505 --- /dev/null +++ b/packages/server-express/package.json @@ -0,0 +1,70 @@ +{ + "name": "@modelcontextprotocol/server-express", + "private": false, + "version": "2.0.0-alpha.0", + "description": "Express adapters for the Model Context Protocol TypeScript server SDK", + "license": "MIT", + "author": "Anthropic, PBC (https://anthropic.com)", + "homepage": "https://modelcontextprotocol.io", + "bugs": "https://github.com/modelcontextprotocol/typescript-sdk/issues", + "type": "module", + "repository": { + "type": "git", + "url": "git+https://github.com/modelcontextprotocol/typescript-sdk.git" + }, + "engines": { + "node": ">=20", + "pnpm": ">=10.24.0" + }, + "packageManager": "pnpm@10.24.0", + "keywords": [ + "modelcontextprotocol", + "mcp", + "express" + ], + "exports": { + ".": { + "types": "./dist/index.d.mts", + "import": "./dist/index.mjs" + } + }, + "files": [ + "dist" + ], + "scripts": { + "typecheck": "tsgo -p tsconfig.json --noEmit", + "build": "tsdown", + "build:watch": "tsdown --watch", + "prepack": "npm run build", + "lint": "eslint src/ && prettier --ignore-path ../../.prettierignore --check .", + "lint:fix": "eslint src/ --fix && prettier --ignore-path ../../.prettierignore --write .", + "check": "npm run typecheck && npm run lint", + "test": "vitest run", + "test:watch": "vitest" + }, + "dependencies": { + "@modelcontextprotocol/server": "workspace:^", + "express": "catalog:runtimeServerOnly", + "express-rate-limit": "catalog:runtimeServerOnly", + "@remix-run/node-fetch-server": "catalog:runtimeServerOnly" + }, + "devDependencies": { + "@modelcontextprotocol/tsconfig": "workspace:^", + "@modelcontextprotocol/vitest-config": "workspace:^", + "@modelcontextprotocol/eslint-config": "workspace:^", + "@eslint/js": "catalog:devTools", + "@types/express": "catalog:devTools", + "@types/express-serve-static-core": "catalog:devTools", + "@types/supertest": "catalog:devTools", + "@typescript/native-preview": "catalog:devTools", + "eslint": "catalog:devTools", + "eslint-config-prettier": "catalog:devTools", + "eslint-plugin-n": "catalog:devTools", + "prettier": "catalog:devTools", + "supertest": "catalog:devTools", + "tsdown": "catalog:devTools", + "typescript": "catalog:devTools", + "typescript-eslint": "catalog:devTools", + "vitest": "catalog:devTools" + } +} diff --git a/packages/server-express/src/auth/bearerAuth.ts b/packages/server-express/src/auth/bearerAuth.ts new file mode 100644 index 000000000..d8d0aad8b --- /dev/null +++ b/packages/server-express/src/auth/bearerAuth.ts @@ -0,0 +1,54 @@ +import { URL } from 'node:url'; + +import type { AuthInfo } from '@modelcontextprotocol/core'; +import type { BearerAuthMiddlewareOptions } from '@modelcontextprotocol/server'; +import { requireBearerAuth as requireBearerAuthWeb } from '@modelcontextprotocol/server'; +import { sendResponse } from '@remix-run/node-fetch-server'; +import type { NextFunction, Request as ExpressRequest, RequestHandler } from 'express'; + +declare module 'express-serve-static-core' { + interface Request { + /** + * Information about the validated access token, if `requireBearerAuth` was used. + */ + auth?: AuthInfo; + } +} + +function expressRequestUrl(req: ExpressRequest): URL { + const host = req.get('host') ?? req.headers.host ?? 'localhost'; + const protocol = req.protocol ?? 'http'; + const path = req.originalUrl ?? req.url ?? '/'; + return new URL(path, `${protocol}://${host}`); +} + +/** + * Express middleware wrapper for the Web-standard `requireBearerAuth` helper. + * + * On success, sets `req.auth` and calls `next()`. + * On failure, writes the JSON error response and ends the request. + */ +export function requireBearerAuth(options: BearerAuthMiddlewareOptions): RequestHandler { + return async (req, res, next: NextFunction) => { + try { + const url = expressRequestUrl(req); + const webReq = new Request(url, { + method: req.method, + headers: { + authorization: req.headers.authorization ?? '' + } + }); + + const result = await requireBearerAuthWeb(webReq, options); + if ('authInfo' in result) { + req.auth = result.authInfo; + next(); + return; + } + + await sendResponse(res, result.response); + } catch (err) { + next(err); + } + }; +} diff --git a/packages/server-express/src/auth/router.ts b/packages/server-express/src/auth/router.ts new file mode 100644 index 000000000..868149efd --- /dev/null +++ b/packages/server-express/src/auth/router.ts @@ -0,0 +1,101 @@ +import type { AuthMetadataOptions, AuthRouterOptions } from '@modelcontextprotocol/server'; +import { + getParsedBody, + mcpAuthMetadataRouter as createWebAuthMetadataRouter, + mcpAuthRouter as createWebAuthRouter, + TooManyRequestsError +} from '@modelcontextprotocol/server'; +import { createRequest, sendResponse } from '@remix-run/node-fetch-server'; +import type { RequestHandler } from 'express'; +import express from 'express'; +import { rateLimit } from 'express-rate-limit'; + +export type ExpressAuthRateLimitOptions = + | false + | { + /** + * Window size in ms (default: 60s) + */ + windowMs?: number; + /** + * Max requests per window per client (default: 60) + */ + max?: number; + }; + +/** + * Express router adapter for the Web-standard `mcpAuthRouter` from `@modelcontextprotocol/server`. + * + * IMPORTANT: This router MUST be mounted at the application root, like: + * + * ```ts + * app.use(mcpAuthRouter(...)) + * ``` + */ +export function mcpAuthRouter(options: AuthRouterOptions & { rateLimit?: ExpressAuthRateLimitOptions }): RequestHandler { + const web = createWebAuthRouter(options); + const router = express.Router(); + + const rateLimitOptions = options.rateLimit; + const limiter = + rateLimitOptions === false + ? undefined + : rateLimit({ + windowMs: rateLimitOptions?.windowMs ?? 60_000, + max: rateLimitOptions?.max ?? 60, + standardHeaders: true, + legacyHeaders: false, + handler: (_req, res) => { + const err = new TooManyRequestsError('Too many requests'); + res.status(429).json(err.toResponseObject()); + } + }); + + const isRateLimitedPath = (path: string): boolean => + path === '/authorize' || path === '/token' || path === '/register' || path === '/revoke'; + + for (const route of web.routes) { + const handlers: RequestHandler[] = []; + if (limiter && isRateLimitedPath(route.path)) { + handlers.push(limiter); + } + handlers.push(async (req, res, next) => { + try { + const webReq = createRequest(req, res); + const parsedBody = req.body !== undefined ? req.body : await getParsedBody(webReq); + const webRes = await route.handler(webReq, { parsedBody }); + await sendResponse(res, webRes); + } catch (err) { + next(err); + } + }); + router.all(route.path, ...handlers); + } + + return router; +} + +/** + * Express router adapter for the Web-standard `mcpAuthMetadataRouter` from `@modelcontextprotocol/server`. + * + * IMPORTANT: This router MUST be mounted at the application root. + */ +export function mcpAuthMetadataRouter(options: AuthMetadataOptions): RequestHandler { + const web = createWebAuthMetadataRouter(options); + const router = express.Router(); + + for (const route of web.routes) { + router.all(route.path, async (req, res, next) => { + try { + const webReq = createRequest(req, res); + const parsedBody = req.body !== undefined ? req.body : await getParsedBody(webReq); + const webRes = await route.handler(webReq, { parsedBody }); + await sendResponse(res, webRes); + } catch (err) { + next(err); + } + }); + } + + return router; +} diff --git a/packages/server/src/server/express.ts b/packages/server-express/src/express.ts similarity index 100% rename from packages/server/src/server/express.ts rename to packages/server-express/src/express.ts diff --git a/packages/server-express/src/index.ts b/packages/server-express/src/index.ts new file mode 100644 index 000000000..3c5b72fe7 --- /dev/null +++ b/packages/server-express/src/index.ts @@ -0,0 +1,4 @@ +export * from './auth/bearerAuth.js'; +export * from './auth/router.js'; +export * from './express.js'; +export * from './middleware/hostHeaderValidation.js'; diff --git a/packages/server-express/src/middleware/hostHeaderValidation.ts b/packages/server-express/src/middleware/hostHeaderValidation.ts new file mode 100644 index 000000000..00ee74e1f --- /dev/null +++ b/packages/server-express/src/middleware/hostHeaderValidation.ts @@ -0,0 +1,52 @@ +import { localhostAllowedHostnames, validateHostHeader } from '@modelcontextprotocol/server'; +import type { NextFunction, Request, RequestHandler, Response } from 'express'; + +/** + * Express middleware for DNS rebinding protection. + * Validates Host header hostname (port-agnostic) against an allowed list. + * + * This is particularly important for servers without authorization or HTTPS, + * such as localhost servers or development servers. DNS rebinding attacks can + * bypass same-origin policy by manipulating DNS to point a domain to a + * localhost address, allowing malicious websites to access your local server. + * + * @param allowedHostnames - List of allowed hostnames (without ports). + * For IPv6, provide the address with brackets (e.g., '[::1]'). + * @returns Express middleware function + * + * @example + * ```typescript + * const middleware = hostHeaderValidation(['localhost', '127.0.0.1', '[::1]']); + * app.use(middleware); + * ``` + */ +export function hostHeaderValidation(allowedHostnames: string[]): RequestHandler { + return (req: Request, res: Response, next: NextFunction) => { + const result = validateHostHeader(req.headers.host, allowedHostnames); + if (!result.ok) { + res.status(403).json({ + jsonrpc: '2.0', + error: { + code: -32000, + message: result.message + }, + id: null + }); + return; + } + next(); + }; +} + +/** + * Convenience middleware for localhost DNS rebinding protection. + * Allows only localhost, 127.0.0.1, and [::1] (IPv6 localhost) hostnames. + * + * @example + * ```typescript + * app.use(localhostHostValidation()); + * ``` + */ +export function localhostHostValidation(): RequestHandler { + return hostHeaderValidation(localhostAllowedHostnames()); +} diff --git a/packages/server/test/server/auth/router.test.ts b/packages/server-express/test/server/auth/router.test.ts similarity index 89% rename from packages/server/test/server/auth/router.test.ts rename to packages/server-express/test/server/auth/router.test.ts index 250fca4c4..9d7638543 100644 --- a/packages/server/test/server/auth/router.test.ts +++ b/packages/server-express/test/server/auth/router.test.ts @@ -1,14 +1,18 @@ -import type { OAuthClientInformationFull, OAuthMetadata, OAuthTokenRevocationRequest, OAuthTokens } from '@modelcontextprotocol/core'; -import type { AuthInfo } from '@modelcontextprotocol/core'; -import { InvalidTokenError } from '@modelcontextprotocol/core'; -import type { Response } from 'express'; +import type { + AuthInfo, + OAuthClientInformationFull, + OAuthMetadata, + OAuthTokenRevocationRequest, + OAuthTokens +} from '@modelcontextprotocol/server'; +import { InvalidTokenError } from '@modelcontextprotocol/server'; import express from 'express'; import supertest from 'supertest'; -import type { OAuthRegisteredClientsStore } from '../../../src/server/auth/clients.js'; -import type { AuthorizationParams, OAuthServerProvider } from '../../../src/server/auth/provider.js'; -import type { AuthMetadataOptions, AuthRouterOptions } from '../../../src/server/auth/router.js'; -import { mcpAuthMetadataRouter, mcpAuthRouter } from '../../../src/server/auth/router.js'; +import type { OAuthRegisteredClientsStore } from '@modelcontextprotocol/server'; +import type { AuthorizationParams, OAuthServerProvider } from '@modelcontextprotocol/server'; +import type { AuthMetadataOptions, AuthRouterOptions } from '@modelcontextprotocol/server'; +import { mcpAuthMetadataRouter, mcpAuthRouter } from '../../../src/auth/router.js'; describe('MCP Auth Router', () => { // Setup mock provider with full capabilities @@ -32,13 +36,13 @@ describe('MCP Auth Router', () => { const mockProvider: OAuthServerProvider = { clientsStore: mockClientStore, - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + async authorize(_client: OAuthClientInformationFull, params: AuthorizationParams): Promise { const redirectUrl = new URL(params.redirectUri); redirectUrl.searchParams.set('code', 'mock_auth_code'); if (params.state) { redirectUrl.searchParams.set('state', params.state); } - res.redirect(302, redirectUrl.toString()); + return Response.redirect(redirectUrl.toString(), 302); }, async challengeForAuthorizationCode(): Promise { @@ -95,13 +99,13 @@ describe('MCP Auth Router', () => { } }, - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + async authorize(_client: OAuthClientInformationFull, params: AuthorizationParams): Promise { const redirectUrl = new URL(params.redirectUri); redirectUrl.searchParams.set('code', 'mock_auth_code'); if (params.state) { redirectUrl.searchParams.set('state', params.state); } - res.redirect(302, redirectUrl.toString()); + return Response.redirect(redirectUrl.toString(), 302); }, async challengeForAuthorizationCode(): Promise { @@ -321,6 +325,36 @@ describe('MCP Auth Router', () => { expect(response.status).not.toBe(404); }); + it('applies rate limiting to token endpoint (express-rate-limit)', async () => { + // Fresh app with a very low rate limit so we can trigger it deterministically + const limitedApp = express(); + const options = { + provider: mockProvider, + issuerUrl: new URL('https://auth.example.com'), + rateLimit: { windowMs: 60_000, max: 1 } + } as const; + limitedApp.use(mcpAuthRouter(options)); + + const first = await supertest(limitedApp).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'valid_verifier' + }); + expect(first.status).not.toBe(404); + + const second = await supertest(limitedApp).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'valid_verifier' + }); + expect(second.status).toBe(429); + expect(second.body).toEqual(expect.objectContaining({ error: 'too_many_requests' })); + }); + it('routes to registration endpoint', async () => { const response = await supertest(app) .post('/register') diff --git a/packages/server-express/tsconfig.json b/packages/server-express/tsconfig.json new file mode 100644 index 000000000..0d7fdd0c0 --- /dev/null +++ b/packages/server-express/tsconfig.json @@ -0,0 +1,14 @@ +{ + "extends": "@modelcontextprotocol/tsconfig", + "include": ["./"], + "exclude": ["node_modules", "dist"], + "compilerOptions": { + "paths": { + "*": ["./*"], + "@modelcontextprotocol/server": ["./node_modules/@modelcontextprotocol/server/src/index.ts"], + "@modelcontextprotocol/core": [ + "./node_modules/@modelcontextprotocol/server/node_modules/@modelcontextprotocol/core/src/index.ts" + ] + } + } +} diff --git a/packages/server-express/tsdown.config.ts b/packages/server-express/tsdown.config.ts new file mode 100644 index 000000000..c72e7a2c4 --- /dev/null +++ b/packages/server-express/tsdown.config.ts @@ -0,0 +1,23 @@ +import { defineConfig } from 'tsdown'; + +export default defineConfig({ + entry: ['src/index.ts'], + format: ['esm'], + outDir: 'dist', + clean: true, + sourcemap: true, + target: 'esnext', + platform: 'node', + shims: true, + dts: { + resolver: 'tsc', + compilerOptions: { + baseUrl: '.', + paths: { + '@modelcontextprotocol/server': ['../server/src/index.ts'], + '@modelcontextprotocol/core': ['../core/src/index.ts'] + } + } + }, + noExternal: ['@modelcontextprotocol/server', '@modelcontextprotocol/core'] +}); diff --git a/packages/server-express/vitest.config.js b/packages/server-express/vitest.config.js new file mode 100644 index 000000000..496fca320 --- /dev/null +++ b/packages/server-express/vitest.config.js @@ -0,0 +1,3 @@ +import baseConfig from '@modelcontextprotocol/vitest-config'; + +export default baseConfig; diff --git a/packages/server-hono/README.md b/packages/server-hono/README.md new file mode 100644 index 000000000..d2788881a --- /dev/null +++ b/packages/server-hono/README.md @@ -0,0 +1,64 @@ +# `@modelcontextprotocol/server-hono` + +Hono adapters for the MCP TypeScript server SDK. + +This package is the Hono-specific companion to [`@modelcontextprotocol/server`](../server/), which is framework-agnostic and uses Web Standard `Request`/`Response` interfaces. + +## Install + +```bash +npm install @modelcontextprotocol/server @modelcontextprotocol/server-hono hono zod +``` + +## Exports + +- `mcpStreamableHttpHandler(transport)` +- `registerMcpAuthRoutes(app, options)` +- `registerMcpAuthMetadataRoutes(app, options)` +- `hostHeaderValidation(allowedHosts)` +- `localhostHostValidation()` + +## Usage + +### Streamable HTTP endpoint (Hono) + +```ts +import { Hono } from 'hono'; +import { McpServer, WebStandardStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { mcpStreamableHttpHandler } from '@modelcontextprotocol/server-hono'; + +const server = new McpServer({ name: 'my-server', version: '1.0.0' }); +const transport = new WebStandardStreamableHTTPServerTransport(); +await server.connect(transport); + +const app = new Hono(); +app.all('/mcp', mcpStreamableHttpHandler(transport)); +``` + +### OAuth routes (Hono) + +`@modelcontextprotocol/server` provides Web-standard auth handlers; this package mounts them onto a Hono app. + +```ts +import { Hono } from 'hono'; +import type { OAuthServerProvider } from '@modelcontextprotocol/server'; +import { registerMcpAuthRoutes } from '@modelcontextprotocol/server-hono'; + +const provider: OAuthServerProvider = /* ... */; + +const app = new Hono(); +registerMcpAuthRoutes(app, { + provider, + issuerUrl: new URL('https://auth.example.com') +}); +``` + +### Host header validation (DNS rebinding protection) + +```ts +import { Hono } from 'hono'; +import { localhostHostValidation } from '@modelcontextprotocol/server-hono'; + +const app = new Hono(); +app.use('*', localhostHostValidation()); +``` diff --git a/packages/server-hono/eslint.config.mjs b/packages/server-hono/eslint.config.mjs new file mode 100644 index 000000000..03d533134 --- /dev/null +++ b/packages/server-hono/eslint.config.mjs @@ -0,0 +1,12 @@ +// @ts-check + +import baseConfig from '@modelcontextprotocol/eslint-config'; + +export default [ + ...baseConfig, + { + settings: { + 'import/internal-regex': '^@modelcontextprotocol/(server|core)' + } + } +]; diff --git a/packages/server-hono/package.json b/packages/server-hono/package.json new file mode 100644 index 000000000..ac5b01a89 --- /dev/null +++ b/packages/server-hono/package.json @@ -0,0 +1,64 @@ +{ + "name": "@modelcontextprotocol/server-hono", + "private": false, + "version": "2.0.0-alpha.0", + "description": "Hono adapters for the Model Context Protocol TypeScript server SDK", + "license": "MIT", + "author": "Anthropic, PBC (https://anthropic.com)", + "homepage": "https://modelcontextprotocol.io", + "bugs": "https://github.com/modelcontextprotocol/typescript-sdk/issues", + "type": "module", + "repository": { + "type": "git", + "url": "git+https://github.com/modelcontextprotocol/typescript-sdk.git" + }, + "engines": { + "node": ">=20", + "pnpm": ">=10.24.0" + }, + "packageManager": "pnpm@10.24.0", + "keywords": [ + "modelcontextprotocol", + "mcp", + "hono" + ], + "exports": { + ".": { + "types": "./dist/index.d.mts", + "import": "./dist/index.mjs" + } + }, + "files": [ + "dist" + ], + "scripts": { + "typecheck": "tsgo -p tsconfig.json --noEmit", + "build": "tsdown", + "build:watch": "tsdown --watch", + "prepack": "npm run build", + "lint": "eslint src/ && prettier --ignore-path ../../.prettierignore --check .", + "lint:fix": "eslint src/ --fix && prettier --ignore-path ../../.prettierignore --write .", + "check": "npm run typecheck && npm run lint", + "test": "vitest run", + "test:watch": "vitest" + }, + "dependencies": { + "@modelcontextprotocol/server": "workspace:^", + "hono": "catalog:runtimeServerOnly" + }, + "devDependencies": { + "@modelcontextprotocol/tsconfig": "workspace:^", + "@modelcontextprotocol/vitest-config": "workspace:^", + "@modelcontextprotocol/eslint-config": "workspace:^", + "@eslint/js": "catalog:devTools", + "@typescript/native-preview": "catalog:devTools", + "eslint": "catalog:devTools", + "eslint-config-prettier": "catalog:devTools", + "eslint-plugin-n": "catalog:devTools", + "prettier": "catalog:devTools", + "tsdown": "catalog:devTools", + "typescript": "catalog:devTools", + "typescript-eslint": "catalog:devTools", + "vitest": "catalog:devTools" + } +} diff --git a/packages/server-hono/src/auth/bearerAuth.ts b/packages/server-hono/src/auth/bearerAuth.ts new file mode 100644 index 000000000..1258eee11 --- /dev/null +++ b/packages/server-hono/src/auth/bearerAuth.ts @@ -0,0 +1,19 @@ +import type { BearerAuthMiddlewareOptions } from '@modelcontextprotocol/server'; +import { requireBearerAuth as requireBearerAuthWeb } from '@modelcontextprotocol/server'; +import type { MiddlewareHandler } from 'hono'; +/** + * Hono middleware wrapper for the Web-standard `requireBearerAuth` helper. + * + * On success, sets `c.set('auth', authInfo)` and calls `next()`. + * On failure, returns the JSON error response. + */ +export function requireBearerAuth(options: BearerAuthMiddlewareOptions): MiddlewareHandler { + return async (c, next) => { + const result = await requireBearerAuthWeb(c.req.raw, options); + if ('authInfo' in result) { + c.set('auth', result.authInfo); + return await next(); + } + return result.response; + }; +} diff --git a/packages/server-hono/src/auth/router.ts b/packages/server-hono/src/auth/router.ts new file mode 100644 index 000000000..f17765318 --- /dev/null +++ b/packages/server-hono/src/auth/router.ts @@ -0,0 +1,61 @@ +import type { AuthMetadataOptions, AuthRoute, AuthRouterOptions } from '@modelcontextprotocol/server'; +import { + getParsedBody, + mcpAuthMetadataRouter as createWebAuthMetadataRouter, + mcpAuthRouter as createWebAuthRouter +} from '@modelcontextprotocol/server'; +import type { Handler } from 'hono'; +import { Hono } from 'hono'; + +/** + * Hono router adapter for the Web-standard `mcpAuthRouter` from `@modelcontextprotocol/server`. + * + * IMPORTANT: This router MUST be mounted at the application root. + * + * @example + * ```ts + * app.route('/', mcpAuthRouter(...)) + * ``` + */ +export function mcpAuthRouter(options: AuthRouterOptions): Hono { + const web = createWebAuthRouter(options); + const router = new Hono(); + registerRoutes(router, web.routes); + return router; +} + +/** + * Hono router adapter for the Web-standard `mcpAuthMetadataRouter` from `@modelcontextprotocol/server`. + * + * IMPORTANT: This router MUST be mounted at the application root. + */ +export function mcpAuthMetadataRouter(options: AuthMetadataOptions): Hono { + const web = createWebAuthMetadataRouter(options); + const router = new Hono(); + registerRoutes(router, web.routes); + return router; +} + +function registerRoutes(app: Hono, routes: AuthRoute[]): void { + for (const route of routes) { + // Use `all()` so unsupported methods still reach the handler and can return 405, + // matching the Express adapter behavior. + const handler: Handler = async c => { + let parsedBody = c.get('parsedBody'); + if (parsedBody === undefined && c.req.method === 'POST') { + // Parse from a clone so we don't consume the original request stream. + parsedBody = await getParsedBody(c.req.raw.clone()); + } + return route.handler(c.req.raw, { parsedBody }); + }; + app.all(route.path, handler); + } +} + +export function registerMcpAuthRoutes(app: Hono, options: AuthRouterOptions): void { + app.route('/', mcpAuthRouter(options)); +} + +export function registerMcpAuthMetadataRoutes(app: Hono, options: AuthMetadataOptions): void { + app.route('/', mcpAuthMetadataRouter(options)); +} diff --git a/packages/server-hono/src/hono.ts b/packages/server-hono/src/hono.ts new file mode 100644 index 000000000..accf4ab27 --- /dev/null +++ b/packages/server-hono/src/hono.ts @@ -0,0 +1,90 @@ +import type { Context } from 'hono'; +import { Hono } from 'hono'; + +import { hostHeaderValidation, localhostHostValidation } from './middleware/hostHeaderValidation.js'; + +/** + * Options for creating an MCP Hono application. + */ +export interface CreateMcpHonoAppOptions { + /** + * The hostname to bind to. Defaults to '127.0.0.1'. + * When set to '127.0.0.1', 'localhost', or '::1', DNS rebinding protection is automatically enabled. + */ + host?: string; + + /** + * List of allowed hostnames for DNS rebinding protection. + * If provided, host header validation will be applied using this list. + * For IPv6, provide addresses with brackets (e.g., '[::1]'). + * + * This is useful when binding to '0.0.0.0' or '::' but still wanting + * to restrict which hostnames are allowed. + */ + allowedHosts?: string[]; +} + +/** + * Creates a Hono application pre-configured for MCP servers. + * + * When the host is '127.0.0.1', 'localhost', or '::1' (the default is '127.0.0.1'), + * DNS rebinding protection middleware is automatically applied to protect against + * DNS rebinding attacks on localhost servers. + * + * This also installs a small JSON body parsing middleware (similar to `express.json()`) + * that stashes the parsed body into `c.set('parsedBody', ...)` when `Content-Type` includes + * `application/json`. + * + * @param options - Configuration options + * @returns A configured Hono application + */ +export function createMcpHonoApp(options: CreateMcpHonoAppOptions = {}): Hono { + const { host = '127.0.0.1', allowedHosts } = options; + + const app = new Hono(); + + // Similar to `express.json()`: parse JSON bodies and make them available to MCP adapters via `parsedBody`. + app.use('*', async (c: Context, next) => { + // If an upstream middleware already set parsedBody, keep it. + if (c.get('parsedBody') !== undefined) { + return await next(); + } + + const ct = c.req.header('content-type') ?? ''; + if (!ct.includes('application/json')) { + return await next(); + } + + try { + // Parse from a clone so we don't consume the original request stream. + const parsed = await c.req.raw.clone().json(); + c.set('parsedBody', parsed); + } catch { + // Mirror express.json() behavior loosely: reject invalid JSON. + return c.text('Invalid JSON', 400); + } + + return await next(); + }); + + // If allowedHosts is explicitly provided, use that for validation. + if (allowedHosts) { + app.use('*', hostHeaderValidation(allowedHosts)); + } else { + // Apply DNS rebinding protection automatically for localhost hosts. + const localhostHosts = ['127.0.0.1', 'localhost', '::1']; + if (localhostHosts.includes(host)) { + app.use('*', localhostHostValidation()); + } else if (host === '0.0.0.0' || host === '::') { + // Warn when binding to all interfaces without DNS rebinding protection. + // eslint-disable-next-line no-console + console.warn( + `Warning: Server is binding to ${host} without DNS rebinding protection. ` + + 'Consider using the allowedHosts option to restrict allowed hosts, ' + + 'or use authentication to protect your server.' + ); + } + } + + return app; +} diff --git a/packages/server-hono/src/index.ts b/packages/server-hono/src/index.ts new file mode 100644 index 000000000..bc6de4318 --- /dev/null +++ b/packages/server-hono/src/index.ts @@ -0,0 +1,5 @@ +export * from './auth/bearerAuth.js'; +export * from './auth/router.js'; +export * from './hono.js'; +export * from './middleware/hostHeaderValidation.js'; +export * from './streamableHttp.js'; diff --git a/packages/server-hono/src/middleware/hostHeaderValidation.ts b/packages/server-hono/src/middleware/hostHeaderValidation.ts new file mode 100644 index 000000000..8f7b20e88 --- /dev/null +++ b/packages/server-hono/src/middleware/hostHeaderValidation.ts @@ -0,0 +1,33 @@ +import { localhostAllowedHostnames, validateHostHeader } from '@modelcontextprotocol/server'; +import type { MiddlewareHandler } from 'hono'; + +/** + * Hono middleware for DNS rebinding protection. + * Validates Host header hostname (port-agnostic) against an allowed list. + */ +export function hostHeaderValidation(allowedHostnames: string[]): MiddlewareHandler { + return async (c, next) => { + const result = validateHostHeader(c.req.header('host'), allowedHostnames); + if (!result.ok) { + return c.json( + { + jsonrpc: '2.0', + error: { + code: -32000, + message: result.message + }, + id: null + }, + 403 + ); + } + return await next(); + }; +} + +/** + * Convenience middleware for localhost DNS rebinding protection. + */ +export function localhostHostValidation(): MiddlewareHandler { + return hostHeaderValidation(localhostAllowedHostnames()); +} diff --git a/packages/server-hono/src/streamableHttp.ts b/packages/server-hono/src/streamableHttp.ts new file mode 100644 index 000000000..2da1bafcd --- /dev/null +++ b/packages/server-hono/src/streamableHttp.ts @@ -0,0 +1,23 @@ +import type { WebStandardStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { getParsedBody } from '@modelcontextprotocol/server'; +import type { Context, Handler } from 'hono'; + +/** + * Convenience Hono handler for the WebStandard Streamable HTTP transport. + * + * Usage: + * ```ts + * app.all('/mcp', mcpStreamableHttpHandler(transport)) + * ``` + */ +export function mcpStreamableHttpHandler(transport: WebStandardStreamableHTTPServerTransport): Handler { + return async (c: Context) => { + let parsedBody = c.get('parsedBody'); + if (parsedBody === undefined && c.req.method === 'POST') { + // Parse from a clone so we don't consume the original request stream. + parsedBody = await getParsedBody(c.req.raw.clone()); + } + const authInfo = c.get('auth'); + return transport.handleRequest(c.req.raw, { authInfo, parsedBody }); + }; +} diff --git a/packages/server-hono/test/server-hono.test.ts b/packages/server-hono/test/server-hono.test.ts new file mode 100644 index 000000000..130e11c71 --- /dev/null +++ b/packages/server-hono/test/server-hono.test.ts @@ -0,0 +1,279 @@ +import type { AuthorizationParams, OAuthClientInformationFull, OAuthServerProvider, OAuthTokens } from '@modelcontextprotocol/server'; +import type { Context } from 'hono'; +import { Hono } from 'hono'; +import { vi } from 'vitest'; + +import { mcpAuthRouter } from '../src/auth/router.js'; +import { createMcpHonoApp } from '../src/hono.js'; +import { hostHeaderValidation } from '../src/middleware/hostHeaderValidation.js'; +import { mcpStreamableHttpHandler } from '../src/streamableHttp.js'; + +describe('@modelcontextprotocol/server-hono', () => { + test('mcpStreamableHttpHandler delegates to transport.handleRequest (and passes authInfo + parsedBody when set)', async () => { + const calls: { url?: string; method?: string; options?: unknown }[] = []; + + const transport = { + async handleRequest(req: Request, options?: unknown): Promise { + calls.push({ url: req.url, method: req.method, options }); + return new Response('ok', { status: 200, headers: { 'content-type': 'text/plain' } }); + } + }; + + const app = new Hono(); + app.use('/mcp', async (c: Context, next) => { + // Upstream middleware can pre-parse and stash body + auth. + c.set('parsedBody', { hello: 'world' }); + c.set('auth', { + token: 't', + clientId: 'c', + scopes: [], + expiresAt: Math.floor(Date.now() / 1000) + 60 + }); + return await next(); + }); + app.all('/mcp', mcpStreamableHttpHandler(transport as unknown as Parameters[0])); + + const res = await app.request('http://localhost/mcp', { method: 'POST' }); + expect(res.status).toBe(200); + expect(await res.text()).toBe('ok'); + expect(calls).toHaveLength(1); + expect(calls[0]!.method).toBe('POST'); + expect(calls[0]!.url).toBe('http://localhost/mcp'); + expect(calls[0]!.options).toEqual( + expect.objectContaining({ + parsedBody: { hello: 'world' }, + authInfo: expect.objectContaining({ clientId: 'c' }) + }) + ); + }); + + test('hostHeaderValidation blocks invalid Host and allows valid Host', async () => { + const app = new Hono(); + app.use('*', hostHeaderValidation(['localhost'])); + app.get('/health', c => c.text('ok')); + + const bad = await app.request('http://localhost/health', { headers: { Host: 'evil.com:3000' } }); + expect(bad.status).toBe(403); + expect(await bad.json()).toEqual( + expect.objectContaining({ + jsonrpc: '2.0', + error: expect.objectContaining({ + code: -32000 + }), + id: null + }) + ); + + const good = await app.request('http://localhost/health', { headers: { Host: 'localhost:3000' } }); + expect(good.status).toBe(200); + expect(await good.text()).toBe('ok'); + }); + + test('registerMcpAuthRoutes mounts metadata + authorize routes', async () => { + const validClient: OAuthClientInformationFull = { + client_id: 'valid-client', + client_secret: 'valid-secret', + redirect_uris: ['https://example.com/callback'] + }; + + const provider: OAuthServerProvider = { + clientsStore: { + async getClient(clientId: string) { + return clientId === 'valid-client' ? validClient : undefined; + } + }, + async authorize(_client: OAuthClientInformationFull, params: AuthorizationParams): Promise { + const u = new URL(params.redirectUri); + u.searchParams.set('code', 'mock_auth_code'); + if (params.state) u.searchParams.set('state', params.state); + return Response.redirect(u.toString(), 302); + }, + async challengeForAuthorizationCode(): Promise { + return 'mock_challenge'; + }, + async exchangeAuthorizationCode(): Promise { + return { + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' + }; + }, + async exchangeRefreshToken(): Promise { + return { + access_token: 'new_mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'new_mock_refresh_token' + }; + }, + async verifyAccessToken() { + throw new Error('not used'); + } + }; + + const app = new Hono(); + app.route('/', mcpAuthRouter({ provider, issuerUrl: new URL('https://auth.example.com') })); + + const metadata = await app.request('http://localhost/.well-known/oauth-authorization-server', { method: 'GET' }); + expect(metadata.status).toBe(200); + const metaJson = (await metadata.json()) as { issuer?: string; authorization_endpoint?: string }; + expect(metaJson.issuer).toBe('https://auth.example.com/'); + expect(metaJson.authorization_endpoint).toBe('https://auth.example.com/authorize'); + + const authorize = await app.request( + 'http://localhost/authorize?client_id=valid-client&response_type=code&code_challenge=x&code_challenge_method=S256&redirect_uri=https%3A%2F%2Fexample.com%2Fcallback&state=s', + { method: 'GET' } + ); + expect(authorize.status).toBe(302); + const location = authorize.headers.get('location')!; + expect(location).toContain('https://example.com/callback'); + expect(location).toContain('code=mock_auth_code'); + expect(location).toContain('state=s'); + }); + + test('registerMcpAuthRoutes returns 405 (not 404) for unsupported methods', async () => { + const provider: OAuthServerProvider = { + clientsStore: { + async getClient() { + return undefined; + } + }, + async authorize() { + throw new Error('not used'); + }, + async challengeForAuthorizationCode() { + throw new Error('not used'); + }, + async exchangeAuthorizationCode() { + throw new Error('not used'); + }, + async exchangeRefreshToken() { + throw new Error('not used'); + }, + async verifyAccessToken() { + throw new Error('not used'); + } + }; + + const app = new Hono(); + app.route('/', mcpAuthRouter({ provider, issuerUrl: new URL('https://auth.example.com') })); + + const res = await app.request('http://localhost/authorize', { method: 'PUT' }); + expect(res.status).toBe(405); + }); + + test('registerMcpAuthRoutes passes parsedBody to web handlers (POST /authorize works with empty raw body)', async () => { + const validClient: OAuthClientInformationFull = { + client_id: 'valid-client', + client_secret: 'valid-secret', + redirect_uris: ['https://example.com/callback'] + }; + + const provider: OAuthServerProvider = { + clientsStore: { + async getClient(clientId: string) { + return clientId === 'valid-client' ? validClient : undefined; + } + }, + async authorize(_client: OAuthClientInformationFull, params: AuthorizationParams): Promise { + const u = new URL(params.redirectUri); + u.searchParams.set('code', 'mock_auth_code'); + if (params.state) u.searchParams.set('state', params.state); + return Response.redirect(u.toString(), 302); + }, + async challengeForAuthorizationCode(): Promise { + return 'mock_challenge'; + }, + async exchangeAuthorizationCode(): Promise { + return { + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' + }; + }, + async exchangeRefreshToken(): Promise { + return { + access_token: 'new_mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'new_mock_refresh_token' + }; + }, + async verifyAccessToken() { + throw new Error('not used'); + } + }; + + const app = new Hono(); + app.use('/authorize', async (c: Context, next) => { + c.set('parsedBody', { + client_id: 'valid-client', + response_type: 'code', + code_challenge: 'x', + code_challenge_method: 'S256', + redirect_uri: 'https://example.com/callback', + state: 's' + }); + return await next(); + }); + app.route('/', mcpAuthRouter({ provider, issuerUrl: new URL('https://auth.example.com') })); + + const authorize = await app.request('http://localhost/authorize', { method: 'POST' }); + expect(authorize.status).toBe(302); + const location = authorize.headers.get('location')!; + expect(location).toContain('https://example.com/callback'); + expect(location).toContain('code=mock_auth_code'); + expect(location).toContain('state=s'); + }); + + test('createMcpHonoApp enables localhost DNS rebinding protection by default', async () => { + const app = createMcpHonoApp(); + app.get('/health', c => c.text('ok')); + + const bad = await app.request('http://localhost/health', { headers: { Host: 'evil.com:3000' } }); + expect(bad.status).toBe(403); + + const good = await app.request('http://localhost/health', { headers: { Host: 'localhost:3000' } }); + expect(good.status).toBe(200); + }); + + test('createMcpHonoApp uses allowedHosts when provided (even when binding to 0.0.0.0)', async () => { + const warn = vi.spyOn(console, 'warn').mockImplementation(() => {}); + const app = createMcpHonoApp({ host: '0.0.0.0', allowedHosts: ['myapp.local'] }); + warn.mockRestore(); + + app.get('/health', c => c.text('ok')); + + const bad = await app.request('http://localhost/health', { headers: { Host: 'evil.com:3000' } }); + expect(bad.status).toBe(403); + + const good = await app.request('http://localhost/health', { headers: { Host: 'myapp.local:3000' } }); + expect(good.status).toBe(200); + }); + + test('createMcpHonoApp does not apply host validation for 0.0.0.0 without allowedHosts', async () => { + const warn = vi.spyOn(console, 'warn').mockImplementation(() => {}); + const app = createMcpHonoApp({ host: '0.0.0.0' }); + warn.mockRestore(); + + app.get('/health', c => c.text('ok')); + + const res = await app.request('http://localhost/health', { headers: { Host: 'evil.com:3000' } }); + expect(res.status).toBe(200); + }); + + test('createMcpHonoApp parses JSON bodies into parsedBody (express.json()-like)', async () => { + const app = createMcpHonoApp(); + app.post('/echo', (c: Context) => c.json(c.get('parsedBody'))); + + const res = await app.request('http://localhost/echo', { + method: 'POST', + headers: { Host: 'localhost:3000', 'content-type': 'application/json' }, + body: JSON.stringify({ a: 1 }) + }); + expect(res.status).toBe(200); + expect(await res.json()).toEqual({ a: 1 }); + }); +}); diff --git a/packages/server-hono/tsconfig.json b/packages/server-hono/tsconfig.json new file mode 100644 index 000000000..0d7fdd0c0 --- /dev/null +++ b/packages/server-hono/tsconfig.json @@ -0,0 +1,14 @@ +{ + "extends": "@modelcontextprotocol/tsconfig", + "include": ["./"], + "exclude": ["node_modules", "dist"], + "compilerOptions": { + "paths": { + "*": ["./*"], + "@modelcontextprotocol/server": ["./node_modules/@modelcontextprotocol/server/src/index.ts"], + "@modelcontextprotocol/core": [ + "./node_modules/@modelcontextprotocol/server/node_modules/@modelcontextprotocol/core/src/index.ts" + ] + } + } +} diff --git a/packages/server-hono/tsdown.config.ts b/packages/server-hono/tsdown.config.ts new file mode 100644 index 000000000..c72e7a2c4 --- /dev/null +++ b/packages/server-hono/tsdown.config.ts @@ -0,0 +1,23 @@ +import { defineConfig } from 'tsdown'; + +export default defineConfig({ + entry: ['src/index.ts'], + format: ['esm'], + outDir: 'dist', + clean: true, + sourcemap: true, + target: 'esnext', + platform: 'node', + shims: true, + dts: { + resolver: 'tsc', + compilerOptions: { + baseUrl: '.', + paths: { + '@modelcontextprotocol/server': ['../server/src/index.ts'], + '@modelcontextprotocol/core': ['../core/src/index.ts'] + } + } + }, + noExternal: ['@modelcontextprotocol/server', '@modelcontextprotocol/core'] +}); diff --git a/packages/server-hono/vitest.config.js b/packages/server-hono/vitest.config.js new file mode 100644 index 000000000..496fca320 --- /dev/null +++ b/packages/server-hono/vitest.config.js @@ -0,0 +1,3 @@ +import baseConfig from '@modelcontextprotocol/vitest-config'; + +export default baseConfig; diff --git a/packages/server/package.json b/packages/server/package.json index b039751f6..20bd77aae 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -44,14 +44,10 @@ "client": "tsx scripts/cli.ts client" }, "dependencies": { - "content-type": "catalog:runtimeServerOnly", - "cors": "catalog:runtimeServerOnly", "@hono/node-server": "catalog:runtimeServerOnly", - "hono": "catalog:runtimeServerOnly", - "express": "catalog:runtimeServerOnly", - "express-rate-limit": "catalog:runtimeServerOnly", - "raw-body": "catalog:runtimeServerOnly", + "content-type": "catalog:runtimeServerOnly", "pkce-challenge": "catalog:runtimeShared", + "raw-body": "catalog:runtimeServerOnly", "zod": "catalog:runtimeShared", "zod-to-json-schema": "catalog:runtimeShared" }, @@ -68,13 +64,13 @@ } }, "devDependencies": { + "@cfworker/json-schema": "catalog:runtimeShared", + "@eslint/js": "catalog:devTools", "@modelcontextprotocol/core": "workspace:^", - "@modelcontextprotocol/tsconfig": "workspace:^", - "@modelcontextprotocol/vitest-config": "workspace:^", "@modelcontextprotocol/eslint-config": "workspace:^", "@modelcontextprotocol/test-helpers": "workspace:^", - "@cfworker/json-schema": "catalog:runtimeShared", - "@eslint/js": "catalog:devTools", + "@modelcontextprotocol/tsconfig": "workspace:^", + "@modelcontextprotocol/vitest-config": "workspace:^", "@types/content-type": "catalog:devTools", "@types/cors": "catalog:devTools", "@types/cross-spawn": "catalog:devTools", diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index 4b0c42053..667ba49d3 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -1,6 +1,6 @@ export * from './server/completable.js'; -export * from './server/express.js'; export * from './server/mcp.js'; +export * from './server/middleware/hostHeaderValidation.js'; export * from './server/server.js'; export * from './server/sse.js'; export * from './server/stdio.js'; diff --git a/packages/server/src/server/auth/handlers/authorize.ts b/packages/server/src/server/auth/handlers/authorize.ts index 65875529e..ecffee114 100644 --- a/packages/server/src/server/auth/handlers/authorize.ts +++ b/packages/server/src/server/auth/handlers/authorize.ts @@ -1,20 +1,12 @@ -import { InvalidClientError, InvalidRequestError, OAuthError, ServerError, TooManyRequestsError } from '@modelcontextprotocol/core'; -import type { RequestHandler } from 'express'; -import express from 'express'; -import type { Options as RateLimitOptions } from 'express-rate-limit'; -import { rateLimit } from 'express-rate-limit'; +import { InvalidClientError, InvalidRequestError, OAuthError, ServerError } from '@modelcontextprotocol/core'; import * as z from 'zod/v4'; -import { allowedMethods } from '../middleware/allowedMethods.js'; import type { OAuthServerProvider } from '../provider.js'; +import type { WebHandler } from '../web.js'; +import { getParsedBody, jsonResponse, methodNotAllowedResponse, noStoreHeaders } from '../web.js'; export type AuthorizationHandlerOptions = { provider: OAuthServerProvider; - /** - * Rate limiting configuration for the authorization endpoint. - * Set to false to disable rate limiting for this endpoint. - */ - rateLimit?: Partial | false; }; // Parameters that must be validated in order to issue redirects. @@ -36,28 +28,18 @@ const RequestAuthorizationParamsSchema = z.object({ resource: z.string().url().optional() }); -export function authorizationHandler({ provider, rateLimit: rateLimitConfig }: AuthorizationHandlerOptions): RequestHandler { - // Create a router to apply middleware - const router = express.Router(); - router.use(allowedMethods(['GET', 'POST'])); - router.use(express.urlencoded({ extended: false })); - - // Apply rate limiting unless explicitly disabled - if (rateLimitConfig !== false) { - router.use( - rateLimit({ - windowMs: 15 * 60 * 1000, // 15 minutes - max: 100, // 100 requests per windowMs - standardHeaders: true, - legacyHeaders: false, - message: new TooManyRequestsError('You have exceeded the rate limit for authorization requests').toResponseObject(), - ...rateLimitConfig - }) - ); - } - - router.all('/', async (req, res) => { - res.setHeader('Cache-Control', 'no-store'); +export function authorizationHandler({ provider }: AuthorizationHandlerOptions): WebHandler { + return async (req, ctx) => { + const noStore = noStoreHeaders(); + + if (req.method !== 'GET' && req.method !== 'POST') { + const resp = methodNotAllowedResponse(req, ['GET', 'POST']); + const body = await resp.text(); + return new Response(body, { + status: resp.status, + headers: { ...Object.fromEntries(resp.headers.entries()), ...noStore } + }); + } // In the authorization flow, errors are split into two categories: // 1. Pre-redirect errors (direct response with 400) @@ -66,7 +48,9 @@ export function authorizationHandler({ provider, rateLimit: rateLimitConfig }: A // Phase 1: Validate client_id and redirect_uri. Any errors here must be direct responses. let client_id, redirect_uri, client; try { - const result = ClientAuthorizationParamsSchema.safeParse(req.method === 'POST' ? req.body : req.query); + const source = + req.method === 'POST' ? await getParsedBody(req, ctx) : Object.fromEntries(new URL(req.url).searchParams.entries()); + const result = ClientAuthorizationParamsSchema.safeParse(source); if (!result.success) { throw new InvalidRequestError(result.error.message); } @@ -97,20 +81,20 @@ export function authorizationHandler({ provider, rateLimit: rateLimitConfig }: A // user anyway. if (error instanceof OAuthError) { const status = error instanceof ServerError ? 500 : 400; - res.status(status).json(error.toResponseObject()); + return jsonResponse(error.toResponseObject(), { status, headers: noStore }); } else { const serverError = new ServerError('Internal Server Error'); - res.status(500).json(serverError.toResponseObject()); + return jsonResponse(serverError.toResponseObject(), { status: 500, headers: noStore }); } - - return; } // Phase 2: Validate other parameters. Any errors here should go into redirect responses. let state; try { // Parse and validate authorization parameters - const parseResult = RequestAuthorizationParamsSchema.safeParse(req.method === 'POST' ? req.body : req.query); + const source = + req.method === 'POST' ? await getParsedBody(req, ctx) : Object.fromEntries(new URL(req.url).searchParams.entries()); + const parseResult = RequestAuthorizationParamsSchema.safeParse(source); if (!parseResult.success) { throw new InvalidRequestError(parseResult.error.message); } @@ -125,29 +109,28 @@ export function authorizationHandler({ provider, rateLimit: rateLimitConfig }: A } // All validation passed, proceed with authorization - await provider.authorize( - client, - { - state, - scopes: requestedScopes, - redirectUri: redirect_uri!, // TODO: Someone to look at. Strict tsconfig showed this could be undefined, while the return type is string. - codeChallenge: code_challenge, - resource: resource ? new URL(resource) : undefined - }, - res - ); + const providerResponse = await provider.authorize(client, { + state, + scopes: requestedScopes, + redirectUri: redirect_uri!, // TODO: Someone to look at. Strict tsconfig showed this could be undefined, while the return type is string. + codeChallenge: code_challenge, + resource: resource ? new URL(resource) : undefined + }); + const headers = new Headers(providerResponse.headers); + headers.set('Cache-Control', 'no-store'); + return new Response(providerResponse.body, { status: providerResponse.status, headers }); } catch (error) { // Post-redirect errors - redirect with error parameters if (error instanceof OAuthError) { - res.redirect(302, createErrorRedirect(redirect_uri!, error, state)); + const location = createErrorRedirect(redirect_uri!, error, state); + return new Response(null, { status: 302, headers: { Location: location, ...noStore } }); } else { const serverError = new ServerError('Internal Server Error'); - res.redirect(302, createErrorRedirect(redirect_uri!, serverError, state)); + const location = createErrorRedirect(redirect_uri!, serverError, state); + return new Response(null, { status: 302, headers: { Location: location, ...noStore } }); } } - }); - - return router; + }; } /** diff --git a/packages/server/src/server/auth/handlers/metadata.ts b/packages/server/src/server/auth/handlers/metadata.ts index 529a6e57a..fea42a8cb 100644 --- a/packages/server/src/server/auth/handlers/metadata.ts +++ b/packages/server/src/server/auth/handlers/metadata.ts @@ -1,21 +1,33 @@ import type { OAuthMetadata, OAuthProtectedResourceMetadata } from '@modelcontextprotocol/core'; -import cors from 'cors'; -import type { RequestHandler } from 'express'; -import express from 'express'; -import { allowedMethods } from '../middleware/allowedMethods.js'; +import type { WebHandler } from '../web.js'; +import { corsHeaders, corsPreflightResponse, jsonResponse, methodNotAllowedResponse } from '../web.js'; -export function metadataHandler(metadata: OAuthMetadata | OAuthProtectedResourceMetadata): RequestHandler { - // Nested router so we can configure middleware and restrict HTTP method - const router = express.Router(); +export function metadataHandler(metadata: OAuthMetadata | OAuthProtectedResourceMetadata): WebHandler { + const cors = { + allowOrigin: '*', + allowMethods: ['GET', 'OPTIONS'], + allowHeaders: ['Content-Type', 'Authorization'], + maxAgeSeconds: 60 * 60 * 24 + } as const; - // Configure CORS to allow any origin, to make accessible to web-based MCP clients - router.use(cors()); + return async req => { + if (req.method === 'OPTIONS') { + return corsPreflightResponse(cors); + } + if (req.method !== 'GET') { + const resp = methodNotAllowedResponse(req, ['GET', 'OPTIONS']); + // Add CORS headers for consistency with successful responses. + const body = await resp.text(); + return new Response(body, { + status: resp.status, + headers: { ...Object.fromEntries(resp.headers.entries()), ...corsHeaders(cors) } + }); + } - router.use(allowedMethods(['GET', 'OPTIONS'])); - router.get('/', (req, res) => { - res.status(200).json(metadata); - }); - - return router; + return jsonResponse(metadata, { + status: 200, + headers: corsHeaders(cors) + }); + }; } diff --git a/packages/server/src/server/auth/handlers/register.ts b/packages/server/src/server/auth/handlers/register.ts index a78154d48..4433a1b5b 100644 --- a/packages/server/src/server/auth/handlers/register.ts +++ b/packages/server/src/server/auth/handlers/register.ts @@ -1,21 +1,11 @@ import crypto from 'node:crypto'; import type { OAuthClientInformationFull } from '@modelcontextprotocol/core'; -import { - InvalidClientMetadataError, - OAuthClientMetadataSchema, - OAuthError, - ServerError, - TooManyRequestsError -} from '@modelcontextprotocol/core'; -import cors from 'cors'; -import type { RequestHandler } from 'express'; -import express from 'express'; -import type { Options as RateLimitOptions } from 'express-rate-limit'; -import { rateLimit } from 'express-rate-limit'; +import { InvalidClientMetadataError, OAuthClientMetadataSchema, OAuthError, ServerError } from '@modelcontextprotocol/core'; import type { OAuthRegisteredClientsStore } from '../clients.js'; -import { allowedMethods } from '../middleware/allowedMethods.js'; +import type { WebHandler } from '../web.js'; +import { corsHeaders, corsPreflightResponse, getParsedBody, jsonResponse, methodNotAllowedResponse, noStoreHeaders } from '../web.js'; export type ClientRegistrationHandlerOptions = { /** @@ -30,13 +20,6 @@ export type ClientRegistrationHandlerOptions = { */ clientSecretExpirySeconds?: number; - /** - * Rate limiting configuration for the client registration endpoint. - * Set to false to disable rate limiting for this endpoint. - * Registration endpoints are particularly sensitive to abuse and should be rate limited. - */ - rateLimit?: Partial | false; - /** * Whether to generate a client ID before calling the client registration endpoint. * @@ -50,41 +33,37 @@ const DEFAULT_CLIENT_SECRET_EXPIRY_SECONDS = 30 * 24 * 60 * 60; // 30 days export function clientRegistrationHandler({ clientsStore, clientSecretExpirySeconds = DEFAULT_CLIENT_SECRET_EXPIRY_SECONDS, - rateLimit: rateLimitConfig, clientIdGeneration = true -}: ClientRegistrationHandlerOptions): RequestHandler { +}: ClientRegistrationHandlerOptions): WebHandler { if (!clientsStore.registerClient) { throw new Error('Client registration store does not support registering clients'); } - // Nested router so we can configure middleware and restrict HTTP method - const router = express.Router(); - - // Configure CORS to allow any origin, to make accessible to web-based MCP clients - router.use(cors()); - - router.use(allowedMethods(['POST'])); - router.use(express.json()); - - // Apply rate limiting unless explicitly disabled - stricter limits for registration - if (rateLimitConfig !== false) { - router.use( - rateLimit({ - windowMs: 60 * 60 * 1000, // 1 hour - max: 20, // 20 requests per hour - stricter as registration is sensitive - standardHeaders: true, - legacyHeaders: false, - message: new TooManyRequestsError('You have exceeded the rate limit for client registration requests').toResponseObject(), - ...rateLimitConfig - }) - ); - } + const cors = { + allowOrigin: '*', + allowMethods: ['POST', 'OPTIONS'], + allowHeaders: ['Content-Type', 'Authorization'], + maxAgeSeconds: 60 * 60 * 24 + } as const; - router.post('/', async (req, res) => { - res.setHeader('Cache-Control', 'no-store'); + return async (req, ctx) => { + const baseHeaders = { ...corsHeaders(cors), ...noStoreHeaders() }; + + if (req.method === 'OPTIONS') { + return corsPreflightResponse(cors); + } + if (req.method !== 'POST') { + const resp = methodNotAllowedResponse(req, ['POST', 'OPTIONS']); + const body = await resp.text(); + return new Response(body, { + status: resp.status, + headers: { ...Object.fromEntries(resp.headers.entries()), ...baseHeaders } + }); + } try { - const parseResult = OAuthClientMetadataSchema.safeParse(req.body); + const rawBody = await getParsedBody(req, ctx); + const parseResult = OAuthClientMetadataSchema.safeParse(rawBody); if (!parseResult.success) { throw new InvalidClientMetadataError(parseResult.error.message); } @@ -113,17 +92,14 @@ export function clientRegistrationHandler({ } clientInfo = await clientsStore.registerClient!(clientInfo); - res.status(201).json(clientInfo); + return jsonResponse(clientInfo, { status: 201, headers: baseHeaders }); } catch (error) { if (error instanceof OAuthError) { const status = error instanceof ServerError ? 500 : 400; - res.status(status).json(error.toResponseObject()); - } else { - const serverError = new ServerError('Internal Server Error'); - res.status(500).json(serverError.toResponseObject()); + return jsonResponse(error.toResponseObject(), { status, headers: baseHeaders }); } + const serverError = new ServerError('Internal Server Error'); + return jsonResponse(serverError.toResponseObject(), { status: 500, headers: baseHeaders }); } - }); - - return router; + }; } diff --git a/packages/server/src/server/auth/handlers/revoke.ts b/packages/server/src/server/auth/handlers/revoke.ts index c7c9f8a6a..e4814345d 100644 --- a/packages/server/src/server/auth/handlers/revoke.ts +++ b/packages/server/src/server/auth/handlers/revoke.ts @@ -1,87 +1,59 @@ -import { - InvalidRequestError, - OAuthError, - OAuthTokenRevocationRequestSchema, - ServerError, - TooManyRequestsError -} from '@modelcontextprotocol/core'; -import cors from 'cors'; -import type { RequestHandler } from 'express'; -import express from 'express'; -import type { Options as RateLimitOptions } from 'express-rate-limit'; -import { rateLimit } from 'express-rate-limit'; +import { InvalidRequestError, OAuthError, OAuthTokenRevocationRequestSchema, ServerError } from '@modelcontextprotocol/core'; -import { allowedMethods } from '../middleware/allowedMethods.js'; import { authenticateClient } from '../middleware/clientAuth.js'; import type { OAuthServerProvider } from '../provider.js'; +import type { WebHandler } from '../web.js'; +import { corsHeaders, corsPreflightResponse, getParsedBody, jsonResponse, methodNotAllowedResponse, noStoreHeaders } from '../web.js'; export type RevocationHandlerOptions = { provider: OAuthServerProvider; - /** - * Rate limiting configuration for the token revocation endpoint. - * Set to false to disable rate limiting for this endpoint. - */ - rateLimit?: Partial | false; }; -export function revocationHandler({ provider, rateLimit: rateLimitConfig }: RevocationHandlerOptions): RequestHandler { +export function revocationHandler({ provider }: RevocationHandlerOptions): WebHandler { if (!provider.revokeToken) { throw new Error('Auth provider does not support revoking tokens'); } - // Nested router so we can configure middleware and restrict HTTP method - const router = express.Router(); + const cors = { + allowOrigin: '*', + allowMethods: ['POST', 'OPTIONS'], + allowHeaders: ['Content-Type', 'Authorization'], + maxAgeSeconds: 60 * 60 * 24 + } as const; - // Configure CORS to allow any origin, to make accessible to web-based MCP clients - router.use(cors()); + return async (req, ctx) => { + const baseHeaders = { ...corsHeaders(cors), ...noStoreHeaders() }; - router.use(allowedMethods(['POST'])); - router.use(express.urlencoded({ extended: false })); - - // Apply rate limiting unless explicitly disabled - if (rateLimitConfig !== false) { - router.use( - rateLimit({ - windowMs: 15 * 60 * 1000, // 15 minutes - max: 50, // 50 requests per windowMs - standardHeaders: true, - legacyHeaders: false, - message: new TooManyRequestsError('You have exceeded the rate limit for token revocation requests').toResponseObject(), - ...rateLimitConfig - }) - ); - } - - // Authenticate and extract client details - router.use(authenticateClient({ clientsStore: provider.clientsStore })); - - router.post('/', async (req, res) => { - res.setHeader('Cache-Control', 'no-store'); + if (req.method === 'OPTIONS') { + return corsPreflightResponse(cors); + } + if (req.method !== 'POST') { + const resp = methodNotAllowedResponse(req, ['POST', 'OPTIONS']); + const body = await resp.text(); + return new Response(body, { + status: resp.status, + headers: { ...Object.fromEntries(resp.headers.entries()), ...baseHeaders } + }); + } try { - const parseResult = OAuthTokenRevocationRequestSchema.safeParse(req.body); + const rawBody = await getParsedBody(req, ctx); + const parseResult = OAuthTokenRevocationRequestSchema.safeParse(rawBody); if (!parseResult.success) { throw new InvalidRequestError(parseResult.error.message); } - const client = req.client; - if (!client) { - // This should never happen - throw new ServerError('Internal Server Error'); - } + const client = await authenticateClient(rawBody, { clientsStore: provider.clientsStore }); await provider.revokeToken!(client, parseResult.data); - res.status(200).json({}); + return jsonResponse({}, { status: 200, headers: baseHeaders }); } catch (error) { if (error instanceof OAuthError) { const status = error instanceof ServerError ? 500 : 400; - res.status(status).json(error.toResponseObject()); - } else { - const serverError = new ServerError('Internal Server Error'); - res.status(500).json(serverError.toResponseObject()); + return jsonResponse(error.toResponseObject(), { status, headers: baseHeaders }); } + const serverError = new ServerError('Internal Server Error'); + return jsonResponse(serverError.toResponseObject(), { status: 500, headers: baseHeaders }); } - }); - - return router; + }; } diff --git a/packages/server/src/server/auth/handlers/token.ts b/packages/server/src/server/auth/handlers/token.ts index 3b7941294..6dcdfd8b1 100644 --- a/packages/server/src/server/auth/handlers/token.ts +++ b/packages/server/src/server/auth/handlers/token.ts @@ -1,30 +1,14 @@ -import { - InvalidGrantError, - InvalidRequestError, - OAuthError, - ServerError, - TooManyRequestsError, - UnsupportedGrantTypeError -} from '@modelcontextprotocol/core'; -import cors from 'cors'; -import type { RequestHandler } from 'express'; -import express from 'express'; -import type { Options as RateLimitOptions } from 'express-rate-limit'; -import { rateLimit } from 'express-rate-limit'; +import { InvalidGrantError, InvalidRequestError, OAuthError, ServerError, UnsupportedGrantTypeError } from '@modelcontextprotocol/core'; import { verifyChallenge } from 'pkce-challenge'; import * as z from 'zod/v4'; -import { allowedMethods } from '../middleware/allowedMethods.js'; import { authenticateClient } from '../middleware/clientAuth.js'; import type { OAuthServerProvider } from '../provider.js'; +import type { WebHandler } from '../web.js'; +import { corsHeaders, corsPreflightResponse, getParsedBody, jsonResponse, methodNotAllowedResponse, noStoreHeaders } from '../web.js'; export type TokenHandlerOptions = { provider: OAuthServerProvider; - /** - * Rate limiting configuration for the token endpoint. - * Set to false to disable rate limiting for this endpoint. - */ - rateLimit?: Partial | false; }; const TokenRequestSchema = z.object({ @@ -44,53 +28,43 @@ const RefreshTokenGrantSchema = z.object({ resource: z.string().url().optional() }); -export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHandlerOptions): RequestHandler { - // Nested router so we can configure middleware and restrict HTTP method - const router = express.Router(); +export function tokenHandler({ provider }: TokenHandlerOptions): WebHandler { + const cors = { + allowOrigin: '*', + allowMethods: ['POST', 'OPTIONS'], + allowHeaders: ['Content-Type', 'Authorization'], + maxAgeSeconds: 60 * 60 * 24 + } as const; - // Configure CORS to allow any origin, to make accessible to web-based MCP clients - router.use(cors()); + return async (req, ctx) => { + const baseHeaders = { ...corsHeaders(cors), ...noStoreHeaders() }; - router.use(allowedMethods(['POST'])); - router.use(express.urlencoded({ extended: false })); - - // Apply rate limiting unless explicitly disabled - if (rateLimitConfig !== false) { - router.use( - rateLimit({ - windowMs: 15 * 60 * 1000, // 15 minutes - max: 50, // 50 requests per windowMs - standardHeaders: true, - legacyHeaders: false, - message: new TooManyRequestsError('You have exceeded the rate limit for token requests').toResponseObject(), - ...rateLimitConfig - }) - ); - } - - // Authenticate and extract client details - router.use(authenticateClient({ clientsStore: provider.clientsStore })); - - router.post('/', async (req, res) => { - res.setHeader('Cache-Control', 'no-store'); + if (req.method === 'OPTIONS') { + return corsPreflightResponse(cors); + } + if (req.method !== 'POST') { + const resp = methodNotAllowedResponse(req, ['POST', 'OPTIONS']); + const body = await resp.text(); + return new Response(body, { + status: resp.status, + headers: { ...Object.fromEntries(resp.headers.entries()), ...baseHeaders } + }); + } try { - const parseResult = TokenRequestSchema.safeParse(req.body); + const rawBody = await getParsedBody(req, ctx); + const parseResult = TokenRequestSchema.safeParse(rawBody); if (!parseResult.success) { throw new InvalidRequestError(parseResult.error.message); } const { grant_type } = parseResult.data; - const client = req.client; - if (!client) { - // This should never happen - throw new ServerError('Internal Server Error'); - } + const client = await authenticateClient(rawBody, { clientsStore: provider.clientsStore }); switch (grant_type) { case 'authorization_code': { - const parseResult = AuthorizationCodeGrantSchema.safeParse(req.body); + const parseResult = AuthorizationCodeGrantSchema.safeParse(rawBody); if (!parseResult.success) { throw new InvalidRequestError(parseResult.error.message); } @@ -116,12 +90,11 @@ export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHand redirect_uri, resource ? new URL(resource) : undefined ); - res.status(200).json(tokens); - break; + return jsonResponse(tokens, { status: 200, headers: baseHeaders }); } case 'refresh_token': { - const parseResult = RefreshTokenGrantSchema.safeParse(req.body); + const parseResult = RefreshTokenGrantSchema.safeParse(rawBody); if (!parseResult.success) { throw new InvalidRequestError(parseResult.error.message); } @@ -135,8 +108,7 @@ export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHand scopes, resource ? new URL(resource) : undefined ); - res.status(200).json(tokens); - break; + return jsonResponse(tokens, { status: 200, headers: baseHeaders }); } // Additional auth methods will not be added on the server side of the SDK. case 'client_credentials': @@ -146,13 +118,10 @@ export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHand } catch (error) { if (error instanceof OAuthError) { const status = error instanceof ServerError ? 500 : 400; - res.status(status).json(error.toResponseObject()); - } else { - const serverError = new ServerError('Internal Server Error'); - res.status(500).json(serverError.toResponseObject()); + return jsonResponse(error.toResponseObject(), { status, headers: baseHeaders }); } + const serverError = new ServerError('Internal Server Error'); + return jsonResponse(serverError.toResponseObject(), { status: 500, headers: baseHeaders }); } - }); - - return router; + }; } diff --git a/packages/server/src/server/auth/index.ts b/packages/server/src/server/auth/index.ts index 5369224cf..2b176805b 100644 --- a/packages/server/src/server/auth/index.ts +++ b/packages/server/src/server/auth/index.ts @@ -10,3 +10,4 @@ export * from './middleware/clientAuth.js'; export * from './provider.js'; export * from './providers/proxyProvider.js'; export * from './router.js'; +export * from './web.js'; diff --git a/packages/server/src/server/auth/middleware/allowedMethods.ts b/packages/server/src/server/auth/middleware/allowedMethods.ts index 72c076ec4..5c5245690 100644 --- a/packages/server/src/server/auth/middleware/allowedMethods.ts +++ b/packages/server/src/server/auth/middleware/allowedMethods.ts @@ -1,20 +1,20 @@ import { MethodNotAllowedError } from '@modelcontextprotocol/core'; -import type { RequestHandler } from 'express'; + +import { jsonResponse } from '../web.js'; /** - * Middleware to handle unsupported HTTP methods with a 405 Method Not Allowed response. + * Helper to handle unsupported HTTP methods with a 405 Method Not Allowed response. * * @param allowedMethods Array of allowed HTTP methods for this endpoint (e.g., ['GET', 'POST']) - * @returns Express middleware that returns a 405 error if method not in allowed list + * @returns Response if method not in allowed list, otherwise undefined */ -export function allowedMethods(allowedMethods: string[]): RequestHandler { - return (req, res, next) => { - if (allowedMethods.includes(req.method)) { - next(); - return; - } - - const error = new MethodNotAllowedError(`The method ${req.method} is not allowed for this endpoint`); - res.status(405).set('Allow', allowedMethods.join(', ')).json(error.toResponseObject()); - }; +export function allowedMethods(allowedMethods: string[], req: Request): Response | undefined { + if (allowedMethods.includes(req.method)) { + return undefined; + } + const error = new MethodNotAllowedError(`The method ${req.method} is not allowed for this endpoint`); + return jsonResponse(error.toResponseObject(), { + status: 405, + headers: { Allow: allowedMethods.join(', ') } + }); } diff --git a/packages/server/src/server/auth/middleware/bearerAuth.ts b/packages/server/src/server/auth/middleware/bearerAuth.ts index 1a16de1a9..853e400f6 100644 --- a/packages/server/src/server/auth/middleware/bearerAuth.ts +++ b/packages/server/src/server/auth/middleware/bearerAuth.ts @@ -1,8 +1,8 @@ import type { AuthInfo } from '@modelcontextprotocol/core'; import { InsufficientScopeError, InvalidTokenError, OAuthError, ServerError } from '@modelcontextprotocol/core'; -import type { RequestHandler } from 'express'; import type { OAuthTokenVerifier } from '../provider.js'; +import { jsonResponse } from '../web.js'; export type BearerAuthMiddlewareOptions = { /** @@ -21,83 +21,81 @@ export type BearerAuthMiddlewareOptions = { resourceMetadataUrl?: string; }; -declare module 'express-serve-static-core' { - interface Request { - /** - * Information about the validated access token, if the `requireBearerAuth` middleware was used. - */ - auth?: AuthInfo; - } -} - /** - * Middleware that requires a valid Bearer token in the Authorization header. + * Validates a Bearer token in the Authorization header. * - * This will validate the token with the auth provider and add the resulting auth info to the request object. - * - * If resourceMetadataUrl is provided, it will be included in the WWW-Authenticate header - * for 401 responses as per the OAuth 2.0 Protected Resource Metadata spec. + * Returns either `{ authInfo }` on success or `{ response }` on failure. */ -export function requireBearerAuth({ verifier, requiredScopes = [], resourceMetadataUrl }: BearerAuthMiddlewareOptions): RequestHandler { - return async (req, res, next) => { - try { - const authHeader = req.headers.authorization; - if (!authHeader) { - throw new InvalidTokenError('Missing Authorization header'); - } +export async function requireBearerAuth( + req: Request, + { verifier, requiredScopes = [], resourceMetadataUrl }: BearerAuthMiddlewareOptions +): Promise<{ authInfo: AuthInfo } | { response: Response }> { + try { + const authHeader = req.headers.get('authorization'); + if (!authHeader) { + throw new InvalidTokenError('Missing Authorization header'); + } - const [type, token] = authHeader.split(' '); - if (type!.toLowerCase() !== 'bearer' || !token) { - throw new InvalidTokenError("Invalid Authorization header format, expected 'Bearer TOKEN'"); - } + const [type, token] = authHeader.split(' '); + if (type!.toLowerCase() !== 'bearer' || !token) { + throw new InvalidTokenError("Invalid Authorization header format, expected 'Bearer TOKEN'"); + } - const authInfo = await verifier.verifyAccessToken(token); + const authInfo = await verifier.verifyAccessToken(token); - // Check if token has the required scopes (if any) - if (requiredScopes.length > 0) { - const hasAllScopes = requiredScopes.every(scope => authInfo.scopes.includes(scope)); + // Check if token has the required scopes (if any) + if (requiredScopes.length > 0) { + const hasAllScopes = requiredScopes.every(scope => authInfo.scopes.includes(scope)); - if (!hasAllScopes) { - throw new InsufficientScopeError('Insufficient scope'); - } + if (!hasAllScopes) { + throw new InsufficientScopeError('Insufficient scope'); } + } - // Check if the token is set to expire or if it is expired - if (typeof authInfo.expiresAt !== 'number' || isNaN(authInfo.expiresAt)) { - throw new InvalidTokenError('Token has no expiration time'); - } else if (authInfo.expiresAt < Date.now() / 1000) { - throw new InvalidTokenError('Token has expired'); + // Check if the token is set to expire or if it is expired + if (typeof authInfo.expiresAt !== 'number' || isNaN(authInfo.expiresAt)) { + throw new InvalidTokenError('Token has no expiration time'); + } else if (authInfo.expiresAt < Date.now() / 1000) { + throw new InvalidTokenError('Token has expired'); + } + + return { authInfo }; + } catch (error) { + // Build WWW-Authenticate header parts + const buildWwwAuthHeader = (errorCode: string, message: string): string => { + let header = `Bearer error="${errorCode}", error_description="${message}"`; + if (requiredScopes.length > 0) { + header += `, scope="${requiredScopes.join(' ')}"`; + } + if (resourceMetadataUrl) { + header += `, resource_metadata="${resourceMetadataUrl}"`; } + return header; + }; - req.auth = authInfo; - next(); - } catch (error) { - // Build WWW-Authenticate header parts - const buildWwwAuthHeader = (errorCode: string, message: string): string => { - let header = `Bearer error="${errorCode}", error_description="${message}"`; - if (requiredScopes.length > 0) { - header += `, scope="${requiredScopes.join(' ')}"`; - } - if (resourceMetadataUrl) { - header += `, resource_metadata="${resourceMetadataUrl}"`; - } - return header; + if (error instanceof InvalidTokenError) { + return { + response: jsonResponse(error.toResponseObject(), { + status: 401, + headers: { 'WWW-Authenticate': buildWwwAuthHeader(error.errorCode, error.message) } + }) }; - - if (error instanceof InvalidTokenError) { - res.set('WWW-Authenticate', buildWwwAuthHeader(error.errorCode, error.message)); - res.status(401).json(error.toResponseObject()); - } else if (error instanceof InsufficientScopeError) { - res.set('WWW-Authenticate', buildWwwAuthHeader(error.errorCode, error.message)); - res.status(403).json(error.toResponseObject()); - } else if (error instanceof ServerError) { - res.status(500).json(error.toResponseObject()); - } else if (error instanceof OAuthError) { - res.status(400).json(error.toResponseObject()); - } else { - const serverError = new ServerError('Internal Server Error'); - res.status(500).json(serverError.toResponseObject()); - } } - }; + if (error instanceof InsufficientScopeError) { + return { + response: jsonResponse(error.toResponseObject(), { + status: 403, + headers: { 'WWW-Authenticate': buildWwwAuthHeader(error.errorCode, error.message) } + }) + }; + } + if (error instanceof ServerError) { + return { response: jsonResponse(error.toResponseObject(), { status: 500 }) }; + } + if (error instanceof OAuthError) { + return { response: jsonResponse(error.toResponseObject(), { status: 400 }) }; + } + const serverError = new ServerError('Internal Server Error'); + return { response: jsonResponse(serverError.toResponseObject(), { status: 500 }) }; + } } diff --git a/packages/server/src/server/auth/middleware/clientAuth.ts b/packages/server/src/server/auth/middleware/clientAuth.ts index ac4bc8b79..9da271e35 100644 --- a/packages/server/src/server/auth/middleware/clientAuth.ts +++ b/packages/server/src/server/auth/middleware/clientAuth.ts @@ -1,6 +1,5 @@ import type { OAuthClientInformationFull } from '@modelcontextprotocol/core'; -import { InvalidClientError, InvalidRequestError, OAuthError, ServerError } from '@modelcontextprotocol/core'; -import type { RequestHandler } from 'express'; +import { InvalidClientError, InvalidRequestError } from '@modelcontextprotocol/core'; import * as z from 'zod/v4'; import type { OAuthRegisteredClientsStore } from '../clients.js'; @@ -17,49 +16,35 @@ const ClientAuthenticatedRequestSchema = z.object({ client_secret: z.string().optional() }); -declare module 'express-serve-static-core' { - interface Request { - /** - * The authenticated client for this request, if the `authenticateClient` middleware was used. - */ - client?: OAuthClientInformationFull; +/** + * Parses and validates client credentials from a request body, returning the authenticated client. + * + * Throws an OAuthError (or ServerError) on failure. + */ +export async function authenticateClient( + body: unknown, + { clientsStore }: ClientAuthenticationMiddlewareOptions +): Promise { + const result = ClientAuthenticatedRequestSchema.safeParse(body); + if (!result.success) { + throw new InvalidRequestError(String(result.error)); } -} - -export function authenticateClient({ clientsStore }: ClientAuthenticationMiddlewareOptions): RequestHandler { - return async (req, res, next) => { - try { - const result = ClientAuthenticatedRequestSchema.safeParse(req.body); - if (!result.success) { - throw new InvalidRequestError(String(result.error)); - } - const { client_id, client_secret } = result.data; - const client = await clientsStore.getClient(client_id); - if (!client) { - throw new InvalidClientError('Invalid client_id'); - } - if (client.client_secret) { - if (!client_secret) { - throw new InvalidClientError('Client secret is required'); - } - if (client.client_secret !== client_secret) { - throw new InvalidClientError('Invalid client_secret'); - } - if (client.client_secret_expires_at && client.client_secret_expires_at < Math.floor(Date.now() / 1000)) { - throw new InvalidClientError('Client secret has expired'); - } - } - - req.client = client; - next(); - } catch (error) { - if (error instanceof OAuthError) { - const status = error instanceof ServerError ? 500 : 400; - res.status(status).json(error.toResponseObject()); - } else { - const serverError = new ServerError('Internal Server Error'); - res.status(500).json(serverError.toResponseObject()); - } + const { client_id, client_secret } = result.data; + const client = await clientsStore.getClient(client_id); + if (!client) { + throw new InvalidClientError('Invalid client_id'); + } + if (client.client_secret) { + if (!client_secret) { + throw new InvalidClientError('Client secret is required'); } - }; + if (client.client_secret !== client_secret) { + throw new InvalidClientError('Invalid client_secret'); + } + if (client.client_secret_expires_at && client.client_secret_expires_at < Math.floor(Date.now() / 1000)) { + throw new InvalidClientError('Client secret has expired'); + } + } + + return client; } diff --git a/packages/server/src/server/auth/provider.ts b/packages/server/src/server/auth/provider.ts index 6d27fb792..d7dc395d1 100644 --- a/packages/server/src/server/auth/provider.ts +++ b/packages/server/src/server/auth/provider.ts @@ -1,5 +1,4 @@ import type { AuthInfo, OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens } from '@modelcontextprotocol/core'; -import type { Response } from 'express'; import type { OAuthRegisteredClientsStore } from './clients.js'; @@ -27,7 +26,7 @@ export interface OAuthServerProvider { * - In the successful case, the redirect MUST include the `code` and `state` (if present) query parameters. * - In the error case, the redirect MUST include the `error` query parameter, and MAY include an optional `error_description` query parameter. */ - authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise; + authorize(client: OAuthClientInformationFull, params: AuthorizationParams): Promise; /** * Returns the `codeChallenge` that was used when the indicated authorization began. diff --git a/packages/server/src/server/auth/providers/proxyProvider.ts b/packages/server/src/server/auth/providers/proxyProvider.ts index 0688754c0..230e8766e 100644 --- a/packages/server/src/server/auth/providers/proxyProvider.ts +++ b/packages/server/src/server/auth/providers/proxyProvider.ts @@ -1,6 +1,5 @@ import type { AuthInfo, FetchLike, OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens } from '@modelcontextprotocol/core'; import { OAuthClientInformationFullSchema, OAuthTokensSchema, ServerError } from '@modelcontextprotocol/core'; -import type { Response } from 'express'; import type { OAuthRegisteredClientsStore } from '../clients.js'; import type { AuthorizationParams, OAuthServerProvider } from '../provider.js'; @@ -112,7 +111,7 @@ export class ProxyOAuthServerProvider implements OAuthServerProvider { }; } - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + async authorize(client: OAuthClientInformationFull, params: AuthorizationParams): Promise { // Start with required OAuth parameters const targetUrl = new URL(this._endpoints.authorizationUrl); const searchParams = new URLSearchParams({ @@ -129,7 +128,7 @@ export class ProxyOAuthServerProvider implements OAuthServerProvider { if (params.resource) searchParams.set('resource', params.resource.href); targetUrl.search = searchParams.toString(); - res.redirect(targetUrl.toString()); + return Response.redirect(targetUrl.toString(), 302); } async challengeForAuthorizationCode(_client: OAuthClientInformationFull, _authorizationCode: string): Promise { diff --git a/packages/server/src/server/auth/router.ts b/packages/server/src/server/auth/router.ts index ba8b030e0..61ed79806 100644 --- a/packages/server/src/server/auth/router.ts +++ b/packages/server/src/server/auth/router.ts @@ -1,6 +1,4 @@ import type { OAuthMetadata, OAuthProtectedResourceMetadata } from '@modelcontextprotocol/core'; -import type { RequestHandler } from 'express'; -import express from 'express'; import type { AuthorizationHandlerOptions } from './handlers/authorize.js'; import { authorizationHandler } from './handlers/authorize.js'; @@ -12,6 +10,7 @@ import { revocationHandler } from './handlers/revoke.js'; import type { TokenHandlerOptions } from './handlers/token.js'; import { tokenHandler } from './handlers/token.js'; import type { OAuthServerProvider } from './provider.js'; +import type { WebHandler } from './web.js'; // Check for dev mode flag that allows HTTP issuer URLs (for development/testing only) const allowInsecureIssuerUrl = @@ -67,6 +66,24 @@ export type AuthRouterOptions = { tokenOptions?: Omit; }; +export type AuthRoute = { + path: string; + methods: string[]; + handler: WebHandler; +}; + +export type WebAuthRouter = { + /** + * List of concrete routes (absolute paths) that should be mounted at the application root. + */ + routes: AuthRoute[]; + + /** + * Convenience dispatcher that matches on `new URL(req.url).pathname` and calls the correct handler. + */ + handle: WebHandler; +}; + const checkIssuerUrl = (issuer: URL): void => { // Technically RFC 8414 does not permit a localhost HTTPS exemption, but this will be necessary for ease of testing if (issuer.protocol !== 'https:' && issuer.hostname !== 'localhost' && issuer.hostname !== '127.0.0.1' && !allowInsecureIssuerUrl) { @@ -124,55 +141,62 @@ export const createOAuthMetadata = (options: { * Installs standard MCP authorization server endpoints, including dynamic client registration and token revocation (if supported). * Also advertises standard authorization server metadata, for easier discovery of supported configurations by clients. * Note: if your MCP server is only a resource server and not an authorization server, use mcpAuthMetadataRouter instead. - * - * By default, rate limiting is applied to all endpoints to prevent abuse. - * - * This router MUST be installed at the application root, like so: - * - * const app = express(); - * app.use(mcpAuthRouter(...)); */ -export function mcpAuthRouter(options: AuthRouterOptions): RequestHandler { +export function mcpAuthRouter(options: AuthRouterOptions): WebAuthRouter { const oauthMetadata = createOAuthMetadata(options); - - const router = express.Router(); - - router.use( - new URL(oauthMetadata.authorization_endpoint).pathname, - authorizationHandler({ provider: options.provider, ...options.authorizationOptions }) - ); - - router.use(new URL(oauthMetadata.token_endpoint).pathname, tokenHandler({ provider: options.provider, ...options.tokenOptions })); - - router.use( - mcpAuthMetadataRouter({ - oauthMetadata, - // Prefer explicit RS; otherwise fall back to AS baseUrl, then to issuer (back-compat) - resourceServerUrl: options.resourceServerUrl ?? options.baseUrl ?? new URL(oauthMetadata.issuer), - serviceDocumentationUrl: options.serviceDocumentationUrl, - scopesSupported: options.scopesSupported, - resourceName: options.resourceName - }) - ); + const routes: AuthRoute[] = []; + + routes.push({ + path: new URL(oauthMetadata.authorization_endpoint).pathname, + methods: ['GET', 'POST'], + handler: authorizationHandler({ provider: options.provider, ...options.authorizationOptions }) + }); + + routes.push({ + path: new URL(oauthMetadata.token_endpoint).pathname, + methods: ['POST', 'OPTIONS'], + handler: tokenHandler({ provider: options.provider, ...options.tokenOptions }) + }); + + const metadataRouter = mcpAuthMetadataRouter({ + oauthMetadata, + // Prefer explicit RS; otherwise fall back to AS baseUrl, then to issuer (back-compat) + resourceServerUrl: options.resourceServerUrl ?? options.baseUrl ?? new URL(oauthMetadata.issuer), + serviceDocumentationUrl: options.serviceDocumentationUrl, + scopesSupported: options.scopesSupported, + resourceName: options.resourceName + }); + routes.push(...metadataRouter.routes); if (oauthMetadata.registration_endpoint) { - router.use( - new URL(oauthMetadata.registration_endpoint).pathname, - clientRegistrationHandler({ + routes.push({ + path: new URL(oauthMetadata.registration_endpoint).pathname, + methods: ['POST', 'OPTIONS'], + handler: clientRegistrationHandler({ clientsStore: options.provider.clientsStore, ...options.clientRegistrationOptions }) - ); + }); } if (oauthMetadata.revocation_endpoint) { - router.use( - new URL(oauthMetadata.revocation_endpoint).pathname, - revocationHandler({ provider: options.provider, ...options.revocationOptions }) - ); + routes.push({ + path: new URL(oauthMetadata.revocation_endpoint).pathname, + methods: ['POST', 'OPTIONS'], + handler: revocationHandler({ provider: options.provider, ...options.revocationOptions }) + }); } - return router; + const handle: WebHandler = async (req, ctx) => { + const pathname = new URL(req.url).pathname; + const route = routes.find(r => r.path === pathname); + if (!route) { + return new Response('Not Found', { status: 404 }); + } + return route.handler(req, ctx); + }; + + return { routes, handle }; } export type AuthMetadataOptions = { @@ -203,11 +227,9 @@ export type AuthMetadataOptions = { resourceName?: string; }; -export function mcpAuthMetadataRouter(options: AuthMetadataOptions): express.Router { +export function mcpAuthMetadataRouter(options: AuthMetadataOptions): WebAuthRouter { checkIssuerUrl(new URL(options.oauthMetadata.issuer)); - const router = express.Router(); - const protectedResourceMetadata: OAuthProtectedResourceMetadata = { resource: options.resourceServerUrl.href, @@ -220,12 +242,24 @@ export function mcpAuthMetadataRouter(options: AuthMetadataOptions): express.Rou // Serve PRM at the path-specific URL per RFC 9728 const rsPath = new URL(options.resourceServerUrl.href).pathname; - router.use(`/.well-known/oauth-protected-resource${rsPath === '/' ? '' : rsPath}`, metadataHandler(protectedResourceMetadata)); - - // Always add this for OAuth Authorization Server metadata per RFC 8414 - router.use('/.well-known/oauth-authorization-server', metadataHandler(options.oauthMetadata)); + const prmPath = `/.well-known/oauth-protected-resource${rsPath === '/' ? '' : rsPath}`; + + const routes: AuthRoute[] = [ + { path: prmPath, methods: ['GET', 'OPTIONS'], handler: metadataHandler(protectedResourceMetadata) }, + // Always add this for OAuth Authorization Server metadata per RFC 8414 + { path: '/.well-known/oauth-authorization-server', methods: ['GET', 'OPTIONS'], handler: metadataHandler(options.oauthMetadata) } + ]; + + const handle: WebHandler = async (req, ctx) => { + const pathname = new URL(req.url).pathname; + const route = routes.find(r => r.path === pathname); + if (!route) { + return new Response('Not Found', { status: 404 }); + } + return route.handler(req, ctx); + }; - return router; + return { routes, handle }; } /** diff --git a/packages/server/src/server/auth/web.ts b/packages/server/src/server/auth/web.ts new file mode 100644 index 000000000..e461e9711 --- /dev/null +++ b/packages/server/src/server/auth/web.ts @@ -0,0 +1,92 @@ +import { MethodNotAllowedError } from '@modelcontextprotocol/core'; + +export type HeaderMap = Record; + +export type WebHandlerContext = { + /** + * Optional pre-parsed request body from an upstream framework. + * If provided, handlers will use this instead of reading from the Request stream. + */ + parsedBody?: unknown; +}; + +export type WebHandler = (req: Request, ctx?: WebHandlerContext) => Promise; + +export function jsonResponse(body: unknown, init?: { status?: number; headers?: HeaderMap }): Response { + const headers: HeaderMap = { 'Content-Type': 'application/json' }; + if (init?.headers) { + Object.assign(headers, init.headers); + } + return new Response(JSON.stringify(body), { + status: init?.status ?? 200, + headers + }); +} + +export function noStoreHeaders(): HeaderMap { + return { 'Cache-Control': 'no-store' }; +} + +export async function getParsedBody(req: Request, ctx?: WebHandlerContext): Promise { + if (ctx?.parsedBody !== undefined) { + return ctx.parsedBody; + } + + const ct = req.headers.get('content-type') ?? ''; + + if (ct.includes('application/json')) { + return await req.json(); + } + + if (ct.includes('application/x-www-form-urlencoded')) { + const text = await req.text(); + return objectFromUrlEncoded(text); + } + + // Empty bodies are treated as empty objects. + const text = await req.text(); + if (!text) return {}; + + // If content-type is missing/unknown, fall back to treating it as urlencoded-like. + return objectFromUrlEncoded(text); +} + +export function objectFromUrlEncoded(body: string): Record { + const params = new URLSearchParams(body); + const out: Record = {}; + for (const [k, v] of params.entries()) out[k] = v; + return out; +} + +export function methodNotAllowedResponse(req: Request, allowed: string[]): Response { + const error = new MethodNotAllowedError(`The method ${req.method} is not allowed for this endpoint`); + return jsonResponse(error.toResponseObject(), { + status: 405, + headers: { Allow: allowed.join(', ') } + }); +} + +export type CorsOptions = { + allowOrigin?: string; + allowMethods: readonly string[]; + allowHeaders?: readonly string[]; + exposeHeaders?: readonly string[]; + maxAgeSeconds?: number; +}; + +export function corsHeaders(options: CorsOptions): HeaderMap { + return { + 'Access-Control-Allow-Origin': options.allowOrigin ?? '*', + 'Access-Control-Allow-Methods': options.allowMethods.join(', '), + 'Access-Control-Allow-Headers': (options.allowHeaders ?? ['Content-Type', 'Authorization']).join(', '), + ...(options.exposeHeaders ? { 'Access-Control-Expose-Headers': options.exposeHeaders.join(', ') } : {}), + ...(options.maxAgeSeconds !== undefined ? { 'Access-Control-Max-Age': String(options.maxAgeSeconds) } : {}) + }; +} + +export function corsPreflightResponse(options: CorsOptions): Response { + return new Response(null, { + status: 204, + headers: corsHeaders(options) + }); +} diff --git a/packages/server/src/server/middleware/hostHeaderValidation.ts b/packages/server/src/server/middleware/hostHeaderValidation.ts index f46178db3..e4d13ecf5 100644 --- a/packages/server/src/server/middleware/hostHeaderValidation.ts +++ b/packages/server/src/server/middleware/hostHeaderValidation.ts @@ -1,79 +1,67 @@ -import type { NextFunction, Request, RequestHandler, Response } from 'express'; +export type HostHeaderValidationResult = + | { ok: true; hostname: string } + | { + ok: false; + errorCode: 'missing_host' | 'invalid_host_header' | 'invalid_host'; + message: string; + hostHeader?: string; + hostname?: string; + }; /** - * Express middleware for DNS rebinding protection. - * Validates Host header hostname (port-agnostic) against an allowed list. + * Parse and validate a Host header against an allowlist of hostnames (port-agnostic). * - * This is particularly important for servers without authorization or HTTPS, - * such as localhost servers or development servers. DNS rebinding attacks can - * bypass same-origin policy by manipulating DNS to point a domain to a - * localhost address, allowing malicious websites to access your local server. - * - * @param allowedHostnames - List of allowed hostnames (without ports). - * For IPv6, provide the address with brackets (e.g., '[::1]'). - * @returns Express middleware function - * - * @example - * ```typescript - * const middleware = hostHeaderValidation(['localhost', '127.0.0.1', '[::1]']); - * app.use(middleware); - * ``` + * - Input host header may include a port (e.g. `localhost:3000`) or IPv6 brackets (e.g. `[::1]:3000`). + * - Allowlist items should be hostnames only (no ports). For IPv6, include brackets (e.g. `[::1]`). */ -export function hostHeaderValidation(allowedHostnames: string[]): RequestHandler { - return (req: Request, res: Response, next: NextFunction) => { - const hostHeader = req.headers.host; - if (!hostHeader) { - res.status(403).json({ - jsonrpc: '2.0', - error: { - code: -32000, - message: 'Missing Host header' - }, - id: null - }); - return; - } +export function validateHostHeader(hostHeader: string | null | undefined, allowedHostnames: string[]): HostHeaderValidationResult { + if (!hostHeader) { + return { ok: false, errorCode: 'missing_host', message: 'Missing Host header' }; + } - // Use URL API to parse hostname (handles IPv4, IPv6, and regular hostnames) - let hostname: string; - try { - hostname = new URL(`http://${hostHeader}`).hostname; - } catch { - res.status(403).json({ - jsonrpc: '2.0', - error: { - code: -32000, - message: `Invalid Host header: ${hostHeader}` - }, - id: null - }); - return; - } + // Use URL API to parse hostname (handles IPv4, IPv6, and regular hostnames) + let hostname: string; + try { + hostname = new URL(`http://${hostHeader}`).hostname; + } catch { + return { ok: false, errorCode: 'invalid_host_header', message: `Invalid Host header: ${hostHeader}`, hostHeader }; + } - if (!allowedHostnames.includes(hostname)) { - res.status(403).json({ - jsonrpc: '2.0', - error: { - code: -32000, - message: `Invalid Host: ${hostname}` - }, - id: null - }); - return; - } - next(); - }; + if (!allowedHostnames.includes(hostname)) { + return { ok: false, errorCode: 'invalid_host', message: `Invalid Host: ${hostname}`, hostHeader, hostname }; + } + + return { ok: true, hostname }; } /** - * Convenience middleware for localhost DNS rebinding protection. - * Allows only localhost, 127.0.0.1, and [::1] (IPv6 localhost) hostnames. - * + * Convenience allowlist for localhost DNS rebinding protection. + */ +export function localhostAllowedHostnames(): string[] { + return ['localhost', '127.0.0.1', '[::1]']; +} + +/** + * Web-standard Request helper for DNS rebinding protection. * @example - * ```typescript - * app.use(localhostHostValidation()); - * ``` + * const result = validateHostHeader(req.headers.get('host'), ['localhost']) */ -export function localhostHostValidation(): RequestHandler { - return hostHeaderValidation(['localhost', '127.0.0.1', '[::1]']); +export function hostHeaderValidationResponse(req: Request, allowedHostnames: string[]): Response | undefined { + const result = validateHostHeader(req.headers.get('host'), allowedHostnames); + if (result.ok) return undefined; + + return new Response( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32000, + message: result.message + }, + id: null + }), + { + status: 403, + headers: { 'Content-Type': 'application/json' } + } + ); } diff --git a/packages/server/src/server/sse.ts b/packages/server/src/server/sse.ts index 4fd0fa1d6..06d418b2d 100644 --- a/packages/server/src/server/sse.ts +++ b/packages/server/src/server/sse.ts @@ -16,24 +16,24 @@ export interface SSEServerTransportOptions { /** * List of allowed host header values for DNS rebinding protection. * If not specified, host validation is disabled. - * @deprecated Use the `hostHeaderValidation` middleware from `@modelcontextprotocol/sdk/server/middleware/hostHeaderValidation.js` instead, - * or use `createMcpExpressApp` from `@modelcontextprotocol/sdk/server/express.js` which includes localhost protection by default. + * @deprecated Use the `hostHeaderValidationResponse` helper from `@modelcontextprotocol/server`, + * or use `createMcpExpressApp` from `@modelcontextprotocol/server-express` which includes localhost protection by default. */ allowedHosts?: string[]; /** * List of allowed origin header values for DNS rebinding protection. * If not specified, origin validation is disabled. - * @deprecated Use the `hostHeaderValidation` middleware from `@modelcontextprotocol/sdk/server/middleware/hostHeaderValidation.js` instead, - * or use `createMcpExpressApp` from `@modelcontextprotocol/sdk/server/express.js` which includes localhost protection by default. + * @deprecated Use the `hostHeaderValidationResponse` helper from `@modelcontextprotocol/server`, + * or use `createMcpExpressApp` from `@modelcontextprotocol/server-express` which includes localhost protection by default. */ allowedOrigins?: string[]; /** * Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). * Default is false for backwards compatibility. - * @deprecated Use the `hostHeaderValidation` middleware from `@modelcontextprotocol/sdk/server/middleware/hostHeaderValidation.js` instead, - * or use `createMcpExpressApp` from `@modelcontextprotocol/sdk/server/express.js` which includes localhost protection by default. + * @deprecated Use the `hostHeaderValidationResponse` helper from `@modelcontextprotocol/server`, + * or use `createMcpExpressApp` from `@modelcontextprotocol/server-express` which includes localhost protection by default. */ enableDnsRebindingProtection?: boolean; } @@ -42,7 +42,7 @@ export interface SSEServerTransportOptions { * Server transport for SSE: this will send messages over an SSE connection and receive messages from HTTP POST requests. * * This transport is only available in Node.js environments. - * @deprecated SSEServerTransport is deprecated. Use StreamableHTTPServerTransport instead. + * @deprecated SSEServerTransport is deprecated. Use NodeStreamableHTTPServerTransport instead. */ export class SSEServerTransport implements Transport { private _sseResponse?: ServerResponse; diff --git a/packages/server/src/server/streamableHttp.ts b/packages/server/src/server/streamableHttp.ts index f9ee07ca8..10a990196 100644 --- a/packages/server/src/server/streamableHttp.ts +++ b/packages/server/src/server/streamableHttp.ts @@ -59,7 +59,7 @@ export type StreamableHTTPServerTransportOptions = WebStandardStreamableHTTPServ * - No Session ID is included in any responses * - No session validation is performed */ -export class StreamableHTTPServerTransport implements Transport { +export class NodeStreamableHTTPServerTransport implements Transport { private _webStandardTransport: WebStandardStreamableHTTPServerTransport; private _requestListener: ReturnType; // Store auth and parsedBody per request for passing through to handleRequest diff --git a/packages/server/src/server/webStandardStreamableHttp.ts b/packages/server/src/server/webStandardStreamableHttp.ts index 082c904e1..73ab30808 100644 --- a/packages/server/src/server/webStandardStreamableHttp.ts +++ b/packages/server/src/server/webStandardStreamableHttp.ts @@ -4,7 +4,7 @@ * This is the core transport implementation using Web Standard APIs (Request, Response, ReadableStream). * It can run on any runtime that supports Web Standards: Node.js 18+, Cloudflare Workers, Deno, Bun, etc. * - * For Node.js Express/HTTP compatibility, use `StreamableHTTPServerTransport` which wraps this transport. + * For Node.js Express/HTTP compatibility, use `NodeStreamableHTTPServerTransport` which wraps this transport. */ import { TextEncoder } from 'node:util'; diff --git a/packages/server/test/server/auth/handlers/authorize.test.ts b/packages/server/test/server/auth/handlers/authorize.test.ts index b84de3bc3..c5943915c 100644 --- a/packages/server/test/server/auth/handlers/authorize.test.ts +++ b/packages/server/test/server/auth/handlers/authorize.test.ts @@ -1,60 +1,41 @@ -import type { AuthInfo, OAuthClientInformationFull, OAuthTokens } from '@modelcontextprotocol/core'; -import { InvalidTokenError } from '@modelcontextprotocol/core'; -import type { Response } from 'express'; -import express from 'express'; -import supertest from 'supertest'; +import type { OAuthClientInformationFull, OAuthTokens } from '@modelcontextprotocol/core'; import type { OAuthRegisteredClientsStore } from '../../../../src/server/auth/clients.js'; -import type { AuthorizationHandlerOptions } from '../../../../src/server/auth/handlers/authorize.js'; import { authorizationHandler } from '../../../../src/server/auth/handlers/authorize.js'; import type { AuthorizationParams, OAuthServerProvider } from '../../../../src/server/auth/provider.js'; -describe('Authorization Handler', () => { - // Mock client data +describe('authorizationHandler (web)', () => { const validClient: OAuthClientInformationFull = { client_id: 'valid-client', client_secret: 'valid-secret', - redirect_uris: ['https://example.com/callback'], - scope: 'profile email' + redirect_uris: ['https://example.com/callback'] }; const multiRedirectClient: OAuthClientInformationFull = { client_id: 'multi-redirect-client', client_secret: 'valid-secret', - redirect_uris: ['https://example.com/callback1', 'https://example.com/callback2'], - scope: 'profile email' + redirect_uris: ['https://example.com/callback1', 'https://example.com/callback2'] }; - // Mock client store - const mockClientStore: OAuthRegisteredClientsStore = { - async getClient(clientId: string): Promise { - if (clientId === 'valid-client') { - return validClient; - } else if (clientId === 'multi-redirect-client') { - return multiRedirectClient; - } + const clientsStore: OAuthRegisteredClientsStore = { + async getClient(clientId: string) { + if (clientId === 'valid-client') return validClient; + if (clientId === 'multi-redirect-client') return multiRedirectClient; return undefined; } }; - // Mock provider - const mockProvider: OAuthServerProvider = { - clientsStore: mockClientStore, - - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { - // Mock implementation - redirects to redirectUri with code and state - const redirectUrl = new URL(params.redirectUri); - redirectUrl.searchParams.set('code', 'mock_auth_code'); - if (params.state) { - redirectUrl.searchParams.set('state', params.state); - } - res.redirect(302, redirectUrl.toString()); + const provider: OAuthServerProvider = { + clientsStore, + async authorize(_client, params: AuthorizationParams): Promise { + const u = new URL(params.redirectUri); + u.searchParams.set('code', 'mock_auth_code'); + if (params.state) u.searchParams.set('state', params.state); + return Response.redirect(u.toString(), 302); }, - async challengeForAuthorizationCode(): Promise { return 'mock_challenge'; }, - async exchangeAuthorizationCode(): Promise { return { access_token: 'mock_access_token', @@ -63,7 +44,6 @@ describe('Authorization Handler', () => { refresh_token: 'mock_refresh_token' }; }, - async exchangeRefreshToken(): Promise { return { access_token: 'new_mock_access_token', @@ -72,225 +52,53 @@ describe('Authorization Handler', () => { refresh_token: 'new_mock_refresh_token' }; }, - - async verifyAccessToken(token: string): Promise { - if (token === 'valid_token') { - return { - token, - clientId: 'valid-client', - scopes: ['read', 'write'], - expiresAt: Date.now() / 1000 + 3600 - }; - } - throw new InvalidTokenError('Token is invalid or expired'); - }, - - async revokeToken(): Promise { - // Do nothing in mock + async verifyAccessToken() { + throw new Error('not used'); } }; - // Setup express app with handler - let app: express.Express; - let options: AuthorizationHandlerOptions; - - beforeEach(() => { - app = express(); - options = { provider: mockProvider }; - const handler = authorizationHandler(options); - app.use('/authorize', handler); + it('returns 405 for unsupported methods', async () => { + const handler = authorizationHandler({ provider }); + const res = await handler(new Request('http://localhost/authorize', { method: 'PUT' })); + expect(res.status).toBe(405); }); - describe('HTTP method validation', () => { - it('rejects non-GET/POST methods', async () => { - const response = await supertest(app).put('/authorize').query({ client_id: 'valid-client' }); - - expect(response.status).toBe(405); // Method not allowed response from handler - }); + it('returns 400 if client does not exist', async () => { + const handler = authorizationHandler({ provider }); + const res = await handler( + new Request( + 'http://localhost/authorize?client_id=missing&response_type=code&code_challenge=x&code_challenge_method=S256&redirect_uri=https%3A%2F%2Fexample.com%2Fcallback', + { method: 'GET' } + ) + ); + expect(res.status).toBe(400); + expect(await res.json()).toEqual(expect.objectContaining({ error: 'invalid_client' })); }); - describe('Client validation', () => { - it('requires client_id parameter', async () => { - const response = await supertest(app).get('/authorize'); - - expect(response.status).toBe(400); - expect(response.text).toContain('client_id'); - }); - - it('validates that client exists', async () => { - const response = await supertest(app).get('/authorize').query({ client_id: 'nonexistent-client' }); - - expect(response.status).toBe(400); - }); + it('redirects with a code on valid request (single redirect_uri inferred)', async () => { + const handler = authorizationHandler({ provider }); + const res = await handler( + new Request( + 'http://localhost/authorize?client_id=valid-client&response_type=code&code_challenge=challenge123&code_challenge_method=S256', + { method: 'GET' } + ) + ); + expect(res.status).toBe(302); + const location = res.headers.get('location')!; + expect(location).toContain('https://example.com/callback'); + expect(location).toContain('code=mock_auth_code'); + expect(res.headers.get('cache-control')).toBe('no-store'); }); - describe('Redirect URI validation', () => { - it('uses the only redirect_uri if client has just one and none provided', async () => { - const response = await supertest(app).get('/authorize').query({ - client_id: 'valid-client', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256' - }); - - expect(response.status).toBe(302); - const location = new URL(response.header.location!); - expect(location.origin + location.pathname).toBe('https://example.com/callback'); - }); - - it('requires redirect_uri if client has multiple', async () => { - const response = await supertest(app).get('/authorize').query({ - client_id: 'multi-redirect-client', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256' - }); - - expect(response.status).toBe(400); - }); - - it('validates redirect_uri against client registered URIs', async () => { - const response = await supertest(app).get('/authorize').query({ - client_id: 'valid-client', - redirect_uri: 'https://malicious.com/callback', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256' - }); - - expect(response.status).toBe(400); - }); - - it('accepts valid redirect_uri that client registered with', async () => { - const response = await supertest(app).get('/authorize').query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256' - }); - - expect(response.status).toBe(302); - const location = new URL(response.header.location!); - expect(location.origin + location.pathname).toBe('https://example.com/callback'); - }); - }); - - describe('Authorization request validation', () => { - it('requires response_type=code', async () => { - const response = await supertest(app).get('/authorize').query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'token', // invalid - we only support code flow - code_challenge: 'challenge123', - code_challenge_method: 'S256' - }); - - expect(response.status).toBe(302); - const location = new URL(response.header.location!); - expect(location.searchParams.get('error')).toBe('invalid_request'); - }); - - it('requires code_challenge parameter', async () => { - const response = await supertest(app).get('/authorize').query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'code', - code_challenge_method: 'S256' - // Missing code_challenge - }); - - expect(response.status).toBe(302); - const location = new URL(response.header.location!); - expect(location.searchParams.get('error')).toBe('invalid_request'); - }); - - it('requires code_challenge_method=S256', async () => { - const response = await supertest(app).get('/authorize').query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'plain' // Only S256 is supported - }); - - expect(response.status).toBe(302); - const location = new URL(response.header.location!); - expect(location.searchParams.get('error')).toBe('invalid_request'); - }); - }); - - describe('Resource parameter validation', () => { - it('propagates resource parameter', async () => { - const mockProviderWithResource = vi.spyOn(mockProvider, 'authorize'); - - const response = await supertest(app).get('/authorize').query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256', - resource: 'https://api.example.com/resource' - }); - - expect(response.status).toBe(302); - expect(mockProviderWithResource).toHaveBeenCalledWith( - validClient, - expect.objectContaining({ - resource: new URL('https://api.example.com/resource'), - redirectUri: 'https://example.com/callback', - codeChallenge: 'challenge123' - }), - expect.any(Object) - ); - }); - }); - - describe('Successful authorization', () => { - it('handles successful authorization with all parameters', async () => { - const response = await supertest(app).get('/authorize').query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256', - scope: 'profile email', - state: 'xyz789' - }); - - expect(response.status).toBe(302); - const location = new URL(response.header.location!); - expect(location.origin + location.pathname).toBe('https://example.com/callback'); - expect(location.searchParams.get('code')).toBe('mock_auth_code'); - expect(location.searchParams.get('state')).toBe('xyz789'); - }); - - it('preserves state parameter in response', async () => { - const response = await supertest(app).get('/authorize').query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256', - state: 'state-value-123' - }); - - expect(response.status).toBe(302); - const location = new URL(response.header.location!); - expect(location.searchParams.get('state')).toBe('state-value-123'); - }); - - it('handles POST requests the same as GET', async () => { - const response = await supertest(app).post('/authorize').type('form').send({ - client_id: 'valid-client', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256' - }); - - expect(response.status).toBe(302); - const location = new URL(response.header.location!); - expect(location.searchParams.has('code')).toBe(true); - }); + it('requires redirect_uri if client has multiple redirect URIs', async () => { + const handler = authorizationHandler({ provider }); + const res = await handler( + new Request( + 'http://localhost/authorize?client_id=multi-redirect-client&response_type=code&code_challenge=challenge123&code_challenge_method=S256', + { method: 'GET' } + ) + ); + expect(res.status).toBe(400); + expect(await res.json()).toEqual(expect.objectContaining({ error: 'invalid_request' })); }); }); diff --git a/packages/server/test/server/auth/handlers/metadata.test.ts b/packages/server/test/server/auth/handlers/metadata.test.ts index 0dc51e51d..722320925 100644 --- a/packages/server/test/server/auth/handlers/metadata.test.ts +++ b/packages/server/test/server/auth/handlers/metadata.test.ts @@ -1,6 +1,4 @@ import type { OAuthMetadata } from '@modelcontextprotocol/core'; -import express from 'express'; -import supertest from 'supertest'; import { metadataHandler } from '../../../../src/server/auth/handlers/metadata.js'; @@ -18,62 +16,65 @@ describe('Metadata Handler', () => { code_challenge_methods_supported: ['S256'] }; - let app: express.Express; - - beforeEach(() => { - // Setup express app with metadata handler - app = express(); - app.use('/.well-known/oauth-authorization-server', metadataHandler(exampleMetadata)); - }); - it('requires GET method', async () => { - const response = await supertest(app).post('/.well-known/oauth-authorization-server').send({}); + const handler = metadataHandler(exampleMetadata); + const res = await handler(new Request('http://localhost/.well-known/oauth-authorization-server', { method: 'POST' })); - expect(response.status).toBe(405); - expect(response.headers.allow).toBe('GET, OPTIONS'); - expect(response.body).toEqual({ + expect(res.status).toBe(405); + expect(res.headers.get('allow')).toBe('GET, OPTIONS'); + expect(await res.json()).toEqual({ error: 'method_not_allowed', error_description: 'The method POST is not allowed for this endpoint' }); }); it('returns the metadata object', async () => { - const response = await supertest(app).get('/.well-known/oauth-authorization-server'); + const handler = metadataHandler(exampleMetadata); + const res = await handler(new Request('http://localhost/.well-known/oauth-authorization-server', { method: 'GET' })); - expect(response.status).toBe(200); - expect(response.body).toEqual(exampleMetadata); + expect(res.status).toBe(200); + expect(await res.json()).toEqual(exampleMetadata); }); it('includes CORS headers in response', async () => { - const response = await supertest(app).get('/.well-known/oauth-authorization-server').set('Origin', 'https://example.com'); + const handler = metadataHandler(exampleMetadata); + const res = await handler( + new Request('http://localhost/.well-known/oauth-authorization-server', { + method: 'GET', + headers: { Origin: 'https://example.com' } + }) + ); - expect(response.header['access-control-allow-origin']).toBe('*'); + expect(res.headers.get('access-control-allow-origin')).toBe('*'); }); it('supports OPTIONS preflight requests', async () => { - const response = await supertest(app) - .options('/.well-known/oauth-authorization-server') - .set('Origin', 'https://example.com') - .set('Access-Control-Request-Method', 'GET'); + const handler = metadataHandler(exampleMetadata); + const res = await handler( + new Request('http://localhost/.well-known/oauth-authorization-server', { + method: 'OPTIONS', + headers: { + Origin: 'https://example.com', + 'Access-Control-Request-Method': 'GET' + } + }) + ); - expect(response.status).toBe(204); - expect(response.header['access-control-allow-origin']).toBe('*'); + expect(res.status).toBe(204); + expect(res.headers.get('access-control-allow-origin')).toBe('*'); }); it('works with minimal metadata', async () => { - // Setup a new express app with minimal metadata - const minimalApp = express(); const minimalMetadata: OAuthMetadata = { issuer: 'https://auth.example.com', authorization_endpoint: 'https://auth.example.com/authorize', token_endpoint: 'https://auth.example.com/token', response_types_supported: ['code'] }; - minimalApp.use('/.well-known/oauth-authorization-server', metadataHandler(minimalMetadata)); - - const response = await supertest(minimalApp).get('/.well-known/oauth-authorization-server'); + const handler = metadataHandler(minimalMetadata); + const res = await handler(new Request('http://localhost/.well-known/oauth-authorization-server', { method: 'GET' })); - expect(response.status).toBe(200); - expect(response.body).toEqual(minimalMetadata); + expect(res.status).toBe(200); + expect(await res.json()).toEqual(minimalMetadata); }); }); diff --git a/packages/server/test/server/auth/handlers/register.test.ts b/packages/server/test/server/auth/handlers/register.test.ts index b10e048ed..6a2ffcd11 100644 --- a/packages/server/test/server/auth/handlers/register.test.ts +++ b/packages/server/test/server/auth/handlers/register.test.ts @@ -1,274 +1,39 @@ -import type { OAuthClientInformationFull, OAuthClientMetadata } from '@modelcontextprotocol/core'; -import express from 'express'; -import supertest from 'supertest'; -import type { MockInstance } from 'vitest'; +import type { OAuthClientInformationFull } from '@modelcontextprotocol/core'; import type { OAuthRegisteredClientsStore } from '../../../../src/server/auth/clients.js'; -import type { ClientRegistrationHandlerOptions } from '../../../../src/server/auth/handlers/register.js'; import { clientRegistrationHandler } from '../../../../src/server/auth/handlers/register.js'; -describe('Client Registration Handler', () => { - // Mock client store with registration support - const mockClientStoreWithRegistration: OAuthRegisteredClientsStore = { - async getClient(_clientId: string): Promise { - return undefined; - }, - - async registerClient(client: OAuthClientInformationFull): Promise { - // Return the client info as-is in the mock - return client; - } - }; - - // Mock client store without registration support - const mockClientStoreWithoutRegistration: OAuthRegisteredClientsStore = { - async getClient(_clientId: string): Promise { - return undefined; - } - // No registerClient method - }; - - describe('Handler creation', () => { - it('throws error if client store does not support registration', () => { - const options: ClientRegistrationHandlerOptions = { - clientsStore: mockClientStoreWithoutRegistration - }; - - expect(() => clientRegistrationHandler(options)).toThrow('does not support registering clients'); - }); - - it('creates handler if client store supports registration', () => { - const options: ClientRegistrationHandlerOptions = { - clientsStore: mockClientStoreWithRegistration - }; - - expect(() => clientRegistrationHandler(options)).not.toThrow(); - }); - }); - - describe('Request handling', () => { - let app: express.Express; - let spyRegisterClient: MockInstance; - - beforeEach(() => { - // Setup express app with registration handler - app = express(); - const options: ClientRegistrationHandlerOptions = { - clientsStore: mockClientStoreWithRegistration, - clientSecretExpirySeconds: 86400 // 1 day for testing - }; - - app.use('/register', clientRegistrationHandler(options)); - - // Spy on the registerClient method - spyRegisterClient = vi.spyOn(mockClientStoreWithRegistration, 'registerClient'); - }); - - afterEach(() => { - spyRegisterClient.mockRestore(); - }); - - it('requires POST method', async () => { - const response = await supertest(app) - .get('/register') - .send({ - redirect_uris: ['https://example.com/callback'] - }); - - expect(response.status).toBe(405); - expect(response.headers.allow).toBe('POST'); - expect(response.body).toEqual({ - error: 'method_not_allowed', - error_description: 'The method GET is not allowed for this endpoint' - }); - expect(spyRegisterClient).not.toHaveBeenCalled(); - }); - - it('validates required client metadata', async () => { - const response = await supertest(app).post('/register').send({ - // Missing redirect_uris (required) - client_name: 'Test Client' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_client_metadata'); - expect(spyRegisterClient).not.toHaveBeenCalled(); - }); - - it('validates redirect URIs format', async () => { - const response = await supertest(app) - .post('/register') - .send({ - redirect_uris: ['invalid-url'] // Invalid URL format - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_client_metadata'); - expect(response.body.error_description).toContain('redirect_uris'); - expect(spyRegisterClient).not.toHaveBeenCalled(); - }); - - it('successfully registers client with minimal metadata', async () => { - const clientMetadata: OAuthClientMetadata = { - redirect_uris: ['https://example.com/callback'] - }; - - const response = await supertest(app).post('/register').send(clientMetadata); - - expect(response.status).toBe(201); - - // Verify the generated client information - expect(response.body.client_id).toBeDefined(); - expect(response.body.client_secret).toBeDefined(); - expect(response.body.client_id_issued_at).toBeDefined(); - expect(response.body.client_secret_expires_at).toBeDefined(); - expect(response.body.redirect_uris).toEqual(['https://example.com/callback']); - - // Verify client was registered - expect(spyRegisterClient).toHaveBeenCalledTimes(1); - }); - - it('sets client_secret to undefined for token_endpoint_auth_method=none', async () => { - const clientMetadata: OAuthClientMetadata = { - redirect_uris: ['https://example.com/callback'], - token_endpoint_auth_method: 'none' - }; - - const response = await supertest(app).post('/register').send(clientMetadata); - - expect(response.status).toBe(201); - expect(response.body.client_secret).toBeUndefined(); - expect(response.body.client_secret_expires_at).toBeUndefined(); - }); - - it('sets client_secret_expires_at for public clients only', async () => { - // Test for public client (token_endpoint_auth_method not 'none') - const publicClientMetadata: OAuthClientMetadata = { - redirect_uris: ['https://example.com/callback'], - token_endpoint_auth_method: 'client_secret_basic' - }; - - const publicResponse = await supertest(app).post('/register').send(publicClientMetadata); - - expect(publicResponse.status).toBe(201); - expect(publicResponse.body.client_secret).toBeDefined(); - expect(publicResponse.body.client_secret_expires_at).toBeDefined(); - - // Test for non-public client (token_endpoint_auth_method is 'none') - const nonPublicClientMetadata: OAuthClientMetadata = { - redirect_uris: ['https://example.com/callback'], - token_endpoint_auth_method: 'none' - }; - - const nonPublicResponse = await supertest(app).post('/register').send(nonPublicClientMetadata); - - expect(nonPublicResponse.status).toBe(201); - expect(nonPublicResponse.body.client_secret).toBeUndefined(); - expect(nonPublicResponse.body.client_secret_expires_at).toBeUndefined(); - }); - - it('sets expiry based on clientSecretExpirySeconds', async () => { - // Create handler with custom expiry time - const customApp = express(); - const options: ClientRegistrationHandlerOptions = { - clientsStore: mockClientStoreWithRegistration, - clientSecretExpirySeconds: 3600 // 1 hour - }; - - customApp.use('/register', clientRegistrationHandler(options)); - - const response = await supertest(customApp) - .post('/register') - .send({ - redirect_uris: ['https://example.com/callback'] - }); - - expect(response.status).toBe(201); - - // Verify the expiration time (~1 hour from now) - const issuedAt = response.body.client_id_issued_at; - const expiresAt = response.body.client_secret_expires_at; - expect(expiresAt - issuedAt).toBe(3600); - }); - - it('sets no expiry when clientSecretExpirySeconds=0', async () => { - // Create handler with no expiry - const customApp = express(); - const options: ClientRegistrationHandlerOptions = { - clientsStore: mockClientStoreWithRegistration, - clientSecretExpirySeconds: 0 // No expiry - }; - - customApp.use('/register', clientRegistrationHandler(options)); - - const response = await supertest(customApp) - .post('/register') - .send({ - redirect_uris: ['https://example.com/callback'] - }); - - expect(response.status).toBe(201); - expect(response.body.client_secret_expires_at).toBe(0); - }); - - it('sets no client_id when clientIdGeneration=false', async () => { - // Create handler with no expiry - const customApp = express(); - const options: ClientRegistrationHandlerOptions = { - clientsStore: mockClientStoreWithRegistration, - clientIdGeneration: false - }; - - customApp.use('/register', clientRegistrationHandler(options)); - - const response = await supertest(customApp) - .post('/register') - .send({ - redirect_uris: ['https://example.com/callback'] - }); - - expect(response.status).toBe(201); - expect(response.body.client_id).toBeUndefined(); - expect(response.body.client_id_issued_at).toBeUndefined(); - }); - - it('handles client with all metadata fields', async () => { - const fullClientMetadata: OAuthClientMetadata = { - redirect_uris: ['https://example.com/callback'], - token_endpoint_auth_method: 'client_secret_basic', - grant_types: ['authorization_code', 'refresh_token'], - response_types: ['code'], - client_name: 'Test Client', - client_uri: 'https://example.com', - logo_uri: 'https://example.com/logo.png', - scope: 'profile email', - contacts: ['dev@example.com'], - tos_uri: 'https://example.com/tos', - policy_uri: 'https://example.com/privacy', - jwks_uri: 'https://example.com/jwks', - software_id: 'test-software', - software_version: '1.0.0' - }; - - const response = await supertest(app).post('/register').send(fullClientMetadata); - - expect(response.status).toBe(201); - - // Verify all metadata was preserved - Object.entries(fullClientMetadata).forEach(([key, value]) => { - expect(response.body[key]).toEqual(value); - }); - }); - - it('includes CORS headers in response', async () => { - const response = await supertest(app) - .post('/register') - .set('Origin', 'https://example.com') - .send({ +describe('clientRegistrationHandler (web)', () => { + it('returns 201 and client info when registration is supported', async () => { + const clientsStore: OAuthRegisteredClientsStore = { + async getClient() { + return undefined; + }, + async registerClient(client: Omit) { + // In real implementation, server may generate ids; here return minimal. + return { + ...client, + client_id: 'generated-client', + client_id_issued_at: Math.floor(Date.now() / 1000), + redirect_uris: (client as any).redirect_uris ?? [] + } as unknown as OAuthClientInformationFull; + } + }; + + const handler = clientRegistrationHandler({ clientsStore }); + + const res = await handler( + new Request('http://localhost/register', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ redirect_uris: ['https://example.com/callback'] - }); + }) + }) + ); - expect(response.header['access-control-allow-origin']).toBe('*'); - }); + expect(res.status).toBe(201); + const body = (await res.json()) as { client_id?: string }; + expect(body.client_id).toBeDefined(); }); }); diff --git a/packages/server/test/server/auth/handlers/revoke.test.ts b/packages/server/test/server/auth/handlers/revoke.test.ts index 61ff51b24..d960f26c3 100644 --- a/packages/server/test/server/auth/handlers/revoke.test.ts +++ b/packages/server/test/server/auth/handlers/revoke.test.ts @@ -1,233 +1,72 @@ -import type { AuthInfo, OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens } from '@modelcontextprotocol/core'; -import { InvalidTokenError } from '@modelcontextprotocol/core'; -import type { Response } from 'express'; -import express from 'express'; -import supertest from 'supertest'; -import type { MockInstance } from 'vitest'; +import type { OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens } from '@modelcontextprotocol/core'; import type { OAuthRegisteredClientsStore } from '../../../../src/server/auth/clients.js'; -import type { RevocationHandlerOptions } from '../../../../src/server/auth/handlers/revoke.js'; import { revocationHandler } from '../../../../src/server/auth/handlers/revoke.js'; import type { AuthorizationParams, OAuthServerProvider } from '../../../../src/server/auth/provider.js'; -describe('Revocation Handler', () => { - // Mock client data - const validClient: OAuthClientInformationFull = { - client_id: 'valid-client', - client_secret: 'valid-secret', - redirect_uris: ['https://example.com/callback'] - }; - - // Mock client store - const mockClientStore: OAuthRegisteredClientsStore = { - async getClient(clientId: string): Promise { - if (clientId === 'valid-client') { - return validClient; +describe('revocationHandler (web)', () => { + it('returns 200 on successful revocation', async () => { + const validClient: OAuthClientInformationFull = { + client_id: 'valid-client', + client_secret: 'valid-secret', + redirect_uris: ['https://example.com/callback'] + }; + + const clientsStore: OAuthRegisteredClientsStore = { + async getClient(clientId: string) { + return clientId === 'valid-client' ? validClient : undefined; } - return undefined; - } - }; - - // Mock provider with revocation capability - const mockProviderWithRevocation: OAuthServerProvider = { - clientsStore: mockClientStore, - - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { - res.redirect('https://example.com/callback?code=mock_auth_code'); - }, - - async challengeForAuthorizationCode(): Promise { - return 'mock_challenge'; - }, - - async exchangeAuthorizationCode(): Promise { - return { - access_token: 'mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'mock_refresh_token' - }; - }, - - async exchangeRefreshToken(): Promise { - return { - access_token: 'new_mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'new_mock_refresh_token' - }; - }, - - async verifyAccessToken(token: string): Promise { - if (token === 'valid_token') { + }; + + const provider: OAuthServerProvider = { + clientsStore, + async authorize(_client: OAuthClientInformationFull, _params: AuthorizationParams): Promise { + return Response.redirect('https://example.com', 302); + }, + async challengeForAuthorizationCode(): Promise { + return 'mock'; + }, + async exchangeAuthorizationCode(): Promise { return { - token, - clientId: 'valid-client', - scopes: ['read', 'write'], - expiresAt: Date.now() / 1000 + 3600 + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' }; - } - throw new InvalidTokenError('Token is invalid or expired'); - }, - - async revokeToken(_client: OAuthClientInformationFull, _request: OAuthTokenRevocationRequest): Promise { - // Success - do nothing in mock - } - }; - - // Mock provider without revocation capability - const mockProviderWithoutRevocation: OAuthServerProvider = { - clientsStore: mockClientStore, - - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { - res.redirect('https://example.com/callback?code=mock_auth_code'); - }, - - async challengeForAuthorizationCode(): Promise { - return 'mock_challenge'; - }, - - async exchangeAuthorizationCode(): Promise { - return { - access_token: 'mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'mock_refresh_token' - }; - }, - - async exchangeRefreshToken(): Promise { - return { - access_token: 'new_mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'new_mock_refresh_token' - }; - }, - - async verifyAccessToken(token: string): Promise { - if (token === 'valid_token') { + }, + async exchangeRefreshToken(): Promise { return { - token, - clientId: 'valid-client', - scopes: ['read', 'write'], - expiresAt: Date.now() / 1000 + 3600 + access_token: 'new_mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'new_mock_refresh_token' }; + }, + async verifyAccessToken() { + throw new Error('not used'); + }, + async revokeToken(_client: OAuthClientInformationFull, _request: OAuthTokenRevocationRequest): Promise { + // ok } - throw new InvalidTokenError('Token is invalid or expired'); - } - - // No revokeToken method - }; - - describe('Handler creation', () => { - it('throws error if provider does not support token revocation', () => { - const options: RevocationHandlerOptions = { provider: mockProviderWithoutRevocation }; - expect(() => revocationHandler(options)).toThrow('does not support revoking tokens'); - }); - - it('creates handler if provider supports token revocation', () => { - const options: RevocationHandlerOptions = { provider: mockProviderWithRevocation }; - expect(() => revocationHandler(options)).not.toThrow(); - }); - }); - - describe('Request handling', () => { - let app: express.Express; - let spyRevokeToken: MockInstance; - - beforeEach(() => { - // Setup express app with revocation handler - app = express(); - const options: RevocationHandlerOptions = { provider: mockProviderWithRevocation }; - app.use('/revoke', revocationHandler(options)); - - // Spy on the revokeToken method - spyRevokeToken = vi.spyOn(mockProviderWithRevocation, 'revokeToken'); - }); - - afterEach(() => { - spyRevokeToken.mockRestore(); - }); - - it('requires POST method', async () => { - const response = await supertest(app).get('/revoke').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - token: 'token_to_revoke' - }); - - expect(response.status).toBe(405); - expect(response.headers.allow).toBe('POST'); - expect(response.body).toEqual({ - error: 'method_not_allowed', - error_description: 'The method GET is not allowed for this endpoint' - }); - expect(spyRevokeToken).not.toHaveBeenCalled(); - }); - - it('requires token parameter', async () => { - const response = await supertest(app).post('/revoke').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret' - // Missing token - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_request'); - expect(spyRevokeToken).not.toHaveBeenCalled(); - }); - - it('authenticates client before revoking token', async () => { - const response = await supertest(app).post('/revoke').type('form').send({ - client_id: 'invalid-client', - client_secret: 'wrong-secret', - token: 'token_to_revoke' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_client'); - expect(spyRevokeToken).not.toHaveBeenCalled(); - }); - - it('successfully revokes token', async () => { - const response = await supertest(app).post('/revoke').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - token: 'token_to_revoke' - }); - - expect(response.status).toBe(200); - expect(response.body).toEqual({}); // Empty response on success - expect(spyRevokeToken).toHaveBeenCalledTimes(1); - expect(spyRevokeToken).toHaveBeenCalledWith(validClient, { - token: 'token_to_revoke' - }); - }); - - it('accepts optional token_type_hint', async () => { - const response = await supertest(app).post('/revoke').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - token: 'token_to_revoke', - token_type_hint: 'refresh_token' - }); - - expect(response.status).toBe(200); - expect(spyRevokeToken).toHaveBeenCalledWith(validClient, { - token: 'token_to_revoke', - token_type_hint: 'refresh_token' - }); - }); - - it('includes CORS headers in response', async () => { - const response = await supertest(app).post('/revoke').type('form').set('Origin', 'https://example.com').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - token: 'token_to_revoke' - }); - - expect(response.header['access-control-allow-origin']).toBe('*'); - }); + }; + + const handler = revocationHandler({ provider }); + + const body = new URLSearchParams({ + client_id: 'valid-client', + client_secret: 'valid-secret', + token: 'token_to_revoke' + }).toString(); + + const res = await handler( + new Request('http://localhost/revoke', { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body + }) + ); + + expect(res.status).toBe(200); + expect(await res.json()).toEqual({}); }); }); diff --git a/packages/server/test/server/auth/handlers/token.test.ts b/packages/server/test/server/auth/handlers/token.test.ts index 02eab891f..d99e4b39a 100644 --- a/packages/server/test/server/auth/handlers/token.test.ts +++ b/packages/server/test/server/auth/handlers/token.test.ts @@ -1,481 +1,116 @@ -import type { AuthInfo, OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens } from '@modelcontextprotocol/core'; -import { InvalidGrantError, InvalidTokenError } from '@modelcontextprotocol/core'; -import type { Response } from 'express'; -import express from 'express'; +import type { AuthInfo, OAuthClientInformationFull, OAuthTokens } from '@modelcontextprotocol/core'; +import { InvalidGrantError } from '@modelcontextprotocol/core'; import * as pkceChallenge from 'pkce-challenge'; -import supertest from 'supertest'; -import { type Mock } from 'vitest'; import type { OAuthRegisteredClientsStore } from '../../../../src/server/auth/clients.js'; -import type { TokenHandlerOptions } from '../../../../src/server/auth/handlers/token.js'; import { tokenHandler } from '../../../../src/server/auth/handlers/token.js'; import type { AuthorizationParams, OAuthServerProvider } from '../../../../src/server/auth/provider.js'; -import { ProxyOAuthServerProvider } from '../../../../src/server/auth/providers/proxyProvider.js'; -// Mock pkce-challenge vi.mock('pkce-challenge', () => ({ - verifyChallenge: vi.fn().mockImplementation(async (verifier, challenge) => { - return verifier === 'valid_verifier' && challenge === 'mock_challenge'; - }) + verifyChallenge: vi.fn() })); -const mockTokens = { - access_token: 'mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'mock_refresh_token' -}; - -const mockTokensWithIdToken = { - ...mockTokens, - id_token: 'mock_id_token' -}; - -describe('Token Handler', () => { - // Mock client data +describe('tokenHandler (web)', () => { const validClient: OAuthClientInformationFull = { client_id: 'valid-client', client_secret: 'valid-secret', redirect_uris: ['https://example.com/callback'] }; - // Mock client store - const mockClientStore: OAuthRegisteredClientsStore = { - async getClient(clientId: string): Promise { - if (clientId === 'valid-client') { - return validClient; - } - return undefined; + const clientsStore: OAuthRegisteredClientsStore = { + async getClient(clientId: string) { + return clientId === 'valid-client' ? validClient : undefined; } }; - // Mock provider - let mockProvider: OAuthServerProvider; - let app: express.Express; - - beforeEach(() => { - // Create fresh mocks for each test - mockProvider = { - clientsStore: mockClientStore, - - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { - res.redirect('https://example.com/callback?code=mock_auth_code'); - }, - - async challengeForAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise { - if (authorizationCode === 'valid_code') { - return 'mock_challenge'; - } else if (authorizationCode === 'expired_code') { - throw new InvalidGrantError('The authorization code has expired'); - } - throw new InvalidGrantError('The authorization code is invalid'); - }, - - async exchangeAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise { - if (authorizationCode === 'valid_code') { - return mockTokens; - } - throw new InvalidGrantError('The authorization code is invalid or has expired'); - }, - - async exchangeRefreshToken(client: OAuthClientInformationFull, refreshToken: string, scopes?: string[]): Promise { - if (refreshToken === 'valid_refresh_token') { - const response: OAuthTokens = { - access_token: 'new_mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'new_mock_refresh_token' - }; - - if (scopes) { - response.scope = scopes.join(' '); - } - - return response; - } - throw new InvalidGrantError('The refresh token is invalid or has expired'); - }, - - async verifyAccessToken(token: string): Promise { - if (token === 'valid_token') { - return { - token, - clientId: 'valid-client', - scopes: ['read', 'write'], - expiresAt: Date.now() / 1000 + 3600 - }; - } - throw new InvalidTokenError('Token is invalid or expired'); - }, - - async revokeToken(_client: OAuthClientInformationFull, _request: OAuthTokenRevocationRequest): Promise { - // Do nothing in mock - } - }; - - // Mock PKCE verification - (pkceChallenge.verifyChallenge as Mock).mockImplementation(async (verifier: string, challenge: string) => { - return verifier === 'valid_verifier' && challenge === 'mock_challenge'; - }); - - // Setup express app with token handler - app = express(); - const options: TokenHandlerOptions = { provider: mockProvider }; - app.use('/token', tokenHandler(options)); - }); - - describe('Basic request validation', () => { - it('requires POST method', async () => { - const response = await supertest(app).get('/token').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code' - }); - - expect(response.status).toBe(405); - expect(response.headers.allow).toBe('POST'); - expect(response.body).toEqual({ - error: 'method_not_allowed', - error_description: 'The method GET is not allowed for this endpoint' - }); - }); - - it('requires grant_type parameter', async () => { - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret' - // Missing grant_type - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_request'); - }); - - it('rejects unsupported grant types', async () => { - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'password' // Unsupported grant type - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('unsupported_grant_type'); - }); - }); - - describe('Client authentication', () => { - it('requires valid client credentials', async () => { - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'invalid-client', - client_secret: 'wrong-secret', - grant_type: 'authorization_code' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_client'); - }); - - it('accepts valid client credentials', async () => { - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'valid_verifier' - }); - - expect(response.status).toBe(200); - }); - }); - - describe('Authorization code grant', () => { - it('requires code parameter', async () => { - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - // Missing code - code_verifier: 'valid_verifier' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_request'); - }); - - it('requires code_verifier parameter', async () => { - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code' - // Missing code_verifier - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_request'); - }); - - it('verifies code_verifier against challenge', async () => { - // Setup invalid verifier - (pkceChallenge.verifyChallenge as Mock).mockResolvedValueOnce(false); - - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'invalid_verifier' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_grant'); - expect(response.body.error_description).toContain('code_verifier'); - }); - - it('rejects expired or invalid authorization codes', async () => { - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'expired_code', - code_verifier: 'valid_verifier' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_grant'); - }); - - it('returns tokens for valid code exchange', async () => { - const mockExchangeCode = vi.spyOn(mockProvider, 'exchangeAuthorizationCode'); - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - resource: 'https://api.example.com/resource', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'valid_verifier' - }); - - expect(response.status).toBe(200); - expect(response.body.access_token).toBe('mock_access_token'); - expect(response.body.token_type).toBe('bearer'); - expect(response.body.expires_in).toBe(3600); - expect(response.body.refresh_token).toBe('mock_refresh_token'); - expect(mockExchangeCode).toHaveBeenCalledWith( - validClient, - 'valid_code', - undefined, // code_verifier is undefined after PKCE validation - undefined, // redirect_uri - new URL('https://api.example.com/resource') // resource parameter - ); - }); - - it('returns id token in code exchange if provided', async () => { - mockProvider.exchangeAuthorizationCode = async ( - client: OAuthClientInformationFull, - authorizationCode: string - ): Promise => { - if (authorizationCode === 'valid_code') { - return mockTokensWithIdToken; - } - throw new InvalidGrantError('The authorization code is invalid or has expired'); + const provider: OAuthServerProvider = { + clientsStore, + async authorize(_client: OAuthClientInformationFull, _params: AuthorizationParams): Promise { + return Response.redirect('https://example.com/callback?code=mock_auth_code', 302); + }, + async challengeForAuthorizationCode(_client: OAuthClientInformationFull, authorizationCode: string): Promise { + if (authorizationCode === 'valid_code') return 'mock_challenge'; + throw new InvalidGrantError('The authorization code is invalid'); + }, + async exchangeAuthorizationCode(): Promise { + return { + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' }; + }, + async exchangeRefreshToken(): Promise { + return { + access_token: 'new_mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'new_mock_refresh_token' + }; + }, + async verifyAccessToken(token: string): Promise { + return { + token, + clientId: 'valid-client', + scopes: [], + expiresAt: Math.floor(Date.now() / 1000) + 3600 + }; + } + }; - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'valid_verifier' - }); - - expect(response.status).toBe(200); - expect(response.body.id_token).toBe('mock_id_token'); - }); - - it('passes through code verifier when using proxy provider', async () => { - const originalFetch = global.fetch; - - try { - global.fetch = vi.fn().mockResolvedValue({ - ok: true, - json: () => Promise.resolve(mockTokens) - }); - - const proxyProvider = new ProxyOAuthServerProvider({ - endpoints: { - authorizationUrl: 'https://example.com/authorize', - tokenUrl: 'https://example.com/token' - }, - verifyAccessToken: async token => ({ - token, - clientId: 'valid-client', - scopes: ['read', 'write'], - expiresAt: Date.now() / 1000 + 3600 - }), - getClient: async clientId => (clientId === 'valid-client' ? validClient : undefined) - }); - - const proxyApp = express(); - const options: TokenHandlerOptions = { provider: proxyProvider }; - proxyApp.use('/token', tokenHandler(options)); - - const response = await supertest(proxyApp).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'any_verifier', - redirect_uri: 'https://example.com/callback' - }); - - expect(response.status).toBe(200); - expect(response.body.access_token).toBe('mock_access_token'); - - expect(global.fetch).toHaveBeenCalledWith( - 'https://example.com/token', - expect.objectContaining({ - method: 'POST', - headers: { - 'Content-Type': 'application/x-www-form-urlencoded' - }, - body: expect.stringContaining('code_verifier=any_verifier') - }) - ); - } finally { - global.fetch = originalFetch; - } - }); - - it('passes through redirect_uri when using proxy provider', async () => { - const originalFetch = global.fetch; - - try { - global.fetch = vi.fn().mockResolvedValue({ - ok: true, - json: () => Promise.resolve(mockTokens) - }); - - const proxyProvider = new ProxyOAuthServerProvider({ - endpoints: { - authorizationUrl: 'https://example.com/authorize', - tokenUrl: 'https://example.com/token' - }, - verifyAccessToken: async token => ({ - token, - clientId: 'valid-client', - scopes: ['read', 'write'], - expiresAt: Date.now() / 1000 + 3600 - }), - getClient: async clientId => (clientId === 'valid-client' ? validClient : undefined) - }); - - const proxyApp = express(); - const options: TokenHandlerOptions = { provider: proxyProvider }; - proxyApp.use('/token', tokenHandler(options)); - - const redirectUri = 'https://example.com/callback'; - const response = await supertest(proxyApp).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'any_verifier', - redirect_uri: redirectUri - }); - - expect(response.status).toBe(200); - expect(response.body.access_token).toBe('mock_access_token'); - - expect(global.fetch).toHaveBeenCalledWith( - 'https://example.com/token', - expect.objectContaining({ - method: 'POST', - headers: { - 'Content-Type': 'application/x-www-form-urlencoded' - }, - body: expect.stringContaining(`redirect_uri=${encodeURIComponent(redirectUri)}`) - }) - ); - } finally { - global.fetch = originalFetch; - } - }); + beforeEach(() => { + vi.clearAllMocks(); }); - describe('Refresh token grant', () => { - it('requires refresh_token parameter', async () => { - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'refresh_token' - // Missing refresh_token - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_request'); - }); - - it('rejects invalid refresh tokens', async () => { - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'refresh_token', - refresh_token: 'invalid_refresh_token' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_grant'); - }); - - it('returns new tokens for valid refresh token', async () => { - const mockExchangeRefresh = vi.spyOn(mockProvider, 'exchangeRefreshToken'); - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - resource: 'https://api.example.com/resource', - grant_type: 'refresh_token', - refresh_token: 'valid_refresh_token' - }); - - expect(response.status).toBe(200); - expect(response.body.access_token).toBe('new_mock_access_token'); - expect(response.body.token_type).toBe('bearer'); - expect(response.body.expires_in).toBe(3600); - expect(response.body.refresh_token).toBe('new_mock_refresh_token'); - expect(mockExchangeRefresh).toHaveBeenCalledWith( - validClient, - 'valid_refresh_token', - undefined, // scopes - new URL('https://api.example.com/resource') // resource parameter - ); - }); - - it('respects requested scopes on refresh', async () => { - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'refresh_token', - refresh_token: 'valid_refresh_token', - scope: 'profile email' - }); - - expect(response.status).toBe(200); - expect(response.body.scope).toBe('profile email'); - }); + it('returns tokens for authorization_code grant when PKCE passes', async () => { + (pkceChallenge.verifyChallenge as unknown as ReturnType).mockResolvedValue(true); + const handler = tokenHandler({ provider }); + + const body = new URLSearchParams({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'valid_verifier' + }).toString(); + + const res = await handler( + new Request('http://localhost/token', { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body + }) + ); + + expect(res.status).toBe(200); + expect(await res.json()).toEqual( + expect.objectContaining({ + access_token: 'mock_access_token' + }) + ); }); - describe('CORS support', () => { - it('includes CORS headers in response', async () => { - const response = await supertest(app).post('/token').type('form').set('Origin', 'https://example.com').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'valid_verifier' - }); - - expect(response.header['access-control-allow-origin']).toBe('*'); - }); + it('returns 400 when PKCE fails', async () => { + (pkceChallenge.verifyChallenge as unknown as ReturnType).mockResolvedValue(false); + const handler = tokenHandler({ provider }); + + const body = new URLSearchParams({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'bad_verifier' + }).toString(); + + const res = await handler( + new Request('http://localhost/token', { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body + }) + ); + + expect(res.status).toBe(400); + expect(await res.json()).toEqual(expect.objectContaining({ error: 'invalid_grant' })); }); }); diff --git a/packages/server/test/server/auth/middleware/allowedMethods.test.ts b/packages/server/test/server/auth/middleware/allowedMethods.test.ts index 40e9c3b1f..3cea847a1 100644 --- a/packages/server/test/server/auth/middleware/allowedMethods.test.ts +++ b/packages/server/test/server/auth/middleware/allowedMethods.test.ts @@ -1,77 +1,29 @@ -import type { Request, Response } from 'express'; -import express from 'express'; -import request from 'supertest'; - import { allowedMethods } from '../../../../src/server/auth/middleware/allowedMethods.js'; describe('allowedMethods', () => { - let app: express.Express; - - beforeEach(() => { - app = express(); - - // Set up a test router with a GET handler and 405 middleware - const router = express.Router(); - - router.get('/test', (req, res) => { - res.status(200).send('GET success'); - }); - - // Add method not allowed middleware for all other methods - router.all('/test', allowedMethods(['GET'])); - - app.use(router); - }); - - test('allows specified HTTP method', async () => { - const response = await request(app).get('/test'); - expect(response.status).toBe(200); - expect(response.text).toBe('GET success'); - }); - - test('returns 405 for unspecified HTTP methods', async () => { - const methods = ['post', 'put', 'delete', 'patch']; - - for (const method of methods) { - // @ts-expect-error - dynamic method call - const response = await request(app)[method]('/test'); - expect(response.status).toBe(405); - expect(response.body).toEqual({ - error: 'method_not_allowed', - error_description: `The method ${method.toUpperCase()} is not allowed for this endpoint` - }); - } + test('returns undefined for allowed HTTP method', () => { + const req = new Request('http://localhost/test', { method: 'GET' }); + const res = allowedMethods(['GET'], req); + expect(res).toBeUndefined(); }); - test('includes Allow header with specified methods', async () => { - const response = await request(app).post('/test'); - expect(response.headers.allow).toBe('GET'); - }); - - test('works with multiple allowed methods', async () => { - const multiMethodApp = express(); - const router = express.Router(); - - router.get('/multi', (req: Request, res: Response) => { - res.status(200).send('GET'); + test('returns 405 response for disallowed HTTP method', async () => { + const req = new Request('http://localhost/test', { method: 'POST' }); + const res = allowedMethods(['GET'], req); + expect(res).toBeDefined(); + expect(res!.status).toBe(405); + expect(res!.headers.get('allow')).toBe('GET'); + expect(await res!.json()).toEqual({ + error: 'method_not_allowed', + error_description: 'The method POST is not allowed for this endpoint' }); - router.post('/multi', (req: Request, res: Response) => { - res.status(200).send('POST'); - }); - router.all('/multi', allowedMethods(['GET', 'POST'])); - - multiMethodApp.use(router); - - // Allowed methods should work - const getResponse = await request(multiMethodApp).get('/multi'); - expect(getResponse.status).toBe(200); - - const postResponse = await request(multiMethodApp).post('/multi'); - expect(postResponse.status).toBe(200); + }); - // Unallowed methods should return 405 - const putResponse = await request(multiMethodApp).put('/multi'); - expect(putResponse.status).toBe(405); - expect(putResponse.headers.allow).toBe('GET, POST'); + test('supports multiple allowed methods', async () => { + const req = new Request('http://localhost/test', { method: 'PUT' }); + const res = allowedMethods(['GET', 'POST'], req); + expect(res).toBeDefined(); + expect(res!.status).toBe(405); + expect(res!.headers.get('allow')).toBe('GET, POST'); }); }); diff --git a/packages/server/test/server/auth/middleware/bearerAuth.test.ts b/packages/server/test/server/auth/middleware/bearerAuth.test.ts index 7b464bbff..9b0ead3ef 100644 --- a/packages/server/test/server/auth/middleware/bearerAuth.test.ts +++ b/packages/server/test/server/auth/middleware/bearerAuth.test.ts @@ -1,502 +1,118 @@ import type { AuthInfo } from '@modelcontextprotocol/core'; -import { CustomOAuthError, InsufficientScopeError, InvalidTokenError, ServerError } from '@modelcontextprotocol/core'; -import { createExpressResponseMock } from '@modelcontextprotocol/test-helpers'; -import type { Request, Response } from 'express'; -import type { Mock } from 'vitest'; +import { InsufficientScopeError, InvalidTokenError, ServerError } from '@modelcontextprotocol/core'; import { requireBearerAuth } from '../../../../src/server/auth/middleware/bearerAuth.js'; import type { OAuthTokenVerifier } from '../../../../src/server/auth/provider.js'; -// Mock verifier -const mockVerifyAccessToken = vi.fn(); -const mockVerifier: OAuthTokenVerifier = { - verifyAccessToken: mockVerifyAccessToken -}; - -describe('requireBearerAuth middleware', () => { - let mockRequest: Partial; - let mockResponse: Partial; - let nextFunction: Mock; +describe('requireBearerAuth (web)', () => { + const verifyAccessToken = vi.fn(); + const verifier: OAuthTokenVerifier = { verifyAccessToken }; beforeEach(() => { - mockRequest = { - headers: {} - }; - mockResponse = createExpressResponseMock(); - nextFunction = vi.fn(); - vi.spyOn(console, 'error').mockImplementation(() => {}); - }); - - afterEach(() => { vi.clearAllMocks(); }); - it('should call next when token is valid', async () => { - const validAuthInfo: AuthInfo = { - token: 'valid-token', - clientId: 'client-123', - scopes: ['read', 'write'], - expiresAt: Math.floor(Date.now() / 1000) + 3600 // Token expires in an hour - }; - mockVerifyAccessToken.mockResolvedValue(validAuthInfo); - - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); - expect(mockRequest.auth).toEqual(validAuthInfo); - expect(nextFunction).toHaveBeenCalled(); - expect(mockResponse.status).not.toHaveBeenCalled(); - expect(mockResponse.json).not.toHaveBeenCalled(); - }); - - it.each([ - [100], // Token expired 100 seconds ago - [0] // Token expires at the same time as now - ])('should reject expired tokens (expired %s seconds ago)', async (expiredSecondsAgo: number) => { - const expiresAt = Math.floor(Date.now() / 1000) - expiredSecondsAgo; - const expiredAuthInfo: AuthInfo = { - token: 'expired-token', - clientId: 'client-123', - scopes: ['read', 'write'], - expiresAt - }; - mockVerifyAccessToken.mockResolvedValue(expiredAuthInfo); - - mockRequest.headers = { - authorization: 'Bearer expired-token' - }; - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith('expired-token'); - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="invalid_token"')); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: 'invalid_token', error_description: 'Token has expired' }) - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it.each([ - [undefined], // Token has no expiration time - [NaN] // Token has no expiration time - ])('should reject tokens with no expiration time (expiresAt: %s)', async (expiresAt: number | undefined) => { - const noExpirationAuthInfo: AuthInfo = { - token: 'no-expiration-token', - clientId: 'client-123', - scopes: ['read', 'write'], - expiresAt - }; - mockVerifyAccessToken.mockResolvedValue(noExpirationAuthInfo); - - mockRequest.headers = { - authorization: 'Bearer expired-token' - }; - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith('expired-token'); - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="invalid_token"')); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: 'invalid_token', error_description: 'Token has no expiration time' }) - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it('should accept non-expired tokens', async () => { - const nonExpiredAuthInfo: AuthInfo = { - token: 'valid-token', - clientId: 'client-123', - scopes: ['read', 'write'], - expiresAt: Math.floor(Date.now() / 1000) + 3600 // Token expires in an hour - }; - mockVerifyAccessToken.mockResolvedValue(nonExpiredAuthInfo); - - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); - expect(mockRequest.auth).toEqual(nonExpiredAuthInfo); - expect(nextFunction).toHaveBeenCalled(); - expect(mockResponse.status).not.toHaveBeenCalled(); - expect(mockResponse.json).not.toHaveBeenCalled(); - }); - - it('should require specific scopes when configured', async () => { - const authInfo: AuthInfo = { - token: 'valid-token', - clientId: 'client-123', - scopes: ['read'] - }; - mockVerifyAccessToken.mockResolvedValue(authInfo); - - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; - - const middleware = requireBearerAuth({ - verifier: mockVerifier, - requiredScopes: ['read', 'write'] - }); - - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); - expect(mockResponse.status).toHaveBeenCalledWith(403); - expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="insufficient_scope"')); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: 'insufficient_scope', error_description: 'Insufficient scope' }) - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it('should accept token with all required scopes', async () => { - const authInfo: AuthInfo = { - token: 'valid-token', - clientId: 'client-123', - scopes: ['read', 'write', 'admin'], - expiresAt: Math.floor(Date.now() / 1000) + 3600 // Token expires in an hour - }; - mockVerifyAccessToken.mockResolvedValue(authInfo); - - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; - - const middleware = requireBearerAuth({ - verifier: mockVerifier, - requiredScopes: ['read', 'write'] - }); - - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); - expect(mockRequest.auth).toEqual(authInfo); - expect(nextFunction).toHaveBeenCalled(); - expect(mockResponse.status).not.toHaveBeenCalled(); - expect(mockResponse.json).not.toHaveBeenCalled(); - }); - - it('should return 401 when no Authorization header is present', async () => { - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).not.toHaveBeenCalled(); - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="invalid_token"')); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: 'invalid_token', error_description: 'Missing Authorization header' }) - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it('should return 401 when Authorization header format is invalid', async () => { - mockRequest.headers = { - authorization: 'InvalidFormat' + it('returns authInfo on success', async () => { + const info: AuthInfo = { + token: 't', + clientId: 'c', + scopes: ['read'], + expiresAt: Math.floor(Date.now() / 1000) + 3600 }; + verifyAccessToken.mockResolvedValue(info); - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + const req = new Request('http://localhost/x', { headers: { Authorization: 'Bearer t' } }); + const result = await requireBearerAuth(req, { verifier }); - expect(mockVerifyAccessToken).not.toHaveBeenCalled(); - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="invalid_token"')); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ - error: 'invalid_token', - error_description: "Invalid Authorization header format, expected 'Bearer TOKEN'" - }) - ); - expect(nextFunction).not.toHaveBeenCalled(); + expect('authInfo' in result).toBe(true); + if ('authInfo' in result) { + expect(result.authInfo).toEqual(info); + } }); - it('should return 401 when token verification fails with InvalidTokenError', async () => { - mockRequest.headers = { - authorization: 'Bearer invalid-token' - }; - - mockVerifyAccessToken.mockRejectedValue(new InvalidTokenError('Token expired')); - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith('invalid-token'); - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="invalid_token"')); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: 'invalid_token', error_description: 'Token expired' }) - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it('should return 403 when access token has insufficient scopes', async () => { - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; + it('returns 401 when missing Authorization header', async () => { + const req = new Request('http://localhost/x'); + const result = await requireBearerAuth(req, { verifier }); - mockVerifyAccessToken.mockRejectedValue(new InsufficientScopeError('Required scopes: read, write')); - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); - expect(mockResponse.status).toHaveBeenCalledWith(403); - expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="insufficient_scope"')); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: 'insufficient_scope', error_description: 'Required scopes: read, write' }) - ); - expect(nextFunction).not.toHaveBeenCalled(); + expect('response' in result).toBe(true); + if ('response' in result) { + expect(result.response.status).toBe(401); + expect(result.response.headers.get('www-authenticate')).toContain('Bearer error="invalid_token"'); + expect(await result.response.json()).toEqual( + expect.objectContaining({ error: 'invalid_token', error_description: 'Missing Authorization header' }) + ); + } }); - it('should return 500 when a ServerError occurs', async () => { - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; - - mockVerifyAccessToken.mockRejectedValue(new ServerError('Internal server issue')); - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + it('returns 401 when verifier throws InvalidTokenError', async () => { + verifyAccessToken.mockRejectedValue(new InvalidTokenError('bad')); + const req = new Request('http://localhost/x', { headers: { Authorization: 'Bearer t' } }); + const result = await requireBearerAuth(req, { verifier }); - expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); - expect(mockResponse.status).toHaveBeenCalledWith(500); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: 'server_error', error_description: 'Internal server issue' }) - ); - expect(nextFunction).not.toHaveBeenCalled(); + expect('response' in result).toBe(true); + if ('response' in result) { + expect(result.response.status).toBe(401); + } }); - it('should return 400 for generic OAuthError', async () => { - mockRequest.headers = { - authorization: 'Bearer valid-token' + it('returns 403 when scopes are insufficient', async () => { + const info: AuthInfo = { + token: 't', + clientId: 'c', + scopes: ['read'], + expiresAt: Math.floor(Date.now() / 1000) + 3600 }; + verifyAccessToken.mockResolvedValue(info); - mockVerifyAccessToken.mockRejectedValue(new CustomOAuthError('custom_error', 'Some OAuth error')); + const req = new Request('http://localhost/x', { headers: { Authorization: 'Bearer t' } }); + const result = await requireBearerAuth(req, { verifier, requiredScopes: ['read', 'write'] }); - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); - expect(mockResponse.status).toHaveBeenCalledWith(400); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: 'custom_error', error_description: 'Some OAuth error' }) - ); - expect(nextFunction).not.toHaveBeenCalled(); + expect('response' in result).toBe(true); + if ('response' in result) { + expect(result.response.status).toBe(403); + expect(result.response.headers.get('www-authenticate')).toContain('Bearer error="insufficient_scope"'); + expect(await result.response.json()).toEqual( + expect.objectContaining({ error: 'insufficient_scope', error_description: 'Insufficient scope' }) + ); + } }); - it('should return 500 when unexpected error occurs', async () => { - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; - - mockVerifyAccessToken.mockRejectedValue(new Error('Unexpected error')); + it('returns 500 when verifier throws ServerError', async () => { + verifyAccessToken.mockRejectedValue(new ServerError('boom')); + const req = new Request('http://localhost/x', { headers: { Authorization: 'Bearer t' } }); + const result = await requireBearerAuth(req, { verifier }); - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); - expect(mockResponse.status).toHaveBeenCalledWith(500); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: 'server_error', error_description: 'Internal Server Error' }) - ); - expect(nextFunction).not.toHaveBeenCalled(); + expect('response' in result).toBe(true); + if ('response' in result) { + expect(result.response.status).toBe(500); + } }); - describe('with requiredScopes in WWW-Authenticate header', () => { - it('should include scope in WWW-Authenticate header for 401 responses when requiredScopes is provided', async () => { - mockRequest.headers = {}; - - const middleware = requireBearerAuth({ - verifier: mockVerifier, - requiredScopes: ['read', 'write'] - }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith( - 'WWW-Authenticate', - 'Bearer error="invalid_token", error_description="Missing Authorization header", scope="read write"' - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it('should include scope in WWW-Authenticate header for 403 insufficient scope responses', async () => { - const authInfo: AuthInfo = { - token: 'valid-token', - clientId: 'client-123', - scopes: ['read'] - }; - mockVerifyAccessToken.mockResolvedValue(authInfo); - - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; - - const middleware = requireBearerAuth({ - verifier: mockVerifier, - requiredScopes: ['read', 'write'] - }); - - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockResponse.status).toHaveBeenCalledWith(403); - expect(mockResponse.set).toHaveBeenCalledWith( - 'WWW-Authenticate', - 'Bearer error="insufficient_scope", error_description="Insufficient scope", scope="read write"' - ); - expect(nextFunction).not.toHaveBeenCalled(); + it('includes scope and resource_metadata in WWW-Authenticate when provided', async () => { + verifyAccessToken.mockRejectedValue(new InvalidTokenError('bad')); + const req = new Request('http://localhost/x', { headers: { Authorization: 'Bearer t' } }); + const result = await requireBearerAuth(req, { + verifier, + requiredScopes: ['read', 'write'], + resourceMetadataUrl: 'https://example.com/.well-known/oauth-protected-resource' }); - it('should include both scope and resource_metadata in WWW-Authenticate header when both are provided', async () => { - mockRequest.headers = {}; - - const resourceMetadataUrl = 'https://api.example.com/.well-known/oauth-protected-resource'; - const middleware = requireBearerAuth({ - verifier: mockVerifier, - requiredScopes: ['admin'], - resourceMetadataUrl - }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith( - 'WWW-Authenticate', - `Bearer error="invalid_token", error_description="Missing Authorization header", scope="admin", resource_metadata="${resourceMetadataUrl}"` - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); + expect('response' in result).toBe(true); + if ('response' in result) { + const header = result.response.headers.get('www-authenticate') ?? ''; + expect(header).toContain('scope="read write"'); + expect(header).toContain('resource_metadata="https://example.com/.well-known/oauth-protected-resource"'); + } }); - describe('with resourceMetadataUrl', () => { - const resourceMetadataUrl = 'https://api.example.com/.well-known/oauth-protected-resource'; - - it('should include resource_metadata in WWW-Authenticate header for 401 responses', async () => { - mockRequest.headers = {}; - - const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith( - 'WWW-Authenticate', - `Bearer error="invalid_token", error_description="Missing Authorization header", resource_metadata="${resourceMetadataUrl}"` - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); + it('passes through InsufficientScopeError from verifier as 403', async () => { + verifyAccessToken.mockRejectedValue(new InsufficientScopeError('nope')); + const req = new Request('http://localhost/x', { headers: { Authorization: 'Bearer t' } }); + const result = await requireBearerAuth(req, { verifier }); - it('should include resource_metadata in WWW-Authenticate header when token verification fails', async () => { - mockRequest.headers = { - authorization: 'Bearer invalid-token' - }; - - mockVerifyAccessToken.mockRejectedValue(new InvalidTokenError('Token expired')); - - const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith( - 'WWW-Authenticate', - `Bearer error="invalid_token", error_description="Token expired", resource_metadata="${resourceMetadataUrl}"` - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it('should include resource_metadata in WWW-Authenticate header for insufficient scope errors', async () => { - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; - - mockVerifyAccessToken.mockRejectedValue(new InsufficientScopeError('Required scopes: admin')); - - const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockResponse.status).toHaveBeenCalledWith(403); - expect(mockResponse.set).toHaveBeenCalledWith( - 'WWW-Authenticate', - `Bearer error="insufficient_scope", error_description="Required scopes: admin", resource_metadata="${resourceMetadataUrl}"` - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it('should include resource_metadata when token is expired', async () => { - const expiredAuthInfo: AuthInfo = { - token: 'expired-token', - clientId: 'client-123', - scopes: ['read', 'write'], - expiresAt: Math.floor(Date.now() / 1000) - 100 - }; - mockVerifyAccessToken.mockResolvedValue(expiredAuthInfo); - - mockRequest.headers = { - authorization: 'Bearer expired-token' - }; - - const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith( - 'WWW-Authenticate', - `Bearer error="invalid_token", error_description="Token has expired", resource_metadata="${resourceMetadataUrl}"` - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it('should include resource_metadata when scope check fails', async () => { - const authInfo: AuthInfo = { - token: 'valid-token', - clientId: 'client-123', - scopes: ['read'] - }; - mockVerifyAccessToken.mockResolvedValue(authInfo); - - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; - - const middleware = requireBearerAuth({ - verifier: mockVerifier, - requiredScopes: ['read', 'write'], - resourceMetadataUrl - }); - - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockResponse.status).toHaveBeenCalledWith(403); - expect(mockResponse.set).toHaveBeenCalledWith( - 'WWW-Authenticate', - `Bearer error="insufficient_scope", error_description="Insufficient scope", scope="read write", resource_metadata="${resourceMetadataUrl}"` - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it('should not affect server errors (no WWW-Authenticate header)', async () => { - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; - - mockVerifyAccessToken.mockRejectedValue(new ServerError('Internal server issue')); - - const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockResponse.status).toHaveBeenCalledWith(500); - expect(mockResponse.set).not.toHaveBeenCalledWith('WWW-Authenticate', expect.anything()); - expect(nextFunction).not.toHaveBeenCalled(); - }); + expect('response' in result).toBe(true); + if ('response' in result) { + expect(result.response.status).toBe(403); + } }); }); diff --git a/packages/server/test/server/auth/middleware/clientAuth.test.ts b/packages/server/test/server/auth/middleware/clientAuth.test.ts index 55a00f0c2..0ee9aae0a 100644 --- a/packages/server/test/server/auth/middleware/clientAuth.test.ts +++ b/packages/server/test/server/auth/middleware/clientAuth.test.ts @@ -1,12 +1,11 @@ import type { OAuthClientInformationFull } from '@modelcontextprotocol/core'; -import express from 'express'; -import supertest from 'supertest'; +import { InvalidClientError, InvalidRequestError } from '@modelcontextprotocol/core'; import type { OAuthRegisteredClientsStore } from '../../../../src/server/auth/clients.js'; import type { ClientAuthenticationMiddlewareOptions } from '../../../../src/server/auth/middleware/clientAuth.js'; import { authenticateClient } from '../../../../src/server/auth/middleware/clientAuth.js'; -describe('clientAuth middleware', () => { +describe('authenticateClient', () => { // Mock client store const mockClientStore: OAuthRegisteredClientsStore = { async getClient(clientId: string): Promise { @@ -35,100 +34,92 @@ describe('clientAuth middleware', () => { } }; - // Setup Express app with middleware - let app: express.Express; let options: ClientAuthenticationMiddlewareOptions; beforeEach(() => { - app = express(); - app.use(express.json()); - options = { clientsStore: mockClientStore }; - - // Setup route with client auth - app.post('/protected', authenticateClient(options), (req, res) => { - res.status(200).json({ success: true, client: req.client }); - }); }); it('authenticates valid client credentials', async () => { - const response = await supertest(app).post('/protected').send({ - client_id: 'valid-client', - client_secret: 'valid-secret' - }); - - expect(response.status).toBe(200); - expect(response.body.success).toBe(true); - expect(response.body.client.client_id).toBe('valid-client'); + const client = await authenticateClient( + { + client_id: 'valid-client', + client_secret: 'valid-secret' + }, + options + ); + + expect(client.client_id).toBe('valid-client'); }); it('rejects invalid client_id', async () => { - const response = await supertest(app).post('/protected').send({ - client_id: 'non-existent-client', - client_secret: 'some-secret' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_client'); - expect(response.body.error_description).toBe('Invalid client_id'); + await expect( + authenticateClient( + { + client_id: 'non-existent-client', + client_secret: 'some-secret' + }, + options + ) + ).rejects.toBeInstanceOf(InvalidClientError); }); it('rejects invalid client_secret', async () => { - const response = await supertest(app).post('/protected').send({ - client_id: 'valid-client', - client_secret: 'wrong-secret' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_client'); - expect(response.body.error_description).toBe('Invalid client_secret'); + await expect( + authenticateClient( + { + client_id: 'valid-client', + client_secret: 'wrong-secret' + }, + options + ) + ).rejects.toBeInstanceOf(InvalidClientError); }); it('rejects missing client_id', async () => { - const response = await supertest(app).post('/protected').send({ - client_secret: 'valid-secret' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_request'); + await expect( + authenticateClient( + { + client_secret: 'valid-secret' + }, + options + ) + ).rejects.toBeInstanceOf(InvalidRequestError); }); it('allows missing client_secret if client has none', async () => { - const response = await supertest(app).post('/protected').send({ - client_id: 'expired-client' - }); - - // Since the client has no secret, this should pass without providing one - expect(response.status).toBe(200); + const client = await authenticateClient( + { + client_id: 'expired-client' + }, + options + ); + expect(client.client_id).toBe('expired-client'); }); it('rejects request when client secret has expired', async () => { - const response = await supertest(app).post('/protected').send({ - client_id: 'client-with-expired-secret', - client_secret: 'expired-secret' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_client'); - expect(response.body.error_description).toBe('Client secret has expired'); - }); - - it('handles malformed request body', async () => { - const response = await supertest(app).post('/protected').send('not-json-format'); - - expect(response.status).toBe(400); + await expect( + authenticateClient( + { + client_id: 'client-with-expired-secret', + client_secret: 'expired-secret' + }, + options + ) + ).rejects.toBeInstanceOf(InvalidClientError); }); - // Testing request with extra fields to ensure they're ignored it('ignores extra fields in request', async () => { - const response = await supertest(app).post('/protected').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - extra_field: 'should be ignored' - }); - - expect(response.status).toBe(200); + const client = await authenticateClient( + { + client_id: 'valid-client', + client_secret: 'valid-secret', + extra_field: 'ignored' + }, + options + ); + expect(client.client_id).toBe('valid-client'); }); }); diff --git a/packages/server/test/server/auth/providers/proxyProvider.test.ts b/packages/server/test/server/auth/providers/proxyProvider.test.ts index 375179e5b..143cfa78d 100644 --- a/packages/server/test/server/auth/providers/proxyProvider.test.ts +++ b/packages/server/test/server/auth/providers/proxyProvider.test.ts @@ -1,6 +1,5 @@ import type { AuthInfo, OAuthClientInformationFull, OAuthTokens } from '@modelcontextprotocol/core'; import { InsufficientScopeError, InvalidTokenError, ServerError } from '@modelcontextprotocol/core'; -import type { Response } from 'express'; import { type Mock } from 'vitest'; import type { ProxyOptions } from '../../../../src/server/auth/providers/proxyProvider.js'; @@ -14,11 +13,6 @@ describe('Proxy OAuth Server Provider', () => { redirect_uris: ['https://example.com/callback'] }; - // Mock response object - const mockResponse = { - redirect: vi.fn() - } as unknown as Response; - // Mock provider functions const mockVerifyToken = vi.fn(); const mockGetClient = vi.fn(); @@ -81,17 +75,13 @@ describe('Proxy OAuth Server Provider', () => { describe('authorization', () => { it('redirects to authorization endpoint with correct parameters', async () => { - await provider.authorize( - validClient, - { - redirectUri: 'https://example.com/callback', - codeChallenge: 'test-challenge', - state: 'test-state', - scopes: ['read', 'write'], - resource: new URL('https://api.example.com/resource') - }, - mockResponse - ); + const response = await provider.authorize(validClient, { + redirectUri: 'https://example.com/callback', + codeChallenge: 'test-challenge', + state: 'test-state', + scopes: ['read', 'write'], + resource: new URL('https://api.example.com/resource') + }); const expectedUrl = new URL('https://auth.example.com/authorize'); expectedUrl.searchParams.set('client_id', 'test-client'); @@ -103,7 +93,8 @@ describe('Proxy OAuth Server Provider', () => { expectedUrl.searchParams.set('scope', 'read write'); expectedUrl.searchParams.set('resource', 'https://api.example.com/resource'); - expect(mockResponse.redirect).toHaveBeenCalledWith(expectedUrl.toString()); + expect(response.status).toBe(302); + expect(response.headers.get('location')).toBe(expectedUrl.toString()); }); }); diff --git a/packages/server/test/server/streamableHttp.test.ts b/packages/server/test/server/streamableHttp.test.ts index d8c6388e4..57e47668b 100644 --- a/packages/server/test/server/streamableHttp.test.ts +++ b/packages/server/test/server/streamableHttp.test.ts @@ -15,9 +15,10 @@ import type { import { listenOnRandomPort } from '@modelcontextprotocol/test-helpers'; import { McpServer } from '../../src/server/mcp.js'; -import { StreamableHTTPServerTransport } from '../../src/server/streamableHttp.js'; +import { NodeStreamableHTTPServerTransport } from '../../src/server/streamableHttp.js'; import type { EventId, EventStore, StreamId } from '../../src/server/webStandardStreamableHttp.js'; -import { type ZodMatrixEntry, zodTestMatrix } from './__fixtures__/zodTestMatrix.js'; +import type { ZodMatrixEntry } from './__fixtures__/zodTestMatrix.js'; +import { zodTestMatrix } from './__fixtures__/zodTestMatrix.js'; async function getFreePort() { return new Promise(res => { @@ -34,7 +35,7 @@ async function getFreePort() { } /** - * Test server configuration for StreamableHTTPServerTransport tests + * Test server configuration for NodeStreamableHTTPServerTransport tests */ interface TestServerConfig { sessionIdGenerator: (() => string) | undefined; @@ -49,7 +50,7 @@ interface TestServerConfig { /** * Helper to stop test server */ -async function stopTestServer({ server, transport }: { server: Server; transport: StreamableHTTPServerTransport }): Promise { +async function stopTestServer({ server, transport }: { server: Server; transport: NodeStreamableHTTPServerTransport }): Promise { // First close the transport to ensure all SSE streams are closed await transport.close(); @@ -153,7 +154,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { */ async function createTestServer(config: TestServerConfig = { sessionIdGenerator: () => randomUUID() }): Promise<{ server: Server; - transport: StreamableHTTPServerTransport; + transport: NodeStreamableHTTPServerTransport; mcpServer: McpServer; baseUrl: URL; }> { @@ -168,7 +169,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } ); - const transport = new StreamableHTTPServerTransport({ + const transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: config.sessionIdGenerator, enableJsonResponse: config.enableJsonResponse ?? false, eventStore: config.eventStore, @@ -202,7 +203,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { */ async function createTestAuthServer(config: TestServerConfig = { sessionIdGenerator: () => randomUUID() }): Promise<{ server: Server; - transport: StreamableHTTPServerTransport; + transport: NodeStreamableHTTPServerTransport; mcpServer: McpServer; baseUrl: URL; }> { @@ -217,7 +218,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } ); - const transport = new StreamableHTTPServerTransport({ + const transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: config.sessionIdGenerator, enableJsonResponse: config.enableJsonResponse ?? false, eventStore: config.eventStore, @@ -247,10 +248,10 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } const { z } = entry; - describe('StreamableHTTPServerTransport', () => { + describe('NodeStreamableHTTPServerTransport', () => { let server: Server; let mcpServer: McpServer; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; let baseUrl: URL; let sessionId: string; @@ -979,9 +980,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); }); - describe('StreamableHTTPServerTransport with AuthInfo', () => { + describe('NodeStreamableHTTPServerTransport with AuthInfo', () => { let server: Server; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; let baseUrl: URL; let sessionId: string; @@ -1079,9 +1080,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); // Test JSON Response Mode - describe('StreamableHTTPServerTransport with JSON Response Mode', () => { + describe('NodeStreamableHTTPServerTransport with JSON Response Mode', () => { let server: Server; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; let baseUrl: URL; let sessionId: string; @@ -1166,9 +1167,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); // Test pre-parsed body handling - describe('StreamableHTTPServerTransport with pre-parsed body', () => { + describe('NodeStreamableHTTPServerTransport with pre-parsed body', () => { let server: Server; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; let baseUrl: URL; let sessionId: string; let parsedBody: unknown = null; @@ -1302,9 +1303,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); // Test resumability support - describe('StreamableHTTPServerTransport with resumability', () => { + describe('NodeStreamableHTTPServerTransport with resumability', () => { let server: Server; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; let baseUrl: URL; let sessionId: string; let mcpServer: McpServer; @@ -1538,9 +1539,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); // Test stateless mode - describe('StreamableHTTPServerTransport in stateless mode', () => { + describe('NodeStreamableHTTPServerTransport in stateless mode', () => { let server: Server; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; let baseUrl: URL; beforeEach(async () => { @@ -1626,9 +1627,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); // Test SSE priming events for POST streams - describe('StreamableHTTPServerTransport POST SSE priming events', () => { + describe('NodeStreamableHTTPServerTransport POST SSE priming events', () => { let server: Server; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; let baseUrl: URL; let sessionId: string; let mcpServer: McpServer; @@ -2327,7 +2328,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); // Test onsessionclosed callback - describe('StreamableHTTPServerTransport onsessionclosed callback', () => { + describe('NodeStreamableHTTPServerTransport onsessionclosed callback', () => { it('should call onsessionclosed callback when session is closed via DELETE', async () => { const mockCallback = vi.fn(); @@ -2486,7 +2487,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); // Test async callbacks for onsessioninitialized and onsessionclosed - describe('StreamableHTTPServerTransport async callbacks', () => { + describe('NodeStreamableHTTPServerTransport async callbacks', () => { it('should support async onsessioninitialized callback', async () => { const initializationOrder: string[] = []; @@ -2693,9 +2694,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); // Test DNS rebinding protection - describe('StreamableHTTPServerTransport DNS rebinding protection', () => { + describe('NodeStreamableHTTPServerTransport DNS rebinding protection', () => { let server: Server; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; let baseUrl: URL; afterEach(async () => { @@ -2931,7 +2932,7 @@ async function createTestServerWithDnsProtection(config: { enableDnsRebindingProtection?: boolean; }): Promise<{ server: Server; - transport: StreamableHTTPServerTransport; + transport: NodeStreamableHTTPServerTransport; mcpServer: McpServer; baseUrl: URL; }> { @@ -2948,7 +2949,7 @@ async function createTestServerWithDnsProtection(config: { }); } - const transport = new StreamableHTTPServerTransport({ + const transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: config.sessionIdGenerator, allowedHosts: config.allowedHosts, allowedOrigins: config.allowedOrigins, diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 92dbf8253..c2b9d0835 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -89,6 +89,9 @@ catalogs: '@hono/node-server': specifier: ^1.19.7 version: 1.19.7 + '@remix-run/node-fetch-server': + specifier: ^0.13.0 + version: 0.13.0 content-type: specifier: ^1.0.5 version: 1.0.5 @@ -99,8 +102,8 @@ catalogs: specifier: ^5.0.1 version: 5.1.0 express-rate-limit: - specifier: ^7.5.0 - version: 7.5.1 + specifier: ^8.2.1 + version: 8.2.1 hono: specifier: ^4.11.1 version: 4.11.1 @@ -305,6 +308,12 @@ importers: '@modelcontextprotocol/server': specifier: workspace:^ version: link:../../packages/server + '@modelcontextprotocol/server-express': + specifier: workspace:^ + version: link:../../packages/server-express + '@modelcontextprotocol/server-hono': + specifier: workspace:^ + version: link:../../packages/server-hono cors: specifier: catalog:runtimeServerOnly version: 2.8.5 @@ -342,6 +351,9 @@ importers: '@modelcontextprotocol/server': specifier: workspace:^ version: link:../../packages/server + '@modelcontextprotocol/server-express': + specifier: workspace:^ + version: link:../../packages/server-express express: specifier: catalog:runtimeServerOnly version: 5.1.0 @@ -561,18 +573,6 @@ importers: content-type: specifier: catalog:runtimeServerOnly version: 1.0.5 - cors: - specifier: catalog:runtimeServerOnly - version: 2.8.5 - express: - specifier: catalog:runtimeServerOnly - version: 5.1.0 - express-rate-limit: - specifier: catalog:runtimeServerOnly - version: 7.5.1(express@5.1.0) - hono: - specifier: catalog:runtimeServerOnly - version: 4.11.1 pkce-challenge: specifier: catalog:runtimeShared version: 5.0.0 @@ -662,6 +662,122 @@ importers: specifier: catalog:devTools version: 4.0.9(@types/node@24.10.3)(tsx@4.20.6) + packages/server-express: + dependencies: + '@modelcontextprotocol/server': + specifier: workspace:^ + version: link:../server + '@remix-run/node-fetch-server': + specifier: catalog:runtimeServerOnly + version: 0.13.0 + express: + specifier: catalog:runtimeServerOnly + version: 5.1.0 + express-rate-limit: + specifier: catalog:runtimeServerOnly + version: 8.2.1(express@5.1.0) + devDependencies: + '@eslint/js': + specifier: catalog:devTools + version: 9.39.1 + '@modelcontextprotocol/eslint-config': + specifier: workspace:^ + version: link:../../common/eslint-config + '@modelcontextprotocol/tsconfig': + specifier: workspace:^ + version: link:../../common/tsconfig + '@modelcontextprotocol/vitest-config': + specifier: workspace:^ + version: link:../../common/vitest-config + '@types/express': + specifier: catalog:devTools + version: 5.0.5 + '@types/express-serve-static-core': + specifier: catalog:devTools + version: 5.1.0 + '@types/supertest': + specifier: catalog:devTools + version: 6.0.3 + '@typescript/native-preview': + specifier: catalog:devTools + version: 7.0.0-dev.20251218.3 + eslint: + specifier: catalog:devTools + version: 9.39.1 + eslint-config-prettier: + specifier: catalog:devTools + version: 10.1.8(eslint@9.39.1) + eslint-plugin-n: + specifier: catalog:devTools + version: 17.23.1(eslint@9.39.1)(typescript@5.9.3) + prettier: + specifier: catalog:devTools + version: 3.6.2 + supertest: + specifier: catalog:devTools + version: 7.1.4 + tsdown: + specifier: catalog:devTools + version: 0.18.0(@typescript/native-preview@7.0.0-dev.20251218.3)(typescript@5.9.3) + typescript: + specifier: catalog:devTools + version: 5.9.3 + typescript-eslint: + specifier: catalog:devTools + version: 8.49.0(eslint@9.39.1)(typescript@5.9.3) + vitest: + specifier: catalog:devTools + version: 4.0.9(@types/node@24.10.3)(tsx@4.20.6) + + packages/server-hono: + dependencies: + '@modelcontextprotocol/server': + specifier: workspace:^ + version: link:../server + hono: + specifier: catalog:runtimeServerOnly + version: 4.11.1 + devDependencies: + '@eslint/js': + specifier: catalog:devTools + version: 9.39.1 + '@modelcontextprotocol/eslint-config': + specifier: workspace:^ + version: link:../../common/eslint-config + '@modelcontextprotocol/tsconfig': + specifier: workspace:^ + version: link:../../common/tsconfig + '@modelcontextprotocol/vitest-config': + specifier: workspace:^ + version: link:../../common/vitest-config + '@typescript/native-preview': + specifier: catalog:devTools + version: 7.0.0-dev.20251218.3 + eslint: + specifier: catalog:devTools + version: 9.39.1 + eslint-config-prettier: + specifier: catalog:devTools + version: 10.1.8(eslint@9.39.1) + eslint-plugin-n: + specifier: catalog:devTools + version: 17.23.1(eslint@9.39.1)(typescript@5.9.3) + prettier: + specifier: catalog:devTools + version: 3.6.2 + tsdown: + specifier: catalog:devTools + version: 0.18.0(@typescript/native-preview@7.0.0-dev.20251218.3)(typescript@5.9.3) + typescript: + specifier: catalog:devTools + version: 5.9.3 + typescript-eslint: + specifier: catalog:devTools + version: 8.49.0(eslint@9.39.1)(typescript@5.9.3) + vitest: + specifier: catalog:devTools + version: 4.0.9(@types/node@24.10.3)(tsx@4.20.6) + test/helpers: devDependencies: '@modelcontextprotocol/client': @@ -703,6 +819,9 @@ importers: '@modelcontextprotocol/server': specifier: workspace:^ version: link:../../packages/server + '@modelcontextprotocol/server-express': + specifier: workspace:^ + version: link:../../packages/server-express '@modelcontextprotocol/test-helpers': specifier: workspace:^ version: link:../helpers @@ -1097,6 +1216,9 @@ packages: '@quansync/fs@1.0.0': resolution: {integrity: sha512-4TJ3DFtlf1L5LDMaM6CanJ/0lckGNtJcMjQ1NAV6zDmA0tEHKZtxNKin8EgPaVX1YzljbxckyT2tJrpQKAtngQ==} + '@remix-run/node-fetch-server@0.13.0': + resolution: {integrity: sha512-1EsNo0ZpgXu/90AWoRZf/oE3RVTUS80tiTUpt+hv5pjtAkw7icN4WskDwz/KdAw5ARbJLMhZBrO1NqThmy/McA==} + '@rolldown/binding-android-arm64@1.0.0-beta.53': resolution: {integrity: sha512-Ok9V8o7o6YfSdTTYA/uHH30r3YtOxLD6G3wih/U9DO0ucBBFq8WPt/DslU53OgfteLRHITZny9N/qCUxMf9kjQ==} engines: {node: ^20.19.0 || >=22.12.0} @@ -2105,8 +2227,8 @@ packages: resolution: {integrity: sha512-JhFGDVJ7tmDJItKhYgJCGLOWjuK9vPxiXoUFLwLDc99NlmklilbiQJwoctZtt13+xMw91MCk/REan6MWHqDjyA==} engines: {node: '>=12.0.0'} - express-rate-limit@7.5.1: - resolution: {integrity: sha512-7iN8iPMDzOMHPUYllBEsQdWVB6fPDMPqwjBaFrgr4Jgr/+okjvzAy+UHlYYL/Vs0OsOrMkwS6PJDkFlJwoxUnw==} + express-rate-limit@8.2.1: + resolution: {integrity: sha512-PCZEIEIxqwhzw4KF0n7QF4QqruVTcF73O5kFKUnGOyjbCCgizBBiFaYpd/fnBLUMPw/BWw9OsiN7GgrNYr7j6g==} engines: {node: '>= 16'} peerDependencies: express: '>= 4.11' @@ -2349,6 +2471,10 @@ packages: resolution: {integrity: sha512-4gd7VpWNQNB4UKKCFFVcp1AVv+FMOgs9NKzjHKusc8jTMhd5eL1NqQqOpE0KzMds804/yHlglp3uxgluOqAPLw==} engines: {node: '>= 0.4'} + ip-address@10.0.1: + resolution: {integrity: sha512-NWv9YLW4PoW2B7xtzaS3NCot75m6nK7Icdv0o3lfMceJVRfSoQwqD4wEH5rLwoKJwUiZ/rfpiVBhnaF0FK4HoA==} + engines: {node: '>= 12'} + ipaddr.js@1.9.1: resolution: {integrity: sha512-0KI/607xoxSToH7GjN1FfSbLoU0+btTicjsQSWQlh/hZykN8KpmMf7uYwPW3R+akZ6R/w18ZlXSHBYXiYUPO3g==} engines: {node: '>= 0.10'} @@ -2480,10 +2606,6 @@ packages: resolution: {integrity: sha512-PMSmkqxr106Xa156c2M265Z+FTrPl+oxd/rgOQy2tijQeK5TxQ43psO1ZCwhVOSdnn+RzkzlRz/eY4BgJBYVpg==} hasBin: true - js-yaml@4.1.0: - resolution: {integrity: sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==} - hasBin: true - js-yaml@4.1.1: resolution: {integrity: sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==} hasBin: true @@ -3663,7 +3785,7 @@ snapshots: globals: 14.0.0 ignore: 5.3.2 import-fresh: 3.3.1 - js-yaml: 4.1.0 + js-yaml: 4.1.1 minimatch: 3.1.2 strip-json-comments: 3.1.1 transitivePeerDependencies: @@ -3768,6 +3890,8 @@ snapshots: dependencies: quansync: 1.0.0 + '@remix-run/node-fetch-server@0.13.0': {} + '@rolldown/binding-android-arm64@1.0.0-beta.53': optional: true @@ -4815,9 +4939,10 @@ snapshots: expect-type@1.2.2: {} - express-rate-limit@7.5.1(express@5.1.0): + express-rate-limit@8.2.1(express@5.1.0): dependencies: express: 5.1.0 + ip-address: 10.0.1 express@5.1.0: dependencies: @@ -5092,6 +5217,8 @@ snapshots: hasown: 2.0.2 side-channel: 1.1.0 + ip-address@10.0.1: {} + ipaddr.js@1.9.1: {} is-array-buffer@3.0.5: @@ -5225,10 +5352,6 @@ snapshots: argparse: 1.0.10 esprima: 4.0.1 - js-yaml@4.1.0: - dependencies: - argparse: 2.0.1 - js-yaml@4.1.1: dependencies: argparse: 2.0.1 diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index 12bae8326..a7222dd71 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -19,8 +19,9 @@ catalogs: content-type: ^1.0.5 cors: ^2.8.5 express: ^5.0.1 - express-rate-limit: ^7.5.0 + express-rate-limit: ^8.2.1 raw-body: ^3.0.0 + '@remix-run/node-fetch-server': ^0.13.0 runtimeClientOnly: jose: ^6.1.1 cross-spawn: ^7.0.5 diff --git a/test/integration/package.json b/test/integration/package.json index e709e431a..baa099bab 100644 --- a/test/integration/package.json +++ b/test/integration/package.json @@ -35,6 +35,7 @@ "@modelcontextprotocol/core": "workspace:^", "@modelcontextprotocol/client": "workspace:^", "@modelcontextprotocol/server": "workspace:^", + "@modelcontextprotocol/server-express": "workspace:^", "zod": "catalog:runtimeShared", "vitest": "catalog:devTools", "supertest": "catalog:devTools", diff --git a/test/integration/test/issues/test_1277_zod_v4_description.test.ts b/test/integration/test/issues/test_1277_zod_v4_description.test.ts index fe58cfcd5..75a61cb36 100644 --- a/test/integration/test/issues/test_1277_zod_v4_description.test.ts +++ b/test/integration/test/issues/test_1277_zod_v4_description.test.ts @@ -9,7 +9,8 @@ import { Client } from '@modelcontextprotocol/client'; import { InMemoryTransport, ListPromptsResultSchema } from '@modelcontextprotocol/core'; import { McpServer } from '@modelcontextprotocol/server'; -import { type ZodMatrixEntry, zodTestMatrix } from '@modelcontextprotocol/test-helpers'; +import type { ZodMatrixEntry } from '@modelcontextprotocol/test-helpers'; +import { zodTestMatrix } from '@modelcontextprotocol/test-helpers'; describe.each(zodTestMatrix)('Issue #1277: $zodVersionLabel', (entry: ZodMatrixEntry) => { const { z } = entry; diff --git a/test/integration/test/server.test.ts b/test/integration/test/server.test.ts index fcac6cc45..30a2c03c4 100644 --- a/test/integration/test/server.test.ts +++ b/test/integration/test/server.test.ts @@ -30,7 +30,9 @@ import { SetLevelRequestSchema, SUPPORTED_PROTOCOL_VERSIONS } from '@modelcontextprotocol/core'; -import { createMcpExpressApp, InMemoryTaskMessageQueue, InMemoryTaskStore, McpServer, Server } from '@modelcontextprotocol/server'; +import { InMemoryTaskStore, McpServer, Server } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; +import type { Request, Response } from 'express'; import supertest from 'supertest'; import * as z3 from 'zod/v3'; import * as z4 from 'zod/v4'; @@ -2066,7 +2068,7 @@ describe('createMcpExpressApp', () => { test('should parse JSON bodies', async () => { const app = createMcpExpressApp({ host: '0.0.0.0' }); // Disable host validation for this test - app.post('/test', (req, res) => { + app.post('/test', (req: Request, res: Response) => { res.json({ received: req.body }); }); @@ -2078,7 +2080,7 @@ describe('createMcpExpressApp', () => { test('should reject requests with invalid Host header by default', async () => { const app = createMcpExpressApp(); - app.post('/test', (_req, res) => { + app.post('/test', (_req: Request, res: Response) => { res.json({ success: true }); }); @@ -2097,7 +2099,7 @@ describe('createMcpExpressApp', () => { test('should allow requests with localhost Host header', async () => { const app = createMcpExpressApp(); - app.post('/test', (_req, res) => { + app.post('/test', (_req: Request, res: Response) => { res.json({ success: true }); }); @@ -2109,7 +2111,7 @@ describe('createMcpExpressApp', () => { test('should allow requests with 127.0.0.1 Host header', async () => { const app = createMcpExpressApp(); - app.post('/test', (_req, res) => { + app.post('/test', (_req: Request, res: Response) => { res.json({ success: true }); }); @@ -2121,7 +2123,7 @@ describe('createMcpExpressApp', () => { test('should not apply host validation when host is 0.0.0.0', async () => { const app = createMcpExpressApp({ host: '0.0.0.0' }); - app.post('/test', (_req, res) => { + app.post('/test', (_req: Request, res: Response) => { res.json({ success: true }); }); @@ -2134,7 +2136,7 @@ describe('createMcpExpressApp', () => { test('should apply host validation when host is explicitly localhost', async () => { const app = createMcpExpressApp({ host: 'localhost' }); - app.post('/test', (_req, res) => { + app.post('/test', (_req: Request, res: Response) => { res.json({ success: true }); }); @@ -2146,7 +2148,7 @@ describe('createMcpExpressApp', () => { test('should allow requests with IPv6 localhost Host header', async () => { const app = createMcpExpressApp(); - app.post('/test', (_req, res) => { + app.post('/test', (_req: Request, res: Response) => { res.json({ success: true }); }); @@ -2158,7 +2160,7 @@ describe('createMcpExpressApp', () => { test('should apply host validation when host is ::1 (IPv6 localhost)', async () => { const app = createMcpExpressApp({ host: '::1' }); - app.post('/test', (_req, res) => { + app.post('/test', (_req: Request, res: Response) => { res.json({ success: true }); }); @@ -2185,7 +2187,7 @@ describe('createMcpExpressApp', () => { test('should use custom allowedHosts when provided', async () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}); const app = createMcpExpressApp({ host: '0.0.0.0', allowedHosts: ['myapp.local', 'localhost'] }); - app.post('/test', (_req, res) => { + app.post('/test', (_req: Request, res: Response) => { res.json({ success: true }); }); @@ -2205,7 +2207,7 @@ describe('createMcpExpressApp', () => { test('should override default localhost validation when allowedHosts is provided', async () => { // Even though host is localhost, we're using custom allowedHosts const app = createMcpExpressApp({ host: 'localhost', allowedHosts: ['custom.local'] }); - app.post('/test', (_req, res) => { + app.post('/test', (_req: Request, res: Response) => { res.json({ success: true }); }); diff --git a/test/integration/test/server/mcp.test.ts b/test/integration/test/server/mcp.test.ts index f7bcececc..90e7152aa 100644 --- a/test/integration/test/server/mcp.test.ts +++ b/test/integration/test/server/mcp.test.ts @@ -1,26 +1,28 @@ import { Client } from '@modelcontextprotocol/client'; -import { getDisplayName, InMemoryTaskStore, InMemoryTransport, UriTemplate } from '@modelcontextprotocol/core'; +import type { CallToolResult, Notification, TextContent } from '@modelcontextprotocol/core'; import { - type CallToolResult, CallToolResultSchema, CompleteResultSchema, ElicitRequestSchema, ErrorCode, + getDisplayName, GetPromptResultSchema, + InMemoryTaskStore, + InMemoryTransport, ListPromptsResultSchema, ListResourcesResultSchema, ListResourceTemplatesResultSchema, ListToolsResultSchema, LoggingMessageNotificationSchema, - type Notification, ReadResourceResultSchema, - type TextContent, + UriTemplate, UrlElicitationRequiredError } from '@modelcontextprotocol/core'; import { completable } from '../../../../packages/server/src/server/completable.js'; import { McpServer, ResourceTemplate } from '../../../../packages/server/src/server/mcp.js'; -import { type ZodMatrixEntry, zodTestMatrix } from '../../../../packages/server/test/server/__fixtures__/zodTestMatrix.js'; +import type { ZodMatrixEntry } from '../../../../packages/server/test/server/__fixtures__/zodTestMatrix.js'; +import { zodTestMatrix } from '../../../../packages/server/test/server/__fixtures__/zodTestMatrix.js'; function createLatch() { let latch = false; diff --git a/test/integration/test/stateManagementStreamableHttp.test.ts b/test/integration/test/stateManagementStreamableHttp.test.ts index c33100efa..72180b688 100644 --- a/test/integration/test/stateManagementStreamableHttp.test.ts +++ b/test/integration/test/stateManagementStreamableHttp.test.ts @@ -1,5 +1,6 @@ import { randomUUID } from 'node:crypto'; -import { createServer, type Server } from 'node:http'; +import type { Server } from 'node:http'; +import { createServer } from 'node:http'; import { Client, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; import { @@ -9,10 +10,10 @@ import { ListResourcesResultSchema, ListToolsResultSchema, McpServer, - StreamableHTTPServerTransport + NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; -import { listenOnRandomPort } from '@modelcontextprotocol/test-helpers'; -import { type ZodMatrixEntry, zodTestMatrix } from '@modelcontextprotocol/test-helpers'; +import type { ZodMatrixEntry } from '@modelcontextprotocol/test-helpers'; +import { listenOnRandomPort, zodTestMatrix } from '@modelcontextprotocol/test-helpers'; describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { const { z } = entry; @@ -68,7 +69,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { ); // Create transport with or without session management - const serverTransport = new StreamableHTTPServerTransport({ + const serverTransport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: withSessionManagement ? () => randomUUID() // With session management, generate UUID : undefined // Without session management, return undefined @@ -89,7 +90,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { describe('Stateless Mode', () => { let server: Server; let mcpServer: McpServer; - let serverTransport: StreamableHTTPServerTransport; + let serverTransport: NodeStreamableHTTPServerTransport; let baseUrl: URL; beforeEach(async () => { @@ -253,7 +254,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { describe('Stateful Mode', () => { let server: Server; let mcpServer: McpServer; - let serverTransport: StreamableHTTPServerTransport; + let serverTransport: NodeStreamableHTTPServerTransport; let baseUrl: URL; beforeEach(async () => { diff --git a/test/integration/test/taskLifecycle.test.ts b/test/integration/test/taskLifecycle.test.ts index 216479e93..324da6aa2 100644 --- a/test/integration/test/taskLifecycle.test.ts +++ b/test/integration/test/taskLifecycle.test.ts @@ -1,5 +1,6 @@ import { randomUUID } from 'node:crypto'; -import { createServer, type Server } from 'node:http'; +import type { Server } from 'node:http'; +import { createServer } from 'node:http'; import { Client, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; import type { TaskRequestOptions } from '@modelcontextprotocol/server'; @@ -13,8 +14,8 @@ import { InMemoryTaskStore, McpError, McpServer, + NodeStreamableHTTPServerTransport, RELATED_TASK_META_KEY, - StreamableHTTPServerTransport, TaskSchema } from '@modelcontextprotocol/server'; import { listenOnRandomPort, waitForTaskStatus } from '@modelcontextprotocol/test-helpers'; @@ -23,7 +24,7 @@ import { z } from 'zod'; describe('Task Lifecycle Integration Tests', () => { let server: Server; let mcpServer: McpServer; - let serverTransport: StreamableHTTPServerTransport; + let serverTransport: NodeStreamableHTTPServerTransport; let baseUrl: URL; let taskStore: InMemoryTaskStore; @@ -188,7 +189,7 @@ describe('Task Lifecycle Integration Tests', () => { ); // Create transport - serverTransport = new StreamableHTTPServerTransport({ + serverTransport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID() }); diff --git a/test/integration/test/taskResumability.test.ts b/test/integration/test/taskResumability.test.ts index 1e4d8a0fd..db60e2d4e 100644 --- a/test/integration/test/taskResumability.test.ts +++ b/test/integration/test/taskResumability.test.ts @@ -1,14 +1,15 @@ import { randomUUID } from 'node:crypto'; -import { createServer, type Server } from 'node:http'; +import type { Server } from 'node:http'; +import { createServer } from 'node:http'; import { Client, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; +import type { EventStore, JSONRPCMessage } from '@modelcontextprotocol/server'; import { CallToolResultSchema, LoggingMessageNotificationSchema, McpServer, - StreamableHTTPServerTransport + NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; -import type { EventStore, JSONRPCMessage } from '@modelcontextprotocol/server'; import type { ZodMatrixEntry } from '@modelcontextprotocol/test-helpers'; import { listenOnRandomPort, zodTestMatrix } from '@modelcontextprotocol/test-helpers'; @@ -51,7 +52,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { describe('Transport resumability', () => { let server: Server; let mcpServer: McpServer; - let serverTransport: StreamableHTTPServerTransport; + let serverTransport: NodeStreamableHTTPServerTransport; let baseUrl: URL; let eventStore: InMemoryEventStore; @@ -117,7 +118,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { ); // Create a transport with the event store - serverTransport = new StreamableHTTPServerTransport({ + serverTransport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), eventStore }); diff --git a/test/integration/test/title.test.ts b/test/integration/test/title.test.ts index 4eec82335..97348c117 100644 --- a/test/integration/test/title.test.ts +++ b/test/integration/test/title.test.ts @@ -1,7 +1,8 @@ import { Client } from '@modelcontextprotocol/client'; import { InMemoryTransport } from '@modelcontextprotocol/core'; import { McpServer, ResourceTemplate, Server } from '@modelcontextprotocol/server'; -import { type ZodMatrixEntry, zodTestMatrix } from '@modelcontextprotocol/test-helpers'; +import type { ZodMatrixEntry } from '@modelcontextprotocol/test-helpers'; +import { zodTestMatrix } from '@modelcontextprotocol/test-helpers'; describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { const { z } = entry; diff --git a/test/integration/tsconfig.json b/test/integration/tsconfig.json index f69a602fd..666fc0509 100644 --- a/test/integration/tsconfig.json +++ b/test/integration/tsconfig.json @@ -8,6 +8,7 @@ "@modelcontextprotocol/core": ["./node_modules/@modelcontextprotocol/core/src/index.ts"], "@modelcontextprotocol/client": ["./node_modules/@modelcontextprotocol/client/src/index.ts"], "@modelcontextprotocol/server": ["./node_modules/@modelcontextprotocol/server/src/index.ts"], + "@modelcontextprotocol/server-express": ["./node_modules/@modelcontextprotocol/server-express/src/index.ts"], "@modelcontextprotocol/vitest-config": ["./node_modules/@modelcontextprotocol/vitest-config/tsconfig.json"], "@modelcontextprotocol/test-helpers": ["./node_modules/@modelcontextprotocol/test-helpers/src/index.ts"] }