Skip to content

Commit 753347d

Browse files
committed
mcp: add StreamableHTTPOptions.SessionTimeout
Add a timeout option for the streamable HTTP handler that automatically cleans up idle sessions. Also, fix a bug in the streamable client, where we hang on a request even though the client can never get a response (because the HTTP request terminated without a response or Last-Event-Id). Fixes #499
1 parent cfa7a51 commit 753347d

File tree

3 files changed

+187
-43
lines changed

3 files changed

+187
-43
lines changed

mcp/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,7 @@ func (s *Server) disconnect(cc *ServerSession) {
825825
type ServerSessionOptions struct {
826826
State *ServerSessionState
827827

828-
onClose func()
828+
onClose func() // used to clean up associated resources
829829
}
830830

831831
// Connect connects the MCP server over the given transport and starts handling

mcp/streamable.go

Lines changed: 121 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"fmt"
1313
"io"
1414
"log/slog"
15+
"maps"
1516
"math"
1617
"math/rand/v2"
1718
"net/http"
@@ -40,12 +41,46 @@ type StreamableHTTPHandler struct {
4041
getServer func(*http.Request) *Server
4142
opts StreamableHTTPOptions
4243

43-
onTransportDeletion func(sessionID string) // for testing only
44+
onTransportDeletion func(sessionID string) // for testing
4445

45-
mu sync.Mutex
46-
// TODO: we should store the ServerSession along with the transport, because
47-
// we need to cancel keepalive requests when closing the transport.
48-
transports map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header)
46+
mu sync.Mutex
47+
sessions map[string]*sessionInfo // keyed by session ID
48+
}
49+
50+
type sessionInfo struct {
51+
session *ServerSession
52+
transport *StreamableServerTransport
53+
54+
// If timeout is set, automatically close the session after an idle period.
55+
timeout time.Duration
56+
timerMu sync.Mutex
57+
timer *time.Timer
58+
}
59+
60+
// resetTimeout resets the inactivity timer.
61+
func (i *sessionInfo) resetTimeout() {
62+
if i.timeout <= 0 {
63+
return
64+
}
65+
66+
i.timerMu.Lock()
67+
defer i.timerMu.Unlock()
68+
69+
if i.timer == nil {
70+
return
71+
}
72+
// Reset the timer if we successfully stopped it.
73+
i.timer.Reset(i.timeout)
74+
}
75+
76+
// stopTimer stops the inactivity timer.
77+
func (i *sessionInfo) stopTimer() {
78+
i.timerMu.Lock()
79+
defer i.timerMu.Unlock()
80+
if i.timer != nil {
81+
i.timer.Stop()
82+
i.timer = nil
83+
}
4984
}
5085

5186
// StreamableHTTPOptions configures the StreamableHTTPHandler.
@@ -77,6 +112,14 @@ type StreamableHTTPOptions struct {
77112
// If set, EventStore will be used to persist stream events and replay them
78113
// upon stream resumption.
79114
EventStore EventStore
115+
116+
// SessionTimeout configures a timeout for idle sessions.
117+
//
118+
// When sessions receive no new HTTP requests from the client for this
119+
// duration, they are automatically closed.
120+
//
121+
// If SessionTimeout is the zero value, idle sessions are never closed.
122+
SessionTimeout time.Duration
80123
}
81124

82125
// NewStreamableHTTPHandler returns a new [StreamableHTTPHandler].
@@ -86,8 +129,8 @@ type StreamableHTTPOptions struct {
86129
// If getServer returns nil, a 400 Bad Request will be served.
87130
func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *StreamableHTTPOptions) *StreamableHTTPHandler {
88131
h := &StreamableHTTPHandler{
89-
getServer: getServer,
90-
transports: make(map[string]*StreamableServerTransport),
132+
getServer: getServer,
133+
sessions: make(map[string]*sessionInfo),
91134
}
92135
if opts != nil {
93136
h.opts = *opts
@@ -100,20 +143,27 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea
100143
return h
101144
}
102145

103-
// closeAll closes all ongoing sessions.
146+
// closeAll closes all ongoing sessions, for tests.
104147
//
105148
// TODO(rfindley): investigate the best API for callers to configure their
106149
// session lifecycle. (?)
107150
//
108151
// Should we allow passing in a session store? That would allow the handler to
109152
// be stateless.
110153
func (h *StreamableHTTPHandler) closeAll() {
154+
// TODO: if we ever expose this outside of tests, we'll need to do better
155+
// than simply collecting sessions while holding the lock: we need to prevent
156+
// new sessions from being added.
157+
//
158+
// Currently, sessions remove themselves from h.sessions when closed, so we
159+
// can't call Close while holding the lock.
111160
h.mu.Lock()
112-
defer h.mu.Unlock()
113-
for _, s := range h.transports {
114-
s.connection.Close()
161+
sessionInfos := slices.Collect(maps.Values(h.sessions))
162+
h.sessions = nil
163+
h.mu.Unlock()
164+
for _, s := range sessionInfos {
165+
s.session.Close()
115166
}
116-
h.transports = nil
117167
}
118168

119169
func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
@@ -144,12 +194,12 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
144194
}
145195

146196
sessionID := req.Header.Get(sessionIDHeader)
147-
var transport *StreamableServerTransport
197+
var sessInfo *sessionInfo
148198
if sessionID != "" {
149199
h.mu.Lock()
150-
transport = h.transports[sessionID]
200+
sessInfo = h.sessions[sessionID]
151201
h.mu.Unlock()
152-
if transport == nil && !h.opts.Stateless {
202+
if sessInfo == nil && !h.opts.Stateless {
153203
// Unless we're in 'stateless' mode, which doesn't perform any Session-ID
154204
// validation, we require that the session ID matches a known session.
155205
//
@@ -164,11 +214,10 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
164214
http.Error(w, "Bad Request: DELETE requires an Mcp-Session-Id header", http.StatusBadRequest)
165215
return
166216
}
167-
if transport != nil { // transport may be nil in stateless mode
168-
h.mu.Lock()
169-
delete(h.transports, transport.SessionID)
170-
h.mu.Unlock()
171-
transport.connection.Close()
217+
if sessInfo != nil { // sessInfo may be nil in stateless mode
218+
// Closing the session also removes it from h.sessions, due to the
219+
// onClose callback.
220+
sessInfo.session.Close()
172221
}
173222
w.WriteHeader(http.StatusNoContent)
174223
return
@@ -225,7 +274,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
225274
return
226275
}
227276

228-
if transport == nil {
277+
if sessInfo == nil {
229278
server := h.getServer(req)
230279
if server == nil {
231280
// The getServer argument to NewStreamableHTTPHandler returned nil.
@@ -237,7 +286,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
237286
// existing transport.
238287
sessionID = server.opts.GetSessionID()
239288
}
240-
transport = &StreamableServerTransport{
289+
transport := &StreamableServerTransport{
241290
SessionID: sessionID,
242291
Stateless: h.opts.Stateless,
243292
EventStore: h.opts.EventStore,
@@ -301,10 +350,13 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
301350
connectOpts = &ServerSessionOptions{
302351
onClose: func() {
303352
h.mu.Lock()
304-
delete(h.transports, transport.SessionID)
305-
h.mu.Unlock()
306-
if h.onTransportDeletion != nil {
307-
h.onTransportDeletion(transport.SessionID)
353+
defer h.mu.Unlock()
354+
if info, ok := h.sessions[transport.SessionID]; ok {
355+
info.stopTimer()
356+
delete(h.sessions, transport.SessionID)
357+
if h.onTransportDeletion != nil {
358+
h.onTransportDeletion(transport.SessionID)
359+
}
308360
}
309361
},
310362
}
@@ -313,23 +365,36 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
313365
// Pass req.Context() here, to allow middleware to add context values.
314366
// The context is detached in the jsonrpc2 library when handling the
315367
// long-running stream.
316-
ss, err := server.Connect(req.Context(), transport, connectOpts)
368+
session, err := server.Connect(req.Context(), transport, connectOpts)
317369
if err != nil {
318370
http.Error(w, "failed connection", http.StatusInternalServerError)
319371
return
320372
}
373+
sessInfo = &sessionInfo{
374+
session: session,
375+
transport: transport,
376+
}
321377
if h.opts.Stateless {
322378
// Stateless mode: close the session when the request exits.
323-
defer ss.Close() // close the fake session after handling the request
379+
defer session.Close() // close the fake session after handling the request
324380
} else {
325381
// Otherwise, save the transport so that it can be reused
382+
383+
// Clean up the session when it times out.
384+
if h.opts.SessionTimeout > 0 {
385+
sessInfo.timeout = h.opts.SessionTimeout
386+
sessInfo.timer = time.AfterFunc(sessInfo.timeout, func() {
387+
sessInfo.session.Close()
388+
})
389+
}
326390
h.mu.Lock()
327-
h.transports[transport.SessionID] = transport
391+
h.sessions[transport.SessionID] = sessInfo
328392
h.mu.Unlock()
329393
}
330394
}
331395

332-
transport.ServeHTTP(w, req)
396+
sessInfo.resetTimeout()
397+
sessInfo.transport.ServeHTTP(w, req)
333398
}
334399

335400
// A StreamableServerTransport implements the server side of the MCP streamable
@@ -1383,9 +1448,12 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
13831448
go c.handleJSON(requestSummary, resp)
13841449

13851450
case "text/event-stream":
1386-
jsonReq, _ := msg.(*jsonrpc.Request)
1451+
var forCall *jsonrpc.Request
1452+
if jsonReq, ok := msg.(*jsonrpc.Request); ok && jsonReq.IsCall() {
1453+
forCall = jsonReq
1454+
}
13871455
// TODO: should we cancel this logical SSE request if/when jsonReq is canceled?
1388-
go c.handleSSE(requestSummary, resp, false, jsonReq)
1456+
go c.handleSSE(requestSummary, resp, false, forCall)
13891457

13901458
default:
13911459
resp.Body.Close()
@@ -1435,9 +1503,9 @@ func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Resp
14351503
// handleSSE manages the lifecycle of an SSE connection. It can be either
14361504
// persistent (for the main GET listener) or temporary (for a POST response).
14371505
//
1438-
// If forReq is set, it is the request that initiated the stream, and the
1506+
// If forCall is set, it is the call that initiated the stream, and the
14391507
// stream is complete when we receive its response.
1440-
func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *http.Response, persistent bool, forReq *jsonrpc2.Request) {
1508+
func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *http.Response, persistent bool, forCall *jsonrpc2.Request) {
14411509
resp := initialResp
14421510
var lastEventID string
14431511
for {
@@ -1447,7 +1515,7 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt
14471515
// Eventually, if we don't get the response, we should stop trying and
14481516
// fail the request.
14491517
if resp != nil {
1450-
eventID, clientClosed := c.processStream(requestSummary, resp, forReq)
1518+
eventID, clientClosed := c.processStream(requestSummary, resp, forCall)
14511519
lastEventID = eventID
14521520

14531521
// If the connection was closed by the client, we're done.
@@ -1510,11 +1578,11 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt
15101578
// incoming channel. It returns the ID of the last processed event and a flag
15111579
// indicating if the connection was closed by the client. If resp is nil, it
15121580
// returns "", false.
1513-
func (c *streamableClientConn) processStream(requestSummary string, resp *http.Response, forReq *jsonrpc.Request) (lastEventID string, clientClosed bool) {
1581+
func (c *streamableClientConn) processStream(requestSummary string, resp *http.Response, forCall *jsonrpc.Request) (lastEventID string, clientClosed bool) {
15141582
defer resp.Body.Close()
15151583
for evt, err := range scanEvents(resp.Body) {
15161584
if err != nil {
1517-
return lastEventID, false
1585+
break
15181586
}
15191587

15201588
if evt.ID != "" {
@@ -1529,10 +1597,10 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
15291597

15301598
select {
15311599
case c.incoming <- msg:
1532-
if jsonResp, ok := msg.(*jsonrpc.Response); ok && forReq != nil {
1600+
if jsonResp, ok := msg.(*jsonrpc.Response); ok && forCall != nil {
15331601
// TODO: we should never get a response when forReq is nil (the standalone SSE request).
15341602
// We should detect this case.
1535-
if jsonResp.ID == forReq.ID {
1603+
if jsonResp.ID == forCall.ID {
15361604
return "", true
15371605
}
15381606
}
@@ -1542,7 +1610,20 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
15421610
}
15431611
}
15441612
// The loop finished without an error, indicating the server closed the stream.
1545-
return "", false
1613+
//
1614+
// If the lastEventID is "", the stream is not retryable and we should
1615+
// report a synthetic error for the call.
1616+
if lastEventID == "" && forCall != nil {
1617+
errmsg := &jsonrpc2.Response{
1618+
ID: forCall.ID,
1619+
Error: fmt.Errorf("request terminated without response"),
1620+
}
1621+
select {
1622+
case c.incoming <- errmsg:
1623+
case <-c.done:
1624+
}
1625+
}
1626+
return lastEventID, false
15461627
}
15471628

15481629
// reconnect handles the logic of retrying a connection with an exponential

0 commit comments

Comments
 (0)