1package quic
2
3import (
4	"bytes"
5	"crypto/tls"
6	"errors"
7	"fmt"
8	"io"
9	"net"
10	"sync"
11	"sync/atomic"
12	"time"
13
14	"github.com/lucas-clemente/quic-go/internal/handshake"
15	"github.com/lucas-clemente/quic-go/internal/protocol"
16	"github.com/lucas-clemente/quic-go/internal/qerr"
17	"github.com/lucas-clemente/quic-go/internal/utils"
18	"github.com/lucas-clemente/quic-go/internal/wire"
19)
20
21// packetHandler handles packets
22type packetHandler interface {
23	handlePacket(*receivedPacket)
24	io.Closer
25	destroy(error)
26	getPerspective() protocol.Perspective
27}
28
29type unknownPacketHandler interface {
30	handlePacket(*receivedPacket)
31	closeWithError(error) error
32}
33
34type packetHandlerManager interface {
35	io.Closer
36	Add(protocol.ConnectionID, packetHandler)
37	Retire(protocol.ConnectionID)
38	Remove(protocol.ConnectionID)
39	AddResetToken([16]byte, packetHandler)
40	RemoveResetToken([16]byte)
41	GetStatelessResetToken(protocol.ConnectionID) [16]byte
42	SetServer(unknownPacketHandler)
43	CloseServer()
44}
45
46type quicSession interface {
47	Session
48	handlePacket(*receivedPacket)
49	GetVersion() protocol.VersionNumber
50	getPerspective() protocol.Perspective
51	run() error
52	destroy(error)
53	closeForRecreating() protocol.PacketNumber
54	closeRemote(error)
55}
56
57type sessionRunner interface {
58	OnHandshakeComplete(Session)
59	Retire(protocol.ConnectionID)
60	Remove(protocol.ConnectionID)
61	AddResetToken([16]byte, packetHandler)
62	RemoveResetToken([16]byte)
63}
64
65type runner struct {
66	packetHandlerManager
67
68	onHandshakeCompleteImpl func(Session)
69}
70
71func (r *runner) OnHandshakeComplete(s Session) { r.onHandshakeCompleteImpl(s) }
72
73var _ sessionRunner = &runner{}
74
75// A Listener of QUIC
76type server struct {
77	mutex sync.Mutex
78
79	tlsConf *tls.Config
80	config  *Config
81
82	conn net.PacketConn
83	// If the server is started with ListenAddr, we create a packet conn.
84	// If it is started with Listen, we take a packet conn as a parameter.
85	createdPacketConn bool
86
87	cookieGenerator *handshake.CookieGenerator
88
89	sessionHandler packetHandlerManager
90
91	// set as a member, so they can be set in the tests
92	newSession func(connection, sessionRunner, protocol.ConnectionID /* original connection ID */, protocol.ConnectionID /* destination connection ID */, protocol.ConnectionID /* source connection ID */, *Config, *tls.Config, *handshake.TransportParameters, utils.Logger, protocol.VersionNumber) (quicSession, error)
93
94	serverError error
95	errorChan   chan struct{}
96	closed      bool
97
98	sessionQueue    chan Session
99	sessionQueueLen int32 // to be used as an atomic
100
101	sessionRunner sessionRunner
102
103	logger utils.Logger
104}
105
106var _ Listener = &server{}
107var _ unknownPacketHandler = &server{}
108
109// ListenAddr creates a QUIC server listening on a given address.
110// The tls.Config must not be nil and must contain a certificate configuration.
111// The quic.Config may be nil, in that case the default values will be used.
112func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) {
113	udpAddr, err := net.ResolveUDPAddr("udp", addr)
114	if err != nil {
115		return nil, err
116	}
117	conn, err := net.ListenUDP("udp", udpAddr)
118	if err != nil {
119		return nil, err
120	}
121	serv, err := listen(conn, tlsConf, config)
122	if err != nil {
123		return nil, err
124	}
125	serv.createdPacketConn = true
126	return serv, nil
127}
128
129// Listen listens for QUIC connections on a given net.PacketConn.
130// A single PacketConn only be used for a single call to Listen.
131// The PacketConn can be used for simultaneous calls to Dial.
132// QUIC connection IDs are used for demultiplexing the different connections.
133// The tls.Config must not be nil and must contain a certificate configuration.
134// The quic.Config may be nil, in that case the default values will be used.
135func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
136	return listen(conn, tlsConf, config)
137}
138
139func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, error) {
140	// TODO(#1655): only require that tls.Config.Certificates or tls.Config.GetCertificate is set
141	if tlsConf == nil || len(tlsConf.Certificates) == 0 {
142		return nil, errors.New("quic: Certificates not set in tls.Config")
143	}
144	config = populateServerConfig(config)
145	for _, v := range config.Versions {
146		if !protocol.IsValidVersion(v) {
147			return nil, fmt.Errorf("%s is not a valid QUIC version", v)
148		}
149	}
150
151	sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength, config.StatelessResetKey)
152	if err != nil {
153		return nil, err
154	}
155	s := &server{
156		conn:           conn,
157		tlsConf:        tlsConf,
158		config:         config,
159		sessionHandler: sessionHandler,
160		sessionQueue:   make(chan Session),
161		errorChan:      make(chan struct{}),
162		newSession:     newSession,
163		logger:         utils.DefaultLogger.WithPrefix("server"),
164	}
165	if err := s.setup(); err != nil {
166		return nil, err
167	}
168	sessionHandler.SetServer(s)
169	s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
170	return s, nil
171}
172
173func (s *server) setup() error {
174	s.sessionRunner = &runner{
175		packetHandlerManager: s.sessionHandler,
176		onHandshakeCompleteImpl: func(sess Session) {
177			go func() {
178				atomic.AddInt32(&s.sessionQueueLen, 1)
179				defer atomic.AddInt32(&s.sessionQueueLen, -1)
180				select {
181				case s.sessionQueue <- sess:
182					// blocks until the session is accepted
183				case <-sess.Context().Done():
184					// don't pass sessions that were already closed to Accept()
185				}
186			}()
187		},
188	}
189	cookieGenerator, err := handshake.NewCookieGenerator()
190	if err != nil {
191		return err
192	}
193	s.cookieGenerator = cookieGenerator
194	return nil
195}
196
197var defaultAcceptCookie = func(clientAddr net.Addr, cookie *Cookie) bool {
198	if cookie == nil {
199		return false
200	}
201	if time.Now().After(cookie.SentTime.Add(protocol.CookieExpiryTime)) {
202		return false
203	}
204	var sourceAddr string
205	if udpAddr, ok := clientAddr.(*net.UDPAddr); ok {
206		sourceAddr = udpAddr.IP.String()
207	} else {
208		sourceAddr = clientAddr.String()
209	}
210	return sourceAddr == cookie.RemoteAddr
211}
212
213// populateServerConfig populates fields in the quic.Config with their default values, if none are set
214// it may be called with nil
215func populateServerConfig(config *Config) *Config {
216	if config == nil {
217		config = &Config{}
218	}
219	versions := config.Versions
220	if len(versions) == 0 {
221		versions = protocol.SupportedVersions
222	}
223
224	vsa := defaultAcceptCookie
225	if config.AcceptCookie != nil {
226		vsa = config.AcceptCookie
227	}
228
229	handshakeTimeout := protocol.DefaultHandshakeTimeout
230	if config.HandshakeTimeout != 0 {
231		handshakeTimeout = config.HandshakeTimeout
232	}
233	idleTimeout := protocol.DefaultIdleTimeout
234	if config.IdleTimeout != 0 {
235		idleTimeout = config.IdleTimeout
236	}
237
238	maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
239	if maxReceiveStreamFlowControlWindow == 0 {
240		maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow
241	}
242	maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
243	if maxReceiveConnectionFlowControlWindow == 0 {
244		maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow
245	}
246	maxIncomingStreams := config.MaxIncomingStreams
247	if maxIncomingStreams == 0 {
248		maxIncomingStreams = protocol.DefaultMaxIncomingStreams
249	} else if maxIncomingStreams < 0 {
250		maxIncomingStreams = 0
251	}
252	maxIncomingUniStreams := config.MaxIncomingUniStreams
253	if maxIncomingUniStreams == 0 {
254		maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
255	} else if maxIncomingUniStreams < 0 {
256		maxIncomingUniStreams = 0
257	}
258	connIDLen := config.ConnectionIDLength
259	if connIDLen == 0 {
260		connIDLen = protocol.DefaultConnectionIDLength
261	}
262
263	return &Config{
264		Versions:                              versions,
265		HandshakeTimeout:                      handshakeTimeout,
266		IdleTimeout:                           idleTimeout,
267		AcceptCookie:                          vsa,
268		KeepAlive:                             config.KeepAlive,
269		MaxReceiveStreamFlowControlWindow:     maxReceiveStreamFlowControlWindow,
270		MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
271		MaxIncomingStreams:                    maxIncomingStreams,
272		MaxIncomingUniStreams:                 maxIncomingUniStreams,
273		ConnectionIDLength:                    connIDLen,
274		StatelessResetKey:                     config.StatelessResetKey,
275	}
276}
277
278// Accept returns newly openend sessions
279func (s *server) Accept() (Session, error) {
280	var sess Session
281	select {
282	case sess = <-s.sessionQueue:
283		return sess, nil
284	case <-s.errorChan:
285		return nil, s.serverError
286	}
287}
288
289// Close the server
290func (s *server) Close() error {
291	s.mutex.Lock()
292	defer s.mutex.Unlock()
293	if s.closed {
294		return nil
295	}
296	return s.closeWithMutex()
297}
298
299func (s *server) closeWithMutex() error {
300	s.sessionHandler.CloseServer()
301	if s.serverError == nil {
302		s.serverError = errors.New("server closed")
303	}
304	var err error
305	// If the server was started with ListenAddr, we created the packet conn.
306	// We need to close it in order to make the go routine reading from that conn return.
307	if s.createdPacketConn {
308		err = s.sessionHandler.Close()
309	}
310	s.closed = true
311	close(s.errorChan)
312	return err
313}
314
315func (s *server) closeWithError(e error) error {
316	s.mutex.Lock()
317	defer s.mutex.Unlock()
318	if s.closed {
319		return nil
320	}
321	s.serverError = e
322	return s.closeWithMutex()
323}
324
325// Addr returns the server's network address
326func (s *server) Addr() net.Addr {
327	return s.conn.LocalAddr()
328}
329
330func (s *server) handlePacket(p *receivedPacket) {
331	go func() {
332		if shouldReleaseBuffer := s.handlePacketImpl(p); !shouldReleaseBuffer {
333			p.buffer.Release()
334		}
335	}()
336}
337
338func (s *server) handlePacketImpl(p *receivedPacket) bool /* was the packet passed on to a session */ {
339	if len(p.data) < protocol.MinInitialPacketSize {
340		s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", len(p.data))
341		return false
342	}
343	// If we're creating a new session, the packet will be passed to the session.
344	// The header will then be parsed again.
345	hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDLength)
346	if err != nil {
347		s.logger.Debugf("Error parsing packet: %s", err)
348		return false
349	}
350	// Short header packets should never end up here in the first place
351	if !hdr.IsLongHeader {
352		return false
353	}
354	// send a Version Negotiation Packet if the client is speaking a different protocol version
355	if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
356		s.sendVersionNegotiationPacket(p, hdr)
357		return false
358	}
359	if hdr.IsLongHeader && hdr.Type != protocol.PacketTypeInitial {
360		// Drop long header packets.
361		// There's litte point in sending a Stateless Reset, since the client
362		// might not have received the token yet.
363		return false
364	}
365
366	s.logger.Debugf("<- Received Initial packet.")
367
368	sess, connID, err := s.handleInitialImpl(p, hdr)
369	if err != nil {
370		s.logger.Errorf("Error occurred handling initial packet: %s", err)
371		return false
372	}
373	if sess == nil { // a retry was done, or the connection attempt was rejected
374		return false
375	}
376	// Don't put the packet buffer back if a new session was created.
377	// The session will handle the packet and take of that.
378	s.sessionHandler.Add(connID, sess)
379	return true
380}
381
382func (s *server) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (quicSession, protocol.ConnectionID, error) {
383	if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
384		return nil, nil, errors.New("too short connection ID")
385	}
386
387	var cookie *Cookie
388	var origDestConnectionID protocol.ConnectionID
389	if len(hdr.Token) > 0 {
390		c, err := s.cookieGenerator.DecodeToken(hdr.Token)
391		if err == nil {
392			cookie = &Cookie{
393				RemoteAddr: c.RemoteAddr,
394				SentTime:   c.SentTime,
395			}
396			origDestConnectionID = c.OriginalDestConnectionID
397		}
398	}
399	if !s.config.AcceptCookie(p.remoteAddr, cookie) {
400		// Log the Initial packet now.
401		// If no Retry is sent, the packet will be logged by the session.
402		(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
403		return nil, nil, s.sendRetry(p.remoteAddr, hdr)
404	}
405
406	if queueLen := atomic.LoadInt32(&s.sessionQueueLen); queueLen >= protocol.MaxAcceptQueueSize {
407		s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize)
408		return nil, nil, s.sendServerBusy(p.remoteAddr, hdr)
409	}
410
411	connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
412	if err != nil {
413		return nil, nil, err
414	}
415	s.logger.Debugf("Changing connection ID to %s.", connID)
416	sess, err := s.createNewSession(
417		p.remoteAddr,
418		origDestConnectionID,
419		hdr.DestConnectionID,
420		hdr.SrcConnectionID,
421		connID,
422		hdr.Version,
423	)
424	if err != nil {
425		return nil, nil, err
426	}
427	sess.handlePacket(p)
428	return sess, connID, nil
429}
430
431func (s *server) createNewSession(
432	remoteAddr net.Addr,
433	origDestConnID protocol.ConnectionID,
434	clientDestConnID protocol.ConnectionID,
435	destConnID protocol.ConnectionID,
436	srcConnID protocol.ConnectionID,
437	version protocol.VersionNumber,
438) (quicSession, error) {
439	token := s.sessionHandler.GetStatelessResetToken(srcConnID)
440	params := &handshake.TransportParameters{
441		InitialMaxStreamDataBidiLocal:  protocol.InitialMaxStreamData,
442		InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
443		InitialMaxStreamDataUni:        protocol.InitialMaxStreamData,
444		InitialMaxData:                 protocol.InitialMaxData,
445		IdleTimeout:                    s.config.IdleTimeout,
446		MaxBidiStreams:                 uint64(s.config.MaxIncomingStreams),
447		MaxUniStreams:                  uint64(s.config.MaxIncomingUniStreams),
448		AckDelayExponent:               protocol.AckDelayExponent,
449		DisableMigration:               true,
450		StatelessResetToken:            &token,
451		OriginalConnectionID:           origDestConnID,
452	}
453	sess, err := s.newSession(
454		&conn{pconn: s.conn, currentAddr: remoteAddr},
455		s.sessionRunner,
456		clientDestConnID,
457		destConnID,
458		srcConnID,
459		s.config,
460		s.tlsConf,
461		params,
462		s.logger,
463		version,
464	)
465	if err != nil {
466		return nil, err
467	}
468	go sess.run()
469	return sess, nil
470}
471
472func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
473	token, err := s.cookieGenerator.NewToken(remoteAddr, hdr.DestConnectionID)
474	if err != nil {
475		return err
476	}
477	connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
478	if err != nil {
479		return err
480	}
481	replyHdr := &wire.ExtendedHeader{}
482	replyHdr.IsLongHeader = true
483	replyHdr.Type = protocol.PacketTypeRetry
484	replyHdr.Version = hdr.Version
485	replyHdr.SrcConnectionID = connID
486	replyHdr.DestConnectionID = hdr.SrcConnectionID
487	replyHdr.OrigDestConnectionID = hdr.DestConnectionID
488	replyHdr.Token = token
489	s.logger.Debugf("Changing connection ID to %s.\n-> Sending Retry", connID)
490	replyHdr.Log(s.logger)
491	buf := &bytes.Buffer{}
492	if err := replyHdr.Write(buf, hdr.Version); err != nil {
493		return err
494	}
495	if _, err := s.conn.WriteTo(buf.Bytes(), remoteAddr); err != nil {
496		s.logger.Debugf("Error sending Retry: %s", err)
497	}
498	return nil
499}
500
501func (s *server) sendServerBusy(remoteAddr net.Addr, hdr *wire.Header) error {
502	sealer, _, err := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer)
503	if err != nil {
504		return err
505	}
506	packetBuffer := getPacketBuffer()
507	defer packetBuffer.Release()
508	buf := bytes.NewBuffer(packetBuffer.Slice[:0])
509
510	ccf := &wire.ConnectionCloseFrame{ErrorCode: qerr.ServerBusy}
511
512	replyHdr := &wire.ExtendedHeader{}
513	replyHdr.IsLongHeader = true
514	replyHdr.Type = protocol.PacketTypeInitial
515	replyHdr.Version = hdr.Version
516	replyHdr.SrcConnectionID = hdr.DestConnectionID
517	replyHdr.DestConnectionID = hdr.SrcConnectionID
518	replyHdr.PacketNumberLen = protocol.PacketNumberLen4
519	replyHdr.Length = 4 /* packet number len */ + ccf.Length(hdr.Version) + protocol.ByteCount(sealer.Overhead())
520	if err := replyHdr.Write(buf, hdr.Version); err != nil {
521		return err
522	}
523	payloadOffset := buf.Len()
524
525	if err := ccf.Write(buf, hdr.Version); err != nil {
526		return err
527	}
528
529	raw := buf.Bytes()
530	_ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], replyHdr.PacketNumber, raw[:payloadOffset])
531	raw = raw[0 : buf.Len()+sealer.Overhead()]
532
533	pnOffset := payloadOffset - int(replyHdr.PacketNumberLen)
534	sealer.EncryptHeader(
535		raw[pnOffset+4:pnOffset+4+16],
536		&raw[0],
537		raw[pnOffset:payloadOffset],
538	)
539
540	replyHdr.Log(s.logger)
541	wire.LogFrame(s.logger, ccf, true)
542	if _, err := s.conn.WriteTo(raw, remoteAddr); err != nil {
543		s.logger.Debugf("Error rejecting connection: %s", err)
544	}
545	return nil
546}
547
548func (s *server) sendVersionNegotiationPacket(p *receivedPacket, hdr *wire.Header) {
549	s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version)
550	data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
551	if err != nil {
552		s.logger.Debugf("Error composing Version Negotiation: %s", err)
553		return
554	}
555	if _, err := s.conn.WriteTo(data, p.remoteAddr); err != nil {
556		s.logger.Debugf("Error sending Version Negotiation: %s", err)
557	}
558}
559