1"""
2Normalization class for Matplotlib that can be used to produce
3colorbars.
4"""
5
6import inspect
7
8import numpy as np
9from numpy import ma
10
11from .interval import (PercentileInterval, AsymmetricPercentileInterval,
12                       ManualInterval, MinMaxInterval, BaseInterval)
13from .stretch import (LinearStretch, SqrtStretch, PowerStretch, LogStretch,
14                      AsinhStretch, BaseStretch)
15
16try:
17    import matplotlib  # pylint: disable=W0611
18    from matplotlib.colors import Normalize
19    from matplotlib import pyplot as plt
20except ImportError:
21    class Normalize:
22        def __init__(self, *args, **kwargs):
23            raise ImportError('matplotlib is required in order to use this '
24                              'class.')
25
26
27__all__ = ['ImageNormalize', 'simple_norm', 'imshow_norm']
28
29__doctest_requires__ = {'*': ['matplotlib']}
30
31
32class ImageNormalize(Normalize):
33    """
34    Normalization class to be used with Matplotlib.
35
36    Parameters
37    ----------
38    data : ndarray, optional
39        The image array.  This input is used only if ``interval`` is
40        also input.  ``data`` and ``interval`` are used to compute the
41        vmin and/or vmax values only if ``vmin`` or ``vmax`` are not
42        input.
43    interval : `~astropy.visualization.BaseInterval` subclass instance, optional
44        The interval object to apply to the input ``data`` to determine
45        the ``vmin`` and ``vmax`` values.  This input is used only if
46        ``data`` is also input.  ``data`` and ``interval`` are used to
47        compute the vmin and/or vmax values only if ``vmin`` or ``vmax``
48        are not input.
49    vmin, vmax : float, optional
50        The minimum and maximum levels to show for the data.  The
51        ``vmin`` and ``vmax`` inputs override any calculated values from
52        the ``interval`` and ``data`` inputs.
53    stretch : `~astropy.visualization.BaseStretch` subclass instance
54        The stretch object to apply to the data.  The default is
55        `~astropy.visualization.LinearStretch`.
56    clip : bool, optional
57        If `True`, data values outside the [0:1] range are clipped to
58        the [0:1] range.
59    invalid : None or float, optional
60        Value to assign NaN values generated by this class.  NaNs in the
61        input ``data`` array are not changed.  For matplotlib
62        normalization, the ``invalid`` value should map to the
63        matplotlib colormap "under" value (i.e., any finite value < 0).
64        If `None`, then NaN values are not replaced.  This keyword has
65        no effect if ``clip=True``.
66    """
67
68    def __init__(self, data=None, interval=None, vmin=None, vmax=None,
69                 stretch=LinearStretch(), clip=False, invalid=-1.0):
70        # this super call checks for matplotlib
71        super().__init__(vmin=vmin, vmax=vmax, clip=clip)
72
73        self.vmin = vmin
74        self.vmax = vmax
75
76        if stretch is None:
77            raise ValueError('stretch must be input')
78        if not isinstance(stretch, BaseStretch):
79            raise TypeError('stretch must be an instance of a BaseStretch '
80                            'subclass')
81        self.stretch = stretch
82
83        if interval is not None and not isinstance(interval, BaseInterval):
84            raise TypeError('interval must be an instance of a BaseInterval '
85                            'subclass')
86        self.interval = interval
87
88        self.inverse_stretch = stretch.inverse
89        self.clip = clip
90        self.invalid = invalid
91
92        # Define vmin and vmax if not None and data was input
93        if data is not None:
94            self._set_limits(data)
95
96    def _set_limits(self, data):
97        if self.vmin is not None and self.vmax is not None:
98            return
99
100        # Define vmin and vmax from the interval class if not None
101        if self.interval is None:
102            if self.vmin is None:
103                self.vmin = np.min(data[np.isfinite(data)])
104            if self.vmax is None:
105                self.vmax = np.max(data[np.isfinite(data)])
106        else:
107            _vmin, _vmax = self.interval.get_limits(data)
108            if self.vmin is None:
109                self.vmin = _vmin
110            if self.vmax is None:
111                self.vmax = _vmax
112
113    def __call__(self, values, clip=None, invalid=None):
114        """
115        Transform values using this normalization.
116
117        Parameters
118        ----------
119        values : array-like
120            The input values.
121        clip : bool, optional
122            If `True`, values outside the [0:1] range are clipped to the
123            [0:1] range.  If `None` then the ``clip`` value from the
124            `ImageNormalize` instance is used (the default of which is
125            `False`).
126        invalid : None or float, optional
127            Value to assign NaN values generated by this class.  NaNs in
128            the input ``data`` array are not changed.  For matplotlib
129            normalization, the ``invalid`` value should map to the
130            matplotlib colormap "under" value (i.e., any finite value <
131            0).  If `None`, then the `ImageNormalize` instance value is
132            used.  This keyword has no effect if ``clip=True``.
133        """
134
135        if clip is None:
136            clip = self.clip
137
138        if invalid is None:
139            invalid = self.invalid
140
141        if isinstance(values, ma.MaskedArray):
142            if clip:
143                mask = False
144            else:
145                mask = values.mask
146            values = values.filled(self.vmax)
147        else:
148            mask = False
149
150        # Make sure scalars get broadcast to 1-d
151        if np.isscalar(values):
152            values = np.array([values], dtype=float)
153        else:
154            # copy because of in-place operations after
155            values = np.array(values, copy=True, dtype=float)
156
157        # Define vmin and vmax if not None
158        self._set_limits(values)
159
160        # Normalize based on vmin and vmax
161        np.subtract(values, self.vmin, out=values)
162        np.true_divide(values, self.vmax - self.vmin, out=values)
163
164        # Clip to the 0 to 1 range
165        if clip:
166            values = np.clip(values, 0., 1., out=values)
167
168        # Stretch values
169        if self.stretch._supports_invalid_kw:
170            values = self.stretch(values, out=values, clip=False,
171                                  invalid=invalid)
172        else:
173            values = self.stretch(values, out=values, clip=False)
174
175        # Convert to masked array for matplotlib
176        return ma.array(values, mask=mask)
177
178    def inverse(self, values, invalid=None):
179        # Find unstretched values in range 0 to 1
180        if self.inverse_stretch._supports_invalid_kw:
181            values_norm = self.inverse_stretch(values, clip=False,
182                                               invalid=invalid)
183        else:
184            values_norm = self.inverse_stretch(values, clip=False)
185
186        # Scale to original range
187        return values_norm * (self.vmax - self.vmin) + self.vmin
188
189
190def simple_norm(data, stretch='linear', power=1.0, asinh_a=0.1, min_cut=None,
191                max_cut=None, min_percent=None, max_percent=None,
192                percent=None, clip=False, log_a=1000, invalid=-1.0):
193    """
194    Return a Normalization class that can be used for displaying images
195    with Matplotlib.
196
197    This function enables only a subset of image stretching functions
198    available in `~astropy.visualization.mpl_normalize.ImageNormalize`.
199
200    This function is used by the
201    ``astropy.visualization.scripts.fits2bitmap`` script.
202
203    Parameters
204    ----------
205    data : ndarray
206        The image array.
207
208    stretch : {'linear', 'sqrt', 'power', log', 'asinh'}, optional
209        The stretch function to apply to the image.  The default is
210        'linear'.
211
212    power : float, optional
213        The power index for ``stretch='power'``.  The default is 1.0.
214
215    asinh_a : float, optional
216        For ``stretch='asinh'``, the value where the asinh curve
217        transitions from linear to logarithmic behavior, expressed as a
218        fraction of the normalized image.  Must be in the range between
219        0 and 1.  The default is 0.1.
220
221    min_cut : float, optional
222        The pixel value of the minimum cut level.  Data values less than
223        ``min_cut`` will set to ``min_cut`` before stretching the image.
224        The default is the image minimum.  ``min_cut`` overrides
225        ``min_percent``.
226
227    max_cut : float, optional
228        The pixel value of the maximum cut level.  Data values greater
229        than ``min_cut`` will set to ``min_cut`` before stretching the
230        image.  The default is the image maximum.  ``max_cut`` overrides
231        ``max_percent``.
232
233    min_percent : float, optional
234        The percentile value used to determine the pixel value of
235        minimum cut level.  The default is 0.0.  ``min_percent``
236        overrides ``percent``.
237
238    max_percent : float, optional
239        The percentile value used to determine the pixel value of
240        maximum cut level.  The default is 100.0.  ``max_percent``
241        overrides ``percent``.
242
243    percent : float, optional
244        The percentage of the image values used to determine the pixel
245        values of the minimum and maximum cut levels.  The lower cut
246        level will set at the ``(100 - percent) / 2`` percentile, while
247        the upper cut level will be set at the ``(100 + percent) / 2``
248        percentile.  The default is 100.0.  ``percent`` is ignored if
249        either ``min_percent`` or ``max_percent`` is input.
250
251    clip : bool, optional
252        If `True`, data values outside the [0:1] range are clipped to
253        the [0:1] range.
254
255    log_a : float, optional
256        The log index for ``stretch='log'``. The default is 1000.
257
258    invalid : None or float, optional
259        Value to assign NaN values generated by the normalization.  NaNs
260        in the input ``data`` array are not changed.  For matplotlib
261        normalization, the ``invalid`` value should map to the
262        matplotlib colormap "under" value (i.e., any finite value < 0).
263        If `None`, then NaN values are not replaced.  This keyword has
264        no effect if ``clip=True``.
265
266    Returns
267    -------
268    result : `ImageNormalize` instance
269        An `ImageNormalize` instance that can be used for displaying
270        images with Matplotlib.
271    """
272
273    if percent is not None:
274        interval = PercentileInterval(percent)
275    elif min_percent is not None or max_percent is not None:
276        interval = AsymmetricPercentileInterval(min_percent or 0.,
277                                                max_percent or 100.)
278    elif min_cut is not None or max_cut is not None:
279        interval = ManualInterval(min_cut, max_cut)
280    else:
281        interval = MinMaxInterval()
282
283    if stretch == 'linear':
284        stretch = LinearStretch()
285    elif stretch == 'sqrt':
286        stretch = SqrtStretch()
287    elif stretch == 'power':
288        stretch = PowerStretch(power)
289    elif stretch == 'log':
290        stretch = LogStretch(log_a)
291    elif stretch == 'asinh':
292        stretch = AsinhStretch(asinh_a)
293    else:
294        raise ValueError(f'Unknown stretch: {stretch}.')
295
296    vmin, vmax = interval.get_limits(data)
297
298    return ImageNormalize(vmin=vmin, vmax=vmax, stretch=stretch, clip=clip,
299                          invalid=invalid)
300
301
302# used in imshow_norm
303_norm_sig = inspect.signature(ImageNormalize)
304
305
306def imshow_norm(data, ax=None, **kwargs):
307    """ A convenience function to call matplotlib's `matplotlib.pyplot.imshow`
308    function, using an `ImageNormalize` object as the normalization.
309
310    Parameters
311    ----------
312    data : 2D or 3D array-like
313        The data to show. Can be whatever `~matplotlib.pyplot.imshow` and
314        `ImageNormalize` both accept. See `~matplotlib.pyplot.imshow`.
315    ax : None or `~matplotlib.axes.Axes`, optional
316        If None, use pyplot's imshow.  Otherwise, calls ``imshow`` method of
317        the supplied axes.
318    kwargs : dict, optional
319        All other keyword arguments are parsed first by the
320        `ImageNormalize` initializer, then to
321        `~matplotlib.pyplot.imshow`.
322
323    Returns
324    -------
325    result : tuple
326        A tuple containing the `~matplotlib.image.AxesImage` generated
327        by `~matplotlib.pyplot.imshow` as well as the `ImageNormalize`
328        instance.
329
330    Notes
331    -----
332    The ``norm`` matplotlib keyword is not supported.
333
334    Examples
335    --------
336    .. plot::
337        :include-source:
338
339        import numpy as np
340        import matplotlib.pyplot as plt
341        from astropy.visualization import (imshow_norm, MinMaxInterval,
342                                           SqrtStretch)
343
344        # Generate and display a test image
345        image = np.arange(65536).reshape((256, 256))
346        fig = plt.figure()
347        ax = fig.add_subplot(1, 1, 1)
348        im, norm = imshow_norm(image, ax, origin='lower',
349                               interval=MinMaxInterval(),
350                               stretch=SqrtStretch())
351        fig.colorbar(im)
352    """
353    if 'X' in kwargs:
354        raise ValueError('Cannot give both ``X`` and ``data``')
355
356    if 'norm' in kwargs:
357        raise ValueError('There is no point in using imshow_norm if you give '
358                         'the ``norm`` keyword - use imshow directly if you '
359                         'want that.')
360
361    imshow_kwargs = dict(kwargs)
362
363    norm_kwargs = {'data': data}
364    for pname in _norm_sig.parameters:
365        if pname in kwargs:
366            norm_kwargs[pname] = imshow_kwargs.pop(pname)
367
368    imshow_kwargs['norm'] = ImageNormalize(**norm_kwargs)
369
370    if ax is None:
371        imshow_result = plt.imshow(data, **imshow_kwargs)
372    else:
373        imshow_result = ax.imshow(data, **imshow_kwargs)
374
375    return imshow_result, imshow_kwargs['norm']
376