1 // pssr.cpp - originally written and placed in the public domain by Wei Dai
2 
3 #include "pch.h"
4 #include "pssr.h"
5 #include "emsa2.h"
6 #include "ripemd.h"
7 #include "whrlpool.h"
8 #include "misc.h"
9 
10 #include <functional>
11 
12 NAMESPACE_BEGIN(CryptoPP)
13 
14 template<> const byte EMSA2HashId<RIPEMD160>::id = 0x31;
15 template<> const byte EMSA2HashId<RIPEMD128>::id = 0x32;
16 template<> const byte EMSA2HashId<Whirlpool>::id = 0x37;
17 
18 #ifndef CRYPTOPP_IMPORTS
19 
MinRepresentativeBitLength(size_t hashIdentifierLength,size_t digestLength) const20 size_t PSSR_MEM_Base::MinRepresentativeBitLength(size_t hashIdentifierLength, size_t digestLength) const
21 {
22 	size_t saltLen = SaltLen(digestLength);
23 	size_t minPadLen = MinPadLen(digestLength);
24 	return 9 + 8*(minPadLen + saltLen + digestLength + hashIdentifierLength);
25 }
26 
MaxRecoverableLength(size_t representativeBitLength,size_t hashIdentifierLength,size_t digestLength) const27 size_t PSSR_MEM_Base::MaxRecoverableLength(size_t representativeBitLength, size_t hashIdentifierLength, size_t digestLength) const
28 {
29 	if (AllowRecovery())
30 		return SaturatingSubtract(representativeBitLength, MinRepresentativeBitLength(hashIdentifierLength, digestLength)) / 8;
31 	return 0;
32 }
33 
IsProbabilistic() const34 bool PSSR_MEM_Base::IsProbabilistic() const
35 {
36 	return SaltLen(1) > 0;
37 }
38 
AllowNonrecoverablePart() const39 bool PSSR_MEM_Base::AllowNonrecoverablePart() const
40 {
41 	return true;
42 }
43 
RecoverablePartFirst() const44 bool PSSR_MEM_Base::RecoverablePartFirst() const
45 {
46 	return false;
47 }
48 
ComputeMessageRepresentative(RandomNumberGenerator & rng,const byte * recoverableMessage,size_t recoverableMessageLength,HashTransformation & hash,HashIdentifier hashIdentifier,bool messageEmpty,byte * representative,size_t representativeBitLength) const49 void PSSR_MEM_Base::ComputeMessageRepresentative(RandomNumberGenerator &rng,
50 	const byte *recoverableMessage, size_t recoverableMessageLength,
51 	HashTransformation &hash, HashIdentifier hashIdentifier, bool messageEmpty,
52 	byte *representative, size_t representativeBitLength) const
53 {
54 	CRYPTOPP_UNUSED(rng), CRYPTOPP_UNUSED(recoverableMessage), CRYPTOPP_UNUSED(recoverableMessageLength);
55 	CRYPTOPP_UNUSED(messageEmpty), CRYPTOPP_UNUSED(hashIdentifier);
56 	CRYPTOPP_ASSERT(representativeBitLength >= MinRepresentativeBitLength(hashIdentifier.second, hash.DigestSize()));
57 
58 	const size_t u = hashIdentifier.second + 1;
59 	const size_t representativeByteLength = BitsToBytes(representativeBitLength);
60 	const size_t digestSize = hash.DigestSize();
61 	const size_t saltSize = SaltLen(digestSize);
62 	byte *const h = representative + representativeByteLength - u - digestSize;
63 
64 	SecByteBlock digest(digestSize), salt(saltSize);
65 	hash.Final(digest);
66 	rng.GenerateBlock(salt, saltSize);
67 
68 	// compute H = hash of M'
69 	byte c[8];
70 	PutWord(false, BIG_ENDIAN_ORDER, c, (word32)SafeRightShift<29>(recoverableMessageLength));
71 	PutWord(false, BIG_ENDIAN_ORDER, c+4, word32(recoverableMessageLength << 3));
72 	hash.Update(c, 8);
73 	hash.Update(recoverableMessage, recoverableMessageLength);
74 	hash.Update(digest, digestSize);
75 	hash.Update(salt, saltSize);
76 	hash.Final(h);
77 
78 	// compute representative
79 	GetMGF().GenerateAndMask(hash, representative, representativeByteLength - u - digestSize, h, digestSize, false);
80 	byte *xorStart = representative + representativeByteLength - u - digestSize - salt.size() - recoverableMessageLength - 1;
81 	xorStart[0] ^= 1;
82 	if (recoverableMessage && recoverableMessageLength)
83 		xorbuf(xorStart + 1, recoverableMessage, recoverableMessageLength);
84 	xorbuf(xorStart + 1 + recoverableMessageLength, salt, salt.size());
85 	if (hashIdentifier.first && hashIdentifier.second)
86 	{
87 		memcpy(representative + representativeByteLength - u, hashIdentifier.first, hashIdentifier.second);
88 		representative[representativeByteLength - 1] = 0xcc;
89 	}
90 	else
91 	{
92 		representative[representativeByteLength - 1] = 0xbc;
93 	}
94 	if (representativeBitLength % 8 != 0)
95 		representative[0] = (byte)Crop(representative[0], representativeBitLength % 8);
96 }
97 
RecoverMessageFromRepresentative(HashTransformation & hash,HashIdentifier hashIdentifier,bool messageEmpty,byte * representative,size_t representativeBitLength,byte * recoverableMessage) const98 DecodingResult PSSR_MEM_Base::RecoverMessageFromRepresentative(
99 	HashTransformation &hash, HashIdentifier hashIdentifier, bool messageEmpty,
100 	byte *representative, size_t representativeBitLength,
101 	byte *recoverableMessage) const
102 {
103 	CRYPTOPP_UNUSED(recoverableMessage), CRYPTOPP_UNUSED(messageEmpty), CRYPTOPP_UNUSED(hashIdentifier);
104 	CRYPTOPP_ASSERT(representativeBitLength >= MinRepresentativeBitLength(hashIdentifier.second, hash.DigestSize()));
105 
106 	const size_t u = hashIdentifier.second + 1;
107 	const size_t representativeByteLength = BitsToBytes(representativeBitLength);
108 	const size_t digestSize = hash.DigestSize();
109 	const size_t saltSize = SaltLen(digestSize);
110 	const byte *const h = representative + representativeByteLength - u - digestSize;
111 
112 	SecByteBlock digest(digestSize);
113 	hash.Final(digest);
114 
115 	DecodingResult result(0);
116 	bool &valid = result.isValidCoding;
117 	size_t &recoverableMessageLength = result.messageLength;
118 
119 	valid = (representative[representativeByteLength - 1] == (hashIdentifier.second ? 0xcc : 0xbc)) && valid;
120 
121 	if (hashIdentifier.first && hashIdentifier.second)
122 		valid = VerifyBufsEqual(representative + representativeByteLength - u, hashIdentifier.first, hashIdentifier.second) && valid;
123 
124 	GetMGF().GenerateAndMask(hash, representative, representativeByteLength - u - digestSize, h, digestSize);
125 	if (representativeBitLength % 8 != 0)
126 		representative[0] = (byte)Crop(representative[0], representativeBitLength % 8);
127 
128 	// extract salt and recoverableMessage from DB = 00 ... || 01 || M || salt
129 	byte *salt = representative + representativeByteLength - u - digestSize - saltSize;
130 	byte *M = FindIfNot(representative, salt-1, byte(0));
131 	recoverableMessageLength = salt-M-1;
132 	if (*M == 0x01 &&
133 	   (size_t)(M - representative - (representativeBitLength % 8 != 0)) >= MinPadLen(digestSize) &&
134 	   recoverableMessageLength <= MaxRecoverableLength(representativeBitLength, hashIdentifier.second, digestSize))
135 	{
136 		if (recoverableMessage)
137 			memcpy(recoverableMessage, M+1, recoverableMessageLength);
138 	}
139 	else
140 	{
141 		recoverableMessageLength = 0;
142 		valid = false;
143 	}
144 
145 	// verify H = hash of M'
146 	byte c[8];
147 	PutWord(false, BIG_ENDIAN_ORDER, c, (word32)SafeRightShift<29>(recoverableMessageLength));
148 	PutWord(false, BIG_ENDIAN_ORDER, c+4, word32(recoverableMessageLength << 3));
149 	hash.Update(c, 8);
150 	hash.Update(recoverableMessage, recoverableMessageLength);
151 	hash.Update(digest, digestSize);
152 	hash.Update(salt, saltSize);
153 	valid = hash.Verify(h) && valid;
154 
155 	if (!AllowRecovery() && valid && recoverableMessageLength != 0)
156 		{throw NotImplemented("PSSR_MEM: message recovery disabled");}
157 
158 	return result;
159 }
160 
161 #endif
162 
163 NAMESPACE_END
164