1# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/>
2# Copyright (c) 2012-2018 The PyWavelets Developers
3#                         <https://github.com/PyWavelets/pywt>
4# See COPYING for license details.
5
6__doc__ = """Cython wrapper for low-level C wavelet transform implementation."""
7__all__ = ['MODES', 'Modes', 'DiscreteContinuousWavelet', 'Wavelet',
8           'ContinuousWavelet', 'wavelist', 'families']
9
10
11import warnings
12import re
13
14from . cimport c_wt
15from . cimport common
16from ._dwt cimport upcoef
17from ._cwt cimport cwt_psi_single
18
19from libc.math cimport pow, sqrt
20
21import numpy as np
22
23
24# Caution: order of _old_modes entries must match _Modes.modes below
25_old_modes = ['zpd',
26              'cpd',
27              'sym',
28              'ppd',
29              'sp1',
30              'per',
31              ]
32
33_attr_deprecation_msg = ('{old} has been renamed to {new} and will '
34                         'be unavailable in a future version '
35                         'of pywt.')
36
37# Extract float/int parameters from a wavelet name. Examples:
38#    re.findall(cwt_pattern, 'fbsp1-1.5-1') ->  ['1', 1.5', '1']
39cwt_pattern = re.compile(r'\D+(\d+\.*\d*)+')
40
41
42# raises exception if the wavelet name is undefined
43cdef int is_discrete_wav(WAVELET_NAME name):
44    cdef int is_discrete
45    discrete = wavelet.is_discrete_wavelet(name)
46    if discrete == -1:
47        raise ValueError("unrecognized wavelet family name")
48    return discrete
49
50
51class _Modes(object):
52    """
53    Because the most common and practical way of representing digital signals
54    in computer science is with finite arrays of values, some extrapolation of
55    the input data has to be performed in order to extend the signal before
56    computing the :ref:`Discrete Wavelet Transform <ref-dwt>` using the
57    cascading filter banks algorithm.
58
59    Depending on the extrapolation method, significant artifacts at the
60    signal's borders can be introduced during that process, which in turn may
61    lead to inaccurate computations of the :ref:`DWT <ref-dwt>` at the signal's
62    ends.
63
64    PyWavelets provides several methods of signal extrapolation that can be
65    used to minimize this negative effect:
66
67    zero - zero-padding                   0  0 | x1 x2 ... xn | 0  0
68    constant - constant-padding          x1 x1 | x1 x2 ... xn | xn xn
69    symmetric - symmetric-padding        x2 x1 | x1 x2 ... xn | xn xn-1
70    reflect - reflect-padding            x3 x2 | x1 x2 ... xn | xn-1 xn-2
71    periodic - periodic-padding        xn-1 xn | x1 x2 ... xn | x1 x2
72    smooth - smooth-padding             (1st derivative interpolation)
73    antisymmetric -                    -x2 -x1 | x1 x2 ... xn | -xn -xn-1
74    antireflect -                      -x3 -x2 | x1 x2 ... xn | -xn-1 -xn-2
75
76    DWT performed for these extension modes is slightly redundant, but ensure a
77    perfect reconstruction for IDWT. To receive the smallest possible number of
78    coefficients, computations can be performed with the periodization mode:
79
80    periodization - like periodic-padding but gives the smallest possible
81                    number of decomposition coefficients. IDWT must be
82                    performed with the same mode.
83
84    Examples
85    --------
86    >>> import pywt
87    >>> pywt.Modes.modes
88        ['zero', 'constant', 'symmetric', 'reflect', 'periodic', 'smooth', 'periodization', 'antisymmetric', 'antireflect']
89    >>> # The different ways of passing wavelet and mode parameters
90    >>> (a, d) = pywt.dwt([1,2,3,4,5,6], 'db2', 'smooth')
91    >>> (a, d) = pywt.dwt([1,2,3,4,5,6], pywt.Wavelet('db2'), pywt.Modes.smooth)
92
93    Notes
94    -----
95    Extending data in context of PyWavelets does not mean reallocation of the
96    data in computer's physical memory and copying values, but rather computing
97    the extra values only when they are needed. This feature saves extra
98    memory and CPU resources and helps to avoid page swapping when handling
99    relatively big data arrays on computers with low physical memory.
100
101    """
102    zero = common.MODE_ZEROPAD
103    constant = common.MODE_CONSTANT_EDGE
104    symmetric = common.MODE_SYMMETRIC
105    reflect = common.MODE_REFLECT
106    periodic = common.MODE_PERIODIC
107    smooth = common.MODE_SMOOTH
108    periodization = common.MODE_PERIODIZATION
109    antisymmetric = common.MODE_ANTISYMMETRIC
110    antireflect = common.MODE_ANTIREFLECT
111
112    # Caution: order in modes list below must match _old_modes above
113    modes = ["zero", "constant", "symmetric", "periodic", "smooth",
114             "periodization", "reflect", "antisymmetric", "antireflect"]
115
116    def from_object(self, mode):
117        if isinstance(mode, int):
118            if mode <= common.MODE_INVALID or mode >= common.MODE_MAX:
119                raise ValueError("Invalid mode.")
120            m = mode
121        else:
122            try:
123                m = getattr(Modes, mode)
124            except AttributeError:
125                raise ValueError("Unknown mode name '%s'." % mode)
126
127        return m
128
129    def __getattr__(self, mode):
130        # catch deprecated mode names
131        if mode in _old_modes:
132            new_mode = Modes.modes[_old_modes.index(mode)]
133            warnings.warn(_attr_deprecation_msg.format(old=mode, new=new_mode),
134                          DeprecationWarning)
135            mode = new_mode
136        return Modes.__getattribute__(mode)
137
138
139Modes = _Modes()
140
141
142class _DeprecatedMODES(_Modes):
143    msg = ("MODES has been renamed to Modes and will be "
144           "removed in a future version of pywt.")
145
146    def __getattribute__(self, attr):
147        """Override so that deprecation warning is shown
148        every time MODES is used.
149
150        N.B. have to use __getattribute__ as well as __getattr__
151        to ensure warning on e.g. `MODES.symmetric`.
152        """
153        if not attr.startswith('_'):
154            warnings.warn(_DeprecatedMODES.msg, DeprecationWarning)
155        return _Modes.__getattribute__(self, attr)
156
157    def __getattr__(self, attr):
158        """Override so that deprecation warning is shown
159        every time MODES is used.
160        """
161        warnings.warn(_DeprecatedMODES.msg, DeprecationWarning)
162        return _Modes.__getattr__(self, attr)
163
164
165MODES = _DeprecatedMODES()
166
167###############################################################################
168# Wavelet
169
170include "wavelets_list.pxi"  # __wname_to_code
171
172cdef object wname_to_code(name):
173    cdef object code_number
174    try:
175        if len(name) > 4 and name[:4] in ['cmor', 'shan', 'fbsp']:
176            name = name[:4]
177        code_number = __wname_to_code[name]
178        return code_number
179    except KeyError:
180        raise ValueError("Unknown wavelet name '%s', check wavelist() for the "
181                         "list of available builtin wavelets." % name)
182
183
184def wavelist(family=None, kind='all'):
185    """
186    wavelist(family=None, kind='all')
187
188    Returns list of available wavelet names for the given family name.
189
190    Parameters
191    ----------
192    family : str, optional
193        Short family name. If the family name is None (default) then names
194        of all the built-in wavelets are returned. Otherwise the function
195        returns names of wavelets that belong to the given family.
196        Valid names are::
197
198            'haar', 'db', 'sym', 'coif', 'bior', 'rbio', 'dmey', 'gaus',
199            'mexh', 'morl', 'cgau', 'shan', 'fbsp', 'cmor'
200
201    kind : {'all', 'continuous', 'discrete'}, optional
202        Whether to return only wavelet names of discrete or continuous
203        wavelets, or all wavelets.  Default is ``'all'``.
204        Ignored if ``family`` is specified.
205
206    Returns
207    -------
208    wavelist : list of str
209        List of available wavelet names.
210
211    Examples
212    --------
213    >>> import pywt
214    >>> pywt.wavelist('coif')
215    ['coif1', 'coif2', 'coif3', 'coif4', 'coif5', 'coif6', 'coif7', ...
216    >>> pywt.wavelist(kind='continuous')
217    ['cgau1', 'cgau2', 'cgau3', 'cgau4', 'cgau5', 'cgau6', 'cgau7', ...
218
219    """
220    cdef object wavelets, sorting_list
221
222    if kind not in ('all', 'continuous', 'discrete'):
223        raise ValueError("Unrecognized value for `kind`: %s" % kind)
224
225    def _check_kind(name, kind):
226        if kind == 'all':
227            return True
228
229        family_code, family_number = wname_to_code(name)
230        is_discrete = is_discrete_wav(family_code)
231        if kind == 'discrete':
232            return is_discrete
233        else:
234            return not is_discrete
235
236    sorting_list = []  # for natural sorting order
237    wavelets = []
238    cdef object name
239    if family is None:
240        for name in __wname_to_code:
241            if _check_kind(name, kind):
242                sorting_list.append((name[:2], len(name), name))
243    elif family in __wfamily_list_short:
244        for name in __wname_to_code:
245            if name.startswith(family):
246                sorting_list.append((name[:2], len(name), name))
247    else:
248        raise ValueError("Invalid short family name '%s'." % family)
249
250    sorting_list.sort()
251    for x, x, name in sorting_list:
252        wavelets.append(name)
253    return wavelets
254
255
256def families(int short=True):
257    """
258    families(short=True)
259
260    Returns a list of available built-in wavelet families.
261
262    Currently the built-in families are:
263
264    * Haar (``haar``)
265    * Daubechies (``db``)
266    * Symlets (``sym``)
267    * Coiflets (``coif``)
268    * Biorthogonal (``bior``)
269    * Reverse biorthogonal (``rbio``)
270    * `"Discrete"` FIR approximation of Meyer wavelet (``dmey``)
271    * Gaussian wavelets (``gaus``)
272    * Mexican hat wavelet (``mexh``)
273    * Morlet wavelet (``morl``)
274    * Complex Gaussian wavelets (``cgau``)
275    * Shannon wavelets (``shan``)
276    * Frequency B-Spline wavelets (``fbsp``)
277    * Complex Morlet wavelets (``cmor``)
278
279    Parameters
280    ----------
281    short : bool, optional
282        Use short names (default: True).
283
284    Returns
285    -------
286    families : list
287        List of available wavelet families.
288
289    Examples
290    --------
291    >>> import pywt
292    >>> pywt.families()
293    ['haar', 'db', 'sym', 'coif', 'bior', 'rbio', 'dmey', 'gaus', 'mexh', 'morl', 'cgau', 'shan', 'fbsp', 'cmor']
294    >>> pywt.families(short=False)
295    ['Haar', 'Daubechies', 'Symlets', 'Coiflets', 'Biorthogonal', 'Reverse biorthogonal', 'Discrete Meyer (FIR Approximation)', 'Gaussian', 'Mexican hat wavelet', 'Morlet wavelet', 'Complex Gaussian wavelets', 'Shannon wavelets', 'Frequency B-Spline wavelets', 'Complex Morlet wavelets']
296
297    """
298    if short:
299        return __wfamily_list_short[:]
300    return __wfamily_list_long[:]
301
302
303def DiscreteContinuousWavelet(name=u"", object filter_bank=None):
304    """
305    DiscreteContinuousWavelet(name, filter_bank=None) returns a
306    Wavelet or a ContinuousWavelet object depending of the given name.
307
308    In order to use a built-in wavelet the parameter name must be
309    a valid name from the wavelist() list.
310    To create a custom wavelet object, filter_bank parameter must
311    be specified. It can be either a list of four filters or an object
312    that a `filter_bank` attribute which returns a list of four
313    filters - just like the Wavelet instance itself.
314
315    For a ContinuousWavelet, filter_bank cannot be used and must remain unset.
316
317    """
318    if not name and filter_bank is None:
319        raise TypeError("Wavelet name or filter bank must be specified.")
320    if filter_bank is None:
321        name = name.lower()
322        family_code, family_number = wname_to_code(name)
323        if is_discrete_wav(family_code):
324            return Wavelet(name, filter_bank)
325        else:
326            return ContinuousWavelet(name)
327    else:
328        return Wavelet(name, filter_bank)
329
330
331cdef public class Wavelet [type WaveletType, object WaveletObject]:
332    """
333    Wavelet(name, filter_bank=None) object describe properties of
334    a wavelet identified by name.
335
336    In order to use a built-in wavelet the parameter name must be
337    a valid name from the wavelist() list.
338    To create a custom wavelet object, filter_bank parameter must
339    be specified. It can be either a list of four filters or an object
340    that a `filter_bank` attribute which returns a list of four
341    filters - just like the Wavelet instance itself.
342
343    """
344    #cdef readonly properties
345    def __cinit__(self, name=u"", object filter_bank=None):
346        cdef object family_code, family_number
347        cdef object filters
348        cdef pywt_index_t filter_length
349        cdef object dec_lo, dec_hi, rec_lo, rec_hi
350
351        if not name and filter_bank is None:
352            raise TypeError("Wavelet name or filter bank must be specified.")
353
354        if filter_bank is None:
355            # builtin wavelet
356            self.name = name.lower()
357            family_code, family_number = wname_to_code(self.name)
358            if is_discrete_wav(family_code):
359                self.w = <wavelet.DiscreteWavelet*> wavelet.discrete_wavelet(family_code, family_number)
360            if self.w is NULL:
361                if self.name in wavelist(kind='continuous'):
362                    raise ValueError("The `Wavelet` class is for discrete "
363                          "wavelets, %s is a continuous wavelet.  Use "
364                          "pywt.ContinuousWavelet instead" % self.name)
365                else:
366                    raise ValueError("Invalid wavelet name '%s'." % self.name)
367            self.number = family_number
368        else:
369            if hasattr(filter_bank, "filter_bank"):
370                filters = filter_bank.filter_bank
371                if len(filters) != 4:
372                    raise ValueError("Expected filter bank with 4 filters, "
373                    "got filter bank with %d filters." % len(filters))
374            elif hasattr(filter_bank, "get_filters_coeffs"):
375                msg = ("Creating custom Wavelets using objects that define "
376                       "`get_filters_coeffs` method is deprecated. "
377                       "The `filter_bank` parameter should define a "
378                       "`filter_bank` attribute instead of "
379                       "`get_filters_coeffs` method.")
380                warnings.warn(msg, DeprecationWarning)
381                filters = filter_bank.get_filters_coeffs()
382                if len(filters) != 4:
383                    msg = ("Expected filter bank with 4 filters, got filter "
384                           "bank with %d filters." % len(filters))
385                    raise ValueError(msg)
386            else:
387                filters = filter_bank
388                if len(filters) != 4:
389                    msg = ("Expected list of 4 filters coefficients, "
390                           "got %d filters." % len(filters))
391                    raise ValueError(msg)
392            try:
393                dec_lo = np.asarray(filters[0], dtype=np.float64)
394                dec_hi = np.asarray(filters[1], dtype=np.float64)
395                rec_lo = np.asarray(filters[2], dtype=np.float64)
396                rec_hi = np.asarray(filters[3], dtype=np.float64)
397            except TypeError:
398                raise ValueError("Filter bank with numeric values required.")
399
400            if not (1 == dec_lo.ndim == dec_hi.ndim ==
401                         rec_lo.ndim == rec_hi.ndim):
402                raise ValueError("All filters in filter bank must be 1D.")
403
404            filter_length = len(dec_lo)
405            if not (0 < filter_length == len(dec_hi) == len(rec_lo) ==
406                                         len(rec_hi)) > 0:
407                raise ValueError("All filters in filter bank must have "
408                                 "length greater than 0.")
409
410            self.w = <wavelet.DiscreteWavelet*> wavelet.blank_discrete_wavelet(filter_length)
411            if self.w is NULL:
412                raise MemoryError("Could not allocate memory for given "
413                                  "filter bank.")
414
415            # copy values to struct
416            copy_object_to_float32_array(dec_lo, self.w.dec_lo_float)
417            copy_object_to_float32_array(dec_hi, self.w.dec_hi_float)
418            copy_object_to_float32_array(rec_lo, self.w.rec_lo_float)
419            copy_object_to_float32_array(rec_hi, self.w.rec_hi_float)
420
421            copy_object_to_float64_array(dec_lo, self.w.dec_lo_double)
422            copy_object_to_float64_array(dec_hi, self.w.dec_hi_double)
423            copy_object_to_float64_array(rec_lo, self.w.rec_lo_double)
424            copy_object_to_float64_array(rec_hi, self.w.rec_hi_double)
425
426            self.name = name
427
428    def __dealloc__(self):
429        if self.w is not NULL:
430            wavelet.free_discrete_wavelet(self.w)
431            self.w = NULL
432
433    def __reduce__(self):
434        return (Wavelet, (self.name, self.filter_bank))
435
436    def __len__(self):
437        return self.w.dec_len
438
439    property dec_lo:
440        "Lowpass decomposition filter"
441        def __get__(self):
442            return float64_array_to_list(self.w.dec_lo_double, self.w.dec_len)
443
444    property dec_hi:
445        "Highpass decomposition filter"
446        def __get__(self):
447            return float64_array_to_list(self.w.dec_hi_double, self.w.dec_len)
448
449    property rec_lo:
450        "Lowpass reconstruction filter"
451        def __get__(self):
452            return float64_array_to_list(self.w.rec_lo_double, self.w.rec_len)
453
454    property rec_hi:
455        "Highpass reconstruction filter"
456        def __get__(self):
457            return float64_array_to_list(self.w.rec_hi_double, self.w.rec_len)
458
459    property rec_len:
460        "Reconstruction filters length"
461        def __get__(self):
462            return self.w.rec_len
463
464    property dec_len:
465        "Decomposition filters length"
466        def __get__(self):
467            return self.w.dec_len
468
469    property family_number:
470        "Wavelet family number"
471        def __get__(self):
472            return self.number
473
474    property family_name:
475        "Wavelet family name"
476        def __get__(self):
477            return self.w.base.family_name.decode('latin-1')
478
479    property short_family_name:
480        "Short wavelet family name"
481        def __get__(self):
482            return self.w.base.short_name.decode('latin-1')
483
484    property orthogonal:
485        "Is orthogonal"
486        def __get__(self):
487            return bool(self.w.base.orthogonal)
488        def __set__(self, int value):
489            self.w.base.orthogonal = (value != 0)
490
491    property biorthogonal:
492        "Is biorthogonal"
493        def __get__(self):
494            return bool(self.w.base.biorthogonal)
495        def __set__(self, int value):
496            self.w.base.biorthogonal = (value != 0)
497
498    property symmetry:
499        "Wavelet symmetry"
500        def __get__(self):
501            if self.w.base.symmetry == wavelet.ASYMMETRIC:
502                return "asymmetric"
503            elif self.w.base.symmetry == wavelet.NEAR_SYMMETRIC:
504                return "near symmetric"
505            elif self.w.base.symmetry == wavelet.SYMMETRIC:
506                return "symmetric"
507            elif self.w.base.symmetry == wavelet.ANTI_SYMMETRIC:
508                return "anti-symmetric"
509            else:
510                return "unknown"
511
512    property vanishing_moments_psi:
513        "Number of vanishing moments for wavelet function"
514        def __get__(self):
515            if self.w.vanishing_moments_psi >= 0:
516                return self.w.vanishing_moments_psi
517
518    property vanishing_moments_phi:
519        "Number of vanishing moments for scaling function"
520        def __get__(self):
521            if self.w.vanishing_moments_phi >= 0:
522                return self.w.vanishing_moments_phi
523
524    property filter_bank:
525        """Returns tuple of wavelet filters coefficients
526        (dec_lo, dec_hi, rec_lo, rec_hi)
527        """
528        def __get__(self):
529            return (self.dec_lo, self.dec_hi, self.rec_lo, self.rec_hi)
530
531    def get_filters_coeffs(self):
532        warnings.warn("The `get_filters_coeffs` method is deprecated. "
533                      "Use `filter_bank` attribute instead.", DeprecationWarning)
534        return self.filter_bank
535
536    property inverse_filter_bank:
537        """Tuple of inverse wavelet filters coefficients
538        (rec_lo[::-1], rec_hi[::-1], dec_lo[::-1], dec_hi[::-1])
539        """
540        def __get__(self):
541            return (self.rec_lo[::-1], self.rec_hi[::-1], self.dec_lo[::-1],
542                    self.dec_hi[::-1])
543
544    def get_reverse_filters_coeffs(self):
545        warnings.warn("The `get_reverse_filters_coeffs` method is deprecated. "
546                      "Use `inverse_filter_bank` attribute instead.",
547                      DeprecationWarning)
548        return self.inverse_filter_bank
549
550    def wavefun(self, int level=8):
551        """
552        wavefun(self, level=8)
553
554        Calculates approximations of scaling function (`phi`) and wavelet
555        function (`psi`) on xgrid (`x`) at a given level of refinement.
556
557        Parameters
558        ----------
559        level : int, optional
560            Level of refinement (default: 8).
561
562        Returns
563        -------
564        [phi, psi, x] : array_like
565            For orthogonal wavelets returns scaling function, wavelet function
566            and xgrid - [phi, psi, x].
567
568        [phi_d, psi_d, phi_r, psi_r, x] : array_like
569            For biorthogonal wavelets returns scaling and wavelet function both
570            for decomposition and reconstruction and xgrid
571
572        Examples
573        --------
574        >>> import pywt
575        >>> # Orthogonal
576        >>> wavelet = pywt.Wavelet('db2')
577        >>> phi, psi, x = wavelet.wavefun(level=5)
578        >>> # Biorthogonal
579        >>> wavelet = pywt.Wavelet('bior3.5')
580        >>> phi_d, psi_d, phi_r, psi_r, x = wavelet.wavefun(level=5)
581
582        """
583        cdef pywt_index_t filter_length "filter_length"
584        cdef pywt_index_t right_extent_length "right_extent_length"
585        cdef pywt_index_t output_length "output_length"
586        cdef pywt_index_t keep_length "keep_length"
587        cdef np.float64_t n, n_mul
588        cdef np.float64_t[::1] n_arr = <np.float64_t[:1]> &n,
589        cdef np.float64_t[::1] n_mul_arr = <np.float64_t[:1]> &n_mul
590        cdef double p "p"
591        cdef double mul "mul"
592        cdef Wavelet other "other"
593        cdef phi_d, psi_d, phi_r, psi_r
594        cdef psi_i
595        cdef np.float64_t[::1] x, psi
596
597        n = pow(sqrt(2.), <double>level)
598        p = (pow(2., <double>level))
599
600        if self.w.base.orthogonal:
601            filter_length = self.w.dec_len
602            output_length = <pywt_index_t> ((filter_length-1) * p + 1)
603            keep_length = get_keep_length(output_length, level, filter_length)
604            output_length = fix_output_length(output_length, keep_length)
605
606            right_extent_length = get_right_extent_length(output_length,
607                                                          keep_length)
608
609            # phi, psi, x
610            return [np.concatenate(([0.],
611                                    keep(upcoef(True, n_arr, self, level, 0), keep_length),
612                                    np.zeros(right_extent_length))),
613                    np.concatenate(([0.],
614                                    keep(upcoef(False, n_arr, self, level, 0), keep_length),
615                                    np.zeros(right_extent_length))),
616                    np.linspace(0.0, (output_length-1)/p, output_length)]
617        else:
618            if self.w.base.biorthogonal and (self.w.vanishing_moments_psi % 4) != 1:
619                # FIXME: I don't think this branch is well tested
620                n_mul = -n
621            else:
622                n_mul = n
623
624            other = Wavelet(filter_bank=self.inverse_filter_bank)
625
626            filter_length  = other.w.dec_len
627            output_length = <pywt_index_t> ((filter_length-1) * p)
628            keep_length = get_keep_length(output_length, level, filter_length)
629            output_length = fix_output_length(output_length, keep_length)
630            right_extent_length = get_right_extent_length(output_length, keep_length)
631
632            phi_d  = np.concatenate(([0.],
633                                     keep(upcoef(True, n_arr, other, level, 0), keep_length),
634                                     np.zeros(right_extent_length)))
635            psi_d  = np.concatenate(([0.],
636                                     keep(upcoef(False, n_mul_arr, other, level, 0),
637                                          keep_length),
638                                     np.zeros(right_extent_length)))
639
640            filter_length = self.w.dec_len
641            output_length = <pywt_index_t> ((filter_length-1) * p)
642            keep_length = get_keep_length(output_length, level, filter_length)
643            output_length = fix_output_length(output_length, keep_length)
644            right_extent_length = get_right_extent_length(output_length, keep_length)
645
646            phi_r  = np.concatenate(([0.],
647                                     keep(upcoef(True, n_arr, self, level, 0), keep_length),
648                                     np.zeros(right_extent_length)))
649            psi_r  = np.concatenate(([0.],
650                                     keep(upcoef(False, n_mul_arr, self, level, 0),
651                                          keep_length),
652                                     np.zeros(right_extent_length)))
653
654            return [phi_d, psi_d, phi_r, psi_r,
655                    np.linspace(0.0, (output_length - 1) / p, output_length)]
656
657    def __str__(self):
658        s = []
659        for x in [
660            u"Wavelet %s"           % self.name,
661            u"  Family name:    %s" % self.family_name,
662            u"  Short name:     %s" % self.short_family_name,
663            u"  Filters length: %d" % self.dec_len,
664            u"  Orthogonal:     %s" % self.orthogonal,
665            u"  Biorthogonal:   %s" % self.biorthogonal,
666            u"  Symmetry:       %s" % self.symmetry,
667            u"  DWT:            True",
668            u"  CWT:            False"
669            ]:
670            s.append(x.rstrip())
671        return u'\n'.join(s)
672
673    def __repr__(self):
674        repr = "{module}.{classname}(name='{name}', filter_bank={filter_bank})"
675        return repr.format(module=type(self).__module__,
676                           classname=type(self).__name__,
677                           name=self.name,
678                           filter_bank=self.filter_bank)
679
680
681cdef public class ContinuousWavelet [type ContinuousWaveletType, object ContinuousWaveletObject]:
682    """
683    ContinuousWavelet(name, dtype) object describe properties of
684    a continuous wavelet identified by name.
685
686    In order to use a built-in wavelet the parameter name must be
687    a valid name from the wavelist() list.
688
689    """
690    #cdef readonly properties
691    def __cinit__(self, name=u"", dtype=np.float64):
692        cdef object family_code, family_number
693
694        # builtin wavelet
695        self.name = name.lower()
696        self.dt = dtype
697        if np.dtype(self.dt) not in [np.float32, np.float64]:
698            raise ValueError(
699                "Only np.float32 and np.float64 dtype are supported for "
700                "ContinuousWavelet objects.")
701        if len(self.name) >= 4 and self.name[:4] in ['cmor', 'shan', 'fbsp']:
702            base_name = self.name[:4]
703            if base_name == self.name:
704                if base_name == 'fbsp':
705                    msg = (
706                        "Wavelets of family {0}, without parameters "
707                        "specified in the name are deprecated.  The name "
708                        "should take the form {0}M-B-C where M is the spline "
709                        "order and B, C are floats representing the bandwidth "
710                        "frequency and center frequency, respectively "
711                        "(example: {0}1-1.5-1.0).").format(base_name)
712                else:
713                    msg = (
714                        "Wavelets from the family {0}, without parameters "
715                        "specified in the name are deprecated. The name "
716                        "should takethe form {0}B-C where B and C are floats "
717                        "representing the bandwidth frequency and center "
718                        "frequency, respectively (example: {0}1.5-1.0)."
719                        ).format(base_name)
720                warnings.warn(msg, FutureWarning)
721        else:
722            base_name = self.name
723        family_code, family_number = wname_to_code(base_name)
724        self.w = <wavelet.ContinuousWavelet*> wavelet.continuous_wavelet(
725            family_code, family_number)
726
727        if self.w is NULL:
728            raise ValueError("Invalid wavelet name '%s'." % self.name)
729        self.number = family_number
730
731        # set wavelet attributes based on frequencies extracted from the name
732        if base_name != self.name:
733            freqs = re.findall(cwt_pattern, self.name)
734            if base_name in ['shan', 'cmor']:
735                if len(freqs) != 2:
736                    raise ValueError(
737                        ("For wavelets of family {0}, the name should take "
738                         "the form {0}B-C where B and C are floats "
739                         "representing the bandwidth frequency and center "
740                         "frequency, respectively. (example: {0}1.5-1.0)"
741                        ).format(base_name))
742                self.w.bandwidth_frequency = float(freqs[0])
743                self.w.center_frequency = float(freqs[1])
744            elif base_name in ['fbsp', ]:
745                if len(freqs) != 3:
746                    raise ValueError(
747                        ("For wavelets of family {0}, the name should take "
748                         "the form {0}M-B-C where M is the spline order and B"
749                         ", C are floats representing the bandwidth frequency "
750                         "and center frequency, respectively "
751                         "(example: {0}1-1.5-1.0).").format(base_name))
752                M = float(freqs[0])
753                self.w.bandwidth_frequency = float(freqs[1])
754                self.w.center_frequency = float(freqs[2])
755                if M < 1 or M % 1 != 0:
756                    raise ValueError(
757                        "Wavelet spline order must be an integer >= 1.")
758                self.w.fbsp_order = int(M)
759            else:
760                raise ValueError(
761                    "Invalid continuous wavelet name '%s'." % self.name)
762
763
764    def __dealloc__(self):
765        if self.w is not NULL:
766            wavelet.free_continuous_wavelet(self.w)
767            self.w = NULL
768
769    def __reduce__(self):
770        return (ContinuousWavelet, (self.name, self.dt))
771
772    property family_number:
773        "Wavelet family number"
774        def __get__(self):
775            return self.number
776
777    property family_name:
778        "Wavelet family name"
779        def __get__(self):
780            return self.w.base.family_name.decode('latin-1')
781
782    property short_family_name:
783        "Short wavelet family name"
784        def __get__(self):
785            return self.w.base.short_name.decode('latin-1')
786
787    property orthogonal:
788        "Is orthogonal"
789        def __get__(self):
790            return bool(self.w.base.orthogonal)
791        def __set__(self, int value):
792            self.w.base.orthogonal = (value != 0)
793
794    property biorthogonal:
795        "Is biorthogonal"
796        def __get__(self):
797            return bool(self.w.base.biorthogonal)
798        def __set__(self, int value):
799            self.w.base.biorthogonal = (value != 0)
800
801    property complex_cwt:
802        "CWT is complex"
803        def __get__(self):
804            return bool(self.w.complex_cwt)
805        def __set__(self, int value):
806            self.w.complex_cwt = (value != 0)
807
808    property lower_bound:
809        "Lower Bound"
810        def __get__(self):
811            if self.w.lower_bound != self.w.upper_bound:
812                return self.w.lower_bound
813        def __set__(self, float value):
814            self.w.lower_bound = value
815
816    property upper_bound:
817        "Upper Bound"
818        def __get__(self):
819            if self.w.upper_bound != self.w.lower_bound:
820                return self.w.upper_bound
821        def __set__(self, float value):
822            self.w.upper_bound = value
823
824    property center_frequency:
825        "Center frequency (shan, fbsp, cmor)"
826        def __get__(self):
827            if self.w.center_frequency > 0:
828                return self.w.center_frequency
829        def __set__(self, float value):
830            self.w.center_frequency = value
831
832    property bandwidth_frequency:
833        "Bandwidth frequency (shan, fbsp, cmor)"
834        def __get__(self):
835            if self.w.bandwidth_frequency > 0:
836                return self.w.bandwidth_frequency
837        def __set__(self, float value):
838            self.w.bandwidth_frequency = value
839
840    property fbsp_order:
841        "order parameter for fbsp"
842        def __get__(self):
843            if self.w.fbsp_order != 0:
844                return self.w.fbsp_order
845        def __set__(self, unsigned int value):
846            self.w.fbsp_order = value
847
848    property symmetry:
849        "Wavelet symmetry"
850        def __get__(self):
851            if self.w.base.symmetry == wavelet.ASYMMETRIC:
852                return "asymmetric"
853            elif self.w.base.symmetry == wavelet.NEAR_SYMMETRIC:
854                return "near symmetric"
855            elif self.w.base.symmetry == wavelet.SYMMETRIC:
856                return "symmetric"
857            elif self.w.base.symmetry == wavelet.ANTI_SYMMETRIC:
858                return "anti-symmetric"
859            else:
860                return "unknown"
861
862    def wavefun(self, int level=8, length=None):
863        """
864        wavefun(self, level=8, length=None)
865
866        Calculates approximations of wavelet function (``psi``) on xgrid
867        (``x``) at a given level of refinement or length itself.
868
869        Parameters
870        ----------
871        level : int, optional
872            Level of refinement (default: 8). Defines the length by
873            ``2**level`` if length is not set.
874        length : int, optional
875            Number of samples. If set to None, the length is set to
876            ``2**level`` instead.
877
878        Returns
879        -------
880        psi : array_like
881            Wavelet function computed for grid xval
882        xval : array_like
883            grid going from lower_bound to upper_bound
884
885        Notes
886        -----
887        The effective support are set with ``lower_bound`` and ``upper_bound``.
888        The wavelet function is complex for ``'cmor'``, ``'shan'``, ``'fbsp'``
889        and ``'cgau'``.
890
891        The complex frequency B-spline wavelet (``'fbsp'``) has
892        ``bandwidth_frequency``, ``center_frequency`` and ``fbsp_order`` as
893        additional parameters.
894
895        The complex Shannon wavelet (``'shan'``) has ``bandwidth_frequency``
896        and ``center_frequency`` as additional parameters.
897
898        The complex Morlet wavelet (``'cmor'``) has ``bandwidth_frequency``
899        and ``center_frequency`` as additional parameters.
900
901        Examples
902        --------
903        >>> import pywt
904        >>> import matplotlib.pyplot as plt
905        >>> lb = -5
906        >>> ub = 5
907        >>> n = 1000
908        >>> wavelet = pywt.ContinuousWavelet("gaus8")
909        >>> wavelet.upper_bound = ub
910        >>> wavelet.lower_bound = lb
911        >>> [psi,xval] = wavelet.wavefun(length=n)
912        >>> plt.plot(xval,psi) # doctest: +ELLIPSIS
913        [<matplotlib.lines.Line2D object at ...>]
914        >>> plt.title("Gaussian Wavelet of order 8") # doctest: +ELLIPSIS
915        <matplotlib.text.Text object at ...>
916        >>> plt.show() # doctest: +SKIP
917
918        >>> import pywt
919        >>> import matplotlib.pyplot as plt
920        >>> lb = -5
921        >>> ub = 5
922        >>> n = 1000
923        >>> wavelet = pywt.ContinuousWavelet("cgau4")
924        >>> wavelet.upper_bound = ub
925        >>> wavelet.lower_bound = lb
926        >>> [psi,xval] = wavelet.wavefun(length=n)
927        >>> plt.subplot(211) # doctest: +ELLIPSIS
928        <matplotlib.axes._subplots.AxesSubplot object at ...>
929        >>> plt.plot(xval,np.real(psi)) # doctest: +ELLIPSIS
930        [<matplotlib.lines.Line2D object at ...>]
931        >>> plt.title("Real part") # doctest: +ELLIPSIS
932        <matplotlib.text.Text object at ...>
933        >>> plt.subplot(212) # doctest: +ELLIPSIS
934        <matplotlib.axes._subplots.AxesSubplot object at ...>
935        >>> plt.plot(xval,np.imag(psi)) # doctest: +ELLIPSIS
936        [<matplotlib.lines.Line2D object at ...>]
937        >>> plt.title("Imaginary part") # doctest: +ELLIPSIS
938        <matplotlib.text.Text object at ...>
939        >>> plt.show() # doctest: +SKIP
940
941        """
942        cdef pywt_index_t output_length "output_length"
943        cdef psi_i, psi_r, psi
944        cdef np.float64_t[::1] x64, psi64
945        cdef np.float32_t[::1] x32, psi32
946
947        p = (pow(2., <double>level))
948
949        if self.w is not NULL:
950            if length is None:
951                output_length = <pywt_index_t>p
952            else:
953                output_length = <pywt_index_t>length
954            if (self.dt == np.float64):
955                x64 = np.linspace(self.w.lower_bound, self.w.upper_bound, output_length, dtype=self.dt)
956            else:
957                x32 = np.linspace(self.w.lower_bound, self.w.upper_bound, output_length, dtype=self.dt)
958            if self.w.complex_cwt:
959                if (self.dt == np.float64):
960                    psi_r, psi_i = cwt_psi_single(x64, self, output_length)
961                    return [np.asarray(psi_r, dtype=self.dt) + 1j * np.asarray(psi_i, dtype=self.dt),
962                        np.asarray(x64, dtype=self.dt)]
963                else:
964                    psi_r, psi_i = cwt_psi_single(x32, self, output_length)
965                    return [np.asarray(psi_r, dtype=self.dt) + 1j * np.asarray(psi_i, dtype=self.dt),
966                            np.asarray(x32, dtype=self.dt)]
967            else:
968                if (self.dt == np.float64):
969                    psi = cwt_psi_single(x64, self, output_length)
970                    return [np.asarray(psi, dtype=self.dt),
971                            np.asarray(x64, dtype=self.dt)]
972
973                else:
974                    psi = cwt_psi_single(x32, self, output_length)
975                    return [np.asarray(psi, dtype=self.dt),
976                            np.asarray(x32, dtype=self.dt)]
977
978    def __str__(self):
979        s = []
980        for x in [
981            u"ContinuousWavelet %s" % self.name,
982            u"  Family name:    %s" % self.family_name,
983            u"  Short name:     %s" % self.short_family_name,
984            u"  Symmetry:       %s" % self.symmetry,
985            u"  DWT:            False",
986            u"  CWT:            True",
987            u"  Complex CWT:    %s" % self.complex_cwt
988            ]:
989            s.append(x.rstrip())
990        return u'\n'.join(s)
991
992    def __repr__(self):
993        repr = "{module}.{classname}(name='{name}')"
994        return repr.format(module=type(self).__module__,
995                           classname=type(self).__name__,
996                           name=self.name)
997
998
999cdef pywt_index_t get_keep_length(pywt_index_t output_length,
1000                             int level, pywt_index_t filter_length):
1001    cdef pywt_index_t lplus "lplus"
1002    cdef pywt_index_t keep_length "keep_length"
1003    cdef int i "i"
1004    lplus = filter_length - 2
1005    keep_length = 1
1006    for i in range(level):
1007        keep_length = 2*keep_length+lplus
1008    return keep_length
1009
1010cdef pywt_index_t fix_output_length(pywt_index_t output_length, pywt_index_t keep_length):
1011    if output_length-keep_length-2 < 0:
1012        output_length = keep_length+2
1013    return output_length
1014
1015cdef pywt_index_t get_right_extent_length(pywt_index_t output_length, pywt_index_t keep_length):
1016    return output_length - keep_length - 1
1017
1018
1019def wavelet_from_object(wavelet):
1020    return c_wavelet_from_object(wavelet)
1021
1022
1023cdef c_wavelet_from_object(wavelet):
1024    if isinstance(wavelet, (Wavelet, ContinuousWavelet)):
1025        return wavelet
1026    else:
1027        return Wavelet(wavelet)
1028
1029
1030cpdef np.dtype _check_dtype(data):
1031    """Check for cA/cD input what (if any) the dtype is."""
1032    cdef np.dtype dt
1033    try:
1034        dt = data.dtype
1035        if dt not in (np.float64, np.float32, np.complex64, np.complex128):
1036            if dt == np.half:
1037                # half-precision input converted to single precision
1038                dt = np.dtype('float32')
1039            elif dt == np.complex256:
1040                # complex256 is not supported.  run at reduced precision
1041                dt = np.dtype('complex128')
1042            else:
1043                # integer input was always accepted; convert to float64
1044                dt = np.dtype('float64')
1045    except AttributeError:
1046        dt = np.dtype('float64')
1047    return dt
1048
1049
1050# TODO: Can this be replaced by the take parameter of upcoef? Or vice-versa?
1051def keep(arr, keep_length):
1052    length = len(arr)
1053    if keep_length < length:
1054        left_bound = (length - keep_length) // 2
1055        return arr[left_bound:left_bound + keep_length]
1056    return arr
1057
1058
1059# Some utility functions
1060
1061cdef object float64_array_to_list(double* data, pywt_index_t n):
1062    cdef pywt_index_t i
1063    cdef object app
1064    cdef object ret
1065    ret = []
1066    app = ret.append
1067    for i in range(n):
1068        app(data[i])
1069    return ret
1070
1071
1072cdef void copy_object_to_float64_array(source, double* dest) except *:
1073    cdef pywt_index_t i
1074    cdef double x
1075    i = 0
1076    for x in source:
1077        dest[i] = x
1078        i = i + 1
1079
1080
1081cdef void copy_object_to_float32_array(source, float* dest) except *:
1082    cdef pywt_index_t i
1083    cdef float x
1084    i = 0
1085    for x in source:
1086        dest[i] = x
1087        i = i + 1
1088