1 // Copyright (c) 2018 IIS (The Internet Foundation in Sweden)
2 // Written by Göran Andersson <initgoran@gmail.com>
3 
4 #ifdef _WIN32
5 #include <ws2tcpip.h>
6 #include <time.h>
7 #define NOMINMAX
8 #include <windows.h>
9 #include <winsock2.h>
10 #pragma comment(lib, "ws2_32.lib")
11 #else
12 #include <arpa/inet.h>
13 #include <unistd.h>
14 #include <netinet/in.h>
15 #include <netinet/ip.h>
16 #include <sys/socket.h>
17 #include <sys/time.h>
18 #include <netdb.h>
19 #include <net/if.h>
20 #endif
21 #ifdef __linux
22 #include <netinet/tcp.h>
23 #endif
24 
25 #include <fcntl.h>
26 #include <sys/types.h>
27 
28 #include <cstdlib>
29 #include <cstdio>
30 #include <cstring>
31 #include <cctype>
32 #include <cerrno>
33 
34 #include <string>
35 #include <iomanip>
36 #include <sstream>
37 #include <map>
38 
39 #include "socket.h"
40 #include "task.h"
41 
Socket(const std::string & label,Task * owner,const std::string & hostname,uint16_t port)42 Socket::Socket(const std::string &label, Task *owner,
43                const std::string &hostname, uint16_t port) :
44     Logger(label),
45     _owner(owner),
46     _hostname(hostname),
47     _port(port),
48     _state(PollState::NONE)
49 {
50 #ifndef _WIN32
51     if (!port && hostname == "UnixDomain") {
52         int pair_sd[2];
53         if (socketpair(AF_UNIX, SOCK_STREAM, 0, pair_sd) < 0) {
54             errno_log() << "cannot create socket pair";
55             _socket = 0;
56         } else {
57             _socket = pair_sd[0];
58             unix_domain_peer = pair_sd[1];
59             fcntl(pair_sd[0], F_SETFL, O_NONBLOCK|O_CLOEXEC);
60             fcntl(pair_sd[1], F_SETFL, O_NONBLOCK);
61         }
62         return;
63     }
64 #endif
65     _socket = -1;
66     _peer_label = _hostname + ":" + std::to_string(_port);
67 }
68 
69 // TODO: take initial state as a parameter, default PollState::READ.
Socket(const std::string & label,Task * owner,int fd)70 Socket::Socket(const std::string &label, Task *owner, int fd) :
71     Logger(label),
72     _owner(owner),
73     _socket(fd),
74     _hostname(""),
75     _port(0),
76     _state(PollState::READ)
77 {
78 }
79 
~Socket()80 Socket::~Socket() {
81     if (_socket >= 0) {
82         closeSocket(_socket);
83         log() << "closed socket " << _socket;
84     }
85 }
86 
87 namespace {
88 #ifdef USE_THREADS
89     thread_local
90 #endif
91     std::map<std::string, struct addrinfo *> dns_cache;
92 }
93 
clearCache()94 void Socket::clearCache() {
95     for (auto p : dns_cache)
96         freeaddrinfo(p.second);
97     dns_cache.clear();
98 }
99 
getAddressInfo(uint16_t iptype)100 struct addrinfo *Socket::getAddressInfo(uint16_t iptype) {
101     auto it = dns_cache.find(_peer_label);
102     if (it == dns_cache.end()) {
103         struct addrinfo hints, *addressInfo;
104         memset(&hints, 0, sizeof hints);
105         hints.ai_family = AF_UNSPEC;
106         hints.ai_socktype = SOCK_STREAM;
107         hints.ai_flags = AI_ADDRCONFIG;
108 
109         const char *hostaddr;
110         if (_hostname.empty()) {
111             hints.ai_family = AF_INET6;
112             hints.ai_flags |= AI_PASSIVE;
113             log() << "wildcard address *:" << _port;
114             hostaddr = nullptr;
115         } else if (_hostname.find_first_not_of("1234567890.:") ==
116                    std::string::npos) {
117             hints.ai_flags |= AI_NUMERICHOST;
118             log() << "numeric address " << _hostname;
119             hostaddr = _hostname.c_str();
120         } else {
121             log() << "dns lookup " << _hostname;
122             hostaddr = _hostname.c_str();
123         }
124 
125         int res = getaddrinfo(hostaddr, std::to_string(_port).c_str(),
126                               &hints, &addressInfo);
127         if (res != 0) {
128             err_log() << "lookup failed: " << gai_strerror(res);
129             return nullptr;
130         } else if (!addressInfo) {
131             err_log() << "no valid address found";
132             return nullptr;
133         }
134         if (!(hints.ai_flags & AI_NUMERICHOST)) {
135             char ip[INET6_ADDRSTRLEN];
136             struct sockaddr *addr = addressInfo->ai_addr;
137             if (addressInfo->ai_family == AF_INET) {
138                 struct sockaddr_in *s = reinterpret_cast<sockaddr_in *>(addr);
139                 inet_ntop(AF_INET, &s->sin_addr, ip, sizeof ip);
140             } else {
141                 struct sockaddr_in6 *s = reinterpret_cast<sockaddr_in6 *>(addr);
142                 inet_ntop(AF_INET6, &s->sin6_addr, ip, sizeof ip);
143             }
144             log() << "lookup done: " << ip;
145         }
146 
147         auto p2 = dns_cache.insert(std::make_pair(_peer_label, addressInfo));
148         it = p2.first;
149     }
150     if (iptype) {
151         int fam = (iptype == 6) ? AF_INET6 : AF_INET;
152         struct addrinfo *ai = it->second;
153         while (ai) {
154             if (ai->ai_family == fam)
155                 return ai;
156             ai = ai->ai_next;
157         }
158     }
159     return it->second;
160 }
161 
createNonBlockingSocket(struct addrinfo * addressEntry,struct addrinfo * localAddr)162 void Socket::createNonBlockingSocket(struct addrinfo *addressEntry,
163                                      struct addrinfo *localAddr) {
164     if (_socket >= 0) {
165         err_log() << "socket already exists";
166         return;
167     }
168     int fd = ::socket(addressEntry->ai_family, addressEntry->ai_socktype,
169                       addressEntry->ai_protocol);
170     if (fd == -1) {
171         errno_log() << "cannot create socket";
172         return;
173     }
174     if (!setNonBlocking(fd)) {
175         closeSocket(fd);
176         return;
177     }
178 
179     if (localAddr && bind(fd, localAddr->ai_addr, localAddr->ai_addrlen) != 0) {
180         errno_log() << "cannot bind to local address";
181         return;
182     }
183 
184     int res = connect(fd, addressEntry->ai_addr, addressEntry->ai_addrlen);
185     if (res == -1 && !isTempError()) {
186         errno_log() << "connect error";
187         closeSocket(fd);
188         return;
189     }
190 
191     // All good, let's keep the socket:
192     _socket = fd;
193     _state = PollState::CONNECTING;
194 }
195 
closeSocket(int fd)196 int Socket::closeSocket(int fd) {
197 #ifdef _WIN32
198      return closesocket(fd);
199 #else
200      return close(fd);
201 #endif
202 }
203 
socketInError(int fd)204 bool Socket::socketInError(int fd) {
205     int res;
206     socklen_t res_len = sizeof(res);
207 #ifdef _WIN32
208     if (getsockopt(fd, SOL_SOCKET, SO_ERROR, (char *)&res, &res_len) < 0)
209 #else
210     if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &res, &res_len) < 0)
211 #endif
212         return true;
213     if (!res)
214         return false;
215     errno = res;
216     return !isTempError();
217 }
218 
setNonBlocking(int fd)219 bool Socket::setNonBlocking(int fd) {
220 #ifdef __APPLE__
221     // SO_NOSIGPIPE only for OS X
222     int value = 1;
223     int status = setsockopt(fd, SOL_SOCKET, SO_NOSIGPIPE,
224                             &value, sizeof(value));
225     if (status != 0) {
226         errno_log() << "cannot set SO_NOSIGPIPE";
227     }
228 #endif
229 
230 #ifdef _WIN32
231     u_long enabledParameter = 1;
232     int nonBlockingResult = ioctlsocket(fd, FIONBIO, &enabledParameter);
233 
234     if (nonBlockingResult == -1) {
235         errno_log() << "cannot set socket non-blocking";
236         closesocket(fd);
237         return false;
238     }
239 #else
240     int nonBlockingResult = fcntl(fd, F_SETFL, O_NONBLOCK);
241     if (nonBlockingResult == -1) {
242         errno_log() << "cannot set socket non-blocking";
243         close(fd);
244         return false;
245     }
246 #endif
247 #ifdef __linux
248     int flag = 1;
249     int result = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY,
250                             reinterpret_cast<char *>(&flag), sizeof(int));
251     if (result < 0)
252         errno_log() << "cannot set TCP_NODELAY";
253 #endif
254     return true;
255 }
256 
getIp(struct sockaddr * address,uint16_t * port)257 const char *Socket::getIp(struct sockaddr *address, uint16_t *port) {
258 #ifdef USE_THREADS
259     thread_local
260 #endif
261     static char client_ip[INET6_ADDRSTRLEN];
262     if (address->sa_family == AF_INET) {
263         struct sockaddr_in *s = reinterpret_cast<sockaddr_in *>(address);
264         inet_ntop(AF_INET, &s->sin_addr, client_ip, INET6_ADDRSTRLEN);
265         if (port)
266             *port = ntohs(s->sin_port);
267     } else {
268         struct sockaddr_in6 *s = reinterpret_cast<sockaddr_in6 *>(address);
269         inet_ntop(AF_INET6, &s->sin6_addr, client_ip, INET6_ADDRSTRLEN);
270         if (port)
271             *port = ntohs(s->sin6_port);
272     }
273     if (strncmp(client_ip, "::ffff:", 7) == 0)
274         return client_ip+7;
275     else
276         return client_ip;
277 }
278 
getIp(struct addrinfo * address,uint16_t * port)279 const char *Socket::getIp(struct addrinfo *address, uint16_t *port) {
280     return getIp(address->ai_addr, port);
281 }
282 
getIp(int fd,uint16_t * port,bool peer)283 const char *Socket::getIp(int fd, uint16_t *port, bool peer) {
284 #ifdef USE_THREADS
285     thread_local
286 #endif
287     static char client_ip[INET6_ADDRSTRLEN];
288     static const char *no_ip = "unknown IP";
289 
290     struct sockaddr_storage address;
291     memset(&address, 0, sizeof address);
292     socklen_t addrlen = sizeof(address);
293 
294     int ret = peer ?
295         getpeername(fd, reinterpret_cast<sockaddr *>(&address), &addrlen) :
296         getsockname(fd, reinterpret_cast<sockaddr *>(&address), &addrlen);
297 
298     if (ret < 0) {
299         return no_ip;
300     } else {
301         if (address.ss_family == AF_INET) {
302             struct sockaddr_in *s = reinterpret_cast<sockaddr_in *>(&address);
303             inet_ntop(AF_INET, &s->sin_addr, client_ip, INET6_ADDRSTRLEN);
304             if (port)
305                 *port = ntohs(s->sin_port);
306         } else {
307             struct sockaddr_in6 *s = reinterpret_cast<sockaddr_in6 *>(&address);
308             inet_ntop(AF_INET6, &s->sin6_addr, client_ip, INET6_ADDRSTRLEN);
309             if (port)
310                 *port = ntohs(s->sin6_port);
311         }
312         if (strncmp(client_ip, "::ffff:", 7) == 0)
313             return client_ip+7;
314         else
315             return client_ip;
316     }
317 }
318 
createServerSocket()319 bool Socket::createServerSocket() {
320     std::string ip = _hostname;
321 
322     if (_socket >= 0)
323         return false; // Already in use!!
324 
325     log() << "Listen on " << port() << " ip " << ip;
326 
327     struct addrinfo *addr = getAddressInfo();
328     if (!addr)
329         return false;
330 
331     int fd = ::socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);
332     if (fd < 0) {
333         errno_log() << "cannot create listen socket";
334         return false;
335     }
336 #ifndef _WIN32
337     int reuse = 1;
338     if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuse,
339                    sizeof(reuse)) < 0) {
340         errno_log() << "cannot reuse listen socket";
341         return false;
342     }
343 #endif
344 
345     if (bind(fd, addr->ai_addr, addr->ai_addrlen) != 0) {
346         errno_log() << "cannot bind listen socket";
347         return false;
348     }
349 
350     if (listen(fd, 20) != 0) {
351         errno_log() << "cannot listen";
352         return false;
353     }
354 
355     // Socket will be -1, and state will be UNDEFINED, unless we get here:
356     _socket = fd;
357     _state = PollState::READ;
358 
359     // Check port number
360     struct sockaddr_in6 address;
361     socklen_t len = sizeof(address);
362     if (getsockname(fd, reinterpret_cast<sockaddr *>(&address), &len) == -1)
363         errno_log() << "getsockname failed";
364     else {
365         _port = ntohs(address.sin6_port);
366         log() << "server socket " << fd << " listening on port " << _port;
367     }
368 
369     return true;
370 }
371