1// Copyright 2020 ConsenSys Software Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Code generated by consensys/gnark-crypto DO NOT EDIT
16
17package fft
18
19import (
20	"math/bits"
21	"runtime"
22
23	"github.com/consensys/gnark-crypto/internal/parallel"
24
25	"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
26)
27
28// Decimation is used in the FFT call to select decimation in time or in frequency
29type Decimation uint8
30
31const (
32	DIT Decimation = iota
33	DIF
34)
35
36// parallelize threshold for a single butterfly op, if the fft stage is not parallelized already
37const butterflyThreshold = 16
38
39// FFT computes (recursively) the discrete Fourier transform of a and stores the result in a
40// if decimation == DIT (decimation in time), the input must be in bit-reversed order
41// if decimation == DIF (decimation in frequency), the output will be in bit-reversed order
42// coset sets the shift of the fft (0 = no shift, standard fft)
43// len(a) must be a power of 2, and w must be a len(a)th root of unity in field F.
44//
45// example:
46// -------
47// domain := NewDomain(m, 2) -->  contains precomputed data for Z/mZ, and Z/4mZ
48// FFT(pol, DIT, 1) --> evaluates pol on the coset 1 in (Z/4mZ)/(Z/mZ)
49func (domain *Domain) FFT(a []fr.Element, decimation Decimation, coset uint64) {
50
51	numCPU := uint64(runtime.NumCPU())
52
53	if coset != 0 {
54		if decimation == DIT {
55			BitReverse(domain.CosetTable[coset-1])
56		}
57		parallel.Execute(len(a), func(start, end int) {
58			for i := start; i < end; i++ {
59				a[i].Mul(&a[i], &domain.CosetTable[coset-1][i])
60			}
61		})
62		// put it back as we found it
63		if decimation == DIT {
64			BitReverse(domain.CosetTable[coset-1])
65		}
66	}
67
68	// find the stage where we should stop spawning go routines in our recursive calls
69	// (ie when we have as many go routines running as we have available CPUs)
70	maxSplits := bits.TrailingZeros64(nextPowerOfTwo(numCPU))
71	if numCPU <= 1 {
72		maxSplits = -1
73	}
74
75	switch decimation {
76	case DIF:
77		difFFT(a, domain.Twiddles, 0, maxSplits, nil)
78	case DIT:
79		ditFFT(a, domain.Twiddles, 0, maxSplits, nil)
80	default:
81		panic("not implemented")
82	}
83}
84
85// FFTInverse computes (recursively) the inverse discrete Fourier transform of a and stores the result in a
86// if decimation == DIT (decimation in time), the input must be in bit-reversed order
87// if decimation == DIF (decimation in frequency), the output will be in bit-reversed order
88// coset sets the shift of the fft (0 = no shift, standard fft)
89// len(a) must be a power of 2, and w must be a len(a)th root of unity in field F.
90func (domain *Domain) FFTInverse(a []fr.Element, decimation Decimation, coset uint64) {
91
92	numCPU := uint64(runtime.NumCPU())
93
94	// find the stage where we should stop spawning go routines in our recursive calls
95	// (ie when we have as many go routines running as we have available CPUs)
96	maxSplits := bits.TrailingZeros64(nextPowerOfTwo(numCPU))
97	if numCPU <= 1 {
98		maxSplits = -1
99	}
100	switch decimation {
101	case DIF:
102		difFFT(a, domain.TwiddlesInv, 0, maxSplits, nil)
103	case DIT:
104		ditFFT(a, domain.TwiddlesInv, 0, maxSplits, nil)
105	default:
106		panic("not implemented")
107	}
108
109	// scale by CardinalityInv (+ cosetTableInv is coset!=0)
110	if coset != 0 {
111		if decimation == DIF {
112			BitReverse(domain.CosetTableInv[coset-1])
113		}
114		parallel.Execute(len(a), func(start, end int) {
115			for i := start; i < end; i++ {
116				a[i].Mul(&a[i], &domain.CosetTableInv[coset-1][i]).
117					Mul(&a[i], &domain.CardinalityInv)
118			}
119		})
120		// put it back as we found it
121		if decimation == DIF {
122			BitReverse(domain.CosetTableInv[coset-1])
123		}
124	} else {
125		parallel.Execute(len(a), func(start, end int) {
126			for i := start; i < end; i++ {
127				a[i].Mul(&a[i], &domain.CardinalityInv)
128			}
129		})
130	}
131}
132
133func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) {
134	if chDone != nil {
135		defer func() {
136			chDone <- struct{}{}
137		}()
138	}
139	n := len(a)
140	if n == 1 {
141		return
142	}
143	m := n >> 1
144
145	// if stage < maxSplits, we parallelize this butterfly
146	// but we have only numCPU / stage cpus available
147	if (m > butterflyThreshold) && (stage < maxSplits) {
148		// 1 << stage == estimated used CPUs
149		numCPU := runtime.NumCPU() / (1 << (stage))
150		parallel.Execute(m, func(start, end int) {
151			var t fr.Element
152			for i := start; i < end; i++ {
153				t = a[i]
154				a[i].Add(&a[i], &a[i+m])
155
156				a[i+m].
157					Sub(&t, &a[i+m]).
158					Mul(&a[i+m], &twiddles[stage][i])
159			}
160		}, numCPU)
161	} else {
162		var t fr.Element
163
164		// i == 0
165		t = a[0]
166		a[0].Add(&a[0], &a[m])
167		a[m].Sub(&t, &a[m])
168
169		for i := 1; i < m; i++ {
170			t = a[i]
171			a[i].Add(&a[i], &a[i+m])
172
173			a[i+m].
174				Sub(&t, &a[i+m]).
175				Mul(&a[i+m], &twiddles[stage][i])
176		}
177	}
178
179	if m == 1 {
180		return
181	}
182
183	nextStage := stage + 1
184	if stage < maxSplits {
185		chDone := make(chan struct{}, 1)
186		go difFFT(a[m:n], twiddles, nextStage, maxSplits, chDone)
187		difFFT(a[0:m], twiddles, nextStage, maxSplits, nil)
188		<-chDone
189	} else {
190		difFFT(a[0:m], twiddles, nextStage, maxSplits, nil)
191		difFFT(a[m:n], twiddles, nextStage, maxSplits, nil)
192	}
193}
194
195func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) {
196	if chDone != nil {
197		defer func() {
198			chDone <- struct{}{}
199		}()
200	}
201	n := len(a)
202	if n == 1 {
203		return
204	}
205	m := n >> 1
206
207	nextStage := stage + 1
208
209	if stage < maxSplits {
210		// that's the only time we fire go routines
211		chDone := make(chan struct{}, 1)
212		go ditFFT(a[m:], twiddles, nextStage, maxSplits, chDone)
213		ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil)
214		<-chDone
215	} else {
216		ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil)
217		ditFFT(a[m:n], twiddles, nextStage, maxSplits, nil)
218
219	}
220
221	// if stage < maxSplits, we parallelize this butterfly
222	// but we have only numCPU / stage cpus available
223	if (m > butterflyThreshold) && (stage < maxSplits) {
224		// 1 << stage == estimated used CPUs
225		numCPU := runtime.NumCPU() / (1 << (stage))
226		parallel.Execute(m, func(start, end int) {
227			var t, tm fr.Element
228			for k := start; k < end; k++ {
229				t = a[k]
230				tm.Mul(&a[k+m], &twiddles[stage][k])
231				a[k].Add(&a[k], &tm)
232				a[k+m].Sub(&t, &tm)
233			}
234		}, numCPU)
235
236	} else {
237		var t, tm fr.Element
238		// k == 0
239		// wPow == 1
240		t = a[0]
241		a[0].Add(&a[0], &a[m])
242		a[m].Sub(&t, &a[m])
243
244		for k := 1; k < m; k++ {
245			t = a[k]
246			tm.Mul(&a[k+m], &twiddles[stage][k])
247			a[k].Add(&a[k], &tm)
248			a[k+m].Sub(&t, &tm)
249		}
250	}
251}
252
253// BitReverse applies the bit-reversal permutation to a.
254// len(a) must be a power of 2 (as in every single function in this file)
255func BitReverse(a []fr.Element) {
256	n := uint64(len(a))
257	nn := uint64(64 - bits.TrailingZeros64(n))
258
259	for i := uint64(0); i < n; i++ {
260		irev := bits.Reverse64(i) >> nn
261		if irev > i {
262			a[i], a[irev] = a[irev], a[i]
263		}
264	}
265}
266