Skip to content

Commit 547b5c1

Browse files
authored
mcp: fix panics with nil logger; ensure test handlers do not panic (#557)
PR 501 added logging to the streamable server, but didn't pass a non-nil logger to the streamable transport (an oversight). However, tests still passed due to net/http's panic recovery. Update test servers to enforce that handlers do not panic, and fix the panic by: 1. Passing along the configured logger to the session transport. 2. Using ensureLogger in Connect, since we technically allow StreamableServerTransports to be constructed by the user. Fixes #556
1 parent 4439658 commit 547b5c1

File tree

2 files changed

+36
-12
lines changed

2 files changed

+36
-12
lines changed

mcp/streamable.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
236236
SessionID: sessionID,
237237
Stateless: h.opts.Stateless,
238238
jsonResponse: h.opts.JSONResponse,
239+
logger: h.opts.Logger,
239240
}
240241

241242
// To support stateless mode, we initialize the session with a default
@@ -377,6 +378,10 @@ type StreamableServerTransport struct {
377378
// StreamableHTTPOptions.JSONResponse is exported.
378379
jsonResponse bool
379380

381+
// optional logger provided through the [StreamableHTTPOptions.Logger].
382+
//
383+
// TODO(rfindley): logger should be exported, since we want to allow people
384+
// to write their own streamable HTTP handler.
380385
logger *slog.Logger
381386

382387
// connection is non-nil if and only if the transport has been connected.
@@ -393,7 +398,7 @@ func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, er
393398
stateless: t.Stateless,
394399
eventStore: t.EventStore,
395400
jsonResponse: t.jsonResponse,
396-
logger: t.logger,
401+
logger: ensureLogger(t.logger), // see #556: must be non-nil
397402
incoming: make(chan jsonrpc.Message, 10),
398403
done: make(chan struct{}),
399404
streams: make(map[string]*stream),

mcp/streamable_test.go

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ import (
1818
"net/http/httptest"
1919
"net/http/httputil"
2020
"net/url"
21+
"os"
22+
"runtime"
2123
"sort"
2224
"strings"
2325
"sync"
@@ -91,7 +93,7 @@ func TestStreamableTransports(t *testing.T) {
9193
headerMu sync.Mutex
9294
lastHeader http.Header
9395
)
94-
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
96+
httpServer := httptest.NewServer(mustNotPanic(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
9597
headerMu.Lock()
9698
lastHeader = r.Header
9799
headerMu.Unlock()
@@ -102,7 +104,7 @@ func TestStreamableTransports(t *testing.T) {
102104
t.Errorf("got cookie %q, want %q", cookie.Value, "test-value")
103105
}
104106
handler.ServeHTTP(w, r)
105-
}))
107+
})))
106108
defer httpServer.Close()
107109

108110
// Create a client and connect it to the server using our StreamableClientTransport.
@@ -315,7 +317,7 @@ func testClientReplay(t *testing.T, test clientReplayTest) {
315317
return new(CallToolResult), nil, nil
316318
})
317319

318-
realServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil))
320+
realServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)))
319321
defer realServer.Close()
320322
realServerURL, err := url.Parse(realServer.URL)
321323
if err != nil {
@@ -324,6 +326,7 @@ func testClientReplay(t *testing.T, test clientReplayTest) {
324326

325327
// Configure a proxy that sits between the client and the real server.
326328
proxyHandler := httputil.NewSingleHostReverseProxy(realServerURL)
329+
// note: don't use mustNotPanic here as the proxy WILL panic when killed.
327330
proxy := httptest.NewServer(proxyHandler)
328331
proxyAddr := proxy.Listener.Addr().String() // Get the address to restart it later.
329332

@@ -434,7 +437,7 @@ func TestServerTransportCleanup(t *testing.T) {
434437
chans[sessionID] <- struct{}{}
435438
}
436439

437-
httpServer := httptest.NewServer(handler)
440+
httpServer := httptest.NewServer(mustNotPanic(t, handler))
438441
defer httpServer.Close()
439442

440443
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -484,7 +487,7 @@ func TestServerInitiatedSSE(t *testing.T) {
484487
notifications := make(chan string)
485488
server := NewServer(testImpl, nil)
486489

487-
httpServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil))
490+
httpServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)))
488491
defer httpServer.Close()
489492

490493
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -857,7 +860,7 @@ func TestStreamableServerTransport(t *testing.T) {
857860
}
858861

859862
func testStreamableHandler(t *testing.T, handler http.Handler, requests []streamableRequest) {
860-
httpServer := httptest.NewServer(handler)
863+
httpServer := httptest.NewServer(mustNotPanic(t, handler))
861864
defer httpServer.Close()
862865

863866
// blocks records request blocks by jsonrpc. ID.
@@ -1247,7 +1250,7 @@ func TestStreamableStateless(t *testing.T) {
12471250

12481251
testClientCompatibility := func(t *testing.T, handler http.Handler) {
12491252
ctx := context.Background()
1250-
httpServer := httptest.NewServer(handler)
1253+
httpServer := httptest.NewServer(mustNotPanic(t, handler))
12511254
defer httpServer.Close()
12521255
cs, err := NewClient(testImpl, nil).Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil)
12531256
if err != nil {
@@ -1332,7 +1335,7 @@ func TestTokenInfo(t *testing.T) {
13321335
}, nil
13331336
}
13341337
handler := auth.RequireBearerToken(verifier, nil)(streamHandler)
1335-
httpServer := httptest.NewServer(handler)
1338+
httpServer := httptest.NewServer(mustNotPanic(t, handler))
13361339
defer httpServer.Close()
13371340

13381341
transport := &StreamableClientTransport{Endpoint: httpServer.URL}
@@ -1366,7 +1369,7 @@ func TestStreamableGET(t *testing.T) {
13661369
server := NewServer(testImpl, nil)
13671370

13681371
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
1369-
httpServer := httptest.NewServer(handler)
1372+
httpServer := httptest.NewServer(mustNotPanic(t, handler))
13701373
defer httpServer.Close()
13711374

13721375
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -1442,7 +1445,7 @@ func TestStreamableClientContextPropagation(t *testing.T) {
14421445
defer cancel()
14431446
ctx2 := context.WithValue(ctx, testKey, testValue)
14441447

1445-
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
1448+
server := httptest.NewServer(mustNotPanic(t, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
14461449
switch req.Method {
14471450
case "POST":
14481451
w.Header().Set("Content-Type", "application/json")
@@ -1455,7 +1458,7 @@ func TestStreamableClientContextPropagation(t *testing.T) {
14551458
case "DELETE":
14561459
w.WriteHeader(http.StatusNoContent)
14571460
}
1458-
}))
1461+
})))
14591462
defer server.Close()
14601463

14611464
transport := &StreamableClientTransport{Endpoint: server.URL}
@@ -1486,3 +1489,19 @@ func TestStreamableClientContextPropagation(t *testing.T) {
14861489
}
14871490

14881491
}
1492+
1493+
// mustNotPanic is a helper to enforce that test handlers do not panic (see
1494+
// issue #556).
1495+
func mustNotPanic(t *testing.T, h http.Handler) http.Handler {
1496+
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
1497+
defer func() {
1498+
if r := recover(); r != nil {
1499+
buf := make([]byte, 1<<20)
1500+
n := runtime.Stack(buf, false)
1501+
fmt.Fprintf(os.Stderr, "handler panic: %v\n\n%s", r, buf[:n])
1502+
t.Errorf("handler panicked: %v", r)
1503+
}
1504+
}()
1505+
h.ServeHTTP(w, req)
1506+
})
1507+
}

0 commit comments

Comments
 (0)