1"""Base classes to manage a Client's interaction with a running kernel"""
2
3# Copyright (c) Jupyter Development Team.
4# Distributed under the terms of the Modified BSD License.
5
6import atexit
7import errno
8from threading import Thread, Event
9import time
10import asyncio
11
12import zmq
13# import ZMQError in top-level namespace, to avoid ugly attribute-error messages
14# during garbage collection of threads at exit:
15from zmq import ZMQError
16
17from jupyter_client import protocol_version_info
18
19from .channelsabc import HBChannelABC
20
21#-----------------------------------------------------------------------------
22# Constants and exceptions
23#-----------------------------------------------------------------------------
24
25major_protocol_version = protocol_version_info[0]
26
27class InvalidPortNumber(Exception):
28    pass
29
30class HBChannel(Thread):
31    """The heartbeat channel which monitors the kernel heartbeat.
32
33    Note that the heartbeat channel is paused by default. As long as you start
34    this channel, the kernel manager will ensure that it is paused and un-paused
35    as appropriate.
36    """
37    context = None
38    session = None
39    socket = None
40    address = None
41    _exiting = False
42
43    time_to_dead = 1.
44    poller = None
45    _running = None
46    _pause = None
47    _beating = None
48
49    def __init__(self, context=None, session=None, address=None, loop=None):
50        """Create the heartbeat monitor thread.
51
52        Parameters
53        ----------
54        context : :class:`zmq.Context`
55            The ZMQ context to use.
56        session : :class:`session.Session`
57            The session to use.
58        address : zmq url
59            Standard (ip, port) tuple that the kernel is listening on.
60        """
61        super().__init__()
62        self.daemon = True
63
64        self.loop = loop
65
66        self.context = context
67        self.session = session
68        if isinstance(address, tuple):
69            if address[1] == 0:
70                message = 'The port number for a channel cannot be 0.'
71                raise InvalidPortNumber(message)
72            address = "tcp://%s:%i" % address
73        self.address = address
74
75        # running is False until `.start()` is called
76        self._running = False
77        self._exit = Event()
78        # don't start paused
79        self._pause = False
80        self.poller = zmq.Poller()
81
82    @staticmethod
83    @atexit.register
84    def _notice_exit():
85        # Class definitions can be torn down during interpreter shutdown.
86        # We only need to set _exiting flag if this hasn't happened.
87        if HBChannel is not None:
88            HBChannel._exiting = True
89
90    def _create_socket(self):
91        if self.socket is not None:
92            # close previous socket, before opening a new one
93            self.poller.unregister(self.socket)
94            self.socket.close()
95        self.socket = self.context.socket(zmq.REQ)
96        self.socket.linger = 1000
97        self.socket.connect(self.address)
98
99        self.poller.register(self.socket, zmq.POLLIN)
100
101    def _poll(self, start_time):
102        """poll for heartbeat replies until we reach self.time_to_dead.
103
104        Ignores interrupts, and returns the result of poll(), which
105        will be an empty list if no messages arrived before the timeout,
106        or the event tuple if there is a message to receive.
107        """
108
109        until_dead = self.time_to_dead - (time.time() - start_time)
110        # ensure poll at least once
111        until_dead = max(until_dead, 1e-3)
112        events = []
113        while True:
114            try:
115                events = self.poller.poll(1000 * until_dead)
116            except ZMQError as e:
117                if e.errno == errno.EINTR:
118                    # ignore interrupts during heartbeat
119                    # this may never actually happen
120                    until_dead = self.time_to_dead - (time.time() - start_time)
121                    until_dead = max(until_dead, 1e-3)
122                    pass
123                else:
124                    raise
125            except Exception:
126                if self._exiting:
127                    break
128                else:
129                    raise
130            else:
131                break
132        return events
133
134    def run(self):
135        """The thread's main activity.  Call start() instead."""
136        if self.loop is not None:
137            asyncio.set_event_loop(self.loop)
138        self._create_socket()
139        self._running = True
140        self._beating = True
141
142        while self._running:
143            if self._pause:
144                # just sleep, and skip the rest of the loop
145                self._exit.wait(self.time_to_dead)
146                continue
147
148            since_last_heartbeat = 0.0
149            # no need to catch EFSM here, because the previous event was
150            # either a recv or connect, which cannot be followed by EFSM
151            self.socket.send(b'ping')
152            request_time = time.time()
153            ready = self._poll(request_time)
154            if ready:
155                self._beating = True
156                # the poll above guarantees we have something to recv
157                self.socket.recv()
158                # sleep the remainder of the cycle
159                remainder = self.time_to_dead - (time.time() - request_time)
160                if remainder > 0:
161                    self._exit.wait(remainder)
162                continue
163            else:
164                # nothing was received within the time limit, signal heart failure
165                self._beating = False
166                since_last_heartbeat = time.time() - request_time
167                self.call_handlers(since_last_heartbeat)
168                # and close/reopen the socket, because the REQ/REP cycle has been broken
169                self._create_socket()
170                continue
171
172    def pause(self):
173        """Pause the heartbeat."""
174        self._pause = True
175
176    def unpause(self):
177        """Unpause the heartbeat."""
178        self._pause = False
179
180    def is_beating(self):
181        """Is the heartbeat running and responsive (and not paused)."""
182        if self.is_alive() and not self._pause and self._beating:
183            return True
184        else:
185            return False
186
187    def stop(self):
188        """Stop the channel's event loop and join its thread."""
189        self._running = False
190        self._exit.set()
191        self.join()
192        self.close()
193
194    def close(self):
195        if self.socket is not None:
196            try:
197                self.socket.close(linger=0)
198            except Exception:
199                pass
200            self.socket = None
201
202    def call_handlers(self, since_last_heartbeat):
203        """This method is called in the ioloop thread when a message arrives.
204
205        Subclasses should override this method to handle incoming messages.
206        It is important to remember that this method is called in the thread
207        so that some logic must be done to ensure that the application level
208        handlers are called in the application thread.
209        """
210        pass
211
212
213HBChannelABC.register(HBChannel)
214