From 30f93671f3be67c5ef29563649b711829b18200d Mon Sep 17 00:00:00 2001 From: Carlo Alberto Ferraris Date: Fri, 17 Sep 2021 16:11:22 +0900 Subject: [PATCH] add a method filter --- filter/client_interceptors.go | 43 +++++++++++ filter/client_interceptors_test.go | 87 ++++++++++++++++++++++ filter/matchlist.go | 19 +++++ filter/matchlist_test.go | 55 ++++++++++++++ filter/server_interceptors.go | 65 ++++++++++++++++ filter/server_interceptors_test.go | 115 +++++++++++++++++++++++++++++ 6 files changed, 384 insertions(+) create mode 100644 filter/client_interceptors.go create mode 100644 filter/client_interceptors_test.go create mode 100644 filter/matchlist.go create mode 100644 filter/matchlist_test.go create mode 100644 filter/server_interceptors.go create mode 100644 filter/server_interceptors_test.go diff --git a/filter/client_interceptors.go b/filter/client_interceptors.go new file mode 100644 index 000000000..c65a346be --- /dev/null +++ b/filter/client_interceptors.go @@ -0,0 +1,43 @@ +package filter + +import ( + "context" + + "google.golang.org/grpc" +) + +// UnaryClientMethods returns an interceptor that applies the provided interceptor only to outgoing unary calls to the specified methods. +// The allowlist parameter specifies whether the provided list of methods is to be treated as an allowlist (true) or a denylist (false). +// If it is an allowlist the interceptor will be applied only to the methods in the list; if it is a denylist the interceptor will be applied only to methods not in the list. +// The methods must be specified using the full name (e.g. "/package.service/method"). +func UnaryClientMethods(interceptor grpc.UnaryClientInterceptor, allowlist bool, methods ...string) grpc.UnaryClientInterceptor { + if interceptor == nil { + panic("nil interceptor") + } + m := newMatchlist(methods, allowlist) + + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + if m.match(method) { + return interceptor(ctx, method, req, reply, cc, invoker, opts...) + } + return invoker(ctx, method, req, reply, cc, opts...) + } +} + +// StreamClientMethods returns an interceptor that applies the provided interceptor only to outgoing unary calls to the specified methods. +// The allowlist parameter specifies whether the provided list of methods is to be treated as an allowlist (true) or a denylist (false). +// If it is an allowlist the interceptor will be applied only to the methods in the list; if it is a denylist the interceptor will be applied only to methods not in the list. +// The methods must be specified using the full name (e.g. "/package.service/method"). +func StreamClientMethods(interceptor grpc.StreamClientInterceptor, allowlist bool, methods ...string) grpc.StreamClientInterceptor { + if interceptor == nil { + panic("nil interceptor") + } + m := newMatchlist(methods, allowlist) + + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if m.match(method) { + return interceptor(ctx, desc, cc, method, streamer, opts...) + } + return streamer(ctx, desc, cc, method, opts...) + } +} diff --git a/filter/client_interceptors_test.go b/filter/client_interceptors_test.go new file mode 100644 index 000000000..7e18c2438 --- /dev/null +++ b/filter/client_interceptors_test.go @@ -0,0 +1,87 @@ +package filter_test + +import ( + "context" + "testing" + + "github.com/grpc-ecosystem/go-grpc-middleware/filter" + grpc_testing "github.com/grpc-ecosystem/go-grpc-middleware/testing" + pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "google.golang.org/grpc" +) + +type noopUnaryClientInterceptor struct { + called bool +} + +func (i *noopUnaryClientInterceptor) intercept(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + i.called = true + return invoker(ctx, method, req, reply, cc, opts...) +} + +type noopStreamClientInterceptor struct { + called bool +} + +func (i *noopStreamClientInterceptor) intercept(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + i.called = true + return streamer(ctx, desc, cc, method, opts...) +} + +func TestClientMethods(t *testing.T) { + service := &someService{ + TestPingService: grpc_testing.TestPingService{T: t}, + } + si := &noopStreamClientInterceptor{} + ui := &noopUnaryClientInterceptor{} + suite.Run(t, &ClientFilterSuite{ + srv: service, + si: si, + ui: ui, + InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{ + TestService: service, + ClientOpts: []grpc.DialOption{ + grpc.WithUnaryInterceptor(filter.UnaryClientMethods(ui.intercept, true, "/mwitkow.testproto.TestService/Ping")), + grpc.WithStreamInterceptor(filter.StreamClientMethods(si.intercept, true, "/mwitkow.testproto.TestService/PingStream")), + }, + }, + }) +} + +type ClientFilterSuite struct { + *grpc_testing.InterceptorTestSuite + srv *someService + si *noopStreamClientInterceptor + ui *noopUnaryClientInterceptor +} + +func (s *ClientFilterSuite) SetupTest() { + s.srv.pingCalled = false + s.srv.pingEmptyCalled = false + s.srv.pingStreamCalled = false + s.si.called = false + s.ui.called = false +} + +func (s *ClientFilterSuite) TestUnary_CallAllowedUnaryMethod() { + res, err := s.Client.Ping(s.SimpleCtx(), &pb_testproto.PingRequest{Value: "hello"}) + require.NoError(s.T(), err) + require.Equal(s.T(), res.Value, "hello") + require.True(s.T(), s.srv.pingCalled) + require.False(s.T(), s.srv.pingEmptyCalled) + require.False(s.T(), s.srv.pingStreamCalled) + require.True(s.T(), s.ui.called) // allowed + require.False(s.T(), s.si.called) +} + +func (s *ClientFilterSuite) TestUnary_CallDisallowedUnaryMethod() { + _, err := s.Client.PingEmpty(s.SimpleCtx(), &pb_testproto.Empty{}) + require.NoError(s.T(), err) + require.False(s.T(), s.srv.pingCalled) + require.True(s.T(), s.srv.pingEmptyCalled) + require.False(s.T(), s.srv.pingStreamCalled) + require.False(s.T(), s.ui.called) // disallowed + require.False(s.T(), s.si.called) +} diff --git a/filter/matchlist.go b/filter/matchlist.go new file mode 100644 index 000000000..83ec2b43c --- /dev/null +++ b/filter/matchlist.go @@ -0,0 +1,19 @@ +package filter + +type matchlist struct { + m map[string]struct{} + p bool +} + +func newMatchlist(s []string, matchPresence bool) *matchlist { + m := make(map[string]struct{}, len(s)) + for _, e := range s { + m[e] = struct{}{} + } + return &matchlist{m, matchPresence} +} + +func (m *matchlist) match(s string) bool { + _, found := m.m[s] + return found == m.p +} diff --git a/filter/matchlist_test.go b/filter/matchlist_test.go new file mode 100644 index 000000000..e745d4f00 --- /dev/null +++ b/filter/matchlist_test.go @@ -0,0 +1,55 @@ +package filter + +import ( + "fmt" + "strconv" + "strings" + "testing" +) + +func TestMatchlist(t *testing.T) { + cases := map[string]struct { + list []string + presence bool + match string + res bool + }{ + "positive match": {[]string{"a", "b"}, true, "a", true}, + "positive match 2": {[]string{"a", "b"}, true, "b", true}, + "positive no match": {[]string{"a", "b"}, true, "c", false}, + "positive no match case insensitive": {[]string{"a", "b"}, true, "A", false}, + "negative match": {[]string{"a", "b"}, false, "a", false}, + "negative match 2": {[]string{"a", "b"}, false, "b", false}, + "negative no match": {[]string{"a", "b"}, false, "c", true}, + "negative no match case insensitive": {[]string{"a", "b"}, false, "A", true}, + + "positive empty list": {[]string{}, true, "a", false}, + "negative empty list": {[]string{}, false, "a", true}, + } + for n, c := range cases { + t.Run(n, func(t *testing.T) { + t.Log(c.list, c.match, c.presence, c.res) + m := newMatchlist(c.list, c.presence) + r := m.match(c.match) + if r != c.res { + t.Error("wrong result") + } + }) + } +} + +func BenchmarkMatchlist(b *testing.B) { + for _, i := range []int{0, 1, 2, 3, 4, 5, 6, 8, 10, 15, 20, 25, 30, 40, 50, 75, 100, 300, 1000} { + var s []string + for j := 0; j < i; j++ { + s = append(s, fmt.Sprintf("%30d", j)) + } + m := newMatchlist(s, true) + c := strings.Repeat(" ", 30) + b.Run(strconv.Itoa(i), func(b *testing.B) { + for j := 0; j < b.N; j++ { + _ = m.match(c) + } + }) + } +} diff --git a/filter/server_interceptors.go b/filter/server_interceptors.go new file mode 100644 index 000000000..ffa0ad231 --- /dev/null +++ b/filter/server_interceptors.go @@ -0,0 +1,65 @@ +package filter + +import ( + "context" + + "google.golang.org/grpc" +) + +// UnaryServerMethods returns an interceptor that applies the provided interceptor only to incoming unary calls to the specified methods. +// The allowlist parameter specifies whether the provided list of methods is to be treated as an allowlist (true) or a denylist (false). +// If it is an allowlist the interceptor will be applied only to the methods in the list; if it is a denylist the interceptor will be applied only to methods not in the list. +// The methods must be specified using the full name (e.g. "/package.service/method"). +func UnaryServerMethods(interceptor grpc.UnaryServerInterceptor, allowlist bool, methods ...string) grpc.UnaryServerInterceptor { + if interceptor == nil { + panic("nil interceptor") + } + m := newMatchlist(methods, allowlist) + + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if m.match(info.FullMethod) { + return interceptor(ctx, req, info, handler) + } + return handler(ctx, req) + } +} + +/* +func UnaryServerMethodsInterceptor(interceptor grpc.UnaryServerInterceptor, allowlist bool, methods ...string) grpc.UnaryServerInterceptor { + m := newMatchlist(methods, allowlist) + + return UnaryServerConditionInterceptor(interceptor, func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo) bool { + return m.match(info.FullMethod) + }) +} + +func UnaryServerConditionInterceptor(interceptor grpc.UnaryServerInterceptor, condition func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo) bool) grpc.UnaryServerInterceptor { + if interceptor == nil { + panic("nil interceptor") + } + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if condition(ctx, req, info) { + return interceptor(ctx, req, info, handler) + } + return handler(ctx, req) + } +} +*/ + +// StreamServerMethods returns an interceptor that applies the provided interceptor only to incoming stream calls to the specified methods. +// The allowlist parameter specifies whether the provided list of methods is to be treated as an allowlist (true) or a denylist (false). +// If it is an allowlist the interceptor will be applied only to the methods in the list; if it is a denylist the interceptor will be applied only to methods not in the list. +// The methods must be specified using the full name (e.g. "/package.service/method"). +func StreamServerMethods(interceptor grpc.StreamServerInterceptor, allowlist bool, methods ...string) grpc.StreamServerInterceptor { + if interceptor == nil { + panic("nil interceptor") + } + m := newMatchlist(methods, allowlist) + + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if m.match(info.FullMethod) { + return interceptor(srv, ss, info, handler) + } + return handler(srv, ss) + } +} diff --git a/filter/server_interceptors_test.go b/filter/server_interceptors_test.go new file mode 100644 index 000000000..3ab512294 --- /dev/null +++ b/filter/server_interceptors_test.go @@ -0,0 +1,115 @@ +package filter_test + +import ( + "context" + "testing" + + "github.com/grpc-ecosystem/go-grpc-middleware/filter" + grpc_testing "github.com/grpc-ecosystem/go-grpc-middleware/testing" + pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "google.golang.org/grpc" +) + +type someService struct { + grpc_testing.TestPingService + pingCalled bool + pingEmptyCalled bool + pingStreamCalled bool +} + +func (s *someService) Ping(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.PingResponse, error) { + s.pingCalled = true + return s.TestPingService.Ping(ctx, ping) +} + +func (s *someService) PingEmpty(ctx context.Context, empty *pb_testproto.Empty) (*pb_testproto.PingResponse, error) { + s.pingEmptyCalled = true + return s.TestPingService.PingEmpty(ctx, empty) +} + +func (s *someService) PingStream(stream pb_testproto.TestService_PingStreamServer) error { + s.pingStreamCalled = true + return s.TestPingService.PingStream(stream) +} + +type noopUnaryServerInterceptor struct { + called bool +} + +func (i *noopUnaryServerInterceptor) intercept(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + i.called = true + return handler(ctx, req) +} + +type noopStreamServerInterceptor struct { + called bool +} + +func (i *noopStreamServerInterceptor) intercept(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + i.called = true + return handler(srv, ss) +} + +func TestServerMethods(t *testing.T) { + service := &someService{ + TestPingService: grpc_testing.TestPingService{T: t}, + } + si := &noopStreamServerInterceptor{} + ui := &noopUnaryServerInterceptor{} + suite.Run(t, &FilterSuite{ + srv: service, + si: si, + ui: ui, + InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{ + TestService: service, + /* + ClientOpts: []grpc.DialOption{ + grpc.WithStreamInterceptor(filter.StreamClientMethod()), + grpc.WithUnaryInterceptor(filter.UnaryClientMethod()), + }, + */ + ServerOpts: []grpc.ServerOption{ + grpc.UnaryInterceptor(filter.UnaryServerMethods(ui.intercept, true, "/mwitkow.testproto.TestService/Ping")), + grpc.StreamInterceptor(filter.StreamServerMethods(si.intercept, true, "/mwitkow.testproto.TestService/PingStream")), + }, + }, + }) +} + +type FilterSuite struct { + *grpc_testing.InterceptorTestSuite + srv *someService + si *noopStreamServerInterceptor + ui *noopUnaryServerInterceptor +} + +func (s *FilterSuite) SetupTest() { + s.srv.pingCalled = false + s.srv.pingEmptyCalled = false + s.srv.pingStreamCalled = false + s.si.called = false + s.ui.called = false +} + +func (s *FilterSuite) TestUnary_CallAllowedUnaryMethod() { + res, err := s.Client.Ping(s.SimpleCtx(), &pb_testproto.PingRequest{Value: "hello"}) + require.NoError(s.T(), err) + require.Equal(s.T(), res.Value, "hello") + require.True(s.T(), s.srv.pingCalled) + require.False(s.T(), s.srv.pingEmptyCalled) + require.False(s.T(), s.srv.pingStreamCalled) + require.True(s.T(), s.ui.called) // allowed + require.False(s.T(), s.si.called) +} + +func (s *FilterSuite) TestUnary_CallDisallowedUnaryMethod() { + _, err := s.Client.PingEmpty(s.SimpleCtx(), &pb_testproto.Empty{}) + require.NoError(s.T(), err) + require.False(s.T(), s.srv.pingCalled) + require.True(s.T(), s.srv.pingEmptyCalled) + require.False(s.T(), s.srv.pingStreamCalled) + require.False(s.T(), s.ui.called) // disallowed + require.False(s.T(), s.si.called) +}