1package middleware 2 3import ( 4 "crypto/subtle" 5 "errors" 6 "net/http" 7 "strings" 8 "time" 9 10 "github.com/labstack/echo" 11 "github.com/labstack/gommon/random" 12) 13 14type ( 15 // CSRFConfig defines the config for CSRF middleware. 16 CSRFConfig struct { 17 // Skipper defines a function to skip middleware. 18 Skipper Skipper 19 20 // TokenLength is the length of the generated token. 21 TokenLength uint8 `yaml:"token_length"` 22 // Optional. Default value 32. 23 24 // TokenLookup is a string in the form of "<source>:<key>" that is used 25 // to extract token from the request. 26 // Optional. Default value "header:X-CSRF-Token". 27 // Possible values: 28 // - "header:<name>" 29 // - "form:<name>" 30 // - "query:<name>" 31 TokenLookup string `yaml:"token_lookup"` 32 33 // Context key to store generated CSRF token into context. 34 // Optional. Default value "csrf". 35 ContextKey string `yaml:"context_key"` 36 37 // Name of the CSRF cookie. This cookie will store CSRF token. 38 // Optional. Default value "csrf". 39 CookieName string `yaml:"cookie_name"` 40 41 // Domain of the CSRF cookie. 42 // Optional. Default value none. 43 CookieDomain string `yaml:"cookie_domain"` 44 45 // Path of the CSRF cookie. 46 // Optional. Default value none. 47 CookiePath string `yaml:"cookie_path"` 48 49 // Max age (in seconds) of the CSRF cookie. 50 // Optional. Default value 86400 (24hr). 51 CookieMaxAge int `yaml:"cookie_max_age"` 52 53 // Indicates if CSRF cookie is secure. 54 // Optional. Default value false. 55 CookieSecure bool `yaml:"cookie_secure"` 56 57 // Indicates if CSRF cookie is HTTP only. 58 // Optional. Default value false. 59 CookieHTTPOnly bool `yaml:"cookie_http_only"` 60 } 61 62 // csrfTokenExtractor defines a function that takes `echo.Context` and returns 63 // either a token or an error. 64 csrfTokenExtractor func(echo.Context) (string, error) 65) 66 67var ( 68 // DefaultCSRFConfig is the default CSRF middleware config. 69 DefaultCSRFConfig = CSRFConfig{ 70 Skipper: DefaultSkipper, 71 TokenLength: 32, 72 TokenLookup: "header:" + echo.HeaderXCSRFToken, 73 ContextKey: "csrf", 74 CookieName: "_csrf", 75 CookieMaxAge: 86400, 76 } 77) 78 79// CSRF returns a Cross-Site Request Forgery (CSRF) middleware. 80// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery 81func CSRF() echo.MiddlewareFunc { 82 c := DefaultCSRFConfig 83 return CSRFWithConfig(c) 84} 85 86// CSRFWithConfig returns a CSRF middleware with config. 87// See `CSRF()`. 88func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { 89 // Defaults 90 if config.Skipper == nil { 91 config.Skipper = DefaultCSRFConfig.Skipper 92 } 93 if config.TokenLength == 0 { 94 config.TokenLength = DefaultCSRFConfig.TokenLength 95 } 96 if config.TokenLookup == "" { 97 config.TokenLookup = DefaultCSRFConfig.TokenLookup 98 } 99 if config.ContextKey == "" { 100 config.ContextKey = DefaultCSRFConfig.ContextKey 101 } 102 if config.CookieName == "" { 103 config.CookieName = DefaultCSRFConfig.CookieName 104 } 105 if config.CookieMaxAge == 0 { 106 config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge 107 } 108 109 // Initialize 110 parts := strings.Split(config.TokenLookup, ":") 111 extractor := csrfTokenFromHeader(parts[1]) 112 switch parts[0] { 113 case "form": 114 extractor = csrfTokenFromForm(parts[1]) 115 case "query": 116 extractor = csrfTokenFromQuery(parts[1]) 117 } 118 119 return func(next echo.HandlerFunc) echo.HandlerFunc { 120 return func(c echo.Context) error { 121 if config.Skipper(c) { 122 return next(c) 123 } 124 125 req := c.Request() 126 k, err := c.Cookie(config.CookieName) 127 token := "" 128 129 // Generate token 130 if err != nil { 131 token = random.String(config.TokenLength) 132 } else { 133 // Reuse token 134 token = k.Value 135 } 136 137 switch req.Method { 138 case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: 139 default: 140 // Validate token only for requests which are not defined as 'safe' by RFC7231 141 clientToken, err := extractor(c) 142 if err != nil { 143 return echo.NewHTTPError(http.StatusBadRequest, err.Error()) 144 } 145 if !validateCSRFToken(token, clientToken) { 146 return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") 147 } 148 } 149 150 // Set CSRF cookie 151 cookie := new(http.Cookie) 152 cookie.Name = config.CookieName 153 cookie.Value = token 154 if config.CookiePath != "" { 155 cookie.Path = config.CookiePath 156 } 157 if config.CookieDomain != "" { 158 cookie.Domain = config.CookieDomain 159 } 160 cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second) 161 cookie.Secure = config.CookieSecure 162 cookie.HttpOnly = config.CookieHTTPOnly 163 c.SetCookie(cookie) 164 165 // Store token in the context 166 c.Set(config.ContextKey, token) 167 168 // Protect clients from caching the response 169 c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie) 170 171 return next(c) 172 } 173 } 174} 175 176// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the 177// provided request header. 178func csrfTokenFromHeader(header string) csrfTokenExtractor { 179 return func(c echo.Context) (string, error) { 180 return c.Request().Header.Get(header), nil 181 } 182} 183 184// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the 185// provided form parameter. 186func csrfTokenFromForm(param string) csrfTokenExtractor { 187 return func(c echo.Context) (string, error) { 188 token := c.FormValue(param) 189 if token == "" { 190 return "", errors.New("missing csrf token in the form parameter") 191 } 192 return token, nil 193 } 194} 195 196// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the 197// provided query parameter. 198func csrfTokenFromQuery(param string) csrfTokenExtractor { 199 return func(c echo.Context) (string, error) { 200 token := c.QueryParam(param) 201 if token == "" { 202 return "", errors.New("missing csrf token in the query string") 203 } 204 return token, nil 205 } 206} 207 208func validateCSRFToken(token, clientToken string) bool { 209 return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1 210} 211