1 //
2 // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2021
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 //
7 #include "td/utils/port/UdpSocketFd.h"
8 
9 #include "td/utils/common.h"
10 #include "td/utils/format.h"
11 #include "td/utils/logging.h"
12 #include "td/utils/misc.h"
13 #include "td/utils/port/detail/skip_eintr.h"
14 #include "td/utils/port/PollFlags.h"
15 #include "td/utils/port/SocketFd.h"
16 #include "td/utils/SliceBuilder.h"
17 #include "td/utils/VectorQueue.h"
18 
19 #if TD_PORT_WINDOWS
20 #include "td/utils/port/detail/Iocp.h"
21 #include "td/utils/SpinLock.h"
22 #endif
23 
24 #if TD_PORT_POSIX
25 #include <cerrno>
26 
27 #include <arpa/inet.h>
28 #include <fcntl.h>
29 #include <netinet/in.h>
30 #include <netinet/tcp.h>
31 #include <sys/socket.h>
32 #include <sys/types.h>
33 #include <unistd.h>
34 
35 #if TD_LINUX
36 #include <linux/errqueue.h>
37 #endif
38 #endif
39 
40 #include <array>
41 #include <atomic>
42 #include <cstring>
43 
44 namespace td {
45 namespace detail {
46 #if TD_PORT_WINDOWS
47 class UdpSocketReceiveHelper {
48  public:
to_native(const UdpMessage & message,WSAMSG & message_header)49   void to_native(const UdpMessage &message, WSAMSG &message_header) {
50     socklen_t addr_len{narrow_cast<socklen_t>(sizeof(addr_))};
51     message_header.name = reinterpret_cast<sockaddr *>(&addr_);
52     message_header.namelen = addr_len;
53     buf_.buf = const_cast<char *>(message.data.as_slice().begin());
54     buf_.len = narrow_cast<DWORD>(message.data.size());
55     message_header.lpBuffers = &buf_;
56     message_header.dwBufferCount = 1;
57     message_header.Control.buf = nullptr;  // control_buf_.data();
58     message_header.Control.len = 0;        // narrow_cast<decltype(message_header.Control.len)>(control_buf_.size());
59     message_header.dwFlags = 0;
60   }
61 
from_native(WSAMSG & message_header,size_t message_size,UdpMessage & message)62   static void from_native(WSAMSG &message_header, size_t message_size, UdpMessage &message) {
63     message.address.init_sockaddr(reinterpret_cast<sockaddr *>(message_header.name), message_header.namelen).ignore();
64     message.error = Status::OK();
65 
66     if ((message_header.dwFlags & (MSG_TRUNC | MSG_CTRUNC)) != 0) {
67       message.error = Status::Error(501, "Message too long");
68       message.data = BufferSlice();
69       return;
70     }
71 
72     CHECK(message_size <= message.data.size());
73     message.data.truncate(message_size);
74     CHECK(message_size == message.data.size());
75   }
76 
77  private:
78   std::array<char, 1024> control_buf_;
79   sockaddr_storage addr_;
80   WSABUF buf_;
81 };
82 
83 class UdpSocketSendHelper {
84  public:
to_native(const UdpMessage & message,WSAMSG & message_header)85   void to_native(const UdpMessage &message, WSAMSG &message_header) {
86     message_header.name = const_cast<sockaddr *>(message.address.get_sockaddr());
87     message_header.namelen = narrow_cast<socklen_t>(message.address.get_sockaddr_len());
88     buf_.buf = const_cast<char *>(message.data.as_slice().begin());
89     buf_.len = narrow_cast<DWORD>(message.data.size());
90     message_header.lpBuffers = &buf_;
91     message_header.dwBufferCount = 1;
92 
93     message_header.Control.buf = nullptr;
94     message_header.Control.len = 0;
95     message_header.dwFlags = 0;
96   }
97 
98  private:
99   WSABUF buf_;
100 };
101 
102 class UdpSocketFdImpl final : private Iocp::Callback {
103  public:
UdpSocketFdImpl(NativeFd fd)104   explicit UdpSocketFdImpl(NativeFd fd) : info_(std::move(fd)) {
105     get_poll_info().add_flags(PollFlags::Write());
106     Iocp::get()->subscribe(get_native_fd(), this);
107     is_receive_active_ = true;
108     notify_iocp_connected();
109   }
get_poll_info()110   PollableFdInfo &get_poll_info() {
111     return info_;
112   }
get_poll_info() const113   const PollableFdInfo &get_poll_info() const {
114     return info_;
115   }
116 
get_native_fd() const117   const NativeFd &get_native_fd() const {
118     return info_.native_fd();
119   }
120 
close()121   void close() {
122     notify_iocp_close();
123   }
124 
receive()125   Result<optional<UdpMessage>> receive() {
126     auto lock = lock_.lock();
127     if (!pending_errors_.empty()) {
128       auto status = pending_errors_.pop();
129       if (!UdpSocketFd::is_critical_read_error(status)) {
130         return UdpMessage{{}, {}, std::move(status)};
131       }
132       return std::move(status);
133     }
134     if (!receive_queue_.empty()) {
135       return receive_queue_.pop();
136     }
137 
138     return optional<UdpMessage>{};
139   }
140 
send(UdpMessage message)141   void send(UdpMessage message) {
142     auto lock = lock_.lock();
143     send_queue_.push(std::move(message));
144   }
145 
flush_send()146   Status flush_send() {
147     if (is_send_waiting_) {
148       auto lock = lock_.lock();
149       is_send_waiting_ = false;
150       notify_iocp_send();
151     }
152     return Status::OK();
153   }
154 
155  private:
156   PollableFdInfo info_;
157   SpinLock lock_;
158 
159   std::atomic<int> refcnt_{1};
160   bool is_connected_{false};
161   bool close_flag_{false};
162 
163   bool is_send_active_{false};
164   bool is_send_waiting_{false};
165   VectorQueue<UdpMessage> send_queue_;
166   WSAOVERLAPPED send_overlapped_;
167 
168   bool is_receive_active_{false};
169   VectorQueue<UdpMessage> receive_queue_;
170   VectorQueue<Status> pending_errors_;
171   UdpMessage to_receive_;
172   WSAMSG receive_message_;
173   UdpSocketReceiveHelper receive_helper_;
174   static constexpr size_t MAX_PACKET_SIZE = 2048;
175   static constexpr size_t RESERVED_SIZE = MAX_PACKET_SIZE * 8;
176   BufferSlice receive_buffer_;
177 
178   UdpMessage to_send_;
179   WSAOVERLAPPED receive_overlapped_;
180 
181   char close_overlapped_;
182 
check_status(Slice message)183   bool check_status(Slice message) {
184     auto last_error = WSAGetLastError();
185     if (last_error == ERROR_IO_PENDING) {
186       return true;
187     }
188     on_error(OS_SOCKET_ERROR(message));
189     return false;
190   }
191 
loop_receive()192   void loop_receive() {
193     CHECK(!is_receive_active_);
194     if (close_flag_) {
195       return;
196     }
197     std::memset(&receive_overlapped_, 0, sizeof(receive_overlapped_));
198     if (receive_buffer_.size() < MAX_PACKET_SIZE) {
199       receive_buffer_ = BufferSlice(RESERVED_SIZE);
200     }
201     to_receive_.data = receive_buffer_.clone();
202     receive_helper_.to_native(to_receive_, receive_message_);
203 
204     LPFN_WSARECVMSG WSARecvMsgPtr = nullptr;
205     GUID guid = WSAID_WSARECVMSG;
206     DWORD numBytes;
207     auto error = ::WSAIoctl(get_native_fd().socket(), SIO_GET_EXTENSION_FUNCTION_POINTER, static_cast<void *>(&guid),
208                             sizeof(guid), static_cast<void *>(&WSARecvMsgPtr), sizeof(WSARecvMsgPtr), &numBytes,
209                             nullptr, nullptr);
210     if (error) {
211       on_error(OS_SOCKET_ERROR("WSAIoctl failed"));
212       return;
213     }
214 
215     auto status = WSARecvMsgPtr(get_native_fd().socket(), &receive_message_, nullptr, &receive_overlapped_, nullptr);
216     if (status == 0 || check_status("WSARecvMsg failed")) {
217       inc_refcnt();
218       is_receive_active_ = true;
219     }
220   }
221 
loop_send()222   void loop_send() {
223     CHECK(!is_send_active_);
224 
225     {
226       auto lock = lock_.lock();
227       if (send_queue_.empty()) {
228         is_send_waiting_ = true;
229         return;
230       }
231       to_send_ = send_queue_.pop();
232     }
233     std::memset(&send_overlapped_, 0, sizeof(send_overlapped_));
234     WSAMSG message;
235     UdpSocketSendHelper send_helper;
236     send_helper.to_native(to_send_, message);
237     auto status = WSASendMsg(get_native_fd().socket(), &message, 0, nullptr, &send_overlapped_, nullptr);
238     if (status == 0 || check_status("WSASendMsg failed")) {
239       inc_refcnt();
240       is_send_active_ = true;
241     }
242   }
243 
on_iocp(Result<size_t> r_size,WSAOVERLAPPED * overlapped)244   void on_iocp(Result<size_t> r_size, WSAOVERLAPPED *overlapped) final {
245     // called from other thread
246     if (dec_refcnt() || close_flag_) {
247       VLOG(fd) << "Ignore IOCP (UDP socket is closing)";
248       return;
249     }
250     if (r_size.is_error()) {
251       return on_error(get_socket_pending_error(get_native_fd(), overlapped, r_size.move_as_error()));
252     }
253 
254     if (!is_connected_ && overlapped == &receive_overlapped_) {
255       return on_connected();
256     }
257 
258     auto size = r_size.move_as_ok();
259     if (overlapped == &send_overlapped_) {
260       return on_send(size);
261     }
262     if (overlapped == nullptr) {
263       CHECK(size == 0);
264       return on_send(size);
265     }
266 
267     if (overlapped == &receive_overlapped_) {
268       return on_receive(size);
269     }
270     if (overlapped == reinterpret_cast<WSAOVERLAPPED *>(&close_overlapped_)) {
271       return on_close();
272     }
273     UNREACHABLE();
274   }
275 
on_error(Status status)276   void on_error(Status status) {
277     VLOG(fd) << get_native_fd() << " on error " << status;
278     {
279       auto lock = lock_.lock();
280       pending_errors_.push(std::move(status));
281     }
282     get_poll_info().add_flags_from_poll(PollFlags::Error());
283   }
284 
on_connected()285   void on_connected() {
286     VLOG(fd) << get_native_fd() << " on connected";
287     CHECK(!is_connected_);
288     CHECK(is_receive_active_);
289     is_connected_ = true;
290     is_receive_active_ = false;
291     loop_receive();
292     loop_send();
293   }
294 
on_receive(size_t size)295   void on_receive(size_t size) {
296     VLOG(fd) << get_native_fd() << " on receive " << size;
297     CHECK(is_receive_active_);
298     is_receive_active_ = false;
299     UdpSocketReceiveHelper::from_native(receive_message_, size, to_receive_);
300     receive_buffer_.confirm_read((to_receive_.data.size() + 7) & ~7);
301     {
302       auto lock = lock_.lock();
303       // LOG(ERROR) << format::escaped(to_receive_.data.as_slice());
304       receive_queue_.push(std::move(to_receive_));
305     }
306     get_poll_info().add_flags_from_poll(PollFlags::Read());
307     loop_receive();
308   }
309 
on_send(size_t size)310   void on_send(size_t size) {
311     VLOG(fd) << get_native_fd() << " on send " << size;
312     if (size == 0) {
313       if (is_send_active_) {
314         return;
315       }
316       is_send_active_ = true;
317     }
318     CHECK(is_send_active_);
319     is_send_active_ = false;
320     loop_send();
321   }
322 
on_close()323   void on_close() {
324     VLOG(fd) << get_native_fd() << " on close";
325     close_flag_ = true;
326     info_.set_native_fd({});
327   }
328 
dec_refcnt()329   bool dec_refcnt() {
330     if (--refcnt_ == 0) {
331       delete this;
332       return true;
333     }
334     return false;
335   }
336 
inc_refcnt()337   void inc_refcnt() {
338     CHECK(refcnt_ != 0);
339     refcnt_++;
340   }
341 
notify_iocp_send()342   void notify_iocp_send() {
343     inc_refcnt();
344     Iocp::get()->post(0, this, nullptr);
345   }
notify_iocp_close()346   void notify_iocp_close() {
347     Iocp::get()->post(0, this, reinterpret_cast<WSAOVERLAPPED *>(&close_overlapped_));
348   }
notify_iocp_connected()349   void notify_iocp_connected() {
350     inc_refcnt();
351     Iocp::get()->post(0, this, reinterpret_cast<WSAOVERLAPPED *>(&receive_overlapped_));
352   }
353 };
354 
operator ()(UdpSocketFdImpl * impl)355 void UdpSocketFdImplDeleter::operator()(UdpSocketFdImpl *impl) {
356   impl->close();
357 }
358 
359 #elif TD_PORT_POSIX
360 //struct iovec {                  [> Scatter/gather array items <]
361 //  void  *iov_base;              [> Starting address <]
362 //  size_t iov_len;               [> Number of bytes to transfer <]
363 //};
364 
365 //struct msghdr {
366 //  void         *msg_name;       [> optional address <]
367 //  socklen_t     msg_namelen;    [> size of address <]
368 //  struct iovec *msg_iov;        [> scatter/gather array <]
369 //  size_t        msg_iovlen;     [> # elements in msg_iov <]
370 //  void         *msg_control;    [> ancillary data, see below <]
371 //  size_t        msg_controllen; [> ancillary data buffer len <]
372 //  int           msg_flags;      [> flags on received message <]
373 //};
374 
375 class UdpSocketReceiveHelper {
376  public:
377   void to_native(const UdpSocketFd::InboundMessage &message, msghdr &message_header) {
378     socklen_t addr_len{narrow_cast<socklen_t>(sizeof(addr_))};
379 
380     message_header.msg_name = &addr_;
381     message_header.msg_namelen = addr_len;
382     io_vec_.iov_base = message.data.begin();
383     io_vec_.iov_len = message.data.size();
384     message_header.msg_iov = &io_vec_;
385     message_header.msg_iovlen = 1;
386     message_header.msg_control = control_buf_.data();
387     message_header.msg_controllen = narrow_cast<decltype(message_header.msg_controllen)>(control_buf_.size());
388     message_header.msg_flags = 0;
389   }
390 
391   static void from_native(msghdr &message_header, size_t message_size, UdpSocketFd::InboundMessage &message) {
392 #if TD_LINUX
393     cmsghdr *cmsg;
394     sock_extended_err *ee = nullptr;
395     for (cmsg = CMSG_FIRSTHDR(&message_header); cmsg != nullptr; cmsg = CMSG_NXTHDR(&message_header, cmsg)) {
396       if (cmsg->cmsg_type == IP_PKTINFO && cmsg->cmsg_level == IPPROTO_IP) {
397         //auto *pi = reinterpret_cast<in_pktinfo *>(CMSG_DATA(cmsg));
398       } else if (cmsg->cmsg_type == IPV6_PKTINFO && cmsg->cmsg_level == IPPROTO_IPV6) {
399         //auto *pi = reinterpret_cast<in6_pktinfo *>(CMSG_DATA(cmsg));
400       } else if ((cmsg->cmsg_type == IP_RECVERR && cmsg->cmsg_level == IPPROTO_IP) ||
401                  (cmsg->cmsg_type == IPV6_RECVERR && cmsg->cmsg_level == IPPROTO_IPV6)) {
402         ee = reinterpret_cast<sock_extended_err *>(CMSG_DATA(cmsg));
403       }
404     }
405     if (ee != nullptr) {
406       auto *addr = reinterpret_cast<sockaddr *>(SO_EE_OFFENDER(ee));
407       IPAddress address;
408       address.init_sockaddr(addr).ignore();
409       if (message.from != nullptr) {
410         *message.from = address;
411       }
412       if (message.error) {
413         *message.error = Status::PosixError(ee->ee_errno, "");
414       }
415       //message.data = MutableSlice();
416       message.data.truncate(0);
417       return;
418     }
419 #endif
420     if (message.from != nullptr) {
421       message.from->init_sockaddr(reinterpret_cast<sockaddr *>(message_header.msg_name), message_header.msg_namelen)
422           .ignore();
423     }
424     if (message.error) {
425       *message.error = Status::OK();
426     }
427     if (message_header.msg_flags & MSG_TRUNC) {
428       if (message.error) {
429         *message.error = Status::Error(501, "Message too long");
430       }
431       message.data.truncate(0);
432       return;
433     }
434     CHECK(message_size <= message.data.size());
435     message.data.truncate(message_size);
436     CHECK(message_size == message.data.size());
437   }
438 
439  private:
440   std::array<char, 1024> control_buf_;
441   sockaddr_storage addr_;
442   iovec io_vec_;
443 };
444 
445 class UdpSocketSendHelper {
446  public:
447   void to_native(const UdpSocketFd::OutboundMessage &message, msghdr &message_header) {
448     CHECK(message.to != nullptr && message.to->is_valid());
449     message_header.msg_name = const_cast<sockaddr *>(message.to->get_sockaddr());
450     message_header.msg_namelen = narrow_cast<socklen_t>(message.to->get_sockaddr_len());
451     io_vec_.iov_base = const_cast<char *>(message.data.begin());
452     io_vec_.iov_len = message.data.size();
453     message_header.msg_iov = &io_vec_;
454     message_header.msg_iovlen = 1;
455     //TODO
456     message_header.msg_control = nullptr;
457     message_header.msg_controllen = 0;
458     message_header.msg_flags = 0;
459   }
460 
461  private:
462   iovec io_vec_;
463 };
464 
465 class UdpSocketFdImpl {
466  public:
467   explicit UdpSocketFdImpl(NativeFd fd) : info_(std::move(fd)) {
468   }
469   PollableFdInfo &get_poll_info() {
470     return info_;
471   }
472   const PollableFdInfo &get_poll_info() const {
473     return info_;
474   }
475 
476   const NativeFd &get_native_fd() const {
477     return info_.native_fd();
478   }
479   Status get_pending_error() {
480     if (!get_poll_info().get_flags_local().has_pending_error()) {
481       return Status::OK();
482     }
483     TRY_STATUS(detail::get_socket_pending_error(get_native_fd()));
484     get_poll_info().clear_flags(PollFlags::Error());
485     return Status::OK();
486   }
487   Status receive_message(UdpSocketFd::InboundMessage &message, bool &is_received) {
488     is_received = false;
489     int flags = 0;
490     if (get_poll_info().get_flags_local().has_pending_error()) {
491 #ifdef MSG_ERRQUEUE
492       flags = MSG_ERRQUEUE;
493 #else
494       return get_pending_error();
495 #endif
496     }
497 
498     msghdr message_header;
499     detail::UdpSocketReceiveHelper helper;
500     helper.to_native(message, message_header);
501 
502     auto native_fd = get_native_fd().socket();
503     auto recvmsg_res = detail::skip_eintr([&] { return recvmsg(native_fd, &message_header, flags); });
504     auto recvmsg_errno = errno;
505     if (recvmsg_res >= 0) {
506       UdpSocketReceiveHelper::from_native(message_header, recvmsg_res, message);
507       is_received = true;
508       return Status::OK();
509     }
510     return process_recvmsg_error(recvmsg_errno, is_received);
511   }
512 
513   Status process_recvmsg_error(int recvmsg_errno, bool &is_received) {
514     is_received = false;
515     if (recvmsg_errno == EAGAIN
516 #if EAGAIN != EWOULDBLOCK
517         || recvmsg_errno == EWOULDBLOCK
518 #endif
519     ) {
520       if (get_poll_info().get_flags_local().has_pending_error()) {
521         get_poll_info().clear_flags(PollFlags::Error());
522       } else {
523         get_poll_info().clear_flags(PollFlags::Read());
524       }
525       return Status::OK();
526     }
527 
528     auto error = Status::PosixError(recvmsg_errno, PSLICE() << "Receive from " << get_native_fd() << " has failed");
529     switch (recvmsg_errno) {
530       case EBADF:
531       case EFAULT:
532       case EINVAL:
533       case ENOTCONN:
534       case ECONNRESET:
535       case ETIMEDOUT:
536         LOG(FATAL) << error;
537         UNREACHABLE();
538       default:
539         LOG(WARNING) << "Unknown error: " << error;
540       // fallthrough
541       case ENOBUFS:
542       case ENOMEM:
543 #ifdef MSG_ERRQUEUE
544         get_poll_info().add_flags(PollFlags::Error());
545 #endif
546         return error;
547     }
548   }
549 
550   Status send_message(const UdpSocketFd::OutboundMessage &message, bool &is_sent) {
551     is_sent = false;
552     msghdr message_header;
553     detail::UdpSocketSendHelper helper;
554     helper.to_native(message, message_header);
555 
556     auto native_fd = get_native_fd().socket();
557     auto sendmsg_res = detail::skip_eintr([&] { return sendmsg(native_fd, &message_header, 0); });
558     auto sendmsg_errno = errno;
559     if (sendmsg_res >= 0) {
560       is_sent = true;
561       return Status::OK();
562     }
563     return process_sendmsg_error(sendmsg_errno, is_sent);
564   }
565   Status process_sendmsg_error(int sendmsg_errno, bool &is_sent) {
566     if (sendmsg_errno == EAGAIN
567 #if EAGAIN != EWOULDBLOCK
568         || sendmsg_errno == EWOULDBLOCK
569 #endif
570     ) {
571       get_poll_info().clear_flags(PollFlags::Write());
572       return Status::OK();
573     }
574 
575     auto error = Status::PosixError(sendmsg_errno, PSLICE() << "Send from " << get_native_fd() << " has failed");
576     switch (sendmsg_errno) {
577       // Still may send some other packets, but there is no point to resend this particular message
578       case EACCES:
579       case EMSGSIZE:
580       case EPERM:
581         LOG(WARNING) << "Silently drop packet :( " << error;
582         //TODO: get errors from MSG_ERRQUEUE is possible
583         is_sent = true;
584         return error;
585 
586       // Some general problems, which may be fixed in future
587       case ENOMEM:
588       case EDQUOT:
589       case EFBIG:
590       case ENETDOWN:
591       case ENETUNREACH:
592       case ENOSPC:
593       case EHOSTUNREACH:
594       case ENOBUFS:
595       default:
596 #ifdef MSG_ERRQUEUE
597         get_poll_info().add_flags(PollFlags::Error());
598 #endif
599         return error;
600 
601       case EBADF:         // impossible
602       case ENOTSOCK:      // impossible
603       case EPIPE:         // impossible for udp
604       case ECONNRESET:    // impossible for udp
605       case EDESTADDRREQ:  // we checked that address is valid
606       case ENOTCONN:      // we checked that address is valid
607       case EINTR:         // we already skipped all EINTR
608       case EISCONN:       // impossible for udp socket
609       case EOPNOTSUPP:
610       case ENOTDIR:
611       case EFAULT:
612       case EINVAL:
613       case EAFNOSUPPORT:
614         LOG(FATAL) << error;
615         UNREACHABLE();
616         return error;
617     }
618   }
619 
620   Status send_messages(Span<UdpSocketFd::OutboundMessage> messages, size_t &cnt) {
621 #if TD_HAS_MMSG
622     return send_messages_fast(messages, cnt);
623 #else
624     return send_messages_slow(messages, cnt);
625 #endif
626   }
627 
628   Status receive_messages(MutableSpan<UdpSocketFd::InboundMessage> messages, size_t &cnt) {
629 #if TD_HAS_MMSG
630     return receive_messages_fast(messages, cnt);
631 #else
632     return receive_messages_slow(messages, cnt);
633 #endif
634   }
635 
636  private:
637   PollableFdInfo info_;
638 
639   Status send_messages_slow(Span<UdpSocketFd::OutboundMessage> messages, size_t &cnt) {
640     cnt = 0;
641     for (auto &message : messages) {
642       CHECK(!message.data.empty());
643       bool is_sent;
644       auto error = send_message(message, is_sent);
645       cnt += is_sent;
646       TRY_STATUS(std::move(error));
647     }
648     return Status::OK();
649   }
650 
651 #if TD_HAS_MMSG
652   Status send_messages_fast(Span<UdpSocketFd::OutboundMessage> messages, size_t &cnt) {
653     //struct mmsghdr {
654     //  msghdr msg_hdr;        [> Message header <]
655     //  unsigned int msg_len;  [> Number of bytes transmitted <]
656     //};
657     std::array<detail::UdpSocketSendHelper, 16> helpers;
658     std::array<mmsghdr, 16> headers;
659     size_t to_send = min(messages.size(), headers.size());
660     for (size_t i = 0; i < to_send; i++) {
661       helpers[i].to_native(messages[i], headers[i].msg_hdr);
662       headers[i].msg_len = 0;
663     }
664 
665     auto native_fd = get_native_fd().socket();
666     auto sendmmsg_res =
667         detail::skip_eintr([&] { return sendmmsg(native_fd, headers.data(), narrow_cast<unsigned int>(to_send), 0); });
668     auto sendmmsg_errno = errno;
669     if (sendmmsg_res >= 0) {
670       cnt = sendmmsg_res;
671       return Status::OK();
672     }
673 
674     bool is_sent = false;
675     auto status = process_sendmsg_error(sendmmsg_errno, is_sent);
676     cnt = is_sent;
677     return status;
678   }
679 #endif
680   Status receive_messages_slow(MutableSpan<UdpSocketFd::InboundMessage> messages, size_t &cnt) {
681     cnt = 0;
682     while (cnt < messages.size() && get_poll_info().get_flags_local().can_read()) {
683       auto &message = messages[cnt];
684       CHECK(!message.data.empty());
685       bool is_received;
686       auto error = receive_message(message, is_received);
687       cnt += is_received;
688       TRY_STATUS(std::move(error));
689     }
690     return Status::OK();
691   }
692 
693 #if TD_HAS_MMSG
694   Status receive_messages_fast(MutableSpan<UdpSocketFd::InboundMessage> messages, size_t &cnt) {
695     int flags = 0;
696     cnt = 0;
697     if (get_poll_info().get_flags_local().has_pending_error()) {
698 #ifdef MSG_ERRQUEUE
699       flags = MSG_ERRQUEUE;
700 #else
701       return get_pending_error();
702 #endif
703     }
704     //struct mmsghdr {
705     //  msghdr msg_hdr;        [> Message header <]
706     //  unsigned int msg_len;  [> Number of bytes transmitted <]
707     //};
708     std::array<detail::UdpSocketReceiveHelper, 16> helpers;
709     std::array<mmsghdr, 16> headers;
710     size_t to_receive = min(messages.size(), headers.size());
711     for (size_t i = 0; i < to_receive; i++) {
712       helpers[i].to_native(messages[i], headers[i].msg_hdr);
713       headers[i].msg_len = 0;
714     }
715 
716     auto native_fd = get_native_fd().socket();
717     auto recvmmsg_res = detail::skip_eintr(
718         [&] { return recvmmsg(native_fd, headers.data(), narrow_cast<unsigned int>(to_receive), flags, nullptr); });
719     auto recvmmsg_errno = errno;
720     if (recvmmsg_res >= 0) {
721       cnt = narrow_cast<size_t>(recvmmsg_res);
722       for (size_t i = 0; i < cnt; i++) {
723         UdpSocketReceiveHelper::from_native(headers[i].msg_hdr, headers[i].msg_len, messages[i]);
724       }
725       return Status::OK();
726     }
727 
728     bool is_received;
729     auto status = process_recvmsg_error(recvmmsg_errno, is_received);
730     cnt = is_received;
731     return status;
732   }
733 #endif
734 };
735 void UdpSocketFdImplDeleter::operator()(UdpSocketFdImpl *impl) {
736   delete impl;
737 }
738 #endif
739 }  // namespace detail
740 
741 UdpSocketFd::UdpSocketFd() = default;
742 UdpSocketFd::UdpSocketFd(UdpSocketFd &&) noexcept = default;
743 UdpSocketFd &UdpSocketFd::operator=(UdpSocketFd &&) noexcept = default;
744 UdpSocketFd::~UdpSocketFd() = default;
get_poll_info()745 PollableFdInfo &UdpSocketFd::get_poll_info() {
746   return impl_->get_poll_info();
747 }
get_poll_info() const748 const PollableFdInfo &UdpSocketFd::get_poll_info() const {
749   return impl_->get_poll_info();
750 }
751 
open(const IPAddress & address)752 Result<UdpSocketFd> UdpSocketFd::open(const IPAddress &address) {
753   NativeFd native_fd{socket(address.get_address_family(), SOCK_DGRAM, IPPROTO_UDP)};
754   if (!native_fd) {
755     return OS_SOCKET_ERROR("Failed to create a socket");
756   }
757   TRY_STATUS(native_fd.set_is_blocking_unsafe(false));
758 
759   auto sock = native_fd.socket();
760 #if TD_PORT_POSIX
761   int flags = 1;
762 #elif TD_PORT_WINDOWS
763   BOOL flags = TRUE;
764 #endif
765   setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char *>(&flags), sizeof(flags));
766   // TODO: SO_REUSEADDR, SO_KEEPALIVE, TCP_NODELAY, SO_SNDBUF, SO_RCVBUF, TCP_QUICKACK, SO_LINGER
767 
768   auto bind_addr = address.get_any_addr();
769   bind_addr.set_port(address.get_port());
770   auto e_bind = bind(sock, bind_addr.get_sockaddr(), narrow_cast<int>(bind_addr.get_sockaddr_len()));
771   if (e_bind != 0) {
772     return OS_SOCKET_ERROR("Failed to bind a socket");
773   }
774   return UdpSocketFd(make_unique<detail::UdpSocketFdImpl>(std::move(native_fd)));
775 }
776 
UdpSocketFd(unique_ptr<detail::UdpSocketFdImpl> impl)777 UdpSocketFd::UdpSocketFd(unique_ptr<detail::UdpSocketFdImpl> impl) : impl_(impl.release()) {
778 }
779 
close()780 void UdpSocketFd::close() {
781   impl_.reset();
782 }
783 
empty() const784 bool UdpSocketFd::empty() const {
785   return !impl_;
786 }
787 
get_native_fd() const788 const NativeFd &UdpSocketFd::get_native_fd() const {
789   return get_poll_info().native_fd();
790 }
791 
792 #if TD_PORT_POSIX
maximize_buffer(int socket_fd,int optname,uint32 max)793 static Result<uint32> maximize_buffer(int socket_fd, int optname, uint32 max) {
794   if (setsockopt(socket_fd, SOL_SOCKET, optname, &max, sizeof(max)) == 0) {
795     // fast path
796     return max;
797   }
798 
799   /* Start with the default size. */
800   uint32 old_size = 0;
801   socklen_t intsize = sizeof(old_size);
802   if (getsockopt(socket_fd, SOL_SOCKET, optname, &old_size, &intsize)) {
803     return OS_ERROR("getsockopt() failed");
804   }
805 #if TD_LINUX
806   old_size /= 2;
807 #endif
808 
809   /* Binary-search for the real maximum. */
810   uint32 last_good = old_size;
811   uint32 min = old_size;
812   while (min <= max) {
813     uint32 avg = min + (max - min) / 2;
814     if (setsockopt(socket_fd, SOL_SOCKET, optname, &avg, sizeof(avg)) == 0) {
815       last_good = avg;
816       min = avg + 1;
817     } else {
818       max = avg - 1;
819     }
820   }
821   return last_good;
822 }
823 
maximize_snd_buffer(uint32 max)824 Result<uint32> UdpSocketFd::maximize_snd_buffer(uint32 max) {
825   return maximize_buffer(get_native_fd().fd(), SO_SNDBUF, max == 0 ? DEFAULT_UDP_MAX_SND_BUFFER_SIZE : max);
826 }
827 
maximize_rcv_buffer(uint32 max)828 Result<uint32> UdpSocketFd::maximize_rcv_buffer(uint32 max) {
829   return maximize_buffer(get_native_fd().fd(), SO_RCVBUF, max == 0 ? DEFAULT_UDP_MAX_RCV_BUFFER_SIZE : max);
830 }
831 #else
maximize_snd_buffer(uint32 max)832 Result<uint32> UdpSocketFd::maximize_snd_buffer(uint32 max) {
833   return 0;
834 }
maximize_rcv_buffer(uint32 max)835 Result<uint32> UdpSocketFd::maximize_rcv_buffer(uint32 max) {
836   return 0;
837 }
838 #endif
839 
840 #if TD_PORT_POSIX
send_message(const OutboundMessage & message,bool & is_sent)841 Status UdpSocketFd::send_message(const OutboundMessage &message, bool &is_sent) {
842   return impl_->send_message(message, is_sent);
843 }
receive_message(InboundMessage & message,bool & is_received)844 Status UdpSocketFd::receive_message(InboundMessage &message, bool &is_received) {
845   return impl_->receive_message(message, is_received);
846 }
847 
send_messages(Span<OutboundMessage> messages,size_t & count)848 Status UdpSocketFd::send_messages(Span<OutboundMessage> messages, size_t &count) {
849   return impl_->send_messages(messages, count);
850 }
receive_messages(MutableSpan<InboundMessage> messages,size_t & count)851 Status UdpSocketFd::receive_messages(MutableSpan<InboundMessage> messages, size_t &count) {
852   return impl_->receive_messages(messages, count);
853 }
854 #endif
855 #if TD_PORT_WINDOWS
receive()856 Result<optional<UdpMessage>> UdpSocketFd::receive() {
857   return impl_->receive();
858 }
859 
send(UdpMessage message)860 void UdpSocketFd::send(UdpMessage message) {
861   return impl_->send(std::move(message));
862 }
863 
flush_send()864 Status UdpSocketFd::flush_send() {
865   return impl_->flush_send();
866 }
867 #endif
868 
is_critical_read_error(const Status & status)869 bool UdpSocketFd::is_critical_read_error(const Status &status) {
870   return status.code() == ENOMEM || status.code() == ENOBUFS;
871 }
872 
873 }  // namespace td
874