1// Copyright 2020 ConsenSys Software Inc. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15// Code generated by consensys/gnark-crypto DO NOT EDIT 16 17package fft 18 19import ( 20 "math/bits" 21 "runtime" 22 23 "github.com/consensys/gnark-crypto/internal/parallel" 24 25 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" 26) 27 28// Decimation is used in the FFT call to select decimation in time or in frequency 29type Decimation uint8 30 31const ( 32 DIT Decimation = iota 33 DIF 34) 35 36// parallelize threshold for a single butterfly op, if the fft stage is not parallelized already 37const butterflyThreshold = 16 38 39// FFT computes (recursively) the discrete Fourier transform of a and stores the result in a 40// if decimation == DIT (decimation in time), the input must be in bit-reversed order 41// if decimation == DIF (decimation in frequency), the output will be in bit-reversed order 42// coset sets the shift of the fft (0 = no shift, standard fft) 43// len(a) must be a power of 2, and w must be a len(a)th root of unity in field F. 44// 45// example: 46// ------- 47// domain := NewDomain(m, 2) --> contains precomputed data for Z/mZ, and Z/4mZ 48// FFT(pol, DIT, 1) --> evaluates pol on the coset 1 in (Z/4mZ)/(Z/mZ) 49func (domain *Domain) FFT(a []fr.Element, decimation Decimation, coset uint64) { 50 51 numCPU := uint64(runtime.NumCPU()) 52 53 if coset != 0 { 54 if decimation == DIT { 55 BitReverse(domain.CosetTable[coset-1]) 56 } 57 parallel.Execute(len(a), func(start, end int) { 58 for i := start; i < end; i++ { 59 a[i].Mul(&a[i], &domain.CosetTable[coset-1][i]) 60 } 61 }) 62 // put it back as we found it 63 if decimation == DIT { 64 BitReverse(domain.CosetTable[coset-1]) 65 } 66 } 67 68 // find the stage where we should stop spawning go routines in our recursive calls 69 // (ie when we have as many go routines running as we have available CPUs) 70 maxSplits := bits.TrailingZeros64(nextPowerOfTwo(numCPU)) 71 if numCPU <= 1 { 72 maxSplits = -1 73 } 74 75 switch decimation { 76 case DIF: 77 difFFT(a, domain.Twiddles, 0, maxSplits, nil) 78 case DIT: 79 ditFFT(a, domain.Twiddles, 0, maxSplits, nil) 80 default: 81 panic("not implemented") 82 } 83} 84 85// FFTInverse computes (recursively) the inverse discrete Fourier transform of a and stores the result in a 86// if decimation == DIT (decimation in time), the input must be in bit-reversed order 87// if decimation == DIF (decimation in frequency), the output will be in bit-reversed order 88// coset sets the shift of the fft (0 = no shift, standard fft) 89// len(a) must be a power of 2, and w must be a len(a)th root of unity in field F. 90func (domain *Domain) FFTInverse(a []fr.Element, decimation Decimation, coset uint64) { 91 92 numCPU := uint64(runtime.NumCPU()) 93 94 // find the stage where we should stop spawning go routines in our recursive calls 95 // (ie when we have as many go routines running as we have available CPUs) 96 maxSplits := bits.TrailingZeros64(nextPowerOfTwo(numCPU)) 97 if numCPU <= 1 { 98 maxSplits = -1 99 } 100 switch decimation { 101 case DIF: 102 difFFT(a, domain.TwiddlesInv, 0, maxSplits, nil) 103 case DIT: 104 ditFFT(a, domain.TwiddlesInv, 0, maxSplits, nil) 105 default: 106 panic("not implemented") 107 } 108 109 // scale by CardinalityInv (+ cosetTableInv is coset!=0) 110 if coset != 0 { 111 if decimation == DIF { 112 BitReverse(domain.CosetTableInv[coset-1]) 113 } 114 parallel.Execute(len(a), func(start, end int) { 115 for i := start; i < end; i++ { 116 a[i].Mul(&a[i], &domain.CosetTableInv[coset-1][i]). 117 Mul(&a[i], &domain.CardinalityInv) 118 } 119 }) 120 // put it back as we found it 121 if decimation == DIF { 122 BitReverse(domain.CosetTableInv[coset-1]) 123 } 124 } else { 125 parallel.Execute(len(a), func(start, end int) { 126 for i := start; i < end; i++ { 127 a[i].Mul(&a[i], &domain.CardinalityInv) 128 } 129 }) 130 } 131} 132 133func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { 134 if chDone != nil { 135 defer func() { 136 chDone <- struct{}{} 137 }() 138 } 139 n := len(a) 140 if n == 1 { 141 return 142 } 143 m := n >> 1 144 145 // if stage < maxSplits, we parallelize this butterfly 146 // but we have only numCPU / stage cpus available 147 if (m > butterflyThreshold) && (stage < maxSplits) { 148 // 1 << stage == estimated used CPUs 149 numCPU := runtime.NumCPU() / (1 << (stage)) 150 parallel.Execute(m, func(start, end int) { 151 var t fr.Element 152 for i := start; i < end; i++ { 153 t = a[i] 154 a[i].Add(&a[i], &a[i+m]) 155 156 a[i+m]. 157 Sub(&t, &a[i+m]). 158 Mul(&a[i+m], &twiddles[stage][i]) 159 } 160 }, numCPU) 161 } else { 162 var t fr.Element 163 164 // i == 0 165 t = a[0] 166 a[0].Add(&a[0], &a[m]) 167 a[m].Sub(&t, &a[m]) 168 169 for i := 1; i < m; i++ { 170 t = a[i] 171 a[i].Add(&a[i], &a[i+m]) 172 173 a[i+m]. 174 Sub(&t, &a[i+m]). 175 Mul(&a[i+m], &twiddles[stage][i]) 176 } 177 } 178 179 if m == 1 { 180 return 181 } 182 183 nextStage := stage + 1 184 if stage < maxSplits { 185 chDone := make(chan struct{}, 1) 186 go difFFT(a[m:n], twiddles, nextStage, maxSplits, chDone) 187 difFFT(a[0:m], twiddles, nextStage, maxSplits, nil) 188 <-chDone 189 } else { 190 difFFT(a[0:m], twiddles, nextStage, maxSplits, nil) 191 difFFT(a[m:n], twiddles, nextStage, maxSplits, nil) 192 } 193} 194 195func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { 196 if chDone != nil { 197 defer func() { 198 chDone <- struct{}{} 199 }() 200 } 201 n := len(a) 202 if n == 1 { 203 return 204 } 205 m := n >> 1 206 207 nextStage := stage + 1 208 209 if stage < maxSplits { 210 // that's the only time we fire go routines 211 chDone := make(chan struct{}, 1) 212 go ditFFT(a[m:], twiddles, nextStage, maxSplits, chDone) 213 ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil) 214 <-chDone 215 } else { 216 ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil) 217 ditFFT(a[m:n], twiddles, nextStage, maxSplits, nil) 218 219 } 220 221 // if stage < maxSplits, we parallelize this butterfly 222 // but we have only numCPU / stage cpus available 223 if (m > butterflyThreshold) && (stage < maxSplits) { 224 // 1 << stage == estimated used CPUs 225 numCPU := runtime.NumCPU() / (1 << (stage)) 226 parallel.Execute(m, func(start, end int) { 227 var t, tm fr.Element 228 for k := start; k < end; k++ { 229 t = a[k] 230 tm.Mul(&a[k+m], &twiddles[stage][k]) 231 a[k].Add(&a[k], &tm) 232 a[k+m].Sub(&t, &tm) 233 } 234 }, numCPU) 235 236 } else { 237 var t, tm fr.Element 238 // k == 0 239 // wPow == 1 240 t = a[0] 241 a[0].Add(&a[0], &a[m]) 242 a[m].Sub(&t, &a[m]) 243 244 for k := 1; k < m; k++ { 245 t = a[k] 246 tm.Mul(&a[k+m], &twiddles[stage][k]) 247 a[k].Add(&a[k], &tm) 248 a[k+m].Sub(&t, &tm) 249 } 250 } 251} 252 253// BitReverse applies the bit-reversal permutation to a. 254// len(a) must be a power of 2 (as in every single function in this file) 255func BitReverse(a []fr.Element) { 256 n := uint64(len(a)) 257 nn := uint64(64 - bits.TrailingZeros64(n)) 258 259 for i := uint64(0); i < n; i++ { 260 irev := bits.Reverse64(i) >> nn 261 if irev > i { 262 a[i], a[irev] = a[irev], a[i] 263 } 264 } 265} 266