1from numbers import Number 2import functools 3 4import numpy as np 5 6import matplotlib as mpl 7from matplotlib import _api 8from matplotlib.gridspec import SubplotSpec 9 10from .axes_divider import Size, SubplotDivider, Divider 11from .mpl_axes import Axes 12 13 14def _tick_only(ax, bottom_on, left_on): 15 bottom_off = not bottom_on 16 left_off = not left_on 17 ax.axis["bottom"].toggle(ticklabels=bottom_off, label=bottom_off) 18 ax.axis["left"].toggle(ticklabels=left_off, label=left_off) 19 20 21class CbarAxesBase: 22 def __init__(self, *args, orientation, **kwargs): 23 self.orientation = orientation 24 self._default_label_on = True 25 self._locator = None # deprecated. 26 super().__init__(*args, **kwargs) 27 28 def colorbar(self, mappable, *, ticks=None, **kwargs): 29 30 if self.orientation in ["top", "bottom"]: 31 orientation = "horizontal" 32 else: 33 orientation = "vertical" 34 35 cb = mpl.colorbar.Colorbar( 36 self, mappable, orientation=orientation, ticks=ticks, **kwargs) 37 self._cbid = mappable.colorbar_cid # deprecated in 3.3. 38 self._locator = cb.locator # deprecated in 3.3. 39 40 self._config_axes() 41 return cb 42 43 cbid = _api.deprecate_privatize_attribute( 44 "3.3", alternative="mappable.colorbar_cid") 45 locator = _api.deprecate_privatize_attribute( 46 "3.3", alternative=".colorbar().locator") 47 48 def _config_axes(self): 49 """Make an axes patch and outline.""" 50 ax = self 51 ax.set_navigate(False) 52 ax.axis[:].toggle(all=False) 53 b = self._default_label_on 54 ax.axis[self.orientation].toggle(all=b) 55 56 def toggle_label(self, b): 57 self._default_label_on = b 58 axis = self.axis[self.orientation] 59 axis.toggle(ticklabels=b, label=b) 60 61 def cla(self): 62 super().cla() 63 self._config_axes() 64 65 66class CbarAxes(CbarAxesBase, Axes): 67 pass 68 69 70class Grid: 71 """ 72 A grid of Axes. 73 74 In Matplotlib, the axes location (and size) is specified in normalized 75 figure coordinates. This may not be ideal for images that needs to be 76 displayed with a given aspect ratio; for example, it is difficult to 77 display multiple images of a same size with some fixed padding between 78 them. AxesGrid can be used in such case. 79 """ 80 81 _defaultAxesClass = Axes 82 83 @_api.delete_parameter("3.3", "add_all") 84 def __init__(self, fig, 85 rect, 86 nrows_ncols, 87 ngrids=None, 88 direction="row", 89 axes_pad=0.02, 90 add_all=True, 91 share_all=False, 92 share_x=True, 93 share_y=True, 94 label_mode="L", 95 axes_class=None, 96 *, 97 aspect=False, 98 ): 99 """ 100 Parameters 101 ---------- 102 fig : `.Figure` 103 The parent figure. 104 rect : (float, float, float, float) or int 105 The axes position, as a ``(left, bottom, width, height)`` tuple or 106 as a three-digit subplot position code (e.g., "121"). 107 nrows_ncols : (int, int) 108 Number of rows and columns in the grid. 109 ngrids : int or None, default: None 110 If not None, only the first *ngrids* axes in the grid are created. 111 direction : {"row", "column"}, default: "row" 112 Whether axes are created in row-major ("row by row") or 113 column-major order ("column by column"). 114 axes_pad : float or (float, float), default: 0.02 115 Padding or (horizontal padding, vertical padding) between axes, in 116 inches. 117 add_all : bool, default: True 118 Whether to add the axes to the figure using `.Figure.add_axes`. 119 This parameter is deprecated. 120 share_all : bool, default: False 121 Whether all axes share their x- and y-axis. Overrides *share_x* 122 and *share_y*. 123 share_x : bool, default: True 124 Whether all axes of a column share their x-axis. 125 share_y : bool, default: True 126 Whether all axes of a row share their y-axis. 127 label_mode : {"L", "1", "all"}, default: "L" 128 Determines which axes will get tick labels: 129 130 - "L": All axes on the left column get vertical tick labels; 131 all axes on the bottom row get horizontal tick labels. 132 - "1": Only the bottom left axes is labelled. 133 - "all": all axes are labelled. 134 135 axes_class : subclass of `matplotlib.axes.Axes`, default: None 136 aspect : bool, default: False 137 Whether the axes aspect ratio follows the aspect ratio of the data 138 limits. 139 """ 140 self._nrows, self._ncols = nrows_ncols 141 142 if ngrids is None: 143 ngrids = self._nrows * self._ncols 144 else: 145 if not 0 < ngrids <= self._nrows * self._ncols: 146 raise Exception("") 147 148 self.ngrids = ngrids 149 150 self._horiz_pad_size, self._vert_pad_size = map( 151 Size.Fixed, np.broadcast_to(axes_pad, 2)) 152 153 _api.check_in_list(["column", "row"], direction=direction) 154 self._direction = direction 155 156 if axes_class is None: 157 axes_class = self._defaultAxesClass 158 elif isinstance(axes_class, (list, tuple)): 159 cls, kwargs = axes_class 160 axes_class = functools.partial(cls, **kwargs) 161 162 kw = dict(horizontal=[], vertical=[], aspect=aspect) 163 if isinstance(rect, (str, Number, SubplotSpec)): 164 self._divider = SubplotDivider(fig, rect, **kw) 165 elif len(rect) == 3: 166 self._divider = SubplotDivider(fig, *rect, **kw) 167 elif len(rect) == 4: 168 self._divider = Divider(fig, rect, **kw) 169 else: 170 raise Exception("") 171 172 rect = self._divider.get_position() 173 174 axes_array = np.full((self._nrows, self._ncols), None, dtype=object) 175 for i in range(self.ngrids): 176 col, row = self._get_col_row(i) 177 if share_all: 178 sharex = sharey = axes_array[0, 0] 179 else: 180 sharex = axes_array[0, col] if share_x else None 181 sharey = axes_array[row, 0] if share_y else None 182 axes_array[row, col] = axes_class( 183 fig, rect, sharex=sharex, sharey=sharey) 184 self.axes_all = axes_array.ravel().tolist() 185 self.axes_column = axes_array.T.tolist() 186 self.axes_row = axes_array.tolist() 187 self.axes_llc = self.axes_column[0][-1] 188 189 self._init_locators() 190 191 if add_all: 192 for ax in self.axes_all: 193 fig.add_axes(ax) 194 195 self.set_label_mode(label_mode) 196 197 def _init_locators(self): 198 199 h = [] 200 h_ax_pos = [] 201 for _ in range(self._ncols): 202 if h: 203 h.append(self._horiz_pad_size) 204 h_ax_pos.append(len(h)) 205 sz = Size.Scaled(1) 206 h.append(sz) 207 208 v = [] 209 v_ax_pos = [] 210 for _ in range(self._nrows): 211 if v: 212 v.append(self._vert_pad_size) 213 v_ax_pos.append(len(v)) 214 sz = Size.Scaled(1) 215 v.append(sz) 216 217 for i in range(self.ngrids): 218 col, row = self._get_col_row(i) 219 locator = self._divider.new_locator( 220 nx=h_ax_pos[col], ny=v_ax_pos[self._nrows - 1 - row]) 221 self.axes_all[i].set_axes_locator(locator) 222 223 self._divider.set_horizontal(h) 224 self._divider.set_vertical(v) 225 226 def _get_col_row(self, n): 227 if self._direction == "column": 228 col, row = divmod(n, self._nrows) 229 else: 230 row, col = divmod(n, self._ncols) 231 232 return col, row 233 234 # Good to propagate __len__ if we have __getitem__ 235 def __len__(self): 236 return len(self.axes_all) 237 238 def __getitem__(self, i): 239 return self.axes_all[i] 240 241 def get_geometry(self): 242 """ 243 Return the number of rows and columns of the grid as (nrows, ncols). 244 """ 245 return self._nrows, self._ncols 246 247 def set_axes_pad(self, axes_pad): 248 """ 249 Set the padding between the axes. 250 251 Parameters 252 ---------- 253 axes_pad : (float, float) 254 The padding (horizontal pad, vertical pad) in inches. 255 """ 256 self._horiz_pad_size.fixed_size = axes_pad[0] 257 self._vert_pad_size.fixed_size = axes_pad[1] 258 259 def get_axes_pad(self): 260 """ 261 Return the axes padding. 262 263 Returns 264 ------- 265 hpad, vpad 266 Padding (horizontal pad, vertical pad) in inches. 267 """ 268 return (self._horiz_pad_size.fixed_size, 269 self._vert_pad_size.fixed_size) 270 271 def set_aspect(self, aspect): 272 """Set the aspect of the SubplotDivider.""" 273 self._divider.set_aspect(aspect) 274 275 def get_aspect(self): 276 """Return the aspect of the SubplotDivider.""" 277 return self._divider.get_aspect() 278 279 def set_label_mode(self, mode): 280 """ 281 Define which axes have tick labels. 282 283 Parameters 284 ---------- 285 mode : {"L", "1", "all"} 286 The label mode: 287 288 - "L": All axes on the left column get vertical tick labels; 289 all axes on the bottom row get horizontal tick labels. 290 - "1": Only the bottom left axes is labelled. 291 - "all": all axes are labelled. 292 """ 293 if mode == "all": 294 for ax in self.axes_all: 295 _tick_only(ax, False, False) 296 elif mode == "L": 297 # left-most axes 298 for ax in self.axes_column[0][:-1]: 299 _tick_only(ax, bottom_on=True, left_on=False) 300 # lower-left axes 301 ax = self.axes_column[0][-1] 302 _tick_only(ax, bottom_on=False, left_on=False) 303 304 for col in self.axes_column[1:]: 305 # axes with no labels 306 for ax in col[:-1]: 307 _tick_only(ax, bottom_on=True, left_on=True) 308 309 # bottom 310 ax = col[-1] 311 _tick_only(ax, bottom_on=False, left_on=True) 312 313 elif mode == "1": 314 for ax in self.axes_all: 315 _tick_only(ax, bottom_on=True, left_on=True) 316 317 ax = self.axes_llc 318 _tick_only(ax, bottom_on=False, left_on=False) 319 320 def get_divider(self): 321 return self._divider 322 323 def set_axes_locator(self, locator): 324 self._divider.set_locator(locator) 325 326 def get_axes_locator(self): 327 return self._divider.get_locator() 328 329 def get_vsize_hsize(self): 330 return self._divider.get_vsize_hsize() 331 332 333class ImageGrid(Grid): 334 # docstring inherited 335 336 _defaultCbarAxesClass = CbarAxes 337 338 @_api.delete_parameter("3.3", "add_all") 339 def __init__(self, fig, 340 rect, 341 nrows_ncols, 342 ngrids=None, 343 direction="row", 344 axes_pad=0.02, 345 add_all=True, 346 share_all=False, 347 aspect=True, 348 label_mode="L", 349 cbar_mode=None, 350 cbar_location="right", 351 cbar_pad=None, 352 cbar_size="5%", 353 cbar_set_cax=True, 354 axes_class=None, 355 ): 356 """ 357 Parameters 358 ---------- 359 fig : `.Figure` 360 The parent figure. 361 rect : (float, float, float, float) or int 362 The axes position, as a ``(left, bottom, width, height)`` tuple or 363 as a three-digit subplot position code (e.g., "121"). 364 nrows_ncols : (int, int) 365 Number of rows and columns in the grid. 366 ngrids : int or None, default: None 367 If not None, only the first *ngrids* axes in the grid are created. 368 direction : {"row", "column"}, default: "row" 369 Whether axes are created in row-major ("row by row") or 370 column-major order ("column by column"). This also affects the 371 order in which axes are accessed using indexing (``grid[index]``). 372 axes_pad : float or (float, float), default: 0.02in 373 Padding or (horizontal padding, vertical padding) between axes, in 374 inches. 375 add_all : bool, default: True 376 Whether to add the axes to the figure using `.Figure.add_axes`. 377 This parameter is deprecated. 378 share_all : bool, default: False 379 Whether all axes share their x- and y-axis. 380 aspect : bool, default: True 381 Whether the axes aspect ratio follows the aspect ratio of the data 382 limits. 383 label_mode : {"L", "1", "all"}, default: "L" 384 Determines which axes will get tick labels: 385 386 - "L": All axes on the left column get vertical tick labels; 387 all axes on the bottom row get horizontal tick labels. 388 - "1": Only the bottom left axes is labelled. 389 - "all": all axes are labelled. 390 391 cbar_mode : {"each", "single", "edge", None}, default: None 392 Whether to create a colorbar for "each" axes, a "single" colorbar 393 for the entire grid, colorbars only for axes on the "edge" 394 determined by *cbar_location*, or no colorbars. The colorbars are 395 stored in the :attr:`cbar_axes` attribute. 396 cbar_location : {"left", "right", "bottom", "top"}, default: "right" 397 cbar_pad : float, default: None 398 Padding between the image axes and the colorbar axes. 399 cbar_size : size specification (see `.Size.from_any`), default: "5%" 400 Colorbar size. 401 cbar_set_cax : bool, default: True 402 If True, each axes in the grid has a *cax* attribute that is bound 403 to associated *cbar_axes*. 404 axes_class : subclass of `matplotlib.axes.Axes`, default: None 405 """ 406 self._colorbar_mode = cbar_mode 407 self._colorbar_location = cbar_location 408 self._colorbar_pad = cbar_pad 409 self._colorbar_size = cbar_size 410 # The colorbar axes are created in _init_locators(). 411 412 if add_all: 413 super().__init__( 414 fig, rect, nrows_ncols, ngrids, 415 direction=direction, axes_pad=axes_pad, 416 share_all=share_all, share_x=True, share_y=True, aspect=aspect, 417 label_mode=label_mode, axes_class=axes_class) 418 else: # Only show deprecation in that case. 419 super().__init__( 420 fig, rect, nrows_ncols, ngrids, 421 direction=direction, axes_pad=axes_pad, add_all=add_all, 422 share_all=share_all, share_x=True, share_y=True, aspect=aspect, 423 label_mode=label_mode, axes_class=axes_class) 424 425 if add_all: 426 for ax in self.cbar_axes: 427 fig.add_axes(ax) 428 429 if cbar_set_cax: 430 if self._colorbar_mode == "single": 431 for ax in self.axes_all: 432 ax.cax = self.cbar_axes[0] 433 elif self._colorbar_mode == "edge": 434 for index, ax in enumerate(self.axes_all): 435 col, row = self._get_col_row(index) 436 if self._colorbar_location in ("left", "right"): 437 ax.cax = self.cbar_axes[row] 438 else: 439 ax.cax = self.cbar_axes[col] 440 else: 441 for ax, cax in zip(self.axes_all, self.cbar_axes): 442 ax.cax = cax 443 444 def _init_locators(self): 445 # Slightly abusing this method to inject colorbar creation into init. 446 447 if self._colorbar_pad is None: 448 # horizontal or vertical arrangement? 449 if self._colorbar_location in ("left", "right"): 450 self._colorbar_pad = self._horiz_pad_size.fixed_size 451 else: 452 self._colorbar_pad = self._vert_pad_size.fixed_size 453 self.cbar_axes = [ 454 self._defaultCbarAxesClass( 455 self.axes_all[0].figure, self._divider.get_position(), 456 orientation=self._colorbar_location) 457 for _ in range(self.ngrids)] 458 459 cb_mode = self._colorbar_mode 460 cb_location = self._colorbar_location 461 462 h = [] 463 v = [] 464 465 h_ax_pos = [] 466 h_cb_pos = [] 467 if cb_mode == "single" and cb_location in ("left", "bottom"): 468 if cb_location == "left": 469 sz = self._nrows * Size.AxesX(self.axes_llc) 470 h.append(Size.from_any(self._colorbar_size, sz)) 471 h.append(Size.from_any(self._colorbar_pad, sz)) 472 locator = self._divider.new_locator(nx=0, ny=0, ny1=-1) 473 elif cb_location == "bottom": 474 sz = self._ncols * Size.AxesY(self.axes_llc) 475 v.append(Size.from_any(self._colorbar_size, sz)) 476 v.append(Size.from_any(self._colorbar_pad, sz)) 477 locator = self._divider.new_locator(nx=0, nx1=-1, ny=0) 478 for i in range(self.ngrids): 479 self.cbar_axes[i].set_visible(False) 480 self.cbar_axes[0].set_axes_locator(locator) 481 self.cbar_axes[0].set_visible(True) 482 483 for col, ax in enumerate(self.axes_row[0]): 484 if h: 485 h.append(self._horiz_pad_size) 486 487 if ax: 488 sz = Size.AxesX(ax, aspect="axes", ref_ax=self.axes_all[0]) 489 else: 490 sz = Size.AxesX(self.axes_all[0], 491 aspect="axes", ref_ax=self.axes_all[0]) 492 493 if (cb_location == "left" 494 and (cb_mode == "each" 495 or (cb_mode == "edge" and col == 0))): 496 h_cb_pos.append(len(h)) 497 h.append(Size.from_any(self._colorbar_size, sz)) 498 h.append(Size.from_any(self._colorbar_pad, sz)) 499 500 h_ax_pos.append(len(h)) 501 h.append(sz) 502 503 if (cb_location == "right" 504 and (cb_mode == "each" 505 or (cb_mode == "edge" and col == self._ncols - 1))): 506 h.append(Size.from_any(self._colorbar_pad, sz)) 507 h_cb_pos.append(len(h)) 508 h.append(Size.from_any(self._colorbar_size, sz)) 509 510 v_ax_pos = [] 511 v_cb_pos = [] 512 for row, ax in enumerate(self.axes_column[0][::-1]): 513 if v: 514 v.append(self._vert_pad_size) 515 516 if ax: 517 sz = Size.AxesY(ax, aspect="axes", ref_ax=self.axes_all[0]) 518 else: 519 sz = Size.AxesY(self.axes_all[0], 520 aspect="axes", ref_ax=self.axes_all[0]) 521 522 if (cb_location == "bottom" 523 and (cb_mode == "each" 524 or (cb_mode == "edge" and row == 0))): 525 v_cb_pos.append(len(v)) 526 v.append(Size.from_any(self._colorbar_size, sz)) 527 v.append(Size.from_any(self._colorbar_pad, sz)) 528 529 v_ax_pos.append(len(v)) 530 v.append(sz) 531 532 if (cb_location == "top" 533 and (cb_mode == "each" 534 or (cb_mode == "edge" and row == self._nrows - 1))): 535 v.append(Size.from_any(self._colorbar_pad, sz)) 536 v_cb_pos.append(len(v)) 537 v.append(Size.from_any(self._colorbar_size, sz)) 538 539 for i in range(self.ngrids): 540 col, row = self._get_col_row(i) 541 locator = self._divider.new_locator(nx=h_ax_pos[col], 542 ny=v_ax_pos[self._nrows-1-row]) 543 self.axes_all[i].set_axes_locator(locator) 544 545 if cb_mode == "each": 546 if cb_location in ("right", "left"): 547 locator = self._divider.new_locator( 548 nx=h_cb_pos[col], ny=v_ax_pos[self._nrows - 1 - row]) 549 550 elif cb_location in ("top", "bottom"): 551 locator = self._divider.new_locator( 552 nx=h_ax_pos[col], ny=v_cb_pos[self._nrows - 1 - row]) 553 554 self.cbar_axes[i].set_axes_locator(locator) 555 elif cb_mode == "edge": 556 if (cb_location == "left" and col == 0 557 or cb_location == "right" and col == self._ncols - 1): 558 locator = self._divider.new_locator( 559 nx=h_cb_pos[0], ny=v_ax_pos[self._nrows - 1 - row]) 560 self.cbar_axes[row].set_axes_locator(locator) 561 elif (cb_location == "bottom" and row == self._nrows - 1 562 or cb_location == "top" and row == 0): 563 locator = self._divider.new_locator(nx=h_ax_pos[col], 564 ny=v_cb_pos[0]) 565 self.cbar_axes[col].set_axes_locator(locator) 566 567 if cb_mode == "single": 568 if cb_location == "right": 569 sz = self._nrows * Size.AxesX(self.axes_llc) 570 h.append(Size.from_any(self._colorbar_pad, sz)) 571 h.append(Size.from_any(self._colorbar_size, sz)) 572 locator = self._divider.new_locator(nx=-2, ny=0, ny1=-1) 573 elif cb_location == "top": 574 sz = self._ncols * Size.AxesY(self.axes_llc) 575 v.append(Size.from_any(self._colorbar_pad, sz)) 576 v.append(Size.from_any(self._colorbar_size, sz)) 577 locator = self._divider.new_locator(nx=0, nx1=-1, ny=-2) 578 if cb_location in ("right", "top"): 579 for i in range(self.ngrids): 580 self.cbar_axes[i].set_visible(False) 581 self.cbar_axes[0].set_axes_locator(locator) 582 self.cbar_axes[0].set_visible(True) 583 elif cb_mode == "each": 584 for i in range(self.ngrids): 585 self.cbar_axes[i].set_visible(True) 586 elif cb_mode == "edge": 587 if cb_location in ("right", "left"): 588 count = self._nrows 589 else: 590 count = self._ncols 591 for i in range(count): 592 self.cbar_axes[i].set_visible(True) 593 for j in range(i + 1, self.ngrids): 594 self.cbar_axes[j].set_visible(False) 595 else: 596 for i in range(self.ngrids): 597 self.cbar_axes[i].set_visible(False) 598 self.cbar_axes[i].set_position([1., 1., 0.001, 0.001], 599 which="active") 600 601 self._divider.set_horizontal(h) 602 self._divider.set_vertical(v) 603 604 605AxesGrid = ImageGrid 606