1#!/usr/bin/env python
2
3import os
4import pickle
5
6import numpy as np
7from numpy.testing import (assert_allclose, assert_, assert_raises,
8                           assert_equal)
9
10import pywt
11
12
13def test_traversing_tree_2d():
14    x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
15    wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
16
17    assert_(np.all(wp.data == x))
18    assert_(wp.path == '')
19    assert_(wp.level == 0)
20    assert_(wp.maxlevel == 3)
21
22    assert_allclose(wp['a'].data, np.array([[3., 7., 11., 15.]] * 4),
23                    rtol=1e-12)
24    assert_allclose(wp['h'].data, np.zeros((4, 4)), rtol=1e-12, atol=1e-14)
25    assert_allclose(wp['v'].data, -np.ones((4, 4)), rtol=1e-12, atol=1e-14)
26    assert_allclose(wp['d'].data, np.zeros((4, 4)), rtol=1e-12, atol=1e-14)
27
28    assert_allclose(wp['aa'].data, np.array([[10., 26.]] * 2), rtol=1e-12)
29
30    assert_(wp['a']['a'].data is wp['aa'].data)
31    assert_allclose(wp['aaa'].data, np.array([[36.]]), rtol=1e-12)
32
33    assert_raises(IndexError, lambda: wp['aaaa'])
34    assert_raises(ValueError, lambda: wp['f'])
35
36
37def test_accessing_node_attributes_2d():
38    x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
39    wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
40
41    assert_allclose(wp['av'].data, np.zeros((2, 2)) - 4, rtol=1e-12)
42    assert_(wp['av'].path == 'av')
43    assert_(wp['av'].node_name == 'v')
44    assert_(wp['av'].parent.path == 'a')
45
46    assert_allclose(wp['av'].parent.data, np.array([[3., 7., 11., 15.]] * 4),
47                    rtol=1e-12)
48    # can also index via a tuple instead of concatenated strings
49    assert_(wp['av'].level == 2)
50    assert_(wp['av'].maxlevel == 3)
51    assert_(wp['av'].mode == 'symmetric')
52
53    # tuple-based access is also supported
54    node = wp[('a', 'v')]
55    # can access a node's path as either a single string or in tuple form
56    assert_(node.path == 'av')
57    assert_(node.path_tuple == ('a', 'v'))
58
59
60def test_collecting_nodes_2d():
61    x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
62    wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
63
64    assert_(len(wp.get_level(0)) == 1)
65    assert_(wp.get_level(0)[0].path == '')
66
67    # First level
68    assert_(len(wp.get_level(1)) == 4)
69    assert_([node.path for node in wp.get_level(1)] == ['a', 'h', 'v', 'd'])
70
71    # Second level
72    assert_(len(wp.get_level(2)) == 16)
73    paths = [node.path for node in wp.get_level(2)]
74    expected_paths = ['aa', 'ah', 'av', 'ad', 'ha', 'hh', 'hv', 'hd', 'va',
75                      'vh', 'vv', 'vd', 'da', 'dh', 'dv', 'dd']
76    assert_(paths == expected_paths)
77
78    # Third level.
79    assert_(len(wp.get_level(3)) == 64)
80    paths = [node.path for node in wp.get_level(3)]
81    expected_paths = ['aaa', 'aah', 'aav', 'aad', 'aha', 'ahh', 'ahv', 'ahd',
82                      'ava', 'avh', 'avv', 'avd', 'ada', 'adh', 'adv', 'add',
83                      'haa', 'hah', 'hav', 'had', 'hha', 'hhh', 'hhv', 'hhd',
84                      'hva', 'hvh', 'hvv', 'hvd', 'hda', 'hdh', 'hdv', 'hdd',
85                      'vaa', 'vah', 'vav', 'vad', 'vha', 'vhh', 'vhv', 'vhd',
86                      'vva', 'vvh', 'vvv', 'vvd', 'vda', 'vdh', 'vdv', 'vdd',
87                      'daa', 'dah', 'dav', 'dad', 'dha', 'dhh', 'dhv', 'dhd',
88                      'dva', 'dvh', 'dvv', 'dvd', 'dda', 'ddh', 'ddv', 'ddd']
89
90    assert_(paths == expected_paths)
91
92    # test 2D frequency ordering at the first level
93    fnodes = wp.get_level(1, order='freq')
94    assert_(fnodes[0][0].path == 'a')
95    assert_(fnodes[0][1].path == 'v')
96    assert_(fnodes[1][0].path == 'h')
97    assert_(fnodes[1][1].path == 'd')
98
99    # test 2D frequency ordering at the second level
100    fnodes = wp.get_level(2, order='freq')
101    assert_([n.path for n in fnodes[0]] == ['aa', 'av', 'vv', 'va'])
102    assert_([n.path for n in fnodes[1]] == ['ah', 'ad', 'vd', 'vh'])
103    assert_([n.path for n in fnodes[2]] == ['hh', 'hd', 'dd', 'dh'])
104    assert_([n.path for n in fnodes[3]] == ['ha', 'hv', 'dv', 'da'])
105
106    # invalid node collection order
107    assert_raises(ValueError, wp.get_level, 2, 'invalid_order')
108
109
110def test_data_reconstruction_2d():
111    x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
112    wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
113
114    new_wp = pywt.WaveletPacket2D(data=None, wavelet='db1', mode='symmetric')
115    new_wp['vh'] = wp['vh'].data
116    new_wp['vv'] = wp['vh'].data
117    new_wp['vd'] = np.zeros((2, 2), dtype=np.float64)
118    new_wp['a'] = [[3.0, 7.0, 11.0, 15.0]] * 4
119    new_wp['d'] = np.zeros((4, 4), dtype=np.float64)
120    new_wp['h'] = wp['h']       # all zeros
121
122    assert_allclose(new_wp.reconstruct(update=False),
123                    np.array([[1.5, 1.5, 3.5, 3.5, 5.5, 5.5, 7.5, 7.5]] * 8),
124                    rtol=1e-12)
125    assert_allclose(wp['va'].data, np.zeros((2, 2)) - 2, rtol=1e-12)
126
127    new_wp['va'] = wp['va'].data
128    assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
129
130
131def test_data_reconstruction_delete_nodes_2d():
132    x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
133    wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
134
135    new_wp = pywt.WaveletPacket2D(data=None, wavelet='db1', mode='symmetric')
136    new_wp['vh'] = wp['vh'].data
137    new_wp['vv'] = wp['vh'].data
138    new_wp['vd'] = np.zeros((2, 2), dtype=np.float64)
139    new_wp['a'] = [[3.0, 7.0, 11.0, 15.0]] * 4
140    new_wp['d'] = np.zeros((4, 4), dtype=np.float64)
141    new_wp['h'] = wp['h']       # all zeros
142
143    assert_allclose(new_wp.reconstruct(update=False),
144                    np.array([[1.5, 1.5, 3.5, 3.5, 5.5, 5.5, 7.5, 7.5]] * 8),
145                    rtol=1e-12)
146
147    new_wp['va'] = wp['va'].data
148    assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
149
150    del(new_wp['va'])
151    # TypeError on accessing deleted node
152    assert_raises(TypeError, lambda: new_wp['va'])
153    new_wp['va'] = wp['va'].data
154    assert_(new_wp.data is None)
155
156    assert_allclose(new_wp.reconstruct(update=True), x, rtol=1e-12)
157    assert_allclose(new_wp.data, x, rtol=1e-12)
158
159    # TODO: decompose=True
160
161
162def test_lazy_evaluation_2D():
163    # Note: internal implementation detail not to be relied on.  Testing for
164    # now for backwards compatibility, but this test may be broken in needed.
165    x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8)
166    wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
167
168    assert_(wp.a is None)
169    assert_allclose(wp['a'].data, np.array([[3., 7., 11., 15.]] * 4),
170                    rtol=1e-12)
171    assert_allclose(wp.a.data, np.array([[3., 7., 11., 15.]] * 4), rtol=1e-12)
172    assert_allclose(wp.d.data, np.zeros((4, 4)), rtol=1e-12, atol=1e-12)
173
174
175def test_wavelet_packet_dtypes():
176    shape = (16, 16)
177    for dtype in [np.float32, np.float64, np.complex64, np.complex128]:
178        x = np.random.randn(*shape).astype(dtype)
179        if np.iscomplexobj(x):
180            x = x + 1j*np.random.randn(*shape).astype(x.real.dtype)
181        wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
182        # no unnecessary copy made
183        assert_(wp.data is x)
184
185        # assiging to a node should not change supported dtypes
186        wp['d'] = wp['d'].data
187        assert_equal(wp['d'].data.dtype, x.dtype)
188
189        # full decomposition
190        wp.get_level(wp.maxlevel)
191
192        # reconstruction from coefficients should preserve dtype
193        r = wp.reconstruct(False)
194        assert_equal(r.dtype, x.dtype)
195        assert_allclose(r, x, atol=1e-5, rtol=1e-5)
196
197
198def test_2d_roundtrip():
199    # test case corresponding to PyWavelets issue 447
200    original = pywt.data.camera()
201    wp = pywt.WaveletPacket2D(data=original, wavelet='db3', mode='smooth',
202                              maxlevel=3)
203    r = wp.reconstruct()
204    assert_allclose(original, r, atol=1e-12, rtol=1e-12)
205
206
207def test_wavelet_packet_axes():
208    rstate = np.random.RandomState(0)
209    shape = (32, 16)
210    x = rstate.standard_normal(shape)
211    for axes in [(0, 1), (1, 0), (-2, 1)]:
212        wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric',
213                                  axes=axes)
214
215        # partial decomposition
216        nodes = wp.get_level(2)
217        # size along the transformed axes has changed
218        for ax2 in range(x.ndim):
219            if ax2 in tuple(np.asarray(axes) % x.ndim):
220                nodes[0].data.shape[ax2] < x.shape[ax2]
221            else:
222                nodes[0].data.shape[ax2] == x.shape[ax2]
223
224        # recontsruction from coefficients should preserve dtype
225        r = wp.reconstruct(False)
226        assert_equal(r.dtype, x.dtype)
227        assert_allclose(r, x, atol=1e-12, rtol=1e-12)
228
229    # must have two non-duplicate axes
230    assert_raises(ValueError, pywt.WaveletPacket2D, data=x, wavelet='db1',
231                  axes=(0, 0))
232    assert_raises(ValueError, pywt.WaveletPacket2D, data=x, wavelet='db1',
233                  axes=(0, ))
234    assert_raises(ValueError, pywt.WaveletPacket2D, data=x, wavelet='db1',
235                  axes=(0, 1, 2))
236
237
238def test_wavelet_packet2d_pickle(tmpdir):
239    packet = pywt.WaveletPacket2D(np.arange(256).reshape(16, 16), 'sym4')
240    filename = os.path.join(tmpdir, 'wp2d.pickle')
241    with open(filename, 'wb') as f:
242        pickle.dump(packet, f)
243    with open(filename, 'rb') as f:
244        packet2 = pickle.load(f)
245    assert isinstance(packet2, pywt.WaveletPacket2D)
246