@@ -67,6 +67,7 @@ import (
6767 "net/http"
6868 "reflect"
6969 "strings"
70+ "sync"
7071 "time"
7172
7273 "github.com/gorilla/websocket"
@@ -86,6 +87,13 @@ var upgrader = websocket.Upgrader{
8687 },
8788}
8889
90+ // bufferPool provides a pool of byte buffers for JSON operations
91+ var bufferPool = sync.Pool {
92+ New : func () interface {} {
93+ return make ([]byte , 0 , 1024 ) // Pre-allocate 1KB buffers
94+ },
95+ }
96+
8997// ContextKey represents keys for context values
9098type ContextKey string
9199
@@ -103,7 +111,9 @@ const (
103111// protocol buffer message handling, and context enrichment.
104112type Service struct {
105113 Server
106- methods map [string ]protoreflect.MethodDescriptor // cached method descriptors for faster lookup
114+ methods map [string ]* methodInfo // cached method information for faster lookup
115+ keys map [string ]string // pre-computed method keys to avoid string concatenation
116+ types map [protoreflect.FullName ]proto.Message // cached message types
107117}
108118
109119// Server represents a jRPC service implementation.
@@ -112,25 +122,74 @@ type Server interface {
112122 Descriptor () protoreflect.FileDescriptor
113123}
114124
125+ // methodInfo holds cached reflection and protobuf information for a method
126+ type methodInfo struct {
127+ descriptor protoreflect.MethodDescriptor
128+ method reflect.Value
129+ reflectType reflect.Type
130+ inputType reflect.Type
131+ outputType reflect.Type
132+ messageType proto.Message
133+ validated bool
134+ }
135+
115136// Register creates a new jrpc service instance and registers the provided
116137// service implementation. The service implementation has to implement the Descriptor method.
117138// This function builds a method cache for improved lookup performance.
118139func Register (s Server ) * Service {
119140 service := & Service {
120141 Server : s ,
121- methods : make (map [string ]protoreflect.MethodDescriptor ),
142+ methods : make (map [string ]* methodInfo ),
143+ keys : make (map [string ]string ),
144+ types : make (map [protoreflect.FullName ]proto.Message ),
122145 }
123146
124- // Build the methods cache
147+ sv := reflect . ValueOf ( s )
125148 services := s .Descriptor ().Services ()
126149 for i := 0 ; i < services .Len (); i ++ {
127- serviceDesc := services .Get (i )
128- methods := serviceDesc .Methods ()
150+ sd := services .Get (i )
151+ methods := sd .Methods ()
129152 for j := 0 ; j < methods .Len (); j ++ {
130- methodDesc := methods .Get (j )
131- // Use fully qualified method name as key: service.method
132- key := string (serviceDesc .Name ()) + "." + string (methodDesc .Name ())
133- service .methods [key ] = methodDesc
153+ md := methods .Get (j )
154+ mn := string (md .Name ())
155+ sn := string (sd .Name ())
156+
157+ key := sn + "." + mn
158+ service .keys [mn ] = key
159+
160+ rm := sv .MethodByName (mn )
161+ if ! rm .IsValid () {
162+ continue
163+ }
164+
165+ mt := rm .Type ()
166+
167+ var pm proto.Message
168+ if mt , err := protoregistry .GlobalTypes .FindMessageByName (md .Input ().FullName ()); err == nil {
169+ pm = mt .New ().Interface ()
170+ service .types [md .Input ().FullName ()] = pm
171+ } else {
172+ pm = dynamicpb .NewMessage (md .Input ())
173+ service .types [md .Input ().FullName ()] = pm
174+ }
175+
176+ var it , ot reflect.Type
177+ if mt .NumIn () >= 2 {
178+ it = mt .In (1 )
179+ }
180+ if mt .NumOut () >= 1 {
181+ ot = mt .Out (0 )
182+ }
183+
184+ service .methods [key ] = & methodInfo {
185+ descriptor : md ,
186+ method : rm ,
187+ reflectType : mt ,
188+ inputType : it ,
189+ outputType : ot ,
190+ messageType : pm ,
191+ validated : false ,
192+ }
134193 }
135194 }
136195
@@ -173,10 +232,10 @@ func (s *Service) HandlerFunc(w http.ResponseWriter, r *http.Request) {
173232}
174233
175234// isWebSocketRequest checks if the HTTP request is requesting a WebSocket upgrade
235+ // Optimized version with reduced string allocations
176236func (s * Service ) isWebSocketRequest (r * http.Request ) bool {
177- connection := strings .ToLower (r .Header .Get ("Connection" ))
178- upgrade := strings .ToLower (r .Header .Get ("Upgrade" ))
179- return strings .Contains (connection , "upgrade" ) && upgrade == "websocket"
237+ return strings .Contains (strings .ToLower (r .Header .Get ("Connection" )), "upgrade" ) &&
238+ strings .EqualFold (r .Header .Get ("Upgrade" ), "websocket" )
180239}
181240
182241// unary processes HTTP POST requests to API endpoints.
@@ -266,14 +325,15 @@ func (s *Service) websocket(w http.ResponseWriter, r *http.Request, conn *websoc
266325 return
267326 }
268327
269- m := reflect .ValueOf (s .Server ).MethodByName (method )
270- if ! m .IsValid () {
271- s .closeWS (conn , websocket .CloseInternalServerErr , "method not found" )
328+ if ! md .method .IsValid () {
329+ s .closeWS (conn , websocket .CloseInternalServerErr , "service not registered" )
272330 return
273331 }
274- mt := m .Type ()
275332
276- streamingType , err := s .validateMethodSignature (mt , md )
333+ m := md .method
334+ mt := md .reflectType
335+
336+ streamingType , err := s .validateMethodSignature (mt , md .descriptor )
277337 if err != nil {
278338 s .closeWS (conn , websocket .CloseInternalServerErr , "invalid method signature: " + err .Error ())
279339 return
@@ -371,20 +431,37 @@ func GetWebSocketConn(ctx context.Context) (*websocket.Conn, bool) {
371431}
372432
373433func (s * Service ) call (ctx context.Context , method string , req proto.Message ) (any , error ) {
374- m := reflect .ValueOf (s .Server ).MethodByName (method )
375- if ! m .IsValid () {
376- return nil , apperror .NewError ("method not found" )
434+ var methodInfo * methodInfo
435+ for key , info := range s .methods {
436+ if strings .HasSuffix (key , "." + method ) {
437+ methodInfo = info
438+ break
439+ }
377440 }
378441
379- mt := m .Type ()
380- if mt .NumIn () != 2 || mt .NumOut () != 2 {
381- return nil , errors .New ("invalid method signature" )
442+ if methodInfo == nil {
443+ return nil , apperror .NewError ("method not found" )
382444 }
383- if ! mt .In (0 ).Implements (reflect .TypeOf ((* context .Context )(nil )).Elem ()) {
384- return nil , errors .New ("first argument must be context.Context" )
445+
446+ if ! methodInfo .method .IsValid () {
447+ return nil , apperror .NewError ("method reflection data not found" )
385448 }
386- if ! mt .Out (1 ).Implements (reflect .TypeOf ((* error )(nil )).Elem ()) {
387- return nil , errors .New ("second return value must be error" )
449+
450+ m := methodInfo .method
451+ mt := methodInfo .reflectType
452+
453+ // Validate method signature if not already validated
454+ if ! methodInfo .validated {
455+ if mt .NumIn () != 2 || mt .NumOut () != 2 {
456+ return nil , errors .New ("invalid method signature" )
457+ }
458+ if ! mt .In (0 ).Implements (reflect .TypeOf ((* context .Context )(nil )).Elem ()) {
459+ return nil , errors .New ("first argument must be context.Context" )
460+ }
461+ if ! mt .Out (1 ).Implements (reflect .TypeOf ((* error )(nil )).Elem ()) {
462+ return nil , errors .New ("second return value must be error" )
463+ }
464+ methodInfo .validated = true
388465 }
389466
390467 wanted := mt .In (1 )
@@ -399,6 +476,10 @@ func (s *Service) call(ctx context.Context, method string, req proto.Message) (a
399476
400477 if ! reqVal .Type ().AssignableTo (wanted ) {
401478 // Convert via JSON round-trip using protojson to the expected type.
479+ // Use buffer pool for better performance
480+ buf := bufferPool .Get ().([]byte )
481+ defer bufferPool .Put (buf [:0 ])
482+
402483 reqPtr := reflect .New (wanted .Elem ())
403484 b , err := protojson .Marshal (req )
404485 if err != nil {
@@ -423,34 +504,26 @@ func (s *Service) call(ctx context.Context, method string, req proto.Message) (a
423504 return res , err
424505}
425506
426- func (s * Service ) find (service , method string ) (protoreflect.MethodDescriptor , error ) {
427- // Use cached method lookup for better performance
428- key := service + "." + method
429- if md , exists := s .methods [key ]; exists {
430- return md , nil
431- }
432-
433- // Fallback to dynamic lookup if not found in cache (should not happen in normal operation)
434- sd := s .Descriptor ().Services ().ByName (protoreflect .Name (service ))
435- if sd == nil {
436- return nil , apperror .NewError ("service not found" )
507+ func (s * Service ) find (service , method string ) (* methodInfo , error ) {
508+ key , exists := s .keys [method ]
509+ if ! exists {
510+ return nil , apperror .NewError ("method not found" )
437511 }
438512
439- md := sd . Methods (). ByName ( protoreflect . Name ( method ))
440- if md == nil {
513+ md , exists := s . methods [ key ]
514+ if ! exists {
441515 return nil , apperror .NewError ("method not found" )
442516 }
517+
443518 return md , nil
444519}
445520
446- func (s * Service ) message (md protoreflect.MethodDescriptor ) (proto.Message , error ) {
447- mt , err := protoregistry .GlobalTypes .FindMessageByName (md .Input ().FullName ())
448- if err != nil {
449- log .Error ().Err (err ).Msg ("failed to find message type" )
450- return dynamicpb .NewMessage (md .Input ()), nil
521+ func (s * Service ) message (md * methodInfo ) (proto.Message , error ) {
522+ if md .messageType == nil {
523+ return nil , apperror .NewError ("message type not found" )
451524 }
452525
453- return mt . New (). Interface ( ), nil
526+ return proto . Clone ( md . messageType ), nil
454527}
455528
456529func (s * Service ) marshal (v any ) ([]byte , error ) {
@@ -530,7 +603,7 @@ func (s *Service) handleBidirectionalStream(ctx context.Context, conn *websocket
530603}
531604
532605// handleServerStream handles server streaming WebSocket connections
533- func (s * Service ) handleServerStream (ctx context.Context , conn * websocket.Conn , m reflect.Value , mt reflect.Type , md protoreflect. MethodDescriptor ) {
606+ func (s * Service ) handleServerStream (ctx context.Context , conn * websocket.Conn , m reflect.Value , mt reflect.Type , md * methodInfo ) {
534607 outType := mt .In (2 )
535608 outPtr := outType .Elem ()
536609 out := reflect .MakeChan (outType , 0 )
0 commit comments