1package listenerutil
2
3import (
4	"errors"
5	"fmt"
6	"net/textproto"
7	"strings"
8	"time"
9
10	"github.com/hashicorp/go-multierror"
11	"github.com/hashicorp/go-secure-stdlib/parseutil"
12	"github.com/hashicorp/go-secure-stdlib/strutil"
13	"github.com/hashicorp/go-secure-stdlib/tlsutil"
14	"github.com/hashicorp/go-sockaddr"
15	"github.com/hashicorp/hcl"
16	"github.com/hashicorp/hcl/hcl/ast"
17)
18
19type ListenerTelemetry struct {
20	UnauthenticatedMetricsAccess    bool        `hcl:"-"`
21	UnauthenticatedMetricsAccessRaw interface{} `hcl:"unauthenticated_metrics_access"`
22}
23
24// ListenerConfig is the listener configuration for the server.
25type ListenerConfig struct {
26	RawConfig map[string]interface{}
27
28	Type       string
29	Purpose    []string    `hcl:"-"`
30	PurposeRaw interface{} `hcl:"purpose"`
31
32	Address                 string        `hcl:"address"`
33	ClusterAddress          string        `hcl:"cluster_address"`
34	MaxRequestSize          int64         `hcl:"-"`
35	MaxRequestSizeRaw       interface{}   `hcl:"max_request_size"`
36	MaxRequestDuration      time.Duration `hcl:"-"`
37	MaxRequestDurationRaw   interface{}   `hcl:"max_request_duration"`
38	RequireRequestHeader    bool          `hcl:"-"`
39	RequireRequestHeaderRaw interface{}   `hcl:"require_request_header"`
40
41	TLSDisable                       bool        `hcl:"-"`
42	TLSDisableRaw                    interface{} `hcl:"tls_disable"`
43	TLSCertFile                      string      `hcl:"tls_cert_file"`
44	TLSKeyFile                       string      `hcl:"tls_key_file"`
45	TLSMinVersion                    string      `hcl:"tls_min_version"`
46	TLSMaxVersion                    string      `hcl:"tls_max_version"`
47	TLSCipherSuites                  []uint16    `hcl:"-"`
48	TLSCipherSuitesRaw               string      `hcl:"tls_cipher_suites"`
49	TLSPreferServerCipherSuites      bool        `hcl:"-"`
50	TLSPreferServerCipherSuitesRaw   interface{} `hcl:"tls_prefer_server_cipher_suites"`
51	TLSRequireAndVerifyClientCert    bool        `hcl:"-"`
52	TLSRequireAndVerifyClientCertRaw interface{} `hcl:"tls_require_and_verify_client_cert"`
53	TLSClientCAFile                  string      `hcl:"tls_client_ca_file"`
54	TLSDisableClientCerts            bool        `hcl:"-"`
55	TLSDisableClientCertsRaw         interface{} `hcl:"tls_disable_client_certs"`
56
57	HTTPReadTimeout          time.Duration `hcl:"-"`
58	HTTPReadTimeoutRaw       interface{}   `hcl:"http_read_timeout"`
59	HTTPReadHeaderTimeout    time.Duration `hcl:"-"`
60	HTTPReadHeaderTimeoutRaw interface{}   `hcl:"http_read_header_timeout"`
61	HTTPWriteTimeout         time.Duration `hcl:"-"`
62	HTTPWriteTimeoutRaw      interface{}   `hcl:"http_write_timeout"`
63	HTTPIdleTimeout          time.Duration `hcl:"-"`
64	HTTPIdleTimeoutRaw       interface{}   `hcl:"http_idle_timeout"`
65
66	ProxyProtocolBehavior           string                        `hcl:"proxy_protocol_behavior"`
67	ProxyProtocolAuthorizedAddrs    []*sockaddr.SockAddrMarshaler `hcl:"-"`
68	ProxyProtocolAuthorizedAddrsRaw interface{}                   `hcl:"proxy_protocol_authorized_addrs"`
69
70	XForwardedForAuthorizedAddrs        []*sockaddr.SockAddrMarshaler `hcl:"-"`
71	XForwardedForAuthorizedAddrsRaw     interface{}                   `hcl:"x_forwarded_for_authorized_addrs"`
72	XForwardedForHopSkips               int64                         `hcl:"-"`
73	XForwardedForHopSkipsRaw            interface{}                   `hcl:"x_forwarded_for_hop_skips"`
74	XForwardedForRejectNotPresent       bool                          `hcl:"-"`
75	XForwardedForRejectNotPresentRaw    interface{}                   `hcl:"x_forwarded_for_reject_not_present"`
76	XForwardedForRejectNotAuthorized    bool                          `hcl:"-"`
77	XForwardedForRejectNotAuthorizedRaw interface{}                   `hcl:"x_forwarded_for_reject_not_authorized"`
78
79	SocketMode  string `hcl:"socket_mode"`
80	SocketUser  string `hcl:"socket_user"`
81	SocketGroup string `hcl:"socket_group"`
82
83	Telemetry ListenerTelemetry `hcl:"telemetry"`
84
85	// RandomPort is used only for some testing purposes
86	RandomPort bool `hcl:"-"`
87
88	CorsEnabledRaw                           interface{} `hcl:"cors_enabled"`
89	CorsEnabled                              *bool       `hcl:"-"`
90	CorsDisableDefaultAllowedOriginValuesRaw interface{} `hcl:"cors_disable_default_allowed_origin_values"`
91	CorsDisableDefaultAllowedOriginValues    *bool       `hcl:"-"`
92	CorsAllowedOrigins                       []string    `hcl:"cors_allowed_origins"`
93	CorsAllowedHeaders                       []string    `hcl:"-"`
94	CorsAllowedHeadersRaw                    []string    `hcl:"cors_allowed_headers"`
95}
96
97func (l *ListenerConfig) GoString() string {
98	return fmt.Sprintf("*%#v", *l)
99}
100
101func ParseListeners(list *ast.ObjectList) ([]*ListenerConfig, error) {
102	var err error
103	result := make([]*ListenerConfig, 0, len(list.Items))
104	for i, item := range list.Items {
105		var l ListenerConfig
106		if err := hcl.DecodeObject(&l, item.Val); err != nil {
107			return nil, multierror.Prefix(err, fmt.Sprintf("listeners.%d:", i))
108		}
109
110		// Hacky way, for now, to get the values we want for sanitizing
111		var m map[string]interface{}
112		if err := hcl.DecodeObject(&m, item.Val); err != nil {
113			return nil, multierror.Prefix(err, fmt.Sprintf("listeners.%d:", i))
114		}
115		l.RawConfig = m
116
117		// Base values
118		{
119			switch {
120			case l.Type != "":
121			case len(item.Keys) == 1:
122				l.Type = strings.ToLower(item.Keys[0].Token.Value().(string))
123			default:
124				return nil, multierror.Prefix(errors.New("listener type must be specified"), fmt.Sprintf("listeners.%d:", i))
125			}
126
127			l.Type = strings.ToLower(l.Type)
128			switch l.Type {
129			case "tcp", "unix":
130			default:
131				return nil, multierror.Prefix(fmt.Errorf("unsupported listener type %q", l.Type), fmt.Sprintf("listeners.%d:", i))
132			}
133
134			if l.PurposeRaw != nil {
135				if l.Purpose, err = parseutil.ParseCommaStringSlice(l.PurposeRaw); err != nil {
136					return nil, multierror.Prefix(fmt.Errorf("unable to parse 'purpose' in listener type %q: %w", l.Type, err), fmt.Sprintf("listeners.%d:", i))
137				}
138				for i, v := range l.Purpose {
139					l.Purpose[i] = strings.ToLower(v)
140				}
141
142				l.PurposeRaw = nil
143			}
144		}
145
146		// Request Parameters
147		{
148			if l.MaxRequestSizeRaw != nil {
149				if l.MaxRequestSize, err = parseutil.ParseInt(l.MaxRequestSizeRaw); err != nil {
150					return nil, multierror.Prefix(fmt.Errorf("error parsing max_request_size: %w", err), fmt.Sprintf("listeners.%d", i))
151				}
152
153				if l.MaxRequestSize < 0 {
154					return nil, multierror.Prefix(errors.New("max_request_size cannot be negative"), fmt.Sprintf("listeners.%d", i))
155				}
156
157				l.MaxRequestSizeRaw = nil
158			}
159
160			if l.MaxRequestDurationRaw != nil {
161				if l.MaxRequestDuration, err = parseutil.ParseDurationSecond(l.MaxRequestDurationRaw); err != nil {
162					return nil, multierror.Prefix(fmt.Errorf("error parsing max_request_duration: %w", err), fmt.Sprintf("listeners.%d", i))
163				}
164				if l.MaxRequestDuration < 0 {
165					return nil, multierror.Prefix(errors.New("max_request_duration cannot be negative"), fmt.Sprintf("listeners.%d", i))
166				}
167
168				l.MaxRequestDurationRaw = nil
169			}
170
171			if l.RequireRequestHeaderRaw != nil {
172				if l.RequireRequestHeader, err = parseutil.ParseBool(l.RequireRequestHeaderRaw); err != nil {
173					return nil, multierror.Prefix(fmt.Errorf("invalid value for require_request_header: %w", err), fmt.Sprintf("listeners.%d", i))
174				}
175
176				l.RequireRequestHeaderRaw = nil
177			}
178		}
179
180		// TLS Parameters
181		{
182			if l.TLSDisableRaw != nil {
183				if l.TLSDisable, err = parseutil.ParseBool(l.TLSDisableRaw); err != nil {
184					return nil, multierror.Prefix(fmt.Errorf("invalid value for tls_disable: %w", err), fmt.Sprintf("listeners.%d", i))
185				}
186
187				l.TLSDisableRaw = nil
188			}
189
190			if l.TLSCipherSuitesRaw != "" {
191				if l.TLSCipherSuites, err = tlsutil.ParseCiphers(l.TLSCipherSuitesRaw); err != nil {
192					return nil, multierror.Prefix(fmt.Errorf("invalid value for tls_cipher_suites: %w", err), fmt.Sprintf("listeners.%d", i))
193				}
194			}
195
196			if l.TLSPreferServerCipherSuitesRaw != nil {
197				if l.TLSPreferServerCipherSuites, err = parseutil.ParseBool(l.TLSPreferServerCipherSuitesRaw); err != nil {
198					return nil, multierror.Prefix(fmt.Errorf("invalid value for tls_prefer_server_cipher_suites: %w", err), fmt.Sprintf("listeners.%d", i))
199				}
200
201				l.TLSPreferServerCipherSuitesRaw = nil
202			}
203
204			if l.TLSRequireAndVerifyClientCertRaw != nil {
205				if l.TLSRequireAndVerifyClientCert, err = parseutil.ParseBool(l.TLSRequireAndVerifyClientCertRaw); err != nil {
206					return nil, multierror.Prefix(fmt.Errorf("invalid value for tls_require_and_verify_client_cert: %w", err), fmt.Sprintf("listeners.%d", i))
207				}
208
209				l.TLSRequireAndVerifyClientCertRaw = nil
210			}
211
212			if l.TLSDisableClientCertsRaw != nil {
213				if l.TLSDisableClientCerts, err = parseutil.ParseBool(l.TLSDisableClientCertsRaw); err != nil {
214					return nil, multierror.Prefix(fmt.Errorf("invalid value for tls_disable_client_certs: %w", err), fmt.Sprintf("listeners.%d", i))
215				}
216
217				l.TLSDisableClientCertsRaw = nil
218			}
219		}
220
221		// HTTP timeouts
222		{
223			if l.HTTPReadTimeoutRaw != nil {
224				if l.HTTPReadTimeout, err = parseutil.ParseDurationSecond(l.HTTPReadTimeoutRaw); err != nil {
225					return nil, multierror.Prefix(fmt.Errorf("error parsing http_read_timeout: %w", err), fmt.Sprintf("listeners.%d", i))
226				}
227
228				l.HTTPReadTimeoutRaw = nil
229			}
230
231			if l.HTTPReadHeaderTimeoutRaw != nil {
232				if l.HTTPReadHeaderTimeout, err = parseutil.ParseDurationSecond(l.HTTPReadHeaderTimeoutRaw); err != nil {
233					return nil, multierror.Prefix(fmt.Errorf("error parsing http_read_header_timeout: %w", err), fmt.Sprintf("listeners.%d", i))
234				}
235
236				l.HTTPReadHeaderTimeoutRaw = nil
237			}
238
239			if l.HTTPWriteTimeoutRaw != nil {
240				if l.HTTPWriteTimeout, err = parseutil.ParseDurationSecond(l.HTTPWriteTimeoutRaw); err != nil {
241					return nil, multierror.Prefix(fmt.Errorf("error parsing http_write_timeout: %w", err), fmt.Sprintf("listeners.%d", i))
242				}
243
244				l.HTTPWriteTimeoutRaw = nil
245			}
246
247			if l.HTTPIdleTimeoutRaw != nil {
248				if l.HTTPIdleTimeout, err = parseutil.ParseDurationSecond(l.HTTPIdleTimeoutRaw); err != nil {
249					return nil, multierror.Prefix(fmt.Errorf("error parsing http_idle_timeout: %w", err), fmt.Sprintf("listeners.%d", i))
250				}
251
252				l.HTTPIdleTimeoutRaw = nil
253			}
254		}
255
256		// Proxy Protocol config
257		{
258			if l.ProxyProtocolAuthorizedAddrsRaw != nil {
259				if l.ProxyProtocolAuthorizedAddrs, err = parseutil.ParseAddrs(l.ProxyProtocolAuthorizedAddrsRaw); err != nil {
260					return nil, multierror.Prefix(fmt.Errorf("error parsing proxy_protocol_authorized_addrs: %w", err), fmt.Sprintf("listeners.%d", i))
261				}
262
263				switch l.ProxyProtocolBehavior {
264				case "allow_authorized", "deny_authorized":
265					if len(l.ProxyProtocolAuthorizedAddrs) == 0 {
266						return nil, multierror.Prefix(errors.New("proxy_protocol_behavior set to allow or deny only authorized addresses but no proxy_protocol_authorized_addrs value"), fmt.Sprintf("listeners.%d", i))
267					}
268				}
269
270				l.ProxyProtocolAuthorizedAddrsRaw = nil
271			}
272		}
273
274		// X-Forwarded-For config
275		{
276			if l.XForwardedForAuthorizedAddrsRaw != nil {
277				if l.XForwardedForAuthorizedAddrs, err = parseutil.ParseAddrs(l.XForwardedForAuthorizedAddrsRaw); err != nil {
278					return nil, multierror.Prefix(fmt.Errorf("error parsing x_forwarded_for_authorized_addrs: %w", err), fmt.Sprintf("listeners.%d", i))
279				}
280
281				l.XForwardedForAuthorizedAddrsRaw = nil
282			}
283
284			if l.XForwardedForHopSkipsRaw != nil {
285				if l.XForwardedForHopSkips, err = parseutil.ParseInt(l.XForwardedForHopSkipsRaw); err != nil {
286					return nil, multierror.Prefix(fmt.Errorf("error parsing x_forwarded_for_hop_skips: %w", err), fmt.Sprintf("listeners.%d", i))
287				}
288
289				if l.XForwardedForHopSkips < 0 {
290					return nil, multierror.Prefix(fmt.Errorf("x_forwarded_for_hop_skips cannot be negative but set to %d", l.XForwardedForHopSkips), fmt.Sprintf("listeners.%d", i))
291				}
292
293				l.XForwardedForHopSkipsRaw = nil
294			}
295
296			if l.XForwardedForRejectNotAuthorizedRaw != nil {
297				if l.XForwardedForRejectNotAuthorized, err = parseutil.ParseBool(l.XForwardedForRejectNotAuthorizedRaw); err != nil {
298					return nil, multierror.Prefix(fmt.Errorf("invalid value for x_forwarded_for_reject_not_authorized: %w", err), fmt.Sprintf("listeners.%d", i))
299				}
300
301				l.XForwardedForRejectNotAuthorizedRaw = nil
302			}
303
304			if l.XForwardedForRejectNotPresentRaw != nil {
305				if l.XForwardedForRejectNotPresent, err = parseutil.ParseBool(l.XForwardedForRejectNotPresentRaw); err != nil {
306					return nil, multierror.Prefix(fmt.Errorf("invalid value for x_forwarded_for_reject_not_present: %w", err), fmt.Sprintf("listeners.%d", i))
307				}
308
309				l.XForwardedForRejectNotPresentRaw = nil
310			}
311		}
312
313		// Telemetry
314		{
315			if l.Telemetry.UnauthenticatedMetricsAccessRaw != nil {
316				if l.Telemetry.UnauthenticatedMetricsAccess, err = parseutil.ParseBool(l.Telemetry.UnauthenticatedMetricsAccessRaw); err != nil {
317					return nil, multierror.Prefix(fmt.Errorf("invalid value for telemetry.unauthenticated_metrics_access: %w", err), fmt.Sprintf("listeners.%d", i))
318				}
319
320				l.Telemetry.UnauthenticatedMetricsAccessRaw = nil
321			}
322		}
323
324		// CORS
325		{
326			if l.CorsEnabledRaw != nil {
327				corsEnabled, err := parseutil.ParseBool(l.CorsEnabledRaw)
328				if err != nil {
329					return nil, multierror.Prefix(fmt.Errorf("invalid value for cors_enabled: %w", err), fmt.Sprintf("listeners.%d", i))
330				}
331				l.CorsEnabled = &corsEnabled
332				l.CorsEnabledRaw = nil
333			}
334
335			if l.CorsDisableDefaultAllowedOriginValuesRaw != nil {
336				disabled, err := parseutil.ParseBool(l.CorsDisableDefaultAllowedOriginValuesRaw)
337				if err != nil {
338					return nil, multierror.Prefix(fmt.Errorf("invalid value for cors_disable_default_allowed_origin_values: %w", err), fmt.Sprintf("listeners.%d", i))
339				}
340				l.CorsDisableDefaultAllowedOriginValues = &disabled
341				l.CorsDisableDefaultAllowedOriginValuesRaw = nil
342			}
343
344			if strutil.StrListContains(l.CorsAllowedOrigins, "*") && len(l.CorsAllowedOrigins) > 1 {
345				return nil, multierror.Prefix(errors.New("cors_allowed_origins must only contain a wildcard or only non-wildcard values"), fmt.Sprintf("listeners.%d", i))
346			}
347
348			if len(l.CorsAllowedHeadersRaw) > 0 {
349				for _, header := range l.CorsAllowedHeadersRaw {
350					l.CorsAllowedHeaders = append(l.CorsAllowedHeaders, textproto.CanonicalMIMEHeaderKey(header))
351				}
352			}
353		}
354
355		result = append(result, &l)
356	}
357
358	return result, nil
359}
360