Skip to content

Commit 022deb2

Browse files
committed
address review comments
1 parent 753347d commit 022deb2

File tree

4 files changed

+93
-14
lines changed

4 files changed

+93
-14
lines changed

mcp/client.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"iter"
1212
"slices"
1313
"sync"
14+
"sync/atomic"
1415
"time"
1516

1617
"github.com/google/jsonschema-go/jsonschema"
@@ -177,7 +178,11 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio
177178
// Call [ClientSession.Close] to close the connection, or await server
178179
// termination with [ClientSession.Wait].
179180
type ClientSession struct {
180-
onClose func()
181+
// Ensure that onClose is called at most once.
182+
// We defensively use an atomic CompareAndSwap rather than a sync.Once, in case the
183+
// onClose callback triggers a re-entrant call to Close.
184+
calledOnClose atomic.Bool
185+
onClose func()
181186

182187
conn *jsonrpc2.Connection
183188
client *Client
@@ -205,6 +210,8 @@ func (cs *ClientSession) ID() string {
205210
// Close performs a graceful close of the connection, preventing new requests
206211
// from being handled, and waiting for ongoing requests to return. Close then
207212
// terminates the connection.
213+
//
214+
// Close is idempotent and concurrency safe.
208215
func (cs *ClientSession) Close() error {
209216
// Note: keepaliveCancel access is safe without a mutex because:
210217
// 1. keepaliveCancel is only written once during startKeepalive (happens-before all Close calls)
@@ -216,7 +223,7 @@ func (cs *ClientSession) Close() error {
216223
}
217224
err := cs.conn.Close()
218225

219-
if cs.onClose != nil {
226+
if cs.onClose != nil && cs.calledOnClose.CompareAndSwap(false, true) {
220227
cs.onClose()
221228
}
222229

mcp/server.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"reflect"
2020
"slices"
2121
"sync"
22+
"sync/atomic"
2223
"time"
2324

2425
"github.com/google/jsonschema-go/jsonschema"
@@ -920,7 +921,11 @@ func newServerRequest[P Params](ss *ServerSession, params P) *ServerRequest[P] {
920921
// Call [ServerSession.Close] to close the connection, or await client
921922
// termination with [ServerSession.Wait].
922923
type ServerSession struct {
923-
onClose func()
924+
// Ensure that onClose is called at most once.
925+
// We defensively use an atomic CompareAndSwap rather than a sync.Once, in case the
926+
// onClose callback triggers a re-entrant call to Close.
927+
calledOnClose atomic.Bool
928+
onClose func()
924929

925930
server *Server
926931
conn *jsonrpc2.Connection
@@ -1185,6 +1190,8 @@ func (ss *ServerSession) setLevel(_ context.Context, params *SetLoggingLevelPara
11851190
// Close performs a graceful shutdown of the connection, preventing new
11861191
// requests from being handled, and waiting for ongoing requests to return.
11871192
// Close then terminates the connection.
1193+
//
1194+
// Close is idempotent and concurrency safe.
11881195
func (ss *ServerSession) Close() error {
11891196
if ss.keepaliveCancel != nil {
11901197
// Note: keepaliveCancel access is safe without a mutex because:
@@ -1196,7 +1203,7 @@ func (ss *ServerSession) Close() error {
11961203
}
11971204
err := ss.conn.Close()
11981205

1199-
if ss.onClose != nil {
1206+
if ss.onClose != nil && ss.calledOnClose.CompareAndSwap(false, true) {
12001207
ss.onClose()
12011208
}
12021209

mcp/streamable.go

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,14 @@ type sessionInfo struct {
5454
// If timeout is set, automatically close the session after an idle period.
5555
timeout time.Duration
5656
timerMu sync.Mutex
57+
refs int // reference count
5758
timer *time.Timer
5859
}
5960

60-
// resetTimeout resets the inactivity timer.
61-
func (i *sessionInfo) resetTimeout() {
61+
// startPOST signals that a POST request for this session is starting (which
62+
// carries a client->server message), stopping the session timeout if it was
63+
// running.
64+
func (i *sessionInfo) startPOST() {
6265
if i.timeout <= 0 {
6366
return
6467
}
@@ -67,13 +70,36 @@ func (i *sessionInfo) resetTimeout() {
6770
defer i.timerMu.Unlock()
6871

6972
if i.timer == nil {
73+
return // timer stopped permanently
74+
}
75+
if i.refs == 0 {
76+
i.timer.Stop()
77+
}
78+
i.refs++
79+
}
80+
81+
// endPOST sigals that a request for this session is ending, starting the
82+
// timeout if there are no other requests running.
83+
func (i *sessionInfo) endPOST() {
84+
if i.timeout <= 0 {
7085
return
7186
}
72-
// Reset the timer if we successfully stopped it.
73-
i.timer.Reset(i.timeout)
87+
88+
i.timerMu.Lock()
89+
defer i.timerMu.Unlock()
90+
91+
if i.timer == nil {
92+
return // timer stopped permanently
93+
}
94+
95+
i.refs--
96+
assert(i.refs >= 0, "negative ref count")
97+
if i.refs == 0 {
98+
i.timer.Reset(i.timeout)
99+
}
74100
}
75101

76-
// stopTimer stops the inactivity timer.
102+
// stopTimer stops the inactivity timer permanently.
77103
func (i *sessionInfo) stopTimer() {
78104
i.timerMu.Lock()
79105
defer i.timerMu.Unlock()
@@ -374,13 +400,17 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
374400
session: session,
375401
transport: transport,
376402
}
403+
377404
if h.opts.Stateless {
378405
// Stateless mode: close the session when the request exits.
379406
defer session.Close() // close the fake session after handling the request
380407
} else {
381408
// Otherwise, save the transport so that it can be reused
382409

383410
// Clean up the session when it times out.
411+
//
412+
// Note that the timer here may fire multiple times, but
413+
// sessInfo.session.Close is idempotent.
384414
if h.opts.SessionTimeout > 0 {
385415
sessInfo.timeout = h.opts.SessionTimeout
386416
sessInfo.timer = time.AfterFunc(sessInfo.timeout, func() {
@@ -393,7 +423,11 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
393423
}
394424
}
395425

396-
sessInfo.resetTimeout()
426+
if req.Method == http.MethodPost {
427+
sessInfo.startPOST()
428+
defer sessInfo.endPOST()
429+
}
430+
397431
sessInfo.transport.ServeHTTP(w, req)
398432
}
399433

mcp/streamable_client_test.go

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ type streamableResponse struct {
3434
body string // or ""
3535
optional bool // if set, request need not be sent
3636
wantProtocolVersion string // if "", unchecked
37-
callback func() // if set, called after the request is handled
3837
}
3938

4039
type fakeResponses map[streamableRequestKey]*streamableResponse
@@ -96,9 +95,6 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques
9695
http.Error(w, "no response", http.StatusInternalServerError)
9796
return
9897
}
99-
if resp.callback != nil {
100-
defer resp.callback()
101-
}
10298
for k, v := range resp.header {
10399
w.Header().Set(k, v)
104100
}
@@ -411,3 +407,38 @@ func TestStreamableClientStrictness(t *testing.T) {
411407
})
412408
}
413409
}
410+
411+
func TestStreamableClientUnresumableRequest(t *testing.T) {
412+
// This test verifies that the client fails fast when making a request that
413+
// is unresumable, because it does not contain any events.
414+
ctx := context.Background()
415+
fake := &fakeStreamableServer{
416+
t: t,
417+
responses: fakeResponses{
418+
{"POST", "", methodInitialize}: {
419+
header: header{
420+
"Content-Type": "text/event-stream",
421+
sessionIDHeader: "123",
422+
},
423+
body: "",
424+
},
425+
{"DELETE", "123", ""}: {optional: true},
426+
},
427+
}
428+
httpServer := httptest.NewServer(fake)
429+
defer httpServer.Close()
430+
431+
transport := &StreamableClientTransport{Endpoint: httpServer.URL}
432+
client := NewClient(testImpl, nil)
433+
cs, err := client.Connect(ctx, transport, nil)
434+
if err == nil {
435+
cs.Close()
436+
t.Fatalf("Connect succeeded unexpectedly")
437+
}
438+
// This may be a bit of a change detector, but for now check that we're
439+
// actually exercising the early failure codepath.
440+
msg := "terminated without response"
441+
if !strings.Contains(err.Error(), msg) {
442+
t.Errorf("Connect: got error %v, want containing %q", err, msg)
443+
}
444+
}

0 commit comments

Comments
 (0)