1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package bn256
6
7import (
8	"math/big"
9)
10
11// curvePoint implements the elliptic curve y²=x³+3. Points are kept in
12// Jacobian form and t=z² when valid. G₁ is the set of points of this curve on
13// GF(p).
14type curvePoint struct {
15	x, y, z, t *big.Int
16}
17
18var curveB = new(big.Int).SetInt64(3)
19
20// curveGen is the generator of G₁.
21var curveGen = &curvePoint{
22	new(big.Int).SetInt64(1),
23	new(big.Int).SetInt64(-2),
24	new(big.Int).SetInt64(1),
25	new(big.Int).SetInt64(1),
26}
27
28func newCurvePoint(pool *bnPool) *curvePoint {
29	return &curvePoint{
30		pool.Get(),
31		pool.Get(),
32		pool.Get(),
33		pool.Get(),
34	}
35}
36
37func (c *curvePoint) String() string {
38	c.MakeAffine(new(bnPool))
39	return "(" + c.x.String() + ", " + c.y.String() + ")"
40}
41
42func (c *curvePoint) Put(pool *bnPool) {
43	pool.Put(c.x)
44	pool.Put(c.y)
45	pool.Put(c.z)
46	pool.Put(c.t)
47}
48
49func (c *curvePoint) Set(a *curvePoint) {
50	c.x.Set(a.x)
51	c.y.Set(a.y)
52	c.z.Set(a.z)
53	c.t.Set(a.t)
54}
55
56// IsOnCurve returns true iff c is on the curve where c must be in affine form.
57func (c *curvePoint) IsOnCurve() bool {
58	yy := new(big.Int).Mul(c.y, c.y)
59	xxx := new(big.Int).Mul(c.x, c.x)
60	xxx.Mul(xxx, c.x)
61	yy.Sub(yy, xxx)
62	yy.Sub(yy, curveB)
63	if yy.Sign() < 0 || yy.Cmp(p) >= 0 {
64		yy.Mod(yy, p)
65	}
66	return yy.Sign() == 0
67}
68
69func (c *curvePoint) SetInfinity() {
70	c.z.SetInt64(0)
71}
72
73func (c *curvePoint) IsInfinity() bool {
74	return c.z.Sign() == 0
75}
76
77func (c *curvePoint) Add(a, b *curvePoint, pool *bnPool) {
78	if a.IsInfinity() {
79		c.Set(b)
80		return
81	}
82	if b.IsInfinity() {
83		c.Set(a)
84		return
85	}
86
87	// See http://hyperelliptic.org/EFD/g1p/auto-code/shortw/jacobian-0/addition/add-2007-bl.op3
88
89	// Normalize the points by replacing a = [x1:y1:z1] and b = [x2:y2:z2]
90	// by [u1:s1:z1·z2] and [u2:s2:z1·z2]
91	// where u1 = x1·z2², s1 = y1·z2³ and u1 = x2·z1², s2 = y2·z1³
92	z1z1 := pool.Get().Mul(a.z, a.z)
93	z1z1.Mod(z1z1, p)
94	z2z2 := pool.Get().Mul(b.z, b.z)
95	z2z2.Mod(z2z2, p)
96	u1 := pool.Get().Mul(a.x, z2z2)
97	u1.Mod(u1, p)
98	u2 := pool.Get().Mul(b.x, z1z1)
99	u2.Mod(u2, p)
100
101	t := pool.Get().Mul(b.z, z2z2)
102	t.Mod(t, p)
103	s1 := pool.Get().Mul(a.y, t)
104	s1.Mod(s1, p)
105
106	t.Mul(a.z, z1z1)
107	t.Mod(t, p)
108	s2 := pool.Get().Mul(b.y, t)
109	s2.Mod(s2, p)
110
111	// Compute x = (2h)²(s²-u1-u2)
112	// where s = (s2-s1)/(u2-u1) is the slope of the line through
113	// (u1,s1) and (u2,s2). The extra factor 2h = 2(u2-u1) comes from the value of z below.
114	// This is also:
115	// 4(s2-s1)² - 4h²(u1+u2) = 4(s2-s1)² - 4h³ - 4h²(2u1)
116	//                        = r² - j - 2v
117	// with the notations below.
118	h := pool.Get().Sub(u2, u1)
119	xEqual := h.Sign() == 0
120
121	t.Add(h, h)
122	// i = 4h²
123	i := pool.Get().Mul(t, t)
124	i.Mod(i, p)
125	// j = 4h³
126	j := pool.Get().Mul(h, i)
127	j.Mod(j, p)
128
129	t.Sub(s2, s1)
130	yEqual := t.Sign() == 0
131	if xEqual && yEqual {
132		c.Double(a, pool)
133		return
134	}
135	r := pool.Get().Add(t, t)
136
137	v := pool.Get().Mul(u1, i)
138	v.Mod(v, p)
139
140	// t4 = 4(s2-s1)²
141	t4 := pool.Get().Mul(r, r)
142	t4.Mod(t4, p)
143	t.Add(v, v)
144	t6 := pool.Get().Sub(t4, j)
145	c.x.Sub(t6, t)
146
147	// Set y = -(2h)³(s1 + s*(x/4h²-u1))
148	// This is also
149	// y = - 2·s1·j - (s2-s1)(2x - 2i·u1) = r(v-x) - 2·s1·j
150	t.Sub(v, c.x) // t7
151	t4.Mul(s1, j) // t8
152	t4.Mod(t4, p)
153	t6.Add(t4, t4) // t9
154	t4.Mul(r, t)   // t10
155	t4.Mod(t4, p)
156	c.y.Sub(t4, t6)
157
158	// Set z = 2(u2-u1)·z1·z2 = 2h·z1·z2
159	t.Add(a.z, b.z) // t11
160	t4.Mul(t, t)    // t12
161	t4.Mod(t4, p)
162	t.Sub(t4, z1z1) // t13
163	t4.Sub(t, z2z2) // t14
164	c.z.Mul(t4, h)
165	c.z.Mod(c.z, p)
166
167	pool.Put(z1z1)
168	pool.Put(z2z2)
169	pool.Put(u1)
170	pool.Put(u2)
171	pool.Put(t)
172	pool.Put(s1)
173	pool.Put(s2)
174	pool.Put(h)
175	pool.Put(i)
176	pool.Put(j)
177	pool.Put(r)
178	pool.Put(v)
179	pool.Put(t4)
180	pool.Put(t6)
181}
182
183func (c *curvePoint) Double(a *curvePoint, pool *bnPool) {
184	// See http://hyperelliptic.org/EFD/g1p/auto-code/shortw/jacobian-0/doubling/dbl-2009-l.op3
185	A := pool.Get().Mul(a.x, a.x)
186	A.Mod(A, p)
187	B := pool.Get().Mul(a.y, a.y)
188	B.Mod(B, p)
189	C := pool.Get().Mul(B, B)
190	C.Mod(C, p)
191
192	t := pool.Get().Add(a.x, B)
193	t2 := pool.Get().Mul(t, t)
194	t2.Mod(t2, p)
195	t.Sub(t2, A)
196	t2.Sub(t, C)
197	d := pool.Get().Add(t2, t2)
198	t.Add(A, A)
199	e := pool.Get().Add(t, A)
200	f := pool.Get().Mul(e, e)
201	f.Mod(f, p)
202
203	t.Add(d, d)
204	c.x.Sub(f, t)
205
206	t.Add(C, C)
207	t2.Add(t, t)
208	t.Add(t2, t2)
209	c.y.Sub(d, c.x)
210	t2.Mul(e, c.y)
211	t2.Mod(t2, p)
212	c.y.Sub(t2, t)
213
214	t.Mul(a.y, a.z)
215	t.Mod(t, p)
216	c.z.Add(t, t)
217
218	pool.Put(A)
219	pool.Put(B)
220	pool.Put(C)
221	pool.Put(t)
222	pool.Put(t2)
223	pool.Put(d)
224	pool.Put(e)
225	pool.Put(f)
226}
227
228func (c *curvePoint) Mul(a *curvePoint, scalar *big.Int, pool *bnPool) *curvePoint {
229	sum := newCurvePoint(pool)
230	sum.SetInfinity()
231	t := newCurvePoint(pool)
232
233	for i := scalar.BitLen(); i >= 0; i-- {
234		t.Double(sum, pool)
235		if scalar.Bit(i) != 0 {
236			sum.Add(t, a, pool)
237		} else {
238			sum.Set(t)
239		}
240	}
241
242	c.Set(sum)
243	sum.Put(pool)
244	t.Put(pool)
245	return c
246}
247
248// MakeAffine converts c to affine form and returns c. If c is ∞, then it sets
249// c to 0 : 1 : 0.
250func (c *curvePoint) MakeAffine(pool *bnPool) *curvePoint {
251	if words := c.z.Bits(); len(words) == 1 && words[0] == 1 {
252		return c
253	}
254	if c.IsInfinity() {
255		c.x.SetInt64(0)
256		c.y.SetInt64(1)
257		c.z.SetInt64(0)
258		c.t.SetInt64(0)
259		return c
260	}
261
262	zInv := pool.Get().ModInverse(c.z, p)
263	t := pool.Get().Mul(c.y, zInv)
264	t.Mod(t, p)
265	zInv2 := pool.Get().Mul(zInv, zInv)
266	zInv2.Mod(zInv2, p)
267	c.y.Mul(t, zInv2)
268	c.y.Mod(c.y, p)
269	t.Mul(c.x, zInv2)
270	t.Mod(t, p)
271	c.x.Set(t)
272	c.z.SetInt64(1)
273	c.t.SetInt64(1)
274
275	pool.Put(zInv)
276	pool.Put(t)
277	pool.Put(zInv2)
278
279	return c
280}
281
282func (c *curvePoint) Negative(a *curvePoint) {
283	c.x.Set(a.x)
284	c.y.Neg(a.y)
285	c.z.Set(a.z)
286	c.t.SetInt64(0)
287}
288