1 /*
2  *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "webrtc/base/natsocketfactory.h"
12 
13 #include "webrtc/base/arraysize.h"
14 #include "webrtc/base/checks.h"
15 #include "webrtc/base/logging.h"
16 #include "webrtc/base/natserver.h"
17 #include "webrtc/base/virtualsocketserver.h"
18 
19 namespace rtc {
20 
21 // Packs the given socketaddress into the buffer in buf, in the quasi-STUN
22 // format that the natserver uses.
23 // Returns 0 if an invalid address is passed.
PackAddressForNAT(char * buf,size_t buf_size,const SocketAddress & remote_addr)24 size_t PackAddressForNAT(char* buf, size_t buf_size,
25                          const SocketAddress& remote_addr) {
26   const IPAddress& ip = remote_addr.ipaddr();
27   int family = ip.family();
28   buf[0] = 0;
29   buf[1] = family;
30   // Writes the port.
31   *(reinterpret_cast<uint16_t*>(&buf[2])) = HostToNetwork16(remote_addr.port());
32   if (family == AF_INET) {
33     RTC_DCHECK(buf_size >= kNATEncodedIPv4AddressSize);
34     in_addr v4addr = ip.ipv4_address();
35     memcpy(&buf[4], &v4addr, kNATEncodedIPv4AddressSize - 4);
36     return kNATEncodedIPv4AddressSize;
37   } else if (family == AF_INET6) {
38     RTC_DCHECK(buf_size >= kNATEncodedIPv6AddressSize);
39     in6_addr v6addr = ip.ipv6_address();
40     memcpy(&buf[4], &v6addr, kNATEncodedIPv6AddressSize - 4);
41     return kNATEncodedIPv6AddressSize;
42   }
43   return 0U;
44 }
45 
46 // Decodes the remote address from a packet that has been encoded with the nat's
47 // quasi-STUN format. Returns the length of the address (i.e., the offset into
48 // data where the original packet starts).
UnpackAddressFromNAT(const char * buf,size_t buf_size,SocketAddress * remote_addr)49 size_t UnpackAddressFromNAT(const char* buf, size_t buf_size,
50                             SocketAddress* remote_addr) {
51   RTC_DCHECK(buf_size >= 8);
52   RTC_DCHECK(buf[0] == 0);
53   int family = buf[1];
54   uint16_t port =
55       NetworkToHost16(*(reinterpret_cast<const uint16_t*>(&buf[2])));
56   if (family == AF_INET) {
57     const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]);
58     *remote_addr = SocketAddress(IPAddress(*v4addr), port);
59     return kNATEncodedIPv4AddressSize;
60   } else if (family == AF_INET6) {
61     RTC_DCHECK(buf_size >= 20);
62     const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]);
63     *remote_addr = SocketAddress(IPAddress(*v6addr), port);
64     return kNATEncodedIPv6AddressSize;
65   }
66   return 0U;
67 }
68 
69 
70 // NATSocket
71 class NATSocket : public AsyncSocket, public sigslot::has_slots<> {
72  public:
NATSocket(NATInternalSocketFactory * sf,int family,int type)73   explicit NATSocket(NATInternalSocketFactory* sf, int family, int type)
74       : sf_(sf), family_(family), type_(type), connected_(false),
75         socket_(NULL), buf_(NULL), size_(0) {
76   }
77 
~NATSocket()78   ~NATSocket() override {
79     delete socket_;
80     delete[] buf_;
81   }
82 
GetLocalAddress() const83   SocketAddress GetLocalAddress() const override {
84     return (socket_) ? socket_->GetLocalAddress() : SocketAddress();
85   }
86 
GetRemoteAddress() const87   SocketAddress GetRemoteAddress() const override {
88     return remote_addr_;  // will be NIL if not connected
89   }
90 
Bind(const SocketAddress & addr)91   int Bind(const SocketAddress& addr) override {
92     if (socket_) {  // already bound, bubble up error
93       return -1;
94     }
95 
96     int result;
97     socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_);
98     result = (socket_) ? socket_->Bind(addr) : -1;
99     if (result >= 0) {
100       socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent);
101       socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent);
102       socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent);
103       socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent);
104     } else {
105       server_addr_.Clear();
106       delete socket_;
107       socket_ = NULL;
108     }
109 
110     return result;
111   }
112 
Connect(const SocketAddress & addr)113   int Connect(const SocketAddress& addr) override {
114     if (!socket_) {  // socket must be bound, for now
115       return -1;
116     }
117 
118     int result = 0;
119     if (type_ == SOCK_STREAM) {
120       result = socket_->Connect(server_addr_.IsNil() ? addr : server_addr_);
121     } else {
122       connected_ = true;
123     }
124 
125     if (result >= 0) {
126       remote_addr_ = addr;
127     }
128 
129     return result;
130   }
131 
Send(const void * data,size_t size)132   int Send(const void* data, size_t size) override {
133     RTC_DCHECK(connected_);
134     return SendTo(data, size, remote_addr_);
135   }
136 
SendTo(const void * data,size_t size,const SocketAddress & addr)137   int SendTo(const void* data,
138              size_t size,
139              const SocketAddress& addr) override {
140     RTC_DCHECK(!connected_ || addr == remote_addr_);
141     if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
142       return socket_->SendTo(data, size, addr);
143     }
144     // This array will be too large for IPv4 packets, but only by 12 bytes.
145     std::unique_ptr<char[]> buf(new char[size + kNATEncodedIPv6AddressSize]);
146     size_t addrlength = PackAddressForNAT(buf.get(),
147                                           size + kNATEncodedIPv6AddressSize,
148                                           addr);
149     size_t encoded_size = size + addrlength;
150     memcpy(buf.get() + addrlength, data, size);
151     int result = socket_->SendTo(buf.get(), encoded_size, server_addr_);
152     if (result >= 0) {
153       RTC_DCHECK(result == static_cast<int>(encoded_size));
154       result = result - static_cast<int>(addrlength);
155     }
156     return result;
157   }
158 
Recv(void * data,size_t size,int64_t * timestamp)159   int Recv(void* data, size_t size, int64_t* timestamp) override {
160     SocketAddress addr;
161     return RecvFrom(data, size, &addr, timestamp);
162   }
163 
RecvFrom(void * data,size_t size,SocketAddress * out_addr,int64_t * timestamp)164   int RecvFrom(void* data,
165                size_t size,
166                SocketAddress* out_addr,
167                int64_t* timestamp) override {
168     if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
169       return socket_->RecvFrom(data, size, out_addr, timestamp);
170     }
171     // Make sure we have enough room to read the requested amount plus the
172     // largest possible header address.
173     SocketAddress remote_addr;
174     Grow(size + kNATEncodedIPv6AddressSize);
175 
176     // Read the packet from the socket.
177     int result = socket_->RecvFrom(buf_, size_, &remote_addr, timestamp);
178     if (result >= 0) {
179       RTC_DCHECK(remote_addr == server_addr_);
180 
181       // TODO: we need better framing so we know how many bytes we can
182       // return before we need to read the next address. For UDP, this will be
183       // fine as long as the reader always reads everything in the packet.
184       RTC_DCHECK((size_t)result < size_);
185 
186       // Decode the wire packet into the actual results.
187       SocketAddress real_remote_addr;
188       size_t addrlength = UnpackAddressFromNAT(buf_, result, &real_remote_addr);
189       memcpy(data, buf_ + addrlength, result - addrlength);
190 
191       // Make sure this packet should be delivered before returning it.
192       if (!connected_ || (real_remote_addr == remote_addr_)) {
193         if (out_addr)
194           *out_addr = real_remote_addr;
195         result = result - static_cast<int>(addrlength);
196       } else {
197         LOG(LS_ERROR) << "Dropping packet from unknown remote address: "
198                       << real_remote_addr.ToString();
199         result = 0;  // Tell the caller we didn't read anything
200       }
201     }
202 
203     return result;
204   }
205 
Close()206   int Close() override {
207     int result = 0;
208     if (socket_) {
209       result = socket_->Close();
210       if (result >= 0) {
211         connected_ = false;
212         remote_addr_ = SocketAddress();
213         delete socket_;
214         socket_ = NULL;
215       }
216     }
217     return result;
218   }
219 
Listen(int backlog)220   int Listen(int backlog) override { return socket_->Listen(backlog); }
Accept(SocketAddress * paddr)221   AsyncSocket* Accept(SocketAddress* paddr) override {
222     return socket_->Accept(paddr);
223   }
GetError() const224   int GetError() const override { return socket_->GetError(); }
SetError(int error)225   void SetError(int error) override { socket_->SetError(error); }
GetState() const226   ConnState GetState() const override {
227     return connected_ ? CS_CONNECTED : CS_CLOSED;
228   }
EstimateMTU(uint16_t * mtu)229   int EstimateMTU(uint16_t* mtu) override { return socket_->EstimateMTU(mtu); }
GetOption(Option opt,int * value)230   int GetOption(Option opt, int* value) override {
231     return socket_->GetOption(opt, value);
232   }
SetOption(Option opt,int value)233   int SetOption(Option opt, int value) override {
234     return socket_->SetOption(opt, value);
235   }
236 
OnConnectEvent(AsyncSocket * socket)237   void OnConnectEvent(AsyncSocket* socket) {
238     // If we're NATed, we need to send a message with the real addr to use.
239     RTC_DCHECK(socket == socket_);
240     if (server_addr_.IsNil()) {
241       connected_ = true;
242       SignalConnectEvent(this);
243     } else {
244       SendConnectRequest();
245     }
246   }
OnReadEvent(AsyncSocket * socket)247   void OnReadEvent(AsyncSocket* socket) {
248     // If we're NATed, we need to process the connect reply.
249     RTC_DCHECK(socket == socket_);
250     if (type_ == SOCK_STREAM && !server_addr_.IsNil() && !connected_) {
251       HandleConnectReply();
252     } else {
253       SignalReadEvent(this);
254     }
255   }
OnWriteEvent(AsyncSocket * socket)256   void OnWriteEvent(AsyncSocket* socket) {
257     RTC_DCHECK(socket == socket_);
258     SignalWriteEvent(this);
259   }
OnCloseEvent(AsyncSocket * socket,int error)260   void OnCloseEvent(AsyncSocket* socket, int error) {
261     RTC_DCHECK(socket == socket_);
262     SignalCloseEvent(this, error);
263   }
264 
265  private:
266   // Makes sure the buffer is at least the given size.
Grow(size_t new_size)267   void Grow(size_t new_size) {
268     if (size_ < new_size) {
269       delete[] buf_;
270       size_ = new_size;
271       buf_ = new char[size_];
272     }
273   }
274 
275   // Sends the destination address to the server to tell it to connect.
SendConnectRequest()276   void SendConnectRequest() {
277     char buf[kNATEncodedIPv6AddressSize];
278     size_t length = PackAddressForNAT(buf, arraysize(buf), remote_addr_);
279     socket_->Send(buf, length);
280   }
281 
282   // Handles the byte sent back from the server and fires the appropriate event.
HandleConnectReply()283   void HandleConnectReply() {
284     char code;
285     socket_->Recv(&code, sizeof(code), nullptr);
286     if (code == 0) {
287       connected_ = true;
288       SignalConnectEvent(this);
289     } else {
290       Close();
291       SignalCloseEvent(this, code);
292     }
293   }
294 
295   NATInternalSocketFactory* sf_;
296   int family_;
297   int type_;
298   bool connected_;
299   SocketAddress remote_addr_;
300   SocketAddress server_addr_;  // address of the NAT server
301   AsyncSocket* socket_;
302   char* buf_;
303   size_t size_;
304 };
305 
306 // NATSocketFactory
NATSocketFactory(SocketFactory * factory,const SocketAddress & nat_udp_addr,const SocketAddress & nat_tcp_addr)307 NATSocketFactory::NATSocketFactory(SocketFactory* factory,
308                                    const SocketAddress& nat_udp_addr,
309                                    const SocketAddress& nat_tcp_addr)
310     : factory_(factory), nat_udp_addr_(nat_udp_addr),
311       nat_tcp_addr_(nat_tcp_addr) {
312 }
313 
CreateSocket(int type)314 Socket* NATSocketFactory::CreateSocket(int type) {
315   return CreateSocket(AF_INET, type);
316 }
317 
CreateSocket(int family,int type)318 Socket* NATSocketFactory::CreateSocket(int family, int type) {
319   return new NATSocket(this, family, type);
320 }
321 
CreateAsyncSocket(int type)322 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int type) {
323   return CreateAsyncSocket(AF_INET, type);
324 }
325 
CreateAsyncSocket(int family,int type)326 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int family, int type) {
327   return new NATSocket(this, family, type);
328 }
329 
CreateInternalSocket(int family,int type,const SocketAddress & local_addr,SocketAddress * nat_addr)330 AsyncSocket* NATSocketFactory::CreateInternalSocket(int family, int type,
331     const SocketAddress& local_addr, SocketAddress* nat_addr) {
332   if (type == SOCK_STREAM) {
333     *nat_addr = nat_tcp_addr_;
334   } else {
335     *nat_addr = nat_udp_addr_;
336   }
337   return factory_->CreateAsyncSocket(family, type);
338 }
339 
340 // NATSocketServer
NATSocketServer(SocketServer * server)341 NATSocketServer::NATSocketServer(SocketServer* server)
342     : server_(server), msg_queue_(NULL) {
343 }
344 
GetTranslator(const SocketAddress & ext_ip)345 NATSocketServer::Translator* NATSocketServer::GetTranslator(
346     const SocketAddress& ext_ip) {
347   return nats_.Get(ext_ip);
348 }
349 
AddTranslator(const SocketAddress & ext_ip,const SocketAddress & int_ip,NATType type)350 NATSocketServer::Translator* NATSocketServer::AddTranslator(
351     const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
352   // Fail if a translator already exists with this extternal address.
353   if (nats_.Get(ext_ip))
354     return NULL;
355 
356   return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip));
357 }
358 
RemoveTranslator(const SocketAddress & ext_ip)359 void NATSocketServer::RemoveTranslator(
360     const SocketAddress& ext_ip) {
361   nats_.Remove(ext_ip);
362 }
363 
CreateSocket(int type)364 Socket* NATSocketServer::CreateSocket(int type) {
365   return CreateSocket(AF_INET, type);
366 }
367 
CreateSocket(int family,int type)368 Socket* NATSocketServer::CreateSocket(int family, int type) {
369   return new NATSocket(this, family, type);
370 }
371 
CreateAsyncSocket(int type)372 AsyncSocket* NATSocketServer::CreateAsyncSocket(int type) {
373   return CreateAsyncSocket(AF_INET, type);
374 }
375 
CreateAsyncSocket(int family,int type)376 AsyncSocket* NATSocketServer::CreateAsyncSocket(int family, int type) {
377   return new NATSocket(this, family, type);
378 }
379 
SetMessageQueue(MessageQueue * queue)380 void NATSocketServer::SetMessageQueue(MessageQueue* queue) {
381   msg_queue_ = queue;
382   server_->SetMessageQueue(queue);
383 }
384 
Wait(int cms,bool process_io)385 bool NATSocketServer::Wait(int cms, bool process_io) {
386   return server_->Wait(cms, process_io);
387 }
388 
WakeUp()389 void NATSocketServer::WakeUp() {
390   server_->WakeUp();
391 }
392 
CreateInternalSocket(int family,int type,const SocketAddress & local_addr,SocketAddress * nat_addr)393 AsyncSocket* NATSocketServer::CreateInternalSocket(int family, int type,
394     const SocketAddress& local_addr, SocketAddress* nat_addr) {
395   AsyncSocket* socket = NULL;
396   Translator* nat = nats_.FindClient(local_addr);
397   if (nat) {
398     socket = nat->internal_factory()->CreateAsyncSocket(family, type);
399     *nat_addr = (type == SOCK_STREAM) ?
400         nat->internal_tcp_address() : nat->internal_udp_address();
401   } else {
402     socket = server_->CreateAsyncSocket(family, type);
403   }
404   return socket;
405 }
406 
407 // NATSocketServer::Translator
Translator(NATSocketServer * server,NATType type,const SocketAddress & int_ip,SocketFactory * ext_factory,const SocketAddress & ext_ip)408 NATSocketServer::Translator::Translator(
409     NATSocketServer* server, NATType type, const SocketAddress& int_ip,
410     SocketFactory* ext_factory, const SocketAddress& ext_ip)
411     : server_(server) {
412   // Create a new private network, and a NATServer running on the private
413   // network that bridges to the external network. Also tell the private
414   // network to use the same message queue as us.
415   VirtualSocketServer* internal_server = new VirtualSocketServer(server_);
416   internal_server->SetMessageQueue(server_->queue());
417   internal_factory_.reset(internal_server);
418   nat_server_.reset(new NATServer(type, internal_server, int_ip, int_ip,
419                                   ext_factory, ext_ip));
420 }
421 
422 NATSocketServer::Translator::~Translator() = default;
423 
GetTranslator(const SocketAddress & ext_ip)424 NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator(
425     const SocketAddress& ext_ip) {
426   return nats_.Get(ext_ip);
427 }
428 
AddTranslator(const SocketAddress & ext_ip,const SocketAddress & int_ip,NATType type)429 NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator(
430     const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
431   // Fail if a translator already exists with this extternal address.
432   if (nats_.Get(ext_ip))
433     return NULL;
434 
435   AddClient(ext_ip);
436   return nats_.Add(ext_ip,
437                    new Translator(server_, type, int_ip, server_, ext_ip));
438 }
RemoveTranslator(const SocketAddress & ext_ip)439 void NATSocketServer::Translator::RemoveTranslator(
440     const SocketAddress& ext_ip) {
441   nats_.Remove(ext_ip);
442   RemoveClient(ext_ip);
443 }
444 
AddClient(const SocketAddress & int_ip)445 bool NATSocketServer::Translator::AddClient(
446     const SocketAddress& int_ip) {
447   // Fail if a client already exists with this internal address.
448   if (clients_.find(int_ip) != clients_.end())
449     return false;
450 
451   clients_.insert(int_ip);
452   return true;
453 }
454 
RemoveClient(const SocketAddress & int_ip)455 void NATSocketServer::Translator::RemoveClient(
456     const SocketAddress& int_ip) {
457   std::set<SocketAddress>::iterator it = clients_.find(int_ip);
458   if (it != clients_.end()) {
459     clients_.erase(it);
460   }
461 }
462 
FindClient(const SocketAddress & int_ip)463 NATSocketServer::Translator* NATSocketServer::Translator::FindClient(
464     const SocketAddress& int_ip) {
465   // See if we have the requested IP, or any of our children do.
466   return (clients_.find(int_ip) != clients_.end()) ?
467       this : nats_.FindClient(int_ip);
468 }
469 
470 // NATSocketServer::TranslatorMap
~TranslatorMap()471 NATSocketServer::TranslatorMap::~TranslatorMap() {
472   for (TranslatorMap::iterator it = begin(); it != end(); ++it) {
473     delete it->second;
474   }
475 }
476 
Get(const SocketAddress & ext_ip)477 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Get(
478     const SocketAddress& ext_ip) {
479   TranslatorMap::iterator it = find(ext_ip);
480   return (it != end()) ? it->second : NULL;
481 }
482 
Add(const SocketAddress & ext_ip,Translator * nat)483 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Add(
484     const SocketAddress& ext_ip, Translator* nat) {
485   (*this)[ext_ip] = nat;
486   return nat;
487 }
488 
Remove(const SocketAddress & ext_ip)489 void NATSocketServer::TranslatorMap::Remove(
490     const SocketAddress& ext_ip) {
491   TranslatorMap::iterator it = find(ext_ip);
492   if (it != end()) {
493     delete it->second;
494     erase(it);
495   }
496 }
497 
FindClient(const SocketAddress & int_ip)498 NATSocketServer::Translator* NATSocketServer::TranslatorMap::FindClient(
499     const SocketAddress& int_ip) {
500   Translator* nat = NULL;
501   for (TranslatorMap::iterator it = begin(); it != end() && !nat; ++it) {
502     nat = it->second->FindClient(int_ip);
503   }
504   return nat;
505 }
506 
507 }  // namespace rtc
508