1from __future__ import (absolute_import, division, print_function,
2                        unicode_literals)
3
4import six
5
6import matplotlib.axes as maxes
7import matplotlib.cbook as cbook
8import matplotlib.ticker as ticker
9from matplotlib.gridspec import SubplotSpec
10
11from .axes_divider import Size, SubplotDivider, LocatableAxes, Divider
12from .colorbar import Colorbar
13
14
15def _extend_axes_pad(value):
16    # Check whether a list/tuple/array or scalar has been passed
17    ret = value
18    if not hasattr(ret, "__getitem__"):
19        ret = (value, value)
20    return ret
21
22
23def _tick_only(ax, bottom_on, left_on):
24    bottom_off = not bottom_on
25    left_off = not left_on
26    # [l.set_visible(bottom_off) for l in ax.get_xticklabels()]
27    # [l.set_visible(left_off) for l in ax.get_yticklabels()]
28    # ax.xaxis.label.set_visible(bottom_off)
29    # ax.yaxis.label.set_visible(left_off)
30    ax.axis["bottom"].toggle(ticklabels=bottom_off, label=bottom_off)
31    ax.axis["left"].toggle(ticklabels=left_off, label=left_off)
32
33
34class CbarAxesBase(object):
35
36    def colorbar(self, mappable, **kwargs):
37        locator = kwargs.pop("locator", None)
38
39        if locator is None:
40            if "ticks" not in kwargs:
41                kwargs["ticks"] = ticker.MaxNLocator(5)
42        if locator is not None:
43            if "ticks" in kwargs:
44                raise ValueError("Either *locator* or *ticks* need" +
45                                 " to be given, not both")
46            else:
47                kwargs["ticks"] = locator
48
49        self._hold = True
50        if self.orientation in ["top", "bottom"]:
51            orientation = "horizontal"
52        else:
53            orientation = "vertical"
54
55        cb = Colorbar(self, mappable, orientation=orientation, **kwargs)
56        self._config_axes()
57
58        def on_changed(m):
59            cb.set_cmap(m.get_cmap())
60            cb.set_clim(m.get_clim())
61            cb.update_bruteforce(m)
62
63        self.cbid = mappable.callbacksSM.connect('changed', on_changed)
64        mappable.colorbar = cb
65
66        self.locator = cb.cbar_axis.get_major_locator()
67
68        return cb
69
70    def _config_axes(self):
71        '''
72        Make an axes patch and outline.
73        '''
74        ax = self
75        ax.set_navigate(False)
76
77        ax.axis[:].toggle(all=False)
78        b = self._default_label_on
79        ax.axis[self.orientation].toggle(all=b)
80
81        # for axis in ax.axis.values():
82        #     axis.major_ticks.set_visible(False)
83        #     axis.minor_ticks.set_visible(False)
84        #     axis.major_ticklabels.set_visible(False)
85        #     axis.minor_ticklabels.set_visible(False)
86        #     axis.label.set_visible(False)
87
88        # axis = ax.axis[self.orientation]
89        # axis.major_ticks.set_visible(True)
90        # axis.minor_ticks.set_visible(True)
91
92        #axis.major_ticklabels.set_size(
93        #    int(axis.major_ticklabels.get_size()*.9))
94        #axis.major_tick_pad = 3
95
96        # axis.major_ticklabels.set_visible(b)
97        # axis.minor_ticklabels.set_visible(b)
98        # axis.label.set_visible(b)
99
100    def toggle_label(self, b):
101        self._default_label_on = b
102        axis = self.axis[self.orientation]
103        axis.toggle(ticklabels=b, label=b)
104        #axis.major_ticklabels.set_visible(b)
105        #axis.minor_ticklabels.set_visible(b)
106        #axis.label.set_visible(b)
107
108
109class CbarAxes(CbarAxesBase, LocatableAxes):
110    def __init__(self, *kl, **kwargs):
111        orientation = kwargs.pop("orientation", None)
112        if orientation is None:
113            raise ValueError("orientation must be specified")
114        self.orientation = orientation
115        self._default_label_on = True
116        self.locator = None
117
118        super(LocatableAxes, self).__init__(*kl, **kwargs)
119
120    def cla(self):
121        super(LocatableAxes, self).cla()
122        self._config_axes()
123
124
125class Grid(object):
126    """
127    A class that creates a grid of Axes. In matplotlib, the axes
128    location (and size) is specified in the normalized figure
129    coordinates. This may not be ideal for images that needs to be
130    displayed with a given aspect ratio.  For example, displaying
131    images of a same size with some fixed padding between them cannot
132    be easily done in matplotlib. AxesGrid is used in such case.
133    """
134
135    _defaultLocatableAxesClass = LocatableAxes
136
137    def __init__(self, fig,
138                 rect,
139                 nrows_ncols,
140                 ngrids=None,
141                 direction="row",
142                 axes_pad=0.02,
143                 add_all=True,
144                 share_all=False,
145                 share_x=True,
146                 share_y=True,
147                 #aspect=True,
148                 label_mode="L",
149                 axes_class=None,
150                 ):
151        """
152        Build an :class:`Grid` instance with a grid nrows*ncols
153        :class:`~matplotlib.axes.Axes` in
154        :class:`~matplotlib.figure.Figure` *fig* with
155        *rect=[left, bottom, width, height]* (in
156        :class:`~matplotlib.figure.Figure` coordinates) or
157        the subplot position code (e.g., "121").
158
159        Optional keyword arguments:
160
161          ================  ========  =========================================
162          Keyword           Default   Description
163          ================  ========  =========================================
164          direction         "row"     [ "row" | "column" ]
165          axes_pad          0.02      float| pad between axes given in inches
166                                      or tuple-like of floats,
167                                      (horizontal padding, vertical padding)
168          add_all           True      bool
169          share_all         False     bool
170          share_x           True      bool
171          share_y           True      bool
172          label_mode        "L"       [ "L" | "1" | "all" ]
173          axes_class        None      a type object which must be a subclass
174                                      of :class:`~matplotlib.axes.Axes`
175          ================  ========  =========================================
176        """
177        self._nrows, self._ncols = nrows_ncols
178
179        if ngrids is None:
180            ngrids = self._nrows * self._ncols
181        else:
182            if (ngrids > self._nrows * self._ncols) or (ngrids <= 0):
183                raise Exception("")
184
185        self.ngrids = ngrids
186
187        self._init_axes_pad(axes_pad)
188
189        if direction not in ["column", "row"]:
190            raise Exception("")
191
192        self._direction = direction
193
194        if axes_class is None:
195            axes_class = self._defaultLocatableAxesClass
196            axes_class_args = {}
197        else:
198            if (type(axes_class)) == type and \
199                   issubclass(axes_class,
200                              self._defaultLocatableAxesClass.Axes):
201                axes_class_args = {}
202            else:
203                axes_class, axes_class_args = axes_class
204
205        self.axes_all = []
206        self.axes_column = [[] for _ in range(self._ncols)]
207        self.axes_row = [[] for _ in range(self._nrows)]
208
209        h = []
210        v = []
211        if isinstance(rect, six.string_types) or cbook.is_numlike(rect):
212            self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v,
213                                           aspect=False)
214        elif isinstance(rect, SubplotSpec):
215            self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v,
216                                           aspect=False)
217        elif len(rect) == 3:
218            kw = dict(horizontal=h, vertical=v, aspect=False)
219            self._divider = SubplotDivider(fig, *rect, **kw)
220        elif len(rect) == 4:
221            self._divider = Divider(fig, rect, horizontal=h, vertical=v,
222                                    aspect=False)
223        else:
224            raise Exception("")
225
226        rect = self._divider.get_position()
227
228        # reference axes
229        self._column_refax = [None for _ in range(self._ncols)]
230        self._row_refax = [None for _ in range(self._nrows)]
231        self._refax = None
232
233        for i in range(self.ngrids):
234
235            col, row = self._get_col_row(i)
236
237            if share_all:
238                sharex = self._refax
239                sharey = self._refax
240            else:
241                if share_x:
242                    sharex = self._column_refax[col]
243                else:
244                    sharex = None
245
246                if share_y:
247                    sharey = self._row_refax[row]
248                else:
249                    sharey = None
250
251            ax = axes_class(fig, rect, sharex=sharex, sharey=sharey,
252                            **axes_class_args)
253
254            if share_all:
255                if self._refax is None:
256                    self._refax = ax
257            else:
258                if sharex is None:
259                    self._column_refax[col] = ax
260                if sharey is None:
261                    self._row_refax[row] = ax
262
263            self.axes_all.append(ax)
264            self.axes_column[col].append(ax)
265            self.axes_row[row].append(ax)
266
267        self.axes_llc = self.axes_column[0][-1]
268
269        self._update_locators()
270
271        if add_all:
272            for ax in self.axes_all:
273                fig.add_axes(ax)
274
275        self.set_label_mode(label_mode)
276
277    def _init_axes_pad(self, axes_pad):
278        axes_pad = _extend_axes_pad(axes_pad)
279        self._axes_pad = axes_pad
280
281        self._horiz_pad_size = Size.Fixed(axes_pad[0])
282        self._vert_pad_size = Size.Fixed(axes_pad[1])
283
284    def _update_locators(self):
285
286        h = []
287
288        h_ax_pos = []
289
290        for _ in self._column_refax:
291            #if h: h.append(Size.Fixed(self._axes_pad))
292            if h:
293                h.append(self._horiz_pad_size)
294
295            h_ax_pos.append(len(h))
296
297            sz = Size.Scaled(1)
298            h.append(sz)
299
300        v = []
301
302        v_ax_pos = []
303        for _ in self._row_refax[::-1]:
304            #if v: v.append(Size.Fixed(self._axes_pad))
305            if v:
306                v.append(self._vert_pad_size)
307
308            v_ax_pos.append(len(v))
309            sz = Size.Scaled(1)
310            v.append(sz)
311
312        for i in range(self.ngrids):
313            col, row = self._get_col_row(i)
314            locator = self._divider.new_locator(nx=h_ax_pos[col],
315                                ny=v_ax_pos[self._nrows - 1 - row])
316            self.axes_all[i].set_axes_locator(locator)
317
318        self._divider.set_horizontal(h)
319        self._divider.set_vertical(v)
320
321    def _get_col_row(self, n):
322        if self._direction == "column":
323            col, row = divmod(n, self._nrows)
324        else:
325            row, col = divmod(n, self._ncols)
326
327        return col, row
328
329    # Good to propagate __len__ if we have __getitem__
330    def __len__(self):
331        return len(self.axes_all)
332
333    def __getitem__(self, i):
334        return self.axes_all[i]
335
336    def get_geometry(self):
337        """
338        get geometry of the grid. Returns a tuple of two integer,
339        representing number of rows and number of columns.
340        """
341        return self._nrows, self._ncols
342
343    def set_axes_pad(self, axes_pad):
344        "set axes_pad"
345        self._axes_pad = axes_pad
346
347        # These two lines actually differ from ones in _init_axes_pad
348        self._horiz_pad_size.fixed_size = axes_pad[0]
349        self._vert_pad_size.fixed_size = axes_pad[1]
350
351    def get_axes_pad(self):
352        """
353        get axes_pad
354
355        Returns
356        -------
357        tuple
358            Padding in inches, (horizontal pad, vertical pad)
359        """
360        return self._axes_pad
361
362    def set_aspect(self, aspect):
363        "set aspect"
364        self._divider.set_aspect(aspect)
365
366    def get_aspect(self):
367        "get aspect"
368        return self._divider.get_aspect()
369
370    def set_label_mode(self, mode):
371        "set label_mode"
372        if mode == "all":
373            for ax in self.axes_all:
374                _tick_only(ax, False, False)
375        elif mode == "L":
376            # left-most axes
377            for ax in self.axes_column[0][:-1]:
378                _tick_only(ax, bottom_on=True, left_on=False)
379            # lower-left axes
380            ax = self.axes_column[0][-1]
381            _tick_only(ax, bottom_on=False, left_on=False)
382
383            for col in self.axes_column[1:]:
384                # axes with no labels
385                for ax in col[:-1]:
386                    _tick_only(ax, bottom_on=True, left_on=True)
387
388                # bottom
389                ax = col[-1]
390                _tick_only(ax, bottom_on=False, left_on=True)
391
392        elif mode == "1":
393            for ax in self.axes_all:
394                _tick_only(ax, bottom_on=True, left_on=True)
395
396            ax = self.axes_llc
397            _tick_only(ax, bottom_on=False, left_on=False)
398
399    def get_divider(self):
400        return self._divider
401
402    def set_axes_locator(self, locator):
403        self._divider.set_locator(locator)
404
405    def get_axes_locator(self):
406        return self._divider.get_locator()
407
408    def get_vsize_hsize(self):
409
410        return self._divider.get_vsize_hsize()
411#         from axes_size import AddList
412
413#         vsize = AddList(self._divider.get_vertical())
414#         hsize = AddList(self._divider.get_horizontal())
415
416#         return vsize, hsize
417
418
419class ImageGrid(Grid):
420    """
421    A class that creates a grid of Axes. In matplotlib, the axes
422    location (and size) is specified in the normalized figure
423    coordinates. This may not be ideal for images that needs to be
424    displayed with a given aspect ratio.  For example, displaying
425    images of a same size with some fixed padding between them cannot
426    be easily done in matplotlib. ImageGrid is used in such case.
427    """
428
429    _defaultCbarAxesClass = CbarAxes
430
431    def __init__(self, fig,
432                 rect,
433                 nrows_ncols,
434                 ngrids=None,
435                 direction="row",
436                 axes_pad=0.02,
437                 add_all=True,
438                 share_all=False,
439                 aspect=True,
440                 label_mode="L",
441                 cbar_mode=None,
442                 cbar_location="right",
443                 cbar_pad=None,
444                 cbar_size="5%",
445                 cbar_set_cax=True,
446                 axes_class=None,
447                 ):
448        """
449        Build an :class:`ImageGrid` instance with a grid nrows*ncols
450        :class:`~matplotlib.axes.Axes` in
451        :class:`~matplotlib.figure.Figure` *fig* with
452        *rect=[left, bottom, width, height]* (in
453        :class:`~matplotlib.figure.Figure` coordinates) or
454        the subplot position code (e.g., "121").
455
456        Optional keyword arguments:
457
458          ================  ========  =========================================
459          Keyword           Default   Description
460          ================  ========  =========================================
461          direction         "row"     [ "row" | "column" ]
462          axes_pad          0.02      float| pad between axes given in inches
463                                      or tuple-like of floats,
464                                      (horizontal padding, vertical padding)
465          add_all           True      bool
466          share_all         False     bool
467          aspect            True      bool
468          label_mode        "L"       [ "L" | "1" | "all" ]
469          cbar_mode         None      [ "each" | "single" | "edge" ]
470          cbar_location     "right"   [ "left" | "right" | "bottom" | "top" ]
471          cbar_pad          None
472          cbar_size         "5%"
473          cbar_set_cax      True      bool
474          axes_class        None      a type object which must be a subclass
475                                      of axes_grid's subclass of
476                                      :class:`~matplotlib.axes.Axes`
477          ================  ========  =========================================
478
479        *cbar_set_cax* : if True, each axes in the grid has a cax
480          attribute that is bind to associated cbar_axes.
481        """
482        self._nrows, self._ncols = nrows_ncols
483
484        if ngrids is None:
485            ngrids = self._nrows * self._ncols
486        else:
487            if not 0 <= ngrids < self._nrows * self._ncols:
488                raise Exception
489
490        self.ngrids = ngrids
491
492        axes_pad = _extend_axes_pad(axes_pad)
493        self._axes_pad = axes_pad
494
495        self._colorbar_mode = cbar_mode
496        self._colorbar_location = cbar_location
497        if cbar_pad is None:
498            # horizontal or vertical arrangement?
499            if cbar_location in ("left", "right"):
500                self._colorbar_pad = axes_pad[0]
501            else:
502                self._colorbar_pad = axes_pad[1]
503        else:
504            self._colorbar_pad = cbar_pad
505
506        self._colorbar_size = cbar_size
507
508        self._init_axes_pad(axes_pad)
509
510        if direction not in ["column", "row"]:
511            raise Exception("")
512
513        self._direction = direction
514
515        if axes_class is None:
516            axes_class = self._defaultLocatableAxesClass
517            axes_class_args = {}
518        else:
519            if isinstance(axes_class, maxes.Axes):
520                axes_class_args = {}
521            else:
522                axes_class, axes_class_args = axes_class
523
524        self.axes_all = []
525        self.axes_column = [[] for _ in range(self._ncols)]
526        self.axes_row = [[] for _ in range(self._nrows)]
527
528        self.cbar_axes = []
529
530        h = []
531        v = []
532        if isinstance(rect, six.string_types) or cbook.is_numlike(rect):
533            self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v,
534                                           aspect=aspect)
535        elif isinstance(rect, SubplotSpec):
536            self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v,
537                                           aspect=aspect)
538        elif len(rect) == 3:
539            kw = dict(horizontal=h, vertical=v, aspect=aspect)
540            self._divider = SubplotDivider(fig, *rect, **kw)
541        elif len(rect) == 4:
542            self._divider = Divider(fig, rect, horizontal=h, vertical=v,
543                                    aspect=aspect)
544        else:
545            raise Exception("")
546
547        rect = self._divider.get_position()
548
549        # reference axes
550        self._column_refax = [None for _ in range(self._ncols)]
551        self._row_refax = [None for _ in range(self._nrows)]
552        self._refax = None
553
554        for i in range(self.ngrids):
555
556            col, row = self._get_col_row(i)
557
558            if share_all:
559                if self.axes_all:
560                    sharex = self.axes_all[0]
561                    sharey = self.axes_all[0]
562                else:
563                    sharex = None
564                    sharey = None
565            else:
566                sharex = self._column_refax[col]
567                sharey = self._row_refax[row]
568
569            ax = axes_class(fig, rect, sharex=sharex, sharey=sharey,
570                            **axes_class_args)
571
572            self.axes_all.append(ax)
573            self.axes_column[col].append(ax)
574            self.axes_row[row].append(ax)
575
576            if share_all:
577                if self._refax is None:
578                    self._refax = ax
579            if sharex is None:
580                self._column_refax[col] = ax
581            if sharey is None:
582                self._row_refax[row] = ax
583
584            cax = self._defaultCbarAxesClass(fig, rect,
585                                        orientation=self._colorbar_location)
586            self.cbar_axes.append(cax)
587
588        self.axes_llc = self.axes_column[0][-1]
589
590        self._update_locators()
591
592        if add_all:
593            for ax in self.axes_all+self.cbar_axes:
594                fig.add_axes(ax)
595
596        if cbar_set_cax:
597            if self._colorbar_mode == "single":
598                for ax in self.axes_all:
599                    ax.cax = self.cbar_axes[0]
600            elif self._colorbar_mode == "edge":
601                for index, ax in enumerate(self.axes_all):
602                    col, row = self._get_col_row(index)
603                    if self._colorbar_location in ("left", "right"):
604                        ax.cax = self.cbar_axes[row]
605                    else:
606                        ax.cax = self.cbar_axes[col]
607            else:
608                for ax, cax in zip(self.axes_all, self.cbar_axes):
609                    ax.cax = cax
610
611        self.set_label_mode(label_mode)
612
613    def _update_locators(self):
614
615        h = []
616        v = []
617
618        h_ax_pos = []
619        h_cb_pos = []
620        if (self._colorbar_mode == "single" and
621             self._colorbar_location in ('left', 'bottom')):
622            if self._colorbar_location == "left":
623                #sz = Size.Fraction(Size.AxesX(self.axes_llc), self._nrows)
624                sz = Size.Fraction(self._nrows, Size.AxesX(self.axes_llc))
625                h.append(Size.from_any(self._colorbar_size, sz))
626                h.append(Size.from_any(self._colorbar_pad, sz))
627                locator = self._divider.new_locator(nx=0, ny=0, ny1=-1)
628            elif self._colorbar_location == "bottom":
629                #sz = Size.Fraction(Size.AxesY(self.axes_llc), self._ncols)
630                sz = Size.Fraction(self._ncols, Size.AxesY(self.axes_llc))
631                v.append(Size.from_any(self._colorbar_size, sz))
632                v.append(Size.from_any(self._colorbar_pad, sz))
633                locator = self._divider.new_locator(nx=0, nx1=-1, ny=0)
634            for i in range(self.ngrids):
635                self.cbar_axes[i].set_visible(False)
636            self.cbar_axes[0].set_axes_locator(locator)
637            self.cbar_axes[0].set_visible(True)
638
639        for col, ax in enumerate(self.axes_row[0]):
640            if h:
641                h.append(self._horiz_pad_size)  # Size.Fixed(self._axes_pad))
642
643            if ax:
644                sz = Size.AxesX(ax, aspect="axes", ref_ax=self.axes_all[0])
645            else:
646                sz = Size.AxesX(self.axes_all[0],
647                                aspect="axes", ref_ax=self.axes_all[0])
648
649            if (self._colorbar_mode == "each" or
650                    (self._colorbar_mode == 'edge' and
651                        col == 0)) and self._colorbar_location == "left":
652                h_cb_pos.append(len(h))
653                h.append(Size.from_any(self._colorbar_size, sz))
654                h.append(Size.from_any(self._colorbar_pad, sz))
655
656            h_ax_pos.append(len(h))
657
658            h.append(sz)
659
660            if ((self._colorbar_mode == "each" or
661                    (self._colorbar_mode == 'edge' and
662                        col == self._ncols - 1)) and
663                    self._colorbar_location == "right"):
664                h.append(Size.from_any(self._colorbar_pad, sz))
665                h_cb_pos.append(len(h))
666                h.append(Size.from_any(self._colorbar_size, sz))
667
668        v_ax_pos = []
669        v_cb_pos = []
670        for row, ax in enumerate(self.axes_column[0][::-1]):
671            if v:
672                v.append(self._vert_pad_size)  # Size.Fixed(self._axes_pad))
673
674            if ax:
675                sz = Size.AxesY(ax, aspect="axes", ref_ax=self.axes_all[0])
676            else:
677                sz = Size.AxesY(self.axes_all[0],
678                                aspect="axes", ref_ax=self.axes_all[0])
679
680            if (self._colorbar_mode == "each" or
681                    (self._colorbar_mode == 'edge' and
682                        row == 0)) and self._colorbar_location == "bottom":
683                v_cb_pos.append(len(v))
684                v.append(Size.from_any(self._colorbar_size, sz))
685                v.append(Size.from_any(self._colorbar_pad, sz))
686
687            v_ax_pos.append(len(v))
688            v.append(sz)
689
690            if ((self._colorbar_mode == "each" or
691                    (self._colorbar_mode == 'edge' and
692                        row == self._nrows - 1)) and
693                        self._colorbar_location == "top"):
694                v.append(Size.from_any(self._colorbar_pad, sz))
695                v_cb_pos.append(len(v))
696                v.append(Size.from_any(self._colorbar_size, sz))
697
698        for i in range(self.ngrids):
699            col, row = self._get_col_row(i)
700            #locator = self._divider.new_locator(nx=4*col,
701            #                                    ny=2*(self._nrows - row - 1))
702            locator = self._divider.new_locator(nx=h_ax_pos[col],
703                                                ny=v_ax_pos[self._nrows-1-row])
704            self.axes_all[i].set_axes_locator(locator)
705
706            if self._colorbar_mode == "each":
707                if self._colorbar_location in ("right", "left"):
708                    locator = self._divider.new_locator(
709                        nx=h_cb_pos[col], ny=v_ax_pos[self._nrows - 1 - row])
710
711                elif self._colorbar_location in ("top", "bottom"):
712                    locator = self._divider.new_locator(
713                        nx=h_ax_pos[col], ny=v_cb_pos[self._nrows - 1 - row])
714
715                self.cbar_axes[i].set_axes_locator(locator)
716            elif self._colorbar_mode == 'edge':
717                if ((self._colorbar_location == 'left' and col == 0) or
718                        (self._colorbar_location == 'right'
719                         and col == self._ncols-1)):
720                    locator = self._divider.new_locator(
721                        nx=h_cb_pos[0], ny=v_ax_pos[self._nrows -1 - row])
722                    self.cbar_axes[row].set_axes_locator(locator)
723                elif ((self._colorbar_location == 'bottom' and
724                       row == self._nrows - 1) or
725                        (self._colorbar_location == 'top' and row == 0)):
726                    locator = self._divider.new_locator(nx=h_ax_pos[col],
727                                                        ny=v_cb_pos[0])
728                    self.cbar_axes[col].set_axes_locator(locator)
729
730        if self._colorbar_mode == "single":
731            if self._colorbar_location == "right":
732                #sz = Size.Fraction(Size.AxesX(self.axes_llc), self._nrows)
733                sz = Size.Fraction(self._nrows, Size.AxesX(self.axes_llc))
734                h.append(Size.from_any(self._colorbar_pad, sz))
735                h.append(Size.from_any(self._colorbar_size, sz))
736                locator = self._divider.new_locator(nx=-2, ny=0, ny1=-1)
737            elif self._colorbar_location == "top":
738                #sz = Size.Fraction(Size.AxesY(self.axes_llc), self._ncols)
739                sz = Size.Fraction(self._ncols, Size.AxesY(self.axes_llc))
740                v.append(Size.from_any(self._colorbar_pad, sz))
741                v.append(Size.from_any(self._colorbar_size, sz))
742                locator = self._divider.new_locator(nx=0, nx1=-1, ny=-2)
743            if self._colorbar_location in ("right", "top"):
744                for i in range(self.ngrids):
745                    self.cbar_axes[i].set_visible(False)
746                self.cbar_axes[0].set_axes_locator(locator)
747                self.cbar_axes[0].set_visible(True)
748        elif self._colorbar_mode == "each":
749            for i in range(self.ngrids):
750                self.cbar_axes[i].set_visible(True)
751        elif self._colorbar_mode == "edge":
752            if self._colorbar_location in ('right', 'left'):
753                count = self._nrows
754            else:
755                count = self._ncols
756            for i in range(count):
757                self.cbar_axes[i].set_visible(True)
758            for j in range(i + 1, self.ngrids):
759                self.cbar_axes[j].set_visible(False)
760        else:
761            for i in range(self.ngrids):
762                self.cbar_axes[i].set_visible(False)
763                self.cbar_axes[i].set_position([1., 1., 0.001, 0.001],
764                                               which="active")
765
766        self._divider.set_horizontal(h)
767        self._divider.set_vertical(v)
768
769
770AxesGrid = ImageGrid
771
772