1package ackhandler
2
3import (
4	"fmt"
5	"time"
6
7	"github.com/lucas-clemente/quic-go/internal/protocol"
8	"github.com/lucas-clemente/quic-go/internal/utils"
9	"github.com/lucas-clemente/quic-go/internal/wire"
10)
11
12type receivedPacketHandler struct {
13	sentPackets sentPacketTracker
14
15	initialPackets   *receivedPacketTracker
16	handshakePackets *receivedPacketTracker
17	appDataPackets   *receivedPacketTracker
18
19	lowest1RTTPacket protocol.PacketNumber
20}
21
22var _ ReceivedPacketHandler = &receivedPacketHandler{}
23
24func newReceivedPacketHandler(
25	sentPackets sentPacketTracker,
26	rttStats *utils.RTTStats,
27	logger utils.Logger,
28	version protocol.VersionNumber,
29) ReceivedPacketHandler {
30	return &receivedPacketHandler{
31		sentPackets:      sentPackets,
32		initialPackets:   newReceivedPacketTracker(rttStats, logger, version),
33		handshakePackets: newReceivedPacketTracker(rttStats, logger, version),
34		appDataPackets:   newReceivedPacketTracker(rttStats, logger, version),
35		lowest1RTTPacket: protocol.InvalidPacketNumber,
36	}
37}
38
39func (h *receivedPacketHandler) ReceivedPacket(
40	pn protocol.PacketNumber,
41	ecn protocol.ECN,
42	encLevel protocol.EncryptionLevel,
43	rcvTime time.Time,
44	shouldInstigateAck bool,
45) error {
46	h.sentPackets.ReceivedPacket(encLevel)
47	switch encLevel {
48	case protocol.EncryptionInitial:
49		h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck)
50	case protocol.EncryptionHandshake:
51		h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck)
52	case protocol.Encryption0RTT:
53		if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket {
54			return fmt.Errorf("received packet number %d on a 0-RTT packet after receiving %d on a 1-RTT packet", pn, h.lowest1RTTPacket)
55		}
56		h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck)
57	case protocol.Encryption1RTT:
58		if h.lowest1RTTPacket == protocol.InvalidPacketNumber || pn < h.lowest1RTTPacket {
59			h.lowest1RTTPacket = pn
60		}
61		h.appDataPackets.IgnoreBelow(h.sentPackets.GetLowestPacketNotConfirmedAcked())
62		h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck)
63	default:
64		panic(fmt.Sprintf("received packet with unknown encryption level: %s", encLevel))
65	}
66	return nil
67}
68
69func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
70	//nolint:exhaustive // 1-RTT packet number space is never dropped.
71	switch encLevel {
72	case protocol.EncryptionInitial:
73		h.initialPackets = nil
74	case protocol.EncryptionHandshake:
75		h.handshakePackets = nil
76	case protocol.Encryption0RTT:
77		// Nothing to do here.
78		// If we are rejecting 0-RTT, no 0-RTT packets will have been decrypted.
79	default:
80		panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
81	}
82}
83
84func (h *receivedPacketHandler) GetAlarmTimeout() time.Time {
85	var initialAlarm, handshakeAlarm time.Time
86	if h.initialPackets != nil {
87		initialAlarm = h.initialPackets.GetAlarmTimeout()
88	}
89	if h.handshakePackets != nil {
90		handshakeAlarm = h.handshakePackets.GetAlarmTimeout()
91	}
92	oneRTTAlarm := h.appDataPackets.GetAlarmTimeout()
93	return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm)
94}
95
96func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame {
97	var ack *wire.AckFrame
98	//nolint:exhaustive // 0-RTT packets can't contain ACK frames.
99	switch encLevel {
100	case protocol.EncryptionInitial:
101		if h.initialPackets != nil {
102			ack = h.initialPackets.GetAckFrame(onlyIfQueued)
103		}
104	case protocol.EncryptionHandshake:
105		if h.handshakePackets != nil {
106			ack = h.handshakePackets.GetAckFrame(onlyIfQueued)
107		}
108	case protocol.Encryption1RTT:
109		// 0-RTT packets can't contain ACK frames
110		return h.appDataPackets.GetAckFrame(onlyIfQueued)
111	default:
112		return nil
113	}
114	// For Initial and Handshake ACKs, the delay time is ignored by the receiver.
115	// Set it to 0 in order to save bytes.
116	if ack != nil {
117		ack.DelayTime = 0
118	}
119	return ack
120}
121
122func (h *receivedPacketHandler) IsPotentiallyDuplicate(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) bool {
123	switch encLevel {
124	case protocol.EncryptionInitial:
125		if h.initialPackets != nil {
126			return h.initialPackets.IsPotentiallyDuplicate(pn)
127		}
128	case protocol.EncryptionHandshake:
129		if h.handshakePackets != nil {
130			return h.handshakePackets.IsPotentiallyDuplicate(pn)
131		}
132	case protocol.Encryption0RTT, protocol.Encryption1RTT:
133		return h.appDataPackets.IsPotentiallyDuplicate(pn)
134	}
135	panic("unexpected encryption level")
136}
137