@@ -6,6 +6,8 @@ package middleware
66import (
77 "crypto/subtle"
88 "net/http"
9+ "slices"
10+ "strings"
911 "time"
1012
1113 "github.com/labstack/echo/v4"
@@ -16,6 +18,22 @@ type CSRFConfig struct {
1618 // Skipper defines a function to skip middleware.
1719 Skipper Skipper
1820
21+ // TrustedOrigin permits any request with `Sec-Fetch-Site` header whose `Origin` header
22+ // exactly matches the specified value.
23+ // Values should be formated as Origin header "scheme://host[:port]".
24+ //
25+ // See [Origin]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin
26+ // See [Sec-Fetch-Site]: https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers
27+ TrustedOrigins []string
28+
29+ // AllowSecFetchSameSite allows custom behaviour for `Sec-Fetch-Site` requests that are about to
30+ // fail with CRSF error, to be allowed or replaced with custom error.
31+ // This function applies to `Sec-Fetch-Site` values:
32+ // - `same-site` same registrable domain (subdomain and/or different port)
33+ // - `cross-site` request originates from different site
34+ // See [Sec-Fetch-Site]: https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers
35+ AllowSecFetchSiteFunc func (c echo.Context ) (bool , error )
36+
1937 // TokenLength is the length of the generated token.
2038 TokenLength uint8 `yaml:"token_length"`
2139 // Optional. Default value 32.
@@ -94,7 +112,11 @@ func CSRF() echo.MiddlewareFunc {
94112// CSRFWithConfig returns a CSRF middleware with config.
95113// See `CSRF()`.
96114func CSRFWithConfig (config CSRFConfig ) echo.MiddlewareFunc {
97- // Defaults
115+ return toMiddlewareOrPanic (config )
116+ }
117+
118+ // ToMiddleware converts CSRFConfig to middleware or returns an error for invalid configuration
119+ func (config CSRFConfig ) ToMiddleware () (echo.MiddlewareFunc , error ) {
98120 if config .Skipper == nil {
99121 config .Skipper = DefaultCSRFConfig .Skipper
100122 }
@@ -117,10 +139,16 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
117139 if config .CookieSameSite == http .SameSiteNoneMode {
118140 config .CookieSecure = true
119141 }
142+ if len (config .TrustedOrigins ) > 0 {
143+ if vErr := validateOrigins (config .TrustedOrigins , "trusted origin" ); vErr != nil {
144+ return nil , vErr
145+ }
146+ config .TrustedOrigins = append ([]string (nil ), config .TrustedOrigins ... )
147+ }
120148
121149 extractors , cErr := CreateExtractors (config .TokenLookup )
122150 if cErr != nil {
123- panic ( cErr )
151+ return nil , cErr
124152 }
125153
126154 return func (next echo.HandlerFunc ) echo.HandlerFunc {
@@ -129,6 +157,17 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
129157 return next (c )
130158 }
131159
160+ // use the `Sec-Fetch-Site` header as part of a modern approach to CSRF protection
161+ allow , err := config .checkSecFetchSiteRequest (c )
162+ if err != nil {
163+ return err
164+ }
165+ if allow {
166+ return next (c )
167+ }
168+
169+ // Fallback to legacy token based CSRF protection
170+
132171 token := ""
133172 if k , err := c .Cookie (config .CookieName ); err != nil {
134173 token = randomString (config .TokenLength )
@@ -210,9 +249,55 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
210249
211250 return next (c )
212251 }
213- }
252+ }, nil
214253}
215254
216255func validateCSRFToken (token , clientToken string ) bool {
217256 return subtle .ConstantTimeCompare ([]byte (token ), []byte (clientToken )) == 1
218257}
258+
259+ var safeMethods = []string {http .MethodGet , http .MethodHead , http .MethodOptions , http .MethodTrace }
260+
261+ func (config CSRFConfig ) checkSecFetchSiteRequest (c echo.Context ) (bool , error ) {
262+ // https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers
263+ // Sec-Fetch-Site values are:
264+ // - `same-origin` exact origin match - allow always
265+ // - `same-site` same registrable domain (subdomain and/or different port) - block, unless explicitly trusted
266+ // - `cross-site` request originates from different site - block, unless explicitly trusted
267+ // - `none` direct navigation (URL bar, bookmark) - allow always
268+ secFetchSite := c .Request ().Header .Get (echo .HeaderSecFetchSite )
269+ if secFetchSite == "" {
270+ return false , nil
271+ }
272+
273+ if len (config .TrustedOrigins ) > 0 {
274+ // trusted sites ala OAuth callbacks etc. should be let through
275+ origin := c .Request ().Header .Get (echo .HeaderOrigin )
276+ if origin != "" {
277+ for _ , trustedOrigin := range config .TrustedOrigins {
278+ if strings .EqualFold (origin , trustedOrigin ) {
279+ return true , nil
280+ }
281+ }
282+ }
283+ }
284+ isSafe := slices .Contains (safeMethods , c .Request ().Method )
285+ if ! isSafe { // for state-changing request check SecFetchSite value
286+ isSafe = secFetchSite == "same-origin" || secFetchSite == "none"
287+ }
288+
289+ if isSafe {
290+ return true , nil
291+ }
292+ // we are here when request is state-changing and `cross-site` or `same-site`
293+
294+ // Note: if you want to allow `same-site` use config.TrustedOrigins or `config.AllowSecFetchSiteFunc`
295+ if config .AllowSecFetchSiteFunc != nil {
296+ return config .AllowSecFetchSiteFunc (c )
297+ }
298+
299+ if secFetchSite == "same-site" {
300+ return false , echo .NewHTTPError (http .StatusForbidden , "same-site request blocked by CSRF" )
301+ }
302+ return false , echo .NewHTTPError (http .StatusForbidden , "cross-site request blocked by CSRF" )
303+ }
0 commit comments