1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/unix_domain_client_socket_posix.h"
6 
7 #include <unistd.h>
8 
9 #include <memory>
10 #include <utility>
11 
12 #include "base/bind.h"
13 #include "base/files/file_path.h"
14 #include "base/files/scoped_temp_dir.h"
15 #include "base/posix/eintr_wrapper.h"
16 #include "net/base/io_buffer.h"
17 #include "net/base/net_errors.h"
18 #include "net/base/sockaddr_storage.h"
19 #include "net/base/test_completion_callback.h"
20 #include "net/socket/socket_posix.h"
21 #include "net/socket/unix_domain_server_socket_posix.h"
22 #include "net/test/gtest_util.h"
23 #include "net/test/test_with_task_environment.h"
24 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
25 #include "testing/gmock/include/gmock/gmock.h"
26 #include "testing/gtest/include/gtest/gtest.h"
27 
28 using net::test::IsError;
29 using net::test::IsOk;
30 
31 namespace net {
32 namespace {
33 
34 const char kSocketFilename[] = "socket_for_testing";
35 
UserCanConnectCallback(bool allow_user,const UnixDomainServerSocket::Credentials & credentials)36 bool UserCanConnectCallback(
37     bool allow_user, const UnixDomainServerSocket::Credentials& credentials) {
38   // Here peers are running in same process.
39 #if defined(OS_LINUX) || defined(OS_ANDROID)
40   EXPECT_EQ(getpid(), credentials.process_id);
41 #endif
42   EXPECT_EQ(getuid(), credentials.user_id);
43   EXPECT_EQ(getgid(), credentials.group_id);
44   return allow_user;
45 }
46 
CreateAuthCallback(bool allow_user)47 UnixDomainServerSocket::AuthCallback CreateAuthCallback(bool allow_user) {
48   return base::Bind(&UserCanConnectCallback, allow_user);
49 }
50 
51 // Connects socket synchronously.
ConnectSynchronously(StreamSocket * socket)52 int ConnectSynchronously(StreamSocket* socket) {
53   TestCompletionCallback connect_callback;
54   int rv = socket->Connect(connect_callback.callback());
55   if (rv == ERR_IO_PENDING)
56     rv = connect_callback.WaitForResult();
57   return rv;
58 }
59 
60 // Reads data from |socket| until it fills |buf| at least up to |min_data_len|.
61 // Returns length of data read, or a net error.
ReadSynchronously(StreamSocket * socket,IOBuffer * buf,int buf_len,int min_data_len)62 int ReadSynchronously(StreamSocket* socket,
63                       IOBuffer* buf,
64                       int buf_len,
65                       int min_data_len) {
66   DCHECK_LE(min_data_len, buf_len);
67   scoped_refptr<DrainableIOBuffer> read_buf =
68       base::MakeRefCounted<DrainableIOBuffer>(buf, buf_len);
69   TestCompletionCallback read_callback;
70   // Iterate reading several times (but not infinite) until it reads at least
71   // |min_data_len| bytes into |buf|.
72   for (int retry_count = 10;
73        retry_count > 0 && (read_buf->BytesConsumed() < min_data_len ||
74                            // Try at least once when min_data_len == 0.
75                            min_data_len == 0);
76        --retry_count) {
77     int rv = socket->Read(
78         read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
79     EXPECT_GE(read_buf->BytesRemaining(), rv);
80     if (rv == ERR_IO_PENDING) {
81       // If |min_data_len| is 0, returns ERR_IO_PENDING to distinguish the case
82       // when some data has been read.
83       if (min_data_len == 0) {
84         // No data has been read because of for-loop condition.
85         DCHECK_EQ(0, read_buf->BytesConsumed());
86         return ERR_IO_PENDING;
87       }
88       rv = read_callback.WaitForResult();
89     }
90     EXPECT_NE(ERR_IO_PENDING, rv);
91     if (rv < 0)
92       return rv;
93     read_buf->DidConsume(rv);
94   }
95   EXPECT_LE(0, read_buf->BytesRemaining());
96   return read_buf->BytesConsumed();
97 }
98 
99 // Writes data to |socket| until it completes writing |buf| up to |buf_len|.
100 // Returns length of data written, or a net error.
WriteSynchronously(StreamSocket * socket,IOBuffer * buf,int buf_len)101 int WriteSynchronously(StreamSocket* socket,
102                        IOBuffer* buf,
103                        int buf_len) {
104   scoped_refptr<DrainableIOBuffer> write_buf =
105       base::MakeRefCounted<DrainableIOBuffer>(buf, buf_len);
106   TestCompletionCallback write_callback;
107   // Iterate writing several times (but not infinite) until it writes buf fully.
108   for (int retry_count = 10;
109        retry_count > 0 && write_buf->BytesRemaining() > 0;
110        --retry_count) {
111     int rv =
112         socket->Write(write_buf.get(), write_buf->BytesRemaining(),
113                       write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
114     EXPECT_GE(write_buf->BytesRemaining(), rv);
115     if (rv == ERR_IO_PENDING)
116       rv = write_callback.WaitForResult();
117     EXPECT_NE(ERR_IO_PENDING, rv);
118     if (rv < 0)
119       return rv;
120     write_buf->DidConsume(rv);
121   }
122   EXPECT_LE(0, write_buf->BytesRemaining());
123   return write_buf->BytesConsumed();
124 }
125 
126 class UnixDomainClientSocketTest : public TestWithTaskEnvironment {
127  protected:
UnixDomainClientSocketTest()128   UnixDomainClientSocketTest() {
129     EXPECT_TRUE(temp_dir_.CreateUniqueTempDir());
130     socket_path_ = temp_dir_.GetPath().Append(kSocketFilename).value();
131   }
132 
133   base::ScopedTempDir temp_dir_;
134   std::string socket_path_;
135 };
136 
TEST_F(UnixDomainClientSocketTest,Connect)137 TEST_F(UnixDomainClientSocketTest, Connect) {
138   const bool kUseAbstractNamespace = false;
139 
140   UnixDomainServerSocket server_socket(CreateAuthCallback(true),
141                                        kUseAbstractNamespace);
142   EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
143 
144   std::unique_ptr<StreamSocket> accepted_socket;
145   TestCompletionCallback accept_callback;
146   EXPECT_EQ(ERR_IO_PENDING,
147             server_socket.Accept(&accepted_socket, accept_callback.callback()));
148   EXPECT_FALSE(accepted_socket);
149 
150   UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
151   EXPECT_FALSE(client_socket.IsConnected());
152 
153   EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
154   EXPECT_TRUE(client_socket.IsConnected());
155   // Server has not yet been notified of the connection.
156   EXPECT_FALSE(accepted_socket);
157 
158   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
159   EXPECT_TRUE(accepted_socket);
160   EXPECT_TRUE(accepted_socket->IsConnected());
161 }
162 
TEST_F(UnixDomainClientSocketTest,ConnectWithSocketDescriptor)163 TEST_F(UnixDomainClientSocketTest, ConnectWithSocketDescriptor) {
164   const bool kUseAbstractNamespace = false;
165 
166   UnixDomainServerSocket server_socket(CreateAuthCallback(true),
167                                        kUseAbstractNamespace);
168   EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
169 
170   SocketDescriptor accepted_socket_fd = kInvalidSocket;
171   TestCompletionCallback accept_callback;
172   EXPECT_EQ(ERR_IO_PENDING,
173             server_socket.AcceptSocketDescriptor(&accepted_socket_fd,
174                                                  accept_callback.callback()));
175   EXPECT_EQ(kInvalidSocket, accepted_socket_fd);
176 
177   UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
178   EXPECT_FALSE(client_socket.IsConnected());
179 
180   EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
181   EXPECT_TRUE(client_socket.IsConnected());
182   // Server has not yet been notified of the connection.
183   EXPECT_EQ(kInvalidSocket, accepted_socket_fd);
184 
185   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
186   EXPECT_NE(kInvalidSocket, accepted_socket_fd);
187 
188   SocketDescriptor client_socket_fd = client_socket.ReleaseConnectedSocket();
189   EXPECT_NE(kInvalidSocket, client_socket_fd);
190 
191   // Now, re-wrap client_socket_fd in a UnixDomainClientSocket and try a read
192   // to be sure it hasn't gotten accidentally closed.
193   SockaddrStorage addr;
194   ASSERT_TRUE(UnixDomainClientSocket::FillAddress(socket_path_, false, &addr));
195   std::unique_ptr<SocketPosix> adopter(new SocketPosix);
196   adopter->AdoptConnectedSocket(client_socket_fd, addr);
197   UnixDomainClientSocket rewrapped_socket(std::move(adopter));
198   EXPECT_TRUE(rewrapped_socket.IsConnected());
199 
200   // Try to read data.
201   const int kReadDataSize = 10;
202   scoped_refptr<IOBuffer> read_buffer =
203       base::MakeRefCounted<IOBuffer>(kReadDataSize);
204   TestCompletionCallback read_callback;
205   EXPECT_EQ(ERR_IO_PENDING,
206             rewrapped_socket.Read(
207                 read_buffer.get(), kReadDataSize, read_callback.callback()));
208 
209   EXPECT_EQ(0, IGNORE_EINTR(close(accepted_socket_fd)));
210 }
211 
TEST_F(UnixDomainClientSocketTest,ConnectWithAbstractNamespace)212 TEST_F(UnixDomainClientSocketTest, ConnectWithAbstractNamespace) {
213   const bool kUseAbstractNamespace = true;
214 
215   UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
216   EXPECT_FALSE(client_socket.IsConnected());
217 
218 #if defined(OS_ANDROID) || defined(OS_LINUX)
219   UnixDomainServerSocket server_socket(CreateAuthCallback(true),
220                                        kUseAbstractNamespace);
221   EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
222 
223   std::unique_ptr<StreamSocket> accepted_socket;
224   TestCompletionCallback accept_callback;
225   EXPECT_EQ(ERR_IO_PENDING,
226             server_socket.Accept(&accepted_socket, accept_callback.callback()));
227   EXPECT_FALSE(accepted_socket);
228 
229   EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
230   EXPECT_TRUE(client_socket.IsConnected());
231   // Server has not yet beend notified of the connection.
232   EXPECT_FALSE(accepted_socket);
233 
234   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
235   EXPECT_TRUE(accepted_socket);
236   EXPECT_TRUE(accepted_socket->IsConnected());
237 #else
238   EXPECT_THAT(ConnectSynchronously(&client_socket),
239               IsError(ERR_ADDRESS_INVALID));
240 #endif
241 }
242 
TEST_F(UnixDomainClientSocketTest,ConnectToNonExistentSocket)243 TEST_F(UnixDomainClientSocketTest, ConnectToNonExistentSocket) {
244   const bool kUseAbstractNamespace = false;
245 
246   UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
247   EXPECT_FALSE(client_socket.IsConnected());
248   EXPECT_THAT(ConnectSynchronously(&client_socket),
249               IsError(ERR_FILE_NOT_FOUND));
250 }
251 
TEST_F(UnixDomainClientSocketTest,ConnectToNonExistentSocketWithAbstractNamespace)252 TEST_F(UnixDomainClientSocketTest,
253        ConnectToNonExistentSocketWithAbstractNamespace) {
254   const bool kUseAbstractNamespace = true;
255 
256   UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
257   EXPECT_FALSE(client_socket.IsConnected());
258 
259   TestCompletionCallback connect_callback;
260 #if defined(OS_ANDROID) || defined(OS_LINUX)
261   EXPECT_THAT(ConnectSynchronously(&client_socket),
262               IsError(ERR_CONNECTION_REFUSED));
263 #else
264   EXPECT_THAT(ConnectSynchronously(&client_socket),
265               IsError(ERR_ADDRESS_INVALID));
266 #endif
267 }
268 
TEST_F(UnixDomainClientSocketTest,DisconnectFromClient)269 TEST_F(UnixDomainClientSocketTest, DisconnectFromClient) {
270   UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
271   EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
272   std::unique_ptr<StreamSocket> accepted_socket;
273   TestCompletionCallback accept_callback;
274   EXPECT_EQ(ERR_IO_PENDING,
275             server_socket.Accept(&accepted_socket, accept_callback.callback()));
276   UnixDomainClientSocket client_socket(socket_path_, false);
277   EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
278 
279   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
280   EXPECT_TRUE(accepted_socket->IsConnected());
281   EXPECT_TRUE(client_socket.IsConnected());
282 
283   // Try to read data.
284   const int kReadDataSize = 10;
285   scoped_refptr<IOBuffer> read_buffer =
286       base::MakeRefCounted<IOBuffer>(kReadDataSize);
287   TestCompletionCallback read_callback;
288   EXPECT_EQ(ERR_IO_PENDING,
289             accepted_socket->Read(
290                 read_buffer.get(), kReadDataSize, read_callback.callback()));
291 
292   // Disconnect from client side.
293   client_socket.Disconnect();
294   EXPECT_FALSE(client_socket.IsConnected());
295   EXPECT_FALSE(accepted_socket->IsConnected());
296 
297   // Connection closed by peer.
298   EXPECT_EQ(0 /* EOF */, read_callback.WaitForResult());
299   // Note that read callback won't be called when the connection is closed
300   // locally before the peer closes it. SocketPosix just clears callbacks.
301 }
302 
TEST_F(UnixDomainClientSocketTest,DisconnectFromServer)303 TEST_F(UnixDomainClientSocketTest, DisconnectFromServer) {
304   UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
305   EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
306   std::unique_ptr<StreamSocket> accepted_socket;
307   TestCompletionCallback accept_callback;
308   EXPECT_EQ(ERR_IO_PENDING,
309             server_socket.Accept(&accepted_socket, accept_callback.callback()));
310   UnixDomainClientSocket client_socket(socket_path_, false);
311   EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
312 
313   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
314   EXPECT_TRUE(accepted_socket->IsConnected());
315   EXPECT_TRUE(client_socket.IsConnected());
316 
317   // Try to read data.
318   const int kReadDataSize = 10;
319   scoped_refptr<IOBuffer> read_buffer =
320       base::MakeRefCounted<IOBuffer>(kReadDataSize);
321   TestCompletionCallback read_callback;
322   EXPECT_EQ(ERR_IO_PENDING,
323             client_socket.Read(
324                 read_buffer.get(), kReadDataSize, read_callback.callback()));
325 
326   // Disconnect from server side.
327   accepted_socket->Disconnect();
328   EXPECT_FALSE(accepted_socket->IsConnected());
329   EXPECT_FALSE(client_socket.IsConnected());
330 
331   // Connection closed by peer.
332   EXPECT_EQ(0 /* EOF */, read_callback.WaitForResult());
333   // Note that read callback won't be called when the connection is closed
334   // locally before the peer closes it. SocketPosix just clears callbacks.
335 }
336 
TEST_F(UnixDomainClientSocketTest,ReadAfterWrite)337 TEST_F(UnixDomainClientSocketTest, ReadAfterWrite) {
338   UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
339   EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
340   std::unique_ptr<StreamSocket> accepted_socket;
341   TestCompletionCallback accept_callback;
342   EXPECT_EQ(ERR_IO_PENDING,
343             server_socket.Accept(&accepted_socket, accept_callback.callback()));
344   UnixDomainClientSocket client_socket(socket_path_, false);
345   EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
346 
347   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
348   EXPECT_TRUE(accepted_socket->IsConnected());
349   EXPECT_TRUE(client_socket.IsConnected());
350 
351   // Send data from client to server.
352   const int kWriteDataSize = 10;
353   scoped_refptr<IOBuffer> write_buffer =
354       base::MakeRefCounted<StringIOBuffer>(std::string(kWriteDataSize, 'd'));
355   EXPECT_EQ(
356       kWriteDataSize,
357       WriteSynchronously(&client_socket, write_buffer.get(), kWriteDataSize));
358 
359   // The buffer is bigger than write data size.
360   const int kReadBufferSize = kWriteDataSize * 2;
361   scoped_refptr<IOBuffer> read_buffer =
362       base::MakeRefCounted<IOBuffer>(kReadBufferSize);
363   EXPECT_EQ(kWriteDataSize,
364             ReadSynchronously(accepted_socket.get(),
365                               read_buffer.get(),
366                               kReadBufferSize,
367                               kWriteDataSize));
368   EXPECT_EQ(std::string(write_buffer->data(), kWriteDataSize),
369             std::string(read_buffer->data(), kWriteDataSize));
370 
371   // Send data from server and client.
372   EXPECT_EQ(kWriteDataSize,
373             WriteSynchronously(
374                 accepted_socket.get(), write_buffer.get(), kWriteDataSize));
375 
376   // Read multiple times.
377   const int kSmallReadBufferSize = kWriteDataSize / 3;
378   EXPECT_EQ(kSmallReadBufferSize,
379             ReadSynchronously(&client_socket,
380                               read_buffer.get(),
381                               kSmallReadBufferSize,
382                               kSmallReadBufferSize));
383   EXPECT_EQ(std::string(write_buffer->data(), kSmallReadBufferSize),
384             std::string(read_buffer->data(), kSmallReadBufferSize));
385 
386   EXPECT_EQ(kWriteDataSize - kSmallReadBufferSize,
387             ReadSynchronously(&client_socket,
388                               read_buffer.get(),
389                               kReadBufferSize,
390                               kWriteDataSize - kSmallReadBufferSize));
391   EXPECT_EQ(std::string(write_buffer->data() + kSmallReadBufferSize,
392                         kWriteDataSize - kSmallReadBufferSize),
393             std::string(read_buffer->data(),
394                         kWriteDataSize - kSmallReadBufferSize));
395 
396   // No more data.
397   EXPECT_EQ(
398       ERR_IO_PENDING,
399       ReadSynchronously(&client_socket, read_buffer.get(), kReadBufferSize, 0));
400 
401   // Disconnect from server side after read-write.
402   accepted_socket->Disconnect();
403   EXPECT_FALSE(accepted_socket->IsConnected());
404   EXPECT_FALSE(client_socket.IsConnected());
405 }
406 
TEST_F(UnixDomainClientSocketTest,ReadBeforeWrite)407 TEST_F(UnixDomainClientSocketTest, ReadBeforeWrite) {
408   UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
409   EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
410   std::unique_ptr<StreamSocket> accepted_socket;
411   TestCompletionCallback accept_callback;
412   EXPECT_EQ(ERR_IO_PENDING,
413             server_socket.Accept(&accepted_socket, accept_callback.callback()));
414   UnixDomainClientSocket client_socket(socket_path_, false);
415   EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
416 
417   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
418   EXPECT_TRUE(accepted_socket->IsConnected());
419   EXPECT_TRUE(client_socket.IsConnected());
420 
421   // Wait for data from client.
422   const int kWriteDataSize = 10;
423   const int kReadBufferSize = kWriteDataSize * 2;
424   const int kSmallReadBufferSize = kWriteDataSize / 3;
425   // Read smaller than write data size first.
426   scoped_refptr<IOBuffer> read_buffer =
427       base::MakeRefCounted<IOBuffer>(kReadBufferSize);
428   TestCompletionCallback read_callback;
429   EXPECT_EQ(
430       ERR_IO_PENDING,
431       accepted_socket->Read(
432           read_buffer.get(), kSmallReadBufferSize, read_callback.callback()));
433 
434   scoped_refptr<IOBuffer> write_buffer =
435       base::MakeRefCounted<StringIOBuffer>(std::string(kWriteDataSize, 'd'));
436   EXPECT_EQ(
437       kWriteDataSize,
438       WriteSynchronously(&client_socket, write_buffer.get(), kWriteDataSize));
439 
440   // First read completed.
441   int rv = read_callback.WaitForResult();
442   EXPECT_LT(0, rv);
443   EXPECT_LE(rv, kSmallReadBufferSize);
444 
445   // Read remaining data.
446   const int kExpectedRemainingDataSize = kWriteDataSize - rv;
447   EXPECT_LE(0, kExpectedRemainingDataSize);
448   EXPECT_EQ(kExpectedRemainingDataSize,
449             ReadSynchronously(accepted_socket.get(),
450                               read_buffer.get(),
451                               kReadBufferSize,
452                               kExpectedRemainingDataSize));
453   // No more data.
454   EXPECT_EQ(ERR_IO_PENDING,
455             ReadSynchronously(
456                 accepted_socket.get(), read_buffer.get(), kReadBufferSize, 0));
457 
458   // Disconnect from server side after read-write.
459   accepted_socket->Disconnect();
460   EXPECT_FALSE(accepted_socket->IsConnected());
461   EXPECT_FALSE(client_socket.IsConnected());
462 }
463 
464 }  // namespace
465 }  // namespace net
466