Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions rpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ type Server struct {
batchItemLimit int
batchResponseLimit int
httpBodyLimit int
wsReadLimit int64
}

// NewServer creates a new server instance with no registered handlers.
Expand All @@ -62,6 +63,7 @@ func NewServer() *Server {
idgen: randomIDGenerator(),
codecs: make(map[ServerCodec]struct{}),
httpBodyLimit: defaultBodyLimit,
wsReadLimit: wsDefaultReadLimit,
}
server.run.Store(true)
// Register the default service providing meta information about the RPC service such
Expand Down Expand Up @@ -89,6 +91,13 @@ func (s *Server) SetHTTPBodyLimit(limit int) {
s.httpBodyLimit = limit
}

// SetWebsocketReadLimit sets the limit for max message size for Websocket requests.
//
// This method should be called before processing any requests via Websocket server.
func (s *Server) SetWebsocketReadLimit(limit int64) {
s.wsReadLimit = limit
}

// RegisterName creates a service for the given receiver type under the given name. When no
// methods on the given receiver match the criteria to be either an RPC method or a
// subscription an error is returned. Otherwise a new service is created and added to the
Expand Down
76 changes: 76 additions & 0 deletions rpc/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ package rpc
import (
"bufio"
"bytes"
"context"
"io"
"net"
"net/http/httptest"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -202,3 +204,77 @@ func TestServerBatchResponseSizeLimit(t *testing.T) {
}
}
}

func TestServerWebsocketReadLimit(t *testing.T) {
t.Parallel()

// Test different read limits
testCases := []struct {
name string
readLimit int64
testSize int
shouldFail bool
}{
{
name: "limit with small request - should succeed",
readLimit: 2048,
testSize: 500, // Small request data
shouldFail: false,
},
{
name: "limit with large request - should fail",
readLimit: 2048,
testSize: 5000, // Large request data that should exceed limit
shouldFail: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create server and set read limits
srv := newTestServer()
srv.SetWebsocketReadLimit(tc.readLimit)
defer srv.Stop()

// Start HTTP server with WebSocket handler
httpsrv := httptest.NewServer(srv.WebsocketHandler([]string{"*"}))
defer httpsrv.Close()

wsURL := "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")

// Connect WebSocket client
client, err := DialOptions(context.Background(), wsURL)
if err != nil {
t.Fatalf("can't dial: %v", err)
}
defer client.Close()

// Create large request data - this is what will be limited
largeString := strings.Repeat("A", tc.testSize)

// Send the large string as a parameter in the request
var result echoResult
err = client.Call(&result, "test_echo", largeString, 42, &echoArgs{S: "test"})

if tc.shouldFail {
// Expecting an error due to read limit exceeded
if err == nil {
t.Fatalf("expected error for request size %d with limit %d, but got none", tc.testSize, tc.readLimit)
}
// Check if it's the expected message size limit error
if !strings.Contains(err.Error(), "message too big") {
t.Fatalf("expected 'message too big' error, got: %v", err)
}
} else {
// Expecting success
if err != nil {
t.Fatalf("unexpected error for request size %d with limit %d: %v", tc.testSize, tc.readLimit, err)
}
// Verify the response is correct - the echo should return our string
if result.String != largeString {
t.Fatalf("expected echo result to match input")
}
}
})
}
}
2 changes: 1 addition & 1 deletion rpc/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
log.Debug("WebSocket upgrade failed", "err", err)
return
}
codec := newWebsocketCodec(conn, r.Host, r.Header, wsDefaultReadLimit)
codec := newWebsocketCodec(conn, r.Host, r.Header, s.wsReadLimit)
s.ServeCodec(codec, 0)
})
}
Expand Down
Loading