diff --git a/pkg/gofr/grpc/log.go b/pkg/gofr/grpc/log.go index c6b38da736..49ffca2f5d 100644 --- a/pkg/gofr/grpc/log.go +++ b/pkg/gofr/grpc/log.go @@ -10,6 +10,7 @@ import ( "time" "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/trace" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -179,6 +180,12 @@ func ObservabilityInterceptor(logger Logger, metrics Metrics) grpc.UnaryServerIn func initializeSpanContext(ctx context.Context) context.Context { md, _ := metadata.FromIncomingContext(ctx) + ctx = otel.GetTextMapPropagator().Extract(ctx, metadataCarrier(md)) + + if trace.SpanContextFromContext(ctx).IsValid() { + return ctx + } + traceIDHex := getMetadataValue(md, "x-gofr-traceid") spanIDHex := getMetadataValue(md, "x-gofr-spanid") @@ -304,3 +311,25 @@ func getMetadataValue(md metadata.MD, key string) string { return "" } + +type metadataCarrier metadata.MD + +func (m metadataCarrier) Get(key string) string { + values := metadata.MD(m).Get(key) + if len(values) > 0 { + return values[0] + } + return "" +} + +func (m metadataCarrier) Set(key string, value string) { + metadata.MD(m).Set(key, value) +} + +func (m metadataCarrier) Keys() []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} diff --git a/pkg/gofr/grpc/middleware/apikey_auth.go b/pkg/gofr/grpc/middleware/apikey_auth.go new file mode 100644 index 0000000000..94b96f8e0d --- /dev/null +++ b/pkg/gofr/grpc/middleware/apikey_auth.go @@ -0,0 +1,57 @@ +package middleware + +import ( + "context" + "crypto/subtle" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +// APIKeyAuthUnaryInterceptor returns a unary interceptor that validates requests using API Key Authentication. +func APIKeyAuthUnaryInterceptor(apiKeys ...string) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if err := validateAPIKey(ctx, apiKeys); err != nil { + return nil, err + } + return handler(ctx, req) + } +} + +// APIKeyAuthStreamInterceptor returns a stream interceptor that validates requests using API Key Authentication. +func APIKeyAuthStreamInterceptor(apiKeys ...string) grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if err := validateAPIKey(ss.Context(), apiKeys); err != nil { + return err + } + return handler(srv, ss) + } +} + +func validateAPIKey(ctx context.Context, validKeys []string) error { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return status.Error(codes.Unauthenticated, "missing metadata") + } + + // Check for x-api-key + values, ok := md["x-api-key"] + if !ok || len(values) == 0 { + return status.Error(codes.Unauthenticated, "missing x-api-key header") + } + + apiKey := values[0] + + for _, key := range validKeys { + if subtle.ConstantTimeCompare([]byte(apiKey), []byte(key)) == 1 { + return nil + } + } + + // Constant time compare with dummy key to mitigate timing attacks + subtle.ConstantTimeCompare([]byte(apiKey), []byte("dummy")) + + return status.Error(codes.Unauthenticated, "invalid api key") +} diff --git a/pkg/gofr/grpc/middleware/auth_test.go b/pkg/gofr/grpc/middleware/auth_test.go new file mode 100644 index 0000000000..25e39536b6 --- /dev/null +++ b/pkg/gofr/grpc/middleware/auth_test.go @@ -0,0 +1,220 @@ +package middleware + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "fmt" + "testing" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + + httpMiddleware "gofr.dev/pkg/gofr/http/middleware" +) + +type mockKeyProvider struct { + key *rsa.PublicKey +} + +func (m *mockKeyProvider) Get(kid string) *rsa.PublicKey { + if kid == "valid-kid" { + return m.key + } + return nil +} + +func TestBasicAuthUnaryInterceptor(t *testing.T) { + users := map[string]string{"user": "pass"} + interceptor := BasicAuthUnaryInterceptor(users) + + tests := []struct { + name string + ctx context.Context + expectedErr error + }{ + { + name: "No Metadata", + ctx: context.Background(), + expectedErr: status.Error(codes.Unauthenticated, "missing metadata"), + }, + { + name: "No Authorization Header", + ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{}), + expectedErr: status.Error(codes.Unauthenticated, "missing authorization header"), + }, + { + name: "Invalid Format", + ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{ + "authorization": []string{"Bearer token"}, + }), + expectedErr: status.Error(codes.Unauthenticated, "invalid authorization header format"), + }, + { + name: "Invalid Base64", + ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{ + "authorization": []string{"Basic invalid-base64"}, + }), + expectedErr: status.Error(codes.Unauthenticated, "invalid base64 credentials"), + }, + { + name: "Invalid Credentials Format", + ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{ + "authorization": []string{"Basic " + base64.StdEncoding.EncodeToString([]byte("user"))}, + }), + expectedErr: status.Error(codes.Unauthenticated, "invalid credentials format"), + }, + { + name: "Wrong Password", + ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{ + "authorization": []string{"Basic " + base64.StdEncoding.EncodeToString([]byte("user:wrong"))}, + }), + expectedErr: status.Error(codes.Unauthenticated, "invalid credentials"), + }, + { + name: "Wrong User", + ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{ + "authorization": []string{"Basic " + base64.StdEncoding.EncodeToString([]byte("wrong:pass"))}, + }), + expectedErr: status.Error(codes.Unauthenticated, "invalid credentials"), + }, + { + name: "Success", + ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{ + "authorization": []string{"Basic " + base64.StdEncoding.EncodeToString([]byte("user:pass"))}, + }), + expectedErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := interceptor(tt.ctx, nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + assert.Equal(t, tt.expectedErr, err) + }) + } +} + +func TestAPIKeyAuthUnaryInterceptor(t *testing.T) { + keys := []string{"valid-key"} + interceptor := APIKeyAuthUnaryInterceptor(keys...) + + tests := []struct { + name string + ctx context.Context + expectedErr error + }{ + { + name: "No Metadata", + ctx: context.Background(), + expectedErr: status.Error(codes.Unauthenticated, "missing metadata"), + }, + { + name: "No API Key Header", + ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{}), + expectedErr: status.Error(codes.Unauthenticated, "missing x-api-key header"), + }, + { + name: "Invalid Key", + ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{ + "x-api-key": []string{"invalid-key"}, + }), + expectedErr: status.Error(codes.Unauthenticated, "invalid api key"), + }, + { + name: "Success", + ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{ + "x-api-key": []string{"valid-key"}, + }), + expectedErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := interceptor(tt.ctx, nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + assert.Equal(t, tt.expectedErr, err) + }) + } +} + +func TestOAuthUnaryInterceptor(t *testing.T) { + // Generate RSA key + privateKey, _ := rsa.GenerateKey(rand.Reader, 2048) + publicKey := &privateKey.PublicKey + + provider := &mockKeyProvider{key: publicKey} + interceptor := OAuthUnaryInterceptor(provider) + + // Create valid token + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "sub": "user", + }) + token.Header["kid"] = "valid-kid" + validToken, _ := token.SignedString(privateKey) + + tests := []struct { + name string + ctx context.Context + expectedErr error + }{ + { + name: "No Metadata", + ctx: context.Background(), + expectedErr: status.Error(codes.Unauthenticated, "missing metadata"), + }, + { + name: "No Authorization Header", + ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{}), + expectedErr: status.Error(codes.Unauthenticated, "missing authorization header"), + }, + { + name: "Invalid Format", + ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{ + "authorization": []string{"Token " + validToken}, + }), + expectedErr: status.Error(codes.Unauthenticated, "invalid authorization header format"), + }, + { + name: "Invalid Token", + ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{ + "authorization": []string{"Bearer invalid-token"}, + }), + expectedErr: status.Error(codes.Unauthenticated, "jwt expected"), + }, + { + name: "Success", + ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{ + "authorization": []string{"Bearer " + validToken}, + }), + expectedErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := interceptor(tt.ctx, nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) { + // Check if claims are in context + if tt.expectedErr == nil { + claims := ctx.Value(httpMiddleware.JWTClaim) + assert.NotNil(t, claims) + } + return nil, nil + }) + if tt.expectedErr != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/pkg/gofr/grpc/middleware/basic_auth.go b/pkg/gofr/grpc/middleware/basic_auth.go new file mode 100644 index 0000000000..04c8c6aec6 --- /dev/null +++ b/pkg/gofr/grpc/middleware/basic_auth.go @@ -0,0 +1,74 @@ +package middleware + +import ( + "context" + "crypto/subtle" + "encoding/base64" + "strings" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +// BasicAuthUnaryInterceptor returns a unary interceptor that validates requests using Basic Authentication. +func BasicAuthUnaryInterceptor(users map[string]string) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if err := validateBasicAuth(ctx, users); err != nil { + return nil, err + } + return handler(ctx, req) + } +} + +// BasicAuthStreamInterceptor returns a stream interceptor that validates requests using Basic Authentication. +func BasicAuthStreamInterceptor(users map[string]string) grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if err := validateBasicAuth(ss.Context(), users); err != nil { + return err + } + return handler(srv, ss) + } +} + +func validateBasicAuth(ctx context.Context, users map[string]string) error { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return status.Error(codes.Unauthenticated, "missing metadata") + } + + authHeader, ok := md["authorization"] + if !ok || len(authHeader) == 0 { + return status.Error(codes.Unauthenticated, "missing authorization header") + } + + // Basic + parts := strings.SplitN(authHeader[0], " ", 2) + if len(parts) != 2 || !strings.EqualFold(parts[0], "Basic") { + return status.Error(codes.Unauthenticated, "invalid authorization header format") + } + + payload, err := base64.StdEncoding.DecodeString(parts[1]) + if err != nil { + return status.Error(codes.Unauthenticated, "invalid base64 credentials") + } + + username, password, found := strings.Cut(string(payload), ":") + if !found { + return status.Error(codes.Unauthenticated, "invalid credentials format") + } + + expectedPass, ok := users[username] + if !ok { + // Use dummy comparison to prevent timing attacks + subtle.ConstantTimeCompare([]byte(password), []byte("dummy")) + return status.Error(codes.Unauthenticated, "invalid credentials") + } + + if subtle.ConstantTimeCompare([]byte(password), []byte(expectedPass)) != 1 { + return status.Error(codes.Unauthenticated, "invalid credentials") + } + + return nil +} diff --git a/pkg/gofr/grpc/middleware/oauth.go b/pkg/gofr/grpc/middleware/oauth.go new file mode 100644 index 0000000000..bccb9187b7 --- /dev/null +++ b/pkg/gofr/grpc/middleware/oauth.go @@ -0,0 +1,105 @@ +package middleware + +import ( + "context" + "errors" + "fmt" + "regexp" + "strings" + + "github.com/golang-jwt/jwt/v5" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + + httpMiddleware "gofr.dev/pkg/gofr/http/middleware" +) + +const ( + jwtRegexPattern = "^[A-Za-z0-9-_]+\\.[A-Za-z0-9-_]+\\.[A-Za-z0-9-_]+$" +) + +// OAuthUnaryInterceptor returns a unary interceptor that validates requests using OAuth. +func OAuthUnaryInterceptor(key httpMiddleware.PublicKeyProvider, options ...jwt.ParserOption) grpc.UnaryServerInterceptor { + regex := regexp.MustCompile(jwtRegexPattern) + options = append(options, jwt.WithIssuedAt()) + + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + claims, err := validateOAuth(ctx, key, regex, options...) + if err != nil { + return nil, err + } + + newCtx := context.WithValue(ctx, httpMiddleware.JWTClaim, claims) + return handler(newCtx, req) + } +} + +// OAuthStreamInterceptor returns a stream interceptor that validates requests using OAuth. +func OAuthStreamInterceptor(key httpMiddleware.PublicKeyProvider, options ...jwt.ParserOption) grpc.StreamServerInterceptor { + regex := regexp.MustCompile(jwtRegexPattern) + options = append(options, jwt.WithIssuedAt()) + + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + claims, err := validateOAuth(ss.Context(), key, regex, options...) + if err != nil { + return err + } + + // We need to wrap the stream to inject the new context containing the claims. + wrapped := &wrappedStream{ss, context.WithValue(ss.Context(), httpMiddleware.JWTClaim, claims)} + return handler(srv, wrapped) + } +} + +type wrappedStream struct { + grpc.ServerStream + ctx context.Context +} + +func (w *wrappedStream) Context() context.Context { + return w.ctx +} + +func validateOAuth(ctx context.Context, key httpMiddleware.PublicKeyProvider, regex *regexp.Regexp, options ...jwt.ParserOption) (jwt.Claims, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Error(codes.Unauthenticated, "missing metadata") + } + + authHeader, ok := md["authorization"] + if !ok || len(authHeader) == 0 { + return nil, status.Error(codes.Unauthenticated, "missing authorization header") + } + + // Bearer + parts := strings.SplitN(authHeader[0], " ", 2) + if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { + return nil, status.Error(codes.Unauthenticated, "invalid authorization header format") + } + + tokenString := parts[1] + if !regex.MatchString(tokenString) { + return nil, status.Error(codes.Unauthenticated, "jwt expected") + } + + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + kid := token.Header["kid"] + jwks := key.Get(fmt.Sprint(kid)) + if jwks == nil { + return nil, errors.New("key not found") + } + return jwks, nil + }, options...) + + if err != nil { + return nil, status.Errorf(codes.Unauthenticated, "invalid token: %v", err) + } + + if !token.Valid { + return nil, status.Error(codes.Unauthenticated, "invalid token") + } + + return token.Claims, nil +}