1#!/usr/bin/env python
2from __future__ import division, print_function, absolute_import
3
4import numpy as np
5from numpy.testing import (assert_allclose, assert_, assert_raises,
6                           assert_array_equal)
7import pywt
8
9# Check that float32, float64, complex64, complex128 are preserved.
10# Other real types get converted to float64.
11# complex256 gets converted to complex128
12dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
13             np.complex128]
14dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
15              np.complex128]
16
17# test complex256 as well if it is available
18try:
19    dtypes_in += [np.complex256, ]
20    dtypes_out += [np.complex128, ]
21except AttributeError:
22    pass
23
24
25def test_dwt_idwt_basic():
26    x = [3, 7, 1, 1, -2, 5, 4, 6]
27    cA, cD = pywt.dwt(x, 'db2')
28    cA_expect = [5.65685425, 7.39923721, 0.22414387, 3.33677403, 7.77817459]
29    cD_expect = [-2.44948974, -1.60368225, -4.44140056, -0.41361256,
30                 1.22474487]
31    assert_allclose(cA, cA_expect)
32    assert_allclose(cD, cD_expect)
33
34    x_roundtrip = pywt.idwt(cA, cD, 'db2')
35    assert_allclose(x_roundtrip, x, rtol=1e-10)
36
37    # mismatched dtypes OK
38    x_roundtrip2 = pywt.idwt(cA.astype(np.float64), cD.astype(np.float32),
39                             'db2')
40    assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
41    assert_(x_roundtrip2.dtype == np.float64)
42
43
44def test_idwt_mixed_complex_dtype():
45    x = np.arange(8).astype(float)
46    x = x + 1j*x[::-1]
47    cA, cD = pywt.dwt(x, 'db2')
48
49    x_roundtrip = pywt.idwt(cA, cD, 'db2')
50    assert_allclose(x_roundtrip, x, rtol=1e-10)
51
52    # mismatched dtypes OK
53    x_roundtrip2 = pywt.idwt(cA.astype(np.complex128), cD.astype(np.complex64),
54                             'db2')
55    assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
56    assert_(x_roundtrip2.dtype == np.complex128)
57
58
59def test_dwt_idwt_dtypes():
60    wavelet = pywt.Wavelet('haar')
61    for dt_in, dt_out in zip(dtypes_in, dtypes_out):
62        x = np.ones(4, dtype=dt_in)
63        errmsg = "wrong dtype returned for {0} input".format(dt_in)
64
65        cA, cD = pywt.dwt(x, wavelet)
66        assert_(cA.dtype == cD.dtype == dt_out, "dwt: " + errmsg)
67
68        x_roundtrip = pywt.idwt(cA, cD, wavelet)
69        assert_(x_roundtrip.dtype == dt_out, "idwt: " + errmsg)
70
71
72def test_dwt_idwt_basic_complex():
73    x = np.asarray([3, 7, 1, 1, -2, 5, 4, 6])
74    x = x + 0.5j*x
75    cA, cD = pywt.dwt(x, 'db2')
76    cA_expect = np.asarray([5.65685425, 7.39923721, 0.22414387, 3.33677403,
77                            7.77817459])
78    cA_expect = cA_expect + 0.5j*cA_expect
79    cD_expect = np.asarray([-2.44948974, -1.60368225, -4.44140056, -0.41361256,
80                            1.22474487])
81    cD_expect = cD_expect + 0.5j*cD_expect
82    assert_allclose(cA, cA_expect)
83    assert_allclose(cD, cD_expect)
84
85    x_roundtrip = pywt.idwt(cA, cD, 'db2')
86    assert_allclose(x_roundtrip, x, rtol=1e-10)
87
88
89def test_dwt_idwt_partial_complex():
90    x = np.asarray([3, 7, 1, 1, -2, 5, 4, 6])
91    x = x + 0.5j*x
92
93    cA, cD = pywt.dwt(x, 'haar')
94    cA_rec_expect = np.array([5.0+2.5j, 5.0+2.5j, 1.0+0.5j, 1.0+0.5j,
95                              1.5+0.75j, 1.5+0.75j, 5.0+2.5j, 5.0+2.5j])
96    cA_rec = pywt.idwt(cA, None, 'haar')
97    assert_allclose(cA_rec, cA_rec_expect)
98
99    cD_rec_expect = np.array([-2.0-1.0j, 2.0+1.0j, 0.0+0.0j, 0.0+0.0j,
100                              -3.5-1.75j, 3.5+1.75j, -1.0-0.5j, 1.0+0.5j])
101    cD_rec = pywt.idwt(None, cD, 'haar')
102    assert_allclose(cD_rec, cD_rec_expect)
103
104    assert_allclose(cA_rec + cD_rec, x)
105
106
107def test_dwt_wavelet_kwd():
108    x = np.array([3, 7, 1, 1, -2, 5, 4, 6])
109    w = pywt.Wavelet('sym3')
110    cA, cD = pywt.dwt(x, wavelet=w, mode='constant')
111    cA_expect = [4.38354585, 3.80302657, 7.31813271, -0.58565539, 4.09727044,
112                 7.81994027]
113    cD_expect = [-1.33068221, -2.78795192, -3.16825651, -0.67715519,
114                 -0.09722957, -0.07045258]
115    assert_allclose(cA, cA_expect)
116    assert_allclose(cD, cD_expect)
117
118
119def test_dwt_coeff_len():
120    x = np.array([3, 7, 1, 1, -2, 5, 4, 6])
121    w = pywt.Wavelet('sym3')
122    ln_modes = [pywt.dwt_coeff_len(len(x), w.dec_len, mode) for mode in
123                pywt.Modes.modes]
124
125    expected_result = [6, ] * len(pywt.Modes.modes)
126    expected_result[pywt.Modes.modes.index('periodization')] = 4
127
128    assert_allclose(ln_modes, expected_result)
129    ln_modes = [pywt.dwt_coeff_len(len(x), w, mode) for mode in
130                pywt.Modes.modes]
131    assert_allclose(ln_modes, expected_result)
132
133
134def test_idwt_none_input():
135    # None input equals arrays of zeros of the right length
136    res1 = pywt.idwt([1, 2, 0, 1], None, 'db2', 'symmetric')
137    res2 = pywt.idwt([1, 2, 0, 1], [0, 0, 0, 0], 'db2', 'symmetric')
138    assert_allclose(res1, res2, rtol=1e-15, atol=1e-15)
139
140    res1 = pywt.idwt(None, [1, 2, 0, 1], 'db2', 'symmetric')
141    res2 = pywt.idwt([0, 0, 0, 0], [1, 2, 0, 1], 'db2', 'symmetric')
142    assert_allclose(res1, res2, rtol=1e-15, atol=1e-15)
143
144    # Only one argument at a time can be None
145    assert_raises(ValueError, pywt.idwt, None, None, 'db2', 'symmetric')
146
147
148def test_idwt_invalid_input():
149    # Too short, min length is 4 for 'db4':
150    assert_raises(ValueError, pywt.idwt, [1, 2, 4], [4, 1, 3], 'db4', 'symmetric')
151
152
153def test_dwt_single_axis():
154    x = [[3, 7, 1, 1],
155         [-2, 5, 4, 6]]
156
157    cA, cD = pywt.dwt(x, 'db2', axis=-1)
158
159    cA0, cD0 = pywt.dwt(x[0], 'db2')
160    cA1, cD1 = pywt.dwt(x[1], 'db2')
161
162    assert_allclose(cA[0], cA0)
163    assert_allclose(cA[1], cA1)
164
165    assert_allclose(cD[0], cD0)
166    assert_allclose(cD[1], cD1)
167
168
169def test_idwt_single_axis():
170    x = [[3, 7, 1, 1],
171         [-2, 5, 4, 6]]
172
173    x = np.asarray(x)
174    x = x + 1j*x   # test with complex data
175    cA, cD = pywt.dwt(x, 'db2', axis=-1)
176
177    x0 = pywt.idwt(cA[0], cD[0], 'db2', axis=-1)
178    x1 = pywt.idwt(cA[1], cD[1], 'db2', axis=-1)
179
180    assert_allclose(x[0], x0)
181    assert_allclose(x[1], x1)
182
183def test_dwt_invalid_input():
184    x = np.arange(1)
185    assert_raises(ValueError, pywt.dwt, x, 'db2', 'reflect')
186    assert_raises(ValueError, pywt.dwt, x, 'haar', 'antireflect')
187
188
189def test_dwt_axis_arg():
190    x = [[3, 7, 1, 1],
191         [-2, 5, 4, 6]]
192
193    cA_, cD_ = pywt.dwt(x, 'db2', axis=-1)
194    cA, cD = pywt.dwt(x, 'db2', axis=1)
195
196    assert_allclose(cA_, cA)
197    assert_allclose(cD_, cD)
198
199def test_dwt_axis_invalid_input():
200    x = np.ones((3,1))
201    assert_raises(ValueError, pywt.dwt, x, 'db2', 'reflect')
202
203def test_idwt_axis_arg():
204    x = [[3, 7, 1, 1],
205         [-2, 5, 4, 6]]
206
207    cA, cD = pywt.dwt(x, 'db2', axis=1)
208
209    x_ = pywt.idwt(cA, cD, 'db2', axis=-1)
210    x = pywt.idwt(cA, cD, 'db2', axis=1)
211
212    assert_allclose(x_, x)
213
214
215def test_dwt_idwt_axis_excess():
216    x = [[3, 7, 1, 1],
217         [-2, 5, 4, 6]]
218    # can't transform over axes that aren't there
219    assert_raises(ValueError,
220                  pywt.dwt, x, 'db2', 'symmetric', axis=2)
221
222    assert_raises(ValueError,
223                  pywt.idwt, [1, 2, 4], [4, 1, 3], 'db2', 'symmetric', axis=1)
224
225
226def test_error_on_continuous_wavelet():
227    # A ValueError is raised if a Continuous wavelet is selected
228    data = np.ones((32, ))
229    for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
230        assert_raises(ValueError, pywt.dwt, data, cwave)
231
232        cA, cD = pywt.dwt(data, 'db1')
233        assert_raises(ValueError, pywt.idwt, cA, cD, cwave)
234
235
236def test_dwt_zero_size_axes():
237    # raise on empty input array
238    assert_raises(ValueError, pywt.dwt, [], 'db2')
239
240    # >1D case uses a different code path so check there as well
241    x = np.ones((1, 4))[0:0, :]  # 2D with a size zero axis
242    assert_raises(ValueError, pywt.dwt, x, 'db2', axis=0)
243
244
245def test_pad_1d():
246    x = [1, 2, 3]
247    assert_array_equal(pywt.pad(x, (4, 6), 'periodization'),
248                       [1, 2, 3, 3, 1, 2, 3, 3, 1, 2, 3, 3, 1, 2])
249    assert_array_equal(pywt.pad(x, (4, 6), 'periodic'),
250                       [3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3])
251    assert_array_equal(pywt.pad(x, (4, 6), 'constant'),
252                       [1, 1, 1, 1, 1, 2, 3, 3, 3, 3, 3, 3, 3])
253    assert_array_equal(pywt.pad(x, (4, 6), 'zero'),
254                       [0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0, 0])
255    assert_array_equal(pywt.pad(x, (4, 6), 'smooth'),
256                       [-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
257    assert_array_equal(pywt.pad(x, (4, 6), 'symmetric'),
258                       [3, 3, 2, 1, 1, 2, 3, 3, 2, 1, 1, 2, 3])
259    assert_array_equal(pywt.pad(x, (4, 6), 'antisymmetric'),
260                       [3, -3, -2, -1, 1, 2, 3, -3, -2, -1, 1, 2, 3])
261    assert_array_equal(pywt.pad(x, (4, 6), 'reflect'),
262                       [1, 2, 3, 2, 1, 2, 3, 2, 1, 2, 3, 2, 1])
263    assert_array_equal(pywt.pad(x, (4, 6), 'antireflect'),
264                       [-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
265
266    # equivalence of various pad_width formats
267    assert_array_equal(pywt.pad(x, 4, 'periodic'),
268                       pywt.pad(x, (4, 4), 'periodic'))
269
270    assert_array_equal(pywt.pad(x, (4, ), 'periodic'),
271                       pywt.pad(x, (4, 4), 'periodic'))
272
273    assert_array_equal(pywt.pad(x, [(4, 4)], 'periodic'),
274                       pywt.pad(x, (4, 4), 'periodic'))
275
276
277def test_pad_errors():
278    # negative pad width
279    x = [1, 2, 3]
280    assert_raises(ValueError, pywt.pad, x, -2, 'periodic')
281
282    # wrong length pad width
283    assert_raises(ValueError, pywt.pad, x, (1, 1, 1), 'periodic')
284
285    # invalid mode name
286    assert_raises(ValueError, pywt.pad, x, 2, 'bad_mode')
287
288
289def test_pad_nd():
290    for ndim in [2, 3]:
291        x = np.arange(4**ndim).reshape((4, ) * ndim)
292        if ndim == 2:
293            pad_widths = [(2, 1), (2, 3)]
294        else:
295            pad_widths = [(2, 1), ] * ndim
296        for mode in pywt.Modes.modes:
297            xp = pywt.pad(x, pad_widths, mode)
298
299            # expected result is the same as applying along axes separably
300            xp_expected = x.copy()
301            for ax in range(ndim):
302                xp_expected = np.apply_along_axis(pywt.pad,
303                                                  ax,
304                                                  xp_expected,
305                                                  pad_widths=[pad_widths[ax]],
306                                                  mode=mode)
307            assert_array_equal(xp, xp_expected)
308