1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
3from functools import partial
4from collections import defaultdict
5
6import numpy as np
7
8from matplotlib import rcParams
9from matplotlib.artist import Artist
10from matplotlib.axes import Axes, subplot_class_factory
11from matplotlib.transforms import Affine2D, Bbox, Transform
12
13import astropy.units as u
14from astropy.coordinates import SkyCoord, BaseCoordinateFrame
15from astropy.wcs import WCS
16from astropy.wcs.wcsapi import BaseHighLevelWCS, BaseLowLevelWCS
17
18from .transforms import CoordinateTransform
19from .coordinates_map import CoordinatesMap
20from .utils import get_coord_meta, transform_contour_set_inplace
21from .frame import RectangularFrame, RectangularFrame1D
22from .wcsapi import IDENTITY, transform_coord_meta_from_wcs
23
24
25__all__ = ['WCSAxes', 'WCSAxesSubplot']
26
27VISUAL_PROPERTIES = ['facecolor', 'edgecolor', 'linewidth', 'alpha', 'linestyle']
28
29
30class _WCSAxesArtist(Artist):
31    """This is a dummy artist to enforce the correct z-order of axis ticks,
32    tick labels, and gridlines.
33
34    FIXME: This is a bit of a hack. ``Axes.draw`` sorts the artists by zorder
35    and then renders them in sequence. For normal Matplotlib axes, the ticks,
36    tick labels, and gridlines are included in this list of artists and hence
37    are automatically drawn in the correct order. However, ``WCSAxes`` disables
38    the native ticks, labels, and gridlines. Instead, ``WCSAxes.draw`` renders
39    ersatz ticks, labels, and gridlines by explicitly calling the functions
40    ``CoordinateHelper._draw_ticks``, ``CoordinateHelper._draw_grid``, etc.
41    This hack would not be necessary if ``WCSAxes`` drew ticks, tick labels,
42    and gridlines in the standary way."""
43
44    def draw(self, renderer, *args, **kwargs):
45        self.axes.draw_wcsaxes(renderer)
46
47
48class WCSAxes(Axes):
49    """
50    The main axes class that can be used to show world coordinates from a WCS.
51
52    Parameters
53    ----------
54    fig : `~matplotlib.figure.Figure`
55        The figure to add the axes to
56    rect : list
57        The position of the axes in the figure in relative units. Should be
58        given as ``[left, bottom, width, height]``.
59    wcs : :class:`~astropy.wcs.WCS`, optional
60        The WCS for the data. If this is specified, ``transform`` cannot be
61        specified.
62    transform : `~matplotlib.transforms.Transform`, optional
63        The transform for the data. If this is specified, ``wcs`` cannot be
64        specified.
65    coord_meta : dict, optional
66        A dictionary providing additional metadata when ``transform`` is
67        specified. This should include the keys ``type``, ``wrap``, and
68        ``unit``. Each of these should be a list with as many items as the
69        dimension of the WCS. The ``type`` entries should be one of
70        ``longitude``, ``latitude``, or ``scalar``, the ``wrap`` entries should
71        give, for the longitude, the angle at which the coordinate wraps (and
72        `None` otherwise), and the ``unit`` should give the unit of the
73        coordinates as :class:`~astropy.units.Unit` instances. This can
74        optionally also include a ``format_unit`` entry giving the units to use
75        for the tick labels (if not specified, this defaults to ``unit``).
76    transData : `~matplotlib.transforms.Transform`, optional
77        Can be used to override the default data -> pixel mapping.
78    slices : tuple, optional
79        For WCS transformations with more than two dimensions, we need to
80        choose which dimensions are being shown in the 2D image. The slice
81        should contain one ``x`` entry, one ``y`` entry, and the rest of the
82        values should be integers indicating the slice through the data. The
83        order of the items in the slice should be the same as the order of the
84        dimensions in the :class:`~astropy.wcs.WCS`, and the opposite of the
85        order of the dimensions in Numpy. For example, ``(50, 'x', 'y')`` means
86        that the first WCS dimension (last Numpy dimension) will be sliced at
87        an index of 50, the second WCS and Numpy dimension will be shown on the
88        x axis, and the final WCS dimension (first Numpy dimension) will be
89        shown on the y-axis (and therefore the data will be plotted using
90        ``data[:, :, 50].transpose()``)
91    frame_class : type, optional
92        The class for the frame, which should be a subclass of
93        :class:`~astropy.visualization.wcsaxes.frame.BaseFrame`. The default is to use a
94        :class:`~astropy.visualization.wcsaxes.frame.RectangularFrame`
95    """
96
97    def __init__(self, fig, rect, wcs=None, transform=None, coord_meta=None,
98                 transData=None, slices=None, frame_class=None,
99                 **kwargs):
100        """
101        """
102
103        super().__init__(fig, rect, **kwargs)
104        self._bboxes = []
105
106        if frame_class is not None:
107            self.frame_class = frame_class
108        elif (wcs is not None and (wcs.pixel_n_dim == 1 or
109                                   (slices is not None and 'y' not in slices))):
110            self.frame_class = RectangularFrame1D
111        else:
112            self.frame_class = RectangularFrame
113
114        if not (transData is None):
115            # User wants to override the transform for the final
116            # data->pixel mapping
117            self.transData = transData
118
119        self.reset_wcs(wcs=wcs, slices=slices, transform=transform, coord_meta=coord_meta)
120        self._hide_parent_artists()
121        self.format_coord = self._display_world_coords
122        self._display_coords_index = 0
123        fig.canvas.mpl_connect('key_press_event', self._set_cursor_prefs)
124        self.patch = self.coords.frame.patch
125        self._wcsaxesartist = _WCSAxesArtist()
126        self.add_artist(self._wcsaxesartist)
127        self._drawn = False
128
129    def _display_world_coords(self, x, y):
130
131        if not self._drawn:
132            return ""
133
134        if self._display_coords_index == -1:
135            return f"{x} {y} (pixel)"
136
137        pixel = np.array([x, y])
138
139        coords = self._all_coords[self._display_coords_index]
140
141        world = coords._transform.transform(np.array([pixel]))[0]
142
143        coord_strings = []
144        for idx, coord in enumerate(coords):
145            if coord.coord_index is not None:
146                coord_strings.append(coord.format_coord(world[coord.coord_index], format='ascii'))
147
148        coord_string = ' '.join(coord_strings)
149
150        if self._display_coords_index == 0:
151            system = "world"
152        else:
153            system = f"world, overlay {self._display_coords_index}"
154
155        coord_string = f"{coord_string} ({system})"
156
157        return coord_string
158
159    def _set_cursor_prefs(self, event, **kwargs):
160        if event.key == 'w':
161            self._display_coords_index += 1
162            if self._display_coords_index + 1 > len(self._all_coords):
163                self._display_coords_index = -1
164
165    def _hide_parent_artists(self):
166        # Turn off spines and current axes
167        for s in self.spines.values():
168            s.set_visible(False)
169
170        self.xaxis.set_visible(False)
171        if self.frame_class is not RectangularFrame1D:
172            self.yaxis.set_visible(False)
173
174    # We now overload ``imshow`` because we need to make sure that origin is
175    # set to ``lower`` for all images, which means that we need to flip RGB
176    # images.
177    def imshow(self, X, *args, **kwargs):
178        """
179        Wrapper to Matplotlib's :meth:`~matplotlib.axes.Axes.imshow`.
180
181        If an RGB image is passed as a PIL object, it will be flipped
182        vertically and ``origin`` will be set to ``lower``, since WCS
183        transformations - like FITS files - assume that the origin is the lower
184        left pixel of the image (whereas RGB images have the origin in the top
185        left).
186
187        All arguments are passed to :meth:`~matplotlib.axes.Axes.imshow`.
188        """
189
190        origin = kwargs.pop('origin', 'lower')
191
192        # plt.imshow passes origin as None, which we should default to lower.
193        if origin is None:
194            origin = 'lower'
195        elif origin == 'upper':
196            raise ValueError("Cannot use images with origin='upper' in WCSAxes.")
197
198        # To check whether the image is a PIL image we can check if the data
199        # has a 'getpixel' attribute - this is what Matplotlib's AxesImage does
200
201        try:
202            from PIL.Image import Image, FLIP_TOP_BOTTOM
203        except ImportError:
204            # We don't need to worry since PIL is not installed, so user cannot
205            # have passed RGB image.
206            pass
207        else:
208            if isinstance(X, Image) or hasattr(X, 'getpixel'):
209                X = X.transpose(FLIP_TOP_BOTTOM)
210
211        return super().imshow(X, *args, origin=origin, **kwargs)
212
213    def contour(self, *args, **kwargs):
214        """
215        Plot contours.
216
217        This is a custom implementation of :meth:`~matplotlib.axes.Axes.contour`
218        which applies the transform (if specified) to all contours in one go for
219        performance rather than to each contour line individually. All
220        positional and keyword arguments are the same as for
221        :meth:`~matplotlib.axes.Axes.contour`.
222        """
223
224        # In Matplotlib, when calling contour() with a transform, each
225        # individual path in the contour map is transformed separately. However,
226        # this is much too slow for us since each call to the transforms results
227        # in an Astropy coordinate transformation, which has a non-negligible
228        # overhead - therefore a better approach is to override contour(), call
229        # the Matplotlib one with no transform, then apply the transform in one
230        # go to all the segments that make up the contour map.
231
232        transform = kwargs.pop('transform', None)
233
234        cset = super().contour(*args, **kwargs)
235
236        if transform is not None:
237            # The transform passed to self.contour will normally include
238            # a transData component at the end, but we can remove that since
239            # we are already working in data space.
240            transform = transform - self.transData
241            transform_contour_set_inplace(cset, transform)
242
243        return cset
244
245    def contourf(self, *args, **kwargs):
246        """
247        Plot filled contours.
248
249        This is a custom implementation of :meth:`~matplotlib.axes.Axes.contourf`
250        which applies the transform (if specified) to all contours in one go for
251        performance rather than to each contour line individually. All
252        positional and keyword arguments are the same as for
253        :meth:`~matplotlib.axes.Axes.contourf`.
254        """
255
256        # See notes for contour above.
257
258        transform = kwargs.pop('transform', None)
259
260        cset = super().contourf(*args, **kwargs)
261
262        if transform is not None:
263            # The transform passed to self.contour will normally include
264            # a transData component at the end, but we can remove that since
265            # we are already working in data space.
266            transform = transform - self.transData
267            transform_contour_set_inplace(cset, transform)
268
269        return cset
270
271    def plot_coord(self, *args, **kwargs):
272        """
273        Plot `~astropy.coordinates.SkyCoord` or
274        `~astropy.coordinates.BaseCoordinateFrame` objects onto the axes.
275
276        The first argument to
277        :meth:`~astropy.visualization.wcsaxes.WCSAxes.plot_coord` should be a
278        coordinate, which will then be converted to the first two parameters to
279        `matplotlib.axes.Axes.plot`. All other arguments are the same as
280        `matplotlib.axes.Axes.plot`. If not specified a ``transform`` keyword
281        argument will be created based on the coordinate.
282
283        Parameters
284        ----------
285        coordinate : `~astropy.coordinates.SkyCoord` or `~astropy.coordinates.BaseCoordinateFrame`
286            The coordinate object to plot on the axes. This is converted to the
287            first two arguments to `matplotlib.axes.Axes.plot`.
288
289        See Also
290        --------
291
292        matplotlib.axes.Axes.plot : This method is called from this function with all arguments passed to it.
293
294        """
295
296        if isinstance(args[0], (SkyCoord, BaseCoordinateFrame)):
297
298            # Extract the frame from the first argument.
299            frame0 = args[0]
300            if isinstance(frame0, SkyCoord):
301                frame0 = frame0.frame
302
303            native_frame = self._transform_pixel2world.frame_out
304            # Transform to the native frame of the plot
305            frame0 = frame0.transform_to(native_frame)
306
307            plot_data = []
308            for coord in self.coords:
309                if coord.coord_type == 'longitude':
310                    plot_data.append(frame0.spherical.lon.to_value(u.deg))
311                elif coord.coord_type == 'latitude':
312                    plot_data.append(frame0.spherical.lat.to_value(u.deg))
313                else:
314                    raise NotImplementedError("Coordinates cannot be plotted with this "
315                                              "method because the WCS does not represent longitude/latitude.")
316
317            if 'transform' in kwargs.keys():
318                raise TypeError("The 'transform' keyword argument is not allowed,"
319                                " as it is automatically determined by the input coordinate frame.")
320
321            transform = self.get_transform(native_frame)
322            kwargs.update({'transform': transform})
323
324            args = tuple(plot_data) + args[1:]
325
326        return super().plot(*args, **kwargs)
327
328    def reset_wcs(self, wcs=None, slices=None, transform=None, coord_meta=None):
329        """
330        Reset the current Axes, to use a new WCS object.
331        """
332
333        # Here determine all the coordinate axes that should be shown.
334        if wcs is None and transform is None:
335
336            self.wcs = IDENTITY
337
338        else:
339
340            # We now force call 'set', which ensures the WCS object is
341            # consistent, which will only be important if the WCS has been set
342            # by hand. For example if the user sets a celestial WCS by hand and
343            # forgets to set the units, WCS.wcs.set() will do this.
344            if wcs is not None:
345                # Check if the WCS object is an instance of `astropy.wcs.WCS`
346                # This check is necessary as only `astropy.wcs.WCS` supports
347                # wcs.set() method
348                if isinstance(wcs, WCS):
349                    wcs.wcs.set()
350
351                if isinstance(wcs, BaseHighLevelWCS):
352                    wcs = wcs.low_level_wcs
353
354            self.wcs = wcs
355
356        # If we are making a new WCS, we need to preserve the path object since
357        # it may already be used by objects that have been plotted, and we need
358        # to continue updating it. CoordinatesMap will create a new frame
359        # instance, but we can tell that instance to keep using the old path.
360        if hasattr(self, 'coords'):
361            previous_frame = {'path': self.coords.frame._path,
362                              'color': self.coords.frame.get_color(),
363                              'linewidth': self.coords.frame.get_linewidth()}
364        else:
365            previous_frame = {'path': None}
366
367        if self.wcs is not None:
368
369            transform, coord_meta = transform_coord_meta_from_wcs(self.wcs, self.frame_class, slices=slices)
370
371        self.coords = CoordinatesMap(self,
372                                     transform=transform,
373                                     coord_meta=coord_meta,
374                                     frame_class=self.frame_class,
375                                     previous_frame_path=previous_frame['path'])
376
377        self._transform_pixel2world = transform
378
379        if previous_frame['path'] is not None:
380            self.coords.frame.set_color(previous_frame['color'])
381            self.coords.frame.set_linewidth(previous_frame['linewidth'])
382
383        self._all_coords = [self.coords]
384
385        # Common default settings for Rectangular Frame
386        for ind, pos in enumerate(coord_meta.get('default_axislabel_position', ['b', 'l'])):
387            self.coords[ind].set_axislabel_position(pos)
388
389        for ind, pos in enumerate(coord_meta.get('default_ticklabel_position', ['b', 'l'])):
390            self.coords[ind].set_ticklabel_position(pos)
391
392        for ind, pos in enumerate(coord_meta.get('default_ticks_position', ['bltr', 'bltr'])):
393            self.coords[ind].set_ticks_position(pos)
394
395        if rcParams['axes.grid']:
396            self.grid()
397
398    def draw_wcsaxes(self, renderer):
399        if not self.axison:
400            return
401        # Here need to find out range of all coordinates, and update range for
402        # each coordinate axis. For now, just assume it covers the whole sky.
403
404        self._bboxes = []
405        # This generates a structure like [coords][axis] = [...]
406        ticklabels_bbox = defaultdict(partial(defaultdict, list))
407
408        visible_ticks = []
409
410        for coords in self._all_coords:
411
412            coords.frame.update()
413            for coord in coords:
414                coord._draw_grid(renderer)
415
416        for coords in self._all_coords:
417
418            for coord in coords:
419                coord._draw_ticks(renderer, bboxes=self._bboxes,
420                                  ticklabels_bbox=ticklabels_bbox[coord])
421                visible_ticks.extend(coord.ticklabels.get_visible_axes())
422
423        for coords in self._all_coords:
424
425            for coord in coords:
426                coord._draw_axislabels(renderer, bboxes=self._bboxes,
427                                       ticklabels_bbox=ticklabels_bbox,
428                                       visible_ticks=visible_ticks)
429
430        self.coords.frame.draw(renderer)
431
432    def draw(self, renderer, **kwargs):
433        """Draw the axes."""
434
435        # Before we do any drawing, we need to remove any existing grid lines
436        # drawn with contours, otherwise if we try and remove the contours
437        # part way through drawing, we end up with the issue mentioned in
438        # https://github.com/astropy/astropy/issues/12446
439        for coords in self._all_coords:
440            for coord in coords:
441                coord._clear_grid_contour()
442
443        # In Axes.draw, the following code can result in the xlim and ylim
444        # values changing, so we need to force call this here to make sure that
445        # the limits are correct before we update the patch.
446        locator = self.get_axes_locator()
447        if locator:
448            pos = locator(self, renderer)
449            self.apply_aspect(pos)
450        else:
451            self.apply_aspect()
452
453        if self._axisbelow is True:
454            self._wcsaxesartist.set_zorder(0.5)
455        elif self._axisbelow is False:
456            self._wcsaxesartist.set_zorder(2.5)
457        else:
458            # 'line': above patches, below lines
459            self._wcsaxesartist.set_zorder(1.5)
460
461        # We need to make sure that that frame path is up to date
462        self.coords.frame._update_patch_path()
463
464        super().draw(renderer, **kwargs)
465
466        self._drawn = True
467
468    # Matplotlib internally sometimes calls set_xlabel(label=...).
469    def set_xlabel(self, xlabel=None, labelpad=1, loc=None, **kwargs):
470        """Set x-label."""
471        if xlabel is None:
472            xlabel = kwargs.pop('label', None)
473            if xlabel is None:
474                raise TypeError("set_xlabel() missing 1 required positional argument: 'xlabel'")
475        for coord in self.coords:
476            if ('b' in coord.axislabels.get_visible_axes() or
477                'h' in coord.axislabels.get_visible_axes()):
478                coord.set_axislabel(xlabel, minpad=labelpad, **kwargs)
479                break
480
481    def set_ylabel(self, ylabel=None, labelpad=1, loc=None, **kwargs):
482        """Set y-label"""
483        if ylabel is None:
484            ylabel = kwargs.pop('label', None)
485            if ylabel is None:
486                raise TypeError("set_ylabel() missing 1 required positional argument: 'ylabel'")
487
488        if self.frame_class is RectangularFrame1D:
489            return super().set_ylabel(ylabel, labelpad=labelpad, **kwargs)
490
491        for coord in self.coords:
492            if ('l' in coord.axislabels.get_visible_axes() or
493                'c' in coord.axislabels.get_visible_axes()):
494                coord.set_axislabel(ylabel, minpad=labelpad, **kwargs)
495                break
496
497    def get_xlabel(self):
498        for coord in self.coords:
499            if ('b' in coord.axislabels.get_visible_axes() or
500                'h' in coord.axislabels.get_visible_axes()):
501                return coord.get_axislabel()
502
503    def get_ylabel(self):
504        if self.frame_class is RectangularFrame1D:
505            return super().get_ylabel()
506
507        for coord in self.coords:
508            if ('l' in coord.axislabels.get_visible_axes() or
509                'c' in coord.axislabels.get_visible_axes()):
510                return coord.get_axislabel()
511
512    def get_coords_overlay(self, frame, coord_meta=None):
513
514        # Here we can't use get_transform because that deals with
515        # pixel-to-pixel transformations when passing a WCS object.
516        if isinstance(frame, WCS):
517            transform, coord_meta = transform_coord_meta_from_wcs(frame, self.frame_class)
518        else:
519            transform = self._get_transform_no_transdata(frame)
520
521        if coord_meta is None:
522            coord_meta = get_coord_meta(frame)
523
524        coords = CoordinatesMap(self, transform=transform,
525                                coord_meta=coord_meta,
526                                frame_class=self.frame_class)
527
528        self._all_coords.append(coords)
529
530        # Common settings for overlay
531        coords[0].set_axislabel_position('t')
532        coords[1].set_axislabel_position('r')
533        coords[0].set_ticklabel_position('t')
534        coords[1].set_ticklabel_position('r')
535
536        self.overlay_coords = coords
537
538        return coords
539
540    def get_transform(self, frame):
541        """
542        Return a transform from the specified frame to display coordinates.
543
544        This does not include the transData transformation
545
546        Parameters
547        ----------
548        frame : :class:`~astropy.wcs.WCS` or :class:`~matplotlib.transforms.Transform` or str
549            The ``frame`` parameter can have several possible types:
550                * :class:`~astropy.wcs.WCS` instance: assumed to be a
551                  transformation from pixel to world coordinates, where the
552                  world coordinates are the same as those in the WCS
553                  transformation used for this ``WCSAxes`` instance. This is
554                  used for example to show contours, since this involves
555                  plotting an array in pixel coordinates that are not the
556                  final data coordinate and have to be transformed to the
557                  common world coordinate system first.
558                * :class:`~matplotlib.transforms.Transform` instance: it is
559                  assumed to be a transform to the world coordinates that are
560                  part of the WCS used to instantiate this ``WCSAxes``
561                  instance.
562                * ``'pixel'`` or ``'world'``: return a transformation that
563                  allows users to plot in pixel/data coordinates (essentially
564                  an identity transform) and ``world`` (the default
565                  world-to-pixel transformation used to instantiate the
566                  ``WCSAxes`` instance).
567                * ``'fk5'`` or ``'galactic'``: return a transformation from
568                  the specified frame to the pixel/data coordinates.
569                * :class:`~astropy.coordinates.BaseCoordinateFrame` instance.
570        """
571        return self._get_transform_no_transdata(frame).inverted() + self.transData
572
573    def _get_transform_no_transdata(self, frame):
574        """
575        Return a transform from data to the specified frame
576        """
577
578        if isinstance(frame, (BaseLowLevelWCS, BaseHighLevelWCS)):
579            if isinstance(frame, BaseHighLevelWCS):
580                frame = frame.low_level_wcs
581
582            transform, coord_meta = transform_coord_meta_from_wcs(frame, self.frame_class)
583            transform_world2pixel = transform.inverted()
584
585            if self._transform_pixel2world.frame_out == transform_world2pixel.frame_in:
586
587                return self._transform_pixel2world + transform_world2pixel
588
589            else:
590
591                return (self._transform_pixel2world +
592                        CoordinateTransform(self._transform_pixel2world.frame_out,
593                                            transform_world2pixel.frame_in) +
594                        transform_world2pixel)
595
596        elif isinstance(frame, str) and frame == 'pixel':
597
598            return Affine2D()
599
600        elif isinstance(frame, Transform):
601
602            return self._transform_pixel2world + frame
603
604        else:
605
606            if isinstance(frame, str) and frame == 'world':
607
608                return self._transform_pixel2world
609
610            else:
611
612                coordinate_transform = CoordinateTransform(self._transform_pixel2world.frame_out, frame)
613
614                if coordinate_transform.same_frames:
615                    return self._transform_pixel2world
616                else:
617                    return self._transform_pixel2world + coordinate_transform
618
619    def get_tightbbox(self, renderer, *args, **kwargs):
620
621        # FIXME: we should determine what to do with the extra arguments here.
622        # Note that the expected signature of this method is different in
623        # Matplotlib 3.x compared to 2.x, but we only support 3.x now.
624
625        if not self.get_visible():
626            return
627
628        bb = [b for b in self._bboxes if b and (b.width != 0 or b.height != 0)]
629        bb.append(super().get_tightbbox(renderer, *args, **kwargs))
630
631        if bb:
632            _bbox = Bbox.union(bb)
633            return _bbox
634        else:
635            return self.get_window_extent(renderer)
636
637    def grid(self, b=None, axis='both', *, which='major', **kwargs):
638        """
639        Plot gridlines for both coordinates.
640
641        Standard matplotlib appearance options (color, alpha, etc.) can be
642        passed as keyword arguments. This behaves like `matplotlib.axes.Axes`
643        except that if no arguments are specified, the grid is shown rather
644        than toggled.
645
646        Parameters
647        ----------
648        b : bool
649            Whether to show the gridlines.
650        axis : 'both', 'x', 'y'
651            Which axis to turn the gridlines on/off for.
652        which : str
653            Currently only ``'major'`` is supported.
654        """
655
656        if not hasattr(self, 'coords'):
657            return
658
659        if which != 'major':
660            raise NotImplementedError('Plotting the grid for the minor ticks is '
661                                      'not supported.')
662
663        if axis == 'both':
664            self.coords.grid(draw_grid=b, **kwargs)
665        elif axis == 'x':
666            self.coords[0].grid(draw_grid=b, **kwargs)
667        elif axis == 'y':
668            self.coords[1].grid(draw_grid=b, **kwargs)
669        else:
670            raise ValueError('axis should be one of x/y/both')
671
672    def tick_params(self, axis='both', **kwargs):
673        """
674        Method to set the tick and tick label parameters in the same way as the
675        :meth:`~matplotlib.axes.Axes.tick_params` method in Matplotlib.
676
677        This is provided for convenience, but the recommended API is to use
678        :meth:`~astropy.visualization.wcsaxes.CoordinateHelper.set_ticks`,
679        :meth:`~astropy.visualization.wcsaxes.CoordinateHelper.set_ticklabel`,
680        :meth:`~astropy.visualization.wcsaxes.CoordinateHelper.set_ticks_position`,
681        :meth:`~astropy.visualization.wcsaxes.CoordinateHelper.set_ticklabel_position`,
682        and :meth:`~astropy.visualization.wcsaxes.CoordinateHelper.grid`.
683
684        Parameters
685        ----------
686        axis : int or str, optional
687            Which axis to apply the parameters to. This defaults to 'both'
688            but this can also be set to an `int` or `str` that refers to the
689            axis to apply it to, following the valid values that can index
690            ``ax.coords``. Note that ``'x'`` and ``'y``' are also accepted in
691            the case of rectangular axes.
692        which : {'both', 'major', 'minor'}, optional
693            Which ticks to apply the settings to. By default, setting are
694            applied to both major and minor ticks. Note that if ``'minor'`` is
695            specified, only the length of the ticks can be set currently.
696        direction : {'in', 'out'}, optional
697            Puts ticks inside the axes, or outside the axes.
698        length : float, optional
699            Tick length in points.
700        width : float, optional
701            Tick width in points.
702        color : color, optional
703            Tick color (accepts any valid Matplotlib color)
704        pad : float, optional
705            Distance in points between tick and label.
706        labelsize : float or str, optional
707            Tick label font size in points or as a string (e.g., 'large').
708        labelcolor : color, optional
709            Tick label color (accepts any valid Matplotlib color)
710        colors : color, optional
711            Changes the tick color and the label color to the same value
712             (accepts any valid Matplotlib color).
713        bottom, top, left, right : bool, optional
714            Where to draw the ticks. Note that this can only be given if a
715            specific coordinate is specified via the ``axis`` argument, and it
716            will not work correctly if the frame is not rectangular.
717        labelbottom, labeltop, labelleft, labelright : bool, optional
718            Where to draw the tick labels. Note that this can only be given if a
719            specific coordinate is specified via the ``axis`` argument, and it
720            will not work correctly if the frame is not rectangular.
721        grid_color : color, optional
722            The color of the grid lines (accepts any valid Matplotlib color).
723        grid_alpha : float, optional
724            Transparency of grid lines: 0 (transparent) to 1 (opaque).
725        grid_linewidth : float, optional
726            Width of grid lines in points.
727        grid_linestyle : str, optional
728            The style of the grid lines (accepts any valid Matplotlib line
729            style).
730        """
731
732        if not hasattr(self, 'coords'):
733            # Axes haven't been fully initialized yet, so just ignore, as
734            # Axes.__init__ calls this method
735            return
736
737        if axis == 'both':
738
739            for pos in ('bottom', 'left', 'top', 'right'):
740                if pos in kwargs:
741                    raise ValueError(f"Cannot specify {pos}= when axis='both'")
742                if 'label' + pos in kwargs:
743                    raise ValueError(f"Cannot specify label{pos}= when axis='both'")
744
745            for coord in self.coords:
746                coord.tick_params(**kwargs)
747
748        elif axis in self.coords:
749
750            self.coords[axis].tick_params(**kwargs)
751
752        elif axis in ('x', 'y') and self.frame_class is RectangularFrame:
753
754            spine = 'b' if axis == 'x' else 'l'
755
756            for coord in self.coords:
757                if spine in coord.axislabels.get_visible_axes():
758                    coord.tick_params(**kwargs)
759
760
761# In the following, we put the generated subplot class in a temporary class and
762# we then inherit it - if we don't do this, the generated class appears to
763# belong in matplotlib, not in WCSAxes, from the API's point of view.
764
765
766class WCSAxesSubplot(subplot_class_factory(WCSAxes)):
767    """
768    A subclass class for WCSAxes
769    """
770    pass
771