1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/websocket_transport_connect_sub_job.h"
6
7 #include "base/bind.h"
8 #include "base/check_op.h"
9 #include "base/notreached.h"
10 #include "net/base/ip_endpoint.h"
11 #include "net/base/net_errors.h"
12 #include "net/log/net_log_with_source.h"
13 #include "net/socket/client_socket_factory.h"
14 #include "net/socket/websocket_endpoint_lock_manager.h"
15
16 namespace net {
17
18 namespace {
19
20 // StreamSocket wrapper that registers/unregisters the wrapped StreamSocket with
21 // a WebSocketEndpointLockManager on creation/destruction.
22 class WebSocketStreamSocket final : public StreamSocket {
23 public:
WebSocketStreamSocket(std::unique_ptr<StreamSocket> wrapped_socket,WebSocketEndpointLockManager * websocket_endpoint_lock_manager,const IPEndPoint & address)24 WebSocketStreamSocket(
25 std::unique_ptr<StreamSocket> wrapped_socket,
26 WebSocketEndpointLockManager* websocket_endpoint_lock_manager,
27 const IPEndPoint& address)
28 : wrapped_socket_(std::move(wrapped_socket)),
29 lock_releaser_(websocket_endpoint_lock_manager, address) {}
30
31 ~WebSocketStreamSocket() override = default;
32
33 // Socket implementation:
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)34 int Read(IOBuffer* buf,
35 int buf_len,
36 CompletionOnceCallback callback) override {
37 return wrapped_socket_->Read(buf, buf_len, std::move(callback));
38 }
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)39 int ReadIfReady(IOBuffer* buf,
40 int buf_len,
41 CompletionOnceCallback callback) override {
42 return wrapped_socket_->ReadIfReady(buf, buf_len, std::move(callback));
43 }
CancelReadIfReady()44 int CancelReadIfReady() override {
45 return wrapped_socket_->CancelReadIfReady();
46 }
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)47 int Write(IOBuffer* buf,
48 int buf_len,
49 CompletionOnceCallback callback,
50 const NetworkTrafficAnnotationTag& traffic_annotation) override {
51 return wrapped_socket_->Write(buf, buf_len, std::move(callback),
52 traffic_annotation);
53 }
SetReceiveBufferSize(int32_t size)54 int SetReceiveBufferSize(int32_t size) override {
55 return wrapped_socket_->SetReceiveBufferSize(size);
56 }
SetSendBufferSize(int32_t size)57 int SetSendBufferSize(int32_t size) override {
58 return wrapped_socket_->SetSendBufferSize(size);
59 }
60
61 // StreamSocket implementation:
Connect(CompletionOnceCallback callback)62 int Connect(CompletionOnceCallback callback) override {
63 return wrapped_socket_->Connect(std::move(callback));
64 }
Disconnect()65 void Disconnect() override { wrapped_socket_->Disconnect(); }
IsConnected() const66 bool IsConnected() const override { return wrapped_socket_->IsConnected(); }
IsConnectedAndIdle() const67 bool IsConnectedAndIdle() const override {
68 return wrapped_socket_->IsConnectedAndIdle();
69 }
GetPeerAddress(IPEndPoint * address) const70 int GetPeerAddress(IPEndPoint* address) const override {
71 return wrapped_socket_->GetPeerAddress(address);
72 }
GetLocalAddress(IPEndPoint * address) const73 int GetLocalAddress(IPEndPoint* address) const override {
74 return wrapped_socket_->GetLocalAddress(address);
75 }
NetLog() const76 const NetLogWithSource& NetLog() const override {
77 return wrapped_socket_->NetLog();
78 }
WasEverUsed() const79 bool WasEverUsed() const override { return wrapped_socket_->WasEverUsed(); }
WasAlpnNegotiated() const80 bool WasAlpnNegotiated() const override {
81 return wrapped_socket_->WasAlpnNegotiated();
82 }
GetNegotiatedProtocol() const83 NextProto GetNegotiatedProtocol() const override {
84 return wrapped_socket_->GetNegotiatedProtocol();
85 }
GetSSLInfo(SSLInfo * ssl_info)86 bool GetSSLInfo(SSLInfo* ssl_info) override {
87 return wrapped_socket_->GetSSLInfo(ssl_info);
88 }
GetConnectionAttempts(ConnectionAttempts * out) const89 void GetConnectionAttempts(ConnectionAttempts* out) const override {
90 wrapped_socket_->GetConnectionAttempts(out);
91 }
ClearConnectionAttempts()92 void ClearConnectionAttempts() override {
93 wrapped_socket_->ClearConnectionAttempts();
94 }
AddConnectionAttempts(const ConnectionAttempts & attempts)95 void AddConnectionAttempts(const ConnectionAttempts& attempts) override {
96 wrapped_socket_->AddConnectionAttempts(attempts);
97 }
GetTotalReceivedBytes() const98 int64_t GetTotalReceivedBytes() const override {
99 return wrapped_socket_->GetTotalReceivedBytes();
100 }
DumpMemoryStats(SocketMemoryStats * stats) const101 void DumpMemoryStats(SocketMemoryStats* stats) const override {
102 wrapped_socket_->DumpMemoryStats(stats);
103 }
ApplySocketTag(const SocketTag & tag)104 void ApplySocketTag(const SocketTag& tag) override {
105 wrapped_socket_->ApplySocketTag(tag);
106 }
107
108 private:
109 std::unique_ptr<StreamSocket> wrapped_socket_;
110 WebSocketEndpointLockManager::LockReleaser lock_releaser_;
111
112 DISALLOW_COPY_AND_ASSIGN(WebSocketStreamSocket);
113 };
114
115 } // namespace
116
WebSocketTransportConnectSubJob(const AddressList & addresses,WebSocketTransportConnectJob * parent_job,SubJobType type,WebSocketEndpointLockManager * websocket_endpoint_lock_manager)117 WebSocketTransportConnectSubJob::WebSocketTransportConnectSubJob(
118 const AddressList& addresses,
119 WebSocketTransportConnectJob* parent_job,
120 SubJobType type,
121 WebSocketEndpointLockManager* websocket_endpoint_lock_manager)
122 : parent_job_(parent_job),
123 addresses_(addresses),
124 current_address_index_(0),
125 next_state_(STATE_NONE),
126 type_(type),
127 websocket_endpoint_lock_manager_(websocket_endpoint_lock_manager) {}
128
~WebSocketTransportConnectSubJob()129 WebSocketTransportConnectSubJob::~WebSocketTransportConnectSubJob() {
130 // We don't worry about cancelling the TCP connect, since ~StreamSocket will
131 // take care of it.
132 if (next()) {
133 DCHECK_EQ(STATE_OBTAIN_LOCK_COMPLETE, next_state_);
134 // The ~Waiter destructor will remove this object from the waiting list.
135 } else if (next_state_ == STATE_TRANSPORT_CONNECT_COMPLETE) {
136 websocket_endpoint_lock_manager_->UnlockEndpoint(CurrentAddress());
137 }
138 }
139
140 // Start connecting.
Start()141 int WebSocketTransportConnectSubJob::Start() {
142 DCHECK_EQ(STATE_NONE, next_state_);
143 next_state_ = STATE_OBTAIN_LOCK;
144 return DoLoop(OK);
145 }
146
147 // Called by WebSocketEndpointLockManager when the lock becomes available.
GotEndpointLock()148 void WebSocketTransportConnectSubJob::GotEndpointLock() {
149 DCHECK_EQ(STATE_OBTAIN_LOCK_COMPLETE, next_state_);
150 OnIOComplete(OK);
151 }
152
GetLoadState() const153 LoadState WebSocketTransportConnectSubJob::GetLoadState() const {
154 switch (next_state_) {
155 case STATE_OBTAIN_LOCK:
156 case STATE_OBTAIN_LOCK_COMPLETE:
157 // TODO(ricea): Add a WebSocket-specific LOAD_STATE ?
158 return LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET;
159 case STATE_TRANSPORT_CONNECT:
160 case STATE_TRANSPORT_CONNECT_COMPLETE:
161 case STATE_DONE:
162 return LOAD_STATE_CONNECTING;
163 case STATE_NONE:
164 return LOAD_STATE_IDLE;
165 }
166 NOTREACHED();
167 return LOAD_STATE_IDLE;
168 }
169
client_socket_factory() const170 ClientSocketFactory* WebSocketTransportConnectSubJob::client_socket_factory()
171 const {
172 return parent_job_->client_socket_factory();
173 }
174
net_log() const175 const NetLogWithSource& WebSocketTransportConnectSubJob::net_log() const {
176 return parent_job_->net_log();
177 }
178
CurrentAddress() const179 const IPEndPoint& WebSocketTransportConnectSubJob::CurrentAddress() const {
180 DCHECK_LT(current_address_index_, addresses_.size());
181 return addresses_[current_address_index_];
182 }
183
OnIOComplete(int result)184 void WebSocketTransportConnectSubJob::OnIOComplete(int result) {
185 int rv = DoLoop(result);
186 if (rv != ERR_IO_PENDING)
187 parent_job_->OnSubJobComplete(rv, this); // |this| deleted
188 }
189
DoLoop(int result)190 int WebSocketTransportConnectSubJob::DoLoop(int result) {
191 DCHECK_NE(next_state_, STATE_NONE);
192
193 int rv = result;
194 do {
195 State state = next_state_;
196 next_state_ = STATE_NONE;
197 switch (state) {
198 case STATE_OBTAIN_LOCK:
199 DCHECK_EQ(OK, rv);
200 rv = DoEndpointLock();
201 break;
202 case STATE_OBTAIN_LOCK_COMPLETE:
203 DCHECK_EQ(OK, rv);
204 rv = DoEndpointLockComplete();
205 break;
206 case STATE_TRANSPORT_CONNECT:
207 DCHECK_EQ(OK, rv);
208 rv = DoTransportConnect();
209 break;
210 case STATE_TRANSPORT_CONNECT_COMPLETE:
211 rv = DoTransportConnectComplete(rv);
212 break;
213 default:
214 NOTREACHED();
215 rv = ERR_FAILED;
216 break;
217 }
218 } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE &&
219 next_state_ != STATE_DONE);
220
221 return rv;
222 }
223
DoEndpointLock()224 int WebSocketTransportConnectSubJob::DoEndpointLock() {
225 int rv =
226 websocket_endpoint_lock_manager_->LockEndpoint(CurrentAddress(), this);
227 next_state_ = STATE_OBTAIN_LOCK_COMPLETE;
228 return rv;
229 }
230
DoEndpointLockComplete()231 int WebSocketTransportConnectSubJob::DoEndpointLockComplete() {
232 next_state_ = STATE_TRANSPORT_CONNECT;
233 return OK;
234 }
235
DoTransportConnect()236 int WebSocketTransportConnectSubJob::DoTransportConnect() {
237 // TODO(ricea): Update global g_last_connect_time and report
238 // ConnectInterval.
239 next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE;
240 AddressList one_address(CurrentAddress());
241 // TODO(https://crbug.com/1123197): Pass a non-null NetworkQualityEstimator.
242 NetworkQualityEstimator* network_quality_estimator = nullptr;
243
244 transport_socket_ = client_socket_factory()->CreateTransportClientSocket(
245 one_address, nullptr, network_quality_estimator, net_log().net_log(),
246 net_log().source());
247 // This use of base::Unretained() is safe because transport_socket_ is
248 // destroyed in the destructor.
249 return transport_socket_->Connect(base::BindOnce(
250 &WebSocketTransportConnectSubJob::OnIOComplete, base::Unretained(this)));
251 }
252
DoTransportConnectComplete(int result)253 int WebSocketTransportConnectSubJob::DoTransportConnectComplete(int result) {
254 next_state_ = STATE_DONE;
255 if (result != OK) {
256 websocket_endpoint_lock_manager_->UnlockEndpoint(CurrentAddress());
257
258 if (current_address_index_ + 1 < addresses_.size()) {
259 // Try falling back to the next address in the list.
260 next_state_ = STATE_OBTAIN_LOCK;
261 ++current_address_index_;
262 result = OK;
263 }
264
265 return result;
266 }
267
268 // On success, need to register the socket with the
269 // WebSocketEndpointLockManager.
270 transport_socket_ = std::make_unique<WebSocketStreamSocket>(
271 std::move(transport_socket_), websocket_endpoint_lock_manager_,
272 CurrentAddress());
273
274 return result;
275 }
276
277 } // namespace net
278