1import inspect 2import functools 3import sys 4import warnings 5from collections.abc import Iterable 6 7import numpy as np 8import scipy 9from numpy.lib import NumpyVersion 10 11from ._warnings import all_warnings, warn 12 13 14__all__ = ['deprecated', 'get_bound_method_class', 'all_warnings', 15 'safe_as_int', 'check_shape_equality', 'check_nD', 'warn', 16 'reshape_nd', 'identity', 'slice_at_axis'] 17 18 19class skimage_deprecation(Warning): 20 """Create our own deprecation class, since Python >= 2.7 21 silences deprecations by default. 22 23 """ 24 pass 25 26 27class change_default_value: 28 """Decorator for changing the default value of an argument. 29 30 Parameters 31 ---------- 32 arg_name: str 33 The name of the argument to be updated. 34 new_value: any 35 The argument new value. 36 changed_version : str 37 The package version in which the change will be introduced. 38 warning_msg: str 39 Optional warning message. If None, a generic warning message 40 is used. 41 42 """ 43 44 def __init__(self, arg_name, *, new_value, changed_version, 45 warning_msg=None): 46 self.arg_name = arg_name 47 self.new_value = new_value 48 self.warning_msg = warning_msg 49 self.changed_version = changed_version 50 51 def __call__(self, func): 52 parameters = inspect.signature(func).parameters 53 arg_idx = list(parameters.keys()).index(self.arg_name) 54 old_value = parameters[self.arg_name].default 55 56 if self.warning_msg is None: 57 self.warning_msg = ( 58 f'The new recommended value for {self.arg_name} is ' 59 f'{self.new_value}. Until version {self.changed_version}, ' 60 f'the default {self.arg_name} value is {old_value}. ' 61 f'From version {self.changed_version}, the {self.arg_name} ' 62 f'default value will be {self.new_value}. To avoid ' 63 f'this warning, please explicitly set {self.arg_name} value.') 64 65 @functools.wraps(func) 66 def fixed_func(*args, **kwargs): 67 if len(args) < arg_idx + 1 and self.arg_name not in kwargs.keys(): 68 # warn that arg_name default value changed: 69 warnings.warn(self.warning_msg, FutureWarning, stacklevel=2) 70 return func(*args, **kwargs) 71 72 return fixed_func 73 74 75class remove_arg: 76 """Decorator to remove an argument from function's signature. 77 78 Parameters 79 ---------- 80 arg_name: str 81 The name of the argument to be removed. 82 changed_version : str 83 The package version in which the warning will be replaced by 84 an error. 85 help_msg: str 86 Optional message appended to the generic warning message. 87 88 """ 89 90 def __init__(self, arg_name, *, changed_version, help_msg=None): 91 self.arg_name = arg_name 92 self.help_msg = help_msg 93 self.changed_version = changed_version 94 95 def __call__(self, func): 96 parameters = inspect.signature(func).parameters 97 arg_idx = list(parameters.keys()).index(self.arg_name) 98 warning_msg = ( 99 f'{self.arg_name} argument is deprecated and will be removed ' 100 f'in version {self.changed_version}. To avoid this warning, ' 101 f'please do not use the {self.arg_name} argument. Please ' 102 f'see {func.__name__} documentation for more details.') 103 104 if self.help_msg is not None: 105 warning_msg += f' {self.help_msg}' 106 107 @functools.wraps(func) 108 def fixed_func(*args, **kwargs): 109 if len(args) > arg_idx or self.arg_name in kwargs.keys(): 110 # warn that arg_name is deprecated 111 warnings.warn(warning_msg, FutureWarning, stacklevel=2) 112 return func(*args, **kwargs) 113 114 return fixed_func 115 116 117def docstring_add_deprecated(func, kwarg_mapping, deprecated_version): 118 """Add deprecated kwarg(s) to the "Other Params" section of a docstring. 119 120 Parameters 121 --------- 122 func : function 123 The function whose docstring we wish to update. 124 kwarg_mapping : dict 125 A dict containing {old_arg: new_arg} key/value pairs as used by 126 `deprecate_kwarg`. 127 deprecated_version : str 128 A major.minor version string specifying when old_arg was 129 deprecated. 130 131 Returns 132 ------- 133 new_doc : str 134 The updated docstring. Returns the original docstring if numpydoc is 135 not available. 136 """ 137 if func.__doc__ is None: 138 return None 139 try: 140 from numpydoc.docscrape import FunctionDoc, Parameter 141 except ImportError: 142 # Return an unmodified docstring if numpydoc is not available. 143 return func.__doc__ 144 145 Doc = FunctionDoc(func) 146 for old_arg, new_arg in kwarg_mapping.items(): 147 desc = [f'Deprecated in favor of `{new_arg}`.', 148 f'', 149 f'.. deprecated:: {deprecated_version}'] 150 Doc['Other Parameters'].append( 151 Parameter(name=old_arg, 152 type='DEPRECATED', 153 desc=desc) 154 ) 155 new_docstring = str(Doc) 156 157 # new_docstring will have a header starting with: 158 # 159 # .. function:: func.__name__ 160 # 161 # and some additional blank lines. We strip these off below. 162 split = new_docstring.split('\n') 163 no_header = split[1:] 164 while not no_header[0].strip(): 165 no_header.pop(0) 166 167 # Store the initial description before any of the Parameters fields. 168 # Usually this is a single line, but the while loop covers any case 169 # where it is not. 170 descr = no_header.pop(0) 171 while no_header[0].strip(): 172 descr += '\n ' + no_header.pop(0) 173 descr += '\n\n' 174 # '\n ' rather than '\n' here to restore the original indentation. 175 final_docstring = descr + '\n '.join(no_header) 176 # strip any extra spaces from ends of lines 177 final_docstring = '\n'.join( 178 [line.rstrip() for line in final_docstring.split('\n')] 179 ) 180 return final_docstring 181 182 183class deprecate_kwarg: 184 """Decorator ensuring backward compatibility when argument names are 185 modified in a function definition. 186 187 Parameters 188 ---------- 189 kwarg_mapping: dict 190 Mapping between the function's old argument names and the new 191 ones. 192 deprecated_version : str 193 The package version in which the argument was first deprecated. 194 warning_msg: str 195 Optional warning message. If None, a generic warning message 196 is used. 197 removed_version : str 198 The package version in which the deprecated argument will be 199 removed. 200 201 """ 202 203 def __init__(self, kwarg_mapping, deprecated_version, warning_msg=None, 204 removed_version=None): 205 self.kwarg_mapping = kwarg_mapping 206 if warning_msg is None: 207 self.warning_msg = ("`{old_arg}` is a deprecated argument name " 208 "for `{func_name}`. ") 209 if removed_version is not None: 210 self.warning_msg += (f'It will be removed in ' 211 f'version {removed_version}.') 212 self.warning_msg += "Please use `{new_arg}` instead." 213 else: 214 self.warning_msg = warning_msg 215 216 self.deprecated_version = deprecated_version 217 218 def __call__(self, func): 219 220 @functools.wraps(func) 221 def fixed_func(*args, **kwargs): 222 for old_arg, new_arg in self.kwarg_mapping.items(): 223 if old_arg in kwargs: 224 # warn that the function interface has changed: 225 warnings.warn(self.warning_msg.format( 226 old_arg=old_arg, func_name=func.__name__, 227 new_arg=new_arg), FutureWarning, stacklevel=2) 228 # Substitute new_arg to old_arg 229 kwargs[new_arg] = kwargs.pop(old_arg) 230 231 # Call the function with the fixed arguments 232 return func(*args, **kwargs) 233 234 if func.__doc__ is not None: 235 newdoc = docstring_add_deprecated(func, self.kwarg_mapping, 236 self.deprecated_version) 237 fixed_func.__doc__ = newdoc 238 return fixed_func 239 240 241class deprecate_multichannel_kwarg(deprecate_kwarg): 242 """Decorator for deprecating multichannel keyword in favor of channel_axis. 243 244 Parameters 245 ---------- 246 removed_version : str 247 The package version in which the deprecated argument will be 248 removed. 249 250 """ 251 252 def __init__(self, removed_version='1.0', multichannel_position=None): 253 super().__init__( 254 kwarg_mapping={'multichannel': 'channel_axis'}, 255 deprecated_version='0.19', 256 warning_msg=None, 257 removed_version=removed_version) 258 self.position = multichannel_position 259 260 def __call__(self, func): 261 @functools.wraps(func) 262 def fixed_func(*args, **kwargs): 263 264 if self.position is not None and len(args) > self.position: 265 warning_msg = ( 266 "Providing the `multichannel` argument positionally to " 267 "{func_name} is deprecated. Use the `channel_axis` kwarg " 268 "instead." 269 ) 270 warnings.warn(warning_msg.format(func_name=func.__name__), 271 FutureWarning, 272 stacklevel=2) 273 if 'channel_axis' in kwargs: 274 raise ValueError( 275 "Cannot provide both a `channel_axis` kwarg and a " 276 "positional `multichannel` value." 277 ) 278 else: 279 channel_axis = -1 if args[self.position] else None 280 kwargs['channel_axis'] = channel_axis 281 282 if 'multichannel' in kwargs: 283 # warn that the function interface has changed: 284 warnings.warn(self.warning_msg.format( 285 old_arg='multichannel', func_name=func.__name__, 286 new_arg='channel_axis'), FutureWarning, stacklevel=2) 287 288 # multichannel = True -> last axis corresponds to channels 289 convert = {True: -1, False: None} 290 kwargs['channel_axis'] = convert[kwargs.pop('multichannel')] 291 292 # Call the function with the fixed arguments 293 return func(*args, **kwargs) 294 295 if func.__doc__ is not None: 296 newdoc = docstring_add_deprecated( 297 func, {'multichannel': 'channel_axis'}, '0.19') 298 fixed_func.__doc__ = newdoc 299 return fixed_func 300 301 302class channel_as_last_axis(): 303 """Decorator for automatically making channels axis last for all arrays. 304 305 This decorator reorders axes for compatibility with functions that only 306 support channels along the last axis. After the function call is complete 307 the channels axis is restored back to its original position. 308 309 Parameters 310 ---------- 311 channel_arg_positions : tuple of int, optional 312 Positional arguments at the positions specified in this tuple are 313 assumed to be multichannel arrays. The default is to assume only the 314 first argument to the function is a multichannel array. 315 channel_kwarg_names : tuple of str, optional 316 A tuple containing the names of any keyword arguments corresponding to 317 multichannel arrays. 318 multichannel_output : bool, optional 319 A boolean that should be True if the output of the function is not a 320 multichannel array and False otherwise. This decorator does not 321 currently support the general case of functions with multiple outputs 322 where some or all are multichannel. 323 324 """ 325 def __init__(self, channel_arg_positions=(0,), channel_kwarg_names=(), 326 multichannel_output=True): 327 self.arg_positions = set(channel_arg_positions) 328 self.kwarg_names = set(channel_kwarg_names) 329 self.multichannel_output = multichannel_output 330 331 def __call__(self, func): 332 @functools.wraps(func) 333 def fixed_func(*args, **kwargs): 334 335 channel_axis = kwargs.get('channel_axis', None) 336 337 if channel_axis is None: 338 return func(*args, **kwargs) 339 340 # TODO: convert scalars to a tuple in anticipation of eventually 341 # supporting a tuple of channel axes. Right now, only an 342 # integer or a single-element tuple is supported, though. 343 if np.isscalar(channel_axis): 344 channel_axis = (channel_axis,) 345 if len(channel_axis) > 1: 346 raise ValueError( 347 "only a single channel axis is currently suported") 348 349 if channel_axis == (-1,) or channel_axis == -1: 350 return func(*args, **kwargs) 351 352 if self.arg_positions: 353 new_args = [] 354 for pos, arg in enumerate(args): 355 if pos in self.arg_positions: 356 new_args.append(np.moveaxis(arg, channel_axis[0], -1)) 357 else: 358 new_args.append(arg) 359 new_args = tuple(new_args) 360 else: 361 new_args = args 362 363 for name in self.kwarg_names: 364 kwargs[name] = np.moveaxis(kwargs[name], channel_axis[0], -1) 365 366 # now that we have moved the channels axis to the last position, 367 # change the channel_axis argument to -1 368 kwargs["channel_axis"] = -1 369 370 # Call the function with the fixed arguments 371 out = func(*new_args, **kwargs) 372 if self.multichannel_output: 373 out = np.moveaxis(out, -1, channel_axis[0]) 374 return out 375 376 return fixed_func 377 378 379class deprecated(object): 380 """Decorator to mark deprecated functions with warning. 381 382 Adapted from <http://wiki.python.org/moin/PythonDecoratorLibrary>. 383 384 Parameters 385 ---------- 386 alt_func : str 387 If given, tell user what function to use instead. 388 behavior : {'warn', 'raise'} 389 Behavior during call to deprecated function: 'warn' = warn user that 390 function is deprecated; 'raise' = raise error. 391 removed_version : str 392 The package version in which the deprecated function will be removed. 393 """ 394 395 def __init__(self, alt_func=None, behavior='warn', removed_version=None): 396 self.alt_func = alt_func 397 self.behavior = behavior 398 self.removed_version = removed_version 399 400 def __call__(self, func): 401 402 alt_msg = '' 403 if self.alt_func is not None: 404 alt_msg = ' Use ``%s`` instead.' % self.alt_func 405 rmv_msg = '' 406 if self.removed_version is not None: 407 rmv_msg = (' and will be removed in version %s' % 408 self.removed_version) 409 410 msg = ('Function ``%s`` is deprecated' % func.__name__ + 411 rmv_msg + '.' + alt_msg) 412 413 @functools.wraps(func) 414 def wrapped(*args, **kwargs): 415 if self.behavior == 'warn': 416 func_code = func.__code__ 417 warnings.simplefilter('always', skimage_deprecation) 418 warnings.warn_explicit(msg, 419 category=skimage_deprecation, 420 filename=func_code.co_filename, 421 lineno=func_code.co_firstlineno + 1) 422 elif self.behavior == 'raise': 423 raise skimage_deprecation(msg) 424 return func(*args, **kwargs) 425 426 # modify doc string to display deprecation warning 427 doc = '**Deprecated function**.' + alt_msg 428 if wrapped.__doc__ is None: 429 wrapped.__doc__ = doc 430 else: 431 wrapped.__doc__ = doc + '\n\n ' + wrapped.__doc__ 432 433 return wrapped 434 435 436def get_bound_method_class(m): 437 """Return the class for a bound method. 438 439 """ 440 return m.im_class if sys.version < '3' else m.__self__.__class__ 441 442 443def safe_as_int(val, atol=1e-3): 444 """ 445 Attempt to safely cast values to integer format. 446 447 Parameters 448 ---------- 449 val : scalar or iterable of scalars 450 Number or container of numbers which are intended to be interpreted as 451 integers, e.g., for indexing purposes, but which may not carry integer 452 type. 453 atol : float 454 Absolute tolerance away from nearest integer to consider values in 455 ``val`` functionally integers. 456 457 Returns 458 ------- 459 val_int : NumPy scalar or ndarray of dtype `np.int64` 460 Returns the input value(s) coerced to dtype `np.int64` assuming all 461 were within ``atol`` of the nearest integer. 462 463 Notes 464 ----- 465 This operation calculates ``val`` modulo 1, which returns the mantissa of 466 all values. Then all mantissas greater than 0.5 are subtracted from one. 467 Finally, the absolute tolerance from zero is calculated. If it is less 468 than ``atol`` for all value(s) in ``val``, they are rounded and returned 469 in an integer array. Or, if ``val`` was a scalar, a NumPy scalar type is 470 returned. 471 472 If any value(s) are outside the specified tolerance, an informative error 473 is raised. 474 475 Examples 476 -------- 477 >>> safe_as_int(7.0) 478 7 479 480 >>> safe_as_int([9, 4, 2.9999999999]) 481 array([9, 4, 3]) 482 483 >>> safe_as_int(53.1) 484 Traceback (most recent call last): 485 ... 486 ValueError: Integer argument required but received 53.1, check inputs. 487 488 >>> safe_as_int(53.01, atol=0.01) 489 53 490 491 """ 492 mod = np.asarray(val) % 1 # Extract mantissa 493 494 # Check for and subtract any mod values > 0.5 from 1 495 if mod.ndim == 0: # Scalar input, cannot be indexed 496 if mod > 0.5: 497 mod = 1 - mod 498 else: # Iterable input, now ndarray 499 mod[mod > 0.5] = 1 - mod[mod > 0.5] # Test on each side of nearest int 500 501 try: 502 np.testing.assert_allclose(mod, 0, atol=atol) 503 except AssertionError: 504 raise ValueError(f'Integer argument required but received ' 505 f'{val}, check inputs.') 506 507 return np.round(val).astype(np.int64) 508 509 510def check_shape_equality(im1, im2): 511 """Raise an error if the shape do not match.""" 512 if not im1.shape == im2.shape: 513 raise ValueError('Input images must have the same dimensions.') 514 return 515 516 517def slice_at_axis(sl, axis): 518 """ 519 Construct tuple of slices to slice an array in the given dimension. 520 521 Parameters 522 ---------- 523 sl : slice 524 The slice for the given dimension. 525 axis : int 526 The axis to which `sl` is applied. All other dimensions are left 527 "unsliced". 528 529 Returns 530 ------- 531 sl : tuple of slices 532 A tuple with slices matching `shape` in length. 533 534 Examples 535 -------- 536 >>> slice_at_axis(slice(None, 3, -1), 1) 537 (slice(None, None, None), slice(None, 3, -1), Ellipsis) 538 """ 539 return (slice(None),) * axis + (sl,) + (...,) 540 541 542def reshape_nd(arr, ndim, dim): 543 """Reshape a 1D array to have n dimensions, all singletons but one. 544 545 Parameters 546 ---------- 547 arr : array, shape (N,) 548 Input array 549 ndim : int 550 Number of desired dimensions of reshaped array. 551 dim : int 552 Which dimension/axis will not be singleton-sized. 553 554 Returns 555 ------- 556 arr_reshaped : array, shape ([1, ...], N, [1,...]) 557 View of `arr` reshaped to the desired shape. 558 559 Examples 560 -------- 561 >>> rng = np.random.default_rng() 562 >>> arr = rng.random(7) 563 >>> reshape_nd(arr, 2, 0).shape 564 (7, 1) 565 >>> reshape_nd(arr, 3, 1).shape 566 (1, 7, 1) 567 >>> reshape_nd(arr, 4, -1).shape 568 (1, 1, 1, 7) 569 """ 570 if arr.ndim != 1: 571 raise ValueError("arr must be a 1D array") 572 new_shape = [1] * ndim 573 new_shape[dim] = -1 574 return np.reshape(arr, new_shape) 575 576 577def check_nD(array, ndim, arg_name='image'): 578 """ 579 Verify an array meets the desired ndims and array isn't empty. 580 581 Parameters 582 ---------- 583 array : array-like 584 Input array to be validated 585 ndim : int or iterable of ints 586 Allowable ndim or ndims for the array. 587 arg_name : str, optional 588 The name of the array in the original function. 589 590 """ 591 array = np.asanyarray(array) 592 msg_incorrect_dim = "The parameter `%s` must be a %s-dimensional array" 593 msg_empty_array = "The parameter `%s` cannot be an empty array" 594 if isinstance(ndim, int): 595 ndim = [ndim] 596 if array.size == 0: 597 raise ValueError(msg_empty_array % (arg_name)) 598 if array.ndim not in ndim: 599 raise ValueError( 600 msg_incorrect_dim % (arg_name, '-or-'.join([str(n) for n in ndim])) 601 ) 602 603 604def convert_to_float(image, preserve_range): 605 """Convert input image to float image with the appropriate range. 606 607 Parameters 608 ---------- 609 image : ndarray 610 Input image. 611 preserve_range : bool 612 Determines if the range of the image should be kept or transformed 613 using img_as_float. Also see 614 https://scikit-image.org/docs/dev/user_guide/data_types.html 615 616 Notes 617 ----- 618 * Input images with `float32` data type are not upcast. 619 620 Returns 621 ------- 622 image : ndarray 623 Transformed version of the input. 624 625 """ 626 if image.dtype == np.float16: 627 return image.astype(np.float32) 628 if preserve_range: 629 # Convert image to double only if it is not single or double 630 # precision float 631 if image.dtype.char not in 'df': 632 image = image.astype(float) 633 else: 634 from ..util.dtype import img_as_float 635 image = img_as_float(image) 636 return image 637 638 639def _validate_interpolation_order(image_dtype, order): 640 """Validate and return spline interpolation's order. 641 642 Parameters 643 ---------- 644 image_dtype : dtype 645 Image dtype. 646 order : int, optional 647 The order of the spline interpolation. The order has to be in 648 the range 0-5. See `skimage.transform.warp` for detail. 649 650 Returns 651 ------- 652 order : int 653 if input order is None, returns 0 if image_dtype is bool and 1 654 otherwise. Otherwise, image_dtype is checked and input order 655 is validated accordingly (order > 0 is not supported for bool 656 image dtype) 657 658 """ 659 660 if order is None: 661 return 0 if image_dtype == bool else 1 662 663 if order < 0 or order > 5: 664 raise ValueError("Spline interpolation order has to be in the " 665 "range 0-5.") 666 667 if image_dtype == bool and order != 0: 668 raise ValueError( 669 "Input image dtype is bool. Interpolation is not defined " 670 "with bool data type. Please set order to 0 or explicitely " 671 "cast input image to another data type.") 672 673 return order 674 675 676def _to_np_mode(mode): 677 """Convert padding modes from `ndi.correlate` to `np.pad`.""" 678 mode_translation_dict = dict(nearest='edge', reflect='symmetric', 679 mirror='reflect') 680 if mode in mode_translation_dict: 681 mode = mode_translation_dict[mode] 682 return mode 683 684 685def _to_ndimage_mode(mode): 686 """Convert from `numpy.pad` mode name to the corresponding ndimage mode.""" 687 mode_translation_dict = dict(constant='constant', edge='nearest', 688 symmetric='reflect', reflect='mirror', 689 wrap='wrap') 690 if mode not in mode_translation_dict: 691 raise ValueError( 692 (f"Unknown mode: '{mode}', or cannot translate mode. The " 693 f"mode should be one of 'constant', 'edge', 'symmetric', " 694 f"'reflect', or 'wrap'. See the documentation of numpy.pad for " 695 f"more info.")) 696 return _fix_ndimage_mode(mode_translation_dict[mode]) 697 698 699def _fix_ndimage_mode(mode): 700 # SciPy 1.6.0 introduced grid variants of constant and wrap which 701 # have less surprising behavior for images. Use these when available 702 grid_modes = {'constant': 'grid-constant', 'wrap': 'grid-wrap'} 703 if NumpyVersion(scipy.__version__) >= '1.6.0': 704 mode = grid_modes.get(mode, mode) 705 return mode 706 707 708new_float_type = { 709 # preserved types 710 np.float32().dtype.char: np.float32, 711 np.float64().dtype.char: np.float64, 712 np.complex64().dtype.char: np.complex64, 713 np.complex128().dtype.char: np.complex128, 714 # altered types 715 np.float16().dtype.char: np.float32, 716 'g': np.float64, # np.float128 ; doesn't exist on windows 717 'G': np.complex128, # np.complex256 ; doesn't exist on windows 718} 719 720 721def _supported_float_type(input_dtype, allow_complex=False): 722 """Return an appropriate floating-point dtype for a given dtype. 723 724 float32, float64, complex64, complex128 are preserved. 725 float16 is promoted to float32. 726 complex256 is demoted to complex128. 727 Other types are cast to float64. 728 729 Parameters 730 ---------- 731 input_dtype : np.dtype or Iterable of np.dtype 732 The input dtype. If a sequence of multiple dtypes is provided, each 733 dtype is first converted to a supported floating point type and the 734 final dtype is then determined by applying `np.result_type` on the 735 sequence of supported floating point types. 736 allow_complex : bool, optional 737 If False, raise a ValueError on complex-valued inputs. 738 739 Returns 740 ------- 741 float_type : dtype 742 Floating-point dtype for the image. 743 """ 744 if isinstance(input_dtype, Iterable) and not isinstance(input_dtype, str): 745 return np.result_type(*(_supported_float_type(d) for d in input_dtype)) 746 input_dtype = np.dtype(input_dtype) 747 if not allow_complex and input_dtype.kind == 'c': 748 raise ValueError("complex valued input is not supported") 749 return new_float_type.get(input_dtype.char, np.float64) 750 751 752def identity(image, *args, **kwargs): 753 """Returns the first argument unmodified.""" 754 return image 755