1package tcpmiddleware
2
3import (
4	"context"
5	"fmt"
6	"strings"
7
8	"github.com/traefik/traefik/v2/pkg/config/runtime"
9	inflightconn "github.com/traefik/traefik/v2/pkg/middlewares/tcp/inflightconn"
10	ipwhitelist "github.com/traefik/traefik/v2/pkg/middlewares/tcp/ipwhitelist"
11	"github.com/traefik/traefik/v2/pkg/server/provider"
12	"github.com/traefik/traefik/v2/pkg/tcp"
13)
14
15type middlewareStackType int
16
17const (
18	middlewareStackKey middlewareStackType = iota
19)
20
21// Builder the middleware builder.
22type Builder struct {
23	configs map[string]*runtime.TCPMiddlewareInfo
24}
25
26// NewBuilder creates a new Builder.
27func NewBuilder(configs map[string]*runtime.TCPMiddlewareInfo) *Builder {
28	return &Builder{configs: configs}
29}
30
31// BuildChain creates a middleware chain.
32func (b *Builder) BuildChain(ctx context.Context, middlewares []string) *tcp.Chain {
33	chain := tcp.NewChain()
34
35	for _, name := range middlewares {
36		middlewareName := provider.GetQualifiedName(ctx, name)
37
38		chain = chain.Append(func(next tcp.Handler) (tcp.Handler, error) {
39			constructorContext := provider.AddInContext(ctx, middlewareName)
40			if midInf, ok := b.configs[middlewareName]; !ok || midInf.TCPMiddleware == nil {
41				return nil, fmt.Errorf("middleware %q does not exist", middlewareName)
42			}
43
44			var err error
45			if constructorContext, err = checkRecursion(constructorContext, middlewareName); err != nil {
46				b.configs[middlewareName].AddError(err, true)
47				return nil, err
48			}
49
50			constructor, err := b.buildConstructor(constructorContext, middlewareName)
51			if err != nil {
52				b.configs[middlewareName].AddError(err, true)
53				return nil, err
54			}
55
56			handler, err := constructor(next)
57			if err != nil {
58				b.configs[middlewareName].AddError(err, true)
59				return nil, err
60			}
61
62			return handler, nil
63		})
64	}
65
66	return &chain
67}
68
69func checkRecursion(ctx context.Context, middlewareName string) (context.Context, error) {
70	currentStack, ok := ctx.Value(middlewareStackKey).([]string)
71	if !ok {
72		currentStack = []string{}
73	}
74
75	if inSlice(middlewareName, currentStack) {
76		return ctx, fmt.Errorf("could not instantiate middleware %s: recursion detected in %s", middlewareName, strings.Join(append(currentStack, middlewareName), "->"))
77	}
78
79	return context.WithValue(ctx, middlewareStackKey, append(currentStack, middlewareName)), nil
80}
81
82func (b *Builder) buildConstructor(ctx context.Context, middlewareName string) (tcp.Constructor, error) {
83	config := b.configs[middlewareName]
84	if config == nil || config.TCPMiddleware == nil {
85		return nil, fmt.Errorf("invalid middleware %q configuration", middlewareName)
86	}
87
88	var middleware tcp.Constructor
89
90	// InFlightConn
91	if config.InFlightConn != nil {
92		middleware = func(next tcp.Handler) (tcp.Handler, error) {
93			return inflightconn.New(ctx, next, *config.InFlightConn, middlewareName)
94		}
95	}
96
97	// IPWhiteList
98	if config.IPWhiteList != nil {
99		middleware = func(next tcp.Handler) (tcp.Handler, error) {
100			return ipwhitelist.New(ctx, next, *config.IPWhiteList, middlewareName)
101		}
102	}
103
104	if middleware == nil {
105		return nil, fmt.Errorf("invalid middleware %q configuration: invalid middleware type or middleware does not exist", middlewareName)
106	}
107
108	return middleware, nil
109}
110
111func inSlice(element string, stack []string) bool {
112	for _, value := range stack {
113		if value == element {
114			return true
115		}
116	}
117	return false
118}
119