1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef INCLUDE_PERFETTO_EXT_BASE_UNIX_SOCKET_H_
18 #define INCLUDE_PERFETTO_EXT_BASE_UNIX_SOCKET_H_
19 
20 #include <stdint.h>
21 #include <sys/types.h>
22 
23 #include <memory>
24 #include <string>
25 
26 #include "perfetto/base/logging.h"
27 #include "perfetto/ext/base/scoped_file.h"
28 #include "perfetto/ext/base/utils.h"
29 #include "perfetto/ext/base/weak_ptr.h"
30 
31 struct msghdr;
32 
33 namespace perfetto {
34 namespace base {
35 
36 class TaskRunner;
37 
38 // Use arbitrarily high values to avoid that some code accidentally ends up
39 // assuming that these enum values match the sysroot's SOCK_xxx defines rather
40 // than using GetSockType() / GetSockFamily().
41 enum class SockType { kStream = 100, kDgram, kSeqPacket };
42 enum class SockFamily { kUnix = 200, kInet, kInet6 };
43 
44 // UnixSocketRaw is a basic wrapper around UNIX sockets. It exposes wrapper
45 // methods that take care of most common pitfalls (e.g., marking fd as
46 // O_CLOEXEC, avoiding SIGPIPE, properly handling partial writes). It is used as
47 // a building block for the more sophisticated UnixSocket class.
48 class UnixSocketRaw {
49  public:
50   // Creates a new unconnected unix socket.
51   static UnixSocketRaw CreateMayFail(SockFamily family, SockType type);
52 
53   // Crates a pair of connected sockets.
54   static std::pair<UnixSocketRaw, UnixSocketRaw> CreatePair(SockFamily,
55                                                             SockType);
56 
57   // Creates an uninitialized unix socket.
58   UnixSocketRaw();
59 
60   // Creates a unix socket adopting an existing file descriptor. This is
61   // typically used to inherit fds from init via environment variables.
62   UnixSocketRaw(ScopedFile, SockFamily, SockType);
63 
64   ~UnixSocketRaw() = default;
65   UnixSocketRaw(UnixSocketRaw&&) noexcept = default;
66   UnixSocketRaw& operator=(UnixSocketRaw&&) = default;
67 
68   bool Bind(const std::string& socket_name);
69   bool Listen();
70   bool Connect(const std::string& socket_name);
71   bool SetTxTimeout(uint32_t timeout_ms);
72   bool SetRxTimeout(uint32_t timeout_ms);
73   void Shutdown();
74   void SetBlocking(bool);
75   bool IsBlocking() const;
76   void RetainOnExec();
type()77   SockType type() const { return type_; }
family()78   SockFamily family() const { return family_; }
fd()79   int fd() const { return *fd_; }
80   explicit operator bool() const { return !!fd_; }
81 
ReleaseFd()82   ScopedFile ReleaseFd() { return std::move(fd_); }
83 
84   ssize_t Send(const void* msg,
85                size_t len,
86                const int* send_fds = nullptr,
87                size_t num_fds = 0);
88 
89   // Re-enter sendmsg until all the data has been sent or an error occurs.
90   // TODO(fmayer): Figure out how to do timeouts here for heapprofd.
91   ssize_t SendMsgAll(struct msghdr* msg);
92 
93   ssize_t Receive(void* msg,
94                   size_t len,
95                   ScopedFile* fd_vec = nullptr,
96                   size_t max_files = 0);
97 
98   // Exposed for testing only.
99   // Update msghdr so subsequent sendmsg will send data that remains after n
100   // bytes have already been sent.
101   static void ShiftMsgHdr(size_t n, struct msghdr* msg);
102 
103  private:
104   UnixSocketRaw(SockFamily, SockType);
105 
106   UnixSocketRaw(const UnixSocketRaw&) = delete;
107   UnixSocketRaw& operator=(const UnixSocketRaw&) = delete;
108 
109   ScopedFile fd_;
110   SockFamily family_ = SockFamily::kUnix;
111   SockType type_ = SockType::kStream;
112 };
113 
114 // A non-blocking UNIX domain socket. Allows also to transfer file descriptors.
115 // None of the methods in this class are blocking.
116 // The main design goal is making strong guarantees on the EventListener
117 // callbacks, in order to avoid ending in some undefined state.
118 // In case of any error it will aggressively just shut down the socket and
119 // notify the failure with OnConnect(false) or OnDisconnect() depending on the
120 // state of the socket (see below).
121 // EventListener callbacks stop happening as soon as the instance is destroyed.
122 //
123 // Lifecycle of a client socket:
124 //
125 //                           Connect()
126 //                               |
127 //            +------------------+------------------+
128 //            | (success)                           | (failure or Shutdown())
129 //            V                                     V
130 //     OnConnect(true)                         OnConnect(false)
131 //            |
132 //            V
133 //    OnDataAvailable()
134 //            |
135 //            V
136 //     OnDisconnect()  (failure or shutdown)
137 //
138 //
139 // Lifecycle of a server socket:
140 //
141 //                          Listen()  --> returns false in case of errors.
142 //                             |
143 //                             V
144 //              OnNewIncomingConnection(new_socket)
145 //
146 //          (|new_socket| inherits the same EventListener)
147 //                             |
148 //                             V
149 //                     OnDataAvailable()
150 //                             | (failure or Shutdown())
151 //                             V
152 //                       OnDisconnect()
153 class UnixSocket {
154  public:
155   class EventListener {
156    public:
157     virtual ~EventListener();
158 
159     // After Listen().
160     virtual void OnNewIncomingConnection(
161         UnixSocket* self,
162         std::unique_ptr<UnixSocket> new_connection);
163 
164     // After Connect(), whether successful or not.
165     virtual void OnConnect(UnixSocket* self, bool connected);
166 
167     // After a successful Connect() or OnNewIncomingConnection(). Either the
168     // other endpoint did disconnect or some other error happened.
169     virtual void OnDisconnect(UnixSocket* self);
170 
171     // Whenever there is data available to Receive(). Note that spurious FD
172     // watch events are possible, so it is possible that Receive() soon after
173     // OnDataAvailable() returns 0 (just ignore those).
174     virtual void OnDataAvailable(UnixSocket* self);
175   };
176 
177   enum class State {
178     kDisconnected = 0,  // Failed connection, peer disconnection or Shutdown().
179     kConnecting,  // Soon after Connect(), before it either succeeds or fails.
180     kConnected,   // After a successful Connect().
181     kListening    // After Listen(), until Shutdown().
182   };
183 
184   // Creates a socket and starts listening. If SockFamily::kUnix and
185   // |socket_name| starts with a '@', an abstract UNIX dmoain socket will be
186   // created instead of a filesystem-linked UNIX socket (Linux/Android only).
187   // If SockFamily::kInet, |socket_name| is host:port (e.g., "1.2.3.4:8000").
188   // If SockFamily::kInet6, |socket_name| is [host]:port (e.g., "[::1]:8000").
189   // Returns nullptr if the socket creation or bind fails. If listening fails,
190   // (e.g. if another socket with the same name is already listening) the
191   // returned socket will have is_listening() == false and last_error() will
192   // contain the failure reason.
193   static std::unique_ptr<UnixSocket> Listen(const std::string& socket_name,
194                                             EventListener*,
195                                             TaskRunner*,
196                                             SockFamily,
197                                             SockType);
198 
199   // Attaches to a pre-existing socket. The socket must have been created in
200   // SOCK_STREAM mode and the caller must have called bind() on it.
201   static std::unique_ptr<UnixSocket> Listen(ScopedFile,
202                                             EventListener*,
203                                             TaskRunner*,
204                                             SockFamily,
205                                             SockType);
206 
207   // Creates a Unix domain socket and connects to the listening endpoint.
208   // Returns always an instance. EventListener::OnConnect(bool success) will
209   // be called always, whether the connection succeeded or not.
210   static std::unique_ptr<UnixSocket> Connect(const std::string& socket_name,
211                                              EventListener*,
212                                              TaskRunner*,
213                                              SockFamily,
214                                              SockType);
215 
216   // Constructs a UnixSocket using the given connected socket.
217   static std::unique_ptr<UnixSocket> AdoptConnected(ScopedFile,
218                                                     EventListener*,
219                                                     TaskRunner*,
220                                                     SockFamily,
221                                                     SockType);
222 
223   UnixSocket(const UnixSocket&) = delete;
224   UnixSocket& operator=(const UnixSocket&) = delete;
225   // Cannot be easily moved because of tasks from the FileDescriptorWatch.
226   UnixSocket(UnixSocket&&) = delete;
227   UnixSocket& operator=(UnixSocket&&) = delete;
228 
229   // This class gives the hard guarantee that no callback is called on the
230   // passed EventListener immediately after the object has been destroyed.
231   // Any queued callback will be silently dropped.
232   ~UnixSocket();
233 
234   // Shuts down the current connection, if any. If the socket was Listen()-ing,
235   // stops listening. The socket goes back to kNotInitialized state, so it can
236   // be reused with Listen() or Connect().
237   void Shutdown(bool notify);
238 
239   // Returns true is the message was queued, false if there was no space in the
240   // output buffer, in which case the client should retry or give up.
241   // If any other error happens the socket will be shutdown and
242   // EventListener::OnDisconnect() will be called.
243   // If the socket is not connected, Send() will just return false.
244   // Does not append a null string terminator to msg in any case.
245   bool Send(const void* msg, size_t len, const int* send_fds, size_t num_fds);
246 
247   inline bool Send(const void* msg, size_t len, int send_fd = -1) {
248     if (send_fd != -1)
249       return Send(msg, len, &send_fd, 1);
250     return Send(msg, len, nullptr, 0);
251   }
252 
Send(const std::string & msg)253   inline bool Send(const std::string& msg) {
254     return Send(msg.c_str(), msg.size() + 1, -1);
255   }
256 
257   // Returns the number of bytes (<= |len|) written in |msg| or 0 if there
258   // is no data in the buffer to read or an error occurs (in which case a
259   // EventListener::OnDisconnect() will follow).
260   // If the ScopedFile pointer is not null and a FD is received, it moves the
261   // received FD into that. If a FD is received but the ScopedFile pointer is
262   // null, the FD will be automatically closed.
263   size_t Receive(void* msg, size_t len, ScopedFile*, size_t max_files = 1);
264 
Receive(void * msg,size_t len)265   inline size_t Receive(void* msg, size_t len) {
266     return Receive(msg, len, nullptr, 0);
267   }
268 
269   // Only for tests. This is slower than Receive() as it requires a heap
270   // allocation and a copy for the std::string. Guarantees that the returned
271   // string is null terminated even if the underlying message sent by the peer
272   // is not.
273   std::string ReceiveString(size_t max_length = 1024);
274 
is_connected()275   bool is_connected() const { return state_ == State::kConnected; }
is_listening()276   bool is_listening() const { return state_ == State::kListening; }
fd()277   int fd() const { return sock_raw_.fd(); }
last_error()278   int last_error() const { return last_error_; }
279 
280   // User ID of the peer, as returned by the kernel. If the client disconnects
281   // and the socket goes into the kDisconnected state, it retains the uid of
282   // the last peer.
peer_uid()283   uid_t peer_uid() const {
284     PERFETTO_DCHECK(!is_listening() && peer_uid_ != kInvalidUid);
285     ignore_result(kInvalidPid);  // Silence warnings in amalgamated builds.
286     return peer_uid_;
287   }
288 
289 #if PERFETTO_BUILDFLAG(PERFETTO_OS_LINUX) || \
290     PERFETTO_BUILDFLAG(PERFETTO_OS_ANDROID)
291   // Process ID of the peer, as returned by the kernel. If the client
292   // disconnects and the socket goes into the kDisconnected state, it
293   // retains the pid of the last peer.
294   //
295   // This is only available on Linux / Android.
peer_pid()296   pid_t peer_pid() const {
297     PERFETTO_DCHECK(!is_listening() && peer_pid_ != kInvalidPid);
298     return peer_pid_;
299   }
300 #endif
301 
302   // This makes the UnixSocket unusable.
303   UnixSocketRaw ReleaseSocket();
304 
305  private:
306   UnixSocket(EventListener*, TaskRunner*, SockFamily, SockType);
307   UnixSocket(EventListener*,
308              TaskRunner*,
309              ScopedFile,
310              State,
311              SockFamily,
312              SockType);
313 
314   // Called once by the corresponding public static factory methods.
315   void DoConnect(const std::string& socket_name);
316   void ReadPeerCredentials();
317 
318   void OnEvent();
319   void NotifyConnectionState(bool success);
320 
321   UnixSocketRaw sock_raw_;
322   State state_ = State::kDisconnected;
323   int last_error_ = 0;
324   uid_t peer_uid_ = kInvalidUid;
325 #if PERFETTO_BUILDFLAG(PERFETTO_OS_LINUX) || \
326     PERFETTO_BUILDFLAG(PERFETTO_OS_ANDROID)
327   pid_t peer_pid_ = kInvalidPid;
328 #endif
329   EventListener* const event_listener_;
330   TaskRunner* const task_runner_;
331   WeakPtrFactory<UnixSocket> weak_ptr_factory_;  // Keep last.
332 };
333 
334 }  // namespace base
335 }  // namespace perfetto
336 
337 #endif  // INCLUDE_PERFETTO_EXT_BASE_UNIX_SOCKET_H_
338