1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
3import copy
4
5import numpy as np
6
7import astropy.units as u
8from astropy.coordinates import CartesianRepresentation, SphericalRepresentation, ITRS
9from astropy.utils import unbroadcast
10
11from .wcs import WCS, WCSSUB_LATITUDE, WCSSUB_LONGITUDE
12
13__doctest_skip__ = ['wcs_to_celestial_frame', 'celestial_frame_to_wcs']
14
15__all__ = ['obsgeo_to_frame', 'add_stokes_axis_to_wcs',
16           'celestial_frame_to_wcs', 'wcs_to_celestial_frame',
17           'proj_plane_pixel_scales', 'proj_plane_pixel_area',
18           'is_proj_plane_distorted', 'non_celestial_pixel_scales',
19           'skycoord_to_pixel', 'pixel_to_skycoord',
20           'custom_wcs_to_frame_mappings', 'custom_frame_to_wcs_mappings',
21           'pixel_to_pixel', 'local_partial_pixel_derivatives',
22           'fit_wcs_from_points']
23
24
25def add_stokes_axis_to_wcs(wcs, add_before_ind):
26    """
27    Add a new Stokes axis that is uncorrelated with any other axes.
28
29    Parameters
30    ----------
31    wcs : `~astropy.wcs.WCS`
32        The WCS to add to
33    add_before_ind : int
34        Index of the WCS to insert the new Stokes axis in front of.
35        To add at the end, do add_before_ind = wcs.wcs.naxis
36        The beginning is at position 0.
37
38    Returns
39    -------
40    `~astropy.wcs.WCS`
41        A new `~astropy.wcs.WCS` instance with an additional axis
42    """
43
44    inds = [i + 1 for i in range(wcs.wcs.naxis)]
45    inds.insert(add_before_ind, 0)
46    newwcs = wcs.sub(inds)
47    newwcs.wcs.ctype[add_before_ind] = 'STOKES'
48    newwcs.wcs.cname[add_before_ind] = 'STOKES'
49    return newwcs
50
51
52def _wcs_to_celestial_frame_builtin(wcs):
53
54    # Import astropy.coordinates here to avoid circular imports
55    from astropy.coordinates import (FK4, FK5, ICRS, ITRS, FK4NoETerms,
56                                     Galactic, SphericalRepresentation)
57    # Import astropy.time here otherwise setup.py fails before extensions are compiled
58    from astropy.time import Time
59
60    if wcs.wcs.lng == -1 or wcs.wcs.lat == -1:
61        return None
62
63    radesys = wcs.wcs.radesys
64
65    if np.isnan(wcs.wcs.equinox):
66        equinox = None
67    else:
68        equinox = wcs.wcs.equinox
69
70    xcoord = wcs.wcs.ctype[wcs.wcs.lng][:4]
71    ycoord = wcs.wcs.ctype[wcs.wcs.lat][:4]
72
73    # Apply logic from FITS standard to determine the default radesys
74    if radesys == '' and xcoord == 'RA--' and ycoord == 'DEC-':
75        if equinox is None:
76            radesys = "ICRS"
77        elif equinox < 1984.:
78            radesys = "FK4"
79        else:
80            radesys = "FK5"
81
82    if radesys == 'FK4':
83        if equinox is not None:
84            equinox = Time(equinox, format='byear')
85        frame = FK4(equinox=equinox)
86    elif radesys == 'FK4-NO-E':
87        if equinox is not None:
88            equinox = Time(equinox, format='byear')
89        frame = FK4NoETerms(equinox=equinox)
90    elif radesys == 'FK5':
91        if equinox is not None:
92            equinox = Time(equinox, format='jyear')
93        frame = FK5(equinox=equinox)
94    elif radesys == 'ICRS':
95        frame = ICRS()
96    else:
97        if xcoord == 'GLON' and ycoord == 'GLAT':
98            frame = Galactic()
99        elif xcoord == 'TLON' and ycoord == 'TLAT':
100            # The default representation for ITRS is cartesian, but for WCS
101            # purposes, we need the spherical representation.
102            frame = ITRS(representation_type=SphericalRepresentation,
103                         obstime=wcs.wcs.dateobs or None)
104        else:
105            frame = None
106
107    return frame
108
109
110def _celestial_frame_to_wcs_builtin(frame, projection='TAN'):
111
112    # Import astropy.coordinates here to avoid circular imports
113    from astropy.coordinates import FK4, FK5, ICRS, ITRS, BaseRADecFrame, FK4NoETerms, Galactic
114
115    # Create a 2-dimensional WCS
116    wcs = WCS(naxis=2)
117
118    if isinstance(frame, BaseRADecFrame):
119
120        xcoord = 'RA--'
121        ycoord = 'DEC-'
122        if isinstance(frame, ICRS):
123            wcs.wcs.radesys = 'ICRS'
124        elif isinstance(frame, FK4NoETerms):
125            wcs.wcs.radesys = 'FK4-NO-E'
126            wcs.wcs.equinox = frame.equinox.byear
127        elif isinstance(frame, FK4):
128            wcs.wcs.radesys = 'FK4'
129            wcs.wcs.equinox = frame.equinox.byear
130        elif isinstance(frame, FK5):
131            wcs.wcs.radesys = 'FK5'
132            wcs.wcs.equinox = frame.equinox.jyear
133        else:
134            return None
135    elif isinstance(frame, Galactic):
136        xcoord = 'GLON'
137        ycoord = 'GLAT'
138    elif isinstance(frame, ITRS):
139        xcoord = 'TLON'
140        ycoord = 'TLAT'
141        wcs.wcs.radesys = 'ITRS'
142        wcs.wcs.dateobs = frame.obstime.utc.isot
143    else:
144        return None
145
146    wcs.wcs.ctype = [xcoord + '-' + projection, ycoord + '-' + projection]
147
148    return wcs
149
150
151WCS_FRAME_MAPPINGS = [[_wcs_to_celestial_frame_builtin]]
152FRAME_WCS_MAPPINGS = [[_celestial_frame_to_wcs_builtin]]
153
154
155class custom_wcs_to_frame_mappings:
156    def __init__(self, mappings=[]):
157        if hasattr(mappings, '__call__'):
158            mappings = [mappings]
159        WCS_FRAME_MAPPINGS.append(mappings)
160
161    def __enter__(self):
162        pass
163
164    def __exit__(self, type, value, tb):
165        WCS_FRAME_MAPPINGS.pop()
166
167
168# Backward-compatibility
169custom_frame_mappings = custom_wcs_to_frame_mappings
170
171
172class custom_frame_to_wcs_mappings:
173    def __init__(self, mappings=[]):
174        if hasattr(mappings, '__call__'):
175            mappings = [mappings]
176        FRAME_WCS_MAPPINGS.append(mappings)
177
178    def __enter__(self):
179        pass
180
181    def __exit__(self, type, value, tb):
182        FRAME_WCS_MAPPINGS.pop()
183
184
185def wcs_to_celestial_frame(wcs):
186    """
187    For a given WCS, return the coordinate frame that matches the celestial
188    component of the WCS.
189
190    Parameters
191    ----------
192    wcs : :class:`~astropy.wcs.WCS` instance
193        The WCS to find the frame for
194
195    Returns
196    -------
197    frame : :class:`~astropy.coordinates.baseframe.BaseCoordinateFrame` subclass instance
198        An instance of a :class:`~astropy.coordinates.baseframe.BaseCoordinateFrame`
199        subclass instance that best matches the specified WCS.
200
201    Notes
202    -----
203
204    To extend this function to frames not defined in astropy.coordinates, you
205    can write your own function which should take a :class:`~astropy.wcs.WCS`
206    instance and should return either an instance of a frame, or `None` if no
207    matching frame was found. You can register this function temporarily with::
208
209        >>> from astropy.wcs.utils import wcs_to_celestial_frame, custom_wcs_to_frame_mappings
210        >>> with custom_wcs_to_frame_mappings(my_function):
211        ...     wcs_to_celestial_frame(...)
212
213    """
214    for mapping_set in WCS_FRAME_MAPPINGS:
215        for func in mapping_set:
216            frame = func(wcs)
217            if frame is not None:
218                return frame
219    raise ValueError("Could not determine celestial frame corresponding to "
220                     "the specified WCS object")
221
222
223def celestial_frame_to_wcs(frame, projection='TAN'):
224    """
225    For a given coordinate frame, return the corresponding WCS object.
226
227    Note that the returned WCS object has only the elements corresponding to
228    coordinate frames set (e.g. ctype, equinox, radesys).
229
230    Parameters
231    ----------
232    frame : :class:`~astropy.coordinates.baseframe.BaseCoordinateFrame` subclass instance
233        An instance of a :class:`~astropy.coordinates.baseframe.BaseCoordinateFrame`
234        subclass instance for which to find the WCS
235    projection : str
236        Projection code to use in ctype, if applicable
237
238    Returns
239    -------
240    wcs : :class:`~astropy.wcs.WCS` instance
241        The corresponding WCS object
242
243    Examples
244    --------
245
246    ::
247
248        >>> from astropy.wcs.utils import celestial_frame_to_wcs
249        >>> from astropy.coordinates import FK5
250        >>> frame = FK5(equinox='J2010')
251        >>> wcs = celestial_frame_to_wcs(frame)
252        >>> wcs.to_header()
253        WCSAXES =                    2 / Number of coordinate axes
254        CRPIX1  =                  0.0 / Pixel coordinate of reference point
255        CRPIX2  =                  0.0 / Pixel coordinate of reference point
256        CDELT1  =                  1.0 / [deg] Coordinate increment at reference point
257        CDELT2  =                  1.0 / [deg] Coordinate increment at reference point
258        CUNIT1  = 'deg'                / Units of coordinate increment and value
259        CUNIT2  = 'deg'                / Units of coordinate increment and value
260        CTYPE1  = 'RA---TAN'           / Right ascension, gnomonic projection
261        CTYPE2  = 'DEC--TAN'           / Declination, gnomonic projection
262        CRVAL1  =                  0.0 / [deg] Coordinate value at reference point
263        CRVAL2  =                  0.0 / [deg] Coordinate value at reference point
264        LONPOLE =                180.0 / [deg] Native longitude of celestial pole
265        LATPOLE =                  0.0 / [deg] Native latitude of celestial pole
266        RADESYS = 'FK5'                / Equatorial coordinate system
267        EQUINOX =               2010.0 / [yr] Equinox of equatorial coordinates
268
269
270    Notes
271    -----
272
273    To extend this function to frames not defined in astropy.coordinates, you
274    can write your own function which should take a
275    :class:`~astropy.coordinates.baseframe.BaseCoordinateFrame` subclass
276    instance and a projection (given as a string) and should return either a WCS
277    instance, or `None` if the WCS could not be determined. You can register
278    this function temporarily with::
279
280        >>> from astropy.wcs.utils import celestial_frame_to_wcs, custom_frame_to_wcs_mappings
281        >>> with custom_frame_to_wcs_mappings(my_function):
282        ...     celestial_frame_to_wcs(...)
283
284    """
285    for mapping_set in FRAME_WCS_MAPPINGS:
286        for func in mapping_set:
287            wcs = func(frame, projection=projection)
288            if wcs is not None:
289                return wcs
290    raise ValueError("Could not determine WCS corresponding to the specified "
291                     "coordinate frame.")
292
293
294def proj_plane_pixel_scales(wcs):
295    """
296    For a WCS returns pixel scales along each axis of the image pixel at
297    the ``CRPIX`` location once it is projected onto the
298    "plane of intermediate world coordinates" as defined in
299    `Greisen & Calabretta 2002, A&A, 395, 1061 <https://ui.adsabs.harvard.edu/abs/2002A%26A...395.1061G>`_.
300
301    .. note::
302        This function is concerned **only** about the transformation
303        "image plane"->"projection plane" and **not** about the
304        transformation "celestial sphere"->"projection plane"->"image plane".
305        Therefore, this function ignores distortions arising due to
306        non-linear nature of most projections.
307
308    .. note::
309        In order to compute the scales corresponding to celestial axes only,
310        make sure that the input `~astropy.wcs.WCS` object contains
311        celestial axes only, e.g., by passing in the
312        `~astropy.wcs.WCS.celestial` WCS object.
313
314    Parameters
315    ----------
316    wcs : `~astropy.wcs.WCS`
317        A world coordinate system object.
318
319    Returns
320    -------
321    scale : ndarray
322        A vector (`~numpy.ndarray`) of projection plane increments
323        corresponding to each pixel side (axis). The units of the returned
324        results are the same as the units of `~astropy.wcs.Wcsprm.cdelt`,
325        `~astropy.wcs.Wcsprm.crval`, and `~astropy.wcs.Wcsprm.cd` for
326        the celestial WCS and can be obtained by inquiring the value
327        of `~astropy.wcs.Wcsprm.cunit` property of the input
328        `~astropy.wcs.WCS` WCS object.
329
330    See Also
331    --------
332    astropy.wcs.utils.proj_plane_pixel_area
333
334    """
335    return np.sqrt((wcs.pixel_scale_matrix**2).sum(axis=0, dtype=float))
336
337
338def proj_plane_pixel_area(wcs):
339    """
340    For a **celestial** WCS (see `astropy.wcs.WCS.celestial`) returns pixel
341    area of the image pixel at the ``CRPIX`` location once it is projected
342    onto the "plane of intermediate world coordinates" as defined in
343    `Greisen & Calabretta 2002, A&A, 395, 1061 <https://ui.adsabs.harvard.edu/abs/2002A%26A...395.1061G>`_.
344
345    .. note::
346        This function is concerned **only** about the transformation
347        "image plane"->"projection plane" and **not** about the
348        transformation "celestial sphere"->"projection plane"->"image plane".
349        Therefore, this function ignores distortions arising due to
350        non-linear nature of most projections.
351
352    .. note::
353        In order to compute the area of pixels corresponding to celestial
354        axes only, this function uses the `~astropy.wcs.WCS.celestial` WCS
355        object of the input ``wcs``.  This is different from the
356        `~astropy.wcs.utils.proj_plane_pixel_scales` function
357        that computes the scales for the axes of the input WCS itself.
358
359    Parameters
360    ----------
361    wcs : `~astropy.wcs.WCS`
362        A world coordinate system object.
363
364    Returns
365    -------
366    area : float
367        Area (in the projection plane) of the pixel at ``CRPIX`` location.
368        The units of the returned result are the same as the units of
369        the `~astropy.wcs.Wcsprm.cdelt`, `~astropy.wcs.Wcsprm.crval`,
370        and `~astropy.wcs.Wcsprm.cd` for the celestial WCS and can be
371        obtained by inquiring the value of `~astropy.wcs.Wcsprm.cunit`
372        property of the `~astropy.wcs.WCS.celestial` WCS object.
373
374    Raises
375    ------
376    ValueError
377        Pixel area is defined only for 2D pixels. Most likely the
378        `~astropy.wcs.Wcsprm.cd` matrix of the `~astropy.wcs.WCS.celestial`
379        WCS is not a square matrix of second order.
380
381    Notes
382    -----
383
384    Depending on the application, square root of the pixel area can be used to
385    represent a single pixel scale of an equivalent square pixel
386    whose area is equal to the area of a generally non-square pixel.
387
388    See Also
389    --------
390    astropy.wcs.utils.proj_plane_pixel_scales
391
392    """
393    psm = wcs.celestial.pixel_scale_matrix
394    if psm.shape != (2, 2):
395        raise ValueError("Pixel area is defined only for 2D pixels.")
396    return np.abs(np.linalg.det(psm))
397
398
399def is_proj_plane_distorted(wcs, maxerr=1.0e-5):
400    r"""
401    For a WCS returns `False` if square image (detector) pixels stay square
402    when projected onto the "plane of intermediate world coordinates"
403    as defined in
404    `Greisen & Calabretta 2002, A&A, 395, 1061 <https://ui.adsabs.harvard.edu/abs/2002A%26A...395.1061G>`_.
405    It will return `True` if transformation from image (detector) coordinates
406    to the focal plane coordinates is non-orthogonal or if WCS contains
407    non-linear (e.g., SIP) distortions.
408
409    .. note::
410        Since this function is concerned **only** about the transformation
411        "image plane"->"focal plane" and **not** about the transformation
412        "celestial sphere"->"focal plane"->"image plane",
413        this function ignores distortions arising due to non-linear nature
414        of most projections.
415
416    Let's denote by *C* either the original or the reconstructed
417    (from ``PC`` and ``CDELT``) CD matrix. `is_proj_plane_distorted`
418    verifies that the transformation from image (detector) coordinates
419    to the focal plane coordinates is orthogonal using the following
420    check:
421
422    .. math::
423        \left \| \frac{C \cdot C^{\mathrm{T}}}
424        {| det(C)|} - I \right \|_{\mathrm{max}} < \epsilon .
425
426    Parameters
427    ----------
428    wcs : `~astropy.wcs.WCS`
429        World coordinate system object
430
431    maxerr : float, optional
432        Accuracy to which the CD matrix, **normalized** such
433        that :math:`|det(CD)|=1`, should be close to being an
434        orthogonal matrix as described in the above equation
435        (see :math:`\epsilon`).
436
437    Returns
438    -------
439    distorted : bool
440        Returns `True` if focal (projection) plane is distorted and `False`
441        otherwise.
442
443    """
444    cwcs = wcs.celestial
445    return (not _is_cd_orthogonal(cwcs.pixel_scale_matrix, maxerr) or
446            _has_distortion(cwcs))
447
448
449def _is_cd_orthogonal(cd, maxerr):
450    shape = cd.shape
451    if not (len(shape) == 2 and shape[0] == shape[1]):
452        raise ValueError("CD (or PC) matrix must be a 2D square matrix.")
453
454    pixarea = np.abs(np.linalg.det(cd))
455    if (pixarea == 0.0):
456        raise ValueError("CD (or PC) matrix is singular.")
457
458    # NOTE: Technically, below we should use np.dot(cd, np.conjugate(cd.T))
459    # However, I am not aware of complex CD/PC matrices...
460    I = np.dot(cd, cd.T) / pixarea
461    cd_unitary_err = np.amax(np.abs(I - np.eye(shape[0])))
462
463    return (cd_unitary_err < maxerr)
464
465
466def non_celestial_pixel_scales(inwcs):
467    """
468    Calculate the pixel scale along each axis of a non-celestial WCS,
469    for example one with mixed spectral and spatial axes.
470
471    Parameters
472    ----------
473    inwcs : `~astropy.wcs.WCS`
474        The world coordinate system object.
475
476    Returns
477    -------
478    scale : `numpy.ndarray`
479        The pixel scale along each axis.
480    """
481
482    if inwcs.is_celestial:
483        raise ValueError("WCS is celestial, use celestial_pixel_scales instead")
484
485    pccd = inwcs.pixel_scale_matrix
486
487    if np.allclose(np.extract(1-np.eye(*pccd.shape), pccd), 0):
488        return np.abs(np.diagonal(pccd))*u.deg
489    else:
490        raise ValueError("WCS is rotated, cannot determine consistent pixel scales")
491
492
493def _has_distortion(wcs):
494    """
495    `True` if contains any SIP or image distortion components.
496    """
497    return any(getattr(wcs, dist_attr) is not None
498               for dist_attr in ['cpdis1', 'cpdis2', 'det2im1', 'det2im2', 'sip'])
499
500
501# TODO: in future, we should think about how the following two functions can be
502# integrated better into the WCS class.
503
504def skycoord_to_pixel(coords, wcs, origin=0, mode='all'):
505    """
506    Convert a set of SkyCoord coordinates into pixels.
507
508    Parameters
509    ----------
510    coords : `~astropy.coordinates.SkyCoord`
511        The coordinates to convert.
512    wcs : `~astropy.wcs.WCS`
513        The WCS transformation to use.
514    origin : int
515        Whether to return 0 or 1-based pixel coordinates.
516    mode : 'all' or 'wcs'
517        Whether to do the transformation including distortions (``'all'``) or
518        only including only the core WCS transformation (``'wcs'``).
519
520    Returns
521    -------
522    xp, yp : `numpy.ndarray`
523        The pixel coordinates
524
525    See Also
526    --------
527    astropy.coordinates.SkyCoord.from_pixel
528    """
529
530    if _has_distortion(wcs) and wcs.naxis != 2:
531        raise ValueError("Can only handle WCS with distortions for 2-dimensional WCS")
532
533    # Keep only the celestial part of the axes, also re-orders lon/lat
534    wcs = wcs.sub([WCSSUB_LONGITUDE, WCSSUB_LATITUDE])
535
536    if wcs.naxis != 2:
537        raise ValueError("WCS should contain celestial component")
538
539    # Check which frame the WCS uses
540    frame = wcs_to_celestial_frame(wcs)
541
542    # Check what unit the WCS needs
543    xw_unit = u.Unit(wcs.wcs.cunit[0])
544    yw_unit = u.Unit(wcs.wcs.cunit[1])
545
546    # Convert positions to frame
547    coords = coords.transform_to(frame)
548
549    # Extract longitude and latitude. We first try and use lon/lat directly,
550    # but if the representation is not spherical or unit spherical this will
551    # fail. We should then force the use of the unit spherical
552    # representation. We don't do that directly to make sure that we preserve
553    # custom lon/lat representations if available.
554    try:
555        lon = coords.data.lon.to(xw_unit)
556        lat = coords.data.lat.to(yw_unit)
557    except AttributeError:
558        lon = coords.spherical.lon.to(xw_unit)
559        lat = coords.spherical.lat.to(yw_unit)
560
561    # Convert to pixel coordinates
562    if mode == 'all':
563        xp, yp = wcs.all_world2pix(lon.value, lat.value, origin)
564    elif mode == 'wcs':
565        xp, yp = wcs.wcs_world2pix(lon.value, lat.value, origin)
566    else:
567        raise ValueError("mode should be either 'all' or 'wcs'")
568
569    return xp, yp
570
571
572def pixel_to_skycoord(xp, yp, wcs, origin=0, mode='all', cls=None):
573    """
574    Convert a set of pixel coordinates into a `~astropy.coordinates.SkyCoord`
575    coordinate.
576
577    Parameters
578    ----------
579    xp, yp : float or ndarray
580        The coordinates to convert.
581    wcs : `~astropy.wcs.WCS`
582        The WCS transformation to use.
583    origin : int
584        Whether to return 0 or 1-based pixel coordinates.
585    mode : 'all' or 'wcs'
586        Whether to do the transformation including distortions (``'all'``) or
587        only including only the core WCS transformation (``'wcs'``).
588    cls : class or None
589        The class of object to create.  Should be a
590        `~astropy.coordinates.SkyCoord` subclass.  If None, defaults to
591        `~astropy.coordinates.SkyCoord`.
592
593    Returns
594    -------
595    coords : `~astropy.coordinates.SkyCoord` subclass
596        The celestial coordinates. Whatever ``cls`` type is.
597
598    See Also
599    --------
600    astropy.coordinates.SkyCoord.from_pixel
601    """
602
603    # Import astropy.coordinates here to avoid circular imports
604    from astropy.coordinates import SkyCoord, UnitSphericalRepresentation
605
606    # we have to do this instead of actually setting the default to SkyCoord
607    # because importing SkyCoord at the module-level leads to circular
608    # dependencies.
609    if cls is None:
610        cls = SkyCoord
611
612    if _has_distortion(wcs) and wcs.naxis != 2:
613        raise ValueError("Can only handle WCS with distortions for 2-dimensional WCS")
614
615    # Keep only the celestial part of the axes, also re-orders lon/lat
616    wcs = wcs.sub([WCSSUB_LONGITUDE, WCSSUB_LATITUDE])
617
618    if wcs.naxis != 2:
619        raise ValueError("WCS should contain celestial component")
620
621    # Check which frame the WCS uses
622    frame = wcs_to_celestial_frame(wcs)
623
624    # Check what unit the WCS gives
625    lon_unit = u.Unit(wcs.wcs.cunit[0])
626    lat_unit = u.Unit(wcs.wcs.cunit[1])
627
628    # Convert pixel coordinates to celestial coordinates
629    if mode == 'all':
630        lon, lat = wcs.all_pix2world(xp, yp, origin)
631    elif mode == 'wcs':
632        lon, lat = wcs.wcs_pix2world(xp, yp, origin)
633    else:
634        raise ValueError("mode should be either 'all' or 'wcs'")
635
636    # Add units to longitude/latitude
637    lon = lon * lon_unit
638    lat = lat * lat_unit
639
640    # Create a SkyCoord-like object
641    data = UnitSphericalRepresentation(lon=lon, lat=lat)
642    coords = cls(frame.realize_frame(data))
643
644    return coords
645
646
647def _unique_with_order_preserved(items):
648    """
649    Return a list of unique items in the list provided, preserving the order
650    in which they are found.
651    """
652    new_items = []
653    for item in items:
654        if item not in new_items:
655            new_items.append(item)
656    return new_items
657
658
659def _pixel_to_world_correlation_matrix(wcs):
660    """
661    Return a correlation matrix between the pixel coordinates and the
662    high level world coordinates, along with the list of high level world
663    coordinate classes.
664
665    The shape of the matrix is ``(n_world, n_pix)``, where ``n_world`` is the
666    number of high level world coordinates.
667    """
668
669    # We basically want to collapse the world dimensions together that are
670    # combined into the same high-level objects.
671
672    # Get the following in advance as getting these properties can be expensive
673    all_components = wcs.low_level_wcs.world_axis_object_components
674    all_classes = wcs.low_level_wcs.world_axis_object_classes
675    axis_correlation_matrix = wcs.low_level_wcs.axis_correlation_matrix
676
677    components = _unique_with_order_preserved([c[0] for c in all_components])
678
679    matrix = np.zeros((len(components), wcs.pixel_n_dim), dtype=bool)
680
681    for iworld in range(wcs.world_n_dim):
682        iworld_unique = components.index(all_components[iworld][0])
683        matrix[iworld_unique] |= axis_correlation_matrix[iworld]
684
685    classes = [all_classes[component][0] for component in components]
686
687    return matrix, classes
688
689
690def _pixel_to_pixel_correlation_matrix(wcs_in, wcs_out):
691    """
692    Correlation matrix between the input and output pixel coordinates for a
693    pixel -> world -> pixel transformation specified by two WCS instances.
694
695    The first WCS specified is the one used for the pixel -> world
696    transformation and the second WCS specified is the one used for the world ->
697    pixel transformation. The shape of the matrix is
698    ``(n_pixel_out, n_pixel_in)``.
699    """
700
701    matrix1, classes1 = _pixel_to_world_correlation_matrix(wcs_in)
702    matrix2, classes2 = _pixel_to_world_correlation_matrix(wcs_out)
703
704    if len(classes1) != len(classes2):
705        raise ValueError("The two WCS return a different number of world coordinates")
706
707    # Check if classes match uniquely
708    unique_match = True
709    mapping = []
710    for class1 in classes1:
711        matches = classes2.count(class1)
712        if matches == 0:
713            raise ValueError("The world coordinate types of the two WCS do not match")
714        elif matches > 1:
715            unique_match = False
716            break
717        else:
718            mapping.append(classes2.index(class1))
719
720    if unique_match:
721
722        # Classes are unique, so we need to re-order matrix2 along the world
723        # axis using the mapping we found above.
724        matrix2 = matrix2[mapping]
725
726    elif classes1 != classes2:
727
728        raise ValueError("World coordinate order doesn't match and automatic matching is ambiguous")
729
730    matrix = np.matmul(matrix2.T, matrix1)
731
732    return matrix
733
734
735def _split_matrix(matrix):
736    """
737    Given an axis correlation matrix from a WCS object, return information about
738    the individual WCS that can be split out.
739
740    The output is a list of tuples, where each tuple contains a list of
741    pixel dimensions and a list of world dimensions that can be extracted to
742    form a new WCS. For example, in the case of a spectral cube with the first
743    two world coordinates being the celestial coordinates and the third
744    coordinate being an uncorrelated spectral axis, the matrix would look like::
745
746        array([[ True,  True, False],
747               [ True,  True, False],
748               [False, False,  True]])
749
750    and this function will return ``[([0, 1], [0, 1]), ([2], [2])]``.
751    """
752
753    pixel_used = []
754
755    split_info = []
756
757    for ipix in range(matrix.shape[1]):
758        if ipix in pixel_used:
759            continue
760        pixel_include = np.zeros(matrix.shape[1], dtype=bool)
761        pixel_include[ipix] = True
762        n_pix_prev, n_pix = 0, 1
763        while n_pix > n_pix_prev:
764            world_include = matrix[:, pixel_include].any(axis=1)
765            pixel_include = matrix[world_include, :].any(axis=0)
766            n_pix_prev, n_pix = n_pix, np.sum(pixel_include)
767        pixel_indices = list(np.nonzero(pixel_include)[0])
768        world_indices = list(np.nonzero(world_include)[0])
769        pixel_used.extend(pixel_indices)
770        split_info.append((pixel_indices, world_indices))
771
772    return split_info
773
774
775def pixel_to_pixel(wcs_in, wcs_out, *inputs):
776    """
777    Transform pixel coordinates in a dataset with a WCS to pixel coordinates
778    in another dataset with a different WCS.
779
780    This function is designed to efficiently deal with input pixel arrays that
781    are broadcasted views of smaller arrays, and is compatible with any
782    APE14-compliant WCS.
783
784    Parameters
785    ----------
786    wcs_in : `~astropy.wcs.wcsapi.BaseHighLevelWCS`
787        A WCS object for the original dataset which complies with the
788        high-level shared APE 14 WCS API.
789    wcs_out : `~astropy.wcs.wcsapi.BaseHighLevelWCS`
790        A WCS object for the target dataset which complies with the
791        high-level shared APE 14 WCS API.
792    *inputs :
793        Scalars or arrays giving the pixel coordinates to transform.
794    """
795
796    # Shortcut for scalars
797    if np.isscalar(inputs[0]):
798        world_outputs = wcs_in.pixel_to_world(*inputs)
799        if not isinstance(world_outputs, (tuple, list)):
800            world_outputs = (world_outputs,)
801        return wcs_out.world_to_pixel(*world_outputs)
802
803    # Remember original shape
804    original_shape = inputs[0].shape
805
806    matrix = _pixel_to_pixel_correlation_matrix(wcs_in, wcs_out)
807    split_info = _split_matrix(matrix)
808
809    outputs = [None] * wcs_out.pixel_n_dim
810
811    for (pixel_in_indices, pixel_out_indices) in split_info:
812
813        pixel_inputs = []
814        for ipix in range(wcs_in.pixel_n_dim):
815            if ipix in pixel_in_indices:
816                pixel_inputs.append(unbroadcast(inputs[ipix]))
817            else:
818                pixel_inputs.append(inputs[ipix].flat[0])
819
820        pixel_inputs = np.broadcast_arrays(*pixel_inputs)
821
822        world_outputs = wcs_in.pixel_to_world(*pixel_inputs)
823
824        if not isinstance(world_outputs, (tuple, list)):
825            world_outputs = (world_outputs,)
826
827        pixel_outputs = wcs_out.world_to_pixel(*world_outputs)
828
829        if wcs_out.pixel_n_dim == 1:
830            pixel_outputs = (pixel_outputs,)
831
832        for ipix in range(wcs_out.pixel_n_dim):
833            if ipix in pixel_out_indices:
834                outputs[ipix] = np.broadcast_to(pixel_outputs[ipix], original_shape)
835
836    return outputs[0] if wcs_out.pixel_n_dim == 1 else outputs
837
838
839def local_partial_pixel_derivatives(wcs, *pixel, normalize_by_world=False):
840    """
841    Return a matrix of shape ``(world_n_dim, pixel_n_dim)`` where each entry
842    ``[i, j]`` is the partial derivative d(world_i)/d(pixel_j) at the requested
843    pixel position.
844
845    Parameters
846    ----------
847    wcs : `~astropy.wcs.WCS`
848        The WCS transformation to evaluate the derivatives for.
849    *pixel : float
850        The scalar pixel coordinates at which to evaluate the derivatives.
851    normalize_by_world : bool
852        If `True`, the matrix is normalized so that for each world entry
853        the derivatives add up to 1.
854    """
855
856    # Find the world coordinates at the requested pixel
857    pixel_ref = np.array(pixel)
858    world_ref = np.array(wcs.pixel_to_world_values(*pixel_ref))
859
860    # Set up the derivative matrix
861    derivatives = np.zeros((wcs.world_n_dim, wcs.pixel_n_dim))
862
863    for i in range(wcs.pixel_n_dim):
864        pixel_off = pixel_ref.copy()
865        pixel_off[i] += 1
866        world_off = np.array(wcs.pixel_to_world_values(*pixel_off))
867        derivatives[:, i] = world_off - world_ref
868
869    if normalize_by_world:
870        derivatives /= derivatives.sum(axis=0)[:, np.newaxis]
871
872    return derivatives
873
874
875def _linear_wcs_fit(params, lon, lat, x, y, w_obj):
876    """
877    Objective function for fitting linear terms.
878
879    Parameters
880    ----------
881    params : array
882        6 element array. First 4 elements are PC matrix, last 2 are CRPIX.
883    lon, lat: array
884        Sky coordinates.
885    x, y: array
886        Pixel coordinates
887    w_obj: `~astropy.wcs.WCS`
888        WCS object
889        """
890    cd = params[0:4]
891    crpix = params[4:6]
892
893    w_obj.wcs.cd = ((cd[0], cd[1]), (cd[2], cd[3]))
894    w_obj.wcs.crpix = crpix
895    lon2, lat2 = w_obj.wcs_pix2world(x, y, 0)
896
897    lat_resids = lat - lat2
898    lon_resids = lon - lon2
899    # In case the longitude has wrapped around
900    lon_resids = np.mod(lon_resids - 180.0, 360.0) - 180.0
901
902    resids = np.concatenate((lon_resids * np.cos(np.radians(lat)), lat_resids))
903
904    return resids
905
906
907def _sip_fit(params, lon, lat, u, v, w_obj, order, coeff_names):
908
909    """ Objective function for fitting SIP.
910
911    Parameters
912    ----------
913    params : array
914        Fittable parameters. First 4 elements are PC matrix, last 2 are CRPIX.
915    lon, lat: array
916        Sky coordinates.
917    u, v: array
918        Pixel coordinates
919    w_obj: `~astropy.wcs.WCS`
920        WCS object
921    """
922
923    from ..modeling.models import SIP  # here to avoid circular import
924
925    # unpack params
926    crpix = params[0:2]
927    cdx = params[2:6].reshape((2, 2))
928    a_params = params[6:6+len(coeff_names)]
929    b_params = params[6+len(coeff_names):]
930
931    # assign to wcs, used for transfomations in this function
932    w_obj.wcs.cd = cdx
933    w_obj.wcs.crpix = crpix
934
935    a_coeff, b_coeff = {}, {}
936    for i in range(len(coeff_names)):
937        a_coeff['A_' + coeff_names[i]] = a_params[i]
938        b_coeff['B_' + coeff_names[i]] = b_params[i]
939
940    sip = SIP(crpix=crpix, a_order=order, b_order=order,
941              a_coeff=a_coeff, b_coeff=b_coeff)
942    fuv, guv = sip(u, v)
943
944    xo, yo = np.dot(cdx, np.array([u+fuv-crpix[0], v+guv-crpix[1]]))
945
946    # use all pix2world in case `projection` contains distortion table
947    x, y = w_obj.all_world2pix(lon, lat, 0)
948    x, y = np.dot(w_obj.wcs.cd, (x-w_obj.wcs.crpix[0], y-w_obj.wcs.crpix[1]))
949
950    resids = np.concatenate((x-xo, y-yo))
951
952    return resids
953
954
955def fit_wcs_from_points(xy, world_coords, proj_point='center',
956                        projection='TAN', sip_degree=None):
957    """
958    Given two matching sets of coordinates on detector and sky,
959    compute the WCS.
960
961    Fits a WCS object to matched set of input detector and sky coordinates.
962    Optionally, a SIP can be fit to account for geometric
963    distortion. Returns an `~astropy.wcs.WCS` object with the best fit
964    parameters for mapping between input pixel and sky coordinates.
965
966    The projection type (default 'TAN') can passed in as a string, one of
967    the valid three-letter projection codes - or as a WCS object with
968    projection keywords already set. Note that if an input WCS has any
969    non-polynomial distortion, this will be applied and reflected in the
970    fit terms and coefficients. Passing in a WCS object in this way essentially
971    allows it to be refit based on the matched input coordinates and projection
972    point, but take care when using this option as non-projection related
973    keywords in the input might cause unexpected behavior.
974
975    Notes
976    -----
977    - The fiducial point for the spherical projection can be set to 'center'
978      to use the mean position of input sky coordinates, or as an
979      `~astropy.coordinates.SkyCoord` object.
980    - Units in all output WCS objects will always be in degrees.
981    - If the coordinate frame differs between `~astropy.coordinates.SkyCoord`
982      objects passed in for ``world_coords`` and ``proj_point``, the frame for
983      ``world_coords``  will override as the frame for the output WCS.
984    - If a WCS object is passed in to ``projection`` the CD/PC matrix will
985      be used as an initial guess for the fit. If this is known to be
986      significantly off and may throw off the fit, set to the identity matrix
987      (for example, by doing wcs.wcs.pc = [(1., 0.,), (0., 1.)])
988
989    Parameters
990    ----------
991    xy : (`numpy.ndarray`, `numpy.ndarray`) tuple
992        x & y pixel coordinates.
993    world_coords : `~astropy.coordinates.SkyCoord`
994        Skycoord object with world coordinates.
995    proj_point : 'center' or ~astropy.coordinates.SkyCoord`
996        Defaults to 'center', in which the geometric center of input world
997        coordinates will be used as the projection point. To specify an exact
998        point for the projection, a Skycoord object with a coordinate pair can
999        be passed in. For consistency, the units and frame of these coordinates
1000        will be transformed to match ``world_coords`` if they don't.
1001    projection : str or `~astropy.wcs.WCS`
1002        Three letter projection code, of any of standard projections defined
1003        in the FITS WCS standard. Optionally, a WCS object with projection
1004        keywords set may be passed in.
1005    sip_degree : None or int
1006        If set to a non-zero integer value, will fit SIP of degree
1007        ``sip_degree`` to model geometric distortion. Defaults to None, meaning
1008        no distortion corrections will be fit.
1009
1010    Returns
1011    -------
1012    wcs : `~astropy.wcs.WCS`
1013        The best-fit WCS to the points given.
1014    """
1015
1016    from scipy.optimize import least_squares
1017
1018    import astropy.units as u
1019    from astropy.coordinates import SkyCoord  # here to avoid circular import
1020
1021    from .wcs import Sip
1022
1023    xp, yp = xy
1024    try:
1025        lon, lat = world_coords.data.lon.deg, world_coords.data.lat.deg
1026    except AttributeError:
1027        unit_sph =  world_coords.unit_spherical
1028        lon, lat = unit_sph.lon.deg, unit_sph.lat.deg
1029
1030    # verify input
1031    if (type(proj_point) != type(world_coords)) and (proj_point != 'center'):
1032        raise ValueError("proj_point must be set to 'center', or an" +
1033                         "`~astropy.coordinates.SkyCoord` object with " +
1034                         "a pair of points.")
1035
1036    use_center_as_proj_point = (str(proj_point) == 'center')
1037
1038    if not use_center_as_proj_point:
1039        assert proj_point.size == 1
1040
1041    proj_codes = [
1042        'AZP', 'SZP', 'TAN', 'STG', 'SIN', 'ARC', 'ZEA', 'AIR', 'CYP',
1043        'CEA', 'CAR', 'MER', 'SFL', 'PAR', 'MOL', 'AIT', 'COP', 'COE',
1044        'COD', 'COO', 'BON', 'PCO', 'TSC', 'CSC', 'QSC', 'HPX', 'XPH'
1045    ]
1046    if type(projection) == str:
1047        if projection not in proj_codes:
1048            raise ValueError("Must specify valid projection code from list of "
1049                             + "supported types: ", ', '.join(proj_codes))
1050        # empty wcs to fill in with fit values
1051        wcs = celestial_frame_to_wcs(frame=world_coords.frame,
1052                                     projection=projection)
1053    else: #if projection is not string, should be wcs object. use as template.
1054        wcs = copy.deepcopy(projection)
1055        wcs.cdelt = (1., 1.) # make sure cdelt is 1
1056        wcs.sip = None
1057
1058    # Change PC to CD, since cdelt will be set to 1
1059    if wcs.wcs.has_pc():
1060        wcs.wcs.cd = wcs.wcs.pc
1061        wcs.wcs.__delattr__('pc')
1062
1063    if (type(sip_degree) != type(None)) and (type(sip_degree) != int):
1064        raise ValueError("sip_degree must be None, or integer.")
1065
1066    # compute bounding box for sources in image coordinates:
1067    xpmin, xpmax, ypmin, ypmax = xp.min(), xp.max(), yp.min(), yp.max()
1068
1069    # set pixel_shape to span of input points
1070    wcs.pixel_shape = (1 if xpmax <= 0.0 else int(np.ceil(xpmax)),
1071                       1 if ypmax <= 0.0 else int(np.ceil(ypmax)))
1072
1073    # determine CRVAL from input
1074    close = lambda l, p: p[np.argmin(np.abs(l))]
1075    if use_center_as_proj_point:  # use center of input points
1076        sc1 = SkyCoord(lon.min()*u.deg, lat.max()*u.deg)
1077        sc2 = SkyCoord(lon.max()*u.deg, lat.min()*u.deg)
1078        pa = sc1.position_angle(sc2)
1079        sep = sc1.separation(sc2)
1080        midpoint_sc = sc1.directional_offset_by(pa, sep/2)
1081        wcs.wcs.crval = ((midpoint_sc.data.lon.deg, midpoint_sc.data.lat.deg))
1082        wcs.wcs.crpix = ((xpmax + xpmin) / 2., (ypmax + ypmin) / 2.)
1083    else:  # convert units, initial guess for crpix
1084        proj_point.transform_to(world_coords)
1085        wcs.wcs.crval = (proj_point.data.lon.deg, proj_point.data.lat.deg)
1086        wcs.wcs.crpix = (close(lon - wcs.wcs.crval[0], xp + 1),
1087                         close(lon - wcs.wcs.crval[1], yp + 1))
1088
1089    # fit linear terms, assign to wcs
1090    # use (1, 0, 0, 1) as initial guess, in case input wcs was passed in
1091    # and cd terms are way off.
1092    # Use bounds to require that the fit center pixel is on the input image
1093    if xpmin == xpmax:
1094        xpmin, xpmax = xpmin - 0.5, xpmax + 0.5
1095    if ypmin == ypmax:
1096        ypmin, ypmax = ypmin - 0.5, ypmax + 0.5
1097
1098    p0 = np.concatenate([wcs.wcs.cd.flatten(), wcs.wcs.crpix.flatten()])
1099    fit = least_squares(
1100        _linear_wcs_fit, p0,
1101        args=(lon, lat, xp, yp, wcs),
1102        bounds=[[-np.inf, -np.inf, -np.inf, -np.inf, xpmin + 1, ypmin + 1],
1103                [np.inf, np.inf, np.inf, np.inf, xpmax + 1, ypmax + 1]]
1104    )
1105    wcs.wcs.crpix = np.array(fit.x[4:6])
1106    wcs.wcs.cd = np.array(fit.x[0:4].reshape((2, 2)))
1107
1108    # fit SIP, if specified. Only fit forward coefficients
1109    if sip_degree:
1110        degree = sip_degree
1111        if '-SIP' not in wcs.wcs.ctype[0]:
1112            wcs.wcs.ctype = [x + '-SIP' for x in wcs.wcs.ctype]
1113
1114        coef_names = [f'{i}_{j}' for i in range(degree+1)
1115                      for j in range(degree+1) if (i+j) < (degree+1) and
1116                      (i+j) > 1]
1117        p0 = np.concatenate((np.array(wcs.wcs.crpix), wcs.wcs.cd.flatten(),
1118                             np.zeros(2*len(coef_names))))
1119
1120        fit = least_squares(
1121            _sip_fit, p0,
1122            args=(lon, lat, xp, yp, wcs, degree, coef_names),
1123            bounds=[[xpmin + 1, ypmin + 1] + [-np.inf]*(4 + 2*len(coef_names)),
1124                    [xpmax + 1, ypmax + 1] + [np.inf]*(4 + 2*len(coef_names))]
1125        )
1126        coef_fit = (list(fit.x[6:6+len(coef_names)]),
1127                    list(fit.x[6+len(coef_names):]))
1128
1129        # put fit values in wcs
1130        wcs.wcs.cd = fit.x[2:6].reshape((2, 2))
1131        wcs.wcs.crpix = fit.x[0:2]
1132
1133        a_vals = np.zeros((degree+1, degree+1))
1134        b_vals = np.zeros((degree+1, degree+1))
1135
1136        for coef_name in coef_names:
1137            a_vals[int(coef_name[0])][int(coef_name[2])] = coef_fit[0].pop(0)
1138            b_vals[int(coef_name[0])][int(coef_name[2])] = coef_fit[1].pop(0)
1139
1140        wcs.sip = Sip(a_vals, b_vals, np.zeros((degree+1, degree+1)),
1141                      np.zeros((degree+1, degree+1)), wcs.wcs.crpix)
1142
1143    return wcs
1144
1145
1146def obsgeo_to_frame(obsgeo, obstime):
1147    """
1148    Convert a WCS obsgeo property into an `~.builtin_frames.ITRS` coordinate frame.
1149
1150    Parameters
1151    ----------
1152    obsgeo : array-like
1153        A shape ``(6, )`` array representing ``OBSGEO-[XYZ], OBSGEO-[BLH]`` as
1154        returned by ``WCS.wcs.obsgeo``.
1155
1156    obstime : time-like
1157        The time associated with the coordinate, will be passed to
1158        `~.builtin_frames.ITRS` as the obstime keyword.
1159
1160    Returns
1161    -------
1162    `~.builtin_frames.ITRS`
1163        An `~.builtin_frames.ITRS` coordinate frame
1164        representing the coordinates.
1165
1166    Notes
1167    -----
1168
1169    The obsgeo array as accessed on a `.WCS` object is a length 6 numpy array
1170    where the first three elements are the coordinate in a cartesian
1171    representation and the second 3 are the coordinate in a spherical
1172    representation.
1173
1174    This function priorities reading the cartesian coordinates, and will only
1175    read the spherical coordinates if the cartesian coordinates are either all
1176    zero or any of the cartesian coordinates are non-finite.
1177
1178    In the case where both the spherical and cartesian coordinates have some
1179    non-finite values the spherical coordinates will be returned with the
1180    non-finite values included.
1181
1182    """
1183    if (obsgeo is None
1184        or len(obsgeo) != 6
1185        or np.all(np.array(obsgeo) == 0)
1186        or np.all(~np.isfinite(obsgeo))
1187    ):
1188        raise ValueError(f"Can not parse the 'obsgeo' location ({obsgeo}). "
1189                         "obsgeo should be a length 6 non-zero, finite numpy array")
1190
1191    # If the cartesian coords are zero or have NaNs in them use the spherical ones
1192    if np.all(obsgeo[:3] == 0) or np.any(~np.isfinite(obsgeo[:3])):
1193        data = SphericalRepresentation(*(obsgeo[3:] * (u.deg, u.deg, u.m)))
1194
1195    # Otherwise we assume the cartesian ones are valid
1196    else:
1197        data = CartesianRepresentation(*obsgeo[:3] * u.m)
1198
1199    return ITRS(data, obstime=obstime)
1200