1 #include <assert.h>
2 #include "common.h"
3 #include "mont.h"
4 
5 int ge(const uint64_t *x, const uint64_t *y, size_t nw);
6 unsigned sub(uint64_t *out, const uint64_t *a, const uint64_t *b, size_t nw);
7 void rsquare(uint64_t *r2, uint64_t *n, size_t nw);
8 int mont_select(uint64_t *out, const uint64_t *a, const uint64_t *b, unsigned cond, unsigned words);
9 
test_ge(void)10 void test_ge(void)
11 {
12     int res;
13     uint64_t x[2] = { 1, 2 };
14     uint64_t y[2] = { 2, 1 };
15 
16     res = ge(x, y, 2);
17     assert(res == 1);
18     res = ge(x, x, 2);
19     assert(res == 1);
20     res = ge(y, x, 2);
21     assert(res == 0);
22 }
23 
test_sub(void)24 void test_sub(void)
25 {
26     uint64_t res;
27     uint64_t x[2] = { 1, 2 };
28     uint64_t y[2] = { 2, 1 };
29     uint64_t out[2];
30 
31     memset(out, 0xFF, sizeof out);
32     res = sub(out, x, x, 2);
33     assert(res == 0);
34     assert(out[0] == 0 && out[1] == 0);
35 
36     memset(out, 0xFF, sizeof out);
37     x[0] = 1; x[1] = 2;
38     res = sub(out, x, y, 2);
39     assert(res == 0);
40     assert(out[0] == 0xFFFFFFFFFFFFFFFFUL);
41     assert(out[1] == 0);
42 
43     memset(out, 0xFF, sizeof out);
44     x[0] = 1; x[1] = 2;
45     res = sub(out, y, x, 2);
46     assert(res == 1);
47     assert(out[0] == 1);
48     assert(out[1] == 0xFFFFFFFFFFFFFFFFUL);
49 }
50 
test_rsquare(void)51 void test_rsquare(void)
52 {
53     uint64_t n1[2] = { 1, 0x89 };
54     uint64_t r2[2];
55 
56     rsquare(r2, n1, 2);
57     assert(r2[0] == 0x44169db8eb2b48d8L);
58     assert(r2[1] == 0x18);
59 }
60 
test_mont_context_init(void)61 void test_mont_context_init(void)
62 {
63     int res;
64     MontContext *ctx;
65     uint8_t modulus[] = { 1, 0, 0, 1 };
66     uint8_t modulus_even[] = { 1, 0, 0, 2 };
67 
68     res = mont_context_init(NULL, modulus, 4);
69     assert(res == ERR_NULL);
70 
71     res = mont_context_init(&ctx, 0, 4);
72     assert(res == ERR_NULL);
73 
74     res = mont_context_init(&ctx, modulus, 0);
75     assert(res == ERR_MODULUS);
76 
77     res = mont_context_init(&ctx, modulus_even, 4);
78     assert(res == ERR_MODULUS);
79 
80     res = mont_context_init(&ctx, modulus, 4);
81     assert(res == 0);
82 
83     mont_context_free(ctx);
84 }
85 
test_mont_from_bytes(void)86 void test_mont_from_bytes(void)
87 {
88     int res;
89     MontContext *ctx;
90     uint8_t modulus[16] = { 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 };
91     uint8_t number[] = { 2, 2 };
92     uint64_t *output;
93 
94     res = mont_context_init(&ctx, modulus, 16);
95     assert(res == 0);
96 
97     res = mont_from_bytes(NULL, number, 2, ctx);
98     assert(res == ERR_NULL);
99 
100     res = mont_from_bytes(&output, NULL, 2, ctx);
101     assert(res == ERR_NULL);
102 
103     res = mont_from_bytes(&output, number, 2, NULL);
104     assert(res == ERR_NULL);
105 
106     res = mont_from_bytes(&output, number, 0, ctx);
107     assert(res == ERR_NOT_ENOUGH_DATA);
108 
109     res = mont_from_bytes(&output, number, 2, ctx);
110     assert(res == 0);
111     assert(output != NULL);
112     assert(output[0] == 18446744073709420033UL);
113     assert(output[1] == 71492449356218367L);
114     free(output);
115 
116     number[0] = 0;
117     number[1] = 0;
118     res = mont_from_bytes(&output, number, 2, ctx);
119     assert(res == 0);
120     assert(output != NULL);
121     assert(output[0] == 0);
122     assert(output[1] == 0);
123     free(output);
124 
125     mont_context_free(ctx);
126 }
127 
test_mont_to_bytes(void)128 void test_mont_to_bytes(void)
129 {
130     int res;
131     MontContext *ctx;
132     uint8_t modulus[16] = { 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 };   // 0x01000001000000000000000000000001
133     uint64_t number_mont[2] = { 18446744073709420033UL, 71492449356218367L };
134     uint8_t number[16];
135 
136     memset(number, 0xAA, 16);
137 
138     res = mont_context_init(&ctx, modulus, 16);
139     assert(res == 0);
140     assert(mont_bytes(ctx) == 16);
141 
142     res = mont_to_bytes(NULL, 16, number_mont, ctx);
143     assert(res == ERR_NULL);
144     res = mont_to_bytes(number, 16, NULL, ctx);
145     assert(res == ERR_NULL);
146     res = mont_to_bytes(number, 16, number_mont, NULL);
147     assert(res == ERR_NULL);
148 
149     res = mont_to_bytes(number, 15, number_mont, ctx);
150     assert(res == ERR_NOT_ENOUGH_DATA);
151 
152     res = mont_to_bytes(number, 16, number_mont, ctx);
153     assert(res == 0);
154     assert(0 == memcmp(number, "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x02", 16));
155 
156     mont_context_free(ctx);
157 }
158 
test_mont_add(void)159 void test_mont_add(void)
160 {
161     int res;
162     MontContext *ctx;
163     uint64_t *tmp;
164     uint8_t modulus[16] = { 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 };   // 0x01000001000000000000000000000001
165     uint8_t modulus2[16];
166     uint64_t a[2] = { -1, -1 };
167     uint64_t b[2] = { 1, 0 };
168     uint64_t out[2];
169 
170     mont_context_init(&ctx, modulus, 16);
171     mont_number(&tmp, 2, ctx);
172 
173     res = mont_add(NULL, a, b, tmp, ctx);
174     assert(res == ERR_NULL);
175     res = mont_add(out, NULL, b, tmp, ctx);
176     assert(res == ERR_NULL);
177     res = mont_add(out, a, NULL, tmp, ctx);
178     assert(res == ERR_NULL);
179     res = mont_add(out, a, b, NULL, ctx);
180     assert(res == ERR_NULL);
181     res = mont_add(out, a, b, tmp, NULL);
182     assert(res == ERR_NULL);
183 
184     // 0x100000200000100000000000000000L + 0x100000200000100000000000000000L
185     a[0] = 0x10;
186     a[1] = 0;
187     b[0] = 0x100;
188     b[1] = 0;
189     res = mont_add(out, a, b, tmp, ctx);
190     assert(res == 0);
191     assert(out[0] == 0x110);
192     assert(out[1] == 0);
193 
194     // 0xff0000fdffffff0000000000000001 + 0x100
195     a[0] = 0x0;
196     a[1] = 0x100000100000000UL;
197     b[0] = 0xffffffffffff0001L;
198     b[1] = 0xff0000ffffffffL;
199     res = mont_add(out, a, b, tmp, ctx);
200     assert(res == 0);
201     assert(out[0] == 0xffffffffffff0000L);
202     assert(out[1] == 0xff0000ffffffffL);
203 
204     // 0xff0000fdffffff0000000000000001L * 2
205     a[0] = 0;
206     a[1] = 0x100000100000000L;
207     b[0] = 0;
208     b[1] = 0x100000100000000L;
209     res = mont_add(out, a, b, tmp, ctx);
210     assert(res == 0);
211     assert(out[0] == 0xffffffffffffffffL);
212     assert(out[1] == 0x1000000ffffffffL);
213 
214     // Use modulus2, to trigger overflow over R
215     mont_context_free(ctx);
216     memset(modulus2, 0xFF, 16);
217     mont_context_init(&ctx, modulus2, 16);
218 
219     // 0xfffffffffffffffffffffffffffffffe * 2
220     // (same encoding in Montgomery domain)
221     a[0] = 0xfffffffffffffffeL;
222     a[1] = 0xffffffffffffffffL;
223     b[0] = a[0];
224     b[1] = a[1];
225     res = mont_add(out, a, b, tmp, ctx);
226     assert(res == 0);
227     assert(out[0] == 0xfffffffffffffffdL);
228     assert(out[1] == 0xffffffffffffffffL);
229 
230     free(tmp);
231     mont_context_free(ctx);
232 }
233 
test_mont_sub(void)234 void test_mont_sub(void)
235 {
236     int res;
237     MontContext *ctx;
238     uint8_t modulus[16] = { 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 };   // 0x01000001000000000000000000000001
239     uint64_t a[2] = { 0, 0 };
240     uint64_t b[2] = { 1, 0 };
241     uint64_t out[3];
242     uint64_t *tmp;
243 
244     mont_context_init(&ctx, modulus, 16);
245     mont_number(&tmp, 2, ctx);
246 
247     res = mont_sub(NULL, a, b, tmp, ctx);
248     assert(res == ERR_NULL);
249     res = mont_sub(out, NULL, b, tmp, ctx);
250     assert(res == ERR_NULL);
251     res = mont_sub(out, a, NULL, tmp, ctx);
252     assert(res == ERR_NULL);
253     res = mont_sub(out, a, b, NULL, ctx);
254     assert(res == ERR_NULL);
255     res = mont_sub(out, a, b, tmp, NULL);
256     assert(res == ERR_NULL);
257 
258     out[2] = 0xA;
259     res = mont_sub(out, a, b, tmp, ctx);
260     assert(res == 0);
261     assert(out[0] == 0);
262     assert(out[1] == 0x100000100000000);
263     assert(out[2] == 0xA);
264 
265     res = mont_sub(out, b, a, tmp, ctx);
266     assert(res == 0);
267     assert(out[0] == 1);
268     assert(out[1] == 0);
269 
270     free(tmp);
271     mont_context_free(ctx);
272 }
273 
test_mont_inv_prime(void)274 void test_mont_inv_prime(void)
275 {
276     int res;
277     MontContext *ctx;
278     uint8_t modulus_f6[9] = { 1, 0, 0, 0, 0, 0, 0, 0, 1 }; // F6 = 2^64 + 1
279     uint64_t a[2] = { 1, 0 };
280     uint64_t out[2];
281     uint64_t *p;
282     uint8_t buf[16];
283 
284     uint8_t modulus_p521[66] = "\x01\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF";
285     uint64_t out_p521[9];
286     uint8_t buf_p521[66];
287 
288     res = mont_context_init(&ctx, modulus_f6, sizeof modulus_f6);
289     assert(res == 0);
290 
291     res = mont_inv_prime(NULL, a, ctx);
292     assert(res == ERR_NULL);
293     res = mont_inv_prime(out, NULL, ctx);
294     assert(res == ERR_NULL);
295     res = mont_inv_prime(out, a, NULL);
296     assert(res == ERR_NULL);
297 
298     /* 1 == R mod N when N = F6 */
299     a[0] = 1;   a[1] = 0;
300     out[0] = 1; out[1] = 0;
301     res = mont_inv_prime(out, a, ctx);
302     assert(res == 0);
303     assert(out[0] == 1);
304     assert(out[1] == 0);
305 
306     assert(sizeof buf == mont_bytes(ctx));
307 
308     /* 2^{-1} mod N = 0x8000000000000001 */
309     res = mont_from_bytes(&p, (uint8_t*)"\x00\x02", 2, ctx);
310     assert(res == 0);
311     res = mont_inv_prime(out, p, ctx);
312     assert(res == 0);
313     res = mont_to_bytes(buf, 16, out, ctx);
314     assert(res == 0);
315     assert(0 == memcmp(buf, (uint8_t*)"\x00\x00\x00\x00\x00\x00\x00\x00\x80\x00\x00\x00\x00\x00\x00\x01", 16));
316     free(p);
317 
318     /* 3^{-1} mod N = 0x287cbedc41008ca6 */
319     res = mont_from_bytes(&p, (uint8_t*)"\x00\x03", 2, ctx);
320     assert(res == 0);
321     res = mont_inv_prime(out, p, ctx);
322     assert(res == 0);
323     res = mont_to_bytes(buf, 16, out, ctx);
324     assert(res == 0);
325     assert(0 == memcmp(buf, (uint8_t*)"\x00\x00\x00\x00\x00\x00\x00\x00\x28\x7c\xbe\xdc\x41\x00\x8c\xa6", 16));
326     free(p);
327 
328     mont_context_free(ctx);
329 
330     /* --- */
331     mont_context_init(&ctx, modulus_p521, sizeof modulus_p521);
332     res = mont_from_bytes(&p, (uint8_t*)"\x01\xE9\xF3\x4F\x60\xAD\x5C\x4B\x98\x8A\xB4\x3A\x0C\x1C\x40\xFB\x5C\xB0\xFD\x1A\x1A\x6F\x4E\x81\xEB\x33\xDE\x7D\x95\x2E\xE2\x62\x0D\x76\x08\x3B\xA2\x28\xCC\x56\xA4\xFE\xD2\xF6\x08\xF3\x17\x1E\x59\x41\xB7\xE1\x6D\x20\x05\xEB\x9F\x55\x6B\x6B\xA1\x36\x0E\xC2\x35\x8C", 66, ctx);
333     assert(res == 0);
334 
335     res = mont_inv_prime(out_p521, p, ctx);
336     assert(res == 0);
337 
338     res = mont_to_bytes(buf_p521, 66, out_p521, ctx);
339     assert(res == 0);
340     assert(0 == memcmp(buf_p521, (uint8_t*)"\x01\xF5\xDD\xE7\xED\xB2\xAD\x9D\x06\x2F\x2C\xAE\x1B\x66\x95\xC0\x9B\xE6\x16\xDA\xEA\x07\x2A\xC8\x2A\xFB\x44\xF4\x21\x79\xE1\x38\x8B\x1C\xEF\x91\xBA\xD3\xEB\x1D\x81\xE5\x45\xEF\x54\x63\xD7\x0A\xED\x39\x70\xFC\xD5\x95\xFF\x1B\xA7\x52\x11\xD3\xC3\x3C\x2C\x14\x42\x51", 66));
341 
342     free(p);
343     mont_context_free(ctx);
344 }
345 
346 
test_mont_set(void)347 void test_mont_set(void)
348 {
349     int res;
350     MontContext *ctx;
351     uint8_t modulus[16] = { 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 };   // 0x01000001000000000000000000000001
352     uint64_t out[2];
353 
354     mont_context_init(&ctx, modulus, 16);
355 
356     res = mont_set(NULL, 0x1000, ctx);
357     assert(res == ERR_NULL);
358     res = mont_set(out, 0x1000, NULL);
359     assert(res == ERR_NULL);
360 
361     res = mont_set(out, 0, ctx);
362     assert(res == 0);
363     assert(out[0] == 0);
364     assert(out[1] == 0);
365 
366     res = mont_set(out, 0x1000, ctx);
367     assert(res == 0);
368     assert(out[0] == 0xfffffffffff00001UL);
369     assert(out[1] == 0xf00000ffffffffUL);
370 
371     mont_context_free(ctx);
372 }
373 
test_mont_select()374 void test_mont_select()
375 {
376     int res;
377     MontContext *ctx;
378     uint8_t modulusA[16] = { 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 };      // 0x01000001000000000000000000000001
379     uint8_t modulusB[17] = { 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3 };   // 0x0301000001000000000000000000000001
380     uint64_t a[2] = { 0xFFFFFFFFFFFFFFFFU, 0xFFFFFFFFFFFFFFFFU };
381     uint64_t b[2] = { 1, 1 };
382     uint64_t c[2];
383     uint64_t d[3] = { 0xFFFFFFFFFFFFFFFFU, 0xFFFFFFFFFFFFFFFFU, 3 };
384     uint64_t e[3] = { 1, 1, 3 };
385     uint64_t f[3];
386 
387     mont_context_init(&ctx, modulusA, 16);
388 
389     memset(c, 0, sizeof c);
390     res = mont_select(c, a, b, 1, ctx->words);
391     assert(res == 0);
392     assert(memcmp(a, c, sizeof c) == 0);
393 
394     memset(c, 0, sizeof c);
395     res = mont_select(c, a, b, 10, ctx->words);
396     assert(res == 0);
397     assert(memcmp(a, c, sizeof c) == 0);
398 
399     memset(c, 0, sizeof c);
400     res = mont_select(c, a, b, 0, ctx->words);
401     assert(res == 0);
402     assert(memcmp(b, c, sizeof c) == 0);
403 
404     mont_context_free(ctx);
405 
406     /* --- */
407 
408     mont_context_init(&ctx, modulusB, 17);
409 
410     memset(f, 0, sizeof f);
411     res = mont_select(f, d, e, 1, ctx->words);
412     assert(res == 0);
413     assert(memcmp(d, f, sizeof f) == 0);
414 
415     memset(f, 0, sizeof f);
416     res = mont_select(f, d, e, 0, ctx->words);
417     assert(res == 0);
418     assert(memcmp(e, f, sizeof f) == 0);
419 }
420 
main(void)421 int main(void) {
422     test_ge();
423     test_sub();
424     test_rsquare();
425     test_mont_context_init();
426     test_mont_from_bytes();
427     test_mont_to_bytes();
428     test_mont_add();
429     test_mont_sub();
430     test_mont_inv_prime();
431     test_mont_set();
432     test_mont_select();
433     return 0;
434 }
435