Skip to content

Commit 2d552fc

Browse files
committed
feat: implement rate limiter for the grid-proxy
Signed-off-by: nabil salah <nabil.salah203@gmail.com>
1 parent eed49b6 commit 2d552fc

File tree

9 files changed

+617
-22
lines changed

9 files changed

+617
-22
lines changed

grid-proxy/Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ server-start: ## Start the proxy server (Args: `m=<MNEMONICS>`)
5858
--postgres-db tfgrid-graphql \
5959
--postgres-password postgres \
6060
--postgres-user postgres \
61+
--rate-limit-rps 1000 \
6162
--mnemonics "$(m)" ;
6263

6364
all-start: db-stop db-start sleep db-fill server-start ## Full start of the database and the server (Args: `m=<MNEMONICS>`)

grid-proxy/cmds/proxy_server/main.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ type flags struct {
5555
relayURL string
5656
mnemonics string
5757
maxPoolOpenConnections int
58+
rateLimitRPS int // Rate limit requests per second per IP
5859

5960
noIndexer bool // true to stop the indexer, useful on running for testing
6061
indexerUpserterBatchSize uint
@@ -98,6 +99,7 @@ func main() {
9899
flag.StringVar(&f.relayURL, "relay-url", DefaultRelayURL, "RMB relay url")
99100
flag.StringVar(&f.mnemonics, "mnemonics", "", "Dummy user mnemonics for relay calls")
100101
flag.IntVar(&f.maxPoolOpenConnections, "max-open-conns", 80, "max number of db connection pool open connections")
102+
flag.IntVar(&f.rateLimitRPS, "rate-limit-rps", 20, "rate limit requests per second per IP address (0 to disable)")
101103

102104
flag.BoolVar(&f.noIndexer, "no-indexer", false, "do not start the indexer")
103105
flag.UintVar(&f.indexerUpserterBatchSize, "indexer-upserter-batch-size", 20, "results batch size which collected before upserting")
@@ -182,6 +184,13 @@ func main() {
182184
log.Fatal().Err(err).Msg("failed to create mux server")
183185
}
184186

187+
// Log rate limiting configuration
188+
if f.rateLimitRPS > 0 {
189+
log.Info().Int("rate_limit_rps", f.rateLimitRPS).Msg("HTTP rate limiting enabled")
190+
} else {
191+
log.Info().Msg("HTTP rate limiting disabled")
192+
}
193+
185194
if err := app(s, f); err != nil {
186195
log.Fatal().Msg(err.Error())
187196
}
@@ -331,7 +340,7 @@ func createServer(f flags, dbClient explorer.DBClient, gitCommit string, relayCl
331340
router := mux.NewRouter().StrictSlash(true)
332341

333342
// setup explorer
334-
if err := explorer.Setup(router, gitCommit, dbClient, relayClient, idxIntervals); err != nil {
343+
if err := explorer.Setup(router, gitCommit, dbClient, relayClient, idxIntervals, f.rateLimitRPS); err != nil {
335344
return nil, err
336345
}
337346

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
package mw
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
"strconv"
7+
"time"
8+
9+
"github.com/rs/zerolog/log"
10+
"github.com/threefoldtech/tfgrid-sdk-go/grid-proxy/tools/ratelimiter"
11+
)
12+
13+
// RateLimiterMiddleware wraps the rate limiter to work with the existing middleware pattern
14+
type RateLimiterMiddleware struct {
15+
limiter *ratelimiter.SlidingWindowRateLimiter
16+
}
17+
18+
// NewRateLimiterMiddleware creates a new rate limiter middleware
19+
func NewRateLimiterMiddleware(ratePerSecond int) *RateLimiterMiddleware {
20+
return &RateLimiterMiddleware{
21+
limiter: ratelimiter.NewSlidingWindowRateLimiter(ratePerSecond),
22+
}
23+
}
24+
25+
// RateLimitAction wraps an Action with rate limiting
26+
func (rlm *RateLimiterMiddleware) RateLimitAction(action Action) Action {
27+
return func(r *http.Request) (interface{}, Response) {
28+
clientIP := ratelimiter.GetClientIP(r)
29+
30+
if !rlm.limiter.Allow(clientIP) {
31+
log.Warn().
32+
Str("ip", clientIP).
33+
Str("method", r.Method).
34+
Str("path", r.URL.Path).
35+
Msg("Rate limit exceeded")
36+
37+
return nil, rlm.TooManyRequests(fmt.Errorf("rate limit exceeded for IP: %s", clientIP), clientIP)
38+
}
39+
40+
return action(r)
41+
}
42+
}
43+
44+
// RateLimitProxyAction wraps a ProxyAction with rate limiting
45+
func (rlm *RateLimiterMiddleware) RateLimitProxyAction(action ProxyAction) ProxyAction {
46+
return func(r *http.Request) (*http.Response, Response) {
47+
// Get client IP
48+
clientIP := ratelimiter.GetClientIP(r)
49+
50+
// Check rate limit
51+
if !rlm.limiter.Allow(clientIP) {
52+
log.Warn().
53+
Str("ip", clientIP).
54+
Str("method", r.Method).
55+
Str("path", r.URL.Path).
56+
Msg("Rate limit exceeded")
57+
58+
return nil, rlm.TooManyRequests(fmt.Errorf("rate limit exceeded for IP: %s", clientIP), clientIP)
59+
}
60+
61+
// Rate limit passed, execute the original action
62+
return action(r)
63+
}
64+
}
65+
66+
// AsRateLimitedHandlerFunc wraps AsHandlerFunc with rate limiting
67+
func (rlm *RateLimiterMiddleware) AsRateLimitedHandlerFunc(action Action) http.HandlerFunc {
68+
rateLimitedAction := rlm.RateLimitAction(action)
69+
return AsHandlerFunc(rateLimitedAction)
70+
}
71+
72+
// AsRateLimitedProxyHandlerFunc wraps AsProxyHandlerFunc with rate limiting
73+
func (rlm *RateLimiterMiddleware) AsRateLimitedProxyHandlerFunc(action ProxyAction) http.HandlerFunc {
74+
rateLimitedAction := rlm.RateLimitProxyAction(action)
75+
return AsProxyHandlerFunc(rateLimitedAction)
76+
}
77+
78+
// GetStats returns rate limiter statistics
79+
func (rlm *RateLimiterMiddleware) GetStats() map[string]interface{} {
80+
return rlm.limiter.GetStats()
81+
}
82+
83+
// TooManyRequests returns a 429 Too Many Requests response with accurate rate limit headers
84+
func (rlm *RateLimiterMiddleware) TooManyRequests(err error, clientIP string) Response {
85+
rateLimit := rlm.limiter.GetRateLimit()
86+
currentRequests := rlm.limiter.GetCurrentRequestCount(clientIP)
87+
remaining := max(0, rateLimit-currentRequests)
88+
resetTime := time.Now().Add(time.Second)
89+
90+
return Error(err, http.StatusTooManyRequests).
91+
WithHeader("Retry-After", "1").
92+
WithHeader("X-RateLimit-Limit", strconv.Itoa(rateLimit)).
93+
WithHeader("X-RateLimit-Remaining", strconv.Itoa(remaining)).
94+
WithHeader("X-RateLimit-Reset", strconv.FormatInt(resetTime.Unix(), 10)).
95+
WithHeader("X-Client-IP", clientIP)
96+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package mw
2+
3+
import "net/http"
4+
5+
// WithRateLimit wraps an Action with rate limiting if a rate limiter is provided.
6+
// If rateLimiter is nil, it falls back to the standard AsHandlerFunc wrapper.
7+
// This provides a clean way to conditionally apply rate limiting to endpoints.
8+
func WithRateLimit(rateLimiter *RateLimiterMiddleware, action Action) http.HandlerFunc {
9+
if rateLimiter != nil {
10+
return rateLimiter.AsRateLimitedHandlerFunc(action)
11+
}
12+
return AsHandlerFunc(action)
13+
}
14+
15+
// WithRateLimitProxy wraps a ProxyAction with rate limiting if a rate limiter is provided.
16+
// If rateLimiter is nil, it falls back to the standard AsProxyHandlerFunc wrapper.
17+
// This provides a clean way to conditionally apply rate limiting to proxy endpoints.
18+
func WithRateLimitProxy(rateLimiter *RateLimiterMiddleware, action ProxyAction) http.HandlerFunc {
19+
if rateLimiter != nil {
20+
return rateLimiter.AsRateLimitedProxyHandlerFunc(action)
21+
}
22+
return AsProxyHandlerFunc(action)
23+
}

grid-proxy/internal/explorer/server.go

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ func (a *App) getContractBills(r *http.Request) (interface{}, mw.Response) {
629629
// @license.name Apache 2.0
630630
// @license.url http://www.apache.org/licenses/LICENSE-2.0.html
631631
// @BasePath /
632-
func Setup(router *mux.Router, gitCommit string, cl DBClient, relayClient rmb.Client, idxIntervals map[string]uint) error {
632+
func Setup(router *mux.Router, gitCommit string, cl DBClient, relayClient rmb.Client, idxIntervals map[string]uint, rateLimitRPS int) error {
633633

634634
a := App{
635635
cl: cl,
@@ -638,32 +638,41 @@ func Setup(router *mux.Router, gitCommit string, cl DBClient, relayClient rmb.Cl
638638
idxIntervals: idxIntervals,
639639
}
640640

641-
router.HandleFunc("/farms", mw.AsHandlerFunc(a.listFarms))
642-
router.HandleFunc("/stats", mw.AsHandlerFunc(a.getStats))
641+
// Create rate limiter middleware if rate limiting is enabled
642+
var rateLimiter *mw.RateLimiterMiddleware
643+
if rateLimitRPS > 0 {
644+
rateLimiter = mw.NewRateLimiterMiddleware(rateLimitRPS)
645+
log.Info().Int("rate_limit_rps", rateLimitRPS).Msg("Rate limiting enabled")
646+
} else {
647+
log.Info().Msg("Rate limiting disabled")
648+
}
649+
650+
router.HandleFunc("/farms", mw.WithRateLimit(rateLimiter, a.listFarms))
651+
router.HandleFunc("/stats", mw.WithRateLimit(rateLimiter, a.getStats))
643652

644-
router.HandleFunc("/twins", mw.AsHandlerFunc(a.listTwins))
645-
router.HandleFunc("/twins/{twin_id:[0-9]+}/consumption", mw.AsHandlerFunc(a.getTwinConsumption))
653+
router.HandleFunc("/twins", mw.WithRateLimit(rateLimiter, a.listTwins))
654+
router.HandleFunc("/twins/{twin_id:[0-9]+}/consumption", mw.WithRateLimit(rateLimiter, a.getTwinConsumption))
646655

647-
router.HandleFunc("/nodes", mw.AsHandlerFunc(a.getNodes))
648-
router.HandleFunc("/nodes/{node_id:[0-9]+}", mw.AsHandlerFunc(a.getNode))
649-
router.HandleFunc("/nodes/{node_id:[0-9]+}/status", mw.AsHandlerFunc(a.getNodeStatus))
650-
router.HandleFunc("/nodes/{node_id:[0-9]+}/statistics", mw.AsHandlerFunc(a.getNodeStatistics))
651-
router.HandleFunc("/nodes/{node_id:[0-9]+}/gpu", mw.AsHandlerFunc(a.getNodeGpus))
656+
router.HandleFunc("/nodes", mw.WithRateLimit(rateLimiter, a.getNodes))
657+
router.HandleFunc("/nodes/{node_id:[0-9]+}", mw.WithRateLimit(rateLimiter, a.getNode))
658+
router.HandleFunc("/nodes/{node_id:[0-9]+}/status", mw.WithRateLimit(rateLimiter, a.getNodeStatus))
659+
router.HandleFunc("/nodes/{node_id:[0-9]+}/statistics", mw.WithRateLimit(rateLimiter, a.getNodeStatistics))
660+
router.HandleFunc("/nodes/{node_id:[0-9]+}/gpu", mw.WithRateLimit(rateLimiter, a.getNodeGpus))
652661

653-
router.HandleFunc("/gateways", mw.AsHandlerFunc(a.getGateways))
654-
router.HandleFunc("/gateways/{node_id:[0-9]+}", mw.AsHandlerFunc(a.getGateway))
655-
router.HandleFunc("/gateways/{node_id:[0-9]+}/status", mw.AsHandlerFunc(a.getNodeStatus))
662+
router.HandleFunc("/gateways", mw.WithRateLimit(rateLimiter, a.getGateways))
663+
router.HandleFunc("/gateways/{node_id:[0-9]+}", mw.WithRateLimit(rateLimiter, a.getGateway))
664+
router.HandleFunc("/gateways/{node_id:[0-9]+}/status", mw.WithRateLimit(rateLimiter, a.getNodeStatus))
656665

657-
router.HandleFunc("/contracts", mw.AsHandlerFunc(a.listContracts))
658-
router.HandleFunc("/contracts/{contract_id:[0-9]+}", mw.AsHandlerFunc(a.getContract))
659-
router.HandleFunc("/contracts/{contract_id:[0-9]+}/bills", mw.AsHandlerFunc(a.getContractBills))
666+
router.HandleFunc("/contracts", mw.WithRateLimit(rateLimiter, a.listContracts))
667+
router.HandleFunc("/contracts/{contract_id:[0-9]+}", mw.WithRateLimit(rateLimiter, a.getContract))
668+
router.HandleFunc("/contracts/{contract_id:[0-9]+}/bills", mw.WithRateLimit(rateLimiter, a.getContractBills))
660669

661-
router.HandleFunc("/public_ips", mw.AsHandlerFunc(a.GetPublicIps))
670+
router.HandleFunc("/public_ips", mw.WithRateLimit(rateLimiter, a.GetPublicIps))
662671

663-
router.HandleFunc("/", mw.AsHandlerFunc(a.indexPage(router)))
664-
router.HandleFunc("/ping", mw.AsHandlerFunc(a.ping))
665-
router.HandleFunc("/version", mw.AsHandlerFunc(a.version))
666-
router.HandleFunc("/health", mw.AsHandlerFunc(a.health))
672+
router.HandleFunc("/", mw.WithRateLimit(rateLimiter, a.indexPage(router)))
673+
router.HandleFunc("/ping", mw.WithRateLimit(rateLimiter, a.ping))
674+
router.HandleFunc("/version", mw.WithRateLimit(rateLimiter, a.version))
675+
router.HandleFunc("/health", mw.WithRateLimit(rateLimiter, a.health))
667676
router.PathPrefix("/swagger/").Handler(httpSwagger.WrapHandler)
668677

669678
return nil
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Rate Limiter
2+
3+
This package implements a sliding window rate limiter for the TFGrid Proxy server.
4+
5+
## Features
6+
7+
- **IP-based Rate Limiting**: Tracks requests per IP address
8+
- **Sliding Window Algorithm**: Uses a sliding window approach for smooth rate limiting
9+
- **Thread-Safe**: Safe for concurrent use across multiple goroutines
10+
- **Memory Efficient**: Automatic cleanup of old entries
11+
- **Configurable**: Rate limit can be set via command-line flag
12+
13+
## Usage
14+
15+
### Command Line Flag
16+
17+
Use the `--rate-limit-rps` flag to set the rate limit:
18+
19+
```bash
20+
# Enable rate limiting at 20 requests per second per IP (default)
21+
./proxy_server --rate-limit-rps 20
22+
23+
# Set custom rate limit of 100 requests per second per IP
24+
./proxy_server --rate-limit-rps 100
25+
26+
# Disable rate limiting
27+
./proxy_server --rate-limit-rps 0
28+
```
29+
30+
### IP Address Detection
31+
32+
The rate limiter automatically extracts the client IP address using the following priority:
33+
34+
1. `X-Real-IP` header
35+
2. `X-Forwarded-For` header (first IP if multiple)
36+
3. `RemoteAddr` from the connection
37+
38+
This ensures proper rate limiting even when the proxy is behind load balancers or CDNs.
39+
40+
### HTTP Response
41+
42+
When rate limit is exceeded, the server returns:
43+
44+
- **Status Code**: 429 (Too Many Requests)
45+
- **Headers**:
46+
- `Retry-After: 1` - Suggests retrying after 1 second
47+
- `X-RateLimit-Limit: 20` - Current rate limit
48+
- `X-RateLimit-Remaining: 0` - Remaining requests (0 when exceeded)
49+
50+
### Algorithm Details
51+
52+
The sliding window algorithm works as follows:
53+
54+
1. **Time Window**: Uses a 1-second sliding window
55+
2. **Request Tracking**: Stores timestamps of requests within the window
56+
3. **Cleanup**: Automatically removes requests older than the window
57+
4. **Memory Management**: Periodically cleans up inactive IP entries
58+
59+
### Performance
60+
61+
- **Memory Usage**: Minimal overhead, only stores active IP addresses
62+
- **CPU Usage**: O(n) where n is the number of requests in the current window
63+
- **Concurrency**: Thread-safe with read-write mutexes for optimal performance
64+
65+
## Configuration
66+
67+
| Flag | Default | Description |
68+
|------|---------|-------------|
69+
| `--rate-limit-rps` | 20 | Requests per second per IP address (0 to disable) |
70+
71+
## Logging
72+
73+
The rate limiter provides debug logging for:
74+
75+
- Rate limit violations (WARN level)
76+
- New IP tracking (DEBUG level)
77+
- Request allowances (DEBUG level)
78+
- Cleanup operations (DEBUG level)
79+
80+
Example log output:
81+
```
82+
{"level":"warn","ip":"192.168.1.100","method":"GET","path":"/nodes","time":"2025-07-07T14:40:43Z","message":"Rate limit exceeded"}
83+
```
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package ratelimiter
2+
3+
import (
4+
"net"
5+
"net/http"
6+
"strings"
7+
)
8+
9+
// GetClientIP extracts the real client IP from the HTTP request
10+
// It checks X-Forwarded-For, X-Real-IP headers, and falls back to RemoteAddr
11+
func GetClientIP(r *http.Request) string {
12+
if realIP := r.Header.Get("X-Real-IP"); realIP != "" {
13+
return realIP
14+
}
15+
16+
if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" {
17+
parts := strings.Split(fwd, ",")
18+
return strings.TrimSpace(parts[0])
19+
}
20+
21+
host, _, err := net.SplitHostPort(r.RemoteAddr)
22+
if err != nil {
23+
return r.RemoteAddr
24+
}
25+
return host
26+
}

0 commit comments

Comments
 (0)