1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/websocket_transport_connect_sub_job.h"
6 
7 #include "base/bind.h"
8 #include "base/check_op.h"
9 #include "base/notreached.h"
10 #include "net/base/ip_endpoint.h"
11 #include "net/base/net_errors.h"
12 #include "net/log/net_log_with_source.h"
13 #include "net/socket/client_socket_factory.h"
14 #include "net/socket/websocket_endpoint_lock_manager.h"
15 
16 namespace net {
17 
18 namespace {
19 
20 // StreamSocket wrapper that registers/unregisters the wrapped StreamSocket with
21 // a WebSocketEndpointLockManager on creation/destruction.
22 class WebSocketStreamSocket final : public StreamSocket {
23  public:
WebSocketStreamSocket(std::unique_ptr<StreamSocket> wrapped_socket,WebSocketEndpointLockManager * websocket_endpoint_lock_manager,const IPEndPoint & address)24   WebSocketStreamSocket(
25       std::unique_ptr<StreamSocket> wrapped_socket,
26       WebSocketEndpointLockManager* websocket_endpoint_lock_manager,
27       const IPEndPoint& address)
28       : wrapped_socket_(std::move(wrapped_socket)),
29         lock_releaser_(websocket_endpoint_lock_manager, address) {}
30 
31   ~WebSocketStreamSocket() override = default;
32 
33   // Socket implementation:
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)34   int Read(IOBuffer* buf,
35            int buf_len,
36            CompletionOnceCallback callback) override {
37     return wrapped_socket_->Read(buf, buf_len, std::move(callback));
38   }
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)39   int ReadIfReady(IOBuffer* buf,
40                   int buf_len,
41                   CompletionOnceCallback callback) override {
42     return wrapped_socket_->ReadIfReady(buf, buf_len, std::move(callback));
43   }
CancelReadIfReady()44   int CancelReadIfReady() override {
45     return wrapped_socket_->CancelReadIfReady();
46   }
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)47   int Write(IOBuffer* buf,
48             int buf_len,
49             CompletionOnceCallback callback,
50             const NetworkTrafficAnnotationTag& traffic_annotation) override {
51     return wrapped_socket_->Write(buf, buf_len, std::move(callback),
52                                   traffic_annotation);
53   }
SetReceiveBufferSize(int32_t size)54   int SetReceiveBufferSize(int32_t size) override {
55     return wrapped_socket_->SetReceiveBufferSize(size);
56   }
SetSendBufferSize(int32_t size)57   int SetSendBufferSize(int32_t size) override {
58     return wrapped_socket_->SetSendBufferSize(size);
59   }
60 
61   // StreamSocket implementation:
Connect(CompletionOnceCallback callback)62   int Connect(CompletionOnceCallback callback) override {
63     return wrapped_socket_->Connect(std::move(callback));
64   }
Disconnect()65   void Disconnect() override { wrapped_socket_->Disconnect(); }
IsConnected() const66   bool IsConnected() const override { return wrapped_socket_->IsConnected(); }
IsConnectedAndIdle() const67   bool IsConnectedAndIdle() const override {
68     return wrapped_socket_->IsConnectedAndIdle();
69   }
GetPeerAddress(IPEndPoint * address) const70   int GetPeerAddress(IPEndPoint* address) const override {
71     return wrapped_socket_->GetPeerAddress(address);
72   }
GetLocalAddress(IPEndPoint * address) const73   int GetLocalAddress(IPEndPoint* address) const override {
74     return wrapped_socket_->GetLocalAddress(address);
75   }
NetLog() const76   const NetLogWithSource& NetLog() const override {
77     return wrapped_socket_->NetLog();
78   }
WasEverUsed() const79   bool WasEverUsed() const override { return wrapped_socket_->WasEverUsed(); }
WasAlpnNegotiated() const80   bool WasAlpnNegotiated() const override {
81     return wrapped_socket_->WasAlpnNegotiated();
82   }
GetNegotiatedProtocol() const83   NextProto GetNegotiatedProtocol() const override {
84     return wrapped_socket_->GetNegotiatedProtocol();
85   }
GetSSLInfo(SSLInfo * ssl_info)86   bool GetSSLInfo(SSLInfo* ssl_info) override {
87     return wrapped_socket_->GetSSLInfo(ssl_info);
88   }
GetConnectionAttempts(ConnectionAttempts * out) const89   void GetConnectionAttempts(ConnectionAttempts* out) const override {
90     wrapped_socket_->GetConnectionAttempts(out);
91   }
ClearConnectionAttempts()92   void ClearConnectionAttempts() override {
93     wrapped_socket_->ClearConnectionAttempts();
94   }
AddConnectionAttempts(const ConnectionAttempts & attempts)95   void AddConnectionAttempts(const ConnectionAttempts& attempts) override {
96     wrapped_socket_->AddConnectionAttempts(attempts);
97   }
GetTotalReceivedBytes() const98   int64_t GetTotalReceivedBytes() const override {
99     return wrapped_socket_->GetTotalReceivedBytes();
100   }
DumpMemoryStats(SocketMemoryStats * stats) const101   void DumpMemoryStats(SocketMemoryStats* stats) const override {
102     wrapped_socket_->DumpMemoryStats(stats);
103   }
ApplySocketTag(const SocketTag & tag)104   void ApplySocketTag(const SocketTag& tag) override {
105     wrapped_socket_->ApplySocketTag(tag);
106   }
107 
108  private:
109   std::unique_ptr<StreamSocket> wrapped_socket_;
110   WebSocketEndpointLockManager::LockReleaser lock_releaser_;
111 
112   DISALLOW_COPY_AND_ASSIGN(WebSocketStreamSocket);
113 };
114 
115 }  // namespace
116 
WebSocketTransportConnectSubJob(const AddressList & addresses,WebSocketTransportConnectJob * parent_job,SubJobType type,WebSocketEndpointLockManager * websocket_endpoint_lock_manager)117 WebSocketTransportConnectSubJob::WebSocketTransportConnectSubJob(
118     const AddressList& addresses,
119     WebSocketTransportConnectJob* parent_job,
120     SubJobType type,
121     WebSocketEndpointLockManager* websocket_endpoint_lock_manager)
122     : parent_job_(parent_job),
123       addresses_(addresses),
124       current_address_index_(0),
125       next_state_(STATE_NONE),
126       type_(type),
127       websocket_endpoint_lock_manager_(websocket_endpoint_lock_manager) {}
128 
~WebSocketTransportConnectSubJob()129 WebSocketTransportConnectSubJob::~WebSocketTransportConnectSubJob() {
130   // We don't worry about cancelling the TCP connect, since ~StreamSocket will
131   // take care of it.
132   if (next()) {
133     DCHECK_EQ(STATE_OBTAIN_LOCK_COMPLETE, next_state_);
134     // The ~Waiter destructor will remove this object from the waiting list.
135   } else if (next_state_ == STATE_TRANSPORT_CONNECT_COMPLETE) {
136     websocket_endpoint_lock_manager_->UnlockEndpoint(CurrentAddress());
137   }
138 }
139 
140 // Start connecting.
Start()141 int WebSocketTransportConnectSubJob::Start() {
142   DCHECK_EQ(STATE_NONE, next_state_);
143   next_state_ = STATE_OBTAIN_LOCK;
144   return DoLoop(OK);
145 }
146 
147 // Called by WebSocketEndpointLockManager when the lock becomes available.
GotEndpointLock()148 void WebSocketTransportConnectSubJob::GotEndpointLock() {
149   DCHECK_EQ(STATE_OBTAIN_LOCK_COMPLETE, next_state_);
150   OnIOComplete(OK);
151 }
152 
GetLoadState() const153 LoadState WebSocketTransportConnectSubJob::GetLoadState() const {
154   switch (next_state_) {
155     case STATE_OBTAIN_LOCK:
156     case STATE_OBTAIN_LOCK_COMPLETE:
157       // TODO(ricea): Add a WebSocket-specific LOAD_STATE ?
158       return LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET;
159     case STATE_TRANSPORT_CONNECT:
160     case STATE_TRANSPORT_CONNECT_COMPLETE:
161     case STATE_DONE:
162       return LOAD_STATE_CONNECTING;
163     case STATE_NONE:
164       return LOAD_STATE_IDLE;
165   }
166   NOTREACHED();
167   return LOAD_STATE_IDLE;
168 }
169 
client_socket_factory() const170 ClientSocketFactory* WebSocketTransportConnectSubJob::client_socket_factory()
171     const {
172   return parent_job_->client_socket_factory();
173 }
174 
net_log() const175 const NetLogWithSource& WebSocketTransportConnectSubJob::net_log() const {
176   return parent_job_->net_log();
177 }
178 
CurrentAddress() const179 const IPEndPoint& WebSocketTransportConnectSubJob::CurrentAddress() const {
180   DCHECK_LT(current_address_index_, addresses_.size());
181   return addresses_[current_address_index_];
182 }
183 
OnIOComplete(int result)184 void WebSocketTransportConnectSubJob::OnIOComplete(int result) {
185   int rv = DoLoop(result);
186   if (rv != ERR_IO_PENDING)
187     parent_job_->OnSubJobComplete(rv, this);  // |this| deleted
188 }
189 
DoLoop(int result)190 int WebSocketTransportConnectSubJob::DoLoop(int result) {
191   DCHECK_NE(next_state_, STATE_NONE);
192 
193   int rv = result;
194   do {
195     State state = next_state_;
196     next_state_ = STATE_NONE;
197     switch (state) {
198       case STATE_OBTAIN_LOCK:
199         DCHECK_EQ(OK, rv);
200         rv = DoEndpointLock();
201         break;
202       case STATE_OBTAIN_LOCK_COMPLETE:
203         DCHECK_EQ(OK, rv);
204         rv = DoEndpointLockComplete();
205         break;
206       case STATE_TRANSPORT_CONNECT:
207         DCHECK_EQ(OK, rv);
208         rv = DoTransportConnect();
209         break;
210       case STATE_TRANSPORT_CONNECT_COMPLETE:
211         rv = DoTransportConnectComplete(rv);
212         break;
213       default:
214         NOTREACHED();
215         rv = ERR_FAILED;
216         break;
217     }
218   } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE &&
219            next_state_ != STATE_DONE);
220 
221   return rv;
222 }
223 
DoEndpointLock()224 int WebSocketTransportConnectSubJob::DoEndpointLock() {
225   int rv =
226       websocket_endpoint_lock_manager_->LockEndpoint(CurrentAddress(), this);
227   next_state_ = STATE_OBTAIN_LOCK_COMPLETE;
228   return rv;
229 }
230 
DoEndpointLockComplete()231 int WebSocketTransportConnectSubJob::DoEndpointLockComplete() {
232   next_state_ = STATE_TRANSPORT_CONNECT;
233   return OK;
234 }
235 
DoTransportConnect()236 int WebSocketTransportConnectSubJob::DoTransportConnect() {
237   // TODO(ricea): Update global g_last_connect_time and report
238   // ConnectInterval.
239   next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE;
240   AddressList one_address(CurrentAddress());
241   // TODO(https://crbug.com/1123197): Pass a non-null NetworkQualityEstimator.
242   NetworkQualityEstimator* network_quality_estimator = nullptr;
243 
244   transport_socket_ = client_socket_factory()->CreateTransportClientSocket(
245       one_address, nullptr, network_quality_estimator, net_log().net_log(),
246       net_log().source());
247   // This use of base::Unretained() is safe because transport_socket_ is
248   // destroyed in the destructor.
249   return transport_socket_->Connect(base::BindOnce(
250       &WebSocketTransportConnectSubJob::OnIOComplete, base::Unretained(this)));
251 }
252 
DoTransportConnectComplete(int result)253 int WebSocketTransportConnectSubJob::DoTransportConnectComplete(int result) {
254   next_state_ = STATE_DONE;
255   if (result != OK) {
256     websocket_endpoint_lock_manager_->UnlockEndpoint(CurrentAddress());
257 
258     if (current_address_index_ + 1 < addresses_.size()) {
259       // Try falling back to the next address in the list.
260       next_state_ = STATE_OBTAIN_LOCK;
261       ++current_address_index_;
262       result = OK;
263     }
264 
265     return result;
266   }
267 
268   // On success, need to register the socket with the
269   // WebSocketEndpointLockManager.
270   transport_socket_ = std::make_unique<WebSocketStreamSocket>(
271       std::move(transport_socket_), websocket_endpoint_lock_manager_,
272       CurrentAddress());
273 
274   return result;
275 }
276 
277 }  // namespace net
278