1package wire
2
3import (
4	"bytes"
5	"errors"
6	"sort"
7	"time"
8
9	"github.com/lucas-clemente/quic-go/internal/protocol"
10	"github.com/lucas-clemente/quic-go/internal/utils"
11)
12
13var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges")
14
15// An AckFrame is an ACK frame
16type AckFrame struct {
17	AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last
18	DelayTime time.Duration
19}
20
21// parseAckFrame reads an ACK frame
22func parseAckFrame(r *bytes.Reader, ackDelayExponent uint8, version protocol.VersionNumber) (*AckFrame, error) {
23	typeByte, err := r.ReadByte()
24	if err != nil {
25		return nil, err
26	}
27	ecn := typeByte&0x1 > 0
28
29	frame := &AckFrame{}
30
31	la, err := utils.ReadVarInt(r)
32	if err != nil {
33		return nil, err
34	}
35	largestAcked := protocol.PacketNumber(la)
36	delay, err := utils.ReadVarInt(r)
37	if err != nil {
38		return nil, err
39	}
40	frame.DelayTime = time.Duration(delay*1<<ackDelayExponent) * time.Microsecond
41
42	numBlocks, err := utils.ReadVarInt(r)
43	if err != nil {
44		return nil, err
45	}
46
47	// read the first ACK range
48	ab, err := utils.ReadVarInt(r)
49	if err != nil {
50		return nil, err
51	}
52	ackBlock := protocol.PacketNumber(ab)
53	if ackBlock > largestAcked {
54		return nil, errors.New("invalid first ACK range")
55	}
56	smallest := largestAcked - ackBlock
57
58	// read all the other ACK ranges
59	frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked})
60	for i := uint64(0); i < numBlocks; i++ {
61		g, err := utils.ReadVarInt(r)
62		if err != nil {
63			return nil, err
64		}
65		gap := protocol.PacketNumber(g)
66		if smallest < gap+2 {
67			return nil, errInvalidAckRanges
68		}
69		largest := smallest - gap - 2
70
71		ab, err := utils.ReadVarInt(r)
72		if err != nil {
73			return nil, err
74		}
75		ackBlock := protocol.PacketNumber(ab)
76
77		if ackBlock > largest {
78			return nil, errInvalidAckRanges
79		}
80		smallest = largest - ackBlock
81		frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest})
82	}
83
84	if !frame.validateAckRanges() {
85		return nil, errInvalidAckRanges
86	}
87
88	// parse (and skip) the ECN section
89	if ecn {
90		for i := 0; i < 3; i++ {
91			if _, err := utils.ReadVarInt(r); err != nil {
92				return nil, err
93			}
94		}
95	}
96
97	return frame, nil
98}
99
100// Write writes an ACK frame.
101func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
102	b.WriteByte(0x2)
103	utils.WriteVarInt(b, uint64(f.LargestAcked()))
104	utils.WriteVarInt(b, encodeAckDelay(f.DelayTime))
105
106	numRanges := f.numEncodableAckRanges()
107	utils.WriteVarInt(b, uint64(numRanges-1))
108
109	// write the first range
110	_, firstRange := f.encodeAckRange(0)
111	utils.WriteVarInt(b, firstRange)
112
113	// write all the other range
114	for i := 1; i < numRanges; i++ {
115		gap, len := f.encodeAckRange(i)
116		utils.WriteVarInt(b, gap)
117		utils.WriteVarInt(b, len)
118	}
119	return nil
120}
121
122// Length of a written frame
123func (f *AckFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
124	largestAcked := f.AckRanges[0].Largest
125	numRanges := f.numEncodableAckRanges()
126
127	length := 1 + utils.VarIntLen(uint64(largestAcked)) + utils.VarIntLen(encodeAckDelay(f.DelayTime))
128
129	length += utils.VarIntLen(uint64(numRanges - 1))
130	lowestInFirstRange := f.AckRanges[0].Smallest
131	length += utils.VarIntLen(uint64(largestAcked - lowestInFirstRange))
132
133	for i := 1; i < numRanges; i++ {
134		gap, len := f.encodeAckRange(i)
135		length += utils.VarIntLen(gap)
136		length += utils.VarIntLen(len)
137	}
138	return length
139}
140
141// gets the number of ACK ranges that can be encoded
142// such that the resulting frame is smaller than the maximum ACK frame size
143func (f *AckFrame) numEncodableAckRanges() int {
144	length := 1 + utils.VarIntLen(uint64(f.LargestAcked())) + utils.VarIntLen(encodeAckDelay(f.DelayTime))
145	length += 2 // assume that the number of ranges will consume 2 bytes
146	for i := 1; i < len(f.AckRanges); i++ {
147		gap, len := f.encodeAckRange(i)
148		rangeLen := utils.VarIntLen(gap) + utils.VarIntLen(len)
149		if length+rangeLen > protocol.MaxAckFrameSize {
150			// Writing range i would exceed the MaxAckFrameSize.
151			// So encode one range less than that.
152			return i - 1
153		}
154		length += rangeLen
155	}
156	return len(f.AckRanges)
157}
158
159func (f *AckFrame) encodeAckRange(i int) (uint64 /* gap */, uint64 /* length */) {
160	if i == 0 {
161		return 0, uint64(f.AckRanges[0].Largest - f.AckRanges[0].Smallest)
162	}
163	return uint64(f.AckRanges[i-1].Smallest - f.AckRanges[i].Largest - 2),
164		uint64(f.AckRanges[i].Largest - f.AckRanges[i].Smallest)
165}
166
167// HasMissingRanges returns if this frame reports any missing packets
168func (f *AckFrame) HasMissingRanges() bool {
169	return len(f.AckRanges) > 1
170}
171
172func (f *AckFrame) validateAckRanges() bool {
173	if len(f.AckRanges) == 0 {
174		return false
175	}
176
177	// check the validity of every single ACK range
178	for _, ackRange := range f.AckRanges {
179		if ackRange.Smallest > ackRange.Largest {
180			return false
181		}
182	}
183
184	// check the consistency for ACK with multiple NACK ranges
185	for i, ackRange := range f.AckRanges {
186		if i == 0 {
187			continue
188		}
189		lastAckRange := f.AckRanges[i-1]
190		if lastAckRange.Smallest <= ackRange.Smallest {
191			return false
192		}
193		if lastAckRange.Smallest <= ackRange.Largest+1 {
194			return false
195		}
196	}
197
198	return true
199}
200
201// LargestAcked is the largest acked packet number
202func (f *AckFrame) LargestAcked() protocol.PacketNumber {
203	return f.AckRanges[0].Largest
204}
205
206// LowestAcked is the lowest acked packet number
207func (f *AckFrame) LowestAcked() protocol.PacketNumber {
208	return f.AckRanges[len(f.AckRanges)-1].Smallest
209}
210
211// AcksPacket determines if this ACK frame acks a certain packet number
212func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool {
213	if p < f.LowestAcked() || p > f.LargestAcked() {
214		return false
215	}
216
217	i := sort.Search(len(f.AckRanges), func(i int) bool {
218		return p >= f.AckRanges[i].Smallest
219	})
220	// i will always be < len(f.AckRanges), since we checked above that p is not bigger than the largest acked
221	return p <= f.AckRanges[i].Largest
222}
223
224func encodeAckDelay(delay time.Duration) uint64 {
225	return uint64(delay.Nanoseconds() / (1000 * (1 << protocol.AckDelayExponent)))
226}
227