1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 // This file contains some tests for TCPClientSocket.
6 // transport_client_socket_unittest.cc contans some other tests that
7 // are common for TCP and other types of sockets.
8
9 #include "net/socket/tcp_client_socket.h"
10
11 #include <stddef.h>
12
13 #include "base/power_monitor/power_monitor.h"
14 #include "base/power_monitor/power_monitor_source.h"
15 #include "base/strings/string_number_conversions.h"
16 #include "base/test/bind.h"
17 #include "base/test/scoped_feature_list.h"
18 #include "base/test/task_environment.h"
19 #include "build/build_config.h"
20 #include "net/base/features.h"
21 #include "net/base/ip_address.h"
22 #include "net/base/ip_endpoint.h"
23 #include "net/base/net_errors.h"
24 #include "net/base/test_completion_callback.h"
25 #include "net/log/net_log_source.h"
26 #include "net/nqe/network_quality_estimator_test_util.h"
27 #include "net/socket/socket_performance_watcher.h"
28 #include "net/socket/socket_test_util.h"
29 #include "net/socket/tcp_server_socket.h"
30 #include "net/test/embedded_test_server/embedded_test_server.h"
31 #include "net/test/gtest_util.h"
32 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
33 #include "testing/gmock/include/gmock/gmock.h"
34 #include "testing/gtest/include/gtest/gtest.h"
35
36 // This matches logic in tcp_client_socket.cc. Only used once, but defining it
37 // in this file instead of just inlining the OS checks where its used makes it
38 // more grep-able.
39 #if !defined(OS_ANDROID) && !defined(OS_NACL)
40 #define TCP_CLIENT_SOCKET_OBSERVES_SUSPEND
41 #endif
42
43 using net::test::IsError;
44 using net::test::IsOk;
45 using testing::Not;
46
47 namespace base {
48 class TimeDelta;
49 }
50
51 namespace net {
52
53 namespace {
54
55 // Test power monitor source that can simulate entering suspend mode. Can't use
56 // the one in base/ because it insists on bringing its own MessageLoop.
57 class TestPowerMonitorSource : public base::PowerMonitorSource {
58 public:
59 TestPowerMonitorSource() = default;
60 ~TestPowerMonitorSource() override = default;
61
Suspend()62 void Suspend() { ProcessPowerEvent(SUSPEND_EVENT); }
63
Resume()64 void Resume() { ProcessPowerEvent(RESUME_EVENT); }
65
IsOnBatteryPowerImpl()66 bool IsOnBatteryPowerImpl() override { return false; }
67
68 private:
69 DISALLOW_COPY_AND_ASSIGN(TestPowerMonitorSource);
70 };
71
72 class TCPClientSocketTest : public testing::Test {
73 public:
TCPClientSocketTest()74 TCPClientSocketTest()
75 : task_environment_(base::test::TaskEnvironment::MainThreadType::IO) {
76 std::unique_ptr<TestPowerMonitorSource> power_monitor_source =
77 std::make_unique<TestPowerMonitorSource>();
78 power_monitor_source_ = power_monitor_source.get();
79 base::PowerMonitor::Initialize(std::move(power_monitor_source));
80 }
81
~TCPClientSocketTest()82 ~TCPClientSocketTest() override { base::PowerMonitor::ShutdownForTesting(); }
83
Suspend()84 void Suspend() { power_monitor_source_->Suspend(); }
Resume()85 void Resume() { power_monitor_source_->Resume(); }
86
CreateConnectedSockets(std::unique_ptr<StreamSocket> * accepted_socket,std::unique_ptr<TCPClientSocket> * client_socket,std::unique_ptr<ServerSocket> * server_socket_opt=nullptr)87 void CreateConnectedSockets(
88 std::unique_ptr<StreamSocket>* accepted_socket,
89 std::unique_ptr<TCPClientSocket>* client_socket,
90 std::unique_ptr<ServerSocket>* server_socket_opt = nullptr) {
91 IPAddress local_address = IPAddress::IPv4Localhost();
92
93 std::unique_ptr<TCPServerSocket> server_socket =
94 std::make_unique<TCPServerSocket>(nullptr, NetLogSource());
95 ASSERT_THAT(server_socket->Listen(IPEndPoint(local_address, 0), 1), IsOk());
96 IPEndPoint server_address;
97 ASSERT_THAT(server_socket->GetLocalAddress(&server_address), IsOk());
98
99 *client_socket = std::make_unique<TCPClientSocket>(
100 AddressList(server_address), nullptr, nullptr, nullptr, NetLogSource());
101
102 EXPECT_THAT((*client_socket)->Bind(IPEndPoint(local_address, 0)), IsOk());
103
104 IPEndPoint local_address_result;
105 EXPECT_THAT((*client_socket)->GetLocalAddress(&local_address_result),
106 IsOk());
107 EXPECT_EQ(local_address, local_address_result.address());
108
109 TestCompletionCallback connect_callback;
110 int connect_result = (*client_socket)->Connect(connect_callback.callback());
111
112 TestCompletionCallback accept_callback;
113 int result =
114 server_socket->Accept(accepted_socket, accept_callback.callback());
115 result = accept_callback.GetResult(result);
116 ASSERT_THAT(result, IsOk());
117
118 ASSERT_THAT(connect_callback.GetResult(connect_result), IsOk());
119
120 EXPECT_TRUE((*client_socket)->IsConnected());
121 EXPECT_TRUE((*accepted_socket)->IsConnected());
122 if (server_socket_opt)
123 *server_socket_opt = std::move(server_socket);
124 }
125
126 private:
127 base::test::TaskEnvironment task_environment_;
128
129 TestPowerMonitorSource* power_monitor_source_;
130 };
131
132 // Try binding a socket to loopback interface and verify that we can
133 // still connect to a server on the same interface.
TEST_F(TCPClientSocketTest,BindLoopbackToLoopback)134 TEST_F(TCPClientSocketTest, BindLoopbackToLoopback) {
135 IPAddress lo_address = IPAddress::IPv4Localhost();
136
137 TCPServerSocket server(nullptr, NetLogSource());
138 ASSERT_THAT(server.Listen(IPEndPoint(lo_address, 0), 1), IsOk());
139 IPEndPoint server_address;
140 ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
141
142 TCPClientSocket socket(AddressList(server_address), nullptr, nullptr, nullptr,
143 NetLogSource());
144
145 EXPECT_THAT(socket.Bind(IPEndPoint(lo_address, 0)), IsOk());
146
147 IPEndPoint local_address_result;
148 EXPECT_THAT(socket.GetLocalAddress(&local_address_result), IsOk());
149 EXPECT_EQ(lo_address, local_address_result.address());
150
151 TestCompletionCallback connect_callback;
152 int connect_result = socket.Connect(connect_callback.callback());
153
154 TestCompletionCallback accept_callback;
155 std::unique_ptr<StreamSocket> accepted_socket;
156 int result = server.Accept(&accepted_socket, accept_callback.callback());
157 result = accept_callback.GetResult(result);
158 ASSERT_THAT(result, IsOk());
159
160 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
161
162 EXPECT_TRUE(socket.IsConnected());
163 socket.Disconnect();
164 EXPECT_FALSE(socket.IsConnected());
165 EXPECT_EQ(ERR_SOCKET_NOT_CONNECTED,
166 socket.GetLocalAddress(&local_address_result));
167 }
168
169 // Try to bind socket to the loopback interface and connect to an
170 // external address, verify that connection fails.
TEST_F(TCPClientSocketTest,BindLoopbackToExternal)171 TEST_F(TCPClientSocketTest, BindLoopbackToExternal) {
172 IPAddress external_ip(72, 14, 213, 105);
173 TCPClientSocket socket(AddressList::CreateFromIPAddress(external_ip, 80),
174 nullptr, nullptr, nullptr, NetLogSource());
175
176 EXPECT_THAT(socket.Bind(IPEndPoint(IPAddress::IPv4Localhost(), 0)), IsOk());
177
178 TestCompletionCallback connect_callback;
179 int result = socket.Connect(connect_callback.callback());
180
181 // We may get different errors here on different system, but
182 // connect() is not expected to succeed.
183 EXPECT_THAT(connect_callback.GetResult(result), Not(IsOk()));
184 }
185
186 // Bind a socket to the IPv4 loopback interface and try to connect to
187 // the IPv6 loopback interface, verify that connection fails.
TEST_F(TCPClientSocketTest,BindLoopbackToIPv6)188 TEST_F(TCPClientSocketTest, BindLoopbackToIPv6) {
189 TCPServerSocket server(nullptr, NetLogSource());
190 int listen_result =
191 server.Listen(IPEndPoint(IPAddress::IPv6Localhost(), 0), 1);
192 if (listen_result != OK) {
193 LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is disabled."
194 " Skipping the test";
195 return;
196 }
197
198 IPEndPoint server_address;
199 ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
200 TCPClientSocket socket(AddressList(server_address), nullptr, nullptr, nullptr,
201 NetLogSource());
202
203 EXPECT_THAT(socket.Bind(IPEndPoint(IPAddress::IPv4Localhost(), 0)), IsOk());
204
205 TestCompletionCallback connect_callback;
206 int result = socket.Connect(connect_callback.callback());
207
208 EXPECT_THAT(connect_callback.GetResult(result), Not(IsOk()));
209 }
210
TEST_F(TCPClientSocketTest,WasEverUsed)211 TEST_F(TCPClientSocketTest, WasEverUsed) {
212 IPAddress lo_address = IPAddress::IPv4Localhost();
213 TCPServerSocket server(nullptr, NetLogSource());
214 ASSERT_THAT(server.Listen(IPEndPoint(lo_address, 0), 1), IsOk());
215 IPEndPoint server_address;
216 ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
217
218 TCPClientSocket socket(AddressList(server_address), nullptr, nullptr, nullptr,
219 NetLogSource());
220
221 EXPECT_FALSE(socket.WasEverUsed());
222
223 EXPECT_THAT(socket.Bind(IPEndPoint(lo_address, 0)), IsOk());
224
225 // Just connecting the socket should not set WasEverUsed.
226 TestCompletionCallback connect_callback;
227 int connect_result = socket.Connect(connect_callback.callback());
228 EXPECT_FALSE(socket.WasEverUsed());
229
230 TestCompletionCallback accept_callback;
231 std::unique_ptr<StreamSocket> accepted_socket;
232 int result = server.Accept(&accepted_socket, accept_callback.callback());
233 ASSERT_THAT(accept_callback.GetResult(result), IsOk());
234 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
235
236 EXPECT_FALSE(socket.WasEverUsed());
237 EXPECT_TRUE(socket.IsConnected());
238
239 // Writing some data to the socket _should_ set WasEverUsed.
240 const char kRequest[] = "GET / HTTP/1.0";
241 auto write_buffer = base::MakeRefCounted<StringIOBuffer>(kRequest);
242 TestCompletionCallback write_callback;
243 socket.Write(write_buffer.get(), write_buffer->size(),
244 write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
245 EXPECT_TRUE(socket.WasEverUsed());
246 socket.Disconnect();
247 EXPECT_FALSE(socket.IsConnected());
248
249 EXPECT_TRUE(socket.WasEverUsed());
250
251 // Re-use the socket, which should set WasEverUsed to false.
252 EXPECT_THAT(socket.Bind(IPEndPoint(lo_address, 0)), IsOk());
253 TestCompletionCallback connect_callback2;
254 connect_result = socket.Connect(connect_callback2.callback());
255 EXPECT_FALSE(socket.WasEverUsed());
256 }
257
258 class TestSocketPerformanceWatcher : public SocketPerformanceWatcher {
259 public:
TestSocketPerformanceWatcher()260 TestSocketPerformanceWatcher() : connection_changed_count_(0u) {}
261 ~TestSocketPerformanceWatcher() override = default;
262
ShouldNotifyUpdatedRTT() const263 bool ShouldNotifyUpdatedRTT() const override { return true; }
264
OnUpdatedRTTAvailable(const base::TimeDelta & rtt)265 void OnUpdatedRTTAvailable(const base::TimeDelta& rtt) override {}
266
OnConnectionChanged()267 void OnConnectionChanged() override { connection_changed_count_++; }
268
connection_changed_count() const269 size_t connection_changed_count() const { return connection_changed_count_; }
270
271 private:
272 size_t connection_changed_count_;
273
274 DISALLOW_COPY_AND_ASSIGN(TestSocketPerformanceWatcher);
275 };
276
277 // TestSocketPerformanceWatcher requires kernel support for tcp_info struct, and
278 // so it is enabled only on certain platforms.
279 #if defined(TCP_INFO) || defined(OS_LINUX) || defined(OS_CHROMEOS)
280 #define MAYBE_TestSocketPerformanceWatcher TestSocketPerformanceWatcher
281 #else
282 #define MAYBE_TestSocketPerformanceWatcher TestSocketPerformanceWatcher
283 #endif
284 // Tests if the socket performance watcher is notified if the same socket is
285 // used for a different connection.
TEST_F(TCPClientSocketTest,MAYBE_TestSocketPerformanceWatcher)286 TEST_F(TCPClientSocketTest, MAYBE_TestSocketPerformanceWatcher) {
287 const size_t kNumIPs = 2;
288 IPAddressList ip_list;
289 for (size_t i = 0; i < kNumIPs; ++i)
290 ip_list.push_back(IPAddress(72, 14, 213, i));
291
292 std::unique_ptr<TestSocketPerformanceWatcher> watcher(
293 new TestSocketPerformanceWatcher());
294 TestSocketPerformanceWatcher* watcher_ptr = watcher.get();
295
296 TCPClientSocket socket(
297 AddressList::CreateFromIPAddressList(ip_list, "example.com"),
298 std::move(watcher), nullptr, nullptr, NetLogSource());
299
300 EXPECT_THAT(socket.Bind(IPEndPoint(IPAddress::IPv4Localhost(), 0)), IsOk());
301
302 TestCompletionCallback connect_callback;
303
304 ASSERT_NE(OK, connect_callback.GetResult(
305 socket.Connect(connect_callback.callback())));
306
307 EXPECT_EQ(kNumIPs - 1, watcher_ptr->connection_changed_count());
308 }
309
310 // On Android, where socket tagging is supported, verify that
311 // TCPClientSocket::Tag works as expected.
312 #if defined(OS_ANDROID)
TEST_F(TCPClientSocketTest,Tag)313 TEST_F(TCPClientSocketTest, Tag) {
314 if (!CanGetTaggedBytes()) {
315 DVLOG(0) << "Skipping test - GetTaggedBytes unsupported.";
316 return;
317 }
318
319 // Start test server.
320 EmbeddedTestServer test_server;
321 test_server.AddDefaultHandlers(base::FilePath());
322 ASSERT_TRUE(test_server.Start());
323
324 AddressList addr_list;
325 ASSERT_TRUE(test_server.GetAddressList(&addr_list));
326 TCPClientSocket s(addr_list, nullptr, nullptr, nullptr, NetLogSource());
327
328 // Verify TCP connect packets are tagged and counted properly.
329 int32_t tag_val1 = 0x12345678;
330 uint64_t old_traffic = GetTaggedBytes(tag_val1);
331 SocketTag tag1(SocketTag::UNSET_UID, tag_val1);
332 s.ApplySocketTag(tag1);
333 TestCompletionCallback connect_callback;
334 int connect_result = s.Connect(connect_callback.callback());
335 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
336 EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
337
338 // Verify socket can be retagged with a new value and the current process's
339 // UID.
340 int32_t tag_val2 = 0x87654321;
341 old_traffic = GetTaggedBytes(tag_val2);
342 SocketTag tag2(getuid(), tag_val2);
343 s.ApplySocketTag(tag2);
344 const char kRequest1[] = "GET / HTTP/1.0";
345 scoped_refptr<IOBuffer> write_buffer1 =
346 base::MakeRefCounted<StringIOBuffer>(kRequest1);
347 TestCompletionCallback write_callback1;
348 EXPECT_EQ(s.Write(write_buffer1.get(), strlen(kRequest1),
349 write_callback1.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
350 static_cast<int>(strlen(kRequest1)));
351 EXPECT_GT(GetTaggedBytes(tag_val2), old_traffic);
352
353 // Verify socket can be retagged with a new value and the current process's
354 // UID.
355 old_traffic = GetTaggedBytes(tag_val1);
356 s.ApplySocketTag(tag1);
357 const char kRequest2[] = "\n\n";
358 scoped_refptr<IOBufferWithSize> write_buffer2 =
359 base::MakeRefCounted<IOBufferWithSize>(strlen(kRequest2));
360 memmove(write_buffer2->data(), kRequest2, strlen(kRequest2));
361 TestCompletionCallback write_callback2;
362 EXPECT_EQ(s.Write(write_buffer2.get(), strlen(kRequest2),
363 write_callback2.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
364 static_cast<int>(strlen(kRequest2)));
365 EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
366
367 s.Disconnect();
368 }
369
TEST_F(TCPClientSocketTest,TagAfterConnect)370 TEST_F(TCPClientSocketTest, TagAfterConnect) {
371 if (!CanGetTaggedBytes()) {
372 DVLOG(0) << "Skipping test - GetTaggedBytes unsupported.";
373 return;
374 }
375
376 // Start test server.
377 EmbeddedTestServer test_server;
378 test_server.AddDefaultHandlers(base::FilePath());
379 ASSERT_TRUE(test_server.Start());
380
381 AddressList addr_list;
382 ASSERT_TRUE(test_server.GetAddressList(&addr_list));
383 TCPClientSocket s(addr_list, nullptr, nullptr, nullptr, NetLogSource());
384
385 // Connect socket.
386 TestCompletionCallback connect_callback;
387 int connect_result = s.Connect(connect_callback.callback());
388 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
389
390 // Verify socket can be tagged with a new value and the current process's
391 // UID.
392 int32_t tag_val2 = 0x87654321;
393 uint64_t old_traffic = GetTaggedBytes(tag_val2);
394 SocketTag tag2(getuid(), tag_val2);
395 s.ApplySocketTag(tag2);
396 const char kRequest1[] = "GET / HTTP/1.0";
397 scoped_refptr<IOBuffer> write_buffer1 =
398 base::MakeRefCounted<StringIOBuffer>(kRequest1);
399 TestCompletionCallback write_callback1;
400 EXPECT_EQ(s.Write(write_buffer1.get(), strlen(kRequest1),
401 write_callback1.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
402 static_cast<int>(strlen(kRequest1)));
403 EXPECT_GT(GetTaggedBytes(tag_val2), old_traffic);
404
405 // Verify socket can be retagged with a new value and the current process's
406 // UID.
407 int32_t tag_val1 = 0x12345678;
408 old_traffic = GetTaggedBytes(tag_val1);
409 SocketTag tag1(SocketTag::UNSET_UID, tag_val1);
410 s.ApplySocketTag(tag1);
411 const char kRequest2[] = "\n\n";
412 scoped_refptr<IOBuffer> write_buffer2 =
413 base::MakeRefCounted<StringIOBuffer>(kRequest2);
414 TestCompletionCallback write_callback2;
415 EXPECT_EQ(s.Write(write_buffer2.get(), strlen(kRequest2),
416 write_callback2.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
417 static_cast<int>(strlen(kRequest2)));
418 EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
419
420 s.Disconnect();
421 }
422 #endif // defined(OS_ANDROID)
423
424 // TCP socket that hangs indefinitely when establishing a connection.
425 class NeverConnectingTCPClientSocket : public TCPClientSocket {
426 public:
NeverConnectingTCPClientSocket(const AddressList & addresses,std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,NetworkQualityEstimator * network_quality_estimator,net::NetLog * net_log,const net::NetLogSource & source)427 NeverConnectingTCPClientSocket(
428 const AddressList& addresses,
429 std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
430 NetworkQualityEstimator* network_quality_estimator,
431 net::NetLog* net_log,
432 const net::NetLogSource& source)
433 : TCPClientSocket(addresses,
434 std::move(socket_performance_watcher),
435 network_quality_estimator,
436 net_log,
437 source) {}
438
439 // Returns the number of times that ConnectInternal() was called.
connect_internal_counter() const440 int connect_internal_counter() const { return connect_internal_counter_; }
441
442 private:
ConnectInternal(const IPEndPoint & endpoint)443 int ConnectInternal(const IPEndPoint& endpoint) override {
444 connect_internal_counter_++;
445 return ERR_IO_PENDING;
446 }
447
448 int connect_internal_counter_ = 0;
449 };
450
451 // Tests for closing sockets on suspend mode.
452 #if defined(TCP_CLIENT_SOCKET_OBSERVES_SUSPEND)
453
454 // Entering suspend mode shouldn't affect sockets that haven't connected yet, or
455 // listening server sockets.
TEST_F(TCPClientSocketTest,SuspendBeforeConnect)456 TEST_F(TCPClientSocketTest, SuspendBeforeConnect) {
457 IPAddress lo_address = IPAddress::IPv4Localhost();
458
459 TCPServerSocket server(nullptr, NetLogSource());
460 ASSERT_THAT(server.Listen(IPEndPoint(lo_address, 0), 1), IsOk());
461 IPEndPoint server_address;
462 ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
463
464 TCPClientSocket socket(AddressList(server_address), nullptr, nullptr, nullptr,
465 NetLogSource());
466
467 EXPECT_THAT(socket.Bind(IPEndPoint(lo_address, 0)), IsOk());
468
469 IPEndPoint local_address_result;
470 EXPECT_THAT(socket.GetLocalAddress(&local_address_result), IsOk());
471 EXPECT_EQ(lo_address, local_address_result.address());
472
473 TestCompletionCallback accept_callback;
474 std::unique_ptr<StreamSocket> accepted_socket;
475 ASSERT_THAT(server.Accept(&accepted_socket, accept_callback.callback()),
476 IsError(ERR_IO_PENDING));
477
478 Suspend();
479 // Power notifications happen asynchronously, so have to wait for the socket
480 // to be notified of the suspend event.
481 base::RunLoop().RunUntilIdle();
482
483 TestCompletionCallback connect_callback;
484 int connect_result = socket.Connect(connect_callback.callback());
485
486 ASSERT_THAT(accept_callback.WaitForResult(), IsOk());
487
488 ASSERT_THAT(connect_callback.GetResult(connect_result), IsOk());
489
490 EXPECT_TRUE(socket.IsConnected());
491 EXPECT_TRUE(accepted_socket->IsConnected());
492 }
493
TEST_F(TCPClientSocketTest,SuspendDuringConnect)494 TEST_F(TCPClientSocketTest, SuspendDuringConnect) {
495 IPAddress lo_address = IPAddress::IPv4Localhost();
496
497 TCPServerSocket server(nullptr, NetLogSource());
498 ASSERT_THAT(server.Listen(IPEndPoint(lo_address, 0), 1), IsOk());
499 IPEndPoint server_address;
500 ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
501
502 NeverConnectingTCPClientSocket socket(AddressList(server_address), nullptr,
503 nullptr, nullptr, NetLogSource());
504
505 EXPECT_THAT(socket.Bind(IPEndPoint(lo_address, 0)), IsOk());
506
507 IPEndPoint local_address_result;
508 EXPECT_THAT(socket.GetLocalAddress(&local_address_result), IsOk());
509 EXPECT_EQ(lo_address, local_address_result.address());
510
511 TestCompletionCallback connect_callback;
512 int rv = socket.Connect(connect_callback.callback());
513 ASSERT_THAT(rv, IsError(ERR_IO_PENDING));
514 Suspend();
515 EXPECT_THAT(connect_callback.WaitForResult(),
516 IsError(ERR_NETWORK_IO_SUSPENDED));
517 }
518
TEST_F(TCPClientSocketTest,SuspendDuringConnectMultipleAddresses)519 TEST_F(TCPClientSocketTest, SuspendDuringConnectMultipleAddresses) {
520 IPAddress lo_address = IPAddress::IPv4Localhost();
521
522 TCPServerSocket server(nullptr, NetLogSource());
523 ASSERT_THAT(server.Listen(IPEndPoint(IPAddress(0, 0, 0, 0), 0), 1), IsOk());
524 IPEndPoint server_address;
525 ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
526
527 AddressList address_list;
528 address_list.push_back(
529 IPEndPoint(IPAddress(127, 0, 0, 1), server_address.port()));
530 address_list.push_back(
531 IPEndPoint(IPAddress(127, 0, 0, 2), server_address.port()));
532 NeverConnectingTCPClientSocket socket(address_list, nullptr, nullptr, nullptr,
533 NetLogSource());
534
535 EXPECT_THAT(socket.Bind(IPEndPoint(lo_address, 0)), IsOk());
536
537 IPEndPoint local_address_result;
538 EXPECT_THAT(socket.GetLocalAddress(&local_address_result), IsOk());
539 EXPECT_EQ(lo_address, local_address_result.address());
540
541 TestCompletionCallback connect_callback;
542 int rv = socket.Connect(connect_callback.callback());
543 ASSERT_THAT(rv, IsError(ERR_IO_PENDING));
544 Suspend();
545 EXPECT_THAT(connect_callback.WaitForResult(),
546 IsError(ERR_NETWORK_IO_SUSPENDED));
547 }
548
TEST_F(TCPClientSocketTest,SuspendWhileIdle)549 TEST_F(TCPClientSocketTest, SuspendWhileIdle) {
550 std::unique_ptr<StreamSocket> accepted_socket;
551 std::unique_ptr<TCPClientSocket> client_socket;
552 std::unique_ptr<ServerSocket> server_socket;
553 CreateConnectedSockets(&accepted_socket, &client_socket, &server_socket);
554
555 Suspend();
556 // Power notifications happen asynchronously.
557 base::RunLoop().RunUntilIdle();
558
559 scoped_refptr<IOBuffer> buffer = base::MakeRefCounted<IOBuffer>(1);
560 buffer->data()[0] = '1';
561 TestCompletionCallback callback;
562 // Check that the client socket is disconnected, and actions fail with
563 // ERR_NETWORK_IO_SUSPENDED.
564 EXPECT_FALSE(client_socket->IsConnected());
565 EXPECT_THAT(client_socket->Read(buffer.get(), 1, callback.callback()),
566 IsError(ERR_NETWORK_IO_SUSPENDED));
567 EXPECT_THAT(client_socket->Write(buffer.get(), 1, callback.callback(),
568 TRAFFIC_ANNOTATION_FOR_TESTS),
569 IsError(ERR_NETWORK_IO_SUSPENDED));
570
571 // Check that the accepted socket is disconnected, and actions fail with
572 // ERR_NETWORK_IO_SUSPENDED.
573 EXPECT_FALSE(accepted_socket->IsConnected());
574 EXPECT_THAT(accepted_socket->Read(buffer.get(), 1, callback.callback()),
575 IsError(ERR_NETWORK_IO_SUSPENDED));
576 EXPECT_THAT(accepted_socket->Write(buffer.get(), 1, callback.callback(),
577 TRAFFIC_ANNOTATION_FOR_TESTS),
578 IsError(ERR_NETWORK_IO_SUSPENDED));
579
580 // Reconnecting the socket should work.
581 TestCompletionCallback connect_callback;
582 int connect_result = client_socket->Connect(connect_callback.callback());
583 accepted_socket.reset();
584 TestCompletionCallback accept_callback;
585 int accept_result =
586 server_socket->Accept(&accepted_socket, accept_callback.callback());
587 ASSERT_THAT(accept_callback.GetResult(accept_result), IsOk());
588 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
589 }
590
TEST_F(TCPClientSocketTest,SuspendDuringRead)591 TEST_F(TCPClientSocketTest, SuspendDuringRead) {
592 std::unique_ptr<StreamSocket> accepted_socket;
593 std::unique_ptr<TCPClientSocket> client_socket;
594 CreateConnectedSockets(&accepted_socket, &client_socket);
595
596 // Start a read. This shouldn't complete, since the other end of the pipe
597 // writes no data.
598 scoped_refptr<IOBuffer> read_buffer = base::MakeRefCounted<IOBuffer>(1);
599 read_buffer->data()[0] = '1';
600 TestCompletionCallback callback;
601 ASSERT_THAT(client_socket->Read(read_buffer.get(), 1, callback.callback()),
602 IsError(ERR_IO_PENDING));
603
604 // Simulate a suspend event. Can't use a real power event, as it would affect
605 // |accepted_socket| as well.
606 client_socket->OnSuspend();
607 EXPECT_THAT(callback.WaitForResult(), IsError(ERR_NETWORK_IO_SUSPENDED));
608
609 // Check that the client socket really is disconnected.
610 EXPECT_FALSE(client_socket->IsConnected());
611 EXPECT_THAT(client_socket->Read(read_buffer.get(), 1, callback.callback()),
612 IsError(ERR_NETWORK_IO_SUSPENDED));
613 EXPECT_THAT(client_socket->Write(read_buffer.get(), 1, callback.callback(),
614 TRAFFIC_ANNOTATION_FOR_TESTS),
615 IsError(ERR_NETWORK_IO_SUSPENDED));
616 }
617
TEST_F(TCPClientSocketTest,SuspendDuringWrite)618 TEST_F(TCPClientSocketTest, SuspendDuringWrite) {
619 std::unique_ptr<StreamSocket> accepted_socket;
620 std::unique_ptr<TCPClientSocket> client_socket;
621 CreateConnectedSockets(&accepted_socket, &client_socket);
622
623 // Write to the socket until a write doesn't complete synchronously.
624 const int kBufferSize = 4096;
625 scoped_refptr<IOBuffer> write_buffer =
626 base::MakeRefCounted<IOBuffer>(kBufferSize);
627 memset(write_buffer->data(), '1', kBufferSize);
628 TestCompletionCallback callback;
629 while (true) {
630 int rv =
631 client_socket->Write(write_buffer.get(), kBufferSize,
632 callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
633 if (rv == ERR_IO_PENDING)
634 break;
635 ASSERT_GT(rv, 0);
636 }
637
638 // Simulate a suspend event. Can't use a real power event, as it would affect
639 // |accepted_socket| as well.
640 client_socket->OnSuspend();
641 EXPECT_THAT(callback.WaitForResult(), IsError(ERR_NETWORK_IO_SUSPENDED));
642
643 // Check that the client socket really is disconnected.
644 EXPECT_FALSE(client_socket->IsConnected());
645 EXPECT_THAT(client_socket->Read(write_buffer.get(), 1, callback.callback()),
646 IsError(ERR_NETWORK_IO_SUSPENDED));
647 EXPECT_THAT(client_socket->Write(write_buffer.get(), 1, callback.callback(),
648 TRAFFIC_ANNOTATION_FOR_TESTS),
649 IsError(ERR_NETWORK_IO_SUSPENDED));
650 }
651
TEST_F(TCPClientSocketTest,SuspendDuringReadAndWrite)652 TEST_F(TCPClientSocketTest, SuspendDuringReadAndWrite) {
653 enum class ReadCallbackAction {
654 kNone,
655 kDestroySocket,
656 kDisconnectSocket,
657 kReconnectSocket,
658 };
659
660 for (ReadCallbackAction read_callback_action : {
661 ReadCallbackAction::kNone,
662 ReadCallbackAction::kDestroySocket,
663 ReadCallbackAction::kDisconnectSocket,
664 ReadCallbackAction::kReconnectSocket,
665 }) {
666 std::unique_ptr<StreamSocket> accepted_socket;
667 std::unique_ptr<TCPClientSocket> client_socket;
668 std::unique_ptr<ServerSocket> server_socket;
669 CreateConnectedSockets(&accepted_socket, &client_socket, &server_socket);
670
671 // Start a read. This shouldn't complete, since the other end of the pipe
672 // writes no data.
673 scoped_refptr<IOBuffer> read_buffer = base::MakeRefCounted<IOBuffer>(1);
674 read_buffer->data()[0] = '1';
675 TestCompletionCallback read_callback;
676
677 // Used int the ReadCallbackAction::kReconnectSocket case, since can't run a
678 // nested message loop in the read callback.
679 TestCompletionCallback nested_connect_callback;
680 int nested_connect_result;
681
682 CompletionOnceCallback read_completion_once_callback =
683 base::BindLambdaForTesting([&](int result) {
684 EXPECT_FALSE(client_socket->IsConnected());
685 switch (read_callback_action) {
686 case ReadCallbackAction::kNone:
687 break;
688 case ReadCallbackAction::kDestroySocket:
689 client_socket.reset();
690 break;
691 case ReadCallbackAction::kDisconnectSocket:
692 client_socket->Disconnect();
693 break;
694 case ReadCallbackAction::kReconnectSocket: {
695 TestCompletionCallback connect_callback;
696 nested_connect_result =
697 client_socket->Connect(nested_connect_callback.callback());
698 break;
699 }
700 }
701 read_callback.callback().Run(result);
702 });
703 ASSERT_THAT(client_socket->Read(read_buffer.get(), 1,
704 std::move(read_completion_once_callback)),
705 IsError(ERR_IO_PENDING));
706
707 // Write to the socket until a write doesn't complete synchronously.
708 const int kBufferSize = 4096;
709 scoped_refptr<IOBuffer> write_buffer =
710 base::MakeRefCounted<IOBuffer>(kBufferSize);
711 memset(write_buffer->data(), '1', kBufferSize);
712 TestCompletionCallback write_callback;
713 while (true) {
714 int rv = client_socket->Write(write_buffer.get(), kBufferSize,
715 write_callback.callback(),
716 TRAFFIC_ANNOTATION_FOR_TESTS);
717 if (rv == ERR_IO_PENDING)
718 break;
719 ASSERT_GT(rv, 0);
720 }
721
722 // Simulate a suspend event. Can't use a real power event, as it would
723 // affect |accepted_socket| as well.
724 client_socket->OnSuspend();
725 EXPECT_THAT(read_callback.WaitForResult(),
726 IsError(ERR_NETWORK_IO_SUSPENDED));
727 if (read_callback_action == ReadCallbackAction::kNone) {
728 EXPECT_THAT(write_callback.WaitForResult(),
729 IsError(ERR_NETWORK_IO_SUSPENDED));
730
731 // Check that the client socket really is disconnected.
732 EXPECT_FALSE(client_socket->IsConnected());
733 EXPECT_THAT(
734 client_socket->Read(read_buffer.get(), 1, read_callback.callback()),
735 IsError(ERR_NETWORK_IO_SUSPENDED));
736 EXPECT_THAT(
737 client_socket->Write(write_buffer.get(), 1, write_callback.callback(),
738 TRAFFIC_ANNOTATION_FOR_TESTS),
739 IsError(ERR_NETWORK_IO_SUSPENDED));
740 } else {
741 // Each of the actions taken in the read callback will cancel the pending
742 // write callback.
743 EXPECT_FALSE(write_callback.have_result());
744 }
745
746 if (read_callback_action == ReadCallbackAction::kReconnectSocket) {
747 // Finish establishing a connection, just to make sure the reconnect case
748 // completely works.
749 accepted_socket.reset();
750 TestCompletionCallback accept_callback;
751 int accept_result =
752 server_socket->Accept(&accepted_socket, accept_callback.callback());
753 ASSERT_THAT(accept_callback.GetResult(accept_result), IsOk());
754 EXPECT_THAT(nested_connect_callback.GetResult(nested_connect_result),
755 IsOk());
756 }
757 }
758 }
759
760 #endif // defined(TCP_CLIENT_SOCKET_OBSERVES_SUSPEND)
761
762 // Scoped helper to override the TCP connect attempt policy.
763 class OverrideTcpConnectAttemptTimeout {
764 public:
OverrideTcpConnectAttemptTimeout(double rtt_multipilier,base::TimeDelta min_timeout,base::TimeDelta max_timeout)765 OverrideTcpConnectAttemptTimeout(double rtt_multipilier,
766 base::TimeDelta min_timeout,
767 base::TimeDelta max_timeout) {
768 base::FieldTrialParams params;
769 params[features::kTimeoutTcpConnectAttemptRTTMultiplier.name] =
770 base::NumberToString(rtt_multipilier);
771 params[features::kTimeoutTcpConnectAttemptMin.name] =
772 base::NumberToString(min_timeout.InMilliseconds()) + "ms";
773 params[features::kTimeoutTcpConnectAttemptMax.name] =
774 base::NumberToString(max_timeout.InMilliseconds()) + "ms";
775
776 scoped_feature_list_.InitAndEnableFeatureWithParameters(
777 features::kTimeoutTcpConnectAttempt, params);
778 }
779
780 private:
781 base::test::ScopedFeatureList scoped_feature_list_;
782 };
783
784 // Test fixture that uses a MOCK_TIME test environment, so time can
785 // be advanced programmatically.
786 class TCPClientSocketMockTimeTest : public testing::Test {
787 public:
TCPClientSocketMockTimeTest()788 TCPClientSocketMockTimeTest()
789 : task_environment_(base::test::TaskEnvironment::MainThreadType::IO,
790 base::test::TaskEnvironment::TimeSource::MOCK_TIME) {}
791
792 protected:
793 base::test::TaskEnvironment task_environment_;
794 };
795
796 // Tests that no TCP connect timeout is enforced by default (i.e.
797 // when the feature is disabled).
TEST_F(TCPClientSocketMockTimeTest,NoConnectAttemptTimeoutByDefault)798 TEST_F(TCPClientSocketMockTimeTest, NoConnectAttemptTimeoutByDefault) {
799 IPEndPoint server_address(IPAddress::IPv4Localhost(), 80);
800 NeverConnectingTCPClientSocket socket(AddressList(server_address), nullptr,
801 nullptr, nullptr, NetLogSource());
802
803 TestCompletionCallback connect_callback;
804 int rv = socket.Connect(connect_callback.callback());
805 ASSERT_THAT(rv, IsError(ERR_IO_PENDING));
806
807 // After 4 minutes, the socket should still be connecting.
808 task_environment_.FastForwardBy(base::TimeDelta::FromMinutes(4));
809 EXPECT_FALSE(connect_callback.have_result());
810 EXPECT_FALSE(socket.IsConnected());
811
812 // 1 attempt was made.
813 EXPECT_EQ(1, socket.connect_internal_counter());
814 }
815
816 // Tests that the maximum timeout is used when there is no estimated
817 // RTT.
TEST_F(TCPClientSocketMockTimeTest,ConnectAttemptTimeoutUsesMaxWhenNoRTT)818 TEST_F(TCPClientSocketMockTimeTest, ConnectAttemptTimeoutUsesMaxWhenNoRTT) {
819 OverrideTcpConnectAttemptTimeout override_timeout(
820 1, base::TimeDelta::FromSeconds(4), base::TimeDelta::FromSeconds(10));
821
822 IPEndPoint server_address(IPAddress::IPv4Localhost(), 80);
823
824 // Pass a null NetworkQualityEstimator, so the TCPClientSocket is unable to
825 // estimate the RTT.
826 NeverConnectingTCPClientSocket socket(AddressList(server_address), nullptr,
827 nullptr, nullptr, NetLogSource());
828
829 // Start connecting.
830 TestCompletionCallback connect_callback;
831 int rv = socket.Connect(connect_callback.callback());
832 ASSERT_THAT(rv, IsError(ERR_IO_PENDING));
833
834 // Advance to t=3.1s
835 // Should still be pending, as this is before the minimum timeout.
836 task_environment_.FastForwardBy(base::TimeDelta::FromMilliseconds(3100));
837 EXPECT_FALSE(connect_callback.have_result());
838 EXPECT_FALSE(socket.IsConnected());
839
840 // Advance to t=4.1s
841 // Should still be pending. This is after the minimum timeout, but before the
842 // maximum.
843 task_environment_.FastForwardBy(base::TimeDelta::FromSeconds(1));
844 EXPECT_FALSE(connect_callback.have_result());
845 EXPECT_FALSE(socket.IsConnected());
846
847 // Advance to t=10.1s
848 // Should now be timed out, as this is after the maximum timeout.
849 task_environment_.FastForwardBy(base::TimeDelta::FromSeconds(6));
850 rv = connect_callback.GetResult(rv);
851 ASSERT_THAT(rv, IsError(ERR_TIMED_OUT));
852
853 // 1 attempt was made.
854 EXPECT_EQ(1, socket.connect_internal_counter());
855 }
856
857 // Tests that the minimum timeout is used when the adaptive timeout using RTT
858 // ends up being too low.
TEST_F(TCPClientSocketMockTimeTest,ConnectAttemptTimeoutUsesMinWhenRTTLow)859 TEST_F(TCPClientSocketMockTimeTest, ConnectAttemptTimeoutUsesMinWhenRTTLow) {
860 OverrideTcpConnectAttemptTimeout override_timeout(
861 5, base::TimeDelta::FromSeconds(4), base::TimeDelta::FromSeconds(10));
862
863 // Set the estimated RTT to 1 millisecond.
864 TestNetworkQualityEstimator network_quality_estimator;
865 network_quality_estimator.SetStartTimeNullTransportRtt(
866 base::TimeDelta::FromMilliseconds(1));
867
868 IPEndPoint server_address(IPAddress::IPv4Localhost(), 80);
869
870 NeverConnectingTCPClientSocket socket(AddressList(server_address), nullptr,
871 &network_quality_estimator, nullptr,
872 NetLogSource());
873
874 // Start connecting.
875 TestCompletionCallback connect_callback;
876 int rv = socket.Connect(connect_callback.callback());
877 ASSERT_THAT(rv, IsError(ERR_IO_PENDING));
878
879 // Advance to t=1.1s
880 // Should be pending, since although the adaptive timeout has been reached, it
881 // is lower than the minimum timeout.
882 task_environment_.FastForwardBy(base::TimeDelta::FromMilliseconds(1100));
883 EXPECT_FALSE(connect_callback.have_result());
884 EXPECT_FALSE(socket.IsConnected());
885
886 // Advance to t=4.1s
887 // Should have timed out due to hitting the minimum timeout.
888 task_environment_.FastForwardBy(base::TimeDelta::FromSeconds(3));
889 rv = connect_callback.GetResult(rv);
890 ASSERT_THAT(rv, IsError(ERR_TIMED_OUT));
891
892 // 1 attempt was made.
893 EXPECT_EQ(1, socket.connect_internal_counter());
894 }
895
896 // Tests that the maximum timeout is used when the adaptive timeout from RTT is
897 // too high.
TEST_F(TCPClientSocketMockTimeTest,ConnectAttemptTimeoutUsesMinWhenRTTHigh)898 TEST_F(TCPClientSocketMockTimeTest, ConnectAttemptTimeoutUsesMinWhenRTTHigh) {
899 OverrideTcpConnectAttemptTimeout override_timeout(
900 5, base::TimeDelta::FromSeconds(4), base::TimeDelta::FromSeconds(10));
901
902 // Set the estimated RTT to 5 seconds.
903 TestNetworkQualityEstimator network_quality_estimator;
904 network_quality_estimator.SetStartTimeNullTransportRtt(
905 base::TimeDelta::FromSeconds(5));
906
907 IPEndPoint server_address(IPAddress::IPv4Localhost(), 80);
908
909 NeverConnectingTCPClientSocket socket(AddressList(server_address), nullptr,
910 &network_quality_estimator, nullptr,
911 NetLogSource());
912
913 // Start connecting.
914 TestCompletionCallback connect_callback;
915 int rv = socket.Connect(connect_callback.callback());
916 ASSERT_THAT(rv, IsError(ERR_IO_PENDING));
917
918 // Advance to t=10.1s
919 // The socket should have timed out due to hitting the maximum timeout. Had
920 // the adaptive timeout been used, the socket would instead be timing out at
921 // t=25s.
922 task_environment_.FastForwardBy(base::TimeDelta::FromMilliseconds(10100));
923 rv = connect_callback.GetResult(rv);
924 ASSERT_THAT(rv, IsError(ERR_TIMED_OUT));
925
926 // 1 attempt was made.
927 EXPECT_EQ(1, socket.connect_internal_counter());
928 }
929
930 // Tests that an adaptive timeout is used for TCP connection attempts based on
931 // the estimated RTT.
TEST_F(TCPClientSocketMockTimeTest,ConnectAttemptTimeoutUsesRTT)932 TEST_F(TCPClientSocketMockTimeTest, ConnectAttemptTimeoutUsesRTT) {
933 OverrideTcpConnectAttemptTimeout override_timeout(
934 5, base::TimeDelta::FromSeconds(4), base::TimeDelta::FromSeconds(10));
935
936 // Set the estimated RTT to 1 second. Since the multiplier is set to 5, the
937 // total adaptive timeout will be 5 seconds.
938 TestNetworkQualityEstimator network_quality_estimator;
939 network_quality_estimator.SetStartTimeNullTransportRtt(
940 base::TimeDelta::FromSeconds(1));
941
942 IPEndPoint server_address(IPAddress::IPv4Localhost(), 80);
943
944 NeverConnectingTCPClientSocket socket(AddressList(server_address), nullptr,
945 &network_quality_estimator, nullptr,
946 NetLogSource());
947
948 // Start connecting.
949 TestCompletionCallback connect_callback;
950 int rv = socket.Connect(connect_callback.callback());
951 ASSERT_THAT(rv, IsError(ERR_IO_PENDING));
952
953 // Advance to t=4.1s
954 // The socket should still be pending. Had the minimum timeout been enforced,
955 // it would instead have timed out now.
956 task_environment_.FastForwardBy(base::TimeDelta::FromMilliseconds(4100));
957 EXPECT_FALSE(connect_callback.have_result());
958 EXPECT_FALSE(socket.IsConnected());
959
960 // Advance to t=5.1s
961 // The adaptive timeout was at t=5s, so it should now be timed out.
962 task_environment_.FastForwardBy(base::TimeDelta::FromSeconds(1));
963 rv = connect_callback.GetResult(rv);
964 ASSERT_THAT(rv, IsError(ERR_TIMED_OUT));
965
966 // 1 attempt was made.
967 EXPECT_EQ(1, socket.connect_internal_counter());
968 }
969
970 // Tests that when multiple TCP connect attempts are made, the timeout for each
971 // one is applied independently.
TEST_F(TCPClientSocketMockTimeTest,ConnectAttemptTimeoutIndependent)972 TEST_F(TCPClientSocketMockTimeTest, ConnectAttemptTimeoutIndependent) {
973 OverrideTcpConnectAttemptTimeout override_timeout(
974 5, base::TimeDelta::FromSeconds(4), base::TimeDelta::FromSeconds(10));
975
976 // This test will attempt connecting to 5 endpoints.
977 const size_t kNumIps = 5;
978
979 AddressList addresses;
980 for (size_t i = 0; i < kNumIps; ++i)
981 addresses.push_back(IPEndPoint(IPAddress::IPv4Localhost(), 80 + i));
982
983 NeverConnectingTCPClientSocket socket(addresses, nullptr, nullptr, nullptr,
984 NetLogSource());
985
986 // Start connecting.
987 TestCompletionCallback connect_callback;
988 int rv = socket.Connect(connect_callback.callback());
989 ASSERT_THAT(rv, IsError(ERR_IO_PENDING));
990
991 // Advance to t=49s
992 // Should still be pending.
993 task_environment_.FastForwardBy(base::TimeDelta::FromSeconds(49));
994 EXPECT_FALSE(connect_callback.have_result());
995 EXPECT_FALSE(socket.IsConnected());
996
997 // Advance to t=50.1s
998 // All attempts should take 50 seconds to complete (5 attempts, 10 seconds
999 // each). So by this point the overall connect attempt will have timed out.
1000 task_environment_.FastForwardBy(base::TimeDelta::FromMilliseconds(1100));
1001 rv = connect_callback.GetResult(rv);
1002 ASSERT_THAT(rv, IsError(ERR_TIMED_OUT));
1003
1004 // 5 attempts were made.
1005 EXPECT_EQ(5, socket.connect_internal_counter());
1006 }
1007
1008 } // namespace
1009
1010 } // namespace net
1011