1 /*
2  * SRT - Secure, Reliable, Transport
3  * Copyright (c) 2018 Haivision Systems Inc.
4  *
5  * This Source Code Form is subject to the terms of the Mozilla Public
6  * License, v. 2.0. If a copy of the MPL was not distributed with this
7  * file, You can obtain one at http://mozilla.org/MPL/2.0/.
8  *
9  * Written by:
10  *             Haivision Systems Inc.
11  */
12 
13 #include <gtest/gtest.h>
14 #include <thread>
15 #include <condition_variable>
16 #include <mutex>
17 
18 #include "srt.h"
19 #include "sync.h"
20 
21 
22 
23 enum PEER_TYPE
24 {
25     PEER_CALLER   = 0,
26     PEER_LISTENER = 1,
27     PEER_COUNT    = 2,  // Number of peers
28 };
29 
30 
31 enum CHECK_SOCKET_TYPE
32 {
33     CHECK_SOCKET_CALLER   = 0,
34     CHECK_SOCKET_ACCEPTED = 1,
35     CHECK_SOCKET_COUNT    = 2,  // Number of peers
36 };
37 
38 
39 enum TEST_CASE
40 {
41     TEST_CASE_A_1 = 0,
42     TEST_CASE_A_2,
43     TEST_CASE_A_3,
44     TEST_CASE_A_4,
45     TEST_CASE_A_5,
46     TEST_CASE_B_1,
47     TEST_CASE_B_2,
48     TEST_CASE_B_3,
49     TEST_CASE_B_4,
50     TEST_CASE_B_5,
51     TEST_CASE_C_1,
52     TEST_CASE_C_2,
53     TEST_CASE_C_3,
54     TEST_CASE_C_4,
55     TEST_CASE_C_5,
56     TEST_CASE_D_1,
57     TEST_CASE_D_2,
58     TEST_CASE_D_3,
59     TEST_CASE_D_4,
60     TEST_CASE_D_5,
61 };
62 
63 
64 struct TestResultNonBlocking
65 {
66     int     connect_ret;
67     int     accept_ret;
68     int     epoll_wait_ret;
69     int     epoll_event;
70     int     socket_state[CHECK_SOCKET_COUNT];
71     int     km_state    [CHECK_SOCKET_COUNT];
72 };
73 
74 
75 struct TestResultBlocking
76 {
77     int     connect_ret;
78     int     accept_ret;
79     int     socket_state[CHECK_SOCKET_COUNT];
80     int     km_state[CHECK_SOCKET_COUNT];
81 };
82 
83 
84 template<typename TResult>
85 struct TestCase
86 {
87     bool                enforcedenc [PEER_COUNT];
88     const std::string  (&password)[PEER_COUNT];
89     TResult             expected_result;
90 };
91 
92 typedef TestCase<TestResultNonBlocking>  TestCaseNonBlocking;
93 typedef TestCase<TestResultBlocking>     TestCaseBlocking;
94 
95 
96 
97 static const std::string s_pwd_a ("s!t@r#i$c^t");
98 static const std::string s_pwd_b ("s!t@r#i$c^tu");
99 static const std::string s_pwd_no("");
100 
101 
102 
103 /*
104  * TESTING SCENARIO
105  * Both peers exchange HandShake v5.
106  * Listener is sender   in a non-blocking mode
107  * Caller   is receiver in a non-blocking mode
108 
109  * Cases B.2-B.4 are specific. Here we have incompatible password settings, but
110  * listener accepts it, while caller rejects it. In this case we have a short-living
111  * confusion state: The connection is accepted on the listener side, and the listener
112  * sends back the conclusion handshake, but caller will reject it.
113  *
114  * Because of that, we should ignore what will happen in the listener as this is
115  * just a matter of luck: if the listener thread is lucky, it will report the socket
116  * to accept, so epoll will signal it and accept will report it, and moreover, further
117  * good luck on this socket would make the state check return SRTS_CONNECTED. Without
118  * this good luck, the caller might be quick enough to reject the handshake and send
119  * the UMSG_SHUTDOWN packet to the peer. If it gets with it before acceptance, it will
120  * withdraw the socket before it could be reported by accept.
121  *
122  * Still, we check predictable things here, so we accept two possibilities:
123  * - The accepted socket wasn't reported at all
124  * - The accepted socket was reported, and after `srt_connect` is done, it should turn to SRTS_BROKEN.
125  *
126  * This embraces both cases when the accepted socket was broken in the beginning, and when it was CONNECTED
127  * in the beginning, but broke soon thereafter.
128  *
129  * This behavior is predicted and accepted - it's also the reason that setting ENFORCEDENC to false is
130  * NOT RECOMMENDED on a listener socket that isn't intended to accept only connections from known callers
131  * that are known to have set this flag also to false.
132  *
133  * In the cases C.2-C.4 it is the listener who rejects the connection, so we don't have an accepted socket
134  * and the situation is always the same and clear in the beginning. The caller cannot continue with the
135  * connection after listener accepted it, even if it tolerates incompatible password settings.
136  */
137 
138 const int IGNORE_EPOLL = -2;
139 const int IGNORE_SRTS = -1;
140 
141 const TestCaseNonBlocking g_test_matrix_non_blocking[] =
142 {
143         // ENFORCEDENC       |  Password           |                                | EPoll wait                       | socket_state                            |  KM State
144         // caller | listener |  caller  | listener |  connect_ret   accept_ret      |  ret         | event             | caller              accepted |  caller              listener
145 /*A.1 */ { {true,     true  }, {s_pwd_a,   s_pwd_a}, { SRT_SUCCESS,                0,             1,  SRT_EPOLL_IN,  {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_SECURED,     SRT_KM_S_SECURED}}},
146 /*A.2 */ { {true,     true  }, {s_pwd_a,   s_pwd_b}, { SRT_SUCCESS, SRT_INVALID_SOCK,             0,  0,             {SRTS_BROKEN,       IGNORE_SRTS}, {SRT_KM_S_UNSECURED,        IGNORE_SRTS}}},
147 /*A.3 */ { {true,     true  }, {s_pwd_a,  s_pwd_no}, { SRT_SUCCESS, SRT_INVALID_SOCK,             0,  0,             {SRTS_BROKEN,       IGNORE_SRTS}, {SRT_KM_S_UNSECURED,        IGNORE_SRTS}}},
148 /*A.4 */ { {true,     true  }, {s_pwd_no,  s_pwd_b}, { SRT_SUCCESS, SRT_INVALID_SOCK,             0,  0,             {SRTS_BROKEN,       IGNORE_SRTS}, {SRT_KM_S_UNSECURED,        IGNORE_SRTS}}},
149 /*A.5 */ { {true,     true  }, {s_pwd_no, s_pwd_no}, { SRT_SUCCESS,                0,             1,  SRT_EPOLL_IN,  {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
150 
151 /*B.1 */ { {true,    false  }, {s_pwd_a,   s_pwd_a}, { SRT_SUCCESS,                0,             1,  SRT_EPOLL_IN,  {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_SECURED,     SRT_KM_S_SECURED}}},
152 /*B.2 */ { {true,    false  }, {s_pwd_a,   s_pwd_b}, { SRT_SUCCESS,                0,  IGNORE_EPOLL,  0,             {SRTS_CONNECTING,   SRTS_BROKEN}, {SRT_KM_S_BADSECRET, SRT_KM_S_BADSECRET}}},
153 /*B.3 */ { {true,    false  }, {s_pwd_a,  s_pwd_no}, { SRT_SUCCESS,                0,  IGNORE_EPOLL,  0,             {SRTS_CONNECTING,   SRTS_BROKEN}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
154 /*B.4 */ { {true,    false  }, {s_pwd_no,  s_pwd_b}, { SRT_SUCCESS,                0,  IGNORE_EPOLL,  0,             {SRTS_CONNECTING,   SRTS_BROKEN}, {SRT_KM_S_UNSECURED,  SRT_KM_S_NOSECRET}}},
155 /*B.5 */ { {true,    false  }, {s_pwd_no, s_pwd_no}, { SRT_SUCCESS,                0,             1,  SRT_EPOLL_IN,  {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
156 
157 /*C.1 */ { {false,    true  }, {s_pwd_a,   s_pwd_a}, { SRT_SUCCESS,                0,             1,  SRT_EPOLL_IN,  {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_SECURED,     SRT_KM_S_SECURED}}},
158 /*C.2 */ { {false,    true  }, {s_pwd_a,   s_pwd_b}, { SRT_SUCCESS, SRT_INVALID_SOCK,             0,  0,             {SRTS_BROKEN,       IGNORE_SRTS}, {SRT_KM_S_UNSECURED,        IGNORE_SRTS}}},
159 /*C.3 */ { {false,    true  }, {s_pwd_a,  s_pwd_no}, { SRT_SUCCESS, SRT_INVALID_SOCK,             0,  0,             {SRTS_BROKEN,       IGNORE_SRTS}, {SRT_KM_S_UNSECURED,        IGNORE_SRTS}}},
160 /*C.4 */ { {false,    true  }, {s_pwd_no,  s_pwd_b}, { SRT_SUCCESS, SRT_INVALID_SOCK,             0,  0,             {SRTS_BROKEN,       IGNORE_SRTS}, {SRT_KM_S_UNSECURED,        IGNORE_SRTS}}},
161 /*C.5 */ { {false,    true  }, {s_pwd_no, s_pwd_no}, { SRT_SUCCESS,                0,             1,  SRT_EPOLL_IN,  {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
162 
163 /*D.1 */ { {false,   false  }, {s_pwd_a,   s_pwd_a}, { SRT_SUCCESS,                0,             1,  SRT_EPOLL_IN,  {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_SECURED,     SRT_KM_S_SECURED}}},
164 /*D.2 */ { {false,   false  }, {s_pwd_a,   s_pwd_b}, { SRT_SUCCESS,                0,             1,  SRT_EPOLL_IN,  {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_BADSECRET, SRT_KM_S_BADSECRET}}},
165 /*D.3 */ { {false,   false  }, {s_pwd_a,  s_pwd_no}, { SRT_SUCCESS,                0,             1,  SRT_EPOLL_IN,  {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
166 /*D.4 */ { {false,   false  }, {s_pwd_no,  s_pwd_b}, { SRT_SUCCESS,                0,             1,  SRT_EPOLL_IN,  {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_NOSECRET,   SRT_KM_S_NOSECRET}}},
167 /*D.5 */ { {false,   false  }, {s_pwd_no, s_pwd_no}, { SRT_SUCCESS,                0,             1,  SRT_EPOLL_IN,  {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
168 };
169 
170 
171 /*
172  * TESTING SCENARIO
173  * Both peers exchange HandShake v5.
174  * Listener is sender   in a blocking mode
175  * Caller   is receiver in a blocking mode
176  *
177  * In the cases B.2-B.4 the caller will reject the connection due to the enforced encryption check
178  * of the HS response from the listener on the stage of the KM response check.
179  * While the listener accepts the connection with the connected state. So the caller sends UMSG_SHUTDOWN
180  * to notify the listener that it has closed the connection. The accepted socket gets the SRTS_BROKEN states.
181  * For these cases a special accept_ret = -2 is used, that allows the accepted socket to be broken or already closed.
182  *
183  * In the cases C.2-C.4 it is the listener who rejects the connection, so we don't have an accepted socket.
184  */
185 const TestCaseBlocking g_test_matrix_blocking[] =
186 {
187         // ENFORCEDENC       |  Password           |                                      | socket_state                   |  KM State
188         // caller | listener |  caller  | listener |  connect_ret         accept_ret      | caller                accepted |  caller              listener
189 /*A.1 */ { {true,     true  }, {s_pwd_a,   s_pwd_a}, { SRT_SUCCESS,                     0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_SECURED,     SRT_KM_S_SECURED}}},
190 /*A.2 */ { {true,     true  }, {s_pwd_a,   s_pwd_b}, { SRT_INVALID_SOCK, SRT_INVALID_SOCK, {SRTS_OPENED,                -1}, {SRT_KM_S_UNSECURED,                 -1}}},
191 /*A.3 */ { {true,     true  }, {s_pwd_a,  s_pwd_no}, { SRT_INVALID_SOCK, SRT_INVALID_SOCK, {SRTS_OPENED,                -1}, {SRT_KM_S_UNSECURED,                 -1}}},
192 /*A.4 */ { {true,     true  }, {s_pwd_no,  s_pwd_b}, { SRT_INVALID_SOCK, SRT_INVALID_SOCK, {SRTS_OPENED,                -1}, {SRT_KM_S_UNSECURED,                 -1}}},
193 /*A.5 */ { {true,     true  }, {s_pwd_no, s_pwd_no}, { SRT_SUCCESS,                     0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
194 
195 /*B.1 */ { {true,    false  }, {s_pwd_a,   s_pwd_a}, { SRT_SUCCESS,                     0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_SECURED,     SRT_KM_S_SECURED}}},
196 /*B.2 */ { {true,    false  }, {s_pwd_a,   s_pwd_b}, { SRT_INVALID_SOCK,               -2, {SRTS_OPENED,       SRTS_BROKEN}, {SRT_KM_S_BADSECRET, SRT_KM_S_BADSECRET}}},
197 /*B.3 */ { {true,    false  }, {s_pwd_a,  s_pwd_no}, { SRT_INVALID_SOCK,               -2, {SRTS_OPENED,       SRTS_BROKEN}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
198 /*B.4 */ { {true,    false  }, {s_pwd_no,  s_pwd_b}, { SRT_INVALID_SOCK,               -2, {SRTS_OPENED,       SRTS_BROKEN}, {SRT_KM_S_UNSECURED,  SRT_KM_S_NOSECRET}}},
199 /*B.5 */ { {true,    false  }, {s_pwd_no, s_pwd_no}, { SRT_SUCCESS,                     0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
200 
201 /*C.1 */ { {false,    true  }, {s_pwd_a,   s_pwd_a}, { SRT_SUCCESS,                     0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_SECURED,     SRT_KM_S_SECURED}}},
202 /*C.2 */ { {false,    true  }, {s_pwd_a,   s_pwd_b}, { SRT_INVALID_SOCK, SRT_INVALID_SOCK, {SRTS_OPENED,                -1}, {SRT_KM_S_UNSECURED,                 -1}}},
203 /*C.3 */ { {false,    true  }, {s_pwd_a,  s_pwd_no}, { SRT_INVALID_SOCK, SRT_INVALID_SOCK, {SRTS_OPENED,                -1}, {SRT_KM_S_UNSECURED,                 -1}}},
204 /*C.4 */ { {false,    true  }, {s_pwd_no,  s_pwd_b}, { SRT_INVALID_SOCK, SRT_INVALID_SOCK, {SRTS_OPENED,                -1}, {SRT_KM_S_UNSECURED,                 -1}}},
205 /*C.5 */ { {false,    true  }, {s_pwd_no, s_pwd_no}, { SRT_SUCCESS,                     0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
206 
207 /*D.1 */ { {false,   false  }, {s_pwd_a,   s_pwd_a}, { SRT_SUCCESS,                     0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_SECURED,     SRT_KM_S_SECURED}}},
208 /*D.2 */ { {false,   false  }, {s_pwd_a,   s_pwd_b}, { SRT_SUCCESS,                     0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_BADSECRET, SRT_KM_S_BADSECRET}}},
209 /*D.3 */ { {false,   false  }, {s_pwd_a,  s_pwd_no}, { SRT_SUCCESS,                     0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
210 /*D.4 */ { {false,   false  }, {s_pwd_no,  s_pwd_b}, { SRT_SUCCESS,                     0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_NOSECRET,   SRT_KM_S_NOSECRET}}},
211 /*D.5 */ { {false,   false  }, {s_pwd_no, s_pwd_no}, { SRT_SUCCESS,                     0, {SRTS_CONNECTED, SRTS_CONNECTED}, {SRT_KM_S_UNSECURED, SRT_KM_S_UNSECURED}}},
212 };
213 
214 
215 
216 class TestEnforcedEncryption
217     : public ::testing::Test
218 {
219 protected:
TestEnforcedEncryption()220     TestEnforcedEncryption()
221     {
222         // initialization code here
223     }
224 
~TestEnforcedEncryption()225     ~TestEnforcedEncryption()
226     {
227         // cleanup any pending stuff, but no exceptions allowed
228     }
229 
230 protected:
231 
232     // SetUp() is run immediately before a test starts.
SetUp()233     void SetUp()
234     {
235         ASSERT_EQ(srt_startup(), 0);
236 
237         m_pollid = srt_epoll_create();
238         ASSERT_GE(m_pollid, 0);
239 
240         m_caller_socket = srt_create_socket();
241         ASSERT_NE(m_caller_socket, SRT_INVALID_SOCK);
242 
243         ASSERT_NE(srt_setsockflag(m_caller_socket,    SRTO_SENDER,    &s_yes, sizeof s_yes), SRT_ERROR);
244         ASSERT_NE(srt_setsockopt (m_caller_socket, 0, SRTO_TSBPDMODE, &s_yes, sizeof s_yes), SRT_ERROR);
245 
246         m_listener_socket = srt_create_socket();
247         ASSERT_NE(m_listener_socket, SRT_INVALID_SOCK);
248 
249         ASSERT_NE(srt_setsockflag(m_listener_socket,    SRTO_SENDER,    &s_no,  sizeof s_no),  SRT_ERROR);
250         ASSERT_NE(srt_setsockopt (m_listener_socket, 0, SRTO_TSBPDMODE, &s_yes, sizeof s_yes), SRT_ERROR);
251 
252         // Will use this epoll to wait for srt_accept(...)
253         const int epoll_out = SRT_EPOLL_IN | SRT_EPOLL_ERR;
254         ASSERT_NE(srt_epoll_add_usock(m_pollid, m_listener_socket, &epoll_out), SRT_ERROR);
255     }
256 
TearDown()257     void TearDown()
258     {
259         // Code here will be called just after the test completes.
260         // OK to throw exceptions from here if needed.
261         ASSERT_NE(srt_close(m_caller_socket),   SRT_ERROR);
262         ASSERT_NE(srt_close(m_listener_socket), SRT_ERROR);
263         srt_cleanup();
264     }
265 
266 
267 public:
268 
269 
SetEnforcedEncryption(PEER_TYPE peer,bool value)270     int SetEnforcedEncryption(PEER_TYPE peer, bool value)
271     {
272         const SRTSOCKET &socket = peer == PEER_CALLER ? m_caller_socket : m_listener_socket;
273         return srt_setsockopt(socket, 0, SRTO_ENFORCEDENCRYPTION, value ? &s_yes : &s_no, sizeof s_yes);
274     }
275 
276 
GetEnforcedEncryption(PEER_TYPE peer_type)277     bool GetEnforcedEncryption(PEER_TYPE peer_type)
278     {
279         const SRTSOCKET socket = peer_type == PEER_CALLER ? m_caller_socket : m_listener_socket;
280         bool optval;
281         int  optlen = sizeof optval;
282         EXPECT_EQ(srt_getsockopt(socket, 0, SRTO_ENFORCEDENCRYPTION, (void*)&optval, &optlen), SRT_SUCCESS);
283         return optval ? true : false;
284     }
285 
286 
SetPassword(PEER_TYPE peer_type,const std::basic_string<char> & pwd)287     int SetPassword(PEER_TYPE peer_type, const std::basic_string<char> &pwd)
288     {
289         const SRTSOCKET socket = peer_type == PEER_CALLER ? m_caller_socket : m_listener_socket;
290         return srt_setsockopt(socket, 0, SRTO_PASSPHRASE, pwd.c_str(), (int) pwd.size());
291     }
292 
293 
GetKMState(SRTSOCKET socket)294     int GetKMState(SRTSOCKET socket)
295     {
296         int km_state = 0;
297         int opt_size = sizeof km_state;
298         EXPECT_EQ(srt_getsockopt(socket, 0, SRTO_KMSTATE, reinterpret_cast<void*>(&km_state), &opt_size), SRT_SUCCESS);
299 
300         return km_state;
301     }
302 
303 
GetSocetkOption(SRTSOCKET socket,SRT_SOCKOPT opt)304     int GetSocetkOption(SRTSOCKET socket, SRT_SOCKOPT opt)
305     {
306         int val = 0;
307         int size = sizeof val;
308         EXPECT_EQ(srt_getsockopt(socket, 0, opt, reinterpret_cast<void*>(&val), &size), SRT_SUCCESS);
309 
310         return val;
311     }
312 
313 
314     template<typename TResult>
315     int WaitOnEpoll(const TResult &expect);
316 
317 
318     template<typename TResult>
319     const TestCase<TResult>& GetTestMatrix(TEST_CASE test_case) const;
320 
321     template<typename TResult>
TestConnect(TEST_CASE test_case)322     void TestConnect(TEST_CASE test_case/*, bool is_blocking*/)
323     {
324         const bool is_blocking = std::is_same<TResult, TestResultBlocking>::value;
325         if (is_blocking)
326         {
327             ASSERT_NE(srt_setsockopt(  m_caller_socket, 0, SRTO_RCVSYN, &s_yes, sizeof s_yes), SRT_ERROR);
328             ASSERT_NE(srt_setsockopt(  m_caller_socket, 0, SRTO_SNDSYN, &s_yes, sizeof s_yes), SRT_ERROR);
329             ASSERT_NE(srt_setsockopt(m_listener_socket, 0, SRTO_RCVSYN, &s_yes, sizeof s_yes), SRT_ERROR);
330             ASSERT_NE(srt_setsockopt(m_listener_socket, 0, SRTO_SNDSYN, &s_yes, sizeof s_yes), SRT_ERROR);
331         }
332         else
333         {
334             ASSERT_NE(srt_setsockopt(  m_caller_socket, 0, SRTO_RCVSYN, &s_no, sizeof s_no), SRT_ERROR); // non-blocking mode
335             ASSERT_NE(srt_setsockopt(  m_caller_socket, 0, SRTO_SNDSYN, &s_no, sizeof s_no), SRT_ERROR); // non-blocking mode
336             ASSERT_NE(srt_setsockopt(m_listener_socket, 0, SRTO_RCVSYN, &s_no, sizeof s_no), SRT_ERROR); // non-blocking mode
337             ASSERT_NE(srt_setsockopt(m_listener_socket, 0, SRTO_SNDSYN, &s_no, sizeof s_no), SRT_ERROR); // non-blocking mode
338         }
339 
340         // Prepare input state
341         const TestCase<TResult> &test = GetTestMatrix<TResult>(test_case);
342         ASSERT_EQ(SetEnforcedEncryption(PEER_CALLER, test.enforcedenc[PEER_CALLER]), SRT_SUCCESS);
343         ASSERT_EQ(SetEnforcedEncryption(PEER_LISTENER, test.enforcedenc[PEER_LISTENER]), SRT_SUCCESS);
344 
345         ASSERT_EQ(SetPassword(PEER_CALLER, test.password[PEER_CALLER]), SRT_SUCCESS);
346         ASSERT_EQ(SetPassword(PEER_LISTENER, test.password[PEER_LISTENER]), SRT_SUCCESS);
347 
348         const TResult &expect = test.expected_result;
349 
350         // Start testing
351         volatile bool caller_done = false;
352         sockaddr_in sa;
353         memset(&sa, 0, sizeof sa);
354         sa.sin_family = AF_INET;
355         sa.sin_port = htons(5200);
356         ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &sa.sin_addr), 1);
357         sockaddr* psa = (sockaddr*)&sa;
358         ASSERT_NE(srt_bind(m_listener_socket, psa, sizeof sa), SRT_ERROR);
359         ASSERT_NE(srt_listen(m_listener_socket, 4), SRT_ERROR);
360 
361         auto accepting_thread = std::thread([&] {
362             const int epoll_event = WaitOnEpoll(expect);
363 
364             // In a blocking mode we expect a socket returned from srt_accept() if the srt_connect succeeded.
365             // In a non-blocking mode we expect a socket returned from srt_accept() if the srt_connect succeeded,
366             // otherwise SRT_INVALID_SOCKET after the listening socket is closed.
367             sockaddr_in client_address;
368             int length = sizeof(sockaddr_in);
369             SRTSOCKET accepted_socket = -1;
370             if (epoll_event == SRT_EPOLL_IN)
371             {
372                 accepted_socket = srt_accept(m_listener_socket, (sockaddr*)&client_address, &length);
373                 std::cout << "ACCEPT: done, result=" << accepted_socket << std::endl;
374             }
375             else
376             {
377                 std::cout << "ACCEPT: NOT done\n";
378             }
379 
380             if (accepted_socket == SRT_INVALID_SOCK)
381             {
382                 std::cerr << "[T] ACCEPT ERROR: " << srt_getlasterror_str() << std::endl;
383             }
384             else
385             {
386                 std::cerr << "[T] ACCEPT SUCCEEDED: @" << accepted_socket << "\n";
387             }
388 
389             EXPECT_NE(accepted_socket, 0);
390             if (expect.accept_ret == SRT_INVALID_SOCK)
391             {
392                 EXPECT_EQ(accepted_socket, SRT_INVALID_SOCK);
393             }
394             else if (expect.accept_ret != -2)
395             {
396                 EXPECT_NE(accepted_socket, SRT_INVALID_SOCK);
397             }
398 
399             if (accepted_socket != SRT_INVALID_SOCK && expect.socket_state[CHECK_SOCKET_ACCEPTED] != IGNORE_SRTS)
400             {
401                 if (m_is_tracing)
402                 {
403                     std::cerr << "EARLY Socket state accepted: " << m_socket_state[srt_getsockstate(accepted_socket)]
404                         << " (expected: " << m_socket_state[expect.socket_state[CHECK_SOCKET_ACCEPTED]] << ")\n";
405                     std::cerr << "KM State accepted:     " << m_km_state[GetKMState(accepted_socket)] << '\n';
406                     std::cerr << "RCV KM State accepted:     " << m_km_state[GetSocetkOption(accepted_socket, SRTO_RCVKMSTATE)] << '\n';
407                     std::cerr << "SND KM State accepted:     " << m_km_state[GetSocetkOption(accepted_socket, SRTO_SNDKMSTATE)] << '\n';
408                 }
409 
410                 // We have to wait some time for the socket to be able to process the HS responce from the caller.
411                 // In test cases B2 - B4 the socket is expected to change its state from CONNECTED to BROKEN
412                 // due to KM mismatches
413                 do
414                 {
415                     std::this_thread::sleep_for(std::chrono::milliseconds(50));
416                 } while (!caller_done);
417 
418                 // Special case when the expected state is "broken": if so, tolerate every possible
419                 // socket state, just NOT LESS than SRTS_BROKEN, and also don't read any flags on that socket.
420 
421                 if (expect.socket_state[CHECK_SOCKET_ACCEPTED] == SRTS_BROKEN)
422                 {
423                     EXPECT_GE(srt_getsockstate(accepted_socket), SRTS_BROKEN);
424                 }
425                 else
426                 {
427                     EXPECT_EQ(srt_getsockstate(accepted_socket), expect.socket_state[CHECK_SOCKET_ACCEPTED]);
428                     EXPECT_EQ(GetSocetkOption(accepted_socket, SRTO_SNDKMSTATE), expect.km_state[CHECK_SOCKET_ACCEPTED]);
429                 }
430 
431                 if (m_is_tracing)
432                 {
433                     const SRT_SOCKSTATUS status = srt_getsockstate(accepted_socket);
434                     std::cerr << "LATE Socket state accepted: " << m_socket_state[status]
435                         << " (expected: " << m_socket_state[expect.socket_state[CHECK_SOCKET_ACCEPTED]] << ")\n";
436                 }
437             }
438         });
439 
440         const int connect_ret = srt_connect(m_caller_socket, psa, sizeof sa);
441         EXPECT_EQ(connect_ret, expect.connect_ret);
442 
443         if (connect_ret == SRT_ERROR && connect_ret != expect.connect_ret)
444         {
445             std::cerr << "UNEXPECTED! srt_connect returned error: "
446                 << srt_getlasterror_str() << " (code " << srt_getlasterror(NULL) << ")\n";
447         }
448 
449         caller_done = true;
450 
451         if (is_blocking == false)
452             accepting_thread.join();
453 
454         if (m_is_tracing)
455         {
456             std::cerr << "Socket state caller:   " << m_socket_state[srt_getsockstate(m_caller_socket)] << "\n";
457             std::cerr << "Socket state listener: " << m_socket_state[srt_getsockstate(m_listener_socket)] << "\n";
458             std::cerr << "KM State caller:       " << m_km_state[GetKMState(m_caller_socket)] << '\n';
459             std::cerr << "RCV KM State caller:   " << m_km_state[GetSocetkOption(m_caller_socket, SRTO_RCVKMSTATE)] << '\n';
460             std::cerr << "SND KM State caller:   " << m_km_state[GetSocetkOption(m_caller_socket, SRTO_SNDKMSTATE)] << '\n';
461             std::cerr << "KM State listener:     " << m_km_state[GetKMState(m_listener_socket)] << '\n';
462         }
463 
464         // If a blocking call to srt_connect() returned error, then the state is not valid,
465         // but we still check it because we know what it should be. This way we may see potential changes in the core behavior.
466         if (is_blocking)
467         {
468             EXPECT_EQ(srt_getsockstate(m_caller_socket), expect.socket_state[CHECK_SOCKET_CALLER]);
469         }
470         // A caller socket, regardless of the mode, if it's not expected to be connected, check negatively.
471         if (expect.socket_state[CHECK_SOCKET_CALLER] == SRTS_CONNECTED)
472         {
473             EXPECT_EQ(srt_getsockstate(m_caller_socket), SRTS_CONNECTED);
474         }
475         else
476         {
477             // If the socket is not expected to be connected (might be CONNECTING),
478             // then it is ok if it's CONNECTING or BROKEN.
479             EXPECT_NE(srt_getsockstate(m_caller_socket), SRTS_CONNECTED);
480         }
481 
482         EXPECT_EQ(GetSocetkOption(m_caller_socket, SRTO_RCVKMSTATE), expect.km_state[CHECK_SOCKET_CALLER]);
483 
484         EXPECT_EQ(srt_getsockstate(m_listener_socket), SRTS_LISTENING);
485         EXPECT_EQ(GetKMState(m_listener_socket), SRT_KM_S_UNSECURED);
486 
487         if (is_blocking)
488         {
489             // srt_accept() has no timeout, so we have to close the socket and wait for the thread to exit.
490             // Just give it some time and close the socket.
491             std::this_thread::sleep_for(std::chrono::milliseconds(50));
492             ASSERT_NE(srt_close(m_listener_socket), SRT_ERROR);
493             accepting_thread.join();
494         }
495     }
496 
497 
498 private:
499     // put in any custom data members that you need
500 
501     SRTSOCKET m_caller_socket   = SRT_INVALID_SOCK;
502     SRTSOCKET m_listener_socket = SRT_INVALID_SOCK;
503 
504     int       m_pollid          = 0;
505 
506     const bool s_yes = true;
507     const bool s_no  = false;
508 
509     const bool          m_is_tracing = false;
510     static const char*  m_km_state[];
511     static const char* const* m_socket_state;
512 };
513 
514 
515 
516 template<>
WaitOnEpoll(const TestResultBlocking &)517 int TestEnforcedEncryption::WaitOnEpoll<TestResultBlocking>(const TestResultBlocking &)
518 {
519     return SRT_EPOLL_IN;
520 }
521 
PrintEpollEvent(std::ostream & os,int events,int et_events)522 static std::ostream& PrintEpollEvent(std::ostream& os, int events, int et_events)
523 {
524     using namespace std;
525 
526     static pair<int, const char*> const namemap [] = {
527         make_pair(SRT_EPOLL_IN, "R"),
528         make_pair(SRT_EPOLL_OUT, "W"),
529         make_pair(SRT_EPOLL_ERR, "E"),
530         make_pair(SRT_EPOLL_UPDATE, "U")
531     };
532 
533     int N = Size(namemap);
534 
535     for (int i = 0; i < N; ++i)
536     {
537         if (events & namemap[i].first)
538         {
539             os << "[";
540             if (et_events & namemap[i].first)
541                 os << "^";
542             os << namemap[i].second << "]";
543         }
544     }
545 
546     return os;
547 }
548 
549 template<>
WaitOnEpoll(const TestResultNonBlocking & expect)550 int TestEnforcedEncryption::WaitOnEpoll<TestResultNonBlocking>(const TestResultNonBlocking &expect)
551 {
552     const int default_len = 3;
553     SRT_EPOLL_EVENT ready[default_len];
554     const int epoll_res = srt_epoll_uwait(m_pollid, ready, default_len, 500);
555     std::cerr << "Epoll wait result: " << epoll_res;
556     if (epoll_res > 0)
557     {
558         std::cerr << " FOUND: @" << ready[0].fd << " in ";
559         PrintEpollEvent(std::cerr, ready[0].events, 0);
560     }
561     else
562     {
563         std::cerr << " NOTHING READY";
564     }
565     std::cerr << std::endl;
566 
567     // Expect: -2 means that
568     if (expect.epoll_wait_ret != IGNORE_EPOLL)
569     {
570         EXPECT_EQ(epoll_res, expect.epoll_wait_ret);
571     }
572 
573     if (epoll_res == SRT_ERROR)
574     {
575         std::cerr << "Epoll returned error: " << srt_getlasterror_str() << " (code " << srt_getlasterror(NULL) << ")\n";
576         return 0;
577     }
578 
579     // We have exactly one socket here and we expect to return
580     // only this one, or nothing.
581     if (epoll_res != 0)
582     {
583         EXPECT_EQ(epoll_res, 1);
584         EXPECT_EQ(ready[0].fd, m_listener_socket);
585     }
586 
587     return epoll_res == 0 ? 0 : int(ready[0].events);
588 }
589 
590 
591 template<>
GetTestMatrix(TEST_CASE test_case) const592 const TestCase<TestResultBlocking>& TestEnforcedEncryption::GetTestMatrix<TestResultBlocking>(TEST_CASE test_case) const
593 {
594     return g_test_matrix_blocking[test_case];
595 }
596 
597 template<>
GetTestMatrix(TEST_CASE test_case) const598 const TestCase<TestResultNonBlocking>& TestEnforcedEncryption::GetTestMatrix<TestResultNonBlocking>(TEST_CASE test_case) const
599 {
600     return g_test_matrix_non_blocking[test_case];
601 }
602 
603 
604 
605 const char* TestEnforcedEncryption::m_km_state[] = {
606     "SRT_KM_S_UNSECURED (0)",      //No encryption
607     "SRT_KM_S_SECURING  (1)",      //Stream encrypted, exchanging Keying Material
608     "SRT_KM_S_SECURED   (2)",      //Stream encrypted, keying Material exchanged, decrypting ok.
609     "SRT_KM_S_NOSECRET  (3)",      //Stream encrypted and no secret to decrypt Keying Material
610     "SRT_KM_S_BADSECRET (4)"       //Stream encrypted and wrong secret, cannot decrypt Keying Material
611 };
612 
613 
614 static const char* const socket_state_array[] = {
615     "IGNORE_SRTS",
616     "SRTS_INVALID",
617     "SRTS_INIT",
618     "SRTS_OPENED",
619     "SRTS_LISTENING",
620     "SRTS_CONNECTING",
621     "SRTS_CONNECTED",
622     "SRTS_BROKEN",
623     "SRTS_CLOSING",
624     "SRTS_CLOSED",
625     "SRTS_NONEXIST"
626 };
627 
628 // A trick that allows the array to be indexed by -1
629 const char* const* TestEnforcedEncryption::m_socket_state = socket_state_array+1;
630 
631 /**
632  * @fn TestEnforcedEncryption.PasswordLength
633  * @brief The password length should belong to the interval of [10; 80]
634  */
TEST_F(TestEnforcedEncryption,PasswordLength)635 TEST_F(TestEnforcedEncryption, PasswordLength)
636 {
637 #ifdef SRT_ENABLE_ENCRYPTION
638     // Empty string sets password to none
639     EXPECT_EQ(SetPassword(PEER_CALLER,   std::string("")), SRT_SUCCESS);
640     EXPECT_EQ(SetPassword(PEER_LISTENER, std::string("")), SRT_SUCCESS);
641 
642     EXPECT_EQ(SetPassword(PEER_CALLER,   std::string("too_short")), SRT_ERROR);
643     EXPECT_EQ(SetPassword(PEER_LISTENER, std::string("too_short")), SRT_ERROR);
644 
645     std::string long_pwd;
646     const int pwd_len = 81;     // 80 is the maximum password length accepted
647     long_pwd.reserve(pwd_len);
648     const char start_char = '!';
649 
650     // Please ensure to be within the valid ASCII symbols!
651     ASSERT_LT(pwd_len + start_char, 126);
652     for (int i = 0; i < pwd_len; ++i)
653         long_pwd.push_back(static_cast<char>(start_char + i));
654 
655     EXPECT_EQ(SetPassword(PEER_CALLER,   long_pwd), SRT_ERROR);
656     EXPECT_EQ(SetPassword(PEER_LISTENER, long_pwd), SRT_ERROR);
657 
658     EXPECT_EQ(SetPassword(PEER_CALLER,   std::string("proper_len")),     SRT_SUCCESS);
659     EXPECT_EQ(SetPassword(PEER_LISTENER, std::string("proper_length")),  SRT_SUCCESS);
660 #else
661     EXPECT_EQ(SetPassword(PEER_CALLER, "whateverpassword"), SRT_ERROR);
662 #endif
663 }
664 
665 
666 /**
667  * @fn TestEnforcedEncryption.SetGetDefault
668  * @brief The default value for the enforced encryption should be ON
669  */
TEST_F(TestEnforcedEncryption,SetGetDefault)670 TEST_F(TestEnforcedEncryption, SetGetDefault)
671 {
672     EXPECT_EQ(GetEnforcedEncryption(PEER_CALLER),   true);
673     EXPECT_EQ(GetEnforcedEncryption(PEER_LISTENER), true);
674 
675     EXPECT_EQ(SetEnforcedEncryption(PEER_CALLER,    false), SRT_SUCCESS);
676     EXPECT_EQ(SetEnforcedEncryption(PEER_LISTENER,  false), SRT_SUCCESS);
677 
678     EXPECT_EQ(GetEnforcedEncryption(PEER_CALLER),   false);
679     EXPECT_EQ(GetEnforcedEncryption(PEER_LISTENER), false);
680 }
681 
682 
683 #define CREATE_TEST_CASE_BLOCKING(CASE_NUMBER, DESC) TEST_F(TestEnforcedEncryption, CASE_NUMBER##_Blocking_##DESC)\
684 {\
685     TestConnect<TestResultBlocking>(TEST_##CASE_NUMBER);\
686 }
687 
688 #define CREATE_TEST_CASE_NONBLOCKING(CASE_NUMBER, DESC) TEST_F(TestEnforcedEncryption, CASE_NUMBER##_NonBlocking_##DESC)\
689 {\
690     TestConnect<TestResultNonBlocking>(TEST_##CASE_NUMBER);\
691 }
692 
693 
694 #define CREATE_TEST_CASES(CASE_NUMBER, DESC) \
695     CREATE_TEST_CASE_NONBLOCKING(CASE_NUMBER, DESC) \
696     CREATE_TEST_CASE_BLOCKING(CASE_NUMBER, DESC)
697 
698 #ifdef SRT_ENABLE_ENCRYPTION
699 CREATE_TEST_CASES(CASE_A_1, Enforced_On_On_Pwd_Set_Set_Match)
700 CREATE_TEST_CASES(CASE_A_2, Enforced_On_On_Pwd_Set_Set_Mismatch)
701 CREATE_TEST_CASES(CASE_A_3, Enforced_On_On_Pwd_Set_None)
702 CREATE_TEST_CASES(CASE_A_4, Enforced_On_On_Pwd_None_Set)
703 #endif
704 CREATE_TEST_CASES(CASE_A_5, Enforced_On_On_Pwd_None_None)
705 
706 #ifdef SRT_ENABLE_ENCRYPTION
707 CREATE_TEST_CASES(CASE_B_1, Enforced_On_Off_Pwd_Set_Set_Match)
708 CREATE_TEST_CASES(CASE_B_2, Enforced_On_Off_Pwd_Set_Set_Mismatch)
709 CREATE_TEST_CASES(CASE_B_3, Enforced_On_Off_Pwd_Set_None)
710 CREATE_TEST_CASES(CASE_B_4, Enforced_On_Off_Pwd_None_Set)
711 #endif
712 CREATE_TEST_CASES(CASE_B_5, Enforced_On_Off_Pwd_None_None)
713 
714 #ifdef SRT_ENABLE_ENCRYPTION
715 CREATE_TEST_CASES(CASE_C_1, Enforced_Off_On_Pwd_Set_Set_Match)
716 CREATE_TEST_CASES(CASE_C_2, Enforced_Off_On_Pwd_Set_Set_Mismatch)
717 CREATE_TEST_CASES(CASE_C_3, Enforced_Off_On_Pwd_Set_None)
718 CREATE_TEST_CASES(CASE_C_4, Enforced_Off_On_Pwd_None_Set)
719 #endif
720 CREATE_TEST_CASES(CASE_C_5, Enforced_Off_On_Pwd_None_None)
721 
722 #ifdef SRT_ENABLE_ENCRYPTION
723 CREATE_TEST_CASES(CASE_D_1, Enforced_Off_Off_Pwd_Set_Set_Match)
724 CREATE_TEST_CASES(CASE_D_2, Enforced_Off_Off_Pwd_Set_Set_Mismatch)
725 CREATE_TEST_CASES(CASE_D_3, Enforced_Off_Off_Pwd_Set_None)
726 CREATE_TEST_CASES(CASE_D_4, Enforced_Off_Off_Pwd_None_Set)
727 #endif
728 CREATE_TEST_CASES(CASE_D_5, Enforced_Off_Off_Pwd_None_None)
729 
730