1 /*
2  *  Copyright 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "pc/dtls_transport.h"
12 
13 #include <utility>
14 #include <vector>
15 
16 #include "absl/memory/memory.h"
17 #include "p2p/base/fake_dtls_transport.h"
18 #include "rtc_base/gunit.h"
19 #include "test/gmock.h"
20 #include "test/gtest.h"
21 
22 constexpr int kDefaultTimeout = 1000;  // milliseconds
23 constexpr int kNonsenseCipherSuite = 1234;
24 
25 using cricket::FakeDtlsTransport;
26 using ::testing::ElementsAre;
27 
28 namespace webrtc {
29 
30 class TestDtlsTransportObserver : public DtlsTransportObserverInterface {
31  public:
OnStateChange(DtlsTransportInformation info)32   void OnStateChange(DtlsTransportInformation info) override {
33     state_change_called_ = true;
34     states_.push_back(info.state());
35     info_ = info;
36   }
37 
OnError(RTCError error)38   void OnError(RTCError error) override {}
39 
state()40   DtlsTransportState state() {
41     if (states_.size() > 0) {
42       return states_[states_.size() - 1];
43     } else {
44       return DtlsTransportState::kNew;
45     }
46   }
47 
48   bool state_change_called_ = false;
49   DtlsTransportInformation info_;
50   std::vector<DtlsTransportState> states_;
51 };
52 
53 class DtlsTransportTest : public ::testing::Test {
54  public:
transport()55   DtlsTransport* transport() { return transport_.get(); }
observer()56   DtlsTransportObserverInterface* observer() { return &observer_; }
57 
CreateTransport(rtc::FakeSSLCertificate * certificate=nullptr)58   void CreateTransport(rtc::FakeSSLCertificate* certificate = nullptr) {
59     auto cricket_transport = std::make_unique<FakeDtlsTransport>(
60         "audio", cricket::ICE_CANDIDATE_COMPONENT_RTP);
61     if (certificate) {
62       cricket_transport->SetRemoteSSLCertificate(certificate);
63     }
64     cricket_transport->SetSslCipherSuite(kNonsenseCipherSuite);
65     transport_ =
66         new rtc::RefCountedObject<DtlsTransport>(std::move(cricket_transport));
67   }
68 
CompleteDtlsHandshake()69   void CompleteDtlsHandshake() {
70     auto fake_dtls1 = static_cast<FakeDtlsTransport*>(transport_->internal());
71     auto fake_dtls2 = std::make_unique<FakeDtlsTransport>(
72         "audio", cricket::ICE_CANDIDATE_COMPONENT_RTP);
73     auto cert1 = rtc::RTCCertificate::Create(
74         rtc::SSLIdentity::Create("session1", rtc::KT_DEFAULT));
75     fake_dtls1->SetLocalCertificate(cert1);
76     auto cert2 = rtc::RTCCertificate::Create(
77         rtc::SSLIdentity::Create("session1", rtc::KT_DEFAULT));
78     fake_dtls2->SetLocalCertificate(cert2);
79     fake_dtls1->SetDestination(fake_dtls2.get());
80   }
81 
82   rtc::scoped_refptr<DtlsTransport> transport_;
83   TestDtlsTransportObserver observer_;
84 };
85 
TEST_F(DtlsTransportTest,CreateClearDelete)86 TEST_F(DtlsTransportTest, CreateClearDelete) {
87   auto cricket_transport = std::make_unique<FakeDtlsTransport>(
88       "audio", cricket::ICE_CANDIDATE_COMPONENT_RTP);
89   rtc::scoped_refptr<DtlsTransport> webrtc_transport =
90       new rtc::RefCountedObject<DtlsTransport>(std::move(cricket_transport));
91   ASSERT_TRUE(webrtc_transport->internal());
92   ASSERT_EQ(DtlsTransportState::kNew, webrtc_transport->Information().state());
93   webrtc_transport->Clear();
94   ASSERT_FALSE(webrtc_transport->internal());
95   ASSERT_EQ(DtlsTransportState::kClosed,
96             webrtc_transport->Information().state());
97 }
98 
TEST_F(DtlsTransportTest,EventsObservedWhenConnecting)99 TEST_F(DtlsTransportTest, EventsObservedWhenConnecting) {
100   CreateTransport();
101   transport()->RegisterObserver(observer());
102   CompleteDtlsHandshake();
103   ASSERT_TRUE_WAIT(observer_.state_change_called_, kDefaultTimeout);
104   EXPECT_THAT(
105       observer_.states_,
106       ElementsAre(  // FakeDtlsTransport doesn't signal the "connecting" state.
107                     // TODO(hta): fix FakeDtlsTransport or file bug on it.
108                     // DtlsTransportState::kConnecting,
109           DtlsTransportState::kConnected));
110 }
111 
TEST_F(DtlsTransportTest,CloseWhenClearing)112 TEST_F(DtlsTransportTest, CloseWhenClearing) {
113   CreateTransport();
114   transport()->RegisterObserver(observer());
115   CompleteDtlsHandshake();
116   ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kConnected,
117                    kDefaultTimeout);
118   transport()->Clear();
119   ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kClosed,
120                    kDefaultTimeout);
121 }
122 
TEST_F(DtlsTransportTest,CertificateAppearsOnConnect)123 TEST_F(DtlsTransportTest, CertificateAppearsOnConnect) {
124   rtc::FakeSSLCertificate fake_certificate("fake data");
125   CreateTransport(&fake_certificate);
126   transport()->RegisterObserver(observer());
127   CompleteDtlsHandshake();
128   ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kConnected,
129                    kDefaultTimeout);
130   EXPECT_TRUE(observer_.info_.remote_ssl_certificates() != nullptr);
131 }
132 
TEST_F(DtlsTransportTest,CertificateDisappearsOnClose)133 TEST_F(DtlsTransportTest, CertificateDisappearsOnClose) {
134   rtc::FakeSSLCertificate fake_certificate("fake data");
135   CreateTransport(&fake_certificate);
136   transport()->RegisterObserver(observer());
137   CompleteDtlsHandshake();
138   ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kConnected,
139                    kDefaultTimeout);
140   EXPECT_TRUE(observer_.info_.remote_ssl_certificates() != nullptr);
141   transport()->Clear();
142   ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kClosed,
143                    kDefaultTimeout);
144   EXPECT_FALSE(observer_.info_.remote_ssl_certificates());
145 }
146 
TEST_F(DtlsTransportTest,CipherSuiteVisibleWhenConnected)147 TEST_F(DtlsTransportTest, CipherSuiteVisibleWhenConnected) {
148   CreateTransport();
149   transport()->RegisterObserver(observer());
150   CompleteDtlsHandshake();
151   ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kConnected,
152                    kDefaultTimeout);
153   ASSERT_TRUE(observer_.info_.ssl_cipher_suite());
154   EXPECT_EQ(kNonsenseCipherSuite, *observer_.info_.ssl_cipher_suite());
155   transport()->Clear();
156   ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kClosed,
157                    kDefaultTimeout);
158   EXPECT_FALSE(observer_.info_.ssl_cipher_suite());
159 }
160 
161 }  // namespace webrtc
162