1# 2# Copyright 2009 Facebook 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may 5# not use this file except in compliance with the License. You may obtain 6# a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations 14# under the License. 15 16"""A utility class to send to and recv from a non-blocking socket, 17using tornado. 18 19.. seealso:: 20 21 - :mod:`zmq.asyncio` 22 - :mod:`zmq.eventloop.future` 23 24""" 25 26from __future__ import with_statement 27import pickle 28import sys 29import warnings 30from queue import Queue 31 32import zmq 33from zmq.utils import jsonapi 34 35 36from .ioloop import IOLoop, gen_log 37 38try: 39 from tornado.stack_context import wrap as stack_context_wrap # type: ignore 40except ImportError: 41 if "zmq.eventloop.minitornado" in sys.modules: 42 from .minitornado.stack_context import wrap as stack_context_wrap # type: ignore 43 else: 44 # tornado 5 deprecates stack_context, 45 # tornado 6 removes it 46 def stack_context_wrap(callback): 47 return callback 48 49 50class ZMQStream(object): 51 """A utility class to register callbacks when a zmq socket sends and receives 52 53 For use with zmq.eventloop.ioloop 54 55 There are three main methods 56 57 Methods: 58 59 * **on_recv(callback, copy=True):** 60 register a callback to be run every time the socket has something to receive 61 * **on_send(callback):** 62 register a callback to be run every time you call send 63 * **send(self, msg, flags=0, copy=False, callback=None):** 64 perform a send that will trigger the callback 65 if callback is passed, on_send is also called. 66 67 There are also send_multipart(), send_json(), send_pyobj() 68 69 Three other methods for deactivating the callbacks: 70 71 * **stop_on_recv():** 72 turn off the recv callback 73 * **stop_on_send():** 74 turn off the send callback 75 76 which simply call ``on_<evt>(None)``. 77 78 The entire socket interface, excluding direct recv methods, is also 79 provided, primarily through direct-linking the methods. 80 e.g. 81 82 >>> stream.bind is stream.socket.bind 83 True 84 85 """ 86 87 socket = None 88 io_loop = None 89 poller = None 90 _send_queue = None 91 _recv_callback = None 92 _send_callback = None 93 _close_callback = None 94 _state = 0 95 _flushed = False 96 _recv_copy = False 97 _fd = None 98 99 def __init__(self, socket, io_loop=None): 100 self.socket = socket 101 self.io_loop = io_loop or IOLoop.current() 102 self.poller = zmq.Poller() 103 self._fd = self.socket.FD 104 105 self._send_queue = Queue() 106 self._recv_callback = None 107 self._send_callback = None 108 self._close_callback = None 109 self._recv_copy = False 110 self._flushed = False 111 112 self._state = 0 113 self._init_io_state() 114 115 # shortcircuit some socket methods 116 self.bind = self.socket.bind 117 self.bind_to_random_port = self.socket.bind_to_random_port 118 self.connect = self.socket.connect 119 self.setsockopt = self.socket.setsockopt 120 self.getsockopt = self.socket.getsockopt 121 self.setsockopt_string = self.socket.setsockopt_string 122 self.getsockopt_string = self.socket.getsockopt_string 123 self.setsockopt_unicode = self.socket.setsockopt_unicode 124 self.getsockopt_unicode = self.socket.getsockopt_unicode 125 126 def stop_on_recv(self): 127 """Disable callback and automatic receiving.""" 128 return self.on_recv(None) 129 130 def stop_on_send(self): 131 """Disable callback on sending.""" 132 return self.on_send(None) 133 134 def stop_on_err(self): 135 """DEPRECATED, does nothing""" 136 gen_log.warn("on_err does nothing, and will be removed") 137 138 def on_err(self, callback): 139 """DEPRECATED, does nothing""" 140 gen_log.warn("on_err does nothing, and will be removed") 141 142 def on_recv(self, callback, copy=True): 143 """Register a callback for when a message is ready to recv. 144 145 There can be only one callback registered at a time, so each 146 call to `on_recv` replaces previously registered callbacks. 147 148 on_recv(None) disables recv event polling. 149 150 Use on_recv_stream(callback) instead, to register a callback that will receive 151 both this ZMQStream and the message, instead of just the message. 152 153 Parameters 154 ---------- 155 156 callback : callable 157 callback must take exactly one argument, which will be a 158 list, as returned by socket.recv_multipart() 159 if callback is None, recv callbacks are disabled. 160 copy : bool 161 copy is passed directly to recv, so if copy is False, 162 callback will receive Message objects. If copy is True, 163 then callback will receive bytes/str objects. 164 165 Returns : None 166 """ 167 168 self._check_closed() 169 assert callback is None or callable(callback) 170 self._recv_callback = stack_context_wrap(callback) 171 self._recv_copy = copy 172 if callback is None: 173 self._drop_io_state(zmq.POLLIN) 174 else: 175 self._add_io_state(zmq.POLLIN) 176 177 def on_recv_stream(self, callback, copy=True): 178 """Same as on_recv, but callback will get this stream as first argument 179 180 callback must take exactly two arguments, as it will be called as:: 181 182 callback(stream, msg) 183 184 Useful when a single callback should be used with multiple streams. 185 """ 186 if callback is None: 187 self.stop_on_recv() 188 else: 189 self.on_recv(lambda msg: callback(self, msg), copy=copy) 190 191 def on_send(self, callback): 192 """Register a callback to be called on each send 193 194 There will be two arguments:: 195 196 callback(msg, status) 197 198 * `msg` will be the list of sendable objects that was just sent 199 * `status` will be the return result of socket.send_multipart(msg) - 200 MessageTracker or None. 201 202 Non-copying sends return a MessageTracker object whose 203 `done` attribute will be True when the send is complete. 204 This allows users to track when an object is safe to write to 205 again. 206 207 The second argument will always be None if copy=True 208 on the send. 209 210 Use on_send_stream(callback) to register a callback that will be passed 211 this ZMQStream as the first argument, in addition to the other two. 212 213 on_send(None) disables recv event polling. 214 215 Parameters 216 ---------- 217 218 callback : callable 219 callback must take exactly two arguments, which will be 220 the message being sent (always a list), 221 and the return result of socket.send_multipart(msg) - 222 MessageTracker or None. 223 224 if callback is None, send callbacks are disabled. 225 """ 226 227 self._check_closed() 228 assert callback is None or callable(callback) 229 self._send_callback = stack_context_wrap(callback) 230 231 def on_send_stream(self, callback): 232 """Same as on_send, but callback will get this stream as first argument 233 234 Callback will be passed three arguments:: 235 236 callback(stream, msg, status) 237 238 Useful when a single callback should be used with multiple streams. 239 """ 240 if callback is None: 241 self.stop_on_send() 242 else: 243 self.on_send(lambda msg, status: callback(self, msg, status)) 244 245 def send(self, msg, flags=0, copy=True, track=False, callback=None, **kwargs): 246 """Send a message, optionally also register a new callback for sends. 247 See zmq.socket.send for details. 248 """ 249 return self.send_multipart( 250 [msg], flags=flags, copy=copy, track=track, callback=callback, **kwargs 251 ) 252 253 def send_multipart( 254 self, msg, flags=0, copy=True, track=False, callback=None, **kwargs 255 ): 256 """Send a multipart message, optionally also register a new callback for sends. 257 See zmq.socket.send_multipart for details. 258 """ 259 kwargs.update(dict(flags=flags, copy=copy, track=track)) 260 self._send_queue.put((msg, kwargs)) 261 callback = callback or self._send_callback 262 if callback is not None: 263 self.on_send(callback) 264 else: 265 # noop callback 266 self.on_send(lambda *args: None) 267 self._add_io_state(zmq.POLLOUT) 268 269 def send_string(self, u, flags=0, encoding='utf-8', callback=None, **kwargs): 270 """Send a unicode message with an encoding. 271 See zmq.socket.send_unicode for details. 272 """ 273 if not isinstance(u, str): 274 raise TypeError("unicode/str objects only") 275 return self.send(u.encode(encoding), flags=flags, callback=callback, **kwargs) 276 277 send_unicode = send_string 278 279 def send_json(self, obj, flags=0, callback=None, **kwargs): 280 """Send json-serialized version of an object. 281 See zmq.socket.send_json for details. 282 """ 283 msg = jsonapi.dumps(obj) 284 return self.send(msg, flags=flags, callback=callback, **kwargs) 285 286 def send_pyobj(self, obj, flags=0, protocol=-1, callback=None, **kwargs): 287 """Send a Python object as a message using pickle to serialize. 288 289 See zmq.socket.send_json for details. 290 """ 291 msg = pickle.dumps(obj, protocol) 292 return self.send(msg, flags, callback=callback, **kwargs) 293 294 def _finish_flush(self): 295 """callback for unsetting _flushed flag.""" 296 self._flushed = False 297 298 def flush(self, flag=zmq.POLLIN | zmq.POLLOUT, limit=None): 299 """Flush pending messages. 300 301 This method safely handles all pending incoming and/or outgoing messages, 302 bypassing the inner loop, passing them to the registered callbacks. 303 304 A limit can be specified, to prevent blocking under high load. 305 306 flush will return the first time ANY of these conditions are met: 307 * No more events matching the flag are pending. 308 * the total number of events handled reaches the limit. 309 310 Note that if ``flag|POLLIN != 0``, recv events will be flushed even if no callback 311 is registered, unlike normal IOLoop operation. This allows flush to be 312 used to remove *and ignore* incoming messages. 313 314 Parameters 315 ---------- 316 flag : int, default=POLLIN|POLLOUT 317 0MQ poll flags. 318 If flag|POLLIN, recv events will be flushed. 319 If flag|POLLOUT, send events will be flushed. 320 Both flags can be set at once, which is the default. 321 limit : None or int, optional 322 The maximum number of messages to send or receive. 323 Both send and recv count against this limit. 324 325 Returns 326 ------- 327 int : count of events handled (both send and recv) 328 """ 329 self._check_closed() 330 # unset self._flushed, so callbacks will execute, in case flush has 331 # already been called this iteration 332 already_flushed = self._flushed 333 self._flushed = False 334 # initialize counters 335 count = 0 336 337 def update_flag(): 338 """Update the poll flag, to prevent registering POLLOUT events 339 if we don't have pending sends.""" 340 return flag & zmq.POLLIN | (self.sending() and flag & zmq.POLLOUT) 341 342 flag = update_flag() 343 if not flag: 344 # nothing to do 345 return 0 346 self.poller.register(self.socket, flag) 347 events = self.poller.poll(0) 348 while events and (not limit or count < limit): 349 s, event = events[0] 350 if event & zmq.POLLIN: # receiving 351 self._handle_recv() 352 count += 1 353 if self.socket is None: 354 # break if socket was closed during callback 355 break 356 if event & zmq.POLLOUT and self.sending(): 357 self._handle_send() 358 count += 1 359 if self.socket is None: 360 # break if socket was closed during callback 361 break 362 363 flag = update_flag() 364 if flag: 365 self.poller.register(self.socket, flag) 366 events = self.poller.poll(0) 367 else: 368 events = [] 369 if count: # only bypass loop if we actually flushed something 370 # skip send/recv callbacks this iteration 371 self._flushed = True 372 # reregister them at the end of the loop 373 if not already_flushed: # don't need to do it again 374 self.io_loop.add_callback(self._finish_flush) 375 elif already_flushed: 376 self._flushed = True 377 378 # update ioloop poll state, which may have changed 379 self._rebuild_io_state() 380 return count 381 382 def set_close_callback(self, callback): 383 """Call the given callback when the stream is closed.""" 384 self._close_callback = stack_context_wrap(callback) 385 386 def close(self, linger=None): 387 """Close this stream.""" 388 if self.socket is not None: 389 if self.socket.closed: 390 # fallback on raw fd for closed sockets 391 # hopefully this happened promptly after close, 392 # otherwise somebody else may have the FD 393 warnings.warn( 394 "Unregistering FD %s after closing socket. " 395 "This could result in unregistering handlers for the wrong socket. " 396 "Please use stream.close() instead of closing the socket directly." 397 % self._fd, 398 stacklevel=2, 399 ) 400 self.io_loop.remove_handler(self._fd) 401 else: 402 self.io_loop.remove_handler(self.socket) 403 self.socket.close(linger) 404 self.socket = None 405 if self._close_callback: 406 self._run_callback(self._close_callback) 407 408 def receiving(self): 409 """Returns True if we are currently receiving from the stream.""" 410 return self._recv_callback is not None 411 412 def sending(self): 413 """Returns True if we are currently sending to the stream.""" 414 return not self._send_queue.empty() 415 416 def closed(self): 417 if self.socket is None: 418 return True 419 if self.socket.closed: 420 # underlying socket has been closed, but not by us! 421 # trigger our cleanup 422 self.close() 423 return True 424 425 def _run_callback(self, callback, *args, **kwargs): 426 """Wrap running callbacks in try/except to allow us to 427 close our socket.""" 428 try: 429 # Use a NullContext to ensure that all StackContexts are run 430 # inside our blanket exception handler rather than outside. 431 callback(*args, **kwargs) 432 except Exception: 433 gen_log.error("Uncaught exception in ZMQStream callback", exc_info=True) 434 # Re-raise the exception so that IOLoop.handle_callback_exception 435 # can see it and log the error 436 raise 437 438 def _handle_events(self, fd, events): 439 """This method is the actual handler for IOLoop, that gets called whenever 440 an event on my socket is posted. It dispatches to _handle_recv, etc.""" 441 if not self.socket: 442 gen_log.warning("Got events for closed stream %s", self) 443 return 444 try: 445 zmq_events = self.socket.EVENTS 446 except zmq.ContextTerminated: 447 gen_log.warning("Got events for stream %s after terminating context", self) 448 return 449 try: 450 # dispatch events: 451 if zmq_events & zmq.POLLIN and self.receiving(): 452 self._handle_recv() 453 if not self.socket: 454 return 455 if zmq_events & zmq.POLLOUT and self.sending(): 456 self._handle_send() 457 if not self.socket: 458 return 459 460 # rebuild the poll state 461 self._rebuild_io_state() 462 except Exception: 463 gen_log.error("Uncaught exception in zmqstream callback", exc_info=True) 464 raise 465 466 def _handle_recv(self): 467 """Handle a recv event.""" 468 if self._flushed: 469 return 470 try: 471 msg = self.socket.recv_multipart(zmq.NOBLOCK, copy=self._recv_copy) 472 except zmq.ZMQError as e: 473 if e.errno == zmq.EAGAIN: 474 # state changed since poll event 475 pass 476 else: 477 raise 478 else: 479 if self._recv_callback: 480 callback = self._recv_callback 481 self._run_callback(callback, msg) 482 483 def _handle_send(self): 484 """Handle a send event.""" 485 if self._flushed: 486 return 487 if not self.sending(): 488 gen_log.error("Shouldn't have handled a send event") 489 return 490 491 msg, kwargs = self._send_queue.get() 492 try: 493 status = self.socket.send_multipart(msg, **kwargs) 494 except zmq.ZMQError as e: 495 gen_log.error("SEND Error: %s", e) 496 status = e 497 if self._send_callback: 498 callback = self._send_callback 499 self._run_callback(callback, msg, status) 500 501 def _check_closed(self): 502 if not self.socket: 503 raise IOError("Stream is closed") 504 505 def _rebuild_io_state(self): 506 """rebuild io state based on self.sending() and receiving()""" 507 if self.socket is None: 508 return 509 state = 0 510 if self.receiving(): 511 state |= zmq.POLLIN 512 if self.sending(): 513 state |= zmq.POLLOUT 514 515 self._state = state 516 self._update_handler(state) 517 518 def _add_io_state(self, state): 519 """Add io_state to poller.""" 520 self._state = self._state | state 521 self._update_handler(self._state) 522 523 def _drop_io_state(self, state): 524 """Stop poller from watching an io_state.""" 525 self._state = self._state & (~state) 526 self._update_handler(self._state) 527 528 def _update_handler(self, state): 529 """Update IOLoop handler with state.""" 530 if self.socket is None: 531 return 532 533 if state & self.socket.events: 534 # events still exist that haven't been processed 535 # explicitly schedule handling to avoid missing events due to edge-triggered FDs 536 self.io_loop.add_callback(lambda: self._handle_events(self.socket, 0)) 537 538 def _init_io_state(self): 539 """initialize the ioloop event handler""" 540 self.io_loop.add_handler(self.socket, self._handle_events, self.io_loop.READ) 541