1# This program is free software: you can redistribute it and/or modify
2# it under the terms of the GNU General Public License as published by
3# the Free Software Foundation, either version 3 of the License, or
4# (at your option) any later version.
5#
6# This program is distributed in the hope that it will be useful,
7# but WITHOUT ANY WARRANTY; without even the implied warranty of
8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
9# GNU General Public License for more details.
10#
11# You should have received a copy of the GNU General Public License
12# along with this program.  If not, see <http://www.gnu.org/licenses/>.
13#
14# Copyright(C) 2019-2021 Max-Planck-Society
15
16
17import numpy as np
18import ducc0
19from time import time
20import matplotlib.pyplot as plt
21
22
23rng = np.random.default_rng(42)
24
25
26def measure_fftw(a, nrepeat, nthr, flags=('FFTW_MEASURE',)):
27    import pyfftw
28    f1 = pyfftw.empty_aligned(a.shape, dtype=a.dtype)
29    f2 = pyfftw.empty_aligned(a.shape, dtype=a.dtype)
30    fftw = pyfftw.FFTW(f1, f2, flags=flags, axes=range(a.ndim), threads=nthr)
31    f1[()] = a
32    times = []
33    for i in range(nrepeat):
34        t0 = time()
35        fftw()
36        t1 = time()
37        times.append(t1-t0)
38    return times, f2
39
40
41def measure_fftw_est(a, nrepeat, nthr):
42    return measure_fftw(a, nrepeat, nthr, flags=('FFTW_ESTIMATE',))
43
44
45def measure_fftw_np_interface(a, nrepeat, nthr):
46    import pyfftw
47    pyfftw.interfaces.cache.enable()
48    times = []
49    b = None
50    for i in range(nrepeat):
51        del b
52        t0 = time()
53        b = pyfftw.interfaces.numpy_fft.fftn(a)
54        t1 = time()
55        del b
56        times.append(t1-t0)
57    return times, b
58
59
60def measure_duccfft(a, nrepeat, nthr):
61    times = []
62    b = a.copy()
63    for i in range(nrepeat):
64        t0 = time()
65        b = ducc0.fft.c2c(a, out=b, forward=True, nthreads=nthr)
66        t1 = time()
67        times.append(t1-t0)
68    return times, b
69
70
71# ducc0.fft, avoiding critical array strides and using in-place transforms.
72# This is probably the most performant mode for ducc0 with multi-D transforms.
73def measure_duccfft_noncrit_inplace(a, nrepeat, nthr):
74    times = []
75    work = ducc0.misc.make_noncritical(a.copy())
76    for i in range(nrepeat):
77        work[()] = a
78        t0 = time()
79        work = ducc0.fft.c2c(work, out=work, forward=True, nthreads=nthr)
80        t1 = time()
81        times.append(t1-t0)
82    return times, work
83
84
85def measure_scipy_fftpack(a, nrepeat, nthr):
86    import scipy.fftpack
87    times = []
88    if nthr != 1:
89        raise NotImplementedError("scipy.fftpack does not support multiple threads")
90    b = None
91    for i in range(nrepeat):
92        del b
93        t0 = time()
94        b = scipy.fftpack.fftn(a)
95        t1 = time()
96        times.append (t1-t0)
97    return times, b
98
99
100def measure_scipy_fft(a, nrepeat, nthr):
101    import scipy.fft
102    times = []
103    b = None
104    for i in range(nrepeat):
105        del b
106        t0 = time()
107        b = scipy.fft.fftn(a, workers=nthr)
108        t1 = time()
109        times.append(t1-t0)
110    return times, b
111
112
113def measure_numpy_fft(a, nrepeat, nthr):
114    if nthr != 1:
115        raise NotImplementedError("numpy.fft does not support multiple threads")
116    times = []
117    b = None
118    for i in range(nrepeat):
119        del b
120        t0 = time()
121        b = np.fft.fftn(a)
122        t1 = time()
123        times.append(t1-t0)
124    return times, b
125
126
127def measure_mkl_fft(a, nrepeat, nthr):
128    import os
129    os.environ['OMP_NUM_THREADS'] = str(nthr)
130    import mkl_fft
131    times = []
132    b = None
133    for i in range(nrepeat):
134        del b
135        t0 = time()
136        b = mkl_fft.fftn(a)
137        t1 = time()
138        times.append(t1-t0)
139    return times, b
140
141
142def bench_nd(ndim, nmax, nthr, ntry, tp, funcs, nrepeat, ttl="", filename="",
143             nice_sizes=True):
144    print("{}D, type {}, max extent is {}:".format(ndim, tp, nmax))
145    results = [[] for i in range(len(funcs))]
146    for n in range(ntry):
147        shp = rng.integers(nmax//3, nmax+1, ndim)
148        if nice_sizes:
149            shp = np.array([ducc0.fft.good_size(sz) for sz in shp])
150        print("  {0:4d}/{1}: shape={2} ...".format(n, ntry, shp), end=" ", flush=True)
151        a = (rng.random(shp)-0.5 + 1j*(rng.random(shp)-0.5)).astype(tp)
152        output = []
153        for func, res in zip(funcs, results):
154            tmp = func(a, nrepeat, nthr)
155            res.append(np.average(tmp[0]))
156            output.append(tmp[1])
157        print("{0:5.2e}/{1:5.2e} = {2:5.2f}  L2 error={3}".format(results[0][n], results[1][n], results[0][n]/results[1][n], ducc0.misc.l2error(output[0], output[1])))
158    results = np.array(results)
159    plt.title("{}: {}D, {}, max_extent={}".format(
160        ttl, ndim, str(tp), nmax))
161    plt.xlabel("time ratio")
162    plt.ylabel("counts")
163    plt.hist(results[0, :]/results[1, :], bins="auto")
164    if filename != "":
165        plt.savefig(filename)
166    plt.show()
167    plt.close()
168
169
170funcs = (measure_duccfft_noncrit_inplace, measure_fftw)
171ttl = "duccfft/FFTW"
172ntry = 10
173nthr = 1
174nice_sizes = True
175limits = [8192, 2048, 256]
176#limits = [524288, 8192, 512]
177bench_nd(1, limits[0], nthr, ntry, "c16", funcs, 10, ttl, "1d.png", nice_sizes)
178bench_nd(2, limits[1], nthr, ntry, "c16", funcs, 10, ttl, "2d.png", nice_sizes)
179bench_nd(3, limits[2], nthr, ntry, "c16", funcs, 10, ttl, "3d.png", nice_sizes)
180bench_nd(1, limits[0], nthr, ntry, "c8", funcs, 10, ttl, "1d_single.png", nice_sizes)
181bench_nd(2, limits[1], nthr, ntry, "c8", funcs, 10, ttl, "2d_single.png", nice_sizes)
182bench_nd(3, limits[2], nthr, ntry, "c8", funcs, 10, ttl, "3d_single.png", nice_sizes)
183