1 /* 2 * Copyright (c) Facebook, Inc. and its affiliates. 3 * 4 * This source code is licensed under the MIT license found in the 5 * LICENSE file in the root directory of this source tree. 6 * 7 */ 8 9 #include <folly/portability/GTest.h> 10 #include <quic/common/test/TestUtils.h> 11 #include <quic/dsr/frontend/Scheduler.h> 12 #include <quic/dsr/test/Mocks.h> 13 #include <quic/fizz/server/handshake/FizzServerQuicHandshakeContext.h> 14 #include <quic/server/state/ServerStateMachine.h> 15 #include <algorithm> 16 17 namespace quic::test { 18 19 class DSRCommonTestFixture : public testing::Test { 20 public: DSRCommonTestFixture()21 DSRCommonTestFixture() 22 : conn_(FizzServerQuicHandshakeContext::Builder().build()), 23 scheduler_(conn_), 24 aead_(createNoOpAead()) { 25 conn_.clientConnectionId = getTestConnectionId(0); 26 conn_.serverConnectionId = getTestConnectionId(1); 27 auto mockHeaderCipher = std::make_unique<MockPacketNumberCipher>(); 28 packetProtectionKey_ = getProtectionKey(); 29 EXPECT_CALL(*mockHeaderCipher, getKey()) 30 .WillRepeatedly(testing::ReturnRef(packetProtectionKey_)); 31 conn_.oneRttWriteHeaderCipher = std::move(mockHeaderCipher); 32 auto mockCipher = std::make_unique<MockAead>(); 33 EXPECT_CALL(*mockCipher, getKey()).WillRepeatedly(testing::Invoke([] { 34 return getQuicTestKey(); 35 })); 36 conn_.oneRttWriteCipher = std::move(mockCipher); 37 38 serverHandshake_ = std::make_unique<FakeServerHandshake>( 39 conn_, 40 FizzServerQuicHandshakeContext::Builder() 41 .setFizzServerContext(createServerCtx()) 42 .build()); 43 serverHandshake_->setCipherSuite(fizz::CipherSuite::TLS_AES_128_GCM_SHA256); 44 conn_.serverHandshakeLayer = serverHandshake_.get(); 45 conn_.handshakeLayer = std::move(serverHandshake_); 46 } 47 48 protected: prepareFlowControlAndStreamLimit()49 void prepareFlowControlAndStreamLimit() { 50 conn_.flowControlState.peerAdvertisedInitialMaxStreamOffsetBidiLocal = 51 kDefaultStreamWindowSize; 52 conn_.flowControlState.peerAdvertisedInitialMaxStreamOffsetBidiRemote = 53 kDefaultStreamWindowSize; 54 conn_.flowControlState.peerAdvertisedInitialMaxStreamOffsetUni = 55 kDefaultStreamWindowSize; 56 conn_.flowControlState.peerAdvertisedMaxOffset = 57 kDefaultConnectionWindowSize; 58 conn_.streamManager->setMaxLocalBidirectionalStreams( 59 kDefaultMaxStreamsBidirectional); 60 conn_.streamManager->setMaxLocalUnidirectionalStreams( 61 kDefaultMaxStreamsUnidirectional); 62 } 63 64 StreamId prepareOneStream(size_t bufMetaLength = 1000) { 65 conn_.streamManager->setMaxLocalBidirectionalStreams( 66 kDefaultMaxStreamsBidirectional); 67 conn_.streamManager->setMaxLocalUnidirectionalStreams( 68 kDefaultMaxStreamsUnidirectional); 69 auto id = conn_.streamManager->createNextBidirectionalStream().value()->id; 70 auto stream = conn_.streamManager->findStream(id); 71 72 auto sender = std::make_unique<MockDSRPacketizationRequestSender>(); 73 ON_CALL(*sender, addSendInstruction(testing::_)) 74 .WillByDefault(testing::Invoke([&](const SendInstruction& instruction) { 75 pendingInstructions_.push_back(instruction); 76 auto streamId = instruction.streamId; 77 if (instructionCounter_.count(streamId) == 0) { 78 instructionCounter_[streamId] = 1; 79 } else { 80 instructionCounter_[streamId] += 1; 81 } 82 return true; 83 })); 84 ON_CALL(*sender, flush()).WillByDefault(testing::Return(true)); 85 stream->dsrSender = std::move(sender); 86 writeDataToQuicStream( 87 *stream, 88 folly::IOBuf::copyBuffer("MetroCard Customer Claims"), 89 false /* eof */); 90 BufferMeta bufMeta(bufMetaLength); 91 writeBufMetaToQuicStream(*stream, bufMeta, true /* eof */); 92 return id; 93 } 94 countInstructions(StreamId streamId)95 size_t countInstructions(StreamId streamId) { 96 if (instructionCounter_.count(streamId) == 0) { 97 return 0; 98 } 99 return instructionCounter_[streamId]; 100 } 101 verifyAllOutstandingsAreDSR()102 bool verifyAllOutstandingsAreDSR() const { 103 return std::all_of( 104 conn_.outstandings.packets.begin(), 105 conn_.outstandings.packets.end(), 106 [](const OutstandingPacket& packet) { return packet.isDSRPacket; }); 107 } 108 109 protected: 110 QuicServerConnectionState conn_; 111 DSRStreamFrameScheduler scheduler_; 112 std::unique_ptr<Aead> aead_; 113 std::unordered_map<StreamId, size_t> instructionCounter_; 114 std::vector<SendInstruction> pendingInstructions_; 115 Buf packetProtectionKey_; 116 std::unique_ptr<FakeServerHandshake> serverHandshake_; 117 }; 118 } // namespace quic::test 119