1# This file was *autogenerated* from the file ristretto.sage
2from sage.all_cmdline import *   # import sage library
3_sage_const_363419362147803445274661903944002267176820680343659030140745099590306164083365386343198191849338272965044442230921818680526749009182718 = Integer(363419362147803445274661903944002267176820680343659030140745099590306164083365386343198191849338272965044442230921818680526749009182718); _sage_const_2 = Integer(2); _sage_const_1 = Integer(1); _sage_const_0 = Integer(0); _sage_const_121666 = Integer(121666); _sage_const_5 = Integer(5); _sage_const_121665 = Integer(121665); _sage_const_8 = Integer(8); _sage_const_298819210078481492676017930443930673437544040154080242095928241372331506189835876003536878655418784733982303233503462500531545062832660 = Integer(298819210078481492676017930443930673437544040154080242095928241372331506189835876003536878655418784733982303233503462500531545062832660); _sage_const_15112221349535400772501151409588531511454012693041857206046113283949847762202 = Integer(15112221349535400772501151409588531511454012693041857206046113283949847762202); _sage_const_46316835694926478169428394003475163141307993866256225615783033603165251855960 = Integer(46316835694926478169428394003475163141307993866256225615783033603165251855960); _sage_const_345397493039729516374008604150537410266655260075183290216406970281645695073672344430481787759340633221708391583424041788924124567700732 = Integer(345397493039729516374008604150537410266655260075183290216406970281645695073672344430481787759340633221708391583424041788924124567700732); _sage_const_4 = Integer(4); _sage_const_448 = Integer(448); _sage_const_0xFF = Integer(0xFF); _sage_const_255 = Integer(255); _sage_const_224580040295924300187604334099896036246789641632564134246125461686950415467406032909029192869357953282578032075146446173674602635247710 = Integer(224580040295924300187604334099896036246789641632564134246125461686950415467406032909029192869357953282578032075146446173674602635247710); _sage_const_1000 = Integer(1000); _sage_const_56 = Integer(56); _sage_const_19 = Integer(19); _sage_const_32 = Integer(32); _sage_const_100 = Integer(100); _sage_const_39082 = Integer(39082); _sage_const_39081 = Integer(39081); _sage_const_224 = Integer(224)
4import binascii
5class InvalidEncodingException(Exception): pass
6class NotOnCurveException(Exception): pass
7class SpecException(Exception): pass
8
9def lobit(x): return int(x) & _sage_const_1
10def hibit(x): return lobit(_sage_const_2 *x)
11def negative(x): return lobit(x)
12def enc_le(x,n): return bytearray([int(x)>>(_sage_const_8 *i) & _sage_const_0xFF  for i in xrange(n)])
13def dec_le(x): return sum(b<<(_sage_const_8 *i) for i,b in enumerate(x))
14def randombytes(n): return bytearray([randint(_sage_const_0 ,_sage_const_255 ) for _ in range(n)])
15
16def optimized_version_of(spec):
17    """Decorator: This function is an optimized version of some specification"""
18    def decorator(f):
19        def wrapper(self,*args,**kwargs):
20            def pr(x):
21                if isinstance(x,bytearray): return binascii.hexlify(x)
22                else: return str(x)
23            try: spec_ans = getattr(self,spec,spec)(*args,**kwargs),None
24            except Exception as e: spec_ans = None,e
25            try: opt_ans = f(self,*args,**kwargs),None
26            except Exception as e: opt_ans = None,e
27            if spec_ans[_sage_const_1 ] is None and opt_ans[_sage_const_1 ] is not None:
28                raise
29                #raise SpecException("Mismatch in %s: spec returned %s but opt threw %s"
30                #    % (f.__name__,str(spec_ans[0]),str(opt_ans[1])))
31            if spec_ans[_sage_const_1 ] is not None and opt_ans[_sage_const_1 ] is None:
32                raise
33                #raise SpecException("Mismatch in %s: spec threw %s but opt returned %s"
34                #    % (f.__name__,str(spec_ans[1]),str(opt_ans[0])))
35            if spec_ans[_sage_const_0 ] != opt_ans[_sage_const_0 ]:
36                raise SpecException("Mismatch in %s: %s != %s"
37                    % (f.__name__,pr(spec_ans[_sage_const_0 ]),pr(opt_ans[_sage_const_0 ])))
38            if opt_ans[_sage_const_1 ] is not None: raise
39            else: return opt_ans[_sage_const_0 ]
40        wrapper.__name__ = f.__name__
41        return wrapper
42    return decorator
43
44def xsqrt(x,exn=InvalidEncodingException("Not on curve")):
45    """Return sqrt(x)"""
46    if not is_square(x): raise exn
47    s = sqrt(x)
48    if negative(s): s=-s
49    return s
50
51def isqrt(x,exn=InvalidEncodingException("Not on curve")):
52    """Return 1/sqrt(x)"""
53    if x==_sage_const_0 : return _sage_const_0
54    if not is_square(x): raise exn
55    s = sqrt(x)
56    #if negative(s): s=-s
57    return _sage_const_1 /s
58
59def inv0(x): return _sage_const_1 /x if x != _sage_const_0  else _sage_const_0
60
61def isqrt_i(x):
62    """Return 1/sqrt(x) or 1/sqrt(zeta * x)"""
63    if x==_sage_const_0 : return True,_sage_const_0
64    gen = x.parent(-_sage_const_1 )
65    while is_square(gen): gen = sqrt(gen)
66    if is_square(x): return True,_sage_const_1 /sqrt(x)
67    else: return False,_sage_const_1 /sqrt(x*gen)
68
69class QuotientEdwardsPoint(object):
70    """Abstract class for point an a quotiented Edwards curve; needs F,a,d,cofactor to work"""
71    def __init__(self,x=_sage_const_0 ,y=_sage_const_1 ):
72        x = self.x = self.F(x)
73        y = self.y = self.F(y)
74        if y**_sage_const_2  + self.a*x**_sage_const_2  != _sage_const_1  + self.d*x**_sage_const_2 *y**_sage_const_2 :
75            raise NotOnCurveException(str(self))
76
77    def __repr__(self):
78        return "%s(0x%x,0x%x)" % (self.__class__.__name__, self.x, self.y)
79
80    def __iter__(self):
81        yield self.x
82        yield self.y
83
84    def __add__(self,other):
85        x,y = self
86        X,Y = other
87        a,d = self.a,self.d
88        return self.__class__(
89            (x*Y+y*X)/(_sage_const_1 +d*x*y*X*Y),
90            (y*Y-a*x*X)/(_sage_const_1 -d*x*y*X*Y)
91        )
92
93    def __neg__(self): return self.__class__(-self.x,self.y)
94    def __sub__(self,other): return self + (-other)
95    def __rmul__(self,other): return self*other
96    def __eq__(self,other):
97        """NB: this is the only method that is different from the usual one"""
98        x,y = self
99        X,Y = other
100        return x*Y == X*y or (self.cofactor==_sage_const_8  and -self.a*x*X == y*Y)
101    def __ne__(self,other): return not (self==other)
102
103    def __mul__(self,exp):
104        exp = int(exp)
105        if exp < _sage_const_0 : exp,self = -exp,-self
106        total = self.__class__()
107        work  = self
108        while exp != _sage_const_0 :
109            if exp & _sage_const_1 : total += work
110            work += work
111            exp >>= _sage_const_1
112        return total
113
114    def xyzt(self):
115        x,y = self
116        z = self.F.random_element()
117        return x*z,y*z,z,x*y*z
118
119    def torque(self):
120        """Apply cofactor group, except keeping the point even"""
121        if self.cofactor == _sage_const_8 :
122            if self.a == -_sage_const_1 : return self.__class__(self.y*self.i, self.x*self.i)
123            if self.a ==  _sage_const_1 : return self.__class__(-self.y, self.x)
124        else:
125            return self.__class__(-self.x, -self.y)
126
127
128    # Utility functions
129    @classmethod
130    def bytesToGf(cls,bytes,mustBeProper=True,mustBePositive=False,maskHiBits=False):
131        """Convert little-endian bytes to field element, sanity check length"""
132        if len(bytes) != cls.encLen:
133            raise InvalidEncodingException("wrong length %d" % len(bytes))
134        s = dec_le(bytes)
135        if mustBeProper and s >= cls.F.order():
136            raise InvalidEncodingException("%d out of range!" % s)
137        bitlen = int(ceil(log(cls.F.order())/log(_sage_const_2 )))
138        if maskHiBits: s &= _sage_const_2 **bitlen-_sage_const_1
139        s = cls.F(s)
140        if mustBePositive and negative(s):
141            raise InvalidEncodingException("%d is negative!" % s)
142        return s
143
144    @classmethod
145    def gfToBytes(cls,x,mustBePositive=False):
146        """Convert little-endian bytes to field element, sanity check length"""
147        if negative(x) and mustBePositive: x = -x
148        return enc_le(x,cls.encLen)
149
150class RistrettoPoint(QuotientEdwardsPoint):
151    """The new Ristretto group"""
152    def encodeSpec(self):
153        """Unoptimized specification for encoding"""
154        x,y = self
155        if self.cofactor==_sage_const_8  and (negative(x*y) or y==_sage_const_0 ): (x,y) = self.torque()
156        if y == -_sage_const_1 : y = _sage_const_1  # Avoid divide by 0; doesn't affect impl
157
158        if negative(x): x,y = -x,-y
159        s = xsqrt(self.mneg*(_sage_const_1 -y)/(_sage_const_1 +y),exn=Exception("Unimplemented: point is odd: " + str(self)))
160        return self.gfToBytes(s)
161
162    @classmethod
163    def decodeSpec(cls,s):
164        """Unoptimized specification for decoding"""
165        s = cls.bytesToGf(s,mustBePositive=True)
166
167        a,d = cls.a,cls.d
168        x = xsqrt(_sage_const_4 *s**_sage_const_2  / (a*d*(_sage_const_1 +a*s**_sage_const_2 )**_sage_const_2  - (_sage_const_1 -a*s**_sage_const_2 )**_sage_const_2 ))
169        y = (_sage_const_1 +a*s**_sage_const_2 ) / (_sage_const_1 -a*s**_sage_const_2 )
170
171        if cls.cofactor==_sage_const_8  and (negative(x*y) or y==_sage_const_0 ):
172            raise InvalidEncodingException("x*y has high bit")
173
174        return cls(x,y)
175
176    @optimized_version_of("encodeSpec")
177    def encode(self):
178        """Encode, optimized version"""
179        a,d,mneg = self.a,self.d,self.mneg
180        x,y,z,t = self.xyzt()
181
182        if self.cofactor==_sage_const_8 :
183            u1    = mneg*(z+y)*(z-y)
184            u2    = x*y # = t*z
185            isr   = isqrt(u1*u2**_sage_const_2 )
186            i1    = isr*u1 # sqrt(mneg*(z+y)*(z-y))/(x*y)
187            i2    = isr*u2 # 1/sqrt(a*(y+z)*(y-z))
188            z_inv = i1*i2*t # 1/z
189
190            if negative(t*z_inv):
191                if a==-_sage_const_1 :
192                    x,y = y*self.i,x*self.i
193                    den_inv = self.magic * i1
194                else:
195                    x,y = -y,x
196                    den_inv = self.i * self.magic * i1
197
198            else:
199                den_inv = i2
200
201            if negative(x*z_inv): y = -y
202            s = (z-y) * den_inv
203        else:
204            num   = mneg*(z+y)*(z-y)
205            isr   = isqrt(num*y**_sage_const_2 )
206            if negative(isr**_sage_const_2 *num*y*t): y = -y
207            s = isr*y*(z-y)
208
209
210        return self.gfToBytes(s,mustBePositive=True)
211
212    @classmethod
213    @optimized_version_of("decodeSpec")
214    def decode(cls,s):
215        """Decode, optimized version"""
216        s = cls.bytesToGf(s,mustBePositive=True)
217
218        a,d = cls.a,cls.d
219        yden     = _sage_const_1 -a*s**_sage_const_2
220        ynum     = _sage_const_1 +a*s**_sage_const_2
221        yden_sqr = yden**_sage_const_2
222        xden_sqr = a*d*ynum**_sage_const_2  - yden_sqr
223
224        isr = isqrt(xden_sqr * yden_sqr)
225
226        xden_inv = isr * yden
227        yden_inv = xden_inv * isr * xden_sqr
228
229        x = _sage_const_2 *s*xden_inv
230        if negative(x): x = -x
231        y = ynum * yden_inv
232
233        if cls.cofactor==_sage_const_8  and (negative(x*y) or y==_sage_const_0 ):
234            raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
235
236        return cls(x,y)
237
238    @classmethod
239    def fromJacobiQuartic(cls,s,t,sgn=_sage_const_1 ):
240        """Convert point from its Jacobi Quartic representation"""
241        a,d = cls.a,cls.d
242        assert s**_sage_const_4  - _sage_const_2 *cls.a*(_sage_const_1 -_sage_const_2 *d/(d-a))*s**_sage_const_2  + _sage_const_1  == t**_sage_const_2
243        x = _sage_const_2 *s*cls.magic / t
244        y = (_sage_const_1 +a*s**_sage_const_2 ) / (_sage_const_1 -a*s**_sage_const_2 )
245        return cls(sgn*x,y)
246
247    @classmethod
248    def elligatorSpec(cls,r0):
249        a,d = cls.a,cls.d
250        r = cls.qnr * cls.bytesToGf(r0,mustBeProper=False,maskHiBits=True)**_sage_const_2
251        den = (d*r-a)*(a*r-d)
252        if den == _sage_const_0 : return cls()
253        n1 = cls.a*(r+_sage_const_1 )*(a+d)*(d-a)/den
254        n2 = r*n1
255        if is_square(n1):
256            sgn,s,t =  _sage_const_1 , xsqrt(n1), -(r-_sage_const_1 )*(a+d)**_sage_const_2  / den - _sage_const_1
257        else:
258            sgn,s,t = -_sage_const_1 ,-xsqrt(n2), r*(r-_sage_const_1 )*(a+d)**_sage_const_2  / den - _sage_const_1
259
260        return cls.fromJacobiQuartic(s,t)
261
262    @classmethod
263    @optimized_version_of("elligatorSpec")
264    def elligator(cls,r0):
265        a,d = cls.a,cls.d
266        r0 = cls.bytesToGf(r0,mustBeProper=False,maskHiBits=True)
267        r = cls.qnr * r0**_sage_const_2
268        den = (d*r-a)*(a*r-d)
269        num = cls.a*(r+_sage_const_1 )*(a+d)*(d-a)
270
271        iss,isri = isqrt_i(num*den)
272        if iss: sgn,twiddle =  _sage_const_1 ,_sage_const_1
273        else:   sgn,twiddle = -_sage_const_1 ,r0*cls.qnr
274        isri *= twiddle
275        s = isri*num
276        t = -sgn*isri*s*(r-_sage_const_1 )*(d+a)**_sage_const_2  - _sage_const_1
277        if negative(s) == iss: s = -s
278        return cls.fromJacobiQuartic(s,t)
279
280
281class Decaf_1_1_Point(QuotientEdwardsPoint):
282    """Like current decaf but tweaked for simplicity"""
283    def encodeSpec(self):
284        """Unoptimized specification for encoding"""
285        a,d = self.a,self.d
286        x,y = self
287        if x==_sage_const_0  or y==_sage_const_0 : return(self.gfToBytes(_sage_const_0 ))
288
289        if self.cofactor==_sage_const_8  and negative(x*y*self.isoMagic):
290            x,y = self.torque()
291
292        sr = xsqrt(_sage_const_1 -a*x**_sage_const_2 )
293        altx = x*y*self.isoMagic / sr
294        if negative(altx): s = (_sage_const_1 +sr)/x
295        else:              s = (_sage_const_1 -sr)/x
296
297        return self.gfToBytes(s,mustBePositive=True)
298
299    @classmethod
300    def decodeSpec(cls,s):
301        """Unoptimized specification for decoding"""
302        a,d = cls.a,cls.d
303        s = cls.bytesToGf(s,mustBePositive=True)
304
305        if s==_sage_const_0 : return cls()
306        t = xsqrt(s**_sage_const_4  + _sage_const_2 *(a-_sage_const_2 *d)*s**_sage_const_2  + _sage_const_1 )
307        altx = _sage_const_2 *s*cls.isoMagic/t
308        if negative(altx): t = -t
309        x = _sage_const_2 *s / (_sage_const_1 +a*s**_sage_const_2 )
310        y = (_sage_const_1 -a*s**_sage_const_2 ) / t
311
312        if cls.cofactor==_sage_const_8  and (negative(x*y*cls.isoMagic) or y==_sage_const_0 ):
313            raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
314
315        return cls(x,y)
316
317    def toJacobiQuartic(self,toggle_rotation=False,toggle_altx=False,toggle_s=False):
318        "Return s,t on jacobi curve"
319        a,d = self.a,self.d
320        x,y,z,t = self.xyzt()
321
322        if self.cofactor == _sage_const_8 :
323            # Cofactor 8 version
324            # Simulate IMAGINE_TWIST because that's how libdecaf does it
325            x = self.i*x
326            t = self.i*t
327            a = -a
328            d = -d
329
330            # OK, the actual libdecaf code should be here
331            num = (z+y)*(z-y)
332            den = x*y
333            isr = isqrt(num*(a-d)*den**_sage_const_2 )
334
335            iden = isr * den * self.isoMagic # 1/sqrt((z+y)(z-y)) = 1/sqrt(1-Y^2) / z
336            inum = isr * num # sqrt(1-Y^2) * z / xysqrt(a-d) ~ 1/sqrt(1-ax^2)/z
337
338            if negative(iden*inum*self.i*t**_sage_const_2 *(d-a)) != toggle_rotation:
339                iden,inum = inum,iden
340                fac = x*sqrt(a)
341                toggle=(a==-_sage_const_1 )
342            else:
343                fac = y
344                toggle=False
345
346            imi = self.isoMagic * self.i
347            altx = inum*t*imi
348            neg_altx = negative(altx) != toggle_altx
349            if neg_altx != toggle: inum =- inum
350
351            tmp = fac*(inum*z + _sage_const_1 )
352            s = iden*tmp*imi
353
354            negm1 = (negative(s) != toggle_s) != neg_altx
355            if negm1: m1 = a*fac + z
356            else:     m1 = a*fac - z
357
358            swap = toggle_s
359
360        else:
361            # Much simpler cofactor 4 version
362            num = (x+t)*(x-t)
363            isr = isqrt(num*(a-d)*x**_sage_const_2 )
364            ratio = isr*num
365            altx = ratio*self.isoMagic
366
367            neg_altx = negative(altx) != toggle_altx
368            if neg_altx: ratio =- ratio
369
370            tmp = ratio*z - t
371            s = (a-d)*isr*x*tmp
372
373            negx = (negative(s) != toggle_s) != neg_altx
374            if negx: m1 = -a*t + x
375            else:    m1 = -a*t - x
376
377            swap = toggle_s
378
379        if negative(s): s = -s
380
381        return s,m1,a*tmp,swap
382
383    def invertElligator(self,toggle_r=False,*args,**kwargs):
384        "Produce preimage of self under elligator, or None"
385        a,d = self.a,self.d
386
387        rets = []
388
389        tr = [False,True] if self.cofactor == _sage_const_8  else [False]
390        for toggle_rotation in tr:
391            for toggle_altx in [False,True]:
392                for toggle_s in [False,True]:
393                    for toggle_r in [False,True]:
394                        s,m1,m12,swap = self.toJacobiQuartic(toggle_rotation,toggle_altx,toggle_s)
395
396                        #print
397                        #print toggle_rotation,toggle_altx,toggle_s
398                        #print m1
399                        #print m12
400
401
402                        if self == self.__class__():
403                            if self.cofactor == _sage_const_4 :
404                                # Hacks for identity!
405                                if toggle_altx: m12 = _sage_const_1
406                                elif toggle_s: m1 = _sage_const_1
407                                elif toggle_r: continue
408                                ## BOTH???
409
410                            else:
411                                m12 = _sage_const_1
412                                imi = self.isoMagic * self.i
413                                if toggle_rotation:
414                                    if toggle_altx: m1 = -imi
415                                    else:           m1 = +imi
416                                else:
417                                    if toggle_altx: m1 = _sage_const_0
418                                    else: m1 = a-d
419
420                        rnum = (d*a*m12-m1)
421                        rden = ((d*a-_sage_const_1 )*m12+m1)
422                        if swap: rnum,rden = rden,rnum
423
424                        ok,sr = isqrt_i(rnum*rden*self.qnr)
425                        if not ok: continue
426                        sr *= rnum
427                        #print "Works! %d %x" % (swap,sr)
428
429                        if negative(sr) != toggle_r: sr = -sr
430                        ret = self.gfToBytes(sr)
431                        if self.elligator(ret) != self and self.elligator(ret) != -self:
432                            print "WRONG!",[toggle_rotation,toggle_altx,toggle_s]
433                        if self.elligator(ret) == -self and self != -self: print "Negated!",[toggle_rotation,toggle_altx,toggle_s]
434                        rets.append(bytes(ret))
435        return rets
436
437    @optimized_version_of("encodeSpec")
438    def encode(self):
439        """Encode, optimized version"""
440        return self.gfToBytes(self.toJacobiQuartic()[_sage_const_0 ])
441
442    @classmethod
443    @optimized_version_of("decodeSpec")
444    def decode(cls,s):
445        """Decode, optimized version"""
446        a,d = cls.a,cls.d
447        s = cls.bytesToGf(s,mustBePositive=True)
448
449        #if s==0: return cls()
450        s2 = s**_sage_const_2
451        den = _sage_const_1 +a*s2
452        num = den**_sage_const_2  - _sage_const_4 *d*s2
453        isr = isqrt(num*den**_sage_const_2 )
454        altx = _sage_const_2 *s*isr*den*cls.isoMagic
455        if negative(altx): isr = -isr
456        x = _sage_const_2 *s *isr**_sage_const_2 *den*num
457        y = (_sage_const_1 -a*s**_sage_const_2 ) * isr*den
458
459        if cls.cofactor==_sage_const_8  and (negative(x*y*cls.isoMagic) or y==_sage_const_0 ):
460            raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
461
462        return cls(x,y)
463
464    @classmethod
465    def fromJacobiQuartic(cls,s,t,sgn=_sage_const_1 ):
466        """Convert point from its Jacobi Quartic representation"""
467        a,d = cls.a,cls.d
468        if s==_sage_const_0 : return cls()
469        x = _sage_const_2 *s / (_sage_const_1 +a*s**_sage_const_2 )
470        y = (_sage_const_1 -a*s**_sage_const_2 ) / t
471        return cls(x,sgn*y)
472
473    @classmethod
474    def elligatorSpec(cls,r0,fromR=False):
475        a,d = cls.a,cls.d
476        if fromR: r = r0
477        else: r = cls.qnr * cls.bytesToGf(r0,mustBeProper=False,maskHiBits=True)**_sage_const_2
478
479        den = (d*r-(d-a))*((d-a)*r-d)
480        if den == _sage_const_0 : return cls()
481        n1 = (r+_sage_const_1 )*(a-_sage_const_2 *d)/den
482        n2 = r*n1
483        if is_square(n1):
484            sgn,s,t = _sage_const_1 ,   xsqrt(n1),  -(r-_sage_const_1 )*(a-_sage_const_2 *d)**_sage_const_2  / den - _sage_const_1
485        else:
486            sgn,s,t = -_sage_const_1 , -xsqrt(n2), r*(r-_sage_const_1 )*(a-_sage_const_2 *d)**_sage_const_2  / den - _sage_const_1
487
488        return cls.fromJacobiQuartic(s,t)
489
490    @classmethod
491    @optimized_version_of("elligatorSpec")
492    def elligator(cls,r0):
493        a,d = cls.a,cls.d
494        r0 = cls.bytesToGf(r0,mustBeProper=False,maskHiBits=True)
495        r = cls.qnr * r0**_sage_const_2
496        den = (d*r-(d-a))*((d-a)*r-d)
497        num = (r+_sage_const_1 )*(a-_sage_const_2 *d)
498
499        iss,isri = isqrt_i(num*den)
500        if iss: sgn,twiddle =  _sage_const_1 ,_sage_const_1
501        else:   sgn,twiddle = -_sage_const_1 ,r0*cls.qnr
502        isri *= twiddle
503        s = isri*num
504        t = -sgn*isri*s*(r-_sage_const_1 )*(a-_sage_const_2 *d)**_sage_const_2  - _sage_const_1
505        if negative(s) == iss: s = -s
506        return cls.fromJacobiQuartic(s,t)
507
508    def elligatorInverseBruteForce(self):
509        """Invert Elligator using SAGE's polynomial solver"""
510        a,d = self.a,self.d
511        R = self.F['r0']; (r0,) = R._first_ngens(1)
512        r = self.qnr * r0**_sage_const_2
513        den = (d*r-(d-a))*((d-a)*r-d)
514        n1 = (r+_sage_const_1 )*(a-_sage_const_2 *d)/den
515        n2 = r*n1
516        ret = set()
517        for s2,t in [(n1, -(r-_sage_const_1 )*(a-_sage_const_2 *d)**_sage_const_2  / den - _sage_const_1 ),
518                     (n2,r*(r-_sage_const_1 )*(a-_sage_const_2 *d)**_sage_const_2  / den - _sage_const_1 )]:
519            x2 = _sage_const_4 *s2/(_sage_const_1 +a*s2)**_sage_const_2
520            y = (_sage_const_1 -a*s2) / t
521
522            selfT = self
523            for i in xrange(self.cofactor/_sage_const_2 ):
524                xT,yT = selfT
525                polyX = xT**_sage_const_2 -x2
526                polyY = yT-y
527                sx = set(r for r,_ in polyX.numerator().roots())
528                sy = set(r for r,_ in polyY.numerator().roots())
529                ret = ret.union(sx.intersection(sy))
530
531                selfT = selfT.torque()
532
533        ret = [self.gfToBytes(r) for r in ret]
534
535        for r in ret:
536            assert self.elligator(r) in [self,-self]
537
538        ret = [r for r in ret if self.elligator(r) == self]
539
540        return ret
541
542class Ed25519Point(RistrettoPoint):
543    F = GF(_sage_const_2 **_sage_const_255 -_sage_const_19 )
544    d = F(-_sage_const_121665 /_sage_const_121666 )
545    a = F(-_sage_const_1 )
546    i = sqrt(F(-_sage_const_1 ))
547    mneg = F(_sage_const_1 )
548    qnr = i
549    magic = isqrt(a*d-_sage_const_1 )
550    cofactor = _sage_const_8
551    encLen = _sage_const_32
552
553    @classmethod
554    def base(cls):
555        return cls( _sage_const_15112221349535400772501151409588531511454012693041857206046113283949847762202 , _sage_const_46316835694926478169428394003475163141307993866256225615783033603165251855960
556        )
557
558class NegEd25519Point(RistrettoPoint):
559    F = GF(_sage_const_2 **_sage_const_255 -_sage_const_19 )
560    d = F(_sage_const_121665 /_sage_const_121666 )
561    a = F(_sage_const_1 )
562    i = sqrt(F(-_sage_const_1 ))
563    mneg = F(-_sage_const_1 ) # TODO checkme vs 1-ad or whatever
564    qnr = i
565    magic = isqrt(a*d-_sage_const_1 )
566    cofactor = _sage_const_8
567    encLen = _sage_const_32
568
569    @classmethod
570    def base(cls):
571        y = cls.F(_sage_const_4 /_sage_const_5 )
572        x = sqrt((y**_sage_const_2 -_sage_const_1 )/(cls.d*y**_sage_const_2 -cls.a))
573        if negative(x): x = -x
574        return cls(x,y)
575
576class IsoEd448Point(RistrettoPoint):
577    F = GF(_sage_const_2 **_sage_const_448 -_sage_const_2 **_sage_const_224 -_sage_const_1 )
578    d = F(_sage_const_39082 /_sage_const_39081 )
579    a = F(_sage_const_1 )
580    mneg = F(-_sage_const_1 )
581    qnr = -_sage_const_1
582    magic = isqrt(a*d-_sage_const_1 )
583    cofactor = _sage_const_4
584    encLen = _sage_const_56
585
586    @classmethod
587    def base(cls):
588        return cls(  # RFC has it wrong
589         _sage_const_345397493039729516374008604150537410266655260075183290216406970281645695073672344430481787759340633221708391583424041788924124567700732 ,
590            -_sage_const_363419362147803445274661903944002267176820680343659030140745099590306164083365386343198191849338272965044442230921818680526749009182718
591        )
592
593class TwistedEd448GoldilocksPoint(Decaf_1_1_Point):
594    F = GF(_sage_const_2 **_sage_const_448 -_sage_const_2 **_sage_const_224 -_sage_const_1 )
595    d = F(-_sage_const_39082 )
596    a = F(-_sage_const_1 )
597    qnr = -_sage_const_1
598    cofactor = _sage_const_4
599    encLen = _sage_const_56
600    isoMagic = IsoEd448Point.magic
601
602    @classmethod
603    def base(cls):
604        return cls.decodeSpec(Ed448GoldilocksPoint.base().encodeSpec())
605
606class Ed448GoldilocksPoint(Decaf_1_1_Point):
607    F = GF(_sage_const_2 **_sage_const_448 -_sage_const_2 **_sage_const_224 -_sage_const_1 )
608    d = F(-_sage_const_39081 )
609    a = F(_sage_const_1 )
610    qnr = -_sage_const_1
611    cofactor = _sage_const_4
612    encLen = _sage_const_56
613    isoMagic = IsoEd448Point.magic
614
615    @classmethod
616    def base(cls):
617        return _sage_const_2 *cls(
618 _sage_const_224580040295924300187604334099896036246789641632564134246125461686950415467406032909029192869357953282578032075146446173674602635247710 , _sage_const_298819210078481492676017930443930673437544040154080242095928241372331506189835876003536878655418784733982303233503462500531545062832660
619        )
620
621class IsoEd25519Point(Decaf_1_1_Point):
622    # TODO: twisted iso too!
623    # TODO: twisted iso might have to IMAGINE_TWIST or whatever
624    F = GF(_sage_const_2 **_sage_const_255 -_sage_const_19 )
625    d = F(-_sage_const_121665 )
626    a = F(_sage_const_1 )
627    i = sqrt(F(-_sage_const_1 ))
628    qnr = i
629    magic = isqrt(a*d-_sage_const_1 )
630    cofactor = _sage_const_8
631    encLen = _sage_const_32
632    isoMagic = Ed25519Point.magic
633    isoA = Ed25519Point.a
634
635    @classmethod
636    def base(cls):
637        return cls.decodeSpec(Ed25519Point.base().encode())
638
639class TestFailedException(Exception): pass
640
641def test(cls,n):
642    print "Testing curve %s" % cls.__name__
643
644    specials = [_sage_const_1 ]
645    ii = cls.F(-_sage_const_1 )
646    while is_square(ii):
647        specials.append(ii)
648        ii = sqrt(ii)
649    specials.append(ii)
650    for i in specials:
651        if negative(cls.F(i)): i = -i
652        i = enc_le(i,cls.encLen)
653        try:
654            Q = cls.decode(i)
655            QE = Q.encode()
656            if QE != i:
657                raise TestFailedException("Round trip special %s != %s" %
658                     (binascii.hexlify(QE),binascii.hexlify(i)))
659        except NotOnCurveException: pass
660        except InvalidEncodingException: pass
661
662
663    P = cls.base()
664    Q = cls()
665    for i in xrange(n):
666        #print binascii.hexlify(Q.encode())
667        QE = Q.encode()
668        QQ = cls.decode(QE)
669        if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q)))
670
671        # Testing s -> 1/s: encodes -point on cofactor
672        s = cls.bytesToGf(QE)
673        if s != _sage_const_0 :
674            ss = cls.gfToBytes(_sage_const_1 /s,mustBePositive=True)
675            try:
676                QN = cls.decode(ss)
677                if cls.cofactor == _sage_const_8 :
678                    raise TestFailedException("1/s shouldnt work for cofactor 8")
679                if QN != -Q:
680                    raise TestFailedException("s -> 1/s should negate point for cofactor 4")
681            except InvalidEncodingException as e:
682                # Should be raised iff cofactor==8
683                if cls.cofactor == _sage_const_4 :
684                    raise TestFailedException("s -> 1/s should work for cofactor 4")
685
686        QT = Q
687        for h in xrange(cls.cofactor):
688            QT = QT.torque()
689            if QT.encode() != QE:
690                raise TestFailedException("Can't torque %s,%d" % (str(Q),h+_sage_const_1 ))
691
692        Q0 = Q + P
693        if Q0 == Q: raise TestFailedException("Addition doesn't work")
694        if Q0-P != Q: raise TestFailedException("Subtraction doesn't work")
695
696        r = randint(_sage_const_1 ,_sage_const_1000 )
697        Q1 = Q0*r
698        Q2 = Q0*(r+_sage_const_1 )
699        if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work")
700        Q = Q1
701
702
703def testElligator(cls,n):
704    print "Testing elligator on %s" % cls.__name__
705    for i in xrange(n):
706        r = randombytes(cls.encLen)
707        P = cls.elligator(r)
708        if hasattr(P,"invertElligator"):
709            iv = P.invertElligator()
710            modr = bytes(cls.gfToBytes(cls.bytesToGf(r,mustBeProper=False)))
711            iv2 = P.torque().invertElligator()
712            if modr not in iv: print "Failed to invert Elligator!"
713            if len(iv) != len(set(iv)):
714                print "Elligator inverses not unique!", len(set(iv)), len(iv)
715            if iv != iv2:
716                print "Elligator is untorqueable!"
717                #print [binascii.hexlify(j) for j in iv]
718                #print [binascii.hexlify(j) for j in iv2]
719                #break
720        else:
721            pass # TODO
722
723
724
725
726def gangtest(classes,n):
727    print "Gang test",[cls.__name__ for cls in classes]
728    specials = [_sage_const_1 ]
729    ii = classes[_sage_const_0 ].F(-_sage_const_1 )
730    while is_square(ii):
731        specials.append(ii)
732        ii = sqrt(ii)
733    specials.append(ii)
734
735    for i in xrange(n):
736        rets = [bytes((cls.base()*i).encode()) for cls in classes]
737        if len(set(rets)) != _sage_const_1 :
738            print "Divergence in encode at %d" % i
739            for c,ret in zip(classes,rets):
740                print c,binascii.hexlify(ret)
741            print
742
743        if i < len(specials): r0 = enc_le(specials[i],classes[_sage_const_0 ].encLen)
744        else: r0 = randombytes(classes[_sage_const_0 ].encLen)
745
746        rets = [bytes((cls.elligator(r0)*i).encode()) for cls in classes]
747        if len(set(rets)) != _sage_const_1 :
748            print "Divergence in elligator at %d" % i
749            for c,ret in zip(classes,rets):
750                print c,binascii.hexlify(ret)
751            print
752
753
754test(Ed25519Point,_sage_const_100 )
755test(NegEd25519Point,_sage_const_100 )
756test(IsoEd25519Point,_sage_const_100 )
757test(IsoEd448Point,_sage_const_100 )
758test(TwistedEd448GoldilocksPoint,_sage_const_100 )
759test(Ed448GoldilocksPoint,_sage_const_100 )
760testElligator(Ed25519Point,_sage_const_100 )
761testElligator(NegEd25519Point,_sage_const_100 )
762testElligator(IsoEd25519Point,_sage_const_100 )
763testElligator(IsoEd448Point,_sage_const_100 )
764testElligator(Ed448GoldilocksPoint,_sage_const_100 )
765testElligator(TwistedEd448GoldilocksPoint,_sage_const_100 )
766gangtest([IsoEd448Point,TwistedEd448GoldilocksPoint,Ed448GoldilocksPoint],_sage_const_100 )
767gangtest([Ed25519Point,IsoEd25519Point],_sage_const_100 )
768