1"""Make unit test for mont_mult() and mont_mult_generic() in mont.c"""
2
3from common import counter, make_main, split64, inverse, bin2int
4from hashlib import sha256
5import struct
6
7def make_test(a, b, modulus, use_mont=True):
8
9    assert(0 <= a < modulus)
10    assert(0 <= b < modulus)
11    assert(modulus & 1)
12
13    R = 1
14    nw = 0
15    B = 1<<64
16    while modulus >= R:
17        R <<= 64
18        nw += 1
19
20    if not use_mont:
21        R = 1
22
23    n0 = modulus & (B-1)
24    m0 = -inverse(n0, B) % B
25    assert(0 < m0 < B)
26
27    a_m = (a*R) % modulus
28    b_m = (b*R) % modulus
29
30    # What we expect the function to compute
31    result_m = (a*b*R) % modulus
32
33    # Turn data into arrays of 64-bit words
34    a_m_s = split64(a_m)
35    b_m_s = split64(b_m)
36    modulus_s = split64(modulus)
37    result_m_s = split64(result_m)
38
39    # Everything must have nw words
40    for ds in (a_m_s, b_m_s, modulus_s, result_m_s):
41        ds += ["0"] * (nw - len(ds))
42
43    # Modulus also byte encoded, big endian
44    modulus_b = []
45    while modulus > 0:
46        modulus_b.insert(0, hex(modulus % 256))
47        modulus >>= 8
48
49    if use_mont:
50        test_nr = counter.next()
51        print ""
52        print "void test_%d() {" % test_nr
53        print "    const uint64_t a[] = {" + ", ".join(a_m_s) + "};"
54        print "    const uint64_t b[] = {" + ", ".join(b_m_s) + "};"
55        print "    const uint64_t n[] = {" + ", ".join(modulus_s) + "};"
56        print "    const uint64_t expected[] = {" + ", ".join(result_m_s) + "};"
57        print "    uint64_t out[%d];" % (nw+1)
58        print "    uint64_t scratch[%d];" % (7*nw)
59        print ""
60        print "    memset(out, 0xAA, sizeof out);"
61        print "    mont_mult_generic(out, a, b, n, %dUL, scratch, %d);" % (m0, nw)
62        print "    assert(memcmp(out, expected, 8*%d) == 0);" % nw
63        print "    assert(out[%d] == 0xAAAAAAAAAAAAAAAAUL);" % nw
64        print "}"
65        print ""
66
67    test_nr = counter.next()
68    print ""
69    print "void test_%d() {" % test_nr
70    print "    const uint64_t a[] = {" + ", ".join(a_m_s) + "};"
71    print "    const uint64_t b[] = {" + ", ".join(b_m_s) + "};"
72    print "    const uint8_t modulus[] = {" + ", ".join(modulus_b) + "};"
73    print "    const uint64_t expected[] = {" + ", ".join(result_m_s) + "};"
74    print "    uint64_t out[%d];" % (nw+1)
75    print "    MontContext *ctx;"
76    print "    int res;"
77    print "    uint64_t scratch[%d];" % (7*nw)
78    print ""
79    print
80    print "    res = mont_context_init(&ctx, modulus, sizeof modulus);"
81    print "    assert(res == 0);"
82    print "    memset(out, 0xAA, sizeof out);"
83    print "    res = mont_mult(out, a, b, scratch, ctx);"
84    print "    assert(res == 0);"
85    print "    assert(out[%d] == 0xAAAAAAAAAAAAAAAAUL);" % nw
86    print "    assert(memcmp(out, expected, 8*%d) == 0);" % nw
87    print "    mont_context_free(ctx);"
88    print "}"
89    print ""
90
91
92
93print "#include <assert.h>"
94print "#include <string.h>"
95print "#include <stdint.h>"
96print "#include <stdio.h>"
97print '#include "mont.h"'
98print ""
99print "void mont_mult_generic(uint64_t *out, const uint64_t *a, const uint64_t *b, const uint64_t *n, uint64_t m0, uint64_t *t, size_t nw);"
100
101p256 = 115792089210356248762697446949407573530086143415290314195533631308867097853951
102p384 = 39402006196394479212279040100143613805079739270465446667948293404245721771496870329047266088258938001861606973112319
103p521 = 0x000001ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
104
105make_test(2, 3, 255)
106make_test(2, 240, 255)
107make_test(189, 240, 255)
108make_test(189, 240, 32984723984723984723847)
109make_test(189000000, 7878787878, 32984723984723984723847)
110make_test(1890000003439483948394839843434, 78787878780003984834673498384734, 3298472398472398472384798743287438734875384758435834539400000033988787)
111
112for x in range(100):
113    modulus_len = x//10 + 5 # 40 bit .. 112 bits
114    modulus = bin2int(sha256(b"modulus" + struct.pack(">I", x)).digest()[:-modulus_len]) |  1
115    a = bin2int(sha256(b"a" + struct.pack(">I", x)).digest()) % modulus
116    b = bin2int(sha256(b"b" + struct.pack(">I", x)).digest()) % modulus
117    make_test(a, b, modulus)
118
119for x in range(100):
120    a = bin2int(sha256(b"a" + struct.pack(">I", x)).digest()) % p256
121    b = bin2int(sha256(b"b" + struct.pack(">I", x)).digest()) % p256
122    make_test(a, b, p256)
123
124for x in range(100):
125    a = bin2int(sha256(b"a" + struct.pack(">I", x)).digest()) % p384
126    b = bin2int(sha256(b"b" + struct.pack(">I", x)).digest()) % p384
127    make_test(a, b, p384)
128
129for x in range(100):
130    a = bin2int(sha256(b"a" + struct.pack(">I", x)).digest()) % p521
131    b = bin2int(sha256(b"b" + struct.pack(">I", x)).digest()) % p521
132    make_test(a, b, p521, use_mont=False)
133
134make_main()
135