1package proxyproto
2
3import (
4	"bufio"
5	"bytes"
6	"errors"
7	"net"
8	"reflect"
9	"testing"
10	"time"
11)
12
13// Stuff to be used in both versions tests.
14
15const (
16	NO_PROTOCOL   = "There is no spoon"
17	IP4_ADDR      = "127.0.0.1"
18	IP6_ADDR      = "::1"
19	IP6_LONG_ADDR = "1234:5678:9abc:def0:cafe:babe:dead:2bad"
20	PORT          = 65533
21	INVALID_PORT  = 99999
22)
23
24var (
25	v4ip = net.ParseIP(IP4_ADDR).To4()
26	v6ip = net.ParseIP(IP6_ADDR).To16()
27
28	v4addr net.Addr = &net.TCPAddr{IP: v4ip, Port: PORT}
29	v6addr net.Addr = &net.TCPAddr{IP: v6ip, Port: PORT}
30
31	v4UDPAddr net.Addr = &net.UDPAddr{IP: v4ip, Port: PORT}
32	v6UDPAddr net.Addr = &net.UDPAddr{IP: v6ip, Port: PORT}
33
34	unixStreamAddr   net.Addr = &net.UnixAddr{Net: "unix", Name: "socket"}
35	unixDatagramAddr net.Addr = &net.UnixAddr{Net: "unixgram", Name: "socket"}
36
37	errReadIntentionallyBroken = errors.New("read is intentionally broken")
38)
39
40type timeoutReader []byte
41
42func (t *timeoutReader) Read([]byte) (int, error) {
43	time.Sleep(500 * time.Millisecond)
44	return 0, nil
45}
46
47type errorReader []byte
48
49func (e *errorReader) Read([]byte) (int, error) {
50	return 0, errReadIntentionallyBroken
51}
52
53func TestReadTimeoutV1Invalid(t *testing.T) {
54	var b timeoutReader
55	reader := bufio.NewReader(&b)
56	_, err := ReadTimeout(reader, 50*time.Millisecond)
57	if err == nil {
58		t.Fatalf("expected error %s", ErrNoProxyProtocol)
59	} else if err != ErrNoProxyProtocol {
60		t.Fatalf("expected %s, actual %s", ErrNoProxyProtocol, err)
61	}
62}
63
64func TestReadTimeoutPropagatesReadError(t *testing.T) {
65	var e errorReader
66	reader := bufio.NewReader(&e)
67	_, err := ReadTimeout(reader, 50*time.Millisecond)
68
69	if err == nil {
70		t.Fatalf("expected error %s", errReadIntentionallyBroken)
71	} else if err != errReadIntentionallyBroken {
72		t.Fatalf("expected error %s, actual %s", errReadIntentionallyBroken, err)
73	}
74}
75
76func TestEqualsTo(t *testing.T) {
77	var headersEqual = []struct {
78		this, that *Header
79		expected   bool
80	}{
81		{
82			&Header{
83				Version:           1,
84				Command:           PROXY,
85				TransportProtocol: TCPv4,
86				SourceAddr: &net.TCPAddr{
87					IP:   net.ParseIP("10.1.1.1"),
88					Port: 1000,
89				},
90				DestinationAddr: &net.TCPAddr{
91					IP:   net.ParseIP("20.2.2.2"),
92					Port: 2000,
93				},
94			},
95			nil,
96			false,
97		},
98		{
99			&Header{
100				Version:           1,
101				Command:           PROXY,
102				TransportProtocol: TCPv4,
103				SourceAddr: &net.TCPAddr{
104					IP:   net.ParseIP("10.1.1.1"),
105					Port: 1000,
106				},
107				DestinationAddr: &net.TCPAddr{
108					IP:   net.ParseIP("20.2.2.2"),
109					Port: 2000,
110				},
111			},
112			&Header{
113				Version:           2,
114				Command:           PROXY,
115				TransportProtocol: TCPv4,
116				SourceAddr: &net.TCPAddr{
117					IP:   net.ParseIP("10.1.1.1"),
118					Port: 1000,
119				},
120				DestinationAddr: &net.TCPAddr{
121					IP:   net.ParseIP("20.2.2.2"),
122					Port: 2000,
123				},
124			},
125			false,
126		},
127		{
128			&Header{
129				Version:           1,
130				Command:           PROXY,
131				TransportProtocol: TCPv4,
132				SourceAddr: &net.TCPAddr{
133					IP:   net.ParseIP("10.1.1.1"),
134					Port: 1000,
135				},
136				DestinationAddr: &net.TCPAddr{
137					IP:   net.ParseIP("20.2.2.2"),
138					Port: 2000,
139				},
140			},
141			&Header{
142				Version:           1,
143				Command:           PROXY,
144				TransportProtocol: TCPv4,
145				SourceAddr: &net.TCPAddr{
146					IP:   net.ParseIP("10.1.1.1"),
147					Port: 1000,
148				},
149				DestinationAddr: &net.TCPAddr{
150					IP:   net.ParseIP("20.2.2.2"),
151					Port: 2000,
152				},
153			},
154			true,
155		},
156	}
157
158	for _, tt := range headersEqual {
159		if actual := tt.this.EqualsTo(tt.that); actual != tt.expected {
160			t.Fatalf("expected %t, actual %t", tt.expected, actual)
161		}
162	}
163}
164
165// This is here just because of coveralls
166func TestEqualTo(t *testing.T) {
167	TestEqualsTo(t)
168}
169
170func TestGetters(t *testing.T) {
171	var tests = []struct {
172		name                         string
173		header                       *Header
174		tcpSourceAddr, tcpDestAddr   *net.TCPAddr
175		udpSourceAddr, udpDestAddr   *net.UDPAddr
176		unixSourceAddr, unixDestAddr *net.UnixAddr
177		ipSource, ipDest             net.IP
178		portSource, portDest         int
179	}{
180		{
181			name: "TCPv4",
182			header: &Header{
183				Version:           1,
184				Command:           PROXY,
185				TransportProtocol: TCPv4,
186				SourceAddr: &net.TCPAddr{
187					IP:   net.ParseIP("10.1.1.1"),
188					Port: 1000,
189				},
190				DestinationAddr: &net.TCPAddr{
191					IP:   net.ParseIP("20.2.2.2"),
192					Port: 2000,
193				},
194			},
195			tcpSourceAddr: &net.TCPAddr{
196				IP:   net.ParseIP("10.1.1.1"),
197				Port: 1000,
198			},
199			tcpDestAddr: &net.TCPAddr{
200				IP:   net.ParseIP("20.2.2.2"),
201				Port: 2000,
202			},
203			ipSource:   net.ParseIP("10.1.1.1"),
204			ipDest:     net.ParseIP("20.2.2.2"),
205			portSource: 1000,
206			portDest:   2000,
207		},
208		{
209			name: "UDPv4",
210			header: &Header{
211				Version:           2,
212				Command:           PROXY,
213				TransportProtocol: UDPv6,
214				SourceAddr: &net.UDPAddr{
215					IP:   net.ParseIP("10.1.1.1"),
216					Port: 1000,
217				},
218				DestinationAddr: &net.UDPAddr{
219					IP:   net.ParseIP("20.2.2.2"),
220					Port: 2000,
221				},
222			},
223			udpSourceAddr: &net.UDPAddr{
224				IP:   net.ParseIP("10.1.1.1"),
225				Port: 1000,
226			},
227			udpDestAddr: &net.UDPAddr{
228				IP:   net.ParseIP("20.2.2.2"),
229				Port: 2000,
230			},
231			ipSource:   net.ParseIP("10.1.1.1"),
232			ipDest:     net.ParseIP("20.2.2.2"),
233			portSource: 1000,
234			portDest:   2000,
235		},
236		{
237			name: "UnixStream",
238			header: &Header{
239				Version:           2,
240				Command:           PROXY,
241				TransportProtocol: UnixStream,
242				SourceAddr: &net.UnixAddr{
243					Net:  "unix",
244					Name: "src",
245				},
246				DestinationAddr: &net.UnixAddr{
247					Net:  "unix",
248					Name: "dst",
249				},
250			},
251			unixSourceAddr: &net.UnixAddr{
252				Net:  "unix",
253				Name: "src",
254			},
255			unixDestAddr: &net.UnixAddr{
256				Net:  "unix",
257				Name: "dst",
258			},
259		},
260		{
261			name: "UnixDatagram",
262			header: &Header{
263				Version:           2,
264				Command:           PROXY,
265				TransportProtocol: UnixDatagram,
266				SourceAddr: &net.UnixAddr{
267					Net:  "unix",
268					Name: "src",
269				},
270				DestinationAddr: &net.UnixAddr{
271					Net:  "unix",
272					Name: "dst",
273				},
274			},
275			unixSourceAddr: &net.UnixAddr{
276				Net:  "unix",
277				Name: "src",
278			},
279			unixDestAddr: &net.UnixAddr{
280				Net:  "unix",
281				Name: "dst",
282			},
283		},
284		{
285			name: "Unspec",
286			header: &Header{
287				Version:           1,
288				Command:           PROXY,
289				TransportProtocol: UNSPEC,
290			},
291		},
292	}
293
294	for _, test := range tests {
295		t.Run(test.name, func(t *testing.T) {
296			tcpSourceAddr, tcpDestAddr, _ := test.header.TCPAddrs()
297			if test.tcpSourceAddr != nil && !reflect.DeepEqual(tcpSourceAddr, test.tcpSourceAddr) {
298				t.Errorf("TCPAddrs() source = %v, want %v", tcpSourceAddr, test.tcpSourceAddr)
299			}
300			if test.tcpDestAddr != nil && !reflect.DeepEqual(tcpDestAddr, test.tcpDestAddr) {
301				t.Errorf("TCPAddrs() dest = %v, want %v", tcpDestAddr, test.tcpDestAddr)
302			}
303
304			udpSourceAddr, udpDestAddr, _ := test.header.UDPAddrs()
305			if test.udpSourceAddr != nil && !reflect.DeepEqual(udpSourceAddr, test.udpSourceAddr) {
306				t.Errorf("TCPAddrs() source = %v, want %v", udpSourceAddr, test.udpSourceAddr)
307			}
308			if test.udpDestAddr != nil && !reflect.DeepEqual(udpDestAddr, test.udpDestAddr) {
309				t.Errorf("TCPAddrs() dest = %v, want %v", udpDestAddr, test.udpDestAddr)
310			}
311
312			unixSourceAddr, unixDestAddr, _ := test.header.UnixAddrs()
313			if test.unixSourceAddr != nil && !reflect.DeepEqual(unixSourceAddr, test.unixSourceAddr) {
314				t.Errorf("UnixAddrs() source = %v, want %v", unixSourceAddr, test.unixSourceAddr)
315			}
316			if test.unixDestAddr != nil && !reflect.DeepEqual(unixDestAddr, test.unixDestAddr) {
317				t.Errorf("UnixAddrs() dest = %v, want %v", unixDestAddr, test.unixDestAddr)
318			}
319
320			ipSource, ipDest, _ := test.header.IPs()
321			if test.ipSource != nil && !ipSource.Equal(test.ipSource) {
322				t.Errorf("IPs() source = %v, want %v", ipSource, test.ipSource)
323			}
324			if test.ipDest != nil && !ipDest.Equal(test.ipDest) {
325				t.Errorf("IPs() dest = %v, want %v", ipDest, test.ipDest)
326			}
327
328			portSource, portDest, _ := test.header.Ports()
329			if test.portSource != 0 && portSource != test.portSource {
330				t.Errorf("Ports() source = %v, want %v", portSource, test.portSource)
331			}
332			if test.portDest != 0 && portDest != test.portDest {
333				t.Errorf("Ports() dest = %v, want %v", portDest, test.portDest)
334			}
335		})
336	}
337}
338
339func TestSetTLVs(t *testing.T) {
340	tests := []struct {
341		header    *Header
342		name      string
343		tlvs      []TLV
344		expectErr bool
345	}{
346		{
347			name: "add authority TLV",
348			header: &Header{
349				Version:           1,
350				Command:           PROXY,
351				TransportProtocol: TCPv4,
352				SourceAddr: &net.TCPAddr{
353					IP:   net.ParseIP("10.1.1.1"),
354					Port: 1000,
355				},
356				DestinationAddr: &net.TCPAddr{
357					IP:   net.ParseIP("20.2.2.2"),
358					Port: 2000,
359				},
360			},
361			tlvs: []TLV{{
362				Type:  PP2_TYPE_AUTHORITY,
363				Value: []byte("example.org"),
364			}},
365		},
366		{
367			name: "add too long TLV",
368			header: &Header{
369				Version:           1,
370				Command:           PROXY,
371				TransportProtocol: TCPv4,
372				SourceAddr: &net.TCPAddr{
373					IP:   net.ParseIP("10.1.1.1"),
374					Port: 1000,
375				},
376				DestinationAddr: &net.TCPAddr{
377					IP:   net.ParseIP("20.2.2.2"),
378					Port: 2000,
379				},
380			},
381			tlvs: []TLV{{
382				Type:  PP2_TYPE_AUTHORITY,
383				Value: append(bytes.Repeat([]byte("a"), 0xFFFF), []byte(".example.org")...),
384			}},
385			expectErr: true,
386		},
387	}
388	for _, tt := range tests {
389		err := tt.header.SetTLVs(tt.tlvs)
390		if err != nil && !tt.expectErr {
391			t.Fatalf("shouldn't have thrown error %q", err.Error())
392		}
393	}
394}
395
396func TestWriteTo(t *testing.T) {
397	var buf bytes.Buffer
398
399	validHeader := &Header{
400		Version:           1,
401		Command:           PROXY,
402		TransportProtocol: TCPv4,
403		SourceAddr: &net.TCPAddr{
404			IP:   net.ParseIP("10.1.1.1"),
405			Port: 1000,
406		},
407		DestinationAddr: &net.TCPAddr{
408			IP:   net.ParseIP("20.2.2.2"),
409			Port: 2000,
410		},
411	}
412
413	if _, err := validHeader.WriteTo(&buf); err != nil {
414		t.Fatalf("shouldn't have thrown error %q", err.Error())
415	}
416
417	invalidHeader := &Header{
418		SourceAddr: &net.TCPAddr{
419			IP:   net.ParseIP("10.1.1.1"),
420			Port: 1000,
421		},
422		DestinationAddr: &net.TCPAddr{
423			IP:   net.ParseIP("20.2.2.2"),
424			Port: 2000,
425		},
426	}
427
428	if _, err := invalidHeader.WriteTo(&buf); err == nil {
429		t.Fatalf("should have thrown error %q", err.Error())
430	}
431}
432
433func TestFormat(t *testing.T) {
434	validHeader := &Header{
435		Version:           1,
436		Command:           PROXY,
437		TransportProtocol: TCPv4,
438		SourceAddr: &net.TCPAddr{
439			IP:   net.ParseIP("10.1.1.1"),
440			Port: 1000,
441		},
442		DestinationAddr: &net.TCPAddr{
443			IP:   net.ParseIP("20.2.2.2"),
444			Port: 2000,
445		},
446	}
447
448	if _, err := validHeader.Format(); err != nil {
449		t.Fatalf("shouldn't have thrown error %q", err.Error())
450	}
451}
452
453func TestFormatInvalid(t *testing.T) {
454	tests := []struct {
455		name   string
456		header *Header
457		err    error
458	}{
459		{
460			name: "invalidVersion",
461			header: &Header{
462				Version:           3,
463				Command:           PROXY,
464				TransportProtocol: TCPv4,
465				SourceAddr:        v4addr,
466				DestinationAddr:   v4addr,
467			},
468			err: ErrUnknownProxyProtocolVersion,
469		},
470		{
471			name: "v2MismatchTCPv4_UDPv4",
472			header: &Header{
473				Version:           2,
474				Command:           PROXY,
475				TransportProtocol: TCPv4,
476				SourceAddr:        v4UDPAddr,
477				DestinationAddr:   v4addr,
478			},
479			err: ErrInvalidAddress,
480		},
481		{
482			name: "v2MismatchTCPv4_TCPv6",
483			header: &Header{
484				Version:           2,
485				Command:           PROXY,
486				TransportProtocol: TCPv4,
487				SourceAddr:        v4addr,
488				DestinationAddr:   v6addr,
489			},
490			err: ErrInvalidAddress,
491		},
492		{
493			name: "v2MismatchUnixStream_TCPv4",
494			header: &Header{
495				Version:           2,
496				Command:           PROXY,
497				TransportProtocol: UnixStream,
498				SourceAddr:        v4addr,
499				DestinationAddr:   unixStreamAddr,
500			},
501			err: ErrInvalidAddress,
502		},
503		{
504			name: "v1MismatchTCPv4_TCPv6",
505			header: &Header{
506				Version:           1,
507				Command:           PROXY,
508				TransportProtocol: TCPv4,
509				SourceAddr:        v6addr,
510				DestinationAddr:   v4addr,
511			},
512			err: ErrInvalidAddress,
513		},
514		{
515			name: "v1MismatchTCPv4_UDPv4",
516			header: &Header{
517				Version:           1,
518				Command:           PROXY,
519				TransportProtocol: TCPv4,
520				SourceAddr:        v4UDPAddr,
521				DestinationAddr:   v4addr,
522			},
523			err: ErrInvalidAddress,
524		},
525	}
526
527	for _, test := range tests {
528		t.Run(test.name, func(t *testing.T) {
529			if _, err := test.header.Format(); err == nil {
530				t.Errorf("Header.Format() succeeded, want an error")
531			} else if err != test.err {
532				t.Errorf("Header.Format() = %q, want %q", err, test.err)
533			}
534		})
535	}
536}
537
538func TestHeaderProxyFromAddrs(t *testing.T) {
539	unspec := &Header{
540		Version:           2,
541		Command:           LOCAL,
542		TransportProtocol: UNSPEC,
543	}
544
545	tests := []struct {
546		name                 string
547		version              byte
548		sourceAddr, destAddr net.Addr
549		expected             *Header
550	}{
551		{
552			name: "TCPv4",
553			sourceAddr: &net.TCPAddr{
554				IP:   net.ParseIP("10.1.1.1"),
555				Port: 1000,
556			},
557			destAddr: &net.TCPAddr{
558				IP:   net.ParseIP("20.2.2.2"),
559				Port: 2000,
560			},
561			expected: &Header{
562				Version:           2,
563				Command:           PROXY,
564				TransportProtocol: TCPv4,
565				SourceAddr: &net.TCPAddr{
566					IP:   net.ParseIP("10.1.1.1"),
567					Port: 1000,
568				},
569				DestinationAddr: &net.TCPAddr{
570					IP:   net.ParseIP("20.2.2.2"),
571					Port: 2000,
572				},
573			},
574		},
575		{
576			name: "TCPv6",
577			sourceAddr: &net.TCPAddr{
578				IP:   net.ParseIP("fde7::372"),
579				Port: 1000,
580			},
581			destAddr: &net.TCPAddr{
582				IP:   net.ParseIP("fde7::1"),
583				Port: 2000,
584			},
585			expected: &Header{
586				Version:           2,
587				Command:           PROXY,
588				TransportProtocol: TCPv6,
589				SourceAddr: &net.TCPAddr{
590					IP:   net.ParseIP("fde7::372"),
591					Port: 1000,
592				},
593				DestinationAddr: &net.TCPAddr{
594					IP:   net.ParseIP("fde7::1"),
595					Port: 2000,
596				},
597			},
598		},
599		{
600			name: "UDPv4",
601			sourceAddr: &net.UDPAddr{
602				IP:   net.ParseIP("10.1.1.1"),
603				Port: 1000,
604			},
605			destAddr: &net.UDPAddr{
606				IP:   net.ParseIP("20.2.2.2"),
607				Port: 2000,
608			},
609			expected: &Header{
610				Version:           2,
611				Command:           PROXY,
612				TransportProtocol: UDPv4,
613				SourceAddr: &net.TCPAddr{
614					IP:   net.ParseIP("10.1.1.1"),
615					Port: 1000,
616				},
617				DestinationAddr: &net.TCPAddr{
618					IP:   net.ParseIP("20.2.2.2"),
619					Port: 2000,
620				},
621			},
622		},
623		{
624			name: "UDPv6",
625			sourceAddr: &net.UDPAddr{
626				IP:   net.ParseIP("fde7::372"),
627				Port: 1000,
628			},
629			destAddr: &net.UDPAddr{
630				IP:   net.ParseIP("fde7::1"),
631				Port: 2000,
632			},
633			expected: &Header{
634				Version:           2,
635				Command:           PROXY,
636				TransportProtocol: UDPv6,
637				SourceAddr: &net.TCPAddr{
638					IP:   net.ParseIP("fde7::372"),
639					Port: 1000,
640				},
641				DestinationAddr: &net.TCPAddr{
642					IP:   net.ParseIP("fde7::1"),
643					Port: 2000,
644				},
645			},
646		},
647		{
648			name: "UnixStream",
649			sourceAddr: &net.UnixAddr{
650				Net:  "unix",
651				Name: "src",
652			},
653			destAddr: &net.UnixAddr{
654				Net:  "unix",
655				Name: "dst",
656			},
657			expected: &Header{
658				Version:           2,
659				Command:           PROXY,
660				TransportProtocol: UnixStream,
661				SourceAddr: &net.UnixAddr{
662					Net:  "unix",
663					Name: "src",
664				},
665				DestinationAddr: &net.UnixAddr{
666					Net:  "unix",
667					Name: "dst",
668				},
669			},
670		},
671		{
672			name: "UnixDatagram",
673			sourceAddr: &net.UnixAddr{
674				Net:  "unixgram",
675				Name: "src",
676			},
677			destAddr: &net.UnixAddr{
678				Net:  "unixgram",
679				Name: "dst",
680			},
681			expected: &Header{
682				Version:           2,
683				Command:           PROXY,
684				TransportProtocol: UnixDatagram,
685				SourceAddr: &net.UnixAddr{
686					Net:  "unixgram",
687					Name: "src",
688				},
689				DestinationAddr: &net.UnixAddr{
690					Net:  "unixgram",
691					Name: "dst",
692				},
693			},
694		},
695		{
696			name:    "Version1",
697			version: 1,
698			sourceAddr: &net.TCPAddr{
699				IP:   net.ParseIP("10.1.1.1"),
700				Port: 1000,
701			},
702			destAddr: &net.TCPAddr{
703				IP:   net.ParseIP("20.2.2.2"),
704				Port: 2000,
705			},
706			expected: &Header{
707				Version:           1,
708				Command:           PROXY,
709				TransportProtocol: TCPv4,
710				SourceAddr: &net.TCPAddr{
711					IP:   net.ParseIP("10.1.1.1"),
712					Port: 1000,
713				},
714				DestinationAddr: &net.TCPAddr{
715					IP:   net.ParseIP("20.2.2.2"),
716					Port: 2000,
717				},
718			},
719		},
720		{
721			name: "TCPInvalidIP",
722			sourceAddr: &net.TCPAddr{
723				IP:   nil,
724				Port: 1000,
725			},
726			destAddr: &net.TCPAddr{
727				IP:   nil,
728				Port: 2000,
729			},
730			expected: unspec,
731		},
732		{
733			name: "UDPInvalidIP",
734			sourceAddr: &net.UDPAddr{
735				IP:   nil,
736				Port: 1000,
737			},
738			destAddr: &net.UDPAddr{
739				IP:   nil,
740				Port: 2000,
741			},
742			expected: unspec,
743		},
744		{
745			name: "TCPAddrTypeMismatch",
746			sourceAddr: &net.TCPAddr{
747				IP:   net.ParseIP("10.1.1.1"),
748				Port: 1000,
749			},
750			destAddr: &net.UDPAddr{
751				IP:   net.ParseIP("20.2.2.2"),
752				Port: 2000,
753			},
754			expected: unspec,
755		},
756		{
757			name: "UDPAddrTypeMismatch",
758			sourceAddr: &net.UDPAddr{
759				IP:   net.ParseIP("10.1.1.1"),
760				Port: 1000,
761			},
762			destAddr: &net.TCPAddr{
763				IP:   net.ParseIP("20.2.2.2"),
764				Port: 2000,
765			},
766			expected: unspec,
767		},
768		{
769			name: "UnixAddrTypeMismatch",
770			sourceAddr: &net.UnixAddr{
771				Net: "unix",
772			},
773			destAddr: &net.TCPAddr{
774				IP:   net.ParseIP("20.2.2.2"),
775				Port: 2000,
776			},
777			expected: unspec,
778		},
779	}
780
781	for _, tt := range tests {
782		t.Run(tt.name, func(t *testing.T) {
783			h := HeaderProxyFromAddrs(tt.version, tt.sourceAddr, tt.destAddr)
784
785			if !h.EqualsTo(tt.expected) {
786				t.Errorf("expected %+v, actual %+v for source %+v and destination %+v", tt.expected, h, tt.sourceAddr, tt.destAddr)
787			}
788		})
789	}
790}
791