1# This file is part of the Astrometry.net suite.
2# Licensed under a 3-clause BSD style license - see LICENSE
3from __future__ import print_function
4from __future__ import absolute_import
5import numpy as np
6from scipy.ndimage.filters import gaussian_filter
7
8from astrometry.util.resample import resample_with_wcs, ResampleError
9
10from .fields import radec_to_sdss_rcf
11from .common import band_name, band_index, AsTransWrapper
12from functools import reduce
13
14def get_sdss_cutout(targetwcs, sdss, get_rawvals=False, bands='irg',
15                    get_rawvals_only=False,
16                    bandscales=dict(z=1.0, i=1.0, r=1.3, g=2.5)):
17
18    rgbims = []
19
20    ra,dec = targetwcs.radec_center()
21    # in deg
22    radius = targetwcs.radius()
23    #print 'Target WCS radius is', radius, 'deg'
24    H,W = targetwcs.get_height(), targetwcs.get_width()
25    targetpixscale = targetwcs.pixel_scale()
26
27    wlistfn = sdss.filenames.get('window_flist', 'window_flist.fits')
28    rad2 = radius*60. + np.hypot(14., 10.)/2.
29    #print 'Rad2 radius', rad2, 'arcmin'
30    RCF = radec_to_sdss_rcf(ra, dec, tablefn=wlistfn, radius=rad2)
31
32    # Drop rerun 157
33    keepRCF = []
34    for run,camcol,field,r,d in RCF:
35        rr = sdss.get_rerun(run, field)
36        #print 'Rerun:', rr
37        if rr == '157':
38            continue
39        keepRCF.append((run,camcol,field))
40    RCF = keepRCF
41    print(len(RCF), 'run/camcol/fields in range')
42
43    # size in SDSS pixels of the target image.
44    sz = np.hypot(H, W)/2. * targetpixscale / 0.396
45    print('SDSS sz:', sz)
46
47    bandnums = [band_index(b) for b in bands]
48
49    for bandnum,band in zip(bandnums, bands):
50        targetim = np.zeros((H, W), np.float32)
51        targetn  = np.zeros((H, W), np.uint8)
52
53        for ifield,(run,camcol,field) in enumerate(RCF):
54
55            fn = sdss.retrieve('frame', run, camcol, field, band)
56            frame = sdss.readFrame(run, camcol, field, bandnum)
57
58            h,w = frame.getImageShape()
59            x,y = frame.astrans.radec_to_pixel(ra, dec)
60            x,y = int(x), int(y)
61            # add some margin for resampling
62            sz2 = int(sz) + 5
63            xlo = np.clip(x - sz2, 0, w)
64            xhi = np.clip(x + sz2 + 1, 0, w)
65            ylo = np.clip(y - sz2, 0, h)
66            yhi = np.clip(y + sz2 + 1, 0, h)
67            if xlo == xhi or ylo == yhi:
68                continue
69            stamp = frame.getImageSlice((slice(ylo, yhi), slice(xlo, xhi)))
70            sh,sw = stamp.shape
71            wcs = AsTransWrapper(frame.astrans, sw, sh, x0=xlo, y0=ylo)
72            # FIXME -- allow nn resampling too
73            try:
74                Yo,Xo,Yi,Xi,[rim] = resample_with_wcs(targetwcs, wcs, [stamp], 3)
75            except ResampleError:
76                continue
77            targetim[Yo,Xo] += rim
78            targetn [Yo,Xo] += 1
79
80        rgbims.append(targetim / targetn)
81
82    if get_rawvals_only:
83        return rgbims
84
85    if get_rawvals:
86        rawvals = [x.copy() for x in rgbims]
87
88    r,g,b = rgbims
89
90    r *= bandscales[bands[0]]
91    g *= bandscales[bands[1]]
92    b *= bandscales[bands[2]]
93
94    # i
95    #r *= 1.0
96    # r
97    #g *= 1.5
98    #g *= 1.3
99    # g
100    #b *= 2.5
101    m = -0.02
102    r = np.maximum(0, r - m)
103    g = np.maximum(0, g - m)
104    b = np.maximum(0, b - m)
105    I = (r+g+b)/3.
106    alpha = 1.5
107    Q = 20
108    m2 = 0.
109    fI = np.arcsinh(alpha * Q * (I - m2)) / np.sqrt(Q)
110    I += (I == 0.) * 1e-6
111    R = fI * r / I
112    G = fI * g / I
113    B = fI * b / I
114    maxrgb = reduce(np.maximum, [R,G,B])
115    J = (maxrgb > 1.)
116    R[J] = R[J]/maxrgb[J]
117    G[J] = G[J]/maxrgb[J]
118    B[J] = B[J]/maxrgb[J]
119    ss = 0.5
120    RGBblur = np.clip(np.dstack([
121        gaussian_filter(R, ss),
122        gaussian_filter(G, ss),
123        gaussian_filter(B, ss)]), 0., 1.)
124
125    if get_rawvals:
126        return RGBblur, rawvals
127    return RGBblur
128
129
130
131if __name__ == '__main__':
132    import tempfile
133    import matplotlib
134    matplotlib.use('Agg')
135    import pylab as plt
136
137    from .dr10 import DR10
138    from astrometry.util.util import Tan
139
140    tempdir = tempfile.gettempdir()
141
142    sdss = DR10(basedir=tempdir)
143    sdss.saveUnzippedFiles(tempdir)
144
145    W,H = 100, 100
146    pixscale = 1.
147    cd = pixscale / 3600.
148    targetwcs = Tan(120., 10., W/2., H/2., -cd, 0., 0., cd, float(W), float(H))
149    rgb = get_sdss_cutout(targetwcs, sdss)
150
151    plt.clf()
152    plt.imshow(rgb, interpolation='nearest', origin='lower')
153    plt.savefig('cutout1.png')
154
155
156    W,H = 3000, 3000
157    pixscale = 0.5
158    cd = pixscale / 3600.
159    targetwcs = Tan(120., 10., W/2., H/2., -cd, 0., 0., cd, float(W), float(H))
160    rgb = get_sdss_cutout(targetwcs, sdss)
161
162    plt.clf()
163    plt.imshow(rgb, interpolation='nearest', origin='lower')
164    plt.savefig('cutout2.png')
165
166