1// Package headers Middleware based on https://github.com/unrolled/secure.
2package headers
3
4import (
5	"context"
6	"errors"
7	"net/http"
8
9	"github.com/opentracing/opentracing-go/ext"
10	"github.com/traefik/traefik/v2/pkg/config/dynamic"
11	"github.com/traefik/traefik/v2/pkg/log"
12	"github.com/traefik/traefik/v2/pkg/middlewares"
13	"github.com/traefik/traefik/v2/pkg/middlewares/connectionheader"
14	"github.com/traefik/traefik/v2/pkg/tracing"
15)
16
17const (
18	typeName = "Headers"
19)
20
21func handleDeprecation(ctx context.Context, cfg *dynamic.Headers) {
22	if cfg.SSLRedirect {
23		log.FromContext(ctx).Warn("SSLRedirect is deprecated, please use entrypoint redirection instead.")
24	}
25	if cfg.SSLTemporaryRedirect {
26		log.FromContext(ctx).Warn("SSLTemporaryRedirect is deprecated, please use entrypoint redirection instead.")
27	}
28	if cfg.SSLHost != "" {
29		log.FromContext(ctx).Warn("SSLHost is deprecated, please use RedirectRegex middleware instead.")
30	}
31	if cfg.SSLForceHost {
32		log.FromContext(ctx).Warn("SSLForceHost is deprecated, please use RedirectScheme middleware instead.")
33	}
34	if cfg.FeaturePolicy != "" {
35		log.FromContext(ctx).Warn("FeaturePolicy is deprecated, please use PermissionsPolicy header instead.")
36	}
37}
38
39type headers struct {
40	name    string
41	handler http.Handler
42}
43
44// New creates a Headers middleware.
45func New(ctx context.Context, next http.Handler, cfg dynamic.Headers, name string) (http.Handler, error) {
46	// HeaderMiddleware -> SecureMiddleWare -> next
47	mCtx := middlewares.GetLoggerCtx(ctx, name, typeName)
48	logger := log.FromContext(mCtx)
49	logger.Debug("Creating middleware")
50
51	handleDeprecation(mCtx, &cfg)
52
53	hasSecureHeaders := cfg.HasSecureHeadersDefined()
54	hasCustomHeaders := cfg.HasCustomHeadersDefined()
55	hasCorsHeaders := cfg.HasCorsHeadersDefined()
56
57	if !hasSecureHeaders && !hasCustomHeaders && !hasCorsHeaders {
58		return nil, errors.New("headers configuration not valid")
59	}
60
61	var handler http.Handler
62	nextHandler := next
63
64	if hasSecureHeaders {
65		logger.Debugf("Setting up secureHeaders from %v", cfg)
66		handler = newSecure(next, cfg, name)
67		nextHandler = handler
68	}
69
70	if hasCustomHeaders || hasCorsHeaders {
71		logger.Debugf("Setting up customHeaders/Cors from %v", cfg)
72		h, err := NewHeader(nextHandler, cfg)
73		if err != nil {
74			return nil, err
75		}
76
77		handler = connectionheader.Remover(h)
78	}
79
80	return &headers{
81		handler: handler,
82		name:    name,
83	}, nil
84}
85
86func (h *headers) GetTracingInformation() (string, ext.SpanKindEnum) {
87	return h.name, tracing.SpanKindNoneEnum
88}
89
90func (h *headers) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
91	h.handler.ServeHTTP(rw, req)
92}
93