1# Licensed under a 3-clause BSD style license - see LICENSE.rst 2 3""" 4Classes that deal with computing intervals from arrays of values based on 5various criteria. 6""" 7 8import abc 9import numpy as np 10 11from .transform import BaseTransform 12 13 14__all__ = ['BaseInterval', 'ManualInterval', 'MinMaxInterval', 15 'AsymmetricPercentileInterval', 'PercentileInterval', 16 'ZScaleInterval'] 17 18 19class BaseInterval(BaseTransform): 20 """ 21 Base class for the interval classes, which, when called with an 22 array of values, return an interval computed following different 23 algorithms. 24 """ 25 26 @abc.abstractmethod 27 def get_limits(self, values): 28 """ 29 Return the minimum and maximum value in the interval based on 30 the values provided. 31 32 Parameters 33 ---------- 34 values : ndarray 35 The image values. 36 37 Returns 38 ------- 39 vmin, vmax : float 40 The mininium and maximum image value in the interval. 41 """ 42 43 raise NotImplementedError('Needs to be implemented in a subclass.') 44 45 def __call__(self, values, clip=True, out=None): 46 """ 47 Transform values using this interval. 48 49 Parameters 50 ---------- 51 values : array-like 52 The input values. 53 clip : bool, optional 54 If `True` (default), values outside the [0:1] range are 55 clipped to the [0:1] range. 56 out : ndarray, optional 57 If specified, the output values will be placed in this array 58 (typically used for in-place calculations). 59 60 Returns 61 ------- 62 result : ndarray 63 The transformed values. 64 """ 65 66 vmin, vmax = self.get_limits(values) 67 68 if out is None: 69 values = np.subtract(values, float(vmin)) 70 else: 71 if out.dtype.kind != 'f': 72 raise TypeError('Can only do in-place scaling for ' 73 'floating-point arrays') 74 values = np.subtract(values, float(vmin), out=out) 75 76 if (vmax - vmin) != 0: 77 np.true_divide(values, vmax - vmin, out=values) 78 79 if clip: 80 np.clip(values, 0., 1., out=values) 81 82 return values 83 84 85class ManualInterval(BaseInterval): 86 """ 87 Interval based on user-specified values. 88 89 Parameters 90 ---------- 91 vmin : float, optional 92 The minimum value in the scaling. Defaults to the image 93 minimum (ignoring NaNs) 94 vmax : float, optional 95 The maximum value in the scaling. Defaults to the image 96 maximum (ignoring NaNs) 97 """ 98 99 def __init__(self, vmin=None, vmax=None): 100 self.vmin = vmin 101 self.vmax = vmax 102 103 def get_limits(self, values): 104 # Make sure values is a Numpy array 105 values = np.asarray(values).ravel() 106 107 # Filter out invalid values (inf, nan) 108 values = values[np.isfinite(values)] 109 110 vmin = np.min(values) if self.vmin is None else self.vmin 111 vmax = np.max(values) if self.vmax is None else self.vmax 112 return vmin, vmax 113 114 115class MinMaxInterval(BaseInterval): 116 """ 117 Interval based on the minimum and maximum values in the data. 118 """ 119 120 def get_limits(self, values): 121 # Make sure values is a Numpy array 122 values = np.asarray(values).ravel() 123 124 # Filter out invalid values (inf, nan) 125 values = values[np.isfinite(values)] 126 127 return np.min(values), np.max(values) 128 129 130class AsymmetricPercentileInterval(BaseInterval): 131 """ 132 Interval based on a keeping a specified fraction of pixels (can be 133 asymmetric). 134 135 Parameters 136 ---------- 137 lower_percentile : float 138 The lower percentile below which to ignore pixels. 139 upper_percentile : float 140 The upper percentile above which to ignore pixels. 141 n_samples : int, optional 142 Maximum number of values to use. If this is specified, and there 143 are more values in the dataset as this, then values are randomly 144 sampled from the array (with replacement). 145 """ 146 147 def __init__(self, lower_percentile, upper_percentile, n_samples=None): 148 self.lower_percentile = lower_percentile 149 self.upper_percentile = upper_percentile 150 self.n_samples = n_samples 151 152 def get_limits(self, values): 153 # Make sure values is a Numpy array 154 values = np.asarray(values).ravel() 155 156 # If needed, limit the number of samples. We sample with replacement 157 # since this is much faster. 158 if self.n_samples is not None and values.size > self.n_samples: 159 values = np.random.choice(values, self.n_samples) 160 161 # Filter out invalid values (inf, nan) 162 values = values[np.isfinite(values)] 163 164 # Determine values at percentiles 165 vmin, vmax = np.percentile(values, (self.lower_percentile, 166 self.upper_percentile)) 167 168 return vmin, vmax 169 170 171class PercentileInterval(AsymmetricPercentileInterval): 172 """ 173 Interval based on a keeping a specified fraction of pixels. 174 175 Parameters 176 ---------- 177 percentile : float 178 The fraction of pixels to keep. The same fraction of pixels is 179 eliminated from both ends. 180 n_samples : int, optional 181 Maximum number of values to use. If this is specified, and there 182 are more values in the dataset as this, then values are randomly 183 sampled from the array (with replacement). 184 """ 185 186 def __init__(self, percentile, n_samples=None): 187 lower_percentile = (100 - percentile) * 0.5 188 upper_percentile = 100 - lower_percentile 189 super().__init__( 190 lower_percentile, upper_percentile, n_samples=n_samples) 191 192 193class ZScaleInterval(BaseInterval): 194 """ 195 Interval based on IRAF's zscale. 196 197 https://iraf.net/forum/viewtopic.php?showtopic=134139 198 199 Original implementation: 200 https://github.com/spacetelescope/stsci.numdisplay/blob/master/lib/stsci/numdisplay/zscale.py 201 202 Licensed under a 3-clause BSD style license (see AURA_LICENSE.rst). 203 204 Parameters 205 ---------- 206 nsamples : int, optional 207 The number of points in the array to sample for determining 208 scaling factors. Defaults to 1000. 209 contrast : float, optional 210 The scaling factor (between 0 and 1) for determining the minimum 211 and maximum value. Larger values increase the difference 212 between the minimum and maximum values used for display. 213 Defaults to 0.25. 214 max_reject : float, optional 215 If more than ``max_reject * npixels`` pixels are rejected, then 216 the returned values are the minimum and maximum of the data. 217 Defaults to 0.5. 218 min_npixels : int, optional 219 If there are less than ``min_npixels`` pixels remaining after 220 the pixel rejection, then the returned values are the minimum 221 and maximum of the data. Defaults to 5. 222 krej : float, optional 223 The number of sigma used for the rejection. Defaults to 2.5. 224 max_iterations : int, optional 225 The maximum number of iterations for the rejection. Defaults to 226 5. 227 """ 228 229 def __init__(self, nsamples=1000, contrast=0.25, max_reject=0.5, 230 min_npixels=5, krej=2.5, max_iterations=5): 231 self.nsamples = nsamples 232 self.contrast = contrast 233 self.max_reject = max_reject 234 self.min_npixels = min_npixels 235 self.krej = krej 236 self.max_iterations = max_iterations 237 238 def get_limits(self, values): 239 # Sample the image 240 values = np.asarray(values) 241 values = values[np.isfinite(values)] 242 stride = int(max(1.0, values.size / self.nsamples)) 243 samples = values[::stride][:self.nsamples] 244 samples.sort() 245 246 npix = len(samples) 247 vmin = samples[0] 248 vmax = samples[-1] 249 250 # Fit a line to the sorted array of samples 251 minpix = max(self.min_npixels, int(npix * self.max_reject)) 252 x = np.arange(npix) 253 ngoodpix = npix 254 last_ngoodpix = npix + 1 255 256 # Bad pixels mask used in k-sigma clipping 257 badpix = np.zeros(npix, dtype=bool) 258 259 # Kernel used to dilate the bad pixels mask 260 ngrow = max(1, int(npix * 0.01)) 261 kernel = np.ones(ngrow, dtype=bool) 262 263 for _ in range(self.max_iterations): 264 if ngoodpix >= last_ngoodpix or ngoodpix < minpix: 265 break 266 267 fit = np.polyfit(x, samples, deg=1, w=(~badpix).astype(int)) 268 fitted = np.poly1d(fit)(x) 269 270 # Subtract fitted line from the data array 271 flat = samples - fitted 272 273 # Compute the k-sigma rejection threshold 274 threshold = self.krej * flat[~badpix].std() 275 276 # Detect and reject pixels further than k*sigma from the 277 # fitted line 278 badpix[(flat < - threshold) | (flat > threshold)] = True 279 280 # Convolve with a kernel of length ngrow 281 badpix = np.convolve(badpix, kernel, mode='same') 282 283 last_ngoodpix = ngoodpix 284 ngoodpix = np.sum(~badpix) 285 286 if ngoodpix >= minpix: 287 slope, _ = fit 288 289 if self.contrast > 0: 290 slope = slope / self.contrast 291 center_pixel = (npix - 1) // 2 292 median = np.median(samples) 293 vmin = max(vmin, median - (center_pixel - 1) * slope) 294 vmax = min(vmax, median + (npix - center_pixel) * slope) 295 296 return vmin, vmax 297