1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/websocket_transport_client_socket_pool.h"
6
7 #include <algorithm>
8
9 #include "base/bind.h"
10 #include "base/callback_helpers.h"
11 #include "base/check_op.h"
12 #include "base/compiler_specific.h"
13 #include "base/location.h"
14 #include "base/notreached.h"
15 #include "base/numerics/safe_conversions.h"
16 #include "base/strings/string_util.h"
17 #include "base/threading/thread_task_runner_handle.h"
18 #include "base/values.h"
19 #include "net/base/net_errors.h"
20 #include "net/log/net_log_event_type.h"
21 #include "net/log/net_log_source.h"
22 #include "net/log/net_log_source_type.h"
23 #include "net/socket/client_socket_handle.h"
24 #include "net/socket/connect_job.h"
25 #include "net/socket/websocket_endpoint_lock_manager.h"
26 #include "net/socket/websocket_transport_connect_job.h"
27 #include "net/traffic_annotation/network_traffic_annotation.h"
28
29 namespace net {
30
WebSocketTransportClientSocketPool(int max_sockets,int max_sockets_per_group,const ProxyServer & proxy_server,const CommonConnectJobParams * common_connect_job_params)31 WebSocketTransportClientSocketPool::WebSocketTransportClientSocketPool(
32 int max_sockets,
33 int max_sockets_per_group,
34 const ProxyServer& proxy_server,
35 const CommonConnectJobParams* common_connect_job_params)
36 : proxy_server_(proxy_server),
37 common_connect_job_params_(common_connect_job_params),
38 max_sockets_(max_sockets),
39 handed_out_socket_count_(0),
40 flushing_(false) {
41 DCHECK(common_connect_job_params_->websocket_endpoint_lock_manager);
42 }
43
~WebSocketTransportClientSocketPool()44 WebSocketTransportClientSocketPool::~WebSocketTransportClientSocketPool() {
45 // Clean up any pending connect jobs.
46 FlushWithError(ERR_ABORTED, "");
47 DCHECK(pending_connects_.empty());
48 DCHECK_EQ(0, handed_out_socket_count_);
49 DCHECK(stalled_request_queue_.empty());
50 DCHECK(stalled_request_map_.empty());
51 }
52
53 // static
UnlockEndpoint(ClientSocketHandle * handle,WebSocketEndpointLockManager * websocket_endpoint_lock_manager)54 void WebSocketTransportClientSocketPool::UnlockEndpoint(
55 ClientSocketHandle* handle,
56 WebSocketEndpointLockManager* websocket_endpoint_lock_manager) {
57 DCHECK(handle->is_initialized());
58 DCHECK(handle->socket());
59 IPEndPoint address;
60 if (handle->socket()->GetPeerAddress(&address) == OK)
61 websocket_endpoint_lock_manager->UnlockEndpoint(address);
62 }
63
RequestSocket(const GroupId & group_id,scoped_refptr<SocketParams> params,const base::Optional<NetworkTrafficAnnotationTag> & proxy_annotation_tag,RequestPriority priority,const SocketTag & socket_tag,RespectLimits respect_limits,ClientSocketHandle * handle,CompletionOnceCallback callback,const ProxyAuthCallback & proxy_auth_callback,const NetLogWithSource & request_net_log)64 int WebSocketTransportClientSocketPool::RequestSocket(
65 const GroupId& group_id,
66 scoped_refptr<SocketParams> params,
67 const base::Optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
68 RequestPriority priority,
69 const SocketTag& socket_tag,
70 RespectLimits respect_limits,
71 ClientSocketHandle* handle,
72 CompletionOnceCallback callback,
73 const ProxyAuthCallback& proxy_auth_callback,
74 const NetLogWithSource& request_net_log) {
75 DCHECK(params);
76 CHECK(!callback.is_null());
77 CHECK(handle);
78 DCHECK(socket_tag == SocketTag());
79
80 NetLogTcpClientSocketPoolRequestedSocket(request_net_log, group_id);
81 request_net_log.BeginEvent(NetLogEventType::SOCKET_POOL);
82
83 if (ReachedMaxSocketsLimit() &&
84 respect_limits == ClientSocketPool::RespectLimits::ENABLED) {
85 request_net_log.AddEvent(NetLogEventType::SOCKET_POOL_STALLED_MAX_SOCKETS);
86 stalled_request_queue_.emplace_back(group_id, params, proxy_annotation_tag,
87 priority, handle, std::move(callback),
88 proxy_auth_callback, request_net_log);
89 auto iterator = stalled_request_queue_.end();
90 --iterator;
91 DCHECK_EQ(handle, iterator->handle);
92 // Because StalledRequestQueue is a std::list, its iterators are guaranteed
93 // to remain valid as long as the elements are not removed. As long as
94 // stalled_request_queue_ and stalled_request_map_ are updated in sync, it
95 // is safe to dereference an iterator in stalled_request_map_ to find the
96 // corresponding list element.
97 stalled_request_map_.insert(
98 StalledRequestMap::value_type(handle, iterator));
99 return ERR_IO_PENDING;
100 }
101
102 std::unique_ptr<ConnectJobDelegate> connect_job_delegate =
103 std::make_unique<ConnectJobDelegate>(this, std::move(callback), handle,
104 request_net_log);
105
106 std::unique_ptr<ConnectJob> connect_job =
107 CreateConnectJob(group_id, params, proxy_server_, proxy_annotation_tag,
108 true /* is_for_websockets */, common_connect_job_params_,
109 priority, SocketTag(), connect_job_delegate.get());
110
111 int result = connect_job_delegate->Connect(std::move(connect_job));
112
113 // Regardless of the outcome of |connect_job|, it will always be bound to
114 // |handle|, since this pool uses early-binding. So the binding is logged
115 // here, without waiting for the result.
116 request_net_log.AddEventReferencingSource(
117 NetLogEventType::SOCKET_POOL_BOUND_TO_CONNECT_JOB,
118 connect_job_delegate->connect_job_net_log().source());
119
120 if (result == ERR_IO_PENDING) {
121 // TODO(ricea): Implement backup job timer?
122 AddJob(handle, std::move(connect_job_delegate));
123 } else {
124 TryHandOutSocket(result, connect_job_delegate.get());
125 }
126
127 return result;
128 }
129
RequestSockets(const GroupId & group_id,scoped_refptr<SocketParams> params,const base::Optional<NetworkTrafficAnnotationTag> & proxy_annotation_tag,int num_sockets,const NetLogWithSource & net_log)130 void WebSocketTransportClientSocketPool::RequestSockets(
131 const GroupId& group_id,
132 scoped_refptr<SocketParams> params,
133 const base::Optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
134 int num_sockets,
135 const NetLogWithSource& net_log) {
136 NOTIMPLEMENTED();
137 }
138
SetPriority(const GroupId & group_id,ClientSocketHandle * handle,RequestPriority priority)139 void WebSocketTransportClientSocketPool::SetPriority(const GroupId& group_id,
140 ClientSocketHandle* handle,
141 RequestPriority priority) {
142 // Since sockets requested by RequestSocket are bound early and
143 // stalled_request_{queue,map} don't take priorities into account, there's
144 // nothing to do within the pool to change priority of the request.
145 // TODO(rdsmith, ricea): Make stalled_request_{queue,map} take priorities
146 // into account.
147 // TODO(rdsmith, chlily): Investigate plumbing the reprioritization request to
148 // the connect job.
149 }
150
CancelRequest(const GroupId & group_id,ClientSocketHandle * handle,bool cancel_connect_job)151 void WebSocketTransportClientSocketPool::CancelRequest(
152 const GroupId& group_id,
153 ClientSocketHandle* handle,
154 bool cancel_connect_job) {
155 DCHECK(!handle->is_initialized());
156 if (DeleteStalledRequest(handle))
157 return;
158 std::unique_ptr<StreamSocket> socket = handle->PassSocket();
159 if (socket)
160 ReleaseSocket(handle->group_id(), std::move(socket),
161 handle->group_generation());
162 if (!DeleteJob(handle))
163 pending_callbacks_.erase(handle);
164
165 ActivateStalledRequest();
166 }
167
ReleaseSocket(const GroupId & group_id,std::unique_ptr<StreamSocket> socket,int64_t generation)168 void WebSocketTransportClientSocketPool::ReleaseSocket(
169 const GroupId& group_id,
170 std::unique_ptr<StreamSocket> socket,
171 int64_t generation) {
172 CHECK_GT(handed_out_socket_count_, 0);
173 --handed_out_socket_count_;
174
175 ActivateStalledRequest();
176 }
177
FlushWithError(int error,const char * net_log_reason_utf8)178 void WebSocketTransportClientSocketPool::FlushWithError(
179 int error,
180 const char* net_log_reason_utf8) {
181 DCHECK_NE(error, OK);
182
183 // Sockets which are in LOAD_STATE_CONNECTING are in danger of unlocking
184 // sockets waiting for the endpoint lock. If they connected synchronously,
185 // then OnConnectJobComplete(). The |flushing_| flag tells this object to
186 // ignore spurious calls to OnConnectJobComplete(). It is safe to ignore those
187 // calls because this method will delete the jobs and call their callbacks
188 // anyway.
189 flushing_ = true;
190 for (auto it = pending_connects_.begin(); it != pending_connects_.end();) {
191 InvokeUserCallbackLater(it->second->socket_handle(),
192 it->second->release_callback(), error);
193 it->second->connect_job_net_log().AddEventWithStringParams(
194 NetLogEventType::SOCKET_POOL_CLOSING_SOCKET, "reason",
195 net_log_reason_utf8);
196 it = pending_connects_.erase(it);
197 }
198 for (auto it = stalled_request_queue_.begin();
199 it != stalled_request_queue_.end(); ++it) {
200 InvokeUserCallbackLater(it->handle, std::move(it->callback), error);
201 }
202 stalled_request_map_.clear();
203 stalled_request_queue_.clear();
204 flushing_ = false;
205 }
206
CloseIdleSockets(const char * net_log_reason_utf8)207 void WebSocketTransportClientSocketPool::CloseIdleSockets(
208 const char* net_log_reason_utf8) {
209 // We have no idle sockets.
210 }
211
CloseIdleSocketsInGroup(const GroupId & group_id,const char * net_log_reason_utf8)212 void WebSocketTransportClientSocketPool::CloseIdleSocketsInGroup(
213 const GroupId& group_id,
214 const char* net_log_reason_utf8) {
215 // We have no idle sockets.
216 }
217
IdleSocketCount() const218 int WebSocketTransportClientSocketPool::IdleSocketCount() const {
219 return 0;
220 }
221
IdleSocketCountInGroup(const GroupId & group_id) const222 size_t WebSocketTransportClientSocketPool::IdleSocketCountInGroup(
223 const GroupId& group_id) const {
224 return 0;
225 }
226
GetLoadState(const GroupId & group_id,const ClientSocketHandle * handle) const227 LoadState WebSocketTransportClientSocketPool::GetLoadState(
228 const GroupId& group_id,
229 const ClientSocketHandle* handle) const {
230 if (stalled_request_map_.find(handle) != stalled_request_map_.end())
231 return LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET;
232 if (pending_callbacks_.count(handle))
233 return LOAD_STATE_CONNECTING;
234 return LookupConnectJob(handle)->GetLoadState();
235 }
236
GetInfoAsValue(const std::string & name,const std::string & type) const237 base::Value WebSocketTransportClientSocketPool::GetInfoAsValue(
238 const std::string& name,
239 const std::string& type) const {
240 base::Value dict(base::Value::Type::DICTIONARY);
241 dict.SetStringKey("name", name);
242 dict.SetStringKey("type", type);
243 dict.SetIntKey("handed_out_socket_count", handed_out_socket_count_);
244 dict.SetIntKey("connecting_socket_count", pending_connects_.size());
245 dict.SetIntKey("idle_socket_count", 0);
246 dict.SetIntKey("max_socket_count", max_sockets_);
247 dict.SetIntKey("max_sockets_per_group", max_sockets_);
248 return dict;
249 }
250
DumpMemoryStats(base::trace_event::ProcessMemoryDump * pmd,const std::string & parent_dump_absolute_name) const251 void WebSocketTransportClientSocketPool::DumpMemoryStats(
252 base::trace_event::ProcessMemoryDump* pmd,
253 const std::string& parent_dump_absolute_name) const {
254 // Not supported.
255 }
256
IsStalled() const257 bool WebSocketTransportClientSocketPool::IsStalled() const {
258 return !stalled_request_queue_.empty();
259 }
260
AddHigherLayeredPool(HigherLayeredPool * higher_pool)261 void WebSocketTransportClientSocketPool::AddHigherLayeredPool(
262 HigherLayeredPool* higher_pool) {
263 // This class doesn't use connection limits like the pools for HTTP do, so no
264 // need to track higher layered pools.
265 }
266
RemoveHigherLayeredPool(HigherLayeredPool * higher_pool)267 void WebSocketTransportClientSocketPool::RemoveHigherLayeredPool(
268 HigherLayeredPool* higher_pool) {
269 // This class doesn't use connection limits like the pools for HTTP do, so no
270 // need to track higher layered pools.
271 }
272
TryHandOutSocket(int result,ConnectJobDelegate * connect_job_delegate)273 bool WebSocketTransportClientSocketPool::TryHandOutSocket(
274 int result,
275 ConnectJobDelegate* connect_job_delegate) {
276 DCHECK_NE(result, ERR_IO_PENDING);
277
278 std::unique_ptr<StreamSocket> socket =
279 connect_job_delegate->connect_job()->PassSocket();
280 LoadTimingInfo::ConnectTiming connect_timing =
281 connect_job_delegate->connect_job()->connect_timing();
282 ClientSocketHandle* const handle = connect_job_delegate->socket_handle();
283 NetLogWithSource request_net_log = connect_job_delegate->request_net_log();
284
285 if (result == OK) {
286 DCHECK(socket);
287
288 HandOutSocket(std::move(socket), connect_timing, handle, request_net_log);
289
290 request_net_log.EndEvent(NetLogEventType::SOCKET_POOL);
291
292 return true;
293 }
294
295 bool handed_out_socket = false;
296
297 // If we got a socket, it must contain error information so pass that
298 // up so that the caller can retrieve it.
299 handle->SetAdditionalErrorState(connect_job_delegate->connect_job());
300 if (socket) {
301 HandOutSocket(std::move(socket), connect_timing, handle, request_net_log);
302 handed_out_socket = true;
303 }
304
305 request_net_log.EndEventWithNetErrorCode(NetLogEventType::SOCKET_POOL,
306 result);
307
308 return handed_out_socket;
309 }
310
OnConnectJobComplete(int result,ConnectJobDelegate * connect_job_delegate)311 void WebSocketTransportClientSocketPool::OnConnectJobComplete(
312 int result,
313 ConnectJobDelegate* connect_job_delegate) {
314 DCHECK_NE(ERR_IO_PENDING, result);
315
316 // See comment in FlushWithError.
317 if (flushing_) {
318 // Just delete the socket.
319 std::unique_ptr<StreamSocket> socket =
320 connect_job_delegate->connect_job()->PassSocket();
321 return;
322 }
323
324 bool handed_out_socket = TryHandOutSocket(result, connect_job_delegate);
325
326 CompletionOnceCallback callback = connect_job_delegate->release_callback();
327
328 ClientSocketHandle* const handle = connect_job_delegate->socket_handle();
329
330 bool delete_succeeded = DeleteJob(handle);
331 DCHECK(delete_succeeded);
332
333 connect_job_delegate = nullptr;
334
335 if (!handed_out_socket)
336 ActivateStalledRequest();
337
338 InvokeUserCallbackLater(handle, std::move(callback), result);
339 }
340
InvokeUserCallbackLater(ClientSocketHandle * handle,CompletionOnceCallback callback,int rv)341 void WebSocketTransportClientSocketPool::InvokeUserCallbackLater(
342 ClientSocketHandle* handle,
343 CompletionOnceCallback callback,
344 int rv) {
345 DCHECK(!pending_callbacks_.count(handle));
346 pending_callbacks_.insert(handle);
347 base::ThreadTaskRunnerHandle::Get()->PostTask(
348 FROM_HERE,
349 base::BindOnce(&WebSocketTransportClientSocketPool::InvokeUserCallback,
350 weak_factory_.GetWeakPtr(), handle, std::move(callback),
351 rv));
352 }
353
InvokeUserCallback(ClientSocketHandle * handle,CompletionOnceCallback callback,int rv)354 void WebSocketTransportClientSocketPool::InvokeUserCallback(
355 ClientSocketHandle* handle,
356 CompletionOnceCallback callback,
357 int rv) {
358 if (pending_callbacks_.erase(handle))
359 std::move(callback).Run(rv);
360 }
361
ReachedMaxSocketsLimit() const362 bool WebSocketTransportClientSocketPool::ReachedMaxSocketsLimit() const {
363 return handed_out_socket_count_ >= max_sockets_ ||
364 base::checked_cast<int>(pending_connects_.size()) >=
365 max_sockets_ - handed_out_socket_count_;
366 }
367
HandOutSocket(std::unique_ptr<StreamSocket> socket,const LoadTimingInfo::ConnectTiming & connect_timing,ClientSocketHandle * handle,const NetLogWithSource & net_log)368 void WebSocketTransportClientSocketPool::HandOutSocket(
369 std::unique_ptr<StreamSocket> socket,
370 const LoadTimingInfo::ConnectTiming& connect_timing,
371 ClientSocketHandle* handle,
372 const NetLogWithSource& net_log) {
373 DCHECK(socket);
374 DCHECK_EQ(ClientSocketHandle::UNUSED, handle->reuse_type());
375 DCHECK_EQ(0, handle->idle_time().InMicroseconds());
376
377 handle->SetSocket(std::move(socket));
378 handle->set_group_generation(0);
379 handle->set_connect_timing(connect_timing);
380
381 net_log.AddEventReferencingSource(
382 NetLogEventType::SOCKET_POOL_BOUND_TO_SOCKET,
383 handle->socket()->NetLog().source());
384
385 ++handed_out_socket_count_;
386 }
387
AddJob(ClientSocketHandle * handle,std::unique_ptr<ConnectJobDelegate> delegate)388 void WebSocketTransportClientSocketPool::AddJob(
389 ClientSocketHandle* handle,
390 std::unique_ptr<ConnectJobDelegate> delegate) {
391 bool inserted =
392 pending_connects_
393 .insert(PendingConnectsMap::value_type(handle, std::move(delegate)))
394 .second;
395 DCHECK(inserted);
396 }
397
DeleteJob(ClientSocketHandle * handle)398 bool WebSocketTransportClientSocketPool::DeleteJob(ClientSocketHandle* handle) {
399 auto it = pending_connects_.find(handle);
400 if (it == pending_connects_.end())
401 return false;
402 // Deleting a ConnectJob which holds an endpoint lock can lead to a different
403 // ConnectJob proceeding to connect. If the connect proceeds synchronously
404 // (usually because of a failure) then it can trigger that job to be
405 // deleted.
406 pending_connects_.erase(it);
407 return true;
408 }
409
LookupConnectJob(const ClientSocketHandle * handle) const410 const ConnectJob* WebSocketTransportClientSocketPool::LookupConnectJob(
411 const ClientSocketHandle* handle) const {
412 auto it = pending_connects_.find(handle);
413 CHECK(it != pending_connects_.end());
414 return it->second->connect_job();
415 }
416
ActivateStalledRequest()417 void WebSocketTransportClientSocketPool::ActivateStalledRequest() {
418 // Usually we will only be able to activate one stalled request at a time,
419 // however if all the connects fail synchronously for some reason, we may be
420 // able to clear the whole queue at once.
421 while (!stalled_request_queue_.empty() && !ReachedMaxSocketsLimit()) {
422 StalledRequest request = std::move(stalled_request_queue_.front());
423 stalled_request_queue_.pop_front();
424 stalled_request_map_.erase(request.handle);
425
426 // Wrap request.callback into a copyable (repeating) callback so that it can
427 // be passed to RequestSocket() and yet called if RequestSocket() returns
428 // synchronously.
429 auto copyable_callback =
430 base::AdaptCallbackForRepeating(std::move(request.callback));
431
432 int rv = RequestSocket(
433 request.group_id, request.params, request.proxy_annotation_tag,
434 request.priority, SocketTag(),
435 // Stalled requests can't have |respect_limits|
436 // DISABLED.
437 RespectLimits::ENABLED, request.handle, copyable_callback,
438 request.proxy_auth_callback, request.net_log);
439
440 // ActivateStalledRequest() never returns synchronously, so it is never
441 // called re-entrantly.
442 if (rv != ERR_IO_PENDING)
443 InvokeUserCallbackLater(request.handle, copyable_callback, rv);
444 }
445 }
446
DeleteStalledRequest(ClientSocketHandle * handle)447 bool WebSocketTransportClientSocketPool::DeleteStalledRequest(
448 ClientSocketHandle* handle) {
449 auto it = stalled_request_map_.find(handle);
450 if (it == stalled_request_map_.end())
451 return false;
452 stalled_request_queue_.erase(it->second);
453 stalled_request_map_.erase(it);
454 return true;
455 }
456
ConnectJobDelegate(WebSocketTransportClientSocketPool * owner,CompletionOnceCallback callback,ClientSocketHandle * socket_handle,const NetLogWithSource & request_net_log)457 WebSocketTransportClientSocketPool::ConnectJobDelegate::ConnectJobDelegate(
458 WebSocketTransportClientSocketPool* owner,
459 CompletionOnceCallback callback,
460 ClientSocketHandle* socket_handle,
461 const NetLogWithSource& request_net_log)
462 : owner_(owner),
463 callback_(std::move(callback)),
464 socket_handle_(socket_handle),
465 request_net_log_(request_net_log) {}
466
467 WebSocketTransportClientSocketPool::ConnectJobDelegate::~ConnectJobDelegate() =
468 default;
469
470 void
OnConnectJobComplete(int result,ConnectJob * job)471 WebSocketTransportClientSocketPool::ConnectJobDelegate::OnConnectJobComplete(
472 int result,
473 ConnectJob* job) {
474 DCHECK_EQ(job, connect_job_.get());
475 owner_->OnConnectJobComplete(result, this);
476 }
477
OnNeedsProxyAuth(const HttpResponseInfo & response,HttpAuthController * auth_controller,base::OnceClosure restart_with_auth_callback,ConnectJob * job)478 void WebSocketTransportClientSocketPool::ConnectJobDelegate::OnNeedsProxyAuth(
479 const HttpResponseInfo& response,
480 HttpAuthController* auth_controller,
481 base::OnceClosure restart_with_auth_callback,
482 ConnectJob* job) {
483 // This class isn't used for proxies.
484 NOTREACHED();
485 }
486
Connect(std::unique_ptr<ConnectJob> connect_job)487 int WebSocketTransportClientSocketPool::ConnectJobDelegate::Connect(
488 std::unique_ptr<ConnectJob> connect_job) {
489 connect_job_ = std::move(connect_job);
490 return connect_job_->Connect();
491 }
492
493 const NetLogWithSource&
connect_job_net_log()494 WebSocketTransportClientSocketPool::ConnectJobDelegate::connect_job_net_log() {
495 return connect_job_->net_log();
496 }
497
StalledRequest(const GroupId & group_id,const scoped_refptr<SocketParams> & params,const base::Optional<NetworkTrafficAnnotationTag> & proxy_annotation_tag,RequestPriority priority,ClientSocketHandle * handle,CompletionOnceCallback callback,const ProxyAuthCallback & proxy_auth_callback,const NetLogWithSource & net_log)498 WebSocketTransportClientSocketPool::StalledRequest::StalledRequest(
499 const GroupId& group_id,
500 const scoped_refptr<SocketParams>& params,
501 const base::Optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
502 RequestPriority priority,
503 ClientSocketHandle* handle,
504 CompletionOnceCallback callback,
505 const ProxyAuthCallback& proxy_auth_callback,
506 const NetLogWithSource& net_log)
507 : group_id(group_id),
508 params(params),
509 proxy_annotation_tag(proxy_annotation_tag),
510 priority(priority),
511 handle(handle),
512 callback(std::move(callback)),
513 proxy_auth_callback(proxy_auth_callback),
514 net_log(net_log) {}
515
516 WebSocketTransportClientSocketPool::StalledRequest::StalledRequest(
517 StalledRequest&& other) = default;
518
519 WebSocketTransportClientSocketPool::StalledRequest::~StalledRequest() = default;
520
521 } // namespace net
522