1"""Test asyncio support"""
2# Copyright (c) PyZMQ Developers
3# Distributed under the terms of the Modified BSD License.
4
5import json
6from multiprocessing import Process
7import os
8import sys
9
10import pytest
11from pytest import mark
12
13import zmq
14from zmq.utils.strtypes import u
15
16import asyncio
17import zmq.asyncio as zaio
18from zmq.auth.asyncio import AsyncioAuthenticator
19
20from concurrent.futures import CancelledError
21from zmq.tests import BaseZMQTestCase
22from zmq.tests.test_auth import TestThreadAuthentication
23
24
25class ProcessForTeardownTest(Process):
26    def __init__(self, event_loop_policy_class):
27        Process.__init__(self)
28        self.event_loop_policy_class = event_loop_policy_class
29
30    def run(self):
31        """Leave context, socket and event loop upon implicit disposal"""
32        asyncio.set_event_loop_policy(self.event_loop_policy_class())
33
34        actx = zaio.Context.instance()
35        socket = actx.socket(zmq.PAIR)
36        socket.bind_to_random_port("tcp://127.0.0.1")
37
38        async def never_ending_task(socket):
39            await socket.recv()  # never ever receive anything
40
41        loop = asyncio.get_event_loop()
42        coro = asyncio.wait_for(never_ending_task(socket), timeout=1)
43        try:
44            loop.run_until_complete(coro)
45        except asyncio.TimeoutError:
46            pass  # expected timeout
47        else:
48            assert False, "never_ending_task was completed unexpectedly"
49
50
51class TestAsyncIOSocket(BaseZMQTestCase):
52    Context = zaio.Context
53
54    def setUp(self):
55        self.loop = asyncio.new_event_loop()
56        asyncio.set_event_loop(self.loop)
57        super(TestAsyncIOSocket, self).setUp()
58
59    def tearDown(self):
60        super().tearDown()
61        self.loop.close()
62        # verify cleanup of references to selectors
63        assert zaio._selectors == {}
64        if 'zmq._asyncio_selector' in sys.modules:
65            assert zmq._asyncio_selector._selector_loops == set()
66
67    def test_socket_class(self):
68        s = self.context.socket(zmq.PUSH)
69        assert isinstance(s, zaio.Socket)
70        s.close()
71
72    def test_instance_subclass_first(self):
73        actx = zmq.asyncio.Context.instance()
74        ctx = zmq.Context.instance()
75        ctx.term()
76        actx.term()
77        assert type(ctx) is zmq.Context
78        assert type(actx) is zmq.asyncio.Context
79
80    def test_instance_subclass_second(self):
81        ctx = zmq.Context.instance()
82        actx = zmq.asyncio.Context.instance()
83        ctx.term()
84        actx.term()
85        assert type(ctx) is zmq.Context
86        assert type(actx) is zmq.asyncio.Context
87
88    def test_recv_multipart(self):
89        async def test():
90            a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
91            f = b.recv_multipart()
92            assert not f.done()
93            await a.send(b"hi")
94            recvd = await f
95            self.assertEqual(recvd, [b"hi"])
96
97        self.loop.run_until_complete(test())
98
99    def test_recv(self):
100        async def test():
101            a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
102            f1 = b.recv()
103            f2 = b.recv()
104            assert not f1.done()
105            assert not f2.done()
106            await a.send_multipart([b"hi", b"there"])
107            recvd = await f2
108            assert f1.done()
109            self.assertEqual(f1.result(), b"hi")
110            self.assertEqual(recvd, b"there")
111
112        self.loop.run_until_complete(test())
113
114    @mark.skipif(not hasattr(zmq, "RCVTIMEO"), reason="requires RCVTIMEO")
115    def test_recv_timeout(self):
116        async def test():
117            a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
118            b.rcvtimeo = 100
119            f1 = b.recv()
120            b.rcvtimeo = 1000
121            f2 = b.recv_multipart()
122            with self.assertRaises(zmq.Again):
123                await f1
124            await a.send_multipart([b"hi", b"there"])
125            recvd = await f2
126            assert f2.done()
127            self.assertEqual(recvd, [b"hi", b"there"])
128
129        self.loop.run_until_complete(test())
130
131    @mark.skipif(not hasattr(zmq, "SNDTIMEO"), reason="requires SNDTIMEO")
132    def test_send_timeout(self):
133        async def test():
134            s = self.socket(zmq.PUSH)
135            s.sndtimeo = 100
136            with self.assertRaises(zmq.Again):
137                await s.send(b"not going anywhere")
138
139        self.loop.run_until_complete(test())
140
141    def test_recv_string(self):
142        async def test():
143            a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
144            f = b.recv_string()
145            assert not f.done()
146            msg = u("πøøπ")
147            await a.send_string(msg)
148            recvd = await f
149            assert f.done()
150            self.assertEqual(f.result(), msg)
151            self.assertEqual(recvd, msg)
152
153        self.loop.run_until_complete(test())
154
155    def test_recv_json(self):
156        async def test():
157            a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
158            f = b.recv_json()
159            assert not f.done()
160            obj = dict(a=5)
161            await a.send_json(obj)
162            recvd = await f
163            assert f.done()
164            self.assertEqual(f.result(), obj)
165            self.assertEqual(recvd, obj)
166
167        self.loop.run_until_complete(test())
168
169    def test_recv_json_cancelled(self):
170        async def test():
171            a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
172            f = b.recv_json()
173            assert not f.done()
174            f.cancel()
175            # cycle eventloop to allow cancel events to fire
176            await asyncio.sleep(0)
177            obj = dict(a=5)
178            await a.send_json(obj)
179            # CancelledError change in 3.8 https://bugs.python.org/issue32528
180            if sys.version_info < (3, 8):
181                with pytest.raises(CancelledError):
182                    recvd = await f
183            else:
184                with pytest.raises(asyncio.exceptions.CancelledError):
185                    recvd = await f
186            assert f.done()
187            # give it a chance to incorrectly consume the event
188            events = await b.poll(timeout=5)
189            assert events
190            await asyncio.sleep(0)
191            # make sure cancelled recv didn't eat up event
192            f = b.recv_json()
193            recvd = await asyncio.wait_for(f, timeout=5)
194            assert recvd == obj
195
196        self.loop.run_until_complete(test())
197
198    def test_recv_pyobj(self):
199        async def test():
200            a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
201            f = b.recv_pyobj()
202            assert not f.done()
203            obj = dict(a=5)
204            await a.send_pyobj(obj)
205            recvd = await f
206            assert f.done()
207            self.assertEqual(f.result(), obj)
208            self.assertEqual(recvd, obj)
209
210        self.loop.run_until_complete(test())
211
212    def test_custom_serialize(self):
213        def serialize(msg):
214            frames = []
215            frames.extend(msg.get("identities", []))
216            content = json.dumps(msg["content"]).encode("utf8")
217            frames.append(content)
218            return frames
219
220        def deserialize(frames):
221            identities = frames[:-1]
222            content = json.loads(frames[-1].decode("utf8"))
223            return {
224                "identities": identities,
225                "content": content,
226            }
227
228        async def test():
229            a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
230
231            msg = {
232                "content": {
233                    "a": 5,
234                    "b": "bee",
235                }
236            }
237            await a.send_serialized(msg, serialize)
238            recvd = await b.recv_serialized(deserialize)
239            assert recvd["content"] == msg["content"]
240            assert recvd["identities"]
241            # bounce back, tests identities
242            await b.send_serialized(recvd, serialize)
243            r2 = await a.recv_serialized(deserialize)
244            assert r2["content"] == msg["content"]
245            assert not r2["identities"]
246
247        self.loop.run_until_complete(test())
248
249    def test_custom_serialize_error(self):
250        async def test():
251            a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
252
253            msg = {
254                "content": {
255                    "a": 5,
256                    "b": "bee",
257                }
258            }
259            with pytest.raises(TypeError):
260                await a.send_serialized(json, json.dumps)
261
262            await a.send(b"not json")
263            with pytest.raises(TypeError):
264                recvd = await b.recv_serialized(json.loads)
265
266        self.loop.run_until_complete(test())
267
268    def test_recv_dontwait(self):
269        async def test():
270            push, pull = self.create_bound_pair(zmq.PUSH, zmq.PULL)
271            f = pull.recv(zmq.DONTWAIT)
272            with self.assertRaises(zmq.Again):
273                await f
274            await push.send(b"ping")
275            await pull.poll()  # ensure message will be waiting
276            f = pull.recv(zmq.DONTWAIT)
277            assert f.done()
278            msg = await f
279            self.assertEqual(msg, b"ping")
280
281        self.loop.run_until_complete(test())
282
283    def test_recv_cancel(self):
284        async def test():
285            a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
286            f1 = b.recv()
287            f2 = b.recv_multipart()
288            assert f1.cancel()
289            assert f1.done()
290            assert not f2.done()
291            await a.send_multipart([b"hi", b"there"])
292            recvd = await f2
293            assert f1.cancelled()
294            assert f2.done()
295            self.assertEqual(recvd, [b"hi", b"there"])
296
297        self.loop.run_until_complete(test())
298
299    def test_poll(self):
300        async def test():
301            a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
302            f = b.poll(timeout=0)
303            await asyncio.sleep(0)
304            self.assertEqual(f.result(), 0)
305
306            f = b.poll(timeout=1)
307            assert not f.done()
308            evt = await f
309
310            self.assertEqual(evt, 0)
311
312            f = b.poll(timeout=1000)
313            assert not f.done()
314            await a.send_multipart([b"hi", b"there"])
315            evt = await f
316            self.assertEqual(evt, zmq.POLLIN)
317            recvd = await b.recv_multipart()
318            self.assertEqual(recvd, [b"hi", b"there"])
319
320        self.loop.run_until_complete(test())
321
322    def test_poll_base_socket(self):
323        async def test():
324            ctx = zmq.Context()
325            url = "inproc://test"
326            a = ctx.socket(zmq.PUSH)
327            b = ctx.socket(zmq.PULL)
328            self.sockets.extend([a, b])
329            a.bind(url)
330            b.connect(url)
331
332            poller = zaio.Poller()
333            poller.register(b, zmq.POLLIN)
334
335            f = poller.poll(timeout=1000)
336            assert not f.done()
337            a.send_multipart([b"hi", b"there"])
338            evt = await f
339            self.assertEqual(evt, [(b, zmq.POLLIN)])
340            recvd = b.recv_multipart()
341            self.assertEqual(recvd, [b"hi", b"there"])
342
343        self.loop.run_until_complete(test())
344
345    def test_poll_on_closed_socket(self):
346        async def test():
347            a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
348
349            f = b.poll(timeout=1)
350            b.close()
351
352            # The test might stall if we try to await f directly so instead just make a few
353            # passes through the event loop to schedule and execute all callbacks
354            for _ in range(5):
355                await asyncio.sleep(0)
356                if f.cancelled():
357                    break
358            assert f.cancelled()
359
360        self.loop.run_until_complete(test())
361
362    @pytest.mark.skipif(
363        sys.platform.startswith("win"),
364        reason="Windows does not support polling on files",
365    )
366    def test_poll_raw(self):
367        async def test():
368            p = zaio.Poller()
369            # make a pipe
370            r, w = os.pipe()
371            r = os.fdopen(r, "rb")
372            w = os.fdopen(w, "wb")
373
374            # POLLOUT
375            p.register(r, zmq.POLLIN)
376            p.register(w, zmq.POLLOUT)
377            evts = await p.poll(timeout=1)
378            evts = dict(evts)
379            assert r.fileno() not in evts
380            assert w.fileno() in evts
381            assert evts[w.fileno()] == zmq.POLLOUT
382
383            # POLLIN
384            p.unregister(w)
385            w.write(b"x")
386            w.flush()
387            evts = await p.poll(timeout=1000)
388            evts = dict(evts)
389            assert r.fileno() in evts
390            assert evts[r.fileno()] == zmq.POLLIN
391            assert r.read(1) == b"x"
392            r.close()
393            w.close()
394
395        loop = asyncio.get_event_loop()
396        loop.run_until_complete(test())
397
398    def test_multiple_loops(self):
399        a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
400
401        async def test():
402            await a.send(b'buf')
403            msg = await b.recv()
404            assert msg == b'buf'
405
406        for i in range(3):
407            loop = asyncio.new_event_loop()
408            asyncio.set_event_loop(loop)
409            loop.run_until_complete(asyncio.wait_for(test(), timeout=10))
410            loop.close()
411
412    def test_shadow(self):
413        async def test():
414            ctx = zmq.Context()
415            s = ctx.socket(zmq.PULL)
416            async_s = zaio.Socket(s)
417            assert isinstance(async_s, self.socket_class)
418
419    def test_process_teardown(self):
420        event_loop_policy_class = type(asyncio.get_event_loop_policy())
421        proc = ProcessForTeardownTest(event_loop_policy_class)
422        proc.start()
423        try:
424            proc.join(10)  # starting new Python process may cost a lot
425            self.assertEqual(
426                proc.exitcode,
427                0,
428                "Python process died with code %d" % proc.exitcode
429                if proc.exitcode
430                else "process teardown hangs",
431            )
432        finally:
433            proc.terminate()
434
435
436class TestAsyncioAuthentication(TestThreadAuthentication):
437    """Test authentication running in a asyncio task"""
438
439    Context = zaio.Context
440
441    def shortDescription(self):
442        """Rewrite doc strings from TestThreadAuthentication from
443        'threaded' to 'asyncio'.
444        """
445        doc = self._testMethodDoc
446        if doc:
447            doc = doc.split("\n")[0].strip()
448            if doc.startswith("threaded auth"):
449                doc = doc.replace("threaded auth", "asyncio auth")
450        return doc
451
452    def setUp(self):
453        self.loop = asyncio.new_event_loop()
454        asyncio.set_event_loop(self.loop)
455        super().setUp()
456
457    def tearDown(self):
458        super().tearDown()
459        self.loop.close()
460
461    def make_auth(self):
462        return AsyncioAuthenticator(self.context)
463
464    def can_connect(self, server, client):
465        """Check if client can connect to server using tcp transport"""
466
467        async def go():
468            result = False
469            iface = "tcp://127.0.0.1"
470            port = server.bind_to_random_port(iface)
471            client.connect("%s:%i" % (iface, port))
472            msg = [b"Hello World"]
473
474            # set timeouts
475            server.SNDTIMEO = client.RCVTIMEO = 1000
476            try:
477                await server.send_multipart(msg)
478            except zmq.Again:
479                return False
480            try:
481                rcvd_msg = await client.recv_multipart()
482            except zmq.Again:
483                return False
484            else:
485                assert rcvd_msg == msg
486                result = True
487            return result
488
489        return self.loop.run_until_complete(go())
490
491    def _select_recv(self, multipart, socket, **kwargs):
492        recv = socket.recv_multipart if multipart else socket.recv
493
494        async def coro():
495            if not await socket.poll(5000):
496                raise TimeoutError("Should have received a message")
497            return await recv(**kwargs)
498
499        return self.loop.run_until_complete(coro())
500