1package quic
2
3import (
4	"bytes"
5	"errors"
6	"fmt"
7	"net"
8	"time"
9
10	"github.com/lucas-clemente/quic-go/internal/ackhandler"
11	"github.com/lucas-clemente/quic-go/internal/handshake"
12	"github.com/lucas-clemente/quic-go/internal/protocol"
13	"github.com/lucas-clemente/quic-go/internal/qerr"
14	"github.com/lucas-clemente/quic-go/internal/utils"
15	"github.com/lucas-clemente/quic-go/internal/wire"
16)
17
18type packer interface {
19	PackCoalescedPacket() (*coalescedPacket, error)
20	PackPacket() (*packedPacket, error)
21	MaybePackProbePacket(protocol.EncryptionLevel) (*packedPacket, error)
22	MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error)
23	PackConnectionClose(*qerr.TransportError) (*coalescedPacket, error)
24	PackApplicationClose(*qerr.ApplicationError) (*coalescedPacket, error)
25
26	SetMaxPacketSize(protocol.ByteCount)
27	PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error)
28
29	HandleTransportParameters(*wire.TransportParameters)
30	SetToken([]byte)
31}
32
33type sealer interface {
34	handshake.LongHeaderSealer
35}
36
37type payload struct {
38	frames []ackhandler.Frame
39	ack    *wire.AckFrame
40	length protocol.ByteCount
41}
42
43type packedPacket struct {
44	buffer *packetBuffer
45	*packetContents
46}
47
48type packetContents struct {
49	header *wire.ExtendedHeader
50	ack    *wire.AckFrame
51	frames []ackhandler.Frame
52
53	length protocol.ByteCount
54
55	isMTUProbePacket bool
56}
57
58type coalescedPacket struct {
59	buffer  *packetBuffer
60	packets []*packetContents
61}
62
63func (p *packetContents) EncryptionLevel() protocol.EncryptionLevel {
64	if !p.header.IsLongHeader {
65		return protocol.Encryption1RTT
66	}
67	//nolint:exhaustive // Will never be called for Retry packets (and they don't have encrypted data).
68	switch p.header.Type {
69	case protocol.PacketTypeInitial:
70		return protocol.EncryptionInitial
71	case protocol.PacketTypeHandshake:
72		return protocol.EncryptionHandshake
73	case protocol.PacketType0RTT:
74		return protocol.Encryption0RTT
75	default:
76		panic("can't determine encryption level")
77	}
78}
79
80func (p *packetContents) IsAckEliciting() bool {
81	return ackhandler.HasAckElicitingFrames(p.frames)
82}
83
84func (p *packetContents) ToAckHandlerPacket(now time.Time, q *retransmissionQueue) *ackhandler.Packet {
85	largestAcked := protocol.InvalidPacketNumber
86	if p.ack != nil {
87		largestAcked = p.ack.LargestAcked()
88	}
89	encLevel := p.EncryptionLevel()
90	for i := range p.frames {
91		if p.frames[i].OnLost != nil {
92			continue
93		}
94		switch encLevel {
95		case protocol.EncryptionInitial:
96			p.frames[i].OnLost = q.AddInitial
97		case protocol.EncryptionHandshake:
98			p.frames[i].OnLost = q.AddHandshake
99		case protocol.Encryption0RTT, protocol.Encryption1RTT:
100			p.frames[i].OnLost = q.AddAppData
101		}
102	}
103	return &ackhandler.Packet{
104		PacketNumber:         p.header.PacketNumber,
105		LargestAcked:         largestAcked,
106		Frames:               p.frames,
107		Length:               p.length,
108		EncryptionLevel:      encLevel,
109		SendTime:             now,
110		IsPathMTUProbePacket: p.isMTUProbePacket,
111	}
112}
113
114func getMaxPacketSize(addr net.Addr) protocol.ByteCount {
115	maxSize := protocol.ByteCount(protocol.MinInitialPacketSize)
116	// If this is not a UDP address, we don't know anything about the MTU.
117	// Use the minimum size of an Initial packet as the max packet size.
118	if udpAddr, ok := addr.(*net.UDPAddr); ok {
119		if utils.IsIPv4(udpAddr.IP) {
120			maxSize = protocol.InitialPacketSizeIPv4
121		} else {
122			maxSize = protocol.InitialPacketSizeIPv6
123		}
124	}
125	return maxSize
126}
127
128type packetNumberManager interface {
129	PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen)
130	PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber
131}
132
133type sealingManager interface {
134	GetInitialSealer() (handshake.LongHeaderSealer, error)
135	GetHandshakeSealer() (handshake.LongHeaderSealer, error)
136	Get0RTTSealer() (handshake.LongHeaderSealer, error)
137	Get1RTTSealer() (handshake.ShortHeaderSealer, error)
138}
139
140type frameSource interface {
141	HasData() bool
142	AppendStreamFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount)
143	AppendControlFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount)
144}
145
146type ackFrameSource interface {
147	GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame
148}
149
150type packetPacker struct {
151	srcConnID     protocol.ConnectionID
152	getDestConnID func() protocol.ConnectionID
153
154	perspective protocol.Perspective
155	version     protocol.VersionNumber
156	cryptoSetup sealingManager
157
158	initialStream   cryptoStream
159	handshakeStream cryptoStream
160
161	token []byte
162
163	pnManager           packetNumberManager
164	framer              frameSource
165	acks                ackFrameSource
166	datagramQueue       *datagramQueue
167	retransmissionQueue *retransmissionQueue
168
169	maxPacketSize          protocol.ByteCount
170	numNonAckElicitingAcks int
171}
172
173var _ packer = &packetPacker{}
174
175func newPacketPacker(
176	srcConnID protocol.ConnectionID,
177	getDestConnID func() protocol.ConnectionID,
178	initialStream cryptoStream,
179	handshakeStream cryptoStream,
180	packetNumberManager packetNumberManager,
181	retransmissionQueue *retransmissionQueue,
182	remoteAddr net.Addr, // only used for determining the max packet size
183	cryptoSetup sealingManager,
184	framer frameSource,
185	acks ackFrameSource,
186	datagramQueue *datagramQueue,
187	perspective protocol.Perspective,
188	version protocol.VersionNumber,
189) *packetPacker {
190	return &packetPacker{
191		cryptoSetup:         cryptoSetup,
192		getDestConnID:       getDestConnID,
193		srcConnID:           srcConnID,
194		initialStream:       initialStream,
195		handshakeStream:     handshakeStream,
196		retransmissionQueue: retransmissionQueue,
197		datagramQueue:       datagramQueue,
198		perspective:         perspective,
199		version:             version,
200		framer:              framer,
201		acks:                acks,
202		pnManager:           packetNumberManager,
203		maxPacketSize:       getMaxPacketSize(remoteAddr),
204	}
205}
206
207// PackConnectionClose packs a packet that closes the connection with a transport error.
208func (p *packetPacker) PackConnectionClose(e *qerr.TransportError) (*coalescedPacket, error) {
209	var reason string
210	// don't send details of crypto errors
211	if !e.ErrorCode.IsCryptoError() {
212		reason = e.ErrorMessage
213	}
214	return p.packConnectionClose(false, uint64(e.ErrorCode), e.FrameType, reason)
215}
216
217// PackApplicationClose packs a packet that closes the connection with an application error.
218func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError) (*coalescedPacket, error) {
219	return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage)
220}
221
222func (p *packetPacker) packConnectionClose(
223	isApplicationError bool,
224	errorCode uint64,
225	frameType uint64,
226	reason string,
227) (*coalescedPacket, error) {
228	var sealers [4]sealer
229	var hdrs [4]*wire.ExtendedHeader
230	var payloads [4]*payload
231	var size protocol.ByteCount
232	var numPackets uint8
233	encLevels := [4]protocol.EncryptionLevel{protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption0RTT, protocol.Encryption1RTT}
234	for i, encLevel := range encLevels {
235		if p.perspective == protocol.PerspectiveServer && encLevel == protocol.Encryption0RTT {
236			continue
237		}
238		ccf := &wire.ConnectionCloseFrame{
239			IsApplicationError: isApplicationError,
240			ErrorCode:          errorCode,
241			FrameType:          frameType,
242			ReasonPhrase:       reason,
243		}
244		// don't send application errors in Initial or Handshake packets
245		if isApplicationError && (encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake) {
246			ccf.IsApplicationError = false
247			ccf.ErrorCode = uint64(qerr.ApplicationErrorErrorCode)
248			ccf.ReasonPhrase = ""
249		}
250		payload := &payload{
251			frames: []ackhandler.Frame{{Frame: ccf}},
252			length: ccf.Length(p.version),
253		}
254
255		var sealer sealer
256		var err error
257		var keyPhase protocol.KeyPhaseBit // only set for 1-RTT
258		switch encLevel {
259		case protocol.EncryptionInitial:
260			sealer, err = p.cryptoSetup.GetInitialSealer()
261		case protocol.EncryptionHandshake:
262			sealer, err = p.cryptoSetup.GetHandshakeSealer()
263		case protocol.Encryption0RTT:
264			sealer, err = p.cryptoSetup.Get0RTTSealer()
265		case protocol.Encryption1RTT:
266			var s handshake.ShortHeaderSealer
267			s, err = p.cryptoSetup.Get1RTTSealer()
268			if err == nil {
269				keyPhase = s.KeyPhase()
270			}
271			sealer = s
272		}
273		if err == handshake.ErrKeysNotYetAvailable || err == handshake.ErrKeysDropped {
274			continue
275		}
276		if err != nil {
277			return nil, err
278		}
279		sealers[i] = sealer
280		var hdr *wire.ExtendedHeader
281		if encLevel == protocol.Encryption1RTT {
282			hdr = p.getShortHeader(keyPhase)
283		} else {
284			hdr = p.getLongHeader(encLevel)
285		}
286		hdrs[i] = hdr
287		payloads[i] = payload
288		size += p.packetLength(hdr, payload) + protocol.ByteCount(sealer.Overhead())
289		numPackets++
290	}
291	contents := make([]*packetContents, 0, numPackets)
292	buffer := getPacketBuffer()
293	for i, encLevel := range encLevels {
294		if sealers[i] == nil {
295			continue
296		}
297		var paddingLen protocol.ByteCount
298		if encLevel == protocol.EncryptionInitial {
299			paddingLen = p.initialPaddingLen(payloads[i].frames, size)
300		}
301		c, err := p.appendPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], false)
302		if err != nil {
303			return nil, err
304		}
305		contents = append(contents, c)
306	}
307	return &coalescedPacket{buffer: buffer, packets: contents}, nil
308}
309
310// packetLength calculates the length of the serialized packet.
311// It takes into account that packets that have a tiny payload need to be padded,
312// such that len(payload) + packet number len >= 4 + AEAD overhead
313func (p *packetPacker) packetLength(hdr *wire.ExtendedHeader, payload *payload) protocol.ByteCount {
314	var paddingLen protocol.ByteCount
315	pnLen := protocol.ByteCount(hdr.PacketNumberLen)
316	if payload.length < 4-pnLen {
317		paddingLen = 4 - pnLen - payload.length
318	}
319	return hdr.GetLength(p.version) + payload.length + paddingLen
320}
321
322func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) {
323	var encLevel protocol.EncryptionLevel
324	var ack *wire.AckFrame
325	if !handshakeConfirmed {
326		ack = p.acks.GetAckFrame(protocol.EncryptionInitial, true)
327		if ack != nil {
328			encLevel = protocol.EncryptionInitial
329		} else {
330			ack = p.acks.GetAckFrame(protocol.EncryptionHandshake, true)
331			if ack != nil {
332				encLevel = protocol.EncryptionHandshake
333			}
334		}
335	}
336	if ack == nil {
337		ack = p.acks.GetAckFrame(protocol.Encryption1RTT, true)
338		if ack == nil {
339			return nil, nil
340		}
341		encLevel = protocol.Encryption1RTT
342	}
343	payload := &payload{
344		ack:    ack,
345		length: ack.Length(p.version),
346	}
347
348	sealer, hdr, err := p.getSealerAndHeader(encLevel)
349	if err != nil {
350		return nil, err
351	}
352	return p.writeSinglePacket(hdr, payload, encLevel, sealer)
353}
354
355// size is the expected size of the packet, if no padding was applied.
356func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, size protocol.ByteCount) protocol.ByteCount {
357	// For the server, only ack-eliciting Initial packets need to be padded.
358	if p.perspective == protocol.PerspectiveServer && !ackhandler.HasAckElicitingFrames(frames) {
359		return 0
360	}
361	if size >= p.maxPacketSize {
362		return 0
363	}
364	return p.maxPacketSize - size
365}
366
367// PackCoalescedPacket packs a new packet.
368// It packs an Initial / Handshake if there is data to send in these packet number spaces.
369// It should only be called before the handshake is confirmed.
370func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) {
371	maxPacketSize := p.maxPacketSize
372	if p.perspective == protocol.PerspectiveClient {
373		maxPacketSize = protocol.MinInitialPacketSize
374	}
375	var initialHdr, handshakeHdr, appDataHdr *wire.ExtendedHeader
376	var initialPayload, handshakePayload, appDataPayload *payload
377	var numPackets int
378	// Try packing an Initial packet.
379	initialSealer, err := p.cryptoSetup.GetInitialSealer()
380	if err != nil && err != handshake.ErrKeysDropped {
381		return nil, err
382	}
383	var size protocol.ByteCount
384	if initialSealer != nil {
385		initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), size, protocol.EncryptionInitial)
386		if initialPayload != nil {
387			size += p.packetLength(initialHdr, initialPayload) + protocol.ByteCount(initialSealer.Overhead())
388			numPackets++
389		}
390	}
391
392	// Add a Handshake packet.
393	var handshakeSealer sealer
394	if size < maxPacketSize-protocol.MinCoalescedPacketSize {
395		var err error
396		handshakeSealer, err = p.cryptoSetup.GetHandshakeSealer()
397		if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable {
398			return nil, err
399		}
400		if handshakeSealer != nil {
401			handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), size, protocol.EncryptionHandshake)
402			if handshakePayload != nil {
403				s := p.packetLength(handshakeHdr, handshakePayload) + protocol.ByteCount(handshakeSealer.Overhead())
404				size += s
405				numPackets++
406			}
407		}
408	}
409
410	// Add a 0-RTT / 1-RTT packet.
411	var appDataSealer sealer
412	appDataEncLevel := protocol.Encryption1RTT
413	if size < maxPacketSize-protocol.MinCoalescedPacketSize {
414		var err error
415		appDataSealer, appDataHdr, appDataPayload = p.maybeGetAppDataPacket(maxPacketSize-size, size)
416		if err != nil {
417			return nil, err
418		}
419		if appDataHdr != nil {
420			if appDataHdr.IsLongHeader {
421				appDataEncLevel = protocol.Encryption0RTT
422			}
423			if appDataPayload != nil {
424				size += p.packetLength(appDataHdr, appDataPayload) + protocol.ByteCount(appDataSealer.Overhead())
425				numPackets++
426			}
427		}
428	}
429
430	if numPackets == 0 {
431		return nil, nil
432	}
433
434	buffer := getPacketBuffer()
435	packet := &coalescedPacket{
436		buffer:  buffer,
437		packets: make([]*packetContents, 0, numPackets),
438	}
439	if initialPayload != nil {
440		padding := p.initialPaddingLen(initialPayload.frames, size)
441		cont, err := p.appendPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, false)
442		if err != nil {
443			return nil, err
444		}
445		packet.packets = append(packet.packets, cont)
446	}
447	if handshakePayload != nil {
448		cont, err := p.appendPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, false)
449		if err != nil {
450			return nil, err
451		}
452		packet.packets = append(packet.packets, cont)
453	}
454	if appDataPayload != nil {
455		cont, err := p.appendPacket(buffer, appDataHdr, appDataPayload, 0, appDataEncLevel, appDataSealer, false)
456		if err != nil {
457			return nil, err
458		}
459		packet.packets = append(packet.packets, cont)
460	}
461	return packet, nil
462}
463
464// PackPacket packs a packet in the application data packet number space.
465// It should be called after the handshake is confirmed.
466func (p *packetPacker) PackPacket() (*packedPacket, error) {
467	sealer, hdr, payload := p.maybeGetAppDataPacket(p.maxPacketSize, 0)
468	if payload == nil {
469		return nil, nil
470	}
471	buffer := getPacketBuffer()
472	encLevel := protocol.Encryption1RTT
473	if hdr.IsLongHeader {
474		encLevel = protocol.Encryption0RTT
475	}
476	cont, err := p.appendPacket(buffer, hdr, payload, 0, encLevel, sealer, false)
477	if err != nil {
478		return nil, err
479	}
480	return &packedPacket{
481		buffer:         buffer,
482		packetContents: cont,
483	}, nil
484}
485
486func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize, currentSize protocol.ByteCount, encLevel protocol.EncryptionLevel) (*wire.ExtendedHeader, *payload) {
487	var s cryptoStream
488	var hasRetransmission bool
489	//nolint:exhaustive // Initial and Handshake are the only two encryption levels here.
490	switch encLevel {
491	case protocol.EncryptionInitial:
492		s = p.initialStream
493		hasRetransmission = p.retransmissionQueue.HasInitialData()
494	case protocol.EncryptionHandshake:
495		s = p.handshakeStream
496		hasRetransmission = p.retransmissionQueue.HasHandshakeData()
497	}
498
499	hasData := s.HasData()
500	var ack *wire.AckFrame
501	if encLevel == protocol.EncryptionInitial || currentSize == 0 {
502		ack = p.acks.GetAckFrame(encLevel, !hasRetransmission && !hasData)
503	}
504	if !hasData && !hasRetransmission && ack == nil {
505		// nothing to send
506		return nil, nil
507	}
508
509	var payload payload
510	if ack != nil {
511		payload.ack = ack
512		payload.length = ack.Length(p.version)
513		maxPacketSize -= payload.length
514	}
515	hdr := p.getLongHeader(encLevel)
516	maxPacketSize -= hdr.GetLength(p.version)
517	if hasRetransmission {
518		for {
519			var f wire.Frame
520			//nolint:exhaustive // 0-RTT packets can't contain any retransmission.s
521			switch encLevel {
522			case protocol.EncryptionInitial:
523				f = p.retransmissionQueue.GetInitialFrame(maxPacketSize)
524			case protocol.EncryptionHandshake:
525				f = p.retransmissionQueue.GetHandshakeFrame(maxPacketSize)
526			}
527			if f == nil {
528				break
529			}
530			payload.frames = append(payload.frames, ackhandler.Frame{Frame: f})
531			frameLen := f.Length(p.version)
532			payload.length += frameLen
533			maxPacketSize -= frameLen
534		}
535	} else if s.HasData() {
536		cf := s.PopCryptoFrame(maxPacketSize)
537		payload.frames = []ackhandler.Frame{{Frame: cf}}
538		payload.length += cf.Length(p.version)
539	}
540	return hdr, &payload
541}
542
543func (p *packetPacker) maybeGetAppDataPacket(maxPacketSize, currentSize protocol.ByteCount) (sealer, *wire.ExtendedHeader, *payload) {
544	var sealer sealer
545	var encLevel protocol.EncryptionLevel
546	var hdr *wire.ExtendedHeader
547	oneRTTSealer, err := p.cryptoSetup.Get1RTTSealer()
548	if err == nil {
549		encLevel = protocol.Encryption1RTT
550		sealer = oneRTTSealer
551		hdr = p.getShortHeader(oneRTTSealer.KeyPhase())
552	} else {
553		// 1-RTT sealer not yet available
554		if p.perspective != protocol.PerspectiveClient {
555			return nil, nil, nil
556		}
557		sealer, err = p.cryptoSetup.Get0RTTSealer()
558		if sealer == nil || err != nil {
559			return nil, nil, nil
560		}
561		encLevel = protocol.Encryption0RTT
562		hdr = p.getLongHeader(protocol.Encryption0RTT)
563	}
564
565	maxPayloadSize := maxPacketSize - hdr.GetLength(p.version) - protocol.ByteCount(sealer.Overhead())
566	payload := p.maybeGetAppDataPacketWithEncLevel(maxPayloadSize, encLevel == protocol.Encryption1RTT && currentSize == 0)
567	return sealer, hdr, payload
568}
569
570func (p *packetPacker) maybeGetAppDataPacketWithEncLevel(maxPayloadSize protocol.ByteCount, ackAllowed bool) *payload {
571	payload := p.composeNextPacket(maxPayloadSize, ackAllowed)
572
573	// check if we have anything to send
574	if len(payload.frames) == 0 {
575		if payload.ack == nil {
576			return nil
577		}
578		// the packet only contains an ACK
579		if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks {
580			ping := &wire.PingFrame{}
581			// don't retransmit the PING frame when it is lost
582			payload.frames = append(payload.frames, ackhandler.Frame{Frame: ping, OnLost: func(wire.Frame) {}})
583			payload.length += ping.Length(p.version)
584			p.numNonAckElicitingAcks = 0
585		} else {
586			p.numNonAckElicitingAcks++
587		}
588	} else {
589		p.numNonAckElicitingAcks = 0
590	}
591	return payload
592}
593
594func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, ackAllowed bool) *payload {
595	payload := &payload{frames: make([]ackhandler.Frame, 0, 1)}
596
597	var hasDatagram bool
598	if p.datagramQueue != nil {
599		if datagram := p.datagramQueue.Get(); datagram != nil {
600			payload.frames = append(payload.frames, ackhandler.Frame{
601				Frame: datagram,
602				// set it to a no-op. Then we won't set the default callback, which would retransmit the frame.
603				OnLost: func(wire.Frame) {},
604			})
605			payload.length += datagram.Length(p.version)
606			hasDatagram = true
607		}
608	}
609
610	var ack *wire.AckFrame
611	hasData := p.framer.HasData()
612	hasRetransmission := p.retransmissionQueue.HasAppData()
613	// TODO: make sure ACKs are sent when a lot of DATAGRAMs are queued
614	if !hasDatagram && ackAllowed {
615		ack = p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData)
616		if ack != nil {
617			payload.ack = ack
618			payload.length += ack.Length(p.version)
619		}
620	}
621
622	if ack == nil && !hasData && !hasRetransmission {
623		return payload
624	}
625
626	if hasRetransmission {
627		for {
628			remainingLen := maxFrameSize - payload.length
629			if remainingLen < protocol.MinStreamFrameSize {
630				break
631			}
632			f := p.retransmissionQueue.GetAppDataFrame(remainingLen)
633			if f == nil {
634				break
635			}
636			payload.frames = append(payload.frames, ackhandler.Frame{Frame: f})
637			payload.length += f.Length(p.version)
638		}
639	}
640
641	if hasData {
642		var lengthAdded protocol.ByteCount
643		payload.frames, lengthAdded = p.framer.AppendControlFrames(payload.frames, maxFrameSize-payload.length)
644		payload.length += lengthAdded
645
646		payload.frames, lengthAdded = p.framer.AppendStreamFrames(payload.frames, maxFrameSize-payload.length)
647		payload.length += lengthAdded
648	}
649	return payload
650}
651
652func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (*packedPacket, error) {
653	var hdr *wire.ExtendedHeader
654	var payload *payload
655	var sealer sealer
656	//nolint:exhaustive // Probe packets are never sent for 0-RTT.
657	switch encLevel {
658	case protocol.EncryptionInitial:
659		var err error
660		sealer, err = p.cryptoSetup.GetInitialSealer()
661		if err != nil {
662			return nil, err
663		}
664		hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.EncryptionInitial)
665	case protocol.EncryptionHandshake:
666		var err error
667		sealer, err = p.cryptoSetup.GetHandshakeSealer()
668		if err != nil {
669			return nil, err
670		}
671		hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.EncryptionHandshake)
672	case protocol.Encryption1RTT:
673		oneRTTSealer, err := p.cryptoSetup.Get1RTTSealer()
674		if err != nil {
675			return nil, err
676		}
677		sealer = oneRTTSealer
678		hdr = p.getShortHeader(oneRTTSealer.KeyPhase())
679		payload = p.maybeGetAppDataPacketWithEncLevel(p.maxPacketSize-protocol.ByteCount(sealer.Overhead())-hdr.GetLength(p.version), true)
680	default:
681		panic("unknown encryption level")
682	}
683	if payload == nil {
684		return nil, nil
685	}
686	size := p.packetLength(hdr, payload) + protocol.ByteCount(sealer.Overhead())
687	var padding protocol.ByteCount
688	if encLevel == protocol.EncryptionInitial {
689		padding = p.initialPaddingLen(payload.frames, size)
690	}
691	buffer := getPacketBuffer()
692	cont, err := p.appendPacket(buffer, hdr, payload, padding, encLevel, sealer, false)
693	if err != nil {
694		return nil, err
695	}
696	return &packedPacket{
697		buffer:         buffer,
698		packetContents: cont,
699	}, nil
700}
701
702func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error) {
703	payload := &payload{
704		frames: []ackhandler.Frame{ping},
705		length: ping.Length(p.version),
706	}
707	buffer := getPacketBuffer()
708	sealer, err := p.cryptoSetup.Get1RTTSealer()
709	if err != nil {
710		return nil, err
711	}
712	hdr := p.getShortHeader(sealer.KeyPhase())
713	padding := size - p.packetLength(hdr, payload) - protocol.ByteCount(sealer.Overhead())
714	contents, err := p.appendPacket(buffer, hdr, payload, padding, protocol.Encryption1RTT, sealer, true)
715	if err != nil {
716		return nil, err
717	}
718	contents.isMTUProbePacket = true
719	return &packedPacket{
720		buffer:         buffer,
721		packetContents: contents,
722	}, nil
723}
724
725func (p *packetPacker) getSealerAndHeader(encLevel protocol.EncryptionLevel) (sealer, *wire.ExtendedHeader, error) {
726	switch encLevel {
727	case protocol.EncryptionInitial:
728		sealer, err := p.cryptoSetup.GetInitialSealer()
729		if err != nil {
730			return nil, nil, err
731		}
732		hdr := p.getLongHeader(protocol.EncryptionInitial)
733		return sealer, hdr, nil
734	case protocol.Encryption0RTT:
735		sealer, err := p.cryptoSetup.Get0RTTSealer()
736		if err != nil {
737			return nil, nil, err
738		}
739		hdr := p.getLongHeader(protocol.Encryption0RTT)
740		return sealer, hdr, nil
741	case protocol.EncryptionHandshake:
742		sealer, err := p.cryptoSetup.GetHandshakeSealer()
743		if err != nil {
744			return nil, nil, err
745		}
746		hdr := p.getLongHeader(protocol.EncryptionHandshake)
747		return sealer, hdr, nil
748	case protocol.Encryption1RTT:
749		sealer, err := p.cryptoSetup.Get1RTTSealer()
750		if err != nil {
751			return nil, nil, err
752		}
753		hdr := p.getShortHeader(sealer.KeyPhase())
754		return sealer, hdr, nil
755	default:
756		return nil, nil, fmt.Errorf("unexpected encryption level: %s", encLevel)
757	}
758}
759
760func (p *packetPacker) getShortHeader(kp protocol.KeyPhaseBit) *wire.ExtendedHeader {
761	pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
762	hdr := &wire.ExtendedHeader{}
763	hdr.PacketNumber = pn
764	hdr.PacketNumberLen = pnLen
765	hdr.DestConnectionID = p.getDestConnID()
766	hdr.KeyPhase = kp
767	return hdr
768}
769
770func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader {
771	pn, pnLen := p.pnManager.PeekPacketNumber(encLevel)
772	hdr := &wire.ExtendedHeader{
773		PacketNumber:    pn,
774		PacketNumberLen: pnLen,
775	}
776	hdr.IsLongHeader = true
777	hdr.Version = p.version
778	hdr.SrcConnectionID = p.srcConnID
779	hdr.DestConnectionID = p.getDestConnID()
780
781	//nolint:exhaustive // 1-RTT packets are not long header packets.
782	switch encLevel {
783	case protocol.EncryptionInitial:
784		hdr.Type = protocol.PacketTypeInitial
785		hdr.Token = p.token
786	case protocol.EncryptionHandshake:
787		hdr.Type = protocol.PacketTypeHandshake
788	case protocol.Encryption0RTT:
789		hdr.Type = protocol.PacketType0RTT
790	}
791	return hdr
792}
793
794// writeSinglePacket packs a single packet.
795func (p *packetPacker) writeSinglePacket(
796	hdr *wire.ExtendedHeader,
797	payload *payload,
798	encLevel protocol.EncryptionLevel,
799	sealer sealer,
800) (*packedPacket, error) {
801	buffer := getPacketBuffer()
802	var paddingLen protocol.ByteCount
803	if encLevel == protocol.EncryptionInitial {
804		paddingLen = p.initialPaddingLen(payload.frames, hdr.GetLength(p.version)+payload.length+protocol.ByteCount(sealer.Overhead()))
805	}
806	contents, err := p.appendPacket(buffer, hdr, payload, paddingLen, encLevel, sealer, false)
807	if err != nil {
808		return nil, err
809	}
810	return &packedPacket{
811		buffer:         buffer,
812		packetContents: contents,
813	}, nil
814}
815
816func (p *packetPacker) appendPacket(buffer *packetBuffer, header *wire.ExtendedHeader, payload *payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, isMTUProbePacket bool) (*packetContents, error) {
817	var paddingLen protocol.ByteCount
818	pnLen := protocol.ByteCount(header.PacketNumberLen)
819	if payload.length < 4-pnLen {
820		paddingLen = 4 - pnLen - payload.length
821	}
822	paddingLen += padding
823	if header.IsLongHeader {
824		header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + payload.length + paddingLen
825	}
826
827	hdrOffset := buffer.Len()
828	buf := bytes.NewBuffer(buffer.Data)
829	if err := header.Write(buf, p.version); err != nil {
830		return nil, err
831	}
832	payloadOffset := buf.Len()
833
834	if payload.ack != nil {
835		if err := payload.ack.Write(buf, p.version); err != nil {
836			return nil, err
837		}
838	}
839	if paddingLen > 0 {
840		buf.Write(make([]byte, paddingLen))
841	}
842	for _, frame := range payload.frames {
843		if err := frame.Write(buf, p.version); err != nil {
844			return nil, err
845		}
846	}
847
848	if payloadSize := protocol.ByteCount(buf.Len()-payloadOffset) - paddingLen; payloadSize != payload.length {
849		return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize)
850	}
851	if !isMTUProbePacket {
852		if size := protocol.ByteCount(buf.Len() + sealer.Overhead()); size > p.maxPacketSize {
853			return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize)
854		}
855	}
856
857	raw := buffer.Data
858	// encrypt the packet
859	raw = raw[:buf.Len()]
860	_ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[hdrOffset:payloadOffset])
861	raw = raw[0 : buf.Len()+sealer.Overhead()]
862	// apply header protection
863	pnOffset := payloadOffset - int(header.PacketNumberLen)
864	sealer.EncryptHeader(raw[pnOffset+4:pnOffset+4+16], &raw[hdrOffset], raw[pnOffset:payloadOffset])
865	buffer.Data = raw
866
867	num := p.pnManager.PopPacketNumber(encLevel)
868	if num != header.PacketNumber {
869		return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
870	}
871	return &packetContents{
872		header: header,
873		ack:    payload.ack,
874		frames: payload.frames,
875		length: buffer.Len() - hdrOffset,
876	}, nil
877}
878
879func (p *packetPacker) SetToken(token []byte) {
880	p.token = token
881}
882
883// When a higher MTU is discovered, use it.
884func (p *packetPacker) SetMaxPacketSize(s protocol.ByteCount) {
885	p.maxPacketSize = s
886}
887
888// If the peer sets a max_packet_size that's smaller than the size we're currently using,
889// we need to reduce the size of packets we send.
890func (p *packetPacker) HandleTransportParameters(params *wire.TransportParameters) {
891	if params.MaxUDPPayloadSize != 0 {
892		p.maxPacketSize = utils.MinByteCount(p.maxPacketSize, params.MaxUDPPayloadSize)
893	}
894}
895