Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
SessionID: sessionID,
Stateless: h.opts.Stateless,
jsonResponse: h.opts.JSONResponse,
logger: h.opts.Logger,
}

// To support stateless mode, we initialize the session with a default
Expand Down Expand Up @@ -377,6 +378,10 @@ type StreamableServerTransport struct {
// StreamableHTTPOptions.JSONResponse is exported.
jsonResponse bool

// optional logger provided through the [StreamableHTTPOptions.Logger].
//
// TODO(rfindley): logger should be exported, since we want to allow people
// to write their own streamable HTTP handler.
logger *slog.Logger

// connection is non-nil if and only if the transport has been connected.
Expand All @@ -393,7 +398,7 @@ func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, er
stateless: t.Stateless,
eventStore: t.EventStore,
jsonResponse: t.jsonResponse,
logger: t.logger,
logger: ensureLogger(t.logger), // see #556: must be non-nil
incoming: make(chan jsonrpc.Message, 10),
done: make(chan struct{}),
streams: make(map[string]*stream),
Expand Down
41 changes: 30 additions & 11 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"net/http/httptest"
"net/http/httputil"
"net/url"
"os"
"runtime"
"sort"
"strings"
"sync"
Expand Down Expand Up @@ -91,7 +93,7 @@ func TestStreamableTransports(t *testing.T) {
headerMu sync.Mutex
lastHeader http.Header
)
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
httpServer := httptest.NewServer(mustNotPanic(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
headerMu.Lock()
lastHeader = r.Header
headerMu.Unlock()
Expand All @@ -102,7 +104,7 @@ func TestStreamableTransports(t *testing.T) {
t.Errorf("got cookie %q, want %q", cookie.Value, "test-value")
}
handler.ServeHTTP(w, r)
}))
})))
defer httpServer.Close()

// Create a client and connect it to the server using our StreamableClientTransport.
Expand Down Expand Up @@ -315,7 +317,7 @@ func testClientReplay(t *testing.T, test clientReplayTest) {
return new(CallToolResult), nil, nil
})

realServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil))
realServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)))
defer realServer.Close()
realServerURL, err := url.Parse(realServer.URL)
if err != nil {
Expand All @@ -324,6 +326,7 @@ func testClientReplay(t *testing.T, test clientReplayTest) {

// Configure a proxy that sits between the client and the real server.
proxyHandler := httputil.NewSingleHostReverseProxy(realServerURL)
// note: don't use mustNotPanic here as the proxy WILL panic when killed.
proxy := httptest.NewServer(proxyHandler)
proxyAddr := proxy.Listener.Addr().String() // Get the address to restart it later.

Expand Down Expand Up @@ -434,7 +437,7 @@ func TestServerTransportCleanup(t *testing.T) {
chans[sessionID] <- struct{}{}
}

httpServer := httptest.NewServer(handler)
httpServer := httptest.NewServer(mustNotPanic(t, handler))
defer httpServer.Close()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
Expand Down Expand Up @@ -484,7 +487,7 @@ func TestServerInitiatedSSE(t *testing.T) {
notifications := make(chan string)
server := NewServer(testImpl, nil)

httpServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil))
httpServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)))
defer httpServer.Close()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
Expand Down Expand Up @@ -857,7 +860,7 @@ func TestStreamableServerTransport(t *testing.T) {
}

func testStreamableHandler(t *testing.T, handler http.Handler, requests []streamableRequest) {
httpServer := httptest.NewServer(handler)
httpServer := httptest.NewServer(mustNotPanic(t, handler))
defer httpServer.Close()

// blocks records request blocks by jsonrpc. ID.
Expand Down Expand Up @@ -1247,7 +1250,7 @@ func TestStreamableStateless(t *testing.T) {

testClientCompatibility := func(t *testing.T, handler http.Handler) {
ctx := context.Background()
httpServer := httptest.NewServer(handler)
httpServer := httptest.NewServer(mustNotPanic(t, handler))
defer httpServer.Close()
cs, err := NewClient(testImpl, nil).Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil)
if err != nil {
Expand Down Expand Up @@ -1332,7 +1335,7 @@ func TestTokenInfo(t *testing.T) {
}, nil
}
handler := auth.RequireBearerToken(verifier, nil)(streamHandler)
httpServer := httptest.NewServer(handler)
httpServer := httptest.NewServer(mustNotPanic(t, handler))
defer httpServer.Close()

transport := &StreamableClientTransport{Endpoint: httpServer.URL}
Expand Down Expand Up @@ -1366,7 +1369,7 @@ func TestStreamableGET(t *testing.T) {
server := NewServer(testImpl, nil)

handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
httpServer := httptest.NewServer(handler)
httpServer := httptest.NewServer(mustNotPanic(t, handler))
defer httpServer.Close()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
Expand Down Expand Up @@ -1442,7 +1445,7 @@ func TestStreamableClientContextPropagation(t *testing.T) {
defer cancel()
ctx2 := context.WithValue(ctx, testKey, testValue)

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
server := httptest.NewServer(mustNotPanic(t, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
switch req.Method {
case "POST":
w.Header().Set("Content-Type", "application/json")
Expand All @@ -1455,7 +1458,7 @@ func TestStreamableClientContextPropagation(t *testing.T) {
case "DELETE":
w.WriteHeader(http.StatusNoContent)
}
}))
})))
defer server.Close()

transport := &StreamableClientTransport{Endpoint: server.URL}
Expand Down Expand Up @@ -1486,3 +1489,19 @@ func TestStreamableClientContextPropagation(t *testing.T) {
}

}

// mustNotPanic is a helper to enforce that test handlers do not panic (see
// issue #556).
func mustNotPanic(t *testing.T, h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
defer func() {
if r := recover(); r != nil {
buf := make([]byte, 1<<20)
n := runtime.Stack(buf, false)
fmt.Fprintf(os.Stderr, "handler panic: %v\n\n%s", r, buf[:n])
t.Errorf("handler panicked: %v", r)
}
}()
h.ServeHTTP(w, req)
})
}