1import asyncio 2from unittest import TestCase 3 4import aioice.stun 5from aioice import ConnectionClosed 6 7from aiortc.exceptions import InvalidStateError 8from aiortc.rtcconfiguration import RTCIceServer 9from aiortc.rtcicetransport import ( 10 RTCIceCandidate, 11 RTCIceGatherer, 12 RTCIceParameters, 13 RTCIceTransport, 14 connection_kwargs, 15 parse_stun_turn_uri, 16) 17 18from .utils import run 19 20 21async def mock_connect(): 22 pass 23 24 25async def mock_get_event(): 26 await asyncio.sleep(0.5) 27 return ConnectionClosed() 28 29 30class ConnectionKwargsTest(TestCase): 31 def test_empty(self): 32 self.assertEqual(connection_kwargs([]), {}) 33 34 def test_stun(self): 35 self.assertEqual( 36 connection_kwargs([RTCIceServer("stun:stun.l.google.com:19302")]), 37 {"stun_server": ("stun.l.google.com", 19302)}, 38 ) 39 40 def test_stun_multiple_servers(self): 41 self.assertEqual( 42 connection_kwargs( 43 [ 44 RTCIceServer("stun:stun.l.google.com:19302"), 45 RTCIceServer("stun:stun.example.com"), 46 ] 47 ), 48 {"stun_server": ("stun.l.google.com", 19302)}, 49 ) 50 51 def test_stun_multiple_urls(self): 52 self.assertEqual( 53 connection_kwargs( 54 [ 55 RTCIceServer( 56 [ 57 "stun:stun1.l.google.com:19302", 58 "stun:stun2.l.google.com:19302", 59 ] 60 ) 61 ] 62 ), 63 {"stun_server": ("stun1.l.google.com", 19302)}, 64 ) 65 66 def test_turn(self): 67 self.assertEqual( 68 connection_kwargs([RTCIceServer("turn:turn.example.com")]), 69 { 70 "turn_password": None, 71 "turn_server": ("turn.example.com", 3478), 72 "turn_ssl": False, 73 "turn_transport": "udp", 74 "turn_username": None, 75 }, 76 ) 77 78 def test_turn_multiple_servers(self): 79 self.assertEqual( 80 connection_kwargs( 81 [ 82 RTCIceServer("turn:turn.example.com"), 83 RTCIceServer("turn:turn.example.net"), 84 ] 85 ), 86 { 87 "turn_password": None, 88 "turn_server": ("turn.example.com", 3478), 89 "turn_ssl": False, 90 "turn_transport": "udp", 91 "turn_username": None, 92 }, 93 ) 94 95 def test_turn_multiple_urls(self): 96 self.assertEqual( 97 connection_kwargs( 98 [RTCIceServer(["turn:turn1.example.com", "turn:turn2.example.com"])] 99 ), 100 { 101 "turn_password": None, 102 "turn_server": ("turn1.example.com", 3478), 103 "turn_ssl": False, 104 "turn_transport": "udp", 105 "turn_username": None, 106 }, 107 ) 108 109 def test_turn_over_bogus(self): 110 self.assertEqual( 111 connection_kwargs([RTCIceServer("turn:turn.example.com?transport=bogus")]), 112 {}, 113 ) 114 115 def test_turn_over_tcp(self): 116 self.assertEqual( 117 connection_kwargs([RTCIceServer("turn:turn.example.com?transport=tcp")]), 118 { 119 "turn_password": None, 120 "turn_server": ("turn.example.com", 3478), 121 "turn_ssl": False, 122 "turn_transport": "tcp", 123 "turn_username": None, 124 }, 125 ) 126 127 def test_turn_with_password(self): 128 self.assertEqual( 129 connection_kwargs( 130 [ 131 RTCIceServer( 132 urls="turn:turn.example.com", username="foo", credential="bar" 133 ) 134 ] 135 ), 136 { 137 "turn_password": "bar", 138 "turn_server": ("turn.example.com", 3478), 139 "turn_ssl": False, 140 "turn_transport": "udp", 141 "turn_username": "foo", 142 }, 143 ) 144 145 def test_turn_with_token(self): 146 self.assertEqual( 147 connection_kwargs( 148 [ 149 RTCIceServer( 150 urls="turn:turn.example.com", 151 username="foo", 152 credential="bar", 153 credentialType="token", 154 ) 155 ] 156 ), 157 {}, 158 ) 159 160 def test_turns(self): 161 self.assertEqual( 162 connection_kwargs([RTCIceServer("turns:turn.example.com")]), 163 { 164 "turn_password": None, 165 "turn_server": ("turn.example.com", 5349), 166 "turn_ssl": True, 167 "turn_transport": "tcp", 168 "turn_username": None, 169 }, 170 ) 171 172 def test_turns_over_udp(self): 173 self.assertEqual( 174 connection_kwargs([RTCIceServer("turns:turn.example.com?transport=udp")]), 175 {}, 176 ) 177 178 179class ParseStunTurnUriTest(TestCase): 180 def test_invalid_scheme(self): 181 with self.assertRaises(ValueError) as cm: 182 parse_stun_turn_uri("foo") 183 self.assertEqual(str(cm.exception), "malformed uri: invalid scheme") 184 185 def test_invalid_uri(self): 186 with self.assertRaises(ValueError) as cm: 187 parse_stun_turn_uri("stun") 188 self.assertEqual(str(cm.exception), "malformed uri") 189 190 def test_stun(self): 191 uri = parse_stun_turn_uri("stun:stun.services.mozilla.com") 192 self.assertEqual( 193 uri, {"host": "stun.services.mozilla.com", "port": 3478, "scheme": "stun"} 194 ) 195 196 def test_stuns(self): 197 uri = parse_stun_turn_uri("stuns:stun.services.mozilla.com") 198 self.assertEqual( 199 uri, {"host": "stun.services.mozilla.com", "port": 5349, "scheme": "stuns"} 200 ) 201 202 def test_stun_with_port(self): 203 uri = parse_stun_turn_uri("stun:stun.l.google.com:19302") 204 self.assertEqual( 205 uri, {"host": "stun.l.google.com", "port": 19302, "scheme": "stun"} 206 ) 207 208 def test_turn(self): 209 uri = parse_stun_turn_uri("turn:1.2.3.4") 210 self.assertEqual( 211 uri, {"host": "1.2.3.4", "port": 3478, "scheme": "turn", "transport": "udp"} 212 ) 213 214 def test_turn_with_port_and_transport(self): 215 uri = parse_stun_turn_uri("turn:1.2.3.4:3478?transport=tcp") 216 self.assertEqual( 217 uri, {"host": "1.2.3.4", "port": 3478, "scheme": "turn", "transport": "tcp"} 218 ) 219 220 def test_turns(self): 221 uri = parse_stun_turn_uri("turns:1.2.3.4") 222 self.assertEqual( 223 uri, 224 {"host": "1.2.3.4", "port": 5349, "scheme": "turns", "transport": "tcp"}, 225 ) 226 227 def test_turns_with_port_and_transport(self): 228 uri = parse_stun_turn_uri("turns:1.2.3.4:1234?transport=tcp") 229 self.assertEqual( 230 uri, 231 {"host": "1.2.3.4", "port": 1234, "scheme": "turns", "transport": "tcp"}, 232 ) 233 234 235class RTCIceGathererTest(TestCase): 236 def test_gather(self): 237 gatherer = RTCIceGatherer() 238 self.assertEqual(gatherer.state, "new") 239 self.assertEqual(gatherer.getLocalCandidates(), []) 240 run(gatherer.gather()) 241 self.assertEqual(gatherer.state, "completed") 242 self.assertTrue(len(gatherer.getLocalCandidates()) > 0) 243 244 # close 245 run(gatherer._connection.close()) 246 247 def test_default_ice_servers(self): 248 self.assertEqual( 249 RTCIceGatherer.getDefaultIceServers(), 250 [RTCIceServer(urls="stun:stun.l.google.com:19302")], 251 ) 252 253 254class RTCIceTransportTest(TestCase): 255 def setUp(self): 256 # save timers 257 self.retry_max = aioice.stun.RETRY_MAX 258 self.retry_rto = aioice.stun.RETRY_RTO 259 260 # shorten timers to run tests faster 261 aioice.stun.RETRY_MAX = 1 262 aioice.stun.RETRY_RTO = 0.1 263 264 def tearDown(self): 265 # restore timers 266 aioice.stun.RETRY_MAX = self.retry_max 267 aioice.stun.RETRY_RTO = self.retry_rto 268 269 def test_construct(self): 270 gatherer = RTCIceGatherer() 271 connection = RTCIceTransport(gatherer) 272 self.assertEqual(connection.state, "new") 273 self.assertEqual(connection.getRemoteCandidates(), []) 274 275 candidate = RTCIceCandidate( 276 component=1, 277 foundation="0", 278 ip="192.168.99.7", 279 port=33543, 280 priority=2122252543, 281 protocol="UDP", 282 type="host", 283 ) 284 285 # add candidate 286 run(connection.addRemoteCandidate(candidate)) 287 self.assertEqual(connection.getRemoteCandidates(), [candidate]) 288 289 # end-of-candidates 290 run(connection.addRemoteCandidate(None)) 291 self.assertEqual(connection.getRemoteCandidates(), [candidate]) 292 293 def test_connect(self): 294 gatherer_1 = RTCIceGatherer() 295 transport_1 = RTCIceTransport(gatherer_1) 296 297 gatherer_2 = RTCIceGatherer() 298 transport_2 = RTCIceTransport(gatherer_2) 299 300 # gather candidates 301 run(asyncio.gather(gatherer_1.gather(), gatherer_2.gather())) 302 for candidate in gatherer_2.getLocalCandidates(): 303 run(transport_1.addRemoteCandidate(candidate)) 304 for candidate in gatherer_1.getLocalCandidates(): 305 run(transport_2.addRemoteCandidate(candidate)) 306 self.assertEqual(transport_1.state, "new") 307 self.assertEqual(transport_2.state, "new") 308 309 # connect 310 run( 311 asyncio.gather( 312 transport_1.start(gatherer_2.getLocalParameters()), 313 transport_2.start(gatherer_1.getLocalParameters()), 314 ) 315 ) 316 self.assertEqual(transport_1.state, "completed") 317 self.assertEqual(transport_2.state, "completed") 318 319 # cleanup 320 run(asyncio.gather(transport_1.stop(), transport_2.stop())) 321 self.assertEqual(transport_1.state, "closed") 322 self.assertEqual(transport_2.state, "closed") 323 324 def test_connect_fail(self): 325 gatherer_1 = RTCIceGatherer() 326 transport_1 = RTCIceTransport(gatherer_1) 327 328 gatherer_2 = RTCIceGatherer() 329 transport_2 = RTCIceTransport(gatherer_2) 330 331 # gather candidates 332 run(asyncio.gather(gatherer_1.gather(), gatherer_2.gather())) 333 for candidate in gatherer_2.getLocalCandidates(): 334 run(transport_1.addRemoteCandidate(candidate)) 335 for candidate in gatherer_1.getLocalCandidates(): 336 run(transport_2.addRemoteCandidate(candidate)) 337 self.assertEqual(transport_1.state, "new") 338 self.assertEqual(transport_2.state, "new") 339 340 # connect 341 run(transport_2.stop()) 342 run(transport_1.start(gatherer_2.getLocalParameters())) 343 self.assertEqual(transport_1.state, "failed") 344 self.assertEqual(transport_2.state, "closed") 345 346 # cleanup 347 run(asyncio.gather(transport_1.stop(), transport_2.stop())) 348 self.assertEqual(transport_1.state, "closed") 349 self.assertEqual(transport_2.state, "closed") 350 351 def test_connect_when_closed(self): 352 gatherer = RTCIceGatherer() 353 transport = RTCIceTransport(gatherer) 354 355 # stop transport 356 run(transport.stop()) 357 self.assertEqual(transport.state, "closed") 358 359 # try to start it 360 with self.assertRaises(InvalidStateError) as cm: 361 run( 362 transport.start( 363 RTCIceParameters(usernameFragment="foo", password="bar") 364 ) 365 ) 366 self.assertEqual(str(cm.exception), "RTCIceTransport is closed") 367 368 def test_connection_closed(self): 369 gatherer = RTCIceGatherer() 370 371 # mock out methods 372 gatherer._connection.connect = mock_connect 373 gatherer._connection.get_event = mock_get_event 374 375 transport = RTCIceTransport(gatherer) 376 self.assertEqual(transport.state, "new") 377 378 run(transport.start(RTCIceParameters(usernameFragment="foo", password="bar"))) 379 self.assertEqual(transport.state, "completed") 380 381 run(asyncio.sleep(1)) 382 self.assertEqual(transport.state, "failed") 383 384 run(transport.stop()) 385 self.assertEqual(transport.state, "closed") 386