1package quic
2
3import (
4	"context"
5	"crypto/tls"
6	"errors"
7	"net"
8	"os"
9	"time"
10
11	mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging"
12	"github.com/lucas-clemente/quic-go/internal/protocol"
13	"github.com/lucas-clemente/quic-go/internal/utils"
14	"github.com/lucas-clemente/quic-go/logging"
15
16	"github.com/golang/mock/gomock"
17
18	. "github.com/onsi/ginkgo"
19	. "github.com/onsi/gomega"
20)
21
22var _ = Describe("Client", func() {
23	var (
24		cl              *client
25		packetConn      *MockPacketConn
26		addr            net.Addr
27		connID          protocol.ConnectionID
28		mockMultiplexer *MockMultiplexer
29		origMultiplexer multiplexer
30		tlsConf         *tls.Config
31		tracer          *mocklogging.MockConnectionTracer
32		config          *Config
33
34		originalClientSessConstructor func(
35			conn sendConn,
36			runner sessionRunner,
37			destConnID protocol.ConnectionID,
38			srcConnID protocol.ConnectionID,
39			conf *Config,
40			tlsConf *tls.Config,
41			initialPacketNumber protocol.PacketNumber,
42			enable0RTT bool,
43			hasNegotiatedVersion bool,
44			tracer logging.ConnectionTracer,
45			tracingID uint64,
46			logger utils.Logger,
47			v protocol.VersionNumber,
48		) quicSession
49	)
50
51	BeforeEach(func() {
52		tlsConf = &tls.Config{NextProtos: []string{"proto1"}}
53		connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
54		originalClientSessConstructor = newClientSession
55		tracer = mocklogging.NewMockConnectionTracer(mockCtrl)
56		tr := mocklogging.NewMockTracer(mockCtrl)
57		tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1)
58		config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.VersionTLS}}
59		Eventually(areSessionsRunning).Should(BeFalse())
60		// sess = NewMockQuicSession(mockCtrl)
61		addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
62		packetConn = NewMockPacketConn(mockCtrl)
63		packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
64		cl = &client{
65			srcConnID:  connID,
66			destConnID: connID,
67			version:    protocol.VersionTLS,
68			conn:       newSendPconn(packetConn, addr),
69			tracer:     tracer,
70			logger:     utils.DefaultLogger,
71		}
72		getMultiplexer() // make the sync.Once execute
73		// replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer
74		mockMultiplexer = NewMockMultiplexer(mockCtrl)
75		origMultiplexer = connMuxer
76		connMuxer = mockMultiplexer
77	})
78
79	AfterEach(func() {
80		connMuxer = origMultiplexer
81		newClientSession = originalClientSessConstructor
82	})
83
84	AfterEach(func() {
85		if s, ok := cl.session.(*session); ok {
86			s.shutdown()
87		}
88		Eventually(areSessionsRunning).Should(BeFalse())
89	})
90
91	Context("Dialing", func() {
92		var origGenerateConnectionID func(int) (protocol.ConnectionID, error)
93		var origGenerateConnectionIDForInitial func() (protocol.ConnectionID, error)
94
95		BeforeEach(func() {
96			origGenerateConnectionID = generateConnectionID
97			origGenerateConnectionIDForInitial = generateConnectionIDForInitial
98			generateConnectionID = func(int) (protocol.ConnectionID, error) {
99				return connID, nil
100			}
101			generateConnectionIDForInitial = func() (protocol.ConnectionID, error) {
102				return connID, nil
103			}
104		})
105
106		AfterEach(func() {
107			generateConnectionID = origGenerateConnectionID
108			generateConnectionIDForInitial = origGenerateConnectionIDForInitial
109		})
110
111		It("resolves the address", func() {
112			if os.Getenv("APPVEYOR") == "True" {
113				Skip("This test is flaky on AppVeyor.")
114			}
115
116			manager := NewMockPacketHandlerManager(mockCtrl)
117			manager.EXPECT().Add(gomock.Any(), gomock.Any())
118			manager.EXPECT().Destroy()
119			mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
120
121			remoteAddrChan := make(chan string, 1)
122			newClientSession = func(
123				conn sendConn,
124				_ sessionRunner,
125				_ protocol.ConnectionID,
126				_ protocol.ConnectionID,
127				_ *Config,
128				_ *tls.Config,
129				_ protocol.PacketNumber,
130				_ bool,
131				_ bool,
132				_ logging.ConnectionTracer,
133				_ uint64,
134				_ utils.Logger,
135				_ protocol.VersionNumber,
136			) quicSession {
137				remoteAddrChan <- conn.RemoteAddr().String()
138				sess := NewMockQuicSession(mockCtrl)
139				sess.EXPECT().run()
140				sess.EXPECT().HandshakeComplete().Return(context.Background())
141				return sess
142			}
143			_, err := DialAddr("localhost:17890", tlsConf, &Config{HandshakeIdleTimeout: time.Millisecond})
144			Expect(err).ToNot(HaveOccurred())
145			Eventually(remoteAddrChan).Should(Receive(Equal("127.0.0.1:17890")))
146		})
147
148		It("uses the tls.Config.ServerName as the hostname, if present", func() {
149			manager := NewMockPacketHandlerManager(mockCtrl)
150			manager.EXPECT().Add(gomock.Any(), gomock.Any())
151			manager.EXPECT().Destroy()
152			mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
153
154			hostnameChan := make(chan string, 1)
155			newClientSession = func(
156				_ sendConn,
157				_ sessionRunner,
158				_ protocol.ConnectionID,
159				_ protocol.ConnectionID,
160				_ *Config,
161				tlsConf *tls.Config,
162				_ protocol.PacketNumber,
163				_ bool,
164				_ bool,
165				_ logging.ConnectionTracer,
166				_ uint64,
167				_ utils.Logger,
168				_ protocol.VersionNumber,
169			) quicSession {
170				hostnameChan <- tlsConf.ServerName
171				sess := NewMockQuicSession(mockCtrl)
172				sess.EXPECT().run()
173				sess.EXPECT().HandshakeComplete().Return(context.Background())
174				return sess
175			}
176			tlsConf.ServerName = "foobar"
177			_, err := DialAddr("localhost:17890", tlsConf, nil)
178			Expect(err).ToNot(HaveOccurred())
179			Eventually(hostnameChan).Should(Receive(Equal("foobar")))
180		})
181
182		It("allows passing host without port as server name", func() {
183			manager := NewMockPacketHandlerManager(mockCtrl)
184			manager.EXPECT().Add(gomock.Any(), gomock.Any())
185			mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
186
187			hostnameChan := make(chan string, 1)
188			newClientSession = func(
189				_ sendConn,
190				_ sessionRunner,
191				_ protocol.ConnectionID,
192				_ protocol.ConnectionID,
193				_ *Config,
194				tlsConf *tls.Config,
195				_ protocol.PacketNumber,
196				_ bool,
197				_ bool,
198				_ logging.ConnectionTracer,
199				_ uint64,
200				_ utils.Logger,
201				_ protocol.VersionNumber,
202			) quicSession {
203				hostnameChan <- tlsConf.ServerName
204				sess := NewMockQuicSession(mockCtrl)
205				sess.EXPECT().HandshakeComplete().Return(context.Background())
206				sess.EXPECT().run()
207				return sess
208			}
209			tracer.EXPECT().StartedConnection(packetConn.LocalAddr(), addr, gomock.Any(), gomock.Any())
210			_, err := Dial(
211				packetConn,
212				addr,
213				"test.com",
214				tlsConf,
215				config,
216			)
217			Expect(err).ToNot(HaveOccurred())
218			Eventually(hostnameChan).Should(Receive(Equal("test.com")))
219		})
220
221		It("returns after the handshake is complete", func() {
222			manager := NewMockPacketHandlerManager(mockCtrl)
223			manager.EXPECT().Add(gomock.Any(), gomock.Any())
224			mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
225
226			run := make(chan struct{})
227			newClientSession = func(
228				_ sendConn,
229				runner sessionRunner,
230				_ protocol.ConnectionID,
231				_ protocol.ConnectionID,
232				_ *Config,
233				_ *tls.Config,
234				_ protocol.PacketNumber,
235				enable0RTT bool,
236				_ bool,
237				_ logging.ConnectionTracer,
238				_ uint64,
239				_ utils.Logger,
240				_ protocol.VersionNumber,
241			) quicSession {
242				Expect(enable0RTT).To(BeFalse())
243				sess := NewMockQuicSession(mockCtrl)
244				sess.EXPECT().run().Do(func() { close(run) })
245				ctx, cancel := context.WithCancel(context.Background())
246				cancel()
247				sess.EXPECT().HandshakeComplete().Return(ctx)
248				return sess
249			}
250			tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
251			s, err := Dial(
252				packetConn,
253				addr,
254				"localhost:1337",
255				tlsConf,
256				config,
257			)
258			Expect(err).ToNot(HaveOccurred())
259			Expect(s).ToNot(BeNil())
260			Eventually(run).Should(BeClosed())
261		})
262
263		It("returns early sessions", func() {
264			manager := NewMockPacketHandlerManager(mockCtrl)
265			manager.EXPECT().Add(gomock.Any(), gomock.Any())
266			mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
267
268			readyChan := make(chan struct{})
269			done := make(chan struct{})
270			newClientSession = func(
271				_ sendConn,
272				runner sessionRunner,
273				_ protocol.ConnectionID,
274				_ protocol.ConnectionID,
275				_ *Config,
276				_ *tls.Config,
277				_ protocol.PacketNumber,
278				enable0RTT bool,
279				_ bool,
280				_ logging.ConnectionTracer,
281				_ uint64,
282				_ utils.Logger,
283				_ protocol.VersionNumber,
284			) quicSession {
285				Expect(enable0RTT).To(BeTrue())
286				sess := NewMockQuicSession(mockCtrl)
287				sess.EXPECT().run().Do(func() { <-done })
288				sess.EXPECT().HandshakeComplete().Return(context.Background())
289				sess.EXPECT().earlySessionReady().Return(readyChan)
290				return sess
291			}
292
293			go func() {
294				defer GinkgoRecover()
295				defer close(done)
296				tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
297				s, err := DialEarly(
298					packetConn,
299					addr,
300					"localhost:1337",
301					tlsConf,
302					config,
303				)
304				Expect(err).ToNot(HaveOccurred())
305				Expect(s).ToNot(BeNil())
306			}()
307			Consistently(done).ShouldNot(BeClosed())
308			close(readyChan)
309			Eventually(done).Should(BeClosed())
310		})
311
312		It("returns an error that occurs while waiting for the handshake to complete", func() {
313			manager := NewMockPacketHandlerManager(mockCtrl)
314			manager.EXPECT().Add(gomock.Any(), gomock.Any())
315			mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
316
317			testErr := errors.New("early handshake error")
318			newClientSession = func(
319				_ sendConn,
320				_ sessionRunner,
321				_ protocol.ConnectionID,
322				_ protocol.ConnectionID,
323				_ *Config,
324				_ *tls.Config,
325				_ protocol.PacketNumber,
326				_ bool,
327				_ bool,
328				_ logging.ConnectionTracer,
329				_ uint64,
330				_ utils.Logger,
331				_ protocol.VersionNumber,
332			) quicSession {
333				sess := NewMockQuicSession(mockCtrl)
334				sess.EXPECT().run().Return(testErr)
335				sess.EXPECT().HandshakeComplete().Return(context.Background())
336				return sess
337			}
338			tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
339			_, err := Dial(
340				packetConn,
341				addr,
342				"localhost:1337",
343				tlsConf,
344				config,
345			)
346			Expect(err).To(MatchError(testErr))
347		})
348
349		It("closes the session when the context is canceled", func() {
350			manager := NewMockPacketHandlerManager(mockCtrl)
351			manager.EXPECT().Add(gomock.Any(), gomock.Any())
352			mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
353
354			sessionRunning := make(chan struct{})
355			defer close(sessionRunning)
356			sess := NewMockQuicSession(mockCtrl)
357			sess.EXPECT().run().Do(func() {
358				<-sessionRunning
359			})
360			sess.EXPECT().HandshakeComplete().Return(context.Background())
361			newClientSession = func(
362				_ sendConn,
363				_ sessionRunner,
364				_ protocol.ConnectionID,
365				_ protocol.ConnectionID,
366				_ *Config,
367				_ *tls.Config,
368				_ protocol.PacketNumber,
369				_ bool,
370				_ bool,
371				_ logging.ConnectionTracer,
372				_ uint64,
373				_ utils.Logger,
374				_ protocol.VersionNumber,
375			) quicSession {
376				return sess
377			}
378			ctx, cancel := context.WithCancel(context.Background())
379			dialed := make(chan struct{})
380			go func() {
381				defer GinkgoRecover()
382				tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
383				_, err := DialContext(
384					ctx,
385					packetConn,
386					addr,
387					"localhost:1337",
388					tlsConf,
389					config,
390				)
391				Expect(err).To(MatchError(context.Canceled))
392				close(dialed)
393			}()
394			Consistently(dialed).ShouldNot(BeClosed())
395			sess.EXPECT().shutdown()
396			cancel()
397			Eventually(dialed).Should(BeClosed())
398		})
399
400		It("closes the connection when it was created by DialAddr", func() {
401			if os.Getenv("APPVEYOR") == "True" {
402				Skip("This test is flaky on AppVeyor.")
403			}
404
405			manager := NewMockPacketHandlerManager(mockCtrl)
406			mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
407			manager.EXPECT().Add(gomock.Any(), gomock.Any())
408
409			var conn sendConn
410			run := make(chan struct{})
411			sessionCreated := make(chan struct{})
412			sess := NewMockQuicSession(mockCtrl)
413			newClientSession = func(
414				connP sendConn,
415				_ sessionRunner,
416				_ protocol.ConnectionID,
417				_ protocol.ConnectionID,
418				_ *Config,
419				_ *tls.Config,
420				_ protocol.PacketNumber,
421				_ bool,
422				_ bool,
423				_ logging.ConnectionTracer,
424				_ uint64,
425				_ utils.Logger,
426				_ protocol.VersionNumber,
427			) quicSession {
428				conn = connP
429				close(sessionCreated)
430				return sess
431			}
432			sess.EXPECT().run().Do(func() {
433				<-run
434			})
435			sess.EXPECT().HandshakeComplete().Return(context.Background())
436
437			done := make(chan struct{})
438			go func() {
439				defer GinkgoRecover()
440				_, err := DialAddr("localhost:1337", tlsConf, nil)
441				Expect(err).ToNot(HaveOccurred())
442				close(done)
443			}()
444
445			Eventually(sessionCreated).Should(BeClosed())
446
447			// check that the connection is not closed
448			Expect(conn.Write([]byte("foobar"))).To(Succeed())
449
450			manager.EXPECT().Destroy()
451			close(run)
452			time.Sleep(50 * time.Millisecond)
453
454			Eventually(done).Should(BeClosed())
455		})
456
457		Context("quic.Config", func() {
458			It("setups with the right values", func() {
459				tokenStore := NewLRUTokenStore(10, 4)
460				config := &Config{
461					HandshakeIdleTimeout:  1337 * time.Minute,
462					MaxIdleTimeout:        42 * time.Hour,
463					MaxIncomingStreams:    1234,
464					MaxIncomingUniStreams: 4321,
465					ConnectionIDLength:    13,
466					StatelessResetKey:     []byte("foobar"),
467					TokenStore:            tokenStore,
468					EnableDatagrams:       true,
469				}
470				c := populateClientConfig(config, false)
471				Expect(c.HandshakeIdleTimeout).To(Equal(1337 * time.Minute))
472				Expect(c.MaxIdleTimeout).To(Equal(42 * time.Hour))
473				Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234))
474				Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321))
475				Expect(c.ConnectionIDLength).To(Equal(13))
476				Expect(c.StatelessResetKey).To(Equal([]byte("foobar")))
477				Expect(c.TokenStore).To(Equal(tokenStore))
478				Expect(c.EnableDatagrams).To(BeTrue())
479			})
480
481			It("errors when the Config contains an invalid version", func() {
482				manager := NewMockPacketHandlerManager(mockCtrl)
483				mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
484
485				version := protocol.VersionNumber(0x1234)
486				_, err := Dial(packetConn, nil, "localhost:1234", tlsConf, &Config{Versions: []protocol.VersionNumber{version}})
487				Expect(err).To(MatchError("0x1234 is not a valid QUIC version"))
488			})
489
490			It("disables bidirectional streams", func() {
491				config := &Config{
492					MaxIncomingStreams:    -1,
493					MaxIncomingUniStreams: 4321,
494				}
495				c := populateClientConfig(config, false)
496				Expect(c.MaxIncomingStreams).To(BeZero())
497				Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321))
498			})
499
500			It("disables unidirectional streams", func() {
501				config := &Config{
502					MaxIncomingStreams:    1234,
503					MaxIncomingUniStreams: -1,
504				}
505				c := populateClientConfig(config, false)
506				Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234))
507				Expect(c.MaxIncomingUniStreams).To(BeZero())
508			})
509
510			It("uses 0-byte connection IDs when dialing an address", func() {
511				c := populateClientConfig(&Config{}, true)
512				Expect(c.ConnectionIDLength).To(BeZero())
513			})
514
515			It("fills in default values if options are not set in the Config", func() {
516				c := populateClientConfig(&Config{}, false)
517				Expect(c.Versions).To(Equal(protocol.SupportedVersions))
518				Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
519				Expect(c.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
520			})
521		})
522
523		It("creates new sessions with the right parameters", func() {
524			manager := NewMockPacketHandlerManager(mockCtrl)
525			manager.EXPECT().Add(connID, gomock.Any())
526			mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
527
528			config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
529			c := make(chan struct{})
530			var cconn sendConn
531			var version protocol.VersionNumber
532			var conf *Config
533			newClientSession = func(
534				connP sendConn,
535				_ sessionRunner,
536				_ protocol.ConnectionID,
537				_ protocol.ConnectionID,
538				configP *Config,
539				_ *tls.Config,
540				_ protocol.PacketNumber,
541				_ bool,
542				_ bool,
543				_ logging.ConnectionTracer,
544				_ uint64,
545				_ utils.Logger,
546				versionP protocol.VersionNumber,
547			) quicSession {
548				cconn = connP
549				version = versionP
550				conf = configP
551				close(c)
552				// TODO: check connection IDs?
553				sess := NewMockQuicSession(mockCtrl)
554				sess.EXPECT().run()
555				sess.EXPECT().HandshakeComplete().Return(context.Background())
556				return sess
557			}
558			_, err := Dial(packetConn, addr, "localhost:1337", tlsConf, config)
559			Expect(err).ToNot(HaveOccurred())
560			Eventually(c).Should(BeClosed())
561			Expect(cconn.(*spconn).PacketConn).To(Equal(packetConn))
562			Expect(version).To(Equal(config.Versions[0]))
563			Expect(conf.Versions).To(Equal(config.Versions))
564		})
565
566		It("creates a new session after version negotiation", func() {
567			manager := NewMockPacketHandlerManager(mockCtrl)
568			manager.EXPECT().Add(connID, gomock.Any()).Times(2)
569			manager.EXPECT().Destroy()
570			mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
571
572			var counter int
573			newClientSession = func(
574				_ sendConn,
575				_ sessionRunner,
576				_ protocol.ConnectionID,
577				_ protocol.ConnectionID,
578				configP *Config,
579				_ *tls.Config,
580				pn protocol.PacketNumber,
581				_ bool,
582				hasNegotiatedVersion bool,
583				_ logging.ConnectionTracer,
584				_ uint64,
585				_ utils.Logger,
586				versionP protocol.VersionNumber,
587			) quicSession {
588				sess := NewMockQuicSession(mockCtrl)
589				sess.EXPECT().HandshakeComplete().Return(context.Background())
590				if counter == 0 {
591					Expect(pn).To(BeZero())
592					Expect(hasNegotiatedVersion).To(BeFalse())
593					sess.EXPECT().run().Return(&errCloseForRecreating{
594						nextPacketNumber: 109,
595						nextVersion:      789,
596					})
597				} else {
598					Expect(pn).To(Equal(protocol.PacketNumber(109)))
599					Expect(hasNegotiatedVersion).To(BeTrue())
600					sess.EXPECT().run()
601				}
602				counter++
603				return sess
604			}
605
606			tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
607			_, err := DialAddr("localhost:7890", tlsConf, config)
608			Expect(err).ToNot(HaveOccurred())
609			Expect(counter).To(Equal(2))
610		})
611	})
612})
613