1"""Test asyncio support""" 2# Copyright (c) PyZMQ Developers 3# Distributed under the terms of the Modified BSD License. 4 5import json 6from multiprocessing import Process 7import os 8import sys 9 10import pytest 11from pytest import mark 12 13import zmq 14from zmq.utils.strtypes import u 15 16import asyncio 17import zmq.asyncio as zaio 18from zmq.auth.asyncio import AsyncioAuthenticator 19 20from concurrent.futures import CancelledError 21from zmq.tests import BaseZMQTestCase 22from zmq.tests.test_auth import TestThreadAuthentication 23 24 25class ProcessForTeardownTest(Process): 26 def __init__(self, event_loop_policy_class): 27 Process.__init__(self) 28 self.event_loop_policy_class = event_loop_policy_class 29 30 def run(self): 31 """Leave context, socket and event loop upon implicit disposal""" 32 asyncio.set_event_loop_policy(self.event_loop_policy_class()) 33 34 actx = zaio.Context.instance() 35 socket = actx.socket(zmq.PAIR) 36 socket.bind_to_random_port("tcp://127.0.0.1") 37 38 async def never_ending_task(socket): 39 await socket.recv() # never ever receive anything 40 41 loop = asyncio.get_event_loop() 42 coro = asyncio.wait_for(never_ending_task(socket), timeout=1) 43 try: 44 loop.run_until_complete(coro) 45 except asyncio.TimeoutError: 46 pass # expected timeout 47 else: 48 assert False, "never_ending_task was completed unexpectedly" 49 50 51class TestAsyncIOSocket(BaseZMQTestCase): 52 Context = zaio.Context 53 54 def setUp(self): 55 self.loop = asyncio.new_event_loop() 56 asyncio.set_event_loop(self.loop) 57 super(TestAsyncIOSocket, self).setUp() 58 59 def tearDown(self): 60 super().tearDown() 61 self.loop.close() 62 # verify cleanup of references to selectors 63 assert zaio._selectors == {} 64 if 'zmq._asyncio_selector' in sys.modules: 65 assert zmq._asyncio_selector._selector_loops == set() 66 67 def test_socket_class(self): 68 s = self.context.socket(zmq.PUSH) 69 assert isinstance(s, zaio.Socket) 70 s.close() 71 72 def test_instance_subclass_first(self): 73 actx = zmq.asyncio.Context.instance() 74 ctx = zmq.Context.instance() 75 ctx.term() 76 actx.term() 77 assert type(ctx) is zmq.Context 78 assert type(actx) is zmq.asyncio.Context 79 80 def test_instance_subclass_second(self): 81 ctx = zmq.Context.instance() 82 actx = zmq.asyncio.Context.instance() 83 ctx.term() 84 actx.term() 85 assert type(ctx) is zmq.Context 86 assert type(actx) is zmq.asyncio.Context 87 88 def test_recv_multipart(self): 89 async def test(): 90 a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) 91 f = b.recv_multipart() 92 assert not f.done() 93 await a.send(b"hi") 94 recvd = await f 95 self.assertEqual(recvd, [b"hi"]) 96 97 self.loop.run_until_complete(test()) 98 99 def test_recv(self): 100 async def test(): 101 a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) 102 f1 = b.recv() 103 f2 = b.recv() 104 assert not f1.done() 105 assert not f2.done() 106 await a.send_multipart([b"hi", b"there"]) 107 recvd = await f2 108 assert f1.done() 109 self.assertEqual(f1.result(), b"hi") 110 self.assertEqual(recvd, b"there") 111 112 self.loop.run_until_complete(test()) 113 114 @mark.skipif(not hasattr(zmq, "RCVTIMEO"), reason="requires RCVTIMEO") 115 def test_recv_timeout(self): 116 async def test(): 117 a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) 118 b.rcvtimeo = 100 119 f1 = b.recv() 120 b.rcvtimeo = 1000 121 f2 = b.recv_multipart() 122 with self.assertRaises(zmq.Again): 123 await f1 124 await a.send_multipart([b"hi", b"there"]) 125 recvd = await f2 126 assert f2.done() 127 self.assertEqual(recvd, [b"hi", b"there"]) 128 129 self.loop.run_until_complete(test()) 130 131 @mark.skipif(not hasattr(zmq, "SNDTIMEO"), reason="requires SNDTIMEO") 132 def test_send_timeout(self): 133 async def test(): 134 s = self.socket(zmq.PUSH) 135 s.sndtimeo = 100 136 with self.assertRaises(zmq.Again): 137 await s.send(b"not going anywhere") 138 139 self.loop.run_until_complete(test()) 140 141 def test_recv_string(self): 142 async def test(): 143 a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) 144 f = b.recv_string() 145 assert not f.done() 146 msg = u("πøøπ") 147 await a.send_string(msg) 148 recvd = await f 149 assert f.done() 150 self.assertEqual(f.result(), msg) 151 self.assertEqual(recvd, msg) 152 153 self.loop.run_until_complete(test()) 154 155 def test_recv_json(self): 156 async def test(): 157 a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) 158 f = b.recv_json() 159 assert not f.done() 160 obj = dict(a=5) 161 await a.send_json(obj) 162 recvd = await f 163 assert f.done() 164 self.assertEqual(f.result(), obj) 165 self.assertEqual(recvd, obj) 166 167 self.loop.run_until_complete(test()) 168 169 def test_recv_json_cancelled(self): 170 async def test(): 171 a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) 172 f = b.recv_json() 173 assert not f.done() 174 f.cancel() 175 # cycle eventloop to allow cancel events to fire 176 await asyncio.sleep(0) 177 obj = dict(a=5) 178 await a.send_json(obj) 179 # CancelledError change in 3.8 https://bugs.python.org/issue32528 180 if sys.version_info < (3, 8): 181 with pytest.raises(CancelledError): 182 recvd = await f 183 else: 184 with pytest.raises(asyncio.exceptions.CancelledError): 185 recvd = await f 186 assert f.done() 187 # give it a chance to incorrectly consume the event 188 events = await b.poll(timeout=5) 189 assert events 190 await asyncio.sleep(0) 191 # make sure cancelled recv didn't eat up event 192 f = b.recv_json() 193 recvd = await asyncio.wait_for(f, timeout=5) 194 assert recvd == obj 195 196 self.loop.run_until_complete(test()) 197 198 def test_recv_pyobj(self): 199 async def test(): 200 a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) 201 f = b.recv_pyobj() 202 assert not f.done() 203 obj = dict(a=5) 204 await a.send_pyobj(obj) 205 recvd = await f 206 assert f.done() 207 self.assertEqual(f.result(), obj) 208 self.assertEqual(recvd, obj) 209 210 self.loop.run_until_complete(test()) 211 212 def test_custom_serialize(self): 213 def serialize(msg): 214 frames = [] 215 frames.extend(msg.get("identities", [])) 216 content = json.dumps(msg["content"]).encode("utf8") 217 frames.append(content) 218 return frames 219 220 def deserialize(frames): 221 identities = frames[:-1] 222 content = json.loads(frames[-1].decode("utf8")) 223 return { 224 "identities": identities, 225 "content": content, 226 } 227 228 async def test(): 229 a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER) 230 231 msg = { 232 "content": { 233 "a": 5, 234 "b": "bee", 235 } 236 } 237 await a.send_serialized(msg, serialize) 238 recvd = await b.recv_serialized(deserialize) 239 assert recvd["content"] == msg["content"] 240 assert recvd["identities"] 241 # bounce back, tests identities 242 await b.send_serialized(recvd, serialize) 243 r2 = await a.recv_serialized(deserialize) 244 assert r2["content"] == msg["content"] 245 assert not r2["identities"] 246 247 self.loop.run_until_complete(test()) 248 249 def test_custom_serialize_error(self): 250 async def test(): 251 a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER) 252 253 msg = { 254 "content": { 255 "a": 5, 256 "b": "bee", 257 } 258 } 259 with pytest.raises(TypeError): 260 await a.send_serialized(json, json.dumps) 261 262 await a.send(b"not json") 263 with pytest.raises(TypeError): 264 recvd = await b.recv_serialized(json.loads) 265 266 self.loop.run_until_complete(test()) 267 268 def test_recv_dontwait(self): 269 async def test(): 270 push, pull = self.create_bound_pair(zmq.PUSH, zmq.PULL) 271 f = pull.recv(zmq.DONTWAIT) 272 with self.assertRaises(zmq.Again): 273 await f 274 await push.send(b"ping") 275 await pull.poll() # ensure message will be waiting 276 f = pull.recv(zmq.DONTWAIT) 277 assert f.done() 278 msg = await f 279 self.assertEqual(msg, b"ping") 280 281 self.loop.run_until_complete(test()) 282 283 def test_recv_cancel(self): 284 async def test(): 285 a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) 286 f1 = b.recv() 287 f2 = b.recv_multipart() 288 assert f1.cancel() 289 assert f1.done() 290 assert not f2.done() 291 await a.send_multipart([b"hi", b"there"]) 292 recvd = await f2 293 assert f1.cancelled() 294 assert f2.done() 295 self.assertEqual(recvd, [b"hi", b"there"]) 296 297 self.loop.run_until_complete(test()) 298 299 def test_poll(self): 300 async def test(): 301 a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) 302 f = b.poll(timeout=0) 303 await asyncio.sleep(0) 304 self.assertEqual(f.result(), 0) 305 306 f = b.poll(timeout=1) 307 assert not f.done() 308 evt = await f 309 310 self.assertEqual(evt, 0) 311 312 f = b.poll(timeout=1000) 313 assert not f.done() 314 await a.send_multipart([b"hi", b"there"]) 315 evt = await f 316 self.assertEqual(evt, zmq.POLLIN) 317 recvd = await b.recv_multipart() 318 self.assertEqual(recvd, [b"hi", b"there"]) 319 320 self.loop.run_until_complete(test()) 321 322 def test_poll_base_socket(self): 323 async def test(): 324 ctx = zmq.Context() 325 url = "inproc://test" 326 a = ctx.socket(zmq.PUSH) 327 b = ctx.socket(zmq.PULL) 328 self.sockets.extend([a, b]) 329 a.bind(url) 330 b.connect(url) 331 332 poller = zaio.Poller() 333 poller.register(b, zmq.POLLIN) 334 335 f = poller.poll(timeout=1000) 336 assert not f.done() 337 a.send_multipart([b"hi", b"there"]) 338 evt = await f 339 self.assertEqual(evt, [(b, zmq.POLLIN)]) 340 recvd = b.recv_multipart() 341 self.assertEqual(recvd, [b"hi", b"there"]) 342 343 self.loop.run_until_complete(test()) 344 345 def test_poll_on_closed_socket(self): 346 async def test(): 347 a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) 348 349 f = b.poll(timeout=1) 350 b.close() 351 352 # The test might stall if we try to await f directly so instead just make a few 353 # passes through the event loop to schedule and execute all callbacks 354 for _ in range(5): 355 await asyncio.sleep(0) 356 if f.cancelled(): 357 break 358 assert f.cancelled() 359 360 self.loop.run_until_complete(test()) 361 362 @pytest.mark.skipif( 363 sys.platform.startswith("win"), 364 reason="Windows does not support polling on files", 365 ) 366 def test_poll_raw(self): 367 async def test(): 368 p = zaio.Poller() 369 # make a pipe 370 r, w = os.pipe() 371 r = os.fdopen(r, "rb") 372 w = os.fdopen(w, "wb") 373 374 # POLLOUT 375 p.register(r, zmq.POLLIN) 376 p.register(w, zmq.POLLOUT) 377 evts = await p.poll(timeout=1) 378 evts = dict(evts) 379 assert r.fileno() not in evts 380 assert w.fileno() in evts 381 assert evts[w.fileno()] == zmq.POLLOUT 382 383 # POLLIN 384 p.unregister(w) 385 w.write(b"x") 386 w.flush() 387 evts = await p.poll(timeout=1000) 388 evts = dict(evts) 389 assert r.fileno() in evts 390 assert evts[r.fileno()] == zmq.POLLIN 391 assert r.read(1) == b"x" 392 r.close() 393 w.close() 394 395 loop = asyncio.get_event_loop() 396 loop.run_until_complete(test()) 397 398 def test_multiple_loops(self): 399 a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) 400 401 async def test(): 402 await a.send(b'buf') 403 msg = await b.recv() 404 assert msg == b'buf' 405 406 for i in range(3): 407 loop = asyncio.new_event_loop() 408 asyncio.set_event_loop(loop) 409 loop.run_until_complete(asyncio.wait_for(test(), timeout=10)) 410 loop.close() 411 412 def test_shadow(self): 413 async def test(): 414 ctx = zmq.Context() 415 s = ctx.socket(zmq.PULL) 416 async_s = zaio.Socket(s) 417 assert isinstance(async_s, self.socket_class) 418 419 def test_process_teardown(self): 420 event_loop_policy_class = type(asyncio.get_event_loop_policy()) 421 proc = ProcessForTeardownTest(event_loop_policy_class) 422 proc.start() 423 try: 424 proc.join(10) # starting new Python process may cost a lot 425 self.assertEqual( 426 proc.exitcode, 427 0, 428 "Python process died with code %d" % proc.exitcode 429 if proc.exitcode 430 else "process teardown hangs", 431 ) 432 finally: 433 proc.terminate() 434 435 436class TestAsyncioAuthentication(TestThreadAuthentication): 437 """Test authentication running in a asyncio task""" 438 439 Context = zaio.Context 440 441 def shortDescription(self): 442 """Rewrite doc strings from TestThreadAuthentication from 443 'threaded' to 'asyncio'. 444 """ 445 doc = self._testMethodDoc 446 if doc: 447 doc = doc.split("\n")[0].strip() 448 if doc.startswith("threaded auth"): 449 doc = doc.replace("threaded auth", "asyncio auth") 450 return doc 451 452 def setUp(self): 453 self.loop = asyncio.new_event_loop() 454 asyncio.set_event_loop(self.loop) 455 super().setUp() 456 457 def tearDown(self): 458 super().tearDown() 459 self.loop.close() 460 461 def make_auth(self): 462 return AsyncioAuthenticator(self.context) 463 464 def can_connect(self, server, client): 465 """Check if client can connect to server using tcp transport""" 466 467 async def go(): 468 result = False 469 iface = "tcp://127.0.0.1" 470 port = server.bind_to_random_port(iface) 471 client.connect("%s:%i" % (iface, port)) 472 msg = [b"Hello World"] 473 474 # set timeouts 475 server.SNDTIMEO = client.RCVTIMEO = 1000 476 try: 477 await server.send_multipart(msg) 478 except zmq.Again: 479 return False 480 try: 481 rcvd_msg = await client.recv_multipart() 482 except zmq.Again: 483 return False 484 else: 485 assert rcvd_msg == msg 486 result = True 487 return result 488 489 return self.loop.run_until_complete(go()) 490 491 def _select_recv(self, multipart, socket, **kwargs): 492 recv = socket.recv_multipart if multipart else socket.recv 493 494 async def coro(): 495 if not await socket.poll(5000): 496 raise TimeoutError("Should have received a message") 497 return await recv(**kwargs) 498 499 return self.loop.run_until_complete(coro()) 500