1 // pssr.cpp - originally written and placed in the public domain by Wei Dai 2 3 #include "pch.h" 4 #include "pssr.h" 5 #include "emsa2.h" 6 #include "ripemd.h" 7 #include "whrlpool.h" 8 #include "misc.h" 9 10 #include <functional> 11 12 NAMESPACE_BEGIN(CryptoPP) 13 14 template<> const byte EMSA2HashId<RIPEMD160>::id = 0x31; 15 template<> const byte EMSA2HashId<RIPEMD128>::id = 0x32; 16 template<> const byte EMSA2HashId<Whirlpool>::id = 0x37; 17 18 #ifndef CRYPTOPP_IMPORTS 19 20 size_t PSSR_MEM_Base::MinRepresentativeBitLength(size_t hashIdentifierLength, size_t digestLength) const 21 { 22 size_t saltLen = SaltLen(digestLength); 23 size_t minPadLen = MinPadLen(digestLength); 24 return 9 + 8*(minPadLen + saltLen + digestLength + hashIdentifierLength); 25 } 26 27 size_t PSSR_MEM_Base::MaxRecoverableLength(size_t representativeBitLength, size_t hashIdentifierLength, size_t digestLength) const 28 { 29 if (AllowRecovery()) 30 return SaturatingSubtract(representativeBitLength, MinRepresentativeBitLength(hashIdentifierLength, digestLength)) / 8; 31 return 0; 32 } 33 34 bool PSSR_MEM_Base::IsProbabilistic() const 35 { 36 return SaltLen(1) > 0; 37 } 38 39 bool PSSR_MEM_Base::AllowNonrecoverablePart() const 40 { 41 return true; 42 } 43 44 bool PSSR_MEM_Base::RecoverablePartFirst() const 45 { 46 return false; 47 } 48 49 void PSSR_MEM_Base::ComputeMessageRepresentative(RandomNumberGenerator &rng, 50 const byte *recoverableMessage, size_t recoverableMessageLength, 51 HashTransformation &hash, HashIdentifier hashIdentifier, bool messageEmpty, 52 byte *representative, size_t representativeBitLength) const 53 { 54 CRYPTOPP_UNUSED(rng), CRYPTOPP_UNUSED(recoverableMessage), CRYPTOPP_UNUSED(recoverableMessageLength); 55 CRYPTOPP_UNUSED(messageEmpty), CRYPTOPP_UNUSED(hashIdentifier); 56 CRYPTOPP_ASSERT(representativeBitLength >= MinRepresentativeBitLength(hashIdentifier.second, hash.DigestSize())); 57 58 const size_t u = hashIdentifier.second + 1; 59 const size_t representativeByteLength = BitsToBytes(representativeBitLength); 60 const size_t digestSize = hash.DigestSize(); 61 const size_t saltSize = SaltLen(digestSize); 62 byte *const h = representative + representativeByteLength - u - digestSize; 63 64 SecByteBlock digest(digestSize), salt(saltSize); 65 hash.Final(digest); 66 rng.GenerateBlock(salt, saltSize); 67 68 // compute H = hash of M' 69 byte c[8]; 70 PutWord(false, BIG_ENDIAN_ORDER, c, (word32)SafeRightShift<29>(recoverableMessageLength)); 71 PutWord(false, BIG_ENDIAN_ORDER, c+4, word32(recoverableMessageLength << 3)); 72 hash.Update(c, 8); 73 hash.Update(recoverableMessage, recoverableMessageLength); 74 hash.Update(digest, digestSize); 75 hash.Update(salt, saltSize); 76 hash.Final(h); 77 78 // compute representative 79 GetMGF().GenerateAndMask(hash, representative, representativeByteLength - u - digestSize, h, digestSize, false); 80 byte *xorStart = representative + representativeByteLength - u - digestSize - salt.size() - recoverableMessageLength - 1; 81 xorStart[0] ^= 1; 82 if (recoverableMessage && recoverableMessageLength) 83 xorbuf(xorStart + 1, recoverableMessage, recoverableMessageLength); 84 xorbuf(xorStart + 1 + recoverableMessageLength, salt, salt.size()); 85 if (hashIdentifier.first && hashIdentifier.second) 86 { 87 memcpy(representative + representativeByteLength - u, hashIdentifier.first, hashIdentifier.second); 88 representative[representativeByteLength - 1] = 0xcc; 89 } 90 else 91 { 92 representative[representativeByteLength - 1] = 0xbc; 93 } 94 if (representativeBitLength % 8 != 0) 95 representative[0] = (byte)Crop(representative[0], representativeBitLength % 8); 96 } 97 98 DecodingResult PSSR_MEM_Base::RecoverMessageFromRepresentative( 99 HashTransformation &hash, HashIdentifier hashIdentifier, bool messageEmpty, 100 byte *representative, size_t representativeBitLength, 101 byte *recoverableMessage) const 102 { 103 CRYPTOPP_UNUSED(recoverableMessage), CRYPTOPP_UNUSED(messageEmpty), CRYPTOPP_UNUSED(hashIdentifier); 104 CRYPTOPP_ASSERT(representativeBitLength >= MinRepresentativeBitLength(hashIdentifier.second, hash.DigestSize())); 105 106 const size_t u = hashIdentifier.second + 1; 107 const size_t representativeByteLength = BitsToBytes(representativeBitLength); 108 const size_t digestSize = hash.DigestSize(); 109 const size_t saltSize = SaltLen(digestSize); 110 const byte *const h = representative + representativeByteLength - u - digestSize; 111 112 SecByteBlock digest(digestSize); 113 hash.Final(digest); 114 115 DecodingResult result(0); 116 bool &valid = result.isValidCoding; 117 size_t &recoverableMessageLength = result.messageLength; 118 119 valid = (representative[representativeByteLength - 1] == (hashIdentifier.second ? 0xcc : 0xbc)) && valid; 120 121 if (hashIdentifier.first && hashIdentifier.second) 122 valid = VerifyBufsEqual(representative + representativeByteLength - u, hashIdentifier.first, hashIdentifier.second) && valid; 123 124 GetMGF().GenerateAndMask(hash, representative, representativeByteLength - u - digestSize, h, digestSize); 125 if (representativeBitLength % 8 != 0) 126 representative[0] = (byte)Crop(representative[0], representativeBitLength % 8); 127 128 // extract salt and recoverableMessage from DB = 00 ... || 01 || M || salt 129 byte *salt = representative + representativeByteLength - u - digestSize - saltSize; 130 byte *M = FindIfNot(representative, salt-1, byte(0)); 131 recoverableMessageLength = salt-M-1; 132 if (*M == 0x01 && 133 (size_t)(M - representative - (representativeBitLength % 8 != 0)) >= MinPadLen(digestSize) && 134 recoverableMessageLength <= MaxRecoverableLength(representativeBitLength, hashIdentifier.second, digestSize)) 135 { 136 if (recoverableMessage) 137 memcpy(recoverableMessage, M+1, recoverableMessageLength); 138 } 139 else 140 { 141 recoverableMessageLength = 0; 142 valid = false; 143 } 144 145 // verify H = hash of M' 146 byte c[8]; 147 PutWord(false, BIG_ENDIAN_ORDER, c, (word32)SafeRightShift<29>(recoverableMessageLength)); 148 PutWord(false, BIG_ENDIAN_ORDER, c+4, word32(recoverableMessageLength << 3)); 149 hash.Update(c, 8); 150 hash.Update(recoverableMessage, recoverableMessageLength); 151 hash.Update(digest, digestSize); 152 hash.Update(salt, saltSize); 153 valid = hash.Verify(h) && valid; 154 155 if (!AllowRecovery() && valid && recoverableMessageLength != 0) 156 {throw NotImplemented("PSSR_MEM: message recovery disabled");} 157 158 return result; 159 } 160 161 #endif 162 163 NAMESPACE_END 164