1"""Make unit test for mont_mult() and mont_mult_generic() in mont.c""" 2 3from common import counter, make_main, split64, inverse, bin2int 4from hashlib import sha256 5import struct 6 7def make_test(a, b, modulus, use_mont=True): 8 9 assert(0 <= a < modulus) 10 assert(0 <= b < modulus) 11 assert(modulus & 1) 12 13 R = 1 14 nw = 0 15 B = 1<<64 16 while modulus >= R: 17 R <<= 64 18 nw += 1 19 20 if not use_mont: 21 R = 1 22 23 n0 = modulus & (B-1) 24 m0 = -inverse(n0, B) % B 25 assert(0 < m0 < B) 26 27 a_m = (a*R) % modulus 28 b_m = (b*R) % modulus 29 30 # What we expect the function to compute 31 result_m = (a*b*R) % modulus 32 33 # Turn data into arrays of 64-bit words 34 a_m_s = split64(a_m) 35 b_m_s = split64(b_m) 36 modulus_s = split64(modulus) 37 result_m_s = split64(result_m) 38 39 # Everything must have nw words 40 for ds in (a_m_s, b_m_s, modulus_s, result_m_s): 41 ds += ["0"] * (nw - len(ds)) 42 43 # Modulus also byte encoded, big endian 44 modulus_b = [] 45 while modulus > 0: 46 modulus_b.insert(0, hex(modulus % 256)) 47 modulus >>= 8 48 49 if use_mont: 50 test_nr = counter.next() 51 print "" 52 print "void test_%d() {" % test_nr 53 print " const uint64_t a[] = {" + ", ".join(a_m_s) + "};" 54 print " const uint64_t b[] = {" + ", ".join(b_m_s) + "};" 55 print " const uint64_t n[] = {" + ", ".join(modulus_s) + "};" 56 print " const uint64_t expected[] = {" + ", ".join(result_m_s) + "};" 57 print " uint64_t out[%d];" % (nw+1) 58 print " uint64_t scratch[%d];" % (7*nw) 59 print "" 60 print " memset(out, 0xAA, sizeof out);" 61 print " mont_mult_generic(out, a, b, n, %dUL, scratch, %d);" % (m0, nw) 62 print " assert(memcmp(out, expected, 8*%d) == 0);" % nw 63 print " assert(out[%d] == 0xAAAAAAAAAAAAAAAAUL);" % nw 64 print "}" 65 print "" 66 67 test_nr = counter.next() 68 print "" 69 print "void test_%d() {" % test_nr 70 print " const uint64_t a[] = {" + ", ".join(a_m_s) + "};" 71 print " const uint64_t b[] = {" + ", ".join(b_m_s) + "};" 72 print " const uint8_t modulus[] = {" + ", ".join(modulus_b) + "};" 73 print " const uint64_t expected[] = {" + ", ".join(result_m_s) + "};" 74 print " uint64_t out[%d];" % (nw+1) 75 print " MontContext *ctx;" 76 print " int res;" 77 print " uint64_t scratch[%d];" % (7*nw) 78 print "" 79 print 80 print " res = mont_context_init(&ctx, modulus, sizeof modulus);" 81 print " assert(res == 0);" 82 print " memset(out, 0xAA, sizeof out);" 83 print " res = mont_mult(out, a, b, scratch, ctx);" 84 print " assert(res == 0);" 85 print " assert(out[%d] == 0xAAAAAAAAAAAAAAAAUL);" % nw 86 print " assert(memcmp(out, expected, 8*%d) == 0);" % nw 87 print " mont_context_free(ctx);" 88 print "}" 89 print "" 90 91 92 93print "#include <assert.h>" 94print "#include <string.h>" 95print "#include <stdint.h>" 96print "#include <stdio.h>" 97print '#include "mont.h"' 98print "" 99print "void mont_mult_generic(uint64_t *out, const uint64_t *a, const uint64_t *b, const uint64_t *n, uint64_t m0, uint64_t *t, size_t nw);" 100 101p256 = 115792089210356248762697446949407573530086143415290314195533631308867097853951 102p384 = 39402006196394479212279040100143613805079739270465446667948293404245721771496870329047266088258938001861606973112319 103p521 = 0x000001ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff 104 105make_test(2, 3, 255) 106make_test(2, 240, 255) 107make_test(189, 240, 255) 108make_test(189, 240, 32984723984723984723847) 109make_test(189000000, 7878787878, 32984723984723984723847) 110make_test(1890000003439483948394839843434, 78787878780003984834673498384734, 3298472398472398472384798743287438734875384758435834539400000033988787) 111 112for x in range(100): 113 modulus_len = x//10 + 5 # 40 bit .. 112 bits 114 modulus = bin2int(sha256(b"modulus" + struct.pack(">I", x)).digest()[:-modulus_len]) | 1 115 a = bin2int(sha256(b"a" + struct.pack(">I", x)).digest()) % modulus 116 b = bin2int(sha256(b"b" + struct.pack(">I", x)).digest()) % modulus 117 make_test(a, b, modulus) 118 119for x in range(100): 120 a = bin2int(sha256(b"a" + struct.pack(">I", x)).digest()) % p256 121 b = bin2int(sha256(b"b" + struct.pack(">I", x)).digest()) % p256 122 make_test(a, b, p256) 123 124for x in range(100): 125 a = bin2int(sha256(b"a" + struct.pack(">I", x)).digest()) % p384 126 b = bin2int(sha256(b"b" + struct.pack(">I", x)).digest()) % p384 127 make_test(a, b, p384) 128 129for x in range(100): 130 a = bin2int(sha256(b"a" + struct.pack(">I", x)).digest()) % p521 131 b = bin2int(sha256(b"b" + struct.pack(">I", x)).digest()) % p521 132 make_test(a, b, p521, use_mont=False) 133 134make_main() 135