1 // datatest.cpp - originally written and placed in the public domain by Wei Dai
2 //                CryptoPP::Test namespace added by JW in February 2017
3 
4 #define CRYPTOPP_DEFAULT_NO_DLL
5 #define CRYPTOPP_ENABLE_NAMESPACE_WEAK 1
6 
7 #include "cryptlib.h"
8 #include "factory.h"
9 #include "integer.h"
10 #include "filters.h"
11 #include "randpool.h"
12 #include "files.h"
13 #include "trunhash.h"
14 #include "queue.h"
15 #include "smartptr.h"
16 #include "validate.h"
17 #include "stdcpp.h"
18 #include "misc.h"
19 #include "hex.h"
20 #include "trap.h"
21 
22 #include <iostream>
23 #include <sstream>
24 #include <cerrno>
25 
26 // Aggressive stack checking with VS2005 SP1 and above.
27 #if (_MSC_FULL_VER >= 140050727)
28 # pragma strict_gs_check (on)
29 #endif
30 
31 #if CRYPTOPP_MSC_VERSION
32 # pragma warning(disable: 4505 4355)
33 #endif
34 
35 #ifdef _MSC_VER
36 # define STRTOUL64 _strtoui64
37 #else
38 # define STRTOUL64 strtoull
39 #endif
40 
41 NAMESPACE_BEGIN(CryptoPP)
42 NAMESPACE_BEGIN(Test)
43 
44 typedef std::map<std::string, std::string> TestData;
45 static bool s_thorough = false;
46 
47 class TestFailure : public Exception
48 {
49 public:
TestFailure()50 	TestFailure() : Exception(OTHER_ERROR, "Validation test failed") {}
51 };
52 
53 static const TestData *s_currentTestData = NULLPTR;
54 
TrimSpace(std::string str)55 std::string TrimSpace(std::string str)
56 {
57 	if (str.empty()) return "";
58 
59 	const std::string whitespace(" \r\t\n");
60 	std::string::size_type beg = str.find_first_not_of(whitespace);
61 	std::string::size_type end = str.find_last_not_of(whitespace);
62 
63 	if (beg != std::string::npos && end != std::string::npos)
64 		return str.substr(beg, end+1);
65 	else if (beg != std::string::npos)
66 		return str.substr(beg);
67 	else
68 		return "";
69 }
70 
TrimComment(std::string str)71 std::string TrimComment(std::string str)
72 {
73 	if (str.empty()) return "";
74 
75 	std::string::size_type first = str.find("#");
76 
77 	if (first != std::string::npos)
78 		return TrimSpace(str.substr(0, first));
79 	else
80 		return TrimSpace(str);
81 }
82 
OutputTestData(const TestData & v)83 static void OutputTestData(const TestData &v)
84 {
85 	std::cerr << "\n";
86 	for (TestData::const_iterator i = v.begin(); i != v.end(); ++i)
87 	{
88 		std::cerr << i->first << ": " << i->second << std::endl;
89 	}
90 }
91 
SignalTestFailure()92 static void SignalTestFailure()
93 {
94 	OutputTestData(*s_currentTestData);
95 	throw TestFailure();
96 }
97 
SignalUnknownAlgorithmError(const std::string & algType)98 static void SignalUnknownAlgorithmError(const std::string& algType)
99 {
100 	OutputTestData(*s_currentTestData);
101 	throw Exception(Exception::OTHER_ERROR, "Unknown algorithm " + algType + " during validation test");
102 }
103 
SignalTestError(const char * msg=NULLPTR)104 static void SignalTestError(const char* msg = NULLPTR)
105 {
106 	OutputTestData(*s_currentTestData);
107 
108 	if (msg)
109 		throw Exception(Exception::OTHER_ERROR, msg);
110 	else
111 		throw Exception(Exception::OTHER_ERROR, "Unexpected error during validation test");
112 }
113 
DataExists(const TestData & data,const char * name)114 bool DataExists(const TestData &data, const char *name)
115 {
116 	TestData::const_iterator i = data.find(name);
117 	return (i != data.end());
118 }
119 
GetRequiredDatum(const TestData & data,const char * name)120 const std::string & GetRequiredDatum(const TestData &data, const char *name)
121 {
122 	TestData::const_iterator i = data.find(name);
123 	if (i == data.end())
124 	{
125 		std::string msg("Required datum \"" + std::string(name) + "\" missing");
126 		SignalTestError(msg.c_str());
127 	}
128 	return i->second;
129 }
130 
RandomizedTransfer(BufferedTransformation & source,BufferedTransformation & target,bool finish,const std::string & channel=DEFAULT_CHANNEL)131 void RandomizedTransfer(BufferedTransformation &source, BufferedTransformation &target, bool finish, const std::string &channel=DEFAULT_CHANNEL)
132 {
133 	while (source.MaxRetrievable() > (finish ? 0 : 4096))
134 	{
135 		byte buf[4096+64];
136 		size_t start = Test::GlobalRNG().GenerateWord32(0, 63);
137 		size_t len = Test::GlobalRNG().GenerateWord32(1, UnsignedMin(4096U, 3*source.MaxRetrievable()/2));
138 		len = source.Get(buf+start, len);
139 		target.ChannelPut(channel, buf+start, len);
140 	}
141 }
142 
PutDecodedDatumInto(const TestData & data,const char * name,BufferedTransformation & target)143 void PutDecodedDatumInto(const TestData &data, const char *name, BufferedTransformation &target)
144 {
145 	std::string s1 = GetRequiredDatum(data, name), s2;
146 	ByteQueue q;
147 
148 	while (!s1.empty())
149 	{
150 		while (s1[0] == ' ')
151 		{
152 			s1 = s1.substr(1);
153 			if (s1.empty())
154 				goto end;	// avoid invalid read if s1 is empty
155 		}
156 
157 		int repeat = 1;
158 		if (s1[0] == 'r')
159 		{
160 			s1 = s1.erase(0, 1);
161 			repeat = ::atoi(s1.c_str());
162 			s1 = s1.substr(s1.find(' ')+1);
163 		}
164 
165 		// Convert word32 or word64 to little endian order. Some algorithm test vectors are
166 		// presented in the format. We probably should have named them word32le and word64le.
167 		if (s1.length() >= 6 && (s1.substr(0,6) == "word32" || s1.substr(0,6) == "word64"))
168 		{
169 			std::istringstream iss(s1.substr(6));
170 			if (s1.substr(0,6) == "word64")
171 			{
172 				word64 value;
173 				while (iss >> std::skipws >> std::hex >> value)
174 				{
175 					value = ConditionalByteReverse(LITTLE_ENDIAN_ORDER, value);
176 					q.Put(reinterpret_cast<const byte *>(&value), 8);
177 				}
178 			}
179 			else
180 			{
181 				word32 value;
182 				while (iss >> std::skipws >> std::hex >> value)
183 				{
184 					value = ConditionalByteReverse(LITTLE_ENDIAN_ORDER, value);
185 					q.Put(reinterpret_cast<const byte *>(&value), 4);
186 				}
187 			}
188 			goto end;
189 		}
190 
191 		s2.clear();
192 		if (s1[0] == '\"')
193 		{
194 			s2 = s1.substr(1, s1.find('\"', 1)-1);
195 			s1 = s1.substr(s2.length() + 2);
196 		}
197 		else if (s1.substr(0, 2) == "0x")
198 		{
199 			std::string::size_type pos = s1.find(' ');
200 			StringSource(s1.substr(2, pos), true, new HexDecoder(new StringSink(s2)));
201 			s1 = s1.substr(STDMIN(pos, s1.length()));
202 		}
203 		else
204 		{
205 			std::string::size_type pos = s1.find(' ');
206 			StringSource(s1.substr(0, pos), true, new HexDecoder(new StringSink(s2)));
207 			s1 = s1.substr(STDMIN(pos, s1.length()));
208 		}
209 
210 		while (repeat--)
211 		{
212 			q.Put(ConstBytePtr(s2), BytePtrSize(s2));
213 			RandomizedTransfer(q, target, false);
214 		}
215 	}
216 
217 end:
218 	RandomizedTransfer(q, target, true);
219 }
220 
GetDecodedDatum(const TestData & data,const char * name)221 std::string GetDecodedDatum(const TestData &data, const char *name)
222 {
223 	std::string s;
224 	PutDecodedDatumInto(data, name, StringSink(s).Ref());
225 	return s;
226 }
227 
GetOptionalDecodedDatum(const TestData & data,const char * name)228 std::string GetOptionalDecodedDatum(const TestData &data, const char *name)
229 {
230 	std::string s;
231 	if (DataExists(data, name))
232 		PutDecodedDatumInto(data, name, StringSink(s).Ref());
233 	return s;
234 }
235 
236 class TestDataNameValuePairs : public NameValuePairs
237 {
238 public:
TestDataNameValuePairs(const TestData & data)239 	TestDataNameValuePairs(const TestData &data) : m_data(data) {}
240 
GetVoidValue(const char * name,const std::type_info & valueType,void * pValue) const241 	virtual bool GetVoidValue(const char *name, const std::type_info &valueType, void *pValue) const
242 	{
243 		TestData::const_iterator i = m_data.find(name);
244 		if (i == m_data.end())
245 		{
246 			if (std::string(name) == Name::DigestSize() && valueType == typeid(int))
247 			{
248 				i = m_data.find("MAC");
249 				if (i == m_data.end())
250 					i = m_data.find("Digest");
251 				if (i == m_data.end())
252 					return false;
253 
254 				m_temp.clear();
255 				PutDecodedDatumInto(m_data, i->first.c_str(), StringSink(m_temp).Ref());
256 				*reinterpret_cast<int *>(pValue) = (int)m_temp.size();
257 				return true;
258 			}
259 			else
260 				return false;
261 		}
262 
263 		const std::string &value = i->second;
264 
265 		if (valueType == typeid(int))
266 			*reinterpret_cast<int *>(pValue) = atoi(value.c_str());
267 		else if (valueType == typeid(word64))
268 		{
269 			std::string x(value.empty() ? "0" : value);
270 			const char* beg = &x[0];
271 			char* end = &x[0] + value.size();
272 
273 			errno = 0;
274 			*reinterpret_cast<word64*>(pValue) = STRTOUL64(beg, &end, 0);
275 			if (errno != 0)
276 				return false;
277 		}
278 		else if (valueType == typeid(Integer))
279 			*reinterpret_cast<Integer *>(pValue) = Integer((std::string(value) + "h").c_str());
280 		else if (valueType == typeid(ConstByteArrayParameter))
281 		{
282 			m_temp.clear();
283 			PutDecodedDatumInto(m_data, name, StringSink(m_temp).Ref());
284 			reinterpret_cast<ConstByteArrayParameter *>(pValue)->Assign(ConstBytePtr(m_temp), BytePtrSize(m_temp), false);
285 		}
286 		else
287 			throw ValueTypeMismatch(name, typeid(std::string), valueType);
288 
289 		return true;
290 	}
291 
292 private:
293 	const TestData &m_data;
294 	mutable std::string m_temp;
295 };
296 
TestKeyPairValidAndConsistent(CryptoMaterial & pub,const CryptoMaterial & priv)297 void TestKeyPairValidAndConsistent(CryptoMaterial &pub, const CryptoMaterial &priv)
298 {
299 	if (!pub.Validate(Test::GlobalRNG(), 2U+!!s_thorough))
300 		SignalTestFailure();
301 	if (!priv.Validate(Test::GlobalRNG(), 2U+!!s_thorough))
302 		SignalTestFailure();
303 
304 	ByteQueue bq1, bq2;
305 	pub.Save(bq1);
306 	pub.AssignFrom(priv);
307 	pub.Save(bq2);
308 	if (bq1 != bq2)
309 		SignalTestFailure();
310 }
311 
TestSignatureScheme(TestData & v)312 void TestSignatureScheme(TestData &v)
313 {
314 	std::string name = GetRequiredDatum(v, "Name");
315 	std::string test = GetRequiredDatum(v, "Test");
316 
317 	member_ptr<PK_Signer> signer(ObjectFactoryRegistry<PK_Signer>::Registry().CreateObject(name.c_str()));
318 	member_ptr<PK_Verifier> verifier(ObjectFactoryRegistry<PK_Verifier>::Registry().CreateObject(name.c_str()));
319 
320 	// Code coverage
321 	(void)signer->AlgorithmName();
322 	(void)verifier->AlgorithmName();
323 	(void)signer->AlgorithmProvider();
324 	(void)verifier->AlgorithmProvider();
325 
326 	TestDataNameValuePairs pairs(v);
327 
328 	if (test == "GenerateKey")
329 	{
330 		signer->AccessPrivateKey().GenerateRandom(Test::GlobalRNG(), pairs);
331 		verifier->AccessPublicKey().AssignFrom(signer->AccessPrivateKey());
332 	}
333 	else
334 	{
335 		std::string keyFormat = GetRequiredDatum(v, "KeyFormat");
336 
337 		if (keyFormat == "DER")
338 			verifier->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref());
339 		else if (keyFormat == "Component")
340 			verifier->AccessMaterial().AssignFrom(pairs);
341 
342 		if (test == "Verify" || test == "NotVerify")
343 		{
344 			SignatureVerificationFilter verifierFilter(*verifier, NULLPTR, SignatureVerificationFilter::SIGNATURE_AT_BEGIN);
345 			PutDecodedDatumInto(v, "Signature", verifierFilter);
346 			PutDecodedDatumInto(v, "Message", verifierFilter);
347 			verifierFilter.MessageEnd();
348 			if (verifierFilter.GetLastResult() == (test == "NotVerify"))
349 				SignalTestFailure();
350 			return;
351 		}
352 		else if (test == "PublicKeyValid")
353 		{
354 			if (!verifier->GetMaterial().Validate(Test::GlobalRNG(), 3))
355 				SignalTestFailure();
356 			return;
357 		}
358 
359 		if (keyFormat == "DER")
360 			signer->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref());
361 		else if (keyFormat == "Component")
362 			signer->AccessMaterial().AssignFrom(pairs);
363 	}
364 
365 	if (test == "GenerateKey" || test == "KeyPairValidAndConsistent")
366 	{
367 		TestKeyPairValidAndConsistent(verifier->AccessMaterial(), signer->GetMaterial());
368 		SignatureVerificationFilter verifierFilter(*verifier, NULLPTR, SignatureVerificationFilter::THROW_EXCEPTION);
369 		const byte msg[3] = {'a', 'b', 'c'};
370 		verifierFilter.Put(msg, sizeof(msg));
371 		StringSource ss(msg, sizeof(msg), true, new SignerFilter(Test::GlobalRNG(), *signer, new Redirector(verifierFilter)));
372 	}
373 	else if (test == "Sign")
374 	{
375 		SignerFilter f(Test::GlobalRNG(), *signer, new HexEncoder(new FileSink(std::cout)));
376 		StringSource ss(GetDecodedDatum(v, "Message"), true, new Redirector(f));
377 		SignalTestFailure();
378 	}
379 	else if (test == "DeterministicSign")
380 	{
381 		// This test is specialized for RFC 6979. The RFC is a drop-in replacement
382 		// for DSA and ECDSA, and access to the seed or secret is not needed. If
383 		// additional deterministic signatures are added, then the test harness will
384 		// likely need to be extended.
385 		std::string signature;
386 		SignerFilter f(Test::GlobalRNG(), *signer, new StringSink(signature));
387 		StringSource ss(GetDecodedDatum(v, "Message"), true, new Redirector(f));
388 
389 		if (GetDecodedDatum(v, "Signature") != signature)
390 			SignalTestFailure();
391 
392 		return;
393 	}
394 	else
395 	{
396 		std::string msg("Unknown signature test \"" + test + "\"");
397 		SignalTestError(msg.c_str());
398 		CRYPTOPP_ASSERT(false);
399 	}
400 }
401 
TestAsymmetricCipher(TestData & v)402 void TestAsymmetricCipher(TestData &v)
403 {
404 	std::string name = GetRequiredDatum(v, "Name");
405 	std::string test = GetRequiredDatum(v, "Test");
406 
407 	member_ptr<PK_Encryptor> encryptor(ObjectFactoryRegistry<PK_Encryptor>::Registry().CreateObject(name.c_str()));
408 	member_ptr<PK_Decryptor> decryptor(ObjectFactoryRegistry<PK_Decryptor>::Registry().CreateObject(name.c_str()));
409 
410 	// Code coverage
411 	(void)encryptor->AlgorithmName();
412 	(void)decryptor->AlgorithmName();
413 	(void)encryptor->AlgorithmProvider();
414 	(void)decryptor->AlgorithmProvider();
415 
416 	std::string keyFormat = GetRequiredDatum(v, "KeyFormat");
417 
418 	if (keyFormat == "DER")
419 	{
420 		decryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref());
421 		encryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref());
422 	}
423 	else if (keyFormat == "Component")
424 	{
425 		TestDataNameValuePairs pairs(v);
426 		decryptor->AccessMaterial().AssignFrom(pairs);
427 		encryptor->AccessMaterial().AssignFrom(pairs);
428 	}
429 
430 	if (test == "DecryptMatch")
431 	{
432 		std::string decrypted, expected = GetDecodedDatum(v, "Plaintext");
433 		StringSource ss(GetDecodedDatum(v, "Ciphertext"), true, new PK_DecryptorFilter(Test::GlobalRNG(), *decryptor, new StringSink(decrypted)));
434 		if (decrypted != expected)
435 			SignalTestFailure();
436 	}
437 	else if (test == "KeyPairValidAndConsistent")
438 	{
439 		TestKeyPairValidAndConsistent(encryptor->AccessMaterial(), decryptor->GetMaterial());
440 	}
441 	else
442 	{
443 		std::string msg("Unknown asymmetric cipher test \"" + test + "\"");
444 		SignalTestError(msg.c_str());
445 		CRYPTOPP_ASSERT(false);
446 	}
447 }
448 
TestSymmetricCipher(TestData & v,const NameValuePairs & overrideParameters)449 void TestSymmetricCipher(TestData &v, const NameValuePairs &overrideParameters)
450 {
451 	std::string name = GetRequiredDatum(v, "Name");
452 	std::string test = GetRequiredDatum(v, "Test");
453 
454 	std::string key = GetDecodedDatum(v, "Key");
455 	std::string plaintext = GetDecodedDatum(v, "Plaintext");
456 
457 	TestDataNameValuePairs testDataPairs(v);
458 	CombinedNameValuePairs pairs(overrideParameters, testDataPairs);
459 
460 	if (test == "Encrypt" || test == "EncryptXorDigest" || test == "Resync" || test == "EncryptionMCT" || test == "DecryptionMCT")
461 	{
462 		static member_ptr<SymmetricCipher> encryptor, decryptor;
463 		static std::string lastName;
464 
465 		if (name != lastName)
466 		{
467 			encryptor.reset(ObjectFactoryRegistry<SymmetricCipher, ENCRYPTION>::Registry().CreateObject(name.c_str()));
468 			decryptor.reset(ObjectFactoryRegistry<SymmetricCipher, DECRYPTION>::Registry().CreateObject(name.c_str()));
469 			lastName = name;
470 
471 			// Code coverage
472 			(void)encryptor->AlgorithmName();
473 			(void)decryptor->AlgorithmName();
474 			(void)encryptor->AlgorithmProvider();
475 			(void)decryptor->AlgorithmProvider();
476 			(void)encryptor->MinKeyLength();
477 			(void)decryptor->MinKeyLength();
478 			(void)encryptor->MaxKeyLength();
479 			(void)decryptor->MaxKeyLength();
480 			(void)encryptor->DefaultKeyLength();
481 			(void)decryptor->DefaultKeyLength();
482 		}
483 
484 		// Most block ciphers don't specify BlockPaddingScheme. Kalyna uses it in test vectors.
485 		// 0 is NoPadding, 1 is ZerosPadding, 2 is PkcsPadding, 3 is OneAndZerosPadding, etc
486 		// Note: The machinery is wired such that paddingScheme is effectively latched. An
487 		//   old paddingScheme may be unintentionally used in a subsequent test.
488 		int paddingScheme = pairs.GetIntValueWithDefault(Name::BlockPaddingScheme(), 0);
489 
490 		ConstByteArrayParameter iv;
491 		if (pairs.GetValue(Name::IV(), iv) && iv.size() != encryptor->IVSize())
492 			SignalTestFailure();
493 
494 		if (test == "Resync")
495 		{
496 			encryptor->Resynchronize(iv.begin(), (int)iv.size());
497 			decryptor->Resynchronize(iv.begin(), (int)iv.size());
498 		}
499 		else
500 		{
501 			encryptor->SetKey(ConstBytePtr(key), BytePtrSize(key), pairs);
502 			decryptor->SetKey(ConstBytePtr(key), BytePtrSize(key), pairs);
503 		}
504 
505 		word64 seek64 = pairs.GetWord64ValueWithDefault("Seek64", 0);
506 		if (seek64)
507 		{
508 			encryptor->Seek(seek64);
509 			decryptor->Seek(seek64);
510 		}
511 		else
512 		{
513 			int seek = pairs.GetIntValueWithDefault("Seek", 0);
514 			if (seek)
515 			{
516 				encryptor->Seek(seek);
517 				decryptor->Seek(seek);
518 			}
519 		}
520 
521 		// If a per-test vector parameter was set for a test, like BlockPadding,
522 		// BlockSize or Tweak, then it becomes latched in testDataPairs. The old
523 		// value is used in subsequent tests, and it could cause a self test
524 		// failure in the next test. The behavior surfaced under Kalyna and
525 		// Threefish. The Kalyna test vectors use NO_PADDING for all tests excpet
526 		// one. For Threefish, using (and not using) a Tweak caused problems as
527 		// we marched through test vectors. For BlockPadding, BlockSize or Tweak,
528 		// unlatch them now, after the key has been set and NameValuePairs have
529 		// been processed. Also note we only unlatch from testDataPairs. If
530 		// overrideParameters are specified, the caller is responsible for
531 		// managing the parameter.
532 		v.erase("Tweak"); v.erase("InitialBlock"); v.erase("BlockSize"); v.erase("BlockPaddingScheme");
533 
534 		std::string encrypted, xorDigest, ciphertext, ciphertextXorDigest;
535 		if (test == "EncryptionMCT" || test == "DecryptionMCT")
536 		{
537 			SymmetricCipher *cipher = encryptor.get();
538 			std::string buf(plaintext), keybuf(key);
539 
540 			if (test == "DecryptionMCT")
541 			{
542 				cipher = decryptor.get();
543 				ciphertext = GetDecodedDatum(v, "Ciphertext");
544 				buf.assign(ciphertext.begin(), ciphertext.end());
545 			}
546 
547 			for (int i=0; i<400; i++)
548 			{
549 				encrypted.reserve(10000 * plaintext.size());
550 				for (int j=0; j<10000; j++)
551 				{
552 					cipher->ProcessString(BytePtr(buf), BytePtrSize(buf));
553 					encrypted.append(buf.begin(), buf.end());
554 				}
555 
556 				encrypted.erase(0, encrypted.size() - keybuf.size());
557 				xorbuf(BytePtr(keybuf), BytePtr(encrypted), BytePtrSize(keybuf));
558 				cipher->SetKey(BytePtr(keybuf), BytePtrSize(keybuf));
559 			}
560 
561 			encrypted.assign(buf.begin(), buf.end());
562 			ciphertext = GetDecodedDatum(v, test == "EncryptionMCT" ? "Ciphertext" : "Plaintext");
563 			if (encrypted != ciphertext)
564 			{
565 				std::cout << "\nincorrectly encrypted: ";
566 				StringSource xx(encrypted, false, new HexEncoder(new FileSink(std::cout)));
567 				xx.Pump(256); xx.Flush(false);
568 				std::cout << "\n";
569 				SignalTestFailure();
570 			}
571 			return;
572 		}
573 
574 		StreamTransformationFilter encFilter(*encryptor, new StringSink(encrypted),
575 				static_cast<BlockPaddingSchemeDef::BlockPaddingScheme>(paddingScheme));
576 
577 		StringStore pstore(plaintext);
578 		RandomizedTransfer(pstore, encFilter, true);
579 		encFilter.MessageEnd();
580 
581 		if (test != "EncryptXorDigest")
582 		{
583 			ciphertext = GetDecodedDatum(v, "Ciphertext");
584 		}
585 		else
586 		{
587 			ciphertextXorDigest = GetDecodedDatum(v, "CiphertextXorDigest");
588 			xorDigest.append(encrypted, 0, 64);
589 			for (size_t i=64; i<encrypted.size(); i++)
590 				xorDigest[i%64] = static_cast<char>(xorDigest[i%64] ^ encrypted[i]);
591 		}
592 		if (test != "EncryptXorDigest" ? encrypted != ciphertext : xorDigest != ciphertextXorDigest)
593 		{
594 			std::cout << "\nincorrectly encrypted: ";
595 			StringSource xx(encrypted, false, new HexEncoder(new FileSink(std::cout)));
596 			xx.Pump(2048); xx.Flush(false);
597 			std::cout << "\n";
598 			SignalTestFailure();
599 		}
600 
601 		std::string decrypted;
602 		StreamTransformationFilter decFilter(*decryptor, new StringSink(decrypted),
603 				static_cast<BlockPaddingSchemeDef::BlockPaddingScheme>(paddingScheme));
604 
605 		StringStore cstore(encrypted);
606 		RandomizedTransfer(cstore, decFilter, true);
607 		decFilter.MessageEnd();
608 
609 		if (decrypted != plaintext)
610 		{
611 			std::cout << "\nincorrectly decrypted: ";
612 			StringSource xx(decrypted, false, new HexEncoder(new FileSink(std::cout)));
613 			xx.Pump(256); xx.Flush(false);
614 			std::cout << "\n";
615 			SignalTestFailure();
616 		}
617 	}
618 	else
619 	{
620 		std::string msg("Unknown symmetric cipher test \"" + test + "\"");
621 		SignalTestError(msg.c_str());
622 	}
623 }
624 
TestAuthenticatedSymmetricCipher(TestData & v,const NameValuePairs & overrideParameters)625 void TestAuthenticatedSymmetricCipher(TestData &v, const NameValuePairs &overrideParameters)
626 {
627 	std::string type = GetRequiredDatum(v, "AlgorithmType");
628 	std::string name = GetRequiredDatum(v, "Name");
629 	std::string test = GetRequiredDatum(v, "Test");
630 	std::string key = GetDecodedDatum(v, "Key");
631 
632 	std::string plaintext = GetOptionalDecodedDatum(v, "Plaintext");
633 	std::string ciphertext = GetOptionalDecodedDatum(v, "Ciphertext");
634 	std::string header = GetOptionalDecodedDatum(v, "Header");
635 	std::string footer = GetOptionalDecodedDatum(v, "Footer");
636 	std::string mac = GetOptionalDecodedDatum(v, "MAC");
637 
638 	TestDataNameValuePairs testDataPairs(v);
639 	CombinedNameValuePairs pairs(overrideParameters, testDataPairs);
640 
641 	if (test == "Encrypt" || test == "EncryptXorDigest" || test == "NotVerify")
642 	{
643 		member_ptr<AuthenticatedSymmetricCipher> encryptor, decryptor;
644 		encryptor.reset(ObjectFactoryRegistry<AuthenticatedSymmetricCipher, ENCRYPTION>::Registry().CreateObject(name.c_str()));
645 		decryptor.reset(ObjectFactoryRegistry<AuthenticatedSymmetricCipher, DECRYPTION>::Registry().CreateObject(name.c_str()));
646 		encryptor->SetKey(ConstBytePtr(key), BytePtrSize(key), pairs);
647 		decryptor->SetKey(ConstBytePtr(key), BytePtrSize(key), pairs);
648 
649 		// Code coverage
650 		(void)encryptor->AlgorithmName();
651 		(void)decryptor->AlgorithmName();
652 
653 		std::string encrypted, decrypted;
654 		AuthenticatedEncryptionFilter ef(*encryptor, new StringSink(encrypted));
655 		bool macAtBegin = !mac.empty() && !Test::GlobalRNG().GenerateBit();	// test both ways randomly
656 		AuthenticatedDecryptionFilter df(*decryptor, new StringSink(decrypted), macAtBegin ? AuthenticatedDecryptionFilter::MAC_AT_BEGIN : 0);
657 
658 		if (encryptor->NeedsPrespecifiedDataLengths())
659 		{
660 			encryptor->SpecifyDataLengths(header.size(), plaintext.size(), footer.size());
661 			decryptor->SpecifyDataLengths(header.size(), plaintext.size(), footer.size());
662 		}
663 
664 		StringStore sh(header), sp(plaintext), sc(ciphertext), sf(footer), sm(mac);
665 
666 		if (macAtBegin)
667 			RandomizedTransfer(sm, df, true);
668 		sh.CopyTo(df, LWORD_MAX, AAD_CHANNEL);
669 		RandomizedTransfer(sc, df, true);
670 		sf.CopyTo(df, LWORD_MAX, AAD_CHANNEL);
671 		if (!macAtBegin)
672 			RandomizedTransfer(sm, df, true);
673 		df.MessageEnd();
674 
675 		RandomizedTransfer(sh, ef, true, AAD_CHANNEL);
676 		RandomizedTransfer(sp, ef, true);
677 		RandomizedTransfer(sf, ef, true, AAD_CHANNEL);
678 		ef.MessageEnd();
679 
680 		if (test == "Encrypt" && encrypted != ciphertext+mac)
681 		{
682 			std::cout << "\nincorrectly encrypted: ";
683 			StringSource xx(encrypted, false, new HexEncoder(new FileSink(std::cout)));
684 			xx.Pump(2048); xx.Flush(false);
685 			std::cout << "\n";
686 			SignalTestFailure();
687 		}
688 		if (test == "Encrypt" && decrypted != plaintext)
689 		{
690 			std::cout << "\nincorrectly decrypted: ";
691 			StringSource xx(decrypted, false, new HexEncoder(new FileSink(std::cout)));
692 			xx.Pump(256); xx.Flush(false);
693 			std::cout << "\n";
694 			SignalTestFailure();
695 		}
696 
697 		if (ciphertext.size()+mac.size()-plaintext.size() != encryptor->DigestSize())
698 		{
699 			std::cout << "\nbad MAC size\n";
700 			SignalTestFailure();
701 		}
702 		if (df.GetLastResult() != (test == "Encrypt"))
703 		{
704 			std::cout << "\nMAC incorrectly verified\n";
705 			SignalTestFailure();
706 		}
707 	}
708 	else
709 	{
710 		std::string msg("Unknown authenticated symmetric cipher test \"" + test + "\"");
711 		SignalTestError(msg.c_str());
712 	}
713 }
714 
TestDigestOrMAC(TestData & v,bool testDigest)715 void TestDigestOrMAC(TestData &v, bool testDigest)
716 {
717 	std::string name = GetRequiredDatum(v, "Name");
718 	std::string test = GetRequiredDatum(v, "Test");
719 	const char *digestName = testDigest ? "Digest" : "MAC";
720 
721 	member_ptr<MessageAuthenticationCode> mac;
722 	member_ptr<HashTransformation> hash;
723 	HashTransformation *pHash = NULLPTR;
724 
725 	TestDataNameValuePairs pairs(v);
726 
727 	if (testDigest)
728 	{
729 		hash.reset(ObjectFactoryRegistry<HashTransformation>::Registry().CreateObject(name.c_str()));
730 		pHash = hash.get();
731 
732 		// Code coverage
733 		(void)hash->AlgorithmName();
734 		(void)hash->AlgorithmProvider();
735 	}
736 	else
737 	{
738 		mac.reset(ObjectFactoryRegistry<MessageAuthenticationCode>::Registry().CreateObject(name.c_str()));
739 		pHash = mac.get();
740 		std::string key = GetDecodedDatum(v, "Key");
741 		mac->SetKey(ConstBytePtr(key), BytePtrSize(key), pairs);
742 
743 		// Code coverage
744 		(void)mac->AlgorithmName();
745 		(void)mac->AlgorithmProvider();
746 	}
747 
748 	if (test == "Verify" || test == "VerifyTruncated" || test == "NotVerify")
749 	{
750 		int digestSize = -1;
751 		if (test == "VerifyTruncated")
752 			digestSize = pairs.GetIntValueWithDefault(Name::DigestSize(), digestSize);
753 		HashVerificationFilter verifierFilter(*pHash, NULLPTR, HashVerificationFilter::HASH_AT_BEGIN, digestSize);
754 		PutDecodedDatumInto(v, digestName, verifierFilter);
755 		PutDecodedDatumInto(v, "Message", verifierFilter);
756 		verifierFilter.MessageEnd();
757 		if (verifierFilter.GetLastResult() == (test == "NotVerify"))
758 			SignalTestFailure();
759 	}
760 	else
761 	{
762 		std::string msg("Unknown digest or mac test \"" + test + "\"");
763 		SignalTestError(msg.c_str());
764 	}
765 }
766 
TestKeyDerivationFunction(TestData & v)767 void TestKeyDerivationFunction(TestData &v)
768 {
769 	std::string name = GetRequiredDatum(v, "Name");
770 	std::string test = GetRequiredDatum(v, "Test");
771 
772 	if(test == "Skip") return;
773 	CRYPTOPP_ASSERT(test == "Verify");
774 
775 	std::string secret = GetDecodedDatum(v, "Secret");
776 	std::string expected = GetDecodedDatum(v, "DerivedKey");
777 
778 	TestDataNameValuePairs pairs(v);
779 
780 	member_ptr<KeyDerivationFunction> kdf;
781 	kdf.reset(ObjectFactoryRegistry<KeyDerivationFunction>::Registry().CreateObject(name.c_str()));
782 
783 	std::string calculated; calculated.resize(expected.size());
784 	kdf->DeriveKey(BytePtr(calculated), BytePtrSize(calculated), BytePtr(secret), BytePtrSize(secret), pairs);
785 
786 	if(calculated != expected)
787 	{
788 		std::cerr << "Calculated: ";
789 		StringSource(calculated, true, new HexEncoder(new FileSink(std::cerr)));
790 		std::cerr << std::endl;
791 
792 		SignalTestFailure();
793 	}
794 }
795 
FirstChar(const std::string & str)796 inline char FirstChar(const std::string& str) {
797 	if (str.empty()) return 0;
798 	return str[0];
799 }
800 
LastChar(const std::string & str)801 inline char LastChar(const std::string& str) {
802 	if (str.empty()) return 0;
803 	return str[str.length()-1];
804 }
805 
806 // GetField parses the name/value pairs. The tricky part is the insertion operator
807 // because Unix&Linux uses LF, OS X uses CR, and Windows uses CRLF. If this function
808 // is modified, then run 'cryptest.exe tv rsa_pkcs1_1_5' as a test. Its the parser
809 // file from hell. If it can be parsed without error, then things are likely OK.
810 // For istream.fail() see https://stackoverflow.com/q/34395801/608639.
GetField(std::istream & is,std::string & name,std::string & value)811 bool GetField(std::istream &is, std::string &name, std::string &value)
812 {
813 	std::string line;
814 	name.clear(); value.clear();
815 
816 	// ***** Name *****
817 	while (is >> std::ws && std::getline(is, line))
818 	{
819 		// Eat whitespace and comments gracefully
820 		if (line.empty() || line[0] == '#')
821 			continue;
822 
823 		std::string::size_type pos = line.find(':');
824 		if (pos == std::string::npos)
825 			SignalTestError("Unable to parse name/value pair");
826 
827 		name = TrimSpace(line.substr(0, pos));
828 		line = TrimSpace(line.substr(pos + 1));
829 
830 		// Empty name is bad
831 		if (name.empty())
832 			return false;
833 
834 		// Empty value is ok
835 		if (line.empty())
836 			return true;
837 
838 		break;
839 	}
840 
841 	// ***** Value *****
842 	bool continueLine = true;
843 
844 	do
845 	{
846 		// Trim leading and trailing whitespace, including OS X and Windows
847 		// new lines. Don't parse comments here because there may be a line
848 		// continuation at the end.
849 		line = TrimSpace(line);
850 
851 		continueLine = false;
852 		if (line.empty())
853 			continue;
854 
855 		// Early out for immediate line continuation
856 		if (line[0] == '\\') {
857 			continueLine = true;
858 			continue;
859 		}
860 		// Check end of line. It must be last character
861 		if (LastChar(line) == '\\') {
862 			continueLine = true;
863 			line.erase(line.end()-1);
864 			line = TrimSpace(line);
865 		}
866 
867 		// Re-trim after parsing
868 		line = TrimComment(line);
869 
870 		if (line.empty())
871 			continue;
872 
873 		// Finally... the value
874 		value += line;
875 
876 		if (continueLine)
877 			value += ' ';
878 	}
879 	while (continueLine && is >> std::ws && std::getline(is, line));
880 
881 	return true;
882 }
883 
OutputPair(const NameValuePairs & v,const char * name)884 void OutputPair(const NameValuePairs &v, const char *name)
885 {
886 	Integer x;
887 	bool b = v.GetValue(name, x);
888 	CRYPTOPP_UNUSED(b); CRYPTOPP_ASSERT(b);
889 	std::cout << name << ": \\\n    ";
890 	x.Encode(HexEncoder(new FileSink(std::cout), false, 64, "\\\n    ").Ref(), x.MinEncodedSize());
891 	std::cout << std::endl;
892 }
893 
OutputNameValuePairs(const NameValuePairs & v)894 void OutputNameValuePairs(const NameValuePairs &v)
895 {
896 	std::string names = v.GetValueNames();
897 	std::string::size_type i = 0;
898 	while (i < names.size())
899 	{
900 		std::string::size_type j = names.find_first_of (';', i);
901 
902 		if (j == std::string::npos)
903 			return;
904 		else
905 		{
906 			std::string name = names.substr(i, j-i);
907 			if (name.find(':') == std::string::npos)
908 				OutputPair(v, name.c_str());
909 		}
910 
911 		i = j + 1;
912 	}
913 }
914 
TestDataFile(std::string filename,const NameValuePairs & overrideParameters,unsigned int & totalTests,unsigned int & failedTests)915 void TestDataFile(std::string filename, const NameValuePairs &overrideParameters, unsigned int &totalTests, unsigned int &failedTests)
916 {
917 	filename = DataDir(filename);
918 	std::ifstream file(filename.c_str());
919 	if (!file.good())
920 		throw Exception(Exception::OTHER_ERROR, "Can not open file " + filename + " for reading");
921 
922 	TestData v;
923 	s_currentTestData = &v;
924 	std::string name, value, lastAlgName;
925 
926 	while (file)
927 	{
928 		if (!GetField(file, name, value))
929 			break;
930 
931 		if (name == "AlgorithmType")
932 			v.clear();
933 
934 		// Can't assert value. Plaintext is sometimes empty.
935 		// CRYPTOPP_ASSERT(!value.empty());
936 		v[name] = value;
937 
938 		if (name == "Test" && (s_thorough || v["SlowTest"] != "1"))
939 		{
940 			bool failed = true;
941 			std::string algType = GetRequiredDatum(v, "AlgorithmType");
942 
943 			if (lastAlgName != GetRequiredDatum(v, "Name"))
944 			{
945 				lastAlgName = GetRequiredDatum(v, "Name");
946 				std::cout << "\nTesting " << algType.c_str() << " algorithm " << lastAlgName.c_str() << ".\n";
947 			}
948 
949 			try
950 			{
951 				if (algType == "Signature")
952 					TestSignatureScheme(v);
953 				else if (algType == "SymmetricCipher")
954 					TestSymmetricCipher(v, overrideParameters);
955 				else if (algType == "AuthenticatedSymmetricCipher")
956 					TestAuthenticatedSymmetricCipher(v, overrideParameters);
957 				else if (algType == "AsymmetricCipher")
958 					TestAsymmetricCipher(v);
959 				else if (algType == "MessageDigest")
960 					TestDigestOrMAC(v, true);
961 				else if (algType == "MAC")
962 					TestDigestOrMAC(v, false);
963 				else if (algType == "KDF")
964 					TestKeyDerivationFunction(v);
965 				else if (algType == "FileList")
966 					TestDataFile(GetRequiredDatum(v, "Test"), g_nullNameValuePairs, totalTests, failedTests);
967 				else
968 					SignalUnknownAlgorithmError(algType);
969 				failed = false;
970 			}
971 			catch (const TestFailure &)
972 			{
973 				std::cout << "\nTest FAILED.\n";
974 			}
975 			catch (const CryptoPP::Exception &e)
976 			{
977 				std::cout << "\nCryptoPP::Exception caught: " << e.what() << std::endl;
978 			}
979 			catch (const std::exception &e)
980 			{
981 				std::cout << "\nstd::exception caught: " << e.what() << std::endl;
982 			}
983 
984 			if (failed)
985 			{
986 				std::cout << "Skipping to next test.\n";
987 				failedTests++;
988 			}
989 			else
990 				std::cout << "." << std::flush;
991 
992 			totalTests++;
993 		}
994 	}
995 }
996 
RunTestDataFile(const char * filename,const NameValuePairs & overrideParameters,bool thorough)997 bool RunTestDataFile(const char *filename, const NameValuePairs &overrideParameters, bool thorough)
998 {
999 	s_thorough = thorough;
1000 	unsigned int totalTests = 0, failedTests = 0;
1001 	TestDataFile((filename ? filename : ""), overrideParameters, totalTests, failedTests);
1002 
1003 	std::cout << std::dec << "\nTests complete. Total tests = " << totalTests << ". Failed tests = " << failedTests << "." << std::endl;
1004 	if (failedTests != 0)
1005 		std::cout << "SOME TESTS FAILED!\n";
1006 
1007 	CRYPTOPP_ASSERT(failedTests == 0);
1008 	return failedTests == 0;
1009 }
1010 
1011 NAMESPACE_END  // Test
1012 NAMESPACE_END  // CryptoPP
1013