1import asyncio
2from unittest import TestCase
3
4import aioice.stun
5from aioice import ConnectionClosed
6
7from aiortc.exceptions import InvalidStateError
8from aiortc.rtcconfiguration import RTCIceServer
9from aiortc.rtcicetransport import (
10    RTCIceCandidate,
11    RTCIceGatherer,
12    RTCIceParameters,
13    RTCIceTransport,
14    connection_kwargs,
15    parse_stun_turn_uri,
16)
17
18from .utils import run
19
20
21async def mock_connect():
22    pass
23
24
25async def mock_get_event():
26    await asyncio.sleep(0.5)
27    return ConnectionClosed()
28
29
30class ConnectionKwargsTest(TestCase):
31    def test_empty(self):
32        self.assertEqual(connection_kwargs([]), {})
33
34    def test_stun(self):
35        self.assertEqual(
36            connection_kwargs([RTCIceServer("stun:stun.l.google.com:19302")]),
37            {"stun_server": ("stun.l.google.com", 19302)},
38        )
39
40    def test_stun_multiple_servers(self):
41        self.assertEqual(
42            connection_kwargs(
43                [
44                    RTCIceServer("stun:stun.l.google.com:19302"),
45                    RTCIceServer("stun:stun.example.com"),
46                ]
47            ),
48            {"stun_server": ("stun.l.google.com", 19302)},
49        )
50
51    def test_stun_multiple_urls(self):
52        self.assertEqual(
53            connection_kwargs(
54                [
55                    RTCIceServer(
56                        [
57                            "stun:stun1.l.google.com:19302",
58                            "stun:stun2.l.google.com:19302",
59                        ]
60                    )
61                ]
62            ),
63            {"stun_server": ("stun1.l.google.com", 19302)},
64        )
65
66    def test_turn(self):
67        self.assertEqual(
68            connection_kwargs([RTCIceServer("turn:turn.example.com")]),
69            {
70                "turn_password": None,
71                "turn_server": ("turn.example.com", 3478),
72                "turn_ssl": False,
73                "turn_transport": "udp",
74                "turn_username": None,
75            },
76        )
77
78    def test_turn_multiple_servers(self):
79        self.assertEqual(
80            connection_kwargs(
81                [
82                    RTCIceServer("turn:turn.example.com"),
83                    RTCIceServer("turn:turn.example.net"),
84                ]
85            ),
86            {
87                "turn_password": None,
88                "turn_server": ("turn.example.com", 3478),
89                "turn_ssl": False,
90                "turn_transport": "udp",
91                "turn_username": None,
92            },
93        )
94
95    def test_turn_multiple_urls(self):
96        self.assertEqual(
97            connection_kwargs(
98                [RTCIceServer(["turn:turn1.example.com", "turn:turn2.example.com"])]
99            ),
100            {
101                "turn_password": None,
102                "turn_server": ("turn1.example.com", 3478),
103                "turn_ssl": False,
104                "turn_transport": "udp",
105                "turn_username": None,
106            },
107        )
108
109    def test_turn_over_bogus(self):
110        self.assertEqual(
111            connection_kwargs([RTCIceServer("turn:turn.example.com?transport=bogus")]),
112            {},
113        )
114
115    def test_turn_over_tcp(self):
116        self.assertEqual(
117            connection_kwargs([RTCIceServer("turn:turn.example.com?transport=tcp")]),
118            {
119                "turn_password": None,
120                "turn_server": ("turn.example.com", 3478),
121                "turn_ssl": False,
122                "turn_transport": "tcp",
123                "turn_username": None,
124            },
125        )
126
127    def test_turn_with_password(self):
128        self.assertEqual(
129            connection_kwargs(
130                [
131                    RTCIceServer(
132                        urls="turn:turn.example.com", username="foo", credential="bar"
133                    )
134                ]
135            ),
136            {
137                "turn_password": "bar",
138                "turn_server": ("turn.example.com", 3478),
139                "turn_ssl": False,
140                "turn_transport": "udp",
141                "turn_username": "foo",
142            },
143        )
144
145    def test_turn_with_token(self):
146        self.assertEqual(
147            connection_kwargs(
148                [
149                    RTCIceServer(
150                        urls="turn:turn.example.com",
151                        username="foo",
152                        credential="bar",
153                        credentialType="token",
154                    )
155                ]
156            ),
157            {},
158        )
159
160    def test_turns(self):
161        self.assertEqual(
162            connection_kwargs([RTCIceServer("turns:turn.example.com")]),
163            {
164                "turn_password": None,
165                "turn_server": ("turn.example.com", 5349),
166                "turn_ssl": True,
167                "turn_transport": "tcp",
168                "turn_username": None,
169            },
170        )
171
172    def test_turns_over_udp(self):
173        self.assertEqual(
174            connection_kwargs([RTCIceServer("turns:turn.example.com?transport=udp")]),
175            {},
176        )
177
178
179class ParseStunTurnUriTest(TestCase):
180    def test_invalid_scheme(self):
181        with self.assertRaises(ValueError) as cm:
182            parse_stun_turn_uri("foo")
183        self.assertEqual(str(cm.exception), "malformed uri: invalid scheme")
184
185    def test_invalid_uri(self):
186        with self.assertRaises(ValueError) as cm:
187            parse_stun_turn_uri("stun")
188        self.assertEqual(str(cm.exception), "malformed uri")
189
190    def test_stun(self):
191        uri = parse_stun_turn_uri("stun:stun.services.mozilla.com")
192        self.assertEqual(
193            uri, {"host": "stun.services.mozilla.com", "port": 3478, "scheme": "stun"}
194        )
195
196    def test_stuns(self):
197        uri = parse_stun_turn_uri("stuns:stun.services.mozilla.com")
198        self.assertEqual(
199            uri, {"host": "stun.services.mozilla.com", "port": 5349, "scheme": "stuns"}
200        )
201
202    def test_stun_with_port(self):
203        uri = parse_stun_turn_uri("stun:stun.l.google.com:19302")
204        self.assertEqual(
205            uri, {"host": "stun.l.google.com", "port": 19302, "scheme": "stun"}
206        )
207
208    def test_turn(self):
209        uri = parse_stun_turn_uri("turn:1.2.3.4")
210        self.assertEqual(
211            uri, {"host": "1.2.3.4", "port": 3478, "scheme": "turn", "transport": "udp"}
212        )
213
214    def test_turn_with_port_and_transport(self):
215        uri = parse_stun_turn_uri("turn:1.2.3.4:3478?transport=tcp")
216        self.assertEqual(
217            uri, {"host": "1.2.3.4", "port": 3478, "scheme": "turn", "transport": "tcp"}
218        )
219
220    def test_turns(self):
221        uri = parse_stun_turn_uri("turns:1.2.3.4")
222        self.assertEqual(
223            uri,
224            {"host": "1.2.3.4", "port": 5349, "scheme": "turns", "transport": "tcp"},
225        )
226
227    def test_turns_with_port_and_transport(self):
228        uri = parse_stun_turn_uri("turns:1.2.3.4:1234?transport=tcp")
229        self.assertEqual(
230            uri,
231            {"host": "1.2.3.4", "port": 1234, "scheme": "turns", "transport": "tcp"},
232        )
233
234
235class RTCIceGathererTest(TestCase):
236    def test_gather(self):
237        gatherer = RTCIceGatherer()
238        self.assertEqual(gatherer.state, "new")
239        self.assertEqual(gatherer.getLocalCandidates(), [])
240        run(gatherer.gather())
241        self.assertEqual(gatherer.state, "completed")
242        self.assertTrue(len(gatherer.getLocalCandidates()) > 0)
243
244        # close
245        run(gatherer._connection.close())
246
247    def test_default_ice_servers(self):
248        self.assertEqual(
249            RTCIceGatherer.getDefaultIceServers(),
250            [RTCIceServer(urls="stun:stun.l.google.com:19302")],
251        )
252
253
254class RTCIceTransportTest(TestCase):
255    def setUp(self):
256        # save timers
257        self.retry_max = aioice.stun.RETRY_MAX
258        self.retry_rto = aioice.stun.RETRY_RTO
259
260        # shorten timers to run tests faster
261        aioice.stun.RETRY_MAX = 1
262        aioice.stun.RETRY_RTO = 0.1
263
264    def tearDown(self):
265        # restore timers
266        aioice.stun.RETRY_MAX = self.retry_max
267        aioice.stun.RETRY_RTO = self.retry_rto
268
269    def test_construct(self):
270        gatherer = RTCIceGatherer()
271        connection = RTCIceTransport(gatherer)
272        self.assertEqual(connection.state, "new")
273        self.assertEqual(connection.getRemoteCandidates(), [])
274
275        candidate = RTCIceCandidate(
276            component=1,
277            foundation="0",
278            ip="192.168.99.7",
279            port=33543,
280            priority=2122252543,
281            protocol="UDP",
282            type="host",
283        )
284
285        # add candidate
286        run(connection.addRemoteCandidate(candidate))
287        self.assertEqual(connection.getRemoteCandidates(), [candidate])
288
289        # end-of-candidates
290        run(connection.addRemoteCandidate(None))
291        self.assertEqual(connection.getRemoteCandidates(), [candidate])
292
293    def test_connect(self):
294        gatherer_1 = RTCIceGatherer()
295        transport_1 = RTCIceTransport(gatherer_1)
296
297        gatherer_2 = RTCIceGatherer()
298        transport_2 = RTCIceTransport(gatherer_2)
299
300        # gather candidates
301        run(asyncio.gather(gatherer_1.gather(), gatherer_2.gather()))
302        for candidate in gatherer_2.getLocalCandidates():
303            run(transport_1.addRemoteCandidate(candidate))
304        for candidate in gatherer_1.getLocalCandidates():
305            run(transport_2.addRemoteCandidate(candidate))
306        self.assertEqual(transport_1.state, "new")
307        self.assertEqual(transport_2.state, "new")
308
309        # connect
310        run(
311            asyncio.gather(
312                transport_1.start(gatherer_2.getLocalParameters()),
313                transport_2.start(gatherer_1.getLocalParameters()),
314            )
315        )
316        self.assertEqual(transport_1.state, "completed")
317        self.assertEqual(transport_2.state, "completed")
318
319        # cleanup
320        run(asyncio.gather(transport_1.stop(), transport_2.stop()))
321        self.assertEqual(transport_1.state, "closed")
322        self.assertEqual(transport_2.state, "closed")
323
324    def test_connect_fail(self):
325        gatherer_1 = RTCIceGatherer()
326        transport_1 = RTCIceTransport(gatherer_1)
327
328        gatherer_2 = RTCIceGatherer()
329        transport_2 = RTCIceTransport(gatherer_2)
330
331        # gather candidates
332        run(asyncio.gather(gatherer_1.gather(), gatherer_2.gather()))
333        for candidate in gatherer_2.getLocalCandidates():
334            run(transport_1.addRemoteCandidate(candidate))
335        for candidate in gatherer_1.getLocalCandidates():
336            run(transport_2.addRemoteCandidate(candidate))
337        self.assertEqual(transport_1.state, "new")
338        self.assertEqual(transport_2.state, "new")
339
340        # connect
341        run(transport_2.stop())
342        run(transport_1.start(gatherer_2.getLocalParameters()))
343        self.assertEqual(transport_1.state, "failed")
344        self.assertEqual(transport_2.state, "closed")
345
346        # cleanup
347        run(asyncio.gather(transport_1.stop(), transport_2.stop()))
348        self.assertEqual(transport_1.state, "closed")
349        self.assertEqual(transport_2.state, "closed")
350
351    def test_connect_when_closed(self):
352        gatherer = RTCIceGatherer()
353        transport = RTCIceTransport(gatherer)
354
355        # stop transport
356        run(transport.stop())
357        self.assertEqual(transport.state, "closed")
358
359        # try to start it
360        with self.assertRaises(InvalidStateError) as cm:
361            run(
362                transport.start(
363                    RTCIceParameters(usernameFragment="foo", password="bar")
364                )
365            )
366        self.assertEqual(str(cm.exception), "RTCIceTransport is closed")
367
368    def test_connection_closed(self):
369        gatherer = RTCIceGatherer()
370
371        # mock out methods
372        gatherer._connection.connect = mock_connect
373        gatherer._connection.get_event = mock_get_event
374
375        transport = RTCIceTransport(gatherer)
376        self.assertEqual(transport.state, "new")
377
378        run(transport.start(RTCIceParameters(usernameFragment="foo", password="bar")))
379        self.assertEqual(transport.state, "completed")
380
381        run(asyncio.sleep(1))
382        self.assertEqual(transport.state, "failed")
383
384        run(transport.stop())
385        self.assertEqual(transport.state, "closed")
386