1""" 2Normalization class for Matplotlib that can be used to produce 3colorbars. 4""" 5 6import inspect 7 8import numpy as np 9from numpy import ma 10 11from .interval import (PercentileInterval, AsymmetricPercentileInterval, 12 ManualInterval, MinMaxInterval, BaseInterval) 13from .stretch import (LinearStretch, SqrtStretch, PowerStretch, LogStretch, 14 AsinhStretch, BaseStretch) 15 16try: 17 import matplotlib # pylint: disable=W0611 18 from matplotlib.colors import Normalize 19 from matplotlib import pyplot as plt 20except ImportError: 21 class Normalize: 22 def __init__(self, *args, **kwargs): 23 raise ImportError('matplotlib is required in order to use this ' 24 'class.') 25 26 27__all__ = ['ImageNormalize', 'simple_norm', 'imshow_norm'] 28 29__doctest_requires__ = {'*': ['matplotlib']} 30 31 32class ImageNormalize(Normalize): 33 """ 34 Normalization class to be used with Matplotlib. 35 36 Parameters 37 ---------- 38 data : ndarray, optional 39 The image array. This input is used only if ``interval`` is 40 also input. ``data`` and ``interval`` are used to compute the 41 vmin and/or vmax values only if ``vmin`` or ``vmax`` are not 42 input. 43 interval : `~astropy.visualization.BaseInterval` subclass instance, optional 44 The interval object to apply to the input ``data`` to determine 45 the ``vmin`` and ``vmax`` values. This input is used only if 46 ``data`` is also input. ``data`` and ``interval`` are used to 47 compute the vmin and/or vmax values only if ``vmin`` or ``vmax`` 48 are not input. 49 vmin, vmax : float, optional 50 The minimum and maximum levels to show for the data. The 51 ``vmin`` and ``vmax`` inputs override any calculated values from 52 the ``interval`` and ``data`` inputs. 53 stretch : `~astropy.visualization.BaseStretch` subclass instance 54 The stretch object to apply to the data. The default is 55 `~astropy.visualization.LinearStretch`. 56 clip : bool, optional 57 If `True`, data values outside the [0:1] range are clipped to 58 the [0:1] range. 59 invalid : None or float, optional 60 Value to assign NaN values generated by this class. NaNs in the 61 input ``data`` array are not changed. For matplotlib 62 normalization, the ``invalid`` value should map to the 63 matplotlib colormap "under" value (i.e., any finite value < 0). 64 If `None`, then NaN values are not replaced. This keyword has 65 no effect if ``clip=True``. 66 """ 67 68 def __init__(self, data=None, interval=None, vmin=None, vmax=None, 69 stretch=LinearStretch(), clip=False, invalid=-1.0): 70 # this super call checks for matplotlib 71 super().__init__(vmin=vmin, vmax=vmax, clip=clip) 72 73 self.vmin = vmin 74 self.vmax = vmax 75 76 if stretch is None: 77 raise ValueError('stretch must be input') 78 if not isinstance(stretch, BaseStretch): 79 raise TypeError('stretch must be an instance of a BaseStretch ' 80 'subclass') 81 self.stretch = stretch 82 83 if interval is not None and not isinstance(interval, BaseInterval): 84 raise TypeError('interval must be an instance of a BaseInterval ' 85 'subclass') 86 self.interval = interval 87 88 self.inverse_stretch = stretch.inverse 89 self.clip = clip 90 self.invalid = invalid 91 92 # Define vmin and vmax if not None and data was input 93 if data is not None: 94 self._set_limits(data) 95 96 def _set_limits(self, data): 97 if self.vmin is not None and self.vmax is not None: 98 return 99 100 # Define vmin and vmax from the interval class if not None 101 if self.interval is None: 102 if self.vmin is None: 103 self.vmin = np.min(data[np.isfinite(data)]) 104 if self.vmax is None: 105 self.vmax = np.max(data[np.isfinite(data)]) 106 else: 107 _vmin, _vmax = self.interval.get_limits(data) 108 if self.vmin is None: 109 self.vmin = _vmin 110 if self.vmax is None: 111 self.vmax = _vmax 112 113 def __call__(self, values, clip=None, invalid=None): 114 """ 115 Transform values using this normalization. 116 117 Parameters 118 ---------- 119 values : array-like 120 The input values. 121 clip : bool, optional 122 If `True`, values outside the [0:1] range are clipped to the 123 [0:1] range. If `None` then the ``clip`` value from the 124 `ImageNormalize` instance is used (the default of which is 125 `False`). 126 invalid : None or float, optional 127 Value to assign NaN values generated by this class. NaNs in 128 the input ``data`` array are not changed. For matplotlib 129 normalization, the ``invalid`` value should map to the 130 matplotlib colormap "under" value (i.e., any finite value < 131 0). If `None`, then the `ImageNormalize` instance value is 132 used. This keyword has no effect if ``clip=True``. 133 """ 134 135 if clip is None: 136 clip = self.clip 137 138 if invalid is None: 139 invalid = self.invalid 140 141 if isinstance(values, ma.MaskedArray): 142 if clip: 143 mask = False 144 else: 145 mask = values.mask 146 values = values.filled(self.vmax) 147 else: 148 mask = False 149 150 # Make sure scalars get broadcast to 1-d 151 if np.isscalar(values): 152 values = np.array([values], dtype=float) 153 else: 154 # copy because of in-place operations after 155 values = np.array(values, copy=True, dtype=float) 156 157 # Define vmin and vmax if not None 158 self._set_limits(values) 159 160 # Normalize based on vmin and vmax 161 np.subtract(values, self.vmin, out=values) 162 np.true_divide(values, self.vmax - self.vmin, out=values) 163 164 # Clip to the 0 to 1 range 165 if clip: 166 values = np.clip(values, 0., 1., out=values) 167 168 # Stretch values 169 if self.stretch._supports_invalid_kw: 170 values = self.stretch(values, out=values, clip=False, 171 invalid=invalid) 172 else: 173 values = self.stretch(values, out=values, clip=False) 174 175 # Convert to masked array for matplotlib 176 return ma.array(values, mask=mask) 177 178 def inverse(self, values, invalid=None): 179 # Find unstretched values in range 0 to 1 180 if self.inverse_stretch._supports_invalid_kw: 181 values_norm = self.inverse_stretch(values, clip=False, 182 invalid=invalid) 183 else: 184 values_norm = self.inverse_stretch(values, clip=False) 185 186 # Scale to original range 187 return values_norm * (self.vmax - self.vmin) + self.vmin 188 189 190def simple_norm(data, stretch='linear', power=1.0, asinh_a=0.1, min_cut=None, 191 max_cut=None, min_percent=None, max_percent=None, 192 percent=None, clip=False, log_a=1000, invalid=-1.0): 193 """ 194 Return a Normalization class that can be used for displaying images 195 with Matplotlib. 196 197 This function enables only a subset of image stretching functions 198 available in `~astropy.visualization.mpl_normalize.ImageNormalize`. 199 200 This function is used by the 201 ``astropy.visualization.scripts.fits2bitmap`` script. 202 203 Parameters 204 ---------- 205 data : ndarray 206 The image array. 207 208 stretch : {'linear', 'sqrt', 'power', log', 'asinh'}, optional 209 The stretch function to apply to the image. The default is 210 'linear'. 211 212 power : float, optional 213 The power index for ``stretch='power'``. The default is 1.0. 214 215 asinh_a : float, optional 216 For ``stretch='asinh'``, the value where the asinh curve 217 transitions from linear to logarithmic behavior, expressed as a 218 fraction of the normalized image. Must be in the range between 219 0 and 1. The default is 0.1. 220 221 min_cut : float, optional 222 The pixel value of the minimum cut level. Data values less than 223 ``min_cut`` will set to ``min_cut`` before stretching the image. 224 The default is the image minimum. ``min_cut`` overrides 225 ``min_percent``. 226 227 max_cut : float, optional 228 The pixel value of the maximum cut level. Data values greater 229 than ``min_cut`` will set to ``min_cut`` before stretching the 230 image. The default is the image maximum. ``max_cut`` overrides 231 ``max_percent``. 232 233 min_percent : float, optional 234 The percentile value used to determine the pixel value of 235 minimum cut level. The default is 0.0. ``min_percent`` 236 overrides ``percent``. 237 238 max_percent : float, optional 239 The percentile value used to determine the pixel value of 240 maximum cut level. The default is 100.0. ``max_percent`` 241 overrides ``percent``. 242 243 percent : float, optional 244 The percentage of the image values used to determine the pixel 245 values of the minimum and maximum cut levels. The lower cut 246 level will set at the ``(100 - percent) / 2`` percentile, while 247 the upper cut level will be set at the ``(100 + percent) / 2`` 248 percentile. The default is 100.0. ``percent`` is ignored if 249 either ``min_percent`` or ``max_percent`` is input. 250 251 clip : bool, optional 252 If `True`, data values outside the [0:1] range are clipped to 253 the [0:1] range. 254 255 log_a : float, optional 256 The log index for ``stretch='log'``. The default is 1000. 257 258 invalid : None or float, optional 259 Value to assign NaN values generated by the normalization. NaNs 260 in the input ``data`` array are not changed. For matplotlib 261 normalization, the ``invalid`` value should map to the 262 matplotlib colormap "under" value (i.e., any finite value < 0). 263 If `None`, then NaN values are not replaced. This keyword has 264 no effect if ``clip=True``. 265 266 Returns 267 ------- 268 result : `ImageNormalize` instance 269 An `ImageNormalize` instance that can be used for displaying 270 images with Matplotlib. 271 """ 272 273 if percent is not None: 274 interval = PercentileInterval(percent) 275 elif min_percent is not None or max_percent is not None: 276 interval = AsymmetricPercentileInterval(min_percent or 0., 277 max_percent or 100.) 278 elif min_cut is not None or max_cut is not None: 279 interval = ManualInterval(min_cut, max_cut) 280 else: 281 interval = MinMaxInterval() 282 283 if stretch == 'linear': 284 stretch = LinearStretch() 285 elif stretch == 'sqrt': 286 stretch = SqrtStretch() 287 elif stretch == 'power': 288 stretch = PowerStretch(power) 289 elif stretch == 'log': 290 stretch = LogStretch(log_a) 291 elif stretch == 'asinh': 292 stretch = AsinhStretch(asinh_a) 293 else: 294 raise ValueError(f'Unknown stretch: {stretch}.') 295 296 vmin, vmax = interval.get_limits(data) 297 298 return ImageNormalize(vmin=vmin, vmax=vmax, stretch=stretch, clip=clip, 299 invalid=invalid) 300 301 302# used in imshow_norm 303_norm_sig = inspect.signature(ImageNormalize) 304 305 306def imshow_norm(data, ax=None, **kwargs): 307 """ A convenience function to call matplotlib's `matplotlib.pyplot.imshow` 308 function, using an `ImageNormalize` object as the normalization. 309 310 Parameters 311 ---------- 312 data : 2D or 3D array-like 313 The data to show. Can be whatever `~matplotlib.pyplot.imshow` and 314 `ImageNormalize` both accept. See `~matplotlib.pyplot.imshow`. 315 ax : None or `~matplotlib.axes.Axes`, optional 316 If None, use pyplot's imshow. Otherwise, calls ``imshow`` method of 317 the supplied axes. 318 kwargs : dict, optional 319 All other keyword arguments are parsed first by the 320 `ImageNormalize` initializer, then to 321 `~matplotlib.pyplot.imshow`. 322 323 Returns 324 ------- 325 result : tuple 326 A tuple containing the `~matplotlib.image.AxesImage` generated 327 by `~matplotlib.pyplot.imshow` as well as the `ImageNormalize` 328 instance. 329 330 Notes 331 ----- 332 The ``norm`` matplotlib keyword is not supported. 333 334 Examples 335 -------- 336 .. plot:: 337 :include-source: 338 339 import numpy as np 340 import matplotlib.pyplot as plt 341 from astropy.visualization import (imshow_norm, MinMaxInterval, 342 SqrtStretch) 343 344 # Generate and display a test image 345 image = np.arange(65536).reshape((256, 256)) 346 fig = plt.figure() 347 ax = fig.add_subplot(1, 1, 1) 348 im, norm = imshow_norm(image, ax, origin='lower', 349 interval=MinMaxInterval(), 350 stretch=SqrtStretch()) 351 fig.colorbar(im) 352 """ 353 if 'X' in kwargs: 354 raise ValueError('Cannot give both ``X`` and ``data``') 355 356 if 'norm' in kwargs: 357 raise ValueError('There is no point in using imshow_norm if you give ' 358 'the ``norm`` keyword - use imshow directly if you ' 359 'want that.') 360 361 imshow_kwargs = dict(kwargs) 362 363 norm_kwargs = {'data': data} 364 for pname in _norm_sig.parameters: 365 if pname in kwargs: 366 norm_kwargs[pname] = imshow_kwargs.pop(pname) 367 368 imshow_kwargs['norm'] = ImageNormalize(**norm_kwargs) 369 370 if ax is None: 371 imshow_result = plt.imshow(data, **imshow_kwargs) 372 else: 373 imshow_result = ax.imshow(data, **imshow_kwargs) 374 375 return imshow_result, imshow_kwargs['norm'] 376