1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
3"""
4Classes that deal with computing intervals from arrays of values based on
5various criteria.
6"""
7
8import abc
9import numpy as np
10
11from .transform import BaseTransform
12
13
14__all__ = ['BaseInterval', 'ManualInterval', 'MinMaxInterval',
15           'AsymmetricPercentileInterval', 'PercentileInterval',
16           'ZScaleInterval']
17
18
19class BaseInterval(BaseTransform):
20    """
21    Base class for the interval classes, which, when called with an
22    array of values, return an interval computed following different
23    algorithms.
24    """
25
26    @abc.abstractmethod
27    def get_limits(self, values):
28        """
29        Return the minimum and maximum value in the interval based on
30        the values provided.
31
32        Parameters
33        ----------
34        values : ndarray
35            The image values.
36
37        Returns
38        -------
39        vmin, vmax : float
40            The mininium and maximum image value in the interval.
41        """
42
43        raise NotImplementedError('Needs to be implemented in a subclass.')
44
45    def __call__(self, values, clip=True, out=None):
46        """
47        Transform values using this interval.
48
49        Parameters
50        ----------
51        values : array-like
52            The input values.
53        clip : bool, optional
54            If `True` (default), values outside the [0:1] range are
55            clipped to the [0:1] range.
56        out : ndarray, optional
57            If specified, the output values will be placed in this array
58            (typically used for in-place calculations).
59
60        Returns
61        -------
62        result : ndarray
63            The transformed values.
64        """
65
66        vmin, vmax = self.get_limits(values)
67
68        if out is None:
69            values = np.subtract(values, float(vmin))
70        else:
71            if out.dtype.kind != 'f':
72                raise TypeError('Can only do in-place scaling for '
73                                'floating-point arrays')
74            values = np.subtract(values, float(vmin), out=out)
75
76        if (vmax - vmin) != 0:
77            np.true_divide(values, vmax - vmin, out=values)
78
79        if clip:
80            np.clip(values, 0., 1., out=values)
81
82        return values
83
84
85class ManualInterval(BaseInterval):
86    """
87    Interval based on user-specified values.
88
89    Parameters
90    ----------
91    vmin : float, optional
92        The minimum value in the scaling.  Defaults to the image
93        minimum (ignoring NaNs)
94    vmax : float, optional
95        The maximum value in the scaling.  Defaults to the image
96        maximum (ignoring NaNs)
97    """
98
99    def __init__(self, vmin=None, vmax=None):
100        self.vmin = vmin
101        self.vmax = vmax
102
103    def get_limits(self, values):
104        # Make sure values is a Numpy array
105        values = np.asarray(values).ravel()
106
107        # Filter out invalid values (inf, nan)
108        values = values[np.isfinite(values)]
109
110        vmin = np.min(values) if self.vmin is None else self.vmin
111        vmax = np.max(values) if self.vmax is None else self.vmax
112        return vmin, vmax
113
114
115class MinMaxInterval(BaseInterval):
116    """
117    Interval based on the minimum and maximum values in the data.
118    """
119
120    def get_limits(self, values):
121        # Make sure values is a Numpy array
122        values = np.asarray(values).ravel()
123
124        # Filter out invalid values (inf, nan)
125        values = values[np.isfinite(values)]
126
127        return np.min(values), np.max(values)
128
129
130class AsymmetricPercentileInterval(BaseInterval):
131    """
132    Interval based on a keeping a specified fraction of pixels (can be
133    asymmetric).
134
135    Parameters
136    ----------
137    lower_percentile : float
138        The lower percentile below which to ignore pixels.
139    upper_percentile : float
140        The upper percentile above which to ignore pixels.
141    n_samples : int, optional
142        Maximum number of values to use. If this is specified, and there
143        are more values in the dataset as this, then values are randomly
144        sampled from the array (with replacement).
145    """
146
147    def __init__(self, lower_percentile, upper_percentile, n_samples=None):
148        self.lower_percentile = lower_percentile
149        self.upper_percentile = upper_percentile
150        self.n_samples = n_samples
151
152    def get_limits(self, values):
153        # Make sure values is a Numpy array
154        values = np.asarray(values).ravel()
155
156        # If needed, limit the number of samples. We sample with replacement
157        # since this is much faster.
158        if self.n_samples is not None and values.size > self.n_samples:
159            values = np.random.choice(values, self.n_samples)
160
161        # Filter out invalid values (inf, nan)
162        values = values[np.isfinite(values)]
163
164        # Determine values at percentiles
165        vmin, vmax = np.percentile(values, (self.lower_percentile,
166                                            self.upper_percentile))
167
168        return vmin, vmax
169
170
171class PercentileInterval(AsymmetricPercentileInterval):
172    """
173    Interval based on a keeping a specified fraction of pixels.
174
175    Parameters
176    ----------
177    percentile : float
178        The fraction of pixels to keep. The same fraction of pixels is
179        eliminated from both ends.
180    n_samples : int, optional
181        Maximum number of values to use. If this is specified, and there
182        are more values in the dataset as this, then values are randomly
183        sampled from the array (with replacement).
184    """
185
186    def __init__(self, percentile, n_samples=None):
187        lower_percentile = (100 - percentile) * 0.5
188        upper_percentile = 100 - lower_percentile
189        super().__init__(
190            lower_percentile, upper_percentile, n_samples=n_samples)
191
192
193class ZScaleInterval(BaseInterval):
194    """
195    Interval based on IRAF's zscale.
196
197    https://iraf.net/forum/viewtopic.php?showtopic=134139
198
199    Original implementation:
200    https://github.com/spacetelescope/stsci.numdisplay/blob/master/lib/stsci/numdisplay/zscale.py
201
202    Licensed under a 3-clause BSD style license (see AURA_LICENSE.rst).
203
204    Parameters
205    ----------
206    nsamples : int, optional
207        The number of points in the array to sample for determining
208        scaling factors.  Defaults to 1000.
209    contrast : float, optional
210        The scaling factor (between 0 and 1) for determining the minimum
211        and maximum value.  Larger values increase the difference
212        between the minimum and maximum values used for display.
213        Defaults to 0.25.
214    max_reject : float, optional
215        If more than ``max_reject * npixels`` pixels are rejected, then
216        the returned values are the minimum and maximum of the data.
217        Defaults to 0.5.
218    min_npixels : int, optional
219        If there are less than ``min_npixels`` pixels remaining after
220        the pixel rejection, then the returned values are the minimum
221        and maximum of the data.  Defaults to 5.
222    krej : float, optional
223        The number of sigma used for the rejection. Defaults to 2.5.
224    max_iterations : int, optional
225        The maximum number of iterations for the rejection. Defaults to
226        5.
227    """
228
229    def __init__(self, nsamples=1000, contrast=0.25, max_reject=0.5,
230                 min_npixels=5, krej=2.5, max_iterations=5):
231        self.nsamples = nsamples
232        self.contrast = contrast
233        self.max_reject = max_reject
234        self.min_npixels = min_npixels
235        self.krej = krej
236        self.max_iterations = max_iterations
237
238    def get_limits(self, values):
239        # Sample the image
240        values = np.asarray(values)
241        values = values[np.isfinite(values)]
242        stride = int(max(1.0, values.size / self.nsamples))
243        samples = values[::stride][:self.nsamples]
244        samples.sort()
245
246        npix = len(samples)
247        vmin = samples[0]
248        vmax = samples[-1]
249
250        # Fit a line to the sorted array of samples
251        minpix = max(self.min_npixels, int(npix * self.max_reject))
252        x = np.arange(npix)
253        ngoodpix = npix
254        last_ngoodpix = npix + 1
255
256        # Bad pixels mask used in k-sigma clipping
257        badpix = np.zeros(npix, dtype=bool)
258
259        # Kernel used to dilate the bad pixels mask
260        ngrow = max(1, int(npix * 0.01))
261        kernel = np.ones(ngrow, dtype=bool)
262
263        for _ in range(self.max_iterations):
264            if ngoodpix >= last_ngoodpix or ngoodpix < minpix:
265                break
266
267            fit = np.polyfit(x, samples, deg=1, w=(~badpix).astype(int))
268            fitted = np.poly1d(fit)(x)
269
270            # Subtract fitted line from the data array
271            flat = samples - fitted
272
273            # Compute the k-sigma rejection threshold
274            threshold = self.krej * flat[~badpix].std()
275
276            # Detect and reject pixels further than k*sigma from the
277            # fitted line
278            badpix[(flat < - threshold) | (flat > threshold)] = True
279
280            # Convolve with a kernel of length ngrow
281            badpix = np.convolve(badpix, kernel, mode='same')
282
283            last_ngoodpix = ngoodpix
284            ngoodpix = np.sum(~badpix)
285
286        if ngoodpix >= minpix:
287            slope, _ = fit
288
289            if self.contrast > 0:
290                slope = slope / self.contrast
291            center_pixel = (npix - 1) // 2
292            median = np.median(samples)
293            vmin = max(vmin, median - (center_pixel - 1) * slope)
294            vmax = min(vmax, median + (npix - center_pixel) * slope)
295
296        return vmin, vmax
297