1import binascii 2import logging 3import os 4from collections import deque 5from dataclasses import dataclass 6from enum import Enum 7from typing import Any, Deque, Dict, FrozenSet, List, Optional, Sequence, Tuple 8 9from .. import tls 10from ..buffer import UINT_VAR_MAX, Buffer, BufferReadError, size_uint_var 11from . import events 12from .configuration import QuicConfiguration 13from .crypto import CryptoError, CryptoPair, KeyUnavailableError 14from .logger import QuicLoggerTrace 15from .packet import ( 16 NON_ACK_ELICITING_FRAME_TYPES, 17 PACKET_TYPE_HANDSHAKE, 18 PACKET_TYPE_INITIAL, 19 PACKET_TYPE_ONE_RTT, 20 PACKET_TYPE_RETRY, 21 PACKET_TYPE_ZERO_RTT, 22 PROBING_FRAME_TYPES, 23 RETRY_INTEGRITY_TAG_SIZE, 24 QuicErrorCode, 25 QuicFrameType, 26 QuicProtocolVersion, 27 QuicStreamFrame, 28 QuicTransportParameters, 29 get_retry_integrity_tag, 30 get_spin_bit, 31 is_long_header, 32 pull_ack_frame, 33 pull_quic_header, 34 pull_quic_transport_parameters, 35 push_ack_frame, 36 push_quic_transport_parameters, 37) 38from .packet_builder import ( 39 PACKET_MAX_SIZE, 40 QuicDeliveryState, 41 QuicPacketBuilder, 42 QuicPacketBuilderStop, 43) 44from .recovery import K_GRANULARITY, QuicPacketRecovery, QuicPacketSpace 45from .stream import QuicStream 46 47logger = logging.getLogger("quic") 48 49CRYPTO_BUFFER_SIZE = 16384 50EPOCH_SHORTCUTS = { 51 "I": tls.Epoch.INITIAL, 52 "H": tls.Epoch.HANDSHAKE, 53 "0": tls.Epoch.ZERO_RTT, 54 "1": tls.Epoch.ONE_RTT, 55} 56MAX_EARLY_DATA = 0xFFFFFFFF 57SECRETS_LABELS = [ 58 [ 59 None, 60 "QUIC_CLIENT_EARLY_TRAFFIC_SECRET", 61 "QUIC_CLIENT_HANDSHAKE_TRAFFIC_SECRET", 62 "QUIC_CLIENT_TRAFFIC_SECRET_0", 63 ], 64 [ 65 None, 66 None, 67 "QUIC_SERVER_HANDSHAKE_TRAFFIC_SECRET", 68 "QUIC_SERVER_TRAFFIC_SECRET_0", 69 ], 70] 71STREAM_FLAGS = 0x07 72 73NetworkAddress = Any 74 75# frame sizes 76ACK_FRAME_CAPACITY = 64 # FIXME: this is arbitrary! 77APPLICATION_CLOSE_FRAME_CAPACITY = 1 + 8 + 8 # + reason length 78HANDSHAKE_DONE_FRAME_CAPACITY = 1 79MAX_DATA_FRAME_CAPACITY = 1 + 8 80MAX_STREAM_DATA_FRAME_CAPACITY = 1 + 8 + 8 81NEW_CONNECTION_ID_FRAME_CAPACITY = 1 + 8 + 8 + 1 + 20 + 16 82PATH_CHALLENGE_FRAME_CAPACITY = 1 + 8 83PATH_RESPONSE_FRAME_CAPACITY = 1 + 8 84PING_FRAME_CAPACITY = 1 85RETIRE_CONNECTION_ID_CAPACITY = 1 + 8 86STREAMS_BLOCKED_CAPACITY = 1 + 8 87TRANSPORT_CLOSE_FRAME_CAPACITY = 1 + 8 + 8 + 8 # + reason length 88 89 90def EPOCHS(shortcut: str) -> FrozenSet[tls.Epoch]: 91 return frozenset(EPOCH_SHORTCUTS[i] for i in shortcut) 92 93 94def dump_cid(cid: bytes) -> str: 95 return binascii.hexlify(cid).decode("ascii") 96 97 98def get_epoch(packet_type: int) -> tls.Epoch: 99 if packet_type == PACKET_TYPE_INITIAL: 100 return tls.Epoch.INITIAL 101 elif packet_type == PACKET_TYPE_ZERO_RTT: 102 return tls.Epoch.ZERO_RTT 103 elif packet_type == PACKET_TYPE_HANDSHAKE: 104 return tls.Epoch.HANDSHAKE 105 else: 106 return tls.Epoch.ONE_RTT 107 108 109def stream_is_client_initiated(stream_id: int) -> bool: 110 """ 111 Returns True if the stream is client initiated. 112 """ 113 return not (stream_id & 1) 114 115 116def stream_is_unidirectional(stream_id: int) -> bool: 117 """ 118 Returns True if the stream is unidirectional. 119 """ 120 return bool(stream_id & 2) 121 122 123class QuicConnectionError(Exception): 124 def __init__(self, error_code: int, frame_type: int, reason_phrase: str): 125 self.error_code = error_code 126 self.frame_type = frame_type 127 self.reason_phrase = reason_phrase 128 129 def __str__(self) -> str: 130 s = "Error: %d, reason: %s" % (self.error_code, self.reason_phrase) 131 if self.frame_type is not None: 132 s += ", frame_type: %s" % self.frame_type 133 return s 134 135 136class QuicConnectionAdapter(logging.LoggerAdapter): 137 def process(self, msg: str, kwargs: Any) -> Tuple[str, Any]: 138 return "[%s] %s" % (self.extra["id"], msg), kwargs 139 140 141@dataclass 142class QuicConnectionId: 143 cid: bytes 144 sequence_number: int 145 stateless_reset_token: bytes = b"" 146 was_sent: bool = False 147 148 149class QuicConnectionState(Enum): 150 FIRSTFLIGHT = 0 151 CONNECTED = 1 152 CLOSING = 2 153 DRAINING = 3 154 TERMINATED = 4 155 156 157@dataclass 158class QuicNetworkPath: 159 addr: NetworkAddress 160 bytes_received: int = 0 161 bytes_sent: int = 0 162 is_validated: bool = False 163 local_challenge: Optional[bytes] = None 164 remote_challenge: Optional[bytes] = None 165 166 def can_send(self, size: int) -> bool: 167 return self.is_validated or (self.bytes_sent + size) <= 3 * self.bytes_received 168 169 170@dataclass 171class QuicReceiveContext: 172 epoch: tls.Epoch 173 host_cid: bytes 174 network_path: QuicNetworkPath 175 quic_logger_frames: Optional[List[Any]] 176 time: float 177 178 179END_STATES = frozenset( 180 [ 181 QuicConnectionState.CLOSING, 182 QuicConnectionState.DRAINING, 183 QuicConnectionState.TERMINATED, 184 ] 185) 186 187 188class QuicConnection: 189 """ 190 A QUIC connection. 191 192 The state machine is driven by three kinds of sources: 193 194 - the API user requesting data to be send out (see :meth:`connect`, 195 :meth:`send_ping`, :meth:`send_datagram_data` and :meth:`send_stream_data`) 196 - data being received from the network (see :meth:`receive_datagram`) 197 - a timer firing (see :meth:`handle_timer`) 198 199 :param configuration: The QUIC configuration to use. 200 """ 201 202 def __init__( 203 self, 204 *, 205 configuration: QuicConfiguration, 206 logger_connection_id: Optional[bytes] = None, 207 original_connection_id: Optional[bytes] = None, 208 session_ticket_fetcher: Optional[tls.SessionTicketFetcher] = None, 209 session_ticket_handler: Optional[tls.SessionTicketHandler] = None, 210 ) -> None: 211 if configuration.is_client: 212 assert ( 213 original_connection_id is None 214 ), "Cannot set original_connection_id for a client" 215 else: 216 assert ( 217 configuration.certificate is not None 218 ), "SSL certificate is required for a server" 219 assert ( 220 configuration.private_key is not None 221 ), "SSL private key is required for a server" 222 223 # configuration 224 self._configuration = configuration 225 self._is_client = configuration.is_client 226 227 self._ack_delay = K_GRANULARITY 228 self._close_at: Optional[float] = None 229 self._close_event: Optional[events.ConnectionTerminated] = None 230 self._connect_called = False 231 self._cryptos: Dict[tls.Epoch, CryptoPair] = {} 232 self._crypto_buffers: Dict[tls.Epoch, Buffer] = {} 233 self._crypto_streams: Dict[tls.Epoch, QuicStream] = {} 234 self._events: Deque[events.QuicEvent] = deque() 235 self._handshake_complete = False 236 self._handshake_confirmed = False 237 self._host_cids = [ 238 QuicConnectionId( 239 cid=os.urandom(configuration.connection_id_length), 240 sequence_number=0, 241 stateless_reset_token=os.urandom(16), 242 was_sent=True, 243 ) 244 ] 245 self.host_cid = self._host_cids[0].cid 246 self._host_cid_seq = 1 247 self._local_ack_delay_exponent = 3 248 self._local_active_connection_id_limit = 8 249 self._local_max_data = configuration.max_data 250 self._local_max_data_sent = configuration.max_data 251 self._local_max_data_used = 0 252 self._local_max_stream_data_bidi_local = configuration.max_stream_data 253 self._local_max_stream_data_bidi_remote = configuration.max_stream_data 254 self._local_max_stream_data_uni = configuration.max_stream_data 255 self._local_max_streams_bidi = 128 256 self._local_max_streams_uni = 128 257 self._loss_at: Optional[float] = None 258 self._network_paths: List[QuicNetworkPath] = [] 259 self._original_connection_id = original_connection_id 260 self._pacing_at: Optional[float] = None 261 self._packet_number = 0 262 self._parameters_received = False 263 self._peer_cid = os.urandom(configuration.connection_id_length) 264 self._peer_cid_seq: Optional[int] = None 265 self._peer_cid_available: List[QuicConnectionId] = [] 266 self._peer_token = b"" 267 self._quic_logger: Optional[QuicLoggerTrace] = None 268 self._remote_ack_delay_exponent = 3 269 self._remote_active_connection_id_limit = 0 270 self._remote_idle_timeout = 0.0 # seconds 271 self._remote_max_data = 0 272 self._remote_max_data_used = 0 273 self._remote_max_datagram_frame_size: Optional[int] = None 274 self._remote_max_stream_data_bidi_local = 0 275 self._remote_max_stream_data_bidi_remote = 0 276 self._remote_max_stream_data_uni = 0 277 self._remote_max_streams_bidi = 0 278 self._remote_max_streams_uni = 0 279 self._spaces: Dict[tls.Epoch, QuicPacketSpace] = {} 280 self._spin_bit = False 281 self._spin_highest_pn = 0 282 self._state = QuicConnectionState.FIRSTFLIGHT 283 self._stateless_retry_count = 0 284 self._streams: Dict[int, QuicStream] = {} 285 self._streams_blocked_bidi: List[QuicStream] = [] 286 self._streams_blocked_uni: List[QuicStream] = [] 287 self._version: Optional[int] = None 288 289 # logging 290 if logger_connection_id is None: 291 logger_connection_id = self._peer_cid 292 self._logger = QuicConnectionAdapter( 293 logger, {"id": dump_cid(logger_connection_id)} 294 ) 295 if configuration.quic_logger: 296 self._quic_logger = configuration.quic_logger.start_trace( 297 is_client=configuration.is_client, odcid=logger_connection_id 298 ) 299 300 # loss recovery 301 self._loss = QuicPacketRecovery( 302 is_client_without_1rtt=self._is_client, 303 quic_logger=self._quic_logger, 304 send_probe=self._send_probe, 305 ) 306 307 # things to send 308 self._close_pending = False 309 self._datagrams_pending: Deque[bytes] = deque() 310 self._handshake_done_pending = False 311 self._ping_pending: List[int] = [] 312 self._probe_pending = False 313 self._retire_connection_ids: List[int] = [] 314 self._streams_blocked_pending = False 315 316 # callbacks 317 self._session_ticket_fetcher = session_ticket_fetcher 318 self._session_ticket_handler = session_ticket_handler 319 320 # frame handlers 321 self.__frame_handlers = { 322 0x00: (self._handle_padding_frame, EPOCHS("IH01")), 323 0x01: (self._handle_ping_frame, EPOCHS("IH01")), 324 0x02: (self._handle_ack_frame, EPOCHS("IH1")), 325 0x03: (self._handle_ack_frame, EPOCHS("IH1")), 326 0x04: (self._handle_reset_stream_frame, EPOCHS("01")), 327 0x05: (self._handle_stop_sending_frame, EPOCHS("01")), 328 0x06: (self._handle_crypto_frame, EPOCHS("IH1")), 329 0x07: (self._handle_new_token_frame, EPOCHS("1")), 330 0x08: (self._handle_stream_frame, EPOCHS("01")), 331 0x09: (self._handle_stream_frame, EPOCHS("01")), 332 0x0A: (self._handle_stream_frame, EPOCHS("01")), 333 0x0B: (self._handle_stream_frame, EPOCHS("01")), 334 0x0C: (self._handle_stream_frame, EPOCHS("01")), 335 0x0D: (self._handle_stream_frame, EPOCHS("01")), 336 0x0E: (self._handle_stream_frame, EPOCHS("01")), 337 0x0F: (self._handle_stream_frame, EPOCHS("01")), 338 0x10: (self._handle_max_data_frame, EPOCHS("01")), 339 0x11: (self._handle_max_stream_data_frame, EPOCHS("01")), 340 0x12: (self._handle_max_streams_bidi_frame, EPOCHS("01")), 341 0x13: (self._handle_max_streams_uni_frame, EPOCHS("01")), 342 0x14: (self._handle_data_blocked_frame, EPOCHS("01")), 343 0x15: (self._handle_stream_data_blocked_frame, EPOCHS("01")), 344 0x16: (self._handle_streams_blocked_frame, EPOCHS("01")), 345 0x17: (self._handle_streams_blocked_frame, EPOCHS("01")), 346 0x18: (self._handle_new_connection_id_frame, EPOCHS("01")), 347 0x19: (self._handle_retire_connection_id_frame, EPOCHS("01")), 348 0x1A: (self._handle_path_challenge_frame, EPOCHS("01")), 349 0x1B: (self._handle_path_response_frame, EPOCHS("01")), 350 0x1C: (self._handle_connection_close_frame, EPOCHS("IH1")), 351 0x1D: (self._handle_connection_close_frame, EPOCHS("1")), 352 0x1E: (self._handle_handshake_done_frame, EPOCHS("1")), 353 0x30: (self._handle_datagram_frame, EPOCHS("01")), 354 0x31: (self._handle_datagram_frame, EPOCHS("01")), 355 } 356 357 @property 358 def configuration(self) -> QuicConfiguration: 359 return self._configuration 360 361 def change_connection_id(self) -> None: 362 """ 363 Switch to the next available connection ID and retire 364 the previous one. 365 366 After calling this method call :meth:`datagrams_to_send` to retrieve data 367 which needs to be sent. 368 """ 369 if self._peer_cid_available: 370 # retire previous CID 371 self._logger.debug( 372 "Retiring CID %s (%d)", dump_cid(self._peer_cid), self._peer_cid_seq 373 ) 374 self._retire_connection_ids.append(self._peer_cid_seq) 375 376 # assign new CID 377 connection_id = self._peer_cid_available.pop(0) 378 self._peer_cid_seq = connection_id.sequence_number 379 self._peer_cid = connection_id.cid 380 self._logger.debug( 381 "Switching to CID %s (%d)", dump_cid(self._peer_cid), self._peer_cid_seq 382 ) 383 384 def close( 385 self, 386 error_code: int = QuicErrorCode.NO_ERROR, 387 frame_type: Optional[int] = None, 388 reason_phrase: str = "", 389 ) -> None: 390 """ 391 Close the connection. 392 393 :param error_code: An error code indicating why the connection is 394 being closed. 395 :param reason_phrase: A human-readable explanation of why the 396 connection is being closed. 397 """ 398 if self._state not in END_STATES: 399 self._close_event = events.ConnectionTerminated( 400 error_code=error_code, 401 frame_type=frame_type, 402 reason_phrase=reason_phrase, 403 ) 404 self._close_pending = True 405 406 def connect(self, addr: NetworkAddress, now: float) -> None: 407 """ 408 Initiate the TLS handshake. 409 410 This method can only be called for clients and a single time. 411 412 After calling this method call :meth:`datagrams_to_send` to retrieve data 413 which needs to be sent. 414 415 :param addr: The network address of the remote peer. 416 :param now: The current time. 417 """ 418 assert ( 419 self._is_client and not self._connect_called 420 ), "connect() can only be called for clients and a single time" 421 self._connect_called = True 422 423 self._network_paths = [QuicNetworkPath(addr, is_validated=True)] 424 self._version = self._configuration.supported_versions[0] 425 self._connect(now=now) 426 427 def datagrams_to_send(self, now: float) -> List[Tuple[bytes, NetworkAddress]]: 428 """ 429 Return a list of `(data, addr)` tuples of datagrams which need to be 430 sent, and the network address to which they need to be sent. 431 432 After calling this method call :meth:`get_timer` to know when the next 433 timer needs to be set. 434 435 :param now: The current time. 436 """ 437 network_path = self._network_paths[0] 438 439 if self._state in END_STATES: 440 return [] 441 442 # build datagrams 443 builder = QuicPacketBuilder( 444 host_cid=self.host_cid, 445 is_client=self._is_client, 446 packet_number=self._packet_number, 447 peer_cid=self._peer_cid, 448 peer_token=self._peer_token, 449 quic_logger=self._quic_logger, 450 spin_bit=self._spin_bit, 451 version=self._version, 452 ) 453 if self._close_pending: 454 for epoch, packet_type in ( 455 (tls.Epoch.ONE_RTT, PACKET_TYPE_ONE_RTT), 456 (tls.Epoch.HANDSHAKE, PACKET_TYPE_HANDSHAKE), 457 (tls.Epoch.INITIAL, PACKET_TYPE_INITIAL), 458 ): 459 crypto = self._cryptos[epoch] 460 if crypto.send.is_valid(): 461 builder.start_packet(packet_type, crypto) 462 self._write_connection_close_frame( 463 builder=builder, 464 error_code=self._close_event.error_code, 465 frame_type=self._close_event.frame_type, 466 reason_phrase=self._close_event.reason_phrase, 467 ) 468 self._close_pending = False 469 break 470 self._close_begin(is_initiator=True, now=now) 471 else: 472 # congestion control 473 builder.max_flight_bytes = ( 474 self._loss.congestion_window - self._loss.bytes_in_flight 475 ) 476 if self._probe_pending and builder.max_flight_bytes < PACKET_MAX_SIZE: 477 builder.max_flight_bytes = PACKET_MAX_SIZE 478 479 # limit data on un-validated network paths 480 if not network_path.is_validated: 481 builder.max_total_bytes = ( 482 network_path.bytes_received * 3 - network_path.bytes_sent 483 ) 484 485 try: 486 if not self._handshake_confirmed: 487 for epoch in [tls.Epoch.INITIAL, tls.Epoch.HANDSHAKE]: 488 self._write_handshake(builder, epoch, now) 489 self._write_application(builder, network_path, now) 490 except QuicPacketBuilderStop: 491 pass 492 493 datagrams, packets = builder.flush() 494 495 if datagrams: 496 self._packet_number = builder.packet_number 497 498 # register packets 499 sent_handshake = False 500 for packet in packets: 501 packet.sent_time = now 502 self._loss.on_packet_sent( 503 packet=packet, space=self._spaces[packet.epoch] 504 ) 505 if packet.epoch == tls.Epoch.HANDSHAKE: 506 sent_handshake = True 507 508 # log packet 509 if self._quic_logger is not None: 510 self._quic_logger.log_event( 511 category="transport", 512 event="packet_sent", 513 data={ 514 "packet_type": self._quic_logger.packet_type( 515 packet.packet_type 516 ), 517 "header": { 518 "packet_number": str(packet.packet_number), 519 "packet_size": packet.sent_bytes, 520 "scid": dump_cid(self.host_cid) 521 if is_long_header(packet.packet_type) 522 else "", 523 "dcid": dump_cid(self._peer_cid), 524 }, 525 "frames": packet.quic_logger_frames, 526 }, 527 ) 528 529 # check if we can discard initial keys 530 if sent_handshake and self._is_client: 531 self._discard_epoch(tls.Epoch.INITIAL) 532 533 # return datagrams to send and the destination network address 534 ret = [] 535 for datagram in datagrams: 536 byte_length = len(datagram) 537 network_path.bytes_sent += byte_length 538 ret.append((datagram, network_path.addr)) 539 540 if self._quic_logger is not None: 541 self._quic_logger.log_event( 542 category="transport", 543 event="datagrams_sent", 544 data={"byte_length": byte_length, "count": 1}, 545 ) 546 return ret 547 548 def get_next_available_stream_id(self, is_unidirectional=False) -> int: 549 """ 550 Return the stream ID for the next stream created by this endpoint. 551 """ 552 stream_id = (int(is_unidirectional) << 1) | int(not self._is_client) 553 while stream_id in self._streams: 554 stream_id += 4 555 return stream_id 556 557 def get_timer(self) -> Optional[float]: 558 """ 559 Return the time at which the timer should fire or None if no timer is needed. 560 """ 561 timer_at = self._close_at 562 if self._state not in END_STATES: 563 # ack timer 564 for space in self._loss.spaces: 565 if space.ack_at is not None and space.ack_at < timer_at: 566 timer_at = space.ack_at 567 568 # loss detection timer 569 self._loss_at = self._loss.get_loss_detection_time() 570 if self._loss_at is not None and self._loss_at < timer_at: 571 timer_at = self._loss_at 572 573 # pacing timer 574 if self._pacing_at is not None and self._pacing_at < timer_at: 575 timer_at = self._pacing_at 576 577 return timer_at 578 579 def handle_timer(self, now: float) -> None: 580 """ 581 Handle the timer. 582 583 After calling this method call :meth:`datagrams_to_send` to retrieve data 584 which needs to be sent. 585 586 :param now: The current time. 587 """ 588 # end of closing period or idle timeout 589 if now >= self._close_at: 590 if self._close_event is None: 591 self._close_event = events.ConnectionTerminated( 592 error_code=QuicErrorCode.INTERNAL_ERROR, 593 frame_type=None, 594 reason_phrase="Idle timeout", 595 ) 596 self._close_end() 597 return 598 599 # loss detection timeout 600 if self._loss_at is not None and now >= self._loss_at: 601 self._logger.debug("Loss detection triggered") 602 self._loss.on_loss_detection_timeout(now=now) 603 604 def next_event(self) -> Optional[events.QuicEvent]: 605 """ 606 Retrieve the next event from the event buffer. 607 608 Returns `None` if there are no buffered events. 609 """ 610 try: 611 return self._events.popleft() 612 except IndexError: 613 return None 614 615 def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> None: 616 """ 617 Handle an incoming datagram. 618 619 After calling this method call :meth:`datagrams_to_send` to retrieve data 620 which needs to be sent. 621 622 :param data: The datagram which was received. 623 :param addr: The network address from which the datagram was received. 624 :param now: The current time. 625 """ 626 # stop handling packets when closing 627 if self._state in END_STATES: 628 return 629 630 if self._quic_logger is not None: 631 self._quic_logger.log_event( 632 category="transport", 633 event="datagrams_received", 634 data={"byte_length": len(data), "count": 1}, 635 ) 636 637 buf = Buffer(data=data) 638 while not buf.eof(): 639 start_off = buf.tell() 640 try: 641 header = pull_quic_header( 642 buf, host_cid_length=self._configuration.connection_id_length 643 ) 644 except ValueError: 645 return 646 647 # check destination CID matches 648 destination_cid_seq: Optional[int] = None 649 for connection_id in self._host_cids: 650 if header.destination_cid == connection_id.cid: 651 destination_cid_seq = connection_id.sequence_number 652 break 653 if self._is_client and destination_cid_seq is None: 654 if self._quic_logger is not None: 655 self._quic_logger.log_event( 656 category="transport", 657 event="packet_dropped", 658 data={"trigger": "unknown_connection_id"}, 659 ) 660 return 661 662 # check protocol version 663 if ( 664 self._is_client 665 and self._state == QuicConnectionState.FIRSTFLIGHT 666 and header.version == QuicProtocolVersion.NEGOTIATION 667 ): 668 # version negotiation 669 versions = [] 670 while not buf.eof(): 671 versions.append(buf.pull_uint32()) 672 if self._quic_logger is not None: 673 self._quic_logger.log_event( 674 category="transport", 675 event="packet_received", 676 data={ 677 "packet_type": "version_negotiation", 678 "header": { 679 "scid": dump_cid(header.source_cid), 680 "dcid": dump_cid(header.destination_cid), 681 }, 682 "frames": [], 683 }, 684 ) 685 common = set(self._configuration.supported_versions).intersection( 686 versions 687 ) 688 if not common: 689 self._logger.error("Could not find a common protocol version") 690 self._close_event = events.ConnectionTerminated( 691 error_code=QuicErrorCode.INTERNAL_ERROR, 692 frame_type=None, 693 reason_phrase="Could not find a common protocol version", 694 ) 695 self._close_end() 696 return 697 self._version = QuicProtocolVersion(max(common)) 698 self._logger.info("Retrying with %s", self._version) 699 self._connect(now=now) 700 return 701 elif ( 702 header.version is not None 703 and header.version not in self._configuration.supported_versions 704 ): 705 # unsupported version 706 if self._quic_logger is not None: 707 self._quic_logger.log_event( 708 category="transport", 709 event="packet_dropped", 710 data={"trigger": "unsupported_version"}, 711 ) 712 return 713 714 if self._is_client and header.packet_type == PACKET_TYPE_RETRY: 715 # calculate stateless retry integrity tag 716 integrity_tag = get_retry_integrity_tag( 717 buf.data_slice(start_off, buf.tell() - RETRY_INTEGRITY_TAG_SIZE), 718 self._peer_cid, 719 ) 720 721 if ( 722 header.destination_cid == self.host_cid 723 and header.integrity_tag == integrity_tag 724 and not self._stateless_retry_count 725 ): 726 if self._quic_logger is not None: 727 self._quic_logger.log_event( 728 category="transport", 729 event="packet_received", 730 data={ 731 "packet_type": "retry", 732 "header": { 733 "scid": dump_cid(header.source_cid), 734 "dcid": dump_cid(header.destination_cid), 735 }, 736 "frames": [], 737 }, 738 ) 739 740 self._original_connection_id = self._peer_cid 741 self._peer_cid = header.source_cid 742 self._peer_token = header.token 743 self._stateless_retry_count += 1 744 self._logger.info("Performing stateless retry") 745 self._connect(now=now) 746 return 747 748 network_path = self._find_network_path(addr) 749 750 # server initialization 751 if not self._is_client and self._state == QuicConnectionState.FIRSTFLIGHT: 752 assert ( 753 header.packet_type == PACKET_TYPE_INITIAL 754 ), "first packet must be INITIAL" 755 self._network_paths = [network_path] 756 self._version = QuicProtocolVersion(header.version) 757 self._initialize(header.destination_cid) 758 759 # determine crypto and packet space 760 epoch = get_epoch(header.packet_type) 761 crypto = self._cryptos[epoch] 762 if epoch == tls.Epoch.ZERO_RTT: 763 space = self._spaces[tls.Epoch.ONE_RTT] 764 else: 765 space = self._spaces[epoch] 766 767 # decrypt packet 768 encrypted_off = buf.tell() - start_off 769 end_off = buf.tell() + header.rest_length 770 buf.seek(end_off) 771 772 try: 773 plain_header, plain_payload, packet_number = crypto.decrypt_packet( 774 data[start_off:end_off], encrypted_off, space.expected_packet_number 775 ) 776 except KeyUnavailableError as exc: 777 self._logger.debug(exc) 778 if self._quic_logger is not None: 779 self._quic_logger.log_event( 780 category="transport", 781 event="packet_dropped", 782 data={"trigger": "key_unavailable"}, 783 ) 784 continue 785 except CryptoError as exc: 786 self._logger.debug(exc) 787 if self._quic_logger is not None: 788 self._quic_logger.log_event( 789 category="transport", 790 event="packet_dropped", 791 data={"trigger": "payload_decrypt_error"}, 792 ) 793 continue 794 795 # check reserved bits 796 if header.is_long_header: 797 reserved_mask = 0x0C 798 else: 799 reserved_mask = 0x18 800 if plain_header[0] & reserved_mask: 801 self.close( 802 error_code=QuicErrorCode.PROTOCOL_VIOLATION, 803 frame_type=None, 804 reason_phrase="Reserved bits must be zero", 805 ) 806 return 807 808 # raise expected packet number 809 if packet_number > space.expected_packet_number: 810 space.expected_packet_number = packet_number + 1 811 812 # log packet 813 quic_logger_frames: Optional[List[Dict]] = None 814 if self._quic_logger is not None: 815 quic_logger_frames = [] 816 self._quic_logger.log_event( 817 category="transport", 818 event="packet_received", 819 data={ 820 "packet_type": self._quic_logger.packet_type( 821 header.packet_type 822 ), 823 "header": { 824 "packet_number": str(packet_number), 825 "packet_size": end_off - start_off, 826 "dcid": dump_cid(header.destination_cid), 827 "scid": dump_cid(header.source_cid), 828 }, 829 "frames": quic_logger_frames, 830 }, 831 ) 832 833 # discard initial keys and packet space 834 if not self._is_client and epoch == tls.Epoch.HANDSHAKE: 835 self._discard_epoch(tls.Epoch.INITIAL) 836 837 # update state 838 if self._peer_cid_seq is None: 839 self._peer_cid = header.source_cid 840 self._peer_cid_seq = 0 841 842 if self._state == QuicConnectionState.FIRSTFLIGHT: 843 self._set_state(QuicConnectionState.CONNECTED) 844 845 # update spin bit 846 if not header.is_long_header and packet_number > self._spin_highest_pn: 847 spin_bit = get_spin_bit(plain_header[0]) 848 if self._is_client: 849 self._spin_bit = not spin_bit 850 else: 851 self._spin_bit = spin_bit 852 self._spin_highest_pn = packet_number 853 854 if self._quic_logger is not None: 855 self._quic_logger.log_event( 856 category="connectivity", 857 event="spin_bit_updated", 858 data={"state": self._spin_bit}, 859 ) 860 861 # handle payload 862 context = QuicReceiveContext( 863 epoch=epoch, 864 host_cid=header.destination_cid, 865 network_path=network_path, 866 quic_logger_frames=quic_logger_frames, 867 time=now, 868 ) 869 try: 870 is_ack_eliciting, is_probing = self._payload_received( 871 context, plain_payload 872 ) 873 except QuicConnectionError as exc: 874 self._logger.warning(exc) 875 self.close( 876 error_code=exc.error_code, 877 frame_type=exc.frame_type, 878 reason_phrase=exc.reason_phrase, 879 ) 880 if self._state in END_STATES or self._close_pending: 881 return 882 883 # update idle timeout 884 self._close_at = now + self._configuration.idle_timeout 885 886 # handle migration 887 if ( 888 not self._is_client 889 and context.host_cid != self.host_cid 890 and epoch == tls.Epoch.ONE_RTT 891 ): 892 self._logger.debug( 893 "Peer switching to CID %s (%d)", 894 dump_cid(context.host_cid), 895 destination_cid_seq, 896 ) 897 self.host_cid = context.host_cid 898 self.change_connection_id() 899 900 # update network path 901 if not network_path.is_validated and epoch == tls.Epoch.HANDSHAKE: 902 self._logger.debug( 903 "Network path %s validated by handshake", network_path.addr 904 ) 905 network_path.is_validated = True 906 network_path.bytes_received += end_off - start_off 907 if network_path not in self._network_paths: 908 self._network_paths.append(network_path) 909 idx = self._network_paths.index(network_path) 910 if idx and not is_probing and packet_number > space.largest_received_packet: 911 self._logger.debug("Network path %s promoted", network_path.addr) 912 self._network_paths.pop(idx) 913 self._network_paths.insert(0, network_path) 914 915 # record packet as received 916 if not space.discarded: 917 if packet_number > space.largest_received_packet: 918 space.largest_received_packet = packet_number 919 space.largest_received_time = now 920 space.ack_queue.add(packet_number) 921 if is_ack_eliciting and space.ack_at is None: 922 space.ack_at = now + self._ack_delay 923 924 def request_key_update(self) -> None: 925 """ 926 Request an update of the encryption keys. 927 """ 928 assert self._handshake_complete, "cannot change key before handshake completes" 929 self._cryptos[tls.Epoch.ONE_RTT].update_key() 930 931 def send_ping(self, uid: int) -> None: 932 """ 933 Send a PING frame to the peer. 934 935 :param uid: A unique ID for this PING. 936 """ 937 self._ping_pending.append(uid) 938 939 def send_datagram_frame(self, data: bytes) -> None: 940 """ 941 Send a DATAGRAM frame. 942 943 :param data: The data to be sent. 944 """ 945 self._datagrams_pending.append(data) 946 947 def send_stream_data( 948 self, stream_id: int, data: bytes, end_stream: bool = False 949 ) -> None: 950 """ 951 Send data on the specific stream. 952 953 :param stream_id: The stream's ID. 954 :param data: The data to be sent. 955 :param end_stream: If set to `True`, the FIN bit will be set. 956 """ 957 if stream_is_client_initiated(stream_id) != self._is_client: 958 if stream_id not in self._streams: 959 raise ValueError("Cannot send data on unknown peer-initiated stream") 960 if stream_is_unidirectional(stream_id): 961 raise ValueError( 962 "Cannot send data on peer-initiated unidirectional stream" 963 ) 964 965 try: 966 stream = self._streams[stream_id] 967 except KeyError: 968 self._create_stream(stream_id=stream_id) 969 stream = self._streams[stream_id] 970 stream.write(data, end_stream=end_stream) 971 972 # Private 973 974 def _alpn_handler(self, alpn_protocol: str) -> None: 975 """ 976 Callback which is invoked by the TLS engine when ALPN negotiation completes. 977 """ 978 self._events.append(events.ProtocolNegotiated(alpn_protocol=alpn_protocol)) 979 980 def _assert_stream_can_receive(self, frame_type: int, stream_id: int) -> None: 981 """ 982 Check the specified stream can receive data or raises a QuicConnectionError. 983 """ 984 if not self._stream_can_receive(stream_id): 985 raise QuicConnectionError( 986 error_code=QuicErrorCode.STREAM_STATE_ERROR, 987 frame_type=frame_type, 988 reason_phrase="Stream is send-only", 989 ) 990 991 def _assert_stream_can_send(self, frame_type: int, stream_id: int) -> None: 992 """ 993 Check the specified stream can send data or raises a QuicConnectionError. 994 """ 995 if not self._stream_can_send(stream_id): 996 raise QuicConnectionError( 997 error_code=QuicErrorCode.STREAM_STATE_ERROR, 998 frame_type=frame_type, 999 reason_phrase="Stream is receive-only", 1000 ) 1001 1002 def _close_begin(self, is_initiator: bool, now: float) -> None: 1003 """ 1004 Begin the close procedure. 1005 """ 1006 self._close_at = now + 3 * self._loss.get_probe_timeout() 1007 if is_initiator: 1008 self._set_state(QuicConnectionState.CLOSING) 1009 else: 1010 self._set_state(QuicConnectionState.DRAINING) 1011 1012 def _close_end(self) -> None: 1013 """ 1014 End the close procedure. 1015 """ 1016 self._close_at = None 1017 for epoch in self._spaces.keys(): 1018 self._discard_epoch(epoch) 1019 self._events.append(self._close_event) 1020 self._set_state(QuicConnectionState.TERMINATED) 1021 1022 # signal log end 1023 if self._quic_logger is not None: 1024 self._configuration.quic_logger.end_trace(self._quic_logger) 1025 self._quic_logger = None 1026 1027 def _connect(self, now: float) -> None: 1028 """ 1029 Start the client handshake. 1030 """ 1031 assert self._is_client 1032 1033 self._close_at = now + self._configuration.idle_timeout 1034 self._initialize(self._peer_cid) 1035 1036 self.tls.handle_message(b"", self._crypto_buffers) 1037 self._push_crypto_data() 1038 1039 def _create_stream(self, stream_id: int) -> QuicStream: 1040 """ 1041 Create a QUIC stream in order to send data to the peer. 1042 """ 1043 # determine limits 1044 if stream_is_unidirectional(stream_id): 1045 max_stream_data_local = 0 1046 max_stream_data_remote = self._remote_max_stream_data_uni 1047 max_streams = self._remote_max_streams_uni 1048 streams_blocked = self._streams_blocked_uni 1049 else: 1050 max_stream_data_local = self._local_max_stream_data_bidi_local 1051 max_stream_data_remote = self._remote_max_stream_data_bidi_remote 1052 max_streams = self._remote_max_streams_bidi 1053 streams_blocked = self._streams_blocked_bidi 1054 1055 # create stream 1056 stream = self._streams[stream_id] = QuicStream( 1057 stream_id=stream_id, 1058 max_stream_data_local=max_stream_data_local, 1059 max_stream_data_remote=max_stream_data_remote, 1060 ) 1061 1062 # mark stream as blocked if needed 1063 if stream_id // 4 >= max_streams: 1064 stream.is_blocked = True 1065 streams_blocked.append(stream) 1066 self._streams_blocked_pending = True 1067 1068 return stream 1069 1070 def _discard_epoch(self, epoch: tls.Epoch) -> None: 1071 self._logger.debug("Discarding epoch %s", epoch) 1072 self._cryptos[epoch].teardown() 1073 self._loss.discard_space(self._spaces[epoch]) 1074 self._spaces[epoch].discarded = True 1075 1076 def _find_network_path(self, addr: NetworkAddress) -> QuicNetworkPath: 1077 # check existing network paths 1078 for idx, network_path in enumerate(self._network_paths): 1079 if network_path.addr == addr: 1080 return network_path 1081 1082 # new network path 1083 network_path = QuicNetworkPath(addr) 1084 self._logger.debug("Network path %s discovered", network_path.addr) 1085 return network_path 1086 1087 def _get_or_create_stream(self, frame_type: int, stream_id: int) -> QuicStream: 1088 """ 1089 Get or create a stream in response to a received frame. 1090 """ 1091 stream = self._streams.get(stream_id, None) 1092 if stream is None: 1093 # check initiator 1094 if stream_is_client_initiated(stream_id) == self._is_client: 1095 raise QuicConnectionError( 1096 error_code=QuicErrorCode.STREAM_STATE_ERROR, 1097 frame_type=frame_type, 1098 reason_phrase="Wrong stream initiator", 1099 ) 1100 1101 # determine limits 1102 if stream_is_unidirectional(stream_id): 1103 max_stream_data_local = self._local_max_stream_data_uni 1104 max_stream_data_remote = 0 1105 max_streams = self._local_max_streams_uni 1106 else: 1107 max_stream_data_local = self._local_max_stream_data_bidi_remote 1108 max_stream_data_remote = self._remote_max_stream_data_bidi_local 1109 max_streams = self._local_max_streams_bidi 1110 1111 # check max streams 1112 if stream_id // 4 >= max_streams: 1113 raise QuicConnectionError( 1114 error_code=QuicErrorCode.STREAM_LIMIT_ERROR, 1115 frame_type=frame_type, 1116 reason_phrase="Too many streams open", 1117 ) 1118 1119 # create stream 1120 self._logger.debug("Stream %d created by peer" % stream_id) 1121 stream = self._streams[stream_id] = QuicStream( 1122 stream_id=stream_id, 1123 max_stream_data_local=max_stream_data_local, 1124 max_stream_data_remote=max_stream_data_remote, 1125 ) 1126 return stream 1127 1128 def _handle_session_ticket(self, session_ticket: tls.SessionTicket) -> None: 1129 if ( 1130 session_ticket.max_early_data_size is not None 1131 and session_ticket.max_early_data_size != MAX_EARLY_DATA 1132 ): 1133 raise QuicConnectionError( 1134 error_code=QuicErrorCode.PROTOCOL_VIOLATION, 1135 frame_type=QuicFrameType.CRYPTO, 1136 reason_phrase="Invalid max_early_data value %s" 1137 % session_ticket.max_early_data_size, 1138 ) 1139 self._session_ticket_handler(session_ticket) 1140 1141 def _initialize(self, peer_cid: bytes) -> None: 1142 # TLS 1143 self.tls = tls.Context( 1144 alpn_protocols=self._configuration.alpn_protocols, 1145 cadata=self._configuration.cadata, 1146 cafile=self._configuration.cafile, 1147 capath=self._configuration.capath, 1148 is_client=self._is_client, 1149 logger=self._logger, 1150 max_early_data=None if self._is_client else MAX_EARLY_DATA, 1151 server_name=self._configuration.server_name, 1152 verify_mode=self._configuration.verify_mode, 1153 ) 1154 self.tls.certificate = self._configuration.certificate 1155 self.tls.certificate_chain = self._configuration.certificate_chain 1156 self.tls.certificate_private_key = self._configuration.private_key 1157 self.tls.handshake_extensions = [ 1158 ( 1159 tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS, 1160 self._serialize_transport_parameters(), 1161 ) 1162 ] 1163 1164 # TLS session resumption 1165 session_ticket = self._configuration.session_ticket 1166 if ( 1167 self._is_client 1168 and session_ticket is not None 1169 and session_ticket.is_valid 1170 and session_ticket.server_name == self._configuration.server_name 1171 ): 1172 self.tls.session_ticket = self._configuration.session_ticket 1173 1174 # parse saved QUIC transport parameters - for 0-RTT 1175 if session_ticket.max_early_data_size == MAX_EARLY_DATA: 1176 for ext_type, ext_data in session_ticket.other_extensions: 1177 if ext_type == tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS: 1178 self._parse_transport_parameters( 1179 ext_data, from_session_ticket=True 1180 ) 1181 break 1182 1183 # TLS callbacks 1184 self.tls.alpn_cb = self._alpn_handler 1185 if self._session_ticket_fetcher is not None: 1186 self.tls.get_session_ticket_cb = self._session_ticket_fetcher 1187 if self._session_ticket_handler is not None: 1188 self.tls.new_session_ticket_cb = self._handle_session_ticket 1189 self.tls.update_traffic_key_cb = self._update_traffic_key 1190 1191 # packet spaces 1192 self._cryptos = { 1193 tls.Epoch.INITIAL: CryptoPair(), 1194 tls.Epoch.ZERO_RTT: CryptoPair(), 1195 tls.Epoch.HANDSHAKE: CryptoPair(), 1196 tls.Epoch.ONE_RTT: CryptoPair(), 1197 } 1198 self._crypto_buffers = { 1199 tls.Epoch.INITIAL: Buffer(capacity=CRYPTO_BUFFER_SIZE), 1200 tls.Epoch.HANDSHAKE: Buffer(capacity=CRYPTO_BUFFER_SIZE), 1201 tls.Epoch.ONE_RTT: Buffer(capacity=CRYPTO_BUFFER_SIZE), 1202 } 1203 self._crypto_streams = { 1204 tls.Epoch.INITIAL: QuicStream(), 1205 tls.Epoch.HANDSHAKE: QuicStream(), 1206 tls.Epoch.ONE_RTT: QuicStream(), 1207 } 1208 self._spaces = { 1209 tls.Epoch.INITIAL: QuicPacketSpace(), 1210 tls.Epoch.HANDSHAKE: QuicPacketSpace(), 1211 tls.Epoch.ONE_RTT: QuicPacketSpace(), 1212 } 1213 1214 self._cryptos[tls.Epoch.INITIAL].setup_initial( 1215 cid=peer_cid, is_client=self._is_client, version=self._version 1216 ) 1217 1218 self._loss.spaces = list(self._spaces.values()) 1219 self._packet_number = 0 1220 1221 def _handle_ack_frame( 1222 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1223 ) -> None: 1224 """ 1225 Handle an ACK frame. 1226 """ 1227 ack_rangeset, ack_delay_encoded = pull_ack_frame(buf) 1228 if frame_type == QuicFrameType.ACK_ECN: 1229 buf.pull_uint_var() 1230 buf.pull_uint_var() 1231 buf.pull_uint_var() 1232 ack_delay = (ack_delay_encoded << self._remote_ack_delay_exponent) / 1000000 1233 1234 # log frame 1235 if self._quic_logger is not None: 1236 context.quic_logger_frames.append( 1237 self._quic_logger.encode_ack_frame(ack_rangeset, ack_delay) 1238 ) 1239 1240 self._loss.on_ack_received( 1241 space=self._spaces[context.epoch], 1242 ack_rangeset=ack_rangeset, 1243 ack_delay=ack_delay, 1244 now=context.time, 1245 ) 1246 1247 def _handle_connection_close_frame( 1248 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1249 ) -> None: 1250 """ 1251 Handle a CONNECTION_CLOSE frame. 1252 """ 1253 error_code = buf.pull_uint_var() 1254 if frame_type == QuicFrameType.TRANSPORT_CLOSE: 1255 frame_type = buf.pull_uint_var() 1256 else: 1257 frame_type = None 1258 reason_length = buf.pull_uint_var() 1259 try: 1260 reason_phrase = buf.pull_bytes(reason_length).decode("utf8") 1261 except UnicodeDecodeError: 1262 reason_phrase = "" 1263 1264 # log frame 1265 if self._quic_logger is not None: 1266 context.quic_logger_frames.append( 1267 self._quic_logger.encode_connection_close_frame( 1268 error_code=error_code, 1269 frame_type=frame_type, 1270 reason_phrase=reason_phrase, 1271 ) 1272 ) 1273 1274 self._logger.info( 1275 "Connection close code 0x%X, reason %s", error_code, reason_phrase 1276 ) 1277 self._close_event = events.ConnectionTerminated( 1278 error_code=error_code, frame_type=frame_type, reason_phrase=reason_phrase 1279 ) 1280 self._close_begin(is_initiator=False, now=context.time) 1281 1282 def _handle_crypto_frame( 1283 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1284 ) -> None: 1285 """ 1286 Handle a CRYPTO frame. 1287 """ 1288 offset = buf.pull_uint_var() 1289 length = buf.pull_uint_var() 1290 if offset + length > UINT_VAR_MAX: 1291 raise QuicConnectionError( 1292 error_code=QuicErrorCode.FRAME_ENCODING_ERROR, 1293 frame_type=frame_type, 1294 reason_phrase="offset + length cannot exceed 2^62 - 1", 1295 ) 1296 frame = QuicStreamFrame(offset=offset, data=buf.pull_bytes(length)) 1297 1298 # log frame 1299 if self._quic_logger is not None: 1300 context.quic_logger_frames.append( 1301 self._quic_logger.encode_crypto_frame(frame) 1302 ) 1303 1304 stream = self._crypto_streams[context.epoch] 1305 event = stream.add_frame(frame) 1306 if event is not None: 1307 # pass data to TLS layer 1308 try: 1309 self.tls.handle_message(event.data, self._crypto_buffers) 1310 self._push_crypto_data() 1311 except tls.Alert as exc: 1312 raise QuicConnectionError( 1313 error_code=QuicErrorCode.CRYPTO_ERROR + int(exc.description), 1314 frame_type=frame_type, 1315 reason_phrase=str(exc), 1316 ) 1317 1318 # parse transport parameters 1319 if ( 1320 not self._parameters_received 1321 and self.tls.received_extensions is not None 1322 ): 1323 for ext_type, ext_data in self.tls.received_extensions: 1324 if ext_type == tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS: 1325 self._parse_transport_parameters(ext_data) 1326 self._parameters_received = True 1327 break 1328 assert ( 1329 self._parameters_received 1330 ), "No QUIC transport parameters received" 1331 1332 # update current epoch 1333 if not self._handshake_complete and self.tls.state in [ 1334 tls.State.CLIENT_POST_HANDSHAKE, 1335 tls.State.SERVER_POST_HANDSHAKE, 1336 ]: 1337 self._handshake_complete = True 1338 1339 # for servers, the handshake is now confirmed 1340 if not self._is_client: 1341 self._discard_epoch(tls.Epoch.HANDSHAKE) 1342 self._handshake_confirmed = True 1343 self._handshake_done_pending = True 1344 1345 self._loss.is_client_without_1rtt = False 1346 self._replenish_connection_ids() 1347 self._events.append( 1348 events.HandshakeCompleted( 1349 alpn_protocol=self.tls.alpn_negotiated, 1350 early_data_accepted=self.tls.early_data_accepted, 1351 session_resumed=self.tls.session_resumed, 1352 ) 1353 ) 1354 self._unblock_streams(is_unidirectional=False) 1355 self._unblock_streams(is_unidirectional=True) 1356 self._logger.info( 1357 "ALPN negotiated protocol %s", self.tls.alpn_negotiated 1358 ) 1359 1360 def _handle_data_blocked_frame( 1361 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1362 ) -> None: 1363 """ 1364 Handle a DATA_BLOCKED frame. 1365 """ 1366 limit = buf.pull_uint_var() 1367 1368 # log frame 1369 if self._quic_logger is not None: 1370 context.quic_logger_frames.append( 1371 self._quic_logger.encode_data_blocked_frame(limit=limit) 1372 ) 1373 1374 def _handle_datagram_frame( 1375 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1376 ) -> None: 1377 """ 1378 Handle a DATAGRAM frame. 1379 """ 1380 start = buf.tell() 1381 if frame_type == QuicFrameType.DATAGRAM_WITH_LENGTH: 1382 length = buf.pull_uint_var() 1383 else: 1384 length = buf.capacity - start 1385 data = buf.pull_bytes(length) 1386 1387 # log frame 1388 if self._quic_logger is not None: 1389 context.quic_logger_frames.append( 1390 self._quic_logger.encode_datagram_frame(length=length) 1391 ) 1392 1393 # check frame is allowed 1394 if ( 1395 self._configuration.max_datagram_frame_size is None 1396 or buf.tell() - start >= self._configuration.max_datagram_frame_size 1397 ): 1398 raise QuicConnectionError( 1399 error_code=QuicErrorCode.PROTOCOL_VIOLATION, 1400 frame_type=frame_type, 1401 reason_phrase="Unexpected DATAGRAM frame", 1402 ) 1403 1404 self._events.append(events.DatagramFrameReceived(data=data)) 1405 1406 def _handle_handshake_done_frame( 1407 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1408 ) -> None: 1409 """ 1410 Handle a HANDSHAKE_DONE frame. 1411 """ 1412 # log frame 1413 if self._quic_logger is not None: 1414 context.quic_logger_frames.append( 1415 self._quic_logger.encode_handshake_done_frame() 1416 ) 1417 1418 if not self._is_client: 1419 raise QuicConnectionError( 1420 error_code=QuicErrorCode.PROTOCOL_VIOLATION, 1421 frame_type=frame_type, 1422 reason_phrase="Clients must not send HANDSHAKE_DONE frames", 1423 ) 1424 1425 # for clients, the handshake is now confirmed 1426 if not self._handshake_confirmed: 1427 self._discard_epoch(tls.Epoch.HANDSHAKE) 1428 self._handshake_confirmed = True 1429 1430 def _handle_max_data_frame( 1431 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1432 ) -> None: 1433 """ 1434 Handle a MAX_DATA frame. 1435 1436 This adjusts the total amount of we can send to the peer. 1437 """ 1438 max_data = buf.pull_uint_var() 1439 1440 # log frame 1441 if self._quic_logger is not None: 1442 context.quic_logger_frames.append( 1443 self._quic_logger.encode_max_data_frame(maximum=max_data) 1444 ) 1445 1446 if max_data > self._remote_max_data: 1447 self._logger.debug("Remote max_data raised to %d", max_data) 1448 self._remote_max_data = max_data 1449 1450 def _handle_max_stream_data_frame( 1451 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1452 ) -> None: 1453 """ 1454 Handle a MAX_STREAM_DATA frame. 1455 1456 This adjusts the amount of data we can send on a specific stream. 1457 """ 1458 stream_id = buf.pull_uint_var() 1459 max_stream_data = buf.pull_uint_var() 1460 1461 # log frame 1462 if self._quic_logger is not None: 1463 context.quic_logger_frames.append( 1464 self._quic_logger.encode_max_stream_data_frame( 1465 maximum=max_stream_data, stream_id=stream_id 1466 ) 1467 ) 1468 1469 # check stream direction 1470 self._assert_stream_can_send(frame_type, stream_id) 1471 1472 stream = self._get_or_create_stream(frame_type, stream_id) 1473 if max_stream_data > stream.max_stream_data_remote: 1474 self._logger.debug( 1475 "Stream %d remote max_stream_data raised to %d", 1476 stream_id, 1477 max_stream_data, 1478 ) 1479 stream.max_stream_data_remote = max_stream_data 1480 1481 def _handle_max_streams_bidi_frame( 1482 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1483 ) -> None: 1484 """ 1485 Handle a MAX_STREAMS_BIDI frame. 1486 1487 This raises number of bidirectional streams we can initiate to the peer. 1488 """ 1489 max_streams = buf.pull_uint_var() 1490 1491 # log frame 1492 if self._quic_logger is not None: 1493 context.quic_logger_frames.append( 1494 self._quic_logger.encode_max_streams_frame( 1495 is_unidirectional=False, maximum=max_streams 1496 ) 1497 ) 1498 1499 if max_streams > self._remote_max_streams_bidi: 1500 self._logger.debug("Remote max_streams_bidi raised to %d", max_streams) 1501 self._remote_max_streams_bidi = max_streams 1502 self._unblock_streams(is_unidirectional=False) 1503 1504 def _handle_max_streams_uni_frame( 1505 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1506 ) -> None: 1507 """ 1508 Handle a MAX_STREAMS_UNI frame. 1509 1510 This raises number of unidirectional streams we can initiate to the peer. 1511 """ 1512 max_streams = buf.pull_uint_var() 1513 1514 # log frame 1515 if self._quic_logger is not None: 1516 context.quic_logger_frames.append( 1517 self._quic_logger.encode_max_streams_frame( 1518 is_unidirectional=True, maximum=max_streams 1519 ) 1520 ) 1521 1522 if max_streams > self._remote_max_streams_uni: 1523 self._logger.debug("Remote max_streams_uni raised to %d", max_streams) 1524 self._remote_max_streams_uni = max_streams 1525 self._unblock_streams(is_unidirectional=True) 1526 1527 def _handle_new_connection_id_frame( 1528 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1529 ) -> None: 1530 """ 1531 Handle a NEW_CONNECTION_ID frame. 1532 """ 1533 sequence_number = buf.pull_uint_var() 1534 retire_prior_to = buf.pull_uint_var() 1535 length = buf.pull_uint8() 1536 connection_id = buf.pull_bytes(length) 1537 stateless_reset_token = buf.pull_bytes(16) 1538 1539 # log frame 1540 if self._quic_logger is not None: 1541 context.quic_logger_frames.append( 1542 self._quic_logger.encode_new_connection_id_frame( 1543 connection_id=connection_id, 1544 retire_prior_to=retire_prior_to, 1545 sequence_number=sequence_number, 1546 stateless_reset_token=stateless_reset_token, 1547 ) 1548 ) 1549 1550 self._peer_cid_available.append( 1551 QuicConnectionId( 1552 cid=connection_id, 1553 sequence_number=sequence_number, 1554 stateless_reset_token=stateless_reset_token, 1555 ) 1556 ) 1557 1558 def _handle_new_token_frame( 1559 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1560 ) -> None: 1561 """ 1562 Handle a NEW_TOKEN frame. 1563 """ 1564 length = buf.pull_uint_var() 1565 token = buf.pull_bytes(length) 1566 1567 # log frame 1568 if self._quic_logger is not None: 1569 context.quic_logger_frames.append( 1570 self._quic_logger.encode_new_token_frame(token=token) 1571 ) 1572 1573 if not self._is_client: 1574 raise QuicConnectionError( 1575 error_code=QuicErrorCode.PROTOCOL_VIOLATION, 1576 frame_type=frame_type, 1577 reason_phrase="Clients must not send NEW_TOKEN frames", 1578 ) 1579 1580 def _handle_padding_frame( 1581 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1582 ) -> None: 1583 """ 1584 Handle a PADDING frame. 1585 """ 1586 # consume padding 1587 pos = buf.tell() 1588 for byte in buf.data_slice(pos, buf.capacity): 1589 if byte: 1590 break 1591 pos += 1 1592 buf.seek(pos) 1593 1594 # log frame 1595 if self._quic_logger is not None: 1596 context.quic_logger_frames.append(self._quic_logger.encode_padding_frame()) 1597 1598 def _handle_path_challenge_frame( 1599 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1600 ) -> None: 1601 """ 1602 Handle a PATH_CHALLENGE frame. 1603 """ 1604 data = buf.pull_bytes(8) 1605 1606 # log frame 1607 if self._quic_logger is not None: 1608 context.quic_logger_frames.append( 1609 self._quic_logger.encode_path_challenge_frame(data=data) 1610 ) 1611 1612 context.network_path.remote_challenge = data 1613 1614 def _handle_path_response_frame( 1615 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1616 ) -> None: 1617 """ 1618 Handle a PATH_RESPONSE frame. 1619 """ 1620 data = buf.pull_bytes(8) 1621 1622 # log frame 1623 if self._quic_logger is not None: 1624 context.quic_logger_frames.append( 1625 self._quic_logger.encode_path_response_frame(data=data) 1626 ) 1627 1628 if data != context.network_path.local_challenge: 1629 raise QuicConnectionError( 1630 error_code=QuicErrorCode.PROTOCOL_VIOLATION, 1631 frame_type=frame_type, 1632 reason_phrase="Response does not match challenge", 1633 ) 1634 self._logger.debug( 1635 "Network path %s validated by challenge", context.network_path.addr 1636 ) 1637 context.network_path.is_validated = True 1638 1639 def _handle_ping_frame( 1640 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1641 ) -> None: 1642 """ 1643 Handle a PING frame. 1644 """ 1645 # log frame 1646 if self._quic_logger is not None: 1647 context.quic_logger_frames.append(self._quic_logger.encode_ping_frame()) 1648 1649 def _handle_reset_stream_frame( 1650 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1651 ) -> None: 1652 """ 1653 Handle a RESET_STREAM frame. 1654 """ 1655 stream_id = buf.pull_uint_var() 1656 error_code = buf.pull_uint_var() 1657 final_size = buf.pull_uint_var() 1658 1659 # log frame 1660 if self._quic_logger is not None: 1661 context.quic_logger_frames.append( 1662 self._quic_logger.encode_reset_stream_frame( 1663 error_code=error_code, final_size=final_size, stream_id=stream_id 1664 ) 1665 ) 1666 1667 # check stream direction 1668 self._assert_stream_can_receive(frame_type, stream_id) 1669 1670 self._logger.info( 1671 "Stream %d reset by peer (error code %d, final size %d)", 1672 stream_id, 1673 error_code, 1674 final_size, 1675 ) 1676 # stream = self._get_or_create_stream(frame_type, stream_id) 1677 self._events.append( 1678 events.StreamReset(error_code=error_code, stream_id=stream_id) 1679 ) 1680 1681 def _handle_retire_connection_id_frame( 1682 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1683 ) -> None: 1684 """ 1685 Handle a RETIRE_CONNECTION_ID frame. 1686 """ 1687 sequence_number = buf.pull_uint_var() 1688 1689 # log frame 1690 if self._quic_logger is not None: 1691 context.quic_logger_frames.append( 1692 self._quic_logger.encode_retire_connection_id_frame(sequence_number) 1693 ) 1694 1695 # find the connection ID by sequence number 1696 for index, connection_id in enumerate(self._host_cids): 1697 if connection_id.sequence_number == sequence_number: 1698 if connection_id.cid == context.host_cid: 1699 raise QuicConnectionError( 1700 error_code=QuicErrorCode.PROTOCOL_VIOLATION, 1701 frame_type=frame_type, 1702 reason_phrase="Cannot retire current connection ID", 1703 ) 1704 self._logger.debug( 1705 "Peer retiring CID %s (%d)", 1706 dump_cid(connection_id.cid), 1707 connection_id.sequence_number, 1708 ) 1709 del self._host_cids[index] 1710 self._events.append( 1711 events.ConnectionIdRetired(connection_id=connection_id.cid) 1712 ) 1713 break 1714 1715 # issue a new connection ID 1716 self._replenish_connection_ids() 1717 1718 def _handle_stop_sending_frame( 1719 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1720 ) -> None: 1721 """ 1722 Handle a STOP_SENDING frame. 1723 """ 1724 stream_id = buf.pull_uint_var() 1725 error_code = buf.pull_uint_var() # application error code 1726 1727 # log frame 1728 if self._quic_logger is not None: 1729 context.quic_logger_frames.append( 1730 self._quic_logger.encode_stop_sending_frame( 1731 error_code=error_code, stream_id=stream_id 1732 ) 1733 ) 1734 1735 # check stream direction 1736 self._assert_stream_can_send(frame_type, stream_id) 1737 1738 self._get_or_create_stream(frame_type, stream_id) 1739 1740 def _handle_stream_frame( 1741 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1742 ) -> None: 1743 """ 1744 Handle a STREAM frame. 1745 """ 1746 stream_id = buf.pull_uint_var() 1747 if frame_type & 4: 1748 offset = buf.pull_uint_var() 1749 else: 1750 offset = 0 1751 if frame_type & 2: 1752 length = buf.pull_uint_var() 1753 else: 1754 length = buf.capacity - buf.tell() 1755 if offset + length > UINT_VAR_MAX: 1756 raise QuicConnectionError( 1757 error_code=QuicErrorCode.FRAME_ENCODING_ERROR, 1758 frame_type=frame_type, 1759 reason_phrase="offset + length cannot exceed 2^62 - 1", 1760 ) 1761 frame = QuicStreamFrame( 1762 offset=offset, data=buf.pull_bytes(length), fin=bool(frame_type & 1) 1763 ) 1764 1765 # log frame 1766 if self._quic_logger is not None: 1767 context.quic_logger_frames.append( 1768 self._quic_logger.encode_stream_frame(frame, stream_id=stream_id) 1769 ) 1770 1771 # check stream direction 1772 self._assert_stream_can_receive(frame_type, stream_id) 1773 1774 # check flow-control limits 1775 stream = self._get_or_create_stream(frame_type, stream_id) 1776 if offset + length > stream.max_stream_data_local: 1777 raise QuicConnectionError( 1778 error_code=QuicErrorCode.FLOW_CONTROL_ERROR, 1779 frame_type=frame_type, 1780 reason_phrase="Over stream data limit", 1781 ) 1782 newly_received = max(0, offset + length - stream._recv_highest) 1783 if self._local_max_data_used + newly_received > self._local_max_data: 1784 raise QuicConnectionError( 1785 error_code=QuicErrorCode.FLOW_CONTROL_ERROR, 1786 frame_type=frame_type, 1787 reason_phrase="Over connection data limit", 1788 ) 1789 1790 event = stream.add_frame(frame) 1791 if event is not None: 1792 self._events.append(event) 1793 self._local_max_data_used += newly_received 1794 1795 def _handle_stream_data_blocked_frame( 1796 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1797 ) -> None: 1798 """ 1799 Handle a STREAM_DATA_BLOCKED frame. 1800 """ 1801 stream_id = buf.pull_uint_var() 1802 limit = buf.pull_uint_var() 1803 1804 # log frame 1805 if self._quic_logger is not None: 1806 context.quic_logger_frames.append( 1807 self._quic_logger.encode_stream_data_blocked_frame( 1808 limit=limit, stream_id=stream_id 1809 ) 1810 ) 1811 1812 # check stream direction 1813 self._assert_stream_can_receive(frame_type, stream_id) 1814 1815 self._get_or_create_stream(frame_type, stream_id) 1816 1817 def _handle_streams_blocked_frame( 1818 self, context: QuicReceiveContext, frame_type: int, buf: Buffer 1819 ) -> None: 1820 """ 1821 Handle a STREAMS_BLOCKED frame. 1822 """ 1823 limit = buf.pull_uint_var() 1824 1825 # log frame 1826 if self._quic_logger is not None: 1827 context.quic_logger_frames.append( 1828 self._quic_logger.encode_streams_blocked_frame( 1829 is_unidirectional=frame_type == QuicFrameType.STREAMS_BLOCKED_UNI, 1830 limit=limit, 1831 ) 1832 ) 1833 1834 def _on_ack_delivery( 1835 self, delivery: QuicDeliveryState, space: QuicPacketSpace, highest_acked: int 1836 ) -> None: 1837 """ 1838 Callback when an ACK frame is acknowledged or lost. 1839 """ 1840 if delivery == QuicDeliveryState.ACKED: 1841 space.ack_queue.subtract(0, highest_acked + 1) 1842 1843 def _on_handshake_done_delivery(self, delivery: QuicDeliveryState) -> None: 1844 """ 1845 Callback when a HANDSHAKE_DONE frame is acknowledged or lost. 1846 """ 1847 if delivery != QuicDeliveryState.ACKED: 1848 self._handshake_done_pending = True 1849 1850 def _on_max_data_delivery(self, delivery: QuicDeliveryState) -> None: 1851 """ 1852 Callback when a MAX_DATA frame is acknowledged or lost. 1853 """ 1854 if delivery != QuicDeliveryState.ACKED: 1855 self._local_max_data_sent = 0 1856 1857 def _on_max_stream_data_delivery( 1858 self, delivery: QuicDeliveryState, stream: QuicStream 1859 ) -> None: 1860 """ 1861 Callback when a MAX_STREAM_DATA frame is acknowledged or lost. 1862 """ 1863 if delivery != QuicDeliveryState.ACKED: 1864 stream.max_stream_data_local_sent = 0 1865 1866 def _on_new_connection_id_delivery( 1867 self, delivery: QuicDeliveryState, connection_id: QuicConnectionId 1868 ) -> None: 1869 """ 1870 Callback when a NEW_CONNECTION_ID frame is acknowledged or lost. 1871 """ 1872 if delivery != QuicDeliveryState.ACKED: 1873 connection_id.was_sent = False 1874 1875 def _on_ping_delivery( 1876 self, delivery: QuicDeliveryState, uids: Sequence[int] 1877 ) -> None: 1878 """ 1879 Callback when a PING frame is acknowledged or lost. 1880 """ 1881 if delivery == QuicDeliveryState.ACKED: 1882 self._logger.debug("Received PING%s response", "" if uids else " (probe)") 1883 for uid in uids: 1884 self._events.append(events.PingAcknowledged(uid=uid)) 1885 else: 1886 self._ping_pending.extend(uids) 1887 1888 def _on_retire_connection_id_delivery( 1889 self, delivery: QuicDeliveryState, sequence_number: int 1890 ) -> None: 1891 """ 1892 Callback when a RETIRE_CONNECTION_ID frame is acknowledged or lost. 1893 """ 1894 if delivery != QuicDeliveryState.ACKED: 1895 self._retire_connection_ids.append(sequence_number) 1896 1897 def _payload_received( 1898 self, context: QuicReceiveContext, plain: bytes 1899 ) -> Tuple[bool, bool]: 1900 """ 1901 Handle a QUIC packet payload. 1902 """ 1903 buf = Buffer(data=plain) 1904 1905 is_ack_eliciting = False 1906 is_probing = None 1907 while not buf.eof(): 1908 frame_type = buf.pull_uint_var() 1909 1910 # check frame type is known 1911 try: 1912 frame_handler, frame_epochs = self.__frame_handlers[frame_type] 1913 except KeyError: 1914 raise QuicConnectionError( 1915 error_code=QuicErrorCode.PROTOCOL_VIOLATION, 1916 frame_type=frame_type, 1917 reason_phrase="Unknown frame type", 1918 ) 1919 1920 # check frame is allowed for the epoch 1921 if context.epoch not in frame_epochs: 1922 raise QuicConnectionError( 1923 error_code=QuicErrorCode.PROTOCOL_VIOLATION, 1924 frame_type=frame_type, 1925 reason_phrase="Unexpected frame type", 1926 ) 1927 1928 # handle the frame 1929 try: 1930 frame_handler(context, frame_type, buf) 1931 except BufferReadError: 1932 raise QuicConnectionError( 1933 error_code=QuicErrorCode.FRAME_ENCODING_ERROR, 1934 frame_type=frame_type, 1935 reason_phrase="Failed to parse frame", 1936 ) 1937 1938 # update ACK only / probing flags 1939 if frame_type not in NON_ACK_ELICITING_FRAME_TYPES: 1940 is_ack_eliciting = True 1941 1942 if frame_type not in PROBING_FRAME_TYPES: 1943 is_probing = False 1944 elif is_probing is None: 1945 is_probing = True 1946 1947 return is_ack_eliciting, bool(is_probing) 1948 1949 def _replenish_connection_ids(self) -> None: 1950 """ 1951 Generate new connection IDs. 1952 """ 1953 while len(self._host_cids) < min(8, self._remote_active_connection_id_limit): 1954 self._host_cids.append( 1955 QuicConnectionId( 1956 cid=os.urandom(self._configuration.connection_id_length), 1957 sequence_number=self._host_cid_seq, 1958 stateless_reset_token=os.urandom(16), 1959 ) 1960 ) 1961 self._host_cid_seq += 1 1962 1963 def _push_crypto_data(self) -> None: 1964 for epoch, buf in self._crypto_buffers.items(): 1965 self._crypto_streams[epoch].write(buf.data) 1966 buf.seek(0) 1967 1968 def _send_probe(self) -> None: 1969 self._probe_pending = True 1970 1971 def _parse_transport_parameters( 1972 self, data: bytes, from_session_ticket: bool = False 1973 ) -> None: 1974 quic_transport_parameters = pull_quic_transport_parameters( 1975 Buffer(data=data), protocol_version=self._version 1976 ) 1977 1978 # log event 1979 if self._quic_logger is not None and not from_session_ticket: 1980 self._quic_logger.log_event( 1981 category="transport", 1982 event="parameters_set", 1983 data=self._quic_logger.encode_transport_parameters( 1984 owner="remote", parameters=quic_transport_parameters 1985 ), 1986 ) 1987 1988 # validate remote parameters 1989 if ( 1990 self._is_client 1991 and not from_session_ticket 1992 and ( 1993 quic_transport_parameters.original_connection_id 1994 != self._original_connection_id 1995 ) 1996 ): 1997 raise QuicConnectionError( 1998 error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, 1999 frame_type=QuicFrameType.CRYPTO, 2000 reason_phrase="original_connection_id does not match", 2001 ) 2002 2003 # store remote parameters 2004 if quic_transport_parameters.ack_delay_exponent is not None: 2005 self._remote_ack_delay_exponent = self._remote_ack_delay_exponent 2006 if quic_transport_parameters.active_connection_id_limit is not None: 2007 self._remote_active_connection_id_limit = ( 2008 quic_transport_parameters.active_connection_id_limit 2009 ) 2010 if quic_transport_parameters.idle_timeout is not None: 2011 self._remote_idle_timeout = quic_transport_parameters.idle_timeout / 1000.0 2012 if quic_transport_parameters.max_ack_delay is not None: 2013 self._loss.max_ack_delay = quic_transport_parameters.max_ack_delay / 1000.0 2014 self._remote_max_datagram_frame_size = ( 2015 quic_transport_parameters.max_datagram_frame_size 2016 ) 2017 for param in [ 2018 "max_data", 2019 "max_stream_data_bidi_local", 2020 "max_stream_data_bidi_remote", 2021 "max_stream_data_uni", 2022 "max_streams_bidi", 2023 "max_streams_uni", 2024 ]: 2025 value = getattr(quic_transport_parameters, "initial_" + param) 2026 if value is not None: 2027 setattr(self, "_remote_" + param, value) 2028 2029 def _serialize_transport_parameters(self) -> bytes: 2030 quic_transport_parameters = QuicTransportParameters( 2031 ack_delay_exponent=self._local_ack_delay_exponent, 2032 active_connection_id_limit=self._local_active_connection_id_limit, 2033 idle_timeout=int(self._configuration.idle_timeout * 1000), 2034 initial_max_data=self._local_max_data, 2035 initial_max_stream_data_bidi_local=self._local_max_stream_data_bidi_local, 2036 initial_max_stream_data_bidi_remote=self._local_max_stream_data_bidi_remote, 2037 initial_max_stream_data_uni=self._local_max_stream_data_uni, 2038 initial_max_streams_bidi=self._local_max_streams_bidi, 2039 initial_max_streams_uni=self._local_max_streams_uni, 2040 max_ack_delay=25, 2041 max_datagram_frame_size=self._configuration.max_datagram_frame_size, 2042 quantum_readiness=b"Q" * 1200 2043 if self._configuration.quantum_readiness_test 2044 else None, 2045 ) 2046 if not self._is_client: 2047 quic_transport_parameters.original_connection_id = ( 2048 self._original_connection_id 2049 ) 2050 2051 # log event 2052 if self._quic_logger is not None: 2053 self._quic_logger.log_event( 2054 category="transport", 2055 event="parameters_set", 2056 data=self._quic_logger.encode_transport_parameters( 2057 owner="local", parameters=quic_transport_parameters 2058 ), 2059 ) 2060 2061 buf = Buffer(capacity=3 * PACKET_MAX_SIZE) 2062 push_quic_transport_parameters( 2063 buf, quic_transport_parameters, protocol_version=self._version 2064 ) 2065 return buf.data 2066 2067 def _set_state(self, state: QuicConnectionState) -> None: 2068 self._logger.debug("%s -> %s", self._state, state) 2069 self._state = state 2070 2071 def _stream_can_receive(self, stream_id: int) -> bool: 2072 return stream_is_client_initiated( 2073 stream_id 2074 ) != self._is_client or not stream_is_unidirectional(stream_id) 2075 2076 def _stream_can_send(self, stream_id: int) -> bool: 2077 return stream_is_client_initiated( 2078 stream_id 2079 ) == self._is_client or not stream_is_unidirectional(stream_id) 2080 2081 def _unblock_streams(self, is_unidirectional: bool) -> None: 2082 if is_unidirectional: 2083 max_stream_data_remote = self._remote_max_stream_data_uni 2084 max_streams = self._remote_max_streams_uni 2085 streams_blocked = self._streams_blocked_uni 2086 else: 2087 max_stream_data_remote = self._remote_max_stream_data_bidi_remote 2088 max_streams = self._remote_max_streams_bidi 2089 streams_blocked = self._streams_blocked_bidi 2090 2091 while streams_blocked and streams_blocked[0].stream_id // 4 < max_streams: 2092 stream = streams_blocked.pop(0) 2093 stream.is_blocked = False 2094 stream.max_stream_data_remote = max_stream_data_remote 2095 2096 if not self._streams_blocked_bidi and not self._streams_blocked_uni: 2097 self._streams_blocked_pending = False 2098 2099 def _update_traffic_key( 2100 self, 2101 direction: tls.Direction, 2102 epoch: tls.Epoch, 2103 cipher_suite: tls.CipherSuite, 2104 secret: bytes, 2105 ) -> None: 2106 """ 2107 Callback which is invoked by the TLS engine when new traffic keys are 2108 available. 2109 """ 2110 secrets_log_file = self._configuration.secrets_log_file 2111 if secrets_log_file is not None: 2112 label_row = self._is_client == (direction == tls.Direction.DECRYPT) 2113 label = SECRETS_LABELS[label_row][epoch.value] 2114 secrets_log_file.write( 2115 "%s %s %s\n" % (label, self.tls.client_random.hex(), secret.hex()) 2116 ) 2117 secrets_log_file.flush() 2118 2119 crypto = self._cryptos[epoch] 2120 if direction == tls.Direction.ENCRYPT: 2121 crypto.send.setup( 2122 cipher_suite=cipher_suite, secret=secret, version=self._version 2123 ) 2124 else: 2125 crypto.recv.setup( 2126 cipher_suite=cipher_suite, secret=secret, version=self._version 2127 ) 2128 2129 def _write_application( 2130 self, builder: QuicPacketBuilder, network_path: QuicNetworkPath, now: float 2131 ) -> None: 2132 crypto_stream: Optional[QuicStream] = None 2133 if self._cryptos[tls.Epoch.ONE_RTT].send.is_valid(): 2134 crypto = self._cryptos[tls.Epoch.ONE_RTT] 2135 crypto_stream = self._crypto_streams[tls.Epoch.ONE_RTT] 2136 packet_type = PACKET_TYPE_ONE_RTT 2137 elif self._cryptos[tls.Epoch.ZERO_RTT].send.is_valid(): 2138 crypto = self._cryptos[tls.Epoch.ZERO_RTT] 2139 packet_type = PACKET_TYPE_ZERO_RTT 2140 else: 2141 return 2142 space = self._spaces[tls.Epoch.ONE_RTT] 2143 2144 while True: 2145 # apply pacing, except if we have ACKs to send 2146 if space.ack_at is None or space.ack_at >= now: 2147 self._pacing_at = self._loss._pacer.next_send_time(now=now) 2148 if self._pacing_at is not None: 2149 break 2150 builder.start_packet(packet_type, crypto) 2151 2152 if self._handshake_complete: 2153 # ACK 2154 if space.ack_at is not None and space.ack_at <= now: 2155 self._write_ack_frame(builder=builder, space=space, now=now) 2156 2157 # HANDSHAKE_DONE 2158 if self._handshake_done_pending: 2159 self._write_handshake_done_frame(builder=builder) 2160 self._handshake_done_pending = False 2161 2162 # PATH CHALLENGE 2163 if ( 2164 not network_path.is_validated 2165 and network_path.local_challenge is None 2166 ): 2167 challenge = os.urandom(8) 2168 self._write_path_challenge_frame( 2169 builder=builder, challenge=challenge 2170 ) 2171 network_path.local_challenge = challenge 2172 2173 # PATH RESPONSE 2174 if network_path.remote_challenge is not None: 2175 self._write_path_response_frame( 2176 builder=builder, challenge=network_path.remote_challenge 2177 ) 2178 network_path.remote_challenge = None 2179 2180 # NEW_CONNECTION_ID 2181 for connection_id in self._host_cids: 2182 if not connection_id.was_sent: 2183 self._write_new_connection_id_frame( 2184 builder=builder, connection_id=connection_id 2185 ) 2186 2187 # RETIRE_CONNECTION_ID 2188 while self._retire_connection_ids: 2189 sequence_number = self._retire_connection_ids.pop(0) 2190 self._write_retire_connection_id_frame( 2191 builder=builder, sequence_number=sequence_number 2192 ) 2193 2194 # STREAMS_BLOCKED 2195 if self._streams_blocked_pending: 2196 if self._streams_blocked_bidi: 2197 self._write_streams_blocked_frame( 2198 builder=builder, 2199 frame_type=QuicFrameType.STREAMS_BLOCKED_BIDI, 2200 limit=self._remote_max_streams_bidi, 2201 ) 2202 if self._streams_blocked_uni: 2203 self._write_streams_blocked_frame( 2204 builder=builder, 2205 frame_type=QuicFrameType.STREAMS_BLOCKED_UNI, 2206 limit=self._remote_max_streams_uni, 2207 ) 2208 self._streams_blocked_pending = False 2209 2210 # MAX_DATA 2211 self._write_connection_limits(builder=builder, space=space) 2212 2213 # stream-level limits 2214 for stream in self._streams.values(): 2215 self._write_stream_limits(builder=builder, space=space, stream=stream) 2216 2217 # PING (user-request) 2218 if self._ping_pending: 2219 self._write_ping_frame(builder, self._ping_pending) 2220 self._ping_pending.clear() 2221 2222 # PING (probe) 2223 if self._probe_pending: 2224 self._write_ping_frame(builder, comment="probe") 2225 self._probe_pending = False 2226 2227 # CRYPTO 2228 if crypto_stream is not None and not crypto_stream.send_buffer_is_empty: 2229 self._write_crypto_frame( 2230 builder=builder, space=space, stream=crypto_stream 2231 ) 2232 2233 # DATAGRAM 2234 while self._datagrams_pending: 2235 try: 2236 self._write_datagram_frame( 2237 builder=builder, 2238 data=self._datagrams_pending[0], 2239 frame_type=QuicFrameType.DATAGRAM_WITH_LENGTH, 2240 ) 2241 self._datagrams_pending.popleft() 2242 except QuicPacketBuilderStop: 2243 break 2244 2245 # STREAM 2246 for stream in self._streams.values(): 2247 if not stream.is_blocked and not stream.send_buffer_is_empty: 2248 self._remote_max_data_used += self._write_stream_frame( 2249 builder=builder, 2250 space=space, 2251 stream=stream, 2252 max_offset=min( 2253 stream._send_highest 2254 + self._remote_max_data 2255 - self._remote_max_data_used, 2256 stream.max_stream_data_remote, 2257 ), 2258 ) 2259 2260 if builder.packet_is_empty: 2261 break 2262 else: 2263 self._loss._pacer.update_after_send(now=now) 2264 2265 def _write_handshake( 2266 self, builder: QuicPacketBuilder, epoch: tls.Epoch, now: float 2267 ) -> None: 2268 crypto = self._cryptos[epoch] 2269 if not crypto.send.is_valid(): 2270 return 2271 2272 crypto_stream = self._crypto_streams[epoch] 2273 space = self._spaces[epoch] 2274 2275 while True: 2276 if epoch == tls.Epoch.INITIAL: 2277 packet_type = PACKET_TYPE_INITIAL 2278 else: 2279 packet_type = PACKET_TYPE_HANDSHAKE 2280 builder.start_packet(packet_type, crypto) 2281 2282 # ACK 2283 if space.ack_at is not None: 2284 self._write_ack_frame(builder=builder, space=space, now=now) 2285 2286 # CRYPTO 2287 if not crypto_stream.send_buffer_is_empty: 2288 if self._write_crypto_frame( 2289 builder=builder, space=space, stream=crypto_stream 2290 ): 2291 self._probe_pending = False 2292 2293 # PING (probe) 2294 if ( 2295 self._probe_pending 2296 and epoch == tls.Epoch.HANDSHAKE 2297 and not self._handshake_complete 2298 ): 2299 self._write_ping_frame(builder, comment="probe") 2300 self._probe_pending = False 2301 2302 if builder.packet_is_empty: 2303 break 2304 2305 def _write_ack_frame( 2306 self, builder: QuicPacketBuilder, space: QuicPacketSpace, now: float 2307 ) -> None: 2308 # calculate ACK delay 2309 ack_delay = now - space.largest_received_time 2310 ack_delay_encoded = int(ack_delay * 1000000) >> self._local_ack_delay_exponent 2311 2312 buf = builder.start_frame( 2313 QuicFrameType.ACK, 2314 capacity=ACK_FRAME_CAPACITY, 2315 handler=self._on_ack_delivery, 2316 handler_args=(space, space.largest_received_packet), 2317 ) 2318 ranges = push_ack_frame(buf, space.ack_queue, ack_delay_encoded) 2319 space.ack_at = None 2320 2321 # log frame 2322 if self._quic_logger is not None: 2323 builder.quic_logger_frames.append( 2324 self._quic_logger.encode_ack_frame( 2325 ranges=space.ack_queue, delay=ack_delay 2326 ) 2327 ) 2328 2329 # check if we need to trigger an ACK-of-ACK 2330 if ranges > 1 and builder.packet_number % 8 == 0: 2331 self._write_ping_frame(builder, comment="ACK-of-ACK trigger") 2332 2333 def _write_connection_close_frame( 2334 self, 2335 builder: QuicPacketBuilder, 2336 error_code: int, 2337 frame_type: Optional[int], 2338 reason_phrase: str, 2339 ) -> None: 2340 reason_bytes = reason_phrase.encode("utf8") 2341 reason_length = len(reason_bytes) 2342 2343 if frame_type is None: 2344 buf = builder.start_frame( 2345 QuicFrameType.APPLICATION_CLOSE, 2346 capacity=APPLICATION_CLOSE_FRAME_CAPACITY + reason_length, 2347 ) 2348 buf.push_uint_var(error_code) 2349 buf.push_uint_var(reason_length) 2350 buf.push_bytes(reason_bytes) 2351 else: 2352 buf = builder.start_frame( 2353 QuicFrameType.TRANSPORT_CLOSE, 2354 capacity=TRANSPORT_CLOSE_FRAME_CAPACITY + reason_length, 2355 ) 2356 buf.push_uint_var(error_code) 2357 buf.push_uint_var(frame_type) 2358 buf.push_uint_var(reason_length) 2359 buf.push_bytes(reason_bytes) 2360 2361 # log frame 2362 if self._quic_logger is not None: 2363 builder.quic_logger_frames.append( 2364 self._quic_logger.encode_connection_close_frame( 2365 error_code=error_code, 2366 frame_type=frame_type, 2367 reason_phrase=reason_phrase, 2368 ) 2369 ) 2370 2371 def _write_connection_limits( 2372 self, builder: QuicPacketBuilder, space: QuicPacketSpace 2373 ) -> None: 2374 """ 2375 Raise MAX_DATA if needed. 2376 """ 2377 if self._local_max_data_used * 2 > self._local_max_data: 2378 self._local_max_data *= 2 2379 self._logger.debug("Local max_data raised to %d", self._local_max_data) 2380 if self._local_max_data_sent != self._local_max_data: 2381 buf = builder.start_frame( 2382 QuicFrameType.MAX_DATA, 2383 capacity=MAX_DATA_FRAME_CAPACITY, 2384 handler=self._on_max_data_delivery, 2385 ) 2386 buf.push_uint_var(self._local_max_data) 2387 self._local_max_data_sent = self._local_max_data 2388 2389 # log frame 2390 if self._quic_logger is not None: 2391 builder.quic_logger_frames.append( 2392 self._quic_logger.encode_max_data_frame(self._local_max_data) 2393 ) 2394 2395 def _write_crypto_frame( 2396 self, builder: QuicPacketBuilder, space: QuicPacketSpace, stream: QuicStream 2397 ) -> bool: 2398 frame_overhead = 3 + size_uint_var(stream.next_send_offset) 2399 frame = stream.get_frame(builder.remaining_flight_space - frame_overhead) 2400 if frame is not None: 2401 buf = builder.start_frame( 2402 QuicFrameType.CRYPTO, 2403 capacity=frame_overhead, 2404 handler=stream.on_data_delivery, 2405 handler_args=(frame.offset, frame.offset + len(frame.data)), 2406 ) 2407 buf.push_uint_var(frame.offset) 2408 buf.push_uint16(len(frame.data) | 0x4000) 2409 buf.push_bytes(frame.data) 2410 2411 # log frame 2412 if self._quic_logger is not None: 2413 builder.quic_logger_frames.append( 2414 self._quic_logger.encode_crypto_frame(frame) 2415 ) 2416 return True 2417 2418 return False 2419 2420 def _write_datagram_frame( 2421 self, builder: QuicPacketBuilder, data: bytes, frame_type: QuicFrameType 2422 ) -> bool: 2423 """ 2424 Write a DATAGRAM frame. 2425 2426 Returns True if the frame was processed, False otherwise. 2427 """ 2428 assert frame_type == QuicFrameType.DATAGRAM_WITH_LENGTH 2429 length = len(data) 2430 frame_size = 1 + size_uint_var(length) + length 2431 2432 buf = builder.start_frame(frame_type, capacity=frame_size) 2433 buf.push_uint_var(length) 2434 buf.push_bytes(data) 2435 2436 # log frame 2437 if self._quic_logger is not None: 2438 builder.quic_logger_frames.append( 2439 self._quic_logger.encode_datagram_frame(length=length) 2440 ) 2441 2442 return True 2443 2444 def _write_handshake_done_frame(self, builder: QuicPacketBuilder) -> None: 2445 builder.start_frame( 2446 QuicFrameType.HANDSHAKE_DONE, 2447 capacity=HANDSHAKE_DONE_FRAME_CAPACITY, 2448 handler=self._on_handshake_done_delivery, 2449 ) 2450 2451 # log frame 2452 if self._quic_logger is not None: 2453 builder.quic_logger_frames.append( 2454 self._quic_logger.encode_handshake_done_frame() 2455 ) 2456 2457 def _write_new_connection_id_frame( 2458 self, builder: QuicPacketBuilder, connection_id: QuicConnectionId 2459 ) -> None: 2460 retire_prior_to = 0 # FIXME 2461 2462 buf = builder.start_frame( 2463 QuicFrameType.NEW_CONNECTION_ID, 2464 capacity=NEW_CONNECTION_ID_FRAME_CAPACITY, 2465 handler=self._on_new_connection_id_delivery, 2466 handler_args=(connection_id,), 2467 ) 2468 buf.push_uint_var(connection_id.sequence_number) 2469 buf.push_uint_var(retire_prior_to) 2470 buf.push_uint8(len(connection_id.cid)) 2471 buf.push_bytes(connection_id.cid) 2472 buf.push_bytes(connection_id.stateless_reset_token) 2473 2474 connection_id.was_sent = True 2475 self._events.append(events.ConnectionIdIssued(connection_id=connection_id.cid)) 2476 2477 # log frame 2478 if self._quic_logger is not None: 2479 builder.quic_logger_frames.append( 2480 self._quic_logger.encode_new_connection_id_frame( 2481 connection_id=connection_id.cid, 2482 retire_prior_to=retire_prior_to, 2483 sequence_number=connection_id.sequence_number, 2484 stateless_reset_token=connection_id.stateless_reset_token, 2485 ) 2486 ) 2487 2488 def _write_path_challenge_frame( 2489 self, builder: QuicPacketBuilder, challenge: bytes 2490 ) -> None: 2491 buf = builder.start_frame( 2492 QuicFrameType.PATH_CHALLENGE, capacity=PATH_CHALLENGE_FRAME_CAPACITY 2493 ) 2494 buf.push_bytes(challenge) 2495 2496 # log frame 2497 if self._quic_logger is not None: 2498 builder.quic_logger_frames.append( 2499 self._quic_logger.encode_path_challenge_frame(data=challenge) 2500 ) 2501 2502 def _write_path_response_frame( 2503 self, builder: QuicPacketBuilder, challenge: bytes 2504 ) -> None: 2505 buf = builder.start_frame( 2506 QuicFrameType.PATH_RESPONSE, capacity=PATH_RESPONSE_FRAME_CAPACITY 2507 ) 2508 buf.push_bytes(challenge) 2509 2510 # log frame 2511 if self._quic_logger is not None: 2512 builder.quic_logger_frames.append( 2513 self._quic_logger.encode_path_response_frame(data=challenge) 2514 ) 2515 2516 def _write_ping_frame( 2517 self, builder: QuicPacketBuilder, uids: List[int] = [], comment="" 2518 ): 2519 builder.start_frame( 2520 QuicFrameType.PING, 2521 capacity=PING_FRAME_CAPACITY, 2522 handler=self._on_ping_delivery, 2523 handler_args=(tuple(uids),), 2524 ) 2525 self._logger.debug( 2526 "Sending PING%s in packet %d", 2527 " (%s)" % comment if comment else "", 2528 builder.packet_number, 2529 ) 2530 2531 # log frame 2532 if self._quic_logger is not None: 2533 builder.quic_logger_frames.append(self._quic_logger.encode_ping_frame()) 2534 2535 def _write_retire_connection_id_frame( 2536 self, builder: QuicPacketBuilder, sequence_number: int 2537 ) -> None: 2538 buf = builder.start_frame( 2539 QuicFrameType.RETIRE_CONNECTION_ID, 2540 capacity=RETIRE_CONNECTION_ID_CAPACITY, 2541 handler=self._on_retire_connection_id_delivery, 2542 handler_args=(sequence_number,), 2543 ) 2544 buf.push_uint_var(sequence_number) 2545 2546 # log frame 2547 if self._quic_logger is not None: 2548 builder.quic_logger_frames.append( 2549 self._quic_logger.encode_retire_connection_id_frame(sequence_number) 2550 ) 2551 2552 def _write_stream_frame( 2553 self, 2554 builder: QuicPacketBuilder, 2555 space: QuicPacketSpace, 2556 stream: QuicStream, 2557 max_offset: int, 2558 ) -> int: 2559 # the frame data size is constrained by our peer's MAX_DATA and 2560 # the space available in the current packet 2561 frame_overhead = ( 2562 3 2563 + size_uint_var(stream.stream_id) 2564 + (size_uint_var(stream.next_send_offset) if stream.next_send_offset else 0) 2565 ) 2566 previous_send_highest = stream._send_highest 2567 frame = stream.get_frame( 2568 builder.remaining_flight_space - frame_overhead, max_offset 2569 ) 2570 2571 if frame is not None: 2572 frame_type = QuicFrameType.STREAM_BASE | 2 # length 2573 if frame.offset: 2574 frame_type |= 4 2575 if frame.fin: 2576 frame_type |= 1 2577 buf = builder.start_frame( 2578 frame_type, 2579 capacity=frame_overhead, 2580 handler=stream.on_data_delivery, 2581 handler_args=(frame.offset, frame.offset + len(frame.data)), 2582 ) 2583 buf.push_uint_var(stream.stream_id) 2584 if frame.offset: 2585 buf.push_uint_var(frame.offset) 2586 buf.push_uint16(len(frame.data) | 0x4000) 2587 buf.push_bytes(frame.data) 2588 2589 # log frame 2590 if self._quic_logger is not None: 2591 builder.quic_logger_frames.append( 2592 self._quic_logger.encode_stream_frame( 2593 frame, stream_id=stream.stream_id 2594 ) 2595 ) 2596 2597 return stream._send_highest - previous_send_highest 2598 else: 2599 return 0 2600 2601 def _write_stream_limits( 2602 self, builder: QuicPacketBuilder, space: QuicPacketSpace, stream: QuicStream 2603 ) -> None: 2604 """ 2605 Raise MAX_STREAM_DATA if needed. 2606 2607 The only case where `stream.max_stream_data_local` is zero is for 2608 locally created unidirectional streams. We skip such streams to avoid 2609 spurious logging. 2610 """ 2611 if ( 2612 stream.max_stream_data_local 2613 and stream._recv_highest * 2 > stream.max_stream_data_local 2614 ): 2615 stream.max_stream_data_local *= 2 2616 self._logger.debug( 2617 "Stream %d local max_stream_data raised to %d", 2618 stream.stream_id, 2619 stream.max_stream_data_local, 2620 ) 2621 if stream.max_stream_data_local_sent != stream.max_stream_data_local: 2622 buf = builder.start_frame( 2623 QuicFrameType.MAX_STREAM_DATA, 2624 capacity=MAX_STREAM_DATA_FRAME_CAPACITY, 2625 handler=self._on_max_stream_data_delivery, 2626 handler_args=(stream,), 2627 ) 2628 buf.push_uint_var(stream.stream_id) 2629 buf.push_uint_var(stream.max_stream_data_local) 2630 stream.max_stream_data_local_sent = stream.max_stream_data_local 2631 2632 # log frame 2633 if self._quic_logger is not None: 2634 builder.quic_logger_frames.append( 2635 self._quic_logger.encode_max_stream_data_frame( 2636 maximum=stream.max_stream_data_local, stream_id=stream.stream_id 2637 ) 2638 ) 2639 2640 def _write_streams_blocked_frame( 2641 self, builder: QuicPacketBuilder, frame_type: QuicFrameType, limit: int 2642 ) -> None: 2643 buf = builder.start_frame(frame_type, capacity=STREAMS_BLOCKED_CAPACITY) 2644 buf.push_uint_var(limit) 2645 2646 # log frame 2647 if self._quic_logger is not None: 2648 builder.quic_logger_frames.append( 2649 self._quic_logger.encode_streams_blocked_frame( 2650 is_unidirectional=frame_type == QuicFrameType.STREAMS_BLOCKED_UNI, 2651 limit=limit, 2652 ) 2653 ) 2654