1import struct
2from collections import deque
3from logging import getLogger
4
5import asyncio
6import priority
7from h2 import events
8from h2 import settings
9from h2.config import H2Configuration
10from h2.connection import H2Connection
11from h2.exceptions import NoSuchStreamError, StreamClosedError, ProtocolError
12
13from . import exceptions
14
15__all__ = ['H2Protocol']
16logger = getLogger(__package__)
17
18
19@asyncio.coroutine
20def _wait_for_events(*events_):
21    while not all([event.is_set() for event in events_]):
22        yield from asyncio.wait([event.wait() for event in events_])
23
24
25class _StreamEndedException(Exception):
26    def __init__(self, bufs=None):
27        if bufs is None:
28            bufs = []
29        self.bufs = bufs
30
31
32class CallableEvent(asyncio.Event):
33    def __init__(self, func, *, loop=None):
34        super().__init__(loop=loop)
35        self._func = func
36
37    @asyncio.coroutine
38    def wait(self):
39        while not self._func():
40            self.clear()
41            yield from super().wait()
42
43    def sync(self):
44        if self._func():
45            self.set()
46        else:
47            self.clear()
48
49    def is_set(self):
50        self.sync()
51        return super().is_set()
52
53
54class H2Stream:
55    def __init__(self, stream_id, window_getter, loop=None):
56        if loop is None:
57            loop = asyncio.get_event_loop()
58        self._stream_id = stream_id
59        self._window_getter = window_getter
60
61        self._wlock = asyncio.Lock(loop=loop)
62        self._window_open = CallableEvent(self._is_window_open, loop=loop)
63
64        self._rlock = asyncio.Lock(loop=loop)
65        self._buffers = deque()
66        self._buffer_size = 0
67        self._buffer_ready = asyncio.Event(loop=loop)
68        self._response = asyncio.Future(loop=loop)
69        self._trailers = asyncio.Future(loop=loop)
70        self._eof_received = False
71        self._closed = False
72
73    @property
74    def id(self):
75        return self._stream_id
76
77    @property
78    def window_open(self):
79        return self._window_open
80
81    @property
82    def rlock(self):
83        return self._rlock
84
85    @property
86    def wlock(self):
87        return self._wlock
88
89    @property
90    def buffer_size(self):
91        return self._buffer_size
92
93    @property
94    def response(self):
95        return self._response
96
97    @property
98    def trailers(self):
99        return self._trailers
100
101    def _is_window_open(self):
102        try:
103            window = self._window_getter(self._stream_id)
104        except NoSuchStreamError:
105            self._closed = True
106            return True
107        else:
108            return window > 0
109
110    def feed_data(self, data):
111        if data:
112            self._buffers.append(data)
113            self._buffer_size += len(data)
114            self._buffer_ready.set()
115
116    def feed_eof(self):
117        self._eof_received = True
118        self._buffer_ready.set()
119        self.feed_trailers({})
120
121    def feed_response(self, headers):
122        self._response.set_result(headers)
123
124    def feed_trailers(self, headers):
125        if not self._trailers.done():
126            self._trailers.set_result(headers)
127
128    @asyncio.coroutine
129    def read_frame(self):
130        yield from self._buffer_ready.wait()
131        rv = b''
132        if self._buffers:
133            rv = self._buffers.popleft()
134            self._buffer_size -= len(rv)
135        if not self._buffers:
136            if self._eof_received:
137                raise _StreamEndedException([rv])
138            else:
139                self._buffer_ready.clear()
140        return rv
141
142    @asyncio.coroutine
143    def read_all(self):
144        yield from self._buffer_ready.wait()
145        rv = []
146        rv.extend(self._buffers)
147        self._buffers.clear()
148        self._buffer_size = 0
149        if self._eof_received:
150            raise _StreamEndedException(rv)
151        else:
152            self._buffer_ready.clear()
153            return rv
154
155    @asyncio.coroutine
156    def read(self, n):
157        yield from self._buffer_ready.wait()
158        rv = []
159        count = 0
160        while n > count and self._buffers:
161            buf = self._buffers.popleft()
162            count += len(buf)
163            if n < count:
164                rv.append(buf[:n - count])
165                self._buffers.appendleft(buf[n - count:])
166                count = n
167            else:
168                rv.append(buf)
169        self._buffer_size -= count
170        if not self._buffers:
171            if self._eof_received:
172                raise _StreamEndedException(rv)
173            else:
174                self._buffer_ready.clear()
175        return rv, count
176
177
178class H2Protocol(asyncio.Protocol):
179    def __init__(self, client_side: bool, *, loop=None,
180                 concurrency=8, functional_timeout=2):
181        if loop is None:
182            loop = asyncio.get_event_loop()
183        self._loop = loop
184        config = H2Configuration(client_side=client_side,
185                                 header_encoding='utf-8')
186        self._conn = H2Connection(config=config)
187        self._transport = None
188        self._streams = {}
189        self._inbound_requests = asyncio.Queue(concurrency, loop=loop)
190        self._priority = priority.PriorityTree()
191        self._priority_events = {}
192        self._handler = None
193
194        # Locks
195
196        self._is_resumed = False
197        self._resumed = CallableEvent(lambda: self._is_resumed, loop=loop)
198        self._stream_creatable = CallableEvent(self._is_stream_creatable,
199                                               loop=loop)
200        self._last_active = 0
201        self._ping_index = -1
202        self._ping_time = 0
203        self._rtt = None
204        self._functional_timeout = functional_timeout
205        self._functional = CallableEvent(self._is_functional, loop=loop)
206
207        # Dispatch table
208
209        self._event_handlers = {
210            events.RequestReceived: self._request_received,
211            events.ResponseReceived: self._response_received,
212            events.TrailersReceived: self._trailers_received,
213            events.DataReceived: self._data_received,
214            events.WindowUpdated: self._window_updated,
215            events.RemoteSettingsChanged: self._remote_settings_changed,
216            events.PingAcknowledged: self._ping_acknowledged,
217            events.StreamEnded: self._stream_ended,
218            events.StreamReset: self._stream_reset,
219            events.PushedStreamReceived: self._pushed_stream_received,
220            events.SettingsAcknowledged: self._settings_acknowledged,
221            events.PriorityUpdated: self._priority_updated,
222            events.ConnectionTerminated: self._connection_terminated,
223        }
224
225    # asyncio protocol
226
227    def connection_made(self, transport):
228        self._transport = transport
229        self._conn.initiate_connection()
230        self._conn.update_settings({
231            settings.SettingCodes.MAX_CONCURRENT_STREAMS: self._inbound_requests.maxsize})
232        self._flush()
233        self._stream_creatable.sync()
234        self.resume_writing()
235        self._last_active = self._loop.time()
236        self._functional.sync()
237
238    def connection_lost(self, exc):
239        self._conn = None
240        self._transport = None
241        self.pause_writing()
242        if self._handler:
243            self._handler.cancel()
244
245    def pause_writing(self):
246        self._is_resumed = False
247        self._resumed.sync()
248
249    def resume_writing(self):
250        self._is_resumed = True
251        self._resumed.sync()
252
253    def data_received(self, data):
254        self._last_active = self._loop.time()
255        self._functional.sync()
256        events_ = self._conn.receive_data(data)
257        self._flush()
258        for event in events_:
259            self._event_received(event)
260
261    def eof_received(self):
262        self._conn.close_connection()
263        self._flush()
264
265    # hyper-h2 event handlers
266
267    def _event_received(self, event):
268        self._event_handlers[type(event)](event)
269
270    def _request_received(self, event: events.RequestReceived):
271        self._inbound_requests.put_nowait((0, event.stream_id, event.headers))
272        self._priority.insert_stream(event.stream_id)
273        self._priority.block(event.stream_id)
274
275    def _response_received(self, event: events.ResponseReceived):
276        self._get_stream(event.stream_id).feed_response(event.headers)
277
278    def _trailers_received(self, event: events.TrailersReceived):
279        self._get_stream(event.stream_id).feed_trailers(event.headers)
280
281    def _data_received(self, event: events.DataReceived):
282        self._get_stream(event.stream_id).feed_data(event.data)
283        if self._conn.inbound_flow_control_window < 1073741823:
284            self._conn.increment_flow_control_window(
285                2 ** 31 - 1 - self._conn.inbound_flow_control_window)
286            self._flush()
287
288    def _window_updated(self, event: events.WindowUpdated):
289        if event.stream_id:
290            self._get_stream(event.stream_id).window_open.sync()
291        else:
292            for stream in list(self._streams.values()):
293                stream.window_open.sync()
294
295    def _remote_settings_changed(self, event: events.RemoteSettingsChanged):
296        if settings.SettingCodes.INITIAL_WINDOW_SIZE in event.changed_settings:
297            for stream in list(self._streams.values()):
298                stream.window_open.sync()
299        if settings.SettingCodes.MAX_CONCURRENT_STREAMS in event.changed_settings:
300            self._stream_creatable.sync()
301
302    def _ping_acknowledged(self, event: events.PingAcknowledged):
303        if struct.unpack('Q', event.ping_data) == (self._ping_index,):
304            self._rtt = self._loop.time() - self._ping_time
305
306    def _stream_ended(self, event: events.StreamEnded):
307        self._get_stream(event.stream_id).feed_eof()
308        self._stream_creatable.sync()
309
310    def _stream_reset(self, event: events.StreamReset):
311        self._get_stream(event.stream_id).window_open.set()
312        self._stream_creatable.sync()
313
314    def _pushed_stream_received(self, event: events.PushedStreamReceived):
315        pass
316
317    def _settings_acknowledged(self, event: events.SettingsAcknowledged):
318        pass
319
320    def _priority_updated(self, event: events.PriorityUpdated):
321        self._priority.reprioritize(
322            event.stream_id, event.depends_on, event.weight, event.exclusive)
323
324    def _connection_terminated(self, event: events.ConnectionTerminated):
325        logger.warning('Remote peer sent GOAWAY [ERR: %s], disconnect now.',
326                       event.error_code)
327        self._transport.close()
328
329    # Internals
330
331    def _get_stream(self, stream_id):
332        stream = self._streams.get(stream_id)
333        if stream is None:
334            stream = self._streams[stream_id] = H2Stream(
335                stream_id, self._conn.local_flow_control_window,
336                loop=self._loop)
337        return stream
338
339    def _flush(self):
340        self._transport.write(self._conn.data_to_send())
341
342    def _is_stream_creatable(self):
343        return (self._conn.open_outbound_streams <
344                self._conn.remote_settings.max_concurrent_streams)
345
346    def _flow_control(self, stream_id):
347        delta = (self._conn.local_settings.initial_window_size -
348                 self._get_stream(stream_id).buffer_size -
349                 self._conn.remote_flow_control_window(stream_id))
350        if delta > 0:
351            self._conn.increment_flow_control_window(delta, stream_id)
352            self._flush()
353
354    def _is_functional(self):
355        return self._last_active + self._functional_timeout > self._loop.time()
356
357    def _priority_step(self):
358        # noinspection PyBroadException
359        try:
360            for stream_id in self._priority:
361                fut = self._priority_events.pop(stream_id, None)
362                if fut is not None:
363                    fut.set_result(None)
364                    break
365        except Exception:
366            if self._priority_events:
367                self._priority_events.popitem()[1].set_result(None)
368
369    # APIs
370
371    def set_handler(self, handler):
372        """
373        Connect with a coroutine, which is scheduled when connection is made.
374
375        This function will create a task, and when connection is closed,
376        the task will be canceled.
377        :param handler:
378        :return: None
379        """
380        if self._handler:
381            raise Exception('Handler was already set')
382        if handler:
383            self._handler = asyncio.async(handler, loop=self._loop)
384
385    def close_connection(self):
386        self._transport.close()
387
388    @asyncio.coroutine
389    def start_request(self, headers, *, end_stream=False):
390        """
391        Start a request by sending given headers on a new stream, and return
392        the ID of the new stream.
393
394        This may block until the underlying transport becomes writable, and
395        the number of concurrent outbound requests (open outbound streams) is
396        less than the value of peer config MAX_CONCURRENT_STREAMS.
397
398        The completion of the call to this method does not mean the request is
399        successfully delivered - data is only correctly stored in a buffer to
400        be sent. There's no guarantee it is truly delivered.
401
402        :param headers: A list of key-value tuples as headers.
403        :param end_stream: To send a request without body, set `end_stream` to
404                           `True` (default `False`).
405        :return: Stream ID as a integer, used for further communication.
406        """
407        yield from _wait_for_events(self._resumed, self._stream_creatable)
408        stream_id = self._conn.get_next_available_stream_id()
409        self._priority.insert_stream(stream_id)
410        self._priority.block(stream_id)
411        self._conn.send_headers(stream_id, headers, end_stream=end_stream)
412        self._flush()
413        return stream_id
414
415    @asyncio.coroutine
416    def start_response(self, stream_id, headers, *, end_stream=False):
417        """
418        Start a response by sending given headers on the given stream.
419
420        This may block until the underlying transport becomes writable.
421
422        :param stream_id: Which stream to send response on.
423        :param headers: A list of key-value tuples as headers.
424        :param end_stream: To send a response without body, set `end_stream` to
425                           `True` (default `False`).
426        """
427        yield from self._resumed.wait()
428        self._conn.send_headers(stream_id, headers, end_stream=end_stream)
429        self._flush()
430
431    @asyncio.coroutine
432    def send_data(self, stream_id, data, *, end_stream=False):
433        """
434        Send request or response body on the given stream.
435
436        This will block until either whole data is sent, or the stream gets
437        closed. Meanwhile, a paused underlying transport or a closed flow
438        control window will also help waiting. If the peer increase the flow
439        control window, this method will start sending automatically.
440
441        This can be called multiple times, but it must be called after a
442        `start_request` or `start_response` with the returning stream ID, and
443        before any `end_stream` instructions; Otherwise it will fail.
444
445        The given data may be automatically split into smaller frames in order
446        to fit in the configured frame size or flow control window.
447
448        Each stream can only have one `send_data` running, others calling this
449        will be blocked on a per-stream lock (wlock), so that coroutines
450        sending data concurrently won't mess up with each other.
451
452        Similarly, the completion of the call to this method does not mean the
453        data is delivered.
454
455        :param stream_id: Which stream to send data on
456        :param data: Bytes to send
457        :param end_stream: To finish sending a request or response, set this to
458                           `True` to close the given stream locally after data
459                           is sent (default `False`).
460        :raise: `SendException` if there is an error sending data. Data left
461                unsent can be found in `data` of the exception.
462        """
463        try:
464            with (yield from self._get_stream(stream_id).wlock):
465                while True:
466                    yield from _wait_for_events(
467                        self._resumed, self._get_stream(stream_id).window_open)
468                    self._priority.unblock(stream_id)
469                    waiter = asyncio.Future()
470                    if not self._priority_events:
471                        self._loop.call_soon(self._priority_step)
472                    self._priority_events[stream_id] = waiter
473                    try:
474                        yield from waiter
475                        data_size = len(data)
476                        size = min(
477                            data_size,
478                            self._conn.local_flow_control_window(stream_id),
479                            self._conn.max_outbound_frame_size)
480                        if data_size == 0 or size == data_size:
481                            self._conn.send_data(stream_id, data,
482                                                 end_stream=end_stream)
483                            self._flush()
484                            break
485                        elif size > 0:
486                            self._conn.send_data(stream_id, data[:size])
487                            data = data[size:]
488                            self._flush()
489                    finally:
490                        self._priority_events.pop(stream_id, None)
491                        self._priority.block(stream_id)
492                        if self._priority_events:
493                            self._loop.call_soon(self._priority_step)
494        except ProtocolError:
495            raise exceptions.SendException(data)
496
497    @asyncio.coroutine
498    def send_trailers(self, stream_id, headers):
499        """
500        Send trailers on the given stream, closing the stream locally.
501
502        This may block until the underlying transport becomes writable, or
503        other coroutines release the wlock on this stream.
504
505        :param stream_id: Which stream to send trailers on.
506        :param headers: A list of key-value tuples as trailers.
507        """
508        with (yield from self._get_stream(stream_id).wlock):
509            yield from self._resumed.wait()
510            self._conn.send_headers(stream_id, headers, end_stream=True)
511            self._flush()
512
513    @asyncio.coroutine
514    def end_stream(self, stream_id):
515        """
516        Close the given stream locally.
517
518        This may block until the underlying transport becomes writable, or
519        other coroutines release the wlock on this stream.
520
521        :param stream_id: Which stream to close.
522        """
523        with (yield from self._get_stream(stream_id).wlock):
524            yield from self._resumed.wait()
525            self._conn.end_stream(stream_id)
526            self._flush()
527
528    @asyncio.coroutine
529    def recv_request(self):
530        """
531        Retrieve next inbound request in queue.
532
533        This will block until a request is available.
534
535        :return: A tuple `(stream_id, headers)`.
536        """
537        rv = yield from self._inbound_requests.get()
538        return rv[1:]
539
540    @asyncio.coroutine
541    def recv_response(self, stream_id):
542        """
543        Wait until a response is ready on the given stream.
544
545        :param stream_id: Stream to wait on.
546        :return: A list of key-value tuples as response headers.
547        """
548        return (yield from self._get_stream(stream_id).response)
549
550    @asyncio.coroutine
551    def recv_trailers(self, stream_id):
552        """
553        Wait until trailers are ready on the given stream.
554
555        :param stream_id: Stream to wait on.
556        :return: A list of key-value tuples as trailers.
557        """
558        return (yield from self._get_stream(stream_id).trailers)
559
560    @asyncio.coroutine
561    def read_stream(self, stream_id, size=None):
562        """
563        Read data from the given stream.
564
565        By default (`size=None`), this returns all data left in current HTTP/2
566        frame. In other words, default behavior is to receive frame by frame.
567
568        If size is given a number above zero, method will try to return as much
569        bytes as possible up to the given size, block until enough bytes are
570        ready or stream is remotely closed.
571
572        If below zero, it will read until the stream is remotely closed and
573        return everything at hand.
574
575        `size=0` is a special case that does nothing but returns `b''`. The
576        same result `b''` is also returned under other conditions if there is
577        no more data on the stream to receive, even under `size=None` and peer
578        sends an empty frame - you can use `b''` to safely identify the end of
579        the given stream.
580
581        Flow control frames will be automatically sent while reading clears the
582        buffer, allowing more data to come in.
583
584        :param stream_id: Stream to read
585        :param size: Expected size to read, `-1` for all, default frame.
586        :return: Bytes read or empty if there is no more to expect.
587        """
588        rv = []
589        try:
590            with (yield from self._get_stream(stream_id).rlock):
591                if size is None:
592                    rv.append((
593                        yield from self._get_stream(stream_id).read_frame()))
594                    self._flow_control(stream_id)
595                elif size < 0:
596                    while True:
597                        rv.extend((
598                            yield from self._get_stream(stream_id).read_all()))
599                        self._flow_control(stream_id)
600                else:
601                    while size > 0:
602                        bufs, count = yield from self._get_stream(
603                            stream_id).read(size)
604                        rv.extend(bufs)
605                        size -= count
606                        self._flow_control(stream_id)
607        except StreamClosedError:
608            pass
609        except _StreamEndedException as e:
610            try:
611                self._flow_control(stream_id)
612            except StreamClosedError:
613                pass
614            rv.extend(e.bufs)
615        return b''.join(rv)
616
617    def update_settings(self, new_settings):
618        self._conn.update_settings(new_settings)
619        self._flush()
620
621    @asyncio.coroutine
622    def wait_functional(self):
623        """
624        Wait until the connection becomes functional.
625
626        The connection is count functional if it was active within last few
627        seconds (defined by `functional_timeout`), where a newly-made
628        connection and received data indicate activeness.
629
630        :return: Most recently calculated round-trip time if any.
631        """
632        while not self._is_functional():
633            self._rtt = None
634            self._ping_index += 1
635            self._ping_time = self._loop.time()
636            self._conn.ping(struct.pack('Q', self._ping_index))
637            self._flush()
638            try:
639                yield from asyncio.wait_for(self._functional.wait(),
640                                            self._functional_timeout)
641            except asyncio.TimeoutError:
642                pass
643        return self._rtt
644
645    def reprioritize(self, stream_id,
646                     depends_on=None, weight=16, exclusive=False):
647        """
648        Update the priority status of an existing stream.
649
650        :param stream_id: The stream ID of the stream being updated.
651        :param depends_on: (optional) The ID of the stream that the stream now
652            depends on. If ``None``, will be moved to depend on stream 0.
653        :param weight: (optional) The new weight to give the stream. Defaults
654            to 16.
655        :param exclusive: (optional) Whether this stream should now be an
656            exclusive dependency of the new parent.
657        """
658        self._priority.reprioritize(stream_id, depends_on, weight, exclusive)
659
660    @property
661    def functional_timeout(self):
662        """
663        A timeout value in seconds used by `wait_functional`, beyond which will
664        self send a PING frame since last activeness and block the call to
665        `wait_functional` until acknowledgement is received.
666
667        Setting this to a larger value may cause pending `wait_functional`
668        calls to unblock immediately.
669        """
670        return self._functional_timeout
671
672    @functional_timeout.setter
673    def functional_timeout(self, val):
674        self._functional_timeout = val
675        self._functional.sync()
676
677    @property
678    def initial_window_size(self):
679        """
680        Self initial window size (in octets) for stream-level flow control.
681
682        Setting a larger value may cause the inbound buffer increase, and allow
683        more data to be received. Setting with a smaller value does not
684        decrease the buffer immediately, but may prevent the peer from sending
685        more data to overflow the buffer for a while. However, it is still up
686        to the peer whether to respect this setting or not.
687        """
688        return self._conn.local_settings.initial_window_size
689
690    @initial_window_size.setter
691    def initial_window_size(self, val):
692        self._conn.update_settings({settings.SettingCodes.INITIAL_WINDOW_SIZE: val})
693        self._flush()
694
695    @property
696    def max_frame_size(self):
697        """
698        The size of the largest frame payload that self is willing to receive,
699        in octets.
700
701        Smaller value indicates finer data slices the peer should send, vice
702        versa; the peer however may not agree.
703        """
704        return self._conn.local_settings.max_frame_size
705
706    @max_frame_size.setter
707    def max_frame_size(self, val):
708        self._conn.update_settings({settings.SettingCodes.MAX_FRAME_SIZE: val})
709        self._flush()
710
711    @property
712    def max_concurrent_streams(self):
713        """
714        The maximum number of concurrent streams that self is willing to allow.
715        """
716        return self._conn.local_settings.max_concurrent_streams
717
718    @max_concurrent_streams.setter
719    def max_concurrent_streams(self, val):
720        self._conn.update_settings({settings.SettingCodes.MAX_CONCURRENT_STREAMS: val})
721        self._flush()
722