1 /*
2 * Copyright 2019 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "pc/sctp_transport.h"
12
13 #include <utility>
14 #include <vector>
15
16 #include "absl/memory/memory.h"
17 #include "p2p/base/fake_dtls_transport.h"
18 #include "pc/dtls_transport.h"
19 #include "rtc_base/gunit.h"
20 #include "test/gmock.h"
21 #include "test/gtest.h"
22
23 constexpr int kDefaultTimeout = 1000; // milliseconds
24 constexpr int kTestMaxSctpStreams = 1234;
25
26 using cricket::FakeDtlsTransport;
27 using ::testing::ElementsAre;
28
29 namespace webrtc {
30
31 namespace {
32
33 class FakeCricketSctpTransport : public cricket::SctpTransportInternal {
34 public:
SetDtlsTransport(rtc::PacketTransportInternal * transport)35 void SetDtlsTransport(rtc::PacketTransportInternal* transport) override {}
Start(int local_port,int remote_port,int max_message_size)36 bool Start(int local_port, int remote_port, int max_message_size) override {
37 return true;
38 }
OpenStream(int sid)39 bool OpenStream(int sid) override { return true; }
ResetStream(int sid)40 bool ResetStream(int sid) override { return true; }
SendData(const cricket::SendDataParams & params,const rtc::CopyOnWriteBuffer & payload,cricket::SendDataResult * result=nullptr)41 bool SendData(const cricket::SendDataParams& params,
42 const rtc::CopyOnWriteBuffer& payload,
43 cricket::SendDataResult* result = nullptr) override {
44 return true;
45 }
ReadyToSendData()46 bool ReadyToSendData() override { return true; }
set_debug_name_for_testing(const char * debug_name)47 void set_debug_name_for_testing(const char* debug_name) override {}
max_message_size() const48 int max_message_size() const override { return 0; }
max_outbound_streams() const49 absl::optional<int> max_outbound_streams() const override {
50 return max_outbound_streams_;
51 }
max_inbound_streams() const52 absl::optional<int> max_inbound_streams() const override {
53 return max_inbound_streams_;
54 }
55 // Methods exposed for testing
SendSignalReadyToSendData()56 void SendSignalReadyToSendData() { SignalReadyToSendData(); }
57
SendSignalAssociationChangeCommunicationUp()58 void SendSignalAssociationChangeCommunicationUp() {
59 SignalAssociationChangeCommunicationUp();
60 }
61
SendSignalClosingProcedureStartedRemotely()62 void SendSignalClosingProcedureStartedRemotely() {
63 SignalClosingProcedureStartedRemotely(1);
64 }
65
SendSignalClosingProcedureComplete()66 void SendSignalClosingProcedureComplete() {
67 SignalClosingProcedureComplete(1);
68 }
set_max_outbound_streams(int streams)69 void set_max_outbound_streams(int streams) {
70 max_outbound_streams_ = streams;
71 }
set_max_inbound_streams(int streams)72 void set_max_inbound_streams(int streams) { max_inbound_streams_ = streams; }
73
74 private:
75 absl::optional<int> max_outbound_streams_;
76 absl::optional<int> max_inbound_streams_;
77 };
78
79 } // namespace
80
81 class TestSctpTransportObserver : public SctpTransportObserverInterface {
82 public:
TestSctpTransportObserver()83 TestSctpTransportObserver() : info_(SctpTransportState::kNew) {}
84
OnStateChange(SctpTransportInformation info)85 void OnStateChange(SctpTransportInformation info) override {
86 info_ = info;
87 states_.push_back(info.state());
88 }
89
State()90 SctpTransportState State() {
91 if (states_.size() > 0) {
92 return states_[states_.size() - 1];
93 } else {
94 return SctpTransportState::kNew;
95 }
96 }
97
States()98 const std::vector<SctpTransportState>& States() { return states_; }
99
LastReceivedInformation()100 const SctpTransportInformation LastReceivedInformation() { return info_; }
101
102 private:
103 std::vector<SctpTransportState> states_;
104 SctpTransportInformation info_;
105 };
106
107 class SctpTransportTest : public ::testing::Test {
108 public:
transport()109 SctpTransport* transport() { return transport_.get(); }
observer()110 SctpTransportObserverInterface* observer() { return &observer_; }
111
CreateTransport()112 void CreateTransport() {
113 auto cricket_sctp_transport =
114 absl::WrapUnique(new FakeCricketSctpTransport());
115 transport_ = new rtc::RefCountedObject<SctpTransport>(
116 std::move(cricket_sctp_transport));
117 }
118
AddDtlsTransport()119 void AddDtlsTransport() {
120 std::unique_ptr<cricket::DtlsTransportInternal> cricket_transport =
121 std::make_unique<FakeDtlsTransport>(
122 "audio", cricket::ICE_CANDIDATE_COMPONENT_RTP);
123 dtls_transport_ =
124 new rtc::RefCountedObject<DtlsTransport>(std::move(cricket_transport));
125 transport_->SetDtlsTransport(dtls_transport_);
126 }
127
CompleteSctpHandshake()128 void CompleteSctpHandshake() {
129 CricketSctpTransport()->SendSignalReadyToSendData();
130 // The computed MaxChannels shall be the minimum of the outgoing
131 // and incoming # of streams.
132 CricketSctpTransport()->set_max_outbound_streams(kTestMaxSctpStreams);
133 CricketSctpTransport()->set_max_inbound_streams(kTestMaxSctpStreams + 1);
134 CricketSctpTransport()->SendSignalAssociationChangeCommunicationUp();
135 }
136
CricketSctpTransport()137 FakeCricketSctpTransport* CricketSctpTransport() {
138 return static_cast<FakeCricketSctpTransport*>(transport_->internal());
139 }
140
141 rtc::scoped_refptr<SctpTransport> transport_;
142 rtc::scoped_refptr<DtlsTransport> dtls_transport_;
143 TestSctpTransportObserver observer_;
144 };
145
TEST(SctpTransportSimpleTest,CreateClearDelete)146 TEST(SctpTransportSimpleTest, CreateClearDelete) {
147 std::unique_ptr<cricket::SctpTransportInternal> fake_cricket_sctp_transport =
148 absl::WrapUnique(new FakeCricketSctpTransport());
149 rtc::scoped_refptr<SctpTransport> sctp_transport =
150 new rtc::RefCountedObject<SctpTransport>(
151 std::move(fake_cricket_sctp_transport));
152 ASSERT_TRUE(sctp_transport->internal());
153 ASSERT_EQ(SctpTransportState::kNew, sctp_transport->Information().state());
154 sctp_transport->Clear();
155 ASSERT_FALSE(sctp_transport->internal());
156 ASSERT_EQ(SctpTransportState::kClosed, sctp_transport->Information().state());
157 }
158
TEST_F(SctpTransportTest,EventsObservedWhenConnecting)159 TEST_F(SctpTransportTest, EventsObservedWhenConnecting) {
160 CreateTransport();
161 transport()->RegisterObserver(observer());
162 AddDtlsTransport();
163 CompleteSctpHandshake();
164 ASSERT_EQ_WAIT(SctpTransportState::kConnected, observer_.State(),
165 kDefaultTimeout);
166 EXPECT_THAT(observer_.States(), ElementsAre(SctpTransportState::kConnecting,
167 SctpTransportState::kConnected));
168 }
169
TEST_F(SctpTransportTest,CloseWhenClearing)170 TEST_F(SctpTransportTest, CloseWhenClearing) {
171 CreateTransport();
172 transport()->RegisterObserver(observer());
173 AddDtlsTransport();
174 CompleteSctpHandshake();
175 ASSERT_EQ_WAIT(SctpTransportState::kConnected, observer_.State(),
176 kDefaultTimeout);
177 transport()->Clear();
178 ASSERT_EQ_WAIT(SctpTransportState::kClosed, observer_.State(),
179 kDefaultTimeout);
180 }
181
TEST_F(SctpTransportTest,MaxChannelsSignalled)182 TEST_F(SctpTransportTest, MaxChannelsSignalled) {
183 CreateTransport();
184 transport()->RegisterObserver(observer());
185 AddDtlsTransport();
186 EXPECT_FALSE(transport()->Information().MaxChannels());
187 EXPECT_FALSE(observer_.LastReceivedInformation().MaxChannels());
188 CompleteSctpHandshake();
189 ASSERT_EQ_WAIT(SctpTransportState::kConnected, observer_.State(),
190 kDefaultTimeout);
191 EXPECT_TRUE(transport()->Information().MaxChannels());
192 EXPECT_EQ(kTestMaxSctpStreams, *(transport()->Information().MaxChannels()));
193 EXPECT_TRUE(observer_.LastReceivedInformation().MaxChannels());
194 EXPECT_EQ(kTestMaxSctpStreams,
195 *(observer_.LastReceivedInformation().MaxChannels()));
196 }
197
TEST_F(SctpTransportTest,CloseWhenTransportCloses)198 TEST_F(SctpTransportTest, CloseWhenTransportCloses) {
199 CreateTransport();
200 transport()->RegisterObserver(observer());
201 AddDtlsTransport();
202 CompleteSctpHandshake();
203 ASSERT_EQ_WAIT(SctpTransportState::kConnected, observer_.State(),
204 kDefaultTimeout);
205 static_cast<cricket::FakeDtlsTransport*>(dtls_transport_->internal())
206 ->SetDtlsState(cricket::DTLS_TRANSPORT_CLOSED);
207 ASSERT_EQ_WAIT(SctpTransportState::kClosed, observer_.State(),
208 kDefaultTimeout);
209 }
210
211 } // namespace webrtc
212