1from dataclasses import dataclass, field
2from enum import Enum
3from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
4
5from ..buffer import Buffer, size_uint_var
6from ..tls import Epoch
7from .crypto import CryptoPair
8from .logger import QuicLoggerTrace
9from .packet import (
10    NON_ACK_ELICITING_FRAME_TYPES,
11    NON_IN_FLIGHT_FRAME_TYPES,
12    PACKET_NUMBER_MAX_SIZE,
13    PACKET_TYPE_HANDSHAKE,
14    PACKET_TYPE_INITIAL,
15    PACKET_TYPE_MASK,
16    QuicFrameType,
17    is_long_header,
18)
19
20PACKET_MAX_SIZE = 1280
21PACKET_LENGTH_SEND_SIZE = 2
22PACKET_NUMBER_SEND_SIZE = 2
23
24
25QuicDeliveryHandler = Callable[..., None]
26
27
28class QuicDeliveryState(Enum):
29    ACKED = 0
30    LOST = 1
31    EXPIRED = 2
32
33
34@dataclass
35class QuicSentPacket:
36    epoch: Epoch
37    in_flight: bool
38    is_ack_eliciting: bool
39    is_crypto_packet: bool
40    packet_number: int
41    packet_type: int
42    sent_time: Optional[float] = None
43    sent_bytes: int = 0
44
45    delivery_handlers: List[Tuple[QuicDeliveryHandler, Any]] = field(
46        default_factory=list
47    )
48    quic_logger_frames: List[Dict] = field(default_factory=list)
49
50
51class QuicPacketBuilderStop(Exception):
52    pass
53
54
55class QuicPacketBuilder:
56    """
57    Helper for building QUIC packets.
58    """
59
60    def __init__(
61        self,
62        *,
63        host_cid: bytes,
64        peer_cid: bytes,
65        version: int,
66        is_client: bool,
67        packet_number: int = 0,
68        peer_token: bytes = b"",
69        quic_logger: Optional[QuicLoggerTrace] = None,
70        spin_bit: bool = False,
71    ):
72        self.max_flight_bytes: Optional[int] = None
73        self.max_total_bytes: Optional[int] = None
74        self.quic_logger_frames: Optional[List[Dict]] = None
75
76        self._host_cid = host_cid
77        self._is_client = is_client
78        self._peer_cid = peer_cid
79        self._peer_token = peer_token
80        self._quic_logger = quic_logger
81        self._spin_bit = spin_bit
82        self._version = version
83
84        # assembled datagrams and packets
85        self._datagrams: List[bytes] = []
86        self._datagram_flight_bytes = 0
87        self._datagram_init = True
88        self._packets: List[QuicSentPacket] = []
89        self._flight_bytes = 0
90        self._total_bytes = 0
91
92        # current packet
93        self._header_size = 0
94        self._packet: Optional[QuicSentPacket] = None
95        self._packet_crypto: Optional[CryptoPair] = None
96        self._packet_long_header = False
97        self._packet_number = packet_number
98        self._packet_start = 0
99        self._packet_type = 0
100
101        self._buffer = Buffer(PACKET_MAX_SIZE)
102        self._buffer_capacity = PACKET_MAX_SIZE
103        self._flight_capacity = PACKET_MAX_SIZE
104
105    @property
106    def packet_is_empty(self) -> bool:
107        """
108        Returns `True` if the current packet is empty.
109        """
110        assert self._packet is not None
111        packet_size = self._buffer.tell() - self._packet_start
112        return packet_size <= self._header_size
113
114    @property
115    def packet_number(self) -> int:
116        """
117        Returns the packet number for the next packet.
118        """
119        return self._packet_number
120
121    @property
122    def remaining_buffer_space(self) -> int:
123        """
124        Returns the remaining number of bytes which can be used in
125        the current packet.
126        """
127        return (
128            self._buffer_capacity
129            - self._buffer.tell()
130            - self._packet_crypto.aead_tag_size
131        )
132
133    @property
134    def remaining_flight_space(self) -> int:
135        """
136        Returns the remaining number of bytes which can be used in
137        the current packet.
138        """
139        return (
140            self._flight_capacity
141            - self._buffer.tell()
142            - self._packet_crypto.aead_tag_size
143        )
144
145    def flush(self) -> Tuple[List[bytes], List[QuicSentPacket]]:
146        """
147        Returns the assembled datagrams.
148        """
149        if self._packet is not None:
150            self._end_packet()
151        self._flush_current_datagram()
152
153        datagrams = self._datagrams
154        packets = self._packets
155        self._datagrams = []
156        self._packets = []
157        return datagrams, packets
158
159    def start_frame(
160        self,
161        frame_type: int,
162        capacity: int = 1,
163        handler: Optional[QuicDeliveryHandler] = None,
164        handler_args: Sequence[Any] = [],
165    ) -> Buffer:
166        """
167        Starts a new frame.
168        """
169        if self.remaining_buffer_space < capacity or (
170            frame_type not in NON_IN_FLIGHT_FRAME_TYPES
171            and self.remaining_flight_space < capacity
172        ):
173            raise QuicPacketBuilderStop
174
175        self._buffer.push_uint_var(frame_type)
176        if frame_type not in NON_ACK_ELICITING_FRAME_TYPES:
177            self._packet.is_ack_eliciting = True
178        if frame_type not in NON_IN_FLIGHT_FRAME_TYPES:
179            self._packet.in_flight = True
180        if frame_type == QuicFrameType.CRYPTO:
181            self._packet.is_crypto_packet = True
182        if handler is not None:
183            self._packet.delivery_handlers.append((handler, handler_args))
184        return self._buffer
185
186    def start_packet(self, packet_type: int, crypto: CryptoPair) -> None:
187        """
188        Starts a new packet.
189        """
190        buf = self._buffer
191
192        # finish previous datagram
193        if self._packet is not None:
194            self._end_packet()
195
196        # if there is too little space remaining, start a new datagram
197        # FIXME: the limit is arbitrary!
198        packet_start = buf.tell()
199        if self._buffer_capacity - packet_start < 128:
200            self._flush_current_datagram()
201            packet_start = 0
202
203        # initialize datagram if needed
204        if self._datagram_init:
205            if self.max_total_bytes is not None:
206                remaining_total_bytes = self.max_total_bytes - self._total_bytes
207                if remaining_total_bytes < self._buffer_capacity:
208                    self._buffer_capacity = remaining_total_bytes
209
210            self._flight_capacity = self._buffer_capacity
211            if self.max_flight_bytes is not None:
212                remaining_flight_bytes = self.max_flight_bytes - self._flight_bytes
213                if remaining_flight_bytes < self._flight_capacity:
214                    self._flight_capacity = remaining_flight_bytes
215            self._datagram_flight_bytes = 0
216            self._datagram_init = False
217
218        # calculate header size
219        packet_long_header = is_long_header(packet_type)
220        if packet_long_header:
221            header_size = 11 + len(self._peer_cid) + len(self._host_cid)
222            if (packet_type & PACKET_TYPE_MASK) == PACKET_TYPE_INITIAL:
223                token_length = len(self._peer_token)
224                header_size += size_uint_var(token_length) + token_length
225        else:
226            header_size = 3 + len(self._peer_cid)
227
228        # check we have enough space
229        if packet_start + header_size >= self._buffer_capacity:
230            raise QuicPacketBuilderStop
231
232        # determine ack epoch
233        if packet_type == PACKET_TYPE_INITIAL:
234            epoch = Epoch.INITIAL
235        elif packet_type == PACKET_TYPE_HANDSHAKE:
236            epoch = Epoch.HANDSHAKE
237        else:
238            epoch = Epoch.ONE_RTT
239
240        self._header_size = header_size
241        self._packet = QuicSentPacket(
242            epoch=epoch,
243            in_flight=False,
244            is_ack_eliciting=False,
245            is_crypto_packet=False,
246            packet_number=self._packet_number,
247            packet_type=packet_type,
248        )
249        self._packet_crypto = crypto
250        self._packet_long_header = packet_long_header
251        self._packet_start = packet_start
252        self._packet_type = packet_type
253        self.quic_logger_frames = self._packet.quic_logger_frames
254
255        buf.seek(self._packet_start + self._header_size)
256
257    def _end_packet(self) -> None:
258        """
259        Ends the current packet.
260        """
261        buf = self._buffer
262        packet_size = buf.tell() - self._packet_start
263        if packet_size > self._header_size:
264            # pad initial datagram
265            if (
266                self._is_client
267                and self._packet_type == PACKET_TYPE_INITIAL
268                and self._packet.is_crypto_packet
269            ):
270                if self.remaining_flight_space:
271                    buf.push_bytes(bytes(self.remaining_flight_space))
272                    packet_size = buf.tell() - self._packet_start
273                    self._packet.in_flight = True
274
275                    # log frame
276                    if self._quic_logger is not None:
277                        self._packet.quic_logger_frames.append(
278                            self._quic_logger.encode_padding_frame()
279                        )
280
281            # write header
282            if self._packet_long_header:
283                length = (
284                    packet_size
285                    - self._header_size
286                    + PACKET_NUMBER_SEND_SIZE
287                    + self._packet_crypto.aead_tag_size
288                )
289
290                buf.seek(self._packet_start)
291                buf.push_uint8(self._packet_type | (PACKET_NUMBER_SEND_SIZE - 1))
292                buf.push_uint32(self._version)
293                buf.push_uint8(len(self._peer_cid))
294                buf.push_bytes(self._peer_cid)
295                buf.push_uint8(len(self._host_cid))
296                buf.push_bytes(self._host_cid)
297                if (self._packet_type & PACKET_TYPE_MASK) == PACKET_TYPE_INITIAL:
298                    buf.push_uint_var(len(self._peer_token))
299                    buf.push_bytes(self._peer_token)
300                buf.push_uint16(length | 0x4000)
301                buf.push_uint16(self._packet_number & 0xFFFF)
302            else:
303                buf.seek(self._packet_start)
304                buf.push_uint8(
305                    self._packet_type
306                    | (self._spin_bit << 5)
307                    | (self._packet_crypto.key_phase << 2)
308                    | (PACKET_NUMBER_SEND_SIZE - 1)
309                )
310                buf.push_bytes(self._peer_cid)
311                buf.push_uint16(self._packet_number & 0xFFFF)
312
313                # check whether we need padding
314                padding_size = (
315                    PACKET_NUMBER_MAX_SIZE
316                    - PACKET_NUMBER_SEND_SIZE
317                    + self._header_size
318                    - packet_size
319                )
320                if padding_size > 0:
321                    buf.seek(self._packet_start + packet_size)
322                    buf.push_bytes(bytes(padding_size))
323                    packet_size += padding_size
324                    self._packet.in_flight = True
325
326                    # log frame
327                    if self._quic_logger is not None:
328                        self._packet.quic_logger_frames.append(
329                            self._quic_logger.encode_padding_frame()
330                        )
331
332            # encrypt in place
333            plain = buf.data_slice(self._packet_start, self._packet_start + packet_size)
334            buf.seek(self._packet_start)
335            buf.push_bytes(
336                self._packet_crypto.encrypt_packet(
337                    plain[0 : self._header_size],
338                    plain[self._header_size : packet_size],
339                    self._packet_number,
340                )
341            )
342            self._packet.sent_bytes = buf.tell() - self._packet_start
343            self._packets.append(self._packet)
344            if self._packet.in_flight:
345                self._datagram_flight_bytes += self._packet.sent_bytes
346
347            # short header packets cannot be coallesced, we need a new datagram
348            if not self._packet_long_header:
349                self._flush_current_datagram()
350
351            self._packet_number += 1
352        else:
353            # "cancel" the packet
354            buf.seek(self._packet_start)
355
356        self._packet = None
357        self.quic_logger_frames = None
358
359    def _flush_current_datagram(self) -> None:
360        datagram_bytes = self._buffer.tell()
361        if datagram_bytes:
362            self._datagrams.append(self._buffer.data)
363            self._flight_bytes += self._datagram_flight_bytes
364            self._total_bytes += datagram_bytes
365            self._datagram_init = True
366            self._buffer.seek(0)
367