1"""Tornado handlers for kernels.
2
3Preliminary documentation at https://github.com/ipython/ipython/wiki/IPEP-16%3A-Notebook-multi-directory-dashboard-and-URL-mapping#kernels-api
4"""
5
6# Copyright (c) Jupyter Development Team.
7# Distributed under the terms of the Modified BSD License.
8
9import json
10import logging
11from textwrap import dedent
12
13from tornado import gen, web
14from tornado.concurrent import Future
15from tornado.ioloop import IOLoop
16
17from jupyter_client import protocol_version as client_protocol_version
18try:
19    from jupyter_client.jsonutil import json_default
20except ImportError:
21    from jupyter_client.jsonutil import (
22        date_default as json_default
23    )
24from ipython_genutils.py3compat import cast_unicode
25from notebook.utils import maybe_future, url_path_join, url_escape
26
27from ...base.handlers import APIHandler
28from ...base.zmqhandlers import AuthenticatedZMQStreamHandler, deserialize_binary_message
29
30class MainKernelHandler(APIHandler):
31
32    @web.authenticated
33    @gen.coroutine
34    def get(self):
35        km = self.kernel_manager
36        kernels = yield maybe_future(km.list_kernels())
37        self.finish(json.dumps(kernels, default=json_default))
38
39    @web.authenticated
40    @gen.coroutine
41    def post(self):
42        km = self.kernel_manager
43        model = self.get_json_body()
44        if model is None:
45            model = {
46                'name': km.default_kernel_name
47            }
48        else:
49            model.setdefault('name', km.default_kernel_name)
50
51        kernel_id = yield maybe_future(km.start_kernel(kernel_name=model['name']))
52        model = yield maybe_future(km.kernel_model(kernel_id))
53        location = url_path_join(self.base_url, 'api', 'kernels', url_escape(kernel_id))
54        self.set_header('Location', location)
55        self.set_status(201)
56        self.finish(json.dumps(model, default=json_default))
57
58
59class KernelHandler(APIHandler):
60
61    @web.authenticated
62    @gen.coroutine
63    def get(self, kernel_id):
64        km = self.kernel_manager
65        model = yield maybe_future(km.kernel_model(kernel_id))
66        self.finish(json.dumps(model, default=json_default))
67
68    @web.authenticated
69    @gen.coroutine
70    def delete(self, kernel_id):
71        km = self.kernel_manager
72        yield maybe_future(km.shutdown_kernel(kernel_id))
73        self.set_status(204)
74        self.finish()
75
76
77class KernelActionHandler(APIHandler):
78
79    @web.authenticated
80    @gen.coroutine
81    def post(self, kernel_id, action):
82        km = self.kernel_manager
83        if action == 'interrupt':
84            yield maybe_future(km.interrupt_kernel(kernel_id))
85            self.set_status(204)
86        if action == 'restart':
87
88            try:
89                yield maybe_future(km.restart_kernel(kernel_id))
90            except Exception as e:
91                self.log.error("Exception restarting kernel", exc_info=True)
92                self.set_status(500)
93            else:
94                model = yield maybe_future(km.kernel_model(kernel_id))
95                self.write(json.dumps(model, default=json_default))
96        self.finish()
97
98
99class ZMQChannelsHandler(AuthenticatedZMQStreamHandler):
100    '''There is one ZMQChannelsHandler per running kernel and it oversees all
101    the sessions.
102    '''
103
104    # class-level registry of open sessions
105    # allows checking for conflict on session-id,
106    # which is used as a zmq identity and must be unique.
107    _open_sessions = {}
108
109    @property
110    def kernel_info_timeout(self):
111        km_default = self.kernel_manager.kernel_info_timeout
112        return self.settings.get('kernel_info_timeout', km_default)
113
114    @property
115    def iopub_msg_rate_limit(self):
116        return self.settings.get('iopub_msg_rate_limit', 0)
117
118    @property
119    def iopub_data_rate_limit(self):
120        return self.settings.get('iopub_data_rate_limit', 0)
121
122    @property
123    def rate_limit_window(self):
124        return self.settings.get('rate_limit_window', 1.0)
125
126    def __repr__(self):
127        return "%s(%s)" % (self.__class__.__name__, getattr(self, 'kernel_id', 'uninitialized'))
128
129    def create_stream(self):
130        km = self.kernel_manager
131        identity = self.session.bsession
132        for channel in ("iopub", "shell", "control", "stdin"):
133            meth = getattr(km, "connect_" + channel)
134            self.channels[channel] = stream = meth(self.kernel_id, identity=identity)
135            stream.channel = channel
136
137    def nudge(self):
138        """Nudge the zmq connections with kernel_info_requests
139
140        Returns a Future that will resolve when we have received
141        a shell reply and at least one iopub message,
142        ensuring that zmq subscriptions are established,
143        sockets are fully connected, and kernel is responsive.
144
145        Keeps retrying kernel_info_request until these are both received.
146        """
147        kernel = self.kernel_manager.get_kernel(self.kernel_id)
148
149        # Do not nudge busy kernels as kernel info requests sent to shell are
150        # queued behind execution requests.
151        # nudging in this case would cause a potentially very long wait
152        # before connections are opened,
153        # plus it is *very* unlikely that a busy kernel will not finish
154        # establishing its zmq subscriptions before processing the next request.
155        if getattr(kernel, "execution_state") == "busy":
156            self.log.debug("Nudge: not nudging busy kernel %s", self.kernel_id)
157            f = Future()
158            f.set_result(None)
159            return f
160
161        # Use a transient shell channel to prevent leaking
162        # shell responses to the front-end.
163        shell_channel = kernel.connect_shell()
164        # The IOPub used by the client, whose subscriptions we are verifying.
165        iopub_channel = self.channels["iopub"]
166
167        info_future = Future()
168        iopub_future = Future()
169        both_done = gen.multi([info_future, iopub_future])
170
171        def finish(f=None):
172            """Ensure all futures are resolved
173
174            which in turn triggers cleanup
175            """
176            for f in (info_future, iopub_future):
177                if not f.done():
178                    f.set_result(None)
179
180        def cleanup(f=None):
181            """Common cleanup"""
182            loop.remove_timeout(nudge_handle)
183            iopub_channel.stop_on_recv()
184            if not shell_channel.closed():
185                shell_channel.close()
186
187        # trigger cleanup when both message futures are resolved
188        both_done.add_done_callback(cleanup)
189
190        def on_shell_reply(msg):
191            self.log.debug("Nudge: shell info reply received: %s", self.kernel_id)
192            if not info_future.done():
193                self.log.debug("Nudge: resolving shell future: %s", self.kernel_id)
194                info_future.set_result(None)
195
196        def on_iopub(msg):
197            self.log.debug("Nudge: IOPub received: %s", self.kernel_id)
198            if not iopub_future.done():
199                iopub_channel.stop_on_recv()
200                self.log.debug("Nudge: resolving iopub future: %s", self.kernel_id)
201                iopub_future.set_result(None)
202
203        iopub_channel.on_recv(on_iopub)
204        shell_channel.on_recv(on_shell_reply)
205        loop = IOLoop.current()
206
207        # Nudge the kernel with kernel info requests until we get an IOPub message
208        def nudge(count):
209            count += 1
210
211            # NOTE: this close check appears to never be True during on_open,
212            # even when the peer has closed the connection
213            if self.ws_connection is None or self.ws_connection.is_closing():
214                self.log.debug(
215                    "Nudge: cancelling on closed websocket: %s", self.kernel_id
216                )
217                finish()
218                return
219
220            # check for stopped kernel
221            if self.kernel_id not in self.kernel_manager:
222                self.log.debug(
223                    "Nudge: cancelling on stopped kernel: %s", self.kernel_id
224                )
225                finish()
226                return
227
228            # check for closed zmq socket
229            if shell_channel.closed():
230                self.log.debug(
231                    "Nudge: cancelling on closed zmq socket: %s", self.kernel_id
232                )
233                finish()
234                return
235
236            if not both_done.done():
237                log = self.log.warning if count % 10 == 0 else self.log.debug
238                log("Nudge: attempt %s on kernel %s" % (count, self.kernel_id))
239                self.session.send(shell_channel, "kernel_info_request")
240                nonlocal nudge_handle
241                nudge_handle = loop.call_later(0.5, nudge, count)
242
243        nudge_handle = loop.call_later(0, nudge, count=0)
244
245        # resolve with a timeout if we get no response
246        future = gen.with_timeout(loop.time() + self.kernel_info_timeout, both_done)
247        # ensure we have no dangling resources or unresolved Futures in case of timeout
248        future.add_done_callback(finish)
249        return future
250
251    def request_kernel_info(self):
252        """send a request for kernel_info"""
253        km = self.kernel_manager
254        kernel = km.get_kernel(self.kernel_id)
255        try:
256            # check for previous request
257            future = kernel._kernel_info_future
258        except AttributeError:
259            self.log.debug("Requesting kernel info from %s", self.kernel_id)
260            # Create a kernel_info channel to query the kernel protocol version.
261            # This channel will be closed after the kernel_info reply is received.
262            if self.kernel_info_channel is None:
263                self.kernel_info_channel = km.connect_shell(self.kernel_id)
264            self.kernel_info_channel.on_recv(self._handle_kernel_info_reply)
265            self.session.send(self.kernel_info_channel, "kernel_info_request")
266            # store the future on the kernel, so only one request is sent
267            kernel._kernel_info_future = self._kernel_info_future
268        else:
269            if not future.done():
270                self.log.debug("Waiting for pending kernel_info request")
271            future.add_done_callback(lambda f: self._finish_kernel_info(f.result()))
272        return self._kernel_info_future
273
274    def _handle_kernel_info_reply(self, msg):
275        """process the kernel_info_reply
276
277        enabling msg spec adaptation, if necessary
278        """
279        idents,msg = self.session.feed_identities(msg)
280        try:
281            msg = self.session.deserialize(msg)
282        except:
283            self.log.error("Bad kernel_info reply", exc_info=True)
284            self._kernel_info_future.set_result({})
285            return
286        else:
287            info = msg['content']
288            self.log.debug("Received kernel info: %s", info)
289            if msg['msg_type'] != 'kernel_info_reply' or 'protocol_version' not in info:
290                self.log.error("Kernel info request failed, assuming current %s", info)
291                info = {}
292            self._finish_kernel_info(info)
293
294        # close the kernel_info channel, we don't need it anymore
295        if self.kernel_info_channel:
296            self.kernel_info_channel.close()
297        self.kernel_info_channel = None
298
299    def _finish_kernel_info(self, info):
300        """Finish handling kernel_info reply
301
302        Set up protocol adaptation, if needed,
303        and signal that connection can continue.
304        """
305        protocol_version = info.get('protocol_version', client_protocol_version)
306        if protocol_version != client_protocol_version:
307            self.session.adapt_version = int(protocol_version.split('.')[0])
308            self.log.info("Adapting from protocol version {protocol_version} (kernel {kernel_id}) to {client_protocol_version} (client).".format(protocol_version=protocol_version, kernel_id=self.kernel_id, client_protocol_version=client_protocol_version))
309        if not self._kernel_info_future.done():
310            self._kernel_info_future.set_result(info)
311
312    def initialize(self):
313        super().initialize()
314        self.zmq_stream = None
315        self.channels = {}
316        self.kernel_id = None
317        self.kernel_info_channel = None
318        self._kernel_info_future = Future()
319        self._close_future = Future()
320        self.session_key = ''
321
322        # Rate limiting code
323        self._iopub_window_msg_count = 0
324        self._iopub_window_byte_count = 0
325        self._iopub_msgs_exceeded = False
326        self._iopub_data_exceeded = False
327        # Queue of (time stamp, byte count)
328        # Allows you to specify that the byte count should be lowered
329        # by a delta amount at some point in the future.
330        self._iopub_window_byte_queue = []
331
332    @gen.coroutine
333    def pre_get(self):
334        # authenticate first
335        super().pre_get()
336        # check session collision:
337        yield self._register_session()
338        # then request kernel info, waiting up to a certain time before giving up.
339        # We don't want to wait forever, because browsers don't take it well when
340        # servers never respond to websocket connection requests.
341        kernel = self.kernel_manager.get_kernel(self.kernel_id)
342        self.session.key = kernel.session.key
343        future = self.request_kernel_info()
344
345        def give_up():
346            """Don't wait forever for the kernel to reply"""
347            if future.done():
348                return
349            self.log.warning("Timeout waiting for kernel_info reply from %s", self.kernel_id)
350            future.set_result({})
351        loop = IOLoop.current()
352        loop.add_timeout(loop.time() + self.kernel_info_timeout, give_up)
353        # actually wait for it
354        yield future
355
356    @gen.coroutine
357    def get(self, kernel_id):
358        self.kernel_id = cast_unicode(kernel_id, 'ascii')
359        yield super().get(kernel_id=kernel_id)
360
361    @gen.coroutine
362    def _register_session(self):
363        """Ensure we aren't creating a duplicate session.
364
365        If a previous identical session is still open, close it to avoid collisions.
366        This is likely due to a client reconnecting from a lost network connection,
367        where the socket on our side has not been cleaned up yet.
368        """
369        self.session_key = '%s:%s' % (self.kernel_id, self.session.session)
370        stale_handler = self._open_sessions.get(self.session_key)
371        if stale_handler:
372            self.log.warning("Replacing stale connection: %s", self.session_key)
373            yield stale_handler.close()
374        self._open_sessions[self.session_key] = self
375
376    def open(self, kernel_id):
377        super().open()
378        km = self.kernel_manager
379        km.notify_connect(kernel_id)
380
381        # on new connections, flush the message buffer
382        buffer_info = km.get_buffer(kernel_id, self.session_key)
383        if buffer_info and buffer_info['session_key'] == self.session_key:
384            self.log.info("Restoring connection for %s", self.session_key)
385            self.channels = buffer_info['channels']
386            connected = self.nudge()
387
388            def replay(value):
389                replay_buffer = buffer_info['buffer']
390                if replay_buffer:
391                    self.log.info("Replaying %s buffered messages", len(replay_buffer))
392                    for channel, msg_list in replay_buffer:
393                        stream = self.channels[channel]
394                        self._on_zmq_reply(stream, msg_list)
395
396            connected.add_done_callback(replay)
397        else:
398            try:
399                self.create_stream()
400                connected = self.nudge()
401            except web.HTTPError as e:
402                self.log.error("Error opening stream: %s", e)
403                # WebSockets don't response to traditional error codes so we
404                # close the connection.
405                for channel, stream in self.channels.items():
406                    if not stream.closed():
407                        stream.close()
408                self.close()
409                return
410
411        km.add_restart_callback(self.kernel_id, self.on_kernel_restarted)
412        km.add_restart_callback(self.kernel_id, self.on_restart_failed, 'dead')
413
414        def subscribe(value):
415            for channel, stream in self.channels.items():
416                stream.on_recv_stream(self._on_zmq_reply)
417
418        connected.add_done_callback(subscribe)
419
420        return connected
421
422    def on_message(self, msg):
423        if not self.channels:
424            # already closed, ignore the message
425            self.log.debug("Received message on closed websocket %r", msg)
426            return
427        if isinstance(msg, bytes):
428            msg = deserialize_binary_message(msg)
429        else:
430            msg = json.loads(msg)
431        channel = msg.pop('channel', None)
432        if channel is None:
433            self.log.warning("No channel specified, assuming shell: %s", msg)
434            channel = 'shell'
435        if channel not in self.channels:
436            self.log.warning("No such channel: %r", channel)
437            return
438        am = self.kernel_manager.allowed_message_types
439        mt = msg['header']['msg_type']
440        if am and mt not in am:
441            self.log.warning('Received message of type "%s", which is not allowed. Ignoring.' % mt)
442        else:
443            stream = self.channels[channel]
444            self.session.send(stream, msg)
445
446    def _on_zmq_reply(self, stream, msg_list):
447        idents, fed_msg_list = self.session.feed_identities(msg_list)
448        msg = self.session.deserialize(fed_msg_list)
449        parent = msg['parent_header']
450        def write_stderr(error_message):
451            self.log.warning(error_message)
452            msg = self.session.msg("stream",
453                content={"text": error_message + '\n', "name": "stderr"},
454                parent=parent
455            )
456            msg['channel'] = 'iopub'
457            self.write_message(json.dumps(msg, default=json_default))
458        channel = getattr(stream, 'channel', None)
459        msg_type = msg['header']['msg_type']
460
461        if channel == 'iopub' and msg_type == 'status' and msg['content'].get('execution_state') == 'idle':
462            # reset rate limit counter on status=idle,
463            # to avoid 'Run All' hitting limits prematurely.
464            self._iopub_window_byte_queue = []
465            self._iopub_window_msg_count = 0
466            self._iopub_window_byte_count = 0
467            self._iopub_msgs_exceeded = False
468            self._iopub_data_exceeded = False
469
470        if channel == 'iopub' and msg_type not in {'status', 'comm_open', 'execute_input'}:
471
472            # Remove the counts queued for removal.
473            now = IOLoop.current().time()
474            while len(self._iopub_window_byte_queue) > 0:
475                queued = self._iopub_window_byte_queue[0]
476                if (now >= queued[0]):
477                    self._iopub_window_byte_count -= queued[1]
478                    self._iopub_window_msg_count -= 1
479                    del self._iopub_window_byte_queue[0]
480                else:
481                    # This part of the queue hasn't be reached yet, so we can
482                    # abort the loop.
483                    break
484
485            # Increment the bytes and message count
486            self._iopub_window_msg_count += 1
487            if msg_type == 'stream':
488                byte_count = sum([len(x) for x in msg_list])
489            else:
490                byte_count = 0
491            self._iopub_window_byte_count += byte_count
492
493            # Queue a removal of the byte and message count for a time in the
494            # future, when we are no longer interested in it.
495            self._iopub_window_byte_queue.append((now + self.rate_limit_window, byte_count))
496
497            # Check the limits, set the limit flags, and reset the
498            # message and data counts.
499            msg_rate = float(self._iopub_window_msg_count) / self.rate_limit_window
500            data_rate = float(self._iopub_window_byte_count) / self.rate_limit_window
501
502            # Check the msg rate
503            if self.iopub_msg_rate_limit > 0 and msg_rate > self.iopub_msg_rate_limit:
504                if not self._iopub_msgs_exceeded:
505                    self._iopub_msgs_exceeded = True
506                    write_stderr(dedent("""\
507                    IOPub message rate exceeded.
508                    The notebook server will temporarily stop sending output
509                    to the client in order to avoid crashing it.
510                    To change this limit, set the config variable
511                    `--NotebookApp.iopub_msg_rate_limit`.
512
513                    Current values:
514                    NotebookApp.iopub_msg_rate_limit={} (msgs/sec)
515                    NotebookApp.rate_limit_window={} (secs)
516                    """.format(self.iopub_msg_rate_limit, self.rate_limit_window)))
517            else:
518                # resume once we've got some headroom below the limit
519                if self._iopub_msgs_exceeded and msg_rate < (0.8 * self.iopub_msg_rate_limit):
520                    self._iopub_msgs_exceeded = False
521                    if not self._iopub_data_exceeded:
522                        self.log.warning("iopub messages resumed")
523
524            # Check the data rate
525            if self.iopub_data_rate_limit > 0 and data_rate > self.iopub_data_rate_limit:
526                if not self._iopub_data_exceeded:
527                    self._iopub_data_exceeded = True
528                    write_stderr(dedent("""\
529                    IOPub data rate exceeded.
530                    The notebook server will temporarily stop sending output
531                    to the client in order to avoid crashing it.
532                    To change this limit, set the config variable
533                    `--NotebookApp.iopub_data_rate_limit`.
534
535                    Current values:
536                    NotebookApp.iopub_data_rate_limit={} (bytes/sec)
537                    NotebookApp.rate_limit_window={} (secs)
538                    """.format(self.iopub_data_rate_limit, self.rate_limit_window)))
539            else:
540                # resume once we've got some headroom below the limit
541                if self._iopub_data_exceeded and data_rate < (0.8 * self.iopub_data_rate_limit):
542                    self._iopub_data_exceeded = False
543                    if not self._iopub_msgs_exceeded:
544                        self.log.warning("iopub messages resumed")
545
546            # If either of the limit flags are set, do not send the message.
547            if self._iopub_msgs_exceeded or self._iopub_data_exceeded:
548                # we didn't send it, remove the current message from the calculus
549                self._iopub_window_msg_count -= 1
550                self._iopub_window_byte_count -= byte_count
551                self._iopub_window_byte_queue.pop(-1)
552                return
553        super()._on_zmq_reply(stream, msg)
554
555    def close(self):
556        super().close()
557        return self._close_future
558
559    def on_close(self):
560        self.log.debug("Websocket closed %s", self.session_key)
561        # unregister myself as an open session (only if it's really me)
562        if self._open_sessions.get(self.session_key) is self:
563            self._open_sessions.pop(self.session_key)
564
565        km = self.kernel_manager
566        if self.kernel_id in km:
567            km.notify_disconnect(self.kernel_id)
568            km.remove_restart_callback(
569                self.kernel_id, self.on_kernel_restarted,
570            )
571            km.remove_restart_callback(
572                self.kernel_id, self.on_restart_failed, 'dead',
573            )
574
575            # start buffering instead of closing if this was the last connection
576            if km._kernel_connections[self.kernel_id] == 0:
577                km.start_buffering(self.kernel_id, self.session_key, self.channels)
578                self._close_future.set_result(None)
579                return
580
581        # This method can be called twice, once by self.kernel_died and once
582        # from the WebSocket close event. If the WebSocket connection is
583        # closed before the ZMQ streams are setup, they could be None.
584        for channel, stream in self.channels.items():
585            if stream is not None and not stream.closed():
586                stream.on_recv(None)
587                stream.close()
588
589        self.channels = {}
590        self._close_future.set_result(None)
591
592    def _send_status_message(self, status):
593        iopub = self.channels.get('iopub', None)
594        if iopub and not iopub.closed():
595            # flush IOPub before sending a restarting/dead status message
596            # ensures proper ordering on the IOPub channel
597            # that all messages from the stopped kernel have been delivered
598            iopub.flush()
599        msg = self.session.msg("status",
600            {'execution_state': status}
601        )
602        msg['channel'] = 'iopub'
603        self.write_message(json.dumps(msg, default=json_default))
604
605    def on_kernel_restarted(self):
606        logging.warn("kernel %s restarted", self.kernel_id)
607        self._send_status_message('restarting')
608
609    def on_restart_failed(self):
610        logging.error("kernel %s restarted failed!", self.kernel_id)
611        self._send_status_message('dead')
612
613
614#-----------------------------------------------------------------------------
615# URL to handler mappings
616#-----------------------------------------------------------------------------
617
618
619_kernel_id_regex = r"(?P<kernel_id>\w+-\w+-\w+-\w+-\w+)"
620_kernel_action_regex = r"(?P<action>restart|interrupt)"
621
622default_handlers = [
623    (r"/api/kernels", MainKernelHandler),
624    (r"/api/kernels/%s" % _kernel_id_regex, KernelHandler),
625    (r"/api/kernels/%s/%s" % (_kernel_id_regex, _kernel_action_regex), KernelActionHandler),
626    (r"/api/kernels/%s/channels" % _kernel_id_regex, ZMQChannelsHandler),
627]
628