1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this file,
5  * You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 #include <algorithm>
8 #include <mutex>
9 
10 #include "mozilla/net/WebrtcTCPSocket.h"
11 #include "mozilla/net/WebrtcTCPSocketCallback.h"
12 
13 #include "nsISocketTransport.h"
14 
15 #define GTEST_HAS_RTTI 0
16 #include "gtest/gtest.h"
17 #include "gtest/Helpers.h"
18 #include "gtest_utils.h"
19 
20 static const uint32_t kDefaultTestTimeout = 2000;
21 static const char kReadData[] = "Hello, World!";
22 static const size_t kReadDataLength = sizeof(kReadData) - 1;
23 static const std::string kReadDataString =
24     std::string(kReadData, kReadDataLength);
25 static int kDataLargeOuterLoopCount = 128;
26 static int kDataLargeInnerLoopCount = 1024;
27 
28 namespace mozilla {
29 
30 using namespace net;
31 using namespace testing;
32 
33 class WebrtcTCPSocketTestCallback;
34 
35 class FakeSocketTransportProvider : public nsISocketTransport {
36  public:
37   NS_DECL_THREADSAFE_ISUPPORTS
38 
39   // nsISocketTransport
GetHost(nsACString & aHost)40   NS_IMETHOD GetHost(nsACString& aHost) override {
41     MOZ_ASSERT(false);
42     return NS_OK;
43   }
GetPort(int32_t * aPort)44   NS_IMETHOD GetPort(int32_t* aPort) override {
45     MOZ_ASSERT(false);
46     return NS_OK;
47   }
GetScriptableOriginAttributes(JSContext * cx,JS::MutableHandleValue aOriginAttributes)48   NS_IMETHOD GetScriptableOriginAttributes(
49       JSContext* cx, JS::MutableHandleValue aOriginAttributes) override {
50     MOZ_ASSERT(false);
51     return NS_OK;
52   }
SetScriptableOriginAttributes(JSContext * cx,JS::HandleValue aOriginAttributes)53   NS_IMETHOD SetScriptableOriginAttributes(
54       JSContext* cx, JS::HandleValue aOriginAttributes) override {
55     MOZ_ASSERT(false);
56     return NS_OK;
57   }
GetOriginAttributes(mozilla::OriginAttributes * _retval)58   virtual nsresult GetOriginAttributes(
59       mozilla::OriginAttributes* _retval) override {
60     MOZ_ASSERT(false);
61     return NS_OK;
62   }
SetOriginAttributes(const mozilla::OriginAttributes & aOriginAttrs)63   virtual nsresult SetOriginAttributes(
64       const mozilla::OriginAttributes& aOriginAttrs) override {
65     MOZ_ASSERT(false);
66     return NS_OK;
67   }
GetPeerAddr(mozilla::net::NetAddr * _retval)68   NS_IMETHOD GetPeerAddr(mozilla::net::NetAddr* _retval) override {
69     MOZ_ASSERT(false);
70     return NS_OK;
71   }
GetSelfAddr(mozilla::net::NetAddr * _retval)72   NS_IMETHOD GetSelfAddr(mozilla::net::NetAddr* _retval) override {
73     MOZ_ASSERT(false);
74     return NS_OK;
75   }
Bind(mozilla::net::NetAddr * aLocalAddr)76   NS_IMETHOD Bind(mozilla::net::NetAddr* aLocalAddr) override {
77     MOZ_ASSERT(false);
78     return NS_OK;
79   }
GetScriptablePeerAddr(nsINetAddr ** _retval)80   NS_IMETHOD GetScriptablePeerAddr(nsINetAddr** _retval) override {
81     MOZ_ASSERT(false);
82     return NS_OK;
83   }
GetScriptableSelfAddr(nsINetAddr ** _retval)84   NS_IMETHOD GetScriptableSelfAddr(nsINetAddr** _retval) override {
85     MOZ_ASSERT(false);
86     return NS_OK;
87   }
GetSecurityInfo(nsISupports ** aSecurityInfo)88   NS_IMETHOD GetSecurityInfo(nsISupports** aSecurityInfo) override {
89     MOZ_ASSERT(false);
90     return NS_OK;
91   }
GetSecurityCallbacks(nsIInterfaceRequestor ** aSecurityCallbacks)92   NS_IMETHOD GetSecurityCallbacks(
93       nsIInterfaceRequestor** aSecurityCallbacks) override {
94     MOZ_ASSERT(false);
95     return NS_OK;
96   }
SetSecurityCallbacks(nsIInterfaceRequestor * aSecurityCallbacks)97   NS_IMETHOD SetSecurityCallbacks(
98       nsIInterfaceRequestor* aSecurityCallbacks) override {
99     MOZ_ASSERT(false);
100     return NS_OK;
101   }
IsAlive(bool * _retval)102   NS_IMETHOD IsAlive(bool* _retval) override {
103     MOZ_ASSERT(false);
104     return NS_OK;
105   }
GetTimeout(uint32_t aType,uint32_t * _retval)106   NS_IMETHOD GetTimeout(uint32_t aType, uint32_t* _retval) override {
107     MOZ_ASSERT(false);
108     return NS_OK;
109   }
SetTimeout(uint32_t aType,uint32_t aValue)110   NS_IMETHOD SetTimeout(uint32_t aType, uint32_t aValue) override {
111     MOZ_ASSERT(false);
112     return NS_OK;
113   }
SetLinger(bool aPolarity,int16_t aTimeout)114   NS_IMETHOD SetLinger(bool aPolarity, int16_t aTimeout) override {
115     MOZ_ASSERT(false);
116     return NS_OK;
117   }
SetReuseAddrPort(bool reuseAddrPort)118   NS_IMETHOD SetReuseAddrPort(bool reuseAddrPort) override {
119     MOZ_ASSERT(false);
120     return NS_OK;
121   }
GetConnectionFlags(uint32_t * aConnectionFlags)122   NS_IMETHOD GetConnectionFlags(uint32_t* aConnectionFlags) override {
123     MOZ_ASSERT(false);
124     return NS_OK;
125   }
SetConnectionFlags(uint32_t aConnectionFlags)126   NS_IMETHOD SetConnectionFlags(uint32_t aConnectionFlags) override {
127     MOZ_ASSERT(false);
128     return NS_OK;
129   }
SetIsPrivate(bool)130   NS_IMETHOD SetIsPrivate(bool) override {
131     MOZ_ASSERT(false);
132     return NS_OK;
133   }
GetTlsFlags(uint32_t * aTlsFlags)134   NS_IMETHOD GetTlsFlags(uint32_t* aTlsFlags) override {
135     MOZ_ASSERT(false);
136     return NS_OK;
137   }
SetTlsFlags(uint32_t aTlsFlags)138   NS_IMETHOD SetTlsFlags(uint32_t aTlsFlags) override {
139     MOZ_ASSERT(false);
140     return NS_OK;
141   }
GetQoSBits(uint8_t * aQoSBits)142   NS_IMETHOD GetQoSBits(uint8_t* aQoSBits) override {
143     MOZ_ASSERT(false);
144     return NS_OK;
145   }
SetQoSBits(uint8_t aQoSBits)146   NS_IMETHOD SetQoSBits(uint8_t aQoSBits) override {
147     MOZ_ASSERT(false);
148     return NS_OK;
149   }
GetRecvBufferSize(uint32_t * aRecvBufferSize)150   NS_IMETHOD GetRecvBufferSize(uint32_t* aRecvBufferSize) override {
151     MOZ_ASSERT(false);
152     return NS_OK;
153   }
GetSendBufferSize(uint32_t * aSendBufferSize)154   NS_IMETHOD GetSendBufferSize(uint32_t* aSendBufferSize) override {
155     MOZ_ASSERT(false);
156     return NS_OK;
157   }
GetKeepaliveEnabled(bool * aKeepaliveEnabled)158   NS_IMETHOD GetKeepaliveEnabled(bool* aKeepaliveEnabled) override {
159     MOZ_ASSERT(false);
160     return NS_OK;
161   }
SetKeepaliveEnabled(bool aKeepaliveEnabled)162   NS_IMETHOD SetKeepaliveEnabled(bool aKeepaliveEnabled) override {
163     MOZ_ASSERT(false);
164     return NS_OK;
165   }
SetKeepaliveVals(int32_t keepaliveIdleTime,int32_t keepaliveRetryInterval)166   NS_IMETHOD SetKeepaliveVals(int32_t keepaliveIdleTime,
167                               int32_t keepaliveRetryInterval) override {
168     MOZ_ASSERT(false);
169     return NS_OK;
170   }
GetResetIPFamilyPreference(bool * aResetIPFamilyPreference)171   NS_IMETHOD GetResetIPFamilyPreference(
172       bool* aResetIPFamilyPreference) override {
173     MOZ_ASSERT(false);
174     return NS_OK;
175   }
GetEchConfigUsed(bool * aEchConfigUsed)176   NS_IMETHOD GetEchConfigUsed(bool* aEchConfigUsed) override {
177     MOZ_ASSERT(false);
178     return NS_OK;
179   }
SetEchConfig(const nsACString & aEchConfig)180   NS_IMETHOD SetEchConfig(const nsACString& aEchConfig) override {
181     MOZ_ASSERT(false);
182     return NS_OK;
183   }
ResolvedByTRR(bool * _retval)184   NS_IMETHOD ResolvedByTRR(bool* _retval) override {
185     MOZ_ASSERT(false);
186     return NS_OK;
187   }
GetRetryDnsIfPossible(bool * aRetryDns)188   NS_IMETHOD GetRetryDnsIfPossible(bool* aRetryDns) override {
189     MOZ_ASSERT(false);
190     return NS_OK;
191   }
GetStatus(nsresult * aStatus)192   NS_IMETHOD GetStatus(nsresult* aStatus) override {
193     MOZ_ASSERT(false);
194     return NS_OK;
195   }
196 
197   // nsITransport
OpenInputStream(uint32_t aFlags,uint32_t aSegmentSize,uint32_t aSegmentCount,nsIInputStream ** _retval)198   NS_IMETHOD OpenInputStream(uint32_t aFlags, uint32_t aSegmentSize,
199                              uint32_t aSegmentCount,
200                              nsIInputStream** _retval) override {
201     MOZ_ASSERT(false);
202     return NS_OK;
203   }
OpenOutputStream(uint32_t aFlags,uint32_t aSegmentSize,uint32_t aSegmentCount,nsIOutputStream ** _retval)204   NS_IMETHOD OpenOutputStream(uint32_t aFlags, uint32_t aSegmentSize,
205                               uint32_t aSegmentCount,
206                               nsIOutputStream** _retval) override {
207     MOZ_ASSERT(false);
208     return NS_OK;
209   }
SetEventSink(nsITransportEventSink * aSink,nsIEventTarget * aEventTarget)210   NS_IMETHOD SetEventSink(nsITransportEventSink* aSink,
211                           nsIEventTarget* aEventTarget) override {
212     MOZ_ASSERT(false);
213     return NS_OK;
214   }
215 
216   // fake except for these methods which are OK to call
217   // nsISocketTransport
SetRecvBufferSize(uint32_t aRecvBufferSize)218   NS_IMETHOD SetRecvBufferSize(uint32_t aRecvBufferSize) override {
219     return NS_OK;
220   }
SetSendBufferSize(uint32_t aSendBufferSize)221   NS_IMETHOD SetSendBufferSize(uint32_t aSendBufferSize) override {
222     return NS_OK;
223   }
224   // nsITransport
Close(nsresult aReason)225   NS_IMETHOD Close(nsresult aReason) override { return NS_OK; }
226 
227  protected:
228   virtual ~FakeSocketTransportProvider() = default;
229 };
230 
231 NS_IMPL_ISUPPORTS(FakeSocketTransportProvider, nsISocketTransport, nsITransport)
232 
233 // Implements some common elements to WebrtcTCPSocketTestOutputStream and
234 // WebrtcTCPSocketTestInputStream.
235 class WebrtcTCPSocketTestStream {
236  public:
237   WebrtcTCPSocketTestStream();
238 
Fail()239   void Fail() { mMustFail = true; }
240 
241   size_t DataLength();
242   template <typename T>
243   void AppendElements(const T* aBuffer, size_t aLength);
244 
245  protected:
246   virtual ~WebrtcTCPSocketTestStream() = default;
247 
248   nsTArray<uint8_t> mData;
249   std::mutex mDataMutex;
250 
251   bool mMustFail;
252 };
253 
WebrtcTCPSocketTestStream()254 WebrtcTCPSocketTestStream::WebrtcTCPSocketTestStream() : mMustFail(false) {}
255 
256 template <typename T>
AppendElements(const T * aBuffer,size_t aLength)257 void WebrtcTCPSocketTestStream::AppendElements(const T* aBuffer,
258                                                size_t aLength) {
259   std::lock_guard<std::mutex> guard(mDataMutex);
260   mData.AppendElements(aBuffer, aLength);
261 }
262 
DataLength()263 size_t WebrtcTCPSocketTestStream::DataLength() {
264   std::lock_guard<std::mutex> guard(mDataMutex);
265   return mData.Length();
266 }
267 
268 class WebrtcTCPSocketTestInputStream : public nsIAsyncInputStream,
269                                        public WebrtcTCPSocketTestStream {
270  public:
271   NS_DECL_THREADSAFE_ISUPPORTS
272   NS_DECL_NSIASYNCINPUTSTREAM
273   NS_DECL_NSIINPUTSTREAM
274 
WebrtcTCPSocketTestInputStream()275   WebrtcTCPSocketTestInputStream()
276       : mMaxReadSize(1024 * 1024), mAllowCallbacks(false) {}
277 
278   void DoCallback();
279   void CallCallback(const nsCOMPtr<nsIInputStreamCallback>& aCallback);
AllowCallbacks()280   void AllowCallbacks() { mAllowCallbacks = true; }
281 
282   size_t mMaxReadSize;
283 
284  protected:
285   virtual ~WebrtcTCPSocketTestInputStream() = default;
286 
287  private:
288   nsCOMPtr<nsIInputStreamCallback> mCallback;
289   nsCOMPtr<nsIEventTarget> mCallbackTarget;
290 
291   bool mAllowCallbacks;
292 };
293 
NS_IMPL_ISUPPORTS(WebrtcTCPSocketTestInputStream,nsIAsyncInputStream,nsIInputStream)294 NS_IMPL_ISUPPORTS(WebrtcTCPSocketTestInputStream, nsIAsyncInputStream,
295                   nsIInputStream)
296 
297 nsresult WebrtcTCPSocketTestInputStream::AsyncWait(
298     nsIInputStreamCallback* aCallback, uint32_t aFlags,
299     uint32_t aRequestedCount, nsIEventTarget* aEventTarget) {
300   MOZ_ASSERT(!aEventTarget, "no event target should be set");
301 
302   mCallback = aCallback;
303   mCallbackTarget = NS_GetCurrentThread();
304 
305   if (mAllowCallbacks && DataLength() > 0) {
306     DoCallback();
307   }
308 
309   return NS_OK;
310 }
311 
CloseWithStatus(nsresult aStatus)312 nsresult WebrtcTCPSocketTestInputStream::CloseWithStatus(nsresult aStatus) {
313   return Close();
314 }
315 
Close()316 nsresult WebrtcTCPSocketTestInputStream::Close() { return NS_OK; }
317 
Available(uint64_t * aAvailable)318 nsresult WebrtcTCPSocketTestInputStream::Available(uint64_t* aAvailable) {
319   *aAvailable = DataLength();
320   return NS_OK;
321 }
322 
Read(char * aBuffer,uint32_t aCount,uint32_t * aRead)323 nsresult WebrtcTCPSocketTestInputStream::Read(char* aBuffer, uint32_t aCount,
324                                               uint32_t* aRead) {
325   std::lock_guard<std::mutex> guard(mDataMutex);
326   if (mMustFail) {
327     return NS_ERROR_FAILURE;
328   }
329   *aRead = std::min({(size_t)aCount, mData.Length(), mMaxReadSize});
330   memcpy(aBuffer, mData.Elements(), *aRead);
331   mData.RemoveElementsAt(0, *aRead);
332   return *aRead > 0 ? NS_OK : NS_BASE_STREAM_WOULD_BLOCK;
333 }
334 
ReadSegments(nsWriteSegmentFun aWriter,void * aClosure,uint32_t aCount,uint32_t * _retval)335 nsresult WebrtcTCPSocketTestInputStream::ReadSegments(nsWriteSegmentFun aWriter,
336                                                       void* aClosure,
337                                                       uint32_t aCount,
338                                                       uint32_t* _retval) {
339   MOZ_ASSERT(false);
340   return NS_OK;
341 }
342 
IsNonBlocking(bool * aIsNonBlocking)343 nsresult WebrtcTCPSocketTestInputStream::IsNonBlocking(bool* aIsNonBlocking) {
344   *aIsNonBlocking = true;
345   return NS_OK;
346 }
347 
CallCallback(const nsCOMPtr<nsIInputStreamCallback> & aCallback)348 void WebrtcTCPSocketTestInputStream::CallCallback(
349     const nsCOMPtr<nsIInputStreamCallback>& aCallback) {
350   aCallback->OnInputStreamReady(this);
351 }
352 
DoCallback()353 void WebrtcTCPSocketTestInputStream::DoCallback() {
354   if (mCallback) {
355     mCallbackTarget->Dispatch(
356         NewRunnableMethod<const nsCOMPtr<nsIInputStreamCallback>&>(
357             "WebrtcTCPSocketTestInputStream::DoCallback", this,
358             &WebrtcTCPSocketTestInputStream::CallCallback,
359             std::move(mCallback)));
360 
361     mCallbackTarget = nullptr;
362   }
363 }
364 
365 class WebrtcTCPSocketTestOutputStream : public nsIAsyncOutputStream,
366                                         public WebrtcTCPSocketTestStream {
367  public:
368   NS_DECL_THREADSAFE_ISUPPORTS
369   NS_DECL_NSIASYNCOUTPUTSTREAM
370   NS_DECL_NSIOUTPUTSTREAM
371 
WebrtcTCPSocketTestOutputStream()372   WebrtcTCPSocketTestOutputStream() : mMaxWriteSize(1024 * 1024) {}
373 
374   void DoCallback();
375   void CallCallback(const nsCOMPtr<nsIOutputStreamCallback>& aCallback);
376 
377   std::string DataString();
378 
379   uint32_t mMaxWriteSize;
380 
381  protected:
382   virtual ~WebrtcTCPSocketTestOutputStream() = default;
383 
384  private:
385   nsCOMPtr<nsIOutputStreamCallback> mCallback;
386   nsCOMPtr<nsIEventTarget> mCallbackTarget;
387 };
388 
NS_IMPL_ISUPPORTS(WebrtcTCPSocketTestOutputStream,nsIAsyncOutputStream,nsIOutputStream)389 NS_IMPL_ISUPPORTS(WebrtcTCPSocketTestOutputStream, nsIAsyncOutputStream,
390                   nsIOutputStream)
391 
392 nsresult WebrtcTCPSocketTestOutputStream::AsyncWait(
393     nsIOutputStreamCallback* aCallback, uint32_t aFlags,
394     uint32_t aRequestedCount, nsIEventTarget* aEventTarget) {
395   MOZ_ASSERT(!aEventTarget, "no event target should be set");
396 
397   mCallback = aCallback;
398   mCallbackTarget = NS_GetCurrentThread();
399 
400   return NS_OK;
401 }
402 
CloseWithStatus(nsresult aStatus)403 nsresult WebrtcTCPSocketTestOutputStream::CloseWithStatus(nsresult aStatus) {
404   return Close();
405 }
406 
Close()407 nsresult WebrtcTCPSocketTestOutputStream::Close() { return NS_OK; }
408 
Flush()409 nsresult WebrtcTCPSocketTestOutputStream::Flush() { return NS_OK; }
410 
Write(const char * aBuffer,uint32_t aCount,uint32_t * aWrote)411 nsresult WebrtcTCPSocketTestOutputStream::Write(const char* aBuffer,
412                                                 uint32_t aCount,
413                                                 uint32_t* aWrote) {
414   if (mMustFail) {
415     return NS_ERROR_FAILURE;
416   }
417   *aWrote = std::min(aCount, mMaxWriteSize);
418   AppendElements(aBuffer, *aWrote);
419   return NS_OK;
420 }
421 
WriteSegments(nsReadSegmentFun aReader,void * aClosure,uint32_t aCount,uint32_t * _retval)422 nsresult WebrtcTCPSocketTestOutputStream::WriteSegments(
423     nsReadSegmentFun aReader, void* aClosure, uint32_t aCount,
424     uint32_t* _retval) {
425   MOZ_ASSERT(false);
426   return NS_OK;
427 }
428 
WriteFrom(nsIInputStream * aFromStream,uint32_t aCount,uint32_t * _retval)429 nsresult WebrtcTCPSocketTestOutputStream::WriteFrom(nsIInputStream* aFromStream,
430                                                     uint32_t aCount,
431                                                     uint32_t* _retval) {
432   MOZ_ASSERT(false);
433   return NS_OK;
434 }
435 
IsNonBlocking(bool * aIsNonBlocking)436 nsresult WebrtcTCPSocketTestOutputStream::IsNonBlocking(bool* aIsNonBlocking) {
437   *aIsNonBlocking = true;
438   return NS_OK;
439 }
440 
CallCallback(const nsCOMPtr<nsIOutputStreamCallback> & aCallback)441 void WebrtcTCPSocketTestOutputStream::CallCallback(
442     const nsCOMPtr<nsIOutputStreamCallback>& aCallback) {
443   aCallback->OnOutputStreamReady(this);
444 }
445 
DoCallback()446 void WebrtcTCPSocketTestOutputStream::DoCallback() {
447   if (mCallback) {
448     mCallbackTarget->Dispatch(
449         NewRunnableMethod<const nsCOMPtr<nsIOutputStreamCallback>&>(
450             "WebrtcTCPSocketTestOutputStream::CallCallback", this,
451             &WebrtcTCPSocketTestOutputStream::CallCallback,
452             std::move(mCallback)));
453 
454     mCallbackTarget = nullptr;
455   }
456 }
457 
DataString()458 std::string WebrtcTCPSocketTestOutputStream::DataString() {
459   std::lock_guard<std::mutex> guard(mDataMutex);
460   return std::string((char*)mData.Elements(), mData.Length());
461 }
462 
463 // Fake as in not the real WebrtcTCPSocket but real enough
464 class FakeWebrtcTCPSocket : public WebrtcTCPSocket {
465  public:
FakeWebrtcTCPSocket(WebrtcTCPSocketCallback * aCallback)466   explicit FakeWebrtcTCPSocket(WebrtcTCPSocketCallback* aCallback)
467       : WebrtcTCPSocket(aCallback) {}
468 
469  protected:
470   virtual ~FakeWebrtcTCPSocket() = default;
471 
472   void InvokeOnClose(nsresult aReason) override;
473   void InvokeOnConnected() override;
474   void InvokeOnRead(nsTArray<uint8_t>&& aReadData) override;
475 };
476 
InvokeOnClose(nsresult aReason)477 void FakeWebrtcTCPSocket::InvokeOnClose(nsresult aReason) {
478   mProxyCallbacks->OnClose(aReason);
479 }
480 
InvokeOnConnected()481 void FakeWebrtcTCPSocket::InvokeOnConnected() {
482   mProxyCallbacks->OnConnected("http"_ns);
483 }
484 
InvokeOnRead(nsTArray<uint8_t> && aReadData)485 void FakeWebrtcTCPSocket::InvokeOnRead(nsTArray<uint8_t>&& aReadData) {
486   mProxyCallbacks->OnRead(std::move(aReadData));
487 }
488 
489 class WebrtcTCPSocketTest : public MtransportTest {
490  public:
WebrtcTCPSocketTest()491   WebrtcTCPSocketTest()
492       : MtransportTest(),
493         mSocketThread(nullptr),
494         mSocketTransport(nullptr),
495         mInputStream(nullptr),
496         mOutputStream(nullptr),
497         mChannel(nullptr),
498         mCallback(nullptr),
499         mOnCloseCalled(false),
500         mOnConnectedCalled(false) {}
501 
502   // WebrtcTCPSocketCallback forwards from mCallback
503   void OnClose(nsresult aReason);
504   void OnConnected(const nsCString& aProxyType);
505   void OnRead(nsTArray<uint8_t>&& aReadData);
506 
507   void SetUp() override;
508   void TearDown() override;
509 
510   void DoTransportAvailable();
511 
512   std::string ReadDataAsString();
513   std::string GetDataLarge();
514 
515   nsCOMPtr<nsIEventTarget> mSocketThread;
516 
517   nsCOMPtr<nsISocketTransport> mSocketTransport;
518   RefPtr<WebrtcTCPSocketTestInputStream> mInputStream;
519   RefPtr<WebrtcTCPSocketTestOutputStream> mOutputStream;
520   RefPtr<FakeWebrtcTCPSocket> mChannel;
521   RefPtr<WebrtcTCPSocketTestCallback> mCallback;
522 
523   bool mOnCloseCalled;
524   bool mOnConnectedCalled;
525 
526   size_t ReadDataLength();
527   template <typename T>
528   void AppendReadData(const T* aBuffer, size_t aLength);
529 
530  private:
531   nsTArray<uint8_t> mReadData;
532   std::mutex mReadDataMutex;
533 };
534 
535 class WebrtcTCPSocketTestCallback : public WebrtcTCPSocketCallback {
536  public:
NS_INLINE_DECL_THREADSAFE_REFCOUNTING(WebrtcTCPSocketTestCallback,override)537   NS_INLINE_DECL_THREADSAFE_REFCOUNTING(WebrtcTCPSocketTestCallback, override)
538 
539   explicit WebrtcTCPSocketTestCallback(WebrtcTCPSocketTest* aTest)
540       : mTest(aTest) {}
541 
542   // WebrtcTCPSocketCallback
543   void OnClose(nsresult aReason) override;
544   void OnConnected(const nsCString& aProxyType) override;
545   void OnRead(nsTArray<uint8_t>&& aReadData) override;
546 
547  protected:
548   virtual ~WebrtcTCPSocketTestCallback() = default;
549 
550  private:
551   WebrtcTCPSocketTest* mTest;
552 };
553 
SetUp()554 void WebrtcTCPSocketTest::SetUp() {
555   nsresult rv;
556   // WebrtcTCPSocket's threading model is the same as mtransport
557   // all socket operations are done on the socket thread
558   // callbacks are invoked on the main thread
559   mSocketThread = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv);
560   ASSERT_TRUE(NS_SUCCEEDED(rv));
561 
562   mSocketTransport = new FakeSocketTransportProvider();
563   mInputStream = new WebrtcTCPSocketTestInputStream();
564   mOutputStream = new WebrtcTCPSocketTestOutputStream();
565   mCallback = new WebrtcTCPSocketTestCallback(this);
566   mChannel = new FakeWebrtcTCPSocket(mCallback.get());
567 }
568 
TearDown()569 void WebrtcTCPSocketTest::TearDown() {}
570 
571 // WebrtcTCPSocketCallback
OnRead(nsTArray<uint8_t> && aReadData)572 void WebrtcTCPSocketTest::OnRead(nsTArray<uint8_t>&& aReadData) {
573   AppendReadData(aReadData.Elements(), aReadData.Length());
574 }
575 
OnConnected(const nsCString & aProxyType)576 void WebrtcTCPSocketTest::OnConnected(const nsCString& aProxyType) {
577   mOnConnectedCalled = true;
578 }
579 
OnClose(nsresult aReason)580 void WebrtcTCPSocketTest::OnClose(nsresult aReason) { mOnCloseCalled = true; }
581 
DoTransportAvailable()582 void WebrtcTCPSocketTest::DoTransportAvailable() {
583   if (!mSocketThread->IsOnCurrentThread()) {
584     mSocketThread->Dispatch(
585         NS_NewRunnableFunction("DoTransportAvailable", [this]() -> void {
586           nsresult rv;
587           rv = mChannel->OnTransportAvailable(mSocketTransport, mInputStream,
588                                               mOutputStream);
589           ASSERT_EQ(NS_OK, rv);
590         }));
591   } else {
592     // should always be called on the main thread
593     MOZ_ASSERT(0);
594   }
595 }
596 
ReadDataAsString()597 std::string WebrtcTCPSocketTest::ReadDataAsString() {
598   std::lock_guard<std::mutex> guard(mReadDataMutex);
599   return std::string((char*)mReadData.Elements(), mReadData.Length());
600 }
601 
GetDataLarge()602 std::string WebrtcTCPSocketTest::GetDataLarge() {
603   std::string data;
604   for (int i = 0; i < kDataLargeOuterLoopCount * kDataLargeInnerLoopCount;
605        ++i) {
606     data += kReadData;
607   }
608   return data;
609 }
610 
611 template <typename T>
AppendReadData(const T * aBuffer,size_t aLength)612 void WebrtcTCPSocketTest::AppendReadData(const T* aBuffer, size_t aLength) {
613   std::lock_guard<std::mutex> guard(mReadDataMutex);
614   mReadData.AppendElements(aBuffer, aLength);
615 }
616 
ReadDataLength()617 size_t WebrtcTCPSocketTest::ReadDataLength() {
618   std::lock_guard<std::mutex> guard(mReadDataMutex);
619   return mReadData.Length();
620 }
621 
OnClose(nsresult aReason)622 void WebrtcTCPSocketTestCallback::OnClose(nsresult aReason) {
623   mTest->OnClose(aReason);
624 }
625 
OnConnected(const nsCString & aProxyType)626 void WebrtcTCPSocketTestCallback::OnConnected(const nsCString& aProxyType) {
627   mTest->OnConnected(aProxyType);
628 }
629 
OnRead(nsTArray<uint8_t> && aReadData)630 void WebrtcTCPSocketTestCallback::OnRead(nsTArray<uint8_t>&& aReadData) {
631   mTest->OnRead(std::move(aReadData));
632 }
633 
634 }  // namespace mozilla
635 
636 typedef mozilla::WebrtcTCPSocketTest WebrtcTCPSocketTest;
637 
TEST_F(WebrtcTCPSocketTest,SetUp)638 TEST_F(WebrtcTCPSocketTest, SetUp) {}
639 
TEST_F(WebrtcTCPSocketTest,TransportAvailable)640 TEST_F(WebrtcTCPSocketTest, TransportAvailable) {
641   DoTransportAvailable();
642   ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout);
643 }
644 
TEST_F(WebrtcTCPSocketTest,Read)645 TEST_F(WebrtcTCPSocketTest, Read) {
646   DoTransportAvailable();
647   ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout);
648 
649   mInputStream->AppendElements(kReadData, kReadDataLength);
650   mInputStream->DoCallback();
651 
652   ASSERT_TRUE_WAIT(ReadDataAsString() == kReadDataString, kDefaultTestTimeout);
653 }
654 
TEST_F(WebrtcTCPSocketTest,Write)655 TEST_F(WebrtcTCPSocketTest, Write) {
656   DoTransportAvailable();
657   ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout);
658 
659   nsTArray<uint8_t> data;
660   data.AppendElements(kReadData, kReadDataLength);
661   mChannel->Write(std::move(data));
662 
663   ASSERT_TRUE_WAIT(mChannel->CountUnwrittenBytes() == kReadDataLength,
664                    kDefaultTestTimeout);
665 
666   mOutputStream->DoCallback();
667 
668   ASSERT_TRUE_WAIT(mOutputStream->DataString() == kReadDataString,
669                    kDefaultTestTimeout);
670 }
671 
TEST_F(WebrtcTCPSocketTest,ReadFail)672 TEST_F(WebrtcTCPSocketTest, ReadFail) {
673   DoTransportAvailable();
674   ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout);
675 
676   mInputStream->AppendElements(kReadData, kReadDataLength);
677   mInputStream->Fail();
678   mInputStream->DoCallback();
679 
680   ASSERT_TRUE_WAIT(mOnCloseCalled, kDefaultTestTimeout);
681   ASSERT_EQ(0U, ReadDataLength());
682 }
683 
TEST_F(WebrtcTCPSocketTest,WriteFail)684 TEST_F(WebrtcTCPSocketTest, WriteFail) {
685   DoTransportAvailable();
686   ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout);
687 
688   nsTArray<uint8_t> array;
689   array.AppendElements(kReadData, kReadDataLength);
690   mChannel->Write(std::move(array));
691 
692   ASSERT_TRUE_WAIT(mChannel->CountUnwrittenBytes() == kReadDataLength,
693                    kDefaultTestTimeout);
694 
695   mOutputStream->Fail();
696   mOutputStream->DoCallback();
697 
698   ASSERT_TRUE_WAIT(mOnCloseCalled, kDefaultTestTimeout);
699   ASSERT_EQ(0U, mOutputStream->DataLength());
700 }
701 
TEST_F(WebrtcTCPSocketTest,ReadLarge)702 TEST_F(WebrtcTCPSocketTest, ReadLarge) {
703   DoTransportAvailable();
704   ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout);
705 
706   const std::string data = GetDataLarge();
707 
708   mInputStream->AppendElements(data.c_str(), data.length());
709   // make sure reading loops more than once
710   mInputStream->mMaxReadSize = 3072;
711   mInputStream->AllowCallbacks();
712   mInputStream->DoCallback();
713 
714   ASSERT_TRUE_WAIT(ReadDataAsString() == data, kDefaultTestTimeout);
715 }
716 
TEST_F(WebrtcTCPSocketTest,WriteLarge)717 TEST_F(WebrtcTCPSocketTest, WriteLarge) {
718   DoTransportAvailable();
719   ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout);
720 
721   const std::string data = GetDataLarge();
722 
723   for (int i = 0; i < kDataLargeOuterLoopCount; ++i) {
724     nsTArray<uint8_t> array;
725     int chunkSize = kReadDataString.length() * kDataLargeInnerLoopCount;
726     int offset = i * chunkSize;
727     array.AppendElements(data.c_str() + offset, chunkSize);
728     mChannel->Write(std::move(array));
729   }
730 
731   ASSERT_TRUE_WAIT(mChannel->CountUnwrittenBytes() == data.length(),
732                    kDefaultTestTimeout);
733 
734   // make sure writing loops more than once per write request
735   mOutputStream->mMaxWriteSize = 1024;
736   mOutputStream->DoCallback();
737 
738   ASSERT_TRUE_WAIT(mOutputStream->DataString() == data, kDefaultTestTimeout);
739 }
740