1 // oaep.cpp - originally written and placed in the public domain by Wei Dai
2 
3 #include "pch.h"
4 
5 #ifndef CRYPTOPP_IMPORTS
6 
7 #include "oaep.h"
8 #include "stdcpp.h"
9 #include "smartptr.h"
10 
NAMESPACE_BEGIN(CryptoPP)11 NAMESPACE_BEGIN(CryptoPP)
12 
13 // ********************************************************
14 
15 size_t OAEP_Base::MaxUnpaddedLength(size_t paddedLength) const
16 {
17 	return SaturatingSubtract(paddedLength/8, 1+2*DigestSize());
18 }
19 
Pad(RandomNumberGenerator & rng,const byte * input,size_t inputLength,byte * oaepBlock,size_t oaepBlockLen,const NameValuePairs & parameters) const20 void OAEP_Base::Pad(RandomNumberGenerator &rng, const byte *input, size_t inputLength, byte *oaepBlock, size_t oaepBlockLen, const NameValuePairs &parameters) const
21 {
22 	CRYPTOPP_ASSERT (inputLength <= MaxUnpaddedLength(oaepBlockLen));
23 
24 	// convert from bit length to byte length
25 	if (oaepBlockLen % 8 != 0)
26 	{
27 		oaepBlock[0] = 0;
28 		oaepBlock++;
29 	}
30 	oaepBlockLen /= 8;
31 
32 	member_ptr<HashTransformation> pHash(NewHash());
33 	const size_t hLen = pHash->DigestSize();
34 	const size_t seedLen = hLen, dbLen = oaepBlockLen-seedLen;
35 	byte *const maskedSeed = oaepBlock;
36 	byte *const maskedDB = oaepBlock+seedLen;
37 
38 	ConstByteArrayParameter encodingParameters;
39 	parameters.GetValue(Name::EncodingParameters(), encodingParameters);
40 
41 	// DB = pHash || 00 ... || 01 || M
42 	pHash->CalculateDigest(maskedDB, encodingParameters.begin(), encodingParameters.size());
43 	memset(maskedDB+hLen, 0, dbLen-hLen-inputLength-1);
44 	maskedDB[dbLen-inputLength-1] = 0x01;
45 	memcpy(maskedDB+dbLen-inputLength, input, inputLength);
46 
47 	rng.GenerateBlock(maskedSeed, seedLen);
48 	member_ptr<MaskGeneratingFunction> pMGF(NewMGF());
49 	pMGF->GenerateAndMask(*pHash, maskedDB, dbLen, maskedSeed, seedLen);
50 	pMGF->GenerateAndMask(*pHash, maskedSeed, seedLen, maskedDB, dbLen);
51 }
52 
Unpad(const byte * oaepBlock,size_t oaepBlockLen,byte * output,const NameValuePairs & parameters) const53 DecodingResult OAEP_Base::Unpad(const byte *oaepBlock, size_t oaepBlockLen, byte *output, const NameValuePairs &parameters) const
54 {
55 	bool invalid = false;
56 
57 	// convert from bit length to byte length
58 	if (oaepBlockLen % 8 != 0)
59 	{
60 		invalid = (oaepBlock[0] != 0) || invalid;
61 		oaepBlock++;
62 	}
63 	oaepBlockLen /= 8;
64 
65 	member_ptr<HashTransformation> pHash(NewHash());
66 	const size_t hLen = pHash->DigestSize();
67 	const size_t seedLen = hLen, dbLen = oaepBlockLen-seedLen;
68 
69 	invalid = (oaepBlockLen < 2*hLen+1) || invalid;
70 
71 	SecByteBlock t(oaepBlock, oaepBlockLen);
72 	byte *const maskedSeed = t;
73 	byte *const maskedDB = t+seedLen;
74 
75 	member_ptr<MaskGeneratingFunction> pMGF(NewMGF());
76 	pMGF->GenerateAndMask(*pHash, maskedSeed, seedLen, maskedDB, dbLen);
77 	pMGF->GenerateAndMask(*pHash, maskedDB, dbLen, maskedSeed, seedLen);
78 
79 	ConstByteArrayParameter encodingParameters;
80 	parameters.GetValue(Name::EncodingParameters(), encodingParameters);
81 
82 	// DB = pHash' || 00 ... || 01 || M
83 	byte *M = std::find(maskedDB+hLen, maskedDB+dbLen, 0x01);
84 	invalid = (M == maskedDB+dbLen) || invalid;
85 	invalid = (FindIfNot(maskedDB+hLen, M, byte(0)) != M) || invalid;
86 	invalid = !pHash->VerifyDigest(maskedDB, encodingParameters.begin(), encodingParameters.size()) || invalid;
87 
88 	if (invalid)
89 		return DecodingResult();
90 
91 	M++;
92 	memcpy(output, M, maskedDB+dbLen-M);
93 	return DecodingResult(maskedDB+dbLen-M);
94 }
95 
96 NAMESPACE_END
97 
98 #endif
99