# Copyright (C) PyZMQ Developers # Distributed under the terms of the Modified BSD License. import copy import gc import os import sys import time from threading import Thread, Event from queue import Queue from unittest import mock from pytest import mark import zmq from zmq.tests import ( BaseZMQTestCase, have_gevent, GreenTest, skip_green, PYPY, SkipTest, ) class KwargTestSocket(zmq.Socket): test_kwarg_value = None def __init__(self, *args, **kwargs): self.test_kwarg_value = kwargs.pop('test_kwarg', None) super(KwargTestSocket, self).__init__(*args, **kwargs) class KwargTestContext(zmq.Context): _socket_class = KwargTestSocket class TestContext(BaseZMQTestCase): def test_init(self): c1 = self.Context() self.assertTrue(isinstance(c1, self.Context)) del c1 c2 = self.Context() self.assertTrue(isinstance(c2, self.Context)) del c2 c3 = self.Context() self.assertTrue(isinstance(c3, self.Context)) del c3 _repr_cls = "zmq.Context" def test_repr(self): with self.Context() as ctx: assert f'{self._repr_cls}()' in repr(ctx) assert 'closed' not in repr(ctx) with ctx.socket(zmq.PUSH) as push: assert f'{self._repr_cls}(1 socket)' in repr(ctx) with ctx.socket(zmq.PULL) as pull: assert f'{self._repr_cls}(2 sockets)' in repr(ctx) assert f'{self._repr_cls}()' in repr(ctx) assert 'closed' in repr(ctx) def test_dir(self): ctx = self.Context() self.assertTrue('socket' in dir(ctx)) if zmq.zmq_version_info() > (3,): self.assertTrue('IO_THREADS' in dir(ctx)) ctx.term() @mark.skipif(mock is None, reason="requires unittest.mock") def test_mockable(self): m = mock.Mock(spec=self.context) def test_term(self): c = self.Context() c.term() self.assertTrue(c.closed) def test_context_manager(self): with self.Context() as c: pass self.assertTrue(c.closed) def test_fail_init(self): self.assertRaisesErrno(zmq.EINVAL, self.Context, -1) def test_term_hang(self): rep, req = self.create_bound_pair(zmq.ROUTER, zmq.DEALER) req.setsockopt(zmq.LINGER, 0) req.send(b'hello', copy=False) req.close() rep.close() self.context.term() def test_instance(self): ctx = self.Context.instance() c2 = self.Context.instance(io_threads=2) self.assertTrue(c2 is ctx) c2.term() c3 = self.Context.instance() c4 = self.Context.instance() self.assertFalse(c3 is c2) self.assertFalse(c3.closed) self.assertTrue(c3 is c4) def test_instance_subclass_first(self): self.context.term() class SubContext(zmq.Context): pass sctx = SubContext.instance() ctx = zmq.Context.instance() ctx.term() sctx.term() assert type(ctx) is zmq.Context assert type(sctx) is SubContext def test_instance_subclass_second(self): self.context.term() class SubContextInherit(zmq.Context): pass class SubContextNoInherit(zmq.Context): _instance = None pass ctx = zmq.Context.instance() sctx = SubContextInherit.instance() sctx2 = SubContextNoInherit.instance() ctx.term() sctx.term() sctx2.term() assert type(ctx) is zmq.Context assert type(sctx) is zmq.Context assert type(sctx2) is SubContextNoInherit def test_instance_threadsafe(self): self.context.term() # clear default context q = Queue() # slow context initialization, # to ensure that we are both trying to create one at the same time class SlowContext(self.Context): def __init__(self, *a, **kw): time.sleep(1) super(SlowContext, self).__init__(*a, **kw) def f(): q.put(SlowContext.instance()) # call ctx.instance() in several threads at once N = 16 threads = [Thread(target=f) for i in range(N)] [t.start() for t in threads] # also call it in the main thread (not first) ctx = SlowContext.instance() assert isinstance(ctx, SlowContext) # check that all the threads got the same context for i in range(N): thread_ctx = q.get(timeout=5) assert thread_ctx is ctx # cleanup ctx.term() [t.join(timeout=5) for t in threads] def test_socket_passes_kwargs(self): test_kwarg_value = 'testing one two three' with KwargTestContext() as ctx: with ctx.socket(zmq.DEALER, test_kwarg=test_kwarg_value) as socket: self.assertTrue(socket.test_kwarg_value is test_kwarg_value) def test_many_sockets(self): """opening and closing many sockets shouldn't cause problems""" ctx = self.Context() for i in range(16): sockets = [ctx.socket(zmq.REP) for i in range(65)] [s.close() for s in sockets] # give the reaper a chance time.sleep(1e-2) ctx.term() def test_sockopts(self): """setting socket options with ctx attributes""" ctx = self.Context() ctx.linger = 5 self.assertEqual(ctx.linger, 5) s = ctx.socket(zmq.REQ) self.assertEqual(s.linger, 5) self.assertEqual(s.getsockopt(zmq.LINGER), 5) s.close() # check that subscribe doesn't get set on sockets that don't subscribe: ctx.subscribe = b'' s = ctx.socket(zmq.REQ) s.close() ctx.term() @mark.skipif(sys.platform.startswith('win'), reason='Segfaults on Windows') def test_destroy(self): """Context.destroy should close sockets""" ctx = self.Context() sockets = [ctx.socket(zmq.REP) for i in range(65)] # close half of the sockets [s.close() for s in sockets[::2]] ctx.destroy() # reaper is not instantaneous time.sleep(1e-2) for s in sockets: self.assertTrue(s.closed) def test_destroy_linger(self): """Context.destroy should set linger on closing sockets""" req, rep = self.create_bound_pair(zmq.REQ, zmq.REP) req.send(b'hi') time.sleep(1e-2) self.context.destroy(linger=0) # reaper is not instantaneous time.sleep(1e-2) for s in (req, rep): self.assertTrue(s.closed) def test_term_noclose(self): """Context.term won't close sockets""" ctx = self.Context() s = ctx.socket(zmq.REQ) self.assertFalse(s.closed) t = Thread(target=ctx.term) t.start() t.join(timeout=0.1) self.assertTrue(t.is_alive(), "Context should be waiting") s.close() t.join(timeout=0.1) self.assertFalse(t.is_alive(), "Context should have closed") def test_gc(self): """test close&term by garbage collection alone""" if PYPY: raise SkipTest("GC doesn't work ") # test credit @dln (GH #137): def gcf(): def inner(): ctx = self.Context() s = ctx.socket(zmq.PUSH) inner() gc.collect() t = Thread(target=gcf) t.start() t.join(timeout=1) self.assertFalse( t.is_alive(), "Garbage collection should have cleaned up context" ) def test_cyclic_destroy(self): """ctx.destroy should succeed when cyclic ref prevents gc""" # test credit @dln (GH #137): class CyclicReference(object): def __init__(self, parent=None): self.parent = parent def crash(self, sock): self.sock = sock self.child = CyclicReference(self) def crash_zmq(): ctx = self.Context() sock = ctx.socket(zmq.PULL) c = CyclicReference() c.crash(sock) ctx.destroy() crash_zmq() def test_term_thread(self): """ctx.term should not crash active threads (#139)""" ctx = self.Context() evt = Event() evt.clear() def block(): s = ctx.socket(zmq.REP) s.bind_to_random_port('tcp://127.0.0.1') evt.set() try: s.recv() except zmq.ZMQError as e: self.assertEqual(e.errno, zmq.ETERM) return finally: s.close() self.fail("recv should have been interrupted with ETERM") t = Thread(target=block) t.start() evt.wait(1) self.assertTrue(evt.is_set(), "sync event never fired") time.sleep(0.01) ctx.term() t.join(timeout=1) self.assertFalse(t.is_alive(), "term should have interrupted s.recv()") def test_destroy_no_sockets(self): ctx = self.Context() s = ctx.socket(zmq.PUB) s.bind_to_random_port('tcp://127.0.0.1') s.close() ctx.destroy() assert s.closed assert ctx.closed def test_ctx_opts(self): if zmq.zmq_version_info() < (3,): raise SkipTest("context options require libzmq 3") ctx = self.Context() ctx.set(zmq.MAX_SOCKETS, 2) self.assertEqual(ctx.get(zmq.MAX_SOCKETS), 2) ctx.max_sockets = 100 self.assertEqual(ctx.max_sockets, 100) self.assertEqual(ctx.get(zmq.MAX_SOCKETS), 100) def test_copy(self): c1 = self.Context() c2 = copy.copy(c1) c2b = copy.deepcopy(c1) c3 = copy.deepcopy(c2) self.assertTrue(c2._shadow) self.assertTrue(c3._shadow) self.assertEqual(c1.underlying, c2.underlying) self.assertEqual(c1.underlying, c3.underlying) self.assertEqual(c1.underlying, c2b.underlying) s = c3.socket(zmq.PUB) s.close() c1.term() def test_shadow(self): ctx = self.Context() ctx2 = self.Context.shadow(ctx.underlying) self.assertEqual(ctx.underlying, ctx2.underlying) s = ctx.socket(zmq.PUB) s.close() del ctx2 self.assertFalse(ctx.closed) s = ctx.socket(zmq.PUB) ctx2 = self.Context.shadow(ctx.underlying) s2 = ctx2.socket(zmq.PUB) s.close() s2.close() ctx.term() self.assertRaisesErrno(zmq.EFAULT, ctx2.socket, zmq.PUB) del ctx2 def test_shadow_pyczmq(self): try: from pyczmq import zctx, zsocket, zstr except Exception: raise SkipTest("Requires pyczmq") ctx = zctx.new() a = zsocket.new(ctx, zmq.PUSH) zsocket.bind(a, "inproc://a") ctx2 = self.Context.shadow_pyczmq(ctx) b = ctx2.socket(zmq.PULL) b.connect("inproc://a") zstr.send(a, b'hi') rcvd = self.recv(b) self.assertEqual(rcvd, b'hi') b.close() @mark.skipif(sys.platform.startswith('win'), reason='No fork on Windows') def test_fork_instance(self): ctx = self.Context.instance() parent_ctx_id = id(ctx) r_fd, w_fd = os.pipe() reader = os.fdopen(r_fd, 'r') child_pid = os.fork() if child_pid == 0: ctx = self.Context.instance() writer = os.fdopen(w_fd, 'w') child_ctx_id = id(ctx) ctx.term() writer.write(str(child_ctx_id) + "\n") writer.flush() writer.close() os._exit(0) else: os.close(w_fd) child_id_s = reader.readline() reader.close() assert child_id_s assert int(child_id_s) != parent_ctx_id ctx.term() if False: # disable green context tests class TestContextGreen(GreenTest, TestContext): """gevent subclass of context tests""" # skip tests that use real threads: test_gc = GreenTest.skip_green test_term_thread = GreenTest.skip_green test_destroy_linger = GreenTest.skip_green _repr_cls = "zmq.green.Context"