1 // This Source Code Form is subject to the terms of the Mozilla Public
2 // License, v. 2.0. If a copy of the MPL was not distributed with this file,
3 // You can obtain one at http://mozilla.org/MPL/2.0/.
4 
5 #include "gtest/gtest.h"
6 
7 #include <stdint.h>
8 #include <string.h>
9 #include <memory>
10 
11 #ifdef __MACH__
12 #include <mach/clock.h>
13 #include <mach/mach.h>
14 #endif
15 
16 #include "mplogic.h"
17 #include "mpi.h"
18 namespace nss_test {
19 
gettime(struct timespec * tp)20 void gettime(struct timespec* tp) {
21 #ifdef __MACH__
22   clock_serv_t cclock;
23   mach_timespec_t mts;
24 
25   host_get_clock_service(mach_host_self(), SYSTEM_CLOCK, &cclock);
26   clock_get_time(cclock, &mts);
27   mach_port_deallocate(mach_task_self(), cclock);
28 
29   tp->tv_sec = mts.tv_sec;
30   tp->tv_nsec = mts.tv_nsec;
31 #else
32   ASSERT_NE(0, timespec_get(tp, TIME_UTC));
33 #endif
34 }
35 
36 class MPITest : public ::testing::Test {
37  protected:
TestCmp(const std::string a_string,const std::string b_string,int result)38   void TestCmp(const std::string a_string, const std::string b_string,
39                int result) {
40     mp_int a, b;
41     MP_DIGITS(&a) = 0;
42     MP_DIGITS(&b) = 0;
43     ASSERT_EQ(MP_OKAY, mp_init(&a));
44     ASSERT_EQ(MP_OKAY, mp_init(&b));
45 
46     mp_read_radix(&a, a_string.c_str(), 16);
47     mp_read_radix(&b, b_string.c_str(), 16);
48     EXPECT_EQ(result, mp_cmp(&a, &b));
49 
50     mp_clear(&a);
51     mp_clear(&b);
52   }
53 
TestDiv(const std::string a_string,const std::string b_string,const std::string result)54   void TestDiv(const std::string a_string, const std::string b_string,
55                const std::string result) {
56     mp_int a, b, c;
57     MP_DIGITS(&a) = 0;
58     MP_DIGITS(&b) = 0;
59     MP_DIGITS(&c) = 0;
60     ASSERT_EQ(MP_OKAY, mp_init(&a));
61     ASSERT_EQ(MP_OKAY, mp_init(&b));
62     ASSERT_EQ(MP_OKAY, mp_init(&c));
63 
64     mp_read_radix(&a, a_string.c_str(), 16);
65     mp_read_radix(&b, b_string.c_str(), 16);
66     mp_read_radix(&c, result.c_str(), 16);
67     EXPECT_EQ(MP_OKAY, mp_div(&a, &b, &a, &b));
68     EXPECT_EQ(0, mp_cmp(&a, &c));
69 
70     mp_clear(&a);
71     mp_clear(&b);
72     mp_clear(&c);
73   }
74 
dump(const std::string & prefix,const uint8_t * buf,size_t len)75   void dump(const std::string& prefix, const uint8_t* buf, size_t len) {
76     auto flags = std::cerr.flags();
77     std::cerr << prefix << ": [" << std::dec << len << "] ";
78     for (size_t i = 0; i < len; ++i) {
79       std::cerr << std::hex << std::setw(2) << std::setfill('0')
80                 << static_cast<int>(buf[i]);
81     }
82     std::cerr << std::endl << std::resetiosflags(flags);
83   }
84 
TestToFixedOctets(const std::vector<uint8_t> & ref,size_t len)85   void TestToFixedOctets(const std::vector<uint8_t>& ref, size_t len) {
86     mp_int a;
87     ASSERT_EQ(MP_OKAY, mp_init(&a));
88     ASSERT_EQ(MP_OKAY, mp_read_unsigned_octets(&a, ref.data(), ref.size()));
89     std::unique_ptr<uint8_t[]> buf(new uint8_t[len]);
90     ASSERT_NE(buf, nullptr);
91     ASSERT_EQ(MP_OKAY, mp_to_fixlen_octets(&a, buf.get(), len));
92     size_t compare;
93     if (len > ref.size()) {
94       for (size_t i = 0; i < len - ref.size(); ++i) {
95         ASSERT_EQ(0U, buf[i]) << "index " << i << " should be zero";
96       }
97       compare = ref.size();
98     } else {
99       compare = len;
100     }
101     dump("value", ref.data(), ref.size());
102     dump("output", buf.get(), len);
103     ASSERT_EQ(0, memcmp(buf.get() + len - compare,
104                         ref.data() + ref.size() - compare, compare))
105         << "comparing " << compare << " octets";
106     mp_clear(&a);
107   }
108 };
109 
TEST_F(MPITest,MpiCmp01Test)110 TEST_F(MPITest, MpiCmp01Test) { TestCmp("0", "1", -1); }
TEST_F(MPITest,MpiCmp10Test)111 TEST_F(MPITest, MpiCmp10Test) { TestCmp("1", "0", 1); }
TEST_F(MPITest,MpiCmp00Test)112 TEST_F(MPITest, MpiCmp00Test) { TestCmp("0", "0", 0); }
TEST_F(MPITest,MpiCmp11Test)113 TEST_F(MPITest, MpiCmp11Test) { TestCmp("1", "1", 0); }
TEST_F(MPITest,MpiDiv32ErrorTest)114 TEST_F(MPITest, MpiDiv32ErrorTest) {
115   TestDiv("FFFF00FFFFFFFF000000000000", "FFFF00FFFFFFFFFF", "FFFFFFFFFF");
116 }
117 
118 #ifdef NSS_X64
119 // This tests assumes 64-bit mp_digits.
TEST_F(MPITest,MpiCmpUnalignedTest)120 TEST_F(MPITest, MpiCmpUnalignedTest) {
121   mp_int a, b, c;
122   MP_DIGITS(&a) = 0;
123   MP_DIGITS(&b) = 0;
124   MP_DIGITS(&c) = 0;
125   ASSERT_EQ(MP_OKAY, mp_init(&a));
126   ASSERT_EQ(MP_OKAY, mp_init(&b));
127   ASSERT_EQ(MP_OKAY, mp_init(&c));
128 
129   mp_read_radix(&a, "ffffffffffffffff3b4e802b4e1478", 16);
130   mp_read_radix(&b, "ffffffffffffffff3b4e802b4e1478", 16);
131   EXPECT_EQ(0, mp_cmp(&a, &b));
132 
133   // Now change a and b such that they contain the same numbers but are not
134   // aligned.
135   // a = ffffffffffffff|ff3b4e802b4e1478
136   // b = ffffffffffffffff|3b4e802b4e1478
137   MP_DIGITS(&b)[0] &= 0x00ffffffffffffff;
138   MP_DIGITS(&b)[1] = 0xffffffffffffffff;
139   EXPECT_EQ(-1, mp_cmp(&a, &b));
140 
141   ASSERT_EQ(MP_OKAY, mp_sub(&a, &b, &c));
142   char c_tmp[40];
143   ASSERT_EQ(MP_OKAY, mp_toradix(&c, c_tmp, 16));
144   ASSERT_TRUE(strncmp(c_tmp, "feffffffffffffff100000000000000", 31));
145 
146   mp_clear(&a);
147   mp_clear(&b);
148   mp_clear(&c);
149 }
150 #endif
151 
152 // The two follow tests ensure very similar mp_set_* functions are ok.
TEST_F(MPITest,MpiSetUlong)153 TEST_F(MPITest, MpiSetUlong) {
154   mp_int a, b, c;
155   MP_DIGITS(&a) = 0;
156   MP_DIGITS(&b) = 0;
157   MP_DIGITS(&c) = 0;
158   ASSERT_EQ(MP_OKAY, mp_init(&a));
159   ASSERT_EQ(MP_OKAY, mp_init(&b));
160   ASSERT_EQ(MP_OKAY, mp_init(&c));
161   EXPECT_EQ(MP_OKAY, mp_set_ulong(&a, 1));
162   EXPECT_EQ(MP_OKAY, mp_set_ulong(&b, 0));
163   EXPECT_EQ(MP_OKAY, mp_set_ulong(&c, -1));
164 
165   mp_clear(&a);
166   mp_clear(&b);
167   mp_clear(&c);
168 }
169 
TEST_F(MPITest,MpiSetInt)170 TEST_F(MPITest, MpiSetInt) {
171   mp_int a, b, c;
172   MP_DIGITS(&a) = 0;
173   MP_DIGITS(&b) = 0;
174   MP_DIGITS(&c) = 0;
175   ASSERT_EQ(MP_OKAY, mp_init(&a));
176   ASSERT_EQ(MP_OKAY, mp_init(&b));
177   ASSERT_EQ(MP_OKAY, mp_init(&c));
178   EXPECT_EQ(MP_OKAY, mp_set_int(&a, 1));
179   EXPECT_EQ(MP_OKAY, mp_set_int(&b, 0));
180   EXPECT_EQ(MP_OKAY, mp_set_int(&c, -1));
181 
182   mp_clear(&a);
183   mp_clear(&b);
184   mp_clear(&c);
185 }
186 
TEST_F(MPITest,MpiFixlenOctetsZero)187 TEST_F(MPITest, MpiFixlenOctetsZero) {
188   std::vector<uint8_t> zero = {0};
189   TestToFixedOctets(zero, 1);
190   TestToFixedOctets(zero, 2);
191   TestToFixedOctets(zero, sizeof(mp_digit));
192   TestToFixedOctets(zero, sizeof(mp_digit) + 1);
193 }
194 
TEST_F(MPITest,MpiRadixSizeNeg)195 TEST_F(MPITest, MpiRadixSizeNeg) {
196   char* str;
197   mp_int a;
198   mp_err rv;
199   const char* negative_edge =
200       "-5400000000000000003000000002200020090919017007777777777870000090"
201       "00000000007500443416610000000000000000000000000000000000000000000"
202       "00000000000000000000000000000000000000000000000000000000075049054"
203       "18610000800555594485440016000031555550000000000000000220030200909"
204       "19017007777777700000000000000000000000000000000000000000000000000"
205       "00000000000500000000000000000000000000004668129841661000071000000"
206       "00000000000000000000000000000000000000000000000007504434166100000"
207       "00000000000000000000000000000000000000000000000000000000000000000"
208       "00000000075049054186100008005555944854400184572169555500000000000"
209       "0000022003020090919017007777777700000000000000000000";
210 
211   rv = mp_init(&a);
212   ASSERT_EQ(MP_OKAY, rv);
213   rv = mp_read_variable_radix(&a, negative_edge, 10);
214   ASSERT_EQ(MP_OKAY, rv);
215 
216   const int radixSize = mp_radix_size(&a, 10);
217   ASSERT_LE(0, radixSize);
218 
219   str = (char*)malloc(radixSize);
220   ASSERT_NE(nullptr, str);
221   rv = mp_toradix(&a, str, 10);
222   ASSERT_EQ(MP_OKAY, rv);
223   ASSERT_EQ(0, strcmp(negative_edge, str));
224   free(str);
225   mp_clear(&a);
226 }
227 
TEST_F(MPITest,MpiFixlenOctetsVarlen)228 TEST_F(MPITest, MpiFixlenOctetsVarlen) {
229   std::vector<uint8_t> packed;
230   for (size_t i = 0; i < sizeof(mp_digit) * 2; ++i) {
231     packed.push_back(0xa4);  // Any non-zero value will do.
232     TestToFixedOctets(packed, packed.size());
233     TestToFixedOctets(packed, packed.size() + 1);
234     TestToFixedOctets(packed, packed.size() + sizeof(mp_digit));
235   }
236 }
237 
TEST_F(MPITest,MpiFixlenOctetsTooSmall)238 TEST_F(MPITest, MpiFixlenOctetsTooSmall) {
239   uint8_t buf[sizeof(mp_digit) * 3];
240   std::vector<uint8_t> ref;
241   for (size_t i = 0; i < sizeof(mp_digit) * 2; i++) {
242     ref.push_back(3);  // Any non-zero value will do.
243     dump("ref", ref.data(), ref.size());
244 
245     mp_int a;
246     ASSERT_EQ(MP_OKAY, mp_init(&a));
247     ASSERT_EQ(MP_OKAY, mp_read_unsigned_octets(&a, ref.data(), ref.size()));
248 #ifdef DEBUG
249     // ARGCHK maps to assert() in a debug build.
250     EXPECT_DEATH(mp_to_fixlen_octets(&a, buf, ref.size() - 1), "");
251 #else
252     EXPECT_EQ(MP_BADARG, mp_to_fixlen_octets(&a, buf, ref.size() - 1));
253 #endif
254     ASSERT_EQ(MP_OKAY, mp_to_fixlen_octets(&a, buf, ref.size()));
255     ASSERT_EQ(0, memcmp(buf, ref.data(), ref.size()));
256 
257     mp_clear(&a);
258   }
259 }
260 
TEST_F(MPITest,MpiSqrMulClamp)261 TEST_F(MPITest, MpiSqrMulClamp) {
262   mp_int a, r, expect;
263   MP_DIGITS(&a) = 0;
264   MP_DIGITS(&r) = 0;
265   MP_DIGITS(&expect) = 0;
266 
267   // Comba32 result is 64 mp_digits. *=2 as this is an ascii representation.
268   std::string expect_str((64 * sizeof(mp_digit)) * 2, '0');
269 
270   // Set second-highest bit (0x80...^2 == 0x4000...)
271   expect_str.replace(0, 1, "4", 1);
272 
273   // Test 32, 16, 8, and 4-1 mp_digit values. 32-4 (powers of two) use the comba
274   // assembly implementation, if enabled and supported. 3-1 use non-comba.
275   int n_digits = 32;
276   while (n_digits > 0) {
277     ASSERT_EQ(MP_OKAY, mp_init(&r));
278     ASSERT_EQ(MP_OKAY, mp_init(&a));
279     ASSERT_EQ(MP_OKAY, mp_init(&expect));
280     ASSERT_EQ(MP_OKAY, mp_read_radix(&expect, expect_str.c_str(), 16));
281 
282     ASSERT_EQ(MP_OKAY, mp_set_int(&a, 1));
283     ASSERT_EQ(MP_OKAY, mpl_lsh(&a, &a, (n_digits * sizeof(mp_digit) * 8) - 1));
284 
285     ASSERT_EQ(MP_OKAY, mp_sqr(&a, &r));
286     EXPECT_EQ(MP_USED(&expect), MP_USED(&r));
287     EXPECT_EQ(0, mp_cmp(&r, &expect));
288     mp_clear(&r);
289 
290     // Take the mul path...
291     ASSERT_EQ(MP_OKAY, mp_init(&r));
292     ASSERT_EQ(MP_OKAY, mp_mul(&a, &a, &r));
293     EXPECT_EQ(MP_USED(&expect), MP_USED(&r));
294     EXPECT_EQ(0, mp_cmp(&r, &expect));
295 
296     mp_clear(&a);
297     mp_clear(&r);
298     mp_clear(&expect);
299 
300     // Once we're down to 4, check non-powers of two.
301     int sub = n_digits > 4 ? n_digits / 2 : 1;
302     n_digits -= sub;
303 
304     // "Shift right" the string (to avoid mutating |expect_str| with MPI).
305     expect_str.resize(expect_str.size() - 2 * 2 * sizeof(mp_digit) * sub);
306   }
307 }
308 
TEST_F(MPITest,MpiInvModLoop)309 TEST_F(MPITest, MpiInvModLoop) {
310   mp_int a;
311   mp_int m;
312   mp_int c_actual;
313   mp_int c_expect;
314   MP_DIGITS(&a) = 0;
315   MP_DIGITS(&m) = 0;
316   MP_DIGITS(&c_actual) = 0;
317   MP_DIGITS(&c_expect) = 0;
318   ASSERT_EQ(MP_OKAY, mp_init(&a));
319   ASSERT_EQ(MP_OKAY, mp_init(&m));
320   ASSERT_EQ(MP_OKAY, mp_init(&c_actual));
321   ASSERT_EQ(MP_OKAY, mp_init(&c_expect));
322   mp_read_radix(&a,
323                 "3e10b9f4859fb9e8150cc0d94e83ef428d655702a0b6fb1e684f4755eb6be6"
324                 "5ac6048cdfc533f73a9bad76125801051f",
325                 16);
326   mp_read_radix(&m,
327                 "ffffffffffffffffffffffffffffffffffffffffffffffffc7634d81f4372d"
328                 "df581a0db248b0a77aecec196accc52973",
329                 16);
330   mp_read_radix(&c_expect,
331                 "12302214814361c15ab6c0f2131150af186099f8c22f6c9d6e77ad496b551c"
332                 "7c8039e61098bfe2af66474420659435c6",
333                 16);
334 
335   int rv = mp_invmod(&a, &m, &c_actual);
336   ASSERT_EQ(MP_OKAY, rv);
337 
338   rv = mp_cmp(&c_actual, &c_expect);
339   EXPECT_EQ(0, rv);
340 
341   mp_clear(&a);
342   mp_clear(&m);
343   mp_clear(&c_actual);
344   mp_clear(&c_expect);
345 }
346 
347 // This test is slow. Disable it by default so we can run these tests on CI.
348 class DISABLED_MPITest : public ::testing::Test {};
349 
TEST_F(DISABLED_MPITest,MpiCmpConstTest)350 TEST_F(DISABLED_MPITest, MpiCmpConstTest) {
351   mp_int a, b, c;
352   MP_DIGITS(&a) = 0;
353   MP_DIGITS(&b) = 0;
354   MP_DIGITS(&c) = 0;
355   ASSERT_EQ(MP_OKAY, mp_init(&a));
356   ASSERT_EQ(MP_OKAY, mp_init(&b));
357   ASSERT_EQ(MP_OKAY, mp_init(&c));
358 
359   mp_read_radix(
360       &a,
361       const_cast<char*>(
362           "FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551"),
363       16);
364   mp_read_radix(
365       &b,
366       const_cast<char*>(
367           "FF0FFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551"),
368       16);
369   mp_read_radix(
370       &c,
371       const_cast<char*>(
372           "FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632550"),
373       16);
374 
375 #ifdef CT_VERIF
376   mp_taint(&b);
377   mp_taint(&c);
378 #endif
379 
380   uint32_t runs = 5000000;
381   uint32_t time_b = 0, time_c = 0;
382   for (uint32_t i = 0; i < runs; ++i) {
383     struct timespec start, end;
384     gettime(&start);
385     int r = mp_cmp(&a, &b);
386     gettime(&end);
387     unsigned long long used = end.tv_sec * 1000000000L + end.tv_nsec;
388     used -= static_cast<unsigned long long>(start.tv_sec * 1000000000L +
389                                             start.tv_nsec);
390     time_b += used;
391     ASSERT_EQ(1, r);
392   }
393   printf("time b: %u\n", time_b / runs);
394 
395   for (uint32_t i = 0; i < runs; ++i) {
396     struct timespec start, end;
397     gettime(&start);
398     int r = mp_cmp(&a, &c);
399     gettime(&end);
400     unsigned long long used = end.tv_sec * 1000000000L + end.tv_nsec;
401     used -= static_cast<unsigned long long>(start.tv_sec * 1000000000L +
402                                             start.tv_nsec);
403     time_c += used;
404     ASSERT_EQ(1, r);
405   }
406   printf("time c: %u\n", time_c / runs);
407 
408   mp_clear(&a);
409   mp_clear(&b);
410   mp_clear(&c);
411 }
412 
413 }  // namespace nss_test
414