@@ -12,6 +12,7 @@ import (
1212 "fmt"
1313 "io"
1414 "log/slog"
15+ "maps"
1516 "math"
1617 "math/rand/v2"
1718 "net/http"
@@ -40,12 +41,46 @@ type StreamableHTTPHandler struct {
4041 getServer func (* http.Request ) * Server
4142 opts StreamableHTTPOptions
4243
43- onTransportDeletion func (sessionID string ) // for testing only
44+ onTransportDeletion func (sessionID string ) // for testing
4445
45- mu sync.Mutex
46- // TODO: we should store the ServerSession along with the transport, because
47- // we need to cancel keepalive requests when closing the transport.
48- transports map [string ]* StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header)
46+ mu sync.Mutex
47+ sessions map [string ]* sessionInfo // keyed by session ID
48+ }
49+
50+ type sessionInfo struct {
51+ session * ServerSession
52+ transport * StreamableServerTransport
53+
54+ // If timeout is set, automatically close the session after an idle period.
55+ timeout time.Duration
56+ timerMu sync.Mutex
57+ timer * time.Timer
58+ }
59+
60+ // resetTimeout resets the inactivity timer.
61+ func (i * sessionInfo ) resetTimeout () {
62+ if i .timeout <= 0 {
63+ return
64+ }
65+
66+ i .timerMu .Lock ()
67+ defer i .timerMu .Unlock ()
68+
69+ if i .timer == nil {
70+ return
71+ }
72+ // Reset the timer if we successfully stopped it.
73+ i .timer .Reset (i .timeout )
74+ }
75+
76+ // stopTimer stops the inactivity timer.
77+ func (i * sessionInfo ) stopTimer () {
78+ i .timerMu .Lock ()
79+ defer i .timerMu .Unlock ()
80+ if i .timer != nil {
81+ i .timer .Stop ()
82+ i .timer = nil
83+ }
4984}
5085
5186// StreamableHTTPOptions configures the StreamableHTTPHandler.
@@ -72,6 +107,14 @@ type StreamableHTTPOptions struct {
72107 // If nil, do not log.
73108 Logger * slog.Logger
74109
110+ // SessionTimeout configures a timeout for idle sessions.
111+ //
112+ // When sessions receive no new HTTP requests from the client for this
113+ // duration, they are automatically closed.
114+ //
115+ // If SessionTimeout is the zero value, idle sessions are never closed.
116+ SessionTimeout time.Duration
117+
75118 // TODO(rfindley): file a proposal to export this option, or something equivalent.
76119 configureTransport func (req * http.Request , transport * StreamableServerTransport )
77120}
@@ -83,8 +126,8 @@ type StreamableHTTPOptions struct {
83126// If getServer returns nil, a 400 Bad Request will be served.
84127func NewStreamableHTTPHandler (getServer func (* http.Request ) * Server , opts * StreamableHTTPOptions ) * StreamableHTTPHandler {
85128 h := & StreamableHTTPHandler {
86- getServer : getServer ,
87- transports : make (map [string ]* StreamableServerTransport ),
129+ getServer : getServer ,
130+ sessions : make (map [string ]* sessionInfo ),
88131 }
89132 if opts != nil {
90133 h .opts = * opts
@@ -97,20 +140,27 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea
97140 return h
98141}
99142
100- // closeAll closes all ongoing sessions.
143+ // closeAll closes all ongoing sessions, for tests .
101144//
102145// TODO(rfindley): investigate the best API for callers to configure their
103146// session lifecycle. (?)
104147//
105148// Should we allow passing in a session store? That would allow the handler to
106149// be stateless.
107150func (h * StreamableHTTPHandler ) closeAll () {
151+ // TODO: if we ever expose this outside of tests, we'll need to do better
152+ // than simply collecting sessions while holding the lock: we need to prevent
153+ // new sessions from being added.
154+ //
155+ // Currently, sessions remove themselves from h.sessions when closed, so we
156+ // can't call Close while holding the lock.
108157 h .mu .Lock ()
109- defer h .mu .Unlock ()
110- for _ , s := range h .transports {
111- s .connection .Close ()
158+ sessionInfos := slices .Collect (maps .Values (h .sessions ))
159+ h .sessions = nil
160+ h .mu .Unlock ()
161+ for _ , s := range sessionInfos {
162+ s .session .Close ()
112163 }
113- h .transports = nil
114164}
115165
116166func (h * StreamableHTTPHandler ) ServeHTTP (w http.ResponseWriter , req * http.Request ) {
@@ -141,12 +191,12 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
141191 }
142192
143193 sessionID := req .Header .Get (sessionIDHeader )
144- var transport * StreamableServerTransport
194+ var sessInfo * sessionInfo
145195 if sessionID != "" {
146196 h .mu .Lock ()
147- transport = h .transports [sessionID ]
197+ sessInfo = h .sessions [sessionID ]
148198 h .mu .Unlock ()
149- if transport == nil && ! h .opts .Stateless {
199+ if sessInfo == nil && ! h .opts .Stateless {
150200 // Unless we're in 'stateless' mode, which doesn't perform any Session-ID
151201 // validation, we require that the session ID matches a known session.
152202 //
@@ -161,11 +211,10 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
161211 http .Error (w , "Bad Request: DELETE requires an Mcp-Session-Id header" , http .StatusBadRequest )
162212 return
163213 }
164- if transport != nil { // transport may be nil in stateless mode
165- h .mu .Lock ()
166- delete (h .transports , transport .SessionID )
167- h .mu .Unlock ()
168- transport .connection .Close ()
214+ if sessInfo != nil { // sessInfo may be nil in stateless mode
215+ // Closing the session also removes it from h.sessions, due to the
216+ // onClose callback.
217+ sessInfo .session .Close ()
169218 }
170219 w .WriteHeader (http .StatusNoContent )
171220 return
@@ -222,7 +271,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
222271 return
223272 }
224273
225- if transport == nil {
274+ if sessInfo == nil {
226275 server := h .getServer (req )
227276 if server == nil {
228277 // The getServer argument to NewStreamableHTTPHandler returned nil.
@@ -234,7 +283,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
234283 // existing transport.
235284 sessionID = server .opts .GetSessionID ()
236285 }
237- transport = & StreamableServerTransport {
286+ transport : = & StreamableServerTransport {
238287 SessionID : sessionID ,
239288 Stateless : h .opts .Stateless ,
240289 jsonResponse : h .opts .JSONResponse ,
@@ -300,10 +349,13 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
300349 connectOpts = & ServerSessionOptions {
301350 onClose : func () {
302351 h .mu .Lock ()
303- delete (h .transports , transport .SessionID )
304- h .mu .Unlock ()
305- if h .onTransportDeletion != nil {
306- h .onTransportDeletion (transport .SessionID )
352+ defer h .mu .Unlock ()
353+ if info , ok := h .sessions [transport .SessionID ]; ok {
354+ info .stopTimer ()
355+ delete (h .sessions , transport .SessionID )
356+ if h .onTransportDeletion != nil {
357+ h .onTransportDeletion (transport .SessionID )
358+ }
307359 }
308360 },
309361 }
@@ -312,23 +364,36 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
312364 // Pass req.Context() here, to allow middleware to add context values.
313365 // The context is detached in the jsonrpc2 library when handling the
314366 // long-running stream.
315- ss , err := server .Connect (req .Context (), transport , connectOpts )
367+ session , err := server .Connect (req .Context (), transport , connectOpts )
316368 if err != nil {
317369 http .Error (w , "failed connection" , http .StatusInternalServerError )
318370 return
319371 }
372+ sessInfo = & sessionInfo {
373+ session : session ,
374+ transport : transport ,
375+ }
320376 if h .opts .Stateless {
321377 // Stateless mode: close the session when the request exits.
322- defer ss .Close () // close the fake session after handling the request
378+ defer session .Close () // close the fake session after handling the request
323379 } else {
324380 // Otherwise, save the transport so that it can be reused
381+
382+ // Clean up the session when it times out.
383+ if h .opts .SessionTimeout > 0 {
384+ sessInfo .timeout = h .opts .SessionTimeout
385+ sessInfo .timer = time .AfterFunc (sessInfo .timeout , func () {
386+ sessInfo .session .Close ()
387+ })
388+ }
325389 h .mu .Lock ()
326- h .transports [transport .SessionID ] = transport
390+ h .sessions [transport .SessionID ] = sessInfo
327391 h .mu .Unlock ()
328392 }
329393 }
330394
331- transport .ServeHTTP (w , req )
395+ sessInfo .resetTimeout ()
396+ sessInfo .transport .ServeHTTP (w , req )
332397}
333398
334399// A StreamableServerTransport implements the server side of the MCP streamable
@@ -1340,9 +1405,12 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
13401405 go c .handleJSON (requestSummary , resp )
13411406
13421407 case "text/event-stream" :
1343- jsonReq , _ := msg .(* jsonrpc.Request )
1408+ var forCall * jsonrpc.Request
1409+ if jsonReq , ok := msg .(* jsonrpc.Request ); ok && jsonReq .IsCall () {
1410+ forCall = jsonReq
1411+ }
13441412 // TODO: should we cancel this logical SSE request if/when jsonReq is canceled?
1345- go c .handleSSE (requestSummary , resp , false , jsonReq )
1413+ go c .handleSSE (requestSummary , resp , false , forCall )
13461414
13471415 default :
13481416 resp .Body .Close ()
@@ -1392,9 +1460,9 @@ func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Resp
13921460// handleSSE manages the lifecycle of an SSE connection. It can be either
13931461// persistent (for the main GET listener) or temporary (for a POST response).
13941462//
1395- // If forReq is set, it is the request that initiated the stream, and the
1463+ // If forCall is set, it is the call that initiated the stream, and the
13961464// stream is complete when we receive its response.
1397- func (c * streamableClientConn ) handleSSE (requestSummary string , initialResp * http.Response , persistent bool , forReq * jsonrpc2.Request ) {
1465+ func (c * streamableClientConn ) handleSSE (requestSummary string , initialResp * http.Response , persistent bool , forCall * jsonrpc2.Request ) {
13981466 resp := initialResp
13991467 var lastEventID string
14001468 for {
@@ -1404,7 +1472,7 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt
14041472 // Eventually, if we don't get the response, we should stop trying and
14051473 // fail the request.
14061474 if resp != nil {
1407- eventID , clientClosed := c .processStream (requestSummary , resp , forReq )
1475+ eventID , clientClosed := c .processStream (requestSummary , resp , forCall )
14081476 lastEventID = eventID
14091477
14101478 // If the connection was closed by the client, we're done.
@@ -1467,11 +1535,11 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt
14671535// incoming channel. It returns the ID of the last processed event and a flag
14681536// indicating if the connection was closed by the client. If resp is nil, it
14691537// returns "", false.
1470- func (c * streamableClientConn ) processStream (requestSummary string , resp * http.Response , forReq * jsonrpc.Request ) (lastEventID string , clientClosed bool ) {
1538+ func (c * streamableClientConn ) processStream (requestSummary string , resp * http.Response , forCall * jsonrpc.Request ) (lastEventID string , clientClosed bool ) {
14711539 defer resp .Body .Close ()
14721540 for evt , err := range scanEvents (resp .Body ) {
14731541 if err != nil {
1474- return lastEventID , false
1542+ break
14751543 }
14761544
14771545 if evt .ID != "" {
@@ -1486,10 +1554,10 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
14861554
14871555 select {
14881556 case c .incoming <- msg :
1489- if jsonResp , ok := msg .(* jsonrpc.Response ); ok && forReq != nil {
1557+ if jsonResp , ok := msg .(* jsonrpc.Response ); ok && forCall != nil {
14901558 // TODO: we should never get a response when forReq is nil (the standalone SSE request).
14911559 // We should detect this case.
1492- if jsonResp .ID == forReq .ID {
1560+ if jsonResp .ID == forCall .ID {
14931561 return "" , true
14941562 }
14951563 }
@@ -1499,7 +1567,20 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
14991567 }
15001568 }
15011569 // The loop finished without an error, indicating the server closed the stream.
1502- return "" , false
1570+ //
1571+ // If the lastEventID is "", the stream is not retryable and we should
1572+ // report a synthetic error for the call.
1573+ if lastEventID == "" && forCall != nil {
1574+ errmsg := & jsonrpc2.Response {
1575+ ID : forCall .ID ,
1576+ Error : fmt .Errorf ("request terminated without response" ),
1577+ }
1578+ select {
1579+ case c .incoming <- errmsg :
1580+ case <- c .done :
1581+ }
1582+ }
1583+ return lastEventID , false
15031584}
15041585
15051586// reconnect handles the logic of retrying a connection with an exponential
0 commit comments