1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  *
7  */
8 
9 #include <folly/portability/GTest.h>
10 #include <quic/common/test/TestUtils.h>
11 #include <quic/dsr/frontend/Scheduler.h>
12 #include <quic/dsr/test/Mocks.h>
13 #include <quic/fizz/server/handshake/FizzServerQuicHandshakeContext.h>
14 #include <quic/server/state/ServerStateMachine.h>
15 #include <algorithm>
16 
17 namespace quic::test {
18 
19 class DSRCommonTestFixture : public testing::Test {
20  public:
DSRCommonTestFixture()21   DSRCommonTestFixture()
22       : conn_(FizzServerQuicHandshakeContext::Builder().build()),
23         scheduler_(conn_),
24         aead_(createNoOpAead()) {
25     conn_.clientConnectionId = getTestConnectionId(0);
26     conn_.serverConnectionId = getTestConnectionId(1);
27     auto mockHeaderCipher = std::make_unique<MockPacketNumberCipher>();
28     packetProtectionKey_ = getProtectionKey();
29     EXPECT_CALL(*mockHeaderCipher, getKey())
30         .WillRepeatedly(testing::ReturnRef(packetProtectionKey_));
31     conn_.oneRttWriteHeaderCipher = std::move(mockHeaderCipher);
32     auto mockCipher = std::make_unique<MockAead>();
33     EXPECT_CALL(*mockCipher, getKey()).WillRepeatedly(testing::Invoke([] {
34       return getQuicTestKey();
35     }));
36     conn_.oneRttWriteCipher = std::move(mockCipher);
37 
38     serverHandshake_ = std::make_unique<FakeServerHandshake>(
39         conn_,
40         FizzServerQuicHandshakeContext::Builder()
41             .setFizzServerContext(createServerCtx())
42             .build());
43     serverHandshake_->setCipherSuite(fizz::CipherSuite::TLS_AES_128_GCM_SHA256);
44     conn_.serverHandshakeLayer = serverHandshake_.get();
45     conn_.handshakeLayer = std::move(serverHandshake_);
46   }
47 
48  protected:
prepareFlowControlAndStreamLimit()49   void prepareFlowControlAndStreamLimit() {
50     conn_.flowControlState.peerAdvertisedInitialMaxStreamOffsetBidiLocal =
51         kDefaultStreamWindowSize;
52     conn_.flowControlState.peerAdvertisedInitialMaxStreamOffsetBidiRemote =
53         kDefaultStreamWindowSize;
54     conn_.flowControlState.peerAdvertisedInitialMaxStreamOffsetUni =
55         kDefaultStreamWindowSize;
56     conn_.flowControlState.peerAdvertisedMaxOffset =
57         kDefaultConnectionWindowSize;
58     conn_.streamManager->setMaxLocalBidirectionalStreams(
59         kDefaultMaxStreamsBidirectional);
60     conn_.streamManager->setMaxLocalUnidirectionalStreams(
61         kDefaultMaxStreamsUnidirectional);
62   }
63 
64   StreamId prepareOneStream(size_t bufMetaLength = 1000) {
65     conn_.streamManager->setMaxLocalBidirectionalStreams(
66         kDefaultMaxStreamsBidirectional);
67     conn_.streamManager->setMaxLocalUnidirectionalStreams(
68         kDefaultMaxStreamsUnidirectional);
69     auto id = conn_.streamManager->createNextBidirectionalStream().value()->id;
70     auto stream = conn_.streamManager->findStream(id);
71 
72     auto sender = std::make_unique<MockDSRPacketizationRequestSender>();
73     ON_CALL(*sender, addSendInstruction(testing::_))
74         .WillByDefault(testing::Invoke([&](const SendInstruction& instruction) {
75           pendingInstructions_.push_back(instruction);
76           auto streamId = instruction.streamId;
77           if (instructionCounter_.count(streamId) == 0) {
78             instructionCounter_[streamId] = 1;
79           } else {
80             instructionCounter_[streamId] += 1;
81           }
82           return true;
83         }));
84     ON_CALL(*sender, flush()).WillByDefault(testing::Return(true));
85     stream->dsrSender = std::move(sender);
86     writeDataToQuicStream(
87         *stream,
88         folly::IOBuf::copyBuffer("MetroCard Customer Claims"),
89         false /* eof */);
90     BufferMeta bufMeta(bufMetaLength);
91     writeBufMetaToQuicStream(*stream, bufMeta, true /* eof */);
92     return id;
93   }
94 
countInstructions(StreamId streamId)95   size_t countInstructions(StreamId streamId) {
96     if (instructionCounter_.count(streamId) == 0) {
97       return 0;
98     }
99     return instructionCounter_[streamId];
100   }
101 
verifyAllOutstandingsAreDSR()102   bool verifyAllOutstandingsAreDSR() const {
103     return std::all_of(
104         conn_.outstandings.packets.begin(),
105         conn_.outstandings.packets.end(),
106         [](const OutstandingPacket& packet) { return packet.isDSRPacket; });
107   }
108 
109  protected:
110   QuicServerConnectionState conn_;
111   DSRStreamFrameScheduler scheduler_;
112   std::unique_ptr<Aead> aead_;
113   std::unordered_map<StreamId, size_t> instructionCounter_;
114   std::vector<SendInstruction> pendingInstructions_;
115   Buf packetProtectionKey_;
116   std::unique_ptr<FakeServerHandshake> serverHandshake_;
117 };
118 } // namespace quic::test
119