1 // Licensed to the .NET Foundation under one or more agreements.
2 // The .NET Foundation licenses this file to you under the MIT license.
3 // See the LICENSE file in the project root for more information.
4 
5 #include "pal_rsa.h"
6 #include "pal_utilities.h"
7 
CryptoNative_RsaCreate()8 extern "C" RSA* CryptoNative_RsaCreate()
9 {
10     return RSA_new();
11 }
12 
CryptoNative_RsaUpRef(RSA * rsa)13 extern "C" int32_t CryptoNative_RsaUpRef(RSA* rsa)
14 {
15     return RSA_up_ref(rsa);
16 }
17 
CryptoNative_RsaDestroy(RSA * rsa)18 extern "C" void CryptoNative_RsaDestroy(RSA* rsa)
19 {
20     if (rsa != nullptr)
21     {
22         RSA_free(rsa);
23     }
24 }
25 
CryptoNative_DecodeRsaPublicKey(const uint8_t * buf,int32_t len)26 extern "C" RSA* CryptoNative_DecodeRsaPublicKey(const uint8_t* buf, int32_t len)
27 {
28     if (!buf || !len)
29     {
30         return nullptr;
31     }
32 
33     return d2i_RSAPublicKey(nullptr, &buf, len);
34 }
35 
GetOpenSslPadding(RsaPadding padding)36 static int GetOpenSslPadding(RsaPadding padding)
37 {
38     assert(padding == Pkcs1 || padding == OaepSHA1);
39 
40     return padding == Pkcs1 ? RSA_PKCS1_PADDING : RSA_PKCS1_OAEP_PADDING;
41 }
42 
HasNoPrivateKey(RSA * rsa)43 static int HasNoPrivateKey(RSA* rsa)
44 {
45     if (rsa == nullptr)
46         return 1;
47 
48     // Shared pointer, don't free.
49     const RSA_METHOD* meth = RSA_get_method(rsa);
50 
51     // The method has descibed itself as having the private key external to the structure.
52     // That doesn't mean it's actually present, but we can't tell.
53     if (meth->flags & RSA_FLAG_EXT_PKEY)
54        return 0;
55 
56     // In the event that there's a middle-ground where we report failure when success is expected,
57     // one could do something like check if the RSA_METHOD intercepts all private key operations:
58     //
59     // * meth->rsa_priv_enc
60     // * meth->rsa_priv_dec
61     // * meth->rsa_sign (in 1.0.x this is only respected if the RSA_FLAG_SIGN_VER flag is asserted)
62     //
63     // But, for now, leave it at the EXT_PKEY flag test.
64 
65     // The module is documented as accepting either d or the full set of CRT parameters (p, q, dp, dq, qInv)
66     // So if we see d, we're good. Otherwise, if any of the rest are missing, we're public-only.
67     if (rsa->d != nullptr)
68         return 0;
69 
70     if (rsa->p == nullptr || rsa->q == nullptr || rsa->dmp1 == nullptr || rsa->dmq1 == nullptr || rsa->iqmp == nullptr)
71         return 1;
72 
73     return 0;
74 }
75 
76 extern "C" int32_t
CryptoNative_RsaPublicEncrypt(int32_t flen,const uint8_t * from,uint8_t * to,RSA * rsa,RsaPadding padding)77 CryptoNative_RsaPublicEncrypt(int32_t flen, const uint8_t* from, uint8_t* to, RSA* rsa, RsaPadding padding)
78 {
79     int openSslPadding = GetOpenSslPadding(padding);
80     return RSA_public_encrypt(flen, from, to, rsa, openSslPadding);
81 }
82 
83 extern "C" int32_t
CryptoNative_RsaPrivateDecrypt(int32_t flen,const uint8_t * from,uint8_t * to,RSA * rsa,RsaPadding padding)84 CryptoNative_RsaPrivateDecrypt(int32_t flen, const uint8_t* from, uint8_t* to, RSA* rsa, RsaPadding padding)
85 {
86     if (HasNoPrivateKey(rsa))
87     {
88         ERR_PUT_error(ERR_LIB_RSA, RSA_F_RSA_PRIVATE_DECRYPT, RSA_R_VALUE_MISSING, __FILE__, __LINE__);
89         return -1;
90     }
91 
92     int openSslPadding = GetOpenSslPadding(padding);
93     return RSA_private_decrypt(flen, from, to, rsa, openSslPadding);
94 }
95 
CryptoNative_RsaSize(RSA * rsa)96 extern "C" int32_t CryptoNative_RsaSize(RSA* rsa)
97 {
98     return RSA_size(rsa);
99 }
100 
CryptoNative_RsaGenerateKeyEx(RSA * rsa,int32_t bits,BIGNUM * e)101 extern "C" int32_t CryptoNative_RsaGenerateKeyEx(RSA* rsa, int32_t bits, BIGNUM* e)
102 {
103     return RSA_generate_key_ex(rsa, bits, e, nullptr);
104 }
105 
106 extern "C" int32_t
CryptoNative_RsaSign(int32_t type,const uint8_t * m,int32_t mlen,uint8_t * sigret,int32_t * siglen,RSA * rsa)107 CryptoNative_RsaSign(int32_t type, const uint8_t* m, int32_t mlen, uint8_t* sigret, int32_t* siglen, RSA* rsa)
108 {
109     if (siglen == nullptr)
110     {
111         assert(false);
112         return 0;
113     }
114 
115     *siglen = 0;
116 
117     if (HasNoPrivateKey(rsa))
118     {
119         ERR_PUT_error(ERR_LIB_RSA, RSA_F_RSA_SIGN, RSA_R_VALUE_MISSING, __FILE__, __LINE__);
120         return 0;
121     }
122 
123     // Shared pointer to the metadata about the message digest algorithm
124     const EVP_MD* digest = EVP_get_digestbynid(type);
125 
126     // If the digest itself isn't known then RSA_R_UNKNOWN_ALGORITHM_TYPE will get reported, but
127     // we have to check that the digest size matches what we expect.
128     if (digest != nullptr && mlen != EVP_MD_size(digest))
129     {
130         ERR_PUT_error(ERR_LIB_RSA, RSA_F_RSA_SIGN, RSA_R_INVALID_MESSAGE_LENGTH, __FILE__, __LINE__);
131         return 0;
132     }
133 
134     unsigned int unsignedSigLen = 0;
135     int32_t ret = RSA_sign(type, m, UnsignedCast(mlen), sigret, &unsignedSigLen, rsa);
136     assert(unsignedSigLen <= INT32_MAX);
137     *siglen = static_cast<int32_t>(unsignedSigLen);
138     return ret;
139 }
140 
141 extern "C" int32_t
CryptoNative_RsaVerify(int32_t type,const uint8_t * m,int32_t mlen,uint8_t * sigbuf,int32_t siglen,RSA * rsa)142 CryptoNative_RsaVerify(int32_t type, const uint8_t* m, int32_t mlen, uint8_t* sigbuf, int32_t siglen, RSA* rsa)
143 {
144     return RSA_verify(type, m, UnsignedCast(mlen), sigbuf, UnsignedCast(siglen), rsa);
145 }
146 
CryptoNative_GetRsaParameters(const RSA * rsa,BIGNUM ** n,BIGNUM ** e,BIGNUM ** d,BIGNUM ** p,BIGNUM ** dmp1,BIGNUM ** q,BIGNUM ** dmq1,BIGNUM ** iqmp)147 extern "C" int32_t CryptoNative_GetRsaParameters(const RSA* rsa,
148                                                  BIGNUM** n,
149                                                  BIGNUM** e,
150                                                  BIGNUM** d,
151                                                  BIGNUM** p,
152                                                  BIGNUM** dmp1,
153                                                  BIGNUM** q,
154                                                  BIGNUM** dmq1,
155                                                  BIGNUM** iqmp)
156 {
157     if (!rsa || !n || !e || !d || !p || !dmp1 || !q || !dmq1 || !iqmp)
158     {
159         assert(false);
160 
161         // since these parameters are 'out' parameters in managed code, ensure they are initialized
162         if (n)
163             *n = nullptr;
164         if (e)
165             *e = nullptr;
166         if (d)
167             *d = nullptr;
168         if (p)
169             *p = nullptr;
170         if (dmp1)
171             *dmp1 = nullptr;
172         if (q)
173             *q = nullptr;
174         if (dmq1)
175             *dmq1 = nullptr;
176         if (iqmp)
177             *iqmp = nullptr;
178 
179         return 0;
180     }
181 
182     *n = rsa->n;
183     *e = rsa->e;
184     *d = rsa->d;
185     *p = rsa->p;
186     *dmp1 = rsa->dmp1;
187     *q = rsa->q;
188     *dmq1 = rsa->dmq1;
189     *iqmp = rsa->iqmp;
190 
191     return 1;
192 }
193 
SetRsaParameter(BIGNUM ** rsaFieldAddress,uint8_t * buffer,int32_t bufferLength)194 static void SetRsaParameter(BIGNUM** rsaFieldAddress, uint8_t* buffer, int32_t bufferLength)
195 {
196     assert(rsaFieldAddress != nullptr);
197     if (rsaFieldAddress)
198     {
199         if (!buffer || !bufferLength)
200         {
201             *rsaFieldAddress = nullptr;
202         }
203         else
204         {
205             BIGNUM* bigNum = BN_bin2bn(buffer, bufferLength, nullptr);
206             *rsaFieldAddress = bigNum;
207         }
208     }
209 }
210 
CryptoNative_SetRsaParameters(RSA * rsa,uint8_t * n,int32_t nLength,uint8_t * e,int32_t eLength,uint8_t * d,int32_t dLength,uint8_t * p,int32_t pLength,uint8_t * dmp1,int32_t dmp1Length,uint8_t * q,int32_t qLength,uint8_t * dmq1,int32_t dmq1Length,uint8_t * iqmp,int32_t iqmpLength)211 extern "C" void CryptoNative_SetRsaParameters(RSA* rsa,
212                                               uint8_t* n,
213                                               int32_t nLength,
214                                               uint8_t* e,
215                                               int32_t eLength,
216                                               uint8_t* d,
217                                               int32_t dLength,
218                                               uint8_t* p,
219                                               int32_t pLength,
220                                               uint8_t* dmp1,
221                                               int32_t dmp1Length,
222                                               uint8_t* q,
223                                               int32_t qLength,
224                                               uint8_t* dmq1,
225                                               int32_t dmq1Length,
226                                               uint8_t* iqmp,
227                                               int32_t iqmpLength)
228 {
229     if (!rsa)
230     {
231         assert(false);
232         return;
233     }
234 
235     SetRsaParameter(&rsa->n, n, nLength);
236     SetRsaParameter(&rsa->e, e, eLength);
237     SetRsaParameter(&rsa->d, d, dLength);
238     SetRsaParameter(&rsa->p, p, pLength);
239     SetRsaParameter(&rsa->dmp1, dmp1, dmp1Length);
240     SetRsaParameter(&rsa->q, q, qLength);
241     SetRsaParameter(&rsa->dmq1, dmq1, dmq1Length);
242     SetRsaParameter(&rsa->iqmp, iqmp, iqmpLength);
243 }
244