1# -*- coding: utf-8 -*-
2"""
3test_utility_functions
4~~~~~~~~~~~~~~~~~~~~~~
5
6Tests for the various utility functions provided by hyper-h2.
7"""
8import pytest
9
10import h2.config
11import h2.connection
12import h2.errors
13import h2.events
14import h2.exceptions
15from h2.utilities import SizeLimitDict, extract_method_header
16
17# These tests require a non-list-returning range function.
18try:
19    range = xrange
20except NameError:
21    range = range
22
23
24class TestGetNextAvailableStreamID(object):
25    """
26    Tests for the ``H2Connection.get_next_available_stream_id`` method.
27    """
28    example_request_headers = [
29        (':authority', 'example.com'),
30        (':path', '/'),
31        (':scheme', 'https'),
32        (':method', 'GET'),
33    ]
34    example_response_headers = [
35        (':status', '200'),
36        ('server', 'fake-serv/0.1.0')
37    ]
38    server_config = h2.config.H2Configuration(client_side=False)
39
40    def test_returns_correct_sequence_for_clients(self, frame_factory):
41        """
42        For a client connection, the correct sequence of stream IDs is
43        returned.
44        """
45        # Running the exhaustive version of this test (all 1 billion available
46        # stream IDs) is too painful. For that reason, we validate that the
47        # original sequence is right for the first few thousand, and then just
48        # check that it terminates properly.
49        #
50        # Make sure that the streams get cleaned up: 8k streams floating
51        # around would make this test memory-hard, and it's not supposed to be
52        # a test of how much RAM your machine has.
53        c = h2.connection.H2Connection()
54        c.initiate_connection()
55        initial_sequence = range(1, 2**13, 2)
56
57        for expected_stream_id in initial_sequence:
58            stream_id = c.get_next_available_stream_id()
59            assert stream_id == expected_stream_id
60
61            c.send_headers(
62                stream_id=stream_id,
63                headers=self.example_request_headers,
64                end_stream=True
65            )
66            f = frame_factory.build_headers_frame(
67                headers=self.example_response_headers,
68                stream_id=stream_id,
69                flags=['END_STREAM'],
70            )
71            c.receive_data(f.serialize())
72            c.clear_outbound_data_buffer()
73
74        # Jump up to the last available stream ID. Don't clean up the stream
75        # here because who cares about one stream.
76        last_client_id = 2**31 - 1
77        c.send_headers(
78            stream_id=last_client_id,
79            headers=self.example_request_headers,
80            end_stream=True
81        )
82
83        with pytest.raises(h2.exceptions.NoAvailableStreamIDError):
84            c.get_next_available_stream_id()
85
86    def test_returns_correct_sequence_for_servers(self, frame_factory):
87        """
88        For a server connection, the correct sequence of stream IDs is
89        returned.
90        """
91        # Running the exhaustive version of this test (all 1 billion available
92        # stream IDs) is too painful. For that reason, we validate that the
93        # original sequence is right for the first few thousand, and then just
94        # check that it terminates properly.
95        #
96        # Make sure that the streams get cleaned up: 8k streams floating
97        # around would make this test memory-hard, and it's not supposed to be
98        # a test of how much RAM your machine has.
99        c = h2.connection.H2Connection(config=self.server_config)
100        c.initiate_connection()
101        c.receive_data(frame_factory.preamble())
102        f = frame_factory.build_headers_frame(
103            headers=self.example_request_headers
104        )
105        c.receive_data(f.serialize())
106
107        initial_sequence = range(2, 2**13, 2)
108
109        for expected_stream_id in initial_sequence:
110            stream_id = c.get_next_available_stream_id()
111            assert stream_id == expected_stream_id
112
113            c.push_stream(
114                stream_id=1,
115                promised_stream_id=stream_id,
116                request_headers=self.example_request_headers
117            )
118            c.send_headers(
119                stream_id=stream_id,
120                headers=self.example_response_headers,
121                end_stream=True
122            )
123            c.clear_outbound_data_buffer()
124
125        # Jump up to the last available stream ID. Don't clean up the stream
126        # here because who cares about one stream.
127        last_server_id = 2**31 - 2
128        c.push_stream(
129            stream_id=1,
130            promised_stream_id=last_server_id,
131            request_headers=self.example_request_headers,
132        )
133
134        with pytest.raises(h2.exceptions.NoAvailableStreamIDError):
135            c.get_next_available_stream_id()
136
137    def test_does_not_increment_without_stream_send(self):
138        """
139        If a new stream isn't actually created, the next stream ID doesn't
140        change.
141        """
142        c = h2.connection.H2Connection()
143        c.initiate_connection()
144
145        first_stream_id = c.get_next_available_stream_id()
146        second_stream_id = c.get_next_available_stream_id()
147
148        assert first_stream_id == second_stream_id
149
150        c.send_headers(
151            stream_id=first_stream_id,
152            headers=self.example_request_headers
153        )
154
155        third_stream_id = c.get_next_available_stream_id()
156        assert third_stream_id == (first_stream_id + 2)
157
158
159class TestExtractHeader(object):
160
161    example_request_headers = [
162            (u':authority', u'example.com'),
163            (u':path', u'/'),
164            (u':scheme', u'https'),
165            (u':method', u'GET'),
166    ]
167    example_headers_with_bytes = [
168            (b':authority', b'example.com'),
169            (b':path', b'/'),
170            (b':scheme', b'https'),
171            (b':method', b'GET'),
172    ]
173
174    @pytest.mark.parametrize(
175        'headers', [example_request_headers, example_headers_with_bytes]
176    )
177    def test_extract_header_method(self, headers):
178        assert extract_method_header(headers) == b'GET'
179
180
181def test_size_limit_dict_limit():
182    dct = SizeLimitDict(size_limit=2)
183
184    dct[1] = 1
185    dct[2] = 2
186
187    assert len(dct) == 2
188    assert dct[1] == 1
189    assert dct[2] == 2
190
191    dct[3] = 3
192
193    assert len(dct) == 2
194    assert dct[2] == 2
195    assert dct[3] == 3
196    assert 1 not in dct
197
198
199def test_size_limit_dict_limit_init():
200    initial_dct = {
201        1: 1,
202        2: 2,
203        3: 3,
204    }
205
206    dct = SizeLimitDict(initial_dct, size_limit=2)
207
208    assert len(dct) == 2
209
210
211def test_size_limit_dict_no_limit():
212    dct = SizeLimitDict(size_limit=None)
213
214    dct[1] = 1
215    dct[2] = 2
216
217    assert len(dct) == 2
218    assert dct[1] == 1
219    assert dct[2] == 2
220
221    dct[3] = 3
222
223    assert len(dct) == 3
224    assert dct[1] == 1
225    assert dct[2] == 2
226    assert dct[3] == 3
227