1# Created by Pearu Peterson, September 2002
2
3__usage__ = """
4Build fftpack:
5  python setup_fftpack.py build
6Run tests if scipy is installed:
7  python -c 'import scipy;scipy.fftpack.test(<level>)'
8Run tests if fftpack is not installed:
9  python tests/test_pseudo_diffs.py [<level>]
10"""
11
12from numpy.testing import (assert_equal, assert_almost_equal,
13                           assert_array_almost_equal)
14from scipy.fftpack import (diff, fft, ifft, tilbert, itilbert, hilbert,
15                           ihilbert, shift, fftfreq, cs_diff, sc_diff,
16                           ss_diff, cc_diff)
17
18import numpy as np
19from numpy import arange, sin, cos, pi, exp, tanh, sum, sign
20from numpy.random import random
21
22
23def direct_diff(x,k=1,period=None):
24    fx = fft(x)
25    n = len(fx)
26    if period is None:
27        period = 2*pi
28    w = fftfreq(n)*2j*pi/period*n
29    if k < 0:
30        w = 1 / w**k
31        w[0] = 0.0
32    else:
33        w = w**k
34    if n > 2000:
35        w[250:n-250] = 0.0
36    return ifft(w*fx).real
37
38
39def direct_tilbert(x,h=1,period=None):
40    fx = fft(x)
41    n = len(fx)
42    if period is None:
43        period = 2*pi
44    w = fftfreq(n)*h*2*pi/period*n
45    w[0] = 1
46    w = 1j/tanh(w)
47    w[0] = 0j
48    return ifft(w*fx)
49
50
51def direct_itilbert(x,h=1,period=None):
52    fx = fft(x)
53    n = len(fx)
54    if period is None:
55        period = 2*pi
56    w = fftfreq(n)*h*2*pi/period*n
57    w = -1j*tanh(w)
58    return ifft(w*fx)
59
60
61def direct_hilbert(x):
62    fx = fft(x)
63    n = len(fx)
64    w = fftfreq(n)*n
65    w = 1j*sign(w)
66    return ifft(w*fx)
67
68
69def direct_ihilbert(x):
70    return -direct_hilbert(x)
71
72
73def direct_shift(x,a,period=None):
74    n = len(x)
75    if period is None:
76        k = fftfreq(n)*1j*n
77    else:
78        k = fftfreq(n)*2j*pi/period*n
79    return ifft(fft(x)*exp(k*a)).real
80
81
82class TestDiff:
83
84    def test_definition(self):
85        for n in [16,17,64,127,32]:
86            x = arange(n)*2*pi/n
87            assert_array_almost_equal(diff(sin(x)),direct_diff(sin(x)))
88            assert_array_almost_equal(diff(sin(x),2),direct_diff(sin(x),2))
89            assert_array_almost_equal(diff(sin(x),3),direct_diff(sin(x),3))
90            assert_array_almost_equal(diff(sin(x),4),direct_diff(sin(x),4))
91            assert_array_almost_equal(diff(sin(x),5),direct_diff(sin(x),5))
92            assert_array_almost_equal(diff(sin(2*x),3),direct_diff(sin(2*x),3))
93            assert_array_almost_equal(diff(sin(2*x),4),direct_diff(sin(2*x),4))
94            assert_array_almost_equal(diff(cos(x)),direct_diff(cos(x)))
95            assert_array_almost_equal(diff(cos(x),2),direct_diff(cos(x),2))
96            assert_array_almost_equal(diff(cos(x),3),direct_diff(cos(x),3))
97            assert_array_almost_equal(diff(cos(x),4),direct_diff(cos(x),4))
98            assert_array_almost_equal(diff(cos(2*x)),direct_diff(cos(2*x)))
99            assert_array_almost_equal(diff(sin(x*n/8)),direct_diff(sin(x*n/8)))
100            assert_array_almost_equal(diff(cos(x*n/8)),direct_diff(cos(x*n/8)))
101            for k in range(5):
102                assert_array_almost_equal(diff(sin(4*x),k),direct_diff(sin(4*x),k))
103                assert_array_almost_equal(diff(cos(4*x),k),direct_diff(cos(4*x),k))
104
105    def test_period(self):
106        for n in [17,64]:
107            x = arange(n)/float(n)
108            assert_array_almost_equal(diff(sin(2*pi*x),period=1),
109                                      2*pi*cos(2*pi*x))
110            assert_array_almost_equal(diff(sin(2*pi*x),3,period=1),
111                                      -(2*pi)**3*cos(2*pi*x))
112
113    def test_sin(self):
114        for n in [32,64,77]:
115            x = arange(n)*2*pi/n
116            assert_array_almost_equal(diff(sin(x)),cos(x))
117            assert_array_almost_equal(diff(cos(x)),-sin(x))
118            assert_array_almost_equal(diff(sin(x),2),-sin(x))
119            assert_array_almost_equal(diff(sin(x),4),sin(x))
120            assert_array_almost_equal(diff(sin(4*x)),4*cos(4*x))
121            assert_array_almost_equal(diff(sin(sin(x))),cos(x)*cos(sin(x)))
122
123    def test_expr(self):
124        for n in [64,77,100,128,256,512,1024,2048,4096,8192][:5]:
125            x = arange(n)*2*pi/n
126            f = sin(x)*cos(4*x)+exp(sin(3*x))
127            df = cos(x)*cos(4*x)-4*sin(x)*sin(4*x)+3*cos(3*x)*exp(sin(3*x))
128            ddf = -17*sin(x)*cos(4*x)-8*cos(x)*sin(4*x)\
129                 - 9*sin(3*x)*exp(sin(3*x))+9*cos(3*x)**2*exp(sin(3*x))
130            d1 = diff(f)
131            assert_array_almost_equal(d1,df)
132            assert_array_almost_equal(diff(df),ddf)
133            assert_array_almost_equal(diff(f,2),ddf)
134            assert_array_almost_equal(diff(ddf,-1),df)
135
136    def test_expr_large(self):
137        for n in [2048,4096]:
138            x = arange(n)*2*pi/n
139            f = sin(x)*cos(4*x)+exp(sin(3*x))
140            df = cos(x)*cos(4*x)-4*sin(x)*sin(4*x)+3*cos(3*x)*exp(sin(3*x))
141            ddf = -17*sin(x)*cos(4*x)-8*cos(x)*sin(4*x)\
142                 - 9*sin(3*x)*exp(sin(3*x))+9*cos(3*x)**2*exp(sin(3*x))
143            assert_array_almost_equal(diff(f),df)
144            assert_array_almost_equal(diff(df),ddf)
145            assert_array_almost_equal(diff(ddf,-1),df)
146            assert_array_almost_equal(diff(f,2),ddf)
147
148    def test_int(self):
149        n = 64
150        x = arange(n)*2*pi/n
151        assert_array_almost_equal(diff(sin(x),-1),-cos(x))
152        assert_array_almost_equal(diff(sin(x),-2),-sin(x))
153        assert_array_almost_equal(diff(sin(x),-4),sin(x))
154        assert_array_almost_equal(diff(2*cos(2*x),-1),sin(2*x))
155
156    def test_random_even(self):
157        for k in [0,2,4,6]:
158            for n in [60,32,64,56,55]:
159                f = random((n,))
160                af = sum(f,axis=0)/n
161                f = f-af
162                # zeroing Nyquist mode:
163                f = diff(diff(f,1),-1)
164                assert_almost_equal(sum(f,axis=0),0.0)
165                assert_array_almost_equal(diff(diff(f,k),-k),f)
166                assert_array_almost_equal(diff(diff(f,-k),k),f)
167
168    def test_random_odd(self):
169        for k in [0,1,2,3,4,5,6]:
170            for n in [33,65,55]:
171                f = random((n,))
172                af = sum(f,axis=0)/n
173                f = f-af
174                assert_almost_equal(sum(f,axis=0),0.0)
175                assert_array_almost_equal(diff(diff(f,k),-k),f)
176                assert_array_almost_equal(diff(diff(f,-k),k),f)
177
178    def test_zero_nyquist(self):
179        for k in [0,1,2,3,4,5,6]:
180            for n in [32,33,64,56,55]:
181                f = random((n,))
182                af = sum(f,axis=0)/n
183                f = f-af
184                # zeroing Nyquist mode:
185                f = diff(diff(f,1),-1)
186                assert_almost_equal(sum(f,axis=0),0.0)
187                assert_array_almost_equal(diff(diff(f,k),-k),f)
188                assert_array_almost_equal(diff(diff(f,-k),k),f)
189
190
191class TestTilbert:
192
193    def test_definition(self):
194        for h in [0.1,0.5,1,5.5,10]:
195            for n in [16,17,64,127]:
196                x = arange(n)*2*pi/n
197                y = tilbert(sin(x),h)
198                y1 = direct_tilbert(sin(x),h)
199                assert_array_almost_equal(y,y1)
200                assert_array_almost_equal(tilbert(sin(x),h),
201                                          direct_tilbert(sin(x),h))
202                assert_array_almost_equal(tilbert(sin(2*x),h),
203                                          direct_tilbert(sin(2*x),h))
204
205    def test_random_even(self):
206        for h in [0.1,0.5,1,5.5,10]:
207            for n in [32,64,56]:
208                f = random((n,))
209                af = sum(f,axis=0)/n
210                f = f-af
211                assert_almost_equal(sum(f,axis=0),0.0)
212                assert_array_almost_equal(direct_tilbert(direct_itilbert(f,h),h),f)
213
214    def test_random_odd(self):
215        for h in [0.1,0.5,1,5.5,10]:
216            for n in [33,65,55]:
217                f = random((n,))
218                af = sum(f,axis=0)/n
219                f = f-af
220                assert_almost_equal(sum(f,axis=0),0.0)
221                assert_array_almost_equal(itilbert(tilbert(f,h),h),f)
222                assert_array_almost_equal(tilbert(itilbert(f,h),h),f)
223
224
225class TestITilbert:
226
227    def test_definition(self):
228        for h in [0.1,0.5,1,5.5,10]:
229            for n in [16,17,64,127]:
230                x = arange(n)*2*pi/n
231                y = itilbert(sin(x),h)
232                y1 = direct_itilbert(sin(x),h)
233                assert_array_almost_equal(y,y1)
234                assert_array_almost_equal(itilbert(sin(x),h),
235                                          direct_itilbert(sin(x),h))
236                assert_array_almost_equal(itilbert(sin(2*x),h),
237                                          direct_itilbert(sin(2*x),h))
238
239
240class TestHilbert:
241
242    def test_definition(self):
243        for n in [16,17,64,127]:
244            x = arange(n)*2*pi/n
245            y = hilbert(sin(x))
246            y1 = direct_hilbert(sin(x))
247            assert_array_almost_equal(y,y1)
248            assert_array_almost_equal(hilbert(sin(2*x)),
249                                      direct_hilbert(sin(2*x)))
250
251    def test_tilbert_relation(self):
252        for n in [16,17,64,127]:
253            x = arange(n)*2*pi/n
254            f = sin(x)+cos(2*x)*sin(x)
255            y = hilbert(f)
256            y1 = direct_hilbert(f)
257            assert_array_almost_equal(y,y1)
258            y2 = tilbert(f,h=10)
259            assert_array_almost_equal(y,y2)
260
261    def test_random_odd(self):
262        for n in [33,65,55]:
263            f = random((n,))
264            af = sum(f,axis=0)/n
265            f = f-af
266            assert_almost_equal(sum(f,axis=0),0.0)
267            assert_array_almost_equal(ihilbert(hilbert(f)),f)
268            assert_array_almost_equal(hilbert(ihilbert(f)),f)
269
270    def test_random_even(self):
271        for n in [32,64,56]:
272            f = random((n,))
273            af = sum(f,axis=0)/n
274            f = f-af
275            # zeroing Nyquist mode:
276            f = diff(diff(f,1),-1)
277            assert_almost_equal(sum(f,axis=0),0.0)
278            assert_array_almost_equal(direct_hilbert(direct_ihilbert(f)),f)
279            assert_array_almost_equal(hilbert(ihilbert(f)),f)
280
281
282class TestIHilbert:
283
284    def test_definition(self):
285        for n in [16,17,64,127]:
286            x = arange(n)*2*pi/n
287            y = ihilbert(sin(x))
288            y1 = direct_ihilbert(sin(x))
289            assert_array_almost_equal(y,y1)
290            assert_array_almost_equal(ihilbert(sin(2*x)),
291                                      direct_ihilbert(sin(2*x)))
292
293    def test_itilbert_relation(self):
294        for n in [16,17,64,127]:
295            x = arange(n)*2*pi/n
296            f = sin(x)+cos(2*x)*sin(x)
297            y = ihilbert(f)
298            y1 = direct_ihilbert(f)
299            assert_array_almost_equal(y,y1)
300            y2 = itilbert(f,h=10)
301            assert_array_almost_equal(y,y2)
302
303
304class TestShift:
305
306    def test_definition(self):
307        for n in [18,17,64,127,32,2048,256]:
308            x = arange(n)*2*pi/n
309            for a in [0.1,3]:
310                assert_array_almost_equal(shift(sin(x),a),direct_shift(sin(x),a))
311                assert_array_almost_equal(shift(sin(x),a),sin(x+a))
312                assert_array_almost_equal(shift(cos(x),a),cos(x+a))
313                assert_array_almost_equal(shift(cos(2*x)+sin(x),a),
314                                          cos(2*(x+a))+sin(x+a))
315                assert_array_almost_equal(shift(exp(sin(x)),a),exp(sin(x+a)))
316            assert_array_almost_equal(shift(sin(x),2*pi),sin(x))
317            assert_array_almost_equal(shift(sin(x),pi),-sin(x))
318            assert_array_almost_equal(shift(sin(x),pi/2),cos(x))
319
320
321class TestOverwrite:
322    """Check input overwrite behavior """
323
324    real_dtypes = (np.float32, np.float64)
325    dtypes = real_dtypes + (np.complex64, np.complex128)
326
327    def _check(self, x, routine, *args, **kwargs):
328        x2 = x.copy()
329        routine(x2, *args, **kwargs)
330        sig = routine.__name__
331        if args:
332            sig += repr(args)
333        if kwargs:
334            sig += repr(kwargs)
335        assert_equal(x2, x, err_msg="spurious overwrite in %s" % sig)
336
337    def _check_1d(self, routine, dtype, shape, *args, **kwargs):
338        np.random.seed(1234)
339        if np.issubdtype(dtype, np.complexfloating):
340            data = np.random.randn(*shape) + 1j*np.random.randn(*shape)
341        else:
342            data = np.random.randn(*shape)
343        data = data.astype(dtype)
344        self._check(data, routine, *args, **kwargs)
345
346    def test_diff(self):
347        for dtype in self.dtypes:
348            self._check_1d(diff, dtype, (16,))
349
350    def test_tilbert(self):
351        for dtype in self.dtypes:
352            self._check_1d(tilbert, dtype, (16,), 1.6)
353
354    def test_itilbert(self):
355        for dtype in self.dtypes:
356            self._check_1d(itilbert, dtype, (16,), 1.6)
357
358    def test_hilbert(self):
359        for dtype in self.dtypes:
360            self._check_1d(hilbert, dtype, (16,))
361
362    def test_cs_diff(self):
363        for dtype in self.dtypes:
364            self._check_1d(cs_diff, dtype, (16,), 1.0, 4.0)
365
366    def test_sc_diff(self):
367        for dtype in self.dtypes:
368            self._check_1d(sc_diff, dtype, (16,), 1.0, 4.0)
369
370    def test_ss_diff(self):
371        for dtype in self.dtypes:
372            self._check_1d(ss_diff, dtype, (16,), 1.0, 4.0)
373
374    def test_cc_diff(self):
375        for dtype in self.dtypes:
376            self._check_1d(cc_diff, dtype, (16,), 1.0, 4.0)
377
378    def test_shift(self):
379        for dtype in self.dtypes:
380            self._check_1d(shift, dtype, (16,), 1.0)
381