1 /*
2  *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include <string.h>
12 
13 #include <algorithm>
14 #include <memory>
15 #include <string>
16 #include <vector>
17 
18 #include "absl/memory/memory.h"
19 #include "rtc_base/async_packet_socket.h"
20 #include "rtc_base/async_socket.h"
21 #include "rtc_base/async_tcp_socket.h"
22 #include "rtc_base/async_udp_socket.h"
23 #include "rtc_base/gunit.h"
24 #include "rtc_base/ip_address.h"
25 #include "rtc_base/logging.h"
26 #include "rtc_base/nat_server.h"
27 #include "rtc_base/nat_socket_factory.h"
28 #include "rtc_base/nat_types.h"
29 #include "rtc_base/net_helpers.h"
30 #include "rtc_base/network.h"
31 #include "rtc_base/physical_socket_server.h"
32 #include "rtc_base/socket_address.h"
33 #include "rtc_base/socket_factory.h"
34 #include "rtc_base/socket_server.h"
35 #include "rtc_base/test_client.h"
36 #include "rtc_base/third_party/sigslot/sigslot.h"
37 #include "rtc_base/thread.h"
38 #include "rtc_base/virtual_socket_server.h"
39 #include "test/gtest.h"
40 
41 namespace rtc {
42 namespace {
43 
CheckReceive(TestClient * client,bool should_receive,const char * buf,size_t size)44 bool CheckReceive(TestClient* client,
45                   bool should_receive,
46                   const char* buf,
47                   size_t size) {
48   return (should_receive) ? client->CheckNextPacket(buf, size, 0)
49                           : client->CheckNoPacket();
50 }
51 
CreateTestClient(SocketFactory * factory,const SocketAddress & local_addr)52 TestClient* CreateTestClient(SocketFactory* factory,
53                              const SocketAddress& local_addr) {
54   return new TestClient(
55       absl::WrapUnique(AsyncUDPSocket::Create(factory, local_addr)));
56 }
57 
CreateTCPTestClient(AsyncSocket * socket)58 TestClient* CreateTCPTestClient(AsyncSocket* socket) {
59   return new TestClient(std::make_unique<AsyncTCPSocket>(socket, false));
60 }
61 
62 // Tests that when sending from internal_addr to external_addrs through the
63 // NAT type specified by nat_type, all external addrs receive the sent packet
64 // and, if exp_same is true, all use the same mapped-address on the NAT.
TestSend(SocketServer * internal,const SocketAddress & internal_addr,SocketServer * external,const SocketAddress external_addrs[4],NATType nat_type,bool exp_same)65 void TestSend(SocketServer* internal,
66               const SocketAddress& internal_addr,
67               SocketServer* external,
68               const SocketAddress external_addrs[4],
69               NATType nat_type,
70               bool exp_same) {
71   Thread th_int(internal);
72   Thread th_ext(external);
73 
74   SocketAddress server_addr = internal_addr;
75   server_addr.SetPort(0);  // Auto-select a port
76   NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr,
77                                  external, external_addrs[0]);
78   NATSocketFactory* natsf = new NATSocketFactory(
79       internal, nat->internal_udp_address(), nat->internal_tcp_address());
80 
81   TestClient* in = CreateTestClient(natsf, internal_addr);
82   TestClient* out[4];
83   for (int i = 0; i < 4; i++)
84     out[i] = CreateTestClient(external, external_addrs[i]);
85 
86   th_int.Start();
87   th_ext.Start();
88 
89   const char* buf = "filter_test";
90   size_t len = strlen(buf);
91 
92   in->SendTo(buf, len, out[0]->address());
93   SocketAddress trans_addr;
94   EXPECT_TRUE(out[0]->CheckNextPacket(buf, len, &trans_addr));
95 
96   for (int i = 1; i < 4; i++) {
97     in->SendTo(buf, len, out[i]->address());
98     SocketAddress trans_addr2;
99     EXPECT_TRUE(out[i]->CheckNextPacket(buf, len, &trans_addr2));
100     bool are_same = (trans_addr == trans_addr2);
101     ASSERT_EQ(are_same, exp_same) << "same translated address";
102     ASSERT_NE(AF_UNSPEC, trans_addr.family());
103     ASSERT_NE(AF_UNSPEC, trans_addr2.family());
104   }
105 
106   th_int.Stop();
107   th_ext.Stop();
108 
109   delete nat;
110   delete natsf;
111   delete in;
112   for (int i = 0; i < 4; i++)
113     delete out[i];
114 }
115 
116 // Tests that when sending from external_addrs to internal_addr, the packet
117 // is delivered according to the specified filter_ip and filter_port rules.
TestRecv(SocketServer * internal,const SocketAddress & internal_addr,SocketServer * external,const SocketAddress external_addrs[4],NATType nat_type,bool filter_ip,bool filter_port)118 void TestRecv(SocketServer* internal,
119               const SocketAddress& internal_addr,
120               SocketServer* external,
121               const SocketAddress external_addrs[4],
122               NATType nat_type,
123               bool filter_ip,
124               bool filter_port) {
125   Thread th_int(internal);
126   Thread th_ext(external);
127 
128   SocketAddress server_addr = internal_addr;
129   server_addr.SetPort(0);  // Auto-select a port
130   NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr,
131                                  external, external_addrs[0]);
132   NATSocketFactory* natsf = new NATSocketFactory(
133       internal, nat->internal_udp_address(), nat->internal_tcp_address());
134 
135   TestClient* in = CreateTestClient(natsf, internal_addr);
136   TestClient* out[4];
137   for (int i = 0; i < 4; i++)
138     out[i] = CreateTestClient(external, external_addrs[i]);
139 
140   th_int.Start();
141   th_ext.Start();
142 
143   const char* buf = "filter_test";
144   size_t len = strlen(buf);
145 
146   in->SendTo(buf, len, out[0]->address());
147   SocketAddress trans_addr;
148   EXPECT_TRUE(out[0]->CheckNextPacket(buf, len, &trans_addr));
149 
150   out[1]->SendTo(buf, len, trans_addr);
151   EXPECT_TRUE(CheckReceive(in, !filter_ip, buf, len));
152 
153   out[2]->SendTo(buf, len, trans_addr);
154   EXPECT_TRUE(CheckReceive(in, !filter_port, buf, len));
155 
156   out[3]->SendTo(buf, len, trans_addr);
157   EXPECT_TRUE(CheckReceive(in, !filter_ip && !filter_port, buf, len));
158 
159   th_int.Stop();
160   th_ext.Stop();
161 
162   delete nat;
163   delete natsf;
164   delete in;
165   for (int i = 0; i < 4; i++)
166     delete out[i];
167 }
168 
169 // Tests that NATServer allocates bindings properly.
TestBindings(SocketServer * internal,const SocketAddress & internal_addr,SocketServer * external,const SocketAddress external_addrs[4])170 void TestBindings(SocketServer* internal,
171                   const SocketAddress& internal_addr,
172                   SocketServer* external,
173                   const SocketAddress external_addrs[4]) {
174   TestSend(internal, internal_addr, external, external_addrs, NAT_OPEN_CONE,
175            true);
176   TestSend(internal, internal_addr, external, external_addrs,
177            NAT_ADDR_RESTRICTED, true);
178   TestSend(internal, internal_addr, external, external_addrs,
179            NAT_PORT_RESTRICTED, true);
180   TestSend(internal, internal_addr, external, external_addrs, NAT_SYMMETRIC,
181            false);
182 }
183 
184 // Tests that NATServer filters packets properly.
TestFilters(SocketServer * internal,const SocketAddress & internal_addr,SocketServer * external,const SocketAddress external_addrs[4])185 void TestFilters(SocketServer* internal,
186                  const SocketAddress& internal_addr,
187                  SocketServer* external,
188                  const SocketAddress external_addrs[4]) {
189   TestRecv(internal, internal_addr, external, external_addrs, NAT_OPEN_CONE,
190            false, false);
191   TestRecv(internal, internal_addr, external, external_addrs,
192            NAT_ADDR_RESTRICTED, true, false);
193   TestRecv(internal, internal_addr, external, external_addrs,
194            NAT_PORT_RESTRICTED, true, true);
195   TestRecv(internal, internal_addr, external, external_addrs, NAT_SYMMETRIC,
196            true, true);
197 }
198 
TestConnectivity(const SocketAddress & src,const IPAddress & dst)199 bool TestConnectivity(const SocketAddress& src, const IPAddress& dst) {
200   // The physical NAT tests require connectivity to the selected ip from the
201   // internal address used for the NAT. Things like firewalls can break that, so
202   // check to see if it's worth even trying with this ip.
203   std::unique_ptr<PhysicalSocketServer> pss(new PhysicalSocketServer());
204   std::unique_ptr<AsyncSocket> client(
205       pss->CreateAsyncSocket(src.family(), SOCK_DGRAM));
206   std::unique_ptr<AsyncSocket> server(
207       pss->CreateAsyncSocket(src.family(), SOCK_DGRAM));
208   if (client->Bind(SocketAddress(src.ipaddr(), 0)) != 0 ||
209       server->Bind(SocketAddress(dst, 0)) != 0) {
210     return false;
211   }
212   const char* buf = "hello other socket";
213   size_t len = strlen(buf);
214   int sent = client->SendTo(buf, len, server->GetLocalAddress());
215   SocketAddress addr;
216   const size_t kRecvBufSize = 64;
217   char recvbuf[kRecvBufSize];
218   Thread::Current()->SleepMs(100);
219   int received = server->RecvFrom(recvbuf, kRecvBufSize, &addr, nullptr);
220   return received == sent && ::memcmp(buf, recvbuf, len) == 0;
221 }
222 
TestPhysicalInternal(const SocketAddress & int_addr)223 void TestPhysicalInternal(const SocketAddress& int_addr) {
224   BasicNetworkManager network_manager;
225   network_manager.StartUpdating();
226   // Process pending messages so the network list is updated.
227   Thread::Current()->ProcessMessages(0);
228 
229   std::vector<Network*> networks;
230   network_manager.GetNetworks(&networks);
231   networks.erase(std::remove_if(networks.begin(), networks.end(),
232                                 [](rtc::Network* network) {
233                                   return rtc::kDefaultNetworkIgnoreMask &
234                                          network->type();
235                                 }),
236                  networks.end());
237   if (networks.empty()) {
238     RTC_LOG(LS_WARNING) << "Not enough network adapters for test.";
239     return;
240   }
241 
242   SocketAddress ext_addr1(int_addr);
243   SocketAddress ext_addr2;
244   // Find an available IP with matching family. The test breaks if int_addr
245   // can't talk to ip, so check for connectivity as well.
246   for (std::vector<Network*>::iterator it = networks.begin();
247        it != networks.end(); ++it) {
248     const IPAddress& ip = (*it)->GetBestIP();
249     if (ip.family() == int_addr.family() && TestConnectivity(int_addr, ip)) {
250       ext_addr2.SetIP(ip);
251       break;
252     }
253   }
254   if (ext_addr2.IsNil()) {
255     RTC_LOG(LS_WARNING) << "No available IP of same family as "
256                         << int_addr.ToString();
257     return;
258   }
259 
260   RTC_LOG(LS_INFO) << "selected ip " << ext_addr2.ipaddr().ToString();
261 
262   SocketAddress ext_addrs[4] = {
263       SocketAddress(ext_addr1), SocketAddress(ext_addr2),
264       SocketAddress(ext_addr1), SocketAddress(ext_addr2)};
265 
266   std::unique_ptr<PhysicalSocketServer> int_pss(new PhysicalSocketServer());
267   std::unique_ptr<PhysicalSocketServer> ext_pss(new PhysicalSocketServer());
268 
269   TestBindings(int_pss.get(), int_addr, ext_pss.get(), ext_addrs);
270   TestFilters(int_pss.get(), int_addr, ext_pss.get(), ext_addrs);
271 }
272 
TEST(NatTest,TestPhysicalIPv4)273 TEST(NatTest, TestPhysicalIPv4) {
274   TestPhysicalInternal(SocketAddress("127.0.0.1", 0));
275 }
276 
TEST(NatTest,TestPhysicalIPv6)277 TEST(NatTest, TestPhysicalIPv6) {
278   if (HasIPv6Enabled()) {
279     TestPhysicalInternal(SocketAddress("::1", 0));
280   } else {
281     RTC_LOG(LS_WARNING) << "No IPv6, skipping";
282   }
283 }
284 
285 namespace {
286 
287 class TestVirtualSocketServer : public VirtualSocketServer {
288  public:
289   // Expose this publicly
GetNextIP(int af)290   IPAddress GetNextIP(int af) { return VirtualSocketServer::GetNextIP(af); }
291 };
292 
293 }  // namespace
294 
TestVirtualInternal(int family)295 void TestVirtualInternal(int family) {
296   std::unique_ptr<TestVirtualSocketServer> int_vss(
297       new TestVirtualSocketServer());
298   std::unique_ptr<TestVirtualSocketServer> ext_vss(
299       new TestVirtualSocketServer());
300 
301   SocketAddress int_addr;
302   SocketAddress ext_addrs[4];
303   int_addr.SetIP(int_vss->GetNextIP(family));
304   ext_addrs[0].SetIP(ext_vss->GetNextIP(int_addr.family()));
305   ext_addrs[1].SetIP(ext_vss->GetNextIP(int_addr.family()));
306   ext_addrs[2].SetIP(ext_addrs[0].ipaddr());
307   ext_addrs[3].SetIP(ext_addrs[1].ipaddr());
308 
309   TestBindings(int_vss.get(), int_addr, ext_vss.get(), ext_addrs);
310   TestFilters(int_vss.get(), int_addr, ext_vss.get(), ext_addrs);
311 }
312 
TEST(NatTest,TestVirtualIPv4)313 TEST(NatTest, TestVirtualIPv4) {
314   TestVirtualInternal(AF_INET);
315 }
316 
TEST(NatTest,TestVirtualIPv6)317 TEST(NatTest, TestVirtualIPv6) {
318   if (HasIPv6Enabled()) {
319     TestVirtualInternal(AF_INET6);
320   } else {
321     RTC_LOG(LS_WARNING) << "No IPv6, skipping";
322   }
323 }
324 
325 class NatTcpTest : public ::testing::Test, public sigslot::has_slots<> {
326  public:
NatTcpTest()327   NatTcpTest()
328       : int_addr_("192.168.0.1", 0),
329         ext_addr_("10.0.0.1", 0),
330         connected_(false),
331         int_vss_(new TestVirtualSocketServer()),
332         ext_vss_(new TestVirtualSocketServer()),
333         int_thread_(new Thread(int_vss_.get())),
334         ext_thread_(new Thread(ext_vss_.get())),
335         nat_(new NATServer(NAT_OPEN_CONE,
336                            int_vss_.get(),
337                            int_addr_,
338                            int_addr_,
339                            ext_vss_.get(),
340                            ext_addr_)),
341         natsf_(new NATSocketFactory(int_vss_.get(),
342                                     nat_->internal_udp_address(),
343                                     nat_->internal_tcp_address())) {
344     int_thread_->Start();
345     ext_thread_->Start();
346   }
347 
OnConnectEvent(AsyncSocket * socket)348   void OnConnectEvent(AsyncSocket* socket) { connected_ = true; }
349 
OnAcceptEvent(AsyncSocket * socket)350   void OnAcceptEvent(AsyncSocket* socket) {
351     accepted_.reset(server_->Accept(nullptr));
352   }
353 
OnCloseEvent(AsyncSocket * socket,int error)354   void OnCloseEvent(AsyncSocket* socket, int error) {}
355 
ConnectEvents()356   void ConnectEvents() {
357     server_->SignalReadEvent.connect(this, &NatTcpTest::OnAcceptEvent);
358     client_->SignalConnectEvent.connect(this, &NatTcpTest::OnConnectEvent);
359   }
360 
361   SocketAddress int_addr_;
362   SocketAddress ext_addr_;
363   bool connected_;
364   std::unique_ptr<TestVirtualSocketServer> int_vss_;
365   std::unique_ptr<TestVirtualSocketServer> ext_vss_;
366   std::unique_ptr<Thread> int_thread_;
367   std::unique_ptr<Thread> ext_thread_;
368   std::unique_ptr<NATServer> nat_;
369   std::unique_ptr<NATSocketFactory> natsf_;
370   std::unique_ptr<AsyncSocket> client_;
371   std::unique_ptr<AsyncSocket> server_;
372   std::unique_ptr<AsyncSocket> accepted_;
373 };
374 
TEST_F(NatTcpTest,DISABLED_TestConnectOut)375 TEST_F(NatTcpTest, DISABLED_TestConnectOut) {
376   server_.reset(ext_vss_->CreateAsyncSocket(AF_INET, SOCK_STREAM));
377   server_->Bind(ext_addr_);
378   server_->Listen(5);
379 
380   client_.reset(natsf_->CreateAsyncSocket(AF_INET, SOCK_STREAM));
381   EXPECT_GE(0, client_->Bind(int_addr_));
382   EXPECT_GE(0, client_->Connect(server_->GetLocalAddress()));
383 
384   ConnectEvents();
385 
386   EXPECT_TRUE_WAIT(connected_, 1000);
387   EXPECT_EQ(client_->GetRemoteAddress(), server_->GetLocalAddress());
388   EXPECT_EQ(accepted_->GetRemoteAddress().ipaddr(), ext_addr_.ipaddr());
389 
390   std::unique_ptr<rtc::TestClient> in(CreateTCPTestClient(client_.release()));
391   std::unique_ptr<rtc::TestClient> out(
392       CreateTCPTestClient(accepted_.release()));
393 
394   const char* buf = "test_packet";
395   size_t len = strlen(buf);
396 
397   in->Send(buf, len);
398   SocketAddress trans_addr;
399   EXPECT_TRUE(out->CheckNextPacket(buf, len, &trans_addr));
400 
401   out->Send(buf, len);
402   EXPECT_TRUE(in->CheckNextPacket(buf, len, &trans_addr));
403 }
404 
405 }  // namespace
406 }  // namespace rtc
407