From 59584492005b9daa2e44643f099a747a9d6ad4a8 Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Sat, 20 Dec 2025 08:56:45 +0200 Subject: [PATCH 1/5] PoC: split away server express and hono deps into server-express and server-hono --- common/eslint-config/eslint.config.mjs | 1 + docs/server.md | 4 +- examples/server/README.md | 5 +- examples/server/package.json | 2 + examples/server/src/elicitationFormExample.ts | 5 +- examples/server/src/elicitationUrlExample.ts | 4 +- .../src/honoWebStandardStreamableHttp.ts | 3 +- .../server/src/jsonResponseStreamableHttp.ts | 3 +- examples/server/src/simpleSseServer.ts | 3 +- .../src/simpleStatelessStreamableHttp.ts | 3 +- examples/server/src/simpleStreamableHttp.ts | 4 +- examples/server/src/simpleTaskInteractive.ts | 2 +- .../sseAndStreamableHttpCompatibleServer.ts | 9 +- examples/server/src/ssePollingExample.ts | 3 +- .../src/standaloneSseWithGetStreamableHttp.ts | 3 +- examples/server/tsconfig.json | 2 + examples/shared/package.json | 1 + .../shared/src/demoInMemoryOAuthProvider.ts | 36 +- .../test/demoInMemoryOAuthProvider.test.ts | 234 +------- examples/shared/tsconfig.json | 1 + package.json | 2 +- packages/client/src/client/sse.ts | 3 +- packages/server-express/README.md | 81 +++ packages/server-express/eslint.config.mjs | 12 + packages/server-express/package.json | 67 +++ .../server-express/src/auth/bearerAuth.ts | 62 ++ packages/server-express/src/auth/router.ts | 172 ++++++ .../server => server-express/src}/express.ts | 0 packages/server-express/src/index.ts | 4 + .../src/middleware/hostHeaderValidation.ts | 52 ++ .../test/server/auth/router.test.ts | 28 +- packages/server-express/tsconfig.json | 14 + packages/server-express/tsdown.config.ts | 23 + packages/server-express/vitest.config.js | 3 + packages/server-hono/README.md | 64 ++ packages/server-hono/eslint.config.mjs | 12 + packages/server-hono/package.json | 63 ++ packages/server-hono/src/auth/router.ts | 33 ++ packages/server-hono/src/index.ts | 3 + .../src/middleware/hostHeaderValidation.ts | 33 ++ packages/server-hono/src/streamableHttp.ts | 14 + packages/server-hono/test/server-hono.test.ts | 114 ++++ packages/server-hono/tsconfig.json | 14 + packages/server-hono/tsdown.config.ts | 23 + packages/server-hono/vitest.config.js | 3 + packages/server/package.json | 5 - packages/server/src/index.ts | 2 +- .../src/server/auth/handlers/authorize.ts | 112 ++-- .../src/server/auth/handlers/metadata.ts | 42 +- .../src/server/auth/handlers/register.ts | 106 ++-- .../server/src/server/auth/handlers/revoke.ts | 109 ++-- .../server/src/server/auth/handlers/token.ts | 120 ++-- packages/server/src/server/auth/index.ts | 1 + .../server/auth/middleware/allowedMethods.ts | 26 +- .../src/server/auth/middleware/bearerAuth.ts | 132 +++-- .../src/server/auth/middleware/clientAuth.ts | 75 +-- packages/server/src/server/auth/provider.ts | 3 +- .../server/auth/providers/proxyProvider.ts | 5 +- packages/server/src/server/auth/router.ts | 126 ++-- packages/server/src/server/auth/web.ts | 140 +++++ .../server/middleware/hostHeaderValidation.ts | 122 ++-- packages/server/src/server/sse.ts | 12 +- packages/server/src/server/streamableHttp.ts | 142 ++++- .../server/auth/handlers/authorize.test.ts | 300 ++-------- .../server/auth/handlers/metadata.test.ts | 65 ++- .../server/auth/handlers/register.test.ts | 297 +--------- .../test/server/auth/handlers/revoke.test.ts | 279 ++------- .../test/server/auth/handlers/token.test.ts | 545 +++-------------- .../auth/middleware/allowedMethods.test.ts | 88 +-- .../server/auth/middleware/bearerAuth.test.ts | 548 +++--------------- .../server/auth/middleware/clientAuth.test.ts | 133 ++--- .../auth/providers/proxyProvider.test.ts | 27 +- pnpm-lock.yaml | 163 ++++-- test/integration/package.json | 1 + .../test_1277_zod_v4_description.test.ts | 3 +- test/integration/test/server.test.ts | 24 +- test/integration/test/server/mcp.test.ts | 12 +- .../stateManagementStreamableHttp.test.ts | 7 +- test/integration/test/taskLifecycle.test.ts | 3 +- .../integration/test/taskResumability.test.ts | 3 +- test/integration/test/title.test.ts | 3 +- test/integration/tsconfig.json | 1 + 82 files changed, 2339 insertions(+), 2670 deletions(-) create mode 100644 packages/server-express/README.md create mode 100644 packages/server-express/eslint.config.mjs create mode 100644 packages/server-express/package.json create mode 100644 packages/server-express/src/auth/bearerAuth.ts create mode 100644 packages/server-express/src/auth/router.ts rename packages/{server/src/server => server-express/src}/express.ts (100%) create mode 100644 packages/server-express/src/index.ts create mode 100644 packages/server-express/src/middleware/hostHeaderValidation.ts rename packages/{server => server-express}/test/server/auth/router.test.ts (95%) create mode 100644 packages/server-express/tsconfig.json create mode 100644 packages/server-express/tsdown.config.ts create mode 100644 packages/server-express/vitest.config.js create mode 100644 packages/server-hono/README.md create mode 100644 packages/server-hono/eslint.config.mjs create mode 100644 packages/server-hono/package.json create mode 100644 packages/server-hono/src/auth/router.ts create mode 100644 packages/server-hono/src/index.ts create mode 100644 packages/server-hono/src/middleware/hostHeaderValidation.ts create mode 100644 packages/server-hono/src/streamableHttp.ts create mode 100644 packages/server-hono/test/server-hono.test.ts create mode 100644 packages/server-hono/tsconfig.json create mode 100644 packages/server-hono/tsdown.config.ts create mode 100644 packages/server-hono/vitest.config.js create mode 100644 packages/server/src/server/auth/web.ts diff --git a/common/eslint-config/eslint.config.mjs b/common/eslint-config/eslint.config.mjs index 321f3f6fc..6ac057c69 100644 --- a/common/eslint-config/eslint.config.mjs +++ b/common/eslint-config/eslint.config.mjs @@ -47,6 +47,7 @@ export default defineConfig( '@typescript-eslint/consistent-type-imports': ['error', { disallowTypeAnnotations: false }], 'simple-import-sort/imports': 'warn', 'simple-import-sort/exports': 'warn', + 'import/consistent-type-specifier-style': ['error', 'prefer-top-level'], 'import/no-extraneous-dependencies': [ 'error', { diff --git a/docs/server.md b/docs/server.md index 4d5138e84..800d336db 100644 --- a/docs/server.md +++ b/docs/server.md @@ -70,7 +70,7 @@ For more detailed patterns (stateless vs stateful, JSON response mode, CORS, DNS MCP servers running on localhost are vulnerable to DNS rebinding attacks. Use `createMcpExpressApp()` to create an Express app with DNS rebinding protection enabled by default: ```typescript -import { createMcpExpressApp } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; // Protection auto-enabled (default host is 127.0.0.1) const app = createMcpExpressApp(); @@ -85,7 +85,7 @@ const app = createMcpExpressApp({ host: '0.0.0.0' }); When binding to `0.0.0.0` / `::`, provide an allow-list of hosts: ```typescript -import { createMcpExpressApp } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; const app = createMcpExpressApp({ host: '0.0.0.0', diff --git a/examples/server/README.md b/examples/server/README.md index 310113e45..bb1216a04 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -1,6 +1,9 @@ # MCP TypeScript SDK Examples (Server) -This directory contains runnable MCP **server** examples built with `@modelcontextprotocol/server`. +This directory contains runnable MCP **server** examples built with `@modelcontextprotocol/server` plus framework adapters: + +- `@modelcontextprotocol/server-express` +- `@modelcontextprotocol/server-hono` For client examples, see [`../client/README.md`](../client/README.md). For guided docs, see [`../../docs/server.md`](../../docs/server.md). diff --git a/examples/server/package.json b/examples/server/package.json index a3a3d14c7..cb37d9f40 100644 --- a/examples/server/package.json +++ b/examples/server/package.json @@ -38,6 +38,8 @@ "hono": "catalog:runtimeServerOnly", "@modelcontextprotocol/examples-shared": "workspace:^", "@modelcontextprotocol/server": "workspace:^", + "@modelcontextprotocol/server-express": "workspace:^", + "@modelcontextprotocol/server-hono": "workspace:^", "cors": "catalog:runtimeServerOnly", "express": "catalog:runtimeServerOnly", "zod": "catalog:runtimeShared" diff --git a/examples/server/src/elicitationFormExample.ts b/examples/server/src/elicitationFormExample.ts index f8863c17b..567975662 100644 --- a/examples/server/src/elicitationFormExample.ts +++ b/examples/server/src/elicitationFormExample.ts @@ -9,8 +9,9 @@ import { randomUUID } from 'node:crypto'; -import { createMcpExpressApp, isInitializeRequest, McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; -import { type Request, type Response } from 'express'; +import { isInitializeRequest, McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; +import type { Request, Response } from 'express'; // Create MCP server - it will automatically use AjvJsonSchemaValidator with sensible defaults // The validator supports format validation (email, date, etc.) if ajv-formats is installed diff --git a/examples/server/src/elicitationUrlExample.ts b/examples/server/src/elicitationUrlExample.ts index 99f85d079..79ba49a17 100644 --- a/examples/server/src/elicitationUrlExample.ts +++ b/examples/server/src/elicitationUrlExample.ts @@ -13,15 +13,13 @@ import { setupAuthServer } from '@modelcontextprotocol/examples-shared'; import type { CallToolResult, ElicitRequestURLParams, ElicitResult, OAuthMetadata } from '@modelcontextprotocol/server'; import { checkResourceAllowed, - createMcpExpressApp, getOAuthProtectedResourceMetadataUrl, isInitializeRequest, - mcpAuthMetadataRouter, McpServer, - requireBearerAuth, StreamableHTTPServerTransport, UrlElicitationRequiredError } from '@modelcontextprotocol/server'; +import { createMcpExpressApp, mcpAuthMetadataRouter, requireBearerAuth } from '@modelcontextprotocol/server-express'; import cors from 'cors'; import type { Request, Response } from 'express'; import express from 'express'; diff --git a/examples/server/src/honoWebStandardStreamableHttp.ts b/examples/server/src/honoWebStandardStreamableHttp.ts index aef1e99e2..f5c59cffe 100644 --- a/examples/server/src/honoWebStandardStreamableHttp.ts +++ b/examples/server/src/honoWebStandardStreamableHttp.ts @@ -10,6 +10,7 @@ import { serve } from '@hono/node-server'; import type { CallToolResult } from '@modelcontextprotocol/server'; import { McpServer, WebStandardStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { mcpStreamableHttpHandler } from '@modelcontextprotocol/server-hono'; import { Hono } from 'hono'; import { cors } from 'hono/cors'; import * as z from 'zod/v4'; @@ -56,7 +57,7 @@ app.use( app.get('/health', c => c.json({ status: 'ok' })); // MCP endpoint -app.all('/mcp', c => transport.handleRequest(c.req.raw)); +app.all('/mcp', mcpStreamableHttpHandler(transport)); // Start the server const PORT = process.env.MCP_PORT ? parseInt(process.env.MCP_PORT, 10) : 3000; diff --git a/examples/server/src/jsonResponseStreamableHttp.ts b/examples/server/src/jsonResponseStreamableHttp.ts index 2199ebfbe..44155ea9d 100644 --- a/examples/server/src/jsonResponseStreamableHttp.ts +++ b/examples/server/src/jsonResponseStreamableHttp.ts @@ -1,7 +1,8 @@ import { randomUUID } from 'node:crypto'; import type { CallToolResult } from '@modelcontextprotocol/server'; -import { createMcpExpressApp, isInitializeRequest, McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { isInitializeRequest, McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; import * as z from 'zod/v4'; diff --git a/examples/server/src/simpleSseServer.ts b/examples/server/src/simpleSseServer.ts index 90561c62f..35b48b69d 100644 --- a/examples/server/src/simpleSseServer.ts +++ b/examples/server/src/simpleSseServer.ts @@ -1,5 +1,6 @@ import type { CallToolResult } from '@modelcontextprotocol/server'; -import { createMcpExpressApp, McpServer, SSEServerTransport } from '@modelcontextprotocol/server'; +import { McpServer, SSEServerTransport } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; import * as z from 'zod/v4'; diff --git a/examples/server/src/simpleStatelessStreamableHttp.ts b/examples/server/src/simpleStatelessStreamableHttp.ts index 3aee2c212..0f3a78e63 100644 --- a/examples/server/src/simpleStatelessStreamableHttp.ts +++ b/examples/server/src/simpleStatelessStreamableHttp.ts @@ -1,5 +1,6 @@ import type { CallToolResult, GetPromptResult, ReadResourceResult } from '@modelcontextprotocol/server'; -import { createMcpExpressApp, McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; import * as z from 'zod/v4'; diff --git a/examples/server/src/simpleStreamableHttp.ts b/examples/server/src/simpleStreamableHttp.ts index 7613e3786..f550ed7d7 100644 --- a/examples/server/src/simpleStreamableHttp.ts +++ b/examples/server/src/simpleStreamableHttp.ts @@ -11,17 +11,15 @@ import type { } from '@modelcontextprotocol/server'; import { checkResourceAllowed, - createMcpExpressApp, ElicitResultSchema, getOAuthProtectedResourceMetadataUrl, InMemoryTaskMessageQueue, InMemoryTaskStore, isInitializeRequest, - mcpAuthMetadataRouter, McpServer, - requireBearerAuth, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { createMcpExpressApp, mcpAuthMetadataRouter, requireBearerAuth } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; import * as z from 'zod/v4'; diff --git a/examples/server/src/simpleTaskInteractive.ts b/examples/server/src/simpleTaskInteractive.ts index 956c33f8e..4685f33f5 100644 --- a/examples/server/src/simpleTaskInteractive.ts +++ b/examples/server/src/simpleTaskInteractive.ts @@ -35,7 +35,6 @@ import type { } from '@modelcontextprotocol/server'; import { CallToolRequestSchema, - createMcpExpressApp, GetTaskPayloadRequestSchema, GetTaskRequestSchema, InMemoryTaskStore, @@ -45,6 +44,7 @@ import { Server, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; // ============================================================================ diff --git a/examples/server/src/sseAndStreamableHttpCompatibleServer.ts b/examples/server/src/sseAndStreamableHttpCompatibleServer.ts index 335802d0a..3ea3b71db 100644 --- a/examples/server/src/sseAndStreamableHttpCompatibleServer.ts +++ b/examples/server/src/sseAndStreamableHttpCompatibleServer.ts @@ -1,13 +1,8 @@ import { randomUUID } from 'node:crypto'; import type { CallToolResult } from '@modelcontextprotocol/server'; -import { - createMcpExpressApp, - isInitializeRequest, - McpServer, - SSEServerTransport, - StreamableHTTPServerTransport -} from '@modelcontextprotocol/server'; +import { isInitializeRequest, McpServer, SSEServerTransport, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; import * as z from 'zod/v4'; diff --git a/examples/server/src/ssePollingExample.ts b/examples/server/src/ssePollingExample.ts index 4e3d36328..e7da09ecb 100644 --- a/examples/server/src/ssePollingExample.ts +++ b/examples/server/src/ssePollingExample.ts @@ -15,7 +15,8 @@ import { randomUUID } from 'node:crypto'; import type { CallToolResult } from '@modelcontextprotocol/server'; -import { createMcpExpressApp, McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import cors from 'cors'; import type { Request, Response } from 'express'; diff --git a/examples/server/src/standaloneSseWithGetStreamableHttp.ts b/examples/server/src/standaloneSseWithGetStreamableHttp.ts index cceb24299..f9fb426cd 100644 --- a/examples/server/src/standaloneSseWithGetStreamableHttp.ts +++ b/examples/server/src/standaloneSseWithGetStreamableHttp.ts @@ -1,7 +1,8 @@ import { randomUUID } from 'node:crypto'; import type { ReadResourceResult } from '@modelcontextprotocol/server'; -import { createMcpExpressApp, isInitializeRequest, McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { isInitializeRequest, McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; // Create an MCP server with implementation details diff --git a/examples/server/tsconfig.json b/examples/server/tsconfig.json index 98d3a5b3f..1f72b0199 100644 --- a/examples/server/tsconfig.json +++ b/examples/server/tsconfig.json @@ -6,6 +6,8 @@ "paths": { "*": ["./*"], "@modelcontextprotocol/server": ["./node_modules/@modelcontextprotocol/server/src/index.ts"], + "@modelcontextprotocol/server-express": ["./node_modules/@modelcontextprotocol/server-express/src/index.ts"], + "@modelcontextprotocol/server-hono": ["./node_modules/@modelcontextprotocol/server-hono/src/index.ts"], "@modelcontextprotocol/core": [ "./node_modules/@modelcontextprotocol/server/node_modules/@modelcontextprotocol/core/src/index.ts" ], diff --git a/examples/shared/package.json b/examples/shared/package.json index 8287ca552..2d0f6ebe9 100644 --- a/examples/shared/package.json +++ b/examples/shared/package.json @@ -35,6 +35,7 @@ }, "dependencies": { "@modelcontextprotocol/server": "workspace:^", + "@modelcontextprotocol/server-express": "workspace:^", "express": "catalog:runtimeServerOnly" }, "devDependencies": { diff --git a/examples/shared/src/demoInMemoryOAuthProvider.ts b/examples/shared/src/demoInMemoryOAuthProvider.ts index bcf11dd0c..23b168224 100644 --- a/examples/shared/src/demoInMemoryOAuthProvider.ts +++ b/examples/shared/src/demoInMemoryOAuthProvider.ts @@ -9,8 +9,9 @@ import type { OAuthServerProvider, OAuthTokens } from '@modelcontextprotocol/server'; -import { createOAuthMetadata, InvalidRequestError, mcpAuthRouter, resourceUrlFromServerUrl } from '@modelcontextprotocol/server'; -import type { Request, Response } from 'express'; +import { createOAuthMetadata, InvalidRequestError, resourceUrlFromServerUrl } from '@modelcontextprotocol/server'; +import { mcpAuthRouter } from '@modelcontextprotocol/server-express'; +import type { Request, Response as ExpressResponse } from 'express'; import express from 'express'; export class DemoInMemoryClientsStore implements OAuthRegisteredClientsStore { @@ -47,7 +48,7 @@ export class DemoInMemoryAuthProvider implements OAuthServerProvider { constructor(private validateResource?: (resource?: URL) => boolean) {} - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + async authorize(client: OAuthClientInformationFull, params: AuthorizationParams): Promise { const code = randomUUID(); const searchParams = new URLSearchParams({ @@ -64,27 +65,24 @@ export class DemoInMemoryAuthProvider implements OAuthServerProvider { // Simulate a user login // Set a secure HTTP-only session cookie with authorization info - if (res.cookie) { - const authCookieData = { - userId: 'demo_user', - name: 'Demo User', - timestamp: Date.now() - }; - res.cookie('demo_session', JSON.stringify(authCookieData), { - httpOnly: true, - secure: false, // In production, this should be true - sameSite: 'lax', - maxAge: 24 * 60 * 60 * 1000, // 24 hours - for demo purposes - path: '/' // Available to all routes - }); - } + const authCookieData = { + userId: 'demo_user', + name: 'Demo User', + timestamp: Date.now() + }; + const cookieValue = encodeURIComponent(JSON.stringify(authCookieData)); + const maxAgeSeconds = 24 * 60 * 60; // 24 hours - demo only + const setCookie = `demo_session=${cookieValue}; HttpOnly; SameSite=Lax; Max-Age=${maxAgeSeconds}; Path=/`; if (!client.redirect_uris.includes(params.redirectUri)) { throw new InvalidRequestError('Unregistered redirect_uri'); } const targetUrl = new URL(params.redirectUri); targetUrl.search = searchParams.toString(); - res.redirect(targetUrl.toString()); + const redirectResponse = Response.redirect(targetUrl.toString(), 302); + const headers = new Headers(redirectResponse.headers); + headers.append('Set-Cookie', setCookie); + return new Response(null, { status: redirectResponse.status, headers }); } async challengeForAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise { @@ -204,7 +202,7 @@ export const setupAuthServer = ({ }) ); - authApp.post('/introspect', async (req: Request, res: Response) => { + authApp.post('/introspect', async (req: Request, res: ExpressResponse) => { try { const { token } = req.body; if (!token) { diff --git a/examples/shared/test/demoInMemoryOAuthProvider.test.ts b/examples/shared/test/demoInMemoryOAuthProvider.test.ts index 4018dddbe..0c1c887aa 100644 --- a/examples/shared/test/demoInMemoryOAuthProvider.test.ts +++ b/examples/shared/test/demoInMemoryOAuthProvider.test.ts @@ -1,21 +1,15 @@ import type { OAuthClientInformationFull } from '@modelcontextprotocol/core'; import type { AuthorizationParams } from '@modelcontextprotocol/server'; import { InvalidRequestError } from '@modelcontextprotocol/server'; -import { createExpressResponseMock } from '@modelcontextprotocol/test-helpers'; -import type { Response } from 'express'; import { beforeEach, describe, expect, it } from 'vitest'; -import { DemoInMemoryAuthProvider, DemoInMemoryClientsStore } from '../src/demoInMemoryOAuthProvider.js'; +import { DemoInMemoryAuthProvider } from '../src/demoInMemoryOAuthProvider.js'; describe('DemoInMemoryAuthProvider', () => { let provider: DemoInMemoryAuthProvider; - let mockResponse: Response & { getRedirectUrl: () => string }; beforeEach(() => { provider = new DemoInMemoryAuthProvider(); - mockResponse = createExpressResponseMock({ trackRedirectUrl: true }) as Response & { - getRedirectUrl: () => string; - }; }); describe('authorize', () => { @@ -26,7 +20,7 @@ describe('DemoInMemoryAuthProvider', () => { scope: 'test-scope' }; - it('should redirect to the requested redirect_uri when valid', async () => { + it('redirects to redirect_uri when valid', async () => { const params: AuthorizationParams = { redirectUri: 'https://example.com/callback', state: 'test-state', @@ -34,18 +28,18 @@ describe('DemoInMemoryAuthProvider', () => { scopes: ['test-scope'] }; - await provider.authorize(validClient, params, mockResponse); - - expect(mockResponse.redirect).toHaveBeenCalled(); - expect(mockResponse.getRedirectUrl()).toBeDefined(); - - const url = new URL(mockResponse.getRedirectUrl()); + const res = await provider.authorize(validClient, params); + expect(res.status).toBe(302); + const location = res.headers.get('location'); + expect(location).toBeTruthy(); + const url = new URL(location!); expect(url.origin + url.pathname).toBe('https://example.com/callback'); expect(url.searchParams.get('state')).toBe('test-state'); - expect(url.searchParams.has('code')).toBe(true); + expect(url.searchParams.get('code')).toBeTruthy(); + expect(res.headers.get('set-cookie')).toContain('demo_session='); }); - it('should throw InvalidRequestError for unregistered redirect_uri', async () => { + it('throws InvalidRequestError for unregistered redirect_uri', async () => { const params: AuthorizationParams = { redirectUri: 'https://evil.com/callback', state: 'test-state', @@ -53,212 +47,8 @@ describe('DemoInMemoryAuthProvider', () => { scopes: ['test-scope'] }; - await expect(provider.authorize(validClient, params, mockResponse)).rejects.toThrow(InvalidRequestError); - - await expect(provider.authorize(validClient, params, mockResponse)).rejects.toThrow('Unregistered redirect_uri'); - - expect(mockResponse.redirect).not.toHaveBeenCalled(); - }); - - it('should generate unique authorization codes for multiple requests', async () => { - const params1: AuthorizationParams = { - redirectUri: 'https://example.com/callback', - state: 'state-1', - codeChallenge: 'challenge-1', - scopes: ['test-scope'] - }; - - const params2: AuthorizationParams = { - redirectUri: 'https://example.com/callback', - state: 'state-2', - codeChallenge: 'challenge-2', - scopes: ['test-scope'] - }; - - await provider.authorize(validClient, params1, mockResponse); - const firstRedirectUrl = mockResponse.getRedirectUrl(); - const firstCode = new URL(firstRedirectUrl).searchParams.get('code'); - - // Reset the mock for the second call - mockResponse = createExpressResponseMock({ trackRedirectUrl: true }) as Response & { - getRedirectUrl: () => string; - }; - await provider.authorize(validClient, params2, mockResponse); - const secondRedirectUrl = mockResponse.getRedirectUrl(); - const secondCode = new URL(secondRedirectUrl).searchParams.get('code'); - - expect(firstCode).toBeDefined(); - expect(secondCode).toBeDefined(); - expect(firstCode).not.toBe(secondCode); - }); - - it('should handle params without state', async () => { - const params: AuthorizationParams = { - redirectUri: 'https://example.com/callback', - codeChallenge: 'test-challenge', - scopes: ['test-scope'] - }; - - await provider.authorize(validClient, params, mockResponse); - - expect(mockResponse.redirect).toHaveBeenCalled(); - expect(mockResponse.getRedirectUrl()).toBeDefined(); - - const url = new URL(mockResponse.getRedirectUrl()); - expect(url.searchParams.has('state')).toBe(false); - expect(url.searchParams.has('code')).toBe(true); - }); - }); - - describe('challengeForAuthorizationCode', () => { - const validClient: OAuthClientInformationFull = { - client_id: 'test-client', - client_secret: 'test-secret', - redirect_uris: ['https://example.com/callback'], - scope: 'test-scope' - }; - - it('should return the code challenge for a valid authorization code', async () => { - const params: AuthorizationParams = { - redirectUri: 'https://example.com/callback', - state: 'test-state', - codeChallenge: 'test-challenge-value', - scopes: ['test-scope'] - }; - - await provider.authorize(validClient, params, mockResponse); - const code = new URL(mockResponse.getRedirectUrl()).searchParams.get('code')!; - - const challenge = await provider.challengeForAuthorizationCode(validClient, code); - expect(challenge).toBe('test-challenge-value'); - }); - - it('should throw error for invalid authorization code', async () => { - await expect(provider.challengeForAuthorizationCode(validClient, 'invalid-code')).rejects.toThrow('Invalid authorization code'); - }); - }); - - describe('exchangeAuthorizationCode', () => { - const validClient: OAuthClientInformationFull = { - client_id: 'test-client', - client_secret: 'test-secret', - redirect_uris: ['https://example.com/callback'], - scope: 'test-scope' - }; - - it('should exchange valid authorization code for tokens', async () => { - const params: AuthorizationParams = { - redirectUri: 'https://example.com/callback', - state: 'test-state', - codeChallenge: 'test-challenge', - scopes: ['test-scope', 'other-scope'] - }; - - await provider.authorize(validClient, params, mockResponse); - const code = new URL(mockResponse.getRedirectUrl()).searchParams.get('code')!; - - const tokens = await provider.exchangeAuthorizationCode(validClient, code); - - expect(tokens).toEqual({ - access_token: expect.any(String), - token_type: 'bearer', - expires_in: 3600, - scope: 'test-scope other-scope' - }); - }); - - it('should throw error for invalid authorization code', async () => { - await expect(provider.exchangeAuthorizationCode(validClient, 'invalid-code')).rejects.toThrow('Invalid authorization code'); - }); - - it('should throw error when client_id does not match', async () => { - const params: AuthorizationParams = { - redirectUri: 'https://example.com/callback', - state: 'test-state', - codeChallenge: 'test-challenge', - scopes: ['test-scope'] - }; - - await provider.authorize(validClient, params, mockResponse); - const code = new URL(mockResponse.getRedirectUrl()).searchParams.get('code')!; - - const differentClient: OAuthClientInformationFull = { - client_id: 'different-client', - client_secret: 'different-secret', - redirect_uris: ['https://example.com/callback'], - scope: 'test-scope' - }; - - await expect(provider.exchangeAuthorizationCode(differentClient, code)).rejects.toThrow( - 'Authorization code was not issued to this client' - ); - }); - - it('should delete authorization code after successful exchange', async () => { - const params: AuthorizationParams = { - redirectUri: 'https://example.com/callback', - state: 'test-state', - codeChallenge: 'test-challenge', - scopes: ['test-scope'] - }; - - await provider.authorize(validClient, params, mockResponse); - const code = new URL(mockResponse.getRedirectUrl()).searchParams.get('code')!; - - // First exchange should succeed - await provider.exchangeAuthorizationCode(validClient, code); - - // Second exchange should fail - await expect(provider.exchangeAuthorizationCode(validClient, code)).rejects.toThrow('Invalid authorization code'); - }); - - it('should validate resource when validateResource is provided', async () => { - const validateResource = vi.fn().mockReturnValue(false); - const strictProvider = new DemoInMemoryAuthProvider(validateResource); - - const params: AuthorizationParams = { - redirectUri: 'https://example.com/callback', - state: 'test-state', - codeChallenge: 'test-challenge', - scopes: ['test-scope'], - resource: new URL('https://invalid-resource.com') - }; - - await strictProvider.authorize(validClient, params, mockResponse); - const code = new URL(mockResponse.getRedirectUrl()).searchParams.get('code')!; - - await expect(strictProvider.exchangeAuthorizationCode(validClient, code)).rejects.toThrow( - 'Invalid resource: https://invalid-resource.com/' - ); - - expect(validateResource).toHaveBeenCalledWith(params.resource); - }); - }); - - describe('DemoInMemoryClientsStore', () => { - let store: DemoInMemoryClientsStore; - - beforeEach(() => { - store = new DemoInMemoryClientsStore(); - }); - - it('should register and retrieve client', async () => { - const client: OAuthClientInformationFull = { - client_id: 'test-client', - client_secret: 'test-secret', - redirect_uris: ['https://example.com/callback'], - scope: 'test-scope' - }; - - await store.registerClient(client); - const retrieved = await store.getClient('test-client'); - - expect(retrieved).toEqual(client); - }); - - it('should return undefined for non-existent client', async () => { - const retrieved = await store.getClient('non-existent'); - expect(retrieved).toBeUndefined(); + await expect(provider.authorize(validClient, params)).rejects.toThrow(InvalidRequestError); + await expect(provider.authorize(validClient, params)).rejects.toThrow('Unregistered redirect_uri'); }); }); }); diff --git a/examples/shared/tsconfig.json b/examples/shared/tsconfig.json index aa994f939..91e368e7a 100644 --- a/examples/shared/tsconfig.json +++ b/examples/shared/tsconfig.json @@ -6,6 +6,7 @@ "paths": { "*": ["./*"], "@modelcontextprotocol/server": ["./node_modules/@modelcontextprotocol/server/src/index.ts"], + "@modelcontextprotocol/server-express": ["./node_modules/@modelcontextprotocol/server-express/src/index.ts"], "@modelcontextprotocol/core": [ "./node_modules/@modelcontextprotocol/server/node_modules/@modelcontextprotocol/core/src/index.ts" ], diff --git a/package.json b/package.json index 82dd50d74..3d37f5c70 100644 --- a/package.json +++ b/package.json @@ -15,7 +15,7 @@ "node": ">=20", "pnpm": ">=10.24.0" }, - "packageManager": "pnpm@10.24.0", + "packageManager": "pnpm@10.26.1", "keywords": [ "modelcontextprotocol", "mcp" diff --git a/packages/client/src/client/sse.ts b/packages/client/src/client/sse.ts index bff74986e..4cfa77e17 100644 --- a/packages/client/src/client/sse.ts +++ b/packages/client/src/client/sse.ts @@ -1,6 +1,7 @@ import type { FetchLike, JSONRPCMessage, Transport } from '@modelcontextprotocol/core'; import { createFetchWithInit, JSONRPCMessageSchema, normalizeHeaders } from '@modelcontextprotocol/core'; -import { type ErrorEvent, EventSource, type EventSourceInit } from 'eventsource'; +import type { ErrorEvent, EventSourceInit } from 'eventsource'; +import { EventSource } from 'eventsource'; import type { AuthResult, OAuthClientProvider } from './auth.js'; import { auth, extractWWWAuthenticateParams, UnauthorizedError } from './auth.js'; diff --git a/packages/server-express/README.md b/packages/server-express/README.md new file mode 100644 index 000000000..c1b10a9c7 --- /dev/null +++ b/packages/server-express/README.md @@ -0,0 +1,81 @@ +# `@modelcontextprotocol/server-express` + +Express adapters for the MCP TypeScript server SDK. + +This package is the Express-specific companion to [`@modelcontextprotocol/server`](../server/), which is framework-agnostic and uses Web Standard `Request`/`Response` interfaces. + +## Install + +```bash +npm install @modelcontextprotocol/server @modelcontextprotocol/server-express zod +``` + +## Exports + +- `createMcpExpressApp(options?)` +- `hostHeaderValidation(allowedHosts)` +- `localhostHostValidation()` +- `mcpAuthRouter(options)` +- `mcpAuthMetadataRouter(options)` +- `requireBearerAuth(options)` + +## Usage + +### Create an Express app with localhost DNS rebinding protection + +```ts +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; + +const app = createMcpExpressApp(); // default host is 127.0.0.1; protection enabled +``` + +### Streamable HTTP endpoint (Express) + +```ts +import { McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; + +const app = createMcpExpressApp(); + +app.post('/mcp', async (req, res) => { + const transport = new StreamableHTTPServerTransport(); + await transport.handleRequest(req, res, req.body); +}); +``` + +### OAuth routes (Express) + +`@modelcontextprotocol/server` provides Web-standard auth handlers; this package wraps them as Express routers. + +```ts +import { mcpAuthRouter } from '@modelcontextprotocol/server-express'; +import type { OAuthServerProvider } from '@modelcontextprotocol/server'; +import express from 'express'; + +const provider: OAuthServerProvider = /* ... */; +const app = express(); +app.use(express.json()); + +// MUST be mounted at the app root +app.use( + mcpAuthRouter({ + provider, + issuerUrl: new URL('https://auth.example.com') + }) +); +``` + +### Bearer auth middleware (Express) + +`requireBearerAuth` validates the `Authorization: Bearer ...` header and sets `req.auth` on success. + +```ts +import { requireBearerAuth } from '@modelcontextprotocol/server-express'; +import type { OAuthTokenVerifier } from '@modelcontextprotocol/server'; + +const verifier: OAuthTokenVerifier = /* ... */; + +app.post('/protected', requireBearerAuth({ verifier }), (req, res) => { + res.json({ clientId: req.auth?.clientId }); +}); +``` diff --git a/packages/server-express/eslint.config.mjs b/packages/server-express/eslint.config.mjs new file mode 100644 index 000000000..03d533134 --- /dev/null +++ b/packages/server-express/eslint.config.mjs @@ -0,0 +1,12 @@ +// @ts-check + +import baseConfig from '@modelcontextprotocol/eslint-config'; + +export default [ + ...baseConfig, + { + settings: { + 'import/internal-regex': '^@modelcontextprotocol/(server|core)' + } + } +]; diff --git a/packages/server-express/package.json b/packages/server-express/package.json new file mode 100644 index 000000000..d7d31cb1e --- /dev/null +++ b/packages/server-express/package.json @@ -0,0 +1,67 @@ +{ + "name": "@modelcontextprotocol/server-express", + "version": "2.0.0-alpha.0", + "description": "Express adapters for the Model Context Protocol TypeScript server SDK", + "license": "MIT", + "author": "Anthropic, PBC (https://anthropic.com)", + "homepage": "https://modelcontextprotocol.io", + "bugs": "https://github.com/modelcontextprotocol/typescript-sdk/issues", + "type": "module", + "repository": { + "type": "git", + "url": "git+https://github.com/modelcontextprotocol/typescript-sdk.git" + }, + "engines": { + "node": ">=20", + "pnpm": ">=10.24.0" + }, + "packageManager": "pnpm@10.24.0", + "keywords": [ + "modelcontextprotocol", + "mcp", + "express" + ], + "exports": { + ".": { + "types": "./dist/index.d.mts", + "import": "./dist/index.mjs" + } + }, + "files": [ + "dist" + ], + "scripts": { + "typecheck": "tsgo -p tsconfig.json --noEmit", + "build": "tsdown", + "build:watch": "tsdown --watch", + "prepack": "npm run build", + "lint": "eslint src/ && prettier --ignore-path ../../.prettierignore --check .", + "lint:fix": "eslint src/ --fix && prettier --ignore-path ../../.prettierignore --write .", + "check": "npm run typecheck && npm run lint", + "test": "vitest run", + "test:watch": "vitest" + }, + "dependencies": { + "@modelcontextprotocol/server": "workspace:^", + "express": "catalog:runtimeServerOnly" + }, + "devDependencies": { + "@modelcontextprotocol/tsconfig": "workspace:^", + "@modelcontextprotocol/vitest-config": "workspace:^", + "@modelcontextprotocol/eslint-config": "workspace:^", + "@eslint/js": "catalog:devTools", + "@types/express": "catalog:devTools", + "@types/express-serve-static-core": "catalog:devTools", + "@types/supertest": "catalog:devTools", + "@typescript/native-preview": "catalog:devTools", + "eslint": "catalog:devTools", + "eslint-config-prettier": "catalog:devTools", + "eslint-plugin-n": "catalog:devTools", + "prettier": "catalog:devTools", + "supertest": "catalog:devTools", + "tsdown": "catalog:devTools", + "typescript": "catalog:devTools", + "typescript-eslint": "catalog:devTools", + "vitest": "catalog:devTools" + } +} diff --git a/packages/server-express/src/auth/bearerAuth.ts b/packages/server-express/src/auth/bearerAuth.ts new file mode 100644 index 000000000..a923ff796 --- /dev/null +++ b/packages/server-express/src/auth/bearerAuth.ts @@ -0,0 +1,62 @@ +import { URL } from 'node:url'; + +import type { AuthInfo } from '@modelcontextprotocol/core'; +import type { BearerAuthMiddlewareOptions } from '@modelcontextprotocol/server'; +import { requireBearerAuth as requireBearerAuthWeb } from '@modelcontextprotocol/server'; +import type { NextFunction, Request as ExpressRequest, RequestHandler, Response as ExpressResponse } from 'express'; + +declare module 'express-serve-static-core' { + interface Request { + /** + * Information about the validated access token, if `requireBearerAuth` was used. + */ + auth?: AuthInfo; + } +} + +function expressRequestUrl(req: ExpressRequest): URL { + const host = req.get('host') ?? req.headers.host ?? 'localhost'; + const protocol = req.protocol ?? 'http'; + const path = req.originalUrl ?? req.url ?? '/'; + return new URL(path, `${protocol}://${host}`); +} + +async function writeWebResponse(res: ExpressResponse, webResponse: Response): Promise { + res.status(webResponse.status); + for (const [k, v] of webResponse.headers.entries()) { + res.setHeader(k, v); + } + const bodyText = await webResponse.text(); + res.send(bodyText); +} + +/** + * Express middleware wrapper for the Web-standard `requireBearerAuth` helper. + * + * On success, sets `req.auth` and calls `next()`. + * On failure, writes the JSON error response and ends the request. + */ +export function requireBearerAuth(options: BearerAuthMiddlewareOptions): RequestHandler { + return async (req, res, next: NextFunction) => { + try { + const url = expressRequestUrl(req); + const webReq = new Request(url, { + method: req.method, + headers: { + authorization: req.headers.authorization ?? '' + } + }); + + const result = await requireBearerAuthWeb(webReq, options); + if ('authInfo' in result) { + req.auth = result.authInfo; + next(); + return; + } + + await writeWebResponse(res, result.response); + } catch (err) { + next(err); + } + }; +} diff --git a/packages/server-express/src/auth/router.ts b/packages/server-express/src/auth/router.ts new file mode 100644 index 000000000..c50d2007d --- /dev/null +++ b/packages/server-express/src/auth/router.ts @@ -0,0 +1,172 @@ +import type { IncomingMessage } from 'node:http'; +import { Readable } from 'node:stream'; +import { URL } from 'node:url'; + +import type { AuthMetadataOptions, AuthRouterOptions, WebHandlerContext } from '@modelcontextprotocol/server'; +import { mcpAuthMetadataRouter as createWebAuthMetadataRouter, mcpAuthRouter as createWebAuthRouter } from '@modelcontextprotocol/server'; +import type { RequestHandler, Response as ExpressResponse } from 'express'; +import express from 'express'; + +type ExpressRequestLike = IncomingMessage & { + method: string; + headers: Record; + originalUrl?: string; + url?: string; + protocol?: string; + // express adds this when trust proxy is enabled + ip?: string; + body?: unknown; + get?: (name: string) => string | undefined; +}; + +function expressRequestUrl(req: ExpressRequestLike): URL { + const host = req.get?.('host') ?? req.headers.host ?? 'localhost'; + const proto = req.protocol ?? 'http'; + const path = req.originalUrl ?? req.url ?? '/'; + return new URL(path, `${proto}://${host}`); +} + +function toHeaders(req: ExpressRequestLike): Headers { + const headers = new Headers(); + for (const [key, value] of Object.entries(req.headers)) { + if (value === undefined) continue; + if (Array.isArray(value)) { + headers.set(key, value.join(', ')); + } else { + headers.set(key, value); + } + } + return headers; +} + +async function readBody(req: IncomingMessage): Promise { + const chunks: Buffer[] = []; + for await (const chunk of req) { + chunks.push(Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk)); + } + return Buffer.concat(chunks); +} + +async function expressToWebRequest(req: ExpressRequestLike, parsedBodyProvided: boolean): Promise { + const url = expressRequestUrl(req); + const headers = toHeaders(req); + + // If upstream body parsing ran, the Node stream is likely consumed. + if (parsedBodyProvided) { + return new Request(url, { method: req.method, headers }); + } + + if (req.method === 'GET' || req.method === 'HEAD') { + return new Request(url, { method: req.method, headers }); + } + + const body = await readBody(req); + return new Request(url, { method: req.method, headers, body }); +} + +async function writeWebResponse(res: ExpressResponse, webResponse: Response): Promise { + res.status(webResponse.status); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const getSetCookie = (webResponse.headers as any).getSetCookie as (() => string[]) | undefined; + const setCookies = typeof getSetCookie === 'function' ? getSetCookie.call(webResponse.headers) : undefined; + + for (const [key, value] of webResponse.headers.entries()) { + if (key.toLowerCase() === 'set-cookie' && setCookies?.length) continue; + res.setHeader(key, value); + } + + if (setCookies?.length) { + res.setHeader('set-cookie', setCookies); + } + + res.flushHeaders?.(); + + if (!webResponse.body) { + res.end(); + return; + } + + await new Promise((resolve, reject) => { + const readable = Readable.fromWeb(webResponse.body as unknown as ReadableStream); + readable.on('error', err => { + try { + res.destroy(err as Error); + } catch { + // ignore + } + reject(err); + }); + res.on('error', reject); + res.on('close', () => { + try { + readable.destroy(); + } catch { + // ignore + } + }); + readable.pipe(res); + res.on('finish', () => resolve()); + }); +} + +function toHandlerContext(req: ExpressRequestLike): WebHandlerContext { + return { + parsedBody: req.body, + clientAddress: req.ip + }; +} + +/** + * Express router adapter for the Web-standard `mcpAuthRouter` from `@modelcontextprotocol/server`. + * + * IMPORTANT: This router MUST be mounted at the application root, like: + * + * ```ts + * app.use(mcpAuthRouter(...)) + * ``` + */ +export function mcpAuthRouter(options: AuthRouterOptions): RequestHandler { + const web = createWebAuthRouter(options); + const router = express.Router(); + + for (const route of web.routes) { + router.all(route.path, async (req, res, next) => { + try { + const parsedBodyProvided = (req as ExpressRequestLike).body !== undefined; + const webReq = await expressToWebRequest(req as ExpressRequestLike, parsedBodyProvided); + const webRes = await route.handler(webReq, toHandlerContext(req as ExpressRequestLike)); + await writeWebResponse(res, webRes); + } catch (err) { + next(err); + } + }); + } + + return router; +} + +/** + * Express router adapter for the Web-standard `mcpAuthMetadataRouter` from `@modelcontextprotocol/server`. + * + * IMPORTANT: This router MUST be mounted at the application root. + */ +export function mcpAuthMetadataRouter(options: AuthMetadataOptions): RequestHandler { + const web = createWebAuthMetadataRouter(options); + const router = express.Router(); + + for (const route of web.routes) { + router.all(route.path, async (req, res, next) => { + try { + const parsedBodyProvided = (req as ExpressRequestLike).body !== undefined; + const webReq = await expressToWebRequest(req as ExpressRequestLike, parsedBodyProvided); + const webRes = await route.handler(webReq, toHandlerContext(req as ExpressRequestLike)); + await writeWebResponse(res, webRes); + } catch (err) { + next(err); + } + }); + } + + return router; +} diff --git a/packages/server/src/server/express.ts b/packages/server-express/src/express.ts similarity index 100% rename from packages/server/src/server/express.ts rename to packages/server-express/src/express.ts diff --git a/packages/server-express/src/index.ts b/packages/server-express/src/index.ts new file mode 100644 index 000000000..3c5b72fe7 --- /dev/null +++ b/packages/server-express/src/index.ts @@ -0,0 +1,4 @@ +export * from './auth/bearerAuth.js'; +export * from './auth/router.js'; +export * from './express.js'; +export * from './middleware/hostHeaderValidation.js'; diff --git a/packages/server-express/src/middleware/hostHeaderValidation.ts b/packages/server-express/src/middleware/hostHeaderValidation.ts new file mode 100644 index 000000000..00ee74e1f --- /dev/null +++ b/packages/server-express/src/middleware/hostHeaderValidation.ts @@ -0,0 +1,52 @@ +import { localhostAllowedHostnames, validateHostHeader } from '@modelcontextprotocol/server'; +import type { NextFunction, Request, RequestHandler, Response } from 'express'; + +/** + * Express middleware for DNS rebinding protection. + * Validates Host header hostname (port-agnostic) against an allowed list. + * + * This is particularly important for servers without authorization or HTTPS, + * such as localhost servers or development servers. DNS rebinding attacks can + * bypass same-origin policy by manipulating DNS to point a domain to a + * localhost address, allowing malicious websites to access your local server. + * + * @param allowedHostnames - List of allowed hostnames (without ports). + * For IPv6, provide the address with brackets (e.g., '[::1]'). + * @returns Express middleware function + * + * @example + * ```typescript + * const middleware = hostHeaderValidation(['localhost', '127.0.0.1', '[::1]']); + * app.use(middleware); + * ``` + */ +export function hostHeaderValidation(allowedHostnames: string[]): RequestHandler { + return (req: Request, res: Response, next: NextFunction) => { + const result = validateHostHeader(req.headers.host, allowedHostnames); + if (!result.ok) { + res.status(403).json({ + jsonrpc: '2.0', + error: { + code: -32000, + message: result.message + }, + id: null + }); + return; + } + next(); + }; +} + +/** + * Convenience middleware for localhost DNS rebinding protection. + * Allows only localhost, 127.0.0.1, and [::1] (IPv6 localhost) hostnames. + * + * @example + * ```typescript + * app.use(localhostHostValidation()); + * ``` + */ +export function localhostHostValidation(): RequestHandler { + return hostHeaderValidation(localhostAllowedHostnames()); +} diff --git a/packages/server/test/server/auth/router.test.ts b/packages/server-express/test/server/auth/router.test.ts similarity index 95% rename from packages/server/test/server/auth/router.test.ts rename to packages/server-express/test/server/auth/router.test.ts index 250fca4c4..7a6c09690 100644 --- a/packages/server/test/server/auth/router.test.ts +++ b/packages/server-express/test/server/auth/router.test.ts @@ -1,14 +1,18 @@ -import type { OAuthClientInformationFull, OAuthMetadata, OAuthTokenRevocationRequest, OAuthTokens } from '@modelcontextprotocol/core'; -import type { AuthInfo } from '@modelcontextprotocol/core'; -import { InvalidTokenError } from '@modelcontextprotocol/core'; -import type { Response } from 'express'; +import type { + AuthInfo, + OAuthClientInformationFull, + OAuthMetadata, + OAuthTokenRevocationRequest, + OAuthTokens +} from '@modelcontextprotocol/server'; +import { InvalidTokenError } from '@modelcontextprotocol/server'; import express from 'express'; import supertest from 'supertest'; -import type { OAuthRegisteredClientsStore } from '../../../src/server/auth/clients.js'; -import type { AuthorizationParams, OAuthServerProvider } from '../../../src/server/auth/provider.js'; -import type { AuthMetadataOptions, AuthRouterOptions } from '../../../src/server/auth/router.js'; -import { mcpAuthMetadataRouter, mcpAuthRouter } from '../../../src/server/auth/router.js'; +import type { OAuthRegisteredClientsStore } from '@modelcontextprotocol/server'; +import type { AuthorizationParams, OAuthServerProvider } from '@modelcontextprotocol/server'; +import type { AuthMetadataOptions, AuthRouterOptions } from '@modelcontextprotocol/server'; +import { mcpAuthMetadataRouter, mcpAuthRouter } from '../../../src/auth/router.js'; describe('MCP Auth Router', () => { // Setup mock provider with full capabilities @@ -32,13 +36,13 @@ describe('MCP Auth Router', () => { const mockProvider: OAuthServerProvider = { clientsStore: mockClientStore, - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + async authorize(_client: OAuthClientInformationFull, params: AuthorizationParams): Promise { const redirectUrl = new URL(params.redirectUri); redirectUrl.searchParams.set('code', 'mock_auth_code'); if (params.state) { redirectUrl.searchParams.set('state', params.state); } - res.redirect(302, redirectUrl.toString()); + return Response.redirect(redirectUrl.toString(), 302); }, async challengeForAuthorizationCode(): Promise { @@ -95,13 +99,13 @@ describe('MCP Auth Router', () => { } }, - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + async authorize(_client: OAuthClientInformationFull, params: AuthorizationParams): Promise { const redirectUrl = new URL(params.redirectUri); redirectUrl.searchParams.set('code', 'mock_auth_code'); if (params.state) { redirectUrl.searchParams.set('state', params.state); } - res.redirect(302, redirectUrl.toString()); + return Response.redirect(redirectUrl.toString(), 302); }, async challengeForAuthorizationCode(): Promise { diff --git a/packages/server-express/tsconfig.json b/packages/server-express/tsconfig.json new file mode 100644 index 000000000..0d7fdd0c0 --- /dev/null +++ b/packages/server-express/tsconfig.json @@ -0,0 +1,14 @@ +{ + "extends": "@modelcontextprotocol/tsconfig", + "include": ["./"], + "exclude": ["node_modules", "dist"], + "compilerOptions": { + "paths": { + "*": ["./*"], + "@modelcontextprotocol/server": ["./node_modules/@modelcontextprotocol/server/src/index.ts"], + "@modelcontextprotocol/core": [ + "./node_modules/@modelcontextprotocol/server/node_modules/@modelcontextprotocol/core/src/index.ts" + ] + } + } +} diff --git a/packages/server-express/tsdown.config.ts b/packages/server-express/tsdown.config.ts new file mode 100644 index 000000000..c72e7a2c4 --- /dev/null +++ b/packages/server-express/tsdown.config.ts @@ -0,0 +1,23 @@ +import { defineConfig } from 'tsdown'; + +export default defineConfig({ + entry: ['src/index.ts'], + format: ['esm'], + outDir: 'dist', + clean: true, + sourcemap: true, + target: 'esnext', + platform: 'node', + shims: true, + dts: { + resolver: 'tsc', + compilerOptions: { + baseUrl: '.', + paths: { + '@modelcontextprotocol/server': ['../server/src/index.ts'], + '@modelcontextprotocol/core': ['../core/src/index.ts'] + } + } + }, + noExternal: ['@modelcontextprotocol/server', '@modelcontextprotocol/core'] +}); diff --git a/packages/server-express/vitest.config.js b/packages/server-express/vitest.config.js new file mode 100644 index 000000000..496fca320 --- /dev/null +++ b/packages/server-express/vitest.config.js @@ -0,0 +1,3 @@ +import baseConfig from '@modelcontextprotocol/vitest-config'; + +export default baseConfig; diff --git a/packages/server-hono/README.md b/packages/server-hono/README.md new file mode 100644 index 000000000..d2788881a --- /dev/null +++ b/packages/server-hono/README.md @@ -0,0 +1,64 @@ +# `@modelcontextprotocol/server-hono` + +Hono adapters for the MCP TypeScript server SDK. + +This package is the Hono-specific companion to [`@modelcontextprotocol/server`](../server/), which is framework-agnostic and uses Web Standard `Request`/`Response` interfaces. + +## Install + +```bash +npm install @modelcontextprotocol/server @modelcontextprotocol/server-hono hono zod +``` + +## Exports + +- `mcpStreamableHttpHandler(transport)` +- `registerMcpAuthRoutes(app, options)` +- `registerMcpAuthMetadataRoutes(app, options)` +- `hostHeaderValidation(allowedHosts)` +- `localhostHostValidation()` + +## Usage + +### Streamable HTTP endpoint (Hono) + +```ts +import { Hono } from 'hono'; +import { McpServer, WebStandardStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { mcpStreamableHttpHandler } from '@modelcontextprotocol/server-hono'; + +const server = new McpServer({ name: 'my-server', version: '1.0.0' }); +const transport = new WebStandardStreamableHTTPServerTransport(); +await server.connect(transport); + +const app = new Hono(); +app.all('/mcp', mcpStreamableHttpHandler(transport)); +``` + +### OAuth routes (Hono) + +`@modelcontextprotocol/server` provides Web-standard auth handlers; this package mounts them onto a Hono app. + +```ts +import { Hono } from 'hono'; +import type { OAuthServerProvider } from '@modelcontextprotocol/server'; +import { registerMcpAuthRoutes } from '@modelcontextprotocol/server-hono'; + +const provider: OAuthServerProvider = /* ... */; + +const app = new Hono(); +registerMcpAuthRoutes(app, { + provider, + issuerUrl: new URL('https://auth.example.com') +}); +``` + +### Host header validation (DNS rebinding protection) + +```ts +import { Hono } from 'hono'; +import { localhostHostValidation } from '@modelcontextprotocol/server-hono'; + +const app = new Hono(); +app.use('*', localhostHostValidation()); +``` diff --git a/packages/server-hono/eslint.config.mjs b/packages/server-hono/eslint.config.mjs new file mode 100644 index 000000000..03d533134 --- /dev/null +++ b/packages/server-hono/eslint.config.mjs @@ -0,0 +1,12 @@ +// @ts-check + +import baseConfig from '@modelcontextprotocol/eslint-config'; + +export default [ + ...baseConfig, + { + settings: { + 'import/internal-regex': '^@modelcontextprotocol/(server|core)' + } + } +]; diff --git a/packages/server-hono/package.json b/packages/server-hono/package.json new file mode 100644 index 000000000..33f633d40 --- /dev/null +++ b/packages/server-hono/package.json @@ -0,0 +1,63 @@ +{ + "name": "@modelcontextprotocol/server-hono", + "version": "2.0.0-alpha.0", + "description": "Hono adapters for the Model Context Protocol TypeScript server SDK", + "license": "MIT", + "author": "Anthropic, PBC (https://anthropic.com)", + "homepage": "https://modelcontextprotocol.io", + "bugs": "https://github.com/modelcontextprotocol/typescript-sdk/issues", + "type": "module", + "repository": { + "type": "git", + "url": "git+https://github.com/modelcontextprotocol/typescript-sdk.git" + }, + "engines": { + "node": ">=20", + "pnpm": ">=10.24.0" + }, + "packageManager": "pnpm@10.24.0", + "keywords": [ + "modelcontextprotocol", + "mcp", + "hono" + ], + "exports": { + ".": { + "types": "./dist/index.d.mts", + "import": "./dist/index.mjs" + } + }, + "files": [ + "dist" + ], + "scripts": { + "typecheck": "tsgo -p tsconfig.json --noEmit", + "build": "tsdown", + "build:watch": "tsdown --watch", + "prepack": "npm run build", + "lint": "eslint src/ && prettier --ignore-path ../../.prettierignore --check .", + "lint:fix": "eslint src/ --fix && prettier --ignore-path ../../.prettierignore --write .", + "check": "npm run typecheck && npm run lint", + "test": "vitest run", + "test:watch": "vitest" + }, + "dependencies": { + "@modelcontextprotocol/server": "workspace:^", + "hono": "catalog:runtimeServerOnly" + }, + "devDependencies": { + "@modelcontextprotocol/tsconfig": "workspace:^", + "@modelcontextprotocol/vitest-config": "workspace:^", + "@modelcontextprotocol/eslint-config": "workspace:^", + "@eslint/js": "catalog:devTools", + "@typescript/native-preview": "catalog:devTools", + "eslint": "catalog:devTools", + "eslint-config-prettier": "catalog:devTools", + "eslint-plugin-n": "catalog:devTools", + "prettier": "catalog:devTools", + "tsdown": "catalog:devTools", + "typescript": "catalog:devTools", + "typescript-eslint": "catalog:devTools", + "vitest": "catalog:devTools" + } +} diff --git a/packages/server-hono/src/auth/router.ts b/packages/server-hono/src/auth/router.ts new file mode 100644 index 000000000..4c61c1d2c --- /dev/null +++ b/packages/server-hono/src/auth/router.ts @@ -0,0 +1,33 @@ +import type { AuthMetadataOptions, AuthRoute, AuthRouterOptions } from '@modelcontextprotocol/server'; +import { mcpAuthMetadataRouter as createWebAuthMetadataRouter, mcpAuthRouter as createWebAuthRouter } from '@modelcontextprotocol/server'; +import type { Handler, Hono } from 'hono'; + +export type RegisterMcpAuthRoutesOptions = AuthRouterOptions; + +/** + * Registers the standard MCP OAuth endpoints on a Hono app. + * + * IMPORTANT: These routes MUST be mounted at the application root. + */ +export function registerMcpAuthRoutes(app: Hono, options: RegisterMcpAuthRoutesOptions): void { + const web = createWebAuthRouter(options); + registerRoutes(app, web.routes); +} + +/** + * Registers only the auth metadata endpoints (RFC 8414 + RFC 9728) on a Hono app. + * + * IMPORTANT: These routes MUST be mounted at the application root. + */ +export function registerMcpAuthMetadataRoutes(app: Hono, options: AuthMetadataOptions): void { + const web = createWebAuthMetadataRouter(options); + registerRoutes(app, web.routes); +} + +function registerRoutes(app: Hono, routes: AuthRoute[]): void { + for (const route of routes) { + // Hono's `on()` expects methods like 'GET', 'POST', etc. + const handler: Handler = c => route.handler(c.req.raw); + app.on(route.methods, route.path, handler); + } +} diff --git a/packages/server-hono/src/index.ts b/packages/server-hono/src/index.ts new file mode 100644 index 000000000..5a7cb5129 --- /dev/null +++ b/packages/server-hono/src/index.ts @@ -0,0 +1,3 @@ +export * from './auth/router.js'; +export * from './middleware/hostHeaderValidation.js'; +export * from './streamableHttp.js'; diff --git a/packages/server-hono/src/middleware/hostHeaderValidation.ts b/packages/server-hono/src/middleware/hostHeaderValidation.ts new file mode 100644 index 000000000..8f7b20e88 --- /dev/null +++ b/packages/server-hono/src/middleware/hostHeaderValidation.ts @@ -0,0 +1,33 @@ +import { localhostAllowedHostnames, validateHostHeader } from '@modelcontextprotocol/server'; +import type { MiddlewareHandler } from 'hono'; + +/** + * Hono middleware for DNS rebinding protection. + * Validates Host header hostname (port-agnostic) against an allowed list. + */ +export function hostHeaderValidation(allowedHostnames: string[]): MiddlewareHandler { + return async (c, next) => { + const result = validateHostHeader(c.req.header('host'), allowedHostnames); + if (!result.ok) { + return c.json( + { + jsonrpc: '2.0', + error: { + code: -32000, + message: result.message + }, + id: null + }, + 403 + ); + } + return await next(); + }; +} + +/** + * Convenience middleware for localhost DNS rebinding protection. + */ +export function localhostHostValidation(): MiddlewareHandler { + return hostHeaderValidation(localhostAllowedHostnames()); +} diff --git a/packages/server-hono/src/streamableHttp.ts b/packages/server-hono/src/streamableHttp.ts new file mode 100644 index 000000000..d81960713 --- /dev/null +++ b/packages/server-hono/src/streamableHttp.ts @@ -0,0 +1,14 @@ +import type { WebStandardStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import type { Context, Handler } from 'hono'; + +/** + * Convenience Hono handler for the WebStandard Streamable HTTP transport. + * + * Usage: + * ```ts + * app.all('/mcp', mcpStreamableHttpHandler(transport)) + * ``` + */ +export function mcpStreamableHttpHandler(transport: WebStandardStreamableHTTPServerTransport): Handler { + return (c: Context) => transport.handleRequest(c.req.raw); +} diff --git a/packages/server-hono/test/server-hono.test.ts b/packages/server-hono/test/server-hono.test.ts new file mode 100644 index 000000000..8b143411b --- /dev/null +++ b/packages/server-hono/test/server-hono.test.ts @@ -0,0 +1,114 @@ +import type { AuthorizationParams, OAuthClientInformationFull, OAuthServerProvider, OAuthTokens } from '@modelcontextprotocol/server'; +import { Hono } from 'hono'; + +import { registerMcpAuthRoutes } from '../src/auth/router.js'; +import { hostHeaderValidation } from '../src/middleware/hostHeaderValidation.js'; +import { mcpStreamableHttpHandler } from '../src/streamableHttp.js'; + +describe('@modelcontextprotocol/server-hono', () => { + test('mcpStreamableHttpHandler delegates to transport.handleRequest', async () => { + const calls: { url?: string; method?: string }[] = []; + + const transport = { + async handleRequest(req: Request): Promise { + calls.push({ url: req.url, method: req.method }); + return new Response('ok', { status: 200, headers: { 'content-type': 'text/plain' } }); + } + }; + + const app = new Hono(); + app.all('/mcp', mcpStreamableHttpHandler(transport as unknown as Parameters[0])); + + const res = await app.request('http://localhost/mcp', { method: 'POST' }); + expect(res.status).toBe(200); + expect(await res.text()).toBe('ok'); + expect(calls).toHaveLength(1); + expect(calls[0]!.method).toBe('POST'); + expect(calls[0]!.url).toBe('http://localhost/mcp'); + }); + + test('hostHeaderValidation blocks invalid Host and allows valid Host', async () => { + const app = new Hono(); + app.use('*', hostHeaderValidation(['localhost'])); + app.get('/health', c => c.text('ok')); + + const bad = await app.request('http://localhost/health', { headers: { Host: 'evil.com:3000' } }); + expect(bad.status).toBe(403); + expect(await bad.json()).toEqual( + expect.objectContaining({ + jsonrpc: '2.0', + error: expect.objectContaining({ + code: -32000 + }), + id: null + }) + ); + + const good = await app.request('http://localhost/health', { headers: { Host: 'localhost:3000' } }); + expect(good.status).toBe(200); + expect(await good.text()).toBe('ok'); + }); + + test('registerMcpAuthRoutes mounts metadata + authorize routes', async () => { + const validClient: OAuthClientInformationFull = { + client_id: 'valid-client', + client_secret: 'valid-secret', + redirect_uris: ['https://example.com/callback'] + }; + + const provider: OAuthServerProvider = { + clientsStore: { + async getClient(clientId: string) { + return clientId === 'valid-client' ? validClient : undefined; + } + }, + async authorize(_client: OAuthClientInformationFull, params: AuthorizationParams): Promise { + const u = new URL(params.redirectUri); + u.searchParams.set('code', 'mock_auth_code'); + if (params.state) u.searchParams.set('state', params.state); + return Response.redirect(u.toString(), 302); + }, + async challengeForAuthorizationCode(): Promise { + return 'mock_challenge'; + }, + async exchangeAuthorizationCode(): Promise { + return { + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' + }; + }, + async exchangeRefreshToken(): Promise { + return { + access_token: 'new_mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'new_mock_refresh_token' + }; + }, + async verifyAccessToken() { + throw new Error('not used'); + } + }; + + const app = new Hono(); + registerMcpAuthRoutes(app, { provider, issuerUrl: new URL('https://auth.example.com') }); + + const metadata = await app.request('http://localhost/.well-known/oauth-authorization-server', { method: 'GET' }); + expect(metadata.status).toBe(200); + const metaJson = (await metadata.json()) as { issuer?: string; authorization_endpoint?: string }; + expect(metaJson.issuer).toBe('https://auth.example.com/'); + expect(metaJson.authorization_endpoint).toBe('https://auth.example.com/authorize'); + + const authorize = await app.request( + 'http://localhost/authorize?client_id=valid-client&response_type=code&code_challenge=x&code_challenge_method=S256&redirect_uri=https%3A%2F%2Fexample.com%2Fcallback&state=s', + { method: 'GET' } + ); + expect(authorize.status).toBe(302); + const location = authorize.headers.get('location')!; + expect(location).toContain('https://example.com/callback'); + expect(location).toContain('code=mock_auth_code'); + expect(location).toContain('state=s'); + }); +}); diff --git a/packages/server-hono/tsconfig.json b/packages/server-hono/tsconfig.json new file mode 100644 index 000000000..0d7fdd0c0 --- /dev/null +++ b/packages/server-hono/tsconfig.json @@ -0,0 +1,14 @@ +{ + "extends": "@modelcontextprotocol/tsconfig", + "include": ["./"], + "exclude": ["node_modules", "dist"], + "compilerOptions": { + "paths": { + "*": ["./*"], + "@modelcontextprotocol/server": ["./node_modules/@modelcontextprotocol/server/src/index.ts"], + "@modelcontextprotocol/core": [ + "./node_modules/@modelcontextprotocol/server/node_modules/@modelcontextprotocol/core/src/index.ts" + ] + } + } +} diff --git a/packages/server-hono/tsdown.config.ts b/packages/server-hono/tsdown.config.ts new file mode 100644 index 000000000..c72e7a2c4 --- /dev/null +++ b/packages/server-hono/tsdown.config.ts @@ -0,0 +1,23 @@ +import { defineConfig } from 'tsdown'; + +export default defineConfig({ + entry: ['src/index.ts'], + format: ['esm'], + outDir: 'dist', + clean: true, + sourcemap: true, + target: 'esnext', + platform: 'node', + shims: true, + dts: { + resolver: 'tsc', + compilerOptions: { + baseUrl: '.', + paths: { + '@modelcontextprotocol/server': ['../server/src/index.ts'], + '@modelcontextprotocol/core': ['../core/src/index.ts'] + } + } + }, + noExternal: ['@modelcontextprotocol/server', '@modelcontextprotocol/core'] +}); diff --git a/packages/server-hono/vitest.config.js b/packages/server-hono/vitest.config.js new file mode 100644 index 000000000..496fca320 --- /dev/null +++ b/packages/server-hono/vitest.config.js @@ -0,0 +1,3 @@ +import baseConfig from '@modelcontextprotocol/vitest-config'; + +export default baseConfig; diff --git a/packages/server/package.json b/packages/server/package.json index b039751f6..4f32cf171 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -45,11 +45,6 @@ }, "dependencies": { "content-type": "catalog:runtimeServerOnly", - "cors": "catalog:runtimeServerOnly", - "@hono/node-server": "catalog:runtimeServerOnly", - "hono": "catalog:runtimeServerOnly", - "express": "catalog:runtimeServerOnly", - "express-rate-limit": "catalog:runtimeServerOnly", "raw-body": "catalog:runtimeServerOnly", "pkce-challenge": "catalog:runtimeShared", "zod": "catalog:runtimeShared", diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index 8c9b9af5f..bbf3c2f59 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -1,7 +1,7 @@ export * from './server/completable.js'; -export * from './server/express.js'; export * from './server/inMemoryEventStore.js'; export * from './server/mcp.js'; +export * from './server/middleware/hostHeaderValidation.js'; export * from './server/server.js'; export * from './server/sse.js'; export * from './server/stdio.js'; diff --git a/packages/server/src/server/auth/handlers/authorize.ts b/packages/server/src/server/auth/handlers/authorize.ts index 65875529e..df2702f88 100644 --- a/packages/server/src/server/auth/handlers/authorize.ts +++ b/packages/server/src/server/auth/handlers/authorize.ts @@ -1,12 +1,9 @@ import { InvalidClientError, InvalidRequestError, OAuthError, ServerError, TooManyRequestsError } from '@modelcontextprotocol/core'; -import type { RequestHandler } from 'express'; -import express from 'express'; -import type { Options as RateLimitOptions } from 'express-rate-limit'; -import { rateLimit } from 'express-rate-limit'; import * as z from 'zod/v4'; -import { allowedMethods } from '../middleware/allowedMethods.js'; import type { OAuthServerProvider } from '../provider.js'; +import type { WebHandler } from '../web.js'; +import { getClientAddress, getParsedBody, InMemoryRateLimiter, jsonResponse, methodNotAllowedResponse, noStoreHeaders } from '../web.js'; export type AuthorizationHandlerOptions = { provider: OAuthServerProvider; @@ -14,7 +11,7 @@ export type AuthorizationHandlerOptions = { * Rate limiting configuration for the authorization endpoint. * Set to false to disable rate limiting for this endpoint. */ - rateLimit?: Partial | false; + rateLimit?: Partial<{ windowMs: number; max: number }> | false; }; // Parameters that must be validated in order to issue redirects. @@ -36,28 +33,44 @@ const RequestAuthorizationParamsSchema = z.object({ resource: z.string().url().optional() }); -export function authorizationHandler({ provider, rateLimit: rateLimitConfig }: AuthorizationHandlerOptions): RequestHandler { - // Create a router to apply middleware - const router = express.Router(); - router.use(allowedMethods(['GET', 'POST'])); - router.use(express.urlencoded({ extended: false })); - - // Apply rate limiting unless explicitly disabled - if (rateLimitConfig !== false) { - router.use( - rateLimit({ - windowMs: 15 * 60 * 1000, // 15 minutes - max: 100, // 100 requests per windowMs - standardHeaders: true, - legacyHeaders: false, - message: new TooManyRequestsError('You have exceeded the rate limit for authorization requests').toResponseObject(), - ...rateLimitConfig - }) - ); - } +export function authorizationHandler({ provider, rateLimit: rateLimitConfig }: AuthorizationHandlerOptions): WebHandler { + const limiter = + rateLimitConfig === false + ? undefined + : new InMemoryRateLimiter({ + windowMs: rateLimitConfig?.windowMs ?? 15 * 60 * 1000, + max: rateLimitConfig?.max ?? 100 + }); + + return async (req, ctx) => { + const noStore = noStoreHeaders(); + + // Rate limit by client address where possible (best-effort). + if (limiter) { + const key = `${getClientAddress(req, ctx) ?? 'global'}:authorize`; + const rl = limiter.consume(key); + if (!rl.allowed) { + return jsonResponse( + new TooManyRequestsError('You have exceeded the rate limit for authorization requests').toResponseObject(), + { + status: 429, + headers: { + ...noStore, + ...(rl.retryAfterSeconds ? { 'Retry-After': String(rl.retryAfterSeconds) } : {}) + } + } + ); + } + } - router.all('/', async (req, res) => { - res.setHeader('Cache-Control', 'no-store'); + if (req.method !== 'GET' && req.method !== 'POST') { + const resp = methodNotAllowedResponse(req, ['GET', 'POST']); + const body = await resp.text(); + return new Response(body, { + status: resp.status, + headers: { ...Object.fromEntries(resp.headers.entries()), ...noStore } + }); + } // In the authorization flow, errors are split into two categories: // 1. Pre-redirect errors (direct response with 400) @@ -66,7 +79,9 @@ export function authorizationHandler({ provider, rateLimit: rateLimitConfig }: A // Phase 1: Validate client_id and redirect_uri. Any errors here must be direct responses. let client_id, redirect_uri, client; try { - const result = ClientAuthorizationParamsSchema.safeParse(req.method === 'POST' ? req.body : req.query); + const source = + req.method === 'POST' ? await getParsedBody(req, ctx) : Object.fromEntries(new URL(req.url).searchParams.entries()); + const result = ClientAuthorizationParamsSchema.safeParse(source); if (!result.success) { throw new InvalidRequestError(result.error.message); } @@ -97,20 +112,20 @@ export function authorizationHandler({ provider, rateLimit: rateLimitConfig }: A // user anyway. if (error instanceof OAuthError) { const status = error instanceof ServerError ? 500 : 400; - res.status(status).json(error.toResponseObject()); + return jsonResponse(error.toResponseObject(), { status, headers: noStore }); } else { const serverError = new ServerError('Internal Server Error'); - res.status(500).json(serverError.toResponseObject()); + return jsonResponse(serverError.toResponseObject(), { status: 500, headers: noStore }); } - - return; } // Phase 2: Validate other parameters. Any errors here should go into redirect responses. let state; try { // Parse and validate authorization parameters - const parseResult = RequestAuthorizationParamsSchema.safeParse(req.method === 'POST' ? req.body : req.query); + const source = + req.method === 'POST' ? await getParsedBody(req, ctx) : Object.fromEntries(new URL(req.url).searchParams.entries()); + const parseResult = RequestAuthorizationParamsSchema.safeParse(source); if (!parseResult.success) { throw new InvalidRequestError(parseResult.error.message); } @@ -125,29 +140,28 @@ export function authorizationHandler({ provider, rateLimit: rateLimitConfig }: A } // All validation passed, proceed with authorization - await provider.authorize( - client, - { - state, - scopes: requestedScopes, - redirectUri: redirect_uri!, // TODO: Someone to look at. Strict tsconfig showed this could be undefined, while the return type is string. - codeChallenge: code_challenge, - resource: resource ? new URL(resource) : undefined - }, - res - ); + const providerResponse = await provider.authorize(client, { + state, + scopes: requestedScopes, + redirectUri: redirect_uri!, // TODO: Someone to look at. Strict tsconfig showed this could be undefined, while the return type is string. + codeChallenge: code_challenge, + resource: resource ? new URL(resource) : undefined + }); + const headers = new Headers(providerResponse.headers); + headers.set('Cache-Control', 'no-store'); + return new Response(providerResponse.body, { status: providerResponse.status, headers }); } catch (error) { // Post-redirect errors - redirect with error parameters if (error instanceof OAuthError) { - res.redirect(302, createErrorRedirect(redirect_uri!, error, state)); + const location = createErrorRedirect(redirect_uri!, error, state); + return new Response(null, { status: 302, headers: { Location: location, ...noStore } }); } else { const serverError = new ServerError('Internal Server Error'); - res.redirect(302, createErrorRedirect(redirect_uri!, serverError, state)); + const location = createErrorRedirect(redirect_uri!, serverError, state); + return new Response(null, { status: 302, headers: { Location: location, ...noStore } }); } } - }); - - return router; + }; } /** diff --git a/packages/server/src/server/auth/handlers/metadata.ts b/packages/server/src/server/auth/handlers/metadata.ts index 529a6e57a..fea42a8cb 100644 --- a/packages/server/src/server/auth/handlers/metadata.ts +++ b/packages/server/src/server/auth/handlers/metadata.ts @@ -1,21 +1,33 @@ import type { OAuthMetadata, OAuthProtectedResourceMetadata } from '@modelcontextprotocol/core'; -import cors from 'cors'; -import type { RequestHandler } from 'express'; -import express from 'express'; -import { allowedMethods } from '../middleware/allowedMethods.js'; +import type { WebHandler } from '../web.js'; +import { corsHeaders, corsPreflightResponse, jsonResponse, methodNotAllowedResponse } from '../web.js'; -export function metadataHandler(metadata: OAuthMetadata | OAuthProtectedResourceMetadata): RequestHandler { - // Nested router so we can configure middleware and restrict HTTP method - const router = express.Router(); +export function metadataHandler(metadata: OAuthMetadata | OAuthProtectedResourceMetadata): WebHandler { + const cors = { + allowOrigin: '*', + allowMethods: ['GET', 'OPTIONS'], + allowHeaders: ['Content-Type', 'Authorization'], + maxAgeSeconds: 60 * 60 * 24 + } as const; - // Configure CORS to allow any origin, to make accessible to web-based MCP clients - router.use(cors()); + return async req => { + if (req.method === 'OPTIONS') { + return corsPreflightResponse(cors); + } + if (req.method !== 'GET') { + const resp = methodNotAllowedResponse(req, ['GET', 'OPTIONS']); + // Add CORS headers for consistency with successful responses. + const body = await resp.text(); + return new Response(body, { + status: resp.status, + headers: { ...Object.fromEntries(resp.headers.entries()), ...corsHeaders(cors) } + }); + } - router.use(allowedMethods(['GET', 'OPTIONS'])); - router.get('/', (req, res) => { - res.status(200).json(metadata); - }); - - return router; + return jsonResponse(metadata, { + status: 200, + headers: corsHeaders(cors) + }); + }; } diff --git a/packages/server/src/server/auth/handlers/register.ts b/packages/server/src/server/auth/handlers/register.ts index a78154d48..fa54644c3 100644 --- a/packages/server/src/server/auth/handlers/register.ts +++ b/packages/server/src/server/auth/handlers/register.ts @@ -8,14 +8,19 @@ import { ServerError, TooManyRequestsError } from '@modelcontextprotocol/core'; -import cors from 'cors'; -import type { RequestHandler } from 'express'; -import express from 'express'; -import type { Options as RateLimitOptions } from 'express-rate-limit'; -import { rateLimit } from 'express-rate-limit'; import type { OAuthRegisteredClientsStore } from '../clients.js'; -import { allowedMethods } from '../middleware/allowedMethods.js'; +import type { WebHandler } from '../web.js'; +import { + corsHeaders, + corsPreflightResponse, + getClientAddress, + getParsedBody, + InMemoryRateLimiter, + jsonResponse, + methodNotAllowedResponse, + noStoreHeaders +} from '../web.js'; export type ClientRegistrationHandlerOptions = { /** @@ -35,7 +40,7 @@ export type ClientRegistrationHandlerOptions = { * Set to false to disable rate limiting for this endpoint. * Registration endpoints are particularly sensitive to abuse and should be rate limited. */ - rateLimit?: Partial | false; + rateLimit?: Partial<{ windowMs: number; max: number }> | false; /** * Whether to generate a client ID before calling the client registration endpoint. @@ -52,39 +57,61 @@ export function clientRegistrationHandler({ clientSecretExpirySeconds = DEFAULT_CLIENT_SECRET_EXPIRY_SECONDS, rateLimit: rateLimitConfig, clientIdGeneration = true -}: ClientRegistrationHandlerOptions): RequestHandler { +}: ClientRegistrationHandlerOptions): WebHandler { if (!clientsStore.registerClient) { throw new Error('Client registration store does not support registering clients'); } - // Nested router so we can configure middleware and restrict HTTP method - const router = express.Router(); - - // Configure CORS to allow any origin, to make accessible to web-based MCP clients - router.use(cors()); - - router.use(allowedMethods(['POST'])); - router.use(express.json()); - - // Apply rate limiting unless explicitly disabled - stricter limits for registration - if (rateLimitConfig !== false) { - router.use( - rateLimit({ - windowMs: 60 * 60 * 1000, // 1 hour - max: 20, // 20 requests per hour - stricter as registration is sensitive - standardHeaders: true, - legacyHeaders: false, - message: new TooManyRequestsError('You have exceeded the rate limit for client registration requests').toResponseObject(), - ...rateLimitConfig - }) - ); - } + const limiter = + rateLimitConfig === false + ? undefined + : new InMemoryRateLimiter({ + windowMs: rateLimitConfig?.windowMs ?? 60 * 60 * 1000, + max: rateLimitConfig?.max ?? 20 + }); + + const cors = { + allowOrigin: '*', + allowMethods: ['POST', 'OPTIONS'], + allowHeaders: ['Content-Type', 'Authorization'], + maxAgeSeconds: 60 * 60 * 24 + } as const; + + return async (req, ctx) => { + const baseHeaders = { ...corsHeaders(cors), ...noStoreHeaders() }; + + if (req.method === 'OPTIONS') { + return corsPreflightResponse(cors); + } + if (req.method !== 'POST') { + const resp = methodNotAllowedResponse(req, ['POST', 'OPTIONS']); + const body = await resp.text(); + return new Response(body, { + status: resp.status, + headers: { ...Object.fromEntries(resp.headers.entries()), ...baseHeaders } + }); + } - router.post('/', async (req, res) => { - res.setHeader('Cache-Control', 'no-store'); + if (limiter) { + const key = `${getClientAddress(req, ctx) ?? 'global'}:register`; + const rl = limiter.consume(key); + if (!rl.allowed) { + return jsonResponse( + new TooManyRequestsError('You have exceeded the rate limit for client registration requests').toResponseObject(), + { + status: 429, + headers: { + ...baseHeaders, + ...(rl.retryAfterSeconds ? { 'Retry-After': String(rl.retryAfterSeconds) } : {}) + } + } + ); + } + } try { - const parseResult = OAuthClientMetadataSchema.safeParse(req.body); + const rawBody = await getParsedBody(req, ctx); + const parseResult = OAuthClientMetadataSchema.safeParse(rawBody); if (!parseResult.success) { throw new InvalidClientMetadataError(parseResult.error.message); } @@ -113,17 +140,14 @@ export function clientRegistrationHandler({ } clientInfo = await clientsStore.registerClient!(clientInfo); - res.status(201).json(clientInfo); + return jsonResponse(clientInfo, { status: 201, headers: baseHeaders }); } catch (error) { if (error instanceof OAuthError) { const status = error instanceof ServerError ? 500 : 400; - res.status(status).json(error.toResponseObject()); - } else { - const serverError = new ServerError('Internal Server Error'); - res.status(500).json(serverError.toResponseObject()); + return jsonResponse(error.toResponseObject(), { status, headers: baseHeaders }); } + const serverError = new ServerError('Internal Server Error'); + return jsonResponse(serverError.toResponseObject(), { status: 500, headers: baseHeaders }); } - }); - - return router; + }; } diff --git a/packages/server/src/server/auth/handlers/revoke.ts b/packages/server/src/server/auth/handlers/revoke.ts index c7c9f8a6a..30c611cff 100644 --- a/packages/server/src/server/auth/handlers/revoke.ts +++ b/packages/server/src/server/auth/handlers/revoke.ts @@ -5,15 +5,20 @@ import { ServerError, TooManyRequestsError } from '@modelcontextprotocol/core'; -import cors from 'cors'; -import type { RequestHandler } from 'express'; -import express from 'express'; -import type { Options as RateLimitOptions } from 'express-rate-limit'; -import { rateLimit } from 'express-rate-limit'; -import { allowedMethods } from '../middleware/allowedMethods.js'; import { authenticateClient } from '../middleware/clientAuth.js'; import type { OAuthServerProvider } from '../provider.js'; +import type { WebHandler } from '../web.js'; +import { + corsHeaders, + corsPreflightResponse, + getClientAddress, + getParsedBody, + InMemoryRateLimiter, + jsonResponse, + methodNotAllowedResponse, + noStoreHeaders +} from '../web.js'; export type RevocationHandlerOptions = { provider: OAuthServerProvider; @@ -21,67 +26,79 @@ export type RevocationHandlerOptions = { * Rate limiting configuration for the token revocation endpoint. * Set to false to disable rate limiting for this endpoint. */ - rateLimit?: Partial | false; + rateLimit?: Partial<{ windowMs: number; max: number }> | false; }; -export function revocationHandler({ provider, rateLimit: rateLimitConfig }: RevocationHandlerOptions): RequestHandler { +export function revocationHandler({ provider, rateLimit: rateLimitConfig }: RevocationHandlerOptions): WebHandler { if (!provider.revokeToken) { throw new Error('Auth provider does not support revoking tokens'); } - // Nested router so we can configure middleware and restrict HTTP method - const router = express.Router(); - - // Configure CORS to allow any origin, to make accessible to web-based MCP clients - router.use(cors()); + const limiter = + rateLimitConfig === false + ? undefined + : new InMemoryRateLimiter({ + windowMs: rateLimitConfig?.windowMs ?? 15 * 60 * 1000, + max: rateLimitConfig?.max ?? 50 + }); - router.use(allowedMethods(['POST'])); - router.use(express.urlencoded({ extended: false })); + const cors = { + allowOrigin: '*', + allowMethods: ['POST', 'OPTIONS'], + allowHeaders: ['Content-Type', 'Authorization'], + maxAgeSeconds: 60 * 60 * 24 + } as const; - // Apply rate limiting unless explicitly disabled - if (rateLimitConfig !== false) { - router.use( - rateLimit({ - windowMs: 15 * 60 * 1000, // 15 minutes - max: 50, // 50 requests per windowMs - standardHeaders: true, - legacyHeaders: false, - message: new TooManyRequestsError('You have exceeded the rate limit for token revocation requests').toResponseObject(), - ...rateLimitConfig - }) - ); - } + return async (req, ctx) => { + const baseHeaders = { ...corsHeaders(cors), ...noStoreHeaders() }; - // Authenticate and extract client details - router.use(authenticateClient({ clientsStore: provider.clientsStore })); + if (req.method === 'OPTIONS') { + return corsPreflightResponse(cors); + } + if (req.method !== 'POST') { + const resp = methodNotAllowedResponse(req, ['POST', 'OPTIONS']); + const body = await resp.text(); + return new Response(body, { + status: resp.status, + headers: { ...Object.fromEntries(resp.headers.entries()), ...baseHeaders } + }); + } - router.post('/', async (req, res) => { - res.setHeader('Cache-Control', 'no-store'); + if (limiter) { + const key = `${getClientAddress(req, ctx) ?? 'global'}:revoke`; + const rl = limiter.consume(key); + if (!rl.allowed) { + return jsonResponse( + new TooManyRequestsError('You have exceeded the rate limit for token revocation requests').toResponseObject(), + { + status: 429, + headers: { + ...baseHeaders, + ...(rl.retryAfterSeconds ? { 'Retry-After': String(rl.retryAfterSeconds) } : {}) + } + } + ); + } + } try { - const parseResult = OAuthTokenRevocationRequestSchema.safeParse(req.body); + const rawBody = await getParsedBody(req, ctx); + const parseResult = OAuthTokenRevocationRequestSchema.safeParse(rawBody); if (!parseResult.success) { throw new InvalidRequestError(parseResult.error.message); } - const client = req.client; - if (!client) { - // This should never happen - throw new ServerError('Internal Server Error'); - } + const client = await authenticateClient(rawBody, { clientsStore: provider.clientsStore }); await provider.revokeToken!(client, parseResult.data); - res.status(200).json({}); + return jsonResponse({}, { status: 200, headers: baseHeaders }); } catch (error) { if (error instanceof OAuthError) { const status = error instanceof ServerError ? 500 : 400; - res.status(status).json(error.toResponseObject()); - } else { - const serverError = new ServerError('Internal Server Error'); - res.status(500).json(serverError.toResponseObject()); + return jsonResponse(error.toResponseObject(), { status, headers: baseHeaders }); } + const serverError = new ServerError('Internal Server Error'); + return jsonResponse(serverError.toResponseObject(), { status: 500, headers: baseHeaders }); } - }); - - return router; + }; } diff --git a/packages/server/src/server/auth/handlers/token.ts b/packages/server/src/server/auth/handlers/token.ts index 3b7941294..096a10ff3 100644 --- a/packages/server/src/server/auth/handlers/token.ts +++ b/packages/server/src/server/auth/handlers/token.ts @@ -6,17 +6,22 @@ import { TooManyRequestsError, UnsupportedGrantTypeError } from '@modelcontextprotocol/core'; -import cors from 'cors'; -import type { RequestHandler } from 'express'; -import express from 'express'; -import type { Options as RateLimitOptions } from 'express-rate-limit'; -import { rateLimit } from 'express-rate-limit'; import { verifyChallenge } from 'pkce-challenge'; import * as z from 'zod/v4'; -import { allowedMethods } from '../middleware/allowedMethods.js'; import { authenticateClient } from '../middleware/clientAuth.js'; import type { OAuthServerProvider } from '../provider.js'; +import type { WebHandler } from '../web.js'; +import { + corsHeaders, + corsPreflightResponse, + getClientAddress, + getParsedBody, + InMemoryRateLimiter, + jsonResponse, + methodNotAllowedResponse, + noStoreHeaders +} from '../web.js'; export type TokenHandlerOptions = { provider: OAuthServerProvider; @@ -24,7 +29,7 @@ export type TokenHandlerOptions = { * Rate limiting configuration for the token endpoint. * Set to false to disable rate limiting for this endpoint. */ - rateLimit?: Partial | false; + rateLimit?: Partial<{ windowMs: number; max: number }> | false; }; const TokenRequestSchema = z.object({ @@ -44,53 +49,65 @@ const RefreshTokenGrantSchema = z.object({ resource: z.string().url().optional() }); -export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHandlerOptions): RequestHandler { - // Nested router so we can configure middleware and restrict HTTP method - const router = express.Router(); - - // Configure CORS to allow any origin, to make accessible to web-based MCP clients - router.use(cors()); - - router.use(allowedMethods(['POST'])); - router.use(express.urlencoded({ extended: false })); - - // Apply rate limiting unless explicitly disabled - if (rateLimitConfig !== false) { - router.use( - rateLimit({ - windowMs: 15 * 60 * 1000, // 15 minutes - max: 50, // 50 requests per windowMs - standardHeaders: true, - legacyHeaders: false, - message: new TooManyRequestsError('You have exceeded the rate limit for token requests').toResponseObject(), - ...rateLimitConfig - }) - ); - } - - // Authenticate and extract client details - router.use(authenticateClient({ clientsStore: provider.clientsStore })); +export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHandlerOptions): WebHandler { + const limiter = + rateLimitConfig === false + ? undefined + : new InMemoryRateLimiter({ + windowMs: rateLimitConfig?.windowMs ?? 15 * 60 * 1000, + max: rateLimitConfig?.max ?? 50 + }); + + const cors = { + allowOrigin: '*', + allowMethods: ['POST', 'OPTIONS'], + allowHeaders: ['Content-Type', 'Authorization'], + maxAgeSeconds: 60 * 60 * 24 + } as const; + + return async (req, ctx) => { + const baseHeaders = { ...corsHeaders(cors), ...noStoreHeaders() }; + + if (req.method === 'OPTIONS') { + return corsPreflightResponse(cors); + } + if (req.method !== 'POST') { + const resp = methodNotAllowedResponse(req, ['POST', 'OPTIONS']); + const body = await resp.text(); + return new Response(body, { + status: resp.status, + headers: { ...Object.fromEntries(resp.headers.entries()), ...baseHeaders } + }); + } - router.post('/', async (req, res) => { - res.setHeader('Cache-Control', 'no-store'); + if (limiter) { + const key = `${getClientAddress(req, ctx) ?? 'global'}:token`; + const rl = limiter.consume(key); + if (!rl.allowed) { + return jsonResponse(new TooManyRequestsError('You have exceeded the rate limit for token requests').toResponseObject(), { + status: 429, + headers: { + ...baseHeaders, + ...(rl.retryAfterSeconds ? { 'Retry-After': String(rl.retryAfterSeconds) } : {}) + } + }); + } + } try { - const parseResult = TokenRequestSchema.safeParse(req.body); + const rawBody = await getParsedBody(req, ctx); + const parseResult = TokenRequestSchema.safeParse(rawBody); if (!parseResult.success) { throw new InvalidRequestError(parseResult.error.message); } const { grant_type } = parseResult.data; - const client = req.client; - if (!client) { - // This should never happen - throw new ServerError('Internal Server Error'); - } + const client = await authenticateClient(rawBody, { clientsStore: provider.clientsStore }); switch (grant_type) { case 'authorization_code': { - const parseResult = AuthorizationCodeGrantSchema.safeParse(req.body); + const parseResult = AuthorizationCodeGrantSchema.safeParse(rawBody); if (!parseResult.success) { throw new InvalidRequestError(parseResult.error.message); } @@ -116,12 +133,11 @@ export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHand redirect_uri, resource ? new URL(resource) : undefined ); - res.status(200).json(tokens); - break; + return jsonResponse(tokens, { status: 200, headers: baseHeaders }); } case 'refresh_token': { - const parseResult = RefreshTokenGrantSchema.safeParse(req.body); + const parseResult = RefreshTokenGrantSchema.safeParse(rawBody); if (!parseResult.success) { throw new InvalidRequestError(parseResult.error.message); } @@ -135,8 +151,7 @@ export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHand scopes, resource ? new URL(resource) : undefined ); - res.status(200).json(tokens); - break; + return jsonResponse(tokens, { status: 200, headers: baseHeaders }); } // Additional auth methods will not be added on the server side of the SDK. case 'client_credentials': @@ -146,13 +161,10 @@ export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHand } catch (error) { if (error instanceof OAuthError) { const status = error instanceof ServerError ? 500 : 400; - res.status(status).json(error.toResponseObject()); - } else { - const serverError = new ServerError('Internal Server Error'); - res.status(500).json(serverError.toResponseObject()); + return jsonResponse(error.toResponseObject(), { status, headers: baseHeaders }); } + const serverError = new ServerError('Internal Server Error'); + return jsonResponse(serverError.toResponseObject(), { status: 500, headers: baseHeaders }); } - }); - - return router; + }; } diff --git a/packages/server/src/server/auth/index.ts b/packages/server/src/server/auth/index.ts index 5369224cf..2b176805b 100644 --- a/packages/server/src/server/auth/index.ts +++ b/packages/server/src/server/auth/index.ts @@ -10,3 +10,4 @@ export * from './middleware/clientAuth.js'; export * from './provider.js'; export * from './providers/proxyProvider.js'; export * from './router.js'; +export * from './web.js'; diff --git a/packages/server/src/server/auth/middleware/allowedMethods.ts b/packages/server/src/server/auth/middleware/allowedMethods.ts index 72c076ec4..5c5245690 100644 --- a/packages/server/src/server/auth/middleware/allowedMethods.ts +++ b/packages/server/src/server/auth/middleware/allowedMethods.ts @@ -1,20 +1,20 @@ import { MethodNotAllowedError } from '@modelcontextprotocol/core'; -import type { RequestHandler } from 'express'; + +import { jsonResponse } from '../web.js'; /** - * Middleware to handle unsupported HTTP methods with a 405 Method Not Allowed response. + * Helper to handle unsupported HTTP methods with a 405 Method Not Allowed response. * * @param allowedMethods Array of allowed HTTP methods for this endpoint (e.g., ['GET', 'POST']) - * @returns Express middleware that returns a 405 error if method not in allowed list + * @returns Response if method not in allowed list, otherwise undefined */ -export function allowedMethods(allowedMethods: string[]): RequestHandler { - return (req, res, next) => { - if (allowedMethods.includes(req.method)) { - next(); - return; - } - - const error = new MethodNotAllowedError(`The method ${req.method} is not allowed for this endpoint`); - res.status(405).set('Allow', allowedMethods.join(', ')).json(error.toResponseObject()); - }; +export function allowedMethods(allowedMethods: string[], req: Request): Response | undefined { + if (allowedMethods.includes(req.method)) { + return undefined; + } + const error = new MethodNotAllowedError(`The method ${req.method} is not allowed for this endpoint`); + return jsonResponse(error.toResponseObject(), { + status: 405, + headers: { Allow: allowedMethods.join(', ') } + }); } diff --git a/packages/server/src/server/auth/middleware/bearerAuth.ts b/packages/server/src/server/auth/middleware/bearerAuth.ts index 1a16de1a9..853e400f6 100644 --- a/packages/server/src/server/auth/middleware/bearerAuth.ts +++ b/packages/server/src/server/auth/middleware/bearerAuth.ts @@ -1,8 +1,8 @@ import type { AuthInfo } from '@modelcontextprotocol/core'; import { InsufficientScopeError, InvalidTokenError, OAuthError, ServerError } from '@modelcontextprotocol/core'; -import type { RequestHandler } from 'express'; import type { OAuthTokenVerifier } from '../provider.js'; +import { jsonResponse } from '../web.js'; export type BearerAuthMiddlewareOptions = { /** @@ -21,83 +21,81 @@ export type BearerAuthMiddlewareOptions = { resourceMetadataUrl?: string; }; -declare module 'express-serve-static-core' { - interface Request { - /** - * Information about the validated access token, if the `requireBearerAuth` middleware was used. - */ - auth?: AuthInfo; - } -} - /** - * Middleware that requires a valid Bearer token in the Authorization header. + * Validates a Bearer token in the Authorization header. * - * This will validate the token with the auth provider and add the resulting auth info to the request object. - * - * If resourceMetadataUrl is provided, it will be included in the WWW-Authenticate header - * for 401 responses as per the OAuth 2.0 Protected Resource Metadata spec. + * Returns either `{ authInfo }` on success or `{ response }` on failure. */ -export function requireBearerAuth({ verifier, requiredScopes = [], resourceMetadataUrl }: BearerAuthMiddlewareOptions): RequestHandler { - return async (req, res, next) => { - try { - const authHeader = req.headers.authorization; - if (!authHeader) { - throw new InvalidTokenError('Missing Authorization header'); - } +export async function requireBearerAuth( + req: Request, + { verifier, requiredScopes = [], resourceMetadataUrl }: BearerAuthMiddlewareOptions +): Promise<{ authInfo: AuthInfo } | { response: Response }> { + try { + const authHeader = req.headers.get('authorization'); + if (!authHeader) { + throw new InvalidTokenError('Missing Authorization header'); + } - const [type, token] = authHeader.split(' '); - if (type!.toLowerCase() !== 'bearer' || !token) { - throw new InvalidTokenError("Invalid Authorization header format, expected 'Bearer TOKEN'"); - } + const [type, token] = authHeader.split(' '); + if (type!.toLowerCase() !== 'bearer' || !token) { + throw new InvalidTokenError("Invalid Authorization header format, expected 'Bearer TOKEN'"); + } - const authInfo = await verifier.verifyAccessToken(token); + const authInfo = await verifier.verifyAccessToken(token); - // Check if token has the required scopes (if any) - if (requiredScopes.length > 0) { - const hasAllScopes = requiredScopes.every(scope => authInfo.scopes.includes(scope)); + // Check if token has the required scopes (if any) + if (requiredScopes.length > 0) { + const hasAllScopes = requiredScopes.every(scope => authInfo.scopes.includes(scope)); - if (!hasAllScopes) { - throw new InsufficientScopeError('Insufficient scope'); - } + if (!hasAllScopes) { + throw new InsufficientScopeError('Insufficient scope'); } + } - // Check if the token is set to expire or if it is expired - if (typeof authInfo.expiresAt !== 'number' || isNaN(authInfo.expiresAt)) { - throw new InvalidTokenError('Token has no expiration time'); - } else if (authInfo.expiresAt < Date.now() / 1000) { - throw new InvalidTokenError('Token has expired'); + // Check if the token is set to expire or if it is expired + if (typeof authInfo.expiresAt !== 'number' || isNaN(authInfo.expiresAt)) { + throw new InvalidTokenError('Token has no expiration time'); + } else if (authInfo.expiresAt < Date.now() / 1000) { + throw new InvalidTokenError('Token has expired'); + } + + return { authInfo }; + } catch (error) { + // Build WWW-Authenticate header parts + const buildWwwAuthHeader = (errorCode: string, message: string): string => { + let header = `Bearer error="${errorCode}", error_description="${message}"`; + if (requiredScopes.length > 0) { + header += `, scope="${requiredScopes.join(' ')}"`; + } + if (resourceMetadataUrl) { + header += `, resource_metadata="${resourceMetadataUrl}"`; } + return header; + }; - req.auth = authInfo; - next(); - } catch (error) { - // Build WWW-Authenticate header parts - const buildWwwAuthHeader = (errorCode: string, message: string): string => { - let header = `Bearer error="${errorCode}", error_description="${message}"`; - if (requiredScopes.length > 0) { - header += `, scope="${requiredScopes.join(' ')}"`; - } - if (resourceMetadataUrl) { - header += `, resource_metadata="${resourceMetadataUrl}"`; - } - return header; + if (error instanceof InvalidTokenError) { + return { + response: jsonResponse(error.toResponseObject(), { + status: 401, + headers: { 'WWW-Authenticate': buildWwwAuthHeader(error.errorCode, error.message) } + }) }; - - if (error instanceof InvalidTokenError) { - res.set('WWW-Authenticate', buildWwwAuthHeader(error.errorCode, error.message)); - res.status(401).json(error.toResponseObject()); - } else if (error instanceof InsufficientScopeError) { - res.set('WWW-Authenticate', buildWwwAuthHeader(error.errorCode, error.message)); - res.status(403).json(error.toResponseObject()); - } else if (error instanceof ServerError) { - res.status(500).json(error.toResponseObject()); - } else if (error instanceof OAuthError) { - res.status(400).json(error.toResponseObject()); - } else { - const serverError = new ServerError('Internal Server Error'); - res.status(500).json(serverError.toResponseObject()); - } } - }; + if (error instanceof InsufficientScopeError) { + return { + response: jsonResponse(error.toResponseObject(), { + status: 403, + headers: { 'WWW-Authenticate': buildWwwAuthHeader(error.errorCode, error.message) } + }) + }; + } + if (error instanceof ServerError) { + return { response: jsonResponse(error.toResponseObject(), { status: 500 }) }; + } + if (error instanceof OAuthError) { + return { response: jsonResponse(error.toResponseObject(), { status: 400 }) }; + } + const serverError = new ServerError('Internal Server Error'); + return { response: jsonResponse(serverError.toResponseObject(), { status: 500 }) }; + } } diff --git a/packages/server/src/server/auth/middleware/clientAuth.ts b/packages/server/src/server/auth/middleware/clientAuth.ts index ac4bc8b79..9da271e35 100644 --- a/packages/server/src/server/auth/middleware/clientAuth.ts +++ b/packages/server/src/server/auth/middleware/clientAuth.ts @@ -1,6 +1,5 @@ import type { OAuthClientInformationFull } from '@modelcontextprotocol/core'; -import { InvalidClientError, InvalidRequestError, OAuthError, ServerError } from '@modelcontextprotocol/core'; -import type { RequestHandler } from 'express'; +import { InvalidClientError, InvalidRequestError } from '@modelcontextprotocol/core'; import * as z from 'zod/v4'; import type { OAuthRegisteredClientsStore } from '../clients.js'; @@ -17,49 +16,35 @@ const ClientAuthenticatedRequestSchema = z.object({ client_secret: z.string().optional() }); -declare module 'express-serve-static-core' { - interface Request { - /** - * The authenticated client for this request, if the `authenticateClient` middleware was used. - */ - client?: OAuthClientInformationFull; +/** + * Parses and validates client credentials from a request body, returning the authenticated client. + * + * Throws an OAuthError (or ServerError) on failure. + */ +export async function authenticateClient( + body: unknown, + { clientsStore }: ClientAuthenticationMiddlewareOptions +): Promise { + const result = ClientAuthenticatedRequestSchema.safeParse(body); + if (!result.success) { + throw new InvalidRequestError(String(result.error)); } -} - -export function authenticateClient({ clientsStore }: ClientAuthenticationMiddlewareOptions): RequestHandler { - return async (req, res, next) => { - try { - const result = ClientAuthenticatedRequestSchema.safeParse(req.body); - if (!result.success) { - throw new InvalidRequestError(String(result.error)); - } - const { client_id, client_secret } = result.data; - const client = await clientsStore.getClient(client_id); - if (!client) { - throw new InvalidClientError('Invalid client_id'); - } - if (client.client_secret) { - if (!client_secret) { - throw new InvalidClientError('Client secret is required'); - } - if (client.client_secret !== client_secret) { - throw new InvalidClientError('Invalid client_secret'); - } - if (client.client_secret_expires_at && client.client_secret_expires_at < Math.floor(Date.now() / 1000)) { - throw new InvalidClientError('Client secret has expired'); - } - } - - req.client = client; - next(); - } catch (error) { - if (error instanceof OAuthError) { - const status = error instanceof ServerError ? 500 : 400; - res.status(status).json(error.toResponseObject()); - } else { - const serverError = new ServerError('Internal Server Error'); - res.status(500).json(serverError.toResponseObject()); - } + const { client_id, client_secret } = result.data; + const client = await clientsStore.getClient(client_id); + if (!client) { + throw new InvalidClientError('Invalid client_id'); + } + if (client.client_secret) { + if (!client_secret) { + throw new InvalidClientError('Client secret is required'); } - }; + if (client.client_secret !== client_secret) { + throw new InvalidClientError('Invalid client_secret'); + } + if (client.client_secret_expires_at && client.client_secret_expires_at < Math.floor(Date.now() / 1000)) { + throw new InvalidClientError('Client secret has expired'); + } + } + + return client; } diff --git a/packages/server/src/server/auth/provider.ts b/packages/server/src/server/auth/provider.ts index 6d27fb792..d7dc395d1 100644 --- a/packages/server/src/server/auth/provider.ts +++ b/packages/server/src/server/auth/provider.ts @@ -1,5 +1,4 @@ import type { AuthInfo, OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens } from '@modelcontextprotocol/core'; -import type { Response } from 'express'; import type { OAuthRegisteredClientsStore } from './clients.js'; @@ -27,7 +26,7 @@ export interface OAuthServerProvider { * - In the successful case, the redirect MUST include the `code` and `state` (if present) query parameters. * - In the error case, the redirect MUST include the `error` query parameter, and MAY include an optional `error_description` query parameter. */ - authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise; + authorize(client: OAuthClientInformationFull, params: AuthorizationParams): Promise; /** * Returns the `codeChallenge` that was used when the indicated authorization began. diff --git a/packages/server/src/server/auth/providers/proxyProvider.ts b/packages/server/src/server/auth/providers/proxyProvider.ts index 0688754c0..230e8766e 100644 --- a/packages/server/src/server/auth/providers/proxyProvider.ts +++ b/packages/server/src/server/auth/providers/proxyProvider.ts @@ -1,6 +1,5 @@ import type { AuthInfo, FetchLike, OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens } from '@modelcontextprotocol/core'; import { OAuthClientInformationFullSchema, OAuthTokensSchema, ServerError } from '@modelcontextprotocol/core'; -import type { Response } from 'express'; import type { OAuthRegisteredClientsStore } from '../clients.js'; import type { AuthorizationParams, OAuthServerProvider } from '../provider.js'; @@ -112,7 +111,7 @@ export class ProxyOAuthServerProvider implements OAuthServerProvider { }; } - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + async authorize(client: OAuthClientInformationFull, params: AuthorizationParams): Promise { // Start with required OAuth parameters const targetUrl = new URL(this._endpoints.authorizationUrl); const searchParams = new URLSearchParams({ @@ -129,7 +128,7 @@ export class ProxyOAuthServerProvider implements OAuthServerProvider { if (params.resource) searchParams.set('resource', params.resource.href); targetUrl.search = searchParams.toString(); - res.redirect(targetUrl.toString()); + return Response.redirect(targetUrl.toString(), 302); } async challengeForAuthorizationCode(_client: OAuthClientInformationFull, _authorizationCode: string): Promise { diff --git a/packages/server/src/server/auth/router.ts b/packages/server/src/server/auth/router.ts index ba8b030e0..083657250 100644 --- a/packages/server/src/server/auth/router.ts +++ b/packages/server/src/server/auth/router.ts @@ -1,6 +1,4 @@ import type { OAuthMetadata, OAuthProtectedResourceMetadata } from '@modelcontextprotocol/core'; -import type { RequestHandler } from 'express'; -import express from 'express'; import type { AuthorizationHandlerOptions } from './handlers/authorize.js'; import { authorizationHandler } from './handlers/authorize.js'; @@ -12,6 +10,7 @@ import { revocationHandler } from './handlers/revoke.js'; import type { TokenHandlerOptions } from './handlers/token.js'; import { tokenHandler } from './handlers/token.js'; import type { OAuthServerProvider } from './provider.js'; +import type { WebHandler } from './web.js'; // Check for dev mode flag that allows HTTP issuer URLs (for development/testing only) const allowInsecureIssuerUrl = @@ -67,6 +66,24 @@ export type AuthRouterOptions = { tokenOptions?: Omit; }; +export type AuthRoute = { + path: string; + methods: string[]; + handler: WebHandler; +}; + +export type WebAuthRouter = { + /** + * List of concrete routes (absolute paths) that should be mounted at the application root. + */ + routes: AuthRoute[]; + + /** + * Convenience dispatcher that matches on `new URL(req.url).pathname` and calls the correct handler. + */ + handle: WebHandler; +}; + const checkIssuerUrl = (issuer: URL): void => { // Technically RFC 8414 does not permit a localhost HTTPS exemption, but this will be necessary for ease of testing if (issuer.protocol !== 'https:' && issuer.hostname !== 'localhost' && issuer.hostname !== '127.0.0.1' && !allowInsecureIssuerUrl) { @@ -126,53 +143,62 @@ export const createOAuthMetadata = (options: { * Note: if your MCP server is only a resource server and not an authorization server, use mcpAuthMetadataRouter instead. * * By default, rate limiting is applied to all endpoints to prevent abuse. - * - * This router MUST be installed at the application root, like so: - * - * const app = express(); - * app.use(mcpAuthRouter(...)); */ -export function mcpAuthRouter(options: AuthRouterOptions): RequestHandler { +export function mcpAuthRouter(options: AuthRouterOptions): WebAuthRouter { const oauthMetadata = createOAuthMetadata(options); - - const router = express.Router(); - - router.use( - new URL(oauthMetadata.authorization_endpoint).pathname, - authorizationHandler({ provider: options.provider, ...options.authorizationOptions }) - ); - - router.use(new URL(oauthMetadata.token_endpoint).pathname, tokenHandler({ provider: options.provider, ...options.tokenOptions })); - - router.use( - mcpAuthMetadataRouter({ - oauthMetadata, - // Prefer explicit RS; otherwise fall back to AS baseUrl, then to issuer (back-compat) - resourceServerUrl: options.resourceServerUrl ?? options.baseUrl ?? new URL(oauthMetadata.issuer), - serviceDocumentationUrl: options.serviceDocumentationUrl, - scopesSupported: options.scopesSupported, - resourceName: options.resourceName - }) - ); + const routes: AuthRoute[] = []; + + routes.push({ + path: new URL(oauthMetadata.authorization_endpoint).pathname, + methods: ['GET', 'POST'], + handler: authorizationHandler({ provider: options.provider, ...options.authorizationOptions }) + }); + + routes.push({ + path: new URL(oauthMetadata.token_endpoint).pathname, + methods: ['POST', 'OPTIONS'], + handler: tokenHandler({ provider: options.provider, ...options.tokenOptions }) + }); + + const metadataRouter = mcpAuthMetadataRouter({ + oauthMetadata, + // Prefer explicit RS; otherwise fall back to AS baseUrl, then to issuer (back-compat) + resourceServerUrl: options.resourceServerUrl ?? options.baseUrl ?? new URL(oauthMetadata.issuer), + serviceDocumentationUrl: options.serviceDocumentationUrl, + scopesSupported: options.scopesSupported, + resourceName: options.resourceName + }); + routes.push(...metadataRouter.routes); if (oauthMetadata.registration_endpoint) { - router.use( - new URL(oauthMetadata.registration_endpoint).pathname, - clientRegistrationHandler({ + routes.push({ + path: new URL(oauthMetadata.registration_endpoint).pathname, + methods: ['POST', 'OPTIONS'], + handler: clientRegistrationHandler({ clientsStore: options.provider.clientsStore, ...options.clientRegistrationOptions }) - ); + }); } if (oauthMetadata.revocation_endpoint) { - router.use( - new URL(oauthMetadata.revocation_endpoint).pathname, - revocationHandler({ provider: options.provider, ...options.revocationOptions }) - ); + routes.push({ + path: new URL(oauthMetadata.revocation_endpoint).pathname, + methods: ['POST', 'OPTIONS'], + handler: revocationHandler({ provider: options.provider, ...options.revocationOptions }) + }); } - return router; + const handle: WebHandler = async (req, ctx) => { + const pathname = new URL(req.url).pathname; + const route = routes.find(r => r.path === pathname); + if (!route) { + return new Response('Not Found', { status: 404 }); + } + return route.handler(req, ctx); + }; + + return { routes, handle }; } export type AuthMetadataOptions = { @@ -203,11 +229,9 @@ export type AuthMetadataOptions = { resourceName?: string; }; -export function mcpAuthMetadataRouter(options: AuthMetadataOptions): express.Router { +export function mcpAuthMetadataRouter(options: AuthMetadataOptions): WebAuthRouter { checkIssuerUrl(new URL(options.oauthMetadata.issuer)); - const router = express.Router(); - const protectedResourceMetadata: OAuthProtectedResourceMetadata = { resource: options.resourceServerUrl.href, @@ -220,12 +244,24 @@ export function mcpAuthMetadataRouter(options: AuthMetadataOptions): express.Rou // Serve PRM at the path-specific URL per RFC 9728 const rsPath = new URL(options.resourceServerUrl.href).pathname; - router.use(`/.well-known/oauth-protected-resource${rsPath === '/' ? '' : rsPath}`, metadataHandler(protectedResourceMetadata)); - - // Always add this for OAuth Authorization Server metadata per RFC 8414 - router.use('/.well-known/oauth-authorization-server', metadataHandler(options.oauthMetadata)); + const prmPath = `/.well-known/oauth-protected-resource${rsPath === '/' ? '' : rsPath}`; + + const routes: AuthRoute[] = [ + { path: prmPath, methods: ['GET', 'OPTIONS'], handler: metadataHandler(protectedResourceMetadata) }, + // Always add this for OAuth Authorization Server metadata per RFC 8414 + { path: '/.well-known/oauth-authorization-server', methods: ['GET', 'OPTIONS'], handler: metadataHandler(options.oauthMetadata) } + ]; + + const handle: WebHandler = async (req, ctx) => { + const pathname = new URL(req.url).pathname; + const route = routes.find(r => r.path === pathname); + if (!route) { + return new Response('Not Found', { status: 404 }); + } + return route.handler(req, ctx); + }; - return router; + return { routes, handle }; } /** diff --git a/packages/server/src/server/auth/web.ts b/packages/server/src/server/auth/web.ts new file mode 100644 index 000000000..a16ada4ef --- /dev/null +++ b/packages/server/src/server/auth/web.ts @@ -0,0 +1,140 @@ +import { MethodNotAllowedError } from '@modelcontextprotocol/core'; + +export type HeaderMap = Record; + +export type WebHandlerContext = { + /** + * Optional pre-parsed request body from an upstream framework. + * If provided, handlers will use this instead of reading from the Request stream. + */ + parsedBody?: unknown; + + /** + * Optional client address for rate limiting (e.g., IP). + */ + clientAddress?: string; +}; + +export type WebHandler = (req: Request, ctx?: WebHandlerContext) => Promise; + +export function jsonResponse(body: unknown, init?: { status?: number; headers?: HeaderMap }): Response { + const headers: HeaderMap = { 'Content-Type': 'application/json' }; + if (init?.headers) { + Object.assign(headers, init.headers); + } + return new Response(JSON.stringify(body), { + status: init?.status ?? 200, + headers + }); +} + +export function noStoreHeaders(): HeaderMap { + return { 'Cache-Control': 'no-store' }; +} + +export function getClientAddress(req: Request, ctx?: WebHandlerContext): string | undefined { + if (ctx?.clientAddress) return ctx.clientAddress; + const xff = req.headers.get('x-forwarded-for'); + if (xff) return xff.split(',')[0]?.trim(); + return undefined; +} + +export async function getParsedBody(req: Request, ctx?: WebHandlerContext): Promise { + if (ctx?.parsedBody !== undefined) { + return ctx.parsedBody; + } + + const ct = req.headers.get('content-type') ?? ''; + + if (ct.includes('application/json')) { + return await req.json(); + } + + if (ct.includes('application/x-www-form-urlencoded')) { + const text = await req.text(); + return objectFromUrlEncoded(text); + } + + // Empty bodies are treated as empty objects. + const text = await req.text(); + if (!text) return {}; + + // If content-type is missing/unknown, fall back to treating it as urlencoded-like. + return objectFromUrlEncoded(text); +} + +export function objectFromUrlEncoded(body: string): Record { + const params = new URLSearchParams(body); + const out: Record = {}; + for (const [k, v] of params.entries()) out[k] = v; + return out; +} + +export function methodNotAllowedResponse(req: Request, allowed: string[]): Response { + const error = new MethodNotAllowedError(`The method ${req.method} is not allowed for this endpoint`); + return jsonResponse(error.toResponseObject(), { + status: 405, + headers: { Allow: allowed.join(', ') } + }); +} + +export type CorsOptions = { + allowOrigin?: string; + allowMethods: readonly string[]; + allowHeaders?: readonly string[]; + exposeHeaders?: readonly string[]; + maxAgeSeconds?: number; +}; + +export function corsHeaders(options: CorsOptions): HeaderMap { + return { + 'Access-Control-Allow-Origin': options.allowOrigin ?? '*', + 'Access-Control-Allow-Methods': options.allowMethods.join(', '), + 'Access-Control-Allow-Headers': (options.allowHeaders ?? ['Content-Type', 'Authorization']).join(', '), + ...(options.exposeHeaders ? { 'Access-Control-Expose-Headers': options.exposeHeaders.join(', ') } : {}), + ...(options.maxAgeSeconds !== undefined ? { 'Access-Control-Max-Age': String(options.maxAgeSeconds) } : {}) + }; +} + +export function corsPreflightResponse(options: CorsOptions): Response { + return new Response(null, { + status: 204, + headers: corsHeaders(options) + }); +} + +export type InMemoryRateLimitConfig = { + windowMs: number; + max: number; +}; + +type RateState = { windowStart: number; count: number }; + +/** + * Minimal in-memory rate limiter for single-process deployments. + * Not suitable for distributed setups without an external store. + */ +export class InMemoryRateLimiter { + private _state = new Map(); + + constructor(private _config: InMemoryRateLimitConfig) {} + + consume(key: string): { allowed: boolean; retryAfterSeconds?: number } { + const now = Date.now(); + const windowStart = now - (now % this._config.windowMs); + const existing = this._state.get(key); + + if (!existing || existing.windowStart !== windowStart) { + this._state.set(key, { windowStart, count: 1 }); + return { allowed: true }; + } + + if (existing.count >= this._config.max) { + const retryAfterMs = windowStart + this._config.windowMs - now; + return { allowed: false, retryAfterSeconds: Math.max(1, Math.ceil(retryAfterMs / 1000)) }; + } + + existing.count += 1; + return { allowed: true }; + } +} diff --git a/packages/server/src/server/middleware/hostHeaderValidation.ts b/packages/server/src/server/middleware/hostHeaderValidation.ts index f46178db3..e4d13ecf5 100644 --- a/packages/server/src/server/middleware/hostHeaderValidation.ts +++ b/packages/server/src/server/middleware/hostHeaderValidation.ts @@ -1,79 +1,67 @@ -import type { NextFunction, Request, RequestHandler, Response } from 'express'; +export type HostHeaderValidationResult = + | { ok: true; hostname: string } + | { + ok: false; + errorCode: 'missing_host' | 'invalid_host_header' | 'invalid_host'; + message: string; + hostHeader?: string; + hostname?: string; + }; /** - * Express middleware for DNS rebinding protection. - * Validates Host header hostname (port-agnostic) against an allowed list. + * Parse and validate a Host header against an allowlist of hostnames (port-agnostic). * - * This is particularly important for servers without authorization or HTTPS, - * such as localhost servers or development servers. DNS rebinding attacks can - * bypass same-origin policy by manipulating DNS to point a domain to a - * localhost address, allowing malicious websites to access your local server. - * - * @param allowedHostnames - List of allowed hostnames (without ports). - * For IPv6, provide the address with brackets (e.g., '[::1]'). - * @returns Express middleware function - * - * @example - * ```typescript - * const middleware = hostHeaderValidation(['localhost', '127.0.0.1', '[::1]']); - * app.use(middleware); - * ``` + * - Input host header may include a port (e.g. `localhost:3000`) or IPv6 brackets (e.g. `[::1]:3000`). + * - Allowlist items should be hostnames only (no ports). For IPv6, include brackets (e.g. `[::1]`). */ -export function hostHeaderValidation(allowedHostnames: string[]): RequestHandler { - return (req: Request, res: Response, next: NextFunction) => { - const hostHeader = req.headers.host; - if (!hostHeader) { - res.status(403).json({ - jsonrpc: '2.0', - error: { - code: -32000, - message: 'Missing Host header' - }, - id: null - }); - return; - } +export function validateHostHeader(hostHeader: string | null | undefined, allowedHostnames: string[]): HostHeaderValidationResult { + if (!hostHeader) { + return { ok: false, errorCode: 'missing_host', message: 'Missing Host header' }; + } - // Use URL API to parse hostname (handles IPv4, IPv6, and regular hostnames) - let hostname: string; - try { - hostname = new URL(`http://${hostHeader}`).hostname; - } catch { - res.status(403).json({ - jsonrpc: '2.0', - error: { - code: -32000, - message: `Invalid Host header: ${hostHeader}` - }, - id: null - }); - return; - } + // Use URL API to parse hostname (handles IPv4, IPv6, and regular hostnames) + let hostname: string; + try { + hostname = new URL(`http://${hostHeader}`).hostname; + } catch { + return { ok: false, errorCode: 'invalid_host_header', message: `Invalid Host header: ${hostHeader}`, hostHeader }; + } - if (!allowedHostnames.includes(hostname)) { - res.status(403).json({ - jsonrpc: '2.0', - error: { - code: -32000, - message: `Invalid Host: ${hostname}` - }, - id: null - }); - return; - } - next(); - }; + if (!allowedHostnames.includes(hostname)) { + return { ok: false, errorCode: 'invalid_host', message: `Invalid Host: ${hostname}`, hostHeader, hostname }; + } + + return { ok: true, hostname }; } /** - * Convenience middleware for localhost DNS rebinding protection. - * Allows only localhost, 127.0.0.1, and [::1] (IPv6 localhost) hostnames. - * + * Convenience allowlist for localhost DNS rebinding protection. + */ +export function localhostAllowedHostnames(): string[] { + return ['localhost', '127.0.0.1', '[::1]']; +} + +/** + * Web-standard Request helper for DNS rebinding protection. * @example - * ```typescript - * app.use(localhostHostValidation()); - * ``` + * const result = validateHostHeader(req.headers.get('host'), ['localhost']) */ -export function localhostHostValidation(): RequestHandler { - return hostHeaderValidation(['localhost', '127.0.0.1', '[::1]']); +export function hostHeaderValidationResponse(req: Request, allowedHostnames: string[]): Response | undefined { + const result = validateHostHeader(req.headers.get('host'), allowedHostnames); + if (result.ok) return undefined; + + return new Response( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32000, + message: result.message + }, + id: null + }), + { + status: 403, + headers: { 'Content-Type': 'application/json' } + } + ); } diff --git a/packages/server/src/server/sse.ts b/packages/server/src/server/sse.ts index 4fd0fa1d6..44117d0dd 100644 --- a/packages/server/src/server/sse.ts +++ b/packages/server/src/server/sse.ts @@ -16,24 +16,24 @@ export interface SSEServerTransportOptions { /** * List of allowed host header values for DNS rebinding protection. * If not specified, host validation is disabled. - * @deprecated Use the `hostHeaderValidation` middleware from `@modelcontextprotocol/sdk/server/middleware/hostHeaderValidation.js` instead, - * or use `createMcpExpressApp` from `@modelcontextprotocol/sdk/server/express.js` which includes localhost protection by default. + * @deprecated Use the `hostHeaderValidationResponse` helper from `@modelcontextprotocol/server`, + * or use `createMcpExpressApp` from `@modelcontextprotocol/server-express` which includes localhost protection by default. */ allowedHosts?: string[]; /** * List of allowed origin header values for DNS rebinding protection. * If not specified, origin validation is disabled. - * @deprecated Use the `hostHeaderValidation` middleware from `@modelcontextprotocol/sdk/server/middleware/hostHeaderValidation.js` instead, - * or use `createMcpExpressApp` from `@modelcontextprotocol/sdk/server/express.js` which includes localhost protection by default. + * @deprecated Use the `hostHeaderValidationResponse` helper from `@modelcontextprotocol/server`, + * or use `createMcpExpressApp` from `@modelcontextprotocol/server-express` which includes localhost protection by default. */ allowedOrigins?: string[]; /** * Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). * Default is false for backwards compatibility. - * @deprecated Use the `hostHeaderValidation` middleware from `@modelcontextprotocol/sdk/server/middleware/hostHeaderValidation.js` instead, - * or use `createMcpExpressApp` from `@modelcontextprotocol/sdk/server/express.js` which includes localhost protection by default. + * @deprecated Use the `hostHeaderValidationResponse` helper from `@modelcontextprotocol/server`, + * or use `createMcpExpressApp` from `@modelcontextprotocol/server-express` which includes localhost protection by default. */ enableDnsRebindingProtection?: boolean; } diff --git a/packages/server/src/server/streamableHttp.ts b/packages/server/src/server/streamableHttp.ts index f9ee07ca8..354f640f9 100644 --- a/packages/server/src/server/streamableHttp.ts +++ b/packages/server/src/server/streamableHttp.ts @@ -8,8 +8,9 @@ */ import type { IncomingMessage, ServerResponse } from 'node:http'; +import { Readable } from 'node:stream'; +import { URL } from 'node:url'; -import { getRequestListener } from '@hono/node-server'; import type { AuthInfo, JSONRPCMessage, MessageExtraInfo, RequestId, Transport } from '@modelcontextprotocol/core'; import type { WebStandardStreamableHTTPServerTransportOptions } from './webStandardStreamableHttp.js'; @@ -22,12 +23,117 @@ import { WebStandardStreamableHTTPServerTransport } from './webStandardStreamabl */ export type StreamableHTTPServerTransportOptions = WebStandardStreamableHTTPServerTransportOptions; +type NodeToWebRequestOptions = { + parsedBody?: unknown; +}; + +function getRequestUrl(req: IncomingMessage): URL { + const host = req.headers.host ?? 'localhost'; + const isTls = Boolean((req.socket as { encrypted?: boolean } | undefined)?.encrypted); + const protocol = isTls ? 'https' : 'http'; + const path = req.url ?? '/'; + return new URL(path, `${protocol}://${host}`); +} + +function toHeaders(req: IncomingMessage): Headers { + const headers = new Headers(); + for (const [key, value] of Object.entries(req.headers)) { + if (value === undefined) continue; + if (Array.isArray(value)) { + // Preserve multi-value headers as a comma-joined value. + // (Set-Cookie does not appear on requests; this is fine here.) + headers.set(key, value.join(', ')); + } else { + headers.set(key, value); + } + } + return headers; +} + +async function readBody(req: IncomingMessage): Promise { + const chunks: Buffer[] = []; + for await (const chunk of req) { + chunks.push(Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk)); + } + return Buffer.concat(chunks); +} + +async function nodeToWebRequest(req: IncomingMessage, options?: NodeToWebRequestOptions): Promise { + const url = getRequestUrl(req); + const method = req.method ?? 'GET'; + const headers = toHeaders(req); + + // If an upstream framework already parsed the body, the IncomingMessage stream + // may be consumed; rely on parsedBody instead of trying to read again. + if (options?.parsedBody !== undefined) { + return new Request(url, { method, headers }); + } + + // Only attach bodies for methods that can carry one. + if (method === 'GET' || method === 'HEAD') { + return new Request(url, { method, headers }); + } + + const body = await readBody(req); + return new Request(url, { method, headers, body }); +} + +function writeWebResponse(res: ServerResponse, webResponse: Response): Promise { + res.statusCode = webResponse.status; + + // Prefer undici's multi Set-Cookie support when available. + // Note: must call with the correct `this` (undici brand-checks Headers). + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const getSetCookie = (webResponse.headers as any).getSetCookie as (() => string[]) | undefined; + const setCookies = typeof getSetCookie === 'function' ? getSetCookie.call(webResponse.headers) : undefined; + + for (const [key, value] of webResponse.headers.entries()) { + // We'll handle Set-Cookie separately if we have structured values. + if (key.toLowerCase() === 'set-cookie' && setCookies?.length) continue; + res.setHeader(key, value); + } + + if (setCookies?.length) { + res.setHeader('set-cookie', setCookies); + } + + // Node requires writing headers before streaming body. + res.flushHeaders?.(); + + if (!webResponse.body) { + res.end(); + return Promise.resolve(); + } + + return new Promise((resolve, reject) => { + const readable = Readable.fromWeb(webResponse.body as unknown as ReadableStream); + readable.on('error', err => { + try { + res.destroy(err as Error); + } catch { + // ignore + } + reject(err); + }); + res.on('error', reject); + res.on('close', () => { + try { + readable.destroy(); + } catch { + // ignore + } + }); + readable.pipe(res); + res.on('finish', () => resolve()); + }); +} + /** * Server transport for Streamable HTTP: this implements the MCP Streamable HTTP transport specification. * It supports both SSE streaming and direct HTTP responses. * * This is a wrapper around `WebStandardStreamableHTTPServerTransport` that provides Node.js HTTP compatibility. - * It uses the `@hono/node-server` library to convert between Node.js HTTP and Web Standard APIs. + * It converts between Node.js HTTP (IncomingMessage/ServerResponse) and Web Standard Request/Response. * * Usage example: * @@ -61,23 +167,9 @@ export type StreamableHTTPServerTransportOptions = WebStandardStreamableHTTPServ */ export class StreamableHTTPServerTransport implements Transport { private _webStandardTransport: WebStandardStreamableHTTPServerTransport; - private _requestListener: ReturnType; - // Store auth and parsedBody per request for passing through to handleRequest - private _requestContext: WeakMap = new WeakMap(); constructor(options: StreamableHTTPServerTransportOptions = {}) { this._webStandardTransport = new WebStandardStreamableHTTPServerTransport(options); - - // Create a request listener that wraps the web standard transport - // getRequestListener converts Node.js HTTP to Web Standard and properly handles SSE streaming - this._requestListener = getRequestListener(async (webRequest: Request) => { - // Get context if available (set during handleRequest) - const context = this._requestContext.get(webRequest); - return this._webStandardTransport.handleRequest(webRequest, { - authInfo: context?.authInfo, - parsedBody: context?.parsedBody - }); - }); } /** @@ -153,21 +245,13 @@ export class StreamableHTTPServerTransport implements Transport { * @param parsedBody - Optional pre-parsed body from body-parser middleware */ async handleRequest(req: IncomingMessage & { auth?: AuthInfo }, res: ServerResponse, parsedBody?: unknown): Promise { - // Store context for this request to pass through auth and parsedBody - // We need to intercept the request creation to attach this context const authInfo = req.auth; - - // Create a custom handler that includes our context - const handler = getRequestListener(async (webRequest: Request) => { - return this._webStandardTransport.handleRequest(webRequest, { - authInfo, - parsedBody - }); + const webRequest = await nodeToWebRequest(req, { parsedBody }); + const webResponse = await this._webStandardTransport.handleRequest(webRequest, { + authInfo, + parsedBody }); - - // Delegate to the request listener which handles all the Node.js <-> Web Standard conversion - // including proper SSE streaming support - await handler(req, res); + await writeWebResponse(res, webResponse); } /** diff --git a/packages/server/test/server/auth/handlers/authorize.test.ts b/packages/server/test/server/auth/handlers/authorize.test.ts index b84de3bc3..e5e65d72c 100644 --- a/packages/server/test/server/auth/handlers/authorize.test.ts +++ b/packages/server/test/server/auth/handlers/authorize.test.ts @@ -1,60 +1,41 @@ -import type { AuthInfo, OAuthClientInformationFull, OAuthTokens } from '@modelcontextprotocol/core'; -import { InvalidTokenError } from '@modelcontextprotocol/core'; -import type { Response } from 'express'; -import express from 'express'; -import supertest from 'supertest'; +import type { OAuthClientInformationFull, OAuthTokens } from '@modelcontextprotocol/core'; import type { OAuthRegisteredClientsStore } from '../../../../src/server/auth/clients.js'; -import type { AuthorizationHandlerOptions } from '../../../../src/server/auth/handlers/authorize.js'; import { authorizationHandler } from '../../../../src/server/auth/handlers/authorize.js'; import type { AuthorizationParams, OAuthServerProvider } from '../../../../src/server/auth/provider.js'; -describe('Authorization Handler', () => { - // Mock client data +describe('authorizationHandler (web)', () => { const validClient: OAuthClientInformationFull = { client_id: 'valid-client', client_secret: 'valid-secret', - redirect_uris: ['https://example.com/callback'], - scope: 'profile email' + redirect_uris: ['https://example.com/callback'] }; const multiRedirectClient: OAuthClientInformationFull = { client_id: 'multi-redirect-client', client_secret: 'valid-secret', - redirect_uris: ['https://example.com/callback1', 'https://example.com/callback2'], - scope: 'profile email' + redirect_uris: ['https://example.com/callback1', 'https://example.com/callback2'] }; - // Mock client store - const mockClientStore: OAuthRegisteredClientsStore = { - async getClient(clientId: string): Promise { - if (clientId === 'valid-client') { - return validClient; - } else if (clientId === 'multi-redirect-client') { - return multiRedirectClient; - } + const clientsStore: OAuthRegisteredClientsStore = { + async getClient(clientId: string) { + if (clientId === 'valid-client') return validClient; + if (clientId === 'multi-redirect-client') return multiRedirectClient; return undefined; } }; - // Mock provider - const mockProvider: OAuthServerProvider = { - clientsStore: mockClientStore, - - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { - // Mock implementation - redirects to redirectUri with code and state - const redirectUrl = new URL(params.redirectUri); - redirectUrl.searchParams.set('code', 'mock_auth_code'); - if (params.state) { - redirectUrl.searchParams.set('state', params.state); - } - res.redirect(302, redirectUrl.toString()); + const provider: OAuthServerProvider = { + clientsStore, + async authorize(_client, params: AuthorizationParams): Promise { + const u = new URL(params.redirectUri); + u.searchParams.set('code', 'mock_auth_code'); + if (params.state) u.searchParams.set('state', params.state); + return Response.redirect(u.toString(), 302); }, - async challengeForAuthorizationCode(): Promise { return 'mock_challenge'; }, - async exchangeAuthorizationCode(): Promise { return { access_token: 'mock_access_token', @@ -63,7 +44,6 @@ describe('Authorization Handler', () => { refresh_token: 'mock_refresh_token' }; }, - async exchangeRefreshToken(): Promise { return { access_token: 'new_mock_access_token', @@ -72,225 +52,53 @@ describe('Authorization Handler', () => { refresh_token: 'new_mock_refresh_token' }; }, - - async verifyAccessToken(token: string): Promise { - if (token === 'valid_token') { - return { - token, - clientId: 'valid-client', - scopes: ['read', 'write'], - expiresAt: Date.now() / 1000 + 3600 - }; - } - throw new InvalidTokenError('Token is invalid or expired'); - }, - - async revokeToken(): Promise { - // Do nothing in mock + async verifyAccessToken() { + throw new Error('not used'); } }; - // Setup express app with handler - let app: express.Express; - let options: AuthorizationHandlerOptions; - - beforeEach(() => { - app = express(); - options = { provider: mockProvider }; - const handler = authorizationHandler(options); - app.use('/authorize', handler); + it('returns 405 for unsupported methods', async () => { + const handler = authorizationHandler({ provider }); + const res = await handler(new Request('http://localhost/authorize', { method: 'PUT' })); + expect(res.status).toBe(405); }); - describe('HTTP method validation', () => { - it('rejects non-GET/POST methods', async () => { - const response = await supertest(app).put('/authorize').query({ client_id: 'valid-client' }); - - expect(response.status).toBe(405); // Method not allowed response from handler - }); + it('returns 400 if client does not exist', async () => { + const handler = authorizationHandler({ provider }); + const res = await handler( + new Request( + 'http://localhost/authorize?client_id=missing&response_type=code&code_challenge=x&code_challenge_method=S256&redirect_uri=https%3A%2F%2Fexample.com%2Fcallback', + { method: 'GET' } + ) + ); + expect(res.status).toBe(400); + expect(await res.json()).toEqual(expect.objectContaining({ error: 'invalid_client' })); }); - describe('Client validation', () => { - it('requires client_id parameter', async () => { - const response = await supertest(app).get('/authorize'); - - expect(response.status).toBe(400); - expect(response.text).toContain('client_id'); - }); - - it('validates that client exists', async () => { - const response = await supertest(app).get('/authorize').query({ client_id: 'nonexistent-client' }); - - expect(response.status).toBe(400); - }); + it('redirects with a code on valid request (single redirect_uri inferred)', async () => { + const handler = authorizationHandler({ provider, rateLimit: false }); + const res = await handler( + new Request( + 'http://localhost/authorize?client_id=valid-client&response_type=code&code_challenge=challenge123&code_challenge_method=S256', + { method: 'GET' } + ) + ); + expect(res.status).toBe(302); + const location = res.headers.get('location')!; + expect(location).toContain('https://example.com/callback'); + expect(location).toContain('code=mock_auth_code'); + expect(res.headers.get('cache-control')).toBe('no-store'); }); - describe('Redirect URI validation', () => { - it('uses the only redirect_uri if client has just one and none provided', async () => { - const response = await supertest(app).get('/authorize').query({ - client_id: 'valid-client', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256' - }); - - expect(response.status).toBe(302); - const location = new URL(response.header.location!); - expect(location.origin + location.pathname).toBe('https://example.com/callback'); - }); - - it('requires redirect_uri if client has multiple', async () => { - const response = await supertest(app).get('/authorize').query({ - client_id: 'multi-redirect-client', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256' - }); - - expect(response.status).toBe(400); - }); - - it('validates redirect_uri against client registered URIs', async () => { - const response = await supertest(app).get('/authorize').query({ - client_id: 'valid-client', - redirect_uri: 'https://malicious.com/callback', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256' - }); - - expect(response.status).toBe(400); - }); - - it('accepts valid redirect_uri that client registered with', async () => { - const response = await supertest(app).get('/authorize').query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256' - }); - - expect(response.status).toBe(302); - const location = new URL(response.header.location!); - expect(location.origin + location.pathname).toBe('https://example.com/callback'); - }); - }); - - describe('Authorization request validation', () => { - it('requires response_type=code', async () => { - const response = await supertest(app).get('/authorize').query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'token', // invalid - we only support code flow - code_challenge: 'challenge123', - code_challenge_method: 'S256' - }); - - expect(response.status).toBe(302); - const location = new URL(response.header.location!); - expect(location.searchParams.get('error')).toBe('invalid_request'); - }); - - it('requires code_challenge parameter', async () => { - const response = await supertest(app).get('/authorize').query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'code', - code_challenge_method: 'S256' - // Missing code_challenge - }); - - expect(response.status).toBe(302); - const location = new URL(response.header.location!); - expect(location.searchParams.get('error')).toBe('invalid_request'); - }); - - it('requires code_challenge_method=S256', async () => { - const response = await supertest(app).get('/authorize').query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'plain' // Only S256 is supported - }); - - expect(response.status).toBe(302); - const location = new URL(response.header.location!); - expect(location.searchParams.get('error')).toBe('invalid_request'); - }); - }); - - describe('Resource parameter validation', () => { - it('propagates resource parameter', async () => { - const mockProviderWithResource = vi.spyOn(mockProvider, 'authorize'); - - const response = await supertest(app).get('/authorize').query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256', - resource: 'https://api.example.com/resource' - }); - - expect(response.status).toBe(302); - expect(mockProviderWithResource).toHaveBeenCalledWith( - validClient, - expect.objectContaining({ - resource: new URL('https://api.example.com/resource'), - redirectUri: 'https://example.com/callback', - codeChallenge: 'challenge123' - }), - expect.any(Object) - ); - }); - }); - - describe('Successful authorization', () => { - it('handles successful authorization with all parameters', async () => { - const response = await supertest(app).get('/authorize').query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256', - scope: 'profile email', - state: 'xyz789' - }); - - expect(response.status).toBe(302); - const location = new URL(response.header.location!); - expect(location.origin + location.pathname).toBe('https://example.com/callback'); - expect(location.searchParams.get('code')).toBe('mock_auth_code'); - expect(location.searchParams.get('state')).toBe('xyz789'); - }); - - it('preserves state parameter in response', async () => { - const response = await supertest(app).get('/authorize').query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256', - state: 'state-value-123' - }); - - expect(response.status).toBe(302); - const location = new URL(response.header.location!); - expect(location.searchParams.get('state')).toBe('state-value-123'); - }); - - it('handles POST requests the same as GET', async () => { - const response = await supertest(app).post('/authorize').type('form').send({ - client_id: 'valid-client', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256' - }); - - expect(response.status).toBe(302); - const location = new URL(response.header.location!); - expect(location.searchParams.has('code')).toBe(true); - }); + it('requires redirect_uri if client has multiple redirect URIs', async () => { + const handler = authorizationHandler({ provider, rateLimit: false }); + const res = await handler( + new Request( + 'http://localhost/authorize?client_id=multi-redirect-client&response_type=code&code_challenge=challenge123&code_challenge_method=S256', + { method: 'GET' } + ) + ); + expect(res.status).toBe(400); + expect(await res.json()).toEqual(expect.objectContaining({ error: 'invalid_request' })); }); }); diff --git a/packages/server/test/server/auth/handlers/metadata.test.ts b/packages/server/test/server/auth/handlers/metadata.test.ts index 0dc51e51d..722320925 100644 --- a/packages/server/test/server/auth/handlers/metadata.test.ts +++ b/packages/server/test/server/auth/handlers/metadata.test.ts @@ -1,6 +1,4 @@ import type { OAuthMetadata } from '@modelcontextprotocol/core'; -import express from 'express'; -import supertest from 'supertest'; import { metadataHandler } from '../../../../src/server/auth/handlers/metadata.js'; @@ -18,62 +16,65 @@ describe('Metadata Handler', () => { code_challenge_methods_supported: ['S256'] }; - let app: express.Express; - - beforeEach(() => { - // Setup express app with metadata handler - app = express(); - app.use('/.well-known/oauth-authorization-server', metadataHandler(exampleMetadata)); - }); - it('requires GET method', async () => { - const response = await supertest(app).post('/.well-known/oauth-authorization-server').send({}); + const handler = metadataHandler(exampleMetadata); + const res = await handler(new Request('http://localhost/.well-known/oauth-authorization-server', { method: 'POST' })); - expect(response.status).toBe(405); - expect(response.headers.allow).toBe('GET, OPTIONS'); - expect(response.body).toEqual({ + expect(res.status).toBe(405); + expect(res.headers.get('allow')).toBe('GET, OPTIONS'); + expect(await res.json()).toEqual({ error: 'method_not_allowed', error_description: 'The method POST is not allowed for this endpoint' }); }); it('returns the metadata object', async () => { - const response = await supertest(app).get('/.well-known/oauth-authorization-server'); + const handler = metadataHandler(exampleMetadata); + const res = await handler(new Request('http://localhost/.well-known/oauth-authorization-server', { method: 'GET' })); - expect(response.status).toBe(200); - expect(response.body).toEqual(exampleMetadata); + expect(res.status).toBe(200); + expect(await res.json()).toEqual(exampleMetadata); }); it('includes CORS headers in response', async () => { - const response = await supertest(app).get('/.well-known/oauth-authorization-server').set('Origin', 'https://example.com'); + const handler = metadataHandler(exampleMetadata); + const res = await handler( + new Request('http://localhost/.well-known/oauth-authorization-server', { + method: 'GET', + headers: { Origin: 'https://example.com' } + }) + ); - expect(response.header['access-control-allow-origin']).toBe('*'); + expect(res.headers.get('access-control-allow-origin')).toBe('*'); }); it('supports OPTIONS preflight requests', async () => { - const response = await supertest(app) - .options('/.well-known/oauth-authorization-server') - .set('Origin', 'https://example.com') - .set('Access-Control-Request-Method', 'GET'); + const handler = metadataHandler(exampleMetadata); + const res = await handler( + new Request('http://localhost/.well-known/oauth-authorization-server', { + method: 'OPTIONS', + headers: { + Origin: 'https://example.com', + 'Access-Control-Request-Method': 'GET' + } + }) + ); - expect(response.status).toBe(204); - expect(response.header['access-control-allow-origin']).toBe('*'); + expect(res.status).toBe(204); + expect(res.headers.get('access-control-allow-origin')).toBe('*'); }); it('works with minimal metadata', async () => { - // Setup a new express app with minimal metadata - const minimalApp = express(); const minimalMetadata: OAuthMetadata = { issuer: 'https://auth.example.com', authorization_endpoint: 'https://auth.example.com/authorize', token_endpoint: 'https://auth.example.com/token', response_types_supported: ['code'] }; - minimalApp.use('/.well-known/oauth-authorization-server', metadataHandler(minimalMetadata)); - - const response = await supertest(minimalApp).get('/.well-known/oauth-authorization-server'); + const handler = metadataHandler(minimalMetadata); + const res = await handler(new Request('http://localhost/.well-known/oauth-authorization-server', { method: 'GET' })); - expect(response.status).toBe(200); - expect(response.body).toEqual(minimalMetadata); + expect(res.status).toBe(200); + expect(await res.json()).toEqual(minimalMetadata); }); }); diff --git a/packages/server/test/server/auth/handlers/register.test.ts b/packages/server/test/server/auth/handlers/register.test.ts index b10e048ed..1d3673c3f 100644 --- a/packages/server/test/server/auth/handlers/register.test.ts +++ b/packages/server/test/server/auth/handlers/register.test.ts @@ -1,274 +1,39 @@ -import type { OAuthClientInformationFull, OAuthClientMetadata } from '@modelcontextprotocol/core'; -import express from 'express'; -import supertest from 'supertest'; -import type { MockInstance } from 'vitest'; +import type { OAuthClientInformationFull } from '@modelcontextprotocol/core'; import type { OAuthRegisteredClientsStore } from '../../../../src/server/auth/clients.js'; -import type { ClientRegistrationHandlerOptions } from '../../../../src/server/auth/handlers/register.js'; import { clientRegistrationHandler } from '../../../../src/server/auth/handlers/register.js'; -describe('Client Registration Handler', () => { - // Mock client store with registration support - const mockClientStoreWithRegistration: OAuthRegisteredClientsStore = { - async getClient(_clientId: string): Promise { - return undefined; - }, - - async registerClient(client: OAuthClientInformationFull): Promise { - // Return the client info as-is in the mock - return client; - } - }; - - // Mock client store without registration support - const mockClientStoreWithoutRegistration: OAuthRegisteredClientsStore = { - async getClient(_clientId: string): Promise { - return undefined; - } - // No registerClient method - }; - - describe('Handler creation', () => { - it('throws error if client store does not support registration', () => { - const options: ClientRegistrationHandlerOptions = { - clientsStore: mockClientStoreWithoutRegistration - }; - - expect(() => clientRegistrationHandler(options)).toThrow('does not support registering clients'); - }); - - it('creates handler if client store supports registration', () => { - const options: ClientRegistrationHandlerOptions = { - clientsStore: mockClientStoreWithRegistration - }; - - expect(() => clientRegistrationHandler(options)).not.toThrow(); - }); - }); - - describe('Request handling', () => { - let app: express.Express; - let spyRegisterClient: MockInstance; - - beforeEach(() => { - // Setup express app with registration handler - app = express(); - const options: ClientRegistrationHandlerOptions = { - clientsStore: mockClientStoreWithRegistration, - clientSecretExpirySeconds: 86400 // 1 day for testing - }; - - app.use('/register', clientRegistrationHandler(options)); - - // Spy on the registerClient method - spyRegisterClient = vi.spyOn(mockClientStoreWithRegistration, 'registerClient'); - }); - - afterEach(() => { - spyRegisterClient.mockRestore(); - }); - - it('requires POST method', async () => { - const response = await supertest(app) - .get('/register') - .send({ - redirect_uris: ['https://example.com/callback'] - }); - - expect(response.status).toBe(405); - expect(response.headers.allow).toBe('POST'); - expect(response.body).toEqual({ - error: 'method_not_allowed', - error_description: 'The method GET is not allowed for this endpoint' - }); - expect(spyRegisterClient).not.toHaveBeenCalled(); - }); - - it('validates required client metadata', async () => { - const response = await supertest(app).post('/register').send({ - // Missing redirect_uris (required) - client_name: 'Test Client' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_client_metadata'); - expect(spyRegisterClient).not.toHaveBeenCalled(); - }); - - it('validates redirect URIs format', async () => { - const response = await supertest(app) - .post('/register') - .send({ - redirect_uris: ['invalid-url'] // Invalid URL format - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_client_metadata'); - expect(response.body.error_description).toContain('redirect_uris'); - expect(spyRegisterClient).not.toHaveBeenCalled(); - }); - - it('successfully registers client with minimal metadata', async () => { - const clientMetadata: OAuthClientMetadata = { - redirect_uris: ['https://example.com/callback'] - }; - - const response = await supertest(app).post('/register').send(clientMetadata); - - expect(response.status).toBe(201); - - // Verify the generated client information - expect(response.body.client_id).toBeDefined(); - expect(response.body.client_secret).toBeDefined(); - expect(response.body.client_id_issued_at).toBeDefined(); - expect(response.body.client_secret_expires_at).toBeDefined(); - expect(response.body.redirect_uris).toEqual(['https://example.com/callback']); - - // Verify client was registered - expect(spyRegisterClient).toHaveBeenCalledTimes(1); - }); - - it('sets client_secret to undefined for token_endpoint_auth_method=none', async () => { - const clientMetadata: OAuthClientMetadata = { - redirect_uris: ['https://example.com/callback'], - token_endpoint_auth_method: 'none' - }; - - const response = await supertest(app).post('/register').send(clientMetadata); - - expect(response.status).toBe(201); - expect(response.body.client_secret).toBeUndefined(); - expect(response.body.client_secret_expires_at).toBeUndefined(); - }); - - it('sets client_secret_expires_at for public clients only', async () => { - // Test for public client (token_endpoint_auth_method not 'none') - const publicClientMetadata: OAuthClientMetadata = { - redirect_uris: ['https://example.com/callback'], - token_endpoint_auth_method: 'client_secret_basic' - }; - - const publicResponse = await supertest(app).post('/register').send(publicClientMetadata); - - expect(publicResponse.status).toBe(201); - expect(publicResponse.body.client_secret).toBeDefined(); - expect(publicResponse.body.client_secret_expires_at).toBeDefined(); - - // Test for non-public client (token_endpoint_auth_method is 'none') - const nonPublicClientMetadata: OAuthClientMetadata = { - redirect_uris: ['https://example.com/callback'], - token_endpoint_auth_method: 'none' - }; - - const nonPublicResponse = await supertest(app).post('/register').send(nonPublicClientMetadata); - - expect(nonPublicResponse.status).toBe(201); - expect(nonPublicResponse.body.client_secret).toBeUndefined(); - expect(nonPublicResponse.body.client_secret_expires_at).toBeUndefined(); - }); - - it('sets expiry based on clientSecretExpirySeconds', async () => { - // Create handler with custom expiry time - const customApp = express(); - const options: ClientRegistrationHandlerOptions = { - clientsStore: mockClientStoreWithRegistration, - clientSecretExpirySeconds: 3600 // 1 hour - }; - - customApp.use('/register', clientRegistrationHandler(options)); - - const response = await supertest(customApp) - .post('/register') - .send({ - redirect_uris: ['https://example.com/callback'] - }); - - expect(response.status).toBe(201); - - // Verify the expiration time (~1 hour from now) - const issuedAt = response.body.client_id_issued_at; - const expiresAt = response.body.client_secret_expires_at; - expect(expiresAt - issuedAt).toBe(3600); - }); - - it('sets no expiry when clientSecretExpirySeconds=0', async () => { - // Create handler with no expiry - const customApp = express(); - const options: ClientRegistrationHandlerOptions = { - clientsStore: mockClientStoreWithRegistration, - clientSecretExpirySeconds: 0 // No expiry - }; - - customApp.use('/register', clientRegistrationHandler(options)); - - const response = await supertest(customApp) - .post('/register') - .send({ - redirect_uris: ['https://example.com/callback'] - }); - - expect(response.status).toBe(201); - expect(response.body.client_secret_expires_at).toBe(0); - }); - - it('sets no client_id when clientIdGeneration=false', async () => { - // Create handler with no expiry - const customApp = express(); - const options: ClientRegistrationHandlerOptions = { - clientsStore: mockClientStoreWithRegistration, - clientIdGeneration: false - }; - - customApp.use('/register', clientRegistrationHandler(options)); - - const response = await supertest(customApp) - .post('/register') - .send({ - redirect_uris: ['https://example.com/callback'] - }); - - expect(response.status).toBe(201); - expect(response.body.client_id).toBeUndefined(); - expect(response.body.client_id_issued_at).toBeUndefined(); - }); - - it('handles client with all metadata fields', async () => { - const fullClientMetadata: OAuthClientMetadata = { - redirect_uris: ['https://example.com/callback'], - token_endpoint_auth_method: 'client_secret_basic', - grant_types: ['authorization_code', 'refresh_token'], - response_types: ['code'], - client_name: 'Test Client', - client_uri: 'https://example.com', - logo_uri: 'https://example.com/logo.png', - scope: 'profile email', - contacts: ['dev@example.com'], - tos_uri: 'https://example.com/tos', - policy_uri: 'https://example.com/privacy', - jwks_uri: 'https://example.com/jwks', - software_id: 'test-software', - software_version: '1.0.0' - }; - - const response = await supertest(app).post('/register').send(fullClientMetadata); - - expect(response.status).toBe(201); - - // Verify all metadata was preserved - Object.entries(fullClientMetadata).forEach(([key, value]) => { - expect(response.body[key]).toEqual(value); - }); - }); - - it('includes CORS headers in response', async () => { - const response = await supertest(app) - .post('/register') - .set('Origin', 'https://example.com') - .send({ +describe('clientRegistrationHandler (web)', () => { + it('returns 201 and client info when registration is supported', async () => { + const clientsStore: OAuthRegisteredClientsStore = { + async getClient() { + return undefined; + }, + async registerClient(client: Omit) { + // In real implementation, server may generate ids; here return minimal. + return { + ...client, + client_id: 'generated-client', + client_id_issued_at: Math.floor(Date.now() / 1000), + redirect_uris: (client as any).redirect_uris ?? [] + } as unknown as OAuthClientInformationFull; + } + }; + + const handler = clientRegistrationHandler({ clientsStore, rateLimit: false }); + + const res = await handler( + new Request('http://localhost/register', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ redirect_uris: ['https://example.com/callback'] - }); + }) + }) + ); - expect(response.header['access-control-allow-origin']).toBe('*'); - }); + expect(res.status).toBe(201); + const body = (await res.json()) as { client_id?: string }; + expect(body.client_id).toBeDefined(); }); }); diff --git a/packages/server/test/server/auth/handlers/revoke.test.ts b/packages/server/test/server/auth/handlers/revoke.test.ts index 61ff51b24..d9598bb43 100644 --- a/packages/server/test/server/auth/handlers/revoke.test.ts +++ b/packages/server/test/server/auth/handlers/revoke.test.ts @@ -1,233 +1,72 @@ -import type { AuthInfo, OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens } from '@modelcontextprotocol/core'; -import { InvalidTokenError } from '@modelcontextprotocol/core'; -import type { Response } from 'express'; -import express from 'express'; -import supertest from 'supertest'; -import type { MockInstance } from 'vitest'; +import type { OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens } from '@modelcontextprotocol/core'; import type { OAuthRegisteredClientsStore } from '../../../../src/server/auth/clients.js'; -import type { RevocationHandlerOptions } from '../../../../src/server/auth/handlers/revoke.js'; import { revocationHandler } from '../../../../src/server/auth/handlers/revoke.js'; import type { AuthorizationParams, OAuthServerProvider } from '../../../../src/server/auth/provider.js'; -describe('Revocation Handler', () => { - // Mock client data - const validClient: OAuthClientInformationFull = { - client_id: 'valid-client', - client_secret: 'valid-secret', - redirect_uris: ['https://example.com/callback'] - }; - - // Mock client store - const mockClientStore: OAuthRegisteredClientsStore = { - async getClient(clientId: string): Promise { - if (clientId === 'valid-client') { - return validClient; +describe('revocationHandler (web)', () => { + it('returns 200 on successful revocation', async () => { + const validClient: OAuthClientInformationFull = { + client_id: 'valid-client', + client_secret: 'valid-secret', + redirect_uris: ['https://example.com/callback'] + }; + + const clientsStore: OAuthRegisteredClientsStore = { + async getClient(clientId: string) { + return clientId === 'valid-client' ? validClient : undefined; } - return undefined; - } - }; - - // Mock provider with revocation capability - const mockProviderWithRevocation: OAuthServerProvider = { - clientsStore: mockClientStore, - - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { - res.redirect('https://example.com/callback?code=mock_auth_code'); - }, - - async challengeForAuthorizationCode(): Promise { - return 'mock_challenge'; - }, - - async exchangeAuthorizationCode(): Promise { - return { - access_token: 'mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'mock_refresh_token' - }; - }, - - async exchangeRefreshToken(): Promise { - return { - access_token: 'new_mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'new_mock_refresh_token' - }; - }, - - async verifyAccessToken(token: string): Promise { - if (token === 'valid_token') { + }; + + const provider: OAuthServerProvider = { + clientsStore, + async authorize(_client: OAuthClientInformationFull, _params: AuthorizationParams): Promise { + return Response.redirect('https://example.com', 302); + }, + async challengeForAuthorizationCode(): Promise { + return 'mock'; + }, + async exchangeAuthorizationCode(): Promise { return { - token, - clientId: 'valid-client', - scopes: ['read', 'write'], - expiresAt: Date.now() / 1000 + 3600 + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' }; - } - throw new InvalidTokenError('Token is invalid or expired'); - }, - - async revokeToken(_client: OAuthClientInformationFull, _request: OAuthTokenRevocationRequest): Promise { - // Success - do nothing in mock - } - }; - - // Mock provider without revocation capability - const mockProviderWithoutRevocation: OAuthServerProvider = { - clientsStore: mockClientStore, - - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { - res.redirect('https://example.com/callback?code=mock_auth_code'); - }, - - async challengeForAuthorizationCode(): Promise { - return 'mock_challenge'; - }, - - async exchangeAuthorizationCode(): Promise { - return { - access_token: 'mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'mock_refresh_token' - }; - }, - - async exchangeRefreshToken(): Promise { - return { - access_token: 'new_mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'new_mock_refresh_token' - }; - }, - - async verifyAccessToken(token: string): Promise { - if (token === 'valid_token') { + }, + async exchangeRefreshToken(): Promise { return { - token, - clientId: 'valid-client', - scopes: ['read', 'write'], - expiresAt: Date.now() / 1000 + 3600 + access_token: 'new_mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'new_mock_refresh_token' }; + }, + async verifyAccessToken() { + throw new Error('not used'); + }, + async revokeToken(_client: OAuthClientInformationFull, _request: OAuthTokenRevocationRequest): Promise { + // ok } - throw new InvalidTokenError('Token is invalid or expired'); - } - - // No revokeToken method - }; - - describe('Handler creation', () => { - it('throws error if provider does not support token revocation', () => { - const options: RevocationHandlerOptions = { provider: mockProviderWithoutRevocation }; - expect(() => revocationHandler(options)).toThrow('does not support revoking tokens'); - }); - - it('creates handler if provider supports token revocation', () => { - const options: RevocationHandlerOptions = { provider: mockProviderWithRevocation }; - expect(() => revocationHandler(options)).not.toThrow(); - }); - }); - - describe('Request handling', () => { - let app: express.Express; - let spyRevokeToken: MockInstance; - - beforeEach(() => { - // Setup express app with revocation handler - app = express(); - const options: RevocationHandlerOptions = { provider: mockProviderWithRevocation }; - app.use('/revoke', revocationHandler(options)); - - // Spy on the revokeToken method - spyRevokeToken = vi.spyOn(mockProviderWithRevocation, 'revokeToken'); - }); - - afterEach(() => { - spyRevokeToken.mockRestore(); - }); - - it('requires POST method', async () => { - const response = await supertest(app).get('/revoke').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - token: 'token_to_revoke' - }); - - expect(response.status).toBe(405); - expect(response.headers.allow).toBe('POST'); - expect(response.body).toEqual({ - error: 'method_not_allowed', - error_description: 'The method GET is not allowed for this endpoint' - }); - expect(spyRevokeToken).not.toHaveBeenCalled(); - }); - - it('requires token parameter', async () => { - const response = await supertest(app).post('/revoke').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret' - // Missing token - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_request'); - expect(spyRevokeToken).not.toHaveBeenCalled(); - }); - - it('authenticates client before revoking token', async () => { - const response = await supertest(app).post('/revoke').type('form').send({ - client_id: 'invalid-client', - client_secret: 'wrong-secret', - token: 'token_to_revoke' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_client'); - expect(spyRevokeToken).not.toHaveBeenCalled(); - }); - - it('successfully revokes token', async () => { - const response = await supertest(app).post('/revoke').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - token: 'token_to_revoke' - }); - - expect(response.status).toBe(200); - expect(response.body).toEqual({}); // Empty response on success - expect(spyRevokeToken).toHaveBeenCalledTimes(1); - expect(spyRevokeToken).toHaveBeenCalledWith(validClient, { - token: 'token_to_revoke' - }); - }); - - it('accepts optional token_type_hint', async () => { - const response = await supertest(app).post('/revoke').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - token: 'token_to_revoke', - token_type_hint: 'refresh_token' - }); - - expect(response.status).toBe(200); - expect(spyRevokeToken).toHaveBeenCalledWith(validClient, { - token: 'token_to_revoke', - token_type_hint: 'refresh_token' - }); - }); - - it('includes CORS headers in response', async () => { - const response = await supertest(app).post('/revoke').type('form').set('Origin', 'https://example.com').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - token: 'token_to_revoke' - }); - - expect(response.header['access-control-allow-origin']).toBe('*'); - }); + }; + + const handler = revocationHandler({ provider, rateLimit: false }); + + const body = new URLSearchParams({ + client_id: 'valid-client', + client_secret: 'valid-secret', + token: 'token_to_revoke' + }).toString(); + + const res = await handler( + new Request('http://localhost/revoke', { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body + }) + ); + + expect(res.status).toBe(200); + expect(await res.json()).toEqual({}); }); }); diff --git a/packages/server/test/server/auth/handlers/token.test.ts b/packages/server/test/server/auth/handlers/token.test.ts index 02eab891f..d0cb0ca24 100644 --- a/packages/server/test/server/auth/handlers/token.test.ts +++ b/packages/server/test/server/auth/handlers/token.test.ts @@ -1,481 +1,116 @@ -import type { AuthInfo, OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens } from '@modelcontextprotocol/core'; -import { InvalidGrantError, InvalidTokenError } from '@modelcontextprotocol/core'; -import type { Response } from 'express'; -import express from 'express'; +import type { AuthInfo, OAuthClientInformationFull, OAuthTokens } from '@modelcontextprotocol/core'; +import { InvalidGrantError } from '@modelcontextprotocol/core'; import * as pkceChallenge from 'pkce-challenge'; -import supertest from 'supertest'; -import { type Mock } from 'vitest'; import type { OAuthRegisteredClientsStore } from '../../../../src/server/auth/clients.js'; -import type { TokenHandlerOptions } from '../../../../src/server/auth/handlers/token.js'; import { tokenHandler } from '../../../../src/server/auth/handlers/token.js'; import type { AuthorizationParams, OAuthServerProvider } from '../../../../src/server/auth/provider.js'; -import { ProxyOAuthServerProvider } from '../../../../src/server/auth/providers/proxyProvider.js'; -// Mock pkce-challenge vi.mock('pkce-challenge', () => ({ - verifyChallenge: vi.fn().mockImplementation(async (verifier, challenge) => { - return verifier === 'valid_verifier' && challenge === 'mock_challenge'; - }) + verifyChallenge: vi.fn() })); -const mockTokens = { - access_token: 'mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'mock_refresh_token' -}; - -const mockTokensWithIdToken = { - ...mockTokens, - id_token: 'mock_id_token' -}; - -describe('Token Handler', () => { - // Mock client data +describe('tokenHandler (web)', () => { const validClient: OAuthClientInformationFull = { client_id: 'valid-client', client_secret: 'valid-secret', redirect_uris: ['https://example.com/callback'] }; - // Mock client store - const mockClientStore: OAuthRegisteredClientsStore = { - async getClient(clientId: string): Promise { - if (clientId === 'valid-client') { - return validClient; - } - return undefined; + const clientsStore: OAuthRegisteredClientsStore = { + async getClient(clientId: string) { + return clientId === 'valid-client' ? validClient : undefined; } }; - // Mock provider - let mockProvider: OAuthServerProvider; - let app: express.Express; - - beforeEach(() => { - // Create fresh mocks for each test - mockProvider = { - clientsStore: mockClientStore, - - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { - res.redirect('https://example.com/callback?code=mock_auth_code'); - }, - - async challengeForAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise { - if (authorizationCode === 'valid_code') { - return 'mock_challenge'; - } else if (authorizationCode === 'expired_code') { - throw new InvalidGrantError('The authorization code has expired'); - } - throw new InvalidGrantError('The authorization code is invalid'); - }, - - async exchangeAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise { - if (authorizationCode === 'valid_code') { - return mockTokens; - } - throw new InvalidGrantError('The authorization code is invalid or has expired'); - }, - - async exchangeRefreshToken(client: OAuthClientInformationFull, refreshToken: string, scopes?: string[]): Promise { - if (refreshToken === 'valid_refresh_token') { - const response: OAuthTokens = { - access_token: 'new_mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'new_mock_refresh_token' - }; - - if (scopes) { - response.scope = scopes.join(' '); - } - - return response; - } - throw new InvalidGrantError('The refresh token is invalid or has expired'); - }, - - async verifyAccessToken(token: string): Promise { - if (token === 'valid_token') { - return { - token, - clientId: 'valid-client', - scopes: ['read', 'write'], - expiresAt: Date.now() / 1000 + 3600 - }; - } - throw new InvalidTokenError('Token is invalid or expired'); - }, - - async revokeToken(_client: OAuthClientInformationFull, _request: OAuthTokenRevocationRequest): Promise { - // Do nothing in mock - } - }; - - // Mock PKCE verification - (pkceChallenge.verifyChallenge as Mock).mockImplementation(async (verifier: string, challenge: string) => { - return verifier === 'valid_verifier' && challenge === 'mock_challenge'; - }); - - // Setup express app with token handler - app = express(); - const options: TokenHandlerOptions = { provider: mockProvider }; - app.use('/token', tokenHandler(options)); - }); - - describe('Basic request validation', () => { - it('requires POST method', async () => { - const response = await supertest(app).get('/token').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code' - }); - - expect(response.status).toBe(405); - expect(response.headers.allow).toBe('POST'); - expect(response.body).toEqual({ - error: 'method_not_allowed', - error_description: 'The method GET is not allowed for this endpoint' - }); - }); - - it('requires grant_type parameter', async () => { - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret' - // Missing grant_type - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_request'); - }); - - it('rejects unsupported grant types', async () => { - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'password' // Unsupported grant type - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('unsupported_grant_type'); - }); - }); - - describe('Client authentication', () => { - it('requires valid client credentials', async () => { - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'invalid-client', - client_secret: 'wrong-secret', - grant_type: 'authorization_code' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_client'); - }); - - it('accepts valid client credentials', async () => { - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'valid_verifier' - }); - - expect(response.status).toBe(200); - }); - }); - - describe('Authorization code grant', () => { - it('requires code parameter', async () => { - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - // Missing code - code_verifier: 'valid_verifier' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_request'); - }); - - it('requires code_verifier parameter', async () => { - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code' - // Missing code_verifier - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_request'); - }); - - it('verifies code_verifier against challenge', async () => { - // Setup invalid verifier - (pkceChallenge.verifyChallenge as Mock).mockResolvedValueOnce(false); - - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'invalid_verifier' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_grant'); - expect(response.body.error_description).toContain('code_verifier'); - }); - - it('rejects expired or invalid authorization codes', async () => { - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'expired_code', - code_verifier: 'valid_verifier' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_grant'); - }); - - it('returns tokens for valid code exchange', async () => { - const mockExchangeCode = vi.spyOn(mockProvider, 'exchangeAuthorizationCode'); - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - resource: 'https://api.example.com/resource', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'valid_verifier' - }); - - expect(response.status).toBe(200); - expect(response.body.access_token).toBe('mock_access_token'); - expect(response.body.token_type).toBe('bearer'); - expect(response.body.expires_in).toBe(3600); - expect(response.body.refresh_token).toBe('mock_refresh_token'); - expect(mockExchangeCode).toHaveBeenCalledWith( - validClient, - 'valid_code', - undefined, // code_verifier is undefined after PKCE validation - undefined, // redirect_uri - new URL('https://api.example.com/resource') // resource parameter - ); - }); - - it('returns id token in code exchange if provided', async () => { - mockProvider.exchangeAuthorizationCode = async ( - client: OAuthClientInformationFull, - authorizationCode: string - ): Promise => { - if (authorizationCode === 'valid_code') { - return mockTokensWithIdToken; - } - throw new InvalidGrantError('The authorization code is invalid or has expired'); + const provider: OAuthServerProvider = { + clientsStore, + async authorize(_client: OAuthClientInformationFull, _params: AuthorizationParams): Promise { + return Response.redirect('https://example.com/callback?code=mock_auth_code', 302); + }, + async challengeForAuthorizationCode(_client: OAuthClientInformationFull, authorizationCode: string): Promise { + if (authorizationCode === 'valid_code') return 'mock_challenge'; + throw new InvalidGrantError('The authorization code is invalid'); + }, + async exchangeAuthorizationCode(): Promise { + return { + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' }; + }, + async exchangeRefreshToken(): Promise { + return { + access_token: 'new_mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'new_mock_refresh_token' + }; + }, + async verifyAccessToken(token: string): Promise { + return { + token, + clientId: 'valid-client', + scopes: [], + expiresAt: Math.floor(Date.now() / 1000) + 3600 + }; + } + }; - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'valid_verifier' - }); - - expect(response.status).toBe(200); - expect(response.body.id_token).toBe('mock_id_token'); - }); - - it('passes through code verifier when using proxy provider', async () => { - const originalFetch = global.fetch; - - try { - global.fetch = vi.fn().mockResolvedValue({ - ok: true, - json: () => Promise.resolve(mockTokens) - }); - - const proxyProvider = new ProxyOAuthServerProvider({ - endpoints: { - authorizationUrl: 'https://example.com/authorize', - tokenUrl: 'https://example.com/token' - }, - verifyAccessToken: async token => ({ - token, - clientId: 'valid-client', - scopes: ['read', 'write'], - expiresAt: Date.now() / 1000 + 3600 - }), - getClient: async clientId => (clientId === 'valid-client' ? validClient : undefined) - }); - - const proxyApp = express(); - const options: TokenHandlerOptions = { provider: proxyProvider }; - proxyApp.use('/token', tokenHandler(options)); - - const response = await supertest(proxyApp).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'any_verifier', - redirect_uri: 'https://example.com/callback' - }); - - expect(response.status).toBe(200); - expect(response.body.access_token).toBe('mock_access_token'); - - expect(global.fetch).toHaveBeenCalledWith( - 'https://example.com/token', - expect.objectContaining({ - method: 'POST', - headers: { - 'Content-Type': 'application/x-www-form-urlencoded' - }, - body: expect.stringContaining('code_verifier=any_verifier') - }) - ); - } finally { - global.fetch = originalFetch; - } - }); - - it('passes through redirect_uri when using proxy provider', async () => { - const originalFetch = global.fetch; - - try { - global.fetch = vi.fn().mockResolvedValue({ - ok: true, - json: () => Promise.resolve(mockTokens) - }); - - const proxyProvider = new ProxyOAuthServerProvider({ - endpoints: { - authorizationUrl: 'https://example.com/authorize', - tokenUrl: 'https://example.com/token' - }, - verifyAccessToken: async token => ({ - token, - clientId: 'valid-client', - scopes: ['read', 'write'], - expiresAt: Date.now() / 1000 + 3600 - }), - getClient: async clientId => (clientId === 'valid-client' ? validClient : undefined) - }); - - const proxyApp = express(); - const options: TokenHandlerOptions = { provider: proxyProvider }; - proxyApp.use('/token', tokenHandler(options)); - - const redirectUri = 'https://example.com/callback'; - const response = await supertest(proxyApp).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'any_verifier', - redirect_uri: redirectUri - }); - - expect(response.status).toBe(200); - expect(response.body.access_token).toBe('mock_access_token'); - - expect(global.fetch).toHaveBeenCalledWith( - 'https://example.com/token', - expect.objectContaining({ - method: 'POST', - headers: { - 'Content-Type': 'application/x-www-form-urlencoded' - }, - body: expect.stringContaining(`redirect_uri=${encodeURIComponent(redirectUri)}`) - }) - ); - } finally { - global.fetch = originalFetch; - } - }); + beforeEach(() => { + vi.clearAllMocks(); }); - describe('Refresh token grant', () => { - it('requires refresh_token parameter', async () => { - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'refresh_token' - // Missing refresh_token - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_request'); - }); - - it('rejects invalid refresh tokens', async () => { - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'refresh_token', - refresh_token: 'invalid_refresh_token' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_grant'); - }); - - it('returns new tokens for valid refresh token', async () => { - const mockExchangeRefresh = vi.spyOn(mockProvider, 'exchangeRefreshToken'); - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - resource: 'https://api.example.com/resource', - grant_type: 'refresh_token', - refresh_token: 'valid_refresh_token' - }); - - expect(response.status).toBe(200); - expect(response.body.access_token).toBe('new_mock_access_token'); - expect(response.body.token_type).toBe('bearer'); - expect(response.body.expires_in).toBe(3600); - expect(response.body.refresh_token).toBe('new_mock_refresh_token'); - expect(mockExchangeRefresh).toHaveBeenCalledWith( - validClient, - 'valid_refresh_token', - undefined, // scopes - new URL('https://api.example.com/resource') // resource parameter - ); - }); - - it('respects requested scopes on refresh', async () => { - const response = await supertest(app).post('/token').type('form').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'refresh_token', - refresh_token: 'valid_refresh_token', - scope: 'profile email' - }); - - expect(response.status).toBe(200); - expect(response.body.scope).toBe('profile email'); - }); + it('returns tokens for authorization_code grant when PKCE passes', async () => { + (pkceChallenge.verifyChallenge as unknown as ReturnType).mockResolvedValue(true); + const handler = tokenHandler({ provider, rateLimit: false }); + + const body = new URLSearchParams({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'valid_verifier' + }).toString(); + + const res = await handler( + new Request('http://localhost/token', { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body + }) + ); + + expect(res.status).toBe(200); + expect(await res.json()).toEqual( + expect.objectContaining({ + access_token: 'mock_access_token' + }) + ); }); - describe('CORS support', () => { - it('includes CORS headers in response', async () => { - const response = await supertest(app).post('/token').type('form').set('Origin', 'https://example.com').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'valid_verifier' - }); - - expect(response.header['access-control-allow-origin']).toBe('*'); - }); + it('returns 400 when PKCE fails', async () => { + (pkceChallenge.verifyChallenge as unknown as ReturnType).mockResolvedValue(false); + const handler = tokenHandler({ provider, rateLimit: false }); + + const body = new URLSearchParams({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'bad_verifier' + }).toString(); + + const res = await handler( + new Request('http://localhost/token', { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body + }) + ); + + expect(res.status).toBe(400); + expect(await res.json()).toEqual(expect.objectContaining({ error: 'invalid_grant' })); }); }); diff --git a/packages/server/test/server/auth/middleware/allowedMethods.test.ts b/packages/server/test/server/auth/middleware/allowedMethods.test.ts index 40e9c3b1f..3cea847a1 100644 --- a/packages/server/test/server/auth/middleware/allowedMethods.test.ts +++ b/packages/server/test/server/auth/middleware/allowedMethods.test.ts @@ -1,77 +1,29 @@ -import type { Request, Response } from 'express'; -import express from 'express'; -import request from 'supertest'; - import { allowedMethods } from '../../../../src/server/auth/middleware/allowedMethods.js'; describe('allowedMethods', () => { - let app: express.Express; - - beforeEach(() => { - app = express(); - - // Set up a test router with a GET handler and 405 middleware - const router = express.Router(); - - router.get('/test', (req, res) => { - res.status(200).send('GET success'); - }); - - // Add method not allowed middleware for all other methods - router.all('/test', allowedMethods(['GET'])); - - app.use(router); - }); - - test('allows specified HTTP method', async () => { - const response = await request(app).get('/test'); - expect(response.status).toBe(200); - expect(response.text).toBe('GET success'); - }); - - test('returns 405 for unspecified HTTP methods', async () => { - const methods = ['post', 'put', 'delete', 'patch']; - - for (const method of methods) { - // @ts-expect-error - dynamic method call - const response = await request(app)[method]('/test'); - expect(response.status).toBe(405); - expect(response.body).toEqual({ - error: 'method_not_allowed', - error_description: `The method ${method.toUpperCase()} is not allowed for this endpoint` - }); - } + test('returns undefined for allowed HTTP method', () => { + const req = new Request('http://localhost/test', { method: 'GET' }); + const res = allowedMethods(['GET'], req); + expect(res).toBeUndefined(); }); - test('includes Allow header with specified methods', async () => { - const response = await request(app).post('/test'); - expect(response.headers.allow).toBe('GET'); - }); - - test('works with multiple allowed methods', async () => { - const multiMethodApp = express(); - const router = express.Router(); - - router.get('/multi', (req: Request, res: Response) => { - res.status(200).send('GET'); + test('returns 405 response for disallowed HTTP method', async () => { + const req = new Request('http://localhost/test', { method: 'POST' }); + const res = allowedMethods(['GET'], req); + expect(res).toBeDefined(); + expect(res!.status).toBe(405); + expect(res!.headers.get('allow')).toBe('GET'); + expect(await res!.json()).toEqual({ + error: 'method_not_allowed', + error_description: 'The method POST is not allowed for this endpoint' }); - router.post('/multi', (req: Request, res: Response) => { - res.status(200).send('POST'); - }); - router.all('/multi', allowedMethods(['GET', 'POST'])); - - multiMethodApp.use(router); - - // Allowed methods should work - const getResponse = await request(multiMethodApp).get('/multi'); - expect(getResponse.status).toBe(200); - - const postResponse = await request(multiMethodApp).post('/multi'); - expect(postResponse.status).toBe(200); + }); - // Unallowed methods should return 405 - const putResponse = await request(multiMethodApp).put('/multi'); - expect(putResponse.status).toBe(405); - expect(putResponse.headers.allow).toBe('GET, POST'); + test('supports multiple allowed methods', async () => { + const req = new Request('http://localhost/test', { method: 'PUT' }); + const res = allowedMethods(['GET', 'POST'], req); + expect(res).toBeDefined(); + expect(res!.status).toBe(405); + expect(res!.headers.get('allow')).toBe('GET, POST'); }); }); diff --git a/packages/server/test/server/auth/middleware/bearerAuth.test.ts b/packages/server/test/server/auth/middleware/bearerAuth.test.ts index 7b464bbff..9b0ead3ef 100644 --- a/packages/server/test/server/auth/middleware/bearerAuth.test.ts +++ b/packages/server/test/server/auth/middleware/bearerAuth.test.ts @@ -1,502 +1,118 @@ import type { AuthInfo } from '@modelcontextprotocol/core'; -import { CustomOAuthError, InsufficientScopeError, InvalidTokenError, ServerError } from '@modelcontextprotocol/core'; -import { createExpressResponseMock } from '@modelcontextprotocol/test-helpers'; -import type { Request, Response } from 'express'; -import type { Mock } from 'vitest'; +import { InsufficientScopeError, InvalidTokenError, ServerError } from '@modelcontextprotocol/core'; import { requireBearerAuth } from '../../../../src/server/auth/middleware/bearerAuth.js'; import type { OAuthTokenVerifier } from '../../../../src/server/auth/provider.js'; -// Mock verifier -const mockVerifyAccessToken = vi.fn(); -const mockVerifier: OAuthTokenVerifier = { - verifyAccessToken: mockVerifyAccessToken -}; - -describe('requireBearerAuth middleware', () => { - let mockRequest: Partial; - let mockResponse: Partial; - let nextFunction: Mock; +describe('requireBearerAuth (web)', () => { + const verifyAccessToken = vi.fn(); + const verifier: OAuthTokenVerifier = { verifyAccessToken }; beforeEach(() => { - mockRequest = { - headers: {} - }; - mockResponse = createExpressResponseMock(); - nextFunction = vi.fn(); - vi.spyOn(console, 'error').mockImplementation(() => {}); - }); - - afterEach(() => { vi.clearAllMocks(); }); - it('should call next when token is valid', async () => { - const validAuthInfo: AuthInfo = { - token: 'valid-token', - clientId: 'client-123', - scopes: ['read', 'write'], - expiresAt: Math.floor(Date.now() / 1000) + 3600 // Token expires in an hour - }; - mockVerifyAccessToken.mockResolvedValue(validAuthInfo); - - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); - expect(mockRequest.auth).toEqual(validAuthInfo); - expect(nextFunction).toHaveBeenCalled(); - expect(mockResponse.status).not.toHaveBeenCalled(); - expect(mockResponse.json).not.toHaveBeenCalled(); - }); - - it.each([ - [100], // Token expired 100 seconds ago - [0] // Token expires at the same time as now - ])('should reject expired tokens (expired %s seconds ago)', async (expiredSecondsAgo: number) => { - const expiresAt = Math.floor(Date.now() / 1000) - expiredSecondsAgo; - const expiredAuthInfo: AuthInfo = { - token: 'expired-token', - clientId: 'client-123', - scopes: ['read', 'write'], - expiresAt - }; - mockVerifyAccessToken.mockResolvedValue(expiredAuthInfo); - - mockRequest.headers = { - authorization: 'Bearer expired-token' - }; - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith('expired-token'); - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="invalid_token"')); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: 'invalid_token', error_description: 'Token has expired' }) - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it.each([ - [undefined], // Token has no expiration time - [NaN] // Token has no expiration time - ])('should reject tokens with no expiration time (expiresAt: %s)', async (expiresAt: number | undefined) => { - const noExpirationAuthInfo: AuthInfo = { - token: 'no-expiration-token', - clientId: 'client-123', - scopes: ['read', 'write'], - expiresAt - }; - mockVerifyAccessToken.mockResolvedValue(noExpirationAuthInfo); - - mockRequest.headers = { - authorization: 'Bearer expired-token' - }; - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith('expired-token'); - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="invalid_token"')); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: 'invalid_token', error_description: 'Token has no expiration time' }) - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it('should accept non-expired tokens', async () => { - const nonExpiredAuthInfo: AuthInfo = { - token: 'valid-token', - clientId: 'client-123', - scopes: ['read', 'write'], - expiresAt: Math.floor(Date.now() / 1000) + 3600 // Token expires in an hour - }; - mockVerifyAccessToken.mockResolvedValue(nonExpiredAuthInfo); - - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); - expect(mockRequest.auth).toEqual(nonExpiredAuthInfo); - expect(nextFunction).toHaveBeenCalled(); - expect(mockResponse.status).not.toHaveBeenCalled(); - expect(mockResponse.json).not.toHaveBeenCalled(); - }); - - it('should require specific scopes when configured', async () => { - const authInfo: AuthInfo = { - token: 'valid-token', - clientId: 'client-123', - scopes: ['read'] - }; - mockVerifyAccessToken.mockResolvedValue(authInfo); - - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; - - const middleware = requireBearerAuth({ - verifier: mockVerifier, - requiredScopes: ['read', 'write'] - }); - - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); - expect(mockResponse.status).toHaveBeenCalledWith(403); - expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="insufficient_scope"')); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: 'insufficient_scope', error_description: 'Insufficient scope' }) - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it('should accept token with all required scopes', async () => { - const authInfo: AuthInfo = { - token: 'valid-token', - clientId: 'client-123', - scopes: ['read', 'write', 'admin'], - expiresAt: Math.floor(Date.now() / 1000) + 3600 // Token expires in an hour - }; - mockVerifyAccessToken.mockResolvedValue(authInfo); - - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; - - const middleware = requireBearerAuth({ - verifier: mockVerifier, - requiredScopes: ['read', 'write'] - }); - - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); - expect(mockRequest.auth).toEqual(authInfo); - expect(nextFunction).toHaveBeenCalled(); - expect(mockResponse.status).not.toHaveBeenCalled(); - expect(mockResponse.json).not.toHaveBeenCalled(); - }); - - it('should return 401 when no Authorization header is present', async () => { - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).not.toHaveBeenCalled(); - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="invalid_token"')); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: 'invalid_token', error_description: 'Missing Authorization header' }) - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it('should return 401 when Authorization header format is invalid', async () => { - mockRequest.headers = { - authorization: 'InvalidFormat' + it('returns authInfo on success', async () => { + const info: AuthInfo = { + token: 't', + clientId: 'c', + scopes: ['read'], + expiresAt: Math.floor(Date.now() / 1000) + 3600 }; + verifyAccessToken.mockResolvedValue(info); - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + const req = new Request('http://localhost/x', { headers: { Authorization: 'Bearer t' } }); + const result = await requireBearerAuth(req, { verifier }); - expect(mockVerifyAccessToken).not.toHaveBeenCalled(); - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="invalid_token"')); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ - error: 'invalid_token', - error_description: "Invalid Authorization header format, expected 'Bearer TOKEN'" - }) - ); - expect(nextFunction).not.toHaveBeenCalled(); + expect('authInfo' in result).toBe(true); + if ('authInfo' in result) { + expect(result.authInfo).toEqual(info); + } }); - it('should return 401 when token verification fails with InvalidTokenError', async () => { - mockRequest.headers = { - authorization: 'Bearer invalid-token' - }; - - mockVerifyAccessToken.mockRejectedValue(new InvalidTokenError('Token expired')); - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith('invalid-token'); - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="invalid_token"')); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: 'invalid_token', error_description: 'Token expired' }) - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it('should return 403 when access token has insufficient scopes', async () => { - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; + it('returns 401 when missing Authorization header', async () => { + const req = new Request('http://localhost/x'); + const result = await requireBearerAuth(req, { verifier }); - mockVerifyAccessToken.mockRejectedValue(new InsufficientScopeError('Required scopes: read, write')); - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); - expect(mockResponse.status).toHaveBeenCalledWith(403); - expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="insufficient_scope"')); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: 'insufficient_scope', error_description: 'Required scopes: read, write' }) - ); - expect(nextFunction).not.toHaveBeenCalled(); + expect('response' in result).toBe(true); + if ('response' in result) { + expect(result.response.status).toBe(401); + expect(result.response.headers.get('www-authenticate')).toContain('Bearer error="invalid_token"'); + expect(await result.response.json()).toEqual( + expect.objectContaining({ error: 'invalid_token', error_description: 'Missing Authorization header' }) + ); + } }); - it('should return 500 when a ServerError occurs', async () => { - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; - - mockVerifyAccessToken.mockRejectedValue(new ServerError('Internal server issue')); - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + it('returns 401 when verifier throws InvalidTokenError', async () => { + verifyAccessToken.mockRejectedValue(new InvalidTokenError('bad')); + const req = new Request('http://localhost/x', { headers: { Authorization: 'Bearer t' } }); + const result = await requireBearerAuth(req, { verifier }); - expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); - expect(mockResponse.status).toHaveBeenCalledWith(500); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: 'server_error', error_description: 'Internal server issue' }) - ); - expect(nextFunction).not.toHaveBeenCalled(); + expect('response' in result).toBe(true); + if ('response' in result) { + expect(result.response.status).toBe(401); + } }); - it('should return 400 for generic OAuthError', async () => { - mockRequest.headers = { - authorization: 'Bearer valid-token' + it('returns 403 when scopes are insufficient', async () => { + const info: AuthInfo = { + token: 't', + clientId: 'c', + scopes: ['read'], + expiresAt: Math.floor(Date.now() / 1000) + 3600 }; + verifyAccessToken.mockResolvedValue(info); - mockVerifyAccessToken.mockRejectedValue(new CustomOAuthError('custom_error', 'Some OAuth error')); + const req = new Request('http://localhost/x', { headers: { Authorization: 'Bearer t' } }); + const result = await requireBearerAuth(req, { verifier, requiredScopes: ['read', 'write'] }); - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); - expect(mockResponse.status).toHaveBeenCalledWith(400); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: 'custom_error', error_description: 'Some OAuth error' }) - ); - expect(nextFunction).not.toHaveBeenCalled(); + expect('response' in result).toBe(true); + if ('response' in result) { + expect(result.response.status).toBe(403); + expect(result.response.headers.get('www-authenticate')).toContain('Bearer error="insufficient_scope"'); + expect(await result.response.json()).toEqual( + expect.objectContaining({ error: 'insufficient_scope', error_description: 'Insufficient scope' }) + ); + } }); - it('should return 500 when unexpected error occurs', async () => { - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; - - mockVerifyAccessToken.mockRejectedValue(new Error('Unexpected error')); + it('returns 500 when verifier throws ServerError', async () => { + verifyAccessToken.mockRejectedValue(new ServerError('boom')); + const req = new Request('http://localhost/x', { headers: { Authorization: 'Bearer t' } }); + const result = await requireBearerAuth(req, { verifier }); - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); - expect(mockResponse.status).toHaveBeenCalledWith(500); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: 'server_error', error_description: 'Internal Server Error' }) - ); - expect(nextFunction).not.toHaveBeenCalled(); + expect('response' in result).toBe(true); + if ('response' in result) { + expect(result.response.status).toBe(500); + } }); - describe('with requiredScopes in WWW-Authenticate header', () => { - it('should include scope in WWW-Authenticate header for 401 responses when requiredScopes is provided', async () => { - mockRequest.headers = {}; - - const middleware = requireBearerAuth({ - verifier: mockVerifier, - requiredScopes: ['read', 'write'] - }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith( - 'WWW-Authenticate', - 'Bearer error="invalid_token", error_description="Missing Authorization header", scope="read write"' - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it('should include scope in WWW-Authenticate header for 403 insufficient scope responses', async () => { - const authInfo: AuthInfo = { - token: 'valid-token', - clientId: 'client-123', - scopes: ['read'] - }; - mockVerifyAccessToken.mockResolvedValue(authInfo); - - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; - - const middleware = requireBearerAuth({ - verifier: mockVerifier, - requiredScopes: ['read', 'write'] - }); - - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockResponse.status).toHaveBeenCalledWith(403); - expect(mockResponse.set).toHaveBeenCalledWith( - 'WWW-Authenticate', - 'Bearer error="insufficient_scope", error_description="Insufficient scope", scope="read write"' - ); - expect(nextFunction).not.toHaveBeenCalled(); + it('includes scope and resource_metadata in WWW-Authenticate when provided', async () => { + verifyAccessToken.mockRejectedValue(new InvalidTokenError('bad')); + const req = new Request('http://localhost/x', { headers: { Authorization: 'Bearer t' } }); + const result = await requireBearerAuth(req, { + verifier, + requiredScopes: ['read', 'write'], + resourceMetadataUrl: 'https://example.com/.well-known/oauth-protected-resource' }); - it('should include both scope and resource_metadata in WWW-Authenticate header when both are provided', async () => { - mockRequest.headers = {}; - - const resourceMetadataUrl = 'https://api.example.com/.well-known/oauth-protected-resource'; - const middleware = requireBearerAuth({ - verifier: mockVerifier, - requiredScopes: ['admin'], - resourceMetadataUrl - }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith( - 'WWW-Authenticate', - `Bearer error="invalid_token", error_description="Missing Authorization header", scope="admin", resource_metadata="${resourceMetadataUrl}"` - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); + expect('response' in result).toBe(true); + if ('response' in result) { + const header = result.response.headers.get('www-authenticate') ?? ''; + expect(header).toContain('scope="read write"'); + expect(header).toContain('resource_metadata="https://example.com/.well-known/oauth-protected-resource"'); + } }); - describe('with resourceMetadataUrl', () => { - const resourceMetadataUrl = 'https://api.example.com/.well-known/oauth-protected-resource'; - - it('should include resource_metadata in WWW-Authenticate header for 401 responses', async () => { - mockRequest.headers = {}; - - const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith( - 'WWW-Authenticate', - `Bearer error="invalid_token", error_description="Missing Authorization header", resource_metadata="${resourceMetadataUrl}"` - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); + it('passes through InsufficientScopeError from verifier as 403', async () => { + verifyAccessToken.mockRejectedValue(new InsufficientScopeError('nope')); + const req = new Request('http://localhost/x', { headers: { Authorization: 'Bearer t' } }); + const result = await requireBearerAuth(req, { verifier }); - it('should include resource_metadata in WWW-Authenticate header when token verification fails', async () => { - mockRequest.headers = { - authorization: 'Bearer invalid-token' - }; - - mockVerifyAccessToken.mockRejectedValue(new InvalidTokenError('Token expired')); - - const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith( - 'WWW-Authenticate', - `Bearer error="invalid_token", error_description="Token expired", resource_metadata="${resourceMetadataUrl}"` - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it('should include resource_metadata in WWW-Authenticate header for insufficient scope errors', async () => { - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; - - mockVerifyAccessToken.mockRejectedValue(new InsufficientScopeError('Required scopes: admin')); - - const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockResponse.status).toHaveBeenCalledWith(403); - expect(mockResponse.set).toHaveBeenCalledWith( - 'WWW-Authenticate', - `Bearer error="insufficient_scope", error_description="Required scopes: admin", resource_metadata="${resourceMetadataUrl}"` - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it('should include resource_metadata when token is expired', async () => { - const expiredAuthInfo: AuthInfo = { - token: 'expired-token', - clientId: 'client-123', - scopes: ['read', 'write'], - expiresAt: Math.floor(Date.now() / 1000) - 100 - }; - mockVerifyAccessToken.mockResolvedValue(expiredAuthInfo); - - mockRequest.headers = { - authorization: 'Bearer expired-token' - }; - - const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith( - 'WWW-Authenticate', - `Bearer error="invalid_token", error_description="Token has expired", resource_metadata="${resourceMetadataUrl}"` - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it('should include resource_metadata when scope check fails', async () => { - const authInfo: AuthInfo = { - token: 'valid-token', - clientId: 'client-123', - scopes: ['read'] - }; - mockVerifyAccessToken.mockResolvedValue(authInfo); - - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; - - const middleware = requireBearerAuth({ - verifier: mockVerifier, - requiredScopes: ['read', 'write'], - resourceMetadataUrl - }); - - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockResponse.status).toHaveBeenCalledWith(403); - expect(mockResponse.set).toHaveBeenCalledWith( - 'WWW-Authenticate', - `Bearer error="insufficient_scope", error_description="Insufficient scope", scope="read write", resource_metadata="${resourceMetadataUrl}"` - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it('should not affect server errors (no WWW-Authenticate header)', async () => { - mockRequest.headers = { - authorization: 'Bearer valid-token' - }; - - mockVerifyAccessToken.mockRejectedValue(new ServerError('Internal server issue')); - - const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockResponse.status).toHaveBeenCalledWith(500); - expect(mockResponse.set).not.toHaveBeenCalledWith('WWW-Authenticate', expect.anything()); - expect(nextFunction).not.toHaveBeenCalled(); - }); + expect('response' in result).toBe(true); + if ('response' in result) { + expect(result.response.status).toBe(403); + } }); }); diff --git a/packages/server/test/server/auth/middleware/clientAuth.test.ts b/packages/server/test/server/auth/middleware/clientAuth.test.ts index 55a00f0c2..0ee9aae0a 100644 --- a/packages/server/test/server/auth/middleware/clientAuth.test.ts +++ b/packages/server/test/server/auth/middleware/clientAuth.test.ts @@ -1,12 +1,11 @@ import type { OAuthClientInformationFull } from '@modelcontextprotocol/core'; -import express from 'express'; -import supertest from 'supertest'; +import { InvalidClientError, InvalidRequestError } from '@modelcontextprotocol/core'; import type { OAuthRegisteredClientsStore } from '../../../../src/server/auth/clients.js'; import type { ClientAuthenticationMiddlewareOptions } from '../../../../src/server/auth/middleware/clientAuth.js'; import { authenticateClient } from '../../../../src/server/auth/middleware/clientAuth.js'; -describe('clientAuth middleware', () => { +describe('authenticateClient', () => { // Mock client store const mockClientStore: OAuthRegisteredClientsStore = { async getClient(clientId: string): Promise { @@ -35,100 +34,92 @@ describe('clientAuth middleware', () => { } }; - // Setup Express app with middleware - let app: express.Express; let options: ClientAuthenticationMiddlewareOptions; beforeEach(() => { - app = express(); - app.use(express.json()); - options = { clientsStore: mockClientStore }; - - // Setup route with client auth - app.post('/protected', authenticateClient(options), (req, res) => { - res.status(200).json({ success: true, client: req.client }); - }); }); it('authenticates valid client credentials', async () => { - const response = await supertest(app).post('/protected').send({ - client_id: 'valid-client', - client_secret: 'valid-secret' - }); - - expect(response.status).toBe(200); - expect(response.body.success).toBe(true); - expect(response.body.client.client_id).toBe('valid-client'); + const client = await authenticateClient( + { + client_id: 'valid-client', + client_secret: 'valid-secret' + }, + options + ); + + expect(client.client_id).toBe('valid-client'); }); it('rejects invalid client_id', async () => { - const response = await supertest(app).post('/protected').send({ - client_id: 'non-existent-client', - client_secret: 'some-secret' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_client'); - expect(response.body.error_description).toBe('Invalid client_id'); + await expect( + authenticateClient( + { + client_id: 'non-existent-client', + client_secret: 'some-secret' + }, + options + ) + ).rejects.toBeInstanceOf(InvalidClientError); }); it('rejects invalid client_secret', async () => { - const response = await supertest(app).post('/protected').send({ - client_id: 'valid-client', - client_secret: 'wrong-secret' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_client'); - expect(response.body.error_description).toBe('Invalid client_secret'); + await expect( + authenticateClient( + { + client_id: 'valid-client', + client_secret: 'wrong-secret' + }, + options + ) + ).rejects.toBeInstanceOf(InvalidClientError); }); it('rejects missing client_id', async () => { - const response = await supertest(app).post('/protected').send({ - client_secret: 'valid-secret' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_request'); + await expect( + authenticateClient( + { + client_secret: 'valid-secret' + }, + options + ) + ).rejects.toBeInstanceOf(InvalidRequestError); }); it('allows missing client_secret if client has none', async () => { - const response = await supertest(app).post('/protected').send({ - client_id: 'expired-client' - }); - - // Since the client has no secret, this should pass without providing one - expect(response.status).toBe(200); + const client = await authenticateClient( + { + client_id: 'expired-client' + }, + options + ); + expect(client.client_id).toBe('expired-client'); }); it('rejects request when client secret has expired', async () => { - const response = await supertest(app).post('/protected').send({ - client_id: 'client-with-expired-secret', - client_secret: 'expired-secret' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_client'); - expect(response.body.error_description).toBe('Client secret has expired'); - }); - - it('handles malformed request body', async () => { - const response = await supertest(app).post('/protected').send('not-json-format'); - - expect(response.status).toBe(400); + await expect( + authenticateClient( + { + client_id: 'client-with-expired-secret', + client_secret: 'expired-secret' + }, + options + ) + ).rejects.toBeInstanceOf(InvalidClientError); }); - // Testing request with extra fields to ensure they're ignored it('ignores extra fields in request', async () => { - const response = await supertest(app).post('/protected').send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - extra_field: 'should be ignored' - }); - - expect(response.status).toBe(200); + const client = await authenticateClient( + { + client_id: 'valid-client', + client_secret: 'valid-secret', + extra_field: 'ignored' + }, + options + ); + expect(client.client_id).toBe('valid-client'); }); }); diff --git a/packages/server/test/server/auth/providers/proxyProvider.test.ts b/packages/server/test/server/auth/providers/proxyProvider.test.ts index 375179e5b..143cfa78d 100644 --- a/packages/server/test/server/auth/providers/proxyProvider.test.ts +++ b/packages/server/test/server/auth/providers/proxyProvider.test.ts @@ -1,6 +1,5 @@ import type { AuthInfo, OAuthClientInformationFull, OAuthTokens } from '@modelcontextprotocol/core'; import { InsufficientScopeError, InvalidTokenError, ServerError } from '@modelcontextprotocol/core'; -import type { Response } from 'express'; import { type Mock } from 'vitest'; import type { ProxyOptions } from '../../../../src/server/auth/providers/proxyProvider.js'; @@ -14,11 +13,6 @@ describe('Proxy OAuth Server Provider', () => { redirect_uris: ['https://example.com/callback'] }; - // Mock response object - const mockResponse = { - redirect: vi.fn() - } as unknown as Response; - // Mock provider functions const mockVerifyToken = vi.fn(); const mockGetClient = vi.fn(); @@ -81,17 +75,13 @@ describe('Proxy OAuth Server Provider', () => { describe('authorization', () => { it('redirects to authorization endpoint with correct parameters', async () => { - await provider.authorize( - validClient, - { - redirectUri: 'https://example.com/callback', - codeChallenge: 'test-challenge', - state: 'test-state', - scopes: ['read', 'write'], - resource: new URL('https://api.example.com/resource') - }, - mockResponse - ); + const response = await provider.authorize(validClient, { + redirectUri: 'https://example.com/callback', + codeChallenge: 'test-challenge', + state: 'test-state', + scopes: ['read', 'write'], + resource: new URL('https://api.example.com/resource') + }); const expectedUrl = new URL('https://auth.example.com/authorize'); expectedUrl.searchParams.set('client_id', 'test-client'); @@ -103,7 +93,8 @@ describe('Proxy OAuth Server Provider', () => { expectedUrl.searchParams.set('scope', 'read write'); expectedUrl.searchParams.set('resource', 'https://api.example.com/resource'); - expect(mockResponse.redirect).toHaveBeenCalledWith(expectedUrl.toString()); + expect(response.status).toBe(302); + expect(response.headers.get('location')).toBe(expectedUrl.toString()); }); }); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 4e127340c..ed80684a8 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -98,9 +98,6 @@ catalogs: express: specifier: ^5.0.1 version: 5.1.0 - express-rate-limit: - specifier: ^7.5.0 - version: 7.5.1 hono: specifier: ^4.11.1 version: 4.11.1 @@ -302,6 +299,12 @@ importers: '@modelcontextprotocol/server': specifier: workspace:^ version: link:../../packages/server + '@modelcontextprotocol/server-express': + specifier: workspace:^ + version: link:../../packages/server-express + '@modelcontextprotocol/server-hono': + specifier: workspace:^ + version: link:../../packages/server-hono cors: specifier: catalog:runtimeServerOnly version: 2.8.5 @@ -339,6 +342,9 @@ importers: '@modelcontextprotocol/server': specifier: workspace:^ version: link:../../packages/server + '@modelcontextprotocol/server-express': + specifier: workspace:^ + version: link:../../packages/server-express express: specifier: catalog:runtimeServerOnly version: 5.1.0 @@ -552,24 +558,9 @@ importers: packages/server: dependencies: - '@hono/node-server': - specifier: catalog:runtimeServerOnly - version: 1.19.7(hono@4.11.1) content-type: specifier: catalog:runtimeServerOnly version: 1.0.5 - cors: - specifier: catalog:runtimeServerOnly - version: 2.8.5 - express: - specifier: catalog:runtimeServerOnly - version: 5.1.0 - express-rate-limit: - specifier: catalog:runtimeServerOnly - version: 7.5.1(express@5.1.0) - hono: - specifier: catalog:runtimeServerOnly - version: 4.11.1 pkce-challenge: specifier: catalog:runtimeShared version: 5.0.0 @@ -659,6 +650,119 @@ importers: specifier: catalog:devTools version: 4.0.9(@types/node@24.10.3)(tsx@4.20.6) + packages/server-express: + dependencies: + '@modelcontextprotocol/core': + specifier: workspace:^ + version: link:../core + '@modelcontextprotocol/server': + specifier: workspace:^ + version: link:../server + express: + specifier: catalog:runtimeServerOnly + version: 5.1.0 + devDependencies: + '@eslint/js': + specifier: catalog:devTools + version: 9.39.1 + '@modelcontextprotocol/eslint-config': + specifier: workspace:^ + version: link:../../common/eslint-config + '@modelcontextprotocol/tsconfig': + specifier: workspace:^ + version: link:../../common/tsconfig + '@modelcontextprotocol/vitest-config': + specifier: workspace:^ + version: link:../../common/vitest-config + '@types/express': + specifier: catalog:devTools + version: 5.0.5 + '@types/express-serve-static-core': + specifier: catalog:devTools + version: 5.1.0 + '@types/supertest': + specifier: catalog:devTools + version: 6.0.3 + '@typescript/native-preview': + specifier: catalog:devTools + version: 7.0.0-dev.20251218.3 + eslint: + specifier: catalog:devTools + version: 9.39.1 + eslint-config-prettier: + specifier: catalog:devTools + version: 10.1.8(eslint@9.39.1) + eslint-plugin-n: + specifier: catalog:devTools + version: 17.23.1(eslint@9.39.1)(typescript@5.9.3) + prettier: + specifier: catalog:devTools + version: 3.6.2 + supertest: + specifier: catalog:devTools + version: 7.1.4 + tsdown: + specifier: catalog:devTools + version: 0.18.0(@typescript/native-preview@7.0.0-dev.20251218.3)(typescript@5.9.3) + typescript: + specifier: catalog:devTools + version: 5.9.3 + typescript-eslint: + specifier: catalog:devTools + version: 8.49.0(eslint@9.39.1)(typescript@5.9.3) + vitest: + specifier: catalog:devTools + version: 4.0.9(@types/node@24.10.3)(tsx@4.20.6) + + packages/server-hono: + dependencies: + '@modelcontextprotocol/server': + specifier: workspace:^ + version: link:../server + hono: + specifier: catalog:runtimeServerOnly + version: 4.11.1 + devDependencies: + '@eslint/js': + specifier: catalog:devTools + version: 9.39.1 + '@modelcontextprotocol/eslint-config': + specifier: workspace:^ + version: link:../../common/eslint-config + '@modelcontextprotocol/tsconfig': + specifier: workspace:^ + version: link:../../common/tsconfig + '@modelcontextprotocol/vitest-config': + specifier: workspace:^ + version: link:../../common/vitest-config + '@typescript/native-preview': + specifier: catalog:devTools + version: 7.0.0-dev.20251218.3 + eslint: + specifier: catalog:devTools + version: 9.39.1 + eslint-config-prettier: + specifier: catalog:devTools + version: 10.1.8(eslint@9.39.1) + eslint-plugin-n: + specifier: catalog:devTools + version: 17.23.1(eslint@9.39.1)(typescript@5.9.3) + prettier: + specifier: catalog:devTools + version: 3.6.2 + tsdown: + specifier: catalog:devTools + version: 0.18.0(@typescript/native-preview@7.0.0-dev.20251218.3)(typescript@5.9.3) + typescript: + specifier: catalog:devTools + version: 5.9.3 + typescript-eslint: + specifier: catalog:devTools + version: 8.49.0(eslint@9.39.1)(typescript@5.9.3) + vitest: + specifier: catalog:devTools + version: 4.0.9(@types/node@24.10.3)(tsx@4.20.6) + test/helpers: devDependencies: '@modelcontextprotocol/client': @@ -700,6 +804,9 @@ importers: '@modelcontextprotocol/server': specifier: workspace:^ version: link:../../packages/server + '@modelcontextprotocol/server-express': + specifier: workspace:^ + version: link:../../packages/server-express '@modelcontextprotocol/test-helpers': specifier: workspace:^ version: link:../helpers @@ -2089,12 +2196,6 @@ packages: resolution: {integrity: sha512-JhFGDVJ7tmDJItKhYgJCGLOWjuK9vPxiXoUFLwLDc99NlmklilbiQJwoctZtt13+xMw91MCk/REan6MWHqDjyA==} engines: {node: '>=12.0.0'} - express-rate-limit@7.5.1: - resolution: {integrity: sha512-7iN8iPMDzOMHPUYllBEsQdWVB6fPDMPqwjBaFrgr4Jgr/+okjvzAy+UHlYYL/Vs0OsOrMkwS6PJDkFlJwoxUnw==} - engines: {node: '>= 16'} - peerDependencies: - express: '>= 4.11' - express@5.1.0: resolution: {integrity: sha512-DT9ck5YIRU+8GYzzU5kT3eHGA5iL+1Zd0EutOmTE9Dtk+Tvuzd23VBU+ec7HPNSTxXYO55gPV/hq4pSBJDjFpA==} engines: {node: '>= 18'} @@ -2464,10 +2565,6 @@ packages: resolution: {integrity: sha512-PMSmkqxr106Xa156c2M265Z+FTrPl+oxd/rgOQy2tijQeK5TxQ43psO1ZCwhVOSdnn+RzkzlRz/eY4BgJBYVpg==} hasBin: true - js-yaml@4.1.0: - resolution: {integrity: sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==} - hasBin: true - js-yaml@4.1.1: resolution: {integrity: sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==} hasBin: true @@ -3614,7 +3711,7 @@ snapshots: globals: 14.0.0 ignore: 5.3.2 import-fresh: 3.3.1 - js-yaml: 4.1.0 + js-yaml: 4.1.1 minimatch: 3.1.2 strip-json-comments: 3.1.1 transitivePeerDependencies: @@ -4762,10 +4859,6 @@ snapshots: expect-type@1.2.2: {} - express-rate-limit@7.5.1(express@5.1.0): - dependencies: - express: 5.1.0 - express@5.1.0: dependencies: accepts: 2.0.0 @@ -5172,10 +5265,6 @@ snapshots: argparse: 1.0.10 esprima: 4.0.1 - js-yaml@4.1.0: - dependencies: - argparse: 2.0.1 - js-yaml@4.1.1: dependencies: argparse: 2.0.1 diff --git a/test/integration/package.json b/test/integration/package.json index e709e431a..baa099bab 100644 --- a/test/integration/package.json +++ b/test/integration/package.json @@ -35,6 +35,7 @@ "@modelcontextprotocol/core": "workspace:^", "@modelcontextprotocol/client": "workspace:^", "@modelcontextprotocol/server": "workspace:^", + "@modelcontextprotocol/server-express": "workspace:^", "zod": "catalog:runtimeShared", "vitest": "catalog:devTools", "supertest": "catalog:devTools", diff --git a/test/integration/test/issues/test_1277_zod_v4_description.test.ts b/test/integration/test/issues/test_1277_zod_v4_description.test.ts index fe58cfcd5..75a61cb36 100644 --- a/test/integration/test/issues/test_1277_zod_v4_description.test.ts +++ b/test/integration/test/issues/test_1277_zod_v4_description.test.ts @@ -9,7 +9,8 @@ import { Client } from '@modelcontextprotocol/client'; import { InMemoryTransport, ListPromptsResultSchema } from '@modelcontextprotocol/core'; import { McpServer } from '@modelcontextprotocol/server'; -import { type ZodMatrixEntry, zodTestMatrix } from '@modelcontextprotocol/test-helpers'; +import type { ZodMatrixEntry } from '@modelcontextprotocol/test-helpers'; +import { zodTestMatrix } from '@modelcontextprotocol/test-helpers'; describe.each(zodTestMatrix)('Issue #1277: $zodVersionLabel', (entry: ZodMatrixEntry) => { const { z } = entry; diff --git a/test/integration/test/server.test.ts b/test/integration/test/server.test.ts index fcac6cc45..30a2c03c4 100644 --- a/test/integration/test/server.test.ts +++ b/test/integration/test/server.test.ts @@ -30,7 +30,9 @@ import { SetLevelRequestSchema, SUPPORTED_PROTOCOL_VERSIONS } from '@modelcontextprotocol/core'; -import { createMcpExpressApp, InMemoryTaskMessageQueue, InMemoryTaskStore, McpServer, Server } from '@modelcontextprotocol/server'; +import { InMemoryTaskStore, McpServer, Server } from '@modelcontextprotocol/server'; +import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; +import type { Request, Response } from 'express'; import supertest from 'supertest'; import * as z3 from 'zod/v3'; import * as z4 from 'zod/v4'; @@ -2066,7 +2068,7 @@ describe('createMcpExpressApp', () => { test('should parse JSON bodies', async () => { const app = createMcpExpressApp({ host: '0.0.0.0' }); // Disable host validation for this test - app.post('/test', (req, res) => { + app.post('/test', (req: Request, res: Response) => { res.json({ received: req.body }); }); @@ -2078,7 +2080,7 @@ describe('createMcpExpressApp', () => { test('should reject requests with invalid Host header by default', async () => { const app = createMcpExpressApp(); - app.post('/test', (_req, res) => { + app.post('/test', (_req: Request, res: Response) => { res.json({ success: true }); }); @@ -2097,7 +2099,7 @@ describe('createMcpExpressApp', () => { test('should allow requests with localhost Host header', async () => { const app = createMcpExpressApp(); - app.post('/test', (_req, res) => { + app.post('/test', (_req: Request, res: Response) => { res.json({ success: true }); }); @@ -2109,7 +2111,7 @@ describe('createMcpExpressApp', () => { test('should allow requests with 127.0.0.1 Host header', async () => { const app = createMcpExpressApp(); - app.post('/test', (_req, res) => { + app.post('/test', (_req: Request, res: Response) => { res.json({ success: true }); }); @@ -2121,7 +2123,7 @@ describe('createMcpExpressApp', () => { test('should not apply host validation when host is 0.0.0.0', async () => { const app = createMcpExpressApp({ host: '0.0.0.0' }); - app.post('/test', (_req, res) => { + app.post('/test', (_req: Request, res: Response) => { res.json({ success: true }); }); @@ -2134,7 +2136,7 @@ describe('createMcpExpressApp', () => { test('should apply host validation when host is explicitly localhost', async () => { const app = createMcpExpressApp({ host: 'localhost' }); - app.post('/test', (_req, res) => { + app.post('/test', (_req: Request, res: Response) => { res.json({ success: true }); }); @@ -2146,7 +2148,7 @@ describe('createMcpExpressApp', () => { test('should allow requests with IPv6 localhost Host header', async () => { const app = createMcpExpressApp(); - app.post('/test', (_req, res) => { + app.post('/test', (_req: Request, res: Response) => { res.json({ success: true }); }); @@ -2158,7 +2160,7 @@ describe('createMcpExpressApp', () => { test('should apply host validation when host is ::1 (IPv6 localhost)', async () => { const app = createMcpExpressApp({ host: '::1' }); - app.post('/test', (_req, res) => { + app.post('/test', (_req: Request, res: Response) => { res.json({ success: true }); }); @@ -2185,7 +2187,7 @@ describe('createMcpExpressApp', () => { test('should use custom allowedHosts when provided', async () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}); const app = createMcpExpressApp({ host: '0.0.0.0', allowedHosts: ['myapp.local', 'localhost'] }); - app.post('/test', (_req, res) => { + app.post('/test', (_req: Request, res: Response) => { res.json({ success: true }); }); @@ -2205,7 +2207,7 @@ describe('createMcpExpressApp', () => { test('should override default localhost validation when allowedHosts is provided', async () => { // Even though host is localhost, we're using custom allowedHosts const app = createMcpExpressApp({ host: 'localhost', allowedHosts: ['custom.local'] }); - app.post('/test', (_req, res) => { + app.post('/test', (_req: Request, res: Response) => { res.json({ success: true }); }); diff --git a/test/integration/test/server/mcp.test.ts b/test/integration/test/server/mcp.test.ts index f7bcececc..90e7152aa 100644 --- a/test/integration/test/server/mcp.test.ts +++ b/test/integration/test/server/mcp.test.ts @@ -1,26 +1,28 @@ import { Client } from '@modelcontextprotocol/client'; -import { getDisplayName, InMemoryTaskStore, InMemoryTransport, UriTemplate } from '@modelcontextprotocol/core'; +import type { CallToolResult, Notification, TextContent } from '@modelcontextprotocol/core'; import { - type CallToolResult, CallToolResultSchema, CompleteResultSchema, ElicitRequestSchema, ErrorCode, + getDisplayName, GetPromptResultSchema, + InMemoryTaskStore, + InMemoryTransport, ListPromptsResultSchema, ListResourcesResultSchema, ListResourceTemplatesResultSchema, ListToolsResultSchema, LoggingMessageNotificationSchema, - type Notification, ReadResourceResultSchema, - type TextContent, + UriTemplate, UrlElicitationRequiredError } from '@modelcontextprotocol/core'; import { completable } from '../../../../packages/server/src/server/completable.js'; import { McpServer, ResourceTemplate } from '../../../../packages/server/src/server/mcp.js'; -import { type ZodMatrixEntry, zodTestMatrix } from '../../../../packages/server/test/server/__fixtures__/zodTestMatrix.js'; +import type { ZodMatrixEntry } from '../../../../packages/server/test/server/__fixtures__/zodTestMatrix.js'; +import { zodTestMatrix } from '../../../../packages/server/test/server/__fixtures__/zodTestMatrix.js'; function createLatch() { let latch = false; diff --git a/test/integration/test/stateManagementStreamableHttp.test.ts b/test/integration/test/stateManagementStreamableHttp.test.ts index c33100efa..6839cba6b 100644 --- a/test/integration/test/stateManagementStreamableHttp.test.ts +++ b/test/integration/test/stateManagementStreamableHttp.test.ts @@ -1,5 +1,6 @@ import { randomUUID } from 'node:crypto'; -import { createServer, type Server } from 'node:http'; +import type { Server } from 'node:http'; +import { createServer } from 'node:http'; import { Client, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; import { @@ -11,8 +12,8 @@ import { McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; -import { listenOnRandomPort } from '@modelcontextprotocol/test-helpers'; -import { type ZodMatrixEntry, zodTestMatrix } from '@modelcontextprotocol/test-helpers'; +import type { ZodMatrixEntry } from '@modelcontextprotocol/test-helpers'; +import { listenOnRandomPort, zodTestMatrix } from '@modelcontextprotocol/test-helpers'; describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { const { z } = entry; diff --git a/test/integration/test/taskLifecycle.test.ts b/test/integration/test/taskLifecycle.test.ts index 216479e93..d644db48e 100644 --- a/test/integration/test/taskLifecycle.test.ts +++ b/test/integration/test/taskLifecycle.test.ts @@ -1,5 +1,6 @@ import { randomUUID } from 'node:crypto'; -import { createServer, type Server } from 'node:http'; +import type { Server } from 'node:http'; +import { createServer } from 'node:http'; import { Client, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; import type { TaskRequestOptions } from '@modelcontextprotocol/server'; diff --git a/test/integration/test/taskResumability.test.ts b/test/integration/test/taskResumability.test.ts index 8dfd3a65a..5947649e4 100644 --- a/test/integration/test/taskResumability.test.ts +++ b/test/integration/test/taskResumability.test.ts @@ -1,5 +1,6 @@ import { randomUUID } from 'node:crypto'; -import { createServer, type Server } from 'node:http'; +import type { Server } from 'node:http'; +import { createServer } from 'node:http'; import { Client, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; import { diff --git a/test/integration/test/title.test.ts b/test/integration/test/title.test.ts index 4eec82335..97348c117 100644 --- a/test/integration/test/title.test.ts +++ b/test/integration/test/title.test.ts @@ -1,7 +1,8 @@ import { Client } from '@modelcontextprotocol/client'; import { InMemoryTransport } from '@modelcontextprotocol/core'; import { McpServer, ResourceTemplate, Server } from '@modelcontextprotocol/server'; -import { type ZodMatrixEntry, zodTestMatrix } from '@modelcontextprotocol/test-helpers'; +import type { ZodMatrixEntry } from '@modelcontextprotocol/test-helpers'; +import { zodTestMatrix } from '@modelcontextprotocol/test-helpers'; describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { const { z } = entry; diff --git a/test/integration/tsconfig.json b/test/integration/tsconfig.json index f69a602fd..666fc0509 100644 --- a/test/integration/tsconfig.json +++ b/test/integration/tsconfig.json @@ -8,6 +8,7 @@ "@modelcontextprotocol/core": ["./node_modules/@modelcontextprotocol/core/src/index.ts"], "@modelcontextprotocol/client": ["./node_modules/@modelcontextprotocol/client/src/index.ts"], "@modelcontextprotocol/server": ["./node_modules/@modelcontextprotocol/server/src/index.ts"], + "@modelcontextprotocol/server-express": ["./node_modules/@modelcontextprotocol/server-express/src/index.ts"], "@modelcontextprotocol/vitest-config": ["./node_modules/@modelcontextprotocol/vitest-config/tsconfig.json"], "@modelcontextprotocol/test-helpers": ["./node_modules/@modelcontextprotocol/test-helpers/src/index.ts"] } From 354fb43bc376a6df543066c3c882a7ce776ec949 Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Sat, 20 Dec 2025 13:39:20 +0200 Subject: [PATCH 2/5] remove rate limiting from core server, move to express only --- packages/server-express/README.md | 4 +- packages/server-express/package.json | 3 +- packages/server-express/src/auth/router.ts | 50 ++++++++++++++++-- .../test/server/auth/router.test.ts | 30 +++++++++++ .../src/server/auth/handlers/authorize.ts | 37 ++----------- .../src/server/auth/handlers/register.ts | 52 +------------------ .../server/src/server/auth/handlers/revoke.ts | 51 ++---------------- .../server/src/server/auth/handlers/token.ts | 49 ++--------------- packages/server/src/server/auth/router.ts | 2 - packages/server/src/server/auth/web.ts | 48 ----------------- .../server/auth/handlers/authorize.test.ts | 4 +- .../server/auth/handlers/register.test.ts | 2 +- .../test/server/auth/handlers/revoke.test.ts | 2 +- .../test/server/auth/handlers/token.test.ts | 4 +- pnpm-lock.yaml | 26 ++++++++-- pnpm-workspace.yaml | 2 +- 16 files changed, 121 insertions(+), 245 deletions(-) diff --git a/packages/server-express/README.md b/packages/server-express/README.md index c1b10a9c7..7721cb16e 100644 --- a/packages/server-express/README.md +++ b/packages/server-express/README.md @@ -60,7 +60,9 @@ app.use(express.json()); app.use( mcpAuthRouter({ provider, - issuerUrl: new URL('https://auth.example.com') + issuerUrl: new URL('https://auth.example.com'), + // Optional rate limiting (implemented via express-rate-limit) + rateLimit: { windowMs: 60_000, max: 60 } }) ); ``` diff --git a/packages/server-express/package.json b/packages/server-express/package.json index d7d31cb1e..8979c37e3 100644 --- a/packages/server-express/package.json +++ b/packages/server-express/package.json @@ -43,7 +43,8 @@ }, "dependencies": { "@modelcontextprotocol/server": "workspace:^", - "express": "catalog:runtimeServerOnly" + "express": "catalog:runtimeServerOnly", + "express-rate-limit": "catalog:runtimeServerOnly" }, "devDependencies": { "@modelcontextprotocol/tsconfig": "workspace:^", diff --git a/packages/server-express/src/auth/router.ts b/packages/server-express/src/auth/router.ts index c50d2007d..b367dc46c 100644 --- a/packages/server-express/src/auth/router.ts +++ b/packages/server-express/src/auth/router.ts @@ -3,9 +3,14 @@ import { Readable } from 'node:stream'; import { URL } from 'node:url'; import type { AuthMetadataOptions, AuthRouterOptions, WebHandlerContext } from '@modelcontextprotocol/server'; -import { mcpAuthMetadataRouter as createWebAuthMetadataRouter, mcpAuthRouter as createWebAuthRouter } from '@modelcontextprotocol/server'; +import { + mcpAuthMetadataRouter as createWebAuthMetadataRouter, + mcpAuthRouter as createWebAuthRouter, + TooManyRequestsError +} from '@modelcontextprotocol/server'; import type { RequestHandler, Response as ExpressResponse } from 'express'; import express from 'express'; +import { rateLimit } from 'express-rate-limit'; type ExpressRequestLike = IncomingMessage & { method: string; @@ -112,11 +117,23 @@ async function writeWebResponse(res: ExpressResponse, webResponse: Response): Pr function toHandlerContext(req: ExpressRequestLike): WebHandlerContext { return { - parsedBody: req.body, - clientAddress: req.ip + parsedBody: req.body }; } +export type ExpressAuthRateLimitOptions = + | false + | { + /** + * Window size in ms (default: 60s) + */ + windowMs?: number; + /** + * Max requests per window per client (default: 60) + */ + max?: number; + }; + /** * Express router adapter for the Web-standard `mcpAuthRouter` from `@modelcontextprotocol/server`. * @@ -126,12 +143,34 @@ function toHandlerContext(req: ExpressRequestLike): WebHandlerContext { * app.use(mcpAuthRouter(...)) * ``` */ -export function mcpAuthRouter(options: AuthRouterOptions): RequestHandler { +export function mcpAuthRouter(options: AuthRouterOptions & { rateLimit?: ExpressAuthRateLimitOptions }): RequestHandler { const web = createWebAuthRouter(options); const router = express.Router(); + const rateLimitOptions = options.rateLimit; + const limiter = + rateLimitOptions === false + ? undefined + : rateLimit({ + windowMs: rateLimitOptions?.windowMs ?? 60_000, + max: rateLimitOptions?.max ?? 60, + standardHeaders: true, + legacyHeaders: false, + handler: (_req, res) => { + const err = new TooManyRequestsError('Too many requests'); + res.status(429).json(err.toResponseObject()); + } + }); + + const isRateLimitedPath = (path: string): boolean => + path === '/authorize' || path === '/token' || path === '/register' || path === '/revoke'; + for (const route of web.routes) { - router.all(route.path, async (req, res, next) => { + const handlers: RequestHandler[] = []; + if (limiter && isRateLimitedPath(route.path)) { + handlers.push(limiter); + } + handlers.push(async (req, res, next) => { try { const parsedBodyProvided = (req as ExpressRequestLike).body !== undefined; const webReq = await expressToWebRequest(req as ExpressRequestLike, parsedBodyProvided); @@ -141,6 +180,7 @@ export function mcpAuthRouter(options: AuthRouterOptions): RequestHandler { next(err); } }); + router.all(route.path, ...handlers); } return router; diff --git a/packages/server-express/test/server/auth/router.test.ts b/packages/server-express/test/server/auth/router.test.ts index 7a6c09690..9d7638543 100644 --- a/packages/server-express/test/server/auth/router.test.ts +++ b/packages/server-express/test/server/auth/router.test.ts @@ -325,6 +325,36 @@ describe('MCP Auth Router', () => { expect(response.status).not.toBe(404); }); + it('applies rate limiting to token endpoint (express-rate-limit)', async () => { + // Fresh app with a very low rate limit so we can trigger it deterministically + const limitedApp = express(); + const options = { + provider: mockProvider, + issuerUrl: new URL('https://auth.example.com'), + rateLimit: { windowMs: 60_000, max: 1 } + } as const; + limitedApp.use(mcpAuthRouter(options)); + + const first = await supertest(limitedApp).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'valid_verifier' + }); + expect(first.status).not.toBe(404); + + const second = await supertest(limitedApp).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'valid_verifier' + }); + expect(second.status).toBe(429); + expect(second.body).toEqual(expect.objectContaining({ error: 'too_many_requests' })); + }); + it('routes to registration endpoint', async () => { const response = await supertest(app) .post('/register') diff --git a/packages/server/src/server/auth/handlers/authorize.ts b/packages/server/src/server/auth/handlers/authorize.ts index df2702f88..ecffee114 100644 --- a/packages/server/src/server/auth/handlers/authorize.ts +++ b/packages/server/src/server/auth/handlers/authorize.ts @@ -1,17 +1,12 @@ -import { InvalidClientError, InvalidRequestError, OAuthError, ServerError, TooManyRequestsError } from '@modelcontextprotocol/core'; +import { InvalidClientError, InvalidRequestError, OAuthError, ServerError } from '@modelcontextprotocol/core'; import * as z from 'zod/v4'; import type { OAuthServerProvider } from '../provider.js'; import type { WebHandler } from '../web.js'; -import { getClientAddress, getParsedBody, InMemoryRateLimiter, jsonResponse, methodNotAllowedResponse, noStoreHeaders } from '../web.js'; +import { getParsedBody, jsonResponse, methodNotAllowedResponse, noStoreHeaders } from '../web.js'; export type AuthorizationHandlerOptions = { provider: OAuthServerProvider; - /** - * Rate limiting configuration for the authorization endpoint. - * Set to false to disable rate limiting for this endpoint. - */ - rateLimit?: Partial<{ windowMs: number; max: number }> | false; }; // Parameters that must be validated in order to issue redirects. @@ -33,36 +28,10 @@ const RequestAuthorizationParamsSchema = z.object({ resource: z.string().url().optional() }); -export function authorizationHandler({ provider, rateLimit: rateLimitConfig }: AuthorizationHandlerOptions): WebHandler { - const limiter = - rateLimitConfig === false - ? undefined - : new InMemoryRateLimiter({ - windowMs: rateLimitConfig?.windowMs ?? 15 * 60 * 1000, - max: rateLimitConfig?.max ?? 100 - }); - +export function authorizationHandler({ provider }: AuthorizationHandlerOptions): WebHandler { return async (req, ctx) => { const noStore = noStoreHeaders(); - // Rate limit by client address where possible (best-effort). - if (limiter) { - const key = `${getClientAddress(req, ctx) ?? 'global'}:authorize`; - const rl = limiter.consume(key); - if (!rl.allowed) { - return jsonResponse( - new TooManyRequestsError('You have exceeded the rate limit for authorization requests').toResponseObject(), - { - status: 429, - headers: { - ...noStore, - ...(rl.retryAfterSeconds ? { 'Retry-After': String(rl.retryAfterSeconds) } : {}) - } - } - ); - } - } - if (req.method !== 'GET' && req.method !== 'POST') { const resp = methodNotAllowedResponse(req, ['GET', 'POST']); const body = await resp.text(); diff --git a/packages/server/src/server/auth/handlers/register.ts b/packages/server/src/server/auth/handlers/register.ts index fa54644c3..4433a1b5b 100644 --- a/packages/server/src/server/auth/handlers/register.ts +++ b/packages/server/src/server/auth/handlers/register.ts @@ -1,26 +1,11 @@ import crypto from 'node:crypto'; import type { OAuthClientInformationFull } from '@modelcontextprotocol/core'; -import { - InvalidClientMetadataError, - OAuthClientMetadataSchema, - OAuthError, - ServerError, - TooManyRequestsError -} from '@modelcontextprotocol/core'; +import { InvalidClientMetadataError, OAuthClientMetadataSchema, OAuthError, ServerError } from '@modelcontextprotocol/core'; import type { OAuthRegisteredClientsStore } from '../clients.js'; import type { WebHandler } from '../web.js'; -import { - corsHeaders, - corsPreflightResponse, - getClientAddress, - getParsedBody, - InMemoryRateLimiter, - jsonResponse, - methodNotAllowedResponse, - noStoreHeaders -} from '../web.js'; +import { corsHeaders, corsPreflightResponse, getParsedBody, jsonResponse, methodNotAllowedResponse, noStoreHeaders } from '../web.js'; export type ClientRegistrationHandlerOptions = { /** @@ -35,13 +20,6 @@ export type ClientRegistrationHandlerOptions = { */ clientSecretExpirySeconds?: number; - /** - * Rate limiting configuration for the client registration endpoint. - * Set to false to disable rate limiting for this endpoint. - * Registration endpoints are particularly sensitive to abuse and should be rate limited. - */ - rateLimit?: Partial<{ windowMs: number; max: number }> | false; - /** * Whether to generate a client ID before calling the client registration endpoint. * @@ -55,21 +33,12 @@ const DEFAULT_CLIENT_SECRET_EXPIRY_SECONDS = 30 * 24 * 60 * 60; // 30 days export function clientRegistrationHandler({ clientsStore, clientSecretExpirySeconds = DEFAULT_CLIENT_SECRET_EXPIRY_SECONDS, - rateLimit: rateLimitConfig, clientIdGeneration = true }: ClientRegistrationHandlerOptions): WebHandler { if (!clientsStore.registerClient) { throw new Error('Client registration store does not support registering clients'); } - const limiter = - rateLimitConfig === false - ? undefined - : new InMemoryRateLimiter({ - windowMs: rateLimitConfig?.windowMs ?? 60 * 60 * 1000, - max: rateLimitConfig?.max ?? 20 - }); - const cors = { allowOrigin: '*', allowMethods: ['POST', 'OPTIONS'], @@ -92,23 +61,6 @@ export function clientRegistrationHandler({ }); } - if (limiter) { - const key = `${getClientAddress(req, ctx) ?? 'global'}:register`; - const rl = limiter.consume(key); - if (!rl.allowed) { - return jsonResponse( - new TooManyRequestsError('You have exceeded the rate limit for client registration requests').toResponseObject(), - { - status: 429, - headers: { - ...baseHeaders, - ...(rl.retryAfterSeconds ? { 'Retry-After': String(rl.retryAfterSeconds) } : {}) - } - } - ); - } - } - try { const rawBody = await getParsedBody(req, ctx); const parseResult = OAuthClientMetadataSchema.safeParse(rawBody); diff --git a/packages/server/src/server/auth/handlers/revoke.ts b/packages/server/src/server/auth/handlers/revoke.ts index 30c611cff..e4814345d 100644 --- a/packages/server/src/server/auth/handlers/revoke.ts +++ b/packages/server/src/server/auth/handlers/revoke.ts @@ -1,47 +1,19 @@ -import { - InvalidRequestError, - OAuthError, - OAuthTokenRevocationRequestSchema, - ServerError, - TooManyRequestsError -} from '@modelcontextprotocol/core'; +import { InvalidRequestError, OAuthError, OAuthTokenRevocationRequestSchema, ServerError } from '@modelcontextprotocol/core'; import { authenticateClient } from '../middleware/clientAuth.js'; import type { OAuthServerProvider } from '../provider.js'; import type { WebHandler } from '../web.js'; -import { - corsHeaders, - corsPreflightResponse, - getClientAddress, - getParsedBody, - InMemoryRateLimiter, - jsonResponse, - methodNotAllowedResponse, - noStoreHeaders -} from '../web.js'; +import { corsHeaders, corsPreflightResponse, getParsedBody, jsonResponse, methodNotAllowedResponse, noStoreHeaders } from '../web.js'; export type RevocationHandlerOptions = { provider: OAuthServerProvider; - /** - * Rate limiting configuration for the token revocation endpoint. - * Set to false to disable rate limiting for this endpoint. - */ - rateLimit?: Partial<{ windowMs: number; max: number }> | false; }; -export function revocationHandler({ provider, rateLimit: rateLimitConfig }: RevocationHandlerOptions): WebHandler { +export function revocationHandler({ provider }: RevocationHandlerOptions): WebHandler { if (!provider.revokeToken) { throw new Error('Auth provider does not support revoking tokens'); } - const limiter = - rateLimitConfig === false - ? undefined - : new InMemoryRateLimiter({ - windowMs: rateLimitConfig?.windowMs ?? 15 * 60 * 1000, - max: rateLimitConfig?.max ?? 50 - }); - const cors = { allowOrigin: '*', allowMethods: ['POST', 'OPTIONS'], @@ -64,23 +36,6 @@ export function revocationHandler({ provider, rateLimit: rateLimitConfig }: Revo }); } - if (limiter) { - const key = `${getClientAddress(req, ctx) ?? 'global'}:revoke`; - const rl = limiter.consume(key); - if (!rl.allowed) { - return jsonResponse( - new TooManyRequestsError('You have exceeded the rate limit for token revocation requests').toResponseObject(), - { - status: 429, - headers: { - ...baseHeaders, - ...(rl.retryAfterSeconds ? { 'Retry-After': String(rl.retryAfterSeconds) } : {}) - } - } - ); - } - } - try { const rawBody = await getParsedBody(req, ctx); const parseResult = OAuthTokenRevocationRequestSchema.safeParse(rawBody); diff --git a/packages/server/src/server/auth/handlers/token.ts b/packages/server/src/server/auth/handlers/token.ts index 096a10ff3..6dcdfd8b1 100644 --- a/packages/server/src/server/auth/handlers/token.ts +++ b/packages/server/src/server/auth/handlers/token.ts @@ -1,35 +1,14 @@ -import { - InvalidGrantError, - InvalidRequestError, - OAuthError, - ServerError, - TooManyRequestsError, - UnsupportedGrantTypeError -} from '@modelcontextprotocol/core'; +import { InvalidGrantError, InvalidRequestError, OAuthError, ServerError, UnsupportedGrantTypeError } from '@modelcontextprotocol/core'; import { verifyChallenge } from 'pkce-challenge'; import * as z from 'zod/v4'; import { authenticateClient } from '../middleware/clientAuth.js'; import type { OAuthServerProvider } from '../provider.js'; import type { WebHandler } from '../web.js'; -import { - corsHeaders, - corsPreflightResponse, - getClientAddress, - getParsedBody, - InMemoryRateLimiter, - jsonResponse, - methodNotAllowedResponse, - noStoreHeaders -} from '../web.js'; +import { corsHeaders, corsPreflightResponse, getParsedBody, jsonResponse, methodNotAllowedResponse, noStoreHeaders } from '../web.js'; export type TokenHandlerOptions = { provider: OAuthServerProvider; - /** - * Rate limiting configuration for the token endpoint. - * Set to false to disable rate limiting for this endpoint. - */ - rateLimit?: Partial<{ windowMs: number; max: number }> | false; }; const TokenRequestSchema = z.object({ @@ -49,15 +28,7 @@ const RefreshTokenGrantSchema = z.object({ resource: z.string().url().optional() }); -export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHandlerOptions): WebHandler { - const limiter = - rateLimitConfig === false - ? undefined - : new InMemoryRateLimiter({ - windowMs: rateLimitConfig?.windowMs ?? 15 * 60 * 1000, - max: rateLimitConfig?.max ?? 50 - }); - +export function tokenHandler({ provider }: TokenHandlerOptions): WebHandler { const cors = { allowOrigin: '*', allowMethods: ['POST', 'OPTIONS'], @@ -80,20 +51,6 @@ export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHand }); } - if (limiter) { - const key = `${getClientAddress(req, ctx) ?? 'global'}:token`; - const rl = limiter.consume(key); - if (!rl.allowed) { - return jsonResponse(new TooManyRequestsError('You have exceeded the rate limit for token requests').toResponseObject(), { - status: 429, - headers: { - ...baseHeaders, - ...(rl.retryAfterSeconds ? { 'Retry-After': String(rl.retryAfterSeconds) } : {}) - } - }); - } - } - try { const rawBody = await getParsedBody(req, ctx); const parseResult = TokenRequestSchema.safeParse(rawBody); diff --git a/packages/server/src/server/auth/router.ts b/packages/server/src/server/auth/router.ts index 083657250..61ed79806 100644 --- a/packages/server/src/server/auth/router.ts +++ b/packages/server/src/server/auth/router.ts @@ -141,8 +141,6 @@ export const createOAuthMetadata = (options: { * Installs standard MCP authorization server endpoints, including dynamic client registration and token revocation (if supported). * Also advertises standard authorization server metadata, for easier discovery of supported configurations by clients. * Note: if your MCP server is only a resource server and not an authorization server, use mcpAuthMetadataRouter instead. - * - * By default, rate limiting is applied to all endpoints to prevent abuse. */ export function mcpAuthRouter(options: AuthRouterOptions): WebAuthRouter { const oauthMetadata = createOAuthMetadata(options); diff --git a/packages/server/src/server/auth/web.ts b/packages/server/src/server/auth/web.ts index a16ada4ef..e461e9711 100644 --- a/packages/server/src/server/auth/web.ts +++ b/packages/server/src/server/auth/web.ts @@ -8,11 +8,6 @@ export type WebHandlerContext = { * If provided, handlers will use this instead of reading from the Request stream. */ parsedBody?: unknown; - - /** - * Optional client address for rate limiting (e.g., IP). - */ - clientAddress?: string; }; export type WebHandler = (req: Request, ctx?: WebHandlerContext) => Promise; @@ -32,13 +27,6 @@ export function noStoreHeaders(): HeaderMap { return { 'Cache-Control': 'no-store' }; } -export function getClientAddress(req: Request, ctx?: WebHandlerContext): string | undefined { - if (ctx?.clientAddress) return ctx.clientAddress; - const xff = req.headers.get('x-forwarded-for'); - if (xff) return xff.split(',')[0]?.trim(); - return undefined; -} - export async function getParsedBody(req: Request, ctx?: WebHandlerContext): Promise { if (ctx?.parsedBody !== undefined) { return ctx.parsedBody; @@ -102,39 +90,3 @@ export function corsPreflightResponse(options: CorsOptions): Response { headers: corsHeaders(options) }); } - -export type InMemoryRateLimitConfig = { - windowMs: number; - max: number; -}; - -type RateState = { windowStart: number; count: number }; - -/** - * Minimal in-memory rate limiter for single-process deployments. - * Not suitable for distributed setups without an external store. - */ -export class InMemoryRateLimiter { - private _state = new Map(); - - constructor(private _config: InMemoryRateLimitConfig) {} - - consume(key: string): { allowed: boolean; retryAfterSeconds?: number } { - const now = Date.now(); - const windowStart = now - (now % this._config.windowMs); - const existing = this._state.get(key); - - if (!existing || existing.windowStart !== windowStart) { - this._state.set(key, { windowStart, count: 1 }); - return { allowed: true }; - } - - if (existing.count >= this._config.max) { - const retryAfterMs = windowStart + this._config.windowMs - now; - return { allowed: false, retryAfterSeconds: Math.max(1, Math.ceil(retryAfterMs / 1000)) }; - } - - existing.count += 1; - return { allowed: true }; - } -} diff --git a/packages/server/test/server/auth/handlers/authorize.test.ts b/packages/server/test/server/auth/handlers/authorize.test.ts index e5e65d72c..c5943915c 100644 --- a/packages/server/test/server/auth/handlers/authorize.test.ts +++ b/packages/server/test/server/auth/handlers/authorize.test.ts @@ -76,7 +76,7 @@ describe('authorizationHandler (web)', () => { }); it('redirects with a code on valid request (single redirect_uri inferred)', async () => { - const handler = authorizationHandler({ provider, rateLimit: false }); + const handler = authorizationHandler({ provider }); const res = await handler( new Request( 'http://localhost/authorize?client_id=valid-client&response_type=code&code_challenge=challenge123&code_challenge_method=S256', @@ -91,7 +91,7 @@ describe('authorizationHandler (web)', () => { }); it('requires redirect_uri if client has multiple redirect URIs', async () => { - const handler = authorizationHandler({ provider, rateLimit: false }); + const handler = authorizationHandler({ provider }); const res = await handler( new Request( 'http://localhost/authorize?client_id=multi-redirect-client&response_type=code&code_challenge=challenge123&code_challenge_method=S256', diff --git a/packages/server/test/server/auth/handlers/register.test.ts b/packages/server/test/server/auth/handlers/register.test.ts index 1d3673c3f..6a2ffcd11 100644 --- a/packages/server/test/server/auth/handlers/register.test.ts +++ b/packages/server/test/server/auth/handlers/register.test.ts @@ -20,7 +20,7 @@ describe('clientRegistrationHandler (web)', () => { } }; - const handler = clientRegistrationHandler({ clientsStore, rateLimit: false }); + const handler = clientRegistrationHandler({ clientsStore }); const res = await handler( new Request('http://localhost/register', { diff --git a/packages/server/test/server/auth/handlers/revoke.test.ts b/packages/server/test/server/auth/handlers/revoke.test.ts index d9598bb43..d960f26c3 100644 --- a/packages/server/test/server/auth/handlers/revoke.test.ts +++ b/packages/server/test/server/auth/handlers/revoke.test.ts @@ -50,7 +50,7 @@ describe('revocationHandler (web)', () => { } }; - const handler = revocationHandler({ provider, rateLimit: false }); + const handler = revocationHandler({ provider }); const body = new URLSearchParams({ client_id: 'valid-client', diff --git a/packages/server/test/server/auth/handlers/token.test.ts b/packages/server/test/server/auth/handlers/token.test.ts index d0cb0ca24..d99e4b39a 100644 --- a/packages/server/test/server/auth/handlers/token.test.ts +++ b/packages/server/test/server/auth/handlers/token.test.ts @@ -64,7 +64,7 @@ describe('tokenHandler (web)', () => { it('returns tokens for authorization_code grant when PKCE passes', async () => { (pkceChallenge.verifyChallenge as unknown as ReturnType).mockResolvedValue(true); - const handler = tokenHandler({ provider, rateLimit: false }); + const handler = tokenHandler({ provider }); const body = new URLSearchParams({ client_id: 'valid-client', @@ -92,7 +92,7 @@ describe('tokenHandler (web)', () => { it('returns 400 when PKCE fails', async () => { (pkceChallenge.verifyChallenge as unknown as ReturnType).mockResolvedValue(false); - const handler = tokenHandler({ provider, rateLimit: false }); + const handler = tokenHandler({ provider }); const body = new URLSearchParams({ client_id: 'valid-client', diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index ed80684a8..5a91b23b4 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -98,6 +98,9 @@ catalogs: express: specifier: ^5.0.1 version: 5.1.0 + express-rate-limit: + specifier: ^8.2.1 + version: 8.2.1 hono: specifier: ^4.11.1 version: 4.11.1 @@ -652,15 +655,15 @@ importers: packages/server-express: dependencies: - '@modelcontextprotocol/core': - specifier: workspace:^ - version: link:../core '@modelcontextprotocol/server': specifier: workspace:^ version: link:../server express: specifier: catalog:runtimeServerOnly version: 5.1.0 + express-rate-limit: + specifier: catalog:runtimeServerOnly + version: 8.2.1(express@5.1.0) devDependencies: '@eslint/js': specifier: catalog:devTools @@ -2196,6 +2199,12 @@ packages: resolution: {integrity: sha512-JhFGDVJ7tmDJItKhYgJCGLOWjuK9vPxiXoUFLwLDc99NlmklilbiQJwoctZtt13+xMw91MCk/REan6MWHqDjyA==} engines: {node: '>=12.0.0'} + express-rate-limit@8.2.1: + resolution: {integrity: sha512-PCZEIEIxqwhzw4KF0n7QF4QqruVTcF73O5kFKUnGOyjbCCgizBBiFaYpd/fnBLUMPw/BWw9OsiN7GgrNYr7j6g==} + engines: {node: '>= 16'} + peerDependencies: + express: '>= 4.11' + express@5.1.0: resolution: {integrity: sha512-DT9ck5YIRU+8GYzzU5kT3eHGA5iL+1Zd0EutOmTE9Dtk+Tvuzd23VBU+ec7HPNSTxXYO55gPV/hq4pSBJDjFpA==} engines: {node: '>= 18'} @@ -2434,6 +2443,10 @@ packages: resolution: {integrity: sha512-4gd7VpWNQNB4UKKCFFVcp1AVv+FMOgs9NKzjHKusc8jTMhd5eL1NqQqOpE0KzMds804/yHlglp3uxgluOqAPLw==} engines: {node: '>= 0.4'} + ip-address@10.0.1: + resolution: {integrity: sha512-NWv9YLW4PoW2B7xtzaS3NCot75m6nK7Icdv0o3lfMceJVRfSoQwqD4wEH5rLwoKJwUiZ/rfpiVBhnaF0FK4HoA==} + engines: {node: '>= 12'} + ipaddr.js@1.9.1: resolution: {integrity: sha512-0KI/607xoxSToH7GjN1FfSbLoU0+btTicjsQSWQlh/hZykN8KpmMf7uYwPW3R+akZ6R/w18ZlXSHBYXiYUPO3g==} engines: {node: '>= 0.10'} @@ -4859,6 +4872,11 @@ snapshots: expect-type@1.2.2: {} + express-rate-limit@8.2.1(express@5.1.0): + dependencies: + express: 5.1.0 + ip-address: 10.0.1 + express@5.1.0: dependencies: accepts: 2.0.0 @@ -5132,6 +5150,8 @@ snapshots: hasown: 2.0.2 side-channel: 1.1.0 + ip-address@10.0.1: {} + ipaddr.js@1.9.1: {} is-array-buffer@3.0.5: diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index 12bae8326..55aac1aba 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -19,7 +19,7 @@ catalogs: content-type: ^1.0.5 cors: ^2.8.5 express: ^5.0.1 - express-rate-limit: ^7.5.0 + express-rate-limit: ^8.2.1 raw-body: ^3.0.0 runtimeClientOnly: jose: ^6.1.1 From 564aed24ede75fba2bbee0a121ae90344aed07e2 Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Sat, 20 Dec 2025 13:48:27 +0200 Subject: [PATCH 3/5] rename StreamableHttpServerTransport to NodeStreamableHttpServerTransport, add server-express, server-hono to pkg.pr.new --- .github/workflows/publish.yml | 2 +- CLAUDE.md | 2 +- examples/server/README.md | 2 +- examples/server/src/elicitationFormExample.ts | 8 +-- examples/server/src/elicitationUrlExample.ts | 8 +-- .../server/src/jsonResponseStreamableHttp.ts | 8 +-- .../src/simpleStatelessStreamableHttp.ts | 4 +- examples/server/src/simpleStreamableHttp.ts | 8 +-- examples/server/src/simpleTaskInteractive.ts | 8 +-- .../sseAndStreamableHttpCompatibleServer.ts | 14 ++--- examples/server/src/ssePollingExample.ts | 6 +-- .../src/standaloneSseWithGetStreamableHttp.ts | 8 +-- packages/core/src/shared/protocol.ts | 4 +- packages/core/src/types/types.ts | 4 +- packages/server-express/README.md | 4 +- packages/server-express/package.json | 1 + packages/server-hono/package.json | 1 + packages/server/src/server/sse.ts | 2 +- packages/server/src/server/streamableHttp.ts | 12 ++--- .../src/server/webStandardStreamableHttp.ts | 2 +- .../server/test/server/streamableHttp.test.ts | 54 +++++++++---------- .../stateManagementStreamableHttp.test.ts | 8 +-- test/integration/test/taskLifecycle.test.ts | 6 +-- .../integration/test/taskResumability.test.ts | 6 +-- 24 files changed, 92 insertions(+), 90 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 1167b176a..8e69fc8a7 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -38,4 +38,4 @@ jobs: run: pnpm run build:all - name: Publish preview packages - run: pnpm dlx pkg-pr-new publish --packageManager=npm --pnpm './packages/server' './packages/client' + run: pnpm dlx pkg-pr-new publish --packageManager=npm --pnpm './packages/server' './packages/client' './packages/server-express' './packages/server-hono' diff --git a/CLAUDE.md b/CLAUDE.md index 2a0b253a6..3caca17b6 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -224,7 +224,7 @@ mcpServer.tool('tool-name', { param: z.string() }, async ({ param }, extra) => { ```typescript // Server -const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID() }); +const transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID() }); await server.connect(transport); // Client diff --git a/examples/server/README.md b/examples/server/README.md index bb1216a04..1e7322b1a 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -71,7 +71,7 @@ When deploying MCP servers in a horizontally scaled environment (multiple server ### Stateless mode -To enable stateless mode, configure the `StreamableHTTPServerTransport` with: +To enable stateless mode, configure the `NodeStreamableHTTPServerTransport` with: ```typescript sessionIdGenerator: undefined; diff --git a/examples/server/src/elicitationFormExample.ts b/examples/server/src/elicitationFormExample.ts index 567975662..eaeb73c32 100644 --- a/examples/server/src/elicitationFormExample.ts +++ b/examples/server/src/elicitationFormExample.ts @@ -9,7 +9,7 @@ import { randomUUID } from 'node:crypto'; -import { isInitializeRequest, McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { isInitializeRequest, McpServer, NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; @@ -322,7 +322,7 @@ async function main() { const app = createMcpExpressApp(); // Map to store transports by session ID - const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {}; + const transports: { [sessionId: string]: NodeStreamableHTTPServerTransport } = {}; // MCP POST endpoint const mcpPostHandler = async (req: Request, res: Response) => { @@ -332,13 +332,13 @@ async function main() { } try { - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; if (sessionId && transports[sessionId]) { // Reuse existing transport for this session transport = transports[sessionId]; } else if (!sessionId && isInitializeRequest(req.body)) { // New initialization request - create new transport - transport = new StreamableHTTPServerTransport({ + transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), onsessioninitialized: sessionId => { // Store the transport by session ID when session is initialized diff --git a/examples/server/src/elicitationUrlExample.ts b/examples/server/src/elicitationUrlExample.ts index 79ba49a17..51e1344b8 100644 --- a/examples/server/src/elicitationUrlExample.ts +++ b/examples/server/src/elicitationUrlExample.ts @@ -16,7 +16,7 @@ import { getOAuthProtectedResourceMetadataUrl, isInitializeRequest, McpServer, - StreamableHTTPServerTransport, + NodeStreamableHTTPServerTransport, UrlElicitationRequiredError } from '@modelcontextprotocol/server'; import { createMcpExpressApp, mcpAuthMetadataRouter, requireBearerAuth } from '@modelcontextprotocol/server-express'; @@ -592,7 +592,7 @@ app.post('/confirm-payment', express.urlencoded(), (req: Request, res: Response) }); // Map to store transports by session ID -const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {}; +const transports: { [sessionId: string]: NodeStreamableHTTPServerTransport } = {}; // Interface for a function that can send an elicitation request type ElicitationSender = (params: ElicitRequestURLParams) => Promise; @@ -611,7 +611,7 @@ const mcpPostHandler = async (req: Request, res: Response) => { console.debug(`Received MCP POST for session: ${sessionId || 'unknown'}`); try { - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; if (sessionId && transports[sessionId]) { // Reuse existing transport transport = transports[sessionId]; @@ -619,7 +619,7 @@ const mcpPostHandler = async (req: Request, res: Response) => { const server = getServer(); // New initialization request const eventStore = new InMemoryEventStore(); - transport = new StreamableHTTPServerTransport({ + transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), eventStore, // Enable resumability onsessioninitialized: sessionId => { diff --git a/examples/server/src/jsonResponseStreamableHttp.ts b/examples/server/src/jsonResponseStreamableHttp.ts index 44155ea9d..5935ad2c2 100644 --- a/examples/server/src/jsonResponseStreamableHttp.ts +++ b/examples/server/src/jsonResponseStreamableHttp.ts @@ -1,7 +1,7 @@ import { randomUUID } from 'node:crypto'; import type { CallToolResult } from '@modelcontextprotocol/server'; -import { isInitializeRequest, McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { isInitializeRequest, McpServer, NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; import * as z from 'zod/v4'; @@ -97,21 +97,21 @@ const getServer = () => { const app = createMcpExpressApp(); // Map to store transports by session ID -const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {}; +const transports: { [sessionId: string]: NodeStreamableHTTPServerTransport } = {}; app.post('/mcp', async (req: Request, res: Response) => { console.log('Received MCP request:', req.body); try { // Check for existing session ID const sessionId = req.headers['mcp-session-id'] as string | undefined; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; if (sessionId && transports[sessionId]) { // Reuse existing transport transport = transports[sessionId]; } else if (!sessionId && isInitializeRequest(req.body)) { // New initialization request - use JSON response mode - transport = new StreamableHTTPServerTransport({ + transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), enableJsonResponse: true, // Enable JSON response mode onsessioninitialized: sessionId => { diff --git a/examples/server/src/simpleStatelessStreamableHttp.ts b/examples/server/src/simpleStatelessStreamableHttp.ts index 0f3a78e63..70389275c 100644 --- a/examples/server/src/simpleStatelessStreamableHttp.ts +++ b/examples/server/src/simpleStatelessStreamableHttp.ts @@ -1,5 +1,5 @@ import type { CallToolResult, GetPromptResult, ReadResourceResult } from '@modelcontextprotocol/server'; -import { McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { McpServer, NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; import * as z from 'zod/v4'; @@ -104,7 +104,7 @@ const app = createMcpExpressApp(); app.post('/mcp', async (req: Request, res: Response) => { const server = getServer(); try { - const transport: StreamableHTTPServerTransport = new StreamableHTTPServerTransport({ + const transport: NodeStreamableHTTPServerTransport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: undefined }); await server.connect(transport); diff --git a/examples/server/src/simpleStreamableHttp.ts b/examples/server/src/simpleStreamableHttp.ts index f550ed7d7..c1656a544 100644 --- a/examples/server/src/simpleStreamableHttp.ts +++ b/examples/server/src/simpleStreamableHttp.ts @@ -17,7 +17,7 @@ import { InMemoryTaskStore, isInitializeRequest, McpServer, - StreamableHTTPServerTransport + NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; import { createMcpExpressApp, mcpAuthMetadataRouter, requireBearerAuth } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; @@ -588,7 +588,7 @@ if (useOAuth) { } // Map to store transports by session ID -const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {}; +const transports: { [sessionId: string]: NodeStreamableHTTPServerTransport } = {}; // MCP POST endpoint with optional auth const mcpPostHandler = async (req: Request, res: Response) => { @@ -603,14 +603,14 @@ const mcpPostHandler = async (req: Request, res: Response) => { console.log('Authenticated user:', req.auth); } try { - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; if (sessionId && transports[sessionId]) { // Reuse existing transport transport = transports[sessionId]; } else if (!sessionId && isInitializeRequest(req.body)) { // New initialization request const eventStore = new InMemoryEventStore(); - transport = new StreamableHTTPServerTransport({ + transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), eventStore, // Enable resumability onsessioninitialized: sessionId => { diff --git a/examples/server/src/simpleTaskInteractive.ts b/examples/server/src/simpleTaskInteractive.ts index 4685f33f5..db4058054 100644 --- a/examples/server/src/simpleTaskInteractive.ts +++ b/examples/server/src/simpleTaskInteractive.ts @@ -42,7 +42,7 @@ import { ListToolsRequestSchema, RELATED_TASK_META_KEY, Server, - StreamableHTTPServerTransport + NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; @@ -642,7 +642,7 @@ const createServer = (): Server => { const app = createMcpExpressApp(); // Map to store transports by session ID -const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {}; +const transports: { [sessionId: string]: NodeStreamableHTTPServerTransport } = {}; // Helper to check if request is initialize const isInitializeRequest = (body: unknown): boolean => { @@ -654,12 +654,12 @@ app.post('/mcp', async (req: Request, res: Response) => { const sessionId = req.headers['mcp-session-id'] as string | undefined; try { - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; if (sessionId && transports[sessionId]) { transport = transports[sessionId]; } else if (!sessionId && isInitializeRequest(req.body)) { - transport = new StreamableHTTPServerTransport({ + transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), onsessioninitialized: sid => { console.log(`Session initialized: ${sid}`); diff --git a/examples/server/src/sseAndStreamableHttpCompatibleServer.ts b/examples/server/src/sseAndStreamableHttpCompatibleServer.ts index 3ea3b71db..d54a5287c 100644 --- a/examples/server/src/sseAndStreamableHttpCompatibleServer.ts +++ b/examples/server/src/sseAndStreamableHttpCompatibleServer.ts @@ -1,7 +1,7 @@ import { randomUUID } from 'node:crypto'; import type { CallToolResult } from '@modelcontextprotocol/server'; -import { isInitializeRequest, McpServer, SSEServerTransport, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { isInitializeRequest, McpServer, SSEServerTransport, NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; import * as z from 'zod/v4'; @@ -76,7 +76,7 @@ const getServer = () => { const app = createMcpExpressApp(); // Store transports by session ID -const transports: Record = {}; +const transports: Record = {}; //============================================================================= // STREAMABLE HTTP TRANSPORT (PROTOCOL VERSION 2025-11-25) @@ -89,16 +89,16 @@ app.all('/mcp', async (req: Request, res: Response) => { try { // Check for existing session ID const sessionId = req.headers['mcp-session-id'] as string | undefined; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; if (sessionId && transports[sessionId]) { // Check if the transport is of the correct type const existingTransport = transports[sessionId]; - if (existingTransport instanceof StreamableHTTPServerTransport) { + if (existingTransport instanceof NodeStreamableHTTPServerTransport) { // Reuse existing transport transport = existingTransport; } else { - // Transport exists but is not a StreamableHTTPServerTransport (could be SSEServerTransport) + // Transport exists but is not a NodeStreamableHTTPServerTransport (could be SSEServerTransport) res.status(400).json({ jsonrpc: '2.0', error: { @@ -111,7 +111,7 @@ app.all('/mcp', async (req: Request, res: Response) => { } } else if (!sessionId && req.method === 'POST' && isInitializeRequest(req.body)) { const eventStore = new InMemoryEventStore(); - transport = new StreamableHTTPServerTransport({ + transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), eventStore, // Enable resumability onsessioninitialized: sessionId => { @@ -186,7 +186,7 @@ app.post('/messages', async (req: Request, res: Response) => { // Reuse existing transport transport = existingTransport; } else { - // Transport exists but is not a SSEServerTransport (could be StreamableHTTPServerTransport) + // Transport exists but is not a SSEServerTransport (could be NodeStreamableHTTPServerTransport) res.status(400).json({ jsonrpc: '2.0', error: { diff --git a/examples/server/src/ssePollingExample.ts b/examples/server/src/ssePollingExample.ts index e7da09ecb..4d0841dee 100644 --- a/examples/server/src/ssePollingExample.ts +++ b/examples/server/src/ssePollingExample.ts @@ -15,7 +15,7 @@ import { randomUUID } from 'node:crypto'; import type { CallToolResult } from '@modelcontextprotocol/server'; -import { McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { McpServer, NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import cors from 'cors'; import type { Request, Response } from 'express'; @@ -112,7 +112,7 @@ app.use(cors()); const eventStore = new InMemoryEventStore(); // Track transports by session ID for session reuse -const transports = new Map(); +const transports = new Map(); // Handle all MCP requests app.all('/mcp', async (req: Request, res: Response) => { @@ -122,7 +122,7 @@ app.all('/mcp', async (req: Request, res: Response) => { let transport = sessionId ? transports.get(sessionId) : undefined; if (!transport) { - transport = new StreamableHTTPServerTransport({ + transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), eventStore, retryInterval: 2000, // Default retry interval for priming events diff --git a/examples/server/src/standaloneSseWithGetStreamableHttp.ts b/examples/server/src/standaloneSseWithGetStreamableHttp.ts index f9fb426cd..869d7e859 100644 --- a/examples/server/src/standaloneSseWithGetStreamableHttp.ts +++ b/examples/server/src/standaloneSseWithGetStreamableHttp.ts @@ -1,7 +1,7 @@ import { randomUUID } from 'node:crypto'; import type { ReadResourceResult } from '@modelcontextprotocol/server'; -import { isInitializeRequest, McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { isInitializeRequest, McpServer, NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; @@ -12,7 +12,7 @@ const server = new McpServer({ }); // Store transports by session ID to send notifications -const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {}; +const transports: { [sessionId: string]: NodeStreamableHTTPServerTransport } = {}; const addResource = (name: string, content: string) => { const uri = `https://mcp-example.com/dynamic/${encodeURIComponent(name)}`; @@ -42,14 +42,14 @@ app.post('/mcp', async (req: Request, res: Response) => { try { // Check for existing session ID const sessionId = req.headers['mcp-session-id'] as string | undefined; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; if (sessionId && transports[sessionId]) { // Reuse existing transport transport = transports[sessionId]; } else if (!sessionId && isInitializeRequest(req.body)) { // New initialization request - transport = new StreamableHTTPServerTransport({ + transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), onsessioninitialized: sessionId => { // Store the transport by session ID when session is initialized diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index 9c65015d1..0a7e6aa1f 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -292,14 +292,14 @@ export type RequestHandlerExtra void; /** * Closes the standalone GET SSE stream, triggering client reconnection. - * Only available when using StreamableHTTPServerTransport with eventStore configured. + * Only available when using NodeStreamableHTTPServerTransport with eventStore configured. * Use this to implement polling behavior for server-initiated notifications. */ closeStandaloneSSEStream?: () => void; diff --git a/packages/core/src/types/types.ts b/packages/core/src/types/types.ts index 35b04745d..f3e1b92a8 100644 --- a/packages/core/src/types/types.ts +++ b/packages/core/src/types/types.ts @@ -2415,13 +2415,13 @@ export interface MessageExtraInfo { /** * Callback to close the SSE stream for this request, triggering client reconnection. - * Only available when using StreamableHTTPServerTransport with eventStore configured. + * Only available when using NodeStreamableHTTPServerTransport with eventStore configured. */ closeSSEStream?: () => void; /** * Callback to close the standalone GET SSE stream, triggering client reconnection. - * Only available when using StreamableHTTPServerTransport with eventStore configured. + * Only available when using NodeStreamableHTTPServerTransport with eventStore configured. */ closeStandaloneSSEStream?: () => void; } diff --git a/packages/server-express/README.md b/packages/server-express/README.md index 7721cb16e..27fb348d7 100644 --- a/packages/server-express/README.md +++ b/packages/server-express/README.md @@ -32,13 +32,13 @@ const app = createMcpExpressApp(); // default host is 127.0.0.1; protection enab ### Streamable HTTP endpoint (Express) ```ts -import { McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { McpServer, NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; const app = createMcpExpressApp(); app.post('/mcp', async (req, res) => { - const transport = new StreamableHTTPServerTransport(); + const transport = new NodeStreamableHTTPServerTransport(); await transport.handleRequest(req, res, req.body); }); ``` diff --git a/packages/server-express/package.json b/packages/server-express/package.json index 8979c37e3..51fde8931 100644 --- a/packages/server-express/package.json +++ b/packages/server-express/package.json @@ -1,5 +1,6 @@ { "name": "@modelcontextprotocol/server-express", + "private": false, "version": "2.0.0-alpha.0", "description": "Express adapters for the Model Context Protocol TypeScript server SDK", "license": "MIT", diff --git a/packages/server-hono/package.json b/packages/server-hono/package.json index 33f633d40..ac5b01a89 100644 --- a/packages/server-hono/package.json +++ b/packages/server-hono/package.json @@ -1,5 +1,6 @@ { "name": "@modelcontextprotocol/server-hono", + "private": false, "version": "2.0.0-alpha.0", "description": "Hono adapters for the Model Context Protocol TypeScript server SDK", "license": "MIT", diff --git a/packages/server/src/server/sse.ts b/packages/server/src/server/sse.ts index 44117d0dd..06d418b2d 100644 --- a/packages/server/src/server/sse.ts +++ b/packages/server/src/server/sse.ts @@ -42,7 +42,7 @@ export interface SSEServerTransportOptions { * Server transport for SSE: this will send messages over an SSE connection and receive messages from HTTP POST requests. * * This transport is only available in Node.js environments. - * @deprecated SSEServerTransport is deprecated. Use StreamableHTTPServerTransport instead. + * @deprecated SSEServerTransport is deprecated. Use NodeStreamableHTTPServerTransport instead. */ export class SSEServerTransport implements Transport { private _sseResponse?: ServerResponse; diff --git a/packages/server/src/server/streamableHttp.ts b/packages/server/src/server/streamableHttp.ts index 354f640f9..65b39c52c 100644 --- a/packages/server/src/server/streamableHttp.ts +++ b/packages/server/src/server/streamableHttp.ts @@ -17,11 +17,11 @@ import type { WebStandardStreamableHTTPServerTransportOptions } from './webStand import { WebStandardStreamableHTTPServerTransport } from './webStandardStreamableHttp.js'; /** - * Configuration options for StreamableHTTPServerTransport + * Configuration options for NodeStreamableHTTPServerTransport * * This is an alias for WebStandardStreamableHTTPServerTransportOptions for backward compatibility. */ -export type StreamableHTTPServerTransportOptions = WebStandardStreamableHTTPServerTransportOptions; +export type NodeStreamableHTTPServerTransportOptions = WebStandardStreamableHTTPServerTransportOptions; type NodeToWebRequestOptions = { parsedBody?: unknown; @@ -139,12 +139,12 @@ function writeWebResponse(res: ServerResponse, webResponse: Response): Promise randomUUID(), * }); * * // Stateless mode - explicitly set session ID to undefined - * const statelessTransport = new StreamableHTTPServerTransport({ + * const statelessTransport = new NodeStreamableHTTPServerTransport({ * sessionIdGenerator: undefined, * }); * @@ -165,10 +165,10 @@ function writeWebResponse(res: ServerResponse, webResponse: Response): Promise string) | undefined; @@ -49,7 +49,7 @@ interface TestServerConfig { /** * Helper to stop test server */ -async function stopTestServer({ server, transport }: { server: Server; transport: StreamableHTTPServerTransport }): Promise { +async function stopTestServer({ server, transport }: { server: Server; transport: NodeStreamableHTTPServerTransport }): Promise { // First close the transport to ensure all SSE streams are closed await transport.close(); @@ -153,7 +153,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { */ async function createTestServer(config: TestServerConfig = { sessionIdGenerator: () => randomUUID() }): Promise<{ server: Server; - transport: StreamableHTTPServerTransport; + transport: NodeStreamableHTTPServerTransport; mcpServer: McpServer; baseUrl: URL; }> { @@ -168,7 +168,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } ); - const transport = new StreamableHTTPServerTransport({ + const transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: config.sessionIdGenerator, enableJsonResponse: config.enableJsonResponse ?? false, eventStore: config.eventStore, @@ -202,7 +202,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { */ async function createTestAuthServer(config: TestServerConfig = { sessionIdGenerator: () => randomUUID() }): Promise<{ server: Server; - transport: StreamableHTTPServerTransport; + transport: NodeStreamableHTTPServerTransport; mcpServer: McpServer; baseUrl: URL; }> { @@ -217,7 +217,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } ); - const transport = new StreamableHTTPServerTransport({ + const transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: config.sessionIdGenerator, enableJsonResponse: config.enableJsonResponse ?? false, eventStore: config.eventStore, @@ -247,10 +247,10 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } const { z } = entry; - describe('StreamableHTTPServerTransport', () => { + describe('NodeStreamableHTTPServerTransport', () => { let server: Server; let mcpServer: McpServer; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; let baseUrl: URL; let sessionId: string; @@ -979,9 +979,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); }); - describe('StreamableHTTPServerTransport with AuthInfo', () => { + describe('NodeStreamableHTTPServerTransport with AuthInfo', () => { let server: Server; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; let baseUrl: URL; let sessionId: string; @@ -1079,9 +1079,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); // Test JSON Response Mode - describe('StreamableHTTPServerTransport with JSON Response Mode', () => { + describe('NodeStreamableHTTPServerTransport with JSON Response Mode', () => { let server: Server; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; let baseUrl: URL; let sessionId: string; @@ -1166,9 +1166,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); // Test pre-parsed body handling - describe('StreamableHTTPServerTransport with pre-parsed body', () => { + describe('NodeStreamableHTTPServerTransport with pre-parsed body', () => { let server: Server; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; let baseUrl: URL; let sessionId: string; let parsedBody: unknown = null; @@ -1302,9 +1302,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); // Test resumability support - describe('StreamableHTTPServerTransport with resumability', () => { + describe('NodeStreamableHTTPServerTransport with resumability', () => { let server: Server; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; let baseUrl: URL; let sessionId: string; let mcpServer: McpServer; @@ -1538,9 +1538,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); // Test stateless mode - describe('StreamableHTTPServerTransport in stateless mode', () => { + describe('NodeStreamableHTTPServerTransport in stateless mode', () => { let server: Server; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; let baseUrl: URL; beforeEach(async () => { @@ -1626,9 +1626,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); // Test SSE priming events for POST streams - describe('StreamableHTTPServerTransport POST SSE priming events', () => { + describe('NodeStreamableHTTPServerTransport POST SSE priming events', () => { let server: Server; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; let baseUrl: URL; let sessionId: string; let mcpServer: McpServer; @@ -2327,7 +2327,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); // Test onsessionclosed callback - describe('StreamableHTTPServerTransport onsessionclosed callback', () => { + describe('NodeStreamableHTTPServerTransport onsessionclosed callback', () => { it('should call onsessionclosed callback when session is closed via DELETE', async () => { const mockCallback = vi.fn(); @@ -2486,7 +2486,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); // Test async callbacks for onsessioninitialized and onsessionclosed - describe('StreamableHTTPServerTransport async callbacks', () => { + describe('NodeStreamableHTTPServerTransport async callbacks', () => { it('should support async onsessioninitialized callback', async () => { const initializationOrder: string[] = []; @@ -2693,9 +2693,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); // Test DNS rebinding protection - describe('StreamableHTTPServerTransport DNS rebinding protection', () => { + describe('NodeStreamableHTTPServerTransport DNS rebinding protection', () => { let server: Server; - let transport: StreamableHTTPServerTransport; + let transport: NodeStreamableHTTPServerTransport; let baseUrl: URL; afterEach(async () => { @@ -2931,7 +2931,7 @@ async function createTestServerWithDnsProtection(config: { enableDnsRebindingProtection?: boolean; }): Promise<{ server: Server; - transport: StreamableHTTPServerTransport; + transport: NodeStreamableHTTPServerTransport; mcpServer: McpServer; baseUrl: URL; }> { @@ -2948,7 +2948,7 @@ async function createTestServerWithDnsProtection(config: { }); } - const transport = new StreamableHTTPServerTransport({ + const transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: config.sessionIdGenerator, allowedHosts: config.allowedHosts, allowedOrigins: config.allowedOrigins, diff --git a/test/integration/test/stateManagementStreamableHttp.test.ts b/test/integration/test/stateManagementStreamableHttp.test.ts index 6839cba6b..72180b688 100644 --- a/test/integration/test/stateManagementStreamableHttp.test.ts +++ b/test/integration/test/stateManagementStreamableHttp.test.ts @@ -10,7 +10,7 @@ import { ListResourcesResultSchema, ListToolsResultSchema, McpServer, - StreamableHTTPServerTransport + NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; import type { ZodMatrixEntry } from '@modelcontextprotocol/test-helpers'; import { listenOnRandomPort, zodTestMatrix } from '@modelcontextprotocol/test-helpers'; @@ -69,7 +69,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { ); // Create transport with or without session management - const serverTransport = new StreamableHTTPServerTransport({ + const serverTransport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: withSessionManagement ? () => randomUUID() // With session management, generate UUID : undefined // Without session management, return undefined @@ -90,7 +90,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { describe('Stateless Mode', () => { let server: Server; let mcpServer: McpServer; - let serverTransport: StreamableHTTPServerTransport; + let serverTransport: NodeStreamableHTTPServerTransport; let baseUrl: URL; beforeEach(async () => { @@ -254,7 +254,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { describe('Stateful Mode', () => { let server: Server; let mcpServer: McpServer; - let serverTransport: StreamableHTTPServerTransport; + let serverTransport: NodeStreamableHTTPServerTransport; let baseUrl: URL; beforeEach(async () => { diff --git a/test/integration/test/taskLifecycle.test.ts b/test/integration/test/taskLifecycle.test.ts index d644db48e..5e3dfc408 100644 --- a/test/integration/test/taskLifecycle.test.ts +++ b/test/integration/test/taskLifecycle.test.ts @@ -15,7 +15,7 @@ import { McpError, McpServer, RELATED_TASK_META_KEY, - StreamableHTTPServerTransport, + NodeStreamableHTTPServerTransport, TaskSchema } from '@modelcontextprotocol/server'; import { listenOnRandomPort, waitForTaskStatus } from '@modelcontextprotocol/test-helpers'; @@ -24,7 +24,7 @@ import { z } from 'zod'; describe('Task Lifecycle Integration Tests', () => { let server: Server; let mcpServer: McpServer; - let serverTransport: StreamableHTTPServerTransport; + let serverTransport: NodeStreamableHTTPServerTransport; let baseUrl: URL; let taskStore: InMemoryTaskStore; @@ -189,7 +189,7 @@ describe('Task Lifecycle Integration Tests', () => { ); // Create transport - serverTransport = new StreamableHTTPServerTransport({ + serverTransport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID() }); diff --git a/test/integration/test/taskResumability.test.ts b/test/integration/test/taskResumability.test.ts index 5947649e4..7a1b15707 100644 --- a/test/integration/test/taskResumability.test.ts +++ b/test/integration/test/taskResumability.test.ts @@ -8,7 +8,7 @@ import { InMemoryEventStore, LoggingMessageNotificationSchema, McpServer, - StreamableHTTPServerTransport + NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; import type { ZodMatrixEntry } from '@modelcontextprotocol/test-helpers'; import { listenOnRandomPort, zodTestMatrix } from '@modelcontextprotocol/test-helpers'; @@ -18,7 +18,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { describe('Transport resumability', () => { let server: Server; let mcpServer: McpServer; - let serverTransport: StreamableHTTPServerTransport; + let serverTransport: NodeStreamableHTTPServerTransport; let baseUrl: URL; let eventStore: InMemoryEventStore; @@ -84,7 +84,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { ); // Create a transport with the event store - serverTransport = new StreamableHTTPServerTransport({ + serverTransport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), eventStore }); From 223fcb067b2d85adc80c91cf6277c760f8a59756 Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Sat, 20 Dec 2025 15:19:26 +0200 Subject: [PATCH 4/5] move back to hono/node-server for mapping incoming node request to web request --- .github/workflows/publish.yml | 4 +- README.md | 3 +- examples/server/src/simpleTaskInteractive.ts | 4 +- .../sseAndStreamableHttpCompatibleServer.ts | 2 +- packages/core/src/shared/protocol.ts | 4 +- packages/server-express/package.json | 3 +- .../server-express/src/auth/bearerAuth.ts | 14 +- packages/server-express/src/auth/router.ts | 135 ++-------------- packages/server/package.json | 11 +- packages/server/src/server/streamableHttp.ts | 152 ++++-------------- .../server/test/server/streamableHttp.test.ts | 3 +- pnpm-lock.yaml | 14 ++ pnpm-workspace.yaml | 1 + test/integration/test/taskLifecycle.test.ts | 2 +- .../integration/test/taskResumability.test.ts | 2 +- 15 files changed, 85 insertions(+), 269 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 8e69fc8a7..25cd12862 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -38,4 +38,6 @@ jobs: run: pnpm run build:all - name: Publish preview packages - run: pnpm dlx pkg-pr-new publish --packageManager=npm --pnpm './packages/server' './packages/client' './packages/server-express' './packages/server-hono' + run: + pnpm dlx pkg-pr-new publish --packageManager=npm --pnpm './packages/server' './packages/client' + './packages/server-express' './packages/server-hono' diff --git a/README.md b/README.md index dc0116c96..4d5270287 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,6 @@ # MCP TypeScript SDK -> [!IMPORTANT] -> **This is the `main` branch which contains v2 of the SDK (currently in development, pre-alpha).** +> [!IMPORTANT] **This is the `main` branch which contains v2 of the SDK (currently in development, pre-alpha).** > > We anticipate a stable v2 release in Q1 2026. Until then, **v1.x remains the recommended version** for production use. v1.x will continue to receive bug fixes and security updates for at least 6 months after v2 ships to give people time to upgrade. > diff --git a/examples/server/src/simpleTaskInteractive.ts b/examples/server/src/simpleTaskInteractive.ts index db4058054..469ecf0c2 100644 --- a/examples/server/src/simpleTaskInteractive.ts +++ b/examples/server/src/simpleTaskInteractive.ts @@ -40,9 +40,9 @@ import { InMemoryTaskStore, isTerminal, ListToolsRequestSchema, + NodeStreamableHTTPServerTransport, RELATED_TASK_META_KEY, - Server, - NodeStreamableHTTPServerTransport + Server } from '@modelcontextprotocol/server'; import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; diff --git a/examples/server/src/sseAndStreamableHttpCompatibleServer.ts b/examples/server/src/sseAndStreamableHttpCompatibleServer.ts index d54a5287c..bb2636ea3 100644 --- a/examples/server/src/sseAndStreamableHttpCompatibleServer.ts +++ b/examples/server/src/sseAndStreamableHttpCompatibleServer.ts @@ -1,7 +1,7 @@ import { randomUUID } from 'node:crypto'; import type { CallToolResult } from '@modelcontextprotocol/server'; -import { isInitializeRequest, McpServer, SSEServerTransport, NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { isInitializeRequest, McpServer, NodeStreamableHTTPServerTransport, SSEServerTransport } from '@modelcontextprotocol/server'; import { createMcpExpressApp } from '@modelcontextprotocol/server-express'; import type { Request, Response } from 'express'; import * as z from 'zod/v4'; diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index 0a7e6aa1f..c9242e96d 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -292,14 +292,14 @@ export type RequestHandlerExtra void; /** * Closes the standalone GET SSE stream, triggering client reconnection. - * Only available when using NodeStreamableHTTPServerTransport with eventStore configured. + * Only available when using aStreamableHTTPServerTransport with eventStore configured. * Use this to implement polling behavior for server-initiated notifications. */ closeStandaloneSSEStream?: () => void; diff --git a/packages/server-express/package.json b/packages/server-express/package.json index 51fde8931..bca9ac505 100644 --- a/packages/server-express/package.json +++ b/packages/server-express/package.json @@ -45,7 +45,8 @@ "dependencies": { "@modelcontextprotocol/server": "workspace:^", "express": "catalog:runtimeServerOnly", - "express-rate-limit": "catalog:runtimeServerOnly" + "express-rate-limit": "catalog:runtimeServerOnly", + "@remix-run/node-fetch-server": "catalog:runtimeServerOnly" }, "devDependencies": { "@modelcontextprotocol/tsconfig": "workspace:^", diff --git a/packages/server-express/src/auth/bearerAuth.ts b/packages/server-express/src/auth/bearerAuth.ts index a923ff796..d8d0aad8b 100644 --- a/packages/server-express/src/auth/bearerAuth.ts +++ b/packages/server-express/src/auth/bearerAuth.ts @@ -3,7 +3,8 @@ import { URL } from 'node:url'; import type { AuthInfo } from '@modelcontextprotocol/core'; import type { BearerAuthMiddlewareOptions } from '@modelcontextprotocol/server'; import { requireBearerAuth as requireBearerAuthWeb } from '@modelcontextprotocol/server'; -import type { NextFunction, Request as ExpressRequest, RequestHandler, Response as ExpressResponse } from 'express'; +import { sendResponse } from '@remix-run/node-fetch-server'; +import type { NextFunction, Request as ExpressRequest, RequestHandler } from 'express'; declare module 'express-serve-static-core' { interface Request { @@ -21,15 +22,6 @@ function expressRequestUrl(req: ExpressRequest): URL { return new URL(path, `${protocol}://${host}`); } -async function writeWebResponse(res: ExpressResponse, webResponse: Response): Promise { - res.status(webResponse.status); - for (const [k, v] of webResponse.headers.entries()) { - res.setHeader(k, v); - } - const bodyText = await webResponse.text(); - res.send(bodyText); -} - /** * Express middleware wrapper for the Web-standard `requireBearerAuth` helper. * @@ -54,7 +46,7 @@ export function requireBearerAuth(options: BearerAuthMiddlewareOptions): Request return; } - await writeWebResponse(res, result.response); + await sendResponse(res, result.response); } catch (err) { next(err); } diff --git a/packages/server-express/src/auth/router.ts b/packages/server-express/src/auth/router.ts index b367dc46c..868149efd 100644 --- a/packages/server-express/src/auth/router.ts +++ b/packages/server-express/src/auth/router.ts @@ -1,126 +1,15 @@ -import type { IncomingMessage } from 'node:http'; -import { Readable } from 'node:stream'; -import { URL } from 'node:url'; - -import type { AuthMetadataOptions, AuthRouterOptions, WebHandlerContext } from '@modelcontextprotocol/server'; +import type { AuthMetadataOptions, AuthRouterOptions } from '@modelcontextprotocol/server'; import { + getParsedBody, mcpAuthMetadataRouter as createWebAuthMetadataRouter, mcpAuthRouter as createWebAuthRouter, TooManyRequestsError } from '@modelcontextprotocol/server'; -import type { RequestHandler, Response as ExpressResponse } from 'express'; +import { createRequest, sendResponse } from '@remix-run/node-fetch-server'; +import type { RequestHandler } from 'express'; import express from 'express'; import { rateLimit } from 'express-rate-limit'; -type ExpressRequestLike = IncomingMessage & { - method: string; - headers: Record; - originalUrl?: string; - url?: string; - protocol?: string; - // express adds this when trust proxy is enabled - ip?: string; - body?: unknown; - get?: (name: string) => string | undefined; -}; - -function expressRequestUrl(req: ExpressRequestLike): URL { - const host = req.get?.('host') ?? req.headers.host ?? 'localhost'; - const proto = req.protocol ?? 'http'; - const path = req.originalUrl ?? req.url ?? '/'; - return new URL(path, `${proto}://${host}`); -} - -function toHeaders(req: ExpressRequestLike): Headers { - const headers = new Headers(); - for (const [key, value] of Object.entries(req.headers)) { - if (value === undefined) continue; - if (Array.isArray(value)) { - headers.set(key, value.join(', ')); - } else { - headers.set(key, value); - } - } - return headers; -} - -async function readBody(req: IncomingMessage): Promise { - const chunks: Buffer[] = []; - for await (const chunk of req) { - chunks.push(Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk)); - } - return Buffer.concat(chunks); -} - -async function expressToWebRequest(req: ExpressRequestLike, parsedBodyProvided: boolean): Promise { - const url = expressRequestUrl(req); - const headers = toHeaders(req); - - // If upstream body parsing ran, the Node stream is likely consumed. - if (parsedBodyProvided) { - return new Request(url, { method: req.method, headers }); - } - - if (req.method === 'GET' || req.method === 'HEAD') { - return new Request(url, { method: req.method, headers }); - } - - const body = await readBody(req); - return new Request(url, { method: req.method, headers, body }); -} - -async function writeWebResponse(res: ExpressResponse, webResponse: Response): Promise { - res.status(webResponse.status); - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const getSetCookie = (webResponse.headers as any).getSetCookie as (() => string[]) | undefined; - const setCookies = typeof getSetCookie === 'function' ? getSetCookie.call(webResponse.headers) : undefined; - - for (const [key, value] of webResponse.headers.entries()) { - if (key.toLowerCase() === 'set-cookie' && setCookies?.length) continue; - res.setHeader(key, value); - } - - if (setCookies?.length) { - res.setHeader('set-cookie', setCookies); - } - - res.flushHeaders?.(); - - if (!webResponse.body) { - res.end(); - return; - } - - await new Promise((resolve, reject) => { - const readable = Readable.fromWeb(webResponse.body as unknown as ReadableStream); - readable.on('error', err => { - try { - res.destroy(err as Error); - } catch { - // ignore - } - reject(err); - }); - res.on('error', reject); - res.on('close', () => { - try { - readable.destroy(); - } catch { - // ignore - } - }); - readable.pipe(res); - res.on('finish', () => resolve()); - }); -} - -function toHandlerContext(req: ExpressRequestLike): WebHandlerContext { - return { - parsedBody: req.body - }; -} - export type ExpressAuthRateLimitOptions = | false | { @@ -172,10 +61,10 @@ export function mcpAuthRouter(options: AuthRouterOptions & { rateLimit?: Express } handlers.push(async (req, res, next) => { try { - const parsedBodyProvided = (req as ExpressRequestLike).body !== undefined; - const webReq = await expressToWebRequest(req as ExpressRequestLike, parsedBodyProvided); - const webRes = await route.handler(webReq, toHandlerContext(req as ExpressRequestLike)); - await writeWebResponse(res, webRes); + const webReq = createRequest(req, res); + const parsedBody = req.body !== undefined ? req.body : await getParsedBody(webReq); + const webRes = await route.handler(webReq, { parsedBody }); + await sendResponse(res, webRes); } catch (err) { next(err); } @@ -198,10 +87,10 @@ export function mcpAuthMetadataRouter(options: AuthMetadataOptions): RequestHand for (const route of web.routes) { router.all(route.path, async (req, res, next) => { try { - const parsedBodyProvided = (req as ExpressRequestLike).body !== undefined; - const webReq = await expressToWebRequest(req as ExpressRequestLike, parsedBodyProvided); - const webRes = await route.handler(webReq, toHandlerContext(req as ExpressRequestLike)); - await writeWebResponse(res, webRes); + const webReq = createRequest(req, res); + const parsedBody = req.body !== undefined ? req.body : await getParsedBody(webReq); + const webRes = await route.handler(webReq, { parsedBody }); + await sendResponse(res, webRes); } catch (err) { next(err); } diff --git a/packages/server/package.json b/packages/server/package.json index 4f32cf171..20bd77aae 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -44,9 +44,10 @@ "client": "tsx scripts/cli.ts client" }, "dependencies": { + "@hono/node-server": "catalog:runtimeServerOnly", "content-type": "catalog:runtimeServerOnly", - "raw-body": "catalog:runtimeServerOnly", "pkce-challenge": "catalog:runtimeShared", + "raw-body": "catalog:runtimeServerOnly", "zod": "catalog:runtimeShared", "zod-to-json-schema": "catalog:runtimeShared" }, @@ -63,13 +64,13 @@ } }, "devDependencies": { + "@cfworker/json-schema": "catalog:runtimeShared", + "@eslint/js": "catalog:devTools", "@modelcontextprotocol/core": "workspace:^", - "@modelcontextprotocol/tsconfig": "workspace:^", - "@modelcontextprotocol/vitest-config": "workspace:^", "@modelcontextprotocol/eslint-config": "workspace:^", "@modelcontextprotocol/test-helpers": "workspace:^", - "@cfworker/json-schema": "catalog:runtimeShared", - "@eslint/js": "catalog:devTools", + "@modelcontextprotocol/tsconfig": "workspace:^", + "@modelcontextprotocol/vitest-config": "workspace:^", "@types/content-type": "catalog:devTools", "@types/cors": "catalog:devTools", "@types/cross-spawn": "catalog:devTools", diff --git a/packages/server/src/server/streamableHttp.ts b/packages/server/src/server/streamableHttp.ts index 65b39c52c..10a990196 100644 --- a/packages/server/src/server/streamableHttp.ts +++ b/packages/server/src/server/streamableHttp.ts @@ -8,143 +8,37 @@ */ import type { IncomingMessage, ServerResponse } from 'node:http'; -import { Readable } from 'node:stream'; -import { URL } from 'node:url'; +import { getRequestListener } from '@hono/node-server'; import type { AuthInfo, JSONRPCMessage, MessageExtraInfo, RequestId, Transport } from '@modelcontextprotocol/core'; import type { WebStandardStreamableHTTPServerTransportOptions } from './webStandardStreamableHttp.js'; import { WebStandardStreamableHTTPServerTransport } from './webStandardStreamableHttp.js'; /** - * Configuration options for NodeStreamableHTTPServerTransport + * Configuration options for StreamableHTTPServerTransport * * This is an alias for WebStandardStreamableHTTPServerTransportOptions for backward compatibility. */ -export type NodeStreamableHTTPServerTransportOptions = WebStandardStreamableHTTPServerTransportOptions; - -type NodeToWebRequestOptions = { - parsedBody?: unknown; -}; - -function getRequestUrl(req: IncomingMessage): URL { - const host = req.headers.host ?? 'localhost'; - const isTls = Boolean((req.socket as { encrypted?: boolean } | undefined)?.encrypted); - const protocol = isTls ? 'https' : 'http'; - const path = req.url ?? '/'; - return new URL(path, `${protocol}://${host}`); -} - -function toHeaders(req: IncomingMessage): Headers { - const headers = new Headers(); - for (const [key, value] of Object.entries(req.headers)) { - if (value === undefined) continue; - if (Array.isArray(value)) { - // Preserve multi-value headers as a comma-joined value. - // (Set-Cookie does not appear on requests; this is fine here.) - headers.set(key, value.join(', ')); - } else { - headers.set(key, value); - } - } - return headers; -} - -async function readBody(req: IncomingMessage): Promise { - const chunks: Buffer[] = []; - for await (const chunk of req) { - chunks.push(Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk)); - } - return Buffer.concat(chunks); -} - -async function nodeToWebRequest(req: IncomingMessage, options?: NodeToWebRequestOptions): Promise { - const url = getRequestUrl(req); - const method = req.method ?? 'GET'; - const headers = toHeaders(req); - - // If an upstream framework already parsed the body, the IncomingMessage stream - // may be consumed; rely on parsedBody instead of trying to read again. - if (options?.parsedBody !== undefined) { - return new Request(url, { method, headers }); - } - - // Only attach bodies for methods that can carry one. - if (method === 'GET' || method === 'HEAD') { - return new Request(url, { method, headers }); - } - - const body = await readBody(req); - return new Request(url, { method, headers, body }); -} - -function writeWebResponse(res: ServerResponse, webResponse: Response): Promise { - res.statusCode = webResponse.status; - - // Prefer undici's multi Set-Cookie support when available. - // Note: must call with the correct `this` (undici brand-checks Headers). - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const getSetCookie = (webResponse.headers as any).getSetCookie as (() => string[]) | undefined; - const setCookies = typeof getSetCookie === 'function' ? getSetCookie.call(webResponse.headers) : undefined; - - for (const [key, value] of webResponse.headers.entries()) { - // We'll handle Set-Cookie separately if we have structured values. - if (key.toLowerCase() === 'set-cookie' && setCookies?.length) continue; - res.setHeader(key, value); - } - - if (setCookies?.length) { - res.setHeader('set-cookie', setCookies); - } - - // Node requires writing headers before streaming body. - res.flushHeaders?.(); - - if (!webResponse.body) { - res.end(); - return Promise.resolve(); - } - - return new Promise((resolve, reject) => { - const readable = Readable.fromWeb(webResponse.body as unknown as ReadableStream); - readable.on('error', err => { - try { - res.destroy(err as Error); - } catch { - // ignore - } - reject(err); - }); - res.on('error', reject); - res.on('close', () => { - try { - readable.destroy(); - } catch { - // ignore - } - }); - readable.pipe(res); - res.on('finish', () => resolve()); - }); -} +export type StreamableHTTPServerTransportOptions = WebStandardStreamableHTTPServerTransportOptions; /** * Server transport for Streamable HTTP: this implements the MCP Streamable HTTP transport specification. * It supports both SSE streaming and direct HTTP responses. * * This is a wrapper around `WebStandardStreamableHTTPServerTransport` that provides Node.js HTTP compatibility. - * It converts between Node.js HTTP (IncomingMessage/ServerResponse) and Web Standard Request/Response. + * It uses the `@hono/node-server` library to convert between Node.js HTTP and Web Standard APIs. * * Usage example: * * ```typescript * // Stateful mode - server sets the session ID - * const statefulTransport = new NodeStreamableHTTPServerTransport({ + * const statefulTransport = new StreamableHTTPServerTransport({ * sessionIdGenerator: () => randomUUID(), * }); * * // Stateless mode - explicitly set session ID to undefined - * const statelessTransport = new NodeStreamableHTTPServerTransport({ + * const statelessTransport = new StreamableHTTPServerTransport({ * sessionIdGenerator: undefined, * }); * @@ -167,9 +61,23 @@ function writeWebResponse(res: ServerResponse, webResponse: Response): Promise; + // Store auth and parsedBody per request for passing through to handleRequest + private _requestContext: WeakMap = new WeakMap(); - constructor(options: NodeStreamableHTTPServerTransportOptions = {}) { + constructor(options: StreamableHTTPServerTransportOptions = {}) { this._webStandardTransport = new WebStandardStreamableHTTPServerTransport(options); + + // Create a request listener that wraps the web standard transport + // getRequestListener converts Node.js HTTP to Web Standard and properly handles SSE streaming + this._requestListener = getRequestListener(async (webRequest: Request) => { + // Get context if available (set during handleRequest) + const context = this._requestContext.get(webRequest); + return this._webStandardTransport.handleRequest(webRequest, { + authInfo: context?.authInfo, + parsedBody: context?.parsedBody + }); + }); } /** @@ -245,13 +153,21 @@ export class NodeStreamableHTTPServerTransport implements Transport { * @param parsedBody - Optional pre-parsed body from body-parser middleware */ async handleRequest(req: IncomingMessage & { auth?: AuthInfo }, res: ServerResponse, parsedBody?: unknown): Promise { + // Store context for this request to pass through auth and parsedBody + // We need to intercept the request creation to attach this context const authInfo = req.auth; - const webRequest = await nodeToWebRequest(req, { parsedBody }); - const webResponse = await this._webStandardTransport.handleRequest(webRequest, { - authInfo, - parsedBody + + // Create a custom handler that includes our context + const handler = getRequestListener(async (webRequest: Request) => { + return this._webStandardTransport.handleRequest(webRequest, { + authInfo, + parsedBody + }); }); - await writeWebResponse(res, webResponse); + + // Delegate to the request listener which handles all the Node.js <-> Web Standard conversion + // including proper SSE streaming support + await handler(req, res); } /** diff --git a/packages/server/test/server/streamableHttp.test.ts b/packages/server/test/server/streamableHttp.test.ts index 5a5230940..57e47668b 100644 --- a/packages/server/test/server/streamableHttp.test.ts +++ b/packages/server/test/server/streamableHttp.test.ts @@ -17,7 +17,8 @@ import { listenOnRandomPort } from '@modelcontextprotocol/test-helpers'; import { McpServer } from '../../src/server/mcp.js'; import { NodeStreamableHTTPServerTransport } from '../../src/server/streamableHttp.js'; import type { EventId, EventStore, StreamId } from '../../src/server/webStandardStreamableHttp.js'; -import { type ZodMatrixEntry, zodTestMatrix } from './__fixtures__/zodTestMatrix.js'; +import type { ZodMatrixEntry } from './__fixtures__/zodTestMatrix.js'; +import { zodTestMatrix } from './__fixtures__/zodTestMatrix.js'; async function getFreePort() { return new Promise(res => { diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index bca9f21f7..c2b9d0835 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -89,6 +89,9 @@ catalogs: '@hono/node-server': specifier: ^1.19.7 version: 1.19.7 + '@remix-run/node-fetch-server': + specifier: ^0.13.0 + version: 0.13.0 content-type: specifier: ^1.0.5 version: 1.0.5 @@ -564,6 +567,9 @@ importers: packages/server: dependencies: + '@hono/node-server': + specifier: catalog:runtimeServerOnly + version: 1.19.7(hono@4.11.1) content-type: specifier: catalog:runtimeServerOnly version: 1.0.5 @@ -661,6 +667,9 @@ importers: '@modelcontextprotocol/server': specifier: workspace:^ version: link:../server + '@remix-run/node-fetch-server': + specifier: catalog:runtimeServerOnly + version: 0.13.0 express: specifier: catalog:runtimeServerOnly version: 5.1.0 @@ -1207,6 +1216,9 @@ packages: '@quansync/fs@1.0.0': resolution: {integrity: sha512-4TJ3DFtlf1L5LDMaM6CanJ/0lckGNtJcMjQ1NAV6zDmA0tEHKZtxNKin8EgPaVX1YzljbxckyT2tJrpQKAtngQ==} + '@remix-run/node-fetch-server@0.13.0': + resolution: {integrity: sha512-1EsNo0ZpgXu/90AWoRZf/oE3RVTUS80tiTUpt+hv5pjtAkw7icN4WskDwz/KdAw5ARbJLMhZBrO1NqThmy/McA==} + '@rolldown/binding-android-arm64@1.0.0-beta.53': resolution: {integrity: sha512-Ok9V8o7o6YfSdTTYA/uHH30r3YtOxLD6G3wih/U9DO0ucBBFq8WPt/DslU53OgfteLRHITZny9N/qCUxMf9kjQ==} engines: {node: ^20.19.0 || >=22.12.0} @@ -3878,6 +3890,8 @@ snapshots: dependencies: quansync: 1.0.0 + '@remix-run/node-fetch-server@0.13.0': {} + '@rolldown/binding-android-arm64@1.0.0-beta.53': optional: true diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index 55aac1aba..a7222dd71 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -21,6 +21,7 @@ catalogs: express: ^5.0.1 express-rate-limit: ^8.2.1 raw-body: ^3.0.0 + '@remix-run/node-fetch-server': ^0.13.0 runtimeClientOnly: jose: ^6.1.1 cross-spawn: ^7.0.5 diff --git a/test/integration/test/taskLifecycle.test.ts b/test/integration/test/taskLifecycle.test.ts index 5e3dfc408..324da6aa2 100644 --- a/test/integration/test/taskLifecycle.test.ts +++ b/test/integration/test/taskLifecycle.test.ts @@ -14,8 +14,8 @@ import { InMemoryTaskStore, McpError, McpServer, - RELATED_TASK_META_KEY, NodeStreamableHTTPServerTransport, + RELATED_TASK_META_KEY, TaskSchema } from '@modelcontextprotocol/server'; import { listenOnRandomPort, waitForTaskStatus } from '@modelcontextprotocol/test-helpers'; diff --git a/test/integration/test/taskResumability.test.ts b/test/integration/test/taskResumability.test.ts index 4e4625561..db60e2d4e 100644 --- a/test/integration/test/taskResumability.test.ts +++ b/test/integration/test/taskResumability.test.ts @@ -3,13 +3,13 @@ import type { Server } from 'node:http'; import { createServer } from 'node:http'; import { Client, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; +import type { EventStore, JSONRPCMessage } from '@modelcontextprotocol/server'; import { CallToolResultSchema, LoggingMessageNotificationSchema, McpServer, NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; -import type { EventStore, JSONRPCMessage } from '@modelcontextprotocol/server'; import type { ZodMatrixEntry } from '@modelcontextprotocol/test-helpers'; import { listenOnRandomPort, zodTestMatrix } from '@modelcontextprotocol/test-helpers'; From aaeff28619a2c4d93551d958db1cab3ad945f417 Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Sat, 20 Dec 2025 15:37:19 +0200 Subject: [PATCH 5/5] hono-server updates --- packages/server-hono/src/auth/bearerAuth.ts | 19 ++ packages/server-hono/src/auth/router.ts | 58 ++++-- packages/server-hono/src/hono.ts | 90 +++++++++ packages/server-hono/src/index.ts | 2 + packages/server-hono/src/streamableHttp.ts | 11 +- packages/server-hono/test/server-hono.test.ts | 177 +++++++++++++++++- 6 files changed, 335 insertions(+), 22 deletions(-) create mode 100644 packages/server-hono/src/auth/bearerAuth.ts create mode 100644 packages/server-hono/src/hono.ts diff --git a/packages/server-hono/src/auth/bearerAuth.ts b/packages/server-hono/src/auth/bearerAuth.ts new file mode 100644 index 000000000..1258eee11 --- /dev/null +++ b/packages/server-hono/src/auth/bearerAuth.ts @@ -0,0 +1,19 @@ +import type { BearerAuthMiddlewareOptions } from '@modelcontextprotocol/server'; +import { requireBearerAuth as requireBearerAuthWeb } from '@modelcontextprotocol/server'; +import type { MiddlewareHandler } from 'hono'; +/** + * Hono middleware wrapper for the Web-standard `requireBearerAuth` helper. + * + * On success, sets `c.set('auth', authInfo)` and calls `next()`. + * On failure, returns the JSON error response. + */ +export function requireBearerAuth(options: BearerAuthMiddlewareOptions): MiddlewareHandler { + return async (c, next) => { + const result = await requireBearerAuthWeb(c.req.raw, options); + if ('authInfo' in result) { + c.set('auth', result.authInfo); + return await next(); + } + return result.response; + }; +} diff --git a/packages/server-hono/src/auth/router.ts b/packages/server-hono/src/auth/router.ts index 4c61c1d2c..f17765318 100644 --- a/packages/server-hono/src/auth/router.ts +++ b/packages/server-hono/src/auth/router.ts @@ -1,33 +1,61 @@ import type { AuthMetadataOptions, AuthRoute, AuthRouterOptions } from '@modelcontextprotocol/server'; -import { mcpAuthMetadataRouter as createWebAuthMetadataRouter, mcpAuthRouter as createWebAuthRouter } from '@modelcontextprotocol/server'; -import type { Handler, Hono } from 'hono'; - -export type RegisterMcpAuthRoutesOptions = AuthRouterOptions; +import { + getParsedBody, + mcpAuthMetadataRouter as createWebAuthMetadataRouter, + mcpAuthRouter as createWebAuthRouter +} from '@modelcontextprotocol/server'; +import type { Handler } from 'hono'; +import { Hono } from 'hono'; /** - * Registers the standard MCP OAuth endpoints on a Hono app. + * Hono router adapter for the Web-standard `mcpAuthRouter` from `@modelcontextprotocol/server`. + * + * IMPORTANT: This router MUST be mounted at the application root. * - * IMPORTANT: These routes MUST be mounted at the application root. + * @example + * ```ts + * app.route('/', mcpAuthRouter(...)) + * ``` */ -export function registerMcpAuthRoutes(app: Hono, options: RegisterMcpAuthRoutesOptions): void { +export function mcpAuthRouter(options: AuthRouterOptions): Hono { const web = createWebAuthRouter(options); - registerRoutes(app, web.routes); + const router = new Hono(); + registerRoutes(router, web.routes); + return router; } /** - * Registers only the auth metadata endpoints (RFC 8414 + RFC 9728) on a Hono app. + * Hono router adapter for the Web-standard `mcpAuthMetadataRouter` from `@modelcontextprotocol/server`. * - * IMPORTANT: These routes MUST be mounted at the application root. + * IMPORTANT: This router MUST be mounted at the application root. */ -export function registerMcpAuthMetadataRoutes(app: Hono, options: AuthMetadataOptions): void { +export function mcpAuthMetadataRouter(options: AuthMetadataOptions): Hono { const web = createWebAuthMetadataRouter(options); - registerRoutes(app, web.routes); + const router = new Hono(); + registerRoutes(router, web.routes); + return router; } function registerRoutes(app: Hono, routes: AuthRoute[]): void { for (const route of routes) { - // Hono's `on()` expects methods like 'GET', 'POST', etc. - const handler: Handler = c => route.handler(c.req.raw); - app.on(route.methods, route.path, handler); + // Use `all()` so unsupported methods still reach the handler and can return 405, + // matching the Express adapter behavior. + const handler: Handler = async c => { + let parsedBody = c.get('parsedBody'); + if (parsedBody === undefined && c.req.method === 'POST') { + // Parse from a clone so we don't consume the original request stream. + parsedBody = await getParsedBody(c.req.raw.clone()); + } + return route.handler(c.req.raw, { parsedBody }); + }; + app.all(route.path, handler); } } + +export function registerMcpAuthRoutes(app: Hono, options: AuthRouterOptions): void { + app.route('/', mcpAuthRouter(options)); +} + +export function registerMcpAuthMetadataRoutes(app: Hono, options: AuthMetadataOptions): void { + app.route('/', mcpAuthMetadataRouter(options)); +} diff --git a/packages/server-hono/src/hono.ts b/packages/server-hono/src/hono.ts new file mode 100644 index 000000000..accf4ab27 --- /dev/null +++ b/packages/server-hono/src/hono.ts @@ -0,0 +1,90 @@ +import type { Context } from 'hono'; +import { Hono } from 'hono'; + +import { hostHeaderValidation, localhostHostValidation } from './middleware/hostHeaderValidation.js'; + +/** + * Options for creating an MCP Hono application. + */ +export interface CreateMcpHonoAppOptions { + /** + * The hostname to bind to. Defaults to '127.0.0.1'. + * When set to '127.0.0.1', 'localhost', or '::1', DNS rebinding protection is automatically enabled. + */ + host?: string; + + /** + * List of allowed hostnames for DNS rebinding protection. + * If provided, host header validation will be applied using this list. + * For IPv6, provide addresses with brackets (e.g., '[::1]'). + * + * This is useful when binding to '0.0.0.0' or '::' but still wanting + * to restrict which hostnames are allowed. + */ + allowedHosts?: string[]; +} + +/** + * Creates a Hono application pre-configured for MCP servers. + * + * When the host is '127.0.0.1', 'localhost', or '::1' (the default is '127.0.0.1'), + * DNS rebinding protection middleware is automatically applied to protect against + * DNS rebinding attacks on localhost servers. + * + * This also installs a small JSON body parsing middleware (similar to `express.json()`) + * that stashes the parsed body into `c.set('parsedBody', ...)` when `Content-Type` includes + * `application/json`. + * + * @param options - Configuration options + * @returns A configured Hono application + */ +export function createMcpHonoApp(options: CreateMcpHonoAppOptions = {}): Hono { + const { host = '127.0.0.1', allowedHosts } = options; + + const app = new Hono(); + + // Similar to `express.json()`: parse JSON bodies and make them available to MCP adapters via `parsedBody`. + app.use('*', async (c: Context, next) => { + // If an upstream middleware already set parsedBody, keep it. + if (c.get('parsedBody') !== undefined) { + return await next(); + } + + const ct = c.req.header('content-type') ?? ''; + if (!ct.includes('application/json')) { + return await next(); + } + + try { + // Parse from a clone so we don't consume the original request stream. + const parsed = await c.req.raw.clone().json(); + c.set('parsedBody', parsed); + } catch { + // Mirror express.json() behavior loosely: reject invalid JSON. + return c.text('Invalid JSON', 400); + } + + return await next(); + }); + + // If allowedHosts is explicitly provided, use that for validation. + if (allowedHosts) { + app.use('*', hostHeaderValidation(allowedHosts)); + } else { + // Apply DNS rebinding protection automatically for localhost hosts. + const localhostHosts = ['127.0.0.1', 'localhost', '::1']; + if (localhostHosts.includes(host)) { + app.use('*', localhostHostValidation()); + } else if (host === '0.0.0.0' || host === '::') { + // Warn when binding to all interfaces without DNS rebinding protection. + // eslint-disable-next-line no-console + console.warn( + `Warning: Server is binding to ${host} without DNS rebinding protection. ` + + 'Consider using the allowedHosts option to restrict allowed hosts, ' + + 'or use authentication to protect your server.' + ); + } + } + + return app; +} diff --git a/packages/server-hono/src/index.ts b/packages/server-hono/src/index.ts index 5a7cb5129..bc6de4318 100644 --- a/packages/server-hono/src/index.ts +++ b/packages/server-hono/src/index.ts @@ -1,3 +1,5 @@ +export * from './auth/bearerAuth.js'; export * from './auth/router.js'; +export * from './hono.js'; export * from './middleware/hostHeaderValidation.js'; export * from './streamableHttp.js'; diff --git a/packages/server-hono/src/streamableHttp.ts b/packages/server-hono/src/streamableHttp.ts index d81960713..2da1bafcd 100644 --- a/packages/server-hono/src/streamableHttp.ts +++ b/packages/server-hono/src/streamableHttp.ts @@ -1,4 +1,5 @@ import type { WebStandardStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { getParsedBody } from '@modelcontextprotocol/server'; import type { Context, Handler } from 'hono'; /** @@ -10,5 +11,13 @@ import type { Context, Handler } from 'hono'; * ``` */ export function mcpStreamableHttpHandler(transport: WebStandardStreamableHTTPServerTransport): Handler { - return (c: Context) => transport.handleRequest(c.req.raw); + return async (c: Context) => { + let parsedBody = c.get('parsedBody'); + if (parsedBody === undefined && c.req.method === 'POST') { + // Parse from a clone so we don't consume the original request stream. + parsedBody = await getParsedBody(c.req.raw.clone()); + } + const authInfo = c.get('auth'); + return transport.handleRequest(c.req.raw, { authInfo, parsedBody }); + }; } diff --git a/packages/server-hono/test/server-hono.test.ts b/packages/server-hono/test/server-hono.test.ts index 8b143411b..130e11c71 100644 --- a/packages/server-hono/test/server-hono.test.ts +++ b/packages/server-hono/test/server-hono.test.ts @@ -1,22 +1,36 @@ import type { AuthorizationParams, OAuthClientInformationFull, OAuthServerProvider, OAuthTokens } from '@modelcontextprotocol/server'; +import type { Context } from 'hono'; import { Hono } from 'hono'; +import { vi } from 'vitest'; -import { registerMcpAuthRoutes } from '../src/auth/router.js'; +import { mcpAuthRouter } from '../src/auth/router.js'; +import { createMcpHonoApp } from '../src/hono.js'; import { hostHeaderValidation } from '../src/middleware/hostHeaderValidation.js'; import { mcpStreamableHttpHandler } from '../src/streamableHttp.js'; describe('@modelcontextprotocol/server-hono', () => { - test('mcpStreamableHttpHandler delegates to transport.handleRequest', async () => { - const calls: { url?: string; method?: string }[] = []; + test('mcpStreamableHttpHandler delegates to transport.handleRequest (and passes authInfo + parsedBody when set)', async () => { + const calls: { url?: string; method?: string; options?: unknown }[] = []; const transport = { - async handleRequest(req: Request): Promise { - calls.push({ url: req.url, method: req.method }); + async handleRequest(req: Request, options?: unknown): Promise { + calls.push({ url: req.url, method: req.method, options }); return new Response('ok', { status: 200, headers: { 'content-type': 'text/plain' } }); } }; const app = new Hono(); + app.use('/mcp', async (c: Context, next) => { + // Upstream middleware can pre-parse and stash body + auth. + c.set('parsedBody', { hello: 'world' }); + c.set('auth', { + token: 't', + clientId: 'c', + scopes: [], + expiresAt: Math.floor(Date.now() / 1000) + 60 + }); + return await next(); + }); app.all('/mcp', mcpStreamableHttpHandler(transport as unknown as Parameters[0])); const res = await app.request('http://localhost/mcp', { method: 'POST' }); @@ -25,6 +39,12 @@ describe('@modelcontextprotocol/server-hono', () => { expect(calls).toHaveLength(1); expect(calls[0]!.method).toBe('POST'); expect(calls[0]!.url).toBe('http://localhost/mcp'); + expect(calls[0]!.options).toEqual( + expect.objectContaining({ + parsedBody: { hello: 'world' }, + authInfo: expect.objectContaining({ clientId: 'c' }) + }) + ); }); test('hostHeaderValidation blocks invalid Host and allows valid Host', async () => { @@ -93,7 +113,7 @@ describe('@modelcontextprotocol/server-hono', () => { }; const app = new Hono(); - registerMcpAuthRoutes(app, { provider, issuerUrl: new URL('https://auth.example.com') }); + app.route('/', mcpAuthRouter({ provider, issuerUrl: new URL('https://auth.example.com') })); const metadata = await app.request('http://localhost/.well-known/oauth-authorization-server', { method: 'GET' }); expect(metadata.status).toBe(200); @@ -111,4 +131,149 @@ describe('@modelcontextprotocol/server-hono', () => { expect(location).toContain('code=mock_auth_code'); expect(location).toContain('state=s'); }); + + test('registerMcpAuthRoutes returns 405 (not 404) for unsupported methods', async () => { + const provider: OAuthServerProvider = { + clientsStore: { + async getClient() { + return undefined; + } + }, + async authorize() { + throw new Error('not used'); + }, + async challengeForAuthorizationCode() { + throw new Error('not used'); + }, + async exchangeAuthorizationCode() { + throw new Error('not used'); + }, + async exchangeRefreshToken() { + throw new Error('not used'); + }, + async verifyAccessToken() { + throw new Error('not used'); + } + }; + + const app = new Hono(); + app.route('/', mcpAuthRouter({ provider, issuerUrl: new URL('https://auth.example.com') })); + + const res = await app.request('http://localhost/authorize', { method: 'PUT' }); + expect(res.status).toBe(405); + }); + + test('registerMcpAuthRoutes passes parsedBody to web handlers (POST /authorize works with empty raw body)', async () => { + const validClient: OAuthClientInformationFull = { + client_id: 'valid-client', + client_secret: 'valid-secret', + redirect_uris: ['https://example.com/callback'] + }; + + const provider: OAuthServerProvider = { + clientsStore: { + async getClient(clientId: string) { + return clientId === 'valid-client' ? validClient : undefined; + } + }, + async authorize(_client: OAuthClientInformationFull, params: AuthorizationParams): Promise { + const u = new URL(params.redirectUri); + u.searchParams.set('code', 'mock_auth_code'); + if (params.state) u.searchParams.set('state', params.state); + return Response.redirect(u.toString(), 302); + }, + async challengeForAuthorizationCode(): Promise { + return 'mock_challenge'; + }, + async exchangeAuthorizationCode(): Promise { + return { + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' + }; + }, + async exchangeRefreshToken(): Promise { + return { + access_token: 'new_mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'new_mock_refresh_token' + }; + }, + async verifyAccessToken() { + throw new Error('not used'); + } + }; + + const app = new Hono(); + app.use('/authorize', async (c: Context, next) => { + c.set('parsedBody', { + client_id: 'valid-client', + response_type: 'code', + code_challenge: 'x', + code_challenge_method: 'S256', + redirect_uri: 'https://example.com/callback', + state: 's' + }); + return await next(); + }); + app.route('/', mcpAuthRouter({ provider, issuerUrl: new URL('https://auth.example.com') })); + + const authorize = await app.request('http://localhost/authorize', { method: 'POST' }); + expect(authorize.status).toBe(302); + const location = authorize.headers.get('location')!; + expect(location).toContain('https://example.com/callback'); + expect(location).toContain('code=mock_auth_code'); + expect(location).toContain('state=s'); + }); + + test('createMcpHonoApp enables localhost DNS rebinding protection by default', async () => { + const app = createMcpHonoApp(); + app.get('/health', c => c.text('ok')); + + const bad = await app.request('http://localhost/health', { headers: { Host: 'evil.com:3000' } }); + expect(bad.status).toBe(403); + + const good = await app.request('http://localhost/health', { headers: { Host: 'localhost:3000' } }); + expect(good.status).toBe(200); + }); + + test('createMcpHonoApp uses allowedHosts when provided (even when binding to 0.0.0.0)', async () => { + const warn = vi.spyOn(console, 'warn').mockImplementation(() => {}); + const app = createMcpHonoApp({ host: '0.0.0.0', allowedHosts: ['myapp.local'] }); + warn.mockRestore(); + + app.get('/health', c => c.text('ok')); + + const bad = await app.request('http://localhost/health', { headers: { Host: 'evil.com:3000' } }); + expect(bad.status).toBe(403); + + const good = await app.request('http://localhost/health', { headers: { Host: 'myapp.local:3000' } }); + expect(good.status).toBe(200); + }); + + test('createMcpHonoApp does not apply host validation for 0.0.0.0 without allowedHosts', async () => { + const warn = vi.spyOn(console, 'warn').mockImplementation(() => {}); + const app = createMcpHonoApp({ host: '0.0.0.0' }); + warn.mockRestore(); + + app.get('/health', c => c.text('ok')); + + const res = await app.request('http://localhost/health', { headers: { Host: 'evil.com:3000' } }); + expect(res.status).toBe(200); + }); + + test('createMcpHonoApp parses JSON bodies into parsedBody (express.json()-like)', async () => { + const app = createMcpHonoApp(); + app.post('/echo', (c: Context) => c.json(c.get('parsedBody'))); + + const res = await app.request('http://localhost/echo', { + method: 'POST', + headers: { Host: 'localhost:3000', 'content-type': 'application/json' }, + body: JSON.stringify({ a: 1 }) + }); + expect(res.status).toBe(200); + expect(await res.json()).toEqual({ a: 1 }); + }); });