@@ -12,6 +12,7 @@ import (
1212 "fmt"
1313 "io"
1414 "log/slog"
15+ "maps"
1516 "math"
1617 "math/rand/v2"
1718 "net/http"
@@ -40,12 +41,46 @@ type StreamableHTTPHandler struct {
4041 getServer func (* http.Request ) * Server
4142 opts StreamableHTTPOptions
4243
43- onTransportDeletion func (sessionID string ) // for testing only
44+ onTransportDeletion func (sessionID string ) // for testing
4445
45- mu sync.Mutex
46- // TODO: we should store the ServerSession along with the transport, because
47- // we need to cancel keepalive requests when closing the transport.
48- transports map [string ]* StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header)
46+ mu sync.Mutex
47+ sessions map [string ]* sessionInfo // keyed by session ID
48+ }
49+
50+ type sessionInfo struct {
51+ session * ServerSession
52+ transport * StreamableServerTransport
53+
54+ // If timeout is set, automatically close the session after an idle period.
55+ timeout time.Duration
56+ timerMu sync.Mutex
57+ timer * time.Timer
58+ }
59+
60+ // resetTimeout resets the inactivity timer.
61+ func (i * sessionInfo ) resetTimeout () {
62+ if i .timeout <= 0 {
63+ return
64+ }
65+
66+ i .timerMu .Lock ()
67+ defer i .timerMu .Unlock ()
68+
69+ if i .timer == nil {
70+ return
71+ }
72+ // Reset the timer if we successfully stopped it.
73+ i .timer .Reset (i .timeout )
74+ }
75+
76+ // stopTimer stops the inactivity timer.
77+ func (i * sessionInfo ) stopTimer () {
78+ i .timerMu .Lock ()
79+ defer i .timerMu .Unlock ()
80+ if i .timer != nil {
81+ i .timer .Stop ()
82+ i .timer = nil
83+ }
4984}
5085
5186// StreamableHTTPOptions configures the StreamableHTTPHandler.
@@ -77,6 +112,14 @@ type StreamableHTTPOptions struct {
77112 // If set, EventStore will be used to persist stream events and replay them
78113 // upon stream resumption.
79114 EventStore EventStore
115+
116+ // SessionTimeout configures a timeout for idle sessions.
117+ //
118+ // When sessions receive no new HTTP requests from the client for this
119+ // duration, they are automatically closed.
120+ //
121+ // If SessionTimeout is the zero value, idle sessions are never closed.
122+ SessionTimeout time.Duration
80123}
81124
82125// NewStreamableHTTPHandler returns a new [StreamableHTTPHandler].
@@ -86,8 +129,8 @@ type StreamableHTTPOptions struct {
86129// If getServer returns nil, a 400 Bad Request will be served.
87130func NewStreamableHTTPHandler (getServer func (* http.Request ) * Server , opts * StreamableHTTPOptions ) * StreamableHTTPHandler {
88131 h := & StreamableHTTPHandler {
89- getServer : getServer ,
90- transports : make (map [string ]* StreamableServerTransport ),
132+ getServer : getServer ,
133+ sessions : make (map [string ]* sessionInfo ),
91134 }
92135 if opts != nil {
93136 h .opts = * opts
@@ -100,20 +143,27 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea
100143 return h
101144}
102145
103- // closeAll closes all ongoing sessions.
146+ // closeAll closes all ongoing sessions, for tests .
104147//
105148// TODO(rfindley): investigate the best API for callers to configure their
106149// session lifecycle. (?)
107150//
108151// Should we allow passing in a session store? That would allow the handler to
109152// be stateless.
110153func (h * StreamableHTTPHandler ) closeAll () {
154+ // TODO: if we ever expose this outside of tests, we'll need to do better
155+ // than simply collecting sessions while holding the lock: we need to prevent
156+ // new sessions from being added.
157+ //
158+ // Currently, sessions remove themselves from h.sessions when closed, so we
159+ // can't call Close while holding the lock.
111160 h .mu .Lock ()
112- defer h .mu .Unlock ()
113- for _ , s := range h .transports {
114- s .connection .Close ()
161+ sessionInfos := slices .Collect (maps .Values (h .sessions ))
162+ h .sessions = nil
163+ h .mu .Unlock ()
164+ for _ , s := range sessionInfos {
165+ s .session .Close ()
115166 }
116- h .transports = nil
117167}
118168
119169func (h * StreamableHTTPHandler ) ServeHTTP (w http.ResponseWriter , req * http.Request ) {
@@ -144,12 +194,12 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
144194 }
145195
146196 sessionID := req .Header .Get (sessionIDHeader )
147- var transport * StreamableServerTransport
197+ var sessInfo * sessionInfo
148198 if sessionID != "" {
149199 h .mu .Lock ()
150- transport = h .transports [sessionID ]
200+ sessInfo = h .sessions [sessionID ]
151201 h .mu .Unlock ()
152- if transport == nil && ! h .opts .Stateless {
202+ if sessInfo == nil && ! h .opts .Stateless {
153203 // Unless we're in 'stateless' mode, which doesn't perform any Session-ID
154204 // validation, we require that the session ID matches a known session.
155205 //
@@ -164,11 +214,10 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
164214 http .Error (w , "Bad Request: DELETE requires an Mcp-Session-Id header" , http .StatusBadRequest )
165215 return
166216 }
167- if transport != nil { // transport may be nil in stateless mode
168- h .mu .Lock ()
169- delete (h .transports , transport .SessionID )
170- h .mu .Unlock ()
171- transport .connection .Close ()
217+ if sessInfo != nil { // sessInfo may be nil in stateless mode
218+ // Closing the session also removes it from h.sessions, due to the
219+ // onClose callback.
220+ sessInfo .session .Close ()
172221 }
173222 w .WriteHeader (http .StatusNoContent )
174223 return
@@ -225,7 +274,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
225274 return
226275 }
227276
228- if transport == nil {
277+ if sessInfo == nil {
229278 server := h .getServer (req )
230279 if server == nil {
231280 // The getServer argument to NewStreamableHTTPHandler returned nil.
@@ -237,7 +286,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
237286 // existing transport.
238287 sessionID = server .opts .GetSessionID ()
239288 }
240- transport = & StreamableServerTransport {
289+ transport : = & StreamableServerTransport {
241290 SessionID : sessionID ,
242291 Stateless : h .opts .Stateless ,
243292 EventStore : h .opts .EventStore ,
@@ -301,10 +350,13 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
301350 connectOpts = & ServerSessionOptions {
302351 onClose : func () {
303352 h .mu .Lock ()
304- delete (h .transports , transport .SessionID )
305- h .mu .Unlock ()
306- if h .onTransportDeletion != nil {
307- h .onTransportDeletion (transport .SessionID )
353+ defer h .mu .Unlock ()
354+ if info , ok := h .sessions [transport .SessionID ]; ok {
355+ info .stopTimer ()
356+ delete (h .sessions , transport .SessionID )
357+ if h .onTransportDeletion != nil {
358+ h .onTransportDeletion (transport .SessionID )
359+ }
308360 }
309361 },
310362 }
@@ -313,23 +365,36 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
313365 // Pass req.Context() here, to allow middleware to add context values.
314366 // The context is detached in the jsonrpc2 library when handling the
315367 // long-running stream.
316- ss , err := server .Connect (req .Context (), transport , connectOpts )
368+ session , err := server .Connect (req .Context (), transport , connectOpts )
317369 if err != nil {
318370 http .Error (w , "failed connection" , http .StatusInternalServerError )
319371 return
320372 }
373+ sessInfo = & sessionInfo {
374+ session : session ,
375+ transport : transport ,
376+ }
321377 if h .opts .Stateless {
322378 // Stateless mode: close the session when the request exits.
323- defer ss .Close () // close the fake session after handling the request
379+ defer session .Close () // close the fake session after handling the request
324380 } else {
325381 // Otherwise, save the transport so that it can be reused
382+
383+ // Clean up the session when it times out.
384+ if h .opts .SessionTimeout > 0 {
385+ sessInfo .timeout = h .opts .SessionTimeout
386+ sessInfo .timer = time .AfterFunc (sessInfo .timeout , func () {
387+ sessInfo .session .Close ()
388+ })
389+ }
326390 h .mu .Lock ()
327- h .transports [transport .SessionID ] = transport
391+ h .sessions [transport .SessionID ] = sessInfo
328392 h .mu .Unlock ()
329393 }
330394 }
331395
332- transport .ServeHTTP (w , req )
396+ sessInfo .resetTimeout ()
397+ sessInfo .transport .ServeHTTP (w , req )
333398}
334399
335400// A StreamableServerTransport implements the server side of the MCP streamable
@@ -1383,9 +1448,12 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
13831448 go c .handleJSON (requestSummary , resp )
13841449
13851450 case "text/event-stream" :
1386- jsonReq , _ := msg .(* jsonrpc.Request )
1451+ var forCall * jsonrpc.Request
1452+ if jsonReq , ok := msg .(* jsonrpc.Request ); ok && jsonReq .IsCall () {
1453+ forCall = jsonReq
1454+ }
13871455 // TODO: should we cancel this logical SSE request if/when jsonReq is canceled?
1388- go c .handleSSE (requestSummary , resp , false , jsonReq )
1456+ go c .handleSSE (requestSummary , resp , false , forCall )
13891457
13901458 default :
13911459 resp .Body .Close ()
@@ -1435,9 +1503,9 @@ func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Resp
14351503// handleSSE manages the lifecycle of an SSE connection. It can be either
14361504// persistent (for the main GET listener) or temporary (for a POST response).
14371505//
1438- // If forReq is set, it is the request that initiated the stream, and the
1506+ // If forCall is set, it is the call that initiated the stream, and the
14391507// stream is complete when we receive its response.
1440- func (c * streamableClientConn ) handleSSE (requestSummary string , initialResp * http.Response , persistent bool , forReq * jsonrpc2.Request ) {
1508+ func (c * streamableClientConn ) handleSSE (requestSummary string , initialResp * http.Response , persistent bool , forCall * jsonrpc2.Request ) {
14411509 resp := initialResp
14421510 var lastEventID string
14431511 for {
@@ -1447,7 +1515,7 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt
14471515 // Eventually, if we don't get the response, we should stop trying and
14481516 // fail the request.
14491517 if resp != nil {
1450- eventID , clientClosed := c .processStream (requestSummary , resp , forReq )
1518+ eventID , clientClosed := c .processStream (requestSummary , resp , forCall )
14511519 lastEventID = eventID
14521520
14531521 // If the connection was closed by the client, we're done.
@@ -1510,11 +1578,11 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt
15101578// incoming channel. It returns the ID of the last processed event and a flag
15111579// indicating if the connection was closed by the client. If resp is nil, it
15121580// returns "", false.
1513- func (c * streamableClientConn ) processStream (requestSummary string , resp * http.Response , forReq * jsonrpc.Request ) (lastEventID string , clientClosed bool ) {
1581+ func (c * streamableClientConn ) processStream (requestSummary string , resp * http.Response , forCall * jsonrpc.Request ) (lastEventID string , clientClosed bool ) {
15141582 defer resp .Body .Close ()
15151583 for evt , err := range scanEvents (resp .Body ) {
15161584 if err != nil {
1517- return lastEventID , false
1585+ break
15181586 }
15191587
15201588 if evt .ID != "" {
@@ -1529,10 +1597,10 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
15291597
15301598 select {
15311599 case c .incoming <- msg :
1532- if jsonResp , ok := msg .(* jsonrpc.Response ); ok && forReq != nil {
1600+ if jsonResp , ok := msg .(* jsonrpc.Response ); ok && forCall != nil {
15331601 // TODO: we should never get a response when forReq is nil (the standalone SSE request).
15341602 // We should detect this case.
1535- if jsonResp .ID == forReq .ID {
1603+ if jsonResp .ID == forCall .ID {
15361604 return "" , true
15371605 }
15381606 }
@@ -1542,7 +1610,20 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
15421610 }
15431611 }
15441612 // The loop finished without an error, indicating the server closed the stream.
1545- return "" , false
1613+ //
1614+ // If the lastEventID is "", the stream is not retryable and we should
1615+ // report a synthetic error for the call.
1616+ if lastEventID == "" && forCall != nil {
1617+ errmsg := & jsonrpc2.Response {
1618+ ID : forCall .ID ,
1619+ Error : fmt .Errorf ("request terminated without response" ),
1620+ }
1621+ select {
1622+ case c .incoming <- errmsg :
1623+ case <- c .done :
1624+ }
1625+ }
1626+ return lastEventID , false
15461627}
15471628
15481629// reconnect handles the logic of retrying a connection with an exponential
0 commit comments