1package quic
2
3import (
4	"bytes"
5	"context"
6	"crypto/rand"
7	"crypto/tls"
8	"errors"
9	"net"
10	"reflect"
11	"runtime/pprof"
12	"strings"
13	"sync"
14	"sync/atomic"
15	"time"
16
17	"github.com/lucas-clemente/quic-go/internal/handshake"
18	mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging"
19	"github.com/lucas-clemente/quic-go/internal/protocol"
20	"github.com/lucas-clemente/quic-go/internal/qerr"
21	"github.com/lucas-clemente/quic-go/internal/testdata"
22	"github.com/lucas-clemente/quic-go/internal/utils"
23	"github.com/lucas-clemente/quic-go/internal/wire"
24	"github.com/lucas-clemente/quic-go/logging"
25
26	"github.com/golang/mock/gomock"
27
28	. "github.com/onsi/ginkgo"
29	. "github.com/onsi/gomega"
30)
31
32func areServersRunning() bool {
33	var b bytes.Buffer
34	pprof.Lookup("goroutine").WriteTo(&b, 1)
35	return strings.Contains(b.String(), "quic-go.(*baseServer).run")
36}
37
38var _ = Describe("Server", func() {
39	var (
40		conn    *MockPacketConn
41		tlsConf *tls.Config
42	)
43
44	getPacket := func(hdr *wire.Header, p []byte) *receivedPacket {
45		buffer := getPacketBuffer()
46		buf := bytes.NewBuffer(buffer.Data)
47		if hdr.IsLongHeader {
48			hdr.Length = 4 + protocol.ByteCount(len(p)) + 16
49		}
50		Expect((&wire.ExtendedHeader{
51			Header:          *hdr,
52			PacketNumber:    0x42,
53			PacketNumberLen: protocol.PacketNumberLen4,
54		}).Write(buf, protocol.VersionTLS)).To(Succeed())
55		n := buf.Len()
56		buf.Write(p)
57		data := buffer.Data[:buf.Len()]
58		sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version)
59		_ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n])
60		data = data[:len(data)+16]
61		sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n])
62		return &receivedPacket{
63			remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456},
64			data:       data,
65			buffer:     buffer,
66		}
67	}
68
69	getInitial := func(destConnID protocol.ConnectionID) *receivedPacket {
70		senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
71		hdr := &wire.Header{
72			IsLongHeader:     true,
73			Type:             protocol.PacketTypeInitial,
74			SrcConnectionID:  protocol.ConnectionID{5, 4, 3, 2, 1},
75			DestConnectionID: destConnID,
76			Version:          protocol.VersionTLS,
77		}
78		p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
79		p.buffer = getPacketBuffer()
80		p.remoteAddr = senderAddr
81		return p
82	}
83
84	getInitialWithRandomDestConnID := func() *receivedPacket {
85		destConnID := make([]byte, 10)
86		_, err := rand.Read(destConnID)
87		Expect(err).ToNot(HaveOccurred())
88
89		return getInitial(destConnID)
90	}
91
92	parseHeader := func(data []byte) *wire.Header {
93		hdr, _, _, err := wire.ParsePacket(data, 0)
94		Expect(err).ToNot(HaveOccurred())
95		return hdr
96	}
97
98	BeforeEach(func() {
99		conn = NewMockPacketConn(mockCtrl)
100		conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
101		conn.EXPECT().ReadFrom(gomock.Any()).Do(func(_ []byte) { <-(make(chan struct{})) }).MaxTimes(1)
102		tlsConf = testdata.GetTLSConfig()
103		tlsConf.NextProtos = []string{"proto1"}
104	})
105
106	AfterEach(func() {
107		Eventually(areServersRunning).Should(BeFalse())
108	})
109
110	It("errors when no tls.Config is given", func() {
111		_, err := ListenAddr("localhost:0", nil, nil)
112		Expect(err).To(HaveOccurred())
113		Expect(err.Error()).To(ContainSubstring("quic: tls.Config not set"))
114	})
115
116	It("errors when the Config contains an invalid version", func() {
117		version := protocol.VersionNumber(0x1234)
118		_, err := Listen(nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}})
119		Expect(err).To(MatchError("0x1234 is not a valid QUIC version"))
120	})
121
122	It("fills in default values if options are not set in the Config", func() {
123		ln, err := Listen(conn, tlsConf, &Config{})
124		Expect(err).ToNot(HaveOccurred())
125		server := ln.(*baseServer)
126		Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
127		Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
128		Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
129		Expect(reflect.ValueOf(server.config.AcceptToken)).To(Equal(reflect.ValueOf(defaultAcceptToken)))
130		Expect(server.config.KeepAlive).To(BeFalse())
131		// stop the listener
132		Expect(ln.Close()).To(Succeed())
133	})
134
135	It("setups with the right values", func() {
136		supportedVersions := []protocol.VersionNumber{protocol.VersionTLS}
137		acceptToken := func(_ net.Addr, _ *Token) bool { return true }
138		config := Config{
139			Versions:             supportedVersions,
140			AcceptToken:          acceptToken,
141			HandshakeIdleTimeout: 1337 * time.Hour,
142			MaxIdleTimeout:       42 * time.Minute,
143			KeepAlive:            true,
144			StatelessResetKey:    []byte("foobar"),
145		}
146		ln, err := Listen(conn, tlsConf, &config)
147		Expect(err).ToNot(HaveOccurred())
148		server := ln.(*baseServer)
149		Expect(server.sessionHandler).ToNot(BeNil())
150		Expect(server.config.Versions).To(Equal(supportedVersions))
151		Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour))
152		Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute))
153		Expect(reflect.ValueOf(server.config.AcceptToken)).To(Equal(reflect.ValueOf(acceptToken)))
154		Expect(server.config.KeepAlive).To(BeTrue())
155		Expect(server.config.StatelessResetKey).To(Equal([]byte("foobar")))
156		// stop the listener
157		Expect(ln.Close()).To(Succeed())
158	})
159
160	It("listens on a given address", func() {
161		addr := "127.0.0.1:13579"
162		ln, err := ListenAddr(addr, tlsConf, &Config{})
163		Expect(err).ToNot(HaveOccurred())
164		Expect(ln.Addr().String()).To(Equal(addr))
165		// stop the listener
166		Expect(ln.Close()).To(Succeed())
167	})
168
169	It("errors if given an invalid address", func() {
170		addr := "127.0.0.1"
171		_, err := ListenAddr(addr, tlsConf, &Config{})
172		Expect(err).To(BeAssignableToTypeOf(&net.AddrError{}))
173	})
174
175	It("errors if given an invalid address", func() {
176		addr := "1.1.1.1:1111"
177		_, err := ListenAddr(addr, tlsConf, &Config{})
178		Expect(err).To(BeAssignableToTypeOf(&net.OpError{}))
179	})
180
181	Context("server accepting sessions that completed the handshake", func() {
182		var (
183			serv   *baseServer
184			phm    *MockPacketHandlerManager
185			tracer *mocklogging.MockTracer
186		)
187
188		BeforeEach(func() {
189			tracer = mocklogging.NewMockTracer(mockCtrl)
190			ln, err := Listen(conn, tlsConf, &Config{Tracer: tracer})
191			Expect(err).ToNot(HaveOccurred())
192			serv = ln.(*baseServer)
193			phm = NewMockPacketHandlerManager(mockCtrl)
194			serv.sessionHandler = phm
195		})
196
197		AfterEach(func() {
198			phm.EXPECT().CloseServer().MaxTimes(1)
199			serv.Close()
200		})
201
202		Context("handling packets", func() {
203			It("drops Initial packets with a too short connection ID", func() {
204				p := getPacket(&wire.Header{
205					IsLongHeader:     true,
206					Type:             protocol.PacketTypeInitial,
207					DestConnectionID: protocol.ConnectionID{1, 2, 3, 4},
208					Version:          serv.config.Versions[0],
209				}, nil)
210				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
211				serv.handlePacket(p)
212				// make sure there are no Write calls on the packet conn
213				time.Sleep(50 * time.Millisecond)
214			})
215
216			It("drops too small Initial", func() {
217				p := getPacket(&wire.Header{
218					IsLongHeader:     true,
219					Type:             protocol.PacketTypeInitial,
220					DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
221					Version:          serv.config.Versions[0],
222				}, make([]byte, protocol.MinInitialPacketSize-100),
223				)
224				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
225				serv.handlePacket(p)
226				// make sure there are no Write calls on the packet conn
227				time.Sleep(50 * time.Millisecond)
228			})
229
230			It("drops non-Initial packets", func() {
231				p := getPacket(&wire.Header{
232					IsLongHeader: true,
233					Type:         protocol.PacketTypeHandshake,
234					Version:      serv.config.Versions[0],
235				}, []byte("invalid"))
236				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedPacket)
237				serv.handlePacket(p)
238				// make sure there are no Write calls on the packet conn
239				time.Sleep(50 * time.Millisecond)
240			})
241
242			It("decodes the token from the Token field", func() {
243				raddr := &net.UDPAddr{
244					IP:   net.IPv4(192, 168, 13, 37),
245					Port: 1337,
246				}
247				done := make(chan struct{})
248				serv.config.AcceptToken = func(addr net.Addr, token *Token) bool {
249					Expect(addr).To(Equal(raddr))
250					Expect(token).ToNot(BeNil())
251					close(done)
252					return false
253				}
254				token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil)
255				Expect(err).ToNot(HaveOccurred())
256				packet := getPacket(&wire.Header{
257					IsLongHeader: true,
258					Type:         protocol.PacketTypeInitial,
259					Token:        token,
260					Version:      serv.config.Versions[0],
261				}, make([]byte, protocol.MinInitialPacketSize))
262				packet.remoteAddr = raddr
263				conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1)
264				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
265				serv.handlePacket(packet)
266				Eventually(done).Should(BeClosed())
267			})
268
269			It("passes an empty token to the callback, if decoding fails", func() {
270				raddr := &net.UDPAddr{
271					IP:   net.IPv4(192, 168, 13, 37),
272					Port: 1337,
273				}
274				done := make(chan struct{})
275				serv.config.AcceptToken = func(addr net.Addr, token *Token) bool {
276					Expect(addr).To(Equal(raddr))
277					Expect(token).To(BeNil())
278					close(done)
279					return false
280				}
281				packet := getPacket(&wire.Header{
282					IsLongHeader: true,
283					Type:         protocol.PacketTypeInitial,
284					Token:        []byte("foobar"),
285					Version:      serv.config.Versions[0],
286				}, make([]byte, protocol.MinInitialPacketSize))
287				packet.remoteAddr = raddr
288				conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1)
289				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
290				serv.handlePacket(packet)
291				Eventually(done).Should(BeClosed())
292			})
293
294			It("creates a session when the token is accepted", func() {
295				serv.config.AcceptToken = func(_ net.Addr, token *Token) bool { return true }
296				retryToken, err := serv.tokenGenerator.NewRetryToken(
297					&net.UDPAddr{},
298					protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde},
299					protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad},
300				)
301				Expect(err).ToNot(HaveOccurred())
302				hdr := &wire.Header{
303					IsLongHeader:     true,
304					Type:             protocol.PacketTypeInitial,
305					SrcConnectionID:  protocol.ConnectionID{5, 4, 3, 2, 1},
306					DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
307					Version:          protocol.VersionTLS,
308					Token:            retryToken,
309				}
310				p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
311				run := make(chan struct{})
312				var token protocol.StatelessResetToken
313				rand.Read(token[:])
314
315				var newConnID protocol.ConnectionID
316				phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
317					newConnID = c
318					phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
319						newConnID = c
320						return token
321					})
322					fn()
323					return true
324				})
325				tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})
326				sess := NewMockQuicSession(mockCtrl)
327				serv.newSession = func(
328					_ sendConn,
329					_ sessionRunner,
330					origDestConnID protocol.ConnectionID,
331					retrySrcConnID *protocol.ConnectionID,
332					clientDestConnID protocol.ConnectionID,
333					destConnID protocol.ConnectionID,
334					srcConnID protocol.ConnectionID,
335					tokenP protocol.StatelessResetToken,
336					_ *Config,
337					_ *tls.Config,
338					_ *handshake.TokenGenerator,
339					enable0RTT bool,
340					_ logging.ConnectionTracer,
341					_ uint64,
342					_ utils.Logger,
343					_ protocol.VersionNumber,
344				) quicSession {
345					Expect(enable0RTT).To(BeFalse())
346					Expect(origDestConnID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}))
347					Expect(retrySrcConnID).To(Equal(&protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}))
348					Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
349					Expect(destConnID).To(Equal(hdr.SrcConnectionID))
350					// make sure we're using a server-generated connection ID
351					Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
352					Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
353					Expect(srcConnID).To(Equal(newConnID))
354					Expect(tokenP).To(Equal(token))
355					sess.EXPECT().handlePacket(p)
356					sess.EXPECT().run().Do(func() { close(run) })
357					sess.EXPECT().Context().Return(context.Background())
358					sess.EXPECT().HandshakeComplete().Return(context.Background())
359					return sess
360				}
361
362				done := make(chan struct{})
363				go func() {
364					defer GinkgoRecover()
365					serv.handlePacket(p)
366					// the Handshake packet is written by the session.
367					// Make sure there are no Write calls on the packet conn.
368					time.Sleep(50 * time.Millisecond)
369					close(done)
370				}()
371				// make sure we're using a server-generated connection ID
372				Eventually(run).Should(BeClosed())
373				Eventually(done).Should(BeClosed())
374			})
375
376			It("sends a Version Negotiation Packet for unsupported versions", func() {
377				srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5}
378				destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6}
379				packet := getPacket(&wire.Header{
380					IsLongHeader:     true,
381					Type:             protocol.PacketTypeHandshake,
382					SrcConnectionID:  srcConnID,
383					DestConnectionID: destConnID,
384					Version:          0x42,
385				}, make([]byte, protocol.MinUnknownVersionPacketSize))
386				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
387				packet.remoteAddr = raddr
388				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) {
389					Expect(replyHdr.IsLongHeader).To(BeTrue())
390					Expect(replyHdr.Version).To(BeZero())
391					Expect(replyHdr.SrcConnectionID).To(Equal(destConnID))
392					Expect(replyHdr.DestConnectionID).To(Equal(srcConnID))
393				})
394				done := make(chan struct{})
395				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
396					defer close(done)
397					Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue())
398					hdr, versions, err := wire.ParseVersionNegotiationPacket(bytes.NewReader(b))
399					Expect(err).ToNot(HaveOccurred())
400					Expect(hdr.DestConnectionID).To(Equal(srcConnID))
401					Expect(hdr.SrcConnectionID).To(Equal(destConnID))
402					Expect(versions).ToNot(ContainElement(protocol.VersionNumber(0x42)))
403					return len(b), nil
404				})
405				serv.handlePacket(packet)
406				Eventually(done).Should(BeClosed())
407			})
408
409			It("doesn't send a Version Negotiation packets if sending them is disabled", func() {
410				serv.config.DisableVersionNegotiationPackets = true
411				srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5}
412				destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6}
413				packet := getPacket(&wire.Header{
414					IsLongHeader:     true,
415					Type:             protocol.PacketTypeHandshake,
416					SrcConnectionID:  srcConnID,
417					DestConnectionID: destConnID,
418					Version:          0x42,
419				}, make([]byte, protocol.MinUnknownVersionPacketSize))
420				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
421				packet.remoteAddr = raddr
422				done := make(chan struct{})
423				conn.EXPECT().WriteTo(gomock.Any(), raddr).Do(func() { close(done) }).Times(0)
424				serv.handlePacket(packet)
425				Consistently(done, 50*time.Millisecond).ShouldNot(BeClosed())
426			})
427
428			It("ignores Version Negotiation packets", func() {
429				data, err := wire.ComposeVersionNegotiation(
430					protocol.ConnectionID{1, 2, 3, 4},
431					protocol.ConnectionID{4, 3, 2, 1},
432					[]protocol.VersionNumber{1, 2, 3},
433				)
434				Expect(err).ToNot(HaveOccurred())
435				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
436				done := make(chan struct{})
437				tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
438					close(done)
439				})
440				serv.handlePacket(&receivedPacket{
441					remoteAddr: raddr,
442					data:       data,
443					buffer:     getPacketBuffer(),
444				})
445				Eventually(done).Should(BeClosed())
446				// make sure no other packet is sent
447				time.Sleep(scaleDuration(20 * time.Millisecond))
448			})
449
450			It("doesn't send a Version Negotiation Packet for unsupported versions, if the packet is too small", func() {
451				srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5}
452				destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6}
453				p := getPacket(&wire.Header{
454					IsLongHeader:     true,
455					Type:             protocol.PacketTypeHandshake,
456					SrcConnectionID:  srcConnID,
457					DestConnectionID: destConnID,
458					Version:          0x42,
459				}, make([]byte, protocol.MinUnknownVersionPacketSize-50))
460				Expect(p.Size()).To(BeNumerically("<", protocol.MinUnknownVersionPacketSize))
461				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
462				p.remoteAddr = raddr
463				done := make(chan struct{})
464				tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
465					close(done)
466				})
467				serv.handlePacket(p)
468				Eventually(done).Should(BeClosed())
469				// make sure no other packet is sent
470				time.Sleep(scaleDuration(20 * time.Millisecond))
471			})
472
473			It("replies with a Retry packet, if a Token is required", func() {
474				serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
475				hdr := &wire.Header{
476					IsLongHeader:     true,
477					Type:             protocol.PacketTypeInitial,
478					SrcConnectionID:  protocol.ConnectionID{5, 4, 3, 2, 1},
479					DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
480					Version:          protocol.VersionTLS,
481				}
482				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
483				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
484				packet.remoteAddr = raddr
485				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) {
486					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
487					Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
488					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
489					Expect(replyHdr.Token).ToNot(BeEmpty())
490				})
491				done := make(chan struct{})
492				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
493					defer close(done)
494					replyHdr := parseHeader(b)
495					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
496					Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
497					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
498					Expect(replyHdr.Token).ToNot(BeEmpty())
499					Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:]))
500					return len(b), nil
501				})
502				serv.handlePacket(packet)
503				Eventually(done).Should(BeClosed())
504			})
505
506			It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
507				serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
508				token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil)
509				Expect(err).ToNot(HaveOccurred())
510				hdr := &wire.Header{
511					IsLongHeader:     true,
512					Type:             protocol.PacketTypeInitial,
513					SrcConnectionID:  protocol.ConnectionID{5, 4, 3, 2, 1},
514					DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
515					Token:            token,
516					Version:          protocol.VersionTLS,
517				}
518				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
519				packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet
520				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
521				packet.remoteAddr = raddr
522				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
523					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
524					Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
525					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
526					Expect(frames).To(HaveLen(1))
527					Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
528					ccf := frames[0].(*logging.ConnectionCloseFrame)
529					Expect(ccf.IsApplicationError).To(BeFalse())
530					Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
531				})
532				done := make(chan struct{})
533				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
534					defer close(done)
535					replyHdr := parseHeader(b)
536					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
537					Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
538					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
539					_, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version)
540					extHdr, err := unpackHeader(opener, replyHdr, b, hdr.Version)
541					Expect(err).ToNot(HaveOccurred())
542					data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
543					Expect(err).ToNot(HaveOccurred())
544					f, err := wire.NewFrameParser(false, hdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial)
545					Expect(err).ToNot(HaveOccurred())
546					Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
547					ccf := f.(*wire.ConnectionCloseFrame)
548					Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
549					Expect(ccf.ReasonPhrase).To(BeEmpty())
550					return len(b), nil
551				})
552				serv.handlePacket(packet)
553				Eventually(done).Should(BeClosed())
554			})
555
556			It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
557				serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
558				token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil)
559				Expect(err).ToNot(HaveOccurred())
560				hdr := &wire.Header{
561					IsLongHeader:     true,
562					Type:             protocol.PacketTypeInitial,
563					SrcConnectionID:  protocol.ConnectionID{5, 4, 3, 2, 1},
564					DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
565					Token:            token,
566					Version:          protocol.VersionTLS,
567				}
568				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
569				packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
570				packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
571				done := make(chan struct{})
572				tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) })
573				serv.handlePacket(packet)
574				// make sure there are no Write calls on the packet conn
575				time.Sleep(50 * time.Millisecond)
576				Eventually(done).Should(BeClosed())
577			})
578
579			It("creates a session, if no Token is required", func() {
580				serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
581				hdr := &wire.Header{
582					IsLongHeader:     true,
583					Type:             protocol.PacketTypeInitial,
584					SrcConnectionID:  protocol.ConnectionID{5, 4, 3, 2, 1},
585					DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
586					Version:          protocol.VersionTLS,
587				}
588				p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
589				run := make(chan struct{})
590				var token protocol.StatelessResetToken
591				rand.Read(token[:])
592
593				var newConnID protocol.ConnectionID
594				phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
595					newConnID = c
596					phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
597						newConnID = c
598						return token
599					})
600					fn()
601					return true
602				})
603				tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
604
605				sess := NewMockQuicSession(mockCtrl)
606				serv.newSession = func(
607					_ sendConn,
608					_ sessionRunner,
609					origDestConnID protocol.ConnectionID,
610					retrySrcConnID *protocol.ConnectionID,
611					clientDestConnID protocol.ConnectionID,
612					destConnID protocol.ConnectionID,
613					srcConnID protocol.ConnectionID,
614					tokenP protocol.StatelessResetToken,
615					_ *Config,
616					_ *tls.Config,
617					_ *handshake.TokenGenerator,
618					enable0RTT bool,
619					_ logging.ConnectionTracer,
620					_ uint64,
621					_ utils.Logger,
622					_ protocol.VersionNumber,
623				) quicSession {
624					Expect(enable0RTT).To(BeFalse())
625					Expect(origDestConnID).To(Equal(hdr.DestConnectionID))
626					Expect(retrySrcConnID).To(BeNil())
627					Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
628					Expect(destConnID).To(Equal(hdr.SrcConnectionID))
629					// make sure we're using a server-generated connection ID
630					Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
631					Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
632					Expect(srcConnID).To(Equal(newConnID))
633					Expect(tokenP).To(Equal(token))
634					sess.EXPECT().handlePacket(p)
635					sess.EXPECT().run().Do(func() { close(run) })
636					sess.EXPECT().Context().Return(context.Background())
637					sess.EXPECT().HandshakeComplete().Return(context.Background())
638					return sess
639				}
640
641				done := make(chan struct{})
642				go func() {
643					defer GinkgoRecover()
644					serv.handlePacket(p)
645					// the Handshake packet is written by the session
646					// make sure there are no Write calls on the packet conn
647					time.Sleep(50 * time.Millisecond)
648					close(done)
649				}()
650				// make sure we're using a server-generated connection ID
651				Eventually(run).Should(BeClosed())
652				Eventually(done).Should(BeClosed())
653			})
654
655			It("drops packets if the receive queue is full", func() {
656				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
657					phm.EXPECT().GetStatelessResetToken(gomock.Any())
658					fn()
659					return true
660				}).AnyTimes()
661				tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes()
662
663				serv.config.AcceptToken = func(net.Addr, *Token) bool { return true }
664				acceptSession := make(chan struct{})
665				var counter uint32 // to be used as an atomic, so we query it in Eventually
666				serv.newSession = func(
667					_ sendConn,
668					runner sessionRunner,
669					_ protocol.ConnectionID,
670					_ *protocol.ConnectionID,
671					_ protocol.ConnectionID,
672					_ protocol.ConnectionID,
673					_ protocol.ConnectionID,
674					_ protocol.StatelessResetToken,
675					_ *Config,
676					_ *tls.Config,
677					_ *handshake.TokenGenerator,
678					_ bool,
679					_ logging.ConnectionTracer,
680					_ uint64,
681					_ utils.Logger,
682					_ protocol.VersionNumber,
683				) quicSession {
684					<-acceptSession
685					atomic.AddUint32(&counter, 1)
686					sess := NewMockQuicSession(mockCtrl)
687					sess.EXPECT().handlePacket(gomock.Any()).MaxTimes(1)
688					sess.EXPECT().run().MaxTimes(1)
689					sess.EXPECT().Context().Return(context.Background()).MaxTimes(1)
690					sess.EXPECT().HandshakeComplete().Return(context.Background()).MaxTimes(1)
691					return sess
692				}
693
694				p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8})
695				serv.handlePacket(p)
696				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention).MinTimes(1)
697				var wg sync.WaitGroup
698				for i := 0; i < 3*protocol.MaxServerUnprocessedPackets; i++ {
699					wg.Add(1)
700					go func() {
701						defer GinkgoRecover()
702						defer wg.Done()
703						serv.handlePacket(getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}))
704					}()
705				}
706				wg.Wait()
707
708				close(acceptSession)
709				Eventually(
710					func() uint32 { return atomic.LoadUint32(&counter) },
711					scaleDuration(100*time.Millisecond),
712				).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
713				Consistently(func() uint32 { return atomic.LoadUint32(&counter) }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
714			})
715
716			It("only creates a single session for a duplicate Initial", func() {
717				serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
718				var createdSession bool
719				sess := NewMockQuicSession(mockCtrl)
720				serv.newSession = func(
721					_ sendConn,
722					runner sessionRunner,
723					_ protocol.ConnectionID,
724					_ *protocol.ConnectionID,
725					_ protocol.ConnectionID,
726					_ protocol.ConnectionID,
727					_ protocol.ConnectionID,
728					_ protocol.StatelessResetToken,
729					_ *Config,
730					_ *tls.Config,
731					_ *handshake.TokenGenerator,
732					_ bool,
733					_ logging.ConnectionTracer,
734					_ uint64,
735					_ utils.Logger,
736					_ protocol.VersionNumber,
737				) quicSession {
738					createdSession = true
739					return sess
740				}
741
742				p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9})
743				phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, gomock.Any(), gomock.Any()).Return(false)
744				Expect(serv.handlePacketImpl(p)).To(BeTrue())
745				Expect(createdSession).To(BeFalse())
746			})
747
748			It("rejects new connection attempts if the accept queue is full", func() {
749				serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
750
751				serv.newSession = func(
752					_ sendConn,
753					runner sessionRunner,
754					_ protocol.ConnectionID,
755					_ *protocol.ConnectionID,
756					_ protocol.ConnectionID,
757					_ protocol.ConnectionID,
758					_ protocol.ConnectionID,
759					_ protocol.StatelessResetToken,
760					_ *Config,
761					_ *tls.Config,
762					_ *handshake.TokenGenerator,
763					_ bool,
764					_ logging.ConnectionTracer,
765					_ uint64,
766					_ utils.Logger,
767					_ protocol.VersionNumber,
768				) quicSession {
769					sess := NewMockQuicSession(mockCtrl)
770					sess.EXPECT().handlePacket(gomock.Any())
771					sess.EXPECT().run()
772					sess.EXPECT().Context().Return(context.Background())
773					ctx, cancel := context.WithCancel(context.Background())
774					cancel()
775					sess.EXPECT().HandshakeComplete().Return(ctx)
776					return sess
777				}
778
779				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
780					phm.EXPECT().GetStatelessResetToken(gomock.Any())
781					fn()
782					return true
783				}).Times(protocol.MaxAcceptQueueSize)
784				tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize)
785
786				var wg sync.WaitGroup
787				wg.Add(protocol.MaxAcceptQueueSize)
788				for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
789					go func() {
790						defer GinkgoRecover()
791						defer wg.Done()
792						serv.handlePacket(getInitialWithRandomDestConnID())
793						// make sure there are no Write calls on the packet conn
794						time.Sleep(50 * time.Millisecond)
795					}()
796				}
797				wg.Wait()
798				p := getInitialWithRandomDestConnID()
799				hdr, _, _, err := wire.ParsePacket(p.data, 0)
800				Expect(err).ToNot(HaveOccurred())
801				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
802				done := make(chan struct{})
803				conn.EXPECT().WriteTo(gomock.Any(), p.remoteAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
804					defer close(done)
805					rejectHdr := parseHeader(b)
806					Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
807					Expect(rejectHdr.Version).To(Equal(hdr.Version))
808					Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
809					Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
810					return len(b), nil
811				})
812				serv.handlePacket(p)
813				Eventually(done).Should(BeClosed())
814			})
815
816			It("doesn't accept new sessions if they were closed in the mean time", func() {
817				serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
818
819				p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
820				ctx, cancel := context.WithCancel(context.Background())
821				sessionCreated := make(chan struct{})
822				sess := NewMockQuicSession(mockCtrl)
823				serv.newSession = func(
824					_ sendConn,
825					runner sessionRunner,
826					_ protocol.ConnectionID,
827					_ *protocol.ConnectionID,
828					_ protocol.ConnectionID,
829					_ protocol.ConnectionID,
830					_ protocol.ConnectionID,
831					_ protocol.StatelessResetToken,
832					_ *Config,
833					_ *tls.Config,
834					_ *handshake.TokenGenerator,
835					_ bool,
836					_ logging.ConnectionTracer,
837					_ uint64,
838					_ utils.Logger,
839					_ protocol.VersionNumber,
840				) quicSession {
841					sess.EXPECT().handlePacket(p)
842					sess.EXPECT().run()
843					sess.EXPECT().Context().Return(ctx)
844					ctx, cancel := context.WithCancel(context.Background())
845					cancel()
846					sess.EXPECT().HandshakeComplete().Return(ctx)
847					close(sessionCreated)
848					return sess
849				}
850
851				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
852					phm.EXPECT().GetStatelessResetToken(gomock.Any())
853					fn()
854					return true
855				})
856				tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())
857
858				serv.handlePacket(p)
859				// make sure there are no Write calls on the packet conn
860				time.Sleep(50 * time.Millisecond)
861				Eventually(sessionCreated).Should(BeClosed())
862				cancel()
863				time.Sleep(scaleDuration(200 * time.Millisecond))
864
865				done := make(chan struct{})
866				go func() {
867					defer GinkgoRecover()
868					serv.Accept(context.Background())
869					close(done)
870				}()
871				Consistently(done).ShouldNot(BeClosed())
872
873				// make the go routine return
874				phm.EXPECT().CloseServer()
875				sess.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID
876				Expect(serv.Close()).To(Succeed())
877				Eventually(done).Should(BeClosed())
878			})
879		})
880
881		Context("accepting sessions", func() {
882			It("returns Accept when an error occurs", func() {
883				testErr := errors.New("test err")
884
885				done := make(chan struct{})
886				go func() {
887					defer GinkgoRecover()
888					_, err := serv.Accept(context.Background())
889					Expect(err).To(MatchError(testErr))
890					close(done)
891				}()
892
893				serv.setCloseError(testErr)
894				Eventually(done).Should(BeClosed())
895			})
896
897			It("returns immediately, if an error occurred before", func() {
898				testErr := errors.New("test err")
899				serv.setCloseError(testErr)
900				for i := 0; i < 3; i++ {
901					_, err := serv.Accept(context.Background())
902					Expect(err).To(MatchError(testErr))
903				}
904			})
905
906			It("returns when the context is canceled", func() {
907				ctx, cancel := context.WithCancel(context.Background())
908				done := make(chan struct{})
909				go func() {
910					defer GinkgoRecover()
911					_, err := serv.Accept(ctx)
912					Expect(err).To(MatchError("context canceled"))
913					close(done)
914				}()
915
916				Consistently(done).ShouldNot(BeClosed())
917				cancel()
918				Eventually(done).Should(BeClosed())
919			})
920
921			It("accepts new sessions when the handshake completes", func() {
922				sess := NewMockQuicSession(mockCtrl)
923
924				done := make(chan struct{})
925				go func() {
926					defer GinkgoRecover()
927					s, err := serv.Accept(context.Background())
928					Expect(err).ToNot(HaveOccurred())
929					Expect(s).To(Equal(sess))
930					close(done)
931				}()
932
933				ctx, cancel := context.WithCancel(context.Background()) // handshake context
934				serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
935				serv.newSession = func(
936					_ sendConn,
937					runner sessionRunner,
938					_ protocol.ConnectionID,
939					_ *protocol.ConnectionID,
940					_ protocol.ConnectionID,
941					_ protocol.ConnectionID,
942					_ protocol.ConnectionID,
943					_ protocol.StatelessResetToken,
944					_ *Config,
945					_ *tls.Config,
946					_ *handshake.TokenGenerator,
947					_ bool,
948					_ logging.ConnectionTracer,
949					_ uint64,
950					_ utils.Logger,
951					_ protocol.VersionNumber,
952				) quicSession {
953					sess.EXPECT().handlePacket(gomock.Any())
954					sess.EXPECT().HandshakeComplete().Return(ctx)
955					sess.EXPECT().run().Do(func() {})
956					sess.EXPECT().Context().Return(context.Background())
957					return sess
958				}
959				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
960					phm.EXPECT().GetStatelessResetToken(gomock.Any())
961					fn()
962					return true
963				})
964				tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())
965				serv.handleInitialImpl(
966					&receivedPacket{buffer: getPacketBuffer()},
967					&wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}},
968				)
969				Consistently(done).ShouldNot(BeClosed())
970				cancel() // complete the handshake
971				Eventually(done).Should(BeClosed())
972			})
973		})
974	})
975
976	Context("server accepting sessions that haven't completed the handshake", func() {
977		var (
978			serv *earlyServer
979			phm  *MockPacketHandlerManager
980		)
981
982		BeforeEach(func() {
983			ln, err := ListenEarly(conn, tlsConf, nil)
984			Expect(err).ToNot(HaveOccurred())
985			serv = ln.(*earlyServer)
986			phm = NewMockPacketHandlerManager(mockCtrl)
987			serv.sessionHandler = phm
988		})
989
990		AfterEach(func() {
991			phm.EXPECT().CloseServer().MaxTimes(1)
992			serv.Close()
993		})
994
995		It("accepts new sessions when they become ready", func() {
996			sess := NewMockQuicSession(mockCtrl)
997
998			done := make(chan struct{})
999			go func() {
1000				defer GinkgoRecover()
1001				s, err := serv.Accept(context.Background())
1002				Expect(err).ToNot(HaveOccurred())
1003				Expect(s).To(Equal(sess))
1004				close(done)
1005			}()
1006
1007			ready := make(chan struct{})
1008			serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
1009			serv.newSession = func(
1010				_ sendConn,
1011				runner sessionRunner,
1012				_ protocol.ConnectionID,
1013				_ *protocol.ConnectionID,
1014				_ protocol.ConnectionID,
1015				_ protocol.ConnectionID,
1016				_ protocol.ConnectionID,
1017				_ protocol.StatelessResetToken,
1018				_ *Config,
1019				_ *tls.Config,
1020				_ *handshake.TokenGenerator,
1021				enable0RTT bool,
1022				_ logging.ConnectionTracer,
1023				_ uint64,
1024				_ utils.Logger,
1025				_ protocol.VersionNumber,
1026			) quicSession {
1027				Expect(enable0RTT).To(BeTrue())
1028				sess.EXPECT().handlePacket(gomock.Any())
1029				sess.EXPECT().run().Do(func() {})
1030				sess.EXPECT().earlySessionReady().Return(ready)
1031				sess.EXPECT().Context().Return(context.Background())
1032				return sess
1033			}
1034			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
1035				phm.EXPECT().GetStatelessResetToken(gomock.Any())
1036				fn()
1037				return true
1038			})
1039			serv.handleInitialImpl(
1040				&receivedPacket{buffer: getPacketBuffer()},
1041				&wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}},
1042			)
1043			Consistently(done).ShouldNot(BeClosed())
1044			close(ready)
1045			Eventually(done).Should(BeClosed())
1046		})
1047
1048		It("rejects new connection attempts if the accept queue is full", func() {
1049			serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
1050			senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
1051
1052			serv.newSession = func(
1053				_ sendConn,
1054				runner sessionRunner,
1055				_ protocol.ConnectionID,
1056				_ *protocol.ConnectionID,
1057				_ protocol.ConnectionID,
1058				_ protocol.ConnectionID,
1059				_ protocol.ConnectionID,
1060				_ protocol.StatelessResetToken,
1061				_ *Config,
1062				_ *tls.Config,
1063				_ *handshake.TokenGenerator,
1064				_ bool,
1065				_ logging.ConnectionTracer,
1066				_ uint64,
1067				_ utils.Logger,
1068				_ protocol.VersionNumber,
1069			) quicSession {
1070				ready := make(chan struct{})
1071				close(ready)
1072				sess := NewMockQuicSession(mockCtrl)
1073				sess.EXPECT().handlePacket(gomock.Any())
1074				sess.EXPECT().run()
1075				sess.EXPECT().earlySessionReady().Return(ready)
1076				sess.EXPECT().Context().Return(context.Background())
1077				return sess
1078			}
1079
1080			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
1081				phm.EXPECT().GetStatelessResetToken(gomock.Any())
1082				fn()
1083				return true
1084			}).Times(protocol.MaxAcceptQueueSize)
1085			for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
1086				serv.handlePacket(getInitialWithRandomDestConnID())
1087			}
1088
1089			Eventually(func() int32 { return atomic.LoadInt32(&serv.sessionQueueLen) }).Should(BeEquivalentTo(protocol.MaxAcceptQueueSize))
1090			// make sure there are no Write calls on the packet conn
1091			time.Sleep(50 * time.Millisecond)
1092
1093			p := getInitialWithRandomDestConnID()
1094			hdr := parseHeader(p.data)
1095			done := make(chan struct{})
1096			conn.EXPECT().WriteTo(gomock.Any(), senderAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
1097				defer close(done)
1098				rejectHdr := parseHeader(b)
1099				Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
1100				Expect(rejectHdr.Version).To(Equal(hdr.Version))
1101				Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
1102				Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
1103				return len(b), nil
1104			})
1105			serv.handlePacket(p)
1106			Eventually(done).Should(BeClosed())
1107		})
1108
1109		It("doesn't accept new sessions if they were closed in the mean time", func() {
1110			serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
1111
1112			p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
1113			ctx, cancel := context.WithCancel(context.Background())
1114			sessionCreated := make(chan struct{})
1115			sess := NewMockQuicSession(mockCtrl)
1116			serv.newSession = func(
1117				_ sendConn,
1118				runner sessionRunner,
1119				_ protocol.ConnectionID,
1120				_ *protocol.ConnectionID,
1121				_ protocol.ConnectionID,
1122				_ protocol.ConnectionID,
1123				_ protocol.ConnectionID,
1124				_ protocol.StatelessResetToken,
1125				_ *Config,
1126				_ *tls.Config,
1127				_ *handshake.TokenGenerator,
1128				_ bool,
1129				_ logging.ConnectionTracer,
1130				_ uint64,
1131				_ utils.Logger,
1132				_ protocol.VersionNumber,
1133			) quicSession {
1134				sess.EXPECT().handlePacket(p)
1135				sess.EXPECT().run()
1136				sess.EXPECT().earlySessionReady()
1137				sess.EXPECT().Context().Return(ctx)
1138				close(sessionCreated)
1139				return sess
1140			}
1141
1142			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
1143				phm.EXPECT().GetStatelessResetToken(gomock.Any())
1144				fn()
1145				return true
1146			})
1147			serv.handlePacket(p)
1148			// make sure there are no Write calls on the packet conn
1149			time.Sleep(50 * time.Millisecond)
1150			Eventually(sessionCreated).Should(BeClosed())
1151			cancel()
1152			time.Sleep(scaleDuration(200 * time.Millisecond))
1153
1154			done := make(chan struct{})
1155			go func() {
1156				defer GinkgoRecover()
1157				serv.Accept(context.Background())
1158				close(done)
1159			}()
1160			Consistently(done).ShouldNot(BeClosed())
1161
1162			// make the go routine return
1163			phm.EXPECT().CloseServer()
1164			sess.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID
1165			Expect(serv.Close()).To(Succeed())
1166			Eventually(done).Should(BeClosed())
1167		})
1168	})
1169})
1170
1171var _ = Describe("default source address verification", func() {
1172	It("accepts a token", func() {
1173		remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
1174		token := &Token{
1175			IsRetryToken: true,
1176			RemoteAddr:   "192.168.0.1",
1177			SentTime:     time.Now().Add(-protocol.RetryTokenValidity).Add(time.Second), // will expire in 1 second
1178		}
1179		Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue())
1180	})
1181
1182	It("requests verification if no token is provided", func() {
1183		remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
1184		Expect(defaultAcceptToken(remoteAddr, nil)).To(BeFalse())
1185	})
1186
1187	It("rejects a token if the address doesn't match", func() {
1188		remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
1189		token := &Token{
1190			IsRetryToken: true,
1191			RemoteAddr:   "127.0.0.1",
1192			SentTime:     time.Now(),
1193		}
1194		Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse())
1195	})
1196
1197	It("accepts a token for a remote address is not a UDP address", func() {
1198		remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
1199		token := &Token{
1200			IsRetryToken: true,
1201			RemoteAddr:   "192.168.0.1:1337",
1202			SentTime:     time.Now(),
1203		}
1204		Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue())
1205	})
1206
1207	It("rejects an invalid token for a remote address is not a UDP address", func() {
1208		remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
1209		token := &Token{
1210			IsRetryToken: true,
1211			RemoteAddr:   "192.168.0.1:7331", // mismatching port
1212			SentTime:     time.Now(),
1213		}
1214		Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse())
1215	})
1216
1217	It("rejects an expired token", func() {
1218		remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
1219		token := &Token{
1220			IsRetryToken: true,
1221			RemoteAddr:   "192.168.0.1",
1222			SentTime:     time.Now().Add(-protocol.RetryTokenValidity).Add(-time.Second), // expired 1 second ago
1223		}
1224		Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse())
1225	})
1226
1227	It("accepts a non-retry token", func() {
1228		Expect(protocol.RetryTokenValidity).To(BeNumerically("<", protocol.TokenValidity))
1229		remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
1230		token := &Token{
1231			IsRetryToken: false,
1232			RemoteAddr:   "192.168.0.1",
1233			// if this was a retry token, it would have expired one second ago
1234			SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(-time.Second),
1235		}
1236		Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue())
1237	})
1238})
1239