1 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4 // option. This file may not be copied, modified, or distributed
5 // except according to those terms.
6 
7 // Congestion control
8 #![deny(clippy::pedantic)]
9 
10 use std::cmp::{max, min};
11 use std::fmt::{self, Debug, Display};
12 use std::time::{Duration, Instant};
13 
14 use super::CongestionControl;
15 
16 use crate::cc::MAX_DATAGRAM_SIZE;
17 use crate::qlog::{self, QlogMetric};
18 use crate::sender::PACING_BURST_SIZE;
19 use crate::tracking::SentPacket;
20 use neqo_common::{const_max, const_min, qdebug, qinfo, qlog::NeqoQlog, qtrace};
21 
22 pub const CWND_INITIAL_PKTS: usize = 10;
23 pub const CWND_INITIAL: usize = const_min(
24     CWND_INITIAL_PKTS * MAX_DATAGRAM_SIZE,
25     const_max(2 * MAX_DATAGRAM_SIZE, 14720),
26 );
27 pub const CWND_MIN: usize = MAX_DATAGRAM_SIZE * 2;
28 const PERSISTENT_CONG_THRESH: u32 = 3;
29 
30 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
31 enum State {
32     /// In either slow start or congestion avoidance, not recovery.
33     SlowStart,
34     /// In congestion avoidance.
35     CongestionAvoidance,
36     /// In a recovery period, but no packets have been sent yet.  This is a
37     /// transient state because we want to exempt the first packet sent after
38     /// entering recovery from the congestion window.
39     RecoveryStart,
40     /// In a recovery period, with the first packet sent at this time.
41     Recovery,
42     /// Start of persistent congestion, which is transient, like `RecoveryStart`.
43     PersistentCongestion,
44 }
45 
46 impl State {
in_recovery(self) -> bool47     pub fn in_recovery(self) -> bool {
48         matches!(self, Self::RecoveryStart | Self::Recovery)
49     }
50 
in_slow_start(self) -> bool51     pub fn in_slow_start(self) -> bool {
52         self == Self::SlowStart
53     }
54 
55     /// These states are transient, we tell qlog on entry, but not on exit.
transient(self) -> bool56     pub fn transient(self) -> bool {
57         matches!(self, Self::RecoveryStart | Self::PersistentCongestion)
58     }
59 
60     /// Update a transient state to the true state.
update(&mut self)61     pub fn update(&mut self) {
62         *self = match self {
63             Self::PersistentCongestion => Self::SlowStart,
64             Self::RecoveryStart => Self::Recovery,
65             _ => unreachable!(),
66         };
67     }
68 
to_qlog(self) -> &'static str69     pub fn to_qlog(self) -> &'static str {
70         match self {
71             Self::SlowStart | Self::PersistentCongestion => "slow_start",
72             Self::CongestionAvoidance => "congestion_avoidance",
73             Self::Recovery | Self::RecoveryStart => "recovery",
74         }
75     }
76 }
77 
78 pub trait WindowAdjustment: Display + Debug {
79     /// This is called when an ack is received.
80     /// The function calculates the amount of acked bytes congestion controller needs
81     /// to collect before increasing its cwnd by `MAX_DATAGRAM_SIZE`.
bytes_for_cwnd_increase( &mut self, curr_cwnd: usize, new_acked_bytes: usize, min_rtt: Duration, now: Instant, ) -> usize82     fn bytes_for_cwnd_increase(
83         &mut self,
84         curr_cwnd: usize,
85         new_acked_bytes: usize,
86         min_rtt: Duration,
87         now: Instant,
88     ) -> usize;
89     /// This function is called when a congestion event has beed detected and it
90     /// returns new (decreased) values of `curr_cwnd` and `acked_bytes`.
91     /// This value can be very small; the calling code is responsible for ensuring that the
92     /// congestion window doesn't drop below the minimum of `CWND_MIN`.
reduce_cwnd(&mut self, curr_cwnd: usize, acked_bytes: usize) -> (usize, usize)93     fn reduce_cwnd(&mut self, curr_cwnd: usize, acked_bytes: usize) -> (usize, usize);
94     /// Cubic needs this signal to reset its epoch.
on_app_limited(&mut self)95     fn on_app_limited(&mut self);
96     #[cfg(test)]
last_max_cwnd(&self) -> f6497     fn last_max_cwnd(&self) -> f64;
98     #[cfg(test)]
set_last_max_cwnd(&mut self, last_max_cwnd: f64)99     fn set_last_max_cwnd(&mut self, last_max_cwnd: f64);
100 }
101 
102 #[derive(Debug)]
103 pub struct ClassicCongestionControl<T> {
104     cc_algorithm: T,
105     state: State,
106     congestion_window: usize, // = kInitialWindow
107     bytes_in_flight: usize,
108     acked_bytes: usize,
109     ssthresh: usize,
110     recovery_start: Option<Instant>,
111 
112     qlog: NeqoQlog,
113 }
114 
115 impl<T: WindowAdjustment> Display for ClassicCongestionControl<T> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result116     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
117         write!(
118             f,
119             "{} CongCtrl {}/{} ssthresh {}",
120             self.cc_algorithm, self.bytes_in_flight, self.congestion_window, self.ssthresh,
121         )?;
122         Ok(())
123     }
124 }
125 
126 impl<T: WindowAdjustment> CongestionControl for ClassicCongestionControl<T> {
set_qlog(&mut self, qlog: NeqoQlog)127     fn set_qlog(&mut self, qlog: NeqoQlog) {
128         self.qlog = qlog;
129     }
130 
131     #[must_use]
cwnd(&self) -> usize132     fn cwnd(&self) -> usize {
133         self.congestion_window
134     }
135 
136     #[must_use]
bytes_in_flight(&self) -> usize137     fn bytes_in_flight(&self) -> usize {
138         self.bytes_in_flight
139     }
140 
141     #[must_use]
cwnd_avail(&self) -> usize142     fn cwnd_avail(&self) -> usize {
143         // BIF can be higher than cwnd due to PTO packets, which are sent even
144         // if avail is 0, but still count towards BIF.
145         self.congestion_window.saturating_sub(self.bytes_in_flight)
146     }
147 
148     // Multi-packet version of OnPacketAckedCC
on_packets_acked(&mut self, acked_pkts: &[SentPacket], min_rtt: Duration, now: Instant)149     fn on_packets_acked(&mut self, acked_pkts: &[SentPacket], min_rtt: Duration, now: Instant) {
150         // Check whether we are app limited before acked packets are removed
151         // from bytes_in_flight.
152         let is_app_limited = self.app_limited();
153         qtrace!(
154             [self],
155             "app limited={}, bytes_in_flight:{}, cwnd: {}, state: {:?} pacing_burst_size: {}",
156             is_app_limited,
157             self.bytes_in_flight,
158             self.congestion_window,
159             self.state,
160             MAX_DATAGRAM_SIZE * PACING_BURST_SIZE,
161         );
162 
163         let mut new_acked = 0;
164         for pkt in acked_pkts.iter().filter(|pkt| pkt.cc_outstanding()) {
165             assert!(self.bytes_in_flight >= pkt.size);
166             self.bytes_in_flight -= pkt.size;
167 
168             if !self.after_recovery_start(pkt) {
169                 // Do not increase congestion window for packets sent before
170                 // recovery last started.
171                 continue;
172             }
173 
174             if self.state.in_recovery() {
175                 self.set_state(State::CongestionAvoidance);
176                 qlog::metrics_updated(&mut self.qlog, &[QlogMetric::InRecovery(false)]);
177             }
178 
179             new_acked += pkt.size;
180         }
181 
182         if is_app_limited {
183             self.cc_algorithm.on_app_limited();
184             return;
185         }
186 
187         qtrace!([self], "ACK received, acked_bytes = {}", self.acked_bytes);
188 
189         // Slow start, up to the slow start threshold.
190         if self.congestion_window < self.ssthresh {
191             self.acked_bytes += new_acked;
192             let increase = min(self.ssthresh - self.congestion_window, self.acked_bytes);
193             self.congestion_window += increase;
194             self.acked_bytes -= increase;
195             qinfo!([self], "slow start += {}", increase);
196             if self.congestion_window == self.ssthresh {
197                 // This doesn't look like it is necessary, but it can happen
198                 // after persistent congestion.
199                 self.set_state(State::CongestionAvoidance);
200             }
201         }
202         // Congestion avoidance, above the slow start threshold.
203         if self.congestion_window >= self.ssthresh {
204             // The following function return the amount acked bytes a controller needs
205             // to collect to be allowed to increase its cwnd by MAX_DATAGRAM_SIZE.
206             let bytes_for_increase = self.cc_algorithm.bytes_for_cwnd_increase(
207                 self.congestion_window,
208                 new_acked,
209                 min_rtt,
210                 now,
211             );
212             debug_assert!(bytes_for_increase > 0);
213             // If enough credit has been accumulated already, apply them gradually.
214             // If we have sudden increase in allowed rate we actually increase cwnd gently.
215             if self.acked_bytes >= bytes_for_increase {
216                 self.acked_bytes = 0;
217                 self.congestion_window += MAX_DATAGRAM_SIZE;
218             }
219             self.acked_bytes += new_acked;
220             if self.acked_bytes >= bytes_for_increase {
221                 self.acked_bytes -= bytes_for_increase;
222                 self.congestion_window += MAX_DATAGRAM_SIZE; // or is this the current MTU?
223             }
224             // The number of bytes we require can go down over time with Cubic.
225             // That might result in an excessive rate of increase, so limit the number of unused
226             // acknowledged bytes after increasing the congestion window twice.
227             self.acked_bytes = min(bytes_for_increase, self.acked_bytes);
228         }
229         qlog::metrics_updated(
230             &mut self.qlog,
231             &[
232                 QlogMetric::CongestionWindow(self.congestion_window),
233                 QlogMetric::BytesInFlight(self.bytes_in_flight),
234             ],
235         );
236     }
237 
238     /// Update congestion controller state based on lost packets.
on_packets_lost( &mut self, first_rtt_sample_time: Option<Instant>, prev_largest_acked_sent: Option<Instant>, pto: Duration, lost_packets: &[SentPacket], ) -> bool239     fn on_packets_lost(
240         &mut self,
241         first_rtt_sample_time: Option<Instant>,
242         prev_largest_acked_sent: Option<Instant>,
243         pto: Duration,
244         lost_packets: &[SentPacket],
245     ) -> bool {
246         if lost_packets.is_empty() {
247             return false;
248         }
249 
250         for pkt in lost_packets.iter().filter(|pkt| pkt.cc_in_flight()) {
251             assert!(self.bytes_in_flight >= pkt.size);
252             self.bytes_in_flight -= pkt.size;
253         }
254         qlog::metrics_updated(
255             &mut self.qlog,
256             &[QlogMetric::BytesInFlight(self.bytes_in_flight)],
257         );
258 
259         qdebug!([self], "Pkts lost {}", lost_packets.len());
260 
261         let congestion = self.on_congestion_event(lost_packets.last().unwrap());
262         let persistent_congestion = self.detect_persistent_congestion(
263             first_rtt_sample_time,
264             prev_largest_acked_sent,
265             pto,
266             lost_packets,
267         );
268         congestion || persistent_congestion
269     }
270 
discard(&mut self, pkt: &SentPacket)271     fn discard(&mut self, pkt: &SentPacket) {
272         if pkt.cc_outstanding() {
273             assert!(self.bytes_in_flight >= pkt.size);
274             self.bytes_in_flight -= pkt.size;
275             qlog::metrics_updated(
276                 &mut self.qlog,
277                 &[QlogMetric::BytesInFlight(self.bytes_in_flight)],
278             );
279             qtrace!([self], "Ignore pkt with size {}", pkt.size);
280         }
281     }
282 
discard_in_flight(&mut self)283     fn discard_in_flight(&mut self) {
284         self.bytes_in_flight = 0;
285         qlog::metrics_updated(
286             &mut self.qlog,
287             &[QlogMetric::BytesInFlight(self.bytes_in_flight)],
288         );
289     }
290 
on_packet_sent(&mut self, pkt: &SentPacket)291     fn on_packet_sent(&mut self, pkt: &SentPacket) {
292         // Record the recovery time and exit any transient state.
293         if self.state.transient() {
294             self.recovery_start = Some(pkt.time_sent);
295             self.state.update();
296         }
297 
298         if !pkt.cc_in_flight() {
299             return;
300         }
301 
302         self.bytes_in_flight += pkt.size;
303         qdebug!(
304             [self],
305             "Pkt Sent len {}, bif {}, cwnd {}",
306             pkt.size,
307             self.bytes_in_flight,
308             self.congestion_window
309         );
310         qlog::metrics_updated(
311             &mut self.qlog,
312             &[QlogMetric::BytesInFlight(self.bytes_in_flight)],
313         );
314     }
315 
316     /// Whether a packet can be sent immediately as a result of entering recovery.
recovery_packet(&self) -> bool317     fn recovery_packet(&self) -> bool {
318         self.state == State::RecoveryStart
319     }
320 }
321 
322 impl<T: WindowAdjustment> ClassicCongestionControl<T> {
new(cc_algorithm: T) -> Self323     pub fn new(cc_algorithm: T) -> Self {
324         Self {
325             cc_algorithm,
326             state: State::SlowStart,
327             congestion_window: CWND_INITIAL,
328             bytes_in_flight: 0,
329             acked_bytes: 0,
330             ssthresh: usize::MAX,
331             recovery_start: None,
332             qlog: NeqoQlog::disabled(),
333         }
334     }
335 
336     #[cfg(test)]
337     #[must_use]
ssthresh(&self) -> usize338     pub fn ssthresh(&self) -> usize {
339         self.ssthresh
340     }
341 
342     #[cfg(test)]
set_ssthresh(&mut self, v: usize)343     pub fn set_ssthresh(&mut self, v: usize) {
344         self.ssthresh = v;
345     }
346 
347     #[cfg(test)]
last_max_cwnd(&self) -> f64348     pub fn last_max_cwnd(&self) -> f64 {
349         self.cc_algorithm.last_max_cwnd()
350     }
351 
352     #[cfg(test)]
set_last_max_cwnd(&mut self, last_max_cwnd: f64)353     pub fn set_last_max_cwnd(&mut self, last_max_cwnd: f64) {
354         self.cc_algorithm.set_last_max_cwnd(last_max_cwnd);
355     }
356 
357     #[cfg(test)]
acked_bytes(&self) -> usize358     pub fn acked_bytes(&self) -> usize {
359         self.acked_bytes
360     }
361 
set_state(&mut self, state: State)362     fn set_state(&mut self, state: State) {
363         if self.state != state {
364             qdebug!([self], "state -> {:?}", state);
365             let old_state = self.state;
366             self.qlog.add_event(|| {
367                 // No need to tell qlog about exit from transient states.
368                 if old_state.transient() {
369                     None
370                 } else {
371                     Some(::qlog::event::Event::congestion_state_updated(
372                         Some(old_state.to_qlog().to_owned()),
373                         state.to_qlog().to_owned(),
374                     ))
375                 }
376             });
377             self.state = state;
378         }
379     }
380 
detect_persistent_congestion( &mut self, first_rtt_sample_time: Option<Instant>, prev_largest_acked_sent: Option<Instant>, pto: Duration, lost_packets: &[SentPacket], ) -> bool381     fn detect_persistent_congestion(
382         &mut self,
383         first_rtt_sample_time: Option<Instant>,
384         prev_largest_acked_sent: Option<Instant>,
385         pto: Duration,
386         lost_packets: &[SentPacket],
387     ) -> bool {
388         if first_rtt_sample_time.is_none() {
389             return false;
390         }
391 
392         let pc_period = pto * PERSISTENT_CONG_THRESH;
393 
394         let mut last_pn = 1 << 62; // Impossibly large, but not enough to overflow.
395         let mut start = None;
396 
397         // Look for the first lost packet after the previous largest acknowledged.
398         // Ignore packets that weren't ack-eliciting for the start of this range.
399         // Also, make sure to ignore any packets sent before we got an RTT estimate
400         // as we might not have sent PTO packets soon enough after those.
401         let cutoff = max(first_rtt_sample_time, prev_largest_acked_sent);
402         for p in lost_packets
403             .iter()
404             .skip_while(|p| Some(p.time_sent) < cutoff)
405         {
406             if p.pn != last_pn + 1 {
407                 // Not a contiguous range of lost packets, start over.
408                 start = None;
409             }
410             last_pn = p.pn;
411             if !p.cc_in_flight() {
412                 // Not interesting, keep looking.
413                 continue;
414             }
415             if let Some(t) = start {
416                 if p.time_sent.duration_since(t) > pc_period {
417                     qinfo!([self], "persistent congestion");
418                     self.congestion_window = CWND_MIN;
419                     self.acked_bytes = 0;
420                     self.set_state(State::PersistentCongestion);
421                     qlog::metrics_updated(
422                         &mut self.qlog,
423                         &[QlogMetric::CongestionWindow(self.congestion_window)],
424                     );
425                     return true;
426                 }
427             } else {
428                 start = Some(p.time_sent);
429             }
430         }
431         false
432     }
433 
434     #[must_use]
after_recovery_start(&mut self, packet: &SentPacket) -> bool435     fn after_recovery_start(&mut self, packet: &SentPacket) -> bool {
436         // At the start of the first recovery period, if the state is
437         // transient, all packets will have been sent before recovery.
438         self.recovery_start
439             .map_or(!self.state.transient(), |t| packet.time_sent >= t)
440     }
441 
442     /// Handle a congestion event.
443     /// Returns true if this was a true congestion event.
on_congestion_event(&mut self, last_packet: &SentPacket) -> bool444     fn on_congestion_event(&mut self, last_packet: &SentPacket) -> bool {
445         // Start a new congestion event if lost packet was sent after the start
446         // of the previous congestion recovery period.
447         if !self.after_recovery_start(last_packet) {
448             return false;
449         }
450 
451         let (cwnd, acked_bytes) = self
452             .cc_algorithm
453             .reduce_cwnd(self.congestion_window, self.acked_bytes);
454         self.congestion_window = max(cwnd, CWND_MIN);
455         self.acked_bytes = acked_bytes;
456         self.ssthresh = self.congestion_window;
457         qinfo!(
458             [self],
459             "Cong event -> recovery; cwnd {}, ssthresh {}",
460             self.congestion_window,
461             self.ssthresh
462         );
463         qlog::metrics_updated(
464             &mut self.qlog,
465             &[
466                 QlogMetric::CongestionWindow(self.congestion_window),
467                 QlogMetric::SsThresh(self.ssthresh),
468                 QlogMetric::InRecovery(true),
469             ],
470         );
471         self.set_state(State::RecoveryStart);
472         true
473     }
474 
475     #[allow(clippy::unused_self)]
app_limited(&self) -> bool476     fn app_limited(&self) -> bool {
477         if self.bytes_in_flight >= self.congestion_window {
478             false
479         } else if self.state.in_slow_start() {
480             // Allow for potential doubling of the congestion window during slow start.
481             // That is, the application might not have been able to send enough to respond
482             // to increases to the congestion window.
483             self.bytes_in_flight < self.congestion_window / 2
484         } else {
485             // We're not limited if the in-flight data is within a single burst of the
486             // congestion window.
487             (self.bytes_in_flight + MAX_DATAGRAM_SIZE * PACING_BURST_SIZE) < self.congestion_window
488         }
489     }
490 }
491 
492 #[cfg(test)]
493 mod tests {
494     use super::{
495         ClassicCongestionControl, WindowAdjustment, CWND_INITIAL, CWND_MIN, PERSISTENT_CONG_THRESH,
496     };
497     use crate::cc::cubic::{Cubic, CUBIC_BETA_USIZE_DIVISOR, CUBIC_BETA_USIZE_QUOTIENT};
498     use crate::cc::new_reno::NewReno;
499     use crate::cc::{
500         CongestionControl, CongestionControlAlgorithm, CWND_INITIAL_PKTS, MAX_DATAGRAM_SIZE,
501     };
502     use crate::packet::{PacketNumber, PacketType};
503     use crate::tracking::SentPacket;
504     use std::convert::TryFrom;
505     use std::time::{Duration, Instant};
506     use test_fixture::now;
507 
508     const PTO: Duration = Duration::from_millis(100);
509     const RTT: Duration = Duration::from_millis(98);
510     const ZERO: Duration = Duration::from_secs(0);
511     const EPSILON: Duration = Duration::from_nanos(1);
512     const GAP: Duration = Duration::from_secs(1);
513     /// The largest time between packets without causing persistent congestion.
514     const SUB_PC: Duration = Duration::from_millis(100 * PERSISTENT_CONG_THRESH as u64);
515     /// The minimum time between packets to cause persistent congestion.
516     /// Uses an odd expression because `Duration` arithmetic isn't `const`.
517     const PC: Duration = Duration::from_nanos(100_000_000 * (PERSISTENT_CONG_THRESH as u64) + 1);
518 
cwnd_is_default(cc: &ClassicCongestionControl<NewReno>)519     fn cwnd_is_default(cc: &ClassicCongestionControl<NewReno>) {
520         assert_eq!(cc.cwnd(), CWND_INITIAL);
521         assert_eq!(cc.ssthresh(), usize::MAX);
522     }
523 
cwnd_is_halved(cc: &ClassicCongestionControl<NewReno>)524     fn cwnd_is_halved(cc: &ClassicCongestionControl<NewReno>) {
525         assert_eq!(cc.cwnd(), CWND_INITIAL / 2);
526         assert_eq!(cc.ssthresh(), CWND_INITIAL / 2);
527     }
528 
lost(pn: PacketNumber, ack_eliciting: bool, t: Duration) -> SentPacket529     fn lost(pn: PacketNumber, ack_eliciting: bool, t: Duration) -> SentPacket {
530         SentPacket::new(
531             PacketType::Short,
532             pn,
533             now() + t,
534             ack_eliciting,
535             Vec::new(),
536             100,
537         )
538     }
539 
congestion_control(cc: CongestionControlAlgorithm) -> Box<dyn CongestionControl>540     fn congestion_control(cc: CongestionControlAlgorithm) -> Box<dyn CongestionControl> {
541         match cc {
542             CongestionControlAlgorithm::NewReno => {
543                 Box::new(ClassicCongestionControl::new(NewReno::default()))
544             }
545             CongestionControlAlgorithm::Cubic => {
546                 Box::new(ClassicCongestionControl::new(Cubic::default()))
547             }
548         }
549     }
550 
persistent_congestion_by_algorithm( cc_alg: CongestionControlAlgorithm, reduced_cwnd: usize, lost_packets: &[SentPacket], persistent_expected: bool, )551     fn persistent_congestion_by_algorithm(
552         cc_alg: CongestionControlAlgorithm,
553         reduced_cwnd: usize,
554         lost_packets: &[SentPacket],
555         persistent_expected: bool,
556     ) {
557         let mut cc = congestion_control(cc_alg);
558         for p in lost_packets {
559             cc.on_packet_sent(p);
560         }
561 
562         cc.on_packets_lost(Some(now()), None, PTO, lost_packets);
563 
564         let persistent = if cc.cwnd() == reduced_cwnd {
565             false
566         } else if cc.cwnd() == CWND_MIN {
567             true
568         } else {
569             panic!("unexpected cwnd");
570         };
571         assert_eq!(persistent, persistent_expected);
572     }
573 
persistent_congestion(lost_packets: &[SentPacket], persistent_expected: bool)574     fn persistent_congestion(lost_packets: &[SentPacket], persistent_expected: bool) {
575         persistent_congestion_by_algorithm(
576             CongestionControlAlgorithm::NewReno,
577             CWND_INITIAL / 2,
578             lost_packets,
579             persistent_expected,
580         );
581         persistent_congestion_by_algorithm(
582             CongestionControlAlgorithm::Cubic,
583             CWND_INITIAL * CUBIC_BETA_USIZE_QUOTIENT / CUBIC_BETA_USIZE_DIVISOR,
584             lost_packets,
585             persistent_expected,
586         );
587     }
588 
589     /// A span of exactly the PC threshold only reduces the window on loss.
590     #[test]
persistent_congestion_none()591     fn persistent_congestion_none() {
592         persistent_congestion(&[lost(1, true, ZERO), lost(2, true, SUB_PC)], false);
593     }
594 
595     /// A span of just more than the PC threshold causes persistent congestion.
596     #[test]
persistent_congestion_simple()597     fn persistent_congestion_simple() {
598         persistent_congestion(&[lost(1, true, ZERO), lost(2, true, PC)], true);
599     }
600 
601     /// Both packets need to be ack-eliciting.
602     #[test]
persistent_congestion_non_ack_eliciting()603     fn persistent_congestion_non_ack_eliciting() {
604         persistent_congestion(&[lost(1, false, ZERO), lost(2, true, PC)], false);
605         persistent_congestion(&[lost(1, true, ZERO), lost(2, false, PC)], false);
606     }
607 
608     /// Packets in the middle, of any type, are OK.
609     #[test]
persistent_congestion_middle()610     fn persistent_congestion_middle() {
611         persistent_congestion(
612             &[lost(1, true, ZERO), lost(2, false, RTT), lost(3, true, PC)],
613             true,
614         );
615         persistent_congestion(
616             &[lost(1, true, ZERO), lost(2, true, RTT), lost(3, true, PC)],
617             true,
618         );
619     }
620 
621     /// Leading non-ack-eliciting packets are skipped.
622     #[test]
persistent_congestion_leading_non_ack_eliciting()623     fn persistent_congestion_leading_non_ack_eliciting() {
624         persistent_congestion(
625             &[lost(1, false, ZERO), lost(2, true, RTT), lost(3, true, PC)],
626             false,
627         );
628         persistent_congestion(
629             &[
630                 lost(1, false, ZERO),
631                 lost(2, true, RTT),
632                 lost(3, true, RTT + PC),
633             ],
634             true,
635         );
636     }
637 
638     /// Trailing non-ack-eliciting packets aren't relevant.
639     #[test]
persistent_congestion_trailing_non_ack_eliciting()640     fn persistent_congestion_trailing_non_ack_eliciting() {
641         persistent_congestion(
642             &[
643                 lost(1, true, ZERO),
644                 lost(2, true, PC),
645                 lost(3, false, PC + EPSILON),
646             ],
647             true,
648         );
649         persistent_congestion(
650             &[
651                 lost(1, true, ZERO),
652                 lost(2, true, SUB_PC),
653                 lost(3, false, PC),
654             ],
655             false,
656         );
657     }
658 
659     /// Gaps in the middle, of any type, restart the count.
660     #[test]
persistent_congestion_gap_reset()661     fn persistent_congestion_gap_reset() {
662         persistent_congestion(&[lost(1, true, ZERO), lost(3, true, PC)], false);
663         persistent_congestion(
664             &[
665                 lost(1, true, ZERO),
666                 lost(2, true, RTT),
667                 lost(4, true, GAP),
668                 lost(5, true, GAP + PTO * PERSISTENT_CONG_THRESH),
669             ],
670             false,
671         );
672     }
673 
674     /// A span either side of a gap will cause persistent congestion.
675     #[test]
persistent_congestion_gap_or()676     fn persistent_congestion_gap_or() {
677         persistent_congestion(
678             &[
679                 lost(1, true, ZERO),
680                 lost(2, true, PC),
681                 lost(4, true, GAP),
682                 lost(5, true, GAP + PTO),
683             ],
684             true,
685         );
686         persistent_congestion(
687             &[
688                 lost(1, true, ZERO),
689                 lost(2, true, PTO),
690                 lost(4, true, GAP),
691                 lost(5, true, GAP + PC),
692             ],
693             true,
694         );
695     }
696 
697     /// A gap only restarts after an ack-eliciting packet.
698     #[test]
persistent_congestion_gap_non_ack_eliciting()699     fn persistent_congestion_gap_non_ack_eliciting() {
700         persistent_congestion(
701             &[
702                 lost(1, true, ZERO),
703                 lost(2, true, PTO),
704                 lost(4, false, GAP),
705                 lost(5, true, GAP + PC),
706             ],
707             false,
708         );
709         persistent_congestion(
710             &[
711                 lost(1, true, ZERO),
712                 lost(2, true, PTO),
713                 lost(4, false, GAP),
714                 lost(5, true, GAP + RTT),
715                 lost(6, true, GAP + RTT + SUB_PC),
716             ],
717             false,
718         );
719         persistent_congestion(
720             &[
721                 lost(1, true, ZERO),
722                 lost(2, true, PTO),
723                 lost(4, false, GAP),
724                 lost(5, true, GAP + RTT),
725                 lost(6, true, GAP + RTT + PC),
726             ],
727             true,
728         );
729     }
730 
731     /// Get a time, in multiples of `PTO`, relative to `now()`.
by_pto(t: u32) -> Instant732     fn by_pto(t: u32) -> Instant {
733         now() + (PTO * t)
734     }
735 
736     /// Make packets that will be made lost.
737     /// `times` is the time of sending, in multiples of `PTO`, relative to `now()`.
make_lost(times: &[u32]) -> Vec<SentPacket>738     fn make_lost(times: &[u32]) -> Vec<SentPacket> {
739         times
740             .iter()
741             .enumerate()
742             .map(|(i, &t)| {
743                 SentPacket::new(
744                     PacketType::Short,
745                     u64::try_from(i).unwrap(),
746                     by_pto(t),
747                     true,
748                     Vec::new(),
749                     1000,
750                 )
751             })
752             .collect::<Vec<_>>()
753     }
754 
755     /// Call `detect_persistent_congestion` using times relative to now and the fixed PTO time.
756     /// `last_ack` and `rtt_time` are times in multiples of `PTO`, relative to `now()`,
757     /// for the time of the largest acknowledged and the first RTT sample, respectively.
persistent_congestion_by_pto<T: WindowAdjustment>( mut cc: ClassicCongestionControl<T>, last_ack: u32, rtt_time: u32, lost: &[SentPacket], ) -> bool758     fn persistent_congestion_by_pto<T: WindowAdjustment>(
759         mut cc: ClassicCongestionControl<T>,
760         last_ack: u32,
761         rtt_time: u32,
762         lost: &[SentPacket],
763     ) -> bool {
764         assert_eq!(cc.cwnd(), CWND_INITIAL);
765 
766         let last_ack = Some(by_pto(last_ack));
767         let rtt_time = Some(by_pto(rtt_time));
768 
769         // Persistent congestion is never declared if the RTT time is `None`.
770         cc.detect_persistent_congestion(None, None, PTO, lost);
771         assert_eq!(cc.cwnd(), CWND_INITIAL);
772         cc.detect_persistent_congestion(None, last_ack, PTO, lost);
773         assert_eq!(cc.cwnd(), CWND_INITIAL);
774 
775         cc.detect_persistent_congestion(rtt_time, last_ack, PTO, lost);
776         cc.cwnd() == CWND_MIN
777     }
778 
779     /// No persistent congestion can be had if there are no lost packets.
780     #[test]
persistent_congestion_no_lost()781     fn persistent_congestion_no_lost() {
782         let lost = make_lost(&[]);
783         assert!(!persistent_congestion_by_pto(
784             ClassicCongestionControl::new(NewReno::default()),
785             0,
786             0,
787             &lost
788         ));
789         assert!(!persistent_congestion_by_pto(
790             ClassicCongestionControl::new(Cubic::default()),
791             0,
792             0,
793             &lost
794         ));
795     }
796 
797     /// No persistent congestion can be had if there is only one lost packet.
798     #[test]
persistent_congestion_one_lost()799     fn persistent_congestion_one_lost() {
800         let lost = make_lost(&[1]);
801         assert!(!persistent_congestion_by_pto(
802             ClassicCongestionControl::new(NewReno::default()),
803             0,
804             0,
805             &lost
806         ));
807         assert!(!persistent_congestion_by_pto(
808             ClassicCongestionControl::new(Cubic::default()),
809             0,
810             0,
811             &lost
812         ));
813     }
814 
815     /// Persistent congestion can't happen based on old packets.
816     #[test]
persistent_congestion_past()817     fn persistent_congestion_past() {
818         // Packets sent prior to either the last acknowledged or the first RTT
819         // sample are not considered.  So 0 is ignored.
820         let lost = make_lost(&[0, PERSISTENT_CONG_THRESH + 1, PERSISTENT_CONG_THRESH + 2]);
821         assert!(!persistent_congestion_by_pto(
822             ClassicCongestionControl::new(NewReno::default()),
823             1,
824             1,
825             &lost
826         ));
827         assert!(!persistent_congestion_by_pto(
828             ClassicCongestionControl::new(NewReno::default()),
829             0,
830             1,
831             &lost
832         ));
833         assert!(!persistent_congestion_by_pto(
834             ClassicCongestionControl::new(NewReno::default()),
835             1,
836             0,
837             &lost
838         ));
839         assert!(!persistent_congestion_by_pto(
840             ClassicCongestionControl::new(Cubic::default()),
841             1,
842             1,
843             &lost
844         ));
845         assert!(!persistent_congestion_by_pto(
846             ClassicCongestionControl::new(Cubic::default()),
847             0,
848             1,
849             &lost
850         ));
851         assert!(!persistent_congestion_by_pto(
852             ClassicCongestionControl::new(Cubic::default()),
853             1,
854             0,
855             &lost
856         ));
857     }
858 
859     /// Persistent congestion doesn't start unless the packet is ack-eliciting.
860     #[test]
persistent_congestion_ack_eliciting()861     fn persistent_congestion_ack_eliciting() {
862         let mut lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]);
863         lost[0] = SentPacket::new(
864             lost[0].pt,
865             lost[0].pn,
866             lost[0].time_sent,
867             false,
868             Vec::new(),
869             lost[0].size,
870         );
871         assert!(!persistent_congestion_by_pto(
872             ClassicCongestionControl::new(NewReno::default()),
873             0,
874             0,
875             &lost
876         ));
877         assert!(!persistent_congestion_by_pto(
878             ClassicCongestionControl::new(Cubic::default()),
879             0,
880             0,
881             &lost
882         ));
883     }
884 
885     /// Detect persistent congestion.  Note that the first lost packet needs to have a time
886     /// greater than the previously acknowledged packet AND the first RTT sample.  And the
887     /// difference in times needs to be greater than the persistent congestion threshold.
888     #[test]
persistent_congestion_min()889     fn persistent_congestion_min() {
890         let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]);
891         assert!(persistent_congestion_by_pto(
892             ClassicCongestionControl::new(NewReno::default()),
893             0,
894             0,
895             &lost
896         ));
897         assert!(persistent_congestion_by_pto(
898             ClassicCongestionControl::new(Cubic::default()),
899             0,
900             0,
901             &lost
902         ));
903     }
904 
905     /// Make sure that not having a previous largest acknowledged also results
906     /// in detecting persistent congestion.  (This is not expected to happen, but
907     /// the code permits it).
908     #[test]
persistent_congestion_no_prev_ack_newreno()909     fn persistent_congestion_no_prev_ack_newreno() {
910         let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]);
911         let mut cc = ClassicCongestionControl::new(NewReno::default());
912         cc.detect_persistent_congestion(Some(by_pto(0)), None, PTO, &lost);
913         assert_eq!(cc.cwnd(), CWND_MIN);
914     }
915 
916     #[test]
persistent_congestion_no_prev_ack_cubic()917     fn persistent_congestion_no_prev_ack_cubic() {
918         let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]);
919         let mut cc = ClassicCongestionControl::new(Cubic::default());
920         cc.detect_persistent_congestion(Some(by_pto(0)), None, PTO, &lost);
921         assert_eq!(cc.cwnd(), CWND_MIN);
922     }
923 
924     /// The code asserts on ordering errors.
925     #[test]
926     #[should_panic]
persistent_congestion_unsorted_newreno()927     fn persistent_congestion_unsorted_newreno() {
928         let lost = make_lost(&[PERSISTENT_CONG_THRESH + 2, 1]);
929         assert!(!persistent_congestion_by_pto(
930             ClassicCongestionControl::new(NewReno::default()),
931             0,
932             0,
933             &lost
934         ));
935     }
936 
937     /// The code asserts on ordering errors.
938     #[test]
939     #[should_panic]
persistent_congestion_unsorted_cubic()940     fn persistent_congestion_unsorted_cubic() {
941         let lost = make_lost(&[PERSISTENT_CONG_THRESH + 2, 1]);
942         assert!(!persistent_congestion_by_pto(
943             ClassicCongestionControl::new(Cubic::default()),
944             0,
945             0,
946             &lost
947         ));
948     }
949 
950     #[test]
app_limited_slow_start()951     fn app_limited_slow_start() {
952         const LESS_THAN_CWND_PKTS: usize = 4;
953         let mut cc = ClassicCongestionControl::new(NewReno::default());
954 
955         for i in 0..CWND_INITIAL_PKTS {
956             let sent = SentPacket::new(
957                 PacketType::Short,
958                 u64::try_from(i).unwrap(), // pn
959                 now(),                     // time sent
960                 true,                      // ack eliciting
961                 Vec::new(),                // tokens
962                 MAX_DATAGRAM_SIZE,         // size
963             );
964             cc.on_packet_sent(&sent);
965         }
966         assert_eq!(cc.bytes_in_flight(), CWND_INITIAL);
967 
968         for i in 0..LESS_THAN_CWND_PKTS {
969             let acked = SentPacket::new(
970                 PacketType::Short,
971                 u64::try_from(i).unwrap(), // pn
972                 now(),                     // time sent
973                 true,                      // ack eliciting
974                 Vec::new(),                // tokens
975                 MAX_DATAGRAM_SIZE,         // size
976             );
977             cc.on_packets_acked(&[acked], RTT, now());
978 
979             assert_eq!(
980                 cc.bytes_in_flight(),
981                 (CWND_INITIAL_PKTS - i - 1) * MAX_DATAGRAM_SIZE
982             );
983             assert_eq!(cc.cwnd(), (CWND_INITIAL_PKTS + i + 1) * MAX_DATAGRAM_SIZE);
984         }
985 
986         // Now we are app limited
987         for i in 4..CWND_INITIAL_PKTS {
988             let p = [SentPacket::new(
989                 PacketType::Short,
990                 u64::try_from(i).unwrap(), // pn
991                 now(),                     // time sent
992                 true,                      // ack eliciting
993                 Vec::new(),                // tokens
994                 MAX_DATAGRAM_SIZE,         // size
995             )];
996             cc.on_packets_acked(&p, RTT, now());
997 
998             assert_eq!(
999                 cc.bytes_in_flight(),
1000                 (CWND_INITIAL_PKTS - i - 1) * MAX_DATAGRAM_SIZE
1001             );
1002             assert_eq!(cc.cwnd(), (CWND_INITIAL_PKTS + 4) * MAX_DATAGRAM_SIZE);
1003         }
1004     }
1005 
1006     #[test]
app_limited_congestion_avoidance()1007     fn app_limited_congestion_avoidance() {
1008         const CWND_PKTS_CA: usize = CWND_INITIAL_PKTS / 2;
1009 
1010         let mut cc = ClassicCongestionControl::new(NewReno::default());
1011 
1012         // Change state to congestion avoidance by introducing loss.
1013 
1014         let p_lost = SentPacket::new(
1015             PacketType::Short,
1016             1,                 // pn
1017             now(),             // time sent
1018             true,              // ack eliciting
1019             Vec::new(),        // tokens
1020             MAX_DATAGRAM_SIZE, // size
1021         );
1022         cc.on_packet_sent(&p_lost);
1023         cwnd_is_default(&cc);
1024         cc.on_packets_lost(Some(now()), None, PTO, &[p_lost]);
1025         cwnd_is_halved(&cc);
1026         let p_not_lost = SentPacket::new(
1027             PacketType::Short,
1028             1,                 // pn
1029             now(),             // time sent
1030             true,              // ack eliciting
1031             Vec::new(),        // tokens
1032             MAX_DATAGRAM_SIZE, // size
1033         );
1034         cc.on_packet_sent(&p_not_lost);
1035         cc.on_packets_acked(&[p_not_lost], RTT, now());
1036         cwnd_is_halved(&cc);
1037         // cc is app limited therefore cwnd in not increased.
1038         assert_eq!(cc.acked_bytes, 0);
1039 
1040         // Now we are in the congestion avoidance state.
1041         let mut pkts = Vec::new();
1042         for i in 0..CWND_PKTS_CA {
1043             let p = SentPacket::new(
1044                 PacketType::Short,
1045                 u64::try_from(i + 3).unwrap(), // pn
1046                 now(),                         // time sent
1047                 true,                          // ack eliciting
1048                 Vec::new(),                    // tokens
1049                 MAX_DATAGRAM_SIZE,             // size
1050             );
1051             cc.on_packet_sent(&p);
1052             pkts.push(p);
1053         }
1054         assert_eq!(cc.bytes_in_flight(), CWND_INITIAL / 2);
1055 
1056         for i in 0..CWND_PKTS_CA - 2 {
1057             cc.on_packets_acked(&pkts[i..=i], RTT, now());
1058 
1059             assert_eq!(
1060                 cc.bytes_in_flight(),
1061                 (CWND_PKTS_CA - i - 1) * MAX_DATAGRAM_SIZE
1062             );
1063             assert_eq!(cc.cwnd(), CWND_PKTS_CA * MAX_DATAGRAM_SIZE);
1064             assert_eq!(cc.acked_bytes, MAX_DATAGRAM_SIZE * (i + 1));
1065         }
1066 
1067         // Now we are app limited
1068         for i in CWND_PKTS_CA - 2..CWND_PKTS_CA {
1069             cc.on_packets_acked(&pkts[i..=i], RTT, now());
1070 
1071             assert_eq!(
1072                 cc.bytes_in_flight(),
1073                 (CWND_PKTS_CA - i - 1) * MAX_DATAGRAM_SIZE
1074             );
1075             assert_eq!(cc.cwnd(), CWND_PKTS_CA * MAX_DATAGRAM_SIZE);
1076             assert_eq!(cc.acked_bytes, MAX_DATAGRAM_SIZE * 3);
1077         }
1078     }
1079 }
1080