1import functools
2from math import ceil
3import numbers
4
5import scipy.stats
6import numpy as np
7import pywt
8
9from ..util.dtype import img_as_float
10from .._shared import utils
11from .._shared.utils import _supported_float_type, warn
12from ._denoise_cy import _denoise_bilateral, _denoise_tv_bregman
13from .. import color
14from ..color.colorconv import ycbcr_from_rgb
15
16
17def _gaussian_weight(array, sigma_squared, *, dtype=float):
18    """Helping function. Define a Gaussian weighting from array and
19    sigma_square.
20
21    Parameters
22    ----------
23    array : ndarray
24        Input array.
25    sigma_squared : float
26        The squared standard deviation used in the filter.
27    dtype : data type object, optional (default : float)
28        The type and size of the data to be returned.
29
30    Returns
31    -------
32    gaussian : ndarray
33        The input array filtered by the Gaussian.
34    """
35    return np.exp(-0.5 * (array ** 2 / sigma_squared), dtype=dtype)
36
37
38def _compute_color_lut(bins, sigma, max_value, *, dtype=float):
39    """Helping function. Define a lookup table containing Gaussian filter
40    values using the color distance sigma.
41
42    Parameters
43    ----------
44    bins : int
45        Number of discrete values for Gaussian weights of color filtering.
46        A larger value results in improved accuracy.
47    sigma : float
48        Standard deviation for grayvalue/color distance (radiometric
49        similarity). A larger value results in averaging of pixels with larger
50        radiometric differences. Note, that the image will be converted using
51        the `img_as_float` function and thus the standard deviation is in
52        respect to the range ``[0, 1]``. If the value is ``None`` the standard
53        deviation of the ``image`` will be used.
54    max_value : float
55        Maximum value of the input image.
56    dtype : data type object, optional (default : float)
57        The type and size of the data to be returned.
58
59    Returns
60    -------
61    color_lut : ndarray
62        Lookup table for the color distance sigma.
63    """
64    values = np.linspace(0, max_value, bins, endpoint=False)
65    return _gaussian_weight(values, sigma**2, dtype=dtype)
66
67
68def _compute_spatial_lut(win_size, sigma, *, dtype=float):
69    """Helping function. Define a lookup table containing Gaussian filter
70    values using the spatial sigma.
71
72    Parameters
73    ----------
74    win_size : int
75        Window size for filtering.
76        If win_size is not specified, it is calculated as
77        ``max(5, 2 * ceil(3 * sigma_spatial) + 1)``.
78    sigma : float
79        Standard deviation for range distance. A larger value results in
80        averaging of pixels with larger spatial differences.
81    dtype : data type object
82        The type and size of the data to be returned.
83
84    Returns
85    -------
86    spatial_lut : ndarray
87        Lookup table for the spatial sigma.
88    """
89    grid_points = np.arange(-win_size // 2, win_size // 2 + 1)
90    rr, cc = np.meshgrid(grid_points, grid_points, indexing='ij')
91    distances = np.hypot(rr, cc)
92    return _gaussian_weight(distances, sigma**2, dtype=dtype).ravel()
93
94
95@utils.channel_as_last_axis()
96@utils.deprecate_multichannel_kwarg(multichannel_position=7)
97def denoise_bilateral(image, win_size=None, sigma_color=None, sigma_spatial=1,
98                      bins=10000, mode='constant', cval=0, multichannel=False,
99                      *, channel_axis=None):
100    """Denoise image using bilateral filter.
101
102    Parameters
103    ----------
104    image : ndarray, shape (M, N[, 3])
105        Input image, 2D grayscale or RGB.
106    win_size : int
107        Window size for filtering.
108        If win_size is not specified, it is calculated as
109        ``max(5, 2 * ceil(3 * sigma_spatial) + 1)``.
110    sigma_color : float
111        Standard deviation for grayvalue/color distance (radiometric
112        similarity). A larger value results in averaging of pixels with larger
113        radiometric differences. If ``None``, the standard deviation of
114        ``image`` will be used.
115    sigma_spatial : float
116        Standard deviation for range distance. A larger value results in
117        averaging of pixels with larger spatial differences.
118    bins : int
119        Number of discrete values for Gaussian weights of color filtering.
120        A larger value results in improved accuracy.
121    mode : {'constant', 'edge', 'symmetric', 'reflect', 'wrap'}
122        How to handle values outside the image borders. See
123        `numpy.pad` for detail.
124    cval : string
125        Used in conjunction with mode 'constant', the value outside
126        the image boundaries.
127    multichannel : bool
128        Whether the last axis of the image is to be interpreted as multiple
129        channels or another spatial dimension. This argument is deprecated:
130        specify `channel_axis` instead.
131    channel_axis : int or None, optional
132        If None, the image is assumed to be a grayscale (single channel) image.
133        Otherwise, this parameter indicates which axis of the array corresponds
134        to channels.
135
136        .. versionadded:: 0.19
137           ``channel_axis`` was added in 0.19.
138
139    Returns
140    -------
141    denoised : ndarray
142        Denoised image.
143
144    Notes
145    -----
146    This is an edge-preserving, denoising filter. It averages pixels based on
147    their spatial closeness and radiometric similarity [1]_.
148
149    Spatial closeness is measured by the Gaussian function of the Euclidean
150    distance between two pixels and a certain standard deviation
151    (`sigma_spatial`).
152
153    Radiometric similarity is measured by the Gaussian function of the
154    Euclidean distance between two color values and a certain standard
155    deviation (`sigma_color`).
156
157    Note that, if the image is of any `int` dtype, ``image`` will be
158    converted using the `img_as_float` function and thus the standard
159    deviation (`sigma_color`) will be in range ``[0, 1]``.
160
161    For more information on scikit-image's data type conversions and how
162    images are rescaled in these conversions,
163    see: https://scikit-image.org/docs/stable/user_guide/data_types.html.
164
165    References
166    ----------
167    .. [1] C. Tomasi and R. Manduchi. "Bilateral Filtering for Gray and Color
168           Images." IEEE International Conference on Computer Vision (1998)
169           839-846. :DOI:`10.1109/ICCV.1998.710815`
170
171    Examples
172    --------
173    >>> from skimage import data, img_as_float
174    >>> astro = img_as_float(data.astronaut())
175    >>> astro = astro[220:300, 220:320]
176    >>> rng = np.random.default_rng()
177    >>> noisy = astro + 0.6 * astro.std() * rng.random(astro.shape)
178    >>> noisy = np.clip(noisy, 0, 1)
179    >>> denoised = denoise_bilateral(noisy, sigma_color=0.05, sigma_spatial=15,
180    ...                              channel_axis=-1)
181    """
182    if channel_axis is not None:
183        if image.ndim != 3:
184            if image.ndim == 2:
185                raise ValueError("Use ``multichannel=False`` for 2D grayscale "
186                                 "images. The last axis of the input image "
187                                 "must be multiple color channels not another "
188                                 "spatial dimension.")
189            else:
190                raise ValueError(f'Bilateral filter is only implemented for '
191                                 f'2D grayscale images (image.ndim == 2) and '
192                                 f'2D multichannel (image.ndim == 3) images, '
193                                 f'but the input image has {image.ndim} dimensions.')
194        elif image.shape[2] not in (3, 4):
195            if image.shape[2] > 4:
196                msg = f'The last axis of the input image is ' \
197                      f'interpreted as channels. Input image with '\
198                      f'shape {image.shape} has {image.shape[2]} channels '\
199                      f'in last axis. ``denoise_bilateral``is implemented ' \
200                      f'for 2D grayscale and color images only.'
201                warn(msg)
202            else:
203                msg = f'Input image must be grayscale, RGB, or RGBA; ' \
204                      f'but has shape {image.shape}.'
205                warn(msg)
206    else:
207        if image.ndim > 2:
208            raise ValueError(f'Bilateral filter is not implemented for '
209                             f'grayscale images of 3 or more dimensions, '
210                             f'but input image has {image.shape} shape. Use '
211                             f'``channel_axis=-1`` for 2D RGB images.')
212
213    if win_size is None:
214        win_size = max(5, 2 * int(ceil(3 * sigma_spatial)) + 1)
215
216    min_value = image.min()
217    max_value = image.max()
218
219    if min_value == max_value:
220        return image
221
222    # if image.max() is 0, then dist_scale can have an unverified value
223    # and color_lut[<int>(dist * dist_scale)] may cause a segmentation fault
224    # so we verify we have a positive image and that the max is not 0.0.
225
226
227    image = np.atleast_3d(img_as_float(image))
228    image = np.ascontiguousarray(image)
229
230    sigma_color = sigma_color or image.std()
231
232    color_lut = _compute_color_lut(bins, sigma_color, max_value,
233                                   dtype=image.dtype)
234
235    range_lut = _compute_spatial_lut(win_size, sigma_spatial,
236                                     dtype=image.dtype)
237
238    out = np.empty(image.shape, dtype=image.dtype)
239
240    dims = image.shape[2]
241
242    # There are a number of arrays needed in the Cython function.
243    # It's easier to allocate them outside of Cython so that all
244    # arrays are in the same type, then just copy the empty array
245    # where needed within Cython.
246    empty_dims = np.empty(dims, dtype=image.dtype)
247
248    if min_value < 0:
249        image = image - min_value
250        max_value -= min_value
251    _denoise_bilateral(image, max_value, win_size, sigma_color, sigma_spatial,
252                       bins, mode, cval, color_lut, range_lut, empty_dims, out)
253    # need to drop the added channels axis for grayscale images
254    out = np.squeeze(out)
255    if min_value < 0:
256        out += min_value
257    return out
258
259
260@utils.channel_as_last_axis()
261@utils.deprecate_multichannel_kwarg()
262@utils.deprecate_kwarg({'max_iter': 'max_num_iter'}, removed_version="1.0",
263                       deprecated_version="0.19")
264def denoise_tv_bregman(image, weight=5.0, max_num_iter=100, eps=1e-3,
265                       isotropic=True, *, channel_axis=None,
266                       multichannel=False):
267    """Perform total-variation denoising using split-Bregman optimization.
268
269    Total-variation denoising (also know as total-variation regularization)
270    tries to find an image with less total-variation under the constraint
271    of being similar to the input image, which is controlled by the
272    regularization parameter ([1]_, [2]_, [3]_, [4]_).
273
274    Parameters
275    ----------
276    image : ndarray
277        Input data to be denoised (converted using img_as_float`).
278    weight : float
279        Denoising weight. The smaller the `weight`, the more denoising (at
280        the expense of less similarity to the `input`). The regularization
281        parameter `lambda` is chosen as `2 * weight`.
282    eps : float, optional
283        Relative difference of the value of the cost function that determines
284        the stop criterion. The algorithm stops when::
285
286            SUM((u(n) - u(n-1))**2) < eps
287
288    max_num_iter : int, optional
289        Maximal number of iterations used for the optimization.
290    isotropic : boolean, optional
291        Switch between isotropic and anisotropic TV denoising.
292    channel_axis : int or None, optional
293        If None, the image is assumed to be a grayscale (single channel) image.
294        Otherwise, this parameter indicates which axis of the array corresponds
295        to channels.
296
297        .. versionadded:: 0.19
298           ``channel_axis`` was added in 0.19.
299    multichannel : bool, optional
300        Apply total-variation denoising separately for each channel. This
301        option should be true for color images, otherwise the denoising is
302        also applied in the channels dimension. This argument is deprecated:
303        specify `channel_axis` instead.
304
305    Returns
306    -------
307    u : ndarray
308        Denoised image.
309
310    References
311    ----------
312    .. [1] https://en.wikipedia.org/wiki/Total_variation_denoising
313    .. [2] Tom Goldstein and Stanley Osher, "The Split Bregman Method For L1
314           Regularized Problems",
315           ftp://ftp.math.ucla.edu/pub/camreport/cam08-29.pdf
316    .. [3] Pascal Getreuer, "Rudin–Osher–Fatemi Total Variation Denoising
317           using Split Bregman" in Image Processing On Line on 2012–05–19,
318           https://www.ipol.im/pub/art/2012/g-tvd/article_lr.pdf
319    .. [4] https://web.math.ucsb.edu/~cgarcia/UGProjects/BregmanAlgorithms_JacquelineBush.pdf
320
321    """
322    image = np.atleast_3d(img_as_float(image))
323
324    rows = image.shape[0]
325    cols = image.shape[1]
326    dims = image.shape[2]
327
328    shape_ext = (rows + 2, cols + 2, dims)
329
330    out = np.zeros(shape_ext, image.dtype)
331
332    if channel_axis is not None:
333        channel_out = np.zeros(shape_ext[:2] + (1,), dtype=out.dtype)
334        for c in range(image.shape[-1]):
335            # the algorithm below expects 3 dimensions to always be present.
336            # slicing the array in this fashion preserves the channel dimension
337            # for us
338            channel_in = np.ascontiguousarray(image[..., c:c+1])
339
340            _denoise_tv_bregman(channel_in, image.dtype.type(weight),
341                                max_num_iter, eps, isotropic, channel_out)
342
343            out[..., c] = channel_out[..., 0]
344
345    else:
346        image = np.ascontiguousarray(image)
347
348        _denoise_tv_bregman(image, image.dtype.type(weight), max_num_iter, eps,
349                            isotropic, out)
350
351    return np.squeeze(out[1:-1, 1:-1])
352
353
354def _denoise_tv_chambolle_nd(image, weight=0.1, eps=2.e-4, n_iter_max=200):
355    """Perform total-variation denoising on n-dimensional images.
356
357    Parameters
358    ----------
359    image : ndarray
360        n-D input data to be denoised.
361    weight : float, optional
362        Denoising weight. The greater `weight`, the more denoising (at
363        the expense of fidelity to `input`).
364    eps : float, optional
365        Relative difference of the value of the cost function that determines
366        the stop criterion. The algorithm stops when:
367
368            (E_(n-1) - E_n) < eps * E_0
369
370    n_iter_max : int, optional
371        Maximal number of iterations used for the optimization.
372
373    Returns
374    -------
375    out : ndarray
376        Denoised array of floats.
377
378    Notes
379    -----
380    Rudin, Osher and Fatemi algorithm.
381    """
382
383    ndim = image.ndim
384    p = np.zeros((image.ndim, ) + image.shape, dtype=image.dtype)
385    g = np.zeros_like(p)
386    d = np.zeros_like(image)
387    i = 0
388    while i < n_iter_max:
389        if i > 0:
390            # d will be the (negative) divergence of p
391            d = -p.sum(0)
392            slices_d = [slice(None), ] * ndim
393            slices_p = [slice(None), ] * (ndim + 1)
394            for ax in range(ndim):
395                slices_d[ax] = slice(1, None)
396                slices_p[ax+1] = slice(0, -1)
397                slices_p[0] = ax
398                d[tuple(slices_d)] += p[tuple(slices_p)]
399                slices_d[ax] = slice(None)
400                slices_p[ax+1] = slice(None)
401            out = image + d
402        else:
403            out = image
404        E = (d ** 2).sum()
405
406        # g stores the gradients of out along each axis
407        # e.g. g[0] is the first order finite difference along axis 0
408        slices_g = [slice(None), ] * (ndim + 1)
409        for ax in range(ndim):
410            slices_g[ax+1] = slice(0, -1)
411            slices_g[0] = ax
412            g[tuple(slices_g)] = np.diff(out, axis=ax)
413            slices_g[ax+1] = slice(None)
414
415        norm = np.sqrt((g ** 2).sum(axis=0))[np.newaxis, ...]
416        E += weight * norm.sum()
417        tau = 1. / (2.*ndim)
418        norm *= tau / weight
419        norm += 1.
420        p -= tau * g
421        p /= norm
422        E /= float(image.size)
423        if i == 0:
424            E_init = E
425            E_previous = E
426        else:
427            if np.abs(E_previous - E) < eps * E_init:
428                break
429            else:
430                E_previous = E
431        i += 1
432    return out
433
434
435@utils.deprecate_multichannel_kwarg(multichannel_position=4)
436def denoise_tv_chambolle(image, weight=0.1, eps=2.e-4, n_iter_max=200,
437                         multichannel=False, *, channel_axis=None):
438    """Perform total-variation denoising on n-dimensional images.
439
440    Parameters
441    ----------
442    image : ndarray of ints, uints or floats
443        Input data to be denoised. `image` can be of any numeric type,
444        but it is cast into an ndarray of floats for the computation
445        of the denoised image.
446    weight : float, optional
447        Denoising weight. The greater `weight`, the more denoising (at
448        the expense of fidelity to `input`).
449    eps : float, optional
450        Relative difference of the value of the cost function that
451        determines the stop criterion. The algorithm stops when:
452
453            (E_(n-1) - E_n) < eps * E_0
454
455    n_iter_max : int, optional
456        Maximal number of iterations used for the optimization.
457    multichannel : bool, optional
458        Apply total-variation denoising separately for each channel. This
459        option should be true for color images, otherwise the denoising is
460        also applied in the channels dimension. This argument is deprecated:
461        specify `channel_axis` instead.
462    channel_axis : int or None, optional
463        If None, the image is assumed to be a grayscale (single channel) image.
464        Otherwise, this parameter indicates which axis of the array corresponds
465        to channels.
466
467        .. versionadded:: 0.19
468           ``channel_axis`` was added in 0.19.
469
470    Returns
471    -------
472    out : ndarray
473        Denoised image.
474
475    Notes
476    -----
477    Make sure to set the multichannel parameter appropriately for color images.
478
479    The principle of total variation denoising is explained in
480    https://en.wikipedia.org/wiki/Total_variation_denoising
481
482    The principle of total variation denoising is to minimize the
483    total variation of the image, which can be roughly described as
484    the integral of the norm of the image gradient. Total variation
485    denoising tends to produce "cartoon-like" images, that is,
486    piecewise-constant images.
487
488    This code is an implementation of the algorithm of Rudin, Fatemi and Osher
489    that was proposed by Chambolle in [1]_.
490
491    References
492    ----------
493    .. [1] A. Chambolle, An algorithm for total variation minimization and
494           applications, Journal of Mathematical Imaging and Vision,
495           Springer, 2004, 20, 89-97.
496
497    Examples
498    --------
499    2D example on astronaut image:
500
501    >>> from skimage import color, data
502    >>> img = color.rgb2gray(data.astronaut())[:50, :50]
503    >>> rng = np.random.default_rng()
504    >>> img += 0.5 * img.std() * rng.standard_normal(img.shape)
505    >>> denoised_img = denoise_tv_chambolle(img, weight=60)
506
507    3D example on synthetic data:
508
509    >>> x, y, z = np.ogrid[0:20, 0:20, 0:20]
510    >>> mask = (x - 22)**2 + (y - 20)**2 + (z - 17)**2 < 8**2
511    >>> mask = mask.astype(float)
512    >>> rng = np.random.default_rng()
513    >>> mask += 0.2 * rng.standard_normal(mask.shape)
514    >>> res = denoise_tv_chambolle(mask, weight=100)
515
516    """
517
518    im_type = image.dtype
519    if not im_type.kind == 'f':
520        image = img_as_float(image)
521
522    # enforce float16->float32 and float128->float64
523    float_dtype = _supported_float_type(image.dtype)
524    image = image.astype(float_dtype, copy=False)
525
526    if channel_axis is not None:
527        channel_axis = channel_axis % image.ndim
528        _at = functools.partial(utils.slice_at_axis, axis=channel_axis)
529        out = np.zeros_like(image)
530        for c in range(image.shape[channel_axis]):
531            out[_at(c)] = _denoise_tv_chambolle_nd(image[_at(c)], weight, eps,
532                                                   n_iter_max)
533    else:
534        out = _denoise_tv_chambolle_nd(image, weight, eps, n_iter_max)
535    return out
536
537
538def _bayes_thresh(details, var):
539    """BayesShrink threshold for a zero-mean details coeff array."""
540    # Equivalent to:  dvar = np.var(details) for 0-mean details array
541    dvar = np.mean(details*details)
542    eps = np.finfo(details.dtype).eps
543    thresh = var / np.sqrt(max(dvar - var, eps))
544    return thresh
545
546
547def _universal_thresh(img, sigma):
548    """ Universal threshold used by the VisuShrink method """
549    return sigma*np.sqrt(2*np.log(img.size))
550
551
552def _sigma_est_dwt(detail_coeffs, distribution='Gaussian'):
553    """Calculate the robust median estimator of the noise standard deviation.
554
555    Parameters
556    ----------
557    detail_coeffs : ndarray
558        The detail coefficients corresponding to the discrete wavelet
559        transform of an image.
560    distribution : str
561        The underlying noise distribution.
562
563    Returns
564    -------
565    sigma : float
566        The estimated noise standard deviation (see section 4.2 of [1]_).
567
568    References
569    ----------
570    .. [1] D. L. Donoho and I. M. Johnstone. "Ideal spatial adaptation
571       by wavelet shrinkage." Biometrika 81.3 (1994): 425-455.
572       :DOI:`10.1093/biomet/81.3.425`
573    """
574    # Consider regions with detail coefficients exactly zero to be masked out
575    detail_coeffs = detail_coeffs[np.nonzero(detail_coeffs)]
576
577    if distribution.lower() == 'gaussian':
578        # 75th quantile of the underlying, symmetric noise distribution
579        denom = scipy.stats.norm.ppf(0.75)
580        sigma = np.median(np.abs(detail_coeffs)) / denom
581    else:
582        raise ValueError("Only Gaussian noise estimation is currently "
583                         "supported")
584    return sigma
585
586
587def _wavelet_threshold(image, wavelet, method=None, threshold=None,
588                       sigma=None, mode='soft', wavelet_levels=None):
589    """Perform wavelet thresholding.
590
591    Parameters
592    ----------
593    image : ndarray (2d or 3d) of ints, uints or floats
594        Input data to be denoised. `image` can be of any numeric type,
595        but it is cast into an ndarray of floats for the computation
596        of the denoised image.
597    wavelet : string
598        The type of wavelet to perform. Can be any of the options
599        pywt.wavelist outputs. For example, this may be any of ``{db1, db2,
600        db3, db4, haar}``.
601    method : {'BayesShrink', 'VisuShrink'}, optional
602        Thresholding method to be used. The currently supported methods are
603        "BayesShrink" [1]_ and "VisuShrink" [2]_. If it is set to None, a
604        user-specified ``threshold`` must be supplied instead.
605    threshold : float, optional
606        The thresholding value to apply during wavelet coefficient
607        thresholding. The default value (None) uses the selected ``method`` to
608        estimate appropriate threshold(s) for noise removal.
609    sigma : float, optional
610        The standard deviation of the noise. The noise is estimated when sigma
611        is None (the default) by the method in [2]_.
612    mode : {'soft', 'hard'}, optional
613        An optional argument to choose the type of denoising performed. It
614        noted that choosing soft thresholding given additive noise finds the
615        best approximation of the original image.
616    wavelet_levels : int or None, optional
617        The number of wavelet decomposition levels to use.  The default is
618        three less than the maximum number of possible decomposition levels
619        (see Notes below).
620
621    Returns
622    -------
623    out : ndarray
624        Denoised image.
625
626    References
627    ----------
628    .. [1] Chang, S. Grace, Bin Yu, and Martin Vetterli. "Adaptive wavelet
629           thresholding for image denoising and compression." Image Processing,
630           IEEE Transactions on 9.9 (2000): 1532-1546.
631           :DOI:`10.1109/83.862633`
632    .. [2] D. L. Donoho and I. M. Johnstone. "Ideal spatial adaptation
633           by wavelet shrinkage." Biometrika 81.3 (1994): 425-455.
634           :DOI:`10.1093/biomet/81.3.425`
635    """
636    wavelet = pywt.Wavelet(wavelet)
637    if not wavelet.orthogonal:
638        warn(f'Wavelet thresholding was designed for '
639             f'use with orthogonal wavelets. For nonorthogonal '
640             f'wavelets such as {wavelet.name},results are '
641             f'likely to be suboptimal.')
642
643    # original_extent is used to workaround PyWavelets issue #80
644    # odd-sized input results in an image with 1 extra sample after waverecn
645    original_extent = tuple(slice(s) for s in image.shape)
646
647    # Determine the number of wavelet decomposition levels
648    if wavelet_levels is None:
649        # Determine the maximum number of possible levels for image
650        wavelet_levels = pywt.dwtn_max_level(image.shape, wavelet)
651
652        # Skip coarsest wavelet scales (see Notes in docstring).
653        wavelet_levels = max(wavelet_levels - 3, 1)
654
655    coeffs = pywt.wavedecn(image, wavelet=wavelet, level=wavelet_levels)
656    # Detail coefficients at each decomposition level
657    dcoeffs = coeffs[1:]
658
659    if sigma is None:
660        # Estimate the noise via the method in [2]_
661        detail_coeffs = dcoeffs[-1]['d' * image.ndim]
662        sigma = _sigma_est_dwt(detail_coeffs, distribution='Gaussian')
663
664    if method is not None and threshold is not None:
665        warn(f'Thresholding method {method} selected. The '
666             f'user-specified threshold will be ignored.')
667
668    if threshold is None:
669        var = sigma**2
670        if method is None:
671            raise ValueError(
672                "If method is None, a threshold must be provided.")
673        elif method == "BayesShrink":
674            # The BayesShrink thresholds from [1]_ in docstring
675            threshold = [{key: _bayes_thresh(level[key], var) for key in level}
676                         for level in dcoeffs]
677        elif method == "VisuShrink":
678            # The VisuShrink thresholds from [2]_ in docstring
679            threshold = _universal_thresh(image, sigma)
680        else:
681            raise ValueError(f'Unrecognized method: {method}')
682
683    if np.isscalar(threshold):
684        # A single threshold for all coefficient arrays
685        denoised_detail = [{key: pywt.threshold(level[key],
686                                                value=threshold,
687                                                mode=mode) for key in level}
688                           for level in dcoeffs]
689    else:
690        # Dict of unique threshold coefficients for each detail coeff. array
691        denoised_detail = [{key: pywt.threshold(level[key],
692                                                value=thresh[key],
693                                                mode=mode) for key in level}
694                           for thresh, level in zip(threshold, dcoeffs)]
695    denoised_coeffs = [coeffs[0]] + denoised_detail
696    return pywt.waverecn(denoised_coeffs, wavelet)[original_extent]
697
698
699def _scale_sigma_and_image_consistently(image, sigma, multichannel,
700                                        rescale_sigma):
701    """If the ``image`` is rescaled, also rescale ``sigma`` consistently.
702
703    Images that are not floating point will be rescaled via ``img_as_float``.
704    Half-precision images will be promoted to single precision.
705    """
706    if multichannel:
707        if isinstance(sigma, numbers.Number) or sigma is None:
708            sigma = [sigma] * image.shape[-1]
709        elif len(sigma) != image.shape[-1]:
710            raise ValueError(
711                "When multichannel is True, sigma must be a scalar or have "
712                "length equal to the number of channels")
713    if image.dtype.kind != 'f':
714        if rescale_sigma:
715            range_pre = image.max() - image.min()
716        image = img_as_float(image)
717        if rescale_sigma:
718            range_post = image.max() - image.min()
719            # apply the same magnitude scaling to sigma
720            scale_factor = range_post / range_pre
721            if multichannel:
722                sigma = [s * scale_factor if s is not None else s
723                         for s in sigma]
724            elif sigma is not None:
725                sigma *= scale_factor
726    elif image.dtype == np.float16:
727        image = image.astype(np.float32)
728    return image, sigma
729
730
731def _rescale_sigma_rgb2ycbcr(sigmas):
732    """Convert user-provided noise standard deviations to YCbCr space.
733
734    Notes
735    -----
736    If R, G, B are linearly independent random variables and a1, a2, a3 are
737    scalars, then random variable C:
738        C = a1 * R + a2 * G + a3 * B
739    has variance, var_C, given by:
740        var_C = a1**2 * var_R + a2**2 * var_G + a3**2 * var_B
741    """
742    if sigmas[0] is None:
743        return sigmas
744    sigmas = np.asarray(sigmas)
745    rgv_variances = sigmas * sigmas
746    for i in range(3):
747        scalars = ycbcr_from_rgb[i, :]
748        var_channel = np.sum(scalars * scalars * rgv_variances)
749        sigmas[i] = np.sqrt(var_channel)
750    return sigmas
751
752
753@utils.channel_as_last_axis()
754@utils.deprecate_multichannel_kwarg(multichannel_position=5)
755def denoise_wavelet(image, sigma=None, wavelet='db1', mode='soft',
756                    wavelet_levels=None, multichannel=False,
757                    convert2ycbcr=False, method='BayesShrink',
758                    rescale_sigma=True, *, channel_axis=None):
759    """Perform wavelet denoising on an image.
760
761    Parameters
762    ----------
763    image : ndarray ([M[, N[, ...P]][, C]) of ints, uints or floats
764        Input data to be denoised. `image` can be of any numeric type,
765        but it is cast into an ndarray of floats for the computation
766        of the denoised image.
767    sigma : float or list, optional
768        The noise standard deviation used when computing the wavelet detail
769        coefficient threshold(s). When None (default), the noise standard
770        deviation is estimated via the method in [2]_.
771    wavelet : string, optional
772        The type of wavelet to perform and can be any of the options
773        ``pywt.wavelist`` outputs. The default is `'db1'`. For example,
774        ``wavelet`` can be any of ``{'db2', 'haar', 'sym9'}`` and many more.
775    mode : {'soft', 'hard'}, optional
776        An optional argument to choose the type of denoising performed. It
777        noted that choosing soft thresholding given additive noise finds the
778        best approximation of the original image.
779    wavelet_levels : int or None, optional
780        The number of wavelet decomposition levels to use.  The default is
781        three less than the maximum number of possible decomposition levels.
782    multichannel : bool, optional
783        Apply wavelet denoising separately for each channel (where channels
784        correspond to the final axis of the array). This argument is
785        deprecated: specify `channel_axis` instead.
786    convert2ycbcr : bool, optional
787        If True and multichannel True, do the wavelet denoising in the YCbCr
788        colorspace instead of the RGB color space. This typically results in
789        better performance for RGB images.
790    method : {'BayesShrink', 'VisuShrink'}, optional
791        Thresholding method to be used. The currently supported methods are
792        "BayesShrink" [1]_ and "VisuShrink" [2]_. Defaults to "BayesShrink".
793    rescale_sigma : bool, optional
794        If False, no rescaling of the user-provided ``sigma`` will be
795        performed. The default of ``True`` rescales sigma appropriately if the
796        image is rescaled internally.
797
798        .. versionadded:: 0.16
799           ``rescale_sigma`` was introduced in 0.16
800    channel_axis : int or None, optional
801        If None, the image is assumed to be a grayscale (single channel) image.
802        Otherwise, this parameter indicates which axis of the array corresponds
803        to channels.
804
805        .. versionadded:: 0.19
806           ``channel_axis`` was added in 0.19.
807
808    Returns
809    -------
810    out : ndarray
811        Denoised image.
812
813    Notes
814    -----
815    The wavelet domain is a sparse representation of the image, and can be
816    thought of similarly to the frequency domain of the Fourier transform.
817    Sparse representations have most values zero or near-zero and truly random
818    noise is (usually) represented by many small values in the wavelet domain.
819    Setting all values below some threshold to 0 reduces the noise in the
820    image, but larger thresholds also decrease the detail present in the image.
821
822    If the input is 3D, this function performs wavelet denoising on each color
823    plane separately.
824
825    .. versionchanged:: 0.16
826       For floating point inputs, the original input range is maintained and
827       there is no clipping applied to the output. Other input types will be
828       converted to a floating point value in the range [-1, 1] or [0, 1]
829       depending on the input image range. Unless ``rescale_sigma = False``,
830       any internal rescaling applied to the ``image`` will also be applied
831       to ``sigma`` to maintain the same relative amplitude.
832
833    Many wavelet coefficient thresholding approaches have been proposed. By
834    default, ``denoise_wavelet`` applies BayesShrink, which is an adaptive
835    thresholding method that computes separate thresholds for each wavelet
836    sub-band as described in [1]_.
837
838    If ``method == "VisuShrink"``, a single "universal threshold" is applied to
839    all wavelet detail coefficients as described in [2]_. This threshold
840    is designed to remove all Gaussian noise at a given ``sigma`` with high
841    probability, but tends to produce images that appear overly smooth.
842
843    Although any of the wavelets from ``PyWavelets`` can be selected, the
844    thresholding methods assume an orthogonal wavelet transform and may not
845    choose the threshold appropriately for biorthogonal wavelets. Orthogonal
846    wavelets are desirable because white noise in the input remains white noise
847    in the subbands. Biorthogonal wavelets lead to colored noise in the
848    subbands. Additionally, the orthogonal wavelets in PyWavelets are
849    orthonormal so that noise variance in the subbands remains identical to the
850    noise variance of the input. Example orthogonal wavelets are the Daubechies
851    (e.g. 'db2') or symmlet (e.g. 'sym2') families.
852
853    References
854    ----------
855    .. [1] Chang, S. Grace, Bin Yu, and Martin Vetterli. "Adaptive wavelet
856           thresholding for image denoising and compression." Image Processing,
857           IEEE Transactions on 9.9 (2000): 1532-1546.
858           :DOI:`10.1109/83.862633`
859    .. [2] D. L. Donoho and I. M. Johnstone. "Ideal spatial adaptation
860           by wavelet shrinkage." Biometrika 81.3 (1994): 425-455.
861           :DOI:`10.1093/biomet/81.3.425`
862
863    Examples
864    --------
865    >>> from skimage import color, data
866    >>> img = img_as_float(data.astronaut())
867    >>> img = color.rgb2gray(img)
868    >>> rng = np.random.default_rng()
869    >>> img += 0.1 * rng.standard_normal(img.shape)
870    >>> img = np.clip(img, 0, 1)
871    >>> denoised_img = denoise_wavelet(img, sigma=0.1, rescale_sigma=True)
872
873    """
874    multichannel = channel_axis is not None
875    if method not in ["BayesShrink", "VisuShrink"]:
876        raise ValueError(f'Invalid method: {method}. The currently supported '
877                         f'methods are "BayesShrink" and "VisuShrink".')
878
879    # floating-point inputs are not rescaled, so don't clip their output.
880    clip_output = image.dtype.kind != 'f'
881
882    if convert2ycbcr and not multichannel:
883        raise ValueError("convert2ycbcr requires multichannel == True")
884
885    image, sigma = _scale_sigma_and_image_consistently(image,
886                                                       sigma,
887                                                       multichannel,
888                                                       rescale_sigma)
889    if multichannel:
890        if convert2ycbcr:
891            out = color.rgb2ycbcr(image)
892            # convert user-supplied sigmas to the new colorspace as well
893            if rescale_sigma:
894                sigma = _rescale_sigma_rgb2ycbcr(sigma)
895            for i in range(3):
896                # renormalizing this color channel to live in [0, 1]
897                _min, _max = out[..., i].min(), out[..., i].max()
898                scale_factor = _max - _min
899                if scale_factor == 0:
900                    # skip any channel containing only zeros!
901                    continue
902                channel = out[..., i] - _min
903                channel /= scale_factor
904                sigma_channel = sigma[i]
905                if sigma_channel is not None:
906                    sigma_channel /= scale_factor
907                out[..., i] = denoise_wavelet(channel,
908                                              wavelet=wavelet,
909                                              method=method,
910                                              sigma=sigma_channel,
911                                              mode=mode,
912                                              wavelet_levels=wavelet_levels,
913                                              rescale_sigma=rescale_sigma)
914                out[..., i] = out[..., i] * scale_factor
915                out[..., i] += _min
916            out = color.ycbcr2rgb(out)
917        else:
918            out = np.empty_like(image)
919            for c in range(image.shape[-1]):
920                out[..., c] = _wavelet_threshold(image[..., c],
921                                                 wavelet=wavelet,
922                                                 method=method,
923                                                 sigma=sigma[c], mode=mode,
924                                                 wavelet_levels=wavelet_levels)
925    else:
926        out = _wavelet_threshold(image, wavelet=wavelet, method=method,
927                                 sigma=sigma, mode=mode,
928                                 wavelet_levels=wavelet_levels)
929
930    if clip_output:
931        clip_range = (-1, 1) if image.min() < 0 else (0, 1)
932        out = np.clip(out, *clip_range, out=out)
933    return out
934
935
936@utils.deprecate_multichannel_kwarg(multichannel_position=2)
937def estimate_sigma(image, average_sigmas=False, multichannel=False, *,
938                   channel_axis=None):
939    """
940    Robust wavelet-based estimator of the (Gaussian) noise standard deviation.
941
942    Parameters
943    ----------
944    image : ndarray
945        Image for which to estimate the noise standard deviation.
946    average_sigmas : bool, optional
947        If true, average the channel estimates of `sigma`.  Otherwise return
948        a list of sigmas corresponding to each channel.
949    multichannel : bool
950        Estimate sigma separately for each channel. This argument is
951        deprecated: specify `channel_axis` instead.
952    channel_axis : int or None, optional
953        If None, the image is assumed to be a grayscale (single channel) image.
954        Otherwise, this parameter indicates which axis of the array corresponds
955        to channels.
956
957        .. versionadded:: 0.19
958           ``channel_axis`` was added in 0.19.
959
960    Returns
961    -------
962    sigma : float or list
963        Estimated noise standard deviation(s).  If `multichannel` is True and
964        `average_sigmas` is False, a separate noise estimate for each channel
965        is returned.  Otherwise, the average of the individual channel
966        estimates is returned.
967
968    Notes
969    -----
970    This function assumes the noise follows a Gaussian distribution. The
971    estimation algorithm is based on the median absolute deviation of the
972    wavelet detail coefficients as described in section 4.2 of [1]_.
973
974    References
975    ----------
976    .. [1] D. L. Donoho and I. M. Johnstone. "Ideal spatial adaptation
977       by wavelet shrinkage." Biometrika 81.3 (1994): 425-455.
978       :DOI:`10.1093/biomet/81.3.425`
979
980    Examples
981    --------
982    >>> import skimage.data
983    >>> from skimage import img_as_float
984    >>> img = img_as_float(skimage.data.camera())
985    >>> sigma = 0.1
986    >>> rng = np.random.default_rng()
987    >>> img = img + sigma * rng.standard_normal(img.shape)
988    >>> sigma_hat = estimate_sigma(img, channel_axis=None)
989    """
990    if channel_axis is not None:
991        channel_axis = channel_axis % image.ndim
992        _at = functools.partial(utils.slice_at_axis, axis=channel_axis)
993        nchannels = image.shape[channel_axis]
994        sigmas = [estimate_sigma(
995            image[_at(c)], channel_axis=None) for c in range(nchannels)]
996        if average_sigmas:
997            sigmas = np.mean(sigmas)
998        return sigmas
999    elif image.shape[-1] <= 4:
1000        msg = f'image is size {image.shape[-1]} on the last axis, '\
1001              f'but channel_axis is None. If this is a color image, '\
1002              f'please set channel_axis=-1 for proper noise estimation.'
1003        warn(msg)
1004    coeffs = pywt.dwtn(image, wavelet='db2')
1005    detail_coeffs = coeffs['d' * image.ndim]
1006    return _sigma_est_dwt(detail_coeffs, distribution='Gaussian')
1007