Skip to content

Commit d30480d

Browse files
stantonktzolovchemicL
committed
feat: Add transport context extraction support to all MCP servers (#477)
- Add McpTransportContextExtractor to WebFlux/WebMVC SSE and Streamable transport providers - Enable extraction of HTTP transport metadata (headers, etc.) for use during request processing - Pass transport context through reactive context chain using McpTransportContext.KEY - Add contextExtractor() builder methods for configuring custom extractors - Update HttpServlet transport providers with same context extraction capability - Modify McpServerSession to properly propagate transport context to handlers - Add test coverage with TEST_CONTEXT_EXTRACTOR in integration tests This allows MCP feature implementations to access HTTP transport level metadata that was present at request time, enabling use cases like authentication, request tracing, and custom header processing. Signed-off-by: Christian Tzolov <christian.tzolov@broadcom.com> Co-authored-by: Christian Tzolov <christian.tzolov@broadcom.com> Co-authored-by: Dariusz Jędrzejczyk <dariusz.jedrzejczyk@broadcom.com>
1 parent a6a8c4f commit d30480d

File tree

14 files changed

+374
-24
lines changed

14 files changed

+374
-24
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111

1212
import com.fasterxml.jackson.core.type.TypeReference;
1313
import com.fasterxml.jackson.databind.ObjectMapper;
14+
15+
import io.modelcontextprotocol.server.DefaultMcpTransportContext;
16+
import io.modelcontextprotocol.server.McpTransportContext;
17+
import io.modelcontextprotocol.server.McpTransportContextExtractor;
1418
import io.modelcontextprotocol.spec.McpError;
1519
import io.modelcontextprotocol.spec.McpSchema;
1620
import io.modelcontextprotocol.spec.McpServerSession;
@@ -115,6 +119,8 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv
115119
*/
116120
private final ConcurrentHashMap<String, McpServerSession> sessions = new ConcurrentHashMap<>();
117121

122+
private McpTransportContextExtractor<ServerRequest> contextExtractor;
123+
118124
/**
119125
* Flag indicating if the transport is shutting down.
120126
*/
@@ -194,15 +200,38 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU
194200
@Deprecated
195201
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
196202
String sseEndpoint, Duration keepAliveInterval) {
203+
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval,
204+
(serverRequest, context) -> context);
205+
}
206+
207+
/**
208+
* Constructs a new WebFlux SSE server transport provider instance.
209+
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
210+
* of MCP messages. Must not be null.
211+
* @param baseUrl webflux message base path
212+
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
213+
* messages. This endpoint will be communicated to clients during SSE connection
214+
* setup. Must not be null.
215+
* @param sseEndpoint The SSE endpoint path. Must not be null.
216+
* @param keepAliveInterval The interval for sending keep-alive pings to clients.
217+
* @param contextExtractor The context extractor to use for extracting MCP transport
218+
* context from HTTP requests. Must not be null.
219+
* @throws IllegalArgumentException if either parameter is null
220+
*/
221+
private WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
222+
String sseEndpoint, Duration keepAliveInterval,
223+
McpTransportContextExtractor<ServerRequest> contextExtractor) {
197224
Assert.notNull(objectMapper, "ObjectMapper must not be null");
198225
Assert.notNull(baseUrl, "Message base path must not be null");
199226
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
200227
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
228+
Assert.notNull(contextExtractor, "Context extractor must not be null");
201229

202230
this.objectMapper = objectMapper;
203231
this.baseUrl = baseUrl;
204232
this.messageEndpoint = messageEndpoint;
205233
this.sseEndpoint = sseEndpoint;
234+
this.contextExtractor = contextExtractor;
206235
this.routerFunction = RouterFunctions.route()
207236
.GET(this.sseEndpoint, this::handleSseConnection)
208237
.POST(this.messageEndpoint, this::handleMessage)
@@ -315,6 +344,8 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
315344
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
316345
}
317346

347+
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
348+
318349
return ServerResponse.ok()
319350
.contentType(MediaType.TEXT_EVENT_STREAM)
320351
.body(Flux.<ServerSentEvent<?>>create(sink -> {
@@ -336,7 +367,7 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
336367
logger.debug("Session {} cancelled", sessionId);
337368
sessions.remove(sessionId);
338369
});
339-
}), ServerSentEvent.class);
370+
}).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), ServerSentEvent.class);
340371
}
341372

342373
/**
@@ -370,6 +401,8 @@ private Mono<ServerResponse> handleMessage(ServerRequest request) {
370401
.bodyValue(new McpError("Session not found: " + request.queryParam("sessionId").get()));
371402
}
372403

404+
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
405+
373406
return request.bodyToMono(String.class).flatMap(body -> {
374407
try {
375408
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);
@@ -386,7 +419,7 @@ private Mono<ServerResponse> handleMessage(ServerRequest request) {
386419
logger.error("Failed to deserialize message: {}", e.getMessage());
387420
return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format"));
388421
}
389-
});
422+
}).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext));
390423
}
391424

392425
private class WebFluxMcpSessionTransport implements McpServerTransport {
@@ -458,6 +491,8 @@ public static class Builder {
458491

459492
private Duration keepAliveInterval;
460493

494+
private McpTransportContextExtractor<ServerRequest> contextExtractor = (serverRequest, context) -> context;
495+
461496
/**
462497
* Sets the ObjectMapper to use for JSON serialization/deserialization of MCP
463498
* messages.
@@ -519,6 +554,22 @@ public Builder keepAliveInterval(Duration keepAliveInterval) {
519554
return this;
520555
}
521556

557+
/**
558+
* Sets the context extractor that allows providing the MCP feature
559+
* implementations to inspect HTTP transport level metadata that was present at
560+
* HTTP request processing time. This allows to extract custom headers and other
561+
* useful data for use during execution later on in the process.
562+
* @param contextExtractor The contextExtractor to fill in a
563+
* {@link McpTransportContext}.
564+
* @return this builder instance
565+
* @throws IllegalArgumentException if contextExtractor is null
566+
*/
567+
public Builder contextExtractor(McpTransportContextExtractor<ServerRequest> contextExtractor) {
568+
Assert.notNull(contextExtractor, "contextExtractor must not be null");
569+
this.contextExtractor = contextExtractor;
570+
return this;
571+
}
572+
522573
/**
523574
* Builds a new instance of {@link WebFluxSseServerTransportProvider} with the
524575
* configured settings.
@@ -530,7 +581,7 @@ public WebFluxSseServerTransportProvider build() {
530581
Assert.notNull(messageEndpoint, "Message endpoint must be set");
531582

532583
return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint,
533-
keepAliveInterval);
584+
keepAliveInterval, contextExtractor);
534585
}
535586

536587
}

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,9 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
191191
String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID);
192192
return ServerResponse.ok()
193193
.contentType(MediaType.TEXT_EVENT_STREAM)
194-
.body(session.replay(lastId), ServerSentEvent.class);
194+
.body(session.replay(lastId)
195+
.contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)),
196+
ServerSentEvent.class);
195197
}
196198

197199
return ServerResponse.ok()
@@ -202,7 +204,9 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
202204
McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session
203205
.listeningStream(sessionTransport);
204206
sink.onDispose(listeningStream::close);
205-
}), ServerSentEvent.class);
207+
// TODO Clarify why the outer context is not present in the
208+
// Flux.create sink?
209+
}).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), ServerSentEvent.class);
206210

207211
}).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext));
208212
}
@@ -282,7 +286,10 @@ else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) {
282286
return true;
283287
}).contextWrite(sink.contextView()).subscribe();
284288
sink.onCancel(streamSubscription);
285-
}), ServerSentEvent.class);
289+
// TODO Clarify why the outer context is not present in the
290+
// Flux.create sink?
291+
}).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)),
292+
ServerSentEvent.class);
286293
}
287294
else {
288295
return ServerResponse.badRequest().bodyValue(new McpError("Unknown message type"));

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
1414
import org.springframework.web.reactive.function.client.WebClient;
1515
import org.springframework.web.reactive.function.server.RouterFunctions;
16+
import org.springframework.web.reactive.function.server.ServerRequest;
1617

1718
import com.fasterxml.jackson.databind.ObjectMapper;
1819

@@ -22,6 +23,7 @@
2223
import io.modelcontextprotocol.server.McpServer;
2324
import io.modelcontextprotocol.server.McpServer.AsyncSpecification;
2425
import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification;
26+
import io.modelcontextprotocol.server.McpTransportContextExtractor;
2527
import io.modelcontextprotocol.server.TestUtil;
2628
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
2729
import reactor.netty.DisposableServer;
@@ -40,6 +42,11 @@ class WebFluxSseIntegrationTests extends AbstractMcpClientServerIntegrationTests
4042

4143
private WebFluxSseServerTransportProvider mcpServerTransportProvider;
4244

45+
static McpTransportContextExtractor<ServerRequest> TEST_CONTEXT_EXTRACTOR = (r, tc) -> {
46+
tc.put("important", "value");
47+
return tc;
48+
};
49+
4350
@Override
4451
protected void prepareClients(int port, String mcpEndpoint) {
4552

@@ -75,6 +82,7 @@ public void before() {
7582
.objectMapper(new ObjectMapper())
7683
.messageEndpoint(CUSTOM_MESSAGE_ENDPOINT)
7784
.sseEndpoint(CUSTOM_SSE_ENDPOINT)
85+
.contextExtractor(TEST_CONTEXT_EXTRACTOR)
7886
.build();
7987

8088
HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction());

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
1414
import org.springframework.web.reactive.function.client.WebClient;
1515
import org.springframework.web.reactive.function.server.RouterFunctions;
16+
import org.springframework.web.reactive.function.server.ServerRequest;
1617

1718
import com.fasterxml.jackson.databind.ObjectMapper;
1819

@@ -22,6 +23,7 @@
2223
import io.modelcontextprotocol.server.McpServer;
2324
import io.modelcontextprotocol.server.McpServer.AsyncSpecification;
2425
import io.modelcontextprotocol.server.McpServer.SyncSpecification;
26+
import io.modelcontextprotocol.server.McpTransportContextExtractor;
2527
import io.modelcontextprotocol.server.TestUtil;
2628
import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider;
2729
import reactor.netty.DisposableServer;
@@ -38,6 +40,11 @@ class WebFluxStreamableIntegrationTests extends AbstractMcpClientServerIntegrati
3840

3941
private WebFluxStreamableServerTransportProvider mcpStreamableServerTransportProvider;
4042

43+
static McpTransportContextExtractor<ServerRequest> TEST_CONTEXT_EXTRACTOR = (r, tc) -> {
44+
tc.put("important", "value");
45+
return tc;
46+
};
47+
4148
@Override
4249
protected void prepareClients(int port, String mcpEndpoint) {
4350

@@ -71,6 +78,7 @@ public void before() {
7178
this.mcpStreamableServerTransportProvider = WebFluxStreamableServerTransportProvider.builder()
7279
.objectMapper(new ObjectMapper())
7380
.messageEndpoint(CUSTOM_MESSAGE_ENDPOINT)
81+
.contextExtractor(TEST_CONTEXT_EXTRACTOR)
7482
.build();
7583

7684
HttpHandler httpHandler = RouterFunctions

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313

1414
import com.fasterxml.jackson.core.type.TypeReference;
1515
import com.fasterxml.jackson.databind.ObjectMapper;
16+
17+
import io.modelcontextprotocol.server.DefaultMcpTransportContext;
18+
import io.modelcontextprotocol.server.McpTransportContext;
19+
import io.modelcontextprotocol.server.McpTransportContextExtractor;
1620
import io.modelcontextprotocol.spec.McpError;
1721
import io.modelcontextprotocol.spec.McpSchema;
1822
import io.modelcontextprotocol.spec.McpServerTransport;
@@ -106,6 +110,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi
106110
*/
107111
private final ConcurrentHashMap<String, McpServerSession> sessions = new ConcurrentHashMap<>();
108112

113+
private McpTransportContextExtractor<ServerRequest> contextExtractor;
114+
109115
/**
110116
* Flag indicating if the transport is shutting down.
111117
*/
@@ -177,23 +183,47 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr
177183
* messages via HTTP POST. This endpoint will be communicated to clients through the
178184
* SSE connection's initial endpoint event.
179185
* @param sseEndpoint The endpoint URI where clients establish their SSE connections.
180-
* * @param keepAliveInterval The interval for sending keep-alive messages to
186+
* @param keepAliveInterval The interval for sending keep-alive messages to clients.
181187
* @throws IllegalArgumentException if any parameter is null
182188
* @deprecated Use the builder {@link #builder()} instead for better configuration
183189
* options.
184190
*/
185191
@Deprecated
186192
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
187193
String sseEndpoint, Duration keepAliveInterval) {
194+
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval,
195+
(serverRequest, context) -> context);
196+
}
197+
198+
/**
199+
* Constructs a new WebMvcSseServerTransportProvider instance.
200+
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
201+
* of messages.
202+
* @param baseUrl The base URL for the message endpoint, used to construct the full
203+
* endpoint URL for clients.
204+
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
205+
* messages via HTTP POST. This endpoint will be communicated to clients through the
206+
* SSE connection's initial endpoint event.
207+
* @param sseEndpoint The endpoint URI where clients establish their SSE connections.
208+
* @param keepAliveInterval The interval for sending keep-alive messages to clients.
209+
* @param contextExtractor The contextExtractor to fill in a
210+
* {@link McpTransportContext}.
211+
* @throws IllegalArgumentException if any parameter is null
212+
*/
213+
private WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
214+
String sseEndpoint, Duration keepAliveInterval,
215+
McpTransportContextExtractor<ServerRequest> contextExtractor) {
188216
Assert.notNull(objectMapper, "ObjectMapper must not be null");
189217
Assert.notNull(baseUrl, "Message base URL must not be null");
190218
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
191219
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
220+
Assert.notNull(contextExtractor, "Context extractor must not be null");
192221

193222
this.objectMapper = objectMapper;
194223
this.baseUrl = baseUrl;
195224
this.messageEndpoint = messageEndpoint;
196225
this.sseEndpoint = sseEndpoint;
226+
this.contextExtractor = contextExtractor;
197227
this.routerFunction = RouterFunctions.route()
198228
.GET(this.sseEndpoint, this::handleSseConnection)
199229
.POST(this.messageEndpoint, this::handleMessage)
@@ -367,11 +397,17 @@ private ServerResponse handleMessage(ServerRequest request) {
367397
}
368398

369399
try {
400+
final McpTransportContext transportContext = this.contextExtractor.extract(request,
401+
new DefaultMcpTransportContext());
402+
370403
String body = request.body(String.class);
371404
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);
372405

373406
// Process the message through the session's handle method
374-
session.handle(message).block(); // Block for WebMVC compatibility
407+
session.handle(message).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); // Block
408+
// for
409+
// WebMVC
410+
// compatibility
375411

376412
return ServerResponse.ok().build();
377413
}
@@ -517,6 +553,8 @@ public static class Builder {
517553

518554
private Duration keepAliveInterval;
519555

556+
private McpTransportContextExtractor<ServerRequest> contextExtractor = (serverRequest, context) -> context;
557+
520558
/**
521559
* Sets the JSON object mapper to use for message serialization/deserialization.
522560
* @param objectMapper The object mapper to use
@@ -576,6 +614,22 @@ public Builder keepAliveInterval(Duration keepAliveInterval) {
576614
return this;
577615
}
578616

617+
/**
618+
* Sets the context extractor that allows providing the MCP feature
619+
* implementations to inspect HTTP transport level metadata that was present at
620+
* HTTP request processing time. This allows to extract custom headers and other
621+
* useful data for use during execution later on in the process.
622+
* @param contextExtractor The contextExtractor to fill in a
623+
* {@link McpTransportContext}.
624+
* @return this builder instance
625+
* @throws IllegalArgumentException if contextExtractor is null
626+
*/
627+
public Builder contextExtractor(McpTransportContextExtractor<ServerRequest> contextExtractor) {
628+
Assert.notNull(contextExtractor, "contextExtractor must not be null");
629+
this.contextExtractor = contextExtractor;
630+
return this;
631+
}
632+
579633
/**
580634
* Builds a new instance of WebMvcSseServerTransportProvider with the configured
581635
* settings.
@@ -587,7 +641,7 @@ public WebMvcSseServerTransportProvider build() {
587641
throw new IllegalStateException("MessageEndpoint must be set");
588642
}
589643
return new WebMvcSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint,
590-
keepAliveInterval);
644+
keepAliveInterval, contextExtractor);
591645
}
592646

593647
}

mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.springframework.web.reactive.function.client.WebClient;
1818
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
1919
import org.springframework.web.servlet.function.RouterFunction;
20+
import org.springframework.web.servlet.function.ServerRequest;
2021
import org.springframework.web.servlet.function.ServerResponse;
2122

2223
import com.fasterxml.jackson.databind.ObjectMapper;
@@ -39,6 +40,11 @@ class WebMvcSseIntegrationTests extends AbstractMcpClientServerIntegrationTests
3940

4041
private WebMvcSseServerTransportProvider mcpServerTransportProvider;
4142

43+
static McpTransportContextExtractor<ServerRequest> TEST_CONTEXT_EXTRACTOR = (r, tc) -> {
44+
tc.put("important", "value");
45+
return tc;
46+
};
47+
4248
@Override
4349
protected void prepareClients(int port, String mcpEndpoint) {
4450

@@ -60,6 +66,7 @@ public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() {
6066
return WebMvcSseServerTransportProvider.builder()
6167
.objectMapper(new ObjectMapper())
6268
.messageEndpoint(MESSAGE_ENDPOINT)
69+
.contextExtractor(TEST_CONTEXT_EXTRACTOR)
6370
.build();
6471
}
6572

0 commit comments

Comments
 (0)