1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "services/network/tcp_connected_socket.h"
6 
7 #include <utility>
8 
9 #include "base/bind.h"
10 #include "base/logging.h"
11 #include "base/numerics/ranges.h"
12 #include "base/numerics/safe_conversions.h"
13 #include "base/optional.h"
14 #include "net/base/net_errors.h"
15 #include "net/log/net_log.h"
16 #include "net/socket/client_socket_factory.h"
17 #include "net/socket/client_socket_handle.h"
18 #include "services/network/tls_client_socket.h"
19 
20 namespace network {
21 
22 namespace {
23 
ClampTCPBufferSize(int requested_buffer_size)24 int ClampTCPBufferSize(int requested_buffer_size) {
25   return base::ClampToRange(requested_buffer_size, 0,
26                             TCPConnectedSocket::kMaxBufferSize);
27 }
28 
29 // Sets the initial options on a fresh socket. Assumes |socket| is currently
30 // configured using the default client socket options
31 // (TCPSocket::SetDefaultOptionsForClient()).
ConfigureSocket(net::TransportClientSocket * socket,const mojom::TCPConnectedSocketOptions & tcp_connected_socket_options)32 int ConfigureSocket(
33     net::TransportClientSocket* socket,
34     const mojom::TCPConnectedSocketOptions& tcp_connected_socket_options) {
35   int send_buffer_size =
36       ClampTCPBufferSize(tcp_connected_socket_options.send_buffer_size);
37   if (send_buffer_size > 0) {
38     int result = socket->SetSendBufferSize(send_buffer_size);
39     DCHECK_NE(net::ERR_IO_PENDING, result);
40     if (result != net::OK)
41       return result;
42   }
43 
44   int receive_buffer_size =
45       ClampTCPBufferSize(tcp_connected_socket_options.receive_buffer_size);
46   if (receive_buffer_size > 0) {
47     int result = socket->SetReceiveBufferSize(receive_buffer_size);
48     DCHECK_NE(net::ERR_IO_PENDING, result);
49     if (result != net::OK)
50       return result;
51   }
52 
53   // No delay is set by default, so only update the setting if it's false.
54   if (!tcp_connected_socket_options.no_delay) {
55     // Unlike the above calls, TcpSocket::SetNoDelay() returns a bool rather
56     // than a network error code.
57     if (!socket->SetNoDelay(false))
58       return net::ERR_FAILED;
59   }
60 
61   return net::OK;
62 }
63 
64 }  // namespace
65 
66 const int TCPConnectedSocket::kMaxBufferSize = 128 * 1024;
67 
TCPConnectedSocket(mojo::PendingRemote<mojom::SocketObserver> observer,net::NetLog * net_log,TLSSocketFactory * tls_socket_factory,net::ClientSocketFactory * client_socket_factory,const net::NetworkTrafficAnnotationTag & traffic_annotation)68 TCPConnectedSocket::TCPConnectedSocket(
69     mojo::PendingRemote<mojom::SocketObserver> observer,
70     net::NetLog* net_log,
71     TLSSocketFactory* tls_socket_factory,
72     net::ClientSocketFactory* client_socket_factory,
73     const net::NetworkTrafficAnnotationTag& traffic_annotation)
74     : observer_(std::move(observer)),
75       net_log_(net_log),
76       client_socket_factory_(client_socket_factory),
77       tls_socket_factory_(tls_socket_factory),
78       traffic_annotation_(traffic_annotation) {}
79 
TCPConnectedSocket(mojo::PendingRemote<mojom::SocketObserver> observer,std::unique_ptr<net::TransportClientSocket> socket,mojo::ScopedDataPipeProducerHandle receive_pipe_handle,mojo::ScopedDataPipeConsumerHandle send_pipe_handle,const net::NetworkTrafficAnnotationTag & traffic_annotation)80 TCPConnectedSocket::TCPConnectedSocket(
81     mojo::PendingRemote<mojom::SocketObserver> observer,
82     std::unique_ptr<net::TransportClientSocket> socket,
83     mojo::ScopedDataPipeProducerHandle receive_pipe_handle,
84     mojo::ScopedDataPipeConsumerHandle send_pipe_handle,
85     const net::NetworkTrafficAnnotationTag& traffic_annotation)
86     : observer_(std::move(observer)),
87       net_log_(nullptr),
88       client_socket_factory_(nullptr),
89       tls_socket_factory_(nullptr),
90       socket_(std::move(socket)),
91       traffic_annotation_(traffic_annotation) {
92   socket_data_pump_ = std::make_unique<SocketDataPump>(
93       socket_.get(), this /*delegate*/, std::move(receive_pipe_handle),
94       std::move(send_pipe_handle), traffic_annotation);
95 }
96 
~TCPConnectedSocket()97 TCPConnectedSocket::~TCPConnectedSocket() {
98   if (connect_callback_) {
99     // If |this| is destroyed when connect hasn't completed, tell the consumer
100     // that request has been aborted.
101     std::move(connect_callback_)
102         .Run(net::ERR_ABORTED, base::nullopt, base::nullopt,
103              mojo::ScopedDataPipeConsumerHandle(),
104              mojo::ScopedDataPipeProducerHandle());
105   }
106 }
107 
Connect(const base::Optional<net::IPEndPoint> & local_addr,const net::AddressList & remote_addr_list,mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options,mojom::NetworkContext::CreateTCPConnectedSocketCallback callback)108 void TCPConnectedSocket::Connect(
109     const base::Optional<net::IPEndPoint>& local_addr,
110     const net::AddressList& remote_addr_list,
111     mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options,
112     mojom::NetworkContext::CreateTCPConnectedSocketCallback callback) {
113   DCHECK(!socket_);
114   DCHECK(callback);
115 
116   std::unique_ptr<net::TransportClientSocket> socket =
117       client_socket_factory_->CreateTransportClientSocket(
118           remote_addr_list, nullptr /*socket_performance_watcher*/, net_log_,
119           net::NetLogSource());
120 
121   if (local_addr) {
122     int result = socket->Bind(local_addr.value());
123     if (result != net::OK) {
124       OnConnectCompleted(result);
125       return;
126     }
127   }
128 
129   return ConnectWithSocket(std::move(socket),
130                            std::move(tcp_connected_socket_options),
131                            std::move(callback));
132 }
133 
ConnectWithSocket(std::unique_ptr<net::TransportClientSocket> socket,mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options,mojom::NetworkContext::CreateTCPConnectedSocketCallback callback)134 void TCPConnectedSocket::ConnectWithSocket(
135     std::unique_ptr<net::TransportClientSocket> socket,
136     mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options,
137     mojom::NetworkContext::CreateTCPConnectedSocketCallback callback) {
138   socket_ = std::move(socket);
139   connect_callback_ = std::move(callback);
140 
141   if (tcp_connected_socket_options) {
142     socket_->SetBeforeConnectCallback(base::BindRepeating(
143         &ConfigureSocket, socket_.get(), *tcp_connected_socket_options));
144   }
145   int result = socket_->Connect(base::BindRepeating(
146       &TCPConnectedSocket::OnConnectCompleted, base::Unretained(this)));
147 
148   if (result == net::ERR_IO_PENDING)
149     return;
150 
151   OnConnectCompleted(result);
152 }
153 
UpgradeToTLS(const net::HostPortPair & host_port_pair,mojom::TLSClientSocketOptionsPtr socket_options,const net::MutableNetworkTrafficAnnotationTag & traffic_annotation,mojo::PendingReceiver<mojom::TLSClientSocket> receiver,mojo::PendingRemote<mojom::SocketObserver> observer,mojom::TCPConnectedSocket::UpgradeToTLSCallback callback)154 void TCPConnectedSocket::UpgradeToTLS(
155     const net::HostPortPair& host_port_pair,
156     mojom::TLSClientSocketOptionsPtr socket_options,
157     const net::MutableNetworkTrafficAnnotationTag& traffic_annotation,
158     mojo::PendingReceiver<mojom::TLSClientSocket> receiver,
159     mojo::PendingRemote<mojom::SocketObserver> observer,
160     mojom::TCPConnectedSocket::UpgradeToTLSCallback callback) {
161   if (!tls_socket_factory_) {
162     std::move(callback).Run(
163         net::ERR_NOT_IMPLEMENTED, mojo::ScopedDataPipeConsumerHandle(),
164         mojo::ScopedDataPipeProducerHandle(), base::nullopt /* ssl_info*/);
165     return;
166   }
167   // Wait for data pipes to be closed by the client before doing the upgrade.
168   if (socket_data_pump_) {
169     pending_upgrade_to_tls_callback_ = base::BindOnce(
170         &TCPConnectedSocket::UpgradeToTLS, base::Unretained(this),
171         host_port_pair, std::move(socket_options), traffic_annotation,
172         std::move(receiver), std::move(observer), std::move(callback));
173     return;
174   }
175   tls_socket_factory_->UpgradeToTLS(
176       this, host_port_pair, std::move(socket_options), traffic_annotation,
177       std::move(receiver), std::move(observer), std::move(callback));
178 }
179 
SetSendBufferSize(int send_buffer_size,SetSendBufferSizeCallback callback)180 void TCPConnectedSocket::SetSendBufferSize(int send_buffer_size,
181                                            SetSendBufferSizeCallback callback) {
182   if (!socket_) {
183     // Fail is this method was called after upgrading to TLS.
184     std::move(callback).Run(net::ERR_UNEXPECTED);
185     return;
186   }
187   int result = socket_->SetSendBufferSize(ClampTCPBufferSize(send_buffer_size));
188   std::move(callback).Run(result);
189 }
190 
SetReceiveBufferSize(int send_buffer_size,SetSendBufferSizeCallback callback)191 void TCPConnectedSocket::SetReceiveBufferSize(
192     int send_buffer_size,
193     SetSendBufferSizeCallback callback) {
194   if (!socket_) {
195     // Fail is this method was called after upgrading to TLS.
196     std::move(callback).Run(net::ERR_UNEXPECTED);
197     return;
198   }
199   int result =
200       socket_->SetReceiveBufferSize(ClampTCPBufferSize(send_buffer_size));
201   std::move(callback).Run(result);
202 }
203 
SetNoDelay(bool no_delay,SetNoDelayCallback callback)204 void TCPConnectedSocket::SetNoDelay(bool no_delay,
205                                     SetNoDelayCallback callback) {
206   if (!socket_) {
207     std::move(callback).Run(false);
208     return;
209   }
210   bool success = socket_->SetNoDelay(no_delay);
211   std::move(callback).Run(success);
212 }
213 
SetKeepAlive(bool enable,int32_t delay_secs,SetKeepAliveCallback callback)214 void TCPConnectedSocket::SetKeepAlive(bool enable,
215                                       int32_t delay_secs,
216                                       SetKeepAliveCallback callback) {
217   if (!socket_) {
218     std::move(callback).Run(false);
219     return;
220   }
221   bool success = socket_->SetKeepAlive(enable, delay_secs);
222   std::move(callback).Run(success);
223 }
224 
OnConnectCompleted(int result)225 void TCPConnectedSocket::OnConnectCompleted(int result) {
226   DCHECK(!connect_callback_.is_null());
227   DCHECK(!socket_data_pump_);
228 
229   net::IPEndPoint peer_addr, local_addr;
230   if (result == net::OK)
231     result = socket_->GetLocalAddress(&local_addr);
232   if (result == net::OK)
233     result = socket_->GetPeerAddress(&peer_addr);
234 
235   if (result != net::OK) {
236     std::move(connect_callback_)
237         .Run(result, base::nullopt, base::nullopt,
238              mojo::ScopedDataPipeConsumerHandle(),
239              mojo::ScopedDataPipeProducerHandle());
240     return;
241   }
242   mojo::DataPipe send_pipe;
243   mojo::DataPipe receive_pipe;
244   socket_data_pump_ = std::make_unique<SocketDataPump>(
245       socket_.get(), this /*delegate*/, std::move(receive_pipe.producer_handle),
246       std::move(send_pipe.consumer_handle), traffic_annotation_);
247   std::move(connect_callback_)
248       .Run(net::OK, local_addr, peer_addr,
249            std::move(receive_pipe.consumer_handle),
250            std::move(send_pipe.producer_handle));
251 }
252 
OnNetworkReadError(int net_error)253 void TCPConnectedSocket::OnNetworkReadError(int net_error) {
254   if (observer_)
255     observer_->OnReadError(net_error);
256 }
257 
OnNetworkWriteError(int net_error)258 void TCPConnectedSocket::OnNetworkWriteError(int net_error) {
259   if (observer_)
260     observer_->OnWriteError(net_error);
261 }
262 
OnShutdown()263 void TCPConnectedSocket::OnShutdown() {
264   socket_data_pump_ = nullptr;
265   if (!pending_upgrade_to_tls_callback_.is_null())
266     std::move(pending_upgrade_to_tls_callback_).Run();
267 }
268 
BorrowSocket()269 const net::StreamSocket* TCPConnectedSocket::BorrowSocket() {
270   return socket_.get();
271 }
272 
TakeSocket()273 std::unique_ptr<net::StreamSocket> TCPConnectedSocket::TakeSocket() {
274   return std::move(socket_);
275 }
276 
277 }  // namespace network
278