1package headers
2
3import (
4	"context"
5	"fmt"
6	"net/http"
7	"regexp"
8	"strconv"
9	"strings"
10
11	"github.com/traefik/traefik/v2/pkg/config/dynamic"
12	"github.com/traefik/traefik/v2/pkg/log"
13)
14
15// Header is a middleware that helps setup a few basic security features.
16// A single headerOptions struct can be provided to configure which features should be enabled,
17// and the ability to override a few of the default values.
18type Header struct {
19	next               http.Handler
20	hasCustomHeaders   bool
21	hasCorsHeaders     bool
22	headers            *dynamic.Headers
23	allowOriginRegexes []*regexp.Regexp
24}
25
26// NewHeader constructs a new header instance from supplied frontend header struct.
27func NewHeader(next http.Handler, cfg dynamic.Headers) (*Header, error) {
28	hasCustomHeaders := cfg.HasCustomHeadersDefined()
29	hasCorsHeaders := cfg.HasCorsHeadersDefined()
30
31	ctx := log.With(context.Background(), log.Str(log.MiddlewareType, typeName))
32	handleDeprecation(ctx, &cfg)
33
34	regexes := make([]*regexp.Regexp, len(cfg.AccessControlAllowOriginListRegex))
35	for i, str := range cfg.AccessControlAllowOriginListRegex {
36		reg, err := regexp.Compile(str)
37		if err != nil {
38			return nil, fmt.Errorf("error occurred during origin parsing: %w", err)
39		}
40		regexes[i] = reg
41	}
42
43	return &Header{
44		next:               next,
45		headers:            &cfg,
46		hasCustomHeaders:   hasCustomHeaders,
47		hasCorsHeaders:     hasCorsHeaders,
48		allowOriginRegexes: regexes,
49	}, nil
50}
51
52func (s *Header) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
53	// Handle Cors headers and preflight if configured.
54	if isPreflight := s.processCorsHeaders(rw, req); isPreflight {
55		return
56	}
57
58	if s.hasCustomHeaders {
59		s.modifyCustomRequestHeaders(req)
60	}
61
62	// If there is a next, call it.
63	if s.next != nil {
64		s.next.ServeHTTP(newResponseModifier(rw, req, s.PostRequestModifyResponseHeaders), req)
65	}
66}
67
68// modifyCustomRequestHeaders sets or deletes custom request headers.
69func (s *Header) modifyCustomRequestHeaders(req *http.Request) {
70	// Loop through Custom request headers
71	for header, value := range s.headers.CustomRequestHeaders {
72		switch {
73		case value == "":
74			req.Header.Del(header)
75
76		case strings.EqualFold(header, "Host"):
77			req.Host = value
78
79		default:
80			req.Header.Set(header, value)
81		}
82	}
83}
84
85// PostRequestModifyResponseHeaders set or delete response headers.
86// This method is called AFTER the response is generated from the backend
87// and can merge/override headers from the backend response.
88func (s *Header) PostRequestModifyResponseHeaders(res *http.Response) error {
89	// Loop through Custom response headers
90	for header, value := range s.headers.CustomResponseHeaders {
91		if value == "" {
92			res.Header.Del(header)
93		} else {
94			res.Header.Set(header, value)
95		}
96	}
97
98	if res != nil && res.Request != nil {
99		originHeader := res.Request.Header.Get("Origin")
100		allowed, match := s.isOriginAllowed(originHeader)
101
102		if allowed {
103			res.Header.Set("Access-Control-Allow-Origin", match)
104		}
105	}
106
107	if s.headers.AccessControlAllowCredentials {
108		res.Header.Set("Access-Control-Allow-Credentials", "true")
109	}
110
111	if len(s.headers.AccessControlExposeHeaders) > 0 {
112		exposeHeaders := strings.Join(s.headers.AccessControlExposeHeaders, ",")
113		res.Header.Set("Access-Control-Expose-Headers", exposeHeaders)
114	}
115
116	if !s.headers.AddVaryHeader {
117		return nil
118	}
119
120	varyHeader := res.Header.Get("Vary")
121	if varyHeader == "Origin" {
122		return nil
123	}
124
125	if varyHeader != "" {
126		varyHeader += ","
127	}
128	varyHeader += "Origin"
129
130	res.Header.Set("Vary", varyHeader)
131	return nil
132}
133
134// processCorsHeaders processes the incoming request,
135// and returns if it is a preflight request.
136// If not a preflight, it handles the preRequestModifyCorsResponseHeaders.
137func (s *Header) processCorsHeaders(rw http.ResponseWriter, req *http.Request) bool {
138	if !s.hasCorsHeaders {
139		return false
140	}
141
142	reqAcMethod := req.Header.Get("Access-Control-Request-Method")
143	originHeader := req.Header.Get("Origin")
144
145	if reqAcMethod != "" && originHeader != "" && req.Method == http.MethodOptions {
146		// If the request is an OPTIONS request with an Access-Control-Request-Method header,
147		// and Origin headers, then it is a CORS preflight request,
148		// and we need to build a custom response: https://www.w3.org/TR/cors/#preflight-request
149		if s.headers.AccessControlAllowCredentials {
150			rw.Header().Set("Access-Control-Allow-Credentials", "true")
151		}
152
153		allowHeaders := strings.Join(s.headers.AccessControlAllowHeaders, ",")
154		if allowHeaders != "" {
155			rw.Header().Set("Access-Control-Allow-Headers", allowHeaders)
156		}
157
158		allowMethods := strings.Join(s.headers.AccessControlAllowMethods, ",")
159		if allowMethods != "" {
160			rw.Header().Set("Access-Control-Allow-Methods", allowMethods)
161		}
162
163		allowed, match := s.isOriginAllowed(originHeader)
164		if allowed {
165			rw.Header().Set("Access-Control-Allow-Origin", match)
166		}
167
168		rw.Header().Set("Access-Control-Max-Age", strconv.Itoa(int(s.headers.AccessControlMaxAge)))
169		return true
170	}
171
172	return false
173}
174
175func (s *Header) isOriginAllowed(origin string) (bool, string) {
176	for _, item := range s.headers.AccessControlAllowOriginList {
177		if item == "*" || item == origin {
178			return true, item
179		}
180	}
181
182	for _, regex := range s.allowOriginRegexes {
183		if regex.MatchString(origin) {
184			return true, origin
185		}
186	}
187
188	return false, ""
189}
190