1"""Function for recording and reporting deprecations.
2
3Note
4-----
5this file is copied (with minor modifications) from the Nibabel.
6https://github.com/nipy/nibabel. See COPYING file distributed along with
7the Nibabel package for the copyright and license terms.
8
9"""
10
11import functools
12import warnings
13import re
14from inspect import signature
15from dipy import __version__
16from packaging.version import parse as version_cmp
17
18_LEADING_WHITE = re.compile(r'^(\s*)')
19
20
21class ExpiredDeprecationError(RuntimeError):
22    """Error for expired deprecation.
23
24    Error raised when a called function or method has passed out of its
25    deprecation period.
26
27    """
28
29    pass
30
31
32class ArgsDeprecationWarning(DeprecationWarning):
33    """Warning for args deprecation.
34
35    Warning raised when a function or method argument has changed or removed.
36
37    """
38
39    pass
40
41
42def _ensure_cr(text):
43    """Remove trailing whitespace and add carriage return.
44
45    Ensures that `text` always ends with a carriage return
46    """
47    return text.rstrip() + '\n'
48
49
50def _add_dep_doc(old_doc, dep_doc):
51    """Add deprecation message `dep_doc` to docstring in `old_doc`.
52
53    Parameters
54    ----------
55    old_doc : str
56        Docstring from some object.
57    dep_doc : str
58        Deprecation warning to add to top of docstring, after initial line.
59
60    Returns
61    -------
62    new_doc : str
63        `old_doc` with `dep_doc` inserted after any first lines of docstring.
64
65    """
66    dep_doc = _ensure_cr(dep_doc)
67    if not old_doc:
68        return dep_doc
69    old_doc = _ensure_cr(old_doc)
70    old_lines = old_doc.splitlines()
71    new_lines = []
72    for line_no, line in enumerate(old_lines):
73        if line.strip():
74            new_lines.append(line)
75        else:
76            break
77    next_line = line_no + 1
78    if next_line >= len(old_lines):
79        # nothing following first paragraph, just append message
80        return old_doc + '\n' + dep_doc
81    indent = _LEADING_WHITE.match(old_lines[next_line]).group()
82    dep_lines = [indent + L for L in [''] + dep_doc.splitlines() + ['']]
83    return '\n'.join(new_lines + dep_lines + old_lines[next_line:]) + '\n'
84
85
86def cmp_pkg_version(version_str, pkg_version_str=__version__):
87    """Compare `version_str` to current package version.
88
89    Parameters
90    ----------
91    version_str : str
92        Version string to compare to current package version
93    pkg_version_str : str, optional
94        Version of our package.  Optional, set fom ``__version__`` by default.
95    Returns
96    -------
97    version_cmp : int
98        1 if `version_str` is a later version than `pkg_version_str`, 0 if
99        same, -1 if earlier.
100    Examples
101    --------
102    >>> cmp_pkg_version('1.2.1', '1.2.0')
103    1
104    >>> cmp_pkg_version('1.2.0dev', '1.2.0')
105    -1
106
107    """
108    if any([re.match(r'^[a-z, A-Z]', v)for v in [version_str,
109                                                 pkg_version_str]]):
110        msg = 'Invalid version {0} or {1}'.format(version_str, pkg_version_str)
111        raise ValueError(msg)
112    elif version_cmp(version_str) > version_cmp(pkg_version_str):
113        return 1
114    elif version_cmp(version_str) == version_cmp(pkg_version_str):
115        return 0
116    else:
117        return -1
118
119
120def is_bad_version(version_str, version_comparator=cmp_pkg_version):
121    """Return True if `version_str` is too high."""
122    return version_comparator(version_str) == -1
123
124
125def deprecate_with_version(message, since='', until='',
126                           version_comparator=cmp_pkg_version,
127                           warn_class=DeprecationWarning,
128                           error_class=ExpiredDeprecationError):
129    """Return decorator function function for deprecation warning / error.
130
131    The decorated function / method will:
132
133    * Raise the given `warning_class` warning when the function / method gets
134      called, up to (and including) version `until` (if specified);
135    * Raise the given `error_class` error when the function / method gets
136      called, when the package version is greater than version `until` (if
137      specified).
138
139    Parameters
140    ----------
141    message : str
142        Message explaining deprecation, giving possible alternatives.
143    since : str, optional
144        Released version at which object was first deprecated.
145    until : str, optional
146        Last released version at which this function will still raise a
147        deprecation warning.  Versions higher than this will raise an
148        error.
149    version_comparator : callable
150        Callable accepting string as argument, and return 1 if string
151        represents a higher version than encoded in the `version_comparator`, 0
152        if the version is equal, and -1 if the version is lower.  For example,
153        the `version_comparator` may compare the input version string to the
154        current package version string.
155    warn_class : class, optional
156        Class of warning to generate for deprecation.
157    error_class : class, optional
158        Class of error to generate when `version_comparator` returns 1 for a
159        given argument of ``until``.
160
161    Returns
162    -------
163    deprecator : func
164        Function returning a decorator.
165
166    """
167    messages = [message]
168    if (since, until) != ('', ''):
169        messages.append('')
170    if since:
171        messages.append('* deprecated from version: ' + since)
172    if until:
173        messages.append('* {0} {1} as of version: {2}'.format(
174            "Raises" if is_bad_version(until) else "Will raise",
175            error_class,
176            until))
177    message = '\n'.join(messages)
178
179    def deprecator(func):
180
181        @functools.wraps(func)
182        def deprecated_func(*args, **kwargs):
183            if until and is_bad_version(until, version_comparator):
184                raise error_class(message)
185            warnings.warn(message, warn_class, stacklevel=2)
186            return func(*args, **kwargs)
187
188        deprecated_func.__doc__ = _add_dep_doc(deprecated_func.__doc__,
189                                               message)
190        return deprecated_func
191
192    return deprecator
193
194
195def deprecated_params(old_name, new_name=None, since='', until='',
196                      version_comparator=cmp_pkg_version,
197                      arg_in_kwargs=False,
198                      warn_class=ArgsDeprecationWarning,
199                      error_class=ExpiredDeprecationError,
200                      alternative=''):
201    """Deprecate a *renamed* or *removed* function argument.
202
203    The decorator assumes that the argument with the ``old_name`` was removed
204    from the function signature and the ``new_name`` replaced it at the
205    **same position** in the signature.  If the ``old_name`` argument is
206    given when calling the decorated function the decorator will catch it and
207    issue a deprecation warning and pass it on as ``new_name`` argument.
208
209    Parameters
210    ----------
211    old_name : str or list/tuple thereof
212        The old name of the argument.
213    new_name : str or list/tuple thereof or ``None``, optional
214        The new name of the argument. Set this to `None` to remove the
215        argument ``old_name`` instead of renaming it.
216    since : str or number or list/tuple thereof, optional
217        The release at which the old argument became deprecated.
218    until : str or number or list/tuple thereof, optional
219        Last released version at which this function will still raise a
220        deprecation warning.  Versions higher than this will raise an
221        error.
222    version_comparator : callable
223        Callable accepting string as argument, and return 1 if string
224        represents a higher version than encoded in the ``version_comparator``,
225        0 if the version is equal, and -1 if the version is lower. For example,
226        the ``version_comparator`` may compare the input version string to the
227        current package version string.
228    arg_in_kwargs : bool or list/tuple thereof, optional
229        If the argument is not a named argument (for example it
230        was meant to be consumed by ``**kwargs``) set this to
231        ``True``.  Otherwise the decorator will throw an Exception
232        if the ``new_name`` cannot be found in the signature of
233        the decorated function.
234        Default is ``False``.
235    warn_class : warning, optional
236        Warning to be issued.
237    error_class : Exception, optional
238        Error to be issued
239    alternative : str, optional
240        An alternative function or class name that the user may use in
241        place of the deprecated object if ``new_name`` is None. The deprecation
242        warning will tell the user about this alternative if provided.
243
244    Raises
245    ------
246    TypeError
247        If the new argument name cannot be found in the function
248        signature and arg_in_kwargs was False or if it is used to
249        deprecate the name of the ``*args``-, ``**kwargs``-like arguments.
250        At runtime such an Error is raised if both the new_name
251        and old_name were specified when calling the function and
252        "relax=False".
253
254    Notes
255    -----
256    This function is based on the Astropy (major modification).
257    https://github.com/astropy/astropy. See COPYING file distributed along with
258    the astropy package for the copyright and license terms.
259
260    Examples
261    --------
262    The deprecation warnings are not shown in the following examples.
263    To deprecate a positional or keyword argument::
264    >>> from dipy.utils.deprecator import deprecated_params
265    >>> @deprecated_params('sig', 'sigma', '0.3')
266    ... def test(sigma):
267    ...     return sigma
268    >>> test(2)
269    2
270    >>> test(sigma=2)
271    2
272    >>> test(sig=2)  # doctest: +SKIP
273    2
274
275    It is also possible to replace multiple arguments. The ``old_name``,
276    ``new_name`` and ``since`` have to be `tuple` or `list` and contain the
277    same number of entries::
278    >>> @deprecated_params(['a', 'b'], ['alpha', 'beta'],
279    ...                    ['0.2', 0.4])
280    ... def test(alpha, beta):
281    ...     return alpha, beta
282    >>> test(a=2, b=3)  # doctest: +SKIP
283    (2, 3)
284
285    """
286    if isinstance(old_name, (list, tuple)):
287        # Normalize input parameters
288        if not isinstance(arg_in_kwargs, (list, tuple)):
289            arg_in_kwargs = [arg_in_kwargs] * len(old_name)
290        if not isinstance(since, (list, tuple)):
291            since = [since] * len(old_name)
292        if not isinstance(until, (list, tuple)):
293            until = [until] * len(old_name)
294        if not isinstance(new_name, (list, tuple)):
295            new_name = [new_name] * len(old_name)
296
297        if len(set([len(old_name), len(new_name), len(since),
298                    len(until), len(arg_in_kwargs)])) != 1:
299            raise ValueError("All parameters should have the same length")
300    else:
301        # To allow a uniform approach later on, wrap all arguments in lists.
302        old_name = [old_name]
303        new_name = [new_name]
304        since = [since]
305        until = [until]
306        arg_in_kwargs = [arg_in_kwargs]
307
308    def deprecator(function):
309        # The named arguments of the function.
310        arguments = signature(function).parameters
311        positions = [None] * len(old_name)
312
313        for i, (o_name, n_name, in_keywords) in enumerate(zip(old_name,
314                                                              new_name,
315                                                              arg_in_kwargs)):
316            # Determine the position of the argument.
317            if in_keywords:
318                continue
319
320            if n_name is not None and n_name not in arguments:
321                # In case the argument is not found in the list of arguments
322                # the only remaining possibility is that it should be caught
323                # by some kind of **kwargs argument.
324                msg = '"{}" was not specified in the function '.format(n_name)
325                msg += 'signature. If it was meant to be part of '
326                msg += '"**kwargs" then set "arg_in_kwargs" to "True"'
327                raise TypeError(msg)
328
329            key = o_name if n_name is None else n_name
330            param = arguments[key]
331
332            if param.kind == param.POSITIONAL_OR_KEYWORD:
333                key = o_name if n_name is None else n_name
334                positions[i] = list(arguments.keys()).index(key)
335            elif param.kind == param.KEYWORD_ONLY:
336                # These cannot be specified by position.
337                positions[i] = None
338            else:
339                # positional-only argument, varargs, varkwargs or some
340                # unknown type:
341                msg = 'cannot replace argument "{}" '.format(n_name)
342                msg += 'of kind {}.'.format(repr(param.kind))
343                raise TypeError(msg)
344
345        @functools.wraps(function)
346        def wrapper(*args, **kwargs):
347            for i, (o_name, n_name) in enumerate(zip(old_name, new_name)):
348                messages = ['"{}" was deprecated'.format(o_name), ]
349                if (since[i], until[i]) != ('', ''):
350                    messages.append('')
351                if since[i]:
352                    messages.append('* deprecated from version: ' +
353                                    str(since[i]))
354                if until[i]:
355                    messages.append('* {0} {1} as of version: {2}'.format(
356                        "Raises" if is_bad_version(until[i]) else "Will raise",
357                        error_class,
358                        until[i]))
359                messages.append('')
360                message = '\n'.join(messages)
361
362                # The only way to have oldkeyword inside the function is
363                # that it is passed as kwarg because the oldkeyword
364                # parameter was renamed to newkeyword.
365                if o_name in kwargs:
366                    value = kwargs.pop(o_name)
367                    # Check if the newkeyword was given as well.
368                    newarg_in_args = (positions[i] is not None and
369                                      len(args) > positions[i])
370                    newarg_in_kwargs = n_name in kwargs
371
372                    if newarg_in_args or newarg_in_kwargs:
373                        msg = 'cannot specify both "{}"'.format(o_name)
374                        msg += ' (deprecated parameter) and '
375                        msg += '"{}" (new parameter name).'.format(n_name)
376                        raise TypeError(msg)
377
378                    # Pass the value of the old argument with the
379                    # name of the new argument to the function
380                    key = n_name or o_name
381                    kwargs[key] = value
382
383                    if n_name is not None:
384                        message += '* Use argument "{}" instead.' \
385                            .format(n_name)
386                    elif alternative:
387                        message += '* Use {} instead.'.format(alternative)
388
389                    if until[i] and is_bad_version(until[i],
390                                                   version_comparator):
391                        raise error_class(message)
392                    warnings.warn(message, warn_class, stacklevel=2)
393
394                # Deprecated keyword without replacement is given as
395                # positional argument.
396                elif (not n_name and positions[i] and
397                      len(args) > positions[i]):
398                    if alternative:
399                        message += '* Use {} instead.'.format(alternative)
400                    if until[i] and is_bad_version(until[i],
401                                                   version_comparator):
402                        raise error_class(message)
403
404                    warnings.warn(message, warn_class, stacklevel=2)
405
406            return function(*args, **kwargs)
407
408        return wrapper
409    return deprecator
410