1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
3from textwrap import indent
4from collections import OrderedDict
5
6from .coordinate_helpers import CoordinateHelper
7from .frame import RectangularFrame, RectangularFrame1D
8from .coordinate_range import find_coordinate_range
9
10
11class CoordinatesMap:
12    """
13    A container for coordinate helpers that represents a coordinate system.
14
15    This object can be used to access coordinate helpers by index (like a list)
16    or by name (like a dictionary).
17
18    Parameters
19    ----------
20    axes : :class:`~astropy.visualization.wcsaxes.WCSAxes`
21        The axes the coordinate map belongs to.
22    transform : `~matplotlib.transforms.Transform`, optional
23        The transform for the data.
24    coord_meta : dict, optional
25        A dictionary providing additional metadata. This should include the keys
26        ``type``, ``wrap``, and ``unit``. Each of these should be a list with as
27        many items as the dimension of the coordinate system. The ``type``
28        entries should be one of ``longitude``, ``latitude``, or ``scalar``, the
29        ``wrap`` entries should give, for the longitude, the angle at which the
30        coordinate wraps (and `None` otherwise), and the ``unit`` should give
31        the unit of the coordinates as :class:`~astropy.units.Unit` instances.
32        This can optionally also include a ``format_unit`` entry giving the
33        units to use for the tick labels (if not specified, this defaults to
34        ``unit``).
35    frame_class : type, optional
36        The class for the frame, which should be a subclass of
37        :class:`~astropy.visualization.wcsaxes.frame.BaseFrame`. The default is to use a
38        :class:`~astropy.visualization.wcsaxes.frame.RectangularFrame`
39    previous_frame_path : `~matplotlib.path.Path`, optional
40        When changing the WCS of the axes, the frame instance will change but
41        we might want to keep re-using the same underlying matplotlib
42        `~matplotlib.path.Path` - in that case, this can be passed to this
43        keyword argument.
44    """
45
46    def __init__(self, axes, transform=None, coord_meta=None,
47                 frame_class=RectangularFrame, previous_frame_path=None):
48
49        self._axes = axes
50        self._transform = transform
51
52        self.frame = frame_class(axes, self._transform, path=previous_frame_path)
53
54        # Set up coordinates
55        self._coords = []
56        self._aliases = {}
57
58        visible_count = 0
59
60        for index in range(len(coord_meta['type'])):
61
62            # Extract coordinate metadata
63            coord_type = coord_meta['type'][index]
64            coord_wrap = coord_meta['wrap'][index]
65            coord_unit = coord_meta['unit'][index]
66            name = coord_meta['name'][index]
67
68            visible = True
69            if 'visible' in coord_meta:
70                visible = coord_meta['visible'][index]
71
72            format_unit = None
73            if 'format_unit' in coord_meta:
74                format_unit = coord_meta['format_unit'][index]
75
76            default_label = name[0] if isinstance(name, (tuple, list)) else name
77            if 'default_axis_label' in coord_meta:
78                default_label = coord_meta['default_axis_label'][index]
79
80            coord_index = None
81            if visible:
82                visible_count += 1
83                coord_index = visible_count - 1
84
85            self._coords.append(CoordinateHelper(parent_axes=axes,
86                                                 parent_map=self,
87                                                 transform=self._transform,
88                                                 coord_index=coord_index,
89                                                 coord_type=coord_type,
90                                                 coord_wrap=coord_wrap,
91                                                 coord_unit=coord_unit,
92                                                 format_unit=format_unit,
93                                                 frame=self.frame,
94                                                 default_label=default_label))
95
96            # Set up aliases for coordinates
97            if isinstance(name, tuple):
98                for nm in name:
99                    nm = nm.lower()
100                    # Do not replace an alias already in the map if we have
101                    # more than one alias for this axis.
102                    if nm not in self._aliases:
103                        self._aliases[nm] = index
104            else:
105                self._aliases[name.lower()] = index
106
107    def __getitem__(self, item):
108        if isinstance(item, str):
109            return self._coords[self._aliases[item.lower()]]
110        else:
111            return self._coords[item]
112
113    def __contains__(self, item):
114        if isinstance(item, str):
115            return item.lower() in self._aliases
116        else:
117            return 0 <= item < len(self._coords)
118
119    def set_visible(self, visibility):
120        raise NotImplementedError()
121
122    def __iter__(self):
123        for coord in self._coords:
124            yield coord
125
126    def grid(self, draw_grid=True, grid_type=None, **kwargs):
127        """
128        Plot gridlines for both coordinates.
129
130        Standard matplotlib appearance options (color, alpha, etc.) can be
131        passed as keyword arguments.
132
133        Parameters
134        ----------
135        draw_grid : bool
136            Whether to show the gridlines
137        grid_type : { 'lines' | 'contours' }
138            Whether to plot the contours by determining the grid lines in
139            world coordinates and then plotting them in world coordinates
140            (``'lines'``) or by determining the world coordinates at many
141            positions in the image and then drawing contours
142            (``'contours'``). The first is recommended for 2-d images, while
143            for 3-d (or higher dimensional) cubes, the ``'contours'`` option
144            is recommended. By default, 'lines' is used if the transform has
145            an inverse, otherwise 'contours' is used.
146        """
147        for coord in self:
148            coord.grid(draw_grid=draw_grid, grid_type=grid_type, **kwargs)
149
150    def get_coord_range(self):
151        xmin, xmax = self._axes.get_xlim()
152
153        if isinstance(self.frame, RectangularFrame1D):
154            extent = [xmin, xmax]
155        else:
156            ymin, ymax = self._axes.get_ylim()
157            extent = [xmin, xmax, ymin, ymax]
158
159        return find_coordinate_range(self._transform,
160                                     extent,
161                                     [coord.coord_type for coord in self if coord.coord_index is not None],
162                                     [coord.coord_unit for coord in self if coord.coord_index is not None],
163                                     [coord.coord_wrap for coord in self if coord.coord_index is not None])
164
165    def _as_table(self):
166
167        # Import Table here to avoid importing the astropy.table package
168        # every time astropy.visualization.wcsaxes is imported.
169        from astropy.table import Table  # noqa
170
171        rows = []
172        for icoord, coord in enumerate(self._coords):
173            aliases = [key for key, value in self._aliases.items() if value == icoord]
174            row = OrderedDict([('index', icoord), ('aliases', ' '.join(aliases)),
175                               ('type', coord.coord_type), ('unit', coord.coord_unit),
176                               ('wrap', coord.coord_wrap), ('format_unit', coord.get_format_unit()),
177                               ('visible', 'no' if coord.coord_index is None else 'yes')])
178            rows.append(row)
179        return Table(rows=rows)
180
181    def __repr__(self):
182        s = f'<CoordinatesMap with {len(self._coords)} world coordinates:\n\n'
183        table = indent(str(self._as_table()), '  ')
184        return s + table + '\n\n>'
185