1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
3"""
4Classes that deal with stretching, i.e. mapping a range of [0:1] values onto
5another set of [0:1] values with a transformation
6"""
7
8import numpy as np
9
10from .transform import BaseTransform
11from .transform import CompositeTransform
12
13
14__all__ = ["BaseStretch", "LinearStretch", "SqrtStretch", "PowerStretch",
15           "PowerDistStretch", "SquaredStretch", "LogStretch", "AsinhStretch",
16           "SinhStretch", "HistEqStretch", "ContrastBiasStretch",
17           "CompositeStretch"]
18
19
20def _logn(n, x, out=None):
21    """Calculate the log base n of x."""
22    # We define this because numpy.lib.scimath.logn doesn't support out=
23    if out is None:
24        return np.log(x) / np.log(n)
25    else:
26        np.log(x, out=out)
27        np.true_divide(out, np.log(n), out=out)
28        return out
29
30
31def _prepare(values, clip=True, out=None):
32    """
33    Prepare the data by optionally clipping and copying, and return the
34    array that should be subsequently used for in-place calculations.
35    """
36
37    if clip:
38        return np.clip(values, 0., 1., out=out)
39    else:
40        if out is None:
41            return np.array(values, copy=True)
42        else:
43            out[:] = np.asarray(values)
44            return out
45
46
47class BaseStretch(BaseTransform):
48    """
49    Base class for the stretch classes, which, when called with an array
50    of values in the range [0:1], return an transformed array of values,
51    also in the range [0:1].
52    """
53
54    @property
55    def _supports_invalid_kw(self):
56        return False
57
58    def __add__(self, other):
59        return CompositeStretch(other, self)
60
61    def __call__(self, values, clip=True, out=None):
62        """
63        Transform values using this stretch.
64
65        Parameters
66        ----------
67        values : array-like
68            The input values, which should already be normalized to the
69            [0:1] range.
70        clip : bool, optional
71            If `True` (default), values outside the [0:1] range are
72            clipped to the [0:1] range.
73        out : ndarray, optional
74            If specified, the output values will be placed in this array
75            (typically used for in-place calculations).
76
77        Returns
78        -------
79        result : ndarray
80            The transformed values.
81        """
82
83    @property
84    def inverse(self):
85        """A stretch object that performs the inverse operation."""
86
87
88class LinearStretch(BaseStretch):
89    """
90    A linear stretch with a slope and offset.
91
92    The stretch is given by:
93
94    .. math::
95        y = slope x + intercept
96
97    Parameters
98    ----------
99    slope : float, optional
100        The ``slope`` parameter used in the above formula.  Default is 1.
101    intercept : float, optional
102        The ``intercept`` parameter used in the above formula.  Default is 0.
103    """
104
105    def __init__(self, slope=1, intercept=0):
106        super().__init__()
107        self.slope = slope
108        self.intercept = intercept
109
110    def __call__(self, values, clip=True, out=None):
111        values = _prepare(values, clip=clip, out=out)
112        if self.slope != 1:
113            np.multiply(values, self.slope, out=values)
114        if self.intercept != 0:
115            np.add(values, self.intercept, out=values)
116        return values
117
118    @property
119    def inverse(self):
120        """A stretch object that performs the inverse operation."""
121        return LinearStretch(1. / self.slope, - self.intercept / self.slope)
122
123
124class SqrtStretch(BaseStretch):
125    r"""
126    A square root stretch.
127
128    The stretch is given by:
129
130    .. math::
131        y = \sqrt{x}
132    """
133
134    @property
135    def _supports_invalid_kw(self):
136        return True
137
138    def __call__(self, values, clip=True, out=None, invalid=None):
139        """
140        Transform values using this stretch.
141
142        Parameters
143        ----------
144        values : array-like
145            The input values, which should already be normalized to the
146            [0:1] range.
147        clip : bool, optional
148            If `True` (default), values outside the [0:1] range are
149            clipped to the [0:1] range.
150        out : ndarray, optional
151            If specified, the output values will be placed in this array
152            (typically used for in-place calculations).
153        invalid : None or float, optional
154            Value to assign NaN values generated by this class.  NaNs in
155            the input ``values`` array are not changed.  This option is
156            generally used with matplotlib normalization classes, where
157            the ``invalid`` value should map to the matplotlib colormap
158            "under" value (i.e., any finite value < 0).  If `None`, then
159            NaN values are not replaced.  This keyword has no effect if
160            ``clip=True``.
161
162        Returns
163        -------
164        result : ndarray
165            The transformed values.
166        """
167
168        values = _prepare(values, clip=clip, out=out)
169        replace_invalid = not clip and invalid is not None
170        with np.errstate(invalid='ignore'):
171            if replace_invalid:
172                idx = (values < 0)
173            np.sqrt(values, out=values)
174
175        if replace_invalid:
176            # Assign new NaN (i.e., NaN not in the original input
177            # values, but generated by this class) to the invalid value.
178            values[idx] = invalid
179
180        return values
181
182    @property
183    def inverse(self):
184        """A stretch object that performs the inverse operation."""
185        return PowerStretch(2)
186
187
188class PowerStretch(BaseStretch):
189    r"""
190    A power stretch.
191
192    The stretch is given by:
193
194    .. math::
195        y = x^a
196
197    Parameters
198    ----------
199    a : float
200        The power index (see the above formula).  ``a`` must be greater
201        than 0.
202    """
203
204    @property
205    def _supports_invalid_kw(self):
206        return True
207
208    def __init__(self, a):
209        super().__init__()
210        if a <= 0:
211            raise ValueError("a must be > 0")
212        self.power = a
213
214    def __call__(self, values, clip=True, out=None, invalid=None):
215        """
216        Transform values using this stretch.
217
218        Parameters
219        ----------
220        values : array-like
221            The input values, which should already be normalized to the
222            [0:1] range.
223        clip : bool, optional
224            If `True` (default), values outside the [0:1] range are
225            clipped to the [0:1] range.
226        out : ndarray, optional
227            If specified, the output values will be placed in this array
228            (typically used for in-place calculations).
229        invalid : None or float, optional
230            Value to assign NaN values generated by this class.  NaNs in
231            the input ``values`` array are not changed.  This option is
232            generally used with matplotlib normalization classes, where
233            the ``invalid`` value should map to the matplotlib colormap
234            "under" value (i.e., any finite value < 0).  If `None`, then
235            NaN values are not replaced.  This keyword has no effect if
236            ``clip=True``.
237
238        Returns
239        -------
240        result : ndarray
241            The transformed values.
242        """
243
244        values = _prepare(values, clip=clip, out=out)
245        replace_invalid = (not clip and invalid is not None
246                           and ((-1 < self.power < 0)
247                                or (0 < self.power < 1)))
248        with np.errstate(invalid='ignore'):
249            if replace_invalid:
250                idx = (values < 0)
251            np.power(values, self.power, out=values)
252
253        if replace_invalid:
254            # Assign new NaN (i.e., NaN not in the original input
255            # values, but generated by this class) to the invalid value.
256            values[idx] = invalid
257
258        return values
259
260    @property
261    def inverse(self):
262        """A stretch object that performs the inverse operation."""
263        return PowerStretch(1. / self.power)
264
265
266class PowerDistStretch(BaseStretch):
267    r"""
268    An alternative power stretch.
269
270    The stretch is given by:
271
272    .. math::
273        y = \frac{a^x - 1}{a - 1}
274
275    Parameters
276    ----------
277    a : float, optional
278        The ``a`` parameter used in the above formula.  ``a`` must be
279        greater than or equal to 0, but cannot be set to 1.  Default is
280        1000.
281    """
282
283    def __init__(self, a=1000.0):
284        if a < 0 or a == 1:  # singularity
285            raise ValueError("a must be >= 0, but cannot be set to 1")
286        super().__init__()
287        self.exp = a
288
289    def __call__(self, values, clip=True, out=None):
290        values = _prepare(values, clip=clip, out=out)
291        np.power(self.exp, values, out=values)
292        np.subtract(values, 1, out=values)
293        np.true_divide(values, self.exp - 1.0, out=values)
294        return values
295
296    @property
297    def inverse(self):
298        """A stretch object that performs the inverse operation."""
299        return InvertedPowerDistStretch(a=self.exp)
300
301
302class InvertedPowerDistStretch(BaseStretch):
303    r"""
304    Inverse transformation for
305    `~astropy.image.scaling.PowerDistStretch`.
306
307    The stretch is given by:
308
309    .. math::
310        y = \frac{\log(y (a-1) + 1)}{\log a}
311
312    Parameters
313    ----------
314    a : float, optional
315        The ``a`` parameter used in the above formula.  ``a`` must be
316        greater than or equal to 0, but cannot be set to 1.  Default is
317        1000.
318    """
319
320    def __init__(self, a=1000.0):
321        if a < 0 or a == 1:  # singularity
322            raise ValueError("a must be >= 0, but cannot be set to 1")
323        super().__init__()
324        self.exp = a
325
326    def __call__(self, values, clip=True, out=None):
327        values = _prepare(values, clip=clip, out=out)
328        np.multiply(values, self.exp - 1.0, out=values)
329        np.add(values, 1, out=values)
330        _logn(self.exp, values, out=values)
331        return values
332
333    @property
334    def inverse(self):
335        """A stretch object that performs the inverse operation."""
336        return PowerDistStretch(a=self.exp)
337
338
339class SquaredStretch(PowerStretch):
340    r"""
341    A convenience class for a power stretch of 2.
342
343    The stretch is given by:
344
345    .. math::
346        y = x^2
347    """
348
349    def __init__(self):
350        super().__init__(2)
351
352    @property
353    def inverse(self):
354        """A stretch object that performs the inverse operation."""
355        return SqrtStretch()
356
357
358class LogStretch(BaseStretch):
359    r"""
360    A log stretch.
361
362    The stretch is given by:
363
364    .. math::
365        y = \frac{\log{(a x + 1)}}{\log{(a + 1)}}
366
367    Parameters
368    ----------
369    a : float
370        The ``a`` parameter used in the above formula.  ``a`` must be
371        greater than 0.  Default is 1000.
372    """
373
374    @property
375    def _supports_invalid_kw(self):
376        return True
377
378    def __init__(self, a=1000.0):
379        super().__init__()
380        if a <= 0:  # singularity
381            raise ValueError("a must be > 0")
382        self.exp = a
383
384    def __call__(self, values, clip=True, out=None, invalid=None):
385        """
386        Transform values using this stretch.
387
388        Parameters
389        ----------
390        values : array-like
391            The input values, which should already be normalized to the
392            [0:1] range.
393        clip : bool, optional
394            If `True` (default), values outside the [0:1] range are
395            clipped to the [0:1] range.
396        out : ndarray, optional
397            If specified, the output values will be placed in this array
398            (typically used for in-place calculations).
399        invalid : None or float, optional
400            Value to assign NaN values generated by this class.  NaNs in
401            the input ``values`` array are not changed.  This option is
402            generally used with matplotlib normalization classes, where
403            the ``invalid`` value should map to the matplotlib colormap
404            "under" value (i.e., any finite value < 0).  If `None`, then
405            NaN values are not replaced.  This keyword has no effect if
406            ``clip=True``.
407
408        Returns
409        -------
410        result : ndarray
411            The transformed values.
412        """
413
414        values = _prepare(values, clip=clip, out=out)
415        replace_invalid = not clip and invalid is not None
416        with np.errstate(invalid='ignore'):
417            if replace_invalid:
418                idx = (values < 0)
419            np.multiply(values, self.exp, out=values)
420            np.add(values, 1., out=values)
421            np.log(values, out=values)
422            np.true_divide(values, np.log(self.exp + 1.), out=values)
423
424        if replace_invalid:
425            # Assign new NaN (i.e., NaN not in the original input
426            # values, but generated by this class) to the invalid value.
427            values[idx] = invalid
428
429        return values
430
431    @property
432    def inverse(self):
433        """A stretch object that performs the inverse operation."""
434        return InvertedLogStretch(self.exp)
435
436
437class InvertedLogStretch(BaseStretch):
438    r"""
439    Inverse transformation for `~astropy.image.scaling.LogStretch`.
440
441    The stretch is given by:
442
443    .. math::
444        y = \frac{e^{y \log{a + 1}} - 1}{a} \\
445        y = \frac{e^{y} (a + 1) - 1}{a}
446
447    Parameters
448    ----------
449    a : float, optional
450        The ``a`` parameter used in the above formula.  ``a`` must be
451        greater than 0.  Default is 1000.
452    """
453
454    def __init__(self, a):
455        super().__init__()
456        if a <= 0:  # singularity
457            raise ValueError("a must be > 0")
458        self.exp = a
459
460    def __call__(self, values, clip=True, out=None):
461        values = _prepare(values, clip=clip, out=out)
462        np.multiply(values, np.log(self.exp + 1.), out=values)
463        np.exp(values, out=values)
464        np.subtract(values, 1., out=values)
465        np.true_divide(values, self.exp, out=values)
466        return values
467
468    @property
469    def inverse(self):
470        """A stretch object that performs the inverse operation."""
471        return LogStretch(self.exp)
472
473
474class AsinhStretch(BaseStretch):
475    r"""
476    An asinh stretch.
477
478    The stretch is given by:
479
480    .. math::
481        y = \frac{{\rm asinh}(x / a)}{{\rm asinh}(1 / a)}.
482
483    Parameters
484    ----------
485    a : float, optional
486        The ``a`` parameter used in the above formula.  The value of
487        this parameter is where the asinh curve transitions from linear
488        to logarithmic behavior, expressed as a fraction of the
489        normalized image.  ``a`` must be greater than 0 and less than or
490        equal to 1 (0 < a <= 1).  Default is 0.1.
491    """
492
493    def __init__(self, a=0.1):
494        super().__init__()
495        if a <= 0 or a > 1:
496            raise ValueError("a must be > 0 and <= 1")
497        self.a = a
498
499    def __call__(self, values, clip=True, out=None):
500        values = _prepare(values, clip=clip, out=out)
501        np.true_divide(values, self.a, out=values)
502        np.arcsinh(values, out=values)
503        np.true_divide(values, np.arcsinh(1. / self.a), out=values)
504        return values
505
506    @property
507    def inverse(self):
508        """A stretch object that performs the inverse operation."""
509        return SinhStretch(a=1. / np.arcsinh(1. / self.a))
510
511
512class SinhStretch(BaseStretch):
513    r"""
514    A sinh stretch.
515
516    The stretch is given by:
517
518    .. math::
519        y = \frac{{\rm sinh}(x / a)}{{\rm sinh}(1 / a)}
520
521    Parameters
522    ----------
523    a : float, optional
524        The ``a`` parameter used in the above formula.  ``a`` must be
525        greater than 0 and less than or equal to 1 (0 < a <= 1).
526        Default is 1/3.
527    """
528
529    def __init__(self, a=1./3.):
530        super().__init__()
531        if a <= 0 or a > 1:
532            raise ValueError("a must be > 0 and <= 1")
533        self.a = a
534
535    def __call__(self, values, clip=True, out=None):
536        values = _prepare(values, clip=clip, out=out)
537        np.true_divide(values, self.a, out=values)
538        np.sinh(values, out=values)
539        np.true_divide(values, np.sinh(1. / self.a), out=values)
540        return values
541
542    @property
543    def inverse(self):
544        """A stretch object that performs the inverse operation."""
545        return AsinhStretch(a=1. / np.sinh(1. / self.a))
546
547
548class HistEqStretch(BaseStretch):
549    """
550    A histogram equalization stretch.
551
552    Parameters
553    ----------
554    data : array-like
555        The data defining the equalization.
556    values : array-like, optional
557        The input image values, which should already be normalized to
558        the [0:1] range.
559    """
560
561    def __init__(self, data, values=None):
562        # Assume data is not necessarily normalized at this point
563        self.data = np.sort(data.ravel())
564        self.data = self.data[np.isfinite(self.data)]
565        vmin = self.data.min()
566        vmax = self.data.max()
567        self.data = (self.data - vmin) / (vmax - vmin)
568
569        # Compute relative position of each pixel
570        if values is None:
571            self.values = np.linspace(0., 1., len(self.data))
572        else:
573            self.values = values
574
575    def __call__(self, values, clip=True, out=None):
576        values = _prepare(values, clip=clip, out=out)
577        values[:] = np.interp(values, self.data, self.values)
578        return values
579
580    @property
581    def inverse(self):
582        """A stretch object that performs the inverse operation."""
583        return InvertedHistEqStretch(self.data, values=self.values)
584
585
586class InvertedHistEqStretch(BaseStretch):
587    """
588    Inverse transformation for `~astropy.image.scaling.HistEqStretch`.
589
590    Parameters
591    ----------
592    data : array-like
593        The data defining the equalization.
594    values : array-like, optional
595        The input image values, which should already be normalized to
596        the [0:1] range.
597    """
598
599    def __init__(self, data, values=None):
600        self.data = data[np.isfinite(data)]
601        if values is None:
602            self.values = np.linspace(0., 1., len(self.data))
603        else:
604            self.values = values
605
606    def __call__(self, values, clip=True, out=None):
607        values = _prepare(values, clip=clip, out=out)
608        values[:] = np.interp(values, self.values, self.data)
609        return values
610
611    @property
612    def inverse(self):
613        """A stretch object that performs the inverse operation."""
614        return HistEqStretch(self.data, values=self.values)
615
616
617class ContrastBiasStretch(BaseStretch):
618    r"""
619    A stretch that takes into account contrast and bias.
620
621    The stretch is given by:
622
623    .. math::
624        y = (x - {\rm bias}) * {\rm contrast} + 0.5
625
626    and the output values are clipped to the [0:1] range.
627
628    Parameters
629    ----------
630    contrast : float
631        The contrast parameter (see the above formula).
632
633    bias : float
634        The bias parameter (see the above formula).
635    """
636
637    def __init__(self, contrast, bias):
638        super().__init__()
639        self.contrast = contrast
640        self.bias = bias
641
642    def __call__(self, values, clip=True, out=None):
643        # As a special case here, we only clip *after* the
644        # transformation since it does not map [0:1] to [0:1]
645        values = _prepare(values, clip=False, out=out)
646
647        np.subtract(values, self.bias, out=values)
648        np.multiply(values, self.contrast, out=values)
649        np.add(values, 0.5, out=values)
650
651        if clip:
652            np.clip(values, 0, 1, out=values)
653
654        return values
655
656    @property
657    def inverse(self):
658        """A stretch object that performs the inverse operation."""
659        return InvertedContrastBiasStretch(self.contrast, self.bias)
660
661
662class InvertedContrastBiasStretch(BaseStretch):
663    """
664    Inverse transformation for ContrastBiasStretch.
665
666    Parameters
667    ----------
668    contrast : float
669        The contrast parameter (see
670        `~astropy.visualization.ConstrastBiasStretch).
671
672    bias : float
673        The bias parameter (see
674        `~astropy.visualization.ConstrastBiasStretch).
675    """
676
677    def __init__(self, contrast, bias):
678        super().__init__()
679        self.contrast = contrast
680        self.bias = bias
681
682    def __call__(self, values, clip=True, out=None):
683        # As a special case here, we only clip *after* the
684        # transformation since it does not map [0:1] to [0:1]
685        values = _prepare(values, clip=False, out=out)
686        np.subtract(values, 0.5, out=values)
687        np.true_divide(values, self.contrast, out=values)
688        np.add(values, self.bias, out=values)
689
690        if clip:
691            np.clip(values, 0, 1, out=values)
692
693        return values
694
695    @property
696    def inverse(self):
697        """A stretch object that performs the inverse operation."""
698        return ContrastBiasStretch(self.contrast, self.bias)
699
700
701class CompositeStretch(CompositeTransform, BaseStretch):
702    """
703    A combination of two stretches.
704
705    Parameters
706    ----------
707    stretch_1 : :class:`astropy.visualization.BaseStretch`
708        The first stretch to apply.
709    stretch_2 : :class:`astropy.visualization.BaseStretch`
710        The second stretch to apply.
711    """
712
713    def __call__(self, values, clip=True, out=None):
714        return self.transform_2(
715            self.transform_1(values, clip=clip, out=out), clip=clip, out=out)
716