1 // oaep.cpp - originally written and placed in the public domain by Wei Dai
2
3 #include "pch.h"
4
5 #ifndef CRYPTOPP_IMPORTS
6
7 #include "oaep.h"
8 #include "stdcpp.h"
9 #include "smartptr.h"
10
NAMESPACE_BEGIN(CryptoPP)11 NAMESPACE_BEGIN(CryptoPP)
12
13 // ********************************************************
14
15 size_t OAEP_Base::MaxUnpaddedLength(size_t paddedLength) const
16 {
17 return SaturatingSubtract(paddedLength/8, 1+2*DigestSize());
18 }
19
Pad(RandomNumberGenerator & rng,const byte * input,size_t inputLength,byte * oaepBlock,size_t oaepBlockLen,const NameValuePairs & parameters) const20 void OAEP_Base::Pad(RandomNumberGenerator &rng, const byte *input, size_t inputLength, byte *oaepBlock, size_t oaepBlockLen, const NameValuePairs ¶meters) const
21 {
22 CRYPTOPP_ASSERT (inputLength <= MaxUnpaddedLength(oaepBlockLen));
23
24 // convert from bit length to byte length
25 if (oaepBlockLen % 8 != 0)
26 {
27 oaepBlock[0] = 0;
28 oaepBlock++;
29 }
30 oaepBlockLen /= 8;
31
32 member_ptr<HashTransformation> pHash(NewHash());
33 const size_t hLen = pHash->DigestSize();
34 const size_t seedLen = hLen, dbLen = oaepBlockLen-seedLen;
35 byte *const maskedSeed = oaepBlock;
36 byte *const maskedDB = oaepBlock+seedLen;
37
38 ConstByteArrayParameter encodingParameters;
39 parameters.GetValue(Name::EncodingParameters(), encodingParameters);
40
41 // DB = pHash || 00 ... || 01 || M
42 pHash->CalculateDigest(maskedDB, encodingParameters.begin(), encodingParameters.size());
43 memset(maskedDB+hLen, 0, dbLen-hLen-inputLength-1);
44 maskedDB[dbLen-inputLength-1] = 0x01;
45 memcpy(maskedDB+dbLen-inputLength, input, inputLength);
46
47 rng.GenerateBlock(maskedSeed, seedLen);
48 member_ptr<MaskGeneratingFunction> pMGF(NewMGF());
49 pMGF->GenerateAndMask(*pHash, maskedDB, dbLen, maskedSeed, seedLen);
50 pMGF->GenerateAndMask(*pHash, maskedSeed, seedLen, maskedDB, dbLen);
51 }
52
Unpad(const byte * oaepBlock,size_t oaepBlockLen,byte * output,const NameValuePairs & parameters) const53 DecodingResult OAEP_Base::Unpad(const byte *oaepBlock, size_t oaepBlockLen, byte *output, const NameValuePairs ¶meters) const
54 {
55 bool invalid = false;
56
57 // convert from bit length to byte length
58 if (oaepBlockLen % 8 != 0)
59 {
60 invalid = (oaepBlock[0] != 0) || invalid;
61 oaepBlock++;
62 }
63 oaepBlockLen /= 8;
64
65 member_ptr<HashTransformation> pHash(NewHash());
66 const size_t hLen = pHash->DigestSize();
67 const size_t seedLen = hLen, dbLen = oaepBlockLen-seedLen;
68
69 invalid = (oaepBlockLen < 2*hLen+1) || invalid;
70
71 SecByteBlock t(oaepBlock, oaepBlockLen);
72 byte *const maskedSeed = t;
73 byte *const maskedDB = t+seedLen;
74
75 member_ptr<MaskGeneratingFunction> pMGF(NewMGF());
76 pMGF->GenerateAndMask(*pHash, maskedSeed, seedLen, maskedDB, dbLen);
77 pMGF->GenerateAndMask(*pHash, maskedDB, dbLen, maskedSeed, seedLen);
78
79 ConstByteArrayParameter encodingParameters;
80 parameters.GetValue(Name::EncodingParameters(), encodingParameters);
81
82 // DB = pHash' || 00 ... || 01 || M
83 byte *M = std::find(maskedDB+hLen, maskedDB+dbLen, 0x01);
84 invalid = (M == maskedDB+dbLen) || invalid;
85 invalid = (FindIfNot(maskedDB+hLen, M, byte(0)) != M) || invalid;
86 invalid = !pHash->VerifyDigest(maskedDB, encodingParameters.begin(), encodingParameters.size()) || invalid;
87
88 if (invalid)
89 return DecodingResult();
90
91 M++;
92 memcpy(output, M, maskedDB+dbLen-M);
93 return DecodingResult(maskedDB+dbLen-M);
94 }
95
96 NAMESPACE_END
97
98 #endif
99