1 //
2 // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2021
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 //
7 #include "td/utils/port/UdpSocketFd.h"
8
9 #include "td/utils/common.h"
10 #include "td/utils/format.h"
11 #include "td/utils/logging.h"
12 #include "td/utils/misc.h"
13 #include "td/utils/port/detail/skip_eintr.h"
14 #include "td/utils/port/PollFlags.h"
15 #include "td/utils/port/SocketFd.h"
16 #include "td/utils/SliceBuilder.h"
17 #include "td/utils/VectorQueue.h"
18
19 #if TD_PORT_WINDOWS
20 #include "td/utils/port/detail/Iocp.h"
21 #include "td/utils/SpinLock.h"
22 #endif
23
24 #if TD_PORT_POSIX
25 #include <cerrno>
26
27 #include <arpa/inet.h>
28 #include <fcntl.h>
29 #include <netinet/in.h>
30 #include <netinet/tcp.h>
31 #include <sys/socket.h>
32 #include <sys/types.h>
33 #include <unistd.h>
34
35 #if TD_LINUX
36 #include <linux/errqueue.h>
37 #endif
38 #endif
39
40 #include <array>
41 #include <atomic>
42 #include <cstring>
43
44 namespace td {
45 namespace detail {
46 #if TD_PORT_WINDOWS
47 class UdpSocketReceiveHelper {
48 public:
to_native(const UdpMessage & message,WSAMSG & message_header)49 void to_native(const UdpMessage &message, WSAMSG &message_header) {
50 socklen_t addr_len{narrow_cast<socklen_t>(sizeof(addr_))};
51 message_header.name = reinterpret_cast<sockaddr *>(&addr_);
52 message_header.namelen = addr_len;
53 buf_.buf = const_cast<char *>(message.data.as_slice().begin());
54 buf_.len = narrow_cast<DWORD>(message.data.size());
55 message_header.lpBuffers = &buf_;
56 message_header.dwBufferCount = 1;
57 message_header.Control.buf = nullptr; // control_buf_.data();
58 message_header.Control.len = 0; // narrow_cast<decltype(message_header.Control.len)>(control_buf_.size());
59 message_header.dwFlags = 0;
60 }
61
from_native(WSAMSG & message_header,size_t message_size,UdpMessage & message)62 static void from_native(WSAMSG &message_header, size_t message_size, UdpMessage &message) {
63 message.address.init_sockaddr(reinterpret_cast<sockaddr *>(message_header.name), message_header.namelen).ignore();
64 message.error = Status::OK();
65
66 if ((message_header.dwFlags & (MSG_TRUNC | MSG_CTRUNC)) != 0) {
67 message.error = Status::Error(501, "Message too long");
68 message.data = BufferSlice();
69 return;
70 }
71
72 CHECK(message_size <= message.data.size());
73 message.data.truncate(message_size);
74 CHECK(message_size == message.data.size());
75 }
76
77 private:
78 std::array<char, 1024> control_buf_;
79 sockaddr_storage addr_;
80 WSABUF buf_;
81 };
82
83 class UdpSocketSendHelper {
84 public:
to_native(const UdpMessage & message,WSAMSG & message_header)85 void to_native(const UdpMessage &message, WSAMSG &message_header) {
86 message_header.name = const_cast<sockaddr *>(message.address.get_sockaddr());
87 message_header.namelen = narrow_cast<socklen_t>(message.address.get_sockaddr_len());
88 buf_.buf = const_cast<char *>(message.data.as_slice().begin());
89 buf_.len = narrow_cast<DWORD>(message.data.size());
90 message_header.lpBuffers = &buf_;
91 message_header.dwBufferCount = 1;
92
93 message_header.Control.buf = nullptr;
94 message_header.Control.len = 0;
95 message_header.dwFlags = 0;
96 }
97
98 private:
99 WSABUF buf_;
100 };
101
102 class UdpSocketFdImpl final : private Iocp::Callback {
103 public:
UdpSocketFdImpl(NativeFd fd)104 explicit UdpSocketFdImpl(NativeFd fd) : info_(std::move(fd)) {
105 get_poll_info().add_flags(PollFlags::Write());
106 Iocp::get()->subscribe(get_native_fd(), this);
107 is_receive_active_ = true;
108 notify_iocp_connected();
109 }
get_poll_info()110 PollableFdInfo &get_poll_info() {
111 return info_;
112 }
get_poll_info() const113 const PollableFdInfo &get_poll_info() const {
114 return info_;
115 }
116
get_native_fd() const117 const NativeFd &get_native_fd() const {
118 return info_.native_fd();
119 }
120
close()121 void close() {
122 notify_iocp_close();
123 }
124
receive()125 Result<optional<UdpMessage>> receive() {
126 auto lock = lock_.lock();
127 if (!pending_errors_.empty()) {
128 auto status = pending_errors_.pop();
129 if (!UdpSocketFd::is_critical_read_error(status)) {
130 return UdpMessage{{}, {}, std::move(status)};
131 }
132 return std::move(status);
133 }
134 if (!receive_queue_.empty()) {
135 return receive_queue_.pop();
136 }
137
138 return optional<UdpMessage>{};
139 }
140
send(UdpMessage message)141 void send(UdpMessage message) {
142 auto lock = lock_.lock();
143 send_queue_.push(std::move(message));
144 }
145
flush_send()146 Status flush_send() {
147 if (is_send_waiting_) {
148 auto lock = lock_.lock();
149 is_send_waiting_ = false;
150 notify_iocp_send();
151 }
152 return Status::OK();
153 }
154
155 private:
156 PollableFdInfo info_;
157 SpinLock lock_;
158
159 std::atomic<int> refcnt_{1};
160 bool is_connected_{false};
161 bool close_flag_{false};
162
163 bool is_send_active_{false};
164 bool is_send_waiting_{false};
165 VectorQueue<UdpMessage> send_queue_;
166 WSAOVERLAPPED send_overlapped_;
167
168 bool is_receive_active_{false};
169 VectorQueue<UdpMessage> receive_queue_;
170 VectorQueue<Status> pending_errors_;
171 UdpMessage to_receive_;
172 WSAMSG receive_message_;
173 UdpSocketReceiveHelper receive_helper_;
174 static constexpr size_t MAX_PACKET_SIZE = 2048;
175 static constexpr size_t RESERVED_SIZE = MAX_PACKET_SIZE * 8;
176 BufferSlice receive_buffer_;
177
178 UdpMessage to_send_;
179 WSAOVERLAPPED receive_overlapped_;
180
181 char close_overlapped_;
182
check_status(Slice message)183 bool check_status(Slice message) {
184 auto last_error = WSAGetLastError();
185 if (last_error == ERROR_IO_PENDING) {
186 return true;
187 }
188 on_error(OS_SOCKET_ERROR(message));
189 return false;
190 }
191
loop_receive()192 void loop_receive() {
193 CHECK(!is_receive_active_);
194 if (close_flag_) {
195 return;
196 }
197 std::memset(&receive_overlapped_, 0, sizeof(receive_overlapped_));
198 if (receive_buffer_.size() < MAX_PACKET_SIZE) {
199 receive_buffer_ = BufferSlice(RESERVED_SIZE);
200 }
201 to_receive_.data = receive_buffer_.clone();
202 receive_helper_.to_native(to_receive_, receive_message_);
203
204 LPFN_WSARECVMSG WSARecvMsgPtr = nullptr;
205 GUID guid = WSAID_WSARECVMSG;
206 DWORD numBytes;
207 auto error = ::WSAIoctl(get_native_fd().socket(), SIO_GET_EXTENSION_FUNCTION_POINTER, static_cast<void *>(&guid),
208 sizeof(guid), static_cast<void *>(&WSARecvMsgPtr), sizeof(WSARecvMsgPtr), &numBytes,
209 nullptr, nullptr);
210 if (error) {
211 on_error(OS_SOCKET_ERROR("WSAIoctl failed"));
212 return;
213 }
214
215 auto status = WSARecvMsgPtr(get_native_fd().socket(), &receive_message_, nullptr, &receive_overlapped_, nullptr);
216 if (status == 0 || check_status("WSARecvMsg failed")) {
217 inc_refcnt();
218 is_receive_active_ = true;
219 }
220 }
221
loop_send()222 void loop_send() {
223 CHECK(!is_send_active_);
224
225 {
226 auto lock = lock_.lock();
227 if (send_queue_.empty()) {
228 is_send_waiting_ = true;
229 return;
230 }
231 to_send_ = send_queue_.pop();
232 }
233 std::memset(&send_overlapped_, 0, sizeof(send_overlapped_));
234 WSAMSG message;
235 UdpSocketSendHelper send_helper;
236 send_helper.to_native(to_send_, message);
237 auto status = WSASendMsg(get_native_fd().socket(), &message, 0, nullptr, &send_overlapped_, nullptr);
238 if (status == 0 || check_status("WSASendMsg failed")) {
239 inc_refcnt();
240 is_send_active_ = true;
241 }
242 }
243
on_iocp(Result<size_t> r_size,WSAOVERLAPPED * overlapped)244 void on_iocp(Result<size_t> r_size, WSAOVERLAPPED *overlapped) final {
245 // called from other thread
246 if (dec_refcnt() || close_flag_) {
247 VLOG(fd) << "Ignore IOCP (UDP socket is closing)";
248 return;
249 }
250 if (r_size.is_error()) {
251 return on_error(get_socket_pending_error(get_native_fd(), overlapped, r_size.move_as_error()));
252 }
253
254 if (!is_connected_ && overlapped == &receive_overlapped_) {
255 return on_connected();
256 }
257
258 auto size = r_size.move_as_ok();
259 if (overlapped == &send_overlapped_) {
260 return on_send(size);
261 }
262 if (overlapped == nullptr) {
263 CHECK(size == 0);
264 return on_send(size);
265 }
266
267 if (overlapped == &receive_overlapped_) {
268 return on_receive(size);
269 }
270 if (overlapped == reinterpret_cast<WSAOVERLAPPED *>(&close_overlapped_)) {
271 return on_close();
272 }
273 UNREACHABLE();
274 }
275
on_error(Status status)276 void on_error(Status status) {
277 VLOG(fd) << get_native_fd() << " on error " << status;
278 {
279 auto lock = lock_.lock();
280 pending_errors_.push(std::move(status));
281 }
282 get_poll_info().add_flags_from_poll(PollFlags::Error());
283 }
284
on_connected()285 void on_connected() {
286 VLOG(fd) << get_native_fd() << " on connected";
287 CHECK(!is_connected_);
288 CHECK(is_receive_active_);
289 is_connected_ = true;
290 is_receive_active_ = false;
291 loop_receive();
292 loop_send();
293 }
294
on_receive(size_t size)295 void on_receive(size_t size) {
296 VLOG(fd) << get_native_fd() << " on receive " << size;
297 CHECK(is_receive_active_);
298 is_receive_active_ = false;
299 UdpSocketReceiveHelper::from_native(receive_message_, size, to_receive_);
300 receive_buffer_.confirm_read((to_receive_.data.size() + 7) & ~7);
301 {
302 auto lock = lock_.lock();
303 // LOG(ERROR) << format::escaped(to_receive_.data.as_slice());
304 receive_queue_.push(std::move(to_receive_));
305 }
306 get_poll_info().add_flags_from_poll(PollFlags::Read());
307 loop_receive();
308 }
309
on_send(size_t size)310 void on_send(size_t size) {
311 VLOG(fd) << get_native_fd() << " on send " << size;
312 if (size == 0) {
313 if (is_send_active_) {
314 return;
315 }
316 is_send_active_ = true;
317 }
318 CHECK(is_send_active_);
319 is_send_active_ = false;
320 loop_send();
321 }
322
on_close()323 void on_close() {
324 VLOG(fd) << get_native_fd() << " on close";
325 close_flag_ = true;
326 info_.set_native_fd({});
327 }
328
dec_refcnt()329 bool dec_refcnt() {
330 if (--refcnt_ == 0) {
331 delete this;
332 return true;
333 }
334 return false;
335 }
336
inc_refcnt()337 void inc_refcnt() {
338 CHECK(refcnt_ != 0);
339 refcnt_++;
340 }
341
notify_iocp_send()342 void notify_iocp_send() {
343 inc_refcnt();
344 Iocp::get()->post(0, this, nullptr);
345 }
notify_iocp_close()346 void notify_iocp_close() {
347 Iocp::get()->post(0, this, reinterpret_cast<WSAOVERLAPPED *>(&close_overlapped_));
348 }
notify_iocp_connected()349 void notify_iocp_connected() {
350 inc_refcnt();
351 Iocp::get()->post(0, this, reinterpret_cast<WSAOVERLAPPED *>(&receive_overlapped_));
352 }
353 };
354
operator ()(UdpSocketFdImpl * impl)355 void UdpSocketFdImplDeleter::operator()(UdpSocketFdImpl *impl) {
356 impl->close();
357 }
358
359 #elif TD_PORT_POSIX
360 //struct iovec { [> Scatter/gather array items <]
361 // void *iov_base; [> Starting address <]
362 // size_t iov_len; [> Number of bytes to transfer <]
363 //};
364
365 //struct msghdr {
366 // void *msg_name; [> optional address <]
367 // socklen_t msg_namelen; [> size of address <]
368 // struct iovec *msg_iov; [> scatter/gather array <]
369 // size_t msg_iovlen; [> # elements in msg_iov <]
370 // void *msg_control; [> ancillary data, see below <]
371 // size_t msg_controllen; [> ancillary data buffer len <]
372 // int msg_flags; [> flags on received message <]
373 //};
374
375 class UdpSocketReceiveHelper {
376 public:
377 void to_native(const UdpSocketFd::InboundMessage &message, msghdr &message_header) {
378 socklen_t addr_len{narrow_cast<socklen_t>(sizeof(addr_))};
379
380 message_header.msg_name = &addr_;
381 message_header.msg_namelen = addr_len;
382 io_vec_.iov_base = message.data.begin();
383 io_vec_.iov_len = message.data.size();
384 message_header.msg_iov = &io_vec_;
385 message_header.msg_iovlen = 1;
386 message_header.msg_control = control_buf_.data();
387 message_header.msg_controllen = narrow_cast<decltype(message_header.msg_controllen)>(control_buf_.size());
388 message_header.msg_flags = 0;
389 }
390
391 static void from_native(msghdr &message_header, size_t message_size, UdpSocketFd::InboundMessage &message) {
392 #if TD_LINUX
393 cmsghdr *cmsg;
394 sock_extended_err *ee = nullptr;
395 for (cmsg = CMSG_FIRSTHDR(&message_header); cmsg != nullptr; cmsg = CMSG_NXTHDR(&message_header, cmsg)) {
396 if (cmsg->cmsg_type == IP_PKTINFO && cmsg->cmsg_level == IPPROTO_IP) {
397 //auto *pi = reinterpret_cast<in_pktinfo *>(CMSG_DATA(cmsg));
398 } else if (cmsg->cmsg_type == IPV6_PKTINFO && cmsg->cmsg_level == IPPROTO_IPV6) {
399 //auto *pi = reinterpret_cast<in6_pktinfo *>(CMSG_DATA(cmsg));
400 } else if ((cmsg->cmsg_type == IP_RECVERR && cmsg->cmsg_level == IPPROTO_IP) ||
401 (cmsg->cmsg_type == IPV6_RECVERR && cmsg->cmsg_level == IPPROTO_IPV6)) {
402 ee = reinterpret_cast<sock_extended_err *>(CMSG_DATA(cmsg));
403 }
404 }
405 if (ee != nullptr) {
406 auto *addr = reinterpret_cast<sockaddr *>(SO_EE_OFFENDER(ee));
407 IPAddress address;
408 address.init_sockaddr(addr).ignore();
409 if (message.from != nullptr) {
410 *message.from = address;
411 }
412 if (message.error) {
413 *message.error = Status::PosixError(ee->ee_errno, "");
414 }
415 //message.data = MutableSlice();
416 message.data.truncate(0);
417 return;
418 }
419 #endif
420 if (message.from != nullptr) {
421 message.from->init_sockaddr(reinterpret_cast<sockaddr *>(message_header.msg_name), message_header.msg_namelen)
422 .ignore();
423 }
424 if (message.error) {
425 *message.error = Status::OK();
426 }
427 if (message_header.msg_flags & MSG_TRUNC) {
428 if (message.error) {
429 *message.error = Status::Error(501, "Message too long");
430 }
431 message.data.truncate(0);
432 return;
433 }
434 CHECK(message_size <= message.data.size());
435 message.data.truncate(message_size);
436 CHECK(message_size == message.data.size());
437 }
438
439 private:
440 std::array<char, 1024> control_buf_;
441 sockaddr_storage addr_;
442 iovec io_vec_;
443 };
444
445 class UdpSocketSendHelper {
446 public:
447 void to_native(const UdpSocketFd::OutboundMessage &message, msghdr &message_header) {
448 CHECK(message.to != nullptr && message.to->is_valid());
449 message_header.msg_name = const_cast<sockaddr *>(message.to->get_sockaddr());
450 message_header.msg_namelen = narrow_cast<socklen_t>(message.to->get_sockaddr_len());
451 io_vec_.iov_base = const_cast<char *>(message.data.begin());
452 io_vec_.iov_len = message.data.size();
453 message_header.msg_iov = &io_vec_;
454 message_header.msg_iovlen = 1;
455 //TODO
456 message_header.msg_control = nullptr;
457 message_header.msg_controllen = 0;
458 message_header.msg_flags = 0;
459 }
460
461 private:
462 iovec io_vec_;
463 };
464
465 class UdpSocketFdImpl {
466 public:
467 explicit UdpSocketFdImpl(NativeFd fd) : info_(std::move(fd)) {
468 }
469 PollableFdInfo &get_poll_info() {
470 return info_;
471 }
472 const PollableFdInfo &get_poll_info() const {
473 return info_;
474 }
475
476 const NativeFd &get_native_fd() const {
477 return info_.native_fd();
478 }
479 Status get_pending_error() {
480 if (!get_poll_info().get_flags_local().has_pending_error()) {
481 return Status::OK();
482 }
483 TRY_STATUS(detail::get_socket_pending_error(get_native_fd()));
484 get_poll_info().clear_flags(PollFlags::Error());
485 return Status::OK();
486 }
487 Status receive_message(UdpSocketFd::InboundMessage &message, bool &is_received) {
488 is_received = false;
489 int flags = 0;
490 if (get_poll_info().get_flags_local().has_pending_error()) {
491 #ifdef MSG_ERRQUEUE
492 flags = MSG_ERRQUEUE;
493 #else
494 return get_pending_error();
495 #endif
496 }
497
498 msghdr message_header;
499 detail::UdpSocketReceiveHelper helper;
500 helper.to_native(message, message_header);
501
502 auto native_fd = get_native_fd().socket();
503 auto recvmsg_res = detail::skip_eintr([&] { return recvmsg(native_fd, &message_header, flags); });
504 auto recvmsg_errno = errno;
505 if (recvmsg_res >= 0) {
506 UdpSocketReceiveHelper::from_native(message_header, recvmsg_res, message);
507 is_received = true;
508 return Status::OK();
509 }
510 return process_recvmsg_error(recvmsg_errno, is_received);
511 }
512
513 Status process_recvmsg_error(int recvmsg_errno, bool &is_received) {
514 is_received = false;
515 if (recvmsg_errno == EAGAIN
516 #if EAGAIN != EWOULDBLOCK
517 || recvmsg_errno == EWOULDBLOCK
518 #endif
519 ) {
520 if (get_poll_info().get_flags_local().has_pending_error()) {
521 get_poll_info().clear_flags(PollFlags::Error());
522 } else {
523 get_poll_info().clear_flags(PollFlags::Read());
524 }
525 return Status::OK();
526 }
527
528 auto error = Status::PosixError(recvmsg_errno, PSLICE() << "Receive from " << get_native_fd() << " has failed");
529 switch (recvmsg_errno) {
530 case EBADF:
531 case EFAULT:
532 case EINVAL:
533 case ENOTCONN:
534 case ECONNRESET:
535 case ETIMEDOUT:
536 LOG(FATAL) << error;
537 UNREACHABLE();
538 default:
539 LOG(WARNING) << "Unknown error: " << error;
540 // fallthrough
541 case ENOBUFS:
542 case ENOMEM:
543 #ifdef MSG_ERRQUEUE
544 get_poll_info().add_flags(PollFlags::Error());
545 #endif
546 return error;
547 }
548 }
549
550 Status send_message(const UdpSocketFd::OutboundMessage &message, bool &is_sent) {
551 is_sent = false;
552 msghdr message_header;
553 detail::UdpSocketSendHelper helper;
554 helper.to_native(message, message_header);
555
556 auto native_fd = get_native_fd().socket();
557 auto sendmsg_res = detail::skip_eintr([&] { return sendmsg(native_fd, &message_header, 0); });
558 auto sendmsg_errno = errno;
559 if (sendmsg_res >= 0) {
560 is_sent = true;
561 return Status::OK();
562 }
563 return process_sendmsg_error(sendmsg_errno, is_sent);
564 }
565 Status process_sendmsg_error(int sendmsg_errno, bool &is_sent) {
566 if (sendmsg_errno == EAGAIN
567 #if EAGAIN != EWOULDBLOCK
568 || sendmsg_errno == EWOULDBLOCK
569 #endif
570 ) {
571 get_poll_info().clear_flags(PollFlags::Write());
572 return Status::OK();
573 }
574
575 auto error = Status::PosixError(sendmsg_errno, PSLICE() << "Send from " << get_native_fd() << " has failed");
576 switch (sendmsg_errno) {
577 // Still may send some other packets, but there is no point to resend this particular message
578 case EACCES:
579 case EMSGSIZE:
580 case EPERM:
581 LOG(WARNING) << "Silently drop packet :( " << error;
582 //TODO: get errors from MSG_ERRQUEUE is possible
583 is_sent = true;
584 return error;
585
586 // Some general problems, which may be fixed in future
587 case ENOMEM:
588 case EDQUOT:
589 case EFBIG:
590 case ENETDOWN:
591 case ENETUNREACH:
592 case ENOSPC:
593 case EHOSTUNREACH:
594 case ENOBUFS:
595 default:
596 #ifdef MSG_ERRQUEUE
597 get_poll_info().add_flags(PollFlags::Error());
598 #endif
599 return error;
600
601 case EBADF: // impossible
602 case ENOTSOCK: // impossible
603 case EPIPE: // impossible for udp
604 case ECONNRESET: // impossible for udp
605 case EDESTADDRREQ: // we checked that address is valid
606 case ENOTCONN: // we checked that address is valid
607 case EINTR: // we already skipped all EINTR
608 case EISCONN: // impossible for udp socket
609 case EOPNOTSUPP:
610 case ENOTDIR:
611 case EFAULT:
612 case EINVAL:
613 case EAFNOSUPPORT:
614 LOG(FATAL) << error;
615 UNREACHABLE();
616 return error;
617 }
618 }
619
620 Status send_messages(Span<UdpSocketFd::OutboundMessage> messages, size_t &cnt) {
621 #if TD_HAS_MMSG
622 return send_messages_fast(messages, cnt);
623 #else
624 return send_messages_slow(messages, cnt);
625 #endif
626 }
627
628 Status receive_messages(MutableSpan<UdpSocketFd::InboundMessage> messages, size_t &cnt) {
629 #if TD_HAS_MMSG
630 return receive_messages_fast(messages, cnt);
631 #else
632 return receive_messages_slow(messages, cnt);
633 #endif
634 }
635
636 private:
637 PollableFdInfo info_;
638
639 Status send_messages_slow(Span<UdpSocketFd::OutboundMessage> messages, size_t &cnt) {
640 cnt = 0;
641 for (auto &message : messages) {
642 CHECK(!message.data.empty());
643 bool is_sent;
644 auto error = send_message(message, is_sent);
645 cnt += is_sent;
646 TRY_STATUS(std::move(error));
647 }
648 return Status::OK();
649 }
650
651 #if TD_HAS_MMSG
652 Status send_messages_fast(Span<UdpSocketFd::OutboundMessage> messages, size_t &cnt) {
653 //struct mmsghdr {
654 // msghdr msg_hdr; [> Message header <]
655 // unsigned int msg_len; [> Number of bytes transmitted <]
656 //};
657 std::array<detail::UdpSocketSendHelper, 16> helpers;
658 std::array<mmsghdr, 16> headers;
659 size_t to_send = min(messages.size(), headers.size());
660 for (size_t i = 0; i < to_send; i++) {
661 helpers[i].to_native(messages[i], headers[i].msg_hdr);
662 headers[i].msg_len = 0;
663 }
664
665 auto native_fd = get_native_fd().socket();
666 auto sendmmsg_res =
667 detail::skip_eintr([&] { return sendmmsg(native_fd, headers.data(), narrow_cast<unsigned int>(to_send), 0); });
668 auto sendmmsg_errno = errno;
669 if (sendmmsg_res >= 0) {
670 cnt = sendmmsg_res;
671 return Status::OK();
672 }
673
674 bool is_sent = false;
675 auto status = process_sendmsg_error(sendmmsg_errno, is_sent);
676 cnt = is_sent;
677 return status;
678 }
679 #endif
680 Status receive_messages_slow(MutableSpan<UdpSocketFd::InboundMessage> messages, size_t &cnt) {
681 cnt = 0;
682 while (cnt < messages.size() && get_poll_info().get_flags_local().can_read()) {
683 auto &message = messages[cnt];
684 CHECK(!message.data.empty());
685 bool is_received;
686 auto error = receive_message(message, is_received);
687 cnt += is_received;
688 TRY_STATUS(std::move(error));
689 }
690 return Status::OK();
691 }
692
693 #if TD_HAS_MMSG
694 Status receive_messages_fast(MutableSpan<UdpSocketFd::InboundMessage> messages, size_t &cnt) {
695 int flags = 0;
696 cnt = 0;
697 if (get_poll_info().get_flags_local().has_pending_error()) {
698 #ifdef MSG_ERRQUEUE
699 flags = MSG_ERRQUEUE;
700 #else
701 return get_pending_error();
702 #endif
703 }
704 //struct mmsghdr {
705 // msghdr msg_hdr; [> Message header <]
706 // unsigned int msg_len; [> Number of bytes transmitted <]
707 //};
708 std::array<detail::UdpSocketReceiveHelper, 16> helpers;
709 std::array<mmsghdr, 16> headers;
710 size_t to_receive = min(messages.size(), headers.size());
711 for (size_t i = 0; i < to_receive; i++) {
712 helpers[i].to_native(messages[i], headers[i].msg_hdr);
713 headers[i].msg_len = 0;
714 }
715
716 auto native_fd = get_native_fd().socket();
717 auto recvmmsg_res = detail::skip_eintr(
718 [&] { return recvmmsg(native_fd, headers.data(), narrow_cast<unsigned int>(to_receive), flags, nullptr); });
719 auto recvmmsg_errno = errno;
720 if (recvmmsg_res >= 0) {
721 cnt = narrow_cast<size_t>(recvmmsg_res);
722 for (size_t i = 0; i < cnt; i++) {
723 UdpSocketReceiveHelper::from_native(headers[i].msg_hdr, headers[i].msg_len, messages[i]);
724 }
725 return Status::OK();
726 }
727
728 bool is_received;
729 auto status = process_recvmsg_error(recvmmsg_errno, is_received);
730 cnt = is_received;
731 return status;
732 }
733 #endif
734 };
735 void UdpSocketFdImplDeleter::operator()(UdpSocketFdImpl *impl) {
736 delete impl;
737 }
738 #endif
739 } // namespace detail
740
741 UdpSocketFd::UdpSocketFd() = default;
742 UdpSocketFd::UdpSocketFd(UdpSocketFd &&) noexcept = default;
743 UdpSocketFd &UdpSocketFd::operator=(UdpSocketFd &&) noexcept = default;
744 UdpSocketFd::~UdpSocketFd() = default;
get_poll_info()745 PollableFdInfo &UdpSocketFd::get_poll_info() {
746 return impl_->get_poll_info();
747 }
get_poll_info() const748 const PollableFdInfo &UdpSocketFd::get_poll_info() const {
749 return impl_->get_poll_info();
750 }
751
open(const IPAddress & address)752 Result<UdpSocketFd> UdpSocketFd::open(const IPAddress &address) {
753 NativeFd native_fd{socket(address.get_address_family(), SOCK_DGRAM, IPPROTO_UDP)};
754 if (!native_fd) {
755 return OS_SOCKET_ERROR("Failed to create a socket");
756 }
757 TRY_STATUS(native_fd.set_is_blocking_unsafe(false));
758
759 auto sock = native_fd.socket();
760 #if TD_PORT_POSIX
761 int flags = 1;
762 #elif TD_PORT_WINDOWS
763 BOOL flags = TRUE;
764 #endif
765 setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char *>(&flags), sizeof(flags));
766 // TODO: SO_REUSEADDR, SO_KEEPALIVE, TCP_NODELAY, SO_SNDBUF, SO_RCVBUF, TCP_QUICKACK, SO_LINGER
767
768 auto bind_addr = address.get_any_addr();
769 bind_addr.set_port(address.get_port());
770 auto e_bind = bind(sock, bind_addr.get_sockaddr(), narrow_cast<int>(bind_addr.get_sockaddr_len()));
771 if (e_bind != 0) {
772 return OS_SOCKET_ERROR("Failed to bind a socket");
773 }
774 return UdpSocketFd(make_unique<detail::UdpSocketFdImpl>(std::move(native_fd)));
775 }
776
UdpSocketFd(unique_ptr<detail::UdpSocketFdImpl> impl)777 UdpSocketFd::UdpSocketFd(unique_ptr<detail::UdpSocketFdImpl> impl) : impl_(impl.release()) {
778 }
779
close()780 void UdpSocketFd::close() {
781 impl_.reset();
782 }
783
empty() const784 bool UdpSocketFd::empty() const {
785 return !impl_;
786 }
787
get_native_fd() const788 const NativeFd &UdpSocketFd::get_native_fd() const {
789 return get_poll_info().native_fd();
790 }
791
792 #if TD_PORT_POSIX
maximize_buffer(int socket_fd,int optname,uint32 max)793 static Result<uint32> maximize_buffer(int socket_fd, int optname, uint32 max) {
794 if (setsockopt(socket_fd, SOL_SOCKET, optname, &max, sizeof(max)) == 0) {
795 // fast path
796 return max;
797 }
798
799 /* Start with the default size. */
800 uint32 old_size = 0;
801 socklen_t intsize = sizeof(old_size);
802 if (getsockopt(socket_fd, SOL_SOCKET, optname, &old_size, &intsize)) {
803 return OS_ERROR("getsockopt() failed");
804 }
805 #if TD_LINUX
806 old_size /= 2;
807 #endif
808
809 /* Binary-search for the real maximum. */
810 uint32 last_good = old_size;
811 uint32 min = old_size;
812 while (min <= max) {
813 uint32 avg = min + (max - min) / 2;
814 if (setsockopt(socket_fd, SOL_SOCKET, optname, &avg, sizeof(avg)) == 0) {
815 last_good = avg;
816 min = avg + 1;
817 } else {
818 max = avg - 1;
819 }
820 }
821 return last_good;
822 }
823
maximize_snd_buffer(uint32 max)824 Result<uint32> UdpSocketFd::maximize_snd_buffer(uint32 max) {
825 return maximize_buffer(get_native_fd().fd(), SO_SNDBUF, max == 0 ? DEFAULT_UDP_MAX_SND_BUFFER_SIZE : max);
826 }
827
maximize_rcv_buffer(uint32 max)828 Result<uint32> UdpSocketFd::maximize_rcv_buffer(uint32 max) {
829 return maximize_buffer(get_native_fd().fd(), SO_RCVBUF, max == 0 ? DEFAULT_UDP_MAX_RCV_BUFFER_SIZE : max);
830 }
831 #else
maximize_snd_buffer(uint32 max)832 Result<uint32> UdpSocketFd::maximize_snd_buffer(uint32 max) {
833 return 0;
834 }
maximize_rcv_buffer(uint32 max)835 Result<uint32> UdpSocketFd::maximize_rcv_buffer(uint32 max) {
836 return 0;
837 }
838 #endif
839
840 #if TD_PORT_POSIX
send_message(const OutboundMessage & message,bool & is_sent)841 Status UdpSocketFd::send_message(const OutboundMessage &message, bool &is_sent) {
842 return impl_->send_message(message, is_sent);
843 }
receive_message(InboundMessage & message,bool & is_received)844 Status UdpSocketFd::receive_message(InboundMessage &message, bool &is_received) {
845 return impl_->receive_message(message, is_received);
846 }
847
send_messages(Span<OutboundMessage> messages,size_t & count)848 Status UdpSocketFd::send_messages(Span<OutboundMessage> messages, size_t &count) {
849 return impl_->send_messages(messages, count);
850 }
receive_messages(MutableSpan<InboundMessage> messages,size_t & count)851 Status UdpSocketFd::receive_messages(MutableSpan<InboundMessage> messages, size_t &count) {
852 return impl_->receive_messages(messages, count);
853 }
854 #endif
855 #if TD_PORT_WINDOWS
receive()856 Result<optional<UdpMessage>> UdpSocketFd::receive() {
857 return impl_->receive();
858 }
859
send(UdpMessage message)860 void UdpSocketFd::send(UdpMessage message) {
861 return impl_->send(std::move(message));
862 }
863
flush_send()864 Status UdpSocketFd::flush_send() {
865 return impl_->flush_send();
866 }
867 #endif
868
is_critical_read_error(const Status & status)869 bool UdpSocketFd::is_critical_read_error(const Status &status) {
870 return status.code() == ENOMEM || status.code() == ENOBUFS;
871 }
872
873 } // namespace td
874