1package tcpmiddleware 2 3import ( 4 "context" 5 "fmt" 6 "strings" 7 8 "github.com/traefik/traefik/v2/pkg/config/runtime" 9 inflightconn "github.com/traefik/traefik/v2/pkg/middlewares/tcp/inflightconn" 10 ipwhitelist "github.com/traefik/traefik/v2/pkg/middlewares/tcp/ipwhitelist" 11 "github.com/traefik/traefik/v2/pkg/server/provider" 12 "github.com/traefik/traefik/v2/pkg/tcp" 13) 14 15type middlewareStackType int 16 17const ( 18 middlewareStackKey middlewareStackType = iota 19) 20 21// Builder the middleware builder. 22type Builder struct { 23 configs map[string]*runtime.TCPMiddlewareInfo 24} 25 26// NewBuilder creates a new Builder. 27func NewBuilder(configs map[string]*runtime.TCPMiddlewareInfo) *Builder { 28 return &Builder{configs: configs} 29} 30 31// BuildChain creates a middleware chain. 32func (b *Builder) BuildChain(ctx context.Context, middlewares []string) *tcp.Chain { 33 chain := tcp.NewChain() 34 35 for _, name := range middlewares { 36 middlewareName := provider.GetQualifiedName(ctx, name) 37 38 chain = chain.Append(func(next tcp.Handler) (tcp.Handler, error) { 39 constructorContext := provider.AddInContext(ctx, middlewareName) 40 if midInf, ok := b.configs[middlewareName]; !ok || midInf.TCPMiddleware == nil { 41 return nil, fmt.Errorf("middleware %q does not exist", middlewareName) 42 } 43 44 var err error 45 if constructorContext, err = checkRecursion(constructorContext, middlewareName); err != nil { 46 b.configs[middlewareName].AddError(err, true) 47 return nil, err 48 } 49 50 constructor, err := b.buildConstructor(constructorContext, middlewareName) 51 if err != nil { 52 b.configs[middlewareName].AddError(err, true) 53 return nil, err 54 } 55 56 handler, err := constructor(next) 57 if err != nil { 58 b.configs[middlewareName].AddError(err, true) 59 return nil, err 60 } 61 62 return handler, nil 63 }) 64 } 65 66 return &chain 67} 68 69func checkRecursion(ctx context.Context, middlewareName string) (context.Context, error) { 70 currentStack, ok := ctx.Value(middlewareStackKey).([]string) 71 if !ok { 72 currentStack = []string{} 73 } 74 75 if inSlice(middlewareName, currentStack) { 76 return ctx, fmt.Errorf("could not instantiate middleware %s: recursion detected in %s", middlewareName, strings.Join(append(currentStack, middlewareName), "->")) 77 } 78 79 return context.WithValue(ctx, middlewareStackKey, append(currentStack, middlewareName)), nil 80} 81 82func (b *Builder) buildConstructor(ctx context.Context, middlewareName string) (tcp.Constructor, error) { 83 config := b.configs[middlewareName] 84 if config == nil || config.TCPMiddleware == nil { 85 return nil, fmt.Errorf("invalid middleware %q configuration", middlewareName) 86 } 87 88 var middleware tcp.Constructor 89 90 // InFlightConn 91 if config.InFlightConn != nil { 92 middleware = func(next tcp.Handler) (tcp.Handler, error) { 93 return inflightconn.New(ctx, next, *config.InFlightConn, middlewareName) 94 } 95 } 96 97 // IPWhiteList 98 if config.IPWhiteList != nil { 99 middleware = func(next tcp.Handler) (tcp.Handler, error) { 100 return ipwhitelist.New(ctx, next, *config.IPWhiteList, middlewareName) 101 } 102 } 103 104 if middleware == nil { 105 return nil, fmt.Errorf("invalid middleware %q configuration: invalid middleware type or middleware does not exist", middlewareName) 106 } 107 108 return middleware, nil 109} 110 111func inSlice(element string, stack []string) bool { 112 for _, value := range stack { 113 if value == element { 114 return true 115 } 116 } 117 return false 118} 119