1 /*
2 * Copyright 2004 The WebRTC Project Authors. All rights reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "webrtc/base/natsocketfactory.h"
12
13 #include "webrtc/base/arraysize.h"
14 #include "webrtc/base/checks.h"
15 #include "webrtc/base/logging.h"
16 #include "webrtc/base/natserver.h"
17 #include "webrtc/base/virtualsocketserver.h"
18
19 namespace rtc {
20
21 // Packs the given socketaddress into the buffer in buf, in the quasi-STUN
22 // format that the natserver uses.
23 // Returns 0 if an invalid address is passed.
PackAddressForNAT(char * buf,size_t buf_size,const SocketAddress & remote_addr)24 size_t PackAddressForNAT(char* buf, size_t buf_size,
25 const SocketAddress& remote_addr) {
26 const IPAddress& ip = remote_addr.ipaddr();
27 int family = ip.family();
28 buf[0] = 0;
29 buf[1] = family;
30 // Writes the port.
31 *(reinterpret_cast<uint16_t*>(&buf[2])) = HostToNetwork16(remote_addr.port());
32 if (family == AF_INET) {
33 RTC_DCHECK(buf_size >= kNATEncodedIPv4AddressSize);
34 in_addr v4addr = ip.ipv4_address();
35 memcpy(&buf[4], &v4addr, kNATEncodedIPv4AddressSize - 4);
36 return kNATEncodedIPv4AddressSize;
37 } else if (family == AF_INET6) {
38 RTC_DCHECK(buf_size >= kNATEncodedIPv6AddressSize);
39 in6_addr v6addr = ip.ipv6_address();
40 memcpy(&buf[4], &v6addr, kNATEncodedIPv6AddressSize - 4);
41 return kNATEncodedIPv6AddressSize;
42 }
43 return 0U;
44 }
45
46 // Decodes the remote address from a packet that has been encoded with the nat's
47 // quasi-STUN format. Returns the length of the address (i.e., the offset into
48 // data where the original packet starts).
UnpackAddressFromNAT(const char * buf,size_t buf_size,SocketAddress * remote_addr)49 size_t UnpackAddressFromNAT(const char* buf, size_t buf_size,
50 SocketAddress* remote_addr) {
51 RTC_DCHECK(buf_size >= 8);
52 RTC_DCHECK(buf[0] == 0);
53 int family = buf[1];
54 uint16_t port =
55 NetworkToHost16(*(reinterpret_cast<const uint16_t*>(&buf[2])));
56 if (family == AF_INET) {
57 const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]);
58 *remote_addr = SocketAddress(IPAddress(*v4addr), port);
59 return kNATEncodedIPv4AddressSize;
60 } else if (family == AF_INET6) {
61 RTC_DCHECK(buf_size >= 20);
62 const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]);
63 *remote_addr = SocketAddress(IPAddress(*v6addr), port);
64 return kNATEncodedIPv6AddressSize;
65 }
66 return 0U;
67 }
68
69
70 // NATSocket
71 class NATSocket : public AsyncSocket, public sigslot::has_slots<> {
72 public:
NATSocket(NATInternalSocketFactory * sf,int family,int type)73 explicit NATSocket(NATInternalSocketFactory* sf, int family, int type)
74 : sf_(sf), family_(family), type_(type), connected_(false),
75 socket_(NULL), buf_(NULL), size_(0) {
76 }
77
~NATSocket()78 ~NATSocket() override {
79 delete socket_;
80 delete[] buf_;
81 }
82
GetLocalAddress() const83 SocketAddress GetLocalAddress() const override {
84 return (socket_) ? socket_->GetLocalAddress() : SocketAddress();
85 }
86
GetRemoteAddress() const87 SocketAddress GetRemoteAddress() const override {
88 return remote_addr_; // will be NIL if not connected
89 }
90
Bind(const SocketAddress & addr)91 int Bind(const SocketAddress& addr) override {
92 if (socket_) { // already bound, bubble up error
93 return -1;
94 }
95
96 int result;
97 socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_);
98 result = (socket_) ? socket_->Bind(addr) : -1;
99 if (result >= 0) {
100 socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent);
101 socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent);
102 socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent);
103 socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent);
104 } else {
105 server_addr_.Clear();
106 delete socket_;
107 socket_ = NULL;
108 }
109
110 return result;
111 }
112
Connect(const SocketAddress & addr)113 int Connect(const SocketAddress& addr) override {
114 if (!socket_) { // socket must be bound, for now
115 return -1;
116 }
117
118 int result = 0;
119 if (type_ == SOCK_STREAM) {
120 result = socket_->Connect(server_addr_.IsNil() ? addr : server_addr_);
121 } else {
122 connected_ = true;
123 }
124
125 if (result >= 0) {
126 remote_addr_ = addr;
127 }
128
129 return result;
130 }
131
Send(const void * data,size_t size)132 int Send(const void* data, size_t size) override {
133 RTC_DCHECK(connected_);
134 return SendTo(data, size, remote_addr_);
135 }
136
SendTo(const void * data,size_t size,const SocketAddress & addr)137 int SendTo(const void* data,
138 size_t size,
139 const SocketAddress& addr) override {
140 RTC_DCHECK(!connected_ || addr == remote_addr_);
141 if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
142 return socket_->SendTo(data, size, addr);
143 }
144 // This array will be too large for IPv4 packets, but only by 12 bytes.
145 std::unique_ptr<char[]> buf(new char[size + kNATEncodedIPv6AddressSize]);
146 size_t addrlength = PackAddressForNAT(buf.get(),
147 size + kNATEncodedIPv6AddressSize,
148 addr);
149 size_t encoded_size = size + addrlength;
150 memcpy(buf.get() + addrlength, data, size);
151 int result = socket_->SendTo(buf.get(), encoded_size, server_addr_);
152 if (result >= 0) {
153 RTC_DCHECK(result == static_cast<int>(encoded_size));
154 result = result - static_cast<int>(addrlength);
155 }
156 return result;
157 }
158
Recv(void * data,size_t size,int64_t * timestamp)159 int Recv(void* data, size_t size, int64_t* timestamp) override {
160 SocketAddress addr;
161 return RecvFrom(data, size, &addr, timestamp);
162 }
163
RecvFrom(void * data,size_t size,SocketAddress * out_addr,int64_t * timestamp)164 int RecvFrom(void* data,
165 size_t size,
166 SocketAddress* out_addr,
167 int64_t* timestamp) override {
168 if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
169 return socket_->RecvFrom(data, size, out_addr, timestamp);
170 }
171 // Make sure we have enough room to read the requested amount plus the
172 // largest possible header address.
173 SocketAddress remote_addr;
174 Grow(size + kNATEncodedIPv6AddressSize);
175
176 // Read the packet from the socket.
177 int result = socket_->RecvFrom(buf_, size_, &remote_addr, timestamp);
178 if (result >= 0) {
179 RTC_DCHECK(remote_addr == server_addr_);
180
181 // TODO: we need better framing so we know how many bytes we can
182 // return before we need to read the next address. For UDP, this will be
183 // fine as long as the reader always reads everything in the packet.
184 RTC_DCHECK((size_t)result < size_);
185
186 // Decode the wire packet into the actual results.
187 SocketAddress real_remote_addr;
188 size_t addrlength = UnpackAddressFromNAT(buf_, result, &real_remote_addr);
189 memcpy(data, buf_ + addrlength, result - addrlength);
190
191 // Make sure this packet should be delivered before returning it.
192 if (!connected_ || (real_remote_addr == remote_addr_)) {
193 if (out_addr)
194 *out_addr = real_remote_addr;
195 result = result - static_cast<int>(addrlength);
196 } else {
197 LOG(LS_ERROR) << "Dropping packet from unknown remote address: "
198 << real_remote_addr.ToString();
199 result = 0; // Tell the caller we didn't read anything
200 }
201 }
202
203 return result;
204 }
205
Close()206 int Close() override {
207 int result = 0;
208 if (socket_) {
209 result = socket_->Close();
210 if (result >= 0) {
211 connected_ = false;
212 remote_addr_ = SocketAddress();
213 delete socket_;
214 socket_ = NULL;
215 }
216 }
217 return result;
218 }
219
Listen(int backlog)220 int Listen(int backlog) override { return socket_->Listen(backlog); }
Accept(SocketAddress * paddr)221 AsyncSocket* Accept(SocketAddress* paddr) override {
222 return socket_->Accept(paddr);
223 }
GetError() const224 int GetError() const override { return socket_->GetError(); }
SetError(int error)225 void SetError(int error) override { socket_->SetError(error); }
GetState() const226 ConnState GetState() const override {
227 return connected_ ? CS_CONNECTED : CS_CLOSED;
228 }
EstimateMTU(uint16_t * mtu)229 int EstimateMTU(uint16_t* mtu) override { return socket_->EstimateMTU(mtu); }
GetOption(Option opt,int * value)230 int GetOption(Option opt, int* value) override {
231 return socket_->GetOption(opt, value);
232 }
SetOption(Option opt,int value)233 int SetOption(Option opt, int value) override {
234 return socket_->SetOption(opt, value);
235 }
236
OnConnectEvent(AsyncSocket * socket)237 void OnConnectEvent(AsyncSocket* socket) {
238 // If we're NATed, we need to send a message with the real addr to use.
239 RTC_DCHECK(socket == socket_);
240 if (server_addr_.IsNil()) {
241 connected_ = true;
242 SignalConnectEvent(this);
243 } else {
244 SendConnectRequest();
245 }
246 }
OnReadEvent(AsyncSocket * socket)247 void OnReadEvent(AsyncSocket* socket) {
248 // If we're NATed, we need to process the connect reply.
249 RTC_DCHECK(socket == socket_);
250 if (type_ == SOCK_STREAM && !server_addr_.IsNil() && !connected_) {
251 HandleConnectReply();
252 } else {
253 SignalReadEvent(this);
254 }
255 }
OnWriteEvent(AsyncSocket * socket)256 void OnWriteEvent(AsyncSocket* socket) {
257 RTC_DCHECK(socket == socket_);
258 SignalWriteEvent(this);
259 }
OnCloseEvent(AsyncSocket * socket,int error)260 void OnCloseEvent(AsyncSocket* socket, int error) {
261 RTC_DCHECK(socket == socket_);
262 SignalCloseEvent(this, error);
263 }
264
265 private:
266 // Makes sure the buffer is at least the given size.
Grow(size_t new_size)267 void Grow(size_t new_size) {
268 if (size_ < new_size) {
269 delete[] buf_;
270 size_ = new_size;
271 buf_ = new char[size_];
272 }
273 }
274
275 // Sends the destination address to the server to tell it to connect.
SendConnectRequest()276 void SendConnectRequest() {
277 char buf[kNATEncodedIPv6AddressSize];
278 size_t length = PackAddressForNAT(buf, arraysize(buf), remote_addr_);
279 socket_->Send(buf, length);
280 }
281
282 // Handles the byte sent back from the server and fires the appropriate event.
HandleConnectReply()283 void HandleConnectReply() {
284 char code;
285 socket_->Recv(&code, sizeof(code), nullptr);
286 if (code == 0) {
287 connected_ = true;
288 SignalConnectEvent(this);
289 } else {
290 Close();
291 SignalCloseEvent(this, code);
292 }
293 }
294
295 NATInternalSocketFactory* sf_;
296 int family_;
297 int type_;
298 bool connected_;
299 SocketAddress remote_addr_;
300 SocketAddress server_addr_; // address of the NAT server
301 AsyncSocket* socket_;
302 char* buf_;
303 size_t size_;
304 };
305
306 // NATSocketFactory
NATSocketFactory(SocketFactory * factory,const SocketAddress & nat_udp_addr,const SocketAddress & nat_tcp_addr)307 NATSocketFactory::NATSocketFactory(SocketFactory* factory,
308 const SocketAddress& nat_udp_addr,
309 const SocketAddress& nat_tcp_addr)
310 : factory_(factory), nat_udp_addr_(nat_udp_addr),
311 nat_tcp_addr_(nat_tcp_addr) {
312 }
313
CreateSocket(int type)314 Socket* NATSocketFactory::CreateSocket(int type) {
315 return CreateSocket(AF_INET, type);
316 }
317
CreateSocket(int family,int type)318 Socket* NATSocketFactory::CreateSocket(int family, int type) {
319 return new NATSocket(this, family, type);
320 }
321
CreateAsyncSocket(int type)322 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int type) {
323 return CreateAsyncSocket(AF_INET, type);
324 }
325
CreateAsyncSocket(int family,int type)326 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int family, int type) {
327 return new NATSocket(this, family, type);
328 }
329
CreateInternalSocket(int family,int type,const SocketAddress & local_addr,SocketAddress * nat_addr)330 AsyncSocket* NATSocketFactory::CreateInternalSocket(int family, int type,
331 const SocketAddress& local_addr, SocketAddress* nat_addr) {
332 if (type == SOCK_STREAM) {
333 *nat_addr = nat_tcp_addr_;
334 } else {
335 *nat_addr = nat_udp_addr_;
336 }
337 return factory_->CreateAsyncSocket(family, type);
338 }
339
340 // NATSocketServer
NATSocketServer(SocketServer * server)341 NATSocketServer::NATSocketServer(SocketServer* server)
342 : server_(server), msg_queue_(NULL) {
343 }
344
GetTranslator(const SocketAddress & ext_ip)345 NATSocketServer::Translator* NATSocketServer::GetTranslator(
346 const SocketAddress& ext_ip) {
347 return nats_.Get(ext_ip);
348 }
349
AddTranslator(const SocketAddress & ext_ip,const SocketAddress & int_ip,NATType type)350 NATSocketServer::Translator* NATSocketServer::AddTranslator(
351 const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
352 // Fail if a translator already exists with this extternal address.
353 if (nats_.Get(ext_ip))
354 return NULL;
355
356 return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip));
357 }
358
RemoveTranslator(const SocketAddress & ext_ip)359 void NATSocketServer::RemoveTranslator(
360 const SocketAddress& ext_ip) {
361 nats_.Remove(ext_ip);
362 }
363
CreateSocket(int type)364 Socket* NATSocketServer::CreateSocket(int type) {
365 return CreateSocket(AF_INET, type);
366 }
367
CreateSocket(int family,int type)368 Socket* NATSocketServer::CreateSocket(int family, int type) {
369 return new NATSocket(this, family, type);
370 }
371
CreateAsyncSocket(int type)372 AsyncSocket* NATSocketServer::CreateAsyncSocket(int type) {
373 return CreateAsyncSocket(AF_INET, type);
374 }
375
CreateAsyncSocket(int family,int type)376 AsyncSocket* NATSocketServer::CreateAsyncSocket(int family, int type) {
377 return new NATSocket(this, family, type);
378 }
379
SetMessageQueue(MessageQueue * queue)380 void NATSocketServer::SetMessageQueue(MessageQueue* queue) {
381 msg_queue_ = queue;
382 server_->SetMessageQueue(queue);
383 }
384
Wait(int cms,bool process_io)385 bool NATSocketServer::Wait(int cms, bool process_io) {
386 return server_->Wait(cms, process_io);
387 }
388
WakeUp()389 void NATSocketServer::WakeUp() {
390 server_->WakeUp();
391 }
392
CreateInternalSocket(int family,int type,const SocketAddress & local_addr,SocketAddress * nat_addr)393 AsyncSocket* NATSocketServer::CreateInternalSocket(int family, int type,
394 const SocketAddress& local_addr, SocketAddress* nat_addr) {
395 AsyncSocket* socket = NULL;
396 Translator* nat = nats_.FindClient(local_addr);
397 if (nat) {
398 socket = nat->internal_factory()->CreateAsyncSocket(family, type);
399 *nat_addr = (type == SOCK_STREAM) ?
400 nat->internal_tcp_address() : nat->internal_udp_address();
401 } else {
402 socket = server_->CreateAsyncSocket(family, type);
403 }
404 return socket;
405 }
406
407 // NATSocketServer::Translator
Translator(NATSocketServer * server,NATType type,const SocketAddress & int_ip,SocketFactory * ext_factory,const SocketAddress & ext_ip)408 NATSocketServer::Translator::Translator(
409 NATSocketServer* server, NATType type, const SocketAddress& int_ip,
410 SocketFactory* ext_factory, const SocketAddress& ext_ip)
411 : server_(server) {
412 // Create a new private network, and a NATServer running on the private
413 // network that bridges to the external network. Also tell the private
414 // network to use the same message queue as us.
415 VirtualSocketServer* internal_server = new VirtualSocketServer(server_);
416 internal_server->SetMessageQueue(server_->queue());
417 internal_factory_.reset(internal_server);
418 nat_server_.reset(new NATServer(type, internal_server, int_ip, int_ip,
419 ext_factory, ext_ip));
420 }
421
422 NATSocketServer::Translator::~Translator() = default;
423
GetTranslator(const SocketAddress & ext_ip)424 NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator(
425 const SocketAddress& ext_ip) {
426 return nats_.Get(ext_ip);
427 }
428
AddTranslator(const SocketAddress & ext_ip,const SocketAddress & int_ip,NATType type)429 NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator(
430 const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
431 // Fail if a translator already exists with this extternal address.
432 if (nats_.Get(ext_ip))
433 return NULL;
434
435 AddClient(ext_ip);
436 return nats_.Add(ext_ip,
437 new Translator(server_, type, int_ip, server_, ext_ip));
438 }
RemoveTranslator(const SocketAddress & ext_ip)439 void NATSocketServer::Translator::RemoveTranslator(
440 const SocketAddress& ext_ip) {
441 nats_.Remove(ext_ip);
442 RemoveClient(ext_ip);
443 }
444
AddClient(const SocketAddress & int_ip)445 bool NATSocketServer::Translator::AddClient(
446 const SocketAddress& int_ip) {
447 // Fail if a client already exists with this internal address.
448 if (clients_.find(int_ip) != clients_.end())
449 return false;
450
451 clients_.insert(int_ip);
452 return true;
453 }
454
RemoveClient(const SocketAddress & int_ip)455 void NATSocketServer::Translator::RemoveClient(
456 const SocketAddress& int_ip) {
457 std::set<SocketAddress>::iterator it = clients_.find(int_ip);
458 if (it != clients_.end()) {
459 clients_.erase(it);
460 }
461 }
462
FindClient(const SocketAddress & int_ip)463 NATSocketServer::Translator* NATSocketServer::Translator::FindClient(
464 const SocketAddress& int_ip) {
465 // See if we have the requested IP, or any of our children do.
466 return (clients_.find(int_ip) != clients_.end()) ?
467 this : nats_.FindClient(int_ip);
468 }
469
470 // NATSocketServer::TranslatorMap
~TranslatorMap()471 NATSocketServer::TranslatorMap::~TranslatorMap() {
472 for (TranslatorMap::iterator it = begin(); it != end(); ++it) {
473 delete it->second;
474 }
475 }
476
Get(const SocketAddress & ext_ip)477 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Get(
478 const SocketAddress& ext_ip) {
479 TranslatorMap::iterator it = find(ext_ip);
480 return (it != end()) ? it->second : NULL;
481 }
482
Add(const SocketAddress & ext_ip,Translator * nat)483 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Add(
484 const SocketAddress& ext_ip, Translator* nat) {
485 (*this)[ext_ip] = nat;
486 return nat;
487 }
488
Remove(const SocketAddress & ext_ip)489 void NATSocketServer::TranslatorMap::Remove(
490 const SocketAddress& ext_ip) {
491 TranslatorMap::iterator it = find(ext_ip);
492 if (it != end()) {
493 delete it->second;
494 erase(it);
495 }
496 }
497
FindClient(const SocketAddress & int_ip)498 NATSocketServer::Translator* NATSocketServer::TranslatorMap::FindClient(
499 const SocketAddress& int_ip) {
500 Translator* nat = NULL;
501 for (TranslatorMap::iterator it = begin(); it != end() && !nat; ++it) {
502 nat = it->second->FindClient(int_ip);
503 }
504 return nat;
505 }
506
507 } // namespace rtc
508