1package proxyproto 2 3import ( 4 "bufio" 5 "bytes" 6 "io" 7 "net" 8 "strconv" 9 "strings" 10 "testing" 11 "time" 12) 13 14var ( 15 IPv4AddressesAndPorts = strings.Join([]string{IP4_ADDR, IP4_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator) 16 IPv4AddressesAndInvalidPorts = strings.Join([]string{IP4_ADDR, IP4_ADDR, strconv.Itoa(INVALID_PORT), strconv.Itoa(INVALID_PORT)}, separator) 17 IPv6AddressesAndPorts = strings.Join([]string{IP6_ADDR, IP6_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator) 18 IPv6LongAddressesAndPorts = strings.Join([]string{IP6_LONG_ADDR, IP6_LONG_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator) 19 20 fixtureTCP4V1 = "PROXY TCP4 " + IPv4AddressesAndPorts + crlf + "GET /" 21 fixtureTCP6V1 = "PROXY TCP6 " + IPv6AddressesAndPorts + crlf + "GET /" 22 23 fixtureTCP6V1Overflow = "PROXY TCP6 " + IPv6LongAddressesAndPorts 24 25 fixtureUnknown = "PROXY UNKNOWN" + crlf 26 fixtureUnknownWithAddresses = "PROXY UNKNOWN " + IPv4AddressesAndInvalidPorts + crlf 27) 28 29var invalidParseV1Tests = []struct { 30 desc string 31 reader *bufio.Reader 32 expectedError error 33}{ 34 { 35 desc: "no signature", 36 reader: newBufioReader([]byte(NO_PROTOCOL)), 37 expectedError: ErrNoProxyProtocol, 38 }, 39 { 40 desc: "prox", 41 reader: newBufioReader([]byte("PROX")), 42 expectedError: ErrNoProxyProtocol, 43 }, 44 { 45 desc: "proxy lf", 46 reader: newBufioReader([]byte("PROXY \n")), 47 expectedError: ErrLineMustEndWithCrlf, 48 }, 49 { 50 desc: "proxy crlf", 51 reader: newBufioReader([]byte("PROXY " + crlf)), 52 expectedError: ErrCantReadAddressFamilyAndProtocol, 53 }, 54 { 55 desc: "proxy no space crlf", 56 reader: newBufioReader([]byte("PROXY" + crlf)), 57 expectedError: ErrCantReadAddressFamilyAndProtocol, 58 }, 59 { 60 desc: "proxy something crlf", 61 reader: newBufioReader([]byte("PROXY SOMETHING" + crlf)), 62 expectedError: ErrCantReadAddressFamilyAndProtocol, 63 }, 64 { 65 desc: "incomplete signature TCP4", 66 reader: newBufioReader([]byte("PROXY TCP4 " + IPv4AddressesAndPorts)), 67 expectedError: ErrCantReadVersion1Header, 68 }, 69 { 70 desc: "TCP6 with IPv4 addresses", 71 reader: newBufioReader([]byte("PROXY TCP6 " + IPv4AddressesAndPorts + crlf)), 72 expectedError: ErrInvalidAddress, 73 }, 74 { 75 desc: "TCP4 with IPv6 addresses", 76 reader: newBufioReader([]byte("PROXY TCP4 " + IPv6AddressesAndPorts + crlf)), 77 expectedError: ErrInvalidAddress, 78 }, 79 { 80 desc: "TCP4 with invalid port", 81 reader: newBufioReader([]byte("PROXY TCP4 " + IPv4AddressesAndInvalidPorts + crlf)), 82 expectedError: ErrInvalidPortNumber, 83 }, 84 { 85 desc: "header too long", 86 reader: newBufioReader([]byte("PROXY UNKNOWN " + IPv6LongAddressesAndPorts + " " + crlf)), 87 expectedError: ErrVersion1HeaderTooLong, 88 }, 89} 90 91func TestReadV1Invalid(t *testing.T) { 92 for _, tt := range invalidParseV1Tests { 93 t.Run(tt.desc, func(t *testing.T) { 94 if _, err := Read(tt.reader); err != tt.expectedError { 95 t.Fatalf("expected %s, actual %v", tt.expectedError, err) 96 } 97 }) 98 } 99} 100 101var validParseAndWriteV1Tests = []struct { 102 desc string 103 reader *bufio.Reader 104 expectedHeader *Header 105}{ 106 { 107 desc: "TCP4", 108 reader: bufio.NewReader(strings.NewReader(fixtureTCP4V1)), 109 expectedHeader: &Header{ 110 Version: 1, 111 Command: PROXY, 112 TransportProtocol: TCPv4, 113 SourceAddr: v4addr, 114 DestinationAddr: v4addr, 115 }, 116 }, 117 { 118 desc: "TCP6", 119 reader: bufio.NewReader(strings.NewReader(fixtureTCP6V1)), 120 expectedHeader: &Header{ 121 Version: 1, 122 Command: PROXY, 123 TransportProtocol: TCPv6, 124 SourceAddr: v6addr, 125 DestinationAddr: v6addr, 126 }, 127 }, 128 { 129 desc: "unknown", 130 reader: bufio.NewReader(strings.NewReader(fixtureUnknown)), 131 expectedHeader: &Header{ 132 Version: 1, 133 Command: LOCAL, 134 TransportProtocol: UNSPEC, 135 SourceAddr: nil, 136 DestinationAddr: nil, 137 }, 138 }, 139 { 140 desc: "unknown with addresses and ports", 141 reader: bufio.NewReader(strings.NewReader(fixtureUnknownWithAddresses)), 142 expectedHeader: &Header{ 143 Version: 1, 144 Command: LOCAL, 145 TransportProtocol: UNSPEC, 146 SourceAddr: nil, 147 DestinationAddr: nil, 148 }, 149 }, 150} 151 152func TestParseV1Valid(t *testing.T) { 153 for _, tt := range validParseAndWriteV1Tests { 154 t.Run(tt.desc, func(t *testing.T) { 155 header, err := Read(tt.reader) 156 if err != nil { 157 t.Fatal("unexpected error", err.Error()) 158 } 159 if !header.EqualsTo(tt.expectedHeader) { 160 t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, header) 161 } 162 }) 163 } 164} 165 166func TestWriteV1Valid(t *testing.T) { 167 for _, tt := range validParseAndWriteV1Tests { 168 t.Run(tt.desc, func(t *testing.T) { 169 var b bytes.Buffer 170 w := bufio.NewWriter(&b) 171 if _, err := tt.expectedHeader.WriteTo(w); err != nil { 172 t.Fatal("unexpected error ", err) 173 } 174 w.Flush() 175 176 // Read written bytes to validate written header 177 r := bufio.NewReader(&b) 178 newHeader, err := Read(r) 179 if err != nil { 180 t.Fatal("unexpected error ", err) 181 } 182 183 if !newHeader.EqualsTo(tt.expectedHeader) { 184 t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, newHeader) 185 } 186 }) 187 } 188} 189 190// Tests for parseVersion1 overflow - issue #69. 191 192type dataSource struct { 193 NBytes int 194 NRead int 195} 196 197func (ds *dataSource) Read(b []byte) (int, error) { 198 if ds.NRead >= ds.NBytes { 199 return 0, io.EOF 200 } 201 avail := ds.NBytes - ds.NRead 202 if len(b) < avail { 203 avail = len(b) 204 } 205 for i := 0; i < avail; i++ { 206 b[i] = 0x20 207 } 208 ds.NRead += avail 209 return avail, nil 210} 211 212func TestParseVersion1Overflow(t *testing.T) { 213 ds := &dataSource{} 214 reader := bufio.NewReader(ds) 215 bufSize := reader.Size() 216 ds.NBytes = bufSize * 16 217 parseVersion1(reader) 218 if ds.NRead > bufSize { 219 t.Fatalf("read: expected max %d bytes, actual %d\n", bufSize, ds.NRead) 220 } 221} 222 223func listen(t *testing.T) *Listener { 224 l, err := net.Listen("tcp", "127.0.0.1:0") 225 if err != nil { 226 t.Fatalf("listen: %v", err) 227 } 228 return &Listener{Listener: l} 229} 230 231func client(t *testing.T, addr, header string, length int, terminate bool, wait time.Duration, done chan struct{}) { 232 c, err := net.Dial("tcp", addr) 233 if err != nil { 234 t.Fatalf("dial: %v", err) 235 } 236 defer c.Close() 237 238 if terminate && length < 2 { 239 length = 2 240 } 241 242 buf := make([]byte, len(header)+length) 243 copy(buf, []byte(header)) 244 for i := 0; i < length-2; i++ { 245 buf[i+len(header)] = 0x20 246 } 247 if terminate { 248 copy(buf[len(header)+length-2:], []byte(crlf)) 249 } 250 251 n, err := c.Write(buf) 252 if err != nil { 253 t.Fatalf("write: %v", err) 254 } 255 if n != len(buf) { 256 t.Fatalf("write; short write") 257 } 258 259 time.Sleep(wait) 260 close(done) 261} 262 263func TestVersion1Overflow(t *testing.T) { 264 done := make(chan struct{}) 265 266 l := listen(t) 267 go client(t, l.Addr().String(), fixtureTCP6V1Overflow, 10240, true, 10*time.Second, done) 268 269 c, err := l.Accept() 270 if err != nil { 271 t.Fatalf("accept: %v", err) 272 } 273 274 b := []byte{} 275 _, err = c.Read(b) 276 if err == nil { 277 t.Fatalf("net.Conn: no error reported for oversized header") 278 } 279} 280 281func TestVersion1SlowLoris(t *testing.T) { 282 done := make(chan struct{}) 283 timeout := make(chan error) 284 285 l := listen(t) 286 go client(t, l.Addr().String(), fixtureTCP6V1Overflow, 0, false, 10*time.Second, done) 287 288 c, err := l.Accept() 289 if err != nil { 290 t.Fatalf("accept: %v", err) 291 } 292 293 go func() { 294 b := []byte{} 295 _, err = c.Read(b) 296 timeout <- err 297 }() 298 299 select { 300 case <-done: 301 t.Fatalf("net.Conn: reader still blocked after 10 seconds") 302 case err := <-timeout: 303 if err == nil { 304 t.Fatalf("net.Conn: no error reported for incomplete header") 305 } 306 } 307} 308 309func TestVersion1SlowLorisOverflow(t *testing.T) { 310 done := make(chan struct{}) 311 timeout := make(chan error) 312 313 l := listen(t) 314 go client(t, l.Addr().String(), fixtureTCP6V1Overflow, 10240, false, 10*time.Second, done) 315 316 c, err := l.Accept() 317 if err != nil { 318 t.Fatalf("accept: %v", err) 319 } 320 321 go func() { 322 b := []byte{} 323 _, err = c.Read(b) 324 timeout <- err 325 }() 326 327 select { 328 case <-done: 329 t.Fatalf("net.Conn: reader still blocked after 10 seconds") 330 case err := <-timeout: 331 if err == nil { 332 t.Fatalf("net.Conn: no error reported for incomplete and overflowed header") 333 } 334 } 335} 336