1from dataclasses import dataclass, field 2from enum import Enum 3from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple 4 5from ..buffer import Buffer, size_uint_var 6from ..tls import Epoch 7from .crypto import CryptoPair 8from .logger import QuicLoggerTrace 9from .packet import ( 10 NON_ACK_ELICITING_FRAME_TYPES, 11 NON_IN_FLIGHT_FRAME_TYPES, 12 PACKET_NUMBER_MAX_SIZE, 13 PACKET_TYPE_HANDSHAKE, 14 PACKET_TYPE_INITIAL, 15 PACKET_TYPE_MASK, 16 QuicFrameType, 17 is_long_header, 18) 19 20PACKET_MAX_SIZE = 1280 21PACKET_LENGTH_SEND_SIZE = 2 22PACKET_NUMBER_SEND_SIZE = 2 23 24 25QuicDeliveryHandler = Callable[..., None] 26 27 28class QuicDeliveryState(Enum): 29 ACKED = 0 30 LOST = 1 31 EXPIRED = 2 32 33 34@dataclass 35class QuicSentPacket: 36 epoch: Epoch 37 in_flight: bool 38 is_ack_eliciting: bool 39 is_crypto_packet: bool 40 packet_number: int 41 packet_type: int 42 sent_time: Optional[float] = None 43 sent_bytes: int = 0 44 45 delivery_handlers: List[Tuple[QuicDeliveryHandler, Any]] = field( 46 default_factory=list 47 ) 48 quic_logger_frames: List[Dict] = field(default_factory=list) 49 50 51class QuicPacketBuilderStop(Exception): 52 pass 53 54 55class QuicPacketBuilder: 56 """ 57 Helper for building QUIC packets. 58 """ 59 60 def __init__( 61 self, 62 *, 63 host_cid: bytes, 64 peer_cid: bytes, 65 version: int, 66 is_client: bool, 67 packet_number: int = 0, 68 peer_token: bytes = b"", 69 quic_logger: Optional[QuicLoggerTrace] = None, 70 spin_bit: bool = False, 71 ): 72 self.max_flight_bytes: Optional[int] = None 73 self.max_total_bytes: Optional[int] = None 74 self.quic_logger_frames: Optional[List[Dict]] = None 75 76 self._host_cid = host_cid 77 self._is_client = is_client 78 self._peer_cid = peer_cid 79 self._peer_token = peer_token 80 self._quic_logger = quic_logger 81 self._spin_bit = spin_bit 82 self._version = version 83 84 # assembled datagrams and packets 85 self._datagrams: List[bytes] = [] 86 self._datagram_flight_bytes = 0 87 self._datagram_init = True 88 self._packets: List[QuicSentPacket] = [] 89 self._flight_bytes = 0 90 self._total_bytes = 0 91 92 # current packet 93 self._header_size = 0 94 self._packet: Optional[QuicSentPacket] = None 95 self._packet_crypto: Optional[CryptoPair] = None 96 self._packet_long_header = False 97 self._packet_number = packet_number 98 self._packet_start = 0 99 self._packet_type = 0 100 101 self._buffer = Buffer(PACKET_MAX_SIZE) 102 self._buffer_capacity = PACKET_MAX_SIZE 103 self._flight_capacity = PACKET_MAX_SIZE 104 105 @property 106 def packet_is_empty(self) -> bool: 107 """ 108 Returns `True` if the current packet is empty. 109 """ 110 assert self._packet is not None 111 packet_size = self._buffer.tell() - self._packet_start 112 return packet_size <= self._header_size 113 114 @property 115 def packet_number(self) -> int: 116 """ 117 Returns the packet number for the next packet. 118 """ 119 return self._packet_number 120 121 @property 122 def remaining_buffer_space(self) -> int: 123 """ 124 Returns the remaining number of bytes which can be used in 125 the current packet. 126 """ 127 return ( 128 self._buffer_capacity 129 - self._buffer.tell() 130 - self._packet_crypto.aead_tag_size 131 ) 132 133 @property 134 def remaining_flight_space(self) -> int: 135 """ 136 Returns the remaining number of bytes which can be used in 137 the current packet. 138 """ 139 return ( 140 self._flight_capacity 141 - self._buffer.tell() 142 - self._packet_crypto.aead_tag_size 143 ) 144 145 def flush(self) -> Tuple[List[bytes], List[QuicSentPacket]]: 146 """ 147 Returns the assembled datagrams. 148 """ 149 if self._packet is not None: 150 self._end_packet() 151 self._flush_current_datagram() 152 153 datagrams = self._datagrams 154 packets = self._packets 155 self._datagrams = [] 156 self._packets = [] 157 return datagrams, packets 158 159 def start_frame( 160 self, 161 frame_type: int, 162 capacity: int = 1, 163 handler: Optional[QuicDeliveryHandler] = None, 164 handler_args: Sequence[Any] = [], 165 ) -> Buffer: 166 """ 167 Starts a new frame. 168 """ 169 if self.remaining_buffer_space < capacity or ( 170 frame_type not in NON_IN_FLIGHT_FRAME_TYPES 171 and self.remaining_flight_space < capacity 172 ): 173 raise QuicPacketBuilderStop 174 175 self._buffer.push_uint_var(frame_type) 176 if frame_type not in NON_ACK_ELICITING_FRAME_TYPES: 177 self._packet.is_ack_eliciting = True 178 if frame_type not in NON_IN_FLIGHT_FRAME_TYPES: 179 self._packet.in_flight = True 180 if frame_type == QuicFrameType.CRYPTO: 181 self._packet.is_crypto_packet = True 182 if handler is not None: 183 self._packet.delivery_handlers.append((handler, handler_args)) 184 return self._buffer 185 186 def start_packet(self, packet_type: int, crypto: CryptoPair) -> None: 187 """ 188 Starts a new packet. 189 """ 190 buf = self._buffer 191 192 # finish previous datagram 193 if self._packet is not None: 194 self._end_packet() 195 196 # if there is too little space remaining, start a new datagram 197 # FIXME: the limit is arbitrary! 198 packet_start = buf.tell() 199 if self._buffer_capacity - packet_start < 128: 200 self._flush_current_datagram() 201 packet_start = 0 202 203 # initialize datagram if needed 204 if self._datagram_init: 205 if self.max_total_bytes is not None: 206 remaining_total_bytes = self.max_total_bytes - self._total_bytes 207 if remaining_total_bytes < self._buffer_capacity: 208 self._buffer_capacity = remaining_total_bytes 209 210 self._flight_capacity = self._buffer_capacity 211 if self.max_flight_bytes is not None: 212 remaining_flight_bytes = self.max_flight_bytes - self._flight_bytes 213 if remaining_flight_bytes < self._flight_capacity: 214 self._flight_capacity = remaining_flight_bytes 215 self._datagram_flight_bytes = 0 216 self._datagram_init = False 217 218 # calculate header size 219 packet_long_header = is_long_header(packet_type) 220 if packet_long_header: 221 header_size = 11 + len(self._peer_cid) + len(self._host_cid) 222 if (packet_type & PACKET_TYPE_MASK) == PACKET_TYPE_INITIAL: 223 token_length = len(self._peer_token) 224 header_size += size_uint_var(token_length) + token_length 225 else: 226 header_size = 3 + len(self._peer_cid) 227 228 # check we have enough space 229 if packet_start + header_size >= self._buffer_capacity: 230 raise QuicPacketBuilderStop 231 232 # determine ack epoch 233 if packet_type == PACKET_TYPE_INITIAL: 234 epoch = Epoch.INITIAL 235 elif packet_type == PACKET_TYPE_HANDSHAKE: 236 epoch = Epoch.HANDSHAKE 237 else: 238 epoch = Epoch.ONE_RTT 239 240 self._header_size = header_size 241 self._packet = QuicSentPacket( 242 epoch=epoch, 243 in_flight=False, 244 is_ack_eliciting=False, 245 is_crypto_packet=False, 246 packet_number=self._packet_number, 247 packet_type=packet_type, 248 ) 249 self._packet_crypto = crypto 250 self._packet_long_header = packet_long_header 251 self._packet_start = packet_start 252 self._packet_type = packet_type 253 self.quic_logger_frames = self._packet.quic_logger_frames 254 255 buf.seek(self._packet_start + self._header_size) 256 257 def _end_packet(self) -> None: 258 """ 259 Ends the current packet. 260 """ 261 buf = self._buffer 262 packet_size = buf.tell() - self._packet_start 263 if packet_size > self._header_size: 264 # pad initial datagram 265 if ( 266 self._is_client 267 and self._packet_type == PACKET_TYPE_INITIAL 268 and self._packet.is_crypto_packet 269 ): 270 if self.remaining_flight_space: 271 buf.push_bytes(bytes(self.remaining_flight_space)) 272 packet_size = buf.tell() - self._packet_start 273 self._packet.in_flight = True 274 275 # log frame 276 if self._quic_logger is not None: 277 self._packet.quic_logger_frames.append( 278 self._quic_logger.encode_padding_frame() 279 ) 280 281 # write header 282 if self._packet_long_header: 283 length = ( 284 packet_size 285 - self._header_size 286 + PACKET_NUMBER_SEND_SIZE 287 + self._packet_crypto.aead_tag_size 288 ) 289 290 buf.seek(self._packet_start) 291 buf.push_uint8(self._packet_type | (PACKET_NUMBER_SEND_SIZE - 1)) 292 buf.push_uint32(self._version) 293 buf.push_uint8(len(self._peer_cid)) 294 buf.push_bytes(self._peer_cid) 295 buf.push_uint8(len(self._host_cid)) 296 buf.push_bytes(self._host_cid) 297 if (self._packet_type & PACKET_TYPE_MASK) == PACKET_TYPE_INITIAL: 298 buf.push_uint_var(len(self._peer_token)) 299 buf.push_bytes(self._peer_token) 300 buf.push_uint16(length | 0x4000) 301 buf.push_uint16(self._packet_number & 0xFFFF) 302 else: 303 buf.seek(self._packet_start) 304 buf.push_uint8( 305 self._packet_type 306 | (self._spin_bit << 5) 307 | (self._packet_crypto.key_phase << 2) 308 | (PACKET_NUMBER_SEND_SIZE - 1) 309 ) 310 buf.push_bytes(self._peer_cid) 311 buf.push_uint16(self._packet_number & 0xFFFF) 312 313 # check whether we need padding 314 padding_size = ( 315 PACKET_NUMBER_MAX_SIZE 316 - PACKET_NUMBER_SEND_SIZE 317 + self._header_size 318 - packet_size 319 ) 320 if padding_size > 0: 321 buf.seek(self._packet_start + packet_size) 322 buf.push_bytes(bytes(padding_size)) 323 packet_size += padding_size 324 self._packet.in_flight = True 325 326 # log frame 327 if self._quic_logger is not None: 328 self._packet.quic_logger_frames.append( 329 self._quic_logger.encode_padding_frame() 330 ) 331 332 # encrypt in place 333 plain = buf.data_slice(self._packet_start, self._packet_start + packet_size) 334 buf.seek(self._packet_start) 335 buf.push_bytes( 336 self._packet_crypto.encrypt_packet( 337 plain[0 : self._header_size], 338 plain[self._header_size : packet_size], 339 self._packet_number, 340 ) 341 ) 342 self._packet.sent_bytes = buf.tell() - self._packet_start 343 self._packets.append(self._packet) 344 if self._packet.in_flight: 345 self._datagram_flight_bytes += self._packet.sent_bytes 346 347 # short header packets cannot be coallesced, we need a new datagram 348 if not self._packet_long_header: 349 self._flush_current_datagram() 350 351 self._packet_number += 1 352 else: 353 # "cancel" the packet 354 buf.seek(self._packet_start) 355 356 self._packet = None 357 self.quic_logger_frames = None 358 359 def _flush_current_datagram(self) -> None: 360 datagram_bytes = self._buffer.tell() 361 if datagram_bytes: 362 self._datagrams.append(self._buffer.data) 363 self._flight_bytes += self._datagram_flight_bytes 364 self._total_bytes += datagram_bytes 365 self._datagram_init = True 366 self._buffer.seek(0) 367