1# Copyright (C) PyZMQ Developers
2# Distributed under the terms of the Modified BSD License.
3
4import copy
5import gc
6import os
7import sys
8import time
9from threading import Thread, Event
10from queue import Queue
11from unittest import mock
12
13from pytest import mark
14
15import zmq
16from zmq.tests import (
17    BaseZMQTestCase,
18    have_gevent,
19    GreenTest,
20    skip_green,
21    PYPY,
22    SkipTest,
23)
24
25
26class KwargTestSocket(zmq.Socket):
27    test_kwarg_value = None
28
29    def __init__(self, *args, **kwargs):
30        self.test_kwarg_value = kwargs.pop('test_kwarg', None)
31        super(KwargTestSocket, self).__init__(*args, **kwargs)
32
33
34class KwargTestContext(zmq.Context):
35    _socket_class = KwargTestSocket
36
37
38class TestContext(BaseZMQTestCase):
39    def test_init(self):
40        c1 = self.Context()
41        self.assertTrue(isinstance(c1, self.Context))
42        del c1
43        c2 = self.Context()
44        self.assertTrue(isinstance(c2, self.Context))
45        del c2
46        c3 = self.Context()
47        self.assertTrue(isinstance(c3, self.Context))
48        del c3
49
50    _repr_cls = "zmq.Context"
51
52    def test_repr(self):
53        with self.Context() as ctx:
54            assert f'{self._repr_cls}()' in repr(ctx)
55            assert 'closed' not in repr(ctx)
56            with ctx.socket(zmq.PUSH) as push:
57                assert f'{self._repr_cls}(1 socket)' in repr(ctx)
58                with ctx.socket(zmq.PULL) as pull:
59                    assert f'{self._repr_cls}(2 sockets)' in repr(ctx)
60        assert f'{self._repr_cls}()' in repr(ctx)
61        assert 'closed' in repr(ctx)
62
63    def test_dir(self):
64        ctx = self.Context()
65        self.assertTrue('socket' in dir(ctx))
66        if zmq.zmq_version_info() > (3,):
67            self.assertTrue('IO_THREADS' in dir(ctx))
68        ctx.term()
69
70    @mark.skipif(mock is None, reason="requires unittest.mock")
71    def test_mockable(self):
72        m = mock.Mock(spec=self.context)
73
74    def test_term(self):
75        c = self.Context()
76        c.term()
77        self.assertTrue(c.closed)
78
79    def test_context_manager(self):
80        with self.Context() as c:
81            pass
82        self.assertTrue(c.closed)
83
84    def test_fail_init(self):
85        self.assertRaisesErrno(zmq.EINVAL, self.Context, -1)
86
87    def test_term_hang(self):
88        rep, req = self.create_bound_pair(zmq.ROUTER, zmq.DEALER)
89        req.setsockopt(zmq.LINGER, 0)
90        req.send(b'hello', copy=False)
91        req.close()
92        rep.close()
93        self.context.term()
94
95    def test_instance(self):
96        ctx = self.Context.instance()
97        c2 = self.Context.instance(io_threads=2)
98        self.assertTrue(c2 is ctx)
99        c2.term()
100        c3 = self.Context.instance()
101        c4 = self.Context.instance()
102        self.assertFalse(c3 is c2)
103        self.assertFalse(c3.closed)
104        self.assertTrue(c3 is c4)
105
106    def test_instance_subclass_first(self):
107        self.context.term()
108
109        class SubContext(zmq.Context):
110            pass
111
112        sctx = SubContext.instance()
113        ctx = zmq.Context.instance()
114        ctx.term()
115        sctx.term()
116        assert type(ctx) is zmq.Context
117        assert type(sctx) is SubContext
118
119    def test_instance_subclass_second(self):
120        self.context.term()
121
122        class SubContextInherit(zmq.Context):
123            pass
124
125        class SubContextNoInherit(zmq.Context):
126            _instance = None
127            pass
128
129        ctx = zmq.Context.instance()
130        sctx = SubContextInherit.instance()
131        sctx2 = SubContextNoInherit.instance()
132        ctx.term()
133        sctx.term()
134        sctx2.term()
135        assert type(ctx) is zmq.Context
136        assert type(sctx) is zmq.Context
137        assert type(sctx2) is SubContextNoInherit
138
139    def test_instance_threadsafe(self):
140        self.context.term()  # clear default context
141
142        q = Queue()
143        # slow context initialization,
144        # to ensure that we are both trying to create one at the same time
145        class SlowContext(self.Context):
146            def __init__(self, *a, **kw):
147                time.sleep(1)
148                super(SlowContext, self).__init__(*a, **kw)
149
150        def f():
151            q.put(SlowContext.instance())
152
153        # call ctx.instance() in several threads at once
154        N = 16
155        threads = [Thread(target=f) for i in range(N)]
156        [t.start() for t in threads]
157        # also call it in the main thread (not first)
158        ctx = SlowContext.instance()
159        assert isinstance(ctx, SlowContext)
160        # check that all the threads got the same context
161        for i in range(N):
162            thread_ctx = q.get(timeout=5)
163            assert thread_ctx is ctx
164        # cleanup
165        ctx.term()
166        [t.join(timeout=5) for t in threads]
167
168    def test_socket_passes_kwargs(self):
169        test_kwarg_value = 'testing one two three'
170        with KwargTestContext() as ctx:
171            with ctx.socket(zmq.DEALER, test_kwarg=test_kwarg_value) as socket:
172                self.assertTrue(socket.test_kwarg_value is test_kwarg_value)
173
174    def test_many_sockets(self):
175        """opening and closing many sockets shouldn't cause problems"""
176        ctx = self.Context()
177        for i in range(16):
178            sockets = [ctx.socket(zmq.REP) for i in range(65)]
179            [s.close() for s in sockets]
180            # give the reaper a chance
181            time.sleep(1e-2)
182        ctx.term()
183
184    def test_sockopts(self):
185        """setting socket options with ctx attributes"""
186        ctx = self.Context()
187        ctx.linger = 5
188        self.assertEqual(ctx.linger, 5)
189        s = ctx.socket(zmq.REQ)
190        self.assertEqual(s.linger, 5)
191        self.assertEqual(s.getsockopt(zmq.LINGER), 5)
192        s.close()
193        # check that subscribe doesn't get set on sockets that don't subscribe:
194        ctx.subscribe = b''
195        s = ctx.socket(zmq.REQ)
196        s.close()
197
198        ctx.term()
199
200    @mark.skipif(sys.platform.startswith('win'), reason='Segfaults on Windows')
201    def test_destroy(self):
202        """Context.destroy should close sockets"""
203        ctx = self.Context()
204        sockets = [ctx.socket(zmq.REP) for i in range(65)]
205
206        # close half of the sockets
207        [s.close() for s in sockets[::2]]
208
209        ctx.destroy()
210        # reaper is not instantaneous
211        time.sleep(1e-2)
212        for s in sockets:
213            self.assertTrue(s.closed)
214
215    def test_destroy_linger(self):
216        """Context.destroy should set linger on closing sockets"""
217        req, rep = self.create_bound_pair(zmq.REQ, zmq.REP)
218        req.send(b'hi')
219        time.sleep(1e-2)
220        self.context.destroy(linger=0)
221        # reaper is not instantaneous
222        time.sleep(1e-2)
223        for s in (req, rep):
224            self.assertTrue(s.closed)
225
226    def test_term_noclose(self):
227        """Context.term won't close sockets"""
228        ctx = self.Context()
229        s = ctx.socket(zmq.REQ)
230        self.assertFalse(s.closed)
231        t = Thread(target=ctx.term)
232        t.start()
233        t.join(timeout=0.1)
234        self.assertTrue(t.is_alive(), "Context should be waiting")
235        s.close()
236        t.join(timeout=0.1)
237        self.assertFalse(t.is_alive(), "Context should have closed")
238
239    def test_gc(self):
240        """test close&term by garbage collection alone"""
241        if PYPY:
242            raise SkipTest("GC doesn't work ")
243
244        # test credit @dln (GH #137):
245        def gcf():
246            def inner():
247                ctx = self.Context()
248                s = ctx.socket(zmq.PUSH)
249
250            inner()
251            gc.collect()
252
253        t = Thread(target=gcf)
254        t.start()
255        t.join(timeout=1)
256        self.assertFalse(
257            t.is_alive(), "Garbage collection should have cleaned up context"
258        )
259
260    def test_cyclic_destroy(self):
261        """ctx.destroy should succeed when cyclic ref prevents gc"""
262        # test credit @dln (GH #137):
263        class CyclicReference(object):
264            def __init__(self, parent=None):
265                self.parent = parent
266
267            def crash(self, sock):
268                self.sock = sock
269                self.child = CyclicReference(self)
270
271        def crash_zmq():
272            ctx = self.Context()
273            sock = ctx.socket(zmq.PULL)
274            c = CyclicReference()
275            c.crash(sock)
276            ctx.destroy()
277
278        crash_zmq()
279
280    def test_term_thread(self):
281        """ctx.term should not crash active threads (#139)"""
282        ctx = self.Context()
283        evt = Event()
284        evt.clear()
285
286        def block():
287            s = ctx.socket(zmq.REP)
288            s.bind_to_random_port('tcp://127.0.0.1')
289            evt.set()
290            try:
291                s.recv()
292            except zmq.ZMQError as e:
293                self.assertEqual(e.errno, zmq.ETERM)
294                return
295            finally:
296                s.close()
297            self.fail("recv should have been interrupted with ETERM")
298
299        t = Thread(target=block)
300        t.start()
301
302        evt.wait(1)
303        self.assertTrue(evt.is_set(), "sync event never fired")
304        time.sleep(0.01)
305        ctx.term()
306        t.join(timeout=1)
307        self.assertFalse(t.is_alive(), "term should have interrupted s.recv()")
308
309    def test_destroy_no_sockets(self):
310        ctx = self.Context()
311        s = ctx.socket(zmq.PUB)
312        s.bind_to_random_port('tcp://127.0.0.1')
313        s.close()
314        ctx.destroy()
315        assert s.closed
316        assert ctx.closed
317
318    def test_ctx_opts(self):
319        if zmq.zmq_version_info() < (3,):
320            raise SkipTest("context options require libzmq 3")
321        ctx = self.Context()
322        ctx.set(zmq.MAX_SOCKETS, 2)
323        self.assertEqual(ctx.get(zmq.MAX_SOCKETS), 2)
324        ctx.max_sockets = 100
325        self.assertEqual(ctx.max_sockets, 100)
326        self.assertEqual(ctx.get(zmq.MAX_SOCKETS), 100)
327
328    def test_copy(self):
329        c1 = self.Context()
330        c2 = copy.copy(c1)
331        c2b = copy.deepcopy(c1)
332        c3 = copy.deepcopy(c2)
333        self.assertTrue(c2._shadow)
334        self.assertTrue(c3._shadow)
335        self.assertEqual(c1.underlying, c2.underlying)
336        self.assertEqual(c1.underlying, c3.underlying)
337        self.assertEqual(c1.underlying, c2b.underlying)
338        s = c3.socket(zmq.PUB)
339        s.close()
340        c1.term()
341
342    def test_shadow(self):
343        ctx = self.Context()
344        ctx2 = self.Context.shadow(ctx.underlying)
345        self.assertEqual(ctx.underlying, ctx2.underlying)
346        s = ctx.socket(zmq.PUB)
347        s.close()
348        del ctx2
349        self.assertFalse(ctx.closed)
350        s = ctx.socket(zmq.PUB)
351        ctx2 = self.Context.shadow(ctx.underlying)
352        s2 = ctx2.socket(zmq.PUB)
353        s.close()
354        s2.close()
355        ctx.term()
356        self.assertRaisesErrno(zmq.EFAULT, ctx2.socket, zmq.PUB)
357        del ctx2
358
359    def test_shadow_pyczmq(self):
360        try:
361            from pyczmq import zctx, zsocket, zstr
362        except Exception:
363            raise SkipTest("Requires pyczmq")
364
365        ctx = zctx.new()
366        a = zsocket.new(ctx, zmq.PUSH)
367        zsocket.bind(a, "inproc://a")
368        ctx2 = self.Context.shadow_pyczmq(ctx)
369        b = ctx2.socket(zmq.PULL)
370        b.connect("inproc://a")
371        zstr.send(a, b'hi')
372        rcvd = self.recv(b)
373        self.assertEqual(rcvd, b'hi')
374        b.close()
375
376    @mark.skipif(sys.platform.startswith('win'), reason='No fork on Windows')
377    def test_fork_instance(self):
378        ctx = self.Context.instance()
379        parent_ctx_id = id(ctx)
380        r_fd, w_fd = os.pipe()
381        reader = os.fdopen(r_fd, 'r')
382        child_pid = os.fork()
383        if child_pid == 0:
384            ctx = self.Context.instance()
385            writer = os.fdopen(w_fd, 'w')
386            child_ctx_id = id(ctx)
387            ctx.term()
388            writer.write(str(child_ctx_id) + "\n")
389            writer.flush()
390            writer.close()
391            os._exit(0)
392        else:
393            os.close(w_fd)
394
395        child_id_s = reader.readline()
396        reader.close()
397        assert child_id_s
398        assert int(child_id_s) != parent_ctx_id
399        ctx.term()
400
401
402if False:  # disable green context tests
403
404    class TestContextGreen(GreenTest, TestContext):
405        """gevent subclass of context tests"""
406
407        # skip tests that use real threads:
408        test_gc = GreenTest.skip_green
409        test_term_thread = GreenTest.skip_green
410        test_destroy_linger = GreenTest.skip_green
411        _repr_cls = "zmq.green.Context"
412