1import os
2from dataclasses import dataclass, field
3from struct import pack, unpack, unpack_from
4from typing import Any, List, Optional, Tuple, Union
5
6from .rtcrtpparameters import RTCRtpParameters
7
8# reserved to avoid confusion with RTCP
9FORBIDDEN_PAYLOAD_TYPES = range(72, 77)
10DYNAMIC_PAYLOAD_TYPES = range(96, 128)
11
12RTP_HEADER_LENGTH = 12
13RTCP_HEADER_LENGTH = 4
14
15PACKETS_LOST_MIN = -(1 << 23)
16PACKETS_LOST_MAX = (1 << 23) - 1
17
18RTCP_SR = 200
19RTCP_RR = 201
20RTCP_SDES = 202
21RTCP_BYE = 203
22RTCP_RTPFB = 205
23RTCP_PSFB = 206
24
25RTCP_RTPFB_NACK = 1
26
27RTCP_PSFB_PLI = 1
28RTCP_PSFB_SLI = 2
29RTCP_PSFB_RPSI = 3
30RTCP_PSFB_APP = 15
31
32
33@dataclass
34class HeaderExtensions:
35    abs_send_time: Optional[int] = None
36    audio_level: Any = None
37    mid: Any = None
38    repaired_rtp_stream_id: Any = None
39    rtp_stream_id: Any = None
40    transmission_offset: Optional[int] = None
41    transport_sequence_number: Optional[int] = None
42
43
44class HeaderExtensionsMap:
45    def __init__(self) -> None:
46        self.__ids = HeaderExtensions()
47
48    def configure(self, parameters: RTCRtpParameters) -> None:
49        for ext in parameters.headerExtensions:
50            if ext.uri == "urn:ietf:params:rtp-hdrext:sdes:mid":
51                self.__ids.mid = ext.id
52            elif ext.uri == "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id":
53                self.__ids.repaired_rtp_stream_id = ext.id
54            elif ext.uri == "urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id":
55                self.__ids.rtp_stream_id = ext.id
56            elif (
57                ext.uri == "http://www.webrtc.org/experiments/rtp-hdrext/abs-send-time"
58            ):
59                self.__ids.abs_send_time = ext.id
60            elif ext.uri == "urn:ietf:params:rtp-hdrext:toffset":
61                self.__ids.transmission_offset = ext.id
62            elif ext.uri == "urn:ietf:params:rtp-hdrext:ssrc-audio-level":
63                self.__ids.audio_level = ext.id
64            elif (
65                ext.uri
66                == "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01"
67            ):
68                self.__ids.transport_sequence_number = ext.id
69
70    def get(self, extension_profile: int, extension_value: bytes) -> HeaderExtensions:
71        values = HeaderExtensions()
72        for x_id, x_value in unpack_header_extensions(
73            extension_profile, extension_value
74        ):
75            if x_id == self.__ids.mid:
76                values.mid = x_value.decode("utf8")
77            elif x_id == self.__ids.repaired_rtp_stream_id:
78                values.repaired_rtp_stream_id = x_value.decode("ascii")
79            elif x_id == self.__ids.rtp_stream_id:
80                values.rtp_stream_id = x_value.decode("ascii")
81            elif x_id == self.__ids.abs_send_time:
82                values.abs_send_time = unpack("!L", b"\00" + x_value)[0]
83            elif x_id == self.__ids.transmission_offset:
84                values.transmission_offset = unpack("!l", x_value + b"\00")[0] >> 8
85            elif x_id == self.__ids.audio_level:
86                vad_level = unpack("!B", x_value)[0]
87                values.audio_level = (vad_level & 0x80 == 0x80, vad_level & 0x7F)
88            elif x_id == self.__ids.transport_sequence_number:
89                values.transport_sequence_number = unpack("!H", x_value)[0]
90        return values
91
92    def set(self, values: HeaderExtensions):
93        extensions = []
94        if values.mid is not None and self.__ids.mid:
95            extensions.append((self.__ids.mid, values.mid.encode("utf8")))
96        if (
97            values.repaired_rtp_stream_id is not None
98            and self.__ids.repaired_rtp_stream_id
99        ):
100            extensions.append(
101                (
102                    self.__ids.repaired_rtp_stream_id,
103                    values.repaired_rtp_stream_id.encode("ascii"),
104                )
105            )
106        if values.rtp_stream_id is not None and self.__ids.rtp_stream_id:
107            extensions.append(
108                (self.__ids.rtp_stream_id, values.rtp_stream_id.encode("ascii"))
109            )
110        if values.abs_send_time is not None and self.__ids.abs_send_time:
111            extensions.append(
112                (self.__ids.abs_send_time, pack("!L", values.abs_send_time)[1:])
113            )
114        if values.transmission_offset is not None and self.__ids.transmission_offset:
115            extensions.append(
116                (
117                    self.__ids.transmission_offset,
118                    pack("!l", values.transmission_offset << 8)[0:2],
119                )
120            )
121        if values.audio_level is not None and self.__ids.audio_level:
122            extensions.append(
123                (
124                    self.__ids.audio_level,
125                    pack(
126                        "!B",
127                        (0x80 if values.audio_level[0] else 0)
128                        | (values.audio_level[1] & 0x7F),
129                    ),
130                )
131            )
132        if (
133            values.transport_sequence_number is not None
134            and self.__ids.transport_sequence_number
135        ):
136            extensions.append(
137                (
138                    self.__ids.transport_sequence_number,
139                    pack("!H", values.transport_sequence_number),
140                )
141            )
142        return pack_header_extensions(extensions)
143
144
145def clamp_packets_lost(count: int) -> int:
146    return max(PACKETS_LOST_MIN, min(count, PACKETS_LOST_MAX))
147
148
149def pack_packets_lost(count: int) -> bytes:
150    return pack("!l", count)[1:]
151
152
153def unpack_packets_lost(d: bytes) -> int:
154    if d[0] & 0x80:
155        d = b"\xff" + d
156    else:
157        d = b"\x00" + d
158    return unpack("!l", d)[0]
159
160
161def pack_rtcp_packet(packet_type: int, count: int, payload: bytes) -> bytes:
162    assert len(payload) % 4 == 0
163    return pack("!BBH", (2 << 6) | count, packet_type, len(payload) // 4) + payload
164
165
166def pack_remb_fci(bitrate: int, ssrcs: List[int]) -> bytes:
167    """
168    Pack the FCI for a Receiver Estimated Maximum Bitrate report.
169
170    https://tools.ietf.org/html/draft-alvestrand-rmcat-remb-03
171    """
172    data = b"REMB"
173    exponent = 0
174    mantissa = bitrate
175    while mantissa > 0x3FFFF:
176        mantissa >>= 1
177        exponent += 1
178    data += pack(
179        "!BBH", len(ssrcs), (exponent << 2) | (mantissa >> 16), (mantissa & 0xFFFF)
180    )
181    for ssrc in ssrcs:
182        data += pack("!L", ssrc)
183    return data
184
185
186def unpack_remb_fci(data: bytes) -> Tuple[int, List[int]]:
187    """
188    Unpack the FCI for a Receiver Estimated Maximum Bitrate report.
189
190    https://tools.ietf.org/html/draft-alvestrand-rmcat-remb-03
191    """
192    if len(data) < 8 or data[0:4] != b"REMB":
193        raise ValueError("Invalid REMB prefix")
194
195    exponent = (data[5] & 0xFC) >> 2
196    mantissa = ((data[5] & 0x03) << 16) | (data[6] << 8) | data[7]
197    bitrate = mantissa << exponent
198
199    pos = 8
200    ssrcs = []
201    for r in range(data[4]):
202        ssrcs.append(unpack_from("!L", data, pos)[0])
203        pos += 4
204
205    return (bitrate, ssrcs)
206
207
208def is_rtcp(msg: bytes) -> bool:
209    return len(msg) >= 2 and msg[1] >= 192 and msg[1] <= 208
210
211
212def padl(length: int) -> int:
213    """
214    Return amount of padding needed for a 4-byte multiple.
215    """
216    return 4 * ((length + 3) // 4) - length
217
218
219def unpack_header_extensions(
220    extension_profile: int, extension_value: bytes
221) -> List[Tuple[int, bytes]]:
222    """
223    Parse header extensions according to RFC 5285.
224    """
225    extensions = []
226    pos = 0
227
228    if extension_profile == 0xBEDE:
229        # One-Byte Header
230        while pos < len(extension_value):
231            # skip padding byte
232            if extension_value[pos] == 0:
233                pos += 1
234                continue
235
236            x_id = (extension_value[pos] & 0xF0) >> 4
237            x_length = (extension_value[pos] & 0x0F) + 1
238            pos += 1
239
240            if len(extension_value) < pos + x_length:
241                raise ValueError("RTP one-byte header extension value is truncated")
242            x_value = extension_value[pos : pos + x_length]
243            extensions.append((x_id, x_value))
244            pos += x_length
245    elif extension_profile == 0x1000:
246        # Two-Byte Header
247        while pos < len(extension_value):
248            # skip padding byte
249            if extension_value[pos] == 0:
250                pos += 1
251                continue
252
253            if len(extension_value) < pos + 2:
254                raise ValueError("RTP two-byte header extension is truncated")
255            x_id, x_length = unpack_from("!BB", extension_value, pos)
256            pos += 2
257
258            if len(extension_value) < pos + x_length:
259                raise ValueError("RTP two-byte header extension value is truncated")
260            x_value = extension_value[pos : pos + x_length]
261            extensions.append((x_id, x_value))
262            pos += x_length
263
264    return extensions
265
266
267def pack_header_extensions(extensions: List[Tuple[int, bytes]]) -> Tuple[int, bytes]:
268    """
269    Serialize header extensions according to RFC 5285.
270    """
271    extension_profile = 0
272    extension_value = b""
273
274    if not extensions:
275        return extension_profile, extension_value
276
277    one_byte = True
278    for x_id, x_value in extensions:
279        x_length = len(x_value)
280        assert x_id > 0 and x_id < 256
281        assert x_length >= 0 and x_length < 256
282        if x_id > 14 or x_length == 0 or x_length > 16:
283            one_byte = False
284
285    if one_byte:
286        # One-Byte Header
287        extension_profile = 0xBEDE
288        extension_value = b""
289        for x_id, x_value in extensions:
290            x_length = len(x_value)
291            extension_value += pack("!B", (x_id << 4) | (x_length - 1))
292            extension_value += x_value
293    else:
294        # Two-Byte Header
295        extension_profile = 0x1000
296        extension_value = b""
297        for x_id, x_value in extensions:
298            x_length = len(x_value)
299            extension_value += pack("!BB", x_id, x_length)
300            extension_value += x_value
301
302    extension_value += b"\x00" * padl(len(extension_value))
303    return extension_profile, extension_value
304
305
306@dataclass
307class RtcpReceiverInfo:
308    ssrc: int
309    fraction_lost: int
310    packets_lost: int
311    highest_sequence: int
312    jitter: int
313    lsr: int
314    dlsr: int
315
316    def __bytes__(self) -> bytes:
317        data = pack("!LB", self.ssrc, self.fraction_lost)
318        data += pack_packets_lost(self.packets_lost)
319        data += pack("!LLLL", self.highest_sequence, self.jitter, self.lsr, self.dlsr)
320        return data
321
322    @classmethod
323    def parse(cls, data: bytes):
324        ssrc, fraction_lost = unpack("!LB", data[0:5])
325        packets_lost = unpack_packets_lost(data[5:8])
326        highest_sequence, jitter, lsr, dlsr = unpack("!LLLL", data[8:])
327        return cls(
328            ssrc=ssrc,
329            fraction_lost=fraction_lost,
330            packets_lost=packets_lost,
331            highest_sequence=highest_sequence,
332            jitter=jitter,
333            lsr=lsr,
334            dlsr=dlsr,
335        )
336
337
338@dataclass
339class RtcpSenderInfo:
340    ntp_timestamp: int
341    rtp_timestamp: int
342    packet_count: int
343    octet_count: int
344
345    def __bytes__(self) -> bytes:
346        return pack(
347            "!QLLL",
348            self.ntp_timestamp,
349            self.rtp_timestamp,
350            self.packet_count,
351            self.octet_count,
352        )
353
354    @classmethod
355    def parse(cls, data: bytes):
356        ntp_timestamp, rtp_timestamp, packet_count, octet_count = unpack("!QLLL", data)
357        return cls(
358            ntp_timestamp=ntp_timestamp,
359            rtp_timestamp=rtp_timestamp,
360            packet_count=packet_count,
361            octet_count=octet_count,
362        )
363
364
365@dataclass
366class RtcpSourceInfo:
367    ssrc: int
368    items: List[Tuple[Any, bytes]]
369
370
371@dataclass
372class RtcpByePacket:
373    sources: List[int]
374
375    def __bytes__(self) -> bytes:
376        payload = b"".join([pack("!L", ssrc) for ssrc in self.sources])
377        return pack_rtcp_packet(RTCP_BYE, len(self.sources), payload)
378
379    @classmethod
380    def parse(cls, data: bytes, count: int):
381        if len(data) < 4 * count:
382            raise ValueError("RTCP bye length is invalid")
383        if count > 0:
384            sources = list(unpack_from("!" + ("L" * count), data, 0))
385        else:
386            sources = []
387        return cls(sources=sources)
388
389
390@dataclass
391class RtcpPsfbPacket:
392    """
393    Payload-Specific Feedback Message (RFC 4585).
394    """
395
396    fmt: int
397    ssrc: int
398    media_ssrc: int
399    fci: bytes = b""
400
401    def __bytes__(self) -> bytes:
402        payload = pack("!LL", self.ssrc, self.media_ssrc) + self.fci
403        return pack_rtcp_packet(RTCP_PSFB, self.fmt, payload)
404
405    @classmethod
406    def parse(cls, data: bytes, fmt: int):
407        if len(data) < 8:
408            raise ValueError("RTCP payload-specific feedback length is invalid")
409
410        ssrc, media_ssrc = unpack("!LL", data[0:8])
411        fci = data[8:]
412        return cls(fmt=fmt, ssrc=ssrc, media_ssrc=media_ssrc, fci=fci)
413
414
415@dataclass
416class RtcpRrPacket:
417    ssrc: int
418    reports: List[RtcpReceiverInfo] = field(default_factory=list)
419
420    def __bytes__(self) -> bytes:
421        payload = pack("!L", self.ssrc)
422        for report in self.reports:
423            payload += bytes(report)
424        return pack_rtcp_packet(RTCP_RR, len(self.reports), payload)
425
426    @classmethod
427    def parse(cls, data: bytes, count: int):
428        if len(data) != 4 + 24 * count:
429            raise ValueError("RTCP receiver report length is invalid")
430
431        ssrc = unpack("!L", data[0:4])[0]
432        pos = 4
433        reports = []
434        for r in range(count):
435            reports.append(RtcpReceiverInfo.parse(data[pos : pos + 24]))
436            pos += 24
437        return cls(ssrc=ssrc, reports=reports)
438
439
440@dataclass
441class RtcpRtpfbPacket:
442    """
443    Generic RTP Feedback Message (RFC 4585).
444    """
445
446    fmt: int
447    ssrc: int
448    media_ssrc: int
449
450    # generick NACK
451    lost: List[int] = field(default_factory=list)
452
453    def __bytes__(self) -> bytes:
454        payload = pack("!LL", self.ssrc, self.media_ssrc)
455        if self.lost:
456            pid = self.lost[0]
457            blp = 0
458            for p in self.lost[1:]:
459                d = p - pid - 1
460                if d < 16:
461                    blp |= 1 << d
462                else:
463                    payload += pack("!HH", pid, blp)
464                    pid = p
465                    blp = 0
466            payload += pack("!HH", pid, blp)
467        return pack_rtcp_packet(RTCP_RTPFB, self.fmt, payload)
468
469    @classmethod
470    def parse(cls, data: bytes, fmt: int):
471        if len(data) < 8 or len(data) % 4:
472            raise ValueError("RTCP RTP feedback length is invalid")
473
474        ssrc, media_ssrc = unpack("!LL", data[0:8])
475        lost = []
476        for pos in range(8, len(data), 4):
477            pid, blp = unpack("!HH", data[pos : pos + 4])
478            lost.append(pid)
479            for d in range(0, 16):
480                if (blp >> d) & 1:
481                    lost.append(pid + d + 1)
482        return cls(fmt=fmt, ssrc=ssrc, media_ssrc=media_ssrc, lost=lost)
483
484
485@dataclass
486class RtcpSdesPacket:
487    chunks: List[RtcpSourceInfo] = field(default_factory=list)
488
489    def __bytes__(self) -> bytes:
490        payload = b""
491        for chunk in self.chunks:
492            payload += pack("!L", chunk.ssrc)
493            for d_type, d_value in chunk.items:
494                payload += pack("!BB", d_type, len(d_value)) + d_value
495            payload += b"\x00\x00"
496        while len(payload) % 4:
497            payload += b"\x00"
498        return pack_rtcp_packet(RTCP_SDES, len(self.chunks), payload)
499
500    @classmethod
501    def parse(cls, data: bytes, count: int):
502        pos = 0
503        chunks = []
504        for r in range(count):
505            if len(data) < pos + 4:
506                raise ValueError("RTCP SDES source is truncated")
507            ssrc = unpack_from("!L", data, pos)[0]
508            pos += 4
509
510            items = []
511            while pos < len(data) - 1:
512                d_type, d_length = unpack_from("!BB", data, pos)
513                pos += 2
514
515                if len(data) < pos + d_length:
516                    raise ValueError("RTCP SDES item is truncated")
517                d_value = data[pos : pos + d_length]
518                pos += d_length
519                if d_type == 0:
520                    break
521                else:
522                    items.append((d_type, d_value))
523            chunks.append(RtcpSourceInfo(ssrc=ssrc, items=items))
524        return cls(chunks=chunks)
525
526
527@dataclass
528class RtcpSrPacket:
529    ssrc: int
530    sender_info: RtcpSenderInfo
531    reports: List[RtcpReceiverInfo] = field(default_factory=list)
532
533    def __bytes__(self) -> bytes:
534        payload = pack("!L", self.ssrc)
535        payload += bytes(self.sender_info)
536        for report in self.reports:
537            payload += bytes(report)
538        return pack_rtcp_packet(RTCP_SR, len(self.reports), payload)
539
540    @classmethod
541    def parse(cls, data: bytes, count: int):
542        if len(data) != 24 + 24 * count:
543            raise ValueError("RTCP sender report length is invalid")
544
545        ssrc = unpack_from("!L", data)[0]
546        sender_info = RtcpSenderInfo.parse(data[4:24])
547        pos = 24
548        reports = []
549        for r in range(count):
550            reports.append(RtcpReceiverInfo.parse(data[pos : pos + 24]))
551            pos += 24
552        return RtcpSrPacket(ssrc=ssrc, sender_info=sender_info, reports=reports)
553
554
555AnyRtcpPacket = Union[
556    RtcpByePacket,
557    RtcpPsfbPacket,
558    RtcpRrPacket,
559    RtcpRtpfbPacket,
560    RtcpSdesPacket,
561    RtcpSrPacket,
562]
563
564
565class RtcpPacket:
566    @classmethod
567    def parse(cls, data: bytes) -> List[AnyRtcpPacket]:
568        pos = 0
569        packets = []
570
571        while pos < len(data):
572            if len(data) < pos + RTCP_HEADER_LENGTH:
573                raise ValueError(
574                    f"RTCP packet length is less than {RTCP_HEADER_LENGTH} bytes"
575                )
576
577            v_p_count, packet_type, length = unpack("!BBH", data[pos : pos + 4])
578            version = v_p_count >> 6
579            padding = (v_p_count >> 5) & 1
580            count = v_p_count & 0x1F
581            if version != 2:
582                raise ValueError("RTCP packet has invalid version")
583            pos += 4
584
585            end = pos + length * 4
586            if len(data) < end:
587                raise ValueError("RTCP packet is truncated")
588            payload = data[pos:end]
589            pos = end
590
591            if padding:
592                if not payload or not payload[-1] or payload[-1] > len(payload):
593                    raise ValueError("RTCP packet padding length is invalid")
594                payload = payload[0 : -payload[-1]]
595
596            if packet_type == RTCP_BYE:
597                packets.append(RtcpByePacket.parse(payload, count))
598            elif packet_type == RTCP_SDES:
599                packets.append(RtcpSdesPacket.parse(payload, count))
600            elif packet_type == RTCP_SR:
601                packets.append(RtcpSrPacket.parse(payload, count))
602            elif packet_type == RTCP_RR:
603                packets.append(RtcpRrPacket.parse(payload, count))
604            elif packet_type == RTCP_RTPFB:
605                packets.append(RtcpRtpfbPacket.parse(payload, count))
606            elif packet_type == RTCP_PSFB:
607                packets.append(RtcpPsfbPacket.parse(payload, count))
608
609        return packets
610
611
612class RtpPacket:
613    def __init__(
614        self,
615        payload_type: int = 0,
616        marker: int = 0,
617        sequence_number: int = 0,
618        timestamp: int = 0,
619        ssrc: int = 0,
620        payload: bytes = b"",
621    ) -> None:
622        self.version = 2
623        self.marker = marker
624        self.payload_type = payload_type
625        self.sequence_number = sequence_number
626        self.timestamp = timestamp
627        self.ssrc = ssrc
628        self.csrc: List[int] = []
629        self.extensions = HeaderExtensions()
630        self.payload = payload
631        self.padding_size = 0
632
633    def __repr__(self) -> str:
634        return (
635            f"RtpPacket(seq={self.sequence_number}, ts={self.timestamp}, "
636            f"marker={self.marker}, payload={self.payload_type}, {len(self.payload)} bytes)"
637        )
638
639    @classmethod
640    def parse(cls, data: bytes, extensions_map=HeaderExtensionsMap()):
641        if len(data) < RTP_HEADER_LENGTH:
642            raise ValueError(
643                f"RTP packet length is less than {RTP_HEADER_LENGTH} bytes"
644            )
645
646        v_p_x_cc, m_pt, sequence_number, timestamp, ssrc = unpack("!BBHLL", data[0:12])
647        version = v_p_x_cc >> 6
648        padding = (v_p_x_cc >> 5) & 1
649        extension = (v_p_x_cc >> 4) & 1
650        cc = v_p_x_cc & 0x0F
651        if version != 2:
652            raise ValueError("RTP packet has invalid version")
653        if len(data) < RTP_HEADER_LENGTH + 4 * cc:
654            raise ValueError("RTP packet has truncated CSRC")
655
656        packet = cls(
657            marker=(m_pt >> 7),
658            payload_type=(m_pt & 0x7F),
659            sequence_number=sequence_number,
660            timestamp=timestamp,
661            ssrc=ssrc,
662        )
663
664        pos = RTP_HEADER_LENGTH
665        for i in range(0, cc):
666            packet.csrc.append(unpack_from("!L", data, pos)[0])
667            pos += 4
668
669        if extension:
670            if len(data) < pos + 4:
671                raise ValueError("RTP packet has truncated extension profile / length")
672            extension_profile, extension_length = unpack_from("!HH", data, pos)
673            extension_length *= 4
674            pos += 4
675
676            if len(data) < pos + extension_length:
677                raise ValueError("RTP packet has truncated extension value")
678            extension_value = data[pos : pos + extension_length]
679            pos += extension_length
680            packet.extensions = extensions_map.get(extension_profile, extension_value)
681
682        if padding:
683            padding_len = data[-1]
684            if not padding_len or padding_len > len(data) - pos:
685                raise ValueError("RTP packet padding length is invalid")
686            packet.padding_size = padding_len
687            packet.payload = data[pos:-padding_len]
688        else:
689            packet.payload = data[pos:]
690
691        return packet
692
693    def serialize(self, extensions_map=HeaderExtensionsMap()) -> bytes:
694        extension_profile, extension_value = extensions_map.set(self.extensions)
695        has_extension = bool(extension_value)
696
697        padding = self.padding_size > 0
698        data = pack(
699            "!BBHLL",
700            (self.version << 6)
701            | (padding << 5)
702            | (has_extension << 4)
703            | len(self.csrc),
704            (self.marker << 7) | self.payload_type,
705            self.sequence_number,
706            self.timestamp,
707            self.ssrc,
708        )
709        for csrc in self.csrc:
710            data += pack("!L", csrc)
711        if has_extension:
712            data += pack("!HH", extension_profile, len(extension_value) >> 2)
713            data += extension_value
714        data += self.payload
715        if padding:
716            data += os.urandom(self.padding_size - 1)
717            data += bytes([self.padding_size])
718        return data
719
720
721def unwrap_rtx(rtx: RtpPacket, payload_type: int, ssrc: int) -> RtpPacket:
722    """
723    Recover initial packet from a retransmission packet.
724    """
725    packet = RtpPacket(
726        payload_type=payload_type,
727        marker=rtx.marker,
728        sequence_number=unpack("!H", rtx.payload[0:2])[0],
729        timestamp=rtx.timestamp,
730        ssrc=ssrc,
731        payload=rtx.payload[2:],
732    )
733    packet.csrc = rtx.csrc
734    packet.extensions = rtx.extensions
735    return packet
736
737
738def wrap_rtx(
739    packet: RtpPacket, payload_type: int, sequence_number: int, ssrc: int
740) -> RtpPacket:
741    """
742    Create a retransmission packet from a lost packet.
743    """
744    rtx = RtpPacket(
745        payload_type=payload_type,
746        marker=packet.marker,
747        sequence_number=sequence_number,
748        timestamp=packet.timestamp,
749        ssrc=ssrc,
750        payload=pack("!H", packet.sequence_number) + packet.payload,
751    )
752    rtx.csrc = packet.csrc
753    rtx.extensions = packet.extensions
754    return rtx
755