1#!/usr/bin/env python 2 3import os 4import pickle 5 6import numpy as np 7from numpy.testing import (assert_allclose, assert_, assert_raises, 8 assert_equal) 9 10import pywt 11 12 13def test_traversing_tree_2d(): 14 x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64) 15 wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric') 16 17 assert_(np.all(wp.data == x)) 18 assert_(wp.path == '') 19 assert_(wp.level == 0) 20 assert_(wp.maxlevel == 3) 21 22 assert_allclose(wp['a'].data, np.array([[3., 7., 11., 15.]] * 4), 23 rtol=1e-12) 24 assert_allclose(wp['h'].data, np.zeros((4, 4)), rtol=1e-12, atol=1e-14) 25 assert_allclose(wp['v'].data, -np.ones((4, 4)), rtol=1e-12, atol=1e-14) 26 assert_allclose(wp['d'].data, np.zeros((4, 4)), rtol=1e-12, atol=1e-14) 27 28 assert_allclose(wp['aa'].data, np.array([[10., 26.]] * 2), rtol=1e-12) 29 30 assert_(wp['a']['a'].data is wp['aa'].data) 31 assert_allclose(wp['aaa'].data, np.array([[36.]]), rtol=1e-12) 32 33 assert_raises(IndexError, lambda: wp['aaaa']) 34 assert_raises(ValueError, lambda: wp['f']) 35 36 37def test_accessing_node_attributes_2d(): 38 x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64) 39 wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric') 40 41 assert_allclose(wp['av'].data, np.zeros((2, 2)) - 4, rtol=1e-12) 42 assert_(wp['av'].path == 'av') 43 assert_(wp['av'].node_name == 'v') 44 assert_(wp['av'].parent.path == 'a') 45 46 assert_allclose(wp['av'].parent.data, np.array([[3., 7., 11., 15.]] * 4), 47 rtol=1e-12) 48 # can also index via a tuple instead of concatenated strings 49 assert_(wp['av'].level == 2) 50 assert_(wp['av'].maxlevel == 3) 51 assert_(wp['av'].mode == 'symmetric') 52 53 # tuple-based access is also supported 54 node = wp[('a', 'v')] 55 # can access a node's path as either a single string or in tuple form 56 assert_(node.path == 'av') 57 assert_(node.path_tuple == ('a', 'v')) 58 59 60def test_collecting_nodes_2d(): 61 x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64) 62 wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric') 63 64 assert_(len(wp.get_level(0)) == 1) 65 assert_(wp.get_level(0)[0].path == '') 66 67 # First level 68 assert_(len(wp.get_level(1)) == 4) 69 assert_([node.path for node in wp.get_level(1)] == ['a', 'h', 'v', 'd']) 70 71 # Second level 72 assert_(len(wp.get_level(2)) == 16) 73 paths = [node.path for node in wp.get_level(2)] 74 expected_paths = ['aa', 'ah', 'av', 'ad', 'ha', 'hh', 'hv', 'hd', 'va', 75 'vh', 'vv', 'vd', 'da', 'dh', 'dv', 'dd'] 76 assert_(paths == expected_paths) 77 78 # Third level. 79 assert_(len(wp.get_level(3)) == 64) 80 paths = [node.path for node in wp.get_level(3)] 81 expected_paths = ['aaa', 'aah', 'aav', 'aad', 'aha', 'ahh', 'ahv', 'ahd', 82 'ava', 'avh', 'avv', 'avd', 'ada', 'adh', 'adv', 'add', 83 'haa', 'hah', 'hav', 'had', 'hha', 'hhh', 'hhv', 'hhd', 84 'hva', 'hvh', 'hvv', 'hvd', 'hda', 'hdh', 'hdv', 'hdd', 85 'vaa', 'vah', 'vav', 'vad', 'vha', 'vhh', 'vhv', 'vhd', 86 'vva', 'vvh', 'vvv', 'vvd', 'vda', 'vdh', 'vdv', 'vdd', 87 'daa', 'dah', 'dav', 'dad', 'dha', 'dhh', 'dhv', 'dhd', 88 'dva', 'dvh', 'dvv', 'dvd', 'dda', 'ddh', 'ddv', 'ddd'] 89 90 assert_(paths == expected_paths) 91 92 # test 2D frequency ordering at the first level 93 fnodes = wp.get_level(1, order='freq') 94 assert_(fnodes[0][0].path == 'a') 95 assert_(fnodes[0][1].path == 'v') 96 assert_(fnodes[1][0].path == 'h') 97 assert_(fnodes[1][1].path == 'd') 98 99 # test 2D frequency ordering at the second level 100 fnodes = wp.get_level(2, order='freq') 101 assert_([n.path for n in fnodes[0]] == ['aa', 'av', 'vv', 'va']) 102 assert_([n.path for n in fnodes[1]] == ['ah', 'ad', 'vd', 'vh']) 103 assert_([n.path for n in fnodes[2]] == ['hh', 'hd', 'dd', 'dh']) 104 assert_([n.path for n in fnodes[3]] == ['ha', 'hv', 'dv', 'da']) 105 106 # invalid node collection order 107 assert_raises(ValueError, wp.get_level, 2, 'invalid_order') 108 109 110def test_data_reconstruction_2d(): 111 x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64) 112 wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric') 113 114 new_wp = pywt.WaveletPacket2D(data=None, wavelet='db1', mode='symmetric') 115 new_wp['vh'] = wp['vh'].data 116 new_wp['vv'] = wp['vh'].data 117 new_wp['vd'] = np.zeros((2, 2), dtype=np.float64) 118 new_wp['a'] = [[3.0, 7.0, 11.0, 15.0]] * 4 119 new_wp['d'] = np.zeros((4, 4), dtype=np.float64) 120 new_wp['h'] = wp['h'] # all zeros 121 122 assert_allclose(new_wp.reconstruct(update=False), 123 np.array([[1.5, 1.5, 3.5, 3.5, 5.5, 5.5, 7.5, 7.5]] * 8), 124 rtol=1e-12) 125 assert_allclose(wp['va'].data, np.zeros((2, 2)) - 2, rtol=1e-12) 126 127 new_wp['va'] = wp['va'].data 128 assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12) 129 130 131def test_data_reconstruction_delete_nodes_2d(): 132 x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64) 133 wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric') 134 135 new_wp = pywt.WaveletPacket2D(data=None, wavelet='db1', mode='symmetric') 136 new_wp['vh'] = wp['vh'].data 137 new_wp['vv'] = wp['vh'].data 138 new_wp['vd'] = np.zeros((2, 2), dtype=np.float64) 139 new_wp['a'] = [[3.0, 7.0, 11.0, 15.0]] * 4 140 new_wp['d'] = np.zeros((4, 4), dtype=np.float64) 141 new_wp['h'] = wp['h'] # all zeros 142 143 assert_allclose(new_wp.reconstruct(update=False), 144 np.array([[1.5, 1.5, 3.5, 3.5, 5.5, 5.5, 7.5, 7.5]] * 8), 145 rtol=1e-12) 146 147 new_wp['va'] = wp['va'].data 148 assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12) 149 150 del(new_wp['va']) 151 # TypeError on accessing deleted node 152 assert_raises(TypeError, lambda: new_wp['va']) 153 new_wp['va'] = wp['va'].data 154 assert_(new_wp.data is None) 155 156 assert_allclose(new_wp.reconstruct(update=True), x, rtol=1e-12) 157 assert_allclose(new_wp.data, x, rtol=1e-12) 158 159 # TODO: decompose=True 160 161 162def test_lazy_evaluation_2D(): 163 # Note: internal implementation detail not to be relied on. Testing for 164 # now for backwards compatibility, but this test may be broken in needed. 165 x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8) 166 wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric') 167 168 assert_(wp.a is None) 169 assert_allclose(wp['a'].data, np.array([[3., 7., 11., 15.]] * 4), 170 rtol=1e-12) 171 assert_allclose(wp.a.data, np.array([[3., 7., 11., 15.]] * 4), rtol=1e-12) 172 assert_allclose(wp.d.data, np.zeros((4, 4)), rtol=1e-12, atol=1e-12) 173 174 175def test_wavelet_packet_dtypes(): 176 shape = (16, 16) 177 for dtype in [np.float32, np.float64, np.complex64, np.complex128]: 178 x = np.random.randn(*shape).astype(dtype) 179 if np.iscomplexobj(x): 180 x = x + 1j*np.random.randn(*shape).astype(x.real.dtype) 181 wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric') 182 # no unnecessary copy made 183 assert_(wp.data is x) 184 185 # assiging to a node should not change supported dtypes 186 wp['d'] = wp['d'].data 187 assert_equal(wp['d'].data.dtype, x.dtype) 188 189 # full decomposition 190 wp.get_level(wp.maxlevel) 191 192 # reconstruction from coefficients should preserve dtype 193 r = wp.reconstruct(False) 194 assert_equal(r.dtype, x.dtype) 195 assert_allclose(r, x, atol=1e-5, rtol=1e-5) 196 197 198def test_2d_roundtrip(): 199 # test case corresponding to PyWavelets issue 447 200 original = pywt.data.camera() 201 wp = pywt.WaveletPacket2D(data=original, wavelet='db3', mode='smooth', 202 maxlevel=3) 203 r = wp.reconstruct() 204 assert_allclose(original, r, atol=1e-12, rtol=1e-12) 205 206 207def test_wavelet_packet_axes(): 208 rstate = np.random.RandomState(0) 209 shape = (32, 16) 210 x = rstate.standard_normal(shape) 211 for axes in [(0, 1), (1, 0), (-2, 1)]: 212 wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric', 213 axes=axes) 214 215 # partial decomposition 216 nodes = wp.get_level(2) 217 # size along the transformed axes has changed 218 for ax2 in range(x.ndim): 219 if ax2 in tuple(np.asarray(axes) % x.ndim): 220 nodes[0].data.shape[ax2] < x.shape[ax2] 221 else: 222 nodes[0].data.shape[ax2] == x.shape[ax2] 223 224 # recontsruction from coefficients should preserve dtype 225 r = wp.reconstruct(False) 226 assert_equal(r.dtype, x.dtype) 227 assert_allclose(r, x, atol=1e-12, rtol=1e-12) 228 229 # must have two non-duplicate axes 230 assert_raises(ValueError, pywt.WaveletPacket2D, data=x, wavelet='db1', 231 axes=(0, 0)) 232 assert_raises(ValueError, pywt.WaveletPacket2D, data=x, wavelet='db1', 233 axes=(0, )) 234 assert_raises(ValueError, pywt.WaveletPacket2D, data=x, wavelet='db1', 235 axes=(0, 1, 2)) 236 237 238def test_wavelet_packet2d_pickle(tmpdir): 239 packet = pywt.WaveletPacket2D(np.arange(256).reshape(16, 16), 'sym4') 240 filename = os.path.join(tmpdir, 'wp2d.pickle') 241 with open(filename, 'wb') as f: 242 pickle.dump(packet, f) 243 with open(filename, 'rb') as f: 244 packet2 = pickle.load(f) 245 assert isinstance(packet2, pywt.WaveletPacket2D) 246