1import binascii 2class InvalidEncodingException(Exception): pass 3class NotOnCurveException(Exception): pass 4class SpecException(Exception): pass 5 6def lobit(x): return int(x) & 1 7def hibit(x): return lobit(2*x) 8def negative(x): return lobit(x) 9def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)]) 10def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x)) 11def randombytes(n): return bytearray([randint(0,255) for _ in range(n)]) 12 13def optimized_version_of(spec): 14 """Decorator: This function is an optimized version of some specification""" 15 def decorator(f): 16 def wrapper(self,*args,**kwargs): 17 def pr(x): 18 if isinstance(x,bytearray): return binascii.hexlify(x) 19 else: return str(x) 20 try: spec_ans = getattr(self,spec,spec)(*args,**kwargs),None 21 except Exception as e: spec_ans = None,e 22 try: opt_ans = f(self,*args,**kwargs),None 23 except Exception as e: opt_ans = None,e 24 if spec_ans[1] is None and opt_ans[1] is not None: 25 raise 26 #raise SpecException("Mismatch in %s: spec returned %s but opt threw %s" 27 # % (f.__name__,str(spec_ans[0]),str(opt_ans[1]))) 28 if spec_ans[1] is not None and opt_ans[1] is None: 29 raise 30 #raise SpecException("Mismatch in %s: spec threw %s but opt returned %s" 31 # % (f.__name__,str(spec_ans[1]),str(opt_ans[0]))) 32 if spec_ans[0] != opt_ans[0]: 33 raise SpecException("Mismatch in %s: %s != %s" 34 % (f.__name__,pr(spec_ans[0]),pr(opt_ans[0]))) 35 if opt_ans[1] is not None: raise 36 else: return opt_ans[0] 37 wrapper.__name__ = f.__name__ 38 return wrapper 39 return decorator 40 41def xsqrt(x,exn=InvalidEncodingException("Not on curve")): 42 """Return sqrt(x)""" 43 if not is_square(x): raise exn 44 s = sqrt(x) 45 if negative(s): s=-s 46 return s 47 48def isqrt(x,exn=InvalidEncodingException("Not on curve")): 49 """Return 1/sqrt(x)""" 50 if x==0: return 0 51 if not is_square(x): raise exn 52 s = sqrt(x) 53 #if negative(s): s=-s 54 return 1/s 55 56def inv0(x): return 1/x if x != 0 else 0 57 58def isqrt_i(x): 59 """Return 1/sqrt(x) or 1/sqrt(zeta * x)""" 60 if x==0: return True,0 61 gen = x.parent(-1) 62 while is_square(gen): gen = sqrt(gen) 63 if is_square(x): return True,1/sqrt(x) 64 else: return False,1/sqrt(x*gen) 65 66class QuotientEdwardsPoint(object): 67 """Abstract class for point an a quotiented Edwards curve; needs F,a,d,cofactor to work""" 68 def __init__(self,x=0,y=1): 69 x = self.x = self.F(x) 70 y = self.y = self.F(y) 71 if y^2 + self.a*x^2 != 1 + self.d*x^2*y^2: 72 raise NotOnCurveException(str(self)) 73 74 def __repr__(self): 75 return "%s(0x%x,0x%x)" % (self.__class__.__name__, self.x, self.y) 76 77 def __iter__(self): 78 yield self.x 79 yield self.y 80 81 def __add__(self,other): 82 x,y = self 83 X,Y = other 84 a,d = self.a,self.d 85 return self.__class__( 86 (x*Y+y*X)/(1+d*x*y*X*Y), 87 (y*Y-a*x*X)/(1-d*x*y*X*Y) 88 ) 89 90 def __neg__(self): return self.__class__(-self.x,self.y) 91 def __sub__(self,other): return self + (-other) 92 def __rmul__(self,other): return self*other 93 def __eq__(self,other): 94 """NB: this is the only method that is different from the usual one""" 95 x,y = self 96 X,Y = other 97 return x*Y == X*y or (self.cofactor==8 and -self.a*x*X == y*Y) 98 def __ne__(self,other): return not (self==other) 99 100 def __mul__(self,exp): 101 exp = int(exp) 102 if exp < 0: exp,self = -exp,-self 103 total = self.__class__() 104 work = self 105 while exp != 0: 106 if exp & 1: total += work 107 work += work 108 exp >>= 1 109 return total 110 111 def xyzt(self): 112 x,y = self 113 z = self.F.random_element() 114 return x*z,y*z,z,x*y*z 115 116 def torque(self): 117 """Apply cofactor group, except keeping the point even""" 118 if self.cofactor == 8: 119 if self.a == -1: return self.__class__(self.y*self.i, self.x*self.i) 120 if self.a == 1: return self.__class__(-self.y, self.x) 121 else: 122 return self.__class__(-self.x, -self.y) 123 124 def doubleAndEncodeSpec(self): 125 return (self+self).encode() 126 127 # Utility functions 128 @classmethod 129 def bytesToGf(cls,bytes,mustBeProper=True,mustBePositive=False,maskHiBits=False): 130 """Convert little-endian bytes to field element, sanity check length""" 131 if len(bytes) != cls.encLen: 132 raise InvalidEncodingException("wrong length %d" % len(bytes)) 133 s = dec_le(bytes) 134 if mustBeProper and s >= cls.F.order(): 135 raise InvalidEncodingException("%d out of range!" % s) 136 bitlen = int(ceil(log(cls.F.order())/log(2))) 137 if maskHiBits: s &= 2^bitlen-1 138 s = cls.F(s) 139 if mustBePositive and negative(s): 140 raise InvalidEncodingException("%d is negative!" % s) 141 return s 142 143 @classmethod 144 def gfToBytes(cls,x,mustBePositive=False): 145 """Convert little-endian bytes to field element, sanity check length""" 146 if negative(x) and mustBePositive: x = -x 147 return enc_le(x,cls.encLen) 148 149class RistrettoPoint(QuotientEdwardsPoint): 150 """The new Ristretto group""" 151 def encodeSpec(self): 152 """Unoptimized specification for encoding""" 153 x,y = self 154 if self.cofactor==8 and (negative(x*y) or y==0): (x,y) = self.torque() 155 if y == -1: y = 1 # Avoid divide by 0; doesn't affect impl 156 157 if negative(x): x,y = -x,-y 158 s = xsqrt(self.mneg*(1-y)/(1+y),exn=Exception("Unimplemented: point is odd: " + str(self))) 159 return self.gfToBytes(s) 160 161 @classmethod 162 def decodeSpec(cls,s): 163 """Unoptimized specification for decoding""" 164 s = cls.bytesToGf(s,mustBePositive=True) 165 166 a,d = cls.a,cls.d 167 x = xsqrt(4*s^2 / (a*d*(1+a*s^2)^2 - (1-a*s^2)^2)) 168 y = (1+a*s^2) / (1-a*s^2) 169 170 if cls.cofactor==8 and (negative(x*y) or y==0): 171 raise InvalidEncodingException("x*y has high bit") 172 173 return cls(x,y) 174 175 @optimized_version_of("encodeSpec") 176 def encode(self): 177 """Encode, optimized version""" 178 a,d,mneg = self.a,self.d,self.mneg 179 x,y,z,t = self.xyzt() 180 181 if self.cofactor==8: 182 u1 = mneg*(z+y)*(z-y) 183 u2 = x*y # = t*z 184 isr = isqrt(u1*u2^2) 185 i1 = isr*u1 # sqrt(mneg*(z+y)*(z-y))/(x*y) 186 i2 = isr*u2 # 1/sqrt(a*(y+z)*(y-z)) 187 z_inv = i1*i2*t # 1/z 188 189 if negative(t*z_inv): 190 if a==-1: 191 x,y = y*self.i,x*self.i 192 den_inv = self.magic * i1 193 else: 194 x,y = -y,x 195 den_inv = self.i * self.magic * i1 196 197 else: 198 den_inv = i2 199 200 if negative(x*z_inv): y = -y 201 s = (z-y) * den_inv 202 else: 203 num = mneg*(z+y)*(z-y) 204 isr = isqrt(num*y^2) 205 if negative(isr^2*num*y*t): y = -y 206 s = isr*y*(z-y) 207 208 return self.gfToBytes(s,mustBePositive=True) 209 210 @optimized_version_of("doubleAndEncodeSpec") 211 def doubleAndEncode(self): 212 X,Y,Z,T = self.xyzt() 213 a,d,mneg = self.a,self.d,self.mneg 214 215 if self.cofactor==8: 216 e = 2*X*Y 217 f = Z^2+d*T^2 218 g = Y^2-a*X^2 219 h = Z^2-d*T^2 220 221 inv1 = 1/(e*f*g*h) 222 z_inv = inv1*e*g # 1 / (f*h) 223 t_inv = inv1*f*h 224 225 if negative(e*g*z_inv): 226 if a==-1: sqrta = self.i 227 else: sqrta = -1 228 e,f,g,h = g,h,-e,f*sqrta 229 factor = self.i 230 else: 231 factor = self.magic 232 233 if negative(h*e*z_inv): g=-g 234 s = (h-g)*factor*g*t_inv 235 236 else: 237 foo = Y^2+a*X^2 238 bar = X*Y 239 den = 1/(foo*bar) 240 if negative(2*bar^2*den): tmp = a*X^2 241 else: tmp = Y^2 242 s = self.magic*(Z^2-tmp)*foo*den 243 244 return self.gfToBytes(s,mustBePositive=True) 245 246 @classmethod 247 @optimized_version_of("decodeSpec") 248 def decode(cls,s): 249 """Decode, optimized version""" 250 s = cls.bytesToGf(s,mustBePositive=True) 251 252 a,d = cls.a,cls.d 253 yden = 1-a*s^2 254 ynum = 1+a*s^2 255 yden_sqr = yden^2 256 xden_sqr = a*d*ynum^2 - yden_sqr 257 258 isr = isqrt(xden_sqr * yden_sqr) 259 260 xden_inv = isr * yden 261 yden_inv = xden_inv * isr * xden_sqr 262 263 x = 2*s*xden_inv 264 if negative(x): x = -x 265 y = ynum * yden_inv 266 267 if cls.cofactor==8 and (negative(x*y) or y==0): 268 raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y)) 269 270 return cls(x,y) 271 272 @classmethod 273 def fromJacobiQuartic(cls,s,t,sgn=1): 274 """Convert point from its Jacobi Quartic representation""" 275 a,d = cls.a,cls.d 276 assert s^4 - 2*cls.a*(1-2*d/(d-a))*s^2 + 1 == t^2 277 x = 2*s*cls.magic / t 278 y = (1+a*s^2) / (1-a*s^2) 279 return cls(sgn*x,y) 280 281 @classmethod 282 def elligatorSpec(cls,r0): 283 a,d = cls.a,cls.d 284 r = cls.qnr * cls.bytesToGf(r0,mustBeProper=False,maskHiBits=True)^2 285 den = (d*r-a)*(a*r-d) 286 if den == 0: return cls() 287 n1 = cls.a*(r+1)*(a+d)*(d-a)/den 288 n2 = r*n1 289 if is_square(n1): 290 sgn,s,t = 1, xsqrt(n1), -(r-1)*(a+d)^2 / den - 1 291 else: 292 sgn,s,t = -1,-xsqrt(n2), r*(r-1)*(a+d)^2 / den - 1 293 294 return cls.fromJacobiQuartic(s,t) 295 296 @classmethod 297 @optimized_version_of("elligatorSpec") 298 def elligator(cls,r0): 299 a,d = cls.a,cls.d 300 r0 = cls.bytesToGf(r0,mustBeProper=False,maskHiBits=True) 301 r = cls.qnr * r0^2 302 den = (d*r-a)*(a*r-d) 303 num = cls.a*(r+1)*(a+d)*(d-a) 304 305 iss,isri = isqrt_i(num*den) 306 if iss: sgn,twiddle = 1,1 307 else: sgn,twiddle = -1,r0*cls.qnr 308 isri *= twiddle 309 s = isri*num 310 t = -sgn*isri*s*(r-1)*(d+a)^2 - 1 311 if negative(s) == iss: s = -s 312 return cls.fromJacobiQuartic(s,t) 313 314 315class Decaf_1_1_Point(QuotientEdwardsPoint): 316 """Like current decaf but tweaked for simplicity""" 317 def encodeSpec(self): 318 """Unoptimized specification for encoding""" 319 a,d = self.a,self.d 320 x,y = self 321 if x==0 or y==0: return(self.gfToBytes(0)) 322 323 if self.cofactor==8 and negative(x*y*self.isoMagic): 324 x,y = self.torque() 325 326 sr = xsqrt(1-a*x^2) 327 altx = x*y*self.isoMagic / sr 328 if negative(altx): s = (1+sr)/x 329 else: s = (1-sr)/x 330 331 return self.gfToBytes(s,mustBePositive=True) 332 333 @classmethod 334 def decodeSpec(cls,s): 335 """Unoptimized specification for decoding""" 336 a,d = cls.a,cls.d 337 s = cls.bytesToGf(s,mustBePositive=True) 338 339 if s==0: return cls() 340 t = xsqrt(s^4 + 2*(a-2*d)*s^2 + 1) 341 altx = 2*s*cls.isoMagic/t 342 if negative(altx): t = -t 343 x = 2*s / (1+a*s^2) 344 y = (1-a*s^2) / t 345 346 if cls.cofactor==8 and (negative(x*y*cls.isoMagic) or y==0): 347 raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y)) 348 349 return cls(x,y) 350 351 def toJacobiQuartic(self,toggle_rotation=False,toggle_altx=False,toggle_s=False): 352 "Return s,t on jacobi curve" 353 a,d = self.a,self.d 354 x,y,z,t = self.xyzt() 355 356 if self.cofactor == 8: 357 # Cofactor 8 version 358 # Simulate IMAGINE_TWIST because that's how libdecaf does it 359 x = self.i*x 360 t = self.i*t 361 a = -a 362 d = -d 363 364 # OK, the actual libdecaf code should be here 365 num = (z+y)*(z-y) 366 den = x*y 367 isr = isqrt(num*(a-d)*den^2) 368 369 iden = isr * den * self.isoMagic # 1/sqrt((z+y)(z-y)) = 1/sqrt(1-Y^2) / z 370 inum = isr * num # sqrt(1-Y^2) * z / xysqrt(a-d) ~ 1/sqrt(1-ax^2)/z 371 372 if negative(iden*inum*self.i*t^2*(d-a)) != toggle_rotation: 373 iden,inum = inum,iden 374 fac = x*sqrt(a) 375 toggle=(a==-1) 376 else: 377 fac = y 378 toggle=False 379 380 imi = self.isoMagic * self.i 381 altx = inum*t*imi 382 neg_altx = negative(altx) != toggle_altx 383 if neg_altx != toggle: inum =- inum 384 385 tmp = fac*(inum*z + 1) 386 s = iden*tmp*imi 387 388 negm1 = (negative(s) != toggle_s) != neg_altx 389 if negm1: m1 = a*fac + z 390 else: m1 = a*fac - z 391 392 swap = toggle_s 393 394 else: 395 # Much simpler cofactor 4 version 396 num = (x+t)*(x-t) 397 isr = isqrt(num*(a-d)*x^2) 398 ratio = isr*num 399 altx = ratio*self.isoMagic 400 401 neg_altx = negative(altx) != toggle_altx 402 if neg_altx: ratio =- ratio 403 404 tmp = ratio*z - t 405 s = (a-d)*isr*x*tmp 406 407 negx = (negative(s) != toggle_s) != neg_altx 408 if negx: m1 = -a*t + x 409 else: m1 = -a*t - x 410 411 swap = toggle_s 412 413 if negative(s): s = -s 414 415 return s,m1,a*tmp,swap 416 417 def invertElligator(self,toggle_r=False,*args,**kwargs): 418 "Produce preimage of self under elligator, or None" 419 a,d = self.a,self.d 420 421 rets = [] 422 423 tr = [False,True] if self.cofactor == 8 else [False] 424 for toggle_rotation in tr: 425 for toggle_altx in [False,True]: 426 for toggle_s in [False,True]: 427 for toggle_r in [False,True]: 428 s,m1,m12,swap = self.toJacobiQuartic(toggle_rotation,toggle_altx,toggle_s) 429 430 #print 431 #print toggle_rotation,toggle_altx,toggle_s 432 #print m1 433 #print m12 434 435 436 if self == self.__class__(): 437 if self.cofactor == 4: 438 # Hacks for identity! 439 if toggle_altx: m12 = 1 440 elif toggle_s: m1 = 1 441 elif toggle_r: continue 442 ## BOTH??? 443 444 else: 445 m12 = 1 446 imi = self.isoMagic * self.i 447 if toggle_rotation: 448 if toggle_altx: m1 = -imi 449 else: m1 = +imi 450 else: 451 if toggle_altx: m1 = 0 452 else: m1 = a-d 453 454 rnum = (d*a*m12-m1) 455 rden = ((d*a-1)*m12+m1) 456 if swap: rnum,rden = rden,rnum 457 458 ok,sr = isqrt_i(rnum*rden*self.qnr) 459 if not ok: continue 460 sr *= rnum 461 #print "Works! %d %x" % (swap,sr) 462 463 if negative(sr) != toggle_r: sr = -sr 464 ret = self.gfToBytes(sr) 465 if self.elligator(ret) != self and self.elligator(ret) != -self: 466 print "WRONG!",[toggle_rotation,toggle_altx,toggle_s] 467 if self.elligator(ret) == -self and self != -self: print "Negated!",[toggle_rotation,toggle_altx,toggle_s] 468 rets.append(bytes(ret)) 469 return rets 470 471 @optimized_version_of("encodeSpec") 472 def encode(self): 473 """Encode, optimized version""" 474 return self.gfToBytes(self.toJacobiQuartic()[0]) 475 476 @classmethod 477 @optimized_version_of("decodeSpec") 478 def decode(cls,s): 479 """Decode, optimized version""" 480 a,d = cls.a,cls.d 481 s = cls.bytesToGf(s,mustBePositive=True) 482 483 #if s==0: return cls() 484 s2 = s^2 485 den = 1+a*s2 486 num = den^2 - 4*d*s2 487 isr = isqrt(num*den^2) 488 altx = 2*s*isr*den*cls.isoMagic 489 if negative(altx): isr = -isr 490 x = 2*s *isr^2*den*num 491 y = (1-a*s^2) * isr*den 492 493 if cls.cofactor==8 and (negative(x*y*cls.isoMagic) or y==0): 494 raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y)) 495 496 return cls(x,y) 497 498 @classmethod 499 def fromJacobiQuartic(cls,s,t,sgn=1): 500 """Convert point from its Jacobi Quartic representation""" 501 a,d = cls.a,cls.d 502 if s==0: return cls() 503 x = 2*s / (1+a*s^2) 504 y = (1-a*s^2) / t 505 return cls(x,sgn*y) 506 507 @optimized_version_of("doubleAndEncodeSpec") 508 def doubleAndEncode(self): 509 X,Y,Z,T = self.xyzt() 510 a,d = self.a,self.d 511 512 if self.cofactor == 8: 513 # Cofactor 8 version 514 # Simulate IMAGINE_TWIST because that's how libdecaf does it 515 X = self.i*X 516 T = self.i*T 517 a = -a 518 d = -d 519 # TODO: This is only being called for a=-1, so could 520 # be wrong for a=1 521 522 e = 2*X*Y 523 f = Y^2+a*X^2 524 g = Y^2-a*X^2 525 h = Z^2-d*T^2 526 527 eim = e*self.isoMagic 528 inv = 1/(eim*g*f*h) 529 fh_inv = eim*g*inv*self.i 530 531 if negative(eim*g*fh_inv): 532 idf = g*self.isoMagic*self.i 533 bar = f 534 foo = g 535 test = eim*f 536 else: 537 idf = eim 538 bar = h 539 foo = -eim 540 test = g*h 541 542 if negative(test*fh_inv): bar =- bar 543 s = idf*(foo+bar)*inv*f*h 544 545 else: 546 xy = X*Y 547 h = Z^2-d*T^2 548 inv = 1/(xy*h) 549 if negative(inv*2*xy^2*self.isoMagic): tmp = Y 550 else: tmp = X 551 s = tmp^2*h*inv # = X/Y or Y/X, interestingly 552 553 return self.gfToBytes(s,mustBePositive=True) 554 555 @classmethod 556 def elligatorSpec(cls,r0,fromR=False): 557 a,d = cls.a,cls.d 558 if fromR: r = r0 559 else: r = cls.qnr * cls.bytesToGf(r0,mustBeProper=False,maskHiBits=True)^2 560 561 den = (d*r-(d-a))*((d-a)*r-d) 562 if den == 0: return cls() 563 n1 = (r+1)*(a-2*d)/den 564 n2 = r*n1 565 if is_square(n1): 566 sgn,s,t = 1, xsqrt(n1), -(r-1)*(a-2*d)^2 / den - 1 567 else: 568 sgn,s,t = -1, -xsqrt(n2), r*(r-1)*(a-2*d)^2 / den - 1 569 570 return cls.fromJacobiQuartic(s,t) 571 572 @classmethod 573 @optimized_version_of("elligatorSpec") 574 def elligator(cls,r0): 575 a,d = cls.a,cls.d 576 r0 = cls.bytesToGf(r0,mustBeProper=False,maskHiBits=True) 577 r = cls.qnr * r0^2 578 den = (d*r-(d-a))*((d-a)*r-d) 579 num = (r+1)*(a-2*d) 580 581 iss,isri = isqrt_i(num*den) 582 if iss: sgn,twiddle = 1,1 583 else: sgn,twiddle = -1,r0*cls.qnr 584 isri *= twiddle 585 s = isri*num 586 t = -sgn*isri*s*(r-1)*(a-2*d)^2 - 1 587 if negative(s) == iss: s = -s 588 return cls.fromJacobiQuartic(s,t) 589 590 def elligatorInverseBruteForce(self): 591 """Invert Elligator using SAGE's polynomial solver""" 592 a,d = self.a,self.d 593 R.<r0> = self.F[] 594 r = self.qnr * r0^2 595 den = (d*r-(d-a))*((d-a)*r-d) 596 n1 = (r+1)*(a-2*d)/den 597 n2 = r*n1 598 ret = set() 599 for s2,t in [(n1, -(r-1)*(a-2*d)^2 / den - 1), 600 (n2,r*(r-1)*(a-2*d)^2 / den - 1)]: 601 x2 = 4*s2/(1+a*s2)^2 602 y = (1-a*s2) / t 603 604 selfT = self 605 for i in xrange(self.cofactor/2): 606 xT,yT = selfT 607 polyX = xT^2-x2 608 polyY = yT-y 609 sx = set(r for r,_ in polyX.numerator().roots()) 610 sy = set(r for r,_ in polyY.numerator().roots()) 611 ret = ret.union(sx.intersection(sy)) 612 613 selfT = selfT.torque() 614 615 ret = [self.gfToBytes(r) for r in ret] 616 617 for r in ret: 618 assert self.elligator(r) in [self,-self] 619 620 ret = [r for r in ret if self.elligator(r) == self] 621 622 return ret 623 624class Ed25519Point(RistrettoPoint): 625 F = GF(2^255-19) 626 d = F(-121665/121666) 627 a = F(-1) 628 i = sqrt(F(-1)) 629 mneg = F(1) 630 qnr = i 631 magic = isqrt(a*d-1) 632 cofactor = 8 633 encLen = 32 634 635 @classmethod 636 def base(cls): 637 return cls( 15112221349535400772501151409588531511454012693041857206046113283949847762202, 46316835694926478169428394003475163141307993866256225615783033603165251855960 638 ) 639 640class NegEd25519Point(RistrettoPoint): 641 F = GF(2^255-19) 642 d = F(121665/121666) 643 a = F(1) 644 i = sqrt(F(-1)) 645 mneg = F(-1) # TODO checkme vs 1-ad or whatever 646 qnr = i 647 magic = isqrt(a*d-1) 648 cofactor = 8 649 encLen = 32 650 651 @classmethod 652 def base(cls): 653 y = cls.F(4/5) 654 x = sqrt((y^2-1)/(cls.d*y^2-cls.a)) 655 if negative(x): x = -x 656 return cls(x,y) 657 658class IsoEd448Point(RistrettoPoint): 659 F = GF(2^448-2^224-1) 660 d = F(39082/39081) 661 a = F(1) 662 mneg = F(-1) 663 qnr = -1 664 magic = isqrt(a*d-1) 665 cofactor = 4 666 encLen = 56 667 668 @classmethod 669 def base(cls): 670 return cls( # RFC has it wrong 671 345397493039729516374008604150537410266655260075183290216406970281645695073672344430481787759340633221708391583424041788924124567700732, 672 -363419362147803445274661903944002267176820680343659030140745099590306164083365386343198191849338272965044442230921818680526749009182718 673 ) 674 675class TwistedEd448GoldilocksPoint(Decaf_1_1_Point): 676 F = GF(2^448-2^224-1) 677 d = F(-39082) 678 a = F(-1) 679 qnr = -1 680 cofactor = 4 681 encLen = 56 682 isoMagic = IsoEd448Point.magic 683 684 @classmethod 685 def base(cls): 686 return cls.decodeSpec(Ed448GoldilocksPoint.base().encodeSpec()) 687 688class Ed448GoldilocksPoint(Decaf_1_1_Point): 689 F = GF(2^448-2^224-1) 690 d = F(-39081) 691 a = F(1) 692 qnr = -1 693 cofactor = 4 694 encLen = 56 695 isoMagic = IsoEd448Point.magic 696 697 @classmethod 698 def base(cls): 699 return 2*cls( 700 224580040295924300187604334099896036246789641632564134246125461686950415467406032909029192869357953282578032075146446173674602635247710, 298819210078481492676017930443930673437544040154080242095928241372331506189835876003536878655418784733982303233503462500531545062832660 701 ) 702 703class IsoEd25519Point(Decaf_1_1_Point): 704 # TODO: twisted iso too! 705 # TODO: twisted iso might have to IMAGINE_TWIST or whatever 706 F = GF(2^255-19) 707 d = F(-121665) 708 a = F(1) 709 i = sqrt(F(-1)) 710 qnr = i 711 magic = isqrt(a*d-1) 712 cofactor = 8 713 encLen = 32 714 isoMagic = Ed25519Point.magic 715 isoA = Ed25519Point.a 716 717 @classmethod 718 def base(cls): 719 return cls.decodeSpec(Ed25519Point.base().encode()) 720 721class TestFailedException(Exception): pass 722 723def test(cls,n): 724 print "Testing curve %s" % cls.__name__ 725 726 specials = [1] 727 ii = cls.F(-1) 728 while is_square(ii): 729 specials.append(ii) 730 ii = sqrt(ii) 731 specials.append(ii) 732 for i in specials: 733 if negative(cls.F(i)): i = -i 734 i = enc_le(i,cls.encLen) 735 try: 736 Q = cls.decode(i) 737 QE = Q.encode() 738 if QE != i: 739 raise TestFailedException("Round trip special %s != %s" % 740 (binascii.hexlify(QE),binascii.hexlify(i))) 741 except NotOnCurveException: pass 742 except InvalidEncodingException: pass 743 744 745 P = cls.base() 746 Q = cls() 747 for i in xrange(n): 748 #print binascii.hexlify(Q.encode()) 749 QE = Q.encode() 750 QQ = cls.decode(QE) 751 if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q))) 752 753 # Testing s -> 1/s: encodes -point on cofactor 754 s = cls.bytesToGf(QE) 755 if s != 0: 756 ss = cls.gfToBytes(1/s,mustBePositive=True) 757 try: 758 QN = cls.decode(ss) 759 if cls.cofactor == 8: 760 raise TestFailedException("1/s shouldnt work for cofactor 8") 761 if QN != -Q: 762 raise TestFailedException("s -> 1/s should negate point for cofactor 4") 763 except InvalidEncodingException as e: 764 # Should be raised iff cofactor==8 765 if cls.cofactor == 4: 766 raise TestFailedException("s -> 1/s should work for cofactor 4") 767 768 QT = Q 769 for h in xrange(cls.cofactor): 770 QT = QT.torque() 771 if QT.encode() != QE: 772 raise TestFailedException("Can't torque %s,%d" % (str(Q),h+1)) 773 774 Q0 = Q + P 775 if Q0 == Q: raise TestFailedException("Addition doesn't work") 776 if Q0-P != Q: raise TestFailedException("Subtraction doesn't work") 777 778 r = randint(1,1000) 779 Q1 = Q0*r 780 Q2 = Q0*(r+1) 781 if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work") 782 Q = Q1 783 784def testElligator(cls,n): 785 print "Testing elligator on %s" % cls.__name__ 786 for i in xrange(n): 787 r = randombytes(cls.encLen) 788 P = cls.elligator(r) 789 if hasattr(P,"invertElligator"): 790 iv = P.invertElligator() 791 modr = bytes(cls.gfToBytes(cls.bytesToGf(r,mustBeProper=False,maskHiBits=True))) 792 iv2 = P.torque().invertElligator() 793 if modr not in iv: print "Failed to invert Elligator!" 794 if len(iv) != len(set(iv)): 795 print "Elligator inverses not unique!", len(set(iv)), len(iv) 796 if iv != iv2: 797 print "Elligator is untorqueable!" 798 #print [binascii.hexlify(j) for j in iv] 799 #print [binascii.hexlify(j) for j in iv2] 800 #break 801 else: 802 pass # TODO 803 804def gangtest(classes,n): 805 print "Gang test",[cls.__name__ for cls in classes] 806 specials = [1] 807 ii = classes[0].F(-1) 808 while is_square(ii): 809 specials.append(ii) 810 ii = sqrt(ii) 811 specials.append(ii) 812 813 for i in xrange(n): 814 rets = [bytes((cls.base()*i).encode()) for cls in classes] 815 if len(set(rets)) != 1: 816 print "Divergence in encode at %d" % i 817 for c,ret in zip(classes,rets): 818 print c,binascii.hexlify(ret) 819 print 820 821 if i < len(specials): r0 = enc_le(specials[i],classes[0].encLen) 822 else: r0 = randombytes(classes[0].encLen) 823 824 rets = [bytes((cls.elligator(r0)*i).encode()) for cls in classes] 825 if len(set(rets)) != 1: 826 print "Divergence in elligator at %d" % i 827 for c,ret in zip(classes,rets): 828 print c,binascii.hexlify(ret) 829 print 830 831def testDoubleAndEncode(cls,n): 832 print "Testing doubleAndEncode on %s" % cls.__name__ 833 for i in xrange(n): 834 r1 = randombytes(cls.encLen) 835 r2 = randombytes(cls.encLen) 836 u = cls.elligator(r1) + cls.elligator(r2) 837 u.doubleAndEncode() 838 839testDoubleAndEncode(Ed25519Point,100) 840testDoubleAndEncode(NegEd25519Point,100) 841testDoubleAndEncode(IsoEd25519Point,100) 842testDoubleAndEncode(IsoEd448Point,100) 843testDoubleAndEncode(TwistedEd448GoldilocksPoint,100) 844#test(Ed25519Point,100) 845#test(NegEd25519Point,100) 846#test(IsoEd25519Point,100) 847#test(IsoEd448Point,100) 848#test(TwistedEd448GoldilocksPoint,100) 849#test(Ed448GoldilocksPoint,100) 850#testElligator(Ed25519Point,100) 851#testElligator(NegEd25519Point,100) 852#testElligator(IsoEd25519Point,100) 853#testElligator(IsoEd448Point,100) 854#testElligator(Ed448GoldilocksPoint,100) 855#testElligator(TwistedEd448GoldilocksPoint,100) 856#gangtest([IsoEd448Point,TwistedEd448GoldilocksPoint,Ed448GoldilocksPoint],100) 857#gangtest([Ed25519Point,IsoEd25519Point],100) 858