1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "perfetto/ext/base/unix_socket.h"
18 
19 #include <errno.h>
20 #include <fcntl.h>
21 #include <netdb.h>
22 #include <netinet/in.h>
23 #include <netinet/tcp.h>
24 #include <stdlib.h>
25 #include <string.h>
26 #include <sys/socket.h>
27 #include <sys/stat.h>
28 #include <sys/types.h>
29 #include <sys/un.h>
30 #include <unistd.h>
31 
32 #include <algorithm>
33 #include <memory>
34 
35 #include "perfetto/base/build_config.h"
36 #include "perfetto/base/logging.h"
37 #include "perfetto/base/task_runner.h"
38 #include "perfetto/ext/base/string_utils.h"
39 #include "perfetto/ext/base/utils.h"
40 
41 #if PERFETTO_BUILDFLAG(PERFETTO_OS_APPLE) || PERFETTO_BUILDFLAG(PERFETTO_OS_FREEBSD)
42 #include <sys/ucred.h>
43 #endif
44 
45 namespace perfetto {
46 namespace base {
47 
48 // The CMSG_* macros use NULL instead of nullptr.
49 #pragma GCC diagnostic push
50 #if !PERFETTO_BUILDFLAG(PERFETTO_OS_APPLE)
51 #pragma GCC diagnostic ignored "-Wzero-as-null-pointer-constant"
52 #endif
53 
54 namespace {
55 
56 // MSG_NOSIGNAL is not supported on Mac OS X, but in that case the socket is
57 // created with SO_NOSIGPIPE (See InitializeSocket()).
58 #if PERFETTO_BUILDFLAG(PERFETTO_OS_APPLE)
59 constexpr int kNoSigPipe = 0;
60 #else
61 constexpr int kNoSigPipe = MSG_NOSIGNAL;
62 #endif
63 
64 // Android takes an int instead of socklen_t for the control buffer size.
65 #if PERFETTO_BUILDFLAG(PERFETTO_OS_ANDROID)
66 using CBufLenType = size_t;
67 #else
68 using CBufLenType = socklen_t;
69 #endif
70 
71 // A wrapper around variable-size sockaddr structs.
72 // This is solving the following problem: when calling connect() or bind(), the
73 // caller needs to take care to allocate the right struct (sockaddr_un for
74 // AF_UNIX, sockaddr_in for AF_INET).   Those structs have different sizes and,
75 // more importantly, are bigger than the base struct sockaddr.
76 struct SockaddrAny {
SockaddrAnyperfetto::base::__anoncce828370111::SockaddrAny77   SockaddrAny() : size() {}
SockaddrAnyperfetto::base::__anoncce828370111::SockaddrAny78   SockaddrAny(const void* addr, socklen_t sz) : data(new char[sz]), size(sz) {
79     memcpy(data.get(), addr, static_cast<size_t>(size));
80   }
81 
addrperfetto::base::__anoncce828370111::SockaddrAny82   const struct sockaddr* addr() const {
83     return reinterpret_cast<const struct sockaddr*>(data.get());
84   }
85 
86   std::unique_ptr<char[]> data;
87   socklen_t size;
88 };
89 
GetSockFamily(SockFamily family)90 inline int GetSockFamily(SockFamily family) {
91   switch (family) {
92     case SockFamily::kUnix:
93       return AF_UNIX;
94     case SockFamily::kInet:
95       return AF_INET;
96     case SockFamily::kInet6:
97       return AF_INET6;
98   }
99   PERFETTO_CHECK(false);  // For GCC.
100 }
101 
GetSockType(SockType type)102 inline int GetSockType(SockType type) {
103 #ifdef SOCK_CLOEXEC
104   constexpr int kSockCloExec = SOCK_CLOEXEC;
105 #else
106   constexpr int kSockCloExec = 0;
107 #endif
108   switch (type) {
109     case SockType::kStream:
110       return SOCK_STREAM | kSockCloExec;
111     case SockType::kDgram:
112       return SOCK_DGRAM | kSockCloExec;
113     case SockType::kSeqPacket:
114       return SOCK_SEQPACKET | kSockCloExec;
115   }
116   PERFETTO_CHECK(false);  // For GCC.
117 }
118 
MakeSockAddr(SockFamily family,const std::string & socket_name)119 SockaddrAny MakeSockAddr(SockFamily family, const std::string& socket_name) {
120   switch (family) {
121     case SockFamily::kUnix: {
122       struct sockaddr_un saddr {};
123       const size_t name_len = socket_name.size();
124       if (name_len >= sizeof(saddr.sun_path)) {
125         errno = ENAMETOOLONG;
126         return SockaddrAny();
127       }
128       memcpy(saddr.sun_path, socket_name.data(), name_len);
129       if (saddr.sun_path[0] == '@')
130         saddr.sun_path[0] = '\0';
131       saddr.sun_family = AF_UNIX;
132       auto size = static_cast<socklen_t>(
133           __builtin_offsetof(sockaddr_un, sun_path) + name_len + 1);
134       PERFETTO_CHECK(static_cast<size_t>(size) <= sizeof(saddr));
135       return SockaddrAny(&saddr, size);
136     }
137     case SockFamily::kInet: {
138       auto parts = SplitString(socket_name, ":");
139       PERFETTO_CHECK(parts.size() == 2);
140       struct addrinfo* addr_info = nullptr;
141       struct addrinfo hints {};
142       hints.ai_family = AF_INET;
143       PERFETTO_CHECK(getaddrinfo(parts[0].c_str(), parts[1].c_str(), &hints,
144                                  &addr_info) == 0);
145       PERFETTO_CHECK(addr_info->ai_family == AF_INET);
146       SockaddrAny res(addr_info->ai_addr, addr_info->ai_addrlen);
147       freeaddrinfo(addr_info);
148       return res;
149     }
150     case SockFamily::kInet6: {
151       auto parts = SplitString(socket_name, "]");
152       PERFETTO_CHECK(parts.size() == 2);
153       auto address = SplitString(parts[0], "[");
154       PERFETTO_CHECK(address.size() == 1);
155       auto port = SplitString(parts[1], ":");
156       PERFETTO_CHECK(port.size() == 1);
157       struct addrinfo* addr_info = nullptr;
158       struct addrinfo hints {};
159       hints.ai_family = AF_INET6;
160       PERFETTO_CHECK(getaddrinfo(address[0].c_str(), port[0].c_str(), &hints,
161                                  &addr_info) == 0);
162       PERFETTO_CHECK(addr_info->ai_family == AF_INET6);
163       SockaddrAny res(addr_info->ai_addr, addr_info->ai_addrlen);
164       freeaddrinfo(addr_info);
165       return res;
166     }
167   }
168   PERFETTO_CHECK(false);  // For GCC.
169 }
170 
171 }  // namespace
172 
173 // +-----------------------+
174 // | UnixSocketRaw methods |
175 // +-----------------------+
176 
177 // static
ShiftMsgHdr(size_t n,struct msghdr * msg)178 void UnixSocketRaw::ShiftMsgHdr(size_t n, struct msghdr* msg) {
179   using LenType = decltype(msg->msg_iovlen);  // Mac and Linux don't agree.
180   for (LenType i = 0; i < msg->msg_iovlen; ++i) {
181     struct iovec* vec = &msg->msg_iov[i];
182     if (n < vec->iov_len) {
183       // We sent a part of this iovec.
184       vec->iov_base = reinterpret_cast<char*>(vec->iov_base) + n;
185       vec->iov_len -= n;
186       msg->msg_iov = vec;
187       msg->msg_iovlen -= i;
188       return;
189     }
190     // We sent the whole iovec.
191     n -= vec->iov_len;
192   }
193   // We sent all the iovecs.
194   PERFETTO_CHECK(n == 0);
195   msg->msg_iovlen = 0;
196   msg->msg_iov = nullptr;
197 }
198 
199 // static
CreateMayFail(SockFamily family,SockType type)200 UnixSocketRaw UnixSocketRaw::CreateMayFail(SockFamily family, SockType type) {
201   auto fd = ScopedFile(socket(GetSockFamily(family), GetSockType(type), 0));
202   if (!fd) {
203     return UnixSocketRaw();
204   }
205   return UnixSocketRaw(std::move(fd), family, type);
206 }
207 
208 // static
CreatePair(SockFamily family,SockType type)209 std::pair<UnixSocketRaw, UnixSocketRaw> UnixSocketRaw::CreatePair(
210     SockFamily family,
211     SockType type) {
212   int fds[2];
213   if (socketpair(GetSockFamily(family), GetSockType(type), 0, fds) != 0)
214     return std::make_pair(UnixSocketRaw(), UnixSocketRaw());
215 
216   return std::make_pair(UnixSocketRaw(ScopedFile(fds[0]), family, type),
217                         UnixSocketRaw(ScopedFile(fds[1]), family, type));
218 }
219 
220 UnixSocketRaw::UnixSocketRaw() = default;
221 
UnixSocketRaw(SockFamily family,SockType type)222 UnixSocketRaw::UnixSocketRaw(SockFamily family, SockType type)
223     : UnixSocketRaw(
224           ScopedFile(socket(GetSockFamily(family), GetSockType(type), 0)),
225           family,
226           type) {}
227 
UnixSocketRaw(ScopedFile fd,SockFamily family,SockType type)228 UnixSocketRaw::UnixSocketRaw(ScopedFile fd, SockFamily family, SockType type)
229     : fd_(std::move(fd)), family_(family), type_(type) {
230   PERFETTO_CHECK(fd_);
231 #if PERFETTO_BUILDFLAG(PERFETTO_OS_APPLE)
232   const int no_sigpipe = 1;
233   setsockopt(*fd_, SOL_SOCKET, SO_NOSIGPIPE, &no_sigpipe, sizeof(no_sigpipe));
234 #endif
235 
236   if (family == SockFamily::kInet || family == SockFamily::kInet6) {
237     int flag = 1;
238     PERFETTO_CHECK(
239         !setsockopt(*fd_, SOL_SOCKET, SO_REUSEADDR, &flag, sizeof(flag)));
240     flag = 1;
241     // Disable Nagle's algorithm, optimize for low-latency.
242     // See https://github.com/google/perfetto/issues/70.
243     setsockopt(*fd_, IPPROTO_TCP, TCP_NODELAY, &flag, sizeof(flag));
244   }
245 
246   // There is no reason why a socket should outlive the process in case of
247   // exec() by default, this is just working around a broken unix design.
248   int fcntl_res = fcntl(*fd_, F_SETFD, FD_CLOEXEC);
249   PERFETTO_CHECK(fcntl_res == 0);
250 }
251 
SetBlocking(bool is_blocking)252 void UnixSocketRaw::SetBlocking(bool is_blocking) {
253   PERFETTO_DCHECK(fd_);
254   int flags = fcntl(*fd_, F_GETFL, 0);
255   if (!is_blocking) {
256     flags |= O_NONBLOCK;
257   } else {
258     flags &= ~static_cast<int>(O_NONBLOCK);
259   }
260   bool fcntl_res = fcntl(*fd_, F_SETFL, flags);
261   PERFETTO_CHECK(fcntl_res == 0);
262 }
263 
RetainOnExec()264 void UnixSocketRaw::RetainOnExec() {
265   PERFETTO_DCHECK(fd_);
266   int flags = fcntl(*fd_, F_GETFD, 0);
267   flags &= ~static_cast<int>(FD_CLOEXEC);
268   bool fcntl_res = fcntl(*fd_, F_SETFD, flags);
269   PERFETTO_CHECK(fcntl_res == 0);
270 }
271 
IsBlocking() const272 bool UnixSocketRaw::IsBlocking() const {
273   PERFETTO_DCHECK(fd_);
274   return (fcntl(*fd_, F_GETFL, 0) & O_NONBLOCK) == 0;
275 }
276 
Bind(const std::string & socket_name)277 bool UnixSocketRaw::Bind(const std::string& socket_name) {
278   PERFETTO_DCHECK(fd_);
279   SockaddrAny addr = MakeSockAddr(family_, socket_name);
280   if (addr.size == 0)
281     return false;
282 
283   if (bind(*fd_, addr.addr(), addr.size)) {
284     PERFETTO_DPLOG("bind(%s)", socket_name.c_str());
285     return false;
286   }
287 
288   return true;
289 }
290 
Listen()291 bool UnixSocketRaw::Listen() {
292   PERFETTO_DCHECK(fd_);
293   PERFETTO_DCHECK(type_ == SockType::kStream || type_ == SockType::kSeqPacket);
294   return listen(*fd_, SOMAXCONN) == 0;
295 }
296 
Connect(const std::string & socket_name)297 bool UnixSocketRaw::Connect(const std::string& socket_name) {
298   PERFETTO_DCHECK(fd_);
299   SockaddrAny addr = MakeSockAddr(family_, socket_name);
300   if (addr.size == 0)
301     return false;
302 
303   int res = PERFETTO_EINTR(connect(*fd_, addr.addr(), addr.size));
304   if (res && errno != EINPROGRESS)
305     return false;
306 
307   return true;
308 }
309 
Shutdown()310 void UnixSocketRaw::Shutdown() {
311   shutdown(*fd_, SHUT_RDWR);
312   fd_.reset();
313 }
314 
315 // For the interested reader, Linux kernel dive to verify this is not only a
316 // theoretical possibility: sock_stream_sendmsg, if sock_alloc_send_pskb returns
317 // NULL [1] (which it does when it gets interrupted [2]), returns early with the
318 // amount of bytes already sent.
319 //
320 // [1]:
321 // https://elixir.bootlin.com/linux/v4.18.10/source/net/unix/af_unix.c#L1872
322 // [2]: https://elixir.bootlin.com/linux/v4.18.10/source/net/core/sock.c#L2101
SendMsgAll(struct msghdr * msg)323 ssize_t UnixSocketRaw::SendMsgAll(struct msghdr* msg) {
324   // This does not make sense on non-blocking sockets.
325   PERFETTO_DCHECK(fd_);
326 
327   ssize_t total_sent = 0;
328   while (msg->msg_iov) {
329     ssize_t sent = PERFETTO_EINTR(sendmsg(*fd_, msg, kNoSigPipe));
330     if (sent <= 0) {
331       if (sent == -1 && IsAgain(errno))
332         return total_sent;
333       return sent;
334     }
335     total_sent += sent;
336     ShiftMsgHdr(static_cast<size_t>(sent), msg);
337     // Only send the ancillary data with the first sendmsg call.
338     msg->msg_control = nullptr;
339     msg->msg_controllen = 0;
340   }
341   return total_sent;
342 }
343 
Send(const void * msg,size_t len,const int * send_fds,size_t num_fds)344 ssize_t UnixSocketRaw::Send(const void* msg,
345                             size_t len,
346                             const int* send_fds,
347                             size_t num_fds) {
348   PERFETTO_DCHECK(fd_);
349   msghdr msg_hdr = {};
350   iovec iov = {const_cast<void*>(msg), len};
351   msg_hdr.msg_iov = &iov;
352   msg_hdr.msg_iovlen = 1;
353   alignas(cmsghdr) char control_buf[256];
354 
355   if (num_fds > 0) {
356     const auto raw_ctl_data_sz = num_fds * sizeof(int);
357     const CBufLenType control_buf_len =
358         static_cast<CBufLenType>(CMSG_SPACE(raw_ctl_data_sz));
359     PERFETTO_CHECK(control_buf_len <= sizeof(control_buf));
360     memset(control_buf, 0, sizeof(control_buf));
361     msg_hdr.msg_control = control_buf;
362     msg_hdr.msg_controllen = control_buf_len;  // used by CMSG_FIRSTHDR
363     struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg_hdr);
364     cmsg->cmsg_level = SOL_SOCKET;
365     cmsg->cmsg_type = SCM_RIGHTS;
366     cmsg->cmsg_len = static_cast<CBufLenType>(CMSG_LEN(raw_ctl_data_sz));
367     memcpy(CMSG_DATA(cmsg), send_fds, num_fds * sizeof(int));
368     // note: if we were to send multiple cmsghdr structures, then
369     // msg_hdr.msg_controllen would need to be adjusted, see "man 3 cmsg".
370   }
371 
372   return SendMsgAll(&msg_hdr);
373 }
374 
Receive(void * msg,size_t len,ScopedFile * fd_vec,size_t max_files)375 ssize_t UnixSocketRaw::Receive(void* msg,
376                                size_t len,
377                                ScopedFile* fd_vec,
378                                size_t max_files) {
379   PERFETTO_DCHECK(fd_);
380   msghdr msg_hdr = {};
381   iovec iov = {msg, len};
382   msg_hdr.msg_iov = &iov;
383   msg_hdr.msg_iovlen = 1;
384   alignas(cmsghdr) char control_buf[256];
385 
386   if (max_files > 0) {
387     msg_hdr.msg_control = control_buf;
388     msg_hdr.msg_controllen =
389         static_cast<CBufLenType>(CMSG_SPACE(max_files * sizeof(int)));
390     PERFETTO_CHECK(msg_hdr.msg_controllen <= sizeof(control_buf));
391   }
392   const ssize_t sz = PERFETTO_EINTR(recvmsg(*fd_, &msg_hdr, 0));
393   if (sz <= 0) {
394     return sz;
395   }
396   PERFETTO_CHECK(static_cast<size_t>(sz) <= len);
397 
398   int* fds = nullptr;
399   uint32_t fds_len = 0;
400 
401   if (max_files > 0) {
402     for (cmsghdr* cmsg = CMSG_FIRSTHDR(&msg_hdr); cmsg;
403          cmsg = CMSG_NXTHDR(&msg_hdr, cmsg)) {
404       const size_t payload_len = cmsg->cmsg_len - CMSG_LEN(0);
405       if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
406         PERFETTO_DCHECK(payload_len % sizeof(int) == 0u);
407         PERFETTO_CHECK(fds == nullptr);
408         fds = reinterpret_cast<int*>(CMSG_DATA(cmsg));
409         fds_len = static_cast<uint32_t>(payload_len / sizeof(int));
410       }
411     }
412   }
413 
414   if (msg_hdr.msg_flags & MSG_TRUNC || msg_hdr.msg_flags & MSG_CTRUNC) {
415     for (size_t i = 0; fds && i < fds_len; ++i)
416       close(fds[i]);
417     errno = EMSGSIZE;
418     return -1;
419   }
420 
421   for (size_t i = 0; fds && i < fds_len; ++i) {
422     if (i < max_files)
423       fd_vec[i].reset(fds[i]);
424     else
425       close(fds[i]);
426   }
427 
428   return sz;
429 }
430 
SetTxTimeout(uint32_t timeout_ms)431 bool UnixSocketRaw::SetTxTimeout(uint32_t timeout_ms) {
432   PERFETTO_DCHECK(fd_);
433   struct timeval timeout {};
434   uint32_t timeout_sec = timeout_ms / 1000;
435   timeout.tv_sec = static_cast<decltype(timeout.tv_sec)>(timeout_sec);
436   timeout.tv_usec = static_cast<decltype(timeout.tv_usec)>(
437       (timeout_ms - (timeout_sec * 1000)) * 1000);
438 
439   return setsockopt(*fd_, SOL_SOCKET, SO_SNDTIMEO,
440                     reinterpret_cast<const char*>(&timeout),
441                     sizeof(timeout)) == 0;
442 }
443 
SetRxTimeout(uint32_t timeout_ms)444 bool UnixSocketRaw::SetRxTimeout(uint32_t timeout_ms) {
445   PERFETTO_DCHECK(fd_);
446   struct timeval timeout {};
447   uint32_t timeout_sec = timeout_ms / 1000;
448   timeout.tv_sec = static_cast<decltype(timeout.tv_sec)>(timeout_sec);
449   timeout.tv_usec = static_cast<decltype(timeout.tv_usec)>(
450       (timeout_ms - (timeout_sec * 1000)) * 1000);
451 
452   return setsockopt(*fd_, SOL_SOCKET, SO_RCVTIMEO,
453                     reinterpret_cast<const char*>(&timeout),
454                     sizeof(timeout)) == 0;
455 }
456 
457 #pragma GCC diagnostic pop
458 
459 // +--------------------+
460 // | UnixSocket methods |
461 // +--------------------+
462 
463 // TODO(primiano): Add ThreadChecker to methods of this class.
464 
465 // static
Listen(const std::string & socket_name,EventListener * event_listener,TaskRunner * task_runner,SockFamily sock_family,SockType sock_type)466 std::unique_ptr<UnixSocket> UnixSocket::Listen(const std::string& socket_name,
467                                                EventListener* event_listener,
468                                                TaskRunner* task_runner,
469                                                SockFamily sock_family,
470                                                SockType sock_type) {
471   auto sock_raw = UnixSocketRaw::CreateMayFail(sock_family, sock_type);
472   if (!sock_raw || !sock_raw.Bind(socket_name))
473     return nullptr;
474 
475   // Forward the call to the Listen() overload below.
476   return Listen(sock_raw.ReleaseFd(), event_listener, task_runner, sock_family,
477                 sock_type);
478 }
479 
480 // static
Listen(ScopedFile fd,EventListener * event_listener,TaskRunner * task_runner,SockFamily sock_family,SockType sock_type)481 std::unique_ptr<UnixSocket> UnixSocket::Listen(ScopedFile fd,
482                                                EventListener* event_listener,
483                                                TaskRunner* task_runner,
484                                                SockFamily sock_family,
485                                                SockType sock_type) {
486   return std::unique_ptr<UnixSocket>(
487       new UnixSocket(event_listener, task_runner, std::move(fd),
488                      State::kListening, sock_family, sock_type));
489 }
490 
491 // static
Connect(const std::string & socket_name,EventListener * event_listener,TaskRunner * task_runner,SockFamily sock_family,SockType sock_type)492 std::unique_ptr<UnixSocket> UnixSocket::Connect(const std::string& socket_name,
493                                                 EventListener* event_listener,
494                                                 TaskRunner* task_runner,
495                                                 SockFamily sock_family,
496                                                 SockType sock_type) {
497   std::unique_ptr<UnixSocket> sock(
498       new UnixSocket(event_listener, task_runner, sock_family, sock_type));
499   sock->DoConnect(socket_name);
500   return sock;
501 }
502 
503 // static
AdoptConnected(ScopedFile fd,EventListener * event_listener,TaskRunner * task_runner,SockFamily sock_family,SockType sock_type)504 std::unique_ptr<UnixSocket> UnixSocket::AdoptConnected(
505     ScopedFile fd,
506     EventListener* event_listener,
507     TaskRunner* task_runner,
508     SockFamily sock_family,
509     SockType sock_type) {
510   return std::unique_ptr<UnixSocket>(
511       new UnixSocket(event_listener, task_runner, std::move(fd),
512                      State::kConnected, sock_family, sock_type));
513 }
514 
UnixSocket(EventListener * event_listener,TaskRunner * task_runner,SockFamily sock_family,SockType sock_type)515 UnixSocket::UnixSocket(EventListener* event_listener,
516                        TaskRunner* task_runner,
517                        SockFamily sock_family,
518                        SockType sock_type)
519     : UnixSocket(event_listener,
520                  task_runner,
521                  ScopedFile(),
522                  State::kDisconnected,
523                  sock_family,
524                  sock_type) {}
525 
UnixSocket(EventListener * event_listener,TaskRunner * task_runner,ScopedFile adopt_fd,State adopt_state,SockFamily sock_family,SockType sock_type)526 UnixSocket::UnixSocket(EventListener* event_listener,
527                        TaskRunner* task_runner,
528                        ScopedFile adopt_fd,
529                        State adopt_state,
530                        SockFamily sock_family,
531                        SockType sock_type)
532     : event_listener_(event_listener),
533       task_runner_(task_runner),
534       weak_ptr_factory_(this) {
535   state_ = State::kDisconnected;
536   if (adopt_state == State::kDisconnected) {
537     PERFETTO_DCHECK(!adopt_fd);
538     sock_raw_ = UnixSocketRaw::CreateMayFail(sock_family, sock_type);
539     if (!sock_raw_) {
540       last_error_ = errno;
541       return;
542     }
543   } else if (adopt_state == State::kConnected) {
544     PERFETTO_DCHECK(adopt_fd);
545     sock_raw_ = UnixSocketRaw(std::move(adopt_fd), sock_family, sock_type);
546     state_ = State::kConnected;
547     ReadPeerCredentials();
548   } else if (adopt_state == State::kListening) {
549     // We get here from Listen().
550 
551     // |adopt_fd| might genuinely be invalid if the bind() failed.
552     if (!adopt_fd) {
553       last_error_ = errno;
554       return;
555     }
556 
557     sock_raw_ = UnixSocketRaw(std::move(adopt_fd), sock_family, sock_type);
558     if (!sock_raw_.Listen()) {
559       last_error_ = errno;
560       PERFETTO_DPLOG("listen()");
561       return;
562     }
563     state_ = State::kListening;
564   } else {
565     PERFETTO_FATAL("Unexpected adopt_state");  // Unfeasible.
566   }
567 
568   PERFETTO_CHECK(sock_raw_);
569   last_error_ = 0;
570 
571   sock_raw_.SetBlocking(false);
572 
573   WeakPtr<UnixSocket> weak_ptr = weak_ptr_factory_.GetWeakPtr();
574 
575   task_runner_->AddFileDescriptorWatch(sock_raw_.fd(), [weak_ptr] {
576     if (weak_ptr)
577       weak_ptr->OnEvent();
578   });
579 }
580 
~UnixSocket()581 UnixSocket::~UnixSocket() {
582   // The implicit dtor of |weak_ptr_factory_| will no-op pending callbacks.
583   Shutdown(true);
584 }
585 
ReleaseSocket()586 UnixSocketRaw UnixSocket::ReleaseSocket() {
587   // This will invalidate any pending calls to OnEvent.
588   state_ = State::kDisconnected;
589   if (sock_raw_)
590     task_runner_->RemoveFileDescriptorWatch(sock_raw_.fd());
591 
592   return std::move(sock_raw_);
593 }
594 
595 // Called only by the Connect() static constructor.
DoConnect(const std::string & socket_name)596 void UnixSocket::DoConnect(const std::string& socket_name) {
597   PERFETTO_DCHECK(state_ == State::kDisconnected);
598 
599   // This is the only thing that can gracefully fail in the ctor.
600   if (!sock_raw_)
601     return NotifyConnectionState(false);
602 
603   if (!sock_raw_.Connect(socket_name)) {
604     last_error_ = errno;
605     return NotifyConnectionState(false);
606   }
607 
608   // At this point either connect() succeeded or started asynchronously
609   // (errno = EINPROGRESS).
610   last_error_ = 0;
611   state_ = State::kConnecting;
612 
613   // Even if the socket is non-blocking, connecting to a UNIX socket can be
614   // acknowledged straight away rather than returning EINPROGRESS.
615   // The decision here is to deal with the two cases uniformly, at the cost of
616   // delaying the straight-away-connect() case by one task, to avoid depending
617   // on implementation details of UNIX socket on the various OSes.
618   // Posting the OnEvent() below emulates a wakeup of the FD watch. OnEvent(),
619   // which knows how to deal with spurious wakeups, will poll the SO_ERROR and
620   // evolve, if necessary, the state into either kConnected or kDisconnected.
621   WeakPtr<UnixSocket> weak_ptr = weak_ptr_factory_.GetWeakPtr();
622   task_runner_->PostTask([weak_ptr] {
623     if (weak_ptr)
624       weak_ptr->OnEvent();
625   });
626 }
627 
ReadPeerCredentials()628 void UnixSocket::ReadPeerCredentials() {
629   // Peer credentials are supported only on AF_UNIX sockets.
630   if (sock_raw_.family() != SockFamily::kUnix)
631     return;
632 
633 #if (PERFETTO_BUILDFLAG(PERFETTO_OS_LINUX) && \
634      (!PERFETTO_BUILDFLAG(PERFETTO_OS_FREEBSD) && !PERFETTO_BUILDFLAG(PERFETTO_OS_DRAGONFLY)) || \
635      PERFETTO_BUILDFLAG(PERFETTO_OS_ANDROID))
636   struct ucred user_cred;
637   socklen_t len = sizeof(user_cred);
638   int fd = sock_raw_.fd();
639   int res = getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &user_cred, &len);
640   PERFETTO_CHECK(res == 0);
641   peer_uid_ = user_cred.uid;
642   peer_pid_ = user_cred.pid;
643 #else
644   struct xucred user_cred;
645   socklen_t len = sizeof(user_cred);
646   int res = getsockopt(sock_raw_.fd(), 0, LOCAL_PEERCRED, &user_cred, &len);
647   PERFETTO_CHECK(res == 0 && user_cred.cr_version == XUCRED_VERSION);
648   peer_uid_ = static_cast<uid_t>(user_cred.cr_uid);
649 // There is no pid in the LOCAL_PEERCREDS for MacOS / FreeBSD.
650 #endif
651 }
652 
OnEvent()653 void UnixSocket::OnEvent() {
654   if (state_ == State::kDisconnected)
655     return;  // Some spurious event, typically queued just before Shutdown().
656 
657   if (state_ == State::kConnected)
658     return event_listener_->OnDataAvailable(this);
659 
660   if (state_ == State::kConnecting) {
661     PERFETTO_DCHECK(sock_raw_);
662     int sock_err = EINVAL;
663     socklen_t err_len = sizeof(sock_err);
664     int res =
665         getsockopt(sock_raw_.fd(), SOL_SOCKET, SO_ERROR, &sock_err, &err_len);
666 
667     if (res == 0 && sock_err == EINPROGRESS)
668       return;  // Not connected yet, just a spurious FD watch wakeup.
669     if (res == 0 && sock_err == 0) {
670       ReadPeerCredentials();
671       state_ = State::kConnected;
672       return event_listener_->OnConnect(this, true /* connected */);
673     }
674     PERFETTO_DLOG("Connection error: %s", strerror(sock_err));
675     last_error_ = sock_err;
676     Shutdown(false);
677     return event_listener_->OnConnect(this, false /* connected */);
678   }
679 
680   // New incoming connection.
681   if (state_ == State::kListening) {
682     // There could be more than one incoming connection behind each FD watch
683     // notification. Drain'em all.
684     for (;;) {
685       struct sockaddr_in cli_addr {};
686       socklen_t size = sizeof(cli_addr);
687       ScopedFile new_fd(PERFETTO_EINTR(accept(
688           sock_raw_.fd(), reinterpret_cast<sockaddr*>(&cli_addr), &size)));
689       if (!new_fd)
690         return;
691       std::unique_ptr<UnixSocket> new_sock(new UnixSocket(
692           event_listener_, task_runner_, std::move(new_fd), State::kConnected,
693           sock_raw_.family(), sock_raw_.type()));
694       event_listener_->OnNewIncomingConnection(this, std::move(new_sock));
695     }
696   }
697 }
698 
Send(const void * msg,size_t len,const int * send_fds,size_t num_fds)699 bool UnixSocket::Send(const void* msg,
700                       size_t len,
701                       const int* send_fds,
702                       size_t num_fds) {
703   if (state_ != State::kConnected) {
704     errno = last_error_ = ENOTCONN;
705     return false;
706   }
707 
708   sock_raw_.SetBlocking(true);
709   const ssize_t sz = sock_raw_.Send(msg, len, send_fds, num_fds);
710   int saved_errno = errno;
711   sock_raw_.SetBlocking(false);
712 
713   if (sz == static_cast<ssize_t>(len)) {
714     last_error_ = 0;
715     return true;
716   }
717 
718   // If sendmsg() succeeds but the returned size is < |len| it means that the
719   // endpoint disconnected in the middle of the read, and we managed to send
720   // only a portion of the buffer. In this case we should just give up.
721 
722   if (sz < 0 && (saved_errno == EAGAIN || saved_errno == EWOULDBLOCK)) {
723     // A genuine out-of-buffer. The client should retry or give up.
724     // Man pages specify that EAGAIN and EWOULDBLOCK have the same semantic here
725     // and clients should check for both.
726     last_error_ = EAGAIN;
727     return false;
728   }
729 
730   // Either the other endpoint disconnected (ECONNRESET) or some other error
731   // happened.
732   last_error_ = saved_errno;
733   PERFETTO_DPLOG("sendmsg() failed");
734   Shutdown(true);
735   return false;
736 }
737 
Shutdown(bool notify)738 void UnixSocket::Shutdown(bool notify) {
739   WeakPtr<UnixSocket> weak_ptr = weak_ptr_factory_.GetWeakPtr();
740   if (notify) {
741     if (state_ == State::kConnected) {
742       task_runner_->PostTask([weak_ptr] {
743         if (weak_ptr)
744           weak_ptr->event_listener_->OnDisconnect(weak_ptr.get());
745       });
746     } else if (state_ == State::kConnecting) {
747       task_runner_->PostTask([weak_ptr] {
748         if (weak_ptr)
749           weak_ptr->event_listener_->OnConnect(weak_ptr.get(), false);
750       });
751     }
752   }
753 
754   if (sock_raw_) {
755     task_runner_->RemoveFileDescriptorWatch(sock_raw_.fd());
756     sock_raw_.Shutdown();
757   }
758   state_ = State::kDisconnected;
759 }
760 
Receive(void * msg,size_t len,ScopedFile * fd_vec,size_t max_files)761 size_t UnixSocket::Receive(void* msg,
762                            size_t len,
763                            ScopedFile* fd_vec,
764                            size_t max_files) {
765   if (state_ != State::kConnected) {
766     last_error_ = ENOTCONN;
767     return 0;
768   }
769 
770   const ssize_t sz = sock_raw_.Receive(msg, len, fd_vec, max_files);
771   if (sz < 0 && (errno == EAGAIN || errno == EWOULDBLOCK)) {
772     last_error_ = EAGAIN;
773     return 0;
774   }
775   if (sz <= 0) {
776     last_error_ = errno;
777     Shutdown(true);
778     return 0;
779   }
780   PERFETTO_CHECK(static_cast<size_t>(sz) <= len);
781   return static_cast<size_t>(sz);
782 }
783 
ReceiveString(size_t max_length)784 std::string UnixSocket::ReceiveString(size_t max_length) {
785   std::unique_ptr<char[]> buf(new char[max_length + 1]);
786   size_t rsize = Receive(buf.get(), max_length);
787   PERFETTO_CHECK(static_cast<size_t>(rsize) <= max_length);
788   buf[static_cast<size_t>(rsize)] = '\0';
789   return std::string(buf.get());
790 }
791 
NotifyConnectionState(bool success)792 void UnixSocket::NotifyConnectionState(bool success) {
793   if (!success)
794     Shutdown(false);
795 
796   WeakPtr<UnixSocket> weak_ptr = weak_ptr_factory_.GetWeakPtr();
797   task_runner_->PostTask([weak_ptr, success] {
798     if (weak_ptr)
799       weak_ptr->event_listener_->OnConnect(weak_ptr.get(), success);
800   });
801 }
802 
~EventListener()803 UnixSocket::EventListener::~EventListener() {}
OnNewIncomingConnection(UnixSocket *,std::unique_ptr<UnixSocket>)804 void UnixSocket::EventListener::OnNewIncomingConnection(
805     UnixSocket*,
806     std::unique_ptr<UnixSocket>) {}
OnConnect(UnixSocket *,bool)807 void UnixSocket::EventListener::OnConnect(UnixSocket*, bool) {}
OnDisconnect(UnixSocket *)808 void UnixSocket::EventListener::OnDisconnect(UnixSocket*) {}
OnDataAvailable(UnixSocket *)809 void UnixSocket::EventListener::OnDataAvailable(UnixSocket*) {}
810 
811 }  // namespace base
812 }  // namespace perfetto
813