1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=8 sts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this file,
5  * You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 #include "gtest/gtest.h"
8 
9 #include <algorithm>
10 #include <cstdint>
11 #include <cstdlib>
12 #include <new>
13 #include <numeric>
14 #include <ostream>
15 #include <string>
16 #include <type_traits>
17 #include <utility>
18 #include <vector>
19 #include "ErrorList.h"
20 #include "mozilla/AlreadyAddRefed.h"
21 #include "mozilla/Assertions.h"
22 #include "mozilla/Attributes.h"
23 #include "mozilla/NotNull.h"
24 #include "mozilla/RefPtr.h"
25 #include "mozilla/Scoped.h"
26 #include "mozilla/Span.h"
27 #include "mozilla/UniquePtr.h"
28 #include "mozilla/dom/SafeRefPtr.h"
29 #include "mozilla/dom/quota/DecryptingInputStream_impl.h"
30 #include "mozilla/dom/quota/DummyCipherStrategy.h"
31 #include "mozilla/dom/quota/EncryptedBlock.h"
32 #include "mozilla/dom/quota/EncryptingOutputStream_impl.h"
33 #include "mozilla/dom/quota/MemoryOutputStream.h"
34 #include "mozilla/dom/quota/NSSCipherStrategy.h"
35 #include "mozilla/fallible.h"
36 #include "nsCOMPtr.h"
37 #include "nsError.h"
38 #include "nsICloneableInputStream.h"
39 #include "nsIInputStream.h"
40 #include "nsIOutputStream.h"
41 #include "nsISeekableStream.h"
42 #include "nsISupports.h"
43 #include "nsITellableStream.h"
44 #include "nsStreamUtils.h"
45 #include "nsString.h"
46 #include "nsStringFwd.h"
47 #include "nsTArray.h"
48 #include "nscore.h"
49 #include "nss.h"
50 
51 namespace mozilla::dom::quota {
52 
53 // Similar to ArrayBufferInputStream from netwerk/base/ArrayBufferInputStream.h,
54 // but this is initialized from a Span on construction, rather than lazily from
55 // a JS ArrayBuffer.
56 class ArrayBufferInputStream : public nsIInputStream,
57                                public nsISeekableStream,
58                                public nsICloneableInputStream {
59  public:
60   explicit ArrayBufferInputStream(mozilla::Span<const uint8_t> aData);
61 
62   NS_DECL_THREADSAFE_ISUPPORTS
63   NS_DECL_NSIINPUTSTREAM
64   NS_DECL_NSITELLABLESTREAM
65   NS_DECL_NSISEEKABLESTREAM
66   NS_DECL_NSICLONEABLEINPUTSTREAM
67 
68  private:
69   virtual ~ArrayBufferInputStream() = default;
70 
71   mozilla::UniquePtr<char[]> mArrayBuffer;
72   uint32_t mBufferLength;
73   uint32_t mPos;
74   bool mClosed;
75 };
76 
77 NS_IMPL_ADDREF(ArrayBufferInputStream);
78 NS_IMPL_RELEASE(ArrayBufferInputStream);
79 
80 NS_INTERFACE_MAP_BEGIN(ArrayBufferInputStream)
NS_INTERFACE_MAP_ENTRY(nsIInputStream)81   NS_INTERFACE_MAP_ENTRY(nsIInputStream)
82   NS_INTERFACE_MAP_ENTRY(nsISeekableStream)
83   NS_INTERFACE_MAP_ENTRY(nsICloneableInputStream)
84   NS_INTERFACE_MAP_ENTRY_AMBIGUOUS(nsISupports, nsIInputStream)
85 NS_INTERFACE_MAP_END
86 
87 ArrayBufferInputStream::ArrayBufferInputStream(
88     mozilla::Span<const uint8_t> aData)
89     : mArrayBuffer(MakeUnique<char[]>(aData.Length())),
90       mBufferLength(aData.Length()),
91       mPos(0),
92       mClosed(false) {
93   std::copy(aData.cbegin(), aData.cend(), mArrayBuffer.get());
94 }
95 
96 NS_IMETHODIMP
Close()97 ArrayBufferInputStream::Close() {
98   mClosed = true;
99   return NS_OK;
100 }
101 
102 NS_IMETHODIMP
Available(uint64_t * aCount)103 ArrayBufferInputStream::Available(uint64_t* aCount) {
104   if (mClosed) {
105     return NS_BASE_STREAM_CLOSED;
106   }
107 
108   if (mArrayBuffer) {
109     *aCount = mBufferLength ? mBufferLength - mPos : 0;
110   } else {
111     *aCount = 0;
112   }
113 
114   return NS_OK;
115 }
116 
117 NS_IMETHODIMP
Read(char * aBuf,uint32_t aCount,uint32_t * aReadCount)118 ArrayBufferInputStream::Read(char* aBuf, uint32_t aCount,
119                              uint32_t* aReadCount) {
120   return ReadSegments(NS_CopySegmentToBuffer, aBuf, aCount, aReadCount);
121 }
122 
123 NS_IMETHODIMP
ReadSegments(nsWriteSegmentFun writer,void * closure,uint32_t aCount,uint32_t * result)124 ArrayBufferInputStream::ReadSegments(nsWriteSegmentFun writer, void* closure,
125                                      uint32_t aCount, uint32_t* result) {
126   MOZ_ASSERT(result, "null ptr");
127   MOZ_ASSERT(mBufferLength >= mPos, "bad stream state");
128 
129   if (mClosed) {
130     return NS_BASE_STREAM_CLOSED;
131   }
132 
133   MOZ_ASSERT(mArrayBuffer || (mPos == mBufferLength),
134              "stream inited incorrectly");
135 
136   *result = 0;
137   while (mPos < mBufferLength) {
138     uint32_t remaining = mBufferLength - mPos;
139     MOZ_ASSERT(mArrayBuffer);
140 
141     uint32_t count = std::min(aCount, remaining);
142     if (count == 0) {
143       break;
144     }
145 
146     uint32_t written;
147     nsresult rv = writer(this, closure, &mArrayBuffer[0] + mPos, *result, count,
148                          &written);
149     if (NS_FAILED(rv)) {
150       // InputStreams do not propagate errors to caller.
151       return NS_OK;
152     }
153 
154     MOZ_ASSERT(written <= count,
155                "writer should not write more than we asked it to write");
156     mPos += written;
157     *result += written;
158     aCount -= written;
159   }
160 
161   return NS_OK;
162 }
163 
164 NS_IMETHODIMP
IsNonBlocking(bool * aNonBlocking)165 ArrayBufferInputStream::IsNonBlocking(bool* aNonBlocking) {
166   // Actually, the stream never blocks, but we lie about it because of the
167   // assumptions in DecryptingInputStream.
168   *aNonBlocking = false;
169   return NS_OK;
170 }
171 
Tell(int64_t * const aRetval)172 NS_IMETHODIMP ArrayBufferInputStream::Tell(int64_t* const aRetval) {
173   MOZ_ASSERT(aRetval);
174 
175   *aRetval = mPos;
176 
177   return NS_OK;
178 }
179 
Seek(const int32_t aWhence,const int64_t aOffset)180 NS_IMETHODIMP ArrayBufferInputStream::Seek(const int32_t aWhence,
181                                            const int64_t aOffset) {
182   // XXX This is not safe. it's hard to use CheckedInt here, though. As long as
183   // the class is only used for testing purposes, that's probably fine.
184 
185   int32_t newPos = mPos;
186   switch (aWhence) {
187     case NS_SEEK_SET:
188       newPos = aOffset;
189       break;
190     case NS_SEEK_CUR:
191       newPos += aOffset;
192       break;
193     case NS_SEEK_END:
194       newPos = mBufferLength;
195       newPos += aOffset;
196       break;
197     default:
198       return NS_ERROR_ILLEGAL_VALUE;
199   }
200   if (newPos < 0 || static_cast<uint32_t>(newPos) > mBufferLength) {
201     return NS_ERROR_ILLEGAL_VALUE;
202   }
203   mPos = newPos;
204 
205   return NS_OK;
206 }
207 
SetEOF()208 NS_IMETHODIMP ArrayBufferInputStream::SetEOF() {
209   // Truncating is not supported on a read-only stream.
210   return NS_ERROR_NOT_IMPLEMENTED;
211 }
212 
GetCloneable(bool * aCloneable)213 NS_IMETHODIMP ArrayBufferInputStream::GetCloneable(bool* aCloneable) {
214   *aCloneable = true;
215   return NS_OK;
216 }
217 
Clone(nsIInputStream ** _retval)218 NS_IMETHODIMP ArrayBufferInputStream::Clone(nsIInputStream** _retval) {
219   *_retval = MakeAndAddRef<ArrayBufferInputStream>(
220                  AsBytes(Span{mArrayBuffer.get(), mBufferLength}))
221                  .take();
222 
223   return NS_OK;
224 }
225 }  // namespace mozilla::dom::quota
226 
227 namespace mozilla {
228 MOZ_TYPE_SPECIFIC_SCOPED_POINTER_TEMPLATE(ScopedNSSContext, NSSInitContext,
229                                           NSS_ShutdownContext);
230 
231 }  // namespace mozilla
232 
233 using namespace mozilla;
234 using namespace mozilla::dom::quota;
235 
236 class DOM_Quota_EncryptedStream : public ::testing::Test {
237  public:
SetUpTestCase()238   static void SetUpTestCase() {
239     // Do this only once, do not tear it down per test case.
240     if (!sNssContext) {
241       sNssContext =
242           NSS_InitContext("", "", "", "", nullptr,
243                           NSS_INIT_READONLY | NSS_INIT_NOCERTDB |
244                               NSS_INIT_NOMODDB | NSS_INIT_FORCEOPEN |
245                               NSS_INIT_OPTIMIZESPACE | NSS_INIT_NOROOTINIT);
246     }
247   }
248 
TearDownTestCase()249   static void TearDownTestCase() { sNssContext = nullptr; }
250 
251  private:
252   inline static ScopedNSSContext sNssContext = ScopedNSSContext{};
253 };
254 
255 enum struct FlushMode { AfterEachChunk, Never };
256 enum struct ChunkSize { SingleByte, Unaligned, DataSize };
257 
258 using PackedTestParams =
259     std::tuple<size_t, ChunkSize, ChunkSize, size_t, FlushMode>;
260 
EffectiveChunkSize(const ChunkSize aChunkSize,const size_t aDataSize)261 static size_t EffectiveChunkSize(const ChunkSize aChunkSize,
262                                  const size_t aDataSize) {
263   switch (aChunkSize) {
264     case ChunkSize::SingleByte:
265       return 1;
266     case ChunkSize::Unaligned:
267       return 17;
268     case ChunkSize::DataSize:
269       return aDataSize;
270   }
271   MOZ_CRASH("Unknown ChunkSize");
272 }
273 
274 struct TestParams {
TestParamsTestParams275   MOZ_IMPLICIT constexpr TestParams(const PackedTestParams& aPackedParams)
276       : mDataSize(std::get<0>(aPackedParams)),
277         mWriteChunkSize(std::get<1>(aPackedParams)),
278         mReadChunkSize(std::get<2>(aPackedParams)),
279         mBlockSize(std::get<3>(aPackedParams)),
280         mFlushMode(std::get<4>(aPackedParams)) {}
281 
DataSizeTestParams282   constexpr size_t DataSize() const { return mDataSize; }
283 
EffectiveWriteChunkSizeTestParams284   size_t EffectiveWriteChunkSize() const {
285     return EffectiveChunkSize(mWriteChunkSize, mDataSize);
286   }
287 
EffectiveReadChunkSizeTestParams288   size_t EffectiveReadChunkSize() const {
289     return EffectiveChunkSize(mReadChunkSize, mDataSize);
290   }
291 
BlockSizeTestParams292   constexpr size_t BlockSize() const { return mBlockSize; }
293 
FlushModeTestParams294   constexpr enum FlushMode FlushMode() const { return mFlushMode; }
295 
296  private:
297   size_t mDataSize;
298 
299   ChunkSize mWriteChunkSize;
300   ChunkSize mReadChunkSize;
301 
302   size_t mBlockSize;
303   enum FlushMode mFlushMode;
304 };
305 
TestParamToString(const testing::TestParamInfo<PackedTestParams> & aTestParams)306 std::string TestParamToString(
307     const testing::TestParamInfo<PackedTestParams>& aTestParams) {
308   const TestParams& testParams = aTestParams.param;
309 
310   static constexpr char kSeparator[] = "_";
311 
312   std::stringstream ss;
313   ss << "data" << testParams.DataSize() << kSeparator << "writechunk"
314      << testParams.EffectiveWriteChunkSize() << kSeparator << "readchunk"
315      << testParams.EffectiveReadChunkSize() << kSeparator << "block"
316      << testParams.BlockSize() << kSeparator;
317   switch (testParams.FlushMode()) {
318     case FlushMode::Never:
319       ss << "FlushNever";
320       break;
321     case FlushMode::AfterEachChunk:
322       ss << "FlushAfterEachChunk";
323       break;
324   };
325   return ss.str();
326 }
327 
328 class ParametrizedCryptTest
329     : public DOM_Quota_EncryptedStream,
330       public testing::WithParamInterface<PackedTestParams> {};
331 
MakeTestData(const size_t aDataSize)332 static auto MakeTestData(const size_t aDataSize) {
333   auto data = nsTArray<uint8_t>();
334   data.SetLength(aDataSize);
335   std::iota(data.begin(), data.end(), 0);
336   return data;
337 }
338 
339 template <typename CipherStrategy>
WriteTestData(nsCOMPtr<nsIOutputStream> && aBaseOutputStream,const Span<const uint8_t> aData,const size_t aWriteChunkSize,const size_t aBlockSize,const typename CipherStrategy::KeyType & aKey,const FlushMode aFlushMode)340 static void WriteTestData(nsCOMPtr<nsIOutputStream>&& aBaseOutputStream,
341                           const Span<const uint8_t> aData,
342                           const size_t aWriteChunkSize, const size_t aBlockSize,
343                           const typename CipherStrategy::KeyType& aKey,
344                           const FlushMode aFlushMode) {
345   auto outStream = MakeSafeRefPtr<EncryptingOutputStream<CipherStrategy>>(
346       std::move(aBaseOutputStream), aBlockSize, aKey);
347 
348   for (auto remaining = aData; !remaining.IsEmpty();) {
349     auto [currentChunk, newRemaining] =
350         remaining.SplitAt(std::min(aWriteChunkSize, remaining.Length()));
351     remaining = newRemaining;
352 
353     uint32_t written;
354     EXPECT_EQ(NS_OK, outStream->Write(
355                          reinterpret_cast<const char*>(currentChunk.Elements()),
356                          currentChunk.Length(), &written));
357     EXPECT_EQ(currentChunk.Length(), written);
358 
359     if (aFlushMode == FlushMode::AfterEachChunk) {
360       outStream->Flush();
361     }
362   }
363 
364   // Close explicitly so we can check the result.
365   EXPECT_EQ(NS_OK, outStream->Close());
366 }
367 
368 template <typename CipherStrategy>
NoExtraChecks(DecryptingInputStream<CipherStrategy> & aInputStream,Span<const uint8_t> aExpectedData,Span<const uint8_t> aRemainder)369 static void NoExtraChecks(DecryptingInputStream<CipherStrategy>& aInputStream,
370                           Span<const uint8_t> aExpectedData,
371                           Span<const uint8_t> aRemainder) {}
372 
373 template <typename CipherStrategy,
374           typename ExtraChecks = decltype(NoExtraChecks<CipherStrategy>)>
ReadTestData(DecryptingInputStream<CipherStrategy> & aDecryptingInputStream,const Span<const uint8_t> aExpectedData,const size_t aReadChunkSize,const ExtraChecks & aExtraChecks=NoExtraChecks<CipherStrategy>)375 static void ReadTestData(
376     DecryptingInputStream<CipherStrategy>& aDecryptingInputStream,
377     const Span<const uint8_t> aExpectedData, const size_t aReadChunkSize,
378     const ExtraChecks& aExtraChecks = NoExtraChecks<CipherStrategy>) {
379   auto readData = nsTArray<uint8_t>();
380   readData.SetLength(aReadChunkSize);
381   for (auto remainder = aExpectedData; !remainder.IsEmpty();) {
382     auto [currentExpected, newExpectedRemainder] =
383         remainder.SplitAt(std::min(aReadChunkSize, remainder.Length()));
384     remainder = newExpectedRemainder;
385 
386     uint32_t read;
387     EXPECT_EQ(NS_OK, aDecryptingInputStream.Read(
388                          reinterpret_cast<char*>(readData.Elements()),
389                          currentExpected.Length(), &read));
390     EXPECT_EQ(currentExpected.Length(), read);
391     EXPECT_EQ(currentExpected,
392               Span{readData}.First(currentExpected.Length()).AsConst());
393 
394     aExtraChecks(aDecryptingInputStream, aExpectedData, remainder);
395   }
396 
397   // Expect EOF.
398   uint32_t read;
399   EXPECT_EQ(NS_OK, aDecryptingInputStream.Read(
400                        reinterpret_cast<char*>(readData.Elements()),
401                        readData.Length(), &read));
402   EXPECT_EQ(0u, read);
403 }
404 
405 template <typename CipherStrategy,
406           typename ExtraChecks = decltype(NoExtraChecks<CipherStrategy>)>
ReadTestData(MovingNotNull<nsCOMPtr<nsIInputStream>> && aBaseInputStream,const Span<const uint8_t> aExpectedData,const size_t aReadChunkSize,const size_t aBlockSize,const typename CipherStrategy::KeyType & aKey,const ExtraChecks & aExtraChecks=NoExtraChecks<CipherStrategy>)407 static auto ReadTestData(
408     MovingNotNull<nsCOMPtr<nsIInputStream>>&& aBaseInputStream,
409     const Span<const uint8_t> aExpectedData, const size_t aReadChunkSize,
410     const size_t aBlockSize, const typename CipherStrategy::KeyType& aKey,
411     const ExtraChecks& aExtraChecks = NoExtraChecks<CipherStrategy>) {
412   auto inStream = MakeSafeRefPtr<DecryptingInputStream<CipherStrategy>>(
413       std::move(aBaseInputStream), aBlockSize, aKey);
414 
415   ReadTestData(*inStream, aExpectedData, aReadChunkSize, aExtraChecks);
416 
417   return inStream;
418 }
419 
420 // XXX Change to return the buffer instead.
421 template <typename CipherStrategy,
422           typename ExtraChecks = decltype(NoExtraChecks<CipherStrategy>)>
DoRoundtripTest(const size_t aDataSize,const size_t aWriteChunkSize,const size_t aReadChunkSize,const size_t aBlockSize,const typename CipherStrategy::KeyType & aKey,const FlushMode aFlushMode,const ExtraChecks & aExtraChecks=NoExtraChecks<CipherStrategy>)423 static RefPtr<dom::quota::MemoryOutputStream> DoRoundtripTest(
424     const size_t aDataSize, const size_t aWriteChunkSize,
425     const size_t aReadChunkSize, const size_t aBlockSize,
426     const typename CipherStrategy::KeyType& aKey, const FlushMode aFlushMode,
427     const ExtraChecks& aExtraChecks = NoExtraChecks<CipherStrategy>) {
428   // XXX Add deduction guide for RefPtr from already_AddRefed
429   const auto baseOutputStream =
430       WrapNotNull(RefPtr<dom::quota::MemoryOutputStream>{
431           dom::quota::MemoryOutputStream::Create(2048)});
432 
433   const auto data = MakeTestData(aDataSize);
434 
435   WriteTestData<CipherStrategy>(
436       nsCOMPtr<nsIOutputStream>{baseOutputStream.get()}, Span{data},
437       aWriteChunkSize, aBlockSize, aKey, aFlushMode);
438 
439   const auto baseInputStream =
440       MakeRefPtr<ArrayBufferInputStream>(baseOutputStream->Data());
441 
442   ReadTestData<CipherStrategy>(
443       WrapNotNull(nsCOMPtr<nsIInputStream>{baseInputStream}), Span{data},
444       aReadChunkSize, aBlockSize, aKey, aExtraChecks);
445 
446   return baseOutputStream;
447 }
448 
TEST_P(ParametrizedCryptTest,NSSCipherStrategy)449 TEST_P(ParametrizedCryptTest, NSSCipherStrategy) {
450   using CipherStrategy = NSSCipherStrategy;
451   const TestParams& testParams = GetParam();
452 
453   auto keyOrErr = CipherStrategy::GenerateKey();
454   ASSERT_FALSE(keyOrErr.isErr());
455 
456   DoRoundtripTest<CipherStrategy>(
457       testParams.DataSize(), testParams.EffectiveWriteChunkSize(),
458       testParams.EffectiveReadChunkSize(), testParams.BlockSize(),
459       keyOrErr.unwrap(), testParams.FlushMode());
460 }
461 
TEST_P(ParametrizedCryptTest,DummyCipherStrategy_CheckOutput)462 TEST_P(ParametrizedCryptTest, DummyCipherStrategy_CheckOutput) {
463   using CipherStrategy = DummyCipherStrategy;
464   const TestParams& testParams = GetParam();
465 
466   const auto encryptedDataStream = DoRoundtripTest<CipherStrategy>(
467       testParams.DataSize(), testParams.EffectiveWriteChunkSize(),
468       testParams.EffectiveReadChunkSize(), testParams.BlockSize(),
469       CipherStrategy::KeyType{}, testParams.FlushMode());
470 
471   if (HasFailure()) {
472     return;
473   }
474 
475   const auto encryptedDataSpan = AsBytes(Span(encryptedDataStream->Data()));
476 
477   const auto plainTestData = MakeTestData(testParams.DataSize());
478   auto encryptedBlock = EncryptedBlock<DummyCipherStrategy::BlockPrefixLength,
479                                        DummyCipherStrategy::BasicBlockSize>{
480       testParams.BlockSize(),
481   };
482   for (auto [encryptedRemainder, plainRemainder] =
483            std::pair(encryptedDataSpan, Span(plainTestData));
484        !encryptedRemainder.IsEmpty();) {
485     const auto [currentBlock, newEncryptedRemainder] =
486         encryptedRemainder.SplitAt(testParams.BlockSize());
487     encryptedRemainder = newEncryptedRemainder;
488 
489     std::copy(currentBlock.cbegin(), currentBlock.cend(),
490               encryptedBlock.MutableWholeBlock().begin());
491 
492     ASSERT_FALSE(plainRemainder.IsEmpty());
493     const auto [currentPlain, newPlainRemainder] =
494         plainRemainder.SplitAt(encryptedBlock.ActualPayloadLength());
495     plainRemainder = newPlainRemainder;
496 
497     const auto pseudoIV = encryptedBlock.CipherPrefix();
498     const auto payload = encryptedBlock.Payload();
499 
500     EXPECT_EQ(Span(DummyCipherStrategy::MakeBlockPrefix()), pseudoIV);
501 
502     auto untransformedPayload = nsTArray<uint8_t>();
503     untransformedPayload.SetLength(testParams.BlockSize());
504     DummyCipherStrategy::DummyTransform(payload, untransformedPayload);
505 
506     EXPECT_EQ(
507         currentPlain,
508         Span(untransformedPayload).AsConst().First(currentPlain.Length()));
509   }
510 }
511 
TEST_P(ParametrizedCryptTest,DummyCipherStrategy_Tell)512 TEST_P(ParametrizedCryptTest, DummyCipherStrategy_Tell) {
513   using CipherStrategy = DummyCipherStrategy;
514   const TestParams& testParams = GetParam();
515 
516   DoRoundtripTest<CipherStrategy>(
517       testParams.DataSize(), testParams.EffectiveWriteChunkSize(),
518       testParams.EffectiveReadChunkSize(), testParams.BlockSize(),
519       CipherStrategy::KeyType{}, testParams.FlushMode(),
520       [](auto& inStream, Span<const uint8_t> expectedData,
521          Span<const uint8_t> remainder) {
522         // Check that Tell tells the right position.
523         int64_t pos;
524         EXPECT_EQ(NS_OK, inStream.Tell(&pos));
525         EXPECT_EQ(expectedData.Length() - remainder.Length(),
526                   static_cast<uint64_t>(pos));
527       });
528 }
529 
TEST_P(ParametrizedCryptTest,DummyCipherStrategy_Available)530 TEST_P(ParametrizedCryptTest, DummyCipherStrategy_Available) {
531   using CipherStrategy = DummyCipherStrategy;
532   const TestParams& testParams = GetParam();
533 
534   DoRoundtripTest<CipherStrategy>(
535       testParams.DataSize(), testParams.EffectiveWriteChunkSize(),
536       testParams.EffectiveReadChunkSize(), testParams.BlockSize(),
537       CipherStrategy::KeyType{}, testParams.FlushMode(),
538       [](auto& inStream, Span<const uint8_t> expectedData,
539          Span<const uint8_t> remainder) {
540         // Check that Available tells the right remainder.
541         uint64_t available;
542         EXPECT_EQ(NS_OK, inStream.Available(&available));
543         EXPECT_EQ(remainder.Length(), available);
544       });
545 }
546 
TEST_P(ParametrizedCryptTest,DummyCipherStrategy_Clone)547 TEST_P(ParametrizedCryptTest, DummyCipherStrategy_Clone) {
548   using CipherStrategy = DummyCipherStrategy;
549   const TestParams& testParams = GetParam();
550 
551   // XXX Add deduction guide for RefPtr from already_AddRefed
552   const auto baseOutputStream =
553       WrapNotNull(RefPtr<dom::quota::MemoryOutputStream>{
554           dom::quota::MemoryOutputStream::Create(2048)});
555 
556   const auto data = MakeTestData(testParams.DataSize());
557 
558   WriteTestData<CipherStrategy>(
559       nsCOMPtr<nsIOutputStream>{baseOutputStream.get()}, Span{data},
560       testParams.EffectiveWriteChunkSize(), testParams.BlockSize(),
561       CipherStrategy::KeyType{}, testParams.FlushMode());
562 
563   const auto baseInputStream =
564       MakeRefPtr<ArrayBufferInputStream>(baseOutputStream->Data());
565 
566   const auto inStream = ReadTestData<CipherStrategy>(
567       WrapNotNull(nsCOMPtr<nsIInputStream>{baseInputStream}), Span{data},
568       testParams.EffectiveReadChunkSize(), testParams.BlockSize(),
569       CipherStrategy::KeyType{});
570 
571   nsCOMPtr<nsIInputStream> clonedInputStream;
572   EXPECT_EQ(NS_OK, inStream->Clone(getter_AddRefs(clonedInputStream)));
573 
574   ReadTestData(
575       static_cast<DecryptingInputStream<CipherStrategy>&>(*clonedInputStream),
576       Span{data}, testParams.EffectiveReadChunkSize());
577 }
578 
579 // XXX This test is actually only parametrized on the block size.
TEST_P(ParametrizedCryptTest,DummyCipherStrategy_IncompleteBlock)580 TEST_P(ParametrizedCryptTest, DummyCipherStrategy_IncompleteBlock) {
581   using CipherStrategy = DummyCipherStrategy;
582   const TestParams& testParams = GetParam();
583 
584   // Provide half a block, content doesn't matter.
585   nsTArray<uint8_t> data;
586   data.SetLength(testParams.BlockSize() / 2);
587 
588   const auto baseInputStream = MakeRefPtr<ArrayBufferInputStream>(data);
589 
590   const auto inStream = MakeSafeRefPtr<DecryptingInputStream<CipherStrategy>>(
591       WrapNotNull(nsCOMPtr<nsIInputStream>{baseInputStream}),
592       testParams.BlockSize(), CipherStrategy::KeyType{});
593 
594   nsTArray<uint8_t> readData;
595   readData.SetLength(testParams.BlockSize());
596   uint32_t read;
597   EXPECT_EQ(NS_ERROR_CORRUPTED_CONTENT,
598             inStream->Read(reinterpret_cast<char*>(readData.Elements()),
599                            readData.Length(), &read));
600 }
601 
602 enum struct SeekOffset {
603   Zero,
604   MinusHalfDataSize,
605   PlusHalfDataSize,
606   PlusDataSize,
607   MinusDataSize
608 };
609 using SeekOp = std::pair<int32_t, SeekOffset>;
610 
611 using PackedSeekTestParams = std::tuple<size_t, size_t, std::vector<SeekOp>>;
612 
613 struct SeekTestParams {
614   size_t mDataSize;
615   size_t mBlockSize;
616   std::vector<SeekOp> mSeekOps;
617 
SeekTestParamsSeekTestParams618   MOZ_IMPLICIT SeekTestParams(const PackedSeekTestParams& aPackedParams)
619       : mDataSize(std::get<0>(aPackedParams)),
620         mBlockSize(std::get<1>(aPackedParams)),
621         mSeekOps(std::get<2>(aPackedParams)) {}
622 };
623 
SeekTestParamToString(const testing::TestParamInfo<PackedSeekTestParams> & aTestParams)624 std::string SeekTestParamToString(
625     const testing::TestParamInfo<PackedSeekTestParams>& aTestParams) {
626   const SeekTestParams& testParams = aTestParams.param;
627 
628   static constexpr char kSeparator[] = "_";
629 
630   std::stringstream ss;
631   ss << "data" << testParams.mDataSize << kSeparator << "writechunk"
632      << testParams.mBlockSize << kSeparator;
633   for (const auto& seekOp : testParams.mSeekOps) {
634     switch (seekOp.first) {
635       case nsISeekableStream::NS_SEEK_SET:
636         ss << "Set";
637         break;
638       case nsISeekableStream::NS_SEEK_CUR:
639         ss << "Cur";
640         break;
641       case nsISeekableStream::NS_SEEK_END:
642         ss << "End";
643         break;
644     };
645     switch (seekOp.second) {
646       case SeekOffset::Zero:
647         ss << "Zero";
648         break;
649       case SeekOffset::MinusHalfDataSize:
650         ss << "MinusHalfDataSize";
651         break;
652       case SeekOffset::PlusHalfDataSize:
653         ss << "PlusHalfDataSize";
654         break;
655       case SeekOffset::MinusDataSize:
656         ss << "MinusDataSize";
657         break;
658       case SeekOffset::PlusDataSize:
659         ss << "PlusDataSize";
660         break;
661     };
662   }
663   return ss.str();
664 }
665 
666 class ParametrizedSeekCryptTest
667     : public DOM_Quota_EncryptedStream,
668       public testing::WithParamInterface<PackedSeekTestParams> {};
669 
TEST_P(ParametrizedSeekCryptTest,DummyCipherStrategy_Seek)670 TEST_P(ParametrizedSeekCryptTest, DummyCipherStrategy_Seek) {
671   using CipherStrategy = DummyCipherStrategy;
672   const SeekTestParams& testParams = GetParam();
673 
674   const auto baseOutputStream =
675       WrapNotNull(RefPtr<dom::quota::MemoryOutputStream>{
676           dom::quota::MemoryOutputStream::Create(2048)});
677 
678   const auto data = MakeTestData(testParams.mDataSize);
679 
680   WriteTestData<CipherStrategy>(
681       nsCOMPtr<nsIOutputStream>{baseOutputStream.get()}, Span{data},
682       testParams.mDataSize, testParams.mBlockSize, CipherStrategy::KeyType{},
683       FlushMode::Never);
684 
685   const auto baseInputStream =
686       MakeRefPtr<ArrayBufferInputStream>(baseOutputStream->Data());
687 
688   const auto inStream = MakeSafeRefPtr<DecryptingInputStream<CipherStrategy>>(
689       WrapNotNull(nsCOMPtr<nsIInputStream>{baseInputStream}),
690       testParams.mBlockSize, CipherStrategy::KeyType{});
691 
692   uint32_t accumulatedOffset = 0;
693   for (const auto& seekOp : testParams.mSeekOps) {
694     const auto offset = [offsetKind = seekOp.second,
695                          dataSize = testParams.mDataSize]() -> int64_t {
696       switch (offsetKind) {
697         case SeekOffset::Zero:
698           return 0;
699         case SeekOffset::MinusHalfDataSize:
700           return -static_cast<int64_t>(dataSize) / 2;
701         case SeekOffset::PlusHalfDataSize:
702           return dataSize / 2;
703         case SeekOffset::MinusDataSize:
704           return -static_cast<int64_t>(dataSize);
705         case SeekOffset::PlusDataSize:
706           return dataSize;
707       }
708       MOZ_CRASH("Unknown SeekOffset");
709     }();
710     switch (seekOp.first) {
711       case nsISeekableStream::NS_SEEK_SET:
712         accumulatedOffset = offset;
713         break;
714       case nsISeekableStream::NS_SEEK_CUR:
715         accumulatedOffset += offset;
716         break;
717       case nsISeekableStream::NS_SEEK_END:
718         accumulatedOffset = testParams.mDataSize + offset;
719         break;
720     }
721     EXPECT_EQ(NS_OK, inStream->Seek(seekOp.first, offset));
722   }
723 
724   {
725     int64_t actualOffset;
726     EXPECT_EQ(NS_OK, inStream->Tell(&actualOffset));
727 
728     EXPECT_EQ(actualOffset, accumulatedOffset);
729   }
730 
731   auto readData = nsTArray<uint8_t>();
732   readData.SetLength(data.Length());
733   uint32_t read;
734   EXPECT_EQ(NS_OK, inStream->Read(reinterpret_cast<char*>(readData.Elements()),
735                                   readData.Length(), &read));
736   // XXX Or should 'read' indicate the actual number of bytes read,
737   // including the encryption overhead?
738   EXPECT_EQ(testParams.mDataSize - accumulatedOffset, read);
739   EXPECT_EQ(Span{data}.SplitAt(accumulatedOffset).second,
740             Span{readData}.First(read).AsConst());
741 }
742 
743 INSTANTIATE_TEST_CASE_P(
744     DOM_Quota_EncryptedStream_Parametrized, ParametrizedCryptTest,
745     testing::Combine(
746         /* dataSize */ testing::Values(0u, 16u, 256u, 512u, 513u),
747         /* writeChunkSize */
748         testing::Values(ChunkSize::SingleByte, ChunkSize::Unaligned,
749                         ChunkSize::DataSize),
750         /* readChunkSize */
751         testing::Values(ChunkSize::SingleByte, ChunkSize::Unaligned,
752                         ChunkSize::DataSize),
753         /* blockSize */ testing::Values(256u, 1024u /*, 8192u*/),
754         /* flushMode */
755         testing::Values(FlushMode::Never, FlushMode::AfterEachChunk)),
756     TestParamToString);
757 
758 INSTANTIATE_TEST_CASE_P(
759     DOM_IndexedDB_EncryptedStream_ParametrizedSeek, ParametrizedSeekCryptTest,
760     testing::Combine(
761         /* dataSize */ testing::Values(0u, 16u, 256u, 512u, 513u),
762         /* blockSize */ testing::Values(256u, 1024u /*, 8192u*/),
763         /* seekOperations */
764         testing::Values(/* NS_SEEK_SET only, single ops */
765                         std::vector<SeekOp>{{nsISeekableStream::NS_SEEK_SET,
766                                              SeekOffset::PlusDataSize}},
767                         std::vector<SeekOp>{{nsISeekableStream::NS_SEEK_SET,
768                                              SeekOffset::PlusHalfDataSize}},
769                         /* NS_SEEK_SET only, multiple ops */
770                         std::vector<SeekOp>{
771                             {nsISeekableStream::NS_SEEK_SET,
772                              SeekOffset::PlusHalfDataSize},
773                             {nsISeekableStream::NS_SEEK_SET, SeekOffset::Zero}},
774                         /* NS_SEEK_CUR only, single ops */
775                         std::vector<SeekOp>{
776                             {nsISeekableStream::NS_SEEK_CUR, SeekOffset::Zero}},
777                         std::vector<SeekOp>{{nsISeekableStream::NS_SEEK_CUR,
778                                              SeekOffset::PlusDataSize}},
779                         std::vector<SeekOp>{{nsISeekableStream::NS_SEEK_CUR,
780                                              SeekOffset::PlusHalfDataSize}},
781                         /* NS_SEEK_END only, single ops */
782                         std::vector<SeekOp>{
783                             {nsISeekableStream::NS_SEEK_END, SeekOffset::Zero}},
784                         std::vector<SeekOp>{{nsISeekableStream::NS_SEEK_END,
785                                              SeekOffset::MinusDataSize}},
786                         std::vector<SeekOp>{{nsISeekableStream::NS_SEEK_END,
787                                              SeekOffset::MinusHalfDataSize}})),
788     SeekTestParamToString);
789