1 /*
2  *  Copyright 2019 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "pc/sctp_transport.h"
12 
13 #include <utility>
14 #include <vector>
15 
16 #include "absl/memory/memory.h"
17 #include "p2p/base/fake_dtls_transport.h"
18 #include "pc/dtls_transport.h"
19 #include "rtc_base/gunit.h"
20 #include "test/gmock.h"
21 #include "test/gtest.h"
22 
23 constexpr int kDefaultTimeout = 1000;  // milliseconds
24 constexpr int kTestMaxSctpStreams = 1234;
25 
26 using cricket::FakeDtlsTransport;
27 using ::testing::ElementsAre;
28 
29 namespace webrtc {
30 
31 namespace {
32 
33 class FakeCricketSctpTransport : public cricket::SctpTransportInternal {
34  public:
SetDtlsTransport(rtc::PacketTransportInternal * transport)35   void SetDtlsTransport(rtc::PacketTransportInternal* transport) override {}
Start(int local_port,int remote_port,int max_message_size)36   bool Start(int local_port, int remote_port, int max_message_size) override {
37     return true;
38   }
OpenStream(int sid)39   bool OpenStream(int sid) override { return true; }
ResetStream(int sid)40   bool ResetStream(int sid) override { return true; }
SendData(const cricket::SendDataParams & params,const rtc::CopyOnWriteBuffer & payload,cricket::SendDataResult * result=nullptr)41   bool SendData(const cricket::SendDataParams& params,
42                 const rtc::CopyOnWriteBuffer& payload,
43                 cricket::SendDataResult* result = nullptr) override {
44     return true;
45   }
ReadyToSendData()46   bool ReadyToSendData() override { return true; }
set_debug_name_for_testing(const char * debug_name)47   void set_debug_name_for_testing(const char* debug_name) override {}
max_message_size() const48   int max_message_size() const override { return 0; }
max_outbound_streams() const49   absl::optional<int> max_outbound_streams() const override {
50     return max_outbound_streams_;
51   }
max_inbound_streams() const52   absl::optional<int> max_inbound_streams() const override {
53     return max_inbound_streams_;
54   }
55   // Methods exposed for testing
SendSignalReadyToSendData()56   void SendSignalReadyToSendData() { SignalReadyToSendData(); }
57 
SendSignalAssociationChangeCommunicationUp()58   void SendSignalAssociationChangeCommunicationUp() {
59     SignalAssociationChangeCommunicationUp();
60   }
61 
SendSignalClosingProcedureStartedRemotely()62   void SendSignalClosingProcedureStartedRemotely() {
63     SignalClosingProcedureStartedRemotely(1);
64   }
65 
SendSignalClosingProcedureComplete()66   void SendSignalClosingProcedureComplete() {
67     SignalClosingProcedureComplete(1);
68   }
set_max_outbound_streams(int streams)69   void set_max_outbound_streams(int streams) {
70     max_outbound_streams_ = streams;
71   }
set_max_inbound_streams(int streams)72   void set_max_inbound_streams(int streams) { max_inbound_streams_ = streams; }
73 
74  private:
75   absl::optional<int> max_outbound_streams_;
76   absl::optional<int> max_inbound_streams_;
77 };
78 
79 }  // namespace
80 
81 class TestSctpTransportObserver : public SctpTransportObserverInterface {
82  public:
TestSctpTransportObserver()83   TestSctpTransportObserver() : info_(SctpTransportState::kNew) {}
84 
OnStateChange(SctpTransportInformation info)85   void OnStateChange(SctpTransportInformation info) override {
86     info_ = info;
87     states_.push_back(info.state());
88   }
89 
State()90   SctpTransportState State() {
91     if (states_.size() > 0) {
92       return states_[states_.size() - 1];
93     } else {
94       return SctpTransportState::kNew;
95     }
96   }
97 
States()98   const std::vector<SctpTransportState>& States() { return states_; }
99 
LastReceivedInformation()100   const SctpTransportInformation LastReceivedInformation() { return info_; }
101 
102  private:
103   std::vector<SctpTransportState> states_;
104   SctpTransportInformation info_;
105 };
106 
107 class SctpTransportTest : public ::testing::Test {
108  public:
transport()109   SctpTransport* transport() { return transport_.get(); }
observer()110   SctpTransportObserverInterface* observer() { return &observer_; }
111 
CreateTransport()112   void CreateTransport() {
113     auto cricket_sctp_transport =
114         absl::WrapUnique(new FakeCricketSctpTransport());
115     transport_ = new rtc::RefCountedObject<SctpTransport>(
116         std::move(cricket_sctp_transport));
117   }
118 
AddDtlsTransport()119   void AddDtlsTransport() {
120     std::unique_ptr<cricket::DtlsTransportInternal> cricket_transport =
121         std::make_unique<FakeDtlsTransport>(
122             "audio", cricket::ICE_CANDIDATE_COMPONENT_RTP);
123     dtls_transport_ =
124         new rtc::RefCountedObject<DtlsTransport>(std::move(cricket_transport));
125     transport_->SetDtlsTransport(dtls_transport_);
126   }
127 
CompleteSctpHandshake()128   void CompleteSctpHandshake() {
129     CricketSctpTransport()->SendSignalReadyToSendData();
130     // The computed MaxChannels shall be the minimum of the outgoing
131     // and incoming # of streams.
132     CricketSctpTransport()->set_max_outbound_streams(kTestMaxSctpStreams);
133     CricketSctpTransport()->set_max_inbound_streams(kTestMaxSctpStreams + 1);
134     CricketSctpTransport()->SendSignalAssociationChangeCommunicationUp();
135   }
136 
CricketSctpTransport()137   FakeCricketSctpTransport* CricketSctpTransport() {
138     return static_cast<FakeCricketSctpTransport*>(transport_->internal());
139   }
140 
141   rtc::scoped_refptr<SctpTransport> transport_;
142   rtc::scoped_refptr<DtlsTransport> dtls_transport_;
143   TestSctpTransportObserver observer_;
144 };
145 
TEST(SctpTransportSimpleTest,CreateClearDelete)146 TEST(SctpTransportSimpleTest, CreateClearDelete) {
147   std::unique_ptr<cricket::SctpTransportInternal> fake_cricket_sctp_transport =
148       absl::WrapUnique(new FakeCricketSctpTransport());
149   rtc::scoped_refptr<SctpTransport> sctp_transport =
150       new rtc::RefCountedObject<SctpTransport>(
151           std::move(fake_cricket_sctp_transport));
152   ASSERT_TRUE(sctp_transport->internal());
153   ASSERT_EQ(SctpTransportState::kNew, sctp_transport->Information().state());
154   sctp_transport->Clear();
155   ASSERT_FALSE(sctp_transport->internal());
156   ASSERT_EQ(SctpTransportState::kClosed, sctp_transport->Information().state());
157 }
158 
TEST_F(SctpTransportTest,EventsObservedWhenConnecting)159 TEST_F(SctpTransportTest, EventsObservedWhenConnecting) {
160   CreateTransport();
161   transport()->RegisterObserver(observer());
162   AddDtlsTransport();
163   CompleteSctpHandshake();
164   ASSERT_EQ_WAIT(SctpTransportState::kConnected, observer_.State(),
165                  kDefaultTimeout);
166   EXPECT_THAT(observer_.States(), ElementsAre(SctpTransportState::kConnecting,
167                                               SctpTransportState::kConnected));
168 }
169 
TEST_F(SctpTransportTest,CloseWhenClearing)170 TEST_F(SctpTransportTest, CloseWhenClearing) {
171   CreateTransport();
172   transport()->RegisterObserver(observer());
173   AddDtlsTransport();
174   CompleteSctpHandshake();
175   ASSERT_EQ_WAIT(SctpTransportState::kConnected, observer_.State(),
176                  kDefaultTimeout);
177   transport()->Clear();
178   ASSERT_EQ_WAIT(SctpTransportState::kClosed, observer_.State(),
179                  kDefaultTimeout);
180 }
181 
TEST_F(SctpTransportTest,MaxChannelsSignalled)182 TEST_F(SctpTransportTest, MaxChannelsSignalled) {
183   CreateTransport();
184   transport()->RegisterObserver(observer());
185   AddDtlsTransport();
186   EXPECT_FALSE(transport()->Information().MaxChannels());
187   EXPECT_FALSE(observer_.LastReceivedInformation().MaxChannels());
188   CompleteSctpHandshake();
189   ASSERT_EQ_WAIT(SctpTransportState::kConnected, observer_.State(),
190                  kDefaultTimeout);
191   EXPECT_TRUE(transport()->Information().MaxChannels());
192   EXPECT_EQ(kTestMaxSctpStreams, *(transport()->Information().MaxChannels()));
193   EXPECT_TRUE(observer_.LastReceivedInformation().MaxChannels());
194   EXPECT_EQ(kTestMaxSctpStreams,
195             *(observer_.LastReceivedInformation().MaxChannels()));
196 }
197 
TEST_F(SctpTransportTest,CloseWhenTransportCloses)198 TEST_F(SctpTransportTest, CloseWhenTransportCloses) {
199   CreateTransport();
200   transport()->RegisterObserver(observer());
201   AddDtlsTransport();
202   CompleteSctpHandshake();
203   ASSERT_EQ_WAIT(SctpTransportState::kConnected, observer_.State(),
204                  kDefaultTimeout);
205   static_cast<cricket::FakeDtlsTransport*>(dtls_transport_->internal())
206       ->SetDtlsState(cricket::DTLS_TRANSPORT_CLOSED);
207   ASSERT_EQ_WAIT(SctpTransportState::kClosed, observer_.State(),
208                  kDefaultTimeout);
209 }
210 
211 }  // namespace webrtc
212