1 /*
2  * Copyright 2012-2019 Max Kellermann <max.kellermann@gmail.com>
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  *
8  * - Redistributions of source code must retain the above copyright
9  * notice, this list of conditions and the following disclaimer.
10  *
11  * - Redistributions in binary form must reproduce the above copyright
12  * notice, this list of conditions and the following disclaimer in the
13  * documentation and/or other materials provided with the
14  * distribution.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
19  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE
20  * FOUNDATION OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
21  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
22  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
23  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
25  * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
27  * OF THE POSSIBILITY OF SUCH DAMAGE.
28  */
29 
30 #include "SocketDescriptor.hxx"
31 #include "SocketAddress.hxx"
32 #include "StaticSocketAddress.hxx"
33 #include "IPv4Address.hxx"
34 #include "IPv6Address.hxx"
35 
36 #ifdef _WIN32
37 #include <winsock2.h>
38 #include <ws2tcpip.h>
39 #else
40 #include <sys/socket.h>
41 #include <netinet/in.h>
42 #include <netinet/tcp.h>
43 #endif
44 
45 #include <cassert>
46 #include <cerrno>
47 
48 #include <string.h>
49 
50 int
GetType() const51 SocketDescriptor::GetType() const noexcept
52 {
53 	assert(IsDefined());
54 
55 	int type;
56 	socklen_t size = sizeof(type);
57 	return getsockopt(fd, SOL_SOCKET, SO_TYPE,
58 			  (char *)&type, &size) == 0
59 		? type
60 		: -1;
61 }
62 
63 bool
IsStream() const64 SocketDescriptor::IsStream() const noexcept
65 {
66 	return GetType() == SOCK_STREAM;
67 }
68 
69 #ifdef _WIN32
70 
71 void
Close()72 SocketDescriptor::Close() noexcept
73 {
74 	if (IsDefined())
75 		::closesocket(Steal());
76 }
77 
78 #endif
79 
80 SocketDescriptor
Accept()81 SocketDescriptor::Accept() noexcept
82 {
83 #ifdef __linux__
84 	int connection_fd = ::accept4(Get(), nullptr, nullptr, SOCK_CLOEXEC);
85 #else
86 	int connection_fd = ::accept(Get(), nullptr, nullptr);
87 #endif
88 	return connection_fd >= 0
89 		? SocketDescriptor(connection_fd)
90 		: Undefined();
91 }
92 
93 SocketDescriptor
AcceptNonBlock() const94 SocketDescriptor::AcceptNonBlock() const noexcept
95 {
96 #ifdef __linux__
97 	int connection_fd = ::accept4(Get(), nullptr, nullptr,
98 				      SOCK_CLOEXEC|SOCK_NONBLOCK);
99 #else
100 	int connection_fd = ::accept(Get(), nullptr, nullptr);
101 	if (connection_fd >= 0)
102 		SocketDescriptor(connection_fd).SetNonBlocking();
103 #endif
104 	return SocketDescriptor(connection_fd);
105 }
106 
107 SocketDescriptor
AcceptNonBlock(StaticSocketAddress & address) const108 SocketDescriptor::AcceptNonBlock(StaticSocketAddress &address) const noexcept
109 {
110 	address.SetMaxSize();
111 #ifdef __linux__
112 	int connection_fd = ::accept4(Get(), address, &address.size,
113 				      SOCK_CLOEXEC|SOCK_NONBLOCK);
114 #else
115 	int connection_fd = ::accept(Get(), address, &address.size);
116 	if (connection_fd >= 0)
117 		SocketDescriptor(connection_fd).SetNonBlocking();
118 #endif
119 	return SocketDescriptor(connection_fd);
120 }
121 
122 bool
Connect(SocketAddress address)123 SocketDescriptor::Connect(SocketAddress address) noexcept
124 {
125 	assert(address.IsDefined());
126 
127 	return ::connect(Get(), address.GetAddress(), address.GetSize()) >= 0;
128 }
129 
130 bool
Create(int domain,int type,int protocol)131 SocketDescriptor::Create(int domain, int type, int protocol) noexcept
132 {
133 #ifdef _WIN32
134 	static bool initialised = false;
135 	if (!initialised) {
136 		WSADATA data;
137 		WSAStartup(MAKEWORD(2,2), &data);
138 		initialised = true;
139 	}
140 #endif
141 
142 #ifdef SOCK_CLOEXEC
143 	/* implemented since Linux 2.6.27 */
144 	type |= SOCK_CLOEXEC;
145 #endif
146 
147 	int new_fd = socket(domain, type, protocol);
148 	if (new_fd < 0)
149 		return false;
150 
151 	Set(new_fd);
152 	return true;
153 }
154 
155 bool
CreateNonBlock(int domain,int type,int protocol)156 SocketDescriptor::CreateNonBlock(int domain, int type, int protocol) noexcept
157 {
158 #ifdef SOCK_NONBLOCK
159 	type |= SOCK_NONBLOCK;
160 #endif
161 
162 	if (!Create(domain, type, protocol))
163 		return false;
164 
165 #ifndef SOCK_NONBLOCK
166 	SetNonBlocking();
167 #endif
168 
169 	return true;
170 }
171 
172 #ifndef _WIN32
173 
174 bool
CreateSocketPair(int domain,int type,int protocol,SocketDescriptor & a,SocketDescriptor & b)175 SocketDescriptor::CreateSocketPair(int domain, int type, int protocol,
176 				   SocketDescriptor &a,
177 				   SocketDescriptor &b) noexcept
178 {
179 #ifdef SOCK_CLOEXEC
180 	/* implemented since Linux 2.6.27 */
181 	type |= SOCK_CLOEXEC;
182 #endif
183 
184 	int fds[2];
185 	if (socketpair(domain, type, protocol, fds) < 0)
186 		return false;
187 
188 	a = SocketDescriptor(fds[0]);
189 	b = SocketDescriptor(fds[1]);
190 	return true;
191 }
192 
193 bool
CreateSocketPairNonBlock(int domain,int type,int protocol,SocketDescriptor & a,SocketDescriptor & b)194 SocketDescriptor::CreateSocketPairNonBlock(int domain, int type, int protocol,
195 					   SocketDescriptor &a,
196 					   SocketDescriptor &b) noexcept
197 {
198 #ifdef SOCK_NONBLOCK
199 	type |= SOCK_NONBLOCK;
200 #endif
201 
202 	if (!CreateSocketPair(domain, type, protocol, a, b))
203 		return false;
204 
205 #ifndef SOCK_NONBLOCK
206 	a.SetNonBlocking();
207 	b.SetNonBlocking();
208 #endif
209 
210 	return true;
211 }
212 
213 #endif
214 
215 int
GetError()216 SocketDescriptor::GetError() noexcept
217 {
218 	assert(IsDefined());
219 
220 	int s_err = 0;
221 	socklen_t s_err_size = sizeof(s_err);
222 	return getsockopt(fd, SOL_SOCKET, SO_ERROR,
223 			  (char *)&s_err, &s_err_size) == 0
224 		? s_err
225 		: errno;
226 }
227 
228 size_t
GetOption(int level,int name,void * value,size_t size) const229 SocketDescriptor::GetOption(int level, int name,
230 			    void *value, size_t size) const noexcept
231 {
232 	assert(IsDefined());
233 
234 	socklen_t size2 = size;
235 	return getsockopt(fd, level, name, (char *)value, &size2) == 0
236 		? size2
237 		: 0;
238 }
239 
240 #ifdef HAVE_STRUCT_UCRED
241 
242 struct ucred
GetPeerCredentials() const243 SocketDescriptor::GetPeerCredentials() const noexcept
244 {
245 	struct ucred cred;
246 	if (GetOption(SOL_SOCKET, SO_PEERCRED,
247 		      &cred, sizeof(cred)) < sizeof(cred))
248 		cred.pid = -1;
249 	return cred;
250 }
251 
252 #endif
253 
254 #ifdef _WIN32
255 
256 bool
SetNonBlocking()257 SocketDescriptor::SetNonBlocking() noexcept
258 {
259 	u_long val = 1;
260 	return ioctlsocket(fd, FIONBIO, &val) == 0;
261 }
262 
263 #endif
264 
265 bool
SetOption(int level,int name,const void * value,size_t size)266 SocketDescriptor::SetOption(int level, int name,
267 			    const void *value, size_t size) noexcept
268 {
269 	assert(IsDefined());
270 
271 	/* on Windows, setsockopt() wants "const char *" */
272 	return setsockopt(fd, level, name, (const char *)value, size) == 0;
273 }
274 
275 bool
SetKeepAlive(bool value)276 SocketDescriptor::SetKeepAlive(bool value) noexcept
277 {
278 	return SetBoolOption(SOL_SOCKET, SO_KEEPALIVE, value);
279 }
280 
281 bool
SetReuseAddress(bool value)282 SocketDescriptor::SetReuseAddress(bool value) noexcept
283 {
284 	return SetBoolOption(SOL_SOCKET, SO_REUSEADDR, value);
285 }
286 
287 #ifdef __linux__
288 
289 bool
SetReusePort(bool value)290 SocketDescriptor::SetReusePort(bool value) noexcept
291 {
292 	return SetBoolOption(SOL_SOCKET, SO_REUSEPORT, value);
293 }
294 
295 bool
SetFreeBind(bool value)296 SocketDescriptor::SetFreeBind(bool value) noexcept
297 {
298 	return SetBoolOption(IPPROTO_IP, IP_FREEBIND, value);
299 }
300 
301 bool
SetNoDelay(bool value)302 SocketDescriptor::SetNoDelay(bool value) noexcept
303 {
304 	return SetBoolOption(IPPROTO_TCP, TCP_NODELAY, value);
305 }
306 
307 bool
SetCork(bool value)308 SocketDescriptor::SetCork(bool value) noexcept
309 {
310 	return SetBoolOption(IPPROTO_TCP, TCP_CORK, value);
311 }
312 
313 bool
SetTcpDeferAccept(const int & seconds)314 SocketDescriptor::SetTcpDeferAccept(const int &seconds) noexcept
315 {
316 	return SetOption(IPPROTO_TCP, TCP_DEFER_ACCEPT, &seconds, sizeof(seconds));
317 }
318 
319 bool
SetTcpUserTimeout(const unsigned & milliseconds)320 SocketDescriptor::SetTcpUserTimeout(const unsigned &milliseconds) noexcept
321 {
322 	return SetOption(IPPROTO_TCP, TCP_USER_TIMEOUT,
323 			 &milliseconds, sizeof(milliseconds));
324 }
325 
326 bool
SetV6Only(bool value)327 SocketDescriptor::SetV6Only(bool value) noexcept
328 {
329 	return SetBoolOption(IPPROTO_IPV6, IPV6_V6ONLY, value);
330 }
331 
332 bool
SetBindToDevice(const char * name)333 SocketDescriptor::SetBindToDevice(const char *name) noexcept
334 {
335 	return SetOption(SOL_SOCKET, SO_BINDTODEVICE, name, strlen(name));
336 }
337 
338 #ifdef TCP_FASTOPEN
339 
340 bool
SetTcpFastOpen(int qlen)341 SocketDescriptor::SetTcpFastOpen(int qlen) noexcept
342 {
343 	return SetOption(SOL_TCP, TCP_FASTOPEN, &qlen, sizeof(qlen));
344 }
345 
346 #endif
347 
348 bool
AddMembership(const IPv4Address & address)349 SocketDescriptor::AddMembership(const IPv4Address &address) noexcept
350 {
351 	struct ip_mreq r{address.GetAddress(), IPv4Address(0).GetAddress()};
352 	return setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP,
353 			  &r, sizeof(r)) == 0;
354 }
355 
356 bool
AddMembership(const IPv6Address & address)357 SocketDescriptor::AddMembership(const IPv6Address &address) noexcept
358 {
359 	struct ipv6_mreq r{address.GetAddress(), 0};
360 	r.ipv6mr_interface = address.GetScopeId();
361 	return setsockopt(fd, IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP,
362 			  &r, sizeof(r)) == 0;
363 }
364 
365 bool
AddMembership(SocketAddress address)366 SocketDescriptor::AddMembership(SocketAddress address) noexcept
367 {
368 	switch (address.GetFamily()) {
369 	case AF_INET:
370 		return AddMembership(IPv4Address(address));
371 
372 	case AF_INET6:
373 		return AddMembership(IPv6Address(address));
374 
375 	default:
376 		errno = EINVAL;
377 		return false;
378 	}
379 }
380 
381 #endif
382 
383 bool
Bind(SocketAddress address)384 SocketDescriptor::Bind(SocketAddress address) noexcept
385 {
386 	return bind(Get(), address.GetAddress(), address.GetSize()) == 0;
387 }
388 
389 #ifdef __linux__
390 
391 bool
AutoBind()392 SocketDescriptor::AutoBind() noexcept
393 {
394 	static constexpr sa_family_t family = AF_LOCAL;
395 	return Bind(SocketAddress((const struct sockaddr *)&family,
396 				  sizeof(family)));
397 }
398 
399 #endif
400 
401 bool
Listen(int backlog)402 SocketDescriptor::Listen(int backlog) noexcept
403 {
404 	return listen(Get(), backlog) == 0;
405 }
406 
407 StaticSocketAddress
GetLocalAddress() const408 SocketDescriptor::GetLocalAddress() const noexcept
409 {
410 	assert(IsDefined());
411 
412 	StaticSocketAddress result;
413 	result.size = result.GetCapacity();
414 	if (getsockname(fd, result, &result.size) < 0)
415 		result.Clear();
416 
417 	return result;
418 }
419 
420 StaticSocketAddress
GetPeerAddress() const421 SocketDescriptor::GetPeerAddress() const noexcept
422 {
423 	assert(IsDefined());
424 
425 	StaticSocketAddress result;
426 	result.size = result.GetCapacity();
427 	if (getpeername(fd, result, &result.size) < 0)
428 		result.Clear();
429 
430 	return result;
431 }
432 
433 ssize_t
Read(void * buffer,size_t length)434 SocketDescriptor::Read(void *buffer, size_t length) noexcept
435 {
436 	int flags = 0;
437 #ifndef _WIN32
438 	flags |= MSG_DONTWAIT;
439 #endif
440 
441 	return ::recv(Get(), (char *)buffer, length, flags);
442 }
443 
444 ssize_t
Write(const void * buffer,size_t length)445 SocketDescriptor::Write(const void *buffer, size_t length) noexcept
446 {
447 	int flags = 0;
448 #ifdef __linux__
449 	flags |= MSG_NOSIGNAL;
450 #endif
451 
452 	return ::send(Get(), (const char *)buffer, length, flags);
453 }
454 
455 #ifdef _WIN32
456 
457 int
WaitReadable(int timeout_ms) const458 SocketDescriptor::WaitReadable(int timeout_ms) const noexcept
459 {
460 	assert(IsDefined());
461 
462 	fd_set rfds;
463 	FD_ZERO(&rfds);
464 	FD_SET(Get(), &rfds);
465 
466 	struct timeval timeout, *timeout_p = nullptr;
467 	if (timeout_ms >= 0) {
468 		timeout.tv_sec = unsigned(timeout_ms) / 1000;
469 		timeout.tv_usec = (unsigned(timeout_ms) % 1000) * 1000;
470 		timeout_p = &timeout;
471 	}
472 
473 	return select(Get() + 1, &rfds, nullptr, nullptr, timeout_p);
474 }
475 
476 int
WaitWritable(int timeout_ms) const477 SocketDescriptor::WaitWritable(int timeout_ms) const noexcept
478 {
479 	assert(IsDefined());
480 
481 	fd_set wfds;
482 	FD_ZERO(&wfds);
483 	FD_SET(Get(), &wfds);
484 
485 	struct timeval timeout, *timeout_p = nullptr;
486 	if (timeout_ms >= 0) {
487 		timeout.tv_sec = unsigned(timeout_ms) / 1000;
488 		timeout.tv_usec = (unsigned(timeout_ms) % 1000) * 1000;
489 		timeout_p = &timeout;
490 	}
491 
492 	return select(Get() + 1, nullptr, &wfds, nullptr, timeout_p);
493 }
494 
495 #endif
496 
497 ssize_t
Read(void * buffer,size_t length,StaticSocketAddress & address)498 SocketDescriptor::Read(void *buffer, size_t length,
499 		       StaticSocketAddress &address) noexcept
500 {
501 	int flags = 0;
502 #ifndef _WIN32
503 	flags |= MSG_DONTWAIT;
504 #endif
505 
506 	socklen_t addrlen = address.GetCapacity();
507 	ssize_t nbytes = ::recvfrom(Get(), (char *)buffer, length, flags,
508 				    address, &addrlen);
509 	if (nbytes > 0)
510 		address.SetSize(addrlen);
511 
512 	return nbytes;
513 }
514 
515 ssize_t
Write(const void * buffer,size_t length,SocketAddress address)516 SocketDescriptor::Write(const void *buffer, size_t length,
517 			SocketAddress address) noexcept
518 {
519 	int flags = 0;
520 #ifndef _WIN32
521 	flags |= MSG_DONTWAIT;
522 #endif
523 #ifdef __linux__
524 	flags |= MSG_NOSIGNAL;
525 #endif
526 
527 	return ::sendto(Get(), (const char *)buffer, length, flags,
528 			address.GetAddress(), address.GetSize());
529 }
530 
531 #ifndef _WIN32
532 
533 void
Shutdown()534 SocketDescriptor::Shutdown() noexcept
535 {
536     shutdown(Get(), SHUT_RDWR);
537 }
538 
539 void
ShutdownRead()540 SocketDescriptor::ShutdownRead() noexcept
541 {
542     shutdown(Get(), SHUT_RD);
543 }
544 
545 void
ShutdownWrite()546 SocketDescriptor::ShutdownWrite() noexcept
547 {
548     shutdown(Get(), SHUT_WR);
549 }
550 
551 #endif
552