1package quic 2 3import ( 4 "bytes" 5 "crypto/rand" 6 "errors" 7 "net" 8 "time" 9 10 mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging" 11 "github.com/lucas-clemente/quic-go/internal/protocol" 12 "github.com/lucas-clemente/quic-go/internal/utils" 13 "github.com/lucas-clemente/quic-go/internal/wire" 14 "github.com/lucas-clemente/quic-go/logging" 15 16 "github.com/golang/mock/gomock" 17 18 . "github.com/onsi/ginkgo" 19 . "github.com/onsi/gomega" 20) 21 22var _ = Describe("Packet Handler Map", func() { 23 type packetToRead struct { 24 addr net.Addr 25 data []byte 26 err error 27 } 28 29 var ( 30 handler *packetHandlerMap 31 conn *MockPacketConn 32 tracer *mocklogging.MockTracer 33 packetChan chan packetToRead 34 35 connIDLen int 36 statelessResetKey []byte 37 ) 38 39 getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte { 40 buf := &bytes.Buffer{} 41 Expect((&wire.ExtendedHeader{ 42 Header: wire.Header{ 43 IsLongHeader: true, 44 Type: t, 45 DestConnectionID: connID, 46 Length: length, 47 Version: protocol.VersionTLS, 48 }, 49 PacketNumberLen: protocol.PacketNumberLen2, 50 }).Write(buf, protocol.VersionWhatever)).To(Succeed()) 51 return buf.Bytes() 52 } 53 54 getPacket := func(connID protocol.ConnectionID) []byte { 55 return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2) 56 } 57 58 BeforeEach(func() { 59 statelessResetKey = nil 60 connIDLen = 0 61 tracer = mocklogging.NewMockTracer(mockCtrl) 62 packetChan = make(chan packetToRead, 10) 63 }) 64 65 JustBeforeEach(func() { 66 conn = NewMockPacketConn(mockCtrl) 67 conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() 68 conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) { 69 p, ok := <-packetChan 70 if !ok { 71 return 0, nil, errors.New("closed") 72 } 73 return copy(b, p.data), p.addr, p.err 74 }).AnyTimes() 75 phm, err := newPacketHandlerMap(conn, connIDLen, statelessResetKey, tracer, utils.DefaultLogger) 76 Expect(err).ToNot(HaveOccurred()) 77 handler = phm.(*packetHandlerMap) 78 }) 79 80 It("closes", func() { 81 getMultiplexer() // make the sync.Once execute 82 // replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer 83 mockMultiplexer := NewMockMultiplexer(mockCtrl) 84 origMultiplexer := connMuxer 85 connMuxer = mockMultiplexer 86 87 defer func() { 88 connMuxer = origMultiplexer 89 }() 90 91 testErr := errors.New("test error ") 92 sess1 := NewMockPacketHandler(mockCtrl) 93 sess1.EXPECT().destroy(testErr) 94 sess2 := NewMockPacketHandler(mockCtrl) 95 sess2.EXPECT().destroy(testErr) 96 handler.Add(protocol.ConnectionID{1, 1, 1, 1}, sess1) 97 handler.Add(protocol.ConnectionID{2, 2, 2, 2}, sess2) 98 mockMultiplexer.EXPECT().RemoveConn(gomock.Any()) 99 handler.close(testErr) 100 close(packetChan) 101 Eventually(handler.listening).Should(BeClosed()) 102 }) 103 104 Context("other operations", func() { 105 AfterEach(func() { 106 // delete sessions and the server before closing 107 // They might be mock implementations, and we'd have to register the expected calls before otherwise. 108 handler.mutex.Lock() 109 for connID := range handler.handlers { 110 delete(handler.handlers, connID) 111 } 112 handler.server = nil 113 handler.mutex.Unlock() 114 conn.EXPECT().Close().MaxTimes(1) 115 close(packetChan) 116 handler.Destroy() 117 Eventually(handler.listening).Should(BeClosed()) 118 }) 119 120 Context("handling packets", func() { 121 BeforeEach(func() { 122 connIDLen = 5 123 }) 124 125 It("handles packets for different packet handlers on the same packet conn", func() { 126 connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} 127 connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} 128 packetHandler1 := NewMockPacketHandler(mockCtrl) 129 packetHandler2 := NewMockPacketHandler(mockCtrl) 130 handledPacket1 := make(chan struct{}) 131 handledPacket2 := make(chan struct{}) 132 packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { 133 connID, err := wire.ParseConnectionID(p.data, 0) 134 Expect(err).ToNot(HaveOccurred()) 135 Expect(connID).To(Equal(connID1)) 136 close(handledPacket1) 137 }) 138 packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { 139 connID, err := wire.ParseConnectionID(p.data, 0) 140 Expect(err).ToNot(HaveOccurred()) 141 Expect(connID).To(Equal(connID2)) 142 close(handledPacket2) 143 }) 144 handler.Add(connID1, packetHandler1) 145 handler.Add(connID2, packetHandler2) 146 packetChan <- packetToRead{data: getPacket(connID1)} 147 packetChan <- packetToRead{data: getPacket(connID2)} 148 149 Eventually(handledPacket1).Should(BeClosed()) 150 Eventually(handledPacket2).Should(BeClosed()) 151 }) 152 153 It("drops unparseable packets", func() { 154 addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} 155 tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError) 156 handler.handlePacket(&receivedPacket{ 157 buffer: getPacketBuffer(), 158 remoteAddr: addr, 159 data: []byte{0, 1, 2, 3}, 160 }) 161 }) 162 163 It("deletes removed sessions immediately", func() { 164 handler.deleteRetiredSessionsAfter = time.Hour 165 connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} 166 handler.Add(connID, NewMockPacketHandler(mockCtrl)) 167 handler.Remove(connID) 168 handler.handlePacket(&receivedPacket{data: getPacket(connID)}) 169 // don't EXPECT any calls to handlePacket of the MockPacketHandler 170 }) 171 172 It("deletes retired session entries after a wait time", func() { 173 handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) 174 connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} 175 sess := NewMockPacketHandler(mockCtrl) 176 handler.Add(connID, sess) 177 handler.Retire(connID) 178 time.Sleep(scaleDuration(30 * time.Millisecond)) 179 handler.handlePacket(&receivedPacket{data: getPacket(connID)}) 180 // don't EXPECT any calls to handlePacket of the MockPacketHandler 181 }) 182 183 It("passes packets arriving late for closed sessions to that session", func() { 184 handler.deleteRetiredSessionsAfter = time.Hour 185 connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} 186 packetHandler := NewMockPacketHandler(mockCtrl) 187 handled := make(chan struct{}) 188 packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { 189 close(handled) 190 }) 191 handler.Add(connID, packetHandler) 192 handler.Retire(connID) 193 handler.handlePacket(&receivedPacket{data: getPacket(connID)}) 194 Eventually(handled).Should(BeClosed()) 195 }) 196 197 It("drops packets for unknown receivers", func() { 198 connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} 199 handler.handlePacket(&receivedPacket{data: getPacket(connID)}) 200 }) 201 202 It("closes the packet handlers when reading from the conn fails", func() { 203 done := make(chan struct{}) 204 packetHandler := NewMockPacketHandler(mockCtrl) 205 packetHandler.EXPECT().destroy(gomock.Any()).Do(func(e error) { 206 Expect(e).To(HaveOccurred()) 207 close(done) 208 }) 209 handler.Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler) 210 packetChan <- packetToRead{err: errors.New("read failed")} 211 Eventually(done).Should(BeClosed()) 212 }) 213 214 It("continues listening for temporary errors", func() { 215 packetHandler := NewMockPacketHandler(mockCtrl) 216 handler.Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler) 217 err := deadlineError{} 218 Expect(err.Temporary()).To(BeTrue()) 219 packetChan <- packetToRead{err: err} 220 // don't EXPECT any calls to packetHandler.destroy 221 time.Sleep(50 * time.Millisecond) 222 }) 223 224 It("says if a connection ID is already taken", func() { 225 connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} 226 Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue()) 227 Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse()) 228 }) 229 230 It("says if a connection ID is already taken, for AddWithConnID", func() { 231 clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} 232 newConnID1 := protocol.ConnectionID{1, 2, 3, 4} 233 newConnID2 := protocol.ConnectionID{4, 3, 2, 1} 234 Expect(handler.AddWithConnID(clientDestConnID, newConnID1, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeTrue()) 235 Expect(handler.AddWithConnID(clientDestConnID, newConnID2, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeFalse()) 236 }) 237 }) 238 239 Context("running a server", func() { 240 It("adds a server", func() { 241 connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} 242 p := getPacket(connID) 243 server := NewMockUnknownPacketHandler(mockCtrl) 244 server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { 245 cid, err := wire.ParseConnectionID(p.data, 0) 246 Expect(err).ToNot(HaveOccurred()) 247 Expect(cid).To(Equal(connID)) 248 }) 249 handler.SetServer(server) 250 handler.handlePacket(&receivedPacket{data: p}) 251 }) 252 253 It("closes all server sessions", func() { 254 handler.SetServer(NewMockUnknownPacketHandler(mockCtrl)) 255 clientSess := NewMockPacketHandler(mockCtrl) 256 clientSess.EXPECT().getPerspective().Return(protocol.PerspectiveClient) 257 serverSess := NewMockPacketHandler(mockCtrl) 258 serverSess.EXPECT().getPerspective().Return(protocol.PerspectiveServer) 259 serverSess.EXPECT().shutdown() 260 261 handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientSess) 262 handler.Add(protocol.ConnectionID{2, 2, 2, 2}, serverSess) 263 handler.CloseServer() 264 }) 265 266 It("stops handling packets with unknown connection IDs after the server is closed", func() { 267 connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} 268 p := getPacket(connID) 269 server := NewMockUnknownPacketHandler(mockCtrl) 270 // don't EXPECT any calls to server.handlePacket 271 handler.SetServer(server) 272 handler.CloseServer() 273 handler.handlePacket(&receivedPacket{data: p}) 274 }) 275 }) 276 277 Context("0-RTT", func() { 278 JustBeforeEach(func() { 279 handler.zeroRTTQueueDuration = time.Hour 280 server := NewMockUnknownPacketHandler(mockCtrl) 281 // we don't expect any calls to server.handlePacket 282 handler.SetServer(server) 283 }) 284 285 It("queues 0-RTT packets", func() { 286 server := NewMockUnknownPacketHandler(mockCtrl) 287 // don't EXPECT any calls to server.handlePacket 288 handler.SetServer(server) 289 connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} 290 p1 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} 291 p2 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 2)} 292 p3 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 3)} 293 handler.handlePacket(p1) 294 handler.handlePacket(p2) 295 handler.handlePacket(p3) 296 sess := NewMockPacketHandler(mockCtrl) 297 done := make(chan struct{}) 298 gomock.InOrder( 299 sess.EXPECT().handlePacket(p1), 300 sess.EXPECT().handlePacket(p2), 301 sess.EXPECT().handlePacket(p3).Do(func(packet *receivedPacket) { close(done) }), 302 ) 303 handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return sess }) 304 Eventually(done).Should(BeClosed()) 305 }) 306 307 It("directs 0-RTT packets to existing sessions", func() { 308 connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} 309 sess := NewMockPacketHandler(mockCtrl) 310 handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return sess }) 311 p1 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} 312 sess.EXPECT().handlePacket(p1) 313 handler.handlePacket(p1) 314 }) 315 316 It("limits the number of 0-RTT queues", func() { 317 for i := 0; i < protocol.Max0RTTQueues; i++ { 318 connID := make(protocol.ConnectionID, 8) 319 rand.Read(connID) 320 p := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} 321 handler.handlePacket(p) 322 } 323 // We're already storing the maximum number of queues. This packet will be dropped. 324 connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9} 325 handler.handlePacket(&receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)}) 326 // Don't EXPECT any handlePacket() calls. 327 sess := NewMockPacketHandler(mockCtrl) 328 handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return sess }) 329 time.Sleep(20 * time.Millisecond) 330 }) 331 332 It("deletes queues if no session is created for this connection ID", func() { 333 queueDuration := scaleDuration(10 * time.Millisecond) 334 handler.zeroRTTQueueDuration = queueDuration 335 336 server := NewMockUnknownPacketHandler(mockCtrl) 337 // don't EXPECT any calls to server.handlePacket 338 handler.SetServer(server) 339 connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} 340 p1 := &receivedPacket{ 341 data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1), 342 buffer: getPacketBuffer(), 343 } 344 p2 := &receivedPacket{ 345 data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 2), 346 buffer: getPacketBuffer(), 347 } 348 handler.handlePacket(p1) 349 handler.handlePacket(p2) 350 // wait a bit. The queue should now already be deleted. 351 time.Sleep(queueDuration * 3) 352 // Don't EXPECT any handlePacket() calls. 353 sess := NewMockPacketHandler(mockCtrl) 354 handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return sess }) 355 time.Sleep(20 * time.Millisecond) 356 }) 357 }) 358 359 Context("stateless resets", func() { 360 BeforeEach(func() { 361 connIDLen = 5 362 }) 363 364 Context("handling", func() { 365 It("handles stateless resets", func() { 366 packetHandler := NewMockPacketHandler(mockCtrl) 367 token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} 368 handler.AddResetToken(token, packetHandler) 369 destroyed := make(chan struct{}) 370 packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) 371 packet = append(packet, token[:]...) 372 packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) { 373 defer GinkgoRecover() 374 defer close(destroyed) 375 Expect(err).To(HaveOccurred()) 376 var resetErr *StatelessResetError 377 Expect(errors.As(err, &resetErr)).To(BeTrue()) 378 Expect(err.Error()).To(ContainSubstring("received a stateless reset")) 379 Expect(resetErr.Token).To(Equal(token)) 380 }) 381 packetChan <- packetToRead{data: packet} 382 Eventually(destroyed).Should(BeClosed()) 383 }) 384 385 It("handles stateless resets for 0-length connection IDs", func() { 386 handler.connIDLen = 0 387 packetHandler := NewMockPacketHandler(mockCtrl) 388 token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} 389 handler.AddResetToken(token, packetHandler) 390 destroyed := make(chan struct{}) 391 packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) 392 packet = append(packet, token[:]...) 393 packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) { 394 defer GinkgoRecover() 395 Expect(err).To(HaveOccurred()) 396 var resetErr *StatelessResetError 397 Expect(errors.As(err, &resetErr)).To(BeTrue()) 398 Expect(err.Error()).To(ContainSubstring("received a stateless reset")) 399 Expect(resetErr.Token).To(Equal(token)) 400 close(destroyed) 401 }) 402 packetChan <- packetToRead{data: packet} 403 Eventually(destroyed).Should(BeClosed()) 404 }) 405 406 It("removes reset tokens", func() { 407 connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} 408 packetHandler := NewMockPacketHandler(mockCtrl) 409 handler.Add(connID, packetHandler) 410 token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} 411 handler.AddResetToken(token, NewMockPacketHandler(mockCtrl)) 412 handler.RemoveResetToken(token) 413 // don't EXPECT any call to packetHandler.destroy() 414 packetHandler.EXPECT().handlePacket(gomock.Any()) 415 p := append([]byte{0x40} /* short header packet */, connID.Bytes()...) 416 p = append(p, make([]byte, 50)...) 417 p = append(p, token[:]...) 418 419 handler.handlePacket(&receivedPacket{data: p}) 420 }) 421 422 It("ignores packets too small to contain a stateless reset", func() { 423 handler.connIDLen = 0 424 packetHandler := NewMockPacketHandler(mockCtrl) 425 token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} 426 handler.AddResetToken(token, packetHandler) 427 done := make(chan struct{}) 428 // don't EXPECT any calls here, but register the closing of the done channel 429 packetHandler.EXPECT().destroy(gomock.Any()).Do(func(error) { 430 close(done) 431 }).AnyTimes() 432 packetChan <- packetToRead{data: append([]byte{0x40} /* short header packet */, token[:15]...)} 433 Consistently(done).ShouldNot(BeClosed()) 434 }) 435 }) 436 437 Context("generating", func() { 438 BeforeEach(func() { 439 key := make([]byte, 32) 440 rand.Read(key) 441 statelessResetKey = key 442 }) 443 444 It("generates stateless reset tokens", func() { 445 connID1 := []byte{0xde, 0xad, 0xbe, 0xef} 446 connID2 := []byte{0xde, 0xca, 0xfb, 0xad} 447 Expect(handler.GetStatelessResetToken(connID1)).ToNot(Equal(handler.GetStatelessResetToken(connID2))) 448 }) 449 450 It("sends stateless resets", func() { 451 addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} 452 p := append([]byte{40}, make([]byte, 100)...) 453 done := make(chan struct{}) 454 conn.EXPECT().WriteTo(gomock.Any(), addr).Do(func(b []byte, _ net.Addr) { 455 defer close(done) 456 Expect(b[0] & 0x80).To(BeZero()) // short header packet 457 Expect(b).To(HaveLen(protocol.MinStatelessResetSize)) 458 }) 459 handler.handlePacket(&receivedPacket{ 460 buffer: getPacketBuffer(), 461 remoteAddr: addr, 462 data: p, 463 }) 464 Eventually(done).Should(BeClosed()) 465 }) 466 467 It("doesn't send stateless resets for small packets", func() { 468 addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} 469 p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...) 470 handler.handlePacket(&receivedPacket{ 471 buffer: getPacketBuffer(), 472 remoteAddr: addr, 473 data: p, 474 }) 475 // make sure there are no Write calls on the packet conn 476 time.Sleep(50 * time.Millisecond) 477 }) 478 }) 479 480 Context("if no key is configured", func() { 481 It("doesn't send stateless resets", func() { 482 addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} 483 p := append([]byte{40}, make([]byte, 100)...) 484 handler.handlePacket(&receivedPacket{ 485 buffer: getPacketBuffer(), 486 remoteAddr: addr, 487 data: p, 488 }) 489 // make sure there are no Write calls on the packet conn 490 time.Sleep(50 * time.Millisecond) 491 }) 492 }) 493 }) 494 }) 495}) 496