1 /*
2 * SRT - Secure, Reliable, Transport
3 * Copyright (c) 2018 Haivision Systems Inc.
4 *
5 * This Source Code Form is subject to the terms of the Mozilla Public
6 * License, v. 2.0. If a copy of the MPL was not distributed with this
7 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 *
9 * Written by:
10 * Haivision Systems Inc.
11 */
12
13 #include <gtest/gtest.h>
14 #include <thread>
15 #include <condition_variable>
16 #include <mutex>
17
18 #include "srt.h"
19 #include "sync.h"
20
21
22
23 enum PEER_TYPE
24 {
25 PEER_CALLER = 0,
26 PEER_LISTENER = 1,
27 PEER_COUNT = 2, // Number of peers
28 };
29
30
31 enum CHECK_SOCKET_TYPE
32 {
33 CHECK_SOCKET_CALLER = 0,
34 CHECK_SOCKET_ACCEPTED = 1,
35 CHECK_SOCKET_COUNT = 2, // Number of peers
36 };
37
38
39 enum TEST_CASE
40 {
41 TEST_CASE_A_1 = 0,
42 TEST_CASE_A_2,
43 TEST_CASE_A_3,
44 TEST_CASE_A_4,
45 TEST_CASE_A_5,
46 TEST_CASE_B_1,
47 TEST_CASE_B_2,
48 TEST_CASE_B_3,
49 TEST_CASE_B_4,
50 TEST_CASE_B_5,
51 TEST_CASE_C_1,
52 TEST_CASE_C_2,
53 TEST_CASE_C_3,
54 TEST_CASE_C_4,
55 TEST_CASE_C_5,
56 TEST_CASE_D_1,
57 TEST_CASE_D_2,
58 TEST_CASE_D_3,
59 TEST_CASE_D_4,
60 TEST_CASE_D_5,
61 };
62
63
64 struct TestResultNonBlocking
65 {
66 int connect_ret;
67 int accept_ret;
68 int epoll_wait_ret;
69 int epoll_event;
70 int socket_state[CHECK_SOCKET_COUNT];
71 int km_state [CHECK_SOCKET_COUNT];
72 };
73
74
75 struct TestResultBlocking
76 {
77 int connect_ret;
78 int accept_ret;
79 int socket_state[CHECK_SOCKET_COUNT];
80 int km_state[CHECK_SOCKET_COUNT];
81 };
82
83
84 template<typename TResult>
85 struct TestCase
86 {
87 bool enforcedenc [PEER_COUNT];
88 const std::string (&password)[PEER_COUNT];
89 TResult expected_result;
90 };
91
92 typedef TestCase<TestResultNonBlocking> TestCaseNonBlocking;
93 typedef TestCase<TestResultBlocking> TestCaseBlocking;
94
95
96
97 static const std::string s_pwd_a ("s!t@r#i$c^t");
98 static const std::string s_pwd_b ("s!t@r#i$c^tu");
99 static const std::string s_pwd_no("");
100
101
102
103 /*
104 * TESTING SCENARIO
105 * Both peers exchange HandShake v5.
106 * Listener is sender in a non-blocking mode
107 * Caller is receiver in a non-blocking mode
108
109 * Cases B.2-B.4 are specific. Here we have incompatible password settings, but
110 * listener accepts it, while caller rejects it. In this case we have a short-living
111 * confusion state: The connection is accepted on the listener side, and the listener
112 * sends back the conclusion handshake, but caller will reject it.
113 *
114 * Because of that, we should ignore what will happen in the listener as this is
115 * just a matter of luck: if the listener thread is lucky, it will report the socket
116 * to accept, so epoll will signal it and accept will report it, and moreover, further
117 * good luck on this socket would make the state check return SRTS_CONNECTED. Without
118 * this good luck, the caller might be quick enough to reject the handshake and send
119 * the UMSG_SHUTDOWN packet to the peer. If it gets with it before acceptance, it will
120 * withdraw the socket before it could be reported by accept.
121 *
122 * Still, we check predictable things here, so we accept two possibilities:
123 * - The accepted socket wasn't reported at all
124 * - The accepted socket was reported, and after `srt_connect` is done, it should turn to SRTS_BROKEN.
125 *
126 * This embraces both cases when the accepted socket was broken in the beginning, and when it was CONNECTED
127 * in the beginning, but broke soon thereafter.
128 *
129 * This behavior is predicted and accepted - it's also the reason that setting ENFORCEDENC to false is
130 * NOT RECOMMENDED on a listener socket that isn't intended to accept only connections from known callers
131 * that are known to have set this flag also to false.
132 *
133 * In the cases C.2-C.4 it is the listener who rejects the connection, so we don't have an accepted socket
134 * and the situation is always the same and clear in the beginning. The caller cannot continue with the
135 * connection after listener accepted it, even if it tolerates incompatible password settings.
136 */
137
138 const int IGNORE_EPOLL = -2;
139 const int IGNORE_SRTS = -1;
140
141 const TestCaseNonBlocking g_test_matrix_non_blocking[] =
142 {
143 // ENFORCEDENC | Password | | EPoll wait | socket_state | KM State
144 // caller | listener | caller | listener | connect_ret accept_ret | ret | event | caller accepted | caller listener
145 /*A.1 */ { {true, true }, {s_pwd_a, s_pwd_a}, { SRT_SUCCESS, 0, 1, SRT_EPOLL_IN, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_SECURED, SRT_KM_S_SECURED}}},
146 /*A.2 */ { {true, true }, {s_pwd_a, s_pwd_b}, { SRT_SUCCESS, SRT_INVALID_SOCK, 0, 0, {SRTS_BROKEN, IGNORE_SRTS}, {SRT_KM_S_UNSECURED, IGNORE_SRTS}}},
147 /*A.3 */ { {true, true }, {s_pwd_a, s_pwd_no}, { SRT_SUCCESS, SRT_INVALID_SOCK, 0, 0, {SRTS_BROKEN, IGNORE_SRTS}, {SRT_KM_S_UNSECURED, IGNORE_SRTS}}},
148 /*A.4 */ { {true, true }, {s_pwd_no, s_pwd_b}, { SRT_SUCCESS, SRT_INVALID_SOCK, 0, 0, {SRTS_BROKEN, IGNORE_SRTS}, {SRT_KM_S_UNSECURED, IGNORE_SRTS}}},
149 /*A.5 */ { {true, true }, {s_pwd_no, s_pwd_no}, { SRT_SUCCESS, 0, 1, SRT_EPOLL_IN, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
150
151 /*B.1 */ { {true, false }, {s_pwd_a, s_pwd_a}, { SRT_SUCCESS, 0, 1, SRT_EPOLL_IN, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_SECURED, SRT_KM_S_SECURED}}},
152 /*B.2 */ { {true, false }, {s_pwd_a, s_pwd_b}, { SRT_SUCCESS, 0, IGNORE_EPOLL, 0, {SRTS_CONNECTING, SRTS_BROKEN}, {SRT_KM_S_BADSECRET, SRT_KM_S_BADSECRET}}},
153 /*B.3 */ { {true, false }, {s_pwd_a, s_pwd_no}, { SRT_SUCCESS, 0, IGNORE_EPOLL, 0, {SRTS_CONNECTING, SRTS_BROKEN}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
154 /*B.4 */ { {true, false }, {s_pwd_no, s_pwd_b}, { SRT_SUCCESS, 0, IGNORE_EPOLL, 0, {SRTS_CONNECTING, SRTS_BROKEN}, {SRT_KM_S_UNSECURED, SRT_KM_S_NOSECRET}}},
155 /*B.5 */ { {true, false }, {s_pwd_no, s_pwd_no}, { SRT_SUCCESS, 0, 1, SRT_EPOLL_IN, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
156
157 /*C.1 */ { {false, true }, {s_pwd_a, s_pwd_a}, { SRT_SUCCESS, 0, 1, SRT_EPOLL_IN, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_SECURED, SRT_KM_S_SECURED}}},
158 /*C.2 */ { {false, true }, {s_pwd_a, s_pwd_b}, { SRT_SUCCESS, SRT_INVALID_SOCK, 0, 0, {SRTS_BROKEN, IGNORE_SRTS}, {SRT_KM_S_UNSECURED, IGNORE_SRTS}}},
159 /*C.3 */ { {false, true }, {s_pwd_a, s_pwd_no}, { SRT_SUCCESS, SRT_INVALID_SOCK, 0, 0, {SRTS_BROKEN, IGNORE_SRTS}, {SRT_KM_S_UNSECURED, IGNORE_SRTS}}},
160 /*C.4 */ { {false, true }, {s_pwd_no, s_pwd_b}, { SRT_SUCCESS, SRT_INVALID_SOCK, 0, 0, {SRTS_BROKEN, IGNORE_SRTS}, {SRT_KM_S_UNSECURED, IGNORE_SRTS}}},
161 /*C.5 */ { {false, true }, {s_pwd_no, s_pwd_no}, { SRT_SUCCESS, 0, 1, SRT_EPOLL_IN, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
162
163 /*D.1 */ { {false, false }, {s_pwd_a, s_pwd_a}, { SRT_SUCCESS, 0, 1, SRT_EPOLL_IN, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_SECURED, SRT_KM_S_SECURED}}},
164 /*D.2 */ { {false, false }, {s_pwd_a, s_pwd_b}, { SRT_SUCCESS, 0, 1, SRT_EPOLL_IN, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_BADSECRET, SRT_KM_S_BADSECRET}}},
165 /*D.3 */ { {false, false }, {s_pwd_a, s_pwd_no}, { SRT_SUCCESS, 0, 1, SRT_EPOLL_IN, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
166 /*D.4 */ { {false, false }, {s_pwd_no, s_pwd_b}, { SRT_SUCCESS, 0, 1, SRT_EPOLL_IN, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_NOSECRET, SRT_KM_S_NOSECRET}}},
167 /*D.5 */ { {false, false }, {s_pwd_no, s_pwd_no}, { SRT_SUCCESS, 0, 1, SRT_EPOLL_IN, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
168 };
169
170
171 /*
172 * TESTING SCENARIO
173 * Both peers exchange HandShake v5.
174 * Listener is sender in a blocking mode
175 * Caller is receiver in a blocking mode
176 *
177 * In the cases B.2-B.4 the caller will reject the connection due to the enforced encryption check
178 * of the HS response from the listener on the stage of the KM response check.
179 * While the listener accepts the connection with the connected state. So the caller sends UMSG_SHUTDOWN
180 * to notify the listener that it has closed the connection. The accepted socket gets the SRTS_BROKEN states.
181 * For these cases a special accept_ret = -2 is used, that allows the accepted socket to be broken or already closed.
182 *
183 * In the cases C.2-C.4 it is the listener who rejects the connection, so we don't have an accepted socket.
184 */
185 const TestCaseBlocking g_test_matrix_blocking[] =
186 {
187 // ENFORCEDENC | Password | | socket_state | KM State
188 // caller | listener | caller | listener | connect_ret accept_ret | caller accepted | caller listener
189 /*A.1 */ { {true, true }, {s_pwd_a, s_pwd_a}, { SRT_SUCCESS, 0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_SECURED, SRT_KM_S_SECURED}}},
190 /*A.2 */ { {true, true }, {s_pwd_a, s_pwd_b}, { SRT_INVALID_SOCK, SRT_INVALID_SOCK, {SRTS_OPENED, -1}, {SRT_KM_S_UNSECURED, -1}}},
191 /*A.3 */ { {true, true }, {s_pwd_a, s_pwd_no}, { SRT_INVALID_SOCK, SRT_INVALID_SOCK, {SRTS_OPENED, -1}, {SRT_KM_S_UNSECURED, -1}}},
192 /*A.4 */ { {true, true }, {s_pwd_no, s_pwd_b}, { SRT_INVALID_SOCK, SRT_INVALID_SOCK, {SRTS_OPENED, -1}, {SRT_KM_S_UNSECURED, -1}}},
193 /*A.5 */ { {true, true }, {s_pwd_no, s_pwd_no}, { SRT_SUCCESS, 0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
194
195 /*B.1 */ { {true, false }, {s_pwd_a, s_pwd_a}, { SRT_SUCCESS, 0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_SECURED, SRT_KM_S_SECURED}}},
196 /*B.2 */ { {true, false }, {s_pwd_a, s_pwd_b}, { SRT_INVALID_SOCK, -2, {SRTS_OPENED, SRTS_BROKEN}, {SRT_KM_S_BADSECRET, SRT_KM_S_BADSECRET}}},
197 /*B.3 */ { {true, false }, {s_pwd_a, s_pwd_no}, { SRT_INVALID_SOCK, -2, {SRTS_OPENED, SRTS_BROKEN}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
198 /*B.4 */ { {true, false }, {s_pwd_no, s_pwd_b}, { SRT_INVALID_SOCK, -2, {SRTS_OPENED, SRTS_BROKEN}, {SRT_KM_S_UNSECURED, SRT_KM_S_NOSECRET}}},
199 /*B.5 */ { {true, false }, {s_pwd_no, s_pwd_no}, { SRT_SUCCESS, 0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
200
201 /*C.1 */ { {false, true }, {s_pwd_a, s_pwd_a}, { SRT_SUCCESS, 0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_SECURED, SRT_KM_S_SECURED}}},
202 /*C.2 */ { {false, true }, {s_pwd_a, s_pwd_b}, { SRT_INVALID_SOCK, SRT_INVALID_SOCK, {SRTS_OPENED, -1}, {SRT_KM_S_UNSECURED, -1}}},
203 /*C.3 */ { {false, true }, {s_pwd_a, s_pwd_no}, { SRT_INVALID_SOCK, SRT_INVALID_SOCK, {SRTS_OPENED, -1}, {SRT_KM_S_UNSECURED, -1}}},
204 /*C.4 */ { {false, true }, {s_pwd_no, s_pwd_b}, { SRT_INVALID_SOCK, SRT_INVALID_SOCK, {SRTS_OPENED, -1}, {SRT_KM_S_UNSECURED, -1}}},
205 /*C.5 */ { {false, true }, {s_pwd_no, s_pwd_no}, { SRT_SUCCESS, 0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
206
207 /*D.1 */ { {false, false }, {s_pwd_a, s_pwd_a}, { SRT_SUCCESS, 0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_SECURED, SRT_KM_S_SECURED}}},
208 /*D.2 */ { {false, false }, {s_pwd_a, s_pwd_b}, { SRT_SUCCESS, 0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_BADSECRET, SRT_KM_S_BADSECRET}}},
209 /*D.3 */ { {false, false }, {s_pwd_a, s_pwd_no}, { SRT_SUCCESS, 0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
210 /*D.4 */ { {false, false }, {s_pwd_no, s_pwd_b}, { SRT_SUCCESS, 0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_NOSECRET, SRT_KM_S_NOSECRET}}},
211 /*D.5 */ { {false, false }, {s_pwd_no, s_pwd_no}, { SRT_SUCCESS, 0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
212 };
213
214
215
216 class TestEnforcedEncryption
217 : public ::testing::Test
218 {
219 protected:
TestEnforcedEncryption()220 TestEnforcedEncryption()
221 {
222 // initialization code here
223 }
224
~TestEnforcedEncryption()225 ~TestEnforcedEncryption()
226 {
227 // cleanup any pending stuff, but no exceptions allowed
228 }
229
230 protected:
231
232 // SetUp() is run immediately before a test starts.
SetUp()233 void SetUp()
234 {
235 ASSERT_EQ(srt_startup(), 0);
236
237 m_pollid = srt_epoll_create();
238 ASSERT_GE(m_pollid, 0);
239
240 m_caller_socket = srt_create_socket();
241 ASSERT_NE(m_caller_socket, SRT_INVALID_SOCK);
242
243 ASSERT_NE(srt_setsockflag(m_caller_socket, SRTO_SENDER, &s_yes, sizeof s_yes), SRT_ERROR);
244 ASSERT_NE(srt_setsockopt (m_caller_socket, 0, SRTO_TSBPDMODE, &s_yes, sizeof s_yes), SRT_ERROR);
245
246 m_listener_socket = srt_create_socket();
247 ASSERT_NE(m_listener_socket, SRT_INVALID_SOCK);
248
249 ASSERT_NE(srt_setsockflag(m_listener_socket, SRTO_SENDER, &s_no, sizeof s_no), SRT_ERROR);
250 ASSERT_NE(srt_setsockopt (m_listener_socket, 0, SRTO_TSBPDMODE, &s_yes, sizeof s_yes), SRT_ERROR);
251
252 // Will use this epoll to wait for srt_accept(...)
253 const int epoll_out = SRT_EPOLL_IN | SRT_EPOLL_ERR;
254 ASSERT_NE(srt_epoll_add_usock(m_pollid, m_listener_socket, &epoll_out), SRT_ERROR);
255 }
256
TearDown()257 void TearDown()
258 {
259 // Code here will be called just after the test completes.
260 // OK to throw exceptions from here if needed.
261 ASSERT_NE(srt_close(m_caller_socket), SRT_ERROR);
262 ASSERT_NE(srt_close(m_listener_socket), SRT_ERROR);
263 srt_cleanup();
264 }
265
266
267 public:
268
269
SetEnforcedEncryption(PEER_TYPE peer,bool value)270 int SetEnforcedEncryption(PEER_TYPE peer, bool value)
271 {
272 const SRTSOCKET &socket = peer == PEER_CALLER ? m_caller_socket : m_listener_socket;
273 return srt_setsockopt(socket, 0, SRTO_ENFORCEDENCRYPTION, value ? &s_yes : &s_no, sizeof s_yes);
274 }
275
276
GetEnforcedEncryption(PEER_TYPE peer_type)277 bool GetEnforcedEncryption(PEER_TYPE peer_type)
278 {
279 const SRTSOCKET socket = peer_type == PEER_CALLER ? m_caller_socket : m_listener_socket;
280 bool optval;
281 int optlen = sizeof optval;
282 EXPECT_EQ(srt_getsockopt(socket, 0, SRTO_ENFORCEDENCRYPTION, (void*)&optval, &optlen), SRT_SUCCESS);
283 return optval ? true : false;
284 }
285
286
SetPassword(PEER_TYPE peer_type,const std::basic_string<char> & pwd)287 int SetPassword(PEER_TYPE peer_type, const std::basic_string<char> &pwd)
288 {
289 const SRTSOCKET socket = peer_type == PEER_CALLER ? m_caller_socket : m_listener_socket;
290 return srt_setsockopt(socket, 0, SRTO_PASSPHRASE, pwd.c_str(), (int) pwd.size());
291 }
292
293
GetKMState(SRTSOCKET socket)294 int GetKMState(SRTSOCKET socket)
295 {
296 int km_state = 0;
297 int opt_size = sizeof km_state;
298 EXPECT_EQ(srt_getsockopt(socket, 0, SRTO_KMSTATE, reinterpret_cast<void*>(&km_state), &opt_size), SRT_SUCCESS);
299
300 return km_state;
301 }
302
303
GetSocetkOption(SRTSOCKET socket,SRT_SOCKOPT opt)304 int GetSocetkOption(SRTSOCKET socket, SRT_SOCKOPT opt)
305 {
306 int val = 0;
307 int size = sizeof val;
308 EXPECT_EQ(srt_getsockopt(socket, 0, opt, reinterpret_cast<void*>(&val), &size), SRT_SUCCESS);
309
310 return val;
311 }
312
313
314 template<typename TResult>
315 int WaitOnEpoll(const TResult &expect);
316
317
318 template<typename TResult>
319 const TestCase<TResult>& GetTestMatrix(TEST_CASE test_case) const;
320
321 template<typename TResult>
TestConnect(TEST_CASE test_case)322 void TestConnect(TEST_CASE test_case/*, bool is_blocking*/)
323 {
324 const bool is_blocking = std::is_same<TResult, TestResultBlocking>::value;
325 if (is_blocking)
326 {
327 ASSERT_NE(srt_setsockopt( m_caller_socket, 0, SRTO_RCVSYN, &s_yes, sizeof s_yes), SRT_ERROR);
328 ASSERT_NE(srt_setsockopt( m_caller_socket, 0, SRTO_SNDSYN, &s_yes, sizeof s_yes), SRT_ERROR);
329 ASSERT_NE(srt_setsockopt(m_listener_socket, 0, SRTO_RCVSYN, &s_yes, sizeof s_yes), SRT_ERROR);
330 ASSERT_NE(srt_setsockopt(m_listener_socket, 0, SRTO_SNDSYN, &s_yes, sizeof s_yes), SRT_ERROR);
331 }
332 else
333 {
334 ASSERT_NE(srt_setsockopt( m_caller_socket, 0, SRTO_RCVSYN, &s_no, sizeof s_no), SRT_ERROR); // non-blocking mode
335 ASSERT_NE(srt_setsockopt( m_caller_socket, 0, SRTO_SNDSYN, &s_no, sizeof s_no), SRT_ERROR); // non-blocking mode
336 ASSERT_NE(srt_setsockopt(m_listener_socket, 0, SRTO_RCVSYN, &s_no, sizeof s_no), SRT_ERROR); // non-blocking mode
337 ASSERT_NE(srt_setsockopt(m_listener_socket, 0, SRTO_SNDSYN, &s_no, sizeof s_no), SRT_ERROR); // non-blocking mode
338 }
339
340 // Prepare input state
341 const TestCase<TResult> &test = GetTestMatrix<TResult>(test_case);
342 ASSERT_EQ(SetEnforcedEncryption(PEER_CALLER, test.enforcedenc[PEER_CALLER]), SRT_SUCCESS);
343 ASSERT_EQ(SetEnforcedEncryption(PEER_LISTENER, test.enforcedenc[PEER_LISTENER]), SRT_SUCCESS);
344
345 ASSERT_EQ(SetPassword(PEER_CALLER, test.password[PEER_CALLER]), SRT_SUCCESS);
346 ASSERT_EQ(SetPassword(PEER_LISTENER, test.password[PEER_LISTENER]), SRT_SUCCESS);
347
348 const TResult &expect = test.expected_result;
349
350 // Start testing
351 volatile bool caller_done = false;
352 sockaddr_in sa;
353 memset(&sa, 0, sizeof sa);
354 sa.sin_family = AF_INET;
355 sa.sin_port = htons(5200);
356 ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &sa.sin_addr), 1);
357 sockaddr* psa = (sockaddr*)&sa;
358 ASSERT_NE(srt_bind(m_listener_socket, psa, sizeof sa), SRT_ERROR);
359 ASSERT_NE(srt_listen(m_listener_socket, 4), SRT_ERROR);
360
361 auto accepting_thread = std::thread([&] {
362 const int epoll_event = WaitOnEpoll(expect);
363
364 // In a blocking mode we expect a socket returned from srt_accept() if the srt_connect succeeded.
365 // In a non-blocking mode we expect a socket returned from srt_accept() if the srt_connect succeeded,
366 // otherwise SRT_INVALID_SOCKET after the listening socket is closed.
367 sockaddr_in client_address;
368 int length = sizeof(sockaddr_in);
369 SRTSOCKET accepted_socket = -1;
370 if (epoll_event == SRT_EPOLL_IN)
371 {
372 accepted_socket = srt_accept(m_listener_socket, (sockaddr*)&client_address, &length);
373 std::cout << "ACCEPT: done, result=" << accepted_socket << std::endl;
374 }
375 else
376 {
377 std::cout << "ACCEPT: NOT done\n";
378 }
379
380 if (accepted_socket == SRT_INVALID_SOCK)
381 {
382 std::cerr << "[T] ACCEPT ERROR: " << srt_getlasterror_str() << std::endl;
383 }
384 else
385 {
386 std::cerr << "[T] ACCEPT SUCCEEDED: @" << accepted_socket << "\n";
387 }
388
389 EXPECT_NE(accepted_socket, 0);
390 if (expect.accept_ret == SRT_INVALID_SOCK)
391 {
392 EXPECT_EQ(accepted_socket, SRT_INVALID_SOCK);
393 }
394 else if (expect.accept_ret != -2)
395 {
396 EXPECT_NE(accepted_socket, SRT_INVALID_SOCK);
397 }
398
399 if (accepted_socket != SRT_INVALID_SOCK && expect.socket_state[CHECK_SOCKET_ACCEPTED] != IGNORE_SRTS)
400 {
401 if (m_is_tracing)
402 {
403 std::cerr << "EARLY Socket state accepted: " << m_socket_state[srt_getsockstate(accepted_socket)]
404 << " (expected: " << m_socket_state[expect.socket_state[CHECK_SOCKET_ACCEPTED]] << ")\n";
405 std::cerr << "KM State accepted: " << m_km_state[GetKMState(accepted_socket)] << '\n';
406 std::cerr << "RCV KM State accepted: " << m_km_state[GetSocetkOption(accepted_socket, SRTO_RCVKMSTATE)] << '\n';
407 std::cerr << "SND KM State accepted: " << m_km_state[GetSocetkOption(accepted_socket, SRTO_SNDKMSTATE)] << '\n';
408 }
409
410 // We have to wait some time for the socket to be able to process the HS responce from the caller.
411 // In test cases B2 - B4 the socket is expected to change its state from CONNECTED to BROKEN
412 // due to KM mismatches
413 do
414 {
415 std::this_thread::sleep_for(std::chrono::milliseconds(50));
416 } while (!caller_done);
417
418 // Special case when the expected state is "broken": if so, tolerate every possible
419 // socket state, just NOT LESS than SRTS_BROKEN, and also don't read any flags on that socket.
420
421 if (expect.socket_state[CHECK_SOCKET_ACCEPTED] == SRTS_BROKEN)
422 {
423 EXPECT_GE(srt_getsockstate(accepted_socket), SRTS_BROKEN);
424 }
425 else
426 {
427 EXPECT_EQ(srt_getsockstate(accepted_socket), expect.socket_state[CHECK_SOCKET_ACCEPTED]);
428 EXPECT_EQ(GetSocetkOption(accepted_socket, SRTO_SNDKMSTATE), expect.km_state[CHECK_SOCKET_ACCEPTED]);
429 }
430
431 if (m_is_tracing)
432 {
433 const SRT_SOCKSTATUS status = srt_getsockstate(accepted_socket);
434 std::cerr << "LATE Socket state accepted: " << m_socket_state[status]
435 << " (expected: " << m_socket_state[expect.socket_state[CHECK_SOCKET_ACCEPTED]] << ")\n";
436 }
437 }
438 });
439
440 const int connect_ret = srt_connect(m_caller_socket, psa, sizeof sa);
441 EXPECT_EQ(connect_ret, expect.connect_ret);
442
443 if (connect_ret == SRT_ERROR && connect_ret != expect.connect_ret)
444 {
445 std::cerr << "UNEXPECTED! srt_connect returned error: "
446 << srt_getlasterror_str() << " (code " << srt_getlasterror(NULL) << ")\n";
447 }
448
449 caller_done = true;
450
451 if (is_blocking == false)
452 accepting_thread.join();
453
454 if (m_is_tracing)
455 {
456 std::cerr << "Socket state caller: " << m_socket_state[srt_getsockstate(m_caller_socket)] << "\n";
457 std::cerr << "Socket state listener: " << m_socket_state[srt_getsockstate(m_listener_socket)] << "\n";
458 std::cerr << "KM State caller: " << m_km_state[GetKMState(m_caller_socket)] << '\n';
459 std::cerr << "RCV KM State caller: " << m_km_state[GetSocetkOption(m_caller_socket, SRTO_RCVKMSTATE)] << '\n';
460 std::cerr << "SND KM State caller: " << m_km_state[GetSocetkOption(m_caller_socket, SRTO_SNDKMSTATE)] << '\n';
461 std::cerr << "KM State listener: " << m_km_state[GetKMState(m_listener_socket)] << '\n';
462 }
463
464 // If a blocking call to srt_connect() returned error, then the state is not valid,
465 // but we still check it because we know what it should be. This way we may see potential changes in the core behavior.
466 if (is_blocking)
467 {
468 EXPECT_EQ(srt_getsockstate(m_caller_socket), expect.socket_state[CHECK_SOCKET_CALLER]);
469 }
470 // A caller socket, regardless of the mode, if it's not expected to be connected, check negatively.
471 if (expect.socket_state[CHECK_SOCKET_CALLER] == SRTS_CONNECTED)
472 {
473 EXPECT_EQ(srt_getsockstate(m_caller_socket), SRTS_CONNECTED);
474 }
475 else
476 {
477 // If the socket is not expected to be connected (might be CONNECTING),
478 // then it is ok if it's CONNECTING or BROKEN.
479 EXPECT_NE(srt_getsockstate(m_caller_socket), SRTS_CONNECTED);
480 }
481
482 EXPECT_EQ(GetSocetkOption(m_caller_socket, SRTO_RCVKMSTATE), expect.km_state[CHECK_SOCKET_CALLER]);
483
484 EXPECT_EQ(srt_getsockstate(m_listener_socket), SRTS_LISTENING);
485 EXPECT_EQ(GetKMState(m_listener_socket), SRT_KM_S_UNSECURED);
486
487 if (is_blocking)
488 {
489 // srt_accept() has no timeout, so we have to close the socket and wait for the thread to exit.
490 // Just give it some time and close the socket.
491 std::this_thread::sleep_for(std::chrono::milliseconds(50));
492 ASSERT_NE(srt_close(m_listener_socket), SRT_ERROR);
493 accepting_thread.join();
494 }
495 }
496
497
498 private:
499 // put in any custom data members that you need
500
501 SRTSOCKET m_caller_socket = SRT_INVALID_SOCK;
502 SRTSOCKET m_listener_socket = SRT_INVALID_SOCK;
503
504 int m_pollid = 0;
505
506 const bool s_yes = true;
507 const bool s_no = false;
508
509 const bool m_is_tracing = false;
510 static const char* m_km_state[];
511 static const char* const* m_socket_state;
512 };
513
514
515
516 template<>
WaitOnEpoll(const TestResultBlocking &)517 int TestEnforcedEncryption::WaitOnEpoll<TestResultBlocking>(const TestResultBlocking &)
518 {
519 return SRT_EPOLL_IN;
520 }
521
PrintEpollEvent(std::ostream & os,int events,int et_events)522 static std::ostream& PrintEpollEvent(std::ostream& os, int events, int et_events)
523 {
524 using namespace std;
525
526 static pair<int, const char*> const namemap [] = {
527 make_pair(SRT_EPOLL_IN, "R"),
528 make_pair(SRT_EPOLL_OUT, "W"),
529 make_pair(SRT_EPOLL_ERR, "E"),
530 make_pair(SRT_EPOLL_UPDATE, "U")
531 };
532
533 int N = Size(namemap);
534
535 for (int i = 0; i < N; ++i)
536 {
537 if (events & namemap[i].first)
538 {
539 os << "[";
540 if (et_events & namemap[i].first)
541 os << "^";
542 os << namemap[i].second << "]";
543 }
544 }
545
546 return os;
547 }
548
549 template<>
WaitOnEpoll(const TestResultNonBlocking & expect)550 int TestEnforcedEncryption::WaitOnEpoll<TestResultNonBlocking>(const TestResultNonBlocking &expect)
551 {
552 const int default_len = 3;
553 SRT_EPOLL_EVENT ready[default_len];
554 const int epoll_res = srt_epoll_uwait(m_pollid, ready, default_len, 500);
555 std::cerr << "Epoll wait result: " << epoll_res;
556 if (epoll_res > 0)
557 {
558 std::cerr << " FOUND: @" << ready[0].fd << " in ";
559 PrintEpollEvent(std::cerr, ready[0].events, 0);
560 }
561 else
562 {
563 std::cerr << " NOTHING READY";
564 }
565 std::cerr << std::endl;
566
567 // Expect: -2 means that
568 if (expect.epoll_wait_ret != IGNORE_EPOLL)
569 {
570 EXPECT_EQ(epoll_res, expect.epoll_wait_ret);
571 }
572
573 if (epoll_res == SRT_ERROR)
574 {
575 std::cerr << "Epoll returned error: " << srt_getlasterror_str() << " (code " << srt_getlasterror(NULL) << ")\n";
576 return 0;
577 }
578
579 // We have exactly one socket here and we expect to return
580 // only this one, or nothing.
581 if (epoll_res != 0)
582 {
583 EXPECT_EQ(epoll_res, 1);
584 EXPECT_EQ(ready[0].fd, m_listener_socket);
585 }
586
587 return epoll_res == 0 ? 0 : int(ready[0].events);
588 }
589
590
591 template<>
GetTestMatrix(TEST_CASE test_case) const592 const TestCase<TestResultBlocking>& TestEnforcedEncryption::GetTestMatrix<TestResultBlocking>(TEST_CASE test_case) const
593 {
594 return g_test_matrix_blocking[test_case];
595 }
596
597 template<>
GetTestMatrix(TEST_CASE test_case) const598 const TestCase<TestResultNonBlocking>& TestEnforcedEncryption::GetTestMatrix<TestResultNonBlocking>(TEST_CASE test_case) const
599 {
600 return g_test_matrix_non_blocking[test_case];
601 }
602
603
604
605 const char* TestEnforcedEncryption::m_km_state[] = {
606 "SRT_KM_S_UNSECURED (0)", //No encryption
607 "SRT_KM_S_SECURING (1)", //Stream encrypted, exchanging Keying Material
608 "SRT_KM_S_SECURED (2)", //Stream encrypted, keying Material exchanged, decrypting ok.
609 "SRT_KM_S_NOSECRET (3)", //Stream encrypted and no secret to decrypt Keying Material
610 "SRT_KM_S_BADSECRET (4)" //Stream encrypted and wrong secret, cannot decrypt Keying Material
611 };
612
613
614 static const char* const socket_state_array[] = {
615 "IGNORE_SRTS",
616 "SRTS_INVALID",
617 "SRTS_INIT",
618 "SRTS_OPENED",
619 "SRTS_LISTENING",
620 "SRTS_CONNECTING",
621 "SRTS_CONNECTED",
622 "SRTS_BROKEN",
623 "SRTS_CLOSING",
624 "SRTS_CLOSED",
625 "SRTS_NONEXIST"
626 };
627
628 // A trick that allows the array to be indexed by -1
629 const char* const* TestEnforcedEncryption::m_socket_state = socket_state_array+1;
630
631 /**
632 * @fn TestEnforcedEncryption.PasswordLength
633 * @brief The password length should belong to the interval of [10; 80]
634 */
TEST_F(TestEnforcedEncryption,PasswordLength)635 TEST_F(TestEnforcedEncryption, PasswordLength)
636 {
637 #ifdef SRT_ENABLE_ENCRYPTION
638 // Empty string sets password to none
639 EXPECT_EQ(SetPassword(PEER_CALLER, std::string("")), SRT_SUCCESS);
640 EXPECT_EQ(SetPassword(PEER_LISTENER, std::string("")), SRT_SUCCESS);
641
642 EXPECT_EQ(SetPassword(PEER_CALLER, std::string("too_short")), SRT_ERROR);
643 EXPECT_EQ(SetPassword(PEER_LISTENER, std::string("too_short")), SRT_ERROR);
644
645 std::string long_pwd;
646 const int pwd_len = 81; // 80 is the maximum password length accepted
647 long_pwd.reserve(pwd_len);
648 const char start_char = '!';
649
650 // Please ensure to be within the valid ASCII symbols!
651 ASSERT_LT(pwd_len + start_char, 126);
652 for (int i = 0; i < pwd_len; ++i)
653 long_pwd.push_back(static_cast<char>(start_char + i));
654
655 EXPECT_EQ(SetPassword(PEER_CALLER, long_pwd), SRT_ERROR);
656 EXPECT_EQ(SetPassword(PEER_LISTENER, long_pwd), SRT_ERROR);
657
658 EXPECT_EQ(SetPassword(PEER_CALLER, std::string("proper_len")), SRT_SUCCESS);
659 EXPECT_EQ(SetPassword(PEER_LISTENER, std::string("proper_length")), SRT_SUCCESS);
660 #else
661 EXPECT_EQ(SetPassword(PEER_CALLER, "whateverpassword"), SRT_ERROR);
662 #endif
663 }
664
665
666 /**
667 * @fn TestEnforcedEncryption.SetGetDefault
668 * @brief The default value for the enforced encryption should be ON
669 */
TEST_F(TestEnforcedEncryption,SetGetDefault)670 TEST_F(TestEnforcedEncryption, SetGetDefault)
671 {
672 EXPECT_EQ(GetEnforcedEncryption(PEER_CALLER), true);
673 EXPECT_EQ(GetEnforcedEncryption(PEER_LISTENER), true);
674
675 EXPECT_EQ(SetEnforcedEncryption(PEER_CALLER, false), SRT_SUCCESS);
676 EXPECT_EQ(SetEnforcedEncryption(PEER_LISTENER, false), SRT_SUCCESS);
677
678 EXPECT_EQ(GetEnforcedEncryption(PEER_CALLER), false);
679 EXPECT_EQ(GetEnforcedEncryption(PEER_LISTENER), false);
680 }
681
682
683 #define CREATE_TEST_CASE_BLOCKING(CASE_NUMBER, DESC) TEST_F(TestEnforcedEncryption, CASE_NUMBER##_Blocking_##DESC)\
684 {\
685 TestConnect<TestResultBlocking>(TEST_##CASE_NUMBER);\
686 }
687
688 #define CREATE_TEST_CASE_NONBLOCKING(CASE_NUMBER, DESC) TEST_F(TestEnforcedEncryption, CASE_NUMBER##_NonBlocking_##DESC)\
689 {\
690 TestConnect<TestResultNonBlocking>(TEST_##CASE_NUMBER);\
691 }
692
693
694 #define CREATE_TEST_CASES(CASE_NUMBER, DESC) \
695 CREATE_TEST_CASE_NONBLOCKING(CASE_NUMBER, DESC) \
696 CREATE_TEST_CASE_BLOCKING(CASE_NUMBER, DESC)
697
698 #ifdef SRT_ENABLE_ENCRYPTION
699 CREATE_TEST_CASES(CASE_A_1, Enforced_On_On_Pwd_Set_Set_Match)
700 CREATE_TEST_CASES(CASE_A_2, Enforced_On_On_Pwd_Set_Set_Mismatch)
701 CREATE_TEST_CASES(CASE_A_3, Enforced_On_On_Pwd_Set_None)
702 CREATE_TEST_CASES(CASE_A_4, Enforced_On_On_Pwd_None_Set)
703 #endif
704 CREATE_TEST_CASES(CASE_A_5, Enforced_On_On_Pwd_None_None)
705
706 #ifdef SRT_ENABLE_ENCRYPTION
707 CREATE_TEST_CASES(CASE_B_1, Enforced_On_Off_Pwd_Set_Set_Match)
708 CREATE_TEST_CASES(CASE_B_2, Enforced_On_Off_Pwd_Set_Set_Mismatch)
709 CREATE_TEST_CASES(CASE_B_3, Enforced_On_Off_Pwd_Set_None)
710 CREATE_TEST_CASES(CASE_B_4, Enforced_On_Off_Pwd_None_Set)
711 #endif
712 CREATE_TEST_CASES(CASE_B_5, Enforced_On_Off_Pwd_None_None)
713
714 #ifdef SRT_ENABLE_ENCRYPTION
715 CREATE_TEST_CASES(CASE_C_1, Enforced_Off_On_Pwd_Set_Set_Match)
716 CREATE_TEST_CASES(CASE_C_2, Enforced_Off_On_Pwd_Set_Set_Mismatch)
717 CREATE_TEST_CASES(CASE_C_3, Enforced_Off_On_Pwd_Set_None)
718 CREATE_TEST_CASES(CASE_C_4, Enforced_Off_On_Pwd_None_Set)
719 #endif
720 CREATE_TEST_CASES(CASE_C_5, Enforced_Off_On_Pwd_None_None)
721
722 #ifdef SRT_ENABLE_ENCRYPTION
723 CREATE_TEST_CASES(CASE_D_1, Enforced_Off_Off_Pwd_Set_Set_Match)
724 CREATE_TEST_CASES(CASE_D_2, Enforced_Off_Off_Pwd_Set_Set_Mismatch)
725 CREATE_TEST_CASES(CASE_D_3, Enforced_Off_Off_Pwd_Set_None)
726 CREATE_TEST_CASES(CASE_D_4, Enforced_Off_Off_Pwd_None_Set)
727 #endif
728 CREATE_TEST_CASES(CASE_D_5, Enforced_Off_Off_Pwd_None_None)
729
730