1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this file,
5  * You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 #include <string>
8 
9 #include "gtest/gtest.h"
10 
11 #include "scoped_ptrs.h"
12 #include "cryptohi.h"
13 #include "secitem.h"
14 #include "secerr.h"
15 
16 namespace nss_test {
17 
18 class SignParamsTestF : public ::testing::Test {
19  protected:
20   ScopedPLArenaPool arena_;
21   ScopedSECKEYPrivateKey privk_;
22   ScopedSECKEYPublicKey pubk_;
23   ScopedSECKEYPrivateKey ecPrivk_;
24   ScopedSECKEYPublicKey ecPubk_;
25 
SetUp()26   void SetUp() {
27     arena_.reset(PORT_NewArena(2048));
28 
29     SECKEYPublicKey *pubk;
30     SECKEYPrivateKey *privk = SECKEY_CreateRSAPrivateKey(1024, &pubk, NULL);
31     ASSERT_NE(nullptr, pubk);
32     pubk_.reset(pubk);
33     ASSERT_NE(nullptr, privk);
34     privk_.reset(privk);
35 
36     SECKEYECParams ecParams = {siBuffer, NULL, 0};
37     SECOidData *oidData;
38     oidData = SECOID_FindOIDByTag(SEC_OID_CURVE25519);
39     ASSERT_NE(nullptr, oidData);
40     ASSERT_NE(nullptr,
41               SECITEM_AllocItem(NULL, &ecParams, (2 + oidData->oid.len)))
42         << "Couldn't allocate memory for OID.";
43     ecParams.data[0] = SEC_ASN1_OBJECT_ID; /* we have to prepend 0x06 */
44     ecParams.data[1] = oidData->oid.len;
45     memcpy(ecParams.data + 2, oidData->oid.data, oidData->oid.len);
46     SECKEYPublicKey *ecPubk;
47     SECKEYPrivateKey *ecPrivk =
48         SECKEY_CreateECPrivateKey(&ecParams, &ecPubk, NULL);
49     ASSERT_NE(nullptr, ecPubk);
50     ecPubk_.reset(ecPubk);
51     ASSERT_NE(nullptr, ecPrivk);
52     ecPrivk_.reset(ecPrivk);
53   }
54 
CreatePssParams(SECKEYRSAPSSParams * params,SECOidTag hashAlgTag)55   void CreatePssParams(SECKEYRSAPSSParams *params, SECOidTag hashAlgTag) {
56     PORT_Memset(params, 0, sizeof(SECKEYRSAPSSParams));
57 
58     params->hashAlg = (SECAlgorithmID *)PORT_ArenaZAlloc(
59         arena_.get(), sizeof(SECAlgorithmID));
60     ASSERT_NE(nullptr, params->hashAlg);
61     SECStatus rv =
62         SECOID_SetAlgorithmID(arena_.get(), params->hashAlg, hashAlgTag, NULL);
63     ASSERT_EQ(SECSuccess, rv);
64   }
65 
CreatePssParams(SECKEYRSAPSSParams * params,SECOidTag hashAlgTag,SECOidTag maskHashAlgTag)66   void CreatePssParams(SECKEYRSAPSSParams *params, SECOidTag hashAlgTag,
67                        SECOidTag maskHashAlgTag) {
68     CreatePssParams(params, hashAlgTag);
69 
70     SECAlgorithmID maskHashAlg;
71     PORT_Memset(&maskHashAlg, 0, sizeof(maskHashAlg));
72     SECStatus rv =
73         SECOID_SetAlgorithmID(arena_.get(), &maskHashAlg, maskHashAlgTag, NULL);
74     ASSERT_EQ(SECSuccess, rv);
75 
76     SECItem *maskHashAlgItem =
77         SEC_ASN1EncodeItem(arena_.get(), NULL, &maskHashAlg,
78                            SEC_ASN1_GET(SECOID_AlgorithmIDTemplate));
79 
80     params->maskAlg = (SECAlgorithmID *)PORT_ArenaZAlloc(
81         arena_.get(), sizeof(SECAlgorithmID));
82     ASSERT_NE(nullptr, params->maskAlg);
83 
84     rv = SECOID_SetAlgorithmID(arena_.get(), params->maskAlg,
85                                SEC_OID_PKCS1_MGF1, maskHashAlgItem);
86     ASSERT_EQ(SECSuccess, rv);
87   }
88 
CreatePssParams(SECKEYRSAPSSParams * params,SECOidTag hashAlgTag,SECOidTag maskHashAlgTag,unsigned long saltLength)89   void CreatePssParams(SECKEYRSAPSSParams *params, SECOidTag hashAlgTag,
90                        SECOidTag maskHashAlgTag, unsigned long saltLength) {
91     CreatePssParams(params, hashAlgTag, maskHashAlgTag);
92 
93     SECItem *saltLengthItem =
94         SEC_ASN1EncodeInteger(arena_.get(), &params->saltLength, saltLength);
95     ASSERT_EQ(&params->saltLength, saltLengthItem);
96   }
97 
CheckHashAlg(SECKEYRSAPSSParams * params,SECOidTag hashAlgTag)98   void CheckHashAlg(SECKEYRSAPSSParams *params, SECOidTag hashAlgTag) {
99     // If hash algorithm is SHA-1, it must be omitted in the parameters
100     if (hashAlgTag == SEC_OID_SHA1) {
101       EXPECT_EQ(nullptr, params->hashAlg);
102     } else {
103       EXPECT_NE(nullptr, params->hashAlg);
104       EXPECT_EQ(hashAlgTag, SECOID_GetAlgorithmTag(params->hashAlg));
105     }
106   }
107 
CheckMaskAlg(SECKEYRSAPSSParams * params,SECOidTag hashAlgTag)108   void CheckMaskAlg(SECKEYRSAPSSParams *params, SECOidTag hashAlgTag) {
109     SECStatus rv;
110 
111     // If hash algorithm is SHA-1, it must be omitted in the parameters
112     if (hashAlgTag == SEC_OID_SHA1)
113       EXPECT_EQ(nullptr, params->hashAlg);
114     else {
115       EXPECT_NE(nullptr, params->maskAlg);
116       EXPECT_EQ(SEC_OID_PKCS1_MGF1, SECOID_GetAlgorithmTag(params->maskAlg));
117 
118       SECAlgorithmID hashAlg;
119       rv = SEC_QuickDERDecodeItem(arena_.get(), &hashAlg,
120                                   SEC_ASN1_GET(SECOID_AlgorithmIDTemplate),
121                                   &params->maskAlg->parameters);
122       ASSERT_EQ(SECSuccess, rv);
123 
124       EXPECT_EQ(hashAlgTag, SECOID_GetAlgorithmTag(&hashAlg));
125     }
126   }
127 
CheckSaltLength(SECKEYRSAPSSParams * params,SECOidTag hashAlg)128   void CheckSaltLength(SECKEYRSAPSSParams *params, SECOidTag hashAlg) {
129     // If the salt length parameter is missing, that means it is 20 (default)
130     if (!params->saltLength.data) {
131       return;
132     }
133 
134     unsigned long value;
135     SECStatus rv = SEC_ASN1DecodeInteger(&params->saltLength, &value);
136     ASSERT_EQ(SECSuccess, rv);
137 
138     // The salt length are usually the same as the hash length,
139     // except for the case where the hash length exceeds the limit
140     // set by the key length
141     switch (hashAlg) {
142       case SEC_OID_SHA1:
143         EXPECT_EQ(20UL, value);
144         break;
145       case SEC_OID_SHA224:
146         EXPECT_EQ(28UL, value);
147         break;
148       case SEC_OID_SHA256:
149         EXPECT_EQ(32UL, value);
150         break;
151       case SEC_OID_SHA384:
152         EXPECT_EQ(48UL, value);
153         break;
154       case SEC_OID_SHA512:
155         // Truncated from 64, because our private key is 1024-bit
156         EXPECT_EQ(62UL, value);
157         break;
158       default:
159         FAIL();
160     }
161   }
162 };
163 
164 class SignParamsTest
165     : public SignParamsTestF,
166       public ::testing::WithParamInterface<std::tuple<SECOidTag, SECOidTag>> {};
167 
168 class SignParamsSourceTest : public SignParamsTestF,
169                              public ::testing::WithParamInterface<SECOidTag> {};
170 
TEST_P(SignParamsTest,CreateRsa)171 TEST_P(SignParamsTest, CreateRsa) {
172   SECOidTag hashAlg = std::get<0>(GetParam());
173   SECOidTag srcHashAlg = std::get<1>(GetParam());
174 
175   SECItem *srcParams;
176   if (srcHashAlg != SEC_OID_UNKNOWN) {
177     SECKEYRSAPSSParams pssParams;
178     ASSERT_NO_FATAL_FAILURE(
179         CreatePssParams(&pssParams, srcHashAlg, srcHashAlg));
180     srcParams = SEC_ASN1EncodeItem(arena_.get(), nullptr, &pssParams,
181                                    SEC_ASN1_GET(SECKEY_RSAPSSParamsTemplate));
182     ASSERT_NE(nullptr, srcParams);
183   } else {
184     srcParams = NULL;
185   }
186 
187   SECItem *params = SEC_CreateSignatureAlgorithmParameters(
188       arena_.get(), nullptr, SEC_OID_PKCS1_RSA_ENCRYPTION, hashAlg, srcParams,
189       privk_.get());
190 
191   // PKCS#1 RSA actually doesn't take any parameters, but if it is
192   // given, return a copy of it
193   if (srcHashAlg != SEC_OID_UNKNOWN) {
194     EXPECT_EQ(srcParams->len, params->len);
195     EXPECT_EQ(0, memcmp(params->data, srcParams->data, srcParams->len));
196   } else {
197     EXPECT_EQ(nullptr, params);
198   }
199 }
200 
TEST_P(SignParamsTest,CreateRsaPss)201 TEST_P(SignParamsTest, CreateRsaPss) {
202   SECOidTag hashAlg = std::get<0>(GetParam());
203   SECOidTag srcHashAlg = std::get<1>(GetParam());
204 
205   SECItem *srcParams;
206   if (srcHashAlg != SEC_OID_UNKNOWN) {
207     SECKEYRSAPSSParams pssParams;
208     ASSERT_NO_FATAL_FAILURE(
209         CreatePssParams(&pssParams, srcHashAlg, srcHashAlg));
210     srcParams = SEC_ASN1EncodeItem(arena_.get(), nullptr, &pssParams,
211                                    SEC_ASN1_GET(SECKEY_RSAPSSParamsTemplate));
212     ASSERT_NE(nullptr, srcParams);
213   } else {
214     srcParams = NULL;
215   }
216 
217   SECItem *params = SEC_CreateSignatureAlgorithmParameters(
218       arena_.get(), nullptr, SEC_OID_PKCS1_RSA_PSS_SIGNATURE, hashAlg,
219       srcParams, privk_.get());
220 
221   if (hashAlg != SEC_OID_UNKNOWN && srcHashAlg != SEC_OID_UNKNOWN &&
222       hashAlg != srcHashAlg) {
223     EXPECT_EQ(nullptr, params);
224     return;
225   }
226 
227   EXPECT_NE(nullptr, params);
228 
229   SECKEYRSAPSSParams pssParams;
230   PORT_Memset(&pssParams, 0, sizeof(pssParams));
231   SECStatus rv =
232       SEC_QuickDERDecodeItem(arena_.get(), &pssParams,
233                              SEC_ASN1_GET(SECKEY_RSAPSSParamsTemplate), params);
234   ASSERT_EQ(SECSuccess, rv);
235 
236   if (hashAlg == SEC_OID_UNKNOWN) {
237     if (!pssParams.hashAlg) {
238       hashAlg = SEC_OID_SHA1;
239     } else {
240       hashAlg = SECOID_GetAlgorithmTag(pssParams.hashAlg);
241     }
242 
243     if (srcHashAlg == SEC_OID_UNKNOWN) {
244       // If both hashAlg and srcHashAlg is unset, NSS will decide the hash
245       // algorithm based on the key length; in this case it's SHA256
246       EXPECT_EQ(SEC_OID_SHA256, hashAlg);
247     } else {
248       EXPECT_EQ(srcHashAlg, hashAlg);
249     }
250   }
251 
252   ASSERT_NO_FATAL_FAILURE(CheckHashAlg(&pssParams, hashAlg));
253   ASSERT_NO_FATAL_FAILURE(CheckMaskAlg(&pssParams, hashAlg));
254   ASSERT_NO_FATAL_FAILURE(CheckSaltLength(&pssParams, hashAlg));
255 
256   // The default trailer field (1) must be omitted
257   EXPECT_EQ(nullptr, pssParams.trailerField.data);
258 }
259 
TEST_P(SignParamsTest,CreateRsaPssWithECPrivateKey)260 TEST_P(SignParamsTest, CreateRsaPssWithECPrivateKey) {
261   SECOidTag hashAlg = std::get<0>(GetParam());
262   SECOidTag srcHashAlg = std::get<1>(GetParam());
263 
264   SECItem *srcParams;
265   if (srcHashAlg != SEC_OID_UNKNOWN) {
266     SECKEYRSAPSSParams pssParams;
267     ASSERT_NO_FATAL_FAILURE(
268         CreatePssParams(&pssParams, srcHashAlg, srcHashAlg));
269     srcParams = SEC_ASN1EncodeItem(arena_.get(), nullptr, &pssParams,
270                                    SEC_ASN1_GET(SECKEY_RSAPSSParamsTemplate));
271     ASSERT_NE(nullptr, srcParams);
272   } else {
273     srcParams = NULL;
274   }
275 
276   SECItem *params = SEC_CreateSignatureAlgorithmParameters(
277       arena_.get(), nullptr, SEC_OID_PKCS1_RSA_PSS_SIGNATURE, hashAlg,
278       srcParams, ecPrivk_.get());
279 
280   EXPECT_EQ(nullptr, params);
281 }
282 
TEST_P(SignParamsTest,CreateRsaPssWithInvalidHashAlg)283 TEST_P(SignParamsTest, CreateRsaPssWithInvalidHashAlg) {
284   SECOidTag srcHashAlg = std::get<1>(GetParam());
285 
286   SECItem *srcParams;
287   if (srcHashAlg != SEC_OID_UNKNOWN) {
288     SECKEYRSAPSSParams pssParams;
289     ASSERT_NO_FATAL_FAILURE(
290         CreatePssParams(&pssParams, srcHashAlg, srcHashAlg));
291     srcParams = SEC_ASN1EncodeItem(arena_.get(), nullptr, &pssParams,
292                                    SEC_ASN1_GET(SECKEY_RSAPSSParamsTemplate));
293     ASSERT_NE(nullptr, srcParams);
294   } else {
295     srcParams = NULL;
296   }
297 
298   SECItem *params = SEC_CreateSignatureAlgorithmParameters(
299       arena_.get(), nullptr, SEC_OID_PKCS1_RSA_PSS_SIGNATURE, SEC_OID_MD5,
300       srcParams, privk_.get());
301 
302   EXPECT_EQ(nullptr, params);
303 }
304 
TEST_P(SignParamsSourceTest,CreateRsaPssWithInvalidHashAlg)305 TEST_P(SignParamsSourceTest, CreateRsaPssWithInvalidHashAlg) {
306   SECOidTag hashAlg = GetParam();
307 
308   SECItem *srcParams;
309   SECKEYRSAPSSParams pssParams;
310   ASSERT_NO_FATAL_FAILURE(
311       CreatePssParams(&pssParams, SEC_OID_MD5, SEC_OID_MD5));
312   srcParams = SEC_ASN1EncodeItem(arena_.get(), nullptr, &pssParams,
313                                  SEC_ASN1_GET(SECKEY_RSAPSSParamsTemplate));
314   ASSERT_NE(nullptr, srcParams);
315 
316   SECItem *params = SEC_CreateSignatureAlgorithmParameters(
317       arena_.get(), nullptr, SEC_OID_PKCS1_RSA_PSS_SIGNATURE, hashAlg,
318       srcParams, privk_.get());
319 
320   EXPECT_EQ(nullptr, params);
321 }
322 
TEST_P(SignParamsSourceTest,CreateRsaPssWithInvalidSaltLength)323 TEST_P(SignParamsSourceTest, CreateRsaPssWithInvalidSaltLength) {
324   SECOidTag hashAlg = GetParam();
325 
326   SECItem *srcParams;
327   SECKEYRSAPSSParams pssParams;
328   ASSERT_NO_FATAL_FAILURE(
329       CreatePssParams(&pssParams, SEC_OID_SHA512, SEC_OID_SHA512, 100));
330   srcParams = SEC_ASN1EncodeItem(arena_.get(), nullptr, &pssParams,
331                                  SEC_ASN1_GET(SECKEY_RSAPSSParamsTemplate));
332   ASSERT_NE(nullptr, srcParams);
333 
334   SECItem *params = SEC_CreateSignatureAlgorithmParameters(
335       arena_.get(), nullptr, SEC_OID_PKCS1_RSA_PSS_SIGNATURE, hashAlg,
336       srcParams, privk_.get());
337 
338   EXPECT_EQ(nullptr, params);
339 }
340 
TEST_P(SignParamsSourceTest,CreateRsaPssWithHashMismatch)341 TEST_P(SignParamsSourceTest, CreateRsaPssWithHashMismatch) {
342   SECOidTag hashAlg = GetParam();
343 
344   SECItem *srcParams;
345   SECKEYRSAPSSParams pssParams;
346   ASSERT_NO_FATAL_FAILURE(
347       CreatePssParams(&pssParams, SEC_OID_SHA256, SEC_OID_SHA512));
348   srcParams = SEC_ASN1EncodeItem(arena_.get(), nullptr, &pssParams,
349                                  SEC_ASN1_GET(SECKEY_RSAPSSParamsTemplate));
350   ASSERT_NE(nullptr, srcParams);
351 
352   SECItem *params = SEC_CreateSignatureAlgorithmParameters(
353       arena_.get(), nullptr, SEC_OID_PKCS1_RSA_PSS_SIGNATURE, hashAlg,
354       srcParams, privk_.get());
355 
356   EXPECT_EQ(nullptr, params);
357 }
358 
359 INSTANTIATE_TEST_CASE_P(
360     SignParamsTestCases, SignParamsTest,
361     ::testing::Combine(::testing::Values(SEC_OID_UNKNOWN, SEC_OID_SHA1,
362                                          SEC_OID_SHA224, SEC_OID_SHA256,
363                                          SEC_OID_SHA384, SEC_OID_SHA512),
364                        ::testing::Values(SEC_OID_UNKNOWN, SEC_OID_SHA1,
365                                          SEC_OID_SHA224, SEC_OID_SHA256,
366                                          SEC_OID_SHA384, SEC_OID_SHA512)));
367 
368 INSTANTIATE_TEST_CASE_P(SignParamsSourceTestCases, SignParamsSourceTest,
369                         ::testing::Values(SEC_OID_UNKNOWN, SEC_OID_SHA1,
370                                           SEC_OID_SHA224, SEC_OID_SHA256,
371                                           SEC_OID_SHA384, SEC_OID_SHA512));
372 
373 }  // namespace nss_test
374