1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <memory>
6 #include <string>
7
8 #include "base/bind.h"
9 #include "base/memory/ref_counted.h"
10 #include "base/run_loop.h"
11 #include "net/base/address_list.h"
12 #include "net/base/io_buffer.h"
13 #include "net/base/ip_address.h"
14 #include "net/base/net_errors.h"
15 #include "net/base/test_completion_callback.h"
16 #include "net/log/net_log_event_type.h"
17 #include "net/log/net_log_source.h"
18 #include "net/log/net_log_with_source.h"
19 #include "net/log/test_net_log.h"
20 #include "net/log/test_net_log_util.h"
21 #include "net/socket/client_socket_factory.h"
22 #include "net/socket/tcp_client_socket.h"
23 #include "net/socket/tcp_server_socket.h"
24 #include "net/test/gtest_util.h"
25 #include "net/test/test_with_task_environment.h"
26 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
27 #include "testing/gmock/include/gmock/gmock.h"
28 #include "testing/gtest/include/gtest/gtest.h"
29 #include "testing/platform_test.h"
30
31 using net::test::IsError;
32 using net::test::IsOk;
33
34 namespace net {
35
36 namespace {
37
38 const char kServerReply[] = "HTTP/1.1 404 Not Found";
39
40 enum ClientSocketTestTypes { TCP, SCTP };
41
42 } // namespace
43
44 class TransportClientSocketTest
45 : public ::testing::TestWithParam<ClientSocketTestTypes>,
46 public WithTaskEnvironment {
47 public:
TransportClientSocketTest()48 TransportClientSocketTest()
49 : listen_port_(0),
50 socket_factory_(ClientSocketFactory::GetDefaultFactory()),
51 close_server_socket_on_next_send_(false) {}
52
53 virtual ~TransportClientSocketTest() = default;
54
55 // Testcase hooks
56 void SetUp() override;
57
CloseServerSocket()58 void CloseServerSocket() {
59 // delete the connected_sock_, which will close it.
60 connected_sock_.reset();
61 }
62
AcceptCallback(int res)63 void AcceptCallback(int res) {
64 ASSERT_THAT(res, IsOk());
65 connect_loop_.Quit();
66 }
67
68 int DrainClientSocket(IOBuffer* buf,
69 uint32_t buf_len,
70 uint32_t bytes_to_read,
71 TestCompletionCallback* callback);
72
73 // Establishes a connection to the server.
74 void EstablishConnection(TestCompletionCallback* callback);
75
76 // Sends a request from the client to the server socket. Makes the server read
77 // the request and send a response.
78 void SendRequestAndResponse();
79
80 // Makes |connected_sock_| to read |expected_bytes_read| bytes. Returns the
81 // the data read as a string.
82 std::string ReadServerData(int expected_bytes_read);
83
84 // Sends server response.
85 void SendServerResponse();
86
set_close_server_socket_on_next_send(bool close)87 void set_close_server_socket_on_next_send(bool close) {
88 close_server_socket_on_next_send_ = close;
89 }
90
91 protected:
92 base::RunLoop connect_loop_;
93 uint16_t listen_port_;
94 RecordingTestNetLog net_log_;
95 ClientSocketFactory* const socket_factory_;
96 std::unique_ptr<StreamSocket> sock_;
97 std::unique_ptr<StreamSocket> connected_sock_;
98
99 private:
100 std::unique_ptr<TCPServerSocket> listen_sock_;
101 bool close_server_socket_on_next_send_;
102 };
103
SetUp()104 void TransportClientSocketTest::SetUp() {
105 ::testing::TestWithParam<ClientSocketTestTypes>::SetUp();
106
107 // Open a server socket on an ephemeral port.
108 listen_sock_.reset(new TCPServerSocket(nullptr, NetLogSource()));
109 IPEndPoint local_address(IPAddress::IPv4Localhost(), 0);
110 ASSERT_THAT(listen_sock_->Listen(local_address, 1), IsOk());
111 // Get the server's address (including the actual port number).
112 ASSERT_THAT(listen_sock_->GetLocalAddress(&local_address), IsOk());
113 listen_port_ = local_address.port();
114 listen_sock_->Accept(
115 &connected_sock_,
116 base::BindOnce(&TransportClientSocketTest::AcceptCallback,
117 base::Unretained(this)));
118
119 AddressList addr = AddressList::CreateFromIPAddress(
120 IPAddress::IPv4Localhost(), listen_port_);
121 sock_ = socket_factory_->CreateTransportClientSocket(
122 addr, nullptr, nullptr, &net_log_, NetLogSource());
123 }
124
DrainClientSocket(IOBuffer * buf,uint32_t buf_len,uint32_t bytes_to_read,TestCompletionCallback * callback)125 int TransportClientSocketTest::DrainClientSocket(
126 IOBuffer* buf,
127 uint32_t buf_len,
128 uint32_t bytes_to_read,
129 TestCompletionCallback* callback) {
130 int rv = OK;
131 uint32_t bytes_read = 0;
132
133 while (bytes_read < bytes_to_read) {
134 rv = sock_->Read(buf, buf_len, callback->callback());
135 EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
136 rv = callback->GetResult(rv);
137 EXPECT_GT(rv, 0);
138 bytes_read += rv;
139 }
140
141 return static_cast<int>(bytes_read);
142 }
143
EstablishConnection(TestCompletionCallback * callback)144 void TransportClientSocketTest::EstablishConnection(
145 TestCompletionCallback* callback) {
146 int rv = sock_->Connect(callback->callback());
147 // Wait for |listen_sock_| to accept a connection.
148 connect_loop_.Run();
149 // Now wait for the client socket to accept the connection.
150 EXPECT_THAT(callback->GetResult(rv), IsOk());
151 }
152
SendRequestAndResponse()153 void TransportClientSocketTest::SendRequestAndResponse() {
154 // Send client request.
155 const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
156 int request_len = strlen(request_text);
157 scoped_refptr<DrainableIOBuffer> request_buffer =
158 base::MakeRefCounted<DrainableIOBuffer>(
159 base::MakeRefCounted<IOBuffer>(request_len), request_len);
160 memcpy(request_buffer->data(), request_text, request_len);
161
162 int bytes_written = 0;
163 while (request_buffer->BytesRemaining() > 0) {
164 TestCompletionCallback write_callback;
165 int write_result =
166 sock_->Write(request_buffer.get(), request_buffer->BytesRemaining(),
167 write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
168 write_result = write_callback.GetResult(write_result);
169 ASSERT_GT(write_result, 0);
170 ASSERT_LE(bytes_written + write_result, request_len);
171 request_buffer->DidConsume(write_result);
172 bytes_written += write_result;
173 }
174 ASSERT_EQ(request_len, bytes_written);
175
176 // Confirm that the server receives what client sent.
177 std::string data_received = ReadServerData(bytes_written);
178 ASSERT_TRUE(connected_sock_->IsConnectedAndIdle());
179 ASSERT_EQ(request_text, data_received);
180
181 // Write server response.
182 SendServerResponse();
183 }
184
SendServerResponse()185 void TransportClientSocketTest::SendServerResponse() {
186 // TODO(dkegel): this might not be long enough to tickle some bugs.
187 int reply_len = strlen(kServerReply);
188 scoped_refptr<DrainableIOBuffer> write_buffer =
189 base::MakeRefCounted<DrainableIOBuffer>(
190 base::MakeRefCounted<IOBuffer>(reply_len), reply_len);
191 memcpy(write_buffer->data(), kServerReply, reply_len);
192 int bytes_written = 0;
193 while (write_buffer->BytesRemaining() > 0) {
194 TestCompletionCallback write_callback;
195 int write_result = connected_sock_->Write(
196 write_buffer.get(), write_buffer->BytesRemaining(),
197 write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
198 write_result = write_callback.GetResult(write_result);
199 ASSERT_GE(write_result, 0);
200 ASSERT_LE(bytes_written + write_result, reply_len);
201 write_buffer->DidConsume(write_result);
202 bytes_written += write_result;
203 }
204 if (close_server_socket_on_next_send_)
205 CloseServerSocket();
206 }
207
ReadServerData(int expected_bytes_read)208 std::string TransportClientSocketTest::ReadServerData(int expected_bytes_read) {
209 int bytes_read = 0;
210 scoped_refptr<IOBufferWithSize> read_buffer =
211 base::MakeRefCounted<IOBufferWithSize>(expected_bytes_read);
212 while (bytes_read < expected_bytes_read) {
213 TestCompletionCallback read_callback;
214 int rv = connected_sock_->Read(read_buffer.get(),
215 expected_bytes_read - bytes_read,
216 read_callback.callback());
217 EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
218 rv = read_callback.GetResult(rv);
219 EXPECT_GE(rv, 0);
220 bytes_read += rv;
221 }
222 EXPECT_EQ(expected_bytes_read, bytes_read);
223 return std::string(read_buffer->data(), bytes_read);
224 }
225
226 // TODO(leighton): Add SCTP to this list when it is ready.
227 INSTANTIATE_TEST_SUITE_P(StreamSocket,
228 TransportClientSocketTest,
229 ::testing::Values(TCP));
230
TEST_P(TransportClientSocketTest,Connect)231 TEST_P(TransportClientSocketTest, Connect) {
232 TestCompletionCallback callback;
233 EXPECT_FALSE(sock_->IsConnected());
234
235 int rv = sock_->Connect(callback.callback());
236 // Wait for |listen_sock_| to accept a connection.
237 connect_loop_.Run();
238
239 auto net_log_entries = net_log_.GetEntries();
240 EXPECT_TRUE(
241 LogContainsBeginEvent(net_log_entries, 0, NetLogEventType::SOCKET_ALIVE));
242 EXPECT_TRUE(
243 LogContainsBeginEvent(net_log_entries, 1, NetLogEventType::TCP_CONNECT));
244 // Now wait for the client socket to accept the connection.
245 if (rv != OK) {
246 ASSERT_EQ(rv, ERR_IO_PENDING);
247 rv = callback.WaitForResult();
248 EXPECT_EQ(rv, OK);
249 }
250
251 EXPECT_TRUE(sock_->IsConnected());
252 net_log_entries = net_log_.GetEntries();
253 EXPECT_TRUE(
254 LogContainsEndEvent(net_log_entries, -1, NetLogEventType::TCP_CONNECT));
255
256 sock_->Disconnect();
257 EXPECT_FALSE(sock_->IsConnected());
258 }
259
TEST_P(TransportClientSocketTest,IsConnected)260 TEST_P(TransportClientSocketTest, IsConnected) {
261 scoped_refptr<IOBuffer> buf = base::MakeRefCounted<IOBuffer>(4096);
262 TestCompletionCallback callback;
263 uint32_t bytes_read;
264
265 EXPECT_FALSE(sock_->IsConnected());
266 EXPECT_FALSE(sock_->IsConnectedAndIdle());
267
268 EstablishConnection(&callback);
269
270 EXPECT_TRUE(sock_->IsConnected());
271 EXPECT_TRUE(sock_->IsConnectedAndIdle());
272
273 // Send the request and wait for the server to respond.
274 SendRequestAndResponse();
275
276 // Drain a single byte so we know we've received some data.
277 bytes_read = DrainClientSocket(buf.get(), 1, 1, &callback);
278 ASSERT_EQ(bytes_read, 1u);
279
280 // Socket should be considered connected, but not idle, due to
281 // pending data.
282 EXPECT_TRUE(sock_->IsConnected());
283 EXPECT_FALSE(sock_->IsConnectedAndIdle());
284
285 bytes_read =
286 DrainClientSocket(buf.get(), 4096, strlen(kServerReply) - 1, &callback);
287 ASSERT_EQ(bytes_read, strlen(kServerReply) - 1);
288
289 // After draining the data, the socket should be back to connected
290 // and idle.
291 EXPECT_TRUE(sock_->IsConnected());
292 EXPECT_TRUE(sock_->IsConnectedAndIdle());
293
294 // This time close the server socket immediately after the server response.
295 set_close_server_socket_on_next_send(true);
296 SendRequestAndResponse();
297
298 bytes_read = DrainClientSocket(buf.get(), 1, 1, &callback);
299 ASSERT_EQ(bytes_read, 1u);
300
301 // As above because of data.
302 EXPECT_TRUE(sock_->IsConnected());
303 EXPECT_FALSE(sock_->IsConnectedAndIdle());
304
305 bytes_read =
306 DrainClientSocket(buf.get(), 4096, strlen(kServerReply) - 1, &callback);
307 ASSERT_EQ(bytes_read, strlen(kServerReply) - 1);
308
309 // Once the data is drained, the socket should now be seen as not
310 // connected.
311 if (sock_->IsConnected()) {
312 // In the unlikely event that the server's connection closure is not
313 // processed in time, wait for the connection to be closed.
314 int rv = sock_->Read(buf.get(), 4096, callback.callback());
315 EXPECT_EQ(0, callback.GetResult(rv));
316 EXPECT_FALSE(sock_->IsConnected());
317 }
318 EXPECT_FALSE(sock_->IsConnectedAndIdle());
319 }
320
TEST_P(TransportClientSocketTest,Read)321 TEST_P(TransportClientSocketTest, Read) {
322 TestCompletionCallback callback;
323 EstablishConnection(&callback);
324
325 SendRequestAndResponse();
326
327 scoped_refptr<IOBuffer> buf = base::MakeRefCounted<IOBuffer>(4096);
328 uint32_t bytes_read =
329 DrainClientSocket(buf.get(), 4096, strlen(kServerReply), &callback);
330 ASSERT_EQ(bytes_read, strlen(kServerReply));
331 ASSERT_EQ(std::string(kServerReply), std::string(buf->data(), bytes_read));
332
333 // All data has been read now. Read once more to force an ERR_IO_PENDING, and
334 // then close the server socket, and note the close.
335
336 int rv = sock_->Read(buf.get(), 4096, callback.callback());
337 ASSERT_THAT(rv, IsError(ERR_IO_PENDING));
338 CloseServerSocket();
339 EXPECT_EQ(0, callback.WaitForResult());
340 }
341
TEST_P(TransportClientSocketTest,Read_SmallChunks)342 TEST_P(TransportClientSocketTest, Read_SmallChunks) {
343 TestCompletionCallback callback;
344 EstablishConnection(&callback);
345
346 SendRequestAndResponse();
347
348 scoped_refptr<IOBuffer> buf = base::MakeRefCounted<IOBuffer>(1);
349 uint32_t bytes_read = 0;
350 while (bytes_read < strlen(kServerReply)) {
351 int rv = sock_->Read(buf.get(), 1, callback.callback());
352 EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
353
354 rv = callback.GetResult(rv);
355
356 ASSERT_EQ(1, rv);
357 bytes_read += rv;
358 }
359
360 // All data has been read now. Read once more to force an ERR_IO_PENDING, and
361 // then close the server socket, and note the close.
362
363 int rv = sock_->Read(buf.get(), 1, callback.callback());
364 ASSERT_THAT(rv, IsError(ERR_IO_PENDING));
365 CloseServerSocket();
366 EXPECT_EQ(0, callback.WaitForResult());
367 }
368
TEST_P(TransportClientSocketTest,Read_Interrupted)369 TEST_P(TransportClientSocketTest, Read_Interrupted) {
370 TestCompletionCallback callback;
371 EstablishConnection(&callback);
372
373 SendRequestAndResponse();
374
375 // Do a partial read and then exit. This test should not crash!
376 scoped_refptr<IOBuffer> buf = base::MakeRefCounted<IOBuffer>(16);
377 int rv = sock_->Read(buf.get(), 16, callback.callback());
378 EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
379
380 rv = callback.GetResult(rv);
381
382 EXPECT_NE(0, rv);
383 }
384
TEST_P(TransportClientSocketTest,FullDuplex_ReadFirst)385 TEST_P(TransportClientSocketTest, FullDuplex_ReadFirst) {
386 TestCompletionCallback callback;
387 EstablishConnection(&callback);
388
389 // Read first. There's no data, so it should return ERR_IO_PENDING.
390 const int kBufLen = 4096;
391 scoped_refptr<IOBuffer> buf = base::MakeRefCounted<IOBuffer>(kBufLen);
392 int rv = sock_->Read(buf.get(), kBufLen, callback.callback());
393 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
394
395 const int kWriteBufLen = 64 * 1024;
396 scoped_refptr<IOBuffer> request_buffer =
397 base::MakeRefCounted<IOBuffer>(kWriteBufLen);
398 char* request_data = request_buffer->data();
399 memset(request_data, 'A', kWriteBufLen);
400 TestCompletionCallback write_callback;
401
402 int bytes_written = 0;
403 while (true) {
404 rv = sock_->Write(request_buffer.get(), kWriteBufLen,
405 write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
406 ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
407 if (rv == ERR_IO_PENDING) {
408 ReadServerData(bytes_written);
409 SendServerResponse();
410 rv = write_callback.WaitForResult();
411 break;
412 }
413 bytes_written += rv;
414 }
415
416 // At this point, both read and write have returned ERR_IO_PENDING, and the
417 // write callback has executed. We wait for the read callback to run now to
418 // make sure that the socket can handle full duplex communications.
419
420 rv = callback.WaitForResult();
421 EXPECT_GE(rv, 0);
422 }
423
TEST_P(TransportClientSocketTest,FullDuplex_WriteFirst)424 TEST_P(TransportClientSocketTest, FullDuplex_WriteFirst) {
425 TestCompletionCallback callback;
426 EstablishConnection(&callback);
427
428 const int kWriteBufLen = 64 * 1024;
429 scoped_refptr<IOBuffer> request_buffer =
430 base::MakeRefCounted<IOBuffer>(kWriteBufLen);
431 char* request_data = request_buffer->data();
432 memset(request_data, 'A', kWriteBufLen);
433 TestCompletionCallback write_callback;
434
435 int bytes_written = 0;
436 while (true) {
437 int rv =
438 sock_->Write(request_buffer.get(), kWriteBufLen,
439 write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
440 ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
441
442 if (rv == ERR_IO_PENDING)
443 break;
444 bytes_written += rv;
445 }
446
447 // Now we have the Write() blocked on ERR_IO_PENDING. It's time to force the
448 // Read() to block on ERR_IO_PENDING too.
449
450 const int kBufLen = 4096;
451 scoped_refptr<IOBuffer> buf = base::MakeRefCounted<IOBuffer>(kBufLen);
452 while (true) {
453 int rv = sock_->Read(buf.get(), kBufLen, callback.callback());
454 ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
455 if (rv == ERR_IO_PENDING)
456 break;
457 }
458
459 // At this point, both read and write have returned ERR_IO_PENDING. Now we
460 // run the write and read callbacks to make sure they can handle full duplex
461 // communications.
462
463 ReadServerData(bytes_written);
464 SendServerResponse();
465 int rv = write_callback.WaitForResult();
466 EXPECT_GE(rv, 0);
467
468 rv = callback.WaitForResult();
469 EXPECT_GT(rv, 0);
470 }
471
472 } // namespace net
473