@@ -80,6 +80,28 @@ import (
8080 "google.golang.org/protobuf/types/dynamicpb"
8181)
8282
83+ // Common error messages to avoid repeated allocations
84+ var (
85+ errMethodNotFound = apperror .NewError ("method not found" )
86+ errMethodReflectionNotFound = apperror .NewError ("method reflection data not found" )
87+ errInvalidMethodSignature = errors .New ("invalid method signature" )
88+ errFirstArgMustBeContext = errors .New ("first argument must be context.Context" )
89+ errSecondReturnMustBeError = errors .New ("second return value must be error" )
90+ errRequestMustBePointer = errors .New ("request must be a pointer" )
91+ errNilRequest = errors .New ("nil request" )
92+ errExpectedProtoMessage = errors .New ("expected proto.Message for request" )
93+
94+ // Cached reflection types to avoid repeated type operations
95+ contextType = reflect .TypeOf ((* context .Context )(nil )).Elem ()
96+ errorType = reflect .TypeOf ((* error )(nil )).Elem ()
97+
98+ // Reusable marshal options to avoid allocation
99+ marshalOpts = protojson.MarshalOptions {
100+ EmitUnpopulated : true ,
101+ UseProtoNames : true ,
102+ }
103+ )
104+
83105// upgrader is the WebSocket upgrader with default options
84106var upgrader = websocket.Upgrader {
85107 CheckOrigin : func (r * http.Request ) bool {
@@ -112,7 +134,6 @@ const (
112134type Service struct {
113135 Server
114136 methods map [string ]* methodInfo // cached method information for faster lookup
115- keys map [string ]string // pre-computed method keys to avoid string concatenation
116137 types map [protoreflect.FullName ]proto.Message // cached message types
117138}
118139
@@ -140,7 +161,6 @@ func Register(s Server) *Service {
140161 service := & Service {
141162 Server : s ,
142163 methods : make (map [string ]* methodInfo ),
143- keys : make (map [string ]string ),
144164 types : make (map [protoreflect.FullName ]proto.Message ),
145165 }
146166
@@ -155,7 +175,6 @@ func Register(s Server) *Service {
155175 sn := string (sd .Name ())
156176
157177 key := sn + "." + mn
158- service .keys [mn ] = key
159178
160179 rm := sv .MethodByName (mn )
161180 if ! rm .IsValid () {
@@ -267,20 +286,32 @@ func (s *Service) unary(w http.ResponseWriter, r *http.Request) {
267286 }
268287
269288 if r .ContentLength > 0 {
270- body , _ := io .ReadAll (r .Body )
271- if len (body ) != int (r .ContentLength ) {
272- http .Error (w , "body length does not match Content-Length" , http .StatusBadRequest )
289+ // Use buffer pool for body reading
290+ buf := bufferPool .Get ().([]byte )
291+ defer bufferPool .Put (buf [:0 ])
292+
293+ // Ensure buffer is large enough
294+ if cap (buf ) < int (r .ContentLength ) {
295+ buf = make ([]byte , r .ContentLength )
296+ } else {
297+ buf = buf [:r .ContentLength ]
298+ }
299+
300+ _ , err := io .ReadFull (r .Body , buf )
301+ if err != nil {
302+ http .Error (w , "failed to read request body" , http .StatusBadRequest )
273303 return
274304 }
275- err = protojson .Unmarshal (body , msg )
305+
306+ err = protojson .Unmarshal (buf , msg )
276307 if err != nil {
277308 http .Error (w , err .Error (), http .StatusBadRequest )
278309 return
279310 }
280311 }
281312 defer apperror .Catch (r .Body .Close , "closing request body failed" )
282313
283- resp , err := s .call (ctx , method , msg )
314+ resp , err := s .call (ctx , service , method , msg )
284315 if err != nil {
285316 http .Error (w , err .Error (), http .StatusInternalServerError )
286317 return
@@ -430,21 +461,14 @@ func GetWebSocketConn(ctx context.Context) (*websocket.Conn, bool) {
430461 return conn , ok
431462}
432463
433- func (s * Service ) call (ctx context.Context , method string , req proto.Message ) (any , error ) {
434- var methodInfo * methodInfo
435- for key , info := range s .methods {
436- if strings .HasSuffix (key , "." + method ) {
437- methodInfo = info
438- break
439- }
440- }
441-
442- if methodInfo == nil {
443- return nil , apperror .NewError ("method not found" )
464+ func (s * Service ) call (ctx context.Context , service , method string , req proto.Message ) (any , error ) {
465+ methodInfo , exists := s .methods [service + "." + method ]
466+ if ! exists {
467+ return nil , errMethodNotFound
444468 }
445469
446470 if ! methodInfo .method .IsValid () {
447- return nil , apperror . NewError ( "method reflection data not found" )
471+ return nil , errMethodReflectionNotFound
448472 }
449473
450474 m := methodInfo .method
@@ -453,25 +477,25 @@ func (s *Service) call(ctx context.Context, method string, req proto.Message) (a
453477 // Validate method signature if not already validated
454478 if ! methodInfo .validated {
455479 if mt .NumIn () != 2 || mt .NumOut () != 2 {
456- return nil , errors . New ( "invalid method signature" )
480+ return nil , errInvalidMethodSignature
457481 }
458- if ! mt .In (0 ).Implements (reflect . TypeOf (( * context . Context )( nil )). Elem () ) {
459- return nil , errors . New ( "first argument must be context.Context" )
482+ if ! mt .In (0 ).Implements (contextType ) {
483+ return nil , errFirstArgMustBeContext
460484 }
461- if ! mt .Out (1 ).Implements (reflect . TypeOf (( * error )( nil )). Elem () ) {
462- return nil , errors . New ( "second return value must be error" )
485+ if ! mt .Out (1 ).Implements (errorType ) {
486+ return nil , errSecondReturnMustBeError
463487 }
464488 methodInfo .validated = true
465489 }
466490
467491 wanted := mt .In (1 )
468492 if wanted .Kind () != reflect .Ptr {
469- return nil , errors . New ( "request must be a pointer" )
493+ return nil , errRequestMustBePointer
470494 }
471495
472496 reqVal := reflect .ValueOf (req )
473497 if ! reqVal .IsValid () {
474- return nil , errors . New ( "nil request" )
498+ return nil , errNilRequest
475499 }
476500
477501 if ! reqVal .Type ().AssignableTo (wanted ) {
@@ -487,7 +511,7 @@ func (s *Service) call(ctx context.Context, method string, req proto.Message) (a
487511 }
488512 pm , ok := reqPtr .Interface ().(proto.Message )
489513 if ! ok {
490- return nil , errors . New ( "expected proto.Message for request" )
514+ return nil , errExpectedProtoMessage
491515 }
492516 if err := protojson .Unmarshal (b , pm ); err != nil {
493517 return nil , err
@@ -505,14 +529,9 @@ func (s *Service) call(ctx context.Context, method string, req proto.Message) (a
505529}
506530
507531func (s * Service ) find (service , method string ) (* methodInfo , error ) {
508- key , exists := s .keys [ method ]
532+ md , exists := s .methods [ service + "." + method ]
509533 if ! exists {
510- return nil , apperror .NewError ("method not found" )
511- }
512-
513- md , exists := s .methods [key ]
514- if ! exists {
515- return nil , apperror .NewError ("method not found" )
534+ return nil , errMethodNotFound
516535 }
517536
518537 return md , nil
@@ -528,18 +547,12 @@ func (s *Service) message(md *methodInfo) (proto.Message, error) {
528547
529548func (s * Service ) marshal (v any ) ([]byte , error ) {
530549 if pm , ok := v .(proto.Message ); ok {
531- return protojson.MarshalOptions {
532- EmitUnpopulated : true ,
533- UseProtoNames : true ,
534- }.Marshal (pm )
550+ return marshalOpts .Marshal (pm )
535551 }
536552 rv := reflect .ValueOf (v )
537553 if rv .IsValid () && rv .CanAddr () {
538554 if pm , ok := rv .Addr ().Interface ().(proto.Message ); ok {
539- return protojson.MarshalOptions {
540- EmitUnpopulated : true ,
541- UseProtoNames : true ,
542- }.Marshal (pm )
555+ return marshalOpts .Marshal (pm )
543556 }
544557 }
545558 // Handle non-pointer values by creating a pointer to them
@@ -550,10 +563,7 @@ func (s *Service) marshal(v any) ([]byte, error) {
550563 newPtr := reflect .New (rv .Type ())
551564 newPtr .Elem ().Set (rv )
552565 if pm , ok := newPtr .Interface ().(proto.Message ); ok {
553- return protojson.MarshalOptions {
554- EmitUnpopulated : true ,
555- UseProtoNames : true ,
556- }.Marshal (pm )
566+ return marshalOpts .Marshal (pm )
557567 }
558568 }
559569 }
@@ -871,9 +881,6 @@ const (
871881
872882// validateMethodSignature validates and determines the streaming type of a method
873883func (s * Service ) validateMethodSignature (mt reflect.Type , md protoreflect.MethodDescriptor ) (StreamingType , error ) {
874- contextType := reflect .TypeOf ((* context .Context )(nil )).Elem ()
875- errorType := reflect .TypeOf ((* error )(nil )).Elem ()
876-
877884 // Basic validation: must have at least context parameter and error return
878885 if mt .NumIn () < 1 || ! mt .In (0 ).Implements (contextType ) {
879886 return StreamingTypeInvalid , errors .New ("first parameter must be context.Context" )
0 commit comments