1import inspect
2import functools
3import sys
4import warnings
5from collections.abc import Iterable
6
7import numpy as np
8import scipy
9from numpy.lib import NumpyVersion
10
11from ._warnings import all_warnings, warn
12
13
14__all__ = ['deprecated', 'get_bound_method_class', 'all_warnings',
15           'safe_as_int', 'check_shape_equality', 'check_nD', 'warn',
16           'reshape_nd', 'identity', 'slice_at_axis']
17
18
19class skimage_deprecation(Warning):
20    """Create our own deprecation class, since Python >= 2.7
21    silences deprecations by default.
22
23    """
24    pass
25
26
27class change_default_value:
28    """Decorator for changing the default value of an argument.
29
30    Parameters
31    ----------
32    arg_name: str
33        The name of the argument to be updated.
34    new_value: any
35        The argument new value.
36    changed_version : str
37        The package version in which the change will be introduced.
38    warning_msg: str
39        Optional warning message. If None, a generic warning message
40        is used.
41
42    """
43
44    def __init__(self, arg_name, *, new_value, changed_version,
45                 warning_msg=None):
46        self.arg_name = arg_name
47        self.new_value = new_value
48        self.warning_msg = warning_msg
49        self.changed_version = changed_version
50
51    def __call__(self, func):
52        parameters = inspect.signature(func).parameters
53        arg_idx = list(parameters.keys()).index(self.arg_name)
54        old_value = parameters[self.arg_name].default
55
56        if self.warning_msg is None:
57            self.warning_msg = (
58                f'The new recommended value for {self.arg_name} is '
59                f'{self.new_value}. Until version {self.changed_version}, '
60                f'the default {self.arg_name} value is {old_value}. '
61                f'From version {self.changed_version}, the {self.arg_name} '
62                f'default value will be {self.new_value}. To avoid '
63                f'this warning, please explicitly set {self.arg_name} value.')
64
65        @functools.wraps(func)
66        def fixed_func(*args, **kwargs):
67            if len(args) < arg_idx + 1 and self.arg_name not in kwargs.keys():
68                # warn that arg_name default value changed:
69                warnings.warn(self.warning_msg, FutureWarning, stacklevel=2)
70            return func(*args, **kwargs)
71
72        return fixed_func
73
74
75class remove_arg:
76    """Decorator to remove an argument from function's signature.
77
78    Parameters
79    ----------
80    arg_name: str
81        The name of the argument to be removed.
82    changed_version : str
83        The package version in which the warning will be replaced by
84        an error.
85    help_msg: str
86        Optional message appended to the generic warning message.
87
88    """
89
90    def __init__(self, arg_name, *, changed_version, help_msg=None):
91        self.arg_name = arg_name
92        self.help_msg = help_msg
93        self.changed_version = changed_version
94
95    def __call__(self, func):
96        parameters = inspect.signature(func).parameters
97        arg_idx = list(parameters.keys()).index(self.arg_name)
98        warning_msg = (
99            f'{self.arg_name} argument is deprecated and will be removed '
100            f'in version {self.changed_version}. To avoid this warning, '
101            f'please do not use the {self.arg_name} argument. Please '
102            f'see {func.__name__} documentation for more details.')
103
104        if self.help_msg is not None:
105            warning_msg += f' {self.help_msg}'
106
107        @functools.wraps(func)
108        def fixed_func(*args, **kwargs):
109            if len(args) > arg_idx or self.arg_name in kwargs.keys():
110                # warn that arg_name is deprecated
111                warnings.warn(warning_msg, FutureWarning, stacklevel=2)
112            return func(*args, **kwargs)
113
114        return fixed_func
115
116
117def docstring_add_deprecated(func, kwarg_mapping, deprecated_version):
118    """Add deprecated kwarg(s) to the "Other Params" section of a docstring.
119
120    Parameters
121    ---------
122    func : function
123        The function whose docstring we wish to update.
124    kwarg_mapping : dict
125        A dict containing {old_arg: new_arg} key/value pairs as used by
126        `deprecate_kwarg`.
127    deprecated_version : str
128        A major.minor version string specifying when old_arg was
129        deprecated.
130
131    Returns
132    -------
133    new_doc : str
134        The updated docstring. Returns the original docstring if numpydoc is
135        not available.
136    """
137    if func.__doc__ is None:
138        return None
139    try:
140        from numpydoc.docscrape import FunctionDoc, Parameter
141    except ImportError:
142        # Return an unmodified docstring if numpydoc is not available.
143        return func.__doc__
144
145    Doc = FunctionDoc(func)
146    for old_arg, new_arg in kwarg_mapping.items():
147        desc = [f'Deprecated in favor of `{new_arg}`.',
148                f'',
149                f'.. deprecated:: {deprecated_version}']
150        Doc['Other Parameters'].append(
151            Parameter(name=old_arg,
152                      type='DEPRECATED',
153                      desc=desc)
154        )
155    new_docstring = str(Doc)
156
157    # new_docstring will have a header starting with:
158    #
159    # .. function:: func.__name__
160    #
161    # and some additional blank lines. We strip these off below.
162    split = new_docstring.split('\n')
163    no_header = split[1:]
164    while not no_header[0].strip():
165        no_header.pop(0)
166
167    # Store the initial description before any of the Parameters fields.
168    # Usually this is a single line, but the while loop covers any case
169    # where it is not.
170    descr = no_header.pop(0)
171    while no_header[0].strip():
172        descr += '\n    ' + no_header.pop(0)
173    descr += '\n\n'
174    # '\n    ' rather than '\n' here to restore the original indentation.
175    final_docstring = descr + '\n    '.join(no_header)
176    # strip any extra spaces from ends of lines
177    final_docstring = '\n'.join(
178        [line.rstrip() for line in final_docstring.split('\n')]
179    )
180    return final_docstring
181
182
183class deprecate_kwarg:
184    """Decorator ensuring backward compatibility when argument names are
185    modified in a function definition.
186
187    Parameters
188    ----------
189    kwarg_mapping: dict
190        Mapping between the function's old argument names and the new
191        ones.
192    deprecated_version : str
193        The package version in which the argument was first deprecated.
194    warning_msg: str
195        Optional warning message. If None, a generic warning message
196        is used.
197    removed_version : str
198        The package version in which the deprecated argument will be
199        removed.
200
201    """
202
203    def __init__(self, kwarg_mapping, deprecated_version, warning_msg=None,
204                 removed_version=None):
205        self.kwarg_mapping = kwarg_mapping
206        if warning_msg is None:
207            self.warning_msg = ("`{old_arg}` is a deprecated argument name "
208                                "for `{func_name}`. ")
209            if removed_version is not None:
210                self.warning_msg += (f'It will be removed in '
211                                     f'version {removed_version}.')
212            self.warning_msg += "Please use `{new_arg}` instead."
213        else:
214            self.warning_msg = warning_msg
215
216        self.deprecated_version = deprecated_version
217
218    def __call__(self, func):
219
220        @functools.wraps(func)
221        def fixed_func(*args, **kwargs):
222            for old_arg, new_arg in self.kwarg_mapping.items():
223                if old_arg in kwargs:
224                    #  warn that the function interface has changed:
225                    warnings.warn(self.warning_msg.format(
226                        old_arg=old_arg, func_name=func.__name__,
227                        new_arg=new_arg), FutureWarning, stacklevel=2)
228                    # Substitute new_arg to old_arg
229                    kwargs[new_arg] = kwargs.pop(old_arg)
230
231            # Call the function with the fixed arguments
232            return func(*args, **kwargs)
233
234        if func.__doc__ is not None:
235            newdoc = docstring_add_deprecated(func, self.kwarg_mapping,
236                                              self.deprecated_version)
237            fixed_func.__doc__ = newdoc
238        return fixed_func
239
240
241class deprecate_multichannel_kwarg(deprecate_kwarg):
242    """Decorator for deprecating multichannel keyword in favor of channel_axis.
243
244    Parameters
245    ----------
246    removed_version : str
247        The package version in which the deprecated argument will be
248        removed.
249
250    """
251
252    def __init__(self, removed_version='1.0', multichannel_position=None):
253        super().__init__(
254            kwarg_mapping={'multichannel': 'channel_axis'},
255            deprecated_version='0.19',
256            warning_msg=None,
257            removed_version=removed_version)
258        self.position = multichannel_position
259
260    def __call__(self, func):
261        @functools.wraps(func)
262        def fixed_func(*args, **kwargs):
263
264            if self.position is not None and len(args) > self.position:
265                warning_msg = (
266                    "Providing the `multichannel` argument positionally to "
267                    "{func_name} is deprecated. Use the `channel_axis` kwarg "
268                    "instead."
269                )
270                warnings.warn(warning_msg.format(func_name=func.__name__),
271                              FutureWarning,
272                              stacklevel=2)
273                if 'channel_axis' in kwargs:
274                    raise ValueError(
275                        "Cannot provide both a `channel_axis` kwarg and a "
276                        "positional `multichannel` value."
277                    )
278                else:
279                    channel_axis = -1 if args[self.position] else None
280                    kwargs['channel_axis'] = channel_axis
281
282            if 'multichannel' in kwargs:
283                #  warn that the function interface has changed:
284                warnings.warn(self.warning_msg.format(
285                    old_arg='multichannel', func_name=func.__name__,
286                    new_arg='channel_axis'), FutureWarning, stacklevel=2)
287
288                # multichannel = True -> last axis corresponds to channels
289                convert = {True: -1, False: None}
290                kwargs['channel_axis'] = convert[kwargs.pop('multichannel')]
291
292            # Call the function with the fixed arguments
293            return func(*args, **kwargs)
294
295        if func.__doc__ is not None:
296            newdoc = docstring_add_deprecated(
297                func, {'multichannel': 'channel_axis'}, '0.19')
298            fixed_func.__doc__ = newdoc
299        return fixed_func
300
301
302class channel_as_last_axis():
303    """Decorator for automatically making channels axis last for all arrays.
304
305    This decorator reorders axes for compatibility with functions that only
306    support channels along the last axis. After the function call is complete
307    the channels axis is restored back to its original position.
308
309    Parameters
310    ----------
311    channel_arg_positions : tuple of int, optional
312        Positional arguments at the positions specified in this tuple are
313        assumed to be multichannel arrays. The default is to assume only the
314        first argument to the function is a multichannel array.
315    channel_kwarg_names : tuple of str, optional
316        A tuple containing the names of any keyword arguments corresponding to
317        multichannel arrays.
318    multichannel_output : bool, optional
319        A boolean that should be True if the output of the function is not a
320        multichannel array and False otherwise. This decorator does not
321        currently support the general case of functions with multiple outputs
322        where some or all are multichannel.
323
324    """
325    def __init__(self, channel_arg_positions=(0,), channel_kwarg_names=(),
326                 multichannel_output=True):
327        self.arg_positions = set(channel_arg_positions)
328        self.kwarg_names = set(channel_kwarg_names)
329        self.multichannel_output = multichannel_output
330
331    def __call__(self, func):
332        @functools.wraps(func)
333        def fixed_func(*args, **kwargs):
334
335            channel_axis = kwargs.get('channel_axis', None)
336
337            if channel_axis is None:
338                return func(*args, **kwargs)
339
340            # TODO: convert scalars to a tuple in anticipation of eventually
341            #       supporting a tuple of channel axes. Right now, only an
342            #       integer or a single-element tuple is supported, though.
343            if np.isscalar(channel_axis):
344                channel_axis = (channel_axis,)
345            if len(channel_axis) > 1:
346                raise ValueError(
347                    "only a single channel axis is currently suported")
348
349            if channel_axis == (-1,) or channel_axis == -1:
350                return func(*args, **kwargs)
351
352            if self.arg_positions:
353                new_args = []
354                for pos, arg in enumerate(args):
355                    if pos in self.arg_positions:
356                        new_args.append(np.moveaxis(arg, channel_axis[0], -1))
357                    else:
358                        new_args.append(arg)
359                new_args = tuple(new_args)
360            else:
361                new_args = args
362
363            for name in self.kwarg_names:
364                kwargs[name] = np.moveaxis(kwargs[name], channel_axis[0], -1)
365
366            # now that we have moved the channels axis to the last position,
367            # change the channel_axis argument to -1
368            kwargs["channel_axis"] = -1
369
370            # Call the function with the fixed arguments
371            out = func(*new_args, **kwargs)
372            if self.multichannel_output:
373                out = np.moveaxis(out, -1, channel_axis[0])
374            return out
375
376        return fixed_func
377
378
379class deprecated(object):
380    """Decorator to mark deprecated functions with warning.
381
382    Adapted from <http://wiki.python.org/moin/PythonDecoratorLibrary>.
383
384    Parameters
385    ----------
386    alt_func : str
387        If given, tell user what function to use instead.
388    behavior : {'warn', 'raise'}
389        Behavior during call to deprecated function: 'warn' = warn user that
390        function is deprecated; 'raise' = raise error.
391    removed_version : str
392        The package version in which the deprecated function will be removed.
393    """
394
395    def __init__(self, alt_func=None, behavior='warn', removed_version=None):
396        self.alt_func = alt_func
397        self.behavior = behavior
398        self.removed_version = removed_version
399
400    def __call__(self, func):
401
402        alt_msg = ''
403        if self.alt_func is not None:
404            alt_msg = ' Use ``%s`` instead.' % self.alt_func
405        rmv_msg = ''
406        if self.removed_version is not None:
407            rmv_msg = (' and will be removed in version %s' %
408                       self.removed_version)
409
410        msg = ('Function ``%s`` is deprecated' % func.__name__ +
411               rmv_msg + '.' + alt_msg)
412
413        @functools.wraps(func)
414        def wrapped(*args, **kwargs):
415            if self.behavior == 'warn':
416                func_code = func.__code__
417                warnings.simplefilter('always', skimage_deprecation)
418                warnings.warn_explicit(msg,
419                                       category=skimage_deprecation,
420                                       filename=func_code.co_filename,
421                                       lineno=func_code.co_firstlineno + 1)
422            elif self.behavior == 'raise':
423                raise skimage_deprecation(msg)
424            return func(*args, **kwargs)
425
426        # modify doc string to display deprecation warning
427        doc = '**Deprecated function**.' + alt_msg
428        if wrapped.__doc__ is None:
429            wrapped.__doc__ = doc
430        else:
431            wrapped.__doc__ = doc + '\n\n    ' + wrapped.__doc__
432
433        return wrapped
434
435
436def get_bound_method_class(m):
437    """Return the class for a bound method.
438
439    """
440    return m.im_class if sys.version < '3' else m.__self__.__class__
441
442
443def safe_as_int(val, atol=1e-3):
444    """
445    Attempt to safely cast values to integer format.
446
447    Parameters
448    ----------
449    val : scalar or iterable of scalars
450        Number or container of numbers which are intended to be interpreted as
451        integers, e.g., for indexing purposes, but which may not carry integer
452        type.
453    atol : float
454        Absolute tolerance away from nearest integer to consider values in
455        ``val`` functionally integers.
456
457    Returns
458    -------
459    val_int : NumPy scalar or ndarray of dtype `np.int64`
460        Returns the input value(s) coerced to dtype `np.int64` assuming all
461        were within ``atol`` of the nearest integer.
462
463    Notes
464    -----
465    This operation calculates ``val`` modulo 1, which returns the mantissa of
466    all values. Then all mantissas greater than 0.5 are subtracted from one.
467    Finally, the absolute tolerance from zero is calculated. If it is less
468    than ``atol`` for all value(s) in ``val``, they are rounded and returned
469    in an integer array. Or, if ``val`` was a scalar, a NumPy scalar type is
470    returned.
471
472    If any value(s) are outside the specified tolerance, an informative error
473    is raised.
474
475    Examples
476    --------
477    >>> safe_as_int(7.0)
478    7
479
480    >>> safe_as_int([9, 4, 2.9999999999])
481    array([9, 4, 3])
482
483    >>> safe_as_int(53.1)
484    Traceback (most recent call last):
485        ...
486    ValueError: Integer argument required but received 53.1, check inputs.
487
488    >>> safe_as_int(53.01, atol=0.01)
489    53
490
491    """
492    mod = np.asarray(val) % 1                # Extract mantissa
493
494    # Check for and subtract any mod values > 0.5 from 1
495    if mod.ndim == 0:                        # Scalar input, cannot be indexed
496        if mod > 0.5:
497            mod = 1 - mod
498    else:                                    # Iterable input, now ndarray
499        mod[mod > 0.5] = 1 - mod[mod > 0.5]  # Test on each side of nearest int
500
501    try:
502        np.testing.assert_allclose(mod, 0, atol=atol)
503    except AssertionError:
504        raise ValueError(f'Integer argument required but received '
505                         f'{val}, check inputs.')
506
507    return np.round(val).astype(np.int64)
508
509
510def check_shape_equality(im1, im2):
511    """Raise an error if the shape do not match."""
512    if not im1.shape == im2.shape:
513        raise ValueError('Input images must have the same dimensions.')
514    return
515
516
517def slice_at_axis(sl, axis):
518    """
519    Construct tuple of slices to slice an array in the given dimension.
520
521    Parameters
522    ----------
523    sl : slice
524        The slice for the given dimension.
525    axis : int
526        The axis to which `sl` is applied. All other dimensions are left
527        "unsliced".
528
529    Returns
530    -------
531    sl : tuple of slices
532        A tuple with slices matching `shape` in length.
533
534    Examples
535    --------
536    >>> slice_at_axis(slice(None, 3, -1), 1)
537    (slice(None, None, None), slice(None, 3, -1), Ellipsis)
538    """
539    return (slice(None),) * axis + (sl,) + (...,)
540
541
542def reshape_nd(arr, ndim, dim):
543    """Reshape a 1D array to have n dimensions, all singletons but one.
544
545    Parameters
546    ----------
547    arr : array, shape (N,)
548        Input array
549    ndim : int
550        Number of desired dimensions of reshaped array.
551    dim : int
552        Which dimension/axis will not be singleton-sized.
553
554    Returns
555    -------
556    arr_reshaped : array, shape ([1, ...], N, [1,...])
557        View of `arr` reshaped to the desired shape.
558
559    Examples
560    --------
561    >>> rng = np.random.default_rng()
562    >>> arr = rng.random(7)
563    >>> reshape_nd(arr, 2, 0).shape
564    (7, 1)
565    >>> reshape_nd(arr, 3, 1).shape
566    (1, 7, 1)
567    >>> reshape_nd(arr, 4, -1).shape
568    (1, 1, 1, 7)
569    """
570    if arr.ndim != 1:
571        raise ValueError("arr must be a 1D array")
572    new_shape = [1] * ndim
573    new_shape[dim] = -1
574    return np.reshape(arr, new_shape)
575
576
577def check_nD(array, ndim, arg_name='image'):
578    """
579    Verify an array meets the desired ndims and array isn't empty.
580
581    Parameters
582    ----------
583    array : array-like
584        Input array to be validated
585    ndim : int or iterable of ints
586        Allowable ndim or ndims for the array.
587    arg_name : str, optional
588        The name of the array in the original function.
589
590    """
591    array = np.asanyarray(array)
592    msg_incorrect_dim = "The parameter `%s` must be a %s-dimensional array"
593    msg_empty_array = "The parameter `%s` cannot be an empty array"
594    if isinstance(ndim, int):
595        ndim = [ndim]
596    if array.size == 0:
597        raise ValueError(msg_empty_array % (arg_name))
598    if array.ndim not in ndim:
599        raise ValueError(
600            msg_incorrect_dim % (arg_name, '-or-'.join([str(n) for n in ndim]))
601        )
602
603
604def convert_to_float(image, preserve_range):
605    """Convert input image to float image with the appropriate range.
606
607    Parameters
608    ----------
609    image : ndarray
610        Input image.
611    preserve_range : bool
612        Determines if the range of the image should be kept or transformed
613        using img_as_float. Also see
614        https://scikit-image.org/docs/dev/user_guide/data_types.html
615
616    Notes
617    -----
618    * Input images with `float32` data type are not upcast.
619
620    Returns
621    -------
622    image : ndarray
623        Transformed version of the input.
624
625    """
626    if image.dtype == np.float16:
627        return image.astype(np.float32)
628    if preserve_range:
629        # Convert image to double only if it is not single or double
630        # precision float
631        if image.dtype.char not in 'df':
632            image = image.astype(float)
633    else:
634        from ..util.dtype import img_as_float
635        image = img_as_float(image)
636    return image
637
638
639def _validate_interpolation_order(image_dtype, order):
640    """Validate and return spline interpolation's order.
641
642    Parameters
643    ----------
644    image_dtype : dtype
645        Image dtype.
646    order : int, optional
647        The order of the spline interpolation. The order has to be in
648        the range 0-5. See `skimage.transform.warp` for detail.
649
650    Returns
651    -------
652    order : int
653        if input order is None, returns 0 if image_dtype is bool and 1
654        otherwise. Otherwise, image_dtype is checked and input order
655        is validated accordingly (order > 0 is not supported for bool
656        image dtype)
657
658    """
659
660    if order is None:
661        return 0 if image_dtype == bool else 1
662
663    if order < 0 or order > 5:
664        raise ValueError("Spline interpolation order has to be in the "
665                         "range 0-5.")
666
667    if image_dtype == bool and order != 0:
668        raise ValueError(
669            "Input image dtype is bool. Interpolation is not defined "
670             "with bool data type. Please set order to 0 or explicitely "
671             "cast input image to another data type.")
672
673    return order
674
675
676def _to_np_mode(mode):
677    """Convert padding modes from `ndi.correlate` to `np.pad`."""
678    mode_translation_dict = dict(nearest='edge', reflect='symmetric',
679                                 mirror='reflect')
680    if mode in mode_translation_dict:
681        mode = mode_translation_dict[mode]
682    return mode
683
684
685def _to_ndimage_mode(mode):
686    """Convert from `numpy.pad` mode name to the corresponding ndimage mode."""
687    mode_translation_dict = dict(constant='constant', edge='nearest',
688                                 symmetric='reflect', reflect='mirror',
689                                 wrap='wrap')
690    if mode not in mode_translation_dict:
691        raise ValueError(
692            (f"Unknown mode: '{mode}', or cannot translate mode. The "
693             f"mode should be one of 'constant', 'edge', 'symmetric', "
694             f"'reflect', or 'wrap'. See the documentation of numpy.pad for "
695             f"more info."))
696    return _fix_ndimage_mode(mode_translation_dict[mode])
697
698
699def _fix_ndimage_mode(mode):
700    # SciPy 1.6.0 introduced grid variants of constant and wrap which
701    # have less surprising behavior for images. Use these when available
702    grid_modes = {'constant': 'grid-constant', 'wrap': 'grid-wrap'}
703    if NumpyVersion(scipy.__version__) >= '1.6.0':
704        mode = grid_modes.get(mode, mode)
705    return mode
706
707
708new_float_type = {
709    # preserved types
710    np.float32().dtype.char: np.float32,
711    np.float64().dtype.char: np.float64,
712    np.complex64().dtype.char: np.complex64,
713    np.complex128().dtype.char: np.complex128,
714    # altered types
715    np.float16().dtype.char: np.float32,
716    'g': np.float64,      # np.float128 ; doesn't exist on windows
717    'G': np.complex128,   # np.complex256 ; doesn't exist on windows
718}
719
720
721def _supported_float_type(input_dtype, allow_complex=False):
722    """Return an appropriate floating-point dtype for a given dtype.
723
724    float32, float64, complex64, complex128 are preserved.
725    float16 is promoted to float32.
726    complex256 is demoted to complex128.
727    Other types are cast to float64.
728
729    Parameters
730    ----------
731    input_dtype : np.dtype or Iterable of np.dtype
732        The input dtype. If a sequence of multiple dtypes is provided, each
733        dtype is first converted to a supported floating point type and the
734        final dtype is then determined by applying `np.result_type` on the
735        sequence of supported floating point types.
736    allow_complex : bool, optional
737        If False, raise a ValueError on complex-valued inputs.
738
739    Returns
740    -------
741    float_type : dtype
742        Floating-point dtype for the image.
743    """
744    if isinstance(input_dtype, Iterable) and not isinstance(input_dtype, str):
745        return np.result_type(*(_supported_float_type(d) for d in input_dtype))
746    input_dtype = np.dtype(input_dtype)
747    if not allow_complex and input_dtype.kind == 'c':
748        raise ValueError("complex valued input is not supported")
749    return new_float_type.get(input_dtype.char, np.float64)
750
751
752def identity(image, *args, **kwargs):
753    """Returns the first argument unmodified."""
754    return image
755