Skip to content

Commit 8aee8d3

Browse files
authored
mcp: clean up stream metadata after completion (#592)
Address longstanding TODOs to clean up stream metadata after streams are complete: - Interpret a missing stream as a complete stream. - Remove the request->stream mapping when the request completes. Any writes after that point will fail as there can be no recipient. Tested using examples/client/loadtest, and updating TestStreamableTransports to check in-band and out-of-band behavior.
1 parent 8fe64fc commit 8aee8d3

File tree

3 files changed

+121
-42
lines changed

3 files changed

+121
-42
lines changed

examples/client/loadtest/main.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ var (
3434
timeout = flag.Duration("timeout", 1*time.Second, "request timeout")
3535
qps = flag.Int("qps", 100, "tool calls per second, per worker")
3636
verbose = flag.Bool("v", false, "if set, enable verbose logging")
37+
cleanup = flag.Bool("cleanup", true, "whether to clean up sessions at the end of the test")
3738
)
3839

3940
func main() {
@@ -76,7 +77,9 @@ func main() {
7677
if err != nil {
7778
log.Fatal(err)
7879
}
79-
defer cs.Close()
80+
if *cleanup {
81+
defer cs.Close()
82+
}
8083

8184
ticker := time.NewTicker(1 * time.Second / time.Duration(*qps))
8285
defer ticker.Stop()

mcp/streamable.go

Lines changed: 71 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -450,22 +450,14 @@ type streamableServerConn struct {
450450
// handled.
451451

452452
// streams holds the logical streams for this session, keyed by their ID.
453-
// TODO: streams are never deleted, so the memory for a connection grows without
454-
// bound. If we deleted a stream when the response is sent, we would lose the ability
455-
// to replay if there was a cut just before the response was transmitted.
456-
// Perhaps we could have a TTL for streams that starts just after the response.
457453
//
458-
// TODO(rfindley): Once all responses have been received for a stream, we can
459-
// remove it as it is no longer necessary, even if the client wants to replay
460-
// messages.
454+
// Lifecycle: streams persist until all of their responses are received from
455+
// the server.
461456
streams map[string]*stream
462457

463458
// requestStreams maps incoming requests to their logical stream ID.
464459
//
465-
// Lifecycle: requestStreams persist for the duration of the session.
466-
//
467-
// TODO(rfindley): clean up once requests are handled. See the TODO for
468-
// streams above.
460+
// Lifecycle: requestStreams persist until their response is received.
469461
requestStreams map[jsonrpc.ID]string
470462
}
471463

@@ -641,17 +633,39 @@ func (c *streamableServerConn) writeEvent(w http.ResponseWriter, stream *stream,
641633
// all messages, so that no delivery or storage of new messages occurs while
642634
// the stream is still replaying.
643635
func (c *streamableServerConn) acquireStream(ctx context.Context, w http.ResponseWriter, streamID string, lastIdx *int) (*stream, chan struct{}) {
636+
// if tempStream is set, the stream is done and we're just replaying messages.
637+
//
638+
// We record a temporary stream to claim exclusive replay rights.
639+
tempStream := false
644640
c.mu.Lock()
645-
stream, ok := c.streams[streamID]
646-
c.mu.Unlock()
641+
s, ok := c.streams[streamID]
647642
if !ok {
648-
http.Error(w, "unknown stream", http.StatusBadRequest)
649-
return nil, nil
643+
// The stream is logically done, but claim exclusive rights to replay it by
644+
// adding a temporary entry in the streams map.
645+
//
646+
// We create this entry with a non-nil deliver function, to ensure it isn't
647+
// claimed by another request before we lock it below.
648+
tempStream = true
649+
s = &stream{
650+
id: streamID,
651+
deliver: func([]byte, bool) error { return nil },
652+
}
653+
c.streams[streamID] = s
654+
655+
// Since this stream is transient, we must clean up after replaying.
656+
defer func() {
657+
c.mu.Lock()
658+
delete(c.streams, streamID)
659+
c.mu.Unlock()
660+
}()
650661
}
662+
c.mu.Unlock()
651663

652-
stream.mu.Lock()
653-
defer stream.mu.Unlock()
654-
if stream.deliver != nil {
664+
s.mu.Lock()
665+
defer s.mu.Unlock()
666+
667+
// Check that this stream wasn't claimed by another request.
668+
if !tempStream && s.deliver != nil {
655669
http.Error(w, "stream ID conflicts with ongoing stream", http.StatusConflict)
656670
return nil, nil
657671
}
@@ -664,7 +678,7 @@ func (c *streamableServerConn) acquireStream(ctx context.Context, w http.Respons
664678
// messages, and registered our delivery function.
665679
var toReplay [][]byte
666680
if c.eventStore != nil {
667-
for data, err := range c.eventStore.After(ctx, c.SessionID(), stream.id, *lastIdx) {
681+
for data, err := range c.eventStore.After(ctx, c.SessionID(), s.id, *lastIdx) {
668682
if err != nil {
669683
// We can't replay events, perhaps because the underlying event store
670684
// has garbage collected its storage.
@@ -685,7 +699,7 @@ func (c *streamableServerConn) acquireStream(ctx context.Context, w http.Respons
685699
w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler]
686700
w.Header().Set("Connection", "keep-alive")
687701

688-
if stream.id == "" {
702+
if s.id == "" {
689703
// Issue #410: the standalone SSE stream is likely not to receive messages
690704
// for a long time. Ensure that headers are flushed.
691705
w.WriteHeader(http.StatusOK)
@@ -695,30 +709,30 @@ func (c *streamableServerConn) acquireStream(ctx context.Context, w http.Respons
695709
}
696710

697711
for _, data := range toReplay {
698-
if err := c.writeEvent(w, stream, data, lastIdx); err != nil {
712+
if err := c.writeEvent(w, s, data, lastIdx); err != nil {
699713
return nil, nil
700714
}
701715
}
702716

703-
if stream.doneLocked() {
717+
if tempStream || s.doneLocked() {
704718
// Nothing more to do.
705719
return nil, nil
706720
}
707721

708-
// Finally register a delivery function and unlock the stream, allowing the
709-
// connection to write new events.
722+
// The stream is not done: register a delivery function before the stream is
723+
// unlocked, allowing the connection to write new events.
710724
done := make(chan struct{})
711-
stream.deliver = func(data []byte, final bool) error {
725+
s.deliver = func(data []byte, final bool) error {
712726
if err := ctx.Err(); err != nil {
713727
return err
714728
}
715-
err := c.writeEvent(w, stream, data, lastIdx)
729+
err := c.writeEvent(w, s, data, lastIdx)
716730
if final {
717731
close(done)
718732
}
719733
return err
720734
}
721-
return stream, done
735+
return s, done
722736
}
723737

724738
// servePOST handles an incoming message, and replies with either an outgoing
@@ -1009,13 +1023,23 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e
10091023
s = c.streams[streamID]
10101024
}
10111025
} else {
1012-
s = c.streams[""]
1026+
s = c.streams[""] // standalone SSE stream
1027+
}
1028+
if responseTo.IsValid() {
1029+
// Once we've responded to a request, disallow related messages by removing
1030+
// the stream association. This also releases memory.
1031+
delete(c.requestStreams, responseTo)
10131032
}
10141033
sessionClosed := c.isDone
10151034
c.mu.Unlock()
10161035

10171036
if s == nil {
1018-
return fmt.Errorf("%w: no stream for request", jsonrpc2.ErrRejected)
1037+
// The request was made in the context of an ongoing request, but that
1038+
// request is complete.
1039+
//
1040+
// In the future, we could be less strict and allow the request to land on
1041+
// the standalone SSE stream.
1042+
return fmt.Errorf("%w: write to closed stream", jsonrpc2.ErrRejected)
10191043
}
10201044
if sessionClosed {
10211045
return errors.New("session is closed")
@@ -1024,10 +1048,28 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e
10241048
s.mu.Lock()
10251049
defer s.mu.Unlock()
10261050
if s.doneLocked() {
1051+
// It's possible that the stream was completed in between getting s above,
1052+
// and acquiring the stream lock. In order to avoid acquiring s.mu while
1053+
// holding c.mu, we check the terminal condition again.
10271054
return fmt.Errorf("%w: write to closed stream", jsonrpc2.ErrRejected)
10281055
}
1056+
// Perform accounting on responses.
10291057
if responseTo.IsValid() {
1058+
if _, ok := s.requests[responseTo]; !ok {
1059+
panic(fmt.Sprintf("internal error: stream %v: response to untracked request %v", s.id, responseTo))
1060+
}
1061+
if s.id == "" {
1062+
// This should be guaranteed not to happen by the stream resolution logic
1063+
// above, but be defensive: we don't ever want to delete the standalone
1064+
// stream.
1065+
panic("internal error: response on standalone stream")
1066+
}
10301067
delete(s.requests, responseTo)
1068+
if len(s.requests) == 0 {
1069+
c.mu.Lock()
1070+
delete(c.streams, s.id)
1071+
c.mu.Unlock()
1072+
}
10311073
}
10321074

10331075
delivered := false

mcp/streamable_test.go

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,23 +73,53 @@ func TestStreamableTransports(t *testing.T) {
7373
return nil, nil, nil
7474
}
7575
AddTool(server, &Tool{Name: "hang"}, hang)
76+
// We use sampling to test server->client requests, both before and after
77+
// the related client->server request completes.
78+
sampleDone := make(chan struct{})
79+
var sampleWG sync.WaitGroup
7680
AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) {
81+
type testCase struct {
82+
label string
83+
ctx context.Context
84+
wantSuccess bool
85+
}
86+
testSample := func(tc testCase) {
87+
res, err := req.Session.CreateMessage(tc.ctx, &CreateMessageParams{})
88+
if gotSuccess := err == nil; gotSuccess != tc.wantSuccess {
89+
t.Errorf("%s: CreateMessage success=%v, want %v", tc.label, gotSuccess, tc.wantSuccess)
90+
}
91+
if err != nil {
92+
return
93+
}
94+
if g, w := res.Model, "aModel"; g != w {
95+
t.Errorf("%s: got model %q, want %q", tc.label, g, w)
96+
}
97+
}
7798
// Test that we can make sampling requests during tool handling.
7899
//
79100
// Try this on both the request context and a background context, so
80101
// that messages may be delivered on either the POST or GET connection.
81-
for _, ctx := range map[string]context.Context{
82-
"request context": ctx,
83-
"background context": context.Background(),
102+
for _, test := range []testCase{
103+
{"request context", ctx, true},
104+
{"background context", context.Background(), true},
84105
} {
85-
res, err := req.Session.CreateMessage(ctx, &CreateMessageParams{})
86-
if err != nil {
87-
return nil, nil, err
88-
}
89-
if g, w := res.Model, "aModel"; g != w {
90-
return nil, nil, fmt.Errorf("got %q, want %q", g, w)
91-
}
106+
testSample(test)
92107
}
108+
// Now, spin off a goroutine that runs after the sampling request, to
109+
// check behavior when the client request has completed.
110+
sampleWG.Add(1)
111+
go func() {
112+
defer sampleWG.Done()
113+
<-sampleDone
114+
// Test that sampling requests in the tool context fail outside of
115+
// tool handling, but succeed on the background context.
116+
for _, test := range []testCase{
117+
{"request context", ctx, false},
118+
{"background context", context.Background(), true},
119+
} {
120+
testSample(test)
121+
}
122+
}()
93123
return &CallToolResult{}, nil, nil
94124
})
95125

@@ -191,15 +221,19 @@ func TestStreamableTransports(t *testing.T) {
191221
t.Fatal("timeout waiting for cancellation")
192222
}
193223

194-
// The "sampling" tool should be able to issue sampling requests during
195-
// tool operation.
224+
// The "sampling" tool checks the validity of server->client requests
225+
// both within and without the tool context.
196226
result, err := session.CallTool(ctx, &CallToolParams{
197227
Name: "sample",
198228
Arguments: map[string]any{},
199229
})
200230
if err != nil {
201231
t.Fatal(err)
202232
}
233+
// Run the out-of-band checks.
234+
close(sampleDone)
235+
sampleWG.Wait()
236+
203237
if result.IsError {
204238
t.Fatalf("tool failed: %s", result.Content[0].(*TextContent).Text)
205239
}

0 commit comments

Comments
 (0)