Skip to content

rpc: add SetWebsocketReadLimit in Server #32279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions rpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ type Server struct {
batchItemLimit int
batchResponseLimit int
httpBodyLimit int
wsReadLimit int64
}

// NewServer creates a new server instance with no registered handlers.
Expand All @@ -62,6 +63,7 @@ func NewServer() *Server {
idgen: randomIDGenerator(),
codecs: make(map[ServerCodec]struct{}),
httpBodyLimit: defaultBodyLimit,
wsReadLimit: wsDefaultReadLimit,
}
server.run.Store(true)
// Register the default service providing meta information about the RPC service such
Expand Down Expand Up @@ -89,6 +91,13 @@ func (s *Server) SetHTTPBodyLimit(limit int) {
s.httpBodyLimit = limit
}

// SetWebsocketReadLimit sets the limit for max message size for Websocket requests.
//
// This method should be called before processing any requests via Websocket server.
func (s *Server) SetWebsocketReadLimit(limit int64) {
s.wsReadLimit = limit
}

// RegisterName creates a service for the given receiver type under the given name. When no
// methods on the given receiver match the criteria to be either an RPC method or a
// subscription an error is returned. Otherwise a new service is created and added to the
Expand Down
88 changes: 88 additions & 0 deletions rpc/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@ package rpc
import (
"bufio"
"bytes"
"context"
"errors"
"io"
"net"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"

"github.com/gorilla/websocket"
)

func TestServerRegisterName(t *testing.T) {
Expand Down Expand Up @@ -202,3 +207,86 @@ func TestServerBatchResponseSizeLimit(t *testing.T) {
}
}
}

func TestServerWebsocketReadLimit(t *testing.T) {
t.Parallel()

// Test different read limits
testCases := []struct {
name string
readLimit int64
testSize int
shouldFail bool
}{
{
name: "limit with small request - should succeed",
readLimit: 4096, // generous limit to comfortably allow JSON overhead
testSize: 256, // reasonably small payload
shouldFail: false,
},
{
name: "limit with large request - should fail",
readLimit: 256, // tight limit to trigger server-side read limit
testSize: 1024, // payload that will exceed the limit including JSON overhead
shouldFail: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create server and set read limits
srv := newTestServer()
srv.SetWebsocketReadLimit(tc.readLimit)
defer srv.Stop()

// Start HTTP server with WebSocket handler
httpsrv := httptest.NewServer(srv.WebsocketHandler([]string{"*"}))
defer httpsrv.Close()

wsURL := "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")

// Connect WebSocket client
client, err := DialOptions(context.Background(), wsURL)
if err != nil {
t.Fatalf("can't dial: %v", err)
}
defer client.Close()

// Create large request data - this is what will be limited
largeString := strings.Repeat("A", tc.testSize)

// Send the large string as a parameter in the request
var result echoResult
err = client.Call(&result, "test_echo", largeString, 42, &echoArgs{S: "test"})

if tc.shouldFail {
// Expecting an error due to read limit exceeded
if err == nil {
t.Fatalf("expected error for request size %d with limit %d, but got none", tc.testSize, tc.readLimit)
}
// Be tolerant about the exact error surfaced by gorilla/websocket.
// Prefer a CloseError with code 1009, but accept ErrReadLimit or an error string containing 1009/message too big.
var cerr *websocket.CloseError
if errors.As(err, &cerr) {
if cerr.Code != websocket.CloseMessageTooBig {
t.Fatalf("unexpected websocket close code: have %d want %d (err=%v)", cerr.Code, websocket.CloseMessageTooBig, err)
}
} else if !errors.Is(err, websocket.ErrReadLimit) &&
!strings.Contains(strings.ToLower(err.Error()), "1009") &&
!strings.Contains(strings.ToLower(err.Error()), "message too big") {
// Not the error we expect from exceeding the message size limit.
t.Fatalf("unexpected error for read limit violation: %v", err)
}
} else {
// Expecting success
if err != nil {
t.Fatalf("unexpected error for request size %d with limit %d: %v", tc.testSize, tc.readLimit, err)
}
// Verify the response is correct - the echo should return our string
if result.String != largeString {
t.Fatalf("expected echo result to match input")
}
}
})
}
}
2 changes: 1 addition & 1 deletion rpc/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
log.Debug("WebSocket upgrade failed", "err", err)
return
}
codec := newWebsocketCodec(conn, r.Host, r.Header, wsDefaultReadLimit)
codec := newWebsocketCodec(conn, r.Host, r.Header, s.wsReadLimit)
s.ServeCodec(codec, 0)
})
}
Expand Down