1import numpy as np
2import pywt
3from matplotlib import pyplot as plt
4from pywt._doc_utils import wavedec2_keys, draw_2d_wp_basis
5
6x = pywt.data.camera().astype(np.float32)
7shape = x.shape
8
9max_lev = 3       # how many levels of decomposition to draw
10label_levels = 3  # how many levels to explicitly label on the plots
11
12fig, axes = plt.subplots(2, 4, figsize=[14, 8])
13for level in range(0, max_lev + 1):
14    if level == 0:
15        # show the original image before decomposition
16        axes[0, 0].set_axis_off()
17        axes[1, 0].imshow(x, cmap=plt.cm.gray)
18        axes[1, 0].set_title('Image')
19        axes[1, 0].set_axis_off()
20        continue
21
22    # plot subband boundaries of a standard DWT basis
23    draw_2d_wp_basis(shape, wavedec2_keys(level), ax=axes[0, level],
24                     label_levels=label_levels)
25    axes[0, level].set_title('{} level\ndecomposition'.format(level))
26
27    # compute the 2D DWT
28    c = pywt.wavedec2(x, 'db2', mode='periodization', level=level)
29    # normalize each coefficient array independently for better visibility
30    c[0] /= np.abs(c[0]).max()
31    for detail_level in range(level):
32        c[detail_level + 1] = [d/np.abs(d).max() for d in c[detail_level + 1]]
33    # show the normalized coefficients
34    arr, slices = pywt.coeffs_to_array(c)
35    axes[1, level].imshow(arr, cmap=plt.cm.gray)
36    axes[1, level].set_title('Coefficients\n({} level)'.format(level))
37    axes[1, level].set_axis_off()
38
39plt.tight_layout()
40plt.show()
41