1 /*
2 * Copyright 2004 The WebRTC Project Authors. All rights reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include <string.h>
12
13 #include <algorithm>
14 #include <memory>
15 #include <string>
16 #include <vector>
17
18 #include "absl/memory/memory.h"
19 #include "rtc_base/async_packet_socket.h"
20 #include "rtc_base/async_socket.h"
21 #include "rtc_base/async_tcp_socket.h"
22 #include "rtc_base/async_udp_socket.h"
23 #include "rtc_base/gunit.h"
24 #include "rtc_base/ip_address.h"
25 #include "rtc_base/logging.h"
26 #include "rtc_base/nat_server.h"
27 #include "rtc_base/nat_socket_factory.h"
28 #include "rtc_base/nat_types.h"
29 #include "rtc_base/net_helpers.h"
30 #include "rtc_base/network.h"
31 #include "rtc_base/physical_socket_server.h"
32 #include "rtc_base/socket_address.h"
33 #include "rtc_base/socket_factory.h"
34 #include "rtc_base/socket_server.h"
35 #include "rtc_base/test_client.h"
36 #include "rtc_base/third_party/sigslot/sigslot.h"
37 #include "rtc_base/thread.h"
38 #include "rtc_base/virtual_socket_server.h"
39 #include "test/gtest.h"
40
41 namespace rtc {
42 namespace {
43
CheckReceive(TestClient * client,bool should_receive,const char * buf,size_t size)44 bool CheckReceive(TestClient* client,
45 bool should_receive,
46 const char* buf,
47 size_t size) {
48 return (should_receive) ? client->CheckNextPacket(buf, size, 0)
49 : client->CheckNoPacket();
50 }
51
CreateTestClient(SocketFactory * factory,const SocketAddress & local_addr)52 TestClient* CreateTestClient(SocketFactory* factory,
53 const SocketAddress& local_addr) {
54 return new TestClient(
55 absl::WrapUnique(AsyncUDPSocket::Create(factory, local_addr)));
56 }
57
CreateTCPTestClient(AsyncSocket * socket)58 TestClient* CreateTCPTestClient(AsyncSocket* socket) {
59 return new TestClient(std::make_unique<AsyncTCPSocket>(socket, false));
60 }
61
62 // Tests that when sending from internal_addr to external_addrs through the
63 // NAT type specified by nat_type, all external addrs receive the sent packet
64 // and, if exp_same is true, all use the same mapped-address on the NAT.
TestSend(SocketServer * internal,const SocketAddress & internal_addr,SocketServer * external,const SocketAddress external_addrs[4],NATType nat_type,bool exp_same)65 void TestSend(SocketServer* internal,
66 const SocketAddress& internal_addr,
67 SocketServer* external,
68 const SocketAddress external_addrs[4],
69 NATType nat_type,
70 bool exp_same) {
71 Thread th_int(internal);
72 Thread th_ext(external);
73
74 SocketAddress server_addr = internal_addr;
75 server_addr.SetPort(0); // Auto-select a port
76 NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr,
77 external, external_addrs[0]);
78 NATSocketFactory* natsf = new NATSocketFactory(
79 internal, nat->internal_udp_address(), nat->internal_tcp_address());
80
81 TestClient* in = CreateTestClient(natsf, internal_addr);
82 TestClient* out[4];
83 for (int i = 0; i < 4; i++)
84 out[i] = CreateTestClient(external, external_addrs[i]);
85
86 th_int.Start();
87 th_ext.Start();
88
89 const char* buf = "filter_test";
90 size_t len = strlen(buf);
91
92 in->SendTo(buf, len, out[0]->address());
93 SocketAddress trans_addr;
94 EXPECT_TRUE(out[0]->CheckNextPacket(buf, len, &trans_addr));
95
96 for (int i = 1; i < 4; i++) {
97 in->SendTo(buf, len, out[i]->address());
98 SocketAddress trans_addr2;
99 EXPECT_TRUE(out[i]->CheckNextPacket(buf, len, &trans_addr2));
100 bool are_same = (trans_addr == trans_addr2);
101 ASSERT_EQ(are_same, exp_same) << "same translated address";
102 ASSERT_NE(AF_UNSPEC, trans_addr.family());
103 ASSERT_NE(AF_UNSPEC, trans_addr2.family());
104 }
105
106 th_int.Stop();
107 th_ext.Stop();
108
109 delete nat;
110 delete natsf;
111 delete in;
112 for (int i = 0; i < 4; i++)
113 delete out[i];
114 }
115
116 // Tests that when sending from external_addrs to internal_addr, the packet
117 // is delivered according to the specified filter_ip and filter_port rules.
TestRecv(SocketServer * internal,const SocketAddress & internal_addr,SocketServer * external,const SocketAddress external_addrs[4],NATType nat_type,bool filter_ip,bool filter_port)118 void TestRecv(SocketServer* internal,
119 const SocketAddress& internal_addr,
120 SocketServer* external,
121 const SocketAddress external_addrs[4],
122 NATType nat_type,
123 bool filter_ip,
124 bool filter_port) {
125 Thread th_int(internal);
126 Thread th_ext(external);
127
128 SocketAddress server_addr = internal_addr;
129 server_addr.SetPort(0); // Auto-select a port
130 NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr,
131 external, external_addrs[0]);
132 NATSocketFactory* natsf = new NATSocketFactory(
133 internal, nat->internal_udp_address(), nat->internal_tcp_address());
134
135 TestClient* in = CreateTestClient(natsf, internal_addr);
136 TestClient* out[4];
137 for (int i = 0; i < 4; i++)
138 out[i] = CreateTestClient(external, external_addrs[i]);
139
140 th_int.Start();
141 th_ext.Start();
142
143 const char* buf = "filter_test";
144 size_t len = strlen(buf);
145
146 in->SendTo(buf, len, out[0]->address());
147 SocketAddress trans_addr;
148 EXPECT_TRUE(out[0]->CheckNextPacket(buf, len, &trans_addr));
149
150 out[1]->SendTo(buf, len, trans_addr);
151 EXPECT_TRUE(CheckReceive(in, !filter_ip, buf, len));
152
153 out[2]->SendTo(buf, len, trans_addr);
154 EXPECT_TRUE(CheckReceive(in, !filter_port, buf, len));
155
156 out[3]->SendTo(buf, len, trans_addr);
157 EXPECT_TRUE(CheckReceive(in, !filter_ip && !filter_port, buf, len));
158
159 th_int.Stop();
160 th_ext.Stop();
161
162 delete nat;
163 delete natsf;
164 delete in;
165 for (int i = 0; i < 4; i++)
166 delete out[i];
167 }
168
169 // Tests that NATServer allocates bindings properly.
TestBindings(SocketServer * internal,const SocketAddress & internal_addr,SocketServer * external,const SocketAddress external_addrs[4])170 void TestBindings(SocketServer* internal,
171 const SocketAddress& internal_addr,
172 SocketServer* external,
173 const SocketAddress external_addrs[4]) {
174 TestSend(internal, internal_addr, external, external_addrs, NAT_OPEN_CONE,
175 true);
176 TestSend(internal, internal_addr, external, external_addrs,
177 NAT_ADDR_RESTRICTED, true);
178 TestSend(internal, internal_addr, external, external_addrs,
179 NAT_PORT_RESTRICTED, true);
180 TestSend(internal, internal_addr, external, external_addrs, NAT_SYMMETRIC,
181 false);
182 }
183
184 // Tests that NATServer filters packets properly.
TestFilters(SocketServer * internal,const SocketAddress & internal_addr,SocketServer * external,const SocketAddress external_addrs[4])185 void TestFilters(SocketServer* internal,
186 const SocketAddress& internal_addr,
187 SocketServer* external,
188 const SocketAddress external_addrs[4]) {
189 TestRecv(internal, internal_addr, external, external_addrs, NAT_OPEN_CONE,
190 false, false);
191 TestRecv(internal, internal_addr, external, external_addrs,
192 NAT_ADDR_RESTRICTED, true, false);
193 TestRecv(internal, internal_addr, external, external_addrs,
194 NAT_PORT_RESTRICTED, true, true);
195 TestRecv(internal, internal_addr, external, external_addrs, NAT_SYMMETRIC,
196 true, true);
197 }
198
TestConnectivity(const SocketAddress & src,const IPAddress & dst)199 bool TestConnectivity(const SocketAddress& src, const IPAddress& dst) {
200 // The physical NAT tests require connectivity to the selected ip from the
201 // internal address used for the NAT. Things like firewalls can break that, so
202 // check to see if it's worth even trying with this ip.
203 std::unique_ptr<PhysicalSocketServer> pss(new PhysicalSocketServer());
204 std::unique_ptr<AsyncSocket> client(
205 pss->CreateAsyncSocket(src.family(), SOCK_DGRAM));
206 std::unique_ptr<AsyncSocket> server(
207 pss->CreateAsyncSocket(src.family(), SOCK_DGRAM));
208 if (client->Bind(SocketAddress(src.ipaddr(), 0)) != 0 ||
209 server->Bind(SocketAddress(dst, 0)) != 0) {
210 return false;
211 }
212 const char* buf = "hello other socket";
213 size_t len = strlen(buf);
214 int sent = client->SendTo(buf, len, server->GetLocalAddress());
215 SocketAddress addr;
216 const size_t kRecvBufSize = 64;
217 char recvbuf[kRecvBufSize];
218 Thread::Current()->SleepMs(100);
219 int received = server->RecvFrom(recvbuf, kRecvBufSize, &addr, nullptr);
220 return received == sent && ::memcmp(buf, recvbuf, len) == 0;
221 }
222
TestPhysicalInternal(const SocketAddress & int_addr)223 void TestPhysicalInternal(const SocketAddress& int_addr) {
224 BasicNetworkManager network_manager;
225 network_manager.StartUpdating();
226 // Process pending messages so the network list is updated.
227 Thread::Current()->ProcessMessages(0);
228
229 std::vector<Network*> networks;
230 network_manager.GetNetworks(&networks);
231 networks.erase(std::remove_if(networks.begin(), networks.end(),
232 [](rtc::Network* network) {
233 return rtc::kDefaultNetworkIgnoreMask &
234 network->type();
235 }),
236 networks.end());
237 if (networks.empty()) {
238 RTC_LOG(LS_WARNING) << "Not enough network adapters for test.";
239 return;
240 }
241
242 SocketAddress ext_addr1(int_addr);
243 SocketAddress ext_addr2;
244 // Find an available IP with matching family. The test breaks if int_addr
245 // can't talk to ip, so check for connectivity as well.
246 for (std::vector<Network*>::iterator it = networks.begin();
247 it != networks.end(); ++it) {
248 const IPAddress& ip = (*it)->GetBestIP();
249 if (ip.family() == int_addr.family() && TestConnectivity(int_addr, ip)) {
250 ext_addr2.SetIP(ip);
251 break;
252 }
253 }
254 if (ext_addr2.IsNil()) {
255 RTC_LOG(LS_WARNING) << "No available IP of same family as "
256 << int_addr.ToString();
257 return;
258 }
259
260 RTC_LOG(LS_INFO) << "selected ip " << ext_addr2.ipaddr().ToString();
261
262 SocketAddress ext_addrs[4] = {
263 SocketAddress(ext_addr1), SocketAddress(ext_addr2),
264 SocketAddress(ext_addr1), SocketAddress(ext_addr2)};
265
266 std::unique_ptr<PhysicalSocketServer> int_pss(new PhysicalSocketServer());
267 std::unique_ptr<PhysicalSocketServer> ext_pss(new PhysicalSocketServer());
268
269 TestBindings(int_pss.get(), int_addr, ext_pss.get(), ext_addrs);
270 TestFilters(int_pss.get(), int_addr, ext_pss.get(), ext_addrs);
271 }
272
TEST(NatTest,TestPhysicalIPv4)273 TEST(NatTest, TestPhysicalIPv4) {
274 TestPhysicalInternal(SocketAddress("127.0.0.1", 0));
275 }
276
TEST(NatTest,TestPhysicalIPv6)277 TEST(NatTest, TestPhysicalIPv6) {
278 if (HasIPv6Enabled()) {
279 TestPhysicalInternal(SocketAddress("::1", 0));
280 } else {
281 RTC_LOG(LS_WARNING) << "No IPv6, skipping";
282 }
283 }
284
285 namespace {
286
287 class TestVirtualSocketServer : public VirtualSocketServer {
288 public:
289 // Expose this publicly
GetNextIP(int af)290 IPAddress GetNextIP(int af) { return VirtualSocketServer::GetNextIP(af); }
291 };
292
293 } // namespace
294
TestVirtualInternal(int family)295 void TestVirtualInternal(int family) {
296 std::unique_ptr<TestVirtualSocketServer> int_vss(
297 new TestVirtualSocketServer());
298 std::unique_ptr<TestVirtualSocketServer> ext_vss(
299 new TestVirtualSocketServer());
300
301 SocketAddress int_addr;
302 SocketAddress ext_addrs[4];
303 int_addr.SetIP(int_vss->GetNextIP(family));
304 ext_addrs[0].SetIP(ext_vss->GetNextIP(int_addr.family()));
305 ext_addrs[1].SetIP(ext_vss->GetNextIP(int_addr.family()));
306 ext_addrs[2].SetIP(ext_addrs[0].ipaddr());
307 ext_addrs[3].SetIP(ext_addrs[1].ipaddr());
308
309 TestBindings(int_vss.get(), int_addr, ext_vss.get(), ext_addrs);
310 TestFilters(int_vss.get(), int_addr, ext_vss.get(), ext_addrs);
311 }
312
TEST(NatTest,TestVirtualIPv4)313 TEST(NatTest, TestVirtualIPv4) {
314 TestVirtualInternal(AF_INET);
315 }
316
TEST(NatTest,TestVirtualIPv6)317 TEST(NatTest, TestVirtualIPv6) {
318 if (HasIPv6Enabled()) {
319 TestVirtualInternal(AF_INET6);
320 } else {
321 RTC_LOG(LS_WARNING) << "No IPv6, skipping";
322 }
323 }
324
325 class NatTcpTest : public ::testing::Test, public sigslot::has_slots<> {
326 public:
NatTcpTest()327 NatTcpTest()
328 : int_addr_("192.168.0.1", 0),
329 ext_addr_("10.0.0.1", 0),
330 connected_(false),
331 int_vss_(new TestVirtualSocketServer()),
332 ext_vss_(new TestVirtualSocketServer()),
333 int_thread_(new Thread(int_vss_.get())),
334 ext_thread_(new Thread(ext_vss_.get())),
335 nat_(new NATServer(NAT_OPEN_CONE,
336 int_vss_.get(),
337 int_addr_,
338 int_addr_,
339 ext_vss_.get(),
340 ext_addr_)),
341 natsf_(new NATSocketFactory(int_vss_.get(),
342 nat_->internal_udp_address(),
343 nat_->internal_tcp_address())) {
344 int_thread_->Start();
345 ext_thread_->Start();
346 }
347
OnConnectEvent(AsyncSocket * socket)348 void OnConnectEvent(AsyncSocket* socket) { connected_ = true; }
349
OnAcceptEvent(AsyncSocket * socket)350 void OnAcceptEvent(AsyncSocket* socket) {
351 accepted_.reset(server_->Accept(nullptr));
352 }
353
OnCloseEvent(AsyncSocket * socket,int error)354 void OnCloseEvent(AsyncSocket* socket, int error) {}
355
ConnectEvents()356 void ConnectEvents() {
357 server_->SignalReadEvent.connect(this, &NatTcpTest::OnAcceptEvent);
358 client_->SignalConnectEvent.connect(this, &NatTcpTest::OnConnectEvent);
359 }
360
361 SocketAddress int_addr_;
362 SocketAddress ext_addr_;
363 bool connected_;
364 std::unique_ptr<TestVirtualSocketServer> int_vss_;
365 std::unique_ptr<TestVirtualSocketServer> ext_vss_;
366 std::unique_ptr<Thread> int_thread_;
367 std::unique_ptr<Thread> ext_thread_;
368 std::unique_ptr<NATServer> nat_;
369 std::unique_ptr<NATSocketFactory> natsf_;
370 std::unique_ptr<AsyncSocket> client_;
371 std::unique_ptr<AsyncSocket> server_;
372 std::unique_ptr<AsyncSocket> accepted_;
373 };
374
TEST_F(NatTcpTest,DISABLED_TestConnectOut)375 TEST_F(NatTcpTest, DISABLED_TestConnectOut) {
376 server_.reset(ext_vss_->CreateAsyncSocket(AF_INET, SOCK_STREAM));
377 server_->Bind(ext_addr_);
378 server_->Listen(5);
379
380 client_.reset(natsf_->CreateAsyncSocket(AF_INET, SOCK_STREAM));
381 EXPECT_GE(0, client_->Bind(int_addr_));
382 EXPECT_GE(0, client_->Connect(server_->GetLocalAddress()));
383
384 ConnectEvents();
385
386 EXPECT_TRUE_WAIT(connected_, 1000);
387 EXPECT_EQ(client_->GetRemoteAddress(), server_->GetLocalAddress());
388 EXPECT_EQ(accepted_->GetRemoteAddress().ipaddr(), ext_addr_.ipaddr());
389
390 std::unique_ptr<rtc::TestClient> in(CreateTCPTestClient(client_.release()));
391 std::unique_ptr<rtc::TestClient> out(
392 CreateTCPTestClient(accepted_.release()));
393
394 const char* buf = "test_packet";
395 size_t len = strlen(buf);
396
397 in->Send(buf, len);
398 SocketAddress trans_addr;
399 EXPECT_TRUE(out->CheckNextPacket(buf, len, &trans_addr));
400
401 out->Send(buf, len);
402 EXPECT_TRUE(in->CheckNextPacket(buf, len, &trans_addr));
403 }
404
405 } // namespace
406 } // namespace rtc
407