1"""Utilities used to generate various figures in the documentation."""
2from itertools import product
3
4import numpy as np
5from matplotlib import pyplot as plt
6
7from ._dwt import pad
8
9__all__ = ['wavedec_keys', 'wavedec2_keys', 'draw_2d_wp_basis',
10           'draw_2d_fswavedecn_basis', 'boundary_mode_subplot']
11
12
13def wavedec_keys(level):
14    """Subband keys corresponding to a wavedec decomposition."""
15    approx = ''
16    coeffs = {}
17    for lev in range(level):
18        for k in ['a', 'd']:
19            coeffs[approx + k] = None
20        approx = 'a' * (lev + 1)
21        if lev < level - 1:
22            coeffs.pop(approx)
23    return list(coeffs.keys())
24
25
26def wavedec2_keys(level):
27    """Subband keys corresponding to a wavedec2 decomposition."""
28    approx = ''
29    coeffs = {}
30    for lev in range(level):
31        for k in ['a', 'h', 'v', 'd']:
32            coeffs[approx + k] = None
33        approx = 'a' * (lev + 1)
34        if lev < level - 1:
35            coeffs.pop(approx)
36    return list(coeffs.keys())
37
38
39def _box(bl, ur):
40    """(x, y) coordinates for the 4 lines making up a rectangular box.
41
42    Parameters
43    ==========
44    bl : float
45        The bottom left corner of the box
46    ur : float
47        The upper right corner of the box
48
49    Returns
50    =======
51    coords : 2-tuple
52        The first and second elements of the tuple are the x and y coordinates
53        of the box.
54    """
55    xl, xr = bl[0], ur[0]
56    yb, yt = bl[1], ur[1]
57    box_x = [xl, xr,
58             xr, xr,
59             xr, xl,
60             xl, xl]
61    box_y = [yb, yb,
62             yb, yt,
63             yt, yt,
64             yt, yb]
65    return (box_x, box_y)
66
67
68def _2d_wp_basis_coords(shape, keys):
69    # Coordinates of the lines to be drawn by draw_2d_wp_basis
70    coords = []
71    centers = {}  # retain center of boxes for use in labeling
72    for key in keys:
73        offset_x = offset_y = 0
74        for n, char in enumerate(key):
75            if char in ['h', 'd']:
76                offset_x += shape[0] // 2**(n + 1)
77            if char in ['v', 'd']:
78                offset_y += shape[1] // 2**(n + 1)
79        sx = shape[0] // 2**(n + 1)
80        sy = shape[1] // 2**(n + 1)
81        xc, yc = _box((offset_x, -offset_y),
82                      (offset_x + sx, -offset_y - sy))
83        coords.append((xc, yc))
84        centers[key] = (offset_x + sx // 2, -offset_y - sy // 2)
85    return coords, centers
86
87
88def draw_2d_wp_basis(shape, keys, fmt='k', plot_kwargs={}, ax=None,
89                     label_levels=0):
90    """Plot a 2D representation of a WaveletPacket2D basis."""
91    coords, centers = _2d_wp_basis_coords(shape, keys)
92    if ax is None:
93        fig, ax = plt.subplots(1, 1)
94    else:
95        fig = ax.get_figure()
96    for coord in coords:
97        ax.plot(coord[0], coord[1], fmt)
98    ax.set_axis_off()
99    ax.axis('square')
100    if label_levels > 0:
101        for key, c in centers.items():
102            if len(key) <= label_levels:
103                ax.text(c[0], c[1], key,
104                        horizontalalignment='center',
105                        verticalalignment='center')
106    return fig, ax
107
108
109def _2d_fswavedecn_coords(shape, levels):
110    coords = []
111    centers = {}  # retain center of boxes for use in labeling
112    for key in product(wavedec_keys(levels), repeat=2):
113        (key0, key1) = key
114        offsets = [0, 0]
115        widths = list(shape)
116        for n0, char in enumerate(key0):
117            if char in ['d']:
118                offsets[0] += shape[0] // 2**(n0 + 1)
119        for n1, char in enumerate(key1):
120            if char in ['d']:
121                offsets[1] += shape[1] // 2**(n1 + 1)
122        widths[0] = shape[0] // 2**(n0 + 1)
123        widths[1] = shape[1] // 2**(n1 + 1)
124        xc, yc = _box((offsets[0], -offsets[1]),
125                      (offsets[0] + widths[0], -offsets[1] - widths[1]))
126        coords.append((xc, yc))
127        centers[(key0, key1)] = (offsets[0] + widths[0] / 2,
128                                 -offsets[1] - widths[1] / 2)
129    return coords, centers
130
131
132def draw_2d_fswavedecn_basis(shape, levels, fmt='k', plot_kwargs={}, ax=None,
133                             label_levels=0):
134    """Plot a 2D representation of a WaveletPacket2D basis."""
135    coords, centers = _2d_fswavedecn_coords(shape, levels)
136    if ax is None:
137        fig, ax = plt.subplots(1, 1)
138    else:
139        fig = ax.get_figure()
140    for coord in coords:
141        ax.plot(coord[0], coord[1], fmt)
142    ax.set_axis_off()
143    ax.axis('square')
144    if label_levels > 0:
145        for key, c in centers.items():
146            lev = np.max([len(k) for k in key])
147            if lev <= label_levels:
148                ax.text(c[0], c[1], key,
149                        horizontalalignment='center',
150                        verticalalignment='center')
151    return fig, ax
152
153
154def boundary_mode_subplot(x, mode, ax, symw=True):
155    """Plot an illustration of the boundary mode in a subplot axis."""
156
157    # if odd-length, periodization replicates the last sample to make it even
158    if mode == 'periodization' and len(x) % 2 == 1:
159        x = np.concatenate((x, (x[-1], )))
160
161    npad = 2 * len(x)
162    t = np.arange(len(x) + 2 * npad)
163    xp = pad(x, (npad, npad), mode=mode)
164
165    ax.plot(t, xp, 'k.')
166    ax.set_title(mode)
167
168    # plot the original signal in red
169    if mode == 'periodization':
170        ax.plot(t[npad:npad + len(x) - 1], x[:-1], 'r.')
171    else:
172        ax.plot(t[npad:npad + len(x)], x, 'r.')
173
174    # add vertical bars indicating points of symmetry or boundary extension
175    o2 = np.ones(2)
176    left = npad
177    if symw:
178        step = len(x) - 1
179        rng = range(-2, 4)
180    else:
181        left -= 0.5
182        step = len(x)
183        rng = range(-2, 4)
184    if mode in ['smooth', 'constant', 'zero']:
185        rng = range(0, 2)
186    for rep in rng:
187        ax.plot((left + rep * step) * o2, [xp.min() - .5, xp.max() + .5], 'k-')
188