1// Copyright (c) 2020 Shivaram Lingamneni <slingamn@cs.stanford.edu> 2// released under the MIT license 3 4package utils 5 6import ( 7 "crypto/tls" 8 "encoding/binary" 9 "io" 10 "net" 11 "strings" 12 "sync" 13 "time" 14) 15 16const ( 17 // https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt 18 // "a 108-byte buffer is always enough to store all the line and a trailing zero 19 // for string processing." 20 maxProxyLineLenV1 = 107 21) 22 23// XXX implement net.Error with a Temporary() method that returns true; 24// otherwise, ErrBadProxyLine will cause (*http.Server).Serve() to exit 25type proxyLineError struct{} 26 27func (p *proxyLineError) Error() string { 28 return "invalid PROXY line" 29} 30 31func (p *proxyLineError) Timeout() bool { 32 return false 33} 34 35func (p *proxyLineError) Temporary() bool { 36 return true 37} 38 39var ( 40 ErrBadProxyLine error = &proxyLineError{} 41) 42 43// ListenerConfig is all the information about how to process 44// incoming IRC connections on a listener. 45type ListenerConfig struct { 46 TLSConfig *tls.Config 47 ProxyDeadline time.Duration 48 RequireProxy bool 49 // these are just metadata for easier tracking, 50 // they are not used by ReloadableListener: 51 Tor bool 52 STSOnly bool 53 WebSocket bool 54 HideSTS bool 55} 56 57// read a PROXY header (either v1 or v2), ensuring we don't read anything beyond 58// the header into a buffer (this would break the TLS handshake) 59func readRawProxyLine(conn net.Conn, deadline time.Duration) (result []byte, err error) { 60 // normally this is covered by ping timeouts, but we're doing this outside 61 // of the normal client goroutine: 62 conn.SetDeadline(time.Now().Add(deadline)) 63 defer conn.SetDeadline(time.Time{}) 64 65 // read the first 16 bytes of the proxy header 66 buf := make([]byte, 16, maxProxyLineLenV1) 67 _, err = io.ReadFull(conn, buf) 68 if err != nil { 69 return 70 } 71 72 switch buf[0] { 73 case 'P': 74 // PROXY v1: starts with "PROXY" 75 return readRawProxyLineV1(conn, buf) 76 case '\r': 77 // PROXY v2: starts with "\r\n\r\n" 78 return readRawProxyLineV2(conn, buf) 79 default: 80 return nil, ErrBadProxyLine 81 } 82} 83 84func readRawProxyLineV1(conn net.Conn, buf []byte) (result []byte, err error) { 85 for { 86 i := len(buf) 87 if i >= maxProxyLineLenV1 { 88 return nil, ErrBadProxyLine // did not find \r\n, fail 89 } 90 // prepare a single byte of free space, then read into it 91 buf = buf[0 : i+1] 92 _, err = io.ReadFull(conn, buf[i:]) 93 if err != nil { 94 return nil, err 95 } 96 if buf[i] == '\n' { 97 return buf, nil 98 } 99 } 100} 101 102func readRawProxyLineV2(conn net.Conn, buf []byte) (result []byte, err error) { 103 // "The 15th and 16th bytes is the address length in bytes in network endian order." 104 addrLen := int(binary.BigEndian.Uint16(buf[14:16])) 105 if addrLen == 0 { 106 return buf[0:16], nil 107 } else if addrLen <= cap(buf)-16 { 108 buf = buf[0 : 16+addrLen] 109 } else { 110 // proxy source is unix domain, we don't really handle this 111 buf2 := make([]byte, 16+addrLen) 112 copy(buf2[0:16], buf[0:16]) 113 buf = buf2 114 } 115 _, err = io.ReadFull(conn, buf[16:16+addrLen]) 116 if err != nil { 117 return 118 } 119 return buf[0 : 16+addrLen], nil 120} 121 122// ParseProxyLine parses a PROXY protocol (v1 or v2) line and returns the remote IP. 123func ParseProxyLine(line []byte) (ip net.IP, err error) { 124 if len(line) == 0 { 125 return nil, ErrBadProxyLine 126 } 127 switch line[0] { 128 case 'P': 129 return ParseProxyLineV1(string(line)) 130 case '\r': 131 return parseProxyLineV2(line) 132 default: 133 return nil, ErrBadProxyLine 134 } 135} 136 137// ParseProxyLineV1 parses a PROXY protocol (v1) line and returns the remote IP. 138func ParseProxyLineV1(line string) (ip net.IP, err error) { 139 params := strings.Fields(line) 140 if len(params) != 6 || params[0] != "PROXY" { 141 return nil, ErrBadProxyLine 142 } 143 ip = net.ParseIP(params[2]) 144 if ip == nil { 145 return nil, ErrBadProxyLine 146 } 147 return ip.To16(), nil 148} 149 150func parseProxyLineV2(line []byte) (ip net.IP, err error) { 151 if len(line) < 16 { 152 return nil, ErrBadProxyLine 153 } 154 // this doesn't allocate 155 if string(line[:12]) != "\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a" { 156 return nil, ErrBadProxyLine 157 } 158 // "The next byte (the 13th one) is the protocol version and command." 159 versionCmd := line[12] 160 // "The highest four bits contains the version [....] it must always be sent as \x2" 161 if (versionCmd >> 4) != 2 { 162 return nil, ErrBadProxyLine 163 } 164 // "The lowest four bits represents the command" 165 switch versionCmd & 0x0f { 166 case 0: 167 return nil, nil // LOCAL command 168 case 1: 169 // PROXY command, continue below 170 default: 171 // "Receivers must drop connections presenting unexpected values here" 172 return nil, ErrBadProxyLine 173 } 174 175 var addrLen int 176 // "The 14th byte contains the transport protocol and address family." 177 protoAddr := line[13] 178 // "The highest 4 bits contain the address family" 179 switch protoAddr >> 4 { 180 case 1: 181 addrLen = 4 // AF_INET 182 case 2: 183 addrLen = 16 // AF_INET6 184 default: 185 return nil, nil // AF_UNSPEC or AF_UNIX, either way there's no IP address 186 } 187 188 // header, source and destination address, two 16-bit port numbers: 189 expectedLen := 16 + 2*addrLen + 4 190 if len(line) < expectedLen { 191 return nil, ErrBadProxyLine 192 } 193 194 // "Starting from the 17th byte, addresses are presented in network byte order. 195 // The address order is always the same : 196 // - source layer 3 address in network byte order [...]" 197 if addrLen == 4 { 198 ip = net.IP(line[16 : 16+addrLen]).To16() 199 } else { 200 ip = make(net.IP, addrLen) 201 copy(ip, line[16:16+addrLen]) 202 } 203 return ip, nil 204} 205 206/// WrappedConn is a net.Conn with some additional data stapled to it; 207// the proxied IP, if one was read via the PROXY protocol, and the listener 208// configuration. 209type WrappedConn struct { 210 net.Conn 211 ProxiedIP net.IP 212 Config ListenerConfig 213 // Secure indicates whether we believe the connection between us and the client 214 // was secure against interception and modification (including all proxies): 215 Secure bool 216} 217 218// ReloadableListener is a wrapper for net.Listener that allows reloading 219// of config data for postprocessing connections (TLS, PROXY protocol, etc.) 220type ReloadableListener struct { 221 // TODO: make this lock-free 222 sync.Mutex 223 realListener net.Listener 224 config ListenerConfig 225 isClosed bool 226} 227 228func NewReloadableListener(realListener net.Listener, config ListenerConfig) *ReloadableListener { 229 return &ReloadableListener{ 230 realListener: realListener, 231 config: config, 232 } 233} 234 235func (rl *ReloadableListener) Reload(config ListenerConfig) { 236 rl.Lock() 237 rl.config = config 238 rl.Unlock() 239} 240 241func (rl *ReloadableListener) Accept() (conn net.Conn, err error) { 242 conn, err = rl.realListener.Accept() 243 244 rl.Lock() 245 config := rl.config 246 isClosed := rl.isClosed 247 rl.Unlock() 248 249 if isClosed { 250 if err == nil { 251 conn.Close() 252 } 253 err = net.ErrClosed 254 } 255 if err != nil { 256 return nil, err 257 } 258 259 var proxiedIP net.IP 260 if config.RequireProxy { 261 // this will occur synchronously on the goroutine calling Accept(), 262 // but that's OK because this listener *requires* a PROXY line, 263 // therefore it must be used with proxies that always send the line 264 // and we won't get slowloris'ed waiting for the client response 265 proxyLine, err := readRawProxyLine(conn, config.ProxyDeadline) 266 if err == nil { 267 proxiedIP, err = ParseProxyLine(proxyLine) 268 } 269 if err != nil { 270 conn.Close() 271 return nil, err 272 } 273 } 274 275 if config.TLSConfig != nil { 276 conn = tls.Server(conn, config.TLSConfig) 277 } 278 279 return &WrappedConn{ 280 Conn: conn, 281 ProxiedIP: proxiedIP, 282 Config: config, 283 }, nil 284} 285 286func (rl *ReloadableListener) Close() error { 287 rl.Lock() 288 rl.isClosed = true 289 rl.Unlock() 290 291 return rl.realListener.Close() 292} 293 294func (rl *ReloadableListener) Addr() net.Addr { 295 return rl.realListener.Addr() 296} 297