Skip to content

Commit d77e8c0

Browse files
authored
Added ErrorHandler and ErrorHandlerWithContext in CSRF middleware (#2257)
* feat: add error handler to csrf middleware Co-authored-by: Mojtaba Arezoomand <mojtaba.arezoomand@snapp.cab>
1 parent 534bbb8 commit d77e8c0

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

middleware/csrf.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,13 @@ type (
6161
// Indicates SameSite mode of the CSRF cookie.
6262
// Optional. Default value SameSiteDefaultMode.
6363
CookieSameSite http.SameSite `yaml:"cookie_same_site"`
64+
65+
// ErrorHandler defines a function which is executed for returning custom errors.
66+
ErrorHandler CSRFErrorHandler
6467
}
68+
69+
// CSRFErrorHandler is a function which is executed for creating custom errors.
70+
CSRFErrorHandler func(err error, c echo.Context) error
6571
)
6672

6773
// ErrCSRFInvalid is returned when CSRF check fails
@@ -154,8 +160,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
154160
lastTokenErr = ErrCSRFInvalid
155161
}
156162
}
163+
var finalErr error
157164
if lastTokenErr != nil {
158-
return lastTokenErr
165+
finalErr = lastTokenErr
159166
} else if lastExtractorErr != nil {
160167
// ugly part to preserve backwards compatible errors. someone could rely on them
161168
if lastExtractorErr == errQueryExtractorValueMissing {
@@ -167,7 +174,14 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
167174
} else {
168175
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error())
169176
}
170-
return lastExtractorErr
177+
finalErr = lastExtractorErr
178+
}
179+
180+
if finalErr != nil {
181+
if config.ErrorHandler != nil {
182+
return config.ErrorHandler(finalErr, c)
183+
}
184+
return finalErr
171185
}
172186
}
173187

middleware/csrf_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,3 +358,25 @@ func TestCSRFConfig_skipper(t *testing.T) {
358358
})
359359
}
360360
}
361+
362+
func TestCSRFErrorHandling(t *testing.T) {
363+
cfg := CSRFConfig{
364+
ErrorHandler: func(err error, c echo.Context) error {
365+
return echo.NewHTTPError(http.StatusTeapot, "error_handler_executed")
366+
},
367+
}
368+
369+
e := echo.New()
370+
e.POST("/", func(c echo.Context) error {
371+
return c.String(http.StatusNotImplemented, "should not end up here")
372+
})
373+
374+
e.Use(CSRFWithConfig(cfg))
375+
376+
req := httptest.NewRequest(http.MethodPost, "/", nil)
377+
res := httptest.NewRecorder()
378+
e.ServeHTTP(res, req)
379+
380+
assert.Equal(t, http.StatusTeapot, res.Code)
381+
assert.Equal(t, "{\"message\":\"error_handler_executed\"}\n", res.Body.String())
382+
}

0 commit comments

Comments
 (0)