1 /*
2 * libjingle
3 * Copyright 2004--2005, Google Inc.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 * 1. Redistributions of source code must retain the above copyright notice,
9 * this list of conditions and the following disclaimer.
10 * 2. Redistributions in binary form must reproduce the above copyright notice,
11 * this list of conditions and the following disclaimer in the documentation
12 * and/or other materials provided with the distribution.
13 * 3. The name of the author may not be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
17 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
18 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
19 * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
20 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
22 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
23 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
24 * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
25 * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28 #include "talk/base/natsocketfactory.h"
29
30 #include "talk/base/logging.h"
31 #include "talk/base/natserver.h"
32 #include "talk/base/virtualsocketserver.h"
33
34 namespace talk_base {
35
36 // Packs the given socketaddress into the buffer in buf, in the quasi-STUN
37 // format that the natserver uses.
38 // Returns 0 if an invalid address is passed.
PackAddressForNAT(char * buf,size_t buf_size,const SocketAddress & remote_addr)39 size_t PackAddressForNAT(char* buf, size_t buf_size,
40 const SocketAddress& remote_addr) {
41 const IPAddress& ip = remote_addr.ipaddr();
42 int family = ip.family();
43 buf[0] = 0;
44 buf[1] = family;
45 // Writes the port.
46 *(reinterpret_cast<uint16*>(&buf[2])) = HostToNetwork16(remote_addr.port());
47 if (family == AF_INET) {
48 ASSERT(buf_size >= kNATEncodedIPv4AddressSize);
49 in_addr v4addr = ip.ipv4_address();
50 std::memcpy(&buf[4], &v4addr, kNATEncodedIPv4AddressSize - 4);
51 return kNATEncodedIPv4AddressSize;
52 } else if (family == AF_INET6) {
53 ASSERT(buf_size >= kNATEncodedIPv6AddressSize);
54 in6_addr v6addr = ip.ipv6_address();
55 std::memcpy(&buf[4], &v6addr, kNATEncodedIPv6AddressSize - 4);
56 return kNATEncodedIPv6AddressSize;
57 }
58 return 0U;
59 }
60
61 // Decodes the remote address from a packet that has been encoded with the nat's
62 // quasi-STUN format. Returns the length of the address (i.e., the offset into
63 // data where the original packet starts).
UnpackAddressFromNAT(const char * buf,size_t buf_size,SocketAddress * remote_addr)64 size_t UnpackAddressFromNAT(const char* buf, size_t buf_size,
65 SocketAddress* remote_addr) {
66 ASSERT(buf_size >= 8);
67 ASSERT(buf[0] == 0);
68 int family = buf[1];
69 uint16 port = NetworkToHost16(*(reinterpret_cast<const uint16*>(&buf[2])));
70 if (family == AF_INET) {
71 const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]);
72 *remote_addr = SocketAddress(IPAddress(*v4addr), port);
73 return kNATEncodedIPv4AddressSize;
74 } else if (family == AF_INET6) {
75 ASSERT(buf_size >= 20);
76 const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]);
77 *remote_addr = SocketAddress(IPAddress(*v6addr), port);
78 return kNATEncodedIPv6AddressSize;
79 }
80 return 0U;
81 }
82
83 // NATSocket
84 class NATSocket : public AsyncSocket, public sigslot::has_slots<> {
85 public:
NATSocket(NATInternalSocketFactory * sf,int family,int type)86 explicit NATSocket(NATInternalSocketFactory* sf, int family, int type)
87 : sf_(sf), family_(family), type_(type), async_(true), connected_(false),
88 socket_(NULL), buf_(NULL), size_(0) {
89 }
90
~NATSocket()91 virtual ~NATSocket() {
92 delete socket_;
93 delete[] buf_;
94 }
95
GetLocalAddress() const96 virtual SocketAddress GetLocalAddress() const {
97 return (socket_) ? socket_->GetLocalAddress() : SocketAddress();
98 }
99
GetRemoteAddress() const100 virtual SocketAddress GetRemoteAddress() const {
101 return remote_addr_; // will be ANY if not connected
102 }
103
Bind(const SocketAddress & addr)104 virtual int Bind(const SocketAddress& addr) {
105 if (socket_) { // already bound, bubble up error
106 return -1;
107 }
108
109 int result;
110 socket_ = sf_->CreateInternalSocket(type_, family_, addr, &server_addr_);
111 result = (socket_) ? socket_->Bind(addr) : -1;
112 if (result >= 0) {
113 socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent);
114 socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent);
115 socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent);
116 socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent);
117 } else {
118 server_addr_.Clear();
119 delete socket_;
120 socket_ = NULL;
121 }
122
123 return result;
124 }
125
Connect(const SocketAddress & addr)126 virtual int Connect(const SocketAddress& addr) {
127 if (!socket_) { // socket must be bound, for now
128 return -1;
129 }
130
131 int result = 0;
132 if (type_ == SOCK_STREAM) {
133 result = socket_->Connect(server_addr_.IsAny() ? addr : server_addr_);
134 } else {
135 connected_ = true;
136 }
137
138 if (result >= 0) {
139 remote_addr_ = addr;
140 }
141
142 return result;
143 }
144
Send(const void * data,size_t size)145 virtual int Send(const void* data, size_t size) {
146 ASSERT(connected_);
147 return SendTo(data, size, remote_addr_);
148 }
149
SendTo(const void * data,size_t size,const SocketAddress & addr)150 virtual int SendTo(const void* data, size_t size, const SocketAddress& addr) {
151 ASSERT(!connected_ || addr == remote_addr_);
152 if (server_addr_.IsAny() || type_ == SOCK_STREAM) {
153 return socket_->SendTo(data, size, addr);
154 }
155 // This array will be too large for IPv4 packets, but only by 12 bytes.
156 scoped_array<char> buf(new char[size + kNATEncodedIPv6AddressSize]);
157 size_t addrlength = PackAddressForNAT(buf.get(),
158 size + kNATEncodedIPv6AddressSize,
159 addr);
160 size_t encoded_size = size + addrlength;
161 std::memcpy(buf.get() + addrlength, data, size);
162 int result = socket_->SendTo(buf.get(), encoded_size, server_addr_);
163 if (result >= 0) {
164 ASSERT(result == static_cast<int>(encoded_size));
165 result = result - static_cast<int>(addrlength);
166 }
167 return result;
168 }
169
Recv(void * data,size_t size)170 virtual int Recv(void* data, size_t size) {
171 SocketAddress addr;
172 return RecvFrom(data, size, &addr);
173 }
174
RecvFrom(void * data,size_t size,SocketAddress * out_addr)175 virtual int RecvFrom(void* data, size_t size, SocketAddress *out_addr) {
176 if (server_addr_.IsAny() || type_ == SOCK_STREAM) {
177 return socket_->RecvFrom(data, size, out_addr);
178 }
179 // Make sure we have enough room to read the requested amount plus the
180 // largest possible header address.
181 SocketAddress remote_addr;
182 Grow(size + kNATEncodedIPv6AddressSize);
183
184 // Read the packet from the socket.
185 int result = socket_->RecvFrom(buf_, size_, &remote_addr);
186 if (result >= 0) {
187 ASSERT(remote_addr == server_addr_);
188
189 // TODO: we need better framing so we know how many bytes we can
190 // return before we need to read the next address. For UDP, this will be
191 // fine as long as the reader always reads everything in the packet.
192 ASSERT((size_t)result < size_);
193
194 // Decode the wire packet into the actual results.
195 SocketAddress real_remote_addr;
196 size_t addrlength =
197 UnpackAddressFromNAT(buf_, result, &real_remote_addr);
198 std::memcpy(data, buf_ + addrlength, result - addrlength);
199
200 // Make sure this packet should be delivered before returning it.
201 if (!connected_ || (real_remote_addr == remote_addr_)) {
202 if (out_addr)
203 *out_addr = real_remote_addr;
204 result = result - static_cast<int>(addrlength);
205 } else {
206 LOG(LS_ERROR) << "Dropping packet from unknown remote address: "
207 << real_remote_addr.ToString();
208 result = 0; // Tell the caller we didn't read anything
209 }
210 }
211
212 return result;
213 }
214
Close()215 virtual int Close() {
216 int result = 0;
217 if (socket_) {
218 result = socket_->Close();
219 if (result >= 0) {
220 connected_ = false;
221 remote_addr_ = SocketAddress();
222 delete socket_;
223 socket_ = NULL;
224 }
225 }
226 return result;
227 }
228
Listen(int backlog)229 virtual int Listen(int backlog) {
230 return socket_->Listen(backlog);
231 }
Accept(SocketAddress * paddr)232 virtual AsyncSocket* Accept(SocketAddress *paddr) {
233 return socket_->Accept(paddr);
234 }
GetError() const235 virtual int GetError() const {
236 return socket_->GetError();
237 }
SetError(int error)238 virtual void SetError(int error) {
239 socket_->SetError(error);
240 }
GetState() const241 virtual ConnState GetState() const {
242 return connected_ ? CS_CONNECTED : CS_CLOSED;
243 }
EstimateMTU(uint16 * mtu)244 virtual int EstimateMTU(uint16* mtu) {
245 return socket_->EstimateMTU(mtu);
246 }
GetOption(Option opt,int * value)247 virtual int GetOption(Option opt, int* value) {
248 return socket_->GetOption(opt, value);
249 }
SetOption(Option opt,int value)250 virtual int SetOption(Option opt, int value) {
251 return socket_->SetOption(opt, value);
252 }
253
OnConnectEvent(AsyncSocket * socket)254 void OnConnectEvent(AsyncSocket* socket) {
255 // If we're NATed, we need to send a request with the real addr to use.
256 ASSERT(socket == socket_);
257 if (server_addr_.IsAny()) {
258 connected_ = true;
259 SignalConnectEvent(this);
260 } else {
261 SendConnectRequest();
262 }
263 }
OnReadEvent(AsyncSocket * socket)264 void OnReadEvent(AsyncSocket* socket) {
265 // If we're NATed, we need to process the connect reply.
266 ASSERT(socket == socket_);
267 if (type_ == SOCK_STREAM && !server_addr_.IsAny() && !connected_) {
268 HandleConnectReply();
269 } else {
270 SignalReadEvent(this);
271 }
272 }
OnWriteEvent(AsyncSocket * socket)273 void OnWriteEvent(AsyncSocket* socket) {
274 ASSERT(socket == socket_);
275 SignalWriteEvent(this);
276 }
OnCloseEvent(AsyncSocket * socket,int error)277 void OnCloseEvent(AsyncSocket* socket, int error) {
278 ASSERT(socket == socket_);
279 SignalCloseEvent(this, error);
280 }
281
282 private:
283 // Makes sure the buffer is at least the given size.
Grow(size_t new_size)284 void Grow(size_t new_size) {
285 if (size_ < new_size) {
286 delete[] buf_;
287 size_ = new_size;
288 buf_ = new char[size_];
289 }
290 }
291
292 // Sends the destination address to the server to tell it to connect.
SendConnectRequest()293 void SendConnectRequest() {
294 char buf[256];
295 size_t length = PackAddressForNAT(buf, ARRAY_SIZE(buf), remote_addr_);
296 socket_->Send(buf, length);
297 }
298
299 // Handles the byte sent back from the server and fires the appropriate event.
HandleConnectReply()300 void HandleConnectReply() {
301 char code;
302 socket_->Recv(&code, sizeof(code));
303 if (code == 0) {
304 SignalConnectEvent(this);
305 } else {
306 Close();
307 SignalCloseEvent(this, code);
308 }
309 }
310
311 NATInternalSocketFactory* sf_;
312 int family_;
313 int type_;
314 bool async_;
315 bool connected_;
316 SocketAddress remote_addr_;
317 SocketAddress server_addr_; // address of the NAT server
318 AsyncSocket* socket_;
319 char* buf_;
320 size_t size_;
321 };
322
323 // NATSocketFactory
NATSocketFactory(SocketFactory * factory,const SocketAddress & nat_addr)324 NATSocketFactory::NATSocketFactory(SocketFactory* factory,
325 const SocketAddress& nat_addr)
326 : factory_(factory), nat_addr_(nat_addr) {
327 }
328
CreateSocket(int type)329 Socket* NATSocketFactory::CreateSocket(int type) {
330 return CreateSocket(AF_INET, type);
331 }
332
CreateSocket(int family,int type)333 Socket* NATSocketFactory::CreateSocket(int family, int type) {
334 return new NATSocket(this, family, type);
335 }
336
CreateAsyncSocket(int type)337 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int type) {
338 return CreateAsyncSocket(AF_INET, type);
339 }
340
CreateAsyncSocket(int family,int type)341 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int family, int type) {
342 return new NATSocket(this, family, type);
343 }
344
CreateInternalSocket(int family,int type,const SocketAddress & local_addr,SocketAddress * nat_addr)345 AsyncSocket* NATSocketFactory::CreateInternalSocket(int family, int type,
346 const SocketAddress& local_addr, SocketAddress* nat_addr) {
347 *nat_addr = nat_addr_;
348 return factory_->CreateAsyncSocket(family, type);
349 }
350
351 // NATSocketServer
NATSocketServer(SocketServer * server)352 NATSocketServer::NATSocketServer(SocketServer* server)
353 : server_(server), msg_queue_(NULL) {
354 }
355
GetTranslator(const SocketAddress & ext_ip)356 NATSocketServer::Translator* NATSocketServer::GetTranslator(
357 const SocketAddress& ext_ip) {
358 return nats_.Get(ext_ip);
359 }
360
AddTranslator(const SocketAddress & ext_ip,const SocketAddress & int_ip,NATType type)361 NATSocketServer::Translator* NATSocketServer::AddTranslator(
362 const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
363 // Fail if a translator already exists with this extternal address.
364 if (nats_.Get(ext_ip))
365 return NULL;
366
367 return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip));
368 }
369
RemoveTranslator(const SocketAddress & ext_ip)370 void NATSocketServer::RemoveTranslator(
371 const SocketAddress& ext_ip) {
372 nats_.Remove(ext_ip);
373 }
374
CreateSocket(int type)375 Socket* NATSocketServer::CreateSocket(int type) {
376 return CreateSocket(AF_INET, type);
377 }
378
CreateSocket(int family,int type)379 Socket* NATSocketServer::CreateSocket(int family, int type) {
380 return new NATSocket(this, family, type);
381 }
382
CreateAsyncSocket(int type)383 AsyncSocket* NATSocketServer::CreateAsyncSocket(int type) {
384 return CreateAsyncSocket(AF_INET, type);
385 }
386
CreateAsyncSocket(int family,int type)387 AsyncSocket* NATSocketServer::CreateAsyncSocket(int family, int type) {
388 return new NATSocket(this, family, type);
389 }
390
CreateInternalSocket(int type,int family,const SocketAddress & local_addr,SocketAddress * nat_addr)391 AsyncSocket* NATSocketServer::CreateInternalSocket(int type, int family,
392 const SocketAddress& local_addr, SocketAddress* nat_addr) {
393 AsyncSocket* socket = NULL;
394 Translator* nat = nats_.FindClient(local_addr);
395 if (nat) {
396 socket = nat->internal_factory()->CreateAsyncSocket(family, type);
397 *nat_addr = (type == SOCK_STREAM) ?
398 nat->internal_tcp_address() : nat->internal_address();
399 } else {
400 socket = server_->CreateAsyncSocket(family, type);
401 }
402 return socket;
403 }
404
405 // NATSocketServer::Translator
Translator(NATSocketServer * server,NATType type,const SocketAddress & int_ip,SocketFactory * ext_factory,const SocketAddress & ext_ip)406 NATSocketServer::Translator::Translator(
407 NATSocketServer* server, NATType type, const SocketAddress& int_ip,
408 SocketFactory* ext_factory, const SocketAddress& ext_ip)
409 : server_(server) {
410 // Create a new private network, and a NATServer running on the private
411 // network that bridges to the external network. Also tell the private
412 // network to use the same message queue as us.
413 VirtualSocketServer* internal_server = new VirtualSocketServer(server_);
414 internal_server->SetMessageQueue(server_->queue());
415 internal_factory_.reset(internal_server);
416 nat_server_.reset(new NATServer(type, internal_server, int_ip,
417 ext_factory, ext_ip));
418 }
419
GetTranslator(const SocketAddress & ext_ip)420 NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator(
421 const SocketAddress& ext_ip) {
422 return nats_.Get(ext_ip);
423 }
424
AddTranslator(const SocketAddress & ext_ip,const SocketAddress & int_ip,NATType type)425 NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator(
426 const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
427 // Fail if a translator already exists with this extternal address.
428 if (nats_.Get(ext_ip))
429 return NULL;
430
431 AddClient(ext_ip);
432 return nats_.Add(ext_ip,
433 new Translator(server_, type, int_ip, server_, ext_ip));
434 }
RemoveTranslator(const SocketAddress & ext_ip)435 void NATSocketServer::Translator::RemoveTranslator(
436 const SocketAddress& ext_ip) {
437 nats_.Remove(ext_ip);
438 RemoveClient(ext_ip);
439 }
440
AddClient(const SocketAddress & int_ip)441 bool NATSocketServer::Translator::AddClient(
442 const SocketAddress& int_ip) {
443 // Fail if a client already exists with this internal address.
444 if (clients_.find(int_ip) != clients_.end())
445 return false;
446
447 clients_.insert(int_ip);
448 return true;
449 }
450
RemoveClient(const SocketAddress & int_ip)451 void NATSocketServer::Translator::RemoveClient(
452 const SocketAddress& int_ip) {
453 std::set<SocketAddress>::iterator it = clients_.find(int_ip);
454 if (it != clients_.end()) {
455 clients_.erase(it);
456 }
457 }
458
FindClient(const SocketAddress & int_ip)459 NATSocketServer::Translator* NATSocketServer::Translator::FindClient(
460 const SocketAddress& int_ip) {
461 // See if we have the requested IP, or any of our children do.
462 return (clients_.find(int_ip) != clients_.end()) ?
463 this : nats_.FindClient(int_ip);
464 }
465
466 // NATSocketServer::TranslatorMap
~TranslatorMap()467 NATSocketServer::TranslatorMap::~TranslatorMap() {
468 for (TranslatorMap::iterator it = begin(); it != end(); ++it) {
469 delete it->second;
470 }
471 }
472
Get(const SocketAddress & ext_ip)473 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Get(
474 const SocketAddress& ext_ip) {
475 TranslatorMap::iterator it = find(ext_ip);
476 return (it != end()) ? it->second : NULL;
477 }
478
Add(const SocketAddress & ext_ip,Translator * nat)479 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Add(
480 const SocketAddress& ext_ip, Translator* nat) {
481 (*this)[ext_ip] = nat;
482 return nat;
483 }
484
Remove(const SocketAddress & ext_ip)485 void NATSocketServer::TranslatorMap::Remove(
486 const SocketAddress& ext_ip) {
487 TranslatorMap::iterator it = find(ext_ip);
488 if (it != end()) {
489 delete it->second;
490 erase(it);
491 }
492 }
493
FindClient(const SocketAddress & int_ip)494 NATSocketServer::Translator* NATSocketServer::TranslatorMap::FindClient(
495 const SocketAddress& int_ip) {
496 Translator* nat = NULL;
497 for (TranslatorMap::iterator it = begin(); it != end() && !nat; ++it) {
498 nat = it->second->FindClient(int_ip);
499 }
500 return nat;
501 }
502
503 } // namespace talk_base
504