1 #include <assert.h>
2 #include "common.h"
3 #include "mont.h"
4
5 int ge(const uint64_t *x, const uint64_t *y, size_t nw);
6 unsigned sub(uint64_t *out, const uint64_t *a, const uint64_t *b, size_t nw);
7 void rsquare(uint64_t *r2, uint64_t *n, size_t nw);
8 int mont_select(uint64_t *out, const uint64_t *a, const uint64_t *b, unsigned cond, unsigned words);
9
test_ge(void)10 void test_ge(void)
11 {
12 int res;
13 uint64_t x[2] = { 1, 2 };
14 uint64_t y[2] = { 2, 1 };
15
16 res = ge(x, y, 2);
17 assert(res == 1);
18 res = ge(x, x, 2);
19 assert(res == 1);
20 res = ge(y, x, 2);
21 assert(res == 0);
22 }
23
test_sub(void)24 void test_sub(void)
25 {
26 uint64_t res;
27 uint64_t x[2] = { 1, 2 };
28 uint64_t y[2] = { 2, 1 };
29 uint64_t out[2];
30
31 memset(out, 0xFF, sizeof out);
32 res = sub(out, x, x, 2);
33 assert(res == 0);
34 assert(out[0] == 0 && out[1] == 0);
35
36 memset(out, 0xFF, sizeof out);
37 x[0] = 1; x[1] = 2;
38 res = sub(out, x, y, 2);
39 assert(res == 0);
40 assert(out[0] == 0xFFFFFFFFFFFFFFFFUL);
41 assert(out[1] == 0);
42
43 memset(out, 0xFF, sizeof out);
44 x[0] = 1; x[1] = 2;
45 res = sub(out, y, x, 2);
46 assert(res == 1);
47 assert(out[0] == 1);
48 assert(out[1] == 0xFFFFFFFFFFFFFFFFUL);
49 }
50
test_rsquare(void)51 void test_rsquare(void)
52 {
53 uint64_t n1[2] = { 1, 0x89 };
54 uint64_t r2[2];
55
56 rsquare(r2, n1, 2);
57 assert(r2[0] == 0x44169db8eb2b48d8L);
58 assert(r2[1] == 0x18);
59 }
60
test_mont_context_init(void)61 void test_mont_context_init(void)
62 {
63 int res;
64 MontContext *ctx;
65 uint8_t modulus[] = { 1, 0, 0, 1 };
66 uint8_t modulus_even[] = { 1, 0, 0, 2 };
67
68 res = mont_context_init(NULL, modulus, 4);
69 assert(res == ERR_NULL);
70
71 res = mont_context_init(&ctx, 0, 4);
72 assert(res == ERR_NULL);
73
74 res = mont_context_init(&ctx, modulus, 0);
75 assert(res == ERR_MODULUS);
76
77 res = mont_context_init(&ctx, modulus_even, 4);
78 assert(res == ERR_MODULUS);
79
80 res = mont_context_init(&ctx, modulus, 4);
81 assert(res == 0);
82
83 mont_context_free(ctx);
84 }
85
test_mont_from_bytes(void)86 void test_mont_from_bytes(void)
87 {
88 int res;
89 MontContext *ctx;
90 uint8_t modulus[16] = { 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 };
91 uint8_t number[] = { 2, 2 };
92 uint64_t *output;
93
94 res = mont_context_init(&ctx, modulus, 16);
95 assert(res == 0);
96
97 res = mont_from_bytes(NULL, number, 2, ctx);
98 assert(res == ERR_NULL);
99
100 res = mont_from_bytes(&output, NULL, 2, ctx);
101 assert(res == ERR_NULL);
102
103 res = mont_from_bytes(&output, number, 2, NULL);
104 assert(res == ERR_NULL);
105
106 res = mont_from_bytes(&output, number, 0, ctx);
107 assert(res == ERR_NOT_ENOUGH_DATA);
108
109 res = mont_from_bytes(&output, number, 2, ctx);
110 assert(res == 0);
111 assert(output != NULL);
112 assert(output[0] == 18446744073709420033UL);
113 assert(output[1] == 71492449356218367L);
114 free(output);
115
116 number[0] = 0;
117 number[1] = 0;
118 res = mont_from_bytes(&output, number, 2, ctx);
119 assert(res == 0);
120 assert(output != NULL);
121 assert(output[0] == 0);
122 assert(output[1] == 0);
123 free(output);
124
125 mont_context_free(ctx);
126 }
127
test_mont_to_bytes(void)128 void test_mont_to_bytes(void)
129 {
130 int res;
131 MontContext *ctx;
132 uint8_t modulus[16] = { 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }; // 0x01000001000000000000000000000001
133 uint64_t number_mont[2] = { 18446744073709420033UL, 71492449356218367L };
134 uint8_t number[16];
135
136 memset(number, 0xAA, 16);
137
138 res = mont_context_init(&ctx, modulus, 16);
139 assert(res == 0);
140 assert(mont_bytes(ctx) == 16);
141
142 res = mont_to_bytes(NULL, 16, number_mont, ctx);
143 assert(res == ERR_NULL);
144 res = mont_to_bytes(number, 16, NULL, ctx);
145 assert(res == ERR_NULL);
146 res = mont_to_bytes(number, 16, number_mont, NULL);
147 assert(res == ERR_NULL);
148
149 res = mont_to_bytes(number, 15, number_mont, ctx);
150 assert(res == ERR_NOT_ENOUGH_DATA);
151
152 res = mont_to_bytes(number, 16, number_mont, ctx);
153 assert(res == 0);
154 assert(0 == memcmp(number, "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x02", 16));
155
156 mont_context_free(ctx);
157 }
158
test_mont_add(void)159 void test_mont_add(void)
160 {
161 int res;
162 MontContext *ctx;
163 uint64_t *tmp;
164 uint8_t modulus[16] = { 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }; // 0x01000001000000000000000000000001
165 uint8_t modulus2[16];
166 uint64_t a[2] = { -1, -1 };
167 uint64_t b[2] = { 1, 0 };
168 uint64_t out[2];
169
170 mont_context_init(&ctx, modulus, 16);
171 mont_number(&tmp, 2, ctx);
172
173 res = mont_add(NULL, a, b, tmp, ctx);
174 assert(res == ERR_NULL);
175 res = mont_add(out, NULL, b, tmp, ctx);
176 assert(res == ERR_NULL);
177 res = mont_add(out, a, NULL, tmp, ctx);
178 assert(res == ERR_NULL);
179 res = mont_add(out, a, b, NULL, ctx);
180 assert(res == ERR_NULL);
181 res = mont_add(out, a, b, tmp, NULL);
182 assert(res == ERR_NULL);
183
184 // 0x100000200000100000000000000000L + 0x100000200000100000000000000000L
185 a[0] = 0x10;
186 a[1] = 0;
187 b[0] = 0x100;
188 b[1] = 0;
189 res = mont_add(out, a, b, tmp, ctx);
190 assert(res == 0);
191 assert(out[0] == 0x110);
192 assert(out[1] == 0);
193
194 // 0xff0000fdffffff0000000000000001 + 0x100
195 a[0] = 0x0;
196 a[1] = 0x100000100000000UL;
197 b[0] = 0xffffffffffff0001L;
198 b[1] = 0xff0000ffffffffL;
199 res = mont_add(out, a, b, tmp, ctx);
200 assert(res == 0);
201 assert(out[0] == 0xffffffffffff0000L);
202 assert(out[1] == 0xff0000ffffffffL);
203
204 // 0xff0000fdffffff0000000000000001L * 2
205 a[0] = 0;
206 a[1] = 0x100000100000000L;
207 b[0] = 0;
208 b[1] = 0x100000100000000L;
209 res = mont_add(out, a, b, tmp, ctx);
210 assert(res == 0);
211 assert(out[0] == 0xffffffffffffffffL);
212 assert(out[1] == 0x1000000ffffffffL);
213
214 // Use modulus2, to trigger overflow over R
215 mont_context_free(ctx);
216 memset(modulus2, 0xFF, 16);
217 mont_context_init(&ctx, modulus2, 16);
218
219 // 0xfffffffffffffffffffffffffffffffe * 2
220 // (same encoding in Montgomery domain)
221 a[0] = 0xfffffffffffffffeL;
222 a[1] = 0xffffffffffffffffL;
223 b[0] = a[0];
224 b[1] = a[1];
225 res = mont_add(out, a, b, tmp, ctx);
226 assert(res == 0);
227 assert(out[0] == 0xfffffffffffffffdL);
228 assert(out[1] == 0xffffffffffffffffL);
229
230 free(tmp);
231 mont_context_free(ctx);
232 }
233
test_mont_sub(void)234 void test_mont_sub(void)
235 {
236 int res;
237 MontContext *ctx;
238 uint8_t modulus[16] = { 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }; // 0x01000001000000000000000000000001
239 uint64_t a[2] = { 0, 0 };
240 uint64_t b[2] = { 1, 0 };
241 uint64_t out[3];
242 uint64_t *tmp;
243
244 mont_context_init(&ctx, modulus, 16);
245 mont_number(&tmp, 2, ctx);
246
247 res = mont_sub(NULL, a, b, tmp, ctx);
248 assert(res == ERR_NULL);
249 res = mont_sub(out, NULL, b, tmp, ctx);
250 assert(res == ERR_NULL);
251 res = mont_sub(out, a, NULL, tmp, ctx);
252 assert(res == ERR_NULL);
253 res = mont_sub(out, a, b, NULL, ctx);
254 assert(res == ERR_NULL);
255 res = mont_sub(out, a, b, tmp, NULL);
256 assert(res == ERR_NULL);
257
258 out[2] = 0xA;
259 res = mont_sub(out, a, b, tmp, ctx);
260 assert(res == 0);
261 assert(out[0] == 0);
262 assert(out[1] == 0x100000100000000);
263 assert(out[2] == 0xA);
264
265 res = mont_sub(out, b, a, tmp, ctx);
266 assert(res == 0);
267 assert(out[0] == 1);
268 assert(out[1] == 0);
269
270 free(tmp);
271 mont_context_free(ctx);
272 }
273
test_mont_inv_prime(void)274 void test_mont_inv_prime(void)
275 {
276 int res;
277 MontContext *ctx;
278 uint8_t modulus_f6[9] = { 1, 0, 0, 0, 0, 0, 0, 0, 1 }; // F6 = 2^64 + 1
279 uint64_t a[2] = { 1, 0 };
280 uint64_t out[2];
281 uint64_t *p;
282 uint8_t buf[16];
283
284 uint8_t modulus_p521[66] = "\x01\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF";
285 uint64_t out_p521[9];
286 uint8_t buf_p521[66];
287
288 res = mont_context_init(&ctx, modulus_f6, sizeof modulus_f6);
289 assert(res == 0);
290
291 res = mont_inv_prime(NULL, a, ctx);
292 assert(res == ERR_NULL);
293 res = mont_inv_prime(out, NULL, ctx);
294 assert(res == ERR_NULL);
295 res = mont_inv_prime(out, a, NULL);
296 assert(res == ERR_NULL);
297
298 /* 1 == R mod N when N = F6 */
299 a[0] = 1; a[1] = 0;
300 out[0] = 1; out[1] = 0;
301 res = mont_inv_prime(out, a, ctx);
302 assert(res == 0);
303 assert(out[0] == 1);
304 assert(out[1] == 0);
305
306 assert(sizeof buf == mont_bytes(ctx));
307
308 /* 2^{-1} mod N = 0x8000000000000001 */
309 res = mont_from_bytes(&p, (uint8_t*)"\x00\x02", 2, ctx);
310 assert(res == 0);
311 res = mont_inv_prime(out, p, ctx);
312 assert(res == 0);
313 res = mont_to_bytes(buf, 16, out, ctx);
314 assert(res == 0);
315 assert(0 == memcmp(buf, (uint8_t*)"\x00\x00\x00\x00\x00\x00\x00\x00\x80\x00\x00\x00\x00\x00\x00\x01", 16));
316 free(p);
317
318 /* 3^{-1} mod N = 0x287cbedc41008ca6 */
319 res = mont_from_bytes(&p, (uint8_t*)"\x00\x03", 2, ctx);
320 assert(res == 0);
321 res = mont_inv_prime(out, p, ctx);
322 assert(res == 0);
323 res = mont_to_bytes(buf, 16, out, ctx);
324 assert(res == 0);
325 assert(0 == memcmp(buf, (uint8_t*)"\x00\x00\x00\x00\x00\x00\x00\x00\x28\x7c\xbe\xdc\x41\x00\x8c\xa6", 16));
326 free(p);
327
328 mont_context_free(ctx);
329
330 /* --- */
331 mont_context_init(&ctx, modulus_p521, sizeof modulus_p521);
332 res = mont_from_bytes(&p, (uint8_t*)"\x01\xE9\xF3\x4F\x60\xAD\x5C\x4B\x98\x8A\xB4\x3A\x0C\x1C\x40\xFB\x5C\xB0\xFD\x1A\x1A\x6F\x4E\x81\xEB\x33\xDE\x7D\x95\x2E\xE2\x62\x0D\x76\x08\x3B\xA2\x28\xCC\x56\xA4\xFE\xD2\xF6\x08\xF3\x17\x1E\x59\x41\xB7\xE1\x6D\x20\x05\xEB\x9F\x55\x6B\x6B\xA1\x36\x0E\xC2\x35\x8C", 66, ctx);
333 assert(res == 0);
334
335 res = mont_inv_prime(out_p521, p, ctx);
336 assert(res == 0);
337
338 res = mont_to_bytes(buf_p521, 66, out_p521, ctx);
339 assert(res == 0);
340 assert(0 == memcmp(buf_p521, (uint8_t*)"\x01\xF5\xDD\xE7\xED\xB2\xAD\x9D\x06\x2F\x2C\xAE\x1B\x66\x95\xC0\x9B\xE6\x16\xDA\xEA\x07\x2A\xC8\x2A\xFB\x44\xF4\x21\x79\xE1\x38\x8B\x1C\xEF\x91\xBA\xD3\xEB\x1D\x81\xE5\x45\xEF\x54\x63\xD7\x0A\xED\x39\x70\xFC\xD5\x95\xFF\x1B\xA7\x52\x11\xD3\xC3\x3C\x2C\x14\x42\x51", 66));
341
342 free(p);
343 mont_context_free(ctx);
344 }
345
346
test_mont_set(void)347 void test_mont_set(void)
348 {
349 int res;
350 MontContext *ctx;
351 uint8_t modulus[16] = { 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }; // 0x01000001000000000000000000000001
352 uint64_t out[2];
353
354 mont_context_init(&ctx, modulus, 16);
355
356 res = mont_set(NULL, 0x1000, ctx);
357 assert(res == ERR_NULL);
358 res = mont_set(out, 0x1000, NULL);
359 assert(res == ERR_NULL);
360
361 res = mont_set(out, 0, ctx);
362 assert(res == 0);
363 assert(out[0] == 0);
364 assert(out[1] == 0);
365
366 res = mont_set(out, 0x1000, ctx);
367 assert(res == 0);
368 assert(out[0] == 0xfffffffffff00001UL);
369 assert(out[1] == 0xf00000ffffffffUL);
370
371 mont_context_free(ctx);
372 }
373
test_mont_select()374 void test_mont_select()
375 {
376 int res;
377 MontContext *ctx;
378 uint8_t modulusA[16] = { 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }; // 0x01000001000000000000000000000001
379 uint8_t modulusB[17] = { 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3 }; // 0x0301000001000000000000000000000001
380 uint64_t a[2] = { 0xFFFFFFFFFFFFFFFFU, 0xFFFFFFFFFFFFFFFFU };
381 uint64_t b[2] = { 1, 1 };
382 uint64_t c[2];
383 uint64_t d[3] = { 0xFFFFFFFFFFFFFFFFU, 0xFFFFFFFFFFFFFFFFU, 3 };
384 uint64_t e[3] = { 1, 1, 3 };
385 uint64_t f[3];
386
387 mont_context_init(&ctx, modulusA, 16);
388
389 memset(c, 0, sizeof c);
390 res = mont_select(c, a, b, 1, ctx->words);
391 assert(res == 0);
392 assert(memcmp(a, c, sizeof c) == 0);
393
394 memset(c, 0, sizeof c);
395 res = mont_select(c, a, b, 10, ctx->words);
396 assert(res == 0);
397 assert(memcmp(a, c, sizeof c) == 0);
398
399 memset(c, 0, sizeof c);
400 res = mont_select(c, a, b, 0, ctx->words);
401 assert(res == 0);
402 assert(memcmp(b, c, sizeof c) == 0);
403
404 mont_context_free(ctx);
405
406 /* --- */
407
408 mont_context_init(&ctx, modulusB, 17);
409
410 memset(f, 0, sizeof f);
411 res = mont_select(f, d, e, 1, ctx->words);
412 assert(res == 0);
413 assert(memcmp(d, f, sizeof f) == 0);
414
415 memset(f, 0, sizeof f);
416 res = mont_select(f, d, e, 0, ctx->words);
417 assert(res == 0);
418 assert(memcmp(e, f, sizeof f) == 0);
419 }
420
main(void)421 int main(void) {
422 test_ge();
423 test_sub();
424 test_rsquare();
425 test_mont_context_init();
426 test_mont_from_bytes();
427 test_mont_to_bytes();
428 test_mont_add();
429 test_mont_sub();
430 test_mont_inv_prime();
431 test_mont_set();
432 test_mont_select();
433 return 0;
434 }
435