1package quic 2 3import ( 4 "crypto/hmac" 5 "crypto/rand" 6 "crypto/sha256" 7 "errors" 8 "hash" 9 "net" 10 "sync" 11 "time" 12 13 "github.com/ooni/psiphon/oopsi/github.com/Psiphon-Labs/quic-go/internal/protocol" 14 "github.com/ooni/psiphon/oopsi/github.com/Psiphon-Labs/quic-go/internal/utils" 15 "github.com/ooni/psiphon/oopsi/github.com/Psiphon-Labs/quic-go/internal/wire" 16) 17 18// The packetHandlerMap stores packetHandlers, identified by connection ID. 19// It is used: 20// * by the server to store sessions 21// * when multiplexing outgoing connections to store clients 22type packetHandlerMap struct { 23 mutex sync.RWMutex 24 25 conn net.PacketConn 26 connIDLen int 27 28 handlers map[string] /* string(ConnectionID)*/ packetHandler 29 resetTokens map[[16]byte] /* stateless reset token */ packetHandler 30 server unknownPacketHandler 31 32 listening chan struct{} // is closed when listen returns 33 closed bool 34 35 deleteRetiredSessionsAfter time.Duration 36 37 statelessResetEnabled bool 38 statelessResetMutex sync.Mutex 39 statelessResetHasher hash.Hash 40 41 logger utils.Logger 42} 43 44var _ packetHandlerManager = &packetHandlerMap{} 45 46func newPacketHandlerMap( 47 conn net.PacketConn, 48 connIDLen int, 49 statelessResetKey []byte, 50 logger utils.Logger, 51) packetHandlerManager { 52 m := &packetHandlerMap{ 53 conn: conn, 54 connIDLen: connIDLen, 55 listening: make(chan struct{}), 56 handlers: make(map[string]packetHandler), 57 resetTokens: make(map[[16]byte]packetHandler), 58 deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout, 59 statelessResetEnabled: len(statelessResetKey) > 0, 60 statelessResetHasher: hmac.New(sha256.New, statelessResetKey), 61 logger: logger, 62 } 63 go m.listen() 64 65 if logger.Debug() { 66 go m.logUsage() 67 } 68 69 return m 70} 71 72func (h *packetHandlerMap) logUsage() { 73 ticker := time.NewTicker(2 * time.Second) 74 var printedZero bool 75 for { 76 select { 77 case <-h.listening: 78 return 79 case <-ticker.C: 80 } 81 82 h.mutex.Lock() 83 numHandlers := len(h.handlers) 84 numTokens := len(h.resetTokens) 85 h.mutex.Unlock() 86 // If the number tracked handlers and tokens is zero, only print it a single time. 87 hasZero := numHandlers == 0 && numTokens == 0 88 if !hasZero || (hasZero && !printedZero) { 89 h.logger.Debugf("Tracking %d connection IDs and %d reset tokens.\n", numHandlers, numTokens) 90 printedZero = false 91 if hasZero { 92 printedZero = true 93 } 94 } 95 } 96} 97 98func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) [16]byte { 99 h.mutex.Lock() 100 h.handlers[string(id)] = handler 101 h.mutex.Unlock() 102 return h.GetStatelessResetToken(id) 103} 104 105func (h *packetHandlerMap) AddIfNotTaken(id protocol.ConnectionID, handler packetHandler) bool /* was added */ { 106 sid := string(id) 107 h.mutex.Lock() 108 defer h.mutex.Unlock() 109 110 if _, ok := h.handlers[sid]; !ok { 111 h.handlers[sid] = handler 112 return true 113 } 114 return false 115} 116 117func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { 118 h.mutex.Lock() 119 delete(h.handlers, string(id)) 120 h.mutex.Unlock() 121} 122 123func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { 124 time.AfterFunc(h.deleteRetiredSessionsAfter, func() { 125 h.mutex.Lock() 126 delete(h.handlers, string(id)) 127 h.mutex.Unlock() 128 }) 129} 130 131func (h *packetHandlerMap) ReplaceWithClosed(id protocol.ConnectionID, handler packetHandler) { 132 h.mutex.Lock() 133 h.handlers[string(id)] = handler 134 h.mutex.Unlock() 135 136 time.AfterFunc(h.deleteRetiredSessionsAfter, func() { 137 h.mutex.Lock() 138 handler.Close() 139 delete(h.handlers, string(id)) 140 h.mutex.Unlock() 141 }) 142} 143 144func (h *packetHandlerMap) AddResetToken(token [16]byte, handler packetHandler) { 145 h.mutex.Lock() 146 h.resetTokens[token] = handler 147 h.mutex.Unlock() 148} 149 150func (h *packetHandlerMap) RemoveResetToken(token [16]byte) { 151 h.mutex.Lock() 152 delete(h.resetTokens, token) 153 h.mutex.Unlock() 154} 155 156func (h *packetHandlerMap) RetireResetToken(token [16]byte) { 157 time.AfterFunc(h.deleteRetiredSessionsAfter, func() { 158 h.mutex.Lock() 159 delete(h.resetTokens, token) 160 h.mutex.Unlock() 161 }) 162} 163 164func (h *packetHandlerMap) SetServer(s unknownPacketHandler) { 165 h.mutex.Lock() 166 h.server = s 167 h.mutex.Unlock() 168} 169 170func (h *packetHandlerMap) CloseServer() { 171 h.mutex.Lock() 172 h.server = nil 173 var wg sync.WaitGroup 174 for _, handler := range h.handlers { 175 if handler.getPerspective() == protocol.PerspectiveServer { 176 wg.Add(1) 177 go func(handler packetHandler) { 178 // session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped 179 _ = handler.Close() 180 wg.Done() 181 }(handler) 182 } 183 } 184 h.mutex.Unlock() 185 wg.Wait() 186} 187 188// Close the underlying connection and wait until listen() has returned. 189func (h *packetHandlerMap) Close() error { 190 if err := h.conn.Close(); err != nil { 191 return err 192 } 193 <-h.listening // wait until listening returns 194 return nil 195} 196 197func (h *packetHandlerMap) close(e error) error { 198 h.mutex.Lock() 199 if h.closed { 200 h.mutex.Unlock() 201 return nil 202 } 203 204 var wg sync.WaitGroup 205 for _, handler := range h.handlers { 206 wg.Add(1) 207 go func(handler packetHandler) { 208 handler.destroy(e) 209 wg.Done() 210 }(handler) 211 } 212 213 // [Psiphon] 214 // Call h.server.setCloseError(e) outside of mutex to prevent deadlock 215 // See comment in psiphon/common/quic/gquic-go/packetHandlerMap.close 216 217 var server unknownPacketHandler 218 if h.server != nil { 219 server = h.server 220 } 221 222 h.mutex.Unlock() 223 224 if server != nil { 225 server.setCloseError(e) 226 } 227 228 h.mutex.Lock() 229 h.closed = true 230 h.mutex.Unlock() 231 232 wg.Wait() 233 return getMultiplexer().RemoveConn(h.conn) 234} 235 236func (h *packetHandlerMap) listen() { 237 defer close(h.listening) 238 for { 239 buffer := getPacketBuffer() 240 data := buffer.Slice 241 // The packet size should not exceed protocol.MaxReceivePacketSize bytes 242 // If it does, we only read a truncated packet, which will then end up undecryptable 243 n, addr, err := h.conn.ReadFrom(data) 244 if err != nil { 245 // [Psiphon] 246 // Do not unconditionally shutdown 247 if netErr, ok := err.(net.Error); !ok || !netErr.Temporary() { 248 h.close(err) 249 return 250 } 251 } 252 h.handlePacket(addr, buffer, data[:n]) 253 } 254} 255 256func (h *packetHandlerMap) handlePacket( 257 addr net.Addr, 258 buffer *packetBuffer, 259 data []byte, 260) { 261 connID, err := wire.ParseConnectionID(data, h.connIDLen) 262 if err != nil { 263 h.logger.Debugf("error parsing connection ID on packet from %s: %s", addr, err) 264 return 265 } 266 rcvTime := time.Now() 267 268 h.mutex.RLock() 269 defer h.mutex.RUnlock() 270 271 if isStatelessReset := h.maybeHandleStatelessReset(data); isStatelessReset { 272 return 273 } 274 275 handler, handlerFound := h.handlers[string(connID)] 276 277 p := &receivedPacket{ 278 remoteAddr: addr, 279 rcvTime: rcvTime, 280 buffer: buffer, 281 data: data, 282 } 283 if handlerFound { // existing session 284 handler.handlePacket(p) 285 return 286 } 287 if data[0]&0x80 == 0 { 288 go h.maybeSendStatelessReset(p, connID) 289 return 290 } 291 if h.server == nil { // no server set 292 h.logger.Debugf("received a packet with an unexpected connection ID %s", connID) 293 return 294 } 295 h.server.handlePacket(p) 296} 297 298func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool { 299 // stateless resets are always short header packets 300 if data[0]&0x80 != 0 { 301 return false 302 } 303 if len(data) < 17 /* type byte + 16 bytes for the reset token */ { 304 return false 305 } 306 307 var token [16]byte 308 copy(token[:], data[len(data)-16:]) 309 if sess, ok := h.resetTokens[token]; ok { 310 h.logger.Debugf("Received a stateless retry with token %#x. Closing session.", token) 311 go sess.destroy(errors.New("received a stateless reset")) 312 return true 313 } 314 return false 315} 316 317func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) [16]byte { 318 var token [16]byte 319 if !h.statelessResetEnabled { 320 // Return a random stateless reset token. 321 // This token will be sent in the server's transport parameters. 322 // By using a random token, an off-path attacker won't be able to disrupt the connection. 323 rand.Read(token[:]) 324 return token 325 } 326 h.statelessResetMutex.Lock() 327 h.statelessResetHasher.Write(connID.Bytes()) 328 copy(token[:], h.statelessResetHasher.Sum(nil)) 329 h.statelessResetHasher.Reset() 330 h.statelessResetMutex.Unlock() 331 return token 332} 333 334func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) { 335 defer p.buffer.Release() 336 if !h.statelessResetEnabled { 337 return 338 } 339 // Don't send a stateless reset in response to very small packets. 340 // This includes packets that could be stateless resets. 341 if len(p.data) <= protocol.MinStatelessResetSize { 342 return 343 } 344 token := h.GetStatelessResetToken(connID) 345 h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) 346 data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize) 347 rand.Read(data) 348 data[0] = (data[0] & 0x7f) | 0x40 349 data = append(data, token[:]...) 350 if _, err := h.conn.WriteTo(data, p.remoteAddr); err != nil { 351 h.logger.Debugf("Error sending Stateless Reset: %s", err) 352 } 353} 354