1package wire
2
3import (
4	"bytes"
5	"errors"
6	"sort"
7	"time"
8
9	"github.com/lucas-clemente/quic-go/internal/protocol"
10	"github.com/lucas-clemente/quic-go/internal/utils"
11)
12
13var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges")
14
15// An AckFrame is an ACK frame
16type AckFrame struct {
17	AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last
18	DelayTime time.Duration
19
20	ECT0, ECT1, ECNCE uint64
21}
22
23// parseAckFrame reads an ACK frame
24func parseAckFrame(r *bytes.Reader, ackDelayExponent uint8, _ protocol.VersionNumber) (*AckFrame, error) {
25	typeByte, err := r.ReadByte()
26	if err != nil {
27		return nil, err
28	}
29	ecn := typeByte&0x1 > 0
30
31	frame := &AckFrame{}
32
33	la, err := utils.ReadVarInt(r)
34	if err != nil {
35		return nil, err
36	}
37	largestAcked := protocol.PacketNumber(la)
38	delay, err := utils.ReadVarInt(r)
39	if err != nil {
40		return nil, err
41	}
42
43	delayTime := time.Duration(delay*1<<ackDelayExponent) * time.Microsecond
44	if delayTime < 0 {
45		// If the delay time overflows, set it to the maximum encodable value.
46		delayTime = utils.InfDuration
47	}
48	frame.DelayTime = delayTime
49
50	numBlocks, err := utils.ReadVarInt(r)
51	if err != nil {
52		return nil, err
53	}
54
55	// read the first ACK range
56	ab, err := utils.ReadVarInt(r)
57	if err != nil {
58		return nil, err
59	}
60	ackBlock := protocol.PacketNumber(ab)
61	if ackBlock > largestAcked {
62		return nil, errors.New("invalid first ACK range")
63	}
64	smallest := largestAcked - ackBlock
65
66	// read all the other ACK ranges
67	frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked})
68	for i := uint64(0); i < numBlocks; i++ {
69		g, err := utils.ReadVarInt(r)
70		if err != nil {
71			return nil, err
72		}
73		gap := protocol.PacketNumber(g)
74		if smallest < gap+2 {
75			return nil, errInvalidAckRanges
76		}
77		largest := smallest - gap - 2
78
79		ab, err := utils.ReadVarInt(r)
80		if err != nil {
81			return nil, err
82		}
83		ackBlock := protocol.PacketNumber(ab)
84
85		if ackBlock > largest {
86			return nil, errInvalidAckRanges
87		}
88		smallest = largest - ackBlock
89		frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest})
90	}
91
92	if !frame.validateAckRanges() {
93		return nil, errInvalidAckRanges
94	}
95
96	// parse (and skip) the ECN section
97	if ecn {
98		for i := 0; i < 3; i++ {
99			if _, err := utils.ReadVarInt(r); err != nil {
100				return nil, err
101			}
102		}
103	}
104
105	return frame, nil
106}
107
108// Write writes an ACK frame.
109func (f *AckFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
110	hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0
111	if hasECN {
112		b.WriteByte(0x3)
113	} else {
114		b.WriteByte(0x2)
115	}
116	utils.WriteVarInt(b, uint64(f.LargestAcked()))
117	utils.WriteVarInt(b, encodeAckDelay(f.DelayTime))
118
119	numRanges := f.numEncodableAckRanges()
120	utils.WriteVarInt(b, uint64(numRanges-1))
121
122	// write the first range
123	_, firstRange := f.encodeAckRange(0)
124	utils.WriteVarInt(b, firstRange)
125
126	// write all the other range
127	for i := 1; i < numRanges; i++ {
128		gap, len := f.encodeAckRange(i)
129		utils.WriteVarInt(b, gap)
130		utils.WriteVarInt(b, len)
131	}
132
133	if hasECN {
134		utils.WriteVarInt(b, f.ECT0)
135		utils.WriteVarInt(b, f.ECT1)
136		utils.WriteVarInt(b, f.ECNCE)
137	}
138	return nil
139}
140
141// Length of a written frame
142func (f *AckFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
143	largestAcked := f.AckRanges[0].Largest
144	numRanges := f.numEncodableAckRanges()
145
146	length := 1 + utils.VarIntLen(uint64(largestAcked)) + utils.VarIntLen(encodeAckDelay(f.DelayTime))
147
148	length += utils.VarIntLen(uint64(numRanges - 1))
149	lowestInFirstRange := f.AckRanges[0].Smallest
150	length += utils.VarIntLen(uint64(largestAcked - lowestInFirstRange))
151
152	for i := 1; i < numRanges; i++ {
153		gap, len := f.encodeAckRange(i)
154		length += utils.VarIntLen(gap)
155		length += utils.VarIntLen(len)
156	}
157	if f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 {
158		length += utils.VarIntLen(f.ECT0)
159		length += utils.VarIntLen(f.ECT1)
160		length += utils.VarIntLen(f.ECNCE)
161	}
162	return length
163}
164
165// gets the number of ACK ranges that can be encoded
166// such that the resulting frame is smaller than the maximum ACK frame size
167func (f *AckFrame) numEncodableAckRanges() int {
168	length := 1 + utils.VarIntLen(uint64(f.LargestAcked())) + utils.VarIntLen(encodeAckDelay(f.DelayTime))
169	length += 2 // assume that the number of ranges will consume 2 bytes
170	for i := 1; i < len(f.AckRanges); i++ {
171		gap, len := f.encodeAckRange(i)
172		rangeLen := utils.VarIntLen(gap) + utils.VarIntLen(len)
173		if length+rangeLen > protocol.MaxAckFrameSize {
174			// Writing range i would exceed the MaxAckFrameSize.
175			// So encode one range less than that.
176			return i - 1
177		}
178		length += rangeLen
179	}
180	return len(f.AckRanges)
181}
182
183func (f *AckFrame) encodeAckRange(i int) (uint64 /* gap */, uint64 /* length */) {
184	if i == 0 {
185		return 0, uint64(f.AckRanges[0].Largest - f.AckRanges[0].Smallest)
186	}
187	return uint64(f.AckRanges[i-1].Smallest - f.AckRanges[i].Largest - 2),
188		uint64(f.AckRanges[i].Largest - f.AckRanges[i].Smallest)
189}
190
191// HasMissingRanges returns if this frame reports any missing packets
192func (f *AckFrame) HasMissingRanges() bool {
193	return len(f.AckRanges) > 1
194}
195
196func (f *AckFrame) validateAckRanges() bool {
197	if len(f.AckRanges) == 0 {
198		return false
199	}
200
201	// check the validity of every single ACK range
202	for _, ackRange := range f.AckRanges {
203		if ackRange.Smallest > ackRange.Largest {
204			return false
205		}
206	}
207
208	// check the consistency for ACK with multiple NACK ranges
209	for i, ackRange := range f.AckRanges {
210		if i == 0 {
211			continue
212		}
213		lastAckRange := f.AckRanges[i-1]
214		if lastAckRange.Smallest <= ackRange.Smallest {
215			return false
216		}
217		if lastAckRange.Smallest <= ackRange.Largest+1 {
218			return false
219		}
220	}
221
222	return true
223}
224
225// LargestAcked is the largest acked packet number
226func (f *AckFrame) LargestAcked() protocol.PacketNumber {
227	return f.AckRanges[0].Largest
228}
229
230// LowestAcked is the lowest acked packet number
231func (f *AckFrame) LowestAcked() protocol.PacketNumber {
232	return f.AckRanges[len(f.AckRanges)-1].Smallest
233}
234
235// AcksPacket determines if this ACK frame acks a certain packet number
236func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool {
237	if p < f.LowestAcked() || p > f.LargestAcked() {
238		return false
239	}
240
241	i := sort.Search(len(f.AckRanges), func(i int) bool {
242		return p >= f.AckRanges[i].Smallest
243	})
244	// i will always be < len(f.AckRanges), since we checked above that p is not bigger than the largest acked
245	return p <= f.AckRanges[i].Largest
246}
247
248func encodeAckDelay(delay time.Duration) uint64 {
249	return uint64(delay.Nanoseconds() / (1000 * (1 << protocol.AckDelayExponent)))
250}
251