1import asyncio
2import logging
3import random
4import time
5import traceback
6import uuid
7from typing import Dict, List, Optional, Union
8
9from . import clock, rtp
10from .codecs import get_capabilities, get_encoder, is_rtx
11from .codecs.base import Encoder
12from .exceptions import InvalidStateError
13from .mediastreams import MediaStreamError, MediaStreamTrack
14from .rtcrtpparameters import RTCRtpCodecParameters, RTCRtpSendParameters
15from .rtp import (
16    RTCP_PSFB_APP,
17    RTCP_PSFB_PLI,
18    RTCP_RTPFB_NACK,
19    AnyRtcpPacket,
20    RtcpByePacket,
21    RtcpPsfbPacket,
22    RtcpRrPacket,
23    RtcpRtpfbPacket,
24    RtcpSdesPacket,
25    RtcpSenderInfo,
26    RtcpSourceInfo,
27    RtcpSrPacket,
28    RtpPacket,
29    unpack_remb_fci,
30    wrap_rtx,
31)
32from .stats import (
33    RTCOutboundRtpStreamStats,
34    RTCRemoteInboundRtpStreamStats,
35    RTCStatsReport,
36)
37from .utils import random16, random32, uint16_add, uint32_add
38
39logger = logging.getLogger(__name__)
40
41RTP_HISTORY_SIZE = 128
42RTT_ALPHA = 0.85
43
44
45class RTCRtpSender:
46    """
47    The :class:`RTCRtpSender` interface provides the ability to control and
48    obtain details about how a particular :class:`MediaStreamTrack` is encoded
49    and sent to a remote peer.
50
51    :param trackOrKind: Either a :class:`MediaStreamTrack` instance or a
52                         media kind (`'audio'` or `'video'`).
53    :param transport: An :class:`RTCDtlsTransport`.
54    """
55
56    def __init__(self, trackOrKind: Union[MediaStreamTrack, str], transport) -> None:
57        if transport.state == "closed":
58            raise InvalidStateError
59
60        if isinstance(trackOrKind, MediaStreamTrack):
61            self.__kind = trackOrKind.kind
62            self.replaceTrack(trackOrKind)
63        else:
64            self.__kind = trackOrKind
65            self.replaceTrack(None)
66        self.__cname: Optional[str] = None
67        self._ssrc = random32()
68        self._rtx_ssrc = random32()
69        # FIXME: how should this be initialised?
70        self._stream_id = str(uuid.uuid4())
71        self.__encoder: Optional[Encoder] = None
72        self.__force_keyframe = False
73        self.__loop = asyncio.get_event_loop()
74        self.__mid: Optional[str] = None
75        self.__rtp_exited = asyncio.Event()
76        self.__rtp_header_extensions_map = rtp.HeaderExtensionsMap()
77        self.__rtp_task: Optional[asyncio.Future[None]] = None
78        self.__rtp_history: Dict[int, RtpPacket] = {}
79        self.__rtcp_exited = asyncio.Event()
80        self.__rtcp_task: Optional[asyncio.Future[None]] = None
81        self.__rtx_payload_type: Optional[int] = None
82        self.__rtx_sequence_number = random16()
83        self.__started = False
84        self.__stats = RTCStatsReport()
85        self.__transport = transport
86
87        # stats
88        self.__lsr: Optional[int] = None
89        self.__lsr_time: Optional[float] = None
90        self.__ntp_timestamp = 0
91        self.__rtp_timestamp = 0
92        self.__octet_count = 0
93        self.__packet_count = 0
94        self.__rtt = None
95
96    @property
97    def kind(self):
98        return self.__kind
99
100    @property
101    def track(self) -> MediaStreamTrack:
102        """
103        The :class:`MediaStreamTrack` which is being handled by the sender.
104        """
105        return self.__track
106
107    @property
108    def transport(self):
109        """
110        The :class:`RTCDtlsTransport` over which media data for the track is
111        transmitted.
112        """
113        return self.__transport
114
115    @classmethod
116    def getCapabilities(self, kind):
117        """
118        Returns the most optimistic view of the system's capabilities for
119        sending media of the given `kind`.
120
121        :rtype: :class:`RTCRtpCapabilities`
122        """
123        return get_capabilities(kind)
124
125    async def getStats(self) -> RTCStatsReport:
126        """
127        Returns statistics about the RTP sender.
128
129        :rtype: :class:`RTCStatsReport`
130        """
131        self.__stats.add(
132            RTCOutboundRtpStreamStats(
133                # RTCStats
134                timestamp=clock.current_datetime(),
135                type="outbound-rtp",
136                id="outbound-rtp_" + str(id(self)),
137                # RTCStreamStats
138                ssrc=self._ssrc,
139                kind=self.__kind,
140                transportId=self.transport._stats_id,
141                # RTCSentRtpStreamStats
142                packetsSent=self.__packet_count,
143                bytesSent=self.__octet_count,
144                # RTCOutboundRtpStreamStats
145                trackId=str(id(self.track)),
146            )
147        )
148        self.__stats.update(self.transport._get_stats())
149
150        return self.__stats
151
152    def replaceTrack(self, track: Optional[MediaStreamTrack]) -> None:
153        self.__track = track
154        if track is not None:
155            self._track_id = track.id
156        else:
157            self._track_id = str(uuid.uuid4())
158
159    def setTransport(self, transport) -> None:
160        self.__transport = transport
161
162    async def send(self, parameters: RTCRtpSendParameters) -> None:
163        """
164        Attempt to set the parameters controlling the sending of media.
165
166        :param parameters: The :class:`RTCRtpSendParameters` for the sender.
167        """
168        if not self.__started:
169            self.__cname = parameters.rtcp.cname
170            self.__mid = parameters.muxId
171
172            # make note of the RTP header extension IDs
173            self.__transport._register_rtp_sender(self, parameters)
174            self.__rtp_header_extensions_map.configure(parameters)
175
176            # make note of RTX payload type
177            for codec in parameters.codecs:
178                if (
179                    is_rtx(codec)
180                    and codec.parameters["apt"] == parameters.codecs[0].payloadType
181                ):
182                    self.__rtx_payload_type = codec.payloadType
183                    break
184
185            self.__rtp_task = asyncio.ensure_future(self._run_rtp(parameters.codecs[0]))
186            self.__rtcp_task = asyncio.ensure_future(self._run_rtcp())
187            self.__started = True
188
189    async def stop(self):
190        """
191        Irreversibly stop the sender.
192        """
193        if self.__started:
194            self.__transport._unregister_rtp_sender(self)
195            self.__rtp_task.cancel()
196            self.__rtcp_task.cancel()
197            await asyncio.gather(self.__rtp_exited.wait(), self.__rtcp_exited.wait())
198
199    async def _handle_rtcp_packet(self, packet):
200        if isinstance(packet, (RtcpRrPacket, RtcpSrPacket)):
201            for report in filter(lambda x: x.ssrc == self._ssrc, packet.reports):
202                # estimate round-trip time
203                if self.__lsr == report.lsr and report.dlsr:
204                    rtt = time.time() - self.__lsr_time - (report.dlsr / 65536)
205                    if self.__rtt is None:
206                        self.__rtt = rtt
207                    else:
208                        self.__rtt = RTT_ALPHA * self.__rtt + (1 - RTT_ALPHA) * rtt
209
210                self.__stats.add(
211                    RTCRemoteInboundRtpStreamStats(
212                        # RTCStats
213                        timestamp=clock.current_datetime(),
214                        type="remote-inbound-rtp",
215                        id="remote-inbound-rtp_" + str(id(self)),
216                        # RTCStreamStats
217                        ssrc=packet.ssrc,
218                        kind=self.__kind,
219                        transportId=self.transport._stats_id,
220                        # RTCReceivedRtpStreamStats
221                        packetsReceived=self.__packet_count - report.packets_lost,
222                        packetsLost=report.packets_lost,
223                        jitter=report.jitter,
224                        # RTCRemoteInboundRtpStreamStats
225                        roundTripTime=self.__rtt,
226                        fractionLost=report.fraction_lost,
227                    )
228                )
229        elif isinstance(packet, RtcpRtpfbPacket) and packet.fmt == RTCP_RTPFB_NACK:
230            for seq in packet.lost:
231                await self._retransmit(seq)
232        elif isinstance(packet, RtcpPsfbPacket) and packet.fmt == RTCP_PSFB_PLI:
233            self._send_keyframe()
234        elif isinstance(packet, RtcpPsfbPacket) and packet.fmt == RTCP_PSFB_APP:
235            try:
236                bitrate, ssrcs = unpack_remb_fci(packet.fci)
237                if self._ssrc in ssrcs:
238                    self.__log_debug(
239                        "- receiver estimated maximum bitrate %d bps", bitrate
240                    )
241                    if self.__encoder and hasattr(self.__encoder, "target_bitrate"):
242                        self.__encoder.target_bitrate = bitrate
243            except ValueError:
244                pass
245
246    async def _next_encoded_frame(self, codec: RTCRtpCodecParameters):
247        # get frame
248        frame = await self.__track.recv()
249
250        # encode frame
251        if self.__encoder is None:
252            self.__encoder = get_encoder(codec)
253        force_keyframe = self.__force_keyframe
254        self.__force_keyframe = False
255        return await self.__loop.run_in_executor(
256            None, self.__encoder.encode, frame, force_keyframe
257        )
258
259    async def _retransmit(self, sequence_number: int) -> None:
260        """
261        Retransmit an RTP packet which was reported as lost.
262        """
263        packet = self.__rtp_history.get(sequence_number % RTP_HISTORY_SIZE)
264        if packet and packet.sequence_number == sequence_number:
265            if self.__rtx_payload_type is not None:
266                packet = wrap_rtx(
267                    packet,
268                    payload_type=self.__rtx_payload_type,
269                    sequence_number=self.__rtx_sequence_number,
270                    ssrc=self._rtx_ssrc,
271                )
272                self.__rtx_sequence_number = uint16_add(self.__rtx_sequence_number, 1)
273
274            self.__log_debug("> %s", packet)
275            packet_bytes = packet.serialize(self.__rtp_header_extensions_map)
276            await self.transport._send_rtp(packet_bytes)
277
278    def _send_keyframe(self) -> None:
279        """
280        Request the next frame to be a keyframe.
281        """
282        self.__force_keyframe = True
283
284    async def _run_rtp(self, codec: RTCRtpCodecParameters) -> None:
285        self.__log_debug("- RTP started")
286
287        sequence_number = random16()
288        timestamp_origin = random32()
289        try:
290            while True:
291                if not self.__track:
292                    await asyncio.sleep(0.02)
293                    continue
294
295                payloads, timestamp = await self._next_encoded_frame(codec)
296                timestamp = uint32_add(timestamp_origin, timestamp)
297
298                for i, payload in enumerate(payloads):
299                    packet = RtpPacket(
300                        payload_type=codec.payloadType,
301                        sequence_number=sequence_number,
302                        timestamp=timestamp,
303                    )
304                    packet.ssrc = self._ssrc
305                    packet.payload = payload
306                    packet.marker = (i == len(payloads) - 1) and 1 or 0
307
308                    # set header extensions
309                    packet.extensions.abs_send_time = (
310                        clock.current_ntp_time() >> 14
311                    ) & 0x00FFFFFF
312                    packet.extensions.mid = self.__mid
313
314                    # send packet
315                    self.__log_debug("> %s", packet)
316                    self.__rtp_history[
317                        packet.sequence_number % RTP_HISTORY_SIZE
318                    ] = packet
319                    packet_bytes = packet.serialize(self.__rtp_header_extensions_map)
320                    await self.transport._send_rtp(packet_bytes)
321
322                    self.__ntp_timestamp = clock.current_ntp_time()
323                    self.__rtp_timestamp = packet.timestamp
324                    self.__octet_count += len(payload)
325                    self.__packet_count += 1
326                    sequence_number = uint16_add(sequence_number, 1)
327        except (asyncio.CancelledError, ConnectionError, MediaStreamError):
328            pass
329        except Exception:
330            # we *need* to set __rtp_exited, otherwise RTCRtpSender.stop() will hang,
331            # so issue a warning if we hit an unexpected exception
332            self.__log_warning(traceback.format_exc())
333
334        # stop track
335        if self.__track:
336            self.__track.stop()
337            self.__track = None
338
339        self.__log_debug("- RTP finished")
340        self.__rtp_exited.set()
341
342    async def _run_rtcp(self) -> None:
343        self.__log_debug("- RTCP started")
344
345        try:
346            while True:
347                # The interval between RTCP packets is varied randomly over the
348                # range [0.5, 1.5] times the calculated interval.
349                await asyncio.sleep(0.5 + random.random())
350
351                # RTCP SR
352                packets: List[AnyRtcpPacket] = [
353                    RtcpSrPacket(
354                        ssrc=self._ssrc,
355                        sender_info=RtcpSenderInfo(
356                            ntp_timestamp=self.__ntp_timestamp,
357                            rtp_timestamp=self.__rtp_timestamp,
358                            packet_count=self.__packet_count,
359                            octet_count=self.__octet_count,
360                        ),
361                    )
362                ]
363                self.__lsr = ((self.__ntp_timestamp) >> 16) & 0xFFFFFFFF
364                self.__lsr_time = time.time()
365
366                # RTCP SDES
367                if self.__cname is not None:
368                    packets.append(
369                        RtcpSdesPacket(
370                            chunks=[
371                                RtcpSourceInfo(
372                                    ssrc=self._ssrc,
373                                    items=[(1, self.__cname.encode("utf8"))],
374                                )
375                            ]
376                        )
377                    )
378
379                await self._send_rtcp(packets)
380        except asyncio.CancelledError:
381            pass
382
383        # RTCP BYE
384        packet = RtcpByePacket(sources=[self._ssrc])
385        await self._send_rtcp([packet])
386
387        self.__log_debug("- RTCP finished")
388        self.__rtcp_exited.set()
389
390    async def _send_rtcp(self, packets: List[AnyRtcpPacket]) -> None:
391        payload = b""
392        for packet in packets:
393            self.__log_debug("> %s", packet)
394            payload += bytes(packet)
395
396        try:
397            await self.transport._send_rtp(payload)
398        except ConnectionError:
399            pass
400
401    def __log_debug(self, msg: str, *args) -> None:
402        logger.debug(f"RTCRtpSender(%s) {msg}", self.__kind, *args)
403
404    def __log_warning(self, msg: str, *args) -> None:
405        logger.warning(f"RTCRtpsender(%s) {msg}", self.__kind, *args)
406