Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions pkg/gofr/grpc/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"time"

"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/propagation"

Check failure on line 13 in pkg/gofr/grpc/log.go

View workflow job for this annotation

GitHub Actions / Example Unit Testing (v1.24)🛠

"go.opentelemetry.io/otel/propagation" imported and not used

Check failure on line 13 in pkg/gofr/grpc/log.go

View workflow job for this annotation

GitHub Actions / Example Unit Testing (v1.25)🛠

"go.opentelemetry.io/otel/propagation" imported and not used

Check failure on line 13 in pkg/gofr/grpc/log.go

View workflow job for this annotation

GitHub Actions / PKG Unit Testing (v1.25)🛠

"go.opentelemetry.io/otel/propagation" imported and not used

Check failure on line 13 in pkg/gofr/grpc/log.go

View workflow job for this annotation

GitHub Actions / PKG Unit Testing (v1.25)🛠

"go.opentelemetry.io/otel/propagation" imported and not used

Check failure on line 13 in pkg/gofr/grpc/log.go

View workflow job for this annotation

GitHub Actions / PKG Unit Testing (v1.24)🛠

"go.opentelemetry.io/otel/propagation" imported and not used

Check failure on line 13 in pkg/gofr/grpc/log.go

View workflow job for this annotation

GitHub Actions / PKG Unit Testing (v1.24)🛠

"go.opentelemetry.io/otel/propagation" imported and not used

Check failure on line 13 in pkg/gofr/grpc/log.go

View workflow job for this annotation

GitHub Actions / Example Unit Testing (v1.23)🛠

"go.opentelemetry.io/otel/propagation" imported and not used

Check failure on line 13 in pkg/gofr/grpc/log.go

View workflow job for this annotation

GitHub Actions / PKG Unit Testing (v1.23)🛠

"go.opentelemetry.io/otel/propagation" imported and not used

Check failure on line 13 in pkg/gofr/grpc/log.go

View workflow job for this annotation

GitHub Actions / PKG Unit Testing (v1.23)🛠

"go.opentelemetry.io/otel/propagation" imported and not used

Check failure on line 13 in pkg/gofr/grpc/log.go

View workflow job for this annotation

GitHub Actions / Code Quality🎖️

"go.opentelemetry.io/otel/propagation" imported and not used (typecheck)
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -179,6 +180,12 @@
func initializeSpanContext(ctx context.Context) context.Context {
md, _ := metadata.FromIncomingContext(ctx)

ctx = otel.GetTextMapPropagator().Extract(ctx, metadataCarrier(md))

if trace.SpanContextFromContext(ctx).IsValid() {
return ctx
}

traceIDHex := getMetadataValue(md, "x-gofr-traceid")
spanIDHex := getMetadataValue(md, "x-gofr-spanid")

Expand Down Expand Up @@ -304,3 +311,25 @@

return ""
}

type metadataCarrier metadata.MD

func (m metadataCarrier) Get(key string) string {
values := metadata.MD(m).Get(key)
if len(values) > 0 {
return values[0]
}
return ""
}

func (m metadataCarrier) Set(key string, value string) {
metadata.MD(m).Set(key, value)
}

func (m metadataCarrier) Keys() []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}
57 changes: 57 additions & 0 deletions pkg/gofr/grpc/middleware/apikey_auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package middleware

import (
"context"
"crypto/subtle"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

// APIKeyAuthUnaryInterceptor returns a unary interceptor that validates requests using API Key Authentication.
func APIKeyAuthUnaryInterceptor(apiKeys ...string) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if err := validateAPIKey(ctx, apiKeys); err != nil {
return nil, err
}
return handler(ctx, req)
}
}

// APIKeyAuthStreamInterceptor returns a stream interceptor that validates requests using API Key Authentication.
func APIKeyAuthStreamInterceptor(apiKeys ...string) grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if err := validateAPIKey(ss.Context(), apiKeys); err != nil {
return err
}
return handler(srv, ss)
}
}

func validateAPIKey(ctx context.Context, validKeys []string) error {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return status.Error(codes.Unauthenticated, "missing metadata")
}

// Check for x-api-key
values, ok := md["x-api-key"]
if !ok || len(values) == 0 {
return status.Error(codes.Unauthenticated, "missing x-api-key header")
}

apiKey := values[0]

for _, key := range validKeys {
if subtle.ConstantTimeCompare([]byte(apiKey), []byte(key)) == 1 {
return nil
}
}

// Constant time compare with dummy key to mitigate timing attacks
subtle.ConstantTimeCompare([]byte(apiKey), []byte("dummy"))

return status.Error(codes.Unauthenticated, "invalid api key")
}
220 changes: 220 additions & 0 deletions pkg/gofr/grpc/middleware/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
package middleware

import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"fmt"
"testing"

"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

httpMiddleware "gofr.dev/pkg/gofr/http/middleware"
)

type mockKeyProvider struct {
key *rsa.PublicKey
}

func (m *mockKeyProvider) Get(kid string) *rsa.PublicKey {
if kid == "valid-kid" {
return m.key
}
return nil
}

func TestBasicAuthUnaryInterceptor(t *testing.T) {
users := map[string]string{"user": "pass"}
interceptor := BasicAuthUnaryInterceptor(users)

tests := []struct {
name string
ctx context.Context
expectedErr error
}{
{
name: "No Metadata",
ctx: context.Background(),
expectedErr: status.Error(codes.Unauthenticated, "missing metadata"),
},
{
name: "No Authorization Header",
ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{}),
expectedErr: status.Error(codes.Unauthenticated, "missing authorization header"),
},
{
name: "Invalid Format",
ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{
"authorization": []string{"Bearer token"},
}),
expectedErr: status.Error(codes.Unauthenticated, "invalid authorization header format"),
},
{
name: "Invalid Base64",
ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{
"authorization": []string{"Basic invalid-base64"},
}),
expectedErr: status.Error(codes.Unauthenticated, "invalid base64 credentials"),
},
{
name: "Invalid Credentials Format",
ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{
"authorization": []string{"Basic " + base64.StdEncoding.EncodeToString([]byte("user"))},
}),
expectedErr: status.Error(codes.Unauthenticated, "invalid credentials format"),
},
{
name: "Wrong Password",
ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{
"authorization": []string{"Basic " + base64.StdEncoding.EncodeToString([]byte("user:wrong"))},
}),
expectedErr: status.Error(codes.Unauthenticated, "invalid credentials"),
},
{
name: "Wrong User",
ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{
"authorization": []string{"Basic " + base64.StdEncoding.EncodeToString([]byte("wrong:pass"))},
}),
expectedErr: status.Error(codes.Unauthenticated, "invalid credentials"),
},
{
name: "Success",
ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{
"authorization": []string{"Basic " + base64.StdEncoding.EncodeToString([]byte("user:pass"))},
}),
expectedErr: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := interceptor(tt.ctx, nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
assert.Equal(t, tt.expectedErr, err)
})
}
}

func TestAPIKeyAuthUnaryInterceptor(t *testing.T) {
keys := []string{"valid-key"}
interceptor := APIKeyAuthUnaryInterceptor(keys...)

tests := []struct {
name string
ctx context.Context
expectedErr error
}{
{
name: "No Metadata",
ctx: context.Background(),
expectedErr: status.Error(codes.Unauthenticated, "missing metadata"),
},
{
name: "No API Key Header",
ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{}),
expectedErr: status.Error(codes.Unauthenticated, "missing x-api-key header"),
},
{
name: "Invalid Key",
ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{
"x-api-key": []string{"invalid-key"},
}),
expectedErr: status.Error(codes.Unauthenticated, "invalid api key"),
},
{
name: "Success",
ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{
"x-api-key": []string{"valid-key"},
}),
expectedErr: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := interceptor(tt.ctx, nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
assert.Equal(t, tt.expectedErr, err)
})
}
}

func TestOAuthUnaryInterceptor(t *testing.T) {
// Generate RSA key
privateKey, _ := rsa.GenerateKey(rand.Reader, 2048)
publicKey := &privateKey.PublicKey

provider := &mockKeyProvider{key: publicKey}
interceptor := OAuthUnaryInterceptor(provider)

// Create valid token
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"sub": "user",
})
token.Header["kid"] = "valid-kid"
validToken, _ := token.SignedString(privateKey)

tests := []struct {
name string
ctx context.Context
expectedErr error
}{
{
name: "No Metadata",
ctx: context.Background(),
expectedErr: status.Error(codes.Unauthenticated, "missing metadata"),
},
{
name: "No Authorization Header",
ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{}),
expectedErr: status.Error(codes.Unauthenticated, "missing authorization header"),
},
{
name: "Invalid Format",
ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{
"authorization": []string{"Token " + validToken},
}),
expectedErr: status.Error(codes.Unauthenticated, "invalid authorization header format"),
},
{
name: "Invalid Token",
ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{
"authorization": []string{"Bearer invalid-token"},
}),
expectedErr: status.Error(codes.Unauthenticated, "jwt expected"),
},
{
name: "Success",
ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{
"authorization": []string{"Bearer " + validToken},
}),
expectedErr: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := interceptor(tt.ctx, nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) {
// Check if claims are in context
if tt.expectedErr == nil {
claims := ctx.Value(httpMiddleware.JWTClaim)
assert.NotNil(t, claims)
}
return nil, nil
})
if tt.expectedErr != nil {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
Loading
Loading