1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "components/cast_channel/cast_socket.h"
6
7 #include <stdint.h>
8
9 #include <utility>
10 #include <vector>
11
12 #include "base/bind.h"
13 #include "base/callback_helpers.h"
14 #include "base/files/file_util.h"
15 #include "base/location.h"
16 #include "base/macros.h"
17 #include "base/memory/ptr_util.h"
18 #include "base/memory/weak_ptr.h"
19 #include "base/path_service.h"
20 #include "base/run_loop.h"
21 #include "base/single_thread_task_runner.h"
22 #include "base/strings/string_number_conversions.h"
23 #include "base/sys_byteorder.h"
24 #include "base/test/bind.h"
25 #include "base/threading/thread_task_runner_handle.h"
26 #include "base/timer/mock_timer.h"
27 #include "build/build_config.h"
28 #include "components/cast_channel/cast_auth_util.h"
29 #include "components/cast_channel/cast_framer.h"
30 #include "components/cast_channel/cast_message_util.h"
31 #include "components/cast_channel/cast_test_util.h"
32 #include "components/cast_channel/cast_transport.h"
33 #include "components/cast_channel/logger.h"
34 #include "content/public/test/browser_task_environment.h"
35 #include "crypto/rsa_private_key.h"
36 #include "mojo/public/cpp/bindings/remote.h"
37 #include "net/base/address_list.h"
38 #include "net/base/net_errors.h"
39 #include "net/cert/pem.h"
40 #include "net/socket/client_socket_factory.h"
41 #include "net/socket/socket_test_util.h"
42 #include "net/socket/ssl_client_socket.h"
43 #include "net/socket/ssl_server_socket.h"
44 #include "net/socket/tcp_client_socket.h"
45 #include "net/socket/tcp_server_socket.h"
46 #include "net/ssl/ssl_info.h"
47 #include "net/ssl/ssl_server_config.h"
48 #include "net/test/cert_test_util.h"
49 #include "net/test/test_data_directory.h"
50 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
51 #include "net/url_request/url_request_test_util.h"
52 #include "services/network/network_context.h"
53 #include "testing/gmock/include/gmock/gmock.h"
54 #include "testing/gtest/include/gtest/gtest.h"
55 #include "third_party/openscreen/src/cast/common/channel/proto/cast_channel.pb.h"
56
57 const int64_t kDistantTimeoutMillis = 100000; // 100 seconds (never hit).
58
59 using ::testing::A;
60 using ::testing::DoAll;
61 using ::testing::Invoke;
62 using ::testing::InvokeArgument;
63 using ::testing::NotNull;
64 using ::testing::Return;
65 using ::testing::SaveArg;
66 using ::testing::_;
67
68 using ::cast::channel::CastMessage;
69
70 namespace cast_channel {
71 namespace {
72 const char kAuthNamespace[] = "urn:x-cast:com.google.cast.tp.deviceauth";
73
74 // Returns an auth challenge message inline.
CreateAuthChallenge()75 CastMessage CreateAuthChallenge() {
76 CastMessage output;
77 CreateAuthChallengeMessage(&output, AuthContext::Create());
78 return output;
79 }
80
81 // Returns an auth challenge response message inline.
CreateAuthReply()82 CastMessage CreateAuthReply() {
83 CastMessage output;
84 output.set_protocol_version(CastMessage::CASTV2_1_0);
85 output.set_source_id("sender-0");
86 output.set_destination_id("receiver-0");
87 output.set_payload_type(CastMessage::BINARY);
88 output.set_payload_binary("abcd");
89 output.set_namespace_(kAuthNamespace);
90 return output;
91 }
92
CreateTestMessage()93 CastMessage CreateTestMessage() {
94 CastMessage test_message;
95 test_message.set_protocol_version(CastMessage::CASTV2_1_0);
96 test_message.set_namespace_("ns");
97 test_message.set_source_id("source");
98 test_message.set_destination_id("dest");
99 test_message.set_payload_type(CastMessage::STRING);
100 test_message.set_payload_utf8("payload");
101 return test_message;
102 }
103
GetTestCertsDirectory()104 base::FilePath GetTestCertsDirectory() {
105 base::FilePath path;
106 base::PathService::Get(base::DIR_SOURCE_ROOT, &path);
107 path = path.Append(FILE_PATH_LITERAL("components"));
108 path = path.Append(FILE_PATH_LITERAL("test"));
109 path = path.Append(FILE_PATH_LITERAL("data"));
110 path = path.Append(FILE_PATH_LITERAL("cast_channel"));
111 return path;
112 }
113
114 class MockTCPSocket : public net::MockTCPClientSocket {
115 public:
MockTCPSocket(bool do_nothing,net::SocketDataProvider * socket_provider)116 MockTCPSocket(bool do_nothing, net::SocketDataProvider* socket_provider)
117 : net::MockTCPClientSocket(net::AddressList(), nullptr, socket_provider) {
118 do_nothing_ = do_nothing;
119 set_enable_read_if_ready(true);
120 }
121
Connect(net::CompletionOnceCallback callback)122 int Connect(net::CompletionOnceCallback callback) override {
123 if (do_nothing_) {
124 // Stall the I/O event loop.
125 return net::ERR_IO_PENDING;
126 }
127 return net::MockTCPClientSocket::Connect(std::move(callback));
128 }
129
130 private:
131 bool do_nothing_;
132
133 DISALLOW_COPY_AND_ASSIGN(MockTCPSocket);
134 };
135
136 class CompleteHandler {
137 public:
CompleteHandler()138 CompleteHandler() {}
139 MOCK_METHOD1(OnCloseComplete, void(int result));
140 MOCK_METHOD1(OnConnectComplete, void(CastSocket* socket));
141 MOCK_METHOD1(OnWriteComplete, void(int result));
142 MOCK_METHOD1(OnReadComplete, void(int result));
143
144 private:
145 DISALLOW_COPY_AND_ASSIGN(CompleteHandler);
146 };
147
148 class TestCastSocketBase : public CastSocketImpl {
149 public:
TestCastSocketBase(network::mojom::NetworkContext * network_context,const CastSocketOpenParams & open_params,Logger * logger)150 TestCastSocketBase(network::mojom::NetworkContext* network_context,
151 const CastSocketOpenParams& open_params,
152 Logger* logger)
153 : CastSocketImpl(base::BindRepeating(
154 [](network::mojom::NetworkContext* network_context) {
155 return network_context;
156 },
157 network_context),
158 open_params,
159 logger,
160 AuthContext::Create()),
161 verify_challenge_result_(true),
162 verify_challenge_disallow_(false),
163 mock_timer_(new base::MockOneShotTimer()) {
164 SetPeerCertForTesting(
165 net::ImportCertFromFile(GetTestCertsDirectory(), "self_signed.pem"));
166 }
~TestCastSocketBase()167 ~TestCastSocketBase() override {}
168
SetVerifyChallengeResult(bool value)169 void SetVerifyChallengeResult(bool value) {
170 verify_challenge_result_ = value;
171 }
172
TriggerTimeout()173 void TriggerTimeout() { mock_timer_->Fire(); }
174
TestVerifyChannelPolicyNone()175 bool TestVerifyChannelPolicyNone() {
176 AuthResult authResult;
177 return VerifyChannelPolicy(authResult);
178 }
179
DisallowVerifyChallengeResult()180 void DisallowVerifyChallengeResult() { verify_challenge_disallow_ = true; }
181
182 protected:
VerifyChallengeReply()183 bool VerifyChallengeReply() override {
184 EXPECT_FALSE(verify_challenge_disallow_);
185 return verify_challenge_result_;
186 }
187
GetTimer()188 base::OneShotTimer* GetTimer() override { return mock_timer_.get(); }
189
190 // Simulated result of verifying challenge reply.
191 bool verify_challenge_result_;
192 bool verify_challenge_disallow_;
193 std::unique_ptr<base::MockOneShotTimer> mock_timer_;
194
195 private:
196 DISALLOW_COPY_AND_ASSIGN(TestCastSocketBase);
197 };
198
199 class MockTestCastSocket : public TestCastSocketBase {
200 public:
CreateSecure(network::mojom::NetworkContext * network_context,const CastSocketOpenParams & open_params,Logger * logger)201 static std::unique_ptr<MockTestCastSocket> CreateSecure(
202 network::mojom::NetworkContext* network_context,
203 const CastSocketOpenParams& open_params,
204 Logger* logger) {
205 return std::unique_ptr<MockTestCastSocket>(
206 new MockTestCastSocket(network_context, open_params, logger));
207 }
208
209 using TestCastSocketBase::TestCastSocketBase;
210
MockTestCastSocket(network::mojom::NetworkContext * network_context,const CastSocketOpenParams & open_params,Logger * logger)211 MockTestCastSocket(network::mojom::NetworkContext* network_context,
212 const CastSocketOpenParams& open_params,
213 Logger* logger)
214 : TestCastSocketBase(network_context, open_params, logger) {}
215
~MockTestCastSocket()216 ~MockTestCastSocket() override {}
217
SetupMockTransport()218 void SetupMockTransport() {
219 mock_transport_ = new MockCastTransport;
220 SetTransportForTesting(base::WrapUnique(mock_transport_));
221 }
222
TestVerifyChannelPolicyAudioOnly()223 bool TestVerifyChannelPolicyAudioOnly() {
224 AuthResult authResult;
225 authResult.channel_policies |= AuthResult::POLICY_AUDIO_ONLY;
226 return VerifyChannelPolicy(authResult);
227 }
228
GetMockTransport()229 MockCastTransport* GetMockTransport() {
230 CHECK(mock_transport_);
231 return mock_transport_;
232 }
233
234 private:
235 MockCastTransport* mock_transport_ = nullptr;
236
237 DISALLOW_COPY_AND_ASSIGN(MockTestCastSocket);
238 };
239
240 // TODO(https://crbug.com/928467): Remove this class.
241 class TestSocketFactory : public net::ClientSocketFactory {
242 public:
TestSocketFactory(net::IPEndPoint ip)243 explicit TestSocketFactory(net::IPEndPoint ip) : ip_(ip) {}
244 ~TestSocketFactory() override = default;
245
246 // Socket connection helpers.
SetupTcpConnect(net::IoMode mode,int result)247 void SetupTcpConnect(net::IoMode mode, int result) {
248 tcp_connect_data_.reset(new net::MockConnect(mode, result, ip_));
249 }
SetupSslConnect(net::IoMode mode,int result)250 void SetupSslConnect(net::IoMode mode, int result) {
251 ssl_connect_data_.reset(new net::MockConnect(mode, result, ip_));
252 }
253
254 // Socket I/O helpers.
AddWriteResult(const net::MockWrite & write)255 void AddWriteResult(const net::MockWrite& write) { writes_.push_back(write); }
AddWriteResult(net::IoMode mode,int result)256 void AddWriteResult(net::IoMode mode, int result) {
257 AddWriteResult(net::MockWrite(mode, result));
258 }
AddWriteResultForData(net::IoMode mode,const std::string & msg)259 void AddWriteResultForData(net::IoMode mode, const std::string& msg) {
260 AddWriteResult(mode, msg.size());
261 }
AddReadResult(const net::MockRead & read)262 void AddReadResult(const net::MockRead& read) { reads_.push_back(read); }
AddReadResult(net::IoMode mode,int result)263 void AddReadResult(net::IoMode mode, int result) {
264 AddReadResult(net::MockRead(mode, result));
265 }
AddReadResultForData(net::IoMode mode,const std::string & data)266 void AddReadResultForData(net::IoMode mode, const std::string& data) {
267 AddReadResult(net::MockRead(mode, data.c_str(), data.size()));
268 }
269
270 // Helpers for modifying other connection-related behaviors.
SetupTcpConnectUnresponsive()271 void SetupTcpConnectUnresponsive() { tcp_unresponsive_ = true; }
272
SetTcpSocket(std::unique_ptr<net::TransportClientSocket> tcp_client_socket)273 void SetTcpSocket(
274 std::unique_ptr<net::TransportClientSocket> tcp_client_socket) {
275 tcp_client_socket_ = std::move(tcp_client_socket);
276 }
277
SetTLSSocketCreatedClosure(base::OnceClosure closure)278 void SetTLSSocketCreatedClosure(base::OnceClosure closure) {
279 tls_socket_created_ = std::move(closure);
280 }
281
Pause()282 void Pause() {
283 if (socket_data_provider_)
284 socket_data_provider_->Pause();
285 else
286 socket_data_provider_paused_ = true;
287 }
288
Resume()289 void Resume() { socket_data_provider_->Resume(); }
290
291 private:
CreateDatagramClientSocket(net::DatagramSocket::BindType,net::NetLog *,const net::NetLogSource &)292 std::unique_ptr<net::DatagramClientSocket> CreateDatagramClientSocket(
293 net::DatagramSocket::BindType,
294 net::NetLog*,
295 const net::NetLogSource&) override {
296 NOTIMPLEMENTED();
297 return nullptr;
298 }
CreateTransportClientSocket(const net::AddressList &,std::unique_ptr<net::SocketPerformanceWatcher>,net::NetworkQualityEstimator *,net::NetLog *,const net::NetLogSource &)299 std::unique_ptr<net::TransportClientSocket> CreateTransportClientSocket(
300 const net::AddressList&,
301 std::unique_ptr<net::SocketPerformanceWatcher>,
302 net::NetworkQualityEstimator*,
303 net::NetLog*,
304 const net::NetLogSource&) override {
305 if (tcp_client_socket_)
306 return std::move(tcp_client_socket_);
307
308 if (tcp_unresponsive_) {
309 socket_data_provider_ = std::make_unique<net::StaticSocketDataProvider>();
310 return std::unique_ptr<net::TransportClientSocket>(
311 new MockTCPSocket(true, socket_data_provider_.get()));
312 } else {
313 socket_data_provider_ =
314 std::make_unique<net::StaticSocketDataProvider>(reads_, writes_);
315 socket_data_provider_->set_connect_data(*tcp_connect_data_);
316 if (socket_data_provider_paused_)
317 socket_data_provider_->Pause();
318 return std::unique_ptr<net::TransportClientSocket>(
319 new MockTCPSocket(false, socket_data_provider_.get()));
320 }
321 }
CreateSSLClientSocket(net::SSLClientContext * context,std::unique_ptr<net::StreamSocket> nested_socket,const net::HostPortPair & host_and_port,const net::SSLConfig & ssl_config)322 std::unique_ptr<net::SSLClientSocket> CreateSSLClientSocket(
323 net::SSLClientContext* context,
324 std::unique_ptr<net::StreamSocket> nested_socket,
325 const net::HostPortPair& host_and_port,
326 const net::SSLConfig& ssl_config) override {
327 if (!ssl_connect_data_) {
328 // Test isn't overriding SSL socket creation.
329 return net::ClientSocketFactory::GetDefaultFactory()
330 ->CreateSSLClientSocket(context, std::move(nested_socket),
331 host_and_port, ssl_config);
332 }
333 ssl_socket_data_provider_ = std::make_unique<net::SSLSocketDataProvider>(
334 ssl_connect_data_->mode, ssl_connect_data_->result);
335
336 if (tls_socket_created_)
337 std::move(tls_socket_created_).Run();
338
339 return std::make_unique<net::MockSSLClientSocket>(
340 std::move(nested_socket), net::HostPortPair(), net::SSLConfig(),
341 ssl_socket_data_provider_.get());
342 }
CreateProxyClientSocket(std::unique_ptr<net::StreamSocket> stream_socket,const std::string & user_agent,const net::HostPortPair & endpoint,const net::ProxyServer & proxy_server,net::HttpAuthController * http_auth_controller,bool tunnel,bool using_spdy,net::NextProto negotiated_protocol,net::ProxyDelegate * proxy_delegate,const net::NetworkTrafficAnnotationTag & traffic_annotation)343 std::unique_ptr<net::ProxyClientSocket> CreateProxyClientSocket(
344 std::unique_ptr<net::StreamSocket> stream_socket,
345 const std::string& user_agent,
346 const net::HostPortPair& endpoint,
347 const net::ProxyServer& proxy_server,
348 net::HttpAuthController* http_auth_controller,
349 bool tunnel,
350 bool using_spdy,
351 net::NextProto negotiated_protocol,
352 net::ProxyDelegate* proxy_delegate,
353 const net::NetworkTrafficAnnotationTag& traffic_annotation) override {
354 NOTIMPLEMENTED();
355 return nullptr;
356 }
357
358 net::IPEndPoint ip_;
359 // Simulated connect data
360 std::unique_ptr<net::MockConnect> tcp_connect_data_;
361 std::unique_ptr<net::MockConnect> ssl_connect_data_;
362 // Simulated read / write data
363 std::vector<net::MockWrite> writes_;
364 std::vector<net::MockRead> reads_;
365 std::unique_ptr<net::StaticSocketDataProvider> socket_data_provider_;
366 std::unique_ptr<net::SSLSocketDataProvider> ssl_socket_data_provider_;
367 bool socket_data_provider_paused_ = false;
368 // If true, makes TCP connection process stall. For timeout testing.
369 bool tcp_unresponsive_ = false;
370 std::unique_ptr<net::TransportClientSocket> tcp_client_socket_;
371 base::OnceClosure tls_socket_created_;
372
373 DISALLOW_COPY_AND_ASSIGN(TestSocketFactory);
374 };
375
376 class CastSocketTestBase : public testing::Test {
377 protected:
CastSocketTestBase()378 CastSocketTestBase()
379 : task_environment_(content::BrowserTaskEnvironment::IO_MAINLOOP),
380 url_request_context_(true),
381 logger_(new Logger()),
382 observer_(new MockCastSocketObserver()),
383 socket_open_params_(
384 CreateIPEndPointForTest(),
385 base::TimeDelta::FromMilliseconds(kDistantTimeoutMillis)),
386 client_socket_factory_(socket_open_params_.ip_endpoint) {}
~CastSocketTestBase()387 ~CastSocketTestBase() override {}
388
SetUp()389 void SetUp() override {
390 EXPECT_CALL(*observer_, OnMessage(_, _)).Times(0);
391
392 url_request_context_.set_client_socket_factory(&client_socket_factory_);
393 url_request_context_.Init();
394 network_context_ = std::make_unique<network::NetworkContext>(
395 nullptr, network_context_remote_.BindNewPipeAndPassReceiver(),
396 &url_request_context_,
397 /*cors_exempt_header_list=*/std::vector<std::string>());
398 }
399
400 // Runs all pending tasks in the message loop.
RunPendingTasks()401 void RunPendingTasks() {
402 base::RunLoop run_loop;
403 run_loop.RunUntilIdle();
404 }
405
client_socket_factory()406 TestSocketFactory* client_socket_factory() { return &client_socket_factory_; }
407
408 content::BrowserTaskEnvironment task_environment_;
409 net::TestURLRequestContext url_request_context_;
410 std::unique_ptr<network::NetworkContext> network_context_;
411 mojo::Remote<network::mojom::NetworkContext> network_context_remote_;
412 Logger* logger_;
413 CompleteHandler handler_;
414 std::unique_ptr<MockCastSocketObserver> observer_;
415 CastSocketOpenParams socket_open_params_;
416 TestSocketFactory client_socket_factory_;
417
418 private:
419 DISALLOW_COPY_AND_ASSIGN(CastSocketTestBase);
420 };
421
422 class MockCastSocketTest : public CastSocketTestBase {
423 protected:
MockCastSocketTest()424 MockCastSocketTest() {}
425
TearDown()426 void TearDown() override {
427 if (socket_) {
428 EXPECT_CALL(handler_, OnCloseComplete(net::OK));
429 socket_->Close(base::BindOnce(&CompleteHandler::OnCloseComplete,
430 base::Unretained(&handler_)));
431 }
432 }
433
CreateCastSocketSecure()434 void CreateCastSocketSecure() {
435 socket_ = MockTestCastSocket::CreateSecure(network_context_.get(),
436 socket_open_params_, logger_);
437 }
438
HandleAuthHandshake()439 void HandleAuthHandshake() {
440 socket_->SetupMockTransport();
441 CastMessage challenge_proto = CreateAuthChallenge();
442 EXPECT_CALL(*socket_->GetMockTransport(),
443 SendMessage(EqualsProto(challenge_proto), _))
444 .WillOnce(PostCompletionCallbackTask<1>(net::OK));
445 EXPECT_CALL(*socket_->GetMockTransport(), Start());
446 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
447 socket_->AddObserver(observer_.get());
448 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
449 base::Unretained(&handler_)));
450 RunPendingTasks();
451 socket_->GetMockTransport()->current_delegate()->OnMessage(
452 CreateAuthReply());
453 RunPendingTasks();
454 }
455
456 std::unique_ptr<MockTestCastSocket> socket_;
457
458 private:
459 DISALLOW_COPY_AND_ASSIGN(MockCastSocketTest);
460 };
461
462 class SslCastSocketTest : public CastSocketTestBase {
463 protected:
SslCastSocketTest()464 SslCastSocketTest() {}
465
TearDown()466 void TearDown() override {
467 if (socket_) {
468 EXPECT_CALL(handler_, OnCloseComplete(net::OK));
469 socket_->Close(base::BindOnce(&CompleteHandler::OnCloseComplete,
470 base::Unretained(&handler_)));
471 }
472 }
473
CreateSockets()474 void CreateSockets() {
475 socket_ = std::make_unique<TestCastSocketBase>(
476 network_context_.get(), socket_open_params_, logger_);
477
478 server_cert_ =
479 net::ImportCertFromFile(GetTestCertsDirectory(), "self_signed.pem");
480 ASSERT_TRUE(server_cert_);
481 server_private_key_ = ReadTestKeyFromPEM("self_signed.pem");
482 ASSERT_TRUE(server_private_key_);
483 server_context_ = CreateSSLServerContext(
484 server_cert_.get(), *server_private_key_, server_ssl_config_);
485
486 tcp_server_socket_.reset(
487 new net::TCPServerSocket(nullptr, net::NetLogSource()));
488 ASSERT_EQ(net::OK,
489 tcp_server_socket_->ListenWithAddressAndPort("127.0.0.1", 0, 1));
490 net::IPEndPoint server_address;
491 ASSERT_EQ(net::OK, tcp_server_socket_->GetLocalAddress(&server_address));
492 tcp_client_socket_.reset(
493 new net::TCPClientSocket(net::AddressList(server_address), nullptr,
494 nullptr, nullptr, net::NetLogSource()));
495
496 std::unique_ptr<net::StreamSocket> accepted_socket;
497 accept_result_ = tcp_server_socket_->Accept(
498 &accepted_socket, base::BindOnce(&SslCastSocketTest::TcpAcceptCallback,
499 base::Unretained(this)));
500 connect_result_ = tcp_client_socket_->Connect(base::BindOnce(
501 &SslCastSocketTest::TcpConnectCallback, base::Unretained(this)));
502 while (accept_result_ == net::ERR_IO_PENDING ||
503 connect_result_ == net::ERR_IO_PENDING) {
504 RunPendingTasks();
505 }
506 ASSERT_EQ(net::OK, accept_result_);
507 ASSERT_EQ(net::OK, connect_result_);
508 ASSERT_TRUE(accepted_socket);
509 ASSERT_TRUE(tcp_client_socket_->IsConnected());
510
511 server_socket_ =
512 server_context_->CreateSSLServerSocket(std::move(accepted_socket));
513 ASSERT_TRUE(server_socket_);
514
515 client_socket_factory()->SetTcpSocket(std::move(tcp_client_socket_));
516 }
517
ConnectSockets()518 void ConnectSockets() {
519 socket_->AddObserver(observer_.get());
520 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
521 base::Unretained(&handler_)));
522
523 net::TestCompletionCallback handshake_callback;
524 int server_ret = handshake_callback.GetResult(
525 server_socket_->Handshake(handshake_callback.callback()));
526
527 ASSERT_EQ(net::OK, server_ret);
528 }
529
TcpAcceptCallback(int result)530 void TcpAcceptCallback(int result) { accept_result_ = result; }
531
TcpConnectCallback(int result)532 void TcpConnectCallback(int result) { connect_result_ = result; }
533
ReadTestKeyFromPEM(const base::StringPiece & name)534 std::unique_ptr<crypto::RSAPrivateKey> ReadTestKeyFromPEM(
535 const base::StringPiece& name) {
536 base::FilePath key_path = GetTestCertsDirectory().AppendASCII(name);
537 std::string pem_data;
538 if (!base::ReadFileToString(key_path, &pem_data)) {
539 return nullptr;
540 }
541
542 const std::vector<std::string> headers({"PRIVATE KEY"});
543 net::PEMTokenizer pem_tokenizer(pem_data, headers);
544 if (!pem_tokenizer.GetNext()) {
545 return nullptr;
546 }
547 std::vector<uint8_t> key_vector(pem_tokenizer.data().begin(),
548 pem_tokenizer.data().end());
549 std::unique_ptr<crypto::RSAPrivateKey> key(
550 crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
551 return key;
552 }
553
ReadExactLength(net::IOBuffer * buffer,int buffer_length,net::Socket * socket)554 int ReadExactLength(net::IOBuffer* buffer,
555 int buffer_length,
556 net::Socket* socket) {
557 scoped_refptr<net::DrainableIOBuffer> draining_buffer =
558 base::MakeRefCounted<net::DrainableIOBuffer>(buffer, buffer_length);
559 while (draining_buffer->BytesRemaining() > 0) {
560 net::TestCompletionCallback read_callback;
561 int read_result = read_callback.GetResult(server_socket_->Read(
562 draining_buffer.get(), draining_buffer->BytesRemaining(),
563 read_callback.callback()));
564 EXPECT_GT(read_result, 0);
565 draining_buffer->DidConsume(read_result);
566 }
567 return buffer_length;
568 }
569
WriteExactLength(net::IOBuffer * buffer,int buffer_length,net::Socket * socket)570 int WriteExactLength(net::IOBuffer* buffer,
571 int buffer_length,
572 net::Socket* socket) {
573 scoped_refptr<net::DrainableIOBuffer> draining_buffer =
574 base::MakeRefCounted<net::DrainableIOBuffer>(buffer, buffer_length);
575 while (draining_buffer->BytesRemaining() > 0) {
576 net::TestCompletionCallback write_callback;
577 int write_result = write_callback.GetResult(server_socket_->Write(
578 draining_buffer.get(), draining_buffer->BytesRemaining(),
579 write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS));
580 EXPECT_GT(write_result, 0);
581 draining_buffer->DidConsume(write_result);
582 }
583 return buffer_length;
584 }
585
586 // Result values used for TCP socket setup. These should contain values from
587 // net::Error.
588 int accept_result_;
589 int connect_result_;
590
591 // Underlying TCP sockets for |socket_| to communicate with |server_socket_|
592 // when testing with the real SSL implementation.
593 std::unique_ptr<net::TransportClientSocket> tcp_client_socket_;
594 std::unique_ptr<net::TCPServerSocket> tcp_server_socket_;
595
596 std::unique_ptr<TestCastSocketBase> socket_;
597
598 // |server_socket_| is used for the *RealSSL tests in order to test the
599 // CastSocket over a real SSL socket. The other members below are used to
600 // initialize |server_socket_|.
601 std::unique_ptr<net::SSLServerSocket> server_socket_;
602 std::unique_ptr<net::SSLServerContext> server_context_;
603 std::unique_ptr<crypto::RSAPrivateKey> server_private_key_;
604 scoped_refptr<net::X509Certificate> server_cert_;
605 net::SSLServerConfig server_ssl_config_;
606
607 private:
608 DISALLOW_COPY_AND_ASSIGN(SslCastSocketTest);
609 };
610
611 } // namespace
612
613 // Tests that the following connection flow works:
614 // - TCP connection succeeds (async)
615 // - SSL connection succeeds (async)
616 // - Cert is extracted successfully
617 // - Challenge request is sent (async)
618 // - Challenge response is received (async)
619 // - Credentials are verified successfuly
TEST_F(MockCastSocketTest,TestConnectFullSecureFlowAsync)620 TEST_F(MockCastSocketTest, TestConnectFullSecureFlowAsync) {
621 CreateCastSocketSecure();
622 client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
623 client_socket_factory()->SetupSslConnect(net::ASYNC, net::OK);
624
625 HandleAuthHandshake();
626
627 EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
628 EXPECT_EQ(ChannelError::NONE, socket_->error_state());
629 }
630
631 // Tests that the following connection flow works:
632 // - TCP connection succeeds (sync)
633 // - SSL connection succeeds (sync)
634 // - Cert is extracted successfully
635 // - Challenge request is sent (sync)
636 // - Challenge response is received (sync)
637 // - Credentials are verified successfuly
TEST_F(MockCastSocketTest,TestConnectFullSecureFlowSync)638 TEST_F(MockCastSocketTest, TestConnectFullSecureFlowSync) {
639 client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
640 client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::OK);
641
642 CreateCastSocketSecure();
643 HandleAuthHandshake();
644
645 EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
646 EXPECT_EQ(ChannelError::NONE, socket_->error_state());
647 }
648
649 // Test that an AuthMessage with a mangled namespace triggers cancelation
650 // of the connection event loop.
TEST_F(MockCastSocketTest,TestConnectAuthMessageCorrupted)651 TEST_F(MockCastSocketTest, TestConnectAuthMessageCorrupted) {
652 CreateCastSocketSecure();
653 socket_->SetupMockTransport();
654
655 client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
656 client_socket_factory()->SetupSslConnect(net::ASYNC, net::OK);
657
658 CastMessage challenge_proto = CreateAuthChallenge();
659 EXPECT_CALL(*socket_->GetMockTransport(),
660 SendMessage(EqualsProto(challenge_proto), _))
661 .WillOnce(PostCompletionCallbackTask<1>(net::OK));
662 EXPECT_CALL(*socket_->GetMockTransport(), Start());
663 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
664 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
665 base::Unretained(&handler_)));
666 RunPendingTasks();
667 CastMessage mangled_auth_reply = CreateAuthReply();
668 mangled_auth_reply.set_namespace_("BOGUS_NAMESPACE");
669
670 socket_->GetMockTransport()->current_delegate()->OnMessage(
671 mangled_auth_reply);
672 RunPendingTasks();
673
674 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
675 EXPECT_EQ(ChannelError::TRANSPORT_ERROR, socket_->error_state());
676
677 // Verifies that the CastSocket's resources were torn down during channel
678 // close. (see http://crbug.com/504078)
679 EXPECT_EQ(nullptr, socket_->transport());
680 }
681
682 // Test connection error - TCP connect fails (async)
TEST_F(MockCastSocketTest,TestConnectTcpConnectErrorAsync)683 TEST_F(MockCastSocketTest, TestConnectTcpConnectErrorAsync) {
684 CreateCastSocketSecure();
685
686 client_socket_factory()->SetupTcpConnect(net::ASYNC, net::ERR_FAILED);
687
688 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
689 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
690 base::Unretained(&handler_)));
691 RunPendingTasks();
692
693 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
694 EXPECT_EQ(ChannelError::CONNECT_ERROR, socket_->error_state());
695 }
696
697 // Test connection error - TCP connect fails (sync)
TEST_F(MockCastSocketTest,TestConnectTcpConnectErrorSync)698 TEST_F(MockCastSocketTest, TestConnectTcpConnectErrorSync) {
699 CreateCastSocketSecure();
700
701 client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::ERR_FAILED);
702
703 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
704 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
705 base::Unretained(&handler_)));
706 RunPendingTasks();
707
708 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
709 EXPECT_EQ(ChannelError::CONNECT_ERROR, socket_->error_state());
710 }
711
712 // Test connection error - timeout
TEST_F(MockCastSocketTest,TestConnectTcpTimeoutError)713 TEST_F(MockCastSocketTest, TestConnectTcpTimeoutError) {
714 CreateCastSocketSecure();
715 client_socket_factory()->SetupTcpConnectUnresponsive();
716 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
717 EXPECT_CALL(*observer_, OnError(_, ChannelError::CONNECT_TIMEOUT));
718 socket_->AddObserver(observer_.get());
719 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
720 base::Unretained(&handler_)));
721 RunPendingTasks();
722
723 EXPECT_EQ(ReadyState::CONNECTING, socket_->ready_state());
724 EXPECT_EQ(ChannelError::NONE, socket_->error_state());
725 socket_->TriggerTimeout();
726 RunPendingTasks();
727
728 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
729 EXPECT_EQ(ChannelError::CONNECT_TIMEOUT, socket_->error_state());
730 }
731
732 // Test connection error - TCP socket returns timeout
TEST_F(MockCastSocketTest,TestConnectTcpSocketTimeoutError)733 TEST_F(MockCastSocketTest, TestConnectTcpSocketTimeoutError) {
734 CreateCastSocketSecure();
735 client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS,
736 net::ERR_CONNECTION_TIMED_OUT);
737 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
738 EXPECT_CALL(*observer_, OnError(_, ChannelError::CONNECT_TIMEOUT));
739 socket_->AddObserver(observer_.get());
740 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
741 base::Unretained(&handler_)));
742 RunPendingTasks();
743
744 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
745 EXPECT_EQ(ChannelError::CONNECT_TIMEOUT, socket_->error_state());
746 EXPECT_EQ(net::ERR_CONNECTION_TIMED_OUT,
747 logger_->GetLastError(socket_->id()).net_return_value);
748 }
749
750 // Test connection error - SSL connect fails (async)
TEST_F(MockCastSocketTest,TestConnectSslConnectErrorAsync)751 TEST_F(MockCastSocketTest, TestConnectSslConnectErrorAsync) {
752 CreateCastSocketSecure();
753
754 client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
755 client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::ERR_FAILED);
756
757 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
758 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
759 base::Unretained(&handler_)));
760 RunPendingTasks();
761
762 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
763 EXPECT_EQ(ChannelError::AUTHENTICATION_ERROR, socket_->error_state());
764 }
765
766 // Test connection error - SSL connect fails (sync)
TEST_F(MockCastSocketTest,TestConnectSslConnectErrorSync)767 TEST_F(MockCastSocketTest, TestConnectSslConnectErrorSync) {
768 CreateCastSocketSecure();
769
770 client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
771 client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::ERR_FAILED);
772
773 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
774 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
775 base::Unretained(&handler_)));
776 RunPendingTasks();
777
778 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
779 EXPECT_EQ(ChannelError::AUTHENTICATION_ERROR, socket_->error_state());
780 EXPECT_EQ(net::ERR_FAILED,
781 logger_->GetLastError(socket_->id()).net_return_value);
782 }
783
784 // Test connection error - SSL connect times out (sync)
TEST_F(MockCastSocketTest,TestConnectSslConnectTimeoutSync)785 TEST_F(MockCastSocketTest, TestConnectSslConnectTimeoutSync) {
786 client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
787 client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS,
788 net::ERR_CONNECTION_TIMED_OUT);
789
790 CreateCastSocketSecure();
791
792 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
793 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
794 base::Unretained(&handler_)));
795 RunPendingTasks();
796
797 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
798 EXPECT_EQ(ChannelError::CONNECT_TIMEOUT, socket_->error_state());
799 EXPECT_EQ(net::ERR_CONNECTION_TIMED_OUT,
800 logger_->GetLastError(socket_->id()).net_return_value);
801 }
802
803 // Test connection error - SSL connect times out (async)
TEST_F(MockCastSocketTest,TestConnectSslConnectTimeoutAsync)804 TEST_F(MockCastSocketTest, TestConnectSslConnectTimeoutAsync) {
805 CreateCastSocketSecure();
806
807 client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
808 client_socket_factory()->SetupSslConnect(net::ASYNC,
809 net::ERR_CONNECTION_TIMED_OUT);
810
811 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
812 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
813 base::Unretained(&handler_)));
814 RunPendingTasks();
815
816 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
817 EXPECT_EQ(ChannelError::CONNECT_TIMEOUT, socket_->error_state());
818 }
819
820 // Test connection error - challenge send fails
TEST_F(MockCastSocketTest,TestConnectChallengeSendError)821 TEST_F(MockCastSocketTest, TestConnectChallengeSendError) {
822 CreateCastSocketSecure();
823 socket_->SetupMockTransport();
824
825 client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
826 client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::OK);
827 EXPECT_CALL(*socket_->GetMockTransport(),
828 SendMessage(EqualsProto(CreateAuthChallenge()), _))
829 .WillOnce(PostCompletionCallbackTask<1>(net::ERR_CONNECTION_RESET));
830
831 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
832 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
833 base::Unretained(&handler_)));
834 RunPendingTasks();
835
836 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
837 EXPECT_EQ(ChannelError::CAST_SOCKET_ERROR, socket_->error_state());
838 }
839
840 // Test connection error - connection is destroyed after the challenge is
841 // sent, with the async result still lurking in the task queue.
TEST_F(MockCastSocketTest,TestConnectDestroyedAfterChallengeSent)842 TEST_F(MockCastSocketTest, TestConnectDestroyedAfterChallengeSent) {
843 CreateCastSocketSecure();
844 socket_->SetupMockTransport();
845 client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
846 client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::OK);
847 EXPECT_CALL(*socket_->GetMockTransport(),
848 SendMessage(EqualsProto(CreateAuthChallenge()), _))
849 .WillOnce(PostCompletionCallbackTask<1>(net::ERR_CONNECTION_RESET));
850 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
851 base::Unretained(&handler_)));
852 RunPendingTasks();
853 socket_.reset();
854 RunPendingTasks();
855 }
856
857 // Test connection error - challenge reply receive fails
TEST_F(MockCastSocketTest,TestConnectChallengeReplyReceiveError)858 TEST_F(MockCastSocketTest, TestConnectChallengeReplyReceiveError) {
859 CreateCastSocketSecure();
860 socket_->SetupMockTransport();
861
862 client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
863 client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::OK);
864 EXPECT_CALL(*socket_->GetMockTransport(),
865 SendMessage(EqualsProto(CreateAuthChallenge()), _))
866 .WillOnce(PostCompletionCallbackTask<1>(net::OK));
867 client_socket_factory()->AddReadResult(net::SYNCHRONOUS, net::ERR_FAILED);
868 EXPECT_CALL(*observer_, OnError(_, ChannelError::CAST_SOCKET_ERROR));
869 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
870 EXPECT_CALL(*socket_->GetMockTransport(), Start());
871 socket_->AddObserver(observer_.get());
872 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
873 base::Unretained(&handler_)));
874 RunPendingTasks();
875 socket_->GetMockTransport()->current_delegate()->OnError(
876 ChannelError::CAST_SOCKET_ERROR);
877 RunPendingTasks();
878
879 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
880 EXPECT_EQ(ChannelError::CAST_SOCKET_ERROR, socket_->error_state());
881 }
882
TEST_F(MockCastSocketTest,TestConnectChallengeVerificationFails)883 TEST_F(MockCastSocketTest, TestConnectChallengeVerificationFails) {
884 CreateCastSocketSecure();
885 socket_->SetupMockTransport();
886 client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
887 client_socket_factory()->SetupSslConnect(net::ASYNC, net::OK);
888 socket_->SetVerifyChallengeResult(false);
889
890 EXPECT_CALL(*observer_, OnError(_, ChannelError::AUTHENTICATION_ERROR));
891 CastMessage challenge_proto = CreateAuthChallenge();
892 EXPECT_CALL(*socket_->GetMockTransport(),
893 SendMessage(EqualsProto(challenge_proto), _))
894 .WillOnce(PostCompletionCallbackTask<1>(net::OK));
895 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
896 EXPECT_CALL(*socket_->GetMockTransport(), Start());
897 socket_->AddObserver(observer_.get());
898 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
899 base::Unretained(&handler_)));
900 RunPendingTasks();
901 socket_->GetMockTransport()->current_delegate()->OnMessage(CreateAuthReply());
902 RunPendingTasks();
903
904 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
905 EXPECT_EQ(ChannelError::AUTHENTICATION_ERROR, socket_->error_state());
906 }
907
908 // Sends message data through an actual non-mocked CastTransport object,
909 // testing the two components in integration.
TEST_F(MockCastSocketTest,TestConnectEndToEndWithRealTransportAsync)910 TEST_F(MockCastSocketTest, TestConnectEndToEndWithRealTransportAsync) {
911 CreateCastSocketSecure();
912 client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
913 client_socket_factory()->SetupSslConnect(net::ASYNC, net::OK);
914
915 // Set low-level auth challenge expectations.
916 CastMessage challenge = CreateAuthChallenge();
917 std::string challenge_str;
918 EXPECT_TRUE(MessageFramer::Serialize(challenge, &challenge_str));
919 client_socket_factory()->AddWriteResultForData(net::ASYNC, challenge_str);
920
921 // Set low-level auth reply expectations.
922 CastMessage reply = CreateAuthReply();
923 std::string reply_str;
924 EXPECT_TRUE(MessageFramer::Serialize(reply, &reply_str));
925 client_socket_factory()->AddReadResultForData(net::ASYNC, reply_str);
926 client_socket_factory()->AddReadResult(net::ASYNC, net::ERR_IO_PENDING);
927 // Make sure the data is ready by the TLS socket and not the TCP socket.
928 client_socket_factory()->Pause();
929 client_socket_factory()->SetTLSSocketCreatedClosure(
930 base::BindLambdaForTesting([&] { client_socket_factory()->Resume(); }));
931
932 CastMessage test_message = CreateTestMessage();
933 std::string test_message_str;
934 EXPECT_TRUE(MessageFramer::Serialize(test_message, &test_message_str));
935 client_socket_factory()->AddWriteResultForData(net::ASYNC, test_message_str);
936
937 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
938 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
939 base::Unretained(&handler_)));
940 RunPendingTasks();
941 EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
942 EXPECT_EQ(ChannelError::NONE, socket_->error_state());
943
944 // Send the test message through a real transport object.
945 EXPECT_CALL(handler_, OnWriteComplete(net::OK));
946 socket_->transport()->SendMessage(
947 test_message, base::BindOnce(&CompleteHandler::OnWriteComplete,
948 base::Unretained(&handler_)));
949 RunPendingTasks();
950
951 EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
952 EXPECT_EQ(ChannelError::NONE, socket_->error_state());
953 }
954
955 // Same as TestConnectEndToEndWithRealTransportAsync, except synchronous.
TEST_F(MockCastSocketTest,TestConnectEndToEndWithRealTransportSync)956 TEST_F(MockCastSocketTest, TestConnectEndToEndWithRealTransportSync) {
957 CreateCastSocketSecure();
958 client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
959 client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::OK);
960
961 // Set low-level auth challenge expectations.
962 CastMessage challenge = CreateAuthChallenge();
963 std::string challenge_str;
964 EXPECT_TRUE(MessageFramer::Serialize(challenge, &challenge_str));
965 client_socket_factory()->AddWriteResultForData(net::SYNCHRONOUS,
966 challenge_str);
967
968 // Set low-level auth reply expectations.
969 CastMessage reply = CreateAuthReply();
970 std::string reply_str;
971 EXPECT_TRUE(MessageFramer::Serialize(reply, &reply_str));
972 client_socket_factory()->AddReadResultForData(net::SYNCHRONOUS, reply_str);
973 client_socket_factory()->AddReadResult(net::ASYNC, net::ERR_IO_PENDING);
974 // Make sure the data is ready by the TLS socket and not the TCP socket.
975 client_socket_factory()->Pause();
976 client_socket_factory()->SetTLSSocketCreatedClosure(
977 base::BindLambdaForTesting([&] { client_socket_factory()->Resume(); }));
978
979 CastMessage test_message = CreateTestMessage();
980 std::string test_message_str;
981 EXPECT_TRUE(MessageFramer::Serialize(test_message, &test_message_str));
982 client_socket_factory()->AddWriteResultForData(net::SYNCHRONOUS,
983 test_message_str);
984
985 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
986 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
987 base::Unretained(&handler_)));
988 RunPendingTasks();
989 EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
990 EXPECT_EQ(ChannelError::NONE, socket_->error_state());
991
992 // Send the test message through a real transport object.
993 EXPECT_CALL(handler_, OnWriteComplete(net::OK));
994 socket_->transport()->SendMessage(
995 test_message, base::BindOnce(&CompleteHandler::OnWriteComplete,
996 base::Unretained(&handler_)));
997 RunPendingTasks();
998
999 EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
1000 EXPECT_EQ(ChannelError::NONE, socket_->error_state());
1001 }
1002
TEST_F(MockCastSocketTest,TestObservers)1003 TEST_F(MockCastSocketTest, TestObservers) {
1004 CreateCastSocketSecure();
1005 // Test AddObserever
1006 MockCastSocketObserver observer1;
1007 MockCastSocketObserver observer2;
1008 socket_->AddObserver(&observer1);
1009 socket_->AddObserver(&observer1);
1010 socket_->AddObserver(&observer2);
1011 socket_->AddObserver(&observer2);
1012
1013 // Test notify observers
1014 EXPECT_CALL(observer1, OnError(_, cast_channel::ChannelError::CONNECT_ERROR));
1015 EXPECT_CALL(observer2, OnError(_, cast_channel::ChannelError::CONNECT_ERROR));
1016 CastSocketImpl::CastSocketMessageDelegate delegate(socket_.get());
1017 delegate.OnError(cast_channel::ChannelError::CONNECT_ERROR);
1018 }
1019
TEST_F(MockCastSocketTest,TestOpenChannelConnectingSocket)1020 TEST_F(MockCastSocketTest, TestOpenChannelConnectingSocket) {
1021 CreateCastSocketSecure();
1022 client_socket_factory()->SetupTcpConnectUnresponsive();
1023 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
1024 base::Unretained(&handler_)));
1025 RunPendingTasks();
1026
1027 EXPECT_CALL(handler_, OnConnectComplete(socket_.get())).Times(2);
1028 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
1029 base::Unretained(&handler_)));
1030 socket_->TriggerTimeout();
1031 RunPendingTasks();
1032 }
1033
TEST_F(MockCastSocketTest,TestOpenChannelConnectedSocket)1034 TEST_F(MockCastSocketTest, TestOpenChannelConnectedSocket) {
1035 CreateCastSocketSecure();
1036 client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
1037 client_socket_factory()->SetupSslConnect(net::ASYNC, net::OK);
1038 HandleAuthHandshake();
1039
1040 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
1041 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
1042 base::Unretained(&handler_)));
1043 }
1044
TEST_F(MockCastSocketTest,TestOpenChannelClosedSocket)1045 TEST_F(MockCastSocketTest, TestOpenChannelClosedSocket) {
1046 CreateCastSocketSecure();
1047 client_socket_factory()->SetupTcpConnect(net::ASYNC, net::ERR_FAILED);
1048
1049 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
1050 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
1051 base::Unretained(&handler_)));
1052 RunPendingTasks();
1053
1054 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
1055 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
1056 base::Unretained(&handler_)));
1057 }
1058
1059 // https://crbug.com/874491, flaky on Win and Mac
1060 #if defined(OS_WIN) || defined(OS_APPLE)
1061 #define MAYBE_TestConnectEndToEndWithRealSSL \
1062 DISABLED_TestConnectEndToEndWithRealSSL
1063 #else
1064 #define MAYBE_TestConnectEndToEndWithRealSSL TestConnectEndToEndWithRealSSL
1065 #endif
1066 // Tests connecting through an actual non-mocked CastTransport object and
1067 // non-mocked SSLClientSocket, testing the components in integration.
TEST_F(SslCastSocketTest,MAYBE_TestConnectEndToEndWithRealSSL)1068 TEST_F(SslCastSocketTest, MAYBE_TestConnectEndToEndWithRealSSL) {
1069 CreateSockets();
1070 ConnectSockets();
1071
1072 // Set low-level auth challenge expectations.
1073 CastMessage challenge = CreateAuthChallenge();
1074 std::string challenge_str;
1075 EXPECT_TRUE(MessageFramer::Serialize(challenge, &challenge_str));
1076
1077 int challenge_buffer_length = challenge_str.size();
1078 scoped_refptr<net::IOBuffer> challenge_buffer =
1079 base::MakeRefCounted<net::IOBuffer>(challenge_buffer_length);
1080 int read = ReadExactLength(challenge_buffer.get(), challenge_buffer_length,
1081 server_socket_.get());
1082
1083 EXPECT_EQ(challenge_buffer_length, read);
1084 EXPECT_EQ(challenge_str,
1085 std::string(challenge_buffer->data(), challenge_buffer_length));
1086
1087 // Set low-level auth reply expectations.
1088 CastMessage reply = CreateAuthReply();
1089 std::string reply_str;
1090 EXPECT_TRUE(MessageFramer::Serialize(reply, &reply_str));
1091
1092 scoped_refptr<net::StringIOBuffer> reply_buffer =
1093 base::MakeRefCounted<net::StringIOBuffer>(reply_str);
1094 int written = WriteExactLength(reply_buffer.get(), reply_buffer->size(),
1095 server_socket_.get());
1096
1097 EXPECT_EQ(reply_buffer->size(), written);
1098 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
1099 RunPendingTasks();
1100
1101 EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
1102 EXPECT_EQ(ChannelError::NONE, socket_->error_state());
1103 }
1104
1105 // Sends message data through an actual non-mocked CastTransport object and
1106 // non-mocked SSLClientSocket, testing the components in integration.
TEST_F(SslCastSocketTest,DISABLED_TestMessageEndToEndWithRealSSL)1107 TEST_F(SslCastSocketTest, DISABLED_TestMessageEndToEndWithRealSSL) {
1108 CreateSockets();
1109 ConnectSockets();
1110
1111 // Set low-level auth challenge expectations.
1112 CastMessage challenge = CreateAuthChallenge();
1113 std::string challenge_str;
1114 EXPECT_TRUE(MessageFramer::Serialize(challenge, &challenge_str));
1115
1116 int challenge_buffer_length = challenge_str.size();
1117 scoped_refptr<net::IOBuffer> challenge_buffer =
1118 base::MakeRefCounted<net::IOBuffer>(challenge_buffer_length);
1119
1120 int read = ReadExactLength(challenge_buffer.get(), challenge_buffer_length,
1121 server_socket_.get());
1122
1123 EXPECT_EQ(challenge_buffer_length, read);
1124 EXPECT_EQ(challenge_str,
1125 std::string(challenge_buffer->data(), challenge_buffer_length));
1126
1127 // Set low-level auth reply expectations.
1128 CastMessage reply = CreateAuthReply();
1129 std::string reply_str;
1130 EXPECT_TRUE(MessageFramer::Serialize(reply, &reply_str));
1131
1132 scoped_refptr<net::StringIOBuffer> reply_buffer =
1133 base::MakeRefCounted<net::StringIOBuffer>(reply_str);
1134 int written = WriteExactLength(reply_buffer.get(), reply_buffer->size(),
1135 server_socket_.get());
1136
1137 EXPECT_EQ(reply_buffer->size(), written);
1138 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
1139 RunPendingTasks();
1140
1141 EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
1142 EXPECT_EQ(ChannelError::NONE, socket_->error_state());
1143
1144 // Send a test message through the ssl socket.
1145 CastMessage test_message = CreateTestMessage();
1146 std::string test_message_str;
1147 EXPECT_TRUE(MessageFramer::Serialize(test_message, &test_message_str));
1148
1149 int test_message_length = test_message_str.size();
1150 scoped_refptr<net::IOBuffer> test_message_buffer =
1151 base::MakeRefCounted<net::IOBuffer>(test_message_length);
1152
1153 EXPECT_CALL(handler_, OnWriteComplete(net::OK));
1154 socket_->transport()->SendMessage(
1155 test_message, base::BindOnce(&CompleteHandler::OnWriteComplete,
1156 base::Unretained(&handler_)));
1157 RunPendingTasks();
1158
1159 read = ReadExactLength(test_message_buffer.get(), test_message_length,
1160 server_socket_.get());
1161
1162 EXPECT_EQ(test_message_length, read);
1163 EXPECT_EQ(test_message_str,
1164 std::string(test_message_buffer->data(), test_message_length));
1165
1166 EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
1167 EXPECT_EQ(ChannelError::NONE, socket_->error_state());
1168 }
1169
1170 } // namespace cast_channel
1171