1"""Utilities used to generate various figures in the documentation.""" 2from itertools import product 3 4import numpy as np 5from matplotlib import pyplot as plt 6 7from ._dwt import pad 8 9__all__ = ['wavedec_keys', 'wavedec2_keys', 'draw_2d_wp_basis', 10 'draw_2d_fswavedecn_basis', 'boundary_mode_subplot'] 11 12 13def wavedec_keys(level): 14 """Subband keys corresponding to a wavedec decomposition.""" 15 approx = '' 16 coeffs = {} 17 for lev in range(level): 18 for k in ['a', 'd']: 19 coeffs[approx + k] = None 20 approx = 'a' * (lev + 1) 21 if lev < level - 1: 22 coeffs.pop(approx) 23 return list(coeffs.keys()) 24 25 26def wavedec2_keys(level): 27 """Subband keys corresponding to a wavedec2 decomposition.""" 28 approx = '' 29 coeffs = {} 30 for lev in range(level): 31 for k in ['a', 'h', 'v', 'd']: 32 coeffs[approx + k] = None 33 approx = 'a' * (lev + 1) 34 if lev < level - 1: 35 coeffs.pop(approx) 36 return list(coeffs.keys()) 37 38 39def _box(bl, ur): 40 """(x, y) coordinates for the 4 lines making up a rectangular box. 41 42 Parameters 43 ========== 44 bl : float 45 The bottom left corner of the box 46 ur : float 47 The upper right corner of the box 48 49 Returns 50 ======= 51 coords : 2-tuple 52 The first and second elements of the tuple are the x and y coordinates 53 of the box. 54 """ 55 xl, xr = bl[0], ur[0] 56 yb, yt = bl[1], ur[1] 57 box_x = [xl, xr, 58 xr, xr, 59 xr, xl, 60 xl, xl] 61 box_y = [yb, yb, 62 yb, yt, 63 yt, yt, 64 yt, yb] 65 return (box_x, box_y) 66 67 68def _2d_wp_basis_coords(shape, keys): 69 # Coordinates of the lines to be drawn by draw_2d_wp_basis 70 coords = [] 71 centers = {} # retain center of boxes for use in labeling 72 for key in keys: 73 offset_x = offset_y = 0 74 for n, char in enumerate(key): 75 if char in ['h', 'd']: 76 offset_x += shape[0] // 2**(n + 1) 77 if char in ['v', 'd']: 78 offset_y += shape[1] // 2**(n + 1) 79 sx = shape[0] // 2**(n + 1) 80 sy = shape[1] // 2**(n + 1) 81 xc, yc = _box((offset_x, -offset_y), 82 (offset_x + sx, -offset_y - sy)) 83 coords.append((xc, yc)) 84 centers[key] = (offset_x + sx // 2, -offset_y - sy // 2) 85 return coords, centers 86 87 88def draw_2d_wp_basis(shape, keys, fmt='k', plot_kwargs={}, ax=None, 89 label_levels=0): 90 """Plot a 2D representation of a WaveletPacket2D basis.""" 91 coords, centers = _2d_wp_basis_coords(shape, keys) 92 if ax is None: 93 fig, ax = plt.subplots(1, 1) 94 else: 95 fig = ax.get_figure() 96 for coord in coords: 97 ax.plot(coord[0], coord[1], fmt) 98 ax.set_axis_off() 99 ax.axis('square') 100 if label_levels > 0: 101 for key, c in centers.items(): 102 if len(key) <= label_levels: 103 ax.text(c[0], c[1], key, 104 horizontalalignment='center', 105 verticalalignment='center') 106 return fig, ax 107 108 109def _2d_fswavedecn_coords(shape, levels): 110 coords = [] 111 centers = {} # retain center of boxes for use in labeling 112 for key in product(wavedec_keys(levels), repeat=2): 113 (key0, key1) = key 114 offsets = [0, 0] 115 widths = list(shape) 116 for n0, char in enumerate(key0): 117 if char in ['d']: 118 offsets[0] += shape[0] // 2**(n0 + 1) 119 for n1, char in enumerate(key1): 120 if char in ['d']: 121 offsets[1] += shape[1] // 2**(n1 + 1) 122 widths[0] = shape[0] // 2**(n0 + 1) 123 widths[1] = shape[1] // 2**(n1 + 1) 124 xc, yc = _box((offsets[0], -offsets[1]), 125 (offsets[0] + widths[0], -offsets[1] - widths[1])) 126 coords.append((xc, yc)) 127 centers[(key0, key1)] = (offsets[0] + widths[0] / 2, 128 -offsets[1] - widths[1] / 2) 129 return coords, centers 130 131 132def draw_2d_fswavedecn_basis(shape, levels, fmt='k', plot_kwargs={}, ax=None, 133 label_levels=0): 134 """Plot a 2D representation of a WaveletPacket2D basis.""" 135 coords, centers = _2d_fswavedecn_coords(shape, levels) 136 if ax is None: 137 fig, ax = plt.subplots(1, 1) 138 else: 139 fig = ax.get_figure() 140 for coord in coords: 141 ax.plot(coord[0], coord[1], fmt) 142 ax.set_axis_off() 143 ax.axis('square') 144 if label_levels > 0: 145 for key, c in centers.items(): 146 lev = np.max([len(k) for k in key]) 147 if lev <= label_levels: 148 ax.text(c[0], c[1], key, 149 horizontalalignment='center', 150 verticalalignment='center') 151 return fig, ax 152 153 154def boundary_mode_subplot(x, mode, ax, symw=True): 155 """Plot an illustration of the boundary mode in a subplot axis.""" 156 157 # if odd-length, periodization replicates the last sample to make it even 158 if mode == 'periodization' and len(x) % 2 == 1: 159 x = np.concatenate((x, (x[-1], ))) 160 161 npad = 2 * len(x) 162 t = np.arange(len(x) + 2 * npad) 163 xp = pad(x, (npad, npad), mode=mode) 164 165 ax.plot(t, xp, 'k.') 166 ax.set_title(mode) 167 168 # plot the original signal in red 169 if mode == 'periodization': 170 ax.plot(t[npad:npad + len(x) - 1], x[:-1], 'r.') 171 else: 172 ax.plot(t[npad:npad + len(x)], x, 'r.') 173 174 # add vertical bars indicating points of symmetry or boundary extension 175 o2 = np.ones(2) 176 left = npad 177 if symw: 178 step = len(x) - 1 179 rng = range(-2, 4) 180 else: 181 left -= 0.5 182 step = len(x) 183 rng = range(-2, 4) 184 if mode in ['smooth', 'constant', 'zero']: 185 rng = range(0, 2) 186 for rep in rng: 187 ax.plot((left + rep * step) * o2, [xp.min() - .5, xp.max() + .5], 'k-') 188