1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  *
7  */
8 
9 #include <quic/common/test/TestUtils.h>
10 
11 #include <fizz/crypto/test/TestUtil.h>
12 #include <fizz/protocol/clock/test/Mocks.h>
13 #include <fizz/protocol/test/Mocks.h>
14 #include <quic/api/QuicTransportFunctions.h>
15 #include <quic/codec/DefaultConnectionIdAlgo.h>
16 #include <quic/fizz/handshake/QuicFizzFactory.h>
17 #include <quic/fizz/server/handshake/AppToken.h>
18 #include <quic/handshake/test/Mocks.h>
19 #include <quic/server/handshake/StatelessResetGenerator.h>
20 #include <quic/state/stream/StreamSendHandlers.h>
21 #include "quic/codec/QuicConnectionId.h"
22 
23 using namespace testing;
24 
25 namespace quic {
26 namespace test {
27 
28 std::function<MockClock::time_point()> MockClock::mockNow;
29 
writeQuicPacket(QuicServerConnectionState & conn,ConnectionId srcConnId,ConnectionId dstConnId,folly::test::MockAsyncUDPSocket & sock,QuicStreamState & stream,const folly::IOBuf & data,bool eof)30 const RegularQuicWritePacket& writeQuicPacket(
31     QuicServerConnectionState& conn,
32     ConnectionId srcConnId,
33     ConnectionId dstConnId,
34     folly::test::MockAsyncUDPSocket& sock,
35     QuicStreamState& stream,
36     const folly::IOBuf& data,
37     bool eof) {
38   auto version = conn.version.value_or(*conn.originalVersion);
39   auto aead = createNoOpAead();
40   auto headerCipher = createNoOpHeaderCipher();
41   writeDataToQuicStream(stream, data.clone(), eof);
42   writeQuicDataToSocket(
43       sock,
44       conn,
45       srcConnId,
46       dstConnId,
47       *aead,
48       *headerCipher,
49       version,
50       conn.transportSettings.writeConnectionDataPacketsLimit);
51   CHECK(
52       conn.outstandings.packets.rend() !=
53       getLastOutstandingPacket(conn, PacketNumberSpace::AppData));
54   return getLastOutstandingPacket(conn, PacketNumberSpace::AppData)->packet;
55 }
56 
rstStreamAndSendPacket(QuicServerConnectionState & conn,folly::AsyncUDPSocket & sock,QuicStreamState & stream,ApplicationErrorCode errorCode)57 PacketNum rstStreamAndSendPacket(
58     QuicServerConnectionState& conn,
59     folly::AsyncUDPSocket& sock,
60     QuicStreamState& stream,
61     ApplicationErrorCode errorCode) {
62   auto aead = createNoOpAead();
63   auto headerCipher = createNoOpHeaderCipher();
64   auto version = conn.version.value_or(*conn.originalVersion);
65   sendRstSMHandler(stream, errorCode);
66   writeQuicDataToSocket(
67       sock,
68       conn,
69       *conn.clientConnectionId,
70       *conn.serverConnectionId,
71       *aead,
72       *headerCipher,
73       version,
74       conn.transportSettings.writeConnectionDataPacketsLimit);
75 
76   for (const auto& packet : conn.outstandings.packets) {
77     for (const auto& frame : packet.packet.frames) {
78       auto rstFrame = frame.asRstStreamFrame();
79       if (!rstFrame) {
80         continue;
81       }
82       if (rstFrame->streamId == stream.id) {
83         return packet.packet.header.getPacketSequenceNum();
84       }
85     }
86   }
87   CHECK(false) << "no packet with reset stream";
88   // some compilers are weird.
89   return 0;
90 }
91 
createAckPacket(QuicConnectionStateBase & dstConn,PacketNum pn,AckBlocks & acks,PacketNumberSpace pnSpace,const Aead * aead)92 RegularQuicPacketBuilder::Packet createAckPacket(
93     QuicConnectionStateBase& dstConn,
94     PacketNum pn,
95     AckBlocks& acks,
96     PacketNumberSpace pnSpace,
97     const Aead* aead) {
98   // This function sends ACK to dstConn
99   auto srcConnId =
100       (dstConn.nodeType == QuicNodeType::Client ? *dstConn.serverConnectionId
101                                                 : *dstConn.clientConnectionId);
102   auto dstConnId =
103       (dstConn.nodeType == QuicNodeType::Client ? *dstConn.clientConnectionId
104                                                 : *dstConn.serverConnectionId);
105   folly::Optional<PacketHeader> header;
106   if (pnSpace == PacketNumberSpace::Initial) {
107     header = LongHeader(
108         LongHeader::Types::Initial,
109         srcConnId,
110         dstConnId,
111         pn,
112         QuicVersion::MVFST);
113   } else if (pnSpace == PacketNumberSpace::Handshake) {
114     header = LongHeader(
115         LongHeader::Types::Handshake,
116         srcConnId,
117         dstConnId,
118         pn,
119         QuicVersion::MVFST);
120   } else {
121     header = ShortHeader(ProtectionType::KeyPhaseZero, dstConnId, pn);
122   }
123   RegularQuicPacketBuilder builder(
124       dstConn.udpSendPacketLen,
125       std::move(*header),
126       getAckState(dstConn, pnSpace).largestAckScheduled.value_or(0));
127   builder.encodePacketHeader();
128   if (aead) {
129     builder.accountForCipherOverhead(aead->getCipherOverhead());
130   }
131   DCHECK(builder.canBuildPacket());
132   AckFrameMetaData ackData(
133       acks, 0us, dstConn.transportSettings.ackDelayExponent);
134   writeAckFrame(ackData, builder);
135   return std::move(builder).buildPacket();
136 }
137 
readCert()138 std::shared_ptr<fizz::SelfCert> readCert() {
139   auto certificate = fizz::test::getCert(fizz::test::kP256Certificate);
140   auto privKey = fizz::test::getPrivateKey(fizz::test::kP256Key);
141   std::vector<folly::ssl::X509UniquePtr> certs;
142   certs.emplace_back(std::move(certificate));
143   return std::make_shared<fizz::SelfCertImpl<fizz::KeyType::P256>>(
144       std::move(privKey), std::move(certs));
145 }
146 
createServerCtx()147 std::shared_ptr<fizz::server::FizzServerContext> createServerCtx() {
148   auto cert = readCert();
149   auto certManager = std::make_unique<fizz::server::CertManager>();
150   certManager->addCert(std::move(cert), true);
151   auto serverCtx = std::make_shared<fizz::server::FizzServerContext>();
152   serverCtx->setFactory(std::make_shared<QuicFizzFactory>());
153   serverCtx->setCertManager(std::move(certManager));
154   serverCtx->setOmitEarlyRecordLayer(true);
155   serverCtx->setClock(std::make_shared<NiceMock<fizz::test::MockClock>>());
156   return serverCtx;
157 }
158 
159 class AcceptingTicketCipher : public fizz::server::TicketCipher {
160  public:
161   ~AcceptingTicketCipher() override = default;
162 
163   folly::Future<folly::Optional<
164       std::pair<std::unique_ptr<folly::IOBuf>, std::chrono::seconds>>>
encrypt(fizz::server::ResumptionState) const165   encrypt(fizz::server::ResumptionState) const override {
166     // Fake handshake, no need todo anything here.
167     return std::make_pair(folly::IOBuf::create(0), 2s);
168   }
169 
setPsk(const QuicCachedPsk & cachedPsk)170   void setPsk(const QuicCachedPsk& cachedPsk) {
171     cachedPsk_ = cachedPsk;
172   }
173 
createResumptionState() const174   fizz::server::ResumptionState createResumptionState() const {
175     fizz::server::ResumptionState resState;
176     resState.version = cachedPsk_.cachedPsk.version;
177     resState.cipher = cachedPsk_.cachedPsk.cipher;
178     resState.resumptionSecret =
179         folly::IOBuf::copyBuffer(cachedPsk_.cachedPsk.secret);
180     resState.serverCert = cachedPsk_.cachedPsk.serverCert;
181     resState.alpn = cachedPsk_.cachedPsk.alpn;
182     resState.ticketAgeAdd = 0;
183     resState.ticketIssueTime = std::chrono::system_clock::time_point();
184     resState.handshakeTime = std::chrono::system_clock::time_point();
185     AppToken appToken;
186     appToken.transportParams = createTicketTransportParameters(
187         kDefaultIdleTimeout.count(),
188         kDefaultUDPReadBufferSize,
189         kDefaultConnectionWindowSize,
190         kDefaultStreamWindowSize,
191         kDefaultStreamWindowSize,
192         kDefaultStreamWindowSize,
193         kDefaultMaxStreamsBidirectional,
194         kDefaultMaxStreamsUnidirectional);
195     appToken.version = QuicVersion::MVFST;
196     resState.appToken = encodeAppToken(appToken);
197     return resState;
198   }
199 
200   folly::Future<
201       std::pair<fizz::PskType, folly::Optional<fizz::server::ResumptionState>>>
decrypt(std::unique_ptr<folly::IOBuf>) const202   decrypt(std::unique_ptr<folly::IOBuf>) const override {
203     return std::make_pair(fizz::PskType::Resumption, createResumptionState());
204   }
205 
206  private:
207   QuicCachedPsk cachedPsk_;
208 };
209 
setupZeroRttOnServerCtx(fizz::server::FizzServerContext & serverCtx,const QuicCachedPsk & cachedPsk)210 void setupZeroRttOnServerCtx(
211     fizz::server::FizzServerContext& serverCtx,
212     const QuicCachedPsk& cachedPsk) {
213   serverCtx.setEarlyDataSettings(
214       true,
215       fizz::server::ClockSkewTolerance{-100000ms, 100000ms},
216       std::make_shared<fizz::server::AllowAllReplayReplayCache>());
217   auto ticketCipher = std::make_shared<AcceptingTicketCipher>();
218   ticketCipher->setPsk(cachedPsk);
219   serverCtx.setTicketCipher(ticketCipher);
220 }
221 
setupZeroRttOnClientCtx(fizz::client::FizzClientContext & clientCtx,std::string hostname)222 QuicCachedPsk setupZeroRttOnClientCtx(
223     fizz::client::FizzClientContext& clientCtx,
224     std::string hostname) {
225   clientCtx.setSendEarlyData(true);
226 
227   QuicCachedPsk quicCachedPsk;
228   auto& psk = quicCachedPsk.cachedPsk;
229   psk.psk = std::string("psk");
230   psk.secret = std::string("secret");
231   psk.type = fizz::PskType::Resumption;
232   psk.version = clientCtx.getSupportedVersions()[0];
233   psk.cipher = clientCtx.getSupportedCiphers()[0];
234   psk.group = clientCtx.getSupportedGroups()[0];
235   auto mockCert = std::make_shared<NiceMock<fizz::test::MockCert>>();
236   ON_CALL(*mockCert, getIdentity()).WillByDefault(Return(hostname));
237   psk.serverCert = mockCert;
238   psk.alpn = clientCtx.getSupportedAlpns()[0];
239   psk.ticketAgeAdd = 1;
240   psk.ticketIssueTime = std::chrono::system_clock::time_point();
241   psk.ticketExpirationTime =
242       std::chrono::system_clock::time_point(std::chrono::minutes(100));
243   psk.ticketHandshakeTime = std::chrono::system_clock::time_point();
244   psk.maxEarlyDataSize = 2;
245 
246   quicCachedPsk.transportParams.idleTimeout = kDefaultIdleTimeout.count();
247   quicCachedPsk.transportParams.maxRecvPacketSize = kDefaultUDPReadBufferSize;
248   quicCachedPsk.transportParams.initialMaxData = kDefaultConnectionWindowSize;
249   quicCachedPsk.transportParams.initialMaxStreamDataBidiLocal =
250       kDefaultStreamWindowSize;
251   quicCachedPsk.transportParams.initialMaxStreamDataBidiRemote =
252       kDefaultStreamWindowSize;
253   quicCachedPsk.transportParams.initialMaxStreamDataUni =
254       kDefaultStreamWindowSize;
255   quicCachedPsk.transportParams.initialMaxStreamsBidi =
256       kDefaultMaxStreamsBidirectional;
257   quicCachedPsk.transportParams.initialMaxStreamsUni =
258       kDefaultMaxStreamsUnidirectional;
259   return quicCachedPsk;
260 }
261 
setupCtxWithTestCert(fizz::server::FizzServerContext & ctx)262 void setupCtxWithTestCert(fizz::server::FizzServerContext& ctx) {
263   auto cert = readCert();
264   auto certManager = std::make_unique<fizz::server::CertManager>();
265   certManager->addCert(std::move(cert), true);
266   ctx.setCertManager(std::move(certManager));
267 }
268 
createNoOpAead()269 std::unique_ptr<MockAead> createNoOpAead() {
270   return createNoOpAeadImpl<MockAead>();
271 }
272 
createNoOpHeaderCipher()273 std::unique_ptr<MockPacketNumberCipher> createNoOpHeaderCipher() {
274   auto headerCipher = std::make_unique<NiceMock<MockPacketNumberCipher>>();
275   ON_CALL(*headerCipher, mask(_)).WillByDefault(Return(HeaderProtectionMask{}));
276   ON_CALL(*headerCipher, keyLength()).WillByDefault(Return(16));
277   return headerCipher;
278 }
279 
createStreamPacket(ConnectionId srcConnId,ConnectionId dstConnId,PacketNum packetNum,StreamId streamId,folly::IOBuf & data,uint8_t cipherOverhead,PacketNum largestAcked,folly::Optional<std::pair<LongHeader::Types,QuicVersion>> longHeaderOverride,bool eof,folly::Optional<ProtectionType> shortHeaderOverride,uint64_t offset,uint64_t packetSizeLimit)280 RegularQuicPacketBuilder::Packet createStreamPacket(
281     ConnectionId srcConnId,
282     ConnectionId dstConnId,
283     PacketNum packetNum,
284     StreamId streamId,
285     folly::IOBuf& data,
286     uint8_t cipherOverhead,
287     PacketNum largestAcked,
288     folly::Optional<std::pair<LongHeader::Types, QuicVersion>>
289         longHeaderOverride,
290     bool eof,
291     folly::Optional<ProtectionType> shortHeaderOverride,
292     uint64_t offset,
293     uint64_t packetSizeLimit) {
294   std::unique_ptr<RegularQuicPacketBuilder> builder;
295   if (longHeaderOverride) {
296     LongHeader header(
297         longHeaderOverride->first,
298         srcConnId,
299         dstConnId,
300         packetNum,
301         longHeaderOverride->second);
302     builder.reset(new RegularQuicPacketBuilder(
303         packetSizeLimit, std::move(header), largestAcked));
304   } else {
305     ProtectionType protectionType = ProtectionType::KeyPhaseZero;
306     if (shortHeaderOverride) {
307       protectionType = *shortHeaderOverride;
308     }
309     ShortHeader header(protectionType, dstConnId, packetNum);
310     builder.reset(new RegularQuicPacketBuilder(
311         packetSizeLimit, std::move(header), largestAcked));
312   }
313   builder->encodePacketHeader();
314   builder->accountForCipherOverhead(cipherOverhead);
315   auto dataLen = *writeStreamFrameHeader(
316       *builder,
317       streamId,
318       offset,
319       data.computeChainDataLength(),
320       data.computeChainDataLength(),
321       eof,
322       folly::none /* skipLenHint */);
323   writeStreamFrameData(
324       *builder,
325       data.clone(),
326       std::min(folly::to<size_t>(dataLen), data.computeChainDataLength()));
327   return std::move(*builder).buildPacket();
328 }
329 
createInitialCryptoPacket(ConnectionId srcConnId,ConnectionId dstConnId,PacketNum packetNum,QuicVersion version,folly::IOBuf & data,const Aead & aead,PacketNum largestAcked,uint64_t offset,const BuilderProvider & builderProvider)330 RegularQuicPacketBuilder::Packet createInitialCryptoPacket(
331     ConnectionId srcConnId,
332     ConnectionId dstConnId,
333     PacketNum packetNum,
334     QuicVersion version,
335     folly::IOBuf& data,
336     const Aead& aead,
337     PacketNum largestAcked,
338     uint64_t offset,
339     const BuilderProvider& builderProvider) {
340   LongHeader header(
341       LongHeader::Types::Initial, srcConnId, dstConnId, packetNum, version);
342   LongHeader copyHeader(header);
343   PacketBuilderInterface* builder = nullptr;
344   if (builderProvider) {
345     builder = builderProvider(std::move(header), largestAcked);
346   }
347   RegularQuicPacketBuilder fallbackBuilder(
348       kDefaultUDPSendPacketLen, std::move(copyHeader), largestAcked);
349   if (!builder) {
350     builder = &fallbackBuilder;
351   }
352   builder->encodePacketHeader();
353   builder->accountForCipherOverhead(aead.getCipherOverhead());
354   writeCryptoFrame(offset, data.clone(), *builder);
355   return std::move(*builder).buildPacket();
356 }
357 
createCryptoPacket(ConnectionId srcConnId,ConnectionId dstConnId,PacketNum packetNum,QuicVersion version,ProtectionType protectionType,folly::IOBuf & data,const Aead & aead,PacketNum largestAcked,uint64_t offset,uint64_t packetSizeLimit)358 RegularQuicPacketBuilder::Packet createCryptoPacket(
359     ConnectionId srcConnId,
360     ConnectionId dstConnId,
361     PacketNum packetNum,
362     QuicVersion version,
363     ProtectionType protectionType,
364     folly::IOBuf& data,
365     const Aead& aead,
366     PacketNum largestAcked,
367     uint64_t offset,
368     uint64_t packetSizeLimit) {
369   folly::Optional<PacketHeader> header;
370   switch (protectionType) {
371     case ProtectionType::Initial:
372       header = LongHeader(
373           LongHeader::Types::Initial, srcConnId, dstConnId, packetNum, version);
374       break;
375     case ProtectionType::Handshake:
376       header = LongHeader(
377           LongHeader::Types::Handshake,
378           srcConnId,
379           dstConnId,
380           packetNum,
381           version);
382       break;
383     case ProtectionType::ZeroRtt:
384       header = LongHeader(
385           LongHeader::Types::ZeroRtt, srcConnId, dstConnId, packetNum, version);
386       break;
387     case ProtectionType::KeyPhaseOne:
388     case ProtectionType::KeyPhaseZero:
389       header = ShortHeader(protectionType, dstConnId, packetNum);
390       break;
391   }
392   RegularQuicPacketBuilder builder(
393       packetSizeLimit, std::move(*header), largestAcked);
394   builder.encodePacketHeader();
395   builder.accountForCipherOverhead(aead.getCipherOverhead());
396   writeCryptoFrame(offset, data.clone(), builder);
397   return std::move(builder).buildPacket();
398 }
399 
packetToBuf(const RegularQuicPacketBuilder::Packet & packet)400 Buf packetToBuf(const RegularQuicPacketBuilder::Packet& packet) {
401   auto packetBuf = packet.header->clone();
402   if (packet.body) {
403     packetBuf->prependChain(packet.body->clone());
404   }
405   return packetBuf;
406 }
407 
packetToBufCleartext(const RegularQuicPacketBuilder::Packet & packet,const Aead & cleartextCipher,const PacketNumberCipher & headerCipher,PacketNum packetNum)408 Buf packetToBufCleartext(
409     const RegularQuicPacketBuilder::Packet& packet,
410     const Aead& cleartextCipher,
411     const PacketNumberCipher& headerCipher,
412     PacketNum packetNum) {
413   VLOG(10) << __func__ << " packet header: "
414            << folly::hexlify(packet.header->clone()->moveToFbString());
415   auto packetBuf = packet.header->clone();
416   Buf body;
417   if (packet.body) {
418     packet.body->coalesce();
419     body = packet.body->clone();
420   } else {
421     body = folly::IOBuf::create(0);
422   }
423   auto headerForm = packet.packet.header.getHeaderForm();
424   packet.header->coalesce();
425   auto tagLen = cleartextCipher.getCipherOverhead();
426   if (body->tailroom() < tagLen) {
427     body->prependChain(folly::IOBuf::create(tagLen));
428   }
429   body->coalesce();
430   auto encryptedBody = cleartextCipher.inplaceEncrypt(
431       std::move(body), packet.header.get(), packetNum);
432   encryptedBody->coalesce();
433   encryptPacketHeader(
434       headerForm,
435       packet.header->writableData(),
436       packet.header->length(),
437       encryptedBody->data(),
438       encryptedBody->length(),
439       headerCipher);
440   packetBuf->prependChain(std::move(encryptedBody));
441   return packetBuf;
442 }
443 
computeExpectedDelay(std::chrono::microseconds ackDelay,uint8_t ackDelayExponent)444 uint64_t computeExpectedDelay(
445     std::chrono::microseconds ackDelay,
446     uint8_t ackDelayExponent) {
447   uint64_t divide = uint64_t(ackDelay.count()) >> ackDelayExponent;
448   return divide << ackDelayExponent;
449 }
450 
getTestConnectionId(uint32_t hostId,ConnectionIdVersion version)451 ConnectionId getTestConnectionId(uint32_t hostId, ConnectionIdVersion version) {
452   ServerConnectionIdParams params(version, hostId, 0, 0);
453   DefaultConnectionIdAlgo connIdAlgo;
454   auto connId = *connIdAlgo.encodeConnectionId(params);
455   // Clear random part of CID, some existing tests expect same CID value
456   // when repeatedly calling with the same hostId.
457   if (version == ConnectionIdVersion::V1) {
458     connId.data()[3] = 3;
459     connId.data()[4] = 4;
460     connId.data()[5] = 5;
461     connId.data()[6] = 6;
462     connId.data()[7] = 7;
463   } else if (version == ConnectionIdVersion::V2) {
464     connId.data()[0] &= 0xC0;
465     connId.data()[5] = 5;
466     connId.data()[6] = 6;
467     connId.data()[7] = 7;
468   } else {
469     CHECK(false) << "Unsupported CID version";
470   }
471 
472   return connId;
473 }
474 
encryptionLevelToProtectionType(fizz::EncryptionLevel encryptionLevel)475 ProtectionType encryptionLevelToProtectionType(
476     fizz::EncryptionLevel encryptionLevel) {
477   switch (encryptionLevel) {
478     case fizz::EncryptionLevel::Plaintext:
479       return ProtectionType::Initial;
480     case fizz::EncryptionLevel::Handshake:
481       // TODO: change this in draft-14
482       return ProtectionType::Initial;
483     case fizz::EncryptionLevel::EarlyData:
484       return ProtectionType::ZeroRtt;
485     case fizz::EncryptionLevel::AppTraffic:
486       return ProtectionType::KeyPhaseZero;
487   }
488   folly::assume_unreachable();
489 }
490 
updateAckState(QuicConnectionStateBase & conn,PacketNumberSpace pnSpace,PacketNum packetNum,bool pkHasRetransmittableData,bool pkHasCryptoData,TimePoint receivedTime)491 void updateAckState(
492     QuicConnectionStateBase& conn,
493     PacketNumberSpace pnSpace,
494     PacketNum packetNum,
495     bool pkHasRetransmittableData,
496     bool pkHasCryptoData,
497     TimePoint receivedTime) {
498   bool outOfOrder = updateLargestReceivedPacketNum(
499       getAckState(conn, pnSpace), packetNum, receivedTime);
500   updateAckSendStateOnRecvPacket(
501       conn,
502       getAckState(conn, pnSpace),
503       outOfOrder,
504       pkHasRetransmittableData,
505       pkHasCryptoData);
506 }
507 
buildRandomInputData(size_t length)508 std::unique_ptr<folly::IOBuf> buildRandomInputData(size_t length) {
509   auto buf = folly::IOBuf::create(length);
510   buf->append(length);
511   folly::Random::secureRandom(buf->writableData(), buf->length());
512   return buf;
513 }
514 
addAckStatesWithCurrentTimestamps(AckState & ackState,PacketNum start,PacketNum end)515 void addAckStatesWithCurrentTimestamps(
516     AckState& ackState,
517     PacketNum start,
518     PacketNum end) {
519   ackState.acks.insert(start, end);
520   ackState.largestRecvdPacketTime = Clock::now();
521 }
522 
makeTestingWritePacket(PacketNum desiredPacketSeqNum,size_t desiredSize,uint64_t totalBytesSent,TimePoint sentTime,uint64_t inflightBytes,uint64_t writeCount)523 OutstandingPacket makeTestingWritePacket(
524     PacketNum desiredPacketSeqNum,
525     size_t desiredSize,
526     uint64_t totalBytesSent,
527     TimePoint sentTime /* = Clock::now() */,
528     uint64_t inflightBytes /* = 0 */,
529     uint64_t writeCount /* = 0 */) {
530   LongHeader longHeader(
531       LongHeader::Types::ZeroRtt,
532       getTestConnectionId(1),
533       getTestConnectionId(),
534       desiredPacketSeqNum,
535       QuicVersion::MVFST);
536   RegularQuicWritePacket packet(std::move(longHeader));
537   return OutstandingPacket(
538       packet,
539       sentTime,
540       desiredSize,
541       0,
542       false,
543       totalBytesSent,
544       0,
545       inflightBytes,
546       0,
547       LossState(),
548       writeCount);
549 }
550 
makeAck(PacketNum seq,uint64_t ackedSize,TimePoint ackedTime,TimePoint sentTime)551 CongestionController::AckEvent makeAck(
552     PacketNum seq,
553     uint64_t ackedSize,
554     TimePoint ackedTime,
555     TimePoint sentTime) {
556   CHECK(sentTime < ackedTime);
557   RegularQuicWritePacket packet(
558       ShortHeader(ProtectionType::KeyPhaseZero, getTestConnectionId(), seq));
559   CongestionController::AckEvent ack;
560   ack.ackedBytes = ackedSize;
561   ack.ackTime = ackedTime;
562   ack.largestAckedPacket = seq;
563   ack.ackedPackets.emplace_back(
564       CongestionController::AckEvent::AckPacket::Builder()
565           .setSentTime(sentTime)
566           .setEncodedSize(ackedSize)
567           .build());
568   ack.largestAckedPacketSentTime = sentTime;
569   return ack;
570 }
571 
bufToQueue(Buf buf)572 BufQueue bufToQueue(Buf buf) {
573   BufQueue queue;
574   buf->coalesce();
575   queue.append(std::move(buf));
576   return queue;
577 }
578 
generateStatelessResetToken()579 StatelessResetToken generateStatelessResetToken() {
580   StatelessResetSecret secret;
581   folly::Random::secureRandom(secret.data(), secret.size());
582   folly::SocketAddress address("1.2.3.4", 8080);
583   StatelessResetGenerator generator(secret, address.getFullyQualified());
584 
585   return generator.generateToken(ConnectionId({0x14, 0x35, 0x22, 0x11}));
586 }
587 
getRandSecret()588 std::array<uint8_t, kStatelessResetTokenSecretLength> getRandSecret() {
589   std::array<uint8_t, kStatelessResetTokenSecretLength> secret;
590   folly::Random::secureRandom(secret.data(), secret.size());
591   return secret;
592 }
593 
createNewPacket(PacketNum packetNum,PacketNumberSpace pnSpace)594 RegularQuicWritePacket createNewPacket(
595     PacketNum packetNum,
596     PacketNumberSpace pnSpace) {
597   switch (pnSpace) {
598     case PacketNumberSpace::Initial:
599       return RegularQuicWritePacket(LongHeader(
600           LongHeader::Types::Initial,
601           getTestConnectionId(1),
602           getTestConnectionId(2),
603           packetNum,
604           QuicVersion::QUIC_DRAFT));
605     case PacketNumberSpace::Handshake:
606       return RegularQuicWritePacket(LongHeader(
607           LongHeader::Types::Handshake,
608           getTestConnectionId(0),
609           getTestConnectionId(4),
610           packetNum,
611           QuicVersion::QUIC_DRAFT));
612     case PacketNumberSpace::AppData:
613       return RegularQuicWritePacket(ShortHeader(
614           ProtectionType::KeyPhaseOne, getTestConnectionId(), packetNum));
615   }
616 
617   folly::assume_unreachable();
618 }
619 
versionList(std::initializer_list<QuicVersionType> types)620 std::vector<QuicVersion> versionList(
621     std::initializer_list<QuicVersionType> types) {
622   std::vector<QuicVersion> versions;
623   for (auto type : types) {
624     versions.push_back(static_cast<QuicVersion>(type));
625   }
626   return versions;
627 }
628 
createRegularQuicWritePacket(StreamId streamId,uint64_t offset,uint64_t len,bool fin)629 RegularQuicWritePacket createRegularQuicWritePacket(
630     StreamId streamId,
631     uint64_t offset,
632     uint64_t len,
633     bool fin) {
634   auto regularWritePacket = createNewPacket(10, PacketNumberSpace::Initial);
635   WriteStreamFrame frame(streamId, offset, len, fin);
636   regularWritePacket.frames.emplace_back(frame);
637   return regularWritePacket;
638 }
639 
createVersionNegotiationPacket()640 VersionNegotiationPacket createVersionNegotiationPacket() {
641   auto versions = {QuicVersion::VERSION_NEGOTIATION, QuicVersion::MVFST};
642   auto packet = VersionNegotiationPacketBuilder(
643                     getTestConnectionId(0), getTestConnectionId(1), versions)
644                     .buildPacket()
645                     .first;
646   return packet;
647 }
648 
createPacketWithAckFrames()649 RegularQuicWritePacket createPacketWithAckFrames() {
650   RegularQuicWritePacket packet =
651       createNewPacket(100, PacketNumberSpace::Initial);
652   WriteAckFrame ackFrame;
653   ackFrame.ackDelay = 111us;
654   ackFrame.ackBlocks.emplace_back(900, 1000);
655   ackFrame.ackBlocks.emplace_back(500, 700);
656 
657   packet.frames.emplace_back(std::move(ackFrame));
658   return packet;
659 }
660 
createPacketWithPaddingFrames()661 RegularQuicWritePacket createPacketWithPaddingFrames() {
662   RegularQuicWritePacket packet =
663       createNewPacket(100, PacketNumberSpace::Initial);
664   for (int i = 0; i < 20; ++i) {
665     PaddingFrame paddingFrame;
666     packet.frames.emplace_back(paddingFrame);
667   }
668   return packet;
669 }
670 
getQLogEventIndices(QLogEventType type,const std::shared_ptr<FileQLogger> & q)671 std::vector<int> getQLogEventIndices(
672     QLogEventType type,
673     const std::shared_ptr<FileQLogger>& q) {
674   std::vector<int> indices;
675   for (uint64_t i = 0; i < q->logs.size(); ++i) {
676     if (q->logs[i]->eventType == type) {
677       indices.push_back(i);
678     }
679   }
680   return indices;
681 }
682 
matchError(std::pair<QuicErrorCode,folly::Optional<folly::StringPiece>> errorCode,LocalErrorCode error)683 bool matchError(
684     std::pair<QuicErrorCode, folly::Optional<folly::StringPiece>> errorCode,
685     LocalErrorCode error) {
686   return errorCode.first.type() == QuicErrorCode::Type::LocalErrorCode &&
687       *errorCode.first.asLocalErrorCode() == error;
688 }
689 
matchError(std::pair<QuicErrorCode,folly::Optional<folly::StringPiece>> errorCode,TransportErrorCode error)690 bool matchError(
691     std::pair<QuicErrorCode, folly::Optional<folly::StringPiece>> errorCode,
692     TransportErrorCode error) {
693   return errorCode.first.type() == QuicErrorCode::Type::TransportErrorCode &&
694       *errorCode.first.asTransportErrorCode() == error;
695 }
696 
matchError(std::pair<QuicErrorCode,folly::Optional<folly::StringPiece>> errorCode,ApplicationErrorCode error)697 bool matchError(
698     std::pair<QuicErrorCode, folly::Optional<folly::StringPiece>> errorCode,
699     ApplicationErrorCode error) {
700   return errorCode.first.type() == QuicErrorCode::Type::ApplicationErrorCode &&
701       *errorCode.first.asApplicationErrorCode() == error;
702 }
703 
matchError(std::pair<QuicErrorCode,std::string> errorCode,ApplicationErrorCode error)704 bool matchError(
705     std::pair<QuicErrorCode, std::string> errorCode,
706     ApplicationErrorCode error) {
707   return errorCode.first.type() == QuicErrorCode::Type::ApplicationErrorCode &&
708       *errorCode.first.asApplicationErrorCode() == error;
709 }
710 
matchError(std::pair<QuicErrorCode,std::string> errorCode,TransportErrorCode error)711 bool matchError(
712     std::pair<QuicErrorCode, std::string> errorCode,
713     TransportErrorCode error) {
714   return errorCode.first.type() == QuicErrorCode::Type::TransportErrorCode &&
715       *errorCode.first.asTransportErrorCode() == error;
716 }
717 
makeAckPacketFromOutstandingPacket(OutstandingPacket outstandingPacket)718 CongestionController::AckEvent::AckPacket makeAckPacketFromOutstandingPacket(
719     OutstandingPacket outstandingPacket) {
720   return CongestionController::AckEvent::AckPacket::Builder()
721       .setSentTime(outstandingPacket.metadata.time)
722       .setEncodedSize(outstandingPacket.metadata.encodedSize)
723       .setLastAckedPacketInfo(std::move(outstandingPacket.lastAckedPacketInfo))
724       .setTotalBytesSentThen(outstandingPacket.metadata.totalBytesSent)
725       .setAppLimited(outstandingPacket.isAppLimited)
726       .build();
727 }
728 
729 folly::Optional<WriteCryptoFrame>
writeCryptoFrame(uint64_t offsetIn,Buf data,PacketBuilderInterface & builder)730 writeCryptoFrame(uint64_t offsetIn, Buf data, PacketBuilderInterface& builder) {
731   BufQueue bufQueue(std::move(data));
732   return writeCryptoFrame(offsetIn, bufQueue, builder);
733 }
734 
overridePacketWithToken(PacketBuilderInterface::Packet & packet,const StatelessResetToken & token)735 void overridePacketWithToken(
736     PacketBuilderInterface::Packet& packet,
737     const StatelessResetToken& token) {
738   overridePacketWithToken(*packet.body, token);
739 }
740 
overridePacketWithToken(folly::IOBuf & bodyBuf,const StatelessResetToken & token)741 void overridePacketWithToken(
742     folly::IOBuf& bodyBuf,
743     const StatelessResetToken& token) {
744   bodyBuf.coalesce();
745   CHECK(bodyBuf.length() > sizeof(StatelessResetToken));
746   memcpy(
747       bodyBuf.writableData() + bodyBuf.length() - sizeof(StatelessResetToken),
748       token.data(),
749       token.size());
750 }
751 
writableContains(QuicStreamManager & streamManager,StreamId streamId)752 bool writableContains(QuicStreamManager& streamManager, StreamId streamId) {
753   return streamManager.writableStreams().count(streamId) > 0 ||
754       streamManager.writableControlStreams().count(streamId) > 0;
755 }
756 
757 std::unique_ptr<PacketNumberCipher>
makePacketNumberCipher(fizz::CipherSuite) const758 FizzCryptoTestFactory::makePacketNumberCipher(fizz::CipherSuite) const {
759   return std::move(packetNumberCipher_);
760 }
761 
762 std::unique_ptr<PacketNumberCipher>
makePacketNumberCipher(folly::ByteRange secret) const763 FizzCryptoTestFactory::makePacketNumberCipher(folly::ByteRange secret) const {
764   return _makePacketNumberCipher(secret);
765 }
766 
setMockPacketNumberCipher(std::unique_ptr<PacketNumberCipher> packetNumberCipher)767 void FizzCryptoTestFactory::setMockPacketNumberCipher(
768     std::unique_ptr<PacketNumberCipher> packetNumberCipher) {
769   packetNumberCipher_ = std::move(packetNumberCipher);
770 }
771 
setDefault()772 void FizzCryptoTestFactory::setDefault() {
773   ON_CALL(*this, _makePacketNumberCipher(_))
774       .WillByDefault(Invoke([&](folly::ByteRange secret) {
775         return FizzCryptoFactory::makePacketNumberCipher(secret);
776       }));
777 }
778 
reset()779 void TestPacketBatchWriter::reset() {
780   bufNum_ = 0;
781   bufSize_ = 0;
782 }
783 
append(std::unique_ptr<folly::IOBuf> &&,size_t size,const folly::SocketAddress &,folly::AsyncUDPSocket *)784 bool TestPacketBatchWriter::append(
785     std::unique_ptr<folly::IOBuf>&& /*unused*/,
786     size_t size,
787     const folly::SocketAddress& /*unused*/,
788     folly::AsyncUDPSocket* /*unused*/) {
789   bufNum_++;
790   bufSize_ += size;
791   return ((maxBufs_ < 0) || (bufNum_ >= maxBufs_));
792 }
793 
write(folly::AsyncUDPSocket &,const folly::SocketAddress &)794 ssize_t TestPacketBatchWriter::write(
795     folly::AsyncUDPSocket& /*unused*/,
796     const folly::SocketAddress& /*unused*/) {
797   return bufSize_;
798 }
799 
getQuicTestKey()800 TrafficKey getQuicTestKey() {
801   TrafficKey testKey;
802   testKey.key = folly::IOBuf::copyBuffer(
803       folly::unhexlify("000102030405060708090A0B0C0D0E0F"));
804   testKey.iv =
805       folly::IOBuf::copyBuffer(folly::unhexlify("000102030405060708090A0B"));
806   return testKey;
807 }
808 
getProtectionKey()809 std::unique_ptr<folly::IOBuf> getProtectionKey() {
810   FizzCryptoFactory factory;
811   auto secret = folly::range(getRandSecret());
812   auto pnCipher =
813       factory.makePacketNumberCipher(fizz::CipherSuite::TLS_AES_128_GCM_SHA256);
814   auto deriver = factory.getFizzFactory()->makeKeyDeriver(
815       fizz::CipherSuite::TLS_AES_128_GCM_SHA256);
816   return deriver->expandLabel(
817       secret, kQuicPNLabel, folly::IOBuf::create(0), pnCipher->keyLength());
818 }
819 } // namespace test
820 } // namespace quic
821