1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "perfetto/ext/base/unix_socket.h"
18 
19 #include <signal.h>
20 #include <sys/mman.h>
21 #include <sys/socket.h>
22 #include <sys/types.h>
23 #include <sys/un.h>
24 #include <list>
25 #include <thread>
26 
27 #include "perfetto/base/build_config.h"
28 #include "perfetto/base/logging.h"
29 #include "perfetto/ext/base/file_utils.h"
30 #include "perfetto/ext/base/pipe.h"
31 #include "perfetto/ext/base/temp_file.h"
32 #include "perfetto/ext/base/utils.h"
33 #include "src/base/test/test_task_runner.h"
34 #include "src/ipc/test/test_socket.h"
35 #include "test/gtest_and_gmock.h"
36 
37 namespace perfetto {
38 namespace base {
39 namespace {
40 
41 using ::testing::_;
42 using ::testing::AtLeast;
43 using ::testing::Invoke;
44 using ::testing::InvokeWithoutArgs;
45 using ::testing::Mock;
46 
47 constexpr char kSocketName[] = TEST_SOCK_NAME("unix_socket_unittest");
48 
49 class MockEventListener : public UnixSocket::EventListener {
50  public:
51   MOCK_METHOD2(OnNewIncomingConnection, void(UnixSocket*, UnixSocket*));
52   MOCK_METHOD2(OnConnect, void(UnixSocket*, bool));
53   MOCK_METHOD1(OnDisconnect, void(UnixSocket*));
54   MOCK_METHOD1(OnDataAvailable, void(UnixSocket*));
55 
56   // GMock doesn't support mocking methods with non-copiable args.
OnNewIncomingConnection(UnixSocket * self,std::unique_ptr<UnixSocket> new_connection)57   void OnNewIncomingConnection(
58       UnixSocket* self,
59       std::unique_ptr<UnixSocket> new_connection) override {
60     incoming_connections_.emplace_back(std::move(new_connection));
61     OnNewIncomingConnection(self, incoming_connections_.back().get());
62   }
63 
GetIncomingConnection()64   std::unique_ptr<UnixSocket> GetIncomingConnection() {
65     if (incoming_connections_.empty())
66       return nullptr;
67     std::unique_ptr<UnixSocket> sock = std::move(incoming_connections_.front());
68     incoming_connections_.pop_front();
69     return sock;
70   }
71 
72  private:
73   std::list<std::unique_ptr<UnixSocket>> incoming_connections_;
74 };
75 
76 class UnixSocketTest : public ::testing::Test {
77  protected:
SetUp()78   void SetUp() override { DESTROY_TEST_SOCK(kSocketName); }
TearDown()79   void TearDown() override { DESTROY_TEST_SOCK(kSocketName); }
80 
81   TestTaskRunner task_runner_;
82   MockEventListener event_listener_;
83 };
84 
TEST_F(UnixSocketTest,ConnectionFailureIfUnreachable)85 TEST_F(UnixSocketTest, ConnectionFailureIfUnreachable) {
86   auto cli = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_,
87                                  SockFamily::kUnix, SockType::kStream);
88   ASSERT_FALSE(cli->is_connected());
89   auto checkpoint = task_runner_.CreateCheckpoint("failure");
90   EXPECT_CALL(event_listener_, OnConnect(cli.get(), false))
91       .WillOnce(InvokeWithoutArgs(checkpoint));
92   task_runner_.RunUntilCheckpoint("failure");
93 }
94 
95 // Both server and client should see an OnDisconnect() if the server drops
96 // incoming connections immediately as they are created.
TEST_F(UnixSocketTest,ConnectionImmediatelyDroppedByServer)97 TEST_F(UnixSocketTest, ConnectionImmediatelyDroppedByServer) {
98   auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_,
99                                 SockFamily::kUnix, SockType::kStream);
100   ASSERT_TRUE(srv->is_listening());
101 
102   // The server will immediately shutdown the connection upon
103   // OnNewIncomingConnection().
104   auto srv_did_shutdown = task_runner_.CreateCheckpoint("srv_did_shutdown");
105   EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
106       .WillOnce(
107           Invoke([this, srv_did_shutdown](UnixSocket*, UnixSocket* new_conn) {
108             EXPECT_CALL(event_listener_, OnDisconnect(new_conn));
109             new_conn->Shutdown(true);
110             srv_did_shutdown();
111           }));
112 
113   auto checkpoint = task_runner_.CreateCheckpoint("cli_connected");
114   auto cli = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_,
115                                  SockFamily::kUnix, SockType::kStream);
116   EXPECT_CALL(event_listener_, OnConnect(cli.get(), true))
117       .WillOnce(InvokeWithoutArgs(checkpoint));
118   task_runner_.RunUntilCheckpoint("cli_connected");
119   task_runner_.RunUntilCheckpoint("srv_did_shutdown");
120 
121   // Trying to send something will trigger the disconnection notification.
122   auto cli_disconnected = task_runner_.CreateCheckpoint("cli_disconnected");
123   EXPECT_CALL(event_listener_, OnDisconnect(cli.get()))
124       .WillOnce(InvokeWithoutArgs(cli_disconnected));
125   EXPECT_FALSE(cli->Send("whatever"));
126   task_runner_.RunUntilCheckpoint("cli_disconnected");
127 }
128 
TEST_F(UnixSocketTest,ClientAndServerExchangeData)129 TEST_F(UnixSocketTest, ClientAndServerExchangeData) {
130   auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_,
131                                 SockFamily::kUnix, SockType::kStream);
132   ASSERT_TRUE(srv->is_listening());
133 
134   auto cli = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_,
135                                  SockFamily::kUnix, SockType::kStream);
136   EXPECT_CALL(event_listener_, OnConnect(cli.get(), true));
137   auto cli_connected = task_runner_.CreateCheckpoint("cli_connected");
138   auto srv_disconnected = task_runner_.CreateCheckpoint("srv_disconnected");
139   EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
140       .WillOnce(Invoke([this, cli_connected, srv_disconnected](
141                            UnixSocket*, UnixSocket* srv_conn) {
142         EXPECT_CALL(event_listener_, OnDisconnect(srv_conn))
143             .WillOnce(InvokeWithoutArgs(srv_disconnected));
144         cli_connected();
145       }));
146   task_runner_.RunUntilCheckpoint("cli_connected");
147 
148   auto srv_conn = event_listener_.GetIncomingConnection();
149   ASSERT_TRUE(srv_conn);
150   ASSERT_TRUE(cli->is_connected());
151 
152   auto cli_did_recv = task_runner_.CreateCheckpoint("cli_did_recv");
153   EXPECT_CALL(event_listener_, OnDataAvailable(cli.get()))
154       .WillOnce(Invoke([cli_did_recv](UnixSocket* s) {
155         ASSERT_EQ("srv>cli", s->ReceiveString());
156         cli_did_recv();
157       }));
158 
159   auto srv_did_recv = task_runner_.CreateCheckpoint("srv_did_recv");
160   EXPECT_CALL(event_listener_, OnDataAvailable(srv_conn.get()))
161       .WillOnce(Invoke([srv_did_recv](UnixSocket* s) {
162         ASSERT_EQ("cli>srv", s->ReceiveString());
163         srv_did_recv();
164       }));
165   ASSERT_TRUE(cli->Send("cli>srv"));
166   ASSERT_TRUE(srv_conn->Send("srv>cli"));
167   task_runner_.RunUntilCheckpoint("cli_did_recv");
168   task_runner_.RunUntilCheckpoint("srv_did_recv");
169 
170   // Check that Send/Receive() fails gracefully once the socket is closed.
171   auto cli_disconnected = task_runner_.CreateCheckpoint("cli_disconnected");
172   EXPECT_CALL(event_listener_, OnDisconnect(cli.get()))
173       .WillOnce(InvokeWithoutArgs(cli_disconnected));
174   cli->Shutdown(true);
175   char msg[4];
176   ASSERT_EQ(0u, cli->Receive(&msg, sizeof(msg)));
177   ASSERT_EQ("", cli->ReceiveString());
178   ASSERT_EQ(0u, srv_conn->Receive(&msg, sizeof(msg)));
179   ASSERT_EQ("", srv_conn->ReceiveString());
180   ASSERT_FALSE(cli->Send("foo"));
181   ASSERT_FALSE(srv_conn->Send("bar"));
182   srv->Shutdown(true);
183   task_runner_.RunUntilCheckpoint("cli_disconnected");
184   task_runner_.RunUntilCheckpoint("srv_disconnected");
185 }
186 
187 constexpr char cli_str[] = "cli>srv";
188 constexpr char srv_str[] = "srv>cli";
189 
TEST_F(UnixSocketTest,ClientAndServerExchangeFDs)190 TEST_F(UnixSocketTest, ClientAndServerExchangeFDs) {
191   auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_,
192                                 SockFamily::kUnix, SockType::kStream);
193   ASSERT_TRUE(srv->is_listening());
194 
195   auto cli = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_,
196                                  SockFamily::kUnix, SockType::kStream);
197   EXPECT_CALL(event_listener_, OnConnect(cli.get(), true));
198   auto cli_connected = task_runner_.CreateCheckpoint("cli_connected");
199   auto srv_disconnected = task_runner_.CreateCheckpoint("srv_disconnected");
200   EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
201       .WillOnce(Invoke([this, cli_connected, srv_disconnected](
202                            UnixSocket*, UnixSocket* srv_conn) {
203         EXPECT_CALL(event_listener_, OnDisconnect(srv_conn))
204             .WillOnce(InvokeWithoutArgs(srv_disconnected));
205         cli_connected();
206       }));
207   task_runner_.RunUntilCheckpoint("cli_connected");
208 
209   auto srv_conn = event_listener_.GetIncomingConnection();
210   ASSERT_TRUE(srv_conn);
211   ASSERT_TRUE(cli->is_connected());
212 
213   ScopedFile null_fd(base::OpenFile("/dev/null", O_RDONLY));
214   ScopedFile zero_fd(base::OpenFile("/dev/zero", O_RDONLY));
215 
216   auto cli_did_recv = task_runner_.CreateCheckpoint("cli_did_recv");
217   EXPECT_CALL(event_listener_, OnDataAvailable(cli.get()))
218       .WillRepeatedly(Invoke([cli_did_recv](UnixSocket* s) {
219         ScopedFile fd_buf[3];
220         char buf[sizeof(cli_str)];
221         if (!s->Receive(buf, sizeof(buf), fd_buf, ArraySize(fd_buf)))
222           return;
223         ASSERT_STREQ(srv_str, buf);
224         ASSERT_NE(*fd_buf[0], -1);
225         ASSERT_NE(*fd_buf[1], -1);
226         ASSERT_EQ(*fd_buf[2], -1);
227 
228         char rd_buf[1];
229         // /dev/null
230         ASSERT_EQ(read(*fd_buf[0], rd_buf, sizeof(rd_buf)), 0);
231         // /dev/zero
232         ASSERT_EQ(read(*fd_buf[1], rd_buf, sizeof(rd_buf)), 1);
233         cli_did_recv();
234       }));
235 
236   auto srv_did_recv = task_runner_.CreateCheckpoint("srv_did_recv");
237   EXPECT_CALL(event_listener_, OnDataAvailable(srv_conn.get()))
238       .WillRepeatedly(Invoke([srv_did_recv](UnixSocket* s) {
239         ScopedFile fd_buf[3];
240         char buf[sizeof(srv_str)];
241         if (!s->Receive(buf, sizeof(buf), fd_buf, ArraySize(fd_buf)))
242           return;
243         ASSERT_STREQ(cli_str, buf);
244         ASSERT_NE(*fd_buf[0], -1);
245         ASSERT_NE(*fd_buf[1], -1);
246         ASSERT_EQ(*fd_buf[2], -1);
247 
248         char rd_buf[1];
249         // /dev/null
250         ASSERT_EQ(read(*fd_buf[0], rd_buf, sizeof(rd_buf)), 0);
251         // /dev/zero
252         ASSERT_EQ(read(*fd_buf[1], rd_buf, sizeof(rd_buf)), 1);
253         srv_did_recv();
254       }));
255 
256   int buf_fd[2] = {null_fd.get(), zero_fd.get()};
257 
258   ASSERT_TRUE(
259       cli->Send(cli_str, sizeof(cli_str), buf_fd, base::ArraySize(buf_fd)));
260   ASSERT_TRUE(srv_conn->Send(srv_str, sizeof(srv_str), buf_fd,
261                              base::ArraySize(buf_fd)));
262   task_runner_.RunUntilCheckpoint("srv_did_recv");
263   task_runner_.RunUntilCheckpoint("cli_did_recv");
264 
265   auto cli_disconnected = task_runner_.CreateCheckpoint("cli_disconnected");
266   EXPECT_CALL(event_listener_, OnDisconnect(cli.get()))
267       .WillOnce(InvokeWithoutArgs(cli_disconnected));
268   cli->Shutdown(true);
269   srv->Shutdown(true);
270   task_runner_.RunUntilCheckpoint("srv_disconnected");
271   task_runner_.RunUntilCheckpoint("cli_disconnected");
272 }
273 
TEST_F(UnixSocketTest,ListenWithPassedFileDescriptor)274 TEST_F(UnixSocketTest, ListenWithPassedFileDescriptor) {
275   auto sock_raw =
276       UnixSocketRaw::CreateMayFail(SockFamily::kUnix, SockType::kStream);
277   ASSERT_TRUE(sock_raw.Bind(kSocketName));
278   auto fd = sock_raw.ReleaseFd();
279   auto srv = UnixSocket::Listen(std::move(fd), &event_listener_, &task_runner_,
280                                 SockFamily::kUnix, SockType::kStream);
281   ASSERT_TRUE(srv->is_listening());
282 
283   auto cli = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_,
284                                  SockFamily::kUnix, SockType::kStream);
285   EXPECT_CALL(event_listener_, OnConnect(cli.get(), true));
286   auto cli_connected = task_runner_.CreateCheckpoint("cli_connected");
287   auto srv_disconnected = task_runner_.CreateCheckpoint("srv_disconnected");
288   EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
289       .WillOnce(Invoke([this, cli_connected, srv_disconnected](
290                            UnixSocket*, UnixSocket* srv_conn) {
291         // Read the EOF state.
292         EXPECT_CALL(event_listener_, OnDataAvailable(srv_conn))
293             .WillOnce(
294                 InvokeWithoutArgs([srv_conn] { srv_conn->ReceiveString(); }));
295         EXPECT_CALL(event_listener_, OnDisconnect(srv_conn))
296             .WillOnce(InvokeWithoutArgs(srv_disconnected));
297         cli_connected();
298       }));
299   task_runner_.RunUntilCheckpoint("cli_connected");
300   ASSERT_TRUE(cli->is_connected());
301   cli.reset();
302   task_runner_.RunUntilCheckpoint("srv_disconnected");
303 }
304 
305 // Mostly a stress tests. Connects kNumClients clients to the same server and
306 // tests that all can exchange data and can see the expected sequence of events.
TEST_F(UnixSocketTest,SeveralClients)307 TEST_F(UnixSocketTest, SeveralClients) {
308   auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_,
309                                 SockFamily::kUnix, SockType::kStream);
310   ASSERT_TRUE(srv->is_listening());
311   constexpr size_t kNumClients = 32;
312   std::unique_ptr<UnixSocket> cli[kNumClients];
313 
314   EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
315       .Times(kNumClients)
316       .WillRepeatedly(Invoke([this](UnixSocket*, UnixSocket* s) {
317         EXPECT_CALL(event_listener_, OnDataAvailable(s))
318             .WillOnce(Invoke([](UnixSocket* t) {
319               ASSERT_EQ("PING", t->ReceiveString());
320               ASSERT_TRUE(t->Send("PONG"));
321             }));
322       }));
323 
324   for (size_t i = 0; i < kNumClients; i++) {
325     cli[i] = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_,
326                                  SockFamily::kUnix, SockType::kStream);
327     EXPECT_CALL(event_listener_, OnConnect(cli[i].get(), true))
328         .WillOnce(Invoke([](UnixSocket* s, bool success) {
329           ASSERT_TRUE(success);
330           ASSERT_TRUE(s->Send("PING"));
331         }));
332 
333     auto checkpoint = task_runner_.CreateCheckpoint(std::to_string(i));
334     EXPECT_CALL(event_listener_, OnDataAvailable(cli[i].get()))
335         .WillOnce(Invoke([checkpoint](UnixSocket* s) {
336           ASSERT_EQ("PONG", s->ReceiveString());
337           checkpoint();
338         }));
339   }
340 
341   for (size_t i = 0; i < kNumClients; i++) {
342     task_runner_.RunUntilCheckpoint(std::to_string(i));
343     ASSERT_TRUE(Mock::VerifyAndClearExpectations(cli[i].get()));
344   }
345 }
346 
347 // Creates two processes. The server process creates a file and passes it over
348 // the socket to the client. Both processes mmap the file in shared mode and
349 // check that they see the same contents.
TEST_F(UnixSocketTest,SharedMemory)350 TEST_F(UnixSocketTest, SharedMemory) {
351   Pipe pipe = Pipe::Create();
352   pid_t pid = fork();
353   ASSERT_GE(pid, 0);
354   constexpr size_t kTmpSize = 4096;
355 
356   if (pid == 0) {
357     // Child process.
358     TempFile scoped_tmp = TempFile::CreateUnlinked();
359     int tmp_fd = scoped_tmp.fd();
360     ASSERT_FALSE(ftruncate(tmp_fd, kTmpSize));
361     char* mem = reinterpret_cast<char*>(
362         mmap(nullptr, kTmpSize, PROT_READ | PROT_WRITE, MAP_SHARED, tmp_fd, 0));
363     ASSERT_NE(nullptr, mem);
364     memcpy(mem, "shm rocks", 10);
365 
366     auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_,
367                                   SockFamily::kUnix, SockType::kStream);
368     ASSERT_TRUE(srv->is_listening());
369     // Signal the other process that it can connect.
370     ASSERT_EQ(1, base::WriteAll(*pipe.wr, ".", 1));
371     auto checkpoint = task_runner_.CreateCheckpoint("change_seen_by_server");
372     EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
373         .WillOnce(Invoke(
374             [this, tmp_fd, checkpoint, mem](UnixSocket*, UnixSocket* new_conn) {
375               ASSERT_EQ(geteuid(), static_cast<uint32_t>(new_conn->peer_uid()));
376               ASSERT_TRUE(new_conn->Send("txfd", 5, tmp_fd));
377               // Wait for the client to change this again.
378               EXPECT_CALL(event_listener_, OnDataAvailable(new_conn))
379                   .WillOnce(Invoke([checkpoint, mem](UnixSocket* s) {
380                     ASSERT_EQ("change notify", s->ReceiveString());
381                     ASSERT_STREQ("rock more", mem);
382                     checkpoint();
383                   }));
384             }));
385     task_runner_.RunUntilCheckpoint("change_seen_by_server");
386     ASSERT_TRUE(Mock::VerifyAndClearExpectations(&event_listener_));
387     _exit(0);
388   } else {
389     char sync_cmd = '\0';
390     ASSERT_EQ(1, PERFETTO_EINTR(read(*pipe.rd, &sync_cmd, 1)));
391     ASSERT_EQ('.', sync_cmd);
392     auto cli = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_,
393                                    SockFamily::kUnix, SockType::kStream);
394     EXPECT_CALL(event_listener_, OnConnect(cli.get(), true));
395     auto checkpoint = task_runner_.CreateCheckpoint("change_seen_by_client");
396     EXPECT_CALL(event_listener_, OnDataAvailable(cli.get()))
397         .WillOnce(Invoke([checkpoint](UnixSocket* s) {
398           char msg[32];
399           ScopedFile fd;
400           ASSERT_EQ(5u, s->Receive(msg, sizeof(msg), &fd));
401           ASSERT_STREQ("txfd", msg);
402           ASSERT_TRUE(fd);
403           char* mem = reinterpret_cast<char*>(mmap(
404               nullptr, kTmpSize, PROT_READ | PROT_WRITE, MAP_SHARED, *fd, 0));
405           ASSERT_NE(nullptr, mem);
406           mem[9] = '\0';  // Just to get a clean error in case of test failure.
407           ASSERT_STREQ("shm rocks", mem);
408 
409           // Now change the shared memory and ping the other process.
410           memcpy(mem, "rock more", 10);
411           ASSERT_TRUE(s->Send("change notify"));
412           checkpoint();
413         }));
414     task_runner_.RunUntilCheckpoint("change_seen_by_client");
415     int st = 0;
416     PERFETTO_EINTR(waitpid(pid, &st, 0));
417     ASSERT_FALSE(WIFSIGNALED(st)) << "Server died with signal " << WTERMSIG(st);
418     EXPECT_TRUE(WIFEXITED(st));
419     ASSERT_EQ(0, WEXITSTATUS(st));
420   }
421 }
422 
423 // Checks that the peer_uid() is retained after the client disconnects. The IPC
424 // layer needs to rely on this to validate messages received immediately before
425 // a client disconnects.
TEST_F(UnixSocketTest,PeerCredentialsRetainedAfterDisconnect)426 TEST_F(UnixSocketTest, PeerCredentialsRetainedAfterDisconnect) {
427   auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_,
428                                 SockFamily::kUnix, SockType::kStream);
429   ASSERT_TRUE(srv->is_listening());
430   UnixSocket* srv_client_conn = nullptr;
431   auto srv_connected = task_runner_.CreateCheckpoint("srv_connected");
432   EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
433       .WillOnce(Invoke(
434           [&srv_client_conn, srv_connected](UnixSocket*, UnixSocket* srv_conn) {
435             srv_client_conn = srv_conn;
436             EXPECT_EQ(geteuid(), static_cast<uint32_t>(srv_conn->peer_uid()));
437 #if PERFETTO_BUILDFLAG(PERFETTO_OS_LINUX) || \
438     PERFETTO_BUILDFLAG(PERFETTO_OS_ANDROID)
439             EXPECT_EQ(getpid(), static_cast<pid_t>(srv_conn->peer_pid()));
440 #endif
441             srv_connected();
442           }));
443   auto cli_connected = task_runner_.CreateCheckpoint("cli_connected");
444   auto cli = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_,
445                                  SockFamily::kUnix, SockType::kStream);
446   EXPECT_CALL(event_listener_, OnConnect(cli.get(), true))
447       .WillOnce(InvokeWithoutArgs(cli_connected));
448 
449   task_runner_.RunUntilCheckpoint("cli_connected");
450   task_runner_.RunUntilCheckpoint("srv_connected");
451   ASSERT_NE(nullptr, srv_client_conn);
452   ASSERT_TRUE(srv_client_conn->is_connected());
453 
454   auto cli_disconnected = task_runner_.CreateCheckpoint("cli_disconnected");
455   EXPECT_CALL(event_listener_, OnDisconnect(srv_client_conn))
456       .WillOnce(InvokeWithoutArgs(cli_disconnected));
457 
458   // TODO(primiano): when the a peer disconnects, the other end receives a
459   // spurious OnDataAvailable() that needs to be acked with a Receive() to read
460   // the EOF. See b/69536434.
461   EXPECT_CALL(event_listener_, OnDataAvailable(srv_client_conn))
462       .WillOnce(Invoke([](UnixSocket* sock) { sock->ReceiveString(); }));
463 
464   cli.reset();
465   task_runner_.RunUntilCheckpoint("cli_disconnected");
466   ASSERT_FALSE(srv_client_conn->is_connected());
467   EXPECT_EQ(geteuid(), static_cast<uint32_t>(srv_client_conn->peer_uid()));
468 #if PERFETTO_BUILDFLAG(PERFETTO_OS_LINUX) || \
469     PERFETTO_BUILDFLAG(PERFETTO_OS_ANDROID)
470   EXPECT_EQ(getpid(), static_cast<pid_t>(srv_client_conn->peer_pid()));
471 #endif
472 }
473 
TEST_F(UnixSocketTest,BlockingSend)474 TEST_F(UnixSocketTest, BlockingSend) {
475   auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_,
476                                 SockFamily::kUnix, SockType::kStream);
477   ASSERT_TRUE(srv->is_listening());
478 
479   auto all_frames_done = task_runner_.CreateCheckpoint("all_frames_done");
480   size_t total_bytes_received = 0;
481   constexpr size_t kTotalBytes = 1024 * 1024 * 4;
482   EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
483       .WillOnce(Invoke([this, &total_bytes_received, all_frames_done](
484                            UnixSocket*, UnixSocket* srv_conn) {
485         EXPECT_CALL(event_listener_, OnDataAvailable(srv_conn))
486             .WillRepeatedly(
487                 Invoke([&total_bytes_received, all_frames_done](UnixSocket* s) {
488                   char buf[1024];
489                   size_t res = s->Receive(buf, sizeof(buf));
490                   total_bytes_received += res;
491                   if (total_bytes_received == kTotalBytes)
492                     all_frames_done();
493                 }));
494       }));
495 
496   // Override default timeout as this test can take time on the emulator.
497   const int kTimeoutMs = 60000 * 3;
498 
499   // Perform the blocking send form another thread.
500   std::thread tx_thread([] {
501     TestTaskRunner tx_task_runner;
502     MockEventListener tx_events;
503     auto cli = UnixSocket::Connect(kSocketName, &tx_events, &tx_task_runner,
504                                    SockFamily::kUnix, SockType::kStream);
505 
506     auto cli_connected = tx_task_runner.CreateCheckpoint("cli_connected");
507     EXPECT_CALL(tx_events, OnConnect(cli.get(), true))
508         .WillOnce(InvokeWithoutArgs(cli_connected));
509     tx_task_runner.RunUntilCheckpoint("cli_connected");
510 
511     auto all_sent = tx_task_runner.CreateCheckpoint("all_sent");
512     char buf[1024 * 32] = {};
513     tx_task_runner.PostTask([&cli, &buf, all_sent] {
514       for (size_t i = 0; i < kTotalBytes / sizeof(buf); i++)
515         cli->Send(buf, sizeof(buf));
516       all_sent();
517     });
518     tx_task_runner.RunUntilCheckpoint("all_sent", kTimeoutMs);
519   });
520 
521   task_runner_.RunUntilCheckpoint("all_frames_done", kTimeoutMs);
522   tx_thread.join();
523 }
524 
525 // Regression test for b/76155349 . If the receiver end disconnects while the
526 // sender is in the middle of a large send(), the socket should gracefully give
527 // up (i.e. Shutdown()) but not crash.
TEST_F(UnixSocketTest,ReceiverDisconnectsDuringSend)528 TEST_F(UnixSocketTest, ReceiverDisconnectsDuringSend) {
529   auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_,
530                                 SockFamily::kUnix, SockType::kStream);
531   ASSERT_TRUE(srv->is_listening());
532   const int kTimeoutMs = 30000;
533 
534   auto receive_done = task_runner_.CreateCheckpoint("receive_done");
535   EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
536       .WillOnce(Invoke([this, receive_done](UnixSocket*, UnixSocket* srv_conn) {
537         EXPECT_CALL(event_listener_, OnDataAvailable(srv_conn))
538             .WillOnce(Invoke([receive_done](UnixSocket* s) {
539               char buf[1024];
540               size_t res = s->Receive(buf, sizeof(buf));
541               ASSERT_EQ(1024u, res);
542               s->Shutdown(false /*notify*/);
543               receive_done();
544             }));
545       }));
546 
547   // Perform the blocking send form another thread.
548   std::thread tx_thread([] {
549     TestTaskRunner tx_task_runner;
550     MockEventListener tx_events;
551     auto cli = UnixSocket::Connect(kSocketName, &tx_events, &tx_task_runner,
552                                    SockFamily::kUnix, SockType::kStream);
553 
554     auto cli_connected = tx_task_runner.CreateCheckpoint("cli_connected");
555     EXPECT_CALL(tx_events, OnConnect(cli.get(), true))
556         .WillOnce(InvokeWithoutArgs(cli_connected));
557     tx_task_runner.RunUntilCheckpoint("cli_connected");
558 
559     auto send_done = tx_task_runner.CreateCheckpoint("send_done");
560     // We need a
561     static constexpr size_t kBufSize = 32 * 1024 * 1024;
562     std::unique_ptr<char[]> buf(new char[kBufSize]());
563     tx_task_runner.PostTask([&cli, &buf, send_done] {
564       bool send_res = cli->Send(buf.get(), kBufSize);
565       ASSERT_FALSE(send_res);
566       send_done();
567     });
568 
569     tx_task_runner.RunUntilCheckpoint("send_done", kTimeoutMs);
570   });
571   task_runner_.RunUntilCheckpoint("receive_done", kTimeoutMs);
572   tx_thread.join();
573 }
574 
TEST_F(UnixSocketTest,ShiftMsgHdrSendPartialFirst)575 TEST_F(UnixSocketTest, ShiftMsgHdrSendPartialFirst) {
576   // Send a part of the first iov, then send the rest.
577   struct iovec iov[2] = {};
578   char hello[] = "hello";
579   char world[] = "world";
580   iov[0].iov_base = &hello[0];
581   iov[0].iov_len = base::ArraySize(hello);
582 
583   iov[1].iov_base = &world[0];
584   iov[1].iov_len = base::ArraySize(world);
585 
586   struct msghdr hdr = {};
587   hdr.msg_iov = iov;
588   hdr.msg_iovlen = base::ArraySize(iov);
589 
590   UnixSocketRaw::ShiftMsgHdr(1, &hdr);
591   EXPECT_NE(hdr.msg_iov, nullptr);
592   EXPECT_EQ(hdr.msg_iov[0].iov_base, &hello[1]);
593   EXPECT_EQ(hdr.msg_iov[1].iov_base, &world[0]);
594   EXPECT_EQ(static_cast<int>(hdr.msg_iovlen), 2);
595   EXPECT_STREQ(reinterpret_cast<char*>(hdr.msg_iov[0].iov_base), "ello");
596   EXPECT_EQ(iov[0].iov_len, base::ArraySize(hello) - 1);
597 
598   UnixSocketRaw::ShiftMsgHdr(base::ArraySize(hello) - 1, &hdr);
599   EXPECT_EQ(hdr.msg_iov, &iov[1]);
600   EXPECT_EQ(static_cast<int>(hdr.msg_iovlen), 1);
601   EXPECT_STREQ(reinterpret_cast<char*>(hdr.msg_iov[0].iov_base), world);
602   EXPECT_EQ(hdr.msg_iov[0].iov_len, base::ArraySize(world));
603 
604   UnixSocketRaw::ShiftMsgHdr(base::ArraySize(world), &hdr);
605   EXPECT_EQ(hdr.msg_iov, nullptr);
606   EXPECT_EQ(static_cast<int>(hdr.msg_iovlen), 0);
607 }
608 
TEST_F(UnixSocketTest,ShiftMsgHdrSendFirstAndPartial)609 TEST_F(UnixSocketTest, ShiftMsgHdrSendFirstAndPartial) {
610   // Send first iov and part of the second iov, then send the rest.
611   struct iovec iov[2] = {};
612   char hello[] = "hello";
613   char world[] = "world";
614   iov[0].iov_base = &hello[0];
615   iov[0].iov_len = base::ArraySize(hello);
616 
617   iov[1].iov_base = &world[0];
618   iov[1].iov_len = base::ArraySize(world);
619 
620   struct msghdr hdr = {};
621   hdr.msg_iov = iov;
622   hdr.msg_iovlen = base::ArraySize(iov);
623 
624   UnixSocketRaw::ShiftMsgHdr(base::ArraySize(hello) + 1, &hdr);
625   EXPECT_NE(hdr.msg_iov, nullptr);
626   EXPECT_EQ(static_cast<int>(hdr.msg_iovlen), 1);
627   EXPECT_STREQ(reinterpret_cast<char*>(hdr.msg_iov[0].iov_base), "orld");
628   EXPECT_EQ(hdr.msg_iov[0].iov_len, base::ArraySize(world) - 1);
629 
630   UnixSocketRaw::ShiftMsgHdr(base::ArraySize(world) - 1, &hdr);
631   EXPECT_EQ(hdr.msg_iov, nullptr);
632   EXPECT_EQ(static_cast<int>(hdr.msg_iovlen), 0);
633 }
634 
TEST_F(UnixSocketTest,ShiftMsgHdrSendEverything)635 TEST_F(UnixSocketTest, ShiftMsgHdrSendEverything) {
636   // Send everything at once.
637   struct iovec iov[2] = {};
638   char hello[] = "hello";
639   char world[] = "world";
640   iov[0].iov_base = &hello[0];
641   iov[0].iov_len = base::ArraySize(hello);
642 
643   iov[1].iov_base = &world[0];
644   iov[1].iov_len = base::ArraySize(world);
645 
646   struct msghdr hdr = {};
647   hdr.msg_iov = iov;
648   hdr.msg_iovlen = base::ArraySize(iov);
649 
650   UnixSocketRaw::ShiftMsgHdr(base::ArraySize(world) + base::ArraySize(hello),
651                              &hdr);
652   EXPECT_EQ(hdr.msg_iov, nullptr);
653   EXPECT_EQ(static_cast<int>(hdr.msg_iovlen), 0);
654 }
655 
Handler(int)656 void Handler(int) {}
657 
RollbackSigaction(const struct sigaction * act)658 int RollbackSigaction(const struct sigaction* act) {
659   return sigaction(SIGWINCH, act, nullptr);
660 }
661 
TEST_F(UnixSocketTest,PartialSendMsgAll)662 TEST_F(UnixSocketTest, PartialSendMsgAll) {
663   UnixSocketRaw send_sock;
664   UnixSocketRaw recv_sock;
665   std::tie(send_sock, recv_sock) =
666       UnixSocketRaw::CreatePair(SockFamily::kUnix, SockType::kStream);
667   ASSERT_TRUE(send_sock);
668   ASSERT_TRUE(recv_sock);
669 
670   // Set bufsize to minimum.
671   int bufsize = 1024;
672   ASSERT_EQ(setsockopt(send_sock.fd(), SOL_SOCKET, SO_SNDBUF, &bufsize,
673                        sizeof(bufsize)),
674             0);
675   ASSERT_EQ(setsockopt(recv_sock.fd(), SOL_SOCKET, SO_RCVBUF, &bufsize,
676                        sizeof(bufsize)),
677             0);
678 
679   // Send something larger than send + recv kernel buffers combined to make
680   // sendmsg block.
681   char send_buf[8192];
682   // Make MSAN happy.
683   for (size_t i = 0; i < sizeof(send_buf); ++i)
684     send_buf[i] = static_cast<char>(i % 256);
685   char recv_buf[sizeof(send_buf)];
686 
687   // Need to install signal handler to cause the interrupt to happen.
688   // man 3 pthread_kill:
689   //   Signal dispositions are process-wide: if a signal handler is
690   //   installed, the handler will be invoked in the thread thread, but if
691   //   the disposition of the signal is "stop", "continue", or "terminate",
692   //   this action will affect the whole process.
693   struct sigaction oldact;
694   struct sigaction newact = {};
695   newact.sa_handler = Handler;
696   ASSERT_EQ(sigaction(SIGWINCH, &newact, &oldact), 0);
697   base::ScopedResource<const struct sigaction*, RollbackSigaction, nullptr>
698       rollback(&oldact);
699 
700   auto blocked_thread = pthread_self();
701   std::thread th([blocked_thread, &recv_sock, &recv_buf] {
702     ssize_t rd = PERFETTO_EINTR(read(recv_sock.fd(), recv_buf, 1));
703     ASSERT_EQ(rd, 1);
704     // We are now sure the other thread is in sendmsg, interrupt send.
705     ASSERT_EQ(pthread_kill(blocked_thread, SIGWINCH), 0);
706     // Drain the socket to allow SendMsgAll to succeed.
707     size_t offset = 1;
708     while (offset < sizeof(recv_buf)) {
709       rd = PERFETTO_EINTR(
710           read(recv_sock.fd(), recv_buf + offset, sizeof(recv_buf) - offset));
711       ASSERT_GE(rd, 0);
712       offset += static_cast<size_t>(rd);
713     }
714   });
715 
716   // Test sending the send_buf in several chunks as an iov to exercise the
717   // more complicated code-paths of SendMsgAll.
718   struct msghdr hdr = {};
719   struct iovec iov[4];
720   static_assert(sizeof(send_buf) % base::ArraySize(iov) == 0,
721                 "Cannot split buffer into even pieces.");
722   constexpr size_t kChunkSize = sizeof(send_buf) / base::ArraySize(iov);
723   for (size_t i = 0; i < base::ArraySize(iov); ++i) {
724     iov[i].iov_base = send_buf + i * kChunkSize;
725     iov[i].iov_len = kChunkSize;
726   }
727   hdr.msg_iov = iov;
728   hdr.msg_iovlen = base::ArraySize(iov);
729 
730   ASSERT_EQ(send_sock.SendMsgAll(&hdr), static_cast<ssize_t>(sizeof(send_buf)));
731   send_sock.Shutdown();
732   th.join();
733   // Make sure the re-entry logic was actually triggered.
734   ASSERT_EQ(hdr.msg_iov, nullptr);
735   ASSERT_EQ(memcmp(send_buf, recv_buf, sizeof(send_buf)), 0);
736 }
737 
TEST_F(UnixSocketTest,ReleaseSocket)738 TEST_F(UnixSocketTest, ReleaseSocket) {
739   auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_,
740                                 SockFamily::kUnix, SockType::kStream);
741   ASSERT_TRUE(srv->is_listening());
742   auto connected = task_runner_.CreateCheckpoint("connected");
743   UnixSocket* peer = nullptr;
744   EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
745       .WillOnce(Invoke([connected, &peer](UnixSocket*, UnixSocket* new_conn) {
746         peer = new_conn;
747         connected();
748       }));
749 
750   auto cli = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_,
751                                  SockFamily::kUnix, SockType::kStream);
752   EXPECT_CALL(event_listener_, OnConnect(cli.get(), true));
753   task_runner_.RunUntilCheckpoint("connected");
754   srv->Shutdown(true);
755 
756   cli->Send("test");
757 
758   ASSERT_NE(peer, nullptr);
759   auto raw_sock = peer->ReleaseSocket();
760 
761   EXPECT_CALL(event_listener_, OnDataAvailable(_)).Times(0);
762   task_runner_.RunUntilIdle();
763 
764   char buf[sizeof("test")];
765   ASSERT_TRUE(raw_sock);
766   ASSERT_EQ(raw_sock.Receive(buf, sizeof(buf)),
767             static_cast<ssize_t>(sizeof(buf)));
768   ASSERT_STREQ(buf, "test");
769 }
770 
TEST_F(UnixSocketTest,TcpStream)771 TEST_F(UnixSocketTest, TcpStream) {
772   char host_and_port[32];
773   int attempt = 0;
774   std::unique_ptr<UnixSocket> srv;
775 
776   // Try listening on a random port. Some ports might be taken by other syste
777   // services. Do a bunch of attempts on different ports before giving up.
778   do {
779     sprintf(host_and_port, "127.0.0.1:%d", 10000 + (rand() % 10000));
780     srv = UnixSocket::Listen(host_and_port, &event_listener_, &task_runner_,
781                              SockFamily::kInet, SockType::kStream);
782   } while ((!srv || !srv->is_listening()) && attempt++ < 10);
783   ASSERT_TRUE(srv->is_listening());
784 
785   constexpr size_t kNumClients = 3;
786   std::unique_ptr<UnixSocket> cli[kNumClients];
787   EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
788       .Times(kNumClients)
789       .WillRepeatedly(Invoke([&](UnixSocket*, UnixSocket* s) {
790         // OnDisconnect() might spuriously happen depending on the dtor order.
791         EXPECT_CALL(event_listener_, OnDisconnect(s)).Times(AtLeast(0));
792         EXPECT_CALL(event_listener_, OnDataAvailable(s))
793             .WillRepeatedly(Invoke([](UnixSocket* cli_sock) {
794               cli_sock->ReceiveString();  // Read connection EOF;
795             }));
796         ASSERT_TRUE(s->Send("welcome"));
797       }));
798 
799   for (size_t i = 0; i < kNumClients; i++) {
800     cli[i] = UnixSocket::Connect(host_and_port, &event_listener_, &task_runner_,
801                                  SockFamily::kInet, SockType::kStream);
802     // PERFETTO_ILOG("cli : %p", reinterpret_cast<void*>(cli[i].get()));
803     auto checkpoint = task_runner_.CreateCheckpoint(std::to_string(i));
804     EXPECT_CALL(event_listener_, OnDisconnect(cli[i].get())).Times(AtLeast(0));
805     EXPECT_CALL(event_listener_, OnConnect(cli[i].get(), true));
806     EXPECT_CALL(event_listener_, OnDataAvailable(cli[i].get()))
807         .WillRepeatedly(Invoke([checkpoint](UnixSocket* s) {
808           auto str = s->ReceiveString();
809           if (str == "")
810             return;  // Connection EOF.
811           ASSERT_EQ("welcome", str);
812           checkpoint();
813         }));
814   }
815 
816   for (size_t i = 0; i < kNumClients; i++) {
817     task_runner_.RunUntilCheckpoint(std::to_string(i));
818     ASSERT_TRUE(Mock::VerifyAndClearExpectations(cli[i].get()));
819   }
820 }
821 
822 // TODO(primiano): add a test to check that in the case of a peer sending a fd
823 // and the other end just doing a recv (without taking it), the fd is closed and
824 // not left around.
825 
826 // TODO(primiano); add a test to check that a socket can be reused after
827 // Shutdown(),
828 
829 // TODO(primiano): add a test to check that OnDisconnect() is called in all
830 // possible cases.
831 
832 // TODO(primiano): add tests that destroy the socket in all possible stages and
833 // verify that no spurious EventListener callback is received.
834 
835 }  // namespace
836 }  // namespace base
837 }  // namespace perfetto
838