1 // Copyright 2015 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "third_party/blink/renderer/modules/peerconnection/rtc_data_channel.h"
6 
7 #include <memory>
8 #include <string>
9 #include <utility>
10 
11 #include "base/memory/ptr_util.h"
12 #include "base/memory/scoped_refptr.h"
13 #include "base/run_loop.h"
14 #include "base/test/test_simple_task_runner.h"
15 #include "testing/gtest/include/gtest/gtest.h"
16 #include "third_party/blink/renderer/bindings/core/v8/v8_binding_for_testing.h"
17 #include "third_party/blink/renderer/core/dom/events/event.h"
18 #include "third_party/blink/renderer/core/frame/local_frame.h"
19 #include "third_party/blink/renderer/core/testing/null_execution_context.h"
20 #include "third_party/blink/renderer/modules/peerconnection/mock_rtc_peer_connection_handler_platform.h"
21 #include "third_party/blink/renderer/platform/scheduler/public/frame_scheduler.h"
22 #include "third_party/blink/renderer/platform/scheduler/public/page_scheduler.h"
23 #include "third_party/blink/renderer/platform/scheduler/public/post_cross_thread_task.h"
24 #include "third_party/blink/renderer/platform/wtf/cross_thread_functional.h"
25 #include "third_party/blink/renderer/platform/wtf/text/wtf_string.h"
26 
27 namespace blink {
28 namespace {
29 
RunSynchronous(base::TestSimpleTaskRunner * thread,CrossThreadOnceClosure closure)30 void RunSynchronous(base::TestSimpleTaskRunner* thread,
31                     CrossThreadOnceClosure closure) {
32   if (thread->BelongsToCurrentThread()) {
33     std::move(closure).Run();
34     return;
35   }
36 
37   base::WaitableEvent waitable_event(
38       base::WaitableEvent::ResetPolicy::MANUAL,
39       base::WaitableEvent::InitialState::NOT_SIGNALED);
40   PostCrossThreadTask(
41       *thread, FROM_HERE,
42       CrossThreadBindOnce(
43           [](CrossThreadOnceClosure closure, base::WaitableEvent* event) {
44             std::move(closure).Run();
45             event->Signal();
46           },
47           WTF::Passed(std::move(closure)),
48           CrossThreadUnretained(&waitable_event)));
49   waitable_event.Wait();
50 }
51 
52 class MockPeerConnectionHandler : public MockRTCPeerConnectionHandlerPlatform {
53  public:
MockPeerConnectionHandler(scoped_refptr<base::TestSimpleTaskRunner> signaling_thread)54   MockPeerConnectionHandler(
55       scoped_refptr<base::TestSimpleTaskRunner> signaling_thread)
56       : signaling_thread_(signaling_thread) {}
57 
signaling_thread() const58   scoped_refptr<base::SingleThreadTaskRunner> signaling_thread()
59       const override {
60     return signaling_thread_;
61   }
62 
RunSynchronousOnceClosureOnSignalingThread(CrossThreadOnceClosure closure,const char * trace_event_name)63   void RunSynchronousOnceClosureOnSignalingThread(
64       CrossThreadOnceClosure closure,
65       const char* trace_event_name) override {
66     closure_ = std::move(closure);
67     RunSynchronous(
68         signaling_thread_.get(),
69         CrossThreadBindOnce(&MockPeerConnectionHandler::RunOnceClosure,
70                             CrossThreadUnretained(this)));
71   }
72 
73  private:
RunOnceClosure()74   void RunOnceClosure() {
75     DCHECK(signaling_thread_->BelongsToCurrentThread());
76     std::move(closure_).Run();
77   }
78 
79   scoped_refptr<base::TestSimpleTaskRunner> signaling_thread_;
80   CrossThreadOnceClosure closure_;
81 
82   DISALLOW_COPY_AND_ASSIGN(MockPeerConnectionHandler);
83 };
84 
85 class MockDataChannel : public webrtc::DataChannelInterface {
86  public:
MockDataChannel(scoped_refptr<base::TestSimpleTaskRunner> signaling_thread)87   explicit MockDataChannel(
88       scoped_refptr<base::TestSimpleTaskRunner> signaling_thread)
89       : signaling_thread_(signaling_thread),
90         buffered_amount_(0),
91         observer_(nullptr),
92         state_(webrtc::DataChannelInterface::kConnecting) {}
93 
label() const94   std::string label() const override { return std::string(); }
reliable() const95   bool reliable() const override { return false; }
ordered() const96   bool ordered() const override { return false; }
maxPacketLifeTime() const97   absl::optional<int> maxPacketLifeTime() const override {
98     return absl::nullopt;
99   }
maxRetransmitsOpt() const100   absl::optional<int> maxRetransmitsOpt() const override {
101     return absl::nullopt;
102   }
protocol() const103   std::string protocol() const override { return std::string(); }
negotiated() const104   bool negotiated() const override { return false; }
id() const105   int id() const override { return 0; }
messages_sent() const106   uint32_t messages_sent() const override { return 0; }
bytes_sent() const107   uint64_t bytes_sent() const override { return 0; }
messages_received() const108   uint32_t messages_received() const override { return 0; }
bytes_received() const109   uint64_t bytes_received() const override { return 0; }
Close()110   void Close() override {}
111 
RegisterObserver(webrtc::DataChannelObserver * observer)112   void RegisterObserver(webrtc::DataChannelObserver* observer) override {
113     RunSynchronous(
114         signaling_thread_.get(),
115         CrossThreadBindOnce(&MockDataChannel::RegisterObserverOnSignalingThread,
116                             CrossThreadUnretained(this),
117                             CrossThreadUnretained(observer)));
118   }
119 
UnregisterObserver()120   void UnregisterObserver() override {
121     RunSynchronous(signaling_thread_.get(),
122                    CrossThreadBindOnce(
123                        &MockDataChannel::UnregisterObserverOnSignalingThread,
124                        CrossThreadUnretained(this)));
125   }
126 
buffered_amount() const127   uint64_t buffered_amount() const override {
128     uint64_t buffered_amount;
129     RunSynchronous(signaling_thread_.get(),
130                    CrossThreadBindOnce(
131                        &MockDataChannel::GetBufferedAmountOnSignalingThread,
132                        CrossThreadUnretained(this),
133                        CrossThreadUnretained(&buffered_amount)));
134     return buffered_amount;
135   }
136 
state() const137   DataState state() const override {
138     DataState state;
139     RunSynchronous(
140         signaling_thread_.get(),
141         CrossThreadBindOnce(&MockDataChannel::GetStateOnSignalingThread,
142                             CrossThreadUnretained(this),
143                             CrossThreadUnretained(&state)));
144     return state;
145   }
146 
Send(const webrtc::DataBuffer & buffer)147   bool Send(const webrtc::DataBuffer& buffer) override {
148     RunSynchronous(
149         signaling_thread_.get(),
150         CrossThreadBindOnce(&MockDataChannel::SendOnSignalingThread,
151                             CrossThreadUnretained(this), buffer.size()));
152     return true;
153   }
154 
155   // For testing.
ChangeState(DataState state)156   void ChangeState(DataState state) {
157     RunSynchronous(
158         signaling_thread_.get(),
159         CrossThreadBindOnce(&MockDataChannel::ChangeStateOnSignalingThread,
160                             CrossThreadUnretained(this), state));
161     // The observer posts the state change from the signaling thread to the main
162     // thread. Wait for the posted task to be executed.
163     base::RunLoop().RunUntilIdle();
164   }
165 
166  protected:
167   ~MockDataChannel() override = default;
168 
169  private:
RegisterObserverOnSignalingThread(webrtc::DataChannelObserver * observer)170   void RegisterObserverOnSignalingThread(
171       webrtc::DataChannelObserver* observer) {
172     DCHECK(signaling_thread_->BelongsToCurrentThread());
173     observer_ = observer;
174   }
175 
UnregisterObserverOnSignalingThread()176   void UnregisterObserverOnSignalingThread() {
177     DCHECK(signaling_thread_->BelongsToCurrentThread());
178     observer_ = nullptr;
179   }
180 
GetBufferedAmountOnSignalingThread(uint64_t * buffered_amount) const181   void GetBufferedAmountOnSignalingThread(uint64_t* buffered_amount) const {
182     DCHECK(signaling_thread_->BelongsToCurrentThread());
183     *buffered_amount = buffered_amount_;
184   }
185 
GetStateOnSignalingThread(DataState * state) const186   void GetStateOnSignalingThread(DataState* state) const {
187     DCHECK(signaling_thread_->BelongsToCurrentThread());
188     *state = state_;
189   }
190 
SendOnSignalingThread(uint64_t buffer_size)191   void SendOnSignalingThread(uint64_t buffer_size) {
192     DCHECK(signaling_thread_->BelongsToCurrentThread());
193     buffered_amount_ += buffer_size;
194   }
195 
ChangeStateOnSignalingThread(DataState state)196   void ChangeStateOnSignalingThread(DataState state) {
197     DCHECK(signaling_thread_->BelongsToCurrentThread());
198     state_ = state;
199     if (observer_) {
200       observer_->OnStateChange();
201     }
202   }
203 
204   scoped_refptr<base::TestSimpleTaskRunner> signaling_thread_;
205 
206   // Accessed on signaling thread.
207   uint64_t buffered_amount_;
208   webrtc::DataChannelObserver* observer_;
209   webrtc::DataChannelInterface::DataState state_;
210 
211   DISALLOW_COPY_AND_ASSIGN(MockDataChannel);
212 };
213 
214 class RTCDataChannelTest : public ::testing::Test {
215  public:
RTCDataChannelTest()216   RTCDataChannelTest() : signaling_thread_(new base::TestSimpleTaskRunner()) {}
217 
signaling_thread()218   scoped_refptr<base::TestSimpleTaskRunner> signaling_thread() {
219     return signaling_thread_;
220   }
221 
222  private:
223   scoped_refptr<base::TestSimpleTaskRunner> signaling_thread_;
224 
225   DISALLOW_COPY_AND_ASSIGN(RTCDataChannelTest);
226 };
227 
228 }  // namespace
229 
TEST_F(RTCDataChannelTest,ChangeStateEarly)230 TEST_F(RTCDataChannelTest, ChangeStateEarly) {
231   scoped_refptr<MockDataChannel> webrtc_channel(
232       new rtc::RefCountedObject<MockDataChannel>(signaling_thread()));
233 
234   // Change state on the webrtc channel before creating the blink channel.
235   webrtc_channel->ChangeState(webrtc::DataChannelInterface::kOpen);
236 
237   std::unique_ptr<MockPeerConnectionHandler> pc(
238       new MockPeerConnectionHandler(signaling_thread()));
239   auto* channel = MakeGarbageCollected<RTCDataChannel>(
240       MakeGarbageCollected<NullExecutionContext>(), webrtc_channel.get(),
241       pc.get());
242 
243   // In RTCDataChannel::Create, the state change update is posted from the
244   // signaling thread to the main thread. Wait for posted the task to be
245   // executed.
246   base::RunLoop().RunUntilIdle();
247 
248   // Verify that the early state change was not lost.
249   EXPECT_EQ("open", channel->readyState());
250 }
251 
TEST_F(RTCDataChannelTest,BufferedAmount)252 TEST_F(RTCDataChannelTest, BufferedAmount) {
253   scoped_refptr<MockDataChannel> webrtc_channel(
254       new rtc::RefCountedObject<MockDataChannel>(signaling_thread()));
255   std::unique_ptr<MockPeerConnectionHandler> pc(
256       new MockPeerConnectionHandler(signaling_thread()));
257   auto* channel = MakeGarbageCollected<RTCDataChannel>(
258       MakeGarbageCollected<NullExecutionContext>(), webrtc_channel.get(),
259       pc.get());
260   webrtc_channel->ChangeState(webrtc::DataChannelInterface::kOpen);
261 
262   String message(std::string(100, 'A').c_str());
263   channel->send(message, IGNORE_EXCEPTION_FOR_TESTING);
264   EXPECT_EQ(100U, channel->bufferedAmount());
265   // The actual send operation is posted to the signaling thread; wait for it
266   // to run to avoid a memory leak.
267   signaling_thread()->RunUntilIdle();
268 }
269 
TEST_F(RTCDataChannelTest,BufferedAmountLow)270 TEST_F(RTCDataChannelTest, BufferedAmountLow) {
271   scoped_refptr<MockDataChannel> webrtc_channel(
272       new rtc::RefCountedObject<MockDataChannel>(signaling_thread()));
273   std::unique_ptr<MockPeerConnectionHandler> pc(
274       new MockPeerConnectionHandler(signaling_thread()));
275   auto* channel = MakeGarbageCollected<RTCDataChannel>(
276       MakeGarbageCollected<NullExecutionContext>(), webrtc_channel.get(),
277       pc.get());
278   webrtc_channel->ChangeState(webrtc::DataChannelInterface::kOpen);
279 
280   channel->setBufferedAmountLowThreshold(1);
281   channel->send("TEST", IGNORE_EXCEPTION_FOR_TESTING);
282   EXPECT_EQ(4U, channel->bufferedAmount());
283   channel->OnBufferedAmountChange(4);
284   ASSERT_EQ(1U, channel->scheduled_events_.size());
285   EXPECT_EQ("bufferedamountlow",
286             channel->scheduled_events_.back()->type().Utf8());
287   // The actual send operation is posted to the signaling thread; wait for it
288   // to run to avoid a memory leak.
289   signaling_thread()->RunUntilIdle();
290 }
291 
TEST_F(RTCDataChannelTest,Open)292 TEST_F(RTCDataChannelTest, Open) {
293   scoped_refptr<MockDataChannel> webrtc_channel(
294       new rtc::RefCountedObject<MockDataChannel>(signaling_thread()));
295   std::unique_ptr<MockPeerConnectionHandler> pc(
296       new MockPeerConnectionHandler(signaling_thread()));
297   auto* channel = MakeGarbageCollected<RTCDataChannel>(
298       MakeGarbageCollected<NullExecutionContext>(), webrtc_channel.get(),
299       pc.get());
300   channel->OnStateChange(webrtc::DataChannelInterface::kOpen);
301   EXPECT_EQ("open", channel->readyState());
302 }
303 
TEST_F(RTCDataChannelTest,Close)304 TEST_F(RTCDataChannelTest, Close) {
305   scoped_refptr<MockDataChannel> webrtc_channel(
306       new rtc::RefCountedObject<MockDataChannel>(signaling_thread()));
307   std::unique_ptr<MockPeerConnectionHandler> pc(
308       new MockPeerConnectionHandler(signaling_thread()));
309   auto* channel = MakeGarbageCollected<RTCDataChannel>(
310       MakeGarbageCollected<NullExecutionContext>(), webrtc_channel.get(),
311       pc.get());
312   channel->OnStateChange(webrtc::DataChannelInterface::kClosed);
313   EXPECT_EQ("closed", channel->readyState());
314 }
315 
TEST_F(RTCDataChannelTest,Message)316 TEST_F(RTCDataChannelTest, Message) {
317   scoped_refptr<MockDataChannel> webrtc_channel(
318       new rtc::RefCountedObject<MockDataChannel>(signaling_thread()));
319   std::unique_ptr<MockPeerConnectionHandler> pc(
320       new MockPeerConnectionHandler(signaling_thread()));
321   auto* channel = MakeGarbageCollected<RTCDataChannel>(
322       MakeGarbageCollected<NullExecutionContext>(), webrtc_channel.get(),
323       pc.get());
324 
325   std::unique_ptr<webrtc::DataBuffer> message(new webrtc::DataBuffer("A"));
326   channel->OnMessage(std::move(message));
327   ASSERT_EQ(1U, channel->scheduled_events_.size());
328   EXPECT_EQ("message", channel->scheduled_events_.back()->type().Utf8());
329 }
330 
TEST_F(RTCDataChannelTest,SendAfterContextDestroyed)331 TEST_F(RTCDataChannelTest, SendAfterContextDestroyed) {
332   scoped_refptr<MockDataChannel> webrtc_channel(
333       new rtc::RefCountedObject<MockDataChannel>(signaling_thread()));
334   std::unique_ptr<MockPeerConnectionHandler> pc(
335       new MockPeerConnectionHandler(signaling_thread()));
336   auto* channel = MakeGarbageCollected<RTCDataChannel>(
337       MakeGarbageCollected<NullExecutionContext>(), webrtc_channel.get(),
338       pc.get());
339   webrtc_channel->ChangeState(webrtc::DataChannelInterface::kOpen);
340 
341   channel->ContextDestroyed();
342 
343   String message(std::string(100, 'A').c_str());
344   DummyExceptionStateForTesting exception_state;
345   channel->send(message, exception_state);
346 
347   EXPECT_TRUE(exception_state.HadException());
348 }
349 
TEST_F(RTCDataChannelTest,CloseAfterContextDestroyed)350 TEST_F(RTCDataChannelTest, CloseAfterContextDestroyed) {
351   scoped_refptr<MockDataChannel> webrtc_channel(
352       new rtc::RefCountedObject<MockDataChannel>(signaling_thread()));
353   std::unique_ptr<MockPeerConnectionHandler> pc(
354       new MockPeerConnectionHandler(signaling_thread()));
355   auto* channel = MakeGarbageCollected<RTCDataChannel>(
356       MakeGarbageCollected<NullExecutionContext>(), webrtc_channel.get(),
357       pc.get());
358   webrtc_channel->ChangeState(webrtc::DataChannelInterface::kOpen);
359 
360   channel->ContextDestroyed();
361   channel->close();
362   EXPECT_EQ(String::FromUTF8("closed"), channel->readyState());
363 }
364 
TEST_F(RTCDataChannelTest,StopsThrottling)365 TEST_F(RTCDataChannelTest, StopsThrottling) {
366   V8TestingScope scope;
367 
368   auto* scheduler = scope.GetFrame().GetFrameScheduler()->GetPageScheduler();
369   EXPECT_FALSE(scheduler->OptedOutFromAggressiveThrottlingForTest());
370 
371   // Creating an RTCDataChannel doesn't enable the opt-out.
372   scoped_refptr<MockDataChannel> webrtc_channel(
373       new rtc::RefCountedObject<MockDataChannel>(signaling_thread()));
374   std::unique_ptr<MockPeerConnectionHandler> pc(
375       new MockPeerConnectionHandler(signaling_thread()));
376   auto* channel = MakeGarbageCollected<RTCDataChannel>(
377       scope.GetExecutionContext(), webrtc_channel.get(), pc.get());
378   EXPECT_EQ("connecting", channel->readyState());
379   EXPECT_FALSE(scheduler->OptedOutFromAggressiveThrottlingForTest());
380 
381   // Transitioning to 'open' enables the opt-out.
382   webrtc_channel->ChangeState(webrtc::DataChannelInterface::kOpen);
383   base::RunLoop().RunUntilIdle();
384   EXPECT_EQ("open", channel->readyState());
385   EXPECT_TRUE(scheduler->OptedOutFromAggressiveThrottlingForTest());
386 
387   // Transitioning to 'closing' keeps the opt-out enabled.
388   webrtc_channel->ChangeState(webrtc::DataChannelInterface::kClosing);
389   base::RunLoop().RunUntilIdle();
390   EXPECT_EQ("closing", channel->readyState());
391   EXPECT_TRUE(scheduler->OptedOutFromAggressiveThrottlingForTest());
392 
393   // Transitioning to 'closed' stops the opt-out.
394   webrtc_channel->ChangeState(webrtc::DataChannelInterface::kClosed);
395   base::RunLoop().RunUntilIdle();
396   EXPECT_EQ("closed", channel->readyState());
397   EXPECT_FALSE(scheduler->OptedOutFromAggressiveThrottlingForTest());
398 }
399 
400 }  // namespace blink
401