1#!/usr/bin/env python3 2 3import sys 4import string 5from collections import namedtuple 6 7assert sys.version_info[:2] >= (3,0), "This is Python 3 code" 8 9class Multiprecision(object): 10 def __init__(self, target, minval, maxval, words): 11 self.target = target 12 self.minval = minval 13 self.maxval = maxval 14 self.words = words 15 assert 0 <= self.minval 16 assert self.minval <= self.maxval 17 assert self.target.nwords(self.maxval) == len(words) 18 19 def getword(self, n): 20 return self.words[n] if n < len(self.words) else "0" 21 22 def __add__(self, rhs): 23 newmin = self.minval + rhs.minval 24 newmax = self.maxval + rhs.maxval 25 nwords = self.target.nwords(newmax) 26 words = [] 27 28 addfn = self.target.add 29 for i in range(nwords): 30 words.append(addfn(self.getword(i), rhs.getword(i))) 31 addfn = self.target.adc 32 33 return Multiprecision(self.target, newmin, newmax, words) 34 35 def __mul__(self, rhs): 36 newmin = self.minval * rhs.minval 37 newmax = self.maxval * rhs.maxval 38 nwords = self.target.nwords(newmax) 39 words = [] 40 41 # There are basically two strategies we could take for 42 # multiplying two multiprecision integers. One is to enumerate 43 # the space of pairs of word indices in lexicographic order, 44 # essentially computing a*b[i] for each i and adding them 45 # together; the other is to enumerate in diagonal order, 46 # computing everything together that belongs at a particular 47 # output word index. 48 # 49 # For the moment, I've gone for the former. 50 51 sprev = [] 52 for i, sword in enumerate(self.words): 53 rprev = None 54 sthis = sprev[:i] 55 for j, rword in enumerate(rhs.words): 56 prevwords = [] 57 if i+j < len(sprev): 58 prevwords.append(sprev[i+j]) 59 if rprev is not None: 60 prevwords.append(rprev) 61 vhi, vlo = self.target.muladd(sword, rword, *prevwords) 62 sthis.append(vlo) 63 rprev = vhi 64 sthis.append(rprev) 65 sprev = sthis 66 67 # Remove unneeded words from the top of the output, if we can 68 # prove by range analysis that they'll always be zero. 69 sprev = sprev[:self.target.nwords(newmax)] 70 71 return Multiprecision(self.target, newmin, newmax, sprev) 72 73 def extract_bits(self, start, bits=None): 74 if bits is None: 75 bits = (self.maxval >> start).bit_length() 76 77 # Overly thorough range analysis: if min and max have the same 78 # *quotient* by 2^bits, then the result of reducing anything 79 # in the range [min,max] mod 2^bits has to fall within the 80 # obvious range. But if they have different quotients, then 81 # you can wrap round the modulus and so any value mod 2^bits 82 # is possible. 83 newmin = self.minval >> start 84 newmax = self.maxval >> start 85 if (newmin >> bits) != (newmax >> bits): 86 newmin = 0 87 newmax = (1 << bits) - 1 88 89 nwords = self.target.nwords(newmax) 90 words = [] 91 for i in range(nwords): 92 srcpos = i * self.target.bits + start 93 maxbits = min(self.target.bits, start + bits - srcpos) 94 wordindex = srcpos // self.target.bits 95 if srcpos % self.target.bits == 0: 96 word = self.getword(srcpos // self.target.bits) 97 elif (wordindex+1 >= len(self.words) or 98 srcpos % self.target.bits + maxbits < self.target.bits): 99 word = self.target.new_value( 100 "(%%s) >> %d" % (srcpos % self.target.bits), 101 self.getword(srcpos // self.target.bits)) 102 else: 103 word = self.target.new_value( 104 "((%%s) >> %d) | ((%%s) << %d)" % ( 105 srcpos % self.target.bits, 106 self.target.bits - (srcpos % self.target.bits)), 107 self.getword(srcpos // self.target.bits), 108 self.getword(srcpos // self.target.bits + 1)) 109 if maxbits < self.target.bits and maxbits < bits: 110 word = self.target.new_value( 111 "(%%s) & ((((BignumInt)1) << %d)-1)" % maxbits, 112 word) 113 words.append(word) 114 115 return Multiprecision(self.target, newmin, newmax, words) 116 117# Each Statement has a list of variables it reads, and a list of ones 118# it writes. 'forms' is a list of multiple actual C statements it 119# could be generated as, depending on which of its output variables is 120# actually used (e.g. no point calling BignumADC if the generated 121# carry in a particular case is unused, or BignumMUL if nobody needs 122# the top half). It is indexed by a bitmap whose bits correspond to 123# the entries in wvars, with wvars[0] the MSB and wvars[-1] the LSB. 124Statement = namedtuple("Statement", "rvars wvars forms") 125 126class CodegenTarget(object): 127 def __init__(self, bits): 128 self.bits = bits 129 self.valindex = 0 130 self.stmts = [] 131 self.generators = {} 132 self.bv_words = (130 + self.bits - 1) // self.bits 133 self.carry_index = 0 134 135 def nwords(self, maxval): 136 return (maxval.bit_length() + self.bits - 1) // self.bits 137 138 def stmt(self, stmt, needed=False): 139 index = len(self.stmts) 140 self.stmts.append([needed, stmt]) 141 for val in stmt.wvars: 142 self.generators[val] = index 143 144 def new_value(self, formatstr=None, *deps): 145 name = "v%d" % self.valindex 146 self.valindex += 1 147 if formatstr is not None: 148 self.stmt(Statement( 149 rvars=deps, wvars=[name], 150 forms=[None, name + " = " + formatstr % deps])) 151 return name 152 153 def bigval_input(self, name, bits): 154 words = (bits + self.bits - 1) // self.bits 155 # Expect not to require an entire extra word 156 assert words == self.bv_words 157 158 return Multiprecision(self, 0, (1<<bits)-1, [ 159 self.new_value("%s->w[%d]" % (name, i)) for i in range(words)]) 160 161 def const(self, value): 162 # We only support constants small enough to both fit in a 163 # BignumInt (of any size supported) _and_ be expressible in C 164 # with no weird integer literal syntax like a trailing LL. 165 # 166 # Supporting larger constants would be possible - you could 167 # break 'value' up into word-sized pieces on the Python side, 168 # and generate a legal C expression for each piece by 169 # splitting it further into pieces within the 170 # standards-guaranteed 'unsigned long' limit of 32 bits and 171 # then casting those to BignumInt before combining them with 172 # shifts. But it would be a lot of effort, and since the 173 # application for this code doesn't even need it, there's no 174 # point in bothering. 175 assert value < 2**16 176 return Multiprecision(self, value, value, ["%d" % value]) 177 178 def current_carry(self): 179 return "carry%d" % self.carry_index 180 181 def add(self, a1, a2): 182 ret = self.new_value() 183 adcform = "BignumADC(%s, carry, %s, %s, 0)" % (ret, a1, a2) 184 plainform = "%s = %s + %s" % (ret, a1, a2) 185 self.carry_index += 1 186 carryout = self.current_carry() 187 self.stmt(Statement( 188 rvars=[a1,a2], wvars=[ret,carryout], 189 forms=[None, adcform, plainform, adcform])) 190 return ret 191 192 def adc(self, a1, a2): 193 ret = self.new_value() 194 adcform = "BignumADC(%s, carry, %s, %s, carry)" % (ret, a1, a2) 195 plainform = "%s = %s + %s + carry" % (ret, a1, a2) 196 carryin = self.current_carry() 197 self.carry_index += 1 198 carryout = self.current_carry() 199 self.stmt(Statement( 200 rvars=[a1,a2,carryin], wvars=[ret,carryout], 201 forms=[None, adcform, plainform, adcform])) 202 return ret 203 204 def muladd(self, m1, m2, *addends): 205 rlo = self.new_value() 206 rhi = self.new_value() 207 wideform = "BignumMUL%s(%s)" % ( 208 { 0:"", 1:"ADD", 2:"ADD2" }[len(addends)], 209 ", ".join([rhi, rlo, m1, m2] + list(addends))) 210 narrowform = " + ".join(["%s = %s * %s" % (rlo, m1, m2)] + 211 list(addends)) 212 self.stmt(Statement( 213 rvars=[m1,m2]+list(addends), wvars=[rhi,rlo], 214 forms=[None, narrowform, wideform, wideform])) 215 return rhi, rlo 216 217 def write_bigval(self, name, val): 218 for i in range(self.bv_words): 219 word = val.getword(i) 220 self.stmt(Statement( 221 rvars=[word], wvars=[], 222 forms=["%s->w[%d] = %s" % (name, i, word)]), 223 needed=True) 224 225 def compute_needed(self): 226 used_vars = set() 227 228 self.queue = [stmt for (needed,stmt) in self.stmts if needed] 229 while len(self.queue) > 0: 230 stmt = self.queue.pop(0) 231 deps = [] 232 for var in stmt.rvars: 233 if var[0] in string.digits: 234 continue # constant 235 deps.append(self.generators[var]) 236 used_vars.add(var) 237 for index in deps: 238 if not self.stmts[index][0]: 239 self.stmts[index][0] = True 240 self.queue.append(self.stmts[index][1]) 241 242 forms = [] 243 for i, (needed, stmt) in enumerate(self.stmts): 244 if needed: 245 formindex = 0 246 for (j, var) in enumerate(stmt.wvars): 247 formindex *= 2 248 if var in used_vars: 249 formindex += 1 250 forms.append(stmt.forms[formindex]) 251 252 # Now we must check whether this form of the statement 253 # also writes some variables we _don't_ actually need 254 # (e.g. if you only wanted the top half from a mul, or 255 # only the carry from an adc, you'd be forced to 256 # generate the other output too). Easiest way to do 257 # this is to look for an identical statement form 258 # later in the array. 259 maxindex = max(i for i in range(len(stmt.forms)) 260 if stmt.forms[i] == stmt.forms[formindex]) 261 extra_vars = maxindex & ~formindex 262 bitpos = 0 263 while extra_vars != 0: 264 if extra_vars & (1 << bitpos): 265 extra_vars &= ~(1 << bitpos) 266 var = stmt.wvars[-1-bitpos] 267 used_vars.add(var) 268 # Also, write out a cast-to-void for each 269 # subsequently unused value, to prevent gcc 270 # warnings when the output code is compiled. 271 forms.append("(void)" + var) 272 bitpos += 1 273 274 used_carry = any(v.startswith("carry") for v in used_vars) 275 used_vars = [v for v in used_vars if v.startswith("v")] 276 used_vars.sort(key=lambda v: int(v[1:])) 277 278 return used_carry, used_vars, forms 279 280 def text(self): 281 used_carry, values, forms = self.compute_needed() 282 283 ret = "" 284 while len(values) > 0: 285 prefix, sep, suffix = " BignumInt ", ", ", ";" 286 currline = values.pop(0) 287 while (len(values) > 0 and 288 len(prefix+currline+sep+values[0]+suffix) < 79): 289 currline += sep + values.pop(0) 290 ret += prefix + currline + suffix + "\n" 291 if used_carry: 292 ret += " BignumCarry carry;\n" 293 if ret != "": 294 ret += "\n" 295 for stmtform in forms: 296 ret += " %s;\n" % stmtform 297 return ret 298 299def gen_add(target): 300 # This is an addition _without_ reduction mod p, so that it can be 301 # used both during accumulation of the polynomial and for adding 302 # on the encrypted nonce at the end (which is mod 2^128, not mod 303 # p). 304 # 305 # Because one of the inputs will have come from our 306 # not-completely-reducing multiplication function, we expect up to 307 # 3 extra bits of input. 308 309 a = target.bigval_input("a", 133) 310 b = target.bigval_input("b", 133) 311 ret = a + b 312 target.write_bigval("r", ret) 313 return """\ 314static void bigval_add(bigval *r, const bigval *a, const bigval *b) 315{ 316%s} 317\n""" % target.text() 318 319def gen_mul(target): 320 # The inputs are not 100% reduced mod p. Specifically, we can get 321 # a full 130-bit number from the pow5==0 pass, and then a 130-bit 322 # number times 5 from the pow5==1 pass, plus a possible carry. The 323 # total of that can be easily bounded above by 2^130 * 8, so we 324 # need to assume we're multiplying two 133-bit numbers. 325 326 a = target.bigval_input("a", 133) 327 b = target.bigval_input("b", 133) 328 ab = a * b 329 ab0 = ab.extract_bits(0, 130) 330 ab1 = ab.extract_bits(130, 130) 331 ab2 = ab.extract_bits(260) 332 ab1_5 = target.const(5) * ab1 333 ab2_25 = target.const(25) * ab2 334 ret = ab0 + ab1_5 + ab2_25 335 target.write_bigval("r", ret) 336 return """\ 337static void bigval_mul_mod_p(bigval *r, const bigval *a, const bigval *b) 338{ 339%s} 340\n""" % target.text() 341 342def gen_final_reduce(target): 343 # Given our input number n, n >> 130 is usually precisely the 344 # multiple of p that needs to be subtracted from n to reduce it to 345 # strictly less than p, but it might be too low by 1 (but not more 346 # than 1, given the range of our input is nowhere near the square 347 # of the modulus). So we add another 5, which will push a carry 348 # into the 130th bit if and only if that has happened, and then 349 # use that to decide whether to subtract one more copy of p. 350 351 a = target.bigval_input("n", 133) 352 q = a.extract_bits(130) 353 adjusted = a.extract_bits(0, 130) + target.const(5) * q 354 final_subtract = (adjusted + target.const(5)).extract_bits(130) 355 adjusted2 = adjusted + target.const(5) * final_subtract 356 ret = adjusted2.extract_bits(0, 130) 357 target.write_bigval("n", ret) 358 return """\ 359static void bigval_final_reduce(bigval *n) 360{ 361%s} 362\n""" % target.text() 363 364pp_keyword = "#if" 365for bits in [16, 32, 64]: 366 sys.stdout.write("%s BIGNUM_INT_BITS == %d\n\n" % (pp_keyword, bits)) 367 pp_keyword = "#elif" 368 sys.stdout.write(gen_add(CodegenTarget(bits))) 369 sys.stdout.write(gen_mul(CodegenTarget(bits))) 370 sys.stdout.write(gen_final_reduce(CodegenTarget(bits))) 371sys.stdout.write("""#else 372#error Add another bit count to contrib/make1305.py and rerun it 373#endif 374""") 375