1package http
2
3import (
4	"bytes"
5	"net/http"
6	"strings"
7	"testing"
8
9	sockaddr "github.com/hashicorp/go-sockaddr"
10	"github.com/hashicorp/vault/vault"
11)
12
13func TestHandler_XForwardedFor(t *testing.T) {
14	goodAddr, err := sockaddr.NewIPAddr("127.0.0.1")
15	if err != nil {
16		t.Fatal(err)
17	}
18
19	badAddr, err := sockaddr.NewIPAddr("1.2.3.4")
20	if err != nil {
21		t.Fatal(err)
22	}
23
24	// First: test reject not present
25	t.Run("reject_not_present", func(t *testing.T) {
26		t.Parallel()
27		testHandler := func(props *vault.HandlerProperties) http.Handler {
28			origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
29				w.WriteHeader(http.StatusOK)
30				w.Write([]byte(r.RemoteAddr))
31			})
32			return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{
33				&sockaddr.SockAddrMarshaler{
34					SockAddr: goodAddr,
35				},
36			}, true, false, 0)
37		}
38
39		cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
40			HandlerFunc: testHandler,
41		})
42		cluster.Start()
43		defer cluster.Cleanup()
44		client := cluster.Cores[0].Client
45
46		req := client.NewRequest("GET", "/")
47		_, err := client.RawRequest(req)
48		if err == nil {
49			t.Fatal("expected error")
50		}
51		if !strings.Contains(err.Error(), "missing x-forwarded-for") {
52			t.Fatalf("bad error message: %v", err)
53		}
54		req = client.NewRequest("GET", "/")
55		req.Headers = make(http.Header)
56		req.Headers.Set("x-forwarded-for", "1.2.3.4")
57		resp, err := client.RawRequest(req)
58		if err != nil {
59			t.Fatal(err)
60		}
61		defer resp.Body.Close()
62		buf := bytes.NewBuffer(nil)
63		buf.ReadFrom(resp.Body)
64		if !strings.HasPrefix(buf.String(), "1.2.3.4:") {
65			t.Fatalf("bad body: %s", buf.String())
66		}
67	})
68
69	// Next: test allow unauth
70	t.Run("allow_unauth", func(t *testing.T) {
71		t.Parallel()
72		testHandler := func(props *vault.HandlerProperties) http.Handler {
73			origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
74				w.WriteHeader(http.StatusOK)
75				w.Write([]byte(r.RemoteAddr))
76			})
77			return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{
78				&sockaddr.SockAddrMarshaler{
79					SockAddr: badAddr,
80				},
81			}, true, false, 0)
82		}
83
84		cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
85			HandlerFunc: testHandler,
86		})
87		cluster.Start()
88		defer cluster.Cleanup()
89		client := cluster.Cores[0].Client
90
91		req := client.NewRequest("GET", "/")
92		req.Headers = make(http.Header)
93		req.Headers.Set("x-forwarded-for", "5.6.7.8")
94		resp, err := client.RawRequest(req)
95		if err != nil {
96			t.Fatal(err)
97		}
98		defer resp.Body.Close()
99		buf := bytes.NewBuffer(nil)
100		buf.ReadFrom(resp.Body)
101		if !strings.HasPrefix(buf.String(), "127.0.0.1:") {
102			t.Fatalf("bad body: %s", buf.String())
103		}
104	})
105
106	// Next: test fail unauth
107	t.Run("fail_unauth", func(t *testing.T) {
108		t.Parallel()
109		testHandler := func(props *vault.HandlerProperties) http.Handler {
110			origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
111				w.WriteHeader(http.StatusOK)
112				w.Write([]byte(r.RemoteAddr))
113			})
114			return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{
115				&sockaddr.SockAddrMarshaler{
116					SockAddr: badAddr,
117				},
118			}, true, true, 0)
119		}
120
121		cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
122			HandlerFunc: testHandler,
123		})
124		cluster.Start()
125		defer cluster.Cleanup()
126		client := cluster.Cores[0].Client
127
128		req := client.NewRequest("GET", "/")
129		req.Headers = make(http.Header)
130		req.Headers.Set("x-forwarded-for", "5.6.7.8")
131		_, err := client.RawRequest(req)
132		if err == nil {
133			t.Fatal("expected error")
134		}
135		if !strings.Contains(err.Error(), "not authorized for x-forwarded-for") {
136			t.Fatalf("bad error message: %v", err)
137		}
138	})
139
140	// Next: test bad hops (too many)
141	t.Run("too_many_hops", func(t *testing.T) {
142		t.Parallel()
143		testHandler := func(props *vault.HandlerProperties) http.Handler {
144			origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
145				w.WriteHeader(http.StatusOK)
146				w.Write([]byte(r.RemoteAddr))
147			})
148			return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{
149				&sockaddr.SockAddrMarshaler{
150					SockAddr: goodAddr,
151				},
152			}, true, true, 4)
153		}
154
155		cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
156			HandlerFunc: testHandler,
157		})
158		cluster.Start()
159		defer cluster.Cleanup()
160		client := cluster.Cores[0].Client
161
162		req := client.NewRequest("GET", "/")
163		req.Headers = make(http.Header)
164		req.Headers.Set("x-forwarded-for", "2.3.4.5,3.4.5.6")
165		_, err := client.RawRequest(req)
166		if err == nil {
167			t.Fatal("expected error")
168		}
169		if !strings.Contains(err.Error(), "would skip before earliest") {
170			t.Fatalf("bad error message: %v", err)
171		}
172	})
173
174	// Next: test picking correct value
175	t.Run("correct_hop_skipping", func(t *testing.T) {
176		t.Parallel()
177		testHandler := func(props *vault.HandlerProperties) http.Handler {
178			origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
179				w.WriteHeader(http.StatusOK)
180				w.Write([]byte(r.RemoteAddr))
181			})
182			return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{
183				&sockaddr.SockAddrMarshaler{
184					SockAddr: goodAddr,
185				},
186			}, true, true, 1)
187		}
188
189		cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
190			HandlerFunc: testHandler,
191		})
192		cluster.Start()
193		defer cluster.Cleanup()
194		client := cluster.Cores[0].Client
195
196		req := client.NewRequest("GET", "/")
197		req.Headers = make(http.Header)
198		req.Headers.Set("x-forwarded-for", "2.3.4.5,3.4.5.6,4.5.6.7,5.6.7.8")
199		resp, err := client.RawRequest(req)
200		if err != nil {
201			t.Fatal(err)
202		}
203		defer resp.Body.Close()
204		buf := bytes.NewBuffer(nil)
205		buf.ReadFrom(resp.Body)
206		if !strings.HasPrefix(buf.String(), "4.5.6.7:") {
207			t.Fatalf("bad body: %s", buf.String())
208		}
209	})
210
211	// Next: multi-header approach
212	t.Run("correct_hop_skipping_multi_header", func(t *testing.T) {
213		t.Parallel()
214		testHandler := func(props *vault.HandlerProperties) http.Handler {
215			origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
216				w.WriteHeader(http.StatusOK)
217				w.Write([]byte(r.RemoteAddr))
218			})
219			return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{
220				&sockaddr.SockAddrMarshaler{
221					SockAddr: goodAddr,
222				},
223			}, true, true, 1)
224		}
225
226		cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
227			HandlerFunc: testHandler,
228		})
229		cluster.Start()
230		defer cluster.Cleanup()
231		client := cluster.Cores[0].Client
232
233		req := client.NewRequest("GET", "/")
234		req.Headers = make(http.Header)
235		req.Headers.Add("x-forwarded-for", "2.3.4.5")
236		req.Headers.Add("x-forwarded-for", "3.4.5.6,4.5.6.7")
237		req.Headers.Add("x-forwarded-for", "5.6.7.8")
238		resp, err := client.RawRequest(req)
239		if err != nil {
240			t.Fatal(err)
241		}
242		defer resp.Body.Close()
243		buf := bytes.NewBuffer(nil)
244		buf.ReadFrom(resp.Body)
245		if !strings.HasPrefix(buf.String(), "4.5.6.7:") {
246			t.Fatalf("bad body: %s", buf.String())
247		}
248	})
249}
250