1import (
2	"math/big"
3	"math/bits"
4	"runtime"
5	"sync"
6	"io"
7
8	{{ template "import_fr" . }}
9	{{ template "import_curve" . }}
10)
11
12// Domain with a power of 2 cardinality
13// compute a field element of order 2x and store it in FinerGenerator
14// all other values can be derived from x, GeneratorSqrt
15type Domain struct {
16	Cardinality       uint64
17	Depth          uint64
18	CardinalityInv    fr.Element
19	Generator         fr.Element
20	GeneratorInv      fr.Element
21	FinerGenerator    fr.Element
22	FinerGeneratorInv fr.Element
23
24	// the following slices are not serialized and are (re)computed through domain.preComputeTwiddles()
25
26	// Twiddles factor for the FFT using Generator for each stage of the recursive FFT
27	Twiddles [][]fr.Element
28
29	// Twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT
30	TwiddlesInv [][]fr.Element
31
32	// we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover
33
34	// CosetTable[i][j] = domain.Generator(i-th)Sqrt ^ j
35	// CosetTable = fft.BitReverse(CosetTable)
36	CosetTable [][]fr.Element
37
38	// CosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j
39	// CosetTableInv = fft.BitReverse(CosetTableInv)
40	CosetTableInv [][]fr.Element
41}
42
43// NewDomain returns a subgroup with a power of 2 cardinality
44// cardinality >= m
45// If depth>0, the Domain will also store a primitive (2**depth)*m root
46// of 1, with associated precomputed data. This allows to perform shifted
47// FFT/FFTInv.
48//
49// example:
50// --------
51//
52// NewDomain(m, 2) outputs a new domain to perform fft on Z/mZ, plus a primitive
53// 2**2*m=4m-th root of 1 and associated data to compute fft/fftinv on the cosets of
54// (Z/4mZ)/(Z/mZ).
55func NewDomain(m, depth uint64) *Domain {
56
57	// generator of the largest 2-adic subgroup
58	var rootOfUnity fr.Element
59	{{if eq .Name "bls12-377"}}
60		rootOfUnity.SetString("8065159656716812877374967518403273466521432693661810619979959746626482506078")
61		const maxOrderRoot uint64 = 47
62	{{else if eq .Name "bls12-381"}}
63		rootOfUnity.SetString("10238227357739495823651030575849232062558860180284477541189508159991286009131")
64		const maxOrderRoot uint64 = 32
65	{{else if eq .Name "bn254"}}
66		rootOfUnity.SetString("19103219067921713944291392827692070036145651957329286315305642004821462161904")
67		const maxOrderRoot uint64 = 28
68	{{else if eq .Name "bw6-761"}}
69		rootOfUnity.SetString("32863578547254505029601261939868325669770508939375122462904745766352256812585773382134936404344547323199885654433")
70		const maxOrderRoot uint64 = 46
71	{{end}}
72
73		subGroup := &Domain{}
74	x := nextPowerOfTwo(m)
75	subGroup.Cardinality = uint64(x)
76	subGroup.Depth = depth
77
78	// find generator for Z/2^(log(m))Z  and Z/2^(log(m)+cosets)Z
79	logx := uint64(bits.TrailingZeros64(x))
80	if logx > maxOrderRoot {
81		panic("m is too big: the required root of unity does not exist")
82	}
83	logGen := logx + depth
84	if logGen > maxOrderRoot {
85		panic("log(m) + cosets is too big: the required root of unity does not exist")
86	}
87
88	expo := uint64(1 << (maxOrderRoot - logGen))
89	bExpo := new(big.Int).SetUint64(expo)
90	subGroup.FinerGenerator.Exp(rootOfUnity, bExpo)
91	subGroup.FinerGeneratorInv.Inverse(&subGroup.FinerGenerator)
92
93	// Generator = FinerGenerator^2 has order x
94	expo = uint64(1 << (maxOrderRoot - logx))
95	bExpo.SetUint64(expo)
96	subGroup.Generator.Exp(rootOfUnity, bExpo) // order x
97	subGroup.GeneratorInv.Inverse(&subGroup.Generator)
98	subGroup.CardinalityInv.SetUint64(uint64(x)).Inverse(&subGroup.CardinalityInv)
99
100	// twiddle factors
101	subGroup.preComputeTwiddles()
102
103	return subGroup
104}
105
106func (d *Domain) preComputeTwiddles() {
107
108	// nb fft stages
109	nbStages := uint64(bits.TrailingZeros64(d.Cardinality))
110	nbCosets := (1 << d.Depth) - 1
111
112	d.Twiddles = make([][]fr.Element, nbStages)
113	d.TwiddlesInv = make([][]fr.Element, nbStages)
114	d.CosetTable = make([][]fr.Element, nbCosets)
115	d.CosetTableInv = make([][]fr.Element, nbCosets)
116	for i := 0; i < nbCosets; i++ {
117		d.CosetTable[i] = make([]fr.Element, d.Cardinality)
118		d.CosetTableInv[i] = make([]fr.Element, d.Cardinality)
119	}
120
121	var wg sync.WaitGroup
122
123	// for each fft stage, we pre compute the twiddle factors
124	twiddles := func(t [][]fr.Element, omega fr.Element) {
125		for i := uint64(0); i < nbStages; i++ {
126			t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1)))
127			var w fr.Element
128			if i == 0 {
129				w = omega
130			} else {
131				w = t[i-1][2]
132			}
133			t[i][0] = fr.One()
134			t[i][1] = w
135			for j := 2; j < len(t[i]); j++ {
136				t[i][j].Mul(&t[i][j-1], &w)
137			}
138		}
139		wg.Done()
140	}
141
142	expTable := func(sqrt fr.Element, t []fr.Element) {
143		t[0] = fr.One()
144		precomputeExpTable(sqrt, t)
145		wg.Done()
146	}
147
148	if nbCosets > 0 {
149		cosetGens := make([]fr.Element, nbCosets)
150		cosetGensInv := make([]fr.Element, nbCosets)
151		cosetGens[0].Set(&d.FinerGenerator)
152		cosetGensInv[0].Set(&d.FinerGeneratorInv)
153		for i := 1; i < nbCosets; i++ {
154			cosetGens[i].Mul(&cosetGens[i-1], &d.FinerGenerator)
155			cosetGensInv[i].Mul(&cosetGensInv[1], &d.FinerGeneratorInv)
156		}
157		wg.Add(2 + 2*nbCosets)
158		go twiddles(d.Twiddles, d.Generator)
159		go twiddles(d.TwiddlesInv, d.GeneratorInv)
160		for i := 0; i < nbCosets-1; i++ {
161			go expTable(cosetGens[i], d.CosetTable[i])
162			go expTable(cosetGensInv[i], d.CosetTableInv[i])
163		}
164		go expTable(cosetGens[nbCosets-1], d.CosetTable[nbCosets-1])
165		expTable(cosetGensInv[nbCosets-1], d.CosetTableInv[nbCosets-1])
166
167		wg.Wait()
168
169	} else {
170		wg.Add(2)
171		go twiddles(d.Twiddles, d.Generator)
172		twiddles(d.TwiddlesInv, d.GeneratorInv)
173		wg.Wait()
174	}
175
176}
177
178func precomputeExpTable(w fr.Element, table []fr.Element) {
179	n := len(table)
180
181	// see if it makes sense to parallelize exp tables pre-computation
182	interval := 0
183	if runtime.NumCPU() >= 4 {
184		interval = (n - 1) / (runtime.NumCPU() / 4)
185	}
186
187	// this ratio roughly correspond to the number of multiplication one can do in place of a Exp operation
188	const ratioExpMul = 6000 / 17
189
190	if interval < ratioExpMul {
191		precomputeExpTableChunk(w, 1, table[1:])
192		return
193	}
194
195	// we parallelize
196	var wg sync.WaitGroup
197	for i := 1; i < n; i += interval {
198		start := i
199		end := i + interval
200		if end > n {
201			end = n
202		}
203		wg.Add(1)
204		go func() {
205			precomputeExpTableChunk(w, uint64(start), table[start:end])
206			wg.Done()
207		}()
208	}
209	wg.Wait()
210}
211
212func precomputeExpTableChunk(w fr.Element, power uint64, table []fr.Element) {
213	table[0].Exp(w, new(big.Int).SetUint64(power))
214	for i := 1; i < len(table); i++ {
215		table[i].Mul(&table[i-1], &w)
216	}
217}
218
219func nextPowerOfTwo(n uint64) uint64 {
220	p := uint64(1)
221	if (n & (n - 1)) == 0 {
222		return n
223	}
224	for p < n {
225		p <<= 1
226	}
227	return p
228}
229
230// WriteTo writes a binary representation of the domain (without the precomputed twiddle factors)
231// to the provided writer
232func (d *Domain) WriteTo(w io.Writer) (int64, error) {
233
234	enc := curve.NewEncoder(w)
235
236	toEncode := []interface{}{d.Cardinality, d.Depth, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FinerGenerator, &d.FinerGeneratorInv}
237
238	for _, v := range toEncode {
239		if err := enc.Encode(v); err != nil {
240			return enc.BytesWritten(), err
241		}
242	}
243
244	return enc.BytesWritten(), nil
245}
246
247// ReadFrom attempts to decode a domain from Reader
248func (d *Domain) ReadFrom(r io.Reader) (int64, error) {
249
250	dec := curve.NewDecoder(r)
251
252	toDecode := []interface{}{&d.Cardinality, &d.Depth, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FinerGenerator, &d.FinerGeneratorInv}
253
254	for _, v := range toDecode {
255		if err := dec.Decode(v); err != nil {
256			return dec.BytesRead(), err
257		}
258	}
259
260	d.preComputeTwiddles()
261	return dec.BytesRead(), nil
262}
263