1 /*
2 * Copyright 2012-2019 Max Kellermann <max.kellermann@gmail.com>
3 *
4 * Redistribution and use in source and binary forms, with or without
5 * modification, are permitted provided that the following conditions
6 * are met:
7 *
8 * - Redistributions of source code must retain the above copyright
9 * notice, this list of conditions and the following disclaimer.
10 *
11 * - Redistributions in binary form must reproduce the above copyright
12 * notice, this list of conditions and the following disclaimer in the
13 * documentation and/or other materials provided with the
14 * distribution.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
19 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
20 * FOUNDATION OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
21 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
22 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
23 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
25 * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
27 * OF THE POSSIBILITY OF SUCH DAMAGE.
28 */
29
30 #include "SocketDescriptor.hxx"
31 #include "SocketAddress.hxx"
32 #include "StaticSocketAddress.hxx"
33 #include "IPv4Address.hxx"
34 #include "IPv6Address.hxx"
35
36 #ifdef _WIN32
37 #include <winsock2.h>
38 #include <ws2tcpip.h>
39 #else
40 #include <sys/socket.h>
41 #include <netinet/in.h>
42 #include <netinet/tcp.h>
43 #endif
44
45 #include <cassert>
46 #include <cerrno>
47
48 #include <string.h>
49
50 int
GetType() const51 SocketDescriptor::GetType() const noexcept
52 {
53 assert(IsDefined());
54
55 int type;
56 socklen_t size = sizeof(type);
57 return getsockopt(fd, SOL_SOCKET, SO_TYPE,
58 (char *)&type, &size) == 0
59 ? type
60 : -1;
61 }
62
63 bool
IsStream() const64 SocketDescriptor::IsStream() const noexcept
65 {
66 return GetType() == SOCK_STREAM;
67 }
68
69 #ifdef _WIN32
70
71 void
Close()72 SocketDescriptor::Close() noexcept
73 {
74 if (IsDefined())
75 ::closesocket(Steal());
76 }
77
78 #endif
79
80 SocketDescriptor
Accept()81 SocketDescriptor::Accept() noexcept
82 {
83 #ifdef __linux__
84 int connection_fd = ::accept4(Get(), nullptr, nullptr, SOCK_CLOEXEC);
85 #else
86 int connection_fd = ::accept(Get(), nullptr, nullptr);
87 #endif
88 return connection_fd >= 0
89 ? SocketDescriptor(connection_fd)
90 : Undefined();
91 }
92
93 SocketDescriptor
AcceptNonBlock() const94 SocketDescriptor::AcceptNonBlock() const noexcept
95 {
96 #ifdef __linux__
97 int connection_fd = ::accept4(Get(), nullptr, nullptr,
98 SOCK_CLOEXEC|SOCK_NONBLOCK);
99 #else
100 int connection_fd = ::accept(Get(), nullptr, nullptr);
101 if (connection_fd >= 0)
102 SocketDescriptor(connection_fd).SetNonBlocking();
103 #endif
104 return SocketDescriptor(connection_fd);
105 }
106
107 SocketDescriptor
AcceptNonBlock(StaticSocketAddress & address) const108 SocketDescriptor::AcceptNonBlock(StaticSocketAddress &address) const noexcept
109 {
110 address.SetMaxSize();
111 #ifdef __linux__
112 int connection_fd = ::accept4(Get(), address, &address.size,
113 SOCK_CLOEXEC|SOCK_NONBLOCK);
114 #else
115 int connection_fd = ::accept(Get(), address, &address.size);
116 if (connection_fd >= 0)
117 SocketDescriptor(connection_fd).SetNonBlocking();
118 #endif
119 return SocketDescriptor(connection_fd);
120 }
121
122 bool
Connect(SocketAddress address)123 SocketDescriptor::Connect(SocketAddress address) noexcept
124 {
125 assert(address.IsDefined());
126
127 return ::connect(Get(), address.GetAddress(), address.GetSize()) >= 0;
128 }
129
130 bool
Create(int domain,int type,int protocol)131 SocketDescriptor::Create(int domain, int type, int protocol) noexcept
132 {
133 #ifdef _WIN32
134 static bool initialised = false;
135 if (!initialised) {
136 WSADATA data;
137 WSAStartup(MAKEWORD(2,2), &data);
138 initialised = true;
139 }
140 #endif
141
142 #ifdef SOCK_CLOEXEC
143 /* implemented since Linux 2.6.27 */
144 type |= SOCK_CLOEXEC;
145 #endif
146
147 int new_fd = socket(domain, type, protocol);
148 if (new_fd < 0)
149 return false;
150
151 Set(new_fd);
152 return true;
153 }
154
155 bool
CreateNonBlock(int domain,int type,int protocol)156 SocketDescriptor::CreateNonBlock(int domain, int type, int protocol) noexcept
157 {
158 #ifdef SOCK_NONBLOCK
159 type |= SOCK_NONBLOCK;
160 #endif
161
162 if (!Create(domain, type, protocol))
163 return false;
164
165 #ifndef SOCK_NONBLOCK
166 SetNonBlocking();
167 #endif
168
169 return true;
170 }
171
172 #ifndef _WIN32
173
174 bool
CreateSocketPair(int domain,int type,int protocol,SocketDescriptor & a,SocketDescriptor & b)175 SocketDescriptor::CreateSocketPair(int domain, int type, int protocol,
176 SocketDescriptor &a,
177 SocketDescriptor &b) noexcept
178 {
179 #ifdef SOCK_CLOEXEC
180 /* implemented since Linux 2.6.27 */
181 type |= SOCK_CLOEXEC;
182 #endif
183
184 int fds[2];
185 if (socketpair(domain, type, protocol, fds) < 0)
186 return false;
187
188 a = SocketDescriptor(fds[0]);
189 b = SocketDescriptor(fds[1]);
190 return true;
191 }
192
193 bool
CreateSocketPairNonBlock(int domain,int type,int protocol,SocketDescriptor & a,SocketDescriptor & b)194 SocketDescriptor::CreateSocketPairNonBlock(int domain, int type, int protocol,
195 SocketDescriptor &a,
196 SocketDescriptor &b) noexcept
197 {
198 #ifdef SOCK_NONBLOCK
199 type |= SOCK_NONBLOCK;
200 #endif
201
202 if (!CreateSocketPair(domain, type, protocol, a, b))
203 return false;
204
205 #ifndef SOCK_NONBLOCK
206 a.SetNonBlocking();
207 b.SetNonBlocking();
208 #endif
209
210 return true;
211 }
212
213 #endif
214
215 int
GetError()216 SocketDescriptor::GetError() noexcept
217 {
218 assert(IsDefined());
219
220 int s_err = 0;
221 socklen_t s_err_size = sizeof(s_err);
222 return getsockopt(fd, SOL_SOCKET, SO_ERROR,
223 (char *)&s_err, &s_err_size) == 0
224 ? s_err
225 : errno;
226 }
227
228 std::size_t
GetOption(int level,int name,void * value,std::size_t size) const229 SocketDescriptor::GetOption(int level, int name,
230 void *value, std::size_t size) const noexcept
231 {
232 assert(IsDefined());
233
234 socklen_t size2 = size;
235 return getsockopt(fd, level, name, (char *)value, &size2) == 0
236 ? size2
237 : 0;
238 }
239
240 #ifdef HAVE_STRUCT_UCRED
241
242 struct ucred
GetPeerCredentials() const243 SocketDescriptor::GetPeerCredentials() const noexcept
244 {
245 struct ucred cred;
246 if (GetOption(SOL_SOCKET, SO_PEERCRED,
247 &cred, sizeof(cred)) < sizeof(cred))
248 cred.pid = -1;
249 return cred;
250 }
251
252 #endif
253
254 #ifdef _WIN32
255
256 bool
SetNonBlocking()257 SocketDescriptor::SetNonBlocking() noexcept
258 {
259 u_long val = 1;
260 return ioctlsocket(fd, FIONBIO, &val) == 0;
261 }
262
263 #endif
264
265 bool
SetOption(int level,int name,const void * value,std::size_t size)266 SocketDescriptor::SetOption(int level, int name,
267 const void *value, std::size_t size) noexcept
268 {
269 assert(IsDefined());
270
271 /* on Windows, setsockopt() wants "const char *" */
272 return setsockopt(fd, level, name, (const char *)value, size) == 0;
273 }
274
275 bool
SetKeepAlive(bool value)276 SocketDescriptor::SetKeepAlive(bool value) noexcept
277 {
278 return SetBoolOption(SOL_SOCKET, SO_KEEPALIVE, value);
279 }
280
281 bool
SetReuseAddress(bool value)282 SocketDescriptor::SetReuseAddress(bool value) noexcept
283 {
284 return SetBoolOption(SOL_SOCKET, SO_REUSEADDR, value);
285 }
286
287 #ifdef __linux__
288
289 bool
SetReusePort(bool value)290 SocketDescriptor::SetReusePort(bool value) noexcept
291 {
292 return SetBoolOption(SOL_SOCKET, SO_REUSEPORT, value);
293 }
294
295 bool
SetFreeBind(bool value)296 SocketDescriptor::SetFreeBind(bool value) noexcept
297 {
298 return SetBoolOption(IPPROTO_IP, IP_FREEBIND, value);
299 }
300
301 bool
SetNoDelay(bool value)302 SocketDescriptor::SetNoDelay(bool value) noexcept
303 {
304 return SetBoolOption(IPPROTO_TCP, TCP_NODELAY, value);
305 }
306
307 bool
SetCork(bool value)308 SocketDescriptor::SetCork(bool value) noexcept
309 {
310 return SetBoolOption(IPPROTO_TCP, TCP_CORK, value);
311 }
312
313 bool
SetTcpDeferAccept(const int & seconds)314 SocketDescriptor::SetTcpDeferAccept(const int &seconds) noexcept
315 {
316 return SetOption(IPPROTO_TCP, TCP_DEFER_ACCEPT, &seconds, sizeof(seconds));
317 }
318
319 bool
SetTcpUserTimeout(const unsigned & milliseconds)320 SocketDescriptor::SetTcpUserTimeout(const unsigned &milliseconds) noexcept
321 {
322 return SetOption(IPPROTO_TCP, TCP_USER_TIMEOUT,
323 &milliseconds, sizeof(milliseconds));
324 }
325
326 bool
SetV6Only(bool value)327 SocketDescriptor::SetV6Only(bool value) noexcept
328 {
329 return SetBoolOption(IPPROTO_IPV6, IPV6_V6ONLY, value);
330 }
331
332 bool
SetBindToDevice(const char * name)333 SocketDescriptor::SetBindToDevice(const char *name) noexcept
334 {
335 return SetOption(SOL_SOCKET, SO_BINDTODEVICE, name, strlen(name));
336 }
337
338 #ifdef TCP_FASTOPEN
339
340 bool
SetTcpFastOpen(int qlen)341 SocketDescriptor::SetTcpFastOpen(int qlen) noexcept
342 {
343 return SetOption(SOL_TCP, TCP_FASTOPEN, &qlen, sizeof(qlen));
344 }
345
346 #endif
347
348 bool
AddMembership(const IPv4Address & address)349 SocketDescriptor::AddMembership(const IPv4Address &address) noexcept
350 {
351 struct ip_mreq r{address.GetAddress(), IPv4Address(0).GetAddress()};
352 return setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP,
353 &r, sizeof(r)) == 0;
354 }
355
356 bool
AddMembership(const IPv6Address & address)357 SocketDescriptor::AddMembership(const IPv6Address &address) noexcept
358 {
359 struct ipv6_mreq r{address.GetAddress(), 0};
360 r.ipv6mr_interface = address.GetScopeId();
361 return setsockopt(fd, IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP,
362 &r, sizeof(r)) == 0;
363 }
364
365 bool
AddMembership(SocketAddress address)366 SocketDescriptor::AddMembership(SocketAddress address) noexcept
367 {
368 switch (address.GetFamily()) {
369 case AF_INET:
370 return AddMembership(IPv4Address(address));
371
372 case AF_INET6:
373 return AddMembership(IPv6Address(address));
374
375 default:
376 errno = EINVAL;
377 return false;
378 }
379 }
380
381 #endif
382
383 bool
Bind(SocketAddress address)384 SocketDescriptor::Bind(SocketAddress address) noexcept
385 {
386 return bind(Get(), address.GetAddress(), address.GetSize()) == 0;
387 }
388
389 #ifdef __linux__
390
391 bool
AutoBind()392 SocketDescriptor::AutoBind() noexcept
393 {
394 static constexpr sa_family_t family = AF_LOCAL;
395 return Bind(SocketAddress((const struct sockaddr *)&family,
396 sizeof(family)));
397 }
398
399 #endif
400
401 bool
Listen(int backlog)402 SocketDescriptor::Listen(int backlog) noexcept
403 {
404 return listen(Get(), backlog) == 0;
405 }
406
407 StaticSocketAddress
GetLocalAddress() const408 SocketDescriptor::GetLocalAddress() const noexcept
409 {
410 assert(IsDefined());
411
412 StaticSocketAddress result;
413 result.size = result.GetCapacity();
414 if (getsockname(fd, result, &result.size) < 0)
415 result.Clear();
416
417 return result;
418 }
419
420 StaticSocketAddress
GetPeerAddress() const421 SocketDescriptor::GetPeerAddress() const noexcept
422 {
423 assert(IsDefined());
424
425 StaticSocketAddress result;
426 result.size = result.GetCapacity();
427 if (getpeername(fd, result, &result.size) < 0)
428 result.Clear();
429
430 return result;
431 }
432
433 ssize_t
Read(void * buffer,std::size_t length)434 SocketDescriptor::Read(void *buffer, std::size_t length) noexcept
435 {
436 int flags = 0;
437 #ifndef _WIN32
438 flags |= MSG_DONTWAIT;
439 #endif
440
441 return ::recv(Get(), (char *)buffer, length, flags);
442 }
443
444 ssize_t
Write(const void * buffer,std::size_t length)445 SocketDescriptor::Write(const void *buffer, std::size_t length) noexcept
446 {
447 int flags = 0;
448 #ifdef __linux__
449 flags |= MSG_NOSIGNAL;
450 #endif
451
452 return ::send(Get(), (const char *)buffer, length, flags);
453 }
454
455 #ifdef _WIN32
456
457 int
WaitReadable(int timeout_ms) const458 SocketDescriptor::WaitReadable(int timeout_ms) const noexcept
459 {
460 assert(IsDefined());
461
462 fd_set rfds;
463 FD_ZERO(&rfds);
464 FD_SET(Get(), &rfds);
465
466 struct timeval timeout, *timeout_p = nullptr;
467 if (timeout_ms >= 0) {
468 timeout.tv_sec = unsigned(timeout_ms) / 1000;
469 timeout.tv_usec = (unsigned(timeout_ms) % 1000) * 1000;
470 timeout_p = &timeout;
471 }
472
473 return select(Get() + 1, &rfds, nullptr, nullptr, timeout_p);
474 }
475
476 int
WaitWritable(int timeout_ms) const477 SocketDescriptor::WaitWritable(int timeout_ms) const noexcept
478 {
479 assert(IsDefined());
480
481 fd_set wfds;
482 FD_ZERO(&wfds);
483 FD_SET(Get(), &wfds);
484
485 struct timeval timeout, *timeout_p = nullptr;
486 if (timeout_ms >= 0) {
487 timeout.tv_sec = unsigned(timeout_ms) / 1000;
488 timeout.tv_usec = (unsigned(timeout_ms) % 1000) * 1000;
489 timeout_p = &timeout;
490 }
491
492 return select(Get() + 1, nullptr, &wfds, nullptr, timeout_p);
493 }
494
495 #endif
496
497 ssize_t
Read(void * buffer,std::size_t length,StaticSocketAddress & address)498 SocketDescriptor::Read(void *buffer, std::size_t length,
499 StaticSocketAddress &address) noexcept
500 {
501 int flags = 0;
502 #ifndef _WIN32
503 flags |= MSG_DONTWAIT;
504 #endif
505
506 socklen_t addrlen = address.GetCapacity();
507 ssize_t nbytes = ::recvfrom(Get(), (char *)buffer, length, flags,
508 address, &addrlen);
509 if (nbytes > 0)
510 address.SetSize(addrlen);
511
512 return nbytes;
513 }
514
515 ssize_t
Write(const void * buffer,std::size_t length,SocketAddress address)516 SocketDescriptor::Write(const void *buffer, std::size_t length,
517 SocketAddress address) noexcept
518 {
519 int flags = 0;
520 #ifndef _WIN32
521 flags |= MSG_DONTWAIT;
522 #endif
523 #ifdef __linux__
524 flags |= MSG_NOSIGNAL;
525 #endif
526
527 return ::sendto(Get(), (const char *)buffer, length, flags,
528 address.GetAddress(), address.GetSize());
529 }
530
531 #ifndef _WIN32
532
533 void
Shutdown()534 SocketDescriptor::Shutdown() noexcept
535 {
536 shutdown(Get(), SHUT_RDWR);
537 }
538
539 void
ShutdownRead()540 SocketDescriptor::ShutdownRead() noexcept
541 {
542 shutdown(Get(), SHUT_RD);
543 }
544
545 void
ShutdownWrite()546 SocketDescriptor::ShutdownWrite() noexcept
547 {
548 shutdown(Get(), SHUT_WR);
549 }
550
551 #endif
552