1 // This file is public domain, in case it's useful to anyone. -comex
2 
3 // The central server implementation.
4 #include <arpa/inet.h>
5 #include <cerrno>
6 #include <chrono>
7 #include <cstdio>
8 #include <cstdlib>
9 #include <cstring>
10 #include <fcntl.h>
11 #include <netinet/in.h>
12 #include <sys/socket.h>
13 #include <sys/types.h>
14 #include <unistd.h>
15 #include <unordered_map>
16 #include <utility>
17 #include <vector>
18 
19 #ifdef HAVE_LIBSYSTEMD
20 #include <systemd/sd-daemon.h>
21 #endif
22 
23 #include "Common/Random.h"
24 #include "Common/TraversalProto.h"
25 
26 #define DEBUG 0
27 #define NUMBER_OF_TRIES 5
28 #define PORT 6262
29 
30 static u64 currentTime;
31 
32 struct OutgoingPacketInfo
33 {
34   TraversalPacket packet;
35   TraversalRequestId misc;
36   sockaddr_in6 dest;
37   int tries;
38   u64 sendTime;
39 };
40 
41 template <typename T>
42 struct EvictEntry
43 {
44   u64 updateTime;
45   T value;
46 };
47 
48 template <typename V>
49 struct EvictFindResult
50 {
51   bool found;
52   V* value;
53 };
54 
55 template <typename K, typename V>
EvictFind(std::unordered_map<K,EvictEntry<V>> & map,const K & key,bool refresh=false)56 EvictFindResult<V> EvictFind(std::unordered_map<K, EvictEntry<V>>& map, const K& key,
57                              bool refresh = false)
58 {
59 retry:
60   const u64 expiryTime = 30 * 1000000;  // 30s
61   EvictFindResult<V> result;
62   if (map.bucket_count())
63   {
64     auto bucket = map.bucket(key);
65     auto it = map.begin(bucket);
66     for (; it != map.end(bucket); ++it)
67     {
68       if (currentTime - it->second.updateTime > expiryTime)
69       {
70         map.erase(it->first);
71         goto retry;
72       }
73       if (it->first == key)
74       {
75         if (refresh)
76           it->second.updateTime = currentTime;
77         result.found = true;
78         result.value = &it->second.value;
79         return result;
80       }
81     }
82   }
83 #if DEBUG
84   printf("failed to find key '");
85   for (size_t i = 0; i < sizeof(key); i++)
86   {
87     printf("%02x", ((u8*)&key)[i]);
88   }
89   printf("'\n");
90 #endif
91   result.found = false;
92   return result;
93 }
94 
95 template <typename K, typename V>
EvictSet(std::unordered_map<K,EvictEntry<V>> & map,const K & key)96 V* EvictSet(std::unordered_map<K, EvictEntry<V>>& map, const K& key)
97 {
98   // can't use a local_iterator to emplace...
99   auto& result = map[key];
100   result.updateTime = currentTime;
101   return &result.value;
102 }
103 
104 namespace std
105 {
106 template <>
107 struct hash<TraversalHostId>
108 {
operator ()std::hash109   size_t operator()(const TraversalHostId& id) const
110   {
111     auto p = (u32*)id.data();
112     return p[0] ^ ((p[1] << 13) | (p[1] >> 19));
113   }
114 };
115 }  // namespace std
116 
117 static int sock;
118 static std::unordered_map<TraversalRequestId, OutgoingPacketInfo> outgoingPackets;
119 static std::unordered_map<TraversalHostId, EvictEntry<TraversalInetAddress>> connectedClients;
120 
MakeInetAddress(const sockaddr_in6 & addr)121 static TraversalInetAddress MakeInetAddress(const sockaddr_in6& addr)
122 {
123   if (addr.sin6_family != AF_INET6)
124   {
125     fprintf(stderr, "bad sockaddr_in6\n");
126     exit(1);
127   }
128   u32* words = (u32*)addr.sin6_addr.s6_addr;
129   TraversalInetAddress result = {0};
130   if (words[0] == 0 && words[1] == 0 && words[2] == 0xffff0000)
131   {
132     result.isIPV6 = false;
133     result.address[0] = words[3];
134   }
135   else
136   {
137     result.isIPV6 = true;
138     memcpy(result.address, words, sizeof(result.address));
139   }
140   result.port = addr.sin6_port;
141   return result;
142 }
143 
MakeSinAddr(const TraversalInetAddress & addr)144 static sockaddr_in6 MakeSinAddr(const TraversalInetAddress& addr)
145 {
146   sockaddr_in6 result;
147 #ifdef SIN6_LEN
148   result.sin6_len = sizeof(result);
149 #endif
150   result.sin6_family = AF_INET6;
151   result.sin6_port = addr.port;
152   result.sin6_flowinfo = 0;
153   if (addr.isIPV6)
154   {
155     memcpy(&result.sin6_addr, addr.address, 16);
156   }
157   else
158   {
159     u32* words = (u32*)result.sin6_addr.s6_addr;
160     words[0] = 0;
161     words[1] = 0;
162     words[2] = 0xffff0000;
163     words[3] = addr.address[0];
164   }
165   result.sin6_scope_id = 0;
166   return result;
167 }
168 
GetRandomHostId(TraversalHostId * hostId)169 static void GetRandomHostId(TraversalHostId* hostId)
170 {
171   char buf[9];
172   const u32 num = Common::Random::GenerateValue<u32>();
173   sprintf(buf, "%08x", num);
174   memcpy(hostId->data(), buf, 8);
175 }
176 
SenderName(sockaddr_in6 * addr)177 static const char* SenderName(sockaddr_in6* addr)
178 {
179   static char buf[INET6_ADDRSTRLEN + 10];
180   inet_ntop(PF_INET6, &addr->sin6_addr, buf, sizeof(buf));
181   sprintf(buf + strlen(buf), ":%d", ntohs(addr->sin6_port));
182   return buf;
183 }
184 
TrySend(const void * buffer,size_t size,sockaddr_in6 * addr)185 static void TrySend(const void* buffer, size_t size, sockaddr_in6* addr)
186 {
187 #if DEBUG
188   printf("-> %d %llu %s\n", ((TraversalPacket*)buffer)->type,
189          (long long)((TraversalPacket*)buffer)->requestId, SenderName(addr));
190 #endif
191   if ((size_t)sendto(sock, buffer, size, 0, (sockaddr*)addr, sizeof(*addr)) != size)
192   {
193     perror("sendto");
194   }
195 }
196 
AllocPacket(const sockaddr_in6 & dest,TraversalRequestId misc=0)197 static TraversalPacket* AllocPacket(const sockaddr_in6& dest, TraversalRequestId misc = 0)
198 {
199   TraversalRequestId requestId;
200   Common::Random::Generate(&requestId, sizeof(requestId));
201   OutgoingPacketInfo* info = &outgoingPackets[requestId];
202   info->dest = dest;
203   info->misc = misc;
204   info->tries = 0;
205   info->sendTime = currentTime;
206   TraversalPacket* result = &info->packet;
207   memset(result, 0, sizeof(*result));
208   result->requestId = requestId;
209   return result;
210 }
211 
SendPacket(OutgoingPacketInfo * info)212 static void SendPacket(OutgoingPacketInfo* info)
213 {
214   info->tries++;
215   info->sendTime = currentTime;
216   TrySend(&info->packet, sizeof(info->packet), &info->dest);
217 }
218 
ResendPackets()219 static void ResendPackets()
220 {
221   std::vector<std::pair<TraversalInetAddress, TraversalRequestId>> todoFailures;
222   todoFailures.clear();
223   for (auto it = outgoingPackets.begin(); it != outgoingPackets.end();)
224   {
225     OutgoingPacketInfo* info = &it->second;
226     if (currentTime - info->sendTime >= (u64)(300000 * info->tries))
227     {
228       if (info->tries >= NUMBER_OF_TRIES)
229       {
230         if (info->packet.type == TraversalPacketPleaseSendPacket)
231         {
232           todoFailures.push_back(std::make_pair(info->packet.pleaseSendPacket.address, info->misc));
233         }
234         it = outgoingPackets.erase(it);
235         continue;
236       }
237       else
238       {
239         SendPacket(info);
240       }
241     }
242     ++it;
243   }
244 
245   for (const auto& p : todoFailures)
246   {
247     TraversalPacket* fail = AllocPacket(MakeSinAddr(p.first));
248     fail->type = TraversalPacketConnectFailed;
249     fail->connectFailed.requestId = p.second;
250     fail->connectFailed.reason = TraversalConnectFailedClientDidntRespond;
251   }
252 }
253 
HandlePacket(TraversalPacket * packet,sockaddr_in6 * addr)254 static void HandlePacket(TraversalPacket* packet, sockaddr_in6* addr)
255 {
256 #if DEBUG
257   printf("<- %d %llu %s\n", packet->type, (long long)packet->requestId, SenderName(addr));
258 #endif
259   bool packetOk = true;
260   switch (packet->type)
261   {
262   case TraversalPacketAck:
263   {
264     auto it = outgoingPackets.find(packet->requestId);
265     if (it == outgoingPackets.end())
266       break;
267 
268     OutgoingPacketInfo* info = &it->second;
269 
270     if (info->packet.type == TraversalPacketPleaseSendPacket)
271     {
272       TraversalPacket* ready = AllocPacket(MakeSinAddr(info->packet.pleaseSendPacket.address));
273       if (packet->ack.ok)
274       {
275         ready->type = TraversalPacketConnectReady;
276         ready->connectReady.requestId = info->misc;
277         ready->connectReady.address = MakeInetAddress(info->dest);
278       }
279       else
280       {
281         ready->type = TraversalPacketConnectFailed;
282         ready->connectFailed.requestId = info->misc;
283         ready->connectFailed.reason = TraversalConnectFailedClientFailure;
284       }
285     }
286 
287     outgoingPackets.erase(it);
288     break;
289   }
290   case TraversalPacketPing:
291   {
292     auto r = EvictFind(connectedClients, packet->ping.hostId, true);
293     packetOk = r.found;
294     break;
295   }
296   case TraversalPacketHelloFromClient:
297   {
298     u8 ok = packet->helloFromClient.protoVersion <= TraversalProtoVersion;
299     TraversalPacket* reply = AllocPacket(*addr);
300     reply->type = TraversalPacketHelloFromServer;
301     reply->helloFromServer.ok = ok;
302     if (ok)
303     {
304       TraversalHostId hostId;
305       TraversalInetAddress* iaddr;
306       // not that there is any significant change of
307       // duplication, but...
308       GetRandomHostId(&hostId);
309       while (true)
310       {
311         auto r = EvictFind(connectedClients, hostId);
312         if (!r.found)
313         {
314           iaddr = EvictSet(connectedClients, hostId);
315           break;
316         }
317       }
318 
319       *iaddr = MakeInetAddress(*addr);
320 
321       reply->helloFromServer.yourAddress = *iaddr;
322       reply->helloFromServer.yourHostId = hostId;
323     }
324     break;
325   }
326   case TraversalPacketConnectPlease:
327   {
328     TraversalHostId& hostId = packet->connectPlease.hostId;
329     auto r = EvictFind(connectedClients, hostId);
330     if (!r.found)
331     {
332       TraversalPacket* reply = AllocPacket(*addr);
333       reply->type = TraversalPacketConnectFailed;
334       reply->connectFailed.requestId = packet->requestId;
335       reply->connectFailed.reason = TraversalConnectFailedNoSuchClient;
336     }
337     else
338     {
339       TraversalPacket* please = AllocPacket(MakeSinAddr(*r.value), packet->requestId);
340       please->type = TraversalPacketPleaseSendPacket;
341       please->pleaseSendPacket.address = MakeInetAddress(*addr);
342     }
343     break;
344   }
345   default:
346     fprintf(stderr, "received unknown packet type %d from %s\n", packet->type, SenderName(addr));
347   }
348   if (packet->type != TraversalPacketAck)
349   {
350     TraversalPacket ack = {};
351     ack.type = TraversalPacketAck;
352     ack.requestId = packet->requestId;
353     ack.ack.ok = packetOk;
354     TrySend(&ack, sizeof(ack), addr);
355   }
356 }
357 
main()358 int main()
359 {
360   int rv;
361   sock = socket(PF_INET6, SOCK_DGRAM, 0);
362   if (sock == -1)
363   {
364     perror("socket");
365     return 1;
366   }
367   int no = 0;
368   rv = setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, &no, sizeof(no));
369   if (rv < 0)
370   {
371     perror("setsockopt IPV6_V6ONLY");
372     return 1;
373   }
374   in6_addr any = IN6ADDR_ANY_INIT;
375   sockaddr_in6 addr;
376 #ifdef SIN6_LEN
377   addr.sin6_len = sizeof(addr);
378 #endif
379   addr.sin6_family = AF_INET6;
380   addr.sin6_port = htons(PORT);
381   addr.sin6_flowinfo = 0;
382   addr.sin6_addr = any;
383   addr.sin6_scope_id = 0;
384 
385   rv = bind(sock, (sockaddr*)&addr, sizeof(addr));
386   if (rv < 0)
387   {
388     perror("bind");
389     return 1;
390   }
391 
392   timeval tv;
393   tv.tv_sec = 0;
394   tv.tv_usec = 300000;
395   rv = setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
396   if (rv < 0)
397   {
398     perror("setsockopt SO_RCVTIMEO");
399     return 1;
400   }
401 
402 #ifdef HAVE_LIBSYSTEMD
403   sd_notifyf(0, "READY=1\nSTATUS=Listening on port %d", PORT);
404 #endif
405 
406   while (true)
407   {
408     sockaddr_in6 raddr;
409     socklen_t addrLen = sizeof(raddr);
410     TraversalPacket packet;
411     // note: switch to recvmmsg (yes, mmsg) if this becomes
412     // expensive
413     rv = recvfrom(sock, &packet, sizeof(packet), 0, (sockaddr*)&raddr, &addrLen);
414     currentTime = std::chrono::duration_cast<std::chrono::microseconds>(
415                       std::chrono::system_clock::now().time_since_epoch())
416                       .count();
417     if (rv < 0)
418     {
419       if (errno != EINTR && errno != EAGAIN)
420       {
421         perror("recvfrom");
422         return 1;
423       }
424     }
425     else if ((size_t)rv < sizeof(packet))
426     {
427       fprintf(stderr, "received short packet from %s\n", SenderName(&raddr));
428     }
429     else
430     {
431       HandlePacket(&packet, &raddr);
432     }
433     ResendPackets();
434 #ifdef HAVE_LIBSYSTEMD
435     sd_notify(0, "WATCHDOG=1");
436 #endif
437   }
438 }
439