1# Copyright (C) PyZMQ Developers 2# Distributed under the terms of the Modified BSD License. 3 4import copy 5import gc 6import os 7import sys 8import time 9from threading import Thread, Event 10from queue import Queue 11from unittest import mock 12 13from pytest import mark 14 15import zmq 16from zmq.tests import ( 17 BaseZMQTestCase, 18 have_gevent, 19 GreenTest, 20 skip_green, 21 PYPY, 22 SkipTest, 23) 24 25 26class KwargTestSocket(zmq.Socket): 27 test_kwarg_value = None 28 29 def __init__(self, *args, **kwargs): 30 self.test_kwarg_value = kwargs.pop('test_kwarg', None) 31 super(KwargTestSocket, self).__init__(*args, **kwargs) 32 33 34class KwargTestContext(zmq.Context): 35 _socket_class = KwargTestSocket 36 37 38class TestContext(BaseZMQTestCase): 39 def test_init(self): 40 c1 = self.Context() 41 self.assertTrue(isinstance(c1, self.Context)) 42 del c1 43 c2 = self.Context() 44 self.assertTrue(isinstance(c2, self.Context)) 45 del c2 46 c3 = self.Context() 47 self.assertTrue(isinstance(c3, self.Context)) 48 del c3 49 50 _repr_cls = "zmq.Context" 51 52 def test_repr(self): 53 with self.Context() as ctx: 54 assert f'{self._repr_cls}()' in repr(ctx) 55 assert 'closed' not in repr(ctx) 56 with ctx.socket(zmq.PUSH) as push: 57 assert f'{self._repr_cls}(1 socket)' in repr(ctx) 58 with ctx.socket(zmq.PULL) as pull: 59 assert f'{self._repr_cls}(2 sockets)' in repr(ctx) 60 assert f'{self._repr_cls}()' in repr(ctx) 61 assert 'closed' in repr(ctx) 62 63 def test_dir(self): 64 ctx = self.Context() 65 self.assertTrue('socket' in dir(ctx)) 66 if zmq.zmq_version_info() > (3,): 67 self.assertTrue('IO_THREADS' in dir(ctx)) 68 ctx.term() 69 70 @mark.skipif(mock is None, reason="requires unittest.mock") 71 def test_mockable(self): 72 m = mock.Mock(spec=self.context) 73 74 def test_term(self): 75 c = self.Context() 76 c.term() 77 self.assertTrue(c.closed) 78 79 def test_context_manager(self): 80 with self.Context() as c: 81 pass 82 self.assertTrue(c.closed) 83 84 def test_fail_init(self): 85 self.assertRaisesErrno(zmq.EINVAL, self.Context, -1) 86 87 def test_term_hang(self): 88 rep, req = self.create_bound_pair(zmq.ROUTER, zmq.DEALER) 89 req.setsockopt(zmq.LINGER, 0) 90 req.send(b'hello', copy=False) 91 req.close() 92 rep.close() 93 self.context.term() 94 95 def test_instance(self): 96 ctx = self.Context.instance() 97 c2 = self.Context.instance(io_threads=2) 98 self.assertTrue(c2 is ctx) 99 c2.term() 100 c3 = self.Context.instance() 101 c4 = self.Context.instance() 102 self.assertFalse(c3 is c2) 103 self.assertFalse(c3.closed) 104 self.assertTrue(c3 is c4) 105 106 def test_instance_subclass_first(self): 107 self.context.term() 108 109 class SubContext(zmq.Context): 110 pass 111 112 sctx = SubContext.instance() 113 ctx = zmq.Context.instance() 114 ctx.term() 115 sctx.term() 116 assert type(ctx) is zmq.Context 117 assert type(sctx) is SubContext 118 119 def test_instance_subclass_second(self): 120 self.context.term() 121 122 class SubContextInherit(zmq.Context): 123 pass 124 125 class SubContextNoInherit(zmq.Context): 126 _instance = None 127 pass 128 129 ctx = zmq.Context.instance() 130 sctx = SubContextInherit.instance() 131 sctx2 = SubContextNoInherit.instance() 132 ctx.term() 133 sctx.term() 134 sctx2.term() 135 assert type(ctx) is zmq.Context 136 assert type(sctx) is zmq.Context 137 assert type(sctx2) is SubContextNoInherit 138 139 def test_instance_threadsafe(self): 140 self.context.term() # clear default context 141 142 q = Queue() 143 # slow context initialization, 144 # to ensure that we are both trying to create one at the same time 145 class SlowContext(self.Context): 146 def __init__(self, *a, **kw): 147 time.sleep(1) 148 super(SlowContext, self).__init__(*a, **kw) 149 150 def f(): 151 q.put(SlowContext.instance()) 152 153 # call ctx.instance() in several threads at once 154 N = 16 155 threads = [Thread(target=f) for i in range(N)] 156 [t.start() for t in threads] 157 # also call it in the main thread (not first) 158 ctx = SlowContext.instance() 159 assert isinstance(ctx, SlowContext) 160 # check that all the threads got the same context 161 for i in range(N): 162 thread_ctx = q.get(timeout=5) 163 assert thread_ctx is ctx 164 # cleanup 165 ctx.term() 166 [t.join(timeout=5) for t in threads] 167 168 def test_socket_passes_kwargs(self): 169 test_kwarg_value = 'testing one two three' 170 with KwargTestContext() as ctx: 171 with ctx.socket(zmq.DEALER, test_kwarg=test_kwarg_value) as socket: 172 self.assertTrue(socket.test_kwarg_value is test_kwarg_value) 173 174 def test_many_sockets(self): 175 """opening and closing many sockets shouldn't cause problems""" 176 ctx = self.Context() 177 for i in range(16): 178 sockets = [ctx.socket(zmq.REP) for i in range(65)] 179 [s.close() for s in sockets] 180 # give the reaper a chance 181 time.sleep(1e-2) 182 ctx.term() 183 184 def test_sockopts(self): 185 """setting socket options with ctx attributes""" 186 ctx = self.Context() 187 ctx.linger = 5 188 self.assertEqual(ctx.linger, 5) 189 s = ctx.socket(zmq.REQ) 190 self.assertEqual(s.linger, 5) 191 self.assertEqual(s.getsockopt(zmq.LINGER), 5) 192 s.close() 193 # check that subscribe doesn't get set on sockets that don't subscribe: 194 ctx.subscribe = b'' 195 s = ctx.socket(zmq.REQ) 196 s.close() 197 198 ctx.term() 199 200 @mark.skipif(sys.platform.startswith('win'), reason='Segfaults on Windows') 201 def test_destroy(self): 202 """Context.destroy should close sockets""" 203 ctx = self.Context() 204 sockets = [ctx.socket(zmq.REP) for i in range(65)] 205 206 # close half of the sockets 207 [s.close() for s in sockets[::2]] 208 209 ctx.destroy() 210 # reaper is not instantaneous 211 time.sleep(1e-2) 212 for s in sockets: 213 self.assertTrue(s.closed) 214 215 def test_destroy_linger(self): 216 """Context.destroy should set linger on closing sockets""" 217 req, rep = self.create_bound_pair(zmq.REQ, zmq.REP) 218 req.send(b'hi') 219 time.sleep(1e-2) 220 self.context.destroy(linger=0) 221 # reaper is not instantaneous 222 time.sleep(1e-2) 223 for s in (req, rep): 224 self.assertTrue(s.closed) 225 226 def test_term_noclose(self): 227 """Context.term won't close sockets""" 228 ctx = self.Context() 229 s = ctx.socket(zmq.REQ) 230 self.assertFalse(s.closed) 231 t = Thread(target=ctx.term) 232 t.start() 233 t.join(timeout=0.1) 234 self.assertTrue(t.is_alive(), "Context should be waiting") 235 s.close() 236 t.join(timeout=0.1) 237 self.assertFalse(t.is_alive(), "Context should have closed") 238 239 def test_gc(self): 240 """test close&term by garbage collection alone""" 241 if PYPY: 242 raise SkipTest("GC doesn't work ") 243 244 # test credit @dln (GH #137): 245 def gcf(): 246 def inner(): 247 ctx = self.Context() 248 s = ctx.socket(zmq.PUSH) 249 250 inner() 251 gc.collect() 252 253 t = Thread(target=gcf) 254 t.start() 255 t.join(timeout=1) 256 self.assertFalse( 257 t.is_alive(), "Garbage collection should have cleaned up context" 258 ) 259 260 def test_cyclic_destroy(self): 261 """ctx.destroy should succeed when cyclic ref prevents gc""" 262 # test credit @dln (GH #137): 263 class CyclicReference(object): 264 def __init__(self, parent=None): 265 self.parent = parent 266 267 def crash(self, sock): 268 self.sock = sock 269 self.child = CyclicReference(self) 270 271 def crash_zmq(): 272 ctx = self.Context() 273 sock = ctx.socket(zmq.PULL) 274 c = CyclicReference() 275 c.crash(sock) 276 ctx.destroy() 277 278 crash_zmq() 279 280 def test_term_thread(self): 281 """ctx.term should not crash active threads (#139)""" 282 ctx = self.Context() 283 evt = Event() 284 evt.clear() 285 286 def block(): 287 s = ctx.socket(zmq.REP) 288 s.bind_to_random_port('tcp://127.0.0.1') 289 evt.set() 290 try: 291 s.recv() 292 except zmq.ZMQError as e: 293 self.assertEqual(e.errno, zmq.ETERM) 294 return 295 finally: 296 s.close() 297 self.fail("recv should have been interrupted with ETERM") 298 299 t = Thread(target=block) 300 t.start() 301 302 evt.wait(1) 303 self.assertTrue(evt.is_set(), "sync event never fired") 304 time.sleep(0.01) 305 ctx.term() 306 t.join(timeout=1) 307 self.assertFalse(t.is_alive(), "term should have interrupted s.recv()") 308 309 def test_destroy_no_sockets(self): 310 ctx = self.Context() 311 s = ctx.socket(zmq.PUB) 312 s.bind_to_random_port('tcp://127.0.0.1') 313 s.close() 314 ctx.destroy() 315 assert s.closed 316 assert ctx.closed 317 318 def test_ctx_opts(self): 319 if zmq.zmq_version_info() < (3,): 320 raise SkipTest("context options require libzmq 3") 321 ctx = self.Context() 322 ctx.set(zmq.MAX_SOCKETS, 2) 323 self.assertEqual(ctx.get(zmq.MAX_SOCKETS), 2) 324 ctx.max_sockets = 100 325 self.assertEqual(ctx.max_sockets, 100) 326 self.assertEqual(ctx.get(zmq.MAX_SOCKETS), 100) 327 328 def test_copy(self): 329 c1 = self.Context() 330 c2 = copy.copy(c1) 331 c2b = copy.deepcopy(c1) 332 c3 = copy.deepcopy(c2) 333 self.assertTrue(c2._shadow) 334 self.assertTrue(c3._shadow) 335 self.assertEqual(c1.underlying, c2.underlying) 336 self.assertEqual(c1.underlying, c3.underlying) 337 self.assertEqual(c1.underlying, c2b.underlying) 338 s = c3.socket(zmq.PUB) 339 s.close() 340 c1.term() 341 342 def test_shadow(self): 343 ctx = self.Context() 344 ctx2 = self.Context.shadow(ctx.underlying) 345 self.assertEqual(ctx.underlying, ctx2.underlying) 346 s = ctx.socket(zmq.PUB) 347 s.close() 348 del ctx2 349 self.assertFalse(ctx.closed) 350 s = ctx.socket(zmq.PUB) 351 ctx2 = self.Context.shadow(ctx.underlying) 352 s2 = ctx2.socket(zmq.PUB) 353 s.close() 354 s2.close() 355 ctx.term() 356 self.assertRaisesErrno(zmq.EFAULT, ctx2.socket, zmq.PUB) 357 del ctx2 358 359 def test_shadow_pyczmq(self): 360 try: 361 from pyczmq import zctx, zsocket, zstr 362 except Exception: 363 raise SkipTest("Requires pyczmq") 364 365 ctx = zctx.new() 366 a = zsocket.new(ctx, zmq.PUSH) 367 zsocket.bind(a, "inproc://a") 368 ctx2 = self.Context.shadow_pyczmq(ctx) 369 b = ctx2.socket(zmq.PULL) 370 b.connect("inproc://a") 371 zstr.send(a, b'hi') 372 rcvd = self.recv(b) 373 self.assertEqual(rcvd, b'hi') 374 b.close() 375 376 @mark.skipif(sys.platform.startswith('win'), reason='No fork on Windows') 377 def test_fork_instance(self): 378 ctx = self.Context.instance() 379 parent_ctx_id = id(ctx) 380 r_fd, w_fd = os.pipe() 381 reader = os.fdopen(r_fd, 'r') 382 child_pid = os.fork() 383 if child_pid == 0: 384 ctx = self.Context.instance() 385 writer = os.fdopen(w_fd, 'w') 386 child_ctx_id = id(ctx) 387 ctx.term() 388 writer.write(str(child_ctx_id) + "\n") 389 writer.flush() 390 writer.close() 391 os._exit(0) 392 else: 393 os.close(w_fd) 394 395 child_id_s = reader.readline() 396 reader.close() 397 assert child_id_s 398 assert int(child_id_s) != parent_ctx_id 399 ctx.term() 400 401 402if False: # disable green context tests 403 404 class TestContextGreen(GreenTest, TestContext): 405 """gevent subclass of context tests""" 406 407 # skip tests that use real threads: 408 test_gc = GreenTest.skip_green 409 test_term_thread = GreenTest.skip_green 410 test_destroy_linger = GreenTest.skip_green 411 _repr_cls = "zmq.green.Context" 412