1#!/usr/bin/env python3
2
3import sys
4import string
5from collections import namedtuple
6
7assert sys.version_info[:2] >= (3,0), "This is Python 3 code"
8
9class Multiprecision(object):
10    def __init__(self, target, minval, maxval, words):
11        self.target = target
12        self.minval = minval
13        self.maxval = maxval
14        self.words = words
15        assert 0 <= self.minval
16        assert self.minval <= self.maxval
17        assert self.target.nwords(self.maxval) == len(words)
18
19    def getword(self, n):
20        return self.words[n] if n < len(self.words) else "0"
21
22    def __add__(self, rhs):
23        newmin = self.minval + rhs.minval
24        newmax = self.maxval + rhs.maxval
25        nwords = self.target.nwords(newmax)
26        words = []
27
28        addfn = self.target.add
29        for i in range(nwords):
30            words.append(addfn(self.getword(i), rhs.getword(i)))
31            addfn = self.target.adc
32
33        return Multiprecision(self.target, newmin, newmax, words)
34
35    def __mul__(self, rhs):
36        newmin = self.minval * rhs.minval
37        newmax = self.maxval * rhs.maxval
38        nwords = self.target.nwords(newmax)
39        words = []
40
41        # There are basically two strategies we could take for
42        # multiplying two multiprecision integers. One is to enumerate
43        # the space of pairs of word indices in lexicographic order,
44        # essentially computing a*b[i] for each i and adding them
45        # together; the other is to enumerate in diagonal order,
46        # computing everything together that belongs at a particular
47        # output word index.
48        #
49        # For the moment, I've gone for the former.
50
51        sprev = []
52        for i, sword in enumerate(self.words):
53            rprev = None
54            sthis = sprev[:i]
55            for j, rword in enumerate(rhs.words):
56                prevwords = []
57                if i+j < len(sprev):
58                    prevwords.append(sprev[i+j])
59                if rprev is not None:
60                    prevwords.append(rprev)
61                vhi, vlo = self.target.muladd(sword, rword, *prevwords)
62                sthis.append(vlo)
63                rprev = vhi
64            sthis.append(rprev)
65            sprev = sthis
66
67        # Remove unneeded words from the top of the output, if we can
68        # prove by range analysis that they'll always be zero.
69        sprev = sprev[:self.target.nwords(newmax)]
70
71        return Multiprecision(self.target, newmin, newmax, sprev)
72
73    def extract_bits(self, start, bits=None):
74        if bits is None:
75            bits = (self.maxval >> start).bit_length()
76
77        # Overly thorough range analysis: if min and max have the same
78        # *quotient* by 2^bits, then the result of reducing anything
79        # in the range [min,max] mod 2^bits has to fall within the
80        # obvious range. But if they have different quotients, then
81        # you can wrap round the modulus and so any value mod 2^bits
82        # is possible.
83        newmin = self.minval >> start
84        newmax = self.maxval >> start
85        if (newmin >> bits) != (newmax >> bits):
86            newmin = 0
87            newmax = (1 << bits) - 1
88
89        nwords = self.target.nwords(newmax)
90        words = []
91        for i in range(nwords):
92            srcpos = i * self.target.bits + start
93            maxbits = min(self.target.bits, start + bits - srcpos)
94            wordindex = srcpos // self.target.bits
95            if srcpos % self.target.bits == 0:
96                word = self.getword(srcpos // self.target.bits)
97            elif (wordindex+1 >= len(self.words) or
98                  srcpos % self.target.bits + maxbits < self.target.bits):
99                word = self.target.new_value(
100                    "(%%s) >> %d" % (srcpos % self.target.bits),
101                    self.getword(srcpos // self.target.bits))
102            else:
103                word = self.target.new_value(
104                    "((%%s) >> %d) | ((%%s) << %d)" % (
105                        srcpos % self.target.bits,
106                        self.target.bits - (srcpos % self.target.bits)),
107                    self.getword(srcpos // self.target.bits),
108                    self.getword(srcpos // self.target.bits + 1))
109            if maxbits < self.target.bits and maxbits < bits:
110                word = self.target.new_value(
111                    "(%%s) & ((((BignumInt)1) << %d)-1)" % maxbits,
112                    word)
113            words.append(word)
114
115        return Multiprecision(self.target, newmin, newmax, words)
116
117# Each Statement has a list of variables it reads, and a list of ones
118# it writes. 'forms' is a list of multiple actual C statements it
119# could be generated as, depending on which of its output variables is
120# actually used (e.g. no point calling BignumADC if the generated
121# carry in a particular case is unused, or BignumMUL if nobody needs
122# the top half). It is indexed by a bitmap whose bits correspond to
123# the entries in wvars, with wvars[0] the MSB and wvars[-1] the LSB.
124Statement = namedtuple("Statement", "rvars wvars forms")
125
126class CodegenTarget(object):
127    def __init__(self, bits):
128        self.bits = bits
129        self.valindex = 0
130        self.stmts = []
131        self.generators = {}
132        self.bv_words = (130 + self.bits - 1) // self.bits
133        self.carry_index = 0
134
135    def nwords(self, maxval):
136        return (maxval.bit_length() + self.bits - 1) // self.bits
137
138    def stmt(self, stmt, needed=False):
139        index = len(self.stmts)
140        self.stmts.append([needed, stmt])
141        for val in stmt.wvars:
142            self.generators[val] = index
143
144    def new_value(self, formatstr=None, *deps):
145        name = "v%d" % self.valindex
146        self.valindex += 1
147        if formatstr is not None:
148            self.stmt(Statement(
149                    rvars=deps, wvars=[name],
150                    forms=[None, name + " = " + formatstr % deps]))
151        return name
152
153    def bigval_input(self, name, bits):
154        words = (bits + self.bits - 1) // self.bits
155        # Expect not to require an entire extra word
156        assert words == self.bv_words
157
158        return Multiprecision(self, 0, (1<<bits)-1, [
159                self.new_value("%s->w[%d]" % (name, i)) for i in range(words)])
160
161    def const(self, value):
162        # We only support constants small enough to both fit in a
163        # BignumInt (of any size supported) _and_ be expressible in C
164        # with no weird integer literal syntax like a trailing LL.
165        #
166        # Supporting larger constants would be possible - you could
167        # break 'value' up into word-sized pieces on the Python side,
168        # and generate a legal C expression for each piece by
169        # splitting it further into pieces within the
170        # standards-guaranteed 'unsigned long' limit of 32 bits and
171        # then casting those to BignumInt before combining them with
172        # shifts. But it would be a lot of effort, and since the
173        # application for this code doesn't even need it, there's no
174        # point in bothering.
175        assert value < 2**16
176        return Multiprecision(self, value, value, ["%d" % value])
177
178    def current_carry(self):
179        return "carry%d" % self.carry_index
180
181    def add(self, a1, a2):
182        ret = self.new_value()
183        adcform = "BignumADC(%s, carry, %s, %s, 0)" % (ret, a1, a2)
184        plainform = "%s = %s + %s" % (ret, a1, a2)
185        self.carry_index += 1
186        carryout = self.current_carry()
187        self.stmt(Statement(
188                rvars=[a1,a2], wvars=[ret,carryout],
189                forms=[None, adcform, plainform, adcform]))
190        return ret
191
192    def adc(self, a1, a2):
193        ret = self.new_value()
194        adcform = "BignumADC(%s, carry, %s, %s, carry)" % (ret, a1, a2)
195        plainform = "%s = %s + %s + carry" % (ret, a1, a2)
196        carryin = self.current_carry()
197        self.carry_index += 1
198        carryout = self.current_carry()
199        self.stmt(Statement(
200                rvars=[a1,a2,carryin], wvars=[ret,carryout],
201                forms=[None, adcform, plainform, adcform]))
202        return ret
203
204    def muladd(self, m1, m2, *addends):
205        rlo = self.new_value()
206        rhi = self.new_value()
207        wideform = "BignumMUL%s(%s)" % (
208            { 0:"", 1:"ADD", 2:"ADD2" }[len(addends)],
209            ", ".join([rhi, rlo, m1, m2] + list(addends)))
210        narrowform = " + ".join(["%s = %s * %s" % (rlo, m1, m2)] +
211                                list(addends))
212        self.stmt(Statement(
213                rvars=[m1,m2]+list(addends), wvars=[rhi,rlo],
214                forms=[None, narrowform, wideform, wideform]))
215        return rhi, rlo
216
217    def write_bigval(self, name, val):
218        for i in range(self.bv_words):
219            word = val.getword(i)
220            self.stmt(Statement(
221                    rvars=[word], wvars=[],
222                    forms=["%s->w[%d] = %s" % (name, i, word)]),
223                      needed=True)
224
225    def compute_needed(self):
226        used_vars = set()
227
228        self.queue = [stmt for (needed,stmt) in self.stmts if needed]
229        while len(self.queue) > 0:
230            stmt = self.queue.pop(0)
231            deps = []
232            for var in stmt.rvars:
233                if var[0] in string.digits:
234                    continue # constant
235                deps.append(self.generators[var])
236                used_vars.add(var)
237            for index in deps:
238                if not self.stmts[index][0]:
239                    self.stmts[index][0] = True
240                    self.queue.append(self.stmts[index][1])
241
242        forms = []
243        for i, (needed, stmt) in enumerate(self.stmts):
244            if needed:
245                formindex = 0
246                for (j, var) in enumerate(stmt.wvars):
247                    formindex *= 2
248                    if var in used_vars:
249                        formindex += 1
250                forms.append(stmt.forms[formindex])
251
252                # Now we must check whether this form of the statement
253                # also writes some variables we _don't_ actually need
254                # (e.g. if you only wanted the top half from a mul, or
255                # only the carry from an adc, you'd be forced to
256                # generate the other output too). Easiest way to do
257                # this is to look for an identical statement form
258                # later in the array.
259                maxindex = max(i for i in range(len(stmt.forms))
260                               if stmt.forms[i] == stmt.forms[formindex])
261                extra_vars = maxindex & ~formindex
262                bitpos = 0
263                while extra_vars != 0:
264                    if extra_vars & (1 << bitpos):
265                        extra_vars &= ~(1 << bitpos)
266                        var = stmt.wvars[-1-bitpos]
267                        used_vars.add(var)
268                        # Also, write out a cast-to-void for each
269                        # subsequently unused value, to prevent gcc
270                        # warnings when the output code is compiled.
271                        forms.append("(void)" + var)
272                    bitpos += 1
273
274        used_carry = any(v.startswith("carry") for v in used_vars)
275        used_vars = [v for v in used_vars if v.startswith("v")]
276        used_vars.sort(key=lambda v: int(v[1:]))
277
278        return used_carry, used_vars, forms
279
280    def text(self):
281        used_carry, values, forms = self.compute_needed()
282
283        ret = ""
284        while len(values) > 0:
285            prefix, sep, suffix = "    BignumInt ", ", ", ";"
286            currline = values.pop(0)
287            while (len(values) > 0 and
288                   len(prefix+currline+sep+values[0]+suffix) < 79):
289                currline += sep + values.pop(0)
290            ret += prefix + currline + suffix + "\n"
291        if used_carry:
292            ret += "    BignumCarry carry;\n"
293        if ret != "":
294            ret += "\n"
295        for stmtform in forms:
296            ret += "    %s;\n" % stmtform
297        return ret
298
299def gen_add(target):
300    # This is an addition _without_ reduction mod p, so that it can be
301    # used both during accumulation of the polynomial and for adding
302    # on the encrypted nonce at the end (which is mod 2^128, not mod
303    # p).
304    #
305    # Because one of the inputs will have come from our
306    # not-completely-reducing multiplication function, we expect up to
307    # 3 extra bits of input.
308
309    a = target.bigval_input("a", 133)
310    b = target.bigval_input("b", 133)
311    ret = a + b
312    target.write_bigval("r", ret)
313    return """\
314static void bigval_add(bigval *r, const bigval *a, const bigval *b)
315{
316%s}
317\n""" % target.text()
318
319def gen_mul(target):
320    # The inputs are not 100% reduced mod p. Specifically, we can get
321    # a full 130-bit number from the pow5==0 pass, and then a 130-bit
322    # number times 5 from the pow5==1 pass, plus a possible carry. The
323    # total of that can be easily bounded above by 2^130 * 8, so we
324    # need to assume we're multiplying two 133-bit numbers.
325
326    a = target.bigval_input("a", 133)
327    b = target.bigval_input("b", 133)
328    ab = a * b
329    ab0 = ab.extract_bits(0, 130)
330    ab1 = ab.extract_bits(130, 130)
331    ab2 = ab.extract_bits(260)
332    ab1_5 = target.const(5) * ab1
333    ab2_25 = target.const(25) * ab2
334    ret = ab0 + ab1_5 + ab2_25
335    target.write_bigval("r", ret)
336    return """\
337static void bigval_mul_mod_p(bigval *r, const bigval *a, const bigval *b)
338{
339%s}
340\n""" % target.text()
341
342def gen_final_reduce(target):
343    # Given our input number n, n >> 130 is usually precisely the
344    # multiple of p that needs to be subtracted from n to reduce it to
345    # strictly less than p, but it might be too low by 1 (but not more
346    # than 1, given the range of our input is nowhere near the square
347    # of the modulus). So we add another 5, which will push a carry
348    # into the 130th bit if and only if that has happened, and then
349    # use that to decide whether to subtract one more copy of p.
350
351    a = target.bigval_input("n", 133)
352    q = a.extract_bits(130)
353    adjusted = a.extract_bits(0, 130) + target.const(5) * q
354    final_subtract = (adjusted + target.const(5)).extract_bits(130)
355    adjusted2 = adjusted + target.const(5) * final_subtract
356    ret = adjusted2.extract_bits(0, 130)
357    target.write_bigval("n", ret)
358    return """\
359static void bigval_final_reduce(bigval *n)
360{
361%s}
362\n""" % target.text()
363
364pp_keyword = "#if"
365for bits in [16, 32, 64]:
366    sys.stdout.write("%s BIGNUM_INT_BITS == %d\n\n" % (pp_keyword, bits))
367    pp_keyword = "#elif"
368    sys.stdout.write(gen_add(CodegenTarget(bits)))
369    sys.stdout.write(gen_mul(CodegenTarget(bits)))
370    sys.stdout.write(gen_final_reduce(CodegenTarget(bits)))
371sys.stdout.write("""#else
372#error Add another bit count to contrib/make1305.py and rerun it
373#endif
374""")
375