1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "services/network/tcp_connected_socket.h"
6
7 #include <utility>
8
9 #include "base/bind.h"
10 #include "base/logging.h"
11 #include "base/numerics/ranges.h"
12 #include "base/numerics/safe_conversions.h"
13 #include "base/optional.h"
14 #include "net/base/net_errors.h"
15 #include "net/log/net_log.h"
16 #include "net/socket/client_socket_factory.h"
17 #include "net/socket/client_socket_handle.h"
18 #include "services/network/tls_client_socket.h"
19
20 namespace network {
21
22 namespace {
23
ClampTCPBufferSize(int requested_buffer_size)24 int ClampTCPBufferSize(int requested_buffer_size) {
25 return base::ClampToRange(requested_buffer_size, 0,
26 TCPConnectedSocket::kMaxBufferSize);
27 }
28
29 // Sets the initial options on a fresh socket. Assumes |socket| is currently
30 // configured using the default client socket options
31 // (TCPSocket::SetDefaultOptionsForClient()).
ConfigureSocket(net::TransportClientSocket * socket,const mojom::TCPConnectedSocketOptions & tcp_connected_socket_options)32 int ConfigureSocket(
33 net::TransportClientSocket* socket,
34 const mojom::TCPConnectedSocketOptions& tcp_connected_socket_options) {
35 int send_buffer_size =
36 ClampTCPBufferSize(tcp_connected_socket_options.send_buffer_size);
37 if (send_buffer_size > 0) {
38 int result = socket->SetSendBufferSize(send_buffer_size);
39 DCHECK_NE(net::ERR_IO_PENDING, result);
40 if (result != net::OK)
41 return result;
42 }
43
44 int receive_buffer_size =
45 ClampTCPBufferSize(tcp_connected_socket_options.receive_buffer_size);
46 if (receive_buffer_size > 0) {
47 int result = socket->SetReceiveBufferSize(receive_buffer_size);
48 DCHECK_NE(net::ERR_IO_PENDING, result);
49 if (result != net::OK)
50 return result;
51 }
52
53 // No delay is set by default, so only update the setting if it's false.
54 if (!tcp_connected_socket_options.no_delay) {
55 // Unlike the above calls, TcpSocket::SetNoDelay() returns a bool rather
56 // than a network error code.
57 if (!socket->SetNoDelay(false))
58 return net::ERR_FAILED;
59 }
60
61 return net::OK;
62 }
63
64 } // namespace
65
66 const int TCPConnectedSocket::kMaxBufferSize = 128 * 1024;
67
TCPConnectedSocket(mojo::PendingRemote<mojom::SocketObserver> observer,net::NetLog * net_log,TLSSocketFactory * tls_socket_factory,net::ClientSocketFactory * client_socket_factory,const net::NetworkTrafficAnnotationTag & traffic_annotation)68 TCPConnectedSocket::TCPConnectedSocket(
69 mojo::PendingRemote<mojom::SocketObserver> observer,
70 net::NetLog* net_log,
71 TLSSocketFactory* tls_socket_factory,
72 net::ClientSocketFactory* client_socket_factory,
73 const net::NetworkTrafficAnnotationTag& traffic_annotation)
74 : observer_(std::move(observer)),
75 net_log_(net_log),
76 client_socket_factory_(client_socket_factory),
77 tls_socket_factory_(tls_socket_factory),
78 traffic_annotation_(traffic_annotation) {}
79
TCPConnectedSocket(mojo::PendingRemote<mojom::SocketObserver> observer,std::unique_ptr<net::TransportClientSocket> socket,mojo::ScopedDataPipeProducerHandle receive_pipe_handle,mojo::ScopedDataPipeConsumerHandle send_pipe_handle,const net::NetworkTrafficAnnotationTag & traffic_annotation)80 TCPConnectedSocket::TCPConnectedSocket(
81 mojo::PendingRemote<mojom::SocketObserver> observer,
82 std::unique_ptr<net::TransportClientSocket> socket,
83 mojo::ScopedDataPipeProducerHandle receive_pipe_handle,
84 mojo::ScopedDataPipeConsumerHandle send_pipe_handle,
85 const net::NetworkTrafficAnnotationTag& traffic_annotation)
86 : observer_(std::move(observer)),
87 net_log_(nullptr),
88 client_socket_factory_(nullptr),
89 tls_socket_factory_(nullptr),
90 socket_(std::move(socket)),
91 traffic_annotation_(traffic_annotation) {
92 socket_data_pump_ = std::make_unique<SocketDataPump>(
93 socket_.get(), this /*delegate*/, std::move(receive_pipe_handle),
94 std::move(send_pipe_handle), traffic_annotation);
95 }
96
~TCPConnectedSocket()97 TCPConnectedSocket::~TCPConnectedSocket() {
98 if (connect_callback_) {
99 // If |this| is destroyed when connect hasn't completed, tell the consumer
100 // that request has been aborted.
101 std::move(connect_callback_)
102 .Run(net::ERR_ABORTED, base::nullopt, base::nullopt,
103 mojo::ScopedDataPipeConsumerHandle(),
104 mojo::ScopedDataPipeProducerHandle());
105 }
106 }
107
Connect(const base::Optional<net::IPEndPoint> & local_addr,const net::AddressList & remote_addr_list,mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options,mojom::NetworkContext::CreateTCPConnectedSocketCallback callback)108 void TCPConnectedSocket::Connect(
109 const base::Optional<net::IPEndPoint>& local_addr,
110 const net::AddressList& remote_addr_list,
111 mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options,
112 mojom::NetworkContext::CreateTCPConnectedSocketCallback callback) {
113 DCHECK(!socket_);
114 DCHECK(callback);
115
116 std::unique_ptr<net::TransportClientSocket> socket =
117 client_socket_factory_->CreateTransportClientSocket(
118 remote_addr_list, nullptr /*socket_performance_watcher*/, net_log_,
119 net::NetLogSource());
120
121 if (local_addr) {
122 int result = socket->Bind(local_addr.value());
123 if (result != net::OK) {
124 OnConnectCompleted(result);
125 return;
126 }
127 }
128
129 return ConnectWithSocket(std::move(socket),
130 std::move(tcp_connected_socket_options),
131 std::move(callback));
132 }
133
ConnectWithSocket(std::unique_ptr<net::TransportClientSocket> socket,mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options,mojom::NetworkContext::CreateTCPConnectedSocketCallback callback)134 void TCPConnectedSocket::ConnectWithSocket(
135 std::unique_ptr<net::TransportClientSocket> socket,
136 mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options,
137 mojom::NetworkContext::CreateTCPConnectedSocketCallback callback) {
138 socket_ = std::move(socket);
139 connect_callback_ = std::move(callback);
140
141 if (tcp_connected_socket_options) {
142 socket_->SetBeforeConnectCallback(base::BindRepeating(
143 &ConfigureSocket, socket_.get(), *tcp_connected_socket_options));
144 }
145 int result = socket_->Connect(base::BindRepeating(
146 &TCPConnectedSocket::OnConnectCompleted, base::Unretained(this)));
147
148 if (result == net::ERR_IO_PENDING)
149 return;
150
151 OnConnectCompleted(result);
152 }
153
UpgradeToTLS(const net::HostPortPair & host_port_pair,mojom::TLSClientSocketOptionsPtr socket_options,const net::MutableNetworkTrafficAnnotationTag & traffic_annotation,mojo::PendingReceiver<mojom::TLSClientSocket> receiver,mojo::PendingRemote<mojom::SocketObserver> observer,mojom::TCPConnectedSocket::UpgradeToTLSCallback callback)154 void TCPConnectedSocket::UpgradeToTLS(
155 const net::HostPortPair& host_port_pair,
156 mojom::TLSClientSocketOptionsPtr socket_options,
157 const net::MutableNetworkTrafficAnnotationTag& traffic_annotation,
158 mojo::PendingReceiver<mojom::TLSClientSocket> receiver,
159 mojo::PendingRemote<mojom::SocketObserver> observer,
160 mojom::TCPConnectedSocket::UpgradeToTLSCallback callback) {
161 if (!tls_socket_factory_) {
162 std::move(callback).Run(
163 net::ERR_NOT_IMPLEMENTED, mojo::ScopedDataPipeConsumerHandle(),
164 mojo::ScopedDataPipeProducerHandle(), base::nullopt /* ssl_info*/);
165 return;
166 }
167 // Wait for data pipes to be closed by the client before doing the upgrade.
168 if (socket_data_pump_) {
169 pending_upgrade_to_tls_callback_ = base::BindOnce(
170 &TCPConnectedSocket::UpgradeToTLS, base::Unretained(this),
171 host_port_pair, std::move(socket_options), traffic_annotation,
172 std::move(receiver), std::move(observer), std::move(callback));
173 return;
174 }
175 tls_socket_factory_->UpgradeToTLS(
176 this, host_port_pair, std::move(socket_options), traffic_annotation,
177 std::move(receiver), std::move(observer), std::move(callback));
178 }
179
SetSendBufferSize(int send_buffer_size,SetSendBufferSizeCallback callback)180 void TCPConnectedSocket::SetSendBufferSize(int send_buffer_size,
181 SetSendBufferSizeCallback callback) {
182 if (!socket_) {
183 // Fail is this method was called after upgrading to TLS.
184 std::move(callback).Run(net::ERR_UNEXPECTED);
185 return;
186 }
187 int result = socket_->SetSendBufferSize(ClampTCPBufferSize(send_buffer_size));
188 std::move(callback).Run(result);
189 }
190
SetReceiveBufferSize(int send_buffer_size,SetSendBufferSizeCallback callback)191 void TCPConnectedSocket::SetReceiveBufferSize(
192 int send_buffer_size,
193 SetSendBufferSizeCallback callback) {
194 if (!socket_) {
195 // Fail is this method was called after upgrading to TLS.
196 std::move(callback).Run(net::ERR_UNEXPECTED);
197 return;
198 }
199 int result =
200 socket_->SetReceiveBufferSize(ClampTCPBufferSize(send_buffer_size));
201 std::move(callback).Run(result);
202 }
203
SetNoDelay(bool no_delay,SetNoDelayCallback callback)204 void TCPConnectedSocket::SetNoDelay(bool no_delay,
205 SetNoDelayCallback callback) {
206 if (!socket_) {
207 std::move(callback).Run(false);
208 return;
209 }
210 bool success = socket_->SetNoDelay(no_delay);
211 std::move(callback).Run(success);
212 }
213
SetKeepAlive(bool enable,int32_t delay_secs,SetKeepAliveCallback callback)214 void TCPConnectedSocket::SetKeepAlive(bool enable,
215 int32_t delay_secs,
216 SetKeepAliveCallback callback) {
217 if (!socket_) {
218 std::move(callback).Run(false);
219 return;
220 }
221 bool success = socket_->SetKeepAlive(enable, delay_secs);
222 std::move(callback).Run(success);
223 }
224
OnConnectCompleted(int result)225 void TCPConnectedSocket::OnConnectCompleted(int result) {
226 DCHECK(!connect_callback_.is_null());
227 DCHECK(!socket_data_pump_);
228
229 net::IPEndPoint peer_addr, local_addr;
230 if (result == net::OK)
231 result = socket_->GetLocalAddress(&local_addr);
232 if (result == net::OK)
233 result = socket_->GetPeerAddress(&peer_addr);
234
235 if (result != net::OK) {
236 std::move(connect_callback_)
237 .Run(result, base::nullopt, base::nullopt,
238 mojo::ScopedDataPipeConsumerHandle(),
239 mojo::ScopedDataPipeProducerHandle());
240 return;
241 }
242 mojo::DataPipe send_pipe;
243 mojo::DataPipe receive_pipe;
244 socket_data_pump_ = std::make_unique<SocketDataPump>(
245 socket_.get(), this /*delegate*/, std::move(receive_pipe.producer_handle),
246 std::move(send_pipe.consumer_handle), traffic_annotation_);
247 std::move(connect_callback_)
248 .Run(net::OK, local_addr, peer_addr,
249 std::move(receive_pipe.consumer_handle),
250 std::move(send_pipe.producer_handle));
251 }
252
OnNetworkReadError(int net_error)253 void TCPConnectedSocket::OnNetworkReadError(int net_error) {
254 if (observer_)
255 observer_->OnReadError(net_error);
256 }
257
OnNetworkWriteError(int net_error)258 void TCPConnectedSocket::OnNetworkWriteError(int net_error) {
259 if (observer_)
260 observer_->OnWriteError(net_error);
261 }
262
OnShutdown()263 void TCPConnectedSocket::OnShutdown() {
264 socket_data_pump_ = nullptr;
265 if (!pending_upgrade_to_tls_callback_.is_null())
266 std::move(pending_upgrade_to_tls_callback_).Run();
267 }
268
BorrowSocket()269 const net::StreamSocket* TCPConnectedSocket::BorrowSocket() {
270 return socket_.get();
271 }
272
TakeSocket()273 std::unique_ptr<net::StreamSocket> TCPConnectedSocket::TakeSocket() {
274 return std::move(socket_);
275 }
276
277 } // namespace network
278