Skip to content

Commit d59e2bf

Browse files
Merge pull request #51 from valentin-kaiser/service
optimize jRPC service method caching and reflection
2 parents d38c207 + 5aa30bc commit d59e2bf

File tree

1 file changed

+171
-64
lines changed

1 file changed

+171
-64
lines changed

web/jrpc/jrpc.go

Lines changed: 171 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
//
1212
// Features:
1313
// - HTTP and WebSocket endpoint support
14-
// - Automatic method resolution and dispatch
14+
// - Automatic method resolution and dispatch with cached lookups
1515
// - Protocol Buffer JSON marshaling/unmarshaling
1616
// - Multiple streaming patterns (unary, server, client, bidirectional)
1717
// - Context enrichment with HTTP and WebSocket components
@@ -21,7 +21,7 @@
2121
// 1. Define your service in a .proto file
2222
// 2. Generate Go code using protoc with the protoc-gen-jrpc plugin
2323
// 3. Implement the generated Server interface
24-
// 4. Create a new jRPC service with jrpc.New(yourServer)
24+
// 4. Create a new jRPC service with jrpc.Register(yourServer)
2525
// 5. Register the HandlerFunc with the web package function WithJRPC
2626
//
2727
// Example:
@@ -50,7 +50,7 @@
5050
// err := web.Instance().
5151
// WithHost("localhost").
5252
// WithPort(8080).
53-
// WithJRPC(jrpc.New(&MyService{})).
53+
// WithJRPC(jrpc.Register(&MyService{})).
5454
// Start().Error
5555
// if err != nil {
5656
// log.Fatal().Err(err).Msg("server exited")
@@ -67,6 +67,7 @@ import (
6767
"net/http"
6868
"reflect"
6969
"strings"
70+
"sync"
7071
"time"
7172

7273
"github.com/gorilla/websocket"
@@ -79,13 +80,42 @@ import (
7980
"google.golang.org/protobuf/types/dynamicpb"
8081
)
8182

83+
// Common error messages to avoid repeated allocations
84+
var (
85+
errMethodNotFound = apperror.NewError("method not found")
86+
errMethodReflectionNotFound = apperror.NewError("method reflection data not found")
87+
errInvalidMethodSignature = errors.New("invalid method signature")
88+
errFirstArgMustBeContext = errors.New("first argument must be context.Context")
89+
errSecondReturnMustBeError = errors.New("second return value must be error")
90+
errRequestMustBePointer = errors.New("request must be a pointer")
91+
errNilRequest = errors.New("nil request")
92+
errExpectedProtoMessage = errors.New("expected proto.Message for request")
93+
94+
// Cached reflection types to avoid repeated type operations
95+
contextType = reflect.TypeOf((*context.Context)(nil)).Elem()
96+
errorType = reflect.TypeOf((*error)(nil)).Elem()
97+
98+
// Reusable marshal options to avoid allocation
99+
marshalOpts = protojson.MarshalOptions{
100+
EmitUnpopulated: true,
101+
UseProtoNames: true,
102+
}
103+
)
104+
82105
// upgrader is the WebSocket upgrader with default options
83106
var upgrader = websocket.Upgrader{
84107
CheckOrigin: func(r *http.Request) bool {
85108
return true // Allow connections from any origin
86109
},
87110
}
88111

112+
// bufferPool provides a pool of byte buffers for JSON operations
113+
var bufferPool = sync.Pool{
114+
New: func() interface{} {
115+
return make([]byte, 0, 1024) // Pre-allocate 1KB buffers
116+
},
117+
}
118+
89119
// ContextKey represents keys for context values
90120
type ContextKey string
91121

@@ -103,6 +133,8 @@ const (
103133
// protocol buffer message handling, and context enrichment.
104134
type Service struct {
105135
Server
136+
methods map[string]*methodInfo // cached method information for faster lookup
137+
types map[protoreflect.FullName]proto.Message // cached message types
106138
}
107139

108140
// Server represents a jRPC service implementation.
@@ -111,10 +143,76 @@ type Server interface {
111143
Descriptor() protoreflect.FileDescriptor
112144
}
113145

114-
// New creates a new jrpc service instance and registers the provided
146+
// methodInfo holds cached reflection and protobuf information for a method
147+
type methodInfo struct {
148+
descriptor protoreflect.MethodDescriptor
149+
method reflect.Value
150+
reflectType reflect.Type
151+
inputType reflect.Type
152+
outputType reflect.Type
153+
messageType proto.Message
154+
validated bool
155+
}
156+
157+
// Register creates a new jrpc service instance and registers the provided
115158
// service implementation. The service implementation has to implement the Descriptor method.
116-
func New(s Server) *Service {
117-
return &Service{Server: s}
159+
// This function builds a method cache for improved lookup performance.
160+
func Register(s Server) *Service {
161+
service := &Service{
162+
Server: s,
163+
methods: make(map[string]*methodInfo),
164+
types: make(map[protoreflect.FullName]proto.Message),
165+
}
166+
167+
sv := reflect.ValueOf(s)
168+
services := s.Descriptor().Services()
169+
for i := 0; i < services.Len(); i++ {
170+
sd := services.Get(i)
171+
methods := sd.Methods()
172+
for j := 0; j < methods.Len(); j++ {
173+
md := methods.Get(j)
174+
mn := string(md.Name())
175+
sn := string(sd.Name())
176+
177+
key := sn + "." + mn
178+
179+
rm := sv.MethodByName(mn)
180+
if !rm.IsValid() {
181+
continue
182+
}
183+
184+
mt := rm.Type()
185+
186+
var pm proto.Message
187+
if mt, err := protoregistry.GlobalTypes.FindMessageByName(md.Input().FullName()); err == nil {
188+
pm = mt.New().Interface()
189+
service.types[md.Input().FullName()] = pm
190+
} else {
191+
pm = dynamicpb.NewMessage(md.Input())
192+
service.types[md.Input().FullName()] = pm
193+
}
194+
195+
var it, ot reflect.Type
196+
if mt.NumIn() >= 2 {
197+
it = mt.In(1)
198+
}
199+
if mt.NumOut() >= 1 {
200+
ot = mt.Out(0)
201+
}
202+
203+
service.methods[key] = &methodInfo{
204+
descriptor: md,
205+
method: rm,
206+
reflectType: mt,
207+
inputType: it,
208+
outputType: ot,
209+
messageType: pm,
210+
validated: false,
211+
}
212+
}
213+
}
214+
215+
return service
118216
}
119217

120218
// SetUpgrader allows setting a custom WebSocket upgrader with specific options.
@@ -153,10 +251,10 @@ func (s *Service) HandlerFunc(w http.ResponseWriter, r *http.Request) {
153251
}
154252

155253
// isWebSocketRequest checks if the HTTP request is requesting a WebSocket upgrade
254+
// Optimized version with reduced string allocations
156255
func (s *Service) isWebSocketRequest(r *http.Request) bool {
157-
connection := strings.ToLower(r.Header.Get("Connection"))
158-
upgrade := strings.ToLower(r.Header.Get("Upgrade"))
159-
return strings.Contains(connection, "upgrade") && upgrade == "websocket"
256+
return strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade") &&
257+
strings.EqualFold(r.Header.Get("Upgrade"), "websocket")
160258
}
161259

162260
// unary processes HTTP POST requests to API endpoints.
@@ -188,20 +286,32 @@ func (s *Service) unary(w http.ResponseWriter, r *http.Request) {
188286
}
189287

190288
if r.ContentLength > 0 {
191-
body, _ := io.ReadAll(r.Body)
192-
if len(body) != int(r.ContentLength) {
193-
http.Error(w, "body length does not match Content-Length", http.StatusBadRequest)
289+
// Use buffer pool for body reading
290+
buf := bufferPool.Get().([]byte)
291+
defer bufferPool.Put(buf[:0])
292+
293+
// Ensure buffer is large enough
294+
if cap(buf) < int(r.ContentLength) {
295+
buf = make([]byte, r.ContentLength)
296+
} else {
297+
buf = buf[:r.ContentLength]
298+
}
299+
300+
_, err := io.ReadFull(r.Body, buf)
301+
if err != nil {
302+
http.Error(w, "failed to read request body", http.StatusBadRequest)
194303
return
195304
}
196-
err = protojson.Unmarshal(body, msg)
305+
306+
err = protojson.Unmarshal(buf, msg)
197307
if err != nil {
198308
http.Error(w, err.Error(), http.StatusBadRequest)
199309
return
200310
}
201311
}
202312
defer apperror.Catch(r.Body.Close, "closing request body failed")
203313

204-
resp, err := s.call(ctx, method, msg)
314+
resp, err := s.call(ctx, service, method, msg)
205315
if err != nil {
206316
http.Error(w, err.Error(), http.StatusInternalServerError)
207317
return
@@ -246,14 +356,15 @@ func (s *Service) websocket(w http.ResponseWriter, r *http.Request, conn *websoc
246356
return
247357
}
248358

249-
m := reflect.ValueOf(s.Server).MethodByName(method)
250-
if !m.IsValid() {
251-
s.closeWS(conn, websocket.CloseInternalServerErr, "method not found")
359+
if !md.method.IsValid() {
360+
s.closeWS(conn, websocket.CloseInternalServerErr, "service not registered")
252361
return
253362
}
254-
mt := m.Type()
255363

256-
streamingType, err := s.validateMethodSignature(mt, md)
364+
m := md.method
365+
mt := md.reflectType
366+
367+
streamingType, err := s.validateMethodSignature(mt, md.descriptor)
257368
if err != nil {
258369
s.closeWS(conn, websocket.CloseInternalServerErr, "invalid method signature: "+err.Error())
259370
return
@@ -350,43 +461,57 @@ func GetWebSocketConn(ctx context.Context) (*websocket.Conn, bool) {
350461
return conn, ok
351462
}
352463

353-
func (s *Service) call(ctx context.Context, method string, req proto.Message) (any, error) {
354-
m := reflect.ValueOf(s.Server).MethodByName(method)
355-
if !m.IsValid() {
356-
return nil, apperror.NewError("method not found")
464+
func (s *Service) call(ctx context.Context, service, method string, req proto.Message) (any, error) {
465+
methodInfo, exists := s.methods[service+"."+method]
466+
if !exists {
467+
return nil, errMethodNotFound
357468
}
358469

359-
mt := m.Type()
360-
if mt.NumIn() != 2 || mt.NumOut() != 2 {
361-
return nil, errors.New("invalid method signature")
362-
}
363-
if !mt.In(0).Implements(reflect.TypeOf((*context.Context)(nil)).Elem()) {
364-
return nil, errors.New("first argument must be context.Context")
470+
if !methodInfo.method.IsValid() {
471+
return nil, errMethodReflectionNotFound
365472
}
366-
if !mt.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
367-
return nil, errors.New("second return value must be error")
473+
474+
m := methodInfo.method
475+
mt := methodInfo.reflectType
476+
477+
// Validate method signature if not already validated
478+
if !methodInfo.validated {
479+
if mt.NumIn() != 2 || mt.NumOut() != 2 {
480+
return nil, errInvalidMethodSignature
481+
}
482+
if !mt.In(0).Implements(contextType) {
483+
return nil, errFirstArgMustBeContext
484+
}
485+
if !mt.Out(1).Implements(errorType) {
486+
return nil, errSecondReturnMustBeError
487+
}
488+
methodInfo.validated = true
368489
}
369490

370491
wanted := mt.In(1)
371492
if wanted.Kind() != reflect.Ptr {
372-
return nil, errors.New("request must be a pointer")
493+
return nil, errRequestMustBePointer
373494
}
374495

375496
reqVal := reflect.ValueOf(req)
376497
if !reqVal.IsValid() {
377-
return nil, errors.New("nil request")
498+
return nil, errNilRequest
378499
}
379500

380501
if !reqVal.Type().AssignableTo(wanted) {
381502
// Convert via JSON round-trip using protojson to the expected type.
503+
// Use buffer pool for better performance
504+
buf := bufferPool.Get().([]byte)
505+
defer bufferPool.Put(buf[:0])
506+
382507
reqPtr := reflect.New(wanted.Elem())
383508
b, err := protojson.Marshal(req)
384509
if err != nil {
385510
return nil, err
386511
}
387512
pm, ok := reqPtr.Interface().(proto.Message)
388513
if !ok {
389-
return nil, errors.New("expected proto.Message for request")
514+
return nil, errExpectedProtoMessage
390515
}
391516
if err := protojson.Unmarshal(b, pm); err != nil {
392517
return nil, err
@@ -403,43 +528,31 @@ func (s *Service) call(ctx context.Context, method string, req proto.Message) (a
403528
return res, err
404529
}
405530

406-
func (s *Service) find(service, method string) (protoreflect.MethodDescriptor, error) {
407-
sd := s.Descriptor().Services().ByName(protoreflect.Name(service))
408-
if sd == nil {
409-
return nil, apperror.NewError("service not found")
531+
func (s *Service) find(service, method string) (*methodInfo, error) {
532+
md, exists := s.methods[service+"."+method]
533+
if !exists {
534+
return nil, errMethodNotFound
410535
}
411536

412-
md := sd.Methods().ByName(protoreflect.Name(method))
413-
if md == nil {
414-
return nil, apperror.NewError("method not found")
415-
}
416537
return md, nil
417538
}
418539

419-
func (s *Service) message(md protoreflect.MethodDescriptor) (proto.Message, error) {
420-
mt, err := protoregistry.GlobalTypes.FindMessageByName(md.Input().FullName())
421-
if err != nil {
422-
log.Error().Err(err).Msg("failed to find message type")
423-
return dynamicpb.NewMessage(md.Input()), nil
540+
func (s *Service) message(md *methodInfo) (proto.Message, error) {
541+
if md.messageType == nil {
542+
return nil, apperror.NewError("message type not found")
424543
}
425544

426-
return mt.New().Interface(), nil
545+
return proto.Clone(md.messageType), nil
427546
}
428547

429548
func (s *Service) marshal(v any) ([]byte, error) {
430549
if pm, ok := v.(proto.Message); ok {
431-
return protojson.MarshalOptions{
432-
EmitUnpopulated: true,
433-
UseProtoNames: true,
434-
}.Marshal(pm)
550+
return marshalOpts.Marshal(pm)
435551
}
436552
rv := reflect.ValueOf(v)
437553
if rv.IsValid() && rv.CanAddr() {
438554
if pm, ok := rv.Addr().Interface().(proto.Message); ok {
439-
return protojson.MarshalOptions{
440-
EmitUnpopulated: true,
441-
UseProtoNames: true,
442-
}.Marshal(pm)
555+
return marshalOpts.Marshal(pm)
443556
}
444557
}
445558
// Handle non-pointer values by creating a pointer to them
@@ -450,10 +563,7 @@ func (s *Service) marshal(v any) ([]byte, error) {
450563
newPtr := reflect.New(rv.Type())
451564
newPtr.Elem().Set(rv)
452565
if pm, ok := newPtr.Interface().(proto.Message); ok {
453-
return protojson.MarshalOptions{
454-
EmitUnpopulated: true,
455-
UseProtoNames: true,
456-
}.Marshal(pm)
566+
return marshalOpts.Marshal(pm)
457567
}
458568
}
459569
}
@@ -503,7 +613,7 @@ func (s *Service) handleBidirectionalStream(ctx context.Context, conn *websocket
503613
}
504614

505615
// handleServerStream handles server streaming WebSocket connections
506-
func (s *Service) handleServerStream(ctx context.Context, conn *websocket.Conn, m reflect.Value, mt reflect.Type, md protoreflect.MethodDescriptor) {
616+
func (s *Service) handleServerStream(ctx context.Context, conn *websocket.Conn, m reflect.Value, mt reflect.Type, md *methodInfo) {
507617
outType := mt.In(2)
508618
outPtr := outType.Elem()
509619
out := reflect.MakeChan(outType, 0)
@@ -771,9 +881,6 @@ const (
771881

772882
// validateMethodSignature validates and determines the streaming type of a method
773883
func (s *Service) validateMethodSignature(mt reflect.Type, md protoreflect.MethodDescriptor) (StreamingType, error) {
774-
contextType := reflect.TypeOf((*context.Context)(nil)).Elem()
775-
errorType := reflect.TypeOf((*error)(nil)).Elem()
776-
777884
// Basic validation: must have at least context parameter and error return
778885
if mt.NumIn() < 1 || !mt.In(0).Implements(contextType) {
779886
return StreamingTypeInvalid, errors.New("first parameter must be context.Context")

0 commit comments

Comments
 (0)