1# This file is part of libkd.
2# Licensed under a 3-clause BSD style license - see LICENSE
3from __future__ import print_function
4from astrometry.libkd import spherematch_c
5
6from astrometry.util.starutil_numpy import radectoxyz, deg2dist, dist2deg, distsq2deg
7
8import numpy as np
9
10def match_xy(x1,y1, x2,y2, R, **kwargs):
11    '''
12    Like match_radec, except for plain old 2-D points.
13    '''
14    I,d = match(np.vstack((x1,y1)).T, np.vstack((x2,y2)).T, R, **kwargs)
15    return I[:,0],I[:,1],d
16
17# Copied from "celestial.py" by Sjoert van Velzen.
18def match_radec(ra1, dec1, ra2, dec2, radius_in_deg, notself=False,
19                nearest=False, indexlist=False, count=False):
20    '''
21    Cross-matches numpy arrays of RA,Dec points.
22
23    Behaves like spherematch.pro of IDL.
24
25    Parameters
26    ----------
27    ra1, dec1, ra2, dec2 : numpy arrays, or scalars.
28        RA,Dec in degrees of points to match.
29
30    radius_in_deg : float
31        Search radius in degrees.
32
33    notself : boolean
34        If True, avoids returning 'identity' matches;
35        ASSUMES that ra1,dec1 == ra2,dec2.
36
37    nearest : boolean
38        If True, returns only the nearest match in *(ra2,dec2)*
39        for each point in *(ra1,dec1)*.
40
41    indexlist : boolean
42        If True, returns a list of length *len(ra1)*, containing *None*
43        or a list of ints of matched points in *ra2,dec2*.
44
45
46    Returns
47    -------
48    m1 : numpy array of integers
49        Indices into the *ra1,dec1* arrays of matching points.
50    m2 : numpy array of integers
51        Same, but for *ra2,dec2*.
52    d12 : numpy array, float
53        Distance, in degrees, between the matching points.
54    '''
55    # Convert to coordinates on the unit sphere
56    xyz1 = radectoxyz(ra1, dec1)
57    #if all(ra1 == ra2) and all(dec1 == dec2):
58    if ra1 is ra2 and dec1 is dec2:
59        xyz2 = xyz1
60    else:
61        xyz2 = radectoxyz(ra2, dec2)
62    r = deg2dist(radius_in_deg)
63
64    extra = ()
65    if nearest:
66        X = _nearest_func(xyz2, xyz1, r, notself=notself, count=count)
67        if not count:
68            (inds,dists2) = X
69            I = np.flatnonzero(inds >= 0)
70            J = inds[I]
71            d = distsq2deg(dists2[I])
72        else:
73            #print 'X', X
74            #(inds,dists2,counts) = X
75            J,I,d,counts = X
76            extra = (counts,)
77            print('I', I.shape, I.dtype)
78            print('J', J.shape, J.dtype)
79            print('counts', counts.shape, counts.dtype)
80    else:
81        X = match(xyz1, xyz2, r, notself=notself, indexlist=indexlist)
82        if indexlist:
83            return X
84        (inds,dists) = X
85        dist_in_deg = dist2deg(dists)
86        I,J = inds[:,0], inds[:,1]
87        d = dist_in_deg[:,0]
88
89    return (I, J, d) + extra
90
91
92def cluster_radec(ra, dec, R, singles=False):
93    '''
94    Finds connected groups of objects in RA,Dec space.
95
96    Returns a list of lists of indices that are connected,
97    EXCLUDING singletons.
98
99    If *singles* is *True*, also returns the indices of singletons.
100    '''
101    I,J,d = match_radec(ra, dec, ra, dec, R, notself=True)
102
103    # 'mgroups' maps each index in a group to a list of the group members
104    mgroups = {}
105    # 'ugroups' is a list of the unique groups
106    ugroups = []
107
108    for i,j in zip(I,J):
109        # Are both sources already in groups?
110        if i in mgroups and j in mgroups:
111            # Are they already in the same group?
112            if mgroups[i] == mgroups[j]:
113                continue
114            # merge if they are different;
115            # assert(they are disjoint)
116            lsti = mgroups[i]
117            lstj = mgroups[j]
118            merge = lsti + lstj
119            for k in merge:
120                mgroups[k] = merge
121
122            ugroups.remove(lsti)
123            ugroups.remove(lstj)
124            ugroups.append(merge)
125
126        elif i in mgroups:
127            # Add j to i's group
128            lst = mgroups[i]
129            lst.append(j)
130            mgroups[j] = lst
131        elif j in mgroups:
132            # Add i to j's group
133            lst = mgroups[j]
134            lst.append(i)
135            mgroups[i] = lst
136        else:
137            # Create a new group
138            lst = [i,j]
139            mgroups[i] = lst
140            mgroups[j] = lst
141
142            ugroups.append(lst)
143
144
145    if singles:
146        S = np.ones(len(ra), bool)
147        for g in ugroups:
148            S[np.array(g)] = False
149        S = np.flatnonzero(S)
150        return ugroups,S
151
152    return ugroups
153
154
155
156
157
158def _cleaninputs(x1, x2):
159    fx1 = x1.astype(np.float64)
160    if x2 is x1:
161        fx2 = fx1
162    else:
163        fx2 = x2.astype(np.float64)
164    (N1,D1) = fx1.shape
165    (N2,D2) = fx2.shape
166    if D1 != D2:
167        raise ValueError('Arrays must have the same dimensionality')
168    return (fx1,fx2)
169
170def _buildtrees(x1, x2):
171    (fx1, fx2) = _cleaninputs(x1, x2)
172    kd1 = spherematch_c.KdTree(fx1)
173    if fx2 is fx1:
174        kd2 = kd1
175    else:
176        kd2 = spherematch_c.KdTree(fx2)
177    return (kd1, kd2)
178
179def match(x1, x2, radius, notself=False, permuted=True, indexlist=False):
180    '''
181    ::
182
183        (indices,dists) = match(x1, x2, radius):
184
185    Or::
186
187        inds = match(x1, x2, radius, indexlist=True):
188
189    Returns the indices (Nx2 int array) and distances (Nx1 float
190    array) between points in *x1* and *x2* that are within *radius*
191    Euclidean distance of each other.
192
193    *x1* is N1xD and *x2* is N2xD.  *x1* and *x2* can be the same
194    array.  Dimensions D above 5-10 will probably not run faster than
195    naive.
196
197    Despite the name of this package, the arrays x1 and x2 need not be
198    celestial positions; in particular, there is no RA wrapping at 0,
199    and no special handling at the poles.  If you want to match
200    celestial coordinates like RA,Dec, see the match_radec function.
201
202    If *indexlist* is True, the return value is a python list with one
203    element per data point in the first tree; that element is a python
204    list containing the indices of points matched in the second tree.
205
206    The *indices* return value has a row for each match; the matched
207    points are:
208    x1[indices[:,0],:]
209    and
210    x2[indices[:,1],:]
211
212    This function doesn\'t know about spherical coordinates -- it just
213    searches for matches in n-dimensional space.
214
215    >>> from astrometry.util.starutil_numpy import *
216    >>> from astrometry.libkd import spherematch
217    >>> # RA,Dec in degrees
218    >>> ra1  = array([  0,  1, 2, 3, 4, 359,360])
219    >>> dec1 = array([-90,-89,-1, 0, 1,  89, 90])
220    >>> # xyz: N x 3 array: unit vectors
221    >>> xyz1 = radectoxyz(ra1, dec1)
222    >>> ra2  = array([ 45,   1,  4, 4, 4,  0,  1])
223    >>> dec2 = array([-89, -88, -1, 0, 2, 89, 89])
224    >>> xyz2 = radectoxyz(ra2, dec2)
225    >>> # The \'radius\' is now distance between points on the unit sphere --
226    >>> # for small angles, this is ~ angular distance in radians.  You can use
227    >>> # the function:
228    >>> radius_in_deg = 2.
229    >>> r = sqrt(deg2distsq(radius_in_deg))
230    >>> (inds,dists) = spherematch.match(xyz1, xyz2, r)
231    >>> # Now *inds* is an Mx2 array of the matching indices,
232    >>> # and *dists* the distances between them:
233    >>> #  eg,  sqrt(sum((xyz1[inds[:,0],:] - xyz2[inds[:,1],:])**2, axis=1)) = dists
234    >>> print inds
235    [[0 0]
236     [1 0]
237     [1 1]
238     [2 2]
239     [3 2]
240     [3 3]
241     [4 3]
242     [4 4]
243     [5 5]
244     [6 5]
245     [5 6]
246     [6 6]]
247    >>> print sqrt(sum((xyz1[inds[:,0],:] - xyz2[inds[:,1],:])**2, axis=1))
248    [ 0.01745307  0.01307557  0.01745307  0.0348995   0.02468143  0.01745307
249      0.01745307  0.01745307  0.0003046   0.01745307  0.00060917  0.01745307]
250    >>> print dists[:,0]
251    [ 0.01745307  0.01307557  0.01745307  0.0348995   0.02468143  0.01745307
252      0.01745307  0.01745307  0.0003046   0.01745307  0.00060917  0.01745307]
253    >>> print vstack((ra1[inds[:,0]], dec1[inds[:,0]], ra2[inds[:,1]], dec2[inds[:,1]])).T
254    [[  0 -90  45 -89]
255     [  1 -89  45 -89]
256     [  1 -89   1 -88]
257     [  2  -1   4  -1]
258     [  3   0   4  -1]
259     [  3   0   4   0]
260     [  4   1   4   0]
261     [  4   1   4   2]
262     [359  89   0  89]
263     [360  90   0  89]
264     [359  89   1  89]
265     [360  90   1  89]]
266
267    Parameters
268    ----------
269    x1 : numpy array, float, shape N1 x D
270        First array of points to match
271
272    x2 : numpy array, float, shape N2 x D
273        Second array of points to match
274
275    radius : float
276        Scalar Euclidean distance to match
277
278    Returns
279    -------
280    indices : numpy array, integers, shape M x 2, for M matches
281        The array of matching indices; *indices[:,0]* are indices in *x1*,
282        *indices[:,1]* are indices in *x2*.
283
284    dists : numpy array, floats, length M, for M matches
285        The distances between matched points.
286
287    If *indexlist* is *True*:
288
289    indices : list of ints of integers
290        The list of matching indices.  One list element per *x1* element,
291        containing a list of matching indices in *x2*.
292
293    '''
294    (kd1,kd2) = _buildtrees(x1, x2)
295    if indexlist:
296        inds = spherematch_c.match2(kd1, kd2, radius, notself, permuted)
297    else:
298        (inds,dists) = spherematch_c.match(kd1, kd2, radius, notself, permuted)
299    if indexlist:
300        return inds
301    return (inds,dists)
302
303def match_naive(x1, x2, radius, notself=False):
304    ''' Does the same thing as match(), but the straight-forward slow
305    way.  (Not necessarily the way you\'d do it in python either).
306    Not very fair as a speed comparison, but useful to convince
307    yourself that match() does the right thing.
308    '''
309    (fx1, fx2) = _cleaninputs(x1, x2)
310    (N1,D1) = x1.shape
311    (N2,D2) = x2.shape
312    inds = []
313    dists = []
314    for i1 in range(N1):
315        for i2 in range(N2):
316            if notself and i1 == i2:
317                continue
318            d2 = sum((x1[i1,:] - x2[i2,:])**2)
319            if d2 < radius**2:
320                inds.append((i1,i2))
321                dists.append(sqrt(d2))
322    inds = array(inds)
323    dists = array(dists)
324    return (inds,dists)
325
326def nearest(x1, x2, maxradius, notself=False, count=False):
327    '''
328    For each point in x2, returns the index of the nearest point in x1,
329    if there is a point within 'maxradius'.
330
331    (Note, this may be backward from what you want/expect!)
332    '''
333    (kd1,kd2) = _buildtrees(x1, x2)
334    if count:
335        X = spherematch_c.nearest2(kd1, kd2, maxradius, notself, count)
336    else:
337        X = spherematch_c.nearest(kd1, kd2, maxradius, notself)
338    return X
339_nearest_func = nearest
340
341def tree_build_radec(ra=None, dec=None, xyz=None):
342    '''
343    Builds a kd-tree given *RA,Dec* or unit-sphere *xyz* coordinates.
344    '''
345    if ra is not None:
346        (N,) = ra.shape
347        xyz = np.zeros((N,3)).astype(float)
348        xyz[:,2] = np.sin(np.deg2rad(dec))
349        cosd = np.cos(np.deg2rad(dec))
350        xyz[:,0] = cosd * np.cos(np.deg2rad(ra))
351        xyz[:,1] = cosd * np.sin(np.deg2rad(ra))
352    kd = spherematch_c.KdTree(xyz)
353    return kd
354
355def tree_build(X, nleaf=16, bbox=True, split=False):
356    '''
357    Builds a kd-tree given a numpy array of Euclidean points.
358
359    Parameters
360    ----------
361    X: numpy array of shape (N,D)
362        The points to index.
363
364    Returns
365    -------
366    kd: integer
367        kd-tree identifier (address).
368    '''
369    return spherematch_c.KdTree(X, nleaf=nleaf, bbox=bbox, split=split)
370
371def tree_free(kd):
372    '''
373    Frees a kd-tree previously created with *tree_build*.
374    '''
375    print('No need for tree_free')
376    pass
377
378def tree_save(kd, fn):
379    '''
380    Writes a kd-tree to the given filename.
381    '''
382    print('Deprecated tree_save()')
383    return kd.write(fn)
384#rtn = spherematch_c.kdtree_write(kd, fn)
385#return rtn
386
387def tree_open(fn, treename=None):
388    '''
389    Reads a kd-tree from the given filename.
390    '''
391    if treename is None:
392        return spherematch_c.KdTree(fn)
393    else:
394        return spherematch_c.KdTree(fn, treename)
395
396def tree_close(kd):
397    '''
398    Closes a kd-tree previously opened with *tree_open*.
399    '''
400    print('No need for tree_close')
401    pass
402
403def tree_search(kd, pos, radius, getdists=False, sortdists=False):
404    '''
405    Searches the given kd-tree for points within *radius* of the given
406    position *pos*.
407    '''
408    #print('Unnecessary call to tree_search(kd, ...); use kd.search(...)')
409    return kd.search(pos, radius, int(getdists), int(sortdists))
410
411def tree_search_radec(kd, ra, dec, radius, getdists=False, sortdists=False):
412    '''
413    ra,dec in degrees
414    radius in degrees
415    '''
416    dec = np.deg2rad(dec)
417    cosd = np.cos(dec)
418    ra = np.deg2rad(ra)
419    pos = np.array([cosd * np.cos(ra), cosd * np.sin(ra), np.sin(dec)])
420    rad = deg2dist(radius)
421    return tree_search(kd, pos, rad, getdists=getdists, sortdists=sortdists)
422
423def trees_match(kd1, kd2, radius, nearest=False, notself=False,
424                permuted=True, count=False):
425    '''
426    Runs rangesearch or nearest-neighbour matching on given kdtrees.
427
428    'radius' is Euclidean distance.
429
430    If 'nearest'=True, returns the nearest neighbour of each point in "kd1";
431    ie, "I" will NOT contain duplicates, but "J" may.
432
433    If 'count'=True, also counts the number of objects within range
434    as well as returning the nearest neighbor of each point in "kd1";
435    the return value becomes I,J,d,counts , counts a numpy array of ints.
436
437    Returns (I, J, d), where
438      I are indices into kd1
439      J are indices into kd2
440      d are distances-squared
441      [counts is number of sources in range]
442
443    >>> import numpy as np
444    >>> X = np.array([[1, 2, 3, 6]]).T.astype(float)
445    >>> Y = np.array([[1, 4, 4]]).T.astype(float)
446    >>> kd1 = tree_build(X)
447    >>> kd2 = tree_build(Y)
448    >>> I,J,d = trees_match(kd1, kd2, 1.1, nearest=True)
449    >>> print I
450    [0 1 2]
451    >>> print J
452    [0 0 2]
453    >>> print d
454    [  0.  60.  60.]
455    >>> I,J,d,count = trees_match(kd1, kd2, 1.1, nearest=True, count=True)
456    >>> print I
457    [0 1 2]
458    >>> print J
459    [0 0 2]
460    >>> print d
461    [  0.  60.  60.]
462    >>> print count
463    [1 1 2]
464    '''
465    rtn = None
466    if nearest:
467        rtn = spherematch_c.nearest2(kd2, kd1, radius, notself, count)
468        # J,I,d,[count]
469        rtn = (rtn[1], rtn[0], np.sqrt(rtn[2])) + rtn[3:]
470        #distsq2deg(rtn[2]),
471    else:
472        (inds,dists) = spherematch_c.match(kd1, kd2, radius, notself, permuted)
473        #d = dist2deg(dists[:,0])
474        d = dists[:,0]
475        I,J = inds[:,0], inds[:,1]
476        rtn = (I,J,d)
477    return rtn
478
479def tree_permute(kd, I):
480    print('Unnecessary call to tree_permute(kd, I): use kd.permute(I)')
481    return kd.permute(I)
482
483def tree_bbox(kd):
484    print('Unnecessary call to tree_bbox(kd): use kd.bbox')
485    return kd.bbox
486
487def tree_n(kd):
488    print('Unnecessary call to tree_n(kd): use kd.n')
489    return kd.n
490
491def tree_print(kd):
492    print('Unnecessary call to tree_print(kd): use kd.print()')
493    kd.print()
494
495def tree_data(kd, I):
496    print('Unnecessary call to tree_data(kd, I): use kd.get_data(I)')
497    return kd.get_data(I)
498
499if __name__ == '__main__':
500    import doctest
501    doctest.testmod()
502