1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "services/network/proxy_resolving_socket_mojo.h"
6 
7 #include <utility>
8 
9 #include "base/bind.h"
10 #include "base/logging.h"
11 #include "base/optional.h"
12 #include "net/base/net_errors.h"
13 #include "services/network/socket_data_pump.h"
14 
15 namespace network {
16 
ProxyResolvingSocketMojo(std::unique_ptr<net::StreamSocket> socket,const net::NetworkTrafficAnnotationTag & traffic_annotation,mojo::PendingRemote<mojom::SocketObserver> observer,TLSSocketFactory * tls_socket_factory)17 ProxyResolvingSocketMojo::ProxyResolvingSocketMojo(
18     std::unique_ptr<net::StreamSocket> socket,
19     const net::NetworkTrafficAnnotationTag& traffic_annotation,
20     mojo::PendingRemote<mojom::SocketObserver> observer,
21     TLSSocketFactory* tls_socket_factory)
22     : observer_(std::move(observer)),
23       tls_socket_factory_(tls_socket_factory),
24       socket_(std::move(socket)),
25       traffic_annotation_(traffic_annotation) {}
26 
~ProxyResolvingSocketMojo()27 ProxyResolvingSocketMojo::~ProxyResolvingSocketMojo() {
28   if (connect_callback_) {
29     // If |this| is destroyed when connect hasn't completed, tell the consumer
30     // that request has been aborted.
31     std::move(connect_callback_)
32         .Run(net::ERR_ABORTED, base::nullopt, base::nullopt,
33              mojo::ScopedDataPipeConsumerHandle(),
34              mojo::ScopedDataPipeProducerHandle());
35   }
36 }
37 
Connect(mojom::ProxyResolvingSocketFactory::CreateProxyResolvingSocketCallback callback)38 void ProxyResolvingSocketMojo::Connect(
39     mojom::ProxyResolvingSocketFactory::CreateProxyResolvingSocketCallback
40         callback) {
41   DCHECK(socket_);
42   DCHECK(callback);
43   DCHECK(!connect_callback_);
44 
45   connect_callback_ = std::move(callback);
46   int result = socket_->Connect(base::BindOnce(
47       &ProxyResolvingSocketMojo::OnConnectCompleted, base::Unretained(this)));
48   if (result == net::ERR_IO_PENDING)
49     return;
50   OnConnectCompleted(result);
51 }
52 
UpgradeToTLS(const net::HostPortPair & host_port_pair,const net::MutableNetworkTrafficAnnotationTag & traffic_annotation,mojo::PendingReceiver<mojom::TLSClientSocket> receiver,mojo::PendingRemote<mojom::SocketObserver> observer,mojom::ProxyResolvingSocket::UpgradeToTLSCallback callback)53 void ProxyResolvingSocketMojo::UpgradeToTLS(
54     const net::HostPortPair& host_port_pair,
55     const net::MutableNetworkTrafficAnnotationTag& traffic_annotation,
56     mojo::PendingReceiver<mojom::TLSClientSocket> receiver,
57     mojo::PendingRemote<mojom::SocketObserver> observer,
58     mojom::ProxyResolvingSocket::UpgradeToTLSCallback callback) {
59   // Wait for data pipes to be closed by the client before doing the upgrade.
60   if (socket_data_pump_) {
61     pending_upgrade_to_tls_callback_ = base::BindOnce(
62         &ProxyResolvingSocketMojo::UpgradeToTLS, base::Unretained(this),
63         host_port_pair, traffic_annotation, std::move(receiver),
64         std::move(observer), std::move(callback));
65     return;
66   }
67   tls_socket_factory_->UpgradeToTLS(
68       this, host_port_pair, nullptr /* sockt_options */, traffic_annotation,
69       std::move(receiver), std::move(observer),
70       base::BindOnce(
71           [](mojom::ProxyResolvingSocket::UpgradeToTLSCallback callback,
72              int32_t net_error,
73              mojo::ScopedDataPipeConsumerHandle receive_stream,
74              mojo::ScopedDataPipeProducerHandle send_stream,
75              const base::Optional<net::SSLInfo>& ssl_info) {
76             DCHECK(!ssl_info);
77             std::move(callback).Run(net_error, std::move(receive_stream),
78                                     std::move(send_stream));
79           },
80           std::move(callback)));
81 }
82 
OnConnectCompleted(int result)83 void ProxyResolvingSocketMojo::OnConnectCompleted(int result) {
84   DCHECK(!connect_callback_.is_null());
85   DCHECK(!socket_data_pump_);
86 
87   net::IPEndPoint local_addr;
88   if (result == net::OK)
89     result = socket_->GetLocalAddress(&local_addr);
90 
91   net::IPEndPoint peer_addr;
92   // If |socket_| is connected through a proxy, GetPeerAddress returns
93   // net::ERR_NAME_NOT_RESOLVED.
94   bool get_peer_address_success =
95       result == net::OK && (socket_->GetPeerAddress(&peer_addr) == net::OK);
96 
97   if (result != net::OK) {
98     std::move(connect_callback_)
99         .Run(result, base::nullopt, base::nullopt,
100              mojo::ScopedDataPipeConsumerHandle(),
101              mojo::ScopedDataPipeProducerHandle());
102     return;
103   }
104   mojo::DataPipe send_pipe;
105   mojo::DataPipe receive_pipe;
106   socket_data_pump_ = std::make_unique<SocketDataPump>(
107       socket_.get(), this /*delegate*/, std::move(receive_pipe.producer_handle),
108       std::move(send_pipe.consumer_handle), traffic_annotation_);
109   std::move(connect_callback_)
110       .Run(net::OK, local_addr,
111            get_peer_address_success
112                ? base::make_optional<net::IPEndPoint>(peer_addr)
113                : base::nullopt,
114            std::move(receive_pipe.consumer_handle),
115            std::move(send_pipe.producer_handle));
116 }
117 
OnNetworkReadError(int net_error)118 void ProxyResolvingSocketMojo::OnNetworkReadError(int net_error) {
119   if (observer_)
120     observer_->OnReadError(net_error);
121 }
122 
OnNetworkWriteError(int net_error)123 void ProxyResolvingSocketMojo::OnNetworkWriteError(int net_error) {
124   if (observer_)
125     observer_->OnWriteError(net_error);
126 }
127 
OnShutdown()128 void ProxyResolvingSocketMojo::OnShutdown() {
129   socket_data_pump_ = nullptr;
130   if (!pending_upgrade_to_tls_callback_.is_null())
131     std::move(pending_upgrade_to_tls_callback_).Run();
132 }
133 
BorrowSocket()134 const net::StreamSocket* ProxyResolvingSocketMojo::BorrowSocket() {
135   return socket_.get();
136 }
137 
TakeSocket()138 std::unique_ptr<net::StreamSocket> ProxyResolvingSocketMojo::TakeSocket() {
139   return std::move(socket_);
140 }
141 
142 }  // namespace network
143