1package quic 2 3import ( 4 "bytes" 5 "context" 6 "crypto/rand" 7 "crypto/tls" 8 "errors" 9 "net" 10 "reflect" 11 "runtime/pprof" 12 "strings" 13 "sync" 14 "sync/atomic" 15 "time" 16 17 "github.com/lucas-clemente/quic-go/internal/handshake" 18 mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging" 19 "github.com/lucas-clemente/quic-go/internal/protocol" 20 "github.com/lucas-clemente/quic-go/internal/qerr" 21 "github.com/lucas-clemente/quic-go/internal/testdata" 22 "github.com/lucas-clemente/quic-go/internal/utils" 23 "github.com/lucas-clemente/quic-go/internal/wire" 24 "github.com/lucas-clemente/quic-go/logging" 25 26 "github.com/golang/mock/gomock" 27 28 . "github.com/onsi/ginkgo" 29 . "github.com/onsi/gomega" 30) 31 32func areServersRunning() bool { 33 var b bytes.Buffer 34 pprof.Lookup("goroutine").WriteTo(&b, 1) 35 return strings.Contains(b.String(), "quic-go.(*baseServer).run") 36} 37 38var _ = Describe("Server", func() { 39 var ( 40 conn *MockPacketConn 41 tlsConf *tls.Config 42 ) 43 44 getPacket := func(hdr *wire.Header, p []byte) *receivedPacket { 45 buffer := getPacketBuffer() 46 buf := bytes.NewBuffer(buffer.Data) 47 if hdr.IsLongHeader { 48 hdr.Length = 4 + protocol.ByteCount(len(p)) + 16 49 } 50 Expect((&wire.ExtendedHeader{ 51 Header: *hdr, 52 PacketNumber: 0x42, 53 PacketNumberLen: protocol.PacketNumberLen4, 54 }).Write(buf, protocol.VersionTLS)).To(Succeed()) 55 n := buf.Len() 56 buf.Write(p) 57 data := buffer.Data[:buf.Len()] 58 sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version) 59 _ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n]) 60 data = data[:len(data)+16] 61 sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n]) 62 return &receivedPacket{ 63 remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456}, 64 data: data, 65 buffer: buffer, 66 } 67 } 68 69 getInitial := func(destConnID protocol.ConnectionID) *receivedPacket { 70 senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} 71 hdr := &wire.Header{ 72 IsLongHeader: true, 73 Type: protocol.PacketTypeInitial, 74 SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, 75 DestConnectionID: destConnID, 76 Version: protocol.VersionTLS, 77 } 78 p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 79 p.buffer = getPacketBuffer() 80 p.remoteAddr = senderAddr 81 return p 82 } 83 84 getInitialWithRandomDestConnID := func() *receivedPacket { 85 destConnID := make([]byte, 10) 86 _, err := rand.Read(destConnID) 87 Expect(err).ToNot(HaveOccurred()) 88 89 return getInitial(destConnID) 90 } 91 92 parseHeader := func(data []byte) *wire.Header { 93 hdr, _, _, err := wire.ParsePacket(data, 0) 94 Expect(err).ToNot(HaveOccurred()) 95 return hdr 96 } 97 98 BeforeEach(func() { 99 conn = NewMockPacketConn(mockCtrl) 100 conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() 101 conn.EXPECT().ReadFrom(gomock.Any()).Do(func(_ []byte) { <-(make(chan struct{})) }).MaxTimes(1) 102 tlsConf = testdata.GetTLSConfig() 103 tlsConf.NextProtos = []string{"proto1"} 104 }) 105 106 AfterEach(func() { 107 Eventually(areServersRunning).Should(BeFalse()) 108 }) 109 110 It("errors when no tls.Config is given", func() { 111 _, err := ListenAddr("localhost:0", nil, nil) 112 Expect(err).To(HaveOccurred()) 113 Expect(err.Error()).To(ContainSubstring("quic: tls.Config not set")) 114 }) 115 116 It("errors when the Config contains an invalid version", func() { 117 version := protocol.VersionNumber(0x1234) 118 _, err := Listen(nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}}) 119 Expect(err).To(MatchError("0x1234 is not a valid QUIC version")) 120 }) 121 122 It("fills in default values if options are not set in the Config", func() { 123 ln, err := Listen(conn, tlsConf, &Config{}) 124 Expect(err).ToNot(HaveOccurred()) 125 server := ln.(*baseServer) 126 Expect(server.config.Versions).To(Equal(protocol.SupportedVersions)) 127 Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) 128 Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) 129 Expect(reflect.ValueOf(server.config.AcceptToken)).To(Equal(reflect.ValueOf(defaultAcceptToken))) 130 Expect(server.config.KeepAlive).To(BeFalse()) 131 // stop the listener 132 Expect(ln.Close()).To(Succeed()) 133 }) 134 135 It("setups with the right values", func() { 136 supportedVersions := []protocol.VersionNumber{protocol.VersionTLS} 137 acceptToken := func(_ net.Addr, _ *Token) bool { return true } 138 config := Config{ 139 Versions: supportedVersions, 140 AcceptToken: acceptToken, 141 HandshakeIdleTimeout: 1337 * time.Hour, 142 MaxIdleTimeout: 42 * time.Minute, 143 KeepAlive: true, 144 StatelessResetKey: []byte("foobar"), 145 } 146 ln, err := Listen(conn, tlsConf, &config) 147 Expect(err).ToNot(HaveOccurred()) 148 server := ln.(*baseServer) 149 Expect(server.sessionHandler).ToNot(BeNil()) 150 Expect(server.config.Versions).To(Equal(supportedVersions)) 151 Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour)) 152 Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute)) 153 Expect(reflect.ValueOf(server.config.AcceptToken)).To(Equal(reflect.ValueOf(acceptToken))) 154 Expect(server.config.KeepAlive).To(BeTrue()) 155 Expect(server.config.StatelessResetKey).To(Equal([]byte("foobar"))) 156 // stop the listener 157 Expect(ln.Close()).To(Succeed()) 158 }) 159 160 It("listens on a given address", func() { 161 addr := "127.0.0.1:13579" 162 ln, err := ListenAddr(addr, tlsConf, &Config{}) 163 Expect(err).ToNot(HaveOccurred()) 164 Expect(ln.Addr().String()).To(Equal(addr)) 165 // stop the listener 166 Expect(ln.Close()).To(Succeed()) 167 }) 168 169 It("errors if given an invalid address", func() { 170 addr := "127.0.0.1" 171 _, err := ListenAddr(addr, tlsConf, &Config{}) 172 Expect(err).To(BeAssignableToTypeOf(&net.AddrError{})) 173 }) 174 175 It("errors if given an invalid address", func() { 176 addr := "1.1.1.1:1111" 177 _, err := ListenAddr(addr, tlsConf, &Config{}) 178 Expect(err).To(BeAssignableToTypeOf(&net.OpError{})) 179 }) 180 181 Context("server accepting sessions that completed the handshake", func() { 182 var ( 183 serv *baseServer 184 phm *MockPacketHandlerManager 185 tracer *mocklogging.MockTracer 186 ) 187 188 BeforeEach(func() { 189 tracer = mocklogging.NewMockTracer(mockCtrl) 190 ln, err := Listen(conn, tlsConf, &Config{Tracer: tracer}) 191 Expect(err).ToNot(HaveOccurred()) 192 serv = ln.(*baseServer) 193 phm = NewMockPacketHandlerManager(mockCtrl) 194 serv.sessionHandler = phm 195 }) 196 197 AfterEach(func() { 198 phm.EXPECT().CloseServer().MaxTimes(1) 199 serv.Close() 200 }) 201 202 Context("handling packets", func() { 203 It("drops Initial packets with a too short connection ID", func() { 204 p := getPacket(&wire.Header{ 205 IsLongHeader: true, 206 Type: protocol.PacketTypeInitial, 207 DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, 208 Version: serv.config.Versions[0], 209 }, nil) 210 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) 211 serv.handlePacket(p) 212 // make sure there are no Write calls on the packet conn 213 time.Sleep(50 * time.Millisecond) 214 }) 215 216 It("drops too small Initial", func() { 217 p := getPacket(&wire.Header{ 218 IsLongHeader: true, 219 Type: protocol.PacketTypeInitial, 220 DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, 221 Version: serv.config.Versions[0], 222 }, make([]byte, protocol.MinInitialPacketSize-100), 223 ) 224 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) 225 serv.handlePacket(p) 226 // make sure there are no Write calls on the packet conn 227 time.Sleep(50 * time.Millisecond) 228 }) 229 230 It("drops non-Initial packets", func() { 231 p := getPacket(&wire.Header{ 232 IsLongHeader: true, 233 Type: protocol.PacketTypeHandshake, 234 Version: serv.config.Versions[0], 235 }, []byte("invalid")) 236 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedPacket) 237 serv.handlePacket(p) 238 // make sure there are no Write calls on the packet conn 239 time.Sleep(50 * time.Millisecond) 240 }) 241 242 It("decodes the token from the Token field", func() { 243 raddr := &net.UDPAddr{ 244 IP: net.IPv4(192, 168, 13, 37), 245 Port: 1337, 246 } 247 done := make(chan struct{}) 248 serv.config.AcceptToken = func(addr net.Addr, token *Token) bool { 249 Expect(addr).To(Equal(raddr)) 250 Expect(token).ToNot(BeNil()) 251 close(done) 252 return false 253 } 254 token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil) 255 Expect(err).ToNot(HaveOccurred()) 256 packet := getPacket(&wire.Header{ 257 IsLongHeader: true, 258 Type: protocol.PacketTypeInitial, 259 Token: token, 260 Version: serv.config.Versions[0], 261 }, make([]byte, protocol.MinInitialPacketSize)) 262 packet.remoteAddr = raddr 263 conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1) 264 tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) 265 serv.handlePacket(packet) 266 Eventually(done).Should(BeClosed()) 267 }) 268 269 It("passes an empty token to the callback, if decoding fails", func() { 270 raddr := &net.UDPAddr{ 271 IP: net.IPv4(192, 168, 13, 37), 272 Port: 1337, 273 } 274 done := make(chan struct{}) 275 serv.config.AcceptToken = func(addr net.Addr, token *Token) bool { 276 Expect(addr).To(Equal(raddr)) 277 Expect(token).To(BeNil()) 278 close(done) 279 return false 280 } 281 packet := getPacket(&wire.Header{ 282 IsLongHeader: true, 283 Type: protocol.PacketTypeInitial, 284 Token: []byte("foobar"), 285 Version: serv.config.Versions[0], 286 }, make([]byte, protocol.MinInitialPacketSize)) 287 packet.remoteAddr = raddr 288 conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1) 289 tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) 290 serv.handlePacket(packet) 291 Eventually(done).Should(BeClosed()) 292 }) 293 294 It("creates a session when the token is accepted", func() { 295 serv.config.AcceptToken = func(_ net.Addr, token *Token) bool { return true } 296 retryToken, err := serv.tokenGenerator.NewRetryToken( 297 &net.UDPAddr{}, 298 protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, 299 protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, 300 ) 301 Expect(err).ToNot(HaveOccurred()) 302 hdr := &wire.Header{ 303 IsLongHeader: true, 304 Type: protocol.PacketTypeInitial, 305 SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, 306 DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, 307 Version: protocol.VersionTLS, 308 Token: retryToken, 309 } 310 p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 311 run := make(chan struct{}) 312 var token protocol.StatelessResetToken 313 rand.Read(token[:]) 314 315 var newConnID protocol.ConnectionID 316 phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { 317 newConnID = c 318 phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken { 319 newConnID = c 320 return token 321 }) 322 fn() 323 return true 324 }) 325 tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}) 326 sess := NewMockQuicSession(mockCtrl) 327 serv.newSession = func( 328 _ sendConn, 329 _ sessionRunner, 330 origDestConnID protocol.ConnectionID, 331 retrySrcConnID *protocol.ConnectionID, 332 clientDestConnID protocol.ConnectionID, 333 destConnID protocol.ConnectionID, 334 srcConnID protocol.ConnectionID, 335 tokenP protocol.StatelessResetToken, 336 _ *Config, 337 _ *tls.Config, 338 _ *handshake.TokenGenerator, 339 enable0RTT bool, 340 _ logging.ConnectionTracer, 341 _ uint64, 342 _ utils.Logger, 343 _ protocol.VersionNumber, 344 ) quicSession { 345 Expect(enable0RTT).To(BeFalse()) 346 Expect(origDestConnID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})) 347 Expect(retrySrcConnID).To(Equal(&protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) 348 Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) 349 Expect(destConnID).To(Equal(hdr.SrcConnectionID)) 350 // make sure we're using a server-generated connection ID 351 Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) 352 Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) 353 Expect(srcConnID).To(Equal(newConnID)) 354 Expect(tokenP).To(Equal(token)) 355 sess.EXPECT().handlePacket(p) 356 sess.EXPECT().run().Do(func() { close(run) }) 357 sess.EXPECT().Context().Return(context.Background()) 358 sess.EXPECT().HandshakeComplete().Return(context.Background()) 359 return sess 360 } 361 362 done := make(chan struct{}) 363 go func() { 364 defer GinkgoRecover() 365 serv.handlePacket(p) 366 // the Handshake packet is written by the session. 367 // Make sure there are no Write calls on the packet conn. 368 time.Sleep(50 * time.Millisecond) 369 close(done) 370 }() 371 // make sure we're using a server-generated connection ID 372 Eventually(run).Should(BeClosed()) 373 Eventually(done).Should(BeClosed()) 374 }) 375 376 It("sends a Version Negotiation Packet for unsupported versions", func() { 377 srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} 378 destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} 379 packet := getPacket(&wire.Header{ 380 IsLongHeader: true, 381 Type: protocol.PacketTypeHandshake, 382 SrcConnectionID: srcConnID, 383 DestConnectionID: destConnID, 384 Version: 0x42, 385 }, make([]byte, protocol.MinUnknownVersionPacketSize)) 386 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 387 packet.remoteAddr = raddr 388 tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) { 389 Expect(replyHdr.IsLongHeader).To(BeTrue()) 390 Expect(replyHdr.Version).To(BeZero()) 391 Expect(replyHdr.SrcConnectionID).To(Equal(destConnID)) 392 Expect(replyHdr.DestConnectionID).To(Equal(srcConnID)) 393 }) 394 done := make(chan struct{}) 395 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 396 defer close(done) 397 Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue()) 398 hdr, versions, err := wire.ParseVersionNegotiationPacket(bytes.NewReader(b)) 399 Expect(err).ToNot(HaveOccurred()) 400 Expect(hdr.DestConnectionID).To(Equal(srcConnID)) 401 Expect(hdr.SrcConnectionID).To(Equal(destConnID)) 402 Expect(versions).ToNot(ContainElement(protocol.VersionNumber(0x42))) 403 return len(b), nil 404 }) 405 serv.handlePacket(packet) 406 Eventually(done).Should(BeClosed()) 407 }) 408 409 It("doesn't send a Version Negotiation packets if sending them is disabled", func() { 410 serv.config.DisableVersionNegotiationPackets = true 411 srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} 412 destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} 413 packet := getPacket(&wire.Header{ 414 IsLongHeader: true, 415 Type: protocol.PacketTypeHandshake, 416 SrcConnectionID: srcConnID, 417 DestConnectionID: destConnID, 418 Version: 0x42, 419 }, make([]byte, protocol.MinUnknownVersionPacketSize)) 420 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 421 packet.remoteAddr = raddr 422 done := make(chan struct{}) 423 conn.EXPECT().WriteTo(gomock.Any(), raddr).Do(func() { close(done) }).Times(0) 424 serv.handlePacket(packet) 425 Consistently(done, 50*time.Millisecond).ShouldNot(BeClosed()) 426 }) 427 428 It("ignores Version Negotiation packets", func() { 429 data, err := wire.ComposeVersionNegotiation( 430 protocol.ConnectionID{1, 2, 3, 4}, 431 protocol.ConnectionID{4, 3, 2, 1}, 432 []protocol.VersionNumber{1, 2, 3}, 433 ) 434 Expect(err).ToNot(HaveOccurred()) 435 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 436 done := make(chan struct{}) 437 tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 438 close(done) 439 }) 440 serv.handlePacket(&receivedPacket{ 441 remoteAddr: raddr, 442 data: data, 443 buffer: getPacketBuffer(), 444 }) 445 Eventually(done).Should(BeClosed()) 446 // make sure no other packet is sent 447 time.Sleep(scaleDuration(20 * time.Millisecond)) 448 }) 449 450 It("doesn't send a Version Negotiation Packet for unsupported versions, if the packet is too small", func() { 451 srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} 452 destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} 453 p := getPacket(&wire.Header{ 454 IsLongHeader: true, 455 Type: protocol.PacketTypeHandshake, 456 SrcConnectionID: srcConnID, 457 DestConnectionID: destConnID, 458 Version: 0x42, 459 }, make([]byte, protocol.MinUnknownVersionPacketSize-50)) 460 Expect(p.Size()).To(BeNumerically("<", protocol.MinUnknownVersionPacketSize)) 461 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 462 p.remoteAddr = raddr 463 done := make(chan struct{}) 464 tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 465 close(done) 466 }) 467 serv.handlePacket(p) 468 Eventually(done).Should(BeClosed()) 469 // make sure no other packet is sent 470 time.Sleep(scaleDuration(20 * time.Millisecond)) 471 }) 472 473 It("replies with a Retry packet, if a Token is required", func() { 474 serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } 475 hdr := &wire.Header{ 476 IsLongHeader: true, 477 Type: protocol.PacketTypeInitial, 478 SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, 479 DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, 480 Version: protocol.VersionTLS, 481 } 482 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 483 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 484 packet.remoteAddr = raddr 485 tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) { 486 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 487 Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) 488 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 489 Expect(replyHdr.Token).ToNot(BeEmpty()) 490 }) 491 done := make(chan struct{}) 492 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 493 defer close(done) 494 replyHdr := parseHeader(b) 495 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 496 Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) 497 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 498 Expect(replyHdr.Token).ToNot(BeEmpty()) 499 Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:])) 500 return len(b), nil 501 }) 502 serv.handlePacket(packet) 503 Eventually(done).Should(BeClosed()) 504 }) 505 506 It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() { 507 serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } 508 token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil) 509 Expect(err).ToNot(HaveOccurred()) 510 hdr := &wire.Header{ 511 IsLongHeader: true, 512 Type: protocol.PacketTypeInitial, 513 SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, 514 DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, 515 Token: token, 516 Version: protocol.VersionTLS, 517 } 518 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 519 packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet 520 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 521 packet.remoteAddr = raddr 522 tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { 523 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) 524 Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) 525 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 526 Expect(frames).To(HaveLen(1)) 527 Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) 528 ccf := frames[0].(*logging.ConnectionCloseFrame) 529 Expect(ccf.IsApplicationError).To(BeFalse()) 530 Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) 531 }) 532 done := make(chan struct{}) 533 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 534 defer close(done) 535 replyHdr := parseHeader(b) 536 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) 537 Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) 538 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 539 _, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version) 540 extHdr, err := unpackHeader(opener, replyHdr, b, hdr.Version) 541 Expect(err).ToNot(HaveOccurred()) 542 data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()]) 543 Expect(err).ToNot(HaveOccurred()) 544 f, err := wire.NewFrameParser(false, hdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial) 545 Expect(err).ToNot(HaveOccurred()) 546 Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) 547 ccf := f.(*wire.ConnectionCloseFrame) 548 Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) 549 Expect(ccf.ReasonPhrase).To(BeEmpty()) 550 return len(b), nil 551 }) 552 serv.handlePacket(packet) 553 Eventually(done).Should(BeClosed()) 554 }) 555 556 It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() { 557 serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } 558 token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil) 559 Expect(err).ToNot(HaveOccurred()) 560 hdr := &wire.Header{ 561 IsLongHeader: true, 562 Type: protocol.PacketTypeInitial, 563 SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, 564 DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, 565 Token: token, 566 Version: protocol.VersionTLS, 567 } 568 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 569 packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet 570 packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 571 done := make(chan struct{}) 572 tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) }) 573 serv.handlePacket(packet) 574 // make sure there are no Write calls on the packet conn 575 time.Sleep(50 * time.Millisecond) 576 Eventually(done).Should(BeClosed()) 577 }) 578 579 It("creates a session, if no Token is required", func() { 580 serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } 581 hdr := &wire.Header{ 582 IsLongHeader: true, 583 Type: protocol.PacketTypeInitial, 584 SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, 585 DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, 586 Version: protocol.VersionTLS, 587 } 588 p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 589 run := make(chan struct{}) 590 var token protocol.StatelessResetToken 591 rand.Read(token[:]) 592 593 var newConnID protocol.ConnectionID 594 phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { 595 newConnID = c 596 phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken { 597 newConnID = c 598 return token 599 }) 600 fn() 601 return true 602 }) 603 tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) 604 605 sess := NewMockQuicSession(mockCtrl) 606 serv.newSession = func( 607 _ sendConn, 608 _ sessionRunner, 609 origDestConnID protocol.ConnectionID, 610 retrySrcConnID *protocol.ConnectionID, 611 clientDestConnID protocol.ConnectionID, 612 destConnID protocol.ConnectionID, 613 srcConnID protocol.ConnectionID, 614 tokenP protocol.StatelessResetToken, 615 _ *Config, 616 _ *tls.Config, 617 _ *handshake.TokenGenerator, 618 enable0RTT bool, 619 _ logging.ConnectionTracer, 620 _ uint64, 621 _ utils.Logger, 622 _ protocol.VersionNumber, 623 ) quicSession { 624 Expect(enable0RTT).To(BeFalse()) 625 Expect(origDestConnID).To(Equal(hdr.DestConnectionID)) 626 Expect(retrySrcConnID).To(BeNil()) 627 Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) 628 Expect(destConnID).To(Equal(hdr.SrcConnectionID)) 629 // make sure we're using a server-generated connection ID 630 Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) 631 Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) 632 Expect(srcConnID).To(Equal(newConnID)) 633 Expect(tokenP).To(Equal(token)) 634 sess.EXPECT().handlePacket(p) 635 sess.EXPECT().run().Do(func() { close(run) }) 636 sess.EXPECT().Context().Return(context.Background()) 637 sess.EXPECT().HandshakeComplete().Return(context.Background()) 638 return sess 639 } 640 641 done := make(chan struct{}) 642 go func() { 643 defer GinkgoRecover() 644 serv.handlePacket(p) 645 // the Handshake packet is written by the session 646 // make sure there are no Write calls on the packet conn 647 time.Sleep(50 * time.Millisecond) 648 close(done) 649 }() 650 // make sure we're using a server-generated connection ID 651 Eventually(run).Should(BeClosed()) 652 Eventually(done).Should(BeClosed()) 653 }) 654 655 It("drops packets if the receive queue is full", func() { 656 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { 657 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 658 fn() 659 return true 660 }).AnyTimes() 661 tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes() 662 663 serv.config.AcceptToken = func(net.Addr, *Token) bool { return true } 664 acceptSession := make(chan struct{}) 665 var counter uint32 // to be used as an atomic, so we query it in Eventually 666 serv.newSession = func( 667 _ sendConn, 668 runner sessionRunner, 669 _ protocol.ConnectionID, 670 _ *protocol.ConnectionID, 671 _ protocol.ConnectionID, 672 _ protocol.ConnectionID, 673 _ protocol.ConnectionID, 674 _ protocol.StatelessResetToken, 675 _ *Config, 676 _ *tls.Config, 677 _ *handshake.TokenGenerator, 678 _ bool, 679 _ logging.ConnectionTracer, 680 _ uint64, 681 _ utils.Logger, 682 _ protocol.VersionNumber, 683 ) quicSession { 684 <-acceptSession 685 atomic.AddUint32(&counter, 1) 686 sess := NewMockQuicSession(mockCtrl) 687 sess.EXPECT().handlePacket(gomock.Any()).MaxTimes(1) 688 sess.EXPECT().run().MaxTimes(1) 689 sess.EXPECT().Context().Return(context.Background()).MaxTimes(1) 690 sess.EXPECT().HandshakeComplete().Return(context.Background()).MaxTimes(1) 691 return sess 692 } 693 694 p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}) 695 serv.handlePacket(p) 696 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention).MinTimes(1) 697 var wg sync.WaitGroup 698 for i := 0; i < 3*protocol.MaxServerUnprocessedPackets; i++ { 699 wg.Add(1) 700 go func() { 701 defer GinkgoRecover() 702 defer wg.Done() 703 serv.handlePacket(getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8})) 704 }() 705 } 706 wg.Wait() 707 708 close(acceptSession) 709 Eventually( 710 func() uint32 { return atomic.LoadUint32(&counter) }, 711 scaleDuration(100*time.Millisecond), 712 ).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) 713 Consistently(func() uint32 { return atomic.LoadUint32(&counter) }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) 714 }) 715 716 It("only creates a single session for a duplicate Initial", func() { 717 serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } 718 var createdSession bool 719 sess := NewMockQuicSession(mockCtrl) 720 serv.newSession = func( 721 _ sendConn, 722 runner sessionRunner, 723 _ protocol.ConnectionID, 724 _ *protocol.ConnectionID, 725 _ protocol.ConnectionID, 726 _ protocol.ConnectionID, 727 _ protocol.ConnectionID, 728 _ protocol.StatelessResetToken, 729 _ *Config, 730 _ *tls.Config, 731 _ *handshake.TokenGenerator, 732 _ bool, 733 _ logging.ConnectionTracer, 734 _ uint64, 735 _ utils.Logger, 736 _ protocol.VersionNumber, 737 ) quicSession { 738 createdSession = true 739 return sess 740 } 741 742 p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}) 743 phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, gomock.Any(), gomock.Any()).Return(false) 744 Expect(serv.handlePacketImpl(p)).To(BeTrue()) 745 Expect(createdSession).To(BeFalse()) 746 }) 747 748 It("rejects new connection attempts if the accept queue is full", func() { 749 serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } 750 751 serv.newSession = func( 752 _ sendConn, 753 runner sessionRunner, 754 _ protocol.ConnectionID, 755 _ *protocol.ConnectionID, 756 _ protocol.ConnectionID, 757 _ protocol.ConnectionID, 758 _ protocol.ConnectionID, 759 _ protocol.StatelessResetToken, 760 _ *Config, 761 _ *tls.Config, 762 _ *handshake.TokenGenerator, 763 _ bool, 764 _ logging.ConnectionTracer, 765 _ uint64, 766 _ utils.Logger, 767 _ protocol.VersionNumber, 768 ) quicSession { 769 sess := NewMockQuicSession(mockCtrl) 770 sess.EXPECT().handlePacket(gomock.Any()) 771 sess.EXPECT().run() 772 sess.EXPECT().Context().Return(context.Background()) 773 ctx, cancel := context.WithCancel(context.Background()) 774 cancel() 775 sess.EXPECT().HandshakeComplete().Return(ctx) 776 return sess 777 } 778 779 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { 780 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 781 fn() 782 return true 783 }).Times(protocol.MaxAcceptQueueSize) 784 tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize) 785 786 var wg sync.WaitGroup 787 wg.Add(protocol.MaxAcceptQueueSize) 788 for i := 0; i < protocol.MaxAcceptQueueSize; i++ { 789 go func() { 790 defer GinkgoRecover() 791 defer wg.Done() 792 serv.handlePacket(getInitialWithRandomDestConnID()) 793 // make sure there are no Write calls on the packet conn 794 time.Sleep(50 * time.Millisecond) 795 }() 796 } 797 wg.Wait() 798 p := getInitialWithRandomDestConnID() 799 hdr, _, _, err := wire.ParsePacket(p.data, 0) 800 Expect(err).ToNot(HaveOccurred()) 801 tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) 802 done := make(chan struct{}) 803 conn.EXPECT().WriteTo(gomock.Any(), p.remoteAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 804 defer close(done) 805 rejectHdr := parseHeader(b) 806 Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) 807 Expect(rejectHdr.Version).To(Equal(hdr.Version)) 808 Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 809 Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) 810 return len(b), nil 811 }) 812 serv.handlePacket(p) 813 Eventually(done).Should(BeClosed()) 814 }) 815 816 It("doesn't accept new sessions if they were closed in the mean time", func() { 817 serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } 818 819 p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) 820 ctx, cancel := context.WithCancel(context.Background()) 821 sessionCreated := make(chan struct{}) 822 sess := NewMockQuicSession(mockCtrl) 823 serv.newSession = func( 824 _ sendConn, 825 runner sessionRunner, 826 _ protocol.ConnectionID, 827 _ *protocol.ConnectionID, 828 _ protocol.ConnectionID, 829 _ protocol.ConnectionID, 830 _ protocol.ConnectionID, 831 _ protocol.StatelessResetToken, 832 _ *Config, 833 _ *tls.Config, 834 _ *handshake.TokenGenerator, 835 _ bool, 836 _ logging.ConnectionTracer, 837 _ uint64, 838 _ utils.Logger, 839 _ protocol.VersionNumber, 840 ) quicSession { 841 sess.EXPECT().handlePacket(p) 842 sess.EXPECT().run() 843 sess.EXPECT().Context().Return(ctx) 844 ctx, cancel := context.WithCancel(context.Background()) 845 cancel() 846 sess.EXPECT().HandshakeComplete().Return(ctx) 847 close(sessionCreated) 848 return sess 849 } 850 851 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { 852 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 853 fn() 854 return true 855 }) 856 tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()) 857 858 serv.handlePacket(p) 859 // make sure there are no Write calls on the packet conn 860 time.Sleep(50 * time.Millisecond) 861 Eventually(sessionCreated).Should(BeClosed()) 862 cancel() 863 time.Sleep(scaleDuration(200 * time.Millisecond)) 864 865 done := make(chan struct{}) 866 go func() { 867 defer GinkgoRecover() 868 serv.Accept(context.Background()) 869 close(done) 870 }() 871 Consistently(done).ShouldNot(BeClosed()) 872 873 // make the go routine return 874 phm.EXPECT().CloseServer() 875 sess.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID 876 Expect(serv.Close()).To(Succeed()) 877 Eventually(done).Should(BeClosed()) 878 }) 879 }) 880 881 Context("accepting sessions", func() { 882 It("returns Accept when an error occurs", func() { 883 testErr := errors.New("test err") 884 885 done := make(chan struct{}) 886 go func() { 887 defer GinkgoRecover() 888 _, err := serv.Accept(context.Background()) 889 Expect(err).To(MatchError(testErr)) 890 close(done) 891 }() 892 893 serv.setCloseError(testErr) 894 Eventually(done).Should(BeClosed()) 895 }) 896 897 It("returns immediately, if an error occurred before", func() { 898 testErr := errors.New("test err") 899 serv.setCloseError(testErr) 900 for i := 0; i < 3; i++ { 901 _, err := serv.Accept(context.Background()) 902 Expect(err).To(MatchError(testErr)) 903 } 904 }) 905 906 It("returns when the context is canceled", func() { 907 ctx, cancel := context.WithCancel(context.Background()) 908 done := make(chan struct{}) 909 go func() { 910 defer GinkgoRecover() 911 _, err := serv.Accept(ctx) 912 Expect(err).To(MatchError("context canceled")) 913 close(done) 914 }() 915 916 Consistently(done).ShouldNot(BeClosed()) 917 cancel() 918 Eventually(done).Should(BeClosed()) 919 }) 920 921 It("accepts new sessions when the handshake completes", func() { 922 sess := NewMockQuicSession(mockCtrl) 923 924 done := make(chan struct{}) 925 go func() { 926 defer GinkgoRecover() 927 s, err := serv.Accept(context.Background()) 928 Expect(err).ToNot(HaveOccurred()) 929 Expect(s).To(Equal(sess)) 930 close(done) 931 }() 932 933 ctx, cancel := context.WithCancel(context.Background()) // handshake context 934 serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } 935 serv.newSession = func( 936 _ sendConn, 937 runner sessionRunner, 938 _ protocol.ConnectionID, 939 _ *protocol.ConnectionID, 940 _ protocol.ConnectionID, 941 _ protocol.ConnectionID, 942 _ protocol.ConnectionID, 943 _ protocol.StatelessResetToken, 944 _ *Config, 945 _ *tls.Config, 946 _ *handshake.TokenGenerator, 947 _ bool, 948 _ logging.ConnectionTracer, 949 _ uint64, 950 _ utils.Logger, 951 _ protocol.VersionNumber, 952 ) quicSession { 953 sess.EXPECT().handlePacket(gomock.Any()) 954 sess.EXPECT().HandshakeComplete().Return(ctx) 955 sess.EXPECT().run().Do(func() {}) 956 sess.EXPECT().Context().Return(context.Background()) 957 return sess 958 } 959 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { 960 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 961 fn() 962 return true 963 }) 964 tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()) 965 serv.handleInitialImpl( 966 &receivedPacket{buffer: getPacketBuffer()}, 967 &wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}}, 968 ) 969 Consistently(done).ShouldNot(BeClosed()) 970 cancel() // complete the handshake 971 Eventually(done).Should(BeClosed()) 972 }) 973 }) 974 }) 975 976 Context("server accepting sessions that haven't completed the handshake", func() { 977 var ( 978 serv *earlyServer 979 phm *MockPacketHandlerManager 980 ) 981 982 BeforeEach(func() { 983 ln, err := ListenEarly(conn, tlsConf, nil) 984 Expect(err).ToNot(HaveOccurred()) 985 serv = ln.(*earlyServer) 986 phm = NewMockPacketHandlerManager(mockCtrl) 987 serv.sessionHandler = phm 988 }) 989 990 AfterEach(func() { 991 phm.EXPECT().CloseServer().MaxTimes(1) 992 serv.Close() 993 }) 994 995 It("accepts new sessions when they become ready", func() { 996 sess := NewMockQuicSession(mockCtrl) 997 998 done := make(chan struct{}) 999 go func() { 1000 defer GinkgoRecover() 1001 s, err := serv.Accept(context.Background()) 1002 Expect(err).ToNot(HaveOccurred()) 1003 Expect(s).To(Equal(sess)) 1004 close(done) 1005 }() 1006 1007 ready := make(chan struct{}) 1008 serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } 1009 serv.newSession = func( 1010 _ sendConn, 1011 runner sessionRunner, 1012 _ protocol.ConnectionID, 1013 _ *protocol.ConnectionID, 1014 _ protocol.ConnectionID, 1015 _ protocol.ConnectionID, 1016 _ protocol.ConnectionID, 1017 _ protocol.StatelessResetToken, 1018 _ *Config, 1019 _ *tls.Config, 1020 _ *handshake.TokenGenerator, 1021 enable0RTT bool, 1022 _ logging.ConnectionTracer, 1023 _ uint64, 1024 _ utils.Logger, 1025 _ protocol.VersionNumber, 1026 ) quicSession { 1027 Expect(enable0RTT).To(BeTrue()) 1028 sess.EXPECT().handlePacket(gomock.Any()) 1029 sess.EXPECT().run().Do(func() {}) 1030 sess.EXPECT().earlySessionReady().Return(ready) 1031 sess.EXPECT().Context().Return(context.Background()) 1032 return sess 1033 } 1034 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { 1035 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1036 fn() 1037 return true 1038 }) 1039 serv.handleInitialImpl( 1040 &receivedPacket{buffer: getPacketBuffer()}, 1041 &wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}}, 1042 ) 1043 Consistently(done).ShouldNot(BeClosed()) 1044 close(ready) 1045 Eventually(done).Should(BeClosed()) 1046 }) 1047 1048 It("rejects new connection attempts if the accept queue is full", func() { 1049 serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } 1050 senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} 1051 1052 serv.newSession = func( 1053 _ sendConn, 1054 runner sessionRunner, 1055 _ protocol.ConnectionID, 1056 _ *protocol.ConnectionID, 1057 _ protocol.ConnectionID, 1058 _ protocol.ConnectionID, 1059 _ protocol.ConnectionID, 1060 _ protocol.StatelessResetToken, 1061 _ *Config, 1062 _ *tls.Config, 1063 _ *handshake.TokenGenerator, 1064 _ bool, 1065 _ logging.ConnectionTracer, 1066 _ uint64, 1067 _ utils.Logger, 1068 _ protocol.VersionNumber, 1069 ) quicSession { 1070 ready := make(chan struct{}) 1071 close(ready) 1072 sess := NewMockQuicSession(mockCtrl) 1073 sess.EXPECT().handlePacket(gomock.Any()) 1074 sess.EXPECT().run() 1075 sess.EXPECT().earlySessionReady().Return(ready) 1076 sess.EXPECT().Context().Return(context.Background()) 1077 return sess 1078 } 1079 1080 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { 1081 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1082 fn() 1083 return true 1084 }).Times(protocol.MaxAcceptQueueSize) 1085 for i := 0; i < protocol.MaxAcceptQueueSize; i++ { 1086 serv.handlePacket(getInitialWithRandomDestConnID()) 1087 } 1088 1089 Eventually(func() int32 { return atomic.LoadInt32(&serv.sessionQueueLen) }).Should(BeEquivalentTo(protocol.MaxAcceptQueueSize)) 1090 // make sure there are no Write calls on the packet conn 1091 time.Sleep(50 * time.Millisecond) 1092 1093 p := getInitialWithRandomDestConnID() 1094 hdr := parseHeader(p.data) 1095 done := make(chan struct{}) 1096 conn.EXPECT().WriteTo(gomock.Any(), senderAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 1097 defer close(done) 1098 rejectHdr := parseHeader(b) 1099 Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) 1100 Expect(rejectHdr.Version).To(Equal(hdr.Version)) 1101 Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 1102 Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) 1103 return len(b), nil 1104 }) 1105 serv.handlePacket(p) 1106 Eventually(done).Should(BeClosed()) 1107 }) 1108 1109 It("doesn't accept new sessions if they were closed in the mean time", func() { 1110 serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } 1111 1112 p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) 1113 ctx, cancel := context.WithCancel(context.Background()) 1114 sessionCreated := make(chan struct{}) 1115 sess := NewMockQuicSession(mockCtrl) 1116 serv.newSession = func( 1117 _ sendConn, 1118 runner sessionRunner, 1119 _ protocol.ConnectionID, 1120 _ *protocol.ConnectionID, 1121 _ protocol.ConnectionID, 1122 _ protocol.ConnectionID, 1123 _ protocol.ConnectionID, 1124 _ protocol.StatelessResetToken, 1125 _ *Config, 1126 _ *tls.Config, 1127 _ *handshake.TokenGenerator, 1128 _ bool, 1129 _ logging.ConnectionTracer, 1130 _ uint64, 1131 _ utils.Logger, 1132 _ protocol.VersionNumber, 1133 ) quicSession { 1134 sess.EXPECT().handlePacket(p) 1135 sess.EXPECT().run() 1136 sess.EXPECT().earlySessionReady() 1137 sess.EXPECT().Context().Return(ctx) 1138 close(sessionCreated) 1139 return sess 1140 } 1141 1142 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { 1143 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1144 fn() 1145 return true 1146 }) 1147 serv.handlePacket(p) 1148 // make sure there are no Write calls on the packet conn 1149 time.Sleep(50 * time.Millisecond) 1150 Eventually(sessionCreated).Should(BeClosed()) 1151 cancel() 1152 time.Sleep(scaleDuration(200 * time.Millisecond)) 1153 1154 done := make(chan struct{}) 1155 go func() { 1156 defer GinkgoRecover() 1157 serv.Accept(context.Background()) 1158 close(done) 1159 }() 1160 Consistently(done).ShouldNot(BeClosed()) 1161 1162 // make the go routine return 1163 phm.EXPECT().CloseServer() 1164 sess.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID 1165 Expect(serv.Close()).To(Succeed()) 1166 Eventually(done).Should(BeClosed()) 1167 }) 1168 }) 1169}) 1170 1171var _ = Describe("default source address verification", func() { 1172 It("accepts a token", func() { 1173 remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} 1174 token := &Token{ 1175 IsRetryToken: true, 1176 RemoteAddr: "192.168.0.1", 1177 SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(time.Second), // will expire in 1 second 1178 } 1179 Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue()) 1180 }) 1181 1182 It("requests verification if no token is provided", func() { 1183 remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} 1184 Expect(defaultAcceptToken(remoteAddr, nil)).To(BeFalse()) 1185 }) 1186 1187 It("rejects a token if the address doesn't match", func() { 1188 remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} 1189 token := &Token{ 1190 IsRetryToken: true, 1191 RemoteAddr: "127.0.0.1", 1192 SentTime: time.Now(), 1193 } 1194 Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse()) 1195 }) 1196 1197 It("accepts a token for a remote address is not a UDP address", func() { 1198 remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} 1199 token := &Token{ 1200 IsRetryToken: true, 1201 RemoteAddr: "192.168.0.1:1337", 1202 SentTime: time.Now(), 1203 } 1204 Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue()) 1205 }) 1206 1207 It("rejects an invalid token for a remote address is not a UDP address", func() { 1208 remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} 1209 token := &Token{ 1210 IsRetryToken: true, 1211 RemoteAddr: "192.168.0.1:7331", // mismatching port 1212 SentTime: time.Now(), 1213 } 1214 Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse()) 1215 }) 1216 1217 It("rejects an expired token", func() { 1218 remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} 1219 token := &Token{ 1220 IsRetryToken: true, 1221 RemoteAddr: "192.168.0.1", 1222 SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(-time.Second), // expired 1 second ago 1223 } 1224 Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse()) 1225 }) 1226 1227 It("accepts a non-retry token", func() { 1228 Expect(protocol.RetryTokenValidity).To(BeNumerically("<", protocol.TokenValidity)) 1229 remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} 1230 token := &Token{ 1231 IsRetryToken: false, 1232 RemoteAddr: "192.168.0.1", 1233 // if this was a retry token, it would have expired one second ago 1234 SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(-time.Second), 1235 } 1236 Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue()) 1237 }) 1238}) 1239