1/*
2Copyright 2016 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package net
18
19import (
20	"bufio"
21	"bytes"
22	"context"
23	"crypto/tls"
24	"errors"
25	"fmt"
26	"io"
27	"mime"
28	"net"
29	"net/http"
30	"net/url"
31	"os"
32	"path"
33	"regexp"
34	"strconv"
35	"strings"
36	"time"
37	"unicode"
38	"unicode/utf8"
39
40	"golang.org/x/net/http2"
41	"k8s.io/klog/v2"
42)
43
44// JoinPreservingTrailingSlash does a path.Join of the specified elements,
45// preserving any trailing slash on the last non-empty segment
46func JoinPreservingTrailingSlash(elem ...string) string {
47	// do the basic path join
48	result := path.Join(elem...)
49
50	// find the last non-empty segment
51	for i := len(elem) - 1; i >= 0; i-- {
52		if len(elem[i]) > 0 {
53			// if the last segment ended in a slash, ensure our result does as well
54			if strings.HasSuffix(elem[i], "/") && !strings.HasSuffix(result, "/") {
55				result += "/"
56			}
57			break
58		}
59	}
60
61	return result
62}
63
64// IsTimeout returns true if the given error is a network timeout error
65func IsTimeout(err error) bool {
66	var neterr net.Error
67	if errors.As(err, &neterr) {
68		return neterr != nil && neterr.Timeout()
69	}
70	return false
71}
72
73// IsProbableEOF returns true if the given error resembles a connection termination
74// scenario that would justify assuming that the watch is empty.
75// These errors are what the Go http stack returns back to us which are general
76// connection closure errors (strongly correlated) and callers that need to
77// differentiate probable errors in connection behavior between normal "this is
78// disconnected" should use the method.
79func IsProbableEOF(err error) bool {
80	if err == nil {
81		return false
82	}
83	var uerr *url.Error
84	if errors.As(err, &uerr) {
85		err = uerr.Err
86	}
87	msg := err.Error()
88	switch {
89	case err == io.EOF:
90		return true
91	case err == io.ErrUnexpectedEOF:
92		return true
93	case msg == "http: can't write HTTP request on broken connection":
94		return true
95	case strings.Contains(msg, "http2: server sent GOAWAY and closed the connection"):
96		return true
97	case strings.Contains(msg, "connection reset by peer"):
98		return true
99	case strings.Contains(strings.ToLower(msg), "use of closed network connection"):
100		return true
101	}
102	return false
103}
104
105var defaultTransport = http.DefaultTransport.(*http.Transport)
106
107// SetOldTransportDefaults applies the defaults from http.DefaultTransport
108// for the Proxy, Dial, and TLSHandshakeTimeout fields if unset
109func SetOldTransportDefaults(t *http.Transport) *http.Transport {
110	if t.Proxy == nil || isDefault(t.Proxy) {
111		// http.ProxyFromEnvironment doesn't respect CIDRs and that makes it impossible to exclude things like pod and service IPs from proxy settings
112		// ProxierWithNoProxyCIDR allows CIDR rules in NO_PROXY
113		t.Proxy = NewProxierWithNoProxyCIDR(http.ProxyFromEnvironment)
114	}
115	// If no custom dialer is set, use the default context dialer
116	if t.DialContext == nil && t.Dial == nil {
117		t.DialContext = defaultTransport.DialContext
118	}
119	if t.TLSHandshakeTimeout == 0 {
120		t.TLSHandshakeTimeout = defaultTransport.TLSHandshakeTimeout
121	}
122	if t.IdleConnTimeout == 0 {
123		t.IdleConnTimeout = defaultTransport.IdleConnTimeout
124	}
125	return t
126}
127
128// SetTransportDefaults applies the defaults from http.DefaultTransport
129// for the Proxy, Dial, and TLSHandshakeTimeout fields if unset
130func SetTransportDefaults(t *http.Transport) *http.Transport {
131	t = SetOldTransportDefaults(t)
132	// Allow clients to disable http2 if needed.
133	if s := os.Getenv("DISABLE_HTTP2"); len(s) > 0 {
134		klog.Info("HTTP2 has been explicitly disabled")
135	} else if allowsHTTP2(t) {
136		if err := configureHTTP2Transport(t); err != nil {
137			klog.Warningf("Transport failed http2 configuration: %v", err)
138		}
139	}
140	return t
141}
142
143func readIdleTimeoutSeconds() int {
144	ret := 30
145	// User can set the readIdleTimeout to 0 to disable the HTTP/2
146	// connection health check.
147	if s := os.Getenv("HTTP2_READ_IDLE_TIMEOUT_SECONDS"); len(s) > 0 {
148		i, err := strconv.Atoi(s)
149		if err != nil {
150			klog.Warningf("Illegal HTTP2_READ_IDLE_TIMEOUT_SECONDS(%q): %v."+
151				" Default value %d is used", s, err, ret)
152			return ret
153		}
154		ret = i
155	}
156	return ret
157}
158
159func pingTimeoutSeconds() int {
160	ret := 15
161	if s := os.Getenv("HTTP2_PING_TIMEOUT_SECONDS"); len(s) > 0 {
162		i, err := strconv.Atoi(s)
163		if err != nil {
164			klog.Warningf("Illegal HTTP2_PING_TIMEOUT_SECONDS(%q): %v."+
165				" Default value %d is used", s, err, ret)
166			return ret
167		}
168		ret = i
169	}
170	return ret
171}
172
173func configureHTTP2Transport(t *http.Transport) error {
174	t2, err := http2.ConfigureTransports(t)
175	if err != nil {
176		return err
177	}
178	// The following enables the HTTP/2 connection health check added in
179	// https://github.com/golang/net/pull/55. The health check detects and
180	// closes broken transport layer connections. Without the health check,
181	// a broken connection can linger too long, e.g., a broken TCP
182	// connection will be closed by the Linux kernel after 13 to 30 minutes
183	// by default, which caused
184	// https://github.com/kubernetes/client-go/issues/374 and
185	// https://github.com/kubernetes/kubernetes/issues/87615.
186	t2.ReadIdleTimeout = time.Duration(readIdleTimeoutSeconds()) * time.Second
187	t2.PingTimeout = time.Duration(pingTimeoutSeconds()) * time.Second
188	return nil
189}
190
191func allowsHTTP2(t *http.Transport) bool {
192	if t.TLSClientConfig == nil || len(t.TLSClientConfig.NextProtos) == 0 {
193		// the transport expressed no NextProto preference, allow
194		return true
195	}
196	for _, p := range t.TLSClientConfig.NextProtos {
197		if p == http2.NextProtoTLS {
198			// the transport explicitly allowed http/2
199			return true
200		}
201	}
202	// the transport explicitly set NextProtos and excluded http/2
203	return false
204}
205
206type RoundTripperWrapper interface {
207	http.RoundTripper
208	WrappedRoundTripper() http.RoundTripper
209}
210
211type DialFunc func(ctx context.Context, net, addr string) (net.Conn, error)
212
213func DialerFor(transport http.RoundTripper) (DialFunc, error) {
214	if transport == nil {
215		return nil, nil
216	}
217
218	switch transport := transport.(type) {
219	case *http.Transport:
220		// transport.DialContext takes precedence over transport.Dial
221		if transport.DialContext != nil {
222			return transport.DialContext, nil
223		}
224		// adapt transport.Dial to the DialWithContext signature
225		if transport.Dial != nil {
226			return func(ctx context.Context, net, addr string) (net.Conn, error) {
227				return transport.Dial(net, addr)
228			}, nil
229		}
230		// otherwise return nil
231		return nil, nil
232	case RoundTripperWrapper:
233		return DialerFor(transport.WrappedRoundTripper())
234	default:
235		return nil, fmt.Errorf("unknown transport type: %T", transport)
236	}
237}
238
239type TLSClientConfigHolder interface {
240	TLSClientConfig() *tls.Config
241}
242
243func TLSClientConfig(transport http.RoundTripper) (*tls.Config, error) {
244	if transport == nil {
245		return nil, nil
246	}
247
248	switch transport := transport.(type) {
249	case *http.Transport:
250		return transport.TLSClientConfig, nil
251	case TLSClientConfigHolder:
252		return transport.TLSClientConfig(), nil
253	case RoundTripperWrapper:
254		return TLSClientConfig(transport.WrappedRoundTripper())
255	default:
256		return nil, fmt.Errorf("unknown transport type: %T", transport)
257	}
258}
259
260func FormatURL(scheme string, host string, port int, path string) *url.URL {
261	return &url.URL{
262		Scheme: scheme,
263		Host:   net.JoinHostPort(host, strconv.Itoa(port)),
264		Path:   path,
265	}
266}
267
268func GetHTTPClient(req *http.Request) string {
269	if ua := req.UserAgent(); len(ua) != 0 {
270		return ua
271	}
272	return "unknown"
273}
274
275// SourceIPs splits the comma separated X-Forwarded-For header and joins it with
276// the X-Real-Ip header and/or req.RemoteAddr, ignoring invalid IPs.
277// The X-Real-Ip is omitted if it's already present in the X-Forwarded-For chain.
278// The req.RemoteAddr is always the last IP in the returned list.
279// It returns nil if all of these are empty or invalid.
280func SourceIPs(req *http.Request) []net.IP {
281	var srcIPs []net.IP
282
283	hdr := req.Header
284	// First check the X-Forwarded-For header for requests via proxy.
285	hdrForwardedFor := hdr.Get("X-Forwarded-For")
286	if hdrForwardedFor != "" {
287		// X-Forwarded-For can be a csv of IPs in case of multiple proxies.
288		// Use the first valid one.
289		parts := strings.Split(hdrForwardedFor, ",")
290		for _, part := range parts {
291			ip := net.ParseIP(strings.TrimSpace(part))
292			if ip != nil {
293				srcIPs = append(srcIPs, ip)
294			}
295		}
296	}
297
298	// Try the X-Real-Ip header.
299	hdrRealIp := hdr.Get("X-Real-Ip")
300	if hdrRealIp != "" {
301		ip := net.ParseIP(hdrRealIp)
302		// Only append the X-Real-Ip if it's not already contained in the X-Forwarded-For chain.
303		if ip != nil && !containsIP(srcIPs, ip) {
304			srcIPs = append(srcIPs, ip)
305		}
306	}
307
308	// Always include the request Remote Address as it cannot be easily spoofed.
309	var remoteIP net.IP
310	// Remote Address in Go's HTTP server is in the form host:port so we need to split that first.
311	host, _, err := net.SplitHostPort(req.RemoteAddr)
312	if err == nil {
313		remoteIP = net.ParseIP(host)
314	}
315	// Fallback if Remote Address was just IP.
316	if remoteIP == nil {
317		remoteIP = net.ParseIP(req.RemoteAddr)
318	}
319
320	// Don't duplicate remote IP if it's already the last address in the chain.
321	if remoteIP != nil && (len(srcIPs) == 0 || !remoteIP.Equal(srcIPs[len(srcIPs)-1])) {
322		srcIPs = append(srcIPs, remoteIP)
323	}
324
325	return srcIPs
326}
327
328// Checks whether the given IP address is contained in the list of IPs.
329func containsIP(ips []net.IP, ip net.IP) bool {
330	for _, v := range ips {
331		if v.Equal(ip) {
332			return true
333		}
334	}
335	return false
336}
337
338// Extracts and returns the clients IP from the given request.
339// Looks at X-Forwarded-For header, X-Real-Ip header and request.RemoteAddr in that order.
340// Returns nil if none of them are set or is set to an invalid value.
341func GetClientIP(req *http.Request) net.IP {
342	ips := SourceIPs(req)
343	if len(ips) == 0 {
344		return nil
345	}
346	return ips[0]
347}
348
349// Prepares the X-Forwarded-For header for another forwarding hop by appending the previous sender's
350// IP address to the X-Forwarded-For chain.
351func AppendForwardedForHeader(req *http.Request) {
352	// Copied from net/http/httputil/reverseproxy.go:
353	if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
354		// If we aren't the first proxy retain prior
355		// X-Forwarded-For information as a comma+space
356		// separated list and fold multiple headers into one.
357		if prior, ok := req.Header["X-Forwarded-For"]; ok {
358			clientIP = strings.Join(prior, ", ") + ", " + clientIP
359		}
360		req.Header.Set("X-Forwarded-For", clientIP)
361	}
362}
363
364var defaultProxyFuncPointer = fmt.Sprintf("%p", http.ProxyFromEnvironment)
365
366// isDefault checks to see if the transportProxierFunc is pointing to the default one
367func isDefault(transportProxier func(*http.Request) (*url.URL, error)) bool {
368	transportProxierPointer := fmt.Sprintf("%p", transportProxier)
369	return transportProxierPointer == defaultProxyFuncPointer
370}
371
372// NewProxierWithNoProxyCIDR constructs a Proxier function that respects CIDRs in NO_PROXY and delegates if
373// no matching CIDRs are found
374func NewProxierWithNoProxyCIDR(delegate func(req *http.Request) (*url.URL, error)) func(req *http.Request) (*url.URL, error) {
375	// we wrap the default method, so we only need to perform our check if the NO_PROXY (or no_proxy) envvar has a CIDR in it
376	noProxyEnv := os.Getenv("NO_PROXY")
377	if noProxyEnv == "" {
378		noProxyEnv = os.Getenv("no_proxy")
379	}
380	noProxyRules := strings.Split(noProxyEnv, ",")
381
382	cidrs := []*net.IPNet{}
383	for _, noProxyRule := range noProxyRules {
384		_, cidr, _ := net.ParseCIDR(noProxyRule)
385		if cidr != nil {
386			cidrs = append(cidrs, cidr)
387		}
388	}
389
390	if len(cidrs) == 0 {
391		return delegate
392	}
393
394	return func(req *http.Request) (*url.URL, error) {
395		ip := net.ParseIP(req.URL.Hostname())
396		if ip == nil {
397			return delegate(req)
398		}
399
400		for _, cidr := range cidrs {
401			if cidr.Contains(ip) {
402				return nil, nil
403			}
404		}
405
406		return delegate(req)
407	}
408}
409
410// DialerFunc implements Dialer for the provided function.
411type DialerFunc func(req *http.Request) (net.Conn, error)
412
413func (fn DialerFunc) Dial(req *http.Request) (net.Conn, error) {
414	return fn(req)
415}
416
417// Dialer dials a host and writes a request to it.
418type Dialer interface {
419	// Dial connects to the host specified by req's URL, writes the request to the connection, and
420	// returns the opened net.Conn.
421	Dial(req *http.Request) (net.Conn, error)
422}
423
424// ConnectWithRedirects uses dialer to send req, following up to 10 redirects (relative to
425// originalLocation). It returns the opened net.Conn and the raw response bytes.
426// If requireSameHostRedirects is true, only redirects to the same host are permitted.
427func ConnectWithRedirects(originalMethod string, originalLocation *url.URL, header http.Header, originalBody io.Reader, dialer Dialer, requireSameHostRedirects bool) (net.Conn, []byte, error) {
428	const (
429		maxRedirects    = 9     // Fail on the 10th redirect
430		maxResponseSize = 16384 // play it safe to allow the potential for lots of / large headers
431	)
432
433	var (
434		location         = originalLocation
435		method           = originalMethod
436		intermediateConn net.Conn
437		rawResponse      = bytes.NewBuffer(make([]byte, 0, 256))
438		body             = originalBody
439	)
440
441	defer func() {
442		if intermediateConn != nil {
443			intermediateConn.Close()
444		}
445	}()
446
447redirectLoop:
448	for redirects := 0; ; redirects++ {
449		if redirects > maxRedirects {
450			return nil, nil, fmt.Errorf("too many redirects (%d)", redirects)
451		}
452
453		req, err := http.NewRequest(method, location.String(), body)
454		if err != nil {
455			return nil, nil, err
456		}
457
458		req.Header = header
459
460		intermediateConn, err = dialer.Dial(req)
461		if err != nil {
462			return nil, nil, err
463		}
464
465		// Peek at the backend response.
466		rawResponse.Reset()
467		respReader := bufio.NewReader(io.TeeReader(
468			io.LimitReader(intermediateConn, maxResponseSize), // Don't read more than maxResponseSize bytes.
469			rawResponse)) // Save the raw response.
470		resp, err := http.ReadResponse(respReader, nil)
471		if err != nil {
472			// Unable to read the backend response; let the client handle it.
473			klog.Warningf("Error reading backend response: %v", err)
474			break redirectLoop
475		}
476
477		switch resp.StatusCode {
478		case http.StatusFound:
479			// Redirect, continue.
480		default:
481			// Don't redirect.
482			break redirectLoop
483		}
484
485		// Redirected requests switch to "GET" according to the HTTP spec:
486		// https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3
487		method = "GET"
488		// don't send a body when following redirects
489		body = nil
490
491		resp.Body.Close() // not used
492
493		// Prepare to follow the redirect.
494		redirectStr := resp.Header.Get("Location")
495		if redirectStr == "" {
496			return nil, nil, fmt.Errorf("%d response missing Location header", resp.StatusCode)
497		}
498		// We have to parse relative to the current location, NOT originalLocation. For example,
499		// if we request http://foo.com/a and get back "http://bar.com/b", the result should be
500		// http://bar.com/b. If we then make that request and get back a redirect to "/c", the result
501		// should be http://bar.com/c, not http://foo.com/c.
502		location, err = location.Parse(redirectStr)
503		if err != nil {
504			return nil, nil, fmt.Errorf("malformed Location header: %v", err)
505		}
506
507		// Only follow redirects to the same host. Otherwise, propagate the redirect response back.
508		if requireSameHostRedirects && location.Hostname() != originalLocation.Hostname() {
509			return nil, nil, fmt.Errorf("hostname mismatch: expected %s, found %s", originalLocation.Hostname(), location.Hostname())
510		}
511
512		// Reset the connection.
513		intermediateConn.Close()
514		intermediateConn = nil
515	}
516
517	connToReturn := intermediateConn
518	intermediateConn = nil // Don't close the connection when we return it.
519	return connToReturn, rawResponse.Bytes(), nil
520}
521
522// CloneRequest creates a shallow copy of the request along with a deep copy of the Headers.
523func CloneRequest(req *http.Request) *http.Request {
524	r := new(http.Request)
525
526	// shallow clone
527	*r = *req
528
529	// deep copy headers
530	r.Header = CloneHeader(req.Header)
531
532	return r
533}
534
535// CloneHeader creates a deep copy of an http.Header.
536func CloneHeader(in http.Header) http.Header {
537	out := make(http.Header, len(in))
538	for key, values := range in {
539		newValues := make([]string, len(values))
540		copy(newValues, values)
541		out[key] = newValues
542	}
543	return out
544}
545
546// WarningHeader contains a single RFC2616 14.46 warnings header
547type WarningHeader struct {
548	// Codeindicates the type of warning. 299 is a miscellaneous persistent warning
549	Code int
550	// Agent contains the name or pseudonym of the server adding the Warning header.
551	// A single "-" is recommended when agent is unknown.
552	Agent string
553	// Warning text
554	Text string
555}
556
557// ParseWarningHeaders extract RFC2616 14.46 warnings headers from the specified set of header values.
558// Multiple comma-separated warnings per header are supported.
559// If errors are encountered on a header, the remainder of that header are skipped and subsequent headers are parsed.
560// Returns successfully parsed warnings and any errors encountered.
561func ParseWarningHeaders(headers []string) ([]WarningHeader, []error) {
562	var (
563		results []WarningHeader
564		errs    []error
565	)
566	for _, header := range headers {
567		for len(header) > 0 {
568			result, remainder, err := ParseWarningHeader(header)
569			if err != nil {
570				errs = append(errs, err)
571				break
572			}
573			results = append(results, result)
574			header = remainder
575		}
576	}
577	return results, errs
578}
579
580var (
581	codeMatcher = regexp.MustCompile(`^[0-9]{3}$`)
582	wordDecoder = &mime.WordDecoder{}
583)
584
585// ParseWarningHeader extracts one RFC2616 14.46 warning from the specified header,
586// returning an error if the header does not contain a correctly formatted warning.
587// Any remaining content in the header is returned.
588func ParseWarningHeader(header string) (result WarningHeader, remainder string, err error) {
589	// https://tools.ietf.org/html/rfc2616#section-14.46
590	//   updated by
591	// https://tools.ietf.org/html/rfc7234#section-5.5
592	//   https://tools.ietf.org/html/rfc7234#appendix-A
593	//     Some requirements regarding production and processing of the Warning
594	//     header fields have been relaxed, as it is not widely implemented.
595	//     Furthermore, the Warning header field no longer uses RFC 2047
596	//     encoding, nor does it allow multiple languages, as these aspects were
597	//     not implemented.
598	//
599	// Format is one of:
600	// warn-code warn-agent "warn-text"
601	// warn-code warn-agent "warn-text" "warn-date"
602	//
603	// warn-code is a three digit number
604	// warn-agent is unquoted and contains no spaces
605	// warn-text is quoted with backslash escaping (RFC2047-encoded according to RFC2616, not encoded according to RFC7234)
606	// warn-date is optional, quoted, and in HTTP-date format (no embedded or escaped quotes)
607	//
608	// additional warnings can optionally be included in the same header by comma-separating them:
609	// warn-code warn-agent "warn-text" "warn-date"[, warn-code warn-agent "warn-text" "warn-date", ...]
610
611	// tolerate leading whitespace
612	header = strings.TrimSpace(header)
613
614	parts := strings.SplitN(header, " ", 3)
615	if len(parts) != 3 {
616		return WarningHeader{}, "", errors.New("invalid warning header: fewer than 3 segments")
617	}
618	code, agent, textDateRemainder := parts[0], parts[1], parts[2]
619
620	// verify code format
621	if !codeMatcher.Match([]byte(code)) {
622		return WarningHeader{}, "", errors.New("invalid warning header: code segment is not 3 digits between 100-299")
623	}
624	codeInt, _ := strconv.ParseInt(code, 10, 64)
625
626	// verify agent presence
627	if len(agent) == 0 {
628		return WarningHeader{}, "", errors.New("invalid warning header: empty agent segment")
629	}
630	if !utf8.ValidString(agent) || hasAnyRunes(agent, unicode.IsControl) {
631		return WarningHeader{}, "", errors.New("invalid warning header: invalid agent")
632	}
633
634	// verify textDateRemainder presence
635	if len(textDateRemainder) == 0 {
636		return WarningHeader{}, "", errors.New("invalid warning header: empty text segment")
637	}
638
639	// extract text
640	text, dateAndRemainder, err := parseQuotedString(textDateRemainder)
641	if err != nil {
642		return WarningHeader{}, "", fmt.Errorf("invalid warning header: %v", err)
643	}
644	// tolerate RFC2047-encoded text from warnings produced according to RFC2616
645	if decodedText, err := wordDecoder.DecodeHeader(text); err == nil {
646		text = decodedText
647	}
648	if !utf8.ValidString(text) || hasAnyRunes(text, unicode.IsControl) {
649		return WarningHeader{}, "", errors.New("invalid warning header: invalid text")
650	}
651	result = WarningHeader{Code: int(codeInt), Agent: agent, Text: text}
652
653	if len(dateAndRemainder) > 0 {
654		if dateAndRemainder[0] == '"' {
655			// consume date
656			foundEndQuote := false
657			for i := 1; i < len(dateAndRemainder); i++ {
658				if dateAndRemainder[i] == '"' {
659					foundEndQuote = true
660					remainder = strings.TrimSpace(dateAndRemainder[i+1:])
661					break
662				}
663			}
664			if !foundEndQuote {
665				return WarningHeader{}, "", errors.New("invalid warning header: unterminated date segment")
666			}
667		} else {
668			remainder = dateAndRemainder
669		}
670	}
671	if len(remainder) > 0 {
672		if remainder[0] == ',' {
673			// consume comma if present
674			remainder = strings.TrimSpace(remainder[1:])
675		} else {
676			return WarningHeader{}, "", errors.New("invalid warning header: unexpected token after warn-date")
677		}
678	}
679
680	return result, remainder, nil
681}
682
683func parseQuotedString(quotedString string) (string, string, error) {
684	if len(quotedString) == 0 {
685		return "", "", errors.New("invalid quoted string: 0-length")
686	}
687
688	if quotedString[0] != '"' {
689		return "", "", errors.New("invalid quoted string: missing initial quote")
690	}
691
692	quotedString = quotedString[1:]
693	var remainder string
694	escaping := false
695	closedQuote := false
696	result := &strings.Builder{}
697loop:
698	for i := 0; i < len(quotedString); i++ {
699		b := quotedString[i]
700		switch b {
701		case '"':
702			if escaping {
703				result.WriteByte(b)
704				escaping = false
705			} else {
706				closedQuote = true
707				remainder = strings.TrimSpace(quotedString[i+1:])
708				break loop
709			}
710		case '\\':
711			if escaping {
712				result.WriteByte(b)
713				escaping = false
714			} else {
715				escaping = true
716			}
717		default:
718			result.WriteByte(b)
719			escaping = false
720		}
721	}
722
723	if !closedQuote {
724		return "", "", errors.New("invalid quoted string: missing closing quote")
725	}
726	return result.String(), remainder, nil
727}
728
729func NewWarningHeader(code int, agent, text string) (string, error) {
730	if code < 0 || code > 999 {
731		return "", errors.New("code must be between 0 and 999")
732	}
733	if len(agent) == 0 {
734		agent = "-"
735	} else if !utf8.ValidString(agent) || strings.ContainsAny(agent, `\"`) || hasAnyRunes(agent, unicode.IsSpace, unicode.IsControl) {
736		return "", errors.New("agent must be valid UTF-8 and must not contain spaces, quotes, backslashes, or control characters")
737	}
738	if !utf8.ValidString(text) || hasAnyRunes(text, unicode.IsControl) {
739		return "", errors.New("text must be valid UTF-8 and must not contain control characters")
740	}
741	return fmt.Sprintf("%03d %s %s", code, agent, makeQuotedString(text)), nil
742}
743
744func hasAnyRunes(s string, runeCheckers ...func(rune) bool) bool {
745	for _, r := range s {
746		for _, checker := range runeCheckers {
747			if checker(r) {
748				return true
749			}
750		}
751	}
752	return false
753}
754
755func makeQuotedString(s string) string {
756	result := &bytes.Buffer{}
757	// opening quote
758	result.WriteRune('"')
759	for _, c := range s {
760		switch c {
761		case '"', '\\':
762			// escape " and \
763			result.WriteRune('\\')
764			result.WriteRune(c)
765		default:
766			// write everything else as-is
767			result.WriteRune(c)
768		}
769	}
770	// closing quote
771	result.WriteRune('"')
772	return result.String()
773}
774