1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2"""
3Combine 3 images to produce a properly-scaled RGB image following Lupton et al. (2004).
4
5The three images must be aligned and have the same pixel scale and size.
6
7For details, see : https://ui.adsabs.harvard.edu/abs/2004PASP..116..133L
8"""
9
10import numpy as np
11from . import ZScaleInterval
12
13
14__all__ = ['make_lupton_rgb']
15
16
17def compute_intensity(image_r, image_g=None, image_b=None):
18    """
19    Return a naive total intensity from the red, blue, and green intensities.
20
21    Parameters
22    ----------
23    image_r : ndarray
24        Intensity of image to be mapped to red; or total intensity if ``image_g``
25        and ``image_b`` are None.
26    image_g : ndarray, optional
27        Intensity of image to be mapped to green.
28    image_b : ndarray, optional
29        Intensity of image to be mapped to blue.
30
31    Returns
32    -------
33    intensity : ndarray
34        Total intensity from the red, blue and green intensities, or ``image_r``
35        if green and blue images are not provided.
36    """
37    if image_g is None or image_b is None:
38        if not (image_g is None and image_b is None):
39            raise ValueError("please specify either a single image "
40                             "or red, green, and blue images.")
41        return image_r
42
43    intensity = (image_r + image_g + image_b)/3.0
44
45    # Repack into whatever type was passed to us
46    return np.asarray(intensity, dtype=image_r.dtype)
47
48
49class Mapping:
50    """
51    Baseclass to map red, blue, green intensities into uint8 values.
52
53    Parameters
54    ----------
55    minimum : float or sequence(3)
56        Intensity that should be mapped to black (a scalar or array for R, G, B).
57    image : ndarray, optional
58        An image used to calculate some parameters of some mappings.
59    """
60
61    def __init__(self, minimum=None, image=None):
62        self._uint8Max = float(np.iinfo(np.uint8).max)
63
64        try:
65            len(minimum)
66        except TypeError:
67            minimum = 3*[minimum]
68        if len(minimum) != 3:
69            raise ValueError("please provide 1 or 3 values for minimum.")
70
71        self.minimum = minimum
72        self._image = np.asarray(image)
73
74    def make_rgb_image(self, image_r, image_g, image_b):
75        """
76        Convert 3 arrays, image_r, image_g, and image_b into an 8-bit RGB image.
77
78        Parameters
79        ----------
80        image_r : ndarray
81            Image to map to red.
82        image_g : ndarray
83            Image to map to green.
84        image_b : ndarray
85            Image to map to blue.
86
87        Returns
88        -------
89        RGBimage : ndarray
90            RGB (integer, 8-bits per channel) color image as an NxNx3 numpy array.
91        """
92        image_r = np.asarray(image_r)
93        image_g = np.asarray(image_g)
94        image_b = np.asarray(image_b)
95
96        if (image_r.shape != image_g.shape) or (image_g.shape != image_b.shape):
97            msg = "The image shapes must match. r: {}, g: {} b: {}"
98            raise ValueError(msg.format(image_r.shape, image_g.shape, image_b.shape))
99
100        return np.dstack(self._convert_images_to_uint8(image_r, image_g, image_b)).astype(np.uint8)
101
102    def intensity(self, image_r, image_g, image_b):
103        """
104        Return the total intensity from the red, blue, and green intensities.
105        This is a naive computation, and may be overridden by subclasses.
106
107        Parameters
108        ----------
109        image_r : ndarray
110            Intensity of image to be mapped to red; or total intensity if
111            ``image_g`` and ``image_b`` are None.
112        image_g : ndarray, optional
113            Intensity of image to be mapped to green.
114        image_b : ndarray, optional
115            Intensity of image to be mapped to blue.
116
117        Returns
118        -------
119        intensity : ndarray
120            Total intensity from the red, blue and green intensities, or
121            ``image_r`` if green and blue images are not provided.
122        """
123        return compute_intensity(image_r, image_g, image_b)
124
125    def map_intensity_to_uint8(self, I):
126        """
127        Return an array which, when multiplied by an image, returns that image
128        mapped to the range of a uint8, [0, 255] (but not converted to uint8).
129
130        The intensity is assumed to have had minimum subtracted (as that can be
131        done per-band).
132
133        Parameters
134        ----------
135        I : ndarray
136            Intensity to be mapped.
137
138        Returns
139        -------
140        mapped_I : ndarray
141            ``I`` mapped to uint8
142        """
143        with np.errstate(invalid='ignore', divide='ignore'):
144            return np.clip(I, 0, self._uint8Max)
145
146    def _convert_images_to_uint8(self, image_r, image_g, image_b):
147        """Use the mapping to convert images image_r, image_g, and image_b to a triplet of uint8 images"""
148        image_r = image_r - self.minimum[0]  # n.b. makes copy
149        image_g = image_g - self.minimum[1]
150        image_b = image_b - self.minimum[2]
151
152        fac = self.map_intensity_to_uint8(self.intensity(image_r, image_g, image_b))
153
154        image_rgb = [image_r, image_g, image_b]
155        for c in image_rgb:
156            c *= fac
157            with np.errstate(invalid='ignore'):
158                c[c < 0] = 0                # individual bands can still be < 0, even if fac isn't
159
160        pixmax = self._uint8Max
161        r0, g0, b0 = image_rgb           # copies -- could work row by row to minimise memory usage
162
163        with np.errstate(invalid='ignore', divide='ignore'):  # n.b. np.where can't and doesn't short-circuit
164            for i, c in enumerate(image_rgb):
165                c = np.where(r0 > g0,
166                             np.where(r0 > b0,
167                                      np.where(r0 >= pixmax, c*pixmax/r0, c),
168                                      np.where(b0 >= pixmax, c*pixmax/b0, c)),
169                             np.where(g0 > b0,
170                                      np.where(g0 >= pixmax, c*pixmax/g0, c),
171                                      np.where(b0 >= pixmax, c*pixmax/b0, c))).astype(np.uint8)
172                c[c > pixmax] = pixmax
173
174                image_rgb[i] = c
175
176        return image_rgb
177
178
179class LinearMapping(Mapping):
180    """
181    A linear map map of red, blue, green intensities into uint8 values.
182
183    A linear stretch from [minimum, maximum].
184    If one or both are omitted use image min and/or max to set them.
185
186    Parameters
187    ----------
188    minimum : float
189        Intensity that should be mapped to black (a scalar or array for R, G, B).
190    maximum : float
191        Intensity that should be mapped to white (a scalar).
192    """
193
194    def __init__(self, minimum=None, maximum=None, image=None):
195        if minimum is None or maximum is None:
196            if image is None:
197                raise ValueError("you must provide an image if you don't "
198                                 "set both minimum and maximum")
199            if minimum is None:
200                minimum = image.min()
201            if maximum is None:
202                maximum = image.max()
203
204        Mapping.__init__(self, minimum=minimum, image=image)
205        self.maximum = maximum
206
207        if maximum is None:
208            self._range = None
209        else:
210            if maximum == minimum:
211                raise ValueError("minimum and maximum values must not be equal")
212            self._range = float(maximum - minimum)
213
214    def map_intensity_to_uint8(self, I):
215        with np.errstate(invalid='ignore', divide='ignore'):  # n.b. np.where can't and doesn't short-circuit
216            return np.where(I <= 0, 0,
217                            np.where(I >= self._range, self._uint8Max/I, self._uint8Max/self._range))
218
219
220class AsinhMapping(Mapping):
221    """
222    A mapping for an asinh stretch (preserving colours independent of brightness)
223
224    x = asinh(Q (I - minimum)/stretch)/Q
225
226    This reduces to a linear stretch if Q == 0
227
228    See https://ui.adsabs.harvard.edu/abs/2004PASP..116..133L
229
230    Parameters
231    ----------
232
233    minimum : float
234        Intensity that should be mapped to black (a scalar or array for R, G, B).
235    stretch : float
236        The linear stretch of the image.
237    Q : float
238        The asinh softening parameter.
239    """
240
241    def __init__(self, minimum, stretch, Q=8):
242        Mapping.__init__(self, minimum)
243
244        epsilon = 1.0/2**23            # 32bit floating point machine epsilon; sys.float_info.epsilon is 64bit
245        if abs(Q) < epsilon:
246            Q = 0.1
247        else:
248            Qmax = 1e10
249            if Q > Qmax:
250                Q = Qmax
251
252        frac = 0.1                  # gradient estimated using frac*stretch is _slope
253        self._slope = frac*self._uint8Max/np.arcsinh(frac*Q)
254
255        self._soften = Q/float(stretch)
256
257    def map_intensity_to_uint8(self, I):
258        with np.errstate(invalid='ignore', divide='ignore'):  # n.b. np.where can't and doesn't short-circuit
259            return np.where(I <= 0, 0, np.arcsinh(I*self._soften)*self._slope/I)
260
261
262class AsinhZScaleMapping(AsinhMapping):
263    """
264    A mapping for an asinh stretch, estimating the linear stretch by zscale.
265
266    x = asinh(Q (I - z1)/(z2 - z1))/Q
267
268    Parameters
269    ----------
270    image1 : ndarray or a list of arrays
271        The image to analyse, or a list of 3 images to be converted to
272        an intensity image.
273    image2 : ndarray, optional
274        the second image to analyse (must be specified with image3).
275    image3 : ndarray, optional
276        the third image to analyse (must be specified with image2).
277    Q : float, optional
278        The asinh softening parameter. Default is 8.
279    pedestal : float or sequence(3), optional
280        The value, or array of 3 values, to subtract from the images; or None.
281
282    Notes
283    -----
284    pedestal, if not None, is removed from the images when calculating the
285    zscale stretch, and added back into Mapping.minimum[]
286    """
287
288    def __init__(self, image1, image2=None, image3=None, Q=8, pedestal=None):
289        """
290        """
291
292        if image2 is None or image3 is None:
293            if not (image2 is None and image3 is None):
294                raise ValueError("please specify either a single image "
295                                 "or three images.")
296            image = [image1]
297        else:
298            image = [image1, image2, image3]
299
300        if pedestal is not None:
301            try:
302                len(pedestal)
303            except TypeError:
304                pedestal = 3*[pedestal]
305
306            if len(pedestal) != 3:
307                raise ValueError("please provide 1 or 3 pedestals.")
308
309            image = list(image)        # needs to be mutable
310            for i, im in enumerate(image):
311                if pedestal[i] != 0.0:
312                    image[i] = im - pedestal[i]  # n.b. a copy
313        else:
314            pedestal = len(image)*[0.0]
315
316        image = compute_intensity(*image)
317
318        zscale_limits = ZScaleInterval().get_limits(image)
319        zscale = LinearMapping(*zscale_limits, image=image)
320        stretch = zscale.maximum - zscale.minimum[0]  # zscale.minimum is always a triple
321        minimum = zscale.minimum
322
323        for i, level in enumerate(pedestal):
324            minimum[i] += level
325
326        AsinhMapping.__init__(self, minimum, stretch, Q)
327        self._image = image
328
329
330def make_lupton_rgb(image_r, image_g, image_b, minimum=0, stretch=5, Q=8,
331                    filename=None):
332    """
333    Return a Red/Green/Blue color image from up to 3 images using an asinh stretch.
334    The input images can be int or float, and in any range or bit-depth.
335
336    For a more detailed look at the use of this method, see the document
337    :ref:`astropy:astropy-visualization-rgb`.
338
339    Parameters
340    ----------
341
342    image_r : ndarray
343        Image to map to red.
344    image_g : ndarray
345        Image to map to green.
346    image_b : ndarray
347        Image to map to blue.
348    minimum : float
349        Intensity that should be mapped to black (a scalar or array for R, G, B).
350    stretch : float
351        The linear stretch of the image.
352    Q : float
353        The asinh softening parameter.
354    filename: str
355        Write the resulting RGB image to a file (file type determined
356        from extension).
357
358    Returns
359    -------
360    rgb : ndarray
361        RGB (integer, 8-bits per channel) color image as an NxNx3 numpy array.
362    """
363    asinhMap = AsinhMapping(minimum, stretch, Q)
364    rgb = asinhMap.make_rgb_image(image_r, image_g, image_b)
365
366    if filename:
367        import matplotlib.image
368        matplotlib.image.imsave(filename, rgb, origin='lower')
369
370    return rgb
371