1 /* 2 * Copyright (c) 2018-present, Facebook, Inc. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <fizz/crypto/aead/test/Mocks.h> 12 #include <fizz/crypto/exchange/test/Mocks.h> 13 #include <fizz/crypto/test/Mocks.h> 14 #include <fizz/protocol/AsyncFizzBase.h> 15 #include <fizz/protocol/Certificate.h> 16 #include <fizz/protocol/CertificateCompressor.h> 17 #include <fizz/protocol/CertificateVerifier.h> 18 #include <fizz/protocol/HandshakeContext.h> 19 #include <fizz/protocol/KeyScheduler.h> 20 #include <fizz/protocol/OpenSSLFactory.h> 21 #include <fizz/protocol/Types.h> 22 #include <fizz/record/test/Mocks.h> 23 24 #include <folly/io/async/test/MockAsyncTransport.h> 25 26 namespace fizz { 27 namespace test { 28 29 /* using override */ 30 using namespace testing; 31 32 class MockKeyScheduler : public KeyScheduler { 33 public: MockKeyScheduler()34 MockKeyScheduler() : KeyScheduler(std::make_unique<MockKeyDerivation>()) {} 35 36 MOCK_METHOD1(deriveEarlySecret, void(folly::ByteRange psk)); 37 MOCK_METHOD0(deriveHandshakeSecret, void()); 38 MOCK_METHOD1(deriveHandshakeSecret, void(folly::ByteRange ecdhe)); 39 MOCK_METHOD0(deriveMasterSecret, void()); 40 MOCK_METHOD1(deriveAppTrafficSecrets, void(folly::ByteRange transcript)); 41 MOCK_METHOD0(clearMasterSecret, void()); 42 MOCK_METHOD0(clientKeyUpdate, uint32_t()); 43 MOCK_METHOD0(serverKeyUpdate, uint32_t()); 44 MOCK_CONST_METHOD2( 45 getSecret, 46 DerivedSecret(EarlySecrets s, folly::ByteRange transcript)); 47 MOCK_CONST_METHOD2( 48 getSecret, 49 DerivedSecret(HandshakeSecrets s, folly::ByteRange transcript)); 50 MOCK_CONST_METHOD2( 51 getSecret, 52 DerivedSecret(MasterSecrets s, folly::ByteRange transcript)); 53 MOCK_CONST_METHOD1(getSecret, DerivedSecret(AppTrafficSecrets s)); 54 MOCK_CONST_METHOD3( 55 getTrafficKey, 56 TrafficKey( 57 folly::ByteRange trafficSecret, 58 size_t keyLength, 59 size_t ivLength)); 60 MOCK_CONST_METHOD2( 61 getResumptionSecret, 62 Buf(folly::ByteRange, folly::ByteRange)); 63 setDefaults()64 void setDefaults() { 65 ON_CALL(*this, getTrafficKey(_, _, _)) 66 .WillByDefault(InvokeWithoutArgs([]() { 67 return TrafficKey{ 68 folly::IOBuf::copyBuffer("key"), folly::IOBuf::copyBuffer("iv")}; 69 })); 70 ON_CALL(*this, getResumptionSecret(_, _)) 71 .WillByDefault(InvokeWithoutArgs( 72 []() { return folly::IOBuf::copyBuffer("resumesecret"); })); 73 ON_CALL(*this, getSecret(An<EarlySecrets>(), _)) 74 .WillByDefault(Invoke([](EarlySecrets type, folly::ByteRange) { 75 return DerivedSecret(std::vector<uint8_t>(), type); 76 })); 77 ON_CALL(*this, getSecret(An<HandshakeSecrets>(), _)) 78 .WillByDefault(Invoke([](HandshakeSecrets type, folly::ByteRange) { 79 return DerivedSecret(std::vector<uint8_t>(), type); 80 })); 81 ON_CALL(*this, getSecret(An<MasterSecrets>(), _)) 82 .WillByDefault(Invoke([](MasterSecrets type, folly::ByteRange) { 83 return DerivedSecret(std::vector<uint8_t>(), type); 84 })); 85 ON_CALL(*this, getSecret(_)) 86 .WillByDefault(Invoke([](AppTrafficSecrets type) { 87 return DerivedSecret(std::vector<uint8_t>(), type); 88 })); 89 } 90 }; 91 92 class MockHandshakeContext : public HandshakeContext { 93 public: 94 MOCK_METHOD1(appendToTranscript, void(const Buf& transcript)); 95 MOCK_CONST_METHOD0(getHandshakeContext, Buf()); 96 MOCK_CONST_METHOD1(getFinishedData, Buf(folly::ByteRange baseKey)); 97 MOCK_CONST_METHOD0(getBlankContext, folly::ByteRange()); 98 MOCK_CONST_METHOD0(clone, std::unique_ptr<HandshakeContext>()); 99 setDefaults()100 void setDefaults() { 101 ON_CALL(*this, getHandshakeContext()).WillByDefault(InvokeWithoutArgs([]() { 102 return folly::IOBuf::copyBuffer("context"); 103 })); 104 105 ON_CALL(*this, getFinishedData(_)).WillByDefault(InvokeWithoutArgs([]() { 106 return folly::IOBuf::copyBuffer("verifydata"); 107 })); 108 109 ON_CALL(*this, clone()).WillByDefault(InvokeWithoutArgs([]() { 110 return std::make_unique<MockHandshakeContext>(); 111 })); 112 } 113 }; 114 115 class MockCert : public Cert { 116 public: 117 MOCK_CONST_METHOD0(getIdentity, std::string()); 118 MOCK_CONST_METHOD0(getX509, folly::ssl::X509UniquePtr()); 119 }; 120 121 class MockSelfCert : public SelfCert { 122 public: 123 MOCK_CONST_METHOD0(getIdentity, std::string()); 124 MOCK_CONST_METHOD0(getAltIdentities, std::vector<std::string>()); 125 MOCK_CONST_METHOD0(getSigSchemes, std::vector<SignatureScheme>()); 126 127 MOCK_CONST_METHOD1(_getCertMessage, CertificateMsg(Buf&)); getCertMessage(Buf buf)128 CertificateMsg getCertMessage(Buf buf) const override { 129 return _getCertMessage(buf); 130 } 131 MOCK_CONST_METHOD1( 132 getCompressedCert, 133 CompressedCertificate(CertificateCompressionAlgorithm)); 134 135 MOCK_CONST_METHOD3( 136 sign, 137 Buf(SignatureScheme scheme, 138 CertificateVerifyContext context, 139 folly::ByteRange toBeSigned)); 140 MOCK_CONST_METHOD0(getX509, folly::ssl::X509UniquePtr()); 141 }; 142 143 class MockPeerCert : public PeerCert { 144 public: 145 MOCK_CONST_METHOD0(getIdentity, std::string()); 146 MOCK_CONST_METHOD4( 147 verify, 148 void( 149 SignatureScheme scheme, 150 CertificateVerifyContext context, 151 folly::ByteRange toBeSigned, 152 folly::ByteRange signature)); 153 MOCK_CONST_METHOD0(getX509, folly::ssl::X509UniquePtr()); 154 }; 155 156 class MockCertificateVerifier : public CertificateVerifier { 157 public: 158 MOCK_CONST_METHOD1( 159 verify, 160 void(const std::vector<std::shared_ptr<const PeerCert>>&)); 161 162 MOCK_CONST_METHOD0(getCertificateRequestExtensions, std::vector<Extension>()); 163 }; 164 165 class MockFactory : public OpenSSLFactory { 166 public: 167 MOCK_CONST_METHOD0( 168 makePlaintextReadRecordLayer, 169 std::unique_ptr<PlaintextReadRecordLayer>()); 170 MOCK_CONST_METHOD0( 171 makePlaintextWriteRecordLayer, 172 std::unique_ptr<PlaintextWriteRecordLayer>()); 173 MOCK_CONST_METHOD1( 174 makeEncryptedReadRecordLayer, 175 std::unique_ptr<EncryptedReadRecordLayer>( 176 EncryptionLevel encryptionLevel)); 177 MOCK_CONST_METHOD1( 178 makeEncryptedWriteRecordLayer, 179 std::unique_ptr<EncryptedWriteRecordLayer>( 180 EncryptionLevel encryptionLevel)); 181 MOCK_CONST_METHOD1( 182 makeKeyScheduler, 183 std::unique_ptr<KeyScheduler>(CipherSuite cipher)); 184 MOCK_CONST_METHOD1( 185 makeHandshakeContext, 186 std::unique_ptr<HandshakeContext>(CipherSuite cipher)); 187 MOCK_CONST_METHOD1( 188 makeKeyExchange, 189 std::unique_ptr<KeyExchange>(NamedGroup group)); 190 MOCK_CONST_METHOD1(makeAead, std::unique_ptr<Aead>(CipherSuite cipher)); 191 MOCK_CONST_METHOD0(makeRandom, Random()); 192 MOCK_CONST_METHOD0(makeTicketAgeAdd, uint32_t()); 193 194 MOCK_CONST_METHOD2( 195 _makePeerCert, 196 std::shared_ptr<PeerCert>(CertificateEntry& entry, bool leaf)); makePeerCert(CertificateEntry entry,bool leaf)197 std::shared_ptr<PeerCert> makePeerCert(CertificateEntry entry, bool leaf) 198 const override { 199 return _makePeerCert(entry, leaf); 200 } 201 setDefaults()202 void setDefaults() { 203 ON_CALL(*this, makePlaintextReadRecordLayer()) 204 .WillByDefault(InvokeWithoutArgs([]() { 205 return std::make_unique<NiceMock<MockPlaintextReadRecordLayer>>(); 206 })); 207 208 ON_CALL(*this, makePlaintextWriteRecordLayer()) 209 .WillByDefault(InvokeWithoutArgs([]() { 210 auto ret = 211 std::make_unique<NiceMock<MockPlaintextWriteRecordLayer>>(); 212 ret->setDefaults(); 213 return ret; 214 })); 215 ON_CALL(*this, makeEncryptedReadRecordLayer(_)) 216 .WillByDefault(Invoke([](EncryptionLevel encryptionLevel) { 217 return std::make_unique<NiceMock<MockEncryptedReadRecordLayer>>( 218 encryptionLevel); 219 })); 220 221 ON_CALL(*this, makeEncryptedWriteRecordLayer(_)) 222 .WillByDefault(Invoke([](EncryptionLevel encryptionLevel) { 223 auto ret = std::make_unique<NiceMock<MockEncryptedWriteRecordLayer>>( 224 encryptionLevel); 225 ret->setDefaults(); 226 return ret; 227 })); 228 229 ON_CALL(*this, makeKeyScheduler(_)).WillByDefault(InvokeWithoutArgs([]() { 230 auto ret = std::make_unique<NiceMock<MockKeyScheduler>>(); 231 ret->setDefaults(); 232 return ret; 233 })); 234 ON_CALL(*this, makeHandshakeContext(_)) 235 .WillByDefault(InvokeWithoutArgs([]() { 236 auto ret = std::make_unique<NiceMock<MockHandshakeContext>>(); 237 ret->setDefaults(); 238 return ret; 239 })); 240 ON_CALL(*this, makeKeyExchange(_)).WillByDefault(InvokeWithoutArgs([]() { 241 auto ret = std::make_unique<NiceMock<MockKeyExchange>>(); 242 ret->setDefaults(); 243 return ret; 244 })); 245 ON_CALL(*this, makeAead(_)).WillByDefault(InvokeWithoutArgs([]() { 246 auto ret = std::make_unique<NiceMock<MockAead>>(); 247 ret->setDefaults(); 248 return ret; 249 })); 250 ON_CALL(*this, makeRandom()).WillByDefault(InvokeWithoutArgs([]() { 251 Random random; 252 random.fill(0x44); 253 return random; 254 })); 255 ON_CALL(*this, makeTicketAgeAdd()).WillByDefault(InvokeWithoutArgs([]() { 256 return 0x44444444; 257 })); 258 ON_CALL(*this, _makePeerCert(_, _)).WillByDefault(InvokeWithoutArgs([]() { 259 return std::make_unique<NiceMock<MockPeerCert>>(); 260 })); 261 } 262 }; 263 264 class MockCertificateDecompressor : public CertificateDecompressor { 265 public: 266 MOCK_CONST_METHOD0(getAlgorithm, CertificateCompressionAlgorithm()); 267 MOCK_METHOD1(decompress, CertificateMsg(const CompressedCertificate&)); setDefaults()268 void setDefaults() { 269 ON_CALL(*this, getAlgorithm()).WillByDefault(InvokeWithoutArgs([]() { 270 return CertificateCompressionAlgorithm::zlib; 271 })); 272 } 273 }; 274 275 class MockCertificateCompressor : public CertificateCompressor { 276 public: 277 MOCK_CONST_METHOD0(getAlgorithm, CertificateCompressionAlgorithm()); 278 MOCK_METHOD1(compress, CompressedCertificate(const CertificateMsg&)); setDefaults()279 void setDefaults() { 280 ON_CALL(*this, getAlgorithm()).WillByDefault(InvokeWithoutArgs([]() { 281 return CertificateCompressionAlgorithm::zlib; 282 })); 283 } 284 }; 285 286 class MockAsyncFizzBase : public AsyncFizzBase { 287 public: MockAsyncFizzBase()288 MockAsyncFizzBase() 289 : AsyncFizzBase( 290 folly::AsyncTransport::UniquePtr( 291 new folly::test::MockAsyncTransport()), 292 AsyncFizzBase::TransportOptions()) {} 293 MOCK_CONST_METHOD0(good, bool()); 294 MOCK_CONST_METHOD0(readable, bool()); 295 MOCK_CONST_METHOD0(connecting, bool()); 296 MOCK_CONST_METHOD0(error, bool()); 297 MOCK_CONST_METHOD0(getPeerCert, folly::ssl::X509UniquePtr()); 298 MOCK_CONST_METHOD0(getSelfCert, const X509*()); 299 MOCK_CONST_METHOD0(isReplaySafe, bool()); 300 MOCK_METHOD1( 301 setReplaySafetyCallback, 302 void(folly::AsyncTransport::ReplaySafetyCallback* callback)); 303 MOCK_CONST_METHOD0(getSelfCertificate, const Cert*()); 304 MOCK_CONST_METHOD0(getPeerCertificate, const Cert*()); 305 MOCK_CONST_METHOD0(getApplicationProtocol_, std::string()); 306 307 MOCK_METHOD1(setReadCB, void(ReadCallback*)); 308 MOCK_METHOD1(setEndOfTLSCallback, void(EndOfTLSCallback*)); 309 getApplicationProtocol()310 std::string getApplicationProtocol() const noexcept override { 311 return getApplicationProtocol_(); 312 } 313 314 MOCK_CONST_METHOD0(getCipher, folly::Optional<CipherSuite>()); 315 MOCK_CONST_METHOD0(getSupportedSigSchemes, std::vector<SignatureScheme>()); 316 MOCK_CONST_METHOD3(getEkm, Buf(folly::StringPiece, const Buf&, uint16_t)); 317 MOCK_CONST_METHOD0(getClientRandom, folly::Optional<Random>()); 318 MOCK_METHOD0(tlsShutdown, void()); 319 320 MOCK_METHOD3( 321 writeAppDataInternal, 322 void( 323 folly::AsyncTransport::WriteCallback*, 324 std::shared_ptr<folly::IOBuf>, 325 folly::WriteFlags)); 326 327 void writeAppData( 328 folly::AsyncTransport::WriteCallback* callback, 329 std::unique_ptr<folly::IOBuf>&& buf, 330 folly::WriteFlags flags = folly::WriteFlags::NONE) override { 331 writeAppDataInternal( 332 callback, std::shared_ptr<folly::IOBuf>(buf.release()), flags); 333 } 334 335 MOCK_METHOD1(transportError, void(const folly::AsyncSocketException&)); 336 337 MOCK_METHOD0(transportDataAvailable, void()); 338 MOCK_METHOD0(pauseEvents, void()); 339 MOCK_METHOD0(resumeEvents, void()); 340 }; 341 342 } // namespace test 343 } // namespace fizz 344