1#
2# Copyright 2009 Facebook
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may
5# not use this file except in compliance with the License. You may obtain
6# a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations
14# under the License.
15
16"""A utility class to send to and recv from a non-blocking socket,
17using tornado.
18
19.. seealso::
20
21    - :mod:`zmq.asyncio`
22    - :mod:`zmq.eventloop.future`
23
24"""
25
26from __future__ import with_statement
27import pickle
28import sys
29import warnings
30from queue import Queue
31
32import zmq
33from zmq.utils import jsonapi
34
35
36from .ioloop import IOLoop, gen_log
37
38try:
39    from tornado.stack_context import wrap as stack_context_wrap  # type: ignore
40except ImportError:
41    if "zmq.eventloop.minitornado" in sys.modules:
42        from .minitornado.stack_context import wrap as stack_context_wrap  # type: ignore
43    else:
44        # tornado 5 deprecates stack_context,
45        # tornado 6 removes it
46        def stack_context_wrap(callback):
47            return callback
48
49
50class ZMQStream(object):
51    """A utility class to register callbacks when a zmq socket sends and receives
52
53    For use with zmq.eventloop.ioloop
54
55    There are three main methods
56
57    Methods:
58
59    * **on_recv(callback, copy=True):**
60        register a callback to be run every time the socket has something to receive
61    * **on_send(callback):**
62        register a callback to be run every time you call send
63    * **send(self, msg, flags=0, copy=False, callback=None):**
64        perform a send that will trigger the callback
65        if callback is passed, on_send is also called.
66
67        There are also send_multipart(), send_json(), send_pyobj()
68
69    Three other methods for deactivating the callbacks:
70
71    * **stop_on_recv():**
72        turn off the recv callback
73    * **stop_on_send():**
74        turn off the send callback
75
76    which simply call ``on_<evt>(None)``.
77
78    The entire socket interface, excluding direct recv methods, is also
79    provided, primarily through direct-linking the methods.
80    e.g.
81
82    >>> stream.bind is stream.socket.bind
83    True
84
85    """
86
87    socket = None
88    io_loop = None
89    poller = None
90    _send_queue = None
91    _recv_callback = None
92    _send_callback = None
93    _close_callback = None
94    _state = 0
95    _flushed = False
96    _recv_copy = False
97    _fd = None
98
99    def __init__(self, socket, io_loop=None):
100        self.socket = socket
101        self.io_loop = io_loop or IOLoop.current()
102        self.poller = zmq.Poller()
103        self._fd = self.socket.FD
104
105        self._send_queue = Queue()
106        self._recv_callback = None
107        self._send_callback = None
108        self._close_callback = None
109        self._recv_copy = False
110        self._flushed = False
111
112        self._state = 0
113        self._init_io_state()
114
115        # shortcircuit some socket methods
116        self.bind = self.socket.bind
117        self.bind_to_random_port = self.socket.bind_to_random_port
118        self.connect = self.socket.connect
119        self.setsockopt = self.socket.setsockopt
120        self.getsockopt = self.socket.getsockopt
121        self.setsockopt_string = self.socket.setsockopt_string
122        self.getsockopt_string = self.socket.getsockopt_string
123        self.setsockopt_unicode = self.socket.setsockopt_unicode
124        self.getsockopt_unicode = self.socket.getsockopt_unicode
125
126    def stop_on_recv(self):
127        """Disable callback and automatic receiving."""
128        return self.on_recv(None)
129
130    def stop_on_send(self):
131        """Disable callback on sending."""
132        return self.on_send(None)
133
134    def stop_on_err(self):
135        """DEPRECATED, does nothing"""
136        gen_log.warn("on_err does nothing, and will be removed")
137
138    def on_err(self, callback):
139        """DEPRECATED, does nothing"""
140        gen_log.warn("on_err does nothing, and will be removed")
141
142    def on_recv(self, callback, copy=True):
143        """Register a callback for when a message is ready to recv.
144
145        There can be only one callback registered at a time, so each
146        call to `on_recv` replaces previously registered callbacks.
147
148        on_recv(None) disables recv event polling.
149
150        Use on_recv_stream(callback) instead, to register a callback that will receive
151        both this ZMQStream and the message, instead of just the message.
152
153        Parameters
154        ----------
155
156        callback : callable
157            callback must take exactly one argument, which will be a
158            list, as returned by socket.recv_multipart()
159            if callback is None, recv callbacks are disabled.
160        copy : bool
161            copy is passed directly to recv, so if copy is False,
162            callback will receive Message objects. If copy is True,
163            then callback will receive bytes/str objects.
164
165        Returns : None
166        """
167
168        self._check_closed()
169        assert callback is None or callable(callback)
170        self._recv_callback = stack_context_wrap(callback)
171        self._recv_copy = copy
172        if callback is None:
173            self._drop_io_state(zmq.POLLIN)
174        else:
175            self._add_io_state(zmq.POLLIN)
176
177    def on_recv_stream(self, callback, copy=True):
178        """Same as on_recv, but callback will get this stream as first argument
179
180        callback must take exactly two arguments, as it will be called as::
181
182            callback(stream, msg)
183
184        Useful when a single callback should be used with multiple streams.
185        """
186        if callback is None:
187            self.stop_on_recv()
188        else:
189            self.on_recv(lambda msg: callback(self, msg), copy=copy)
190
191    def on_send(self, callback):
192        """Register a callback to be called on each send
193
194        There will be two arguments::
195
196            callback(msg, status)
197
198        * `msg` will be the list of sendable objects that was just sent
199        * `status` will be the return result of socket.send_multipart(msg) -
200          MessageTracker or None.
201
202        Non-copying sends return a MessageTracker object whose
203        `done` attribute will be True when the send is complete.
204        This allows users to track when an object is safe to write to
205        again.
206
207        The second argument will always be None if copy=True
208        on the send.
209
210        Use on_send_stream(callback) to register a callback that will be passed
211        this ZMQStream as the first argument, in addition to the other two.
212
213        on_send(None) disables recv event polling.
214
215        Parameters
216        ----------
217
218        callback : callable
219            callback must take exactly two arguments, which will be
220            the message being sent (always a list),
221            and the return result of socket.send_multipart(msg) -
222            MessageTracker or None.
223
224            if callback is None, send callbacks are disabled.
225        """
226
227        self._check_closed()
228        assert callback is None or callable(callback)
229        self._send_callback = stack_context_wrap(callback)
230
231    def on_send_stream(self, callback):
232        """Same as on_send, but callback will get this stream as first argument
233
234        Callback will be passed three arguments::
235
236            callback(stream, msg, status)
237
238        Useful when a single callback should be used with multiple streams.
239        """
240        if callback is None:
241            self.stop_on_send()
242        else:
243            self.on_send(lambda msg, status: callback(self, msg, status))
244
245    def send(self, msg, flags=0, copy=True, track=False, callback=None, **kwargs):
246        """Send a message, optionally also register a new callback for sends.
247        See zmq.socket.send for details.
248        """
249        return self.send_multipart(
250            [msg], flags=flags, copy=copy, track=track, callback=callback, **kwargs
251        )
252
253    def send_multipart(
254        self, msg, flags=0, copy=True, track=False, callback=None, **kwargs
255    ):
256        """Send a multipart message, optionally also register a new callback for sends.
257        See zmq.socket.send_multipart for details.
258        """
259        kwargs.update(dict(flags=flags, copy=copy, track=track))
260        self._send_queue.put((msg, kwargs))
261        callback = callback or self._send_callback
262        if callback is not None:
263            self.on_send(callback)
264        else:
265            # noop callback
266            self.on_send(lambda *args: None)
267        self._add_io_state(zmq.POLLOUT)
268
269    def send_string(self, u, flags=0, encoding='utf-8', callback=None, **kwargs):
270        """Send a unicode message with an encoding.
271        See zmq.socket.send_unicode for details.
272        """
273        if not isinstance(u, str):
274            raise TypeError("unicode/str objects only")
275        return self.send(u.encode(encoding), flags=flags, callback=callback, **kwargs)
276
277    send_unicode = send_string
278
279    def send_json(self, obj, flags=0, callback=None, **kwargs):
280        """Send json-serialized version of an object.
281        See zmq.socket.send_json for details.
282        """
283        msg = jsonapi.dumps(obj)
284        return self.send(msg, flags=flags, callback=callback, **kwargs)
285
286    def send_pyobj(self, obj, flags=0, protocol=-1, callback=None, **kwargs):
287        """Send a Python object as a message using pickle to serialize.
288
289        See zmq.socket.send_json for details.
290        """
291        msg = pickle.dumps(obj, protocol)
292        return self.send(msg, flags, callback=callback, **kwargs)
293
294    def _finish_flush(self):
295        """callback for unsetting _flushed flag."""
296        self._flushed = False
297
298    def flush(self, flag=zmq.POLLIN | zmq.POLLOUT, limit=None):
299        """Flush pending messages.
300
301        This method safely handles all pending incoming and/or outgoing messages,
302        bypassing the inner loop, passing them to the registered callbacks.
303
304        A limit can be specified, to prevent blocking under high load.
305
306        flush will return the first time ANY of these conditions are met:
307            * No more events matching the flag are pending.
308            * the total number of events handled reaches the limit.
309
310        Note that if ``flag|POLLIN != 0``, recv events will be flushed even if no callback
311        is registered, unlike normal IOLoop operation. This allows flush to be
312        used to remove *and ignore* incoming messages.
313
314        Parameters
315        ----------
316        flag : int, default=POLLIN|POLLOUT
317                0MQ poll flags.
318                If flag|POLLIN,  recv events will be flushed.
319                If flag|POLLOUT, send events will be flushed.
320                Both flags can be set at once, which is the default.
321        limit : None or int, optional
322                The maximum number of messages to send or receive.
323                Both send and recv count against this limit.
324
325        Returns
326        -------
327        int : count of events handled (both send and recv)
328        """
329        self._check_closed()
330        # unset self._flushed, so callbacks will execute, in case flush has
331        # already been called this iteration
332        already_flushed = self._flushed
333        self._flushed = False
334        # initialize counters
335        count = 0
336
337        def update_flag():
338            """Update the poll flag, to prevent registering POLLOUT events
339            if we don't have pending sends."""
340            return flag & zmq.POLLIN | (self.sending() and flag & zmq.POLLOUT)
341
342        flag = update_flag()
343        if not flag:
344            # nothing to do
345            return 0
346        self.poller.register(self.socket, flag)
347        events = self.poller.poll(0)
348        while events and (not limit or count < limit):
349            s, event = events[0]
350            if event & zmq.POLLIN:  # receiving
351                self._handle_recv()
352                count += 1
353                if self.socket is None:
354                    # break if socket was closed during callback
355                    break
356            if event & zmq.POLLOUT and self.sending():
357                self._handle_send()
358                count += 1
359                if self.socket is None:
360                    # break if socket was closed during callback
361                    break
362
363            flag = update_flag()
364            if flag:
365                self.poller.register(self.socket, flag)
366                events = self.poller.poll(0)
367            else:
368                events = []
369        if count:  # only bypass loop if we actually flushed something
370            # skip send/recv callbacks this iteration
371            self._flushed = True
372            # reregister them at the end of the loop
373            if not already_flushed:  # don't need to do it again
374                self.io_loop.add_callback(self._finish_flush)
375        elif already_flushed:
376            self._flushed = True
377
378        # update ioloop poll state, which may have changed
379        self._rebuild_io_state()
380        return count
381
382    def set_close_callback(self, callback):
383        """Call the given callback when the stream is closed."""
384        self._close_callback = stack_context_wrap(callback)
385
386    def close(self, linger=None):
387        """Close this stream."""
388        if self.socket is not None:
389            if self.socket.closed:
390                # fallback on raw fd for closed sockets
391                # hopefully this happened promptly after close,
392                # otherwise somebody else may have the FD
393                warnings.warn(
394                    "Unregistering FD %s after closing socket. "
395                    "This could result in unregistering handlers for the wrong socket. "
396                    "Please use stream.close() instead of closing the socket directly."
397                    % self._fd,
398                    stacklevel=2,
399                )
400                self.io_loop.remove_handler(self._fd)
401            else:
402                self.io_loop.remove_handler(self.socket)
403                self.socket.close(linger)
404            self.socket = None
405            if self._close_callback:
406                self._run_callback(self._close_callback)
407
408    def receiving(self):
409        """Returns True if we are currently receiving from the stream."""
410        return self._recv_callback is not None
411
412    def sending(self):
413        """Returns True if we are currently sending to the stream."""
414        return not self._send_queue.empty()
415
416    def closed(self):
417        if self.socket is None:
418            return True
419        if self.socket.closed:
420            # underlying socket has been closed, but not by us!
421            # trigger our cleanup
422            self.close()
423            return True
424
425    def _run_callback(self, callback, *args, **kwargs):
426        """Wrap running callbacks in try/except to allow us to
427        close our socket."""
428        try:
429            # Use a NullContext to ensure that all StackContexts are run
430            # inside our blanket exception handler rather than outside.
431            callback(*args, **kwargs)
432        except Exception:
433            gen_log.error("Uncaught exception in ZMQStream callback", exc_info=True)
434            # Re-raise the exception so that IOLoop.handle_callback_exception
435            # can see it and log the error
436            raise
437
438    def _handle_events(self, fd, events):
439        """This method is the actual handler for IOLoop, that gets called whenever
440        an event on my socket is posted. It dispatches to _handle_recv, etc."""
441        if not self.socket:
442            gen_log.warning("Got events for closed stream %s", self)
443            return
444        try:
445            zmq_events = self.socket.EVENTS
446        except zmq.ContextTerminated:
447            gen_log.warning("Got events for stream %s after terminating context", self)
448            return
449        try:
450            # dispatch events:
451            if zmq_events & zmq.POLLIN and self.receiving():
452                self._handle_recv()
453                if not self.socket:
454                    return
455            if zmq_events & zmq.POLLOUT and self.sending():
456                self._handle_send()
457                if not self.socket:
458                    return
459
460            # rebuild the poll state
461            self._rebuild_io_state()
462        except Exception:
463            gen_log.error("Uncaught exception in zmqstream callback", exc_info=True)
464            raise
465
466    def _handle_recv(self):
467        """Handle a recv event."""
468        if self._flushed:
469            return
470        try:
471            msg = self.socket.recv_multipart(zmq.NOBLOCK, copy=self._recv_copy)
472        except zmq.ZMQError as e:
473            if e.errno == zmq.EAGAIN:
474                # state changed since poll event
475                pass
476            else:
477                raise
478        else:
479            if self._recv_callback:
480                callback = self._recv_callback
481                self._run_callback(callback, msg)
482
483    def _handle_send(self):
484        """Handle a send event."""
485        if self._flushed:
486            return
487        if not self.sending():
488            gen_log.error("Shouldn't have handled a send event")
489            return
490
491        msg, kwargs = self._send_queue.get()
492        try:
493            status = self.socket.send_multipart(msg, **kwargs)
494        except zmq.ZMQError as e:
495            gen_log.error("SEND Error: %s", e)
496            status = e
497        if self._send_callback:
498            callback = self._send_callback
499            self._run_callback(callback, msg, status)
500
501    def _check_closed(self):
502        if not self.socket:
503            raise IOError("Stream is closed")
504
505    def _rebuild_io_state(self):
506        """rebuild io state based on self.sending() and receiving()"""
507        if self.socket is None:
508            return
509        state = 0
510        if self.receiving():
511            state |= zmq.POLLIN
512        if self.sending():
513            state |= zmq.POLLOUT
514
515        self._state = state
516        self._update_handler(state)
517
518    def _add_io_state(self, state):
519        """Add io_state to poller."""
520        self._state = self._state | state
521        self._update_handler(self._state)
522
523    def _drop_io_state(self, state):
524        """Stop poller from watching an io_state."""
525        self._state = self._state & (~state)
526        self._update_handler(self._state)
527
528    def _update_handler(self, state):
529        """Update IOLoop handler with state."""
530        if self.socket is None:
531            return
532
533        if state & self.socket.events:
534            # events still exist that haven't been processed
535            # explicitly schedule handling to avoid missing events due to edge-triggered FDs
536            self.io_loop.add_callback(lambda: self._handle_events(self.socket, 0))
537
538    def _init_io_state(self):
539        """initialize the ioloop event handler"""
540        self.io_loop.add_handler(self.socket, self._handle_events, self.io_loop.READ)
541