1package proxyproto
2
3import (
4	"bufio"
5	"bytes"
6	"io"
7	"net"
8	"strconv"
9	"strings"
10	"testing"
11	"time"
12)
13
14var (
15	IPv4AddressesAndPorts        = strings.Join([]string{IP4_ADDR, IP4_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator)
16	IPv4AddressesAndInvalidPorts = strings.Join([]string{IP4_ADDR, IP4_ADDR, strconv.Itoa(INVALID_PORT), strconv.Itoa(INVALID_PORT)}, separator)
17	IPv6AddressesAndPorts        = strings.Join([]string{IP6_ADDR, IP6_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator)
18	IPv6LongAddressesAndPorts    = strings.Join([]string{IP6_LONG_ADDR, IP6_LONG_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator)
19
20	fixtureTCP4V1 = "PROXY TCP4 " + IPv4AddressesAndPorts + crlf + "GET /"
21	fixtureTCP6V1 = "PROXY TCP6 " + IPv6AddressesAndPorts + crlf + "GET /"
22
23	fixtureTCP6V1Overflow = "PROXY TCP6 " + IPv6LongAddressesAndPorts
24
25	fixtureUnknown              = "PROXY UNKNOWN" + crlf
26	fixtureUnknownWithAddresses = "PROXY UNKNOWN " + IPv4AddressesAndInvalidPorts + crlf
27)
28
29var invalidParseV1Tests = []struct {
30	desc          string
31	reader        *bufio.Reader
32	expectedError error
33}{
34	{
35		desc:          "no signature",
36		reader:        newBufioReader([]byte(NO_PROTOCOL)),
37		expectedError: ErrNoProxyProtocol,
38	},
39	{
40		desc:          "prox",
41		reader:        newBufioReader([]byte("PROX")),
42		expectedError: ErrNoProxyProtocol,
43	},
44	{
45		desc:          "proxy lf",
46		reader:        newBufioReader([]byte("PROXY \n")),
47		expectedError: ErrLineMustEndWithCrlf,
48	},
49	{
50		desc:          "proxy crlf",
51		reader:        newBufioReader([]byte("PROXY " + crlf)),
52		expectedError: ErrCantReadAddressFamilyAndProtocol,
53	},
54	{
55		desc:          "proxy no space crlf",
56		reader:        newBufioReader([]byte("PROXY" + crlf)),
57		expectedError: ErrCantReadAddressFamilyAndProtocol,
58	},
59	{
60		desc:          "proxy something crlf",
61		reader:        newBufioReader([]byte("PROXY SOMETHING" + crlf)),
62		expectedError: ErrCantReadAddressFamilyAndProtocol,
63	},
64	{
65		desc:          "incomplete signature TCP4",
66		reader:        newBufioReader([]byte("PROXY TCP4 " + IPv4AddressesAndPorts)),
67		expectedError: ErrCantReadVersion1Header,
68	},
69	{
70		desc:          "TCP6 with IPv4 addresses",
71		reader:        newBufioReader([]byte("PROXY TCP6 " + IPv4AddressesAndPorts + crlf)),
72		expectedError: ErrInvalidAddress,
73	},
74	{
75		desc:          "TCP4 with IPv6 addresses",
76		reader:        newBufioReader([]byte("PROXY TCP4 " + IPv6AddressesAndPorts + crlf)),
77		expectedError: ErrInvalidAddress,
78	},
79	{
80		desc:          "TCP4 with invalid port",
81		reader:        newBufioReader([]byte("PROXY TCP4 " + IPv4AddressesAndInvalidPorts + crlf)),
82		expectedError: ErrInvalidPortNumber,
83	},
84	{
85		desc:          "header too long",
86		reader:        newBufioReader([]byte("PROXY UNKNOWN " + IPv6LongAddressesAndPorts + " " + crlf)),
87		expectedError: ErrVersion1HeaderTooLong,
88	},
89}
90
91func TestReadV1Invalid(t *testing.T) {
92	for _, tt := range invalidParseV1Tests {
93		t.Run(tt.desc, func(t *testing.T) {
94			if _, err := Read(tt.reader); err != tt.expectedError {
95				t.Fatalf("expected %s, actual %v", tt.expectedError, err)
96			}
97		})
98	}
99}
100
101var validParseAndWriteV1Tests = []struct {
102	desc           string
103	reader         *bufio.Reader
104	expectedHeader *Header
105}{
106	{
107		desc:   "TCP4",
108		reader: bufio.NewReader(strings.NewReader(fixtureTCP4V1)),
109		expectedHeader: &Header{
110			Version:           1,
111			Command:           PROXY,
112			TransportProtocol: TCPv4,
113			SourceAddr:        v4addr,
114			DestinationAddr:   v4addr,
115		},
116	},
117	{
118		desc:   "TCP6",
119		reader: bufio.NewReader(strings.NewReader(fixtureTCP6V1)),
120		expectedHeader: &Header{
121			Version:           1,
122			Command:           PROXY,
123			TransportProtocol: TCPv6,
124			SourceAddr:        v6addr,
125			DestinationAddr:   v6addr,
126		},
127	},
128	{
129		desc:   "unknown",
130		reader: bufio.NewReader(strings.NewReader(fixtureUnknown)),
131		expectedHeader: &Header{
132			Version:           1,
133			Command:           LOCAL,
134			TransportProtocol: UNSPEC,
135			SourceAddr:        nil,
136			DestinationAddr:   nil,
137		},
138	},
139	{
140		desc:   "unknown with addresses and ports",
141		reader: bufio.NewReader(strings.NewReader(fixtureUnknownWithAddresses)),
142		expectedHeader: &Header{
143			Version:           1,
144			Command:           LOCAL,
145			TransportProtocol: UNSPEC,
146			SourceAddr:        nil,
147			DestinationAddr:   nil,
148		},
149	},
150}
151
152func TestParseV1Valid(t *testing.T) {
153	for _, tt := range validParseAndWriteV1Tests {
154		t.Run(tt.desc, func(t *testing.T) {
155			header, err := Read(tt.reader)
156			if err != nil {
157				t.Fatal("unexpected error", err.Error())
158			}
159			if !header.EqualsTo(tt.expectedHeader) {
160				t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, header)
161			}
162		})
163	}
164}
165
166func TestWriteV1Valid(t *testing.T) {
167	for _, tt := range validParseAndWriteV1Tests {
168		t.Run(tt.desc, func(t *testing.T) {
169			var b bytes.Buffer
170			w := bufio.NewWriter(&b)
171			if _, err := tt.expectedHeader.WriteTo(w); err != nil {
172				t.Fatal("unexpected error ", err)
173			}
174			w.Flush()
175
176			// Read written bytes to validate written header
177			r := bufio.NewReader(&b)
178			newHeader, err := Read(r)
179			if err != nil {
180				t.Fatal("unexpected error ", err)
181			}
182
183			if !newHeader.EqualsTo(tt.expectedHeader) {
184				t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, newHeader)
185			}
186		})
187	}
188}
189
190// Tests for parseVersion1 overflow - issue #69.
191
192type dataSource struct {
193	NBytes int
194	NRead  int
195}
196
197func (ds *dataSource) Read(b []byte) (int, error) {
198	if ds.NRead >= ds.NBytes {
199		return 0, io.EOF
200	}
201	avail := ds.NBytes - ds.NRead
202	if len(b) < avail {
203		avail = len(b)
204	}
205	for i := 0; i < avail; i++ {
206		b[i] = 0x20
207	}
208	ds.NRead += avail
209	return avail, nil
210}
211
212func TestParseVersion1Overflow(t *testing.T) {
213	ds := &dataSource{}
214	reader := bufio.NewReader(ds)
215	bufSize := reader.Size()
216	ds.NBytes = bufSize * 16
217	parseVersion1(reader)
218	if ds.NRead > bufSize {
219		t.Fatalf("read: expected max %d bytes, actual %d\n", bufSize, ds.NRead)
220	}
221}
222
223func listen(t *testing.T) *Listener {
224	l, err := net.Listen("tcp", "127.0.0.1:0")
225	if err != nil {
226		t.Fatalf("listen: %v", err)
227	}
228	return &Listener{Listener: l}
229}
230
231func client(t *testing.T, addr, header string, length int, terminate bool, wait time.Duration, done chan struct{}) {
232	c, err := net.Dial("tcp", addr)
233	if err != nil {
234		t.Fatalf("dial: %v", err)
235	}
236	defer c.Close()
237
238	if terminate && length < 2 {
239		length = 2
240	}
241
242	buf := make([]byte, len(header)+length)
243	copy(buf, []byte(header))
244	for i := 0; i < length-2; i++ {
245		buf[i+len(header)] = 0x20
246	}
247	if terminate {
248		copy(buf[len(header)+length-2:], []byte(crlf))
249	}
250
251	n, err := c.Write(buf)
252	if err != nil {
253		t.Fatalf("write: %v", err)
254	}
255	if n != len(buf) {
256		t.Fatalf("write; short write")
257	}
258
259	time.Sleep(wait)
260	close(done)
261}
262
263func TestVersion1Overflow(t *testing.T) {
264	done := make(chan struct{})
265
266	l := listen(t)
267	go client(t, l.Addr().String(), fixtureTCP6V1Overflow, 10240, true, 10*time.Second, done)
268
269	c, err := l.Accept()
270	if err != nil {
271		t.Fatalf("accept: %v", err)
272	}
273
274	b := []byte{}
275	_, err = c.Read(b)
276	if err == nil {
277		t.Fatalf("net.Conn: no error reported for oversized header")
278	}
279}
280
281func TestVersion1SlowLoris(t *testing.T) {
282	done := make(chan struct{})
283	timeout := make(chan error)
284
285	l := listen(t)
286	go client(t, l.Addr().String(), fixtureTCP6V1Overflow, 0, false, 10*time.Second, done)
287
288	c, err := l.Accept()
289	if err != nil {
290		t.Fatalf("accept: %v", err)
291	}
292
293	go func() {
294		b := []byte{}
295		_, err = c.Read(b)
296		timeout <- err
297	}()
298
299	select {
300	case <-done:
301		t.Fatalf("net.Conn: reader still blocked after 10 seconds")
302	case err := <-timeout:
303		if err == nil {
304			t.Fatalf("net.Conn: no error reported for incomplete header")
305		}
306	}
307}
308
309func TestVersion1SlowLorisOverflow(t *testing.T) {
310	done := make(chan struct{})
311	timeout := make(chan error)
312
313	l := listen(t)
314	go client(t, l.Addr().String(), fixtureTCP6V1Overflow, 10240, false, 10*time.Second, done)
315
316	c, err := l.Accept()
317	if err != nil {
318		t.Fatalf("accept: %v", err)
319	}
320
321	go func() {
322		b := []byte{}
323		_, err = c.Read(b)
324		timeout <- err
325	}()
326
327	select {
328	case <-done:
329		t.Fatalf("net.Conn: reader still blocked after 10 seconds")
330	case err := <-timeout:
331		if err == nil {
332			t.Fatalf("net.Conn: no error reported for incomplete and overflowed header")
333		}
334	}
335}
336