1package http 2 3import ( 4 "bytes" 5 "net/http" 6 "strings" 7 "testing" 8 9 sockaddr "github.com/hashicorp/go-sockaddr" 10 "github.com/hashicorp/vault/vault" 11) 12 13func TestHandler_XForwardedFor(t *testing.T) { 14 goodAddr, err := sockaddr.NewIPAddr("127.0.0.1") 15 if err != nil { 16 t.Fatal(err) 17 } 18 19 badAddr, err := sockaddr.NewIPAddr("1.2.3.4") 20 if err != nil { 21 t.Fatal(err) 22 } 23 24 // First: test reject not present 25 t.Run("reject_not_present", func(t *testing.T) { 26 t.Parallel() 27 testHandler := func(props *vault.HandlerProperties) http.Handler { 28 origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 29 w.WriteHeader(http.StatusOK) 30 w.Write([]byte(r.RemoteAddr)) 31 }) 32 return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{ 33 &sockaddr.SockAddrMarshaler{ 34 SockAddr: goodAddr, 35 }, 36 }, true, false, 0) 37 } 38 39 cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ 40 HandlerFunc: testHandler, 41 }) 42 cluster.Start() 43 defer cluster.Cleanup() 44 client := cluster.Cores[0].Client 45 46 req := client.NewRequest("GET", "/") 47 _, err := client.RawRequest(req) 48 if err == nil { 49 t.Fatal("expected error") 50 } 51 if !strings.Contains(err.Error(), "missing x-forwarded-for") { 52 t.Fatalf("bad error message: %v", err) 53 } 54 req = client.NewRequest("GET", "/") 55 req.Headers = make(http.Header) 56 req.Headers.Set("x-forwarded-for", "1.2.3.4") 57 resp, err := client.RawRequest(req) 58 if err != nil { 59 t.Fatal(err) 60 } 61 defer resp.Body.Close() 62 buf := bytes.NewBuffer(nil) 63 buf.ReadFrom(resp.Body) 64 if !strings.HasPrefix(buf.String(), "1.2.3.4:") { 65 t.Fatalf("bad body: %s", buf.String()) 66 } 67 }) 68 69 // Next: test allow unauth 70 t.Run("allow_unauth", func(t *testing.T) { 71 t.Parallel() 72 testHandler := func(props *vault.HandlerProperties) http.Handler { 73 origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 74 w.WriteHeader(http.StatusOK) 75 w.Write([]byte(r.RemoteAddr)) 76 }) 77 return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{ 78 &sockaddr.SockAddrMarshaler{ 79 SockAddr: badAddr, 80 }, 81 }, true, false, 0) 82 } 83 84 cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ 85 HandlerFunc: testHandler, 86 }) 87 cluster.Start() 88 defer cluster.Cleanup() 89 client := cluster.Cores[0].Client 90 91 req := client.NewRequest("GET", "/") 92 req.Headers = make(http.Header) 93 req.Headers.Set("x-forwarded-for", "5.6.7.8") 94 resp, err := client.RawRequest(req) 95 if err != nil { 96 t.Fatal(err) 97 } 98 defer resp.Body.Close() 99 buf := bytes.NewBuffer(nil) 100 buf.ReadFrom(resp.Body) 101 if !strings.HasPrefix(buf.String(), "127.0.0.1:") { 102 t.Fatalf("bad body: %s", buf.String()) 103 } 104 }) 105 106 // Next: test fail unauth 107 t.Run("fail_unauth", func(t *testing.T) { 108 t.Parallel() 109 testHandler := func(props *vault.HandlerProperties) http.Handler { 110 origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 111 w.WriteHeader(http.StatusOK) 112 w.Write([]byte(r.RemoteAddr)) 113 }) 114 return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{ 115 &sockaddr.SockAddrMarshaler{ 116 SockAddr: badAddr, 117 }, 118 }, true, true, 0) 119 } 120 121 cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ 122 HandlerFunc: testHandler, 123 }) 124 cluster.Start() 125 defer cluster.Cleanup() 126 client := cluster.Cores[0].Client 127 128 req := client.NewRequest("GET", "/") 129 req.Headers = make(http.Header) 130 req.Headers.Set("x-forwarded-for", "5.6.7.8") 131 _, err := client.RawRequest(req) 132 if err == nil { 133 t.Fatal("expected error") 134 } 135 if !strings.Contains(err.Error(), "not authorized for x-forwarded-for") { 136 t.Fatalf("bad error message: %v", err) 137 } 138 }) 139 140 // Next: test bad hops (too many) 141 t.Run("too_many_hops", func(t *testing.T) { 142 t.Parallel() 143 testHandler := func(props *vault.HandlerProperties) http.Handler { 144 origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 145 w.WriteHeader(http.StatusOK) 146 w.Write([]byte(r.RemoteAddr)) 147 }) 148 return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{ 149 &sockaddr.SockAddrMarshaler{ 150 SockAddr: goodAddr, 151 }, 152 }, true, true, 4) 153 } 154 155 cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ 156 HandlerFunc: testHandler, 157 }) 158 cluster.Start() 159 defer cluster.Cleanup() 160 client := cluster.Cores[0].Client 161 162 req := client.NewRequest("GET", "/") 163 req.Headers = make(http.Header) 164 req.Headers.Set("x-forwarded-for", "2.3.4.5,3.4.5.6") 165 _, err := client.RawRequest(req) 166 if err == nil { 167 t.Fatal("expected error") 168 } 169 if !strings.Contains(err.Error(), "would skip before earliest") { 170 t.Fatalf("bad error message: %v", err) 171 } 172 }) 173 174 // Next: test picking correct value 175 t.Run("correct_hop_skipping", func(t *testing.T) { 176 t.Parallel() 177 testHandler := func(props *vault.HandlerProperties) http.Handler { 178 origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 179 w.WriteHeader(http.StatusOK) 180 w.Write([]byte(r.RemoteAddr)) 181 }) 182 return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{ 183 &sockaddr.SockAddrMarshaler{ 184 SockAddr: goodAddr, 185 }, 186 }, true, true, 1) 187 } 188 189 cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ 190 HandlerFunc: testHandler, 191 }) 192 cluster.Start() 193 defer cluster.Cleanup() 194 client := cluster.Cores[0].Client 195 196 req := client.NewRequest("GET", "/") 197 req.Headers = make(http.Header) 198 req.Headers.Set("x-forwarded-for", "2.3.4.5,3.4.5.6,4.5.6.7,5.6.7.8") 199 resp, err := client.RawRequest(req) 200 if err != nil { 201 t.Fatal(err) 202 } 203 defer resp.Body.Close() 204 buf := bytes.NewBuffer(nil) 205 buf.ReadFrom(resp.Body) 206 if !strings.HasPrefix(buf.String(), "4.5.6.7:") { 207 t.Fatalf("bad body: %s", buf.String()) 208 } 209 }) 210 211 // Next: multi-header approach 212 t.Run("correct_hop_skipping_multi_header", func(t *testing.T) { 213 t.Parallel() 214 testHandler := func(props *vault.HandlerProperties) http.Handler { 215 origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 216 w.WriteHeader(http.StatusOK) 217 w.Write([]byte(r.RemoteAddr)) 218 }) 219 return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{ 220 &sockaddr.SockAddrMarshaler{ 221 SockAddr: goodAddr, 222 }, 223 }, true, true, 1) 224 } 225 226 cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ 227 HandlerFunc: testHandler, 228 }) 229 cluster.Start() 230 defer cluster.Cleanup() 231 client := cluster.Cores[0].Client 232 233 req := client.NewRequest("GET", "/") 234 req.Headers = make(http.Header) 235 req.Headers.Add("x-forwarded-for", "2.3.4.5") 236 req.Headers.Add("x-forwarded-for", "3.4.5.6,4.5.6.7") 237 req.Headers.Add("x-forwarded-for", "5.6.7.8") 238 resp, err := client.RawRequest(req) 239 if err != nil { 240 t.Fatal(err) 241 } 242 defer resp.Body.Close() 243 buf := bytes.NewBuffer(nil) 244 buf.ReadFrom(resp.Body) 245 if !strings.HasPrefix(buf.String(), "4.5.6.7:") { 246 t.Fatalf("bad body: %s", buf.String()) 247 } 248 }) 249} 250