1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/transport_client_socket_pool_test_util.h"
6 
7 #include <stdint.h>
8 #include <string>
9 #include <utility>
10 
11 #include "base/bind.h"
12 #include "base/location.h"
13 #include "base/logging.h"
14 #include "base/macros.h"
15 #include "base/memory/weak_ptr.h"
16 #include "base/run_loop.h"
17 #include "base/single_thread_task_runner.h"
18 #include "base/threading/thread_task_runner_handle.h"
19 #include "net/base/ip_address.h"
20 #include "net/base/ip_endpoint.h"
21 #include "net/base/load_timing_info.h"
22 #include "net/base/load_timing_info_test_util.h"
23 #include "net/log/net_log_source.h"
24 #include "net/log/net_log_source_type.h"
25 #include "net/log/net_log_with_source.h"
26 #include "net/socket/client_socket_handle.h"
27 #include "net/socket/datagram_client_socket.h"
28 #include "net/socket/ssl_client_socket.h"
29 #include "net/socket/transport_client_socket.h"
30 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
31 #include "testing/gtest/include/gtest/gtest.h"
32 
33 namespace net {
34 
35 namespace {
36 
ParseIP(const std::string & ip)37 IPAddress ParseIP(const std::string& ip) {
38   IPAddress address;
39   CHECK(address.AssignFromIPLiteral(ip));
40   return address;
41 }
42 
43 // A StreamSocket which connects synchronously and successfully.
44 class MockConnectClientSocket : public TransportClientSocket {
45  public:
MockConnectClientSocket(const AddressList & addrlist,net::NetLog * net_log)46   MockConnectClientSocket(const AddressList& addrlist, net::NetLog* net_log)
47       : connected_(false),
48         addrlist_(addrlist),
49         net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::SOCKET)) {}
50 
51   // TransportClientSocket implementation.
Bind(const net::IPEndPoint & local_addr)52   int Bind(const net::IPEndPoint& local_addr) override {
53     NOTREACHED();
54     return ERR_FAILED;
55   }
56   // StreamSocket implementation.
Connect(CompletionOnceCallback callback)57   int Connect(CompletionOnceCallback callback) override {
58     connected_ = true;
59     return OK;
60   }
Disconnect()61   void Disconnect() override { connected_ = false; }
IsConnected() const62   bool IsConnected() const override { return connected_; }
IsConnectedAndIdle() const63   bool IsConnectedAndIdle() const override { return connected_; }
64 
GetPeerAddress(IPEndPoint * address) const65   int GetPeerAddress(IPEndPoint* address) const override {
66     *address = addrlist_.front();
67     return OK;
68   }
GetLocalAddress(IPEndPoint * address) const69   int GetLocalAddress(IPEndPoint* address) const override {
70     if (!connected_)
71       return ERR_SOCKET_NOT_CONNECTED;
72     if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
73       SetIPv4Address(address);
74     else
75       SetIPv6Address(address);
76     return OK;
77   }
NetLog() const78   const NetLogWithSource& NetLog() const override { return net_log_; }
79 
WasEverUsed() const80   bool WasEverUsed() const override { return false; }
WasAlpnNegotiated() const81   bool WasAlpnNegotiated() const override { return false; }
GetNegotiatedProtocol() const82   NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
GetSSLInfo(SSLInfo * ssl_info)83   bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
GetConnectionAttempts(ConnectionAttempts * out) const84   void GetConnectionAttempts(ConnectionAttempts* out) const override {
85     out->clear();
86   }
ClearConnectionAttempts()87   void ClearConnectionAttempts() override {}
AddConnectionAttempts(const ConnectionAttempts & attempts)88   void AddConnectionAttempts(const ConnectionAttempts& attempts) override {}
GetTotalReceivedBytes() const89   int64_t GetTotalReceivedBytes() const override {
90     NOTIMPLEMENTED();
91     return 0;
92   }
ApplySocketTag(const SocketTag & tag)93   void ApplySocketTag(const SocketTag& tag) override {}
94 
95   // Socket implementation.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)96   int Read(IOBuffer* buf,
97            int buf_len,
98            CompletionOnceCallback callback) override {
99     return ERR_FAILED;
100   }
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)101   int Write(IOBuffer* buf,
102             int buf_len,
103             CompletionOnceCallback callback,
104             const NetworkTrafficAnnotationTag& traffic_annotation) override {
105     return ERR_FAILED;
106   }
SetReceiveBufferSize(int32_t size)107   int SetReceiveBufferSize(int32_t size) override { return OK; }
SetSendBufferSize(int32_t size)108   int SetSendBufferSize(int32_t size) override { return OK; }
109 
110  private:
111   bool connected_;
112   const AddressList addrlist_;
113   NetLogWithSource net_log_;
114 
115   DISALLOW_COPY_AND_ASSIGN(MockConnectClientSocket);
116 };
117 
118 class MockFailingClientSocket : public TransportClientSocket {
119  public:
MockFailingClientSocket(const AddressList & addrlist,net::NetLog * net_log)120   MockFailingClientSocket(const AddressList& addrlist, net::NetLog* net_log)
121       : addrlist_(addrlist),
122         net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::SOCKET)) {}
123 
124   // TransportClientSocket implementation.
Bind(const net::IPEndPoint & local_addr)125   int Bind(const net::IPEndPoint& local_addr) override {
126     NOTREACHED();
127     return ERR_FAILED;
128   }
129 
130   // StreamSocket implementation.
Connect(CompletionOnceCallback callback)131   int Connect(CompletionOnceCallback callback) override {
132     return ERR_CONNECTION_FAILED;
133   }
134 
Disconnect()135   void Disconnect() override {}
136 
IsConnected() const137   bool IsConnected() const override { return false; }
IsConnectedAndIdle() const138   bool IsConnectedAndIdle() const override { return false; }
GetPeerAddress(IPEndPoint * address) const139   int GetPeerAddress(IPEndPoint* address) const override {
140     return ERR_UNEXPECTED;
141   }
GetLocalAddress(IPEndPoint * address) const142   int GetLocalAddress(IPEndPoint* address) const override {
143     return ERR_UNEXPECTED;
144   }
NetLog() const145   const NetLogWithSource& NetLog() const override { return net_log_; }
146 
WasEverUsed() const147   bool WasEverUsed() const override { return false; }
WasAlpnNegotiated() const148   bool WasAlpnNegotiated() const override { return false; }
GetNegotiatedProtocol() const149   NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
GetSSLInfo(SSLInfo * ssl_info)150   bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
GetConnectionAttempts(ConnectionAttempts * out) const151   void GetConnectionAttempts(ConnectionAttempts* out) const override {
152     out->clear();
153     for (const auto& addr : addrlist_)
154       out->push_back(ConnectionAttempt(addr, ERR_CONNECTION_FAILED));
155   }
ClearConnectionAttempts()156   void ClearConnectionAttempts() override {}
AddConnectionAttempts(const ConnectionAttempts & attempts)157   void AddConnectionAttempts(const ConnectionAttempts& attempts) override {}
GetTotalReceivedBytes() const158   int64_t GetTotalReceivedBytes() const override {
159     NOTIMPLEMENTED();
160     return 0;
161   }
ApplySocketTag(const SocketTag & tag)162   void ApplySocketTag(const SocketTag& tag) override {}
163 
164   // Socket implementation.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)165   int Read(IOBuffer* buf,
166            int buf_len,
167            CompletionOnceCallback callback) override {
168     return ERR_FAILED;
169   }
170 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)171   int Write(IOBuffer* buf,
172             int buf_len,
173             CompletionOnceCallback callback,
174             const NetworkTrafficAnnotationTag& traffic_annotation) override {
175     return ERR_FAILED;
176   }
SetReceiveBufferSize(int32_t size)177   int SetReceiveBufferSize(int32_t size) override { return OK; }
SetSendBufferSize(int32_t size)178   int SetSendBufferSize(int32_t size) override { return OK; }
179 
180  private:
181   const AddressList addrlist_;
182   NetLogWithSource net_log_;
183 
184   DISALLOW_COPY_AND_ASSIGN(MockFailingClientSocket);
185 };
186 
187 class MockTriggerableClientSocket : public TransportClientSocket {
188  public:
189   // |should_connect| indicates whether the socket should successfully complete
190   // or fail.
MockTriggerableClientSocket(const AddressList & addrlist,bool should_connect,net::NetLog * net_log)191   MockTriggerableClientSocket(const AddressList& addrlist,
192                               bool should_connect,
193                               net::NetLog* net_log)
194       : should_connect_(should_connect),
195         is_connected_(false),
196         addrlist_(addrlist),
197         net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::SOCKET)) {}
198 
199   // Call this method to get a closure which will trigger the connect callback
200   // when called. The closure can be called even after the socket is deleted; it
201   // will safely do nothing.
GetConnectCallback()202   base::OnceClosure GetConnectCallback() {
203     return base::BindOnce(&MockTriggerableClientSocket::DoCallback,
204                           weak_factory_.GetWeakPtr());
205   }
206 
MakeMockPendingClientSocket(const AddressList & addrlist,bool should_connect,net::NetLog * net_log)207   static std::unique_ptr<TransportClientSocket> MakeMockPendingClientSocket(
208       const AddressList& addrlist,
209       bool should_connect,
210       net::NetLog* net_log) {
211     std::unique_ptr<MockTriggerableClientSocket> socket(
212         new MockTriggerableClientSocket(addrlist, should_connect, net_log));
213     base::ThreadTaskRunnerHandle::Get()->PostTask(FROM_HERE,
214                                                   socket->GetConnectCallback());
215     return std::move(socket);
216   }
217 
MakeMockDelayedClientSocket(const AddressList & addrlist,bool should_connect,const base::TimeDelta & delay,net::NetLog * net_log)218   static std::unique_ptr<TransportClientSocket> MakeMockDelayedClientSocket(
219       const AddressList& addrlist,
220       bool should_connect,
221       const base::TimeDelta& delay,
222       net::NetLog* net_log) {
223     std::unique_ptr<MockTriggerableClientSocket> socket(
224         new MockTriggerableClientSocket(addrlist, should_connect, net_log));
225     base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
226         FROM_HERE, socket->GetConnectCallback(), delay);
227     return std::move(socket);
228   }
229 
MakeMockStalledClientSocket(const AddressList & addrlist,net::NetLog * net_log,bool failing)230   static std::unique_ptr<TransportClientSocket> MakeMockStalledClientSocket(
231       const AddressList& addrlist,
232       net::NetLog* net_log,
233       bool failing) {
234     std::unique_ptr<MockTriggerableClientSocket> socket(
235         new MockTriggerableClientSocket(addrlist, true, net_log));
236     if (failing) {
237       DCHECK_LE(1u, addrlist.size());
238       ConnectionAttempts attempts;
239       attempts.push_back(ConnectionAttempt(addrlist[0], ERR_CONNECTION_FAILED));
240       socket->AddConnectionAttempts(attempts);
241     }
242     return std::move(socket);
243   }
244 
245   // TransportClientSocket implementation.
Bind(const net::IPEndPoint & local_addr)246   int Bind(const net::IPEndPoint& local_addr) override {
247     NOTREACHED();
248     return ERR_FAILED;
249   }
250 
251   // StreamSocket implementation.
Connect(CompletionOnceCallback callback)252   int Connect(CompletionOnceCallback callback) override {
253     DCHECK(callback_.is_null());
254     callback_ = std::move(callback);
255     return ERR_IO_PENDING;
256   }
257 
Disconnect()258   void Disconnect() override {}
259 
IsConnected() const260   bool IsConnected() const override { return is_connected_; }
IsConnectedAndIdle() const261   bool IsConnectedAndIdle() const override { return is_connected_; }
GetPeerAddress(IPEndPoint * address) const262   int GetPeerAddress(IPEndPoint* address) const override {
263     *address = addrlist_.front();
264     return OK;
265   }
GetLocalAddress(IPEndPoint * address) const266   int GetLocalAddress(IPEndPoint* address) const override {
267     if (!is_connected_)
268       return ERR_SOCKET_NOT_CONNECTED;
269     if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
270       SetIPv4Address(address);
271     else
272       SetIPv6Address(address);
273     return OK;
274   }
NetLog() const275   const NetLogWithSource& NetLog() const override { return net_log_; }
276 
WasEverUsed() const277   bool WasEverUsed() const override { return false; }
WasAlpnNegotiated() const278   bool WasAlpnNegotiated() const override { return false; }
GetNegotiatedProtocol() const279   NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
GetSSLInfo(SSLInfo * ssl_info)280   bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
GetConnectionAttempts(ConnectionAttempts * out) const281   void GetConnectionAttempts(ConnectionAttempts* out) const override {
282     *out = connection_attempts_;
283   }
ClearConnectionAttempts()284   void ClearConnectionAttempts() override { connection_attempts_.clear(); }
AddConnectionAttempts(const ConnectionAttempts & attempts)285   void AddConnectionAttempts(const ConnectionAttempts& attempts) override {
286     connection_attempts_.insert(connection_attempts_.begin(), attempts.begin(),
287                                 attempts.end());
288   }
GetTotalReceivedBytes() const289   int64_t GetTotalReceivedBytes() const override {
290     NOTIMPLEMENTED();
291     return 0;
292   }
ApplySocketTag(const SocketTag & tag)293   void ApplySocketTag(const SocketTag& tag) override {}
294 
295   // Socket implementation.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)296   int Read(IOBuffer* buf,
297            int buf_len,
298            CompletionOnceCallback callback) override {
299     return ERR_FAILED;
300   }
301 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)302   int Write(IOBuffer* buf,
303             int buf_len,
304             CompletionOnceCallback callback,
305             const NetworkTrafficAnnotationTag& traffic_annotation) override {
306     return ERR_FAILED;
307   }
SetReceiveBufferSize(int32_t size)308   int SetReceiveBufferSize(int32_t size) override { return OK; }
SetSendBufferSize(int32_t size)309   int SetSendBufferSize(int32_t size) override { return OK; }
310 
311  private:
DoCallback()312   void DoCallback() {
313     is_connected_ = should_connect_;
314     std::move(callback_).Run(is_connected_ ? OK : ERR_CONNECTION_FAILED);
315   }
316 
317   bool should_connect_;
318   bool is_connected_;
319   const AddressList addrlist_;
320   NetLogWithSource net_log_;
321   CompletionOnceCallback callback_;
322   ConnectionAttempts connection_attempts_;
323 
324   base::WeakPtrFactory<MockTriggerableClientSocket> weak_factory_{this};
325 
326   DISALLOW_COPY_AND_ASSIGN(MockTriggerableClientSocket);
327 };
328 
329 }  // namespace
330 
TestLoadTimingInfoConnectedReused(const ClientSocketHandle & handle)331 void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle) {
332   LoadTimingInfo load_timing_info;
333   // Only pass true in as |is_reused|, as in general, HttpStream types should
334   // have stricter concepts of reuse than socket pools.
335   EXPECT_TRUE(handle.GetLoadTimingInfo(true, &load_timing_info));
336 
337   EXPECT_TRUE(load_timing_info.socket_reused);
338   EXPECT_NE(NetLogSource::kInvalidId, load_timing_info.socket_log_id);
339 
340   ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing);
341   ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
342 }
343 
TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle & handle)344 void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle) {
345   EXPECT_FALSE(handle.is_reused());
346 
347   LoadTimingInfo load_timing_info;
348   EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info));
349 
350   EXPECT_FALSE(load_timing_info.socket_reused);
351   EXPECT_NE(NetLogSource::kInvalidId, load_timing_info.socket_log_id);
352 
353   ExpectConnectTimingHasTimes(load_timing_info.connect_timing,
354                               CONNECT_TIMING_HAS_DNS_TIMES);
355   ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
356 
357   TestLoadTimingInfoConnectedReused(handle);
358 }
359 
SetIPv4Address(IPEndPoint * address)360 void SetIPv4Address(IPEndPoint* address) {
361   *address = IPEndPoint(ParseIP("1.1.1.1"), 80);
362 }
363 
SetIPv6Address(IPEndPoint * address)364 void SetIPv6Address(IPEndPoint* address) {
365   *address = IPEndPoint(ParseIP("1:abcd::3:4:ff"), 80);
366 }
367 
MockTransportClientSocketFactory(NetLog * net_log)368 MockTransportClientSocketFactory::MockTransportClientSocketFactory(
369     NetLog* net_log)
370     : net_log_(net_log),
371       allocation_count_(0),
372       client_socket_type_(MOCK_CLIENT_SOCKET),
373       client_socket_types_(nullptr),
374       client_socket_index_(0),
375       client_socket_index_max_(0),
376       delay_(base::TimeDelta::FromMilliseconds(
377           ClientSocketPool::kMaxConnectRetryIntervalMs)) {}
378 
379 MockTransportClientSocketFactory::~MockTransportClientSocketFactory() = default;
380 
381 std::unique_ptr<DatagramClientSocket>
CreateDatagramClientSocket(DatagramSocket::BindType bind_type,NetLog * net_log,const NetLogSource & source)382 MockTransportClientSocketFactory::CreateDatagramClientSocket(
383     DatagramSocket::BindType bind_type,
384     NetLog* net_log,
385     const NetLogSource& source) {
386   NOTREACHED();
387   return std::unique_ptr<DatagramClientSocket>();
388 }
389 
390 std::unique_ptr<TransportClientSocket>
CreateTransportClientSocket(const AddressList & addresses,std::unique_ptr<SocketPerformanceWatcher>,NetLog *,const NetLogSource &)391 MockTransportClientSocketFactory::CreateTransportClientSocket(
392     const AddressList& addresses,
393     std::unique_ptr<SocketPerformanceWatcher> /* socket_performance_watcher */,
394     NetLog* /* net_log */,
395     const NetLogSource& /* source */) {
396   allocation_count_++;
397 
398   ClientSocketType type = client_socket_type_;
399   if (client_socket_types_ && client_socket_index_ < client_socket_index_max_) {
400     type = client_socket_types_[client_socket_index_++];
401   }
402 
403   switch (type) {
404     case MOCK_CLIENT_SOCKET:
405       return std::make_unique<MockConnectClientSocket>(addresses, net_log_);
406     case MOCK_FAILING_CLIENT_SOCKET:
407       return std::make_unique<MockFailingClientSocket>(addresses, net_log_);
408     case MOCK_PENDING_CLIENT_SOCKET:
409       return MockTriggerableClientSocket::MakeMockPendingClientSocket(
410           addresses, true, net_log_);
411     case MOCK_PENDING_FAILING_CLIENT_SOCKET:
412       return MockTriggerableClientSocket::MakeMockPendingClientSocket(
413           addresses, false, net_log_);
414     case MOCK_DELAYED_CLIENT_SOCKET:
415       return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
416           addresses, true, delay_, net_log_);
417     case MOCK_DELAYED_FAILING_CLIENT_SOCKET:
418       return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
419           addresses, false, delay_, net_log_);
420     case MOCK_STALLED_CLIENT_SOCKET:
421       return MockTriggerableClientSocket::MakeMockStalledClientSocket(
422           addresses, net_log_, false);
423     case MOCK_STALLED_FAILING_CLIENT_SOCKET:
424       return MockTriggerableClientSocket::MakeMockStalledClientSocket(
425           addresses, net_log_, true);
426     case MOCK_TRIGGERABLE_CLIENT_SOCKET: {
427       std::unique_ptr<MockTriggerableClientSocket> rv(
428           new MockTriggerableClientSocket(addresses, true, net_log_));
429       triggerable_sockets_.push(rv->GetConnectCallback());
430       // run_loop_quit_closure_ behaves like a condition variable. It will
431       // wake up WaitForTriggerableSocketCreation() if it is sleeping. We
432       // don't need to worry about atomicity because this code is
433       // single-threaded.
434       if (!run_loop_quit_closure_.is_null())
435         std::move(run_loop_quit_closure_).Run();
436       return std::move(rv);
437     }
438     default:
439       NOTREACHED();
440       return std::make_unique<MockConnectClientSocket>(addresses, net_log_);
441   }
442 }
443 
444 std::unique_ptr<SSLClientSocket>
CreateSSLClientSocket(SSLClientContext * context,std::unique_ptr<StreamSocket> stream_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config)445 MockTransportClientSocketFactory::CreateSSLClientSocket(
446     SSLClientContext* context,
447     std::unique_ptr<StreamSocket> stream_socket,
448     const HostPortPair& host_and_port,
449     const SSLConfig& ssl_config) {
450   NOTIMPLEMENTED();
451   return nullptr;
452 }
453 
454 std::unique_ptr<ProxyClientSocket>
CreateProxyClientSocket(std::unique_ptr<StreamSocket> stream_socket,const std::string & user_agent,const HostPortPair & endpoint,const ProxyServer & proxy_server,HttpAuthController * http_auth_controller,bool tunnel,bool using_spdy,NextProto negotiated_protocol,ProxyDelegate * proxy_delegate,const NetworkTrafficAnnotationTag & traffic_annotation)455 MockTransportClientSocketFactory::CreateProxyClientSocket(
456     std::unique_ptr<StreamSocket> stream_socket,
457     const std::string& user_agent,
458     const HostPortPair& endpoint,
459     const ProxyServer& proxy_server,
460     HttpAuthController* http_auth_controller,
461     bool tunnel,
462     bool using_spdy,
463     NextProto negotiated_protocol,
464     ProxyDelegate* proxy_delegate,
465     const NetworkTrafficAnnotationTag& traffic_annotation) {
466   NOTIMPLEMENTED();
467   return nullptr;
468 }
469 
set_client_socket_types(ClientSocketType * type_list,int num_types)470 void MockTransportClientSocketFactory::set_client_socket_types(
471     ClientSocketType* type_list,
472     int num_types) {
473   DCHECK_GT(num_types, 0);
474   client_socket_types_ = type_list;
475   client_socket_index_ = 0;
476   client_socket_index_max_ = num_types;
477 }
478 
479 base::OnceClosure
WaitForTriggerableSocketCreation()480 MockTransportClientSocketFactory::WaitForTriggerableSocketCreation() {
481   while (triggerable_sockets_.empty()) {
482     base::RunLoop run_loop;
483     run_loop_quit_closure_ = run_loop.QuitClosure();
484     run_loop.Run();
485     run_loop_quit_closure_.Reset();
486   }
487   base::OnceClosure trigger = std::move(triggerable_sockets_.front());
488   triggerable_sockets_.pop();
489   return trigger;
490 }
491 
492 }  // namespace net
493