diff --git a/docs/advanced-guide/middlewares/page.md b/docs/advanced-guide/middlewares/page.md index 7bce611ef3..79f782c369 100644 --- a/docs/advanced-guide/middlewares/page.md +++ b/docs/advanced-guide/middlewares/page.md @@ -67,3 +67,52 @@ func main() { } ``` +## Rate Limiter Middleware in GoFr + +GoFr provides a built-in rate limiter middleware to protect your API from abuse and ensure fair resource distribution. +It uses a token bucket algorithm for smooth rate limiting with configurable burst capacity. + +### Features + +- **Token Bucket Algorithm**: Allows smooth rate limiting with configurable burst capacity +- **Per-IP Rate Limiting**: Each client IP gets its own rate limit (configurable) +- **Health Check Exemption**: `/.well-known/alive` and `/.well-known/health` endpoints are automatically exempt +- **Prometheus Metrics**: Track rate limit violations via `app_http_rate_limit_exceeded_total` counter +- **429 Status Code**: Returns standard HTTP 429 (Too Many Requests) when limit is exceeded + +### Configuration + +```go +import ( + "gofr.dev/pkg/gofr" + "gofr.dev/pkg/gofr/http/middleware" +) + +func main() { + app := gofr.New() + + // Configure rate limiter + rateLimiterConfig := middleware.RateLimiterConfig{ + RequestsPerSecond: 5, // Average requests per second + Burst: 10, // Maximum burst size + PerIP: true, // Enable per-IP limiting + } + + // Add rate limiter middleware + app.UseMiddleware(middleware.RateLimiter(rateLimiterConfig, app.Metrics())) + + app.GET("/api/resource", handler) + app.Run() +} +``` + +### Parameters + +- `RequestsPerSecond`: Average number of requests allowed per second +- `Burst`: Maximum number of requests that can be made in a burst (allows temporary spikes) +- `PerIP`: Set to `true` for per-IP limiting (recommended) or `false` for global rate limit across all clients +- `TrustedProxies`: *(Optional)* Set to `true` to trust `X-Forwarded-For` and `X-Real-IP` headers for IP extraction. Only enable when behind a trusted reverse proxy. + +> **Security Warning**: Only set `TrustedProxies: true` if your application is behind a trusted reverse proxy (nginx, ALB, etc.). +> Without a trusted proxy, clients can spoof headers to bypass rate limits. + diff --git a/pkg/gofr/http/errors.go b/pkg/gofr/http/errors.go index a117046600..24b177cf9a 100644 --- a/pkg/gofr/http/errors.go +++ b/pkg/gofr/http/errors.go @@ -163,6 +163,21 @@ func (ErrorPanicRecovery) LogLevel() logging.Level { return logging.ERROR } +// ErrorTooManyRequests represents an error when rate limit is exceeded. +type ErrorTooManyRequests struct{} + +func (ErrorTooManyRequests) Error() string { + return "rate limit exceeded" +} + +func (ErrorTooManyRequests) StatusCode() int { + return http.StatusTooManyRequests +} + +func (ErrorTooManyRequests) LogLevel() logging.Level { + return logging.WARN +} + // validate the errors satisfy the underlying interfaces they depend on. var ( _ StatusCodeResponder = ErrorEntityNotFound{} @@ -174,6 +189,7 @@ var ( _ StatusCodeResponder = ErrorPanicRecovery{} _ StatusCodeResponder = ErrorServiceUnavailable{} _ StatusCodeResponder = ErrorClientClosedRequest{} + _ StatusCodeResponder = ErrorTooManyRequests{} _ logging.LogLevelResponder = ErrorClientClosedRequest{} _ logging.LogLevelResponder = ErrorEntityNotFound{} @@ -184,4 +200,5 @@ var ( _ logging.LogLevelResponder = ErrorRequestTimeout{} _ logging.LogLevelResponder = ErrorPanicRecovery{} _ logging.LogLevelResponder = ErrorServiceUnavailable{} + _ logging.LogLevelResponder = ErrorTooManyRequests{} ) diff --git a/pkg/gofr/http/middleware/rate_limiter.go b/pkg/gofr/http/middleware/rate_limiter.go new file mode 100644 index 0000000000..69a446e9d5 --- /dev/null +++ b/pkg/gofr/http/middleware/rate_limiter.go @@ -0,0 +1,172 @@ +package middleware + +import ( + "context" + "errors" + "fmt" + "math" + "net" + "net/http" + "strings" + + gofrHttp "gofr.dev/pkg/gofr/http" +) + +var ( + // errInvalidRequestsPerSecond is returned when RequestsPerSecond is not positive. + errInvalidRequestsPerSecond = errors.New("requestsPerSecond must be positive") + + // errInvalidBurst is returned when Burst is not positive. + errInvalidBurst = errors.New("burst must be positive") +) + +// RateLimiterConfig holds configuration for rate limiting. +// +// Note: The default implementation uses in-memory token buckets and is suitable +// for single-pod deployments. In multi-pod deployments, each pod will enforce +// limits independently. For distributed rate limiting across multiple pods, +// a Redis-backed store can be implemented in a future update. +// +// Security: When using PerIP=true, only enable TrustedProxies if your application +// is behind a trusted reverse proxy (nginx, ALB, etc.) that sets X-Forwarded-For. +// Without trusted proxies, clients can spoof IP addresses to bypass rate limits. +// +// Cleanup: The rate limiter starts a background goroutine that runs for the +// application lifetime. This is acceptable for long-running servers but consider +// calling Store.StopCleanup() in shutdown handlers if needed. +type RateLimiterConfig struct { + RequestsPerSecond float64 + Burst int + PerIP bool + Store RateLimiterStore // Optional: defaults to in-memory store + TrustedProxies bool // If true, trust X-Forwarded-For and X-Real-IP headers +} + +// Validate checks if the configuration values are valid. +func (c RateLimiterConfig) Validate() error { + if c.RequestsPerSecond <= 0 { + return errInvalidRequestsPerSecond + } + + if c.Burst <= 0 { + return errInvalidBurst + } + + return nil +} + +// getIP extracts the client IP address from the request. +// If trustProxies is false, only RemoteAddr is used to prevent IP spoofing. +func getIP(r *http.Request, trustProxies bool) string { + if !trustProxies { + return getRemoteAddr(r) + } + + // Try X-Forwarded-For header first + if ip := getForwardedIP(r); ip != "" { + return ip + } + + // Try X-Real-IP header + if ip := getRealIP(r); ip != "" { + return ip + } + + // Fall back to RemoteAddr + return getRemoteAddr(r) +} + +// getForwardedIP extracts IP from X-Forwarded-For header. +func getForwardedIP(r *http.Request) string { + forwarded := r.Header.Get("X-Forwarded-For") + if forwarded == "" { + return "" + } + + // X-Forwarded-For can contain multiple IPs, take the first one + ips := strings.Split(forwarded, ",") + if len(ips) == 0 { + return "" + } + + return strings.TrimSpace(ips[0]) +} + +// getRealIP extracts IP from X-Real-IP header. +func getRealIP(r *http.Request) string { + realIP := r.Header.Get("X-Real-IP") + return strings.TrimSpace(realIP) +} + +// getRemoteAddr extracts IP from RemoteAddr. +func getRemoteAddr(r *http.Request) string { + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + + return ip +} + +// RateLimiter creates a middleware that limits requests based on the configuration. +func RateLimiter(config RateLimiterConfig, m metrics) func(http.Handler) http.Handler { + // Validate configuration + if err := config.Validate(); err != nil { + panic(fmt.Sprintf("invalid rate limiter config: %v", err)) + } + + // Use in-memory store if none provided + if config.Store == nil { + config.Store = NewMemoryRateLimiterStore(config) + } + + // Start cleanup routine with context.Background(). + // The cleanup goroutine runs for the application lifetime. + // For graceful shutdown, call config.Store.StopCleanup() in your shutdown handler. + ctx := context.Background() + config.Store.StartCleanup(ctx) + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Skip rate limiting for health check endpoints + if isWellKnown(r.URL.Path) { + next.ServeHTTP(w, r) + return + } + + // Determine the rate limit key (IP or global) + key := "global" + if config.PerIP { + key = getIP(r, config.TrustedProxies) + } + + // Check rate limit + allowed, retryAfter, err := config.Store.Allow(r.Context(), key, config) + if err != nil { + // Fail open on errors + next.ServeHTTP(w, r) + return + } + + if !allowed { + // Set Retry-After header (RFC 6585) + // Use math.Ceil to ensure at least 1 second for sub-second delays + w.Header().Set("Retry-After", fmt.Sprintf("%.0f", math.Ceil(retryAfter.Seconds()))) + + // Increment rate limit exceeded metric + if m != nil { + m.IncrementCounter(r.Context(), "app_http_rate_limit_exceeded_total", + "path", r.URL.Path, "method", r.Method) + } + + // Return 429 Too Many Requests + responder := gofrHttp.NewResponder(w, r.Method) + responder.Respond(nil, gofrHttp.ErrorTooManyRequests{}) + + return + } + + next.ServeHTTP(w, r) + }) + } +} diff --git a/pkg/gofr/http/middleware/rate_limiter_store.go b/pkg/gofr/http/middleware/rate_limiter_store.go new file mode 100644 index 0000000000..b8cf33ed85 --- /dev/null +++ b/pkg/gofr/http/middleware/rate_limiter_store.go @@ -0,0 +1,129 @@ +package middleware + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "golang.org/x/time/rate" +) + +// RateLimiterStore abstracts the storage and cleanup for rate limiter buckets. +// This interface matches the one defined in pkg/gofr/service for consistency. +// +// Note: The config parameter in Allow() is provided for interface compatibility. +// Implementations may use a stored configuration and ignore this parameter. +type RateLimiterStore interface { + Allow(ctx context.Context, key string, config RateLimiterConfig) (allowed bool, retryAfter time.Duration, err error) + StartCleanup(ctx context.Context) + StopCleanup() +} + +// memoryRateLimiterStore implements RateLimiterStore using in-memory token buckets. +type memoryRateLimiterStore struct { + limiters sync.Map // map[string]*limiterEntry + stopCh chan struct{} + cleanupOnce sync.Once + stopOnce sync.Once + config RateLimiterConfig // Store config for consistency +} + +type limiterEntry struct { + limiter *rate.Limiter + lastAccess int64 // Unix timestamp for cleanup +} + +// NewMemoryRateLimiterStore creates a new in-memory rate limiter store. +// The config is stored to ensure consistent rate limiting for all keys. +func NewMemoryRateLimiterStore(config RateLimiterConfig) RateLimiterStore { + return &memoryRateLimiterStore{config: config} +} + +// Allow checks if a request should be allowed based on the rate limit. +func (m *memoryRateLimiterStore) Allow(_ context.Context, key string, _ RateLimiterConfig) (bool, time.Duration, error) { + now := time.Now().Unix() + + // Use stored config for consistency across all keys + cfg := m.config + + // Get or create limiter for this key + val, _ := m.limiters.LoadOrStore(key, &limiterEntry{ + limiter: rate.NewLimiter(rate.Limit(cfg.RequestsPerSecond), cfg.Burst), + lastAccess: now, + }) + + entry := val.(*limiterEntry) + atomic.StoreInt64(&entry.lastAccess, now) + + // Check if request is allowed + if !entry.limiter.Allow() { + // Calculate retry-after duration + reservation := entry.limiter.Reserve() + if !reservation.OK() { + // Burst exceeded - calculate delay based on request rate + // Time to wait for one token = 1 / RequestsPerSecond + delay := time.Duration(float64(time.Second) / cfg.RequestsPerSecond) + return false, delay, nil + } + + delay := reservation.Delay() + reservation.Cancel() // Don't actually consume the token + + return false, delay, nil + } + + return true, 0, nil +} + +// StartCleanup starts a background goroutine to clean up stale limiters. +// This method is safe to call multiple times - only one cleanup goroutine will be started. +func (m *memoryRateLimiterStore) StartCleanup(ctx context.Context) { + m.cleanupOnce.Do(func() { + m.stopCh = make(chan struct{}) + + go func() { + const cleanupInterval = 5 * time.Minute + + const staleThreshold = 10 * time.Minute + + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + m.cleanup(staleThreshold) + case <-m.stopCh: + return + case <-ctx.Done(): + return + } + } + }() + }) +} + +// StopCleanup stops the cleanup goroutine. +// This method is safe to call multiple times. +func (m *memoryRateLimiterStore) StopCleanup() { + m.stopOnce.Do(func() { + if m.stopCh != nil { + close(m.stopCh) + } + }) +} + +// cleanup removes stale limiters that haven't been accessed recently. +func (m *memoryRateLimiterStore) cleanup(staleThreshold time.Duration) { + cutoff := time.Now().Unix() - int64(staleThreshold.Seconds()) + + m.limiters.Range(func(key, value any) bool { + entry := value.(*limiterEntry) + if atomic.LoadInt64(&entry.lastAccess) < cutoff { + m.limiters.Delete(key) + } + + return true + }) +} diff --git a/pkg/gofr/http/middleware/rate_limiter_test.go b/pkg/gofr/http/middleware/rate_limiter_test.go new file mode 100644 index 0000000000..31c1f135fd --- /dev/null +++ b/pkg/gofr/http/middleware/rate_limiter_test.go @@ -0,0 +1,521 @@ +package middleware + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type rateLimiterMockMetrics struct { + mu sync.Mutex + counters map[string]int +} + +func newRateLimiterMockMetrics() *rateLimiterMockMetrics { + return &rateLimiterMockMetrics{ + counters: make(map[string]int), + } +} + +func (m *rateLimiterMockMetrics) IncrementCounter(_ context.Context, name string, _ ...string) { + m.mu.Lock() + defer m.mu.Unlock() + + m.counters[name]++ +} + +func (*rateLimiterMockMetrics) DeltaUpDownCounter(_ context.Context, _ string, _ float64, _ ...string) { + // Not used in rate limiter tests +} + +func (*rateLimiterMockMetrics) RecordHistogram(_ context.Context, _ string, _ float64, _ ...string) { + // Not used in rate limiter tests +} + +func (*rateLimiterMockMetrics) SetGauge(_ string, _ float64, _ ...string) { + // Not used in rate limiter tests +} + +func (m *rateLimiterMockMetrics) GetCounter(name string) int { + m.mu.Lock() + defer m.mu.Unlock() + + return m.counters[name] +} + +func TestRateLimiter_GlobalLimit(t *testing.T) { + metrics := newRateLimiterMockMetrics() + config := RateLimiterConfig{ + RequestsPerSecond: 2, + Burst: 2, + PerIP: false, + } + + handler := RateLimiter(config, metrics)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + })) + + // First 2 requests should succeed (burst) + for i := 0; i < 2; i++ { + req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, "Request %d should succeed", i+1) + } + + // 3rd request should be rate limited + req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Request should be rate limited") + + // Verify metric was incremented + assert.Equal(t, 1, metrics.GetCounter("app_http_rate_limit_exceeded_total")) +} + +func TestRateLimiter_PerIPLimit(t *testing.T) { + metrics := newRateLimiterMockMetrics() + config := RateLimiterConfig{ + RequestsPerSecond: 2, + Burst: 2, + PerIP: true, + } + + handler := RateLimiter(config, metrics)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // IP1: First 2 requests should succeed + for i := 0; i < 2; i++ { + req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody) + req.RemoteAddr = "192.168.1.1:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + } + + // IP1: 3rd request should be rate limited + req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody) + req.RemoteAddr = "192.168.1.1:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusTooManyRequests, rr.Code) + + // IP2: Should still be able to make requests (different limiter) + req = httptest.NewRequest(http.MethodGet, "/test", http.NoBody) + req.RemoteAddr = "192.168.1.2:54321" + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) +} + +func TestRateLimiter_SkipHealthEndpoints(t *testing.T) { + metrics := newRateLimiterMockMetrics() + config := RateLimiterConfig{ + RequestsPerSecond: 1, + Burst: 1, + PerIP: false, + } + + handler := RateLimiter(config, metrics)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Health endpoints should not be rate limited + healthPaths := []string{"/.well-known/health", "/.well-known/alive"} + + for _, path := range healthPaths { + for i := 0; i < 5; i++ { + req := httptest.NewRequest(http.MethodGet, path, http.NoBody) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, "Health endpoint %s should not be rate limited", path) + } + } +} + +func TestRateLimiter_ConcurrentRequests(t *testing.T) { + metrics := newRateLimiterMockMetrics() + config := RateLimiterConfig{ + RequestsPerSecond: 10, + Burst: 10, + PerIP: true, + } + + handler := RateLimiter(config, metrics)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + var wg sync.WaitGroup + + successCount := 0 + rateLimitedCount := 0 + + var mu sync.Mutex + + // Send 20 concurrent requests from same IP + for i := 0; i < 20; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + + req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody) + req.RemoteAddr = "192.168.1.1:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + mu.Lock() + + if rr.Code == http.StatusOK { + successCount++ + } else if rr.Code == http.StatusTooManyRequests { + rateLimitedCount++ + } + + mu.Unlock() + }() + } + + wg.Wait() + + // Due to timing/race conditions in concurrent tests, we allow a small tolerance + // The important thing is that rate limiting occurred + assert.GreaterOrEqual(t, successCount, 9, "Should allow approximately burst size requests") + assert.LessOrEqual(t, successCount, 11, "Should not allow significantly more than burst size") + assert.Positive(t, rateLimitedCount, "Should have some rate limited requests") + assert.Equal(t, 20, successCount+rateLimitedCount, "Total requests should be 20") +} + +func TestRateLimiter_TokenRefill(t *testing.T) { + if testing.Short() { + t.Skip("Skipping time-based test in short mode") + } + + metrics := newRateLimiterMockMetrics() + config := RateLimiterConfig{ + RequestsPerSecond: 5, // 5 requests per second + Burst: 2, + PerIP: false, + } + + handler := RateLimiter(config, metrics)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Use up burst + for i := 0; i < 2; i++ { + req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + require.Equal(t, http.StatusOK, rr.Code) + } + + // Next request should be rate limited + req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusTooManyRequests, rr.Code) + + // Wait for token refill (200ms = 1 token at 5 req/sec) + time.Sleep(220 * time.Millisecond) + + // Should succeed now + req = httptest.NewRequest(http.MethodGet, "/test", http.NoBody) + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) +} + +func TestGetIP_XForwardedFor(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody) + req.Header.Set("X-Forwarded-For", "203.0.113.1, 198.51.100.1") + req.RemoteAddr = "192.168.1.1:12345" + + ip := getIP(req, true) + assert.Equal(t, "203.0.113.1", ip, "Should extract first IP from X-Forwarded-For when trusting proxies") + + // Without trusting proxies, should use RemoteAddr + ip = getIP(req, false) + assert.Equal(t, "192.168.1.1", ip, "Should use RemoteAddr when not trusting proxies") +} + +func TestGetIP_XRealIP(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody) + req.Header.Set("X-Real-IP", "203.0.113.5") + req.RemoteAddr = "192.168.1.1:12345" + + ip := getIP(req, true) + assert.Equal(t, "203.0.113.5", ip, "Should extract IP from X-Real-IP when trusting proxies") + + // Without trusting proxies, should use RemoteAddr + ip = getIP(req, false) + assert.Equal(t, "192.168.1.1", ip, "Should use RemoteAddr when not trusting proxies") +} + +func TestGetIP_RemoteAddr(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody) + req.RemoteAddr = "192.168.1.1:12345" + + ip := getIP(req, false) + assert.Equal(t, "192.168.1.1", ip, "Should extract IP from RemoteAddr") +} + +func TestGetIP_Priority(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody) + req.Header.Set("X-Forwarded-For", "203.0.113.1") + req.Header.Set("X-Real-IP", "203.0.113.2") + req.RemoteAddr = "192.168.1.1:12345" + + ip := getIP(req, true) + assert.Equal(t, "203.0.113.1", ip, "X-Forwarded-For should have highest priority when trusting proxies") +} + +func TestRateLimiter_RetryAfterHeader(t *testing.T) { + metrics := newRateLimiterMockMetrics() + config := RateLimiterConfig{ + RequestsPerSecond: 2, + Burst: 1, + PerIP: false, + } + + handler := RateLimiter(config, metrics)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First request should succeed + req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + + // Second request should be rate limited and include Retry-After header + req = httptest.NewRequest(http.MethodGet, "/test", http.NoBody) + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusTooManyRequests, rr.Code) + assert.NotEmpty(t, rr.Header().Get("Retry-After"), "Retry-After header should be set") +} + +func TestRateLimiterConfig_Validate(t *testing.T) { + tests := []struct { + name string + config RateLimiterConfig + wantErr bool + }{ + { + name: "valid config", + config: RateLimiterConfig{ + RequestsPerSecond: 10, + Burst: 20, + PerIP: true, + }, + wantErr: false, + }, + { + name: "zero RequestsPerSecond", + config: RateLimiterConfig{ + RequestsPerSecond: 0, + Burst: 20, + PerIP: true, + }, + wantErr: true, + }, + { + name: "negative RequestsPerSecond", + config: RateLimiterConfig{ + RequestsPerSecond: -5, + Burst: 20, + PerIP: true, + }, + wantErr: true, + }, + { + name: "zero Burst", + config: RateLimiterConfig{ + RequestsPerSecond: 10, + Burst: 0, + PerIP: true, + }, + wantErr: true, + }, + { + name: "negative Burst", + config: RateLimiterConfig{ + RequestsPerSecond: 10, + Burst: -5, + PerIP: true, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestMemoryRateLimiterStore_StopCleanupMultipleCalls(t *testing.T) { + t.Helper() + + config := RateLimiterConfig{ + RequestsPerSecond: 10, + Burst: 20, + PerIP: true, + } + + store := NewMemoryRateLimiterStore(config).(*memoryRateLimiterStore) + ctx := context.Background() + + // Start cleanup + store.StartCleanup(ctx) + + // Stop multiple times - should not panic + store.StopCleanup() + store.StopCleanup() + store.StopCleanup() + + // Test passes if no panic occurs +} + +func TestMemoryRateLimiterStore_Cleanup(t *testing.T) { + config := RateLimiterConfig{ + RequestsPerSecond: 10, + Burst: 20, + PerIP: true, + } + + store := NewMemoryRateLimiterStore(config).(*memoryRateLimiterStore) + ctx := context.Background() + + // Add some entries + allowed1, _, _ := store.Allow(ctx, "ip1", config) + allowed2, _, _ := store.Allow(ctx, "ip2", config) + allowed3, _, _ := store.Allow(ctx, "ip3", config) + + assert.True(t, allowed1 && allowed2 && allowed3, "All initial requests should be allowed") + + // Verify entries exist + count := 0 + + store.limiters.Range(func(_, _ any) bool { + count++ + return true + }) + + assert.Equal(t, 3, count, "Should have 3 entries") + + // Manually trigger cleanup with a threshold that marks all as stale + // Set lastAccess to past time + store.limiters.Range(func(_ any, value any) bool { + entry := value.(*limiterEntry) + atomic.StoreInt64(&entry.lastAccess, time.Now().Unix()-3600) // 1 hour ago + + return true + }) + + // Run cleanup with 10 minute threshold + store.cleanup(10 * time.Minute) + + // Verify stale entries were removed + count = 0 + + store.limiters.Range(func(_, _ any) bool { + count++ + return true + }) + + assert.Equal(t, 0, count, "Stale entries should be removed") +} + +func TestRateLimiter_TrustedProxiesEnabled(t *testing.T) { + metrics := newRateLimiterMockMetrics() + config := RateLimiterConfig{ + RequestsPerSecond: 2, + Burst: 2, + PerIP: true, + TrustedProxies: true, // Trust proxy headers + } + + handler := RateLimiter(config, metrics)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Send 2 requests from same X-Forwarded-For IP + for i := 0; i < 2; i++ { + req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody) + req.RemoteAddr = "127.0.0.1:12345" // Proxy IP + req.Header.Set("X-Forwarded-For", "203.0.113.1") // Client IP + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + } + + // 3rd request from same X-Forwarded-For IP should be rate limited + req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody) + req.RemoteAddr = "127.0.0.1:12345" + req.Header.Set("X-Forwarded-For", "203.0.113.1") + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Should rate limit based on X-Forwarded-For IP") + + // Different X-Forwarded-For IP should have separate limit + req = httptest.NewRequest(http.MethodGet, "/test", http.NoBody) + req.RemoteAddr = "127.0.0.1:12345" // Same proxy + req.Header.Set("X-Forwarded-For", "203.0.113.2") // Different client IP + + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, "Different client IP should have separate rate limit") +} + +func TestRateLimiter_TrustedProxiesDisabled(t *testing.T) { + metrics := newRateLimiterMockMetrics() + config := RateLimiterConfig{ + RequestsPerSecond: 2, + Burst: 2, + PerIP: true, + TrustedProxies: false, // Do not trust proxy headers + } + + handler := RateLimiter(config, metrics)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Send 2 requests with same RemoteAddr but different X-Forwarded-For + for i := 0; i < 2; i++ { + req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody) + req.RemoteAddr = "127.0.0.1:12345" + req.Header.Set("X-Forwarded-For", fmt.Sprintf("203.0.113.%d", i+1)) // Different spoofed IPs + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + } + // 3rd request should be rate limited based on RemoteAddr, ignoring X-Forwarded-For + req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody) + req.RemoteAddr = "127.0.0.1:12345" + req.Header.Set("X-Forwarded-For", "203.0.113.99") // Different spoofed IP + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Should rate limit based on RemoteAddr, ignoring spoofed headers") +}