1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "components/cast_channel/cast_socket.h"
6 
7 #include <stdint.h>
8 
9 #include <utility>
10 #include <vector>
11 
12 #include "base/bind.h"
13 #include "base/callback_helpers.h"
14 #include "base/files/file_util.h"
15 #include "base/location.h"
16 #include "base/macros.h"
17 #include "base/memory/ptr_util.h"
18 #include "base/memory/weak_ptr.h"
19 #include "base/path_service.h"
20 #include "base/run_loop.h"
21 #include "base/single_thread_task_runner.h"
22 #include "base/strings/string_number_conversions.h"
23 #include "base/sys_byteorder.h"
24 #include "base/test/bind.h"
25 #include "base/threading/thread_task_runner_handle.h"
26 #include "base/timer/mock_timer.h"
27 #include "build/build_config.h"
28 #include "components/cast_channel/cast_auth_util.h"
29 #include "components/cast_channel/cast_framer.h"
30 #include "components/cast_channel/cast_message_util.h"
31 #include "components/cast_channel/cast_test_util.h"
32 #include "components/cast_channel/cast_transport.h"
33 #include "components/cast_channel/logger.h"
34 #include "content/public/test/browser_task_environment.h"
35 #include "crypto/rsa_private_key.h"
36 #include "mojo/public/cpp/bindings/remote.h"
37 #include "net/base/address_list.h"
38 #include "net/base/net_errors.h"
39 #include "net/cert/pem.h"
40 #include "net/socket/client_socket_factory.h"
41 #include "net/socket/socket_test_util.h"
42 #include "net/socket/ssl_client_socket.h"
43 #include "net/socket/ssl_server_socket.h"
44 #include "net/socket/tcp_client_socket.h"
45 #include "net/socket/tcp_server_socket.h"
46 #include "net/ssl/ssl_info.h"
47 #include "net/ssl/ssl_server_config.h"
48 #include "net/test/cert_test_util.h"
49 #include "net/test/test_data_directory.h"
50 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
51 #include "net/url_request/url_request_test_util.h"
52 #include "services/network/network_context.h"
53 #include "testing/gmock/include/gmock/gmock.h"
54 #include "testing/gtest/include/gtest/gtest.h"
55 #include "third_party/openscreen/src/cast/common/channel/proto/cast_channel.pb.h"
56 
57 const int64_t kDistantTimeoutMillis = 100000;  // 100 seconds (never hit).
58 
59 using ::testing::A;
60 using ::testing::DoAll;
61 using ::testing::Invoke;
62 using ::testing::InvokeArgument;
63 using ::testing::NotNull;
64 using ::testing::Return;
65 using ::testing::SaveArg;
66 using ::testing::_;
67 
68 using ::cast::channel::CastMessage;
69 
70 namespace cast_channel {
71 namespace {
72 const char kAuthNamespace[] = "urn:x-cast:com.google.cast.tp.deviceauth";
73 
74 // Returns an auth challenge message inline.
CreateAuthChallenge()75 CastMessage CreateAuthChallenge() {
76   CastMessage output;
77   CreateAuthChallengeMessage(&output, AuthContext::Create());
78   return output;
79 }
80 
81 // Returns an auth challenge response message inline.
CreateAuthReply()82 CastMessage CreateAuthReply() {
83   CastMessage output;
84   output.set_protocol_version(CastMessage::CASTV2_1_0);
85   output.set_source_id("sender-0");
86   output.set_destination_id("receiver-0");
87   output.set_payload_type(CastMessage::BINARY);
88   output.set_payload_binary("abcd");
89   output.set_namespace_(kAuthNamespace);
90   return output;
91 }
92 
CreateTestMessage()93 CastMessage CreateTestMessage() {
94   CastMessage test_message;
95   test_message.set_protocol_version(CastMessage::CASTV2_1_0);
96   test_message.set_namespace_("ns");
97   test_message.set_source_id("source");
98   test_message.set_destination_id("dest");
99   test_message.set_payload_type(CastMessage::STRING);
100   test_message.set_payload_utf8("payload");
101   return test_message;
102 }
103 
GetTestCertsDirectory()104 base::FilePath GetTestCertsDirectory() {
105   base::FilePath path;
106   base::PathService::Get(base::DIR_SOURCE_ROOT, &path);
107   path = path.Append(FILE_PATH_LITERAL("components"));
108   path = path.Append(FILE_PATH_LITERAL("test"));
109   path = path.Append(FILE_PATH_LITERAL("data"));
110   path = path.Append(FILE_PATH_LITERAL("cast_channel"));
111   return path;
112 }
113 
114 class MockTCPSocket : public net::MockTCPClientSocket {
115  public:
MockTCPSocket(bool do_nothing,net::SocketDataProvider * socket_provider)116   MockTCPSocket(bool do_nothing, net::SocketDataProvider* socket_provider)
117       : net::MockTCPClientSocket(net::AddressList(), nullptr, socket_provider) {
118     do_nothing_ = do_nothing;
119     set_enable_read_if_ready(true);
120   }
121 
Connect(net::CompletionOnceCallback callback)122   int Connect(net::CompletionOnceCallback callback) override {
123     if (do_nothing_) {
124       // Stall the I/O event loop.
125       return net::ERR_IO_PENDING;
126     }
127     return net::MockTCPClientSocket::Connect(std::move(callback));
128   }
129 
130  private:
131   bool do_nothing_;
132 
133   DISALLOW_COPY_AND_ASSIGN(MockTCPSocket);
134 };
135 
136 class CompleteHandler {
137  public:
CompleteHandler()138   CompleteHandler() {}
139   MOCK_METHOD1(OnCloseComplete, void(int result));
140   MOCK_METHOD1(OnConnectComplete, void(CastSocket* socket));
141   MOCK_METHOD1(OnWriteComplete, void(int result));
142   MOCK_METHOD1(OnReadComplete, void(int result));
143 
144  private:
145   DISALLOW_COPY_AND_ASSIGN(CompleteHandler);
146 };
147 
148 class TestCastSocketBase : public CastSocketImpl {
149  public:
TestCastSocketBase(network::mojom::NetworkContext * network_context,const CastSocketOpenParams & open_params,Logger * logger)150   TestCastSocketBase(network::mojom::NetworkContext* network_context,
151                      const CastSocketOpenParams& open_params,
152                      Logger* logger)
153       : CastSocketImpl(base::BindRepeating(
154                            [](network::mojom::NetworkContext* network_context) {
155                              return network_context;
156                            },
157                            network_context),
158                        open_params,
159                        logger,
160                        AuthContext::Create()),
161         verify_challenge_result_(true),
162         verify_challenge_disallow_(false),
163         mock_timer_(new base::MockOneShotTimer()) {
164     SetPeerCertForTesting(
165         net::ImportCertFromFile(GetTestCertsDirectory(), "self_signed.pem"));
166   }
~TestCastSocketBase()167   ~TestCastSocketBase() override {}
168 
SetVerifyChallengeResult(bool value)169   void SetVerifyChallengeResult(bool value) {
170     verify_challenge_result_ = value;
171   }
172 
TriggerTimeout()173   void TriggerTimeout() { mock_timer_->Fire(); }
174 
TestVerifyChannelPolicyNone()175   bool TestVerifyChannelPolicyNone() {
176     AuthResult authResult;
177     return VerifyChannelPolicy(authResult);
178   }
179 
DisallowVerifyChallengeResult()180   void DisallowVerifyChallengeResult() { verify_challenge_disallow_ = true; }
181 
182  protected:
VerifyChallengeReply()183   bool VerifyChallengeReply() override {
184     EXPECT_FALSE(verify_challenge_disallow_);
185     return verify_challenge_result_;
186   }
187 
GetTimer()188   base::OneShotTimer* GetTimer() override { return mock_timer_.get(); }
189 
190   // Simulated result of verifying challenge reply.
191   bool verify_challenge_result_;
192   bool verify_challenge_disallow_;
193   std::unique_ptr<base::MockOneShotTimer> mock_timer_;
194 
195  private:
196   DISALLOW_COPY_AND_ASSIGN(TestCastSocketBase);
197 };
198 
199 class MockTestCastSocket : public TestCastSocketBase {
200  public:
CreateSecure(network::mojom::NetworkContext * network_context,const CastSocketOpenParams & open_params,Logger * logger)201   static std::unique_ptr<MockTestCastSocket> CreateSecure(
202       network::mojom::NetworkContext* network_context,
203       const CastSocketOpenParams& open_params,
204       Logger* logger) {
205     return std::unique_ptr<MockTestCastSocket>(
206         new MockTestCastSocket(network_context, open_params, logger));
207   }
208 
209   using TestCastSocketBase::TestCastSocketBase;
210 
MockTestCastSocket(network::mojom::NetworkContext * network_context,const CastSocketOpenParams & open_params,Logger * logger)211   MockTestCastSocket(network::mojom::NetworkContext* network_context,
212                      const CastSocketOpenParams& open_params,
213                      Logger* logger)
214       : TestCastSocketBase(network_context, open_params, logger) {}
215 
~MockTestCastSocket()216   ~MockTestCastSocket() override {}
217 
SetupMockTransport()218   void SetupMockTransport() {
219     mock_transport_ = new MockCastTransport;
220     SetTransportForTesting(base::WrapUnique(mock_transport_));
221   }
222 
TestVerifyChannelPolicyAudioOnly()223   bool TestVerifyChannelPolicyAudioOnly() {
224     AuthResult authResult;
225     authResult.channel_policies |= AuthResult::POLICY_AUDIO_ONLY;
226     return VerifyChannelPolicy(authResult);
227   }
228 
GetMockTransport()229   MockCastTransport* GetMockTransport() {
230     CHECK(mock_transport_);
231     return mock_transport_;
232   }
233 
234  private:
235   MockCastTransport* mock_transport_ = nullptr;
236 
237   DISALLOW_COPY_AND_ASSIGN(MockTestCastSocket);
238 };
239 
240 // TODO(https://crbug.com/928467):  Remove this class.
241 class TestSocketFactory : public net::ClientSocketFactory {
242  public:
TestSocketFactory(net::IPEndPoint ip)243   explicit TestSocketFactory(net::IPEndPoint ip) : ip_(ip) {}
244   ~TestSocketFactory() override = default;
245 
246   // Socket connection helpers.
SetupTcpConnect(net::IoMode mode,int result)247   void SetupTcpConnect(net::IoMode mode, int result) {
248     tcp_connect_data_.reset(new net::MockConnect(mode, result, ip_));
249   }
SetupSslConnect(net::IoMode mode,int result)250   void SetupSslConnect(net::IoMode mode, int result) {
251     ssl_connect_data_.reset(new net::MockConnect(mode, result, ip_));
252   }
253 
254   // Socket I/O helpers.
AddWriteResult(const net::MockWrite & write)255   void AddWriteResult(const net::MockWrite& write) { writes_.push_back(write); }
AddWriteResult(net::IoMode mode,int result)256   void AddWriteResult(net::IoMode mode, int result) {
257     AddWriteResult(net::MockWrite(mode, result));
258   }
AddWriteResultForData(net::IoMode mode,const std::string & msg)259   void AddWriteResultForData(net::IoMode mode, const std::string& msg) {
260     AddWriteResult(mode, msg.size());
261   }
AddReadResult(const net::MockRead & read)262   void AddReadResult(const net::MockRead& read) { reads_.push_back(read); }
AddReadResult(net::IoMode mode,int result)263   void AddReadResult(net::IoMode mode, int result) {
264     AddReadResult(net::MockRead(mode, result));
265   }
AddReadResultForData(net::IoMode mode,const std::string & data)266   void AddReadResultForData(net::IoMode mode, const std::string& data) {
267     AddReadResult(net::MockRead(mode, data.c_str(), data.size()));
268   }
269 
270   // Helpers for modifying other connection-related behaviors.
SetupTcpConnectUnresponsive()271   void SetupTcpConnectUnresponsive() { tcp_unresponsive_ = true; }
272 
SetTcpSocket(std::unique_ptr<net::TransportClientSocket> tcp_client_socket)273   void SetTcpSocket(
274       std::unique_ptr<net::TransportClientSocket> tcp_client_socket) {
275     tcp_client_socket_ = std::move(tcp_client_socket);
276   }
277 
SetTLSSocketCreatedClosure(base::OnceClosure closure)278   void SetTLSSocketCreatedClosure(base::OnceClosure closure) {
279     tls_socket_created_ = std::move(closure);
280   }
281 
Pause()282   void Pause() {
283     if (socket_data_provider_)
284       socket_data_provider_->Pause();
285     else
286       socket_data_provider_paused_ = true;
287   }
288 
Resume()289   void Resume() { socket_data_provider_->Resume(); }
290 
291  private:
CreateDatagramClientSocket(net::DatagramSocket::BindType,net::NetLog *,const net::NetLogSource &)292   std::unique_ptr<net::DatagramClientSocket> CreateDatagramClientSocket(
293       net::DatagramSocket::BindType,
294       net::NetLog*,
295       const net::NetLogSource&) override {
296     NOTIMPLEMENTED();
297     return nullptr;
298   }
CreateTransportClientSocket(const net::AddressList &,std::unique_ptr<net::SocketPerformanceWatcher>,net::NetworkQualityEstimator *,net::NetLog *,const net::NetLogSource &)299   std::unique_ptr<net::TransportClientSocket> CreateTransportClientSocket(
300       const net::AddressList&,
301       std::unique_ptr<net::SocketPerformanceWatcher>,
302       net::NetworkQualityEstimator*,
303       net::NetLog*,
304       const net::NetLogSource&) override {
305     if (tcp_client_socket_)
306       return std::move(tcp_client_socket_);
307 
308     if (tcp_unresponsive_) {
309       socket_data_provider_ = std::make_unique<net::StaticSocketDataProvider>();
310       return std::unique_ptr<net::TransportClientSocket>(
311           new MockTCPSocket(true, socket_data_provider_.get()));
312     } else {
313       socket_data_provider_ =
314           std::make_unique<net::StaticSocketDataProvider>(reads_, writes_);
315       socket_data_provider_->set_connect_data(*tcp_connect_data_);
316       if (socket_data_provider_paused_)
317         socket_data_provider_->Pause();
318       return std::unique_ptr<net::TransportClientSocket>(
319           new MockTCPSocket(false, socket_data_provider_.get()));
320     }
321   }
CreateSSLClientSocket(net::SSLClientContext * context,std::unique_ptr<net::StreamSocket> nested_socket,const net::HostPortPair & host_and_port,const net::SSLConfig & ssl_config)322   std::unique_ptr<net::SSLClientSocket> CreateSSLClientSocket(
323       net::SSLClientContext* context,
324       std::unique_ptr<net::StreamSocket> nested_socket,
325       const net::HostPortPair& host_and_port,
326       const net::SSLConfig& ssl_config) override {
327     if (!ssl_connect_data_) {
328       // Test isn't overriding SSL socket creation.
329       return net::ClientSocketFactory::GetDefaultFactory()
330           ->CreateSSLClientSocket(context, std::move(nested_socket),
331                                   host_and_port, ssl_config);
332     }
333     ssl_socket_data_provider_ = std::make_unique<net::SSLSocketDataProvider>(
334         ssl_connect_data_->mode, ssl_connect_data_->result);
335 
336     if (tls_socket_created_)
337       std::move(tls_socket_created_).Run();
338 
339     return std::make_unique<net::MockSSLClientSocket>(
340         std::move(nested_socket), net::HostPortPair(), net::SSLConfig(),
341         ssl_socket_data_provider_.get());
342   }
CreateProxyClientSocket(std::unique_ptr<net::StreamSocket> stream_socket,const std::string & user_agent,const net::HostPortPair & endpoint,const net::ProxyServer & proxy_server,net::HttpAuthController * http_auth_controller,bool tunnel,bool using_spdy,net::NextProto negotiated_protocol,net::ProxyDelegate * proxy_delegate,const net::NetworkTrafficAnnotationTag & traffic_annotation)343   std::unique_ptr<net::ProxyClientSocket> CreateProxyClientSocket(
344       std::unique_ptr<net::StreamSocket> stream_socket,
345       const std::string& user_agent,
346       const net::HostPortPair& endpoint,
347       const net::ProxyServer& proxy_server,
348       net::HttpAuthController* http_auth_controller,
349       bool tunnel,
350       bool using_spdy,
351       net::NextProto negotiated_protocol,
352       net::ProxyDelegate* proxy_delegate,
353       const net::NetworkTrafficAnnotationTag& traffic_annotation) override {
354     NOTIMPLEMENTED();
355     return nullptr;
356   }
357 
358   net::IPEndPoint ip_;
359   // Simulated connect data
360   std::unique_ptr<net::MockConnect> tcp_connect_data_;
361   std::unique_ptr<net::MockConnect> ssl_connect_data_;
362   // Simulated read / write data
363   std::vector<net::MockWrite> writes_;
364   std::vector<net::MockRead> reads_;
365   std::unique_ptr<net::StaticSocketDataProvider> socket_data_provider_;
366   std::unique_ptr<net::SSLSocketDataProvider> ssl_socket_data_provider_;
367   bool socket_data_provider_paused_ = false;
368   // If true, makes TCP connection process stall. For timeout testing.
369   bool tcp_unresponsive_ = false;
370   std::unique_ptr<net::TransportClientSocket> tcp_client_socket_;
371   base::OnceClosure tls_socket_created_;
372 
373   DISALLOW_COPY_AND_ASSIGN(TestSocketFactory);
374 };
375 
376 class CastSocketTestBase : public testing::Test {
377  protected:
CastSocketTestBase()378   CastSocketTestBase()
379       : task_environment_(content::BrowserTaskEnvironment::IO_MAINLOOP),
380         url_request_context_(true),
381         logger_(new Logger()),
382         observer_(new MockCastSocketObserver()),
383         socket_open_params_(
384             CreateIPEndPointForTest(),
385             base::TimeDelta::FromMilliseconds(kDistantTimeoutMillis)),
386         client_socket_factory_(socket_open_params_.ip_endpoint) {}
~CastSocketTestBase()387   ~CastSocketTestBase() override {}
388 
SetUp()389   void SetUp() override {
390     EXPECT_CALL(*observer_, OnMessage(_, _)).Times(0);
391 
392     url_request_context_.set_client_socket_factory(&client_socket_factory_);
393     url_request_context_.Init();
394     network_context_ = std::make_unique<network::NetworkContext>(
395         nullptr, network_context_remote_.BindNewPipeAndPassReceiver(),
396         &url_request_context_,
397         /*cors_exempt_header_list=*/std::vector<std::string>());
398   }
399 
400   // Runs all pending tasks in the message loop.
RunPendingTasks()401   void RunPendingTasks() {
402     base::RunLoop run_loop;
403     run_loop.RunUntilIdle();
404   }
405 
client_socket_factory()406   TestSocketFactory* client_socket_factory() { return &client_socket_factory_; }
407 
408   content::BrowserTaskEnvironment task_environment_;
409   net::TestURLRequestContext url_request_context_;
410   std::unique_ptr<network::NetworkContext> network_context_;
411   mojo::Remote<network::mojom::NetworkContext> network_context_remote_;
412   Logger* logger_;
413   CompleteHandler handler_;
414   std::unique_ptr<MockCastSocketObserver> observer_;
415   CastSocketOpenParams socket_open_params_;
416   TestSocketFactory client_socket_factory_;
417 
418  private:
419   DISALLOW_COPY_AND_ASSIGN(CastSocketTestBase);
420 };
421 
422 class MockCastSocketTest : public CastSocketTestBase {
423  protected:
MockCastSocketTest()424   MockCastSocketTest() {}
425 
TearDown()426   void TearDown() override {
427     if (socket_) {
428       EXPECT_CALL(handler_, OnCloseComplete(net::OK));
429       socket_->Close(base::BindOnce(&CompleteHandler::OnCloseComplete,
430                                     base::Unretained(&handler_)));
431     }
432   }
433 
CreateCastSocketSecure()434   void CreateCastSocketSecure() {
435     socket_ = MockTestCastSocket::CreateSecure(network_context_.get(),
436                                                socket_open_params_, logger_);
437   }
438 
HandleAuthHandshake()439   void HandleAuthHandshake() {
440     socket_->SetupMockTransport();
441     CastMessage challenge_proto = CreateAuthChallenge();
442     EXPECT_CALL(*socket_->GetMockTransport(),
443                 SendMessage(EqualsProto(challenge_proto), _))
444         .WillOnce(PostCompletionCallbackTask<1>(net::OK));
445     EXPECT_CALL(*socket_->GetMockTransport(), Start());
446     EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
447     socket_->AddObserver(observer_.get());
448     socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
449                                     base::Unretained(&handler_)));
450     RunPendingTasks();
451     socket_->GetMockTransport()->current_delegate()->OnMessage(
452         CreateAuthReply());
453     RunPendingTasks();
454   }
455 
456   std::unique_ptr<MockTestCastSocket> socket_;
457 
458  private:
459   DISALLOW_COPY_AND_ASSIGN(MockCastSocketTest);
460 };
461 
462 class SslCastSocketTest : public CastSocketTestBase {
463  protected:
SslCastSocketTest()464   SslCastSocketTest() {}
465 
TearDown()466   void TearDown() override {
467     if (socket_) {
468       EXPECT_CALL(handler_, OnCloseComplete(net::OK));
469       socket_->Close(base::BindOnce(&CompleteHandler::OnCloseComplete,
470                                     base::Unretained(&handler_)));
471     }
472   }
473 
CreateSockets()474   void CreateSockets() {
475     socket_ = std::make_unique<TestCastSocketBase>(
476         network_context_.get(), socket_open_params_, logger_);
477 
478     server_cert_ =
479         net::ImportCertFromFile(GetTestCertsDirectory(), "self_signed.pem");
480     ASSERT_TRUE(server_cert_);
481     server_private_key_ = ReadTestKeyFromPEM("self_signed.pem");
482     ASSERT_TRUE(server_private_key_);
483     server_context_ = CreateSSLServerContext(
484         server_cert_.get(), *server_private_key_, server_ssl_config_);
485 
486     tcp_server_socket_.reset(
487         new net::TCPServerSocket(nullptr, net::NetLogSource()));
488     ASSERT_EQ(net::OK,
489               tcp_server_socket_->ListenWithAddressAndPort("127.0.0.1", 0, 1));
490     net::IPEndPoint server_address;
491     ASSERT_EQ(net::OK, tcp_server_socket_->GetLocalAddress(&server_address));
492     tcp_client_socket_.reset(
493         new net::TCPClientSocket(net::AddressList(server_address), nullptr,
494                                  nullptr, nullptr, net::NetLogSource()));
495 
496     std::unique_ptr<net::StreamSocket> accepted_socket;
497     accept_result_ = tcp_server_socket_->Accept(
498         &accepted_socket, base::BindOnce(&SslCastSocketTest::TcpAcceptCallback,
499                                          base::Unretained(this)));
500     connect_result_ = tcp_client_socket_->Connect(base::BindOnce(
501         &SslCastSocketTest::TcpConnectCallback, base::Unretained(this)));
502     while (accept_result_ == net::ERR_IO_PENDING ||
503            connect_result_ == net::ERR_IO_PENDING) {
504       RunPendingTasks();
505     }
506     ASSERT_EQ(net::OK, accept_result_);
507     ASSERT_EQ(net::OK, connect_result_);
508     ASSERT_TRUE(accepted_socket);
509     ASSERT_TRUE(tcp_client_socket_->IsConnected());
510 
511     server_socket_ =
512         server_context_->CreateSSLServerSocket(std::move(accepted_socket));
513     ASSERT_TRUE(server_socket_);
514 
515     client_socket_factory()->SetTcpSocket(std::move(tcp_client_socket_));
516   }
517 
ConnectSockets()518   void ConnectSockets() {
519     socket_->AddObserver(observer_.get());
520     socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
521                                     base::Unretained(&handler_)));
522 
523     net::TestCompletionCallback handshake_callback;
524     int server_ret = handshake_callback.GetResult(
525         server_socket_->Handshake(handshake_callback.callback()));
526 
527     ASSERT_EQ(net::OK, server_ret);
528   }
529 
TcpAcceptCallback(int result)530   void TcpAcceptCallback(int result) { accept_result_ = result; }
531 
TcpConnectCallback(int result)532   void TcpConnectCallback(int result) { connect_result_ = result; }
533 
ReadTestKeyFromPEM(const base::StringPiece & name)534   std::unique_ptr<crypto::RSAPrivateKey> ReadTestKeyFromPEM(
535       const base::StringPiece& name) {
536     base::FilePath key_path = GetTestCertsDirectory().AppendASCII(name);
537     std::string pem_data;
538     if (!base::ReadFileToString(key_path, &pem_data)) {
539       return nullptr;
540     }
541 
542     const std::vector<std::string> headers({"PRIVATE KEY"});
543     net::PEMTokenizer pem_tokenizer(pem_data, headers);
544     if (!pem_tokenizer.GetNext()) {
545       return nullptr;
546     }
547     std::vector<uint8_t> key_vector(pem_tokenizer.data().begin(),
548                                     pem_tokenizer.data().end());
549     std::unique_ptr<crypto::RSAPrivateKey> key(
550         crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
551     return key;
552   }
553 
ReadExactLength(net::IOBuffer * buffer,int buffer_length,net::Socket * socket)554   int ReadExactLength(net::IOBuffer* buffer,
555                       int buffer_length,
556                       net::Socket* socket) {
557     scoped_refptr<net::DrainableIOBuffer> draining_buffer =
558         base::MakeRefCounted<net::DrainableIOBuffer>(buffer, buffer_length);
559     while (draining_buffer->BytesRemaining() > 0) {
560       net::TestCompletionCallback read_callback;
561       int read_result = read_callback.GetResult(server_socket_->Read(
562           draining_buffer.get(), draining_buffer->BytesRemaining(),
563           read_callback.callback()));
564       EXPECT_GT(read_result, 0);
565       draining_buffer->DidConsume(read_result);
566     }
567     return buffer_length;
568   }
569 
WriteExactLength(net::IOBuffer * buffer,int buffer_length,net::Socket * socket)570   int WriteExactLength(net::IOBuffer* buffer,
571                        int buffer_length,
572                        net::Socket* socket) {
573     scoped_refptr<net::DrainableIOBuffer> draining_buffer =
574         base::MakeRefCounted<net::DrainableIOBuffer>(buffer, buffer_length);
575     while (draining_buffer->BytesRemaining() > 0) {
576       net::TestCompletionCallback write_callback;
577       int write_result = write_callback.GetResult(server_socket_->Write(
578           draining_buffer.get(), draining_buffer->BytesRemaining(),
579           write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS));
580       EXPECT_GT(write_result, 0);
581       draining_buffer->DidConsume(write_result);
582     }
583     return buffer_length;
584   }
585 
586   // Result values used for TCP socket setup.  These should contain values from
587   // net::Error.
588   int accept_result_;
589   int connect_result_;
590 
591   // Underlying TCP sockets for |socket_| to communicate with |server_socket_|
592   // when testing with the real SSL implementation.
593   std::unique_ptr<net::TransportClientSocket> tcp_client_socket_;
594   std::unique_ptr<net::TCPServerSocket> tcp_server_socket_;
595 
596   std::unique_ptr<TestCastSocketBase> socket_;
597 
598   // |server_socket_| is used for the *RealSSL tests in order to test the
599   // CastSocket over a real SSL socket.  The other members below are used to
600   // initialize |server_socket_|.
601   std::unique_ptr<net::SSLServerSocket> server_socket_;
602   std::unique_ptr<net::SSLServerContext> server_context_;
603   std::unique_ptr<crypto::RSAPrivateKey> server_private_key_;
604   scoped_refptr<net::X509Certificate> server_cert_;
605   net::SSLServerConfig server_ssl_config_;
606 
607  private:
608   DISALLOW_COPY_AND_ASSIGN(SslCastSocketTest);
609 };
610 
611 }  // namespace
612 
613 // Tests that the following connection flow works:
614 // - TCP connection succeeds (async)
615 // - SSL connection succeeds (async)
616 // - Cert is extracted successfully
617 // - Challenge request is sent (async)
618 // - Challenge response is received (async)
619 // - Credentials are verified successfuly
TEST_F(MockCastSocketTest,TestConnectFullSecureFlowAsync)620 TEST_F(MockCastSocketTest, TestConnectFullSecureFlowAsync) {
621   CreateCastSocketSecure();
622   client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
623   client_socket_factory()->SetupSslConnect(net::ASYNC, net::OK);
624 
625   HandleAuthHandshake();
626 
627   EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
628   EXPECT_EQ(ChannelError::NONE, socket_->error_state());
629 }
630 
631 // Tests that the following connection flow works:
632 // - TCP connection succeeds (sync)
633 // - SSL connection succeeds (sync)
634 // - Cert is extracted successfully
635 // - Challenge request is sent (sync)
636 // - Challenge response is received (sync)
637 // - Credentials are verified successfuly
TEST_F(MockCastSocketTest,TestConnectFullSecureFlowSync)638 TEST_F(MockCastSocketTest, TestConnectFullSecureFlowSync) {
639   client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
640   client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::OK);
641 
642   CreateCastSocketSecure();
643   HandleAuthHandshake();
644 
645   EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
646   EXPECT_EQ(ChannelError::NONE, socket_->error_state());
647 }
648 
649 // Test that an AuthMessage with a mangled namespace triggers cancelation
650 // of the connection event loop.
TEST_F(MockCastSocketTest,TestConnectAuthMessageCorrupted)651 TEST_F(MockCastSocketTest, TestConnectAuthMessageCorrupted) {
652   CreateCastSocketSecure();
653   socket_->SetupMockTransport();
654 
655   client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
656   client_socket_factory()->SetupSslConnect(net::ASYNC, net::OK);
657 
658   CastMessage challenge_proto = CreateAuthChallenge();
659   EXPECT_CALL(*socket_->GetMockTransport(),
660               SendMessage(EqualsProto(challenge_proto), _))
661       .WillOnce(PostCompletionCallbackTask<1>(net::OK));
662   EXPECT_CALL(*socket_->GetMockTransport(), Start());
663   EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
664   socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
665                                   base::Unretained(&handler_)));
666   RunPendingTasks();
667   CastMessage mangled_auth_reply = CreateAuthReply();
668   mangled_auth_reply.set_namespace_("BOGUS_NAMESPACE");
669 
670   socket_->GetMockTransport()->current_delegate()->OnMessage(
671       mangled_auth_reply);
672   RunPendingTasks();
673 
674   EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
675   EXPECT_EQ(ChannelError::TRANSPORT_ERROR, socket_->error_state());
676 
677   // Verifies that the CastSocket's resources were torn down during channel
678   // close. (see http://crbug.com/504078)
679   EXPECT_EQ(nullptr, socket_->transport());
680 }
681 
682 // Test connection error - TCP connect fails (async)
TEST_F(MockCastSocketTest,TestConnectTcpConnectErrorAsync)683 TEST_F(MockCastSocketTest, TestConnectTcpConnectErrorAsync) {
684   CreateCastSocketSecure();
685 
686   client_socket_factory()->SetupTcpConnect(net::ASYNC, net::ERR_FAILED);
687 
688   EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
689   socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
690                                   base::Unretained(&handler_)));
691   RunPendingTasks();
692 
693   EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
694   EXPECT_EQ(ChannelError::CONNECT_ERROR, socket_->error_state());
695 }
696 
697 // Test connection error - TCP connect fails (sync)
TEST_F(MockCastSocketTest,TestConnectTcpConnectErrorSync)698 TEST_F(MockCastSocketTest, TestConnectTcpConnectErrorSync) {
699   CreateCastSocketSecure();
700 
701   client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::ERR_FAILED);
702 
703   EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
704   socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
705                                   base::Unretained(&handler_)));
706   RunPendingTasks();
707 
708   EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
709   EXPECT_EQ(ChannelError::CONNECT_ERROR, socket_->error_state());
710 }
711 
712 // Test connection error - timeout
TEST_F(MockCastSocketTest,TestConnectTcpTimeoutError)713 TEST_F(MockCastSocketTest, TestConnectTcpTimeoutError) {
714   CreateCastSocketSecure();
715   client_socket_factory()->SetupTcpConnectUnresponsive();
716   EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
717   EXPECT_CALL(*observer_, OnError(_, ChannelError::CONNECT_TIMEOUT));
718   socket_->AddObserver(observer_.get());
719   socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
720                                   base::Unretained(&handler_)));
721   RunPendingTasks();
722 
723   EXPECT_EQ(ReadyState::CONNECTING, socket_->ready_state());
724   EXPECT_EQ(ChannelError::NONE, socket_->error_state());
725   socket_->TriggerTimeout();
726   RunPendingTasks();
727 
728   EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
729   EXPECT_EQ(ChannelError::CONNECT_TIMEOUT, socket_->error_state());
730 }
731 
732 // Test connection error - TCP socket returns timeout
TEST_F(MockCastSocketTest,TestConnectTcpSocketTimeoutError)733 TEST_F(MockCastSocketTest, TestConnectTcpSocketTimeoutError) {
734   CreateCastSocketSecure();
735   client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS,
736                                            net::ERR_CONNECTION_TIMED_OUT);
737   EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
738   EXPECT_CALL(*observer_, OnError(_, ChannelError::CONNECT_TIMEOUT));
739   socket_->AddObserver(observer_.get());
740   socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
741                                   base::Unretained(&handler_)));
742   RunPendingTasks();
743 
744   EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
745   EXPECT_EQ(ChannelError::CONNECT_TIMEOUT, socket_->error_state());
746   EXPECT_EQ(net::ERR_CONNECTION_TIMED_OUT,
747             logger_->GetLastError(socket_->id()).net_return_value);
748 }
749 
750 // Test connection error - SSL connect fails (async)
TEST_F(MockCastSocketTest,TestConnectSslConnectErrorAsync)751 TEST_F(MockCastSocketTest, TestConnectSslConnectErrorAsync) {
752   CreateCastSocketSecure();
753 
754   client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
755   client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::ERR_FAILED);
756 
757   EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
758   socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
759                                   base::Unretained(&handler_)));
760   RunPendingTasks();
761 
762   EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
763   EXPECT_EQ(ChannelError::AUTHENTICATION_ERROR, socket_->error_state());
764 }
765 
766 // Test connection error - SSL connect fails (sync)
TEST_F(MockCastSocketTest,TestConnectSslConnectErrorSync)767 TEST_F(MockCastSocketTest, TestConnectSslConnectErrorSync) {
768   CreateCastSocketSecure();
769 
770   client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
771   client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::ERR_FAILED);
772 
773   EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
774   socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
775                                   base::Unretained(&handler_)));
776   RunPendingTasks();
777 
778   EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
779   EXPECT_EQ(ChannelError::AUTHENTICATION_ERROR, socket_->error_state());
780   EXPECT_EQ(net::ERR_FAILED,
781             logger_->GetLastError(socket_->id()).net_return_value);
782 }
783 
784 // Test connection error - SSL connect times out (sync)
TEST_F(MockCastSocketTest,TestConnectSslConnectTimeoutSync)785 TEST_F(MockCastSocketTest, TestConnectSslConnectTimeoutSync) {
786   client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
787   client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS,
788                                            net::ERR_CONNECTION_TIMED_OUT);
789 
790   CreateCastSocketSecure();
791 
792   EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
793   socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
794                                   base::Unretained(&handler_)));
795   RunPendingTasks();
796 
797   EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
798   EXPECT_EQ(ChannelError::CONNECT_TIMEOUT, socket_->error_state());
799   EXPECT_EQ(net::ERR_CONNECTION_TIMED_OUT,
800             logger_->GetLastError(socket_->id()).net_return_value);
801 }
802 
803 // Test connection error - SSL connect times out (async)
TEST_F(MockCastSocketTest,TestConnectSslConnectTimeoutAsync)804 TEST_F(MockCastSocketTest, TestConnectSslConnectTimeoutAsync) {
805   CreateCastSocketSecure();
806 
807   client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
808   client_socket_factory()->SetupSslConnect(net::ASYNC,
809                                            net::ERR_CONNECTION_TIMED_OUT);
810 
811   EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
812   socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
813                                   base::Unretained(&handler_)));
814   RunPendingTasks();
815 
816   EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
817   EXPECT_EQ(ChannelError::CONNECT_TIMEOUT, socket_->error_state());
818 }
819 
820 // Test connection error - challenge send fails
TEST_F(MockCastSocketTest,TestConnectChallengeSendError)821 TEST_F(MockCastSocketTest, TestConnectChallengeSendError) {
822   CreateCastSocketSecure();
823   socket_->SetupMockTransport();
824 
825   client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
826   client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::OK);
827   EXPECT_CALL(*socket_->GetMockTransport(),
828               SendMessage(EqualsProto(CreateAuthChallenge()), _))
829       .WillOnce(PostCompletionCallbackTask<1>(net::ERR_CONNECTION_RESET));
830 
831   EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
832   socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
833                                   base::Unretained(&handler_)));
834   RunPendingTasks();
835 
836   EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
837   EXPECT_EQ(ChannelError::CAST_SOCKET_ERROR, socket_->error_state());
838 }
839 
840 // Test connection error - connection is destroyed after the challenge is
841 // sent, with the async result still lurking in the task queue.
TEST_F(MockCastSocketTest,TestConnectDestroyedAfterChallengeSent)842 TEST_F(MockCastSocketTest, TestConnectDestroyedAfterChallengeSent) {
843   CreateCastSocketSecure();
844   socket_->SetupMockTransport();
845   client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
846   client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::OK);
847   EXPECT_CALL(*socket_->GetMockTransport(),
848               SendMessage(EqualsProto(CreateAuthChallenge()), _))
849       .WillOnce(PostCompletionCallbackTask<1>(net::ERR_CONNECTION_RESET));
850   socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
851                                   base::Unretained(&handler_)));
852   RunPendingTasks();
853   socket_.reset();
854   RunPendingTasks();
855 }
856 
857 // Test connection error - challenge reply receive fails
TEST_F(MockCastSocketTest,TestConnectChallengeReplyReceiveError)858 TEST_F(MockCastSocketTest, TestConnectChallengeReplyReceiveError) {
859   CreateCastSocketSecure();
860   socket_->SetupMockTransport();
861 
862   client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
863   client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::OK);
864   EXPECT_CALL(*socket_->GetMockTransport(),
865               SendMessage(EqualsProto(CreateAuthChallenge()), _))
866       .WillOnce(PostCompletionCallbackTask<1>(net::OK));
867   client_socket_factory()->AddReadResult(net::SYNCHRONOUS, net::ERR_FAILED);
868   EXPECT_CALL(*observer_, OnError(_, ChannelError::CAST_SOCKET_ERROR));
869   EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
870   EXPECT_CALL(*socket_->GetMockTransport(), Start());
871   socket_->AddObserver(observer_.get());
872   socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
873                                   base::Unretained(&handler_)));
874   RunPendingTasks();
875   socket_->GetMockTransport()->current_delegate()->OnError(
876       ChannelError::CAST_SOCKET_ERROR);
877   RunPendingTasks();
878 
879   EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
880   EXPECT_EQ(ChannelError::CAST_SOCKET_ERROR, socket_->error_state());
881 }
882 
TEST_F(MockCastSocketTest,TestConnectChallengeVerificationFails)883 TEST_F(MockCastSocketTest, TestConnectChallengeVerificationFails) {
884   CreateCastSocketSecure();
885   socket_->SetupMockTransport();
886   client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
887   client_socket_factory()->SetupSslConnect(net::ASYNC, net::OK);
888   socket_->SetVerifyChallengeResult(false);
889 
890   EXPECT_CALL(*observer_, OnError(_, ChannelError::AUTHENTICATION_ERROR));
891   CastMessage challenge_proto = CreateAuthChallenge();
892   EXPECT_CALL(*socket_->GetMockTransport(),
893               SendMessage(EqualsProto(challenge_proto), _))
894       .WillOnce(PostCompletionCallbackTask<1>(net::OK));
895   EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
896   EXPECT_CALL(*socket_->GetMockTransport(), Start());
897   socket_->AddObserver(observer_.get());
898   socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
899                                   base::Unretained(&handler_)));
900   RunPendingTasks();
901   socket_->GetMockTransport()->current_delegate()->OnMessage(CreateAuthReply());
902   RunPendingTasks();
903 
904   EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
905   EXPECT_EQ(ChannelError::AUTHENTICATION_ERROR, socket_->error_state());
906 }
907 
908 // Sends message data through an actual non-mocked CastTransport object,
909 // testing the two components in integration.
TEST_F(MockCastSocketTest,TestConnectEndToEndWithRealTransportAsync)910 TEST_F(MockCastSocketTest, TestConnectEndToEndWithRealTransportAsync) {
911   CreateCastSocketSecure();
912   client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
913   client_socket_factory()->SetupSslConnect(net::ASYNC, net::OK);
914 
915   // Set low-level auth challenge expectations.
916   CastMessage challenge = CreateAuthChallenge();
917   std::string challenge_str;
918   EXPECT_TRUE(MessageFramer::Serialize(challenge, &challenge_str));
919   client_socket_factory()->AddWriteResultForData(net::ASYNC, challenge_str);
920 
921   // Set low-level auth reply expectations.
922   CastMessage reply = CreateAuthReply();
923   std::string reply_str;
924   EXPECT_TRUE(MessageFramer::Serialize(reply, &reply_str));
925   client_socket_factory()->AddReadResultForData(net::ASYNC, reply_str);
926   client_socket_factory()->AddReadResult(net::ASYNC, net::ERR_IO_PENDING);
927   // Make sure the data is ready by the TLS socket and not the TCP socket.
928   client_socket_factory()->Pause();
929   client_socket_factory()->SetTLSSocketCreatedClosure(
930       base::BindLambdaForTesting([&] { client_socket_factory()->Resume(); }));
931 
932   CastMessage test_message = CreateTestMessage();
933   std::string test_message_str;
934   EXPECT_TRUE(MessageFramer::Serialize(test_message, &test_message_str));
935   client_socket_factory()->AddWriteResultForData(net::ASYNC, test_message_str);
936 
937   EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
938   socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
939                                   base::Unretained(&handler_)));
940   RunPendingTasks();
941   EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
942   EXPECT_EQ(ChannelError::NONE, socket_->error_state());
943 
944   // Send the test message through a real transport object.
945   EXPECT_CALL(handler_, OnWriteComplete(net::OK));
946   socket_->transport()->SendMessage(
947       test_message, base::BindOnce(&CompleteHandler::OnWriteComplete,
948                                    base::Unretained(&handler_)));
949   RunPendingTasks();
950 
951   EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
952   EXPECT_EQ(ChannelError::NONE, socket_->error_state());
953 }
954 
955 // Same as TestConnectEndToEndWithRealTransportAsync, except synchronous.
TEST_F(MockCastSocketTest,TestConnectEndToEndWithRealTransportSync)956 TEST_F(MockCastSocketTest, TestConnectEndToEndWithRealTransportSync) {
957   CreateCastSocketSecure();
958   client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
959   client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::OK);
960 
961   // Set low-level auth challenge expectations.
962   CastMessage challenge = CreateAuthChallenge();
963   std::string challenge_str;
964   EXPECT_TRUE(MessageFramer::Serialize(challenge, &challenge_str));
965   client_socket_factory()->AddWriteResultForData(net::SYNCHRONOUS,
966                                                  challenge_str);
967 
968   // Set low-level auth reply expectations.
969   CastMessage reply = CreateAuthReply();
970   std::string reply_str;
971   EXPECT_TRUE(MessageFramer::Serialize(reply, &reply_str));
972   client_socket_factory()->AddReadResultForData(net::SYNCHRONOUS, reply_str);
973   client_socket_factory()->AddReadResult(net::ASYNC, net::ERR_IO_PENDING);
974   // Make sure the data is ready by the TLS socket and not the TCP socket.
975   client_socket_factory()->Pause();
976   client_socket_factory()->SetTLSSocketCreatedClosure(
977       base::BindLambdaForTesting([&] { client_socket_factory()->Resume(); }));
978 
979   CastMessage test_message = CreateTestMessage();
980   std::string test_message_str;
981   EXPECT_TRUE(MessageFramer::Serialize(test_message, &test_message_str));
982   client_socket_factory()->AddWriteResultForData(net::SYNCHRONOUS,
983                                                  test_message_str);
984 
985   EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
986   socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
987                                   base::Unretained(&handler_)));
988   RunPendingTasks();
989   EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
990   EXPECT_EQ(ChannelError::NONE, socket_->error_state());
991 
992   // Send the test message through a real transport object.
993   EXPECT_CALL(handler_, OnWriteComplete(net::OK));
994   socket_->transport()->SendMessage(
995       test_message, base::BindOnce(&CompleteHandler::OnWriteComplete,
996                                    base::Unretained(&handler_)));
997   RunPendingTasks();
998 
999   EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
1000   EXPECT_EQ(ChannelError::NONE, socket_->error_state());
1001 }
1002 
TEST_F(MockCastSocketTest,TestObservers)1003 TEST_F(MockCastSocketTest, TestObservers) {
1004   CreateCastSocketSecure();
1005   // Test AddObserever
1006   MockCastSocketObserver observer1;
1007   MockCastSocketObserver observer2;
1008   socket_->AddObserver(&observer1);
1009   socket_->AddObserver(&observer1);
1010   socket_->AddObserver(&observer2);
1011   socket_->AddObserver(&observer2);
1012 
1013   // Test notify observers
1014   EXPECT_CALL(observer1, OnError(_, cast_channel::ChannelError::CONNECT_ERROR));
1015   EXPECT_CALL(observer2, OnError(_, cast_channel::ChannelError::CONNECT_ERROR));
1016   CastSocketImpl::CastSocketMessageDelegate delegate(socket_.get());
1017   delegate.OnError(cast_channel::ChannelError::CONNECT_ERROR);
1018 }
1019 
TEST_F(MockCastSocketTest,TestOpenChannelConnectingSocket)1020 TEST_F(MockCastSocketTest, TestOpenChannelConnectingSocket) {
1021   CreateCastSocketSecure();
1022   client_socket_factory()->SetupTcpConnectUnresponsive();
1023   socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
1024                                   base::Unretained(&handler_)));
1025   RunPendingTasks();
1026 
1027   EXPECT_CALL(handler_, OnConnectComplete(socket_.get())).Times(2);
1028   socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
1029                                   base::Unretained(&handler_)));
1030   socket_->TriggerTimeout();
1031   RunPendingTasks();
1032 }
1033 
TEST_F(MockCastSocketTest,TestOpenChannelConnectedSocket)1034 TEST_F(MockCastSocketTest, TestOpenChannelConnectedSocket) {
1035   CreateCastSocketSecure();
1036   client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
1037   client_socket_factory()->SetupSslConnect(net::ASYNC, net::OK);
1038   HandleAuthHandshake();
1039 
1040   EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
1041   socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
1042                                   base::Unretained(&handler_)));
1043 }
1044 
TEST_F(MockCastSocketTest,TestOpenChannelClosedSocket)1045 TEST_F(MockCastSocketTest, TestOpenChannelClosedSocket) {
1046   CreateCastSocketSecure();
1047   client_socket_factory()->SetupTcpConnect(net::ASYNC, net::ERR_FAILED);
1048 
1049   EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
1050   socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
1051                                   base::Unretained(&handler_)));
1052   RunPendingTasks();
1053 
1054   EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
1055   socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
1056                                   base::Unretained(&handler_)));
1057 }
1058 
1059 // https://crbug.com/874491, flaky on Win and Mac
1060 #if defined(OS_WIN) || defined(OS_APPLE)
1061 #define MAYBE_TestConnectEndToEndWithRealSSL \
1062   DISABLED_TestConnectEndToEndWithRealSSL
1063 #else
1064 #define MAYBE_TestConnectEndToEndWithRealSSL TestConnectEndToEndWithRealSSL
1065 #endif
1066 // Tests connecting through an actual non-mocked CastTransport object and
1067 // non-mocked SSLClientSocket, testing the components in integration.
TEST_F(SslCastSocketTest,MAYBE_TestConnectEndToEndWithRealSSL)1068 TEST_F(SslCastSocketTest, MAYBE_TestConnectEndToEndWithRealSSL) {
1069   CreateSockets();
1070   ConnectSockets();
1071 
1072   // Set low-level auth challenge expectations.
1073   CastMessage challenge = CreateAuthChallenge();
1074   std::string challenge_str;
1075   EXPECT_TRUE(MessageFramer::Serialize(challenge, &challenge_str));
1076 
1077   int challenge_buffer_length = challenge_str.size();
1078   scoped_refptr<net::IOBuffer> challenge_buffer =
1079       base::MakeRefCounted<net::IOBuffer>(challenge_buffer_length);
1080   int read = ReadExactLength(challenge_buffer.get(), challenge_buffer_length,
1081                              server_socket_.get());
1082 
1083   EXPECT_EQ(challenge_buffer_length, read);
1084   EXPECT_EQ(challenge_str,
1085             std::string(challenge_buffer->data(), challenge_buffer_length));
1086 
1087   // Set low-level auth reply expectations.
1088   CastMessage reply = CreateAuthReply();
1089   std::string reply_str;
1090   EXPECT_TRUE(MessageFramer::Serialize(reply, &reply_str));
1091 
1092   scoped_refptr<net::StringIOBuffer> reply_buffer =
1093       base::MakeRefCounted<net::StringIOBuffer>(reply_str);
1094   int written = WriteExactLength(reply_buffer.get(), reply_buffer->size(),
1095                                  server_socket_.get());
1096 
1097   EXPECT_EQ(reply_buffer->size(), written);
1098   EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
1099   RunPendingTasks();
1100 
1101   EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
1102   EXPECT_EQ(ChannelError::NONE, socket_->error_state());
1103 }
1104 
1105 // Sends message data through an actual non-mocked CastTransport object and
1106 // non-mocked SSLClientSocket, testing the components in integration.
TEST_F(SslCastSocketTest,DISABLED_TestMessageEndToEndWithRealSSL)1107 TEST_F(SslCastSocketTest, DISABLED_TestMessageEndToEndWithRealSSL) {
1108   CreateSockets();
1109   ConnectSockets();
1110 
1111   // Set low-level auth challenge expectations.
1112   CastMessage challenge = CreateAuthChallenge();
1113   std::string challenge_str;
1114   EXPECT_TRUE(MessageFramer::Serialize(challenge, &challenge_str));
1115 
1116   int challenge_buffer_length = challenge_str.size();
1117   scoped_refptr<net::IOBuffer> challenge_buffer =
1118       base::MakeRefCounted<net::IOBuffer>(challenge_buffer_length);
1119 
1120   int read = ReadExactLength(challenge_buffer.get(), challenge_buffer_length,
1121                              server_socket_.get());
1122 
1123   EXPECT_EQ(challenge_buffer_length, read);
1124   EXPECT_EQ(challenge_str,
1125             std::string(challenge_buffer->data(), challenge_buffer_length));
1126 
1127   // Set low-level auth reply expectations.
1128   CastMessage reply = CreateAuthReply();
1129   std::string reply_str;
1130   EXPECT_TRUE(MessageFramer::Serialize(reply, &reply_str));
1131 
1132   scoped_refptr<net::StringIOBuffer> reply_buffer =
1133       base::MakeRefCounted<net::StringIOBuffer>(reply_str);
1134   int written = WriteExactLength(reply_buffer.get(), reply_buffer->size(),
1135                                  server_socket_.get());
1136 
1137   EXPECT_EQ(reply_buffer->size(), written);
1138   EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
1139   RunPendingTasks();
1140 
1141   EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
1142   EXPECT_EQ(ChannelError::NONE, socket_->error_state());
1143 
1144   // Send a test message through the ssl socket.
1145   CastMessage test_message = CreateTestMessage();
1146   std::string test_message_str;
1147   EXPECT_TRUE(MessageFramer::Serialize(test_message, &test_message_str));
1148 
1149   int test_message_length = test_message_str.size();
1150   scoped_refptr<net::IOBuffer> test_message_buffer =
1151       base::MakeRefCounted<net::IOBuffer>(test_message_length);
1152 
1153   EXPECT_CALL(handler_, OnWriteComplete(net::OK));
1154   socket_->transport()->SendMessage(
1155       test_message, base::BindOnce(&CompleteHandler::OnWriteComplete,
1156                                    base::Unretained(&handler_)));
1157   RunPendingTasks();
1158 
1159   read = ReadExactLength(test_message_buffer.get(), test_message_length,
1160                          server_socket_.get());
1161 
1162   EXPECT_EQ(test_message_length, read);
1163   EXPECT_EQ(test_message_str,
1164             std::string(test_message_buffer->data(), test_message_length));
1165 
1166   EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
1167   EXPECT_EQ(ChannelError::NONE, socket_->error_state());
1168 }
1169 
1170 }  // namespace cast_channel
1171