1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <stdint.h>
6 
7 #include <utility>
8 #include <vector>
9 
10 #include "base/callback_helpers.h"
11 #include "base/location.h"
12 #include "base/logging.h"
13 #include "base/macros.h"
14 #include "base/run_loop.h"
15 #include "base/test/task_environment.h"
16 #include "mojo/public/cpp/system/simple_watcher.h"
17 #include "net/base/net_errors.h"
18 #include "net/base/test_completion_callback.h"
19 #include "net/socket/socket_test_util.h"
20 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
21 #include "net/url_request/url_request_test_util.h"
22 #include "services/network/mojo_socket_test_util.h"
23 #include "services/network/public/mojom/network_service.mojom.h"
24 #include "services/network/public/mojom/udp_socket.mojom.h"
25 #include "services/network/socket_factory.h"
26 #include "services/network/tcp_connected_socket.h"
27 #include "services/network/tcp_server_socket.h"
28 #include "testing/gtest/include/gtest/gtest.h"
29 
30 namespace network {
31 
32 // Test delegate to wait on network read/write errors.
33 class TestSocketDataPumpDelegate : public SocketDataPump::Delegate {
34  public:
TestSocketDataPumpDelegate()35   TestSocketDataPumpDelegate() {}
~TestSocketDataPumpDelegate()36   ~TestSocketDataPumpDelegate() {}
37 
38   // Waits for read error. Returns the error observed.
WaitForReadError()39   int WaitForReadError() {
40     read_loop_.Run();
41     int error = read_error_;
42     read_error_ = net::OK;
43     return error;
44   }
45 
46   // Waits for write error. Returns the error observed.
WaitForWriteError()47   int WaitForWriteError() {
48     write_loop_.Run();
49     int error = write_error_;
50     write_error_ = net::OK;
51     return error;
52   }
53 
54   // Waits for shutdown.
WaitForShutdown()55   void WaitForShutdown() { shutdown_loop_.Run(); }
56 
57  private:
OnNetworkReadError(int error)58   void OnNetworkReadError(int error) override {
59     read_error_ = error;
60     read_loop_.Quit();
61   }
OnNetworkWriteError(int error)62   void OnNetworkWriteError(int error) override {
63     write_error_ = error;
64     write_loop_.Quit();
65   }
OnShutdown()66   void OnShutdown() override { shutdown_loop_.Quit(); }
67 
68   int read_error_ = net::OK;
69   int write_error_ = net::OK;
70   base::RunLoop read_loop_;
71   base::RunLoop write_loop_;
72   base::RunLoop shutdown_loop_;
73 
74   DISALLOW_COPY_AND_ASSIGN(TestSocketDataPumpDelegate);
75 };
76 
77 class SocketDataPumpTest : public testing::Test,
78                            public ::testing::WithParamInterface<net::IoMode> {
79  public:
SocketDataPumpTest()80   SocketDataPumpTest()
81       : task_environment_(base::test::TaskEnvironment::MainThreadType::IO) {}
~SocketDataPumpTest()82   ~SocketDataPumpTest() override {}
83 
84   // Initializes the test case with a socket data provider, which will be used
85   // to populate the read/write data of the mock socket.
Init(net::StaticSocketDataProvider * data_provider)86   void Init(net::StaticSocketDataProvider* data_provider) {
87     mock_client_socket_factory_.AddSocketDataProvider(data_provider);
88     mock_client_socket_factory_.set_enable_read_if_ready(true);
89     mojo::DataPipe send_pipe;
90     mojo::DataPipe receive_pipe;
91     receive_handle_ = std::move(receive_pipe.consumer_handle);
92     send_handle_ = std::move(send_pipe.producer_handle);
93     socket_ = mock_client_socket_factory_.CreateTransportClientSocket(
94         net::AddressList(), nullptr /*socket_performance_watcher*/,
95         nullptr /*netlog*/, net::NetLogSource());
96     net::TestCompletionCallback callback;
97     int result = socket_->Connect(callback.callback());
98     if (result == net::ERR_IO_PENDING)
99       result = callback.WaitForResult();
100     EXPECT_EQ(net::OK, result);
101     data_pump_ = std::make_unique<SocketDataPump>(
102         socket_.get(), delegate(), std::move(receive_pipe.producer_handle),
103         std::move(send_pipe.consumer_handle), TRAFFIC_ANNOTATION_FOR_TESTS);
104   }
105 
106   // Reads |num_bytes| from |handle| or reads until an error occurs. Returns the
107   // bytes read as a string.
Read(mojo::ScopedDataPipeConsumerHandle * handle,size_t num_bytes)108   std::string Read(mojo::ScopedDataPipeConsumerHandle* handle,
109                    size_t num_bytes) {
110     std::string received_contents;
111     while (received_contents.size() < num_bytes) {
112       base::RunLoop().RunUntilIdle();
113       std::vector<char> buffer(num_bytes);
114       uint32_t read_size = static_cast<uint32_t>(num_bytes);
115       MojoResult result = handle->get().ReadData(buffer.data(), &read_size,
116                                                  MOJO_READ_DATA_FLAG_NONE);
117       if (result == MOJO_RESULT_SHOULD_WAIT)
118         continue;
119       if (result != MOJO_RESULT_OK)
120         return received_contents;
121       received_contents.append(buffer.data(), read_size);
122     }
123     return received_contents;
124   }
125 
delegate()126   TestSocketDataPumpDelegate* delegate() { return &test_delegate_; }
127 
128   mojo::ScopedDataPipeConsumerHandle receive_handle_;
129   mojo::ScopedDataPipeProducerHandle send_handle_;
130 
131  private:
132   base::test::TaskEnvironment task_environment_;
133   net::MockClientSocketFactory mock_client_socket_factory_;
134   TestSocketDataPumpDelegate test_delegate_;
135   std::unique_ptr<net::StreamSocket> socket_;
136   std::unique_ptr<SocketDataPump> data_pump_;
137 
138   DISALLOW_COPY_AND_ASSIGN(SocketDataPumpTest);
139 };
140 
141 INSTANTIATE_TEST_SUITE_P(All,
142                          SocketDataPumpTest,
143                          testing::Values(net::SYNCHRONOUS, net::ASYNC));
144 
TEST_P(SocketDataPumpTest,ReadAndWriteMultiple)145 TEST_P(SocketDataPumpTest, ReadAndWriteMultiple) {
146   const char kTestMsg[] = "abcdefghij";
147   const size_t kMsgSize = strlen(kTestMsg);
148   const int kNumIterations = 3;
149   std::vector<net::MockRead> reads;
150   std::vector<net::MockWrite> writes;
151   int sequence_number = 0;
152   net::IoMode mode = GetParam();
153   for (int j = 0; j < kNumIterations; ++j) {
154     for (size_t i = 0; i < kMsgSize; ++i) {
155       reads.push_back(net::MockRead(mode, &kTestMsg[i], 1, sequence_number++));
156     }
157     if (j == kNumIterations - 1) {
158       reads.push_back(net::MockRead(mode, net::OK, sequence_number++));
159     }
160     for (size_t i = 0; i < kMsgSize; ++i) {
161       writes.push_back(
162           net::MockWrite(mode, &kTestMsg[i], 1, sequence_number++));
163     }
164   }
165   net::StaticSocketDataProvider data_provider(reads, writes);
166   Init(&data_provider);
167   // Loop kNumIterations times to test that writes can follow reads, and reads
168   // can follow writes.
169   for (int j = 0; j < kNumIterations; ++j) {
170     // Reading kMsgSize should coalesce the 1-byte mock reads.
171     EXPECT_EQ(kTestMsg, Read(&receive_handle_, kMsgSize));
172     // Write multiple times.
173     for (size_t i = 0; i < kMsgSize; ++i) {
174       uint32_t num_bytes = 1;
175       EXPECT_EQ(MOJO_RESULT_OK,
176                 send_handle_->WriteData(&kTestMsg[i], &num_bytes,
177                                         MOJO_WRITE_DATA_FLAG_NONE));
178       // Flush the 1 byte write.
179       base::RunLoop().RunUntilIdle();
180     }
181   }
182   EXPECT_TRUE(data_provider.AllReadDataConsumed());
183   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
184 }
185 
TEST_P(SocketDataPumpTest,PartialStreamSocketWrite)186 TEST_P(SocketDataPumpTest, PartialStreamSocketWrite) {
187   const char kTestMsg[] = "abcdefghij";
188   const size_t kMsgSize = strlen(kTestMsg);
189   const int kNumIterations = 3;
190   std::vector<net::MockRead> reads;
191   std::vector<net::MockWrite> writes;
192   int sequence_number = 0;
193   net::IoMode mode = GetParam();
194   for (int j = 0; j < kNumIterations; ++j) {
195     for (size_t i = 0; i < kMsgSize; ++i) {
196       reads.push_back(net::MockRead(mode, &kTestMsg[i], 1, sequence_number++));
197     }
198     if (j == kNumIterations - 1) {
199       reads.push_back(net::MockRead(mode, net::OK, sequence_number++));
200     }
201     for (size_t i = 0; i < kMsgSize; ++i) {
202       writes.push_back(
203           net::MockWrite(mode, &kTestMsg[i], 1, sequence_number++));
204     }
205   }
206   net::StaticSocketDataProvider data_provider(reads, writes);
207   Init(&data_provider);
208   // Loop kNumIterations times to test that writes can follow reads, and reads
209   // can follow writes.
210   for (int j = 0; j < kNumIterations; ++j) {
211     // Reading kMsgSize should coalesce the 1-byte mock reads.
212     EXPECT_EQ(kTestMsg, Read(&receive_handle_, kMsgSize));
213     // Write twice, each with kMsgSize/2 bytes which is bigger than the 1-byte
214     // MockWrite(). This is to exercise that StreamSocket::Write() can do
215     // partial write.
216     uint32_t first_write_size = kMsgSize / 2;
217     EXPECT_EQ(MOJO_RESULT_OK,
218               send_handle_->WriteData(&kTestMsg[0], &first_write_size,
219                                       MOJO_WRITE_DATA_FLAG_NONE));
220     EXPECT_EQ(kMsgSize / 2, first_write_size);
221     // Flush the kMsgSize/2 byte write.
222     base::RunLoop().RunUntilIdle();
223     uint32_t second_write_size = kMsgSize - first_write_size;
224     EXPECT_EQ(
225         MOJO_RESULT_OK,
226         send_handle_->WriteData(&kTestMsg[first_write_size], &second_write_size,
227                                 MOJO_WRITE_DATA_FLAG_NONE));
228     EXPECT_EQ(kMsgSize - first_write_size, second_write_size);
229     // Flush the kMsgSize/2 byte write.
230     base::RunLoop().RunUntilIdle();
231   }
232   EXPECT_TRUE(data_provider.AllReadDataConsumed());
233   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
234 }
235 
TEST_P(SocketDataPumpTest,ReadEof)236 TEST_P(SocketDataPumpTest, ReadEof) {
237   net::IoMode mode = GetParam();
238   net::MockRead reads[] = {net::MockRead(mode, net::OK)};
239   const char kTestMsg[] = "hello!";
240   net::MockWrite writes[] = {
241       net::MockWrite(mode, kTestMsg, strlen(kTestMsg), 0)};
242   net::StaticSocketDataProvider data_provider(reads, writes);
243   Init(&data_provider);
244   EXPECT_EQ("", Read(&receive_handle_, 1));
245   EXPECT_EQ(net::OK, delegate()->WaitForReadError());
246   // Writes can proceed even though there is a read error.
247   uint32_t num_bytes = strlen(kTestMsg);
248   EXPECT_EQ(MOJO_RESULT_OK, send_handle_->WriteData(&kTestMsg, &num_bytes,
249                                                     MOJO_WRITE_DATA_FLAG_NONE));
250   EXPECT_EQ(strlen(kTestMsg), num_bytes);
251 
252   base::RunLoop().RunUntilIdle();
253   EXPECT_TRUE(data_provider.AllReadDataConsumed());
254   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
255 }
256 
TEST_P(SocketDataPumpTest,ReadError)257 TEST_P(SocketDataPumpTest, ReadError) {
258   net::IoMode mode = GetParam();
259   net::MockRead reads[] = {net::MockRead(mode, net::ERR_FAILED)};
260   const char kTestMsg[] = "hello!";
261   net::MockWrite writes[] = {
262       net::MockWrite(mode, kTestMsg, strlen(kTestMsg), 0)};
263   net::StaticSocketDataProvider data_provider(reads, writes);
264   Init(&data_provider);
265   EXPECT_EQ("", Read(&receive_handle_, 1));
266   EXPECT_EQ(net::ERR_FAILED, delegate()->WaitForReadError());
267   // Writes can proceed even though there is a read error.
268   uint32_t num_bytes = strlen(kTestMsg);
269   EXPECT_EQ(MOJO_RESULT_OK, send_handle_->WriteData(&kTestMsg, &num_bytes,
270                                                     MOJO_WRITE_DATA_FLAG_NONE));
271   EXPECT_EQ(strlen(kTestMsg), num_bytes);
272 
273   base::RunLoop().RunUntilIdle();
274   EXPECT_TRUE(data_provider.AllReadDataConsumed());
275   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
276 }
277 
TEST_P(SocketDataPumpTest,WriteEof)278 TEST_P(SocketDataPumpTest, WriteEof) {
279   net::IoMode mode = GetParam();
280   const char kTestMsg[] = "hello!";
281   net::MockRead reads[] = {net::MockRead(mode, kTestMsg, strlen(kTestMsg), 0),
282                            net::MockRead(mode, net::OK)};
283   net::MockWrite writes[] = {net::MockWrite(mode, net::OK)};
284   net::StaticSocketDataProvider data_provider(reads, writes);
285   Init(&data_provider);
286   uint32_t num_bytes = strlen(kTestMsg);
287   EXPECT_EQ(MOJO_RESULT_OK, send_handle_->WriteData(&kTestMsg, &num_bytes,
288                                                     MOJO_WRITE_DATA_FLAG_NONE));
289   EXPECT_EQ(strlen(kTestMsg), num_bytes);
290   EXPECT_EQ(net::OK, delegate()->WaitForWriteError());
291   // Reads can proceed even though there is a read error.
292   EXPECT_EQ(kTestMsg, Read(&receive_handle_, strlen(kTestMsg)));
293 
294   base::RunLoop().RunUntilIdle();
295   EXPECT_TRUE(data_provider.AllReadDataConsumed());
296   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
297 }
298 
TEST_P(SocketDataPumpTest,WriteError)299 TEST_P(SocketDataPumpTest, WriteError) {
300   net::IoMode mode = GetParam();
301   const char kTestMsg[] = "hello!";
302   net::MockRead reads[] = {net::MockRead(mode, kTestMsg, strlen(kTestMsg), 0),
303                            net::MockRead(mode, net::OK)};
304   net::MockWrite writes[] = {net::MockWrite(mode, net::ERR_FAILED)};
305   net::StaticSocketDataProvider data_provider(reads, writes);
306   Init(&data_provider);
307   uint32_t num_bytes = strlen(kTestMsg);
308   EXPECT_EQ(MOJO_RESULT_OK, send_handle_->WriteData(&kTestMsg, &num_bytes,
309                                                     MOJO_WRITE_DATA_FLAG_NONE));
310   EXPECT_EQ(strlen(kTestMsg), num_bytes);
311   EXPECT_EQ(net::ERR_FAILED, delegate()->WaitForWriteError());
312   // Reads can proceed even though there is a read error.
313   EXPECT_EQ(kTestMsg, Read(&receive_handle_, strlen(kTestMsg)));
314 
315   base::RunLoop().RunUntilIdle();
316   EXPECT_TRUE(data_provider.AllReadDataConsumed());
317   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
318 }
319 
TEST_P(SocketDataPumpTest,PipesShutdown)320 TEST_P(SocketDataPumpTest, PipesShutdown) {
321   net::IoMode mode = GetParam();
322   net::MockRead reads[] = {net::MockRead(mode, net::OK)};
323   net::StaticSocketDataProvider data_provider(reads,
324                                               base::span<net::MockWrite>());
325   Init(&data_provider);
326   send_handle_.reset();
327   receive_handle_.reset();
328   delegate()->WaitForShutdown();
329 }
330 
331 }  // namespace network
332