1 /*
2  *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "rtc_base/natsocketfactory.h"
12 
13 #include "rtc_base/arraysize.h"
14 #include "rtc_base/checks.h"
15 #include "rtc_base/logging.h"
16 #include "rtc_base/natserver.h"
17 #include "rtc_base/virtualsocketserver.h"
18 
19 namespace rtc {
20 
21 // Packs the given socketaddress into the buffer in buf, in the quasi-STUN
22 // format that the natserver uses.
23 // Returns 0 if an invalid address is passed.
PackAddressForNAT(char * buf,size_t buf_size,const SocketAddress & remote_addr)24 size_t PackAddressForNAT(char* buf, size_t buf_size,
25                          const SocketAddress& remote_addr) {
26   const IPAddress& ip = remote_addr.ipaddr();
27   int family = ip.family();
28   buf[0] = 0;
29   buf[1] = family;
30   // Writes the port.
31   *(reinterpret_cast<uint16_t*>(&buf[2])) = HostToNetwork16(remote_addr.port());
32   if (family == AF_INET) {
33     RTC_DCHECK(buf_size >= kNATEncodedIPv4AddressSize);
34     in_addr v4addr = ip.ipv4_address();
35     memcpy(&buf[4], &v4addr, kNATEncodedIPv4AddressSize - 4);
36     return kNATEncodedIPv4AddressSize;
37   } else if (family == AF_INET6) {
38     RTC_DCHECK(buf_size >= kNATEncodedIPv6AddressSize);
39     in6_addr v6addr = ip.ipv6_address();
40     memcpy(&buf[4], &v6addr, kNATEncodedIPv6AddressSize - 4);
41     return kNATEncodedIPv6AddressSize;
42   }
43   return 0U;
44 }
45 
46 // Decodes the remote address from a packet that has been encoded with the nat's
47 // quasi-STUN format. Returns the length of the address (i.e., the offset into
48 // data where the original packet starts).
UnpackAddressFromNAT(const char * buf,size_t buf_size,SocketAddress * remote_addr)49 size_t UnpackAddressFromNAT(const char* buf, size_t buf_size,
50                             SocketAddress* remote_addr) {
51   RTC_DCHECK(buf_size >= 8);
52   RTC_DCHECK(buf[0] == 0);
53   int family = buf[1];
54   uint16_t port =
55       NetworkToHost16(*(reinterpret_cast<const uint16_t*>(&buf[2])));
56   if (family == AF_INET) {
57     const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]);
58     *remote_addr = SocketAddress(IPAddress(*v4addr), port);
59     return kNATEncodedIPv4AddressSize;
60   } else if (family == AF_INET6) {
61     RTC_DCHECK(buf_size >= 20);
62     const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]);
63     *remote_addr = SocketAddress(IPAddress(*v6addr), port);
64     return kNATEncodedIPv6AddressSize;
65   }
66   return 0U;
67 }
68 
69 
70 // NATSocket
71 class NATSocket : public AsyncSocket, public sigslot::has_slots<> {
72  public:
NATSocket(NATInternalSocketFactory * sf,int family,int type)73   explicit NATSocket(NATInternalSocketFactory* sf, int family, int type)
74       : sf_(sf),
75         family_(family),
76         type_(type),
77         connected_(false),
78         socket_(nullptr),
79         buf_(nullptr),
80         size_(0) {}
81 
~NATSocket()82   ~NATSocket() override {
83     delete socket_;
84     delete[] buf_;
85   }
86 
GetLocalAddress() const87   SocketAddress GetLocalAddress() const override {
88     return (socket_) ? socket_->GetLocalAddress() : SocketAddress();
89   }
90 
GetRemoteAddress() const91   SocketAddress GetRemoteAddress() const override {
92     return remote_addr_;  // will be NIL if not connected
93   }
94 
Bind(const SocketAddress & addr)95   int Bind(const SocketAddress& addr) override {
96     if (socket_) {  // already bound, bubble up error
97       return -1;
98     }
99 
100     return BindInternal(addr);
101   }
102 
Connect(const SocketAddress & addr)103   int Connect(const SocketAddress& addr) override {
104     int result = 0;
105     // If we're not already bound (meaning |socket_| is null), bind to ANY
106     // address.
107     if (!socket_) {
108       result = BindInternal(SocketAddress(GetAnyIP(family_), 0));
109       if (result < 0) {
110         return result;
111       }
112     }
113 
114     if (type_ == SOCK_STREAM) {
115       result = socket_->Connect(server_addr_.IsNil() ? addr : server_addr_);
116     } else {
117       connected_ = true;
118     }
119 
120     if (result >= 0) {
121       remote_addr_ = addr;
122     }
123 
124     return result;
125   }
126 
Send(const void * data,size_t size)127   int Send(const void* data, size_t size) override {
128     RTC_DCHECK(connected_);
129     return SendTo(data, size, remote_addr_);
130   }
131 
SendTo(const void * data,size_t size,const SocketAddress & addr)132   int SendTo(const void* data,
133              size_t size,
134              const SocketAddress& addr) override {
135     RTC_DCHECK(!connected_ || addr == remote_addr_);
136     if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
137       return socket_->SendTo(data, size, addr);
138     }
139     // This array will be too large for IPv4 packets, but only by 12 bytes.
140     std::unique_ptr<char[]> buf(new char[size + kNATEncodedIPv6AddressSize]);
141     size_t addrlength = PackAddressForNAT(buf.get(),
142                                           size + kNATEncodedIPv6AddressSize,
143                                           addr);
144     size_t encoded_size = size + addrlength;
145     memcpy(buf.get() + addrlength, data, size);
146     int result = socket_->SendTo(buf.get(), encoded_size, server_addr_);
147     if (result >= 0) {
148       RTC_DCHECK(result == static_cast<int>(encoded_size));
149       result = result - static_cast<int>(addrlength);
150     }
151     return result;
152   }
153 
Recv(void * data,size_t size,int64_t * timestamp)154   int Recv(void* data, size_t size, int64_t* timestamp) override {
155     SocketAddress addr;
156     return RecvFrom(data, size, &addr, timestamp);
157   }
158 
RecvFrom(void * data,size_t size,SocketAddress * out_addr,int64_t * timestamp)159   int RecvFrom(void* data,
160                size_t size,
161                SocketAddress* out_addr,
162                int64_t* timestamp) override {
163     if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
164       return socket_->RecvFrom(data, size, out_addr, timestamp);
165     }
166     // Make sure we have enough room to read the requested amount plus the
167     // largest possible header address.
168     SocketAddress remote_addr;
169     Grow(size + kNATEncodedIPv6AddressSize);
170 
171     // Read the packet from the socket.
172     int result = socket_->RecvFrom(buf_, size_, &remote_addr, timestamp);
173     if (result >= 0) {
174       RTC_DCHECK(remote_addr == server_addr_);
175 
176       // TODO: we need better framing so we know how many bytes we can
177       // return before we need to read the next address. For UDP, this will be
178       // fine as long as the reader always reads everything in the packet.
179       RTC_DCHECK((size_t)result < size_);
180 
181       // Decode the wire packet into the actual results.
182       SocketAddress real_remote_addr;
183       size_t addrlength = UnpackAddressFromNAT(buf_, result, &real_remote_addr);
184       memcpy(data, buf_ + addrlength, result - addrlength);
185 
186       // Make sure this packet should be delivered before returning it.
187       if (!connected_ || (real_remote_addr == remote_addr_)) {
188         if (out_addr)
189           *out_addr = real_remote_addr;
190         result = result - static_cast<int>(addrlength);
191       } else {
192         RTC_LOG(LS_ERROR) << "Dropping packet from unknown remote address: "
193                           << real_remote_addr.ToString();
194         result = 0;  // Tell the caller we didn't read anything
195       }
196     }
197 
198     return result;
199   }
200 
Close()201   int Close() override {
202     int result = 0;
203     if (socket_) {
204       result = socket_->Close();
205       if (result >= 0) {
206         connected_ = false;
207         remote_addr_ = SocketAddress();
208         delete socket_;
209         socket_ = nullptr;
210       }
211     }
212     return result;
213   }
214 
Listen(int backlog)215   int Listen(int backlog) override { return socket_->Listen(backlog); }
Accept(SocketAddress * paddr)216   AsyncSocket* Accept(SocketAddress* paddr) override {
217     return socket_->Accept(paddr);
218   }
GetError() const219   int GetError() const override {
220     return socket_ ? socket_->GetError() : error_;
221   }
SetError(int error)222   void SetError(int error) override {
223     if (socket_) {
224       socket_->SetError(error);
225     } else {
226       error_ = error;
227     }
228   }
GetState() const229   ConnState GetState() const override {
230     return connected_ ? CS_CONNECTED : CS_CLOSED;
231   }
GetOption(Option opt,int * value)232   int GetOption(Option opt, int* value) override {
233     return socket_->GetOption(opt, value);
234   }
SetOption(Option opt,int value)235   int SetOption(Option opt, int value) override {
236     return socket_->SetOption(opt, value);
237   }
238 
OnConnectEvent(AsyncSocket * socket)239   void OnConnectEvent(AsyncSocket* socket) {
240     // If we're NATed, we need to send a message with the real addr to use.
241     RTC_DCHECK(socket == socket_);
242     if (server_addr_.IsNil()) {
243       connected_ = true;
244       SignalConnectEvent(this);
245     } else {
246       SendConnectRequest();
247     }
248   }
OnReadEvent(AsyncSocket * socket)249   void OnReadEvent(AsyncSocket* socket) {
250     // If we're NATed, we need to process the connect reply.
251     RTC_DCHECK(socket == socket_);
252     if (type_ == SOCK_STREAM && !server_addr_.IsNil() && !connected_) {
253       HandleConnectReply();
254     } else {
255       SignalReadEvent(this);
256     }
257   }
OnWriteEvent(AsyncSocket * socket)258   void OnWriteEvent(AsyncSocket* socket) {
259     RTC_DCHECK(socket == socket_);
260     SignalWriteEvent(this);
261   }
OnCloseEvent(AsyncSocket * socket,int error)262   void OnCloseEvent(AsyncSocket* socket, int error) {
263     RTC_DCHECK(socket == socket_);
264     SignalCloseEvent(this, error);
265   }
266 
267  private:
BindInternal(const SocketAddress & addr)268   int BindInternal(const SocketAddress& addr) {
269     RTC_DCHECK(!socket_);
270 
271     int result;
272     socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_);
273     result = (socket_) ? socket_->Bind(addr) : -1;
274     if (result >= 0) {
275       socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent);
276       socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent);
277       socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent);
278       socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent);
279     } else {
280       server_addr_.Clear();
281       delete socket_;
282       socket_ = nullptr;
283     }
284 
285     return result;
286   }
287 
288   // Makes sure the buffer is at least the given size.
Grow(size_t new_size)289   void Grow(size_t new_size) {
290     if (size_ < new_size) {
291       delete[] buf_;
292       size_ = new_size;
293       buf_ = new char[size_];
294     }
295   }
296 
297   // Sends the destination address to the server to tell it to connect.
SendConnectRequest()298   void SendConnectRequest() {
299     char buf[kNATEncodedIPv6AddressSize];
300     size_t length = PackAddressForNAT(buf, arraysize(buf), remote_addr_);
301     socket_->Send(buf, length);
302   }
303 
304   // Handles the byte sent back from the server and fires the appropriate event.
HandleConnectReply()305   void HandleConnectReply() {
306     char code;
307     socket_->Recv(&code, sizeof(code), nullptr);
308     if (code == 0) {
309       connected_ = true;
310       SignalConnectEvent(this);
311     } else {
312       Close();
313       SignalCloseEvent(this, code);
314     }
315   }
316 
317   NATInternalSocketFactory* sf_;
318   int family_;
319   int type_;
320   bool connected_;
321   SocketAddress remote_addr_;
322   SocketAddress server_addr_;  // address of the NAT server
323   AsyncSocket* socket_;
324   // Need to hold error in case it occurs before the socket is created.
325   int error_ = 0;
326   char* buf_;
327   size_t size_;
328 };
329 
330 // NATSocketFactory
NATSocketFactory(SocketFactory * factory,const SocketAddress & nat_udp_addr,const SocketAddress & nat_tcp_addr)331 NATSocketFactory::NATSocketFactory(SocketFactory* factory,
332                                    const SocketAddress& nat_udp_addr,
333                                    const SocketAddress& nat_tcp_addr)
334     : factory_(factory), nat_udp_addr_(nat_udp_addr),
335       nat_tcp_addr_(nat_tcp_addr) {
336 }
337 
CreateSocket(int type)338 Socket* NATSocketFactory::CreateSocket(int type) {
339   return CreateSocket(AF_INET, type);
340 }
341 
CreateSocket(int family,int type)342 Socket* NATSocketFactory::CreateSocket(int family, int type) {
343   return new NATSocket(this, family, type);
344 }
345 
CreateAsyncSocket(int type)346 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int type) {
347   return CreateAsyncSocket(AF_INET, type);
348 }
349 
CreateAsyncSocket(int family,int type)350 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int family, int type) {
351   return new NATSocket(this, family, type);
352 }
353 
CreateInternalSocket(int family,int type,const SocketAddress & local_addr,SocketAddress * nat_addr)354 AsyncSocket* NATSocketFactory::CreateInternalSocket(int family, int type,
355     const SocketAddress& local_addr, SocketAddress* nat_addr) {
356   if (type == SOCK_STREAM) {
357     *nat_addr = nat_tcp_addr_;
358   } else {
359     *nat_addr = nat_udp_addr_;
360   }
361   return factory_->CreateAsyncSocket(family, type);
362 }
363 
364 // NATSocketServer
NATSocketServer(SocketServer * server)365 NATSocketServer::NATSocketServer(SocketServer* server)
366     : server_(server), msg_queue_(nullptr) {}
367 
GetTranslator(const SocketAddress & ext_ip)368 NATSocketServer::Translator* NATSocketServer::GetTranslator(
369     const SocketAddress& ext_ip) {
370   return nats_.Get(ext_ip);
371 }
372 
AddTranslator(const SocketAddress & ext_ip,const SocketAddress & int_ip,NATType type)373 NATSocketServer::Translator* NATSocketServer::AddTranslator(
374     const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
375   // Fail if a translator already exists with this extternal address.
376   if (nats_.Get(ext_ip))
377     return nullptr;
378 
379   return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip));
380 }
381 
RemoveTranslator(const SocketAddress & ext_ip)382 void NATSocketServer::RemoveTranslator(
383     const SocketAddress& ext_ip) {
384   nats_.Remove(ext_ip);
385 }
386 
CreateSocket(int type)387 Socket* NATSocketServer::CreateSocket(int type) {
388   return CreateSocket(AF_INET, type);
389 }
390 
CreateSocket(int family,int type)391 Socket* NATSocketServer::CreateSocket(int family, int type) {
392   return new NATSocket(this, family, type);
393 }
394 
CreateAsyncSocket(int type)395 AsyncSocket* NATSocketServer::CreateAsyncSocket(int type) {
396   return CreateAsyncSocket(AF_INET, type);
397 }
398 
CreateAsyncSocket(int family,int type)399 AsyncSocket* NATSocketServer::CreateAsyncSocket(int family, int type) {
400   return new NATSocket(this, family, type);
401 }
402 
SetMessageQueue(MessageQueue * queue)403 void NATSocketServer::SetMessageQueue(MessageQueue* queue) {
404   msg_queue_ = queue;
405   server_->SetMessageQueue(queue);
406 }
407 
Wait(int cms,bool process_io)408 bool NATSocketServer::Wait(int cms, bool process_io) {
409   return server_->Wait(cms, process_io);
410 }
411 
WakeUp()412 void NATSocketServer::WakeUp() {
413   server_->WakeUp();
414 }
415 
CreateInternalSocket(int family,int type,const SocketAddress & local_addr,SocketAddress * nat_addr)416 AsyncSocket* NATSocketServer::CreateInternalSocket(int family, int type,
417     const SocketAddress& local_addr, SocketAddress* nat_addr) {
418   AsyncSocket* socket = nullptr;
419   Translator* nat = nats_.FindClient(local_addr);
420   if (nat) {
421     socket = nat->internal_factory()->CreateAsyncSocket(family, type);
422     *nat_addr = (type == SOCK_STREAM) ?
423         nat->internal_tcp_address() : nat->internal_udp_address();
424   } else {
425     socket = server_->CreateAsyncSocket(family, type);
426   }
427   return socket;
428 }
429 
430 // NATSocketServer::Translator
Translator(NATSocketServer * server,NATType type,const SocketAddress & int_ip,SocketFactory * ext_factory,const SocketAddress & ext_ip)431 NATSocketServer::Translator::Translator(
432     NATSocketServer* server, NATType type, const SocketAddress& int_ip,
433     SocketFactory* ext_factory, const SocketAddress& ext_ip)
434     : server_(server) {
435   // Create a new private network, and a NATServer running on the private
436   // network that bridges to the external network. Also tell the private
437   // network to use the same message queue as us.
438   VirtualSocketServer* internal_server = new VirtualSocketServer();
439   internal_server->SetMessageQueue(server_->queue());
440   internal_factory_.reset(internal_server);
441   nat_server_.reset(new NATServer(type, internal_server, int_ip, int_ip,
442                                   ext_factory, ext_ip));
443 }
444 
445 NATSocketServer::Translator::~Translator() = default;
446 
GetTranslator(const SocketAddress & ext_ip)447 NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator(
448     const SocketAddress& ext_ip) {
449   return nats_.Get(ext_ip);
450 }
451 
AddTranslator(const SocketAddress & ext_ip,const SocketAddress & int_ip,NATType type)452 NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator(
453     const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
454   // Fail if a translator already exists with this extternal address.
455   if (nats_.Get(ext_ip))
456     return nullptr;
457 
458   AddClient(ext_ip);
459   return nats_.Add(ext_ip,
460                    new Translator(server_, type, int_ip, server_, ext_ip));
461 }
RemoveTranslator(const SocketAddress & ext_ip)462 void NATSocketServer::Translator::RemoveTranslator(
463     const SocketAddress& ext_ip) {
464   nats_.Remove(ext_ip);
465   RemoveClient(ext_ip);
466 }
467 
AddClient(const SocketAddress & int_ip)468 bool NATSocketServer::Translator::AddClient(
469     const SocketAddress& int_ip) {
470   // Fail if a client already exists with this internal address.
471   if (clients_.find(int_ip) != clients_.end())
472     return false;
473 
474   clients_.insert(int_ip);
475   return true;
476 }
477 
RemoveClient(const SocketAddress & int_ip)478 void NATSocketServer::Translator::RemoveClient(
479     const SocketAddress& int_ip) {
480   std::set<SocketAddress>::iterator it = clients_.find(int_ip);
481   if (it != clients_.end()) {
482     clients_.erase(it);
483   }
484 }
485 
FindClient(const SocketAddress & int_ip)486 NATSocketServer::Translator* NATSocketServer::Translator::FindClient(
487     const SocketAddress& int_ip) {
488   // See if we have the requested IP, or any of our children do.
489   return (clients_.find(int_ip) != clients_.end()) ?
490       this : nats_.FindClient(int_ip);
491 }
492 
493 // NATSocketServer::TranslatorMap
~TranslatorMap()494 NATSocketServer::TranslatorMap::~TranslatorMap() {
495   for (TranslatorMap::iterator it = begin(); it != end(); ++it) {
496     delete it->second;
497   }
498 }
499 
Get(const SocketAddress & ext_ip)500 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Get(
501     const SocketAddress& ext_ip) {
502   TranslatorMap::iterator it = find(ext_ip);
503   return (it != end()) ? it->second : nullptr;
504 }
505 
Add(const SocketAddress & ext_ip,Translator * nat)506 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Add(
507     const SocketAddress& ext_ip, Translator* nat) {
508   (*this)[ext_ip] = nat;
509   return nat;
510 }
511 
Remove(const SocketAddress & ext_ip)512 void NATSocketServer::TranslatorMap::Remove(
513     const SocketAddress& ext_ip) {
514   TranslatorMap::iterator it = find(ext_ip);
515   if (it != end()) {
516     delete it->second;
517     erase(it);
518   }
519 }
520 
FindClient(const SocketAddress & int_ip)521 NATSocketServer::Translator* NATSocketServer::TranslatorMap::FindClient(
522     const SocketAddress& int_ip) {
523   Translator* nat = nullptr;
524   for (TranslatorMap::iterator it = begin(); it != end() && !nat; ++it) {
525     nat = it->second->FindClient(int_ip);
526   }
527   return nat;
528 }
529 
530 }  // namespace rtc
531