1// Copyright (c) 2020 Shivaram Lingamneni <slingamn@cs.stanford.edu>
2// released under the MIT license
3
4package utils
5
6import (
7	"crypto/tls"
8	"encoding/binary"
9	"io"
10	"net"
11	"strings"
12	"sync"
13	"time"
14)
15
16const (
17	// https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
18	// "a 108-byte buffer is always enough to store all the line and a trailing zero
19	// for string processing."
20	maxProxyLineLenV1 = 107
21)
22
23// XXX implement net.Error with a Temporary() method that returns true;
24// otherwise, ErrBadProxyLine will cause (*http.Server).Serve() to exit
25type proxyLineError struct{}
26
27func (p *proxyLineError) Error() string {
28	return "invalid PROXY line"
29}
30
31func (p *proxyLineError) Timeout() bool {
32	return false
33}
34
35func (p *proxyLineError) Temporary() bool {
36	return true
37}
38
39var (
40	ErrBadProxyLine error = &proxyLineError{}
41)
42
43// ListenerConfig is all the information about how to process
44// incoming IRC connections on a listener.
45type ListenerConfig struct {
46	TLSConfig     *tls.Config
47	ProxyDeadline time.Duration
48	RequireProxy  bool
49	// these are just metadata for easier tracking,
50	// they are not used by ReloadableListener:
51	Tor       bool
52	STSOnly   bool
53	WebSocket bool
54	HideSTS   bool
55}
56
57// read a PROXY header (either v1 or v2), ensuring we don't read anything beyond
58// the header into a buffer (this would break the TLS handshake)
59func readRawProxyLine(conn net.Conn, deadline time.Duration) (result []byte, err error) {
60	// normally this is covered by ping timeouts, but we're doing this outside
61	// of the normal client goroutine:
62	conn.SetDeadline(time.Now().Add(deadline))
63	defer conn.SetDeadline(time.Time{})
64
65	// read the first 16 bytes of the proxy header
66	buf := make([]byte, 16, maxProxyLineLenV1)
67	_, err = io.ReadFull(conn, buf)
68	if err != nil {
69		return
70	}
71
72	switch buf[0] {
73	case 'P':
74		// PROXY v1: starts with "PROXY"
75		return readRawProxyLineV1(conn, buf)
76	case '\r':
77		// PROXY v2: starts with "\r\n\r\n"
78		return readRawProxyLineV2(conn, buf)
79	default:
80		return nil, ErrBadProxyLine
81	}
82}
83
84func readRawProxyLineV1(conn net.Conn, buf []byte) (result []byte, err error) {
85	for {
86		i := len(buf)
87		if i >= maxProxyLineLenV1 {
88			return nil, ErrBadProxyLine // did not find \r\n, fail
89		}
90		// prepare a single byte of free space, then read into it
91		buf = buf[0 : i+1]
92		_, err = io.ReadFull(conn, buf[i:])
93		if err != nil {
94			return nil, err
95		}
96		if buf[i] == '\n' {
97			return buf, nil
98		}
99	}
100}
101
102func readRawProxyLineV2(conn net.Conn, buf []byte) (result []byte, err error) {
103	// "The 15th and 16th bytes is the address length in bytes in network endian order."
104	addrLen := int(binary.BigEndian.Uint16(buf[14:16]))
105	if addrLen == 0 {
106		return buf[0:16], nil
107	} else if addrLen <= cap(buf)-16 {
108		buf = buf[0 : 16+addrLen]
109	} else {
110		// proxy source is unix domain, we don't really handle this
111		buf2 := make([]byte, 16+addrLen)
112		copy(buf2[0:16], buf[0:16])
113		buf = buf2
114	}
115	_, err = io.ReadFull(conn, buf[16:16+addrLen])
116	if err != nil {
117		return
118	}
119	return buf[0 : 16+addrLen], nil
120}
121
122// ParseProxyLine parses a PROXY protocol (v1 or v2) line and returns the remote IP.
123func ParseProxyLine(line []byte) (ip net.IP, err error) {
124	if len(line) == 0 {
125		return nil, ErrBadProxyLine
126	}
127	switch line[0] {
128	case 'P':
129		return ParseProxyLineV1(string(line))
130	case '\r':
131		return parseProxyLineV2(line)
132	default:
133		return nil, ErrBadProxyLine
134	}
135}
136
137// ParseProxyLineV1 parses a PROXY protocol (v1) line and returns the remote IP.
138func ParseProxyLineV1(line string) (ip net.IP, err error) {
139	params := strings.Fields(line)
140	if len(params) != 6 || params[0] != "PROXY" {
141		return nil, ErrBadProxyLine
142	}
143	ip = net.ParseIP(params[2])
144	if ip == nil {
145		return nil, ErrBadProxyLine
146	}
147	return ip.To16(), nil
148}
149
150func parseProxyLineV2(line []byte) (ip net.IP, err error) {
151	if len(line) < 16 {
152		return nil, ErrBadProxyLine
153	}
154	// this doesn't allocate
155	if string(line[:12]) != "\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a" {
156		return nil, ErrBadProxyLine
157	}
158	// "The next byte (the 13th one) is the protocol version and command."
159	versionCmd := line[12]
160	// "The highest four bits contains the version [....] it must always be sent as \x2"
161	if (versionCmd >> 4) != 2 {
162		return nil, ErrBadProxyLine
163	}
164	// "The lowest four bits represents the command"
165	switch versionCmd & 0x0f {
166	case 0:
167		return nil, nil // LOCAL command
168	case 1:
169		// PROXY command, continue below
170	default:
171		// "Receivers must drop connections presenting unexpected values here"
172		return nil, ErrBadProxyLine
173	}
174
175	var addrLen int
176	// "The 14th byte contains the transport protocol and address family."
177	protoAddr := line[13]
178	// "The highest 4 bits contain the address family"
179	switch protoAddr >> 4 {
180	case 1:
181		addrLen = 4 // AF_INET
182	case 2:
183		addrLen = 16 // AF_INET6
184	default:
185		return nil, nil // AF_UNSPEC or AF_UNIX, either way there's no IP address
186	}
187
188	// header, source and destination address, two 16-bit port numbers:
189	expectedLen := 16 + 2*addrLen + 4
190	if len(line) < expectedLen {
191		return nil, ErrBadProxyLine
192	}
193
194	// "Starting from the 17th byte, addresses are presented in network byte order.
195	//  The address order is always the same :
196	//    - source layer 3 address in network byte order [...]"
197	if addrLen == 4 {
198		ip = net.IP(line[16 : 16+addrLen]).To16()
199	} else {
200		ip = make(net.IP, addrLen)
201		copy(ip, line[16:16+addrLen])
202	}
203	return ip, nil
204}
205
206/// WrappedConn is a net.Conn with some additional data stapled to it;
207// the proxied IP, if one was read via the PROXY protocol, and the listener
208// configuration.
209type WrappedConn struct {
210	net.Conn
211	ProxiedIP net.IP
212	Config    ListenerConfig
213	// Secure indicates whether we believe the connection between us and the client
214	// was secure against interception and modification (including all proxies):
215	Secure bool
216}
217
218// ReloadableListener is a wrapper for net.Listener that allows reloading
219// of config data for postprocessing connections (TLS, PROXY protocol, etc.)
220type ReloadableListener struct {
221	// TODO: make this lock-free
222	sync.Mutex
223	realListener net.Listener
224	config       ListenerConfig
225	isClosed     bool
226}
227
228func NewReloadableListener(realListener net.Listener, config ListenerConfig) *ReloadableListener {
229	return &ReloadableListener{
230		realListener: realListener,
231		config:       config,
232	}
233}
234
235func (rl *ReloadableListener) Reload(config ListenerConfig) {
236	rl.Lock()
237	rl.config = config
238	rl.Unlock()
239}
240
241func (rl *ReloadableListener) Accept() (conn net.Conn, err error) {
242	conn, err = rl.realListener.Accept()
243
244	rl.Lock()
245	config := rl.config
246	isClosed := rl.isClosed
247	rl.Unlock()
248
249	if isClosed {
250		if err == nil {
251			conn.Close()
252		}
253		err = net.ErrClosed
254	}
255	if err != nil {
256		return nil, err
257	}
258
259	var proxiedIP net.IP
260	if config.RequireProxy {
261		// this will occur synchronously on the goroutine calling Accept(),
262		// but that's OK because this listener *requires* a PROXY line,
263		// therefore it must be used with proxies that always send the line
264		// and we won't get slowloris'ed waiting for the client response
265		proxyLine, err := readRawProxyLine(conn, config.ProxyDeadline)
266		if err == nil {
267			proxiedIP, err = ParseProxyLine(proxyLine)
268		}
269		if err != nil {
270			conn.Close()
271			return nil, err
272		}
273	}
274
275	if config.TLSConfig != nil {
276		conn = tls.Server(conn, config.TLSConfig)
277	}
278
279	return &WrappedConn{
280		Conn:      conn,
281		ProxiedIP: proxiedIP,
282		Config:    config,
283	}, nil
284}
285
286func (rl *ReloadableListener) Close() error {
287	rl.Lock()
288	rl.isClosed = true
289	rl.Unlock()
290
291	return rl.realListener.Close()
292}
293
294func (rl *ReloadableListener) Addr() net.Addr {
295	return rl.realListener.Addr()
296}
297