1package middleware
2
3import (
4	"crypto/subtle"
5	"errors"
6	"net/http"
7	"strings"
8	"time"
9
10	"github.com/labstack/echo"
11	"github.com/labstack/gommon/random"
12)
13
14type (
15	// CSRFConfig defines the config for CSRF middleware.
16	CSRFConfig struct {
17		// Skipper defines a function to skip middleware.
18		Skipper Skipper
19
20		// TokenLength is the length of the generated token.
21		TokenLength uint8 `yaml:"token_length"`
22		// Optional. Default value 32.
23
24		// TokenLookup is a string in the form of "<source>:<key>" that is used
25		// to extract token from the request.
26		// Optional. Default value "header:X-CSRF-Token".
27		// Possible values:
28		// - "header:<name>"
29		// - "form:<name>"
30		// - "query:<name>"
31		TokenLookup string `yaml:"token_lookup"`
32
33		// Context key to store generated CSRF token into context.
34		// Optional. Default value "csrf".
35		ContextKey string `yaml:"context_key"`
36
37		// Name of the CSRF cookie. This cookie will store CSRF token.
38		// Optional. Default value "csrf".
39		CookieName string `yaml:"cookie_name"`
40
41		// Domain of the CSRF cookie.
42		// Optional. Default value none.
43		CookieDomain string `yaml:"cookie_domain"`
44
45		// Path of the CSRF cookie.
46		// Optional. Default value none.
47		CookiePath string `yaml:"cookie_path"`
48
49		// Max age (in seconds) of the CSRF cookie.
50		// Optional. Default value 86400 (24hr).
51		CookieMaxAge int `yaml:"cookie_max_age"`
52
53		// Indicates if CSRF cookie is secure.
54		// Optional. Default value false.
55		CookieSecure bool `yaml:"cookie_secure"`
56
57		// Indicates if CSRF cookie is HTTP only.
58		// Optional. Default value false.
59		CookieHTTPOnly bool `yaml:"cookie_http_only"`
60	}
61
62	// csrfTokenExtractor defines a function that takes `echo.Context` and returns
63	// either a token or an error.
64	csrfTokenExtractor func(echo.Context) (string, error)
65)
66
67var (
68	// DefaultCSRFConfig is the default CSRF middleware config.
69	DefaultCSRFConfig = CSRFConfig{
70		Skipper:      DefaultSkipper,
71		TokenLength:  32,
72		TokenLookup:  "header:" + echo.HeaderXCSRFToken,
73		ContextKey:   "csrf",
74		CookieName:   "_csrf",
75		CookieMaxAge: 86400,
76	}
77)
78
79// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
80// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
81func CSRF() echo.MiddlewareFunc {
82	c := DefaultCSRFConfig
83	return CSRFWithConfig(c)
84}
85
86// CSRFWithConfig returns a CSRF middleware with config.
87// See `CSRF()`.
88func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
89	// Defaults
90	if config.Skipper == nil {
91		config.Skipper = DefaultCSRFConfig.Skipper
92	}
93	if config.TokenLength == 0 {
94		config.TokenLength = DefaultCSRFConfig.TokenLength
95	}
96	if config.TokenLookup == "" {
97		config.TokenLookup = DefaultCSRFConfig.TokenLookup
98	}
99	if config.ContextKey == "" {
100		config.ContextKey = DefaultCSRFConfig.ContextKey
101	}
102	if config.CookieName == "" {
103		config.CookieName = DefaultCSRFConfig.CookieName
104	}
105	if config.CookieMaxAge == 0 {
106		config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge
107	}
108
109	// Initialize
110	parts := strings.Split(config.TokenLookup, ":")
111	extractor := csrfTokenFromHeader(parts[1])
112	switch parts[0] {
113	case "form":
114		extractor = csrfTokenFromForm(parts[1])
115	case "query":
116		extractor = csrfTokenFromQuery(parts[1])
117	}
118
119	return func(next echo.HandlerFunc) echo.HandlerFunc {
120		return func(c echo.Context) error {
121			if config.Skipper(c) {
122				return next(c)
123			}
124
125			req := c.Request()
126			k, err := c.Cookie(config.CookieName)
127			token := ""
128
129			// Generate token
130			if err != nil {
131				token = random.String(config.TokenLength)
132			} else {
133				// Reuse token
134				token = k.Value
135			}
136
137			switch req.Method {
138			case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
139			default:
140				// Validate token only for requests which are not defined as 'safe' by RFC7231
141				clientToken, err := extractor(c)
142				if err != nil {
143					return echo.NewHTTPError(http.StatusBadRequest, err.Error())
144				}
145				if !validateCSRFToken(token, clientToken) {
146					return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
147				}
148			}
149
150			// Set CSRF cookie
151			cookie := new(http.Cookie)
152			cookie.Name = config.CookieName
153			cookie.Value = token
154			if config.CookiePath != "" {
155				cookie.Path = config.CookiePath
156			}
157			if config.CookieDomain != "" {
158				cookie.Domain = config.CookieDomain
159			}
160			cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)
161			cookie.Secure = config.CookieSecure
162			cookie.HttpOnly = config.CookieHTTPOnly
163			c.SetCookie(cookie)
164
165			// Store token in the context
166			c.Set(config.ContextKey, token)
167
168			// Protect clients from caching the response
169			c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie)
170
171			return next(c)
172		}
173	}
174}
175
176// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
177// provided request header.
178func csrfTokenFromHeader(header string) csrfTokenExtractor {
179	return func(c echo.Context) (string, error) {
180		return c.Request().Header.Get(header), nil
181	}
182}
183
184// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
185// provided form parameter.
186func csrfTokenFromForm(param string) csrfTokenExtractor {
187	return func(c echo.Context) (string, error) {
188		token := c.FormValue(param)
189		if token == "" {
190			return "", errors.New("missing csrf token in the form parameter")
191		}
192		return token, nil
193	}
194}
195
196// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
197// provided query parameter.
198func csrfTokenFromQuery(param string) csrfTokenExtractor {
199	return func(c echo.Context) (string, error) {
200		token := c.QueryParam(param)
201		if token == "" {
202			return "", errors.New("missing csrf token in the query string")
203		}
204		return token, nil
205	}
206}
207
208func validateCSRFToken(token, clientToken string) bool {
209	return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
210}
211