1#!/usr/bin/env python 2from __future__ import division, print_function, absolute_import 3 4import numpy as np 5from numpy.testing import (assert_allclose, assert_, assert_raises, 6 assert_array_equal) 7import pywt 8 9# Check that float32, float64, complex64, complex128 are preserved. 10# Other real types get converted to float64. 11# complex256 gets converted to complex128 12dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64, 13 np.complex128] 14dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64, 15 np.complex128] 16 17# test complex256 as well if it is available 18try: 19 dtypes_in += [np.complex256, ] 20 dtypes_out += [np.complex128, ] 21except AttributeError: 22 pass 23 24 25def test_dwt_idwt_basic(): 26 x = [3, 7, 1, 1, -2, 5, 4, 6] 27 cA, cD = pywt.dwt(x, 'db2') 28 cA_expect = [5.65685425, 7.39923721, 0.22414387, 3.33677403, 7.77817459] 29 cD_expect = [-2.44948974, -1.60368225, -4.44140056, -0.41361256, 30 1.22474487] 31 assert_allclose(cA, cA_expect) 32 assert_allclose(cD, cD_expect) 33 34 x_roundtrip = pywt.idwt(cA, cD, 'db2') 35 assert_allclose(x_roundtrip, x, rtol=1e-10) 36 37 # mismatched dtypes OK 38 x_roundtrip2 = pywt.idwt(cA.astype(np.float64), cD.astype(np.float32), 39 'db2') 40 assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7) 41 assert_(x_roundtrip2.dtype == np.float64) 42 43 44def test_idwt_mixed_complex_dtype(): 45 x = np.arange(8).astype(float) 46 x = x + 1j*x[::-1] 47 cA, cD = pywt.dwt(x, 'db2') 48 49 x_roundtrip = pywt.idwt(cA, cD, 'db2') 50 assert_allclose(x_roundtrip, x, rtol=1e-10) 51 52 # mismatched dtypes OK 53 x_roundtrip2 = pywt.idwt(cA.astype(np.complex128), cD.astype(np.complex64), 54 'db2') 55 assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7) 56 assert_(x_roundtrip2.dtype == np.complex128) 57 58 59def test_dwt_idwt_dtypes(): 60 wavelet = pywt.Wavelet('haar') 61 for dt_in, dt_out in zip(dtypes_in, dtypes_out): 62 x = np.ones(4, dtype=dt_in) 63 errmsg = "wrong dtype returned for {0} input".format(dt_in) 64 65 cA, cD = pywt.dwt(x, wavelet) 66 assert_(cA.dtype == cD.dtype == dt_out, "dwt: " + errmsg) 67 68 x_roundtrip = pywt.idwt(cA, cD, wavelet) 69 assert_(x_roundtrip.dtype == dt_out, "idwt: " + errmsg) 70 71 72def test_dwt_idwt_basic_complex(): 73 x = np.asarray([3, 7, 1, 1, -2, 5, 4, 6]) 74 x = x + 0.5j*x 75 cA, cD = pywt.dwt(x, 'db2') 76 cA_expect = np.asarray([5.65685425, 7.39923721, 0.22414387, 3.33677403, 77 7.77817459]) 78 cA_expect = cA_expect + 0.5j*cA_expect 79 cD_expect = np.asarray([-2.44948974, -1.60368225, -4.44140056, -0.41361256, 80 1.22474487]) 81 cD_expect = cD_expect + 0.5j*cD_expect 82 assert_allclose(cA, cA_expect) 83 assert_allclose(cD, cD_expect) 84 85 x_roundtrip = pywt.idwt(cA, cD, 'db2') 86 assert_allclose(x_roundtrip, x, rtol=1e-10) 87 88 89def test_dwt_idwt_partial_complex(): 90 x = np.asarray([3, 7, 1, 1, -2, 5, 4, 6]) 91 x = x + 0.5j*x 92 93 cA, cD = pywt.dwt(x, 'haar') 94 cA_rec_expect = np.array([5.0+2.5j, 5.0+2.5j, 1.0+0.5j, 1.0+0.5j, 95 1.5+0.75j, 1.5+0.75j, 5.0+2.5j, 5.0+2.5j]) 96 cA_rec = pywt.idwt(cA, None, 'haar') 97 assert_allclose(cA_rec, cA_rec_expect) 98 99 cD_rec_expect = np.array([-2.0-1.0j, 2.0+1.0j, 0.0+0.0j, 0.0+0.0j, 100 -3.5-1.75j, 3.5+1.75j, -1.0-0.5j, 1.0+0.5j]) 101 cD_rec = pywt.idwt(None, cD, 'haar') 102 assert_allclose(cD_rec, cD_rec_expect) 103 104 assert_allclose(cA_rec + cD_rec, x) 105 106 107def test_dwt_wavelet_kwd(): 108 x = np.array([3, 7, 1, 1, -2, 5, 4, 6]) 109 w = pywt.Wavelet('sym3') 110 cA, cD = pywt.dwt(x, wavelet=w, mode='constant') 111 cA_expect = [4.38354585, 3.80302657, 7.31813271, -0.58565539, 4.09727044, 112 7.81994027] 113 cD_expect = [-1.33068221, -2.78795192, -3.16825651, -0.67715519, 114 -0.09722957, -0.07045258] 115 assert_allclose(cA, cA_expect) 116 assert_allclose(cD, cD_expect) 117 118 119def test_dwt_coeff_len(): 120 x = np.array([3, 7, 1, 1, -2, 5, 4, 6]) 121 w = pywt.Wavelet('sym3') 122 ln_modes = [pywt.dwt_coeff_len(len(x), w.dec_len, mode) for mode in 123 pywt.Modes.modes] 124 125 expected_result = [6, ] * len(pywt.Modes.modes) 126 expected_result[pywt.Modes.modes.index('periodization')] = 4 127 128 assert_allclose(ln_modes, expected_result) 129 ln_modes = [pywt.dwt_coeff_len(len(x), w, mode) for mode in 130 pywt.Modes.modes] 131 assert_allclose(ln_modes, expected_result) 132 133 134def test_idwt_none_input(): 135 # None input equals arrays of zeros of the right length 136 res1 = pywt.idwt([1, 2, 0, 1], None, 'db2', 'symmetric') 137 res2 = pywt.idwt([1, 2, 0, 1], [0, 0, 0, 0], 'db2', 'symmetric') 138 assert_allclose(res1, res2, rtol=1e-15, atol=1e-15) 139 140 res1 = pywt.idwt(None, [1, 2, 0, 1], 'db2', 'symmetric') 141 res2 = pywt.idwt([0, 0, 0, 0], [1, 2, 0, 1], 'db2', 'symmetric') 142 assert_allclose(res1, res2, rtol=1e-15, atol=1e-15) 143 144 # Only one argument at a time can be None 145 assert_raises(ValueError, pywt.idwt, None, None, 'db2', 'symmetric') 146 147 148def test_idwt_invalid_input(): 149 # Too short, min length is 4 for 'db4': 150 assert_raises(ValueError, pywt.idwt, [1, 2, 4], [4, 1, 3], 'db4', 'symmetric') 151 152 153def test_dwt_single_axis(): 154 x = [[3, 7, 1, 1], 155 [-2, 5, 4, 6]] 156 157 cA, cD = pywt.dwt(x, 'db2', axis=-1) 158 159 cA0, cD0 = pywt.dwt(x[0], 'db2') 160 cA1, cD1 = pywt.dwt(x[1], 'db2') 161 162 assert_allclose(cA[0], cA0) 163 assert_allclose(cA[1], cA1) 164 165 assert_allclose(cD[0], cD0) 166 assert_allclose(cD[1], cD1) 167 168 169def test_idwt_single_axis(): 170 x = [[3, 7, 1, 1], 171 [-2, 5, 4, 6]] 172 173 x = np.asarray(x) 174 x = x + 1j*x # test with complex data 175 cA, cD = pywt.dwt(x, 'db2', axis=-1) 176 177 x0 = pywt.idwt(cA[0], cD[0], 'db2', axis=-1) 178 x1 = pywt.idwt(cA[1], cD[1], 'db2', axis=-1) 179 180 assert_allclose(x[0], x0) 181 assert_allclose(x[1], x1) 182 183def test_dwt_invalid_input(): 184 x = np.arange(1) 185 assert_raises(ValueError, pywt.dwt, x, 'db2', 'reflect') 186 assert_raises(ValueError, pywt.dwt, x, 'haar', 'antireflect') 187 188 189def test_dwt_axis_arg(): 190 x = [[3, 7, 1, 1], 191 [-2, 5, 4, 6]] 192 193 cA_, cD_ = pywt.dwt(x, 'db2', axis=-1) 194 cA, cD = pywt.dwt(x, 'db2', axis=1) 195 196 assert_allclose(cA_, cA) 197 assert_allclose(cD_, cD) 198 199def test_dwt_axis_invalid_input(): 200 x = np.ones((3,1)) 201 assert_raises(ValueError, pywt.dwt, x, 'db2', 'reflect') 202 203def test_idwt_axis_arg(): 204 x = [[3, 7, 1, 1], 205 [-2, 5, 4, 6]] 206 207 cA, cD = pywt.dwt(x, 'db2', axis=1) 208 209 x_ = pywt.idwt(cA, cD, 'db2', axis=-1) 210 x = pywt.idwt(cA, cD, 'db2', axis=1) 211 212 assert_allclose(x_, x) 213 214 215def test_dwt_idwt_axis_excess(): 216 x = [[3, 7, 1, 1], 217 [-2, 5, 4, 6]] 218 # can't transform over axes that aren't there 219 assert_raises(ValueError, 220 pywt.dwt, x, 'db2', 'symmetric', axis=2) 221 222 assert_raises(ValueError, 223 pywt.idwt, [1, 2, 4], [4, 1, 3], 'db2', 'symmetric', axis=1) 224 225 226def test_error_on_continuous_wavelet(): 227 # A ValueError is raised if a Continuous wavelet is selected 228 data = np.ones((32, )) 229 for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]: 230 assert_raises(ValueError, pywt.dwt, data, cwave) 231 232 cA, cD = pywt.dwt(data, 'db1') 233 assert_raises(ValueError, pywt.idwt, cA, cD, cwave) 234 235 236def test_dwt_zero_size_axes(): 237 # raise on empty input array 238 assert_raises(ValueError, pywt.dwt, [], 'db2') 239 240 # >1D case uses a different code path so check there as well 241 x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis 242 assert_raises(ValueError, pywt.dwt, x, 'db2', axis=0) 243 244 245def test_pad_1d(): 246 x = [1, 2, 3] 247 assert_array_equal(pywt.pad(x, (4, 6), 'periodization'), 248 [1, 2, 3, 3, 1, 2, 3, 3, 1, 2, 3, 3, 1, 2]) 249 assert_array_equal(pywt.pad(x, (4, 6), 'periodic'), 250 [3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3]) 251 assert_array_equal(pywt.pad(x, (4, 6), 'constant'), 252 [1, 1, 1, 1, 1, 2, 3, 3, 3, 3, 3, 3, 3]) 253 assert_array_equal(pywt.pad(x, (4, 6), 'zero'), 254 [0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0, 0]) 255 assert_array_equal(pywt.pad(x, (4, 6), 'smooth'), 256 [-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) 257 assert_array_equal(pywt.pad(x, (4, 6), 'symmetric'), 258 [3, 3, 2, 1, 1, 2, 3, 3, 2, 1, 1, 2, 3]) 259 assert_array_equal(pywt.pad(x, (4, 6), 'antisymmetric'), 260 [3, -3, -2, -1, 1, 2, 3, -3, -2, -1, 1, 2, 3]) 261 assert_array_equal(pywt.pad(x, (4, 6), 'reflect'), 262 [1, 2, 3, 2, 1, 2, 3, 2, 1, 2, 3, 2, 1]) 263 assert_array_equal(pywt.pad(x, (4, 6), 'antireflect'), 264 [-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) 265 266 # equivalence of various pad_width formats 267 assert_array_equal(pywt.pad(x, 4, 'periodic'), 268 pywt.pad(x, (4, 4), 'periodic')) 269 270 assert_array_equal(pywt.pad(x, (4, ), 'periodic'), 271 pywt.pad(x, (4, 4), 'periodic')) 272 273 assert_array_equal(pywt.pad(x, [(4, 4)], 'periodic'), 274 pywt.pad(x, (4, 4), 'periodic')) 275 276 277def test_pad_errors(): 278 # negative pad width 279 x = [1, 2, 3] 280 assert_raises(ValueError, pywt.pad, x, -2, 'periodic') 281 282 # wrong length pad width 283 assert_raises(ValueError, pywt.pad, x, (1, 1, 1), 'periodic') 284 285 # invalid mode name 286 assert_raises(ValueError, pywt.pad, x, 2, 'bad_mode') 287 288 289def test_pad_nd(): 290 for ndim in [2, 3]: 291 x = np.arange(4**ndim).reshape((4, ) * ndim) 292 if ndim == 2: 293 pad_widths = [(2, 1), (2, 3)] 294 else: 295 pad_widths = [(2, 1), ] * ndim 296 for mode in pywt.Modes.modes: 297 xp = pywt.pad(x, pad_widths, mode) 298 299 # expected result is the same as applying along axes separably 300 xp_expected = x.copy() 301 for ax in range(ndim): 302 xp_expected = np.apply_along_axis(pywt.pad, 303 ax, 304 xp_expected, 305 pad_widths=[pad_widths[ax]], 306 mode=mode) 307 assert_array_equal(xp, xp_expected) 308