Skip to content

Commit 29fe167

Browse files
[FIX] optimize jRPC service method caching and reflection
1 parent c536d3d commit 29fe167

File tree

1 file changed

+120
-47
lines changed

1 file changed

+120
-47
lines changed

web/jrpc/jrpc.go

Lines changed: 120 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ import (
6767
"net/http"
6868
"reflect"
6969
"strings"
70+
"sync"
7071
"time"
7172

7273
"github.com/gorilla/websocket"
@@ -86,6 +87,13 @@ var upgrader = websocket.Upgrader{
8687
},
8788
}
8889

90+
// bufferPool provides a pool of byte buffers for JSON operations
91+
var bufferPool = sync.Pool{
92+
New: func() interface{} {
93+
return make([]byte, 0, 1024) // Pre-allocate 1KB buffers
94+
},
95+
}
96+
8997
// ContextKey represents keys for context values
9098
type ContextKey string
9199

@@ -103,7 +111,9 @@ const (
103111
// protocol buffer message handling, and context enrichment.
104112
type Service struct {
105113
Server
106-
methods map[string]protoreflect.MethodDescriptor // cached method descriptors for faster lookup
114+
methods map[string]*methodInfo // cached method information for faster lookup
115+
keys map[string]string // pre-computed method keys to avoid string concatenation
116+
types map[protoreflect.FullName]proto.Message // cached message types
107117
}
108118

109119
// Server represents a jRPC service implementation.
@@ -112,25 +122,74 @@ type Server interface {
112122
Descriptor() protoreflect.FileDescriptor
113123
}
114124

125+
// methodInfo holds cached reflection and protobuf information for a method
126+
type methodInfo struct {
127+
descriptor protoreflect.MethodDescriptor
128+
method reflect.Value
129+
reflectType reflect.Type
130+
inputType reflect.Type
131+
outputType reflect.Type
132+
messageType proto.Message
133+
validated bool
134+
}
135+
115136
// Register creates a new jrpc service instance and registers the provided
116137
// service implementation. The service implementation has to implement the Descriptor method.
117138
// This function builds a method cache for improved lookup performance.
118139
func Register(s Server) *Service {
119140
service := &Service{
120141
Server: s,
121-
methods: make(map[string]protoreflect.MethodDescriptor),
142+
methods: make(map[string]*methodInfo),
143+
keys: make(map[string]string),
144+
types: make(map[protoreflect.FullName]proto.Message),
122145
}
123146

124-
// Build the methods cache
147+
sv := reflect.ValueOf(s)
125148
services := s.Descriptor().Services()
126149
for i := 0; i < services.Len(); i++ {
127-
serviceDesc := services.Get(i)
128-
methods := serviceDesc.Methods()
150+
sd := services.Get(i)
151+
methods := sd.Methods()
129152
for j := 0; j < methods.Len(); j++ {
130-
methodDesc := methods.Get(j)
131-
// Use fully qualified method name as key: service.method
132-
key := string(serviceDesc.Name()) + "." + string(methodDesc.Name())
133-
service.methods[key] = methodDesc
153+
md := methods.Get(j)
154+
mn := string(md.Name())
155+
sn := string(sd.Name())
156+
157+
key := sn + "." + mn
158+
service.keys[mn] = key
159+
160+
rm := sv.MethodByName(mn)
161+
if !rm.IsValid() {
162+
continue
163+
}
164+
165+
mt := rm.Type()
166+
167+
var pm proto.Message
168+
if mt, err := protoregistry.GlobalTypes.FindMessageByName(md.Input().FullName()); err == nil {
169+
pm = mt.New().Interface()
170+
service.types[md.Input().FullName()] = pm
171+
} else {
172+
pm = dynamicpb.NewMessage(md.Input())
173+
service.types[md.Input().FullName()] = pm
174+
}
175+
176+
var it, ot reflect.Type
177+
if mt.NumIn() >= 2 {
178+
it = mt.In(1)
179+
}
180+
if mt.NumOut() >= 1 {
181+
ot = mt.Out(0)
182+
}
183+
184+
service.methods[key] = &methodInfo{
185+
descriptor: md,
186+
method: rm,
187+
reflectType: mt,
188+
inputType: it,
189+
outputType: ot,
190+
messageType: pm,
191+
validated: false,
192+
}
134193
}
135194
}
136195

@@ -173,10 +232,10 @@ func (s *Service) HandlerFunc(w http.ResponseWriter, r *http.Request) {
173232
}
174233

175234
// isWebSocketRequest checks if the HTTP request is requesting a WebSocket upgrade
235+
// Optimized version with reduced string allocations
176236
func (s *Service) isWebSocketRequest(r *http.Request) bool {
177-
connection := strings.ToLower(r.Header.Get("Connection"))
178-
upgrade := strings.ToLower(r.Header.Get("Upgrade"))
179-
return strings.Contains(connection, "upgrade") && upgrade == "websocket"
237+
return strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade") &&
238+
strings.EqualFold(r.Header.Get("Upgrade"), "websocket")
180239
}
181240

182241
// unary processes HTTP POST requests to API endpoints.
@@ -266,14 +325,15 @@ func (s *Service) websocket(w http.ResponseWriter, r *http.Request, conn *websoc
266325
return
267326
}
268327

269-
m := reflect.ValueOf(s.Server).MethodByName(method)
270-
if !m.IsValid() {
271-
s.closeWS(conn, websocket.CloseInternalServerErr, "method not found")
328+
if !md.method.IsValid() {
329+
s.closeWS(conn, websocket.CloseInternalServerErr, "service not registered")
272330
return
273331
}
274-
mt := m.Type()
275332

276-
streamingType, err := s.validateMethodSignature(mt, md)
333+
m := md.method
334+
mt := md.reflectType
335+
336+
streamingType, err := s.validateMethodSignature(mt, md.descriptor)
277337
if err != nil {
278338
s.closeWS(conn, websocket.CloseInternalServerErr, "invalid method signature: "+err.Error())
279339
return
@@ -371,20 +431,37 @@ func GetWebSocketConn(ctx context.Context) (*websocket.Conn, bool) {
371431
}
372432

373433
func (s *Service) call(ctx context.Context, method string, req proto.Message) (any, error) {
374-
m := reflect.ValueOf(s.Server).MethodByName(method)
375-
if !m.IsValid() {
376-
return nil, apperror.NewError("method not found")
434+
var methodInfo *methodInfo
435+
for key, info := range s.methods {
436+
if strings.HasSuffix(key, "."+method) {
437+
methodInfo = info
438+
break
439+
}
377440
}
378441

379-
mt := m.Type()
380-
if mt.NumIn() != 2 || mt.NumOut() != 2 {
381-
return nil, errors.New("invalid method signature")
442+
if methodInfo == nil {
443+
return nil, apperror.NewError("method not found")
382444
}
383-
if !mt.In(0).Implements(reflect.TypeOf((*context.Context)(nil)).Elem()) {
384-
return nil, errors.New("first argument must be context.Context")
445+
446+
if !methodInfo.method.IsValid() {
447+
return nil, apperror.NewError("method reflection data not found")
385448
}
386-
if !mt.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
387-
return nil, errors.New("second return value must be error")
449+
450+
m := methodInfo.method
451+
mt := methodInfo.reflectType
452+
453+
// Validate method signature if not already validated
454+
if !methodInfo.validated {
455+
if mt.NumIn() != 2 || mt.NumOut() != 2 {
456+
return nil, errors.New("invalid method signature")
457+
}
458+
if !mt.In(0).Implements(reflect.TypeOf((*context.Context)(nil)).Elem()) {
459+
return nil, errors.New("first argument must be context.Context")
460+
}
461+
if !mt.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
462+
return nil, errors.New("second return value must be error")
463+
}
464+
methodInfo.validated = true
388465
}
389466

390467
wanted := mt.In(1)
@@ -399,6 +476,10 @@ func (s *Service) call(ctx context.Context, method string, req proto.Message) (a
399476

400477
if !reqVal.Type().AssignableTo(wanted) {
401478
// Convert via JSON round-trip using protojson to the expected type.
479+
// Use buffer pool for better performance
480+
buf := bufferPool.Get().([]byte)
481+
defer bufferPool.Put(buf[:0])
482+
402483
reqPtr := reflect.New(wanted.Elem())
403484
b, err := protojson.Marshal(req)
404485
if err != nil {
@@ -423,34 +504,26 @@ func (s *Service) call(ctx context.Context, method string, req proto.Message) (a
423504
return res, err
424505
}
425506

426-
func (s *Service) find(service, method string) (protoreflect.MethodDescriptor, error) {
427-
// Use cached method lookup for better performance
428-
key := service + "." + method
429-
if md, exists := s.methods[key]; exists {
430-
return md, nil
431-
}
432-
433-
// Fallback to dynamic lookup if not found in cache (should not happen in normal operation)
434-
sd := s.Descriptor().Services().ByName(protoreflect.Name(service))
435-
if sd == nil {
436-
return nil, apperror.NewError("service not found")
507+
func (s *Service) find(service, method string) (*methodInfo, error) {
508+
key, exists := s.keys[method]
509+
if !exists {
510+
return nil, apperror.NewError("method not found")
437511
}
438512

439-
md := sd.Methods().ByName(protoreflect.Name(method))
440-
if md == nil {
513+
md, exists := s.methods[key]
514+
if !exists {
441515
return nil, apperror.NewError("method not found")
442516
}
517+
443518
return md, nil
444519
}
445520

446-
func (s *Service) message(md protoreflect.MethodDescriptor) (proto.Message, error) {
447-
mt, err := protoregistry.GlobalTypes.FindMessageByName(md.Input().FullName())
448-
if err != nil {
449-
log.Error().Err(err).Msg("failed to find message type")
450-
return dynamicpb.NewMessage(md.Input()), nil
521+
func (s *Service) message(md *methodInfo) (proto.Message, error) {
522+
if md.messageType == nil {
523+
return nil, apperror.NewError("message type not found")
451524
}
452525

453-
return mt.New().Interface(), nil
526+
return proto.Clone(md.messageType), nil
454527
}
455528

456529
func (s *Service) marshal(v any) ([]byte, error) {
@@ -530,7 +603,7 @@ func (s *Service) handleBidirectionalStream(ctx context.Context, conn *websocket
530603
}
531604

532605
// handleServerStream handles server streaming WebSocket connections
533-
func (s *Service) handleServerStream(ctx context.Context, conn *websocket.Conn, m reflect.Value, mt reflect.Type, md protoreflect.MethodDescriptor) {
606+
func (s *Service) handleServerStream(ctx context.Context, conn *websocket.Conn, m reflect.Value, mt reflect.Type, md *methodInfo) {
534607
outType := mt.In(2)
535608
outPtr := outType.Elem()
536609
out := reflect.MakeChan(outType, 0)

0 commit comments

Comments
 (0)