Skip to content

Commit ec96d16

Browse files
authored
Fix csrf middleware behavior with header key lookup (#2063)
* 🐛 [Bug]: Strange CSRF middleware behavior with header KeyLookup configuration #2045
1 parent 6026560 commit ec96d16

File tree

3 files changed

+120
-7
lines changed

3 files changed

+120
-7
lines changed

middleware/csrf/config.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,17 @@ type Config struct {
102102
Extractor func(c *fiber.Ctx) (string, error)
103103
}
104104

105+
const HeaderName = "X-Csrf-Token"
106+
105107
// ConfigDefault is the default config
106108
var ConfigDefault = Config{
107-
KeyLookup: "header:X-Csrf-Token",
109+
KeyLookup: "header:" + HeaderName,
108110
CookieName: "csrf_",
109111
CookieSameSite: "Lax",
110112
Expiration: 1 * time.Hour,
111113
KeyGenerator: utils.UUID,
112114
ErrorHandler: defaultErrorHandler,
113-
Extractor: CsrfFromHeader("X-Csrf-Token"),
115+
Extractor: CsrfFromHeader(HeaderName),
114116
}
115117

116118
// default ErrorHandler that process return error from fiber.Handler

middleware/csrf/csrf_test.go

Lines changed: 111 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func Test_CSRF(t *testing.T) {
4040
ctx.Request.Reset()
4141
ctx.Response.Reset()
4242
ctx.Request.Header.SetMethod("POST")
43-
ctx.Request.Header.Set("X-CSRF-Token", "johndoe")
43+
ctx.Request.Header.Set(HeaderName, "johndoe")
4444
h(ctx)
4545
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
4646

@@ -55,7 +55,7 @@ func Test_CSRF(t *testing.T) {
5555
ctx.Request.Reset()
5656
ctx.Response.Reset()
5757
ctx.Request.Header.SetMethod("POST")
58-
ctx.Request.Header.Set("X-CSRF-Token", token)
58+
ctx.Request.Header.Set(HeaderName, token)
5959
h(ctx)
6060
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
6161
}
@@ -305,7 +305,7 @@ func Test_CSRF_ErrorHandler_InvalidToken(t *testing.T) {
305305
ctx.Request.Reset()
306306
ctx.Response.Reset()
307307
ctx.Request.Header.SetMethod("POST")
308-
ctx.Request.Header.Set("X-CSRF-Token", "johndoe")
308+
ctx.Request.Header.Set(HeaderName, "johndoe")
309309
h(ctx)
310310
utils.AssertEqual(t, 419, ctx.Response.StatusCode())
311311
utils.AssertEqual(t, "invalid CSRF token", string(ctx.Response.Body()))
@@ -340,3 +340,111 @@ func Test_CSRF_ErrorHandler_EmptyToken(t *testing.T) {
340340
utils.AssertEqual(t, 419, ctx.Response.StatusCode())
341341
utils.AssertEqual(t, "empty CSRF token", string(ctx.Response.Body()))
342342
}
343+
344+
// TODO: use this test case and make the unsafe header value bug from https://github.com/gofiber/fiber/issues/2045 reproducible and permanently fixed/tested by this testcase
345+
//func Test_CSRF_UnsafeHeaderValue(t *testing.T) {
346+
// app := fiber.New()
347+
//
348+
// app.Use(New())
349+
// app.Get("/", func(c *fiber.Ctx) error {
350+
// return c.SendStatus(fiber.StatusOK)
351+
// })
352+
// app.Get("/test", func(c *fiber.Ctx) error {
353+
// return c.SendStatus(fiber.StatusOK)
354+
// })
355+
// app.Post("/", func(c *fiber.Ctx) error {
356+
// return c.SendStatus(fiber.StatusOK)
357+
// })
358+
//
359+
// resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
360+
// utils.AssertEqual(t, nil, err)
361+
// utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
362+
//
363+
// var token string
364+
// for _, c := range resp.Cookies() {
365+
// if c.Name != ConfigDefault.CookieName {
366+
// continue
367+
// }
368+
// token = c.Value
369+
// break
370+
// }
371+
//
372+
// fmt.Println("token", token)
373+
//
374+
// getReq := httptest.NewRequest(http.MethodGet, "/", nil)
375+
// getReq.Header.Set(HeaderName, token)
376+
// resp, err = app.Test(getReq)
377+
//
378+
// getReq = httptest.NewRequest(http.MethodGet, "/test", nil)
379+
// getReq.Header.Set("X-Requested-With", "XMLHttpRequest")
380+
// getReq.Header.Set(fiber.HeaderCacheControl, "no")
381+
// getReq.Header.Set(HeaderName, token)
382+
//
383+
// resp, err = app.Test(getReq)
384+
//
385+
// getReq.Header.Set(fiber.HeaderAccept, "*/*")
386+
// getReq.Header.Del(HeaderName)
387+
// resp, err = app.Test(getReq)
388+
//
389+
// postReq := httptest.NewRequest(http.MethodPost, "/", nil)
390+
// postReq.Header.Set("X-Requested-With", "XMLHttpRequest")
391+
// postReq.Header.Set(HeaderName, token)
392+
// resp, err = app.Test(postReq)
393+
//}
394+
395+
// go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_Check -benchmem -count=4
396+
func Benchmark_Middleware_CSRF_Check(b *testing.B) {
397+
app := fiber.New()
398+
399+
app.Use(New())
400+
app.Get("/", func(c *fiber.Ctx) error {
401+
return c.SendStatus(fiber.StatusTeapot)
402+
})
403+
404+
fctx := &fasthttp.RequestCtx{}
405+
h := app.Handler()
406+
ctx := &fasthttp.RequestCtx{}
407+
408+
// Generate CSRF token
409+
ctx.Request.Header.SetMethod("GET")
410+
h(ctx)
411+
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
412+
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
413+
414+
ctx.Request.Header.SetMethod("POST")
415+
ctx.Request.Header.Set(HeaderName, token)
416+
417+
b.ReportAllocs()
418+
b.ResetTimer()
419+
420+
for n := 0; n < b.N; n++ {
421+
h(fctx)
422+
}
423+
424+
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
425+
}
426+
427+
// go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_GenerateToken -benchmem -count=4
428+
func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) {
429+
app := fiber.New()
430+
431+
app.Use(New())
432+
app.Get("/", func(c *fiber.Ctx) error {
433+
return c.SendStatus(fiber.StatusTeapot)
434+
})
435+
436+
fctx := &fasthttp.RequestCtx{}
437+
h := app.Handler()
438+
ctx := &fasthttp.RequestCtx{}
439+
440+
// Generate CSRF token
441+
ctx.Request.Header.SetMethod("GET")
442+
b.ReportAllocs()
443+
b.ResetTimer()
444+
445+
for n := 0; n < b.N; n++ {
446+
h(fctx)
447+
}
448+
449+
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
450+
}

middleware/csrf/manager.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
"github.com/gofiber/fiber/v2"
88
"github.com/gofiber/fiber/v2/internal/memory"
9+
"github.com/gofiber/fiber/v2/utils"
910
)
1011

1112
// go:generate msgp
@@ -88,7 +89,8 @@ func (m *manager) set(key string, it *item, exp time.Duration) {
8889
_ = m.storage.Set(key, raw, exp)
8990
}
9091
} else {
91-
m.memory.Set(key, it, exp)
92+
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
93+
m.memory.Set(utils.CopyString(key), it, exp)
9294
}
9395
}
9496

@@ -97,7 +99,8 @@ func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
9799
if m.storage != nil {
98100
_ = m.storage.Set(key, raw, exp)
99101
} else {
100-
m.memory.Set(key, raw, exp)
102+
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
103+
m.memory.Set(utils.CopyString(key), raw, exp)
101104
}
102105
}
103106

0 commit comments

Comments
 (0)