Skip to content

Commit 5aa30bc

Browse files
[FIX] optimize JRPC service with error constants and buffer pool
1 parent 29fe167 commit 5aa30bc

File tree

1 file changed

+57
-50
lines changed

1 file changed

+57
-50
lines changed

web/jrpc/jrpc.go

Lines changed: 57 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,28 @@ import (
8080
"google.golang.org/protobuf/types/dynamicpb"
8181
)
8282

83+
// Common error messages to avoid repeated allocations
84+
var (
85+
errMethodNotFound = apperror.NewError("method not found")
86+
errMethodReflectionNotFound = apperror.NewError("method reflection data not found")
87+
errInvalidMethodSignature = errors.New("invalid method signature")
88+
errFirstArgMustBeContext = errors.New("first argument must be context.Context")
89+
errSecondReturnMustBeError = errors.New("second return value must be error")
90+
errRequestMustBePointer = errors.New("request must be a pointer")
91+
errNilRequest = errors.New("nil request")
92+
errExpectedProtoMessage = errors.New("expected proto.Message for request")
93+
94+
// Cached reflection types to avoid repeated type operations
95+
contextType = reflect.TypeOf((*context.Context)(nil)).Elem()
96+
errorType = reflect.TypeOf((*error)(nil)).Elem()
97+
98+
// Reusable marshal options to avoid allocation
99+
marshalOpts = protojson.MarshalOptions{
100+
EmitUnpopulated: true,
101+
UseProtoNames: true,
102+
}
103+
)
104+
83105
// upgrader is the WebSocket upgrader with default options
84106
var upgrader = websocket.Upgrader{
85107
CheckOrigin: func(r *http.Request) bool {
@@ -112,7 +134,6 @@ const (
112134
type Service struct {
113135
Server
114136
methods map[string]*methodInfo // cached method information for faster lookup
115-
keys map[string]string // pre-computed method keys to avoid string concatenation
116137
types map[protoreflect.FullName]proto.Message // cached message types
117138
}
118139

@@ -140,7 +161,6 @@ func Register(s Server) *Service {
140161
service := &Service{
141162
Server: s,
142163
methods: make(map[string]*methodInfo),
143-
keys: make(map[string]string),
144164
types: make(map[protoreflect.FullName]proto.Message),
145165
}
146166

@@ -155,7 +175,6 @@ func Register(s Server) *Service {
155175
sn := string(sd.Name())
156176

157177
key := sn + "." + mn
158-
service.keys[mn] = key
159178

160179
rm := sv.MethodByName(mn)
161180
if !rm.IsValid() {
@@ -267,20 +286,32 @@ func (s *Service) unary(w http.ResponseWriter, r *http.Request) {
267286
}
268287

269288
if r.ContentLength > 0 {
270-
body, _ := io.ReadAll(r.Body)
271-
if len(body) != int(r.ContentLength) {
272-
http.Error(w, "body length does not match Content-Length", http.StatusBadRequest)
289+
// Use buffer pool for body reading
290+
buf := bufferPool.Get().([]byte)
291+
defer bufferPool.Put(buf[:0])
292+
293+
// Ensure buffer is large enough
294+
if cap(buf) < int(r.ContentLength) {
295+
buf = make([]byte, r.ContentLength)
296+
} else {
297+
buf = buf[:r.ContentLength]
298+
}
299+
300+
_, err := io.ReadFull(r.Body, buf)
301+
if err != nil {
302+
http.Error(w, "failed to read request body", http.StatusBadRequest)
273303
return
274304
}
275-
err = protojson.Unmarshal(body, msg)
305+
306+
err = protojson.Unmarshal(buf, msg)
276307
if err != nil {
277308
http.Error(w, err.Error(), http.StatusBadRequest)
278309
return
279310
}
280311
}
281312
defer apperror.Catch(r.Body.Close, "closing request body failed")
282313

283-
resp, err := s.call(ctx, method, msg)
314+
resp, err := s.call(ctx, service, method, msg)
284315
if err != nil {
285316
http.Error(w, err.Error(), http.StatusInternalServerError)
286317
return
@@ -430,21 +461,14 @@ func GetWebSocketConn(ctx context.Context) (*websocket.Conn, bool) {
430461
return conn, ok
431462
}
432463

433-
func (s *Service) call(ctx context.Context, method string, req proto.Message) (any, error) {
434-
var methodInfo *methodInfo
435-
for key, info := range s.methods {
436-
if strings.HasSuffix(key, "."+method) {
437-
methodInfo = info
438-
break
439-
}
440-
}
441-
442-
if methodInfo == nil {
443-
return nil, apperror.NewError("method not found")
464+
func (s *Service) call(ctx context.Context, service, method string, req proto.Message) (any, error) {
465+
methodInfo, exists := s.methods[service+"."+method]
466+
if !exists {
467+
return nil, errMethodNotFound
444468
}
445469

446470
if !methodInfo.method.IsValid() {
447-
return nil, apperror.NewError("method reflection data not found")
471+
return nil, errMethodReflectionNotFound
448472
}
449473

450474
m := methodInfo.method
@@ -453,25 +477,25 @@ func (s *Service) call(ctx context.Context, method string, req proto.Message) (a
453477
// Validate method signature if not already validated
454478
if !methodInfo.validated {
455479
if mt.NumIn() != 2 || mt.NumOut() != 2 {
456-
return nil, errors.New("invalid method signature")
480+
return nil, errInvalidMethodSignature
457481
}
458-
if !mt.In(0).Implements(reflect.TypeOf((*context.Context)(nil)).Elem()) {
459-
return nil, errors.New("first argument must be context.Context")
482+
if !mt.In(0).Implements(contextType) {
483+
return nil, errFirstArgMustBeContext
460484
}
461-
if !mt.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
462-
return nil, errors.New("second return value must be error")
485+
if !mt.Out(1).Implements(errorType) {
486+
return nil, errSecondReturnMustBeError
463487
}
464488
methodInfo.validated = true
465489
}
466490

467491
wanted := mt.In(1)
468492
if wanted.Kind() != reflect.Ptr {
469-
return nil, errors.New("request must be a pointer")
493+
return nil, errRequestMustBePointer
470494
}
471495

472496
reqVal := reflect.ValueOf(req)
473497
if !reqVal.IsValid() {
474-
return nil, errors.New("nil request")
498+
return nil, errNilRequest
475499
}
476500

477501
if !reqVal.Type().AssignableTo(wanted) {
@@ -487,7 +511,7 @@ func (s *Service) call(ctx context.Context, method string, req proto.Message) (a
487511
}
488512
pm, ok := reqPtr.Interface().(proto.Message)
489513
if !ok {
490-
return nil, errors.New("expected proto.Message for request")
514+
return nil, errExpectedProtoMessage
491515
}
492516
if err := protojson.Unmarshal(b, pm); err != nil {
493517
return nil, err
@@ -505,14 +529,9 @@ func (s *Service) call(ctx context.Context, method string, req proto.Message) (a
505529
}
506530

507531
func (s *Service) find(service, method string) (*methodInfo, error) {
508-
key, exists := s.keys[method]
532+
md, exists := s.methods[service+"."+method]
509533
if !exists {
510-
return nil, apperror.NewError("method not found")
511-
}
512-
513-
md, exists := s.methods[key]
514-
if !exists {
515-
return nil, apperror.NewError("method not found")
534+
return nil, errMethodNotFound
516535
}
517536

518537
return md, nil
@@ -528,18 +547,12 @@ func (s *Service) message(md *methodInfo) (proto.Message, error) {
528547

529548
func (s *Service) marshal(v any) ([]byte, error) {
530549
if pm, ok := v.(proto.Message); ok {
531-
return protojson.MarshalOptions{
532-
EmitUnpopulated: true,
533-
UseProtoNames: true,
534-
}.Marshal(pm)
550+
return marshalOpts.Marshal(pm)
535551
}
536552
rv := reflect.ValueOf(v)
537553
if rv.IsValid() && rv.CanAddr() {
538554
if pm, ok := rv.Addr().Interface().(proto.Message); ok {
539-
return protojson.MarshalOptions{
540-
EmitUnpopulated: true,
541-
UseProtoNames: true,
542-
}.Marshal(pm)
555+
return marshalOpts.Marshal(pm)
543556
}
544557
}
545558
// Handle non-pointer values by creating a pointer to them
@@ -550,10 +563,7 @@ func (s *Service) marshal(v any) ([]byte, error) {
550563
newPtr := reflect.New(rv.Type())
551564
newPtr.Elem().Set(rv)
552565
if pm, ok := newPtr.Interface().(proto.Message); ok {
553-
return protojson.MarshalOptions{
554-
EmitUnpopulated: true,
555-
UseProtoNames: true,
556-
}.Marshal(pm)
566+
return marshalOpts.Marshal(pm)
557567
}
558568
}
559569
}
@@ -871,9 +881,6 @@ const (
871881

872882
// validateMethodSignature validates and determines the streaming type of a method
873883
func (s *Service) validateMethodSignature(mt reflect.Type, md protoreflect.MethodDescriptor) (StreamingType, error) {
874-
contextType := reflect.TypeOf((*context.Context)(nil)).Elem()
875-
errorType := reflect.TypeOf((*error)(nil)).Elem()
876-
877884
// Basic validation: must have at least context parameter and error return
878885
if mt.NumIn() < 1 || !mt.In(0).Implements(contextType) {
879886
return StreamingTypeInvalid, errors.New("first parameter must be context.Context")

0 commit comments

Comments
 (0)