1 // Part of Measurement Kit <https://measurement-kit.github.io/>.
2 // Measurement Kit is free software under the BSD license. See AUTHORS
3 // and LICENSE for more information on the copying conditions.
4 
5 #include "test/winsock.hpp"
6 
7 #include "include/private/catch.hpp"
8 
9 #include "src/libmeasurement_kit/net/socks5.hpp"
10 #include "src/libmeasurement_kit/net/error.hpp"
11 
12 using namespace mk;
13 using namespace mk::net;
14 
15 TEST_CASE("format_auth_request() works as expected") {
16     Buffer buffer = socks5_format_auth_request(Logger::make());
17     REQUIRE(buffer.length() == 3);
18     std::string message = buffer.read();
19     REQUIRE(message[0] == '\5');
20     REQUIRE(message[1] == '\1');
21     REQUIRE(message[2] == '\0');
22 }
23 
24 TEST_CASE("parse_auth_response() works as expected") {
25 
26     SECTION("When there is no input at all") {
27         Buffer input;
28         ErrorOr<bool> rc = socks5_parse_auth_response(input, Logger::make());
29         REQUIRE(rc.as_error() == NoError());
30         REQUIRE(rc.as_value() == false);
31     }
32 
33     SECTION("When there is just one byte of data") {
34         Buffer input;
35         input.write_uint8(5);
36         ErrorOr<bool> rc = socks5_parse_auth_response(input, Logger::make());
37         REQUIRE(rc.as_error() == NoError());
38         REQUIRE(rc.as_value() == false);
39     }
40 
41     SECTION("When the version is wrong") {
42         Buffer input;
43         input.write_uint8(4);
44         input.write_uint8(0);
45         ErrorOr<bool> rc = socks5_parse_auth_response(input, Logger::make());
46         REQUIRE(rc.as_error() == BadSocksVersionError());
47         REQUIRE_THROWS_AS(rc.as_value(), std::runtime_error);
48     }
49 
50     SECTION("When the preferred_auth is wrong") {
51         Buffer input;
52         input.write_uint8(5);
53         input.write_uint8(16);
54         ErrorOr<bool> rc = socks5_parse_auth_response(input, Logger::make());
55         REQUIRE(rc.as_error() == NoAvailableSocksAuthenticationError());
56         REQUIRE_THROWS_AS(rc.as_value(), std::runtime_error);
57     }
58 
59     SECTION("When the input is OK") {
60         Buffer input;
61         input.write_uint8(5);
62         input.write_uint8(0);
63         ErrorOr<bool> rc = socks5_parse_auth_response(input, Logger::make());
64         REQUIRE(rc.as_error() == NoError());
65         REQUIRE(rc.as_value() == true);
66     }
67 }
68 
69 TEST_CASE("format_connect_request() works as expected") {
70     SECTION("When the address is too long") {
71         ErrorOr<Buffer> rc = socks5_format_connect_request({
72             {"net/address", std::string(1024, 'A')},
73         }, Logger::make());
74         REQUIRE(rc.as_error() == SocksAddressTooLongError());
75         REQUIRE_THROWS_AS(rc.as_value(), std::runtime_error);
76     }
77 
78     SECTION("When the port number is negative") {
79         ErrorOr<Buffer> rc = socks5_format_connect_request({
80             {"net/address", "130.192.91.211"}, {"net/port", -1},
81         }, Logger::make());
82         REQUIRE(rc.as_error() == SocksInvalidPortError());
83         REQUIRE_THROWS_AS(rc.as_value(), std::runtime_error);
84     }
85 
86     SECTION("When the port number is too large") {
87         ErrorOr<Buffer> rc = socks5_format_connect_request({
88             {"net/address", "130.192.91.211"}, {"net/port", 65536},
89         }, Logger::make());
90         REQUIRE(rc.as_error() == SocksInvalidPortError());
91         REQUIRE_THROWS_AS(rc.as_value(), std::runtime_error);
92     }
93 
94     SECTION("When input is OK") {
95         std::string address = "130.192.91.211";
96         uint16_t orig_port = 8080;
97         ErrorOr<Buffer> rc = socks5_format_connect_request({
98             {"net/address", address}, {"net/port", orig_port},
99         }, Logger::make());
100         REQUIRE(rc.as_error() == NoError());
101         std::string msg = rc->read(5 + address.length());
102         REQUIRE(msg[0] == '\5');
103         REQUIRE(msg[1] == '\1');
104         REQUIRE(msg[2] == '\0');
105         REQUIRE(msg[3] == '\3');
106         REQUIRE((unsigned char)msg[4] == address.length());
107         REQUIRE(msg.substr(5, address.length()) == address);
108         // XXX This part of the test is currently not possible:
109         // uint16_t port = rc->read_uint16();
110         // REQUIRE(port == orig_port);
111     }
112 }
113 
114 TEST_CASE("parse_connect_response() works as expected") {
115 
116     SECTION("When there are less than five bytes of input") {
117         Buffer input("ABCD");
118         ErrorOr<bool> rc = socks5_parse_connect_response(input, Logger::make());
119         REQUIRE(rc.as_error() == NoError());
120         REQUIRE(rc.as_value() == false);
121     }
122 
123     SECTION("When the version is not OK") {
124         Buffer input;
125         input.write_uint8(4);
126         input.write_uint8(0);
127         input.write_uint8(0);
128         input.write_uint8(3);
129         input.write_uint8(0);
130         ErrorOr<bool> rc = socks5_parse_connect_response(input, Logger::make());
131         REQUIRE(rc.as_error() == BadSocksVersionError());
132         REQUIRE_THROWS_AS(rc.as_value(), std::runtime_error);
133     }
134 
135     SECTION("When there was a network error") {
136         Buffer input;
137         input.write_uint8(5);
138         input.write_uint8(1);
139         input.write_uint8(0);
140         input.write_uint8(3);
141         input.write_uint8(0);
142         ErrorOr<bool> rc = socks5_parse_connect_response(input, Logger::make());
143         REQUIRE(rc.as_error() == SocksError());
144         REQUIRE_THROWS_AS(rc.as_value(), std::runtime_error);
145     }
146 
147     SECTION("When the reserved field is invalid") {
148         Buffer input;
149         input.write_uint8(5);
150         input.write_uint8(0);
151         input.write_uint8(1);
152         input.write_uint8(3);
153         input.write_uint8(0);
154         ErrorOr<bool> rc = socks5_parse_connect_response(input, Logger::make());
155         REQUIRE(rc.as_error() == BadSocksReservedFieldError());
156         REQUIRE_THROWS_AS(rc.as_value(), std::runtime_error);
157     }
158 
159     SECTION("When the atype field is invalid") {
160         Buffer input;
161         input.write_uint8(5);
162         input.write_uint8(0);
163         input.write_uint8(0);
164         input.write_uint8(44);
165         input.write_uint8(0);
166         ErrorOr<bool> rc = socks5_parse_connect_response(input, Logger::make());
167         REQUIRE(rc.as_error() == BadSocksAtypeValueError());
168         REQUIRE_THROWS_AS(rc.as_value(), std::runtime_error);
169     }
170 
171     SECTION("When not the whole message was read") {
172         Buffer input;
173         input.write_uint8(5);
174         input.write_uint8(0);
175         input.write_uint8(0);
176         input.write_uint8(3);
177         input.write_uint8(6);
178         ErrorOr<bool> rc = socks5_parse_connect_response(input, Logger::make());
179         REQUIRE(rc.as_error() == NoError());
180         REQUIRE(rc.as_value() == false);
181     }
182 
183     SECTION("When the message contains an IPv4 address") {
184         Buffer input;
185         input.write_uint8(5);
186         input.write_uint8(0);
187         input.write_uint8(0);
188         input.write_uint8(1);
189         // <IPv4>
190         input.write_uint8(0);
191         input.write_uint8(0);
192         input.write_uint8(0);
193         input.write_uint8(0);
194         // </IPv4>
195         input.write_uint16(8080);
196         ErrorOr<bool> rc = socks5_parse_connect_response(input, Logger::make());
197         REQUIRE(rc.as_error() == NoError());
198         REQUIRE(rc.as_value() == true);
199         REQUIRE(input.length() == 0);
200     }
201 
202     SECTION("When the message contains a string address") {
203         Buffer input;
204         input.write_uint8(5);
205         input.write_uint8(0);
206         input.write_uint8(0);
207         input.write_uint8(3);
208         // <len+string>
209         input.write_uint8(5);
210         input.write_uint8('x');
211         input.write_uint8('.');
212         input.write_uint8('o');
213         input.write_uint8('r');
214         input.write_uint8('g');
215         // </len+string>
216         input.write_uint16(8080);
217         ErrorOr<bool> rc = socks5_parse_connect_response(input, Logger::make());
218         REQUIRE(rc.as_error() == NoError());
219         REQUIRE(rc.as_value() == true);
220         REQUIRE(input.length() == 0);
221     }
222 
223     SECTION("When the message contains a IPv6 address") {
224         Buffer input;
225         input.write_uint8(5);
226         input.write_uint8(0);
227         input.write_uint8(0);
228         input.write_uint8(4);
229         // <IPv6>
230         input.write_uint8(0), input.write_uint8(0), input.write_uint8(0),
231             input.write_uint8(0);
232         input.write_uint8(0), input.write_uint8(0), input.write_uint8(0),
233             input.write_uint8(0);
234         input.write_uint8(0), input.write_uint8(0), input.write_uint8(0),
235             input.write_uint8(0);
236         input.write_uint8(0), input.write_uint8(0), input.write_uint8(0),
237             input.write_uint8(0);
238         // </IPv6>
239         input.write_uint16(8080);
240         ErrorOr<bool> rc = socks5_parse_connect_response(input, Logger::make());
241         REQUIRE(rc.as_error() == NoError());
242         REQUIRE(rc.as_value() == true);
243         REQUIRE(input.length() == 0);
244     }
245 }
246