1111//
1212// Features:
1313// - HTTP and WebSocket endpoint support
14- // - Automatic method resolution and dispatch
14+ // - Automatic method resolution and dispatch with cached lookups
1515// - Protocol Buffer JSON marshaling/unmarshaling
1616// - Multiple streaming patterns (unary, server, client, bidirectional)
1717// - Context enrichment with HTTP and WebSocket components
2121// 1. Define your service in a .proto file
2222// 2. Generate Go code using protoc with the protoc-gen-jrpc plugin
2323// 3. Implement the generated Server interface
24- // 4. Create a new jRPC service with jrpc.New (yourServer)
24+ // 4. Create a new jRPC service with jrpc.Register (yourServer)
2525// 5. Register the HandlerFunc with the web package function WithJRPC
2626//
2727// Example:
5050// err := web.Instance().
5151// WithHost("localhost").
5252// WithPort(8080).
53- // WithJRPC(jrpc.New (&MyService{})).
53+ // WithJRPC(jrpc.Register (&MyService{})).
5454// Start().Error
5555// if err != nil {
5656// log.Fatal().Err(err).Msg("server exited")
@@ -67,6 +67,7 @@ import (
6767 "net/http"
6868 "reflect"
6969 "strings"
70+ "sync"
7071 "time"
7172
7273 "github.com/gorilla/websocket"
@@ -79,13 +80,42 @@ import (
7980 "google.golang.org/protobuf/types/dynamicpb"
8081)
8182
83+ // Common error messages to avoid repeated allocations
84+ var (
85+ errMethodNotFound = apperror .NewError ("method not found" )
86+ errMethodReflectionNotFound = apperror .NewError ("method reflection data not found" )
87+ errInvalidMethodSignature = errors .New ("invalid method signature" )
88+ errFirstArgMustBeContext = errors .New ("first argument must be context.Context" )
89+ errSecondReturnMustBeError = errors .New ("second return value must be error" )
90+ errRequestMustBePointer = errors .New ("request must be a pointer" )
91+ errNilRequest = errors .New ("nil request" )
92+ errExpectedProtoMessage = errors .New ("expected proto.Message for request" )
93+
94+ // Cached reflection types to avoid repeated type operations
95+ contextType = reflect .TypeOf ((* context .Context )(nil )).Elem ()
96+ errorType = reflect .TypeOf ((* error )(nil )).Elem ()
97+
98+ // Reusable marshal options to avoid allocation
99+ marshalOpts = protojson.MarshalOptions {
100+ EmitUnpopulated : true ,
101+ UseProtoNames : true ,
102+ }
103+ )
104+
82105// upgrader is the WebSocket upgrader with default options
83106var upgrader = websocket.Upgrader {
84107 CheckOrigin : func (r * http.Request ) bool {
85108 return true // Allow connections from any origin
86109 },
87110}
88111
112+ // bufferPool provides a pool of byte buffers for JSON operations
113+ var bufferPool = sync.Pool {
114+ New : func () interface {} {
115+ return make ([]byte , 0 , 1024 ) // Pre-allocate 1KB buffers
116+ },
117+ }
118+
89119// ContextKey represents keys for context values
90120type ContextKey string
91121
@@ -103,6 +133,8 @@ const (
103133// protocol buffer message handling, and context enrichment.
104134type Service struct {
105135 Server
136+ methods map [string ]* methodInfo // cached method information for faster lookup
137+ types map [protoreflect.FullName ]proto.Message // cached message types
106138}
107139
108140// Server represents a jRPC service implementation.
@@ -111,10 +143,76 @@ type Server interface {
111143 Descriptor () protoreflect.FileDescriptor
112144}
113145
114- // New creates a new jrpc service instance and registers the provided
146+ // methodInfo holds cached reflection and protobuf information for a method
147+ type methodInfo struct {
148+ descriptor protoreflect.MethodDescriptor
149+ method reflect.Value
150+ reflectType reflect.Type
151+ inputType reflect.Type
152+ outputType reflect.Type
153+ messageType proto.Message
154+ validated bool
155+ }
156+
157+ // Register creates a new jrpc service instance and registers the provided
115158// service implementation. The service implementation has to implement the Descriptor method.
116- func New (s Server ) * Service {
117- return & Service {Server : s }
159+ // This function builds a method cache for improved lookup performance.
160+ func Register (s Server ) * Service {
161+ service := & Service {
162+ Server : s ,
163+ methods : make (map [string ]* methodInfo ),
164+ types : make (map [protoreflect.FullName ]proto.Message ),
165+ }
166+
167+ sv := reflect .ValueOf (s )
168+ services := s .Descriptor ().Services ()
169+ for i := 0 ; i < services .Len (); i ++ {
170+ sd := services .Get (i )
171+ methods := sd .Methods ()
172+ for j := 0 ; j < methods .Len (); j ++ {
173+ md := methods .Get (j )
174+ mn := string (md .Name ())
175+ sn := string (sd .Name ())
176+
177+ key := sn + "." + mn
178+
179+ rm := sv .MethodByName (mn )
180+ if ! rm .IsValid () {
181+ continue
182+ }
183+
184+ mt := rm .Type ()
185+
186+ var pm proto.Message
187+ if mt , err := protoregistry .GlobalTypes .FindMessageByName (md .Input ().FullName ()); err == nil {
188+ pm = mt .New ().Interface ()
189+ service .types [md .Input ().FullName ()] = pm
190+ } else {
191+ pm = dynamicpb .NewMessage (md .Input ())
192+ service .types [md .Input ().FullName ()] = pm
193+ }
194+
195+ var it , ot reflect.Type
196+ if mt .NumIn () >= 2 {
197+ it = mt .In (1 )
198+ }
199+ if mt .NumOut () >= 1 {
200+ ot = mt .Out (0 )
201+ }
202+
203+ service .methods [key ] = & methodInfo {
204+ descriptor : md ,
205+ method : rm ,
206+ reflectType : mt ,
207+ inputType : it ,
208+ outputType : ot ,
209+ messageType : pm ,
210+ validated : false ,
211+ }
212+ }
213+ }
214+
215+ return service
118216}
119217
120218// SetUpgrader allows setting a custom WebSocket upgrader with specific options.
@@ -153,10 +251,10 @@ func (s *Service) HandlerFunc(w http.ResponseWriter, r *http.Request) {
153251}
154252
155253// isWebSocketRequest checks if the HTTP request is requesting a WebSocket upgrade
254+ // Optimized version with reduced string allocations
156255func (s * Service ) isWebSocketRequest (r * http.Request ) bool {
157- connection := strings .ToLower (r .Header .Get ("Connection" ))
158- upgrade := strings .ToLower (r .Header .Get ("Upgrade" ))
159- return strings .Contains (connection , "upgrade" ) && upgrade == "websocket"
256+ return strings .Contains (strings .ToLower (r .Header .Get ("Connection" )), "upgrade" ) &&
257+ strings .EqualFold (r .Header .Get ("Upgrade" ), "websocket" )
160258}
161259
162260// unary processes HTTP POST requests to API endpoints.
@@ -188,20 +286,32 @@ func (s *Service) unary(w http.ResponseWriter, r *http.Request) {
188286 }
189287
190288 if r .ContentLength > 0 {
191- body , _ := io .ReadAll (r .Body )
192- if len (body ) != int (r .ContentLength ) {
193- http .Error (w , "body length does not match Content-Length" , http .StatusBadRequest )
289+ // Use buffer pool for body reading
290+ buf := bufferPool .Get ().([]byte )
291+ defer bufferPool .Put (buf [:0 ])
292+
293+ // Ensure buffer is large enough
294+ if cap (buf ) < int (r .ContentLength ) {
295+ buf = make ([]byte , r .ContentLength )
296+ } else {
297+ buf = buf [:r .ContentLength ]
298+ }
299+
300+ _ , err := io .ReadFull (r .Body , buf )
301+ if err != nil {
302+ http .Error (w , "failed to read request body" , http .StatusBadRequest )
194303 return
195304 }
196- err = protojson .Unmarshal (body , msg )
305+
306+ err = protojson .Unmarshal (buf , msg )
197307 if err != nil {
198308 http .Error (w , err .Error (), http .StatusBadRequest )
199309 return
200310 }
201311 }
202312 defer apperror .Catch (r .Body .Close , "closing request body failed" )
203313
204- resp , err := s .call (ctx , method , msg )
314+ resp , err := s .call (ctx , service , method , msg )
205315 if err != nil {
206316 http .Error (w , err .Error (), http .StatusInternalServerError )
207317 return
@@ -246,14 +356,15 @@ func (s *Service) websocket(w http.ResponseWriter, r *http.Request, conn *websoc
246356 return
247357 }
248358
249- m := reflect .ValueOf (s .Server ).MethodByName (method )
250- if ! m .IsValid () {
251- s .closeWS (conn , websocket .CloseInternalServerErr , "method not found" )
359+ if ! md .method .IsValid () {
360+ s .closeWS (conn , websocket .CloseInternalServerErr , "service not registered" )
252361 return
253362 }
254- mt := m .Type ()
255363
256- streamingType , err := s .validateMethodSignature (mt , md )
364+ m := md .method
365+ mt := md .reflectType
366+
367+ streamingType , err := s .validateMethodSignature (mt , md .descriptor )
257368 if err != nil {
258369 s .closeWS (conn , websocket .CloseInternalServerErr , "invalid method signature: " + err .Error ())
259370 return
@@ -350,43 +461,57 @@ func GetWebSocketConn(ctx context.Context) (*websocket.Conn, bool) {
350461 return conn , ok
351462}
352463
353- func (s * Service ) call (ctx context.Context , method string , req proto.Message ) (any , error ) {
354- m := reflect . ValueOf ( s . Server ). MethodByName ( method )
355- if ! m . IsValid () {
356- return nil , apperror . NewError ( "method not found" )
464+ func (s * Service ) call (ctx context.Context , service , method string , req proto.Message ) (any , error ) {
465+ methodInfo , exists := s . methods [ service + "." + method ]
466+ if ! exists {
467+ return nil , errMethodNotFound
357468 }
358469
359- mt := m .Type ()
360- if mt .NumIn () != 2 || mt .NumOut () != 2 {
361- return nil , errors .New ("invalid method signature" )
362- }
363- if ! mt .In (0 ).Implements (reflect .TypeOf ((* context .Context )(nil )).Elem ()) {
364- return nil , errors .New ("first argument must be context.Context" )
470+ if ! methodInfo .method .IsValid () {
471+ return nil , errMethodReflectionNotFound
365472 }
366- if ! mt .Out (1 ).Implements (reflect .TypeOf ((* error )(nil )).Elem ()) {
367- return nil , errors .New ("second return value must be error" )
473+
474+ m := methodInfo .method
475+ mt := methodInfo .reflectType
476+
477+ // Validate method signature if not already validated
478+ if ! methodInfo .validated {
479+ if mt .NumIn () != 2 || mt .NumOut () != 2 {
480+ return nil , errInvalidMethodSignature
481+ }
482+ if ! mt .In (0 ).Implements (contextType ) {
483+ return nil , errFirstArgMustBeContext
484+ }
485+ if ! mt .Out (1 ).Implements (errorType ) {
486+ return nil , errSecondReturnMustBeError
487+ }
488+ methodInfo .validated = true
368489 }
369490
370491 wanted := mt .In (1 )
371492 if wanted .Kind () != reflect .Ptr {
372- return nil , errors . New ( "request must be a pointer" )
493+ return nil , errRequestMustBePointer
373494 }
374495
375496 reqVal := reflect .ValueOf (req )
376497 if ! reqVal .IsValid () {
377- return nil , errors . New ( "nil request" )
498+ return nil , errNilRequest
378499 }
379500
380501 if ! reqVal .Type ().AssignableTo (wanted ) {
381502 // Convert via JSON round-trip using protojson to the expected type.
503+ // Use buffer pool for better performance
504+ buf := bufferPool .Get ().([]byte )
505+ defer bufferPool .Put (buf [:0 ])
506+
382507 reqPtr := reflect .New (wanted .Elem ())
383508 b , err := protojson .Marshal (req )
384509 if err != nil {
385510 return nil , err
386511 }
387512 pm , ok := reqPtr .Interface ().(proto.Message )
388513 if ! ok {
389- return nil , errors . New ( "expected proto.Message for request" )
514+ return nil , errExpectedProtoMessage
390515 }
391516 if err := protojson .Unmarshal (b , pm ); err != nil {
392517 return nil , err
@@ -403,43 +528,31 @@ func (s *Service) call(ctx context.Context, method string, req proto.Message) (a
403528 return res , err
404529}
405530
406- func (s * Service ) find (service , method string ) (protoreflect. MethodDescriptor , error ) {
407- sd := s .Descriptor (). Services (). ByName ( protoreflect . Name ( service ))
408- if sd == nil {
409- return nil , apperror . NewError ( "service not found" )
531+ func (s * Service ) find (service , method string ) (* methodInfo , error ) {
532+ md , exists := s .methods [ service + "." + method ]
533+ if ! exists {
534+ return nil , errMethodNotFound
410535 }
411536
412- md := sd .Methods ().ByName (protoreflect .Name (method ))
413- if md == nil {
414- return nil , apperror .NewError ("method not found" )
415- }
416537 return md , nil
417538}
418539
419- func (s * Service ) message (md protoreflect.MethodDescriptor ) (proto.Message , error ) {
420- mt , err := protoregistry .GlobalTypes .FindMessageByName (md .Input ().FullName ())
421- if err != nil {
422- log .Error ().Err (err ).Msg ("failed to find message type" )
423- return dynamicpb .NewMessage (md .Input ()), nil
540+ func (s * Service ) message (md * methodInfo ) (proto.Message , error ) {
541+ if md .messageType == nil {
542+ return nil , apperror .NewError ("message type not found" )
424543 }
425544
426- return mt . New (). Interface ( ), nil
545+ return proto . Clone ( md . messageType ), nil
427546}
428547
429548func (s * Service ) marshal (v any ) ([]byte , error ) {
430549 if pm , ok := v .(proto.Message ); ok {
431- return protojson.MarshalOptions {
432- EmitUnpopulated : true ,
433- UseProtoNames : true ,
434- }.Marshal (pm )
550+ return marshalOpts .Marshal (pm )
435551 }
436552 rv := reflect .ValueOf (v )
437553 if rv .IsValid () && rv .CanAddr () {
438554 if pm , ok := rv .Addr ().Interface ().(proto.Message ); ok {
439- return protojson.MarshalOptions {
440- EmitUnpopulated : true ,
441- UseProtoNames : true ,
442- }.Marshal (pm )
555+ return marshalOpts .Marshal (pm )
443556 }
444557 }
445558 // Handle non-pointer values by creating a pointer to them
@@ -450,10 +563,7 @@ func (s *Service) marshal(v any) ([]byte, error) {
450563 newPtr := reflect .New (rv .Type ())
451564 newPtr .Elem ().Set (rv )
452565 if pm , ok := newPtr .Interface ().(proto.Message ); ok {
453- return protojson.MarshalOptions {
454- EmitUnpopulated : true ,
455- UseProtoNames : true ,
456- }.Marshal (pm )
566+ return marshalOpts .Marshal (pm )
457567 }
458568 }
459569 }
@@ -503,7 +613,7 @@ func (s *Service) handleBidirectionalStream(ctx context.Context, conn *websocket
503613}
504614
505615// handleServerStream handles server streaming WebSocket connections
506- func (s * Service ) handleServerStream (ctx context.Context , conn * websocket.Conn , m reflect.Value , mt reflect.Type , md protoreflect. MethodDescriptor ) {
616+ func (s * Service ) handleServerStream (ctx context.Context , conn * websocket.Conn , m reflect.Value , mt reflect.Type , md * methodInfo ) {
507617 outType := mt .In (2 )
508618 outPtr := outType .Elem ()
509619 out := reflect .MakeChan (outType , 0 )
@@ -771,9 +881,6 @@ const (
771881
772882// validateMethodSignature validates and determines the streaming type of a method
773883func (s * Service ) validateMethodSignature (mt reflect.Type , md protoreflect.MethodDescriptor ) (StreamingType , error ) {
774- contextType := reflect .TypeOf ((* context .Context )(nil )).Elem ()
775- errorType := reflect .TypeOf ((* error )(nil )).Elem ()
776-
777884 // Basic validation: must have at least context parameter and error return
778885 if mt .NumIn () < 1 || ! mt .In (0 ).Implements (contextType ) {
779886 return StreamingTypeInvalid , errors .New ("first parameter must be context.Context" )
0 commit comments