1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <stdint.h>
6 
7 #include <string>
8 #include <utility>
9 #include <vector>
10 
11 #include "base/bind.h"
12 #include "base/check_op.h"
13 #include "base/macros.h"
14 #include "base/run_loop.h"
15 #include "base/test/bind.h"
16 #include "base/test/task_environment.h"
17 #include "base/threading/thread.h"
18 #include "mojo/public/cpp/bindings/pending_receiver.h"
19 #include "mojo/public/cpp/bindings/remote.h"
20 #include "net/base/completion_once_callback.h"
21 #include "net/base/net_errors.h"
22 #include "net/base/network_isolation_key.h"
23 #include "net/base/test_completion_callback.h"
24 #include "net/proxy_resolution/configured_proxy_resolution_service.h"
25 #include "net/socket/server_socket.h"
26 #include "net/socket/socket_test_util.h"
27 #include "net/test/embedded_test_server/embedded_test_server.h"
28 #include "net/test/embedded_test_server/http_request.h"
29 #include "net/test/embedded_test_server/http_response.h"
30 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
31 #include "net/url_request/url_request_test_util.h"
32 #include "services/network/mojo_socket_test_util.h"
33 #include "services/network/proxy_resolving_socket_factory_mojo.h"
34 #include "services/network/public/mojom/network_service.mojom.h"
35 #include "services/network/socket_factory.h"
36 #include "testing/gtest/include/gtest/gtest.h"
37 
38 namespace network {
39 
40 namespace {
41 
42 // Message sent over the tcp connection.
43 const char kMsg[] = "please start tls!";
44 const size_t kMsgSize = strlen(kMsg);
45 
46 // Message sent over the tls connection.
47 const char kSecretMsg[] = "here is secret.";
48 const size_t kSecretMsgSize = strlen(kSecretMsg);
49 
50 class TLSClientSocketTestBase {
51  public:
52   enum Mode { kDirect, kProxyResolving };
53 
TLSClientSocketTestBase(Mode mode)54   explicit TLSClientSocketTestBase(Mode mode)
55       : mode_(mode),
56         task_environment_(base::test::TaskEnvironment::MainThreadType::IO),
57         url_request_context_(true) {}
~TLSClientSocketTestBase()58   virtual ~TLSClientSocketTestBase() {}
59 
mode()60   Mode mode() { return mode_; }
61 
62  protected:
63   // One of the two fields will be set, depending on the mode.
64   struct SocketHandle {
65     mojo::Remote<mojom::TCPConnectedSocket> tcp_socket;
66     mojo::Remote<mojom::ProxyResolvingSocket> proxy_socket;
67   };
68 
69   struct SocketRequest {
70     mojo::PendingReceiver<mojom::TCPConnectedSocket> tcp_socket_receiver;
71     mojo::PendingReceiver<mojom::ProxyResolvingSocket> proxy_socket_receiver;
72   };
73 
74   // Initializes the test fixture. If |use_mock_sockets|, mock client socket
75   // factory will be used.
Init(bool use_mock_sockets,bool configure_proxy)76   void Init(bool use_mock_sockets, bool configure_proxy) {
77     if (use_mock_sockets) {
78       mock_client_socket_factory_.set_enable_read_if_ready(true);
79       url_request_context_.set_client_socket_factory(
80           &mock_client_socket_factory_);
81     }
82     if (configure_proxy) {
83       proxy_resolution_service_ =
84           net::ConfiguredProxyResolutionService::CreateFixed(
85               "http://proxy:8080", TRAFFIC_ANNOTATION_FOR_TESTS);
86       url_request_context_.set_proxy_resolution_service(
87           proxy_resolution_service_.get());
88     }
89     url_request_context_.Init();
90     factory_ = std::make_unique<SocketFactory>(nullptr /*net_log*/,
91                                                &url_request_context_);
92     proxy_resolving_factory_ =
93         std::make_unique<ProxyResolvingSocketFactoryMojo>(
94             &url_request_context_);
95   }
96 
97   // Reads |num_bytes| from |handle| or reads until an error occurs. Returns the
98   // bytes read as a string.
Read(mojo::ScopedDataPipeConsumerHandle * handle,size_t num_bytes)99   std::string Read(mojo::ScopedDataPipeConsumerHandle* handle,
100                    size_t num_bytes) {
101     std::string received_contents;
102     while (received_contents.size() < num_bytes) {
103       base::RunLoop().RunUntilIdle();
104       std::vector<char> buffer(num_bytes);
105       uint32_t read_size = static_cast<uint32_t>(num_bytes);
106       MojoResult result = handle->get().ReadData(buffer.data(), &read_size,
107                                                  MOJO_READ_DATA_FLAG_NONE);
108       if (result == MOJO_RESULT_SHOULD_WAIT)
109         continue;
110       if (result != MOJO_RESULT_OK)
111         return received_contents;
112       received_contents.append(buffer.data(), read_size);
113     }
114     return received_contents;
115   }
116 
MakeRequest(SocketHandle * handle)117   SocketRequest MakeRequest(SocketHandle* handle) {
118     SocketRequest result;
119     if (mode_ == kDirect)
120       result.tcp_socket_receiver =
121           handle->tcp_socket.BindNewPipeAndPassReceiver();
122     else
123       result.proxy_socket_receiver =
124           handle->proxy_socket.BindNewPipeAndPassReceiver();
125     return result;
126   }
127 
ResetSocket(SocketHandle * handle)128   void ResetSocket(SocketHandle* handle) {
129     if (mode_ == kDirect)
130       handle->tcp_socket.reset();
131     else
132       handle->proxy_socket.reset();
133   }
134 
CreateSocketSync(SocketRequest request,const net::IPEndPoint & remote_addr)135   int CreateSocketSync(SocketRequest request,
136                        const net::IPEndPoint& remote_addr) {
137     if (mode_ == kDirect) {
138       return CreateTCPConnectedSocketSync(
139           std::move(request.tcp_socket_receiver), remote_addr);
140     } else {
141       return CreateProxyResolvingSocketSync(
142           std::move(request.proxy_socket_receiver), remote_addr);
143     }
144   }
145 
CreateTCPConnectedSocketSync(mojo::PendingReceiver<mojom::TCPConnectedSocket> receiver,const net::IPEndPoint & remote_addr)146   int CreateTCPConnectedSocketSync(
147       mojo::PendingReceiver<mojom::TCPConnectedSocket> receiver,
148       const net::IPEndPoint& remote_addr) {
149     net::AddressList remote_addr_list(remote_addr);
150     base::RunLoop run_loop;
151     int net_error = net::ERR_FAILED;
152     factory_->CreateTCPConnectedSocket(
153         base::nullopt /* local_addr */, remote_addr_list,
154         nullptr /* tcp_connected_socket_options */,
155         TRAFFIC_ANNOTATION_FOR_TESTS, std::move(receiver),
156         pre_tls_observer()->GetObserverRemote(),
157         base::BindLambdaForTesting(
158             [&](int result,
159                 const base::Optional<net::IPEndPoint>& actual_local_addr,
160                 const base::Optional<net::IPEndPoint>& peer_addr,
161                 mojo::ScopedDataPipeConsumerHandle receive_pipe_handle,
162                 mojo::ScopedDataPipeProducerHandle send_pipe_handle) {
163               net_error = result;
164               pre_tls_recv_handle_ = std::move(receive_pipe_handle);
165               pre_tls_send_handle_ = std::move(send_pipe_handle);
166               run_loop.Quit();
167             }));
168     run_loop.Run();
169     return net_error;
170   }
171 
CreateProxyResolvingSocketSync(mojo::PendingReceiver<mojom::ProxyResolvingSocket> receiver,const net::IPEndPoint & remote_addr)172   int CreateProxyResolvingSocketSync(
173       mojo::PendingReceiver<mojom::ProxyResolvingSocket> receiver,
174       const net::IPEndPoint& remote_addr) {
175     GURL url("https://" + remote_addr.ToString());
176     base::RunLoop run_loop;
177     int net_error = net::ERR_FAILED;
178     proxy_resolving_factory_->CreateProxyResolvingSocket(
179         url, net::NetworkIsolationKey(), nullptr /* options */,
180         net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS),
181         std::move(receiver), mojo::NullRemote() /* observer */,
182         base::BindLambdaForTesting(
183             [&](int result,
184                 const base::Optional<net::IPEndPoint>& actual_local_addr,
185                 const base::Optional<net::IPEndPoint>& peer_addr,
186                 mojo::ScopedDataPipeConsumerHandle receive_pipe_handle,
187                 mojo::ScopedDataPipeProducerHandle send_pipe_handle) {
188               net_error = result;
189               pre_tls_recv_handle_ = std::move(receive_pipe_handle);
190               pre_tls_send_handle_ = std::move(send_pipe_handle);
191               run_loop.Quit();
192             }));
193     run_loop.Run();
194     return net_error;
195   }
196 
UpgradeToTLS(SocketHandle * handle,const net::HostPortPair & host_port_pair,mojo::PendingReceiver<mojom::TLSClientSocket> receiver,net::CompletionOnceCallback callback)197   void UpgradeToTLS(SocketHandle* handle,
198                     const net::HostPortPair& host_port_pair,
199                     mojo::PendingReceiver<mojom::TLSClientSocket> receiver,
200                     net::CompletionOnceCallback callback) {
201     if (mode_ == kDirect) {
202       UpgradeTCPConnectedSocketToTLS(handle->tcp_socket.get(), host_port_pair,
203                                      nullptr /* options */, std::move(receiver),
204                                      std::move(callback));
205     } else {
206       UpgradeProxyResolvingSocketToTLS(handle->proxy_socket.get(),
207                                        host_port_pair, std::move(receiver),
208                                        std::move(callback));
209     }
210   }
211 
UpgradeTCPConnectedSocketToTLS(mojom::TCPConnectedSocket * client_socket,const net::HostPortPair & host_port_pair,mojom::TLSClientSocketOptionsPtr options,mojo::PendingReceiver<mojom::TLSClientSocket> receiver,net::CompletionOnceCallback callback)212   void UpgradeTCPConnectedSocketToTLS(
213       mojom::TCPConnectedSocket* client_socket,
214       const net::HostPortPair& host_port_pair,
215       mojom::TLSClientSocketOptionsPtr options,
216       mojo::PendingReceiver<mojom::TLSClientSocket> receiver,
217       net::CompletionOnceCallback callback) {
218     client_socket->UpgradeToTLS(
219         host_port_pair, std::move(options),
220         net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS),
221         std::move(receiver), post_tls_observer()->GetObserverRemote(),
222         base::BindOnce(
223             [](net::CompletionOnceCallback cb,
224                mojo::ScopedDataPipeConsumerHandle* consumer_handle_out,
225                mojo::ScopedDataPipeProducerHandle* producer_handle_out,
226                base::Optional<net::SSLInfo>* ssl_info_out, int result,
227                mojo::ScopedDataPipeConsumerHandle receive_pipe_handle,
228                mojo::ScopedDataPipeProducerHandle send_pipe_handle,
229                const base::Optional<net::SSLInfo>& ssl_info) {
230               *consumer_handle_out = std::move(receive_pipe_handle);
231               *producer_handle_out = std::move(send_pipe_handle);
232               *ssl_info_out = ssl_info;
233               std::move(cb).Run(result);
234             },
235             std::move(callback), &post_tls_recv_handle_, &post_tls_send_handle_,
236             &ssl_info_));
237   }
238 
UpgradeProxyResolvingSocketToTLS(mojom::ProxyResolvingSocket * client_socket,const net::HostPortPair & host_port_pair,mojo::PendingReceiver<mojom::TLSClientSocket> receiver,net::CompletionOnceCallback callback)239   void UpgradeProxyResolvingSocketToTLS(
240       mojom::ProxyResolvingSocket* client_socket,
241       const net::HostPortPair& host_port_pair,
242       mojo::PendingReceiver<mojom::TLSClientSocket> receiver,
243       net::CompletionOnceCallback callback) {
244     client_socket->UpgradeToTLS(
245         host_port_pair,
246         net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS),
247         std::move(receiver), post_tls_observer()->GetObserverRemote(),
248         base::BindOnce(
249             [](net::CompletionOnceCallback cb,
250                mojo::ScopedDataPipeConsumerHandle* consumer_handle,
251                mojo::ScopedDataPipeProducerHandle* producer_handle, int result,
252                mojo::ScopedDataPipeConsumerHandle receive_pipe_handle,
253                mojo::ScopedDataPipeProducerHandle send_pipe_handle) {
254               *consumer_handle = std::move(receive_pipe_handle);
255               *producer_handle = std::move(send_pipe_handle);
256               std::move(cb).Run(result);
257             },
258             std::move(callback), &post_tls_recv_handle_,
259             &post_tls_send_handle_));
260   }
261 
pre_tls_observer()262   TestSocketObserver* pre_tls_observer() { return &pre_tls_observer_; }
post_tls_observer()263   TestSocketObserver* post_tls_observer() { return &post_tls_observer_; }
264 
pre_tls_recv_handle()265   mojo::ScopedDataPipeConsumerHandle* pre_tls_recv_handle() {
266     return &pre_tls_recv_handle_;
267   }
268 
pre_tls_send_handle()269   mojo::ScopedDataPipeProducerHandle* pre_tls_send_handle() {
270     return &pre_tls_send_handle_;
271   }
272 
post_tls_recv_handle()273   mojo::ScopedDataPipeConsumerHandle* post_tls_recv_handle() {
274     return &post_tls_recv_handle_;
275   }
276 
post_tls_send_handle()277   mojo::ScopedDataPipeProducerHandle* post_tls_send_handle() {
278     return &post_tls_send_handle_;
279   }
280 
ssl_info()281   const base::Optional<net::SSLInfo>& ssl_info() { return ssl_info_; }
282 
mock_client_socket_factory()283   net::MockClientSocketFactory* mock_client_socket_factory() {
284     return &mock_client_socket_factory_;
285   }
286 
mode() const287   Mode mode() const { return mode_; }
288 
289  private:
290   Mode mode_;
291   base::test::TaskEnvironment task_environment_;
292 
293   // Mojo data handles obtained from CreateTCPConnectedSocket.
294   mojo::ScopedDataPipeConsumerHandle pre_tls_recv_handle_;
295   mojo::ScopedDataPipeProducerHandle pre_tls_send_handle_;
296 
297   // Mojo data handles obtained from UpgradeToTLS.
298   mojo::ScopedDataPipeConsumerHandle post_tls_recv_handle_;
299   mojo::ScopedDataPipeProducerHandle post_tls_send_handle_;
300 
301   // SSLInfo obtained from UpgradeToTLS.
302   base::Optional<net::SSLInfo> ssl_info_;
303 
304   std::unique_ptr<net::ProxyResolutionService> proxy_resolution_service_;
305   net::TestURLRequestContext url_request_context_;
306   net::MockClientSocketFactory mock_client_socket_factory_;
307   std::unique_ptr<SocketFactory> factory_;
308   std::unique_ptr<ProxyResolvingSocketFactoryMojo> proxy_resolving_factory_;
309   TestSocketObserver pre_tls_observer_;
310   TestSocketObserver post_tls_observer_;
311 
312   DISALLOW_COPY_AND_ASSIGN(TLSClientSocketTestBase);
313 };
314 
315 class TLSClientSocketTest
316     : public ::testing::TestWithParam<TLSClientSocketTestBase::Mode>,
317       public TLSClientSocketTestBase {
318  public:
TLSClientSocketTest()319   TLSClientSocketTest() : TLSClientSocketTestBase(GetParam()) {
320     Init(true /* use_mock_sockets */, false /* configure_proxy */);
321   }
322 
~TLSClientSocketTest()323   ~TLSClientSocketTest() override {}
324 
325  private:
326   DISALLOW_COPY_AND_ASSIGN(TLSClientSocketTest);
327 };
328 
329 // Basic test to call UpgradeToTLS, and then read/write after UpgradeToTLS is
330 // successful.
TEST_P(TLSClientSocketTest,UpgradeToTLS)331 TEST_P(TLSClientSocketTest, UpgradeToTLS) {
332   const net::MockRead kReads[] = {net::MockRead(net::ASYNC, kMsg, kMsgSize, 1),
333                                   net::MockRead(net::SYNCHRONOUS, net::OK, 2)};
334   const net::MockWrite kWrites[] = {
335       net::MockWrite(net::SYNCHRONOUS, kMsg, kMsgSize, 0)};
336   net::SequencedSocketData data_provider(kReads, kWrites);
337   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
338   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
339   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
340   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
341 
342   SocketHandle client_socket;
343   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
344   EXPECT_EQ(net::OK,
345             CreateSocketSync(MakeRequest(&client_socket), server_addr));
346 
347   net::HostPortPair host_port_pair("example.org", 443);
348   pre_tls_recv_handle()->reset();
349   pre_tls_send_handle()->reset();
350   net::TestCompletionCallback callback;
351   mojo::Remote<mojom::TLSClientSocket> tls_socket;
352   UpgradeToTLS(&client_socket, host_port_pair,
353                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
354   ASSERT_EQ(net::OK, callback.WaitForResult());
355   ResetSocket(&client_socket);
356 
357   uint32_t num_bytes = strlen(kMsg);
358   EXPECT_EQ(MOJO_RESULT_OK, post_tls_send_handle()->get().WriteData(
359                                 &kMsg, &num_bytes, MOJO_WRITE_DATA_FLAG_NONE));
360   EXPECT_EQ(kMsg, Read(post_tls_recv_handle(), kMsgSize));
361   base::RunLoop().RunUntilIdle();
362   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
363   EXPECT_TRUE(data_provider.AllReadDataConsumed());
364   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
365 }
366 
367 // Same as the UpgradeToTLS test above, except this test calls
368 // base::RunLoop().RunUntilIdle() after destroying the pre-tls data pipes.
TEST_P(TLSClientSocketTest,ClosePipesRunUntilIdleAndUpgradeToTLS)369 TEST_P(TLSClientSocketTest, ClosePipesRunUntilIdleAndUpgradeToTLS) {
370   const net::MockRead kReads[] = {net::MockRead(net::ASYNC, kMsg, kMsgSize, 1),
371                                   net::MockRead(net::SYNCHRONOUS, net::OK, 2)};
372   const net::MockWrite kWrites[] = {
373       net::MockWrite(net::SYNCHRONOUS, kMsg, kMsgSize, 0)};
374   net::SequencedSocketData data_provider(kReads, kWrites);
375   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
376   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
377   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
378   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
379 
380   SocketHandle client_socket;
381   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
382   EXPECT_EQ(net::OK,
383             CreateSocketSync(MakeRequest(&client_socket), server_addr));
384 
385   net::HostPortPair host_port_pair("example.org", 443);
386 
387   // Call RunUntilIdle() to test the case that pipes are closed before
388   // UpgradeToTLS.
389   pre_tls_recv_handle()->reset();
390   pre_tls_send_handle()->reset();
391   base::RunLoop().RunUntilIdle();
392 
393   net::TestCompletionCallback callback;
394   mojo::Remote<mojom::TLSClientSocket> tls_socket;
395   UpgradeToTLS(&client_socket, host_port_pair,
396                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
397   ASSERT_EQ(net::OK, callback.WaitForResult());
398   ResetSocket(&client_socket);
399 
400   uint32_t num_bytes = strlen(kMsg);
401   EXPECT_EQ(MOJO_RESULT_OK, post_tls_send_handle()->get().WriteData(
402                                 &kMsg, &num_bytes, MOJO_WRITE_DATA_FLAG_NONE));
403   EXPECT_EQ(kMsg, Read(post_tls_recv_handle(), kMsgSize));
404   base::RunLoop().RunUntilIdle();
405   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
406   EXPECT_TRUE(data_provider.AllReadDataConsumed());
407   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
408 }
409 
410 // Calling UpgradeToTLS on the same mojo::Remote<TCPConnectedSocket> is illegal
411 // and should receive an error.
TEST_P(TLSClientSocketTest,UpgradeToTLSTwice)412 TEST_P(TLSClientSocketTest, UpgradeToTLSTwice) {
413   const net::MockRead kReads[] = {net::MockRead(net::ASYNC, net::OK, 0)};
414   net::SequencedSocketData data_provider(kReads, base::span<net::MockWrite>());
415   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
416   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
417   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
418   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
419 
420   SocketHandle client_socket;
421   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
422   EXPECT_EQ(net::OK,
423             CreateSocketSync(MakeRequest(&client_socket), server_addr));
424 
425   net::HostPortPair host_port_pair("example.org", 443);
426   pre_tls_recv_handle()->reset();
427   pre_tls_send_handle()->reset();
428 
429   // First UpgradeToTLS should complete successfully.
430   net::TestCompletionCallback callback;
431   mojo::Remote<mojom::TLSClientSocket> tls_socket;
432   UpgradeToTLS(&client_socket, host_port_pair,
433                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
434   ASSERT_EQ(net::OK, callback.WaitForResult());
435 
436   // Second time UpgradeToTLS is called, it should fail.
437   mojo::Remote<mojom::TLSClientSocket> tls_socket2;
438   base::RunLoop run_loop;
439   int net_error = net::ERR_FAILED;
440   if (mode() == kDirect) {
441     auto upgrade2_callback = base::BindLambdaForTesting(
442         [&](int result, mojo::ScopedDataPipeConsumerHandle receive_pipe_handle,
443             mojo::ScopedDataPipeProducerHandle send_pipe_handle,
444             const base::Optional<net::SSLInfo>& ssl_info) {
445           net_error = result;
446           run_loop.Quit();
447         });
448     client_socket.tcp_socket->UpgradeToTLS(
449         host_port_pair, nullptr /* ssl_config_ptr */,
450         net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS),
451         tls_socket2.BindNewPipeAndPassReceiver(),
452         mojo::NullRemote() /*observer */, std::move(upgrade2_callback));
453   } else {
454     auto upgrade2_callback = base::BindLambdaForTesting(
455         [&](int result, mojo::ScopedDataPipeConsumerHandle receive_pipe_handle,
456             mojo::ScopedDataPipeProducerHandle send_pipe_handle) {
457           net_error = result;
458           run_loop.Quit();
459         });
460     client_socket.proxy_socket->UpgradeToTLS(
461         host_port_pair,
462         net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS),
463         tls_socket2.BindNewPipeAndPassReceiver(),
464         mojo::NullRemote() /*observer */, std::move(upgrade2_callback));
465   }
466   run_loop.Run();
467   ASSERT_EQ(net::ERR_SOCKET_NOT_CONNECTED, net_error);
468 
469   base::RunLoop().RunUntilIdle();
470   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
471   EXPECT_TRUE(data_provider.AllReadDataConsumed());
472   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
473 }
474 
TEST_P(TLSClientSocketTest,UpgradeToTLSWithCustomSSLConfig)475 TEST_P(TLSClientSocketTest, UpgradeToTLSWithCustomSSLConfig) {
476   // No custom options in the proxy-resolving case.
477   if (mode() != kDirect)
478     return;
479   const net::MockRead kReads[] = {net::MockRead(net::ASYNC, net::OK, 0)};
480   net::SequencedSocketData data_provider(kReads, base::span<net::MockWrite>());
481   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
482   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
483   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
484   ssl_socket.expected_ssl_version_min = net::SSL_PROTOCOL_VERSION_TLS1_1;
485   ssl_socket.expected_ssl_version_max = net::SSL_PROTOCOL_VERSION_TLS1_2;
486   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
487 
488   SocketHandle client_socket;
489   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
490   EXPECT_EQ(net::OK,
491             CreateSocketSync(MakeRequest(&client_socket), server_addr));
492 
493   net::HostPortPair host_port_pair("example.org", 443);
494   pre_tls_recv_handle()->reset();
495   pre_tls_send_handle()->reset();
496 
497   mojo::Remote<mojom::TLSClientSocket> tls_socket;
498   base::RunLoop run_loop;
499   mojom::TLSClientSocketOptionsPtr options =
500       mojom::TLSClientSocketOptions::New();
501   options->version_min = mojom::SSLVersion::kTLS11;
502   options->version_max = mojom::SSLVersion::kTLS12;
503   int net_error = net::ERR_FAILED;
504   auto upgrade_callback = base::BindLambdaForTesting(
505       [&](int result, mojo::ScopedDataPipeConsumerHandle receive_pipe_handle,
506           mojo::ScopedDataPipeProducerHandle send_pipe_handle,
507           const base::Optional<net::SSLInfo>& ssl_info) {
508         net_error = result;
509         run_loop.Quit();
510       });
511   client_socket.tcp_socket->UpgradeToTLS(
512       host_port_pair, std::move(options),
513       net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS),
514       tls_socket.BindNewPipeAndPassReceiver(), mojo::NullRemote() /*observer */,
515       std::move(upgrade_callback));
516   run_loop.Run();
517   ASSERT_EQ(net::OK, net_error);
518 
519   base::RunLoop().RunUntilIdle();
520   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
521   EXPECT_TRUE(data_provider.AllReadDataConsumed());
522   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
523 }
524 
525 // Same as the UpgradeToTLS test, except this also reads and writes to the tcp
526 // connection before UpgradeToTLS is called.
TEST_P(TLSClientSocketTest,ReadWriteBeforeUpgradeToTLS)527 TEST_P(TLSClientSocketTest, ReadWriteBeforeUpgradeToTLS) {
528   const net::MockRead kReads[] = {
529       net::MockRead(net::SYNCHRONOUS, kMsg, kMsgSize, 0),
530       net::MockRead(net::ASYNC, kSecretMsg, kSecretMsgSize, 3),
531       net::MockRead(net::SYNCHRONOUS, net::OK, 4)};
532   const net::MockWrite kWrites[] = {
533       net::MockWrite(net::SYNCHRONOUS, kMsg, kMsgSize, 1),
534       net::MockWrite(net::SYNCHRONOUS, kSecretMsg, kSecretMsgSize, 2),
535   };
536   net::SequencedSocketData data_provider(kReads, kWrites);
537   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
538   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
539   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
540   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
541 
542   SocketHandle client_socket;
543   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
544   EXPECT_EQ(net::OK,
545             CreateSocketSync(MakeRequest(&client_socket), server_addr));
546 
547   EXPECT_EQ(kMsg, Read(pre_tls_recv_handle(), kMsgSize));
548 
549   uint32_t num_bytes = kMsgSize;
550   EXPECT_EQ(MOJO_RESULT_OK, pre_tls_send_handle()->get().WriteData(
551                                 &kMsg, &num_bytes, MOJO_WRITE_DATA_FLAG_NONE));
552 
553   net::HostPortPair host_port_pair("example.org", 443);
554   pre_tls_recv_handle()->reset();
555   pre_tls_send_handle()->reset();
556   net::TestCompletionCallback callback;
557   mojo::Remote<mojom::TLSClientSocket> tls_socket;
558   UpgradeToTLS(&client_socket, host_port_pair,
559                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
560   ASSERT_EQ(net::OK, callback.WaitForResult());
561   ResetSocket(&client_socket);
562 
563   num_bytes = strlen(kSecretMsg);
564   EXPECT_EQ(MOJO_RESULT_OK,
565             post_tls_send_handle()->get().WriteData(&kSecretMsg, &num_bytes,
566                                                     MOJO_WRITE_DATA_FLAG_NONE));
567   EXPECT_EQ(kSecretMsg, Read(post_tls_recv_handle(), kSecretMsgSize));
568   base::RunLoop().RunUntilIdle();
569   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
570   EXPECT_TRUE(data_provider.AllReadDataConsumed());
571   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
572 }
573 
574 // Tests that a read error is encountered after UpgradeToTLS completes
575 // successfully.
TEST_P(TLSClientSocketTest,ReadErrorAfterUpgradeToTLS)576 TEST_P(TLSClientSocketTest, ReadErrorAfterUpgradeToTLS) {
577   const net::MockRead kReads[] = {
578       net::MockRead(net::ASYNC, kSecretMsg, kSecretMsgSize, 1),
579       net::MockRead(net::SYNCHRONOUS, net::ERR_CONNECTION_CLOSED, 2)};
580   const net::MockWrite kWrites[] = {
581       net::MockWrite(net::SYNCHRONOUS, kSecretMsg, kSecretMsgSize, 0)};
582   net::SequencedSocketData data_provider(kReads, kWrites);
583   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
584   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
585   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
586   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
587 
588   SocketHandle client_socket;
589   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
590   EXPECT_EQ(net::OK,
591             CreateSocketSync(MakeRequest(&client_socket), server_addr));
592 
593   net::HostPortPair host_port_pair("example.org", 443);
594   pre_tls_recv_handle()->reset();
595   pre_tls_send_handle()->reset();
596   net::TestCompletionCallback callback;
597   mojo::Remote<mojom::TLSClientSocket> tls_socket;
598   UpgradeToTLS(&client_socket, host_port_pair,
599                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
600   ASSERT_EQ(net::OK, callback.WaitForResult());
601   ResetSocket(&client_socket);
602 
603   uint32_t num_bytes = strlen(kSecretMsg);
604   EXPECT_EQ(MOJO_RESULT_OK,
605             post_tls_send_handle()->get().WriteData(&kSecretMsg, &num_bytes,
606                                                     MOJO_WRITE_DATA_FLAG_NONE));
607   EXPECT_EQ(kSecretMsg, Read(post_tls_recv_handle(), kSecretMsgSize));
608   EXPECT_EQ(net::ERR_CONNECTION_CLOSED,
609             post_tls_observer()->WaitForReadError());
610 
611   base::RunLoop().RunUntilIdle();
612   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
613   EXPECT_TRUE(data_provider.AllReadDataConsumed());
614   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
615 }
616 
617 // Tests that a read error is encountered after UpgradeToTLS completes
618 // successfully.
TEST_P(TLSClientSocketTest,WriteErrorAfterUpgradeToTLS)619 TEST_P(TLSClientSocketTest, WriteErrorAfterUpgradeToTLS) {
620   const net::MockRead kReads[] = {net::MockRead(net::ASYNC, net::OK, 0)};
621   const net::MockWrite kWrites[] = {
622       net::MockWrite(net::SYNCHRONOUS, net::ERR_CONNECTION_CLOSED, 1)};
623   net::SequencedSocketData data_provider(kReads, kWrites);
624   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
625   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
626   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
627   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
628 
629   SocketHandle client_socket;
630   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
631   EXPECT_EQ(net::OK,
632             CreateSocketSync(MakeRequest(&client_socket), server_addr));
633 
634   net::HostPortPair host_port_pair("example.org", 443);
635   pre_tls_recv_handle()->reset();
636   pre_tls_send_handle()->reset();
637   net::TestCompletionCallback callback;
638   mojo::Remote<mojom::TLSClientSocket> tls_socket;
639   UpgradeToTLS(&client_socket, host_port_pair,
640                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
641   ASSERT_EQ(net::OK, callback.WaitForResult());
642   ResetSocket(&client_socket);
643 
644   uint32_t num_bytes = strlen(kSecretMsg);
645   EXPECT_EQ(MOJO_RESULT_OK,
646             post_tls_send_handle()->get().WriteData(&kSecretMsg, &num_bytes,
647                                                     MOJO_WRITE_DATA_FLAG_NONE));
648   EXPECT_EQ(net::ERR_CONNECTION_CLOSED,
649             post_tls_observer()->WaitForWriteError());
650 
651   base::RunLoop().RunUntilIdle();
652   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
653   EXPECT_TRUE(data_provider.AllReadDataConsumed());
654   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
655 }
656 
657 // Tests that reading from the pre-tls data pipe is okay even after UpgradeToTLS
658 // is called.
TEST_P(TLSClientSocketTest,ReadFromPreTlsDataPipeAfterUpgradeToTLS)659 TEST_P(TLSClientSocketTest, ReadFromPreTlsDataPipeAfterUpgradeToTLS) {
660   const net::MockRead kReads[] = {
661       net::MockRead(net::ASYNC, kMsg, kMsgSize, 0),
662       net::MockRead(net::ASYNC, kSecretMsg, kSecretMsgSize, 2),
663       net::MockRead(net::SYNCHRONOUS, net::OK, 3)};
664   const net::MockWrite kWrites[] = {
665       net::MockWrite(net::SYNCHRONOUS, kSecretMsg, kSecretMsgSize, 1)};
666   net::SequencedSocketData data_provider(kReads, kWrites);
667   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
668   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
669   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
670   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
671 
672   SocketHandle client_socket;
673   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
674   EXPECT_EQ(net::OK,
675             CreateSocketSync(MakeRequest(&client_socket), server_addr));
676 
677   net::HostPortPair host_port_pair("example.org", 443);
678   pre_tls_send_handle()->reset();
679   net::TestCompletionCallback callback;
680   mojo::Remote<mojom::TLSClientSocket> tls_socket;
681   UpgradeToTLS(&client_socket, host_port_pair,
682                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
683   base::RunLoop().RunUntilIdle();
684 
685   EXPECT_EQ(kMsg, Read(pre_tls_recv_handle(), kMsgSize));
686 
687   // Reset pre-tls receive pipe now and UpgradeToTLS should complete.
688   pre_tls_recv_handle()->reset();
689   ASSERT_EQ(net::OK, callback.WaitForResult());
690   ResetSocket(&client_socket);
691 
692   uint32_t num_bytes = strlen(kSecretMsg);
693   EXPECT_EQ(MOJO_RESULT_OK,
694             post_tls_send_handle()->get().WriteData(&kSecretMsg, &num_bytes,
695                                                     MOJO_WRITE_DATA_FLAG_NONE));
696   EXPECT_EQ(kSecretMsg, Read(post_tls_recv_handle(), kSecretMsgSize));
697   base::RunLoop().RunUntilIdle();
698   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
699   EXPECT_TRUE(data_provider.AllReadDataConsumed());
700   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
701 }
702 
703 // Tests that writing to the pre-tls data pipe is okay even after UpgradeToTLS
704 // is called.
TEST_P(TLSClientSocketTest,WriteToPreTlsDataPipeAfterUpgradeToTLS)705 TEST_P(TLSClientSocketTest, WriteToPreTlsDataPipeAfterUpgradeToTLS) {
706   const net::MockRead kReads[] = {
707       net::MockRead(net::ASYNC, kSecretMsg, kSecretMsgSize, 2),
708       net::MockRead(net::SYNCHRONOUS, net::OK, 3)};
709   const net::MockWrite kWrites[] = {
710       net::MockWrite(net::SYNCHRONOUS, kMsg, kMsgSize, 0),
711       net::MockWrite(net::SYNCHRONOUS, kSecretMsg, kSecretMsgSize, 1)};
712   net::SequencedSocketData data_provider(kReads, kWrites);
713   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
714   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
715   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
716   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
717 
718   SocketHandle client_socket;
719   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
720   EXPECT_EQ(net::OK,
721             CreateSocketSync(MakeRequest(&client_socket), server_addr));
722 
723   net::HostPortPair host_port_pair("example.org", 443);
724   pre_tls_recv_handle()->reset();
725   net::TestCompletionCallback callback;
726   mojo::Remote<mojom::TLSClientSocket> tls_socket;
727   UpgradeToTLS(&client_socket, host_port_pair,
728                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
729   base::RunLoop().RunUntilIdle();
730 
731   uint32_t num_bytes = strlen(kMsg);
732   EXPECT_EQ(MOJO_RESULT_OK, pre_tls_send_handle()->get().WriteData(
733                                 &kMsg, &num_bytes, MOJO_WRITE_DATA_FLAG_NONE));
734 
735   // Reset pre-tls send pipe now and UpgradeToTLS should complete.
736   pre_tls_send_handle()->reset();
737   ASSERT_EQ(net::OK, callback.WaitForResult());
738   ResetSocket(&client_socket);
739 
740   num_bytes = strlen(kSecretMsg);
741   EXPECT_EQ(MOJO_RESULT_OK,
742             post_tls_send_handle()->get().WriteData(&kSecretMsg, &num_bytes,
743                                                     MOJO_WRITE_DATA_FLAG_NONE));
744   EXPECT_EQ(kSecretMsg, Read(post_tls_recv_handle(), kSecretMsgSize));
745   base::RunLoop().RunUntilIdle();
746   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
747   EXPECT_TRUE(data_provider.AllReadDataConsumed());
748   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
749 }
750 
751 // Tests that reading from and writing to pre-tls data pipe is okay even after
752 // UpgradeToTLS is called.
TEST_P(TLSClientSocketTest,ReadAndWritePreTlsDataPipeAfterUpgradeToTLS)753 TEST_P(TLSClientSocketTest, ReadAndWritePreTlsDataPipeAfterUpgradeToTLS) {
754   const net::MockRead kReads[] = {
755       net::MockRead(net::ASYNC, kMsg, kMsgSize, 0),
756       net::MockRead(net::ASYNC, kSecretMsg, kSecretMsgSize, 3),
757       net::MockRead(net::SYNCHRONOUS, net::OK, 4)};
758   const net::MockWrite kWrites[] = {
759       net::MockWrite(net::SYNCHRONOUS, kMsg, kMsgSize, 1),
760       net::MockWrite(net::SYNCHRONOUS, kSecretMsg, kSecretMsgSize, 2)};
761   net::SequencedSocketData data_provider(kReads, kWrites);
762   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
763   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
764   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
765   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
766 
767   SocketHandle client_socket;
768   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
769   EXPECT_EQ(net::OK,
770             CreateSocketSync(MakeRequest(&client_socket), server_addr));
771 
772   net::HostPortPair host_port_pair("example.org", 443);
773   base::RunLoop run_loop;
774   net::TestCompletionCallback callback;
775   mojo::Remote<mojom::TLSClientSocket> tls_socket;
776   UpgradeToTLS(&client_socket, host_port_pair,
777                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
778   EXPECT_EQ(kMsg, Read(pre_tls_recv_handle(), kMsgSize));
779   uint32_t num_bytes = strlen(kMsg);
780   EXPECT_EQ(MOJO_RESULT_OK, pre_tls_send_handle()->get().WriteData(
781                                 &kMsg, &num_bytes, MOJO_WRITE_DATA_FLAG_NONE));
782 
783   // Reset pre-tls pipes now and UpgradeToTLS should complete.
784   pre_tls_recv_handle()->reset();
785   pre_tls_send_handle()->reset();
786   ASSERT_EQ(net::OK, callback.WaitForResult());
787   ResetSocket(&client_socket);
788 
789   num_bytes = strlen(kSecretMsg);
790   EXPECT_EQ(MOJO_RESULT_OK,
791             post_tls_send_handle()->get().WriteData(&kSecretMsg, &num_bytes,
792                                                     MOJO_WRITE_DATA_FLAG_NONE));
793   EXPECT_EQ(kSecretMsg, Read(post_tls_recv_handle(), kSecretMsgSize));
794   base::RunLoop().RunUntilIdle();
795   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
796   EXPECT_TRUE(data_provider.AllReadDataConsumed());
797   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
798 }
799 
800 // Tests that a read error is encountered before UpgradeToTLS completes.
TEST_P(TLSClientSocketTest,ReadErrorBeforeUpgradeToTLS)801 TEST_P(TLSClientSocketTest, ReadErrorBeforeUpgradeToTLS) {
802   // This requires pre_tls_observer(), which is not provided by proxy resolving
803   // sockets.
804   if (mode() != kDirect)
805     return;
806   const net::MockRead kReads[] = {
807       net::MockRead(net::ASYNC, kMsg, kMsgSize, 0),
808       net::MockRead(net::SYNCHRONOUS, net::ERR_CONNECTION_CLOSED, 1)};
809   net::SequencedSocketData data_provider(kReads, base::span<net::MockWrite>());
810   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
811   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
812 
813   SocketHandle client_socket;
814   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
815   EXPECT_EQ(net::OK,
816             CreateSocketSync(MakeRequest(&client_socket), server_addr));
817 
818   net::HostPortPair host_port_pair("example.org", 443);
819   pre_tls_send_handle()->reset();
820   net::TestCompletionCallback callback;
821   mojo::Remote<mojom::TLSClientSocket> tls_socket;
822   UpgradeToTLS(&client_socket, host_port_pair,
823                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
824 
825   EXPECT_EQ(kMsg, Read(pre_tls_recv_handle(), kMsgSize));
826   EXPECT_EQ(net::ERR_CONNECTION_CLOSED, pre_tls_observer()->WaitForReadError());
827 
828   // Reset pre-tls receive pipe now and UpgradeToTLS should complete.
829   pre_tls_recv_handle()->reset();
830   ASSERT_EQ(net::ERR_SOCKET_NOT_CONNECTED, callback.WaitForResult());
831   ResetSocket(&client_socket);
832 
833   base::RunLoop().RunUntilIdle();
834   EXPECT_TRUE(data_provider.AllReadDataConsumed());
835   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
836 }
837 
838 // Tests that a write error is encountered before UpgradeToTLS completes.
TEST_P(TLSClientSocketTest,WriteErrorBeforeUpgradeToTLS)839 TEST_P(TLSClientSocketTest, WriteErrorBeforeUpgradeToTLS) {
840   // This requires pre_tls_observer(), which is not provided by proxy resolving
841   // sockets.
842   if (mode() != kDirect)
843     return;
844 
845   const net::MockRead kReads[] = {net::MockRead(net::ASYNC, net::OK, 1)};
846   const net::MockWrite kWrites[] = {
847       net::MockWrite(net::SYNCHRONOUS, net::ERR_CONNECTION_CLOSED, 0)};
848   net::SequencedSocketData data_provider(kReads, kWrites);
849   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
850   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
851 
852   SocketHandle client_socket;
853   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
854   EXPECT_EQ(net::OK,
855             CreateSocketSync(MakeRequest(&client_socket), server_addr));
856 
857   net::HostPortPair host_port_pair("example.org", 443);
858   pre_tls_recv_handle()->reset();
859   net::TestCompletionCallback callback;
860   mojo::Remote<mojom::TLSClientSocket> tls_socket;
861   UpgradeToTLS(&client_socket, host_port_pair,
862                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
863   uint32_t num_bytes = strlen(kMsg);
864   EXPECT_EQ(MOJO_RESULT_OK, pre_tls_send_handle()->get().WriteData(
865                                 &kMsg, &num_bytes, MOJO_WRITE_DATA_FLAG_NONE));
866 
867   EXPECT_EQ(net::ERR_CONNECTION_CLOSED,
868             pre_tls_observer()->WaitForWriteError());
869   // Reset pre-tls send pipe now and UpgradeToTLS should complete.
870   pre_tls_send_handle()->reset();
871   ASSERT_EQ(net::ERR_SOCKET_NOT_CONNECTED, callback.WaitForResult());
872   ResetSocket(&client_socket);
873 
874   base::RunLoop().RunUntilIdle();
875   // Write failed before the mock read can be consumed.
876   EXPECT_FALSE(data_provider.AllReadDataConsumed());
877   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
878 }
879 
880 INSTANTIATE_TEST_SUITE_P(
881     All,
882     TLSClientSocketTest,
883     ::testing::Values(TLSClientSocketTestBase::kDirect,
884                       TLSClientSocketTestBase::kProxyResolving));
885 
886 // Tests with proxy resolving socket and a proxy actually configured.
887 class TLSCLientSocketProxyTest : public ::testing::Test,
888                                  public TLSClientSocketTestBase {
889  public:
TLSCLientSocketProxyTest()890   TLSCLientSocketProxyTest()
891       : TLSClientSocketTestBase(TLSClientSocketTestBase::kProxyResolving) {
892     Init(true /* use_mock_sockets*/, true /* configure_proxy */);
893   }
894 
~TLSCLientSocketProxyTest()895   ~TLSCLientSocketProxyTest() override {}
896 
897  private:
898   DISALLOW_COPY_AND_ASSIGN(TLSCLientSocketProxyTest);
899 };
900 
TEST_F(TLSCLientSocketProxyTest,UpgradeToTLS)901 TEST_F(TLSCLientSocketProxyTest, UpgradeToTLS) {
902   const char kConnectRequest[] =
903       "CONNECT 192.168.1.1:1234 HTTP/1.1\r\n"
904       "Host: 192.168.1.1:1234\r\n"
905       "Proxy-Connection: keep-alive\r\n\r\n";
906   const char kConnectResponse[] = "HTTP/1.1 200 OK\r\n\r\n";
907 
908   const net::MockRead kReads[] = {
909       net::MockRead(net::ASYNC, kConnectResponse, strlen(kConnectResponse), 1),
910       net::MockRead(net::ASYNC, kMsg, kMsgSize, 3),
911       net::MockRead(net::SYNCHRONOUS, net::OK, 4)};
912   const net::MockWrite kWrites[] = {
913       net::MockWrite(net::ASYNC, kConnectRequest, strlen(kConnectRequest), 0),
914       net::MockWrite(net::SYNCHRONOUS, kMsg, kMsgSize, 2)};
915   net::SequencedSocketData data_provider(kReads, kWrites);
916   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
917   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
918   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
919   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
920 
921   SocketHandle client_socket;
922   net::IPEndPoint server_addr(net::IPAddress(192, 168, 1, 1), 1234);
923   EXPECT_EQ(net::OK,
924             CreateSocketSync(MakeRequest(&client_socket), server_addr));
925 
926   net::HostPortPair host_port_pair("example.org", 443);
927   pre_tls_recv_handle()->reset();
928   pre_tls_send_handle()->reset();
929   net::TestCompletionCallback callback;
930   mojo::Remote<mojom::TLSClientSocket> tls_socket;
931   UpgradeToTLS(&client_socket, host_port_pair,
932                tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
933   ASSERT_EQ(net::OK, callback.WaitForResult());
934   ResetSocket(&client_socket);
935 
936   uint32_t num_bytes = strlen(kMsg);
937   EXPECT_EQ(MOJO_RESULT_OK, post_tls_send_handle()->get().WriteData(
938                                 &kMsg, &num_bytes, MOJO_WRITE_DATA_FLAG_NONE));
939   EXPECT_EQ(kMsg, Read(post_tls_recv_handle(), kMsgSize));
940   base::RunLoop().RunUntilIdle();
941   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
942   EXPECT_TRUE(data_provider.AllReadDataConsumed());
943   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
944 }
945 
946 class TLSClientSocketIoModeTest : public TLSClientSocketTestBase,
947                                   public testing::TestWithParam<net::IoMode> {
948  public:
TLSClientSocketIoModeTest()949   TLSClientSocketIoModeTest()
950       : TLSClientSocketTestBase(TLSClientSocketTestBase::kDirect) {
951     Init(true /* use_mock_sockets*/, false /* configure_proxy */);
952   }
953 
~TLSClientSocketIoModeTest()954   ~TLSClientSocketIoModeTest() override {}
955 
956  private:
957   DISALLOW_COPY_AND_ASSIGN(TLSClientSocketIoModeTest);
958 };
959 
960 INSTANTIATE_TEST_SUITE_P(All,
961                          TLSClientSocketIoModeTest,
962                          testing::Values(net::SYNCHRONOUS, net::ASYNC));
963 
TEST_P(TLSClientSocketIoModeTest,MultipleWriteToTLSSocket)964 TEST_P(TLSClientSocketIoModeTest, MultipleWriteToTLSSocket) {
965   const int kNumIterations = 3;
966   std::vector<net::MockRead> reads;
967   std::vector<net::MockWrite> writes;
968   int sequence_number = 0;
969   net::IoMode mode = GetParam();
970   for (int j = 0; j < kNumIterations; ++j) {
971     for (size_t i = 0; i < kSecretMsgSize; ++i) {
972       writes.push_back(
973           net::MockWrite(mode, &kSecretMsg[i], 1, sequence_number++));
974     }
975     for (size_t i = 0; i < kSecretMsgSize; ++i) {
976       reads.push_back(
977           net::MockRead(net::ASYNC, &kSecretMsg[i], 1, sequence_number++));
978     }
979     if (j == kNumIterations - 1) {
980       reads.push_back(net::MockRead(mode, net::OK, sequence_number++));
981     }
982   }
983   net::SequencedSocketData data_provider(reads, writes);
984   data_provider.set_connect_data(net::MockConnect(GetParam(), net::OK));
985   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
986   net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
987   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
988 
989   mojo::Remote<mojom::TCPConnectedSocket> client_socket;
990   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
991   EXPECT_EQ(net::OK,
992             CreateTCPConnectedSocketSync(
993                 client_socket.BindNewPipeAndPassReceiver(), server_addr));
994 
995   net::HostPortPair host_port_pair("example.org", 443);
996   pre_tls_recv_handle()->reset();
997   pre_tls_send_handle()->reset();
998   net::TestCompletionCallback callback;
999   mojo::Remote<mojom::TLSClientSocket> tls_socket;
1000   UpgradeTCPConnectedSocketToTLS(
1001       client_socket.get(), host_port_pair, nullptr /* options */,
1002       tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
1003   ASSERT_EQ(net::OK, callback.WaitForResult());
1004   client_socket.reset();
1005   EXPECT_FALSE(ssl_info());
1006 
1007   // Loop kNumIterations times to test that writes can follow reads, and reads
1008   // can follow writes.
1009   for (int j = 0; j < kNumIterations; ++j) {
1010     // Write multiple times.
1011     for (size_t i = 0; i < kSecretMsgSize; ++i) {
1012       uint32_t num_bytes = 1;
1013       EXPECT_EQ(MOJO_RESULT_OK,
1014                 post_tls_send_handle()->get().WriteData(
1015                     &kSecretMsg[i], &num_bytes, MOJO_WRITE_DATA_FLAG_NONE));
1016       // Flush the 1 byte write.
1017       base::RunLoop().RunUntilIdle();
1018     }
1019     // Reading kSecretMsgSize should coalesce the 1-byte mock reads.
1020     EXPECT_EQ(kSecretMsg, Read(post_tls_recv_handle(), kSecretMsgSize));
1021   }
1022   EXPECT_TRUE(ssl_socket.ConnectDataConsumed());
1023   EXPECT_TRUE(data_provider.AllReadDataConsumed());
1024   EXPECT_TRUE(data_provider.AllWriteDataConsumed());
1025 }
1026 
1027 // Check SSLInfo is provided in both sync and async cases.
TEST_P(TLSClientSocketIoModeTest,SSLInfo)1028 TEST_P(TLSClientSocketIoModeTest, SSLInfo) {
1029   // End of file. Reads don't matter, only the handshake does.
1030   std::vector<net::MockRead> reads = {net::MockRead(net::SYNCHRONOUS, net::OK)};
1031   std::vector<net::MockWrite> writes;
1032   net::SequencedSocketData data_provider(reads, writes);
1033   data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
1034   mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
1035   net::SSLSocketDataProvider ssl_socket(GetParam(), net::OK);
1036   // Set a value on SSLInfo to make sure it's correctly received.
1037   ssl_socket.ssl_info.is_issued_by_known_root = true;
1038   mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
1039 
1040   mojo::Remote<mojom::TCPConnectedSocket> client_socket;
1041   net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234);
1042   EXPECT_EQ(net::OK,
1043             CreateTCPConnectedSocketSync(
1044                 client_socket.BindNewPipeAndPassReceiver(), server_addr));
1045 
1046   net::HostPortPair host_port_pair("example.org", 443);
1047   pre_tls_recv_handle()->reset();
1048   pre_tls_send_handle()->reset();
1049   net::TestCompletionCallback callback;
1050   mojo::Remote<mojom::TLSClientSocket> tls_socket;
1051   mojom::TLSClientSocketOptionsPtr options =
1052       mojom::TLSClientSocketOptions::New();
1053   options->send_ssl_info = true;
1054   UpgradeTCPConnectedSocketToTLS(
1055       client_socket.get(), host_port_pair, std::move(options),
1056       tls_socket.BindNewPipeAndPassReceiver(), callback.callback());
1057   ASSERT_EQ(net::OK, callback.WaitForResult());
1058   ASSERT_TRUE(ssl_info());
1059   EXPECT_TRUE(ssl_socket.ssl_info.is_issued_by_known_root);
1060   EXPECT_FALSE(ssl_socket.ssl_info.is_fatal_cert_error);
1061 }
1062 
1063 class TLSClientSocketTestWithEmbeddedTestServerBase
1064     : public TLSClientSocketTestBase {
1065  public:
TLSClientSocketTestWithEmbeddedTestServerBase(Mode mode)1066   explicit TLSClientSocketTestWithEmbeddedTestServerBase(Mode mode)
1067       : TLSClientSocketTestBase(mode),
1068         server_(net::EmbeddedTestServer::TYPE_HTTPS) {
1069     Init(false /* use_mock_sockets */, false /* configure_proxy */);
1070   }
1071 
~TLSClientSocketTestWithEmbeddedTestServerBase()1072   ~TLSClientSocketTestWithEmbeddedTestServerBase() override {}
1073 
1074   // Starts the test server using the specified certificate.
StartTestServer(net::EmbeddedTestServer::ServerCertificate certificate)1075   bool StartTestServer(net::EmbeddedTestServer::ServerCertificate certificate)
1076       WARN_UNUSED_RESULT {
1077     server_.RegisterRequestHandler(
1078         base::BindRepeating([](const net::test_server::HttpRequest& request) {
1079           if (base::StartsWith(request.relative_url, "/secret",
1080                                base::CompareCase::INSENSITIVE_ASCII)) {
1081             return std::unique_ptr<net::test_server::HttpResponse>(
1082                 new net::test_server::RawHttpResponse("HTTP/1.1 200 OK",
1083                                                       "Hello There!"));
1084           }
1085           return std::unique_ptr<net::test_server::HttpResponse>();
1086         }));
1087     server_.SetSSLConfig(certificate);
1088     return server_.Start();
1089   }
1090 
1091   // Attempts to eastablish a TLS connection to the test server by first
1092   // establishing a TCP connection, and then upgrading it.  Returns the
1093   // resulting network error code.
CreateTLSSocket()1094   int CreateTLSSocket() WARN_UNUSED_RESULT {
1095     SocketHandle client_socket;
1096     net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(),
1097                                 server_.port());
1098     EXPECT_EQ(net::OK,
1099               CreateSocketSync(MakeRequest(&client_socket), server_addr));
1100 
1101     pre_tls_recv_handle()->reset();
1102     pre_tls_send_handle()->reset();
1103     net::TestCompletionCallback callback;
1104     UpgradeToTLS(&client_socket, server_.host_port_pair(),
1105                  tls_socket_.BindNewPipeAndPassReceiver(), callback.callback());
1106     int result = callback.WaitForResult();
1107     ResetSocket(&client_socket);
1108     return result;
1109   }
1110 
CreateTLSSocketWithOptions(mojom::TLSClientSocketOptionsPtr options)1111   int CreateTLSSocketWithOptions(mojom::TLSClientSocketOptionsPtr options)
1112       WARN_UNUSED_RESULT {
1113     // Proxy connections don't support TLSClientSocketOptions.
1114     DCHECK_EQ(kDirect, mode());
1115 
1116     mojo::Remote<mojom::TCPConnectedSocket> tcp_socket;
1117     net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(),
1118                                 server_.port());
1119     EXPECT_EQ(net::OK,
1120               CreateTCPConnectedSocketSync(
1121                   tcp_socket.BindNewPipeAndPassReceiver(), server_addr));
1122 
1123     pre_tls_recv_handle()->reset();
1124     pre_tls_send_handle()->reset();
1125     net::TestCompletionCallback callback;
1126     UpgradeTCPConnectedSocketToTLS(
1127         tcp_socket.get(), server_.host_port_pair(), std::move(options),
1128         tls_socket_.BindNewPipeAndPassReceiver(), callback.callback());
1129     int result = callback.WaitForResult();
1130     tcp_socket.reset();
1131     return result;
1132   }
1133 
TestTlsSocket()1134   void TestTlsSocket() {
1135     ASSERT_TRUE(tls_socket_.is_bound());
1136     const char kTestMsg[] = "GET /secret HTTP/1.1\r\n\r\n";
1137     uint32_t num_bytes = strlen(kTestMsg);
1138     const char kResponse[] = "HTTP/1.1 200 OK\n\n";
1139     EXPECT_EQ(MOJO_RESULT_OK,
1140               post_tls_send_handle()->get().WriteData(
1141                   &kTestMsg, &num_bytes, MOJO_WRITE_DATA_FLAG_NONE));
1142     EXPECT_EQ(kResponse, Read(post_tls_recv_handle(), strlen(kResponse)));
1143   }
1144 
server()1145   net::EmbeddedTestServer* server() { return &server_; }
1146 
1147  private:
1148   net::EmbeddedTestServer server_;
1149 
1150   mojo::Remote<mojom::TLSClientSocket> tls_socket_;
1151 
1152   DISALLOW_COPY_AND_ASSIGN(TLSClientSocketTestWithEmbeddedTestServerBase);
1153 };
1154 
1155 class TLSClientSocketTestWithEmbeddedTestServer
1156     : public TLSClientSocketTestWithEmbeddedTestServerBase,
1157       public ::testing::TestWithParam<TLSClientSocketTestBase::Mode> {
1158  public:
TLSClientSocketTestWithEmbeddedTestServer()1159   TLSClientSocketTestWithEmbeddedTestServer()
1160       : TLSClientSocketTestWithEmbeddedTestServerBase(GetParam()) {}
~TLSClientSocketTestWithEmbeddedTestServer()1161   ~TLSClientSocketTestWithEmbeddedTestServer() override {}
1162 
1163  private:
1164   DISALLOW_COPY_AND_ASSIGN(TLSClientSocketTestWithEmbeddedTestServer);
1165 };
1166 
TEST_P(TLSClientSocketTestWithEmbeddedTestServer,Basic)1167 TEST_P(TLSClientSocketTestWithEmbeddedTestServer, Basic) {
1168   ASSERT_TRUE(StartTestServer(net::EmbeddedTestServer::CERT_OK));
1169   ASSERT_EQ(net::OK, CreateTLSSocket());
1170   // No SSLInfo should be received by default. SSLInfo is only supported in the
1171   // kDirect case, but it doesn't hurt to check it's null it in the
1172   // kProxyResolving case.
1173   EXPECT_FALSE(ssl_info());
1174   TestTlsSocket();
1175 }
1176 
TEST_P(TLSClientSocketTestWithEmbeddedTestServer,ServerCertError)1177 TEST_P(TLSClientSocketTestWithEmbeddedTestServer, ServerCertError) {
1178   ASSERT_TRUE(StartTestServer(net::EmbeddedTestServer::CERT_MISMATCHED_NAME));
1179   ASSERT_EQ(net::ERR_CERT_COMMON_NAME_INVALID, CreateTLSSocket());
1180   // No SSLInfo should be received by default. SSLInfo is only supported in the
1181   // kDirect case, but it doesn't hurt to check it's null in the kProxyResolving
1182   // case.
1183   EXPECT_FALSE(ssl_info());
1184 
1185   // Handles should be invalid.
1186   EXPECT_FALSE(post_tls_recv_handle()->is_valid());
1187   EXPECT_FALSE(post_tls_send_handle()->is_valid());
1188 }
1189 
1190 INSTANTIATE_TEST_SUITE_P(
1191     All,
1192     TLSClientSocketTestWithEmbeddedTestServer,
1193     ::testing::Values(TLSClientSocketTestBase::kDirect,
1194                       TLSClientSocketTestBase::kProxyResolving));
1195 
1196 class TLSClientSocketDirectTestWithEmbeddedTestServer
1197     : public TLSClientSocketTestWithEmbeddedTestServerBase,
1198       public testing::Test {
1199  public:
TLSClientSocketDirectTestWithEmbeddedTestServer()1200   TLSClientSocketDirectTestWithEmbeddedTestServer()
1201       : TLSClientSocketTestWithEmbeddedTestServerBase(kDirect) {}
~TLSClientSocketDirectTestWithEmbeddedTestServer()1202   ~TLSClientSocketDirectTestWithEmbeddedTestServer() override {}
1203 
1204  private:
1205   DISALLOW_COPY_AND_ASSIGN(TLSClientSocketDirectTestWithEmbeddedTestServer);
1206 };
1207 
TEST_F(TLSClientSocketDirectTestWithEmbeddedTestServer,SSLInfo)1208 TEST_F(TLSClientSocketDirectTestWithEmbeddedTestServer, SSLInfo) {
1209   ASSERT_TRUE(StartTestServer(net::EmbeddedTestServer::CERT_OK));
1210   mojom::TLSClientSocketOptionsPtr options =
1211       mojom::TLSClientSocketOptions::New();
1212   options->send_ssl_info = true;
1213   ASSERT_EQ(net::OK, CreateTLSSocketWithOptions(std::move(options)));
1214 
1215   ASSERT_TRUE(ssl_info());
1216   EXPECT_TRUE(ssl_info()->is_valid());
1217   EXPECT_FALSE(ssl_info()->is_fatal_cert_error);
1218 
1219   TestTlsSocket();
1220 }
1221 
TEST_F(TLSClientSocketDirectTestWithEmbeddedTestServer,SSLInfoServerCertError)1222 TEST_F(TLSClientSocketDirectTestWithEmbeddedTestServer,
1223        SSLInfoServerCertError) {
1224   ASSERT_TRUE(StartTestServer(net::EmbeddedTestServer::CERT_MISMATCHED_NAME));
1225   mojom::TLSClientSocketOptionsPtr options =
1226       mojom::TLSClientSocketOptions::New();
1227   options->send_ssl_info = true;
1228   // Requesting SSLInfo should not bypass cert verification.
1229   ASSERT_EQ(net::ERR_CERT_COMMON_NAME_INVALID,
1230             CreateTLSSocketWithOptions(std::move(options)));
1231 
1232   // No SSLInfo should be provided on error.
1233   EXPECT_FALSE(ssl_info());
1234 
1235   // Handles should be invalid.
1236   EXPECT_FALSE(post_tls_recv_handle()->is_valid());
1237   EXPECT_FALSE(post_tls_send_handle()->is_valid());
1238 }
1239 
1240 // Check skipping cert verification always received SSLInfo, even with valid
1241 // certs.
TEST_F(TLSClientSocketDirectTestWithEmbeddedTestServer,UnsafelySkipCertVerification)1242 TEST_F(TLSClientSocketDirectTestWithEmbeddedTestServer,
1243        UnsafelySkipCertVerification) {
1244   ASSERT_TRUE(StartTestServer(net::EmbeddedTestServer::CERT_OK));
1245   mojom::TLSClientSocketOptionsPtr options =
1246       mojom::TLSClientSocketOptions::New();
1247   options->unsafely_skip_cert_verification = true;
1248   ASSERT_EQ(net::OK, CreateTLSSocketWithOptions(std::move(options)));
1249 
1250   ASSERT_TRUE(ssl_info());
1251   EXPECT_TRUE(ssl_info()->is_valid());
1252   EXPECT_FALSE(ssl_info()->is_fatal_cert_error);
1253 
1254   TestTlsSocket();
1255 }
1256 
TEST_F(TLSClientSocketDirectTestWithEmbeddedTestServer,UnsafelySkipCertVerificationServerCertError)1257 TEST_F(TLSClientSocketDirectTestWithEmbeddedTestServer,
1258        UnsafelySkipCertVerificationServerCertError) {
1259   ASSERT_TRUE(StartTestServer(net::EmbeddedTestServer::CERT_MISMATCHED_NAME));
1260   mojom::TLSClientSocketOptionsPtr options =
1261       mojom::TLSClientSocketOptions::New();
1262   options->unsafely_skip_cert_verification = true;
1263   ASSERT_EQ(net::OK, CreateTLSSocketWithOptions(std::move(options)));
1264 
1265   ASSERT_TRUE(ssl_info());
1266   EXPECT_TRUE(ssl_info()->is_valid());
1267   EXPECT_FALSE(ssl_info()->is_fatal_cert_error);
1268 
1269   TestTlsSocket();
1270 }
1271 
1272 }  // namespace
1273 
1274 }  // namespace network
1275