1import pytest
2
3from http.client import HTTPConnection
4
5
6def put_required_headers(conn):
7    conn.putheader("Connection", "upgrade")
8    conn.putheader("Upgrade", "websocket")
9    conn.putheader("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
10    conn.putheader("Sec-WebSocket-Version", "13")
11
12
13@pytest.mark.parametrize(
14    "hostname, port_type, status",
15    [
16        # Valid hosts
17        ("localhost", "remote_agent_port", 101),
18        ("localhost", "default_port", 101),
19        ("127.0.0.1", "remote_agent_port", 101),
20        ("127.0.0.1", "default_port", 101),
21        ("[::1]", "remote_agent_port", 101),
22        ("[::1]", "default_port", 101),
23        ("192.168.8.1", "remote_agent_port", 101),
24        ("192.168.8.1", "default_port", 101),
25        ("[fdf8:f535:82e4::53]", "remote_agent_port", 101),
26        ("[fdf8:f535:82e4::53]", "default_port", 101),
27        # Invalid hosts
28        ("mozilla.org", "remote_agent_port", 400),
29        ("mozilla.org", "wrong_port", 400),
30        ("mozilla.org", "default_port", 400),
31        ("localhost", "wrong_port", 400),
32        ("127.0.0.1", "wrong_port", 400),
33        ("[::1]", "wrong_port", 400),
34        ("192.168.8.1", "wrong_port", 400),
35        ("[fdf8:f535:82e4::53]", "wrong_port", 400),
36    ],
37    ids=[
38        # Valid hosts
39        "localhost with same port as RemoteAgent",
40        "localhost with default port",
41        "127.0.0.1 (loopback) with same port as RemoteAgent",
42        "127.0.0.1 (loopback) with default port",
43        "[::1] (ipv6 loopback) with same port as RemoteAgent",
44        "[::1] (ipv6 loopback) with default port",
45        "ipv4 address with same port as RemoteAgent",
46        "ipv4 address with default port",
47        "ipv6 address with same port as RemoteAgent",
48        "ipv6 address with default port",
49        # Invalid hosts
50        "random hostname with the same port as RemoteAgent",
51        "random hostname with a different port than RemoteAgent",
52        "random hostname with default port",
53        "localhost with a different port than RemoteAgent",
54        "127.0.0.1 (loopback) with a different port than RemoteAgent",
55        "[::1] (ipv6 loopback) with a different port than RemoteAgent",
56        "ipv4 address with a different port than RemoteAgent",
57        "ipv6 address with a different port than RemoteAgent",
58    ],
59)
60@pytest.mark.capabilities({"webSocketUrl": True})
61def test_host_header(session, hostname, port_type, status):
62    websocket_url = session.capabilities["webSocketUrl"]
63    url = websocket_url.replace("ws:", "http:")
64    _, _, real_host, path = url.split("/", 3)
65    _, remote_agent_port = real_host.split(":")
66
67    def get_host():
68        if port_type == "default_port":
69            return hostname
70        elif port_type == "remote_agent_port":
71            return hostname + ":" + remote_agent_port
72        elif port_type == "wrong_port":
73            wrong_port = str(int(remote_agent_port) + 1)
74            return hostname + ":" + wrong_port
75
76    conn = HTTPConnection(real_host)
77
78    conn.putrequest("GET", url, skip_host=True)
79
80    conn.putheader("Host", get_host())
81    put_required_headers(conn)
82    conn.endheaders()
83
84    response = conn.getresponse()
85
86    assert response.status == status
87
88
89@pytest.mark.parametrize(
90    "origin, status",
91    [
92        (None, 101),
93        ("", 400),
94        ("sometext", 400),
95        ("http://localhost:1234", 400),
96    ],
97)
98@pytest.mark.capabilities({"webSocketUrl": True})
99def test_origin_header(session, origin, status):
100    websocket_url = session.capabilities["webSocketUrl"]
101    url = websocket_url.replace("ws:", "http:")
102    _, _, real_host, path = url.split("/", 3)
103
104    conn = HTTPConnection(real_host)
105    conn.putrequest("GET", url)
106
107    if origin is not None:
108        conn.putheader("Origin", origin)
109
110    put_required_headers(conn)
111    conn.endheaders()
112
113    response = conn.getresponse()
114
115    assert response.status == status
116