1 /*
2  *  Copyright (c) 2018-present, Facebook, Inc.
3  *  All rights reserved.
4  *
5  *  This source code is licensed under the BSD-style license found in the
6  *  LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <fizz/crypto/aead/test/Mocks.h>
12 #include <fizz/crypto/exchange/test/Mocks.h>
13 #include <fizz/crypto/test/Mocks.h>
14 #include <fizz/protocol/AsyncFizzBase.h>
15 #include <fizz/protocol/Certificate.h>
16 #include <fizz/protocol/CertificateCompressor.h>
17 #include <fizz/protocol/CertificateVerifier.h>
18 #include <fizz/protocol/HandshakeContext.h>
19 #include <fizz/protocol/KeyScheduler.h>
20 #include <fizz/protocol/OpenSSLFactory.h>
21 #include <fizz/protocol/Types.h>
22 #include <fizz/record/test/Mocks.h>
23 
24 #include <folly/io/async/test/MockAsyncTransport.h>
25 
26 namespace fizz {
27 namespace test {
28 
29 /* using override */
30 using namespace testing;
31 
32 class MockKeyScheduler : public KeyScheduler {
33  public:
MockKeyScheduler()34   MockKeyScheduler() : KeyScheduler(std::make_unique<MockKeyDerivation>()) {}
35 
36   MOCK_METHOD1(deriveEarlySecret, void(folly::ByteRange psk));
37   MOCK_METHOD0(deriveHandshakeSecret, void());
38   MOCK_METHOD1(deriveHandshakeSecret, void(folly::ByteRange ecdhe));
39   MOCK_METHOD0(deriveMasterSecret, void());
40   MOCK_METHOD1(deriveAppTrafficSecrets, void(folly::ByteRange transcript));
41   MOCK_METHOD0(clearMasterSecret, void());
42   MOCK_METHOD0(clientKeyUpdate, uint32_t());
43   MOCK_METHOD0(serverKeyUpdate, uint32_t());
44   MOCK_CONST_METHOD2(
45       getSecret,
46       DerivedSecret(EarlySecrets s, folly::ByteRange transcript));
47   MOCK_CONST_METHOD2(
48       getSecret,
49       DerivedSecret(HandshakeSecrets s, folly::ByteRange transcript));
50   MOCK_CONST_METHOD2(
51       getSecret,
52       DerivedSecret(MasterSecrets s, folly::ByteRange transcript));
53   MOCK_CONST_METHOD1(getSecret, DerivedSecret(AppTrafficSecrets s));
54   MOCK_CONST_METHOD3(
55       getTrafficKey,
56       TrafficKey(
57           folly::ByteRange trafficSecret,
58           size_t keyLength,
59           size_t ivLength));
60   MOCK_CONST_METHOD2(
61       getResumptionSecret,
62       Buf(folly::ByteRange, folly::ByteRange));
63 
setDefaults()64   void setDefaults() {
65     ON_CALL(*this, getTrafficKey(_, _, _))
66         .WillByDefault(InvokeWithoutArgs([]() {
67           return TrafficKey{
68               folly::IOBuf::copyBuffer("key"), folly::IOBuf::copyBuffer("iv")};
69         }));
70     ON_CALL(*this, getResumptionSecret(_, _))
71         .WillByDefault(InvokeWithoutArgs(
72             []() { return folly::IOBuf::copyBuffer("resumesecret"); }));
73     ON_CALL(*this, getSecret(An<EarlySecrets>(), _))
74         .WillByDefault(Invoke([](EarlySecrets type, folly::ByteRange) {
75           return DerivedSecret(std::vector<uint8_t>(), type);
76         }));
77     ON_CALL(*this, getSecret(An<HandshakeSecrets>(), _))
78         .WillByDefault(Invoke([](HandshakeSecrets type, folly::ByteRange) {
79           return DerivedSecret(std::vector<uint8_t>(), type);
80         }));
81     ON_CALL(*this, getSecret(An<MasterSecrets>(), _))
82         .WillByDefault(Invoke([](MasterSecrets type, folly::ByteRange) {
83           return DerivedSecret(std::vector<uint8_t>(), type);
84         }));
85     ON_CALL(*this, getSecret(_))
86         .WillByDefault(Invoke([](AppTrafficSecrets type) {
87           return DerivedSecret(std::vector<uint8_t>(), type);
88         }));
89   }
90 };
91 
92 class MockHandshakeContext : public HandshakeContext {
93  public:
94   MOCK_METHOD1(appendToTranscript, void(const Buf& transcript));
95   MOCK_CONST_METHOD0(getHandshakeContext, Buf());
96   MOCK_CONST_METHOD1(getFinishedData, Buf(folly::ByteRange baseKey));
97   MOCK_CONST_METHOD0(getBlankContext, folly::ByteRange());
98   MOCK_CONST_METHOD0(clone, std::unique_ptr<HandshakeContext>());
99 
setDefaults()100   void setDefaults() {
101     ON_CALL(*this, getHandshakeContext()).WillByDefault(InvokeWithoutArgs([]() {
102       return folly::IOBuf::copyBuffer("context");
103     }));
104 
105     ON_CALL(*this, getFinishedData(_)).WillByDefault(InvokeWithoutArgs([]() {
106       return folly::IOBuf::copyBuffer("verifydata");
107     }));
108 
109     ON_CALL(*this, clone()).WillByDefault(InvokeWithoutArgs([]() {
110       return std::make_unique<MockHandshakeContext>();
111     }));
112   }
113 };
114 
115 class MockCert : public Cert {
116  public:
117   MOCK_CONST_METHOD0(getIdentity, std::string());
118   MOCK_CONST_METHOD0(getX509, folly::ssl::X509UniquePtr());
119 };
120 
121 class MockSelfCert : public SelfCert {
122  public:
123   MOCK_CONST_METHOD0(getIdentity, std::string());
124   MOCK_CONST_METHOD0(getAltIdentities, std::vector<std::string>());
125   MOCK_CONST_METHOD0(getSigSchemes, std::vector<SignatureScheme>());
126 
127   MOCK_CONST_METHOD1(_getCertMessage, CertificateMsg(Buf&));
getCertMessage(Buf buf)128   CertificateMsg getCertMessage(Buf buf) const override {
129     return _getCertMessage(buf);
130   }
131   MOCK_CONST_METHOD1(
132       getCompressedCert,
133       CompressedCertificate(CertificateCompressionAlgorithm));
134 
135   MOCK_CONST_METHOD3(
136       sign,
137       Buf(SignatureScheme scheme,
138           CertificateVerifyContext context,
139           folly::ByteRange toBeSigned));
140   MOCK_CONST_METHOD0(getX509, folly::ssl::X509UniquePtr());
141 };
142 
143 class MockPeerCert : public PeerCert {
144  public:
145   MOCK_CONST_METHOD0(getIdentity, std::string());
146   MOCK_CONST_METHOD4(
147       verify,
148       void(
149           SignatureScheme scheme,
150           CertificateVerifyContext context,
151           folly::ByteRange toBeSigned,
152           folly::ByteRange signature));
153   MOCK_CONST_METHOD0(getX509, folly::ssl::X509UniquePtr());
154 };
155 
156 class MockCertificateVerifier : public CertificateVerifier {
157  public:
158   MOCK_CONST_METHOD1(
159       verify,
160       void(const std::vector<std::shared_ptr<const PeerCert>>&));
161 
162   MOCK_CONST_METHOD0(getCertificateRequestExtensions, std::vector<Extension>());
163 };
164 
165 class MockFactory : public OpenSSLFactory {
166  public:
167   MOCK_CONST_METHOD0(
168       makePlaintextReadRecordLayer,
169       std::unique_ptr<PlaintextReadRecordLayer>());
170   MOCK_CONST_METHOD0(
171       makePlaintextWriteRecordLayer,
172       std::unique_ptr<PlaintextWriteRecordLayer>());
173   MOCK_CONST_METHOD1(
174       makeEncryptedReadRecordLayer,
175       std::unique_ptr<EncryptedReadRecordLayer>(
176           EncryptionLevel encryptionLevel));
177   MOCK_CONST_METHOD1(
178       makeEncryptedWriteRecordLayer,
179       std::unique_ptr<EncryptedWriteRecordLayer>(
180           EncryptionLevel encryptionLevel));
181   MOCK_CONST_METHOD1(
182       makeKeyScheduler,
183       std::unique_ptr<KeyScheduler>(CipherSuite cipher));
184   MOCK_CONST_METHOD1(
185       makeHandshakeContext,
186       std::unique_ptr<HandshakeContext>(CipherSuite cipher));
187   MOCK_CONST_METHOD1(
188       makeKeyExchange,
189       std::unique_ptr<KeyExchange>(NamedGroup group));
190   MOCK_CONST_METHOD1(makeAead, std::unique_ptr<Aead>(CipherSuite cipher));
191   MOCK_CONST_METHOD0(makeRandom, Random());
192   MOCK_CONST_METHOD0(makeTicketAgeAdd, uint32_t());
193 
194   MOCK_CONST_METHOD2(
195       _makePeerCert,
196       std::shared_ptr<PeerCert>(CertificateEntry& entry, bool leaf));
makePeerCert(CertificateEntry entry,bool leaf)197   std::shared_ptr<PeerCert> makePeerCert(CertificateEntry entry, bool leaf)
198       const override {
199     return _makePeerCert(entry, leaf);
200   }
201 
setDefaults()202   void setDefaults() {
203     ON_CALL(*this, makePlaintextReadRecordLayer())
204         .WillByDefault(InvokeWithoutArgs([]() {
205           return std::make_unique<NiceMock<MockPlaintextReadRecordLayer>>();
206         }));
207 
208     ON_CALL(*this, makePlaintextWriteRecordLayer())
209         .WillByDefault(InvokeWithoutArgs([]() {
210           auto ret =
211               std::make_unique<NiceMock<MockPlaintextWriteRecordLayer>>();
212           ret->setDefaults();
213           return ret;
214         }));
215     ON_CALL(*this, makeEncryptedReadRecordLayer(_))
216         .WillByDefault(Invoke([](EncryptionLevel encryptionLevel) {
217           return std::make_unique<NiceMock<MockEncryptedReadRecordLayer>>(
218               encryptionLevel);
219         }));
220 
221     ON_CALL(*this, makeEncryptedWriteRecordLayer(_))
222         .WillByDefault(Invoke([](EncryptionLevel encryptionLevel) {
223           auto ret = std::make_unique<NiceMock<MockEncryptedWriteRecordLayer>>(
224               encryptionLevel);
225           ret->setDefaults();
226           return ret;
227         }));
228 
229     ON_CALL(*this, makeKeyScheduler(_)).WillByDefault(InvokeWithoutArgs([]() {
230       auto ret = std::make_unique<NiceMock<MockKeyScheduler>>();
231       ret->setDefaults();
232       return ret;
233     }));
234     ON_CALL(*this, makeHandshakeContext(_))
235         .WillByDefault(InvokeWithoutArgs([]() {
236           auto ret = std::make_unique<NiceMock<MockHandshakeContext>>();
237           ret->setDefaults();
238           return ret;
239         }));
240     ON_CALL(*this, makeKeyExchange(_)).WillByDefault(InvokeWithoutArgs([]() {
241       auto ret = std::make_unique<NiceMock<MockKeyExchange>>();
242       ret->setDefaults();
243       return ret;
244     }));
245     ON_CALL(*this, makeAead(_)).WillByDefault(InvokeWithoutArgs([]() {
246       auto ret = std::make_unique<NiceMock<MockAead>>();
247       ret->setDefaults();
248       return ret;
249     }));
250     ON_CALL(*this, makeRandom()).WillByDefault(InvokeWithoutArgs([]() {
251       Random random;
252       random.fill(0x44);
253       return random;
254     }));
255     ON_CALL(*this, makeTicketAgeAdd()).WillByDefault(InvokeWithoutArgs([]() {
256       return 0x44444444;
257     }));
258     ON_CALL(*this, _makePeerCert(_, _)).WillByDefault(InvokeWithoutArgs([]() {
259       return std::make_unique<NiceMock<MockPeerCert>>();
260     }));
261   }
262 };
263 
264 class MockCertificateDecompressor : public CertificateDecompressor {
265  public:
266   MOCK_CONST_METHOD0(getAlgorithm, CertificateCompressionAlgorithm());
267   MOCK_METHOD1(decompress, CertificateMsg(const CompressedCertificate&));
setDefaults()268   void setDefaults() {
269     ON_CALL(*this, getAlgorithm()).WillByDefault(InvokeWithoutArgs([]() {
270       return CertificateCompressionAlgorithm::zlib;
271     }));
272   }
273 };
274 
275 class MockCertificateCompressor : public CertificateCompressor {
276  public:
277   MOCK_CONST_METHOD0(getAlgorithm, CertificateCompressionAlgorithm());
278   MOCK_METHOD1(compress, CompressedCertificate(const CertificateMsg&));
setDefaults()279   void setDefaults() {
280     ON_CALL(*this, getAlgorithm()).WillByDefault(InvokeWithoutArgs([]() {
281       return CertificateCompressionAlgorithm::zlib;
282     }));
283   }
284 };
285 
286 class MockAsyncFizzBase : public AsyncFizzBase {
287  public:
MockAsyncFizzBase()288   MockAsyncFizzBase()
289       : AsyncFizzBase(
290             folly::AsyncTransport::UniquePtr(
291                 new folly::test::MockAsyncTransport()),
292             AsyncFizzBase::TransportOptions()) {}
293   MOCK_CONST_METHOD0(good, bool());
294   MOCK_CONST_METHOD0(readable, bool());
295   MOCK_CONST_METHOD0(connecting, bool());
296   MOCK_CONST_METHOD0(error, bool());
297   MOCK_CONST_METHOD0(getPeerCert, folly::ssl::X509UniquePtr());
298   MOCK_CONST_METHOD0(getSelfCert, const X509*());
299   MOCK_CONST_METHOD0(isReplaySafe, bool());
300   MOCK_METHOD1(
301       setReplaySafetyCallback,
302       void(folly::AsyncTransport::ReplaySafetyCallback* callback));
303   MOCK_CONST_METHOD0(getSelfCertificate, const Cert*());
304   MOCK_CONST_METHOD0(getPeerCertificate, const Cert*());
305   MOCK_CONST_METHOD0(getApplicationProtocol_, std::string());
306 
307   MOCK_METHOD1(setReadCB, void(ReadCallback*));
308   MOCK_METHOD1(setEndOfTLSCallback, void(EndOfTLSCallback*));
309 
getApplicationProtocol()310   std::string getApplicationProtocol() const noexcept override {
311     return getApplicationProtocol_();
312   }
313 
314   MOCK_CONST_METHOD0(getCipher, folly::Optional<CipherSuite>());
315   MOCK_CONST_METHOD0(getSupportedSigSchemes, std::vector<SignatureScheme>());
316   MOCK_CONST_METHOD3(getEkm, Buf(folly::StringPiece, const Buf&, uint16_t));
317   MOCK_CONST_METHOD0(getClientRandom, folly::Optional<Random>());
318   MOCK_METHOD0(tlsShutdown, void());
319 
320   MOCK_METHOD3(
321       writeAppDataInternal,
322       void(
323           folly::AsyncTransport::WriteCallback*,
324           std::shared_ptr<folly::IOBuf>,
325           folly::WriteFlags));
326 
327   void writeAppData(
328       folly::AsyncTransport::WriteCallback* callback,
329       std::unique_ptr<folly::IOBuf>&& buf,
330       folly::WriteFlags flags = folly::WriteFlags::NONE) override {
331     writeAppDataInternal(
332         callback, std::shared_ptr<folly::IOBuf>(buf.release()), flags);
333   }
334 
335   MOCK_METHOD1(transportError, void(const folly::AsyncSocketException&));
336 
337   MOCK_METHOD0(transportDataAvailable, void());
338   MOCK_METHOD0(pauseEvents, void());
339   MOCK_METHOD0(resumeEvents, void());
340 };
341 
342 } // namespace test
343 } // namespace fizz
344