1""" 2The Slicer classes. 3 4The main purpose of these classes is to have auto adjust of axes size to 5the data with different layout of cuts. 6""" 7 8import collections.abc 9import numbers 10from distutils.version import LooseVersion 11 12import matplotlib 13import matplotlib.pyplot as plt 14import numpy as np 15import warnings 16from matplotlib import cm as mpl_cm 17from matplotlib import (colors, 18 lines, 19 transforms, 20 ) 21from matplotlib.colorbar import ColorbarBase 22from matplotlib.font_manager import FontProperties 23from matplotlib.patches import FancyArrow 24from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar 25from scipy import sparse, stats 26 27from . import cm, glass_brain 28from .edge_detect import _edge_map 29from .find_cuts import find_xyz_cut_coords, find_cut_slices 30from .. import _utils 31from ..image import new_img_like 32from ..image.resampling import (get_bounds, reorder_img, coord_transform, 33 get_mask_bounds) 34from nilearn.image import get_data 35 36 37############################################################################### 38# class BaseAxes 39############################################################################### 40 41class BaseAxes(object): 42 """ An MPL axis-like object that displays a 2D view of 3D volumes 43 """ 44 45 def __init__(self, ax, direction, coord): 46 """ An MPL axis-like object that displays a cut of 3D volumes 47 48 Parameters 49 ---------- 50 ax : A MPL axes instance 51 The axes in which the plots will be drawn. 52 53 direction : {'x', 'y', 'z'} 54 The directions of the view. 55 56 coord : float 57 The coordinate along the direction of the cut. 58 59 """ 60 self.ax = ax 61 self.direction = direction 62 self.coord = coord 63 self._object_bounds = list() 64 self.shape = None 65 66 def transform_to_2d(self, data, affine): 67 raise NotImplementedError("'transform_to_2d' needs to be implemented " 68 "in derived classes'") 69 70 def add_object_bounds(self, bounds): 71 """Ensures that axes get rescaled when adding object bounds 72 73 """ 74 old_object_bounds = self.get_object_bounds() 75 self._object_bounds.append(bounds) 76 new_object_bounds = self.get_object_bounds() 77 78 if new_object_bounds != old_object_bounds: 79 self.ax.axis(self.get_object_bounds()) 80 81 def draw_2d(self, data_2d, data_bounds, bounding_box, 82 type='imshow', **kwargs): 83 # kwargs messaging 84 kwargs['origin'] = 'upper' 85 86 if self.direction == 'y': 87 (xmin, xmax), (_, _), (zmin, zmax) = data_bounds 88 (xmin_, xmax_), (_, _), (zmin_, zmax_) = bounding_box 89 elif self.direction in 'xlr': 90 (_, _), (xmin, xmax), (zmin, zmax) = data_bounds 91 (_, _), (xmin_, xmax_), (zmin_, zmax_) = bounding_box 92 elif self.direction == 'z': 93 (xmin, xmax), (zmin, zmax), (_, _) = data_bounds 94 (xmin_, xmax_), (zmin_, zmax_), (_, _) = bounding_box 95 else: 96 raise ValueError('Invalid value for direction %s' % 97 self.direction) 98 ax = self.ax 99 # Here we need to do a copy to avoid having the image changing as 100 # we change the data 101 im = getattr(ax, type)(data_2d.copy(), 102 extent=(xmin, xmax, zmin, zmax), 103 **kwargs) 104 105 self.add_object_bounds((xmin_, xmax_, zmin_, zmax_)) 106 self.shape = data_2d.T.shape 107 108 # The bounds of the object do not take into account a possible 109 # inversion of the axis. As such, we check that the axis is properly 110 # inverted when direction is left 111 if self.direction == 'l' and not (ax.get_xlim()[0] > ax.get_xlim()[1]): 112 ax.invert_xaxis() 113 114 return im 115 116 def get_object_bounds(self): 117 """ Return the bounds of the objects on this axes. 118 """ 119 if len(self._object_bounds) == 0: 120 # Nothing plotted yet 121 return -.01, .01, -.01, .01 122 xmins, xmaxs, ymins, ymaxs = np.array(self._object_bounds).T 123 xmax = max(xmaxs.max(), xmins.max()) 124 xmin = min(xmins.min(), xmaxs.min()) 125 ymax = max(ymaxs.max(), ymins.max()) 126 ymin = min(ymins.min(), ymaxs.min()) 127 128 return xmin, xmax, ymin, ymax 129 130 def draw_left_right(self, size, bg_color, **kwargs): 131 if self.direction in 'xlr': 132 return 133 ax = self.ax 134 ax.text(.1, .95, 'L', 135 transform=ax.transAxes, 136 horizontalalignment='left', 137 verticalalignment='top', 138 size=size, 139 bbox=dict(boxstyle="square,pad=0", 140 ec=bg_color, fc=bg_color, alpha=1), 141 **kwargs) 142 143 ax.text(.9, .95, 'R', 144 transform=ax.transAxes, 145 horizontalalignment='right', 146 verticalalignment='top', 147 size=size, 148 bbox=dict(boxstyle="square,pad=0", ec=bg_color, fc=bg_color), 149 **kwargs) 150 151 def draw_scale_bar(self, bg_color, size=5.0, units='cm', 152 fontproperties=None, frameon=False, loc=4, pad=.1, 153 borderpad=.5, sep=5, size_vertical=0, label_top=False, 154 color='black', fontsize=None, **kwargs): 155 """ Adds a scale bar annotation to the display 156 157 Parameters 158 ---------- 159 bgcolor : matplotlib color: str or (r, g, b) value 160 The background color of the scale bar annotation. 161 162 size : float, optional 163 Horizontal length of the scale bar, given in `units`. 164 Default=5.0. 165 166 units : str, optional 167 Physical units of the scale bar (`'cm'` or `'mm'`). 168 Default='cm'. 169 170 fontproperties : ``matplotlib.font_manager.FontProperties`` or dict, optional 171 Font properties for the label text. 172 173 frameon : Boolean, optional 174 Whether the scale bar is plotted with a border. Default=False. 175 176 loc : int, optional 177 Location of this scale bar. Valid location codes are documented 178 `here <https://matplotlib.org/mpl_toolkits/axes_grid/\ 179 api/anchored_artists_api.html#mpl_toolkits.axes_grid1.\ 180 anchored_artists.AnchoredSizeBar>`__. 181 Default=4. 182 183 pad : int of float, optional 184 Padding around the label and scale bar, in fraction of the font 185 size. Default=0.1. 186 187 borderpad : int or float, optional 188 Border padding, in fraction of the font size. Default=0.5. 189 190 sep : int or float, optional 191 Separation between the label and the scale bar, in points. 192 Default=5. 193 194 size_vertical : int or float, optional 195 Vertical length of the size bar, given in `units`. Default=0. 196 197 label_top : bool, optional 198 If True, the label will be over the scale bar. Default=False. 199 200 color : str, optional 201 Color for the scale bar and label. Default='black'. 202 203 fontsize : int, optional 204 Label font size (overwrites the size passed in through the 205 ``fontproperties`` argument). 206 207 **kwargs : 208 Keyworded arguments to pass to 209 ``matplotlib.offsetbox.AnchoredOffsetbox``. 210 211 """ 212 axis = self.ax 213 fontproperties = fontproperties or FontProperties() 214 if fontsize: 215 fontproperties.set_size(fontsize) 216 width_mm = size 217 if units == 'cm': 218 width_mm *= 10 219 220 anchor_size_bar = AnchoredSizeBar( 221 axis.transData, 222 width_mm, 223 '%g%s' % (size, units), 224 fontproperties=fontproperties, 225 frameon=frameon, 226 loc=loc, 227 pad=pad, 228 borderpad=borderpad, 229 sep=sep, 230 size_vertical=size_vertical, 231 label_top=label_top, 232 color=color, 233 **kwargs) 234 235 if frameon: 236 anchor_size_bar.patch.set_facecolor(bg_color) 237 anchor_size_bar.patch.set_edgecolor('none') 238 axis.add_artist(anchor_size_bar) 239 240 def draw_position(self, size, bg_color, **kwargs): 241 raise NotImplementedError("'draw_position' should be implemented " 242 "in derived classes") 243 244 245############################################################################### 246# class CutAxes 247############################################################################### 248 249class CutAxes(BaseAxes): 250 """ An MPL axis-like object that displays a cut of 3D volumes 251 """ 252 def transform_to_2d(self, data, affine): 253 """ Cut the 3D volume into a 2D slice 254 255 Parameters 256 ---------- 257 data : 3D ndarray 258 The 3D volume to cut. 259 260 affine : 4x4 ndarray 261 The affine of the volume. 262 263 """ 264 coords = [0, 0, 0] 265 coords['xyz'.index(self.direction)] = self.coord 266 x_map, y_map, z_map = [int(np.round(c)) for c in 267 coord_transform(coords[0], 268 coords[1], 269 coords[2], 270 np.linalg.inv(affine))] 271 if self.direction == 'y': 272 cut = np.rot90(data[:, y_map, :]) 273 elif self.direction == 'x': 274 cut = np.rot90(data[x_map, :, :]) 275 elif self.direction == 'z': 276 cut = np.rot90(data[:, :, z_map]) 277 else: 278 raise ValueError('Invalid value for direction %s' % 279 self.direction) 280 return cut 281 282 def draw_position(self, size, bg_color, decimals=False, **kwargs): 283 if decimals: 284 text = '%s=%.{}f'.format(decimals) 285 coord = float(self.coord) 286 else: 287 text = '%s=%i' 288 coord = self.coord 289 ax = self.ax 290 ax.text(0, 0, text % (self.direction, coord), 291 transform=ax.transAxes, 292 horizontalalignment='left', 293 verticalalignment='bottom', 294 size=size, 295 bbox=dict(boxstyle="square,pad=0", 296 ec=bg_color, fc=bg_color, alpha=1), 297 **kwargs) 298 299 300def _get_index_from_direction(direction): 301 """Returns numerical index from direction 302 """ 303 directions = ['x', 'y', 'z'] 304 try: 305 # l and r are subcases of x 306 if direction in 'lr': 307 index = 0 308 else: 309 index = directions.index(direction) 310 except ValueError: 311 message = ( 312 '{0} is not a valid direction. ' 313 "Allowed values are 'l', 'r', 'x', 'y' and 'z'").format(direction) 314 raise ValueError(message) 315 return index 316 317 318def _coords_3d_to_2d(coords_3d, direction, return_direction=False): 319 """Project 3d coordinates into 2d ones given the direction of a cut 320 """ 321 index = _get_index_from_direction(direction) 322 dimensions = [0, 1, 2] 323 dimensions.pop(index) 324 325 if return_direction: 326 return coords_3d[:, dimensions], coords_3d[:, index] 327 328 return coords_3d[:, dimensions] 329 330 331############################################################################### 332# class GlassBrainAxes 333############################################################################### 334 335class GlassBrainAxes(BaseAxes): 336 """An MPL axis-like object that displays a 2D projection of 3D 337 volumes with a schematic view of the brain. 338 339 """ 340 def __init__(self, ax, direction, coord, plot_abs=True, **kwargs): 341 super(GlassBrainAxes, self).__init__(ax, direction, coord) 342 self._plot_abs = plot_abs 343 if ax is not None: 344 object_bounds = glass_brain.plot_brain_schematics(ax, 345 direction, 346 **kwargs) 347 self.add_object_bounds(object_bounds) 348 349 def transform_to_2d(self, data, affine): 350 """ Returns the maximum of the absolute value of the 3D volume 351 along an axis. 352 353 Parameters 354 ---------- 355 data : 3D ndarray 356 The 3D volume. 357 358 affine : 4x4 ndarray 359 The affine of the volume. 360 361 """ 362 if self.direction in 'xlr': 363 max_axis = 0 364 else: 365 max_axis = '.yz'.index(self.direction) 366 367 # set unselected brain hemisphere activations to 0 368 369 if self.direction == 'l': 370 x_center, _, _, _ = np.dot(np.linalg.inv(affine), 371 np.array([0, 0, 0, 1])) 372 data_selection = data[:int(x_center), :, :] 373 elif self.direction == 'r': 374 x_center, _, _, _ = np.dot(np.linalg.inv(affine), 375 np.array([0, 0, 0, 1])) 376 data_selection = data[int(x_center):, :, :] 377 else: 378 data_selection = data 379 380 # We need to make sure data_selection is not empty in the x axis 381 # This should be the case since we expect images in MNI space 382 if data_selection.shape[0] == 0: 383 data_selection = data 384 385 if not self._plot_abs: 386 # get the shape of the array we are projecting to 387 new_shape = list(data.shape) 388 del new_shape[max_axis] 389 390 # generate a 3D indexing array that points to max abs value in the 391 # current projection 392 a1, a2 = np.indices(new_shape) 393 inds = [a1, a2] 394 inds.insert(max_axis, np.abs(data_selection).argmax(axis=max_axis)) 395 396 # take the values where the absolute value of the projection 397 # is the highest 398 maximum_intensity_data = data_selection[tuple(inds)] 399 else: 400 maximum_intensity_data = np.abs(data_selection).max(axis=max_axis) 401 402 # This work around can be removed bumping matplotlib > 2.1.0. See #1815 403 # in nilearn for the invention of this work around 404 405 if self.direction == 'l' and data_selection.min() is np.ma.masked and \ 406 not (self.ax.get_xlim()[0] > self.ax.get_xlim()[1]): 407 self.ax.invert_xaxis() 408 409 return np.rot90(maximum_intensity_data) 410 411 def draw_position(self, size, bg_color, **kwargs): 412 # It does not make sense to draw crosses for the position of 413 # the cuts since we are taking the max along one axis 414 pass 415 416 def _add_markers(self, marker_coords, marker_color, marker_size, **kwargs): 417 """Plot markers 418 419 In the case of 'l' and 'r' directions (for hemispheric projections), 420 markers in the coordinate x == 0 are included in both hemispheres. 421 422 """ 423 marker_coords_2d = _coords_3d_to_2d(marker_coords, self.direction) 424 xdata, ydata = marker_coords_2d.T 425 426 # Allow markers only in their respective hemisphere when appropriate 427 if self.direction in 'lr': 428 if not isinstance(marker_color, str) and \ 429 not isinstance(marker_color, np.ndarray): 430 marker_color = np.asarray(marker_color) 431 relevant_coords = [] 432 xcoords, ycoords, zcoords = marker_coords.T 433 for cidx, xc in enumerate(xcoords): 434 if self.direction == 'r' and xc >= 0: 435 relevant_coords.append(cidx) 436 elif self.direction == 'l' and xc <= 0: 437 relevant_coords.append(cidx) 438 xdata = xdata[relevant_coords] 439 ydata = ydata[relevant_coords] 440 # if marker_color is string for example 'red' or 'blue', then 441 # we pass marker_color as it is to matplotlib scatter without 442 # making any selection in 'l' or 'r' color. 443 # More likely that user wants to display all nodes to be in 444 # same color. 445 if not isinstance(marker_color, str) and \ 446 len(marker_color) != 1: 447 marker_color = marker_color[relevant_coords] 448 449 if not isinstance(marker_size, numbers.Number): 450 marker_size = np.asarray(marker_size)[relevant_coords] 451 452 defaults = {'marker': 'o', 453 'zorder': 1000} 454 for k, v in defaults.items(): 455 kwargs.setdefault(k, v) 456 457 self.ax.scatter(xdata, ydata, s=marker_size, 458 c=marker_color, **kwargs) 459 460 def _add_lines(self, line_coords, line_values, cmap, 461 vmin=None, vmax=None, directed=False, **kwargs): 462 """Plot lines 463 464 Parameters 465 ---------- 466 line_coords : list of numpy arrays of shape (2, 3) 467 3d coordinates of lines start points and end points. 468 469 line_values : array_like 470 Values of the lines. 471 472 cmap : colormap 473 Colormap used to map line_values to a color. 474 475 vmin, vmax : float, optional 476 If not None, either or both of these values will be used to 477 as the minimum and maximum values to color lines. If None are 478 supplied the maximum absolute value within the given threshold 479 will be used as minimum (multiplied by -1) and maximum 480 coloring levels. 481 482 directed : boolean, optional 483 Add arrows instead of lines if set to True. Use this when plotting 484 directed graphs for example. Default=False. 485 486 kwargs : dict 487 Additional arguments to pass to matplotlib Line2D. 488 489 """ 490 # colormap for colorbar 491 self.cmap = cmap 492 if vmin is None and vmax is None: 493 abs_line_values_max = np.abs(line_values).max() 494 vmin = -abs_line_values_max 495 vmax = abs_line_values_max 496 elif vmin is None: 497 if vmax > 0: 498 vmin = -vmax 499 else: 500 raise ValueError( 501 "If vmax is set to a non-positive number " 502 "then vmin needs to be specified" 503 ) 504 elif vmax is None: 505 if vmin < 0: 506 vmax = -vmin 507 else: 508 raise ValueError( 509 "If vmin is set to a non-negative number " 510 "then vmax needs to be specified" 511 ) 512 norm = colors.Normalize(vmin=vmin, 513 vmax=vmax) 514 # normalization useful for colorbar 515 self.norm = norm 516 abs_norm = colors.Normalize(vmin=0, 517 vmax=vmax) 518 value_to_color = plt.cm.ScalarMappable(norm=norm, cmap=cmap).to_rgba 519 520 # Allow lines only in their respective hemisphere when appropriate 521 if self.direction in 'lr': 522 relevant_lines = [] 523 for lidx, line in enumerate(line_coords): 524 if self.direction == 'r': 525 if line[0, 0] >= 0 and line[1, 0] >= 0: 526 relevant_lines.append(lidx) 527 elif self.direction == 'l': 528 if line[0, 0] < 0 and line[1, 0] < 0: 529 relevant_lines.append(lidx) 530 line_coords = np.array(line_coords)[relevant_lines] 531 line_values = line_values[relevant_lines] 532 533 for start_end_point_3d, line_value in zip( 534 line_coords, line_values): 535 start_end_point_2d = _coords_3d_to_2d(start_end_point_3d, 536 self.direction) 537 538 color = value_to_color(line_value) 539 abs_line_value = abs(line_value) 540 linewidth = 1 + 2 * abs_norm(abs_line_value) 541 # Hacky way to put the strongest connections on top of the weakest 542 # note sign does not matter hence using 'abs' 543 zorder = 10 + 10 * abs_norm(abs_line_value) 544 this_kwargs = {'color': color, 'linewidth': linewidth, 545 'zorder': zorder} 546 # kwargs should have priority over this_kwargs so that the 547 # user can override the default logic 548 this_kwargs.update(kwargs) 549 xdata, ydata = start_end_point_2d.T 550 # If directed is True, add an arrow 551 if directed: 552 dx = xdata[1] - xdata[0] 553 dy = ydata[1] - ydata[0] 554 # Hack to avoid empty arrows to crash with 555 # matplotlib versions older than 3.1 556 # This can be removed once support for 557 # matplotlib pre 3.1 has been dropped. 558 if dx == 0 and dy == 0: 559 arrow = FancyArrow(xdata[0], ydata[0], 560 dx, dy) 561 else: 562 arrow = FancyArrow(xdata[0], ydata[0], 563 dx, dy, 564 length_includes_head=True, 565 width=linewidth, 566 head_width=3*linewidth, 567 **this_kwargs) 568 self.ax.add_patch(arrow) 569 # Otherwise a line 570 else: 571 line = lines.Line2D(xdata, ydata, **this_kwargs) 572 self.ax.add_line(line) 573 574 575############################################################################### 576# class BaseSlicer 577############################################################################### 578 579class BaseSlicer(object): 580 """ The main purpose of these class is to have auto adjust of axes size 581 to the data with different layout of cuts. 582 583 """ 584 # This actually encodes the figsize for only one axe 585 _default_figsize = [2.2, 2.6] 586 _axes_class = CutAxes 587 588 def __init__(self, cut_coords, axes=None, black_bg=False, 589 brain_color=(0.5, 0.5, 0.5), **kwargs): 590 """ Create 3 linked axes for plotting orthogonal cuts. 591 592 Parameters 593 ---------- 594 cut_coords : 3 tuple of ints 595 The cut position, in world space. 596 597 axes : matplotlib axes object, optional 598 The axes that will be subdivided in 3. 599 600 black_bg : boolean, optional 601 If True, the background of the figure will be put to 602 black. If you wish to save figures with a black background, 603 you will need to pass "facecolor='k', edgecolor='k'" 604 to matplotlib.pyplot.savefig. Default=False. 605 606 brain_color : tuple, optional 607 The brain color to use as the background color (e.g., for 608 transparent colorbars). 609 Default=(0.5, 0.5, 0.5) 610 611 """ 612 self.cut_coords = cut_coords 613 if axes is None: 614 axes = plt.axes((0., 0., 1., 1.)) 615 axes.axis('off') 616 self.frame_axes = axes 617 axes.set_zorder(1) 618 bb = axes.get_position() 619 self.rect = (bb.x0, bb.y0, bb.x1, bb.y1) 620 self._black_bg = black_bg 621 self._brain_color = brain_color 622 self._colorbar = False 623 self._colorbar_width = 0.05 * bb.width 624 self._colorbar_margin = dict(left=0.25 * bb.width, 625 right=0.02 * bb.width, 626 top=0.05 * bb.height, 627 bottom=0.05 * bb.height) 628 self._init_axes(**kwargs) 629 630 @staticmethod 631 def find_cut_coords(img=None, threshold=None, cut_coords=None): 632 # Implement this as a staticmethod or a classmethod when 633 # subclassing 634 raise NotImplementedError 635 636 @classmethod 637 def init_with_figure(cls, img, threshold=None, 638 cut_coords=None, figure=None, axes=None, 639 black_bg=False, leave_space=False, colorbar=False, 640 brain_color=(0.5, 0.5, 0.5), **kwargs): 641 "Initialize the slicer with an image" 642 # deal with "fake" 4D images 643 if img is not None and img is not False: 644 img = _utils.check_niimg_3d(img) 645 646 cut_coords = cls.find_cut_coords(img, threshold, cut_coords) 647 648 if isinstance(axes, plt.Axes) and figure is None: 649 figure = axes.figure 650 651 if not isinstance(figure, plt.Figure): 652 # Make sure that we have a figure 653 figsize = cls._default_figsize[:] 654 655 # Adjust for the number of axes 656 figsize[0] *= len(cut_coords) 657 658 # Make space for the colorbar 659 if colorbar: 660 figsize[0] += .7 661 662 facecolor = 'k' if black_bg else 'w' 663 664 if leave_space: 665 figsize[0] += 3.4 666 figure = plt.figure(figure, figsize=figsize, 667 facecolor=facecolor) 668 if isinstance(axes, plt.Axes): 669 assert axes.figure is figure, ("The axes passed are not " 670 "in the figure") 671 672 if axes is None: 673 axes = [0., 0., 1., 1.] 674 if leave_space: 675 axes = [0.3, 0, .7, 1.] 676 if isinstance(axes, collections.abc.Sequence): 677 axes = figure.add_axes(axes) 678 # People forget to turn their axis off, or to set the zorder, and 679 # then they cannot see their slicer 680 axes.axis('off') 681 return cls(cut_coords, axes, black_bg, brain_color, **kwargs) 682 683 def title(self, text, x=0.01, y=0.99, size=15, color=None, bgcolor=None, 684 alpha=1, **kwargs): 685 """ Write a title to the view. 686 687 Parameters 688 ---------- 689 text : string 690 The text of the title. 691 692 x : float, optional 693 The horizontal position of the title on the frame in 694 fraction of the frame width. Default=0.01. 695 696 y : float, optional 697 The vertical position of the title on the frame in 698 fraction of the frame height. Default=0.99. 699 700 size : integer, optional 701 The size of the title text. Default=15. 702 703 color : matplotlib color specifier, optional 704 The color of the font of the title. 705 706 bgcolor : matplotlib color specifier, optional 707 The color of the background of the title. 708 709 alpha : float, optional 710 The alpha value for the background. Default=1. 711 712 kwargs : 713 Extra keyword arguments are passed to matplotlib's text 714 function. 715 716 """ 717 if color is None: 718 color = 'k' if self._black_bg else 'w' 719 if bgcolor is None: 720 bgcolor = 'w' if self._black_bg else 'k' 721 if hasattr(self, '_cut_displayed'): 722 # Adapt to the case of mosaic plotting 723 if isinstance(self.cut_coords, dict): 724 first_axe = self._cut_displayed[-1] 725 first_axe = (first_axe, self.cut_coords[first_axe][0]) 726 else: 727 first_axe = self._cut_displayed[0] 728 else: 729 first_axe = self.cut_coords[0] 730 ax = self.axes[first_axe].ax 731 ax.text(x, y, text, 732 transform=self.frame_axes.transAxes, 733 horizontalalignment='left', 734 verticalalignment='top', 735 size=size, color=color, 736 bbox=dict(boxstyle="square,pad=.3", 737 ec=bgcolor, fc=bgcolor, alpha=alpha), 738 zorder=1000, 739 **kwargs) 740 ax.set_zorder(1000) 741 742 def add_overlay(self, img, threshold=1e-6, colorbar=False, **kwargs): 743 """ Plot a 3D map in all the views. 744 745 Parameters 746 ----------- 747 img : Niimg-like object 748 See http://nilearn.github.io/manipulating_images/input_output.html 749 If it is a masked array, only the non-masked part will be plotted. 750 751 threshold : Int or Float or None, optional 752 If None is given, the maps are not thresholded. 753 If a number is given, it is used to threshold the maps: 754 values below the threshold (in absolute value) are 755 plotted as transparent. Default=1e-6. 756 757 colorbar : boolean, optional 758 If True, display a colorbar on the right of the plots. 759 Default=False. 760 761 kwargs : 762 Extra keyword arguments are passed to imshow. 763 764 """ 765 if colorbar and self._colorbar: 766 raise ValueError("This figure already has an overlay with a " 767 "colorbar.") 768 else: 769 self._colorbar = colorbar 770 771 img = _utils.check_niimg_3d(img) 772 773 # Make sure that add_overlay shows consistent default behavior 774 # with plot_stat_map 775 kwargs.setdefault('interpolation', 'nearest') 776 ims = self._map_show(img, type='imshow', threshold=threshold, **kwargs) 777 778 # `ims` can be empty in some corner cases, look at test_img_plotting.test_outlier_cut_coords. 779 if colorbar and ims: 780 self._show_colorbar(ims[0].cmap, ims[0].norm, threshold) 781 782 plt.draw_if_interactive() 783 784 def add_contours(self, img, threshold=1e-6, filled=False, **kwargs): 785 """ Contour a 3D map in all the views. 786 787 Parameters 788 ----------- 789 img : Niimg-like object 790 See http://nilearn.github.io/manipulating_images/input_output.html 791 Provides image to plot. 792 793 threshold : Int or Float or None, optional 794 If None is given, the maps are not thresholded. 795 If a number is given, it is used to threshold the maps, 796 values below the threshold (in absolute value) are plotted 797 as transparent. Default=1e-6. 798 799 filled : boolean, optional 800 If filled=True, contours are displayed with color fillings. 801 Default=False. 802 803 kwargs : 804 Extra keyword arguments are passed to contour, see the 805 documentation of pylab.contour and see pylab.contourf documentation 806 for arguments related to contours with fillings. 807 Useful, arguments are typical "levels", which is a 808 list of values to use for plotting a contour or contour 809 fillings (if filled=True), and 810 "colors", which is one color or a list of colors for 811 these contours. 812 813 Notes 814 ----- 815 If colors are not specified, default coloring choices 816 (from matplotlib) for contours and contour_fillings can be 817 different. 818 819 """ 820 if not filled: 821 threshold = None 822 self._map_show(img, type='contour', threshold=threshold, **kwargs) 823 if filled: 824 if 'levels' in kwargs: 825 levels = kwargs['levels'] 826 if len(levels) <= 1: 827 # contour fillings levels should be given as (lower, upper). 828 levels.append(np.inf) 829 830 self._map_show(img, type='contourf', threshold=threshold, **kwargs) 831 832 plt.draw_if_interactive() 833 834 def _map_show(self, img, type='imshow', 835 resampling_interpolation='continuous', 836 threshold=None, **kwargs): 837 # In the special case where the affine of img is not diagonal, 838 # the function `reorder_img` will trigger a resampling 839 # of the provided image with a continuous interpolation 840 # since this is the default value here. In the special 841 # case where this image is binary, such as when this function 842 # is called from `add_contours`, continuous interpolation 843 # does not make sense and we turn to nearest interpolation instead. 844 if _utils.niimg._is_binary_niimg(img): 845 img = reorder_img(img, resample='nearest') 846 else: 847 img = reorder_img(img, resample=resampling_interpolation) 848 threshold = float(threshold) if threshold is not None else None 849 850 if threshold is not None: 851 data = _utils.niimg._safe_get_data(img, ensure_finite=True) 852 if threshold == 0: 853 data = np.ma.masked_equal(data, 0, copy=False) 854 else: 855 data = np.ma.masked_inside(data, -threshold, threshold, 856 copy=False) 857 img = new_img_like(img, data, img.affine) 858 859 affine = img.affine 860 data = _utils.niimg._safe_get_data(img, ensure_finite=True) 861 data_bounds = get_bounds(data.shape, affine) 862 (xmin, xmax), (ymin, ymax), (zmin, zmax) = data_bounds 863 864 xmin_, xmax_, ymin_, ymax_, zmin_, zmax_ = \ 865 xmin, xmax, ymin, ymax, zmin, zmax 866 867 # Compute tight bounds 868 if type in ('contour', 'contourf'): 869 # Define a pseudo threshold to have a tight bounding box 870 if 'levels' in kwargs: 871 thr = 0.9 * np.min(np.abs(kwargs['levels'])) 872 else: 873 thr = 1e-6 874 not_mask = np.logical_or(data > thr, data < -thr) 875 xmin_, xmax_, ymin_, ymax_, zmin_, zmax_ = \ 876 get_mask_bounds(new_img_like(img, not_mask, affine)) 877 elif hasattr(data, 'mask') and isinstance(data.mask, np.ndarray): 878 not_mask = np.logical_not(data.mask) 879 xmin_, xmax_, ymin_, ymax_, zmin_, zmax_ = \ 880 get_mask_bounds(new_img_like(img, not_mask, affine)) 881 882 data_2d_list = [] 883 for display_ax in self.axes.values(): 884 try: 885 data_2d = display_ax.transform_to_2d(data, affine) 886 except IndexError: 887 # We are cutting outside the indices of the data 888 data_2d = None 889 890 data_2d_list.append(data_2d) 891 892 if kwargs.get('vmin') is None: 893 kwargs['vmin'] = np.ma.min([d.min() for d in data_2d_list 894 if d is not None]) 895 if kwargs.get('vmax') is None: 896 kwargs['vmax'] = np.ma.max([d.max() for d in data_2d_list 897 if d is not None]) 898 899 bounding_box = (xmin_, xmax_), (ymin_, ymax_), (zmin_, zmax_) 900 ims = [] 901 to_iterate_over = zip(self.axes.values(), data_2d_list) 902 for display_ax, data_2d in to_iterate_over: 903 if data_2d is not None and data_2d.min() is not np.ma.masked: 904 # If data_2d is completely masked, then there is nothing to 905 # plot. Hence, no point to do imshow(). Moreover, we see 906 # problem came up with matplotlib 2.1.0 (issue #9280) when 907 # data is completely masked or with numpy < 1.14 908 # (issue #4595). This work around can be removed when bumping 909 # matplotlib version above 2.1.0 910 im = display_ax.draw_2d(data_2d, data_bounds, bounding_box, 911 type=type, **kwargs) 912 ims.append(im) 913 return ims 914 915 def _show_colorbar(self, cmap, norm, threshold=None): 916 """Displays the colorbar. 917 918 Parameters 919 ---------- 920 cmap : a matplotlib colormap 921 The colormap used. 922 923 norm : a matplotlib.colors.Normalize object 924 This object is typically found as the 'norm' attribute of an 925 matplotlib.image.AxesImage. 926 927 threshold : float or None, optional 928 The absolute value at which the colorbar is thresholded. 929 930 """ 931 if threshold is None: 932 offset = 0 933 else: 934 offset = threshold 935 if offset > norm.vmax: 936 offset = norm.vmax 937 938 # create new axis for the colorbar 939 figure = self.frame_axes.figure 940 _, y0, x1, y1 = self.rect 941 height = y1 - y0 942 x_adjusted_width = self._colorbar_width / len(self.axes) 943 x_adjusted_margin = self._colorbar_margin['right'] / len(self.axes) 944 lt_wid_top_ht = [x1 - (x_adjusted_width + x_adjusted_margin), 945 y0 + self._colorbar_margin['top'], 946 x_adjusted_width, 947 height - (self._colorbar_margin['top'] + 948 self._colorbar_margin['bottom'])] 949 self._colorbar_ax = figure.add_axes(lt_wid_top_ht) 950 if LooseVersion(matplotlib.__version__) >= LooseVersion("1.6"): 951 self._colorbar_ax.set_facecolor('w') 952 else: 953 self._colorbar_ax.set_axis_bgcolor('w') 954 955 our_cmap = mpl_cm.get_cmap(cmap) 956 # edge case where the data has a single value 957 # yields a cryptic matplotlib error message 958 # when trying to plot the color bar 959 nb_ticks = 5 if norm.vmin != norm.vmax else 1 960 ticks = np.linspace(norm.vmin, norm.vmax, nb_ticks) 961 bounds = np.linspace(norm.vmin, norm.vmax, our_cmap.N) 962 963 # some colormap hacking 964 cmaplist = [our_cmap(i) for i in range(our_cmap.N)] 965 transparent_start = int(norm(-offset, clip=True) * (our_cmap.N - 1)) 966 transparent_stop = int(norm(offset, clip=True) * (our_cmap.N - 1)) 967 for i in range(transparent_start, transparent_stop): 968 cmaplist[i] = self._brain_color + (0.,) # transparent 969 if norm.vmin == norm.vmax: # len(np.unique(data)) == 1 ? 970 return 971 else: 972 our_cmap = colors.LinearSegmentedColormap.from_list( 973 'Custom cmap', cmaplist, our_cmap.N) 974 975 self._cbar = ColorbarBase( 976 self._colorbar_ax, ticks=ticks, norm=norm, 977 orientation='vertical', cmap=our_cmap, boundaries=bounds, 978 spacing='proportional', format='%.2g') 979 self._cbar.ax.set_facecolor(self._brain_color) 980 981 self._colorbar_ax.yaxis.tick_left() 982 tick_color = 'w' if self._black_bg else 'k' 983 outline_color = 'w' if self._black_bg else 'k' 984 985 for tick in self._colorbar_ax.yaxis.get_ticklabels(): 986 tick.set_color(tick_color) 987 self._colorbar_ax.yaxis.set_tick_params(width=0) 988 self._cbar.outline.set_edgecolor(outline_color) 989 990 def add_edges(self, img, color='r'): 991 """ Plot the edges of a 3D map in all the views. 992 993 Parameters 994 ---------- 995 img : Niimg-like object 996 See http://nilearn.github.io/manipulating_images/input_output.html 997 The 3D map to be plotted. 998 If it is a masked array, only the non-masked part will be plotted. 999 1000 color : matplotlib color: string or (r, g, b) value 1001 The color used to display the edge map. 1002 Default='r'. 1003 1004 """ 1005 img = reorder_img(img, resample='continuous') 1006 data = get_data(img) 1007 affine = img.affine 1008 single_color_cmap = colors.ListedColormap([color]) 1009 data_bounds = get_bounds(data.shape, img.affine) 1010 1011 # For each ax, cut the data and plot it 1012 for display_ax in self.axes.values(): 1013 try: 1014 data_2d = display_ax.transform_to_2d(data, affine) 1015 edge_mask = _edge_map(data_2d) 1016 except IndexError: 1017 # We are cutting outside the indices of the data 1018 continue 1019 display_ax.draw_2d(edge_mask, data_bounds, data_bounds, 1020 type='imshow', cmap=single_color_cmap) 1021 1022 plt.draw_if_interactive() 1023 1024 def add_markers(self, marker_coords, marker_color='r', marker_size=30, 1025 **kwargs): 1026 """Add markers to the plot. 1027 1028 Parameters 1029 ---------- 1030 marker_coords : array of size (n_markers, 3) 1031 Coordinates of the markers to plot. For each slice, only markers 1032 that are 2 millimeters away from the slice are plotted. 1033 1034 marker_color : pyplot compatible color or list of shape (n_markers,), optional 1035 List of colors for each marker that can be string or matplotlib colors. 1036 Default='r'. 1037 1038 marker_size : single float or list of shape (n_markers,), optional 1039 Size in pixel for each marker. Default=30. 1040 1041 """ 1042 defaults = {'marker': 'o', 1043 'zorder': 1000} 1044 marker_coords = np.asanyarray(marker_coords) 1045 for k, v in defaults.items(): 1046 kwargs.setdefault(k, v) 1047 1048 for display_ax in self.axes.values(): 1049 direction = display_ax.direction 1050 coord = display_ax.coord 1051 marker_coords_2d, third_d = _coords_3d_to_2d( 1052 marker_coords, direction, return_direction=True) 1053 xdata, ydata = marker_coords_2d.T 1054 # Allow markers only in their respective hemisphere when appropriate 1055 marker_color_ = marker_color 1056 if direction in ('lr'): 1057 if (not isinstance(marker_color, str) and 1058 not isinstance(marker_color, np.ndarray)): 1059 marker_color_ = np.asarray(marker_color) 1060 xcoords, ycoords, zcoords = marker_coords.T 1061 if direction == 'r': 1062 relevant_coords = (xcoords >= 0) 1063 elif direction == 'l': 1064 relevant_coords = (xcoords <= 0) 1065 xdata = xdata[relevant_coords] 1066 ydata = ydata[relevant_coords] 1067 if (not isinstance(marker_color, str) and 1068 len(marker_color) != 1): 1069 marker_color_ = marker_color_[relevant_coords] 1070 # Check if coord has integer represents a cut in direction 1071 # to follow the heuristic. If no foreground image is given 1072 # coordinate is empty or None. This case is valid for plotting 1073 # markers on glass brain without any foreground image. 1074 if isinstance(coord, numbers.Number): 1075 # Heuristic that plots only markers that are 2mm away 1076 # from the current slice. 1077 # XXX: should we keep this heuristic? 1078 mask = np.abs(third_d - coord) <= 2. 1079 xdata = xdata[mask] 1080 ydata = ydata[mask] 1081 display_ax.ax.scatter(xdata, ydata, s=marker_size, 1082 c=marker_color_, **kwargs) 1083 1084 def annotate(self, left_right=True, positions=True, scalebar=False, 1085 size=12, scale_size=5.0, scale_units='cm', scale_loc=4, 1086 decimals=0, **kwargs): 1087 """Add annotations to the plot. 1088 1089 Parameters 1090 ---------- 1091 left_right : boolean, optional 1092 If left_right is True, annotations indicating which side 1093 is left and which side is right are drawn. Default=True. 1094 1095 positions : boolean, optional 1096 If positions is True, annotations indicating the 1097 positions of the cuts are drawn. Default=True. 1098 1099 scalebar : boolean, optional 1100 If ``True``, cuts are annotated with a reference scale bar. 1101 For finer control of the scale bar, please check out 1102 the draw_scale_bar method on the axes in "axes" attribute of 1103 this object. Default=False. 1104 1105 size : integer, optional 1106 The size of the text used. Default=12. 1107 1108 scale_size : number, optional 1109 The length of the scalebar, in units of scale_units. 1110 Default=5.0. 1111 1112 scale_units : {'cm', 'mm'}, optional 1113 The units for the scalebar. Default='cm'. 1114 1115 scale_loc : integer, optional 1116 The positioning for the scalebar. Default=4. 1117 Valid location codes are: 1118 1119 - 'upper right' : 1 1120 - 'upper left' : 2 1121 - 'lower left' : 3 1122 - 'lower right' : 4 1123 - 'right' : 5 1124 - 'center left' : 6 1125 - 'center right' : 7 1126 - 'lower center' : 8 1127 - 'upper center' : 9 1128 - 'center' : 10 1129 1130 decimals : integer, optional 1131 Number of decimal places on slice position annotation. If zero, 1132 the slice position is integer without decimal point. 1133 Default=0. 1134 1135 kwargs : 1136 Extra keyword arguments are passed to matplotlib's text 1137 function. 1138 1139 """ 1140 kwargs = kwargs.copy() 1141 if 'color' not in kwargs: 1142 if self._black_bg: 1143 kwargs['color'] = 'w' 1144 else: 1145 kwargs['color'] = 'k' 1146 1147 bg_color = ('k' if self._black_bg else 'w') 1148 1149 if left_right: 1150 for display_axis in self.axes.values(): 1151 display_axis.draw_left_right(size=size, bg_color=bg_color, 1152 **kwargs) 1153 1154 if positions: 1155 for display_axis in self.axes.values(): 1156 display_axis.draw_position(size=size, bg_color=bg_color, 1157 decimals=decimals, 1158 **kwargs) 1159 1160 if scalebar: 1161 axes = self.axes.values() 1162 for display_axis in axes: 1163 display_axis.draw_scale_bar(bg_color=bg_color, 1164 fontsize=size, 1165 size=scale_size, 1166 units=scale_units, 1167 loc=scale_loc, 1168 **kwargs) 1169 1170 def close(self): 1171 """ Close the figure. This is necessary to avoid leaking memory. 1172 """ 1173 plt.close(self.frame_axes.figure.number) 1174 1175 def savefig(self, filename, dpi=None): 1176 """ Save the figure to a file 1177 1178 Parameters 1179 ---------- 1180 filename : string 1181 The file name to save to. Its extension determines the 1182 file type, typically '.png', '.svg' or '.pdf'. 1183 1184 dpi : None or scalar, optional 1185 The resolution in dots per inch. 1186 1187 """ 1188 facecolor = edgecolor = 'k' if self._black_bg else 'w' 1189 self.frame_axes.figure.savefig(filename, dpi=dpi, 1190 facecolor=facecolor, 1191 edgecolor=edgecolor) 1192 1193 1194############################################################################### 1195# class OrthoSlicer 1196############################################################################### 1197 1198class OrthoSlicer(BaseSlicer): 1199 """ A class to create 3 linked axes for plotting orthogonal 1200 cuts of 3D maps. 1201 1202 Attributes 1203 ---------- 1204 axes : dictionary of axes 1205 The 3 axes used to plot each view. 1206 1207 frame_axes : axes 1208 The axes framing the whole set of views. 1209 1210 Notes 1211 ----- 1212 The extent of the different axes are adjusted to fit the data 1213 best in the viewing area. 1214 1215 """ 1216 _cut_displayed = 'yxz' 1217 _axes_class = CutAxes 1218 1219 @classmethod 1220 def find_cut_coords(cls, img=None, threshold=None, cut_coords=None): 1221 "Instantiate the slicer and find cut coordinates" 1222 if cut_coords is None: 1223 if img is None or img is False: 1224 cut_coords = (0, 0, 0) 1225 else: 1226 cut_coords = find_xyz_cut_coords( 1227 img, activation_threshold=threshold) 1228 cut_coords = [cut_coords['xyz'.find(c)] 1229 for c in sorted(cls._cut_displayed)] 1230 return cut_coords 1231 1232 def _init_axes(self, **kwargs): 1233 cut_coords = self.cut_coords 1234 if len(cut_coords) != len(self._cut_displayed): 1235 raise ValueError('The number cut_coords passed does not' 1236 ' match the display_mode') 1237 x0, y0, x1, y1 = self.rect 1238 facecolor = 'k' if self._black_bg else 'w' 1239 # Create our axes: 1240 self.axes = dict() 1241 for index, direction in enumerate(self._cut_displayed): 1242 fh = self.frame_axes.get_figure() 1243 ax = fh.add_axes([0.3 * index * (x1 - x0) + x0, y0, 1244 .3 * (x1 - x0), y1 - y0], aspect='equal') 1245 if LooseVersion(matplotlib.__version__) >= LooseVersion("1.6"): 1246 ax.set_facecolor(facecolor) 1247 else: 1248 ax.set_axis_bgcolor(facecolor) 1249 1250 ax.axis('off') 1251 coord = self.cut_coords[ 1252 sorted(self._cut_displayed).index(direction)] 1253 display_ax = self._axes_class(ax, direction, coord, **kwargs) 1254 self.axes[direction] = display_ax 1255 ax.set_axes_locator(self._locator) 1256 1257 if self._black_bg: 1258 for ax in self.axes.values(): 1259 ax.ax.imshow(np.zeros((2, 2, 3)), 1260 extent=[-5000, 5000, -5000, 5000], 1261 zorder=-500, aspect='equal') 1262 1263 # To have a black background in PDF, we need to create a 1264 # patch in black for the background 1265 self.frame_axes.imshow(np.zeros((2, 2, 3)), 1266 extent=[-5000, 5000, -5000, 5000], 1267 zorder=-500, aspect='auto') 1268 self.frame_axes.set_zorder(-1000) 1269 1270 def _locator(self, axes, renderer): 1271 """ The locator function used by matplotlib to position axes. 1272 Here we put the logic used to adjust the size of the axes. 1273 1274 """ 1275 x0, y0, x1, y1 = self.rect 1276 width_dict = dict() 1277 # A dummy axes, for the situation in which we are not plotting 1278 # all three (x, y, z) cuts 1279 dummy_ax = self._axes_class(None, None, None) 1280 width_dict[dummy_ax.ax] = 0 1281 display_ax_dict = self.axes 1282 1283 if self._colorbar: 1284 adjusted_width = self._colorbar_width / len(self.axes) 1285 right_margin = self._colorbar_margin['right'] / len(self.axes) 1286 ticks_margin = self._colorbar_margin['left'] / len(self.axes) 1287 x1 = x1 - (adjusted_width + ticks_margin + right_margin) 1288 1289 for display_ax in display_ax_dict.values(): 1290 bounds = display_ax.get_object_bounds() 1291 if not bounds: 1292 # This happens if the call to _map_show was not 1293 # successful. As it happens asynchronously (during a 1294 # refresh of the figure) we capture the problem and 1295 # ignore it: it only adds a non informative traceback 1296 bounds = [0, 1, 0, 1] 1297 xmin, xmax, ymin, ymax = bounds 1298 width_dict[display_ax.ax] = (xmax - xmin) 1299 1300 total_width = float(sum(width_dict.values())) 1301 for ax, width in width_dict.items(): 1302 width_dict[ax] = width / total_width * (x1 - x0) 1303 1304 direction_ax = [] 1305 for d in self._cut_displayed: 1306 direction_ax.append(display_ax_dict.get(d, dummy_ax).ax) 1307 left_dict = dict() 1308 for idx, ax in enumerate(direction_ax): 1309 left_dict[ax] = x0 1310 for prev_ax in direction_ax[:idx]: 1311 left_dict[ax] += width_dict[prev_ax] 1312 1313 return transforms.Bbox([[left_dict[axes], y0], 1314 [left_dict[axes] + width_dict[axes], y1]]) 1315 1316 def draw_cross(self, cut_coords=None, **kwargs): 1317 """ Draw a crossbar on the plot to show where the cut is 1318 performed. 1319 1320 Parameters 1321 ---------- 1322 cut_coords : 3-tuple of floats, optional 1323 The position of the cross to draw. If none is passed, the 1324 ortho_slicer's cut coordinates are used. 1325 1326 kwargs : 1327 Extra keyword arguments are passed to axhline 1328 1329 """ 1330 if cut_coords is None: 1331 cut_coords = self.cut_coords 1332 coords = dict() 1333 for direction in 'xyz': 1334 coord = None 1335 if direction in self._cut_displayed: 1336 coord = cut_coords[ 1337 sorted(self._cut_displayed).index(direction)] 1338 coords[direction] = coord 1339 x, y, z = coords['x'], coords['y'], coords['z'] 1340 1341 kwargs = kwargs.copy() 1342 if 'color' not in kwargs: 1343 if self._black_bg: 1344 kwargs['color'] = '.8' 1345 else: 1346 kwargs['color'] = 'k' 1347 1348 if 'y' in self.axes: 1349 ax = self.axes['y'].ax 1350 if x is not None: 1351 ax.axvline(x, ymin=.05, ymax=.95, **kwargs) 1352 if z is not None: 1353 ax.axhline(z, **kwargs) 1354 1355 if 'x' in self.axes: 1356 ax = self.axes['x'].ax 1357 if y is not None: 1358 ax.axvline(y, ymin=.05, ymax=.95, **kwargs) 1359 if z is not None: 1360 ax.axhline(z, xmax=.95, **kwargs) 1361 1362 if 'z' in self.axes: 1363 ax = self.axes['z'].ax 1364 if x is not None: 1365 ax.axvline(x, ymin=.05, ymax=.95, **kwargs) 1366 if y is not None: 1367 ax.axhline(y, **kwargs) 1368 1369 1370############################################################################### 1371# class TiledSlicer 1372############################################################################### 1373 1374class TiledSlicer(BaseSlicer): 1375 """ A class to create 3 axes for plotting orthogonal 1376 cuts of 3D maps, organized in a 2x2 grid. 1377 1378 Attributes 1379 ---------- 1380 axes : dictionary of axes 1381 The 3 axes used to plot each view. 1382 1383 frame_axes : axes 1384 The axes framing the whole set of views. 1385 1386 Notes 1387 ----- 1388 The extent of the different axes are adjusted to fit the data 1389 best in the viewing area. 1390 1391 """ 1392 _cut_displayed = 'yxz' 1393 _axes_class = CutAxes 1394 _default_figsize = [2.0, 6.0] 1395 1396 @classmethod 1397 def find_cut_coords(cls, img=None, threshold=None, cut_coords=None): 1398 """Instantiate the slicer and find cut coordinates. 1399 1400 Parameters 1401 ---------- 1402 img : 3D Nifti1Image 1403 The brain map. 1404 1405 threshold : float, optional 1406 The lower threshold to the positive activation. If None, the 1407 activation threshold is computed using the 80% percentile of 1408 the absolute value of the map. 1409 1410 cut_coords : list of float, optional 1411 xyz world coordinates of cuts. 1412 1413 Returns 1414 ------- 1415 cut_coords : list of float 1416 xyz world coordinates of cuts. 1417 1418 """ 1419 if cut_coords is None: 1420 if img is None or img is False: 1421 cut_coords = (0, 0, 0) 1422 else: 1423 cut_coords = find_xyz_cut_coords( 1424 img, activation_threshold=threshold) 1425 cut_coords = [cut_coords['xyz'.find(c)] 1426 for c in sorted(cls._cut_displayed)] 1427 1428 return cut_coords 1429 1430 def _find_initial_axes_coord(self, index): 1431 """Find coordinates for initial axes placement for xyz cuts. 1432 1433 Parameters 1434 ---------- 1435 index : int 1436 Index corresponding to current cut 'x', 'y' or 'z'. 1437 1438 Returns 1439 ------- 1440 [coord1, coord2, coord3, coord4] : list of int 1441 x0, y0, x1, y1 coordinates used by matplotlib 1442 to position axes in figure. 1443 1444 """ 1445 rect_x0, rect_y0, rect_x1, rect_y1 = self.rect 1446 1447 if index == 0: 1448 coord1 = rect_x1 - rect_x0 1449 coord2 = 0.5 * (rect_y1 - rect_y0) + rect_y0 1450 coord3 = 0.5 * (rect_x1 - rect_x0) + rect_x0 1451 coord4 = rect_y1 - rect_y0 1452 elif index == 1: 1453 coord1 = 0.5 * (rect_x1 - rect_x0) + rect_x0 1454 coord2 = 0.5 * (rect_y1 - rect_y0) + rect_y0 1455 coord3 = rect_x1 - rect_x0 1456 coord4 = rect_y1 - rect_y0 1457 elif index == 2: 1458 coord1 = rect_x1 - rect_x0 1459 coord2 = rect_y1 - rect_y0 1460 coord3 = 0.5 * (rect_x1 - rect_x0) + rect_x0 1461 coord4 = 0.5 * (rect_y1 - rect_y0) + rect_y0 1462 return [coord1, coord2, coord3, coord4] 1463 1464 def _init_axes(self, **kwargs): 1465 """Initializes and places axes for display of 'xyz' cuts. 1466 1467 Parameters 1468 ---------- 1469 kwargs : 1470 additional arguments to pass to self._axes_class 1471 1472 """ 1473 cut_coords = self.cut_coords 1474 if len(cut_coords) != len(self._cut_displayed): 1475 raise ValueError('The number cut_coords passed does not' 1476 ' match the display_mode') 1477 1478 facecolor = 'k' if self._black_bg else 'w' 1479 1480 self.axes = dict() 1481 for index, direction in enumerate(self._cut_displayed): 1482 fh = self.frame_axes.get_figure() 1483 axes_coords = self._find_initial_axes_coord(index) 1484 ax = fh.add_axes(axes_coords, aspect='equal') 1485 1486 if LooseVersion(matplotlib.__version__) >= LooseVersion("1.6"): 1487 ax.set_facecolor(facecolor) 1488 else: 1489 ax.set_axis_bgcolor(facecolor) 1490 1491 ax.axis('off') 1492 coord = self.cut_coords[ 1493 sorted(self._cut_displayed).index(direction)] 1494 display_ax = self._axes_class(ax, direction, coord, **kwargs) 1495 self.axes[direction] = display_ax 1496 ax.set_axes_locator(self._locator) 1497 1498 def _adjust_width_height(self, width_dict, height_dict, 1499 rect_x0, rect_y0, rect_x1, rect_y1): 1500 """Adjusts absolute image width and height to ratios. 1501 1502 Parameters 1503 ---------- 1504 width_dict : dict 1505 Width of image cuts displayed in axes. 1506 1507 height_dict : dict 1508 Height of image cuts displayed in axes. 1509 1510 rect_x0, rect_y0, rect_x1, rect_y1 : float 1511 Matplotlib figure boundaries. 1512 1513 Returns 1514 ------- 1515 width_dict : dict 1516 Width ratios of image cuts for optimal positioning of axes. 1517 1518 height_dict : dict 1519 Height ratios of image cuts for optimal positioning of axes. 1520 1521 """ 1522 total_height = 0 1523 total_width = 0 1524 1525 if 'y' in self.axes: 1526 ax = self.axes['y'].ax 1527 total_height = total_height + height_dict[ax] 1528 total_width = total_width + width_dict[ax] 1529 1530 if 'x' in self.axes: 1531 ax = self.axes['x'].ax 1532 total_width = total_width + width_dict[ax] 1533 1534 if 'z' in self.axes: 1535 ax = self.axes['z'].ax 1536 total_height = total_height + height_dict[ax] 1537 1538 for ax, width in width_dict.items(): 1539 width_dict[ax] = width / total_width * (rect_x1 - rect_x0) 1540 1541 for ax, height in height_dict.items(): 1542 height_dict[ax] = height / total_height * (rect_y1 - rect_y0) 1543 1544 return (width_dict, height_dict) 1545 1546 def _find_axes_coord(self, rel_width_dict, rel_height_dict, 1547 rect_x0, rect_y0, rect_x1, rect_y1): 1548 """"Find coordinates for initial axes placement for xyz cuts. 1549 1550 Parameters 1551 ---------- 1552 rel_width_dict : dict 1553 Width ratios of image cuts for optimal positioning of axes. 1554 1555 rel_height_dict : dict 1556 Height ratios of image cuts for optimal positioning of axes. 1557 1558 rect_x0, rect_y0, rect_x1, rect_y1 : float 1559 Matplotlib figure boundaries. 1560 1561 Returns 1562 ------- 1563 coord1, coord2, coord3, coord4 : dict 1564 x0, y0, x1, y1 coordinates per axes used by matplotlib 1565 to position axes in figure. 1566 1567 """ 1568 coord1 = dict() 1569 coord2 = dict() 1570 coord3 = dict() 1571 coord4 = dict() 1572 1573 if 'y' in self.axes: 1574 ax = self.axes['y'].ax 1575 coord1[ax] = rect_x0 1576 coord2[ax] = (rect_y1) - rel_height_dict[ax] 1577 coord3[ax] = rect_x0 + rel_width_dict[ax] 1578 coord4[ax] = rect_y1 1579 1580 if 'x' in self.axes: 1581 ax = self.axes['x'].ax 1582 coord1[ax] = (rect_x1) - rel_width_dict[ax] 1583 coord2[ax] = (rect_y1) - rel_height_dict[ax] 1584 coord3[ax] = rect_x1 1585 coord4[ax] = rect_y1 1586 1587 if 'z' in self.axes: 1588 ax = self.axes['z'].ax 1589 coord1[ax] = rect_x0 1590 coord2[ax] = rect_y0 1591 coord3[ax] = rect_x0 + rel_width_dict[ax] 1592 coord4[ax] = rect_y0 + rel_height_dict[ax] 1593 1594 return(coord1, coord2, coord3, coord4) 1595 1596 def _locator(self, axes, renderer): 1597 """ The locator function used by matplotlib to position axes. 1598 Here we put the logic used to adjust the size of the axes. 1599 1600 """ 1601 rect_x0, rect_y0, rect_x1, rect_y1 = self.rect 1602 1603 # image width and height 1604 width_dict = dict() 1605 height_dict = dict() 1606 1607 # A dummy axes, for the situation in which we are not plotting 1608 # all three (x, y, z) cuts 1609 dummy_ax = self._axes_class(None, None, None) 1610 width_dict[dummy_ax.ax] = 0 1611 height_dict[dummy_ax.ax] = 0 1612 display_ax_dict = self.axes 1613 1614 if self._colorbar: 1615 adjusted_width = self._colorbar_width / len(self.axes) 1616 right_margin = self._colorbar_margin['right'] / len(self.axes) 1617 ticks_margin = self._colorbar_margin['left'] / len(self.axes) 1618 rect_x1 = rect_x1 - (adjusted_width + ticks_margin + right_margin) 1619 1620 for display_ax in display_ax_dict.values(): 1621 bounds = display_ax.get_object_bounds() 1622 if not bounds: 1623 # This happens if the call to _map_show was not 1624 # successful. As it happens asynchronously (during a 1625 # refresh of the figure) we capture the problem and 1626 # ignore it: it only adds a non informative traceback 1627 bounds = [0, 1, 0, 1] 1628 xmin, xmax, ymin, ymax = bounds 1629 width_dict[display_ax.ax] = (xmax - xmin) 1630 height_dict[display_ax.ax] = (ymax - ymin) 1631 1632 # relative image height and width 1633 rel_width_dict, rel_height_dict = self._adjust_width_height( 1634 width_dict, height_dict, 1635 rect_x0, rect_y0, rect_x1, rect_y1) 1636 1637 direction_ax = [] 1638 for d in self._cut_displayed: 1639 direction_ax.append(display_ax_dict.get(d, dummy_ax).ax) 1640 1641 coord1, coord2, coord3, coord4 = self._find_axes_coord( 1642 rel_width_dict, rel_height_dict, 1643 rect_x0, rect_y0, rect_x1, rect_y1) 1644 1645 return transforms.Bbox([[coord1[axes], coord2[axes]], 1646 [coord3[axes], coord4[axes]]]) 1647 1648 def draw_cross(self, cut_coords=None, **kwargs): 1649 """Draw a crossbar on the plot to show where the cut is performed. 1650 1651 Parameters 1652 ---------- 1653 cut_coords : 3-tuple of floats, optional 1654 The position of the cross to draw. If none is passed, the 1655 ortho_slicer's cut coordinates are used. 1656 1657 kwargs : 1658 Extra keyword arguments are passed to axhline 1659 1660 """ 1661 if cut_coords is None: 1662 cut_coords = self.cut_coords 1663 coords = dict() 1664 for direction in 'xyz': 1665 coord_ = None 1666 if direction in self._cut_displayed: 1667 sorted_cuts = sorted(self._cut_displayed) 1668 index = sorted_cuts.index(direction) 1669 coord_ = cut_coords[index] 1670 coords[direction] = coord_ 1671 x, y, z = coords['x'], coords['y'], coords['z'] 1672 1673 kwargs = kwargs.copy() 1674 if 'color' not in kwargs: 1675 try: 1676 kwargs['color'] = '.8' if self._black_bg else 'k' 1677 except KeyError: 1678 pass 1679 1680 if 'y' in self.axes: 1681 ax = self.axes['y'].ax 1682 if x is not None: 1683 ax.axvline(x, **kwargs) 1684 if z is not None: 1685 ax.axhline(z, **kwargs) 1686 1687 if 'x' in self.axes: 1688 ax = self.axes['x'].ax 1689 if y is not None: 1690 ax.axvline(y, **kwargs) 1691 if z is not None: 1692 ax.axhline(z, **kwargs) 1693 1694 if 'z' in self.axes: 1695 ax = self.axes['z'].ax 1696 if x is not None: 1697 ax.axvline(x, **kwargs) 1698 if y is not None: 1699 ax.axhline(y, **kwargs) 1700 1701############################################################################### 1702# class BaseStackedSlicer 1703############################################################################### 1704 1705class BaseStackedSlicer(BaseSlicer): 1706 """ A class to create linked axes for plotting stacked 1707 cuts of 2D maps. 1708 1709 Attributes 1710 ---------- 1711 axes : dictionary of axes 1712 The axes used to plot each view. 1713 1714 frame_axes : axes 1715 The axes framing the whole set of views. 1716 1717 Notes 1718 ----- 1719 The extent of the different axes are adjusted to fit the data 1720 best in the viewing area. 1721 1722 """ 1723 @classmethod 1724 def find_cut_coords(cls, img=None, threshold=None, cut_coords=None): 1725 "Instantiate the slicer and find cut coordinates" 1726 if cut_coords is None: 1727 cut_coords = 7 1728 1729 if img is None or img is False: 1730 bounds = ((-40, 40), (-30, 30), (-30, 75)) 1731 lower, upper = bounds['xyz'.index(cls._direction)] 1732 cut_coords = np.linspace(lower, upper, cut_coords).tolist() 1733 else: 1734 if (not isinstance(cut_coords, collections.abc.Sequence) and 1735 isinstance(cut_coords, numbers.Number)): 1736 cut_coords = find_cut_slices(img, 1737 direction=cls._direction, 1738 n_cuts=cut_coords) 1739 1740 return cut_coords 1741 1742 def _init_axes(self, **kwargs): 1743 x0, y0, x1, y1 = self.rect 1744 # Create our axes: 1745 self.axes = dict() 1746 fraction = 1. / len(self.cut_coords) 1747 for index, coord in enumerate(self.cut_coords): 1748 coord = float(coord) 1749 fh = self.frame_axes.get_figure() 1750 ax = fh.add_axes([fraction * index * (x1 - x0) + x0, y0, 1751 fraction * (x1 - x0), y1 - y0]) 1752 ax.axis('off') 1753 display_ax = self._axes_class(ax, self._direction, 1754 coord, **kwargs) 1755 self.axes[coord] = display_ax 1756 ax.set_axes_locator(self._locator) 1757 1758 if self._black_bg: 1759 for ax in self.axes.values(): 1760 ax.ax.imshow(np.zeros((2, 2, 3)), 1761 extent=[-5000, 5000, -5000, 5000], 1762 zorder=-500, aspect='equal') 1763 1764 # To have a black background in PDF, we need to create a 1765 # patch in black for the background 1766 self.frame_axes.imshow(np.zeros((2, 2, 3)), 1767 extent=[-5000, 5000, -5000, 5000], 1768 zorder=-500, aspect='auto') 1769 self.frame_axes.set_zorder(-1000) 1770 1771 def _locator(self, axes, renderer): 1772 """ The locator function used by matplotlib to position axes. 1773 Here we put the logic used to adjust the size of the axes. 1774 1775 """ 1776 x0, y0, x1, y1 = self.rect 1777 width_dict = dict() 1778 display_ax_dict = self.axes 1779 1780 if self._colorbar: 1781 adjusted_width = self._colorbar_width / len(self.axes) 1782 right_margin = self._colorbar_margin['right'] / len(self.axes) 1783 ticks_margin = self._colorbar_margin['left'] / len(self.axes) 1784 x1 = x1 - (adjusted_width + right_margin + ticks_margin) 1785 1786 for display_ax in display_ax_dict.values(): 1787 bounds = display_ax.get_object_bounds() 1788 if not bounds: 1789 # This happens if the call to _map_show was not 1790 # successful. As it happens asynchronously (during a 1791 # refresh of the figure) we capture the problem and 1792 # ignore it: it only adds a non informative traceback 1793 bounds = [0, 1, 0, 1] 1794 xmin, xmax, ymin, ymax = bounds 1795 width_dict[display_ax.ax] = (xmax - xmin) 1796 total_width = float(sum(width_dict.values())) 1797 for ax, width in width_dict.items(): 1798 width_dict[ax] = width / total_width * (x1 - x0) 1799 left_dict = dict() 1800 left = float(x0) 1801 for coord, display_ax in display_ax_dict.items(): 1802 left_dict[display_ax.ax] = left 1803 this_width = width_dict[display_ax.ax] 1804 left += this_width 1805 return transforms.Bbox([[left_dict[axes], y0], 1806 [left_dict[axes] + width_dict[axes], y1]]) 1807 1808 def draw_cross(self, cut_coords=None, **kwargs): 1809 """ Draw a crossbar on the plot to show where the cut is 1810 performed. 1811 1812 Parameters 1813 ---------- 1814 cut_coords : 3-tuple of floats, optional 1815 The position of the cross to draw. If none is passed, the 1816 ortho_slicer's cut coordinates are used. 1817 1818 kwargs : 1819 Extra keyword arguments are passed to axhline 1820 1821 """ 1822 return 1823 1824 1825class XSlicer(BaseStackedSlicer): 1826 _direction = 'x' 1827 _default_figsize = [2.6, 2.3] 1828 1829 1830class YSlicer(BaseStackedSlicer): 1831 _direction = 'y' 1832 _default_figsize = [2.2, 2.3] 1833 1834 1835class ZSlicer(BaseStackedSlicer): 1836 _direction = 'z' 1837 _default_figsize = [2.2, 2.3] 1838 1839 1840class XZSlicer(OrthoSlicer): 1841 _cut_displayed = 'xz' 1842 1843 1844class YXSlicer(OrthoSlicer): 1845 _cut_displayed = 'yx' 1846 1847 1848class YZSlicer(OrthoSlicer): 1849 _cut_displayed = 'yz' 1850 1851 1852class MosaicSlicer(BaseSlicer): 1853 """ A class to create 3 axes for plotting cuts of 3D maps, 1854 in multiple rows and columns. 1855 1856 Attributes 1857 ---------- 1858 axes : dictionary of axes 1859 The 3 axes used to plot multiple views. 1860 1861 frame_axes : axes 1862 The axes framing the whole set of views. 1863 1864 """ 1865 _cut_displayed = 'yxz' 1866 _axes_class = CutAxes 1867 _default_figsize = [11.1, 7.2] 1868 1869 @classmethod 1870 def find_cut_coords(cls, img=None, threshold=None, cut_coords=None): 1871 """Instantiate the slicer and find cut coordinates for mosaic plotting. 1872 1873 Parameters 1874 ---------- 1875 img : 3D Nifti1Image, optional 1876 The brain image. 1877 1878 threshold : float, optional 1879 The lower threshold to the positive activation. If None, the 1880 activation threshold is computed using the 80% percentile of 1881 the absolute value of the map. 1882 1883 cut_coords : list/tuple of 3 floats, integer, optional 1884 xyz world coordinates of cuts. If cut_coords 1885 are not provided, 7 coordinates of cuts are automatically 1886 calculated. 1887 1888 Returns 1889 ------- 1890 cut_coords : dict 1891 xyz world coordinates of cuts in a direction. Each key 1892 denotes the direction. 1893 """ 1894 if cut_coords is None: 1895 cut_coords = 7 1896 1897 if (not isinstance(cut_coords, collections.abc.Sequence) and 1898 isinstance(cut_coords, numbers.Number)): 1899 cut_coords = [cut_coords] * 3 1900 cut_coords = cls._find_cut_coords(img, cut_coords, 1901 cls._cut_displayed) 1902 else: 1903 if len(cut_coords) != len(cls._cut_displayed): 1904 raise ValueError('The number cut_coords passed does not' 1905 ' match the display_mode. Mosaic plotting ' 1906 'expects tuple of length 3.' ) 1907 cut_coords = [cut_coords['xyz'.find(c)] 1908 for c in sorted(cls._cut_displayed)] 1909 cut_coords = cls._find_cut_coords(img, cut_coords, 1910 cls._cut_displayed) 1911 return cut_coords 1912 1913 @staticmethod 1914 def _find_cut_coords(img, cut_coords, cut_displayed): 1915 """ Find slicing positions along a given axis. 1916 1917 Helper function to find_cut_coords. 1918 1919 Parameters 1920 ---------- 1921 img : 3D Nifti1Image 1922 The brain image. 1923 1924 cut_coords : list/tuple of 3 floats, integer, optional 1925 xyz world coordinates of cuts. 1926 1927 cut_displayed : str 1928 Sectional directions 'yxz' 1929 1930 Returns 1931 ------- 1932 cut_coords : 1D array of length specified in n_cuts 1933 The computed cut_coords. 1934 """ 1935 coords = dict() 1936 if img is None or img is False: 1937 bounds = ((-40, 40), (-30, 30), (-30, 75)) 1938 for direction, n_cuts in zip(sorted(cut_displayed), 1939 cut_coords): 1940 lower, upper = bounds['xyz'.index(direction)] 1941 coords[direction] = np.linspace(lower, upper, 1942 n_cuts).tolist() 1943 else: 1944 for direction, n_cuts in zip(sorted(cut_displayed), 1945 cut_coords): 1946 coords[direction] = find_cut_slices(img, direction=direction, 1947 n_cuts=n_cuts) 1948 return coords 1949 1950 def _init_axes(self, **kwargs): 1951 """Initializes and places axes for display of 'xyz' multiple cuts. 1952 1953 Parameters 1954 ---------- 1955 kwargs: 1956 additional arguments to pass to self._axes_class 1957 1958 """ 1959 if not isinstance(self.cut_coords, dict): 1960 self.cut_coords = self.find_cut_coords(cut_coords=self.cut_coords) 1961 1962 if len(self.cut_coords) != len(self._cut_displayed): 1963 raise ValueError('The number cut_coords passed does not' 1964 ' match the mosaic mode') 1965 x0, y0, x1, y1 = self.rect 1966 1967 # Create our axes: 1968 self.axes = dict() 1969 # portions for main axes 1970 fraction = y1 / len(self.cut_coords) 1971 height = fraction 1972 for index, direction in enumerate(self._cut_displayed): 1973 coords = self.cut_coords[direction] 1974 # portions allotment for each of 'x', 'y', 'z' coordinate 1975 fraction_c = 1. / len(coords) 1976 fh = self.frame_axes.get_figure() 1977 indices = [x0, fraction * index * (y1 - y0) + y0, 1978 x1, fraction * (y1 - y0)] 1979 ax = fh.add_axes(indices) 1980 ax.axis('off') 1981 this_x0, this_y0, this_x1, this_y1 = indices 1982 for index_c, coord in enumerate(coords): 1983 coord = float(coord) 1984 fh_c = ax.get_figure() 1985 # indices for each sub axes within main axes 1986 indices = [fraction_c * index_c * (this_x1 - this_x0) + this_x0, 1987 this_y0, 1988 fraction_c * (this_x1 - this_x0), 1989 height] 1990 ax = fh_c.add_axes(indices) 1991 ax.axis('off') 1992 display_ax = self._axes_class(ax, direction, 1993 coord, **kwargs) 1994 self.axes[(direction, coord)] = display_ax 1995 ax.set_axes_locator(self._locator) 1996 1997 def _locator(self, axes, renderer): 1998 """ The locator function used by matplotlib to position axes. 1999 Here we put the logic used to adjust the size of the axes. 2000 """ 2001 x0, y0, x1, y1 = self.rect 2002 display_ax_dict = self.axes 2003 2004 if self._colorbar: 2005 adjusted_width = self._colorbar_width / len(self.axes) 2006 right_margin = self._colorbar_margin['right'] / len(self.axes) 2007 ticks_margin = self._colorbar_margin['left'] / len(self.axes) 2008 x1 = x1 - (adjusted_width + right_margin + ticks_margin) 2009 2010 # capture widths for each axes for anchoring Bbox 2011 width_dict = dict() 2012 for direction in self._cut_displayed: 2013 this_width = dict() 2014 for display_ax in display_ax_dict.values(): 2015 if direction == display_ax.direction: 2016 bounds = display_ax.get_object_bounds() 2017 if not bounds: 2018 # This happens if the call to _map_show was not 2019 # successful. As it happens asynchronously (during a 2020 # refresh of the figure) we capture the problem and 2021 # ignore it: it only adds a non informative traceback 2022 bounds = [0, 1, 0, 1] 2023 xmin, xmax, ymin, ymax = bounds 2024 this_width[display_ax.ax] = (xmax - xmin) 2025 total_width = float(sum(this_width.values())) 2026 for ax, w in this_width.items(): 2027 width_dict[ax] = w / total_width * (x1 - x0) 2028 2029 left_dict = dict() 2030 # bottom positions in Bbox according to cuts 2031 bottom_dict = dict() 2032 # fraction is divided by the cut directions 'y', 'x', 'z' 2033 fraction = y1 / len(self._cut_displayed) 2034 height_dict = dict() 2035 for index, direction in enumerate(self._cut_displayed): 2036 left = float(x0) 2037 this_height = fraction + fraction * index 2038 for coord, display_ax in display_ax_dict.items(): 2039 if direction == display_ax.direction: 2040 left_dict[display_ax.ax] = left 2041 this_width = width_dict[display_ax.ax] 2042 left += this_width 2043 bottom_dict[display_ax.ax] = fraction * index * (y1 - y0) 2044 height_dict[display_ax.ax] = this_height 2045 return transforms.Bbox([[left_dict[axes], bottom_dict[axes]], 2046 [left_dict[axes] + width_dict[axes], 2047 height_dict[axes]]]) 2048 2049 2050 def draw_cross(self, cut_coords=None, **kwargs): 2051 """ Draw a crossbar on the plot to show where the cut is 2052 performed. 2053 2054 Parameters 2055 ---------- 2056 cut_coords: 3-tuple of floats, optional 2057 The position of the cross to draw. If none is passed, the 2058 ortho_slicer's cut coordinates are used. 2059 kwargs: 2060 Extra keyword arguments are passed to axhline 2061 """ 2062 return 2063 2064 2065SLICERS = dict(ortho=OrthoSlicer, 2066 tiled=TiledSlicer, 2067 mosaic=MosaicSlicer, 2068 xz=XZSlicer, 2069 yz=YZSlicer, 2070 yx=YXSlicer, 2071 x=XSlicer, 2072 y=YSlicer, 2073 z=ZSlicer) 2074 2075 2076class OrthoProjector(OrthoSlicer): 2077 """A class to create linked axes for plotting orthogonal projections 2078 of 3D maps. 2079 2080 """ 2081 _axes_class = GlassBrainAxes 2082 2083 @classmethod 2084 def find_cut_coords(cls, img=None, threshold=None, cut_coords=None): 2085 return (None, ) * len(cls._cut_displayed) 2086 2087 def draw_cross(self, cut_coords=None, **kwargs): 2088 # It does not make sense to draw crosses for the position of 2089 # the cuts since we are taking the max along one axis 2090 pass 2091 2092 def add_graph(self, adjacency_matrix, node_coords, 2093 node_color='auto', node_size=50, 2094 edge_cmap=cm.bwr, 2095 edge_vmin=None, edge_vmax=None, 2096 edge_threshold=None, 2097 edge_kwargs=None, node_kwargs=None, colorbar=False, 2098 ): 2099 """Plot undirected graph on each of the axes 2100 2101 Parameters 2102 ---------- 2103 adjacency_matrix : numpy array of shape (n, n) 2104 Represents the edges strengths of the graph. 2105 The matrix can be symmetric which will result in 2106 an undirected graph, or not symmetric which will 2107 result in a directed graph. 2108 2109 node_coords : numpy array_like of shape (n, 3) 2110 3d coordinates of the graph nodes in world space. 2111 2112 node_color : color or sequence of colors, optional 2113 Color(s) of the nodes. Default='auto'. 2114 2115 node_size : scalar or array_like, optional 2116 Size(s) of the nodes in points^2. Default=50. 2117 2118 edge_cmap : colormap, optional 2119 Colormap used for representing the strength of the edges. 2120 Default=cm.bwr. 2121 2122 edge_vmin, edge_vmax : float, optional 2123 If not None, either or both of these values will be used to 2124 as the minimum and maximum values to color edges. If None are 2125 supplied the maximum absolute value within the given threshold 2126 will be used as minimum (multiplied by -1) and maximum 2127 coloring levels. 2128 2129 edge_threshold : str or number, optional 2130 If it is a number only the edges with a value greater than 2131 edge_threshold will be shown. 2132 If it is a string it must finish with a percent sign, 2133 e.g. "25.3%", and only the edges with a abs(value) above 2134 the given percentile will be shown. 2135 2136 edge_kwargs : dict, optional 2137 Will be passed as kwargs for each edge matlotlib Line2D. 2138 2139 node_kwargs : dict 2140 Will be passed as kwargs to the plt.scatter call that plots all 2141 the nodes in one go. 2142 2143 """ 2144 # set defaults 2145 if edge_kwargs is None: 2146 edge_kwargs = {} 2147 if node_kwargs is None: 2148 node_kwargs = {} 2149 if isinstance(node_color, str) and node_color == 'auto': 2150 nb_nodes = len(node_coords) 2151 node_color = mpl_cm.Set2(np.linspace(0, 1, nb_nodes)) 2152 node_coords = np.asarray(node_coords) 2153 2154 # decompress input matrix if sparse 2155 if sparse.issparse(adjacency_matrix): 2156 adjacency_matrix = adjacency_matrix.toarray() 2157 2158 # make the lines below well-behaved 2159 adjacency_matrix = np.nan_to_num(adjacency_matrix) 2160 2161 # safety checks 2162 if 's' in node_kwargs: 2163 raise ValueError("Please use 'node_size' and not 'node_kwargs' " 2164 "to specify node sizes") 2165 if 'c' in node_kwargs: 2166 raise ValueError("Please use 'node_color' and not 'node_kwargs' " 2167 "to specify node colors") 2168 2169 adjacency_matrix_shape = adjacency_matrix.shape 2170 if (len(adjacency_matrix_shape) != 2 or 2171 adjacency_matrix_shape[0] != adjacency_matrix_shape[1]): 2172 raise ValueError( 2173 "'adjacency_matrix' is supposed to have shape (n, n)." 2174 ' Its shape was {0}'.format(adjacency_matrix_shape)) 2175 2176 node_coords_shape = node_coords.shape 2177 if len(node_coords_shape) != 2 or node_coords_shape[1] != 3: 2178 message = ( 2179 "Invalid shape for 'node_coords'. You passed an " 2180 "'adjacency_matrix' of shape {0} therefore " 2181 "'node_coords' should be a array with shape ({0[0]}, 3) " 2182 'while its shape was {1}').format(adjacency_matrix_shape, 2183 node_coords_shape) 2184 2185 raise ValueError(message) 2186 2187 if isinstance(node_color, (list, np.ndarray)) and len(node_color) != 1: 2188 if len(node_color) != node_coords_shape[0]: 2189 raise ValueError( 2190 "Mismatch between the number of nodes ({0}) " 2191 "and and the number of node colors ({1})." 2192 .format(node_coords_shape[0], len(node_color))) 2193 2194 if node_coords_shape[0] != adjacency_matrix_shape[0]: 2195 raise ValueError( 2196 "Shape mismatch between 'adjacency_matrix' " 2197 "and 'node_coords'" 2198 "'adjacency_matrix' shape is {0}, 'node_coords' shape is {1}" 2199 .format(adjacency_matrix_shape, node_coords_shape)) 2200 2201 # If the adjacency matrix is not symmetric, give a warning 2202 symmetric = True 2203 if not np.allclose(adjacency_matrix, adjacency_matrix.T, rtol=1e-3): 2204 symmetric = False 2205 warnings.warn(("'adjacency_matrix' is not symmetric. " 2206 "A directed graph will be plotted.")) 2207 2208 # For a masked array, masked values are replaced with zeros 2209 if hasattr(adjacency_matrix, 'mask'): 2210 if not (adjacency_matrix.mask == adjacency_matrix.mask.T).all(): 2211 symmetric = False 2212 warnings.warn(("'adjacency_matrix' was masked with " 2213 "a non symmetric mask. A directed " 2214 "graph will be plotted.")) 2215 adjacency_matrix = adjacency_matrix.filled(0) 2216 2217 if edge_threshold is not None: 2218 if symmetric: 2219 # Keep a percentile of edges with the highest absolute 2220 # values, so only need to look at the covariance 2221 # coefficients below the diagonal 2222 lower_diagonal_indices = np.tril_indices_from(adjacency_matrix, 2223 k=-1) 2224 lower_diagonal_values = adjacency_matrix[ 2225 lower_diagonal_indices] 2226 edge_threshold = _utils.param_validation.check_threshold( 2227 edge_threshold, np.abs(lower_diagonal_values), 2228 stats.scoreatpercentile, 'edge_threshold') 2229 else: 2230 edge_threshold = _utils.param_validation.check_threshold( 2231 edge_threshold, np.abs(adjacency_matrix.ravel()), 2232 stats.scoreatpercentile, 'edge_threshold') 2233 2234 adjacency_matrix = adjacency_matrix.copy() 2235 threshold_mask = np.abs(adjacency_matrix) < edge_threshold 2236 adjacency_matrix[threshold_mask] = 0 2237 2238 if symmetric: 2239 lower_triangular_adjacency_matrix = np.tril(adjacency_matrix, k=-1) 2240 non_zero_indices = lower_triangular_adjacency_matrix.nonzero() 2241 else: 2242 non_zero_indices = adjacency_matrix.nonzero() 2243 2244 line_coords = [node_coords[list(index)] 2245 for index in zip(*non_zero_indices)] 2246 2247 adjacency_matrix_values = adjacency_matrix[non_zero_indices] 2248 for ax in self.axes.values(): 2249 ax._add_markers(node_coords, node_color, node_size, **node_kwargs) 2250 if line_coords: 2251 ax._add_lines(line_coords, adjacency_matrix_values, edge_cmap, 2252 vmin=edge_vmin, vmax=edge_vmax, directed=(not symmetric), 2253 **edge_kwargs) 2254 # To obtain the brain left view, we simply invert the x axis 2255 if ax.direction == 'l' and not (ax.ax.get_xlim()[0] > ax.ax.get_xlim()[1]): 2256 ax.ax.invert_xaxis() 2257 2258 if colorbar: 2259 self._colorbar = colorbar 2260 self._show_colorbar(ax.cmap, ax.norm, threshold=edge_threshold) 2261 2262 plt.draw_if_interactive() 2263 2264 2265class XProjector(OrthoProjector): 2266 _cut_displayed = 'x' 2267 _default_figsize = [2.6, 2.3] 2268 2269 2270class YProjector(OrthoProjector): 2271 _cut_displayed = 'y' 2272 _default_figsize = [2.2, 2.3] 2273 2274 2275class ZProjector(OrthoProjector): 2276 _cut_displayed = 'z' 2277 _default_figsize = [2.2, 2.3] 2278 2279 2280class XZProjector(OrthoProjector): 2281 _cut_displayed = 'xz' 2282 2283 2284class YXProjector(OrthoProjector): 2285 _cut_displayed = 'yx' 2286 2287 2288class YZProjector(OrthoProjector): 2289 _cut_displayed = 'yz' 2290 2291 2292class LYRZProjector(OrthoProjector): 2293 _cut_displayed = 'lyrz' 2294 2295 2296class LZRYProjector(OrthoProjector): 2297 _cut_displayed = 'lzry' 2298 2299 2300class LZRProjector(OrthoProjector): 2301 _cut_displayed = 'lzr' 2302 2303 2304class LYRProjector(OrthoProjector): 2305 _cut_displayed = 'lyr' 2306 2307 2308class LRProjector(OrthoProjector): 2309 _cut_displayed = 'lr' 2310 2311 2312class LProjector(OrthoProjector): 2313 _cut_displayed = 'l' 2314 _default_figsize = [2.6, 2.3] 2315 2316 2317class RProjector(OrthoProjector): 2318 _cut_displayed = 'r' 2319 _default_figsize = [2.6, 2.3] 2320 2321 2322PROJECTORS = dict(ortho=OrthoProjector, 2323 xz=XZProjector, 2324 yz=YZProjector, 2325 yx=YXProjector, 2326 x=XProjector, 2327 y=YProjector, 2328 z=ZProjector, 2329 lzry=LZRYProjector, 2330 lyrz=LYRZProjector, 2331 lyr=LYRProjector, 2332 lzr=LZRProjector, 2333 lr=LRProjector, 2334 l=LProjector, 2335 r=RProjector) 2336 2337 2338def get_create_display_fun(display_mode, class_dict): 2339 try: 2340 return class_dict[display_mode].init_with_figure 2341 except KeyError: 2342 message = ('{0} is not a valid display_mode. ' 2343 'Valid options are {1}').format( 2344 display_mode, sorted(class_dict.keys())) 2345 raise ValueError(message) 2346 2347 2348def get_slicer(display_mode): 2349 "Internal function to retrieve a slicer" 2350 return get_create_display_fun(display_mode, SLICERS) 2351 2352 2353def get_projector(display_mode): 2354 "Internal function to retrieve a projector" 2355 return get_create_display_fun(display_mode, PROJECTORS) 2356