1package congestion
2
3import (
4	"time"
5
6	"github.com/ooni/psiphon/oopsi/github.com/Psiphon-Labs/quic-go/internal/protocol"
7	"github.com/ooni/psiphon/oopsi/github.com/Psiphon-Labs/quic-go/internal/utils"
8)
9
10const (
11	maxBurstBytes                                     = 3 * protocol.DefaultTCPMSS
12	renoBeta                       float32            = 0.7 // Reno backoff factor.
13	defaultMinimumCongestionWindow protocol.ByteCount = 2 * protocol.DefaultTCPMSS
14)
15
16type cubicSender struct {
17	hybridSlowStart HybridSlowStart
18	prr             PrrSender
19	rttStats        *RTTStats
20	stats           connectionStats
21	cubic           *Cubic
22
23	noPRR bool
24	reno  bool
25
26	// Track the largest packet that has been sent.
27	largestSentPacketNumber protocol.PacketNumber
28
29	// Track the largest packet that has been acked.
30	largestAckedPacketNumber protocol.PacketNumber
31
32	// Track the largest packet number outstanding when a CWND cutback occurs.
33	largestSentAtLastCutback protocol.PacketNumber
34
35	// Whether the last loss event caused us to exit slowstart.
36	// Used for stats collection of slowstartPacketsLost
37	lastCutbackExitedSlowstart bool
38
39	// When true, exit slow start with large cutback of congestion window.
40	slowStartLargeReduction bool
41
42	// Congestion window in packets.
43	congestionWindow protocol.ByteCount
44
45	// Minimum congestion window in packets.
46	minCongestionWindow protocol.ByteCount
47
48	// Maximum congestion window.
49	maxCongestionWindow protocol.ByteCount
50
51	// Slow start congestion window in bytes, aka ssthresh.
52	slowstartThreshold protocol.ByteCount
53
54	// Number of connections to simulate.
55	numConnections int
56
57	// ACK counter for the Reno implementation.
58	numAckedPackets uint64
59
60	initialCongestionWindow    protocol.ByteCount
61	initialMaxCongestionWindow protocol.ByteCount
62
63	minSlowStartExitWindow protocol.ByteCount
64}
65
66var _ SendAlgorithm = &cubicSender{}
67var _ SendAlgorithmWithDebugInfos = &cubicSender{}
68
69// NewCubicSender makes a new cubic sender
70func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount) *cubicSender {
71	return &cubicSender{
72		rttStats:                   rttStats,
73		largestSentPacketNumber:    protocol.InvalidPacketNumber,
74		largestAckedPacketNumber:   protocol.InvalidPacketNumber,
75		largestSentAtLastCutback:   protocol.InvalidPacketNumber,
76		initialCongestionWindow:    initialCongestionWindow,
77		initialMaxCongestionWindow: initialMaxCongestionWindow,
78		congestionWindow:           initialCongestionWindow,
79		minCongestionWindow:        defaultMinimumCongestionWindow,
80		slowstartThreshold:         initialMaxCongestionWindow,
81		maxCongestionWindow:        initialMaxCongestionWindow,
82		numConnections:             defaultNumConnections,
83		cubic:                      NewCubic(clock),
84		reno:                       reno,
85	}
86}
87
88// TimeUntilSend returns when the next packet should be sent.
89func (c *cubicSender) TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration {
90	if !c.noPRR && c.InRecovery() {
91		// PRR is used when in recovery.
92		if c.prr.CanSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) {
93			return 0
94		}
95	}
96	return c.rttStats.SmoothedRTT() * time.Duration(protocol.DefaultTCPMSS) / time.Duration(2*c.GetCongestionWindow())
97}
98
99func (c *cubicSender) OnPacketSent(
100	sentTime time.Time,
101	bytesInFlight protocol.ByteCount,
102	packetNumber protocol.PacketNumber,
103	bytes protocol.ByteCount,
104	isRetransmittable bool,
105) {
106	if !isRetransmittable {
107		return
108	}
109	if c.InRecovery() {
110		// PRR is used when in recovery.
111		c.prr.OnPacketSent(bytes)
112	}
113	c.largestSentPacketNumber = packetNumber
114	c.hybridSlowStart.OnPacketSent(packetNumber)
115}
116
117func (c *cubicSender) CanSend(bytesInFlight protocol.ByteCount) bool {
118	if !c.noPRR && c.InRecovery() {
119		return c.prr.CanSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold())
120	}
121	return bytesInFlight < c.GetCongestionWindow()
122}
123
124func (c *cubicSender) InRecovery() bool {
125	return c.largestAckedPacketNumber != protocol.InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback
126}
127
128func (c *cubicSender) InSlowStart() bool {
129	return c.GetCongestionWindow() < c.GetSlowStartThreshold()
130}
131
132func (c *cubicSender) GetCongestionWindow() protocol.ByteCount {
133	return c.congestionWindow
134}
135
136func (c *cubicSender) GetSlowStartThreshold() protocol.ByteCount {
137	return c.slowstartThreshold
138}
139
140func (c *cubicSender) ExitSlowstart() {
141	c.slowstartThreshold = c.congestionWindow
142}
143
144func (c *cubicSender) SlowstartThreshold() protocol.ByteCount {
145	return c.slowstartThreshold
146}
147
148func (c *cubicSender) MaybeExitSlowStart() {
149	if c.InSlowStart() && c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/protocol.DefaultTCPMSS) {
150		c.ExitSlowstart()
151	}
152}
153
154func (c *cubicSender) OnPacketAcked(
155	ackedPacketNumber protocol.PacketNumber,
156	ackedBytes protocol.ByteCount,
157	priorInFlight protocol.ByteCount,
158	eventTime time.Time,
159) {
160	c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber)
161	if c.InRecovery() {
162		// PRR is used when in recovery.
163		if !c.noPRR {
164			c.prr.OnPacketAcked(ackedBytes)
165		}
166		return
167	}
168	c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
169	if c.InSlowStart() {
170		c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
171	}
172}
173
174func (c *cubicSender) OnPacketLost(
175	packetNumber protocol.PacketNumber,
176	lostBytes protocol.ByteCount,
177	priorInFlight protocol.ByteCount,
178) {
179	// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
180	// already sent should be treated as a single loss event, since it's expected.
181	if packetNumber <= c.largestSentAtLastCutback {
182		if c.lastCutbackExitedSlowstart {
183			c.stats.slowstartPacketsLost++
184			c.stats.slowstartBytesLost += lostBytes
185			if c.slowStartLargeReduction {
186				// Reduce congestion window by lost_bytes for every loss.
187				c.congestionWindow = utils.MaxByteCount(c.congestionWindow-lostBytes, c.minSlowStartExitWindow)
188				c.slowstartThreshold = c.congestionWindow
189			}
190		}
191		return
192	}
193	c.lastCutbackExitedSlowstart = c.InSlowStart()
194	if c.InSlowStart() {
195		c.stats.slowstartPacketsLost++
196	}
197
198	if !c.noPRR {
199		c.prr.OnPacketLost(priorInFlight)
200	}
201
202	// TODO(chromium): Separate out all of slow start into a separate class.
203	if c.slowStartLargeReduction && c.InSlowStart() {
204		if c.congestionWindow >= 2*c.initialCongestionWindow {
205			c.minSlowStartExitWindow = c.congestionWindow / 2
206		}
207		c.congestionWindow -= protocol.DefaultTCPMSS
208	} else if c.reno {
209		c.congestionWindow = protocol.ByteCount(float32(c.congestionWindow) * c.RenoBeta())
210	} else {
211		c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
212	}
213	if c.congestionWindow < c.minCongestionWindow {
214		c.congestionWindow = c.minCongestionWindow
215	}
216	c.slowstartThreshold = c.congestionWindow
217	c.largestSentAtLastCutback = c.largestSentPacketNumber
218	// reset packet count from congestion avoidance mode. We start
219	// counting again when we're out of recovery.
220	c.numAckedPackets = 0
221}
222
223func (c *cubicSender) RenoBeta() float32 {
224	// kNConnectionBeta is the backoff factor after loss for our N-connection
225	// emulation, which emulates the effective backoff of an ensemble of N
226	// TCP-Reno connections on a single loss event. The effective multiplier is
227	// computed as:
228	return (float32(c.numConnections) - 1. + renoBeta) / float32(c.numConnections)
229}
230
231// Called when we receive an ack. Normal TCP tracks how many packets one ack
232// represents, but quic has a separate ack for each packet.
233func (c *cubicSender) maybeIncreaseCwnd(
234	_ protocol.PacketNumber,
235	ackedBytes protocol.ByteCount,
236	priorInFlight protocol.ByteCount,
237	eventTime time.Time,
238) {
239	// Do not increase the congestion window unless the sender is close to using
240	// the current window.
241	if !c.isCwndLimited(priorInFlight) {
242		c.cubic.OnApplicationLimited()
243		return
244	}
245	if c.congestionWindow >= c.maxCongestionWindow {
246		return
247	}
248	if c.InSlowStart() {
249		// TCP slow start, exponential growth, increase by one for each ACK.
250		c.congestionWindow += protocol.DefaultTCPMSS
251		return
252	}
253	// Congestion avoidance
254	if c.reno {
255		// Classic Reno congestion avoidance.
256		c.numAckedPackets++
257		// Divide by num_connections to smoothly increase the CWND at a faster
258		// rate than conventional Reno.
259		if c.numAckedPackets*uint64(c.numConnections) >= uint64(c.congestionWindow)/uint64(protocol.DefaultTCPMSS) {
260			c.congestionWindow += protocol.DefaultTCPMSS
261			c.numAckedPackets = 0
262		}
263	} else {
264		c.congestionWindow = utils.MinByteCount(c.maxCongestionWindow, c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
265	}
266}
267
268func (c *cubicSender) isCwndLimited(bytesInFlight protocol.ByteCount) bool {
269	congestionWindow := c.GetCongestionWindow()
270	if bytesInFlight >= congestionWindow {
271		return true
272	}
273	availableBytes := congestionWindow - bytesInFlight
274	slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2
275	return slowStartLimited || availableBytes <= maxBurstBytes
276}
277
278// BandwidthEstimate returns the current bandwidth estimate
279func (c *cubicSender) BandwidthEstimate() Bandwidth {
280	srtt := c.rttStats.SmoothedRTT()
281	if srtt == 0 {
282		// If we haven't measured an rtt, the bandwidth estimate is unknown.
283		return 0
284	}
285	return BandwidthFromDelta(c.GetCongestionWindow(), srtt)
286}
287
288// HybridSlowStart returns the hybrid slow start instance for testing
289func (c *cubicSender) HybridSlowStart() *HybridSlowStart {
290	return &c.hybridSlowStart
291}
292
293// SetNumEmulatedConnections sets the number of emulated connections
294func (c *cubicSender) SetNumEmulatedConnections(n int) {
295	c.numConnections = utils.Max(n, 1)
296	c.cubic.SetNumConnections(c.numConnections)
297}
298
299// OnRetransmissionTimeout is called on an retransmission timeout
300func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
301	c.largestSentAtLastCutback = protocol.InvalidPacketNumber
302	if !packetsRetransmitted {
303		return
304	}
305	c.hybridSlowStart.Restart()
306	c.cubic.Reset()
307	c.slowstartThreshold = c.congestionWindow / 2
308	c.congestionWindow = c.minCongestionWindow
309}
310
311// OnConnectionMigration is called when the connection is migrated (?)
312func (c *cubicSender) OnConnectionMigration() {
313	c.hybridSlowStart.Restart()
314	c.prr = PrrSender{}
315	c.largestSentPacketNumber = protocol.InvalidPacketNumber
316	c.largestAckedPacketNumber = protocol.InvalidPacketNumber
317	c.largestSentAtLastCutback = protocol.InvalidPacketNumber
318	c.lastCutbackExitedSlowstart = false
319	c.cubic.Reset()
320	c.numAckedPackets = 0
321	c.congestionWindow = c.initialCongestionWindow
322	c.slowstartThreshold = c.initialMaxCongestionWindow
323	c.maxCongestionWindow = c.initialMaxCongestionWindow
324}
325
326// SetSlowStartLargeReduction allows enabling the SSLR experiment
327func (c *cubicSender) SetSlowStartLargeReduction(enabled bool) {
328	c.slowStartLargeReduction = enabled
329}
330