1import binascii
2class InvalidEncodingException(Exception): pass
3class NotOnCurveException(Exception): pass
4class SpecException(Exception): pass
5
6def lobit(x): return int(x) & 1
7def hibit(x): return lobit(2*x)
8def negative(x): return lobit(x)
9def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)])
10def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x))
11def randombytes(n): return bytearray([randint(0,255) for _ in range(n)])
12
13def optimized_version_of(spec):
14    """Decorator: This function is an optimized version of some specification"""
15    def decorator(f):
16        def wrapper(self,*args,**kwargs):
17            def pr(x):
18                if isinstance(x,bytearray): return binascii.hexlify(x)
19                else: return str(x)
20            try: spec_ans = getattr(self,spec,spec)(*args,**kwargs),None
21            except Exception as e: spec_ans = None,e
22            try: opt_ans = f(self,*args,**kwargs),None
23            except Exception as e: opt_ans = None,e
24            if spec_ans[1] is None and opt_ans[1] is not None:
25                raise
26                #raise SpecException("Mismatch in %s: spec returned %s but opt threw %s"
27                #    % (f.__name__,str(spec_ans[0]),str(opt_ans[1])))
28            if spec_ans[1] is not None and opt_ans[1] is None:
29                raise
30                #raise SpecException("Mismatch in %s: spec threw %s but opt returned %s"
31                #    % (f.__name__,str(spec_ans[1]),str(opt_ans[0])))
32            if spec_ans[0] != opt_ans[0]:
33                raise SpecException("Mismatch in %s: %s != %s"
34                    % (f.__name__,pr(spec_ans[0]),pr(opt_ans[0])))
35            if opt_ans[1] is not None: raise
36            else: return opt_ans[0]
37        wrapper.__name__ = f.__name__
38        return wrapper
39    return decorator
40
41def xsqrt(x,exn=InvalidEncodingException("Not on curve")):
42    """Return sqrt(x)"""
43    if not is_square(x): raise exn
44    s = sqrt(x)
45    if negative(s): s=-s
46    return s
47
48def isqrt(x,exn=InvalidEncodingException("Not on curve")):
49    """Return 1/sqrt(x)"""
50    if x==0: return 0
51    if not is_square(x): raise exn
52    s = sqrt(x)
53    #if negative(s): s=-s
54    return 1/s
55
56def inv0(x): return 1/x if x != 0 else 0
57
58def isqrt_i(x):
59    """Return 1/sqrt(x) or 1/sqrt(zeta * x)"""
60    if x==0: return True,0
61    gen = x.parent(-1)
62    while is_square(gen): gen = sqrt(gen)
63    if is_square(x): return True,1/sqrt(x)
64    else: return False,1/sqrt(x*gen)
65
66class QuotientEdwardsPoint(object):
67    """Abstract class for point an a quotiented Edwards curve; needs F,a,d,cofactor to work"""
68    def __init__(self,x=0,y=1):
69        x = self.x = self.F(x)
70        y = self.y = self.F(y)
71        if y^2 + self.a*x^2 != 1 + self.d*x^2*y^2:
72            raise NotOnCurveException(str(self))
73
74    def __repr__(self):
75        return "%s(0x%x,0x%x)" % (self.__class__.__name__, self.x, self.y)
76
77    def __iter__(self):
78        yield self.x
79        yield self.y
80
81    def __add__(self,other):
82        x,y = self
83        X,Y = other
84        a,d = self.a,self.d
85        return self.__class__(
86            (x*Y+y*X)/(1+d*x*y*X*Y),
87            (y*Y-a*x*X)/(1-d*x*y*X*Y)
88        )
89
90    def __neg__(self): return self.__class__(-self.x,self.y)
91    def __sub__(self,other): return self + (-other)
92    def __rmul__(self,other): return self*other
93    def __eq__(self,other):
94        """NB: this is the only method that is different from the usual one"""
95        x,y = self
96        X,Y = other
97        return x*Y == X*y or (self.cofactor==8 and -self.a*x*X == y*Y)
98    def __ne__(self,other): return not (self==other)
99
100    def __mul__(self,exp):
101        exp = int(exp)
102        if exp < 0: exp,self = -exp,-self
103        total = self.__class__()
104        work  = self
105        while exp != 0:
106            if exp & 1: total += work
107            work += work
108            exp >>= 1
109        return total
110
111    def xyzt(self):
112        x,y = self
113        z = self.F.random_element()
114        return x*z,y*z,z,x*y*z
115
116    def torque(self):
117        """Apply cofactor group, except keeping the point even"""
118        if self.cofactor == 8:
119            if self.a == -1: return self.__class__(self.y*self.i, self.x*self.i)
120            if self.a ==  1: return self.__class__(-self.y, self.x)
121        else:
122            return self.__class__(-self.x, -self.y)
123
124    def doubleAndEncodeSpec(self):
125        return (self+self).encode()
126
127    # Utility functions
128    @classmethod
129    def bytesToGf(cls,bytes,mustBeProper=True,mustBePositive=False,maskHiBits=False):
130        """Convert little-endian bytes to field element, sanity check length"""
131        if len(bytes) != cls.encLen:
132            raise InvalidEncodingException("wrong length %d" % len(bytes))
133        s = dec_le(bytes)
134        if mustBeProper and s >= cls.F.order():
135            raise InvalidEncodingException("%d out of range!" % s)
136        bitlen = int(ceil(log(cls.F.order())/log(2)))
137        if maskHiBits: s &= 2^bitlen-1
138        s = cls.F(s)
139        if mustBePositive and negative(s):
140            raise InvalidEncodingException("%d is negative!" % s)
141        return s
142
143    @classmethod
144    def gfToBytes(cls,x,mustBePositive=False):
145        """Convert little-endian bytes to field element, sanity check length"""
146        if negative(x) and mustBePositive: x = -x
147        return enc_le(x,cls.encLen)
148
149class RistrettoPoint(QuotientEdwardsPoint):
150    """The new Ristretto group"""
151    def encodeSpec(self):
152        """Unoptimized specification for encoding"""
153        x,y = self
154        if self.cofactor==8 and (negative(x*y) or y==0): (x,y) = self.torque()
155        if y == -1: y = 1 # Avoid divide by 0; doesn't affect impl
156
157        if negative(x): x,y = -x,-y
158        s = xsqrt(self.mneg*(1-y)/(1+y),exn=Exception("Unimplemented: point is odd: " + str(self)))
159        return self.gfToBytes(s)
160
161    @classmethod
162    def decodeSpec(cls,s):
163        """Unoptimized specification for decoding"""
164        s = cls.bytesToGf(s,mustBePositive=True)
165
166        a,d = cls.a,cls.d
167        x = xsqrt(4*s^2 / (a*d*(1+a*s^2)^2 - (1-a*s^2)^2))
168        y = (1+a*s^2) / (1-a*s^2)
169
170        if cls.cofactor==8 and (negative(x*y) or y==0):
171            raise InvalidEncodingException("x*y has high bit")
172
173        return cls(x,y)
174
175    @optimized_version_of("encodeSpec")
176    def encode(self):
177        """Encode, optimized version"""
178        a,d,mneg = self.a,self.d,self.mneg
179        x,y,z,t = self.xyzt()
180
181        if self.cofactor==8:
182            u1    = mneg*(z+y)*(z-y)
183            u2    = x*y # = t*z
184            isr   = isqrt(u1*u2^2)
185            i1    = isr*u1 # sqrt(mneg*(z+y)*(z-y))/(x*y)
186            i2    = isr*u2 # 1/sqrt(a*(y+z)*(y-z))
187            z_inv = i1*i2*t # 1/z
188
189            if negative(t*z_inv):
190                if a==-1:
191                    x,y = y*self.i,x*self.i
192                    den_inv = self.magic * i1
193                else:
194                    x,y = -y,x
195                    den_inv = self.i * self.magic * i1
196
197            else:
198                den_inv = i2
199
200            if negative(x*z_inv): y = -y
201            s = (z-y) * den_inv
202        else:
203            num   = mneg*(z+y)*(z-y)
204            isr   = isqrt(num*y^2)
205            if negative(isr^2*num*y*t): y = -y
206            s = isr*y*(z-y)
207
208        return self.gfToBytes(s,mustBePositive=True)
209
210    @optimized_version_of("doubleAndEncodeSpec")
211    def doubleAndEncode(self):
212        X,Y,Z,T = self.xyzt()
213        a,d,mneg = self.a,self.d,self.mneg
214
215        if self.cofactor==8:
216            e = 2*X*Y
217            f = Z^2+d*T^2
218            g = Y^2-a*X^2
219            h = Z^2-d*T^2
220
221            inv1 = 1/(e*f*g*h)
222            z_inv = inv1*e*g # 1 / (f*h)
223            t_inv = inv1*f*h
224
225            if negative(e*g*z_inv):
226                if a==-1: sqrta = self.i
227                else:     sqrta = -1
228                e,f,g,h = g,h,-e,f*sqrta
229                factor = self.i
230            else:
231                factor = self.magic
232
233            if negative(h*e*z_inv): g=-g
234            s = (h-g)*factor*g*t_inv
235
236        else:
237            foo = Y^2+a*X^2
238            bar = X*Y
239            den = 1/(foo*bar)
240            if negative(2*bar^2*den): tmp = a*X^2
241            else: tmp = Y^2
242            s = self.magic*(Z^2-tmp)*foo*den
243
244        return self.gfToBytes(s,mustBePositive=True)
245
246    @classmethod
247    @optimized_version_of("decodeSpec")
248    def decode(cls,s):
249        """Decode, optimized version"""
250        s = cls.bytesToGf(s,mustBePositive=True)
251
252        a,d = cls.a,cls.d
253        yden     = 1-a*s^2
254        ynum     = 1+a*s^2
255        yden_sqr = yden^2
256        xden_sqr = a*d*ynum^2 - yden_sqr
257
258        isr = isqrt(xden_sqr * yden_sqr)
259
260        xden_inv = isr * yden
261        yden_inv = xden_inv * isr * xden_sqr
262
263        x = 2*s*xden_inv
264        if negative(x): x = -x
265        y = ynum * yden_inv
266
267        if cls.cofactor==8 and (negative(x*y) or y==0):
268            raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
269
270        return cls(x,y)
271
272    @classmethod
273    def fromJacobiQuartic(cls,s,t,sgn=1):
274        """Convert point from its Jacobi Quartic representation"""
275        a,d = cls.a,cls.d
276        assert s^4 - 2*cls.a*(1-2*d/(d-a))*s^2 + 1 == t^2
277        x = 2*s*cls.magic / t
278        y = (1+a*s^2) / (1-a*s^2)
279        return cls(sgn*x,y)
280
281    @classmethod
282    def elligatorSpec(cls,r0):
283        a,d = cls.a,cls.d
284        r = cls.qnr * cls.bytesToGf(r0,mustBeProper=False,maskHiBits=True)^2
285        den = (d*r-a)*(a*r-d)
286        if den == 0: return cls()
287        n1 = cls.a*(r+1)*(a+d)*(d-a)/den
288        n2 = r*n1
289        if is_square(n1):
290            sgn,s,t =  1, xsqrt(n1), -(r-1)*(a+d)^2 / den - 1
291        else:
292            sgn,s,t = -1,-xsqrt(n2), r*(r-1)*(a+d)^2 / den - 1
293
294        return cls.fromJacobiQuartic(s,t)
295
296    @classmethod
297    @optimized_version_of("elligatorSpec")
298    def elligator(cls,r0):
299        a,d = cls.a,cls.d
300        r0 = cls.bytesToGf(r0,mustBeProper=False,maskHiBits=True)
301        r = cls.qnr * r0^2
302        den = (d*r-a)*(a*r-d)
303        num = cls.a*(r+1)*(a+d)*(d-a)
304
305        iss,isri = isqrt_i(num*den)
306        if iss: sgn,twiddle =  1,1
307        else:   sgn,twiddle = -1,r0*cls.qnr
308        isri *= twiddle
309        s = isri*num
310        t = -sgn*isri*s*(r-1)*(d+a)^2 - 1
311        if negative(s) == iss: s = -s
312        return cls.fromJacobiQuartic(s,t)
313
314
315class Decaf_1_1_Point(QuotientEdwardsPoint):
316    """Like current decaf but tweaked for simplicity"""
317    def encodeSpec(self):
318        """Unoptimized specification for encoding"""
319        a,d = self.a,self.d
320        x,y = self
321        if x==0 or y==0: return(self.gfToBytes(0))
322
323        if self.cofactor==8 and negative(x*y*self.isoMagic):
324            x,y = self.torque()
325
326        sr = xsqrt(1-a*x^2)
327        altx = x*y*self.isoMagic / sr
328        if negative(altx): s = (1+sr)/x
329        else:              s = (1-sr)/x
330
331        return self.gfToBytes(s,mustBePositive=True)
332
333    @classmethod
334    def decodeSpec(cls,s):
335        """Unoptimized specification for decoding"""
336        a,d = cls.a,cls.d
337        s = cls.bytesToGf(s,mustBePositive=True)
338
339        if s==0: return cls()
340        t = xsqrt(s^4 + 2*(a-2*d)*s^2 + 1)
341        altx = 2*s*cls.isoMagic/t
342        if negative(altx): t = -t
343        x = 2*s / (1+a*s^2)
344        y = (1-a*s^2) / t
345
346        if cls.cofactor==8 and (negative(x*y*cls.isoMagic) or y==0):
347            raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
348
349        return cls(x,y)
350
351    def toJacobiQuartic(self,toggle_rotation=False,toggle_altx=False,toggle_s=False):
352        "Return s,t on jacobi curve"
353        a,d = self.a,self.d
354        x,y,z,t = self.xyzt()
355
356        if self.cofactor == 8:
357            # Cofactor 8 version
358            # Simulate IMAGINE_TWIST because that's how libdecaf does it
359            x = self.i*x
360            t = self.i*t
361            a = -a
362            d = -d
363
364            # OK, the actual libdecaf code should be here
365            num = (z+y)*(z-y)
366            den = x*y
367            isr = isqrt(num*(a-d)*den^2)
368
369            iden = isr * den * self.isoMagic # 1/sqrt((z+y)(z-y)) = 1/sqrt(1-Y^2) / z
370            inum = isr * num # sqrt(1-Y^2) * z / xysqrt(a-d) ~ 1/sqrt(1-ax^2)/z
371
372            if negative(iden*inum*self.i*t^2*(d-a)) != toggle_rotation:
373                iden,inum = inum,iden
374                fac = x*sqrt(a)
375                toggle=(a==-1)
376            else:
377                fac = y
378                toggle=False
379
380            imi = self.isoMagic * self.i
381            altx = inum*t*imi
382            neg_altx = negative(altx) != toggle_altx
383            if neg_altx != toggle: inum =- inum
384
385            tmp = fac*(inum*z + 1)
386            s = iden*tmp*imi
387
388            negm1 = (negative(s) != toggle_s) != neg_altx
389            if negm1: m1 = a*fac + z
390            else:     m1 = a*fac - z
391
392            swap = toggle_s
393
394        else:
395            # Much simpler cofactor 4 version
396            num = (x+t)*(x-t)
397            isr = isqrt(num*(a-d)*x^2)
398            ratio = isr*num
399            altx = ratio*self.isoMagic
400
401            neg_altx = negative(altx) != toggle_altx
402            if neg_altx: ratio =- ratio
403
404            tmp = ratio*z - t
405            s = (a-d)*isr*x*tmp
406
407            negx = (negative(s) != toggle_s) != neg_altx
408            if negx: m1 = -a*t + x
409            else:    m1 = -a*t - x
410
411            swap = toggle_s
412
413        if negative(s): s = -s
414
415        return s,m1,a*tmp,swap
416
417    def invertElligator(self,toggle_r=False,*args,**kwargs):
418        "Produce preimage of self under elligator, or None"
419        a,d = self.a,self.d
420
421        rets = []
422
423        tr = [False,True] if self.cofactor == 8 else [False]
424        for toggle_rotation in tr:
425            for toggle_altx in [False,True]:
426                for toggle_s in [False,True]:
427                    for toggle_r in [False,True]:
428                        s,m1,m12,swap = self.toJacobiQuartic(toggle_rotation,toggle_altx,toggle_s)
429
430                        #print
431                        #print toggle_rotation,toggle_altx,toggle_s
432                        #print m1
433                        #print m12
434
435
436                        if self == self.__class__():
437                            if self.cofactor == 4:
438                                # Hacks for identity!
439                                if toggle_altx: m12 = 1
440                                elif toggle_s: m1 = 1
441                                elif toggle_r: continue
442                                ## BOTH???
443
444                            else:
445                                m12 = 1
446                                imi = self.isoMagic * self.i
447                                if toggle_rotation:
448                                    if toggle_altx: m1 = -imi
449                                    else:           m1 = +imi
450                                else:
451                                    if toggle_altx: m1 = 0
452                                    else: m1 = a-d
453
454                        rnum = (d*a*m12-m1)
455                        rden = ((d*a-1)*m12+m1)
456                        if swap: rnum,rden = rden,rnum
457
458                        ok,sr = isqrt_i(rnum*rden*self.qnr)
459                        if not ok: continue
460                        sr *= rnum
461                        #print "Works! %d %x" % (swap,sr)
462
463                        if negative(sr) != toggle_r: sr = -sr
464                        ret = self.gfToBytes(sr)
465                        if self.elligator(ret) != self and self.elligator(ret) != -self:
466                            print "WRONG!",[toggle_rotation,toggle_altx,toggle_s]
467                        if self.elligator(ret) == -self and self != -self: print "Negated!",[toggle_rotation,toggle_altx,toggle_s]
468                        rets.append(bytes(ret))
469        return rets
470
471    @optimized_version_of("encodeSpec")
472    def encode(self):
473        """Encode, optimized version"""
474        return self.gfToBytes(self.toJacobiQuartic()[0])
475
476    @classmethod
477    @optimized_version_of("decodeSpec")
478    def decode(cls,s):
479        """Decode, optimized version"""
480        a,d = cls.a,cls.d
481        s = cls.bytesToGf(s,mustBePositive=True)
482
483        #if s==0: return cls()
484        s2 = s^2
485        den = 1+a*s2
486        num = den^2 - 4*d*s2
487        isr = isqrt(num*den^2)
488        altx = 2*s*isr*den*cls.isoMagic
489        if negative(altx): isr = -isr
490        x = 2*s *isr^2*den*num
491        y = (1-a*s^2) * isr*den
492
493        if cls.cofactor==8 and (negative(x*y*cls.isoMagic) or y==0):
494            raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
495
496        return cls(x,y)
497
498    @classmethod
499    def fromJacobiQuartic(cls,s,t,sgn=1):
500        """Convert point from its Jacobi Quartic representation"""
501        a,d = cls.a,cls.d
502        if s==0: return cls()
503        x = 2*s / (1+a*s^2)
504        y = (1-a*s^2) / t
505        return cls(x,sgn*y)
506
507    @optimized_version_of("doubleAndEncodeSpec")
508    def doubleAndEncode(self):
509        X,Y,Z,T = self.xyzt()
510        a,d = self.a,self.d
511
512        if self.cofactor == 8:
513            # Cofactor 8 version
514            # Simulate IMAGINE_TWIST because that's how libdecaf does it
515            X = self.i*X
516            T = self.i*T
517            a = -a
518            d = -d
519            # TODO: This is only being called for a=-1, so could
520            # be wrong for a=1
521
522            e = 2*X*Y
523            f = Y^2+a*X^2
524            g = Y^2-a*X^2
525            h = Z^2-d*T^2
526
527            eim = e*self.isoMagic
528            inv = 1/(eim*g*f*h)
529            fh_inv = eim*g*inv*self.i
530
531            if negative(eim*g*fh_inv):
532                idf = g*self.isoMagic*self.i
533                bar = f
534                foo = g
535                test = eim*f
536            else:
537                idf = eim
538                bar = h
539                foo = -eim
540                test = g*h
541
542            if negative(test*fh_inv): bar =- bar
543            s = idf*(foo+bar)*inv*f*h
544
545        else:
546            xy = X*Y
547            h = Z^2-d*T^2
548            inv = 1/(xy*h)
549            if negative(inv*2*xy^2*self.isoMagic): tmp = Y
550            else: tmp = X
551            s = tmp^2*h*inv # = X/Y or Y/X, interestingly
552
553        return self.gfToBytes(s,mustBePositive=True)
554
555    @classmethod
556    def elligatorSpec(cls,r0,fromR=False):
557        a,d = cls.a,cls.d
558        if fromR: r = r0
559        else: r = cls.qnr * cls.bytesToGf(r0,mustBeProper=False,maskHiBits=True)^2
560
561        den = (d*r-(d-a))*((d-a)*r-d)
562        if den == 0: return cls()
563        n1 = (r+1)*(a-2*d)/den
564        n2 = r*n1
565        if is_square(n1):
566            sgn,s,t = 1,   xsqrt(n1),  -(r-1)*(a-2*d)^2 / den - 1
567        else:
568            sgn,s,t = -1, -xsqrt(n2), r*(r-1)*(a-2*d)^2 / den - 1
569
570        return cls.fromJacobiQuartic(s,t)
571
572    @classmethod
573    @optimized_version_of("elligatorSpec")
574    def elligator(cls,r0):
575        a,d = cls.a,cls.d
576        r0 = cls.bytesToGf(r0,mustBeProper=False,maskHiBits=True)
577        r = cls.qnr * r0^2
578        den = (d*r-(d-a))*((d-a)*r-d)
579        num = (r+1)*(a-2*d)
580
581        iss,isri = isqrt_i(num*den)
582        if iss: sgn,twiddle =  1,1
583        else:   sgn,twiddle = -1,r0*cls.qnr
584        isri *= twiddle
585        s = isri*num
586        t = -sgn*isri*s*(r-1)*(a-2*d)^2 - 1
587        if negative(s) == iss: s = -s
588        return cls.fromJacobiQuartic(s,t)
589
590    def elligatorInverseBruteForce(self):
591        """Invert Elligator using SAGE's polynomial solver"""
592        a,d = self.a,self.d
593        R.<r0> = self.F[]
594        r = self.qnr * r0^2
595        den = (d*r-(d-a))*((d-a)*r-d)
596        n1 = (r+1)*(a-2*d)/den
597        n2 = r*n1
598        ret = set()
599        for s2,t in [(n1, -(r-1)*(a-2*d)^2 / den - 1),
600                     (n2,r*(r-1)*(a-2*d)^2 / den - 1)]:
601            x2 = 4*s2/(1+a*s2)^2
602            y = (1-a*s2) / t
603
604            selfT = self
605            for i in xrange(self.cofactor/2):
606                xT,yT = selfT
607                polyX = xT^2-x2
608                polyY = yT-y
609                sx = set(r for r,_ in polyX.numerator().roots())
610                sy = set(r for r,_ in polyY.numerator().roots())
611                ret = ret.union(sx.intersection(sy))
612
613                selfT = selfT.torque()
614
615        ret = [self.gfToBytes(r) for r in ret]
616
617        for r in ret:
618            assert self.elligator(r) in [self,-self]
619
620        ret = [r for r in ret if self.elligator(r) == self]
621
622        return ret
623
624class Ed25519Point(RistrettoPoint):
625    F = GF(2^255-19)
626    d = F(-121665/121666)
627    a = F(-1)
628    i = sqrt(F(-1))
629    mneg = F(1)
630    qnr = i
631    magic = isqrt(a*d-1)
632    cofactor = 8
633    encLen = 32
634
635    @classmethod
636    def base(cls):
637        return cls( 15112221349535400772501151409588531511454012693041857206046113283949847762202, 46316835694926478169428394003475163141307993866256225615783033603165251855960
638        )
639
640class NegEd25519Point(RistrettoPoint):
641    F = GF(2^255-19)
642    d = F(121665/121666)
643    a = F(1)
644    i = sqrt(F(-1))
645    mneg = F(-1) # TODO checkme vs 1-ad or whatever
646    qnr = i
647    magic = isqrt(a*d-1)
648    cofactor = 8
649    encLen = 32
650
651    @classmethod
652    def base(cls):
653        y = cls.F(4/5)
654        x = sqrt((y^2-1)/(cls.d*y^2-cls.a))
655        if negative(x): x = -x
656        return cls(x,y)
657
658class IsoEd448Point(RistrettoPoint):
659    F = GF(2^448-2^224-1)
660    d = F(39082/39081)
661    a = F(1)
662    mneg = F(-1)
663    qnr = -1
664    magic = isqrt(a*d-1)
665    cofactor = 4
666    encLen = 56
667
668    @classmethod
669    def base(cls):
670        return cls(  # RFC has it wrong
671         345397493039729516374008604150537410266655260075183290216406970281645695073672344430481787759340633221708391583424041788924124567700732,
672            -363419362147803445274661903944002267176820680343659030140745099590306164083365386343198191849338272965044442230921818680526749009182718
673        )
674
675class TwistedEd448GoldilocksPoint(Decaf_1_1_Point):
676    F = GF(2^448-2^224-1)
677    d = F(-39082)
678    a = F(-1)
679    qnr = -1
680    cofactor = 4
681    encLen = 56
682    isoMagic = IsoEd448Point.magic
683
684    @classmethod
685    def base(cls):
686        return cls.decodeSpec(Ed448GoldilocksPoint.base().encodeSpec())
687
688class Ed448GoldilocksPoint(Decaf_1_1_Point):
689    F = GF(2^448-2^224-1)
690    d = F(-39081)
691    a = F(1)
692    qnr = -1
693    cofactor = 4
694    encLen = 56
695    isoMagic = IsoEd448Point.magic
696
697    @classmethod
698    def base(cls):
699        return 2*cls(
700 224580040295924300187604334099896036246789641632564134246125461686950415467406032909029192869357953282578032075146446173674602635247710, 298819210078481492676017930443930673437544040154080242095928241372331506189835876003536878655418784733982303233503462500531545062832660
701        )
702
703class IsoEd25519Point(Decaf_1_1_Point):
704    # TODO: twisted iso too!
705    # TODO: twisted iso might have to IMAGINE_TWIST or whatever
706    F = GF(2^255-19)
707    d = F(-121665)
708    a = F(1)
709    i = sqrt(F(-1))
710    qnr = i
711    magic = isqrt(a*d-1)
712    cofactor = 8
713    encLen = 32
714    isoMagic = Ed25519Point.magic
715    isoA = Ed25519Point.a
716
717    @classmethod
718    def base(cls):
719        return cls.decodeSpec(Ed25519Point.base().encode())
720
721class TestFailedException(Exception): pass
722
723def test(cls,n):
724    print "Testing curve %s" % cls.__name__
725
726    specials = [1]
727    ii = cls.F(-1)
728    while is_square(ii):
729        specials.append(ii)
730        ii = sqrt(ii)
731    specials.append(ii)
732    for i in specials:
733        if negative(cls.F(i)): i = -i
734        i = enc_le(i,cls.encLen)
735        try:
736            Q = cls.decode(i)
737            QE = Q.encode()
738            if QE != i:
739                raise TestFailedException("Round trip special %s != %s" %
740                     (binascii.hexlify(QE),binascii.hexlify(i)))
741        except NotOnCurveException: pass
742        except InvalidEncodingException: pass
743
744
745    P = cls.base()
746    Q = cls()
747    for i in xrange(n):
748        #print binascii.hexlify(Q.encode())
749        QE = Q.encode()
750        QQ = cls.decode(QE)
751        if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q)))
752
753        # Testing s -> 1/s: encodes -point on cofactor
754        s = cls.bytesToGf(QE)
755        if s != 0:
756            ss = cls.gfToBytes(1/s,mustBePositive=True)
757            try:
758                QN = cls.decode(ss)
759                if cls.cofactor == 8:
760                    raise TestFailedException("1/s shouldnt work for cofactor 8")
761                if QN != -Q:
762                    raise TestFailedException("s -> 1/s should negate point for cofactor 4")
763            except InvalidEncodingException as e:
764                # Should be raised iff cofactor==8
765                if cls.cofactor == 4:
766                    raise TestFailedException("s -> 1/s should work for cofactor 4")
767
768        QT = Q
769        for h in xrange(cls.cofactor):
770            QT = QT.torque()
771            if QT.encode() != QE:
772                raise TestFailedException("Can't torque %s,%d" % (str(Q),h+1))
773
774        Q0 = Q + P
775        if Q0 == Q: raise TestFailedException("Addition doesn't work")
776        if Q0-P != Q: raise TestFailedException("Subtraction doesn't work")
777
778        r = randint(1,1000)
779        Q1 = Q0*r
780        Q2 = Q0*(r+1)
781        if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work")
782        Q = Q1
783
784def testElligator(cls,n):
785    print "Testing elligator on %s" % cls.__name__
786    for i in xrange(n):
787        r = randombytes(cls.encLen)
788        P = cls.elligator(r)
789        if hasattr(P,"invertElligator"):
790            iv = P.invertElligator()
791            modr = bytes(cls.gfToBytes(cls.bytesToGf(r,mustBeProper=False,maskHiBits=True)))
792            iv2 = P.torque().invertElligator()
793            if modr not in iv: print "Failed to invert Elligator!"
794            if len(iv) != len(set(iv)):
795                print "Elligator inverses not unique!", len(set(iv)), len(iv)
796            if iv != iv2:
797                print "Elligator is untorqueable!"
798                #print [binascii.hexlify(j) for j in iv]
799                #print [binascii.hexlify(j) for j in iv2]
800                #break
801        else:
802            pass # TODO
803
804def gangtest(classes,n):
805    print "Gang test",[cls.__name__ for cls in classes]
806    specials = [1]
807    ii = classes[0].F(-1)
808    while is_square(ii):
809        specials.append(ii)
810        ii = sqrt(ii)
811    specials.append(ii)
812
813    for i in xrange(n):
814        rets = [bytes((cls.base()*i).encode()) for cls in classes]
815        if len(set(rets)) != 1:
816            print "Divergence in encode at %d" % i
817            for c,ret in zip(classes,rets):
818                print c,binascii.hexlify(ret)
819            print
820
821        if i < len(specials): r0 = enc_le(specials[i],classes[0].encLen)
822        else: r0 = randombytes(classes[0].encLen)
823
824        rets = [bytes((cls.elligator(r0)*i).encode()) for cls in classes]
825        if len(set(rets)) != 1:
826            print "Divergence in elligator at %d" % i
827            for c,ret in zip(classes,rets):
828                print c,binascii.hexlify(ret)
829            print
830
831def testDoubleAndEncode(cls,n):
832    print "Testing doubleAndEncode on %s" % cls.__name__
833    for i in xrange(n):
834        r1 = randombytes(cls.encLen)
835        r2 = randombytes(cls.encLen)
836        u = cls.elligator(r1) + cls.elligator(r2)
837        u.doubleAndEncode()
838
839testDoubleAndEncode(Ed25519Point,100)
840testDoubleAndEncode(NegEd25519Point,100)
841testDoubleAndEncode(IsoEd25519Point,100)
842testDoubleAndEncode(IsoEd448Point,100)
843testDoubleAndEncode(TwistedEd448GoldilocksPoint,100)
844#test(Ed25519Point,100)
845#test(NegEd25519Point,100)
846#test(IsoEd25519Point,100)
847#test(IsoEd448Point,100)
848#test(TwistedEd448GoldilocksPoint,100)
849#test(Ed448GoldilocksPoint,100)
850#testElligator(Ed25519Point,100)
851#testElligator(NegEd25519Point,100)
852#testElligator(IsoEd25519Point,100)
853#testElligator(IsoEd448Point,100)
854#testElligator(Ed448GoldilocksPoint,100)
855#testElligator(TwistedEd448GoldilocksPoint,100)
856#gangtest([IsoEd448Point,TwistedEd448GoldilocksPoint,Ed448GoldilocksPoint],100)
857#gangtest([Ed25519Point,IsoEd25519Point],100)
858