1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import sys
5
6import matplotlib.pyplot as plt
7import numpy as np
8import pywt
9
10
11usage = """
12Usage:
13    python waveinfo.py waveletname
14
15    Example: python waveinfo.py 'sym5'
16"""
17
18try:
19    wavelet = pywt.Wavelet(sys.argv[1])
20    try:
21        level = int(sys.argv[2])
22    except IndexError as e:
23        level = 10
24except ValueError as e:
25    print("Unknown wavelet")
26    raise SystemExit
27except IndexError as e:
28    print(usage)
29    raise SystemExit
30
31
32data = wavelet.wavefun(level)
33if len(data) == 2:
34    x = data[1]
35    psi = data[0]
36    fig = plt.figure()
37    if wavelet.complex_cwt:
38        plt.subplot(211)
39        plt.title(wavelet.name+' real part')
40        mi, ma = np.real(psi).min(), np.real(psi).max()
41        margin = (ma - mi) * 0.05
42        plt.plot(x,np.real(psi))
43        plt.ylim(mi - margin, ma + margin)
44        plt.xlim(x[0], x[-1])
45        plt.subplot(212)
46        plt.title(wavelet.name+' imag part')
47        mi, ma = np.imag(psi).min(), np.imag(psi).max()
48        margin = (ma - mi) * 0.05
49        plt.plot(x,np.imag(psi))
50        plt.ylim(mi - margin, ma + margin)
51        plt.xlim(x[0], x[-1])
52    else:
53        mi, ma = psi.min(), psi.max()
54        margin = (ma - mi) * 0.05
55        plt.plot(x,psi)
56        plt.title(wavelet.name)
57        plt.ylim(mi - margin, ma + margin)
58        plt.xlim(x[0], x[-1])
59else:
60    funcs, x = data[:-1], data[-1]
61    labels = ["scaling function (phi)", "wavelet function (psi)",
62              "r. scaling function (phi)", "r. wavelet function (psi)"]
63    colors = ("r", "g", "r", "g")
64    fig = plt.figure()
65    for i, (d, label, color) in enumerate(zip(funcs, labels, colors)):
66        mi, ma = d.min(), d.max()
67        margin = (ma - mi) * 0.05
68        ax = fig.add_subplot((len(data) - 1) // 2, 2, 1 + i)
69
70        ax.plot(x, d, color)
71        ax.set_title(label)
72        ax.set_ylim(mi - margin, ma + margin)
73        ax.set_xlim(x[0], x[-1])
74
75plt.show()
76