1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/transport_client_socket_pool_test_util.h"
6
7 #include <stdint.h>
8 #include <string>
9 #include <utility>
10
11 #include "base/bind.h"
12 #include "base/location.h"
13 #include "base/logging.h"
14 #include "base/macros.h"
15 #include "base/memory/weak_ptr.h"
16 #include "base/run_loop.h"
17 #include "base/single_thread_task_runner.h"
18 #include "base/threading/thread_task_runner_handle.h"
19 #include "net/base/ip_address.h"
20 #include "net/base/ip_endpoint.h"
21 #include "net/base/load_timing_info.h"
22 #include "net/base/load_timing_info_test_util.h"
23 #include "net/log/net_log_source.h"
24 #include "net/log/net_log_source_type.h"
25 #include "net/log/net_log_with_source.h"
26 #include "net/socket/client_socket_handle.h"
27 #include "net/socket/datagram_client_socket.h"
28 #include "net/socket/ssl_client_socket.h"
29 #include "net/socket/transport_client_socket.h"
30 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
31 #include "testing/gtest/include/gtest/gtest.h"
32
33 namespace net {
34
35 namespace {
36
ParseIP(const std::string & ip)37 IPAddress ParseIP(const std::string& ip) {
38 IPAddress address;
39 CHECK(address.AssignFromIPLiteral(ip));
40 return address;
41 }
42
43 // A StreamSocket which connects synchronously and successfully.
44 class MockConnectClientSocket : public TransportClientSocket {
45 public:
MockConnectClientSocket(const AddressList & addrlist,net::NetLog * net_log)46 MockConnectClientSocket(const AddressList& addrlist, net::NetLog* net_log)
47 : connected_(false),
48 addrlist_(addrlist),
49 net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::SOCKET)) {}
50
51 // TransportClientSocket implementation.
Bind(const net::IPEndPoint & local_addr)52 int Bind(const net::IPEndPoint& local_addr) override {
53 NOTREACHED();
54 return ERR_FAILED;
55 }
56 // StreamSocket implementation.
Connect(CompletionOnceCallback callback)57 int Connect(CompletionOnceCallback callback) override {
58 connected_ = true;
59 return OK;
60 }
Disconnect()61 void Disconnect() override { connected_ = false; }
IsConnected() const62 bool IsConnected() const override { return connected_; }
IsConnectedAndIdle() const63 bool IsConnectedAndIdle() const override { return connected_; }
64
GetPeerAddress(IPEndPoint * address) const65 int GetPeerAddress(IPEndPoint* address) const override {
66 *address = addrlist_.front();
67 return OK;
68 }
GetLocalAddress(IPEndPoint * address) const69 int GetLocalAddress(IPEndPoint* address) const override {
70 if (!connected_)
71 return ERR_SOCKET_NOT_CONNECTED;
72 if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
73 SetIPv4Address(address);
74 else
75 SetIPv6Address(address);
76 return OK;
77 }
NetLog() const78 const NetLogWithSource& NetLog() const override { return net_log_; }
79
WasEverUsed() const80 bool WasEverUsed() const override { return false; }
WasAlpnNegotiated() const81 bool WasAlpnNegotiated() const override { return false; }
GetNegotiatedProtocol() const82 NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
GetSSLInfo(SSLInfo * ssl_info)83 bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
GetConnectionAttempts(ConnectionAttempts * out) const84 void GetConnectionAttempts(ConnectionAttempts* out) const override {
85 out->clear();
86 }
ClearConnectionAttempts()87 void ClearConnectionAttempts() override {}
AddConnectionAttempts(const ConnectionAttempts & attempts)88 void AddConnectionAttempts(const ConnectionAttempts& attempts) override {}
GetTotalReceivedBytes() const89 int64_t GetTotalReceivedBytes() const override {
90 NOTIMPLEMENTED();
91 return 0;
92 }
ApplySocketTag(const SocketTag & tag)93 void ApplySocketTag(const SocketTag& tag) override {}
94
95 // Socket implementation.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)96 int Read(IOBuffer* buf,
97 int buf_len,
98 CompletionOnceCallback callback) override {
99 return ERR_FAILED;
100 }
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)101 int Write(IOBuffer* buf,
102 int buf_len,
103 CompletionOnceCallback callback,
104 const NetworkTrafficAnnotationTag& traffic_annotation) override {
105 return ERR_FAILED;
106 }
SetReceiveBufferSize(int32_t size)107 int SetReceiveBufferSize(int32_t size) override { return OK; }
SetSendBufferSize(int32_t size)108 int SetSendBufferSize(int32_t size) override { return OK; }
109
110 private:
111 bool connected_;
112 const AddressList addrlist_;
113 NetLogWithSource net_log_;
114
115 DISALLOW_COPY_AND_ASSIGN(MockConnectClientSocket);
116 };
117
118 class MockFailingClientSocket : public TransportClientSocket {
119 public:
MockFailingClientSocket(const AddressList & addrlist,net::NetLog * net_log)120 MockFailingClientSocket(const AddressList& addrlist, net::NetLog* net_log)
121 : addrlist_(addrlist),
122 net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::SOCKET)) {}
123
124 // TransportClientSocket implementation.
Bind(const net::IPEndPoint & local_addr)125 int Bind(const net::IPEndPoint& local_addr) override {
126 NOTREACHED();
127 return ERR_FAILED;
128 }
129
130 // StreamSocket implementation.
Connect(CompletionOnceCallback callback)131 int Connect(CompletionOnceCallback callback) override {
132 return ERR_CONNECTION_FAILED;
133 }
134
Disconnect()135 void Disconnect() override {}
136
IsConnected() const137 bool IsConnected() const override { return false; }
IsConnectedAndIdle() const138 bool IsConnectedAndIdle() const override { return false; }
GetPeerAddress(IPEndPoint * address) const139 int GetPeerAddress(IPEndPoint* address) const override {
140 return ERR_UNEXPECTED;
141 }
GetLocalAddress(IPEndPoint * address) const142 int GetLocalAddress(IPEndPoint* address) const override {
143 return ERR_UNEXPECTED;
144 }
NetLog() const145 const NetLogWithSource& NetLog() const override { return net_log_; }
146
WasEverUsed() const147 bool WasEverUsed() const override { return false; }
WasAlpnNegotiated() const148 bool WasAlpnNegotiated() const override { return false; }
GetNegotiatedProtocol() const149 NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
GetSSLInfo(SSLInfo * ssl_info)150 bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
GetConnectionAttempts(ConnectionAttempts * out) const151 void GetConnectionAttempts(ConnectionAttempts* out) const override {
152 out->clear();
153 for (const auto& addr : addrlist_)
154 out->push_back(ConnectionAttempt(addr, ERR_CONNECTION_FAILED));
155 }
ClearConnectionAttempts()156 void ClearConnectionAttempts() override {}
AddConnectionAttempts(const ConnectionAttempts & attempts)157 void AddConnectionAttempts(const ConnectionAttempts& attempts) override {}
GetTotalReceivedBytes() const158 int64_t GetTotalReceivedBytes() const override {
159 NOTIMPLEMENTED();
160 return 0;
161 }
ApplySocketTag(const SocketTag & tag)162 void ApplySocketTag(const SocketTag& tag) override {}
163
164 // Socket implementation.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)165 int Read(IOBuffer* buf,
166 int buf_len,
167 CompletionOnceCallback callback) override {
168 return ERR_FAILED;
169 }
170
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)171 int Write(IOBuffer* buf,
172 int buf_len,
173 CompletionOnceCallback callback,
174 const NetworkTrafficAnnotationTag& traffic_annotation) override {
175 return ERR_FAILED;
176 }
SetReceiveBufferSize(int32_t size)177 int SetReceiveBufferSize(int32_t size) override { return OK; }
SetSendBufferSize(int32_t size)178 int SetSendBufferSize(int32_t size) override { return OK; }
179
180 private:
181 const AddressList addrlist_;
182 NetLogWithSource net_log_;
183
184 DISALLOW_COPY_AND_ASSIGN(MockFailingClientSocket);
185 };
186
187 class MockTriggerableClientSocket : public TransportClientSocket {
188 public:
189 // |should_connect| indicates whether the socket should successfully complete
190 // or fail.
MockTriggerableClientSocket(const AddressList & addrlist,bool should_connect,net::NetLog * net_log)191 MockTriggerableClientSocket(const AddressList& addrlist,
192 bool should_connect,
193 net::NetLog* net_log)
194 : should_connect_(should_connect),
195 is_connected_(false),
196 addrlist_(addrlist),
197 net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::SOCKET)) {}
198
199 // Call this method to get a closure which will trigger the connect callback
200 // when called. The closure can be called even after the socket is deleted; it
201 // will safely do nothing.
GetConnectCallback()202 base::OnceClosure GetConnectCallback() {
203 return base::BindOnce(&MockTriggerableClientSocket::DoCallback,
204 weak_factory_.GetWeakPtr());
205 }
206
MakeMockPendingClientSocket(const AddressList & addrlist,bool should_connect,net::NetLog * net_log)207 static std::unique_ptr<TransportClientSocket> MakeMockPendingClientSocket(
208 const AddressList& addrlist,
209 bool should_connect,
210 net::NetLog* net_log) {
211 std::unique_ptr<MockTriggerableClientSocket> socket(
212 new MockTriggerableClientSocket(addrlist, should_connect, net_log));
213 base::ThreadTaskRunnerHandle::Get()->PostTask(FROM_HERE,
214 socket->GetConnectCallback());
215 return std::move(socket);
216 }
217
MakeMockDelayedClientSocket(const AddressList & addrlist,bool should_connect,const base::TimeDelta & delay,net::NetLog * net_log)218 static std::unique_ptr<TransportClientSocket> MakeMockDelayedClientSocket(
219 const AddressList& addrlist,
220 bool should_connect,
221 const base::TimeDelta& delay,
222 net::NetLog* net_log) {
223 std::unique_ptr<MockTriggerableClientSocket> socket(
224 new MockTriggerableClientSocket(addrlist, should_connect, net_log));
225 base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
226 FROM_HERE, socket->GetConnectCallback(), delay);
227 return std::move(socket);
228 }
229
MakeMockStalledClientSocket(const AddressList & addrlist,net::NetLog * net_log,bool failing)230 static std::unique_ptr<TransportClientSocket> MakeMockStalledClientSocket(
231 const AddressList& addrlist,
232 net::NetLog* net_log,
233 bool failing) {
234 std::unique_ptr<MockTriggerableClientSocket> socket(
235 new MockTriggerableClientSocket(addrlist, true, net_log));
236 if (failing) {
237 DCHECK_LE(1u, addrlist.size());
238 ConnectionAttempts attempts;
239 attempts.push_back(ConnectionAttempt(addrlist[0], ERR_CONNECTION_FAILED));
240 socket->AddConnectionAttempts(attempts);
241 }
242 return std::move(socket);
243 }
244
245 // TransportClientSocket implementation.
Bind(const net::IPEndPoint & local_addr)246 int Bind(const net::IPEndPoint& local_addr) override {
247 NOTREACHED();
248 return ERR_FAILED;
249 }
250
251 // StreamSocket implementation.
Connect(CompletionOnceCallback callback)252 int Connect(CompletionOnceCallback callback) override {
253 DCHECK(callback_.is_null());
254 callback_ = std::move(callback);
255 return ERR_IO_PENDING;
256 }
257
Disconnect()258 void Disconnect() override {}
259
IsConnected() const260 bool IsConnected() const override { return is_connected_; }
IsConnectedAndIdle() const261 bool IsConnectedAndIdle() const override { return is_connected_; }
GetPeerAddress(IPEndPoint * address) const262 int GetPeerAddress(IPEndPoint* address) const override {
263 *address = addrlist_.front();
264 return OK;
265 }
GetLocalAddress(IPEndPoint * address) const266 int GetLocalAddress(IPEndPoint* address) const override {
267 if (!is_connected_)
268 return ERR_SOCKET_NOT_CONNECTED;
269 if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
270 SetIPv4Address(address);
271 else
272 SetIPv6Address(address);
273 return OK;
274 }
NetLog() const275 const NetLogWithSource& NetLog() const override { return net_log_; }
276
WasEverUsed() const277 bool WasEverUsed() const override { return false; }
WasAlpnNegotiated() const278 bool WasAlpnNegotiated() const override { return false; }
GetNegotiatedProtocol() const279 NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
GetSSLInfo(SSLInfo * ssl_info)280 bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
GetConnectionAttempts(ConnectionAttempts * out) const281 void GetConnectionAttempts(ConnectionAttempts* out) const override {
282 *out = connection_attempts_;
283 }
ClearConnectionAttempts()284 void ClearConnectionAttempts() override { connection_attempts_.clear(); }
AddConnectionAttempts(const ConnectionAttempts & attempts)285 void AddConnectionAttempts(const ConnectionAttempts& attempts) override {
286 connection_attempts_.insert(connection_attempts_.begin(), attempts.begin(),
287 attempts.end());
288 }
GetTotalReceivedBytes() const289 int64_t GetTotalReceivedBytes() const override {
290 NOTIMPLEMENTED();
291 return 0;
292 }
ApplySocketTag(const SocketTag & tag)293 void ApplySocketTag(const SocketTag& tag) override {}
294
295 // Socket implementation.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)296 int Read(IOBuffer* buf,
297 int buf_len,
298 CompletionOnceCallback callback) override {
299 return ERR_FAILED;
300 }
301
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)302 int Write(IOBuffer* buf,
303 int buf_len,
304 CompletionOnceCallback callback,
305 const NetworkTrafficAnnotationTag& traffic_annotation) override {
306 return ERR_FAILED;
307 }
SetReceiveBufferSize(int32_t size)308 int SetReceiveBufferSize(int32_t size) override { return OK; }
SetSendBufferSize(int32_t size)309 int SetSendBufferSize(int32_t size) override { return OK; }
310
311 private:
DoCallback()312 void DoCallback() {
313 is_connected_ = should_connect_;
314 std::move(callback_).Run(is_connected_ ? OK : ERR_CONNECTION_FAILED);
315 }
316
317 bool should_connect_;
318 bool is_connected_;
319 const AddressList addrlist_;
320 NetLogWithSource net_log_;
321 CompletionOnceCallback callback_;
322 ConnectionAttempts connection_attempts_;
323
324 base::WeakPtrFactory<MockTriggerableClientSocket> weak_factory_{this};
325
326 DISALLOW_COPY_AND_ASSIGN(MockTriggerableClientSocket);
327 };
328
329 } // namespace
330
TestLoadTimingInfoConnectedReused(const ClientSocketHandle & handle)331 void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle) {
332 LoadTimingInfo load_timing_info;
333 // Only pass true in as |is_reused|, as in general, HttpStream types should
334 // have stricter concepts of reuse than socket pools.
335 EXPECT_TRUE(handle.GetLoadTimingInfo(true, &load_timing_info));
336
337 EXPECT_TRUE(load_timing_info.socket_reused);
338 EXPECT_NE(NetLogSource::kInvalidId, load_timing_info.socket_log_id);
339
340 ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing);
341 ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
342 }
343
TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle & handle)344 void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle) {
345 EXPECT_FALSE(handle.is_reused());
346
347 LoadTimingInfo load_timing_info;
348 EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info));
349
350 EXPECT_FALSE(load_timing_info.socket_reused);
351 EXPECT_NE(NetLogSource::kInvalidId, load_timing_info.socket_log_id);
352
353 ExpectConnectTimingHasTimes(load_timing_info.connect_timing,
354 CONNECT_TIMING_HAS_DNS_TIMES);
355 ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
356
357 TestLoadTimingInfoConnectedReused(handle);
358 }
359
SetIPv4Address(IPEndPoint * address)360 void SetIPv4Address(IPEndPoint* address) {
361 *address = IPEndPoint(ParseIP("1.1.1.1"), 80);
362 }
363
SetIPv6Address(IPEndPoint * address)364 void SetIPv6Address(IPEndPoint* address) {
365 *address = IPEndPoint(ParseIP("1:abcd::3:4:ff"), 80);
366 }
367
MockTransportClientSocketFactory(NetLog * net_log)368 MockTransportClientSocketFactory::MockTransportClientSocketFactory(
369 NetLog* net_log)
370 : net_log_(net_log),
371 allocation_count_(0),
372 client_socket_type_(MOCK_CLIENT_SOCKET),
373 client_socket_types_(nullptr),
374 client_socket_index_(0),
375 client_socket_index_max_(0),
376 delay_(base::TimeDelta::FromMilliseconds(
377 ClientSocketPool::kMaxConnectRetryIntervalMs)) {}
378
379 MockTransportClientSocketFactory::~MockTransportClientSocketFactory() = default;
380
381 std::unique_ptr<DatagramClientSocket>
CreateDatagramClientSocket(DatagramSocket::BindType bind_type,NetLog * net_log,const NetLogSource & source)382 MockTransportClientSocketFactory::CreateDatagramClientSocket(
383 DatagramSocket::BindType bind_type,
384 NetLog* net_log,
385 const NetLogSource& source) {
386 NOTREACHED();
387 return std::unique_ptr<DatagramClientSocket>();
388 }
389
390 std::unique_ptr<TransportClientSocket>
CreateTransportClientSocket(const AddressList & addresses,std::unique_ptr<SocketPerformanceWatcher>,NetLog *,const NetLogSource &)391 MockTransportClientSocketFactory::CreateTransportClientSocket(
392 const AddressList& addresses,
393 std::unique_ptr<SocketPerformanceWatcher> /* socket_performance_watcher */,
394 NetLog* /* net_log */,
395 const NetLogSource& /* source */) {
396 allocation_count_++;
397
398 ClientSocketType type = client_socket_type_;
399 if (client_socket_types_ && client_socket_index_ < client_socket_index_max_) {
400 type = client_socket_types_[client_socket_index_++];
401 }
402
403 switch (type) {
404 case MOCK_CLIENT_SOCKET:
405 return std::make_unique<MockConnectClientSocket>(addresses, net_log_);
406 case MOCK_FAILING_CLIENT_SOCKET:
407 return std::make_unique<MockFailingClientSocket>(addresses, net_log_);
408 case MOCK_PENDING_CLIENT_SOCKET:
409 return MockTriggerableClientSocket::MakeMockPendingClientSocket(
410 addresses, true, net_log_);
411 case MOCK_PENDING_FAILING_CLIENT_SOCKET:
412 return MockTriggerableClientSocket::MakeMockPendingClientSocket(
413 addresses, false, net_log_);
414 case MOCK_DELAYED_CLIENT_SOCKET:
415 return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
416 addresses, true, delay_, net_log_);
417 case MOCK_DELAYED_FAILING_CLIENT_SOCKET:
418 return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
419 addresses, false, delay_, net_log_);
420 case MOCK_STALLED_CLIENT_SOCKET:
421 return MockTriggerableClientSocket::MakeMockStalledClientSocket(
422 addresses, net_log_, false);
423 case MOCK_STALLED_FAILING_CLIENT_SOCKET:
424 return MockTriggerableClientSocket::MakeMockStalledClientSocket(
425 addresses, net_log_, true);
426 case MOCK_TRIGGERABLE_CLIENT_SOCKET: {
427 std::unique_ptr<MockTriggerableClientSocket> rv(
428 new MockTriggerableClientSocket(addresses, true, net_log_));
429 triggerable_sockets_.push(rv->GetConnectCallback());
430 // run_loop_quit_closure_ behaves like a condition variable. It will
431 // wake up WaitForTriggerableSocketCreation() if it is sleeping. We
432 // don't need to worry about atomicity because this code is
433 // single-threaded.
434 if (!run_loop_quit_closure_.is_null())
435 std::move(run_loop_quit_closure_).Run();
436 return std::move(rv);
437 }
438 default:
439 NOTREACHED();
440 return std::make_unique<MockConnectClientSocket>(addresses, net_log_);
441 }
442 }
443
444 std::unique_ptr<SSLClientSocket>
CreateSSLClientSocket(SSLClientContext * context,std::unique_ptr<StreamSocket> stream_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config)445 MockTransportClientSocketFactory::CreateSSLClientSocket(
446 SSLClientContext* context,
447 std::unique_ptr<StreamSocket> stream_socket,
448 const HostPortPair& host_and_port,
449 const SSLConfig& ssl_config) {
450 NOTIMPLEMENTED();
451 return nullptr;
452 }
453
454 std::unique_ptr<ProxyClientSocket>
CreateProxyClientSocket(std::unique_ptr<StreamSocket> stream_socket,const std::string & user_agent,const HostPortPair & endpoint,const ProxyServer & proxy_server,HttpAuthController * http_auth_controller,bool tunnel,bool using_spdy,NextProto negotiated_protocol,ProxyDelegate * proxy_delegate,const NetworkTrafficAnnotationTag & traffic_annotation)455 MockTransportClientSocketFactory::CreateProxyClientSocket(
456 std::unique_ptr<StreamSocket> stream_socket,
457 const std::string& user_agent,
458 const HostPortPair& endpoint,
459 const ProxyServer& proxy_server,
460 HttpAuthController* http_auth_controller,
461 bool tunnel,
462 bool using_spdy,
463 NextProto negotiated_protocol,
464 ProxyDelegate* proxy_delegate,
465 const NetworkTrafficAnnotationTag& traffic_annotation) {
466 NOTIMPLEMENTED();
467 return nullptr;
468 }
469
set_client_socket_types(ClientSocketType * type_list,int num_types)470 void MockTransportClientSocketFactory::set_client_socket_types(
471 ClientSocketType* type_list,
472 int num_types) {
473 DCHECK_GT(num_types, 0);
474 client_socket_types_ = type_list;
475 client_socket_index_ = 0;
476 client_socket_index_max_ = num_types;
477 }
478
479 base::OnceClosure
WaitForTriggerableSocketCreation()480 MockTransportClientSocketFactory::WaitForTriggerableSocketCreation() {
481 while (triggerable_sockets_.empty()) {
482 base::RunLoop run_loop;
483 run_loop_quit_closure_ = run_loop.QuitClosure();
484 run_loop.Run();
485 run_loop_quit_closure_.Reset();
486 }
487 base::OnceClosure trigger = std::move(triggerable_sockets_.front());
488 triggerable_sockets_.pop();
489 return trigger;
490 }
491
492 } // namespace net
493