1# -*- coding: utf-8 -*-
2# Licensed under a 3-clause BSD style license - see LICENSE.rst
3
4import pytest
5import numpy as np
6from numpy import testing as npt
7
8from astropy.tests.helper import assert_quantity_allclose as assert_allclose
9
10from astropy import units as u
11
12from astropy.coordinates import matching
13
14from astropy.utils.compat.optional_deps import HAS_SCIPY  # noqa
15
16"""
17These are the tests for coordinate matching.
18
19Note that this requires scipy.
20"""
21
22
23@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy.")
24def test_matching_function():
25    from astropy.coordinates import ICRS
26    from astropy.coordinates.matching import match_coordinates_3d
27    # this only uses match_coordinates_3d because that's the actual implementation
28
29    cmatch = ICRS([4, 2.1]*u.degree, [0, 0]*u.degree)
30    ccatalog = ICRS([1, 2, 3, 4]*u.degree, [0, 0, 0, 0]*u.degree)
31
32    idx, d2d, d3d = match_coordinates_3d(cmatch, ccatalog)
33    npt.assert_array_equal(idx, [3, 1])
34    npt.assert_array_almost_equal(d2d.degree, [0, 0.1])
35    assert d3d.value[0] == 0
36
37    idx, d2d, d3d = match_coordinates_3d(cmatch, ccatalog, nthneighbor=2)
38    assert np.all(idx == 2)
39    npt.assert_array_almost_equal(d2d.degree, [1, 0.9])
40    npt.assert_array_less(d3d.value, 0.02)
41
42
43@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy.")
44def test_matching_function_3d_and_sky():
45    from astropy.coordinates import ICRS
46    from astropy.coordinates.matching import match_coordinates_3d, match_coordinates_sky
47
48    cmatch = ICRS([4, 2.1]*u.degree, [0, 0]*u.degree, distance=[1, 5] * u.kpc)
49    ccatalog = ICRS([1, 2, 3, 4]*u.degree, [0, 0, 0, 0]*u.degree, distance=[1, 1, 1, 5] * u.kpc)
50
51    idx, d2d, d3d = match_coordinates_3d(cmatch, ccatalog)
52    npt.assert_array_equal(idx, [2, 3])
53
54    assert_allclose(d2d, [1, 1.9] * u.deg)
55    assert np.abs(d3d[0].to_value(u.kpc) - np.radians(1)) < 1e-6
56    assert np.abs(d3d[1].to_value(u.kpc) - 5*np.radians(1.9)) < 1e-5
57
58    idx, d2d, d3d = match_coordinates_sky(cmatch, ccatalog)
59    npt.assert_array_equal(idx, [3, 1])
60
61    assert_allclose(d2d, [0, 0.1] * u.deg)
62    assert_allclose(d3d, [4, 4.0000019] * u.kpc)
63
64
65@pytest.mark.parametrize('functocheck, args, defaultkdtname, bothsaved',
66                         [(matching.match_coordinates_3d, [], 'kdtree_3d', False),
67                          (matching.match_coordinates_sky, [], 'kdtree_sky', False),
68                          (matching.search_around_3d, [1*u.kpc], 'kdtree_3d', True),
69                          (matching.search_around_sky, [1*u.deg], 'kdtree_sky', False)
70                         ])
71@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy.")
72def test_kdtree_storage(functocheck, args, defaultkdtname, bothsaved):
73    from astropy.coordinates import ICRS
74
75    def make_scs():
76        cmatch = ICRS([4, 2.1]*u.degree, [0, 0]*u.degree, distance=[1, 2]*u.kpc)
77        ccatalog = ICRS([1, 2, 3, 4]*u.degree, [0, 0, 0, 0]*u.degree, distance=[1, 2, 3, 4]*u.kpc)
78        return cmatch, ccatalog
79
80    cmatch, ccatalog = make_scs()
81    functocheck(cmatch, ccatalog, *args, storekdtree=False)
82    assert 'kdtree' not in ccatalog.cache
83    assert defaultkdtname not in ccatalog.cache
84
85    cmatch, ccatalog = make_scs()
86    functocheck(cmatch, ccatalog, *args)
87    assert defaultkdtname in ccatalog.cache
88    assert 'kdtree' not in ccatalog.cache
89
90    cmatch, ccatalog = make_scs()
91    functocheck(cmatch, ccatalog, *args, storekdtree=True)
92    assert 'kdtree' in ccatalog.cache
93    assert defaultkdtname not in ccatalog.cache
94
95    cmatch, ccatalog = make_scs()
96    assert 'tislit_cheese' not in ccatalog.cache
97    functocheck(cmatch, ccatalog, *args, storekdtree='tislit_cheese')
98    assert 'tislit_cheese' in ccatalog.cache
99    assert defaultkdtname not in ccatalog.cache
100    assert 'kdtree' not in ccatalog.cache
101    if bothsaved:
102        assert 'tislit_cheese' in cmatch.cache
103        assert defaultkdtname not in cmatch.cache
104        assert 'kdtree' not in cmatch.cache
105    else:
106        assert 'tislit_cheese' not in cmatch.cache
107
108    # now a bit of a hacky trick to make sure it at least tries to *use* it
109    ccatalog.cache['tislit_cheese'] = 1
110    cmatch.cache['tislit_cheese'] = 1
111    with pytest.raises(TypeError) as e:
112        functocheck(cmatch, ccatalog, *args, storekdtree='tislit_cheese')
113    assert 'KD' in e.value.args[0]
114
115
116@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy.")
117def test_python_kdtree(monkeypatch):
118    from astropy.coordinates import ICRS
119
120    cmatch = ICRS([4, 2.1]*u.degree, [0, 0]*u.degree, distance=[1, 2]*u.kpc)
121    ccatalog = ICRS([1, 2, 3, 4]*u.degree, [0, 0, 0, 0]*u.degree, distance=[1, 2, 3, 4]*u.kpc)
122
123    monkeypatch.delattr("scipy.spatial.cKDTree")
124    with pytest.warns(UserWarning, match=r'C-based KD tree not found'):
125        matching.match_coordinates_sky(cmatch, ccatalog)
126
127
128@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy.")
129def test_matching_method():
130    from astropy.coordinates import ICRS, SkyCoord
131    from astropy.utils import NumpyRNGContext
132    from astropy.coordinates.matching import match_coordinates_3d, match_coordinates_sky
133
134    with NumpyRNGContext(987654321):
135        cmatch = ICRS(np.random.rand(20) * 360.*u.degree,
136                      (np.random.rand(20) * 180. - 90.)*u.degree)
137        ccatalog = ICRS(np.random.rand(100) * 360. * u.degree,
138                        (np.random.rand(100) * 180. - 90.)*u.degree)
139
140    idx1, d2d1, d3d1 = SkyCoord(cmatch).match_to_catalog_3d(ccatalog)
141    idx2, d2d2, d3d2 = match_coordinates_3d(cmatch, ccatalog)
142
143    npt.assert_array_equal(idx1, idx2)
144    assert_allclose(d2d1, d2d2)
145    assert_allclose(d3d1, d3d2)
146
147    # should be the same as above because there's no distance, but just make sure this method works
148    idx1, d2d1, d3d1 = SkyCoord(cmatch).match_to_catalog_sky(ccatalog)
149    idx2, d2d2, d3d2 = match_coordinates_sky(cmatch, ccatalog)
150
151    npt.assert_array_equal(idx1, idx2)
152    assert_allclose(d2d1, d2d2)
153    assert_allclose(d3d1, d3d2)
154
155    assert len(idx1) == len(d2d1) == len(d3d1) == 20
156
157
158@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy")
159def test_search_around():
160    from astropy.coordinates import ICRS, SkyCoord
161    from astropy.coordinates.matching import search_around_sky, search_around_3d
162
163    coo1 = ICRS([4, 2.1]*u.degree, [0, 0]*u.degree, distance=[1, 5] * u.kpc)
164    coo2 = ICRS([1, 2, 3, 4]*u.degree, [0, 0, 0, 0]*u.degree, distance=[1, 1, 1, 5] * u.kpc)
165
166    idx1_1deg, idx2_1deg, d2d_1deg, d3d_1deg = search_around_sky(coo1, coo2, 1.01*u.deg)
167    idx1_0p05deg, idx2_0p05deg, d2d_0p05deg, d3d_0p05deg = search_around_sky(coo1, coo2, 0.05*u.deg)
168
169    assert list(zip(idx1_1deg, idx2_1deg)) == [(0, 2), (0, 3), (1, 1), (1, 2)]
170    assert d2d_1deg[0] == 1.0*u.deg
171    assert_allclose(d2d_1deg, [1, 0, .1, .9]*u.deg)
172
173    assert list(zip(idx1_0p05deg, idx2_0p05deg)) == [(0, 3)]
174
175    idx1_1kpc, idx2_1kpc, d2d_1kpc, d3d_1kpc = search_around_3d(coo1, coo2, 1*u.kpc)
176    idx1_sm, idx2_sm, d2d_sm, d3d_sm = search_around_3d(coo1, coo2, 0.05*u.kpc)
177
178    assert list(zip(idx1_1kpc, idx2_1kpc)) == [(0, 0), (0, 1), (0, 2), (1, 3)]
179    assert list(zip(idx1_sm, idx2_sm)) == [(0, 1), (0, 2)]
180    assert_allclose(d2d_sm, [2, 1]*u.deg)
181
182    # Test for the non-matches, #4877
183    coo1 = ICRS([4.1, 2.1]*u.degree, [0, 0]*u.degree, distance=[1, 5] * u.kpc)
184    idx1, idx2, d2d, d3d = search_around_sky(coo1, coo2, 1*u.arcsec)
185    assert idx1.size == idx2.size == d2d.size == d3d.size == 0
186    assert idx1.dtype == idx2.dtype == int
187    assert d2d.unit == u.deg
188    assert d3d.unit == u.kpc
189    idx1, idx2, d2d, d3d = search_around_3d(coo1, coo2, 1*u.m)
190    assert idx1.size == idx2.size == d2d.size == d3d.size == 0
191    assert idx1.dtype == idx2.dtype == int
192    assert d2d.unit == u.deg
193    assert d3d.unit == u.kpc
194
195    # Test when one or both of the coordinate arrays is empty, #4875
196    empty = ICRS(ra=[] * u.degree, dec=[] * u.degree, distance=[] * u.kpc)
197    idx1, idx2, d2d, d3d = search_around_sky(empty, coo2, 1*u.arcsec)
198    assert idx1.size == idx2.size == d2d.size == d3d.size == 0
199    assert idx1.dtype == idx2.dtype == int
200    assert d2d.unit == u.deg
201    assert d3d.unit == u.kpc
202    idx1, idx2, d2d, d3d = search_around_sky(coo1, empty, 1*u.arcsec)
203    assert idx1.size == idx2.size == d2d.size == d3d.size == 0
204    assert idx1.dtype == idx2.dtype == int
205    assert d2d.unit == u.deg
206    assert d3d.unit == u.kpc
207    empty = ICRS(ra=[] * u.degree, dec=[] * u.degree, distance=[] * u.kpc)
208    idx1, idx2, d2d, d3d = search_around_sky(empty, empty[:], 1*u.arcsec)
209    assert idx1.size == idx2.size == d2d.size == d3d.size == 0
210    assert idx1.dtype == idx2.dtype == int
211    assert d2d.unit == u.deg
212    assert d3d.unit == u.kpc
213    idx1, idx2, d2d, d3d = search_around_3d(empty, coo2, 1*u.m)
214    assert idx1.size == idx2.size == d2d.size == d3d.size == 0
215    assert idx1.dtype == idx2.dtype == int
216    assert d2d.unit == u.deg
217    assert d3d.unit == u.kpc
218    idx1, idx2, d2d, d3d = search_around_3d(coo1, empty, 1*u.m)
219    assert idx1.size == idx2.size == d2d.size == d3d.size == 0
220    assert idx1.dtype == idx2.dtype == int
221    assert d2d.unit == u.deg
222    assert d3d.unit == u.kpc
223    idx1, idx2, d2d, d3d = search_around_3d(empty, empty[:], 1*u.m)
224    assert idx1.size == idx2.size == d2d.size == d3d.size == 0
225    assert idx1.dtype == idx2.dtype == int
226    assert d2d.unit == u.deg
227    assert d3d.unit == u.kpc
228
229    # Test that input without distance units results in a
230    # 'dimensionless_unscaled' unit
231    cempty = SkyCoord(ra=[], dec=[], unit=u.deg)
232    idx1, idx2, d2d, d3d = search_around_3d(cempty, cempty[:], 1*u.m)
233    assert d2d.unit == u.deg
234    assert d3d.unit == u.dimensionless_unscaled
235    idx1, idx2, d2d, d3d = search_around_sky(cempty, cempty[:], 1*u.m)
236    assert d2d.unit == u.deg
237    assert d3d.unit == u.dimensionless_unscaled
238
239
240@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy")
241def test_search_around_scalar():
242    from astropy.coordinates import SkyCoord, Angle
243
244    cat = SkyCoord([1, 2, 3], [-30, 45, 8], unit="deg")
245    target = SkyCoord('1.1 -30.1', unit="deg")
246
247    with pytest.raises(ValueError) as excinfo:
248        cat.search_around_sky(target, Angle('2d'))
249
250    # make sure the error message is *specific* to search_around_sky rather than
251    # generic as reported in #3359
252    assert 'search_around_sky' in str(excinfo.value)
253
254    with pytest.raises(ValueError) as excinfo:
255        cat.search_around_3d(target, Angle('2d'))
256    assert 'search_around_3d' in str(excinfo.value)
257
258
259@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy")
260def test_match_catalog_empty():
261    from astropy.coordinates import SkyCoord
262
263    sc1 = SkyCoord(1, 2, unit="deg")
264    cat0 = SkyCoord([], [], unit="deg")
265    cat1 = SkyCoord([1.1], [2.1], unit="deg")
266    cat2 = SkyCoord([1.1, 3], [2.1, 5], unit="deg")
267
268    sc1.match_to_catalog_sky(cat2)
269    sc1.match_to_catalog_3d(cat2)
270
271    sc1.match_to_catalog_sky(cat1)
272    sc1.match_to_catalog_3d(cat1)
273
274    with pytest.raises(ValueError) as excinfo:
275        sc1.match_to_catalog_sky(cat1[0])
276    assert 'catalog' in str(excinfo.value)
277    with pytest.raises(ValueError) as excinfo:
278        sc1.match_to_catalog_3d(cat1[0])
279    assert 'catalog' in str(excinfo.value)
280
281    with pytest.raises(ValueError) as excinfo:
282        sc1.match_to_catalog_sky(cat0)
283    assert 'catalog' in str(excinfo.value)
284    with pytest.raises(ValueError) as excinfo:
285        sc1.match_to_catalog_3d(cat0)
286    assert 'catalog' in str(excinfo.value)
287
288
289@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy")
290@pytest.mark.filterwarnings(
291    r'ignore:invalid value encountered in.*:RuntimeWarning')
292def test_match_catalog_nan():
293    from astropy.coordinates import SkyCoord, Galactic
294
295    sc1 = SkyCoord(1, 2, unit="deg")
296    sc_with_nans = SkyCoord(1, np.nan, unit="deg")
297
298    cat = SkyCoord([1.1, 3], [2.1, 5], unit="deg")
299    cat_with_nans = SkyCoord([1.1, np.nan], [2.1, 5], unit="deg")
300    galcat_with_nans = Galactic([1.2, np.nan]*u.deg, [5.6, 7.8]*u.deg)
301
302    with pytest.raises(ValueError) as excinfo:
303        sc1.match_to_catalog_sky(cat_with_nans)
304    assert 'Catalog coordinates cannot contain' in str(excinfo.value)
305    with pytest.raises(ValueError) as excinfo:
306        sc1.match_to_catalog_3d(cat_with_nans)
307    assert 'Catalog coordinates cannot contain' in str(excinfo.value)
308
309    with pytest.raises(ValueError) as excinfo:
310        sc1.match_to_catalog_sky(galcat_with_nans)
311    assert 'Catalog coordinates cannot contain' in str(excinfo.value)
312    with pytest.raises(ValueError) as excinfo:
313        sc1.match_to_catalog_3d(galcat_with_nans)
314    assert 'Catalog coordinates cannot contain' in str(excinfo.value)
315
316    with pytest.raises(ValueError) as excinfo:
317        sc_with_nans.match_to_catalog_sky(cat)
318    assert 'Matching coordinates cannot contain' in str(excinfo.value)
319    with pytest.raises(ValueError) as excinfo:
320        sc_with_nans.match_to_catalog_3d(cat)
321    assert 'Matching coordinates cannot contain' in str(excinfo.value)
322
323
324@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy")
325def test_match_catalog_nounit():
326    from astropy.coordinates import ICRS, CartesianRepresentation
327    from astropy.coordinates.matching import match_coordinates_sky
328
329    i1 = ICRS([[1], [2], [3]], representation_type=CartesianRepresentation)
330    i2 = ICRS([[1], [2], [4, 5]], representation_type=CartesianRepresentation)
331    i, sep, sep3d = match_coordinates_sky(i1, i2)
332    assert_allclose(sep3d, [1]*u.dimensionless_unscaled)
333