1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
5 * You can obtain one at http://mozilla.org/MPL/2.0/. */
6
7 #include <algorithm>
8 #include <mutex>
9
10 #include "mozilla/net/WebrtcTCPSocket.h"
11 #include "mozilla/net/WebrtcTCPSocketCallback.h"
12
13 #include "nsISocketTransport.h"
14
15 #define GTEST_HAS_RTTI 0
16 #include "gtest/gtest.h"
17 #include "gtest/Helpers.h"
18 #include "gtest_utils.h"
19
20 static const uint32_t kDefaultTestTimeout = 2000;
21 static const char kReadData[] = "Hello, World!";
22 static const size_t kReadDataLength = sizeof(kReadData) - 1;
23 static const std::string kReadDataString =
24 std::string(kReadData, kReadDataLength);
25 static int kDataLargeOuterLoopCount = 128;
26 static int kDataLargeInnerLoopCount = 1024;
27
28 namespace mozilla {
29
30 using namespace net;
31 using namespace testing;
32
33 class WebrtcTCPSocketTestCallback;
34
35 class FakeSocketTransportProvider : public nsISocketTransport {
36 public:
37 NS_DECL_THREADSAFE_ISUPPORTS
38
39 // nsISocketTransport
GetHost(nsACString & aHost)40 NS_IMETHOD GetHost(nsACString& aHost) override {
41 MOZ_ASSERT(false);
42 return NS_OK;
43 }
GetPort(int32_t * aPort)44 NS_IMETHOD GetPort(int32_t* aPort) override {
45 MOZ_ASSERT(false);
46 return NS_OK;
47 }
GetScriptableOriginAttributes(JSContext * cx,JS::MutableHandleValue aOriginAttributes)48 NS_IMETHOD GetScriptableOriginAttributes(
49 JSContext* cx, JS::MutableHandleValue aOriginAttributes) override {
50 MOZ_ASSERT(false);
51 return NS_OK;
52 }
SetScriptableOriginAttributes(JSContext * cx,JS::HandleValue aOriginAttributes)53 NS_IMETHOD SetScriptableOriginAttributes(
54 JSContext* cx, JS::HandleValue aOriginAttributes) override {
55 MOZ_ASSERT(false);
56 return NS_OK;
57 }
GetOriginAttributes(mozilla::OriginAttributes * _retval)58 virtual nsresult GetOriginAttributes(
59 mozilla::OriginAttributes* _retval) override {
60 MOZ_ASSERT(false);
61 return NS_OK;
62 }
SetOriginAttributes(const mozilla::OriginAttributes & aOriginAttrs)63 virtual nsresult SetOriginAttributes(
64 const mozilla::OriginAttributes& aOriginAttrs) override {
65 MOZ_ASSERT(false);
66 return NS_OK;
67 }
GetPeerAddr(mozilla::net::NetAddr * _retval)68 NS_IMETHOD GetPeerAddr(mozilla::net::NetAddr* _retval) override {
69 MOZ_ASSERT(false);
70 return NS_OK;
71 }
GetSelfAddr(mozilla::net::NetAddr * _retval)72 NS_IMETHOD GetSelfAddr(mozilla::net::NetAddr* _retval) override {
73 MOZ_ASSERT(false);
74 return NS_OK;
75 }
Bind(mozilla::net::NetAddr * aLocalAddr)76 NS_IMETHOD Bind(mozilla::net::NetAddr* aLocalAddr) override {
77 MOZ_ASSERT(false);
78 return NS_OK;
79 }
GetScriptablePeerAddr(nsINetAddr ** _retval)80 NS_IMETHOD GetScriptablePeerAddr(nsINetAddr** _retval) override {
81 MOZ_ASSERT(false);
82 return NS_OK;
83 }
GetScriptableSelfAddr(nsINetAddr ** _retval)84 NS_IMETHOD GetScriptableSelfAddr(nsINetAddr** _retval) override {
85 MOZ_ASSERT(false);
86 return NS_OK;
87 }
GetSecurityInfo(nsISupports ** aSecurityInfo)88 NS_IMETHOD GetSecurityInfo(nsISupports** aSecurityInfo) override {
89 MOZ_ASSERT(false);
90 return NS_OK;
91 }
GetSecurityCallbacks(nsIInterfaceRequestor ** aSecurityCallbacks)92 NS_IMETHOD GetSecurityCallbacks(
93 nsIInterfaceRequestor** aSecurityCallbacks) override {
94 MOZ_ASSERT(false);
95 return NS_OK;
96 }
SetSecurityCallbacks(nsIInterfaceRequestor * aSecurityCallbacks)97 NS_IMETHOD SetSecurityCallbacks(
98 nsIInterfaceRequestor* aSecurityCallbacks) override {
99 MOZ_ASSERT(false);
100 return NS_OK;
101 }
IsAlive(bool * _retval)102 NS_IMETHOD IsAlive(bool* _retval) override {
103 MOZ_ASSERT(false);
104 return NS_OK;
105 }
GetTimeout(uint32_t aType,uint32_t * _retval)106 NS_IMETHOD GetTimeout(uint32_t aType, uint32_t* _retval) override {
107 MOZ_ASSERT(false);
108 return NS_OK;
109 }
SetTimeout(uint32_t aType,uint32_t aValue)110 NS_IMETHOD SetTimeout(uint32_t aType, uint32_t aValue) override {
111 MOZ_ASSERT(false);
112 return NS_OK;
113 }
SetLinger(bool aPolarity,int16_t aTimeout)114 NS_IMETHOD SetLinger(bool aPolarity, int16_t aTimeout) override {
115 MOZ_ASSERT(false);
116 return NS_OK;
117 }
SetReuseAddrPort(bool reuseAddrPort)118 NS_IMETHOD SetReuseAddrPort(bool reuseAddrPort) override {
119 MOZ_ASSERT(false);
120 return NS_OK;
121 }
GetConnectionFlags(uint32_t * aConnectionFlags)122 NS_IMETHOD GetConnectionFlags(uint32_t* aConnectionFlags) override {
123 MOZ_ASSERT(false);
124 return NS_OK;
125 }
SetConnectionFlags(uint32_t aConnectionFlags)126 NS_IMETHOD SetConnectionFlags(uint32_t aConnectionFlags) override {
127 MOZ_ASSERT(false);
128 return NS_OK;
129 }
SetIsPrivate(bool)130 NS_IMETHOD SetIsPrivate(bool) override {
131 MOZ_ASSERT(false);
132 return NS_OK;
133 }
GetTlsFlags(uint32_t * aTlsFlags)134 NS_IMETHOD GetTlsFlags(uint32_t* aTlsFlags) override {
135 MOZ_ASSERT(false);
136 return NS_OK;
137 }
SetTlsFlags(uint32_t aTlsFlags)138 NS_IMETHOD SetTlsFlags(uint32_t aTlsFlags) override {
139 MOZ_ASSERT(false);
140 return NS_OK;
141 }
GetQoSBits(uint8_t * aQoSBits)142 NS_IMETHOD GetQoSBits(uint8_t* aQoSBits) override {
143 MOZ_ASSERT(false);
144 return NS_OK;
145 }
SetQoSBits(uint8_t aQoSBits)146 NS_IMETHOD SetQoSBits(uint8_t aQoSBits) override {
147 MOZ_ASSERT(false);
148 return NS_OK;
149 }
GetRecvBufferSize(uint32_t * aRecvBufferSize)150 NS_IMETHOD GetRecvBufferSize(uint32_t* aRecvBufferSize) override {
151 MOZ_ASSERT(false);
152 return NS_OK;
153 }
GetSendBufferSize(uint32_t * aSendBufferSize)154 NS_IMETHOD GetSendBufferSize(uint32_t* aSendBufferSize) override {
155 MOZ_ASSERT(false);
156 return NS_OK;
157 }
GetKeepaliveEnabled(bool * aKeepaliveEnabled)158 NS_IMETHOD GetKeepaliveEnabled(bool* aKeepaliveEnabled) override {
159 MOZ_ASSERT(false);
160 return NS_OK;
161 }
SetKeepaliveEnabled(bool aKeepaliveEnabled)162 NS_IMETHOD SetKeepaliveEnabled(bool aKeepaliveEnabled) override {
163 MOZ_ASSERT(false);
164 return NS_OK;
165 }
SetKeepaliveVals(int32_t keepaliveIdleTime,int32_t keepaliveRetryInterval)166 NS_IMETHOD SetKeepaliveVals(int32_t keepaliveIdleTime,
167 int32_t keepaliveRetryInterval) override {
168 MOZ_ASSERT(false);
169 return NS_OK;
170 }
GetResetIPFamilyPreference(bool * aResetIPFamilyPreference)171 NS_IMETHOD GetResetIPFamilyPreference(
172 bool* aResetIPFamilyPreference) override {
173 MOZ_ASSERT(false);
174 return NS_OK;
175 }
GetEchConfigUsed(bool * aEchConfigUsed)176 NS_IMETHOD GetEchConfigUsed(bool* aEchConfigUsed) override {
177 MOZ_ASSERT(false);
178 return NS_OK;
179 }
SetEchConfig(const nsACString & aEchConfig)180 NS_IMETHOD SetEchConfig(const nsACString& aEchConfig) override {
181 MOZ_ASSERT(false);
182 return NS_OK;
183 }
ResolvedByTRR(bool * _retval)184 NS_IMETHOD ResolvedByTRR(bool* _retval) override {
185 MOZ_ASSERT(false);
186 return NS_OK;
187 }
GetRetryDnsIfPossible(bool * aRetryDns)188 NS_IMETHOD GetRetryDnsIfPossible(bool* aRetryDns) override {
189 MOZ_ASSERT(false);
190 return NS_OK;
191 }
GetStatus(nsresult * aStatus)192 NS_IMETHOD GetStatus(nsresult* aStatus) override {
193 MOZ_ASSERT(false);
194 return NS_OK;
195 }
196
197 // nsITransport
OpenInputStream(uint32_t aFlags,uint32_t aSegmentSize,uint32_t aSegmentCount,nsIInputStream ** _retval)198 NS_IMETHOD OpenInputStream(uint32_t aFlags, uint32_t aSegmentSize,
199 uint32_t aSegmentCount,
200 nsIInputStream** _retval) override {
201 MOZ_ASSERT(false);
202 return NS_OK;
203 }
OpenOutputStream(uint32_t aFlags,uint32_t aSegmentSize,uint32_t aSegmentCount,nsIOutputStream ** _retval)204 NS_IMETHOD OpenOutputStream(uint32_t aFlags, uint32_t aSegmentSize,
205 uint32_t aSegmentCount,
206 nsIOutputStream** _retval) override {
207 MOZ_ASSERT(false);
208 return NS_OK;
209 }
SetEventSink(nsITransportEventSink * aSink,nsIEventTarget * aEventTarget)210 NS_IMETHOD SetEventSink(nsITransportEventSink* aSink,
211 nsIEventTarget* aEventTarget) override {
212 MOZ_ASSERT(false);
213 return NS_OK;
214 }
215
216 // fake except for these methods which are OK to call
217 // nsISocketTransport
SetRecvBufferSize(uint32_t aRecvBufferSize)218 NS_IMETHOD SetRecvBufferSize(uint32_t aRecvBufferSize) override {
219 return NS_OK;
220 }
SetSendBufferSize(uint32_t aSendBufferSize)221 NS_IMETHOD SetSendBufferSize(uint32_t aSendBufferSize) override {
222 return NS_OK;
223 }
224 // nsITransport
Close(nsresult aReason)225 NS_IMETHOD Close(nsresult aReason) override { return NS_OK; }
226
227 protected:
228 virtual ~FakeSocketTransportProvider() = default;
229 };
230
231 NS_IMPL_ISUPPORTS(FakeSocketTransportProvider, nsISocketTransport, nsITransport)
232
233 // Implements some common elements to WebrtcTCPSocketTestOutputStream and
234 // WebrtcTCPSocketTestInputStream.
235 class WebrtcTCPSocketTestStream {
236 public:
237 WebrtcTCPSocketTestStream();
238
Fail()239 void Fail() { mMustFail = true; }
240
241 size_t DataLength();
242 template <typename T>
243 void AppendElements(const T* aBuffer, size_t aLength);
244
245 protected:
246 virtual ~WebrtcTCPSocketTestStream() = default;
247
248 nsTArray<uint8_t> mData;
249 std::mutex mDataMutex;
250
251 bool mMustFail;
252 };
253
WebrtcTCPSocketTestStream()254 WebrtcTCPSocketTestStream::WebrtcTCPSocketTestStream() : mMustFail(false) {}
255
256 template <typename T>
AppendElements(const T * aBuffer,size_t aLength)257 void WebrtcTCPSocketTestStream::AppendElements(const T* aBuffer,
258 size_t aLength) {
259 std::lock_guard<std::mutex> guard(mDataMutex);
260 mData.AppendElements(aBuffer, aLength);
261 }
262
DataLength()263 size_t WebrtcTCPSocketTestStream::DataLength() {
264 std::lock_guard<std::mutex> guard(mDataMutex);
265 return mData.Length();
266 }
267
268 class WebrtcTCPSocketTestInputStream : public nsIAsyncInputStream,
269 public WebrtcTCPSocketTestStream {
270 public:
271 NS_DECL_THREADSAFE_ISUPPORTS
272 NS_DECL_NSIASYNCINPUTSTREAM
273 NS_DECL_NSIINPUTSTREAM
274
WebrtcTCPSocketTestInputStream()275 WebrtcTCPSocketTestInputStream()
276 : mMaxReadSize(1024 * 1024), mAllowCallbacks(false) {}
277
278 void DoCallback();
279 void CallCallback(const nsCOMPtr<nsIInputStreamCallback>& aCallback);
AllowCallbacks()280 void AllowCallbacks() { mAllowCallbacks = true; }
281
282 size_t mMaxReadSize;
283
284 protected:
285 virtual ~WebrtcTCPSocketTestInputStream() = default;
286
287 private:
288 nsCOMPtr<nsIInputStreamCallback> mCallback;
289 nsCOMPtr<nsIEventTarget> mCallbackTarget;
290
291 bool mAllowCallbacks;
292 };
293
NS_IMPL_ISUPPORTS(WebrtcTCPSocketTestInputStream,nsIAsyncInputStream,nsIInputStream)294 NS_IMPL_ISUPPORTS(WebrtcTCPSocketTestInputStream, nsIAsyncInputStream,
295 nsIInputStream)
296
297 nsresult WebrtcTCPSocketTestInputStream::AsyncWait(
298 nsIInputStreamCallback* aCallback, uint32_t aFlags,
299 uint32_t aRequestedCount, nsIEventTarget* aEventTarget) {
300 MOZ_ASSERT(!aEventTarget, "no event target should be set");
301
302 mCallback = aCallback;
303 mCallbackTarget = NS_GetCurrentThread();
304
305 if (mAllowCallbacks && DataLength() > 0) {
306 DoCallback();
307 }
308
309 return NS_OK;
310 }
311
CloseWithStatus(nsresult aStatus)312 nsresult WebrtcTCPSocketTestInputStream::CloseWithStatus(nsresult aStatus) {
313 return Close();
314 }
315
Close()316 nsresult WebrtcTCPSocketTestInputStream::Close() { return NS_OK; }
317
Available(uint64_t * aAvailable)318 nsresult WebrtcTCPSocketTestInputStream::Available(uint64_t* aAvailable) {
319 *aAvailable = DataLength();
320 return NS_OK;
321 }
322
Read(char * aBuffer,uint32_t aCount,uint32_t * aRead)323 nsresult WebrtcTCPSocketTestInputStream::Read(char* aBuffer, uint32_t aCount,
324 uint32_t* aRead) {
325 std::lock_guard<std::mutex> guard(mDataMutex);
326 if (mMustFail) {
327 return NS_ERROR_FAILURE;
328 }
329 *aRead = std::min({(size_t)aCount, mData.Length(), mMaxReadSize});
330 memcpy(aBuffer, mData.Elements(), *aRead);
331 mData.RemoveElementsAt(0, *aRead);
332 return *aRead > 0 ? NS_OK : NS_BASE_STREAM_WOULD_BLOCK;
333 }
334
ReadSegments(nsWriteSegmentFun aWriter,void * aClosure,uint32_t aCount,uint32_t * _retval)335 nsresult WebrtcTCPSocketTestInputStream::ReadSegments(nsWriteSegmentFun aWriter,
336 void* aClosure,
337 uint32_t aCount,
338 uint32_t* _retval) {
339 MOZ_ASSERT(false);
340 return NS_OK;
341 }
342
IsNonBlocking(bool * aIsNonBlocking)343 nsresult WebrtcTCPSocketTestInputStream::IsNonBlocking(bool* aIsNonBlocking) {
344 *aIsNonBlocking = true;
345 return NS_OK;
346 }
347
CallCallback(const nsCOMPtr<nsIInputStreamCallback> & aCallback)348 void WebrtcTCPSocketTestInputStream::CallCallback(
349 const nsCOMPtr<nsIInputStreamCallback>& aCallback) {
350 aCallback->OnInputStreamReady(this);
351 }
352
DoCallback()353 void WebrtcTCPSocketTestInputStream::DoCallback() {
354 if (mCallback) {
355 mCallbackTarget->Dispatch(
356 NewRunnableMethod<const nsCOMPtr<nsIInputStreamCallback>&>(
357 "WebrtcTCPSocketTestInputStream::DoCallback", this,
358 &WebrtcTCPSocketTestInputStream::CallCallback,
359 std::move(mCallback)));
360
361 mCallbackTarget = nullptr;
362 }
363 }
364
365 class WebrtcTCPSocketTestOutputStream : public nsIAsyncOutputStream,
366 public WebrtcTCPSocketTestStream {
367 public:
368 NS_DECL_THREADSAFE_ISUPPORTS
369 NS_DECL_NSIASYNCOUTPUTSTREAM
370 NS_DECL_NSIOUTPUTSTREAM
371
WebrtcTCPSocketTestOutputStream()372 WebrtcTCPSocketTestOutputStream() : mMaxWriteSize(1024 * 1024) {}
373
374 void DoCallback();
375 void CallCallback(const nsCOMPtr<nsIOutputStreamCallback>& aCallback);
376
377 std::string DataString();
378
379 uint32_t mMaxWriteSize;
380
381 protected:
382 virtual ~WebrtcTCPSocketTestOutputStream() = default;
383
384 private:
385 nsCOMPtr<nsIOutputStreamCallback> mCallback;
386 nsCOMPtr<nsIEventTarget> mCallbackTarget;
387 };
388
NS_IMPL_ISUPPORTS(WebrtcTCPSocketTestOutputStream,nsIAsyncOutputStream,nsIOutputStream)389 NS_IMPL_ISUPPORTS(WebrtcTCPSocketTestOutputStream, nsIAsyncOutputStream,
390 nsIOutputStream)
391
392 nsresult WebrtcTCPSocketTestOutputStream::AsyncWait(
393 nsIOutputStreamCallback* aCallback, uint32_t aFlags,
394 uint32_t aRequestedCount, nsIEventTarget* aEventTarget) {
395 MOZ_ASSERT(!aEventTarget, "no event target should be set");
396
397 mCallback = aCallback;
398 mCallbackTarget = NS_GetCurrentThread();
399
400 return NS_OK;
401 }
402
CloseWithStatus(nsresult aStatus)403 nsresult WebrtcTCPSocketTestOutputStream::CloseWithStatus(nsresult aStatus) {
404 return Close();
405 }
406
Close()407 nsresult WebrtcTCPSocketTestOutputStream::Close() { return NS_OK; }
408
Flush()409 nsresult WebrtcTCPSocketTestOutputStream::Flush() { return NS_OK; }
410
Write(const char * aBuffer,uint32_t aCount,uint32_t * aWrote)411 nsresult WebrtcTCPSocketTestOutputStream::Write(const char* aBuffer,
412 uint32_t aCount,
413 uint32_t* aWrote) {
414 if (mMustFail) {
415 return NS_ERROR_FAILURE;
416 }
417 *aWrote = std::min(aCount, mMaxWriteSize);
418 AppendElements(aBuffer, *aWrote);
419 return NS_OK;
420 }
421
WriteSegments(nsReadSegmentFun aReader,void * aClosure,uint32_t aCount,uint32_t * _retval)422 nsresult WebrtcTCPSocketTestOutputStream::WriteSegments(
423 nsReadSegmentFun aReader, void* aClosure, uint32_t aCount,
424 uint32_t* _retval) {
425 MOZ_ASSERT(false);
426 return NS_OK;
427 }
428
WriteFrom(nsIInputStream * aFromStream,uint32_t aCount,uint32_t * _retval)429 nsresult WebrtcTCPSocketTestOutputStream::WriteFrom(nsIInputStream* aFromStream,
430 uint32_t aCount,
431 uint32_t* _retval) {
432 MOZ_ASSERT(false);
433 return NS_OK;
434 }
435
IsNonBlocking(bool * aIsNonBlocking)436 nsresult WebrtcTCPSocketTestOutputStream::IsNonBlocking(bool* aIsNonBlocking) {
437 *aIsNonBlocking = true;
438 return NS_OK;
439 }
440
CallCallback(const nsCOMPtr<nsIOutputStreamCallback> & aCallback)441 void WebrtcTCPSocketTestOutputStream::CallCallback(
442 const nsCOMPtr<nsIOutputStreamCallback>& aCallback) {
443 aCallback->OnOutputStreamReady(this);
444 }
445
DoCallback()446 void WebrtcTCPSocketTestOutputStream::DoCallback() {
447 if (mCallback) {
448 mCallbackTarget->Dispatch(
449 NewRunnableMethod<const nsCOMPtr<nsIOutputStreamCallback>&>(
450 "WebrtcTCPSocketTestOutputStream::CallCallback", this,
451 &WebrtcTCPSocketTestOutputStream::CallCallback,
452 std::move(mCallback)));
453
454 mCallbackTarget = nullptr;
455 }
456 }
457
DataString()458 std::string WebrtcTCPSocketTestOutputStream::DataString() {
459 std::lock_guard<std::mutex> guard(mDataMutex);
460 return std::string((char*)mData.Elements(), mData.Length());
461 }
462
463 // Fake as in not the real WebrtcTCPSocket but real enough
464 class FakeWebrtcTCPSocket : public WebrtcTCPSocket {
465 public:
FakeWebrtcTCPSocket(WebrtcTCPSocketCallback * aCallback)466 explicit FakeWebrtcTCPSocket(WebrtcTCPSocketCallback* aCallback)
467 : WebrtcTCPSocket(aCallback) {}
468
469 protected:
470 virtual ~FakeWebrtcTCPSocket() = default;
471
472 void InvokeOnClose(nsresult aReason) override;
473 void InvokeOnConnected() override;
474 void InvokeOnRead(nsTArray<uint8_t>&& aReadData) override;
475 };
476
InvokeOnClose(nsresult aReason)477 void FakeWebrtcTCPSocket::InvokeOnClose(nsresult aReason) {
478 mProxyCallbacks->OnClose(aReason);
479 }
480
InvokeOnConnected()481 void FakeWebrtcTCPSocket::InvokeOnConnected() {
482 mProxyCallbacks->OnConnected("http"_ns);
483 }
484
InvokeOnRead(nsTArray<uint8_t> && aReadData)485 void FakeWebrtcTCPSocket::InvokeOnRead(nsTArray<uint8_t>&& aReadData) {
486 mProxyCallbacks->OnRead(std::move(aReadData));
487 }
488
489 class WebrtcTCPSocketTest : public MtransportTest {
490 public:
WebrtcTCPSocketTest()491 WebrtcTCPSocketTest()
492 : MtransportTest(),
493 mSocketThread(nullptr),
494 mSocketTransport(nullptr),
495 mInputStream(nullptr),
496 mOutputStream(nullptr),
497 mChannel(nullptr),
498 mCallback(nullptr),
499 mOnCloseCalled(false),
500 mOnConnectedCalled(false) {}
501
502 // WebrtcTCPSocketCallback forwards from mCallback
503 void OnClose(nsresult aReason);
504 void OnConnected(const nsCString& aProxyType);
505 void OnRead(nsTArray<uint8_t>&& aReadData);
506
507 void SetUp() override;
508 void TearDown() override;
509
510 void DoTransportAvailable();
511
512 std::string ReadDataAsString();
513 std::string GetDataLarge();
514
515 nsCOMPtr<nsIEventTarget> mSocketThread;
516
517 nsCOMPtr<nsISocketTransport> mSocketTransport;
518 RefPtr<WebrtcTCPSocketTestInputStream> mInputStream;
519 RefPtr<WebrtcTCPSocketTestOutputStream> mOutputStream;
520 RefPtr<FakeWebrtcTCPSocket> mChannel;
521 RefPtr<WebrtcTCPSocketTestCallback> mCallback;
522
523 bool mOnCloseCalled;
524 bool mOnConnectedCalled;
525
526 size_t ReadDataLength();
527 template <typename T>
528 void AppendReadData(const T* aBuffer, size_t aLength);
529
530 private:
531 nsTArray<uint8_t> mReadData;
532 std::mutex mReadDataMutex;
533 };
534
535 class WebrtcTCPSocketTestCallback : public WebrtcTCPSocketCallback {
536 public:
NS_INLINE_DECL_THREADSAFE_REFCOUNTING(WebrtcTCPSocketTestCallback,override)537 NS_INLINE_DECL_THREADSAFE_REFCOUNTING(WebrtcTCPSocketTestCallback, override)
538
539 explicit WebrtcTCPSocketTestCallback(WebrtcTCPSocketTest* aTest)
540 : mTest(aTest) {}
541
542 // WebrtcTCPSocketCallback
543 void OnClose(nsresult aReason) override;
544 void OnConnected(const nsCString& aProxyType) override;
545 void OnRead(nsTArray<uint8_t>&& aReadData) override;
546
547 protected:
548 virtual ~WebrtcTCPSocketTestCallback() = default;
549
550 private:
551 WebrtcTCPSocketTest* mTest;
552 };
553
SetUp()554 void WebrtcTCPSocketTest::SetUp() {
555 nsresult rv;
556 // WebrtcTCPSocket's threading model is the same as mtransport
557 // all socket operations are done on the socket thread
558 // callbacks are invoked on the main thread
559 mSocketThread = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv);
560 ASSERT_TRUE(NS_SUCCEEDED(rv));
561
562 mSocketTransport = new FakeSocketTransportProvider();
563 mInputStream = new WebrtcTCPSocketTestInputStream();
564 mOutputStream = new WebrtcTCPSocketTestOutputStream();
565 mCallback = new WebrtcTCPSocketTestCallback(this);
566 mChannel = new FakeWebrtcTCPSocket(mCallback.get());
567 }
568
TearDown()569 void WebrtcTCPSocketTest::TearDown() {}
570
571 // WebrtcTCPSocketCallback
OnRead(nsTArray<uint8_t> && aReadData)572 void WebrtcTCPSocketTest::OnRead(nsTArray<uint8_t>&& aReadData) {
573 AppendReadData(aReadData.Elements(), aReadData.Length());
574 }
575
OnConnected(const nsCString & aProxyType)576 void WebrtcTCPSocketTest::OnConnected(const nsCString& aProxyType) {
577 mOnConnectedCalled = true;
578 }
579
OnClose(nsresult aReason)580 void WebrtcTCPSocketTest::OnClose(nsresult aReason) { mOnCloseCalled = true; }
581
DoTransportAvailable()582 void WebrtcTCPSocketTest::DoTransportAvailable() {
583 if (!mSocketThread->IsOnCurrentThread()) {
584 mSocketThread->Dispatch(
585 NS_NewRunnableFunction("DoTransportAvailable", [this]() -> void {
586 nsresult rv;
587 rv = mChannel->OnTransportAvailable(mSocketTransport, mInputStream,
588 mOutputStream);
589 ASSERT_EQ(NS_OK, rv);
590 }));
591 } else {
592 // should always be called on the main thread
593 MOZ_ASSERT(0);
594 }
595 }
596
ReadDataAsString()597 std::string WebrtcTCPSocketTest::ReadDataAsString() {
598 std::lock_guard<std::mutex> guard(mReadDataMutex);
599 return std::string((char*)mReadData.Elements(), mReadData.Length());
600 }
601
GetDataLarge()602 std::string WebrtcTCPSocketTest::GetDataLarge() {
603 std::string data;
604 for (int i = 0; i < kDataLargeOuterLoopCount * kDataLargeInnerLoopCount;
605 ++i) {
606 data += kReadData;
607 }
608 return data;
609 }
610
611 template <typename T>
AppendReadData(const T * aBuffer,size_t aLength)612 void WebrtcTCPSocketTest::AppendReadData(const T* aBuffer, size_t aLength) {
613 std::lock_guard<std::mutex> guard(mReadDataMutex);
614 mReadData.AppendElements(aBuffer, aLength);
615 }
616
ReadDataLength()617 size_t WebrtcTCPSocketTest::ReadDataLength() {
618 std::lock_guard<std::mutex> guard(mReadDataMutex);
619 return mReadData.Length();
620 }
621
OnClose(nsresult aReason)622 void WebrtcTCPSocketTestCallback::OnClose(nsresult aReason) {
623 mTest->OnClose(aReason);
624 }
625
OnConnected(const nsCString & aProxyType)626 void WebrtcTCPSocketTestCallback::OnConnected(const nsCString& aProxyType) {
627 mTest->OnConnected(aProxyType);
628 }
629
OnRead(nsTArray<uint8_t> && aReadData)630 void WebrtcTCPSocketTestCallback::OnRead(nsTArray<uint8_t>&& aReadData) {
631 mTest->OnRead(std::move(aReadData));
632 }
633
634 } // namespace mozilla
635
636 typedef mozilla::WebrtcTCPSocketTest WebrtcTCPSocketTest;
637
TEST_F(WebrtcTCPSocketTest,SetUp)638 TEST_F(WebrtcTCPSocketTest, SetUp) {}
639
TEST_F(WebrtcTCPSocketTest,TransportAvailable)640 TEST_F(WebrtcTCPSocketTest, TransportAvailable) {
641 DoTransportAvailable();
642 ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout);
643 }
644
TEST_F(WebrtcTCPSocketTest,Read)645 TEST_F(WebrtcTCPSocketTest, Read) {
646 DoTransportAvailable();
647 ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout);
648
649 mInputStream->AppendElements(kReadData, kReadDataLength);
650 mInputStream->DoCallback();
651
652 ASSERT_TRUE_WAIT(ReadDataAsString() == kReadDataString, kDefaultTestTimeout);
653 }
654
TEST_F(WebrtcTCPSocketTest,Write)655 TEST_F(WebrtcTCPSocketTest, Write) {
656 DoTransportAvailable();
657 ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout);
658
659 nsTArray<uint8_t> data;
660 data.AppendElements(kReadData, kReadDataLength);
661 mChannel->Write(std::move(data));
662
663 ASSERT_TRUE_WAIT(mChannel->CountUnwrittenBytes() == kReadDataLength,
664 kDefaultTestTimeout);
665
666 mOutputStream->DoCallback();
667
668 ASSERT_TRUE_WAIT(mOutputStream->DataString() == kReadDataString,
669 kDefaultTestTimeout);
670 }
671
TEST_F(WebrtcTCPSocketTest,ReadFail)672 TEST_F(WebrtcTCPSocketTest, ReadFail) {
673 DoTransportAvailable();
674 ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout);
675
676 mInputStream->AppendElements(kReadData, kReadDataLength);
677 mInputStream->Fail();
678 mInputStream->DoCallback();
679
680 ASSERT_TRUE_WAIT(mOnCloseCalled, kDefaultTestTimeout);
681 ASSERT_EQ(0U, ReadDataLength());
682 }
683
TEST_F(WebrtcTCPSocketTest,WriteFail)684 TEST_F(WebrtcTCPSocketTest, WriteFail) {
685 DoTransportAvailable();
686 ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout);
687
688 nsTArray<uint8_t> array;
689 array.AppendElements(kReadData, kReadDataLength);
690 mChannel->Write(std::move(array));
691
692 ASSERT_TRUE_WAIT(mChannel->CountUnwrittenBytes() == kReadDataLength,
693 kDefaultTestTimeout);
694
695 mOutputStream->Fail();
696 mOutputStream->DoCallback();
697
698 ASSERT_TRUE_WAIT(mOnCloseCalled, kDefaultTestTimeout);
699 ASSERT_EQ(0U, mOutputStream->DataLength());
700 }
701
TEST_F(WebrtcTCPSocketTest,ReadLarge)702 TEST_F(WebrtcTCPSocketTest, ReadLarge) {
703 DoTransportAvailable();
704 ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout);
705
706 const std::string data = GetDataLarge();
707
708 mInputStream->AppendElements(data.c_str(), data.length());
709 // make sure reading loops more than once
710 mInputStream->mMaxReadSize = 3072;
711 mInputStream->AllowCallbacks();
712 mInputStream->DoCallback();
713
714 ASSERT_TRUE_WAIT(ReadDataAsString() == data, kDefaultTestTimeout);
715 }
716
TEST_F(WebrtcTCPSocketTest,WriteLarge)717 TEST_F(WebrtcTCPSocketTest, WriteLarge) {
718 DoTransportAvailable();
719 ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout);
720
721 const std::string data = GetDataLarge();
722
723 for (int i = 0; i < kDataLargeOuterLoopCount; ++i) {
724 nsTArray<uint8_t> array;
725 int chunkSize = kReadDataString.length() * kDataLargeInnerLoopCount;
726 int offset = i * chunkSize;
727 array.AppendElements(data.c_str() + offset, chunkSize);
728 mChannel->Write(std::move(array));
729 }
730
731 ASSERT_TRUE_WAIT(mChannel->CountUnwrittenBytes() == data.length(),
732 kDefaultTestTimeout);
733
734 // make sure writing loops more than once per write request
735 mOutputStream->mMaxWriteSize = 1024;
736 mOutputStream->DoCallback();
737
738 ASSERT_TRUE_WAIT(mOutputStream->DataString() == data, kDefaultTestTimeout);
739 }
740