1# -*- coding: utf-8 -*-
2"""
3test_state_machines
4~~~~~~~~~~~~~~~~~~~
5
6These tests validate the state machines directly. Writing meaningful tests for
7this case can be tricky, so the majority of these tests use Hypothesis to try
8to talk about general behaviours rather than specific cases.
9"""
10import pytest
11
12import h2.connection
13import h2.exceptions
14import h2.stream
15
16from hypothesis import given
17from hypothesis.strategies import sampled_from
18
19
20class TestConnectionStateMachine(object):
21    """
22    Tests of the connection state machine.
23    """
24    @given(state=sampled_from(h2.connection.ConnectionState),
25           input_=sampled_from(h2.connection.ConnectionInputs))
26    def test_state_transitions(self, state, input_):
27        c = h2.connection.H2ConnectionStateMachine()
28        c.state = state
29
30        try:
31            c.process_input(input_)
32        except h2.exceptions.ProtocolError:
33            assert c.state == h2.connection.ConnectionState.CLOSED
34        else:
35            assert c.state in h2.connection.ConnectionState
36
37    def test_state_machine_only_allows_connection_states(self):
38        """
39        The Connection state machine only allows ConnectionState inputs.
40        """
41        c = h2.connection.H2ConnectionStateMachine()
42
43        with pytest.raises(ValueError):
44            c.process_input(1)
45
46    @pytest.mark.parametrize(
47        "state",
48        (
49            s for s in h2.connection.ConnectionState
50            if s != h2.connection.ConnectionState.CLOSED
51        ),
52    )
53    @pytest.mark.parametrize(
54        "input_",
55        [
56            h2.connection.ConnectionInputs.RECV_PRIORITY,
57            h2.connection.ConnectionInputs.SEND_PRIORITY
58        ]
59    )
60    def test_priority_frames_allowed_in_all_states(self, state, input_):
61        """
62        Priority frames can be sent/received in all connection states except
63        closed.
64        """
65        c = h2.connection.H2ConnectionStateMachine()
66        c.state = state
67
68        c.process_input(input_)
69
70
71class TestStreamStateMachine(object):
72    """
73    Tests of the stream state machine.
74    """
75    @given(state=sampled_from(h2.stream.StreamState),
76           input_=sampled_from(h2.stream.StreamInputs))
77    def test_state_transitions(self, state, input_):
78        s = h2.stream.H2StreamStateMachine(stream_id=1)
79        s.state = state
80
81        try:
82            s.process_input(input_)
83        except h2.exceptions.StreamClosedError:
84            # This can only happen for streams that started in the closed
85            # state OR where the input was RECV_DATA and the state was not
86            # OPEN or HALF_CLOSED_LOCAL OR where the state was
87            # HALF_CLOSED_REMOTE and a frame was received.
88            if state == h2.stream.StreamState.CLOSED:
89                assert s.state == h2.stream.StreamState.CLOSED
90            elif input_ == h2.stream.StreamInputs.RECV_DATA:
91                assert s.state == h2.stream.StreamState.CLOSED
92                assert state not in (
93                    h2.stream.StreamState.OPEN,
94                    h2.stream.StreamState.HALF_CLOSED_LOCAL,
95                )
96            elif state == h2.stream.StreamState.HALF_CLOSED_REMOTE:
97                assert input_ in (
98                    h2.stream.StreamInputs.RECV_HEADERS,
99                    h2.stream.StreamInputs.RECV_PUSH_PROMISE,
100                    h2.stream.StreamInputs.RECV_DATA,
101                    h2.stream.StreamInputs.RECV_CONTINUATION,
102                )
103        except h2.exceptions.ProtocolError:
104            assert s.state == h2.stream.StreamState.CLOSED
105        else:
106            assert s.state in h2.stream.StreamState
107
108    def test_state_machine_only_allows_stream_states(self):
109        """
110        The Stream state machine only allows StreamState inputs.
111        """
112        s = h2.stream.H2StreamStateMachine(stream_id=1)
113
114        with pytest.raises(ValueError):
115            s.process_input(1)
116
117    def test_stream_state_machine_forbids_pushes_on_server_streams(self):
118        """
119        Streams where this peer is a server do not allow receiving pushed
120        frames.
121        """
122        s = h2.stream.H2StreamStateMachine(stream_id=1)
123        s.process_input(h2.stream.StreamInputs.RECV_HEADERS)
124
125        with pytest.raises(h2.exceptions.ProtocolError):
126            s.process_input(h2.stream.StreamInputs.RECV_PUSH_PROMISE)
127
128    def test_stream_state_machine_forbids_sending_pushes_from_clients(self):
129        """
130        Streams where this peer is a client do not allow sending pushed frames.
131        """
132        s = h2.stream.H2StreamStateMachine(stream_id=1)
133        s.process_input(h2.stream.StreamInputs.SEND_HEADERS)
134
135        with pytest.raises(h2.exceptions.ProtocolError):
136            s.process_input(h2.stream.StreamInputs.SEND_PUSH_PROMISE)
137
138    @pytest.mark.parametrize(
139        "input_",
140        [
141            h2.stream.StreamInputs.SEND_HEADERS,
142            h2.stream.StreamInputs.SEND_PUSH_PROMISE,
143            h2.stream.StreamInputs.SEND_RST_STREAM,
144            h2.stream.StreamInputs.SEND_DATA,
145            h2.stream.StreamInputs.SEND_WINDOW_UPDATE,
146            h2.stream.StreamInputs.SEND_END_STREAM,
147        ]
148    )
149    def test_cannot_send_on_closed_streams(self, input_):
150        """
151        Sending anything but a PRIORITY frame is forbidden on closed streams.
152        """
153        c = h2.stream.H2StreamStateMachine(stream_id=1)
154        c.state = h2.stream.StreamState.CLOSED
155
156        expected_error = (
157            h2.exceptions.ProtocolError
158            if input_ == h2.stream.StreamInputs.SEND_PUSH_PROMISE
159            else h2.exceptions.StreamClosedError
160        )
161
162        with pytest.raises(expected_error):
163            c.process_input(input_)
164