1package proxyproto
2
3import (
4	"bufio"
5	"bytes"
6	"encoding/binary"
7	"math/rand"
8	"reflect"
9	"testing"
10)
11
12var (
13	invalidRune = byte('\x99')
14
15	// Lengths to use in tests
16	lengthPadded = uint16(84)
17
18	lengthEmptyBytes = func() []byte {
19		a := make([]byte, 2)
20		binary.BigEndian.PutUint16(a, 0)
21		return a
22	}()
23	lengthPaddedBytes = func() []byte {
24		a := make([]byte, 2)
25		binary.BigEndian.PutUint16(a, lengthPadded)
26		return a
27	}()
28
29	// If life gives you lemons, make mojitos
30	portBytes = func() []byte {
31		a := make([]byte, 2)
32		binary.BigEndian.PutUint16(a, PORT)
33		return a
34	}()
35
36	unixBytes = pad([]byte("socket"), 108)
37
38	// Tests don't care if source and destination addresses and ports are the same
39	addressesIPv4 = append(v4ip.To4(), v4ip.To4()...)
40	addressesIPv6 = append(v6ip.To16(), v6ip.To16()...)
41	ports         = append(portBytes, portBytes...)
42
43	// Fixtures to use in tests
44	fixtureIPv4Address  = append(addressesIPv4, ports...)
45	fixtureIPv4V2       = append(lengthV4Bytes, fixtureIPv4Address...)
46	fixtureIPv4V2Padded = append(append(lengthPaddedBytes, fixtureIPv4Address...), make([]byte, lengthPadded-lengthV4)...)
47	fixtureIPv6Address  = append(addressesIPv6, ports...)
48	fixtureIPv6V2       = append(lengthV6Bytes, fixtureIPv6Address...)
49	fixtureIPv6V2Padded = append(append(lengthPaddedBytes, fixtureIPv6Address...), make([]byte, lengthPadded-lengthV6)...)
50	fixtureUnixAddress  = append(unixBytes, unixBytes...)
51	fixtureUnixV2       = append(lengthUnixBytes, fixtureUnixAddress...)
52	fixtureTLV          = func() []byte {
53		tlv := make([]byte, 2+rand.Intn(1<<12)) // Not enough to overflow, at least size two
54		rand.Read(tlv)
55		return tlv
56	}()
57	fixtureIPv4V2TLV = fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixtureTLV)
58	fixtureIPv6V2TLV = fixtureWithTLV(lengthV6Bytes, fixtureIPv6Address, fixtureTLV)
59	fixtureUnspecTLV = fixtureWithTLV(lengthUnspecBytes, []byte{}, fixtureTLV)
60
61	// Arbitrary bytes following proxy bytes
62	arbitraryTailBytes = []byte{'\x99', '\x97', '\x98'}
63)
64
65func pad(b []byte, n int) []byte {
66	padding := make([]byte, n-len(b))
67	return append(b, padding...)
68}
69
70var invalidParseV2Tests = []struct {
71	desc          string
72	reader        *bufio.Reader
73	expectedError error
74}{
75	{
76		desc:          "no signature",
77		reader:        newBufioReader([]byte(NO_PROTOCOL)),
78		expectedError: ErrNoProxyProtocol,
79	},
80	{
81		desc:          "truncated v2 signature",
82		reader:        newBufioReader(SIGV2[2:]),
83		expectedError: ErrNoProxyProtocol,
84	},
85	{
86		desc:          "v2 signature and nothing else",
87		reader:        newBufioReader(SIGV2),
88		expectedError: ErrCantReadProtocolVersionAndCommand,
89	},
90	{
91		desc:          "v2 signature with invalid command",
92		reader:        newBufioReader(append(SIGV2, invalidRune)),
93		expectedError: ErrUnsupportedProtocolVersionAndCommand,
94	},
95	{
96		desc:          "v2 signature with command but nothing else",
97		reader:        newBufioReader(append(SIGV2, byte(PROXY))),
98		expectedError: ErrCantReadAddressFamilyAndProtocol,
99	},
100	{
101		desc:          "command proxy but inet family unspec",
102		reader:        newBufioReader(append(SIGV2, byte(PROXY), byte(UNSPEC))),
103		expectedError: ErrUnsupportedAddressFamilyAndProtocol,
104	},
105	{
106		desc:          "v2 signature with command and invalid inet family", // translated to UNSPEC
107		reader:        newBufioReader(append(SIGV2, byte(PROXY), invalidRune)),
108		expectedError: ErrCantReadLength,
109	},
110	{
111		desc:          "TCPv4 but no length",
112		reader:        newBufioReader(append(SIGV2, byte(PROXY), byte(TCPv4))),
113		expectedError: ErrCantReadLength,
114	},
115	{
116		desc:          "TCPv4 but invalid length",
117		reader:        newBufioReader(append(SIGV2, byte(PROXY), byte(TCPv4), invalidRune)),
118		expectedError: ErrCantReadLength,
119	},
120	{
121		desc:          "unspec but no length",
122		reader:        newBufioReader(append(SIGV2, byte(LOCAL), byte(UNSPEC))),
123		expectedError: ErrCantReadLength,
124	},
125	{
126		desc:          "TCPv4 with mismatching length",
127		reader:        newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), lengthV4Bytes...)),
128		expectedError: ErrInvalidLength,
129	},
130	{
131		desc:          "TCPv6 with mismatching length",
132		reader:        newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv6)), lengthV6Bytes...)),
133		expectedError: ErrInvalidLength,
134	},
135	{
136		desc:          "TCPv4 length zero but with address and ports",
137		reader:        newBufioReader(append(append(append(SIGV2, byte(PROXY), byte(TCPv4)), lengthEmptyBytes...), fixtureIPv6Address...)),
138		expectedError: ErrInvalidLength,
139	},
140	{
141		desc:          "TCPv6 with IPv6 length but IPv4 address and ports",
142		reader:        newBufioReader(append(append(append(SIGV2, byte(PROXY), byte(TCPv6)), lengthV6Bytes...), fixtureIPv4Address...)),
143		expectedError: ErrInvalidLength,
144	},
145	{
146		desc:          "unspec length greater than zero but no TLVs",
147		reader:        newBufioReader(append(append(SIGV2, byte(LOCAL), byte(UNSPEC)), fixtureUnspecTLV[:2]...)),
148		expectedError: ErrInvalidLength,
149	},
150}
151
152func TestParseV2Invalid(t *testing.T) {
153	for _, tt := range invalidParseV2Tests {
154		t.Run(tt.desc, func(t *testing.T) {
155			if _, err := Read(tt.reader); err != tt.expectedError {
156				t.Fatalf("expected %s, actual %s", tt.expectedError, err.Error())
157			}
158		})
159	}
160}
161
162var validParseAndWriteV2Tests = []struct {
163	desc           string
164	reader         *bufio.Reader
165	expectedHeader *Header
166}{
167	{
168		desc:   "local",
169		reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(TCPv4)), fixtureIPv4V2...)),
170		expectedHeader: &Header{
171			Version:           2,
172			Command:           LOCAL,
173			TransportProtocol: TCPv4,
174			SourceAddr:        v4addr,
175			DestinationAddr:   v4addr,
176		},
177	},
178	{
179		desc:   "local unspec",
180		reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(UNSPEC)), lengthUnspecBytes...)),
181		expectedHeader: &Header{
182			Version:           2,
183			Command:           LOCAL,
184			TransportProtocol: UNSPEC,
185			SourceAddr:        nil,
186			DestinationAddr:   nil,
187		},
188	},
189	{
190		desc:   "proxy TCPv4",
191		reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureIPv4V2...)),
192		expectedHeader: &Header{
193			Version:           2,
194			Command:           PROXY,
195			TransportProtocol: TCPv4,
196			SourceAddr:        v4addr,
197			DestinationAddr:   v4addr,
198		},
199	},
200	{
201		desc:   "proxy TCPv6",
202		reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv6)), fixtureIPv6V2...)),
203		expectedHeader: &Header{
204			Version:           2,
205			Command:           PROXY,
206			TransportProtocol: TCPv6,
207			SourceAddr:        v6addr,
208			DestinationAddr:   v6addr,
209		},
210	},
211	{
212		desc:   "proxy TCPv4 with TLV",
213		reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureIPv4V2TLV...)),
214		expectedHeader: &Header{
215			Version:           2,
216			Command:           PROXY,
217			TransportProtocol: TCPv4,
218			SourceAddr:        v4addr,
219			DestinationAddr:   v4addr,
220			rawTLVs:           fixtureTLV,
221		},
222	},
223	{
224		desc:   "proxy TCPv6 with TLV",
225		reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv6)), fixtureIPv6V2TLV...)),
226		expectedHeader: &Header{
227			Version:           2,
228			Command:           PROXY,
229			TransportProtocol: TCPv6,
230			SourceAddr:        v6addr,
231			DestinationAddr:   v6addr,
232			rawTLVs:           fixtureTLV,
233		},
234	},
235	{
236		desc:   "local unspec with TLV",
237		reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(UNSPEC)), fixtureUnspecTLV...)),
238		expectedHeader: &Header{
239			Version:           2,
240			Command:           LOCAL,
241			TransportProtocol: UNSPEC,
242			SourceAddr:        nil,
243			DestinationAddr:   nil,
244			rawTLVs:           fixtureTLV,
245		},
246	},
247	{
248		desc:   "proxy UDPv4",
249		reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UDPv4)), fixtureIPv4V2...)),
250		expectedHeader: &Header{
251			Version:           2,
252			Command:           PROXY,
253			TransportProtocol: UDPv4,
254			SourceAddr:        v4UDPAddr,
255			DestinationAddr:   v4UDPAddr,
256		},
257	},
258	{
259		desc:   "proxy UDPv6",
260		reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UDPv6)), fixtureIPv6V2...)),
261		expectedHeader: &Header{
262			Version:           2,
263			Command:           PROXY,
264			TransportProtocol: UDPv6,
265			SourceAddr:        v6UDPAddr,
266			DestinationAddr:   v6UDPAddr,
267		},
268	},
269	{
270		desc:   "proxy unix stream",
271		reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UnixStream)), fixtureUnixV2...)),
272		expectedHeader: &Header{
273			Version:           2,
274			Command:           PROXY,
275			TransportProtocol: UnixStream,
276			SourceAddr:        unixStreamAddr,
277			DestinationAddr:   unixStreamAddr,
278		},
279	},
280	{
281		desc:   "proxy unix datagram",
282		reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UnixDatagram)), fixtureUnixV2...)),
283		expectedHeader: &Header{
284			Version:           2,
285			Command:           PROXY,
286			TransportProtocol: UnixDatagram,
287			SourceAddr:        unixDatagramAddr,
288			DestinationAddr:   unixDatagramAddr,
289		},
290	},
291}
292
293func TestParseV2Valid(t *testing.T) {
294	for _, tt := range validParseAndWriteV2Tests {
295		t.Run(tt.desc, func(t *testing.T) {
296			header, err := Read(tt.reader)
297			if err != nil {
298				t.Fatal("unexpected error", err.Error())
299			}
300			if !header.EqualsTo(tt.expectedHeader) {
301				t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, header)
302			}
303		})
304	}
305}
306
307func TestWriteV2Valid(t *testing.T) {
308	for _, tt := range validParseAndWriteV2Tests {
309		t.Run(tt.desc, func(t *testing.T) {
310			var b bytes.Buffer
311			w := bufio.NewWriter(&b)
312			if _, err := tt.expectedHeader.WriteTo(w); err != nil {
313				t.Fatal("unexpected error ", err)
314			}
315			w.Flush()
316
317			// Read written bytes to validate written header
318			r := bufio.NewReader(&b)
319			newHeader, err := Read(r)
320			if err != nil {
321				t.Fatal("unexpected error ", err)
322			}
323
324			if !newHeader.EqualsTo(tt.expectedHeader) {
325				t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, newHeader)
326			}
327		})
328	}
329}
330
331var validParseV2PaddedTests = []struct {
332	desc           string
333	value          []byte
334	expectedHeader *Header
335}{
336	{
337		desc:  "proxy TCPv4",
338		value: append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureIPv4V2Padded...),
339		expectedHeader: &Header{
340			Version:           2,
341			Command:           PROXY,
342			TransportProtocol: TCPv4,
343			SourceAddr:        v4addr,
344			DestinationAddr:   v4addr,
345			rawTLVs:           make([]byte, lengthPadded-lengthV4),
346		},
347	},
348	{
349		desc:  "proxy TCPv6",
350		value: append(append(SIGV2, byte(PROXY), byte(TCPv6)), fixtureIPv6V2Padded...),
351		expectedHeader: &Header{
352			Version:           2,
353			Command:           PROXY,
354			TransportProtocol: TCPv6,
355			SourceAddr:        v6addr,
356			DestinationAddr:   v6addr,
357			rawTLVs:           make([]byte, lengthPadded-lengthV6),
358		},
359	},
360	{
361		desc:  "proxy UDPv4",
362		value: append(append(SIGV2, byte(PROXY), byte(UDPv4)), fixtureIPv4V2Padded...),
363		expectedHeader: &Header{
364			Version:           2,
365			Command:           PROXY,
366			TransportProtocol: UDPv4,
367			SourceAddr:        v4addr,
368			DestinationAddr:   v4addr,
369			rawTLVs:           make([]byte, lengthPadded-lengthV4),
370		},
371	},
372	{
373		desc:  "proxy UDPv6",
374		value: append(append(SIGV2, byte(PROXY), byte(UDPv6)), fixtureIPv6V2Padded...),
375		expectedHeader: &Header{
376			Version:           2,
377			Command:           PROXY,
378			TransportProtocol: UDPv6,
379			SourceAddr:        v6addr,
380			DestinationAddr:   v6addr,
381			rawTLVs:           make([]byte, lengthPadded-lengthV6),
382		},
383	},
384}
385
386func TestParseV2Padded(t *testing.T) {
387	for _, tt := range validParseV2PaddedTests {
388		t.Run(tt.desc, func(t *testing.T) {
389			reader := newBufioReader(append(tt.value, arbitraryTailBytes...))
390
391			newHeader, err := Read(reader)
392			if err != nil {
393				t.Fatal("unexpected error ", err)
394			}
395			if !newHeader.EqualsTo(tt.expectedHeader) {
396				t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, newHeader)
397			}
398
399			// Check that remaining padding bytes have been flushed
400			nextBytes, err := reader.Peek(len(arbitraryTailBytes))
401			if err != nil {
402				t.Fatal("unexpected error ", err)
403			}
404			if !reflect.DeepEqual(nextBytes, arbitraryTailBytes) {
405				t.Fatalf("expected %#v, actual %#v", arbitraryTailBytes, nextBytes)
406			}
407		})
408	}
409}
410
411func TestV2EqualsToTLV(t *testing.T) {
412	eHdr := &Header{
413		Version:           2,
414		Command:           PROXY,
415		TransportProtocol: TCPv4,
416		SourceAddr:        v4addr,
417		DestinationAddr:   v4addr,
418	}
419	hdr, err := Read(newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureIPv4V2TLV...)))
420	if err != nil {
421		t.Fatal("unexpected error ", err)
422	}
423	if eHdr.EqualsTo(hdr) {
424		t.Fatalf("unexpectedly equal created: %#v, parsed: %#v", eHdr, hdr)
425	}
426	eHdr.rawTLVs = fixtureTLV[:]
427
428	if !eHdr.EqualsTo(hdr) {
429		t.Fatalf("unexpectedly unequal after tlv copy created: %#v, parsed: %#v", eHdr, hdr)
430	}
431
432	eHdr.rawTLVs[0] = eHdr.rawTLVs[0] + 1
433	if eHdr.EqualsTo(hdr) {
434		t.Fatalf("unexpectedly equal after changing tlv created: %#v, parsed: %#v", eHdr, hdr)
435	}
436}
437
438var tlvFormatTests = []struct {
439	desc   string
440	header *Header
441}{
442	{
443		desc: "proxy TCPv4",
444		header: &Header{
445			Version:           2,
446			Command:           PROXY,
447			TransportProtocol: TCPv4,
448			SourceAddr:        v4addr,
449			DestinationAddr:   v4addr,
450			rawTLVs:           make([]byte, 1<<16),
451		},
452	},
453	{
454		desc: "proxy TCPv6",
455		header: &Header{
456			Version:           2,
457			Command:           PROXY,
458			TransportProtocol: TCPv6,
459			SourceAddr:        v6addr,
460			DestinationAddr:   v6addr,
461			rawTLVs:           make([]byte, 1<<16),
462		},
463	},
464	{
465		desc: "proxy UDPv4",
466		header: &Header{
467			Version:           2,
468			Command:           PROXY,
469			TransportProtocol: UDPv4,
470			SourceAddr:        v4addr,
471			DestinationAddr:   v4addr,
472			rawTLVs:           make([]byte, 1<<16),
473		},
474	},
475	{
476		desc: "proxy UDPv6",
477		header: &Header{
478			Version:           2,
479			Command:           PROXY,
480			TransportProtocol: UDPv6,
481			SourceAddr:        v6addr,
482			DestinationAddr:   v6addr,
483			rawTLVs:           make([]byte, 1<<16),
484		},
485	},
486	{
487		desc: "local unspec",
488		header: &Header{
489			Version:           2,
490			Command:           LOCAL,
491			TransportProtocol: UNSPEC,
492			SourceAddr:        nil,
493			DestinationAddr:   nil,
494			rawTLVs:           make([]byte, 1<<16),
495		},
496	},
497}
498
499func TestV2TLVFormatTooLargeTLV(t *testing.T) {
500	for _, tt := range tlvFormatTests {
501		t.Run(tt.desc, func(t *testing.T) {
502			if _, err := tt.header.Format(); err != errUint16Overflow {
503				t.Fatalf("missing or expected error when formatting too-large TLV %#v", err)
504			}
505		})
506
507	}
508}
509
510func newBufioReader(b []byte) *bufio.Reader {
511	return bufio.NewReader(bytes.NewReader(b))
512}
513
514func fixtureWithTLV(cur []byte, addr []byte, tlv []byte) []byte {
515	tlen, err := addTLVLen(cur, len(tlv))
516	if err != nil {
517		panic(err)
518	}
519
520	return append(append(tlen, addr...), tlv...)
521}
522