1from numbers import Number
2import functools
3
4import numpy as np
5
6import matplotlib as mpl
7from matplotlib import _api
8from matplotlib.gridspec import SubplotSpec
9
10from .axes_divider import Size, SubplotDivider, Divider
11from .mpl_axes import Axes
12
13
14def _tick_only(ax, bottom_on, left_on):
15    bottom_off = not bottom_on
16    left_off = not left_on
17    ax.axis["bottom"].toggle(ticklabels=bottom_off, label=bottom_off)
18    ax.axis["left"].toggle(ticklabels=left_off, label=left_off)
19
20
21class CbarAxesBase:
22    def __init__(self, *args, orientation, **kwargs):
23        self.orientation = orientation
24        self._default_label_on = True
25        self._locator = None  # deprecated.
26        super().__init__(*args, **kwargs)
27
28    def colorbar(self, mappable, *, ticks=None, **kwargs):
29
30        if self.orientation in ["top", "bottom"]:
31            orientation = "horizontal"
32        else:
33            orientation = "vertical"
34
35        cb = mpl.colorbar.Colorbar(
36            self, mappable, orientation=orientation, ticks=ticks, **kwargs)
37        self._cbid = mappable.colorbar_cid  # deprecated in 3.3.
38        self._locator = cb.locator  # deprecated in 3.3.
39
40        self._config_axes()
41        return cb
42
43    cbid = _api.deprecate_privatize_attribute(
44        "3.3", alternative="mappable.colorbar_cid")
45    locator = _api.deprecate_privatize_attribute(
46        "3.3", alternative=".colorbar().locator")
47
48    def _config_axes(self):
49        """Make an axes patch and outline."""
50        ax = self
51        ax.set_navigate(False)
52        ax.axis[:].toggle(all=False)
53        b = self._default_label_on
54        ax.axis[self.orientation].toggle(all=b)
55
56    def toggle_label(self, b):
57        self._default_label_on = b
58        axis = self.axis[self.orientation]
59        axis.toggle(ticklabels=b, label=b)
60
61    def cla(self):
62        super().cla()
63        self._config_axes()
64
65
66class CbarAxes(CbarAxesBase, Axes):
67    pass
68
69
70class Grid:
71    """
72    A grid of Axes.
73
74    In Matplotlib, the axes location (and size) is specified in normalized
75    figure coordinates. This may not be ideal for images that needs to be
76    displayed with a given aspect ratio; for example, it is difficult to
77    display multiple images of a same size with some fixed padding between
78    them.  AxesGrid can be used in such case.
79    """
80
81    _defaultAxesClass = Axes
82
83    @_api.delete_parameter("3.3", "add_all")
84    def __init__(self, fig,
85                 rect,
86                 nrows_ncols,
87                 ngrids=None,
88                 direction="row",
89                 axes_pad=0.02,
90                 add_all=True,
91                 share_all=False,
92                 share_x=True,
93                 share_y=True,
94                 label_mode="L",
95                 axes_class=None,
96                 *,
97                 aspect=False,
98                 ):
99        """
100        Parameters
101        ----------
102        fig : `.Figure`
103            The parent figure.
104        rect : (float, float, float, float) or int
105            The axes position, as a ``(left, bottom, width, height)`` tuple or
106            as a three-digit subplot position code (e.g., "121").
107        nrows_ncols : (int, int)
108            Number of rows and columns in the grid.
109        ngrids : int or None, default: None
110            If not None, only the first *ngrids* axes in the grid are created.
111        direction : {"row", "column"}, default: "row"
112            Whether axes are created in row-major ("row by row") or
113            column-major order ("column by column").
114        axes_pad : float or (float, float), default: 0.02
115            Padding or (horizontal padding, vertical padding) between axes, in
116            inches.
117        add_all : bool, default: True
118            Whether to add the axes to the figure using `.Figure.add_axes`.
119            This parameter is deprecated.
120        share_all : bool, default: False
121            Whether all axes share their x- and y-axis.  Overrides *share_x*
122            and *share_y*.
123        share_x : bool, default: True
124            Whether all axes of a column share their x-axis.
125        share_y : bool, default: True
126            Whether all axes of a row share their y-axis.
127        label_mode : {"L", "1", "all"}, default: "L"
128            Determines which axes will get tick labels:
129
130            - "L": All axes on the left column get vertical tick labels;
131              all axes on the bottom row get horizontal tick labels.
132            - "1": Only the bottom left axes is labelled.
133            - "all": all axes are labelled.
134
135        axes_class : subclass of `matplotlib.axes.Axes`, default: None
136        aspect : bool, default: False
137            Whether the axes aspect ratio follows the aspect ratio of the data
138            limits.
139        """
140        self._nrows, self._ncols = nrows_ncols
141
142        if ngrids is None:
143            ngrids = self._nrows * self._ncols
144        else:
145            if not 0 < ngrids <= self._nrows * self._ncols:
146                raise Exception("")
147
148        self.ngrids = ngrids
149
150        self._horiz_pad_size, self._vert_pad_size = map(
151            Size.Fixed, np.broadcast_to(axes_pad, 2))
152
153        _api.check_in_list(["column", "row"], direction=direction)
154        self._direction = direction
155
156        if axes_class is None:
157            axes_class = self._defaultAxesClass
158        elif isinstance(axes_class, (list, tuple)):
159            cls, kwargs = axes_class
160            axes_class = functools.partial(cls, **kwargs)
161
162        kw = dict(horizontal=[], vertical=[], aspect=aspect)
163        if isinstance(rect, (str, Number, SubplotSpec)):
164            self._divider = SubplotDivider(fig, rect, **kw)
165        elif len(rect) == 3:
166            self._divider = SubplotDivider(fig, *rect, **kw)
167        elif len(rect) == 4:
168            self._divider = Divider(fig, rect, **kw)
169        else:
170            raise Exception("")
171
172        rect = self._divider.get_position()
173
174        axes_array = np.full((self._nrows, self._ncols), None, dtype=object)
175        for i in range(self.ngrids):
176            col, row = self._get_col_row(i)
177            if share_all:
178                sharex = sharey = axes_array[0, 0]
179            else:
180                sharex = axes_array[0, col] if share_x else None
181                sharey = axes_array[row, 0] if share_y else None
182            axes_array[row, col] = axes_class(
183                fig, rect, sharex=sharex, sharey=sharey)
184        self.axes_all = axes_array.ravel().tolist()
185        self.axes_column = axes_array.T.tolist()
186        self.axes_row = axes_array.tolist()
187        self.axes_llc = self.axes_column[0][-1]
188
189        self._init_locators()
190
191        if add_all:
192            for ax in self.axes_all:
193                fig.add_axes(ax)
194
195        self.set_label_mode(label_mode)
196
197    def _init_locators(self):
198
199        h = []
200        h_ax_pos = []
201        for _ in range(self._ncols):
202            if h:
203                h.append(self._horiz_pad_size)
204            h_ax_pos.append(len(h))
205            sz = Size.Scaled(1)
206            h.append(sz)
207
208        v = []
209        v_ax_pos = []
210        for _ in range(self._nrows):
211            if v:
212                v.append(self._vert_pad_size)
213            v_ax_pos.append(len(v))
214            sz = Size.Scaled(1)
215            v.append(sz)
216
217        for i in range(self.ngrids):
218            col, row = self._get_col_row(i)
219            locator = self._divider.new_locator(
220                nx=h_ax_pos[col], ny=v_ax_pos[self._nrows - 1 - row])
221            self.axes_all[i].set_axes_locator(locator)
222
223        self._divider.set_horizontal(h)
224        self._divider.set_vertical(v)
225
226    def _get_col_row(self, n):
227        if self._direction == "column":
228            col, row = divmod(n, self._nrows)
229        else:
230            row, col = divmod(n, self._ncols)
231
232        return col, row
233
234    # Good to propagate __len__ if we have __getitem__
235    def __len__(self):
236        return len(self.axes_all)
237
238    def __getitem__(self, i):
239        return self.axes_all[i]
240
241    def get_geometry(self):
242        """
243        Return the number of rows and columns of the grid as (nrows, ncols).
244        """
245        return self._nrows, self._ncols
246
247    def set_axes_pad(self, axes_pad):
248        """
249        Set the padding between the axes.
250
251        Parameters
252        ----------
253        axes_pad : (float, float)
254            The padding (horizontal pad, vertical pad) in inches.
255        """
256        self._horiz_pad_size.fixed_size = axes_pad[0]
257        self._vert_pad_size.fixed_size = axes_pad[1]
258
259    def get_axes_pad(self):
260        """
261        Return the axes padding.
262
263        Returns
264        -------
265        hpad, vpad
266            Padding (horizontal pad, vertical pad) in inches.
267        """
268        return (self._horiz_pad_size.fixed_size,
269                self._vert_pad_size.fixed_size)
270
271    def set_aspect(self, aspect):
272        """Set the aspect of the SubplotDivider."""
273        self._divider.set_aspect(aspect)
274
275    def get_aspect(self):
276        """Return the aspect of the SubplotDivider."""
277        return self._divider.get_aspect()
278
279    def set_label_mode(self, mode):
280        """
281        Define which axes have tick labels.
282
283        Parameters
284        ----------
285        mode : {"L", "1", "all"}
286            The label mode:
287
288            - "L": All axes on the left column get vertical tick labels;
289              all axes on the bottom row get horizontal tick labels.
290            - "1": Only the bottom left axes is labelled.
291            - "all": all axes are labelled.
292        """
293        if mode == "all":
294            for ax in self.axes_all:
295                _tick_only(ax, False, False)
296        elif mode == "L":
297            # left-most axes
298            for ax in self.axes_column[0][:-1]:
299                _tick_only(ax, bottom_on=True, left_on=False)
300            # lower-left axes
301            ax = self.axes_column[0][-1]
302            _tick_only(ax, bottom_on=False, left_on=False)
303
304            for col in self.axes_column[1:]:
305                # axes with no labels
306                for ax in col[:-1]:
307                    _tick_only(ax, bottom_on=True, left_on=True)
308
309                # bottom
310                ax = col[-1]
311                _tick_only(ax, bottom_on=False, left_on=True)
312
313        elif mode == "1":
314            for ax in self.axes_all:
315                _tick_only(ax, bottom_on=True, left_on=True)
316
317            ax = self.axes_llc
318            _tick_only(ax, bottom_on=False, left_on=False)
319
320    def get_divider(self):
321        return self._divider
322
323    def set_axes_locator(self, locator):
324        self._divider.set_locator(locator)
325
326    def get_axes_locator(self):
327        return self._divider.get_locator()
328
329    def get_vsize_hsize(self):
330        return self._divider.get_vsize_hsize()
331
332
333class ImageGrid(Grid):
334    # docstring inherited
335
336    _defaultCbarAxesClass = CbarAxes
337
338    @_api.delete_parameter("3.3", "add_all")
339    def __init__(self, fig,
340                 rect,
341                 nrows_ncols,
342                 ngrids=None,
343                 direction="row",
344                 axes_pad=0.02,
345                 add_all=True,
346                 share_all=False,
347                 aspect=True,
348                 label_mode="L",
349                 cbar_mode=None,
350                 cbar_location="right",
351                 cbar_pad=None,
352                 cbar_size="5%",
353                 cbar_set_cax=True,
354                 axes_class=None,
355                 ):
356        """
357        Parameters
358        ----------
359        fig : `.Figure`
360            The parent figure.
361        rect : (float, float, float, float) or int
362            The axes position, as a ``(left, bottom, width, height)`` tuple or
363            as a three-digit subplot position code (e.g., "121").
364        nrows_ncols : (int, int)
365            Number of rows and columns in the grid.
366        ngrids : int or None, default: None
367            If not None, only the first *ngrids* axes in the grid are created.
368        direction : {"row", "column"}, default: "row"
369            Whether axes are created in row-major ("row by row") or
370            column-major order ("column by column").  This also affects the
371            order in which axes are accessed using indexing (``grid[index]``).
372        axes_pad : float or (float, float), default: 0.02in
373            Padding or (horizontal padding, vertical padding) between axes, in
374            inches.
375        add_all : bool, default: True
376            Whether to add the axes to the figure using `.Figure.add_axes`.
377            This parameter is deprecated.
378        share_all : bool, default: False
379            Whether all axes share their x- and y-axis.
380        aspect : bool, default: True
381            Whether the axes aspect ratio follows the aspect ratio of the data
382            limits.
383        label_mode : {"L", "1", "all"}, default: "L"
384            Determines which axes will get tick labels:
385
386            - "L": All axes on the left column get vertical tick labels;
387              all axes on the bottom row get horizontal tick labels.
388            - "1": Only the bottom left axes is labelled.
389            - "all": all axes are labelled.
390
391        cbar_mode : {"each", "single", "edge", None}, default: None
392            Whether to create a colorbar for "each" axes, a "single" colorbar
393            for the entire grid, colorbars only for axes on the "edge"
394            determined by *cbar_location*, or no colorbars.  The colorbars are
395            stored in the :attr:`cbar_axes` attribute.
396        cbar_location : {"left", "right", "bottom", "top"}, default: "right"
397        cbar_pad : float, default: None
398            Padding between the image axes and the colorbar axes.
399        cbar_size : size specification (see `.Size.from_any`), default: "5%"
400            Colorbar size.
401        cbar_set_cax : bool, default: True
402            If True, each axes in the grid has a *cax* attribute that is bound
403            to associated *cbar_axes*.
404        axes_class : subclass of `matplotlib.axes.Axes`, default: None
405        """
406        self._colorbar_mode = cbar_mode
407        self._colorbar_location = cbar_location
408        self._colorbar_pad = cbar_pad
409        self._colorbar_size = cbar_size
410        # The colorbar axes are created in _init_locators().
411
412        if add_all:
413            super().__init__(
414                fig, rect, nrows_ncols, ngrids,
415                direction=direction, axes_pad=axes_pad,
416                share_all=share_all, share_x=True, share_y=True, aspect=aspect,
417                label_mode=label_mode, axes_class=axes_class)
418        else:  # Only show deprecation in that case.
419            super().__init__(
420                fig, rect, nrows_ncols, ngrids,
421                direction=direction, axes_pad=axes_pad, add_all=add_all,
422                share_all=share_all, share_x=True, share_y=True, aspect=aspect,
423                label_mode=label_mode, axes_class=axes_class)
424
425        if add_all:
426            for ax in self.cbar_axes:
427                fig.add_axes(ax)
428
429        if cbar_set_cax:
430            if self._colorbar_mode == "single":
431                for ax in self.axes_all:
432                    ax.cax = self.cbar_axes[0]
433            elif self._colorbar_mode == "edge":
434                for index, ax in enumerate(self.axes_all):
435                    col, row = self._get_col_row(index)
436                    if self._colorbar_location in ("left", "right"):
437                        ax.cax = self.cbar_axes[row]
438                    else:
439                        ax.cax = self.cbar_axes[col]
440            else:
441                for ax, cax in zip(self.axes_all, self.cbar_axes):
442                    ax.cax = cax
443
444    def _init_locators(self):
445        # Slightly abusing this method to inject colorbar creation into init.
446
447        if self._colorbar_pad is None:
448            # horizontal or vertical arrangement?
449            if self._colorbar_location in ("left", "right"):
450                self._colorbar_pad = self._horiz_pad_size.fixed_size
451            else:
452                self._colorbar_pad = self._vert_pad_size.fixed_size
453        self.cbar_axes = [
454            self._defaultCbarAxesClass(
455                self.axes_all[0].figure, self._divider.get_position(),
456                orientation=self._colorbar_location)
457            for _ in range(self.ngrids)]
458
459        cb_mode = self._colorbar_mode
460        cb_location = self._colorbar_location
461
462        h = []
463        v = []
464
465        h_ax_pos = []
466        h_cb_pos = []
467        if cb_mode == "single" and cb_location in ("left", "bottom"):
468            if cb_location == "left":
469                sz = self._nrows * Size.AxesX(self.axes_llc)
470                h.append(Size.from_any(self._colorbar_size, sz))
471                h.append(Size.from_any(self._colorbar_pad, sz))
472                locator = self._divider.new_locator(nx=0, ny=0, ny1=-1)
473            elif cb_location == "bottom":
474                sz = self._ncols * Size.AxesY(self.axes_llc)
475                v.append(Size.from_any(self._colorbar_size, sz))
476                v.append(Size.from_any(self._colorbar_pad, sz))
477                locator = self._divider.new_locator(nx=0, nx1=-1, ny=0)
478            for i in range(self.ngrids):
479                self.cbar_axes[i].set_visible(False)
480            self.cbar_axes[0].set_axes_locator(locator)
481            self.cbar_axes[0].set_visible(True)
482
483        for col, ax in enumerate(self.axes_row[0]):
484            if h:
485                h.append(self._horiz_pad_size)
486
487            if ax:
488                sz = Size.AxesX(ax, aspect="axes", ref_ax=self.axes_all[0])
489            else:
490                sz = Size.AxesX(self.axes_all[0],
491                                aspect="axes", ref_ax=self.axes_all[0])
492
493            if (cb_location == "left"
494                    and (cb_mode == "each"
495                         or (cb_mode == "edge" and col == 0))):
496                h_cb_pos.append(len(h))
497                h.append(Size.from_any(self._colorbar_size, sz))
498                h.append(Size.from_any(self._colorbar_pad, sz))
499
500            h_ax_pos.append(len(h))
501            h.append(sz)
502
503            if (cb_location == "right"
504                    and (cb_mode == "each"
505                         or (cb_mode == "edge" and col == self._ncols - 1))):
506                h.append(Size.from_any(self._colorbar_pad, sz))
507                h_cb_pos.append(len(h))
508                h.append(Size.from_any(self._colorbar_size, sz))
509
510        v_ax_pos = []
511        v_cb_pos = []
512        for row, ax in enumerate(self.axes_column[0][::-1]):
513            if v:
514                v.append(self._vert_pad_size)
515
516            if ax:
517                sz = Size.AxesY(ax, aspect="axes", ref_ax=self.axes_all[0])
518            else:
519                sz = Size.AxesY(self.axes_all[0],
520                                aspect="axes", ref_ax=self.axes_all[0])
521
522            if (cb_location == "bottom"
523                    and (cb_mode == "each"
524                         or (cb_mode == "edge" and row == 0))):
525                v_cb_pos.append(len(v))
526                v.append(Size.from_any(self._colorbar_size, sz))
527                v.append(Size.from_any(self._colorbar_pad, sz))
528
529            v_ax_pos.append(len(v))
530            v.append(sz)
531
532            if (cb_location == "top"
533                    and (cb_mode == "each"
534                         or (cb_mode == "edge" and row == self._nrows - 1))):
535                v.append(Size.from_any(self._colorbar_pad, sz))
536                v_cb_pos.append(len(v))
537                v.append(Size.from_any(self._colorbar_size, sz))
538
539        for i in range(self.ngrids):
540            col, row = self._get_col_row(i)
541            locator = self._divider.new_locator(nx=h_ax_pos[col],
542                                                ny=v_ax_pos[self._nrows-1-row])
543            self.axes_all[i].set_axes_locator(locator)
544
545            if cb_mode == "each":
546                if cb_location in ("right", "left"):
547                    locator = self._divider.new_locator(
548                        nx=h_cb_pos[col], ny=v_ax_pos[self._nrows - 1 - row])
549
550                elif cb_location in ("top", "bottom"):
551                    locator = self._divider.new_locator(
552                        nx=h_ax_pos[col], ny=v_cb_pos[self._nrows - 1 - row])
553
554                self.cbar_axes[i].set_axes_locator(locator)
555            elif cb_mode == "edge":
556                if (cb_location == "left" and col == 0
557                        or cb_location == "right" and col == self._ncols - 1):
558                    locator = self._divider.new_locator(
559                        nx=h_cb_pos[0], ny=v_ax_pos[self._nrows - 1 - row])
560                    self.cbar_axes[row].set_axes_locator(locator)
561                elif (cb_location == "bottom" and row == self._nrows - 1
562                      or cb_location == "top" and row == 0):
563                    locator = self._divider.new_locator(nx=h_ax_pos[col],
564                                                        ny=v_cb_pos[0])
565                    self.cbar_axes[col].set_axes_locator(locator)
566
567        if cb_mode == "single":
568            if cb_location == "right":
569                sz = self._nrows * Size.AxesX(self.axes_llc)
570                h.append(Size.from_any(self._colorbar_pad, sz))
571                h.append(Size.from_any(self._colorbar_size, sz))
572                locator = self._divider.new_locator(nx=-2, ny=0, ny1=-1)
573            elif cb_location == "top":
574                sz = self._ncols * Size.AxesY(self.axes_llc)
575                v.append(Size.from_any(self._colorbar_pad, sz))
576                v.append(Size.from_any(self._colorbar_size, sz))
577                locator = self._divider.new_locator(nx=0, nx1=-1, ny=-2)
578            if cb_location in ("right", "top"):
579                for i in range(self.ngrids):
580                    self.cbar_axes[i].set_visible(False)
581                self.cbar_axes[0].set_axes_locator(locator)
582                self.cbar_axes[0].set_visible(True)
583        elif cb_mode == "each":
584            for i in range(self.ngrids):
585                self.cbar_axes[i].set_visible(True)
586        elif cb_mode == "edge":
587            if cb_location in ("right", "left"):
588                count = self._nrows
589            else:
590                count = self._ncols
591            for i in range(count):
592                self.cbar_axes[i].set_visible(True)
593            for j in range(i + 1, self.ngrids):
594                self.cbar_axes[j].set_visible(False)
595        else:
596            for i in range(self.ngrids):
597                self.cbar_axes[i].set_visible(False)
598                self.cbar_axes[i].set_position([1., 1., 0.001, 0.001],
599                                               which="active")
600
601        self._divider.set_horizontal(h)
602        self._divider.set_vertical(v)
603
604
605AxesGrid = ImageGrid
606