1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <stdint.h>
6 
7 #include <string>
8 #include <utility>
9 #include <vector>
10 
11 #include "base/bind.h"
12 #include "base/logging.h"
13 #include "base/macros.h"
14 #include "base/run_loop.h"
15 #include "base/test/bind_test_util.h"
16 #include "base/test/task_environment.h"
17 #include "base/threading/thread.h"
18 #include "mojo/public/cpp/bindings/pending_receiver.h"
19 #include "mojo/public/cpp/bindings/remote.h"
20 #include "net/base/completion_once_callback.h"
21 #include "net/base/net_errors.h"
22 #include "net/base/test_completion_callback.h"
23 #include "net/proxy_resolution/configured_proxy_resolution_service.h"
24 #include "net/socket/server_socket.h"
25 #include "net/socket/socket_test_util.h"
26 #include "net/test/embedded_test_server/embedded_test_server.h"
27 #include "net/test/embedded_test_server/http_request.h"
28 #include "net/test/embedded_test_server/http_response.h"
29 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
30 #include "net/url_request/url_request_test_util.h"
31 #include "services/network/mojo_socket_test_util.h"
32 #include "services/network/proxy_resolving_socket_factory_mojo.h"
33 #include "services/network/public/mojom/network_service.mojom.h"
34 #include "services/network/socket_factory.h"
35 #include "testing/gtest/include/gtest/gtest.h"
36 
37 namespace network {
38 
39 namespace {
40 
41 // Message sent over the tcp connection.
42 const char kMsg[] = "please start tls!";
43 const size_t kMsgSize = strlen(kMsg);
44 
45 // Message sent over the tls connection.
46 const char kSecretMsg[] = "here is secret.";
47 const size_t kSecretMsgSize = strlen(kSecretMsg);
48 
49 class TLSClientSocketTestBase {
50  public:
51   enum Mode { kDirect, kProxyResolving };
52 
TLSClientSocketTestBase(Mode mode)53   explicit TLSClientSocketTestBase(Mode mode)
54       : mode_(mode),
55         task_environment_(base::test::TaskEnvironment::MainThreadType::IO),
56         url_request_context_(true) {}
~TLSClientSocketTestBase()57   virtual ~TLSClientSocketTestBase() {}
58 
mode()59   Mode mode() { return mode_; }
60 
61  protected:
62   // One of the two fields will be set, depending on the mode.
63   struct SocketHandle {
64     mojo::Remote<mojom::TCPConnectedSocket> tcp_socket;
65     mojo::Remote<mojom::ProxyResolvingSocket> proxy_socket;
66   };
67 
68   struct SocketRequest {
69     mojo::PendingReceiver<mojom::TCPConnectedSocket> tcp_socket_receiver;
70     mojo::PendingReceiver<mojom::ProxyResolvingSocket> proxy_socket_receiver;
71   };
72 
73   // Initializes the test fixture. If |use_mock_sockets|, mock client socket
74   // factory will be used.
Init(bool use_mock_sockets,bool configure_proxy)75   void Init(bool use_mock_sockets, bool configure_proxy) {
76     if (use_mock_sockets) {
77       mock_client_socket_factory_.set_enable_read_if_ready(true);
78       url_request_context_.set_client_socket_factory(
79           &mock_client_socket_factory_);
80     }
81     if (configure_proxy) {
82       proxy_resolution_service_ =
83           net::ConfiguredProxyResolutionService::CreateFixed(
84               "http://proxy:8080", TRAFFIC_ANNOTATION_FOR_TESTS);
85       url_request_context_.set_proxy_resolution_service(
86           proxy_resolution_service_.get());
87     }
88     url_request_context_.Init();
89     factory_ = std::make_unique<SocketFactory>(nullptr /*net_log*/,
90                                                &url_request_context_);
91     proxy_resolving_factory_ =
92         std::make_unique<ProxyResolvingSocketFactoryMojo>(
93             &url_request_context_);
94   }
95 
96   // Reads |num_bytes| from |handle| or reads until an error occurs. Returns the
97   // bytes read as a string.
Read(mojo::ScopedDataPipeConsumerHandle * handle,size_t num_bytes)98   std::string Read(mojo::ScopedDataPipeConsumerHandle* handle,
99                    size_t num_bytes) {
100     std::string received_contents;
101     while (received_contents.size() < num_bytes) {
102       base::RunLoop().RunUntilIdle();
103       std::vector<char> buffer(num_bytes);
104       uint32_t read_size = static_cast<uint32_t>(num_bytes);
105       MojoResult result = handle->get().ReadData(buffer.data(), &read_size,
106                                                  MOJO_READ_DATA_FLAG_NONE);
107       if (result == MOJO_RESULT_SHOULD_WAIT)
108         continue;
109       if (result != MOJO_RESULT_OK)
110         return received_contents;
111       received_contents.append(buffer.data(), read_size);
112     }
113     return received_contents;
114   }
115 
MakeRequest(SocketHandle * handle)116   SocketRequest MakeRequest(SocketHandle* handle) {
117     SocketRequest result;
118     if (mode_ == kDirect)
119       result.tcp_socket_receiver =
120           handle->tcp_socket.BindNewPipeAndPassReceiver();
121     else
122       result.proxy_socket_receiver =
123           handle->proxy_socket.BindNewPipeAndPassReceiver();
124     return result;
125   }
126 
ResetSocket(SocketHandle * handle)127   void ResetSocket(SocketHandle* handle) {
128     if (mode_ == kDirect)
129       handle->tcp_socket.reset();
130     else
131       handle->proxy_socket.reset();
132   }
133 
CreateSocketSync(SocketRequest request,const net::IPEndPoint & remote_addr)134   int CreateSocketSync(SocketRequest request,
135                        const net::IPEndPoint& remote_addr) {
136     if (mode_ == kDirect) {
137       return CreateTCPConnectedSocketSync(
138           std::move(request.tcp_socket_receiver), remote_addr);
139     } else {
140       return CreateProxyResolvingSocketSync(
141           std::move(request.proxy_socket_receiver), remote_addr);
142     }
143   }
144 
CreateTCPConnectedSocketSync(mojo::PendingReceiver<mojom::TCPConnectedSocket> receiver,const net::IPEndPoint & remote_addr)145   int CreateTCPConnectedSocketSync(
146       mojo::PendingReceiver<mojom::TCPConnectedSocket> receiver,
147       const net::IPEndPoint& remote_addr) {
148     net::AddressList remote_addr_list(remote_addr);
149     base::RunLoop run_loop;
150     int net_error = net::ERR_FAILED;
151     factory_->CreateTCPConnectedSocket(
152         base::nullopt /* local_addr */, remote_addr_list,
153         nullptr /* tcp_connected_socket_options */,
154         TRAFFIC_ANNOTATION_FOR_TESTS, std::move(receiver),
155         pre_tls_observer()->GetObserverRemote(),
156         base::BindLambdaForTesting(
157             [&](int result,
158                 const base::Optional<net::IPEndPoint>& actual_local_addr,
159                 const base::Optional<net::IPEndPoint>& peer_addr,
160                 mojo::ScopedDataPipeConsumerHandle receive_pipe_handle,
161                 mojo::ScopedDataPipeProducerHandle send_pipe_handle) {
162               net_error = result;
163               pre_tls_recv_handle_ = std::move(receive_pipe_handle);
164               pre_tls_send_handle_ = std::move(send_pipe_handle);
165               run_loop.Quit();
166             }));
167     run_loop.Run();
168     return net_error;
169   }
170 
CreateProxyResolvingSocketSync(mojo::PendingReceiver<mojom::ProxyResolvingSocket> receiver,const net::IPEndPoint & remote_addr)171   int CreateProxyResolvingSocketSync(
172       mojo::PendingReceiver<mojom::ProxyResolvingSocket> receiver,
173       const net::IPEndPoint& remote_addr) {
174     GURL url("https://" + remote_addr.ToString());
175     base::RunLoop run_loop;
176     int net_error = net::ERR_FAILED;
177     proxy_resolving_factory_->CreateProxyResolvingSocket(
178         url, nullptr /* options */,
179         net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS),
180         std::move(receiver), mojo::NullRemote() /* observer */,
181         base::BindLambdaForTesting(
182             [&](int result,
183                 const base::Optional<net::IPEndPoint>& actual_local_addr,
184                 const base::Optional<net::IPEndPoint>& peer_addr,
185                 mojo::ScopedDataPipeConsumerHandle receive_pipe_handle,
186                 mojo::ScopedDataPipeProducerHandle send_pipe_handle) {
187               net_error = result;
188               pre_tls_recv_handle_ = std::move(receive_pipe_handle);
189               pre_tls_send_handle_ = std::move(send_pipe_handle);
190               run_loop.Quit();
191             }));
192     run_loop.Run();
193     return net_error;
194   }
195 
UpgradeToTLS(SocketHandle * handle,const net::HostPortPair & host_port_pair,mojo::PendingReceiver<mojom::TLSClientSocket> receiver,net::CompletionOnceCallback callback)196   void UpgradeToTLS(SocketHandle* handle,
197                     const net::HostPortPair& host_port_pair,
198                     mojo::PendingReceiver<mojom::TLSClientSocket> receiver,
199                     net::CompletionOnceCallback callback) {
200     if (mode_ == kDirect) {
201       UpgradeTCPConnectedSocketToTLS(handle->tcp_socket.get(), host_port_pair,
202                                      nullptr /* options */, std::move(receiver),
203                                      std::move(callback));
204     } else {
205       UpgradeProxyResolvingSocketToTLS(handle->proxy_socket.get(),
206                                        host_port_pair, std::move(receiver),
207                                        std::move(callback));
208     }
209   }
210 
UpgradeTCPConnectedSocketToTLS(mojom::TCPConnectedSocket * client_socket,const net::HostPortPair & host_port_pair,mojom::TLSClientSocketOptionsPtr options,mojo::PendingReceiver<mojom::TLSClientSocket> receiver,net::CompletionOnceCallback callback)211   void UpgradeTCPConnectedSocketToTLS(
212       mojom::TCPConnectedSocket* client_socket,
213       const net::HostPortPair& host_port_pair,
214       mojom::TLSClientSocketOptionsPtr options,
215       mojo::PendingReceiver<mojom::TLSClientSocket> receiver,
216       net::CompletionOnceCallback callback) {
217     client_socket->UpgradeToTLS(
218         host_port_pair, std::move(options),
219         net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS),
220         std::move(receiver), post_tls_observer()->GetObserverRemote(),
221         base::BindOnce(
222             [](net::CompletionOnceCallback cb,
223                mojo::ScopedDataPipeConsumerHandle* consumer_handle_out,
224                mojo::ScopedDataPipeProducerHandle* producer_handle_out,
225                base::Optional<net::SSLInfo>* ssl_info_out, int result,
226                mojo::ScopedDataPipeConsumerHandle receive_pipe_handle,
227                mojo::ScopedDataPipeProducerHandle send_pipe_handle,
228                const base::Optional<net::SSLInfo>& ssl_info) {
229               *consumer_handle_out = std::move(receive_pipe_handle);
230               *producer_handle_out = std::move(send_pipe_handle);
231               *ssl_info_out = ssl_info;
232               std::move(cb).Run(result);
233             },
234             std::move(callback), &post_tls_recv_handle_, &post_tls_send_handle_,
235             &ssl_info_));
236   }
237 
UpgradeProxyResolvingSocketToTLS(mojom::ProxyResolvingSocket * client_socket,const net::HostPortPair & host_port_pair,mojo::PendingReceiver<mojom::TLSClientSocket> receiver,net::CompletionOnceCallback callback)238   void UpgradeProxyResolvingSocketToTLS(
239       mojom::ProxyResolvingSocket* client_socket,
240       const net::HostPortPair& host_port_pair,
241       mojo::PendingReceiver<mojom::TLSClientSocket> receiver,
242       net::CompletionOnceCallback callback) {
243     client_socket->UpgradeToTLS(
244         host_port_pair,
245         net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS),
246         std::move(receiver), post_tls_observer()->GetObserverRemote(),
247         base::BindOnce(
248             [](net::CompletionOnceCallback cb,
249                mojo::ScopedDataPipeConsumerHandle* consumer_handle,
250                mojo::ScopedDataPipeProducerHandle* producer_handle, int result,
251                mojo::ScopedDataPipeConsumerHandle receive_pipe_handle,
252                mojo::ScopedDataPipeProducerHandle send_pipe_handle) {
253               *consumer_handle = std::move(receive_pipe_handle);
254               *producer_handle = std::move(send_pipe_handle);
255               std::move(cb).Run(result);
256             },
257             std::move(callback), &post_tls_recv_handle_,
258             &post_tls_send_handle_));
259   }
260 
pre_tls_observer()261   TestSocketObserver* pre_tls_observer() { return &pre_tls_observer_; }
post_tls_observer()262   TestSocketObserver* post_tls_observer() { return &post_tls_observer_; }
263 
pre_tls_recv_handle()264   mojo::ScopedDataPipeConsumerHandle* pre_tls_recv_handle() {
265     return &pre_tls_recv_handle_;
266   }
267 
pre_tls_send_handle()268   mojo::ScopedDataPipeProducerHandle* pre_tls_send_handle() {
269     return &pre_tls_send_handle_;
270   }
271 
post_tls_recv_handle()272   mojo::ScopedDataPipeConsumerHandle* post_tls_recv_handle() {
273     return &post_tls_recv_handle_;
274   }
275 
post_tls_send_handle()276   mojo::ScopedDataPipeProducerHandle* post_tls_send_handle() {
277     return &post_tls_send_handle_;
278   }
279 
ssl_info()280   const base::Optional<net::SSLInfo>& ssl_info() { return ssl_info_; }
281 
mock_client_socket_factory()282   net::MockClientSocketFactory* mock_client_socket_factory() {
283     return &mock_client_socket_factory_;
284   }
285 
mode() const286   Mode mode() const { return mode_; }
287 
288  private:
289   Mode mode_;
290   base::test::TaskEnvironment task_environment_;
291 
292   // Mojo data handles obtained from CreateTCPConnectedSocket.
293   mojo::ScopedDataPipeConsumerHandle pre_tls_recv_handle_;
294   mojo::ScopedDataPipeProducerHandle pre_tls_send_handle_;
295 
296   // Mojo data handles obtained from UpgradeToTLS.
297   mojo::ScopedDataPipeConsumerHandle post_tls_recv_handle_;
298   mojo::ScopedDataPipeProducerHandle post_tls_send_handle_;
299 
300   // SSLInfo obtained from UpgradeToTLS.
301   base::Optional<net::SSLInfo> ssl_info_;
302 
303   std::unique_ptr<net::ProxyResolutionService> proxy_resolution_service_;
304   net::TestURLRequestContext url_request_context_;
305   net::MockClientSocketFactory mock_client_socket_factory_;
306   std::unique_ptr<SocketFactory> factory_;
307   std::unique_ptr<ProxyResolvingSocketFactoryMojo> proxy_resolving_factory_;
308   TestSocketObserver pre_tls_observer_;
309   TestSocketObserver post_tls_observer_;
310 
311   DISALLOW_COPY_AND_ASSIGN(TLSClientSocketTestBase);
312 };
313 
314 class TLSClientSocketTest
315     : public ::testing::TestWithParam<TLSClientSocketTestBase::Mode>,
316       public TLSClientSocketTestBase {
317  public:
TLSClientSocketTest()318   TLSClientSocketTest() : TLSClientSocketTestBase(GetParam()) {
319     Init(true /* use_mock_sockets */, false /* configure_proxy */);
320   }
321 
~TLSClientSocketTest()322   ~TLSClientSocketTest() override {}
323 
324  private:
325   DISALLOW_COPY_AND_ASSIGN(TLSClientSocketTest);
326 };
327 
328 // Basic test to call UpgradeToTLS, and then read/write after UpgradeToTLS is
329 // successful.
TEST_P(TLSClientSocketTest,UpgradeToTLS)330 TEST_P(TLSClientSocketTest, UpgradeToTLS) {
331   const net::MockRead kReads[] = {net::MockRead(net::ASYNC, kMsg, kMsgSize, 1),
332                                   net::MockRead(net::SYNCHRONOUS, net::OK, 2)};
333   const net::MockWrite kWrites[] = {
334       net::MockWrite(net::SYNCHRONOUS, kMsg, kMsgSize, 0)};
335   net::SequencedSocketData data_provider(kReads, kWrites);
336   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
337   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
338   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
339   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
340 
341   SocketHandle client_socket;
342   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
343   EXPECT_EQ(net::OK,
344             CreateSocketSync(MakeRequest(&client_socket), server_addr));
345 
346   net::HostPortPair host_port_pair("example.org", 443);
347   pre_tls_recv_handle()->reset();
348   pre_tls_send_handle()->reset();
349   net::TestCompletionCallback callback;
350   mojo::Remote<mojom::TLSClientSocket> tls_socket;
351   UpgradeToTLS(&client_socket, host_port_pair,
352                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
353   ASSERT_EQ(net::OK, callback.WaitForResult());
354   ResetSocket(&client_socket);
355 
356   uint32_t num_bytes = strlen(kMsg);
357   EXPECT_EQ(MOJO_RESULT_OK, post_tls_send_handle()->get().WriteData(
358                                 &kMsg, &num_bytes, MOJO_WRITE_DATA_FLAG_NONE));
359   EXPECT_EQ(kMsg, Read(post_tls_recv_handle(), kMsgSize));
360   base::RunLoop().RunUntilIdle();
361   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
362   EXPECT_TRUE(data_provider.AllReadDataConsumed());
363   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
364 }
365 
366 // Same as the UpgradeToTLS test above, except this test calls
367 // base::RunLoop().RunUntilIdle() after destroying the pre-tls data pipes.
TEST_P(TLSClientSocketTest,ClosePipesRunUntilIdleAndUpgradeToTLS)368 TEST_P(TLSClientSocketTest, ClosePipesRunUntilIdleAndUpgradeToTLS) {
369   const net::MockRead kReads[] = {net::MockRead(net::ASYNC, kMsg, kMsgSize, 1),
370                                   net::MockRead(net::SYNCHRONOUS, net::OK, 2)};
371   const net::MockWrite kWrites[] = {
372       net::MockWrite(net::SYNCHRONOUS, kMsg, kMsgSize, 0)};
373   net::SequencedSocketData data_provider(kReads, kWrites);
374   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
375   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
376   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
377   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
378 
379   SocketHandle client_socket;
380   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
381   EXPECT_EQ(net::OK,
382             CreateSocketSync(MakeRequest(&client_socket), server_addr));
383 
384   net::HostPortPair host_port_pair("example.org", 443);
385 
386   // Call RunUntilIdle() to test the case that pipes are closed before
387   // UpgradeToTLS.
388   pre_tls_recv_handle()->reset();
389   pre_tls_send_handle()->reset();
390   base::RunLoop().RunUntilIdle();
391 
392   net::TestCompletionCallback callback;
393   mojo::Remote<mojom::TLSClientSocket> tls_socket;
394   UpgradeToTLS(&client_socket, host_port_pair,
395                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
396   ASSERT_EQ(net::OK, callback.WaitForResult());
397   ResetSocket(&client_socket);
398 
399   uint32_t num_bytes = strlen(kMsg);
400   EXPECT_EQ(MOJO_RESULT_OK, post_tls_send_handle()->get().WriteData(
401                                 &kMsg, &num_bytes, MOJO_WRITE_DATA_FLAG_NONE));
402   EXPECT_EQ(kMsg, Read(post_tls_recv_handle(), kMsgSize));
403   base::RunLoop().RunUntilIdle();
404   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
405   EXPECT_TRUE(data_provider.AllReadDataConsumed());
406   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
407 }
408 
409 // Calling UpgradeToTLS on the same mojo::Remote<TCPConnectedSocket> is illegal
410 // and should receive an error.
TEST_P(TLSClientSocketTest,UpgradeToTLSTwice)411 TEST_P(TLSClientSocketTest, UpgradeToTLSTwice) {
412   const net::MockRead kReads[] = {net::MockRead(net::ASYNC, net::OK, 0)};
413   net::SequencedSocketData data_provider(kReads, base::span<net::MockWrite>());
414   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
415   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
416   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
417   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
418 
419   SocketHandle client_socket;
420   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
421   EXPECT_EQ(net::OK,
422             CreateSocketSync(MakeRequest(&client_socket), server_addr));
423 
424   net::HostPortPair host_port_pair("example.org", 443);
425   pre_tls_recv_handle()->reset();
426   pre_tls_send_handle()->reset();
427 
428   // First UpgradeToTLS should complete successfully.
429   net::TestCompletionCallback callback;
430   mojo::Remote<mojom::TLSClientSocket> tls_socket;
431   UpgradeToTLS(&client_socket, host_port_pair,
432                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
433   ASSERT_EQ(net::OK, callback.WaitForResult());
434 
435   // Second time UpgradeToTLS is called, it should fail.
436   mojo::Remote<mojom::TLSClientSocket> tls_socket2;
437   base::RunLoop run_loop;
438   int net_error = net::ERR_FAILED;
439   if (mode() == kDirect) {
440     auto upgrade2_callback = base::BindLambdaForTesting(
441         [&](int result, mojo::ScopedDataPipeConsumerHandle receive_pipe_handle,
442             mojo::ScopedDataPipeProducerHandle send_pipe_handle,
443             const base::Optional<net::SSLInfo>& ssl_info) {
444           net_error = result;
445           run_loop.Quit();
446         });
447     client_socket.tcp_socket->UpgradeToTLS(
448         host_port_pair, nullptr /* ssl_config_ptr */,
449         net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS),
450         tls_socket2.BindNewPipeAndPassReceiver(),
451         mojo::NullRemote() /*observer */, std::move(upgrade2_callback));
452   } else {
453     auto upgrade2_callback = base::BindLambdaForTesting(
454         [&](int result, mojo::ScopedDataPipeConsumerHandle receive_pipe_handle,
455             mojo::ScopedDataPipeProducerHandle send_pipe_handle) {
456           net_error = result;
457           run_loop.Quit();
458         });
459     client_socket.proxy_socket->UpgradeToTLS(
460         host_port_pair,
461         net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS),
462         tls_socket2.BindNewPipeAndPassReceiver(),
463         mojo::NullRemote() /*observer */, std::move(upgrade2_callback));
464   }
465   run_loop.Run();
466   ASSERT_EQ(net::ERR_SOCKET_NOT_CONNECTED, net_error);
467 
468   base::RunLoop().RunUntilIdle();
469   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
470   EXPECT_TRUE(data_provider.AllReadDataConsumed());
471   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
472 }
473 
TEST_P(TLSClientSocketTest,UpgradeToTLSWithCustomSSLConfig)474 TEST_P(TLSClientSocketTest, UpgradeToTLSWithCustomSSLConfig) {
475   // No custom options in the proxy-resolving case.
476   if (mode() != kDirect)
477     return;
478   const net::MockRead kReads[] = {net::MockRead(net::ASYNC, net::OK, 0)};
479   net::SequencedSocketData data_provider(kReads, base::span<net::MockWrite>());
480   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
481   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
482   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
483   ssl_socket.expected_ssl_version_min = net::SSL_PROTOCOL_VERSION_TLS1_1;
484   ssl_socket.expected_ssl_version_max = net::SSL_PROTOCOL_VERSION_TLS1_2;
485   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
486 
487   SocketHandle client_socket;
488   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
489   EXPECT_EQ(net::OK,
490             CreateSocketSync(MakeRequest(&client_socket), server_addr));
491 
492   net::HostPortPair host_port_pair("example.org", 443);
493   pre_tls_recv_handle()->reset();
494   pre_tls_send_handle()->reset();
495 
496   mojo::Remote<mojom::TLSClientSocket> tls_socket;
497   base::RunLoop run_loop;
498   mojom::TLSClientSocketOptionsPtr options =
499       mojom::TLSClientSocketOptions::New();
500   options->version_min = mojom::SSLVersion::kTLS11;
501   options->version_max = mojom::SSLVersion::kTLS12;
502   int net_error = net::ERR_FAILED;
503   auto upgrade_callback = base::BindLambdaForTesting(
504       [&](int result, mojo::ScopedDataPipeConsumerHandle receive_pipe_handle,
505           mojo::ScopedDataPipeProducerHandle send_pipe_handle,
506           const base::Optional<net::SSLInfo>& ssl_info) {
507         net_error = result;
508         run_loop.Quit();
509       });
510   client_socket.tcp_socket->UpgradeToTLS(
511       host_port_pair, std::move(options),
512       net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS),
513       tls_socket.BindNewPipeAndPassReceiver(), mojo::NullRemote() /*observer */,
514       std::move(upgrade_callback));
515   run_loop.Run();
516   ASSERT_EQ(net::OK, net_error);
517 
518   base::RunLoop().RunUntilIdle();
519   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
520   EXPECT_TRUE(data_provider.AllReadDataConsumed());
521   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
522 }
523 
524 // Same as the UpgradeToTLS test, except this also reads and writes to the tcp
525 // connection before UpgradeToTLS is called.
TEST_P(TLSClientSocketTest,ReadWriteBeforeUpgradeToTLS)526 TEST_P(TLSClientSocketTest, ReadWriteBeforeUpgradeToTLS) {
527   const net::MockRead kReads[] = {
528       net::MockRead(net::SYNCHRONOUS, kMsg, kMsgSize, 0),
529       net::MockRead(net::ASYNC, kSecretMsg, kSecretMsgSize, 3),
530       net::MockRead(net::SYNCHRONOUS, net::OK, 4)};
531   const net::MockWrite kWrites[] = {
532       net::MockWrite(net::SYNCHRONOUS, kMsg, kMsgSize, 1),
533       net::MockWrite(net::SYNCHRONOUS, kSecretMsg, kSecretMsgSize, 2),
534   };
535   net::SequencedSocketData data_provider(kReads, kWrites);
536   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
537   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
538   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
539   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
540 
541   SocketHandle client_socket;
542   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
543   EXPECT_EQ(net::OK,
544             CreateSocketSync(MakeRequest(&client_socket), server_addr));
545 
546   EXPECT_EQ(kMsg, Read(pre_tls_recv_handle(), kMsgSize));
547 
548   uint32_t num_bytes = kMsgSize;
549   EXPECT_EQ(MOJO_RESULT_OK, pre_tls_send_handle()->get().WriteData(
550                                 &kMsg, &num_bytes, MOJO_WRITE_DATA_FLAG_NONE));
551 
552   net::HostPortPair host_port_pair("example.org", 443);
553   pre_tls_recv_handle()->reset();
554   pre_tls_send_handle()->reset();
555   net::TestCompletionCallback callback;
556   mojo::Remote<mojom::TLSClientSocket> tls_socket;
557   UpgradeToTLS(&client_socket, host_port_pair,
558                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
559   ASSERT_EQ(net::OK, callback.WaitForResult());
560   ResetSocket(&client_socket);
561 
562   num_bytes = strlen(kSecretMsg);
563   EXPECT_EQ(MOJO_RESULT_OK,
564             post_tls_send_handle()->get().WriteData(&kSecretMsg, &num_bytes,
565                                                     MOJO_WRITE_DATA_FLAG_NONE));
566   EXPECT_EQ(kSecretMsg, Read(post_tls_recv_handle(), kSecretMsgSize));
567   base::RunLoop().RunUntilIdle();
568   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
569   EXPECT_TRUE(data_provider.AllReadDataConsumed());
570   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
571 }
572 
573 // Tests that a read error is encountered after UpgradeToTLS completes
574 // successfully.
TEST_P(TLSClientSocketTest,ReadErrorAfterUpgradeToTLS)575 TEST_P(TLSClientSocketTest, ReadErrorAfterUpgradeToTLS) {
576   const net::MockRead kReads[] = {
577       net::MockRead(net::ASYNC, kSecretMsg, kSecretMsgSize, 1),
578       net::MockRead(net::SYNCHRONOUS, net::ERR_CONNECTION_CLOSED, 2)};
579   const net::MockWrite kWrites[] = {
580       net::MockWrite(net::SYNCHRONOUS, kSecretMsg, kSecretMsgSize, 0)};
581   net::SequencedSocketData data_provider(kReads, kWrites);
582   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
583   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
584   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
585   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
586 
587   SocketHandle client_socket;
588   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
589   EXPECT_EQ(net::OK,
590             CreateSocketSync(MakeRequest(&client_socket), server_addr));
591 
592   net::HostPortPair host_port_pair("example.org", 443);
593   pre_tls_recv_handle()->reset();
594   pre_tls_send_handle()->reset();
595   net::TestCompletionCallback callback;
596   mojo::Remote<mojom::TLSClientSocket> tls_socket;
597   UpgradeToTLS(&client_socket, host_port_pair,
598                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
599   ASSERT_EQ(net::OK, callback.WaitForResult());
600   ResetSocket(&client_socket);
601 
602   uint32_t num_bytes = strlen(kSecretMsg);
603   EXPECT_EQ(MOJO_RESULT_OK,
604             post_tls_send_handle()->get().WriteData(&kSecretMsg, &num_bytes,
605                                                     MOJO_WRITE_DATA_FLAG_NONE));
606   EXPECT_EQ(kSecretMsg, Read(post_tls_recv_handle(), kSecretMsgSize));
607   EXPECT_EQ(net::ERR_CONNECTION_CLOSED,
608             post_tls_observer()->WaitForReadError());
609 
610   base::RunLoop().RunUntilIdle();
611   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
612   EXPECT_TRUE(data_provider.AllReadDataConsumed());
613   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
614 }
615 
616 // Tests that a read error is encountered after UpgradeToTLS completes
617 // successfully.
TEST_P(TLSClientSocketTest,WriteErrorAfterUpgradeToTLS)618 TEST_P(TLSClientSocketTest, WriteErrorAfterUpgradeToTLS) {
619   const net::MockRead kReads[] = {net::MockRead(net::ASYNC, net::OK, 0)};
620   const net::MockWrite kWrites[] = {
621       net::MockWrite(net::SYNCHRONOUS, net::ERR_CONNECTION_CLOSED, 1)};
622   net::SequencedSocketData data_provider(kReads, kWrites);
623   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
624   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
625   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
626   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
627 
628   SocketHandle client_socket;
629   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
630   EXPECT_EQ(net::OK,
631             CreateSocketSync(MakeRequest(&client_socket), server_addr));
632 
633   net::HostPortPair host_port_pair("example.org", 443);
634   pre_tls_recv_handle()->reset();
635   pre_tls_send_handle()->reset();
636   net::TestCompletionCallback callback;
637   mojo::Remote<mojom::TLSClientSocket> tls_socket;
638   UpgradeToTLS(&client_socket, host_port_pair,
639                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
640   ASSERT_EQ(net::OK, callback.WaitForResult());
641   ResetSocket(&client_socket);
642 
643   uint32_t num_bytes = strlen(kSecretMsg);
644   EXPECT_EQ(MOJO_RESULT_OK,
645             post_tls_send_handle()->get().WriteData(&kSecretMsg, &num_bytes,
646                                                     MOJO_WRITE_DATA_FLAG_NONE));
647   EXPECT_EQ(net::ERR_CONNECTION_CLOSED,
648             post_tls_observer()->WaitForWriteError());
649 
650   base::RunLoop().RunUntilIdle();
651   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
652   EXPECT_TRUE(data_provider.AllReadDataConsumed());
653   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
654 }
655 
656 // Tests that reading from the pre-tls data pipe is okay even after UpgradeToTLS
657 // is called.
TEST_P(TLSClientSocketTest,ReadFromPreTlsDataPipeAfterUpgradeToTLS)658 TEST_P(TLSClientSocketTest, ReadFromPreTlsDataPipeAfterUpgradeToTLS) {
659   const net::MockRead kReads[] = {
660       net::MockRead(net::ASYNC, kMsg, kMsgSize, 0),
661       net::MockRead(net::ASYNC, kSecretMsg, kSecretMsgSize, 2),
662       net::MockRead(net::SYNCHRONOUS, net::OK, 3)};
663   const net::MockWrite kWrites[] = {
664       net::MockWrite(net::SYNCHRONOUS, kSecretMsg, kSecretMsgSize, 1)};
665   net::SequencedSocketData data_provider(kReads, kWrites);
666   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
667   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
668   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
669   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
670 
671   SocketHandle client_socket;
672   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
673   EXPECT_EQ(net::OK,
674             CreateSocketSync(MakeRequest(&client_socket), server_addr));
675 
676   net::HostPortPair host_port_pair("example.org", 443);
677   pre_tls_send_handle()->reset();
678   net::TestCompletionCallback callback;
679   mojo::Remote<mojom::TLSClientSocket> tls_socket;
680   UpgradeToTLS(&client_socket, host_port_pair,
681                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
682   base::RunLoop().RunUntilIdle();
683 
684   EXPECT_EQ(kMsg, Read(pre_tls_recv_handle(), kMsgSize));
685 
686   // Reset pre-tls receive pipe now and UpgradeToTLS should complete.
687   pre_tls_recv_handle()->reset();
688   ASSERT_EQ(net::OK, callback.WaitForResult());
689   ResetSocket(&client_socket);
690 
691   uint32_t num_bytes = strlen(kSecretMsg);
692   EXPECT_EQ(MOJO_RESULT_OK,
693             post_tls_send_handle()->get().WriteData(&kSecretMsg, &num_bytes,
694                                                     MOJO_WRITE_DATA_FLAG_NONE));
695   EXPECT_EQ(kSecretMsg, Read(post_tls_recv_handle(), kSecretMsgSize));
696   base::RunLoop().RunUntilIdle();
697   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
698   EXPECT_TRUE(data_provider.AllReadDataConsumed());
699   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
700 }
701 
702 // Tests that writing to the pre-tls data pipe is okay even after UpgradeToTLS
703 // is called.
TEST_P(TLSClientSocketTest,WriteToPreTlsDataPipeAfterUpgradeToTLS)704 TEST_P(TLSClientSocketTest, WriteToPreTlsDataPipeAfterUpgradeToTLS) {
705   const net::MockRead kReads[] = {
706       net::MockRead(net::ASYNC, kSecretMsg, kSecretMsgSize, 2),
707       net::MockRead(net::SYNCHRONOUS, net::OK, 3)};
708   const net::MockWrite kWrites[] = {
709       net::MockWrite(net::SYNCHRONOUS, kMsg, kMsgSize, 0),
710       net::MockWrite(net::SYNCHRONOUS, kSecretMsg, kSecretMsgSize, 1)};
711   net::SequencedSocketData data_provider(kReads, kWrites);
712   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
713   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
714   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
715   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
716 
717   SocketHandle client_socket;
718   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
719   EXPECT_EQ(net::OK,
720             CreateSocketSync(MakeRequest(&client_socket), server_addr));
721 
722   net::HostPortPair host_port_pair("example.org", 443);
723   pre_tls_recv_handle()->reset();
724   net::TestCompletionCallback callback;
725   mojo::Remote<mojom::TLSClientSocket> tls_socket;
726   UpgradeToTLS(&client_socket, host_port_pair,
727                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
728   base::RunLoop().RunUntilIdle();
729 
730   uint32_t num_bytes = strlen(kMsg);
731   EXPECT_EQ(MOJO_RESULT_OK, pre_tls_send_handle()->get().WriteData(
732                                 &kMsg, &num_bytes, MOJO_WRITE_DATA_FLAG_NONE));
733 
734   // Reset pre-tls send pipe now and UpgradeToTLS should complete.
735   pre_tls_send_handle()->reset();
736   ASSERT_EQ(net::OK, callback.WaitForResult());
737   ResetSocket(&client_socket);
738 
739   num_bytes = strlen(kSecretMsg);
740   EXPECT_EQ(MOJO_RESULT_OK,
741             post_tls_send_handle()->get().WriteData(&kSecretMsg, &num_bytes,
742                                                     MOJO_WRITE_DATA_FLAG_NONE));
743   EXPECT_EQ(kSecretMsg, Read(post_tls_recv_handle(), kSecretMsgSize));
744   base::RunLoop().RunUntilIdle();
745   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
746   EXPECT_TRUE(data_provider.AllReadDataConsumed());
747   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
748 }
749 
750 // Tests that reading from and writing to pre-tls data pipe is okay even after
751 // UpgradeToTLS is called.
TEST_P(TLSClientSocketTest,ReadAndWritePreTlsDataPipeAfterUpgradeToTLS)752 TEST_P(TLSClientSocketTest, ReadAndWritePreTlsDataPipeAfterUpgradeToTLS) {
753   const net::MockRead kReads[] = {
754       net::MockRead(net::ASYNC, kMsg, kMsgSize, 0),
755       net::MockRead(net::ASYNC, kSecretMsg, kSecretMsgSize, 3),
756       net::MockRead(net::SYNCHRONOUS, net::OK, 4)};
757   const net::MockWrite kWrites[] = {
758       net::MockWrite(net::SYNCHRONOUS, kMsg, kMsgSize, 1),
759       net::MockWrite(net::SYNCHRONOUS, kSecretMsg, kSecretMsgSize, 2)};
760   net::SequencedSocketData data_provider(kReads, kWrites);
761   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
762   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
763   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
764   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
765 
766   SocketHandle client_socket;
767   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
768   EXPECT_EQ(net::OK,
769             CreateSocketSync(MakeRequest(&client_socket), server_addr));
770 
771   net::HostPortPair host_port_pair("example.org", 443);
772   base::RunLoop run_loop;
773   net::TestCompletionCallback callback;
774   mojo::Remote<mojom::TLSClientSocket> tls_socket;
775   UpgradeToTLS(&client_socket, host_port_pair,
776                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
777   EXPECT_EQ(kMsg, Read(pre_tls_recv_handle(), kMsgSize));
778   uint32_t num_bytes = strlen(kMsg);
779   EXPECT_EQ(MOJO_RESULT_OK, pre_tls_send_handle()->get().WriteData(
780                                 &kMsg, &num_bytes, MOJO_WRITE_DATA_FLAG_NONE));
781 
782   // Reset pre-tls pipes now and UpgradeToTLS should complete.
783   pre_tls_recv_handle()->reset();
784   pre_tls_send_handle()->reset();
785   ASSERT_EQ(net::OK, callback.WaitForResult());
786   ResetSocket(&client_socket);
787 
788   num_bytes = strlen(kSecretMsg);
789   EXPECT_EQ(MOJO_RESULT_OK,
790             post_tls_send_handle()->get().WriteData(&kSecretMsg, &num_bytes,
791                                                     MOJO_WRITE_DATA_FLAG_NONE));
792   EXPECT_EQ(kSecretMsg, Read(post_tls_recv_handle(), kSecretMsgSize));
793   base::RunLoop().RunUntilIdle();
794   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
795   EXPECT_TRUE(data_provider.AllReadDataConsumed());
796   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
797 }
798 
799 // Tests that a read error is encountered before UpgradeToTLS completes.
TEST_P(TLSClientSocketTest,ReadErrorBeforeUpgradeToTLS)800 TEST_P(TLSClientSocketTest, ReadErrorBeforeUpgradeToTLS) {
801   // This requires pre_tls_observer(), which is not provided by proxy resolving
802   // sockets.
803   if (mode() != kDirect)
804     return;
805   const net::MockRead kReads[] = {
806       net::MockRead(net::ASYNC, kMsg, kMsgSize, 0),
807       net::MockRead(net::SYNCHRONOUS, net::ERR_CONNECTION_CLOSED, 1)};
808   net::SequencedSocketData data_provider(kReads, base::span<net::MockWrite>());
809   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
810   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
811 
812   SocketHandle client_socket;
813   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
814   EXPECT_EQ(net::OK,
815             CreateSocketSync(MakeRequest(&client_socket), server_addr));
816 
817   net::HostPortPair host_port_pair("example.org", 443);
818   pre_tls_send_handle()->reset();
819   net::TestCompletionCallback callback;
820   mojo::Remote<mojom::TLSClientSocket> tls_socket;
821   UpgradeToTLS(&client_socket, host_port_pair,
822                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
823 
824   EXPECT_EQ(kMsg, Read(pre_tls_recv_handle(), kMsgSize));
825   EXPECT_EQ(net::ERR_CONNECTION_CLOSED, pre_tls_observer()->WaitForReadError());
826 
827   // Reset pre-tls receive pipe now and UpgradeToTLS should complete.
828   pre_tls_recv_handle()->reset();
829   ASSERT_EQ(net::ERR_SOCKET_NOT_CONNECTED, callback.WaitForResult());
830   ResetSocket(&client_socket);
831 
832   base::RunLoop().RunUntilIdle();
833   EXPECT_TRUE(data_provider.AllReadDataConsumed());
834   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
835 }
836 
837 // Tests that a write error is encountered before UpgradeToTLS completes.
TEST_P(TLSClientSocketTest,WriteErrorBeforeUpgradeToTLS)838 TEST_P(TLSClientSocketTest, WriteErrorBeforeUpgradeToTLS) {
839   // This requires pre_tls_observer(), which is not provided by proxy resolving
840   // sockets.
841   if (mode() != kDirect)
842     return;
843 
844   const net::MockRead kReads[] = {net::MockRead(net::ASYNC, net::OK, 1)};
845   const net::MockWrite kWrites[] = {
846       net::MockWrite(net::SYNCHRONOUS, net::ERR_CONNECTION_CLOSED, 0)};
847   net::SequencedSocketData data_provider(kReads, kWrites);
848   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
849   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
850 
851   SocketHandle client_socket;
852   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
853   EXPECT_EQ(net::OK,
854             CreateSocketSync(MakeRequest(&client_socket), server_addr));
855 
856   net::HostPortPair host_port_pair("example.org", 443);
857   pre_tls_recv_handle()->reset();
858   net::TestCompletionCallback callback;
859   mojo::Remote<mojom::TLSClientSocket> tls_socket;
860   UpgradeToTLS(&client_socket, host_port_pair,
861                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
862   uint32_t num_bytes = strlen(kMsg);
863   EXPECT_EQ(MOJO_RESULT_OK, pre_tls_send_handle()->get().WriteData(
864                                 &kMsg, &num_bytes, MOJO_WRITE_DATA_FLAG_NONE));
865 
866   EXPECT_EQ(net::ERR_CONNECTION_CLOSED,
867             pre_tls_observer()->WaitForWriteError());
868   // Reset pre-tls send pipe now and UpgradeToTLS should complete.
869   pre_tls_send_handle()->reset();
870   ASSERT_EQ(net::ERR_SOCKET_NOT_CONNECTED, callback.WaitForResult());
871   ResetSocket(&client_socket);
872 
873   base::RunLoop().RunUntilIdle();
874   // Write failed before the mock read can be consumed.
875   EXPECT_FALSE(data_provider.AllReadDataConsumed());
876   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
877 }
878 
879 INSTANTIATE_TEST_SUITE_P(
880     All,
881     TLSClientSocketTest,
882     ::testing::Values(TLSClientSocketTestBase::kDirect,
883                       TLSClientSocketTestBase::kProxyResolving));
884 
885 // Tests with proxy resolving socket and a proxy actually configured.
886 class TLSCLientSocketProxyTest : public ::testing::Test,
887                                  public TLSClientSocketTestBase {
888  public:
TLSCLientSocketProxyTest()889   TLSCLientSocketProxyTest()
890       : TLSClientSocketTestBase(TLSClientSocketTestBase::kProxyResolving) {
891     Init(true /* use_mock_sockets*/, true /* configure_proxy */);
892   }
893 
~TLSCLientSocketProxyTest()894   ~TLSCLientSocketProxyTest() override {}
895 
896  private:
897   DISALLOW_COPY_AND_ASSIGN(TLSCLientSocketProxyTest);
898 };
899 
TEST_F(TLSCLientSocketProxyTest,UpgradeToTLS)900 TEST_F(TLSCLientSocketProxyTest, UpgradeToTLS) {
901   const char kConnectRequest[] =
902       "CONNECT 192.168.1.1:1234 HTTP/1.1\r\n"
903       "Host: 192.168.1.1:1234\r\n"
904       "Proxy-Connection: keep-alive\r\n\r\n";
905   const char kConnectResponse[] = "HTTP/1.1 200 OK\r\n\r\n";
906 
907   const net::MockRead kReads[] = {
908       net::MockRead(net::ASYNC, kConnectResponse, strlen(kConnectResponse), 1),
909       net::MockRead(net::ASYNC, kMsg, kMsgSize, 3),
910       net::MockRead(net::SYNCHRONOUS, net::OK, 4)};
911   const net::MockWrite kWrites[] = {
912       net::MockWrite(net::ASYNC, kConnectRequest, strlen(kConnectRequest), 0),
913       net::MockWrite(net::SYNCHRONOUS, kMsg, kMsgSize, 2)};
914   net::SequencedSocketData data_provider(kReads, kWrites);
915   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
916   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
917   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
918   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
919 
920   SocketHandle client_socket;
921   net::IPEndPoint server_addr(net::IPAddress(192, 168, 1, 1), 1234);
922   EXPECT_EQ(net::OK,
923             CreateSocketSync(MakeRequest(&client_socket), server_addr));
924 
925   net::HostPortPair host_port_pair("example.org", 443);
926   pre_tls_recv_handle()->reset();
927   pre_tls_send_handle()->reset();
928   net::TestCompletionCallback callback;
929   mojo::Remote<mojom::TLSClientSocket> tls_socket;
930   UpgradeToTLS(&client_socket, host_port_pair,
931                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
932   ASSERT_EQ(net::OK, callback.WaitForResult());
933   ResetSocket(&client_socket);
934 
935   uint32_t num_bytes = strlen(kMsg);
936   EXPECT_EQ(MOJO_RESULT_OK, post_tls_send_handle()->get().WriteData(
937                                 &kMsg, &num_bytes, MOJO_WRITE_DATA_FLAG_NONE));
938   EXPECT_EQ(kMsg, Read(post_tls_recv_handle(), kMsgSize));
939   base::RunLoop().RunUntilIdle();
940   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
941   EXPECT_TRUE(data_provider.AllReadDataConsumed());
942   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
943 }
944 
945 class TLSClientSocketIoModeTest : public TLSClientSocketTestBase,
946                                   public testing::TestWithParam<net::IoMode> {
947  public:
TLSClientSocketIoModeTest()948   TLSClientSocketIoModeTest()
949       : TLSClientSocketTestBase(TLSClientSocketTestBase::kDirect) {
950     Init(true /* use_mock_sockets*/, false /* configure_proxy */);
951   }
952 
~TLSClientSocketIoModeTest()953   ~TLSClientSocketIoModeTest() override {}
954 
955  private:
956   DISALLOW_COPY_AND_ASSIGN(TLSClientSocketIoModeTest);
957 };
958 
959 INSTANTIATE_TEST_SUITE_P(All,
960                          TLSClientSocketIoModeTest,
961                          testing::Values(net::SYNCHRONOUS, net::ASYNC));
962 
TEST_P(TLSClientSocketIoModeTest,MultipleWriteToTLSSocket)963 TEST_P(TLSClientSocketIoModeTest, MultipleWriteToTLSSocket) {
964   const int kNumIterations = 3;
965   std::vector<net::MockRead> reads;
966   std::vector<net::MockWrite> writes;
967   int sequence_number = 0;
968   net::IoMode mode = GetParam();
969   for (int j = 0; j < kNumIterations; ++j) {
970     for (size_t i = 0; i < kSecretMsgSize; ++i) {
971       writes.push_back(
972           net::MockWrite(mode, &kSecretMsg[i], 1, sequence_number++));
973     }
974     for (size_t i = 0; i < kSecretMsgSize; ++i) {
975       reads.push_back(
976           net::MockRead(net::ASYNC, &kSecretMsg[i], 1, sequence_number++));
977     }
978     if (j == kNumIterations - 1) {
979       reads.push_back(net::MockRead(mode, net::OK, sequence_number++));
980     }
981   }
982   net::SequencedSocketData data_provider(reads, writes);
983   data_provider.set_connect_data(net::MockConnect(GetParam(), net::OK));
984   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
985   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
986   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
987 
988   mojo::Remote<mojom::TCPConnectedSocket> client_socket;
989   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
990   EXPECT_EQ(net::OK,
991             CreateTCPConnectedSocketSync(
992                 client_socket.BindNewPipeAndPassReceiver(), server_addr));
993 
994   net::HostPortPair host_port_pair("example.org", 443);
995   pre_tls_recv_handle()->reset();
996   pre_tls_send_handle()->reset();
997   net::TestCompletionCallback callback;
998   mojo::Remote<mojom::TLSClientSocket> tls_socket;
999   UpgradeTCPConnectedSocketToTLS(
1000       client_socket.get(), host_port_pair, nullptr /* options */,
1001       tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
1002   ASSERT_EQ(net::OK, callback.WaitForResult());
1003   client_socket.reset();
1004   EXPECT_FALSE(ssl_info());
1005 
1006   // Loop kNumIterations times to test that writes can follow reads, and reads
1007   // can follow writes.
1008   for (int j = 0; j < kNumIterations; ++j) {
1009     // Write multiple times.
1010     for (size_t i = 0; i < kSecretMsgSize; ++i) {
1011       uint32_t num_bytes = 1;
1012       EXPECT_EQ(MOJO_RESULT_OK,
1013                 post_tls_send_handle()->get().WriteData(
1014                     &kSecretMsg[i], &num_bytes, MOJO_WRITE_DATA_FLAG_NONE));
1015       // Flush the 1 byte write.
1016       base::RunLoop().RunUntilIdle();
1017     }
1018     // Reading kSecretMsgSize should coalesce the 1-byte mock reads.
1019     EXPECT_EQ(kSecretMsg, Read(post_tls_recv_handle(), kSecretMsgSize));
1020   }
1021   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
1022   EXPECT_TRUE(data_provider.AllReadDataConsumed());
1023   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
1024 }
1025 
1026 // Check SSLInfo is provided in both sync and async cases.
TEST_P(TLSClientSocketIoModeTest,SSLInfo)1027 TEST_P(TLSClientSocketIoModeTest, SSLInfo) {
1028   // End of file. Reads don't matter, only the handshake does.
1029   std::vector<net::MockRead> reads = {net::MockRead(net::SYNCHRONOUS, net::OK)};
1030   std::vector<net::MockWrite> writes;
1031   net::SequencedSocketData data_provider(reads, writes);
1032   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
1033   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
1034   net::SSLSocketDataProvider ssl_socket(GetParam(), net::OK);
1035   // Set a value on SSLInfo to make sure it's correctly received.
1036   ssl_socket.ssl_info.is_issued_by_known_root = true;
1037   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
1038 
1039   mojo::Remote<mojom::TCPConnectedSocket> client_socket;
1040   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
1041   EXPECT_EQ(net::OK,
1042             CreateTCPConnectedSocketSync(
1043                 client_socket.BindNewPipeAndPassReceiver(), server_addr));
1044 
1045   net::HostPortPair host_port_pair("example.org", 443);
1046   pre_tls_recv_handle()->reset();
1047   pre_tls_send_handle()->reset();
1048   net::TestCompletionCallback callback;
1049   mojo::Remote<mojom::TLSClientSocket> tls_socket;
1050   mojom::TLSClientSocketOptionsPtr options =
1051       mojom::TLSClientSocketOptions::New();
1052   options->send_ssl_info = true;
1053   UpgradeTCPConnectedSocketToTLS(
1054       client_socket.get(), host_port_pair, std::move(options),
1055       tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
1056   ASSERT_EQ(net::OK, callback.WaitForResult());
1057   ASSERT_TRUE(ssl_info());
1058   EXPECT_TRUE(ssl_socket.ssl_info.is_issued_by_known_root);
1059   EXPECT_FALSE(ssl_socket.ssl_info.is_fatal_cert_error);
1060 }
1061 
1062 class TLSClientSocketTestWithEmbeddedTestServerBase
1063     : public TLSClientSocketTestBase {
1064  public:
TLSClientSocketTestWithEmbeddedTestServerBase(Mode mode)1065   explicit TLSClientSocketTestWithEmbeddedTestServerBase(Mode mode)
1066       : TLSClientSocketTestBase(mode),
1067         server_(net::EmbeddedTestServer::TYPE_HTTPS) {
1068     Init(false /* use_mock_sockets */, false /* configure_proxy */);
1069   }
1070 
~TLSClientSocketTestWithEmbeddedTestServerBase()1071   ~TLSClientSocketTestWithEmbeddedTestServerBase() override {}
1072 
1073   // Starts the test server using the specified certificate.
StartTestServer(net::EmbeddedTestServer::ServerCertificate certificate)1074   bool StartTestServer(net::EmbeddedTestServer::ServerCertificate certificate)
1075       WARN_UNUSED_RESULT {
1076     server_.RegisterRequestHandler(
1077         base::BindRepeating([](const net::test_server::HttpRequest& request) {
1078           if (base::StartsWith(request.relative_url, "/secret",
1079                                base::CompareCase::INSENSITIVE_ASCII)) {
1080             return std::unique_ptr<net::test_server::HttpResponse>(
1081                 new net::test_server::RawHttpResponse("HTTP/1.1 200 OK",
1082                                                       "Hello There!"));
1083           }
1084           return std::unique_ptr<net::test_server::HttpResponse>();
1085         }));
1086     server_.SetSSLConfig(certificate);
1087     return server_.Start();
1088   }
1089 
1090   // Attempts to eastablish a TLS connection to the test server by first
1091   // establishing a TCP connection, and then upgrading it.  Returns the
1092   // resulting network error code.
CreateTLSSocket()1093   int CreateTLSSocket() WARN_UNUSED_RESULT {
1094     SocketHandle client_socket;
1095     net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(),
1096                                 server_.port());
1097     EXPECT_EQ(net::OK,
1098               CreateSocketSync(MakeRequest(&client_socket), server_addr));
1099 
1100     pre_tls_recv_handle()->reset();
1101     pre_tls_send_handle()->reset();
1102     net::TestCompletionCallback callback;
1103     UpgradeToTLS(&client_socket, server_.host_port_pair(),
1104                  tls_socket_.BindNewPipeAndPassReceiver(), callback.callback());
1105     int result = callback.WaitForResult();
1106     ResetSocket(&client_socket);
1107     return result;
1108   }
1109 
CreateTLSSocketWithOptions(mojom::TLSClientSocketOptionsPtr options)1110   int CreateTLSSocketWithOptions(mojom::TLSClientSocketOptionsPtr options)
1111       WARN_UNUSED_RESULT {
1112     // Proxy connections don't support TLSClientSocketOptions.
1113     DCHECK_EQ(kDirect, mode());
1114 
1115     mojo::Remote<mojom::TCPConnectedSocket> tcp_socket;
1116     net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(),
1117                                 server_.port());
1118     EXPECT_EQ(net::OK,
1119               CreateTCPConnectedSocketSync(
1120                   tcp_socket.BindNewPipeAndPassReceiver(), server_addr));
1121 
1122     pre_tls_recv_handle()->reset();
1123     pre_tls_send_handle()->reset();
1124     net::TestCompletionCallback callback;
1125     UpgradeTCPConnectedSocketToTLS(
1126         tcp_socket.get(), server_.host_port_pair(), std::move(options),
1127         tls_socket_.BindNewPipeAndPassReceiver(), callback.callback());
1128     int result = callback.WaitForResult();
1129     tcp_socket.reset();
1130     return result;
1131   }
1132 
TestTlsSocket()1133   void TestTlsSocket() {
1134     ASSERT_TRUE(tls_socket_.is_bound());
1135     const char kTestMsg[] = "GET /secret HTTP/1.1\r\n\r\n";
1136     uint32_t num_bytes = strlen(kTestMsg);
1137     const char kResponse[] = "HTTP/1.1 200 OK\n\n";
1138     EXPECT_EQ(MOJO_RESULT_OK,
1139               post_tls_send_handle()->get().WriteData(
1140                   &kTestMsg, &num_bytes, MOJO_WRITE_DATA_FLAG_NONE));
1141     EXPECT_EQ(kResponse, Read(post_tls_recv_handle(), strlen(kResponse)));
1142   }
1143 
server()1144   net::EmbeddedTestServer* server() { return &server_; }
1145 
1146  private:
1147   net::EmbeddedTestServer server_;
1148 
1149   mojo::Remote<mojom::TLSClientSocket> tls_socket_;
1150 
1151   DISALLOW_COPY_AND_ASSIGN(TLSClientSocketTestWithEmbeddedTestServerBase);
1152 };
1153 
1154 class TLSClientSocketTestWithEmbeddedTestServer
1155     : public TLSClientSocketTestWithEmbeddedTestServerBase,
1156       public ::testing::TestWithParam<TLSClientSocketTestBase::Mode> {
1157  public:
TLSClientSocketTestWithEmbeddedTestServer()1158   TLSClientSocketTestWithEmbeddedTestServer()
1159       : TLSClientSocketTestWithEmbeddedTestServerBase(GetParam()) {}
~TLSClientSocketTestWithEmbeddedTestServer()1160   ~TLSClientSocketTestWithEmbeddedTestServer() override {}
1161 
1162  private:
1163   DISALLOW_COPY_AND_ASSIGN(TLSClientSocketTestWithEmbeddedTestServer);
1164 };
1165 
TEST_P(TLSClientSocketTestWithEmbeddedTestServer,Basic)1166 TEST_P(TLSClientSocketTestWithEmbeddedTestServer, Basic) {
1167   ASSERT_TRUE(StartTestServer(net::EmbeddedTestServer::CERT_OK));
1168   ASSERT_EQ(net::OK, CreateTLSSocket());
1169   // No SSLInfo should be received by default. SSLInfo is only supported in the
1170   // kDirect case, but it doesn't hurt to check it's null it in the
1171   // kProxyResolving case.
1172   EXPECT_FALSE(ssl_info());
1173   TestTlsSocket();
1174 }
1175 
TEST_P(TLSClientSocketTestWithEmbeddedTestServer,ServerCertError)1176 TEST_P(TLSClientSocketTestWithEmbeddedTestServer, ServerCertError) {
1177   ASSERT_TRUE(StartTestServer(net::EmbeddedTestServer::CERT_MISMATCHED_NAME));
1178   ASSERT_EQ(net::ERR_CERT_COMMON_NAME_INVALID, CreateTLSSocket());
1179   // No SSLInfo should be received by default. SSLInfo is only supported in the
1180   // kDirect case, but it doesn't hurt to check it's null in the kProxyResolving
1181   // case.
1182   EXPECT_FALSE(ssl_info());
1183 
1184   // Handles should be invalid.
1185   EXPECT_FALSE(post_tls_recv_handle()->is_valid());
1186   EXPECT_FALSE(post_tls_send_handle()->is_valid());
1187 }
1188 
1189 INSTANTIATE_TEST_SUITE_P(
1190     All,
1191     TLSClientSocketTestWithEmbeddedTestServer,
1192     ::testing::Values(TLSClientSocketTestBase::kDirect,
1193                       TLSClientSocketTestBase::kProxyResolving));
1194 
1195 class TLSClientSocketDirectTestWithEmbeddedTestServer
1196     : public TLSClientSocketTestWithEmbeddedTestServerBase,
1197       public testing::Test {
1198  public:
TLSClientSocketDirectTestWithEmbeddedTestServer()1199   TLSClientSocketDirectTestWithEmbeddedTestServer()
1200       : TLSClientSocketTestWithEmbeddedTestServerBase(kDirect) {}
~TLSClientSocketDirectTestWithEmbeddedTestServer()1201   ~TLSClientSocketDirectTestWithEmbeddedTestServer() override {}
1202 
1203  private:
1204   DISALLOW_COPY_AND_ASSIGN(TLSClientSocketDirectTestWithEmbeddedTestServer);
1205 };
1206 
TEST_F(TLSClientSocketDirectTestWithEmbeddedTestServer,SSLInfo)1207 TEST_F(TLSClientSocketDirectTestWithEmbeddedTestServer, SSLInfo) {
1208   ASSERT_TRUE(StartTestServer(net::EmbeddedTestServer::CERT_OK));
1209   mojom::TLSClientSocketOptionsPtr options =
1210       mojom::TLSClientSocketOptions::New();
1211   options->send_ssl_info = true;
1212   ASSERT_EQ(net::OK, CreateTLSSocketWithOptions(std::move(options)));
1213 
1214   ASSERT_TRUE(ssl_info());
1215   EXPECT_TRUE(ssl_info()->is_valid());
1216   EXPECT_FALSE(ssl_info()->is_fatal_cert_error);
1217 
1218   TestTlsSocket();
1219 }
1220 
TEST_F(TLSClientSocketDirectTestWithEmbeddedTestServer,SSLInfoServerCertError)1221 TEST_F(TLSClientSocketDirectTestWithEmbeddedTestServer,
1222        SSLInfoServerCertError) {
1223   ASSERT_TRUE(StartTestServer(net::EmbeddedTestServer::CERT_MISMATCHED_NAME));
1224   mojom::TLSClientSocketOptionsPtr options =
1225       mojom::TLSClientSocketOptions::New();
1226   options->send_ssl_info = true;
1227   // Requesting SSLInfo should not bypass cert verification.
1228   ASSERT_EQ(net::ERR_CERT_COMMON_NAME_INVALID,
1229             CreateTLSSocketWithOptions(std::move(options)));
1230 
1231   // No SSLInfo should be provided on error.
1232   EXPECT_FALSE(ssl_info());
1233 
1234   // Handles should be invalid.
1235   EXPECT_FALSE(post_tls_recv_handle()->is_valid());
1236   EXPECT_FALSE(post_tls_send_handle()->is_valid());
1237 }
1238 
1239 // Check skipping cert verification always received SSLInfo, even with valid
1240 // certs.
TEST_F(TLSClientSocketDirectTestWithEmbeddedTestServer,UnsafelySkipCertVerification)1241 TEST_F(TLSClientSocketDirectTestWithEmbeddedTestServer,
1242        UnsafelySkipCertVerification) {
1243   ASSERT_TRUE(StartTestServer(net::EmbeddedTestServer::CERT_OK));
1244   mojom::TLSClientSocketOptionsPtr options =
1245       mojom::TLSClientSocketOptions::New();
1246   options->unsafely_skip_cert_verification = true;
1247   ASSERT_EQ(net::OK, CreateTLSSocketWithOptions(std::move(options)));
1248 
1249   ASSERT_TRUE(ssl_info());
1250   EXPECT_TRUE(ssl_info()->is_valid());
1251   EXPECT_FALSE(ssl_info()->is_fatal_cert_error);
1252 
1253   TestTlsSocket();
1254 }
1255 
TEST_F(TLSClientSocketDirectTestWithEmbeddedTestServer,UnsafelySkipCertVerificationServerCertError)1256 TEST_F(TLSClientSocketDirectTestWithEmbeddedTestServer,
1257        UnsafelySkipCertVerificationServerCertError) {
1258   ASSERT_TRUE(StartTestServer(net::EmbeddedTestServer::CERT_MISMATCHED_NAME));
1259   mojom::TLSClientSocketOptionsPtr options =
1260       mojom::TLSClientSocketOptions::New();
1261   options->unsafely_skip_cert_verification = true;
1262   ASSERT_EQ(net::OK, CreateTLSSocketWithOptions(std::move(options)));
1263 
1264   ASSERT_TRUE(ssl_info());
1265   EXPECT_TRUE(ssl_info()->is_valid());
1266   EXPECT_FALSE(ssl_info()->is_fatal_cert_error);
1267 
1268   TestTlsSocket();
1269 }
1270 
1271 }  // namespace
1272 
1273 }  // namespace network
1274