1# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/> 2# Copyright (c) 2012-2018 The PyWavelets Developers 3# <https://github.com/PyWavelets/pywt> 4# See COPYING for license details. 5 6__doc__ = """Cython wrapper for low-level C wavelet transform implementation.""" 7__all__ = ['MODES', 'Modes', 'DiscreteContinuousWavelet', 'Wavelet', 8 'ContinuousWavelet', 'wavelist', 'families'] 9 10 11import warnings 12import re 13 14from . cimport c_wt 15from . cimport common 16from ._dwt cimport upcoef 17from ._cwt cimport cwt_psi_single 18 19from libc.math cimport pow, sqrt 20 21import numpy as np 22 23 24# Caution: order of _old_modes entries must match _Modes.modes below 25_old_modes = ['zpd', 26 'cpd', 27 'sym', 28 'ppd', 29 'sp1', 30 'per', 31 ] 32 33_attr_deprecation_msg = ('{old} has been renamed to {new} and will ' 34 'be unavailable in a future version ' 35 'of pywt.') 36 37# Extract float/int parameters from a wavelet name. Examples: 38# re.findall(cwt_pattern, 'fbsp1-1.5-1') -> ['1', 1.5', '1'] 39cwt_pattern = re.compile(r'\D+(\d+\.*\d*)+') 40 41 42# raises exception if the wavelet name is undefined 43cdef int is_discrete_wav(WAVELET_NAME name): 44 cdef int is_discrete 45 discrete = wavelet.is_discrete_wavelet(name) 46 if discrete == -1: 47 raise ValueError("unrecognized wavelet family name") 48 return discrete 49 50 51class _Modes(object): 52 """ 53 Because the most common and practical way of representing digital signals 54 in computer science is with finite arrays of values, some extrapolation of 55 the input data has to be performed in order to extend the signal before 56 computing the :ref:`Discrete Wavelet Transform <ref-dwt>` using the 57 cascading filter banks algorithm. 58 59 Depending on the extrapolation method, significant artifacts at the 60 signal's borders can be introduced during that process, which in turn may 61 lead to inaccurate computations of the :ref:`DWT <ref-dwt>` at the signal's 62 ends. 63 64 PyWavelets provides several methods of signal extrapolation that can be 65 used to minimize this negative effect: 66 67 zero - zero-padding 0 0 | x1 x2 ... xn | 0 0 68 constant - constant-padding x1 x1 | x1 x2 ... xn | xn xn 69 symmetric - symmetric-padding x2 x1 | x1 x2 ... xn | xn xn-1 70 reflect - reflect-padding x3 x2 | x1 x2 ... xn | xn-1 xn-2 71 periodic - periodic-padding xn-1 xn | x1 x2 ... xn | x1 x2 72 smooth - smooth-padding (1st derivative interpolation) 73 antisymmetric - -x2 -x1 | x1 x2 ... xn | -xn -xn-1 74 antireflect - -x3 -x2 | x1 x2 ... xn | -xn-1 -xn-2 75 76 DWT performed for these extension modes is slightly redundant, but ensure a 77 perfect reconstruction for IDWT. To receive the smallest possible number of 78 coefficients, computations can be performed with the periodization mode: 79 80 periodization - like periodic-padding but gives the smallest possible 81 number of decomposition coefficients. IDWT must be 82 performed with the same mode. 83 84 Examples 85 -------- 86 >>> import pywt 87 >>> pywt.Modes.modes 88 ['zero', 'constant', 'symmetric', 'reflect', 'periodic', 'smooth', 'periodization', 'antisymmetric', 'antireflect'] 89 >>> # The different ways of passing wavelet and mode parameters 90 >>> (a, d) = pywt.dwt([1,2,3,4,5,6], 'db2', 'smooth') 91 >>> (a, d) = pywt.dwt([1,2,3,4,5,6], pywt.Wavelet('db2'), pywt.Modes.smooth) 92 93 Notes 94 ----- 95 Extending data in context of PyWavelets does not mean reallocation of the 96 data in computer's physical memory and copying values, but rather computing 97 the extra values only when they are needed. This feature saves extra 98 memory and CPU resources and helps to avoid page swapping when handling 99 relatively big data arrays on computers with low physical memory. 100 101 """ 102 zero = common.MODE_ZEROPAD 103 constant = common.MODE_CONSTANT_EDGE 104 symmetric = common.MODE_SYMMETRIC 105 reflect = common.MODE_REFLECT 106 periodic = common.MODE_PERIODIC 107 smooth = common.MODE_SMOOTH 108 periodization = common.MODE_PERIODIZATION 109 antisymmetric = common.MODE_ANTISYMMETRIC 110 antireflect = common.MODE_ANTIREFLECT 111 112 # Caution: order in modes list below must match _old_modes above 113 modes = ["zero", "constant", "symmetric", "periodic", "smooth", 114 "periodization", "reflect", "antisymmetric", "antireflect"] 115 116 def from_object(self, mode): 117 if isinstance(mode, int): 118 if mode <= common.MODE_INVALID or mode >= common.MODE_MAX: 119 raise ValueError("Invalid mode.") 120 m = mode 121 else: 122 try: 123 m = getattr(Modes, mode) 124 except AttributeError: 125 raise ValueError("Unknown mode name '%s'." % mode) 126 127 return m 128 129 def __getattr__(self, mode): 130 # catch deprecated mode names 131 if mode in _old_modes: 132 new_mode = Modes.modes[_old_modes.index(mode)] 133 warnings.warn(_attr_deprecation_msg.format(old=mode, new=new_mode), 134 DeprecationWarning) 135 mode = new_mode 136 return Modes.__getattribute__(mode) 137 138 139Modes = _Modes() 140 141 142class _DeprecatedMODES(_Modes): 143 msg = ("MODES has been renamed to Modes and will be " 144 "removed in a future version of pywt.") 145 146 def __getattribute__(self, attr): 147 """Override so that deprecation warning is shown 148 every time MODES is used. 149 150 N.B. have to use __getattribute__ as well as __getattr__ 151 to ensure warning on e.g. `MODES.symmetric`. 152 """ 153 if not attr.startswith('_'): 154 warnings.warn(_DeprecatedMODES.msg, DeprecationWarning) 155 return _Modes.__getattribute__(self, attr) 156 157 def __getattr__(self, attr): 158 """Override so that deprecation warning is shown 159 every time MODES is used. 160 """ 161 warnings.warn(_DeprecatedMODES.msg, DeprecationWarning) 162 return _Modes.__getattr__(self, attr) 163 164 165MODES = _DeprecatedMODES() 166 167############################################################################### 168# Wavelet 169 170include "wavelets_list.pxi" # __wname_to_code 171 172cdef object wname_to_code(name): 173 cdef object code_number 174 try: 175 if len(name) > 4 and name[:4] in ['cmor', 'shan', 'fbsp']: 176 name = name[:4] 177 code_number = __wname_to_code[name] 178 return code_number 179 except KeyError: 180 raise ValueError("Unknown wavelet name '%s', check wavelist() for the " 181 "list of available builtin wavelets." % name) 182 183 184def wavelist(family=None, kind='all'): 185 """ 186 wavelist(family=None, kind='all') 187 188 Returns list of available wavelet names for the given family name. 189 190 Parameters 191 ---------- 192 family : str, optional 193 Short family name. If the family name is None (default) then names 194 of all the built-in wavelets are returned. Otherwise the function 195 returns names of wavelets that belong to the given family. 196 Valid names are:: 197 198 'haar', 'db', 'sym', 'coif', 'bior', 'rbio', 'dmey', 'gaus', 199 'mexh', 'morl', 'cgau', 'shan', 'fbsp', 'cmor' 200 201 kind : {'all', 'continuous', 'discrete'}, optional 202 Whether to return only wavelet names of discrete or continuous 203 wavelets, or all wavelets. Default is ``'all'``. 204 Ignored if ``family`` is specified. 205 206 Returns 207 ------- 208 wavelist : list of str 209 List of available wavelet names. 210 211 Examples 212 -------- 213 >>> import pywt 214 >>> pywt.wavelist('coif') 215 ['coif1', 'coif2', 'coif3', 'coif4', 'coif5', 'coif6', 'coif7', ... 216 >>> pywt.wavelist(kind='continuous') 217 ['cgau1', 'cgau2', 'cgau3', 'cgau4', 'cgau5', 'cgau6', 'cgau7', ... 218 219 """ 220 cdef object wavelets, sorting_list 221 222 if kind not in ('all', 'continuous', 'discrete'): 223 raise ValueError("Unrecognized value for `kind`: %s" % kind) 224 225 def _check_kind(name, kind): 226 if kind == 'all': 227 return True 228 229 family_code, family_number = wname_to_code(name) 230 is_discrete = is_discrete_wav(family_code) 231 if kind == 'discrete': 232 return is_discrete 233 else: 234 return not is_discrete 235 236 sorting_list = [] # for natural sorting order 237 wavelets = [] 238 cdef object name 239 if family is None: 240 for name in __wname_to_code: 241 if _check_kind(name, kind): 242 sorting_list.append((name[:2], len(name), name)) 243 elif family in __wfamily_list_short: 244 for name in __wname_to_code: 245 if name.startswith(family): 246 sorting_list.append((name[:2], len(name), name)) 247 else: 248 raise ValueError("Invalid short family name '%s'." % family) 249 250 sorting_list.sort() 251 for x, x, name in sorting_list: 252 wavelets.append(name) 253 return wavelets 254 255 256def families(int short=True): 257 """ 258 families(short=True) 259 260 Returns a list of available built-in wavelet families. 261 262 Currently the built-in families are: 263 264 * Haar (``haar``) 265 * Daubechies (``db``) 266 * Symlets (``sym``) 267 * Coiflets (``coif``) 268 * Biorthogonal (``bior``) 269 * Reverse biorthogonal (``rbio``) 270 * `"Discrete"` FIR approximation of Meyer wavelet (``dmey``) 271 * Gaussian wavelets (``gaus``) 272 * Mexican hat wavelet (``mexh``) 273 * Morlet wavelet (``morl``) 274 * Complex Gaussian wavelets (``cgau``) 275 * Shannon wavelets (``shan``) 276 * Frequency B-Spline wavelets (``fbsp``) 277 * Complex Morlet wavelets (``cmor``) 278 279 Parameters 280 ---------- 281 short : bool, optional 282 Use short names (default: True). 283 284 Returns 285 ------- 286 families : list 287 List of available wavelet families. 288 289 Examples 290 -------- 291 >>> import pywt 292 >>> pywt.families() 293 ['haar', 'db', 'sym', 'coif', 'bior', 'rbio', 'dmey', 'gaus', 'mexh', 'morl', 'cgau', 'shan', 'fbsp', 'cmor'] 294 >>> pywt.families(short=False) 295 ['Haar', 'Daubechies', 'Symlets', 'Coiflets', 'Biorthogonal', 'Reverse biorthogonal', 'Discrete Meyer (FIR Approximation)', 'Gaussian', 'Mexican hat wavelet', 'Morlet wavelet', 'Complex Gaussian wavelets', 'Shannon wavelets', 'Frequency B-Spline wavelets', 'Complex Morlet wavelets'] 296 297 """ 298 if short: 299 return __wfamily_list_short[:] 300 return __wfamily_list_long[:] 301 302 303def DiscreteContinuousWavelet(name=u"", object filter_bank=None): 304 """ 305 DiscreteContinuousWavelet(name, filter_bank=None) returns a 306 Wavelet or a ContinuousWavelet object depending of the given name. 307 308 In order to use a built-in wavelet the parameter name must be 309 a valid name from the wavelist() list. 310 To create a custom wavelet object, filter_bank parameter must 311 be specified. It can be either a list of four filters or an object 312 that a `filter_bank` attribute which returns a list of four 313 filters - just like the Wavelet instance itself. 314 315 For a ContinuousWavelet, filter_bank cannot be used and must remain unset. 316 317 """ 318 if not name and filter_bank is None: 319 raise TypeError("Wavelet name or filter bank must be specified.") 320 if filter_bank is None: 321 name = name.lower() 322 family_code, family_number = wname_to_code(name) 323 if is_discrete_wav(family_code): 324 return Wavelet(name, filter_bank) 325 else: 326 return ContinuousWavelet(name) 327 else: 328 return Wavelet(name, filter_bank) 329 330 331cdef public class Wavelet [type WaveletType, object WaveletObject]: 332 """ 333 Wavelet(name, filter_bank=None) object describe properties of 334 a wavelet identified by name. 335 336 In order to use a built-in wavelet the parameter name must be 337 a valid name from the wavelist() list. 338 To create a custom wavelet object, filter_bank parameter must 339 be specified. It can be either a list of four filters or an object 340 that a `filter_bank` attribute which returns a list of four 341 filters - just like the Wavelet instance itself. 342 343 """ 344 #cdef readonly properties 345 def __cinit__(self, name=u"", object filter_bank=None): 346 cdef object family_code, family_number 347 cdef object filters 348 cdef pywt_index_t filter_length 349 cdef object dec_lo, dec_hi, rec_lo, rec_hi 350 351 if not name and filter_bank is None: 352 raise TypeError("Wavelet name or filter bank must be specified.") 353 354 if filter_bank is None: 355 # builtin wavelet 356 self.name = name.lower() 357 family_code, family_number = wname_to_code(self.name) 358 if is_discrete_wav(family_code): 359 self.w = <wavelet.DiscreteWavelet*> wavelet.discrete_wavelet(family_code, family_number) 360 if self.w is NULL: 361 if self.name in wavelist(kind='continuous'): 362 raise ValueError("The `Wavelet` class is for discrete " 363 "wavelets, %s is a continuous wavelet. Use " 364 "pywt.ContinuousWavelet instead" % self.name) 365 else: 366 raise ValueError("Invalid wavelet name '%s'." % self.name) 367 self.number = family_number 368 else: 369 if hasattr(filter_bank, "filter_bank"): 370 filters = filter_bank.filter_bank 371 if len(filters) != 4: 372 raise ValueError("Expected filter bank with 4 filters, " 373 "got filter bank with %d filters." % len(filters)) 374 elif hasattr(filter_bank, "get_filters_coeffs"): 375 msg = ("Creating custom Wavelets using objects that define " 376 "`get_filters_coeffs` method is deprecated. " 377 "The `filter_bank` parameter should define a " 378 "`filter_bank` attribute instead of " 379 "`get_filters_coeffs` method.") 380 warnings.warn(msg, DeprecationWarning) 381 filters = filter_bank.get_filters_coeffs() 382 if len(filters) != 4: 383 msg = ("Expected filter bank with 4 filters, got filter " 384 "bank with %d filters." % len(filters)) 385 raise ValueError(msg) 386 else: 387 filters = filter_bank 388 if len(filters) != 4: 389 msg = ("Expected list of 4 filters coefficients, " 390 "got %d filters." % len(filters)) 391 raise ValueError(msg) 392 try: 393 dec_lo = np.asarray(filters[0], dtype=np.float64) 394 dec_hi = np.asarray(filters[1], dtype=np.float64) 395 rec_lo = np.asarray(filters[2], dtype=np.float64) 396 rec_hi = np.asarray(filters[3], dtype=np.float64) 397 except TypeError: 398 raise ValueError("Filter bank with numeric values required.") 399 400 if not (1 == dec_lo.ndim == dec_hi.ndim == 401 rec_lo.ndim == rec_hi.ndim): 402 raise ValueError("All filters in filter bank must be 1D.") 403 404 filter_length = len(dec_lo) 405 if not (0 < filter_length == len(dec_hi) == len(rec_lo) == 406 len(rec_hi)) > 0: 407 raise ValueError("All filters in filter bank must have " 408 "length greater than 0.") 409 410 self.w = <wavelet.DiscreteWavelet*> wavelet.blank_discrete_wavelet(filter_length) 411 if self.w is NULL: 412 raise MemoryError("Could not allocate memory for given " 413 "filter bank.") 414 415 # copy values to struct 416 copy_object_to_float32_array(dec_lo, self.w.dec_lo_float) 417 copy_object_to_float32_array(dec_hi, self.w.dec_hi_float) 418 copy_object_to_float32_array(rec_lo, self.w.rec_lo_float) 419 copy_object_to_float32_array(rec_hi, self.w.rec_hi_float) 420 421 copy_object_to_float64_array(dec_lo, self.w.dec_lo_double) 422 copy_object_to_float64_array(dec_hi, self.w.dec_hi_double) 423 copy_object_to_float64_array(rec_lo, self.w.rec_lo_double) 424 copy_object_to_float64_array(rec_hi, self.w.rec_hi_double) 425 426 self.name = name 427 428 def __dealloc__(self): 429 if self.w is not NULL: 430 wavelet.free_discrete_wavelet(self.w) 431 self.w = NULL 432 433 def __reduce__(self): 434 return (Wavelet, (self.name, self.filter_bank)) 435 436 def __len__(self): 437 return self.w.dec_len 438 439 property dec_lo: 440 "Lowpass decomposition filter" 441 def __get__(self): 442 return float64_array_to_list(self.w.dec_lo_double, self.w.dec_len) 443 444 property dec_hi: 445 "Highpass decomposition filter" 446 def __get__(self): 447 return float64_array_to_list(self.w.dec_hi_double, self.w.dec_len) 448 449 property rec_lo: 450 "Lowpass reconstruction filter" 451 def __get__(self): 452 return float64_array_to_list(self.w.rec_lo_double, self.w.rec_len) 453 454 property rec_hi: 455 "Highpass reconstruction filter" 456 def __get__(self): 457 return float64_array_to_list(self.w.rec_hi_double, self.w.rec_len) 458 459 property rec_len: 460 "Reconstruction filters length" 461 def __get__(self): 462 return self.w.rec_len 463 464 property dec_len: 465 "Decomposition filters length" 466 def __get__(self): 467 return self.w.dec_len 468 469 property family_number: 470 "Wavelet family number" 471 def __get__(self): 472 return self.number 473 474 property family_name: 475 "Wavelet family name" 476 def __get__(self): 477 return self.w.base.family_name.decode('latin-1') 478 479 property short_family_name: 480 "Short wavelet family name" 481 def __get__(self): 482 return self.w.base.short_name.decode('latin-1') 483 484 property orthogonal: 485 "Is orthogonal" 486 def __get__(self): 487 return bool(self.w.base.orthogonal) 488 def __set__(self, int value): 489 self.w.base.orthogonal = (value != 0) 490 491 property biorthogonal: 492 "Is biorthogonal" 493 def __get__(self): 494 return bool(self.w.base.biorthogonal) 495 def __set__(self, int value): 496 self.w.base.biorthogonal = (value != 0) 497 498 property symmetry: 499 "Wavelet symmetry" 500 def __get__(self): 501 if self.w.base.symmetry == wavelet.ASYMMETRIC: 502 return "asymmetric" 503 elif self.w.base.symmetry == wavelet.NEAR_SYMMETRIC: 504 return "near symmetric" 505 elif self.w.base.symmetry == wavelet.SYMMETRIC: 506 return "symmetric" 507 elif self.w.base.symmetry == wavelet.ANTI_SYMMETRIC: 508 return "anti-symmetric" 509 else: 510 return "unknown" 511 512 property vanishing_moments_psi: 513 "Number of vanishing moments for wavelet function" 514 def __get__(self): 515 if self.w.vanishing_moments_psi >= 0: 516 return self.w.vanishing_moments_psi 517 518 property vanishing_moments_phi: 519 "Number of vanishing moments for scaling function" 520 def __get__(self): 521 if self.w.vanishing_moments_phi >= 0: 522 return self.w.vanishing_moments_phi 523 524 property filter_bank: 525 """Returns tuple of wavelet filters coefficients 526 (dec_lo, dec_hi, rec_lo, rec_hi) 527 """ 528 def __get__(self): 529 return (self.dec_lo, self.dec_hi, self.rec_lo, self.rec_hi) 530 531 def get_filters_coeffs(self): 532 warnings.warn("The `get_filters_coeffs` method is deprecated. " 533 "Use `filter_bank` attribute instead.", DeprecationWarning) 534 return self.filter_bank 535 536 property inverse_filter_bank: 537 """Tuple of inverse wavelet filters coefficients 538 (rec_lo[::-1], rec_hi[::-1], dec_lo[::-1], dec_hi[::-1]) 539 """ 540 def __get__(self): 541 return (self.rec_lo[::-1], self.rec_hi[::-1], self.dec_lo[::-1], 542 self.dec_hi[::-1]) 543 544 def get_reverse_filters_coeffs(self): 545 warnings.warn("The `get_reverse_filters_coeffs` method is deprecated. " 546 "Use `inverse_filter_bank` attribute instead.", 547 DeprecationWarning) 548 return self.inverse_filter_bank 549 550 def wavefun(self, int level=8): 551 """ 552 wavefun(self, level=8) 553 554 Calculates approximations of scaling function (`phi`) and wavelet 555 function (`psi`) on xgrid (`x`) at a given level of refinement. 556 557 Parameters 558 ---------- 559 level : int, optional 560 Level of refinement (default: 8). 561 562 Returns 563 ------- 564 [phi, psi, x] : array_like 565 For orthogonal wavelets returns scaling function, wavelet function 566 and xgrid - [phi, psi, x]. 567 568 [phi_d, psi_d, phi_r, psi_r, x] : array_like 569 For biorthogonal wavelets returns scaling and wavelet function both 570 for decomposition and reconstruction and xgrid 571 572 Examples 573 -------- 574 >>> import pywt 575 >>> # Orthogonal 576 >>> wavelet = pywt.Wavelet('db2') 577 >>> phi, psi, x = wavelet.wavefun(level=5) 578 >>> # Biorthogonal 579 >>> wavelet = pywt.Wavelet('bior3.5') 580 >>> phi_d, psi_d, phi_r, psi_r, x = wavelet.wavefun(level=5) 581 582 """ 583 cdef pywt_index_t filter_length "filter_length" 584 cdef pywt_index_t right_extent_length "right_extent_length" 585 cdef pywt_index_t output_length "output_length" 586 cdef pywt_index_t keep_length "keep_length" 587 cdef np.float64_t n, n_mul 588 cdef np.float64_t[::1] n_arr = <np.float64_t[:1]> &n, 589 cdef np.float64_t[::1] n_mul_arr = <np.float64_t[:1]> &n_mul 590 cdef double p "p" 591 cdef double mul "mul" 592 cdef Wavelet other "other" 593 cdef phi_d, psi_d, phi_r, psi_r 594 cdef psi_i 595 cdef np.float64_t[::1] x, psi 596 597 n = pow(sqrt(2.), <double>level) 598 p = (pow(2., <double>level)) 599 600 if self.w.base.orthogonal: 601 filter_length = self.w.dec_len 602 output_length = <pywt_index_t> ((filter_length-1) * p + 1) 603 keep_length = get_keep_length(output_length, level, filter_length) 604 output_length = fix_output_length(output_length, keep_length) 605 606 right_extent_length = get_right_extent_length(output_length, 607 keep_length) 608 609 # phi, psi, x 610 return [np.concatenate(([0.], 611 keep(upcoef(True, n_arr, self, level, 0), keep_length), 612 np.zeros(right_extent_length))), 613 np.concatenate(([0.], 614 keep(upcoef(False, n_arr, self, level, 0), keep_length), 615 np.zeros(right_extent_length))), 616 np.linspace(0.0, (output_length-1)/p, output_length)] 617 else: 618 if self.w.base.biorthogonal and (self.w.vanishing_moments_psi % 4) != 1: 619 # FIXME: I don't think this branch is well tested 620 n_mul = -n 621 else: 622 n_mul = n 623 624 other = Wavelet(filter_bank=self.inverse_filter_bank) 625 626 filter_length = other.w.dec_len 627 output_length = <pywt_index_t> ((filter_length-1) * p) 628 keep_length = get_keep_length(output_length, level, filter_length) 629 output_length = fix_output_length(output_length, keep_length) 630 right_extent_length = get_right_extent_length(output_length, keep_length) 631 632 phi_d = np.concatenate(([0.], 633 keep(upcoef(True, n_arr, other, level, 0), keep_length), 634 np.zeros(right_extent_length))) 635 psi_d = np.concatenate(([0.], 636 keep(upcoef(False, n_mul_arr, other, level, 0), 637 keep_length), 638 np.zeros(right_extent_length))) 639 640 filter_length = self.w.dec_len 641 output_length = <pywt_index_t> ((filter_length-1) * p) 642 keep_length = get_keep_length(output_length, level, filter_length) 643 output_length = fix_output_length(output_length, keep_length) 644 right_extent_length = get_right_extent_length(output_length, keep_length) 645 646 phi_r = np.concatenate(([0.], 647 keep(upcoef(True, n_arr, self, level, 0), keep_length), 648 np.zeros(right_extent_length))) 649 psi_r = np.concatenate(([0.], 650 keep(upcoef(False, n_mul_arr, self, level, 0), 651 keep_length), 652 np.zeros(right_extent_length))) 653 654 return [phi_d, psi_d, phi_r, psi_r, 655 np.linspace(0.0, (output_length - 1) / p, output_length)] 656 657 def __str__(self): 658 s = [] 659 for x in [ 660 u"Wavelet %s" % self.name, 661 u" Family name: %s" % self.family_name, 662 u" Short name: %s" % self.short_family_name, 663 u" Filters length: %d" % self.dec_len, 664 u" Orthogonal: %s" % self.orthogonal, 665 u" Biorthogonal: %s" % self.biorthogonal, 666 u" Symmetry: %s" % self.symmetry, 667 u" DWT: True", 668 u" CWT: False" 669 ]: 670 s.append(x.rstrip()) 671 return u'\n'.join(s) 672 673 def __repr__(self): 674 repr = "{module}.{classname}(name='{name}', filter_bank={filter_bank})" 675 return repr.format(module=type(self).__module__, 676 classname=type(self).__name__, 677 name=self.name, 678 filter_bank=self.filter_bank) 679 680 681cdef public class ContinuousWavelet [type ContinuousWaveletType, object ContinuousWaveletObject]: 682 """ 683 ContinuousWavelet(name, dtype) object describe properties of 684 a continuous wavelet identified by name. 685 686 In order to use a built-in wavelet the parameter name must be 687 a valid name from the wavelist() list. 688 689 """ 690 #cdef readonly properties 691 def __cinit__(self, name=u"", dtype=np.float64): 692 cdef object family_code, family_number 693 694 # builtin wavelet 695 self.name = name.lower() 696 self.dt = dtype 697 if np.dtype(self.dt) not in [np.float32, np.float64]: 698 raise ValueError( 699 "Only np.float32 and np.float64 dtype are supported for " 700 "ContinuousWavelet objects.") 701 if len(self.name) >= 4 and self.name[:4] in ['cmor', 'shan', 'fbsp']: 702 base_name = self.name[:4] 703 if base_name == self.name: 704 if base_name == 'fbsp': 705 msg = ( 706 "Wavelets of family {0}, without parameters " 707 "specified in the name are deprecated. The name " 708 "should take the form {0}M-B-C where M is the spline " 709 "order and B, C are floats representing the bandwidth " 710 "frequency and center frequency, respectively " 711 "(example: {0}1-1.5-1.0).").format(base_name) 712 else: 713 msg = ( 714 "Wavelets from the family {0}, without parameters " 715 "specified in the name are deprecated. The name " 716 "should takethe form {0}B-C where B and C are floats " 717 "representing the bandwidth frequency and center " 718 "frequency, respectively (example: {0}1.5-1.0)." 719 ).format(base_name) 720 warnings.warn(msg, FutureWarning) 721 else: 722 base_name = self.name 723 family_code, family_number = wname_to_code(base_name) 724 self.w = <wavelet.ContinuousWavelet*> wavelet.continuous_wavelet( 725 family_code, family_number) 726 727 if self.w is NULL: 728 raise ValueError("Invalid wavelet name '%s'." % self.name) 729 self.number = family_number 730 731 # set wavelet attributes based on frequencies extracted from the name 732 if base_name != self.name: 733 freqs = re.findall(cwt_pattern, self.name) 734 if base_name in ['shan', 'cmor']: 735 if len(freqs) != 2: 736 raise ValueError( 737 ("For wavelets of family {0}, the name should take " 738 "the form {0}B-C where B and C are floats " 739 "representing the bandwidth frequency and center " 740 "frequency, respectively. (example: {0}1.5-1.0)" 741 ).format(base_name)) 742 self.w.bandwidth_frequency = float(freqs[0]) 743 self.w.center_frequency = float(freqs[1]) 744 elif base_name in ['fbsp', ]: 745 if len(freqs) != 3: 746 raise ValueError( 747 ("For wavelets of family {0}, the name should take " 748 "the form {0}M-B-C where M is the spline order and B" 749 ", C are floats representing the bandwidth frequency " 750 "and center frequency, respectively " 751 "(example: {0}1-1.5-1.0).").format(base_name)) 752 M = float(freqs[0]) 753 self.w.bandwidth_frequency = float(freqs[1]) 754 self.w.center_frequency = float(freqs[2]) 755 if M < 1 or M % 1 != 0: 756 raise ValueError( 757 "Wavelet spline order must be an integer >= 1.") 758 self.w.fbsp_order = int(M) 759 else: 760 raise ValueError( 761 "Invalid continuous wavelet name '%s'." % self.name) 762 763 764 def __dealloc__(self): 765 if self.w is not NULL: 766 wavelet.free_continuous_wavelet(self.w) 767 self.w = NULL 768 769 def __reduce__(self): 770 return (ContinuousWavelet, (self.name, self.dt)) 771 772 property family_number: 773 "Wavelet family number" 774 def __get__(self): 775 return self.number 776 777 property family_name: 778 "Wavelet family name" 779 def __get__(self): 780 return self.w.base.family_name.decode('latin-1') 781 782 property short_family_name: 783 "Short wavelet family name" 784 def __get__(self): 785 return self.w.base.short_name.decode('latin-1') 786 787 property orthogonal: 788 "Is orthogonal" 789 def __get__(self): 790 return bool(self.w.base.orthogonal) 791 def __set__(self, int value): 792 self.w.base.orthogonal = (value != 0) 793 794 property biorthogonal: 795 "Is biorthogonal" 796 def __get__(self): 797 return bool(self.w.base.biorthogonal) 798 def __set__(self, int value): 799 self.w.base.biorthogonal = (value != 0) 800 801 property complex_cwt: 802 "CWT is complex" 803 def __get__(self): 804 return bool(self.w.complex_cwt) 805 def __set__(self, int value): 806 self.w.complex_cwt = (value != 0) 807 808 property lower_bound: 809 "Lower Bound" 810 def __get__(self): 811 if self.w.lower_bound != self.w.upper_bound: 812 return self.w.lower_bound 813 def __set__(self, float value): 814 self.w.lower_bound = value 815 816 property upper_bound: 817 "Upper Bound" 818 def __get__(self): 819 if self.w.upper_bound != self.w.lower_bound: 820 return self.w.upper_bound 821 def __set__(self, float value): 822 self.w.upper_bound = value 823 824 property center_frequency: 825 "Center frequency (shan, fbsp, cmor)" 826 def __get__(self): 827 if self.w.center_frequency > 0: 828 return self.w.center_frequency 829 def __set__(self, float value): 830 self.w.center_frequency = value 831 832 property bandwidth_frequency: 833 "Bandwidth frequency (shan, fbsp, cmor)" 834 def __get__(self): 835 if self.w.bandwidth_frequency > 0: 836 return self.w.bandwidth_frequency 837 def __set__(self, float value): 838 self.w.bandwidth_frequency = value 839 840 property fbsp_order: 841 "order parameter for fbsp" 842 def __get__(self): 843 if self.w.fbsp_order != 0: 844 return self.w.fbsp_order 845 def __set__(self, unsigned int value): 846 self.w.fbsp_order = value 847 848 property symmetry: 849 "Wavelet symmetry" 850 def __get__(self): 851 if self.w.base.symmetry == wavelet.ASYMMETRIC: 852 return "asymmetric" 853 elif self.w.base.symmetry == wavelet.NEAR_SYMMETRIC: 854 return "near symmetric" 855 elif self.w.base.symmetry == wavelet.SYMMETRIC: 856 return "symmetric" 857 elif self.w.base.symmetry == wavelet.ANTI_SYMMETRIC: 858 return "anti-symmetric" 859 else: 860 return "unknown" 861 862 def wavefun(self, int level=8, length=None): 863 """ 864 wavefun(self, level=8, length=None) 865 866 Calculates approximations of wavelet function (``psi``) on xgrid 867 (``x``) at a given level of refinement or length itself. 868 869 Parameters 870 ---------- 871 level : int, optional 872 Level of refinement (default: 8). Defines the length by 873 ``2**level`` if length is not set. 874 length : int, optional 875 Number of samples. If set to None, the length is set to 876 ``2**level`` instead. 877 878 Returns 879 ------- 880 psi : array_like 881 Wavelet function computed for grid xval 882 xval : array_like 883 grid going from lower_bound to upper_bound 884 885 Notes 886 ----- 887 The effective support are set with ``lower_bound`` and ``upper_bound``. 888 The wavelet function is complex for ``'cmor'``, ``'shan'``, ``'fbsp'`` 889 and ``'cgau'``. 890 891 The complex frequency B-spline wavelet (``'fbsp'``) has 892 ``bandwidth_frequency``, ``center_frequency`` and ``fbsp_order`` as 893 additional parameters. 894 895 The complex Shannon wavelet (``'shan'``) has ``bandwidth_frequency`` 896 and ``center_frequency`` as additional parameters. 897 898 The complex Morlet wavelet (``'cmor'``) has ``bandwidth_frequency`` 899 and ``center_frequency`` as additional parameters. 900 901 Examples 902 -------- 903 >>> import pywt 904 >>> import matplotlib.pyplot as plt 905 >>> lb = -5 906 >>> ub = 5 907 >>> n = 1000 908 >>> wavelet = pywt.ContinuousWavelet("gaus8") 909 >>> wavelet.upper_bound = ub 910 >>> wavelet.lower_bound = lb 911 >>> [psi,xval] = wavelet.wavefun(length=n) 912 >>> plt.plot(xval,psi) # doctest: +ELLIPSIS 913 [<matplotlib.lines.Line2D object at ...>] 914 >>> plt.title("Gaussian Wavelet of order 8") # doctest: +ELLIPSIS 915 <matplotlib.text.Text object at ...> 916 >>> plt.show() # doctest: +SKIP 917 918 >>> import pywt 919 >>> import matplotlib.pyplot as plt 920 >>> lb = -5 921 >>> ub = 5 922 >>> n = 1000 923 >>> wavelet = pywt.ContinuousWavelet("cgau4") 924 >>> wavelet.upper_bound = ub 925 >>> wavelet.lower_bound = lb 926 >>> [psi,xval] = wavelet.wavefun(length=n) 927 >>> plt.subplot(211) # doctest: +ELLIPSIS 928 <matplotlib.axes._subplots.AxesSubplot object at ...> 929 >>> plt.plot(xval,np.real(psi)) # doctest: +ELLIPSIS 930 [<matplotlib.lines.Line2D object at ...>] 931 >>> plt.title("Real part") # doctest: +ELLIPSIS 932 <matplotlib.text.Text object at ...> 933 >>> plt.subplot(212) # doctest: +ELLIPSIS 934 <matplotlib.axes._subplots.AxesSubplot object at ...> 935 >>> plt.plot(xval,np.imag(psi)) # doctest: +ELLIPSIS 936 [<matplotlib.lines.Line2D object at ...>] 937 >>> plt.title("Imaginary part") # doctest: +ELLIPSIS 938 <matplotlib.text.Text object at ...> 939 >>> plt.show() # doctest: +SKIP 940 941 """ 942 cdef pywt_index_t output_length "output_length" 943 cdef psi_i, psi_r, psi 944 cdef np.float64_t[::1] x64, psi64 945 cdef np.float32_t[::1] x32, psi32 946 947 p = (pow(2., <double>level)) 948 949 if self.w is not NULL: 950 if length is None: 951 output_length = <pywt_index_t>p 952 else: 953 output_length = <pywt_index_t>length 954 if (self.dt == np.float64): 955 x64 = np.linspace(self.w.lower_bound, self.w.upper_bound, output_length, dtype=self.dt) 956 else: 957 x32 = np.linspace(self.w.lower_bound, self.w.upper_bound, output_length, dtype=self.dt) 958 if self.w.complex_cwt: 959 if (self.dt == np.float64): 960 psi_r, psi_i = cwt_psi_single(x64, self, output_length) 961 return [np.asarray(psi_r, dtype=self.dt) + 1j * np.asarray(psi_i, dtype=self.dt), 962 np.asarray(x64, dtype=self.dt)] 963 else: 964 psi_r, psi_i = cwt_psi_single(x32, self, output_length) 965 return [np.asarray(psi_r, dtype=self.dt) + 1j * np.asarray(psi_i, dtype=self.dt), 966 np.asarray(x32, dtype=self.dt)] 967 else: 968 if (self.dt == np.float64): 969 psi = cwt_psi_single(x64, self, output_length) 970 return [np.asarray(psi, dtype=self.dt), 971 np.asarray(x64, dtype=self.dt)] 972 973 else: 974 psi = cwt_psi_single(x32, self, output_length) 975 return [np.asarray(psi, dtype=self.dt), 976 np.asarray(x32, dtype=self.dt)] 977 978 def __str__(self): 979 s = [] 980 for x in [ 981 u"ContinuousWavelet %s" % self.name, 982 u" Family name: %s" % self.family_name, 983 u" Short name: %s" % self.short_family_name, 984 u" Symmetry: %s" % self.symmetry, 985 u" DWT: False", 986 u" CWT: True", 987 u" Complex CWT: %s" % self.complex_cwt 988 ]: 989 s.append(x.rstrip()) 990 return u'\n'.join(s) 991 992 def __repr__(self): 993 repr = "{module}.{classname}(name='{name}')" 994 return repr.format(module=type(self).__module__, 995 classname=type(self).__name__, 996 name=self.name) 997 998 999cdef pywt_index_t get_keep_length(pywt_index_t output_length, 1000 int level, pywt_index_t filter_length): 1001 cdef pywt_index_t lplus "lplus" 1002 cdef pywt_index_t keep_length "keep_length" 1003 cdef int i "i" 1004 lplus = filter_length - 2 1005 keep_length = 1 1006 for i in range(level): 1007 keep_length = 2*keep_length+lplus 1008 return keep_length 1009 1010cdef pywt_index_t fix_output_length(pywt_index_t output_length, pywt_index_t keep_length): 1011 if output_length-keep_length-2 < 0: 1012 output_length = keep_length+2 1013 return output_length 1014 1015cdef pywt_index_t get_right_extent_length(pywt_index_t output_length, pywt_index_t keep_length): 1016 return output_length - keep_length - 1 1017 1018 1019def wavelet_from_object(wavelet): 1020 return c_wavelet_from_object(wavelet) 1021 1022 1023cdef c_wavelet_from_object(wavelet): 1024 if isinstance(wavelet, (Wavelet, ContinuousWavelet)): 1025 return wavelet 1026 else: 1027 return Wavelet(wavelet) 1028 1029 1030cpdef np.dtype _check_dtype(data): 1031 """Check for cA/cD input what (if any) the dtype is.""" 1032 cdef np.dtype dt 1033 try: 1034 dt = data.dtype 1035 if dt not in (np.float64, np.float32, np.complex64, np.complex128): 1036 if dt == np.half: 1037 # half-precision input converted to single precision 1038 dt = np.dtype('float32') 1039 elif dt == np.complex256: 1040 # complex256 is not supported. run at reduced precision 1041 dt = np.dtype('complex128') 1042 else: 1043 # integer input was always accepted; convert to float64 1044 dt = np.dtype('float64') 1045 except AttributeError: 1046 dt = np.dtype('float64') 1047 return dt 1048 1049 1050# TODO: Can this be replaced by the take parameter of upcoef? Or vice-versa? 1051def keep(arr, keep_length): 1052 length = len(arr) 1053 if keep_length < length: 1054 left_bound = (length - keep_length) // 2 1055 return arr[left_bound:left_bound + keep_length] 1056 return arr 1057 1058 1059# Some utility functions 1060 1061cdef object float64_array_to_list(double* data, pywt_index_t n): 1062 cdef pywt_index_t i 1063 cdef object app 1064 cdef object ret 1065 ret = [] 1066 app = ret.append 1067 for i in range(n): 1068 app(data[i]) 1069 return ret 1070 1071 1072cdef void copy_object_to_float64_array(source, double* dest) except *: 1073 cdef pywt_index_t i 1074 cdef double x 1075 i = 0 1076 for x in source: 1077 dest[i] = x 1078 i = i + 1 1079 1080 1081cdef void copy_object_to_float32_array(source, float* dest) except *: 1082 cdef pywt_index_t i 1083 cdef float x 1084 i = 0 1085 for x in source: 1086 dest[i] = x 1087 i = i + 1 1088